Allow uploading media to media folder (#66143)

This commit is contained in:
Paulus Schoutsen 2022-02-10 08:03:14 -08:00 committed by GitHub
parent 0fb2c78b6d
commit dd48f1e6fc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 221 additions and 2 deletions

View file

@ -1,21 +1,29 @@
"""Local Media Source Implementation.""" """Local Media Source Implementation."""
from __future__ import annotations from __future__ import annotations
import logging
import mimetypes import mimetypes
from pathlib import Path from pathlib import Path
from aiohttp import web from aiohttp import web
from aiohttp.web_request import FileField
from aioshutil import shutil
import voluptuous as vol
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from homeassistant.components.media_player.const import MEDIA_CLASS_DIRECTORY from homeassistant.components.media_player.const import MEDIA_CLASS_DIRECTORY
from homeassistant.components.media_player.errors import BrowseError from homeassistant.components.media_player.errors import BrowseError
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.util import raise_if_invalid_path from homeassistant.exceptions import Unauthorized
from homeassistant.util import raise_if_invalid_filename, raise_if_invalid_path
from .const import DOMAIN, MEDIA_CLASS_MAP, MEDIA_MIME_TYPES from .const import DOMAIN, MEDIA_CLASS_MAP, MEDIA_MIME_TYPES
from .error import Unresolvable from .error import Unresolvable
from .models import BrowseMediaSource, MediaSource, MediaSourceItem, PlayMedia from .models import BrowseMediaSource, MediaSource, MediaSourceItem, PlayMedia
MAX_UPLOAD_SIZE = 1024 * 1024 * 10
LOGGER = logging.getLogger(__name__)
@callback @callback
def async_setup(hass: HomeAssistant) -> None: def async_setup(hass: HomeAssistant) -> None:
@ -23,6 +31,7 @@ def async_setup(hass: HomeAssistant) -> None:
source = LocalSource(hass) source = LocalSource(hass)
hass.data[DOMAIN][DOMAIN] = source hass.data[DOMAIN][DOMAIN] = source
hass.http.register_view(LocalMediaView(hass, source)) hass.http.register_view(LocalMediaView(hass, source))
hass.http.register_view(UploadMediaView(hass, source))
class LocalSource(MediaSource): class LocalSource(MediaSource):
@ -43,11 +52,14 @@ class LocalSource(MediaSource):
@callback @callback
def async_parse_identifier(self, item: MediaSourceItem) -> tuple[str, str]: def async_parse_identifier(self, item: MediaSourceItem) -> tuple[str, str]:
"""Parse identifier.""" """Parse identifier."""
if item.domain != DOMAIN:
raise Unresolvable("Unknown domain.")
if not item.identifier: if not item.identifier:
# Empty source_dir_id and location # Empty source_dir_id and location
return "", "" return "", ""
source_dir_id, location = item.identifier.split("/", 1) source_dir_id, _, location = item.identifier.partition("/")
if source_dir_id not in self.hass.config.media_dirs: if source_dir_id not in self.hass.config.media_dirs:
raise Unresolvable("Unknown source directory.") raise Unresolvable("Unknown source directory.")
@ -217,3 +229,88 @@ class LocalMediaView(HomeAssistantView):
raise web.HTTPNotFound() raise web.HTTPNotFound()
return web.FileResponse(media_path) return web.FileResponse(media_path)
class UploadMediaView(HomeAssistantView):
"""View to upload images."""
url = "/api/media_source/local_source/upload"
name = "api:media_source:local_source:upload"
def __init__(self, hass: HomeAssistant, source: LocalSource) -> None:
"""Initialize the media view."""
self.hass = hass
self.source = source
self.schema = vol.Schema(
{
"media_content_id": str,
"file": FileField,
}
)
async def post(self, request: web.Request) -> web.Response:
"""Handle upload."""
if not request["hass_user"].is_admin:
raise Unauthorized()
# Increase max payload
request._client_max_size = MAX_UPLOAD_SIZE # pylint: disable=protected-access
try:
data = self.schema(dict(await request.post()))
except vol.Invalid as err:
LOGGER.error("Received invalid upload data: %s", err)
raise web.HTTPBadRequest() from err
try:
item = MediaSourceItem.from_uri(self.hass, data["media_content_id"])
except ValueError as err:
LOGGER.error("Received invalid upload data: %s", err)
raise web.HTTPBadRequest() from err
try:
source_dir_id, location = self.source.async_parse_identifier(item)
except Unresolvable as err:
LOGGER.error("Invalid local source ID")
raise web.HTTPBadRequest() from err
uploaded_file: FileField = data["file"]
if not uploaded_file.content_type.startswith(("image/", "video/")):
LOGGER.error("Content type not allowed")
raise vol.Invalid("Only images and video are allowed")
try:
raise_if_invalid_filename(uploaded_file.filename)
except ValueError as err:
LOGGER.error("Invalid filename")
raise web.HTTPBadRequest() from err
try:
await self.hass.async_add_executor_job(
self._move_file,
self.source.async_full_path(source_dir_id, location),
uploaded_file,
)
except ValueError as err:
LOGGER.error("Moving upload failed: %s", err)
raise web.HTTPBadRequest() from err
return self.json(
{"media_content_id": f"{data['media_content_id']}/{uploaded_file.filename}"}
)
def _move_file( # pylint: disable=no-self-use
self, target_dir: Path, uploaded_file: FileField
) -> None:
"""Move file to target."""
if not target_dir.is_dir():
raise ValueError("Target is not an existing directory")
target_path = target_dir / uploaded_file.filename
target_path.relative_to(target_dir)
raise_if_invalid_path(str(target_path))
with target_path.open("wb") as target_fp:
shutil.copyfileobj(uploaded_file.file, target_fp)

View file

@ -1,5 +1,9 @@
"""Test Local Media Source.""" """Test Local Media Source."""
from http import HTTPStatus from http import HTTPStatus
import io
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest.mock import patch
import pytest import pytest
@ -9,6 +13,20 @@ from homeassistant.config import async_process_ha_core_config
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@pytest.fixture
async def temp_dir(hass):
"""Return a temp dir."""
with TemporaryDirectory() as tmpdirname:
target_dir = Path(tmpdirname) / "another_subdir"
target_dir.mkdir()
await async_process_ha_core_config(
hass, {"media_dirs": {"test_dir": str(target_dir)}}
)
assert await async_setup_component(hass, const.DOMAIN, {})
yield str(target_dir)
async def test_async_browse_media(hass): async def test_async_browse_media(hass):
"""Test browse media.""" """Test browse media."""
local_media = hass.config.path("media") local_media = hass.config.path("media")
@ -102,3 +120,107 @@ async def test_media_view(hass, hass_client):
resp = await client.get("/media/recordings/test.mp3") resp = await client.get("/media/recordings/test.mp3")
assert resp.status == HTTPStatus.OK assert resp.status == HTTPStatus.OK
async def test_upload_view(hass, hass_client, temp_dir, hass_admin_user):
"""Allow uploading media."""
img = (Path(__file__).parent.parent / "image/logo.png").read_bytes()
def get_file(name):
pic = io.BytesIO(img)
pic.name = name
return pic
client = await hass_client()
# Test normal upload
res = await client.post(
"/api/media_source/local_source/upload",
data={
"media_content_id": "media-source://media_source/test_dir/.",
"file": get_file("logo.png"),
},
)
assert res.status == 200
assert (Path(temp_dir) / "logo.png").is_file()
# Test with bad media source ID
for bad_id in (
# Subdir doesn't exist
"media-source://media_source/test_dir/some-other-dir",
# Main dir doesn't exist
"media-source://media_source/test_dir2",
# Location is invalid
"media-source://media_source/test_dir/..",
# Domain != media_source
"media-source://nest/test_dir/.",
# Completely something else
"http://bla",
):
res = await client.post(
"/api/media_source/local_source/upload",
data={
"media_content_id": bad_id,
"file": get_file("bad-source-id.png"),
},
)
assert res.status == 400
assert not (Path(temp_dir) / "bad-source-id.png").is_file()
# Test invalid POST data
res = await client.post(
"/api/media_source/local_source/upload",
data={
"media_content_id": "media-source://media_source/test_dir/.",
"file": get_file("invalid-data.png"),
"incorrect": "format",
},
)
assert res.status == 400
assert not (Path(temp_dir) / "invalid-data.png").is_file()
# Test invalid content type
text_file = io.BytesIO(b"Hello world")
text_file.name = "hello.txt"
res = await client.post(
"/api/media_source/local_source/upload",
data={
"media_content_id": "media-source://media_source/test_dir/.",
"file": text_file,
},
)
assert res.status == 400
assert not (Path(temp_dir) / "hello.txt").is_file()
# Test invalid filename
with patch(
"aiohttp.formdata.guess_filename", return_value="../invalid-filename.png"
):
res = await client.post(
"/api/media_source/local_source/upload",
data={
"media_content_id": "media-source://media_source/test_dir/.",
"file": get_file("../invalid-filename.png"),
},
)
assert res.status == 400
assert not (Path(temp_dir) / "../invalid-filename.png").is_file()
# Remove admin access
hass_admin_user.groups = []
res = await client.post(
"/api/media_source/local_source/upload",
data={
"media_content_id": "media-source://media_source/test_dir/.",
"file": get_file("no-admin-test.png"),
},
)
assert res.status == 401
assert not (Path(temp_dir) / "no-admin-test.png").is_file()