Allow TTS requests to resolve in the background (#90944)
This commit is contained in:
parent
59a02cd08c
commit
86e9f6643f
4 changed files with 163 additions and 39 deletions
|
@ -4,11 +4,16 @@ from hass_nabucasa import Cloud
|
|||
from hass_nabucasa.voice import MAP_VOICE, AudioOutput, VoiceError
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.tts import CONF_LANG, PLATFORM_SCHEMA, Provider
|
||||
from homeassistant.components.tts import (
|
||||
ATTR_AUDIO_OUTPUT,
|
||||
CONF_LANG,
|
||||
PLATFORM_SCHEMA,
|
||||
Provider,
|
||||
)
|
||||
|
||||
from .const import DOMAIN
|
||||
|
||||
CONF_GENDER = "gender"
|
||||
ATTR_GENDER = "gender"
|
||||
|
||||
SUPPORT_LANGUAGES = list({key[0] for key in MAP_VOICE})
|
||||
|
||||
|
@ -18,8 +23,8 @@ def validate_lang(value):
|
|||
if (lang := value.get(CONF_LANG)) is None:
|
||||
return value
|
||||
|
||||
if (gender := value.get(CONF_GENDER)) is None:
|
||||
gender = value[CONF_GENDER] = next(
|
||||
if (gender := value.get(ATTR_GENDER)) is None:
|
||||
gender = value[ATTR_GENDER] = next(
|
||||
(chk_gender for chk_lang, chk_gender in MAP_VOICE if chk_lang == lang), None
|
||||
)
|
||||
|
||||
|
@ -33,7 +38,7 @@ PLATFORM_SCHEMA = vol.All(
|
|||
PLATFORM_SCHEMA.extend(
|
||||
{
|
||||
vol.Optional(CONF_LANG): str,
|
||||
vol.Optional(CONF_GENDER): str,
|
||||
vol.Optional(ATTR_GENDER): str,
|
||||
}
|
||||
),
|
||||
validate_lang,
|
||||
|
@ -49,7 +54,7 @@ async def async_get_engine(hass, config, discovery_info=None):
|
|||
gender = None
|
||||
else:
|
||||
language = config[CONF_LANG]
|
||||
gender = config[CONF_GENDER]
|
||||
gender = config[ATTR_GENDER]
|
||||
|
||||
return CloudProvider(cloud, language, gender)
|
||||
|
||||
|
@ -87,12 +92,15 @@ class CloudProvider(Provider):
|
|||
@property
|
||||
def supported_options(self):
|
||||
"""Return list of supported options like voice, emotion."""
|
||||
return [CONF_GENDER]
|
||||
return [ATTR_GENDER, ATTR_AUDIO_OUTPUT]
|
||||
|
||||
@property
|
||||
def default_options(self):
|
||||
"""Return a dict include default options."""
|
||||
return {CONF_GENDER: self._gender}
|
||||
return {
|
||||
ATTR_GENDER: self._gender,
|
||||
ATTR_AUDIO_OUTPUT: AudioOutput.MP3,
|
||||
}
|
||||
|
||||
async def async_get_tts_audio(self, message, language, options=None):
|
||||
"""Load TTS from NabuCasa Cloud."""
|
||||
|
@ -101,10 +109,10 @@ class CloudProvider(Provider):
|
|||
data = await self.cloud.voice.process_tts(
|
||||
message,
|
||||
language,
|
||||
gender=options[CONF_GENDER],
|
||||
output=AudioOutput.MP3,
|
||||
gender=options[ATTR_GENDER],
|
||||
output=options[ATTR_AUDIO_OUTPUT],
|
||||
)
|
||||
except VoiceError:
|
||||
return (None, None)
|
||||
|
||||
return ("mp3", data)
|
||||
return (str(options[ATTR_AUDIO_OUTPUT]), data)
|
||||
|
|
|
@ -59,6 +59,7 @@ ATTR_LANGUAGE = "language"
|
|||
ATTR_MESSAGE = "message"
|
||||
ATTR_OPTIONS = "options"
|
||||
ATTR_PLATFORM = "platform"
|
||||
ATTR_AUDIO_OUTPUT = "audio_output"
|
||||
|
||||
BASE_URL_KEY = "tts_base_url"
|
||||
|
||||
|
@ -134,6 +135,7 @@ class TTSCache(TypedDict):
|
|||
|
||||
filename: str
|
||||
voice: bytes
|
||||
pending: asyncio.Task | None
|
||||
|
||||
|
||||
@callback
|
||||
|
@ -495,8 +497,11 @@ class SpeechManager:
|
|||
)
|
||||
|
||||
extension = os.path.splitext(self.mem_cache[cache_key]["filename"])[1][1:]
|
||||
data = self.mem_cache[cache_key]["voice"]
|
||||
return extension, data
|
||||
cached = self.mem_cache[cache_key]
|
||||
if pending := cached.get("pending"):
|
||||
await pending
|
||||
cached = self.mem_cache[cache_key]
|
||||
return extension, cached["voice"]
|
||||
|
||||
@callback
|
||||
def _generate_cache_key(
|
||||
|
@ -527,30 +532,62 @@ class SpeechManager:
|
|||
This method is a coroutine.
|
||||
"""
|
||||
provider = self.providers[engine]
|
||||
extension, data = await provider.async_get_tts_audio(message, language, options)
|
||||
|
||||
if data is None or extension is None:
|
||||
raise HomeAssistantError(f"No TTS from {engine} for '{message}'")
|
||||
if options is not None and ATTR_AUDIO_OUTPUT in options:
|
||||
expected_extension = options[ATTR_AUDIO_OUTPUT]
|
||||
else:
|
||||
expected_extension = None
|
||||
|
||||
# Create file infos
|
||||
filename = f"{cache_key}.{extension}".lower()
|
||||
|
||||
# Validate filename
|
||||
if not _RE_VOICE_FILE.match(filename):
|
||||
raise HomeAssistantError(
|
||||
f"TTS filename '{filename}' from {engine} is invalid!"
|
||||
async def get_tts_data() -> str:
|
||||
"""Handle data available."""
|
||||
extension, data = await provider.async_get_tts_audio(
|
||||
message, language, options
|
||||
)
|
||||
|
||||
# Save to memory
|
||||
if extension == "mp3":
|
||||
data = self.write_tags(filename, data, provider, message, language, options)
|
||||
self._async_store_to_memcache(cache_key, filename, data)
|
||||
if data is None or extension is None:
|
||||
raise HomeAssistantError(f"No TTS from {engine} for '{message}'")
|
||||
|
||||
if cache:
|
||||
self.hass.async_create_task(
|
||||
self._async_save_tts_audio(cache_key, filename, data)
|
||||
)
|
||||
# Create file infos
|
||||
filename = f"{cache_key}.{extension}".lower()
|
||||
|
||||
# Validate filename
|
||||
if not _RE_VOICE_FILE.match(filename):
|
||||
raise HomeAssistantError(
|
||||
f"TTS filename '{filename}' from {engine} is invalid!"
|
||||
)
|
||||
|
||||
# Save to memory
|
||||
if extension == "mp3":
|
||||
data = self.write_tags(
|
||||
filename, data, provider, message, language, options
|
||||
)
|
||||
self._async_store_to_memcache(cache_key, filename, data)
|
||||
|
||||
if cache:
|
||||
self.hass.async_create_task(
|
||||
self._async_save_tts_audio(cache_key, filename, data)
|
||||
)
|
||||
|
||||
return filename
|
||||
|
||||
audio_task = self.hass.async_create_task(get_tts_data())
|
||||
|
||||
if expected_extension is None:
|
||||
return await audio_task
|
||||
|
||||
def handle_error(_future: asyncio.Future) -> None:
|
||||
"""Handle error."""
|
||||
if audio_task.exception():
|
||||
self.mem_cache.pop(cache_key, None)
|
||||
|
||||
audio_task.add_done_callback(handle_error)
|
||||
|
||||
filename = f"{cache_key}.{expected_extension}".lower()
|
||||
self.mem_cache[cache_key] = {
|
||||
"filename": filename,
|
||||
"voice": b"",
|
||||
"pending": audio_task,
|
||||
}
|
||||
return filename
|
||||
|
||||
async def _async_save_tts_audio(
|
||||
|
@ -601,7 +638,11 @@ class SpeechManager:
|
|||
self, cache_key: str, filename: str, data: bytes
|
||||
) -> None:
|
||||
"""Store data to memcache and set timer to remove it."""
|
||||
self.mem_cache[cache_key] = {"filename": filename, "voice": data}
|
||||
self.mem_cache[cache_key] = {
|
||||
"filename": filename,
|
||||
"voice": data,
|
||||
"pending": None,
|
||||
}
|
||||
|
||||
@callback
|
||||
def async_remove_from_mem() -> None:
|
||||
|
@ -628,7 +669,11 @@ class SpeechManager:
|
|||
await self._async_file_to_mem(cache_key)
|
||||
|
||||
content, _ = mimetypes.guess_type(filename)
|
||||
return content, self.mem_cache[cache_key]["voice"]
|
||||
cached = self.mem_cache[cache_key]
|
||||
if pending := cached.get("pending"):
|
||||
await pending
|
||||
cached = self.mem_cache[cache_key]
|
||||
return content, cached["voice"]
|
||||
|
||||
@staticmethod
|
||||
def write_tags(
|
||||
|
|
|
@ -58,17 +58,17 @@ async def test_prefs_default_voice(
|
|||
)
|
||||
|
||||
assert provider_pref.default_language == "en-US"
|
||||
assert provider_pref.default_options == {"gender": "female"}
|
||||
assert provider_pref.default_options == {"gender": "female", "audio_output": "mp3"}
|
||||
assert provider_conf.default_language == "fr-FR"
|
||||
assert provider_conf.default_options == {"gender": "female"}
|
||||
assert provider_conf.default_options == {"gender": "female", "audio_output": "mp3"}
|
||||
|
||||
await cloud_prefs.async_update(tts_default_voice=("nl-NL", "male"))
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert provider_pref.default_language == "nl-NL"
|
||||
assert provider_pref.default_options == {"gender": "male"}
|
||||
assert provider_pref.default_options == {"gender": "male", "audio_output": "mp3"}
|
||||
assert provider_conf.default_language == "fr-FR"
|
||||
assert provider_conf.default_options == {"gender": "female"}
|
||||
assert provider_conf.default_options == {"gender": "female", "audio_output": "mp3"}
|
||||
|
||||
|
||||
async def test_provider_properties(cloud_with_prefs) -> None:
|
||||
|
@ -76,7 +76,7 @@ async def test_provider_properties(cloud_with_prefs) -> None:
|
|||
provider = await tts.async_get_engine(
|
||||
Mock(data={const.DOMAIN: cloud_with_prefs}), None, {}
|
||||
)
|
||||
assert provider.supported_options == ["gender"]
|
||||
assert provider.supported_options == ["gender", "audio_output"]
|
||||
assert "nl-NL" in provider.supported_languages
|
||||
|
||||
|
||||
|
@ -85,5 +85,5 @@ async def test_get_tts_audio(cloud_with_prefs) -> None:
|
|||
provider = await tts.async_get_engine(
|
||||
Mock(data={const.DOMAIN: cloud_with_prefs}), None, {}
|
||||
)
|
||||
assert provider.supported_options == ["gender"]
|
||||
assert provider.supported_options == ["gender", "audio_output"]
|
||||
assert "nl-NL" in provider.supported_languages
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""The tests for the TTS component."""
|
||||
import asyncio
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
@ -996,3 +997,73 @@ async def test_support_options(hass: HomeAssistant, setup_tts) -> None:
|
|||
await tts.async_support_options(hass, "test", "en", {"invalid_option": "yo"})
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
async def test_fetching_in_async(hass: HomeAssistant, hass_client) -> None:
|
||||
"""Test async fetching of data."""
|
||||
tts_audio = asyncio.Future()
|
||||
|
||||
class ProviderWithAsyncFetching(MockProvider):
|
||||
"""Provider that supports audio output option."""
|
||||
|
||||
@property
|
||||
def supported_options(self) -> list[str]:
|
||||
"""Return list of supported options like voice, emotions."""
|
||||
return [tts.ATTR_AUDIO_OUTPUT]
|
||||
|
||||
@property
|
||||
def default_options(self) -> dict[str, str]:
|
||||
"""Return a dict including the default options."""
|
||||
return {tts.ATTR_AUDIO_OUTPUT: "mp3"}
|
||||
|
||||
async def async_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
) -> tts.TtsAudioType:
|
||||
return ("mp3", await tts_audio)
|
||||
|
||||
mock_integration(hass, MockModule(domain="test"))
|
||||
mock_platform(hass, "test.tts", MockTTS(ProviderWithAsyncFetching))
|
||||
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
|
||||
|
||||
# Test async_get_media_source_audio
|
||||
media_source_id = tts.generate_media_source_id(
|
||||
hass, "test message", "test", "en", None, None
|
||||
)
|
||||
|
||||
task = hass.async_create_task(
|
||||
tts.async_get_media_source_audio(hass, media_source_id)
|
||||
)
|
||||
task2 = hass.async_create_task(
|
||||
tts.async_get_media_source_audio(hass, media_source_id)
|
||||
)
|
||||
|
||||
url = await get_media_source_url(hass, media_source_id)
|
||||
client = await hass_client()
|
||||
client_get_task = hass.async_create_task(client.get(url))
|
||||
|
||||
# Make sure that tasks are waiting for our future to resolve
|
||||
done, pending = await asyncio.wait((task, task2, client_get_task), timeout=0.1)
|
||||
assert len(done) == 0
|
||||
assert len(pending) == 3
|
||||
|
||||
tts_audio.set_result(b"test")
|
||||
|
||||
assert await task == ("mp3", b"test")
|
||||
assert await task2 == ("mp3", b"test")
|
||||
|
||||
req = await client_get_task
|
||||
assert req.status == HTTPStatus.OK
|
||||
assert await req.read() == b"test"
|
||||
|
||||
# Test error is not cached
|
||||
media_source_id = tts.generate_media_source_id(
|
||||
hass, "test message 2", "test", "en", None, None
|
||||
)
|
||||
tts_audio = asyncio.Future()
|
||||
tts_audio.set_exception(HomeAssistantError("test error"))
|
||||
with pytest.raises(HomeAssistantError):
|
||||
assert await tts.async_get_media_source_audio(hass, media_source_id)
|
||||
|
||||
tts_audio = asyncio.Future()
|
||||
tts_audio.set_result(b"test 2")
|
||||
await tts.async_get_media_source_audio(hass, media_source_id) == ("mp3", b"test 2")
|
||||
|
|
Loading…
Add table
Reference in a new issue