Remove fuzzy language matching from stt and tts (#92002)

* Remove fuzzy language matching from stt and tts

* Update tests
This commit is contained in:
Erik Montnemery 2023-04-25 17:54:42 +02:00 committed by GitHub
parent d1e6e4078c
commit 792ea92e55
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 57 additions and 97 deletions

View file

@ -221,18 +221,9 @@ class SpeechToTextEntity(RestoreEntity):
@callback @callback
def check_metadata(self, metadata: SpeechMetadata) -> bool: def check_metadata(self, metadata: SpeechMetadata) -> bool:
"""Check if given metadata supported by this provider.""" """Check if given metadata supported by this provider."""
if metadata.language not in self.supported_languages:
language_matches = language_util.matches(
metadata.language,
self.supported_languages,
)
if language_matches:
metadata.language = language_matches[0]
else:
return False
if ( if (
metadata.format not in self.supported_formats metadata.language not in self.supported_languages
or metadata.format not in self.supported_formats
or metadata.codec not in self.supported_codecs or metadata.codec not in self.supported_codecs
or metadata.bit_rate not in self.supported_bit_rates or metadata.bit_rate not in self.supported_bit_rates
or metadata.sample_rate not in self.supported_sample_rates or metadata.sample_rate not in self.supported_sample_rates

View file

@ -11,7 +11,6 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_per_platform, discovery from homeassistant.helpers import config_per_platform, discovery
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import async_prepare_setup_platform from homeassistant.setup import async_prepare_setup_platform
from homeassistant.util import language as language_util
from .const import ( from .const import (
DATA_PROVIDERS, DATA_PROVIDERS,
@ -163,18 +162,9 @@ class Provider(ABC):
@callback @callback
def check_metadata(self, metadata: SpeechMetadata) -> bool: def check_metadata(self, metadata: SpeechMetadata) -> bool:
"""Check if given metadata supported by this provider.""" """Check if given metadata supported by this provider."""
if metadata.language not in self.supported_languages:
language_matches = language_util.matches(
metadata.language,
self.supported_languages,
)
if language_matches:
metadata.language = language_matches[0]
else:
return False
if ( if (
metadata.format not in self.supported_formats metadata.language not in self.supported_languages
or metadata.format not in self.supported_formats
or metadata.codec not in self.supported_codecs or metadata.codec not in self.supported_codecs
or metadata.bit_rate not in self.supported_bit_rates or metadata.bit_rate not in self.supported_bit_rates
or metadata.sample_rate not in self.supported_sample_rates or metadata.sample_rate not in self.supported_sample_rates

View file

@ -483,20 +483,13 @@ class SpeechManager:
"""Validate and process options.""" """Validate and process options."""
# Languages # Languages
language = language or engine_instance.default_language language = language or engine_instance.default_language
if (
if language is None or engine_instance.supported_languages is None: language is None
or engine_instance.supported_languages is None
or language not in engine_instance.supported_languages
):
raise HomeAssistantError(f"Not supported language {language}") raise HomeAssistantError(f"Not supported language {language}")
if language not in engine_instance.supported_languages:
language_matches = language_util.matches(
language, engine_instance.supported_languages
)
if language_matches:
# Choose best match
language = language_matches[0]
else:
raise HomeAssistantError(f"Not supported language {language}")
# Options # Options
if (default_options := engine_instance.default_options) and options: if (default_options := engine_instance.default_options) and options:
merged_options = dict(default_options) merged_options = dict(default_options)

View file

@ -151,7 +151,7 @@
dict({ dict({
'data': dict({ 'data': dict({
'engine': 'test', 'engine': 'test',
'language': 'en-UA', 'language': 'en-US',
'tts_input': "Sorry, I couldn't understand that", 'tts_input': "Sorry, I couldn't understand that",
'voice': 'Arnold Schwarzenegger', 'voice': 'Arnold Schwarzenegger',
}), }),
@ -160,7 +160,7 @@
dict({ dict({
'data': dict({ 'data': dict({
'tts_output': dict({ 'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-UA&voice=Arnold+Schwarzenegger", 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=Arnold+Schwarzenegger",
'mime_type': 'audio/mpeg', 'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_2657c1a8ee_test.mp3', 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_2657c1a8ee_test.mp3',
}), }),
@ -238,7 +238,7 @@
dict({ dict({
'data': dict({ 'data': dict({
'engine': 'test', 'engine': 'test',
'language': 'en-AU', 'language': 'en-US',
'tts_input': "Sorry, I couldn't understand that", 'tts_input': "Sorry, I couldn't understand that",
'voice': 'Arnold Schwarzenegger', 'voice': 'Arnold Schwarzenegger',
}), }),
@ -247,7 +247,7 @@
dict({ dict({
'data': dict({ 'data': dict({
'tts_output': dict({ 'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-AU&voice=Arnold+Schwarzenegger", 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=Arnold+Schwarzenegger",
'mime_type': 'audio/mpeg', 'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_2657c1a8ee_test.mp3', 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_2657c1a8ee_test.mp3',
}), }),

View file

@ -84,9 +84,9 @@ async def test_pipeline_from_audio_stream_legacy(
"language": "en", "language": "en",
"name": "test_name", "name": "test_name",
"stt_engine": "test", "stt_engine": "test",
"stt_language": "en-UK", "stt_language": "en-US",
"tts_engine": "test", "tts_engine": "test",
"tts_language": "en-AU", "tts_language": "en-US",
"tts_voice": "Arnold Schwarzenegger", "tts_voice": "Arnold Schwarzenegger",
} }
) )
@ -150,9 +150,9 @@ async def test_pipeline_from_audio_stream_entity(
"language": "en", "language": "en",
"name": "test_name", "name": "test_name",
"stt_engine": mock_stt_provider_entity.entity_id, "stt_engine": mock_stt_provider_entity.entity_id,
"stt_language": "en-UK", "stt_language": "en-US",
"tts_engine": "test", "tts_engine": "test",
"tts_language": "en-UA", "tts_language": "en-US",
"tts_voice": "Arnold Schwarzenegger", "tts_voice": "Arnold Schwarzenegger",
} }
) )

View file

@ -54,7 +54,7 @@ class BaseProvider:
@property @property
def supported_languages(self) -> list[str]: def supported_languages(self) -> list[str]:
"""Return a list of supported languages.""" """Return a list of supported languages."""
return ["de", "de-CH", "en-US"] return ["de", "de-CH", "en"]
@property @property
def supported_formats(self) -> list[AudioFormats]: def supported_formats(self) -> list[AudioFormats]:
@ -224,7 +224,7 @@ async def test_get_provider_info(
response = await client.get(f"/api/stt/{setup.url_path}") response = await client.get(f"/api/stt/{setup.url_path}")
assert response.status == HTTPStatus.OK assert response.status == HTTPStatus.OK
assert await response.json() == { assert await response.json() == {
"languages": ["de", "de-CH", "en-US"], "languages": ["de", "de-CH", "en"],
"formats": ["wav", "ogg"], "formats": ["wav", "ogg"],
"codecs": ["pcm", "opus"], "codecs": ["pcm", "opus"],
"sample_rates": [16000], "sample_rates": [16000],
@ -247,7 +247,6 @@ async def test_non_existing_provider(
response = await client.get("/api/stt/not_exist") response = await client.get("/api/stt/not_exist")
assert response.status == HTTPStatus.NOT_FOUND assert response.status == HTTPStatus.NOT_FOUND
# Language en is matched with en-US
response = await client.post( response = await client.post(
"/api/stt/not_exist", "/api/stt/not_exist",
headers={ headers={
@ -270,8 +269,6 @@ async def test_stream_audio(
) -> None: ) -> None:
"""Test streaming audio and getting response.""" """Test streaming audio and getting response."""
client = await hass_client() client = await hass_client()
# Language en is matched with en-US
response = await client.post( response = await client.post(
f"/api/stt/{setup.url_path}", f"/api/stt/{setup.url_path}",
headers={ headers={
@ -404,7 +401,7 @@ async def test_ws_list_engines(
assert msg["success"] assert msg["success"]
assert msg["result"] == { assert msg["result"] == {
"providers": [ "providers": [
{"engine_id": engine_id, "supported_languages": ["de", "de-CH", "en-US"]} {"engine_id": engine_id, "supported_languages": ["de", "de-CH", "en"]}
] ]
} }
@ -421,7 +418,7 @@ async def test_ws_list_engines(
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
assert msg["result"] == { assert msg["result"] == {
"providers": [{"engine_id": engine_id, "supported_languages": ["en-US"]}] "providers": [{"engine_id": engine_id, "supported_languages": ["en"]}]
} }
await client.send_json_auto_id({"type": "stt/engine/list", "language": "en-UK"}) await client.send_json_auto_id({"type": "stt/engine/list", "language": "en-UK"})
@ -429,7 +426,7 @@ async def test_ws_list_engines(
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
assert msg["result"] == { assert msg["result"] == {
"providers": [{"engine_id": engine_id, "supported_languages": ["en-US"]}] "providers": [{"engine_id": engine_id, "supported_languages": ["en"]}]
} }
await client.send_json_auto_id({"type": "stt/engine/list", "language": "de"}) await client.send_json_auto_id({"type": "stt/engine/list", "language": "de"})

View file

@ -32,7 +32,6 @@ from tests.common import (
DEFAULT_LANG = "en_US" DEFAULT_LANG = "en_US"
SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"] SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"]
TEST_DOMAIN = "test" TEST_DOMAIN = "test"
TEST_LANGUAGES = ["de", "en"]
async def get_media_source_url(hass: HomeAssistant, media_content_id: str) -> str: async def get_media_source_url(hass: HomeAssistant, media_content_id: str) -> str:
@ -105,11 +104,7 @@ class MockTTS(MockPlatform):
"""A mock TTS platform.""" """A mock TTS platform."""
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
{ {vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES)}
vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In(
SUPPORT_LANGUAGES + TEST_LANGUAGES
)
}
) )
def __init__(self, provider: MockProvider, **kwargs: Any) -> None: def __init__(self, provider: MockProvider, **kwargs: Any) -> None:

View file

@ -217,9 +217,9 @@ async def test_service(
).is_file() ).is_file()
# Language de is matched with de_DE
@pytest.mark.parametrize( @pytest.mark.parametrize(
("mock_provider", "mock_tts_entity"), [(MockProvider("de"), MockTTSEntity("de"))] ("mock_provider", "mock_tts_entity"),
[(MockProvider("de_DE"), MockTTSEntity("de_DE"))],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
("setup", "tts_service", "service_data", "expected_url_suffix"), ("setup", "tts_service", "service_data", "expected_url_suffix"),
@ -346,7 +346,7 @@ async def test_service_default_special_language(
{ {
ATTR_ENTITY_ID: "media_player.something", ATTR_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is someone at the door.", tts.ATTR_MESSAGE: "There is someone at the door.",
tts.ATTR_LANGUAGE: "de", tts.ATTR_LANGUAGE: "de_DE",
}, },
"test", "test",
), ),
@ -357,7 +357,7 @@ async def test_service_default_special_language(
ATTR_ENTITY_ID: "tts.test", ATTR_ENTITY_ID: "tts.test",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something", tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is someone at the door.", tts.ATTR_MESSAGE: "There is someone at the door.",
tts.ATTR_LANGUAGE: "de", tts.ATTR_LANGUAGE: "de_DE",
}, },
"tts.test", "tts.test",
), ),
@ -455,7 +455,7 @@ async def test_service_wrong_language(
{ {
ATTR_ENTITY_ID: "media_player.something", ATTR_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is someone at the door.", tts.ATTR_MESSAGE: "There is someone at the door.",
tts.ATTR_LANGUAGE: "de", tts.ATTR_LANGUAGE: "de_DE",
tts.ATTR_OPTIONS: {"voice": "alex", "age": 5}, tts.ATTR_OPTIONS: {"voice": "alex", "age": 5},
}, },
"test", "test",
@ -467,7 +467,7 @@ async def test_service_wrong_language(
ATTR_ENTITY_ID: "tts.test", ATTR_ENTITY_ID: "tts.test",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something", tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is someone at the door.", tts.ATTR_MESSAGE: "There is someone at the door.",
tts.ATTR_LANGUAGE: "de", tts.ATTR_LANGUAGE: "de_DE",
tts.ATTR_OPTIONS: {"voice": "alex", "age": 5}, tts.ATTR_OPTIONS: {"voice": "alex", "age": 5},
}, },
"tts.test", "tts.test",
@ -541,7 +541,7 @@ class MockEntityWithDefaults(MockTTSEntity):
{ {
ATTR_ENTITY_ID: "media_player.something", ATTR_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is someone at the door.", tts.ATTR_MESSAGE: "There is someone at the door.",
tts.ATTR_LANGUAGE: "de", tts.ATTR_LANGUAGE: "de_DE",
}, },
"test", "test",
), ),
@ -552,7 +552,7 @@ class MockEntityWithDefaults(MockTTSEntity):
ATTR_ENTITY_ID: "tts.test", ATTR_ENTITY_ID: "tts.test",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something", tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is someone at the door.", tts.ATTR_MESSAGE: "There is someone at the door.",
tts.ATTR_LANGUAGE: "de", tts.ATTR_LANGUAGE: "de_DE",
}, },
"tts.test", "tts.test",
), ),
@ -607,7 +607,7 @@ async def test_service_default_options(
{ {
ATTR_ENTITY_ID: "media_player.something", ATTR_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is someone at the door.", tts.ATTR_MESSAGE: "There is someone at the door.",
tts.ATTR_LANGUAGE: "de", tts.ATTR_LANGUAGE: "de_DE",
tts.ATTR_OPTIONS: {"age": 5}, tts.ATTR_OPTIONS: {"age": 5},
}, },
"test", "test",
@ -619,7 +619,7 @@ async def test_service_default_options(
ATTR_ENTITY_ID: "tts.test", ATTR_ENTITY_ID: "tts.test",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something", tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is someone at the door.", tts.ATTR_MESSAGE: "There is someone at the door.",
tts.ATTR_LANGUAGE: "de", tts.ATTR_LANGUAGE: "de_DE",
tts.ATTR_OPTIONS: {"age": 5}, tts.ATTR_OPTIONS: {"age": 5},
}, },
"tts.test", "tts.test",
@ -674,7 +674,7 @@ async def test_merge_default_service_options(
{ {
ATTR_ENTITY_ID: "media_player.something", ATTR_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is someone at the door.", tts.ATTR_MESSAGE: "There is someone at the door.",
tts.ATTR_LANGUAGE: "de", tts.ATTR_LANGUAGE: "de_DE",
tts.ATTR_OPTIONS: {"speed": 1}, tts.ATTR_OPTIONS: {"speed": 1},
}, },
"test", "test",
@ -686,7 +686,7 @@ async def test_merge_default_service_options(
ATTR_ENTITY_ID: "tts.test", ATTR_ENTITY_ID: "tts.test",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something", tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is someone at the door.", tts.ATTR_MESSAGE: "There is someone at the door.",
tts.ATTR_LANGUAGE: "de", tts.ATTR_LANGUAGE: "de_DE",
tts.ATTR_OPTIONS: {"speed": 1}, tts.ATTR_OPTIONS: {"speed": 1},
}, },
"tts.test", "tts.test",
@ -855,7 +855,8 @@ async def test_service_receive_voice(
@pytest.mark.parametrize( @pytest.mark.parametrize(
("mock_provider", "mock_tts_entity"), [(MockProvider("de"), MockTTSEntity("de"))] ("mock_provider", "mock_tts_entity"),
[(MockProvider("de_DE"), MockTTSEntity("de_DE"))],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
("setup", "tts_service", "service_data", "expected_url_suffix"), ("setup", "tts_service", "service_data", "expected_url_suffix"),
@ -1047,7 +1048,6 @@ async def test_setup_legacy_cache_dir(
"""Set up a TTS platform with cache and call service without cache.""" """Set up a TTS platform with cache and call service without cache."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
# Language en is matched with en_US
tts_data = b"" tts_data = b""
cache_file = ( cache_file = (
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3" empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3"
@ -1084,7 +1084,6 @@ async def test_setup_cache_dir(
"""Set up a TTS platform with cache and call service without cache.""" """Set up a TTS platform with cache and call service without cache."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
# Language en is matched with en_US
tts_data = b"" tts_data = b""
cache_file = empty_cache_dir / ( cache_file = empty_cache_dir / (
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_tts.test.mp3" "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_tts.test.mp3"
@ -1187,10 +1186,9 @@ async def test_load_cache_legacy_retrieve_without_mem_cache(
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
) -> None: ) -> None:
"""Set up component and load cache and get without mem cache.""" """Set up component and load cache and get without mem cache."""
# Language en is matched with en_US
tts_data = b"" tts_data = b""
cache_file = ( cache_file = (
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3" empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
) )
with open(cache_file, "wb") as voice_file: with open(cache_file, "wb") as voice_file:
@ -1200,7 +1198,7 @@ async def test_load_cache_legacy_retrieve_without_mem_cache(
client = await hass_client() client = await hass_client()
url = "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3" url = "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
req = await client.get(url) req = await client.get(url)
assert req.status == HTTPStatus.OK assert req.status == HTTPStatus.OK
@ -1214,7 +1212,6 @@ async def test_load_cache_retrieve_without_mem_cache(
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
) -> None: ) -> None:
"""Set up component and load cache and get without mem cache.""" """Set up component and load cache and get without mem cache."""
# Language en is matched with en_US
tts_data = b"" tts_data = b""
cache_file = empty_cache_dir / ( cache_file = empty_cache_dir / (
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_tts.test.mp3" "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_tts.test.mp3"
@ -1306,7 +1303,7 @@ async def test_tags_with_wave() -> None:
) )
tagged_data = ORIG_WRITE_TAGS( tagged_data = ORIG_WRITE_TAGS(
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.wav", "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.wav",
tts_data, tts_data,
"Test", "Test",
"AI person is in front of your door.", "AI person is in front of your door.",
@ -1367,9 +1364,9 @@ def test_invalid_base_url(value) -> None:
("engine", "language", "options", "cache", "result_query"), ("engine", "language", "options", "cache", "result_query"),
( (
(None, None, None, None, ""), (None, None, None, None, ""),
(None, "de", None, None, "language=de"), (None, "de_DE", None, None, "language=de_DE"),
(None, "de", {"voice": "henk"}, None, "language=de&voice=henk"), (None, "de_DE", {"voice": "henk"}, None, "language=de_DE&voice=henk"),
(None, "de", None, True, "cache=true&language=de"), (None, "de_DE", None, True, "cache=true&language=de_DE"),
), ),
) )
async def test_generate_media_source_id( async def test_generate_media_source_id(
@ -1456,11 +1453,12 @@ def test_resolve_engine(hass: HomeAssistant, setup: str, engine_id: str) -> None
) )
async def test_support_options(hass: HomeAssistant, setup: str, engine_id: str) -> None: async def test_support_options(hass: HomeAssistant, setup: str, engine_id: str) -> None:
"""Test supporting options.""" """Test supporting options."""
# Language en is matched with en_US assert await tts.async_support_options(hass, engine_id, "en_US") is True
assert await tts.async_support_options(hass, engine_id, "en") is True
assert await tts.async_support_options(hass, engine_id, "nl") is False assert await tts.async_support_options(hass, engine_id, "nl") is False
assert ( assert (
await tts.async_support_options(hass, engine_id, "en", {"invalid_option": "yo"}) await tts.async_support_options(
hass, engine_id, "en_US", {"invalid_option": "yo"}
)
is False is False
) )
@ -1496,7 +1494,7 @@ async def test_legacy_fetching_in_async(
# Test async_get_media_source_audio # Test async_get_media_source_audio
media_source_id = tts.generate_media_source_id( media_source_id = tts.generate_media_source_id(
hass, "test message", "test", "en", None, None hass, "test message", "test", "en_US", None, None
) )
task = hass.async_create_task( task = hass.async_create_task(
@ -1526,7 +1524,7 @@ async def test_legacy_fetching_in_async(
# Test error is not cached # Test error is not cached
media_source_id = tts.generate_media_source_id( media_source_id = tts.generate_media_source_id(
hass, "test message 2", "test", "en", None, None hass, "test message 2", "test", "en_US", None, None
) )
tts_audio = asyncio.Future() tts_audio = asyncio.Future()
tts_audio.set_exception(HomeAssistantError("test error")) tts_audio.set_exception(HomeAssistantError("test error"))
@ -1569,7 +1567,7 @@ async def test_fetching_in_async(
# Test async_get_media_source_audio # Test async_get_media_source_audio
media_source_id = tts.generate_media_source_id( media_source_id = tts.generate_media_source_id(
hass, "test message", "tts.test", "en", None, None hass, "test message", "tts.test", "en_US", None, None
) )
task = hass.async_create_task( task = hass.async_create_task(
@ -1599,7 +1597,7 @@ async def test_fetching_in_async(
# Test error is not cached # Test error is not cached
media_source_id = tts.generate_media_source_id( media_source_id = tts.generate_media_source_id(
hass, "test message 2", "tts.test", "en", None, None hass, "test message 2", "tts.test", "en_US", None, None
) )
tts_audio = asyncio.Future() tts_audio = asyncio.Future()
tts_audio.set_exception(HomeAssistantError("test error")) tts_audio.set_exception(HomeAssistantError("test error"))

View file

@ -109,7 +109,7 @@ async def test_legacy_resolving(hass: HomeAssistant, mock_provider: MSProvider)
mock_get_tts_audio.reset_mock() mock_get_tts_audio.reset_mock()
media = await media_source.async_resolve_media( media = await media_source.async_resolve_media(
hass, hass,
"media-source://tts/test?message=Bye%20World&language=de&voice=Paulus", "media-source://tts/test?message=Bye%20World&language=de_DE&voice=Paulus",
None, None,
) )
assert media.url.startswith("/api/tts_proxy/") assert media.url.startswith("/api/tts_proxy/")
@ -144,7 +144,7 @@ async def test_resolving(hass: HomeAssistant, mock_tts_entity: MSEntity) -> None
mock_get_tts_audio.reset_mock() mock_get_tts_audio.reset_mock()
media = await media_source.async_resolve_media( media = await media_source.async_resolve_media(
hass, hass,
"media-source://tts/tts.test?message=Bye%20World&language=de&voice=Paulus", "media-source://tts/tts.test?message=Bye%20World&language=de_DE&voice=Paulus",
None, None,
) )
assert media.url.startswith("/api/tts_proxy/") assert media.url.startswith("/api/tts_proxy/")

View file

@ -104,7 +104,7 @@ async def test_setup_service(
"name": "tts_test", "name": "tts_test",
"entity_id": "tts.test", "entity_id": "tts.test",
"media_player": "media_player.demo", "media_player": "media_player.demo",
"language": "en", "language": "en_US",
}, },
} }

View file

@ -53,9 +53,7 @@ async def test_get_tts_audio(hass: HomeAssistant, init_wyoming_tts, snapshot) ->
) as mock_client: ) as mock_client:
extension, data = await tts.async_get_media_source_audio( extension, data = await tts.async_get_media_source_audio(
hass, hass,
tts.generate_media_source_id( tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"),
hass, "Hello world", "tts.test_tts", hass.config.language
),
) )
assert extension == "wav" assert extension == "wav"
@ -89,7 +87,7 @@ async def test_get_tts_audio_raw(
hass, hass,
"Hello world", "Hello world",
"tts.test_tts", "tts.test_tts",
hass.config.language, "en-US",
options={tts.ATTR_AUDIO_OUTPUT: "raw"}, options={tts.ATTR_AUDIO_OUTPUT: "raw"},
), ),
) )
@ -109,9 +107,7 @@ async def test_get_tts_audio_connection_lost(
), pytest.raises(HomeAssistantError): ), pytest.raises(HomeAssistantError):
await tts.async_get_media_source_audio( await tts.async_get_media_source_audio(
hass, hass,
tts.generate_media_source_id( tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"),
hass, "Hello world", "tts.test_tts", hass.config.language
),
) )