Fix defaults for cloud STT/TTS (#121229)

* Fix defaults for cloud STT/TTS

* Prefer entity over legacy provider

* Remove unrealistic tests

* Add tests which show cloud stt/tts entity is preferred

---------

Co-authored-by: Erik <erik@montnemery.com>
This commit is contained in:
Paulus Schoutsen 2024-08-26 19:39:09 +02:00 committed by GitHub
parent 547dbf77aa
commit 156948c496
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 144 additions and 34 deletions

View file

@ -72,9 +72,18 @@ CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
@callback
def async_default_engine(hass: HomeAssistant) -> str | None:
"""Return the domain or entity id of the default engine."""
return next(
iter(hass.states.async_entity_ids(DOMAIN)), None
) or async_default_provider(hass)
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
default_entity_id: str | None = None
for entity in component.entities:
if entity.platform and entity.platform.platform_name == "cloud":
return entity.entity_id
if default_entity_id is None:
default_entity_id = entity.entity_id
return default_entity_id or async_default_provider(hass)
@callback

View file

@ -137,15 +137,16 @@ def async_default_engine(hass: HomeAssistant) -> str | None:
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
if "cloud" in manager.providers:
return "cloud"
default_entity_id: str | None = None
entity = next(iter(component.entities), None)
for entity in component.entities:
if entity.platform and entity.platform.platform_name == "cloud":
return entity.entity_id
if entity is not None:
return entity.entity_id
if default_entity_id is None:
default_entity_id = entity.entity_id
return next(iter(manager.providers), None)
return default_entity_id or next(iter(manager.providers), None)
@callback

View file

@ -1,6 +1,7 @@
"""Test STT component setup."""
from collections.abc import AsyncIterable, Generator
from collections.abc import AsyncIterable, Generator, Iterable
from contextlib import ExitStack
from http import HTTPStatus
from pathlib import Path
from unittest.mock import AsyncMock
@ -122,20 +123,23 @@ class STTFlow(ConfigFlow):
"""Test flow."""
@pytest.fixture(name="config_flow_test_domain")
def config_flow_test_domain_fixture() -> str:
@pytest.fixture(name="config_flow_test_domains")
def config_flow_test_domain_fixture() -> Iterable[str]:
"""Test domain fixture."""
return TEST_DOMAIN
return (TEST_DOMAIN,)
@pytest.fixture(autouse=True)
def config_flow_fixture(
hass: HomeAssistant, config_flow_test_domain: str
hass: HomeAssistant, config_flow_test_domains: Iterable[str]
) -> Generator[None]:
"""Mock config flow."""
mock_platform(hass, f"{config_flow_test_domain}.config_flow")
for domain in config_flow_test_domains:
mock_platform(hass, f"{domain}.config_flow")
with mock_config_flow(config_flow_test_domain, STTFlow):
with ExitStack() as stack:
for domain in config_flow_test_domains:
stack.enter_context(mock_config_flow(domain, STTFlow))
yield
@ -496,21 +500,25 @@ async def test_default_engine_entity(
assert async_default_engine(hass) == f"{DOMAIN}.{TEST_DOMAIN}"
@pytest.mark.parametrize("config_flow_test_domain", ["new_test"])
@pytest.mark.parametrize("config_flow_test_domains", [("new_test",)])
async def test_default_engine_prefer_entity(
hass: HomeAssistant,
tmp_path: Path,
mock_provider_entity: MockProviderEntity,
mock_provider: MockProvider,
config_flow_test_domain: str,
config_flow_test_domains: str,
) -> None:
"""Test async_default_engine."""
"""Test async_default_engine.
In this tests there's an entity and a legacy provider.
The test asserts async_default_engine returns the entity.
"""
mock_provider_entity.url_path = "stt.new_test"
mock_provider_entity._attr_name = "New test"
await mock_setup(hass, tmp_path, mock_provider)
await mock_config_entry_setup(
hass, tmp_path, mock_provider_entity, test_domain=config_flow_test_domain
hass, tmp_path, mock_provider_entity, test_domain=config_flow_test_domains[0]
)
await hass.async_block_till_done()
@ -523,6 +531,48 @@ async def test_default_engine_prefer_entity(
assert async_default_engine(hass) == "stt.new_test"
@pytest.mark.parametrize(
"config_flow_test_domains",
[
# Test different setup order to ensure the default is not influenced
# by setup order.
("cloud", "new_test"),
("new_test", "cloud"),
],
)
async def test_default_engine_prefer_cloud_entity(
hass: HomeAssistant,
tmp_path: Path,
mock_provider: MockProvider,
config_flow_test_domains: str,
) -> None:
"""Test async_default_engine.
In this tests there's an entity from domain cloud, an entity from domain new_test
and a legacy provider.
The test asserts async_default_engine returns the entity from domain cloud.
"""
await mock_setup(hass, tmp_path, mock_provider)
for domain in config_flow_test_domains:
entity = MockProviderEntity()
entity.url_path = f"stt.{domain}"
entity._attr_name = f"{domain} STT entity"
await mock_config_entry_setup(hass, tmp_path, entity, test_domain=domain)
await hass.async_block_till_done()
for domain in config_flow_test_domains:
entity_engine = async_get_speech_to_text_engine(
hass, f"stt.{domain}_stt_entity"
)
assert entity_engine is not None
assert entity_engine.name == f"{domain} STT entity"
provider_engine = async_get_speech_to_text_engine(hass, "test")
assert provider_engine is not None
assert provider_engine.name == "test"
assert async_default_engine(hass) == "stt.cloud_stt_entity"
async def test_get_engine_legacy(
hass: HomeAssistant, tmp_path: Path, mock_provider: MockProvider
) -> None:

View file

@ -215,7 +215,9 @@ async def mock_setup(
async def mock_config_entry_setup(
hass: HomeAssistant, tts_entity: MockTTSEntity
hass: HomeAssistant,
tts_entity: MockTTSEntity,
test_domain: str = TEST_DOMAIN,
) -> MockConfigEntry:
"""Set up a test tts platform via config entry."""
@ -236,7 +238,7 @@ async def mock_config_entry_setup(
mock_integration(
hass,
MockModule(
TEST_DOMAIN,
test_domain,
async_setup_entry=async_setup_entry_init,
async_unload_entry=async_unload_entry_init,
),
@ -251,9 +253,9 @@ async def mock_config_entry_setup(
async_add_entities([tts_entity])
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform)
mock_platform(hass, f"{TEST_DOMAIN}.{TTS_DOMAIN}", loaded_platform)
mock_platform(hass, f"{test_domain}.{TTS_DOMAIN}", loaded_platform)
config_entry = MockConfigEntry(domain=TEST_DOMAIN)
config_entry = MockConfigEntry(domain=test_domain)
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()

View file

@ -3,7 +3,8 @@
From http://doc.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures
"""
from collections.abc import Generator
from collections.abc import Generator, Iterable
from contextlib import ExitStack
from pathlib import Path
from unittest.mock import MagicMock
@ -81,12 +82,23 @@ class TTSFlow(ConfigFlow):
"""Test flow."""
@pytest.fixture(autouse=True)
def config_flow_fixture(hass: HomeAssistant) -> Generator[None]:
"""Mock config flow."""
mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
@pytest.fixture(name="config_flow_test_domains")
def config_flow_test_domain_fixture() -> Iterable[str]:
"""Test domain fixture."""
return (TEST_DOMAIN,)
with mock_config_flow(TEST_DOMAIN, TTSFlow):
@pytest.fixture(autouse=True)
def config_flow_fixture(
hass: HomeAssistant, config_flow_test_domains: Iterable[str]
) -> Generator[None]:
"""Mock config flow."""
for domain in config_flow_test_domains:
mock_platform(hass, f"{domain}.config_flow")
with ExitStack() as stack:
for domain in config_flow_test_domains:
stack.enter_context(mock_config_flow(domain, TTSFlow))
yield

View file

@ -1389,9 +1389,6 @@ def test_resolve_engine(hass: HomeAssistant, setup: str, engine_id: str) -> None
):
assert tts.async_resolve_engine(hass, None) is None
with patch.dict(hass.data[tts.DATA_TTS_MANAGER].providers, {"cloud": object()}):
assert tts.async_resolve_engine(hass, None) == "cloud"
@pytest.mark.parametrize(
("setup", "engine_id"),
@ -1845,7 +1842,11 @@ async def test_default_engine_prefer_entity(
mock_tts_entity: MockTTSEntity,
mock_provider: MockProvider,
) -> None:
"""Test async_default_engine."""
"""Test async_default_engine.
In this tests there's an entity and a legacy provider.
The test asserts async_default_engine returns the entity.
"""
mock_tts_entity._attr_name = "New test"
await mock_setup(hass, mock_provider)
@ -1857,3 +1858,38 @@ async def test_default_engine_prefer_entity(
provider_engine = tts.async_resolve_engine(hass, "test")
assert provider_engine == "test"
assert tts.async_default_engine(hass) == "tts.new_test"
@pytest.mark.parametrize(
"config_flow_test_domains",
[
# Test different setup order to ensure the default is not influenced
# by setup order.
("cloud", "new_test"),
("new_test", "cloud"),
],
)
async def test_default_engine_prefer_cloud_entity(
hass: HomeAssistant,
mock_provider: MockProvider,
config_flow_test_domains: str,
) -> None:
"""Test async_default_engine.
In this tests there's an entity from domain cloud, an entity from domain new_test
and a legacy provider.
The test asserts async_default_engine returns the entity from domain cloud.
"""
await mock_setup(hass, mock_provider)
for domain in config_flow_test_domains:
entity = MockTTSEntity(DEFAULT_LANG)
entity._attr_name = f"{domain} TTS entity"
await mock_config_entry_setup(hass, entity, test_domain=domain)
await hass.async_block_till_done()
for domain in config_flow_test_domains:
entity_engine = tts.async_resolve_engine(hass, f"tts.{domain}_tts_entity")
assert entity_engine == f"tts.{domain}_tts_entity"
provider_engine = tts.async_resolve_engine(hass, "test")
assert provider_engine == "test"
assert tts.async_default_engine(hass) == "tts.cloud_tts_entity"