Add a media source to TTS (#66483)
This commit is contained in:
parent
013d227113
commit
8456c6416e
10 changed files with 283 additions and 75 deletions
|
@ -9,6 +9,7 @@ import io
|
|||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
|
||||
|
@ -39,10 +40,11 @@ import homeassistant.helpers.config_validation as cv
|
|||
from homeassistant.helpers.network import get_url
|
||||
from homeassistant.helpers.service import async_set_service_schema
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
from homeassistant.loader import async_get_integration
|
||||
from homeassistant.setup import async_prepare_setup_platform
|
||||
from homeassistant.util.yaml import load_yaml
|
||||
|
||||
from .const import DOMAIN
|
||||
|
||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
@ -69,7 +71,6 @@ CONF_FIELDS = "fields"
|
|||
DEFAULT_CACHE = True
|
||||
DEFAULT_CACHE_DIR = "tts"
|
||||
DEFAULT_TIME_MEMORY = 300
|
||||
DOMAIN = "tts"
|
||||
|
||||
MEM_CACHE_FILENAME = "filename"
|
||||
MEM_CACHE_VOICE = "voice"
|
||||
|
@ -135,12 +136,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
_LOGGER.exception("Error on cache init")
|
||||
return False
|
||||
|
||||
hass.data[DOMAIN] = tts
|
||||
hass.http.register_view(TextToSpeechView(tts))
|
||||
hass.http.register_view(TextToSpeechUrlView(tts))
|
||||
|
||||
# Load service descriptions from tts/services.yaml
|
||||
integration = await async_get_integration(hass, DOMAIN)
|
||||
services_yaml = integration.file_path / "services.yaml"
|
||||
services_yaml = Path(__file__).parent / "services.yaml"
|
||||
services_dict = cast(
|
||||
dict, await hass.async_add_executor_job(load_yaml, str(services_yaml))
|
||||
)
|
||||
|
@ -343,7 +344,11 @@ class SpeechManager:
|
|||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
provider = self.providers[engine]
|
||||
provider = self.providers.get(engine)
|
||||
|
||||
if provider is None:
|
||||
raise HomeAssistantError(f"Provider {engine} not found")
|
||||
|
||||
msg_hash = hashlib.sha1(bytes(message, "utf-8")).hexdigest()
|
||||
use_cache = cache if cache is not None else self.use_cache
|
||||
|
||||
|
|
3
homeassistant/components/tts/const.py
Normal file
3
homeassistant/components/tts/const.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
"""Text-to-speech constants."""
|
||||
|
||||
DOMAIN = "tts"
|
109
homeassistant/components/tts/media_source.py
Normal file
109
homeassistant/components/tts/media_source.py
Normal file
|
@ -0,0 +1,109 @@
|
|||
"""Text-to-speech media source."""
|
||||
from __future__ import annotations
|
||||
|
||||
import mimetypes
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from homeassistant.components.media_player.const import MEDIA_CLASS_APP
|
||||
from homeassistant.components.media_player.errors import BrowseError
|
||||
from homeassistant.components.media_source.error import Unresolvable
|
||||
from homeassistant.components.media_source.models import (
|
||||
BrowseMediaSource,
|
||||
MediaSource,
|
||||
MediaSourceItem,
|
||||
PlayMedia,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from .const import DOMAIN
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import SpeechManager
|
||||
|
||||
|
||||
async def async_get_media_source(hass: HomeAssistant) -> TTSMediaSource:
|
||||
"""Set up tts media source."""
|
||||
return TTSMediaSource(hass)
|
||||
|
||||
|
||||
class TTSMediaSource(MediaSource):
|
||||
"""Provide text-to-speech providers as media sources."""
|
||||
|
||||
name: str = "Text to Speech"
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize TTSMediaSource."""
|
||||
super().__init__(DOMAIN)
|
||||
self.hass = hass
|
||||
|
||||
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
|
||||
"""Resolve media to a url."""
|
||||
parsed = URL(item.identifier)
|
||||
if "message" not in parsed.query:
|
||||
raise Unresolvable("No message specified.")
|
||||
|
||||
options = dict(parsed.query)
|
||||
kwargs = {
|
||||
"engine": parsed.name,
|
||||
"message": options.pop("message"),
|
||||
"language": options.pop("language", None),
|
||||
"options": options,
|
||||
}
|
||||
|
||||
manager: SpeechManager = self.hass.data[DOMAIN]
|
||||
|
||||
try:
|
||||
url = await manager.async_get_url_path(**kwargs) # type: ignore
|
||||
except HomeAssistantError as err:
|
||||
raise Unresolvable(str(err)) from err
|
||||
|
||||
mime_type = mimetypes.guess_type(url)[0] or "audio/mpeg"
|
||||
|
||||
return PlayMedia(url, mime_type)
|
||||
|
||||
async def async_browse_media(
|
||||
self,
|
||||
item: MediaSourceItem,
|
||||
) -> BrowseMediaSource:
|
||||
"""Return media."""
|
||||
if item.identifier:
|
||||
provider, _, _ = item.identifier.partition("?")
|
||||
return self._provider_item(provider)
|
||||
|
||||
# Root. List providers.
|
||||
manager: SpeechManager = self.hass.data[DOMAIN]
|
||||
children = [self._provider_item(provider) for provider in manager.providers]
|
||||
return BrowseMediaSource(
|
||||
domain=DOMAIN,
|
||||
identifier=None,
|
||||
media_class=MEDIA_CLASS_APP,
|
||||
media_content_type="",
|
||||
title=self.name,
|
||||
can_play=False,
|
||||
can_expand=True,
|
||||
children_media_class=MEDIA_CLASS_APP,
|
||||
children=children,
|
||||
)
|
||||
|
||||
@callback
|
||||
def _provider_item(self, provider_domain: str) -> BrowseMediaSource:
|
||||
"""Return provider item."""
|
||||
manager: SpeechManager = self.hass.data[DOMAIN]
|
||||
provider = manager.providers.get(provider_domain)
|
||||
|
||||
if provider is None:
|
||||
raise BrowseError("Unknown provider")
|
||||
|
||||
return BrowseMediaSource(
|
||||
domain=DOMAIN,
|
||||
identifier=provider_domain,
|
||||
media_class=MEDIA_CLASS_APP,
|
||||
media_content_type="provider",
|
||||
title=provider.name,
|
||||
thumbnail=f"https://brands.home-assistant.io/_/{provider_domain}/logo.png",
|
||||
can_play=False,
|
||||
can_expand=True,
|
||||
)
|
|
@ -16,7 +16,7 @@ from homeassistant.config import async_process_ha_core_config
|
|||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import async_mock_service
|
||||
from tests.components.tts.test_init import mutagen_mock # noqa: F401
|
||||
from tests.components.tts.conftest import mutagen_mock # noqa: F401
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
|
|
@ -2,9 +2,12 @@
|
|||
|
||||
From http://doc.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.tts import _get_cache_files
|
||||
|
||||
|
||||
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
|
||||
def pytest_runtest_makereport(item, call):
|
||||
|
@ -16,3 +19,55 @@ def pytest_runtest_makereport(item, call):
|
|||
# set a report attribute for each phase of a call, which can
|
||||
# be "setup", "call", "teardown"
|
||||
setattr(item, f"rep_{rep.when}", rep)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_get_cache_files():
|
||||
"""Mock the list TTS cache function."""
|
||||
with patch(
|
||||
"homeassistant.components.tts._get_cache_files", return_value={}
|
||||
) as mock_cache_files:
|
||||
yield mock_cache_files
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_init_cache_dir():
|
||||
"""Mock the TTS cache dir in memory."""
|
||||
with patch(
|
||||
"homeassistant.components.tts._init_tts_cache_dir",
|
||||
side_effect=lambda hass, cache_dir: hass.config.path(cache_dir),
|
||||
) as mock_cache_dir:
|
||||
yield mock_cache_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_cache_dir(tmp_path, mock_init_cache_dir, mock_get_cache_files, request):
|
||||
"""Mock the TTS cache dir with empty dir."""
|
||||
mock_init_cache_dir.side_effect = None
|
||||
mock_init_cache_dir.return_value = str(tmp_path)
|
||||
|
||||
# Restore original get cache files behavior, we're working with a real dir.
|
||||
mock_get_cache_files.side_effect = _get_cache_files
|
||||
|
||||
yield tmp_path
|
||||
|
||||
if request.node.rep_call.passed:
|
||||
return
|
||||
|
||||
# Print contents of dir if failed
|
||||
print("Content of dir for", request.node.nodeid)
|
||||
for fil in tmp_path.iterdir():
|
||||
print(fil.relative_to(tmp_path))
|
||||
|
||||
# To show the log.
|
||||
assert False
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mutagen_mock():
|
||||
"""Mock writing tags."""
|
||||
with patch(
|
||||
"homeassistant.components.tts.SpeechManager.write_tags",
|
||||
side_effect=lambda *args: args[1],
|
||||
) as mock_write_tags:
|
||||
yield mock_write_tags
|
||||
|
|
|
@ -5,6 +5,7 @@ from unittest.mock import PropertyMock, patch
|
|||
import pytest
|
||||
import yarl
|
||||
|
||||
from homeassistant.components import tts
|
||||
from homeassistant.components.demo.tts import DemoProvider
|
||||
from homeassistant.components.media_player.const import (
|
||||
ATTR_MEDIA_CONTENT_ID,
|
||||
|
@ -13,13 +14,13 @@ from homeassistant.components.media_player.const import (
|
|||
MEDIA_TYPE_MUSIC,
|
||||
SERVICE_PLAY_MEDIA,
|
||||
)
|
||||
import homeassistant.components.tts as tts
|
||||
from homeassistant.components.tts import _get_cache_files
|
||||
from homeassistant.config import async_process_ha_core_config
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import assert_setup_component, async_mock_service
|
||||
|
||||
ORIG_WRITE_TAGS = tts.SpeechManager.write_tags
|
||||
|
||||
|
||||
def relative_url(url):
|
||||
"""Convert an absolute url to a relative one."""
|
||||
|
@ -32,58 +33,6 @@ def demo_provider():
|
|||
return DemoProvider("en")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_get_cache_files():
|
||||
"""Mock the list TTS cache function."""
|
||||
with patch(
|
||||
"homeassistant.components.tts._get_cache_files", return_value={}
|
||||
) as mock_cache_files:
|
||||
yield mock_cache_files
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_init_cache_dir():
|
||||
"""Mock the TTS cache dir in memory."""
|
||||
with patch(
|
||||
"homeassistant.components.tts._init_tts_cache_dir",
|
||||
side_effect=lambda hass, cache_dir: hass.config.path(cache_dir),
|
||||
) as mock_cache_dir:
|
||||
yield mock_cache_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_cache_dir(tmp_path, mock_init_cache_dir, mock_get_cache_files, request):
|
||||
"""Mock the TTS cache dir with empty dir."""
|
||||
mock_init_cache_dir.side_effect = None
|
||||
mock_init_cache_dir.return_value = str(tmp_path)
|
||||
|
||||
# Restore original get cache files behavior, we're working with a real dir.
|
||||
mock_get_cache_files.side_effect = _get_cache_files
|
||||
|
||||
yield tmp_path
|
||||
|
||||
if request.node.rep_call.passed:
|
||||
return
|
||||
|
||||
# Print contents of dir if failed
|
||||
print("Content of dir for", request.node.nodeid)
|
||||
for fil in tmp_path.iterdir():
|
||||
print(fil.relative_to(tmp_path))
|
||||
|
||||
# To show the log.
|
||||
assert False
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mutagen_mock():
|
||||
"""Mock writing tags."""
|
||||
with patch(
|
||||
"homeassistant.components.tts.SpeechManager.write_tags",
|
||||
side_effect=lambda *args: args[1],
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def internal_url_mock(hass):
|
||||
"""Mock internal URL of the instance."""
|
||||
|
@ -730,7 +679,7 @@ async def test_tags_with_wave(hass, demo_provider):
|
|||
+ "22 56 00 00 88 58 01 00 04 00 10 00 64 61 74 61 00 00 00 00"
|
||||
)
|
||||
|
||||
tagged_data = tts.SpeechManager.write_tags(
|
||||
tagged_data = ORIG_WRITE_TAGS(
|
||||
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_demo.wav",
|
||||
demo_data,
|
||||
demo_provider,
|
||||
|
|
99
tests/components/tts/test_media_source.py
Normal file
99
tests/components/tts/test_media_source.py
Normal file
|
@ -0,0 +1,99 @@
|
|||
"""Tests for TTS media source."""
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import media_source
|
||||
from homeassistant.components.media_player.errors import BrowseError
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def mock_get_tts_audio(hass):
|
||||
"""Set up media source."""
|
||||
assert await async_setup_component(hass, "media_source", {})
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"tts",
|
||||
{
|
||||
"tts": {
|
||||
"platform": "demo",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.demo.tts.DemoProvider.get_tts_audio",
|
||||
return_value=("mp3", b""),
|
||||
) as mock_get_tts:
|
||||
yield mock_get_tts
|
||||
|
||||
|
||||
async def test_browsing(hass):
|
||||
"""Test browsing TTS media source."""
|
||||
item = await media_source.async_browse_media(hass, "media-source://tts")
|
||||
assert item is not None
|
||||
assert item.title == "Text to Speech"
|
||||
assert len(item.children) == 1
|
||||
assert item.can_play is False
|
||||
assert item.can_expand is True
|
||||
|
||||
item_child = await media_source.async_browse_media(
|
||||
hass, item.children[0].media_content_id
|
||||
)
|
||||
assert item_child is not None
|
||||
assert item_child.title == "Demo"
|
||||
assert item_child.children is None
|
||||
assert item_child.can_play is False
|
||||
assert item_child.can_expand is True
|
||||
|
||||
with pytest.raises(BrowseError):
|
||||
await media_source.async_browse_media(hass, "media-source://tts/non-existing")
|
||||
|
||||
|
||||
async def test_resolving(hass, mock_get_tts_audio):
|
||||
"""Test resolving."""
|
||||
media = await media_source.async_resolve_media(
|
||||
hass, "media-source://tts/demo?message=Hello%20World"
|
||||
)
|
||||
assert media.url.startswith("/api/tts_proxy/")
|
||||
assert media.mime_type == "audio/mpeg"
|
||||
|
||||
assert len(mock_get_tts_audio.mock_calls) == 1
|
||||
message, language = mock_get_tts_audio.mock_calls[0][1]
|
||||
assert message == "Hello World"
|
||||
assert language == "en"
|
||||
assert mock_get_tts_audio.mock_calls[0][2]["options"] is None
|
||||
|
||||
# Pass language and options
|
||||
mock_get_tts_audio.reset_mock()
|
||||
media = await media_source.async_resolve_media(
|
||||
hass, "media-source://tts/demo?message=Bye%20World&language=de&voice=Paulus"
|
||||
)
|
||||
assert media.url.startswith("/api/tts_proxy/")
|
||||
assert media.mime_type == "audio/mpeg"
|
||||
|
||||
assert len(mock_get_tts_audio.mock_calls) == 1
|
||||
message, language = mock_get_tts_audio.mock_calls[0][1]
|
||||
assert message == "Bye World"
|
||||
assert language == "de"
|
||||
assert mock_get_tts_audio.mock_calls[0][2]["options"] == {"voice": "Paulus"}
|
||||
|
||||
|
||||
async def test_resolving_errors(hass):
|
||||
"""Test resolving."""
|
||||
# No message added
|
||||
with pytest.raises(media_source.Unresolvable):
|
||||
await media_source.async_resolve_media(hass, "media-source://tts/demo")
|
||||
|
||||
# Non-existing provider
|
||||
with pytest.raises(media_source.Unresolvable):
|
||||
await media_source.async_resolve_media(
|
||||
hass, "media-source://tts/non-existing?message=bla"
|
||||
)
|
||||
|
||||
# Non-existing option
|
||||
with pytest.raises(media_source.Unresolvable):
|
||||
await media_source.async_resolve_media(
|
||||
hass, "media-source://tts/non-existing?message=bla&non_existing_option=bla"
|
||||
)
|
|
@ -1,6 +1,4 @@
|
|||
"""The tests for the TTS component."""
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import yarl
|
||||
|
||||
|
@ -22,16 +20,6 @@ def relative_url(url):
|
|||
return str(yarl.URL(url).relative())
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mutagen_mock():
|
||||
"""Mock writing tags."""
|
||||
with patch(
|
||||
"homeassistant.components.tts.SpeechManager.write_tags",
|
||||
side_effect=lambda *args: args[1],
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def internal_url_mock(hass):
|
||||
"""Mock internal URL of the instance."""
|
||||
|
|
|
@ -15,7 +15,7 @@ import homeassistant.components.tts as tts
|
|||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import assert_setup_component, async_mock_service
|
||||
from tests.components.tts.test_init import mutagen_mock # noqa: F401
|
||||
from tests.components.tts.conftest import mutagen_mock # noqa: F401
|
||||
|
||||
URL = "https://api.voicerss.org/"
|
||||
FORM_DATA = {
|
||||
|
|
|
@ -14,7 +14,7 @@ import homeassistant.components.tts as tts
|
|||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import assert_setup_component, async_mock_service
|
||||
from tests.components.tts.test_init import ( # noqa: F401, pylint: disable=unused-import
|
||||
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
|
||||
mutagen_mock,
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue