Fix defaults for cloud STT/TTS (#121229)
* Fix defaults for cloud STT/TTS * Prefer entity over legacy provider * Remove unrealistic tests * Add tests which show cloud stt/tts entity is preferred --------- Co-authored-by: Erik <erik@montnemery.com>
This commit is contained in:
parent
547dbf77aa
commit
156948c496
6 changed files with 144 additions and 34 deletions
|
@ -72,9 +72,18 @@ CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
|
|||
@callback
|
||||
def async_default_engine(hass: HomeAssistant) -> str | None:
|
||||
"""Return the domain or entity id of the default engine."""
|
||||
return next(
|
||||
iter(hass.states.async_entity_ids(DOMAIN)), None
|
||||
) or async_default_provider(hass)
|
||||
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
|
||||
|
||||
default_entity_id: str | None = None
|
||||
|
||||
for entity in component.entities:
|
||||
if entity.platform and entity.platform.platform_name == "cloud":
|
||||
return entity.entity_id
|
||||
|
||||
if default_entity_id is None:
|
||||
default_entity_id = entity.entity_id
|
||||
|
||||
return default_entity_id or async_default_provider(hass)
|
||||
|
||||
|
||||
@callback
|
||||
|
|
|
@ -137,15 +137,16 @@ def async_default_engine(hass: HomeAssistant) -> str | None:
|
|||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
||||
|
||||
if "cloud" in manager.providers:
|
||||
return "cloud"
|
||||
default_entity_id: str | None = None
|
||||
|
||||
entity = next(iter(component.entities), None)
|
||||
for entity in component.entities:
|
||||
if entity.platform and entity.platform.platform_name == "cloud":
|
||||
return entity.entity_id
|
||||
|
||||
if entity is not None:
|
||||
return entity.entity_id
|
||||
if default_entity_id is None:
|
||||
default_entity_id = entity.entity_id
|
||||
|
||||
return next(iter(manager.providers), None)
|
||||
return default_entity_id or next(iter(manager.providers), None)
|
||||
|
||||
|
||||
@callback
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Test STT component setup."""
|
||||
|
||||
from collections.abc import AsyncIterable, Generator
|
||||
from collections.abc import AsyncIterable, Generator, Iterable
|
||||
from contextlib import ExitStack
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock
|
||||
|
@ -122,20 +123,23 @@ class STTFlow(ConfigFlow):
|
|||
"""Test flow."""
|
||||
|
||||
|
||||
@pytest.fixture(name="config_flow_test_domain")
|
||||
def config_flow_test_domain_fixture() -> str:
|
||||
@pytest.fixture(name="config_flow_test_domains")
|
||||
def config_flow_test_domain_fixture() -> Iterable[str]:
|
||||
"""Test domain fixture."""
|
||||
return TEST_DOMAIN
|
||||
return (TEST_DOMAIN,)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def config_flow_fixture(
|
||||
hass: HomeAssistant, config_flow_test_domain: str
|
||||
hass: HomeAssistant, config_flow_test_domains: Iterable[str]
|
||||
) -> Generator[None]:
|
||||
"""Mock config flow."""
|
||||
mock_platform(hass, f"{config_flow_test_domain}.config_flow")
|
||||
for domain in config_flow_test_domains:
|
||||
mock_platform(hass, f"{domain}.config_flow")
|
||||
|
||||
with mock_config_flow(config_flow_test_domain, STTFlow):
|
||||
with ExitStack() as stack:
|
||||
for domain in config_flow_test_domains:
|
||||
stack.enter_context(mock_config_flow(domain, STTFlow))
|
||||
yield
|
||||
|
||||
|
||||
|
@ -496,21 +500,25 @@ async def test_default_engine_entity(
|
|||
assert async_default_engine(hass) == f"{DOMAIN}.{TEST_DOMAIN}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config_flow_test_domain", ["new_test"])
|
||||
@pytest.mark.parametrize("config_flow_test_domains", [("new_test",)])
|
||||
async def test_default_engine_prefer_entity(
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
mock_provider_entity: MockProviderEntity,
|
||||
mock_provider: MockProvider,
|
||||
config_flow_test_domain: str,
|
||||
config_flow_test_domains: str,
|
||||
) -> None:
|
||||
"""Test async_default_engine."""
|
||||
"""Test async_default_engine.
|
||||
|
||||
In this tests there's an entity and a legacy provider.
|
||||
The test asserts async_default_engine returns the entity.
|
||||
"""
|
||||
mock_provider_entity.url_path = "stt.new_test"
|
||||
mock_provider_entity._attr_name = "New test"
|
||||
|
||||
await mock_setup(hass, tmp_path, mock_provider)
|
||||
await mock_config_entry_setup(
|
||||
hass, tmp_path, mock_provider_entity, test_domain=config_flow_test_domain
|
||||
hass, tmp_path, mock_provider_entity, test_domain=config_flow_test_domains[0]
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
@ -523,6 +531,48 @@ async def test_default_engine_prefer_entity(
|
|||
assert async_default_engine(hass) == "stt.new_test"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_flow_test_domains",
|
||||
[
|
||||
# Test different setup order to ensure the default is not influenced
|
||||
# by setup order.
|
||||
("cloud", "new_test"),
|
||||
("new_test", "cloud"),
|
||||
],
|
||||
)
|
||||
async def test_default_engine_prefer_cloud_entity(
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
mock_provider: MockProvider,
|
||||
config_flow_test_domains: str,
|
||||
) -> None:
|
||||
"""Test async_default_engine.
|
||||
|
||||
In this tests there's an entity from domain cloud, an entity from domain new_test
|
||||
and a legacy provider.
|
||||
The test asserts async_default_engine returns the entity from domain cloud.
|
||||
"""
|
||||
await mock_setup(hass, tmp_path, mock_provider)
|
||||
for domain in config_flow_test_domains:
|
||||
entity = MockProviderEntity()
|
||||
entity.url_path = f"stt.{domain}"
|
||||
entity._attr_name = f"{domain} STT entity"
|
||||
await mock_config_entry_setup(hass, tmp_path, entity, test_domain=domain)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
for domain in config_flow_test_domains:
|
||||
entity_engine = async_get_speech_to_text_engine(
|
||||
hass, f"stt.{domain}_stt_entity"
|
||||
)
|
||||
assert entity_engine is not None
|
||||
assert entity_engine.name == f"{domain} STT entity"
|
||||
|
||||
provider_engine = async_get_speech_to_text_engine(hass, "test")
|
||||
assert provider_engine is not None
|
||||
assert provider_engine.name == "test"
|
||||
assert async_default_engine(hass) == "stt.cloud_stt_entity"
|
||||
|
||||
|
||||
async def test_get_engine_legacy(
|
||||
hass: HomeAssistant, tmp_path: Path, mock_provider: MockProvider
|
||||
) -> None:
|
||||
|
|
|
@ -215,7 +215,9 @@ async def mock_setup(
|
|||
|
||||
|
||||
async def mock_config_entry_setup(
|
||||
hass: HomeAssistant, tts_entity: MockTTSEntity
|
||||
hass: HomeAssistant,
|
||||
tts_entity: MockTTSEntity,
|
||||
test_domain: str = TEST_DOMAIN,
|
||||
) -> MockConfigEntry:
|
||||
"""Set up a test tts platform via config entry."""
|
||||
|
||||
|
@ -236,7 +238,7 @@ async def mock_config_entry_setup(
|
|||
mock_integration(
|
||||
hass,
|
||||
MockModule(
|
||||
TEST_DOMAIN,
|
||||
test_domain,
|
||||
async_setup_entry=async_setup_entry_init,
|
||||
async_unload_entry=async_unload_entry_init,
|
||||
),
|
||||
|
@ -251,9 +253,9 @@ async def mock_config_entry_setup(
|
|||
async_add_entities([tts_entity])
|
||||
|
||||
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform)
|
||||
mock_platform(hass, f"{TEST_DOMAIN}.{TTS_DOMAIN}", loaded_platform)
|
||||
mock_platform(hass, f"{test_domain}.{TTS_DOMAIN}", loaded_platform)
|
||||
|
||||
config_entry = MockConfigEntry(domain=TEST_DOMAIN)
|
||||
config_entry = MockConfigEntry(domain=test_domain)
|
||||
config_entry.add_to_hass(hass)
|
||||
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
|
|
@ -3,7 +3,8 @@
|
|||
From http://doc.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Iterable
|
||||
from contextlib import ExitStack
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
@ -81,12 +82,23 @@ class TTSFlow(ConfigFlow):
|
|||
"""Test flow."""
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def config_flow_fixture(hass: HomeAssistant) -> Generator[None]:
|
||||
"""Mock config flow."""
|
||||
mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
|
||||
@pytest.fixture(name="config_flow_test_domains")
|
||||
def config_flow_test_domain_fixture() -> Iterable[str]:
|
||||
"""Test domain fixture."""
|
||||
return (TEST_DOMAIN,)
|
||||
|
||||
with mock_config_flow(TEST_DOMAIN, TTSFlow):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def config_flow_fixture(
|
||||
hass: HomeAssistant, config_flow_test_domains: Iterable[str]
|
||||
) -> Generator[None]:
|
||||
"""Mock config flow."""
|
||||
for domain in config_flow_test_domains:
|
||||
mock_platform(hass, f"{domain}.config_flow")
|
||||
|
||||
with ExitStack() as stack:
|
||||
for domain in config_flow_test_domains:
|
||||
stack.enter_context(mock_config_flow(domain, TTSFlow))
|
||||
yield
|
||||
|
||||
|
||||
|
|
|
@ -1389,9 +1389,6 @@ def test_resolve_engine(hass: HomeAssistant, setup: str, engine_id: str) -> None
|
|||
):
|
||||
assert tts.async_resolve_engine(hass, None) is None
|
||||
|
||||
with patch.dict(hass.data[tts.DATA_TTS_MANAGER].providers, {"cloud": object()}):
|
||||
assert tts.async_resolve_engine(hass, None) == "cloud"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("setup", "engine_id"),
|
||||
|
@ -1845,7 +1842,11 @@ async def test_default_engine_prefer_entity(
|
|||
mock_tts_entity: MockTTSEntity,
|
||||
mock_provider: MockProvider,
|
||||
) -> None:
|
||||
"""Test async_default_engine."""
|
||||
"""Test async_default_engine.
|
||||
|
||||
In this tests there's an entity and a legacy provider.
|
||||
The test asserts async_default_engine returns the entity.
|
||||
"""
|
||||
mock_tts_entity._attr_name = "New test"
|
||||
|
||||
await mock_setup(hass, mock_provider)
|
||||
|
@ -1857,3 +1858,38 @@ async def test_default_engine_prefer_entity(
|
|||
provider_engine = tts.async_resolve_engine(hass, "test")
|
||||
assert provider_engine == "test"
|
||||
assert tts.async_default_engine(hass) == "tts.new_test"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_flow_test_domains",
|
||||
[
|
||||
# Test different setup order to ensure the default is not influenced
|
||||
# by setup order.
|
||||
("cloud", "new_test"),
|
||||
("new_test", "cloud"),
|
||||
],
|
||||
)
|
||||
async def test_default_engine_prefer_cloud_entity(
|
||||
hass: HomeAssistant,
|
||||
mock_provider: MockProvider,
|
||||
config_flow_test_domains: str,
|
||||
) -> None:
|
||||
"""Test async_default_engine.
|
||||
|
||||
In this tests there's an entity from domain cloud, an entity from domain new_test
|
||||
and a legacy provider.
|
||||
The test asserts async_default_engine returns the entity from domain cloud.
|
||||
"""
|
||||
await mock_setup(hass, mock_provider)
|
||||
for domain in config_flow_test_domains:
|
||||
entity = MockTTSEntity(DEFAULT_LANG)
|
||||
entity._attr_name = f"{domain} TTS entity"
|
||||
await mock_config_entry_setup(hass, entity, test_domain=domain)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
for domain in config_flow_test_domains:
|
||||
entity_engine = tts.async_resolve_engine(hass, f"tts.{domain}_tts_entity")
|
||||
assert entity_engine == f"tts.{domain}_tts_entity"
|
||||
provider_engine = tts.async_resolve_engine(hass, "test")
|
||||
assert provider_engine == "test"
|
||||
assert tts.async_default_engine(hass) == "tts.cloud_tts_entity"
|
||||
|
|
Loading…
Add table
Reference in a new issue