Disconnect before reconnecting to satellite (#105500)

Disconnect before reconnecting
This commit is contained in:
Michael Hansen 2023-12-11 10:18:46 -06:00 committed by Franck Nijhof
parent 8a0b1637b1
commit 4849884467
No known key found for this signature in database
GPG key ID: D62583BA8AB11CA3
2 changed files with 41 additions and 8 deletions

View file

@ -70,11 +70,11 @@ class WyomingSatellite:
while self.is_running:
try:
# Check if satellite has been disabled
if not self.device.is_enabled:
while not self.device.is_enabled:
await self.on_disabled()
if not self.is_running:
# Satellite was stopped while waiting to be enabled
break
return
# Connect and run pipeline loop
await self._run_once()
@ -86,7 +86,7 @@ class WyomingSatellite:
# Ensure sensor is off
self.device.set_is_active(False)
await self.on_stopped()
await self.on_stopped()
def stop(self) -> None:
"""Signal satellite task to stop running."""
@ -129,6 +129,7 @@ class WyomingSatellite:
self._audio_queue.put_nowait(None)
self._enabled_changed_event.set()
self._enabled_changed_event.clear()
def _pipeline_changed(self) -> None:
"""Run when device pipeline changes."""
@ -243,9 +244,17 @@ class WyomingSatellite:
chunk = AudioChunk.from_event(client_event)
chunk = self._chunk_converter.convert(chunk)
self._audio_queue.put_nowait(chunk.audio)
elif AudioStop.is_type(client_event.type):
# Stop pipeline
_LOGGER.debug("Client requested pipeline to stop")
self._audio_queue.put_nowait(b"")
break
else:
_LOGGER.debug("Unexpected event from satellite: %s", client_event)
# Ensure task finishes
await _pipeline_task
_LOGGER.debug("Pipeline finished")
def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None:
@ -336,12 +345,23 @@ class WyomingSatellite:
async def _connect(self) -> None:
"""Connect to satellite over TCP."""
await self._disconnect()
_LOGGER.debug(
"Connecting to satellite at %s:%s", self.service.host, self.service.port
)
self._client = AsyncTcpClient(self.service.host, self.service.port)
await self._client.connect()
async def _disconnect(self) -> None:
"""Disconnect if satellite is currently connected."""
if self._client is None:
return
_LOGGER.debug("Disconnecting from satellite")
await self._client.disconnect()
self._client = None
async def _stream_tts(self, media_id: str) -> None:
"""Stream TTS WAV audio to satellite in chunks."""
assert self._client is not None

View file

@ -322,11 +322,12 @@ async def test_satellite_disabled(hass: HomeAssistant) -> None:
hass: HomeAssistant, config_entry: ConfigEntry, service: WyomingService
):
satellite = original_make_satellite(hass, config_entry, service)
satellite.device.is_enabled = False
satellite.device.set_is_enabled(False)
return satellite
async def on_disabled(self):
self.device.set_is_enabled(True)
on_disabled_event.set()
with patch(
@ -368,11 +369,19 @@ async def test_satellite_restart(hass: HomeAssistant) -> None:
async def test_satellite_reconnect(hass: HomeAssistant) -> None:
"""Test satellite reconnect call after connection refused."""
on_reconnect_event = asyncio.Event()
num_reconnects = 0
reconnect_event = asyncio.Event()
stopped_event = asyncio.Event()
async def on_reconnect(self):
self.stop()
on_reconnect_event.set()
nonlocal num_reconnects
num_reconnects += 1
if num_reconnects >= 2:
reconnect_event.set()
self.stop()
async def on_stopped(self):
stopped_event.set()
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
@ -383,10 +392,14 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None:
), patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_reconnect",
on_reconnect,
), patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
on_stopped,
):
await setup_config_entry(hass)
async with asyncio.timeout(1):
await on_reconnect_event.wait()
await reconnect_event.wait()
await stopped_event.wait()
async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None: