Add a media source to TTS (#66483)

This commit is contained in:
Paulus Schoutsen 2022-02-14 08:54:12 -08:00 committed by GitHub
parent 013d227113
commit 8456c6416e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 283 additions and 75 deletions

View file

@ -9,6 +9,7 @@ import io
import logging
import mimetypes
import os
from pathlib import Path
import re
from typing import TYPE_CHECKING, Optional, cast
@ -39,10 +40,11 @@ import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.network import get_url
from homeassistant.helpers.service import async_set_service_schema
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.loader import async_get_integration
from homeassistant.setup import async_prepare_setup_platform
from homeassistant.util.yaml import load_yaml
from .const import DOMAIN
# mypy: allow-untyped-defs, no-check-untyped-defs
_LOGGER = logging.getLogger(__name__)
@ -69,7 +71,6 @@ CONF_FIELDS = "fields"
DEFAULT_CACHE = True
DEFAULT_CACHE_DIR = "tts"
DEFAULT_TIME_MEMORY = 300
DOMAIN = "tts"
MEM_CACHE_FILENAME = "filename"
MEM_CACHE_VOICE = "voice"
@ -135,12 +136,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
_LOGGER.exception("Error on cache init")
return False
hass.data[DOMAIN] = tts
hass.http.register_view(TextToSpeechView(tts))
hass.http.register_view(TextToSpeechUrlView(tts))
# Load service descriptions from tts/services.yaml
integration = await async_get_integration(hass, DOMAIN)
services_yaml = integration.file_path / "services.yaml"
services_yaml = Path(__file__).parent / "services.yaml"
services_dict = cast(
dict, await hass.async_add_executor_job(load_yaml, str(services_yaml))
)
@ -343,7 +344,11 @@ class SpeechManager:
This method is a coroutine.
"""
provider = self.providers[engine]
provider = self.providers.get(engine)
if provider is None:
raise HomeAssistantError(f"Provider {engine} not found")
msg_hash = hashlib.sha1(bytes(message, "utf-8")).hexdigest()
use_cache = cache if cache is not None else self.use_cache

View file

@ -0,0 +1,3 @@
"""Text-to-speech constants."""
DOMAIN = "tts"

View file

@ -0,0 +1,109 @@
"""Text-to-speech media source."""
from __future__ import annotations
import mimetypes
from typing import TYPE_CHECKING
from yarl import URL
from homeassistant.components.media_player.const import MEDIA_CLASS_APP
from homeassistant.components.media_player.errors import BrowseError
from homeassistant.components.media_source.error import Unresolvable
from homeassistant.components.media_source.models import (
BrowseMediaSource,
MediaSource,
MediaSourceItem,
PlayMedia,
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from .const import DOMAIN
if TYPE_CHECKING:
from . import SpeechManager
async def async_get_media_source(hass: HomeAssistant) -> TTSMediaSource:
"""Set up tts media source."""
return TTSMediaSource(hass)
class TTSMediaSource(MediaSource):
"""Provide text-to-speech providers as media sources."""
name: str = "Text to Speech"
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize TTSMediaSource."""
super().__init__(DOMAIN)
self.hass = hass
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
"""Resolve media to a url."""
parsed = URL(item.identifier)
if "message" not in parsed.query:
raise Unresolvable("No message specified.")
options = dict(parsed.query)
kwargs = {
"engine": parsed.name,
"message": options.pop("message"),
"language": options.pop("language", None),
"options": options,
}
manager: SpeechManager = self.hass.data[DOMAIN]
try:
url = await manager.async_get_url_path(**kwargs) # type: ignore
except HomeAssistantError as err:
raise Unresolvable(str(err)) from err
mime_type = mimetypes.guess_type(url)[0] or "audio/mpeg"
return PlayMedia(url, mime_type)
async def async_browse_media(
self,
item: MediaSourceItem,
) -> BrowseMediaSource:
"""Return media."""
if item.identifier:
provider, _, _ = item.identifier.partition("?")
return self._provider_item(provider)
# Root. List providers.
manager: SpeechManager = self.hass.data[DOMAIN]
children = [self._provider_item(provider) for provider in manager.providers]
return BrowseMediaSource(
domain=DOMAIN,
identifier=None,
media_class=MEDIA_CLASS_APP,
media_content_type="",
title=self.name,
can_play=False,
can_expand=True,
children_media_class=MEDIA_CLASS_APP,
children=children,
)
@callback
def _provider_item(self, provider_domain: str) -> BrowseMediaSource:
"""Return provider item."""
manager: SpeechManager = self.hass.data[DOMAIN]
provider = manager.providers.get(provider_domain)
if provider is None:
raise BrowseError("Unknown provider")
return BrowseMediaSource(
domain=DOMAIN,
identifier=provider_domain,
media_class=MEDIA_CLASS_APP,
media_content_type="provider",
title=provider.name,
thumbnail=f"https://brands.home-assistant.io/_/{provider_domain}/logo.png",
can_play=False,
can_expand=True,
)

View file

@ -16,7 +16,7 @@ from homeassistant.config import async_process_ha_core_config
from homeassistant.setup import async_setup_component
from tests.common import async_mock_service
from tests.components.tts.test_init import mutagen_mock # noqa: F401
from tests.components.tts.conftest import mutagen_mock # noqa: F401
@pytest.fixture(autouse=True)

View file

@ -2,9 +2,12 @@
From http://doc.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures
"""
from unittest.mock import patch
import pytest
from homeassistant.components.tts import _get_cache_files
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
def pytest_runtest_makereport(item, call):
@ -16,3 +19,55 @@ def pytest_runtest_makereport(item, call):
# set a report attribute for each phase of a call, which can
# be "setup", "call", "teardown"
setattr(item, f"rep_{rep.when}", rep)
@pytest.fixture(autouse=True)
def mock_get_cache_files():
"""Mock the list TTS cache function."""
with patch(
"homeassistant.components.tts._get_cache_files", return_value={}
) as mock_cache_files:
yield mock_cache_files
@pytest.fixture(autouse=True)
def mock_init_cache_dir():
"""Mock the TTS cache dir in memory."""
with patch(
"homeassistant.components.tts._init_tts_cache_dir",
side_effect=lambda hass, cache_dir: hass.config.path(cache_dir),
) as mock_cache_dir:
yield mock_cache_dir
@pytest.fixture
def empty_cache_dir(tmp_path, mock_init_cache_dir, mock_get_cache_files, request):
"""Mock the TTS cache dir with empty dir."""
mock_init_cache_dir.side_effect = None
mock_init_cache_dir.return_value = str(tmp_path)
# Restore original get cache files behavior, we're working with a real dir.
mock_get_cache_files.side_effect = _get_cache_files
yield tmp_path
if request.node.rep_call.passed:
return
# Print contents of dir if failed
print("Content of dir for", request.node.nodeid)
for fil in tmp_path.iterdir():
print(fil.relative_to(tmp_path))
# To show the log.
assert False
@pytest.fixture(autouse=True)
def mutagen_mock():
"""Mock writing tags."""
with patch(
"homeassistant.components.tts.SpeechManager.write_tags",
side_effect=lambda *args: args[1],
) as mock_write_tags:
yield mock_write_tags

View file

@ -5,6 +5,7 @@ from unittest.mock import PropertyMock, patch
import pytest
import yarl
from homeassistant.components import tts
from homeassistant.components.demo.tts import DemoProvider
from homeassistant.components.media_player.const import (
ATTR_MEDIA_CONTENT_ID,
@ -13,13 +14,13 @@ from homeassistant.components.media_player.const import (
MEDIA_TYPE_MUSIC,
SERVICE_PLAY_MEDIA,
)
import homeassistant.components.tts as tts
from homeassistant.components.tts import _get_cache_files
from homeassistant.config import async_process_ha_core_config
from homeassistant.setup import async_setup_component
from tests.common import assert_setup_component, async_mock_service
ORIG_WRITE_TAGS = tts.SpeechManager.write_tags
def relative_url(url):
"""Convert an absolute url to a relative one."""
@ -32,58 +33,6 @@ def demo_provider():
return DemoProvider("en")
@pytest.fixture(autouse=True)
def mock_get_cache_files():
"""Mock the list TTS cache function."""
with patch(
"homeassistant.components.tts._get_cache_files", return_value={}
) as mock_cache_files:
yield mock_cache_files
@pytest.fixture(autouse=True)
def mock_init_cache_dir():
"""Mock the TTS cache dir in memory."""
with patch(
"homeassistant.components.tts._init_tts_cache_dir",
side_effect=lambda hass, cache_dir: hass.config.path(cache_dir),
) as mock_cache_dir:
yield mock_cache_dir
@pytest.fixture
def empty_cache_dir(tmp_path, mock_init_cache_dir, mock_get_cache_files, request):
"""Mock the TTS cache dir with empty dir."""
mock_init_cache_dir.side_effect = None
mock_init_cache_dir.return_value = str(tmp_path)
# Restore original get cache files behavior, we're working with a real dir.
mock_get_cache_files.side_effect = _get_cache_files
yield tmp_path
if request.node.rep_call.passed:
return
# Print contents of dir if failed
print("Content of dir for", request.node.nodeid)
for fil in tmp_path.iterdir():
print(fil.relative_to(tmp_path))
# To show the log.
assert False
@pytest.fixture()
def mutagen_mock():
"""Mock writing tags."""
with patch(
"homeassistant.components.tts.SpeechManager.write_tags",
side_effect=lambda *args: args[1],
):
yield
@pytest.fixture(autouse=True)
async def internal_url_mock(hass):
"""Mock internal URL of the instance."""
@ -730,7 +679,7 @@ async def test_tags_with_wave(hass, demo_provider):
+ "22 56 00 00 88 58 01 00 04 00 10 00 64 61 74 61 00 00 00 00"
)
tagged_data = tts.SpeechManager.write_tags(
tagged_data = ORIG_WRITE_TAGS(
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_demo.wav",
demo_data,
demo_provider,

View file

@ -0,0 +1,99 @@
"""Tests for TTS media source."""
from unittest.mock import patch
import pytest
from homeassistant.components import media_source
from homeassistant.components.media_player.errors import BrowseError
from homeassistant.setup import async_setup_component
@pytest.fixture(autouse=True)
async def mock_get_tts_audio(hass):
"""Set up media source."""
assert await async_setup_component(hass, "media_source", {})
assert await async_setup_component(
hass,
"tts",
{
"tts": {
"platform": "demo",
}
},
)
with patch(
"homeassistant.components.demo.tts.DemoProvider.get_tts_audio",
return_value=("mp3", b""),
) as mock_get_tts:
yield mock_get_tts
async def test_browsing(hass):
"""Test browsing TTS media source."""
item = await media_source.async_browse_media(hass, "media-source://tts")
assert item is not None
assert item.title == "Text to Speech"
assert len(item.children) == 1
assert item.can_play is False
assert item.can_expand is True
item_child = await media_source.async_browse_media(
hass, item.children[0].media_content_id
)
assert item_child is not None
assert item_child.title == "Demo"
assert item_child.children is None
assert item_child.can_play is False
assert item_child.can_expand is True
with pytest.raises(BrowseError):
await media_source.async_browse_media(hass, "media-source://tts/non-existing")
async def test_resolving(hass, mock_get_tts_audio):
"""Test resolving."""
media = await media_source.async_resolve_media(
hass, "media-source://tts/demo?message=Hello%20World"
)
assert media.url.startswith("/api/tts_proxy/")
assert media.mime_type == "audio/mpeg"
assert len(mock_get_tts_audio.mock_calls) == 1
message, language = mock_get_tts_audio.mock_calls[0][1]
assert message == "Hello World"
assert language == "en"
assert mock_get_tts_audio.mock_calls[0][2]["options"] is None
# Pass language and options
mock_get_tts_audio.reset_mock()
media = await media_source.async_resolve_media(
hass, "media-source://tts/demo?message=Bye%20World&language=de&voice=Paulus"
)
assert media.url.startswith("/api/tts_proxy/")
assert media.mime_type == "audio/mpeg"
assert len(mock_get_tts_audio.mock_calls) == 1
message, language = mock_get_tts_audio.mock_calls[0][1]
assert message == "Bye World"
assert language == "de"
assert mock_get_tts_audio.mock_calls[0][2]["options"] == {"voice": "Paulus"}
async def test_resolving_errors(hass):
"""Test resolving."""
# No message added
with pytest.raises(media_source.Unresolvable):
await media_source.async_resolve_media(hass, "media-source://tts/demo")
# Non-existing provider
with pytest.raises(media_source.Unresolvable):
await media_source.async_resolve_media(
hass, "media-source://tts/non-existing?message=bla"
)
# Non-existing option
with pytest.raises(media_source.Unresolvable):
await media_source.async_resolve_media(
hass, "media-source://tts/non-existing?message=bla&non_existing_option=bla"
)

View file

@ -1,6 +1,4 @@
"""The tests for the TTS component."""
from unittest.mock import patch
import pytest
import yarl
@ -22,16 +20,6 @@ def relative_url(url):
return str(yarl.URL(url).relative())
@pytest.fixture(autouse=True)
def mutagen_mock():
"""Mock writing tags."""
with patch(
"homeassistant.components.tts.SpeechManager.write_tags",
side_effect=lambda *args: args[1],
):
yield
@pytest.fixture(autouse=True)
async def internal_url_mock(hass):
"""Mock internal URL of the instance."""

View file

@ -15,7 +15,7 @@ import homeassistant.components.tts as tts
from homeassistant.setup import async_setup_component
from tests.common import assert_setup_component, async_mock_service
from tests.components.tts.test_init import mutagen_mock # noqa: F401
from tests.components.tts.conftest import mutagen_mock # noqa: F401
URL = "https://api.voicerss.org/"
FORM_DATA = {

View file

@ -14,7 +14,7 @@ import homeassistant.components.tts as tts
from homeassistant.setup import async_setup_component
from tests.common import assert_setup_component, async_mock_service
from tests.components.tts.test_init import ( # noqa: F401, pylint: disable=unused-import
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
mutagen_mock,
)