Add speech-to-text cooldown for local wake word (#108806)
* Deconflict based on wake word * Undo test * Make wake up key a string, rename error * Update snapshot * Change to "wake word phrase" and normalize * Move normalization into the wake provider * Working on describe * Use satellite info to resolve wake word phrase * Add test for wake word phrase * Match phrase with model name in wake word provider * Check model id * Use one constant wake word cooldown * Update homeassistant/components/assist_pipeline/error.py Co-authored-by: Paulus Schoutsen <balloob@gmail.com> * Fix wake word tests --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
c38e0d22b8
commit
f6622ea8e0
20 changed files with 641 additions and 184 deletions
|
@ -83,6 +83,7 @@ async def async_pipeline_from_audio_stream(
|
|||
event_callback: PipelineEventCallback,
|
||||
stt_metadata: stt.SpeechMetadata,
|
||||
stt_stream: AsyncIterable[bytes],
|
||||
wake_word_phrase: str | None = None,
|
||||
pipeline_id: str | None = None,
|
||||
conversation_id: str | None = None,
|
||||
tts_audio_output: str | None = None,
|
||||
|
@ -101,6 +102,7 @@ async def async_pipeline_from_audio_stream(
|
|||
device_id=device_id,
|
||||
stt_metadata=stt_metadata,
|
||||
stt_stream=stt_stream,
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
run=PipelineRun(
|
||||
hass,
|
||||
context=context,
|
||||
|
|
|
@ -10,6 +10,6 @@ DEFAULT_WAKE_WORD_TIMEOUT = 3 # seconds
|
|||
CONF_DEBUG_RECORDING_DIR = "debug_recording_dir"
|
||||
|
||||
DATA_LAST_WAKE_UP = f"{DOMAIN}.last_wake_up"
|
||||
DEFAULT_WAKE_WORD_COOLDOWN = 2 # seconds
|
||||
WAKE_WORD_COOLDOWN = 2 # seconds
|
||||
|
||||
EVENT_RECORDING = f"{DOMAIN}_recording"
|
||||
|
|
|
@ -38,6 +38,17 @@ class SpeechToTextError(PipelineError):
|
|||
"""Error in speech-to-text portion of pipeline."""
|
||||
|
||||
|
||||
class DuplicateWakeUpDetectedError(WakeWordDetectionError):
|
||||
"""Error when multiple voice assistants wake up at the same time (same wake word)."""
|
||||
|
||||
def __init__(self, wake_up_phrase: str) -> None:
|
||||
"""Set error message."""
|
||||
super().__init__(
|
||||
"duplicate_wake_up_detected",
|
||||
f"Duplicate wake-up detected for {wake_up_phrase}",
|
||||
)
|
||||
|
||||
|
||||
class IntentRecognitionError(PipelineError):
|
||||
"""Error in intent recognition portion of pipeline."""
|
||||
|
||||
|
|
|
@ -55,10 +55,11 @@ from .const import (
|
|||
CONF_DEBUG_RECORDING_DIR,
|
||||
DATA_CONFIG,
|
||||
DATA_LAST_WAKE_UP,
|
||||
DEFAULT_WAKE_WORD_COOLDOWN,
|
||||
DOMAIN,
|
||||
WAKE_WORD_COOLDOWN,
|
||||
)
|
||||
from .error import (
|
||||
DuplicateWakeUpDetectedError,
|
||||
IntentRecognitionError,
|
||||
PipelineError,
|
||||
PipelineNotFound,
|
||||
|
@ -453,9 +454,6 @@ class WakeWordSettings:
|
|||
audio_seconds_to_buffer: float = 0
|
||||
"""Seconds of audio to buffer before detection and forward to STT."""
|
||||
|
||||
cooldown_seconds: float = DEFAULT_WAKE_WORD_COOLDOWN
|
||||
"""Seconds after a wake word detection where other detections are ignored."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AudioSettings:
|
||||
|
@ -742,16 +740,22 @@ class PipelineRun:
|
|||
wake_word_output: dict[str, Any] = {}
|
||||
else:
|
||||
# Avoid duplicate detections by checking cooldown
|
||||
wake_up_key = f"{self.wake_word_entity_id}.{result.wake_word_id}"
|
||||
last_wake_up = self.hass.data[DATA_LAST_WAKE_UP].get(wake_up_key)
|
||||
last_wake_up = self.hass.data[DATA_LAST_WAKE_UP].get(
|
||||
result.wake_word_phrase
|
||||
)
|
||||
if last_wake_up is not None:
|
||||
sec_since_last_wake_up = time.monotonic() - last_wake_up
|
||||
if sec_since_last_wake_up < wake_word_settings.cooldown_seconds:
|
||||
_LOGGER.debug("Duplicate wake word detection occurred")
|
||||
raise WakeWordDetectionAborted
|
||||
if sec_since_last_wake_up < WAKE_WORD_COOLDOWN:
|
||||
_LOGGER.debug(
|
||||
"Duplicate wake word detection occurred for %s",
|
||||
result.wake_word_phrase,
|
||||
)
|
||||
raise DuplicateWakeUpDetectedError(result.wake_word_phrase)
|
||||
|
||||
# Record last wake up time to block duplicate detections
|
||||
self.hass.data[DATA_LAST_WAKE_UP][wake_up_key] = time.monotonic()
|
||||
self.hass.data[DATA_LAST_WAKE_UP][
|
||||
result.wake_word_phrase
|
||||
] = time.monotonic()
|
||||
|
||||
if result.queued_audio:
|
||||
# Add audio that was pending at detection.
|
||||
|
@ -1308,6 +1312,9 @@ class PipelineInput:
|
|||
stt_stream: AsyncIterable[bytes] | None = None
|
||||
"""Input audio for stt. Required when start_stage = stt."""
|
||||
|
||||
wake_word_phrase: str | None = None
|
||||
"""Optional key used to de-duplicate wake-ups for local wake word detection."""
|
||||
|
||||
intent_input: str | None = None
|
||||
"""Input for conversation agent. Required when start_stage = intent."""
|
||||
|
||||
|
@ -1352,6 +1359,25 @@ class PipelineInput:
|
|||
assert self.stt_metadata is not None
|
||||
assert stt_processed_stream is not None
|
||||
|
||||
if self.wake_word_phrase is not None:
|
||||
# Avoid duplicate wake-ups by checking cooldown
|
||||
last_wake_up = self.run.hass.data[DATA_LAST_WAKE_UP].get(
|
||||
self.wake_word_phrase
|
||||
)
|
||||
if last_wake_up is not None:
|
||||
sec_since_last_wake_up = time.monotonic() - last_wake_up
|
||||
if sec_since_last_wake_up < WAKE_WORD_COOLDOWN:
|
||||
_LOGGER.debug(
|
||||
"Speech-to-text cancelled to avoid duplicate wake-up for %s",
|
||||
self.wake_word_phrase,
|
||||
)
|
||||
raise DuplicateWakeUpDetectedError(self.wake_word_phrase)
|
||||
|
||||
# Record last wake up time to block duplicate detections
|
||||
self.run.hass.data[DATA_LAST_WAKE_UP][
|
||||
self.wake_word_phrase
|
||||
] = time.monotonic()
|
||||
|
||||
stt_input_stream = stt_processed_stream
|
||||
|
||||
if stt_audio_buffer:
|
||||
|
|
|
@ -97,7 +97,12 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
|
|||
extra=vol.ALLOW_EXTRA,
|
||||
),
|
||||
PipelineStage.STT: vol.Schema(
|
||||
{vol.Required("input"): {vol.Required("sample_rate"): int}},
|
||||
{
|
||||
vol.Required("input"): {
|
||||
vol.Required("sample_rate"): int,
|
||||
vol.Optional("wake_word_phrase"): str,
|
||||
}
|
||||
},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
),
|
||||
PipelineStage.INTENT: vol.Schema(
|
||||
|
@ -149,12 +154,15 @@ async def websocket_run(
|
|||
msg_input = msg["input"]
|
||||
audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
incoming_sample_rate = msg_input["sample_rate"]
|
||||
wake_word_phrase: str | None = None
|
||||
|
||||
if start_stage == PipelineStage.WAKE_WORD:
|
||||
wake_word_settings = WakeWordSettings(
|
||||
timeout=msg["input"].get("timeout", DEFAULT_WAKE_WORD_TIMEOUT),
|
||||
audio_seconds_to_buffer=msg_input.get("audio_seconds_to_buffer", 0),
|
||||
)
|
||||
elif start_stage == PipelineStage.STT:
|
||||
wake_word_phrase = msg["input"].get("wake_word_phrase")
|
||||
|
||||
async def stt_stream() -> AsyncGenerator[bytes, None]:
|
||||
state = None
|
||||
|
@ -189,6 +197,7 @@ async def websocket_run(
|
|||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
)
|
||||
input_args["stt_stream"] = stt_stream()
|
||||
input_args["wake_word_phrase"] = wake_word_phrase
|
||||
|
||||
# Audio settings
|
||||
audio_settings = AudioSettings(
|
||||
|
|
|
@ -7,7 +7,13 @@ class WakeWord:
|
|||
"""Wake word model."""
|
||||
|
||||
id: str
|
||||
"""Id of wake word model"""
|
||||
|
||||
name: str
|
||||
"""Name of wake word model"""
|
||||
|
||||
phrase: str | None = None
|
||||
"""Wake word phrase used to trigger model"""
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -17,6 +23,9 @@ class DetectionResult:
|
|||
wake_word_id: str
|
||||
"""Id of detected wake word"""
|
||||
|
||||
wake_word_phrase: str
|
||||
"""Normalized phrase for the detected wake word"""
|
||||
|
||||
timestamp: int | None
|
||||
"""Timestamp of audio chunk with detected wake word"""
|
||||
|
||||
|
|
|
@ -6,6 +6,6 @@
|
|||
"dependencies": ["assist_pipeline"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/wyoming",
|
||||
"iot_class": "local_push",
|
||||
"requirements": ["wyoming==1.5.2"],
|
||||
"requirements": ["wyoming==1.5.3"],
|
||||
"zeroconf": ["_wyoming._tcp.local."]
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Support for Wyoming satellite services."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
import io
|
||||
|
@ -10,6 +11,7 @@ from wyoming.asr import Transcribe, Transcript
|
|||
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStart, AudioStop
|
||||
from wyoming.client import AsyncTcpClient
|
||||
from wyoming.error import Error
|
||||
from wyoming.info import Describe, Info
|
||||
from wyoming.ping import Ping, Pong
|
||||
from wyoming.pipeline import PipelineStage, RunPipeline
|
||||
from wyoming.satellite import PauseSatellite, RunSatellite
|
||||
|
@ -86,7 +88,9 @@ class WyomingSatellite:
|
|||
await self._connect_and_loop()
|
||||
except asyncio.CancelledError:
|
||||
raise # don't restart
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
except Exception as err: # pylint: disable=broad-exception-caught
|
||||
_LOGGER.debug("%s: %s", err.__class__.__name__, str(err))
|
||||
|
||||
# Ensure sensor is off (before restart)
|
||||
self.device.set_is_active(False)
|
||||
|
||||
|
@ -197,6 +201,8 @@ class WyomingSatellite:
|
|||
async def _run_pipeline_loop(self) -> None:
|
||||
"""Run a pipeline one or more times."""
|
||||
assert self._client is not None
|
||||
client_info: Info | None = None
|
||||
wake_word_phrase: str | None = None
|
||||
run_pipeline: RunPipeline | None = None
|
||||
send_ping = True
|
||||
|
||||
|
@ -209,6 +215,9 @@ class WyomingSatellite:
|
|||
)
|
||||
pending = {pipeline_ended_task, client_event_task}
|
||||
|
||||
# Update info from satellite
|
||||
await self._client.write_event(Describe().event())
|
||||
|
||||
while self.is_running and (not self.device.is_muted):
|
||||
if send_ping:
|
||||
# Ensure satellite is still connected
|
||||
|
@ -230,6 +239,9 @@ class WyomingSatellite:
|
|||
)
|
||||
pending.add(pipeline_ended_task)
|
||||
|
||||
# Clear last wake word detection
|
||||
wake_word_phrase = None
|
||||
|
||||
if (run_pipeline is not None) and run_pipeline.restart_on_end:
|
||||
# Automatically restart pipeline.
|
||||
# Used with "always on" streaming satellites.
|
||||
|
@ -253,7 +265,7 @@ class WyomingSatellite:
|
|||
elif RunPipeline.is_type(client_event.type):
|
||||
# Satellite requested pipeline run
|
||||
run_pipeline = RunPipeline.from_event(client_event)
|
||||
self._run_pipeline_once(run_pipeline)
|
||||
self._run_pipeline_once(run_pipeline, wake_word_phrase)
|
||||
elif (
|
||||
AudioChunk.is_type(client_event.type) and self._is_pipeline_running
|
||||
):
|
||||
|
@ -265,6 +277,32 @@ class WyomingSatellite:
|
|||
# Stop pipeline
|
||||
_LOGGER.debug("Client requested pipeline to stop")
|
||||
self._audio_queue.put_nowait(b"")
|
||||
elif Info.is_type(client_event.type):
|
||||
client_info = Info.from_event(client_event)
|
||||
_LOGGER.debug("Updated client info: %s", client_info)
|
||||
elif Detection.is_type(client_event.type):
|
||||
detection = Detection.from_event(client_event)
|
||||
wake_word_phrase = detection.name
|
||||
|
||||
# Resolve wake word name/id to phrase if info is available.
|
||||
#
|
||||
# This allows us to deconflict multiple satellite wake-ups
|
||||
# with the same wake word.
|
||||
if (client_info is not None) and (client_info.wake is not None):
|
||||
found_phrase = False
|
||||
for wake_service in client_info.wake:
|
||||
for wake_model in wake_service.models:
|
||||
if wake_model.name == detection.name:
|
||||
wake_word_phrase = (
|
||||
wake_model.phrase or wake_model.name
|
||||
)
|
||||
found_phrase = True
|
||||
break
|
||||
|
||||
if found_phrase:
|
||||
break
|
||||
|
||||
_LOGGER.debug("Client detected wake word: %s", wake_word_phrase)
|
||||
else:
|
||||
_LOGGER.debug("Unexpected event from satellite: %s", client_event)
|
||||
|
||||
|
@ -274,7 +312,9 @@ class WyomingSatellite:
|
|||
)
|
||||
pending.add(client_event_task)
|
||||
|
||||
def _run_pipeline_once(self, run_pipeline: RunPipeline) -> None:
|
||||
def _run_pipeline_once(
|
||||
self, run_pipeline: RunPipeline, wake_word_phrase: str | None = None
|
||||
) -> None:
|
||||
"""Run a pipeline once."""
|
||||
_LOGGER.debug("Received run information: %s", run_pipeline)
|
||||
|
||||
|
@ -332,6 +372,7 @@ class WyomingSatellite:
|
|||
volume_multiplier=self.device.volume_multiplier,
|
||||
),
|
||||
device_id=self.device.device_id,
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
),
|
||||
name="wyoming satellite pipeline",
|
||||
)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Support for Wyoming wake-word-detection services."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
import logging
|
||||
|
@ -49,7 +50,9 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
|
|||
wake_service = service.info.wake[0]
|
||||
|
||||
self._supported_wake_words = [
|
||||
wake_word.WakeWord(id=ww.name, name=ww.description or ww.name)
|
||||
wake_word.WakeWord(
|
||||
id=ww.name, name=ww.description or ww.name, phrase=ww.phrase
|
||||
)
|
||||
for ww in wake_service.models
|
||||
]
|
||||
self._attr_name = wake_service.name
|
||||
|
@ -64,7 +67,11 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
|
|||
if info is not None:
|
||||
wake_service = info.wake[0]
|
||||
self._supported_wake_words = [
|
||||
wake_word.WakeWord(id=ww.name, name=ww.description or ww.name)
|
||||
wake_word.WakeWord(
|
||||
id=ww.name,
|
||||
name=ww.description or ww.name,
|
||||
phrase=ww.phrase,
|
||||
)
|
||||
for ww in wake_service.models
|
||||
]
|
||||
|
||||
|
@ -140,6 +147,7 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
|
|||
|
||||
return wake_word.DetectionResult(
|
||||
wake_word_id=detection.name,
|
||||
wake_word_phrase=self._get_phrase(detection.name),
|
||||
timestamp=detection.timestamp,
|
||||
queued_audio=queued_audio,
|
||||
)
|
||||
|
@ -183,3 +191,14 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
|
|||
_LOGGER.exception("Error processing audio stream: %s", err)
|
||||
|
||||
return None
|
||||
|
||||
def _get_phrase(self, model_id: str) -> str:
|
||||
"""Get wake word phrase for model id."""
|
||||
for ww_model in self._supported_wake_words:
|
||||
if not ww_model.phrase:
|
||||
continue
|
||||
|
||||
if ww_model.id == model_id:
|
||||
return ww_model.phrase
|
||||
|
||||
return model_id
|
||||
|
|
|
@ -2863,7 +2863,7 @@ wled==0.17.0
|
|||
wolf-comm==0.0.4
|
||||
|
||||
# homeassistant.components.wyoming
|
||||
wyoming==1.5.2
|
||||
wyoming==1.5.3
|
||||
|
||||
# homeassistant.components.xbox
|
||||
xbox-webapi==2.0.11
|
||||
|
|
|
@ -2195,7 +2195,7 @@ wled==0.17.0
|
|||
wolf-comm==0.0.4
|
||||
|
||||
# homeassistant.components.wyoming
|
||||
wyoming==1.5.2
|
||||
wyoming==1.5.3
|
||||
|
||||
# homeassistant.components.xbox
|
||||
xbox-webapi==2.0.11
|
||||
|
|
|
@ -201,16 +201,19 @@ class MockWakeWordEntity(wake_word.WakeWordDetectionEntity):
|
|||
|
||||
if self.alternate_detections:
|
||||
detected_id = wake_words[self.detected_wake_word_index].id
|
||||
detected_name = wake_words[self.detected_wake_word_index].name
|
||||
self.detected_wake_word_index = (self.detected_wake_word_index + 1) % len(
|
||||
wake_words
|
||||
)
|
||||
else:
|
||||
detected_id = wake_words[0].id
|
||||
detected_name = wake_words[0].name
|
||||
|
||||
async for chunk, timestamp in stream:
|
||||
if chunk.startswith(b"wake word"):
|
||||
return wake_word.DetectionResult(
|
||||
wake_word_id=detected_id,
|
||||
wake_word_phrase=detected_name,
|
||||
timestamp=timestamp,
|
||||
queued_audio=[(b"queued audio", 0)],
|
||||
)
|
||||
|
@ -240,6 +243,7 @@ class MockWakeWordEntity2(wake_word.WakeWordDetectionEntity):
|
|||
if chunk.startswith(b"wake word"):
|
||||
return wake_word.DetectionResult(
|
||||
wake_word_id=wake_words[0].id,
|
||||
wake_word_phrase=wake_words[0].name,
|
||||
timestamp=timestamp,
|
||||
queued_audio=[(b"queued audio", 0)],
|
||||
)
|
||||
|
|
|
@ -294,6 +294,7 @@
|
|||
'wake_word_output': dict({
|
||||
'timestamp': 2000,
|
||||
'wake_word_id': 'test_ww',
|
||||
'wake_word_phrase': 'Test Wake Word',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.WAKE_WORD_END: 'wake_word-end'>,
|
||||
|
|
|
@ -381,6 +381,7 @@
|
|||
'wake_word_output': dict({
|
||||
'timestamp': 0,
|
||||
'wake_word_id': 'test_ww',
|
||||
'wake_word_phrase': 'Test Wake Word',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
|
@ -695,6 +696,46 @@
|
|||
# name: test_pipeline_empty_tts_output.3
|
||||
None
|
||||
# ---
|
||||
# name: test_stt_cooldown_different_ids
|
||||
dict({
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_stt_cooldown_different_ids.1
|
||||
dict({
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_stt_cooldown_same_id
|
||||
dict({
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_stt_cooldown_same_id.1
|
||||
dict({
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_stt_provider_missing
|
||||
dict({
|
||||
'language': 'en',
|
||||
|
@ -926,15 +967,14 @@
|
|||
'wake_word_output': dict({
|
||||
'timestamp': 0,
|
||||
'wake_word_id': 'test_ww',
|
||||
'wake_word_phrase': 'Test Wake Word',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_wake_word_cooldown_different_entities.5
|
||||
dict({
|
||||
'wake_word_output': dict({
|
||||
'timestamp': 0,
|
||||
'wake_word_id': 'test_ww',
|
||||
}),
|
||||
'code': 'duplicate_wake_up_detected',
|
||||
'message': 'Duplicate wake-up detected for Test Wake Word',
|
||||
})
|
||||
# ---
|
||||
# name: test_wake_word_cooldown_different_ids
|
||||
|
@ -988,6 +1028,7 @@
|
|||
'wake_word_output': dict({
|
||||
'timestamp': 0,
|
||||
'wake_word_id': 'test_ww',
|
||||
'wake_word_phrase': 'Test Wake Word',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
|
@ -996,6 +1037,7 @@
|
|||
'wake_word_output': dict({
|
||||
'timestamp': 0,
|
||||
'wake_word_id': 'test_ww_2',
|
||||
'wake_word_phrase': 'Test Wake Word 2',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
|
@ -1045,3 +1087,18 @@
|
|||
'timeout': 3,
|
||||
})
|
||||
# ---
|
||||
# name: test_wake_word_cooldown_same_id.4
|
||||
dict({
|
||||
'wake_word_output': dict({
|
||||
'timestamp': 0,
|
||||
'wake_word_id': 'test_ww',
|
||||
'wake_word_phrase': 'Test Wake Word',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_wake_word_cooldown_same_id.5
|
||||
dict({
|
||||
'code': 'duplicate_wake_up_detected',
|
||||
'message': 'Duplicate wake-up detected for Test Wake Word',
|
||||
})
|
||||
# ---
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Websocket tests for Voice Assistant integration."""
|
||||
import asyncio
|
||||
import base64
|
||||
from typing import Any
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
@ -1887,14 +1888,23 @@ async def test_wake_word_cooldown_same_id(
|
|||
await client_2.send_bytes(bytes([handler_id_2]) + b"wake word")
|
||||
|
||||
# Get response events
|
||||
error_data: dict[str, Any] | None = None
|
||||
msg = await client_1.receive_json()
|
||||
event_type_1 = msg["event"]["type"]
|
||||
assert msg["event"]["data"] == snapshot
|
||||
if event_type_1 == "error":
|
||||
error_data = msg["event"]["data"]
|
||||
|
||||
msg = await client_2.receive_json()
|
||||
event_type_2 = msg["event"]["type"]
|
||||
assert msg["event"]["data"] == snapshot
|
||||
if event_type_2 == "error":
|
||||
error_data = msg["event"]["data"]
|
||||
|
||||
# One should be a wake up, one should be an error
|
||||
assert {event_type_1, event_type_2} == {"wake_word-end", "error"}
|
||||
assert error_data is not None
|
||||
assert error_data["code"] == "duplicate_wake_up_detected"
|
||||
|
||||
|
||||
async def test_wake_word_cooldown_different_ids(
|
||||
|
@ -1989,7 +1999,7 @@ async def test_wake_word_cooldown_different_entities(
|
|||
hass_ws_client: WebSocketGenerator,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that duplicate wake word detections are allowed with different entities."""
|
||||
"""Test that duplicate wake word detections are blocked even with different wake word entities."""
|
||||
client_pipeline = await hass_ws_client(hass)
|
||||
await client_pipeline.send_json_auto_id(
|
||||
{
|
||||
|
@ -2049,7 +2059,7 @@ async def test_wake_word_cooldown_different_entities(
|
|||
}
|
||||
)
|
||||
|
||||
# Use different wake word entity
|
||||
# Use different wake word entity (but same wake word)
|
||||
await client_2.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
|
@ -2099,18 +2109,23 @@ async def test_wake_word_cooldown_different_entities(
|
|||
await client_2.send_bytes(bytes([handler_id_2]) + b"wake word")
|
||||
|
||||
# Get response events
|
||||
error_data: dict[str, Any] | None = None
|
||||
msg = await client_1.receive_json()
|
||||
assert msg["event"]["type"] == "wake_word-end", msg
|
||||
ww_id_1 = msg["event"]["data"]["wake_word_output"]["wake_word_id"]
|
||||
event_type_1 = msg["event"]["type"]
|
||||
assert msg["event"]["data"] == snapshot
|
||||
if event_type_1 == "error":
|
||||
error_data = msg["event"]["data"]
|
||||
|
||||
msg = await client_2.receive_json()
|
||||
assert msg["event"]["type"] == "wake_word-end", msg
|
||||
ww_id_2 = msg["event"]["data"]["wake_word_output"]["wake_word_id"]
|
||||
event_type_2 = msg["event"]["type"]
|
||||
assert msg["event"]["data"] == snapshot
|
||||
if event_type_2 == "error":
|
||||
error_data = msg["event"]["data"]
|
||||
|
||||
# Wake words should be the same
|
||||
assert ww_id_1 == ww_id_2
|
||||
# One should be a wake up, one should be an error
|
||||
assert {event_type_1, event_type_2} == {"wake_word-end", "error"}
|
||||
assert error_data is not None
|
||||
assert error_data["code"] == "duplicate_wake_up_detected"
|
||||
|
||||
|
||||
async def test_device_capture(
|
||||
|
@ -2521,3 +2536,138 @@ async def test_pipeline_list_devices(
|
|||
"pipeline_entity": "select.test_assist_device_test_prefix_pipeline",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
async def test_stt_cooldown_same_id(
|
||||
hass: HomeAssistant,
|
||||
init_components,
|
||||
mock_stt_provider,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that two speech-to-text pipelines cannot run within the cooldown period if they have the same wake word."""
|
||||
client_1 = await hass_ws_client(hass)
|
||||
client_2 = await hass_ws_client(hass)
|
||||
|
||||
await client_1.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "stt",
|
||||
"end_stage": "tts",
|
||||
"input": {
|
||||
"sample_rate": 16000,
|
||||
"wake_word_phrase": "ok_nabu",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
await client_2.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "stt",
|
||||
"end_stage": "tts",
|
||||
"input": {
|
||||
"sample_rate": 16000,
|
||||
"wake_word_phrase": "ok_nabu",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client_1.receive_json()
|
||||
assert msg["success"], msg
|
||||
|
||||
msg = await client_2.receive_json()
|
||||
assert msg["success"], msg
|
||||
|
||||
# run start
|
||||
msg = await client_1.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
msg["event"]["data"]["pipeline"] = ANY
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
msg = await client_2.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
msg["event"]["data"]["pipeline"] = ANY
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# Get response events
|
||||
error_data: dict[str, Any] | None = None
|
||||
msg = await client_1.receive_json()
|
||||
event_type_1 = msg["event"]["type"]
|
||||
if event_type_1 == "error":
|
||||
error_data = msg["event"]["data"]
|
||||
|
||||
msg = await client_2.receive_json()
|
||||
event_type_2 = msg["event"]["type"]
|
||||
if event_type_2 == "error":
|
||||
error_data = msg["event"]["data"]
|
||||
|
||||
# One should be a stt start, one should be an error
|
||||
assert {event_type_1, event_type_2} == {"stt-start", "error"}
|
||||
assert error_data is not None
|
||||
assert error_data["code"] == "duplicate_wake_up_detected"
|
||||
|
||||
|
||||
async def test_stt_cooldown_different_ids(
|
||||
hass: HomeAssistant,
|
||||
init_components,
|
||||
mock_stt_provider,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that two speech-to-text pipelines can run within the cooldown period if they have the different wake words."""
|
||||
client_1 = await hass_ws_client(hass)
|
||||
client_2 = await hass_ws_client(hass)
|
||||
|
||||
await client_1.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "stt",
|
||||
"end_stage": "tts",
|
||||
"input": {
|
||||
"sample_rate": 16000,
|
||||
"wake_word_phrase": "ok_nabu",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
await client_2.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "stt",
|
||||
"end_stage": "tts",
|
||||
"input": {
|
||||
"sample_rate": 16000,
|
||||
"wake_word_phrase": "hey_jarvis",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client_1.receive_json()
|
||||
assert msg["success"], msg
|
||||
|
||||
msg = await client_2.receive_json()
|
||||
assert msg["success"], msg
|
||||
|
||||
# run start
|
||||
msg = await client_1.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
msg["event"]["data"]["pipeline"] = ANY
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
msg = await client_2.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
msg["event"]["data"]["pipeline"] = ANY
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# Get response events
|
||||
msg = await client_1.receive_json()
|
||||
event_type_1 = msg["event"]["type"]
|
||||
|
||||
msg = await client_2.receive_json()
|
||||
event_type_2 = msg["event"]["type"]
|
||||
|
||||
# Both should start stt
|
||||
assert {event_type_1, event_type_2} == {"stt-start"}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Test wake_word component setup."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable, Generator
|
||||
from functools import partial
|
||||
|
@ -43,8 +44,12 @@ class MockProviderEntity(wake_word.WakeWordDetectionEntity):
|
|||
async def get_supported_wake_words(self) -> list[wake_word.WakeWord]:
|
||||
"""Return a list of supported wake words."""
|
||||
return [
|
||||
wake_word.WakeWord(id="test_ww", name="Test Wake Word"),
|
||||
wake_word.WakeWord(id="test_ww_2", name="Test Wake Word 2"),
|
||||
wake_word.WakeWord(
|
||||
id="test_ww", name="Test Wake Word", phrase="Test Phrase"
|
||||
),
|
||||
wake_word.WakeWord(
|
||||
id="test_ww_2", name="Test Wake Word 2", phrase="Test Phrase 2"
|
||||
),
|
||||
]
|
||||
|
||||
async def _async_process_audio_stream(
|
||||
|
@ -54,10 +59,18 @@ class MockProviderEntity(wake_word.WakeWordDetectionEntity):
|
|||
if wake_word_id is None:
|
||||
wake_word_id = (await self.get_supported_wake_words())[0].id
|
||||
|
||||
wake_word_phrase = wake_word_id
|
||||
for ww in await self.get_supported_wake_words():
|
||||
if ww.id == wake_word_id:
|
||||
wake_word_phrase = ww.phrase or ww.name
|
||||
break
|
||||
|
||||
async for _chunk, timestamp in stream:
|
||||
if timestamp >= 2000:
|
||||
return wake_word.DetectionResult(
|
||||
wake_word_id=wake_word_id, timestamp=timestamp
|
||||
wake_word_id=wake_word_id,
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
|
||||
# Not detected
|
||||
|
@ -159,10 +172,10 @@ async def test_config_entry_unload(
|
|||
|
||||
@freeze_time("2023-06-22 10:30:00+00:00")
|
||||
@pytest.mark.parametrize(
|
||||
("wake_word_id", "expected_ww"),
|
||||
("wake_word_id", "expected_ww", "expected_phrase"),
|
||||
[
|
||||
(None, "test_ww"),
|
||||
("test_ww_2", "test_ww_2"),
|
||||
(None, "test_ww", "Test Phrase"),
|
||||
("test_ww_2", "test_ww_2", "Test Phrase 2"),
|
||||
],
|
||||
)
|
||||
async def test_detected_entity(
|
||||
|
@ -171,6 +184,7 @@ async def test_detected_entity(
|
|||
setup: MockProviderEntity,
|
||||
wake_word_id: str | None,
|
||||
expected_ww: str,
|
||||
expected_phrase: str,
|
||||
) -> None:
|
||||
"""Test successful detection through entity."""
|
||||
|
||||
|
@ -184,7 +198,9 @@ async def test_detected_entity(
|
|||
state = setup.state
|
||||
assert state is None
|
||||
result = await setup.async_process_audio_stream(three_second_stream(), wake_word_id)
|
||||
assert result == wake_word.DetectionResult(expected_ww, 2048)
|
||||
assert result == wake_word.DetectionResult(
|
||||
wake_word_id=expected_ww, wake_word_phrase=expected_phrase, timestamp=2048
|
||||
)
|
||||
|
||||
assert state != setup.state
|
||||
assert setup.state == "2023-06-22T10:30:00+00:00"
|
||||
|
@ -285,8 +301,8 @@ async def test_list_wake_words(
|
|||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"wake_words": [
|
||||
{"id": "test_ww", "name": "Test Wake Word"},
|
||||
{"id": "test_ww_2", "name": "Test Wake Word 2"},
|
||||
{"id": "test_ww", "name": "Test Wake Word", "phrase": "Test Phrase"},
|
||||
{"id": "test_ww_2", "name": "Test Wake Word 2", "phrase": "Test Phrase 2"},
|
||||
]
|
||||
}
|
||||
|
||||
|
@ -320,9 +336,10 @@ async def test_list_wake_words_timeout(
|
|||
"""Test that the list_wake_words websocket command handles unknown entity."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
with patch.object(
|
||||
setup, "get_supported_wake_words", partial(asyncio.sleep, 1)
|
||||
), patch("homeassistant.components.wake_word.TIMEOUT_FETCH_WAKE_WORDS", 0):
|
||||
with (
|
||||
patch.object(setup, "get_supported_wake_words", partial(asyncio.sleep, 1)),
|
||||
patch("homeassistant.components.wake_word.TIMEOUT_FETCH_WAKE_WORDS", 0),
|
||||
):
|
||||
await client.send_json(
|
||||
{
|
||||
"id": 5,
|
||||
|
|
|
@ -75,6 +75,7 @@ WAKE_WORD_INFO = Info(
|
|||
WakeModel(
|
||||
name="Test Model",
|
||||
description="Test Model",
|
||||
phrase="Test Phrase",
|
||||
installed=True,
|
||||
attribution=TEST_ATTR,
|
||||
languages=["en-US"],
|
||||
|
|
|
@ -9,5 +9,6 @@
|
|||
]),
|
||||
'timestamp': 0,
|
||||
'wake_word_id': 'Test Model',
|
||||
'wake_word_phrase': 'Test Phrase',
|
||||
})
|
||||
# ---
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Test Wyoming satellite."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
@ -12,6 +13,7 @@ from wyoming.asr import Transcribe, Transcript
|
|||
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
||||
from wyoming.error import Error
|
||||
from wyoming.event import Event
|
||||
from wyoming.info import Info
|
||||
from wyoming.ping import Ping, Pong
|
||||
from wyoming.pipeline import PipelineStage, RunPipeline
|
||||
from wyoming.satellite import RunSatellite
|
||||
|
@ -26,7 +28,7 @@ from homeassistant.config_entries import ConfigEntry
|
|||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import SATELLITE_INFO, MockAsyncTcpClient
|
||||
from . import SATELLITE_INFO, WAKE_WORD_INFO, MockAsyncTcpClient
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
@ -207,19 +209,25 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
|||
audio_chunk_received.set()
|
||||
break
|
||||
|
||||
with patch(
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client, patch(
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
async_pipeline_from_audio_stream,
|
||||
), patch(
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
|
||||
return_value=("wav", get_test_wav()),
|
||||
), patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0):
|
||||
),
|
||||
patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0),
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||
entry.entry_id
|
||||
|
@ -433,14 +441,16 @@ async def test_satellite_muted(hass: HomeAssistant) -> None:
|
|||
self.device.set_is_muted(False)
|
||||
on_muted_event.set()
|
||||
|
||||
with patch(
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
"homeassistant.components.wyoming._make_satellite", make_muted_satellite
|
||||
), patch(
|
||||
),
|
||||
patch("homeassistant.components.wyoming._make_satellite", make_muted_satellite),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_muted",
|
||||
on_muted,
|
||||
),
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
async with asyncio.timeout(1):
|
||||
|
@ -462,16 +472,21 @@ async def test_satellite_restart(hass: HomeAssistant) -> None:
|
|||
self.stop()
|
||||
on_restart_event.set()
|
||||
|
||||
with patch(
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite._connect_and_loop",
|
||||
side_effect=RuntimeError(),
|
||||
), patch(
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
||||
on_restart,
|
||||
), patch("homeassistant.components.wyoming.satellite._RESTART_SECONDS", 0):
|
||||
),
|
||||
patch("homeassistant.components.wyoming.satellite._RESTART_SECONDS", 0),
|
||||
):
|
||||
await setup_config_entry(hass)
|
||||
async with asyncio.timeout(1):
|
||||
await on_restart_event.wait()
|
||||
|
@ -497,19 +512,25 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None:
|
|||
async def on_stopped(self):
|
||||
stopped_event.set()
|
||||
|
||||
with patch(
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient.connect",
|
||||
side_effect=ConnectionRefusedError(),
|
||||
), patch(
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_reconnect",
|
||||
on_reconnect,
|
||||
), patch(
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
|
||||
on_stopped,
|
||||
), patch("homeassistant.components.wyoming.satellite._RECONNECT_SECONDS", 0):
|
||||
),
|
||||
patch("homeassistant.components.wyoming.satellite._RECONNECT_SECONDS", 0),
|
||||
):
|
||||
await setup_config_entry(hass)
|
||||
async with asyncio.timeout(1):
|
||||
await reconnect_event.wait()
|
||||
|
@ -524,17 +545,22 @@ async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None
|
|||
self.stop()
|
||||
on_restart_event.set()
|
||||
|
||||
with patch(
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
MockAsyncTcpClient([]), # no RunPipeline event
|
||||
), patch(
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
) as mock_run_pipeline, patch(
|
||||
) as mock_run_pipeline,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
||||
on_restart,
|
||||
),
|
||||
):
|
||||
await setup_config_entry(hass)
|
||||
async with asyncio.timeout(1):
|
||||
|
@ -564,20 +590,26 @@ async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None
|
|||
async def on_stopped(self):
|
||||
on_stopped_event.set()
|
||||
|
||||
with patch(
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
MockAsyncTcpClient(events),
|
||||
), patch(
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
) as mock_run_pipeline, patch(
|
||||
) as mock_run_pipeline,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
||||
on_restart,
|
||||
), patch(
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
|
||||
on_stopped,
|
||||
),
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||
|
@ -608,16 +640,20 @@ async def test_satellite_error_during_pipeline(hass: HomeAssistant) -> None:
|
|||
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
|
||||
pipeline_event.set()
|
||||
|
||||
with patch(
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client, patch(
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
wraps=_async_pipeline_from_audio_stream,
|
||||
) as mock_run_pipeline:
|
||||
) as mock_run_pipeline,
|
||||
):
|
||||
await setup_config_entry(hass)
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
|
@ -663,21 +699,27 @@ async def test_tts_not_wav(hass: HomeAssistant) -> None:
|
|||
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
|
||||
pipeline_event.set()
|
||||
|
||||
with patch(
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client, patch(
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
wraps=_async_pipeline_from_audio_stream,
|
||||
) as mock_run_pipeline, patch(
|
||||
) as mock_run_pipeline,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
|
||||
return_value=("mp3", bytes(1)),
|
||||
), patch(
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite._stream_tts",
|
||||
_stream_tts,
|
||||
),
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
|
||||
|
@ -752,15 +794,19 @@ async def test_pipeline_changed(hass: HomeAssistant) -> None:
|
|||
|
||||
pipeline_stopped.set()
|
||||
|
||||
with patch(
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client, patch(
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
async_pipeline_from_audio_stream,
|
||||
),
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||
|
@ -822,15 +868,19 @@ async def test_audio_settings_changed(hass: HomeAssistant) -> None:
|
|||
|
||||
pipeline_stopped.set()
|
||||
|
||||
with patch(
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client, patch(
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
async_pipeline_from_audio_stream,
|
||||
),
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||
|
@ -873,7 +923,7 @@ async def test_invalid_stages(hass: HomeAssistant) -> None:
|
|||
start_stage_event = asyncio.Event()
|
||||
end_stage_event = asyncio.Event()
|
||||
|
||||
def _run_pipeline_once(self, run_pipeline):
|
||||
def _run_pipeline_once(self, run_pipeline, wake_word_phrase):
|
||||
# Set bad start stage
|
||||
run_pipeline.start_stage = PipelineStage.INTENT
|
||||
run_pipeline.end_stage = PipelineStage.TTS
|
||||
|
@ -892,15 +942,19 @@ async def test_invalid_stages(hass: HomeAssistant) -> None:
|
|||
except ValueError:
|
||||
end_stage_event.set()
|
||||
|
||||
with patch(
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client, patch(
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite._run_pipeline_once",
|
||||
_run_pipeline_once,
|
||||
),
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
|
||||
|
@ -950,15 +1004,19 @@ async def test_client_stops_pipeline(hass: HomeAssistant) -> None:
|
|||
|
||||
pipeline_stopped.set()
|
||||
|
||||
with patch(
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client, patch(
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
async_pipeline_from_audio_stream,
|
||||
),
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
|
||||
|
@ -982,3 +1040,46 @@ async def test_client_stops_pipeline(hass: HomeAssistant) -> None:
|
|||
# Stop the satellite
|
||||
await hass.config_entries.async_unload(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def test_wake_word_phrase(hass: HomeAssistant) -> None:
|
||||
"""Test that wake word phrase from info is given to pipeline."""
|
||||
events = [
|
||||
# Fake local wake word detection
|
||||
Info(satellite=SATELLITE_INFO.satellite, wake=WAKE_WORD_INFO.wake).event(),
|
||||
Detection(name="Test Model").event(),
|
||||
RunPipeline(
|
||||
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
||||
).event(),
|
||||
]
|
||||
|
||||
pipeline_event = asyncio.Event()
|
||||
|
||||
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
|
||||
pipeline_event.set()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
wraps=_async_pipeline_from_audio_stream,
|
||||
) as mock_run_pipeline,
|
||||
):
|
||||
await setup_config_entry(hass)
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await pipeline_event.wait()
|
||||
|
||||
# async_pipeline_from_audio_stream will receive the wake word phrase for
|
||||
# deconfliction.
|
||||
mock_run_pipeline.assert_called_once()
|
||||
assert (
|
||||
mock_run_pipeline.call_args.kwargs.get("wake_word_phrase") == "Test Phrase"
|
||||
)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Test stt."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
@ -26,7 +27,7 @@ async def test_support(hass: HomeAssistant, init_wyoming_wake_word) -> None:
|
|||
assert entity is not None
|
||||
|
||||
assert (await entity.get_supported_wake_words()) == [
|
||||
wake_word.WakeWord(id="Test Model", name="Test Model")
|
||||
wake_word.WakeWord(id="Test Model", name="Test Model", phrase="Test Phrase")
|
||||
]
|
||||
|
||||
|
||||
|
@ -59,6 +60,8 @@ async def test_streaming_audio(
|
|||
|
||||
assert result is not None
|
||||
assert result == snapshot
|
||||
assert result.wake_word_id == "Test Model"
|
||||
assert result.wake_word_phrase == "Test Phrase"
|
||||
|
||||
|
||||
async def test_streaming_audio_connection_lost(
|
||||
|
@ -100,10 +103,13 @@ async def test_streaming_audio_oserror(
|
|||
[Detection(name="Test Model", timestamp=1000).event()]
|
||||
)
|
||||
|
||||
with patch(
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.wake_word.AsyncTcpClient",
|
||||
mock_client,
|
||||
), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")):
|
||||
),
|
||||
patch.object(mock_client, "read_event", side_effect=OSError("Boom!")),
|
||||
):
|
||||
result = await entity.async_process_audio_stream(audio_stream(), None)
|
||||
|
||||
assert result is None
|
||||
|
@ -171,7 +177,7 @@ async def test_dynamic_wake_word_info(
|
|||
|
||||
# Original info
|
||||
assert (await entity.get_supported_wake_words()) == [
|
||||
wake_word.WakeWord("Test Model", "Test Model")
|
||||
wake_word.WakeWord("Test Model", "Test Model", "Test Phrase")
|
||||
]
|
||||
|
||||
new_info = Info(
|
||||
|
@ -185,6 +191,7 @@ async def test_dynamic_wake_word_info(
|
|||
WakeModel(
|
||||
name="ww1",
|
||||
description="Wake Word 1",
|
||||
phrase="Wake Word Phrase 1",
|
||||
installed=True,
|
||||
attribution=TEST_ATTR,
|
||||
languages=[],
|
||||
|
@ -193,6 +200,7 @@ async def test_dynamic_wake_word_info(
|
|||
WakeModel(
|
||||
name="ww2",
|
||||
description="Wake Word 2",
|
||||
phrase="Wake Word Phrase 2",
|
||||
installed=True,
|
||||
attribution=TEST_ATTR,
|
||||
languages=[],
|
||||
|
@ -210,6 +218,6 @@ async def test_dynamic_wake_word_info(
|
|||
return_value=new_info,
|
||||
):
|
||||
assert (await entity.get_supported_wake_words()) == [
|
||||
wake_word.WakeWord("ww1", "Wake Word 1"),
|
||||
wake_word.WakeWord("ww2", "Wake Word 2"),
|
||||
wake_word.WakeWord("ww1", "Wake Word 1", "Wake Word Phrase 1"),
|
||||
wake_word.WakeWord("ww2", "Wake Word 2", "Wake Word Phrase 2"),
|
||||
]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue