From 48498844678126c1c1472758155170a5cf919ca6 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Mon, 11 Dec 2023 10:18:46 -0600 Subject: [PATCH] Disconnect before reconnecting to satellite (#105500) Disconnect before reconnecting --- homeassistant/components/wyoming/satellite.py | 26 ++++++++++++++++--- tests/components/wyoming/test_satellite.py | 23 ++++++++++++---- 2 files changed, 41 insertions(+), 8 deletions(-) diff --git a/homeassistant/components/wyoming/satellite.py b/homeassistant/components/wyoming/satellite.py index 0e8e5d62f4b..45d882dd1e2 100644 --- a/homeassistant/components/wyoming/satellite.py +++ b/homeassistant/components/wyoming/satellite.py @@ -70,11 +70,11 @@ class WyomingSatellite: while self.is_running: try: # Check if satellite has been disabled - if not self.device.is_enabled: + while not self.device.is_enabled: await self.on_disabled() if not self.is_running: # Satellite was stopped while waiting to be enabled - break + return # Connect and run pipeline loop await self._run_once() @@ -86,7 +86,7 @@ class WyomingSatellite: # Ensure sensor is off self.device.set_is_active(False) - await self.on_stopped() + await self.on_stopped() def stop(self) -> None: """Signal satellite task to stop running.""" @@ -129,6 +129,7 @@ class WyomingSatellite: self._audio_queue.put_nowait(None) self._enabled_changed_event.set() + self._enabled_changed_event.clear() def _pipeline_changed(self) -> None: """Run when device pipeline changes.""" @@ -243,9 +244,17 @@ class WyomingSatellite: chunk = AudioChunk.from_event(client_event) chunk = self._chunk_converter.convert(chunk) self._audio_queue.put_nowait(chunk.audio) + elif AudioStop.is_type(client_event.type): + # Stop pipeline + _LOGGER.debug("Client requested pipeline to stop") + self._audio_queue.put_nowait(b"") + break else: _LOGGER.debug("Unexpected event from satellite: %s", client_event) + # Ensure task finishes + await _pipeline_task + _LOGGER.debug("Pipeline finished") def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None: @@ -336,12 +345,23 @@ class WyomingSatellite: async def _connect(self) -> None: """Connect to satellite over TCP.""" + await self._disconnect() + _LOGGER.debug( "Connecting to satellite at %s:%s", self.service.host, self.service.port ) self._client = AsyncTcpClient(self.service.host, self.service.port) await self._client.connect() + async def _disconnect(self) -> None: + """Disconnect if satellite is currently connected.""" + if self._client is None: + return + + _LOGGER.debug("Disconnecting from satellite") + await self._client.disconnect() + self._client = None + async def _stream_tts(self, media_id: str) -> None: """Stream TTS WAV audio to satellite in chunks.""" assert self._client is not None diff --git a/tests/components/wyoming/test_satellite.py b/tests/components/wyoming/test_satellite.py index 50252007aa5..83e4d98d971 100644 --- a/tests/components/wyoming/test_satellite.py +++ b/tests/components/wyoming/test_satellite.py @@ -322,11 +322,12 @@ async def test_satellite_disabled(hass: HomeAssistant) -> None: hass: HomeAssistant, config_entry: ConfigEntry, service: WyomingService ): satellite = original_make_satellite(hass, config_entry, service) - satellite.device.is_enabled = False + satellite.device.set_is_enabled(False) return satellite async def on_disabled(self): + self.device.set_is_enabled(True) on_disabled_event.set() with patch( @@ -368,11 +369,19 @@ async def test_satellite_restart(hass: HomeAssistant) -> None: async def test_satellite_reconnect(hass: HomeAssistant) -> None: """Test satellite reconnect call after connection refused.""" - on_reconnect_event = asyncio.Event() + num_reconnects = 0 + reconnect_event = asyncio.Event() + stopped_event = asyncio.Event() async def on_reconnect(self): - self.stop() - on_reconnect_event.set() + nonlocal num_reconnects + num_reconnects += 1 + if num_reconnects >= 2: + reconnect_event.set() + self.stop() + + async def on_stopped(self): + stopped_event.set() with patch( "homeassistant.components.wyoming.data.load_wyoming_info", @@ -383,10 +392,14 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None: ), patch( "homeassistant.components.wyoming.satellite.WyomingSatellite.on_reconnect", on_reconnect, + ), patch( + "homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped", + on_stopped, ): await setup_config_entry(hass) async with asyncio.timeout(1): - await on_reconnect_event.wait() + await reconnect_event.wait() + await stopped_event.wait() async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None: