Prefer stt entity over legacy stt provider (#124625)
* Prefer stt entity over legacy stt provider * Update assist_pipeline tests
This commit is contained in:
parent
0a05cdc381
commit
7b71f024fb
7 changed files with 38 additions and 38 deletions
|
@ -72,9 +72,9 @@ CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
|
|||
@callback
|
||||
def async_default_engine(hass: HomeAssistant) -> str | None:
|
||||
"""Return the domain or entity id of the default engine."""
|
||||
return async_default_provider(hass) or next(
|
||||
return next(
|
||||
iter(hass.states.async_entity_ids(DOMAIN)), None
|
||||
)
|
||||
) or async_default_provider(hass)
|
||||
|
||||
|
||||
@callback
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'engine': 'test',
|
||||
'engine': 'stt.mock_stt',
|
||||
'metadata': dict({
|
||||
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
|
||||
'channel': <AudioChannels.CHANNEL_MONO: 1>,
|
||||
|
@ -301,7 +301,7 @@
|
|||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'engine': 'test',
|
||||
'engine': 'stt.mock_stt',
|
||||
'metadata': dict({
|
||||
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
|
||||
'channel': <AudioChannels.CHANNEL_MONO: 1>,
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# ---
|
||||
# name: test_audio_pipeline.1
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'engine': 'stt.mock_stt',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
|
@ -92,7 +92,7 @@
|
|||
# ---
|
||||
# name: test_audio_pipeline_debug.1
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'engine': 'stt.mock_stt',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
|
@ -185,7 +185,7 @@
|
|||
# ---
|
||||
# name: test_audio_pipeline_with_enhancements.1
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'engine': 'stt.mock_stt',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
|
@ -288,7 +288,7 @@
|
|||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_no_timeout.3
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'engine': 'stt.mock_stt',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
|
@ -401,7 +401,7 @@
|
|||
# ---
|
||||
# name: test_device_capture.1
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'engine': 'stt.mock_stt',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
|
@ -427,7 +427,7 @@
|
|||
# ---
|
||||
# name: test_device_capture_override.1
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'engine': 'stt.mock_stt',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
|
@ -475,7 +475,7 @@
|
|||
# ---
|
||||
# name: test_device_capture_queue_full.1
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'engine': 'stt.mock_stt',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
|
@ -649,7 +649,7 @@
|
|||
# ---
|
||||
# name: test_stt_stream_failed.1
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'engine': 'stt.mock_stt',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
|
|
|
@ -47,7 +47,7 @@ def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
|
|||
|
||||
async def test_pipeline_from_audio_stream_auto(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider: MockSttProvider,
|
||||
mock_stt_provider_entity: MockSttProviderEntity,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
|
@ -80,9 +80,9 @@ async def test_pipeline_from_audio_stream_auto(
|
|||
)
|
||||
|
||||
assert process_events(events) == snapshot
|
||||
assert len(mock_stt_provider.received) == 2
|
||||
assert mock_stt_provider.received[0].startswith(b"part1")
|
||||
assert mock_stt_provider.received[1].startswith(b"part2")
|
||||
assert len(mock_stt_provider_entity.received) == 2
|
||||
assert mock_stt_provider_entity.received[0].startswith(b"part1")
|
||||
assert mock_stt_provider_entity.received[1].startswith(b"part2")
|
||||
|
||||
|
||||
async def test_pipeline_from_audio_stream_legacy(
|
||||
|
@ -319,7 +319,7 @@ async def test_pipeline_from_audio_stream_unknown_pipeline(
|
|||
|
||||
async def test_pipeline_from_audio_stream_wake_word(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider: MockSttProvider,
|
||||
mock_stt_provider_entity: MockSttProviderEntity,
|
||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
|
@ -381,16 +381,16 @@ async def test_pipeline_from_audio_stream_wake_word(
|
|||
# 2. queued audio (from mock wake word entity)
|
||||
# 3. part1
|
||||
# 4. part2
|
||||
assert len(mock_stt_provider.received) > 3
|
||||
assert len(mock_stt_provider_entity.received) > 3
|
||||
|
||||
first_chunk = bytes(
|
||||
[c_byte for c in mock_stt_provider.received[:-3] for c_byte in c]
|
||||
[c_byte for c in mock_stt_provider_entity.received[:-3] for c_byte in c]
|
||||
)
|
||||
assert first_chunk == wake_chunk_1[len(wake_chunk_1) // 2 :] + wake_chunk_2
|
||||
|
||||
assert mock_stt_provider.received[-3] == b"queued audio"
|
||||
assert mock_stt_provider.received[-2].startswith(b"part1")
|
||||
assert mock_stt_provider.received[-1].startswith(b"part2")
|
||||
assert mock_stt_provider_entity.received[-3] == b"queued audio"
|
||||
assert mock_stt_provider_entity.received[-2].startswith(b"part1")
|
||||
assert mock_stt_provider_entity.received[-1].startswith(b"part2")
|
||||
|
||||
|
||||
async def test_pipeline_save_audio(
|
||||
|
|
|
@ -26,7 +26,7 @@ from homeassistant.core import HomeAssistant
|
|||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import MANY_LANGUAGES
|
||||
from .conftest import MockSttProvider, MockTTSProvider
|
||||
from .conftest import MockSttProviderEntity, MockTTSProvider
|
||||
|
||||
from tests.common import flush_store
|
||||
|
||||
|
@ -398,7 +398,7 @@ async def test_default_pipeline_no_stt_tts(
|
|||
@pytest.mark.usefixtures("init_supporting_components")
|
||||
async def test_default_pipeline(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider: MockSttProvider,
|
||||
mock_stt_provider_entity: MockSttProviderEntity,
|
||||
mock_tts_provider: MockTTSProvider,
|
||||
ha_language: str,
|
||||
ha_country: str | None,
|
||||
|
@ -412,7 +412,7 @@ async def test_default_pipeline(
|
|||
hass.config.language = ha_language
|
||||
|
||||
with (
|
||||
patch.object(mock_stt_provider, "_supported_languages", MANY_LANGUAGES),
|
||||
patch.object(mock_stt_provider_entity, "_supported_languages", MANY_LANGUAGES),
|
||||
patch.object(mock_tts_provider, "_supported_languages", MANY_LANGUAGES),
|
||||
):
|
||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||
|
@ -429,7 +429,7 @@ async def test_default_pipeline(
|
|||
id=pipeline.id,
|
||||
language=pipeline_language,
|
||||
name="Home Assistant",
|
||||
stt_engine="test",
|
||||
stt_engine="stt.mock_stt",
|
||||
stt_language=stt_language,
|
||||
tts_engine="test",
|
||||
tts_language=tts_language,
|
||||
|
@ -441,10 +441,10 @@ async def test_default_pipeline(
|
|||
|
||||
@pytest.mark.usefixtures("init_supporting_components")
|
||||
async def test_default_pipeline_unsupported_stt_language(
|
||||
hass: HomeAssistant, mock_stt_provider: MockSttProvider
|
||||
hass: HomeAssistant, mock_stt_provider_entity: MockSttProviderEntity
|
||||
) -> None:
|
||||
"""Test async_get_pipeline."""
|
||||
with patch.object(mock_stt_provider, "_supported_languages", ["smurfish"]):
|
||||
with patch.object(mock_stt_provider_entity, "_supported_languages", ["smurfish"]):
|
||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
|
@ -489,7 +489,7 @@ async def test_default_pipeline_unsupported_tts_language(
|
|||
id=pipeline.id,
|
||||
language="en",
|
||||
name="Home Assistant",
|
||||
stt_engine="test",
|
||||
stt_engine="stt.mock_stt",
|
||||
stt_language="en-US",
|
||||
tts_engine=None,
|
||||
tts_language=None,
|
||||
|
|
|
@ -682,7 +682,7 @@ async def test_stt_provider_missing(
|
|||
) -> None:
|
||||
"""Test events from a pipeline run with a non-existent STT provider."""
|
||||
with patch(
|
||||
"homeassistant.components.stt.async_get_provider",
|
||||
"homeassistant.components.stt.async_get_speech_to_text_entity",
|
||||
return_value=None,
|
||||
):
|
||||
client = await hass_ws_client(hass)
|
||||
|
@ -708,11 +708,11 @@ async def test_stt_provider_bad_metadata(
|
|||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
mock_stt_provider,
|
||||
mock_stt_provider_entity,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test events from a pipeline run with wrong metadata."""
|
||||
with patch.object(mock_stt_provider, "check_metadata", return_value=False):
|
||||
with patch.object(mock_stt_provider_entity, "check_metadata", return_value=False):
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json_auto_id(
|
||||
|
@ -743,7 +743,7 @@ async def test_stt_stream_failed(
|
|||
client = await hass_ws_client(hass)
|
||||
|
||||
with patch(
|
||||
"tests.components.assist_pipeline.conftest.MockSttProvider.async_process_audio_stream",
|
||||
"tests.components.assist_pipeline.conftest.MockSttProviderEntity.async_process_audio_stream",
|
||||
side_effect=RuntimeError,
|
||||
):
|
||||
await client.send_json_auto_id(
|
||||
|
@ -1188,7 +1188,7 @@ async def test_get_pipeline(
|
|||
"id": ANY,
|
||||
"language": "en",
|
||||
"name": "Home Assistant",
|
||||
"stt_engine": "test",
|
||||
"stt_engine": "stt.mock_stt",
|
||||
"stt_language": "en-US",
|
||||
"tts_engine": "test",
|
||||
"tts_language": "en-US",
|
||||
|
@ -1213,7 +1213,7 @@ async def test_get_pipeline(
|
|||
"language": "en",
|
||||
"name": "Home Assistant",
|
||||
# It found these defaults
|
||||
"stt_engine": "test",
|
||||
"stt_engine": "stt.mock_stt",
|
||||
"stt_language": "en-US",
|
||||
"tts_engine": "test",
|
||||
"tts_language": "en-US",
|
||||
|
@ -1297,7 +1297,7 @@ async def test_list_pipelines(
|
|||
"id": ANY,
|
||||
"language": "en",
|
||||
"name": "Home Assistant",
|
||||
"stt_engine": "test",
|
||||
"stt_engine": "stt.mock_stt",
|
||||
"stt_language": "en-US",
|
||||
"tts_engine": "test",
|
||||
"tts_language": "en-US",
|
||||
|
|
|
@ -497,7 +497,7 @@ async def test_default_engine_entity(
|
|||
|
||||
|
||||
@pytest.mark.parametrize("config_flow_test_domain", ["new_test"])
|
||||
async def test_default_engine_prefer_provider(
|
||||
async def test_default_engine_prefer_entity(
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
mock_provider_entity: MockProviderEntity,
|
||||
|
@ -520,7 +520,7 @@ async def test_default_engine_prefer_provider(
|
|||
provider_engine = async_get_speech_to_text_engine(hass, "test")
|
||||
assert provider_engine is not None
|
||||
assert provider_engine.name == "test"
|
||||
assert async_default_engine(hass) == "test"
|
||||
assert async_default_engine(hass) == "stt.new_test"
|
||||
|
||||
|
||||
async def test_get_engine_legacy(
|
||||
|
|
Loading…
Add table
Reference in a new issue