Prefer stt entity over legacy stt provider (#124625)

* Prefer stt entity over legacy stt provider

* Update assist_pipeline tests
This commit is contained in:
Erik Montnemery 2024-08-26 13:43:14 +02:00 committed by GitHub
parent 0a05cdc381
commit 7b71f024fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 38 additions and 38 deletions

View file

@ -72,9 +72,9 @@ CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
@callback
def async_default_engine(hass: HomeAssistant) -> str | None:
"""Return the domain or entity id of the default engine."""
return async_default_provider(hass) or next(
return next(
iter(hass.states.async_entity_ids(DOMAIN)), None
)
) or async_default_provider(hass)
@callback

View file

@ -10,7 +10,7 @@
}),
dict({
'data': dict({
'engine': 'test',
'engine': 'stt.mock_stt',
'metadata': dict({
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
'channel': <AudioChannels.CHANNEL_MONO: 1>,
@ -301,7 +301,7 @@
}),
dict({
'data': dict({
'engine': 'test',
'engine': 'stt.mock_stt',
'metadata': dict({
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
'channel': <AudioChannels.CHANNEL_MONO: 1>,

View file

@ -11,7 +11,7 @@
# ---
# name: test_audio_pipeline.1
dict({
'engine': 'test',
'engine': 'stt.mock_stt',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
@ -92,7 +92,7 @@
# ---
# name: test_audio_pipeline_debug.1
dict({
'engine': 'test',
'engine': 'stt.mock_stt',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
@ -185,7 +185,7 @@
# ---
# name: test_audio_pipeline_with_enhancements.1
dict({
'engine': 'test',
'engine': 'stt.mock_stt',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
@ -288,7 +288,7 @@
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.3
dict({
'engine': 'test',
'engine': 'stt.mock_stt',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
@ -401,7 +401,7 @@
# ---
# name: test_device_capture.1
dict({
'engine': 'test',
'engine': 'stt.mock_stt',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
@ -427,7 +427,7 @@
# ---
# name: test_device_capture_override.1
dict({
'engine': 'test',
'engine': 'stt.mock_stt',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
@ -475,7 +475,7 @@
# ---
# name: test_device_capture_queue_full.1
dict({
'engine': 'test',
'engine': 'stt.mock_stt',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
@ -649,7 +649,7 @@
# ---
# name: test_stt_stream_failed.1
dict({
'engine': 'test',
'engine': 'stt.mock_stt',
'metadata': dict({
'bit_rate': 16,
'channel': 1,

View file

@ -47,7 +47,7 @@ def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
async def test_pipeline_from_audio_stream_auto(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_stt_provider_entity: MockSttProviderEntity,
init_components,
snapshot: SnapshotAssertion,
) -> None:
@ -80,9 +80,9 @@ async def test_pipeline_from_audio_stream_auto(
)
assert process_events(events) == snapshot
assert len(mock_stt_provider.received) == 2
assert mock_stt_provider.received[0].startswith(b"part1")
assert mock_stt_provider.received[1].startswith(b"part2")
assert len(mock_stt_provider_entity.received) == 2
assert mock_stt_provider_entity.received[0].startswith(b"part1")
assert mock_stt_provider_entity.received[1].startswith(b"part2")
async def test_pipeline_from_audio_stream_legacy(
@ -319,7 +319,7 @@ async def test_pipeline_from_audio_stream_unknown_pipeline(
async def test_pipeline_from_audio_stream_wake_word(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_stt_provider_entity: MockSttProviderEntity,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_components,
snapshot: SnapshotAssertion,
@ -381,16 +381,16 @@ async def test_pipeline_from_audio_stream_wake_word(
# 2. queued audio (from mock wake word entity)
# 3. part1
# 4. part2
assert len(mock_stt_provider.received) > 3
assert len(mock_stt_provider_entity.received) > 3
first_chunk = bytes(
[c_byte for c in mock_stt_provider.received[:-3] for c_byte in c]
[c_byte for c in mock_stt_provider_entity.received[:-3] for c_byte in c]
)
assert first_chunk == wake_chunk_1[len(wake_chunk_1) // 2 :] + wake_chunk_2
assert mock_stt_provider.received[-3] == b"queued audio"
assert mock_stt_provider.received[-2].startswith(b"part1")
assert mock_stt_provider.received[-1].startswith(b"part2")
assert mock_stt_provider_entity.received[-3] == b"queued audio"
assert mock_stt_provider_entity.received[-2].startswith(b"part1")
assert mock_stt_provider_entity.received[-1].startswith(b"part2")
async def test_pipeline_save_audio(

View file

@ -26,7 +26,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from . import MANY_LANGUAGES
from .conftest import MockSttProvider, MockTTSProvider
from .conftest import MockSttProviderEntity, MockTTSProvider
from tests.common import flush_store
@ -398,7 +398,7 @@ async def test_default_pipeline_no_stt_tts(
@pytest.mark.usefixtures("init_supporting_components")
async def test_default_pipeline(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_stt_provider_entity: MockSttProviderEntity,
mock_tts_provider: MockTTSProvider,
ha_language: str,
ha_country: str | None,
@ -412,7 +412,7 @@ async def test_default_pipeline(
hass.config.language = ha_language
with (
patch.object(mock_stt_provider, "_supported_languages", MANY_LANGUAGES),
patch.object(mock_stt_provider_entity, "_supported_languages", MANY_LANGUAGES),
patch.object(mock_tts_provider, "_supported_languages", MANY_LANGUAGES),
):
assert await async_setup_component(hass, "assist_pipeline", {})
@ -429,7 +429,7 @@ async def test_default_pipeline(
id=pipeline.id,
language=pipeline_language,
name="Home Assistant",
stt_engine="test",
stt_engine="stt.mock_stt",
stt_language=stt_language,
tts_engine="test",
tts_language=tts_language,
@ -441,10 +441,10 @@ async def test_default_pipeline(
@pytest.mark.usefixtures("init_supporting_components")
async def test_default_pipeline_unsupported_stt_language(
hass: HomeAssistant, mock_stt_provider: MockSttProvider
hass: HomeAssistant, mock_stt_provider_entity: MockSttProviderEntity
) -> None:
"""Test async_get_pipeline."""
with patch.object(mock_stt_provider, "_supported_languages", ["smurfish"]):
with patch.object(mock_stt_provider_entity, "_supported_languages", ["smurfish"]):
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
@ -489,7 +489,7 @@ async def test_default_pipeline_unsupported_tts_language(
id=pipeline.id,
language="en",
name="Home Assistant",
stt_engine="test",
stt_engine="stt.mock_stt",
stt_language="en-US",
tts_engine=None,
tts_language=None,

View file

@ -682,7 +682,7 @@ async def test_stt_provider_missing(
) -> None:
"""Test events from a pipeline run with a non-existent STT provider."""
with patch(
"homeassistant.components.stt.async_get_provider",
"homeassistant.components.stt.async_get_speech_to_text_entity",
return_value=None,
):
client = await hass_ws_client(hass)
@ -708,11 +708,11 @@ async def test_stt_provider_bad_metadata(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
mock_stt_provider,
mock_stt_provider_entity,
snapshot: SnapshotAssertion,
) -> None:
"""Test events from a pipeline run with wrong metadata."""
with patch.object(mock_stt_provider, "check_metadata", return_value=False):
with patch.object(mock_stt_provider_entity, "check_metadata", return_value=False):
client = await hass_ws_client(hass)
await client.send_json_auto_id(
@ -743,7 +743,7 @@ async def test_stt_stream_failed(
client = await hass_ws_client(hass)
with patch(
"tests.components.assist_pipeline.conftest.MockSttProvider.async_process_audio_stream",
"tests.components.assist_pipeline.conftest.MockSttProviderEntity.async_process_audio_stream",
side_effect=RuntimeError,
):
await client.send_json_auto_id(
@ -1188,7 +1188,7 @@ async def test_get_pipeline(
"id": ANY,
"language": "en",
"name": "Home Assistant",
"stt_engine": "test",
"stt_engine": "stt.mock_stt",
"stt_language": "en-US",
"tts_engine": "test",
"tts_language": "en-US",
@ -1213,7 +1213,7 @@ async def test_get_pipeline(
"language": "en",
"name": "Home Assistant",
# It found these defaults
"stt_engine": "test",
"stt_engine": "stt.mock_stt",
"stt_language": "en-US",
"tts_engine": "test",
"tts_language": "en-US",
@ -1297,7 +1297,7 @@ async def test_list_pipelines(
"id": ANY,
"language": "en",
"name": "Home Assistant",
"stt_engine": "test",
"stt_engine": "stt.mock_stt",
"stt_language": "en-US",
"tts_engine": "test",
"tts_language": "en-US",

View file

@ -497,7 +497,7 @@ async def test_default_engine_entity(
@pytest.mark.parametrize("config_flow_test_domain", ["new_test"])
async def test_default_engine_prefer_provider(
async def test_default_engine_prefer_entity(
hass: HomeAssistant,
tmp_path: Path,
mock_provider_entity: MockProviderEntity,
@ -520,7 +520,7 @@ async def test_default_engine_prefer_provider(
provider_engine = async_get_speech_to_text_engine(hass, "test")
assert provider_engine is not None
assert provider_engine.name == "test"
assert async_default_engine(hass) == "test"
assert async_default_engine(hass) == "stt.new_test"
async def test_get_engine_legacy(