TTS Cleanup and expose get audio (#79065)
This commit is contained in:
parent
39ddc37d76
commit
697e7b3a20
4 changed files with 250 additions and 85 deletions
|
@ -34,7 +34,7 @@ from .const import (
|
||||||
URI_SCHEME_REGEX,
|
URI_SCHEME_REGEX,
|
||||||
)
|
)
|
||||||
from .error import MediaSourceError, Unresolvable
|
from .error import MediaSourceError, Unresolvable
|
||||||
from .models import BrowseMediaSource, MediaSourceItem, PlayMedia
|
from .models import BrowseMediaSource, MediaSource, MediaSourceItem, PlayMedia
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DOMAIN",
|
"DOMAIN",
|
||||||
|
@ -46,6 +46,7 @@ __all__ = [
|
||||||
"PlayMedia",
|
"PlayMedia",
|
||||||
"MediaSourceItem",
|
"MediaSourceItem",
|
||||||
"Unresolvable",
|
"Unresolvable",
|
||||||
|
"MediaSource",
|
||||||
"MediaSourceError",
|
"MediaSourceError",
|
||||||
"MEDIA_CLASS_MAP",
|
"MEDIA_CLASS_MAP",
|
||||||
"MEDIA_MIME_TYPES",
|
"MEDIA_MIME_TYPES",
|
||||||
|
|
|
@ -11,7 +11,7 @@ import mimetypes
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
from typing import TYPE_CHECKING, Any, Optional, TypedDict, cast
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import mutagen
|
import mutagen
|
||||||
|
@ -28,7 +28,6 @@ from homeassistant.components.media_player import (
|
||||||
SERVICE_PLAY_MEDIA,
|
SERVICE_PLAY_MEDIA,
|
||||||
MediaType,
|
MediaType,
|
||||||
)
|
)
|
||||||
from homeassistant.components.media_source import generate_media_source_id
|
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
ATTR_ENTITY_ID,
|
ATTR_ENTITY_ID,
|
||||||
CONF_DESCRIPTION,
|
CONF_DESCRIPTION,
|
||||||
|
@ -48,6 +47,7 @@ from homeassistant.util.network import normalize_url
|
||||||
from homeassistant.util.yaml import load_yaml
|
from homeassistant.util.yaml import load_yaml
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
|
from .media_source import generate_media_source_id, media_source_id_to_kwargs
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -74,9 +74,6 @@ DEFAULT_CACHE = True
|
||||||
DEFAULT_CACHE_DIR = "tts"
|
DEFAULT_CACHE_DIR = "tts"
|
||||||
DEFAULT_TIME_MEMORY = 300
|
DEFAULT_TIME_MEMORY = 300
|
||||||
|
|
||||||
MEM_CACHE_FILENAME = "filename"
|
|
||||||
MEM_CACHE_VOICE = "voice"
|
|
||||||
|
|
||||||
SERVICE_CLEAR_CACHE = "clear_cache"
|
SERVICE_CLEAR_CACHE = "clear_cache"
|
||||||
SERVICE_SAY = "say"
|
SERVICE_SAY = "say"
|
||||||
|
|
||||||
|
@ -131,6 +128,24 @@ SCHEMA_SERVICE_SAY = vol.Schema(
|
||||||
SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
|
SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
|
||||||
|
|
||||||
|
|
||||||
|
class TTSCache(TypedDict):
|
||||||
|
"""Cached TTS file."""
|
||||||
|
|
||||||
|
filename: str
|
||||||
|
voice: bytes
|
||||||
|
|
||||||
|
|
||||||
|
async def async_get_media_source_audio(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
media_source_id: str,
|
||||||
|
) -> tuple[str, bytes]:
|
||||||
|
"""Get TTS audio as extension, data."""
|
||||||
|
manager: SpeechManager = hass.data[DOMAIN]
|
||||||
|
return await manager.async_get_tts_audio(
|
||||||
|
**media_source_id_to_kwargs(media_source_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
"""Set up TTS."""
|
"""Set up TTS."""
|
||||||
tts = SpeechManager(hass)
|
tts = SpeechManager(hass)
|
||||||
|
@ -197,21 +212,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
async def async_say_handle(service: ServiceCall) -> None:
|
async def async_say_handle(service: ServiceCall) -> None:
|
||||||
"""Service handle for say."""
|
"""Service handle for say."""
|
||||||
entity_ids = service.data[ATTR_ENTITY_ID]
|
entity_ids = service.data[ATTR_ENTITY_ID]
|
||||||
message = service.data[ATTR_MESSAGE]
|
|
||||||
cache = service.data.get(ATTR_CACHE)
|
|
||||||
language = service.data.get(ATTR_LANGUAGE)
|
|
||||||
options = service.data.get(ATTR_OPTIONS)
|
|
||||||
|
|
||||||
tts.process_options(p_type, language, options)
|
|
||||||
params = {
|
|
||||||
"message": message,
|
|
||||||
}
|
|
||||||
if cache is not None:
|
|
||||||
params["cache"] = "true" if cache else "false"
|
|
||||||
if language is not None:
|
|
||||||
params["language"] = language
|
|
||||||
if options is not None:
|
|
||||||
params.update(options)
|
|
||||||
|
|
||||||
await hass.services.async_call(
|
await hass.services.async_call(
|
||||||
DOMAIN_MP,
|
DOMAIN_MP,
|
||||||
|
@ -219,8 +219,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
{
|
{
|
||||||
ATTR_ENTITY_ID: entity_ids,
|
ATTR_ENTITY_ID: entity_ids,
|
||||||
ATTR_MEDIA_CONTENT_ID: generate_media_source_id(
|
ATTR_MEDIA_CONTENT_ID: generate_media_source_id(
|
||||||
DOMAIN,
|
hass,
|
||||||
str(yarl.URL.build(path=p_type, query=params)),
|
engine=p_type,
|
||||||
|
message=service.data[ATTR_MESSAGE],
|
||||||
|
language=service.data.get(ATTR_LANGUAGE),
|
||||||
|
options=service.data.get(ATTR_OPTIONS),
|
||||||
|
cache=service.data.get(ATTR_CACHE),
|
||||||
),
|
),
|
||||||
ATTR_MEDIA_CONTENT_TYPE: MediaType.MUSIC,
|
ATTR_MEDIA_CONTENT_TYPE: MediaType.MUSIC,
|
||||||
ATTR_MEDIA_ANNOUNCE: True,
|
ATTR_MEDIA_ANNOUNCE: True,
|
||||||
|
@ -296,7 +300,7 @@ class SpeechManager:
|
||||||
self.time_memory = DEFAULT_TIME_MEMORY
|
self.time_memory = DEFAULT_TIME_MEMORY
|
||||||
self.base_url: str | None = None
|
self.base_url: str | None = None
|
||||||
self.file_cache: dict[str, str] = {}
|
self.file_cache: dict[str, str] = {}
|
||||||
self.mem_cache: dict[str, dict[str, str | bytes]] = {}
|
self.mem_cache: dict[str, TTSCache] = {}
|
||||||
|
|
||||||
async def async_init_cache(
|
async def async_init_cache(
|
||||||
self, use_cache: bool, cache_dir: str, time_memory: int, base_url: str | None
|
self, use_cache: bool, cache_dir: str, time_memory: int, base_url: str | None
|
||||||
|
@ -380,10 +384,11 @@ class SpeechManager:
|
||||||
options = options or provider.default_options
|
options = options or provider.default_options
|
||||||
|
|
||||||
if options is not None:
|
if options is not None:
|
||||||
|
supported_options = provider.supported_options or []
|
||||||
invalid_opts = [
|
invalid_opts = [
|
||||||
opt_name
|
opt_name
|
||||||
for opt_name in options.keys()
|
for opt_name in options.keys()
|
||||||
if opt_name not in (provider.supported_options or [])
|
if opt_name not in supported_options
|
||||||
]
|
]
|
||||||
if invalid_opts:
|
if invalid_opts:
|
||||||
raise HomeAssistantError(f"Invalid options found: {invalid_opts}")
|
raise HomeAssistantError(f"Invalid options found: {invalid_opts}")
|
||||||
|
@ -403,25 +408,25 @@ class SpeechManager:
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
"""
|
"""
|
||||||
language, options = self.process_options(engine, language, options)
|
language, options = self.process_options(engine, language, options)
|
||||||
options_key = _hash_options(options) if options else "-"
|
cache_key = self._generate_cache_key(message, language, options, engine)
|
||||||
msg_hash = hashlib.sha1(bytes(message, "utf-8")).hexdigest()
|
|
||||||
use_cache = cache if cache is not None else self.use_cache
|
use_cache = cache if cache is not None else self.use_cache
|
||||||
|
|
||||||
key = KEY_PATTERN.format(
|
|
||||||
msg_hash, language.replace("_", "-"), options_key, engine
|
|
||||||
).lower()
|
|
||||||
|
|
||||||
# Is speech already in memory
|
# Is speech already in memory
|
||||||
if key in self.mem_cache:
|
if cache_key in self.mem_cache:
|
||||||
filename = cast(str, self.mem_cache[key][MEM_CACHE_FILENAME])
|
filename = self.mem_cache[cache_key]["filename"]
|
||||||
# Is file store in file cache
|
# Is file store in file cache
|
||||||
elif use_cache and key in self.file_cache:
|
elif use_cache and cache_key in self.file_cache:
|
||||||
filename = self.file_cache[key]
|
filename = self.file_cache[cache_key]
|
||||||
self.hass.async_create_task(self.async_file_to_mem(key))
|
self.hass.async_create_task(self._async_file_to_mem(cache_key))
|
||||||
# Load speech from provider into memory
|
# Load speech from provider into memory
|
||||||
else:
|
else:
|
||||||
filename = await self.async_get_tts_audio(
|
filename = await self._async_get_tts_audio(
|
||||||
engine, key, message, use_cache, language, options
|
engine,
|
||||||
|
cache_key,
|
||||||
|
message,
|
||||||
|
use_cache,
|
||||||
|
language,
|
||||||
|
options,
|
||||||
)
|
)
|
||||||
|
|
||||||
return f"/api/tts_proxy/{filename}"
|
return f"/api/tts_proxy/{filename}"
|
||||||
|
@ -429,13 +434,54 @@ class SpeechManager:
|
||||||
async def async_get_tts_audio(
|
async def async_get_tts_audio(
|
||||||
self,
|
self,
|
||||||
engine: str,
|
engine: str,
|
||||||
key: str,
|
message: str,
|
||||||
|
cache: bool | None = None,
|
||||||
|
language: str | None = None,
|
||||||
|
options: dict | None = None,
|
||||||
|
) -> tuple[str, bytes]:
|
||||||
|
"""Fetch TTS audio."""
|
||||||
|
language, options = self.process_options(engine, language, options)
|
||||||
|
cache_key = self._generate_cache_key(message, language, options, engine)
|
||||||
|
use_cache = cache if cache is not None else self.use_cache
|
||||||
|
|
||||||
|
# If we have the file, load it into memory if necessary
|
||||||
|
if cache_key not in self.mem_cache:
|
||||||
|
if use_cache and cache_key in self.file_cache:
|
||||||
|
await self._async_file_to_mem(cache_key)
|
||||||
|
else:
|
||||||
|
await self._async_get_tts_audio(
|
||||||
|
engine, cache_key, message, use_cache, language, options
|
||||||
|
)
|
||||||
|
|
||||||
|
extension = os.path.splitext(self.mem_cache[cache_key]["filename"])[1][1:]
|
||||||
|
data = self.mem_cache[cache_key]["voice"]
|
||||||
|
return extension, data
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _generate_cache_key(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
language: str,
|
||||||
|
options: dict | None,
|
||||||
|
engine: str,
|
||||||
|
) -> str:
|
||||||
|
"""Generate a cache key for a message."""
|
||||||
|
options_key = _hash_options(options) if options else "-"
|
||||||
|
msg_hash = hashlib.sha1(bytes(message, "utf-8")).hexdigest()
|
||||||
|
return KEY_PATTERN.format(
|
||||||
|
msg_hash, language.replace("_", "-"), options_key, engine
|
||||||
|
).lower()
|
||||||
|
|
||||||
|
async def _async_get_tts_audio(
|
||||||
|
self,
|
||||||
|
engine: str,
|
||||||
|
cache_key: str,
|
||||||
message: str,
|
message: str,
|
||||||
cache: bool,
|
cache: bool,
|
||||||
language: str,
|
language: str,
|
||||||
options: dict | None,
|
options: dict | None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Receive TTS and store for view in cache.
|
"""Receive TTS, store for view in cache and return filename.
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
"""
|
"""
|
||||||
|
@ -446,7 +492,7 @@ class SpeechManager:
|
||||||
raise HomeAssistantError(f"No TTS from {engine} for '{message}'")
|
raise HomeAssistantError(f"No TTS from {engine} for '{message}'")
|
||||||
|
|
||||||
# Create file infos
|
# Create file infos
|
||||||
filename = f"{key}.{extension}".lower()
|
filename = f"{cache_key}.{extension}".lower()
|
||||||
|
|
||||||
# Validate filename
|
# Validate filename
|
||||||
if not _RE_VOICE_FILE.match(filename):
|
if not _RE_VOICE_FILE.match(filename):
|
||||||
|
@ -456,14 +502,18 @@ class SpeechManager:
|
||||||
|
|
||||||
# Save to memory
|
# Save to memory
|
||||||
data = self.write_tags(filename, data, provider, message, language, options)
|
data = self.write_tags(filename, data, provider, message, language, options)
|
||||||
self._async_store_to_memcache(key, filename, data)
|
self._async_store_to_memcache(cache_key, filename, data)
|
||||||
|
|
||||||
if cache:
|
if cache:
|
||||||
self.hass.async_create_task(self.async_save_tts_audio(key, filename, data))
|
self.hass.async_create_task(
|
||||||
|
self._async_save_tts_audio(cache_key, filename, data)
|
||||||
|
)
|
||||||
|
|
||||||
return filename
|
return filename
|
||||||
|
|
||||||
async def async_save_tts_audio(self, key: str, filename: str, data: bytes) -> None:
|
async def _async_save_tts_audio(
|
||||||
|
self, cache_key: str, filename: str, data: bytes
|
||||||
|
) -> None:
|
||||||
"""Store voice data to file and file_cache.
|
"""Store voice data to file and file_cache.
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
|
@ -477,17 +527,17 @@ class SpeechManager:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.hass.async_add_executor_job(save_speech)
|
await self.hass.async_add_executor_job(save_speech)
|
||||||
self.file_cache[key] = filename
|
self.file_cache[cache_key] = filename
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
_LOGGER.error("Can't write %s: %s", filename, err)
|
_LOGGER.error("Can't write %s: %s", filename, err)
|
||||||
|
|
||||||
async def async_file_to_mem(self, key: str) -> None:
|
async def _async_file_to_mem(self, cache_key: str) -> None:
|
||||||
"""Load voice from file cache into memory.
|
"""Load voice from file cache into memory.
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
"""
|
"""
|
||||||
if not (filename := self.file_cache.get(key)):
|
if not (filename := self.file_cache.get(cache_key)):
|
||||||
raise HomeAssistantError(f"Key {key} not in file cache!")
|
raise HomeAssistantError(f"Key {cache_key} not in file cache!")
|
||||||
|
|
||||||
voice_file = os.path.join(self.cache_dir, filename)
|
voice_file = os.path.join(self.cache_dir, filename)
|
||||||
|
|
||||||
|
@ -499,20 +549,22 @@ class SpeechManager:
|
||||||
try:
|
try:
|
||||||
data = await self.hass.async_add_executor_job(load_speech)
|
data = await self.hass.async_add_executor_job(load_speech)
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
del self.file_cache[key]
|
del self.file_cache[cache_key]
|
||||||
raise HomeAssistantError(f"Can't read {voice_file}") from err
|
raise HomeAssistantError(f"Can't read {voice_file}") from err
|
||||||
|
|
||||||
self._async_store_to_memcache(key, filename, data)
|
self._async_store_to_memcache(cache_key, filename, data)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_store_to_memcache(self, key: str, filename: str, data: bytes) -> None:
|
def _async_store_to_memcache(
|
||||||
|
self, cache_key: str, filename: str, data: bytes
|
||||||
|
) -> None:
|
||||||
"""Store data to memcache and set timer to remove it."""
|
"""Store data to memcache and set timer to remove it."""
|
||||||
self.mem_cache[key] = {MEM_CACHE_FILENAME: filename, MEM_CACHE_VOICE: data}
|
self.mem_cache[cache_key] = {"filename": filename, "voice": data}
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_remove_from_mem() -> None:
|
def async_remove_from_mem() -> None:
|
||||||
"""Cleanup memcache."""
|
"""Cleanup memcache."""
|
||||||
self.mem_cache.pop(key, None)
|
self.mem_cache.pop(cache_key, None)
|
||||||
|
|
||||||
self.hass.loop.call_later(self.time_memory, async_remove_from_mem)
|
self.hass.loop.call_later(self.time_memory, async_remove_from_mem)
|
||||||
|
|
||||||
|
@ -524,17 +576,17 @@ class SpeechManager:
|
||||||
if not (record := _RE_VOICE_FILE.match(filename.lower())):
|
if not (record := _RE_VOICE_FILE.match(filename.lower())):
|
||||||
raise HomeAssistantError("Wrong tts file format!")
|
raise HomeAssistantError("Wrong tts file format!")
|
||||||
|
|
||||||
key = KEY_PATTERN.format(
|
cache_key = KEY_PATTERN.format(
|
||||||
record.group(1), record.group(2), record.group(3), record.group(4)
|
record.group(1), record.group(2), record.group(3), record.group(4)
|
||||||
)
|
)
|
||||||
|
|
||||||
if key not in self.mem_cache:
|
if cache_key not in self.mem_cache:
|
||||||
if key not in self.file_cache:
|
if cache_key not in self.file_cache:
|
||||||
raise HomeAssistantError(f"{key} not in cache!")
|
raise HomeAssistantError(f"{cache_key} not in cache!")
|
||||||
await self.async_file_to_mem(key)
|
await self._async_file_to_mem(cache_key)
|
||||||
|
|
||||||
content, _ = mimetypes.guess_type(filename)
|
content, _ = mimetypes.guess_type(filename)
|
||||||
return content, cast(bytes, self.mem_cache[key][MEM_CACHE_VOICE])
|
return content, self.mem_cache[cache_key]["voice"]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def write_tags(
|
def write_tags(
|
||||||
|
|
|
@ -2,17 +2,18 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import mimetypes
|
import mimetypes
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, TypedDict
|
||||||
|
|
||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
from homeassistant.components.media_player import BrowseError, MediaClass
|
from homeassistant.components.media_player import BrowseError, MediaClass
|
||||||
from homeassistant.components.media_source.error import Unresolvable
|
from homeassistant.components.media_source import (
|
||||||
from homeassistant.components.media_source.models import (
|
|
||||||
BrowseMediaSource,
|
BrowseMediaSource,
|
||||||
MediaSource,
|
MediaSource,
|
||||||
MediaSourceItem,
|
MediaSourceItem,
|
||||||
PlayMedia,
|
PlayMedia,
|
||||||
|
Unresolvable,
|
||||||
|
generate_media_source_id as ms_generate_media_source_id,
|
||||||
)
|
)
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
@ -29,6 +30,75 @@ async def async_get_media_source(hass: HomeAssistant) -> TTSMediaSource:
|
||||||
return TTSMediaSource(hass)
|
return TTSMediaSource(hass)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def generate_media_source_id(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
message: str,
|
||||||
|
engine: str | None = None,
|
||||||
|
language: str | None = None,
|
||||||
|
options: dict | None = None,
|
||||||
|
cache: bool | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Generate a media source ID for text-to-speech."""
|
||||||
|
manager: SpeechManager = hass.data[DOMAIN]
|
||||||
|
|
||||||
|
if engine is not None:
|
||||||
|
pass
|
||||||
|
elif not manager.providers:
|
||||||
|
raise HomeAssistantError("No TTS providers available")
|
||||||
|
elif "cloud" in manager.providers:
|
||||||
|
engine = "cloud"
|
||||||
|
else:
|
||||||
|
engine = next(iter(manager.providers))
|
||||||
|
|
||||||
|
manager.process_options(engine, language, options)
|
||||||
|
params = {
|
||||||
|
"message": message,
|
||||||
|
}
|
||||||
|
if cache is not None:
|
||||||
|
params["cache"] = "true" if cache else "false"
|
||||||
|
if language is not None:
|
||||||
|
params["language"] = language
|
||||||
|
if options is not None:
|
||||||
|
params.update(options)
|
||||||
|
|
||||||
|
return ms_generate_media_source_id(
|
||||||
|
DOMAIN,
|
||||||
|
str(URL.build(path=engine, query=params)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MediaSourceOptions(TypedDict):
|
||||||
|
"""Media source options."""
|
||||||
|
|
||||||
|
engine: str
|
||||||
|
message: str
|
||||||
|
language: str | None
|
||||||
|
options: dict | None
|
||||||
|
cache: bool | None
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def media_source_id_to_kwargs(media_source_id: str) -> MediaSourceOptions:
|
||||||
|
"""Turn a media source ID into options."""
|
||||||
|
parsed = URL(media_source_id)
|
||||||
|
if "message" not in parsed.query:
|
||||||
|
raise Unresolvable("No message specified.")
|
||||||
|
|
||||||
|
options = dict(parsed.query)
|
||||||
|
kwargs: MediaSourceOptions = {
|
||||||
|
"engine": parsed.name,
|
||||||
|
"message": options.pop("message"),
|
||||||
|
"language": options.pop("language", None),
|
||||||
|
"options": options,
|
||||||
|
"cache": None,
|
||||||
|
}
|
||||||
|
if "cache" in options:
|
||||||
|
kwargs["cache"] = options.pop("cache") == "true"
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
class TTSMediaSource(MediaSource):
|
class TTSMediaSource(MediaSource):
|
||||||
"""Provide text-to-speech providers as media sources."""
|
"""Provide text-to-speech providers as media sources."""
|
||||||
|
|
||||||
|
@ -41,24 +111,12 @@ class TTSMediaSource(MediaSource):
|
||||||
|
|
||||||
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
|
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
|
||||||
"""Resolve media to a url."""
|
"""Resolve media to a url."""
|
||||||
parsed = URL(item.identifier)
|
|
||||||
if "message" not in parsed.query:
|
|
||||||
raise Unresolvable("No message specified.")
|
|
||||||
|
|
||||||
options = dict(parsed.query)
|
|
||||||
kwargs: dict[str, Any] = {
|
|
||||||
"engine": parsed.name,
|
|
||||||
"message": options.pop("message"),
|
|
||||||
"language": options.pop("language", None),
|
|
||||||
"options": options,
|
|
||||||
}
|
|
||||||
if "cache" in options:
|
|
||||||
kwargs["cache"] = options.pop("cache") == "true"
|
|
||||||
|
|
||||||
manager: SpeechManager = self.hass.data[DOMAIN]
|
manager: SpeechManager = self.hass.data[DOMAIN]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
url = await manager.async_get_url_path(**kwargs)
|
url = await manager.async_get_url_path(
|
||||||
|
**media_source_id_to_kwargs(item.identifier)
|
||||||
|
)
|
||||||
except HomeAssistantError as err:
|
except HomeAssistantError as err:
|
||||||
raise Unresolvable(str(err)) from err
|
raise Unresolvable(str(err)) from err
|
||||||
|
|
||||||
|
|
|
@ -49,13 +49,18 @@ async def internal_url_mock(hass):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_setup_component_demo(hass):
|
@pytest.fixture
|
||||||
|
async def setup_tts(hass):
|
||||||
|
"""Mock TTS."""
|
||||||
|
with patch("homeassistant.components.demo.async_setup", return_value=True):
|
||||||
|
assert await async_setup_component(
|
||||||
|
hass, tts.DOMAIN, {"tts": {"platform": "demo"}}
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_setup_component_demo(hass, setup_tts):
|
||||||
"""Set up the demo platform with defaults."""
|
"""Set up the demo platform with defaults."""
|
||||||
config = {tts.DOMAIN: {"platform": "demo"}}
|
|
||||||
|
|
||||||
with assert_setup_component(1, tts.DOMAIN):
|
|
||||||
assert await async_setup_component(hass, tts.DOMAIN, config)
|
|
||||||
|
|
||||||
assert hass.services.has_service(tts.DOMAIN, "demo_say")
|
assert hass.services.has_service(tts.DOMAIN, "demo_say")
|
||||||
assert hass.services.has_service(tts.DOMAIN, "clear_cache")
|
assert hass.services.has_service(tts.DOMAIN, "clear_cache")
|
||||||
assert f"{tts.DOMAIN}.demo" in hass.config.components
|
assert f"{tts.DOMAIN}.demo" in hass.config.components
|
||||||
|
@ -421,12 +426,14 @@ async def test_setup_component_and_test_service_with_receive_voice(
|
||||||
with assert_setup_component(1, tts.DOMAIN):
|
with assert_setup_component(1, tts.DOMAIN):
|
||||||
assert await async_setup_component(hass, tts.DOMAIN, config)
|
assert await async_setup_component(hass, tts.DOMAIN, config)
|
||||||
|
|
||||||
|
message = "There is someone at the door."
|
||||||
|
|
||||||
await hass.services.async_call(
|
await hass.services.async_call(
|
||||||
tts.DOMAIN,
|
tts.DOMAIN,
|
||||||
"demo_say",
|
"demo_say",
|
||||||
{
|
{
|
||||||
"entity_id": "media_player.something",
|
"entity_id": "media_player.something",
|
||||||
tts.ATTR_MESSAGE: "There is someone at the door.",
|
tts.ATTR_MESSAGE: message,
|
||||||
},
|
},
|
||||||
blocking=True,
|
blocking=True,
|
||||||
)
|
)
|
||||||
|
@ -440,13 +447,19 @@ async def test_setup_component_and_test_service_with_receive_voice(
|
||||||
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_demo.mp3",
|
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_demo.mp3",
|
||||||
demo_data,
|
demo_data,
|
||||||
demo_provider,
|
demo_provider,
|
||||||
"There is someone at the door.",
|
message,
|
||||||
"en",
|
"en",
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
assert req.status == HTTPStatus.OK
|
assert req.status == HTTPStatus.OK
|
||||||
assert await req.read() == demo_data
|
assert await req.read() == demo_data
|
||||||
|
|
||||||
|
extension, data = await tts.async_get_media_source_audio(
|
||||||
|
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
||||||
|
)
|
||||||
|
assert extension == "mp3"
|
||||||
|
assert demo_data == data
|
||||||
|
|
||||||
|
|
||||||
async def test_setup_component_and_test_service_with_receive_voice_german(
|
async def test_setup_component_and_test_service_with_receive_voice_german(
|
||||||
hass, demo_provider, hass_client
|
hass, demo_provider, hass_client
|
||||||
|
@ -736,3 +749,44 @@ def test_invalid_base_url(value):
|
||||||
"""Test we catch bad base urls."""
|
"""Test we catch bad base urls."""
|
||||||
with pytest.raises(vol.Invalid):
|
with pytest.raises(vol.Invalid):
|
||||||
tts.valid_base_url(value)
|
tts.valid_base_url(value)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"engine,language,options,cache,result_engine,result_query",
|
||||||
|
(
|
||||||
|
(None, None, None, None, "demo", ""),
|
||||||
|
(None, "de", None, None, "demo", "language=de"),
|
||||||
|
(None, "de", {"voice": "henk"}, None, "demo", "language=de&voice=henk"),
|
||||||
|
(None, "de", None, True, "demo", "cache=true&language=de"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
async def test_generate_media_source_id(
|
||||||
|
hass, setup_tts, engine, language, options, cache, result_engine, result_query
|
||||||
|
):
|
||||||
|
"""Test generating a media source ID."""
|
||||||
|
media_source_id = tts.generate_media_source_id(
|
||||||
|
hass, "msg", engine, language, options, cache
|
||||||
|
)
|
||||||
|
|
||||||
|
assert media_source_id.startswith("media-source://tts/")
|
||||||
|
_, _, engine_query = media_source_id.rpartition("/")
|
||||||
|
engine, _, query = engine_query.partition("?")
|
||||||
|
assert engine == result_engine
|
||||||
|
assert query.startswith("message=msg")
|
||||||
|
assert query[12:] == result_query
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"engine,language,options",
|
||||||
|
(
|
||||||
|
("not-loaded-engine", None, None),
|
||||||
|
(None, "unsupported-language", None),
|
||||||
|
(None, None, {"option": "not-supported"}),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
async def test_generate_media_source_id_invalid_options(
|
||||||
|
hass, setup_tts, engine, language, options
|
||||||
|
):
|
||||||
|
"""Test generating a media source ID."""
|
||||||
|
with pytest.raises(HomeAssistantError):
|
||||||
|
tts.generate_media_source_id(hass, "msg", engine, language, options, None)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue