Add wake word integration (#96380)

* Add wake component

* Add wake support to Wyoming

* Add helper function to assist_pipeline (not complete)

* Rename wake to wake_word

* Fix platform

* Use send_event and clean up

* Merge wake word into pipeline

* Add wake option to async_pipeline_from_audio_stream

* Add start/end stages to async_pipeline_from_audio_stream

* Add wake timeout

* Remove layer in wake_output

* Use VAD for wake word timeout

* Include audio metadata in wake-start

* Remove unnecessary websocket command

* wake -> wake_word

* Incorporate feedback

* Clean up wake_word tests

* Add wyoming wake word tests

* Add pipeline wake word test

* Add last processed state

* Fix tests

* Add tests for wake word

* More tests for the codebot
This commit is contained in:
Michael Hansen 2023-08-07 21:22:16 -05:00 committed by GitHub
parent 798fb3e31a
commit 7ea2998b55
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 1802 additions and 27 deletions

View file

@ -1373,6 +1373,8 @@ build.json @home-assistant/supervisor
/tests/components/vulcan/ @Antoni-Czaplicki
/homeassistant/components/wake_on_lan/ @ntilley905
/tests/components/wake_on_lan/ @ntilley905
/homeassistant/components/wake_word/ @home-assistant/core @synesthesiam
/tests/components/wake_word/ @home-assistant/core @synesthesiam
/homeassistant/components/wallbox/ @hesselonline
/tests/components/wallbox/ @hesselonline
/homeassistant/components/waqi/ @andrey-git

View file

@ -18,6 +18,7 @@ from .pipeline import (
PipelineInput,
PipelineRun,
PipelineStage,
WakeWordSettings,
async_create_default_pipeline,
async_get_pipeline,
async_get_pipelines,
@ -35,6 +36,7 @@ __all__ = (
"PipelineEvent",
"PipelineEventType",
"PipelineNotFound",
"WakeWordSettings",
)
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
@ -57,7 +59,10 @@ async def async_pipeline_from_audio_stream(
pipeline_id: str | None = None,
conversation_id: str | None = None,
tts_audio_output: str | None = None,
wake_word_settings: WakeWordSettings | None = None,
device_id: str | None = None,
start_stage: PipelineStage = PipelineStage.STT,
end_stage: PipelineStage = PipelineStage.TTS,
) -> None:
"""Create an audio pipeline from an audio stream.
@ -72,10 +77,11 @@ async def async_pipeline_from_audio_stream(
hass,
context=context,
pipeline=async_get_pipeline(hass, pipeline_id=pipeline_id),
start_stage=PipelineStage.STT,
end_stage=PipelineStage.TTS,
start_stage=start_stage,
end_stage=end_stage,
event_callback=event_callback,
tts_audio_output=tts_audio_output,
wake_word_settings=wake_word_settings,
),
)
await pipeline_input.validate()

View file

@ -18,6 +18,14 @@ class PipelineNotFound(PipelineError):
"""Unspecified pipeline picked."""
class WakeWordDetectionError(PipelineError):
"""Error in wake-word-detection portion of pipeline."""
class WakeWordTimeoutError(WakeWordDetectionError):
"""Timeout when wake word was not detected."""
class SpeechToTextError(PipelineError):
"""Error in speech-to-text portion of pipeline."""

View file

@ -2,7 +2,7 @@
"domain": "assist_pipeline",
"name": "Assist pipeline",
"codeowners": ["@balloob", "@synesthesiam"],
"dependencies": ["conversation", "stt", "tts"],
"dependencies": ["conversation", "stt", "tts", "wake_word"],
"documentation": "https://www.home-assistant.io/integrations/assist_pipeline",
"iot_class": "local_push",
"quality_scale": "internal",

View file

@ -2,7 +2,7 @@
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterable, Callable, Iterable
from collections.abc import AsyncGenerator, AsyncIterable, Callable, Iterable
from dataclasses import asdict, dataclass, field
from enum import StrEnum
import logging
@ -10,7 +10,14 @@ from typing import Any, cast
import voluptuous as vol
from homeassistant.components import conversation, media_source, stt, tts, websocket_api
from homeassistant.components import (
conversation,
media_source,
stt,
tts,
wake_word,
websocket_api,
)
from homeassistant.components.tts.media_source import (
generate_media_source_id as tts_generate_media_source_id,
)
@ -39,7 +46,10 @@ from .error import (
PipelineNotFound,
SpeechToTextError,
TextToSpeechError,
WakeWordDetectionError,
WakeWordTimeoutError,
)
from .vad import VoiceActivityTimeout, VoiceCommandSegmenter
_LOGGER = logging.getLogger(__name__)
@ -241,6 +251,8 @@ class PipelineEventType(StrEnum):
RUN_START = "run-start"
RUN_END = "run-end"
WAKE_WORD_START = "wake_word-start"
WAKE_WORD_END = "wake_word-end"
STT_START = "stt-start"
STT_END = "stt-end"
INTENT_START = "intent-start"
@ -297,12 +309,14 @@ class Pipeline:
class PipelineStage(StrEnum):
"""Stages of a pipeline."""
WAKE_WORD = "wake_word"
STT = "stt"
INTENT = "intent"
TTS = "tts"
PIPELINE_STAGE_ORDER = [
PipelineStage.WAKE_WORD,
PipelineStage.STT,
PipelineStage.INTENT,
PipelineStage.TTS,
@ -327,6 +341,17 @@ class InvalidPipelineStagesError(PipelineRunValidationError):
)
@dataclass(frozen=True)
class WakeWordSettings:
"""Settings for wake word detection."""
timeout: float | None = None
"""Seconds of silence before detection times out."""
audio_seconds_to_buffer: float = 0
"""Seconds of audio to buffer before detection and forward to STT."""
@dataclass
class PipelineRun:
"""Running context for a pipeline."""
@ -341,17 +366,20 @@ class PipelineRun:
runner_data: Any | None = None
intent_agent: str | None = None
tts_audio_output: str | None = None
wake_word_settings: WakeWordSettings | None = None
id: str = field(default_factory=ulid_util.ulid)
stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False)
tts_engine: str = field(init=False)
tts_options: dict | None = field(init=False, default=None)
wake_word_engine: str = field(init=False)
wake_word_provider: wake_word.WakeWordDetectionEntity = field(init=False)
def __post_init__(self) -> None:
"""Set language for pipeline."""
self.language = self.pipeline.language or self.hass.config.language
# stt -> intent -> tts
# wake -> stt -> intent -> tts
if PIPELINE_STAGE_ORDER.index(self.end_stage) < PIPELINE_STAGE_ORDER.index(
self.start_stage
):
@ -393,6 +421,141 @@ class PipelineRun:
)
)
async def prepare_wake_word_detection(self) -> None:
"""Prepare wake-word-detection."""
# Need to add to pipeline store
engine = wake_word.async_default_engine(self.hass)
if engine is None:
raise WakeWordDetectionError(
code="wake-engine-missing",
message="No wake word engine",
)
wake_word_provider = wake_word.async_get_wake_word_detection_entity(
self.hass, engine
)
if wake_word_provider is None:
raise WakeWordDetectionError(
code="wake-provider-missing",
message=f"No wake-word-detection provider for: {engine}",
)
self.wake_word_engine = engine
self.wake_word_provider = wake_word_provider
async def wake_word_detection(
self,
stream: AsyncIterable[bytes],
audio_buffer: list[bytes],
) -> wake_word.DetectionResult | None:
"""Run wake-word-detection portion of pipeline. Returns detection result."""
metadata_dict = asdict(
stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
)
)
# Remove language since it doesn't apply to wake words yet
metadata_dict.pop("language", None)
self.process_event(
PipelineEvent(
PipelineEventType.WAKE_WORD_START,
{
"engine": self.wake_word_engine,
"metadata": metadata_dict,
},
)
)
wake_word_settings = self.wake_word_settings or WakeWordSettings()
wake_word_vad: VoiceActivityTimeout | None = None
if (wake_word_settings.timeout is not None) and (
wake_word_settings.timeout > 0
):
# Use VAD to determine timeout
wake_word_vad = VoiceActivityTimeout(wake_word_settings.timeout)
# Audio chunk buffer.
audio_bytes_to_buffer = int(
wake_word_settings.audio_seconds_to_buffer * 16000 * 2
)
audio_ring_buffer = b""
async def timestamped_stream() -> AsyncIterable[tuple[bytes, int]]:
"""Yield audio with timestamps (milliseconds since start of stream)."""
nonlocal audio_ring_buffer
timestamp_ms = 0
async for chunk in stream:
yield chunk, timestamp_ms
timestamp_ms += (len(chunk) // 2) // 16 # milliseconds @ 16Khz
# Keeping audio right before wake word detection allows the
# voice command to be spoken immediately after the wake word.
if audio_bytes_to_buffer > 0:
audio_ring_buffer += chunk
if len(audio_ring_buffer) > audio_bytes_to_buffer:
# A proper ring buffer would be far more efficient
audio_ring_buffer = audio_ring_buffer[
len(audio_ring_buffer) - audio_bytes_to_buffer :
]
if (wake_word_vad is not None) and (not wake_word_vad.process(chunk)):
raise WakeWordTimeoutError(
code="wake-word-timeout", message="Wake word was not detected"
)
try:
# Detect wake word(s)
result = await self.wake_word_provider.async_process_audio_stream(
timestamped_stream()
)
if audio_ring_buffer:
# All audio kept from right before the wake word was detected as
# a single chunk.
audio_buffer.append(audio_ring_buffer)
except WakeWordTimeoutError:
_LOGGER.debug("Timeout during wake word detection")
raise
except Exception as src_error:
_LOGGER.exception("Unexpected error during wake-word-detection")
raise WakeWordDetectionError(
code="wake-stream-failed",
message="Unexpected error during wake-word-detection",
) from src_error
_LOGGER.debug("wake-word-detection result %s", result)
if result is None:
wake_word_output: dict[str, Any] = {}
else:
if result.queued_audio:
# Add audio that was pending at detection
for chunk_ts in result.queued_audio:
audio_buffer.append(chunk_ts[0])
wake_word_output = asdict(result)
# Remove non-JSON fields
wake_word_output.pop("queued_audio", None)
self.process_event(
PipelineEvent(
PipelineEventType.WAKE_WORD_END,
{"wake_word_output": wake_word_output},
)
)
return result
async def prepare_speech_to_text(self, metadata: stt.SpeechMetadata) -> None:
"""Prepare speech-to-text."""
# pipeline.stt_engine can't be None or this function is not called
@ -443,9 +606,21 @@ class PipelineRun:
)
try:
segmenter = VoiceCommandSegmenter()
async def segment_stream(
stream: AsyncIterable[bytes],
) -> AsyncGenerator[bytes, None]:
"""Stop stream when voice command is finished."""
async for chunk in stream:
if not segmenter.process(chunk):
break
yield chunk
# Transcribe audio stream
result = await self.stt_provider.async_process_audio_stream(
metadata, stream
metadata, segment_stream(stream)
)
except Exception as src_error:
_LOGGER.exception("Unexpected error during speech-to-text")
@ -663,17 +838,45 @@ class PipelineInput:
async def execute(self) -> None:
"""Run pipeline."""
self.run.start()
current_stage = self.run.start_stage
current_stage: PipelineStage | None = self.run.start_stage
audio_buffer: list[bytes] = []
try:
if current_stage == PipelineStage.WAKE_WORD:
assert self.stt_stream is not None
detect_result = await self.run.wake_word_detection(
self.stt_stream, audio_buffer
)
if detect_result is None:
# No wake word. Abort the rest of the pipeline.
self.run.end()
return
current_stage = PipelineStage.STT
# speech-to-text
intent_input = self.intent_input
if current_stage == PipelineStage.STT:
assert self.stt_metadata is not None
assert self.stt_stream is not None
if audio_buffer:
async def buffered_stream() -> AsyncGenerator[bytes, None]:
for chunk in audio_buffer:
yield chunk
assert self.stt_stream is not None
async for chunk in self.stt_stream:
yield chunk
stt_stream = cast(AsyncIterable[bytes], buffered_stream())
else:
stt_stream = self.stt_stream
intent_input = await self.run.speech_to_text(
self.stt_metadata,
self.stt_stream,
stt_stream,
)
current_stage = PipelineStage.INTENT
@ -707,7 +910,7 @@ class PipelineInput:
async def validate(self) -> None:
"""Validate pipeline input against start stage."""
if self.run.start_stage == PipelineStage.STT:
if self.run.start_stage in (PipelineStage.WAKE_WORD, PipelineStage.STT):
if self.run.pipeline.stt_engine is None:
raise PipelineRunValidationError(
"the pipeline does not support speech-to-text"
@ -741,6 +944,13 @@ class PipelineInput:
prepare_tasks = []
if (
start_stage_index
<= PIPELINE_STAGE_ORDER.index(PipelineStage.WAKE_WORD)
<= end_stage_index
):
prepare_tasks.append(self.run.prepare_wake_word_detection())
if (
start_stage_index
<= PIPELINE_STAGE_ORDER.index(PipelineStage.STT)

View file

@ -88,7 +88,7 @@ class VoiceCommandSegmenter:
self.in_command = False
def process(self, samples: bytes) -> bool:
"""Process a 16-bit 16Khz mono audio samples.
"""Process 16-bit 16Khz mono audio samples.
Returns False when command is done.
"""
@ -148,3 +148,94 @@ class VoiceCommandSegmenter:
self._silence_seconds_left = self.silence_seconds
return True
@dataclass
class VoiceActivityTimeout:
"""Detects silence in audio until a timeout is reached."""
silence_seconds: float
"""Seconds of silence before timeout."""
reset_seconds: float = 0.5
"""Seconds of speech before resetting timeout."""
vad_mode: int = 3
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
vad_frames: int = 480 # 30 ms
"""Must be 10, 20, or 30 ms at 16Khz."""
_silence_seconds_left: float = 0.0
"""Seconds left before considering voice command as stopped."""
_reset_seconds_left: float = 0.0
"""Seconds left before resetting start/stop time counters."""
_vad: webrtcvad.Vad = None
_audio_buffer: bytes = field(default_factory=bytes)
_bytes_per_chunk: int = 480 * 2 # 16-bit samples
_seconds_per_chunk: float = 0.03 # 30 ms
def __post_init__(self) -> None:
"""Initialize VAD."""
self._vad = webrtcvad.Vad(self.vad_mode)
self._bytes_per_chunk = self.vad_frames * 2
self._seconds_per_chunk = self.vad_frames / _SAMPLE_RATE
self.reset()
def reset(self) -> None:
"""Reset all counters and state."""
self._audio_buffer = b""
self._silence_seconds_left = self.silence_seconds
self._reset_seconds_left = self.reset_seconds
def process(self, samples: bytes) -> bool:
"""Process 16-bit 16Khz mono audio samples.
Returns False when timeout is reached.
"""
self._audio_buffer += samples
# Process in 10, 20, or 30 ms chunks.
num_chunks = len(self._audio_buffer) // self._bytes_per_chunk
for chunk_idx in range(num_chunks):
chunk_offset = chunk_idx * self._bytes_per_chunk
chunk = self._audio_buffer[
chunk_offset : chunk_offset + self._bytes_per_chunk
]
if not self._process_chunk(chunk):
return False
if num_chunks > 0:
# Remove from buffer
self._audio_buffer = self._audio_buffer[
num_chunks * self._bytes_per_chunk :
]
return True
def _process_chunk(self, chunk: bytes) -> bool:
"""Process a single chunk of 16-bit 16Khz mono audio.
Returns False when timeout is reached.
"""
if self._vad.is_speech(chunk, _SAMPLE_RATE):
# Speech
self._reset_seconds_left -= self._seconds_per_chunk
if self._reset_seconds_left <= 0:
# Reset timeout
self._silence_seconds_left = self.silence_seconds
else:
# Silence
self._silence_seconds_left -= self._seconds_per_chunk
if self._silence_seconds_left <= 0:
# Timeout reached
return False
# Slowly build reset counter back up
self._reset_seconds_left = min(
self.reset_seconds, self._reset_seconds_left + self._seconds_per_chunk
)
return True

View file

@ -26,11 +26,12 @@ from .pipeline import (
PipelineInput,
PipelineRun,
PipelineStage,
WakeWordSettings,
async_get_pipeline,
)
from .vad import VoiceCommandSegmenter
DEFAULT_TIMEOUT = 30
DEFAULT_WAKE_WORD_TIMEOUT = 3
_LOGGER = logging.getLogger(__name__)
@ -63,6 +64,18 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
cv.key_value_schemas(
"start_stage",
{
PipelineStage.WAKE_WORD: vol.Schema(
{
vol.Required("input"): {
vol.Required("sample_rate"): int,
vol.Optional("timeout"): vol.Any(float, int),
vol.Optional("audio_seconds_to_buffer"): vol.Any(
float, int
),
}
},
extra=vol.ALLOW_EXTRA,
),
PipelineStage.STT: vol.Schema(
{vol.Required("input"): {vol.Required("sample_rate"): int}},
extra=vol.ALLOW_EXTRA,
@ -102,6 +115,7 @@ async def websocket_run(
end_stage = PipelineStage(msg["end_stage"])
handler_id: int | None = None
unregister_handler: Callable[[], None] | None = None
wake_word_settings: WakeWordSettings | None = None
# Arguments to PipelineInput
input_args: dict[str, Any] = {
@ -109,24 +123,26 @@ async def websocket_run(
"device_id": msg.get("device_id"),
}
if start_stage == PipelineStage.STT:
if start_stage in (PipelineStage.WAKE_WORD, PipelineStage.STT):
# Audio pipeline that will receive audio as binary websocket messages
audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
incoming_sample_rate = msg["input"]["sample_rate"]
if start_stage == PipelineStage.WAKE_WORD:
wake_word_settings = WakeWordSettings(
timeout=msg["input"].get("timeout", DEFAULT_WAKE_WORD_TIMEOUT),
audio_seconds_to_buffer=msg["input"].get("audio_seconds_to_buffer", 0),
)
async def stt_stream() -> AsyncGenerator[bytes, None]:
state = None
segmenter = VoiceCommandSegmenter()
# Yield until we receive an empty chunk
while chunk := await audio_queue.get():
chunk, state = audioop.ratecv(
chunk, 2, 1, incoming_sample_rate, 16000, state
)
if not segmenter.process(chunk):
# Voice command is finished
break
if incoming_sample_rate != 16000:
chunk, state = audioop.ratecv(
chunk, 2, 1, incoming_sample_rate, 16000, state
)
yield chunk
def handle_binary(
@ -169,6 +185,7 @@ async def websocket_run(
"stt_binary_handler_id": handler_id,
"timeout": timeout,
},
wake_word_settings=wake_word_settings,
)
pipeline_input = PipelineInput(**input_args)

View file

@ -0,0 +1,119 @@
"""Provide functionality to wake word."""
from __future__ import annotations
from abc import abstractmethod
from collections.abc import AsyncIterable
import logging
from typing import final
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.typing import ConfigType
from homeassistant.util import dt as dt_util
from .const import DOMAIN
from .models import DetectionResult, WakeWord
__all__ = [
"async_default_engine",
"async_get_wake_word_detection_entity",
"DetectionResult",
"DOMAIN",
"WakeWord",
"WakeWordDetectionEntity",
]
_LOGGER = logging.getLogger(__name__)
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
@callback
def async_default_engine(hass: HomeAssistant) -> str | None:
"""Return the domain or entity id of the default engine."""
return next(iter(hass.states.async_entity_ids(DOMAIN)), None)
@callback
def async_get_wake_word_detection_entity(
hass: HomeAssistant, entity_id: str
) -> WakeWordDetectionEntity | None:
"""Return wake word entity."""
component: EntityComponent = hass.data[DOMAIN]
return component.get_entity(entity_id)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up STT."""
component = hass.data[DOMAIN] = EntityComponent(_LOGGER, DOMAIN, hass)
component.register_shutdown()
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
component: EntityComponent = hass.data[DOMAIN]
return await component.async_setup_entry(entry)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
component: EntityComponent = hass.data[DOMAIN]
return await component.async_unload_entry(entry)
class WakeWordDetectionEntity(RestoreEntity):
"""Represent a single wake word provider."""
_attr_should_poll = False
__last_processed: str | None = None
@property
@final
def state(self) -> str | None:
"""Return the state of the entity."""
if self.__last_processed is None:
return None
return self.__last_processed
@property
@abstractmethod
def supported_wake_words(self) -> list[WakeWord]:
"""Return a list of supported wake words."""
@abstractmethod
async def _async_process_audio_stream(
self, stream: AsyncIterable[tuple[bytes, int]]
) -> DetectionResult | None:
"""Try to detect wake word(s) in an audio stream with timestamps.
Audio must be 16Khz sample rate with 16-bit mono PCM samples.
"""
async def async_process_audio_stream(
self, stream: AsyncIterable[tuple[bytes, int]]
) -> DetectionResult | None:
"""Try to detect wake word(s) in an audio stream with timestamps.
Audio must be 16Khz sample rate with 16-bit mono PCM samples.
"""
self.__last_processed = dt_util.utcnow().isoformat()
self.async_write_ha_state()
return await self._async_process_audio_stream(stream)
async def async_internal_added_to_hass(self) -> None:
"""Call when the entity is added to hass."""
await super().async_internal_added_to_hass()
state = await self.async_get_last_state()
if (
state is not None
and state.state is not None
and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
):
self.__last_processed = state.state

View file

@ -0,0 +1,2 @@
"""Wake word constants."""
DOMAIN = "wake_word"

View file

@ -0,0 +1,8 @@
{
"domain": "wake_word",
"name": "Wake-word detection",
"codeowners": ["@home-assistant/core", "@synesthesiam"],
"documentation": "https://www.home-assistant.io/integrations/wake_word",
"integration_type": "entity",
"quality_scale": "internal"
}

View file

@ -0,0 +1,24 @@
"""Wake word models."""
from dataclasses import dataclass
@dataclass(frozen=True)
class WakeWord:
"""Wake word model."""
ww_id: str
name: str
@dataclass
class DetectionResult:
"""Result of wake word detection."""
ww_id: str
"""Id of detected wake word"""
timestamp: int | None
"""Timestamp of audio chunk with detected wake word"""
queued_audio: list[tuple[bytes, int]] | None = None
"""Audio chunks that were queued when wake word was detected."""

View file

@ -50,14 +50,21 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
errors={"base": "cannot_connect"},
)
# ASR = automated speech recognition (STT)
# ASR = automated speech recognition (speech-to-text)
asr_installed = [asr for asr in service.info.asr if asr.installed]
# TTS = text-to-speech
tts_installed = [tts for tts in service.info.tts if tts.installed]
# wake-word-detection
wake_installed = [wake for wake in service.info.wake if wake.installed]
if asr_installed:
name = asr_installed[0].name
elif tts_installed:
name = tts_installed[0].name
elif wake_installed:
name = wake_installed[0].name
else:
return self.async_abort(reason="no_services")

View file

@ -29,6 +29,8 @@ class WyomingService:
platforms.append(Platform.STT)
if any(tts.installed for tts in info.tts):
platforms.append(Platform.TTS)
if any(wake.installed for wake in info.wake):
platforms.append(Platform.WAKE_WORD)
self.platforms = platforms
@classmethod

View file

@ -0,0 +1,157 @@
"""Support for Wyoming wake-word-detection services."""
import asyncio
from collections.abc import AsyncIterable
import logging
from wyoming.audio import AudioChunk, AudioStart
from wyoming.client import AsyncTcpClient
from wyoming.wake import Detection
from homeassistant.components import wake_word
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN
from .data import WyomingService
from .error import WyomingError
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up Wyoming wake-word-detection."""
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id]
async_add_entities(
[
WyomingWakeWordProvider(config_entry, service),
]
)
class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
"""Wyoming wake-word-detection provider."""
def __init__(
self,
config_entry: ConfigEntry,
service: WyomingService,
) -> None:
"""Set up provider."""
self.service = service
wake_service = service.info.wake[0]
self._supported_wake_words = [
wake_word.WakeWord(ww_id=ww.name, name=ww.name)
for ww in wake_service.models
]
self._attr_name = wake_service.name
self._attr_unique_id = f"{config_entry.entry_id}-wake_word"
@property
def supported_wake_words(self) -> list[wake_word.WakeWord]:
"""Return a list of supported wake words."""
return self._supported_wake_words
async def _async_process_audio_stream(
self, stream: AsyncIterable[tuple[bytes, int]]
) -> wake_word.DetectionResult | None:
"""Try to detect one or more wake words in an audio stream.
Audio must be 16Khz sample rate with 16-bit mono PCM samples.
"""
async def next_chunk():
"""Get the next chunk from audio stream."""
async for chunk_bytes in stream:
return chunk_bytes
try:
async with AsyncTcpClient(self.service.host, self.service.port) as client:
await client.write_event(
AudioStart(
rate=16000,
width=2,
channels=1,
).event(),
)
# Read audio and wake events in "parallel"
audio_task = asyncio.create_task(next_chunk())
wake_task = asyncio.create_task(client.read_event())
pending = {audio_task, wake_task}
try:
while True:
done, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED
)
if wake_task in done:
event = wake_task.result()
if event is None:
_LOGGER.debug("Connection lost")
break
if Detection.is_type(event.type):
# Successful detection
detection = Detection.from_event(event)
_LOGGER.info(detection)
# Retrieve queued audio
queued_audio: list[tuple[bytes, int]] | None = None
if audio_task in pending:
# Save queued audio
await audio_task
pending.remove(audio_task)
queued_audio = [audio_task.result()]
return wake_word.DetectionResult(
ww_id=detection.name,
timestamp=detection.timestamp,
queued_audio=queued_audio,
)
# Next event
wake_task = asyncio.create_task(client.read_event())
pending.add(wake_task)
if audio_task in done:
# Forward audio to wake service
chunk_info = audio_task.result()
if chunk_info is None:
break
chunk_bytes, chunk_timestamp = chunk_info
chunk = AudioChunk(
rate=16000,
width=2,
channels=1,
audio=chunk_bytes,
timestamp=chunk_timestamp,
)
await client.write_event(chunk.event())
# Next chunk
audio_task = asyncio.create_task(next_chunk())
pending.add(audio_task)
finally:
# Clean up
if audio_task in pending:
# It's critical that we don't cancel the audio task or
# leave it hanging. This would mess up the pipeline STT
# by stopping the audio stream.
await audio_task
pending.remove(audio_task)
for task in pending:
task.cancel()
except (OSError, WyomingError) as err:
_LOGGER.exception("Error processing audio stream: %s", err)
return None

View file

@ -57,6 +57,7 @@ class Platform(StrEnum):
TTS = "tts"
VACUUM = "vacuum"
UPDATE = "update"
WAKE_WORD = "wake_word"
WATER_HEATER = "water_heater"
WEATHER = "weather"

View file

@ -7,7 +7,7 @@ from unittest.mock import AsyncMock
import pytest
from homeassistant.components import stt, tts
from homeassistant.components import stt, tts, wake_word
from homeassistant.components.assist_pipeline import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import (
PipelineData,
@ -174,6 +174,40 @@ class MockSttPlatform(MockPlatform):
self.async_get_engine = async_get_engine
class MockWakeWordEntity(wake_word.WakeWordDetectionEntity):
"""Mock wake word entity."""
fail_process_audio = False
url_path = "wake_word.test"
_attr_name = "test"
@property
def supported_wake_words(self) -> list[wake_word.WakeWord]:
"""Return a list of supported wake words."""
return [wake_word.WakeWord(ww_id="test_ww", name="Test Wake Word")]
async def _async_process_audio_stream(
self, stream: AsyncIterable[tuple[bytes, int]]
) -> wake_word.DetectionResult | None:
"""Try to detect wake word(s) in an audio stream with timestamps."""
async for chunk, timestamp in stream:
if chunk == b"wake word":
return wake_word.DetectionResult(
ww_id=self.supported_wake_words[0].ww_id,
timestamp=timestamp,
queued_audio=[(b"queued audio", 0)],
)
# Not detected
return None
@pytest.fixture
async def mock_wake_word_provider_entity(hass) -> MockWakeWordEntity:
"""Mock wake word provider."""
return MockWakeWordEntity()
class MockFlow(ConfigFlow):
"""Test flow."""
@ -193,6 +227,7 @@ async def init_supporting_components(
mock_stt_provider: MockSttProvider,
mock_stt_provider_entity: MockSttProviderEntity,
mock_tts_provider: MockTTSProvider,
mock_wake_word_provider_entity: MockWakeWordEntity,
config_flow_fixture,
):
"""Initialize relevant components with empty configs."""
@ -201,14 +236,18 @@ async def init_supporting_components(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Set up test config entry."""
await hass.config_entries.async_forward_entry_setup(config_entry, stt.DOMAIN)
await hass.config_entries.async_forward_entry_setups(
config_entry, [stt.DOMAIN, wake_word.DOMAIN]
)
return True
async def async_unload_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Unload up test config entry."""
await hass.config_entries.async_forward_entry_unload(config_entry, stt.DOMAIN)
await hass.config_entries.async_unload_platforms(
config_entry, [stt.DOMAIN, wake_word.DOMAIN]
)
return True
async def async_setup_entry_stt_platform(
@ -219,6 +258,14 @@ async def init_supporting_components(
"""Set up test stt platform via config entry."""
async_add_entities([mock_stt_provider_entity])
async def async_setup_entry_wake_word_platform(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up test wake word platform via config entry."""
async_add_entities([mock_wake_word_provider_entity])
mock_integration(
hass,
MockModule(
@ -242,11 +289,19 @@ async def init_supporting_components(
async_setup_entry=async_setup_entry_stt_platform,
),
)
mock_platform(
hass,
"test.wake_word",
MockPlatform(
async_setup_entry=async_setup_entry_wake_word_platform,
),
)
mock_platform(hass, "test.config_flow")
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
# assert await async_setup_component(hass, wake_word.DOMAIN, {"wake_word": {}})
assert await async_setup_component(hass, "media_source", {})
config_entry = MockConfigEntry(domain="test")

View file

@ -266,3 +266,114 @@
}),
])
# ---
# name: test_pipeline_from_audio_stream_wake_word
list([
dict({
'data': dict({
'language': 'en',
'pipeline': <ANY>,
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
dict({
'data': dict({
'engine': 'wake_word.test',
'metadata': dict({
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
'channel': <AudioChannels.CHANNEL_MONO: 1>,
'codec': <AudioCodecs.PCM: 'pcm'>,
'format': <AudioFormats.WAV: 'wav'>,
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
}),
}),
'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>,
}),
dict({
'data': dict({
'wake_word_output': dict({
'timestamp': 2000,
'ww_id': 'test_ww',
}),
}),
'type': <PipelineEventType.WAKE_WORD_END: 'wake_word-end'>,
}),
dict({
'data': dict({
'engine': 'test',
'metadata': dict({
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
'channel': <AudioChannels.CHANNEL_MONO: 1>,
'codec': <AudioCodecs.PCM: 'pcm'>,
'format': <AudioFormats.WAV: 'wav'>,
'language': 'en-US',
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
}),
}),
'type': <PipelineEventType.STT_START: 'stt-start'>,
}),
dict({
'data': dict({
'stt_output': dict({
'text': 'test transcript',
}),
}),
'type': <PipelineEventType.STT_END: 'stt-end'>,
}),
dict({
'data': dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant',
'intent_input': 'test transcript',
'language': 'en',
}),
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
}),
dict({
'data': dict({
'intent_output': dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'code': 'no_intent_match',
}),
'language': 'en',
'response_type': 'error',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': "Sorry, I couldn't understand that",
}),
}),
}),
}),
}),
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
}),
dict({
'data': dict({
'engine': 'test',
'language': 'en-US',
'tts_input': "Sorry, I couldn't understand that",
'voice': 'james_earl_jones',
}),
'type': <PipelineEventType.TTS_START: 'tts-start'>,
}),
dict({
'data': dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
}),
}),
'type': <PipelineEventType.TTS_END: 'tts-end'>,
}),
dict({
'data': None,
'type': <PipelineEventType.RUN_END: 'run-end'>,
}),
])
# ---

View file

@ -155,6 +155,243 @@
}),
})
# ---
# name: test_audio_pipeline_no_wake_word_engine
dict({
'code': 'wake-engine-missing',
'message': 'No wake word engine',
})
# ---
# name: test_audio_pipeline_no_wake_word_entity
dict({
'code': 'wake-provider-missing',
'message': 'No wake-word-detection provider for: wake_word.bad-entity-id',
})
# ---
# name: test_audio_pipeline_with_wake_word
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 30,
}),
})
# ---
# name: test_audio_pipeline_with_wake_word.1
dict({
'engine': 'wake_word.test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'sample_rate': 16000,
}),
})
# ---
# name: test_audio_pipeline_with_wake_word.2
dict({
'wake_word_output': dict({
'queued_audio': None,
'timestamp': 1000,
'ww_id': 'test_ww',
}),
})
# ---
# name: test_audio_pipeline_with_wake_word.3
dict({
'engine': 'test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'language': 'en-US',
'sample_rate': 16000,
}),
})
# ---
# name: test_audio_pipeline_with_wake_word.4
dict({
'stt_output': dict({
'text': 'test transcript',
}),
})
# ---
# name: test_audio_pipeline_with_wake_word.5
dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant',
'intent_input': 'test transcript',
'language': 'en',
})
# ---
# name: test_audio_pipeline_with_wake_word.6
dict({
'intent_output': dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'code': 'no_intent_match',
}),
'language': 'en',
'response_type': 'error',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': "Sorry, I couldn't understand that",
}),
}),
}),
}),
})
# ---
# name: test_audio_pipeline_with_wake_word.7
dict({
'engine': 'test',
'language': 'en-US',
'tts_input': "Sorry, I couldn't understand that",
'voice': 'james_earl_jones',
})
# ---
# name: test_audio_pipeline_with_wake_word.8
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
}),
})
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 30,
}),
})
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.1
dict({
'engine': 'wake_word.test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'sample_rate': 16000,
}),
})
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.2
dict({
'wake_word_output': dict({
'timestamp': 0,
'ww_id': 'test_ww',
}),
})
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.3
dict({
'engine': 'test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'language': 'en-US',
'sample_rate': 16000,
}),
})
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.4
dict({
'stt_output': dict({
'text': 'test transcript',
}),
})
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.5
dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant',
'intent_input': 'test transcript',
'language': 'en',
})
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.6
dict({
'intent_output': dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'code': 'no_intent_match',
}),
'language': 'en',
'response_type': 'error',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': "Sorry, I couldn't understand that",
}),
}),
}),
}),
})
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.7
dict({
'engine': 'test',
'language': 'en-US',
'tts_input': "Sorry, I couldn't understand that",
'voice': 'james_earl_jones',
})
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.8
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
}),
})
# ---
# name: test_audio_pipeline_with_wake_word_timeout
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 30,
}),
})
# ---
# name: test_audio_pipeline_with_wake_word_timeout.1
dict({
'engine': 'wake_word.test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'sample_rate': 16000,
}),
})
# ---
# name: test_audio_pipeline_with_wake_word_timeout.2
dict({
'code': 'wake-word-timeout',
'message': 'Wake word was not detected',
})
# ---
# name: test_intent_failed
dict({
'language': 'en',

View file

@ -1,5 +1,6 @@
"""Test Voice Assistant init."""
from dataclasses import asdict
import itertools as it
from unittest.mock import ANY
import pytest
@ -8,10 +9,12 @@ from syrupy.assertion import SnapshotAssertion
from homeassistant.components import assist_pipeline, stt
from homeassistant.core import Context, HomeAssistant
from .conftest import MockSttProvider, MockSttProviderEntity
from .conftest import MockSttProvider, MockSttProviderEntity, MockWakeWordEntity
from tests.typing import WebSocketGenerator
BYTES_ONE_SECOND = 16000 * 2
def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
"""Process events to remove dynamic values."""
@ -280,3 +283,61 @@ async def test_pipeline_from_audio_stream_unknown_pipeline(
)
assert not events
async def test_pipeline_from_audio_stream_wake_word(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test creating a pipeline from an audio stream with wake word."""
events = []
# [0, 1, ...]
wake_chunk_1 = bytes(it.islice(it.cycle(range(256)), BYTES_ONE_SECOND))
# [0, 2, ...]
wake_chunk_2 = bytes(it.islice(it.cycle(range(0, 256, 2)), BYTES_ONE_SECOND))
async def audio_data():
yield wake_chunk_1 # 1 second
yield wake_chunk_2 # 1 second
yield b"wake word"
yield b"part1"
yield b"part2"
yield b""
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
Context(),
events.append,
stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
audio_data(),
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
wake_word_settings=assist_pipeline.WakeWordSettings(
audio_seconds_to_buffer=1.5
),
)
assert process_events(events) == snapshot
# 1. Half of wake_chunk_1 + all wake_chunk_2
# 2. queued audio (from mock wake word entity)
# 3. part1
# 4. part2
assert len(mock_stt_provider.received) == 4
first_chunk = mock_stt_provider.received[0]
assert first_chunk == wake_chunk_1[len(wake_chunk_1) // 2 :] + wake_chunk_2
assert mock_stt_provider.received[1:] == [b"queued audio", b"part1", b"part2"]

View file

@ -167,6 +167,224 @@ async def test_audio_pipeline(
assert msg["result"] == {"events": events}
async def test_audio_pipeline_with_wake_word_timeout(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test timeout from a pipeline run with audio input/output + wake word."""
events = []
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "wake_word",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
"timeout": 1,
},
}
)
# result
msg = await client.receive_json()
assert msg["success"], msg
# run start
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# wake_word
msg = await client.receive_json()
assert msg["event"]["type"] == "wake_word-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# 2 seconds of silence
await client.send_bytes(bytes([1]) + bytes(16000 * 2 * 2))
# Time out error
msg = await client.receive_json()
assert msg["event"]["type"] == "error"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
async def test_audio_pipeline_with_wake_word_no_timeout(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test events from a pipeline run with audio input/output + wake word with no timeout."""
events = []
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "wake_word",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
"timeout": 0,
},
}
)
# result
msg = await client.receive_json()
assert msg["success"], msg
# run start
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# wake_word
msg = await client.receive_json()
assert msg["event"]["type"] == "wake_word-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# "audio"
await client.send_bytes(bytes([1]) + b"wake word")
msg = await client.receive_json()
assert msg["event"]["type"] == "wake_word-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# stt
msg = await client.receive_json()
assert msg["event"]["type"] == "stt-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# End of audio stream (handler id + empty payload)
await client.send_bytes(bytes([1]))
msg = await client.receive_json()
assert msg["event"]["type"] == "stt-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# intent
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# text-to-speech
msg = await client.receive_json()
assert msg["event"]["type"] == "tts-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
msg = await client.receive_json()
assert msg["event"]["type"] == "tts-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# run end
msg = await client.receive_json()
assert msg["event"]["type"] == "run-end"
assert msg["event"]["data"] is None
events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = list(pipeline_data.pipeline_runs)[0]
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}
async def test_audio_pipeline_no_wake_word_engine(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test timeout from a pipeline run with audio input/output + wake word."""
client = await hass_ws_client(hass)
with patch(
"homeassistant.components.wake_word.async_default_engine", return_value=None
):
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "wake_word",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
},
}
)
# error
msg = await client.receive_json()
assert not msg["success"]
assert "error" in msg
assert msg["error"] == snapshot
async def test_audio_pipeline_no_wake_word_entity(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test timeout from a pipeline run with audio input/output + wake word."""
client = await hass_ws_client(hass)
with patch(
"homeassistant.components.wake_word.async_default_engine",
return_value="wake_word.bad-entity-id",
), patch(
"homeassistant.components.wake_word.async_get_wake_word_detection_entity",
return_value=None,
):
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "wake_word",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
},
}
)
# error
msg = await client.receive_json()
assert not msg["success"]
assert "error" in msg
assert msg["error"] == snapshot
async def test_intent_timeout(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,

View file

@ -0,0 +1 @@
"""Wake-word-detection tests."""

View file

@ -0,0 +1,29 @@
"""Provide common test tools for wake-word-detection."""
from __future__ import annotations
from collections.abc import Callable, Coroutine
from pathlib import Path
from typing import Any
from homeassistant.components import wake_word
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from tests.common import MockPlatform, mock_platform
def mock_wake_word_entity_platform(
hass: HomeAssistant,
tmp_path: Path,
integration: str,
async_setup_entry: Callable[
[HomeAssistant, ConfigEntry, AddEntitiesCallback],
Coroutine[Any, Any, None],
]
| None = None,
) -> MockPlatform:
"""Specialize the mock platform for stt."""
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry)
mock_platform(hass, f"{integration}.{wake_word.DOMAIN}", loaded_platform)
return loaded_platform

View file

@ -0,0 +1,11 @@
# serializer version: 1
# name: test_ws_detect
dict({
'event': dict({
'timestamp': 2048.0,
'ww_id': 'test_ww',
}),
'id': 1,
'type': 'event',
})
# ---

View file

@ -0,0 +1,226 @@
"""Test wake_word component setup."""
from collections.abc import AsyncIterable, Generator
from pathlib import Path
import pytest
from homeassistant.components import wake_word
from homeassistant.config_entries import ConfigEntry, ConfigEntryState, ConfigFlow
from homeassistant.core import HomeAssistant, State
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.setup import async_setup_component
from .common import mock_wake_word_entity_platform
from tests.common import (
MockConfigEntry,
MockModule,
mock_config_flow,
mock_integration,
mock_platform,
mock_restore_cache,
)
TEST_DOMAIN = "test"
_SAMPLES_PER_CHUNK = 1024
_BYTES_PER_CHUNK = _SAMPLES_PER_CHUNK * 2 # 16-bit
_MS_PER_CHUNK = (_BYTES_PER_CHUNK // 2) // 16 # 16Khz
class MockProviderEntity(wake_word.WakeWordDetectionEntity):
"""Mock provider entity."""
url_path = "wake_word.test"
_attr_name = "test"
@property
def supported_wake_words(self) -> list[wake_word.WakeWord]:
"""Return a list of supported wake words."""
return [wake_word.WakeWord(ww_id="test_ww", name="Test Wake Word")]
async def _async_process_audio_stream(
self, stream: AsyncIterable[tuple[bytes, int]]
) -> wake_word.DetectionResult | None:
"""Try to detect wake word(s) in an audio stream with timestamps."""
async for _chunk, timestamp in stream:
if timestamp >= 2000:
return wake_word.DetectionResult(
ww_id=self.supported_wake_words[0].ww_id, timestamp=timestamp
)
# Not detected
return None
@pytest.fixture
def mock_provider_entity() -> MockProviderEntity:
"""Test provider entity fixture."""
return MockProviderEntity()
class WakeWordFlow(ConfigFlow):
"""Test flow."""
@pytest.fixture(autouse=True)
def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]:
"""Mock config flow."""
mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
with mock_config_flow(TEST_DOMAIN, WakeWordFlow):
yield
@pytest.fixture(name="setup")
async def setup_fixture(
hass: HomeAssistant,
tmp_path: Path,
) -> MockProviderEntity:
"""Set up the test environment."""
provider = MockProviderEntity()
await mock_config_entry_setup(hass, tmp_path, provider)
return provider
async def mock_config_entry_setup(
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
) -> MockConfigEntry:
"""Set up a test provider via config entry."""
async def async_setup_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Set up test config entry."""
await hass.config_entries.async_forward_entry_setup(
config_entry, wake_word.DOMAIN
)
return True
async def async_unload_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Unload up test config entry."""
await hass.config_entries.async_forward_entry_unload(
config_entry, wake_word.DOMAIN
)
return True
mock_integration(
hass,
MockModule(
TEST_DOMAIN,
async_setup_entry=async_setup_entry_init,
async_unload_entry=async_unload_entry_init,
),
)
async def async_setup_entry_platform(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up test stt platform via config entry."""
async_add_entities([mock_provider_entity])
mock_wake_word_entity_platform(
hass, tmp_path, TEST_DOMAIN, async_setup_entry_platform
)
config_entry = MockConfigEntry(domain=TEST_DOMAIN)
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
return config_entry
async def test_config_entry_unload(
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
) -> None:
"""Test we can unload config entry."""
config_entry = await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
assert config_entry.state == ConfigEntryState.LOADED
await hass.config_entries.async_unload(config_entry.entry_id)
assert config_entry.state == ConfigEntryState.NOT_LOADED
async def test_detected_entity(
hass: HomeAssistant, tmp_path: Path, setup: MockProviderEntity
) -> None:
"""Test successful detection through entity."""
async def three_second_stream():
timestamp = 0
while timestamp < 3000:
yield bytes(_BYTES_PER_CHUNK), timestamp
timestamp += _MS_PER_CHUNK
# Need 2 seconds to trigger
result = await setup.async_process_audio_stream(three_second_stream())
assert result == wake_word.DetectionResult("test_ww", 2048)
async def test_not_detected_entity(
hass: HomeAssistant, setup: MockProviderEntity
) -> None:
"""Test unsuccessful detection through entity."""
async def one_second_stream():
timestamp = 0
while timestamp < 1000:
yield bytes(_BYTES_PER_CHUNK), timestamp
timestamp += _MS_PER_CHUNK
# Need 2 seconds to trigger
result = await setup.async_process_audio_stream(one_second_stream())
assert result is None
async def test_default_engine_none(hass: HomeAssistant, tmp_path: Path) -> None:
"""Test async_default_engine."""
assert await async_setup_component(hass, wake_word.DOMAIN, {wake_word.DOMAIN: {}})
await hass.async_block_till_done()
assert wake_word.async_default_engine(hass) is None
async def test_default_engine_entity(
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
) -> None:
"""Test async_default_engine."""
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
assert wake_word.async_default_engine(hass) == f"{wake_word.DOMAIN}.{TEST_DOMAIN}"
async def test_get_engine_entity(
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
) -> None:
"""Test async_get_speech_to_text_engine."""
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
assert (
wake_word.async_get_wake_word_detection_entity(hass, f"{wake_word.DOMAIN}.test")
is mock_provider_entity
)
async def test_restore_state(
hass: HomeAssistant,
tmp_path: Path,
mock_provider_entity: MockProviderEntity,
) -> None:
"""Test we restore state in the integration."""
entity_id = f"{wake_word.DOMAIN}.{TEST_DOMAIN}"
timestamp = "2023-01-01T23:59:59+00:00"
mock_restore_cache(hass, (State(entity_id, timestamp),))
config_entry = await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
await hass.async_block_till_done()
assert config_entry.state == ConfigEntryState.LOADED
state = hass.states.get(entity_id)
assert state
assert state.state == timestamp

View file

@ -1,4 +1,6 @@
"""Tests for the Wyoming integration."""
import asyncio
from wyoming.info import (
AsrModel,
AsrProgram,
@ -7,6 +9,8 @@ from wyoming.info import (
TtsProgram,
TtsVoice,
TtsVoiceSpeaker,
WakeModel,
WakeProgram,
)
TEST_ATTR = Attribution(name="Test", url="http://www.test.com")
@ -49,6 +53,25 @@ TTS_INFO = Info(
)
]
)
WAKE_WORD_INFO = Info(
wake=[
WakeProgram(
name="Test Wake Word",
description="Test Wake Word",
installed=True,
attribution=TEST_ATTR,
models=[
WakeModel(
name="Test Model",
description="Test Model",
installed=True,
attribution=TEST_ATTR,
languages=["en-US"],
)
],
)
]
)
EMPTY_INFO = Info()
@ -68,6 +91,7 @@ class MockAsyncTcpClient:
async def read_event(self):
"""Receive."""
await asyncio.sleep(0) # force context switch
return self.responses.pop(0)
async def __aenter__(self):

View file

@ -8,7 +8,7 @@ from homeassistant.components import stt
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from . import STT_INFO, TTS_INFO
from . import STT_INFO, TTS_INFO, WAKE_WORD_INFO
from tests.common import MockConfigEntry
@ -52,6 +52,21 @@ def tts_config_entry(hass: HomeAssistant) -> ConfigEntry:
return entry
@pytest.fixture
def wake_word_config_entry(hass: HomeAssistant) -> ConfigEntry:
"""Create a config entry."""
entry = MockConfigEntry(
domain="wyoming",
data={
"host": "1.2.3.4",
"port": 1234,
},
title="Test Wake Word",
)
entry.add_to_hass(hass)
return entry
@pytest.fixture
async def init_wyoming_stt(hass: HomeAssistant, stt_config_entry: ConfigEntry):
"""Initialize Wyoming STT."""
@ -72,6 +87,18 @@ async def init_wyoming_tts(hass: HomeAssistant, tts_config_entry: ConfigEntry):
await hass.config_entries.async_setup(tts_config_entry.entry_id)
@pytest.fixture
async def init_wyoming_wake_word(
hass: HomeAssistant, wake_word_config_entry: ConfigEntry
):
"""Initialize Wyoming Wake Word."""
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=WAKE_WORD_INFO,
):
await hass.config_entries.async_setup(wake_word_config_entry.entry_id)
@pytest.fixture
def metadata(hass: HomeAssistant) -> stt.SpeechMetadata:
"""Get default STT metadata."""

View file

@ -0,0 +1,13 @@
# serializer version: 1
# name: test_streaming_audio
dict({
'queued_audio': list([
tuple(
b'chunk',
1,
),
]),
'timestamp': 0,
'ww_id': 'Test Model',
})
# ---

View file

@ -0,0 +1,108 @@
"""Test stt."""
from __future__ import annotations
import asyncio
from unittest.mock import patch
from syrupy.assertion import SnapshotAssertion
from wyoming.asr import Transcript
from wyoming.wake import Detection
from homeassistant.components import wake_word
from homeassistant.core import HomeAssistant
from . import MockAsyncTcpClient
async def test_support(hass: HomeAssistant, init_wyoming_wake_word) -> None:
"""Test supported properties."""
state = hass.states.get("wake_word.test_wake_word")
assert state is not None
entity = wake_word.async_get_wake_word_detection_entity(
hass, "wake_word.test_wake_word"
)
assert entity is not None
assert entity.supported_wake_words == [
wake_word.WakeWord(ww_id="Test Model", name="Test Model")
]
async def test_streaming_audio(
hass: HomeAssistant, init_wyoming_wake_word, snapshot: SnapshotAssertion
) -> None:
"""Test streaming audio."""
entity = wake_word.async_get_wake_word_detection_entity(
hass, "wake_word.test_wake_word"
)
assert entity is not None
async def audio_stream():
yield b"chunk", 0
# Delay to force a pending audio chunk
await asyncio.sleep(0.05)
yield b"chunk", 1
client_events = [
Transcript("not a wake word event").event(),
Detection(name="Test Model", timestamp=0).event(),
]
with patch(
"homeassistant.components.wyoming.wake_word.AsyncTcpClient",
MockAsyncTcpClient(client_events),
):
result = await entity.async_process_audio_stream(audio_stream())
assert result is not None
assert result == snapshot
async def test_streaming_audio_connection_lost(
hass: HomeAssistant, init_wyoming_wake_word
) -> None:
"""Test streaming audio and losing connection."""
entity = wake_word.async_get_wake_word_detection_entity(
hass, "wake_word.test_wake_word"
)
assert entity is not None
async def audio_stream():
# Delay to force a pending audio chunk
await asyncio.sleep(0.05)
yield b"chunk", 1
with patch(
"homeassistant.components.wyoming.wake_word.AsyncTcpClient",
MockAsyncTcpClient([None]),
):
result = await entity.async_process_audio_stream(audio_stream())
assert result is None
async def test_streaming_audio_oserror(
hass: HomeAssistant, init_wyoming_wake_word
) -> None:
"""Test streaming audio and error raising."""
entity = wake_word.async_get_wake_word_detection_entity(
hass, "wake_word.test_wake_word"
)
assert entity is not None
async def audio_stream():
yield b"chunk1", 1000
mock_client = MockAsyncTcpClient(
[Detection(name="Test Model", timestamp=1000).event()]
)
with patch(
"homeassistant.components.wyoming.wake_word.AsyncTcpClient",
mock_client,
), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")):
result = await entity.async_process_audio_stream(audio_stream())
assert result is None