Fix defaults for cloud STT/TTS (#121229)
* Fix defaults for cloud STT/TTS * Prefer entity over legacy provider * Remove unrealistic tests * Add tests which show cloud stt/tts entity is preferred --------- Co-authored-by: Erik <erik@montnemery.com>
This commit is contained in:
parent
547dbf77aa
commit
156948c496
6 changed files with 144 additions and 34 deletions
|
@ -72,9 +72,18 @@ CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
|
||||||
@callback
|
@callback
|
||||||
def async_default_engine(hass: HomeAssistant) -> str | None:
|
def async_default_engine(hass: HomeAssistant) -> str | None:
|
||||||
"""Return the domain or entity id of the default engine."""
|
"""Return the domain or entity id of the default engine."""
|
||||||
return next(
|
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
|
||||||
iter(hass.states.async_entity_ids(DOMAIN)), None
|
|
||||||
) or async_default_provider(hass)
|
default_entity_id: str | None = None
|
||||||
|
|
||||||
|
for entity in component.entities:
|
||||||
|
if entity.platform and entity.platform.platform_name == "cloud":
|
||||||
|
return entity.entity_id
|
||||||
|
|
||||||
|
if default_entity_id is None:
|
||||||
|
default_entity_id = entity.entity_id
|
||||||
|
|
||||||
|
return default_entity_id or async_default_provider(hass)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
|
|
@ -137,15 +137,16 @@ def async_default_engine(hass: HomeAssistant) -> str | None:
|
||||||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
||||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
||||||
|
|
||||||
if "cloud" in manager.providers:
|
default_entity_id: str | None = None
|
||||||
return "cloud"
|
|
||||||
|
|
||||||
entity = next(iter(component.entities), None)
|
for entity in component.entities:
|
||||||
|
if entity.platform and entity.platform.platform_name == "cloud":
|
||||||
|
return entity.entity_id
|
||||||
|
|
||||||
if entity is not None:
|
if default_entity_id is None:
|
||||||
return entity.entity_id
|
default_entity_id = entity.entity_id
|
||||||
|
|
||||||
return next(iter(manager.providers), None)
|
return default_entity_id or next(iter(manager.providers), None)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Test STT component setup."""
|
"""Test STT component setup."""
|
||||||
|
|
||||||
from collections.abc import AsyncIterable, Generator
|
from collections.abc import AsyncIterable, Generator, Iterable
|
||||||
|
from contextlib import ExitStack
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
@ -122,20 +123,23 @@ class STTFlow(ConfigFlow):
|
||||||
"""Test flow."""
|
"""Test flow."""
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="config_flow_test_domain")
|
@pytest.fixture(name="config_flow_test_domains")
|
||||||
def config_flow_test_domain_fixture() -> str:
|
def config_flow_test_domain_fixture() -> Iterable[str]:
|
||||||
"""Test domain fixture."""
|
"""Test domain fixture."""
|
||||||
return TEST_DOMAIN
|
return (TEST_DOMAIN,)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def config_flow_fixture(
|
def config_flow_fixture(
|
||||||
hass: HomeAssistant, config_flow_test_domain: str
|
hass: HomeAssistant, config_flow_test_domains: Iterable[str]
|
||||||
) -> Generator[None]:
|
) -> Generator[None]:
|
||||||
"""Mock config flow."""
|
"""Mock config flow."""
|
||||||
mock_platform(hass, f"{config_flow_test_domain}.config_flow")
|
for domain in config_flow_test_domains:
|
||||||
|
mock_platform(hass, f"{domain}.config_flow")
|
||||||
|
|
||||||
with mock_config_flow(config_flow_test_domain, STTFlow):
|
with ExitStack() as stack:
|
||||||
|
for domain in config_flow_test_domains:
|
||||||
|
stack.enter_context(mock_config_flow(domain, STTFlow))
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@ -496,21 +500,25 @@ async def test_default_engine_entity(
|
||||||
assert async_default_engine(hass) == f"{DOMAIN}.{TEST_DOMAIN}"
|
assert async_default_engine(hass) == f"{DOMAIN}.{TEST_DOMAIN}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("config_flow_test_domain", ["new_test"])
|
@pytest.mark.parametrize("config_flow_test_domains", [("new_test",)])
|
||||||
async def test_default_engine_prefer_entity(
|
async def test_default_engine_prefer_entity(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
mock_provider_entity: MockProviderEntity,
|
mock_provider_entity: MockProviderEntity,
|
||||||
mock_provider: MockProvider,
|
mock_provider: MockProvider,
|
||||||
config_flow_test_domain: str,
|
config_flow_test_domains: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test async_default_engine."""
|
"""Test async_default_engine.
|
||||||
|
|
||||||
|
In this tests there's an entity and a legacy provider.
|
||||||
|
The test asserts async_default_engine returns the entity.
|
||||||
|
"""
|
||||||
mock_provider_entity.url_path = "stt.new_test"
|
mock_provider_entity.url_path = "stt.new_test"
|
||||||
mock_provider_entity._attr_name = "New test"
|
mock_provider_entity._attr_name = "New test"
|
||||||
|
|
||||||
await mock_setup(hass, tmp_path, mock_provider)
|
await mock_setup(hass, tmp_path, mock_provider)
|
||||||
await mock_config_entry_setup(
|
await mock_config_entry_setup(
|
||||||
hass, tmp_path, mock_provider_entity, test_domain=config_flow_test_domain
|
hass, tmp_path, mock_provider_entity, test_domain=config_flow_test_domains[0]
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
@ -523,6 +531,48 @@ async def test_default_engine_prefer_entity(
|
||||||
assert async_default_engine(hass) == "stt.new_test"
|
assert async_default_engine(hass) == "stt.new_test"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"config_flow_test_domains",
|
||||||
|
[
|
||||||
|
# Test different setup order to ensure the default is not influenced
|
||||||
|
# by setup order.
|
||||||
|
("cloud", "new_test"),
|
||||||
|
("new_test", "cloud"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_default_engine_prefer_cloud_entity(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
tmp_path: Path,
|
||||||
|
mock_provider: MockProvider,
|
||||||
|
config_flow_test_domains: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test async_default_engine.
|
||||||
|
|
||||||
|
In this tests there's an entity from domain cloud, an entity from domain new_test
|
||||||
|
and a legacy provider.
|
||||||
|
The test asserts async_default_engine returns the entity from domain cloud.
|
||||||
|
"""
|
||||||
|
await mock_setup(hass, tmp_path, mock_provider)
|
||||||
|
for domain in config_flow_test_domains:
|
||||||
|
entity = MockProviderEntity()
|
||||||
|
entity.url_path = f"stt.{domain}"
|
||||||
|
entity._attr_name = f"{domain} STT entity"
|
||||||
|
await mock_config_entry_setup(hass, tmp_path, entity, test_domain=domain)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
for domain in config_flow_test_domains:
|
||||||
|
entity_engine = async_get_speech_to_text_engine(
|
||||||
|
hass, f"stt.{domain}_stt_entity"
|
||||||
|
)
|
||||||
|
assert entity_engine is not None
|
||||||
|
assert entity_engine.name == f"{domain} STT entity"
|
||||||
|
|
||||||
|
provider_engine = async_get_speech_to_text_engine(hass, "test")
|
||||||
|
assert provider_engine is not None
|
||||||
|
assert provider_engine.name == "test"
|
||||||
|
assert async_default_engine(hass) == "stt.cloud_stt_entity"
|
||||||
|
|
||||||
|
|
||||||
async def test_get_engine_legacy(
|
async def test_get_engine_legacy(
|
||||||
hass: HomeAssistant, tmp_path: Path, mock_provider: MockProvider
|
hass: HomeAssistant, tmp_path: Path, mock_provider: MockProvider
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
@ -215,7 +215,9 @@ async def mock_setup(
|
||||||
|
|
||||||
|
|
||||||
async def mock_config_entry_setup(
|
async def mock_config_entry_setup(
|
||||||
hass: HomeAssistant, tts_entity: MockTTSEntity
|
hass: HomeAssistant,
|
||||||
|
tts_entity: MockTTSEntity,
|
||||||
|
test_domain: str = TEST_DOMAIN,
|
||||||
) -> MockConfigEntry:
|
) -> MockConfigEntry:
|
||||||
"""Set up a test tts platform via config entry."""
|
"""Set up a test tts platform via config entry."""
|
||||||
|
|
||||||
|
@ -236,7 +238,7 @@ async def mock_config_entry_setup(
|
||||||
mock_integration(
|
mock_integration(
|
||||||
hass,
|
hass,
|
||||||
MockModule(
|
MockModule(
|
||||||
TEST_DOMAIN,
|
test_domain,
|
||||||
async_setup_entry=async_setup_entry_init,
|
async_setup_entry=async_setup_entry_init,
|
||||||
async_unload_entry=async_unload_entry_init,
|
async_unload_entry=async_unload_entry_init,
|
||||||
),
|
),
|
||||||
|
@ -251,9 +253,9 @@ async def mock_config_entry_setup(
|
||||||
async_add_entities([tts_entity])
|
async_add_entities([tts_entity])
|
||||||
|
|
||||||
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform)
|
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform)
|
||||||
mock_platform(hass, f"{TEST_DOMAIN}.{TTS_DOMAIN}", loaded_platform)
|
mock_platform(hass, f"{test_domain}.{TTS_DOMAIN}", loaded_platform)
|
||||||
|
|
||||||
config_entry = MockConfigEntry(domain=TEST_DOMAIN)
|
config_entry = MockConfigEntry(domain=test_domain)
|
||||||
config_entry.add_to_hass(hass)
|
config_entry.add_to_hass(hass)
|
||||||
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
|
@ -3,7 +3,8 @@
|
||||||
From http://doc.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures
|
From http://doc.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator, Iterable
|
||||||
|
from contextlib import ExitStack
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
@ -81,12 +82,23 @@ class TTSFlow(ConfigFlow):
|
||||||
"""Test flow."""
|
"""Test flow."""
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(name="config_flow_test_domains")
|
||||||
def config_flow_fixture(hass: HomeAssistant) -> Generator[None]:
|
def config_flow_test_domain_fixture() -> Iterable[str]:
|
||||||
"""Mock config flow."""
|
"""Test domain fixture."""
|
||||||
mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
|
return (TEST_DOMAIN,)
|
||||||
|
|
||||||
with mock_config_flow(TEST_DOMAIN, TTSFlow):
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def config_flow_fixture(
|
||||||
|
hass: HomeAssistant, config_flow_test_domains: Iterable[str]
|
||||||
|
) -> Generator[None]:
|
||||||
|
"""Mock config flow."""
|
||||||
|
for domain in config_flow_test_domains:
|
||||||
|
mock_platform(hass, f"{domain}.config_flow")
|
||||||
|
|
||||||
|
with ExitStack() as stack:
|
||||||
|
for domain in config_flow_test_domains:
|
||||||
|
stack.enter_context(mock_config_flow(domain, TTSFlow))
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1389,9 +1389,6 @@ def test_resolve_engine(hass: HomeAssistant, setup: str, engine_id: str) -> None
|
||||||
):
|
):
|
||||||
assert tts.async_resolve_engine(hass, None) is None
|
assert tts.async_resolve_engine(hass, None) is None
|
||||||
|
|
||||||
with patch.dict(hass.data[tts.DATA_TTS_MANAGER].providers, {"cloud": object()}):
|
|
||||||
assert tts.async_resolve_engine(hass, None) == "cloud"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("setup", "engine_id"),
|
("setup", "engine_id"),
|
||||||
|
@ -1845,7 +1842,11 @@ async def test_default_engine_prefer_entity(
|
||||||
mock_tts_entity: MockTTSEntity,
|
mock_tts_entity: MockTTSEntity,
|
||||||
mock_provider: MockProvider,
|
mock_provider: MockProvider,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test async_default_engine."""
|
"""Test async_default_engine.
|
||||||
|
|
||||||
|
In this tests there's an entity and a legacy provider.
|
||||||
|
The test asserts async_default_engine returns the entity.
|
||||||
|
"""
|
||||||
mock_tts_entity._attr_name = "New test"
|
mock_tts_entity._attr_name = "New test"
|
||||||
|
|
||||||
await mock_setup(hass, mock_provider)
|
await mock_setup(hass, mock_provider)
|
||||||
|
@ -1857,3 +1858,38 @@ async def test_default_engine_prefer_entity(
|
||||||
provider_engine = tts.async_resolve_engine(hass, "test")
|
provider_engine = tts.async_resolve_engine(hass, "test")
|
||||||
assert provider_engine == "test"
|
assert provider_engine == "test"
|
||||||
assert tts.async_default_engine(hass) == "tts.new_test"
|
assert tts.async_default_engine(hass) == "tts.new_test"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"config_flow_test_domains",
|
||||||
|
[
|
||||||
|
# Test different setup order to ensure the default is not influenced
|
||||||
|
# by setup order.
|
||||||
|
("cloud", "new_test"),
|
||||||
|
("new_test", "cloud"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_default_engine_prefer_cloud_entity(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_provider: MockProvider,
|
||||||
|
config_flow_test_domains: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test async_default_engine.
|
||||||
|
|
||||||
|
In this tests there's an entity from domain cloud, an entity from domain new_test
|
||||||
|
and a legacy provider.
|
||||||
|
The test asserts async_default_engine returns the entity from domain cloud.
|
||||||
|
"""
|
||||||
|
await mock_setup(hass, mock_provider)
|
||||||
|
for domain in config_flow_test_domains:
|
||||||
|
entity = MockTTSEntity(DEFAULT_LANG)
|
||||||
|
entity._attr_name = f"{domain} TTS entity"
|
||||||
|
await mock_config_entry_setup(hass, entity, test_domain=domain)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
for domain in config_flow_test_domains:
|
||||||
|
entity_engine = tts.async_resolve_engine(hass, f"tts.{domain}_tts_entity")
|
||||||
|
assert entity_engine == f"tts.{domain}_tts_entity"
|
||||||
|
provider_engine = tts.async_resolve_engine(hass, "test")
|
||||||
|
assert provider_engine == "test"
|
||||||
|
assert tts.async_default_engine(hass) == "tts.cloud_tts_entity"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue