Wyoming satellite ping and bugfix for local wake word (#108164)
* Refactor with ping * Fix tests * Increase test coverage
This commit is contained in:
parent
7dffc9f515
commit
db81f4d046
5 changed files with 627 additions and 117 deletions
|
@ -6,6 +6,6 @@
|
|||
"dependencies": ["assist_pipeline"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/wyoming",
|
||||
"iot_class": "local_push",
|
||||
"requirements": ["wyoming==1.4.0"],
|
||||
"requirements": ["wyoming==1.5.0"],
|
||||
"zeroconf": ["_wyoming._tcp.local."]
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ from wyoming.asr import Transcribe, Transcript
|
|||
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStart, AudioStop
|
||||
from wyoming.client import AsyncTcpClient
|
||||
from wyoming.error import Error
|
||||
from wyoming.ping import Ping, Pong
|
||||
from wyoming.pipeline import PipelineStage, RunPipeline
|
||||
from wyoming.satellite import RunSatellite
|
||||
from wyoming.tts import Synthesize, SynthesizeVoice
|
||||
|
@ -29,6 +30,9 @@ _LOGGER = logging.getLogger()
|
|||
_SAMPLES_PER_CHUNK: Final = 1024
|
||||
_RECONNECT_SECONDS: Final = 10
|
||||
_RESTART_SECONDS: Final = 3
|
||||
_PING_TIMEOUT: Final = 5
|
||||
_PING_SEND_DELAY: Final = 2
|
||||
_PIPELINE_FINISH_TIMEOUT: Final = 1
|
||||
|
||||
# Wyoming stage -> Assist stage
|
||||
_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
|
||||
|
@ -54,6 +58,7 @@ class WyomingSatellite:
|
|||
self._client: AsyncTcpClient | None = None
|
||||
self._chunk_converter = AudioChunkConverter(rate=16000, width=2, channels=1)
|
||||
self._is_pipeline_running = False
|
||||
self._pipeline_ended_event = asyncio.Event()
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._pipeline_id: str | None = None
|
||||
self._muted_changed_event = asyncio.Event()
|
||||
|
@ -77,9 +82,9 @@ class WyomingSatellite:
|
|||
return
|
||||
|
||||
# Connect and run pipeline loop
|
||||
await self._run_once()
|
||||
await self._connect_and_loop()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
raise # don't restart
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
await self.on_restart()
|
||||
finally:
|
||||
|
@ -142,8 +147,8 @@ class WyomingSatellite:
|
|||
# Cancel any running pipeline
|
||||
self._audio_queue.put_nowait(None)
|
||||
|
||||
async def _run_once(self) -> None:
|
||||
"""Run pipelines until an error occurs."""
|
||||
async def _connect_and_loop(self) -> None:
|
||||
"""Connect to satellite and run pipelines until an error occurs."""
|
||||
self.device.set_is_active(False)
|
||||
|
||||
while self.is_running and (not self.device.is_muted):
|
||||
|
@ -163,27 +168,94 @@ class WyomingSatellite:
|
|||
# Tell satellite that we're ready
|
||||
await self._client.write_event(RunSatellite().event())
|
||||
|
||||
# Wait until we get RunPipeline event
|
||||
run_pipeline: RunPipeline | None = None
|
||||
# Run until stopped or muted
|
||||
while self.is_running and (not self.device.is_muted):
|
||||
run_event = await self._client.read_event()
|
||||
if run_event is None:
|
||||
raise ConnectionResetError("Satellite disconnected")
|
||||
await self._run_pipeline_loop()
|
||||
|
||||
if RunPipeline.is_type(run_event.type):
|
||||
run_pipeline = RunPipeline.from_event(run_event)
|
||||
break
|
||||
async def _run_pipeline_loop(self) -> None:
|
||||
"""Run a pipeline one or more times."""
|
||||
assert self._client is not None
|
||||
run_pipeline: RunPipeline | None = None
|
||||
send_ping = True
|
||||
|
||||
_LOGGER.debug("Unexpected event from satellite: %s", run_event)
|
||||
# Read events and check for pipeline end in parallel
|
||||
pipeline_ended_task = self.hass.async_create_background_task(
|
||||
self._pipeline_ended_event.wait(), "satellite pipeline ended"
|
||||
)
|
||||
client_event_task = self.hass.async_create_background_task(
|
||||
self._client.read_event(), "satellite event read"
|
||||
)
|
||||
pending = {pipeline_ended_task, client_event_task}
|
||||
|
||||
assert run_pipeline is not None
|
||||
while self.is_running and (not self.device.is_muted):
|
||||
if send_ping:
|
||||
# Ensure satellite is still connected
|
||||
send_ping = False
|
||||
self.hass.async_create_background_task(
|
||||
self._send_delayed_ping(), "ping satellite"
|
||||
)
|
||||
|
||||
async with asyncio.timeout(_PING_TIMEOUT):
|
||||
done, pending = await asyncio.wait(
|
||||
pending, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
if pipeline_ended_task in done:
|
||||
# Pipeline run end event was received
|
||||
_LOGGER.debug("Pipeline finished")
|
||||
self._pipeline_ended_event.clear()
|
||||
pipeline_ended_task = self.hass.async_create_background_task(
|
||||
self._pipeline_ended_event.wait(), "satellite pipeline ended"
|
||||
)
|
||||
pending.add(pipeline_ended_task)
|
||||
|
||||
if (run_pipeline is not None) and run_pipeline.restart_on_end:
|
||||
# Automatically restart pipeline.
|
||||
# Used with "always on" streaming satellites.
|
||||
self._run_pipeline_once(run_pipeline)
|
||||
continue
|
||||
|
||||
if client_event_task not in done:
|
||||
continue
|
||||
|
||||
client_event = client_event_task.result()
|
||||
if client_event is None:
|
||||
raise ConnectionResetError("Satellite disconnected")
|
||||
|
||||
if Pong.is_type(client_event.type):
|
||||
# Satellite is still there, send next ping
|
||||
send_ping = True
|
||||
elif Ping.is_type(client_event.type):
|
||||
# Respond to ping from satellite
|
||||
ping = Ping.from_event(client_event)
|
||||
await self._client.write_event(Pong(text=ping.text).event())
|
||||
elif RunPipeline.is_type(client_event.type):
|
||||
# Satellite requested pipeline run
|
||||
run_pipeline = RunPipeline.from_event(client_event)
|
||||
self._run_pipeline_once(run_pipeline)
|
||||
elif (
|
||||
AudioChunk.is_type(client_event.type) and self._is_pipeline_running
|
||||
):
|
||||
# Microphone audio
|
||||
chunk = AudioChunk.from_event(client_event)
|
||||
chunk = self._chunk_converter.convert(chunk)
|
||||
self._audio_queue.put_nowait(chunk.audio)
|
||||
elif AudioStop.is_type(client_event.type) and self._is_pipeline_running:
|
||||
# Stop pipeline
|
||||
_LOGGER.debug("Client requested pipeline to stop")
|
||||
self._audio_queue.put_nowait(b"")
|
||||
else:
|
||||
_LOGGER.debug("Unexpected event from satellite: %s", client_event)
|
||||
|
||||
# Next event
|
||||
client_event_task = self.hass.async_create_background_task(
|
||||
self._client.read_event(), "satellite event read"
|
||||
)
|
||||
pending.add(client_event_task)
|
||||
|
||||
def _run_pipeline_once(self, run_pipeline: RunPipeline) -> None:
|
||||
"""Run a pipeline once."""
|
||||
_LOGGER.debug("Received run information: %s", run_pipeline)
|
||||
|
||||
if (not self.is_running) or self.device.is_muted:
|
||||
# Run was cancelled or satellite was disabled while waiting for
|
||||
# RunPipeline event.
|
||||
return
|
||||
|
||||
start_stage = _STAGES.get(run_pipeline.start_stage)
|
||||
end_stage = _STAGES.get(run_pipeline.end_stage)
|
||||
|
||||
|
@ -193,79 +265,64 @@ class WyomingSatellite:
|
|||
if end_stage is None:
|
||||
raise ValueError(f"Invalid end stage: {end_stage}")
|
||||
|
||||
# Each loop is a pipeline run
|
||||
while self.is_running and (not self.device.is_muted):
|
||||
# Use select to get pipeline each time in case it's changed
|
||||
pipeline_id = pipeline_select.get_chosen_pipeline(
|
||||
pipeline_id = pipeline_select.get_chosen_pipeline(
|
||||
self.hass,
|
||||
DOMAIN,
|
||||
self.device.satellite_id,
|
||||
)
|
||||
pipeline = assist_pipeline.async_get_pipeline(self.hass, pipeline_id)
|
||||
assert pipeline is not None
|
||||
|
||||
# We will push audio in through a queue
|
||||
self._audio_queue = asyncio.Queue()
|
||||
stt_stream = self._stt_stream()
|
||||
|
||||
# Start pipeline running
|
||||
_LOGGER.debug(
|
||||
"Starting pipeline %s from %s to %s",
|
||||
pipeline.name,
|
||||
start_stage,
|
||||
end_stage,
|
||||
)
|
||||
self._is_pipeline_running = True
|
||||
self._pipeline_ended_event.clear()
|
||||
self.hass.async_create_background_task(
|
||||
assist_pipeline.async_pipeline_from_audio_stream(
|
||||
self.hass,
|
||||
DOMAIN,
|
||||
self.device.satellite_id,
|
||||
)
|
||||
pipeline = assist_pipeline.async_get_pipeline(self.hass, pipeline_id)
|
||||
assert pipeline is not None
|
||||
context=Context(),
|
||||
event_callback=self._event_callback,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language=pipeline.language,
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=stt_stream,
|
||||
start_stage=start_stage,
|
||||
end_stage=end_stage,
|
||||
tts_audio_output="wav",
|
||||
pipeline_id=pipeline_id,
|
||||
audio_settings=assist_pipeline.AudioSettings(
|
||||
noise_suppression_level=self.device.noise_suppression_level,
|
||||
auto_gain_dbfs=self.device.auto_gain,
|
||||
volume_multiplier=self.device.volume_multiplier,
|
||||
),
|
||||
device_id=self.device.device_id,
|
||||
),
|
||||
name="wyoming satellite pipeline",
|
||||
)
|
||||
|
||||
# We will push audio in through a queue
|
||||
self._audio_queue = asyncio.Queue()
|
||||
stt_stream = self._stt_stream()
|
||||
async def _send_delayed_ping(self) -> None:
|
||||
"""Send ping to satellite after a delay."""
|
||||
assert self._client is not None
|
||||
|
||||
# Start pipeline running
|
||||
_LOGGER.debug(
|
||||
"Starting pipeline %s from %s to %s",
|
||||
pipeline.name,
|
||||
start_stage,
|
||||
end_stage,
|
||||
)
|
||||
self._is_pipeline_running = True
|
||||
_pipeline_task = asyncio.create_task(
|
||||
assist_pipeline.async_pipeline_from_audio_stream(
|
||||
self.hass,
|
||||
context=Context(),
|
||||
event_callback=self._event_callback,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language=pipeline.language,
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=stt_stream,
|
||||
start_stage=start_stage,
|
||||
end_stage=end_stage,
|
||||
tts_audio_output="wav",
|
||||
pipeline_id=pipeline_id,
|
||||
audio_settings=assist_pipeline.AudioSettings(
|
||||
noise_suppression_level=self.device.noise_suppression_level,
|
||||
auto_gain_dbfs=self.device.auto_gain,
|
||||
volume_multiplier=self.device.volume_multiplier,
|
||||
),
|
||||
device_id=self.device.device_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Run until pipeline is complete or cancelled with an empty audio chunk
|
||||
while self._is_pipeline_running:
|
||||
client_event = await self._client.read_event()
|
||||
if client_event is None:
|
||||
raise ConnectionResetError("Satellite disconnected")
|
||||
|
||||
if AudioChunk.is_type(client_event.type):
|
||||
# Microphone audio
|
||||
chunk = AudioChunk.from_event(client_event)
|
||||
chunk = self._chunk_converter.convert(chunk)
|
||||
self._audio_queue.put_nowait(chunk.audio)
|
||||
elif AudioStop.is_type(client_event.type):
|
||||
# Stop pipeline
|
||||
_LOGGER.debug("Client requested pipeline to stop")
|
||||
self._audio_queue.put_nowait(b"")
|
||||
break
|
||||
else:
|
||||
_LOGGER.debug("Unexpected event from satellite: %s", client_event)
|
||||
|
||||
# Ensure task finishes
|
||||
await _pipeline_task
|
||||
|
||||
_LOGGER.debug("Pipeline finished")
|
||||
try:
|
||||
await asyncio.sleep(_PING_SEND_DELAY)
|
||||
await self._client.write_event(Ping().event())
|
||||
except ConnectionError:
|
||||
pass # handled with timeout
|
||||
|
||||
def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None:
|
||||
"""Translate pipeline events into Wyoming events."""
|
||||
|
@ -274,6 +331,7 @@ class WyomingSatellite:
|
|||
if event.type == assist_pipeline.PipelineEventType.RUN_END:
|
||||
# Pipeline run is complete
|
||||
self._is_pipeline_running = False
|
||||
self._pipeline_ended_event.set()
|
||||
self.device.set_is_active(False)
|
||||
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
|
||||
self.hass.add_job(self._client.write_event(Detect().event()))
|
||||
|
@ -413,10 +471,13 @@ class WyomingSatellite:
|
|||
|
||||
async def _stt_stream(self) -> AsyncGenerator[bytes, None]:
|
||||
"""Yield audio chunks from a queue."""
|
||||
is_first_chunk = True
|
||||
while chunk := await self._audio_queue.get():
|
||||
if is_first_chunk:
|
||||
is_first_chunk = False
|
||||
_LOGGER.debug("Receiving audio from satellite")
|
||||
try:
|
||||
is_first_chunk = True
|
||||
while chunk := await self._audio_queue.get():
|
||||
if is_first_chunk:
|
||||
is_first_chunk = False
|
||||
_LOGGER.debug("Receiving audio from satellite")
|
||||
|
||||
yield chunk
|
||||
yield chunk
|
||||
except asyncio.CancelledError:
|
||||
pass # ignore
|
||||
|
|
|
@ -2821,7 +2821,7 @@ wled==0.17.0
|
|||
wolf-smartset==0.1.11
|
||||
|
||||
# homeassistant.components.wyoming
|
||||
wyoming==1.4.0
|
||||
wyoming==1.5.0
|
||||
|
||||
# homeassistant.components.xbox
|
||||
xbox-webapi==2.0.11
|
||||
|
|
|
@ -2141,7 +2141,7 @@ wled==0.17.0
|
|||
wolf-smartset==0.1.11
|
||||
|
||||
# homeassistant.components.wyoming
|
||||
wyoming==1.4.0
|
||||
wyoming==1.5.0
|
||||
|
||||
# homeassistant.components.xbox
|
||||
xbox-webapi==2.0.11
|
||||
|
|
|
@ -2,7 +2,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import io
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
import wave
|
||||
|
||||
|
@ -10,6 +12,7 @@ from wyoming.asr import Transcribe, Transcript
|
|||
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
||||
from wyoming.error import Error
|
||||
from wyoming.event import Event
|
||||
from wyoming.ping import Ping, Pong
|
||||
from wyoming.pipeline import PipelineStage, RunPipeline
|
||||
from wyoming.satellite import RunSatellite
|
||||
from wyoming.tts import Synthesize
|
||||
|
@ -100,6 +103,12 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient):
|
|||
self.error_event = asyncio.Event()
|
||||
self.error: Error | None = None
|
||||
|
||||
self.pong_event = asyncio.Event()
|
||||
self.pong: Pong | None = None
|
||||
|
||||
self.ping_event = asyncio.Event()
|
||||
self.ping: Ping | None = None
|
||||
|
||||
self._mic_audio_chunk = AudioChunk(
|
||||
rate=16000, width=2, channels=1, audio=b"chunk"
|
||||
).event()
|
||||
|
@ -142,6 +151,12 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient):
|
|||
elif Error.is_type(event.type):
|
||||
self.error = Error.from_event(event)
|
||||
self.error_event.set()
|
||||
elif Pong.is_type(event.type):
|
||||
self.pong = Pong.from_event(event)
|
||||
self.pong_event.set()
|
||||
elif Ping.is_type(event.type):
|
||||
self.ping = Ping.from_event(event)
|
||||
self.ping_event.set()
|
||||
|
||||
async def read_event(self) -> Event | None:
|
||||
"""Receive."""
|
||||
|
@ -150,6 +165,10 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient):
|
|||
# Keep sending audio chunks instead of None
|
||||
return event or self._mic_audio_chunk
|
||||
|
||||
def inject_event(self, event: Event) -> None:
|
||||
"""Put an event in as the next response."""
|
||||
self.responses = [event] + self.responses
|
||||
|
||||
|
||||
async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
||||
"""Test running a pipeline with a satellite."""
|
||||
|
@ -157,10 +176,37 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
|||
|
||||
events = [
|
||||
RunPipeline(
|
||||
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
||||
start_stage=PipelineStage.WAKE,
|
||||
end_stage=PipelineStage.TTS,
|
||||
restart_on_end=True,
|
||||
).event(),
|
||||
]
|
||||
|
||||
pipeline_kwargs: dict[str, Any] = {}
|
||||
pipeline_event_callback: Callable[
|
||||
[assist_pipeline.PipelineEvent], None
|
||||
] | None = None
|
||||
run_pipeline_called = asyncio.Event()
|
||||
audio_chunk_received = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(
|
||||
hass: HomeAssistant,
|
||||
context,
|
||||
event_callback,
|
||||
stt_metadata,
|
||||
stt_stream,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
nonlocal pipeline_kwargs, pipeline_event_callback
|
||||
pipeline_kwargs = kwargs
|
||||
pipeline_event_callback = event_callback
|
||||
|
||||
run_pipeline_called.set()
|
||||
async for chunk in stt_stream:
|
||||
if chunk:
|
||||
audio_chunk_received.set()
|
||||
break
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
|
@ -169,10 +215,11 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
|||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client, patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
) as mock_run_pipeline, patch(
|
||||
async_pipeline_from_audio_stream,
|
||||
), patch(
|
||||
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
|
||||
return_value=("wav", get_test_wav()),
|
||||
):
|
||||
), patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0):
|
||||
entry = await setup_config_entry(hass)
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||
entry.entry_id
|
||||
|
@ -182,12 +229,39 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
|||
await mock_client.connect_event.wait()
|
||||
await mock_client.run_satellite_event.wait()
|
||||
|
||||
mock_run_pipeline.assert_called_once()
|
||||
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
|
||||
assert mock_run_pipeline.call_args.kwargs.get("device_id") == device.device_id
|
||||
async with asyncio.timeout(1):
|
||||
await run_pipeline_called.wait()
|
||||
|
||||
# Reset so we can check the pipeline is automatically restarted below
|
||||
run_pipeline_called.clear()
|
||||
|
||||
assert pipeline_event_callback is not None
|
||||
assert pipeline_kwargs.get("device_id") == device.device_id
|
||||
|
||||
# Test a ping
|
||||
mock_client.inject_event(Ping("test-ping").event())
|
||||
|
||||
# Pong is expected with the same text
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.pong_event.wait()
|
||||
|
||||
assert mock_client.pong is not None
|
||||
assert mock_client.pong.text == "test-ping"
|
||||
|
||||
# The client should have received the first ping
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.ping_event.wait()
|
||||
|
||||
assert mock_client.ping is not None
|
||||
|
||||
# Reset and send a pong back.
|
||||
# We will get a second ping by the end of the test.
|
||||
mock_client.ping_event.clear()
|
||||
mock_client.ping = None
|
||||
mock_client.inject_event(Pong().event())
|
||||
|
||||
# Start detecting wake word
|
||||
event_callback(
|
||||
pipeline_event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
assist_pipeline.PipelineEventType.WAKE_WORD_START
|
||||
)
|
||||
|
@ -198,8 +272,13 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
|||
assert not device.is_active
|
||||
assert not device.is_muted
|
||||
|
||||
# Push in some audio
|
||||
mock_client.inject_event(
|
||||
AudioChunk(rate=16000, width=2, channels=1, audio=bytes(1024)).event()
|
||||
)
|
||||
|
||||
# Wake word is detected
|
||||
event_callback(
|
||||
pipeline_event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
assist_pipeline.PipelineEventType.WAKE_WORD_END,
|
||||
{"wake_word_output": {"wake_word_id": "test_wake_word"}},
|
||||
|
@ -215,7 +294,7 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
|||
assert device.is_active
|
||||
|
||||
# Speech-to-text started
|
||||
event_callback(
|
||||
pipeline_event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
assist_pipeline.PipelineEventType.STT_START,
|
||||
{"metadata": {"language": "en"}},
|
||||
|
@ -227,8 +306,13 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
|||
assert mock_client.transcribe is not None
|
||||
assert mock_client.transcribe.language == "en"
|
||||
|
||||
# Push in some audio
|
||||
mock_client.inject_event(
|
||||
AudioChunk(rate=16000, width=2, channels=1, audio=bytes(1024)).event()
|
||||
)
|
||||
|
||||
# User started speaking
|
||||
event_callback(
|
||||
pipeline_event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
assist_pipeline.PipelineEventType.STT_VAD_START, {"timestamp": 1234}
|
||||
)
|
||||
|
@ -240,7 +324,7 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
|||
assert mock_client.voice_started.timestamp == 1234
|
||||
|
||||
# User stopped speaking
|
||||
event_callback(
|
||||
pipeline_event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
assist_pipeline.PipelineEventType.STT_VAD_END, {"timestamp": 5678}
|
||||
)
|
||||
|
@ -252,7 +336,7 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
|||
assert mock_client.voice_stopped.timestamp == 5678
|
||||
|
||||
# Speech-to-text transcription
|
||||
event_callback(
|
||||
pipeline_event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
assist_pipeline.PipelineEventType.STT_END,
|
||||
{"stt_output": {"text": "test transcript"}},
|
||||
|
@ -265,7 +349,7 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
|||
assert mock_client.transcript.text == "test transcript"
|
||||
|
||||
# Text-to-speech text
|
||||
event_callback(
|
||||
pipeline_event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
assist_pipeline.PipelineEventType.TTS_START,
|
||||
{
|
||||
|
@ -283,7 +367,7 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
|||
assert mock_client.synthesize.voice.name == "test voice"
|
||||
|
||||
# Text-to-speech media
|
||||
event_callback(
|
||||
pipeline_event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
assist_pipeline.PipelineEventType.TTS_END,
|
||||
{"tts_output": {"media_id": "test media id"}},
|
||||
|
@ -302,11 +386,21 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
|||
assert mock_client.tts_audio_chunk.audio == b"123"
|
||||
|
||||
# Pipeline finished
|
||||
event_callback(
|
||||
pipeline_event_callback(
|
||||
assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END)
|
||||
)
|
||||
assert not device.is_active
|
||||
|
||||
# The client should have received another ping by now
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.ping_event.wait()
|
||||
|
||||
assert mock_client.ping is not None
|
||||
|
||||
# Pipeline should automatically restart
|
||||
async with asyncio.timeout(1):
|
||||
await run_pipeline_called.wait()
|
||||
|
||||
# Stop the satellite
|
||||
await hass.config_entries.async_unload(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
@ -317,6 +411,7 @@ async def test_satellite_muted(hass: HomeAssistant) -> None:
|
|||
on_muted_event = asyncio.Event()
|
||||
|
||||
original_make_satellite = wyoming._make_satellite
|
||||
original_on_muted = wyoming.satellite.WyomingSatellite.on_muted
|
||||
|
||||
def make_muted_satellite(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry, service: WyomingService
|
||||
|
@ -327,6 +422,14 @@ async def test_satellite_muted(hass: HomeAssistant) -> None:
|
|||
return satellite
|
||||
|
||||
async def on_muted(self):
|
||||
# Trigger original function
|
||||
self._muted_changed_event.set()
|
||||
await original_on_muted(self)
|
||||
|
||||
# Ensure satellite stops
|
||||
self.is_running = False
|
||||
|
||||
# Proceed with test
|
||||
self.device.set_is_muted(False)
|
||||
on_muted_event.set()
|
||||
|
||||
|
@ -339,16 +442,23 @@ async def test_satellite_muted(hass: HomeAssistant) -> None:
|
|||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_muted",
|
||||
on_muted,
|
||||
):
|
||||
await setup_config_entry(hass)
|
||||
entry = await setup_config_entry(hass)
|
||||
async with asyncio.timeout(1):
|
||||
await on_muted_event.wait()
|
||||
|
||||
# Stop the satellite
|
||||
await hass.config_entries.async_unload(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def test_satellite_restart(hass: HomeAssistant) -> None:
|
||||
"""Test pipeline loop restart after unexpected error."""
|
||||
on_restart_event = asyncio.Event()
|
||||
|
||||
original_on_restart = wyoming.satellite.WyomingSatellite.on_restart
|
||||
|
||||
async def on_restart(self):
|
||||
await original_on_restart(self)
|
||||
self.stop()
|
||||
on_restart_event.set()
|
||||
|
||||
|
@ -356,12 +466,12 @@ async def test_satellite_restart(hass: HomeAssistant) -> None:
|
|||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite._run_once",
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite._connect_and_loop",
|
||||
side_effect=RuntimeError(),
|
||||
), patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
||||
on_restart,
|
||||
):
|
||||
), patch("homeassistant.components.wyoming.satellite._RESTART_SECONDS", 0):
|
||||
await setup_config_entry(hass)
|
||||
async with asyncio.timeout(1):
|
||||
await on_restart_event.wait()
|
||||
|
@ -373,7 +483,11 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None:
|
|||
reconnect_event = asyncio.Event()
|
||||
stopped_event = asyncio.Event()
|
||||
|
||||
original_on_reconnect = wyoming.satellite.WyomingSatellite.on_reconnect
|
||||
|
||||
async def on_reconnect(self):
|
||||
await original_on_reconnect(self)
|
||||
|
||||
nonlocal num_reconnects
|
||||
num_reconnects += 1
|
||||
if num_reconnects >= 2:
|
||||
|
@ -395,7 +509,7 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None:
|
|||
), patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
|
||||
on_stopped,
|
||||
):
|
||||
), patch("homeassistant.components.wyoming.satellite._RECONNECT_SECONDS", 0):
|
||||
await setup_config_entry(hass)
|
||||
async with asyncio.timeout(1):
|
||||
await reconnect_event.wait()
|
||||
|
@ -519,3 +633,338 @@ async def test_satellite_error_during_pipeline(hass: HomeAssistant) -> None:
|
|||
assert mock_client.error is not None
|
||||
assert mock_client.error.text == "test message"
|
||||
assert mock_client.error.code == "test code"
|
||||
|
||||
|
||||
async def test_tts_not_wav(hass: HomeAssistant) -> None:
|
||||
"""Test satellite receiving non-WAV audio from text-to-speech."""
|
||||
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
||||
|
||||
original_stream_tts = wyoming.satellite.WyomingSatellite._stream_tts
|
||||
error_event = asyncio.Event()
|
||||
|
||||
async def _stream_tts(self, media_id):
|
||||
try:
|
||||
await original_stream_tts(self, media_id)
|
||||
except ValueError:
|
||||
error_event.set()
|
||||
|
||||
events = [
|
||||
RunPipeline(start_stage=PipelineStage.TTS, end_stage=PipelineStage.TTS).event(),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client, patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
) as mock_run_pipeline, patch(
|
||||
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
|
||||
return_value=("mp3", bytes(1)),
|
||||
), patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite._stream_tts",
|
||||
_stream_tts,
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.connect_event.wait()
|
||||
await mock_client.run_satellite_event.wait()
|
||||
|
||||
mock_run_pipeline.assert_called_once()
|
||||
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
|
||||
|
||||
# Text-to-speech text
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
assist_pipeline.PipelineEventType.TTS_START,
|
||||
{
|
||||
"tts_input": "test text to speak",
|
||||
"voice": "test voice",
|
||||
},
|
||||
)
|
||||
)
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.synthesize_event.wait()
|
||||
|
||||
# Text-to-speech media
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
assist_pipeline.PipelineEventType.TTS_END,
|
||||
{"tts_output": {"media_id": "test media id"}},
|
||||
)
|
||||
)
|
||||
|
||||
# Expect error because only WAV is supported
|
||||
async with asyncio.timeout(1):
|
||||
await error_event.wait()
|
||||
|
||||
# Stop the satellite
|
||||
await hass.config_entries.async_unload(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def test_pipeline_changed(hass: HomeAssistant) -> None:
|
||||
"""Test that changing the pipeline setting stops the current pipeline."""
|
||||
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
||||
|
||||
events = [
|
||||
RunPipeline(
|
||||
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
||||
).event(),
|
||||
]
|
||||
|
||||
pipeline_event_callback: Callable[
|
||||
[assist_pipeline.PipelineEvent], None
|
||||
] | None = None
|
||||
run_pipeline_called = asyncio.Event()
|
||||
pipeline_stopped = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(
|
||||
hass: HomeAssistant,
|
||||
context,
|
||||
event_callback,
|
||||
stt_metadata,
|
||||
stt_stream,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
nonlocal pipeline_event_callback
|
||||
pipeline_event_callback = event_callback
|
||||
|
||||
run_pipeline_called.set()
|
||||
async for _chunk in stt_stream:
|
||||
pass
|
||||
|
||||
pipeline_stopped.set()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client, patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
async_pipeline_from_audio_stream,
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||
entry.entry_id
|
||||
].satellite.device
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.connect_event.wait()
|
||||
await mock_client.run_satellite_event.wait()
|
||||
|
||||
# Pipeline has started
|
||||
async with asyncio.timeout(1):
|
||||
await run_pipeline_called.wait()
|
||||
|
||||
assert pipeline_event_callback is not None
|
||||
|
||||
# Change pipelines
|
||||
device.set_pipeline_name("different pipeline")
|
||||
|
||||
# Running pipeline should be cancelled
|
||||
async with asyncio.timeout(1):
|
||||
await pipeline_stopped.wait()
|
||||
|
||||
# Stop the satellite
|
||||
await hass.config_entries.async_unload(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def test_audio_settings_changed(hass: HomeAssistant) -> None:
|
||||
"""Test that changing audio settings stops the current pipeline."""
|
||||
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
||||
|
||||
events = [
|
||||
RunPipeline(
|
||||
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
||||
).event(),
|
||||
]
|
||||
|
||||
pipeline_event_callback: Callable[
|
||||
[assist_pipeline.PipelineEvent], None
|
||||
] | None = None
|
||||
run_pipeline_called = asyncio.Event()
|
||||
pipeline_stopped = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(
|
||||
hass: HomeAssistant,
|
||||
context,
|
||||
event_callback,
|
||||
stt_metadata,
|
||||
stt_stream,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
nonlocal pipeline_event_callback
|
||||
pipeline_event_callback = event_callback
|
||||
|
||||
run_pipeline_called.set()
|
||||
async for _chunk in stt_stream:
|
||||
pass
|
||||
|
||||
pipeline_stopped.set()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client, patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
async_pipeline_from_audio_stream,
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||
entry.entry_id
|
||||
].satellite.device
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.connect_event.wait()
|
||||
await mock_client.run_satellite_event.wait()
|
||||
|
||||
# Pipeline has started
|
||||
async with asyncio.timeout(1):
|
||||
await run_pipeline_called.wait()
|
||||
|
||||
assert pipeline_event_callback is not None
|
||||
|
||||
# Change audio setting
|
||||
device.set_noise_suppression_level(1)
|
||||
|
||||
# Running pipeline should be cancelled
|
||||
async with asyncio.timeout(1):
|
||||
await pipeline_stopped.wait()
|
||||
|
||||
# Stop the satellite
|
||||
await hass.config_entries.async_unload(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def test_invalid_stages(hass: HomeAssistant) -> None:
|
||||
"""Test error when providing invalid pipeline stages."""
|
||||
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
||||
|
||||
events = [
|
||||
RunPipeline(
|
||||
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
||||
).event(),
|
||||
]
|
||||
|
||||
original_run_pipeline_once = wyoming.satellite.WyomingSatellite._run_pipeline_once
|
||||
start_stage_event = asyncio.Event()
|
||||
end_stage_event = asyncio.Event()
|
||||
|
||||
def _run_pipeline_once(self, run_pipeline):
|
||||
# Set bad start stage
|
||||
run_pipeline.start_stage = PipelineStage.INTENT
|
||||
run_pipeline.end_stage = PipelineStage.TTS
|
||||
|
||||
try:
|
||||
original_run_pipeline_once(self, run_pipeline)
|
||||
except ValueError:
|
||||
start_stage_event.set()
|
||||
|
||||
# Set bad end stage
|
||||
run_pipeline.start_stage = PipelineStage.WAKE
|
||||
run_pipeline.end_stage = PipelineStage.INTENT
|
||||
|
||||
try:
|
||||
original_run_pipeline_once(self, run_pipeline)
|
||||
except ValueError:
|
||||
end_stage_event.set()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client, patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite._run_pipeline_once",
|
||||
_run_pipeline_once,
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.connect_event.wait()
|
||||
await mock_client.run_satellite_event.wait()
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await start_stage_event.wait()
|
||||
await end_stage_event.wait()
|
||||
|
||||
# Stop the satellite
|
||||
await hass.config_entries.async_unload(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def test_client_stops_pipeline(hass: HomeAssistant) -> None:
|
||||
"""Test that an AudioStop message stops the current pipeline."""
|
||||
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
||||
|
||||
events = [
|
||||
RunPipeline(
|
||||
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
||||
).event(),
|
||||
]
|
||||
|
||||
pipeline_event_callback: Callable[
|
||||
[assist_pipeline.PipelineEvent], None
|
||||
] | None = None
|
||||
run_pipeline_called = asyncio.Event()
|
||||
pipeline_stopped = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(
|
||||
hass: HomeAssistant,
|
||||
context,
|
||||
event_callback,
|
||||
stt_metadata,
|
||||
stt_stream,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
nonlocal pipeline_event_callback
|
||||
pipeline_event_callback = event_callback
|
||||
|
||||
run_pipeline_called.set()
|
||||
async for _chunk in stt_stream:
|
||||
pass
|
||||
|
||||
pipeline_stopped.set()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client, patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
async_pipeline_from_audio_stream,
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.connect_event.wait()
|
||||
await mock_client.run_satellite_event.wait()
|
||||
|
||||
# Pipeline has started
|
||||
async with asyncio.timeout(1):
|
||||
await run_pipeline_called.wait()
|
||||
|
||||
assert pipeline_event_callback is not None
|
||||
|
||||
# Client sends stop message
|
||||
mock_client.inject_event(AudioStop().event())
|
||||
|
||||
# Running pipeline should be cancelled
|
||||
async with asyncio.timeout(1):
|
||||
await pipeline_stopped.wait()
|
||||
|
||||
# Stop the satellite
|
||||
await hass.config_entries.async_unload(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
|
Loading…
Add table
Reference in a new issue