Allow TTS requests to resolve in the background (#90944)

This commit is contained in:
Paulus Schoutsen 2023-04-06 11:42:55 -04:00 committed by GitHub
parent 59a02cd08c
commit 86e9f6643f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 163 additions and 39 deletions

View file

@ -4,11 +4,16 @@ from hass_nabucasa import Cloud
from hass_nabucasa.voice import MAP_VOICE, AudioOutput, VoiceError
import voluptuous as vol
from homeassistant.components.tts import CONF_LANG, PLATFORM_SCHEMA, Provider
from homeassistant.components.tts import (
ATTR_AUDIO_OUTPUT,
CONF_LANG,
PLATFORM_SCHEMA,
Provider,
)
from .const import DOMAIN
CONF_GENDER = "gender"
ATTR_GENDER = "gender"
SUPPORT_LANGUAGES = list({key[0] for key in MAP_VOICE})
@ -18,8 +23,8 @@ def validate_lang(value):
if (lang := value.get(CONF_LANG)) is None:
return value
if (gender := value.get(CONF_GENDER)) is None:
gender = value[CONF_GENDER] = next(
if (gender := value.get(ATTR_GENDER)) is None:
gender = value[ATTR_GENDER] = next(
(chk_gender for chk_lang, chk_gender in MAP_VOICE if chk_lang == lang), None
)
@ -33,7 +38,7 @@ PLATFORM_SCHEMA = vol.All(
PLATFORM_SCHEMA.extend(
{
vol.Optional(CONF_LANG): str,
vol.Optional(CONF_GENDER): str,
vol.Optional(ATTR_GENDER): str,
}
),
validate_lang,
@ -49,7 +54,7 @@ async def async_get_engine(hass, config, discovery_info=None):
gender = None
else:
language = config[CONF_LANG]
gender = config[CONF_GENDER]
gender = config[ATTR_GENDER]
return CloudProvider(cloud, language, gender)
@ -87,12 +92,15 @@ class CloudProvider(Provider):
@property
def supported_options(self):
"""Return list of supported options like voice, emotion."""
return [CONF_GENDER]
return [ATTR_GENDER, ATTR_AUDIO_OUTPUT]
@property
def default_options(self):
"""Return a dict include default options."""
return {CONF_GENDER: self._gender}
return {
ATTR_GENDER: self._gender,
ATTR_AUDIO_OUTPUT: AudioOutput.MP3,
}
async def async_get_tts_audio(self, message, language, options=None):
"""Load TTS from NabuCasa Cloud."""
@ -101,10 +109,10 @@ class CloudProvider(Provider):
data = await self.cloud.voice.process_tts(
message,
language,
gender=options[CONF_GENDER],
output=AudioOutput.MP3,
gender=options[ATTR_GENDER],
output=options[ATTR_AUDIO_OUTPUT],
)
except VoiceError:
return (None, None)
return ("mp3", data)
return (str(options[ATTR_AUDIO_OUTPUT]), data)

View file

@ -59,6 +59,7 @@ ATTR_LANGUAGE = "language"
ATTR_MESSAGE = "message"
ATTR_OPTIONS = "options"
ATTR_PLATFORM = "platform"
ATTR_AUDIO_OUTPUT = "audio_output"
BASE_URL_KEY = "tts_base_url"
@ -134,6 +135,7 @@ class TTSCache(TypedDict):
filename: str
voice: bytes
pending: asyncio.Task | None
@callback
@ -495,8 +497,11 @@ class SpeechManager:
)
extension = os.path.splitext(self.mem_cache[cache_key]["filename"])[1][1:]
data = self.mem_cache[cache_key]["voice"]
return extension, data
cached = self.mem_cache[cache_key]
if pending := cached.get("pending"):
await pending
cached = self.mem_cache[cache_key]
return extension, cached["voice"]
@callback
def _generate_cache_key(
@ -527,30 +532,62 @@ class SpeechManager:
This method is a coroutine.
"""
provider = self.providers[engine]
extension, data = await provider.async_get_tts_audio(message, language, options)
if data is None or extension is None:
raise HomeAssistantError(f"No TTS from {engine} for '{message}'")
if options is not None and ATTR_AUDIO_OUTPUT in options:
expected_extension = options[ATTR_AUDIO_OUTPUT]
else:
expected_extension = None
# Create file infos
filename = f"{cache_key}.{extension}".lower()
# Validate filename
if not _RE_VOICE_FILE.match(filename):
raise HomeAssistantError(
f"TTS filename '{filename}' from {engine} is invalid!"
async def get_tts_data() -> str:
"""Handle data available."""
extension, data = await provider.async_get_tts_audio(
message, language, options
)
# Save to memory
if extension == "mp3":
data = self.write_tags(filename, data, provider, message, language, options)
self._async_store_to_memcache(cache_key, filename, data)
if data is None or extension is None:
raise HomeAssistantError(f"No TTS from {engine} for '{message}'")
if cache:
self.hass.async_create_task(
self._async_save_tts_audio(cache_key, filename, data)
)
# Create file infos
filename = f"{cache_key}.{extension}".lower()
# Validate filename
if not _RE_VOICE_FILE.match(filename):
raise HomeAssistantError(
f"TTS filename '{filename}' from {engine} is invalid!"
)
# Save to memory
if extension == "mp3":
data = self.write_tags(
filename, data, provider, message, language, options
)
self._async_store_to_memcache(cache_key, filename, data)
if cache:
self.hass.async_create_task(
self._async_save_tts_audio(cache_key, filename, data)
)
return filename
audio_task = self.hass.async_create_task(get_tts_data())
if expected_extension is None:
return await audio_task
def handle_error(_future: asyncio.Future) -> None:
"""Handle error."""
if audio_task.exception():
self.mem_cache.pop(cache_key, None)
audio_task.add_done_callback(handle_error)
filename = f"{cache_key}.{expected_extension}".lower()
self.mem_cache[cache_key] = {
"filename": filename,
"voice": b"",
"pending": audio_task,
}
return filename
async def _async_save_tts_audio(
@ -601,7 +638,11 @@ class SpeechManager:
self, cache_key: str, filename: str, data: bytes
) -> None:
"""Store data to memcache and set timer to remove it."""
self.mem_cache[cache_key] = {"filename": filename, "voice": data}
self.mem_cache[cache_key] = {
"filename": filename,
"voice": data,
"pending": None,
}
@callback
def async_remove_from_mem() -> None:
@ -628,7 +669,11 @@ class SpeechManager:
await self._async_file_to_mem(cache_key)
content, _ = mimetypes.guess_type(filename)
return content, self.mem_cache[cache_key]["voice"]
cached = self.mem_cache[cache_key]
if pending := cached.get("pending"):
await pending
cached = self.mem_cache[cache_key]
return content, cached["voice"]
@staticmethod
def write_tags(

View file

@ -58,17 +58,17 @@ async def test_prefs_default_voice(
)
assert provider_pref.default_language == "en-US"
assert provider_pref.default_options == {"gender": "female"}
assert provider_pref.default_options == {"gender": "female", "audio_output": "mp3"}
assert provider_conf.default_language == "fr-FR"
assert provider_conf.default_options == {"gender": "female"}
assert provider_conf.default_options == {"gender": "female", "audio_output": "mp3"}
await cloud_prefs.async_update(tts_default_voice=("nl-NL", "male"))
await hass.async_block_till_done()
assert provider_pref.default_language == "nl-NL"
assert provider_pref.default_options == {"gender": "male"}
assert provider_pref.default_options == {"gender": "male", "audio_output": "mp3"}
assert provider_conf.default_language == "fr-FR"
assert provider_conf.default_options == {"gender": "female"}
assert provider_conf.default_options == {"gender": "female", "audio_output": "mp3"}
async def test_provider_properties(cloud_with_prefs) -> None:
@ -76,7 +76,7 @@ async def test_provider_properties(cloud_with_prefs) -> None:
provider = await tts.async_get_engine(
Mock(data={const.DOMAIN: cloud_with_prefs}), None, {}
)
assert provider.supported_options == ["gender"]
assert provider.supported_options == ["gender", "audio_output"]
assert "nl-NL" in provider.supported_languages
@ -85,5 +85,5 @@ async def test_get_tts_audio(cloud_with_prefs) -> None:
provider = await tts.async_get_engine(
Mock(data={const.DOMAIN: cloud_with_prefs}), None, {}
)
assert provider.supported_options == ["gender"]
assert provider.supported_options == ["gender", "audio_output"]
assert "nl-NL" in provider.supported_languages

View file

@ -1,4 +1,5 @@
"""The tests for the TTS component."""
import asyncio
from http import HTTPStatus
from typing import Any
from unittest.mock import patch
@ -996,3 +997,73 @@ async def test_support_options(hass: HomeAssistant, setup_tts) -> None:
await tts.async_support_options(hass, "test", "en", {"invalid_option": "yo"})
is False
)
async def test_fetching_in_async(hass: HomeAssistant, hass_client) -> None:
"""Test async fetching of data."""
tts_audio = asyncio.Future()
class ProviderWithAsyncFetching(MockProvider):
"""Provider that supports audio output option."""
@property
def supported_options(self) -> list[str]:
"""Return list of supported options like voice, emotions."""
return [tts.ATTR_AUDIO_OUTPUT]
@property
def default_options(self) -> dict[str, str]:
"""Return a dict including the default options."""
return {tts.ATTR_AUDIO_OUTPUT: "mp3"}
async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
) -> tts.TtsAudioType:
return ("mp3", await tts_audio)
mock_integration(hass, MockModule(domain="test"))
mock_platform(hass, "test.tts", MockTTS(ProviderWithAsyncFetching))
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
# Test async_get_media_source_audio
media_source_id = tts.generate_media_source_id(
hass, "test message", "test", "en", None, None
)
task = hass.async_create_task(
tts.async_get_media_source_audio(hass, media_source_id)
)
task2 = hass.async_create_task(
tts.async_get_media_source_audio(hass, media_source_id)
)
url = await get_media_source_url(hass, media_source_id)
client = await hass_client()
client_get_task = hass.async_create_task(client.get(url))
# Make sure that tasks are waiting for our future to resolve
done, pending = await asyncio.wait((task, task2, client_get_task), timeout=0.1)
assert len(done) == 0
assert len(pending) == 3
tts_audio.set_result(b"test")
assert await task == ("mp3", b"test")
assert await task2 == ("mp3", b"test")
req = await client_get_task
assert req.status == HTTPStatus.OK
assert await req.read() == b"test"
# Test error is not cached
media_source_id = tts.generate_media_source_id(
hass, "test message 2", "test", "en", None, None
)
tts_audio = asyncio.Future()
tts_audio.set_exception(HomeAssistantError("test error"))
with pytest.raises(HomeAssistantError):
assert await tts.async_get_media_source_audio(hass, media_source_id)
tts_audio = asyncio.Future()
tts_audio.set_result(b"test 2")
await tts.async_get_media_source_audio(hass, media_source_id) == ("mp3", b"test 2")