Use webrtc-noise-gain for audio enhancement in Assist pipelines (#100698)

* Use webrtc-noise-gain instead of webrtcvad package

* Switching to ProcessedAudioChunk

* Refactor VAD and fix tests

* Add vad no chunking test

* Add test that runs audio enhancements
This commit is contained in:
Michael Hansen 2023-09-25 19:03:50 -05:00 committed by GitHub
parent a4f7f3ba7e
commit 785618909a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 707 additions and 258 deletions

View file

@ -12,6 +12,7 @@ from homeassistant.helpers.typing import ConfigType
from .const import DATA_CONFIG, DOMAIN
from .error import PipelineNotFound
from .pipeline import (
AudioSettings,
Pipeline,
PipelineEvent,
PipelineEventCallback,
@ -33,6 +34,7 @@ __all__ = (
"async_get_pipelines",
"async_setup",
"async_pipeline_from_audio_stream",
"AudioSettings",
"Pipeline",
"PipelineEvent",
"PipelineEventType",
@ -71,6 +73,7 @@ async def async_pipeline_from_audio_stream(
conversation_id: str | None = None,
tts_audio_output: str | None = None,
wake_word_settings: WakeWordSettings | None = None,
audio_settings: AudioSettings | None = None,
device_id: str | None = None,
start_stage: PipelineStage = PipelineStage.STT,
end_stage: PipelineStage = PipelineStage.TTS,
@ -93,6 +96,7 @@ async def async_pipeline_from_audio_stream(
event_callback=event_callback,
tts_audio_output=tts_audio_output,
wake_word_settings=wake_word_settings,
audio_settings=audio_settings or AudioSettings(),
),
)
await pipeline_input.validate()

View file

@ -6,5 +6,5 @@
"documentation": "https://www.home-assistant.io/integrations/assist_pipeline",
"iot_class": "local_push",
"quality_scale": "internal",
"requirements": ["webrtcvad==2.0.10"]
"requirements": ["webrtc-noise-gain==1.1.0"]
}

View file

@ -1,7 +1,9 @@
"""Classes for voice assistant pipelines."""
from __future__ import annotations
import array
import asyncio
from collections import deque
from collections.abc import AsyncGenerator, AsyncIterable, Callable, Iterable
from dataclasses import asdict, dataclass, field
from enum import StrEnum
@ -10,10 +12,11 @@ from pathlib import Path
from queue import Queue
from threading import Thread
import time
from typing import Any, cast
from typing import Any, Final, cast
import wave
import voluptuous as vol
from webrtc_noise_gain import AudioProcessor
from homeassistant.components import (
conversation,
@ -54,8 +57,7 @@ from .error import (
WakeWordDetectionError,
WakeWordTimeoutError,
)
from .ring_buffer import RingBuffer
from .vad import VoiceActivityTimeout, VoiceCommandSegmenter
from .vad import AudioBuffer, VoiceActivityTimeout, VoiceCommandSegmenter, chunk_samples
_LOGGER = logging.getLogger(__name__)
@ -95,6 +97,9 @@ STORED_PIPELINE_RUNS = 10
SAVE_DELAY = 10
AUDIO_PROCESSOR_SAMPLES: Final = 160 # 10 ms @ 16 Khz
AUDIO_PROCESSOR_BYTES: Final = AUDIO_PROCESSOR_SAMPLES * 2 # 16-bit samples
async def _async_resolve_default_pipeline_settings(
hass: HomeAssistant,
@ -393,6 +398,60 @@ class WakeWordSettings:
"""Seconds of audio to buffer before detection and forward to STT."""
@dataclass(frozen=True)
class AudioSettings:
"""Settings for pipeline audio processing."""
noise_suppression_level: int = 0
"""Level of noise suppression (0 = disabled, 4 = max)"""
auto_gain_dbfs: int = 0
"""Amount of automatic gain in dbFS (0 = disabled, 31 = max)"""
volume_multiplier: float = 1.0
"""Multiplier used directly on PCM samples (1.0 = no change, 2.0 = twice as loud)"""
is_vad_enabled: bool = True
"""True if VAD is used to determine the end of the voice command."""
is_chunking_enabled: bool = True
"""True if audio is automatically split into 10 ms chunks (required for VAD, etc.)"""
def __post_init__(self) -> None:
"""Verify settings post-initialization."""
if (self.noise_suppression_level < 0) or (self.noise_suppression_level > 4):
raise ValueError("noise_suppression_level must be in [0, 4]")
if (self.auto_gain_dbfs < 0) or (self.auto_gain_dbfs > 31):
raise ValueError("auto_gain_dbfs must be in [0, 31]")
if self.needs_processor and (not self.is_chunking_enabled):
raise ValueError("Chunking must be enabled for audio processing")
@property
def needs_processor(self) -> bool:
"""True if an audio processor is needed."""
return (
self.is_vad_enabled
or (self.noise_suppression_level > 0)
or (self.auto_gain_dbfs > 0)
)
@dataclass(frozen=True, slots=True)
class ProcessedAudioChunk:
"""Processed audio chunk and metadata."""
audio: bytes
"""Raw PCM audio @ 16Khz with 16-bit mono samples"""
timestamp_ms: int
"""Timestamp relative to start of audio stream (milliseconds)"""
is_speech: bool | None
"""True if audio chunk likely contains speech, False if not, None if unknown"""
@dataclass
class PipelineRun:
"""Running context for a pipeline."""
@ -408,6 +467,7 @@ class PipelineRun:
intent_agent: str | None = None
tts_audio_output: str | None = None
wake_word_settings: WakeWordSettings | None = None
audio_settings: AudioSettings = field(default_factory=AudioSettings)
id: str = field(default_factory=ulid_util.ulid)
stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False)
@ -422,6 +482,12 @@ class PipelineRun:
debug_recording_queue: Queue[str | bytes | None] | None = None
"""Queue to communicate with debug recording thread"""
audio_processor: AudioProcessor | None = None
"""VAD/noise suppression/auto gain"""
audio_processor_buffer: AudioBuffer = field(init=False)
"""Buffer used when splitting audio into chunks for audio processing"""
def __post_init__(self) -> None:
"""Set language for pipeline."""
self.language = self.pipeline.language or self.hass.config.language
@ -439,6 +505,14 @@ class PipelineRun:
)
pipeline_data.pipeline_runs[self.pipeline.id][self.id] = PipelineRunDebug()
# Initialize with audio settings
self.audio_processor_buffer = AudioBuffer(AUDIO_PROCESSOR_BYTES)
if self.audio_settings.needs_processor:
self.audio_processor = AudioProcessor(
self.audio_settings.auto_gain_dbfs,
self.audio_settings.noise_suppression_level,
)
@callback
def process_event(self, event: PipelineEvent) -> None:
"""Log an event and call listener."""
@ -499,8 +573,8 @@ class PipelineRun:
async def wake_word_detection(
self,
stream: AsyncIterable[bytes],
audio_chunks_for_stt: list[bytes],
stream: AsyncIterable[ProcessedAudioChunk],
audio_chunks_for_stt: list[ProcessedAudioChunk],
) -> wake_word.DetectionResult | None:
"""Run wake-word-detection portion of pipeline. Returns detection result."""
metadata_dict = asdict(
@ -541,12 +615,13 @@ class PipelineRun:
# Audio chunk buffer. This audio will be forwarded to speech-to-text
# after wake-word-detection.
num_audio_bytes_to_buffer = int(
wake_word_settings.audio_seconds_to_buffer * 16000 * 2 # 16-bit @ 16Khz
num_audio_chunks_to_buffer = int(
(wake_word_settings.audio_seconds_to_buffer * 16000)
/ AUDIO_PROCESSOR_SAMPLES
)
stt_audio_buffer: RingBuffer | None = None
if num_audio_bytes_to_buffer > 0:
stt_audio_buffer = RingBuffer(num_audio_bytes_to_buffer)
stt_audio_buffer: deque[ProcessedAudioChunk] | None = None
if num_audio_chunks_to_buffer > 0:
stt_audio_buffer = deque(maxlen=num_audio_chunks_to_buffer)
try:
# Detect wake word(s)
@ -562,7 +637,7 @@ class PipelineRun:
if stt_audio_buffer is not None:
# All audio kept from right before the wake word was detected as
# a single chunk.
audio_chunks_for_stt.append(stt_audio_buffer.getvalue())
audio_chunks_for_stt.extend(stt_audio_buffer)
except WakeWordTimeoutError:
_LOGGER.debug("Timeout during wake word detection")
raise
@ -586,7 +661,11 @@ class PipelineRun:
# speech-to-text so the user does not have to pause before
# speaking the voice command.
for chunk_ts in result.queued_audio:
audio_chunks_for_stt.append(chunk_ts[0])
audio_chunks_for_stt.append(
ProcessedAudioChunk(
audio=chunk_ts[0], timestamp_ms=chunk_ts[1], is_speech=False
)
)
wake_word_output = asdict(result)
@ -604,8 +683,8 @@ class PipelineRun:
async def _wake_word_audio_stream(
self,
audio_stream: AsyncIterable[bytes],
stt_audio_buffer: RingBuffer | None,
audio_stream: AsyncIterable[ProcessedAudioChunk],
stt_audio_buffer: deque[ProcessedAudioChunk] | None,
wake_word_vad: VoiceActivityTimeout | None,
sample_rate: int = 16000,
sample_width: int = 2,
@ -615,25 +694,24 @@ class PipelineRun:
Adds audio to a ring buffer that will be forwarded to speech-to-text after
detection. Times out if VAD detects enough silence.
"""
ms_per_sample = sample_rate // 1000
timestamp_ms = 0
chunk_seconds = AUDIO_PROCESSOR_SAMPLES / sample_rate
async for chunk in audio_stream:
if self.debug_recording_queue is not None:
self.debug_recording_queue.put_nowait(chunk)
self.debug_recording_queue.put_nowait(chunk.audio)
yield chunk, timestamp_ms
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
yield chunk.audio, chunk.timestamp_ms
# Wake-word-detection occurs *after* the wake word was actually
# spoken. Keeping audio right before detection allows the voice
# command to be spoken immediately after the wake word.
if stt_audio_buffer is not None:
stt_audio_buffer.put(chunk)
stt_audio_buffer.append(chunk)
if (wake_word_vad is not None) and (not wake_word_vad.process(chunk)):
raise WakeWordTimeoutError(
code="wake-word-timeout", message="Wake word was not detected"
)
if wake_word_vad is not None:
if not wake_word_vad.process(chunk_seconds, chunk.is_speech):
raise WakeWordTimeoutError(
code="wake-word-timeout", message="Wake word was not detected"
)
async def prepare_speech_to_text(self, metadata: stt.SpeechMetadata) -> None:
"""Prepare speech-to-text."""
@ -666,7 +744,7 @@ class PipelineRun:
async def speech_to_text(
self,
metadata: stt.SpeechMetadata,
stream: AsyncIterable[bytes],
stream: AsyncIterable[ProcessedAudioChunk],
) -> str:
"""Run speech-to-text portion of pipeline. Returns the spoken text."""
if isinstance(self.stt_provider, stt.Provider):
@ -690,11 +768,13 @@ class PipelineRun:
try:
# Transcribe audio stream
stt_vad: VoiceCommandSegmenter | None = None
if self.audio_settings.is_vad_enabled:
stt_vad = VoiceCommandSegmenter()
result = await self.stt_provider.async_process_audio_stream(
metadata,
self._speech_to_text_stream(
audio_stream=stream, stt_vad=VoiceCommandSegmenter()
),
self._speech_to_text_stream(audio_stream=stream, stt_vad=stt_vad),
)
except Exception as src_error:
_LOGGER.exception("Unexpected error during speech-to-text")
@ -731,26 +811,25 @@ class PipelineRun:
async def _speech_to_text_stream(
self,
audio_stream: AsyncIterable[bytes],
audio_stream: AsyncIterable[ProcessedAudioChunk],
stt_vad: VoiceCommandSegmenter | None,
sample_rate: int = 16000,
sample_width: int = 2,
) -> AsyncGenerator[bytes, None]:
"""Yield audio chunks until VAD detects silence or speech-to-text completes."""
ms_per_sample = sample_rate // 1000
chunk_seconds = AUDIO_PROCESSOR_SAMPLES / sample_rate
sent_vad_start = False
timestamp_ms = 0
async for chunk in audio_stream:
if self.debug_recording_queue is not None:
self.debug_recording_queue.put_nowait(chunk)
self.debug_recording_queue.put_nowait(chunk.audio)
if stt_vad is not None:
if not stt_vad.process(chunk):
if not stt_vad.process(chunk_seconds, chunk.is_speech):
# Silence detected at the end of voice command
self.process_event(
PipelineEvent(
PipelineEventType.STT_VAD_END,
{"timestamp": timestamp_ms},
{"timestamp": chunk.timestamp_ms},
)
)
break
@ -760,13 +839,12 @@ class PipelineRun:
self.process_event(
PipelineEvent(
PipelineEventType.STT_VAD_START,
{"timestamp": timestamp_ms},
{"timestamp": chunk.timestamp_ms},
)
)
sent_vad_start = True
yield chunk
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
yield chunk.audio
async def prepare_recognize_intent(self) -> None:
"""Prepare recognizing an intent."""
@ -977,6 +1055,94 @@ class PipelineRun:
self.debug_recording_queue = None
self.debug_recording_thread = None
async def process_volume_only(
self,
audio_stream: AsyncIterable[bytes],
sample_rate: int = 16000,
sample_width: int = 2,
) -> AsyncGenerator[ProcessedAudioChunk, None]:
"""Apply volume transformation only (no VAD/audio enhancements) with optional chunking."""
ms_per_sample = sample_rate // 1000
ms_per_chunk = (AUDIO_PROCESSOR_SAMPLES // sample_width) // ms_per_sample
timestamp_ms = 0
async for chunk in audio_stream:
if self.audio_settings.volume_multiplier != 1.0:
chunk = _multiply_volume(chunk, self.audio_settings.volume_multiplier)
if self.audio_settings.is_chunking_enabled:
# 10 ms chunking
for chunk_10ms in chunk_samples(
chunk, AUDIO_PROCESSOR_BYTES, self.audio_processor_buffer
):
yield ProcessedAudioChunk(
audio=chunk_10ms,
timestamp_ms=timestamp_ms,
is_speech=None, # no VAD
)
timestamp_ms += ms_per_chunk
else:
# No chunking
yield ProcessedAudioChunk(
audio=chunk,
timestamp_ms=timestamp_ms,
is_speech=None, # no VAD
)
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
async def process_enhance_audio(
self,
audio_stream: AsyncIterable[bytes],
sample_rate: int = 16000,
sample_width: int = 2,
) -> AsyncGenerator[ProcessedAudioChunk, None]:
"""Split audio into 10 ms chunks and apply VAD/noise suppression/auto gain/volume transformation."""
assert self.audio_processor is not None
ms_per_sample = sample_rate // 1000
ms_per_chunk = (AUDIO_PROCESSOR_SAMPLES // sample_width) // ms_per_sample
timestamp_ms = 0
async for dirty_samples in audio_stream:
if self.audio_settings.volume_multiplier != 1.0:
# Static gain
dirty_samples = _multiply_volume(
dirty_samples, self.audio_settings.volume_multiplier
)
# Split into 10ms chunks for audio enhancements/VAD
for dirty_10ms_chunk in chunk_samples(
dirty_samples, AUDIO_PROCESSOR_BYTES, self.audio_processor_buffer
):
ap_result = self.audio_processor.Process10ms(dirty_10ms_chunk)
yield ProcessedAudioChunk(
audio=ap_result.audio,
timestamp_ms=timestamp_ms,
is_speech=ap_result.is_speech,
)
timestamp_ms += ms_per_chunk
def _multiply_volume(chunk: bytes, volume_multiplier: float) -> bytes:
"""Multiplies 16-bit PCM samples by a constant."""
return array.array(
"h",
[
int(
# Clamp to signed 16-bit range
max(
-32767,
min(
32767,
value * volume_multiplier,
),
)
)
for value in array.array("h", chunk)
],
).tobytes()
def _pipeline_debug_recording_thread_proc(
run_recording_dir: Path,
@ -1042,14 +1208,23 @@ class PipelineInput:
"""Run pipeline."""
self.run.start(device_id=self.device_id)
current_stage: PipelineStage | None = self.run.start_stage
stt_audio_buffer: list[bytes] = []
stt_audio_buffer: list[ProcessedAudioChunk] = []
stt_processed_stream: AsyncIterable[ProcessedAudioChunk] | None = None
if self.stt_stream is not None:
if self.run.audio_settings.needs_processor:
# VAD/noise suppression/auto gain/volume
stt_processed_stream = self.run.process_enhance_audio(self.stt_stream)
else:
# Volume multiplier only
stt_processed_stream = self.run.process_volume_only(self.stt_stream)
try:
if current_stage == PipelineStage.WAKE_WORD:
# wake-word-detection
assert self.stt_stream is not None
assert stt_processed_stream is not None
detect_result = await self.run.wake_word_detection(
self.stt_stream, stt_audio_buffer
stt_processed_stream, stt_audio_buffer
)
if detect_result is None:
# No wake word. Abort the rest of the pipeline.
@ -1062,28 +1237,30 @@ class PipelineInput:
intent_input = self.intent_input
if current_stage == PipelineStage.STT:
assert self.stt_metadata is not None
assert self.stt_stream is not None
assert stt_processed_stream is not None
stt_stream = self.stt_stream
stt_input_stream = stt_processed_stream
if stt_audio_buffer:
# Send audio in the buffer first to speech-to-text, then move on to stt_stream.
# This is basically an async itertools.chain.
async def buffer_then_audio_stream() -> AsyncGenerator[bytes, None]:
async def buffer_then_audio_stream() -> AsyncGenerator[
ProcessedAudioChunk, None
]:
# Buffered audio
for chunk in stt_audio_buffer:
yield chunk
# Streamed audio
assert self.stt_stream is not None
async for chunk in self.stt_stream:
assert stt_processed_stream is not None
async for chunk in stt_processed_stream:
yield chunk
stt_stream = buffer_then_audio_stream()
stt_input_stream = buffer_then_audio_stream()
intent_input = await self.run.speech_to_text(
self.stt_metadata,
stt_stream,
stt_input_stream,
)
current_stage = PipelineStage.INTENT

View file

@ -1,12 +1,13 @@
"""Voice activity detection."""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass, field
from dataclasses import dataclass
from enum import StrEnum
from typing import Final
from typing import Final, cast
import webrtcvad
from webrtc_noise_gain import AudioProcessor
_SAMPLE_RATE: Final = 16000 # Hz
_SAMPLE_WIDTH: Final = 2 # bytes
@ -32,6 +33,38 @@ class VadSensitivity(StrEnum):
return 1.0
class VoiceActivityDetector(ABC):
"""Base class for voice activity detectors (VAD)."""
@abstractmethod
def is_speech(self, chunk: bytes) -> bool:
"""Return True if audio chunk contains speech."""
@property
@abstractmethod
def samples_per_chunk(self) -> int | None:
"""Return number of samples per chunk or None if chunking is not required."""
class WebRtcVad(VoiceActivityDetector):
"""Voice activity detector based on webrtc."""
def __init__(self) -> None:
"""Initialize webrtcvad."""
# Just VAD: no noise suppression or auto gain
self._audio_processor = AudioProcessor(0, 0)
def is_speech(self, chunk: bytes) -> bool:
"""Return True if audio chunk contains speech."""
result = self._audio_processor.Process10ms(chunk)
return cast(bool, result.is_speech)
@property
def samples_per_chunk(self) -> int | None:
"""Return 10 ms."""
return int(0.01 * _SAMPLE_RATE) # 10 ms
class AudioBuffer:
"""Fixed-sized audio buffer with variable internal length."""
@ -73,13 +106,7 @@ class AudioBuffer:
@dataclass
class VoiceCommandSegmenter:
"""Segments an audio stream into voice commands using webrtcvad."""
vad_mode: int = 3
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
vad_samples_per_chunk: int = 480 # 30 ms
"""Must be 10, 20, or 30 ms at 16Khz."""
"""Segments an audio stream into voice commands."""
speech_seconds: float = 0.3
"""Seconds of speech before voice command has started."""
@ -108,85 +135,85 @@ class VoiceCommandSegmenter:
_reset_seconds_left: float = 0.0
"""Seconds left before resetting start/stop time counters."""
_vad: webrtcvad.Vad = None
_leftover_chunk_buffer: AudioBuffer = field(init=False)
_bytes_per_chunk: int = field(init=False)
_seconds_per_chunk: float = field(init=False)
def __post_init__(self) -> None:
"""Initialize VAD."""
self._vad = webrtcvad.Vad(self.vad_mode)
self._bytes_per_chunk = self.vad_samples_per_chunk * _SAMPLE_WIDTH
self._seconds_per_chunk = self.vad_samples_per_chunk / _SAMPLE_RATE
self._leftover_chunk_buffer = AudioBuffer(
self.vad_samples_per_chunk * _SAMPLE_WIDTH
)
"""Reset after initialization."""
self.reset()
def reset(self) -> None:
"""Reset all counters and state."""
self._leftover_chunk_buffer.clear()
self._speech_seconds_left = self.speech_seconds
self._silence_seconds_left = self.silence_seconds
self._timeout_seconds_left = self.timeout_seconds
self._reset_seconds_left = self.reset_seconds
self.in_command = False
def process(self, samples: bytes) -> bool:
"""Process 16-bit 16Khz mono audio samples.
def process(self, chunk_seconds: float, is_speech: bool | None) -> bool:
"""Process samples using external VAD.
Returns False when command is done.
"""
for chunk in chunk_samples(
samples, self._bytes_per_chunk, self._leftover_chunk_buffer
):
if not self._process_chunk(chunk):
self.reset()
return False
return True
@property
def audio_buffer(self) -> bytes:
"""Get partial chunk in the audio buffer."""
return self._leftover_chunk_buffer.bytes()
def _process_chunk(self, chunk: bytes) -> bool:
"""Process a single chunk of 16-bit 16Khz mono audio.
Returns False when command is done.
"""
is_speech = self._vad.is_speech(chunk, _SAMPLE_RATE)
self._timeout_seconds_left -= self._seconds_per_chunk
self._timeout_seconds_left -= chunk_seconds
if self._timeout_seconds_left <= 0:
self.reset()
return False
if not self.in_command:
if is_speech:
self._reset_seconds_left = self.reset_seconds
self._speech_seconds_left -= self._seconds_per_chunk
self._speech_seconds_left -= chunk_seconds
if self._speech_seconds_left <= 0:
# Inside voice command
self.in_command = True
else:
# Reset if enough silence
self._reset_seconds_left -= self._seconds_per_chunk
self._reset_seconds_left -= chunk_seconds
if self._reset_seconds_left <= 0:
self._speech_seconds_left = self.speech_seconds
elif not is_speech:
self._reset_seconds_left = self.reset_seconds
self._silence_seconds_left -= self._seconds_per_chunk
self._silence_seconds_left -= chunk_seconds
if self._silence_seconds_left <= 0:
self.reset()
return False
else:
# Reset if enough speech
self._reset_seconds_left -= self._seconds_per_chunk
self._reset_seconds_left -= chunk_seconds
if self._reset_seconds_left <= 0:
self._silence_seconds_left = self.silence_seconds
return True
def process_with_vad(
self,
chunk: bytes,
vad: VoiceActivityDetector,
leftover_chunk_buffer: AudioBuffer | None,
) -> bool:
"""Process an audio chunk using an external VAD.
A buffer is required if the VAD requires fixed-sized audio chunks (usually the case).
Returns False when voice command is finished.
"""
if vad.samples_per_chunk is None:
# No chunking
chunk_seconds = (len(chunk) // _SAMPLE_WIDTH) / _SAMPLE_RATE
is_speech = vad.is_speech(chunk)
return self.process(chunk_seconds, is_speech)
if leftover_chunk_buffer is None:
raise ValueError("leftover_chunk_buffer is required when vad uses chunking")
# With chunking
seconds_per_chunk = vad.samples_per_chunk / _SAMPLE_RATE
bytes_per_chunk = vad.samples_per_chunk * _SAMPLE_WIDTH
for vad_chunk in chunk_samples(chunk, bytes_per_chunk, leftover_chunk_buffer):
is_speech = vad.is_speech(vad_chunk)
if not self.process(seconds_per_chunk, is_speech):
return False
return True
@dataclass
class VoiceActivityTimeout:
@ -198,73 +225,43 @@ class VoiceActivityTimeout:
reset_seconds: float = 0.5
"""Seconds of speech before resetting timeout."""
vad_mode: int = 3
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
vad_samples_per_chunk: int = 480 # 30 ms
"""Must be 10, 20, or 30 ms at 16Khz."""
_silence_seconds_left: float = 0.0
"""Seconds left before considering voice command as stopped."""
_reset_seconds_left: float = 0.0
"""Seconds left before resetting start/stop time counters."""
_vad: webrtcvad.Vad = None
_leftover_chunk_buffer: AudioBuffer = field(init=False)
_bytes_per_chunk: int = field(init=False)
_seconds_per_chunk: float = field(init=False)
def __post_init__(self) -> None:
"""Initialize VAD."""
self._vad = webrtcvad.Vad(self.vad_mode)
self._bytes_per_chunk = self.vad_samples_per_chunk * _SAMPLE_WIDTH
self._seconds_per_chunk = self.vad_samples_per_chunk / _SAMPLE_RATE
self._leftover_chunk_buffer = AudioBuffer(
self.vad_samples_per_chunk * _SAMPLE_WIDTH
)
"""Reset after initialization."""
self.reset()
def reset(self) -> None:
"""Reset all counters and state."""
self._leftover_chunk_buffer.clear()
self._silence_seconds_left = self.silence_seconds
self._reset_seconds_left = self.reset_seconds
def process(self, samples: bytes) -> bool:
"""Process 16-bit 16Khz mono audio samples.
def process(self, chunk_seconds: float, is_speech: bool | None) -> bool:
"""Process samples using external VAD.
Returns False when timeout is reached.
"""
for chunk in chunk_samples(
samples, self._bytes_per_chunk, self._leftover_chunk_buffer
):
if not self._process_chunk(chunk):
return False
return True
def _process_chunk(self, chunk: bytes) -> bool:
"""Process a single chunk of 16-bit 16Khz mono audio.
Returns False when timeout is reached.
"""
if self._vad.is_speech(chunk, _SAMPLE_RATE):
if is_speech:
# Speech
self._reset_seconds_left -= self._seconds_per_chunk
self._reset_seconds_left -= chunk_seconds
if self._reset_seconds_left <= 0:
# Reset timeout
self._silence_seconds_left = self.silence_seconds
else:
# Silence
self._silence_seconds_left -= self._seconds_per_chunk
self._silence_seconds_left -= chunk_seconds
if self._silence_seconds_left <= 0:
# Timeout reached
self.reset()
return False
# Slowly build reset counter back up
self._reset_seconds_left = min(
self.reset_seconds, self._reset_seconds_left + self._seconds_per_chunk
self.reset_seconds, self._reset_seconds_left + chunk_seconds
)
return True

View file

@ -18,6 +18,7 @@ from homeassistant.util import language as language_util
from .const import DOMAIN
from .error import PipelineNotFound
from .pipeline import (
AudioSettings,
PipelineData,
PipelineError,
PipelineEvent,
@ -71,6 +72,13 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
vol.Optional("audio_seconds_to_buffer"): vol.Any(
float, int
),
# Audio enhancement
vol.Optional("noise_suppression_level"): int,
vol.Optional("auto_gain_dbfs"): int,
vol.Optional("volume_multiplier"): float,
# Advanced use cases/testing
vol.Optional("no_vad"): bool,
vol.Optional("no_chunking"): bool,
}
},
extra=vol.ALLOW_EXTRA,
@ -115,6 +123,7 @@ async def websocket_run(
handler_id: int | None = None
unregister_handler: Callable[[], None] | None = None
wake_word_settings: WakeWordSettings | None = None
audio_settings: AudioSettings | None = None
# Arguments to PipelineInput
input_args: dict[str, Any] = {
@ -124,13 +133,14 @@ async def websocket_run(
if start_stage in (PipelineStage.WAKE_WORD, PipelineStage.STT):
# Audio pipeline that will receive audio as binary websocket messages
msg_input = msg["input"]
audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
incoming_sample_rate = msg["input"]["sample_rate"]
incoming_sample_rate = msg_input["sample_rate"]
if start_stage == PipelineStage.WAKE_WORD:
wake_word_settings = WakeWordSettings(
timeout=msg["input"].get("timeout", DEFAULT_WAKE_WORD_TIMEOUT),
audio_seconds_to_buffer=msg["input"].get("audio_seconds_to_buffer", 0),
audio_seconds_to_buffer=msg_input.get("audio_seconds_to_buffer", 0),
)
async def stt_stream() -> AsyncGenerator[bytes, None]:
@ -166,6 +176,15 @@ async def websocket_run(
channel=stt.AudioChannels.CHANNEL_MONO,
)
input_args["stt_stream"] = stt_stream()
# Audio settings
audio_settings = AudioSettings(
noise_suppression_level=msg_input.get("noise_suppression_level", 0),
auto_gain_dbfs=msg_input.get("auto_gain_dbfs", 0),
volume_multiplier=msg_input.get("volume_multiplier", 1.0),
is_vad_enabled=not msg_input.get("no_vad", False),
is_chunking_enabled=not msg_input.get("no_chunking", False),
)
elif start_stage == PipelineStage.INTENT:
# Input to conversation agent
input_args["intent_input"] = msg["input"]["text"]
@ -185,6 +204,7 @@ async def websocket_run(
"timeout": timeout,
},
wake_word_settings=wake_word_settings,
audio_settings=audio_settings or AudioSettings(),
)
pipeline_input = PipelineInput(**input_args)

View file

@ -29,8 +29,11 @@ from homeassistant.components.assist_pipeline import (
select as pipeline_select,
)
from homeassistant.components.assist_pipeline.vad import (
AudioBuffer,
VadSensitivity,
VoiceActivityDetector,
VoiceCommandSegmenter,
WebRtcVad,
)
from homeassistant.const import __version__
from homeassistant.core import Context, HomeAssistant
@ -225,11 +228,13 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
try:
# Wait for speech before starting pipeline
segmenter = VoiceCommandSegmenter(silence_seconds=self.silence_seconds)
vad = WebRtcVad()
chunk_buffer: deque[bytes] = deque(
maxlen=self.buffered_chunks_before_speech,
)
speech_detected = await self._wait_for_speech(
segmenter,
vad,
chunk_buffer,
)
if not speech_detected:
@ -243,6 +248,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
try:
async for chunk in self._segment_audio(
segmenter,
vad,
chunk_buffer,
):
yield chunk
@ -306,6 +312,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
async def _wait_for_speech(
self,
segmenter: VoiceCommandSegmenter,
vad: VoiceActivityDetector,
chunk_buffer: MutableSequence[bytes],
):
"""Buffer audio chunks until speech is detected.
@ -317,12 +324,18 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
async with asyncio.timeout(self.audio_timeout):
chunk = await self._audio_queue.get()
assert vad.samples_per_chunk is not None
vad_buffer = AudioBuffer(vad.samples_per_chunk * WIDTH)
while chunk:
chunk_buffer.append(chunk)
segmenter.process(chunk)
segmenter.process_with_vad(chunk, vad, vad_buffer)
if segmenter.in_command:
# Buffer until command starts
if len(vad_buffer) > 0:
chunk_buffer.append(vad_buffer.bytes())
return True
async with asyncio.timeout(self.audio_timeout):
@ -333,6 +346,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
async def _segment_audio(
self,
segmenter: VoiceCommandSegmenter,
vad: VoiceActivityDetector,
chunk_buffer: Sequence[bytes],
) -> AsyncIterable[bytes]:
"""Yield audio chunks until voice command has finished."""
@ -345,8 +359,11 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
async with asyncio.timeout(self.audio_timeout):
chunk = await self._audio_queue.get()
assert vad.samples_per_chunk is not None
vad_buffer = AudioBuffer(vad.samples_per_chunk * WIDTH)
while chunk:
if not segmenter.process(chunk):
if not segmenter.process_with_vad(chunk, vad, vad_buffer):
# Voice command is finished
break

View file

@ -51,7 +51,7 @@ typing-extensions>=4.8.0,<5.0
ulid-transform==0.8.1
voluptuous-serialize==2.6.0
voluptuous==0.13.1
webrtcvad==2.0.10
webrtc-noise-gain==1.1.0
yarl==1.9.2
zeroconf==0.114.0

View file

@ -2691,7 +2691,7 @@ waterfurnace==1.1.0
webexteamssdk==1.1.1
# homeassistant.components.assist_pipeline
webrtcvad==2.0.10
webrtc-noise-gain==1.1.0
# homeassistant.components.whirlpool
whirlpool-sixth-sense==0.18.4

View file

@ -1994,7 +1994,7 @@ wallbox==0.4.12
watchdog==2.3.1
# homeassistant.components.assist_pipeline
webrtcvad==2.0.10
webrtc-noise-gain==1.1.0
# homeassistant.components.whirlpool
whirlpool-sixth-sense==0.18.4

View file

@ -311,18 +311,6 @@
}),
'type': <PipelineEventType.STT_START: 'stt-start'>,
}),
dict({
'data': dict({
'timestamp': 0,
}),
'type': <PipelineEventType.STT_VAD_START: 'stt-vad-start'>,
}),
dict({
'data': dict({
'timestamp': 1500,
}),
'type': <PipelineEventType.STT_VAD_END: 'stt-vad-end'>,
}),
dict({
'data': dict({
'stt_output': dict({

View file

@ -173,6 +173,87 @@
'message': 'No wake-word-detection provider for: wake_word.bad-entity-id',
})
# ---
# name: test_audio_pipeline_with_enhancements
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 30,
}),
})
# ---
# name: test_audio_pipeline_with_enhancements.1
dict({
'engine': 'test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'language': 'en-US',
'sample_rate': 16000,
}),
})
# ---
# name: test_audio_pipeline_with_enhancements.2
dict({
'stt_output': dict({
'text': 'test transcript',
}),
})
# ---
# name: test_audio_pipeline_with_enhancements.3
dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant',
'intent_input': 'test transcript',
'language': 'en',
})
# ---
# name: test_audio_pipeline_with_enhancements.4
dict({
'intent_output': dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'code': 'no_intent_match',
}),
'language': 'en',
'response_type': 'error',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': "Sorry, I couldn't understand that",
}),
}),
}),
}),
})
# ---
# name: test_audio_pipeline_with_enhancements.5
dict({
'engine': 'test',
'language': 'en-US',
'tts_input': "Sorry, I couldn't understand that",
'voice': 'james_earl_jones',
})
# ---
# name: test_audio_pipeline_with_enhancements.6
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
}),
})
# ---
# name: test_audio_pipeline_with_enhancements.7
None
# ---
# name: test_audio_pipeline_with_wake_word
dict({
'language': 'en',

View file

@ -64,6 +64,9 @@ async def test_pipeline_from_audio_stream_auto(
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
)
assert process_events(events) == snapshot
@ -126,6 +129,9 @@ async def test_pipeline_from_audio_stream_legacy(
),
stt_stream=audio_data(),
pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
)
assert process_events(events) == snapshot
@ -188,6 +194,9 @@ async def test_pipeline_from_audio_stream_entity(
),
stt_stream=audio_data(),
pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
)
assert process_events(events) == snapshot
@ -251,6 +260,9 @@ async def test_pipeline_from_audio_stream_no_stt(
),
stt_stream=audio_data(),
pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
)
assert not events
@ -312,44 +324,47 @@ async def test_pipeline_from_audio_stream_wake_word(
# [0, 2, ...]
wake_chunk_2 = bytes(it.islice(it.cycle(range(0, 256, 2)), BYTES_ONE_SECOND))
bytes_per_chunk = int(0.01 * BYTES_ONE_SECOND)
async def audio_data():
yield wake_chunk_1 # 1 second
yield wake_chunk_2 # 1 second
# 1 second in 10 ms chunks
i = 0
while i < len(wake_chunk_1):
yield wake_chunk_1[i : i + bytes_per_chunk]
i += bytes_per_chunk
# 1 second in 30 ms chunks
i = 0
while i < len(wake_chunk_2):
yield wake_chunk_2[i : i + bytes_per_chunk]
i += bytes_per_chunk
yield b"wake word!"
yield b"part1"
yield b"part2"
yield b"end"
yield b""
def continue_stt(self, chunk):
# Ensure stt_vad_start event is triggered
self.in_command = True
# Stop on fake end chunk to trigger stt_vad_end
return chunk != b"end"
with patch(
"homeassistant.components.assist_pipeline.pipeline.VoiceCommandSegmenter.process",
continue_stt,
):
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
wake_word_settings=assist_pipeline.WakeWordSettings(
audio_seconds_to_buffer=1.5
),
)
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
wake_word_settings=assist_pipeline.WakeWordSettings(
audio_seconds_to_buffer=1.5
),
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
)
assert process_events(events) == snapshot
@ -357,12 +372,14 @@ async def test_pipeline_from_audio_stream_wake_word(
# 2. queued audio (from mock wake word entity)
# 3. part1
# 4. part2
assert len(mock_stt_provider.received) == 4
assert len(mock_stt_provider.received) > 3
first_chunk = mock_stt_provider.received[0]
first_chunk = bytes(
[c_byte for c in mock_stt_provider.received[:-3] for c_byte in c]
)
assert first_chunk == wake_chunk_1[len(wake_chunk_1) // 2 :] + wake_chunk_2
assert mock_stt_provider.received[1:] == [b"queued audio", b"part1", b"part2"]
assert mock_stt_provider.received[-3:] == [b"queued audio", b"part1", b"part2"]
async def test_pipeline_save_audio(
@ -410,6 +427,9 @@ async def test_pipeline_save_audio(
pipeline_id=pipeline.id,
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
end_stage=assist_pipeline.PipelineStage.STT,
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
)
pipeline_dirs = list(temp_dir.iterdir())

View file

@ -1,14 +1,15 @@
"""Tests for webrtcvad voice command segmenter."""
"""Tests for voice command segmenter."""
import itertools as it
from unittest.mock import patch
from homeassistant.components.assist_pipeline.vad import (
AudioBuffer,
VoiceActivityDetector,
VoiceCommandSegmenter,
chunk_samples,
)
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit
_ONE_SECOND = 1.0
def test_silence() -> None:
@ -16,87 +17,85 @@ def test_silence() -> None:
segmenter = VoiceCommandSegmenter()
# True return value indicates voice command has not finished
assert segmenter.process(bytes(_ONE_SECOND * 3))
assert segmenter.process(_ONE_SECOND * 3, False)
def test_speech() -> None:
"""Test that silence + speech + silence triggers a voice command."""
def is_speech(self, chunk, sample_rate):
def is_speech(chunk):
"""Anything non-zero is speech."""
return sum(chunk) > 0
with patch(
"webrtcvad.Vad.is_speech",
new=is_speech,
):
segmenter = VoiceCommandSegmenter()
segmenter = VoiceCommandSegmenter()
# silence
assert segmenter.process(bytes(_ONE_SECOND))
# silence
assert segmenter.process(_ONE_SECOND, False)
# "speech"
assert segmenter.process(bytes([255] * _ONE_SECOND))
# "speech"
assert segmenter.process(_ONE_SECOND, True)
# silence
# False return value indicates voice command is finished
assert not segmenter.process(bytes(_ONE_SECOND))
# silence
# False return value indicates voice command is finished
assert not segmenter.process(_ONE_SECOND, False)
def test_audio_buffer() -> None:
"""Test audio buffer wrapping."""
def is_speech(self, chunk, sample_rate):
"""Disable VAD."""
return False
class DisabledVad(VoiceActivityDetector):
def is_speech(self, chunk):
return False
with patch(
"webrtcvad.Vad.is_speech",
new=is_speech,
):
segmenter = VoiceCommandSegmenter()
bytes_per_chunk = segmenter.vad_samples_per_chunk * 2
@property
def samples_per_chunk(self):
return 160 # 10 ms
with patch.object(
segmenter, "_process_chunk", return_value=True
) as mock_process:
# Partially fill audio buffer
half_chunk = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk // 2))
segmenter.process(half_chunk)
vad = DisabledVad()
bytes_per_chunk = vad.samples_per_chunk * 2
vad_buffer = AudioBuffer(bytes_per_chunk)
segmenter = VoiceCommandSegmenter()
assert not mock_process.called
assert segmenter.audio_buffer == half_chunk
with patch.object(vad, "is_speech", return_value=False) as mock_process:
# Partially fill audio buffer
half_chunk = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk // 2))
segmenter.process_with_vad(half_chunk, vad, vad_buffer)
# Fill and wrap with 1/4 chunk left over
three_quarters_chunk = bytes(
it.islice(it.cycle(range(256)), int(0.75 * bytes_per_chunk))
)
segmenter.process(three_quarters_chunk)
assert not mock_process.called
assert vad_buffer is not None
assert vad_buffer.bytes() == half_chunk
assert mock_process.call_count == 1
assert (
segmenter.audio_buffer
== three_quarters_chunk[
len(three_quarters_chunk) - (bytes_per_chunk // 4) :
]
)
assert (
mock_process.call_args[0][0]
== half_chunk + three_quarters_chunk[: bytes_per_chunk // 2]
)
# Fill and wrap with 1/4 chunk left over
three_quarters_chunk = bytes(
it.islice(it.cycle(range(256)), int(0.75 * bytes_per_chunk))
)
segmenter.process_with_vad(three_quarters_chunk, vad, vad_buffer)
# Run 2 chunks through
segmenter.reset()
assert len(segmenter.audio_buffer) == 0
assert mock_process.call_count == 1
assert (
vad_buffer.bytes()
== three_quarters_chunk[
len(three_quarters_chunk) - (bytes_per_chunk // 4) :
]
)
assert (
mock_process.call_args[0][0]
== half_chunk + three_quarters_chunk[: bytes_per_chunk // 2]
)
mock_process.reset_mock()
two_chunks = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk * 2))
segmenter.process(two_chunks)
# Run 2 chunks through
segmenter.reset()
vad_buffer.clear()
assert len(vad_buffer) == 0
assert mock_process.call_count == 2
assert len(segmenter.audio_buffer) == 0
assert mock_process.call_args_list[0][0][0] == two_chunks[:bytes_per_chunk]
assert mock_process.call_args_list[1][0][0] == two_chunks[bytes_per_chunk:]
mock_process.reset_mock()
two_chunks = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk * 2))
segmenter.process_with_vad(two_chunks, vad, vad_buffer)
assert mock_process.call_count == 2
assert len(vad_buffer) == 0
assert mock_process.call_args_list[0][0][0] == two_chunks[:bytes_per_chunk]
assert mock_process.call_args_list[1][0][0] == two_chunks[bytes_per_chunk:]
def test_partial_chunk() -> None:
@ -125,3 +124,43 @@ def test_chunk_samples_leftover() -> None:
assert len(chunks) == 1
assert leftover_chunk_buffer.bytes() == bytes([5, 6])
def test_vad_no_chunking() -> None:
"""Test VAD that doesn't require chunking."""
class VadNoChunk(VoiceActivityDetector):
def is_speech(self, chunk: bytes) -> bool:
return sum(chunk) > 0
@property
def samples_per_chunk(self) -> int | None:
return None
vad = VadNoChunk()
segmenter = VoiceCommandSegmenter(
speech_seconds=1.0, silence_seconds=1.0, reset_seconds=0.5
)
silence = bytes([0] * 16000)
speech = bytes([255] * (16000 // 2))
# Test with differently-sized chunks
assert vad.is_speech(speech)
assert not vad.is_speech(silence)
# Simulate voice command
assert segmenter.process_with_vad(silence, vad, None)
# begin
assert segmenter.process_with_vad(speech, vad, None)
assert segmenter.process_with_vad(speech, vad, None)
assert segmenter.process_with_vad(speech, vad, None)
# reset with silence
assert segmenter.process_with_vad(silence, vad, None)
# resume
assert segmenter.process_with_vad(speech, vad, None)
assert segmenter.process_with_vad(speech, vad, None)
assert segmenter.process_with_vad(speech, vad, None)
assert segmenter.process_with_vad(speech, vad, None)
# end
assert segmenter.process_with_vad(silence, vad, None)
assert not segmenter.process_with_vad(silence, vad, None)

View file

@ -107,6 +107,7 @@ async def test_audio_pipeline(
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
events.append(msg["event"])
# stt
@ -116,7 +117,7 @@ async def test_audio_pipeline(
events.append(msg["event"])
# End of audio stream (handler id + empty payload)
await client.send_bytes(bytes([1]))
await client.send_bytes(bytes([handler_id]))
msg = await client.receive_json()
assert msg["event"]["type"] == "stt-end"
@ -240,6 +241,8 @@ async def test_audio_pipeline_with_wake_word_no_timeout(
"input": {
"sample_rate": 16000,
"timeout": 0,
"no_vad": True,
"no_chunking": True,
},
}
)
@ -253,6 +256,7 @@ async def test_audio_pipeline_with_wake_word_no_timeout(
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
events.append(msg["event"])
# wake_word
@ -276,7 +280,7 @@ async def test_audio_pipeline_with_wake_word_no_timeout(
events.append(msg["event"])
# End of audio stream (handler id + empty payload)
await client.send_bytes(bytes([1]))
await client.send_bytes(bytes([handler_id]))
msg = await client.receive_json()
assert msg["event"]["type"] == "stt-end"
@ -731,6 +735,7 @@ async def test_stt_stream_failed(
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
events.append(msg["event"])
# stt
@ -740,7 +745,7 @@ async def test_stt_stream_failed(
events.append(msg["event"])
# End of audio stream (handler id + empty payload)
await client.send_bytes(b"1")
await client.send_bytes(bytes([handler_id]))
# stt error
msg = await client.receive_json()
@ -1489,6 +1494,7 @@ async def test_audio_pipeline_debug(
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
events.append(msg["event"])
# stt
@ -1498,7 +1504,7 @@ async def test_audio_pipeline_debug(
events.append(msg["event"])
# End of audio stream (handler id + empty payload)
await client.send_bytes(bytes([1]))
await client.send_bytes(bytes([handler_id]))
msg = await client.receive_json()
assert msg["event"]["type"] == "stt-end"
@ -1699,3 +1705,103 @@ async def test_list_pipeline_languages_with_aliases(
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"languages": ["he", "nb"]}
async def test_audio_pipeline_with_enhancements(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test events from a pipeline run with audio input/output."""
events = []
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "stt",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
# Enhancements
"noise_suppression_level": 2,
"auto_gain_dbfs": 15,
"volume_multiplier": 2.0,
},
}
)
# result
msg = await client.receive_json()
assert msg["success"]
# run start
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
events.append(msg["event"])
# stt
msg = await client.receive_json()
assert msg["event"]["type"] == "stt-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# One second of silence.
# This will pass through the audio enhancement pipeline, but we don't test
# the actual output.
await client.send_bytes(bytes([handler_id]) + bytes(16000 * 2))
# End of audio stream (handler id + empty payload)
await client.send_bytes(bytes([handler_id]))
msg = await client.receive_json()
assert msg["event"]["type"] == "stt-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# intent
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# text-to-speech
msg = await client.receive_json()
assert msg["event"]["type"] == "tts-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
msg = await client.receive_json()
assert msg["event"]["type"] == "tts-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# run end
msg = await client.receive_json()
assert msg["event"]["type"] == "run-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = list(pipeline_data.pipeline_runs)[0]
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}

View file

@ -21,7 +21,7 @@ async def test_pipeline(
"""Test that pipeline function is called from RTP protocol."""
assert await async_setup_component(hass, "voip", {})
def is_speech(self, chunk, sample_rate):
def is_speech(self, chunk):
"""Anything non-zero is speech."""
return sum(chunk) > 0
@ -76,7 +76,7 @@ async def test_pipeline(
return ("mp3", b"")
with patch(
"webrtcvad.Vad.is_speech",
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
new=is_speech,
), patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
@ -210,7 +210,7 @@ async def test_tts_timeout(
"""Test that TTS will time out based on its length."""
assert await async_setup_component(hass, "voip", {})
def is_speech(self, chunk, sample_rate):
def is_speech(self, chunk):
"""Anything non-zero is speech."""
return sum(chunk) > 0
@ -269,7 +269,7 @@ async def test_tts_timeout(
return ("raw", bytes(0))
with patch(
"webrtcvad.Vad.is_speech",
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
new=is_speech,
), patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",