Add speech-to-text cooldown for local wake word (#108806)
* Deconflict based on wake word * Undo test * Make wake up key a string, rename error * Update snapshot * Change to "wake word phrase" and normalize * Move normalization into the wake provider * Working on describe * Use satellite info to resolve wake word phrase * Add test for wake word phrase * Match phrase with model name in wake word provider * Check model id * Use one constant wake word cooldown * Update homeassistant/components/assist_pipeline/error.py Co-authored-by: Paulus Schoutsen <balloob@gmail.com> * Fix wake word tests --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
c38e0d22b8
commit
f6622ea8e0
20 changed files with 641 additions and 184 deletions
|
@ -83,6 +83,7 @@ async def async_pipeline_from_audio_stream(
|
||||||
event_callback: PipelineEventCallback,
|
event_callback: PipelineEventCallback,
|
||||||
stt_metadata: stt.SpeechMetadata,
|
stt_metadata: stt.SpeechMetadata,
|
||||||
stt_stream: AsyncIterable[bytes],
|
stt_stream: AsyncIterable[bytes],
|
||||||
|
wake_word_phrase: str | None = None,
|
||||||
pipeline_id: str | None = None,
|
pipeline_id: str | None = None,
|
||||||
conversation_id: str | None = None,
|
conversation_id: str | None = None,
|
||||||
tts_audio_output: str | None = None,
|
tts_audio_output: str | None = None,
|
||||||
|
@ -101,6 +102,7 @@ async def async_pipeline_from_audio_stream(
|
||||||
device_id=device_id,
|
device_id=device_id,
|
||||||
stt_metadata=stt_metadata,
|
stt_metadata=stt_metadata,
|
||||||
stt_stream=stt_stream,
|
stt_stream=stt_stream,
|
||||||
|
wake_word_phrase=wake_word_phrase,
|
||||||
run=PipelineRun(
|
run=PipelineRun(
|
||||||
hass,
|
hass,
|
||||||
context=context,
|
context=context,
|
||||||
|
|
|
@ -10,6 +10,6 @@ DEFAULT_WAKE_WORD_TIMEOUT = 3 # seconds
|
||||||
CONF_DEBUG_RECORDING_DIR = "debug_recording_dir"
|
CONF_DEBUG_RECORDING_DIR = "debug_recording_dir"
|
||||||
|
|
||||||
DATA_LAST_WAKE_UP = f"{DOMAIN}.last_wake_up"
|
DATA_LAST_WAKE_UP = f"{DOMAIN}.last_wake_up"
|
||||||
DEFAULT_WAKE_WORD_COOLDOWN = 2 # seconds
|
WAKE_WORD_COOLDOWN = 2 # seconds
|
||||||
|
|
||||||
EVENT_RECORDING = f"{DOMAIN}_recording"
|
EVENT_RECORDING = f"{DOMAIN}_recording"
|
||||||
|
|
|
@ -38,6 +38,17 @@ class SpeechToTextError(PipelineError):
|
||||||
"""Error in speech-to-text portion of pipeline."""
|
"""Error in speech-to-text portion of pipeline."""
|
||||||
|
|
||||||
|
|
||||||
|
class DuplicateWakeUpDetectedError(WakeWordDetectionError):
|
||||||
|
"""Error when multiple voice assistants wake up at the same time (same wake word)."""
|
||||||
|
|
||||||
|
def __init__(self, wake_up_phrase: str) -> None:
|
||||||
|
"""Set error message."""
|
||||||
|
super().__init__(
|
||||||
|
"duplicate_wake_up_detected",
|
||||||
|
f"Duplicate wake-up detected for {wake_up_phrase}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class IntentRecognitionError(PipelineError):
|
class IntentRecognitionError(PipelineError):
|
||||||
"""Error in intent recognition portion of pipeline."""
|
"""Error in intent recognition portion of pipeline."""
|
||||||
|
|
||||||
|
|
|
@ -55,10 +55,11 @@ from .const import (
|
||||||
CONF_DEBUG_RECORDING_DIR,
|
CONF_DEBUG_RECORDING_DIR,
|
||||||
DATA_CONFIG,
|
DATA_CONFIG,
|
||||||
DATA_LAST_WAKE_UP,
|
DATA_LAST_WAKE_UP,
|
||||||
DEFAULT_WAKE_WORD_COOLDOWN,
|
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
|
WAKE_WORD_COOLDOWN,
|
||||||
)
|
)
|
||||||
from .error import (
|
from .error import (
|
||||||
|
DuplicateWakeUpDetectedError,
|
||||||
IntentRecognitionError,
|
IntentRecognitionError,
|
||||||
PipelineError,
|
PipelineError,
|
||||||
PipelineNotFound,
|
PipelineNotFound,
|
||||||
|
@ -453,9 +454,6 @@ class WakeWordSettings:
|
||||||
audio_seconds_to_buffer: float = 0
|
audio_seconds_to_buffer: float = 0
|
||||||
"""Seconds of audio to buffer before detection and forward to STT."""
|
"""Seconds of audio to buffer before detection and forward to STT."""
|
||||||
|
|
||||||
cooldown_seconds: float = DEFAULT_WAKE_WORD_COOLDOWN
|
|
||||||
"""Seconds after a wake word detection where other detections are ignored."""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class AudioSettings:
|
class AudioSettings:
|
||||||
|
@ -742,16 +740,22 @@ class PipelineRun:
|
||||||
wake_word_output: dict[str, Any] = {}
|
wake_word_output: dict[str, Any] = {}
|
||||||
else:
|
else:
|
||||||
# Avoid duplicate detections by checking cooldown
|
# Avoid duplicate detections by checking cooldown
|
||||||
wake_up_key = f"{self.wake_word_entity_id}.{result.wake_word_id}"
|
last_wake_up = self.hass.data[DATA_LAST_WAKE_UP].get(
|
||||||
last_wake_up = self.hass.data[DATA_LAST_WAKE_UP].get(wake_up_key)
|
result.wake_word_phrase
|
||||||
|
)
|
||||||
if last_wake_up is not None:
|
if last_wake_up is not None:
|
||||||
sec_since_last_wake_up = time.monotonic() - last_wake_up
|
sec_since_last_wake_up = time.monotonic() - last_wake_up
|
||||||
if sec_since_last_wake_up < wake_word_settings.cooldown_seconds:
|
if sec_since_last_wake_up < WAKE_WORD_COOLDOWN:
|
||||||
_LOGGER.debug("Duplicate wake word detection occurred")
|
_LOGGER.debug(
|
||||||
raise WakeWordDetectionAborted
|
"Duplicate wake word detection occurred for %s",
|
||||||
|
result.wake_word_phrase,
|
||||||
|
)
|
||||||
|
raise DuplicateWakeUpDetectedError(result.wake_word_phrase)
|
||||||
|
|
||||||
# Record last wake up time to block duplicate detections
|
# Record last wake up time to block duplicate detections
|
||||||
self.hass.data[DATA_LAST_WAKE_UP][wake_up_key] = time.monotonic()
|
self.hass.data[DATA_LAST_WAKE_UP][
|
||||||
|
result.wake_word_phrase
|
||||||
|
] = time.monotonic()
|
||||||
|
|
||||||
if result.queued_audio:
|
if result.queued_audio:
|
||||||
# Add audio that was pending at detection.
|
# Add audio that was pending at detection.
|
||||||
|
@ -1308,6 +1312,9 @@ class PipelineInput:
|
||||||
stt_stream: AsyncIterable[bytes] | None = None
|
stt_stream: AsyncIterable[bytes] | None = None
|
||||||
"""Input audio for stt. Required when start_stage = stt."""
|
"""Input audio for stt. Required when start_stage = stt."""
|
||||||
|
|
||||||
|
wake_word_phrase: str | None = None
|
||||||
|
"""Optional key used to de-duplicate wake-ups for local wake word detection."""
|
||||||
|
|
||||||
intent_input: str | None = None
|
intent_input: str | None = None
|
||||||
"""Input for conversation agent. Required when start_stage = intent."""
|
"""Input for conversation agent. Required when start_stage = intent."""
|
||||||
|
|
||||||
|
@ -1352,6 +1359,25 @@ class PipelineInput:
|
||||||
assert self.stt_metadata is not None
|
assert self.stt_metadata is not None
|
||||||
assert stt_processed_stream is not None
|
assert stt_processed_stream is not None
|
||||||
|
|
||||||
|
if self.wake_word_phrase is not None:
|
||||||
|
# Avoid duplicate wake-ups by checking cooldown
|
||||||
|
last_wake_up = self.run.hass.data[DATA_LAST_WAKE_UP].get(
|
||||||
|
self.wake_word_phrase
|
||||||
|
)
|
||||||
|
if last_wake_up is not None:
|
||||||
|
sec_since_last_wake_up = time.monotonic() - last_wake_up
|
||||||
|
if sec_since_last_wake_up < WAKE_WORD_COOLDOWN:
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Speech-to-text cancelled to avoid duplicate wake-up for %s",
|
||||||
|
self.wake_word_phrase,
|
||||||
|
)
|
||||||
|
raise DuplicateWakeUpDetectedError(self.wake_word_phrase)
|
||||||
|
|
||||||
|
# Record last wake up time to block duplicate detections
|
||||||
|
self.run.hass.data[DATA_LAST_WAKE_UP][
|
||||||
|
self.wake_word_phrase
|
||||||
|
] = time.monotonic()
|
||||||
|
|
||||||
stt_input_stream = stt_processed_stream
|
stt_input_stream = stt_processed_stream
|
||||||
|
|
||||||
if stt_audio_buffer:
|
if stt_audio_buffer:
|
||||||
|
|
|
@ -97,7 +97,12 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||||
extra=vol.ALLOW_EXTRA,
|
extra=vol.ALLOW_EXTRA,
|
||||||
),
|
),
|
||||||
PipelineStage.STT: vol.Schema(
|
PipelineStage.STT: vol.Schema(
|
||||||
{vol.Required("input"): {vol.Required("sample_rate"): int}},
|
{
|
||||||
|
vol.Required("input"): {
|
||||||
|
vol.Required("sample_rate"): int,
|
||||||
|
vol.Optional("wake_word_phrase"): str,
|
||||||
|
}
|
||||||
|
},
|
||||||
extra=vol.ALLOW_EXTRA,
|
extra=vol.ALLOW_EXTRA,
|
||||||
),
|
),
|
||||||
PipelineStage.INTENT: vol.Schema(
|
PipelineStage.INTENT: vol.Schema(
|
||||||
|
@ -149,12 +154,15 @@ async def websocket_run(
|
||||||
msg_input = msg["input"]
|
msg_input = msg["input"]
|
||||||
audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||||
incoming_sample_rate = msg_input["sample_rate"]
|
incoming_sample_rate = msg_input["sample_rate"]
|
||||||
|
wake_word_phrase: str | None = None
|
||||||
|
|
||||||
if start_stage == PipelineStage.WAKE_WORD:
|
if start_stage == PipelineStage.WAKE_WORD:
|
||||||
wake_word_settings = WakeWordSettings(
|
wake_word_settings = WakeWordSettings(
|
||||||
timeout=msg["input"].get("timeout", DEFAULT_WAKE_WORD_TIMEOUT),
|
timeout=msg["input"].get("timeout", DEFAULT_WAKE_WORD_TIMEOUT),
|
||||||
audio_seconds_to_buffer=msg_input.get("audio_seconds_to_buffer", 0),
|
audio_seconds_to_buffer=msg_input.get("audio_seconds_to_buffer", 0),
|
||||||
)
|
)
|
||||||
|
elif start_stage == PipelineStage.STT:
|
||||||
|
wake_word_phrase = msg["input"].get("wake_word_phrase")
|
||||||
|
|
||||||
async def stt_stream() -> AsyncGenerator[bytes, None]:
|
async def stt_stream() -> AsyncGenerator[bytes, None]:
|
||||||
state = None
|
state = None
|
||||||
|
@ -189,6 +197,7 @@ async def websocket_run(
|
||||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
)
|
)
|
||||||
input_args["stt_stream"] = stt_stream()
|
input_args["stt_stream"] = stt_stream()
|
||||||
|
input_args["wake_word_phrase"] = wake_word_phrase
|
||||||
|
|
||||||
# Audio settings
|
# Audio settings
|
||||||
audio_settings = AudioSettings(
|
audio_settings = AudioSettings(
|
||||||
|
|
|
@ -7,7 +7,13 @@ class WakeWord:
|
||||||
"""Wake word model."""
|
"""Wake word model."""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
|
"""Id of wake word model"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
|
"""Name of wake word model"""
|
||||||
|
|
||||||
|
phrase: str | None = None
|
||||||
|
"""Wake word phrase used to trigger model"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -17,6 +23,9 @@ class DetectionResult:
|
||||||
wake_word_id: str
|
wake_word_id: str
|
||||||
"""Id of detected wake word"""
|
"""Id of detected wake word"""
|
||||||
|
|
||||||
|
wake_word_phrase: str
|
||||||
|
"""Normalized phrase for the detected wake word"""
|
||||||
|
|
||||||
timestamp: int | None
|
timestamp: int | None
|
||||||
"""Timestamp of audio chunk with detected wake word"""
|
"""Timestamp of audio chunk with detected wake word"""
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,6 @@
|
||||||
"dependencies": ["assist_pipeline"],
|
"dependencies": ["assist_pipeline"],
|
||||||
"documentation": "https://www.home-assistant.io/integrations/wyoming",
|
"documentation": "https://www.home-assistant.io/integrations/wyoming",
|
||||||
"iot_class": "local_push",
|
"iot_class": "local_push",
|
||||||
"requirements": ["wyoming==1.5.2"],
|
"requirements": ["wyoming==1.5.3"],
|
||||||
"zeroconf": ["_wyoming._tcp.local."]
|
"zeroconf": ["_wyoming._tcp.local."]
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""Support for Wyoming satellite services."""
|
"""Support for Wyoming satellite services."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
import io
|
import io
|
||||||
|
@ -10,6 +11,7 @@ from wyoming.asr import Transcribe, Transcript
|
||||||
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStart, AudioStop
|
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStart, AudioStop
|
||||||
from wyoming.client import AsyncTcpClient
|
from wyoming.client import AsyncTcpClient
|
||||||
from wyoming.error import Error
|
from wyoming.error import Error
|
||||||
|
from wyoming.info import Describe, Info
|
||||||
from wyoming.ping import Ping, Pong
|
from wyoming.ping import Ping, Pong
|
||||||
from wyoming.pipeline import PipelineStage, RunPipeline
|
from wyoming.pipeline import PipelineStage, RunPipeline
|
||||||
from wyoming.satellite import PauseSatellite, RunSatellite
|
from wyoming.satellite import PauseSatellite, RunSatellite
|
||||||
|
@ -86,7 +88,9 @@ class WyomingSatellite:
|
||||||
await self._connect_and_loop()
|
await self._connect_and_loop()
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
raise # don't restart
|
raise # don't restart
|
||||||
except Exception: # pylint: disable=broad-exception-caught
|
except Exception as err: # pylint: disable=broad-exception-caught
|
||||||
|
_LOGGER.debug("%s: %s", err.__class__.__name__, str(err))
|
||||||
|
|
||||||
# Ensure sensor is off (before restart)
|
# Ensure sensor is off (before restart)
|
||||||
self.device.set_is_active(False)
|
self.device.set_is_active(False)
|
||||||
|
|
||||||
|
@ -197,6 +201,8 @@ class WyomingSatellite:
|
||||||
async def _run_pipeline_loop(self) -> None:
|
async def _run_pipeline_loop(self) -> None:
|
||||||
"""Run a pipeline one or more times."""
|
"""Run a pipeline one or more times."""
|
||||||
assert self._client is not None
|
assert self._client is not None
|
||||||
|
client_info: Info | None = None
|
||||||
|
wake_word_phrase: str | None = None
|
||||||
run_pipeline: RunPipeline | None = None
|
run_pipeline: RunPipeline | None = None
|
||||||
send_ping = True
|
send_ping = True
|
||||||
|
|
||||||
|
@ -209,6 +215,9 @@ class WyomingSatellite:
|
||||||
)
|
)
|
||||||
pending = {pipeline_ended_task, client_event_task}
|
pending = {pipeline_ended_task, client_event_task}
|
||||||
|
|
||||||
|
# Update info from satellite
|
||||||
|
await self._client.write_event(Describe().event())
|
||||||
|
|
||||||
while self.is_running and (not self.device.is_muted):
|
while self.is_running and (not self.device.is_muted):
|
||||||
if send_ping:
|
if send_ping:
|
||||||
# Ensure satellite is still connected
|
# Ensure satellite is still connected
|
||||||
|
@ -230,6 +239,9 @@ class WyomingSatellite:
|
||||||
)
|
)
|
||||||
pending.add(pipeline_ended_task)
|
pending.add(pipeline_ended_task)
|
||||||
|
|
||||||
|
# Clear last wake word detection
|
||||||
|
wake_word_phrase = None
|
||||||
|
|
||||||
if (run_pipeline is not None) and run_pipeline.restart_on_end:
|
if (run_pipeline is not None) and run_pipeline.restart_on_end:
|
||||||
# Automatically restart pipeline.
|
# Automatically restart pipeline.
|
||||||
# Used with "always on" streaming satellites.
|
# Used with "always on" streaming satellites.
|
||||||
|
@ -253,7 +265,7 @@ class WyomingSatellite:
|
||||||
elif RunPipeline.is_type(client_event.type):
|
elif RunPipeline.is_type(client_event.type):
|
||||||
# Satellite requested pipeline run
|
# Satellite requested pipeline run
|
||||||
run_pipeline = RunPipeline.from_event(client_event)
|
run_pipeline = RunPipeline.from_event(client_event)
|
||||||
self._run_pipeline_once(run_pipeline)
|
self._run_pipeline_once(run_pipeline, wake_word_phrase)
|
||||||
elif (
|
elif (
|
||||||
AudioChunk.is_type(client_event.type) and self._is_pipeline_running
|
AudioChunk.is_type(client_event.type) and self._is_pipeline_running
|
||||||
):
|
):
|
||||||
|
@ -265,6 +277,32 @@ class WyomingSatellite:
|
||||||
# Stop pipeline
|
# Stop pipeline
|
||||||
_LOGGER.debug("Client requested pipeline to stop")
|
_LOGGER.debug("Client requested pipeline to stop")
|
||||||
self._audio_queue.put_nowait(b"")
|
self._audio_queue.put_nowait(b"")
|
||||||
|
elif Info.is_type(client_event.type):
|
||||||
|
client_info = Info.from_event(client_event)
|
||||||
|
_LOGGER.debug("Updated client info: %s", client_info)
|
||||||
|
elif Detection.is_type(client_event.type):
|
||||||
|
detection = Detection.from_event(client_event)
|
||||||
|
wake_word_phrase = detection.name
|
||||||
|
|
||||||
|
# Resolve wake word name/id to phrase if info is available.
|
||||||
|
#
|
||||||
|
# This allows us to deconflict multiple satellite wake-ups
|
||||||
|
# with the same wake word.
|
||||||
|
if (client_info is not None) and (client_info.wake is not None):
|
||||||
|
found_phrase = False
|
||||||
|
for wake_service in client_info.wake:
|
||||||
|
for wake_model in wake_service.models:
|
||||||
|
if wake_model.name == detection.name:
|
||||||
|
wake_word_phrase = (
|
||||||
|
wake_model.phrase or wake_model.name
|
||||||
|
)
|
||||||
|
found_phrase = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if found_phrase:
|
||||||
|
break
|
||||||
|
|
||||||
|
_LOGGER.debug("Client detected wake word: %s", wake_word_phrase)
|
||||||
else:
|
else:
|
||||||
_LOGGER.debug("Unexpected event from satellite: %s", client_event)
|
_LOGGER.debug("Unexpected event from satellite: %s", client_event)
|
||||||
|
|
||||||
|
@ -274,7 +312,9 @@ class WyomingSatellite:
|
||||||
)
|
)
|
||||||
pending.add(client_event_task)
|
pending.add(client_event_task)
|
||||||
|
|
||||||
def _run_pipeline_once(self, run_pipeline: RunPipeline) -> None:
|
def _run_pipeline_once(
|
||||||
|
self, run_pipeline: RunPipeline, wake_word_phrase: str | None = None
|
||||||
|
) -> None:
|
||||||
"""Run a pipeline once."""
|
"""Run a pipeline once."""
|
||||||
_LOGGER.debug("Received run information: %s", run_pipeline)
|
_LOGGER.debug("Received run information: %s", run_pipeline)
|
||||||
|
|
||||||
|
@ -332,6 +372,7 @@ class WyomingSatellite:
|
||||||
volume_multiplier=self.device.volume_multiplier,
|
volume_multiplier=self.device.volume_multiplier,
|
||||||
),
|
),
|
||||||
device_id=self.device.device_id,
|
device_id=self.device.device_id,
|
||||||
|
wake_word_phrase=wake_word_phrase,
|
||||||
),
|
),
|
||||||
name="wyoming satellite pipeline",
|
name="wyoming satellite pipeline",
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""Support for Wyoming wake-word-detection services."""
|
"""Support for Wyoming wake-word-detection services."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import AsyncIterable
|
from collections.abc import AsyncIterable
|
||||||
import logging
|
import logging
|
||||||
|
@ -49,7 +50,9 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
|
||||||
wake_service = service.info.wake[0]
|
wake_service = service.info.wake[0]
|
||||||
|
|
||||||
self._supported_wake_words = [
|
self._supported_wake_words = [
|
||||||
wake_word.WakeWord(id=ww.name, name=ww.description or ww.name)
|
wake_word.WakeWord(
|
||||||
|
id=ww.name, name=ww.description or ww.name, phrase=ww.phrase
|
||||||
|
)
|
||||||
for ww in wake_service.models
|
for ww in wake_service.models
|
||||||
]
|
]
|
||||||
self._attr_name = wake_service.name
|
self._attr_name = wake_service.name
|
||||||
|
@ -64,7 +67,11 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
|
||||||
if info is not None:
|
if info is not None:
|
||||||
wake_service = info.wake[0]
|
wake_service = info.wake[0]
|
||||||
self._supported_wake_words = [
|
self._supported_wake_words = [
|
||||||
wake_word.WakeWord(id=ww.name, name=ww.description or ww.name)
|
wake_word.WakeWord(
|
||||||
|
id=ww.name,
|
||||||
|
name=ww.description or ww.name,
|
||||||
|
phrase=ww.phrase,
|
||||||
|
)
|
||||||
for ww in wake_service.models
|
for ww in wake_service.models
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -140,6 +147,7 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
|
||||||
|
|
||||||
return wake_word.DetectionResult(
|
return wake_word.DetectionResult(
|
||||||
wake_word_id=detection.name,
|
wake_word_id=detection.name,
|
||||||
|
wake_word_phrase=self._get_phrase(detection.name),
|
||||||
timestamp=detection.timestamp,
|
timestamp=detection.timestamp,
|
||||||
queued_audio=queued_audio,
|
queued_audio=queued_audio,
|
||||||
)
|
)
|
||||||
|
@ -183,3 +191,14 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
|
||||||
_LOGGER.exception("Error processing audio stream: %s", err)
|
_LOGGER.exception("Error processing audio stream: %s", err)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _get_phrase(self, model_id: str) -> str:
|
||||||
|
"""Get wake word phrase for model id."""
|
||||||
|
for ww_model in self._supported_wake_words:
|
||||||
|
if not ww_model.phrase:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ww_model.id == model_id:
|
||||||
|
return ww_model.phrase
|
||||||
|
|
||||||
|
return model_id
|
||||||
|
|
|
@ -2863,7 +2863,7 @@ wled==0.17.0
|
||||||
wolf-comm==0.0.4
|
wolf-comm==0.0.4
|
||||||
|
|
||||||
# homeassistant.components.wyoming
|
# homeassistant.components.wyoming
|
||||||
wyoming==1.5.2
|
wyoming==1.5.3
|
||||||
|
|
||||||
# homeassistant.components.xbox
|
# homeassistant.components.xbox
|
||||||
xbox-webapi==2.0.11
|
xbox-webapi==2.0.11
|
||||||
|
|
|
@ -2195,7 +2195,7 @@ wled==0.17.0
|
||||||
wolf-comm==0.0.4
|
wolf-comm==0.0.4
|
||||||
|
|
||||||
# homeassistant.components.wyoming
|
# homeassistant.components.wyoming
|
||||||
wyoming==1.5.2
|
wyoming==1.5.3
|
||||||
|
|
||||||
# homeassistant.components.xbox
|
# homeassistant.components.xbox
|
||||||
xbox-webapi==2.0.11
|
xbox-webapi==2.0.11
|
||||||
|
|
|
@ -201,16 +201,19 @@ class MockWakeWordEntity(wake_word.WakeWordDetectionEntity):
|
||||||
|
|
||||||
if self.alternate_detections:
|
if self.alternate_detections:
|
||||||
detected_id = wake_words[self.detected_wake_word_index].id
|
detected_id = wake_words[self.detected_wake_word_index].id
|
||||||
|
detected_name = wake_words[self.detected_wake_word_index].name
|
||||||
self.detected_wake_word_index = (self.detected_wake_word_index + 1) % len(
|
self.detected_wake_word_index = (self.detected_wake_word_index + 1) % len(
|
||||||
wake_words
|
wake_words
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
detected_id = wake_words[0].id
|
detected_id = wake_words[0].id
|
||||||
|
detected_name = wake_words[0].name
|
||||||
|
|
||||||
async for chunk, timestamp in stream:
|
async for chunk, timestamp in stream:
|
||||||
if chunk.startswith(b"wake word"):
|
if chunk.startswith(b"wake word"):
|
||||||
return wake_word.DetectionResult(
|
return wake_word.DetectionResult(
|
||||||
wake_word_id=detected_id,
|
wake_word_id=detected_id,
|
||||||
|
wake_word_phrase=detected_name,
|
||||||
timestamp=timestamp,
|
timestamp=timestamp,
|
||||||
queued_audio=[(b"queued audio", 0)],
|
queued_audio=[(b"queued audio", 0)],
|
||||||
)
|
)
|
||||||
|
@ -240,6 +243,7 @@ class MockWakeWordEntity2(wake_word.WakeWordDetectionEntity):
|
||||||
if chunk.startswith(b"wake word"):
|
if chunk.startswith(b"wake word"):
|
||||||
return wake_word.DetectionResult(
|
return wake_word.DetectionResult(
|
||||||
wake_word_id=wake_words[0].id,
|
wake_word_id=wake_words[0].id,
|
||||||
|
wake_word_phrase=wake_words[0].name,
|
||||||
timestamp=timestamp,
|
timestamp=timestamp,
|
||||||
queued_audio=[(b"queued audio", 0)],
|
queued_audio=[(b"queued audio", 0)],
|
||||||
)
|
)
|
||||||
|
|
|
@ -294,6 +294,7 @@
|
||||||
'wake_word_output': dict({
|
'wake_word_output': dict({
|
||||||
'timestamp': 2000,
|
'timestamp': 2000,
|
||||||
'wake_word_id': 'test_ww',
|
'wake_word_id': 'test_ww',
|
||||||
|
'wake_word_phrase': 'Test Wake Word',
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
'type': <PipelineEventType.WAKE_WORD_END: 'wake_word-end'>,
|
'type': <PipelineEventType.WAKE_WORD_END: 'wake_word-end'>,
|
||||||
|
|
|
@ -381,6 +381,7 @@
|
||||||
'wake_word_output': dict({
|
'wake_word_output': dict({
|
||||||
'timestamp': 0,
|
'timestamp': 0,
|
||||||
'wake_word_id': 'test_ww',
|
'wake_word_id': 'test_ww',
|
||||||
|
'wake_word_phrase': 'Test Wake Word',
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
|
@ -695,6 +696,46 @@
|
||||||
# name: test_pipeline_empty_tts_output.3
|
# name: test_pipeline_empty_tts_output.3
|
||||||
None
|
None
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_stt_cooldown_different_ids
|
||||||
|
dict({
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
'runner_data': dict({
|
||||||
|
'stt_binary_handler_id': 1,
|
||||||
|
'timeout': 300,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_stt_cooldown_different_ids.1
|
||||||
|
dict({
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
'runner_data': dict({
|
||||||
|
'stt_binary_handler_id': 1,
|
||||||
|
'timeout': 300,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_stt_cooldown_same_id
|
||||||
|
dict({
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
'runner_data': dict({
|
||||||
|
'stt_binary_handler_id': 1,
|
||||||
|
'timeout': 300,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_stt_cooldown_same_id.1
|
||||||
|
dict({
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
'runner_data': dict({
|
||||||
|
'stt_binary_handler_id': 1,
|
||||||
|
'timeout': 300,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
# name: test_stt_provider_missing
|
# name: test_stt_provider_missing
|
||||||
dict({
|
dict({
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
|
@ -926,15 +967,14 @@
|
||||||
'wake_word_output': dict({
|
'wake_word_output': dict({
|
||||||
'timestamp': 0,
|
'timestamp': 0,
|
||||||
'wake_word_id': 'test_ww',
|
'wake_word_id': 'test_ww',
|
||||||
|
'wake_word_phrase': 'Test Wake Word',
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_wake_word_cooldown_different_entities.5
|
# name: test_wake_word_cooldown_different_entities.5
|
||||||
dict({
|
dict({
|
||||||
'wake_word_output': dict({
|
'code': 'duplicate_wake_up_detected',
|
||||||
'timestamp': 0,
|
'message': 'Duplicate wake-up detected for Test Wake Word',
|
||||||
'wake_word_id': 'test_ww',
|
|
||||||
}),
|
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_wake_word_cooldown_different_ids
|
# name: test_wake_word_cooldown_different_ids
|
||||||
|
@ -988,6 +1028,7 @@
|
||||||
'wake_word_output': dict({
|
'wake_word_output': dict({
|
||||||
'timestamp': 0,
|
'timestamp': 0,
|
||||||
'wake_word_id': 'test_ww',
|
'wake_word_id': 'test_ww',
|
||||||
|
'wake_word_phrase': 'Test Wake Word',
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
|
@ -996,6 +1037,7 @@
|
||||||
'wake_word_output': dict({
|
'wake_word_output': dict({
|
||||||
'timestamp': 0,
|
'timestamp': 0,
|
||||||
'wake_word_id': 'test_ww_2',
|
'wake_word_id': 'test_ww_2',
|
||||||
|
'wake_word_phrase': 'Test Wake Word 2',
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
|
@ -1045,3 +1087,18 @@
|
||||||
'timeout': 3,
|
'timeout': 3,
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_wake_word_cooldown_same_id.4
|
||||||
|
dict({
|
||||||
|
'wake_word_output': dict({
|
||||||
|
'timestamp': 0,
|
||||||
|
'wake_word_id': 'test_ww',
|
||||||
|
'wake_word_phrase': 'Test Wake Word',
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_wake_word_cooldown_same_id.5
|
||||||
|
dict({
|
||||||
|
'code': 'duplicate_wake_up_detected',
|
||||||
|
'message': 'Duplicate wake-up detected for Test Wake Word',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Websocket tests for Voice Assistant integration."""
|
"""Websocket tests for Voice Assistant integration."""
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
from typing import Any
|
||||||
from unittest.mock import ANY, patch
|
from unittest.mock import ANY, patch
|
||||||
|
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
@ -1887,14 +1888,23 @@ async def test_wake_word_cooldown_same_id(
|
||||||
await client_2.send_bytes(bytes([handler_id_2]) + b"wake word")
|
await client_2.send_bytes(bytes([handler_id_2]) + b"wake word")
|
||||||
|
|
||||||
# Get response events
|
# Get response events
|
||||||
|
error_data: dict[str, Any] | None = None
|
||||||
msg = await client_1.receive_json()
|
msg = await client_1.receive_json()
|
||||||
event_type_1 = msg["event"]["type"]
|
event_type_1 = msg["event"]["type"]
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
if event_type_1 == "error":
|
||||||
|
error_data = msg["event"]["data"]
|
||||||
|
|
||||||
msg = await client_2.receive_json()
|
msg = await client_2.receive_json()
|
||||||
event_type_2 = msg["event"]["type"]
|
event_type_2 = msg["event"]["type"]
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
if event_type_2 == "error":
|
||||||
|
error_data = msg["event"]["data"]
|
||||||
|
|
||||||
# One should be a wake up, one should be an error
|
# One should be a wake up, one should be an error
|
||||||
assert {event_type_1, event_type_2} == {"wake_word-end", "error"}
|
assert {event_type_1, event_type_2} == {"wake_word-end", "error"}
|
||||||
|
assert error_data is not None
|
||||||
|
assert error_data["code"] == "duplicate_wake_up_detected"
|
||||||
|
|
||||||
|
|
||||||
async def test_wake_word_cooldown_different_ids(
|
async def test_wake_word_cooldown_different_ids(
|
||||||
|
@ -1989,7 +1999,7 @@ async def test_wake_word_cooldown_different_entities(
|
||||||
hass_ws_client: WebSocketGenerator,
|
hass_ws_client: WebSocketGenerator,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that duplicate wake word detections are allowed with different entities."""
|
"""Test that duplicate wake word detections are blocked even with different wake word entities."""
|
||||||
client_pipeline = await hass_ws_client(hass)
|
client_pipeline = await hass_ws_client(hass)
|
||||||
await client_pipeline.send_json_auto_id(
|
await client_pipeline.send_json_auto_id(
|
||||||
{
|
{
|
||||||
|
@ -2049,7 +2059,7 @@ async def test_wake_word_cooldown_different_entities(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use different wake word entity
|
# Use different wake word entity (but same wake word)
|
||||||
await client_2.send_json_auto_id(
|
await client_2.send_json_auto_id(
|
||||||
{
|
{
|
||||||
"type": "assist_pipeline/run",
|
"type": "assist_pipeline/run",
|
||||||
|
@ -2099,18 +2109,23 @@ async def test_wake_word_cooldown_different_entities(
|
||||||
await client_2.send_bytes(bytes([handler_id_2]) + b"wake word")
|
await client_2.send_bytes(bytes([handler_id_2]) + b"wake word")
|
||||||
|
|
||||||
# Get response events
|
# Get response events
|
||||||
|
error_data: dict[str, Any] | None = None
|
||||||
msg = await client_1.receive_json()
|
msg = await client_1.receive_json()
|
||||||
assert msg["event"]["type"] == "wake_word-end", msg
|
event_type_1 = msg["event"]["type"]
|
||||||
ww_id_1 = msg["event"]["data"]["wake_word_output"]["wake_word_id"]
|
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
if event_type_1 == "error":
|
||||||
|
error_data = msg["event"]["data"]
|
||||||
|
|
||||||
msg = await client_2.receive_json()
|
msg = await client_2.receive_json()
|
||||||
assert msg["event"]["type"] == "wake_word-end", msg
|
event_type_2 = msg["event"]["type"]
|
||||||
ww_id_2 = msg["event"]["data"]["wake_word_output"]["wake_word_id"]
|
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
if event_type_2 == "error":
|
||||||
|
error_data = msg["event"]["data"]
|
||||||
|
|
||||||
# Wake words should be the same
|
# One should be a wake up, one should be an error
|
||||||
assert ww_id_1 == ww_id_2
|
assert {event_type_1, event_type_2} == {"wake_word-end", "error"}
|
||||||
|
assert error_data is not None
|
||||||
|
assert error_data["code"] == "duplicate_wake_up_detected"
|
||||||
|
|
||||||
|
|
||||||
async def test_device_capture(
|
async def test_device_capture(
|
||||||
|
@ -2521,3 +2536,138 @@ async def test_pipeline_list_devices(
|
||||||
"pipeline_entity": "select.test_assist_device_test_prefix_pipeline",
|
"pipeline_entity": "select.test_assist_device_test_prefix_pipeline",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_stt_cooldown_same_id(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components,
|
||||||
|
mock_stt_provider,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test that two speech-to-text pipelines cannot run within the cooldown period if they have the same wake word."""
|
||||||
|
client_1 = await hass_ws_client(hass)
|
||||||
|
client_2 = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await client_1.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/run",
|
||||||
|
"start_stage": "stt",
|
||||||
|
"end_stage": "tts",
|
||||||
|
"input": {
|
||||||
|
"sample_rate": 16000,
|
||||||
|
"wake_word_phrase": "ok_nabu",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await client_2.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/run",
|
||||||
|
"start_stage": "stt",
|
||||||
|
"end_stage": "tts",
|
||||||
|
"input": {
|
||||||
|
"sample_rate": 16000,
|
||||||
|
"wake_word_phrase": "ok_nabu",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client_1.receive_json()
|
||||||
|
assert msg["success"], msg
|
||||||
|
|
||||||
|
msg = await client_2.receive_json()
|
||||||
|
assert msg["success"], msg
|
||||||
|
|
||||||
|
# run start
|
||||||
|
msg = await client_1.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-start"
|
||||||
|
msg["event"]["data"]["pipeline"] = ANY
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
|
msg = await client_2.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-start"
|
||||||
|
msg["event"]["data"]["pipeline"] = ANY
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
|
# Get response events
|
||||||
|
error_data: dict[str, Any] | None = None
|
||||||
|
msg = await client_1.receive_json()
|
||||||
|
event_type_1 = msg["event"]["type"]
|
||||||
|
if event_type_1 == "error":
|
||||||
|
error_data = msg["event"]["data"]
|
||||||
|
|
||||||
|
msg = await client_2.receive_json()
|
||||||
|
event_type_2 = msg["event"]["type"]
|
||||||
|
if event_type_2 == "error":
|
||||||
|
error_data = msg["event"]["data"]
|
||||||
|
|
||||||
|
# One should be a stt start, one should be an error
|
||||||
|
assert {event_type_1, event_type_2} == {"stt-start", "error"}
|
||||||
|
assert error_data is not None
|
||||||
|
assert error_data["code"] == "duplicate_wake_up_detected"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_stt_cooldown_different_ids(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components,
|
||||||
|
mock_stt_provider,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test that two speech-to-text pipelines can run within the cooldown period if they have the different wake words."""
|
||||||
|
client_1 = await hass_ws_client(hass)
|
||||||
|
client_2 = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await client_1.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/run",
|
||||||
|
"start_stage": "stt",
|
||||||
|
"end_stage": "tts",
|
||||||
|
"input": {
|
||||||
|
"sample_rate": 16000,
|
||||||
|
"wake_word_phrase": "ok_nabu",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await client_2.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/run",
|
||||||
|
"start_stage": "stt",
|
||||||
|
"end_stage": "tts",
|
||||||
|
"input": {
|
||||||
|
"sample_rate": 16000,
|
||||||
|
"wake_word_phrase": "hey_jarvis",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client_1.receive_json()
|
||||||
|
assert msg["success"], msg
|
||||||
|
|
||||||
|
msg = await client_2.receive_json()
|
||||||
|
assert msg["success"], msg
|
||||||
|
|
||||||
|
# run start
|
||||||
|
msg = await client_1.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-start"
|
||||||
|
msg["event"]["data"]["pipeline"] = ANY
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
|
msg = await client_2.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-start"
|
||||||
|
msg["event"]["data"]["pipeline"] = ANY
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
|
# Get response events
|
||||||
|
msg = await client_1.receive_json()
|
||||||
|
event_type_1 = msg["event"]["type"]
|
||||||
|
|
||||||
|
msg = await client_2.receive_json()
|
||||||
|
event_type_2 = msg["event"]["type"]
|
||||||
|
|
||||||
|
# Both should start stt
|
||||||
|
assert {event_type_1, event_type_2} == {"stt-start"}
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""Test wake_word component setup."""
|
"""Test wake_word component setup."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import AsyncIterable, Generator
|
from collections.abc import AsyncIterable, Generator
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
@ -43,8 +44,12 @@ class MockProviderEntity(wake_word.WakeWordDetectionEntity):
|
||||||
async def get_supported_wake_words(self) -> list[wake_word.WakeWord]:
|
async def get_supported_wake_words(self) -> list[wake_word.WakeWord]:
|
||||||
"""Return a list of supported wake words."""
|
"""Return a list of supported wake words."""
|
||||||
return [
|
return [
|
||||||
wake_word.WakeWord(id="test_ww", name="Test Wake Word"),
|
wake_word.WakeWord(
|
||||||
wake_word.WakeWord(id="test_ww_2", name="Test Wake Word 2"),
|
id="test_ww", name="Test Wake Word", phrase="Test Phrase"
|
||||||
|
),
|
||||||
|
wake_word.WakeWord(
|
||||||
|
id="test_ww_2", name="Test Wake Word 2", phrase="Test Phrase 2"
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
async def _async_process_audio_stream(
|
async def _async_process_audio_stream(
|
||||||
|
@ -54,10 +59,18 @@ class MockProviderEntity(wake_word.WakeWordDetectionEntity):
|
||||||
if wake_word_id is None:
|
if wake_word_id is None:
|
||||||
wake_word_id = (await self.get_supported_wake_words())[0].id
|
wake_word_id = (await self.get_supported_wake_words())[0].id
|
||||||
|
|
||||||
|
wake_word_phrase = wake_word_id
|
||||||
|
for ww in await self.get_supported_wake_words():
|
||||||
|
if ww.id == wake_word_id:
|
||||||
|
wake_word_phrase = ww.phrase or ww.name
|
||||||
|
break
|
||||||
|
|
||||||
async for _chunk, timestamp in stream:
|
async for _chunk, timestamp in stream:
|
||||||
if timestamp >= 2000:
|
if timestamp >= 2000:
|
||||||
return wake_word.DetectionResult(
|
return wake_word.DetectionResult(
|
||||||
wake_word_id=wake_word_id, timestamp=timestamp
|
wake_word_id=wake_word_id,
|
||||||
|
wake_word_phrase=wake_word_phrase,
|
||||||
|
timestamp=timestamp,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Not detected
|
# Not detected
|
||||||
|
@ -159,10 +172,10 @@ async def test_config_entry_unload(
|
||||||
|
|
||||||
@freeze_time("2023-06-22 10:30:00+00:00")
|
@freeze_time("2023-06-22 10:30:00+00:00")
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("wake_word_id", "expected_ww"),
|
("wake_word_id", "expected_ww", "expected_phrase"),
|
||||||
[
|
[
|
||||||
(None, "test_ww"),
|
(None, "test_ww", "Test Phrase"),
|
||||||
("test_ww_2", "test_ww_2"),
|
("test_ww_2", "test_ww_2", "Test Phrase 2"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_detected_entity(
|
async def test_detected_entity(
|
||||||
|
@ -171,6 +184,7 @@ async def test_detected_entity(
|
||||||
setup: MockProviderEntity,
|
setup: MockProviderEntity,
|
||||||
wake_word_id: str | None,
|
wake_word_id: str | None,
|
||||||
expected_ww: str,
|
expected_ww: str,
|
||||||
|
expected_phrase: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test successful detection through entity."""
|
"""Test successful detection through entity."""
|
||||||
|
|
||||||
|
@ -184,7 +198,9 @@ async def test_detected_entity(
|
||||||
state = setup.state
|
state = setup.state
|
||||||
assert state is None
|
assert state is None
|
||||||
result = await setup.async_process_audio_stream(three_second_stream(), wake_word_id)
|
result = await setup.async_process_audio_stream(three_second_stream(), wake_word_id)
|
||||||
assert result == wake_word.DetectionResult(expected_ww, 2048)
|
assert result == wake_word.DetectionResult(
|
||||||
|
wake_word_id=expected_ww, wake_word_phrase=expected_phrase, timestamp=2048
|
||||||
|
)
|
||||||
|
|
||||||
assert state != setup.state
|
assert state != setup.state
|
||||||
assert setup.state == "2023-06-22T10:30:00+00:00"
|
assert setup.state == "2023-06-22T10:30:00+00:00"
|
||||||
|
@ -285,8 +301,8 @@ async def test_list_wake_words(
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
assert msg["result"] == {
|
assert msg["result"] == {
|
||||||
"wake_words": [
|
"wake_words": [
|
||||||
{"id": "test_ww", "name": "Test Wake Word"},
|
{"id": "test_ww", "name": "Test Wake Word", "phrase": "Test Phrase"},
|
||||||
{"id": "test_ww_2", "name": "Test Wake Word 2"},
|
{"id": "test_ww_2", "name": "Test Wake Word 2", "phrase": "Test Phrase 2"},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -320,9 +336,10 @@ async def test_list_wake_words_timeout(
|
||||||
"""Test that the list_wake_words websocket command handles unknown entity."""
|
"""Test that the list_wake_words websocket command handles unknown entity."""
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
with patch.object(
|
with (
|
||||||
setup, "get_supported_wake_words", partial(asyncio.sleep, 1)
|
patch.object(setup, "get_supported_wake_words", partial(asyncio.sleep, 1)),
|
||||||
), patch("homeassistant.components.wake_word.TIMEOUT_FETCH_WAKE_WORDS", 0):
|
patch("homeassistant.components.wake_word.TIMEOUT_FETCH_WAKE_WORDS", 0),
|
||||||
|
):
|
||||||
await client.send_json(
|
await client.send_json(
|
||||||
{
|
{
|
||||||
"id": 5,
|
"id": 5,
|
||||||
|
|
|
@ -75,6 +75,7 @@ WAKE_WORD_INFO = Info(
|
||||||
WakeModel(
|
WakeModel(
|
||||||
name="Test Model",
|
name="Test Model",
|
||||||
description="Test Model",
|
description="Test Model",
|
||||||
|
phrase="Test Phrase",
|
||||||
installed=True,
|
installed=True,
|
||||||
attribution=TEST_ATTR,
|
attribution=TEST_ATTR,
|
||||||
languages=["en-US"],
|
languages=["en-US"],
|
||||||
|
|
|
@ -9,5 +9,6 @@
|
||||||
]),
|
]),
|
||||||
'timestamp': 0,
|
'timestamp': 0,
|
||||||
'wake_word_id': 'Test Model',
|
'wake_word_id': 'Test Model',
|
||||||
|
'wake_word_phrase': 'Test Phrase',
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""Test Wyoming satellite."""
|
"""Test Wyoming satellite."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
@ -12,6 +13,7 @@ from wyoming.asr import Transcribe, Transcript
|
||||||
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
||||||
from wyoming.error import Error
|
from wyoming.error import Error
|
||||||
from wyoming.event import Event
|
from wyoming.event import Event
|
||||||
|
from wyoming.info import Info
|
||||||
from wyoming.ping import Ping, Pong
|
from wyoming.ping import Ping, Pong
|
||||||
from wyoming.pipeline import PipelineStage, RunPipeline
|
from wyoming.pipeline import PipelineStage, RunPipeline
|
||||||
from wyoming.satellite import RunSatellite
|
from wyoming.satellite import RunSatellite
|
||||||
|
@ -26,7 +28,7 @@ from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from . import SATELLITE_INFO, MockAsyncTcpClient
|
from . import SATELLITE_INFO, WAKE_WORD_INFO, MockAsyncTcpClient
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
|
@ -207,19 +209,25 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
||||||
audio_chunk_received.set()
|
audio_chunk_received.set()
|
||||||
break
|
break
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
return_value=SATELLITE_INFO,
|
return_value=SATELLITE_INFO,
|
||||||
), patch(
|
),
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||||
SatelliteAsyncTcpClient(events),
|
SatelliteAsyncTcpClient(events),
|
||||||
) as mock_client, patch(
|
) as mock_client,
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||||
async_pipeline_from_audio_stream,
|
async_pipeline_from_audio_stream,
|
||||||
), patch(
|
),
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
|
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
|
||||||
return_value=("wav", get_test_wav()),
|
return_value=("wav", get_test_wav()),
|
||||||
), patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0):
|
),
|
||||||
|
patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0),
|
||||||
|
):
|
||||||
entry = await setup_config_entry(hass)
|
entry = await setup_config_entry(hass)
|
||||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||||
entry.entry_id
|
entry.entry_id
|
||||||
|
@ -433,14 +441,16 @@ async def test_satellite_muted(hass: HomeAssistant) -> None:
|
||||||
self.device.set_is_muted(False)
|
self.device.set_is_muted(False)
|
||||||
on_muted_event.set()
|
on_muted_event.set()
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
return_value=SATELLITE_INFO,
|
return_value=SATELLITE_INFO,
|
||||||
), patch(
|
),
|
||||||
"homeassistant.components.wyoming._make_satellite", make_muted_satellite
|
patch("homeassistant.components.wyoming._make_satellite", make_muted_satellite),
|
||||||
), patch(
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_muted",
|
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_muted",
|
||||||
on_muted,
|
on_muted,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
entry = await setup_config_entry(hass)
|
entry = await setup_config_entry(hass)
|
||||||
async with asyncio.timeout(1):
|
async with asyncio.timeout(1):
|
||||||
|
@ -462,16 +472,21 @@ async def test_satellite_restart(hass: HomeAssistant) -> None:
|
||||||
self.stop()
|
self.stop()
|
||||||
on_restart_event.set()
|
on_restart_event.set()
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
return_value=SATELLITE_INFO,
|
return_value=SATELLITE_INFO,
|
||||||
), patch(
|
),
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.WyomingSatellite._connect_and_loop",
|
"homeassistant.components.wyoming.satellite.WyomingSatellite._connect_and_loop",
|
||||||
side_effect=RuntimeError(),
|
side_effect=RuntimeError(),
|
||||||
), patch(
|
),
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
||||||
on_restart,
|
on_restart,
|
||||||
), patch("homeassistant.components.wyoming.satellite._RESTART_SECONDS", 0):
|
),
|
||||||
|
patch("homeassistant.components.wyoming.satellite._RESTART_SECONDS", 0),
|
||||||
|
):
|
||||||
await setup_config_entry(hass)
|
await setup_config_entry(hass)
|
||||||
async with asyncio.timeout(1):
|
async with asyncio.timeout(1):
|
||||||
await on_restart_event.wait()
|
await on_restart_event.wait()
|
||||||
|
@ -497,19 +512,25 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None:
|
||||||
async def on_stopped(self):
|
async def on_stopped(self):
|
||||||
stopped_event.set()
|
stopped_event.set()
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
return_value=SATELLITE_INFO,
|
return_value=SATELLITE_INFO,
|
||||||
), patch(
|
),
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient.connect",
|
"homeassistant.components.wyoming.satellite.AsyncTcpClient.connect",
|
||||||
side_effect=ConnectionRefusedError(),
|
side_effect=ConnectionRefusedError(),
|
||||||
), patch(
|
),
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_reconnect",
|
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_reconnect",
|
||||||
on_reconnect,
|
on_reconnect,
|
||||||
), patch(
|
),
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
|
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
|
||||||
on_stopped,
|
on_stopped,
|
||||||
), patch("homeassistant.components.wyoming.satellite._RECONNECT_SECONDS", 0):
|
),
|
||||||
|
patch("homeassistant.components.wyoming.satellite._RECONNECT_SECONDS", 0),
|
||||||
|
):
|
||||||
await setup_config_entry(hass)
|
await setup_config_entry(hass)
|
||||||
async with asyncio.timeout(1):
|
async with asyncio.timeout(1):
|
||||||
await reconnect_event.wait()
|
await reconnect_event.wait()
|
||||||
|
@ -524,17 +545,22 @@ async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None
|
||||||
self.stop()
|
self.stop()
|
||||||
on_restart_event.set()
|
on_restart_event.set()
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
return_value=SATELLITE_INFO,
|
return_value=SATELLITE_INFO,
|
||||||
), patch(
|
),
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||||
MockAsyncTcpClient([]), # no RunPipeline event
|
MockAsyncTcpClient([]), # no RunPipeline event
|
||||||
), patch(
|
),
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||||
) as mock_run_pipeline, patch(
|
) as mock_run_pipeline,
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
||||||
on_restart,
|
on_restart,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
await setup_config_entry(hass)
|
await setup_config_entry(hass)
|
||||||
async with asyncio.timeout(1):
|
async with asyncio.timeout(1):
|
||||||
|
@ -564,20 +590,26 @@ async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None
|
||||||
async def on_stopped(self):
|
async def on_stopped(self):
|
||||||
on_stopped_event.set()
|
on_stopped_event.set()
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
return_value=SATELLITE_INFO,
|
return_value=SATELLITE_INFO,
|
||||||
), patch(
|
),
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||||
MockAsyncTcpClient(events),
|
MockAsyncTcpClient(events),
|
||||||
), patch(
|
),
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||||
) as mock_run_pipeline, patch(
|
) as mock_run_pipeline,
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
||||||
on_restart,
|
on_restart,
|
||||||
), patch(
|
),
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
|
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
|
||||||
on_stopped,
|
on_stopped,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
entry = await setup_config_entry(hass)
|
entry = await setup_config_entry(hass)
|
||||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||||
|
@ -608,16 +640,20 @@ async def test_satellite_error_during_pipeline(hass: HomeAssistant) -> None:
|
||||||
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
|
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
|
||||||
pipeline_event.set()
|
pipeline_event.set()
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
return_value=SATELLITE_INFO,
|
return_value=SATELLITE_INFO,
|
||||||
), patch(
|
),
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||||
SatelliteAsyncTcpClient(events),
|
SatelliteAsyncTcpClient(events),
|
||||||
) as mock_client, patch(
|
) as mock_client,
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||||
wraps=_async_pipeline_from_audio_stream,
|
wraps=_async_pipeline_from_audio_stream,
|
||||||
) as mock_run_pipeline:
|
) as mock_run_pipeline,
|
||||||
|
):
|
||||||
await setup_config_entry(hass)
|
await setup_config_entry(hass)
|
||||||
|
|
||||||
async with asyncio.timeout(1):
|
async with asyncio.timeout(1):
|
||||||
|
@ -663,21 +699,27 @@ async def test_tts_not_wav(hass: HomeAssistant) -> None:
|
||||||
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
|
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
|
||||||
pipeline_event.set()
|
pipeline_event.set()
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
return_value=SATELLITE_INFO,
|
return_value=SATELLITE_INFO,
|
||||||
), patch(
|
),
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||||
SatelliteAsyncTcpClient(events),
|
SatelliteAsyncTcpClient(events),
|
||||||
) as mock_client, patch(
|
) as mock_client,
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||||
wraps=_async_pipeline_from_audio_stream,
|
wraps=_async_pipeline_from_audio_stream,
|
||||||
) as mock_run_pipeline, patch(
|
) as mock_run_pipeline,
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
|
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
|
||||||
return_value=("mp3", bytes(1)),
|
return_value=("mp3", bytes(1)),
|
||||||
), patch(
|
),
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.WyomingSatellite._stream_tts",
|
"homeassistant.components.wyoming.satellite.WyomingSatellite._stream_tts",
|
||||||
_stream_tts,
|
_stream_tts,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
entry = await setup_config_entry(hass)
|
entry = await setup_config_entry(hass)
|
||||||
|
|
||||||
|
@ -752,15 +794,19 @@ async def test_pipeline_changed(hass: HomeAssistant) -> None:
|
||||||
|
|
||||||
pipeline_stopped.set()
|
pipeline_stopped.set()
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
return_value=SATELLITE_INFO,
|
return_value=SATELLITE_INFO,
|
||||||
), patch(
|
),
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||||
SatelliteAsyncTcpClient(events),
|
SatelliteAsyncTcpClient(events),
|
||||||
) as mock_client, patch(
|
) as mock_client,
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||||
async_pipeline_from_audio_stream,
|
async_pipeline_from_audio_stream,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
entry = await setup_config_entry(hass)
|
entry = await setup_config_entry(hass)
|
||||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||||
|
@ -822,15 +868,19 @@ async def test_audio_settings_changed(hass: HomeAssistant) -> None:
|
||||||
|
|
||||||
pipeline_stopped.set()
|
pipeline_stopped.set()
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
return_value=SATELLITE_INFO,
|
return_value=SATELLITE_INFO,
|
||||||
), patch(
|
),
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||||
SatelliteAsyncTcpClient(events),
|
SatelliteAsyncTcpClient(events),
|
||||||
) as mock_client, patch(
|
) as mock_client,
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||||
async_pipeline_from_audio_stream,
|
async_pipeline_from_audio_stream,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
entry = await setup_config_entry(hass)
|
entry = await setup_config_entry(hass)
|
||||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||||
|
@ -873,7 +923,7 @@ async def test_invalid_stages(hass: HomeAssistant) -> None:
|
||||||
start_stage_event = asyncio.Event()
|
start_stage_event = asyncio.Event()
|
||||||
end_stage_event = asyncio.Event()
|
end_stage_event = asyncio.Event()
|
||||||
|
|
||||||
def _run_pipeline_once(self, run_pipeline):
|
def _run_pipeline_once(self, run_pipeline, wake_word_phrase):
|
||||||
# Set bad start stage
|
# Set bad start stage
|
||||||
run_pipeline.start_stage = PipelineStage.INTENT
|
run_pipeline.start_stage = PipelineStage.INTENT
|
||||||
run_pipeline.end_stage = PipelineStage.TTS
|
run_pipeline.end_stage = PipelineStage.TTS
|
||||||
|
@ -892,15 +942,19 @@ async def test_invalid_stages(hass: HomeAssistant) -> None:
|
||||||
except ValueError:
|
except ValueError:
|
||||||
end_stage_event.set()
|
end_stage_event.set()
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
return_value=SATELLITE_INFO,
|
return_value=SATELLITE_INFO,
|
||||||
), patch(
|
),
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||||
SatelliteAsyncTcpClient(events),
|
SatelliteAsyncTcpClient(events),
|
||||||
) as mock_client, patch(
|
) as mock_client,
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.WyomingSatellite._run_pipeline_once",
|
"homeassistant.components.wyoming.satellite.WyomingSatellite._run_pipeline_once",
|
||||||
_run_pipeline_once,
|
_run_pipeline_once,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
entry = await setup_config_entry(hass)
|
entry = await setup_config_entry(hass)
|
||||||
|
|
||||||
|
@ -950,15 +1004,19 @@ async def test_client_stops_pipeline(hass: HomeAssistant) -> None:
|
||||||
|
|
||||||
pipeline_stopped.set()
|
pipeline_stopped.set()
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
return_value=SATELLITE_INFO,
|
return_value=SATELLITE_INFO,
|
||||||
), patch(
|
),
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||||
SatelliteAsyncTcpClient(events),
|
SatelliteAsyncTcpClient(events),
|
||||||
) as mock_client, patch(
|
) as mock_client,
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||||
async_pipeline_from_audio_stream,
|
async_pipeline_from_audio_stream,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
entry = await setup_config_entry(hass)
|
entry = await setup_config_entry(hass)
|
||||||
|
|
||||||
|
@ -982,3 +1040,46 @@ async def test_client_stops_pipeline(hass: HomeAssistant) -> None:
|
||||||
# Stop the satellite
|
# Stop the satellite
|
||||||
await hass.config_entries.async_unload(entry.entry_id)
|
await hass.config_entries.async_unload(entry.entry_id)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_wake_word_phrase(hass: HomeAssistant) -> None:
|
||||||
|
"""Test that wake word phrase from info is given to pipeline."""
|
||||||
|
events = [
|
||||||
|
# Fake local wake word detection
|
||||||
|
Info(satellite=SATELLITE_INFO.satellite, wake=WAKE_WORD_INFO.wake).event(),
|
||||||
|
Detection(name="Test Model").event(),
|
||||||
|
RunPipeline(
|
||||||
|
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
||||||
|
).event(),
|
||||||
|
]
|
||||||
|
|
||||||
|
pipeline_event = asyncio.Event()
|
||||||
|
|
||||||
|
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
|
||||||
|
pipeline_event.set()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
|
return_value=SATELLITE_INFO,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||||
|
SatelliteAsyncTcpClient(events),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||||
|
wraps=_async_pipeline_from_audio_stream,
|
||||||
|
) as mock_run_pipeline,
|
||||||
|
):
|
||||||
|
await setup_config_entry(hass)
|
||||||
|
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await pipeline_event.wait()
|
||||||
|
|
||||||
|
# async_pipeline_from_audio_stream will receive the wake word phrase for
|
||||||
|
# deconfliction.
|
||||||
|
mock_run_pipeline.assert_called_once()
|
||||||
|
assert (
|
||||||
|
mock_run_pipeline.call_args.kwargs.get("wake_word_phrase") == "Test Phrase"
|
||||||
|
)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""Test stt."""
|
"""Test stt."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
@ -26,7 +27,7 @@ async def test_support(hass: HomeAssistant, init_wyoming_wake_word) -> None:
|
||||||
assert entity is not None
|
assert entity is not None
|
||||||
|
|
||||||
assert (await entity.get_supported_wake_words()) == [
|
assert (await entity.get_supported_wake_words()) == [
|
||||||
wake_word.WakeWord(id="Test Model", name="Test Model")
|
wake_word.WakeWord(id="Test Model", name="Test Model", phrase="Test Phrase")
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,6 +60,8 @@ async def test_streaming_audio(
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result == snapshot
|
assert result == snapshot
|
||||||
|
assert result.wake_word_id == "Test Model"
|
||||||
|
assert result.wake_word_phrase == "Test Phrase"
|
||||||
|
|
||||||
|
|
||||||
async def test_streaming_audio_connection_lost(
|
async def test_streaming_audio_connection_lost(
|
||||||
|
@ -100,10 +103,13 @@ async def test_streaming_audio_oserror(
|
||||||
[Detection(name="Test Model", timestamp=1000).event()]
|
[Detection(name="Test Model", timestamp=1000).event()]
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
|
patch(
|
||||||
"homeassistant.components.wyoming.wake_word.AsyncTcpClient",
|
"homeassistant.components.wyoming.wake_word.AsyncTcpClient",
|
||||||
mock_client,
|
mock_client,
|
||||||
), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")):
|
),
|
||||||
|
patch.object(mock_client, "read_event", side_effect=OSError("Boom!")),
|
||||||
|
):
|
||||||
result = await entity.async_process_audio_stream(audio_stream(), None)
|
result = await entity.async_process_audio_stream(audio_stream(), None)
|
||||||
|
|
||||||
assert result is None
|
assert result is None
|
||||||
|
@ -171,7 +177,7 @@ async def test_dynamic_wake_word_info(
|
||||||
|
|
||||||
# Original info
|
# Original info
|
||||||
assert (await entity.get_supported_wake_words()) == [
|
assert (await entity.get_supported_wake_words()) == [
|
||||||
wake_word.WakeWord("Test Model", "Test Model")
|
wake_word.WakeWord("Test Model", "Test Model", "Test Phrase")
|
||||||
]
|
]
|
||||||
|
|
||||||
new_info = Info(
|
new_info = Info(
|
||||||
|
@ -185,6 +191,7 @@ async def test_dynamic_wake_word_info(
|
||||||
WakeModel(
|
WakeModel(
|
||||||
name="ww1",
|
name="ww1",
|
||||||
description="Wake Word 1",
|
description="Wake Word 1",
|
||||||
|
phrase="Wake Word Phrase 1",
|
||||||
installed=True,
|
installed=True,
|
||||||
attribution=TEST_ATTR,
|
attribution=TEST_ATTR,
|
||||||
languages=[],
|
languages=[],
|
||||||
|
@ -193,6 +200,7 @@ async def test_dynamic_wake_word_info(
|
||||||
WakeModel(
|
WakeModel(
|
||||||
name="ww2",
|
name="ww2",
|
||||||
description="Wake Word 2",
|
description="Wake Word 2",
|
||||||
|
phrase="Wake Word Phrase 2",
|
||||||
installed=True,
|
installed=True,
|
||||||
attribution=TEST_ATTR,
|
attribution=TEST_ATTR,
|
||||||
languages=[],
|
languages=[],
|
||||||
|
@ -210,6 +218,6 @@ async def test_dynamic_wake_word_info(
|
||||||
return_value=new_info,
|
return_value=new_info,
|
||||||
):
|
):
|
||||||
assert (await entity.get_supported_wake_words()) == [
|
assert (await entity.get_supported_wake_words()) == [
|
||||||
wake_word.WakeWord("ww1", "Wake Word 1"),
|
wake_word.WakeWord("ww1", "Wake Word 1", "Wake Word Phrase 1"),
|
||||||
wake_word.WakeWord("ww2", "Wake Word 2"),
|
wake_word.WakeWord("ww2", "Wake Word 2", "Wake Word Phrase 2"),
|
||||||
]
|
]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue