Wyoming satellite ping and bugfix for local wake word (#108164)

* Refactor with ping

* Fix tests

* Increase test coverage
This commit is contained in:
Michael Hansen 2024-01-16 15:43:30 -06:00 committed by GitHub
parent 7dffc9f515
commit db81f4d046
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 627 additions and 117 deletions

View file

@ -6,6 +6,6 @@
"dependencies": ["assist_pipeline"],
"documentation": "https://www.home-assistant.io/integrations/wyoming",
"iot_class": "local_push",
"requirements": ["wyoming==1.4.0"],
"requirements": ["wyoming==1.5.0"],
"zeroconf": ["_wyoming._tcp.local."]
}

View file

@ -10,6 +10,7 @@ from wyoming.asr import Transcribe, Transcript
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStart, AudioStop
from wyoming.client import AsyncTcpClient
from wyoming.error import Error
from wyoming.ping import Ping, Pong
from wyoming.pipeline import PipelineStage, RunPipeline
from wyoming.satellite import RunSatellite
from wyoming.tts import Synthesize, SynthesizeVoice
@ -29,6 +30,9 @@ _LOGGER = logging.getLogger()
_SAMPLES_PER_CHUNK: Final = 1024
_RECONNECT_SECONDS: Final = 10
_RESTART_SECONDS: Final = 3
_PING_TIMEOUT: Final = 5
_PING_SEND_DELAY: Final = 2
_PIPELINE_FINISH_TIMEOUT: Final = 1
# Wyoming stage -> Assist stage
_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
@ -54,6 +58,7 @@ class WyomingSatellite:
self._client: AsyncTcpClient | None = None
self._chunk_converter = AudioChunkConverter(rate=16000, width=2, channels=1)
self._is_pipeline_running = False
self._pipeline_ended_event = asyncio.Event()
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
self._pipeline_id: str | None = None
self._muted_changed_event = asyncio.Event()
@ -77,9 +82,9 @@ class WyomingSatellite:
return
# Connect and run pipeline loop
await self._run_once()
await self._connect_and_loop()
except asyncio.CancelledError:
raise
raise # don't restart
except Exception: # pylint: disable=broad-exception-caught
await self.on_restart()
finally:
@ -142,8 +147,8 @@ class WyomingSatellite:
# Cancel any running pipeline
self._audio_queue.put_nowait(None)
async def _run_once(self) -> None:
"""Run pipelines until an error occurs."""
async def _connect_and_loop(self) -> None:
"""Connect to satellite and run pipelines until an error occurs."""
self.device.set_is_active(False)
while self.is_running and (not self.device.is_muted):
@ -163,27 +168,94 @@ class WyomingSatellite:
# Tell satellite that we're ready
await self._client.write_event(RunSatellite().event())
# Wait until we get RunPipeline event
run_pipeline: RunPipeline | None = None
# Run until stopped or muted
while self.is_running and (not self.device.is_muted):
run_event = await self._client.read_event()
if run_event is None:
raise ConnectionResetError("Satellite disconnected")
await self._run_pipeline_loop()
if RunPipeline.is_type(run_event.type):
run_pipeline = RunPipeline.from_event(run_event)
break
async def _run_pipeline_loop(self) -> None:
"""Run a pipeline one or more times."""
assert self._client is not None
run_pipeline: RunPipeline | None = None
send_ping = True
_LOGGER.debug("Unexpected event from satellite: %s", run_event)
# Read events and check for pipeline end in parallel
pipeline_ended_task = self.hass.async_create_background_task(
self._pipeline_ended_event.wait(), "satellite pipeline ended"
)
client_event_task = self.hass.async_create_background_task(
self._client.read_event(), "satellite event read"
)
pending = {pipeline_ended_task, client_event_task}
assert run_pipeline is not None
while self.is_running and (not self.device.is_muted):
if send_ping:
# Ensure satellite is still connected
send_ping = False
self.hass.async_create_background_task(
self._send_delayed_ping(), "ping satellite"
)
async with asyncio.timeout(_PING_TIMEOUT):
done, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED
)
if pipeline_ended_task in done:
# Pipeline run end event was received
_LOGGER.debug("Pipeline finished")
self._pipeline_ended_event.clear()
pipeline_ended_task = self.hass.async_create_background_task(
self._pipeline_ended_event.wait(), "satellite pipeline ended"
)
pending.add(pipeline_ended_task)
if (run_pipeline is not None) and run_pipeline.restart_on_end:
# Automatically restart pipeline.
# Used with "always on" streaming satellites.
self._run_pipeline_once(run_pipeline)
continue
if client_event_task not in done:
continue
client_event = client_event_task.result()
if client_event is None:
raise ConnectionResetError("Satellite disconnected")
if Pong.is_type(client_event.type):
# Satellite is still there, send next ping
send_ping = True
elif Ping.is_type(client_event.type):
# Respond to ping from satellite
ping = Ping.from_event(client_event)
await self._client.write_event(Pong(text=ping.text).event())
elif RunPipeline.is_type(client_event.type):
# Satellite requested pipeline run
run_pipeline = RunPipeline.from_event(client_event)
self._run_pipeline_once(run_pipeline)
elif (
AudioChunk.is_type(client_event.type) and self._is_pipeline_running
):
# Microphone audio
chunk = AudioChunk.from_event(client_event)
chunk = self._chunk_converter.convert(chunk)
self._audio_queue.put_nowait(chunk.audio)
elif AudioStop.is_type(client_event.type) and self._is_pipeline_running:
# Stop pipeline
_LOGGER.debug("Client requested pipeline to stop")
self._audio_queue.put_nowait(b"")
else:
_LOGGER.debug("Unexpected event from satellite: %s", client_event)
# Next event
client_event_task = self.hass.async_create_background_task(
self._client.read_event(), "satellite event read"
)
pending.add(client_event_task)
def _run_pipeline_once(self, run_pipeline: RunPipeline) -> None:
"""Run a pipeline once."""
_LOGGER.debug("Received run information: %s", run_pipeline)
if (not self.is_running) or self.device.is_muted:
# Run was cancelled or satellite was disabled while waiting for
# RunPipeline event.
return
start_stage = _STAGES.get(run_pipeline.start_stage)
end_stage = _STAGES.get(run_pipeline.end_stage)
@ -193,79 +265,64 @@ class WyomingSatellite:
if end_stage is None:
raise ValueError(f"Invalid end stage: {end_stage}")
# Each loop is a pipeline run
while self.is_running and (not self.device.is_muted):
# Use select to get pipeline each time in case it's changed
pipeline_id = pipeline_select.get_chosen_pipeline(
pipeline_id = pipeline_select.get_chosen_pipeline(
self.hass,
DOMAIN,
self.device.satellite_id,
)
pipeline = assist_pipeline.async_get_pipeline(self.hass, pipeline_id)
assert pipeline is not None
# We will push audio in through a queue
self._audio_queue = asyncio.Queue()
stt_stream = self._stt_stream()
# Start pipeline running
_LOGGER.debug(
"Starting pipeline %s from %s to %s",
pipeline.name,
start_stage,
end_stage,
)
self._is_pipeline_running = True
self._pipeline_ended_event.clear()
self.hass.async_create_background_task(
assist_pipeline.async_pipeline_from_audio_stream(
self.hass,
DOMAIN,
self.device.satellite_id,
)
pipeline = assist_pipeline.async_get_pipeline(self.hass, pipeline_id)
assert pipeline is not None
context=Context(),
event_callback=self._event_callback,
stt_metadata=stt.SpeechMetadata(
language=pipeline.language,
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=stt_stream,
start_stage=start_stage,
end_stage=end_stage,
tts_audio_output="wav",
pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings(
noise_suppression_level=self.device.noise_suppression_level,
auto_gain_dbfs=self.device.auto_gain,
volume_multiplier=self.device.volume_multiplier,
),
device_id=self.device.device_id,
),
name="wyoming satellite pipeline",
)
# We will push audio in through a queue
self._audio_queue = asyncio.Queue()
stt_stream = self._stt_stream()
async def _send_delayed_ping(self) -> None:
"""Send ping to satellite after a delay."""
assert self._client is not None
# Start pipeline running
_LOGGER.debug(
"Starting pipeline %s from %s to %s",
pipeline.name,
start_stage,
end_stage,
)
self._is_pipeline_running = True
_pipeline_task = asyncio.create_task(
assist_pipeline.async_pipeline_from_audio_stream(
self.hass,
context=Context(),
event_callback=self._event_callback,
stt_metadata=stt.SpeechMetadata(
language=pipeline.language,
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=stt_stream,
start_stage=start_stage,
end_stage=end_stage,
tts_audio_output="wav",
pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings(
noise_suppression_level=self.device.noise_suppression_level,
auto_gain_dbfs=self.device.auto_gain,
volume_multiplier=self.device.volume_multiplier,
),
device_id=self.device.device_id,
)
)
# Run until pipeline is complete or cancelled with an empty audio chunk
while self._is_pipeline_running:
client_event = await self._client.read_event()
if client_event is None:
raise ConnectionResetError("Satellite disconnected")
if AudioChunk.is_type(client_event.type):
# Microphone audio
chunk = AudioChunk.from_event(client_event)
chunk = self._chunk_converter.convert(chunk)
self._audio_queue.put_nowait(chunk.audio)
elif AudioStop.is_type(client_event.type):
# Stop pipeline
_LOGGER.debug("Client requested pipeline to stop")
self._audio_queue.put_nowait(b"")
break
else:
_LOGGER.debug("Unexpected event from satellite: %s", client_event)
# Ensure task finishes
await _pipeline_task
_LOGGER.debug("Pipeline finished")
try:
await asyncio.sleep(_PING_SEND_DELAY)
await self._client.write_event(Ping().event())
except ConnectionError:
pass # handled with timeout
def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None:
"""Translate pipeline events into Wyoming events."""
@ -274,6 +331,7 @@ class WyomingSatellite:
if event.type == assist_pipeline.PipelineEventType.RUN_END:
# Pipeline run is complete
self._is_pipeline_running = False
self._pipeline_ended_event.set()
self.device.set_is_active(False)
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
self.hass.add_job(self._client.write_event(Detect().event()))
@ -413,10 +471,13 @@ class WyomingSatellite:
async def _stt_stream(self) -> AsyncGenerator[bytes, None]:
"""Yield audio chunks from a queue."""
is_first_chunk = True
while chunk := await self._audio_queue.get():
if is_first_chunk:
is_first_chunk = False
_LOGGER.debug("Receiving audio from satellite")
try:
is_first_chunk = True
while chunk := await self._audio_queue.get():
if is_first_chunk:
is_first_chunk = False
_LOGGER.debug("Receiving audio from satellite")
yield chunk
yield chunk
except asyncio.CancelledError:
pass # ignore

View file

@ -2821,7 +2821,7 @@ wled==0.17.0
wolf-smartset==0.1.11
# homeassistant.components.wyoming
wyoming==1.4.0
wyoming==1.5.0
# homeassistant.components.xbox
xbox-webapi==2.0.11

View file

@ -2141,7 +2141,7 @@ wled==0.17.0
wolf-smartset==0.1.11
# homeassistant.components.wyoming
wyoming==1.4.0
wyoming==1.5.0
# homeassistant.components.xbox
xbox-webapi==2.0.11

View file

@ -2,7 +2,9 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
import io
from typing import Any
from unittest.mock import patch
import wave
@ -10,6 +12,7 @@ from wyoming.asr import Transcribe, Transcript
from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.error import Error
from wyoming.event import Event
from wyoming.ping import Ping, Pong
from wyoming.pipeline import PipelineStage, RunPipeline
from wyoming.satellite import RunSatellite
from wyoming.tts import Synthesize
@ -100,6 +103,12 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient):
self.error_event = asyncio.Event()
self.error: Error | None = None
self.pong_event = asyncio.Event()
self.pong: Pong | None = None
self.ping_event = asyncio.Event()
self.ping: Ping | None = None
self._mic_audio_chunk = AudioChunk(
rate=16000, width=2, channels=1, audio=b"chunk"
).event()
@ -142,6 +151,12 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient):
elif Error.is_type(event.type):
self.error = Error.from_event(event)
self.error_event.set()
elif Pong.is_type(event.type):
self.pong = Pong.from_event(event)
self.pong_event.set()
elif Ping.is_type(event.type):
self.ping = Ping.from_event(event)
self.ping_event.set()
async def read_event(self) -> Event | None:
"""Receive."""
@ -150,6 +165,10 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient):
# Keep sending audio chunks instead of None
return event or self._mic_audio_chunk
def inject_event(self, event: Event) -> None:
"""Put an event in as the next response."""
self.responses = [event] + self.responses
async def test_satellite_pipeline(hass: HomeAssistant) -> None:
"""Test running a pipeline with a satellite."""
@ -157,10 +176,37 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
events = [
RunPipeline(
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
start_stage=PipelineStage.WAKE,
end_stage=PipelineStage.TTS,
restart_on_end=True,
).event(),
]
pipeline_kwargs: dict[str, Any] = {}
pipeline_event_callback: Callable[
[assist_pipeline.PipelineEvent], None
] | None = None
run_pipeline_called = asyncio.Event()
audio_chunk_received = asyncio.Event()
async def async_pipeline_from_audio_stream(
hass: HomeAssistant,
context,
event_callback,
stt_metadata,
stt_stream,
**kwargs,
) -> None:
nonlocal pipeline_kwargs, pipeline_event_callback
pipeline_kwargs = kwargs
pipeline_event_callback = event_callback
run_pipeline_called.set()
async for chunk in stt_stream:
if chunk:
audio_chunk_received.set()
break
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
@ -169,10 +215,11 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
SatelliteAsyncTcpClient(events),
) as mock_client, patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
) as mock_run_pipeline, patch(
async_pipeline_from_audio_stream,
), patch(
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
return_value=("wav", get_test_wav()),
):
), patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0):
entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
entry.entry_id
@ -182,12 +229,39 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
mock_run_pipeline.assert_called_once()
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
assert mock_run_pipeline.call_args.kwargs.get("device_id") == device.device_id
async with asyncio.timeout(1):
await run_pipeline_called.wait()
# Reset so we can check the pipeline is automatically restarted below
run_pipeline_called.clear()
assert pipeline_event_callback is not None
assert pipeline_kwargs.get("device_id") == device.device_id
# Test a ping
mock_client.inject_event(Ping("test-ping").event())
# Pong is expected with the same text
async with asyncio.timeout(1):
await mock_client.pong_event.wait()
assert mock_client.pong is not None
assert mock_client.pong.text == "test-ping"
# The client should have received the first ping
async with asyncio.timeout(1):
await mock_client.ping_event.wait()
assert mock_client.ping is not None
# Reset and send a pong back.
# We will get a second ping by the end of the test.
mock_client.ping_event.clear()
mock_client.ping = None
mock_client.inject_event(Pong().event())
# Start detecting wake word
event_callback(
pipeline_event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.WAKE_WORD_START
)
@ -198,8 +272,13 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
assert not device.is_active
assert not device.is_muted
# Push in some audio
mock_client.inject_event(
AudioChunk(rate=16000, width=2, channels=1, audio=bytes(1024)).event()
)
# Wake word is detected
event_callback(
pipeline_event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.WAKE_WORD_END,
{"wake_word_output": {"wake_word_id": "test_wake_word"}},
@ -215,7 +294,7 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
assert device.is_active
# Speech-to-text started
event_callback(
pipeline_event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.STT_START,
{"metadata": {"language": "en"}},
@ -227,8 +306,13 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
assert mock_client.transcribe is not None
assert mock_client.transcribe.language == "en"
# Push in some audio
mock_client.inject_event(
AudioChunk(rate=16000, width=2, channels=1, audio=bytes(1024)).event()
)
# User started speaking
event_callback(
pipeline_event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.STT_VAD_START, {"timestamp": 1234}
)
@ -240,7 +324,7 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
assert mock_client.voice_started.timestamp == 1234
# User stopped speaking
event_callback(
pipeline_event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.STT_VAD_END, {"timestamp": 5678}
)
@ -252,7 +336,7 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
assert mock_client.voice_stopped.timestamp == 5678
# Speech-to-text transcription
event_callback(
pipeline_event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.STT_END,
{"stt_output": {"text": "test transcript"}},
@ -265,7 +349,7 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
assert mock_client.transcript.text == "test transcript"
# Text-to-speech text
event_callback(
pipeline_event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.TTS_START,
{
@ -283,7 +367,7 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
assert mock_client.synthesize.voice.name == "test voice"
# Text-to-speech media
event_callback(
pipeline_event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.TTS_END,
{"tts_output": {"media_id": "test media id"}},
@ -302,11 +386,21 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
assert mock_client.tts_audio_chunk.audio == b"123"
# Pipeline finished
event_callback(
pipeline_event_callback(
assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END)
)
assert not device.is_active
# The client should have received another ping by now
async with asyncio.timeout(1):
await mock_client.ping_event.wait()
assert mock_client.ping is not None
# Pipeline should automatically restart
async with asyncio.timeout(1):
await run_pipeline_called.wait()
# Stop the satellite
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
@ -317,6 +411,7 @@ async def test_satellite_muted(hass: HomeAssistant) -> None:
on_muted_event = asyncio.Event()
original_make_satellite = wyoming._make_satellite
original_on_muted = wyoming.satellite.WyomingSatellite.on_muted
def make_muted_satellite(
hass: HomeAssistant, config_entry: ConfigEntry, service: WyomingService
@ -327,6 +422,14 @@ async def test_satellite_muted(hass: HomeAssistant) -> None:
return satellite
async def on_muted(self):
# Trigger original function
self._muted_changed_event.set()
await original_on_muted(self)
# Ensure satellite stops
self.is_running = False
# Proceed with test
self.device.set_is_muted(False)
on_muted_event.set()
@ -339,16 +442,23 @@ async def test_satellite_muted(hass: HomeAssistant) -> None:
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_muted",
on_muted,
):
await setup_config_entry(hass)
entry = await setup_config_entry(hass)
async with asyncio.timeout(1):
await on_muted_event.wait()
# Stop the satellite
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
async def test_satellite_restart(hass: HomeAssistant) -> None:
"""Test pipeline loop restart after unexpected error."""
on_restart_event = asyncio.Event()
original_on_restart = wyoming.satellite.WyomingSatellite.on_restart
async def on_restart(self):
await original_on_restart(self)
self.stop()
on_restart_event.set()
@ -356,12 +466,12 @@ async def test_satellite_restart(hass: HomeAssistant) -> None:
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
), patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite._run_once",
"homeassistant.components.wyoming.satellite.WyomingSatellite._connect_and_loop",
side_effect=RuntimeError(),
), patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
on_restart,
):
), patch("homeassistant.components.wyoming.satellite._RESTART_SECONDS", 0):
await setup_config_entry(hass)
async with asyncio.timeout(1):
await on_restart_event.wait()
@ -373,7 +483,11 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None:
reconnect_event = asyncio.Event()
stopped_event = asyncio.Event()
original_on_reconnect = wyoming.satellite.WyomingSatellite.on_reconnect
async def on_reconnect(self):
await original_on_reconnect(self)
nonlocal num_reconnects
num_reconnects += 1
if num_reconnects >= 2:
@ -395,7 +509,7 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None:
), patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
on_stopped,
):
), patch("homeassistant.components.wyoming.satellite._RECONNECT_SECONDS", 0):
await setup_config_entry(hass)
async with asyncio.timeout(1):
await reconnect_event.wait()
@ -519,3 +633,338 @@ async def test_satellite_error_during_pipeline(hass: HomeAssistant) -> None:
assert mock_client.error is not None
assert mock_client.error.text == "test message"
assert mock_client.error.code == "test code"
async def test_tts_not_wav(hass: HomeAssistant) -> None:
"""Test satellite receiving non-WAV audio from text-to-speech."""
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
original_stream_tts = wyoming.satellite.WyomingSatellite._stream_tts
error_event = asyncio.Event()
async def _stream_tts(self, media_id):
try:
await original_stream_tts(self, media_id)
except ValueError:
error_event.set()
events = [
RunPipeline(start_stage=PipelineStage.TTS, end_stage=PipelineStage.TTS).event(),
]
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
), patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client, patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
) as mock_run_pipeline, patch(
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
return_value=("mp3", bytes(1)),
), patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite._stream_tts",
_stream_tts,
):
entry = await setup_config_entry(hass)
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
mock_run_pipeline.assert_called_once()
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
# Text-to-speech text
event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.TTS_START,
{
"tts_input": "test text to speak",
"voice": "test voice",
},
)
)
async with asyncio.timeout(1):
await mock_client.synthesize_event.wait()
# Text-to-speech media
event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.TTS_END,
{"tts_output": {"media_id": "test media id"}},
)
)
# Expect error because only WAV is supported
async with asyncio.timeout(1):
await error_event.wait()
# Stop the satellite
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
async def test_pipeline_changed(hass: HomeAssistant) -> None:
"""Test that changing the pipeline setting stops the current pipeline."""
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
events = [
RunPipeline(
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
).event(),
]
pipeline_event_callback: Callable[
[assist_pipeline.PipelineEvent], None
] | None = None
run_pipeline_called = asyncio.Event()
pipeline_stopped = asyncio.Event()
async def async_pipeline_from_audio_stream(
hass: HomeAssistant,
context,
event_callback,
stt_metadata,
stt_stream,
**kwargs,
) -> None:
nonlocal pipeline_event_callback
pipeline_event_callback = event_callback
run_pipeline_called.set()
async for _chunk in stt_stream:
pass
pipeline_stopped.set()
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
), patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client, patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream,
):
entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
entry.entry_id
].satellite.device
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
# Pipeline has started
async with asyncio.timeout(1):
await run_pipeline_called.wait()
assert pipeline_event_callback is not None
# Change pipelines
device.set_pipeline_name("different pipeline")
# Running pipeline should be cancelled
async with asyncio.timeout(1):
await pipeline_stopped.wait()
# Stop the satellite
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
async def test_audio_settings_changed(hass: HomeAssistant) -> None:
"""Test that changing audio settings stops the current pipeline."""
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
events = [
RunPipeline(
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
).event(),
]
pipeline_event_callback: Callable[
[assist_pipeline.PipelineEvent], None
] | None = None
run_pipeline_called = asyncio.Event()
pipeline_stopped = asyncio.Event()
async def async_pipeline_from_audio_stream(
hass: HomeAssistant,
context,
event_callback,
stt_metadata,
stt_stream,
**kwargs,
) -> None:
nonlocal pipeline_event_callback
pipeline_event_callback = event_callback
run_pipeline_called.set()
async for _chunk in stt_stream:
pass
pipeline_stopped.set()
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
), patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client, patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream,
):
entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
entry.entry_id
].satellite.device
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
# Pipeline has started
async with asyncio.timeout(1):
await run_pipeline_called.wait()
assert pipeline_event_callback is not None
# Change audio setting
device.set_noise_suppression_level(1)
# Running pipeline should be cancelled
async with asyncio.timeout(1):
await pipeline_stopped.wait()
# Stop the satellite
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
async def test_invalid_stages(hass: HomeAssistant) -> None:
"""Test error when providing invalid pipeline stages."""
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
events = [
RunPipeline(
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
).event(),
]
original_run_pipeline_once = wyoming.satellite.WyomingSatellite._run_pipeline_once
start_stage_event = asyncio.Event()
end_stage_event = asyncio.Event()
def _run_pipeline_once(self, run_pipeline):
# Set bad start stage
run_pipeline.start_stage = PipelineStage.INTENT
run_pipeline.end_stage = PipelineStage.TTS
try:
original_run_pipeline_once(self, run_pipeline)
except ValueError:
start_stage_event.set()
# Set bad end stage
run_pipeline.start_stage = PipelineStage.WAKE
run_pipeline.end_stage = PipelineStage.INTENT
try:
original_run_pipeline_once(self, run_pipeline)
except ValueError:
end_stage_event.set()
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
), patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client, patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite._run_pipeline_once",
_run_pipeline_once,
):
entry = await setup_config_entry(hass)
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
async with asyncio.timeout(1):
await start_stage_event.wait()
await end_stage_event.wait()
# Stop the satellite
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
async def test_client_stops_pipeline(hass: HomeAssistant) -> None:
"""Test that an AudioStop message stops the current pipeline."""
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
events = [
RunPipeline(
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
).event(),
]
pipeline_event_callback: Callable[
[assist_pipeline.PipelineEvent], None
] | None = None
run_pipeline_called = asyncio.Event()
pipeline_stopped = asyncio.Event()
async def async_pipeline_from_audio_stream(
hass: HomeAssistant,
context,
event_callback,
stt_metadata,
stt_stream,
**kwargs,
) -> None:
nonlocal pipeline_event_callback
pipeline_event_callback = event_callback
run_pipeline_called.set()
async for _chunk in stt_stream:
pass
pipeline_stopped.set()
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
), patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client, patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream,
):
entry = await setup_config_entry(hass)
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
# Pipeline has started
async with asyncio.timeout(1):
await run_pipeline_called.wait()
assert pipeline_event_callback is not None
# Client sends stop message
mock_client.inject_event(AudioStop().event())
# Running pipeline should be cancelled
async with asyncio.timeout(1):
await pipeline_stopped.wait()
# Stop the satellite
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()