Improved local media ID handling (#67083)

This commit is contained in:
Paulus Schoutsen 2022-02-22 23:39:54 -08:00 committed by GitHub
parent c76d2c4283
commit fda3877852
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 38 additions and 29 deletions

View file

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
import dataclasses
from datetime import timedelta from datetime import timedelta
from typing import Any from typing import Any
from urllib.parse import quote from urllib.parse import quote
@ -165,15 +166,17 @@ async def websocket_resolve_media(
"""Resolve media.""" """Resolve media."""
try: try:
media = await async_resolve_media(hass, msg["media_content_id"]) media = await async_resolve_media(hass, msg["media_content_id"])
url = media.url
except Unresolvable as err: except Unresolvable as err:
connection.send_error(msg["id"], "resolve_media_failed", str(err)) connection.send_error(msg["id"], "resolve_media_failed", str(err))
else: return
if url[0] == "/":
url = async_sign_path( data = dataclasses.asdict(media)
if data["url"][0] == "/":
data["url"] = async_sign_path(
hass, hass,
quote(url), quote(data["url"]),
timedelta(seconds=msg["expires"]), timedelta(seconds=msg["expires"]),
) )
connection.send_result(msg["id"], {"url": url, "mime_type": media.mime_type}) connection.send_result(msg["id"], data)

View file

@ -56,10 +56,6 @@ class LocalSource(MediaSource):
if item.domain != DOMAIN: if item.domain != DOMAIN:
raise Unresolvable("Unknown domain.") raise Unresolvable("Unknown domain.")
if not item.identifier:
# Empty source_dir_id and location
return "", ""
source_dir_id, _, location = item.identifier.partition("/") source_dir_id, _, location = item.identifier.partition("/")
if source_dir_id not in self.hass.config.media_dirs: if source_dir_id not in self.hass.config.media_dirs:
raise Unresolvable("Unknown source directory.") raise Unresolvable("Unknown source directory.")
@ -74,36 +70,39 @@ class LocalSource(MediaSource):
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia: async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
"""Resolve media to a url.""" """Resolve media to a url."""
source_dir_id, location = self.async_parse_identifier(item) source_dir_id, location = self.async_parse_identifier(item)
if source_dir_id == "" or source_dir_id not in self.hass.config.media_dirs: path = self.async_full_path(source_dir_id, location)
raise Unresolvable("Unknown source directory.") mime_type, _ = mimetypes.guess_type(str(path))
mime_type, _ = mimetypes.guess_type(
str(self.async_full_path(source_dir_id, location))
)
assert isinstance(mime_type, str) assert isinstance(mime_type, str)
return PlayMedia(f"/media/{item.identifier}", mime_type) return PlayMedia(f"/media/{item.identifier}", mime_type)
async def async_browse_media(self, item: MediaSourceItem) -> BrowseMediaSource: async def async_browse_media(self, item: MediaSourceItem) -> BrowseMediaSource:
"""Return media.""" """Return media."""
if item.identifier:
try: try:
source_dir_id, location = self.async_parse_identifier(item) source_dir_id, location = self.async_parse_identifier(item)
except Unresolvable as err: except Unresolvable as err:
raise BrowseError(str(err)) from err raise BrowseError(str(err)) from err
else:
source_dir_id, location = None, ""
result = await self.hass.async_add_executor_job( result = await self.hass.async_add_executor_job(
self._browse_media, source_dir_id, location self._browse_media, source_dir_id, location
) )
return result return result
def _browse_media(self, source_dir_id: str, location: str) -> BrowseMediaSource: def _browse_media(
self, source_dir_id: str | None, location: str
) -> BrowseMediaSource:
"""Browse media.""" """Browse media."""
# If only one media dir is configured, use that as the local media root # If only one media dir is configured, use that as the local media root
if source_dir_id == "" and len(self.hass.config.media_dirs) == 1: if source_dir_id is None and len(self.hass.config.media_dirs) == 1:
source_dir_id = list(self.hass.config.media_dirs)[0] source_dir_id = list(self.hass.config.media_dirs)[0]
# Multiple folder, root is requested # Multiple folder, root is requested
if source_dir_id == "": if source_dir_id is None:
if location: if location:
raise BrowseError("Folder not found.") raise BrowseError("Folder not found.")

View file

@ -1,8 +1,8 @@
"""Test Media Source initialization.""" """Test Media Source initialization."""
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from urllib.parse import quote
import pytest import pytest
import yarl
from homeassistant.components import media_source from homeassistant.components import media_source
from homeassistant.components.media_player import MEDIA_CLASS_DIRECTORY, BrowseError from homeassistant.components.media_player import MEDIA_CLASS_DIRECTORY, BrowseError
@ -159,7 +159,10 @@ async def test_websocket_resolve_media(hass, hass_ws_client, filename):
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
media = media_source.models.PlayMedia(f"/media/local/{filename}", "audio/mpeg") media = media_source.models.PlayMedia(
f"/media/local/{filename}",
"audio/mpeg",
)
with patch( with patch(
"homeassistant.components.media_source.async_resolve_media", "homeassistant.components.media_source.async_resolve_media",
@ -177,9 +180,13 @@ async def test_websocket_resolve_media(hass, hass_ws_client, filename):
assert msg["success"] assert msg["success"]
assert msg["id"] == 1 assert msg["id"] == 1
assert msg["result"]["url"].startswith(quote(media.url))
assert msg["result"]["mime_type"] == media.mime_type assert msg["result"]["mime_type"] == media.mime_type
# Validate url is signed.
parsed = yarl.URL(msg["result"]["url"])
assert parsed.path == getattr(media, "url")
assert "authSig" in parsed.query
with patch( with patch(
"homeassistant.components.media_source.async_resolve_media", "homeassistant.components.media_source.async_resolve_media",
side_effect=media_source.Unresolvable("test"), side_effect=media_source.Unresolvable("test"),