diff --git a/homeassistant/components/assist_pipeline/__init__.py b/homeassistant/components/assist_pipeline/__init__.py index 7f87bd254d0..9a61346f673 100644 --- a/homeassistant/components/assist_pipeline/__init__.py +++ b/homeassistant/components/assist_pipeline/__init__.py @@ -12,6 +12,7 @@ from homeassistant.helpers.typing import ConfigType from .const import DATA_CONFIG, DOMAIN from .error import PipelineNotFound from .pipeline import ( + AudioSettings, Pipeline, PipelineEvent, PipelineEventCallback, @@ -33,6 +34,7 @@ __all__ = ( "async_get_pipelines", "async_setup", "async_pipeline_from_audio_stream", + "AudioSettings", "Pipeline", "PipelineEvent", "PipelineEventType", @@ -71,6 +73,7 @@ async def async_pipeline_from_audio_stream( conversation_id: str | None = None, tts_audio_output: str | None = None, wake_word_settings: WakeWordSettings | None = None, + audio_settings: AudioSettings | None = None, device_id: str | None = None, start_stage: PipelineStage = PipelineStage.STT, end_stage: PipelineStage = PipelineStage.TTS, @@ -93,6 +96,7 @@ async def async_pipeline_from_audio_stream( event_callback=event_callback, tts_audio_output=tts_audio_output, wake_word_settings=wake_word_settings, + audio_settings=audio_settings or AudioSettings(), ), ) await pipeline_input.validate() diff --git a/homeassistant/components/assist_pipeline/manifest.json b/homeassistant/components/assist_pipeline/manifest.json index 1db415b29d2..1034d1b5f62 100644 --- a/homeassistant/components/assist_pipeline/manifest.json +++ b/homeassistant/components/assist_pipeline/manifest.json @@ -6,5 +6,5 @@ "documentation": "https://www.home-assistant.io/integrations/assist_pipeline", "iot_class": "local_push", "quality_scale": "internal", - "requirements": ["webrtcvad==2.0.10"] + "requirements": ["webrtc-noise-gain==1.1.0"] } diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index e3b0eafda20..89bb9736737 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -1,7 +1,9 @@ """Classes for voice assistant pipelines.""" from __future__ import annotations +import array import asyncio +from collections import deque from collections.abc import AsyncGenerator, AsyncIterable, Callable, Iterable from dataclasses import asdict, dataclass, field from enum import StrEnum @@ -10,10 +12,11 @@ from pathlib import Path from queue import Queue from threading import Thread import time -from typing import Any, cast +from typing import Any, Final, cast import wave import voluptuous as vol +from webrtc_noise_gain import AudioProcessor from homeassistant.components import ( conversation, @@ -54,8 +57,7 @@ from .error import ( WakeWordDetectionError, WakeWordTimeoutError, ) -from .ring_buffer import RingBuffer -from .vad import VoiceActivityTimeout, VoiceCommandSegmenter +from .vad import AudioBuffer, VoiceActivityTimeout, VoiceCommandSegmenter, chunk_samples _LOGGER = logging.getLogger(__name__) @@ -95,6 +97,9 @@ STORED_PIPELINE_RUNS = 10 SAVE_DELAY = 10 +AUDIO_PROCESSOR_SAMPLES: Final = 160 # 10 ms @ 16 Khz +AUDIO_PROCESSOR_BYTES: Final = AUDIO_PROCESSOR_SAMPLES * 2 # 16-bit samples + async def _async_resolve_default_pipeline_settings( hass: HomeAssistant, @@ -393,6 +398,60 @@ class WakeWordSettings: """Seconds of audio to buffer before detection and forward to STT.""" +@dataclass(frozen=True) +class AudioSettings: + """Settings for pipeline audio processing.""" + + noise_suppression_level: int = 0 + """Level of noise suppression (0 = disabled, 4 = max)""" + + auto_gain_dbfs: int = 0 + """Amount of automatic gain in dbFS (0 = disabled, 31 = max)""" + + volume_multiplier: float = 1.0 + """Multiplier used directly on PCM samples (1.0 = no change, 2.0 = twice as loud)""" + + is_vad_enabled: bool = True + """True if VAD is used to determine the end of the voice command.""" + + is_chunking_enabled: bool = True + """True if audio is automatically split into 10 ms chunks (required for VAD, etc.)""" + + def __post_init__(self) -> None: + """Verify settings post-initialization.""" + if (self.noise_suppression_level < 0) or (self.noise_suppression_level > 4): + raise ValueError("noise_suppression_level must be in [0, 4]") + + if (self.auto_gain_dbfs < 0) or (self.auto_gain_dbfs > 31): + raise ValueError("auto_gain_dbfs must be in [0, 31]") + + if self.needs_processor and (not self.is_chunking_enabled): + raise ValueError("Chunking must be enabled for audio processing") + + @property + def needs_processor(self) -> bool: + """True if an audio processor is needed.""" + return ( + self.is_vad_enabled + or (self.noise_suppression_level > 0) + or (self.auto_gain_dbfs > 0) + ) + + +@dataclass(frozen=True, slots=True) +class ProcessedAudioChunk: + """Processed audio chunk and metadata.""" + + audio: bytes + """Raw PCM audio @ 16Khz with 16-bit mono samples""" + + timestamp_ms: int + """Timestamp relative to start of audio stream (milliseconds)""" + + is_speech: bool | None + """True if audio chunk likely contains speech, False if not, None if unknown""" + + @dataclass class PipelineRun: """Running context for a pipeline.""" @@ -408,6 +467,7 @@ class PipelineRun: intent_agent: str | None = None tts_audio_output: str | None = None wake_word_settings: WakeWordSettings | None = None + audio_settings: AudioSettings = field(default_factory=AudioSettings) id: str = field(default_factory=ulid_util.ulid) stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False) @@ -422,6 +482,12 @@ class PipelineRun: debug_recording_queue: Queue[str | bytes | None] | None = None """Queue to communicate with debug recording thread""" + audio_processor: AudioProcessor | None = None + """VAD/noise suppression/auto gain""" + + audio_processor_buffer: AudioBuffer = field(init=False) + """Buffer used when splitting audio into chunks for audio processing""" + def __post_init__(self) -> None: """Set language for pipeline.""" self.language = self.pipeline.language or self.hass.config.language @@ -439,6 +505,14 @@ class PipelineRun: ) pipeline_data.pipeline_runs[self.pipeline.id][self.id] = PipelineRunDebug() + # Initialize with audio settings + self.audio_processor_buffer = AudioBuffer(AUDIO_PROCESSOR_BYTES) + if self.audio_settings.needs_processor: + self.audio_processor = AudioProcessor( + self.audio_settings.auto_gain_dbfs, + self.audio_settings.noise_suppression_level, + ) + @callback def process_event(self, event: PipelineEvent) -> None: """Log an event and call listener.""" @@ -499,8 +573,8 @@ class PipelineRun: async def wake_word_detection( self, - stream: AsyncIterable[bytes], - audio_chunks_for_stt: list[bytes], + stream: AsyncIterable[ProcessedAudioChunk], + audio_chunks_for_stt: list[ProcessedAudioChunk], ) -> wake_word.DetectionResult | None: """Run wake-word-detection portion of pipeline. Returns detection result.""" metadata_dict = asdict( @@ -541,12 +615,13 @@ class PipelineRun: # Audio chunk buffer. This audio will be forwarded to speech-to-text # after wake-word-detection. - num_audio_bytes_to_buffer = int( - wake_word_settings.audio_seconds_to_buffer * 16000 * 2 # 16-bit @ 16Khz + num_audio_chunks_to_buffer = int( + (wake_word_settings.audio_seconds_to_buffer * 16000) + / AUDIO_PROCESSOR_SAMPLES ) - stt_audio_buffer: RingBuffer | None = None - if num_audio_bytes_to_buffer > 0: - stt_audio_buffer = RingBuffer(num_audio_bytes_to_buffer) + stt_audio_buffer: deque[ProcessedAudioChunk] | None = None + if num_audio_chunks_to_buffer > 0: + stt_audio_buffer = deque(maxlen=num_audio_chunks_to_buffer) try: # Detect wake word(s) @@ -562,7 +637,7 @@ class PipelineRun: if stt_audio_buffer is not None: # All audio kept from right before the wake word was detected as # a single chunk. - audio_chunks_for_stt.append(stt_audio_buffer.getvalue()) + audio_chunks_for_stt.extend(stt_audio_buffer) except WakeWordTimeoutError: _LOGGER.debug("Timeout during wake word detection") raise @@ -586,7 +661,11 @@ class PipelineRun: # speech-to-text so the user does not have to pause before # speaking the voice command. for chunk_ts in result.queued_audio: - audio_chunks_for_stt.append(chunk_ts[0]) + audio_chunks_for_stt.append( + ProcessedAudioChunk( + audio=chunk_ts[0], timestamp_ms=chunk_ts[1], is_speech=False + ) + ) wake_word_output = asdict(result) @@ -604,8 +683,8 @@ class PipelineRun: async def _wake_word_audio_stream( self, - audio_stream: AsyncIterable[bytes], - stt_audio_buffer: RingBuffer | None, + audio_stream: AsyncIterable[ProcessedAudioChunk], + stt_audio_buffer: deque[ProcessedAudioChunk] | None, wake_word_vad: VoiceActivityTimeout | None, sample_rate: int = 16000, sample_width: int = 2, @@ -615,25 +694,24 @@ class PipelineRun: Adds audio to a ring buffer that will be forwarded to speech-to-text after detection. Times out if VAD detects enough silence. """ - ms_per_sample = sample_rate // 1000 - timestamp_ms = 0 + chunk_seconds = AUDIO_PROCESSOR_SAMPLES / sample_rate async for chunk in audio_stream: if self.debug_recording_queue is not None: - self.debug_recording_queue.put_nowait(chunk) + self.debug_recording_queue.put_nowait(chunk.audio) - yield chunk, timestamp_ms - timestamp_ms += (len(chunk) // sample_width) // ms_per_sample + yield chunk.audio, chunk.timestamp_ms # Wake-word-detection occurs *after* the wake word was actually # spoken. Keeping audio right before detection allows the voice # command to be spoken immediately after the wake word. if stt_audio_buffer is not None: - stt_audio_buffer.put(chunk) + stt_audio_buffer.append(chunk) - if (wake_word_vad is not None) and (not wake_word_vad.process(chunk)): - raise WakeWordTimeoutError( - code="wake-word-timeout", message="Wake word was not detected" - ) + if wake_word_vad is not None: + if not wake_word_vad.process(chunk_seconds, chunk.is_speech): + raise WakeWordTimeoutError( + code="wake-word-timeout", message="Wake word was not detected" + ) async def prepare_speech_to_text(self, metadata: stt.SpeechMetadata) -> None: """Prepare speech-to-text.""" @@ -666,7 +744,7 @@ class PipelineRun: async def speech_to_text( self, metadata: stt.SpeechMetadata, - stream: AsyncIterable[bytes], + stream: AsyncIterable[ProcessedAudioChunk], ) -> str: """Run speech-to-text portion of pipeline. Returns the spoken text.""" if isinstance(self.stt_provider, stt.Provider): @@ -690,11 +768,13 @@ class PipelineRun: try: # Transcribe audio stream + stt_vad: VoiceCommandSegmenter | None = None + if self.audio_settings.is_vad_enabled: + stt_vad = VoiceCommandSegmenter() + result = await self.stt_provider.async_process_audio_stream( metadata, - self._speech_to_text_stream( - audio_stream=stream, stt_vad=VoiceCommandSegmenter() - ), + self._speech_to_text_stream(audio_stream=stream, stt_vad=stt_vad), ) except Exception as src_error: _LOGGER.exception("Unexpected error during speech-to-text") @@ -731,26 +811,25 @@ class PipelineRun: async def _speech_to_text_stream( self, - audio_stream: AsyncIterable[bytes], + audio_stream: AsyncIterable[ProcessedAudioChunk], stt_vad: VoiceCommandSegmenter | None, sample_rate: int = 16000, sample_width: int = 2, ) -> AsyncGenerator[bytes, None]: """Yield audio chunks until VAD detects silence or speech-to-text completes.""" - ms_per_sample = sample_rate // 1000 + chunk_seconds = AUDIO_PROCESSOR_SAMPLES / sample_rate sent_vad_start = False - timestamp_ms = 0 async for chunk in audio_stream: if self.debug_recording_queue is not None: - self.debug_recording_queue.put_nowait(chunk) + self.debug_recording_queue.put_nowait(chunk.audio) if stt_vad is not None: - if not stt_vad.process(chunk): + if not stt_vad.process(chunk_seconds, chunk.is_speech): # Silence detected at the end of voice command self.process_event( PipelineEvent( PipelineEventType.STT_VAD_END, - {"timestamp": timestamp_ms}, + {"timestamp": chunk.timestamp_ms}, ) ) break @@ -760,13 +839,12 @@ class PipelineRun: self.process_event( PipelineEvent( PipelineEventType.STT_VAD_START, - {"timestamp": timestamp_ms}, + {"timestamp": chunk.timestamp_ms}, ) ) sent_vad_start = True - yield chunk - timestamp_ms += (len(chunk) // sample_width) // ms_per_sample + yield chunk.audio async def prepare_recognize_intent(self) -> None: """Prepare recognizing an intent.""" @@ -977,6 +1055,94 @@ class PipelineRun: self.debug_recording_queue = None self.debug_recording_thread = None + async def process_volume_only( + self, + audio_stream: AsyncIterable[bytes], + sample_rate: int = 16000, + sample_width: int = 2, + ) -> AsyncGenerator[ProcessedAudioChunk, None]: + """Apply volume transformation only (no VAD/audio enhancements) with optional chunking.""" + ms_per_sample = sample_rate // 1000 + ms_per_chunk = (AUDIO_PROCESSOR_SAMPLES // sample_width) // ms_per_sample + timestamp_ms = 0 + + async for chunk in audio_stream: + if self.audio_settings.volume_multiplier != 1.0: + chunk = _multiply_volume(chunk, self.audio_settings.volume_multiplier) + + if self.audio_settings.is_chunking_enabled: + # 10 ms chunking + for chunk_10ms in chunk_samples( + chunk, AUDIO_PROCESSOR_BYTES, self.audio_processor_buffer + ): + yield ProcessedAudioChunk( + audio=chunk_10ms, + timestamp_ms=timestamp_ms, + is_speech=None, # no VAD + ) + timestamp_ms += ms_per_chunk + else: + # No chunking + yield ProcessedAudioChunk( + audio=chunk, + timestamp_ms=timestamp_ms, + is_speech=None, # no VAD + ) + timestamp_ms += (len(chunk) // sample_width) // ms_per_sample + + async def process_enhance_audio( + self, + audio_stream: AsyncIterable[bytes], + sample_rate: int = 16000, + sample_width: int = 2, + ) -> AsyncGenerator[ProcessedAudioChunk, None]: + """Split audio into 10 ms chunks and apply VAD/noise suppression/auto gain/volume transformation.""" + assert self.audio_processor is not None + + ms_per_sample = sample_rate // 1000 + ms_per_chunk = (AUDIO_PROCESSOR_SAMPLES // sample_width) // ms_per_sample + timestamp_ms = 0 + + async for dirty_samples in audio_stream: + if self.audio_settings.volume_multiplier != 1.0: + # Static gain + dirty_samples = _multiply_volume( + dirty_samples, self.audio_settings.volume_multiplier + ) + + # Split into 10ms chunks for audio enhancements/VAD + for dirty_10ms_chunk in chunk_samples( + dirty_samples, AUDIO_PROCESSOR_BYTES, self.audio_processor_buffer + ): + ap_result = self.audio_processor.Process10ms(dirty_10ms_chunk) + yield ProcessedAudioChunk( + audio=ap_result.audio, + timestamp_ms=timestamp_ms, + is_speech=ap_result.is_speech, + ) + + timestamp_ms += ms_per_chunk + + +def _multiply_volume(chunk: bytes, volume_multiplier: float) -> bytes: + """Multiplies 16-bit PCM samples by a constant.""" + return array.array( + "h", + [ + int( + # Clamp to signed 16-bit range + max( + -32767, + min( + 32767, + value * volume_multiplier, + ), + ) + ) + for value in array.array("h", chunk) + ], + ).tobytes() + def _pipeline_debug_recording_thread_proc( run_recording_dir: Path, @@ -1042,14 +1208,23 @@ class PipelineInput: """Run pipeline.""" self.run.start(device_id=self.device_id) current_stage: PipelineStage | None = self.run.start_stage - stt_audio_buffer: list[bytes] = [] + stt_audio_buffer: list[ProcessedAudioChunk] = [] + stt_processed_stream: AsyncIterable[ProcessedAudioChunk] | None = None + + if self.stt_stream is not None: + if self.run.audio_settings.needs_processor: + # VAD/noise suppression/auto gain/volume + stt_processed_stream = self.run.process_enhance_audio(self.stt_stream) + else: + # Volume multiplier only + stt_processed_stream = self.run.process_volume_only(self.stt_stream) try: if current_stage == PipelineStage.WAKE_WORD: # wake-word-detection - assert self.stt_stream is not None + assert stt_processed_stream is not None detect_result = await self.run.wake_word_detection( - self.stt_stream, stt_audio_buffer + stt_processed_stream, stt_audio_buffer ) if detect_result is None: # No wake word. Abort the rest of the pipeline. @@ -1062,28 +1237,30 @@ class PipelineInput: intent_input = self.intent_input if current_stage == PipelineStage.STT: assert self.stt_metadata is not None - assert self.stt_stream is not None + assert stt_processed_stream is not None - stt_stream = self.stt_stream + stt_input_stream = stt_processed_stream if stt_audio_buffer: # Send audio in the buffer first to speech-to-text, then move on to stt_stream. # This is basically an async itertools.chain. - async def buffer_then_audio_stream() -> AsyncGenerator[bytes, None]: + async def buffer_then_audio_stream() -> AsyncGenerator[ + ProcessedAudioChunk, None + ]: # Buffered audio for chunk in stt_audio_buffer: yield chunk # Streamed audio - assert self.stt_stream is not None - async for chunk in self.stt_stream: + assert stt_processed_stream is not None + async for chunk in stt_processed_stream: yield chunk - stt_stream = buffer_then_audio_stream() + stt_input_stream = buffer_then_audio_stream() intent_input = await self.run.speech_to_text( self.stt_metadata, - stt_stream, + stt_input_stream, ) current_stage = PipelineStage.INTENT diff --git a/homeassistant/components/assist_pipeline/vad.py b/homeassistant/components/assist_pipeline/vad.py index 20a048d5621..30fad1c80d6 100644 --- a/homeassistant/components/assist_pipeline/vad.py +++ b/homeassistant/components/assist_pipeline/vad.py @@ -1,12 +1,13 @@ """Voice activity detection.""" from __future__ import annotations +from abc import ABC, abstractmethod from collections.abc import Iterable -from dataclasses import dataclass, field +from dataclasses import dataclass from enum import StrEnum -from typing import Final +from typing import Final, cast -import webrtcvad +from webrtc_noise_gain import AudioProcessor _SAMPLE_RATE: Final = 16000 # Hz _SAMPLE_WIDTH: Final = 2 # bytes @@ -32,6 +33,38 @@ class VadSensitivity(StrEnum): return 1.0 +class VoiceActivityDetector(ABC): + """Base class for voice activity detectors (VAD).""" + + @abstractmethod + def is_speech(self, chunk: bytes) -> bool: + """Return True if audio chunk contains speech.""" + + @property + @abstractmethod + def samples_per_chunk(self) -> int | None: + """Return number of samples per chunk or None if chunking is not required.""" + + +class WebRtcVad(VoiceActivityDetector): + """Voice activity detector based on webrtc.""" + + def __init__(self) -> None: + """Initialize webrtcvad.""" + # Just VAD: no noise suppression or auto gain + self._audio_processor = AudioProcessor(0, 0) + + def is_speech(self, chunk: bytes) -> bool: + """Return True if audio chunk contains speech.""" + result = self._audio_processor.Process10ms(chunk) + return cast(bool, result.is_speech) + + @property + def samples_per_chunk(self) -> int | None: + """Return 10 ms.""" + return int(0.01 * _SAMPLE_RATE) # 10 ms + + class AudioBuffer: """Fixed-sized audio buffer with variable internal length.""" @@ -73,13 +106,7 @@ class AudioBuffer: @dataclass class VoiceCommandSegmenter: - """Segments an audio stream into voice commands using webrtcvad.""" - - vad_mode: int = 3 - """Aggressiveness in filtering out non-speech. 3 is the most aggressive.""" - - vad_samples_per_chunk: int = 480 # 30 ms - """Must be 10, 20, or 30 ms at 16Khz.""" + """Segments an audio stream into voice commands.""" speech_seconds: float = 0.3 """Seconds of speech before voice command has started.""" @@ -108,85 +135,85 @@ class VoiceCommandSegmenter: _reset_seconds_left: float = 0.0 """Seconds left before resetting start/stop time counters.""" - _vad: webrtcvad.Vad = None - _leftover_chunk_buffer: AudioBuffer = field(init=False) - _bytes_per_chunk: int = field(init=False) - _seconds_per_chunk: float = field(init=False) - def __post_init__(self) -> None: - """Initialize VAD.""" - self._vad = webrtcvad.Vad(self.vad_mode) - self._bytes_per_chunk = self.vad_samples_per_chunk * _SAMPLE_WIDTH - self._seconds_per_chunk = self.vad_samples_per_chunk / _SAMPLE_RATE - self._leftover_chunk_buffer = AudioBuffer( - self.vad_samples_per_chunk * _SAMPLE_WIDTH - ) + """Reset after initialization.""" self.reset() def reset(self) -> None: """Reset all counters and state.""" - self._leftover_chunk_buffer.clear() self._speech_seconds_left = self.speech_seconds self._silence_seconds_left = self.silence_seconds self._timeout_seconds_left = self.timeout_seconds self._reset_seconds_left = self.reset_seconds self.in_command = False - def process(self, samples: bytes) -> bool: - """Process 16-bit 16Khz mono audio samples. + def process(self, chunk_seconds: float, is_speech: bool | None) -> bool: + """Process samples using external VAD. Returns False when command is done. """ - for chunk in chunk_samples( - samples, self._bytes_per_chunk, self._leftover_chunk_buffer - ): - if not self._process_chunk(chunk): - self.reset() - return False - - return True - - @property - def audio_buffer(self) -> bytes: - """Get partial chunk in the audio buffer.""" - return self._leftover_chunk_buffer.bytes() - - def _process_chunk(self, chunk: bytes) -> bool: - """Process a single chunk of 16-bit 16Khz mono audio. - - Returns False when command is done. - """ - is_speech = self._vad.is_speech(chunk, _SAMPLE_RATE) - - self._timeout_seconds_left -= self._seconds_per_chunk + self._timeout_seconds_left -= chunk_seconds if self._timeout_seconds_left <= 0: + self.reset() return False if not self.in_command: if is_speech: self._reset_seconds_left = self.reset_seconds - self._speech_seconds_left -= self._seconds_per_chunk + self._speech_seconds_left -= chunk_seconds if self._speech_seconds_left <= 0: # Inside voice command self.in_command = True else: # Reset if enough silence - self._reset_seconds_left -= self._seconds_per_chunk + self._reset_seconds_left -= chunk_seconds if self._reset_seconds_left <= 0: self._speech_seconds_left = self.speech_seconds elif not is_speech: self._reset_seconds_left = self.reset_seconds - self._silence_seconds_left -= self._seconds_per_chunk + self._silence_seconds_left -= chunk_seconds if self._silence_seconds_left <= 0: + self.reset() return False else: # Reset if enough speech - self._reset_seconds_left -= self._seconds_per_chunk + self._reset_seconds_left -= chunk_seconds if self._reset_seconds_left <= 0: self._silence_seconds_left = self.silence_seconds return True + def process_with_vad( + self, + chunk: bytes, + vad: VoiceActivityDetector, + leftover_chunk_buffer: AudioBuffer | None, + ) -> bool: + """Process an audio chunk using an external VAD. + + A buffer is required if the VAD requires fixed-sized audio chunks (usually the case). + + Returns False when voice command is finished. + """ + if vad.samples_per_chunk is None: + # No chunking + chunk_seconds = (len(chunk) // _SAMPLE_WIDTH) / _SAMPLE_RATE + is_speech = vad.is_speech(chunk) + return self.process(chunk_seconds, is_speech) + + if leftover_chunk_buffer is None: + raise ValueError("leftover_chunk_buffer is required when vad uses chunking") + + # With chunking + seconds_per_chunk = vad.samples_per_chunk / _SAMPLE_RATE + bytes_per_chunk = vad.samples_per_chunk * _SAMPLE_WIDTH + for vad_chunk in chunk_samples(chunk, bytes_per_chunk, leftover_chunk_buffer): + is_speech = vad.is_speech(vad_chunk) + if not self.process(seconds_per_chunk, is_speech): + return False + + return True + @dataclass class VoiceActivityTimeout: @@ -198,73 +225,43 @@ class VoiceActivityTimeout: reset_seconds: float = 0.5 """Seconds of speech before resetting timeout.""" - vad_mode: int = 3 - """Aggressiveness in filtering out non-speech. 3 is the most aggressive.""" - - vad_samples_per_chunk: int = 480 # 30 ms - """Must be 10, 20, or 30 ms at 16Khz.""" - _silence_seconds_left: float = 0.0 """Seconds left before considering voice command as stopped.""" _reset_seconds_left: float = 0.0 """Seconds left before resetting start/stop time counters.""" - _vad: webrtcvad.Vad = None - _leftover_chunk_buffer: AudioBuffer = field(init=False) - _bytes_per_chunk: int = field(init=False) - _seconds_per_chunk: float = field(init=False) - def __post_init__(self) -> None: - """Initialize VAD.""" - self._vad = webrtcvad.Vad(self.vad_mode) - self._bytes_per_chunk = self.vad_samples_per_chunk * _SAMPLE_WIDTH - self._seconds_per_chunk = self.vad_samples_per_chunk / _SAMPLE_RATE - self._leftover_chunk_buffer = AudioBuffer( - self.vad_samples_per_chunk * _SAMPLE_WIDTH - ) + """Reset after initialization.""" self.reset() def reset(self) -> None: """Reset all counters and state.""" - self._leftover_chunk_buffer.clear() self._silence_seconds_left = self.silence_seconds self._reset_seconds_left = self.reset_seconds - def process(self, samples: bytes) -> bool: - """Process 16-bit 16Khz mono audio samples. + def process(self, chunk_seconds: float, is_speech: bool | None) -> bool: + """Process samples using external VAD. Returns False when timeout is reached. """ - for chunk in chunk_samples( - samples, self._bytes_per_chunk, self._leftover_chunk_buffer - ): - if not self._process_chunk(chunk): - return False - - return True - - def _process_chunk(self, chunk: bytes) -> bool: - """Process a single chunk of 16-bit 16Khz mono audio. - - Returns False when timeout is reached. - """ - if self._vad.is_speech(chunk, _SAMPLE_RATE): + if is_speech: # Speech - self._reset_seconds_left -= self._seconds_per_chunk + self._reset_seconds_left -= chunk_seconds if self._reset_seconds_left <= 0: # Reset timeout self._silence_seconds_left = self.silence_seconds else: # Silence - self._silence_seconds_left -= self._seconds_per_chunk + self._silence_seconds_left -= chunk_seconds if self._silence_seconds_left <= 0: # Timeout reached + self.reset() return False # Slowly build reset counter back up self._reset_seconds_left = min( - self.reset_seconds, self._reset_seconds_left + self._seconds_per_chunk + self.reset_seconds, self._reset_seconds_left + chunk_seconds ) return True diff --git a/homeassistant/components/assist_pipeline/websocket_api.py b/homeassistant/components/assist_pipeline/websocket_api.py index 6d8fd02a217..bc542b5c32b 100644 --- a/homeassistant/components/assist_pipeline/websocket_api.py +++ b/homeassistant/components/assist_pipeline/websocket_api.py @@ -18,6 +18,7 @@ from homeassistant.util import language as language_util from .const import DOMAIN from .error import PipelineNotFound from .pipeline import ( + AudioSettings, PipelineData, PipelineError, PipelineEvent, @@ -71,6 +72,13 @@ def async_register_websocket_api(hass: HomeAssistant) -> None: vol.Optional("audio_seconds_to_buffer"): vol.Any( float, int ), + # Audio enhancement + vol.Optional("noise_suppression_level"): int, + vol.Optional("auto_gain_dbfs"): int, + vol.Optional("volume_multiplier"): float, + # Advanced use cases/testing + vol.Optional("no_vad"): bool, + vol.Optional("no_chunking"): bool, } }, extra=vol.ALLOW_EXTRA, @@ -115,6 +123,7 @@ async def websocket_run( handler_id: int | None = None unregister_handler: Callable[[], None] | None = None wake_word_settings: WakeWordSettings | None = None + audio_settings: AudioSettings | None = None # Arguments to PipelineInput input_args: dict[str, Any] = { @@ -124,13 +133,14 @@ async def websocket_run( if start_stage in (PipelineStage.WAKE_WORD, PipelineStage.STT): # Audio pipeline that will receive audio as binary websocket messages + msg_input = msg["input"] audio_queue: asyncio.Queue[bytes] = asyncio.Queue() - incoming_sample_rate = msg["input"]["sample_rate"] + incoming_sample_rate = msg_input["sample_rate"] if start_stage == PipelineStage.WAKE_WORD: wake_word_settings = WakeWordSettings( timeout=msg["input"].get("timeout", DEFAULT_WAKE_WORD_TIMEOUT), - audio_seconds_to_buffer=msg["input"].get("audio_seconds_to_buffer", 0), + audio_seconds_to_buffer=msg_input.get("audio_seconds_to_buffer", 0), ) async def stt_stream() -> AsyncGenerator[bytes, None]: @@ -166,6 +176,15 @@ async def websocket_run( channel=stt.AudioChannels.CHANNEL_MONO, ) input_args["stt_stream"] = stt_stream() + + # Audio settings + audio_settings = AudioSettings( + noise_suppression_level=msg_input.get("noise_suppression_level", 0), + auto_gain_dbfs=msg_input.get("auto_gain_dbfs", 0), + volume_multiplier=msg_input.get("volume_multiplier", 1.0), + is_vad_enabled=not msg_input.get("no_vad", False), + is_chunking_enabled=not msg_input.get("no_chunking", False), + ) elif start_stage == PipelineStage.INTENT: # Input to conversation agent input_args["intent_input"] = msg["input"]["text"] @@ -185,6 +204,7 @@ async def websocket_run( "timeout": timeout, }, wake_word_settings=wake_word_settings, + audio_settings=audio_settings or AudioSettings(), ) pipeline_input = PipelineInput(**input_args) diff --git a/homeassistant/components/voip/voip.py b/homeassistant/components/voip/voip.py index efa62e0e8f4..6ea97268684 100644 --- a/homeassistant/components/voip/voip.py +++ b/homeassistant/components/voip/voip.py @@ -29,8 +29,11 @@ from homeassistant.components.assist_pipeline import ( select as pipeline_select, ) from homeassistant.components.assist_pipeline.vad import ( + AudioBuffer, VadSensitivity, + VoiceActivityDetector, VoiceCommandSegmenter, + WebRtcVad, ) from homeassistant.const import __version__ from homeassistant.core import Context, HomeAssistant @@ -225,11 +228,13 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol): try: # Wait for speech before starting pipeline segmenter = VoiceCommandSegmenter(silence_seconds=self.silence_seconds) + vad = WebRtcVad() chunk_buffer: deque[bytes] = deque( maxlen=self.buffered_chunks_before_speech, ) speech_detected = await self._wait_for_speech( segmenter, + vad, chunk_buffer, ) if not speech_detected: @@ -243,6 +248,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol): try: async for chunk in self._segment_audio( segmenter, + vad, chunk_buffer, ): yield chunk @@ -306,6 +312,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol): async def _wait_for_speech( self, segmenter: VoiceCommandSegmenter, + vad: VoiceActivityDetector, chunk_buffer: MutableSequence[bytes], ): """Buffer audio chunks until speech is detected. @@ -317,12 +324,18 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol): async with asyncio.timeout(self.audio_timeout): chunk = await self._audio_queue.get() + assert vad.samples_per_chunk is not None + vad_buffer = AudioBuffer(vad.samples_per_chunk * WIDTH) + while chunk: chunk_buffer.append(chunk) - segmenter.process(chunk) + segmenter.process_with_vad(chunk, vad, vad_buffer) if segmenter.in_command: # Buffer until command starts + if len(vad_buffer) > 0: + chunk_buffer.append(vad_buffer.bytes()) + return True async with asyncio.timeout(self.audio_timeout): @@ -333,6 +346,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol): async def _segment_audio( self, segmenter: VoiceCommandSegmenter, + vad: VoiceActivityDetector, chunk_buffer: Sequence[bytes], ) -> AsyncIterable[bytes]: """Yield audio chunks until voice command has finished.""" @@ -345,8 +359,11 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol): async with asyncio.timeout(self.audio_timeout): chunk = await self._audio_queue.get() + assert vad.samples_per_chunk is not None + vad_buffer = AudioBuffer(vad.samples_per_chunk * WIDTH) + while chunk: - if not segmenter.process(chunk): + if not segmenter.process_with_vad(chunk, vad, vad_buffer): # Voice command is finished break diff --git a/homeassistant/package_constraints.txt b/homeassistant/package_constraints.txt index 893ac1bc26a..e6c019092f7 100644 --- a/homeassistant/package_constraints.txt +++ b/homeassistant/package_constraints.txt @@ -51,7 +51,7 @@ typing-extensions>=4.8.0,<5.0 ulid-transform==0.8.1 voluptuous-serialize==2.6.0 voluptuous==0.13.1 -webrtcvad==2.0.10 +webrtc-noise-gain==1.1.0 yarl==1.9.2 zeroconf==0.114.0 diff --git a/requirements_all.txt b/requirements_all.txt index 506cc9a5f96..738794f5a84 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -2691,7 +2691,7 @@ waterfurnace==1.1.0 webexteamssdk==1.1.1 # homeassistant.components.assist_pipeline -webrtcvad==2.0.10 +webrtc-noise-gain==1.1.0 # homeassistant.components.whirlpool whirlpool-sixth-sense==0.18.4 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 62eade4c102..c93ff021492 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -1994,7 +1994,7 @@ wallbox==0.4.12 watchdog==2.3.1 # homeassistant.components.assist_pipeline -webrtcvad==2.0.10 +webrtc-noise-gain==1.1.0 # homeassistant.components.whirlpool whirlpool-sixth-sense==0.18.4 diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index f80f294c09d..f36a334d97d 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -311,18 +311,6 @@ }), 'type': , }), - dict({ - 'data': dict({ - 'timestamp': 0, - }), - 'type': , - }), - dict({ - 'data': dict({ - 'timestamp': 1500, - }), - 'type': , - }), dict({ 'data': dict({ 'stt_output': dict({ diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index e8eb573b374..dd88997262f 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -173,6 +173,87 @@ 'message': 'No wake-word-detection provider for: wake_word.bad-entity-id', }) # --- +# name: test_audio_pipeline_with_enhancements + dict({ + 'language': 'en', + 'pipeline': , + 'runner_data': dict({ + 'stt_binary_handler_id': 1, + 'timeout': 30, + }), + }) +# --- +# name: test_audio_pipeline_with_enhancements.1 + dict({ + 'engine': 'test', + 'metadata': dict({ + 'bit_rate': 16, + 'channel': 1, + 'codec': 'pcm', + 'format': 'wav', + 'language': 'en-US', + 'sample_rate': 16000, + }), + }) +# --- +# name: test_audio_pipeline_with_enhancements.2 + dict({ + 'stt_output': dict({ + 'text': 'test transcript', + }), + }) +# --- +# name: test_audio_pipeline_with_enhancements.3 + dict({ + 'conversation_id': None, + 'device_id': None, + 'engine': 'homeassistant', + 'intent_input': 'test transcript', + 'language': 'en', + }) +# --- +# name: test_audio_pipeline_with_enhancements.4 + dict({ + 'intent_output': dict({ + 'conversation_id': None, + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'code': 'no_intent_match', + }), + 'language': 'en', + 'response_type': 'error', + 'speech': dict({ + 'plain': dict({ + 'extra_data': None, + 'speech': "Sorry, I couldn't understand that", + }), + }), + }), + }), + }) +# --- +# name: test_audio_pipeline_with_enhancements.5 + dict({ + 'engine': 'test', + 'language': 'en-US', + 'tts_input': "Sorry, I couldn't understand that", + 'voice': 'james_earl_jones', + }) +# --- +# name: test_audio_pipeline_with_enhancements.6 + dict({ + 'tts_output': dict({ + 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones", + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3', + }), + }) +# --- +# name: test_audio_pipeline_with_enhancements.7 + None +# --- # name: test_audio_pipeline_with_wake_word dict({ 'language': 'en', diff --git a/tests/components/assist_pipeline/test_init.py b/tests/components/assist_pipeline/test_init.py index 1a7362aab80..b41e23d7a0d 100644 --- a/tests/components/assist_pipeline/test_init.py +++ b/tests/components/assist_pipeline/test_init.py @@ -64,6 +64,9 @@ async def test_pipeline_from_audio_stream_auto( channel=stt.AudioChannels.CHANNEL_MONO, ), stt_stream=audio_data(), + audio_settings=assist_pipeline.AudioSettings( + is_vad_enabled=False, is_chunking_enabled=False + ), ) assert process_events(events) == snapshot @@ -126,6 +129,9 @@ async def test_pipeline_from_audio_stream_legacy( ), stt_stream=audio_data(), pipeline_id=pipeline_id, + audio_settings=assist_pipeline.AudioSettings( + is_vad_enabled=False, is_chunking_enabled=False + ), ) assert process_events(events) == snapshot @@ -188,6 +194,9 @@ async def test_pipeline_from_audio_stream_entity( ), stt_stream=audio_data(), pipeline_id=pipeline_id, + audio_settings=assist_pipeline.AudioSettings( + is_vad_enabled=False, is_chunking_enabled=False + ), ) assert process_events(events) == snapshot @@ -251,6 +260,9 @@ async def test_pipeline_from_audio_stream_no_stt( ), stt_stream=audio_data(), pipeline_id=pipeline_id, + audio_settings=assist_pipeline.AudioSettings( + is_vad_enabled=False, is_chunking_enabled=False + ), ) assert not events @@ -312,44 +324,47 @@ async def test_pipeline_from_audio_stream_wake_word( # [0, 2, ...] wake_chunk_2 = bytes(it.islice(it.cycle(range(0, 256, 2)), BYTES_ONE_SECOND)) + bytes_per_chunk = int(0.01 * BYTES_ONE_SECOND) + async def audio_data(): - yield wake_chunk_1 # 1 second - yield wake_chunk_2 # 1 second + # 1 second in 10 ms chunks + i = 0 + while i < len(wake_chunk_1): + yield wake_chunk_1[i : i + bytes_per_chunk] + i += bytes_per_chunk + + # 1 second in 30 ms chunks + i = 0 + while i < len(wake_chunk_2): + yield wake_chunk_2[i : i + bytes_per_chunk] + i += bytes_per_chunk + yield b"wake word!" yield b"part1" yield b"part2" - yield b"end" yield b"" - def continue_stt(self, chunk): - # Ensure stt_vad_start event is triggered - self.in_command = True - - # Stop on fake end chunk to trigger stt_vad_end - return chunk != b"end" - - with patch( - "homeassistant.components.assist_pipeline.pipeline.VoiceCommandSegmenter.process", - continue_stt, - ): - await assist_pipeline.async_pipeline_from_audio_stream( - hass, - context=Context(), - event_callback=events.append, - stt_metadata=stt.SpeechMetadata( - language="", - format=stt.AudioFormats.WAV, - codec=stt.AudioCodecs.PCM, - bit_rate=stt.AudioBitRates.BITRATE_16, - sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, - channel=stt.AudioChannels.CHANNEL_MONO, - ), - stt_stream=audio_data(), - start_stage=assist_pipeline.PipelineStage.WAKE_WORD, - wake_word_settings=assist_pipeline.WakeWordSettings( - audio_seconds_to_buffer=1.5 - ), - ) + await assist_pipeline.async_pipeline_from_audio_stream( + hass, + context=Context(), + event_callback=events.append, + stt_metadata=stt.SpeechMetadata( + language="", + format=stt.AudioFormats.WAV, + codec=stt.AudioCodecs.PCM, + bit_rate=stt.AudioBitRates.BITRATE_16, + sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, + channel=stt.AudioChannels.CHANNEL_MONO, + ), + stt_stream=audio_data(), + start_stage=assist_pipeline.PipelineStage.WAKE_WORD, + wake_word_settings=assist_pipeline.WakeWordSettings( + audio_seconds_to_buffer=1.5 + ), + audio_settings=assist_pipeline.AudioSettings( + is_vad_enabled=False, is_chunking_enabled=False + ), + ) assert process_events(events) == snapshot @@ -357,12 +372,14 @@ async def test_pipeline_from_audio_stream_wake_word( # 2. queued audio (from mock wake word entity) # 3. part1 # 4. part2 - assert len(mock_stt_provider.received) == 4 + assert len(mock_stt_provider.received) > 3 - first_chunk = mock_stt_provider.received[0] + first_chunk = bytes( + [c_byte for c in mock_stt_provider.received[:-3] for c_byte in c] + ) assert first_chunk == wake_chunk_1[len(wake_chunk_1) // 2 :] + wake_chunk_2 - assert mock_stt_provider.received[1:] == [b"queued audio", b"part1", b"part2"] + assert mock_stt_provider.received[-3:] == [b"queued audio", b"part1", b"part2"] async def test_pipeline_save_audio( @@ -410,6 +427,9 @@ async def test_pipeline_save_audio( pipeline_id=pipeline.id, start_stage=assist_pipeline.PipelineStage.WAKE_WORD, end_stage=assist_pipeline.PipelineStage.STT, + audio_settings=assist_pipeline.AudioSettings( + is_vad_enabled=False, is_chunking_enabled=False + ), ) pipeline_dirs = list(temp_dir.iterdir()) diff --git a/tests/components/assist_pipeline/test_vad.py b/tests/components/assist_pipeline/test_vad.py index 4dc8c8f6197..57b567c49df 100644 --- a/tests/components/assist_pipeline/test_vad.py +++ b/tests/components/assist_pipeline/test_vad.py @@ -1,14 +1,15 @@ -"""Tests for webrtcvad voice command segmenter.""" +"""Tests for voice command segmenter.""" import itertools as it from unittest.mock import patch from homeassistant.components.assist_pipeline.vad import ( AudioBuffer, + VoiceActivityDetector, VoiceCommandSegmenter, chunk_samples, ) -_ONE_SECOND = 16000 * 2 # 16Khz 16-bit +_ONE_SECOND = 1.0 def test_silence() -> None: @@ -16,87 +17,85 @@ def test_silence() -> None: segmenter = VoiceCommandSegmenter() # True return value indicates voice command has not finished - assert segmenter.process(bytes(_ONE_SECOND * 3)) + assert segmenter.process(_ONE_SECOND * 3, False) def test_speech() -> None: """Test that silence + speech + silence triggers a voice command.""" - def is_speech(self, chunk, sample_rate): + def is_speech(chunk): """Anything non-zero is speech.""" return sum(chunk) > 0 - with patch( - "webrtcvad.Vad.is_speech", - new=is_speech, - ): - segmenter = VoiceCommandSegmenter() + segmenter = VoiceCommandSegmenter() - # silence - assert segmenter.process(bytes(_ONE_SECOND)) + # silence + assert segmenter.process(_ONE_SECOND, False) - # "speech" - assert segmenter.process(bytes([255] * _ONE_SECOND)) + # "speech" + assert segmenter.process(_ONE_SECOND, True) - # silence - # False return value indicates voice command is finished - assert not segmenter.process(bytes(_ONE_SECOND)) + # silence + # False return value indicates voice command is finished + assert not segmenter.process(_ONE_SECOND, False) def test_audio_buffer() -> None: """Test audio buffer wrapping.""" - def is_speech(self, chunk, sample_rate): - """Disable VAD.""" - return False + class DisabledVad(VoiceActivityDetector): + def is_speech(self, chunk): + return False - with patch( - "webrtcvad.Vad.is_speech", - new=is_speech, - ): - segmenter = VoiceCommandSegmenter() - bytes_per_chunk = segmenter.vad_samples_per_chunk * 2 + @property + def samples_per_chunk(self): + return 160 # 10 ms - with patch.object( - segmenter, "_process_chunk", return_value=True - ) as mock_process: - # Partially fill audio buffer - half_chunk = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk // 2)) - segmenter.process(half_chunk) + vad = DisabledVad() + bytes_per_chunk = vad.samples_per_chunk * 2 + vad_buffer = AudioBuffer(bytes_per_chunk) + segmenter = VoiceCommandSegmenter() - assert not mock_process.called - assert segmenter.audio_buffer == half_chunk + with patch.object(vad, "is_speech", return_value=False) as mock_process: + # Partially fill audio buffer + half_chunk = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk // 2)) + segmenter.process_with_vad(half_chunk, vad, vad_buffer) - # Fill and wrap with 1/4 chunk left over - three_quarters_chunk = bytes( - it.islice(it.cycle(range(256)), int(0.75 * bytes_per_chunk)) - ) - segmenter.process(three_quarters_chunk) + assert not mock_process.called + assert vad_buffer is not None + assert vad_buffer.bytes() == half_chunk - assert mock_process.call_count == 1 - assert ( - segmenter.audio_buffer - == three_quarters_chunk[ - len(three_quarters_chunk) - (bytes_per_chunk // 4) : - ] - ) - assert ( - mock_process.call_args[0][0] - == half_chunk + three_quarters_chunk[: bytes_per_chunk // 2] - ) + # Fill and wrap with 1/4 chunk left over + three_quarters_chunk = bytes( + it.islice(it.cycle(range(256)), int(0.75 * bytes_per_chunk)) + ) + segmenter.process_with_vad(three_quarters_chunk, vad, vad_buffer) - # Run 2 chunks through - segmenter.reset() - assert len(segmenter.audio_buffer) == 0 + assert mock_process.call_count == 1 + assert ( + vad_buffer.bytes() + == three_quarters_chunk[ + len(three_quarters_chunk) - (bytes_per_chunk // 4) : + ] + ) + assert ( + mock_process.call_args[0][0] + == half_chunk + three_quarters_chunk[: bytes_per_chunk // 2] + ) - mock_process.reset_mock() - two_chunks = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk * 2)) - segmenter.process(two_chunks) + # Run 2 chunks through + segmenter.reset() + vad_buffer.clear() + assert len(vad_buffer) == 0 - assert mock_process.call_count == 2 - assert len(segmenter.audio_buffer) == 0 - assert mock_process.call_args_list[0][0][0] == two_chunks[:bytes_per_chunk] - assert mock_process.call_args_list[1][0][0] == two_chunks[bytes_per_chunk:] + mock_process.reset_mock() + two_chunks = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk * 2)) + segmenter.process_with_vad(two_chunks, vad, vad_buffer) + + assert mock_process.call_count == 2 + assert len(vad_buffer) == 0 + assert mock_process.call_args_list[0][0][0] == two_chunks[:bytes_per_chunk] + assert mock_process.call_args_list[1][0][0] == two_chunks[bytes_per_chunk:] def test_partial_chunk() -> None: @@ -125,3 +124,43 @@ def test_chunk_samples_leftover() -> None: assert len(chunks) == 1 assert leftover_chunk_buffer.bytes() == bytes([5, 6]) + + +def test_vad_no_chunking() -> None: + """Test VAD that doesn't require chunking.""" + + class VadNoChunk(VoiceActivityDetector): + def is_speech(self, chunk: bytes) -> bool: + return sum(chunk) > 0 + + @property + def samples_per_chunk(self) -> int | None: + return None + + vad = VadNoChunk() + segmenter = VoiceCommandSegmenter( + speech_seconds=1.0, silence_seconds=1.0, reset_seconds=0.5 + ) + silence = bytes([0] * 16000) + speech = bytes([255] * (16000 // 2)) + + # Test with differently-sized chunks + assert vad.is_speech(speech) + assert not vad.is_speech(silence) + + # Simulate voice command + assert segmenter.process_with_vad(silence, vad, None) + # begin + assert segmenter.process_with_vad(speech, vad, None) + assert segmenter.process_with_vad(speech, vad, None) + assert segmenter.process_with_vad(speech, vad, None) + # reset with silence + assert segmenter.process_with_vad(silence, vad, None) + # resume + assert segmenter.process_with_vad(speech, vad, None) + assert segmenter.process_with_vad(speech, vad, None) + assert segmenter.process_with_vad(speech, vad, None) + assert segmenter.process_with_vad(speech, vad, None) + # end + assert segmenter.process_with_vad(silence, vad, None) + assert not segmenter.process_with_vad(silence, vad, None) diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index e3561e77852..76ec88b009b 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -107,6 +107,7 @@ async def test_audio_pipeline( assert msg["event"]["type"] == "run-start" msg["event"]["data"]["pipeline"] = ANY assert msg["event"]["data"] == snapshot + handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"] events.append(msg["event"]) # stt @@ -116,7 +117,7 @@ async def test_audio_pipeline( events.append(msg["event"]) # End of audio stream (handler id + empty payload) - await client.send_bytes(bytes([1])) + await client.send_bytes(bytes([handler_id])) msg = await client.receive_json() assert msg["event"]["type"] == "stt-end" @@ -240,6 +241,8 @@ async def test_audio_pipeline_with_wake_word_no_timeout( "input": { "sample_rate": 16000, "timeout": 0, + "no_vad": True, + "no_chunking": True, }, } ) @@ -253,6 +256,7 @@ async def test_audio_pipeline_with_wake_word_no_timeout( assert msg["event"]["type"] == "run-start" msg["event"]["data"]["pipeline"] = ANY assert msg["event"]["data"] == snapshot + handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"] events.append(msg["event"]) # wake_word @@ -276,7 +280,7 @@ async def test_audio_pipeline_with_wake_word_no_timeout( events.append(msg["event"]) # End of audio stream (handler id + empty payload) - await client.send_bytes(bytes([1])) + await client.send_bytes(bytes([handler_id])) msg = await client.receive_json() assert msg["event"]["type"] == "stt-end" @@ -731,6 +735,7 @@ async def test_stt_stream_failed( assert msg["event"]["type"] == "run-start" msg["event"]["data"]["pipeline"] = ANY assert msg["event"]["data"] == snapshot + handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"] events.append(msg["event"]) # stt @@ -740,7 +745,7 @@ async def test_stt_stream_failed( events.append(msg["event"]) # End of audio stream (handler id + empty payload) - await client.send_bytes(b"1") + await client.send_bytes(bytes([handler_id])) # stt error msg = await client.receive_json() @@ -1489,6 +1494,7 @@ async def test_audio_pipeline_debug( assert msg["event"]["type"] == "run-start" msg["event"]["data"]["pipeline"] = ANY assert msg["event"]["data"] == snapshot + handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"] events.append(msg["event"]) # stt @@ -1498,7 +1504,7 @@ async def test_audio_pipeline_debug( events.append(msg["event"]) # End of audio stream (handler id + empty payload) - await client.send_bytes(bytes([1])) + await client.send_bytes(bytes([handler_id])) msg = await client.receive_json() assert msg["event"]["type"] == "stt-end" @@ -1699,3 +1705,103 @@ async def test_list_pipeline_languages_with_aliases( msg = await client.receive_json() assert msg["success"] assert msg["result"] == {"languages": ["he", "nb"]} + + +async def test_audio_pipeline_with_enhancements( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, + snapshot: SnapshotAssertion, +) -> None: + """Test events from a pipeline run with audio input/output.""" + events = [] + client = await hass_ws_client(hass) + + await client.send_json_auto_id( + { + "type": "assist_pipeline/run", + "start_stage": "stt", + "end_stage": "tts", + "input": { + "sample_rate": 16000, + # Enhancements + "noise_suppression_level": 2, + "auto_gain_dbfs": 15, + "volume_multiplier": 2.0, + }, + } + ) + + # result + msg = await client.receive_json() + assert msg["success"] + + # run start + msg = await client.receive_json() + assert msg["event"]["type"] == "run-start" + msg["event"]["data"]["pipeline"] = ANY + assert msg["event"]["data"] == snapshot + handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"] + events.append(msg["event"]) + + # stt + msg = await client.receive_json() + assert msg["event"]["type"] == "stt-start" + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + # One second of silence. + # This will pass through the audio enhancement pipeline, but we don't test + # the actual output. + await client.send_bytes(bytes([handler_id]) + bytes(16000 * 2)) + + # End of audio stream (handler id + empty payload) + await client.send_bytes(bytes([handler_id])) + + msg = await client.receive_json() + assert msg["event"]["type"] == "stt-end" + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + # intent + msg = await client.receive_json() + assert msg["event"]["type"] == "intent-start" + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + msg = await client.receive_json() + assert msg["event"]["type"] == "intent-end" + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + # text-to-speech + msg = await client.receive_json() + assert msg["event"]["type"] == "tts-start" + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + msg = await client.receive_json() + assert msg["event"]["type"] == "tts-end" + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + # run end + msg = await client.receive_json() + assert msg["event"]["type"] == "run-end" + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_id = list(pipeline_data.pipeline_runs)[0] + pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0] + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline_debug/get", + "pipeline_id": pipeline_id, + "pipeline_run_id": pipeline_run_id, + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == {"events": events} diff --git a/tests/components/voip/test_voip.py b/tests/components/voip/test_voip.py index 361e4e7f0e2..f82a00087c6 100644 --- a/tests/components/voip/test_voip.py +++ b/tests/components/voip/test_voip.py @@ -21,7 +21,7 @@ async def test_pipeline( """Test that pipeline function is called from RTP protocol.""" assert await async_setup_component(hass, "voip", {}) - def is_speech(self, chunk, sample_rate): + def is_speech(self, chunk): """Anything non-zero is speech.""" return sum(chunk) > 0 @@ -76,7 +76,7 @@ async def test_pipeline( return ("mp3", b"") with patch( - "webrtcvad.Vad.is_speech", + "homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech", new=is_speech, ), patch( "homeassistant.components.voip.voip.async_pipeline_from_audio_stream", @@ -210,7 +210,7 @@ async def test_tts_timeout( """Test that TTS will time out based on its length.""" assert await async_setup_component(hass, "voip", {}) - def is_speech(self, chunk, sample_rate): + def is_speech(self, chunk): """Anything non-zero is speech.""" return sum(chunk) > 0 @@ -269,7 +269,7 @@ async def test_tts_timeout( return ("raw", bytes(0)) with patch( - "webrtcvad.Vad.is_speech", + "homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech", new=is_speech, ), patch( "homeassistant.components.voip.voip.async_pipeline_from_audio_stream",