Add speech-to-text cooldown for local wake word (#108806)

* Deconflict based on wake word

* Undo test

* Make wake up key a string, rename error

* Update snapshot

* Change to "wake word phrase" and normalize

* Move normalization into the wake provider

* Working on describe

* Use satellite info to resolve wake word phrase

* Add test for wake word phrase

* Match phrase with model name in wake word provider

* Check model id

* Use one constant wake word cooldown

* Update homeassistant/components/assist_pipeline/error.py

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>

* Fix wake word tests

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen 2024-02-26 19:35:19 -06:00 committed by GitHub
parent c38e0d22b8
commit f6622ea8e0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 641 additions and 184 deletions

View file

@ -83,6 +83,7 @@ async def async_pipeline_from_audio_stream(
event_callback: PipelineEventCallback, event_callback: PipelineEventCallback,
stt_metadata: stt.SpeechMetadata, stt_metadata: stt.SpeechMetadata,
stt_stream: AsyncIterable[bytes], stt_stream: AsyncIterable[bytes],
wake_word_phrase: str | None = None,
pipeline_id: str | None = None, pipeline_id: str | None = None,
conversation_id: str | None = None, conversation_id: str | None = None,
tts_audio_output: str | None = None, tts_audio_output: str | None = None,
@ -101,6 +102,7 @@ async def async_pipeline_from_audio_stream(
device_id=device_id, device_id=device_id,
stt_metadata=stt_metadata, stt_metadata=stt_metadata,
stt_stream=stt_stream, stt_stream=stt_stream,
wake_word_phrase=wake_word_phrase,
run=PipelineRun( run=PipelineRun(
hass, hass,
context=context, context=context,

View file

@ -10,6 +10,6 @@ DEFAULT_WAKE_WORD_TIMEOUT = 3 # seconds
CONF_DEBUG_RECORDING_DIR = "debug_recording_dir" CONF_DEBUG_RECORDING_DIR = "debug_recording_dir"
DATA_LAST_WAKE_UP = f"{DOMAIN}.last_wake_up" DATA_LAST_WAKE_UP = f"{DOMAIN}.last_wake_up"
DEFAULT_WAKE_WORD_COOLDOWN = 2 # seconds WAKE_WORD_COOLDOWN = 2 # seconds
EVENT_RECORDING = f"{DOMAIN}_recording" EVENT_RECORDING = f"{DOMAIN}_recording"

View file

@ -38,6 +38,17 @@ class SpeechToTextError(PipelineError):
"""Error in speech-to-text portion of pipeline.""" """Error in speech-to-text portion of pipeline."""
class DuplicateWakeUpDetectedError(WakeWordDetectionError):
"""Error when multiple voice assistants wake up at the same time (same wake word)."""
def __init__(self, wake_up_phrase: str) -> None:
"""Set error message."""
super().__init__(
"duplicate_wake_up_detected",
f"Duplicate wake-up detected for {wake_up_phrase}",
)
class IntentRecognitionError(PipelineError): class IntentRecognitionError(PipelineError):
"""Error in intent recognition portion of pipeline.""" """Error in intent recognition portion of pipeline."""

View file

@ -55,10 +55,11 @@ from .const import (
CONF_DEBUG_RECORDING_DIR, CONF_DEBUG_RECORDING_DIR,
DATA_CONFIG, DATA_CONFIG,
DATA_LAST_WAKE_UP, DATA_LAST_WAKE_UP,
DEFAULT_WAKE_WORD_COOLDOWN,
DOMAIN, DOMAIN,
WAKE_WORD_COOLDOWN,
) )
from .error import ( from .error import (
DuplicateWakeUpDetectedError,
IntentRecognitionError, IntentRecognitionError,
PipelineError, PipelineError,
PipelineNotFound, PipelineNotFound,
@ -453,9 +454,6 @@ class WakeWordSettings:
audio_seconds_to_buffer: float = 0 audio_seconds_to_buffer: float = 0
"""Seconds of audio to buffer before detection and forward to STT.""" """Seconds of audio to buffer before detection and forward to STT."""
cooldown_seconds: float = DEFAULT_WAKE_WORD_COOLDOWN
"""Seconds after a wake word detection where other detections are ignored."""
@dataclass(frozen=True) @dataclass(frozen=True)
class AudioSettings: class AudioSettings:
@ -742,16 +740,22 @@ class PipelineRun:
wake_word_output: dict[str, Any] = {} wake_word_output: dict[str, Any] = {}
else: else:
# Avoid duplicate detections by checking cooldown # Avoid duplicate detections by checking cooldown
wake_up_key = f"{self.wake_word_entity_id}.{result.wake_word_id}" last_wake_up = self.hass.data[DATA_LAST_WAKE_UP].get(
last_wake_up = self.hass.data[DATA_LAST_WAKE_UP].get(wake_up_key) result.wake_word_phrase
)
if last_wake_up is not None: if last_wake_up is not None:
sec_since_last_wake_up = time.monotonic() - last_wake_up sec_since_last_wake_up = time.monotonic() - last_wake_up
if sec_since_last_wake_up < wake_word_settings.cooldown_seconds: if sec_since_last_wake_up < WAKE_WORD_COOLDOWN:
_LOGGER.debug("Duplicate wake word detection occurred") _LOGGER.debug(
raise WakeWordDetectionAborted "Duplicate wake word detection occurred for %s",
result.wake_word_phrase,
)
raise DuplicateWakeUpDetectedError(result.wake_word_phrase)
# Record last wake up time to block duplicate detections # Record last wake up time to block duplicate detections
self.hass.data[DATA_LAST_WAKE_UP][wake_up_key] = time.monotonic() self.hass.data[DATA_LAST_WAKE_UP][
result.wake_word_phrase
] = time.monotonic()
if result.queued_audio: if result.queued_audio:
# Add audio that was pending at detection. # Add audio that was pending at detection.
@ -1308,6 +1312,9 @@ class PipelineInput:
stt_stream: AsyncIterable[bytes] | None = None stt_stream: AsyncIterable[bytes] | None = None
"""Input audio for stt. Required when start_stage = stt.""" """Input audio for stt. Required when start_stage = stt."""
wake_word_phrase: str | None = None
"""Optional key used to de-duplicate wake-ups for local wake word detection."""
intent_input: str | None = None intent_input: str | None = None
"""Input for conversation agent. Required when start_stage = intent.""" """Input for conversation agent. Required when start_stage = intent."""
@ -1352,6 +1359,25 @@ class PipelineInput:
assert self.stt_metadata is not None assert self.stt_metadata is not None
assert stt_processed_stream is not None assert stt_processed_stream is not None
if self.wake_word_phrase is not None:
# Avoid duplicate wake-ups by checking cooldown
last_wake_up = self.run.hass.data[DATA_LAST_WAKE_UP].get(
self.wake_word_phrase
)
if last_wake_up is not None:
sec_since_last_wake_up = time.monotonic() - last_wake_up
if sec_since_last_wake_up < WAKE_WORD_COOLDOWN:
_LOGGER.debug(
"Speech-to-text cancelled to avoid duplicate wake-up for %s",
self.wake_word_phrase,
)
raise DuplicateWakeUpDetectedError(self.wake_word_phrase)
# Record last wake up time to block duplicate detections
self.run.hass.data[DATA_LAST_WAKE_UP][
self.wake_word_phrase
] = time.monotonic()
stt_input_stream = stt_processed_stream stt_input_stream = stt_processed_stream
if stt_audio_buffer: if stt_audio_buffer:

View file

@ -97,7 +97,12 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
extra=vol.ALLOW_EXTRA, extra=vol.ALLOW_EXTRA,
), ),
PipelineStage.STT: vol.Schema( PipelineStage.STT: vol.Schema(
{vol.Required("input"): {vol.Required("sample_rate"): int}}, {
vol.Required("input"): {
vol.Required("sample_rate"): int,
vol.Optional("wake_word_phrase"): str,
}
},
extra=vol.ALLOW_EXTRA, extra=vol.ALLOW_EXTRA,
), ),
PipelineStage.INTENT: vol.Schema( PipelineStage.INTENT: vol.Schema(
@ -149,12 +154,15 @@ async def websocket_run(
msg_input = msg["input"] msg_input = msg["input"]
audio_queue: asyncio.Queue[bytes] = asyncio.Queue() audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
incoming_sample_rate = msg_input["sample_rate"] incoming_sample_rate = msg_input["sample_rate"]
wake_word_phrase: str | None = None
if start_stage == PipelineStage.WAKE_WORD: if start_stage == PipelineStage.WAKE_WORD:
wake_word_settings = WakeWordSettings( wake_word_settings = WakeWordSettings(
timeout=msg["input"].get("timeout", DEFAULT_WAKE_WORD_TIMEOUT), timeout=msg["input"].get("timeout", DEFAULT_WAKE_WORD_TIMEOUT),
audio_seconds_to_buffer=msg_input.get("audio_seconds_to_buffer", 0), audio_seconds_to_buffer=msg_input.get("audio_seconds_to_buffer", 0),
) )
elif start_stage == PipelineStage.STT:
wake_word_phrase = msg["input"].get("wake_word_phrase")
async def stt_stream() -> AsyncGenerator[bytes, None]: async def stt_stream() -> AsyncGenerator[bytes, None]:
state = None state = None
@ -189,6 +197,7 @@ async def websocket_run(
channel=stt.AudioChannels.CHANNEL_MONO, channel=stt.AudioChannels.CHANNEL_MONO,
) )
input_args["stt_stream"] = stt_stream() input_args["stt_stream"] = stt_stream()
input_args["wake_word_phrase"] = wake_word_phrase
# Audio settings # Audio settings
audio_settings = AudioSettings( audio_settings = AudioSettings(

View file

@ -7,7 +7,13 @@ class WakeWord:
"""Wake word model.""" """Wake word model."""
id: str id: str
"""Id of wake word model"""
name: str name: str
"""Name of wake word model"""
phrase: str | None = None
"""Wake word phrase used to trigger model"""
@dataclass @dataclass
@ -17,6 +23,9 @@ class DetectionResult:
wake_word_id: str wake_word_id: str
"""Id of detected wake word""" """Id of detected wake word"""
wake_word_phrase: str
"""Normalized phrase for the detected wake word"""
timestamp: int | None timestamp: int | None
"""Timestamp of audio chunk with detected wake word""" """Timestamp of audio chunk with detected wake word"""

View file

@ -6,6 +6,6 @@
"dependencies": ["assist_pipeline"], "dependencies": ["assist_pipeline"],
"documentation": "https://www.home-assistant.io/integrations/wyoming", "documentation": "https://www.home-assistant.io/integrations/wyoming",
"iot_class": "local_push", "iot_class": "local_push",
"requirements": ["wyoming==1.5.2"], "requirements": ["wyoming==1.5.3"],
"zeroconf": ["_wyoming._tcp.local."] "zeroconf": ["_wyoming._tcp.local."]
} }

View file

@ -1,4 +1,5 @@
"""Support for Wyoming satellite services.""" """Support for Wyoming satellite services."""
import asyncio import asyncio
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
import io import io
@ -10,6 +11,7 @@ from wyoming.asr import Transcribe, Transcript
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStart, AudioStop from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStart, AudioStop
from wyoming.client import AsyncTcpClient from wyoming.client import AsyncTcpClient
from wyoming.error import Error from wyoming.error import Error
from wyoming.info import Describe, Info
from wyoming.ping import Ping, Pong from wyoming.ping import Ping, Pong
from wyoming.pipeline import PipelineStage, RunPipeline from wyoming.pipeline import PipelineStage, RunPipeline
from wyoming.satellite import PauseSatellite, RunSatellite from wyoming.satellite import PauseSatellite, RunSatellite
@ -86,7 +88,9 @@ class WyomingSatellite:
await self._connect_and_loop() await self._connect_and_loop()
except asyncio.CancelledError: except asyncio.CancelledError:
raise # don't restart raise # don't restart
except Exception: # pylint: disable=broad-exception-caught except Exception as err: # pylint: disable=broad-exception-caught
_LOGGER.debug("%s: %s", err.__class__.__name__, str(err))
# Ensure sensor is off (before restart) # Ensure sensor is off (before restart)
self.device.set_is_active(False) self.device.set_is_active(False)
@ -197,6 +201,8 @@ class WyomingSatellite:
async def _run_pipeline_loop(self) -> None: async def _run_pipeline_loop(self) -> None:
"""Run a pipeline one or more times.""" """Run a pipeline one or more times."""
assert self._client is not None assert self._client is not None
client_info: Info | None = None
wake_word_phrase: str | None = None
run_pipeline: RunPipeline | None = None run_pipeline: RunPipeline | None = None
send_ping = True send_ping = True
@ -209,6 +215,9 @@ class WyomingSatellite:
) )
pending = {pipeline_ended_task, client_event_task} pending = {pipeline_ended_task, client_event_task}
# Update info from satellite
await self._client.write_event(Describe().event())
while self.is_running and (not self.device.is_muted): while self.is_running and (not self.device.is_muted):
if send_ping: if send_ping:
# Ensure satellite is still connected # Ensure satellite is still connected
@ -230,6 +239,9 @@ class WyomingSatellite:
) )
pending.add(pipeline_ended_task) pending.add(pipeline_ended_task)
# Clear last wake word detection
wake_word_phrase = None
if (run_pipeline is not None) and run_pipeline.restart_on_end: if (run_pipeline is not None) and run_pipeline.restart_on_end:
# Automatically restart pipeline. # Automatically restart pipeline.
# Used with "always on" streaming satellites. # Used with "always on" streaming satellites.
@ -253,7 +265,7 @@ class WyomingSatellite:
elif RunPipeline.is_type(client_event.type): elif RunPipeline.is_type(client_event.type):
# Satellite requested pipeline run # Satellite requested pipeline run
run_pipeline = RunPipeline.from_event(client_event) run_pipeline = RunPipeline.from_event(client_event)
self._run_pipeline_once(run_pipeline) self._run_pipeline_once(run_pipeline, wake_word_phrase)
elif ( elif (
AudioChunk.is_type(client_event.type) and self._is_pipeline_running AudioChunk.is_type(client_event.type) and self._is_pipeline_running
): ):
@ -265,6 +277,32 @@ class WyomingSatellite:
# Stop pipeline # Stop pipeline
_LOGGER.debug("Client requested pipeline to stop") _LOGGER.debug("Client requested pipeline to stop")
self._audio_queue.put_nowait(b"") self._audio_queue.put_nowait(b"")
elif Info.is_type(client_event.type):
client_info = Info.from_event(client_event)
_LOGGER.debug("Updated client info: %s", client_info)
elif Detection.is_type(client_event.type):
detection = Detection.from_event(client_event)
wake_word_phrase = detection.name
# Resolve wake word name/id to phrase if info is available.
#
# This allows us to deconflict multiple satellite wake-ups
# with the same wake word.
if (client_info is not None) and (client_info.wake is not None):
found_phrase = False
for wake_service in client_info.wake:
for wake_model in wake_service.models:
if wake_model.name == detection.name:
wake_word_phrase = (
wake_model.phrase or wake_model.name
)
found_phrase = True
break
if found_phrase:
break
_LOGGER.debug("Client detected wake word: %s", wake_word_phrase)
else: else:
_LOGGER.debug("Unexpected event from satellite: %s", client_event) _LOGGER.debug("Unexpected event from satellite: %s", client_event)
@ -274,7 +312,9 @@ class WyomingSatellite:
) )
pending.add(client_event_task) pending.add(client_event_task)
def _run_pipeline_once(self, run_pipeline: RunPipeline) -> None: def _run_pipeline_once(
self, run_pipeline: RunPipeline, wake_word_phrase: str | None = None
) -> None:
"""Run a pipeline once.""" """Run a pipeline once."""
_LOGGER.debug("Received run information: %s", run_pipeline) _LOGGER.debug("Received run information: %s", run_pipeline)
@ -332,6 +372,7 @@ class WyomingSatellite:
volume_multiplier=self.device.volume_multiplier, volume_multiplier=self.device.volume_multiplier,
), ),
device_id=self.device.device_id, device_id=self.device.device_id,
wake_word_phrase=wake_word_phrase,
), ),
name="wyoming satellite pipeline", name="wyoming satellite pipeline",
) )

View file

@ -1,4 +1,5 @@
"""Support for Wyoming wake-word-detection services.""" """Support for Wyoming wake-word-detection services."""
import asyncio import asyncio
from collections.abc import AsyncIterable from collections.abc import AsyncIterable
import logging import logging
@ -49,7 +50,9 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
wake_service = service.info.wake[0] wake_service = service.info.wake[0]
self._supported_wake_words = [ self._supported_wake_words = [
wake_word.WakeWord(id=ww.name, name=ww.description or ww.name) wake_word.WakeWord(
id=ww.name, name=ww.description or ww.name, phrase=ww.phrase
)
for ww in wake_service.models for ww in wake_service.models
] ]
self._attr_name = wake_service.name self._attr_name = wake_service.name
@ -64,7 +67,11 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
if info is not None: if info is not None:
wake_service = info.wake[0] wake_service = info.wake[0]
self._supported_wake_words = [ self._supported_wake_words = [
wake_word.WakeWord(id=ww.name, name=ww.description or ww.name) wake_word.WakeWord(
id=ww.name,
name=ww.description or ww.name,
phrase=ww.phrase,
)
for ww in wake_service.models for ww in wake_service.models
] ]
@ -140,6 +147,7 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
return wake_word.DetectionResult( return wake_word.DetectionResult(
wake_word_id=detection.name, wake_word_id=detection.name,
wake_word_phrase=self._get_phrase(detection.name),
timestamp=detection.timestamp, timestamp=detection.timestamp,
queued_audio=queued_audio, queued_audio=queued_audio,
) )
@ -183,3 +191,14 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
_LOGGER.exception("Error processing audio stream: %s", err) _LOGGER.exception("Error processing audio stream: %s", err)
return None return None
def _get_phrase(self, model_id: str) -> str:
"""Get wake word phrase for model id."""
for ww_model in self._supported_wake_words:
if not ww_model.phrase:
continue
if ww_model.id == model_id:
return ww_model.phrase
return model_id

View file

@ -2863,7 +2863,7 @@ wled==0.17.0
wolf-comm==0.0.4 wolf-comm==0.0.4
# homeassistant.components.wyoming # homeassistant.components.wyoming
wyoming==1.5.2 wyoming==1.5.3
# homeassistant.components.xbox # homeassistant.components.xbox
xbox-webapi==2.0.11 xbox-webapi==2.0.11

View file

@ -2195,7 +2195,7 @@ wled==0.17.0
wolf-comm==0.0.4 wolf-comm==0.0.4
# homeassistant.components.wyoming # homeassistant.components.wyoming
wyoming==1.5.2 wyoming==1.5.3
# homeassistant.components.xbox # homeassistant.components.xbox
xbox-webapi==2.0.11 xbox-webapi==2.0.11

View file

@ -201,16 +201,19 @@ class MockWakeWordEntity(wake_word.WakeWordDetectionEntity):
if self.alternate_detections: if self.alternate_detections:
detected_id = wake_words[self.detected_wake_word_index].id detected_id = wake_words[self.detected_wake_word_index].id
detected_name = wake_words[self.detected_wake_word_index].name
self.detected_wake_word_index = (self.detected_wake_word_index + 1) % len( self.detected_wake_word_index = (self.detected_wake_word_index + 1) % len(
wake_words wake_words
) )
else: else:
detected_id = wake_words[0].id detected_id = wake_words[0].id
detected_name = wake_words[0].name
async for chunk, timestamp in stream: async for chunk, timestamp in stream:
if chunk.startswith(b"wake word"): if chunk.startswith(b"wake word"):
return wake_word.DetectionResult( return wake_word.DetectionResult(
wake_word_id=detected_id, wake_word_id=detected_id,
wake_word_phrase=detected_name,
timestamp=timestamp, timestamp=timestamp,
queued_audio=[(b"queued audio", 0)], queued_audio=[(b"queued audio", 0)],
) )
@ -240,6 +243,7 @@ class MockWakeWordEntity2(wake_word.WakeWordDetectionEntity):
if chunk.startswith(b"wake word"): if chunk.startswith(b"wake word"):
return wake_word.DetectionResult( return wake_word.DetectionResult(
wake_word_id=wake_words[0].id, wake_word_id=wake_words[0].id,
wake_word_phrase=wake_words[0].name,
timestamp=timestamp, timestamp=timestamp,
queued_audio=[(b"queued audio", 0)], queued_audio=[(b"queued audio", 0)],
) )

View file

@ -294,6 +294,7 @@
'wake_word_output': dict({ 'wake_word_output': dict({
'timestamp': 2000, 'timestamp': 2000,
'wake_word_id': 'test_ww', 'wake_word_id': 'test_ww',
'wake_word_phrase': 'Test Wake Word',
}), }),
}), }),
'type': <PipelineEventType.WAKE_WORD_END: 'wake_word-end'>, 'type': <PipelineEventType.WAKE_WORD_END: 'wake_word-end'>,

View file

@ -381,6 +381,7 @@
'wake_word_output': dict({ 'wake_word_output': dict({
'timestamp': 0, 'timestamp': 0,
'wake_word_id': 'test_ww', 'wake_word_id': 'test_ww',
'wake_word_phrase': 'Test Wake Word',
}), }),
}) })
# --- # ---
@ -695,6 +696,46 @@
# name: test_pipeline_empty_tts_output.3 # name: test_pipeline_empty_tts_output.3
None None
# --- # ---
# name: test_stt_cooldown_different_ids
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 300,
}),
})
# ---
# name: test_stt_cooldown_different_ids.1
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 300,
}),
})
# ---
# name: test_stt_cooldown_same_id
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 300,
}),
})
# ---
# name: test_stt_cooldown_same_id.1
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 300,
}),
})
# ---
# name: test_stt_provider_missing # name: test_stt_provider_missing
dict({ dict({
'language': 'en', 'language': 'en',
@ -926,15 +967,14 @@
'wake_word_output': dict({ 'wake_word_output': dict({
'timestamp': 0, 'timestamp': 0,
'wake_word_id': 'test_ww', 'wake_word_id': 'test_ww',
'wake_word_phrase': 'Test Wake Word',
}), }),
}) })
# --- # ---
# name: test_wake_word_cooldown_different_entities.5 # name: test_wake_word_cooldown_different_entities.5
dict({ dict({
'wake_word_output': dict({ 'code': 'duplicate_wake_up_detected',
'timestamp': 0, 'message': 'Duplicate wake-up detected for Test Wake Word',
'wake_word_id': 'test_ww',
}),
}) })
# --- # ---
# name: test_wake_word_cooldown_different_ids # name: test_wake_word_cooldown_different_ids
@ -988,6 +1028,7 @@
'wake_word_output': dict({ 'wake_word_output': dict({
'timestamp': 0, 'timestamp': 0,
'wake_word_id': 'test_ww', 'wake_word_id': 'test_ww',
'wake_word_phrase': 'Test Wake Word',
}), }),
}) })
# --- # ---
@ -996,6 +1037,7 @@
'wake_word_output': dict({ 'wake_word_output': dict({
'timestamp': 0, 'timestamp': 0,
'wake_word_id': 'test_ww_2', 'wake_word_id': 'test_ww_2',
'wake_word_phrase': 'Test Wake Word 2',
}), }),
}) })
# --- # ---
@ -1045,3 +1087,18 @@
'timeout': 3, 'timeout': 3,
}) })
# --- # ---
# name: test_wake_word_cooldown_same_id.4
dict({
'wake_word_output': dict({
'timestamp': 0,
'wake_word_id': 'test_ww',
'wake_word_phrase': 'Test Wake Word',
}),
})
# ---
# name: test_wake_word_cooldown_same_id.5
dict({
'code': 'duplicate_wake_up_detected',
'message': 'Duplicate wake-up detected for Test Wake Word',
})
# ---

View file

@ -1,6 +1,7 @@
"""Websocket tests for Voice Assistant integration.""" """Websocket tests for Voice Assistant integration."""
import asyncio import asyncio
import base64 import base64
from typing import Any
from unittest.mock import ANY, patch from unittest.mock import ANY, patch
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
@ -1887,14 +1888,23 @@ async def test_wake_word_cooldown_same_id(
await client_2.send_bytes(bytes([handler_id_2]) + b"wake word") await client_2.send_bytes(bytes([handler_id_2]) + b"wake word")
# Get response events # Get response events
error_data: dict[str, Any] | None = None
msg = await client_1.receive_json() msg = await client_1.receive_json()
event_type_1 = msg["event"]["type"] event_type_1 = msg["event"]["type"]
assert msg["event"]["data"] == snapshot
if event_type_1 == "error":
error_data = msg["event"]["data"]
msg = await client_2.receive_json() msg = await client_2.receive_json()
event_type_2 = msg["event"]["type"] event_type_2 = msg["event"]["type"]
assert msg["event"]["data"] == snapshot
if event_type_2 == "error":
error_data = msg["event"]["data"]
# One should be a wake up, one should be an error # One should be a wake up, one should be an error
assert {event_type_1, event_type_2} == {"wake_word-end", "error"} assert {event_type_1, event_type_2} == {"wake_word-end", "error"}
assert error_data is not None
assert error_data["code"] == "duplicate_wake_up_detected"
async def test_wake_word_cooldown_different_ids( async def test_wake_word_cooldown_different_ids(
@ -1989,7 +1999,7 @@ async def test_wake_word_cooldown_different_entities(
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test that duplicate wake word detections are allowed with different entities.""" """Test that duplicate wake word detections are blocked even with different wake word entities."""
client_pipeline = await hass_ws_client(hass) client_pipeline = await hass_ws_client(hass)
await client_pipeline.send_json_auto_id( await client_pipeline.send_json_auto_id(
{ {
@ -2049,7 +2059,7 @@ async def test_wake_word_cooldown_different_entities(
} }
) )
# Use different wake word entity # Use different wake word entity (but same wake word)
await client_2.send_json_auto_id( await client_2.send_json_auto_id(
{ {
"type": "assist_pipeline/run", "type": "assist_pipeline/run",
@ -2099,18 +2109,23 @@ async def test_wake_word_cooldown_different_entities(
await client_2.send_bytes(bytes([handler_id_2]) + b"wake word") await client_2.send_bytes(bytes([handler_id_2]) + b"wake word")
# Get response events # Get response events
error_data: dict[str, Any] | None = None
msg = await client_1.receive_json() msg = await client_1.receive_json()
assert msg["event"]["type"] == "wake_word-end", msg event_type_1 = msg["event"]["type"]
ww_id_1 = msg["event"]["data"]["wake_word_output"]["wake_word_id"]
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
if event_type_1 == "error":
error_data = msg["event"]["data"]
msg = await client_2.receive_json() msg = await client_2.receive_json()
assert msg["event"]["type"] == "wake_word-end", msg event_type_2 = msg["event"]["type"]
ww_id_2 = msg["event"]["data"]["wake_word_output"]["wake_word_id"]
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
if event_type_2 == "error":
error_data = msg["event"]["data"]
# Wake words should be the same # One should be a wake up, one should be an error
assert ww_id_1 == ww_id_2 assert {event_type_1, event_type_2} == {"wake_word-end", "error"}
assert error_data is not None
assert error_data["code"] == "duplicate_wake_up_detected"
async def test_device_capture( async def test_device_capture(
@ -2521,3 +2536,138 @@ async def test_pipeline_list_devices(
"pipeline_entity": "select.test_assist_device_test_prefix_pipeline", "pipeline_entity": "select.test_assist_device_test_prefix_pipeline",
} }
] ]
async def test_stt_cooldown_same_id(
hass: HomeAssistant,
init_components,
mock_stt_provider,
hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
) -> None:
"""Test that two speech-to-text pipelines cannot run within the cooldown period if they have the same wake word."""
client_1 = await hass_ws_client(hass)
client_2 = await hass_ws_client(hass)
await client_1.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "stt",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
"wake_word_phrase": "ok_nabu",
},
}
)
await client_2.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "stt",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
"wake_word_phrase": "ok_nabu",
},
}
)
# result
msg = await client_1.receive_json()
assert msg["success"], msg
msg = await client_2.receive_json()
assert msg["success"], msg
# run start
msg = await client_1.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot
msg = await client_2.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot
# Get response events
error_data: dict[str, Any] | None = None
msg = await client_1.receive_json()
event_type_1 = msg["event"]["type"]
if event_type_1 == "error":
error_data = msg["event"]["data"]
msg = await client_2.receive_json()
event_type_2 = msg["event"]["type"]
if event_type_2 == "error":
error_data = msg["event"]["data"]
# One should be a stt start, one should be an error
assert {event_type_1, event_type_2} == {"stt-start", "error"}
assert error_data is not None
assert error_data["code"] == "duplicate_wake_up_detected"
async def test_stt_cooldown_different_ids(
hass: HomeAssistant,
init_components,
mock_stt_provider,
hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
) -> None:
"""Test that two speech-to-text pipelines can run within the cooldown period if they have the different wake words."""
client_1 = await hass_ws_client(hass)
client_2 = await hass_ws_client(hass)
await client_1.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "stt",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
"wake_word_phrase": "ok_nabu",
},
}
)
await client_2.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "stt",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
"wake_word_phrase": "hey_jarvis",
},
}
)
# result
msg = await client_1.receive_json()
assert msg["success"], msg
msg = await client_2.receive_json()
assert msg["success"], msg
# run start
msg = await client_1.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot
msg = await client_2.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot
# Get response events
msg = await client_1.receive_json()
event_type_1 = msg["event"]["type"]
msg = await client_2.receive_json()
event_type_2 = msg["event"]["type"]
# Both should start stt
assert {event_type_1, event_type_2} == {"stt-start"}

View file

@ -1,4 +1,5 @@
"""Test wake_word component setup.""" """Test wake_word component setup."""
import asyncio import asyncio
from collections.abc import AsyncIterable, Generator from collections.abc import AsyncIterable, Generator
from functools import partial from functools import partial
@ -43,8 +44,12 @@ class MockProviderEntity(wake_word.WakeWordDetectionEntity):
async def get_supported_wake_words(self) -> list[wake_word.WakeWord]: async def get_supported_wake_words(self) -> list[wake_word.WakeWord]:
"""Return a list of supported wake words.""" """Return a list of supported wake words."""
return [ return [
wake_word.WakeWord(id="test_ww", name="Test Wake Word"), wake_word.WakeWord(
wake_word.WakeWord(id="test_ww_2", name="Test Wake Word 2"), id="test_ww", name="Test Wake Word", phrase="Test Phrase"
),
wake_word.WakeWord(
id="test_ww_2", name="Test Wake Word 2", phrase="Test Phrase 2"
),
] ]
async def _async_process_audio_stream( async def _async_process_audio_stream(
@ -54,10 +59,18 @@ class MockProviderEntity(wake_word.WakeWordDetectionEntity):
if wake_word_id is None: if wake_word_id is None:
wake_word_id = (await self.get_supported_wake_words())[0].id wake_word_id = (await self.get_supported_wake_words())[0].id
wake_word_phrase = wake_word_id
for ww in await self.get_supported_wake_words():
if ww.id == wake_word_id:
wake_word_phrase = ww.phrase or ww.name
break
async for _chunk, timestamp in stream: async for _chunk, timestamp in stream:
if timestamp >= 2000: if timestamp >= 2000:
return wake_word.DetectionResult( return wake_word.DetectionResult(
wake_word_id=wake_word_id, timestamp=timestamp wake_word_id=wake_word_id,
wake_word_phrase=wake_word_phrase,
timestamp=timestamp,
) )
# Not detected # Not detected
@ -159,10 +172,10 @@ async def test_config_entry_unload(
@freeze_time("2023-06-22 10:30:00+00:00") @freeze_time("2023-06-22 10:30:00+00:00")
@pytest.mark.parametrize( @pytest.mark.parametrize(
("wake_word_id", "expected_ww"), ("wake_word_id", "expected_ww", "expected_phrase"),
[ [
(None, "test_ww"), (None, "test_ww", "Test Phrase"),
("test_ww_2", "test_ww_2"), ("test_ww_2", "test_ww_2", "Test Phrase 2"),
], ],
) )
async def test_detected_entity( async def test_detected_entity(
@ -171,6 +184,7 @@ async def test_detected_entity(
setup: MockProviderEntity, setup: MockProviderEntity,
wake_word_id: str | None, wake_word_id: str | None,
expected_ww: str, expected_ww: str,
expected_phrase: str,
) -> None: ) -> None:
"""Test successful detection through entity.""" """Test successful detection through entity."""
@ -184,7 +198,9 @@ async def test_detected_entity(
state = setup.state state = setup.state
assert state is None assert state is None
result = await setup.async_process_audio_stream(three_second_stream(), wake_word_id) result = await setup.async_process_audio_stream(three_second_stream(), wake_word_id)
assert result == wake_word.DetectionResult(expected_ww, 2048) assert result == wake_word.DetectionResult(
wake_word_id=expected_ww, wake_word_phrase=expected_phrase, timestamp=2048
)
assert state != setup.state assert state != setup.state
assert setup.state == "2023-06-22T10:30:00+00:00" assert setup.state == "2023-06-22T10:30:00+00:00"
@ -285,8 +301,8 @@ async def test_list_wake_words(
assert msg["success"] assert msg["success"]
assert msg["result"] == { assert msg["result"] == {
"wake_words": [ "wake_words": [
{"id": "test_ww", "name": "Test Wake Word"}, {"id": "test_ww", "name": "Test Wake Word", "phrase": "Test Phrase"},
{"id": "test_ww_2", "name": "Test Wake Word 2"}, {"id": "test_ww_2", "name": "Test Wake Word 2", "phrase": "Test Phrase 2"},
] ]
} }
@ -320,9 +336,10 @@ async def test_list_wake_words_timeout(
"""Test that the list_wake_words websocket command handles unknown entity.""" """Test that the list_wake_words websocket command handles unknown entity."""
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
with patch.object( with (
setup, "get_supported_wake_words", partial(asyncio.sleep, 1) patch.object(setup, "get_supported_wake_words", partial(asyncio.sleep, 1)),
), patch("homeassistant.components.wake_word.TIMEOUT_FETCH_WAKE_WORDS", 0): patch("homeassistant.components.wake_word.TIMEOUT_FETCH_WAKE_WORDS", 0),
):
await client.send_json( await client.send_json(
{ {
"id": 5, "id": 5,

View file

@ -75,6 +75,7 @@ WAKE_WORD_INFO = Info(
WakeModel( WakeModel(
name="Test Model", name="Test Model",
description="Test Model", description="Test Model",
phrase="Test Phrase",
installed=True, installed=True,
attribution=TEST_ATTR, attribution=TEST_ATTR,
languages=["en-US"], languages=["en-US"],

View file

@ -9,5 +9,6 @@
]), ]),
'timestamp': 0, 'timestamp': 0,
'wake_word_id': 'Test Model', 'wake_word_id': 'Test Model',
'wake_word_phrase': 'Test Phrase',
}) })
# --- # ---

View file

@ -1,4 +1,5 @@
"""Test Wyoming satellite.""" """Test Wyoming satellite."""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
@ -12,6 +13,7 @@ from wyoming.asr import Transcribe, Transcript
from wyoming.audio import AudioChunk, AudioStart, AudioStop from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.error import Error from wyoming.error import Error
from wyoming.event import Event from wyoming.event import Event
from wyoming.info import Info
from wyoming.ping import Ping, Pong from wyoming.ping import Ping, Pong
from wyoming.pipeline import PipelineStage, RunPipeline from wyoming.pipeline import PipelineStage, RunPipeline
from wyoming.satellite import RunSatellite from wyoming.satellite import RunSatellite
@ -26,7 +28,7 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from . import SATELLITE_INFO, MockAsyncTcpClient from . import SATELLITE_INFO, WAKE_WORD_INFO, MockAsyncTcpClient
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -207,19 +209,25 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
audio_chunk_received.set() audio_chunk_received.set()
break break
with patch( with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info", "homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO, return_value=SATELLITE_INFO,
), patch( ),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient", "homeassistant.components.wyoming.satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events), SatelliteAsyncTcpClient(events),
) as mock_client, patch( ) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream, async_pipeline_from_audio_stream,
), patch( ),
patch(
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio", "homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
return_value=("wav", get_test_wav()), return_value=("wav", get_test_wav()),
), patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0): ),
patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0),
):
entry = await setup_config_entry(hass) entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][ device: SatelliteDevice = hass.data[wyoming.DOMAIN][
entry.entry_id entry.entry_id
@ -433,14 +441,16 @@ async def test_satellite_muted(hass: HomeAssistant) -> None:
self.device.set_is_muted(False) self.device.set_is_muted(False)
on_muted_event.set() on_muted_event.set()
with patch( with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info", "homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO, return_value=SATELLITE_INFO,
), patch( ),
"homeassistant.components.wyoming._make_satellite", make_muted_satellite patch("homeassistant.components.wyoming._make_satellite", make_muted_satellite),
), patch( patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_muted", "homeassistant.components.wyoming.satellite.WyomingSatellite.on_muted",
on_muted, on_muted,
),
): ):
entry = await setup_config_entry(hass) entry = await setup_config_entry(hass)
async with asyncio.timeout(1): async with asyncio.timeout(1):
@ -462,16 +472,21 @@ async def test_satellite_restart(hass: HomeAssistant) -> None:
self.stop() self.stop()
on_restart_event.set() on_restart_event.set()
with patch( with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info", "homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO, return_value=SATELLITE_INFO,
), patch( ),
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite._connect_and_loop", "homeassistant.components.wyoming.satellite.WyomingSatellite._connect_and_loop",
side_effect=RuntimeError(), side_effect=RuntimeError(),
), patch( ),
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart", "homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
on_restart, on_restart,
), patch("homeassistant.components.wyoming.satellite._RESTART_SECONDS", 0): ),
patch("homeassistant.components.wyoming.satellite._RESTART_SECONDS", 0),
):
await setup_config_entry(hass) await setup_config_entry(hass)
async with asyncio.timeout(1): async with asyncio.timeout(1):
await on_restart_event.wait() await on_restart_event.wait()
@ -497,19 +512,25 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None:
async def on_stopped(self): async def on_stopped(self):
stopped_event.set() stopped_event.set()
with patch( with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info", "homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO, return_value=SATELLITE_INFO,
), patch( ),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient.connect", "homeassistant.components.wyoming.satellite.AsyncTcpClient.connect",
side_effect=ConnectionRefusedError(), side_effect=ConnectionRefusedError(),
), patch( ),
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_reconnect", "homeassistant.components.wyoming.satellite.WyomingSatellite.on_reconnect",
on_reconnect, on_reconnect,
), patch( ),
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped", "homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
on_stopped, on_stopped,
), patch("homeassistant.components.wyoming.satellite._RECONNECT_SECONDS", 0): ),
patch("homeassistant.components.wyoming.satellite._RECONNECT_SECONDS", 0),
):
await setup_config_entry(hass) await setup_config_entry(hass)
async with asyncio.timeout(1): async with asyncio.timeout(1):
await reconnect_event.wait() await reconnect_event.wait()
@ -524,17 +545,22 @@ async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None
self.stop() self.stop()
on_restart_event.set() on_restart_event.set()
with patch( with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info", "homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO, return_value=SATELLITE_INFO,
), patch( ),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient", "homeassistant.components.wyoming.satellite.AsyncTcpClient",
MockAsyncTcpClient([]), # no RunPipeline event MockAsyncTcpClient([]), # no RunPipeline event
), patch( ),
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
) as mock_run_pipeline, patch( ) as mock_run_pipeline,
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart", "homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
on_restart, on_restart,
),
): ):
await setup_config_entry(hass) await setup_config_entry(hass)
async with asyncio.timeout(1): async with asyncio.timeout(1):
@ -564,20 +590,26 @@ async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None
async def on_stopped(self): async def on_stopped(self):
on_stopped_event.set() on_stopped_event.set()
with patch( with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info", "homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO, return_value=SATELLITE_INFO,
), patch( ),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient", "homeassistant.components.wyoming.satellite.AsyncTcpClient",
MockAsyncTcpClient(events), MockAsyncTcpClient(events),
), patch( ),
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
) as mock_run_pipeline, patch( ) as mock_run_pipeline,
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart", "homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
on_restart, on_restart,
), patch( ),
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped", "homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
on_stopped, on_stopped,
),
): ):
entry = await setup_config_entry(hass) entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][ device: SatelliteDevice = hass.data[wyoming.DOMAIN][
@ -608,16 +640,20 @@ async def test_satellite_error_during_pipeline(hass: HomeAssistant) -> None:
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None: def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
pipeline_event.set() pipeline_event.set()
with patch( with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info", "homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO, return_value=SATELLITE_INFO,
), patch( ),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient", "homeassistant.components.wyoming.satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events), SatelliteAsyncTcpClient(events),
) as mock_client, patch( ) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
wraps=_async_pipeline_from_audio_stream, wraps=_async_pipeline_from_audio_stream,
) as mock_run_pipeline: ) as mock_run_pipeline,
):
await setup_config_entry(hass) await setup_config_entry(hass)
async with asyncio.timeout(1): async with asyncio.timeout(1):
@ -663,21 +699,27 @@ async def test_tts_not_wav(hass: HomeAssistant) -> None:
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None: def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
pipeline_event.set() pipeline_event.set()
with patch( with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info", "homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO, return_value=SATELLITE_INFO,
), patch( ),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient", "homeassistant.components.wyoming.satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events), SatelliteAsyncTcpClient(events),
) as mock_client, patch( ) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
wraps=_async_pipeline_from_audio_stream, wraps=_async_pipeline_from_audio_stream,
) as mock_run_pipeline, patch( ) as mock_run_pipeline,
patch(
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio", "homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
return_value=("mp3", bytes(1)), return_value=("mp3", bytes(1)),
), patch( ),
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite._stream_tts", "homeassistant.components.wyoming.satellite.WyomingSatellite._stream_tts",
_stream_tts, _stream_tts,
),
): ):
entry = await setup_config_entry(hass) entry = await setup_config_entry(hass)
@ -752,15 +794,19 @@ async def test_pipeline_changed(hass: HomeAssistant) -> None:
pipeline_stopped.set() pipeline_stopped.set()
with patch( with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info", "homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO, return_value=SATELLITE_INFO,
), patch( ),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient", "homeassistant.components.wyoming.satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events), SatelliteAsyncTcpClient(events),
) as mock_client, patch( ) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream, async_pipeline_from_audio_stream,
),
): ):
entry = await setup_config_entry(hass) entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][ device: SatelliteDevice = hass.data[wyoming.DOMAIN][
@ -822,15 +868,19 @@ async def test_audio_settings_changed(hass: HomeAssistant) -> None:
pipeline_stopped.set() pipeline_stopped.set()
with patch( with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info", "homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO, return_value=SATELLITE_INFO,
), patch( ),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient", "homeassistant.components.wyoming.satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events), SatelliteAsyncTcpClient(events),
) as mock_client, patch( ) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream, async_pipeline_from_audio_stream,
),
): ):
entry = await setup_config_entry(hass) entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][ device: SatelliteDevice = hass.data[wyoming.DOMAIN][
@ -873,7 +923,7 @@ async def test_invalid_stages(hass: HomeAssistant) -> None:
start_stage_event = asyncio.Event() start_stage_event = asyncio.Event()
end_stage_event = asyncio.Event() end_stage_event = asyncio.Event()
def _run_pipeline_once(self, run_pipeline): def _run_pipeline_once(self, run_pipeline, wake_word_phrase):
# Set bad start stage # Set bad start stage
run_pipeline.start_stage = PipelineStage.INTENT run_pipeline.start_stage = PipelineStage.INTENT
run_pipeline.end_stage = PipelineStage.TTS run_pipeline.end_stage = PipelineStage.TTS
@ -892,15 +942,19 @@ async def test_invalid_stages(hass: HomeAssistant) -> None:
except ValueError: except ValueError:
end_stage_event.set() end_stage_event.set()
with patch( with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info", "homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO, return_value=SATELLITE_INFO,
), patch( ),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient", "homeassistant.components.wyoming.satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events), SatelliteAsyncTcpClient(events),
) as mock_client, patch( ) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite._run_pipeline_once", "homeassistant.components.wyoming.satellite.WyomingSatellite._run_pipeline_once",
_run_pipeline_once, _run_pipeline_once,
),
): ):
entry = await setup_config_entry(hass) entry = await setup_config_entry(hass)
@ -950,15 +1004,19 @@ async def test_client_stops_pipeline(hass: HomeAssistant) -> None:
pipeline_stopped.set() pipeline_stopped.set()
with patch( with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info", "homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO, return_value=SATELLITE_INFO,
), patch( ),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient", "homeassistant.components.wyoming.satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events), SatelliteAsyncTcpClient(events),
) as mock_client, patch( ) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream, async_pipeline_from_audio_stream,
),
): ):
entry = await setup_config_entry(hass) entry = await setup_config_entry(hass)
@ -982,3 +1040,46 @@ async def test_client_stops_pipeline(hass: HomeAssistant) -> None:
# Stop the satellite # Stop the satellite
await hass.config_entries.async_unload(entry.entry_id) await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
async def test_wake_word_phrase(hass: HomeAssistant) -> None:
"""Test that wake word phrase from info is given to pipeline."""
events = [
# Fake local wake word detection
Info(satellite=SATELLITE_INFO.satellite, wake=WAKE_WORD_INFO.wake).event(),
Detection(name="Test Model").event(),
RunPipeline(
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
).event(),
]
pipeline_event = asyncio.Event()
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
pipeline_event.set()
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
),
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
wraps=_async_pipeline_from_audio_stream,
) as mock_run_pipeline,
):
await setup_config_entry(hass)
async with asyncio.timeout(1):
await pipeline_event.wait()
# async_pipeline_from_audio_stream will receive the wake word phrase for
# deconfliction.
mock_run_pipeline.assert_called_once()
assert (
mock_run_pipeline.call_args.kwargs.get("wake_word_phrase") == "Test Phrase"
)

View file

@ -1,4 +1,5 @@
"""Test stt.""" """Test stt."""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
@ -26,7 +27,7 @@ async def test_support(hass: HomeAssistant, init_wyoming_wake_word) -> None:
assert entity is not None assert entity is not None
assert (await entity.get_supported_wake_words()) == [ assert (await entity.get_supported_wake_words()) == [
wake_word.WakeWord(id="Test Model", name="Test Model") wake_word.WakeWord(id="Test Model", name="Test Model", phrase="Test Phrase")
] ]
@ -59,6 +60,8 @@ async def test_streaming_audio(
assert result is not None assert result is not None
assert result == snapshot assert result == snapshot
assert result.wake_word_id == "Test Model"
assert result.wake_word_phrase == "Test Phrase"
async def test_streaming_audio_connection_lost( async def test_streaming_audio_connection_lost(
@ -100,10 +103,13 @@ async def test_streaming_audio_oserror(
[Detection(name="Test Model", timestamp=1000).event()] [Detection(name="Test Model", timestamp=1000).event()]
) )
with patch( with (
patch(
"homeassistant.components.wyoming.wake_word.AsyncTcpClient", "homeassistant.components.wyoming.wake_word.AsyncTcpClient",
mock_client, mock_client,
), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")): ),
patch.object(mock_client, "read_event", side_effect=OSError("Boom!")),
):
result = await entity.async_process_audio_stream(audio_stream(), None) result = await entity.async_process_audio_stream(audio_stream(), None)
assert result is None assert result is None
@ -171,7 +177,7 @@ async def test_dynamic_wake_word_info(
# Original info # Original info
assert (await entity.get_supported_wake_words()) == [ assert (await entity.get_supported_wake_words()) == [
wake_word.WakeWord("Test Model", "Test Model") wake_word.WakeWord("Test Model", "Test Model", "Test Phrase")
] ]
new_info = Info( new_info = Info(
@ -185,6 +191,7 @@ async def test_dynamic_wake_word_info(
WakeModel( WakeModel(
name="ww1", name="ww1",
description="Wake Word 1", description="Wake Word 1",
phrase="Wake Word Phrase 1",
installed=True, installed=True,
attribution=TEST_ATTR, attribution=TEST_ATTR,
languages=[], languages=[],
@ -193,6 +200,7 @@ async def test_dynamic_wake_word_info(
WakeModel( WakeModel(
name="ww2", name="ww2",
description="Wake Word 2", description="Wake Word 2",
phrase="Wake Word Phrase 2",
installed=True, installed=True,
attribution=TEST_ATTR, attribution=TEST_ATTR,
languages=[], languages=[],
@ -210,6 +218,6 @@ async def test_dynamic_wake_word_info(
return_value=new_info, return_value=new_info,
): ):
assert (await entity.get_supported_wake_words()) == [ assert (await entity.get_supported_wake_words()) == [
wake_word.WakeWord("ww1", "Wake Word 1"), wake_word.WakeWord("ww1", "Wake Word 1", "Wake Word Phrase 1"),
wake_word.WakeWord("ww2", "Wake Word 2"), wake_word.WakeWord("ww2", "Wake Word 2", "Wake Word Phrase 2"),
] ]