Add wake word integration (#96380)
* Add wake component * Add wake support to Wyoming * Add helper function to assist_pipeline (not complete) * Rename wake to wake_word * Fix platform * Use send_event and clean up * Merge wake word into pipeline * Add wake option to async_pipeline_from_audio_stream * Add start/end stages to async_pipeline_from_audio_stream * Add wake timeout * Remove layer in wake_output * Use VAD for wake word timeout * Include audio metadata in wake-start * Remove unnecessary websocket command * wake -> wake_word * Incorporate feedback * Clean up wake_word tests * Add wyoming wake word tests * Add pipeline wake word test * Add last processed state * Fix tests * Add tests for wake word * More tests for the codebot
This commit is contained in:
parent
798fb3e31a
commit
7ea2998b55
28 changed files with 1802 additions and 27 deletions
|
@ -1373,6 +1373,8 @@ build.json @home-assistant/supervisor
|
|||
/tests/components/vulcan/ @Antoni-Czaplicki
|
||||
/homeassistant/components/wake_on_lan/ @ntilley905
|
||||
/tests/components/wake_on_lan/ @ntilley905
|
||||
/homeassistant/components/wake_word/ @home-assistant/core @synesthesiam
|
||||
/tests/components/wake_word/ @home-assistant/core @synesthesiam
|
||||
/homeassistant/components/wallbox/ @hesselonline
|
||||
/tests/components/wallbox/ @hesselonline
|
||||
/homeassistant/components/waqi/ @andrey-git
|
||||
|
|
|
@ -18,6 +18,7 @@ from .pipeline import (
|
|||
PipelineInput,
|
||||
PipelineRun,
|
||||
PipelineStage,
|
||||
WakeWordSettings,
|
||||
async_create_default_pipeline,
|
||||
async_get_pipeline,
|
||||
async_get_pipelines,
|
||||
|
@ -35,6 +36,7 @@ __all__ = (
|
|||
"PipelineEvent",
|
||||
"PipelineEventType",
|
||||
"PipelineNotFound",
|
||||
"WakeWordSettings",
|
||||
)
|
||||
|
||||
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
|
||||
|
@ -57,7 +59,10 @@ async def async_pipeline_from_audio_stream(
|
|||
pipeline_id: str | None = None,
|
||||
conversation_id: str | None = None,
|
||||
tts_audio_output: str | None = None,
|
||||
wake_word_settings: WakeWordSettings | None = None,
|
||||
device_id: str | None = None,
|
||||
start_stage: PipelineStage = PipelineStage.STT,
|
||||
end_stage: PipelineStage = PipelineStage.TTS,
|
||||
) -> None:
|
||||
"""Create an audio pipeline from an audio stream.
|
||||
|
||||
|
@ -72,10 +77,11 @@ async def async_pipeline_from_audio_stream(
|
|||
hass,
|
||||
context=context,
|
||||
pipeline=async_get_pipeline(hass, pipeline_id=pipeline_id),
|
||||
start_stage=PipelineStage.STT,
|
||||
end_stage=PipelineStage.TTS,
|
||||
start_stage=start_stage,
|
||||
end_stage=end_stage,
|
||||
event_callback=event_callback,
|
||||
tts_audio_output=tts_audio_output,
|
||||
wake_word_settings=wake_word_settings,
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
|
|
|
@ -18,6 +18,14 @@ class PipelineNotFound(PipelineError):
|
|||
"""Unspecified pipeline picked."""
|
||||
|
||||
|
||||
class WakeWordDetectionError(PipelineError):
|
||||
"""Error in wake-word-detection portion of pipeline."""
|
||||
|
||||
|
||||
class WakeWordTimeoutError(WakeWordDetectionError):
|
||||
"""Timeout when wake word was not detected."""
|
||||
|
||||
|
||||
class SpeechToTextError(PipelineError):
|
||||
"""Error in speech-to-text portion of pipeline."""
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
"domain": "assist_pipeline",
|
||||
"name": "Assist pipeline",
|
||||
"codeowners": ["@balloob", "@synesthesiam"],
|
||||
"dependencies": ["conversation", "stt", "tts"],
|
||||
"dependencies": ["conversation", "stt", "tts", "wake_word"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/assist_pipeline",
|
||||
"iot_class": "local_push",
|
||||
"quality_scale": "internal",
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable, Callable, Iterable
|
||||
from collections.abc import AsyncGenerator, AsyncIterable, Callable, Iterable
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import StrEnum
|
||||
import logging
|
||||
|
@ -10,7 +10,14 @@ from typing import Any, cast
|
|||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import conversation, media_source, stt, tts, websocket_api
|
||||
from homeassistant.components import (
|
||||
conversation,
|
||||
media_source,
|
||||
stt,
|
||||
tts,
|
||||
wake_word,
|
||||
websocket_api,
|
||||
)
|
||||
from homeassistant.components.tts.media_source import (
|
||||
generate_media_source_id as tts_generate_media_source_id,
|
||||
)
|
||||
|
@ -39,7 +46,10 @@ from .error import (
|
|||
PipelineNotFound,
|
||||
SpeechToTextError,
|
||||
TextToSpeechError,
|
||||
WakeWordDetectionError,
|
||||
WakeWordTimeoutError,
|
||||
)
|
||||
from .vad import VoiceActivityTimeout, VoiceCommandSegmenter
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -241,6 +251,8 @@ class PipelineEventType(StrEnum):
|
|||
|
||||
RUN_START = "run-start"
|
||||
RUN_END = "run-end"
|
||||
WAKE_WORD_START = "wake_word-start"
|
||||
WAKE_WORD_END = "wake_word-end"
|
||||
STT_START = "stt-start"
|
||||
STT_END = "stt-end"
|
||||
INTENT_START = "intent-start"
|
||||
|
@ -297,12 +309,14 @@ class Pipeline:
|
|||
class PipelineStage(StrEnum):
|
||||
"""Stages of a pipeline."""
|
||||
|
||||
WAKE_WORD = "wake_word"
|
||||
STT = "stt"
|
||||
INTENT = "intent"
|
||||
TTS = "tts"
|
||||
|
||||
|
||||
PIPELINE_STAGE_ORDER = [
|
||||
PipelineStage.WAKE_WORD,
|
||||
PipelineStage.STT,
|
||||
PipelineStage.INTENT,
|
||||
PipelineStage.TTS,
|
||||
|
@ -327,6 +341,17 @@ class InvalidPipelineStagesError(PipelineRunValidationError):
|
|||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WakeWordSettings:
|
||||
"""Settings for wake word detection."""
|
||||
|
||||
timeout: float | None = None
|
||||
"""Seconds of silence before detection times out."""
|
||||
|
||||
audio_seconds_to_buffer: float = 0
|
||||
"""Seconds of audio to buffer before detection and forward to STT."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineRun:
|
||||
"""Running context for a pipeline."""
|
||||
|
@ -341,17 +366,20 @@ class PipelineRun:
|
|||
runner_data: Any | None = None
|
||||
intent_agent: str | None = None
|
||||
tts_audio_output: str | None = None
|
||||
wake_word_settings: WakeWordSettings | None = None
|
||||
|
||||
id: str = field(default_factory=ulid_util.ulid)
|
||||
stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False)
|
||||
tts_engine: str = field(init=False)
|
||||
tts_options: dict | None = field(init=False, default=None)
|
||||
wake_word_engine: str = field(init=False)
|
||||
wake_word_provider: wake_word.WakeWordDetectionEntity = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Set language for pipeline."""
|
||||
self.language = self.pipeline.language or self.hass.config.language
|
||||
|
||||
# stt -> intent -> tts
|
||||
# wake -> stt -> intent -> tts
|
||||
if PIPELINE_STAGE_ORDER.index(self.end_stage) < PIPELINE_STAGE_ORDER.index(
|
||||
self.start_stage
|
||||
):
|
||||
|
@ -393,6 +421,141 @@ class PipelineRun:
|
|||
)
|
||||
)
|
||||
|
||||
async def prepare_wake_word_detection(self) -> None:
|
||||
"""Prepare wake-word-detection."""
|
||||
# Need to add to pipeline store
|
||||
engine = wake_word.async_default_engine(self.hass)
|
||||
if engine is None:
|
||||
raise WakeWordDetectionError(
|
||||
code="wake-engine-missing",
|
||||
message="No wake word engine",
|
||||
)
|
||||
|
||||
wake_word_provider = wake_word.async_get_wake_word_detection_entity(
|
||||
self.hass, engine
|
||||
)
|
||||
if wake_word_provider is None:
|
||||
raise WakeWordDetectionError(
|
||||
code="wake-provider-missing",
|
||||
message=f"No wake-word-detection provider for: {engine}",
|
||||
)
|
||||
|
||||
self.wake_word_engine = engine
|
||||
self.wake_word_provider = wake_word_provider
|
||||
|
||||
async def wake_word_detection(
|
||||
self,
|
||||
stream: AsyncIterable[bytes],
|
||||
audio_buffer: list[bytes],
|
||||
) -> wake_word.DetectionResult | None:
|
||||
"""Run wake-word-detection portion of pipeline. Returns detection result."""
|
||||
metadata_dict = asdict(
|
||||
stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
)
|
||||
)
|
||||
|
||||
# Remove language since it doesn't apply to wake words yet
|
||||
metadata_dict.pop("language", None)
|
||||
|
||||
self.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.WAKE_WORD_START,
|
||||
{
|
||||
"engine": self.wake_word_engine,
|
||||
"metadata": metadata_dict,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
wake_word_settings = self.wake_word_settings or WakeWordSettings()
|
||||
|
||||
wake_word_vad: VoiceActivityTimeout | None = None
|
||||
if (wake_word_settings.timeout is not None) and (
|
||||
wake_word_settings.timeout > 0
|
||||
):
|
||||
# Use VAD to determine timeout
|
||||
wake_word_vad = VoiceActivityTimeout(wake_word_settings.timeout)
|
||||
|
||||
# Audio chunk buffer.
|
||||
audio_bytes_to_buffer = int(
|
||||
wake_word_settings.audio_seconds_to_buffer * 16000 * 2
|
||||
)
|
||||
audio_ring_buffer = b""
|
||||
|
||||
async def timestamped_stream() -> AsyncIterable[tuple[bytes, int]]:
|
||||
"""Yield audio with timestamps (milliseconds since start of stream)."""
|
||||
nonlocal audio_ring_buffer
|
||||
|
||||
timestamp_ms = 0
|
||||
async for chunk in stream:
|
||||
yield chunk, timestamp_ms
|
||||
timestamp_ms += (len(chunk) // 2) // 16 # milliseconds @ 16Khz
|
||||
|
||||
# Keeping audio right before wake word detection allows the
|
||||
# voice command to be spoken immediately after the wake word.
|
||||
if audio_bytes_to_buffer > 0:
|
||||
audio_ring_buffer += chunk
|
||||
if len(audio_ring_buffer) > audio_bytes_to_buffer:
|
||||
# A proper ring buffer would be far more efficient
|
||||
audio_ring_buffer = audio_ring_buffer[
|
||||
len(audio_ring_buffer) - audio_bytes_to_buffer :
|
||||
]
|
||||
|
||||
if (wake_word_vad is not None) and (not wake_word_vad.process(chunk)):
|
||||
raise WakeWordTimeoutError(
|
||||
code="wake-word-timeout", message="Wake word was not detected"
|
||||
)
|
||||
|
||||
try:
|
||||
# Detect wake word(s)
|
||||
result = await self.wake_word_provider.async_process_audio_stream(
|
||||
timestamped_stream()
|
||||
)
|
||||
|
||||
if audio_ring_buffer:
|
||||
# All audio kept from right before the wake word was detected as
|
||||
# a single chunk.
|
||||
audio_buffer.append(audio_ring_buffer)
|
||||
except WakeWordTimeoutError:
|
||||
_LOGGER.debug("Timeout during wake word detection")
|
||||
raise
|
||||
except Exception as src_error:
|
||||
_LOGGER.exception("Unexpected error during wake-word-detection")
|
||||
raise WakeWordDetectionError(
|
||||
code="wake-stream-failed",
|
||||
message="Unexpected error during wake-word-detection",
|
||||
) from src_error
|
||||
|
||||
_LOGGER.debug("wake-word-detection result %s", result)
|
||||
|
||||
if result is None:
|
||||
wake_word_output: dict[str, Any] = {}
|
||||
else:
|
||||
if result.queued_audio:
|
||||
# Add audio that was pending at detection
|
||||
for chunk_ts in result.queued_audio:
|
||||
audio_buffer.append(chunk_ts[0])
|
||||
|
||||
wake_word_output = asdict(result)
|
||||
|
||||
# Remove non-JSON fields
|
||||
wake_word_output.pop("queued_audio", None)
|
||||
|
||||
self.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.WAKE_WORD_END,
|
||||
{"wake_word_output": wake_word_output},
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def prepare_speech_to_text(self, metadata: stt.SpeechMetadata) -> None:
|
||||
"""Prepare speech-to-text."""
|
||||
# pipeline.stt_engine can't be None or this function is not called
|
||||
|
@ -443,9 +606,21 @@ class PipelineRun:
|
|||
)
|
||||
|
||||
try:
|
||||
segmenter = VoiceCommandSegmenter()
|
||||
|
||||
async def segment_stream(
|
||||
stream: AsyncIterable[bytes],
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""Stop stream when voice command is finished."""
|
||||
async for chunk in stream:
|
||||
if not segmenter.process(chunk):
|
||||
break
|
||||
|
||||
yield chunk
|
||||
|
||||
# Transcribe audio stream
|
||||
result = await self.stt_provider.async_process_audio_stream(
|
||||
metadata, stream
|
||||
metadata, segment_stream(stream)
|
||||
)
|
||||
except Exception as src_error:
|
||||
_LOGGER.exception("Unexpected error during speech-to-text")
|
||||
|
@ -663,17 +838,45 @@ class PipelineInput:
|
|||
async def execute(self) -> None:
|
||||
"""Run pipeline."""
|
||||
self.run.start()
|
||||
current_stage = self.run.start_stage
|
||||
current_stage: PipelineStage | None = self.run.start_stage
|
||||
audio_buffer: list[bytes] = []
|
||||
|
||||
try:
|
||||
if current_stage == PipelineStage.WAKE_WORD:
|
||||
assert self.stt_stream is not None
|
||||
detect_result = await self.run.wake_word_detection(
|
||||
self.stt_stream, audio_buffer
|
||||
)
|
||||
if detect_result is None:
|
||||
# No wake word. Abort the rest of the pipeline.
|
||||
self.run.end()
|
||||
return
|
||||
|
||||
current_stage = PipelineStage.STT
|
||||
|
||||
# speech-to-text
|
||||
intent_input = self.intent_input
|
||||
if current_stage == PipelineStage.STT:
|
||||
assert self.stt_metadata is not None
|
||||
assert self.stt_stream is not None
|
||||
|
||||
if audio_buffer:
|
||||
|
||||
async def buffered_stream() -> AsyncGenerator[bytes, None]:
|
||||
for chunk in audio_buffer:
|
||||
yield chunk
|
||||
|
||||
assert self.stt_stream is not None
|
||||
async for chunk in self.stt_stream:
|
||||
yield chunk
|
||||
|
||||
stt_stream = cast(AsyncIterable[bytes], buffered_stream())
|
||||
else:
|
||||
stt_stream = self.stt_stream
|
||||
|
||||
intent_input = await self.run.speech_to_text(
|
||||
self.stt_metadata,
|
||||
self.stt_stream,
|
||||
stt_stream,
|
||||
)
|
||||
current_stage = PipelineStage.INTENT
|
||||
|
||||
|
@ -707,7 +910,7 @@ class PipelineInput:
|
|||
|
||||
async def validate(self) -> None:
|
||||
"""Validate pipeline input against start stage."""
|
||||
if self.run.start_stage == PipelineStage.STT:
|
||||
if self.run.start_stage in (PipelineStage.WAKE_WORD, PipelineStage.STT):
|
||||
if self.run.pipeline.stt_engine is None:
|
||||
raise PipelineRunValidationError(
|
||||
"the pipeline does not support speech-to-text"
|
||||
|
@ -741,6 +944,13 @@ class PipelineInput:
|
|||
|
||||
prepare_tasks = []
|
||||
|
||||
if (
|
||||
start_stage_index
|
||||
<= PIPELINE_STAGE_ORDER.index(PipelineStage.WAKE_WORD)
|
||||
<= end_stage_index
|
||||
):
|
||||
prepare_tasks.append(self.run.prepare_wake_word_detection())
|
||||
|
||||
if (
|
||||
start_stage_index
|
||||
<= PIPELINE_STAGE_ORDER.index(PipelineStage.STT)
|
||||
|
|
|
@ -88,7 +88,7 @@ class VoiceCommandSegmenter:
|
|||
self.in_command = False
|
||||
|
||||
def process(self, samples: bytes) -> bool:
|
||||
"""Process a 16-bit 16Khz mono audio samples.
|
||||
"""Process 16-bit 16Khz mono audio samples.
|
||||
|
||||
Returns False when command is done.
|
||||
"""
|
||||
|
@ -148,3 +148,94 @@ class VoiceCommandSegmenter:
|
|||
self._silence_seconds_left = self.silence_seconds
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoiceActivityTimeout:
|
||||
"""Detects silence in audio until a timeout is reached."""
|
||||
|
||||
silence_seconds: float
|
||||
"""Seconds of silence before timeout."""
|
||||
|
||||
reset_seconds: float = 0.5
|
||||
"""Seconds of speech before resetting timeout."""
|
||||
|
||||
vad_mode: int = 3
|
||||
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
|
||||
|
||||
vad_frames: int = 480 # 30 ms
|
||||
"""Must be 10, 20, or 30 ms at 16Khz."""
|
||||
|
||||
_silence_seconds_left: float = 0.0
|
||||
"""Seconds left before considering voice command as stopped."""
|
||||
|
||||
_reset_seconds_left: float = 0.0
|
||||
"""Seconds left before resetting start/stop time counters."""
|
||||
|
||||
_vad: webrtcvad.Vad = None
|
||||
_audio_buffer: bytes = field(default_factory=bytes)
|
||||
_bytes_per_chunk: int = 480 * 2 # 16-bit samples
|
||||
_seconds_per_chunk: float = 0.03 # 30 ms
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Initialize VAD."""
|
||||
self._vad = webrtcvad.Vad(self.vad_mode)
|
||||
self._bytes_per_chunk = self.vad_frames * 2
|
||||
self._seconds_per_chunk = self.vad_frames / _SAMPLE_RATE
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all counters and state."""
|
||||
self._audio_buffer = b""
|
||||
self._silence_seconds_left = self.silence_seconds
|
||||
self._reset_seconds_left = self.reset_seconds
|
||||
|
||||
def process(self, samples: bytes) -> bool:
|
||||
"""Process 16-bit 16Khz mono audio samples.
|
||||
|
||||
Returns False when timeout is reached.
|
||||
"""
|
||||
self._audio_buffer += samples
|
||||
|
||||
# Process in 10, 20, or 30 ms chunks.
|
||||
num_chunks = len(self._audio_buffer) // self._bytes_per_chunk
|
||||
for chunk_idx in range(num_chunks):
|
||||
chunk_offset = chunk_idx * self._bytes_per_chunk
|
||||
chunk = self._audio_buffer[
|
||||
chunk_offset : chunk_offset + self._bytes_per_chunk
|
||||
]
|
||||
if not self._process_chunk(chunk):
|
||||
return False
|
||||
|
||||
if num_chunks > 0:
|
||||
# Remove from buffer
|
||||
self._audio_buffer = self._audio_buffer[
|
||||
num_chunks * self._bytes_per_chunk :
|
||||
]
|
||||
|
||||
return True
|
||||
|
||||
def _process_chunk(self, chunk: bytes) -> bool:
|
||||
"""Process a single chunk of 16-bit 16Khz mono audio.
|
||||
|
||||
Returns False when timeout is reached.
|
||||
"""
|
||||
if self._vad.is_speech(chunk, _SAMPLE_RATE):
|
||||
# Speech
|
||||
self._reset_seconds_left -= self._seconds_per_chunk
|
||||
if self._reset_seconds_left <= 0:
|
||||
# Reset timeout
|
||||
self._silence_seconds_left = self.silence_seconds
|
||||
else:
|
||||
# Silence
|
||||
self._silence_seconds_left -= self._seconds_per_chunk
|
||||
if self._silence_seconds_left <= 0:
|
||||
# Timeout reached
|
||||
return False
|
||||
|
||||
# Slowly build reset counter back up
|
||||
self._reset_seconds_left = min(
|
||||
self.reset_seconds, self._reset_seconds_left + self._seconds_per_chunk
|
||||
)
|
||||
|
||||
return True
|
||||
|
|
|
@ -26,11 +26,12 @@ from .pipeline import (
|
|||
PipelineInput,
|
||||
PipelineRun,
|
||||
PipelineStage,
|
||||
WakeWordSettings,
|
||||
async_get_pipeline,
|
||||
)
|
||||
from .vad import VoiceCommandSegmenter
|
||||
|
||||
DEFAULT_TIMEOUT = 30
|
||||
DEFAULT_WAKE_WORD_TIMEOUT = 3
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -63,6 +64,18 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
|
|||
cv.key_value_schemas(
|
||||
"start_stage",
|
||||
{
|
||||
PipelineStage.WAKE_WORD: vol.Schema(
|
||||
{
|
||||
vol.Required("input"): {
|
||||
vol.Required("sample_rate"): int,
|
||||
vol.Optional("timeout"): vol.Any(float, int),
|
||||
vol.Optional("audio_seconds_to_buffer"): vol.Any(
|
||||
float, int
|
||||
),
|
||||
}
|
||||
},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
),
|
||||
PipelineStage.STT: vol.Schema(
|
||||
{vol.Required("input"): {vol.Required("sample_rate"): int}},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
|
@ -102,6 +115,7 @@ async def websocket_run(
|
|||
end_stage = PipelineStage(msg["end_stage"])
|
||||
handler_id: int | None = None
|
||||
unregister_handler: Callable[[], None] | None = None
|
||||
wake_word_settings: WakeWordSettings | None = None
|
||||
|
||||
# Arguments to PipelineInput
|
||||
input_args: dict[str, Any] = {
|
||||
|
@ -109,24 +123,26 @@ async def websocket_run(
|
|||
"device_id": msg.get("device_id"),
|
||||
}
|
||||
|
||||
if start_stage == PipelineStage.STT:
|
||||
if start_stage in (PipelineStage.WAKE_WORD, PipelineStage.STT):
|
||||
# Audio pipeline that will receive audio as binary websocket messages
|
||||
audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
incoming_sample_rate = msg["input"]["sample_rate"]
|
||||
|
||||
if start_stage == PipelineStage.WAKE_WORD:
|
||||
wake_word_settings = WakeWordSettings(
|
||||
timeout=msg["input"].get("timeout", DEFAULT_WAKE_WORD_TIMEOUT),
|
||||
audio_seconds_to_buffer=msg["input"].get("audio_seconds_to_buffer", 0),
|
||||
)
|
||||
|
||||
async def stt_stream() -> AsyncGenerator[bytes, None]:
|
||||
state = None
|
||||
segmenter = VoiceCommandSegmenter()
|
||||
|
||||
# Yield until we receive an empty chunk
|
||||
while chunk := await audio_queue.get():
|
||||
chunk, state = audioop.ratecv(
|
||||
chunk, 2, 1, incoming_sample_rate, 16000, state
|
||||
)
|
||||
if not segmenter.process(chunk):
|
||||
# Voice command is finished
|
||||
break
|
||||
|
||||
if incoming_sample_rate != 16000:
|
||||
chunk, state = audioop.ratecv(
|
||||
chunk, 2, 1, incoming_sample_rate, 16000, state
|
||||
)
|
||||
yield chunk
|
||||
|
||||
def handle_binary(
|
||||
|
@ -169,6 +185,7 @@ async def websocket_run(
|
|||
"stt_binary_handler_id": handler_id,
|
||||
"timeout": timeout,
|
||||
},
|
||||
wake_word_settings=wake_word_settings,
|
||||
)
|
||||
|
||||
pipeline_input = PipelineInput(**input_args)
|
||||
|
|
119
homeassistant/components/wake_word/__init__.py
Normal file
119
homeassistant/components/wake_word/__init__.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
"""Provide functionality to wake word."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from collections.abc import AsyncIterable
|
||||
import logging
|
||||
from typing import final
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from .const import DOMAIN
|
||||
from .models import DetectionResult, WakeWord
|
||||
|
||||
__all__ = [
|
||||
"async_default_engine",
|
||||
"async_get_wake_word_detection_entity",
|
||||
"DetectionResult",
|
||||
"DOMAIN",
|
||||
"WakeWord",
|
||||
"WakeWordDetectionEntity",
|
||||
]
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
|
||||
|
||||
|
||||
@callback
|
||||
def async_default_engine(hass: HomeAssistant) -> str | None:
|
||||
"""Return the domain or entity id of the default engine."""
|
||||
return next(iter(hass.states.async_entity_ids(DOMAIN)), None)
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_wake_word_detection_entity(
|
||||
hass: HomeAssistant, entity_id: str
|
||||
) -> WakeWordDetectionEntity | None:
|
||||
"""Return wake word entity."""
|
||||
component: EntityComponent = hass.data[DOMAIN]
|
||||
|
||||
return component.get_entity(entity_id)
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up STT."""
|
||||
component = hass.data[DOMAIN] = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
component.register_shutdown()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up a config entry."""
|
||||
component: EntityComponent = hass.data[DOMAIN]
|
||||
return await component.async_setup_entry(entry)
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
component: EntityComponent = hass.data[DOMAIN]
|
||||
return await component.async_unload_entry(entry)
|
||||
|
||||
|
||||
class WakeWordDetectionEntity(RestoreEntity):
|
||||
"""Represent a single wake word provider."""
|
||||
|
||||
_attr_should_poll = False
|
||||
__last_processed: str | None = None
|
||||
|
||||
@property
|
||||
@final
|
||||
def state(self) -> str | None:
|
||||
"""Return the state of the entity."""
|
||||
if self.__last_processed is None:
|
||||
return None
|
||||
return self.__last_processed
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_wake_words(self) -> list[WakeWord]:
|
||||
"""Return a list of supported wake words."""
|
||||
|
||||
@abstractmethod
|
||||
async def _async_process_audio_stream(
|
||||
self, stream: AsyncIterable[tuple[bytes, int]]
|
||||
) -> DetectionResult | None:
|
||||
"""Try to detect wake word(s) in an audio stream with timestamps.
|
||||
|
||||
Audio must be 16Khz sample rate with 16-bit mono PCM samples.
|
||||
"""
|
||||
|
||||
async def async_process_audio_stream(
|
||||
self, stream: AsyncIterable[tuple[bytes, int]]
|
||||
) -> DetectionResult | None:
|
||||
"""Try to detect wake word(s) in an audio stream with timestamps.
|
||||
|
||||
Audio must be 16Khz sample rate with 16-bit mono PCM samples.
|
||||
"""
|
||||
self.__last_processed = dt_util.utcnow().isoformat()
|
||||
self.async_write_ha_state()
|
||||
return await self._async_process_audio_stream(stream)
|
||||
|
||||
async def async_internal_added_to_hass(self) -> None:
|
||||
"""Call when the entity is added to hass."""
|
||||
await super().async_internal_added_to_hass()
|
||||
state = await self.async_get_last_state()
|
||||
if (
|
||||
state is not None
|
||||
and state.state is not None
|
||||
and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
|
||||
):
|
||||
self.__last_processed = state.state
|
2
homeassistant/components/wake_word/const.py
Normal file
2
homeassistant/components/wake_word/const.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
"""Wake word constants."""
|
||||
DOMAIN = "wake_word"
|
8
homeassistant/components/wake_word/manifest.json
Normal file
8
homeassistant/components/wake_word/manifest.json
Normal file
|
@ -0,0 +1,8 @@
|
|||
{
|
||||
"domain": "wake_word",
|
||||
"name": "Wake-word detection",
|
||||
"codeowners": ["@home-assistant/core", "@synesthesiam"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/wake_word",
|
||||
"integration_type": "entity",
|
||||
"quality_scale": "internal"
|
||||
}
|
24
homeassistant/components/wake_word/models.py
Normal file
24
homeassistant/components/wake_word/models.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
"""Wake word models."""
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WakeWord:
|
||||
"""Wake word model."""
|
||||
|
||||
ww_id: str
|
||||
name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionResult:
|
||||
"""Result of wake word detection."""
|
||||
|
||||
ww_id: str
|
||||
"""Id of detected wake word"""
|
||||
|
||||
timestamp: int | None
|
||||
"""Timestamp of audio chunk with detected wake word"""
|
||||
|
||||
queued_audio: list[tuple[bytes, int]] | None = None
|
||||
"""Audio chunks that were queued when wake word was detected."""
|
|
@ -50,14 +50,21 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
errors={"base": "cannot_connect"},
|
||||
)
|
||||
|
||||
# ASR = automated speech recognition (STT)
|
||||
# ASR = automated speech recognition (speech-to-text)
|
||||
asr_installed = [asr for asr in service.info.asr if asr.installed]
|
||||
|
||||
# TTS = text-to-speech
|
||||
tts_installed = [tts for tts in service.info.tts if tts.installed]
|
||||
|
||||
# wake-word-detection
|
||||
wake_installed = [wake for wake in service.info.wake if wake.installed]
|
||||
|
||||
if asr_installed:
|
||||
name = asr_installed[0].name
|
||||
elif tts_installed:
|
||||
name = tts_installed[0].name
|
||||
elif wake_installed:
|
||||
name = wake_installed[0].name
|
||||
else:
|
||||
return self.async_abort(reason="no_services")
|
||||
|
||||
|
|
|
@ -29,6 +29,8 @@ class WyomingService:
|
|||
platforms.append(Platform.STT)
|
||||
if any(tts.installed for tts in info.tts):
|
||||
platforms.append(Platform.TTS)
|
||||
if any(wake.installed for wake in info.wake):
|
||||
platforms.append(Platform.WAKE_WORD)
|
||||
self.platforms = platforms
|
||||
|
||||
@classmethod
|
||||
|
|
157
homeassistant/components/wyoming/wake_word.py
Normal file
157
homeassistant/components/wyoming/wake_word.py
Normal file
|
@ -0,0 +1,157 @@
|
|||
"""Support for Wyoming wake-word-detection services."""
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
import logging
|
||||
|
||||
from wyoming.audio import AudioChunk, AudioStart
|
||||
from wyoming.client import AsyncTcpClient
|
||||
from wyoming.wake import Detection
|
||||
|
||||
from homeassistant.components import wake_word
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from .const import DOMAIN
|
||||
from .data import WyomingService
|
||||
from .error import WyomingError
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up Wyoming wake-word-detection."""
|
||||
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id]
|
||||
async_add_entities(
|
||||
[
|
||||
WyomingWakeWordProvider(config_entry, service),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
|
||||
"""Wyoming wake-word-detection provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config_entry: ConfigEntry,
|
||||
service: WyomingService,
|
||||
) -> None:
|
||||
"""Set up provider."""
|
||||
self.service = service
|
||||
wake_service = service.info.wake[0]
|
||||
|
||||
self._supported_wake_words = [
|
||||
wake_word.WakeWord(ww_id=ww.name, name=ww.name)
|
||||
for ww in wake_service.models
|
||||
]
|
||||
self._attr_name = wake_service.name
|
||||
self._attr_unique_id = f"{config_entry.entry_id}-wake_word"
|
||||
|
||||
@property
|
||||
def supported_wake_words(self) -> list[wake_word.WakeWord]:
|
||||
"""Return a list of supported wake words."""
|
||||
return self._supported_wake_words
|
||||
|
||||
async def _async_process_audio_stream(
|
||||
self, stream: AsyncIterable[tuple[bytes, int]]
|
||||
) -> wake_word.DetectionResult | None:
|
||||
"""Try to detect one or more wake words in an audio stream.
|
||||
|
||||
Audio must be 16Khz sample rate with 16-bit mono PCM samples.
|
||||
"""
|
||||
|
||||
async def next_chunk():
|
||||
"""Get the next chunk from audio stream."""
|
||||
async for chunk_bytes in stream:
|
||||
return chunk_bytes
|
||||
|
||||
try:
|
||||
async with AsyncTcpClient(self.service.host, self.service.port) as client:
|
||||
await client.write_event(
|
||||
AudioStart(
|
||||
rate=16000,
|
||||
width=2,
|
||||
channels=1,
|
||||
).event(),
|
||||
)
|
||||
|
||||
# Read audio and wake events in "parallel"
|
||||
audio_task = asyncio.create_task(next_chunk())
|
||||
wake_task = asyncio.create_task(client.read_event())
|
||||
pending = {audio_task, wake_task}
|
||||
|
||||
try:
|
||||
while True:
|
||||
done, pending = await asyncio.wait(
|
||||
pending, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
if wake_task in done:
|
||||
event = wake_task.result()
|
||||
if event is None:
|
||||
_LOGGER.debug("Connection lost")
|
||||
break
|
||||
|
||||
if Detection.is_type(event.type):
|
||||
# Successful detection
|
||||
detection = Detection.from_event(event)
|
||||
_LOGGER.info(detection)
|
||||
|
||||
# Retrieve queued audio
|
||||
queued_audio: list[tuple[bytes, int]] | None = None
|
||||
if audio_task in pending:
|
||||
# Save queued audio
|
||||
await audio_task
|
||||
pending.remove(audio_task)
|
||||
queued_audio = [audio_task.result()]
|
||||
|
||||
return wake_word.DetectionResult(
|
||||
ww_id=detection.name,
|
||||
timestamp=detection.timestamp,
|
||||
queued_audio=queued_audio,
|
||||
)
|
||||
|
||||
# Next event
|
||||
wake_task = asyncio.create_task(client.read_event())
|
||||
pending.add(wake_task)
|
||||
|
||||
if audio_task in done:
|
||||
# Forward audio to wake service
|
||||
chunk_info = audio_task.result()
|
||||
if chunk_info is None:
|
||||
break
|
||||
|
||||
chunk_bytes, chunk_timestamp = chunk_info
|
||||
chunk = AudioChunk(
|
||||
rate=16000,
|
||||
width=2,
|
||||
channels=1,
|
||||
audio=chunk_bytes,
|
||||
timestamp=chunk_timestamp,
|
||||
)
|
||||
await client.write_event(chunk.event())
|
||||
|
||||
# Next chunk
|
||||
audio_task = asyncio.create_task(next_chunk())
|
||||
pending.add(audio_task)
|
||||
finally:
|
||||
# Clean up
|
||||
if audio_task in pending:
|
||||
# It's critical that we don't cancel the audio task or
|
||||
# leave it hanging. This would mess up the pipeline STT
|
||||
# by stopping the audio stream.
|
||||
await audio_task
|
||||
pending.remove(audio_task)
|
||||
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
except (OSError, WyomingError) as err:
|
||||
_LOGGER.exception("Error processing audio stream: %s", err)
|
||||
|
||||
return None
|
|
@ -57,6 +57,7 @@ class Platform(StrEnum):
|
|||
TTS = "tts"
|
||||
VACUUM = "vacuum"
|
||||
UPDATE = "update"
|
||||
WAKE_WORD = "wake_word"
|
||||
WATER_HEATER = "water_heater"
|
||||
WEATHER = "weather"
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from unittest.mock import AsyncMock
|
|||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import stt, tts
|
||||
from homeassistant.components import stt, tts, wake_word
|
||||
from homeassistant.components.assist_pipeline import DOMAIN
|
||||
from homeassistant.components.assist_pipeline.pipeline import (
|
||||
PipelineData,
|
||||
|
@ -174,6 +174,40 @@ class MockSttPlatform(MockPlatform):
|
|||
self.async_get_engine = async_get_engine
|
||||
|
||||
|
||||
class MockWakeWordEntity(wake_word.WakeWordDetectionEntity):
|
||||
"""Mock wake word entity."""
|
||||
|
||||
fail_process_audio = False
|
||||
url_path = "wake_word.test"
|
||||
_attr_name = "test"
|
||||
|
||||
@property
|
||||
def supported_wake_words(self) -> list[wake_word.WakeWord]:
|
||||
"""Return a list of supported wake words."""
|
||||
return [wake_word.WakeWord(ww_id="test_ww", name="Test Wake Word")]
|
||||
|
||||
async def _async_process_audio_stream(
|
||||
self, stream: AsyncIterable[tuple[bytes, int]]
|
||||
) -> wake_word.DetectionResult | None:
|
||||
"""Try to detect wake word(s) in an audio stream with timestamps."""
|
||||
async for chunk, timestamp in stream:
|
||||
if chunk == b"wake word":
|
||||
return wake_word.DetectionResult(
|
||||
ww_id=self.supported_wake_words[0].ww_id,
|
||||
timestamp=timestamp,
|
||||
queued_audio=[(b"queued audio", 0)],
|
||||
)
|
||||
|
||||
# Not detected
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_wake_word_provider_entity(hass) -> MockWakeWordEntity:
|
||||
"""Mock wake word provider."""
|
||||
return MockWakeWordEntity()
|
||||
|
||||
|
||||
class MockFlow(ConfigFlow):
|
||||
"""Test flow."""
|
||||
|
||||
|
@ -193,6 +227,7 @@ async def init_supporting_components(
|
|||
mock_stt_provider: MockSttProvider,
|
||||
mock_stt_provider_entity: MockSttProviderEntity,
|
||||
mock_tts_provider: MockTTSProvider,
|
||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||
config_flow_fixture,
|
||||
):
|
||||
"""Initialize relevant components with empty configs."""
|
||||
|
@ -201,14 +236,18 @@ async def init_supporting_components(
|
|||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Set up test config entry."""
|
||||
await hass.config_entries.async_forward_entry_setup(config_entry, stt.DOMAIN)
|
||||
await hass.config_entries.async_forward_entry_setups(
|
||||
config_entry, [stt.DOMAIN, wake_word.DOMAIN]
|
||||
)
|
||||
return True
|
||||
|
||||
async def async_unload_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Unload up test config entry."""
|
||||
await hass.config_entries.async_forward_entry_unload(config_entry, stt.DOMAIN)
|
||||
await hass.config_entries.async_unload_platforms(
|
||||
config_entry, [stt.DOMAIN, wake_word.DOMAIN]
|
||||
)
|
||||
return True
|
||||
|
||||
async def async_setup_entry_stt_platform(
|
||||
|
@ -219,6 +258,14 @@ async def init_supporting_components(
|
|||
"""Set up test stt platform via config entry."""
|
||||
async_add_entities([mock_stt_provider_entity])
|
||||
|
||||
async def async_setup_entry_wake_word_platform(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up test wake word platform via config entry."""
|
||||
async_add_entities([mock_wake_word_provider_entity])
|
||||
|
||||
mock_integration(
|
||||
hass,
|
||||
MockModule(
|
||||
|
@ -242,11 +289,19 @@ async def init_supporting_components(
|
|||
async_setup_entry=async_setup_entry_stt_platform,
|
||||
),
|
||||
)
|
||||
mock_platform(
|
||||
hass,
|
||||
"test.wake_word",
|
||||
MockPlatform(
|
||||
async_setup_entry=async_setup_entry_wake_word_platform,
|
||||
),
|
||||
)
|
||||
mock_platform(hass, "test.config_flow")
|
||||
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
|
||||
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
|
||||
# assert await async_setup_component(hass, wake_word.DOMAIN, {"wake_word": {}})
|
||||
assert await async_setup_component(hass, "media_source", {})
|
||||
|
||||
config_entry = MockConfigEntry(domain="test")
|
||||
|
|
|
@ -266,3 +266,114 @@
|
|||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_pipeline_from_audio_stream_wake_word
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'engine': 'wake_word.test',
|
||||
'metadata': dict({
|
||||
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
|
||||
'channel': <AudioChannels.CHANNEL_MONO: 1>,
|
||||
'codec': <AudioCodecs.PCM: 'pcm'>,
|
||||
'format': <AudioFormats.WAV: 'wav'>,
|
||||
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'wake_word_output': dict({
|
||||
'timestamp': 2000,
|
||||
'ww_id': 'test_ww',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.WAKE_WORD_END: 'wake_word-end'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'engine': 'test',
|
||||
'metadata': dict({
|
||||
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
|
||||
'channel': <AudioChannels.CHANNEL_MONO: 1>,
|
||||
'codec': <AudioCodecs.PCM: 'pcm'>,
|
||||
'format': <AudioFormats.WAV: 'wav'>,
|
||||
'language': 'en-US',
|
||||
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.STT_START: 'stt-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'stt_output': dict({
|
||||
'text': 'test transcript',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.STT_END: 'stt-end'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': None,
|
||||
'device_id': None,
|
||||
'engine': 'homeassistant',
|
||||
'intent_input': 'test transcript',
|
||||
'language': 'en',
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
'data': dict({
|
||||
'code': 'no_intent_match',
|
||||
}),
|
||||
'language': 'en',
|
||||
'response_type': 'error',
|
||||
'speech': dict({
|
||||
'plain': dict({
|
||||
'extra_data': None,
|
||||
'speech': "Sorry, I couldn't understand that",
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'engine': 'test',
|
||||
'language': 'en-US',
|
||||
'tts_input': "Sorry, I couldn't understand that",
|
||||
'voice': 'james_earl_jones',
|
||||
}),
|
||||
'type': <PipelineEventType.TTS_START: 'tts-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.TTS_END: 'tts-end'>,
|
||||
}),
|
||||
dict({
|
||||
'data': None,
|
||||
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
|
|
|
@ -155,6 +155,243 @@
|
|||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_no_wake_word_engine
|
||||
dict({
|
||||
'code': 'wake-engine-missing',
|
||||
'message': 'No wake word engine',
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_no_wake_word_entity
|
||||
dict({
|
||||
'code': 'wake-provider-missing',
|
||||
'message': 'No wake-word-detection provider for: wake_word.bad-entity-id',
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word
|
||||
dict({
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': 1,
|
||||
'timeout': 30,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word.1
|
||||
dict({
|
||||
'engine': 'wake_word.test',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
'codec': 'pcm',
|
||||
'format': 'wav',
|
||||
'sample_rate': 16000,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word.2
|
||||
dict({
|
||||
'wake_word_output': dict({
|
||||
'queued_audio': None,
|
||||
'timestamp': 1000,
|
||||
'ww_id': 'test_ww',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word.3
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
'codec': 'pcm',
|
||||
'format': 'wav',
|
||||
'language': 'en-US',
|
||||
'sample_rate': 16000,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word.4
|
||||
dict({
|
||||
'stt_output': dict({
|
||||
'text': 'test transcript',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word.5
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'device_id': None,
|
||||
'engine': 'homeassistant',
|
||||
'intent_input': 'test transcript',
|
||||
'language': 'en',
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word.6
|
||||
dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
'data': dict({
|
||||
'code': 'no_intent_match',
|
||||
}),
|
||||
'language': 'en',
|
||||
'response_type': 'error',
|
||||
'speech': dict({
|
||||
'plain': dict({
|
||||
'extra_data': None,
|
||||
'speech': "Sorry, I couldn't understand that",
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word.7
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'language': 'en-US',
|
||||
'tts_input': "Sorry, I couldn't understand that",
|
||||
'voice': 'james_earl_jones',
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word.8
|
||||
dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_no_timeout
|
||||
dict({
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': 1,
|
||||
'timeout': 30,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_no_timeout.1
|
||||
dict({
|
||||
'engine': 'wake_word.test',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
'codec': 'pcm',
|
||||
'format': 'wav',
|
||||
'sample_rate': 16000,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_no_timeout.2
|
||||
dict({
|
||||
'wake_word_output': dict({
|
||||
'timestamp': 0,
|
||||
'ww_id': 'test_ww',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_no_timeout.3
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
'codec': 'pcm',
|
||||
'format': 'wav',
|
||||
'language': 'en-US',
|
||||
'sample_rate': 16000,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_no_timeout.4
|
||||
dict({
|
||||
'stt_output': dict({
|
||||
'text': 'test transcript',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_no_timeout.5
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'device_id': None,
|
||||
'engine': 'homeassistant',
|
||||
'intent_input': 'test transcript',
|
||||
'language': 'en',
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_no_timeout.6
|
||||
dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
'data': dict({
|
||||
'code': 'no_intent_match',
|
||||
}),
|
||||
'language': 'en',
|
||||
'response_type': 'error',
|
||||
'speech': dict({
|
||||
'plain': dict({
|
||||
'extra_data': None,
|
||||
'speech': "Sorry, I couldn't understand that",
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_no_timeout.7
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'language': 'en-US',
|
||||
'tts_input': "Sorry, I couldn't understand that",
|
||||
'voice': 'james_earl_jones',
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_no_timeout.8
|
||||
dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_timeout
|
||||
dict({
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': 1,
|
||||
'timeout': 30,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_timeout.1
|
||||
dict({
|
||||
'engine': 'wake_word.test',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
'codec': 'pcm',
|
||||
'format': 'wav',
|
||||
'sample_rate': 16000,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_timeout.2
|
||||
dict({
|
||||
'code': 'wake-word-timeout',
|
||||
'message': 'Wake word was not detected',
|
||||
})
|
||||
# ---
|
||||
# name: test_intent_failed
|
||||
dict({
|
||||
'language': 'en',
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""Test Voice Assistant init."""
|
||||
from dataclasses import asdict
|
||||
import itertools as it
|
||||
from unittest.mock import ANY
|
||||
|
||||
import pytest
|
||||
|
@ -8,10 +9,12 @@ from syrupy.assertion import SnapshotAssertion
|
|||
from homeassistant.components import assist_pipeline, stt
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
|
||||
from .conftest import MockSttProvider, MockSttProviderEntity
|
||||
from .conftest import MockSttProvider, MockSttProviderEntity, MockWakeWordEntity
|
||||
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
||||
BYTES_ONE_SECOND = 16000 * 2
|
||||
|
||||
|
||||
def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
|
||||
"""Process events to remove dynamic values."""
|
||||
|
@ -280,3 +283,61 @@ async def test_pipeline_from_audio_stream_unknown_pipeline(
|
|||
)
|
||||
|
||||
assert not events
|
||||
|
||||
|
||||
async def test_pipeline_from_audio_stream_wake_word(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider: MockSttProvider,
|
||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test creating a pipeline from an audio stream with wake word."""
|
||||
|
||||
events = []
|
||||
|
||||
# [0, 1, ...]
|
||||
wake_chunk_1 = bytes(it.islice(it.cycle(range(256)), BYTES_ONE_SECOND))
|
||||
|
||||
# [0, 2, ...]
|
||||
wake_chunk_2 = bytes(it.islice(it.cycle(range(0, 256, 2)), BYTES_ONE_SECOND))
|
||||
|
||||
async def audio_data():
|
||||
yield wake_chunk_1 # 1 second
|
||||
yield wake_chunk_2 # 1 second
|
||||
yield b"wake word"
|
||||
yield b"part1"
|
||||
yield b"part2"
|
||||
yield b""
|
||||
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
Context(),
|
||||
events.append,
|
||||
stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
audio_data(),
|
||||
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
|
||||
wake_word_settings=assist_pipeline.WakeWordSettings(
|
||||
audio_seconds_to_buffer=1.5
|
||||
),
|
||||
)
|
||||
|
||||
assert process_events(events) == snapshot
|
||||
|
||||
# 1. Half of wake_chunk_1 + all wake_chunk_2
|
||||
# 2. queued audio (from mock wake word entity)
|
||||
# 3. part1
|
||||
# 4. part2
|
||||
assert len(mock_stt_provider.received) == 4
|
||||
|
||||
first_chunk = mock_stt_provider.received[0]
|
||||
assert first_chunk == wake_chunk_1[len(wake_chunk_1) // 2 :] + wake_chunk_2
|
||||
|
||||
assert mock_stt_provider.received[1:] == [b"queued audio", b"part1", b"part2"]
|
||||
|
|
|
@ -167,6 +167,224 @@ async def test_audio_pipeline(
|
|||
assert msg["result"] == {"events": events}
|
||||
|
||||
|
||||
async def test_audio_pipeline_with_wake_word_timeout(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test timeout from a pipeline run with audio input/output + wake word."""
|
||||
events = []
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "wake_word",
|
||||
"end_stage": "tts",
|
||||
"input": {
|
||||
"sample_rate": 16000,
|
||||
"timeout": 1,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"], msg
|
||||
|
||||
# run start
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
msg["event"]["data"]["pipeline"] = ANY
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# wake_word
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "wake_word-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# 2 seconds of silence
|
||||
await client.send_bytes(bytes([1]) + bytes(16000 * 2 * 2))
|
||||
|
||||
# Time out error
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "error"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
|
||||
async def test_audio_pipeline_with_wake_word_no_timeout(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test events from a pipeline run with audio input/output + wake word with no timeout."""
|
||||
events = []
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "wake_word",
|
||||
"end_stage": "tts",
|
||||
"input": {
|
||||
"sample_rate": 16000,
|
||||
"timeout": 0,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"], msg
|
||||
|
||||
# run start
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
msg["event"]["data"]["pipeline"] = ANY
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# wake_word
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "wake_word-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# "audio"
|
||||
await client.send_bytes(bytes([1]) + b"wake word")
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "wake_word-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# stt
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "stt-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# End of audio stream (handler id + empty payload)
|
||||
await client.send_bytes(bytes([1]))
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "stt-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# intent
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "intent-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "intent-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# text-to-speech
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "tts-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "tts-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# run end
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-end"
|
||||
assert msg["event"]["data"] is None
|
||||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": pipeline_id,
|
||||
"pipeline_run_id": pipeline_run_id,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"events": events}
|
||||
|
||||
|
||||
async def test_audio_pipeline_no_wake_word_engine(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test timeout from a pipeline run with audio input/output + wake word."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wake_word.async_default_engine", return_value=None
|
||||
):
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "wake_word",
|
||||
"end_stage": "tts",
|
||||
"input": {
|
||||
"sample_rate": 16000,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# error
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
assert "error" in msg
|
||||
assert msg["error"] == snapshot
|
||||
|
||||
|
||||
async def test_audio_pipeline_no_wake_word_entity(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test timeout from a pipeline run with audio input/output + wake word."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wake_word.async_default_engine",
|
||||
return_value="wake_word.bad-entity-id",
|
||||
), patch(
|
||||
"homeassistant.components.wake_word.async_get_wake_word_detection_entity",
|
||||
return_value=None,
|
||||
):
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "wake_word",
|
||||
"end_stage": "tts",
|
||||
"input": {
|
||||
"sample_rate": 16000,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# error
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
assert "error" in msg
|
||||
assert msg["error"] == snapshot
|
||||
|
||||
|
||||
async def test_intent_timeout(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
|
|
1
tests/components/wake_word/__init__.py
Normal file
1
tests/components/wake_word/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
"""Wake-word-detection tests."""
|
29
tests/components/wake_word/common.py
Normal file
29
tests/components/wake_word/common.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
"""Provide common test tools for wake-word-detection."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Coroutine
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.components import wake_word
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from tests.common import MockPlatform, mock_platform
|
||||
|
||||
|
||||
def mock_wake_word_entity_platform(
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
integration: str,
|
||||
async_setup_entry: Callable[
|
||||
[HomeAssistant, ConfigEntry, AddEntitiesCallback],
|
||||
Coroutine[Any, Any, None],
|
||||
]
|
||||
| None = None,
|
||||
) -> MockPlatform:
|
||||
"""Specialize the mock platform for stt."""
|
||||
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry)
|
||||
mock_platform(hass, f"{integration}.{wake_word.DOMAIN}", loaded_platform)
|
||||
return loaded_platform
|
11
tests/components/wake_word/snapshots/test_init.ambr
Normal file
11
tests/components/wake_word/snapshots/test_init.ambr
Normal file
|
@ -0,0 +1,11 @@
|
|||
# serializer version: 1
|
||||
# name: test_ws_detect
|
||||
dict({
|
||||
'event': dict({
|
||||
'timestamp': 2048.0,
|
||||
'ww_id': 'test_ww',
|
||||
}),
|
||||
'id': 1,
|
||||
'type': 'event',
|
||||
})
|
||||
# ---
|
226
tests/components/wake_word/test_init.py
Normal file
226
tests/components/wake_word/test_init.py
Normal file
|
@ -0,0 +1,226 @@
|
|||
"""Test wake_word component setup."""
|
||||
from collections.abc import AsyncIterable, Generator
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import wake_word
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigEntryState, ConfigFlow
|
||||
from homeassistant.core import HomeAssistant, State
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from .common import mock_wake_word_entity_platform
|
||||
|
||||
from tests.common import (
|
||||
MockConfigEntry,
|
||||
MockModule,
|
||||
mock_config_flow,
|
||||
mock_integration,
|
||||
mock_platform,
|
||||
mock_restore_cache,
|
||||
)
|
||||
|
||||
TEST_DOMAIN = "test"
|
||||
|
||||
_SAMPLES_PER_CHUNK = 1024
|
||||
_BYTES_PER_CHUNK = _SAMPLES_PER_CHUNK * 2 # 16-bit
|
||||
_MS_PER_CHUNK = (_BYTES_PER_CHUNK // 2) // 16 # 16Khz
|
||||
|
||||
|
||||
class MockProviderEntity(wake_word.WakeWordDetectionEntity):
|
||||
"""Mock provider entity."""
|
||||
|
||||
url_path = "wake_word.test"
|
||||
_attr_name = "test"
|
||||
|
||||
@property
|
||||
def supported_wake_words(self) -> list[wake_word.WakeWord]:
|
||||
"""Return a list of supported wake words."""
|
||||
return [wake_word.WakeWord(ww_id="test_ww", name="Test Wake Word")]
|
||||
|
||||
async def _async_process_audio_stream(
|
||||
self, stream: AsyncIterable[tuple[bytes, int]]
|
||||
) -> wake_word.DetectionResult | None:
|
||||
"""Try to detect wake word(s) in an audio stream with timestamps."""
|
||||
async for _chunk, timestamp in stream:
|
||||
if timestamp >= 2000:
|
||||
return wake_word.DetectionResult(
|
||||
ww_id=self.supported_wake_words[0].ww_id, timestamp=timestamp
|
||||
)
|
||||
|
||||
# Not detected
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_entity() -> MockProviderEntity:
|
||||
"""Test provider entity fixture."""
|
||||
return MockProviderEntity()
|
||||
|
||||
|
||||
class WakeWordFlow(ConfigFlow):
|
||||
"""Test flow."""
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]:
|
||||
"""Mock config flow."""
|
||||
mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
|
||||
|
||||
with mock_config_flow(TEST_DOMAIN, WakeWordFlow):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(name="setup")
|
||||
async def setup_fixture(
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
) -> MockProviderEntity:
|
||||
"""Set up the test environment."""
|
||||
provider = MockProviderEntity()
|
||||
await mock_config_entry_setup(hass, tmp_path, provider)
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
async def mock_config_entry_setup(
|
||||
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
|
||||
) -> MockConfigEntry:
|
||||
"""Set up a test provider via config entry."""
|
||||
|
||||
async def async_setup_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Set up test config entry."""
|
||||
await hass.config_entries.async_forward_entry_setup(
|
||||
config_entry, wake_word.DOMAIN
|
||||
)
|
||||
return True
|
||||
|
||||
async def async_unload_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Unload up test config entry."""
|
||||
await hass.config_entries.async_forward_entry_unload(
|
||||
config_entry, wake_word.DOMAIN
|
||||
)
|
||||
return True
|
||||
|
||||
mock_integration(
|
||||
hass,
|
||||
MockModule(
|
||||
TEST_DOMAIN,
|
||||
async_setup_entry=async_setup_entry_init,
|
||||
async_unload_entry=async_unload_entry_init,
|
||||
),
|
||||
)
|
||||
|
||||
async def async_setup_entry_platform(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up test stt platform via config entry."""
|
||||
async_add_entities([mock_provider_entity])
|
||||
|
||||
mock_wake_word_entity_platform(
|
||||
hass, tmp_path, TEST_DOMAIN, async_setup_entry_platform
|
||||
)
|
||||
|
||||
config_entry = MockConfigEntry(domain=TEST_DOMAIN)
|
||||
config_entry.add_to_hass(hass)
|
||||
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
return config_entry
|
||||
|
||||
|
||||
async def test_config_entry_unload(
|
||||
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
|
||||
) -> None:
|
||||
"""Test we can unload config entry."""
|
||||
config_entry = await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
|
||||
assert config_entry.state == ConfigEntryState.LOADED
|
||||
await hass.config_entries.async_unload(config_entry.entry_id)
|
||||
assert config_entry.state == ConfigEntryState.NOT_LOADED
|
||||
|
||||
|
||||
async def test_detected_entity(
|
||||
hass: HomeAssistant, tmp_path: Path, setup: MockProviderEntity
|
||||
) -> None:
|
||||
"""Test successful detection through entity."""
|
||||
|
||||
async def three_second_stream():
|
||||
timestamp = 0
|
||||
while timestamp < 3000:
|
||||
yield bytes(_BYTES_PER_CHUNK), timestamp
|
||||
timestamp += _MS_PER_CHUNK
|
||||
|
||||
# Need 2 seconds to trigger
|
||||
result = await setup.async_process_audio_stream(three_second_stream())
|
||||
assert result == wake_word.DetectionResult("test_ww", 2048)
|
||||
|
||||
|
||||
async def test_not_detected_entity(
|
||||
hass: HomeAssistant, setup: MockProviderEntity
|
||||
) -> None:
|
||||
"""Test unsuccessful detection through entity."""
|
||||
|
||||
async def one_second_stream():
|
||||
timestamp = 0
|
||||
while timestamp < 1000:
|
||||
yield bytes(_BYTES_PER_CHUNK), timestamp
|
||||
timestamp += _MS_PER_CHUNK
|
||||
|
||||
# Need 2 seconds to trigger
|
||||
result = await setup.async_process_audio_stream(one_second_stream())
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_default_engine_none(hass: HomeAssistant, tmp_path: Path) -> None:
|
||||
"""Test async_default_engine."""
|
||||
assert await async_setup_component(hass, wake_word.DOMAIN, {wake_word.DOMAIN: {}})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert wake_word.async_default_engine(hass) is None
|
||||
|
||||
|
||||
async def test_default_engine_entity(
|
||||
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
|
||||
) -> None:
|
||||
"""Test async_default_engine."""
|
||||
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
|
||||
|
||||
assert wake_word.async_default_engine(hass) == f"{wake_word.DOMAIN}.{TEST_DOMAIN}"
|
||||
|
||||
|
||||
async def test_get_engine_entity(
|
||||
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
|
||||
) -> None:
|
||||
"""Test async_get_speech_to_text_engine."""
|
||||
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
|
||||
|
||||
assert (
|
||||
wake_word.async_get_wake_word_detection_entity(hass, f"{wake_word.DOMAIN}.test")
|
||||
is mock_provider_entity
|
||||
)
|
||||
|
||||
|
||||
async def test_restore_state(
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
mock_provider_entity: MockProviderEntity,
|
||||
) -> None:
|
||||
"""Test we restore state in the integration."""
|
||||
entity_id = f"{wake_word.DOMAIN}.{TEST_DOMAIN}"
|
||||
timestamp = "2023-01-01T23:59:59+00:00"
|
||||
mock_restore_cache(hass, (State(entity_id, timestamp),))
|
||||
|
||||
config_entry = await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert config_entry.state == ConfigEntryState.LOADED
|
||||
state = hass.states.get(entity_id)
|
||||
assert state
|
||||
assert state.state == timestamp
|
|
@ -1,4 +1,6 @@
|
|||
"""Tests for the Wyoming integration."""
|
||||
import asyncio
|
||||
|
||||
from wyoming.info import (
|
||||
AsrModel,
|
||||
AsrProgram,
|
||||
|
@ -7,6 +9,8 @@ from wyoming.info import (
|
|||
TtsProgram,
|
||||
TtsVoice,
|
||||
TtsVoiceSpeaker,
|
||||
WakeModel,
|
||||
WakeProgram,
|
||||
)
|
||||
|
||||
TEST_ATTR = Attribution(name="Test", url="http://www.test.com")
|
||||
|
@ -49,6 +53,25 @@ TTS_INFO = Info(
|
|||
)
|
||||
]
|
||||
)
|
||||
WAKE_WORD_INFO = Info(
|
||||
wake=[
|
||||
WakeProgram(
|
||||
name="Test Wake Word",
|
||||
description="Test Wake Word",
|
||||
installed=True,
|
||||
attribution=TEST_ATTR,
|
||||
models=[
|
||||
WakeModel(
|
||||
name="Test Model",
|
||||
description="Test Model",
|
||||
installed=True,
|
||||
attribution=TEST_ATTR,
|
||||
languages=["en-US"],
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
EMPTY_INFO = Info()
|
||||
|
||||
|
||||
|
@ -68,6 +91,7 @@ class MockAsyncTcpClient:
|
|||
|
||||
async def read_event(self):
|
||||
"""Receive."""
|
||||
await asyncio.sleep(0) # force context switch
|
||||
return self.responses.pop(0)
|
||||
|
||||
async def __aenter__(self):
|
||||
|
|
|
@ -8,7 +8,7 @@ from homeassistant.components import stt
|
|||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from . import STT_INFO, TTS_INFO
|
||||
from . import STT_INFO, TTS_INFO, WAKE_WORD_INFO
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
@ -52,6 +52,21 @@ def tts_config_entry(hass: HomeAssistant) -> ConfigEntry:
|
|||
return entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def wake_word_config_entry(hass: HomeAssistant) -> ConfigEntry:
|
||||
"""Create a config entry."""
|
||||
entry = MockConfigEntry(
|
||||
domain="wyoming",
|
||||
data={
|
||||
"host": "1.2.3.4",
|
||||
"port": 1234,
|
||||
},
|
||||
title="Test Wake Word",
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
return entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_wyoming_stt(hass: HomeAssistant, stt_config_entry: ConfigEntry):
|
||||
"""Initialize Wyoming STT."""
|
||||
|
@ -72,6 +87,18 @@ async def init_wyoming_tts(hass: HomeAssistant, tts_config_entry: ConfigEntry):
|
|||
await hass.config_entries.async_setup(tts_config_entry.entry_id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_wyoming_wake_word(
|
||||
hass: HomeAssistant, wake_word_config_entry: ConfigEntry
|
||||
):
|
||||
"""Initialize Wyoming Wake Word."""
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=WAKE_WORD_INFO,
|
||||
):
|
||||
await hass.config_entries.async_setup(wake_word_config_entry.entry_id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def metadata(hass: HomeAssistant) -> stt.SpeechMetadata:
|
||||
"""Get default STT metadata."""
|
||||
|
|
13
tests/components/wyoming/snapshots/test_wake_word.ambr
Normal file
13
tests/components/wyoming/snapshots/test_wake_word.ambr
Normal file
|
@ -0,0 +1,13 @@
|
|||
# serializer version: 1
|
||||
# name: test_streaming_audio
|
||||
dict({
|
||||
'queued_audio': list([
|
||||
tuple(
|
||||
b'chunk',
|
||||
1,
|
||||
),
|
||||
]),
|
||||
'timestamp': 0,
|
||||
'ww_id': 'Test Model',
|
||||
})
|
||||
# ---
|
108
tests/components/wyoming/test_wake_word.py
Normal file
108
tests/components/wyoming/test_wake_word.py
Normal file
|
@ -0,0 +1,108 @@
|
|||
"""Test stt."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
from wyoming.asr import Transcript
|
||||
from wyoming.wake import Detection
|
||||
|
||||
from homeassistant.components import wake_word
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from . import MockAsyncTcpClient
|
||||
|
||||
|
||||
async def test_support(hass: HomeAssistant, init_wyoming_wake_word) -> None:
|
||||
"""Test supported properties."""
|
||||
state = hass.states.get("wake_word.test_wake_word")
|
||||
assert state is not None
|
||||
|
||||
entity = wake_word.async_get_wake_word_detection_entity(
|
||||
hass, "wake_word.test_wake_word"
|
||||
)
|
||||
assert entity is not None
|
||||
|
||||
assert entity.supported_wake_words == [
|
||||
wake_word.WakeWord(ww_id="Test Model", name="Test Model")
|
||||
]
|
||||
|
||||
|
||||
async def test_streaming_audio(
|
||||
hass: HomeAssistant, init_wyoming_wake_word, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
"""Test streaming audio."""
|
||||
entity = wake_word.async_get_wake_word_detection_entity(
|
||||
hass, "wake_word.test_wake_word"
|
||||
)
|
||||
assert entity is not None
|
||||
|
||||
async def audio_stream():
|
||||
yield b"chunk", 0
|
||||
|
||||
# Delay to force a pending audio chunk
|
||||
await asyncio.sleep(0.05)
|
||||
yield b"chunk", 1
|
||||
|
||||
client_events = [
|
||||
Transcript("not a wake word event").event(),
|
||||
Detection(name="Test Model", timestamp=0).event(),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.wake_word.AsyncTcpClient",
|
||||
MockAsyncTcpClient(client_events),
|
||||
):
|
||||
result = await entity.async_process_audio_stream(audio_stream())
|
||||
|
||||
assert result is not None
|
||||
assert result == snapshot
|
||||
|
||||
|
||||
async def test_streaming_audio_connection_lost(
|
||||
hass: HomeAssistant, init_wyoming_wake_word
|
||||
) -> None:
|
||||
"""Test streaming audio and losing connection."""
|
||||
entity = wake_word.async_get_wake_word_detection_entity(
|
||||
hass, "wake_word.test_wake_word"
|
||||
)
|
||||
assert entity is not None
|
||||
|
||||
async def audio_stream():
|
||||
# Delay to force a pending audio chunk
|
||||
await asyncio.sleep(0.05)
|
||||
yield b"chunk", 1
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.wake_word.AsyncTcpClient",
|
||||
MockAsyncTcpClient([None]),
|
||||
):
|
||||
result = await entity.async_process_audio_stream(audio_stream())
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_streaming_audio_oserror(
|
||||
hass: HomeAssistant, init_wyoming_wake_word
|
||||
) -> None:
|
||||
"""Test streaming audio and error raising."""
|
||||
entity = wake_word.async_get_wake_word_detection_entity(
|
||||
hass, "wake_word.test_wake_word"
|
||||
)
|
||||
assert entity is not None
|
||||
|
||||
async def audio_stream():
|
||||
yield b"chunk1", 1000
|
||||
|
||||
mock_client = MockAsyncTcpClient(
|
||||
[Detection(name="Test Model", timestamp=1000).event()]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.wake_word.AsyncTcpClient",
|
||||
mock_client,
|
||||
), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")):
|
||||
result = await entity.async_process_audio_stream(audio_stream())
|
||||
|
||||
assert result is None
|
Loading…
Add table
Reference in a new issue