Increase file upload limit to 100 MB (#77117)

* Increase file upload limit to 100 MB

* Remove comment

* Add test and fix chunk processing

* Add test for wrong field

* Add review suggestions

* Use nonlocal and remove unneeded executor task

* Use Janus to process chunk uploading

* Address review comments

* Address review comments #2

* Improve tests

* Fix discovery test

* Fix tests
This commit is contained in:
Marvin Wichmann 2022-11-30 02:46:34 +01:00 committed by GitHub
parent a3ec9529ec
commit 1908feab79
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 107 additions and 16 deletions

View file

@ -9,7 +9,8 @@ from pathlib import Path
import shutil
import tempfile
from aiohttp import web
from aiohttp import BodyPartReader, web
import janus
import voluptuous as vol
from homeassistant.components.http import HomeAssistantView
@ -22,9 +23,8 @@ from homeassistant.util.ulid import ulid_hex
DOMAIN = "file_upload"
# If increased, change upload view to streaming
# https://docs.aiohttp.org/en/stable/web_quickstart.html#file-uploads
MAX_SIZE = 1024 * 1024 * 10
ONE_MEGABYTE = 1024 * 1024
MAX_SIZE = 100 * ONE_MEGABYTE
TEMP_DIR_NAME = f"home-assistant-{DOMAIN}"
@ -126,14 +126,18 @@ class FileUploadView(HomeAssistantView):
# Increase max payload
request._client_max_size = MAX_SIZE # pylint: disable=protected-access
data = await request.post()
file_field = data.get("file")
reader = await request.multipart()
file_field_reader = await reader.next()
if not isinstance(file_field, web.FileField):
if (
not isinstance(file_field_reader, BodyPartReader)
or file_field_reader.name != "file"
or file_field_reader.filename is None
):
raise vol.Invalid("Expected a file")
try:
raise_if_invalid_filename(file_field.filename)
raise_if_invalid_filename(file_field_reader.filename)
except ValueError as err:
raise web.HTTPBadRequest from err
@ -145,19 +149,39 @@ class FileUploadView(HomeAssistantView):
file_upload_data: FileUploadData = hass.data[DOMAIN]
file_dir = file_upload_data.file_dir(file_id)
queue: janus.Queue[bytes | None] = janus.Queue()
def _sync_work() -> None:
def _sync_queue_consumer(
sync_q: janus.SyncQueue[bytes | None], _file_name: str
) -> None:
file_dir.mkdir()
with (file_dir / _file_name).open("wb") as file_handle:
while True:
_chunk = sync_q.get()
if _chunk is None:
break
# MyPy forgets about the isinstance check because we're in a function scope
assert isinstance(file_field, web.FileField)
file_handle.write(_chunk)
sync_q.task_done()
with (file_dir / file_field.filename).open("wb") as target_fileobj:
shutil.copyfileobj(file_field.file, target_fileobj)
fut: asyncio.Future[None] | None = None
try:
fut = hass.async_add_executor_job(
_sync_queue_consumer,
queue.sync_q,
file_field_reader.filename,
)
await hass.async_add_executor_job(_sync_work)
while chunk := await file_field_reader.read_chunk(ONE_MEGABYTE):
queue.async_q.put_nowait(chunk)
if queue.async_q.qsize() > 5: # Allow up to 5 MB buffer size
await queue.async_q.join()
queue.async_q.put_nowait(None) # terminate queue consumer
finally:
if fut is not None:
await fut
file_upload_data.files[file_id] = file_field.filename
file_upload_data.files[file_id] = file_field_reader.filename
return self.json({"file_id": file_id})

View file

@ -2,6 +2,7 @@
"domain": "file_upload",
"name": "File Upload",
"documentation": "https://www.home-assistant.io/integrations/file_upload",
"requirements": ["janus==1.0.0"],
"dependencies": ["http"],
"codeowners": ["@home-assistant/core"],
"quality_scale": "internal",

View file

@ -25,6 +25,7 @@ home-assistant-bluetooth==1.8.1
home-assistant-frontend==20221108.0
httpx==0.23.1
ifaddr==0.1.7
janus==1.0.0
jinja2==3.1.2
lru-dict==1.1.8
orjson==3.8.1

View file

@ -964,6 +964,9 @@ iperf3==0.1.11
# homeassistant.components.gogogate2
ismartgate==4.0.4
# homeassistant.components.file_upload
janus==1.0.0
# homeassistant.components.jellyfin
jellyfin-apiclient-python==1.9.2

View file

@ -717,6 +717,9 @@ iotawattpy==0.1.0
# homeassistant.components.gogogate2
ismartgate==4.0.4
# homeassistant.components.file_upload
janus==1.0.0
# homeassistant.components.jellyfin
jellyfin-apiclient-python==1.9.2

View file

@ -0,0 +1,13 @@
"""Fixtures for FileUpload integration."""
from io import StringIO
import pytest
@pytest.fixture
def large_file_io() -> StringIO:
"""Generate a file on the fly. Simulates a large file."""
return StringIO(
2
* "Home Assistant is awesome. Open source home automation that puts local control and privacy first."
)

View file

@ -64,3 +64,49 @@ async def test_removed_on_stop(hass: HomeAssistant, hass_client, uploaded_file_d
# Test it's removed
assert not uploaded_file_dir.exists()
async def test_upload_large_file(hass: HomeAssistant, hass_client, large_file_io):
"""Test uploading large file."""
assert await async_setup_component(hass, "file_upload", {})
client = await hass_client()
with patch(
# Patch temp dir name to avoid tests fail running in parallel
"homeassistant.components.file_upload.TEMP_DIR_NAME",
file_upload.TEMP_DIR_NAME + f"-{getrandbits(10):03x}",
), patch(
# Patch one megabyte to 8 bytes to prevent having to use big files in tests
"homeassistant.components.file_upload.ONE_MEGABYTE",
8,
):
res = await client.post("/api/file_upload", data={"file": large_file_io})
assert res.status == 200
response = await res.json()
file_dir = hass.data[file_upload.DOMAIN].file_dir(response["file_id"])
assert file_dir.is_dir()
large_file_io.seek(0)
with file_upload.process_uploaded_file(hass, file_dir.name) as file_path:
assert file_path.is_file()
assert file_path.parent == file_dir
assert file_path.read_bytes() == large_file_io.read().encode("utf-8")
async def test_upload_with_wrong_key_fails(
hass: HomeAssistant, hass_client, large_file_io
):
"""Test uploading fails."""
assert await async_setup_component(hass, "file_upload", {})
client = await hass_client()
with patch(
# Patch temp dir name to avoid tests fail running in parallel
"homeassistant.components.file_upload.TEMP_DIR_NAME",
file_upload.TEMP_DIR_NAME + f"-{getrandbits(10):03x}",
):
res = await client.post("/api/file_upload", data={"wrong_key": large_file_io})
assert res.status == 400

View file

@ -401,7 +401,7 @@ async def test_discovery_requirements_mqtt(hass):
) as mock_process:
await async_get_integration_with_requirements(hass, "mqtt_comp")
assert len(mock_process.mock_calls) == 2 # mqtt also depends on http
assert len(mock_process.mock_calls) == 3 # mqtt also depends on http
assert mock_process.mock_calls[0][1][1] == mqtt.requirements