Fix defaults for cloud STT/TTS (#121229)

* Fix defaults for cloud STT/TTS

* Prefer entity over legacy provider

* Remove unrealistic tests

* Add tests which show cloud stt/tts entity is preferred

---------

Co-authored-by: Erik <erik@montnemery.com>
This commit is contained in:
Paulus Schoutsen 2024-08-26 19:39:09 +02:00 committed by GitHub
parent 547dbf77aa
commit 156948c496
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 144 additions and 34 deletions

View file

@ -72,9 +72,18 @@ CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
@callback @callback
def async_default_engine(hass: HomeAssistant) -> str | None: def async_default_engine(hass: HomeAssistant) -> str | None:
"""Return the domain or entity id of the default engine.""" """Return the domain or entity id of the default engine."""
return next( component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
iter(hass.states.async_entity_ids(DOMAIN)), None
) or async_default_provider(hass) default_entity_id: str | None = None
for entity in component.entities:
if entity.platform and entity.platform.platform_name == "cloud":
return entity.entity_id
if default_entity_id is None:
default_entity_id = entity.entity_id
return default_entity_id or async_default_provider(hass)
@callback @callback

View file

@ -137,15 +137,16 @@ def async_default_engine(hass: HomeAssistant) -> str | None:
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN] component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
manager: SpeechManager = hass.data[DATA_TTS_MANAGER] manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
if "cloud" in manager.providers: default_entity_id: str | None = None
return "cloud"
entity = next(iter(component.entities), None) for entity in component.entities:
if entity.platform and entity.platform.platform_name == "cloud":
return entity.entity_id
if entity is not None: if default_entity_id is None:
return entity.entity_id default_entity_id = entity.entity_id
return next(iter(manager.providers), None) return default_entity_id or next(iter(manager.providers), None)
@callback @callback

View file

@ -1,6 +1,7 @@
"""Test STT component setup.""" """Test STT component setup."""
from collections.abc import AsyncIterable, Generator from collections.abc import AsyncIterable, Generator, Iterable
from contextlib import ExitStack
from http import HTTPStatus from http import HTTPStatus
from pathlib import Path from pathlib import Path
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
@ -122,20 +123,23 @@ class STTFlow(ConfigFlow):
"""Test flow.""" """Test flow."""
@pytest.fixture(name="config_flow_test_domain") @pytest.fixture(name="config_flow_test_domains")
def config_flow_test_domain_fixture() -> str: def config_flow_test_domain_fixture() -> Iterable[str]:
"""Test domain fixture.""" """Test domain fixture."""
return TEST_DOMAIN return (TEST_DOMAIN,)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def config_flow_fixture( def config_flow_fixture(
hass: HomeAssistant, config_flow_test_domain: str hass: HomeAssistant, config_flow_test_domains: Iterable[str]
) -> Generator[None]: ) -> Generator[None]:
"""Mock config flow.""" """Mock config flow."""
mock_platform(hass, f"{config_flow_test_domain}.config_flow") for domain in config_flow_test_domains:
mock_platform(hass, f"{domain}.config_flow")
with mock_config_flow(config_flow_test_domain, STTFlow): with ExitStack() as stack:
for domain in config_flow_test_domains:
stack.enter_context(mock_config_flow(domain, STTFlow))
yield yield
@ -496,21 +500,25 @@ async def test_default_engine_entity(
assert async_default_engine(hass) == f"{DOMAIN}.{TEST_DOMAIN}" assert async_default_engine(hass) == f"{DOMAIN}.{TEST_DOMAIN}"
@pytest.mark.parametrize("config_flow_test_domain", ["new_test"]) @pytest.mark.parametrize("config_flow_test_domains", [("new_test",)])
async def test_default_engine_prefer_entity( async def test_default_engine_prefer_entity(
hass: HomeAssistant, hass: HomeAssistant,
tmp_path: Path, tmp_path: Path,
mock_provider_entity: MockProviderEntity, mock_provider_entity: MockProviderEntity,
mock_provider: MockProvider, mock_provider: MockProvider,
config_flow_test_domain: str, config_flow_test_domains: str,
) -> None: ) -> None:
"""Test async_default_engine.""" """Test async_default_engine.
In this tests there's an entity and a legacy provider.
The test asserts async_default_engine returns the entity.
"""
mock_provider_entity.url_path = "stt.new_test" mock_provider_entity.url_path = "stt.new_test"
mock_provider_entity._attr_name = "New test" mock_provider_entity._attr_name = "New test"
await mock_setup(hass, tmp_path, mock_provider) await mock_setup(hass, tmp_path, mock_provider)
await mock_config_entry_setup( await mock_config_entry_setup(
hass, tmp_path, mock_provider_entity, test_domain=config_flow_test_domain hass, tmp_path, mock_provider_entity, test_domain=config_flow_test_domains[0]
) )
await hass.async_block_till_done() await hass.async_block_till_done()
@ -523,6 +531,48 @@ async def test_default_engine_prefer_entity(
assert async_default_engine(hass) == "stt.new_test" assert async_default_engine(hass) == "stt.new_test"
@pytest.mark.parametrize(
"config_flow_test_domains",
[
# Test different setup order to ensure the default is not influenced
# by setup order.
("cloud", "new_test"),
("new_test", "cloud"),
],
)
async def test_default_engine_prefer_cloud_entity(
hass: HomeAssistant,
tmp_path: Path,
mock_provider: MockProvider,
config_flow_test_domains: str,
) -> None:
"""Test async_default_engine.
In this tests there's an entity from domain cloud, an entity from domain new_test
and a legacy provider.
The test asserts async_default_engine returns the entity from domain cloud.
"""
await mock_setup(hass, tmp_path, mock_provider)
for domain in config_flow_test_domains:
entity = MockProviderEntity()
entity.url_path = f"stt.{domain}"
entity._attr_name = f"{domain} STT entity"
await mock_config_entry_setup(hass, tmp_path, entity, test_domain=domain)
await hass.async_block_till_done()
for domain in config_flow_test_domains:
entity_engine = async_get_speech_to_text_engine(
hass, f"stt.{domain}_stt_entity"
)
assert entity_engine is not None
assert entity_engine.name == f"{domain} STT entity"
provider_engine = async_get_speech_to_text_engine(hass, "test")
assert provider_engine is not None
assert provider_engine.name == "test"
assert async_default_engine(hass) == "stt.cloud_stt_entity"
async def test_get_engine_legacy( async def test_get_engine_legacy(
hass: HomeAssistant, tmp_path: Path, mock_provider: MockProvider hass: HomeAssistant, tmp_path: Path, mock_provider: MockProvider
) -> None: ) -> None:

View file

@ -215,7 +215,9 @@ async def mock_setup(
async def mock_config_entry_setup( async def mock_config_entry_setup(
hass: HomeAssistant, tts_entity: MockTTSEntity hass: HomeAssistant,
tts_entity: MockTTSEntity,
test_domain: str = TEST_DOMAIN,
) -> MockConfigEntry: ) -> MockConfigEntry:
"""Set up a test tts platform via config entry.""" """Set up a test tts platform via config entry."""
@ -236,7 +238,7 @@ async def mock_config_entry_setup(
mock_integration( mock_integration(
hass, hass,
MockModule( MockModule(
TEST_DOMAIN, test_domain,
async_setup_entry=async_setup_entry_init, async_setup_entry=async_setup_entry_init,
async_unload_entry=async_unload_entry_init, async_unload_entry=async_unload_entry_init,
), ),
@ -251,9 +253,9 @@ async def mock_config_entry_setup(
async_add_entities([tts_entity]) async_add_entities([tts_entity])
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform) loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform)
mock_platform(hass, f"{TEST_DOMAIN}.{TTS_DOMAIN}", loaded_platform) mock_platform(hass, f"{test_domain}.{TTS_DOMAIN}", loaded_platform)
config_entry = MockConfigEntry(domain=TEST_DOMAIN) config_entry = MockConfigEntry(domain=test_domain)
config_entry.add_to_hass(hass) config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id) assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()

View file

@ -3,7 +3,8 @@
From http://doc.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures From http://doc.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures
""" """
from collections.abc import Generator from collections.abc import Generator, Iterable
from contextlib import ExitStack
from pathlib import Path from pathlib import Path
from unittest.mock import MagicMock from unittest.mock import MagicMock
@ -81,12 +82,23 @@ class TTSFlow(ConfigFlow):
"""Test flow.""" """Test flow."""
@pytest.fixture(autouse=True) @pytest.fixture(name="config_flow_test_domains")
def config_flow_fixture(hass: HomeAssistant) -> Generator[None]: def config_flow_test_domain_fixture() -> Iterable[str]:
"""Mock config flow.""" """Test domain fixture."""
mock_platform(hass, f"{TEST_DOMAIN}.config_flow") return (TEST_DOMAIN,)
with mock_config_flow(TEST_DOMAIN, TTSFlow):
@pytest.fixture(autouse=True)
def config_flow_fixture(
hass: HomeAssistant, config_flow_test_domains: Iterable[str]
) -> Generator[None]:
"""Mock config flow."""
for domain in config_flow_test_domains:
mock_platform(hass, f"{domain}.config_flow")
with ExitStack() as stack:
for domain in config_flow_test_domains:
stack.enter_context(mock_config_flow(domain, TTSFlow))
yield yield

View file

@ -1389,9 +1389,6 @@ def test_resolve_engine(hass: HomeAssistant, setup: str, engine_id: str) -> None
): ):
assert tts.async_resolve_engine(hass, None) is None assert tts.async_resolve_engine(hass, None) is None
with patch.dict(hass.data[tts.DATA_TTS_MANAGER].providers, {"cloud": object()}):
assert tts.async_resolve_engine(hass, None) == "cloud"
@pytest.mark.parametrize( @pytest.mark.parametrize(
("setup", "engine_id"), ("setup", "engine_id"),
@ -1845,7 +1842,11 @@ async def test_default_engine_prefer_entity(
mock_tts_entity: MockTTSEntity, mock_tts_entity: MockTTSEntity,
mock_provider: MockProvider, mock_provider: MockProvider,
) -> None: ) -> None:
"""Test async_default_engine.""" """Test async_default_engine.
In this tests there's an entity and a legacy provider.
The test asserts async_default_engine returns the entity.
"""
mock_tts_entity._attr_name = "New test" mock_tts_entity._attr_name = "New test"
await mock_setup(hass, mock_provider) await mock_setup(hass, mock_provider)
@ -1857,3 +1858,38 @@ async def test_default_engine_prefer_entity(
provider_engine = tts.async_resolve_engine(hass, "test") provider_engine = tts.async_resolve_engine(hass, "test")
assert provider_engine == "test" assert provider_engine == "test"
assert tts.async_default_engine(hass) == "tts.new_test" assert tts.async_default_engine(hass) == "tts.new_test"
@pytest.mark.parametrize(
"config_flow_test_domains",
[
# Test different setup order to ensure the default is not influenced
# by setup order.
("cloud", "new_test"),
("new_test", "cloud"),
],
)
async def test_default_engine_prefer_cloud_entity(
hass: HomeAssistant,
mock_provider: MockProvider,
config_flow_test_domains: str,
) -> None:
"""Test async_default_engine.
In this tests there's an entity from domain cloud, an entity from domain new_test
and a legacy provider.
The test asserts async_default_engine returns the entity from domain cloud.
"""
await mock_setup(hass, mock_provider)
for domain in config_flow_test_domains:
entity = MockTTSEntity(DEFAULT_LANG)
entity._attr_name = f"{domain} TTS entity"
await mock_config_entry_setup(hass, entity, test_domain=domain)
await hass.async_block_till_done()
for domain in config_flow_test_domains:
entity_engine = tts.async_resolve_engine(hass, f"tts.{domain}_tts_entity")
assert entity_engine == f"tts.{domain}_tts_entity"
provider_engine = tts.async_resolve_engine(hass, "test")
assert provider_engine == "test"
assert tts.async_default_engine(hass) == "tts.cloud_tts_entity"