TTS: allow resolving engine and test supported options (#90539)

TTS: allow resolving engine
This commit is contained in:
Paulus Schoutsen 2023-03-31 14:34:42 -04:00 committed by GitHub
parent 44eaf70625
commit 8018be28ee
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 66 additions and 8 deletions

View file

@ -136,6 +136,44 @@ class TTSCache(TypedDict):
voice: bytes voice: bytes
@callback
def async_resolve_engine(hass: HomeAssistant, engine: str | None) -> str | None:
"""Resolve engine.
Returns None if no engines found or invalid engine passed in.
"""
manager: SpeechManager = hass.data[DOMAIN]
if engine is not None:
if engine not in manager.providers:
return None
return engine
if not manager.providers:
return None
if "cloud" in manager.providers:
return "cloud"
return next(iter(manager.providers))
async def async_support_options(
hass: HomeAssistant,
engine: str,
language: str | None = None,
options: dict | None = None,
) -> bool:
"""Return if an engine supports options."""
manager: SpeechManager = hass.data[DOMAIN]
try:
manager.process_options(engine, language, options)
except HomeAssistantError:
return False
return True
async def async_get_media_source_audio( async def async_get_media_source_audio(
hass: HomeAssistant, hass: HomeAssistant,
media_source_id: str, media_source_id: str,

View file

@ -40,16 +40,12 @@ def generate_media_source_id(
cache: bool | None = None, cache: bool | None = None,
) -> str: ) -> str:
"""Generate a media source ID for text-to-speech.""" """Generate a media source ID for text-to-speech."""
from . import async_resolve_engine # pylint: disable=import-outside-toplevel
manager: SpeechManager = hass.data[DOMAIN] manager: SpeechManager = hass.data[DOMAIN]
if engine is not None: if (engine := async_resolve_engine(hass, engine)) is None:
pass raise HomeAssistantError("Invalid TTS provider selected")
elif not manager.providers:
raise HomeAssistantError("No TTS providers available")
elif "cloud" in manager.providers:
engine = "cloud"
else:
engine = next(iter(manager.providers))
manager.process_options(engine, language, options) manager.process_options(engine, language, options)
params = { params = {

View file

@ -1,6 +1,7 @@
"""The tests for the TTS component.""" """The tests for the TTS component."""
from http import HTTPStatus from http import HTTPStatus
from typing import Any from typing import Any
from unittest.mock import patch
import pytest import pytest
import voluptuous as vol import voluptuous as vol
@ -972,3 +973,26 @@ async def test_generate_media_source_id_invalid_options(
"""Test generating a media source ID.""" """Test generating a media source ID."""
with pytest.raises(HomeAssistantError): with pytest.raises(HomeAssistantError):
tts.generate_media_source_id(hass, "msg", engine, language, options, None) tts.generate_media_source_id(hass, "msg", engine, language, options, None)
def test_resolve_engine(hass: HomeAssistant, setup_tts) -> None:
"""Test resolving engine."""
assert tts.async_resolve_engine(hass, None) == "test"
assert tts.async_resolve_engine(hass, "test") == "test"
assert tts.async_resolve_engine(hass, "non-existing") is None
with patch.dict(hass.data[tts.DOMAIN].providers, {}, clear=True):
assert tts.async_resolve_engine(hass, "test") is None
with patch.dict(hass.data[tts.DOMAIN].providers, {"cloud": object()}):
assert tts.async_resolve_engine(hass, None) == "cloud"
async def test_support_options(hass: HomeAssistant, setup_tts) -> None:
"""Test supporting options."""
assert await tts.async_support_options(hass, "test", "en") is True
assert await tts.async_support_options(hass, "test", "nl") is False
assert (
await tts.async_support_options(hass, "test", "en", {"invalid_option": "yo"})
is False
)