From c3717f8182d0eb7e176efc08ccb5be61d91b9a27 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sun, 26 Mar 2023 22:41:17 -0400 Subject: [PATCH] Clean up voice assistant integration (#90239) * Clean up voice assistant * Reinstate auto-removed imports * Resample STT audio from 44.1Khz to 16Khz * Energy based VAD for prototyping --------- Co-authored-by: Michael Hansen --- homeassistant/components/cloud/stt.py | 6 +- .../components/voice_assistant/pipeline.py | 150 ++++++------ .../voice_assistant/websocket_api.py | 62 ++++- .../snapshots/test_websocket.ambr | 210 +++++++++++++++++ .../voice_assistant/test_websocket.py | 216 +++++------------- 5 files changed, 407 insertions(+), 237 deletions(-) create mode 100644 tests/components/voice_assistant/snapshots/test_websocket.ambr diff --git a/homeassistant/components/cloud/stt.py b/homeassistant/components/cloud/stt.py index bdce055c3c4..13062db57d6 100644 --- a/homeassistant/components/cloud/stt.py +++ b/homeassistant/components/cloud/stt.py @@ -2,6 +2,7 @@ from __future__ import annotations from collections.abc import AsyncIterable +import logging from hass_nabucasa import Cloud from hass_nabucasa.voice import VoiceError @@ -20,6 +21,8 @@ from homeassistant.components.stt import ( from .const import DOMAIN +_LOGGER = logging.getLogger(__name__) + SUPPORT_LANGUAGES = [ "da-DK", "de-DE", @@ -102,7 +105,8 @@ class CloudProvider(Provider): result = await self.cloud.voice.process_stt( stream, content, metadata.language ) - except VoiceError: + except VoiceError as err: + _LOGGER.debug("Voice error: %s", err) return SpeechResult(None, SpeechResultState.ERROR) # Return Speech as Text diff --git a/homeassistant/components/voice_assistant/pipeline.py b/homeassistant/components/voice_assistant/pipeline.py index 0070154bd40..806a603f5e5 100644 --- a/homeassistant/components/voice_assistant/pipeline.py +++ b/homeassistant/components/voice_assistant/pipeline.py @@ -150,6 +150,7 @@ class PipelineRun: end_stage: PipelineStage event_callback: Callable[[PipelineEvent], None] language: str = None # type: ignore[assignment] + runner_data: Any | None = None def __post_init__(self): """Set language for pipeline.""" @@ -163,15 +164,14 @@ class PipelineRun: def start(self): """Emit run start event.""" - self.event_callback( - PipelineEvent( - PipelineEventType.RUN_START, - { - "pipeline": self.pipeline.name, - "language": self.language, - }, - ) - ) + data = { + "pipeline": self.pipeline.name, + "language": self.language, + } + if self.runner_data is not None: + data["runner_data"] = self.runner_data + + self.event_callback(PipelineEvent(PipelineEventType.RUN_START, data)) def end(self): """Emit run end event.""" @@ -200,41 +200,45 @@ class PipelineRun: try: # Load provider - stt_provider = stt.async_get_provider(self.hass, self.pipeline.stt_engine) + stt_provider: stt.Provider = stt.async_get_provider( + self.hass, self.pipeline.stt_engine + ) assert stt_provider is not None except Exception as src_error: - stt_error = SpeechToTextError( + _LOGGER.exception("No speech to text provider for %s", engine) + raise SpeechToTextError( code="stt-provider-missing", message=f"No speech to text provider for: {engine}", + ) from src_error + + if not stt_provider.check_metadata(metadata): + raise SpeechToTextError( + code="stt-provider-unsupported-metadata", + message=f"Provider {engine} does not support input speech to text metadata", ) - _LOGGER.exception(stt_error.message) - self.event_callback( - PipelineEvent( - PipelineEventType.ERROR, - {"code": stt_error.code, "message": stt_error.message}, - ) - ) - raise stt_error from src_error try: # Transcribe audio stream result = await stt_provider.async_process_audio_stream(metadata, stream) - assert (result.text is not None) and ( - result.result == stt.SpeechResultState.SUCCESS - ) except Exception as src_error: - stt_error = SpeechToTextError( + _LOGGER.exception("Unexpected error during speech to text") + raise SpeechToTextError( code="stt-stream-failed", message="Unexpected error during speech to text", + ) from src_error + + _LOGGER.debug("speech-to-text result %s", result) + + if result.result != stt.SpeechResultState.SUCCESS: + raise SpeechToTextError( + code="stt-stream-failed", + message="Speech to text failed", ) - _LOGGER.exception(stt_error.message) - self.event_callback( - PipelineEvent( - PipelineEventType.ERROR, - {"code": stt_error.code, "message": stt_error.message}, - ) + + if not result.text: + raise SpeechToTextError( + code="stt-no-text-recognized", message="No text recognized" ) - raise stt_error from src_error self.event_callback( PipelineEvent( @@ -273,18 +277,13 @@ class PipelineRun: agent_id=self.pipeline.conversation_engine, ) except Exception as src_error: - intent_error = IntentRecognitionError( + _LOGGER.exception("Unexpected error during intent recognition") + raise IntentRecognitionError( code="intent-failed", message="Unexpected error during intent recognition", - ) - _LOGGER.exception(intent_error.message) - self.event_callback( - PipelineEvent( - PipelineEventType.ERROR, - {"code": intent_error.code, "message": intent_error.message}, - ) - ) - raise intent_error from src_error + ) from src_error + + _LOGGER.debug("conversation result %s", conversation_result) self.event_callback( PipelineEvent( @@ -320,18 +319,13 @@ class PipelineRun: ), ) except Exception as src_error: - tts_error = TextToSpeechError( + _LOGGER.exception("Unexpected error during text to speech") + raise TextToSpeechError( code="tts-failed", message="Unexpected error during text to speech", - ) - _LOGGER.exception(tts_error.message) - self.event_callback( - PipelineEvent( - PipelineEventType.ERROR, - {"code": tts_error.code, "message": tts_error.message}, - ) - ) - raise tts_error from src_error + ) from src_error + + _LOGGER.debug("TTS result %s", tts_media) self.event_callback( PipelineEvent( @@ -377,31 +371,41 @@ class PipelineInput: run.start() current_stage = run.start_stage - # Speech to text - intent_input = self.intent_input - if current_stage == PipelineStage.STT: - assert self.stt_metadata is not None - assert self.stt_stream is not None - intent_input = await run.speech_to_text( - self.stt_metadata, - self.stt_stream, - ) - current_stage = PipelineStage.INTENT - - if run.end_stage != PipelineStage.STT: - tts_input = self.tts_input - - if current_stage == PipelineStage.INTENT: - assert intent_input is not None - tts_input = await run.recognize_intent( - intent_input, self.conversation_id + try: + # Speech to text + intent_input = self.intent_input + if current_stage == PipelineStage.STT: + assert self.stt_metadata is not None + assert self.stt_stream is not None + intent_input = await run.speech_to_text( + self.stt_metadata, + self.stt_stream, ) - current_stage = PipelineStage.TTS + current_stage = PipelineStage.INTENT - if run.end_stage != PipelineStage.INTENT: - if current_stage == PipelineStage.TTS: - assert tts_input is not None - await run.text_to_speech(tts_input) + if run.end_stage != PipelineStage.STT: + tts_input = self.tts_input + + if current_stage == PipelineStage.INTENT: + assert intent_input is not None + tts_input = await run.recognize_intent( + intent_input, self.conversation_id + ) + current_stage = PipelineStage.TTS + + if run.end_stage != PipelineStage.INTENT: + if current_stage == PipelineStage.TTS: + assert tts_input is not None + await run.text_to_speech(tts_input) + + except PipelineError as err: + run.event_callback( + PipelineEvent( + PipelineEventType.ERROR, + {"code": err.code, "message": err.message}, + ) + ) + return run.end() diff --git a/homeassistant/components/voice_assistant/websocket_api.py b/homeassistant/components/voice_assistant/websocket_api.py index cc4799f13e7..28cafb7a355 100644 --- a/homeassistant/components/voice_assistant/websocket_api.py +++ b/homeassistant/components/voice_assistant/websocket_api.py @@ -1,5 +1,6 @@ """Voice Assistant Websocket API.""" import asyncio +import audioop from collections.abc import Callable import logging from typing import Any @@ -12,6 +13,8 @@ from homeassistant.core import HomeAssistant, callback from .pipeline import ( DEFAULT_TIMEOUT, PipelineError, + PipelineEvent, + PipelineEventType, PipelineInput, PipelineRun, PipelineStage, @@ -20,6 +23,10 @@ from .pipeline import ( _LOGGER = logging.getLogger(__name__) +_VAD_ENERGY_THRESHOLD = 1000 +_VAD_SPEECH_FRAMES = 25 +_VAD_SILENCE_FRAMES = 25 + @callback def async_register_websocket_api(hass: HomeAssistant) -> None: @@ -27,6 +34,17 @@ def async_register_websocket_api(hass: HomeAssistant) -> None: websocket_api.async_register_command(hass, websocket_run) +def _get_debiased_energy(audio_data: bytes, width: int = 2) -> float: + """Compute RMS of debiased audio.""" + energy = -audioop.rms(audio_data, width) + energy_bytes = bytes([energy & 0xFF, (energy >> 8) & 0xFF]) + debiased_energy = audioop.rms( + audioop.add(audio_data, energy_bytes * (len(audio_data) // width), width), width + ) + + return debiased_energy + + @websocket_api.websocket_command( { vol.Required("type"): "voice_assistant/run", @@ -49,6 +67,11 @@ async def websocket_run( ) -> None: """Run a pipeline.""" language = msg.get("language", hass.config.language) + + # Temporary workaround for language codes + if language == "en": + language = "en-US" + pipeline_id = msg.get("pipeline") pipeline = async_get_pipeline( hass, @@ -79,8 +102,32 @@ async def websocket_run( audio_queue: "asyncio.Queue[bytes]" = asyncio.Queue() async def stt_stream(): + state = None + speech_count = 0 + in_voice_command = False + # Yield until we receive an empty chunk while chunk := await audio_queue.get(): + chunk, state = audioop.ratecv(chunk, 2, 1, 44100, 16000, state) + is_speech = _get_debiased_energy(chunk) > _VAD_ENERGY_THRESHOLD + + if in_voice_command: + if is_speech: + speech_count += 1 + else: + speech_count -= 1 + + if speech_count <= -_VAD_SILENCE_FRAMES: + _LOGGER.info("Voice command stopped") + break + else: + if is_speech: + speech_count += 1 + + if speech_count >= _VAD_SPEECH_FRAMES: + in_voice_command = True + _LOGGER.info("Voice command started") + yield chunk def handle_binary(_hass, _connection, data: bytes): @@ -119,6 +166,9 @@ async def websocket_run( event_callback=lambda event: connection.send_event( msg["id"], event.as_dict() ), + runner_data={ + "stt_binary_handler_id": handler_id, + }, ), timeout=timeout, ) @@ -130,16 +180,20 @@ async def websocket_run( # Confirm subscription connection.send_result(msg["id"]) - if handler_id is not None: - # Send handler id to client - connection.send_event(msg["id"], {"handler_id": handler_id}) - try: # Task contains a timeout await run_task except PipelineError as error: # Report more specific error when possible connection.send_error(msg["id"], error.code, error.message) + except asyncio.TimeoutError: + connection.send_event( + msg["id"], + PipelineEvent( + PipelineEventType.ERROR, + {"code": "timeout", "message": "Timeout running pipeline"}, + ), + ) finally: if unregister_handler is not None: # Unregister binary handler diff --git a/tests/components/voice_assistant/snapshots/test_websocket.ambr b/tests/components/voice_assistant/snapshots/test_websocket.ambr new file mode 100644 index 00000000000..07934df6c4c --- /dev/null +++ b/tests/components/voice_assistant/snapshots/test_websocket.ambr @@ -0,0 +1,210 @@ +# serializer version: 1 +# name: test_audio_pipeline + dict({ + 'language': 'en-US', + 'pipeline': 'en-US', + 'runner_data': dict({ + 'stt_binary_handler_id': 1, + }), + }) +# --- +# name: test_audio_pipeline.1 + dict({ + 'engine': 'default', + 'metadata': dict({ + 'bit_rate': 16, + 'channel': 1, + 'codec': 'pcm', + 'format': 'wav', + 'language': 'en-US', + 'sample_rate': 16000, + }), + }) +# --- +# name: test_audio_pipeline.2 + dict({ + 'stt_output': dict({ + 'text': 'test transcript', + }), + }) +# --- +# name: test_audio_pipeline.3 + dict({ + 'engine': 'default', + 'intent_input': 'test transcript', + }) +# --- +# name: test_audio_pipeline.4 + dict({ + 'intent_output': dict({ + 'conversation_id': None, + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'code': 'no_intent_match', + }), + 'language': 'en-US', + 'response_type': 'error', + 'speech': dict({ + 'plain': dict({ + 'extra_data': None, + 'speech': "Sorry, I couldn't understand that", + }), + }), + }), + }), + }) +# --- +# name: test_audio_pipeline.5 + dict({ + 'engine': 'default', + 'tts_input': "Sorry, I couldn't understand that", + }) +# --- +# name: test_audio_pipeline.6 + dict({ + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en_-_demo.mp3', + }), + }) +# --- +# name: test_intent_failed + dict({ + 'language': 'en-US', + 'pipeline': 'en-US', + 'runner_data': dict({ + 'stt_binary_handler_id': None, + }), + }) +# --- +# name: test_intent_failed.1 + dict({ + 'engine': 'default', + 'intent_input': 'Are the lights on?', + }) +# --- +# name: test_intent_timeout + dict({ + 'language': 'en-US', + 'pipeline': 'en-US', + 'runner_data': dict({ + 'stt_binary_handler_id': None, + }), + }) +# --- +# name: test_intent_timeout.1 + dict({ + 'engine': 'default', + 'intent_input': 'Are the lights on?', + }) +# --- +# name: test_intent_timeout.2 + dict({ + 'code': 'timeout', + 'message': 'Timeout running pipeline', + }) +# --- +# name: test_stt_provider_missing + dict({ + 'language': 'en-US', + 'pipeline': 'en-US', + 'runner_data': dict({ + 'stt_binary_handler_id': 1, + }), + }) +# --- +# name: test_stt_provider_missing.1 + dict({ + 'engine': 'default', + 'metadata': dict({ + 'bit_rate': 16, + 'channel': 1, + 'codec': 'pcm', + 'format': 'wav', + 'language': 'en-US', + 'sample_rate': 16000, + }), + }) +# --- +# name: test_stt_stream_failed + dict({ + 'language': 'en-US', + 'pipeline': 'en-US', + 'runner_data': dict({ + 'stt_binary_handler_id': 1, + }), + }) +# --- +# name: test_stt_stream_failed.1 + dict({ + 'engine': 'default', + 'metadata': dict({ + 'bit_rate': 16, + 'channel': 1, + 'codec': 'pcm', + 'format': 'wav', + 'language': 'en-US', + 'sample_rate': 16000, + }), + }) +# --- +# name: test_text_only_pipeline + dict({ + 'language': 'en-US', + 'pipeline': 'en-US', + 'runner_data': dict({ + 'stt_binary_handler_id': None, + }), + }) +# --- +# name: test_text_only_pipeline.1 + dict({ + 'engine': 'default', + 'intent_input': 'Are the lights on?', + }) +# --- +# name: test_text_only_pipeline.2 + dict({ + 'intent_output': dict({ + 'conversation_id': None, + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'code': 'no_intent_match', + }), + 'language': 'en-US', + 'response_type': 'error', + 'speech': dict({ + 'plain': dict({ + 'extra_data': None, + 'speech': "Sorry, I couldn't understand that", + }), + }), + }), + }), + }) +# --- +# name: test_text_pipeline_timeout + dict({ + 'code': 'timeout', + 'message': 'Timeout running pipeline', + }) +# --- +# name: test_tts_failed + dict({ + 'language': 'en-US', + 'pipeline': 'en-US', + 'runner_data': dict({ + 'stt_binary_handler_id': None, + }), + }) +# --- +# name: test_tts_failed.1 + dict({ + 'engine': 'default', + 'tts_input': 'Lights are on.', + }) +# --- diff --git a/tests/components/voice_assistant/test_websocket.py b/tests/components/voice_assistant/test_websocket.py index a1ba8b5f7cb..f02122a3e7f 100644 --- a/tests/components/voice_assistant/test_websocket.py +++ b/tests/components/voice_assistant/test_websocket.py @@ -4,6 +4,7 @@ from collections.abc import AsyncIterable from unittest.mock import MagicMock, patch import pytest +from syrupy.assertion import SnapshotAssertion from homeassistant.components import stt from homeassistant.core import HomeAssistant @@ -29,7 +30,7 @@ class MockSttProvider(stt.Provider): @property def supported_languages(self) -> list[str]: """Return a list of supported languages.""" - return [self.hass.config.language] + return ["en-US"] @property def supported_formats(self) -> list[stt.AudioFormats]: @@ -64,7 +65,11 @@ class MockSttProvider(stt.Provider): @pytest.fixture(autouse=True) -async def init_components(hass): +async def init_components( + hass: HomeAssistant, + mock_get_cache_files, # noqa: F811 + mock_init_cache_dir, # noqa: F811 +): """Initialize relevant components with empty configs.""" assert await async_setup_component(hass, "media_source", {}) assert await async_setup_component( @@ -93,6 +98,7 @@ async def init_components(hass): async def test_text_only_pipeline( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, + snapshot: SnapshotAssertion, ) -> None: """Test events from a pipeline run with text input (no STT/TTS).""" client = await hass_ws_client(hass) @@ -114,38 +120,16 @@ async def test_text_only_pipeline( # run start msg = await client.receive_json() assert msg["event"]["type"] == "run-start" - assert msg["event"]["data"] == { - "pipeline": hass.config.language, - "language": hass.config.language, - } + assert msg["event"]["data"] == snapshot # intent msg = await client.receive_json() assert msg["event"]["type"] == "intent-start" - assert msg["event"]["data"] == { - "engine": "default", - "intent_input": "Are the lights on?", - } + assert msg["event"]["data"] == snapshot msg = await client.receive_json() assert msg["event"]["type"] == "intent-end" - assert msg["event"]["data"] == { - "intent_output": { - "response": { - "speech": { - "plain": { - "speech": "Sorry, I couldn't understand that", - "extra_data": None, - } - }, - "card": {}, - "language": "en", - "response_type": "error", - "data": {"code": "no_intent_match"}, - }, - "conversation_id": None, - } - } + assert msg["event"]["data"] == snapshot # run end msg = await client.receive_json() @@ -154,8 +138,7 @@ async def test_text_only_pipeline( async def test_audio_pipeline( - hass: HomeAssistant, - hass_ws_client: WebSocketGenerator, + hass: HomeAssistant, hass_ws_client: WebSocketGenerator, snapshot: SnapshotAssertion ) -> None: """Test events from a pipeline run with audio input/output.""" client = await hass_ws_client(hass) @@ -173,86 +156,40 @@ async def test_audio_pipeline( msg = await client.receive_json() assert msg["success"] - # handler id - msg = await client.receive_json() - assert msg["event"]["handler_id"] == 1 - # run start msg = await client.receive_json() assert msg["event"]["type"] == "run-start" - assert msg["event"]["data"] == { - "pipeline": hass.config.language, - "language": hass.config.language, - } + assert msg["event"]["data"] == snapshot # stt msg = await client.receive_json() assert msg["event"]["type"] == "stt-start" - assert msg["event"]["data"] == { - "engine": "default", - "metadata": { - "bit_rate": 16, - "channel": 1, - "codec": "pcm", - "format": "wav", - "language": "en", - "sample_rate": 16000, - }, - } + assert msg["event"]["data"] == snapshot # End of audio stream (handler id + empty payload) await client.send_bytes(b"1") msg = await client.receive_json() assert msg["event"]["type"] == "stt-end" - assert msg["event"]["data"] == { - "stt_output": {"text": _TRANSCRIPT}, - } + assert msg["event"]["data"] == snapshot # intent msg = await client.receive_json() assert msg["event"]["type"] == "intent-start" - assert msg["event"]["data"] == { - "engine": "default", - "intent_input": _TRANSCRIPT, - } + assert msg["event"]["data"] == snapshot msg = await client.receive_json() assert msg["event"]["type"] == "intent-end" - assert msg["event"]["data"] == { - "intent_output": { - "response": { - "speech": { - "plain": { - "speech": "Sorry, I couldn't understand that", - "extra_data": None, - } - }, - "card": {}, - "language": "en", - "response_type": "error", - "data": {"code": "no_intent_match"}, - }, - "conversation_id": None, - } - } + assert msg["event"]["data"] == snapshot # text to speech msg = await client.receive_json() assert msg["event"]["type"] == "tts-start" - assert msg["event"]["data"] == { - "engine": "default", - "tts_input": "Sorry, I couldn't understand that", - } + assert msg["event"]["data"] == snapshot msg = await client.receive_json() assert msg["event"]["type"] == "tts-end" - assert msg["event"]["data"] == { - "tts_output": { - "url": f"/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_{hass.config.language}_-_demo.mp3", - "mime_type": "audio/mpeg", - }, - } + assert msg["event"]["data"] == snapshot # run end msg = await client.receive_json() @@ -261,7 +198,10 @@ async def test_audio_pipeline( async def test_intent_timeout( - hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, + snapshot: SnapshotAssertion, ) -> None: """Test partial pipeline run with conversation agent timeout.""" client = await hass_ws_client(hass) @@ -291,27 +231,24 @@ async def test_intent_timeout( # run start msg = await client.receive_json() assert msg["event"]["type"] == "run-start" - assert msg["event"]["data"] == { - "pipeline": hass.config.language, - "language": hass.config.language, - } + assert msg["event"]["data"] == snapshot # intent msg = await client.receive_json() assert msg["event"]["type"] == "intent-start" - assert msg["event"]["data"] == { - "engine": "default", - "intent_input": "Are the lights on?", - } + assert msg["event"]["data"] == snapshot # timeout error msg = await client.receive_json() - assert not msg["success"] - assert msg["error"]["code"] == "timeout" + assert msg["event"]["type"] == "error" + assert msg["event"]["data"] == snapshot async def test_text_pipeline_timeout( - hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, + snapshot: SnapshotAssertion, ) -> None: """Test text-only pipeline run with immediate timeout.""" client = await hass_ws_client(hass) @@ -340,12 +277,15 @@ async def test_text_pipeline_timeout( # timeout error msg = await client.receive_json() - assert not msg["success"] - assert msg["error"]["code"] == "timeout" + assert msg["event"]["type"] == "error" + assert msg["event"]["data"] == snapshot async def test_intent_failed( - hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, + snapshot: SnapshotAssertion, ) -> None: """Test text-only pipeline run with conversation agent error.""" client = await hass_ws_client(hass) @@ -371,18 +311,12 @@ async def test_intent_failed( # run start msg = await client.receive_json() assert msg["event"]["type"] == "run-start" - assert msg["event"]["data"] == { - "pipeline": hass.config.language, - "language": hass.config.language, - } + assert msg["event"]["data"] == snapshot # intent start msg = await client.receive_json() assert msg["event"]["type"] == "intent-start" - assert msg["event"]["data"] == { - "engine": "default", - "intent_input": "Are the lights on?", - } + assert msg["event"]["data"] == snapshot # intent error msg = await client.receive_json() @@ -391,7 +325,10 @@ async def test_intent_failed( async def test_audio_pipeline_timeout( - hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, + snapshot: SnapshotAssertion, ) -> None: """Test audio pipeline run with immediate timeout.""" client = await hass_ws_client(hass) @@ -417,19 +354,16 @@ async def test_audio_pipeline_timeout( msg = await client.receive_json() assert msg["success"] - # handler id - msg = await client.receive_json() - assert msg["event"]["handler_id"] == 1 - # timeout error msg = await client.receive_json() - assert not msg["success"] - assert msg["error"]["code"] == "timeout" + assert msg["event"]["type"] == "error" + assert msg["event"]["data"]["code"] == "timeout" async def test_stt_provider_missing( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, + snapshot: SnapshotAssertion, ) -> None: """Test events from a pipeline run with a non-existent STT provider.""" with patch( @@ -451,32 +385,15 @@ async def test_stt_provider_missing( msg = await client.receive_json() assert msg["success"] - # handler id - msg = await client.receive_json() - assert msg["event"]["handler_id"] == 1 - # run start msg = await client.receive_json() assert msg["event"]["type"] == "run-start" - assert msg["event"]["data"] == { - "pipeline": hass.config.language, - "language": hass.config.language, - } + assert msg["event"]["data"] == snapshot # stt msg = await client.receive_json() assert msg["event"]["type"] == "stt-start" - assert msg["event"]["data"] == { - "engine": "default", - "metadata": { - "bit_rate": 16, - "channel": 1, - "codec": "pcm", - "format": "wav", - "language": "en", - "sample_rate": 16000, - }, - } + assert msg["event"]["data"] == snapshot # End of audio stream (handler id + empty payload) await client.send_bytes(b"1") @@ -490,6 +407,7 @@ async def test_stt_provider_missing( async def test_stt_stream_failed( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, + snapshot: SnapshotAssertion, ) -> None: """Test events from a pipeline run with a non-existent STT provider.""" with patch( @@ -511,32 +429,15 @@ async def test_stt_stream_failed( msg = await client.receive_json() assert msg["success"] - # handler id - msg = await client.receive_json() - assert msg["event"]["handler_id"] == 1 - # run start msg = await client.receive_json() assert msg["event"]["type"] == "run-start" - assert msg["event"]["data"] == { - "pipeline": hass.config.language, - "language": hass.config.language, - } + assert msg["event"]["data"] == snapshot # stt msg = await client.receive_json() assert msg["event"]["type"] == "stt-start" - assert msg["event"]["data"] == { - "engine": "default", - "metadata": { - "bit_rate": 16, - "channel": 1, - "codec": "pcm", - "format": "wav", - "language": "en", - "sample_rate": 16000, - }, - } + assert msg["event"]["data"] == snapshot # End of audio stream (handler id + empty payload) await client.send_bytes(b"1") @@ -548,7 +449,10 @@ async def test_stt_stream_failed( async def test_tts_failed( - hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, + snapshot: SnapshotAssertion, ) -> None: """Test pipeline run with text to speech error.""" client = await hass_ws_client(hass) @@ -574,18 +478,12 @@ async def test_tts_failed( # run start msg = await client.receive_json() assert msg["event"]["type"] == "run-start" - assert msg["event"]["data"] == { - "pipeline": hass.config.language, - "language": hass.config.language, - } + assert msg["event"]["data"] == snapshot # tts start msg = await client.receive_json() assert msg["event"]["type"] == "tts-start" - assert msg["event"]["data"] == { - "engine": "default", - "tts_input": "Lights are on.", - } + assert msg["event"]["data"] == snapshot # tts error msg = await client.receive_json()