From 156948c4960b34e579a2047eb7b570960767c2fa Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 26 Aug 2024 19:39:09 +0200 Subject: [PATCH] Fix defaults for cloud STT/TTS (#121229) * Fix defaults for cloud STT/TTS * Prefer entity over legacy provider * Remove unrealistic tests * Add tests which show cloud stt/tts entity is preferred --------- Co-authored-by: Erik --- homeassistant/components/stt/__init__.py | 15 ++++- homeassistant/components/tts/__init__.py | 13 +++-- tests/components/stt/test_init.py | 72 ++++++++++++++++++++---- tests/components/tts/common.py | 10 ++-- tests/components/tts/conftest.py | 24 ++++++-- tests/components/tts/test_init.py | 44 +++++++++++++-- 6 files changed, 144 insertions(+), 34 deletions(-) diff --git a/homeassistant/components/stt/__init__.py b/homeassistant/components/stt/__init__.py index 227c92f2b98..f6c38c1e0b7 100644 --- a/homeassistant/components/stt/__init__.py +++ b/homeassistant/components/stt/__init__.py @@ -72,9 +72,18 @@ CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN) @callback def async_default_engine(hass: HomeAssistant) -> str | None: """Return the domain or entity id of the default engine.""" - return next( - iter(hass.states.async_entity_ids(DOMAIN)), None - ) or async_default_provider(hass) + component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN] + + default_entity_id: str | None = None + + for entity in component.entities: + if entity.platform and entity.platform.platform_name == "cloud": + return entity.entity_id + + if default_entity_id is None: + default_entity_id = entity.entity_id + + return default_entity_id or async_default_provider(hass) @callback diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index 5286b01f67f..583db4472d4 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -137,15 +137,16 @@ def async_default_engine(hass: HomeAssistant) -> str | None: component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN] manager: SpeechManager = hass.data[DATA_TTS_MANAGER] - if "cloud" in manager.providers: - return "cloud" + default_entity_id: str | None = None - entity = next(iter(component.entities), None) + for entity in component.entities: + if entity.platform and entity.platform.platform_name == "cloud": + return entity.entity_id - if entity is not None: - return entity.entity_id + if default_entity_id is None: + default_entity_id = entity.entity_id - return next(iter(manager.providers), None) + return default_entity_id or next(iter(manager.providers), None) @callback diff --git a/tests/components/stt/test_init.py b/tests/components/stt/test_init.py index a42ac44112e..e5d75d3c4a5 100644 --- a/tests/components/stt/test_init.py +++ b/tests/components/stt/test_init.py @@ -1,6 +1,7 @@ """Test STT component setup.""" -from collections.abc import AsyncIterable, Generator +from collections.abc import AsyncIterable, Generator, Iterable +from contextlib import ExitStack from http import HTTPStatus from pathlib import Path from unittest.mock import AsyncMock @@ -122,20 +123,23 @@ class STTFlow(ConfigFlow): """Test flow.""" -@pytest.fixture(name="config_flow_test_domain") -def config_flow_test_domain_fixture() -> str: +@pytest.fixture(name="config_flow_test_domains") +def config_flow_test_domain_fixture() -> Iterable[str]: """Test domain fixture.""" - return TEST_DOMAIN + return (TEST_DOMAIN,) @pytest.fixture(autouse=True) def config_flow_fixture( - hass: HomeAssistant, config_flow_test_domain: str + hass: HomeAssistant, config_flow_test_domains: Iterable[str] ) -> Generator[None]: """Mock config flow.""" - mock_platform(hass, f"{config_flow_test_domain}.config_flow") + for domain in config_flow_test_domains: + mock_platform(hass, f"{domain}.config_flow") - with mock_config_flow(config_flow_test_domain, STTFlow): + with ExitStack() as stack: + for domain in config_flow_test_domains: + stack.enter_context(mock_config_flow(domain, STTFlow)) yield @@ -496,21 +500,25 @@ async def test_default_engine_entity( assert async_default_engine(hass) == f"{DOMAIN}.{TEST_DOMAIN}" -@pytest.mark.parametrize("config_flow_test_domain", ["new_test"]) +@pytest.mark.parametrize("config_flow_test_domains", [("new_test",)]) async def test_default_engine_prefer_entity( hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity, mock_provider: MockProvider, - config_flow_test_domain: str, + config_flow_test_domains: str, ) -> None: - """Test async_default_engine.""" + """Test async_default_engine. + + In this tests there's an entity and a legacy provider. + The test asserts async_default_engine returns the entity. + """ mock_provider_entity.url_path = "stt.new_test" mock_provider_entity._attr_name = "New test" await mock_setup(hass, tmp_path, mock_provider) await mock_config_entry_setup( - hass, tmp_path, mock_provider_entity, test_domain=config_flow_test_domain + hass, tmp_path, mock_provider_entity, test_domain=config_flow_test_domains[0] ) await hass.async_block_till_done() @@ -523,6 +531,48 @@ async def test_default_engine_prefer_entity( assert async_default_engine(hass) == "stt.new_test" +@pytest.mark.parametrize( + "config_flow_test_domains", + [ + # Test different setup order to ensure the default is not influenced + # by setup order. + ("cloud", "new_test"), + ("new_test", "cloud"), + ], +) +async def test_default_engine_prefer_cloud_entity( + hass: HomeAssistant, + tmp_path: Path, + mock_provider: MockProvider, + config_flow_test_domains: str, +) -> None: + """Test async_default_engine. + + In this tests there's an entity from domain cloud, an entity from domain new_test + and a legacy provider. + The test asserts async_default_engine returns the entity from domain cloud. + """ + await mock_setup(hass, tmp_path, mock_provider) + for domain in config_flow_test_domains: + entity = MockProviderEntity() + entity.url_path = f"stt.{domain}" + entity._attr_name = f"{domain} STT entity" + await mock_config_entry_setup(hass, tmp_path, entity, test_domain=domain) + await hass.async_block_till_done() + + for domain in config_flow_test_domains: + entity_engine = async_get_speech_to_text_engine( + hass, f"stt.{domain}_stt_entity" + ) + assert entity_engine is not None + assert entity_engine.name == f"{domain} STT entity" + + provider_engine = async_get_speech_to_text_engine(hass, "test") + assert provider_engine is not None + assert provider_engine.name == "test" + assert async_default_engine(hass) == "stt.cloud_stt_entity" + + async def test_get_engine_legacy( hass: HomeAssistant, tmp_path: Path, mock_provider: MockProvider ) -> None: diff --git a/tests/components/tts/common.py b/tests/components/tts/common.py index 71edf29721f..4acba401fad 100644 --- a/tests/components/tts/common.py +++ b/tests/components/tts/common.py @@ -215,7 +215,9 @@ async def mock_setup( async def mock_config_entry_setup( - hass: HomeAssistant, tts_entity: MockTTSEntity + hass: HomeAssistant, + tts_entity: MockTTSEntity, + test_domain: str = TEST_DOMAIN, ) -> MockConfigEntry: """Set up a test tts platform via config entry.""" @@ -236,7 +238,7 @@ async def mock_config_entry_setup( mock_integration( hass, MockModule( - TEST_DOMAIN, + test_domain, async_setup_entry=async_setup_entry_init, async_unload_entry=async_unload_entry_init, ), @@ -251,9 +253,9 @@ async def mock_config_entry_setup( async_add_entities([tts_entity]) loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform) - mock_platform(hass, f"{TEST_DOMAIN}.{TTS_DOMAIN}", loaded_platform) + mock_platform(hass, f"{test_domain}.{TTS_DOMAIN}", loaded_platform) - config_entry = MockConfigEntry(domain=TEST_DOMAIN) + config_entry = MockConfigEntry(domain=test_domain) config_entry.add_to_hass(hass) assert await hass.config_entries.async_setup(config_entry.entry_id) await hass.async_block_till_done() diff --git a/tests/components/tts/conftest.py b/tests/components/tts/conftest.py index d9a4499f544..91ddd7742af 100644 --- a/tests/components/tts/conftest.py +++ b/tests/components/tts/conftest.py @@ -3,7 +3,8 @@ From http://doc.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures """ -from collections.abc import Generator +from collections.abc import Generator, Iterable +from contextlib import ExitStack from pathlib import Path from unittest.mock import MagicMock @@ -81,12 +82,23 @@ class TTSFlow(ConfigFlow): """Test flow.""" -@pytest.fixture(autouse=True) -def config_flow_fixture(hass: HomeAssistant) -> Generator[None]: - """Mock config flow.""" - mock_platform(hass, f"{TEST_DOMAIN}.config_flow") +@pytest.fixture(name="config_flow_test_domains") +def config_flow_test_domain_fixture() -> Iterable[str]: + """Test domain fixture.""" + return (TEST_DOMAIN,) - with mock_config_flow(TEST_DOMAIN, TTSFlow): + +@pytest.fixture(autouse=True) +def config_flow_fixture( + hass: HomeAssistant, config_flow_test_domains: Iterable[str] +) -> Generator[None]: + """Mock config flow.""" + for domain in config_flow_test_domains: + mock_platform(hass, f"{domain}.config_flow") + + with ExitStack() as stack: + for domain in config_flow_test_domains: + stack.enter_context(mock_config_flow(domain, TTSFlow)) yield diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index 55ff4492e80..05c19622e84 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -1389,9 +1389,6 @@ def test_resolve_engine(hass: HomeAssistant, setup: str, engine_id: str) -> None ): assert tts.async_resolve_engine(hass, None) is None - with patch.dict(hass.data[tts.DATA_TTS_MANAGER].providers, {"cloud": object()}): - assert tts.async_resolve_engine(hass, None) == "cloud" - @pytest.mark.parametrize( ("setup", "engine_id"), @@ -1845,7 +1842,11 @@ async def test_default_engine_prefer_entity( mock_tts_entity: MockTTSEntity, mock_provider: MockProvider, ) -> None: - """Test async_default_engine.""" + """Test async_default_engine. + + In this tests there's an entity and a legacy provider. + The test asserts async_default_engine returns the entity. + """ mock_tts_entity._attr_name = "New test" await mock_setup(hass, mock_provider) @@ -1857,3 +1858,38 @@ async def test_default_engine_prefer_entity( provider_engine = tts.async_resolve_engine(hass, "test") assert provider_engine == "test" assert tts.async_default_engine(hass) == "tts.new_test" + + +@pytest.mark.parametrize( + "config_flow_test_domains", + [ + # Test different setup order to ensure the default is not influenced + # by setup order. + ("cloud", "new_test"), + ("new_test", "cloud"), + ], +) +async def test_default_engine_prefer_cloud_entity( + hass: HomeAssistant, + mock_provider: MockProvider, + config_flow_test_domains: str, +) -> None: + """Test async_default_engine. + + In this tests there's an entity from domain cloud, an entity from domain new_test + and a legacy provider. + The test asserts async_default_engine returns the entity from domain cloud. + """ + await mock_setup(hass, mock_provider) + for domain in config_flow_test_domains: + entity = MockTTSEntity(DEFAULT_LANG) + entity._attr_name = f"{domain} TTS entity" + await mock_config_entry_setup(hass, entity, test_domain=domain) + await hass.async_block_till_done() + + for domain in config_flow_test_domains: + entity_engine = tts.async_resolve_engine(hass, f"tts.{domain}_tts_entity") + assert entity_engine == f"tts.{domain}_tts_entity" + provider_engine = tts.async_resolve_engine(hass, "test") + assert provider_engine == "test" + assert tts.async_default_engine(hass) == "tts.cloud_tts_entity"