Use webrtc-noise-gain for audio enhancement in Assist pipelines (#100698)

* Use webrtc-noise-gain instead of webrtcvad package

* Switching to ProcessedAudioChunk

* Refactor VAD and fix tests

* Add vad no chunking test

* Add test that runs audio enhancements
This commit is contained in:
Michael Hansen 2023-09-25 19:03:50 -05:00 committed by GitHub
parent a4f7f3ba7e
commit 785618909a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 707 additions and 258 deletions

View file

@ -12,6 +12,7 @@ from homeassistant.helpers.typing import ConfigType
from .const import DATA_CONFIG, DOMAIN from .const import DATA_CONFIG, DOMAIN
from .error import PipelineNotFound from .error import PipelineNotFound
from .pipeline import ( from .pipeline import (
AudioSettings,
Pipeline, Pipeline,
PipelineEvent, PipelineEvent,
PipelineEventCallback, PipelineEventCallback,
@ -33,6 +34,7 @@ __all__ = (
"async_get_pipelines", "async_get_pipelines",
"async_setup", "async_setup",
"async_pipeline_from_audio_stream", "async_pipeline_from_audio_stream",
"AudioSettings",
"Pipeline", "Pipeline",
"PipelineEvent", "PipelineEvent",
"PipelineEventType", "PipelineEventType",
@ -71,6 +73,7 @@ async def async_pipeline_from_audio_stream(
conversation_id: str | None = None, conversation_id: str | None = None,
tts_audio_output: str | None = None, tts_audio_output: str | None = None,
wake_word_settings: WakeWordSettings | None = None, wake_word_settings: WakeWordSettings | None = None,
audio_settings: AudioSettings | None = None,
device_id: str | None = None, device_id: str | None = None,
start_stage: PipelineStage = PipelineStage.STT, start_stage: PipelineStage = PipelineStage.STT,
end_stage: PipelineStage = PipelineStage.TTS, end_stage: PipelineStage = PipelineStage.TTS,
@ -93,6 +96,7 @@ async def async_pipeline_from_audio_stream(
event_callback=event_callback, event_callback=event_callback,
tts_audio_output=tts_audio_output, tts_audio_output=tts_audio_output,
wake_word_settings=wake_word_settings, wake_word_settings=wake_word_settings,
audio_settings=audio_settings or AudioSettings(),
), ),
) )
await pipeline_input.validate() await pipeline_input.validate()

View file

@ -6,5 +6,5 @@
"documentation": "https://www.home-assistant.io/integrations/assist_pipeline", "documentation": "https://www.home-assistant.io/integrations/assist_pipeline",
"iot_class": "local_push", "iot_class": "local_push",
"quality_scale": "internal", "quality_scale": "internal",
"requirements": ["webrtcvad==2.0.10"] "requirements": ["webrtc-noise-gain==1.1.0"]
} }

View file

@ -1,7 +1,9 @@
"""Classes for voice assistant pipelines.""" """Classes for voice assistant pipelines."""
from __future__ import annotations from __future__ import annotations
import array
import asyncio import asyncio
from collections import deque
from collections.abc import AsyncGenerator, AsyncIterable, Callable, Iterable from collections.abc import AsyncGenerator, AsyncIterable, Callable, Iterable
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from enum import StrEnum from enum import StrEnum
@ -10,10 +12,11 @@ from pathlib import Path
from queue import Queue from queue import Queue
from threading import Thread from threading import Thread
import time import time
from typing import Any, cast from typing import Any, Final, cast
import wave import wave
import voluptuous as vol import voluptuous as vol
from webrtc_noise_gain import AudioProcessor
from homeassistant.components import ( from homeassistant.components import (
conversation, conversation,
@ -54,8 +57,7 @@ from .error import (
WakeWordDetectionError, WakeWordDetectionError,
WakeWordTimeoutError, WakeWordTimeoutError,
) )
from .ring_buffer import RingBuffer from .vad import AudioBuffer, VoiceActivityTimeout, VoiceCommandSegmenter, chunk_samples
from .vad import VoiceActivityTimeout, VoiceCommandSegmenter
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -95,6 +97,9 @@ STORED_PIPELINE_RUNS = 10
SAVE_DELAY = 10 SAVE_DELAY = 10
AUDIO_PROCESSOR_SAMPLES: Final = 160 # 10 ms @ 16 Khz
AUDIO_PROCESSOR_BYTES: Final = AUDIO_PROCESSOR_SAMPLES * 2 # 16-bit samples
async def _async_resolve_default_pipeline_settings( async def _async_resolve_default_pipeline_settings(
hass: HomeAssistant, hass: HomeAssistant,
@ -393,6 +398,60 @@ class WakeWordSettings:
"""Seconds of audio to buffer before detection and forward to STT.""" """Seconds of audio to buffer before detection and forward to STT."""
@dataclass(frozen=True)
class AudioSettings:
"""Settings for pipeline audio processing."""
noise_suppression_level: int = 0
"""Level of noise suppression (0 = disabled, 4 = max)"""
auto_gain_dbfs: int = 0
"""Amount of automatic gain in dbFS (0 = disabled, 31 = max)"""
volume_multiplier: float = 1.0
"""Multiplier used directly on PCM samples (1.0 = no change, 2.0 = twice as loud)"""
is_vad_enabled: bool = True
"""True if VAD is used to determine the end of the voice command."""
is_chunking_enabled: bool = True
"""True if audio is automatically split into 10 ms chunks (required for VAD, etc.)"""
def __post_init__(self) -> None:
"""Verify settings post-initialization."""
if (self.noise_suppression_level < 0) or (self.noise_suppression_level > 4):
raise ValueError("noise_suppression_level must be in [0, 4]")
if (self.auto_gain_dbfs < 0) or (self.auto_gain_dbfs > 31):
raise ValueError("auto_gain_dbfs must be in [0, 31]")
if self.needs_processor and (not self.is_chunking_enabled):
raise ValueError("Chunking must be enabled for audio processing")
@property
def needs_processor(self) -> bool:
"""True if an audio processor is needed."""
return (
self.is_vad_enabled
or (self.noise_suppression_level > 0)
or (self.auto_gain_dbfs > 0)
)
@dataclass(frozen=True, slots=True)
class ProcessedAudioChunk:
"""Processed audio chunk and metadata."""
audio: bytes
"""Raw PCM audio @ 16Khz with 16-bit mono samples"""
timestamp_ms: int
"""Timestamp relative to start of audio stream (milliseconds)"""
is_speech: bool | None
"""True if audio chunk likely contains speech, False if not, None if unknown"""
@dataclass @dataclass
class PipelineRun: class PipelineRun:
"""Running context for a pipeline.""" """Running context for a pipeline."""
@ -408,6 +467,7 @@ class PipelineRun:
intent_agent: str | None = None intent_agent: str | None = None
tts_audio_output: str | None = None tts_audio_output: str | None = None
wake_word_settings: WakeWordSettings | None = None wake_word_settings: WakeWordSettings | None = None
audio_settings: AudioSettings = field(default_factory=AudioSettings)
id: str = field(default_factory=ulid_util.ulid) id: str = field(default_factory=ulid_util.ulid)
stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False) stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False)
@ -422,6 +482,12 @@ class PipelineRun:
debug_recording_queue: Queue[str | bytes | None] | None = None debug_recording_queue: Queue[str | bytes | None] | None = None
"""Queue to communicate with debug recording thread""" """Queue to communicate with debug recording thread"""
audio_processor: AudioProcessor | None = None
"""VAD/noise suppression/auto gain"""
audio_processor_buffer: AudioBuffer = field(init=False)
"""Buffer used when splitting audio into chunks for audio processing"""
def __post_init__(self) -> None: def __post_init__(self) -> None:
"""Set language for pipeline.""" """Set language for pipeline."""
self.language = self.pipeline.language or self.hass.config.language self.language = self.pipeline.language or self.hass.config.language
@ -439,6 +505,14 @@ class PipelineRun:
) )
pipeline_data.pipeline_runs[self.pipeline.id][self.id] = PipelineRunDebug() pipeline_data.pipeline_runs[self.pipeline.id][self.id] = PipelineRunDebug()
# Initialize with audio settings
self.audio_processor_buffer = AudioBuffer(AUDIO_PROCESSOR_BYTES)
if self.audio_settings.needs_processor:
self.audio_processor = AudioProcessor(
self.audio_settings.auto_gain_dbfs,
self.audio_settings.noise_suppression_level,
)
@callback @callback
def process_event(self, event: PipelineEvent) -> None: def process_event(self, event: PipelineEvent) -> None:
"""Log an event and call listener.""" """Log an event and call listener."""
@ -499,8 +573,8 @@ class PipelineRun:
async def wake_word_detection( async def wake_word_detection(
self, self,
stream: AsyncIterable[bytes], stream: AsyncIterable[ProcessedAudioChunk],
audio_chunks_for_stt: list[bytes], audio_chunks_for_stt: list[ProcessedAudioChunk],
) -> wake_word.DetectionResult | None: ) -> wake_word.DetectionResult | None:
"""Run wake-word-detection portion of pipeline. Returns detection result.""" """Run wake-word-detection portion of pipeline. Returns detection result."""
metadata_dict = asdict( metadata_dict = asdict(
@ -541,12 +615,13 @@ class PipelineRun:
# Audio chunk buffer. This audio will be forwarded to speech-to-text # Audio chunk buffer. This audio will be forwarded to speech-to-text
# after wake-word-detection. # after wake-word-detection.
num_audio_bytes_to_buffer = int( num_audio_chunks_to_buffer = int(
wake_word_settings.audio_seconds_to_buffer * 16000 * 2 # 16-bit @ 16Khz (wake_word_settings.audio_seconds_to_buffer * 16000)
/ AUDIO_PROCESSOR_SAMPLES
) )
stt_audio_buffer: RingBuffer | None = None stt_audio_buffer: deque[ProcessedAudioChunk] | None = None
if num_audio_bytes_to_buffer > 0: if num_audio_chunks_to_buffer > 0:
stt_audio_buffer = RingBuffer(num_audio_bytes_to_buffer) stt_audio_buffer = deque(maxlen=num_audio_chunks_to_buffer)
try: try:
# Detect wake word(s) # Detect wake word(s)
@ -562,7 +637,7 @@ class PipelineRun:
if stt_audio_buffer is not None: if stt_audio_buffer is not None:
# All audio kept from right before the wake word was detected as # All audio kept from right before the wake word was detected as
# a single chunk. # a single chunk.
audio_chunks_for_stt.append(stt_audio_buffer.getvalue()) audio_chunks_for_stt.extend(stt_audio_buffer)
except WakeWordTimeoutError: except WakeWordTimeoutError:
_LOGGER.debug("Timeout during wake word detection") _LOGGER.debug("Timeout during wake word detection")
raise raise
@ -586,7 +661,11 @@ class PipelineRun:
# speech-to-text so the user does not have to pause before # speech-to-text so the user does not have to pause before
# speaking the voice command. # speaking the voice command.
for chunk_ts in result.queued_audio: for chunk_ts in result.queued_audio:
audio_chunks_for_stt.append(chunk_ts[0]) audio_chunks_for_stt.append(
ProcessedAudioChunk(
audio=chunk_ts[0], timestamp_ms=chunk_ts[1], is_speech=False
)
)
wake_word_output = asdict(result) wake_word_output = asdict(result)
@ -604,8 +683,8 @@ class PipelineRun:
async def _wake_word_audio_stream( async def _wake_word_audio_stream(
self, self,
audio_stream: AsyncIterable[bytes], audio_stream: AsyncIterable[ProcessedAudioChunk],
stt_audio_buffer: RingBuffer | None, stt_audio_buffer: deque[ProcessedAudioChunk] | None,
wake_word_vad: VoiceActivityTimeout | None, wake_word_vad: VoiceActivityTimeout | None,
sample_rate: int = 16000, sample_rate: int = 16000,
sample_width: int = 2, sample_width: int = 2,
@ -615,25 +694,24 @@ class PipelineRun:
Adds audio to a ring buffer that will be forwarded to speech-to-text after Adds audio to a ring buffer that will be forwarded to speech-to-text after
detection. Times out if VAD detects enough silence. detection. Times out if VAD detects enough silence.
""" """
ms_per_sample = sample_rate // 1000 chunk_seconds = AUDIO_PROCESSOR_SAMPLES / sample_rate
timestamp_ms = 0
async for chunk in audio_stream: async for chunk in audio_stream:
if self.debug_recording_queue is not None: if self.debug_recording_queue is not None:
self.debug_recording_queue.put_nowait(chunk) self.debug_recording_queue.put_nowait(chunk.audio)
yield chunk, timestamp_ms yield chunk.audio, chunk.timestamp_ms
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
# Wake-word-detection occurs *after* the wake word was actually # Wake-word-detection occurs *after* the wake word was actually
# spoken. Keeping audio right before detection allows the voice # spoken. Keeping audio right before detection allows the voice
# command to be spoken immediately after the wake word. # command to be spoken immediately after the wake word.
if stt_audio_buffer is not None: if stt_audio_buffer is not None:
stt_audio_buffer.put(chunk) stt_audio_buffer.append(chunk)
if (wake_word_vad is not None) and (not wake_word_vad.process(chunk)): if wake_word_vad is not None:
raise WakeWordTimeoutError( if not wake_word_vad.process(chunk_seconds, chunk.is_speech):
code="wake-word-timeout", message="Wake word was not detected" raise WakeWordTimeoutError(
) code="wake-word-timeout", message="Wake word was not detected"
)
async def prepare_speech_to_text(self, metadata: stt.SpeechMetadata) -> None: async def prepare_speech_to_text(self, metadata: stt.SpeechMetadata) -> None:
"""Prepare speech-to-text.""" """Prepare speech-to-text."""
@ -666,7 +744,7 @@ class PipelineRun:
async def speech_to_text( async def speech_to_text(
self, self,
metadata: stt.SpeechMetadata, metadata: stt.SpeechMetadata,
stream: AsyncIterable[bytes], stream: AsyncIterable[ProcessedAudioChunk],
) -> str: ) -> str:
"""Run speech-to-text portion of pipeline. Returns the spoken text.""" """Run speech-to-text portion of pipeline. Returns the spoken text."""
if isinstance(self.stt_provider, stt.Provider): if isinstance(self.stt_provider, stt.Provider):
@ -690,11 +768,13 @@ class PipelineRun:
try: try:
# Transcribe audio stream # Transcribe audio stream
stt_vad: VoiceCommandSegmenter | None = None
if self.audio_settings.is_vad_enabled:
stt_vad = VoiceCommandSegmenter()
result = await self.stt_provider.async_process_audio_stream( result = await self.stt_provider.async_process_audio_stream(
metadata, metadata,
self._speech_to_text_stream( self._speech_to_text_stream(audio_stream=stream, stt_vad=stt_vad),
audio_stream=stream, stt_vad=VoiceCommandSegmenter()
),
) )
except Exception as src_error: except Exception as src_error:
_LOGGER.exception("Unexpected error during speech-to-text") _LOGGER.exception("Unexpected error during speech-to-text")
@ -731,26 +811,25 @@ class PipelineRun:
async def _speech_to_text_stream( async def _speech_to_text_stream(
self, self,
audio_stream: AsyncIterable[bytes], audio_stream: AsyncIterable[ProcessedAudioChunk],
stt_vad: VoiceCommandSegmenter | None, stt_vad: VoiceCommandSegmenter | None,
sample_rate: int = 16000, sample_rate: int = 16000,
sample_width: int = 2, sample_width: int = 2,
) -> AsyncGenerator[bytes, None]: ) -> AsyncGenerator[bytes, None]:
"""Yield audio chunks until VAD detects silence or speech-to-text completes.""" """Yield audio chunks until VAD detects silence or speech-to-text completes."""
ms_per_sample = sample_rate // 1000 chunk_seconds = AUDIO_PROCESSOR_SAMPLES / sample_rate
sent_vad_start = False sent_vad_start = False
timestamp_ms = 0
async for chunk in audio_stream: async for chunk in audio_stream:
if self.debug_recording_queue is not None: if self.debug_recording_queue is not None:
self.debug_recording_queue.put_nowait(chunk) self.debug_recording_queue.put_nowait(chunk.audio)
if stt_vad is not None: if stt_vad is not None:
if not stt_vad.process(chunk): if not stt_vad.process(chunk_seconds, chunk.is_speech):
# Silence detected at the end of voice command # Silence detected at the end of voice command
self.process_event( self.process_event(
PipelineEvent( PipelineEvent(
PipelineEventType.STT_VAD_END, PipelineEventType.STT_VAD_END,
{"timestamp": timestamp_ms}, {"timestamp": chunk.timestamp_ms},
) )
) )
break break
@ -760,13 +839,12 @@ class PipelineRun:
self.process_event( self.process_event(
PipelineEvent( PipelineEvent(
PipelineEventType.STT_VAD_START, PipelineEventType.STT_VAD_START,
{"timestamp": timestamp_ms}, {"timestamp": chunk.timestamp_ms},
) )
) )
sent_vad_start = True sent_vad_start = True
yield chunk yield chunk.audio
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
async def prepare_recognize_intent(self) -> None: async def prepare_recognize_intent(self) -> None:
"""Prepare recognizing an intent.""" """Prepare recognizing an intent."""
@ -977,6 +1055,94 @@ class PipelineRun:
self.debug_recording_queue = None self.debug_recording_queue = None
self.debug_recording_thread = None self.debug_recording_thread = None
async def process_volume_only(
self,
audio_stream: AsyncIterable[bytes],
sample_rate: int = 16000,
sample_width: int = 2,
) -> AsyncGenerator[ProcessedAudioChunk, None]:
"""Apply volume transformation only (no VAD/audio enhancements) with optional chunking."""
ms_per_sample = sample_rate // 1000
ms_per_chunk = (AUDIO_PROCESSOR_SAMPLES // sample_width) // ms_per_sample
timestamp_ms = 0
async for chunk in audio_stream:
if self.audio_settings.volume_multiplier != 1.0:
chunk = _multiply_volume(chunk, self.audio_settings.volume_multiplier)
if self.audio_settings.is_chunking_enabled:
# 10 ms chunking
for chunk_10ms in chunk_samples(
chunk, AUDIO_PROCESSOR_BYTES, self.audio_processor_buffer
):
yield ProcessedAudioChunk(
audio=chunk_10ms,
timestamp_ms=timestamp_ms,
is_speech=None, # no VAD
)
timestamp_ms += ms_per_chunk
else:
# No chunking
yield ProcessedAudioChunk(
audio=chunk,
timestamp_ms=timestamp_ms,
is_speech=None, # no VAD
)
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
async def process_enhance_audio(
self,
audio_stream: AsyncIterable[bytes],
sample_rate: int = 16000,
sample_width: int = 2,
) -> AsyncGenerator[ProcessedAudioChunk, None]:
"""Split audio into 10 ms chunks and apply VAD/noise suppression/auto gain/volume transformation."""
assert self.audio_processor is not None
ms_per_sample = sample_rate // 1000
ms_per_chunk = (AUDIO_PROCESSOR_SAMPLES // sample_width) // ms_per_sample
timestamp_ms = 0
async for dirty_samples in audio_stream:
if self.audio_settings.volume_multiplier != 1.0:
# Static gain
dirty_samples = _multiply_volume(
dirty_samples, self.audio_settings.volume_multiplier
)
# Split into 10ms chunks for audio enhancements/VAD
for dirty_10ms_chunk in chunk_samples(
dirty_samples, AUDIO_PROCESSOR_BYTES, self.audio_processor_buffer
):
ap_result = self.audio_processor.Process10ms(dirty_10ms_chunk)
yield ProcessedAudioChunk(
audio=ap_result.audio,
timestamp_ms=timestamp_ms,
is_speech=ap_result.is_speech,
)
timestamp_ms += ms_per_chunk
def _multiply_volume(chunk: bytes, volume_multiplier: float) -> bytes:
"""Multiplies 16-bit PCM samples by a constant."""
return array.array(
"h",
[
int(
# Clamp to signed 16-bit range
max(
-32767,
min(
32767,
value * volume_multiplier,
),
)
)
for value in array.array("h", chunk)
],
).tobytes()
def _pipeline_debug_recording_thread_proc( def _pipeline_debug_recording_thread_proc(
run_recording_dir: Path, run_recording_dir: Path,
@ -1042,14 +1208,23 @@ class PipelineInput:
"""Run pipeline.""" """Run pipeline."""
self.run.start(device_id=self.device_id) self.run.start(device_id=self.device_id)
current_stage: PipelineStage | None = self.run.start_stage current_stage: PipelineStage | None = self.run.start_stage
stt_audio_buffer: list[bytes] = [] stt_audio_buffer: list[ProcessedAudioChunk] = []
stt_processed_stream: AsyncIterable[ProcessedAudioChunk] | None = None
if self.stt_stream is not None:
if self.run.audio_settings.needs_processor:
# VAD/noise suppression/auto gain/volume
stt_processed_stream = self.run.process_enhance_audio(self.stt_stream)
else:
# Volume multiplier only
stt_processed_stream = self.run.process_volume_only(self.stt_stream)
try: try:
if current_stage == PipelineStage.WAKE_WORD: if current_stage == PipelineStage.WAKE_WORD:
# wake-word-detection # wake-word-detection
assert self.stt_stream is not None assert stt_processed_stream is not None
detect_result = await self.run.wake_word_detection( detect_result = await self.run.wake_word_detection(
self.stt_stream, stt_audio_buffer stt_processed_stream, stt_audio_buffer
) )
if detect_result is None: if detect_result is None:
# No wake word. Abort the rest of the pipeline. # No wake word. Abort the rest of the pipeline.
@ -1062,28 +1237,30 @@ class PipelineInput:
intent_input = self.intent_input intent_input = self.intent_input
if current_stage == PipelineStage.STT: if current_stage == PipelineStage.STT:
assert self.stt_metadata is not None assert self.stt_metadata is not None
assert self.stt_stream is not None assert stt_processed_stream is not None
stt_stream = self.stt_stream stt_input_stream = stt_processed_stream
if stt_audio_buffer: if stt_audio_buffer:
# Send audio in the buffer first to speech-to-text, then move on to stt_stream. # Send audio in the buffer first to speech-to-text, then move on to stt_stream.
# This is basically an async itertools.chain. # This is basically an async itertools.chain.
async def buffer_then_audio_stream() -> AsyncGenerator[bytes, None]: async def buffer_then_audio_stream() -> AsyncGenerator[
ProcessedAudioChunk, None
]:
# Buffered audio # Buffered audio
for chunk in stt_audio_buffer: for chunk in stt_audio_buffer:
yield chunk yield chunk
# Streamed audio # Streamed audio
assert self.stt_stream is not None assert stt_processed_stream is not None
async for chunk in self.stt_stream: async for chunk in stt_processed_stream:
yield chunk yield chunk
stt_stream = buffer_then_audio_stream() stt_input_stream = buffer_then_audio_stream()
intent_input = await self.run.speech_to_text( intent_input = await self.run.speech_to_text(
self.stt_metadata, self.stt_metadata,
stt_stream, stt_input_stream,
) )
current_stage = PipelineStage.INTENT current_stage = PipelineStage.INTENT

View file

@ -1,12 +1,13 @@
"""Voice activity detection.""" """Voice activity detection."""
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import dataclass, field from dataclasses import dataclass
from enum import StrEnum from enum import StrEnum
from typing import Final from typing import Final, cast
import webrtcvad from webrtc_noise_gain import AudioProcessor
_SAMPLE_RATE: Final = 16000 # Hz _SAMPLE_RATE: Final = 16000 # Hz
_SAMPLE_WIDTH: Final = 2 # bytes _SAMPLE_WIDTH: Final = 2 # bytes
@ -32,6 +33,38 @@ class VadSensitivity(StrEnum):
return 1.0 return 1.0
class VoiceActivityDetector(ABC):
"""Base class for voice activity detectors (VAD)."""
@abstractmethod
def is_speech(self, chunk: bytes) -> bool:
"""Return True if audio chunk contains speech."""
@property
@abstractmethod
def samples_per_chunk(self) -> int | None:
"""Return number of samples per chunk or None if chunking is not required."""
class WebRtcVad(VoiceActivityDetector):
"""Voice activity detector based on webrtc."""
def __init__(self) -> None:
"""Initialize webrtcvad."""
# Just VAD: no noise suppression or auto gain
self._audio_processor = AudioProcessor(0, 0)
def is_speech(self, chunk: bytes) -> bool:
"""Return True if audio chunk contains speech."""
result = self._audio_processor.Process10ms(chunk)
return cast(bool, result.is_speech)
@property
def samples_per_chunk(self) -> int | None:
"""Return 10 ms."""
return int(0.01 * _SAMPLE_RATE) # 10 ms
class AudioBuffer: class AudioBuffer:
"""Fixed-sized audio buffer with variable internal length.""" """Fixed-sized audio buffer with variable internal length."""
@ -73,13 +106,7 @@ class AudioBuffer:
@dataclass @dataclass
class VoiceCommandSegmenter: class VoiceCommandSegmenter:
"""Segments an audio stream into voice commands using webrtcvad.""" """Segments an audio stream into voice commands."""
vad_mode: int = 3
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
vad_samples_per_chunk: int = 480 # 30 ms
"""Must be 10, 20, or 30 ms at 16Khz."""
speech_seconds: float = 0.3 speech_seconds: float = 0.3
"""Seconds of speech before voice command has started.""" """Seconds of speech before voice command has started."""
@ -108,85 +135,85 @@ class VoiceCommandSegmenter:
_reset_seconds_left: float = 0.0 _reset_seconds_left: float = 0.0
"""Seconds left before resetting start/stop time counters.""" """Seconds left before resetting start/stop time counters."""
_vad: webrtcvad.Vad = None
_leftover_chunk_buffer: AudioBuffer = field(init=False)
_bytes_per_chunk: int = field(init=False)
_seconds_per_chunk: float = field(init=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
"""Initialize VAD.""" """Reset after initialization."""
self._vad = webrtcvad.Vad(self.vad_mode)
self._bytes_per_chunk = self.vad_samples_per_chunk * _SAMPLE_WIDTH
self._seconds_per_chunk = self.vad_samples_per_chunk / _SAMPLE_RATE
self._leftover_chunk_buffer = AudioBuffer(
self.vad_samples_per_chunk * _SAMPLE_WIDTH
)
self.reset() self.reset()
def reset(self) -> None: def reset(self) -> None:
"""Reset all counters and state.""" """Reset all counters and state."""
self._leftover_chunk_buffer.clear()
self._speech_seconds_left = self.speech_seconds self._speech_seconds_left = self.speech_seconds
self._silence_seconds_left = self.silence_seconds self._silence_seconds_left = self.silence_seconds
self._timeout_seconds_left = self.timeout_seconds self._timeout_seconds_left = self.timeout_seconds
self._reset_seconds_left = self.reset_seconds self._reset_seconds_left = self.reset_seconds
self.in_command = False self.in_command = False
def process(self, samples: bytes) -> bool: def process(self, chunk_seconds: float, is_speech: bool | None) -> bool:
"""Process 16-bit 16Khz mono audio samples. """Process samples using external VAD.
Returns False when command is done. Returns False when command is done.
""" """
for chunk in chunk_samples( self._timeout_seconds_left -= chunk_seconds
samples, self._bytes_per_chunk, self._leftover_chunk_buffer
):
if not self._process_chunk(chunk):
self.reset()
return False
return True
@property
def audio_buffer(self) -> bytes:
"""Get partial chunk in the audio buffer."""
return self._leftover_chunk_buffer.bytes()
def _process_chunk(self, chunk: bytes) -> bool:
"""Process a single chunk of 16-bit 16Khz mono audio.
Returns False when command is done.
"""
is_speech = self._vad.is_speech(chunk, _SAMPLE_RATE)
self._timeout_seconds_left -= self._seconds_per_chunk
if self._timeout_seconds_left <= 0: if self._timeout_seconds_left <= 0:
self.reset()
return False return False
if not self.in_command: if not self.in_command:
if is_speech: if is_speech:
self._reset_seconds_left = self.reset_seconds self._reset_seconds_left = self.reset_seconds
self._speech_seconds_left -= self._seconds_per_chunk self._speech_seconds_left -= chunk_seconds
if self._speech_seconds_left <= 0: if self._speech_seconds_left <= 0:
# Inside voice command # Inside voice command
self.in_command = True self.in_command = True
else: else:
# Reset if enough silence # Reset if enough silence
self._reset_seconds_left -= self._seconds_per_chunk self._reset_seconds_left -= chunk_seconds
if self._reset_seconds_left <= 0: if self._reset_seconds_left <= 0:
self._speech_seconds_left = self.speech_seconds self._speech_seconds_left = self.speech_seconds
elif not is_speech: elif not is_speech:
self._reset_seconds_left = self.reset_seconds self._reset_seconds_left = self.reset_seconds
self._silence_seconds_left -= self._seconds_per_chunk self._silence_seconds_left -= chunk_seconds
if self._silence_seconds_left <= 0: if self._silence_seconds_left <= 0:
self.reset()
return False return False
else: else:
# Reset if enough speech # Reset if enough speech
self._reset_seconds_left -= self._seconds_per_chunk self._reset_seconds_left -= chunk_seconds
if self._reset_seconds_left <= 0: if self._reset_seconds_left <= 0:
self._silence_seconds_left = self.silence_seconds self._silence_seconds_left = self.silence_seconds
return True return True
def process_with_vad(
self,
chunk: bytes,
vad: VoiceActivityDetector,
leftover_chunk_buffer: AudioBuffer | None,
) -> bool:
"""Process an audio chunk using an external VAD.
A buffer is required if the VAD requires fixed-sized audio chunks (usually the case).
Returns False when voice command is finished.
"""
if vad.samples_per_chunk is None:
# No chunking
chunk_seconds = (len(chunk) // _SAMPLE_WIDTH) / _SAMPLE_RATE
is_speech = vad.is_speech(chunk)
return self.process(chunk_seconds, is_speech)
if leftover_chunk_buffer is None:
raise ValueError("leftover_chunk_buffer is required when vad uses chunking")
# With chunking
seconds_per_chunk = vad.samples_per_chunk / _SAMPLE_RATE
bytes_per_chunk = vad.samples_per_chunk * _SAMPLE_WIDTH
for vad_chunk in chunk_samples(chunk, bytes_per_chunk, leftover_chunk_buffer):
is_speech = vad.is_speech(vad_chunk)
if not self.process(seconds_per_chunk, is_speech):
return False
return True
@dataclass @dataclass
class VoiceActivityTimeout: class VoiceActivityTimeout:
@ -198,73 +225,43 @@ class VoiceActivityTimeout:
reset_seconds: float = 0.5 reset_seconds: float = 0.5
"""Seconds of speech before resetting timeout.""" """Seconds of speech before resetting timeout."""
vad_mode: int = 3
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
vad_samples_per_chunk: int = 480 # 30 ms
"""Must be 10, 20, or 30 ms at 16Khz."""
_silence_seconds_left: float = 0.0 _silence_seconds_left: float = 0.0
"""Seconds left before considering voice command as stopped.""" """Seconds left before considering voice command as stopped."""
_reset_seconds_left: float = 0.0 _reset_seconds_left: float = 0.0
"""Seconds left before resetting start/stop time counters.""" """Seconds left before resetting start/stop time counters."""
_vad: webrtcvad.Vad = None
_leftover_chunk_buffer: AudioBuffer = field(init=False)
_bytes_per_chunk: int = field(init=False)
_seconds_per_chunk: float = field(init=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
"""Initialize VAD.""" """Reset after initialization."""
self._vad = webrtcvad.Vad(self.vad_mode)
self._bytes_per_chunk = self.vad_samples_per_chunk * _SAMPLE_WIDTH
self._seconds_per_chunk = self.vad_samples_per_chunk / _SAMPLE_RATE
self._leftover_chunk_buffer = AudioBuffer(
self.vad_samples_per_chunk * _SAMPLE_WIDTH
)
self.reset() self.reset()
def reset(self) -> None: def reset(self) -> None:
"""Reset all counters and state.""" """Reset all counters and state."""
self._leftover_chunk_buffer.clear()
self._silence_seconds_left = self.silence_seconds self._silence_seconds_left = self.silence_seconds
self._reset_seconds_left = self.reset_seconds self._reset_seconds_left = self.reset_seconds
def process(self, samples: bytes) -> bool: def process(self, chunk_seconds: float, is_speech: bool | None) -> bool:
"""Process 16-bit 16Khz mono audio samples. """Process samples using external VAD.
Returns False when timeout is reached. Returns False when timeout is reached.
""" """
for chunk in chunk_samples( if is_speech:
samples, self._bytes_per_chunk, self._leftover_chunk_buffer
):
if not self._process_chunk(chunk):
return False
return True
def _process_chunk(self, chunk: bytes) -> bool:
"""Process a single chunk of 16-bit 16Khz mono audio.
Returns False when timeout is reached.
"""
if self._vad.is_speech(chunk, _SAMPLE_RATE):
# Speech # Speech
self._reset_seconds_left -= self._seconds_per_chunk self._reset_seconds_left -= chunk_seconds
if self._reset_seconds_left <= 0: if self._reset_seconds_left <= 0:
# Reset timeout # Reset timeout
self._silence_seconds_left = self.silence_seconds self._silence_seconds_left = self.silence_seconds
else: else:
# Silence # Silence
self._silence_seconds_left -= self._seconds_per_chunk self._silence_seconds_left -= chunk_seconds
if self._silence_seconds_left <= 0: if self._silence_seconds_left <= 0:
# Timeout reached # Timeout reached
self.reset()
return False return False
# Slowly build reset counter back up # Slowly build reset counter back up
self._reset_seconds_left = min( self._reset_seconds_left = min(
self.reset_seconds, self._reset_seconds_left + self._seconds_per_chunk self.reset_seconds, self._reset_seconds_left + chunk_seconds
) )
return True return True

View file

@ -18,6 +18,7 @@ from homeassistant.util import language as language_util
from .const import DOMAIN from .const import DOMAIN
from .error import PipelineNotFound from .error import PipelineNotFound
from .pipeline import ( from .pipeline import (
AudioSettings,
PipelineData, PipelineData,
PipelineError, PipelineError,
PipelineEvent, PipelineEvent,
@ -71,6 +72,13 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
vol.Optional("audio_seconds_to_buffer"): vol.Any( vol.Optional("audio_seconds_to_buffer"): vol.Any(
float, int float, int
), ),
# Audio enhancement
vol.Optional("noise_suppression_level"): int,
vol.Optional("auto_gain_dbfs"): int,
vol.Optional("volume_multiplier"): float,
# Advanced use cases/testing
vol.Optional("no_vad"): bool,
vol.Optional("no_chunking"): bool,
} }
}, },
extra=vol.ALLOW_EXTRA, extra=vol.ALLOW_EXTRA,
@ -115,6 +123,7 @@ async def websocket_run(
handler_id: int | None = None handler_id: int | None = None
unregister_handler: Callable[[], None] | None = None unregister_handler: Callable[[], None] | None = None
wake_word_settings: WakeWordSettings | None = None wake_word_settings: WakeWordSettings | None = None
audio_settings: AudioSettings | None = None
# Arguments to PipelineInput # Arguments to PipelineInput
input_args: dict[str, Any] = { input_args: dict[str, Any] = {
@ -124,13 +133,14 @@ async def websocket_run(
if start_stage in (PipelineStage.WAKE_WORD, PipelineStage.STT): if start_stage in (PipelineStage.WAKE_WORD, PipelineStage.STT):
# Audio pipeline that will receive audio as binary websocket messages # Audio pipeline that will receive audio as binary websocket messages
msg_input = msg["input"]
audio_queue: asyncio.Queue[bytes] = asyncio.Queue() audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
incoming_sample_rate = msg["input"]["sample_rate"] incoming_sample_rate = msg_input["sample_rate"]
if start_stage == PipelineStage.WAKE_WORD: if start_stage == PipelineStage.WAKE_WORD:
wake_word_settings = WakeWordSettings( wake_word_settings = WakeWordSettings(
timeout=msg["input"].get("timeout", DEFAULT_WAKE_WORD_TIMEOUT), timeout=msg["input"].get("timeout", DEFAULT_WAKE_WORD_TIMEOUT),
audio_seconds_to_buffer=msg["input"].get("audio_seconds_to_buffer", 0), audio_seconds_to_buffer=msg_input.get("audio_seconds_to_buffer", 0),
) )
async def stt_stream() -> AsyncGenerator[bytes, None]: async def stt_stream() -> AsyncGenerator[bytes, None]:
@ -166,6 +176,15 @@ async def websocket_run(
channel=stt.AudioChannels.CHANNEL_MONO, channel=stt.AudioChannels.CHANNEL_MONO,
) )
input_args["stt_stream"] = stt_stream() input_args["stt_stream"] = stt_stream()
# Audio settings
audio_settings = AudioSettings(
noise_suppression_level=msg_input.get("noise_suppression_level", 0),
auto_gain_dbfs=msg_input.get("auto_gain_dbfs", 0),
volume_multiplier=msg_input.get("volume_multiplier", 1.0),
is_vad_enabled=not msg_input.get("no_vad", False),
is_chunking_enabled=not msg_input.get("no_chunking", False),
)
elif start_stage == PipelineStage.INTENT: elif start_stage == PipelineStage.INTENT:
# Input to conversation agent # Input to conversation agent
input_args["intent_input"] = msg["input"]["text"] input_args["intent_input"] = msg["input"]["text"]
@ -185,6 +204,7 @@ async def websocket_run(
"timeout": timeout, "timeout": timeout,
}, },
wake_word_settings=wake_word_settings, wake_word_settings=wake_word_settings,
audio_settings=audio_settings or AudioSettings(),
) )
pipeline_input = PipelineInput(**input_args) pipeline_input = PipelineInput(**input_args)

View file

@ -29,8 +29,11 @@ from homeassistant.components.assist_pipeline import (
select as pipeline_select, select as pipeline_select,
) )
from homeassistant.components.assist_pipeline.vad import ( from homeassistant.components.assist_pipeline.vad import (
AudioBuffer,
VadSensitivity, VadSensitivity,
VoiceActivityDetector,
VoiceCommandSegmenter, VoiceCommandSegmenter,
WebRtcVad,
) )
from homeassistant.const import __version__ from homeassistant.const import __version__
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
@ -225,11 +228,13 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
try: try:
# Wait for speech before starting pipeline # Wait for speech before starting pipeline
segmenter = VoiceCommandSegmenter(silence_seconds=self.silence_seconds) segmenter = VoiceCommandSegmenter(silence_seconds=self.silence_seconds)
vad = WebRtcVad()
chunk_buffer: deque[bytes] = deque( chunk_buffer: deque[bytes] = deque(
maxlen=self.buffered_chunks_before_speech, maxlen=self.buffered_chunks_before_speech,
) )
speech_detected = await self._wait_for_speech( speech_detected = await self._wait_for_speech(
segmenter, segmenter,
vad,
chunk_buffer, chunk_buffer,
) )
if not speech_detected: if not speech_detected:
@ -243,6 +248,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
try: try:
async for chunk in self._segment_audio( async for chunk in self._segment_audio(
segmenter, segmenter,
vad,
chunk_buffer, chunk_buffer,
): ):
yield chunk yield chunk
@ -306,6 +312,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
async def _wait_for_speech( async def _wait_for_speech(
self, self,
segmenter: VoiceCommandSegmenter, segmenter: VoiceCommandSegmenter,
vad: VoiceActivityDetector,
chunk_buffer: MutableSequence[bytes], chunk_buffer: MutableSequence[bytes],
): ):
"""Buffer audio chunks until speech is detected. """Buffer audio chunks until speech is detected.
@ -317,12 +324,18 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
async with asyncio.timeout(self.audio_timeout): async with asyncio.timeout(self.audio_timeout):
chunk = await self._audio_queue.get() chunk = await self._audio_queue.get()
assert vad.samples_per_chunk is not None
vad_buffer = AudioBuffer(vad.samples_per_chunk * WIDTH)
while chunk: while chunk:
chunk_buffer.append(chunk) chunk_buffer.append(chunk)
segmenter.process(chunk) segmenter.process_with_vad(chunk, vad, vad_buffer)
if segmenter.in_command: if segmenter.in_command:
# Buffer until command starts # Buffer until command starts
if len(vad_buffer) > 0:
chunk_buffer.append(vad_buffer.bytes())
return True return True
async with asyncio.timeout(self.audio_timeout): async with asyncio.timeout(self.audio_timeout):
@ -333,6 +346,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
async def _segment_audio( async def _segment_audio(
self, self,
segmenter: VoiceCommandSegmenter, segmenter: VoiceCommandSegmenter,
vad: VoiceActivityDetector,
chunk_buffer: Sequence[bytes], chunk_buffer: Sequence[bytes],
) -> AsyncIterable[bytes]: ) -> AsyncIterable[bytes]:
"""Yield audio chunks until voice command has finished.""" """Yield audio chunks until voice command has finished."""
@ -345,8 +359,11 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
async with asyncio.timeout(self.audio_timeout): async with asyncio.timeout(self.audio_timeout):
chunk = await self._audio_queue.get() chunk = await self._audio_queue.get()
assert vad.samples_per_chunk is not None
vad_buffer = AudioBuffer(vad.samples_per_chunk * WIDTH)
while chunk: while chunk:
if not segmenter.process(chunk): if not segmenter.process_with_vad(chunk, vad, vad_buffer):
# Voice command is finished # Voice command is finished
break break

View file

@ -51,7 +51,7 @@ typing-extensions>=4.8.0,<5.0
ulid-transform==0.8.1 ulid-transform==0.8.1
voluptuous-serialize==2.6.0 voluptuous-serialize==2.6.0
voluptuous==0.13.1 voluptuous==0.13.1
webrtcvad==2.0.10 webrtc-noise-gain==1.1.0
yarl==1.9.2 yarl==1.9.2
zeroconf==0.114.0 zeroconf==0.114.0

View file

@ -2691,7 +2691,7 @@ waterfurnace==1.1.0
webexteamssdk==1.1.1 webexteamssdk==1.1.1
# homeassistant.components.assist_pipeline # homeassistant.components.assist_pipeline
webrtcvad==2.0.10 webrtc-noise-gain==1.1.0
# homeassistant.components.whirlpool # homeassistant.components.whirlpool
whirlpool-sixth-sense==0.18.4 whirlpool-sixth-sense==0.18.4

View file

@ -1994,7 +1994,7 @@ wallbox==0.4.12
watchdog==2.3.1 watchdog==2.3.1
# homeassistant.components.assist_pipeline # homeassistant.components.assist_pipeline
webrtcvad==2.0.10 webrtc-noise-gain==1.1.0
# homeassistant.components.whirlpool # homeassistant.components.whirlpool
whirlpool-sixth-sense==0.18.4 whirlpool-sixth-sense==0.18.4

View file

@ -311,18 +311,6 @@
}), }),
'type': <PipelineEventType.STT_START: 'stt-start'>, 'type': <PipelineEventType.STT_START: 'stt-start'>,
}), }),
dict({
'data': dict({
'timestamp': 0,
}),
'type': <PipelineEventType.STT_VAD_START: 'stt-vad-start'>,
}),
dict({
'data': dict({
'timestamp': 1500,
}),
'type': <PipelineEventType.STT_VAD_END: 'stt-vad-end'>,
}),
dict({ dict({
'data': dict({ 'data': dict({
'stt_output': dict({ 'stt_output': dict({

View file

@ -173,6 +173,87 @@
'message': 'No wake-word-detection provider for: wake_word.bad-entity-id', 'message': 'No wake-word-detection provider for: wake_word.bad-entity-id',
}) })
# --- # ---
# name: test_audio_pipeline_with_enhancements
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 30,
}),
})
# ---
# name: test_audio_pipeline_with_enhancements.1
dict({
'engine': 'test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'language': 'en-US',
'sample_rate': 16000,
}),
})
# ---
# name: test_audio_pipeline_with_enhancements.2
dict({
'stt_output': dict({
'text': 'test transcript',
}),
})
# ---
# name: test_audio_pipeline_with_enhancements.3
dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant',
'intent_input': 'test transcript',
'language': 'en',
})
# ---
# name: test_audio_pipeline_with_enhancements.4
dict({
'intent_output': dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'code': 'no_intent_match',
}),
'language': 'en',
'response_type': 'error',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': "Sorry, I couldn't understand that",
}),
}),
}),
}),
})
# ---
# name: test_audio_pipeline_with_enhancements.5
dict({
'engine': 'test',
'language': 'en-US',
'tts_input': "Sorry, I couldn't understand that",
'voice': 'james_earl_jones',
})
# ---
# name: test_audio_pipeline_with_enhancements.6
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
}),
})
# ---
# name: test_audio_pipeline_with_enhancements.7
None
# ---
# name: test_audio_pipeline_with_wake_word # name: test_audio_pipeline_with_wake_word
dict({ dict({
'language': 'en', 'language': 'en',

View file

@ -64,6 +64,9 @@ async def test_pipeline_from_audio_stream_auto(
channel=stt.AudioChannels.CHANNEL_MONO, channel=stt.AudioChannels.CHANNEL_MONO,
), ),
stt_stream=audio_data(), stt_stream=audio_data(),
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
) )
assert process_events(events) == snapshot assert process_events(events) == snapshot
@ -126,6 +129,9 @@ async def test_pipeline_from_audio_stream_legacy(
), ),
stt_stream=audio_data(), stt_stream=audio_data(),
pipeline_id=pipeline_id, pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
) )
assert process_events(events) == snapshot assert process_events(events) == snapshot
@ -188,6 +194,9 @@ async def test_pipeline_from_audio_stream_entity(
), ),
stt_stream=audio_data(), stt_stream=audio_data(),
pipeline_id=pipeline_id, pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
) )
assert process_events(events) == snapshot assert process_events(events) == snapshot
@ -251,6 +260,9 @@ async def test_pipeline_from_audio_stream_no_stt(
), ),
stt_stream=audio_data(), stt_stream=audio_data(),
pipeline_id=pipeline_id, pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
) )
assert not events assert not events
@ -312,44 +324,47 @@ async def test_pipeline_from_audio_stream_wake_word(
# [0, 2, ...] # [0, 2, ...]
wake_chunk_2 = bytes(it.islice(it.cycle(range(0, 256, 2)), BYTES_ONE_SECOND)) wake_chunk_2 = bytes(it.islice(it.cycle(range(0, 256, 2)), BYTES_ONE_SECOND))
bytes_per_chunk = int(0.01 * BYTES_ONE_SECOND)
async def audio_data(): async def audio_data():
yield wake_chunk_1 # 1 second # 1 second in 10 ms chunks
yield wake_chunk_2 # 1 second i = 0
while i < len(wake_chunk_1):
yield wake_chunk_1[i : i + bytes_per_chunk]
i += bytes_per_chunk
# 1 second in 30 ms chunks
i = 0
while i < len(wake_chunk_2):
yield wake_chunk_2[i : i + bytes_per_chunk]
i += bytes_per_chunk
yield b"wake word!" yield b"wake word!"
yield b"part1" yield b"part1"
yield b"part2" yield b"part2"
yield b"end"
yield b"" yield b""
def continue_stt(self, chunk): await assist_pipeline.async_pipeline_from_audio_stream(
# Ensure stt_vad_start event is triggered hass,
self.in_command = True context=Context(),
event_callback=events.append,
# Stop on fake end chunk to trigger stt_vad_end stt_metadata=stt.SpeechMetadata(
return chunk != b"end" language="",
format=stt.AudioFormats.WAV,
with patch( codec=stt.AudioCodecs.PCM,
"homeassistant.components.assist_pipeline.pipeline.VoiceCommandSegmenter.process", bit_rate=stt.AudioBitRates.BITRATE_16,
continue_stt, sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
): channel=stt.AudioChannels.CHANNEL_MONO,
await assist_pipeline.async_pipeline_from_audio_stream( ),
hass, stt_stream=audio_data(),
context=Context(), start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
event_callback=events.append, wake_word_settings=assist_pipeline.WakeWordSettings(
stt_metadata=stt.SpeechMetadata( audio_seconds_to_buffer=1.5
language="", ),
format=stt.AudioFormats.WAV, audio_settings=assist_pipeline.AudioSettings(
codec=stt.AudioCodecs.PCM, is_vad_enabled=False, is_chunking_enabled=False
bit_rate=stt.AudioBitRates.BITRATE_16, ),
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, )
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
wake_word_settings=assist_pipeline.WakeWordSettings(
audio_seconds_to_buffer=1.5
),
)
assert process_events(events) == snapshot assert process_events(events) == snapshot
@ -357,12 +372,14 @@ async def test_pipeline_from_audio_stream_wake_word(
# 2. queued audio (from mock wake word entity) # 2. queued audio (from mock wake word entity)
# 3. part1 # 3. part1
# 4. part2 # 4. part2
assert len(mock_stt_provider.received) == 4 assert len(mock_stt_provider.received) > 3
first_chunk = mock_stt_provider.received[0] first_chunk = bytes(
[c_byte for c in mock_stt_provider.received[:-3] for c_byte in c]
)
assert first_chunk == wake_chunk_1[len(wake_chunk_1) // 2 :] + wake_chunk_2 assert first_chunk == wake_chunk_1[len(wake_chunk_1) // 2 :] + wake_chunk_2
assert mock_stt_provider.received[1:] == [b"queued audio", b"part1", b"part2"] assert mock_stt_provider.received[-3:] == [b"queued audio", b"part1", b"part2"]
async def test_pipeline_save_audio( async def test_pipeline_save_audio(
@ -410,6 +427,9 @@ async def test_pipeline_save_audio(
pipeline_id=pipeline.id, pipeline_id=pipeline.id,
start_stage=assist_pipeline.PipelineStage.WAKE_WORD, start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
end_stage=assist_pipeline.PipelineStage.STT, end_stage=assist_pipeline.PipelineStage.STT,
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
) )
pipeline_dirs = list(temp_dir.iterdir()) pipeline_dirs = list(temp_dir.iterdir())

View file

@ -1,14 +1,15 @@
"""Tests for webrtcvad voice command segmenter.""" """Tests for voice command segmenter."""
import itertools as it import itertools as it
from unittest.mock import patch from unittest.mock import patch
from homeassistant.components.assist_pipeline.vad import ( from homeassistant.components.assist_pipeline.vad import (
AudioBuffer, AudioBuffer,
VoiceActivityDetector,
VoiceCommandSegmenter, VoiceCommandSegmenter,
chunk_samples, chunk_samples,
) )
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit _ONE_SECOND = 1.0
def test_silence() -> None: def test_silence() -> None:
@ -16,87 +17,85 @@ def test_silence() -> None:
segmenter = VoiceCommandSegmenter() segmenter = VoiceCommandSegmenter()
# True return value indicates voice command has not finished # True return value indicates voice command has not finished
assert segmenter.process(bytes(_ONE_SECOND * 3)) assert segmenter.process(_ONE_SECOND * 3, False)
def test_speech() -> None: def test_speech() -> None:
"""Test that silence + speech + silence triggers a voice command.""" """Test that silence + speech + silence triggers a voice command."""
def is_speech(self, chunk, sample_rate): def is_speech(chunk):
"""Anything non-zero is speech.""" """Anything non-zero is speech."""
return sum(chunk) > 0 return sum(chunk) > 0
with patch( segmenter = VoiceCommandSegmenter()
"webrtcvad.Vad.is_speech",
new=is_speech,
):
segmenter = VoiceCommandSegmenter()
# silence # silence
assert segmenter.process(bytes(_ONE_SECOND)) assert segmenter.process(_ONE_SECOND, False)
# "speech" # "speech"
assert segmenter.process(bytes([255] * _ONE_SECOND)) assert segmenter.process(_ONE_SECOND, True)
# silence # silence
# False return value indicates voice command is finished # False return value indicates voice command is finished
assert not segmenter.process(bytes(_ONE_SECOND)) assert not segmenter.process(_ONE_SECOND, False)
def test_audio_buffer() -> None: def test_audio_buffer() -> None:
"""Test audio buffer wrapping.""" """Test audio buffer wrapping."""
def is_speech(self, chunk, sample_rate): class DisabledVad(VoiceActivityDetector):
"""Disable VAD.""" def is_speech(self, chunk):
return False return False
with patch( @property
"webrtcvad.Vad.is_speech", def samples_per_chunk(self):
new=is_speech, return 160 # 10 ms
):
segmenter = VoiceCommandSegmenter()
bytes_per_chunk = segmenter.vad_samples_per_chunk * 2
with patch.object( vad = DisabledVad()
segmenter, "_process_chunk", return_value=True bytes_per_chunk = vad.samples_per_chunk * 2
) as mock_process: vad_buffer = AudioBuffer(bytes_per_chunk)
# Partially fill audio buffer segmenter = VoiceCommandSegmenter()
half_chunk = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk // 2))
segmenter.process(half_chunk)
assert not mock_process.called with patch.object(vad, "is_speech", return_value=False) as mock_process:
assert segmenter.audio_buffer == half_chunk # Partially fill audio buffer
half_chunk = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk // 2))
segmenter.process_with_vad(half_chunk, vad, vad_buffer)
# Fill and wrap with 1/4 chunk left over assert not mock_process.called
three_quarters_chunk = bytes( assert vad_buffer is not None
it.islice(it.cycle(range(256)), int(0.75 * bytes_per_chunk)) assert vad_buffer.bytes() == half_chunk
)
segmenter.process(three_quarters_chunk)
assert mock_process.call_count == 1 # Fill and wrap with 1/4 chunk left over
assert ( three_quarters_chunk = bytes(
segmenter.audio_buffer it.islice(it.cycle(range(256)), int(0.75 * bytes_per_chunk))
== three_quarters_chunk[ )
len(three_quarters_chunk) - (bytes_per_chunk // 4) : segmenter.process_with_vad(three_quarters_chunk, vad, vad_buffer)
]
)
assert (
mock_process.call_args[0][0]
== half_chunk + three_quarters_chunk[: bytes_per_chunk // 2]
)
# Run 2 chunks through assert mock_process.call_count == 1
segmenter.reset() assert (
assert len(segmenter.audio_buffer) == 0 vad_buffer.bytes()
== three_quarters_chunk[
len(three_quarters_chunk) - (bytes_per_chunk // 4) :
]
)
assert (
mock_process.call_args[0][0]
== half_chunk + three_quarters_chunk[: bytes_per_chunk // 2]
)
mock_process.reset_mock() # Run 2 chunks through
two_chunks = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk * 2)) segmenter.reset()
segmenter.process(two_chunks) vad_buffer.clear()
assert len(vad_buffer) == 0
assert mock_process.call_count == 2 mock_process.reset_mock()
assert len(segmenter.audio_buffer) == 0 two_chunks = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk * 2))
assert mock_process.call_args_list[0][0][0] == two_chunks[:bytes_per_chunk] segmenter.process_with_vad(two_chunks, vad, vad_buffer)
assert mock_process.call_args_list[1][0][0] == two_chunks[bytes_per_chunk:]
assert mock_process.call_count == 2
assert len(vad_buffer) == 0
assert mock_process.call_args_list[0][0][0] == two_chunks[:bytes_per_chunk]
assert mock_process.call_args_list[1][0][0] == two_chunks[bytes_per_chunk:]
def test_partial_chunk() -> None: def test_partial_chunk() -> None:
@ -125,3 +124,43 @@ def test_chunk_samples_leftover() -> None:
assert len(chunks) == 1 assert len(chunks) == 1
assert leftover_chunk_buffer.bytes() == bytes([5, 6]) assert leftover_chunk_buffer.bytes() == bytes([5, 6])
def test_vad_no_chunking() -> None:
"""Test VAD that doesn't require chunking."""
class VadNoChunk(VoiceActivityDetector):
def is_speech(self, chunk: bytes) -> bool:
return sum(chunk) > 0
@property
def samples_per_chunk(self) -> int | None:
return None
vad = VadNoChunk()
segmenter = VoiceCommandSegmenter(
speech_seconds=1.0, silence_seconds=1.0, reset_seconds=0.5
)
silence = bytes([0] * 16000)
speech = bytes([255] * (16000 // 2))
# Test with differently-sized chunks
assert vad.is_speech(speech)
assert not vad.is_speech(silence)
# Simulate voice command
assert segmenter.process_with_vad(silence, vad, None)
# begin
assert segmenter.process_with_vad(speech, vad, None)
assert segmenter.process_with_vad(speech, vad, None)
assert segmenter.process_with_vad(speech, vad, None)
# reset with silence
assert segmenter.process_with_vad(silence, vad, None)
# resume
assert segmenter.process_with_vad(speech, vad, None)
assert segmenter.process_with_vad(speech, vad, None)
assert segmenter.process_with_vad(speech, vad, None)
assert segmenter.process_with_vad(speech, vad, None)
# end
assert segmenter.process_with_vad(silence, vad, None)
assert not segmenter.process_with_vad(silence, vad, None)

View file

@ -107,6 +107,7 @@ async def test_audio_pipeline(
assert msg["event"]["type"] == "run-start" assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
events.append(msg["event"]) events.append(msg["event"])
# stt # stt
@ -116,7 +117,7 @@ async def test_audio_pipeline(
events.append(msg["event"]) events.append(msg["event"])
# End of audio stream (handler id + empty payload) # End of audio stream (handler id + empty payload)
await client.send_bytes(bytes([1])) await client.send_bytes(bytes([handler_id]))
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "stt-end" assert msg["event"]["type"] == "stt-end"
@ -240,6 +241,8 @@ async def test_audio_pipeline_with_wake_word_no_timeout(
"input": { "input": {
"sample_rate": 16000, "sample_rate": 16000,
"timeout": 0, "timeout": 0,
"no_vad": True,
"no_chunking": True,
}, },
} }
) )
@ -253,6 +256,7 @@ async def test_audio_pipeline_with_wake_word_no_timeout(
assert msg["event"]["type"] == "run-start" assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
events.append(msg["event"]) events.append(msg["event"])
# wake_word # wake_word
@ -276,7 +280,7 @@ async def test_audio_pipeline_with_wake_word_no_timeout(
events.append(msg["event"]) events.append(msg["event"])
# End of audio stream (handler id + empty payload) # End of audio stream (handler id + empty payload)
await client.send_bytes(bytes([1])) await client.send_bytes(bytes([handler_id]))
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "stt-end" assert msg["event"]["type"] == "stt-end"
@ -731,6 +735,7 @@ async def test_stt_stream_failed(
assert msg["event"]["type"] == "run-start" assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
events.append(msg["event"]) events.append(msg["event"])
# stt # stt
@ -740,7 +745,7 @@ async def test_stt_stream_failed(
events.append(msg["event"]) events.append(msg["event"])
# End of audio stream (handler id + empty payload) # End of audio stream (handler id + empty payload)
await client.send_bytes(b"1") await client.send_bytes(bytes([handler_id]))
# stt error # stt error
msg = await client.receive_json() msg = await client.receive_json()
@ -1489,6 +1494,7 @@ async def test_audio_pipeline_debug(
assert msg["event"]["type"] == "run-start" assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
events.append(msg["event"]) events.append(msg["event"])
# stt # stt
@ -1498,7 +1504,7 @@ async def test_audio_pipeline_debug(
events.append(msg["event"]) events.append(msg["event"])
# End of audio stream (handler id + empty payload) # End of audio stream (handler id + empty payload)
await client.send_bytes(bytes([1])) await client.send_bytes(bytes([handler_id]))
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "stt-end" assert msg["event"]["type"] == "stt-end"
@ -1699,3 +1705,103 @@ async def test_list_pipeline_languages_with_aliases(
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
assert msg["result"] == {"languages": ["he", "nb"]} assert msg["result"] == {"languages": ["he", "nb"]}
async def test_audio_pipeline_with_enhancements(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test events from a pipeline run with audio input/output."""
events = []
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "stt",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
# Enhancements
"noise_suppression_level": 2,
"auto_gain_dbfs": 15,
"volume_multiplier": 2.0,
},
}
)
# result
msg = await client.receive_json()
assert msg["success"]
# run start
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
events.append(msg["event"])
# stt
msg = await client.receive_json()
assert msg["event"]["type"] == "stt-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# One second of silence.
# This will pass through the audio enhancement pipeline, but we don't test
# the actual output.
await client.send_bytes(bytes([handler_id]) + bytes(16000 * 2))
# End of audio stream (handler id + empty payload)
await client.send_bytes(bytes([handler_id]))
msg = await client.receive_json()
assert msg["event"]["type"] == "stt-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# intent
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# text-to-speech
msg = await client.receive_json()
assert msg["event"]["type"] == "tts-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
msg = await client.receive_json()
assert msg["event"]["type"] == "tts-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# run end
msg = await client.receive_json()
assert msg["event"]["type"] == "run-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = list(pipeline_data.pipeline_runs)[0]
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}

View file

@ -21,7 +21,7 @@ async def test_pipeline(
"""Test that pipeline function is called from RTP protocol.""" """Test that pipeline function is called from RTP protocol."""
assert await async_setup_component(hass, "voip", {}) assert await async_setup_component(hass, "voip", {})
def is_speech(self, chunk, sample_rate): def is_speech(self, chunk):
"""Anything non-zero is speech.""" """Anything non-zero is speech."""
return sum(chunk) > 0 return sum(chunk) > 0
@ -76,7 +76,7 @@ async def test_pipeline(
return ("mp3", b"") return ("mp3", b"")
with patch( with patch(
"webrtcvad.Vad.is_speech", "homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
new=is_speech, new=is_speech,
), patch( ), patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream", "homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
@ -210,7 +210,7 @@ async def test_tts_timeout(
"""Test that TTS will time out based on its length.""" """Test that TTS will time out based on its length."""
assert await async_setup_component(hass, "voip", {}) assert await async_setup_component(hass, "voip", {})
def is_speech(self, chunk, sample_rate): def is_speech(self, chunk):
"""Anything non-zero is speech.""" """Anything non-zero is speech."""
return sum(chunk) > 0 return sum(chunk) > 0
@ -269,7 +269,7 @@ async def test_tts_timeout(
return ("raw", bytes(0)) return ("raw", bytes(0))
with patch( with patch(
"webrtcvad.Vad.is_speech", "homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
new=is_speech, new=is_speech,
), patch( ), patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream", "homeassistant.components.voip.voip.async_pipeline_from_audio_stream",