Wake word cleanup (#98652)
* Make arguments for async_pipeline_from_audio_stream keyword-only after hass * Use a bytearray ring buffer * Move generator outside * Move stt stream generator outside * Clean up execute * Refactor VAD to use bytearray * More tests * Refactor chunk_samples to be more correct and robust * Change AudioBuffer to use append instead of setitem * Cleanup --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
49897341ba
commit
8768c39021
9 changed files with 458 additions and 163 deletions
|
@ -52,6 +52,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
|
||||
async def async_pipeline_from_audio_stream(
|
||||
hass: HomeAssistant,
|
||||
*,
|
||||
context: Context,
|
||||
event_callback: PipelineEventCallback,
|
||||
stt_metadata: stt.SpeechMetadata,
|
||||
|
|
|
@ -49,6 +49,7 @@ from .error import (
|
|||
WakeWordDetectionError,
|
||||
WakeWordTimeoutError,
|
||||
)
|
||||
from .ring_buffer import RingBuffer
|
||||
from .vad import VoiceActivityTimeout, VoiceCommandSegmenter
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
@ -425,7 +426,6 @@ class PipelineRun:
|
|||
|
||||
async def prepare_wake_word_detection(self) -> None:
|
||||
"""Prepare wake-word-detection."""
|
||||
# Need to add to pipeline store
|
||||
engine = wake_word.async_default_engine(self.hass)
|
||||
if engine is None:
|
||||
raise WakeWordDetectionError(
|
||||
|
@ -448,7 +448,7 @@ class PipelineRun:
|
|||
async def wake_word_detection(
|
||||
self,
|
||||
stream: AsyncIterable[bytes],
|
||||
audio_buffer: list[bytes],
|
||||
audio_chunks_for_stt: list[bytes],
|
||||
) -> wake_word.DetectionResult | None:
|
||||
"""Run wake-word-detection portion of pipeline. Returns detection result."""
|
||||
metadata_dict = asdict(
|
||||
|
@ -484,46 +484,29 @@ class PipelineRun:
|
|||
# Use VAD to determine timeout
|
||||
wake_word_vad = VoiceActivityTimeout(wake_word_settings.timeout)
|
||||
|
||||
# Audio chunk buffer.
|
||||
audio_bytes_to_buffer = int(
|
||||
wake_word_settings.audio_seconds_to_buffer * 16000 * 2
|
||||
# Audio chunk buffer. This audio will be forwarded to speech-to-text
|
||||
# after wake-word-detection.
|
||||
num_audio_bytes_to_buffer = int(
|
||||
wake_word_settings.audio_seconds_to_buffer * 16000 * 2 # 16-bit @ 16Khz
|
||||
)
|
||||
audio_ring_buffer = b""
|
||||
|
||||
async def timestamped_stream() -> AsyncIterable[tuple[bytes, int]]:
|
||||
"""Yield audio with timestamps (milliseconds since start of stream)."""
|
||||
nonlocal audio_ring_buffer
|
||||
|
||||
timestamp_ms = 0
|
||||
async for chunk in stream:
|
||||
yield chunk, timestamp_ms
|
||||
timestamp_ms += (len(chunk) // 2) // 16 # milliseconds @ 16Khz
|
||||
|
||||
# Keeping audio right before wake word detection allows the
|
||||
# voice command to be spoken immediately after the wake word.
|
||||
if audio_bytes_to_buffer > 0:
|
||||
audio_ring_buffer += chunk
|
||||
if len(audio_ring_buffer) > audio_bytes_to_buffer:
|
||||
# A proper ring buffer would be far more efficient
|
||||
audio_ring_buffer = audio_ring_buffer[
|
||||
len(audio_ring_buffer) - audio_bytes_to_buffer :
|
||||
]
|
||||
|
||||
if (wake_word_vad is not None) and (not wake_word_vad.process(chunk)):
|
||||
raise WakeWordTimeoutError(
|
||||
code="wake-word-timeout", message="Wake word was not detected"
|
||||
)
|
||||
stt_audio_buffer: RingBuffer | None = None
|
||||
if num_audio_bytes_to_buffer > 0:
|
||||
stt_audio_buffer = RingBuffer(num_audio_bytes_to_buffer)
|
||||
|
||||
try:
|
||||
# Detect wake word(s)
|
||||
result = await self.wake_word_provider.async_process_audio_stream(
|
||||
timestamped_stream()
|
||||
_wake_word_audio_stream(
|
||||
audio_stream=stream,
|
||||
stt_audio_buffer=stt_audio_buffer,
|
||||
wake_word_vad=wake_word_vad,
|
||||
)
|
||||
)
|
||||
|
||||
if audio_ring_buffer:
|
||||
if stt_audio_buffer is not None:
|
||||
# All audio kept from right before the wake word was detected as
|
||||
# a single chunk.
|
||||
audio_buffer.append(audio_ring_buffer)
|
||||
audio_chunks_for_stt.append(stt_audio_buffer.getvalue())
|
||||
except WakeWordTimeoutError:
|
||||
_LOGGER.debug("Timeout during wake word detection")
|
||||
raise
|
||||
|
@ -540,9 +523,14 @@ class PipelineRun:
|
|||
wake_word_output: dict[str, Any] = {}
|
||||
else:
|
||||
if result.queued_audio:
|
||||
# Add audio that was pending at detection
|
||||
# Add audio that was pending at detection.
|
||||
#
|
||||
# Because detection occurs *after* the wake word was actually
|
||||
# spoken, we need to make sure pending audio is forwarded to
|
||||
# speech-to-text so the user does not have to pause before
|
||||
# speaking the voice command.
|
||||
for chunk_ts in result.queued_audio:
|
||||
audio_buffer.append(chunk_ts[0])
|
||||
audio_chunks_for_stt.append(chunk_ts[0])
|
||||
|
||||
wake_word_output = asdict(result)
|
||||
|
||||
|
@ -608,41 +596,12 @@ class PipelineRun:
|
|||
)
|
||||
|
||||
try:
|
||||
segmenter = VoiceCommandSegmenter()
|
||||
|
||||
async def segment_stream(
|
||||
stream: AsyncIterable[bytes],
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""Stop stream when voice command is finished."""
|
||||
sent_vad_start = False
|
||||
timestamp_ms = 0
|
||||
async for chunk in stream:
|
||||
if not segmenter.process(chunk):
|
||||
# Silence detected at the end of voice command
|
||||
self.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.STT_VAD_END,
|
||||
{"timestamp": timestamp_ms},
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
if segmenter.in_command and (not sent_vad_start):
|
||||
# Speech detected at start of voice command
|
||||
self.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.STT_VAD_START,
|
||||
{"timestamp": timestamp_ms},
|
||||
)
|
||||
)
|
||||
sent_vad_start = True
|
||||
|
||||
yield chunk
|
||||
timestamp_ms += (len(chunk) // 2) // 16 # milliseconds @ 16Khz
|
||||
|
||||
# Transcribe audio stream
|
||||
result = await self.stt_provider.async_process_audio_stream(
|
||||
metadata, segment_stream(stream)
|
||||
metadata,
|
||||
self._speech_to_text_stream(
|
||||
audio_stream=stream, stt_vad=VoiceCommandSegmenter()
|
||||
),
|
||||
)
|
||||
except Exception as src_error:
|
||||
_LOGGER.exception("Unexpected error during speech-to-text")
|
||||
|
@ -677,6 +636,42 @@ class PipelineRun:
|
|||
|
||||
return result.text
|
||||
|
||||
async def _speech_to_text_stream(
|
||||
self,
|
||||
audio_stream: AsyncIterable[bytes],
|
||||
stt_vad: VoiceCommandSegmenter | None,
|
||||
sample_rate: int = 16000,
|
||||
sample_width: int = 2,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""Yield audio chunks until VAD detects silence or speech-to-text completes."""
|
||||
ms_per_sample = sample_rate // 1000
|
||||
sent_vad_start = False
|
||||
timestamp_ms = 0
|
||||
async for chunk in audio_stream:
|
||||
if stt_vad is not None:
|
||||
if not stt_vad.process(chunk):
|
||||
# Silence detected at the end of voice command
|
||||
self.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.STT_VAD_END,
|
||||
{"timestamp": timestamp_ms},
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
if stt_vad.in_command and (not sent_vad_start):
|
||||
# Speech detected at start of voice command
|
||||
self.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.STT_VAD_START,
|
||||
{"timestamp": timestamp_ms},
|
||||
)
|
||||
)
|
||||
sent_vad_start = True
|
||||
|
||||
yield chunk
|
||||
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
|
||||
|
||||
async def prepare_recognize_intent(self) -> None:
|
||||
"""Prepare recognizing an intent."""
|
||||
agent_info = conversation.async_get_agent_info(
|
||||
|
@ -861,13 +856,14 @@ class PipelineInput:
|
|||
"""Run pipeline."""
|
||||
self.run.start()
|
||||
current_stage: PipelineStage | None = self.run.start_stage
|
||||
audio_buffer: list[bytes] = []
|
||||
stt_audio_buffer: list[bytes] = []
|
||||
|
||||
try:
|
||||
if current_stage == PipelineStage.WAKE_WORD:
|
||||
# wake-word-detection
|
||||
assert self.stt_stream is not None
|
||||
detect_result = await self.run.wake_word_detection(
|
||||
self.stt_stream, audio_buffer
|
||||
self.stt_stream, stt_audio_buffer
|
||||
)
|
||||
if detect_result is None:
|
||||
# No wake word. Abort the rest of the pipeline.
|
||||
|
@ -882,19 +878,22 @@ class PipelineInput:
|
|||
assert self.stt_metadata is not None
|
||||
assert self.stt_stream is not None
|
||||
|
||||
if audio_buffer:
|
||||
stt_stream = self.stt_stream
|
||||
|
||||
async def buffered_stream() -> AsyncGenerator[bytes, None]:
|
||||
for chunk in audio_buffer:
|
||||
if stt_audio_buffer:
|
||||
# Send audio in the buffer first to speech-to-text, then move on to stt_stream.
|
||||
# This is basically an async itertools.chain.
|
||||
async def buffer_then_audio_stream() -> AsyncGenerator[bytes, None]:
|
||||
# Buffered audio
|
||||
for chunk in stt_audio_buffer:
|
||||
yield chunk
|
||||
|
||||
# Streamed audio
|
||||
assert self.stt_stream is not None
|
||||
async for chunk in self.stt_stream:
|
||||
yield chunk
|
||||
|
||||
stt_stream = cast(AsyncIterable[bytes], buffered_stream())
|
||||
else:
|
||||
stt_stream = self.stt_stream
|
||||
stt_stream = buffer_then_audio_stream()
|
||||
|
||||
intent_input = await self.run.speech_to_text(
|
||||
self.stt_metadata,
|
||||
|
@ -906,6 +905,7 @@ class PipelineInput:
|
|||
tts_input = self.tts_input
|
||||
|
||||
if current_stage == PipelineStage.INTENT:
|
||||
# intent-recognition
|
||||
assert intent_input is not None
|
||||
tts_input = await self.run.recognize_intent(
|
||||
intent_input,
|
||||
|
@ -915,6 +915,7 @@ class PipelineInput:
|
|||
current_stage = PipelineStage.TTS
|
||||
|
||||
if self.run.end_stage != PipelineStage.INTENT:
|
||||
# text-to-speech
|
||||
if current_stage == PipelineStage.TTS:
|
||||
assert tts_input is not None
|
||||
await self.run.text_to_speech(tts_input)
|
||||
|
@ -999,6 +1000,36 @@ class PipelineInput:
|
|||
await asyncio.gather(*prepare_tasks)
|
||||
|
||||
|
||||
async def _wake_word_audio_stream(
|
||||
audio_stream: AsyncIterable[bytes],
|
||||
stt_audio_buffer: RingBuffer | None,
|
||||
wake_word_vad: VoiceActivityTimeout | None,
|
||||
sample_rate: int = 16000,
|
||||
sample_width: int = 2,
|
||||
) -> AsyncIterable[tuple[bytes, int]]:
|
||||
"""Yield audio chunks with timestamps (milliseconds since start of stream).
|
||||
|
||||
Adds audio to a ring buffer that will be forwarded to speech-to-text after
|
||||
detection. Times out if VAD detects enough silence.
|
||||
"""
|
||||
ms_per_sample = sample_rate // 1000
|
||||
timestamp_ms = 0
|
||||
async for chunk in audio_stream:
|
||||
yield chunk, timestamp_ms
|
||||
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
|
||||
|
||||
# Wake-word-detection occurs *after* the wake word was actually
|
||||
# spoken. Keeping audio right before detection allows the voice
|
||||
# command to be spoken immediately after the wake word.
|
||||
if stt_audio_buffer is not None:
|
||||
stt_audio_buffer.put(chunk)
|
||||
|
||||
if (wake_word_vad is not None) and (not wake_word_vad.process(chunk)):
|
||||
raise WakeWordTimeoutError(
|
||||
code="wake-word-timeout", message="Wake word was not detected"
|
||||
)
|
||||
|
||||
|
||||
class PipelinePreferred(CollectionError):
|
||||
"""Raised when attempting to delete the preferred pipelen."""
|
||||
|
||||
|
|
57
homeassistant/components/assist_pipeline/ring_buffer.py
Normal file
57
homeassistant/components/assist_pipeline/ring_buffer.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
"""Implementation of a ring buffer using bytearray."""
|
||||
|
||||
|
||||
class RingBuffer:
|
||||
"""Basic ring buffer using a bytearray.
|
||||
|
||||
Not threadsafe.
|
||||
"""
|
||||
|
||||
def __init__(self, maxlen: int) -> None:
|
||||
"""Initialize empty buffer."""
|
||||
self._buffer = bytearray(maxlen)
|
||||
self._pos = 0
|
||||
self._length = 0
|
||||
self._maxlen = maxlen
|
||||
|
||||
@property
|
||||
def maxlen(self) -> int:
|
||||
"""Return the maximum size of the buffer."""
|
||||
return self._maxlen
|
||||
|
||||
@property
|
||||
def pos(self) -> int:
|
||||
"""Return the current put position."""
|
||||
return self._pos
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the length of data stored in the buffer."""
|
||||
return self._length
|
||||
|
||||
def put(self, data: bytes) -> None:
|
||||
"""Put a chunk of data into the buffer, possibly wrapping around."""
|
||||
data_len = len(data)
|
||||
new_pos = self._pos + data_len
|
||||
if new_pos >= self._maxlen:
|
||||
# Split into two chunks
|
||||
num_bytes_1 = self._maxlen - self._pos
|
||||
num_bytes_2 = new_pos - self._maxlen
|
||||
|
||||
self._buffer[self._pos : self._maxlen] = data[:num_bytes_1]
|
||||
self._buffer[:num_bytes_2] = data[num_bytes_1:]
|
||||
new_pos = new_pos - self._maxlen
|
||||
else:
|
||||
# Entire chunk fits at current position
|
||||
self._buffer[self._pos : self._pos + data_len] = data
|
||||
|
||||
self._pos = new_pos
|
||||
self._length = min(self._maxlen, self._length + data_len)
|
||||
|
||||
def getvalue(self) -> bytes:
|
||||
"""Get bytes written to the buffer."""
|
||||
if (self._pos + self._length) <= self._maxlen:
|
||||
# Single chunk
|
||||
return bytes(self._buffer[: self._length])
|
||||
|
||||
# Two chunks
|
||||
return bytes(self._buffer[self._pos :] + self._buffer[: self._pos])
|
|
@ -1,12 +1,15 @@
|
|||
"""Voice activity detection."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
from typing import Final
|
||||
|
||||
import webrtcvad
|
||||
|
||||
_SAMPLE_RATE = 16000
|
||||
_SAMPLE_RATE: Final = 16000 # Hz
|
||||
_SAMPLE_WIDTH: Final = 2 # bytes
|
||||
|
||||
|
||||
class VadSensitivity(StrEnum):
|
||||
|
@ -29,6 +32,45 @@ class VadSensitivity(StrEnum):
|
|||
return 1.0
|
||||
|
||||
|
||||
class AudioBuffer:
|
||||
"""Fixed-sized audio buffer with variable internal length."""
|
||||
|
||||
def __init__(self, maxlen: int) -> None:
|
||||
"""Initialize buffer."""
|
||||
self._buffer = bytearray(maxlen)
|
||||
self._length = 0
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
"""Get number of bytes currently in the buffer."""
|
||||
return self._length
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the buffer."""
|
||||
self._length = 0
|
||||
|
||||
def append(self, data: bytes) -> None:
|
||||
"""Append bytes to the buffer, increasing the internal length."""
|
||||
data_len = len(data)
|
||||
if (self._length + data_len) > len(self._buffer):
|
||||
raise ValueError("Length cannot be greater than buffer size")
|
||||
|
||||
self._buffer[self._length : self._length + data_len] = data
|
||||
self._length += data_len
|
||||
|
||||
def bytes(self) -> bytes:
|
||||
"""Convert written portion of buffer to bytes."""
|
||||
return bytes(self._buffer[: self._length])
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Get the number of bytes currently in the buffer."""
|
||||
return self._length
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Return True if there are bytes in the buffer."""
|
||||
return self._length > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoiceCommandSegmenter:
|
||||
"""Segments an audio stream into voice commands using webrtcvad."""
|
||||
|
@ -36,7 +78,7 @@ class VoiceCommandSegmenter:
|
|||
vad_mode: int = 3
|
||||
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
|
||||
|
||||
vad_frames: int = 480 # 30 ms
|
||||
vad_samples_per_chunk: int = 480 # 30 ms
|
||||
"""Must be 10, 20, or 30 ms at 16Khz."""
|
||||
|
||||
speech_seconds: float = 0.3
|
||||
|
@ -67,20 +109,23 @@ class VoiceCommandSegmenter:
|
|||
"""Seconds left before resetting start/stop time counters."""
|
||||
|
||||
_vad: webrtcvad.Vad = None
|
||||
_audio_buffer: bytes = field(default_factory=bytes)
|
||||
_bytes_per_chunk: int = 480 * 2 # 16-bit samples
|
||||
_seconds_per_chunk: float = 0.03 # 30 ms
|
||||
_leftover_chunk_buffer: AudioBuffer = field(init=False)
|
||||
_bytes_per_chunk: int = field(init=False)
|
||||
_seconds_per_chunk: float = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Initialize VAD."""
|
||||
self._vad = webrtcvad.Vad(self.vad_mode)
|
||||
self._bytes_per_chunk = self.vad_frames * 2
|
||||
self._seconds_per_chunk = self.vad_frames / _SAMPLE_RATE
|
||||
self._bytes_per_chunk = self.vad_samples_per_chunk * _SAMPLE_WIDTH
|
||||
self._seconds_per_chunk = self.vad_samples_per_chunk / _SAMPLE_RATE
|
||||
self._leftover_chunk_buffer = AudioBuffer(
|
||||
self.vad_samples_per_chunk * _SAMPLE_WIDTH
|
||||
)
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all counters and state."""
|
||||
self._audio_buffer = b""
|
||||
self._leftover_chunk_buffer.clear()
|
||||
self._speech_seconds_left = self.speech_seconds
|
||||
self._silence_seconds_left = self.silence_seconds
|
||||
self._timeout_seconds_left = self.timeout_seconds
|
||||
|
@ -92,27 +137,20 @@ class VoiceCommandSegmenter:
|
|||
|
||||
Returns False when command is done.
|
||||
"""
|
||||
self._audio_buffer += samples
|
||||
|
||||
# Process in 10, 20, or 30 ms chunks.
|
||||
num_chunks = len(self._audio_buffer) // self._bytes_per_chunk
|
||||
for chunk_idx in range(num_chunks):
|
||||
chunk_offset = chunk_idx * self._bytes_per_chunk
|
||||
chunk = self._audio_buffer[
|
||||
chunk_offset : chunk_offset + self._bytes_per_chunk
|
||||
]
|
||||
for chunk in chunk_samples(
|
||||
samples, self._bytes_per_chunk, self._leftover_chunk_buffer
|
||||
):
|
||||
if not self._process_chunk(chunk):
|
||||
self.reset()
|
||||
return False
|
||||
|
||||
if num_chunks > 0:
|
||||
# Remove from buffer
|
||||
self._audio_buffer = self._audio_buffer[
|
||||
num_chunks * self._bytes_per_chunk :
|
||||
]
|
||||
|
||||
return True
|
||||
|
||||
@property
|
||||
def audio_buffer(self) -> bytes:
|
||||
"""Get partial chunk in the audio buffer."""
|
||||
return self._leftover_chunk_buffer.bytes()
|
||||
|
||||
def _process_chunk(self, chunk: bytes) -> bool:
|
||||
"""Process a single chunk of 16-bit 16Khz mono audio.
|
||||
|
||||
|
@ -163,7 +201,7 @@ class VoiceActivityTimeout:
|
|||
vad_mode: int = 3
|
||||
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
|
||||
|
||||
vad_frames: int = 480 # 30 ms
|
||||
vad_samples_per_chunk: int = 480 # 30 ms
|
||||
"""Must be 10, 20, or 30 ms at 16Khz."""
|
||||
|
||||
_silence_seconds_left: float = 0.0
|
||||
|
@ -173,20 +211,23 @@ class VoiceActivityTimeout:
|
|||
"""Seconds left before resetting start/stop time counters."""
|
||||
|
||||
_vad: webrtcvad.Vad = None
|
||||
_audio_buffer: bytes = field(default_factory=bytes)
|
||||
_bytes_per_chunk: int = 480 * 2 # 16-bit samples
|
||||
_seconds_per_chunk: float = 0.03 # 30 ms
|
||||
_leftover_chunk_buffer: AudioBuffer = field(init=False)
|
||||
_bytes_per_chunk: int = field(init=False)
|
||||
_seconds_per_chunk: float = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Initialize VAD."""
|
||||
self._vad = webrtcvad.Vad(self.vad_mode)
|
||||
self._bytes_per_chunk = self.vad_frames * 2
|
||||
self._seconds_per_chunk = self.vad_frames / _SAMPLE_RATE
|
||||
self._bytes_per_chunk = self.vad_samples_per_chunk * _SAMPLE_WIDTH
|
||||
self._seconds_per_chunk = self.vad_samples_per_chunk / _SAMPLE_RATE
|
||||
self._leftover_chunk_buffer = AudioBuffer(
|
||||
self.vad_samples_per_chunk * _SAMPLE_WIDTH
|
||||
)
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all counters and state."""
|
||||
self._audio_buffer = b""
|
||||
self._leftover_chunk_buffer.clear()
|
||||
self._silence_seconds_left = self.silence_seconds
|
||||
self._reset_seconds_left = self.reset_seconds
|
||||
|
||||
|
@ -195,24 +236,12 @@ class VoiceActivityTimeout:
|
|||
|
||||
Returns False when timeout is reached.
|
||||
"""
|
||||
self._audio_buffer += samples
|
||||
|
||||
# Process in 10, 20, or 30 ms chunks.
|
||||
num_chunks = len(self._audio_buffer) // self._bytes_per_chunk
|
||||
for chunk_idx in range(num_chunks):
|
||||
chunk_offset = chunk_idx * self._bytes_per_chunk
|
||||
chunk = self._audio_buffer[
|
||||
chunk_offset : chunk_offset + self._bytes_per_chunk
|
||||
]
|
||||
for chunk in chunk_samples(
|
||||
samples, self._bytes_per_chunk, self._leftover_chunk_buffer
|
||||
):
|
||||
if not self._process_chunk(chunk):
|
||||
return False
|
||||
|
||||
if num_chunks > 0:
|
||||
# Remove from buffer
|
||||
self._audio_buffer = self._audio_buffer[
|
||||
num_chunks * self._bytes_per_chunk :
|
||||
]
|
||||
|
||||
return True
|
||||
|
||||
def _process_chunk(self, chunk: bytes) -> bool:
|
||||
|
@ -239,3 +268,37 @@ class VoiceActivityTimeout:
|
|||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def chunk_samples(
|
||||
samples: bytes,
|
||||
bytes_per_chunk: int,
|
||||
leftover_chunk_buffer: AudioBuffer,
|
||||
) -> Iterable[bytes]:
|
||||
"""Yield fixed-sized chunks from samples, keeping leftover bytes from previous call(s)."""
|
||||
|
||||
if (len(leftover_chunk_buffer) + len(samples)) < bytes_per_chunk:
|
||||
# Extend leftover chunk, but not enough samples to complete it
|
||||
leftover_chunk_buffer.append(samples)
|
||||
return
|
||||
|
||||
next_chunk_idx = 0
|
||||
|
||||
if leftover_chunk_buffer:
|
||||
# Add to leftover chunk from previous call(s).
|
||||
bytes_to_copy = bytes_per_chunk - len(leftover_chunk_buffer)
|
||||
leftover_chunk_buffer.append(samples[:bytes_to_copy])
|
||||
next_chunk_idx = bytes_to_copy
|
||||
|
||||
# Process full chunk in buffer
|
||||
yield leftover_chunk_buffer.bytes()
|
||||
leftover_chunk_buffer.clear()
|
||||
|
||||
while next_chunk_idx < len(samples) - bytes_per_chunk + 1:
|
||||
# Process full chunk
|
||||
yield samples[next_chunk_idx : next_chunk_idx + bytes_per_chunk]
|
||||
next_chunk_idx += bytes_per_chunk
|
||||
|
||||
# Capture leftover chunks
|
||||
if rest_samples := samples[next_chunk_idx:]:
|
||||
leftover_chunk_buffer.append(rest_samples)
|
||||
|
|
|
@ -79,8 +79,6 @@ class WakeWordDetectionEntity(RestoreEntity):
|
|||
@final
|
||||
def state(self) -> str | None:
|
||||
"""Return the state of the entity."""
|
||||
if self.__last_detected is None:
|
||||
return None
|
||||
return self.__last_detected
|
||||
|
||||
@property
|
||||
|
|
|
@ -317,6 +317,12 @@
|
|||
}),
|
||||
'type': <PipelineEventType.STT_VAD_START: 'stt-vad-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'timestamp': 1500,
|
||||
}),
|
||||
'type': <PipelineEventType.STT_VAD_END: 'stt-vad-end'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'stt_output': dict({
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""Test Voice Assistant init."""
|
||||
from dataclasses import asdict
|
||||
import itertools as it
|
||||
from unittest.mock import ANY
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
@ -49,9 +49,9 @@ async def test_pipeline_from_audio_stream_auto(
|
|||
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
Context(),
|
||||
events.append,
|
||||
stt.SpeechMetadata(
|
||||
context=Context(),
|
||||
event_callback=events.append,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
|
@ -59,7 +59,7 @@ async def test_pipeline_from_audio_stream_auto(
|
|||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
audio_data(),
|
||||
stt_stream=audio_data(),
|
||||
)
|
||||
|
||||
assert process_events(events) == snapshot
|
||||
|
@ -108,9 +108,9 @@ async def test_pipeline_from_audio_stream_legacy(
|
|||
# Use the created pipeline
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
Context(),
|
||||
events.append,
|
||||
stt.SpeechMetadata(
|
||||
context=Context(),
|
||||
event_callback=events.append,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="en-UK",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
|
@ -118,7 +118,7 @@ async def test_pipeline_from_audio_stream_legacy(
|
|||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
audio_data(),
|
||||
stt_stream=audio_data(),
|
||||
pipeline_id=pipeline_id,
|
||||
)
|
||||
|
||||
|
@ -168,9 +168,9 @@ async def test_pipeline_from_audio_stream_entity(
|
|||
# Use the created pipeline
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
Context(),
|
||||
events.append,
|
||||
stt.SpeechMetadata(
|
||||
context=Context(),
|
||||
event_callback=events.append,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="en-UK",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
|
@ -178,7 +178,7 @@ async def test_pipeline_from_audio_stream_entity(
|
|||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
audio_data(),
|
||||
stt_stream=audio_data(),
|
||||
pipeline_id=pipeline_id,
|
||||
)
|
||||
|
||||
|
@ -229,9 +229,9 @@ async def test_pipeline_from_audio_stream_no_stt(
|
|||
with pytest.raises(assist_pipeline.pipeline.PipelineRunValidationError):
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
Context(),
|
||||
events.append,
|
||||
stt.SpeechMetadata(
|
||||
context=Context(),
|
||||
event_callback=events.append,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="en-UK",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
|
@ -239,7 +239,7 @@ async def test_pipeline_from_audio_stream_no_stt(
|
|||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
audio_data(),
|
||||
stt_stream=audio_data(),
|
||||
pipeline_id=pipeline_id,
|
||||
)
|
||||
|
||||
|
@ -268,9 +268,9 @@ async def test_pipeline_from_audio_stream_unknown_pipeline(
|
|||
with pytest.raises(assist_pipeline.PipelineNotFound):
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
Context(),
|
||||
events.append,
|
||||
stt.SpeechMetadata(
|
||||
context=Context(),
|
||||
event_callback=events.append,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="en-UK",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
|
@ -278,7 +278,7 @@ async def test_pipeline_from_audio_stream_unknown_pipeline(
|
|||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
audio_data(),
|
||||
stt_stream=audio_data(),
|
||||
pipeline_id="blah",
|
||||
)
|
||||
|
||||
|
@ -308,26 +308,38 @@ async def test_pipeline_from_audio_stream_wake_word(
|
|||
yield b"wake word"
|
||||
yield b"part1"
|
||||
yield b"part2"
|
||||
yield b"end"
|
||||
yield b""
|
||||
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
Context(),
|
||||
events.append,
|
||||
stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
audio_data(),
|
||||
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
|
||||
wake_word_settings=assist_pipeline.WakeWordSettings(
|
||||
audio_seconds_to_buffer=1.5
|
||||
),
|
||||
)
|
||||
def continue_stt(self, chunk):
|
||||
# Ensure stt_vad_start event is triggered
|
||||
self.in_command = True
|
||||
|
||||
# Stop on fake end chunk to trigger stt_vad_end
|
||||
return chunk != b"end"
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.VoiceCommandSegmenter.process",
|
||||
continue_stt,
|
||||
):
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
context=Context(),
|
||||
event_callback=events.append,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_data(),
|
||||
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
|
||||
wake_word_settings=assist_pipeline.WakeWordSettings(
|
||||
audio_seconds_to_buffer=1.5
|
||||
),
|
||||
)
|
||||
|
||||
assert process_events(events) == snapshot
|
||||
|
||||
|
|
38
tests/components/assist_pipeline/test_ring_buffer.py
Normal file
38
tests/components/assist_pipeline/test_ring_buffer.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
"""Tests for audio ring buffer."""
|
||||
from homeassistant.components.assist_pipeline.ring_buffer import RingBuffer
|
||||
|
||||
|
||||
def test_ring_buffer_empty() -> None:
|
||||
"""Test empty ring buffer."""
|
||||
rb = RingBuffer(10)
|
||||
assert rb.maxlen == 10
|
||||
assert rb.pos == 0
|
||||
assert rb.getvalue() == b""
|
||||
|
||||
|
||||
def test_ring_buffer_put_1() -> None:
|
||||
"""Test putting some data smaller than the maximum length."""
|
||||
rb = RingBuffer(10)
|
||||
rb.put(bytes([1, 2, 3, 4, 5]))
|
||||
assert len(rb) == 5
|
||||
assert rb.pos == 5
|
||||
assert rb.getvalue() == bytes([1, 2, 3, 4, 5])
|
||||
|
||||
|
||||
def test_ring_buffer_put_2() -> None:
|
||||
"""Test putting some data past the end of the buffer."""
|
||||
rb = RingBuffer(10)
|
||||
rb.put(bytes([1, 2, 3, 4, 5]))
|
||||
rb.put(bytes([6, 7, 8, 9, 10, 11, 12]))
|
||||
assert len(rb) == 10
|
||||
assert rb.pos == 2
|
||||
assert rb.getvalue() == bytes([3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
|
||||
|
||||
|
||||
def test_ring_buffer_put_too_large() -> None:
|
||||
"""Test putting data too large for the buffer."""
|
||||
rb = RingBuffer(10)
|
||||
rb.put(bytes([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]))
|
||||
assert len(rb) == 10
|
||||
assert rb.pos == 2
|
||||
assert rb.getvalue() == bytes([3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
|
|
@ -1,7 +1,12 @@
|
|||
"""Tests for webrtcvad voice command segmenter."""
|
||||
import itertools as it
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant.components.assist_pipeline.vad import VoiceCommandSegmenter
|
||||
from homeassistant.components.assist_pipeline.vad import (
|
||||
AudioBuffer,
|
||||
VoiceCommandSegmenter,
|
||||
chunk_samples,
|
||||
)
|
||||
|
||||
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit
|
||||
|
||||
|
@ -36,3 +41,87 @@ def test_speech() -> None:
|
|||
# silence
|
||||
# False return value indicates voice command is finished
|
||||
assert not segmenter.process(bytes(_ONE_SECOND))
|
||||
|
||||
|
||||
def test_audio_buffer() -> None:
|
||||
"""Test audio buffer wrapping."""
|
||||
|
||||
def is_speech(self, chunk, sample_rate):
|
||||
"""Disable VAD."""
|
||||
return False
|
||||
|
||||
with patch(
|
||||
"webrtcvad.Vad.is_speech",
|
||||
new=is_speech,
|
||||
):
|
||||
segmenter = VoiceCommandSegmenter()
|
||||
bytes_per_chunk = segmenter.vad_samples_per_chunk * 2
|
||||
|
||||
with patch.object(
|
||||
segmenter, "_process_chunk", return_value=True
|
||||
) as mock_process:
|
||||
# Partially fill audio buffer
|
||||
half_chunk = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk // 2))
|
||||
segmenter.process(half_chunk)
|
||||
|
||||
assert not mock_process.called
|
||||
assert segmenter.audio_buffer == half_chunk
|
||||
|
||||
# Fill and wrap with 1/4 chunk left over
|
||||
three_quarters_chunk = bytes(
|
||||
it.islice(it.cycle(range(256)), int(0.75 * bytes_per_chunk))
|
||||
)
|
||||
segmenter.process(three_quarters_chunk)
|
||||
|
||||
assert mock_process.call_count == 1
|
||||
assert (
|
||||
segmenter.audio_buffer
|
||||
== three_quarters_chunk[
|
||||
len(three_quarters_chunk) - (bytes_per_chunk // 4) :
|
||||
]
|
||||
)
|
||||
assert (
|
||||
mock_process.call_args[0][0]
|
||||
== half_chunk + three_quarters_chunk[: bytes_per_chunk // 2]
|
||||
)
|
||||
|
||||
# Run 2 chunks through
|
||||
segmenter.reset()
|
||||
assert len(segmenter.audio_buffer) == 0
|
||||
|
||||
mock_process.reset_mock()
|
||||
two_chunks = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk * 2))
|
||||
segmenter.process(two_chunks)
|
||||
|
||||
assert mock_process.call_count == 2
|
||||
assert len(segmenter.audio_buffer) == 0
|
||||
assert mock_process.call_args_list[0][0][0] == two_chunks[:bytes_per_chunk]
|
||||
assert mock_process.call_args_list[1][0][0] == two_chunks[bytes_per_chunk:]
|
||||
|
||||
|
||||
def test_partial_chunk() -> None:
|
||||
"""Test that chunk_samples returns when given a partial chunk."""
|
||||
bytes_per_chunk = 5
|
||||
samples = bytes([1, 2, 3])
|
||||
leftover_chunk_buffer = AudioBuffer(bytes_per_chunk)
|
||||
chunks = list(chunk_samples(samples, bytes_per_chunk, leftover_chunk_buffer))
|
||||
|
||||
assert len(chunks) == 0
|
||||
assert leftover_chunk_buffer.bytes() == samples
|
||||
|
||||
|
||||
def test_chunk_samples_leftover() -> None:
|
||||
"""Test that chunk_samples property keeps left over bytes across calls."""
|
||||
bytes_per_chunk = 5
|
||||
samples = bytes([1, 2, 3, 4, 5, 6])
|
||||
leftover_chunk_buffer = AudioBuffer(bytes_per_chunk)
|
||||
chunks = list(chunk_samples(samples, bytes_per_chunk, leftover_chunk_buffer))
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert leftover_chunk_buffer.bytes() == bytes([6])
|
||||
|
||||
# Add some more to the chunk
|
||||
chunks = list(chunk_samples(samples, bytes_per_chunk, leftover_chunk_buffer))
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert leftover_chunk_buffer.bytes() == bytes([5, 6])
|
||||
|
|
Loading…
Add table
Reference in a new issue