Disconnect before reconnecting to satellite (#105500)
Disconnect before reconnecting
This commit is contained in:
parent
8a0b1637b1
commit
4849884467
2 changed files with 41 additions and 8 deletions
|
@ -70,11 +70,11 @@ class WyomingSatellite:
|
|||
while self.is_running:
|
||||
try:
|
||||
# Check if satellite has been disabled
|
||||
if not self.device.is_enabled:
|
||||
while not self.device.is_enabled:
|
||||
await self.on_disabled()
|
||||
if not self.is_running:
|
||||
# Satellite was stopped while waiting to be enabled
|
||||
break
|
||||
return
|
||||
|
||||
# Connect and run pipeline loop
|
||||
await self._run_once()
|
||||
|
@ -86,7 +86,7 @@ class WyomingSatellite:
|
|||
# Ensure sensor is off
|
||||
self.device.set_is_active(False)
|
||||
|
||||
await self.on_stopped()
|
||||
await self.on_stopped()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal satellite task to stop running."""
|
||||
|
@ -129,6 +129,7 @@ class WyomingSatellite:
|
|||
self._audio_queue.put_nowait(None)
|
||||
|
||||
self._enabled_changed_event.set()
|
||||
self._enabled_changed_event.clear()
|
||||
|
||||
def _pipeline_changed(self) -> None:
|
||||
"""Run when device pipeline changes."""
|
||||
|
@ -243,9 +244,17 @@ class WyomingSatellite:
|
|||
chunk = AudioChunk.from_event(client_event)
|
||||
chunk = self._chunk_converter.convert(chunk)
|
||||
self._audio_queue.put_nowait(chunk.audio)
|
||||
elif AudioStop.is_type(client_event.type):
|
||||
# Stop pipeline
|
||||
_LOGGER.debug("Client requested pipeline to stop")
|
||||
self._audio_queue.put_nowait(b"")
|
||||
break
|
||||
else:
|
||||
_LOGGER.debug("Unexpected event from satellite: %s", client_event)
|
||||
|
||||
# Ensure task finishes
|
||||
await _pipeline_task
|
||||
|
||||
_LOGGER.debug("Pipeline finished")
|
||||
|
||||
def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None:
|
||||
|
@ -336,12 +345,23 @@ class WyomingSatellite:
|
|||
|
||||
async def _connect(self) -> None:
|
||||
"""Connect to satellite over TCP."""
|
||||
await self._disconnect()
|
||||
|
||||
_LOGGER.debug(
|
||||
"Connecting to satellite at %s:%s", self.service.host, self.service.port
|
||||
)
|
||||
self._client = AsyncTcpClient(self.service.host, self.service.port)
|
||||
await self._client.connect()
|
||||
|
||||
async def _disconnect(self) -> None:
|
||||
"""Disconnect if satellite is currently connected."""
|
||||
if self._client is None:
|
||||
return
|
||||
|
||||
_LOGGER.debug("Disconnecting from satellite")
|
||||
await self._client.disconnect()
|
||||
self._client = None
|
||||
|
||||
async def _stream_tts(self, media_id: str) -> None:
|
||||
"""Stream TTS WAV audio to satellite in chunks."""
|
||||
assert self._client is not None
|
||||
|
|
|
@ -322,11 +322,12 @@ async def test_satellite_disabled(hass: HomeAssistant) -> None:
|
|||
hass: HomeAssistant, config_entry: ConfigEntry, service: WyomingService
|
||||
):
|
||||
satellite = original_make_satellite(hass, config_entry, service)
|
||||
satellite.device.is_enabled = False
|
||||
satellite.device.set_is_enabled(False)
|
||||
|
||||
return satellite
|
||||
|
||||
async def on_disabled(self):
|
||||
self.device.set_is_enabled(True)
|
||||
on_disabled_event.set()
|
||||
|
||||
with patch(
|
||||
|
@ -368,11 +369,19 @@ async def test_satellite_restart(hass: HomeAssistant) -> None:
|
|||
|
||||
async def test_satellite_reconnect(hass: HomeAssistant) -> None:
|
||||
"""Test satellite reconnect call after connection refused."""
|
||||
on_reconnect_event = asyncio.Event()
|
||||
num_reconnects = 0
|
||||
reconnect_event = asyncio.Event()
|
||||
stopped_event = asyncio.Event()
|
||||
|
||||
async def on_reconnect(self):
|
||||
self.stop()
|
||||
on_reconnect_event.set()
|
||||
nonlocal num_reconnects
|
||||
num_reconnects += 1
|
||||
if num_reconnects >= 2:
|
||||
reconnect_event.set()
|
||||
self.stop()
|
||||
|
||||
async def on_stopped(self):
|
||||
stopped_event.set()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
|
@ -383,10 +392,14 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None:
|
|||
), patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_reconnect",
|
||||
on_reconnect,
|
||||
), patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
|
||||
on_stopped,
|
||||
):
|
||||
await setup_config_entry(hass)
|
||||
async with asyncio.timeout(1):
|
||||
await on_reconnect_event.wait()
|
||||
await reconnect_event.wait()
|
||||
await stopped_event.wait()
|
||||
|
||||
|
||||
async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None:
|
||||
|
|
Loading…
Add table
Reference in a new issue