Wake word cleanup (#98652)

* Make arguments for async_pipeline_from_audio_stream keyword-only after hass

* Use a bytearray ring buffer

* Move generator outside

* Move stt stream generator outside

* Clean up execute

* Refactor VAD to use bytearray

* More tests

* Refactor chunk_samples to be more correct and robust

* Change AudioBuffer to use append instead of setitem

* Cleanup

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen 2023-08-25 12:28:48 -05:00 committed by GitHub
parent 49897341ba
commit 8768c39021
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 458 additions and 163 deletions

View file

@ -52,6 +52,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def async_pipeline_from_audio_stream(
hass: HomeAssistant,
*,
context: Context,
event_callback: PipelineEventCallback,
stt_metadata: stt.SpeechMetadata,

View file

@ -49,6 +49,7 @@ from .error import (
WakeWordDetectionError,
WakeWordTimeoutError,
)
from .ring_buffer import RingBuffer
from .vad import VoiceActivityTimeout, VoiceCommandSegmenter
_LOGGER = logging.getLogger(__name__)
@ -425,7 +426,6 @@ class PipelineRun:
async def prepare_wake_word_detection(self) -> None:
"""Prepare wake-word-detection."""
# Need to add to pipeline store
engine = wake_word.async_default_engine(self.hass)
if engine is None:
raise WakeWordDetectionError(
@ -448,7 +448,7 @@ class PipelineRun:
async def wake_word_detection(
self,
stream: AsyncIterable[bytes],
audio_buffer: list[bytes],
audio_chunks_for_stt: list[bytes],
) -> wake_word.DetectionResult | None:
"""Run wake-word-detection portion of pipeline. Returns detection result."""
metadata_dict = asdict(
@ -484,46 +484,29 @@ class PipelineRun:
# Use VAD to determine timeout
wake_word_vad = VoiceActivityTimeout(wake_word_settings.timeout)
# Audio chunk buffer.
audio_bytes_to_buffer = int(
wake_word_settings.audio_seconds_to_buffer * 16000 * 2
# Audio chunk buffer. This audio will be forwarded to speech-to-text
# after wake-word-detection.
num_audio_bytes_to_buffer = int(
wake_word_settings.audio_seconds_to_buffer * 16000 * 2 # 16-bit @ 16Khz
)
audio_ring_buffer = b""
async def timestamped_stream() -> AsyncIterable[tuple[bytes, int]]:
"""Yield audio with timestamps (milliseconds since start of stream)."""
nonlocal audio_ring_buffer
timestamp_ms = 0
async for chunk in stream:
yield chunk, timestamp_ms
timestamp_ms += (len(chunk) // 2) // 16 # milliseconds @ 16Khz
# Keeping audio right before wake word detection allows the
# voice command to be spoken immediately after the wake word.
if audio_bytes_to_buffer > 0:
audio_ring_buffer += chunk
if len(audio_ring_buffer) > audio_bytes_to_buffer:
# A proper ring buffer would be far more efficient
audio_ring_buffer = audio_ring_buffer[
len(audio_ring_buffer) - audio_bytes_to_buffer :
]
if (wake_word_vad is not None) and (not wake_word_vad.process(chunk)):
raise WakeWordTimeoutError(
code="wake-word-timeout", message="Wake word was not detected"
)
stt_audio_buffer: RingBuffer | None = None
if num_audio_bytes_to_buffer > 0:
stt_audio_buffer = RingBuffer(num_audio_bytes_to_buffer)
try:
# Detect wake word(s)
result = await self.wake_word_provider.async_process_audio_stream(
timestamped_stream()
_wake_word_audio_stream(
audio_stream=stream,
stt_audio_buffer=stt_audio_buffer,
wake_word_vad=wake_word_vad,
)
)
if audio_ring_buffer:
if stt_audio_buffer is not None:
# All audio kept from right before the wake word was detected as
# a single chunk.
audio_buffer.append(audio_ring_buffer)
audio_chunks_for_stt.append(stt_audio_buffer.getvalue())
except WakeWordTimeoutError:
_LOGGER.debug("Timeout during wake word detection")
raise
@ -540,9 +523,14 @@ class PipelineRun:
wake_word_output: dict[str, Any] = {}
else:
if result.queued_audio:
# Add audio that was pending at detection
# Add audio that was pending at detection.
#
# Because detection occurs *after* the wake word was actually
# spoken, we need to make sure pending audio is forwarded to
# speech-to-text so the user does not have to pause before
# speaking the voice command.
for chunk_ts in result.queued_audio:
audio_buffer.append(chunk_ts[0])
audio_chunks_for_stt.append(chunk_ts[0])
wake_word_output = asdict(result)
@ -608,41 +596,12 @@ class PipelineRun:
)
try:
segmenter = VoiceCommandSegmenter()
async def segment_stream(
stream: AsyncIterable[bytes],
) -> AsyncGenerator[bytes, None]:
"""Stop stream when voice command is finished."""
sent_vad_start = False
timestamp_ms = 0
async for chunk in stream:
if not segmenter.process(chunk):
# Silence detected at the end of voice command
self.process_event(
PipelineEvent(
PipelineEventType.STT_VAD_END,
{"timestamp": timestamp_ms},
)
)
break
if segmenter.in_command and (not sent_vad_start):
# Speech detected at start of voice command
self.process_event(
PipelineEvent(
PipelineEventType.STT_VAD_START,
{"timestamp": timestamp_ms},
)
)
sent_vad_start = True
yield chunk
timestamp_ms += (len(chunk) // 2) // 16 # milliseconds @ 16Khz
# Transcribe audio stream
result = await self.stt_provider.async_process_audio_stream(
metadata, segment_stream(stream)
metadata,
self._speech_to_text_stream(
audio_stream=stream, stt_vad=VoiceCommandSegmenter()
),
)
except Exception as src_error:
_LOGGER.exception("Unexpected error during speech-to-text")
@ -677,6 +636,42 @@ class PipelineRun:
return result.text
async def _speech_to_text_stream(
self,
audio_stream: AsyncIterable[bytes],
stt_vad: VoiceCommandSegmenter | None,
sample_rate: int = 16000,
sample_width: int = 2,
) -> AsyncGenerator[bytes, None]:
"""Yield audio chunks until VAD detects silence or speech-to-text completes."""
ms_per_sample = sample_rate // 1000
sent_vad_start = False
timestamp_ms = 0
async for chunk in audio_stream:
if stt_vad is not None:
if not stt_vad.process(chunk):
# Silence detected at the end of voice command
self.process_event(
PipelineEvent(
PipelineEventType.STT_VAD_END,
{"timestamp": timestamp_ms},
)
)
break
if stt_vad.in_command and (not sent_vad_start):
# Speech detected at start of voice command
self.process_event(
PipelineEvent(
PipelineEventType.STT_VAD_START,
{"timestamp": timestamp_ms},
)
)
sent_vad_start = True
yield chunk
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
async def prepare_recognize_intent(self) -> None:
"""Prepare recognizing an intent."""
agent_info = conversation.async_get_agent_info(
@ -861,13 +856,14 @@ class PipelineInput:
"""Run pipeline."""
self.run.start()
current_stage: PipelineStage | None = self.run.start_stage
audio_buffer: list[bytes] = []
stt_audio_buffer: list[bytes] = []
try:
if current_stage == PipelineStage.WAKE_WORD:
# wake-word-detection
assert self.stt_stream is not None
detect_result = await self.run.wake_word_detection(
self.stt_stream, audio_buffer
self.stt_stream, stt_audio_buffer
)
if detect_result is None:
# No wake word. Abort the rest of the pipeline.
@ -882,19 +878,22 @@ class PipelineInput:
assert self.stt_metadata is not None
assert self.stt_stream is not None
if audio_buffer:
stt_stream = self.stt_stream
async def buffered_stream() -> AsyncGenerator[bytes, None]:
for chunk in audio_buffer:
if stt_audio_buffer:
# Send audio in the buffer first to speech-to-text, then move on to stt_stream.
# This is basically an async itertools.chain.
async def buffer_then_audio_stream() -> AsyncGenerator[bytes, None]:
# Buffered audio
for chunk in stt_audio_buffer:
yield chunk
# Streamed audio
assert self.stt_stream is not None
async for chunk in self.stt_stream:
yield chunk
stt_stream = cast(AsyncIterable[bytes], buffered_stream())
else:
stt_stream = self.stt_stream
stt_stream = buffer_then_audio_stream()
intent_input = await self.run.speech_to_text(
self.stt_metadata,
@ -906,6 +905,7 @@ class PipelineInput:
tts_input = self.tts_input
if current_stage == PipelineStage.INTENT:
# intent-recognition
assert intent_input is not None
tts_input = await self.run.recognize_intent(
intent_input,
@ -915,6 +915,7 @@ class PipelineInput:
current_stage = PipelineStage.TTS
if self.run.end_stage != PipelineStage.INTENT:
# text-to-speech
if current_stage == PipelineStage.TTS:
assert tts_input is not None
await self.run.text_to_speech(tts_input)
@ -999,6 +1000,36 @@ class PipelineInput:
await asyncio.gather(*prepare_tasks)
async def _wake_word_audio_stream(
audio_stream: AsyncIterable[bytes],
stt_audio_buffer: RingBuffer | None,
wake_word_vad: VoiceActivityTimeout | None,
sample_rate: int = 16000,
sample_width: int = 2,
) -> AsyncIterable[tuple[bytes, int]]:
"""Yield audio chunks with timestamps (milliseconds since start of stream).
Adds audio to a ring buffer that will be forwarded to speech-to-text after
detection. Times out if VAD detects enough silence.
"""
ms_per_sample = sample_rate // 1000
timestamp_ms = 0
async for chunk in audio_stream:
yield chunk, timestamp_ms
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
# Wake-word-detection occurs *after* the wake word was actually
# spoken. Keeping audio right before detection allows the voice
# command to be spoken immediately after the wake word.
if stt_audio_buffer is not None:
stt_audio_buffer.put(chunk)
if (wake_word_vad is not None) and (not wake_word_vad.process(chunk)):
raise WakeWordTimeoutError(
code="wake-word-timeout", message="Wake word was not detected"
)
class PipelinePreferred(CollectionError):
"""Raised when attempting to delete the preferred pipelen."""

View file

@ -0,0 +1,57 @@
"""Implementation of a ring buffer using bytearray."""
class RingBuffer:
"""Basic ring buffer using a bytearray.
Not threadsafe.
"""
def __init__(self, maxlen: int) -> None:
"""Initialize empty buffer."""
self._buffer = bytearray(maxlen)
self._pos = 0
self._length = 0
self._maxlen = maxlen
@property
def maxlen(self) -> int:
"""Return the maximum size of the buffer."""
return self._maxlen
@property
def pos(self) -> int:
"""Return the current put position."""
return self._pos
def __len__(self) -> int:
"""Return the length of data stored in the buffer."""
return self._length
def put(self, data: bytes) -> None:
"""Put a chunk of data into the buffer, possibly wrapping around."""
data_len = len(data)
new_pos = self._pos + data_len
if new_pos >= self._maxlen:
# Split into two chunks
num_bytes_1 = self._maxlen - self._pos
num_bytes_2 = new_pos - self._maxlen
self._buffer[self._pos : self._maxlen] = data[:num_bytes_1]
self._buffer[:num_bytes_2] = data[num_bytes_1:]
new_pos = new_pos - self._maxlen
else:
# Entire chunk fits at current position
self._buffer[self._pos : self._pos + data_len] = data
self._pos = new_pos
self._length = min(self._maxlen, self._length + data_len)
def getvalue(self) -> bytes:
"""Get bytes written to the buffer."""
if (self._pos + self._length) <= self._maxlen:
# Single chunk
return bytes(self._buffer[: self._length])
# Two chunks
return bytes(self._buffer[self._pos :] + self._buffer[: self._pos])

View file

@ -1,12 +1,15 @@
"""Voice activity detection."""
from __future__ import annotations
from collections.abc import Iterable
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Final
import webrtcvad
_SAMPLE_RATE = 16000
_SAMPLE_RATE: Final = 16000 # Hz
_SAMPLE_WIDTH: Final = 2 # bytes
class VadSensitivity(StrEnum):
@ -29,6 +32,45 @@ class VadSensitivity(StrEnum):
return 1.0
class AudioBuffer:
"""Fixed-sized audio buffer with variable internal length."""
def __init__(self, maxlen: int) -> None:
"""Initialize buffer."""
self._buffer = bytearray(maxlen)
self._length = 0
@property
def length(self) -> int:
"""Get number of bytes currently in the buffer."""
return self._length
def clear(self) -> None:
"""Clear the buffer."""
self._length = 0
def append(self, data: bytes) -> None:
"""Append bytes to the buffer, increasing the internal length."""
data_len = len(data)
if (self._length + data_len) > len(self._buffer):
raise ValueError("Length cannot be greater than buffer size")
self._buffer[self._length : self._length + data_len] = data
self._length += data_len
def bytes(self) -> bytes:
"""Convert written portion of buffer to bytes."""
return bytes(self._buffer[: self._length])
def __len__(self) -> int:
"""Get the number of bytes currently in the buffer."""
return self._length
def __bool__(self) -> bool:
"""Return True if there are bytes in the buffer."""
return self._length > 0
@dataclass
class VoiceCommandSegmenter:
"""Segments an audio stream into voice commands using webrtcvad."""
@ -36,7 +78,7 @@ class VoiceCommandSegmenter:
vad_mode: int = 3
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
vad_frames: int = 480 # 30 ms
vad_samples_per_chunk: int = 480 # 30 ms
"""Must be 10, 20, or 30 ms at 16Khz."""
speech_seconds: float = 0.3
@ -67,20 +109,23 @@ class VoiceCommandSegmenter:
"""Seconds left before resetting start/stop time counters."""
_vad: webrtcvad.Vad = None
_audio_buffer: bytes = field(default_factory=bytes)
_bytes_per_chunk: int = 480 * 2 # 16-bit samples
_seconds_per_chunk: float = 0.03 # 30 ms
_leftover_chunk_buffer: AudioBuffer = field(init=False)
_bytes_per_chunk: int = field(init=False)
_seconds_per_chunk: float = field(init=False)
def __post_init__(self) -> None:
"""Initialize VAD."""
self._vad = webrtcvad.Vad(self.vad_mode)
self._bytes_per_chunk = self.vad_frames * 2
self._seconds_per_chunk = self.vad_frames / _SAMPLE_RATE
self._bytes_per_chunk = self.vad_samples_per_chunk * _SAMPLE_WIDTH
self._seconds_per_chunk = self.vad_samples_per_chunk / _SAMPLE_RATE
self._leftover_chunk_buffer = AudioBuffer(
self.vad_samples_per_chunk * _SAMPLE_WIDTH
)
self.reset()
def reset(self) -> None:
"""Reset all counters and state."""
self._audio_buffer = b""
self._leftover_chunk_buffer.clear()
self._speech_seconds_left = self.speech_seconds
self._silence_seconds_left = self.silence_seconds
self._timeout_seconds_left = self.timeout_seconds
@ -92,27 +137,20 @@ class VoiceCommandSegmenter:
Returns False when command is done.
"""
self._audio_buffer += samples
# Process in 10, 20, or 30 ms chunks.
num_chunks = len(self._audio_buffer) // self._bytes_per_chunk
for chunk_idx in range(num_chunks):
chunk_offset = chunk_idx * self._bytes_per_chunk
chunk = self._audio_buffer[
chunk_offset : chunk_offset + self._bytes_per_chunk
]
for chunk in chunk_samples(
samples, self._bytes_per_chunk, self._leftover_chunk_buffer
):
if not self._process_chunk(chunk):
self.reset()
return False
if num_chunks > 0:
# Remove from buffer
self._audio_buffer = self._audio_buffer[
num_chunks * self._bytes_per_chunk :
]
return True
@property
def audio_buffer(self) -> bytes:
"""Get partial chunk in the audio buffer."""
return self._leftover_chunk_buffer.bytes()
def _process_chunk(self, chunk: bytes) -> bool:
"""Process a single chunk of 16-bit 16Khz mono audio.
@ -163,7 +201,7 @@ class VoiceActivityTimeout:
vad_mode: int = 3
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
vad_frames: int = 480 # 30 ms
vad_samples_per_chunk: int = 480 # 30 ms
"""Must be 10, 20, or 30 ms at 16Khz."""
_silence_seconds_left: float = 0.0
@ -173,20 +211,23 @@ class VoiceActivityTimeout:
"""Seconds left before resetting start/stop time counters."""
_vad: webrtcvad.Vad = None
_audio_buffer: bytes = field(default_factory=bytes)
_bytes_per_chunk: int = 480 * 2 # 16-bit samples
_seconds_per_chunk: float = 0.03 # 30 ms
_leftover_chunk_buffer: AudioBuffer = field(init=False)
_bytes_per_chunk: int = field(init=False)
_seconds_per_chunk: float = field(init=False)
def __post_init__(self) -> None:
"""Initialize VAD."""
self._vad = webrtcvad.Vad(self.vad_mode)
self._bytes_per_chunk = self.vad_frames * 2
self._seconds_per_chunk = self.vad_frames / _SAMPLE_RATE
self._bytes_per_chunk = self.vad_samples_per_chunk * _SAMPLE_WIDTH
self._seconds_per_chunk = self.vad_samples_per_chunk / _SAMPLE_RATE
self._leftover_chunk_buffer = AudioBuffer(
self.vad_samples_per_chunk * _SAMPLE_WIDTH
)
self.reset()
def reset(self) -> None:
"""Reset all counters and state."""
self._audio_buffer = b""
self._leftover_chunk_buffer.clear()
self._silence_seconds_left = self.silence_seconds
self._reset_seconds_left = self.reset_seconds
@ -195,24 +236,12 @@ class VoiceActivityTimeout:
Returns False when timeout is reached.
"""
self._audio_buffer += samples
# Process in 10, 20, or 30 ms chunks.
num_chunks = len(self._audio_buffer) // self._bytes_per_chunk
for chunk_idx in range(num_chunks):
chunk_offset = chunk_idx * self._bytes_per_chunk
chunk = self._audio_buffer[
chunk_offset : chunk_offset + self._bytes_per_chunk
]
for chunk in chunk_samples(
samples, self._bytes_per_chunk, self._leftover_chunk_buffer
):
if not self._process_chunk(chunk):
return False
if num_chunks > 0:
# Remove from buffer
self._audio_buffer = self._audio_buffer[
num_chunks * self._bytes_per_chunk :
]
return True
def _process_chunk(self, chunk: bytes) -> bool:
@ -239,3 +268,37 @@ class VoiceActivityTimeout:
)
return True
def chunk_samples(
samples: bytes,
bytes_per_chunk: int,
leftover_chunk_buffer: AudioBuffer,
) -> Iterable[bytes]:
"""Yield fixed-sized chunks from samples, keeping leftover bytes from previous call(s)."""
if (len(leftover_chunk_buffer) + len(samples)) < bytes_per_chunk:
# Extend leftover chunk, but not enough samples to complete it
leftover_chunk_buffer.append(samples)
return
next_chunk_idx = 0
if leftover_chunk_buffer:
# Add to leftover chunk from previous call(s).
bytes_to_copy = bytes_per_chunk - len(leftover_chunk_buffer)
leftover_chunk_buffer.append(samples[:bytes_to_copy])
next_chunk_idx = bytes_to_copy
# Process full chunk in buffer
yield leftover_chunk_buffer.bytes()
leftover_chunk_buffer.clear()
while next_chunk_idx < len(samples) - bytes_per_chunk + 1:
# Process full chunk
yield samples[next_chunk_idx : next_chunk_idx + bytes_per_chunk]
next_chunk_idx += bytes_per_chunk
# Capture leftover chunks
if rest_samples := samples[next_chunk_idx:]:
leftover_chunk_buffer.append(rest_samples)

View file

@ -79,8 +79,6 @@ class WakeWordDetectionEntity(RestoreEntity):
@final
def state(self) -> str | None:
"""Return the state of the entity."""
if self.__last_detected is None:
return None
return self.__last_detected
@property

View file

@ -317,6 +317,12 @@
}),
'type': <PipelineEventType.STT_VAD_START: 'stt-vad-start'>,
}),
dict({
'data': dict({
'timestamp': 1500,
}),
'type': <PipelineEventType.STT_VAD_END: 'stt-vad-end'>,
}),
dict({
'data': dict({
'stt_output': dict({

View file

@ -1,7 +1,7 @@
"""Test Voice Assistant init."""
from dataclasses import asdict
import itertools as it
from unittest.mock import ANY
from unittest.mock import ANY, patch
import pytest
from syrupy.assertion import SnapshotAssertion
@ -49,9 +49,9 @@ async def test_pipeline_from_audio_stream_auto(
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
Context(),
events.append,
stt.SpeechMetadata(
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
@ -59,7 +59,7 @@ async def test_pipeline_from_audio_stream_auto(
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
audio_data(),
stt_stream=audio_data(),
)
assert process_events(events) == snapshot
@ -108,9 +108,9 @@ async def test_pipeline_from_audio_stream_legacy(
# Use the created pipeline
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
Context(),
events.append,
stt.SpeechMetadata(
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="en-UK",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
@ -118,7 +118,7 @@ async def test_pipeline_from_audio_stream_legacy(
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
audio_data(),
stt_stream=audio_data(),
pipeline_id=pipeline_id,
)
@ -168,9 +168,9 @@ async def test_pipeline_from_audio_stream_entity(
# Use the created pipeline
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
Context(),
events.append,
stt.SpeechMetadata(
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="en-UK",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
@ -178,7 +178,7 @@ async def test_pipeline_from_audio_stream_entity(
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
audio_data(),
stt_stream=audio_data(),
pipeline_id=pipeline_id,
)
@ -229,9 +229,9 @@ async def test_pipeline_from_audio_stream_no_stt(
with pytest.raises(assist_pipeline.pipeline.PipelineRunValidationError):
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
Context(),
events.append,
stt.SpeechMetadata(
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="en-UK",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
@ -239,7 +239,7 @@ async def test_pipeline_from_audio_stream_no_stt(
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
audio_data(),
stt_stream=audio_data(),
pipeline_id=pipeline_id,
)
@ -268,9 +268,9 @@ async def test_pipeline_from_audio_stream_unknown_pipeline(
with pytest.raises(assist_pipeline.PipelineNotFound):
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
Context(),
events.append,
stt.SpeechMetadata(
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="en-UK",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
@ -278,7 +278,7 @@ async def test_pipeline_from_audio_stream_unknown_pipeline(
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
audio_data(),
stt_stream=audio_data(),
pipeline_id="blah",
)
@ -308,26 +308,38 @@ async def test_pipeline_from_audio_stream_wake_word(
yield b"wake word"
yield b"part1"
yield b"part2"
yield b"end"
yield b""
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
Context(),
events.append,
stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
audio_data(),
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
wake_word_settings=assist_pipeline.WakeWordSettings(
audio_seconds_to_buffer=1.5
),
)
def continue_stt(self, chunk):
# Ensure stt_vad_start event is triggered
self.in_command = True
# Stop on fake end chunk to trigger stt_vad_end
return chunk != b"end"
with patch(
"homeassistant.components.assist_pipeline.pipeline.VoiceCommandSegmenter.process",
continue_stt,
):
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
wake_word_settings=assist_pipeline.WakeWordSettings(
audio_seconds_to_buffer=1.5
),
)
assert process_events(events) == snapshot

View file

@ -0,0 +1,38 @@
"""Tests for audio ring buffer."""
from homeassistant.components.assist_pipeline.ring_buffer import RingBuffer
def test_ring_buffer_empty() -> None:
"""Test empty ring buffer."""
rb = RingBuffer(10)
assert rb.maxlen == 10
assert rb.pos == 0
assert rb.getvalue() == b""
def test_ring_buffer_put_1() -> None:
"""Test putting some data smaller than the maximum length."""
rb = RingBuffer(10)
rb.put(bytes([1, 2, 3, 4, 5]))
assert len(rb) == 5
assert rb.pos == 5
assert rb.getvalue() == bytes([1, 2, 3, 4, 5])
def test_ring_buffer_put_2() -> None:
"""Test putting some data past the end of the buffer."""
rb = RingBuffer(10)
rb.put(bytes([1, 2, 3, 4, 5]))
rb.put(bytes([6, 7, 8, 9, 10, 11, 12]))
assert len(rb) == 10
assert rb.pos == 2
assert rb.getvalue() == bytes([3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
def test_ring_buffer_put_too_large() -> None:
"""Test putting data too large for the buffer."""
rb = RingBuffer(10)
rb.put(bytes([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]))
assert len(rb) == 10
assert rb.pos == 2
assert rb.getvalue() == bytes([3, 4, 5, 6, 7, 8, 9, 10, 11, 12])

View file

@ -1,7 +1,12 @@
"""Tests for webrtcvad voice command segmenter."""
import itertools as it
from unittest.mock import patch
from homeassistant.components.assist_pipeline.vad import VoiceCommandSegmenter
from homeassistant.components.assist_pipeline.vad import (
AudioBuffer,
VoiceCommandSegmenter,
chunk_samples,
)
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit
@ -36,3 +41,87 @@ def test_speech() -> None:
# silence
# False return value indicates voice command is finished
assert not segmenter.process(bytes(_ONE_SECOND))
def test_audio_buffer() -> None:
"""Test audio buffer wrapping."""
def is_speech(self, chunk, sample_rate):
"""Disable VAD."""
return False
with patch(
"webrtcvad.Vad.is_speech",
new=is_speech,
):
segmenter = VoiceCommandSegmenter()
bytes_per_chunk = segmenter.vad_samples_per_chunk * 2
with patch.object(
segmenter, "_process_chunk", return_value=True
) as mock_process:
# Partially fill audio buffer
half_chunk = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk // 2))
segmenter.process(half_chunk)
assert not mock_process.called
assert segmenter.audio_buffer == half_chunk
# Fill and wrap with 1/4 chunk left over
three_quarters_chunk = bytes(
it.islice(it.cycle(range(256)), int(0.75 * bytes_per_chunk))
)
segmenter.process(three_quarters_chunk)
assert mock_process.call_count == 1
assert (
segmenter.audio_buffer
== three_quarters_chunk[
len(three_quarters_chunk) - (bytes_per_chunk // 4) :
]
)
assert (
mock_process.call_args[0][0]
== half_chunk + three_quarters_chunk[: bytes_per_chunk // 2]
)
# Run 2 chunks through
segmenter.reset()
assert len(segmenter.audio_buffer) == 0
mock_process.reset_mock()
two_chunks = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk * 2))
segmenter.process(two_chunks)
assert mock_process.call_count == 2
assert len(segmenter.audio_buffer) == 0
assert mock_process.call_args_list[0][0][0] == two_chunks[:bytes_per_chunk]
assert mock_process.call_args_list[1][0][0] == two_chunks[bytes_per_chunk:]
def test_partial_chunk() -> None:
"""Test that chunk_samples returns when given a partial chunk."""
bytes_per_chunk = 5
samples = bytes([1, 2, 3])
leftover_chunk_buffer = AudioBuffer(bytes_per_chunk)
chunks = list(chunk_samples(samples, bytes_per_chunk, leftover_chunk_buffer))
assert len(chunks) == 0
assert leftover_chunk_buffer.bytes() == samples
def test_chunk_samples_leftover() -> None:
"""Test that chunk_samples property keeps left over bytes across calls."""
bytes_per_chunk = 5
samples = bytes([1, 2, 3, 4, 5, 6])
leftover_chunk_buffer = AudioBuffer(bytes_per_chunk)
chunks = list(chunk_samples(samples, bytes_per_chunk, leftover_chunk_buffer))
assert len(chunks) == 1
assert leftover_chunk_buffer.bytes() == bytes([6])
# Add some more to the chunk
chunks = list(chunk_samples(samples, bytes_per_chunk, leftover_chunk_buffer))
assert len(chunks) == 1
assert leftover_chunk_buffer.bytes() == bytes([5, 6])