Improved local media ID handling (#67083)
This commit is contained in:
parent
c76d2c4283
commit
fda3877852
3 changed files with 38 additions and 29 deletions
|
@ -2,6 +2,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
import dataclasses
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
@ -165,15 +166,17 @@ async def websocket_resolve_media(
|
||||||
"""Resolve media."""
|
"""Resolve media."""
|
||||||
try:
|
try:
|
||||||
media = await async_resolve_media(hass, msg["media_content_id"])
|
media = await async_resolve_media(hass, msg["media_content_id"])
|
||||||
url = media.url
|
|
||||||
except Unresolvable as err:
|
except Unresolvable as err:
|
||||||
connection.send_error(msg["id"], "resolve_media_failed", str(err))
|
connection.send_error(msg["id"], "resolve_media_failed", str(err))
|
||||||
else:
|
return
|
||||||
if url[0] == "/":
|
|
||||||
url = async_sign_path(
|
data = dataclasses.asdict(media)
|
||||||
|
|
||||||
|
if data["url"][0] == "/":
|
||||||
|
data["url"] = async_sign_path(
|
||||||
hass,
|
hass,
|
||||||
quote(url),
|
quote(data["url"]),
|
||||||
timedelta(seconds=msg["expires"]),
|
timedelta(seconds=msg["expires"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
connection.send_result(msg["id"], {"url": url, "mime_type": media.mime_type})
|
connection.send_result(msg["id"], data)
|
||||||
|
|
|
@ -56,10 +56,6 @@ class LocalSource(MediaSource):
|
||||||
if item.domain != DOMAIN:
|
if item.domain != DOMAIN:
|
||||||
raise Unresolvable("Unknown domain.")
|
raise Unresolvable("Unknown domain.")
|
||||||
|
|
||||||
if not item.identifier:
|
|
||||||
# Empty source_dir_id and location
|
|
||||||
return "", ""
|
|
||||||
|
|
||||||
source_dir_id, _, location = item.identifier.partition("/")
|
source_dir_id, _, location = item.identifier.partition("/")
|
||||||
if source_dir_id not in self.hass.config.media_dirs:
|
if source_dir_id not in self.hass.config.media_dirs:
|
||||||
raise Unresolvable("Unknown source directory.")
|
raise Unresolvable("Unknown source directory.")
|
||||||
|
@ -74,36 +70,39 @@ class LocalSource(MediaSource):
|
||||||
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
|
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
|
||||||
"""Resolve media to a url."""
|
"""Resolve media to a url."""
|
||||||
source_dir_id, location = self.async_parse_identifier(item)
|
source_dir_id, location = self.async_parse_identifier(item)
|
||||||
if source_dir_id == "" or source_dir_id not in self.hass.config.media_dirs:
|
path = self.async_full_path(source_dir_id, location)
|
||||||
raise Unresolvable("Unknown source directory.")
|
mime_type, _ = mimetypes.guess_type(str(path))
|
||||||
|
|
||||||
mime_type, _ = mimetypes.guess_type(
|
|
||||||
str(self.async_full_path(source_dir_id, location))
|
|
||||||
)
|
|
||||||
assert isinstance(mime_type, str)
|
assert isinstance(mime_type, str)
|
||||||
return PlayMedia(f"/media/{item.identifier}", mime_type)
|
return PlayMedia(f"/media/{item.identifier}", mime_type)
|
||||||
|
|
||||||
async def async_browse_media(self, item: MediaSourceItem) -> BrowseMediaSource:
|
async def async_browse_media(self, item: MediaSourceItem) -> BrowseMediaSource:
|
||||||
"""Return media."""
|
"""Return media."""
|
||||||
|
if item.identifier:
|
||||||
try:
|
try:
|
||||||
source_dir_id, location = self.async_parse_identifier(item)
|
source_dir_id, location = self.async_parse_identifier(item)
|
||||||
except Unresolvable as err:
|
except Unresolvable as err:
|
||||||
raise BrowseError(str(err)) from err
|
raise BrowseError(str(err)) from err
|
||||||
|
|
||||||
|
else:
|
||||||
|
source_dir_id, location = None, ""
|
||||||
|
|
||||||
result = await self.hass.async_add_executor_job(
|
result = await self.hass.async_add_executor_job(
|
||||||
self._browse_media, source_dir_id, location
|
self._browse_media, source_dir_id, location
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _browse_media(self, source_dir_id: str, location: str) -> BrowseMediaSource:
|
def _browse_media(
|
||||||
|
self, source_dir_id: str | None, location: str
|
||||||
|
) -> BrowseMediaSource:
|
||||||
"""Browse media."""
|
"""Browse media."""
|
||||||
|
|
||||||
# If only one media dir is configured, use that as the local media root
|
# If only one media dir is configured, use that as the local media root
|
||||||
if source_dir_id == "" and len(self.hass.config.media_dirs) == 1:
|
if source_dir_id is None and len(self.hass.config.media_dirs) == 1:
|
||||||
source_dir_id = list(self.hass.config.media_dirs)[0]
|
source_dir_id = list(self.hass.config.media_dirs)[0]
|
||||||
|
|
||||||
# Multiple folder, root is requested
|
# Multiple folder, root is requested
|
||||||
if source_dir_id == "":
|
if source_dir_id is None:
|
||||||
if location:
|
if location:
|
||||||
raise BrowseError("Folder not found.")
|
raise BrowseError("Folder not found.")
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
"""Test Media Source initialization."""
|
"""Test Media Source initialization."""
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from urllib.parse import quote
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import yarl
|
||||||
|
|
||||||
from homeassistant.components import media_source
|
from homeassistant.components import media_source
|
||||||
from homeassistant.components.media_player import MEDIA_CLASS_DIRECTORY, BrowseError
|
from homeassistant.components.media_player import MEDIA_CLASS_DIRECTORY, BrowseError
|
||||||
|
@ -159,7 +159,10 @@ async def test_websocket_resolve_media(hass, hass_ws_client, filename):
|
||||||
|
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
media = media_source.models.PlayMedia(f"/media/local/{filename}", "audio/mpeg")
|
media = media_source.models.PlayMedia(
|
||||||
|
f"/media/local/{filename}",
|
||||||
|
"audio/mpeg",
|
||||||
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.media_source.async_resolve_media",
|
"homeassistant.components.media_source.async_resolve_media",
|
||||||
|
@ -177,9 +180,13 @@ async def test_websocket_resolve_media(hass, hass_ws_client, filename):
|
||||||
|
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
assert msg["id"] == 1
|
assert msg["id"] == 1
|
||||||
assert msg["result"]["url"].startswith(quote(media.url))
|
|
||||||
assert msg["result"]["mime_type"] == media.mime_type
|
assert msg["result"]["mime_type"] == media.mime_type
|
||||||
|
|
||||||
|
# Validate url is signed.
|
||||||
|
parsed = yarl.URL(msg["result"]["url"])
|
||||||
|
assert parsed.path == getattr(media, "url")
|
||||||
|
assert "authSig" in parsed.query
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.media_source.async_resolve_media",
|
"homeassistant.components.media_source.async_resolve_media",
|
||||||
side_effect=media_source.Unresolvable("test"),
|
side_effect=media_source.Unresolvable("test"),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue