diff --git a/homeassistant/components/wyoming/manifest.json b/homeassistant/components/wyoming/manifest.json index 7174683fd18..430e46fd890 100644 --- a/homeassistant/components/wyoming/manifest.json +++ b/homeassistant/components/wyoming/manifest.json @@ -6,6 +6,6 @@ "dependencies": ["assist_pipeline"], "documentation": "https://www.home-assistant.io/integrations/wyoming", "iot_class": "local_push", - "requirements": ["wyoming==1.4.0"], + "requirements": ["wyoming==1.5.0"], "zeroconf": ["_wyoming._tcp.local."] } diff --git a/homeassistant/components/wyoming/satellite.py b/homeassistant/components/wyoming/satellite.py index 78f57ff4b01..8e7586534f5 100644 --- a/homeassistant/components/wyoming/satellite.py +++ b/homeassistant/components/wyoming/satellite.py @@ -10,6 +10,7 @@ from wyoming.asr import Transcribe, Transcript from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStart, AudioStop from wyoming.client import AsyncTcpClient from wyoming.error import Error +from wyoming.ping import Ping, Pong from wyoming.pipeline import PipelineStage, RunPipeline from wyoming.satellite import RunSatellite from wyoming.tts import Synthesize, SynthesizeVoice @@ -29,6 +30,9 @@ _LOGGER = logging.getLogger() _SAMPLES_PER_CHUNK: Final = 1024 _RECONNECT_SECONDS: Final = 10 _RESTART_SECONDS: Final = 3 +_PING_TIMEOUT: Final = 5 +_PING_SEND_DELAY: Final = 2 +_PIPELINE_FINISH_TIMEOUT: Final = 1 # Wyoming stage -> Assist stage _STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = { @@ -54,6 +58,7 @@ class WyomingSatellite: self._client: AsyncTcpClient | None = None self._chunk_converter = AudioChunkConverter(rate=16000, width=2, channels=1) self._is_pipeline_running = False + self._pipeline_ended_event = asyncio.Event() self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue() self._pipeline_id: str | None = None self._muted_changed_event = asyncio.Event() @@ -77,9 +82,9 @@ class WyomingSatellite: return # Connect and run pipeline loop - await self._run_once() + await self._connect_and_loop() except asyncio.CancelledError: - raise + raise # don't restart except Exception: # pylint: disable=broad-exception-caught await self.on_restart() finally: @@ -142,8 +147,8 @@ class WyomingSatellite: # Cancel any running pipeline self._audio_queue.put_nowait(None) - async def _run_once(self) -> None: - """Run pipelines until an error occurs.""" + async def _connect_and_loop(self) -> None: + """Connect to satellite and run pipelines until an error occurs.""" self.device.set_is_active(False) while self.is_running and (not self.device.is_muted): @@ -163,27 +168,94 @@ class WyomingSatellite: # Tell satellite that we're ready await self._client.write_event(RunSatellite().event()) - # Wait until we get RunPipeline event - run_pipeline: RunPipeline | None = None + # Run until stopped or muted while self.is_running and (not self.device.is_muted): - run_event = await self._client.read_event() - if run_event is None: - raise ConnectionResetError("Satellite disconnected") + await self._run_pipeline_loop() - if RunPipeline.is_type(run_event.type): - run_pipeline = RunPipeline.from_event(run_event) - break + async def _run_pipeline_loop(self) -> None: + """Run a pipeline one or more times.""" + assert self._client is not None + run_pipeline: RunPipeline | None = None + send_ping = True - _LOGGER.debug("Unexpected event from satellite: %s", run_event) + # Read events and check for pipeline end in parallel + pipeline_ended_task = self.hass.async_create_background_task( + self._pipeline_ended_event.wait(), "satellite pipeline ended" + ) + client_event_task = self.hass.async_create_background_task( + self._client.read_event(), "satellite event read" + ) + pending = {pipeline_ended_task, client_event_task} - assert run_pipeline is not None + while self.is_running and (not self.device.is_muted): + if send_ping: + # Ensure satellite is still connected + send_ping = False + self.hass.async_create_background_task( + self._send_delayed_ping(), "ping satellite" + ) + + async with asyncio.timeout(_PING_TIMEOUT): + done, pending = await asyncio.wait( + pending, return_when=asyncio.FIRST_COMPLETED + ) + if pipeline_ended_task in done: + # Pipeline run end event was received + _LOGGER.debug("Pipeline finished") + self._pipeline_ended_event.clear() + pipeline_ended_task = self.hass.async_create_background_task( + self._pipeline_ended_event.wait(), "satellite pipeline ended" + ) + pending.add(pipeline_ended_task) + + if (run_pipeline is not None) and run_pipeline.restart_on_end: + # Automatically restart pipeline. + # Used with "always on" streaming satellites. + self._run_pipeline_once(run_pipeline) + continue + + if client_event_task not in done: + continue + + client_event = client_event_task.result() + if client_event is None: + raise ConnectionResetError("Satellite disconnected") + + if Pong.is_type(client_event.type): + # Satellite is still there, send next ping + send_ping = True + elif Ping.is_type(client_event.type): + # Respond to ping from satellite + ping = Ping.from_event(client_event) + await self._client.write_event(Pong(text=ping.text).event()) + elif RunPipeline.is_type(client_event.type): + # Satellite requested pipeline run + run_pipeline = RunPipeline.from_event(client_event) + self._run_pipeline_once(run_pipeline) + elif ( + AudioChunk.is_type(client_event.type) and self._is_pipeline_running + ): + # Microphone audio + chunk = AudioChunk.from_event(client_event) + chunk = self._chunk_converter.convert(chunk) + self._audio_queue.put_nowait(chunk.audio) + elif AudioStop.is_type(client_event.type) and self._is_pipeline_running: + # Stop pipeline + _LOGGER.debug("Client requested pipeline to stop") + self._audio_queue.put_nowait(b"") + else: + _LOGGER.debug("Unexpected event from satellite: %s", client_event) + + # Next event + client_event_task = self.hass.async_create_background_task( + self._client.read_event(), "satellite event read" + ) + pending.add(client_event_task) + + def _run_pipeline_once(self, run_pipeline: RunPipeline) -> None: + """Run a pipeline once.""" _LOGGER.debug("Received run information: %s", run_pipeline) - if (not self.is_running) or self.device.is_muted: - # Run was cancelled or satellite was disabled while waiting for - # RunPipeline event. - return - start_stage = _STAGES.get(run_pipeline.start_stage) end_stage = _STAGES.get(run_pipeline.end_stage) @@ -193,79 +265,64 @@ class WyomingSatellite: if end_stage is None: raise ValueError(f"Invalid end stage: {end_stage}") - # Each loop is a pipeline run - while self.is_running and (not self.device.is_muted): - # Use select to get pipeline each time in case it's changed - pipeline_id = pipeline_select.get_chosen_pipeline( + pipeline_id = pipeline_select.get_chosen_pipeline( + self.hass, + DOMAIN, + self.device.satellite_id, + ) + pipeline = assist_pipeline.async_get_pipeline(self.hass, pipeline_id) + assert pipeline is not None + + # We will push audio in through a queue + self._audio_queue = asyncio.Queue() + stt_stream = self._stt_stream() + + # Start pipeline running + _LOGGER.debug( + "Starting pipeline %s from %s to %s", + pipeline.name, + start_stage, + end_stage, + ) + self._is_pipeline_running = True + self._pipeline_ended_event.clear() + self.hass.async_create_background_task( + assist_pipeline.async_pipeline_from_audio_stream( self.hass, - DOMAIN, - self.device.satellite_id, - ) - pipeline = assist_pipeline.async_get_pipeline(self.hass, pipeline_id) - assert pipeline is not None + context=Context(), + event_callback=self._event_callback, + stt_metadata=stt.SpeechMetadata( + language=pipeline.language, + format=stt.AudioFormats.WAV, + codec=stt.AudioCodecs.PCM, + bit_rate=stt.AudioBitRates.BITRATE_16, + sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, + channel=stt.AudioChannels.CHANNEL_MONO, + ), + stt_stream=stt_stream, + start_stage=start_stage, + end_stage=end_stage, + tts_audio_output="wav", + pipeline_id=pipeline_id, + audio_settings=assist_pipeline.AudioSettings( + noise_suppression_level=self.device.noise_suppression_level, + auto_gain_dbfs=self.device.auto_gain, + volume_multiplier=self.device.volume_multiplier, + ), + device_id=self.device.device_id, + ), + name="wyoming satellite pipeline", + ) - # We will push audio in through a queue - self._audio_queue = asyncio.Queue() - stt_stream = self._stt_stream() + async def _send_delayed_ping(self) -> None: + """Send ping to satellite after a delay.""" + assert self._client is not None - # Start pipeline running - _LOGGER.debug( - "Starting pipeline %s from %s to %s", - pipeline.name, - start_stage, - end_stage, - ) - self._is_pipeline_running = True - _pipeline_task = asyncio.create_task( - assist_pipeline.async_pipeline_from_audio_stream( - self.hass, - context=Context(), - event_callback=self._event_callback, - stt_metadata=stt.SpeechMetadata( - language=pipeline.language, - format=stt.AudioFormats.WAV, - codec=stt.AudioCodecs.PCM, - bit_rate=stt.AudioBitRates.BITRATE_16, - sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, - channel=stt.AudioChannels.CHANNEL_MONO, - ), - stt_stream=stt_stream, - start_stage=start_stage, - end_stage=end_stage, - tts_audio_output="wav", - pipeline_id=pipeline_id, - audio_settings=assist_pipeline.AudioSettings( - noise_suppression_level=self.device.noise_suppression_level, - auto_gain_dbfs=self.device.auto_gain, - volume_multiplier=self.device.volume_multiplier, - ), - device_id=self.device.device_id, - ) - ) - - # Run until pipeline is complete or cancelled with an empty audio chunk - while self._is_pipeline_running: - client_event = await self._client.read_event() - if client_event is None: - raise ConnectionResetError("Satellite disconnected") - - if AudioChunk.is_type(client_event.type): - # Microphone audio - chunk = AudioChunk.from_event(client_event) - chunk = self._chunk_converter.convert(chunk) - self._audio_queue.put_nowait(chunk.audio) - elif AudioStop.is_type(client_event.type): - # Stop pipeline - _LOGGER.debug("Client requested pipeline to stop") - self._audio_queue.put_nowait(b"") - break - else: - _LOGGER.debug("Unexpected event from satellite: %s", client_event) - - # Ensure task finishes - await _pipeline_task - - _LOGGER.debug("Pipeline finished") + try: + await asyncio.sleep(_PING_SEND_DELAY) + await self._client.write_event(Ping().event()) + except ConnectionError: + pass # handled with timeout def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None: """Translate pipeline events into Wyoming events.""" @@ -274,6 +331,7 @@ class WyomingSatellite: if event.type == assist_pipeline.PipelineEventType.RUN_END: # Pipeline run is complete self._is_pipeline_running = False + self._pipeline_ended_event.set() self.device.set_is_active(False) elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START: self.hass.add_job(self._client.write_event(Detect().event())) @@ -413,10 +471,13 @@ class WyomingSatellite: async def _stt_stream(self) -> AsyncGenerator[bytes, None]: """Yield audio chunks from a queue.""" - is_first_chunk = True - while chunk := await self._audio_queue.get(): - if is_first_chunk: - is_first_chunk = False - _LOGGER.debug("Receiving audio from satellite") + try: + is_first_chunk = True + while chunk := await self._audio_queue.get(): + if is_first_chunk: + is_first_chunk = False + _LOGGER.debug("Receiving audio from satellite") - yield chunk + yield chunk + except asyncio.CancelledError: + pass # ignore diff --git a/requirements_all.txt b/requirements_all.txt index d58a208774a..d7df00ae9b7 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -2821,7 +2821,7 @@ wled==0.17.0 wolf-smartset==0.1.11 # homeassistant.components.wyoming -wyoming==1.4.0 +wyoming==1.5.0 # homeassistant.components.xbox xbox-webapi==2.0.11 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 4cc6cea4b0c..b6862948c0f 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -2141,7 +2141,7 @@ wled==0.17.0 wolf-smartset==0.1.11 # homeassistant.components.wyoming -wyoming==1.4.0 +wyoming==1.5.0 # homeassistant.components.xbox xbox-webapi==2.0.11 diff --git a/tests/components/wyoming/test_satellite.py b/tests/components/wyoming/test_satellite.py index 07a6aa8925e..b6564afcfe9 100644 --- a/tests/components/wyoming/test_satellite.py +++ b/tests/components/wyoming/test_satellite.py @@ -2,7 +2,9 @@ from __future__ import annotations import asyncio +from collections.abc import Callable import io +from typing import Any from unittest.mock import patch import wave @@ -10,6 +12,7 @@ from wyoming.asr import Transcribe, Transcript from wyoming.audio import AudioChunk, AudioStart, AudioStop from wyoming.error import Error from wyoming.event import Event +from wyoming.ping import Ping, Pong from wyoming.pipeline import PipelineStage, RunPipeline from wyoming.satellite import RunSatellite from wyoming.tts import Synthesize @@ -100,6 +103,12 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient): self.error_event = asyncio.Event() self.error: Error | None = None + self.pong_event = asyncio.Event() + self.pong: Pong | None = None + + self.ping_event = asyncio.Event() + self.ping: Ping | None = None + self._mic_audio_chunk = AudioChunk( rate=16000, width=2, channels=1, audio=b"chunk" ).event() @@ -142,6 +151,12 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient): elif Error.is_type(event.type): self.error = Error.from_event(event) self.error_event.set() + elif Pong.is_type(event.type): + self.pong = Pong.from_event(event) + self.pong_event.set() + elif Ping.is_type(event.type): + self.ping = Ping.from_event(event) + self.ping_event.set() async def read_event(self) -> Event | None: """Receive.""" @@ -150,6 +165,10 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient): # Keep sending audio chunks instead of None return event or self._mic_audio_chunk + def inject_event(self, event: Event) -> None: + """Put an event in as the next response.""" + self.responses = [event] + self.responses + async def test_satellite_pipeline(hass: HomeAssistant) -> None: """Test running a pipeline with a satellite.""" @@ -157,10 +176,37 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None: events = [ RunPipeline( - start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS + start_stage=PipelineStage.WAKE, + end_stage=PipelineStage.TTS, + restart_on_end=True, ).event(), ] + pipeline_kwargs: dict[str, Any] = {} + pipeline_event_callback: Callable[ + [assist_pipeline.PipelineEvent], None + ] | None = None + run_pipeline_called = asyncio.Event() + audio_chunk_received = asyncio.Event() + + async def async_pipeline_from_audio_stream( + hass: HomeAssistant, + context, + event_callback, + stt_metadata, + stt_stream, + **kwargs, + ) -> None: + nonlocal pipeline_kwargs, pipeline_event_callback + pipeline_kwargs = kwargs + pipeline_event_callback = event_callback + + run_pipeline_called.set() + async for chunk in stt_stream: + if chunk: + audio_chunk_received.set() + break + with patch( "homeassistant.components.wyoming.data.load_wyoming_info", return_value=SATELLITE_INFO, @@ -169,10 +215,11 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None: SatelliteAsyncTcpClient(events), ) as mock_client, patch( "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", - ) as mock_run_pipeline, patch( + async_pipeline_from_audio_stream, + ), patch( "homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio", return_value=("wav", get_test_wav()), - ): + ), patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0): entry = await setup_config_entry(hass) device: SatelliteDevice = hass.data[wyoming.DOMAIN][ entry.entry_id @@ -182,12 +229,39 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None: await mock_client.connect_event.wait() await mock_client.run_satellite_event.wait() - mock_run_pipeline.assert_called_once() - event_callback = mock_run_pipeline.call_args.kwargs["event_callback"] - assert mock_run_pipeline.call_args.kwargs.get("device_id") == device.device_id + async with asyncio.timeout(1): + await run_pipeline_called.wait() + + # Reset so we can check the pipeline is automatically restarted below + run_pipeline_called.clear() + + assert pipeline_event_callback is not None + assert pipeline_kwargs.get("device_id") == device.device_id + + # Test a ping + mock_client.inject_event(Ping("test-ping").event()) + + # Pong is expected with the same text + async with asyncio.timeout(1): + await mock_client.pong_event.wait() + + assert mock_client.pong is not None + assert mock_client.pong.text == "test-ping" + + # The client should have received the first ping + async with asyncio.timeout(1): + await mock_client.ping_event.wait() + + assert mock_client.ping is not None + + # Reset and send a pong back. + # We will get a second ping by the end of the test. + mock_client.ping_event.clear() + mock_client.ping = None + mock_client.inject_event(Pong().event()) # Start detecting wake word - event_callback( + pipeline_event_callback( assist_pipeline.PipelineEvent( assist_pipeline.PipelineEventType.WAKE_WORD_START ) @@ -198,8 +272,13 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None: assert not device.is_active assert not device.is_muted + # Push in some audio + mock_client.inject_event( + AudioChunk(rate=16000, width=2, channels=1, audio=bytes(1024)).event() + ) + # Wake word is detected - event_callback( + pipeline_event_callback( assist_pipeline.PipelineEvent( assist_pipeline.PipelineEventType.WAKE_WORD_END, {"wake_word_output": {"wake_word_id": "test_wake_word"}}, @@ -215,7 +294,7 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None: assert device.is_active # Speech-to-text started - event_callback( + pipeline_event_callback( assist_pipeline.PipelineEvent( assist_pipeline.PipelineEventType.STT_START, {"metadata": {"language": "en"}}, @@ -227,8 +306,13 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None: assert mock_client.transcribe is not None assert mock_client.transcribe.language == "en" + # Push in some audio + mock_client.inject_event( + AudioChunk(rate=16000, width=2, channels=1, audio=bytes(1024)).event() + ) + # User started speaking - event_callback( + pipeline_event_callback( assist_pipeline.PipelineEvent( assist_pipeline.PipelineEventType.STT_VAD_START, {"timestamp": 1234} ) @@ -240,7 +324,7 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None: assert mock_client.voice_started.timestamp == 1234 # User stopped speaking - event_callback( + pipeline_event_callback( assist_pipeline.PipelineEvent( assist_pipeline.PipelineEventType.STT_VAD_END, {"timestamp": 5678} ) @@ -252,7 +336,7 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None: assert mock_client.voice_stopped.timestamp == 5678 # Speech-to-text transcription - event_callback( + pipeline_event_callback( assist_pipeline.PipelineEvent( assist_pipeline.PipelineEventType.STT_END, {"stt_output": {"text": "test transcript"}}, @@ -265,7 +349,7 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None: assert mock_client.transcript.text == "test transcript" # Text-to-speech text - event_callback( + pipeline_event_callback( assist_pipeline.PipelineEvent( assist_pipeline.PipelineEventType.TTS_START, { @@ -283,7 +367,7 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None: assert mock_client.synthesize.voice.name == "test voice" # Text-to-speech media - event_callback( + pipeline_event_callback( assist_pipeline.PipelineEvent( assist_pipeline.PipelineEventType.TTS_END, {"tts_output": {"media_id": "test media id"}}, @@ -302,11 +386,21 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None: assert mock_client.tts_audio_chunk.audio == b"123" # Pipeline finished - event_callback( + pipeline_event_callback( assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END) ) assert not device.is_active + # The client should have received another ping by now + async with asyncio.timeout(1): + await mock_client.ping_event.wait() + + assert mock_client.ping is not None + + # Pipeline should automatically restart + async with asyncio.timeout(1): + await run_pipeline_called.wait() + # Stop the satellite await hass.config_entries.async_unload(entry.entry_id) await hass.async_block_till_done() @@ -317,6 +411,7 @@ async def test_satellite_muted(hass: HomeAssistant) -> None: on_muted_event = asyncio.Event() original_make_satellite = wyoming._make_satellite + original_on_muted = wyoming.satellite.WyomingSatellite.on_muted def make_muted_satellite( hass: HomeAssistant, config_entry: ConfigEntry, service: WyomingService @@ -327,6 +422,14 @@ async def test_satellite_muted(hass: HomeAssistant) -> None: return satellite async def on_muted(self): + # Trigger original function + self._muted_changed_event.set() + await original_on_muted(self) + + # Ensure satellite stops + self.is_running = False + + # Proceed with test self.device.set_is_muted(False) on_muted_event.set() @@ -339,16 +442,23 @@ async def test_satellite_muted(hass: HomeAssistant) -> None: "homeassistant.components.wyoming.satellite.WyomingSatellite.on_muted", on_muted, ): - await setup_config_entry(hass) + entry = await setup_config_entry(hass) async with asyncio.timeout(1): await on_muted_event.wait() + # Stop the satellite + await hass.config_entries.async_unload(entry.entry_id) + await hass.async_block_till_done() + async def test_satellite_restart(hass: HomeAssistant) -> None: """Test pipeline loop restart after unexpected error.""" on_restart_event = asyncio.Event() + original_on_restart = wyoming.satellite.WyomingSatellite.on_restart + async def on_restart(self): + await original_on_restart(self) self.stop() on_restart_event.set() @@ -356,12 +466,12 @@ async def test_satellite_restart(hass: HomeAssistant) -> None: "homeassistant.components.wyoming.data.load_wyoming_info", return_value=SATELLITE_INFO, ), patch( - "homeassistant.components.wyoming.satellite.WyomingSatellite._run_once", + "homeassistant.components.wyoming.satellite.WyomingSatellite._connect_and_loop", side_effect=RuntimeError(), ), patch( "homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart", on_restart, - ): + ), patch("homeassistant.components.wyoming.satellite._RESTART_SECONDS", 0): await setup_config_entry(hass) async with asyncio.timeout(1): await on_restart_event.wait() @@ -373,7 +483,11 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None: reconnect_event = asyncio.Event() stopped_event = asyncio.Event() + original_on_reconnect = wyoming.satellite.WyomingSatellite.on_reconnect + async def on_reconnect(self): + await original_on_reconnect(self) + nonlocal num_reconnects num_reconnects += 1 if num_reconnects >= 2: @@ -395,7 +509,7 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None: ), patch( "homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped", on_stopped, - ): + ), patch("homeassistant.components.wyoming.satellite._RECONNECT_SECONDS", 0): await setup_config_entry(hass) async with asyncio.timeout(1): await reconnect_event.wait() @@ -519,3 +633,338 @@ async def test_satellite_error_during_pipeline(hass: HomeAssistant) -> None: assert mock_client.error is not None assert mock_client.error.text == "test message" assert mock_client.error.code == "test code" + + +async def test_tts_not_wav(hass: HomeAssistant) -> None: + """Test satellite receiving non-WAV audio from text-to-speech.""" + assert await async_setup_component(hass, assist_pipeline.DOMAIN, {}) + + original_stream_tts = wyoming.satellite.WyomingSatellite._stream_tts + error_event = asyncio.Event() + + async def _stream_tts(self, media_id): + try: + await original_stream_tts(self, media_id) + except ValueError: + error_event.set() + + events = [ + RunPipeline(start_stage=PipelineStage.TTS, end_stage=PipelineStage.TTS).event(), + ] + + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), patch( + "homeassistant.components.wyoming.satellite.AsyncTcpClient", + SatelliteAsyncTcpClient(events), + ) as mock_client, patch( + "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + ) as mock_run_pipeline, patch( + "homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio", + return_value=("mp3", bytes(1)), + ), patch( + "homeassistant.components.wyoming.satellite.WyomingSatellite._stream_tts", + _stream_tts, + ): + entry = await setup_config_entry(hass) + async with asyncio.timeout(1): + await mock_client.connect_event.wait() + await mock_client.run_satellite_event.wait() + + mock_run_pipeline.assert_called_once() + event_callback = mock_run_pipeline.call_args.kwargs["event_callback"] + + # Text-to-speech text + event_callback( + assist_pipeline.PipelineEvent( + assist_pipeline.PipelineEventType.TTS_START, + { + "tts_input": "test text to speak", + "voice": "test voice", + }, + ) + ) + async with asyncio.timeout(1): + await mock_client.synthesize_event.wait() + + # Text-to-speech media + event_callback( + assist_pipeline.PipelineEvent( + assist_pipeline.PipelineEventType.TTS_END, + {"tts_output": {"media_id": "test media id"}}, + ) + ) + + # Expect error because only WAV is supported + async with asyncio.timeout(1): + await error_event.wait() + + # Stop the satellite + await hass.config_entries.async_unload(entry.entry_id) + await hass.async_block_till_done() + + +async def test_pipeline_changed(hass: HomeAssistant) -> None: + """Test that changing the pipeline setting stops the current pipeline.""" + assert await async_setup_component(hass, assist_pipeline.DOMAIN, {}) + + events = [ + RunPipeline( + start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS + ).event(), + ] + + pipeline_event_callback: Callable[ + [assist_pipeline.PipelineEvent], None + ] | None = None + run_pipeline_called = asyncio.Event() + pipeline_stopped = asyncio.Event() + + async def async_pipeline_from_audio_stream( + hass: HomeAssistant, + context, + event_callback, + stt_metadata, + stt_stream, + **kwargs, + ) -> None: + nonlocal pipeline_event_callback + pipeline_event_callback = event_callback + + run_pipeline_called.set() + async for _chunk in stt_stream: + pass + + pipeline_stopped.set() + + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), patch( + "homeassistant.components.wyoming.satellite.AsyncTcpClient", + SatelliteAsyncTcpClient(events), + ) as mock_client, patch( + "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + async_pipeline_from_audio_stream, + ): + entry = await setup_config_entry(hass) + device: SatelliteDevice = hass.data[wyoming.DOMAIN][ + entry.entry_id + ].satellite.device + + async with asyncio.timeout(1): + await mock_client.connect_event.wait() + await mock_client.run_satellite_event.wait() + + # Pipeline has started + async with asyncio.timeout(1): + await run_pipeline_called.wait() + + assert pipeline_event_callback is not None + + # Change pipelines + device.set_pipeline_name("different pipeline") + + # Running pipeline should be cancelled + async with asyncio.timeout(1): + await pipeline_stopped.wait() + + # Stop the satellite + await hass.config_entries.async_unload(entry.entry_id) + await hass.async_block_till_done() + + +async def test_audio_settings_changed(hass: HomeAssistant) -> None: + """Test that changing audio settings stops the current pipeline.""" + assert await async_setup_component(hass, assist_pipeline.DOMAIN, {}) + + events = [ + RunPipeline( + start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS + ).event(), + ] + + pipeline_event_callback: Callable[ + [assist_pipeline.PipelineEvent], None + ] | None = None + run_pipeline_called = asyncio.Event() + pipeline_stopped = asyncio.Event() + + async def async_pipeline_from_audio_stream( + hass: HomeAssistant, + context, + event_callback, + stt_metadata, + stt_stream, + **kwargs, + ) -> None: + nonlocal pipeline_event_callback + pipeline_event_callback = event_callback + + run_pipeline_called.set() + async for _chunk in stt_stream: + pass + + pipeline_stopped.set() + + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), patch( + "homeassistant.components.wyoming.satellite.AsyncTcpClient", + SatelliteAsyncTcpClient(events), + ) as mock_client, patch( + "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + async_pipeline_from_audio_stream, + ): + entry = await setup_config_entry(hass) + device: SatelliteDevice = hass.data[wyoming.DOMAIN][ + entry.entry_id + ].satellite.device + + async with asyncio.timeout(1): + await mock_client.connect_event.wait() + await mock_client.run_satellite_event.wait() + + # Pipeline has started + async with asyncio.timeout(1): + await run_pipeline_called.wait() + + assert pipeline_event_callback is not None + + # Change audio setting + device.set_noise_suppression_level(1) + + # Running pipeline should be cancelled + async with asyncio.timeout(1): + await pipeline_stopped.wait() + + # Stop the satellite + await hass.config_entries.async_unload(entry.entry_id) + await hass.async_block_till_done() + + +async def test_invalid_stages(hass: HomeAssistant) -> None: + """Test error when providing invalid pipeline stages.""" + assert await async_setup_component(hass, assist_pipeline.DOMAIN, {}) + + events = [ + RunPipeline( + start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS + ).event(), + ] + + original_run_pipeline_once = wyoming.satellite.WyomingSatellite._run_pipeline_once + start_stage_event = asyncio.Event() + end_stage_event = asyncio.Event() + + def _run_pipeline_once(self, run_pipeline): + # Set bad start stage + run_pipeline.start_stage = PipelineStage.INTENT + run_pipeline.end_stage = PipelineStage.TTS + + try: + original_run_pipeline_once(self, run_pipeline) + except ValueError: + start_stage_event.set() + + # Set bad end stage + run_pipeline.start_stage = PipelineStage.WAKE + run_pipeline.end_stage = PipelineStage.INTENT + + try: + original_run_pipeline_once(self, run_pipeline) + except ValueError: + end_stage_event.set() + + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), patch( + "homeassistant.components.wyoming.satellite.AsyncTcpClient", + SatelliteAsyncTcpClient(events), + ) as mock_client, patch( + "homeassistant.components.wyoming.satellite.WyomingSatellite._run_pipeline_once", + _run_pipeline_once, + ): + entry = await setup_config_entry(hass) + + async with asyncio.timeout(1): + await mock_client.connect_event.wait() + await mock_client.run_satellite_event.wait() + + async with asyncio.timeout(1): + await start_stage_event.wait() + await end_stage_event.wait() + + # Stop the satellite + await hass.config_entries.async_unload(entry.entry_id) + await hass.async_block_till_done() + + +async def test_client_stops_pipeline(hass: HomeAssistant) -> None: + """Test that an AudioStop message stops the current pipeline.""" + assert await async_setup_component(hass, assist_pipeline.DOMAIN, {}) + + events = [ + RunPipeline( + start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS + ).event(), + ] + + pipeline_event_callback: Callable[ + [assist_pipeline.PipelineEvent], None + ] | None = None + run_pipeline_called = asyncio.Event() + pipeline_stopped = asyncio.Event() + + async def async_pipeline_from_audio_stream( + hass: HomeAssistant, + context, + event_callback, + stt_metadata, + stt_stream, + **kwargs, + ) -> None: + nonlocal pipeline_event_callback + pipeline_event_callback = event_callback + + run_pipeline_called.set() + async for _chunk in stt_stream: + pass + + pipeline_stopped.set() + + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), patch( + "homeassistant.components.wyoming.satellite.AsyncTcpClient", + SatelliteAsyncTcpClient(events), + ) as mock_client, patch( + "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + async_pipeline_from_audio_stream, + ): + entry = await setup_config_entry(hass) + + async with asyncio.timeout(1): + await mock_client.connect_event.wait() + await mock_client.run_satellite_event.wait() + + # Pipeline has started + async with asyncio.timeout(1): + await run_pipeline_called.wait() + + assert pipeline_event_callback is not None + + # Client sends stop message + mock_client.inject_event(AudioStop().event()) + + # Running pipeline should be cancelled + async with asyncio.timeout(1): + await pipeline_stopped.wait() + + # Stop the satellite + await hass.config_entries.async_unload(entry.entry_id) + await hass.async_block_till_done()