Use webrtc-noise-gain for audio enhancement in Assist pipelines (#100698)
* Use webrtc-noise-gain instead of webrtcvad package * Switching to ProcessedAudioChunk * Refactor VAD and fix tests * Add vad no chunking test * Add test that runs audio enhancements
This commit is contained in:
parent
a4f7f3ba7e
commit
785618909a
15 changed files with 707 additions and 258 deletions
|
@ -12,6 +12,7 @@ from homeassistant.helpers.typing import ConfigType
|
|||
from .const import DATA_CONFIG, DOMAIN
|
||||
from .error import PipelineNotFound
|
||||
from .pipeline import (
|
||||
AudioSettings,
|
||||
Pipeline,
|
||||
PipelineEvent,
|
||||
PipelineEventCallback,
|
||||
|
@ -33,6 +34,7 @@ __all__ = (
|
|||
"async_get_pipelines",
|
||||
"async_setup",
|
||||
"async_pipeline_from_audio_stream",
|
||||
"AudioSettings",
|
||||
"Pipeline",
|
||||
"PipelineEvent",
|
||||
"PipelineEventType",
|
||||
|
@ -71,6 +73,7 @@ async def async_pipeline_from_audio_stream(
|
|||
conversation_id: str | None = None,
|
||||
tts_audio_output: str | None = None,
|
||||
wake_word_settings: WakeWordSettings | None = None,
|
||||
audio_settings: AudioSettings | None = None,
|
||||
device_id: str | None = None,
|
||||
start_stage: PipelineStage = PipelineStage.STT,
|
||||
end_stage: PipelineStage = PipelineStage.TTS,
|
||||
|
@ -93,6 +96,7 @@ async def async_pipeline_from_audio_stream(
|
|||
event_callback=event_callback,
|
||||
tts_audio_output=tts_audio_output,
|
||||
wake_word_settings=wake_word_settings,
|
||||
audio_settings=audio_settings or AudioSettings(),
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
|
|
|
@ -6,5 +6,5 @@
|
|||
"documentation": "https://www.home-assistant.io/integrations/assist_pipeline",
|
||||
"iot_class": "local_push",
|
||||
"quality_scale": "internal",
|
||||
"requirements": ["webrtcvad==2.0.10"]
|
||||
"requirements": ["webrtc-noise-gain==1.1.0"]
|
||||
}
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
"""Classes for voice assistant pipelines."""
|
||||
from __future__ import annotations
|
||||
|
||||
import array
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from collections.abc import AsyncGenerator, AsyncIterable, Callable, Iterable
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import StrEnum
|
||||
|
@ -10,10 +12,11 @@ from pathlib import Path
|
|||
from queue import Queue
|
||||
from threading import Thread
|
||||
import time
|
||||
from typing import Any, cast
|
||||
from typing import Any, Final, cast
|
||||
import wave
|
||||
|
||||
import voluptuous as vol
|
||||
from webrtc_noise_gain import AudioProcessor
|
||||
|
||||
from homeassistant.components import (
|
||||
conversation,
|
||||
|
@ -54,8 +57,7 @@ from .error import (
|
|||
WakeWordDetectionError,
|
||||
WakeWordTimeoutError,
|
||||
)
|
||||
from .ring_buffer import RingBuffer
|
||||
from .vad import VoiceActivityTimeout, VoiceCommandSegmenter
|
||||
from .vad import AudioBuffer, VoiceActivityTimeout, VoiceCommandSegmenter, chunk_samples
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -95,6 +97,9 @@ STORED_PIPELINE_RUNS = 10
|
|||
|
||||
SAVE_DELAY = 10
|
||||
|
||||
AUDIO_PROCESSOR_SAMPLES: Final = 160 # 10 ms @ 16 Khz
|
||||
AUDIO_PROCESSOR_BYTES: Final = AUDIO_PROCESSOR_SAMPLES * 2 # 16-bit samples
|
||||
|
||||
|
||||
async def _async_resolve_default_pipeline_settings(
|
||||
hass: HomeAssistant,
|
||||
|
@ -393,6 +398,60 @@ class WakeWordSettings:
|
|||
"""Seconds of audio to buffer before detection and forward to STT."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AudioSettings:
|
||||
"""Settings for pipeline audio processing."""
|
||||
|
||||
noise_suppression_level: int = 0
|
||||
"""Level of noise suppression (0 = disabled, 4 = max)"""
|
||||
|
||||
auto_gain_dbfs: int = 0
|
||||
"""Amount of automatic gain in dbFS (0 = disabled, 31 = max)"""
|
||||
|
||||
volume_multiplier: float = 1.0
|
||||
"""Multiplier used directly on PCM samples (1.0 = no change, 2.0 = twice as loud)"""
|
||||
|
||||
is_vad_enabled: bool = True
|
||||
"""True if VAD is used to determine the end of the voice command."""
|
||||
|
||||
is_chunking_enabled: bool = True
|
||||
"""True if audio is automatically split into 10 ms chunks (required for VAD, etc.)"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Verify settings post-initialization."""
|
||||
if (self.noise_suppression_level < 0) or (self.noise_suppression_level > 4):
|
||||
raise ValueError("noise_suppression_level must be in [0, 4]")
|
||||
|
||||
if (self.auto_gain_dbfs < 0) or (self.auto_gain_dbfs > 31):
|
||||
raise ValueError("auto_gain_dbfs must be in [0, 31]")
|
||||
|
||||
if self.needs_processor and (not self.is_chunking_enabled):
|
||||
raise ValueError("Chunking must be enabled for audio processing")
|
||||
|
||||
@property
|
||||
def needs_processor(self) -> bool:
|
||||
"""True if an audio processor is needed."""
|
||||
return (
|
||||
self.is_vad_enabled
|
||||
or (self.noise_suppression_level > 0)
|
||||
or (self.auto_gain_dbfs > 0)
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ProcessedAudioChunk:
|
||||
"""Processed audio chunk and metadata."""
|
||||
|
||||
audio: bytes
|
||||
"""Raw PCM audio @ 16Khz with 16-bit mono samples"""
|
||||
|
||||
timestamp_ms: int
|
||||
"""Timestamp relative to start of audio stream (milliseconds)"""
|
||||
|
||||
is_speech: bool | None
|
||||
"""True if audio chunk likely contains speech, False if not, None if unknown"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineRun:
|
||||
"""Running context for a pipeline."""
|
||||
|
@ -408,6 +467,7 @@ class PipelineRun:
|
|||
intent_agent: str | None = None
|
||||
tts_audio_output: str | None = None
|
||||
wake_word_settings: WakeWordSettings | None = None
|
||||
audio_settings: AudioSettings = field(default_factory=AudioSettings)
|
||||
|
||||
id: str = field(default_factory=ulid_util.ulid)
|
||||
stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False)
|
||||
|
@ -422,6 +482,12 @@ class PipelineRun:
|
|||
debug_recording_queue: Queue[str | bytes | None] | None = None
|
||||
"""Queue to communicate with debug recording thread"""
|
||||
|
||||
audio_processor: AudioProcessor | None = None
|
||||
"""VAD/noise suppression/auto gain"""
|
||||
|
||||
audio_processor_buffer: AudioBuffer = field(init=False)
|
||||
"""Buffer used when splitting audio into chunks for audio processing"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Set language for pipeline."""
|
||||
self.language = self.pipeline.language or self.hass.config.language
|
||||
|
@ -439,6 +505,14 @@ class PipelineRun:
|
|||
)
|
||||
pipeline_data.pipeline_runs[self.pipeline.id][self.id] = PipelineRunDebug()
|
||||
|
||||
# Initialize with audio settings
|
||||
self.audio_processor_buffer = AudioBuffer(AUDIO_PROCESSOR_BYTES)
|
||||
if self.audio_settings.needs_processor:
|
||||
self.audio_processor = AudioProcessor(
|
||||
self.audio_settings.auto_gain_dbfs,
|
||||
self.audio_settings.noise_suppression_level,
|
||||
)
|
||||
|
||||
@callback
|
||||
def process_event(self, event: PipelineEvent) -> None:
|
||||
"""Log an event and call listener."""
|
||||
|
@ -499,8 +573,8 @@ class PipelineRun:
|
|||
|
||||
async def wake_word_detection(
|
||||
self,
|
||||
stream: AsyncIterable[bytes],
|
||||
audio_chunks_for_stt: list[bytes],
|
||||
stream: AsyncIterable[ProcessedAudioChunk],
|
||||
audio_chunks_for_stt: list[ProcessedAudioChunk],
|
||||
) -> wake_word.DetectionResult | None:
|
||||
"""Run wake-word-detection portion of pipeline. Returns detection result."""
|
||||
metadata_dict = asdict(
|
||||
|
@ -541,12 +615,13 @@ class PipelineRun:
|
|||
|
||||
# Audio chunk buffer. This audio will be forwarded to speech-to-text
|
||||
# after wake-word-detection.
|
||||
num_audio_bytes_to_buffer = int(
|
||||
wake_word_settings.audio_seconds_to_buffer * 16000 * 2 # 16-bit @ 16Khz
|
||||
num_audio_chunks_to_buffer = int(
|
||||
(wake_word_settings.audio_seconds_to_buffer * 16000)
|
||||
/ AUDIO_PROCESSOR_SAMPLES
|
||||
)
|
||||
stt_audio_buffer: RingBuffer | None = None
|
||||
if num_audio_bytes_to_buffer > 0:
|
||||
stt_audio_buffer = RingBuffer(num_audio_bytes_to_buffer)
|
||||
stt_audio_buffer: deque[ProcessedAudioChunk] | None = None
|
||||
if num_audio_chunks_to_buffer > 0:
|
||||
stt_audio_buffer = deque(maxlen=num_audio_chunks_to_buffer)
|
||||
|
||||
try:
|
||||
# Detect wake word(s)
|
||||
|
@ -562,7 +637,7 @@ class PipelineRun:
|
|||
if stt_audio_buffer is not None:
|
||||
# All audio kept from right before the wake word was detected as
|
||||
# a single chunk.
|
||||
audio_chunks_for_stt.append(stt_audio_buffer.getvalue())
|
||||
audio_chunks_for_stt.extend(stt_audio_buffer)
|
||||
except WakeWordTimeoutError:
|
||||
_LOGGER.debug("Timeout during wake word detection")
|
||||
raise
|
||||
|
@ -586,7 +661,11 @@ class PipelineRun:
|
|||
# speech-to-text so the user does not have to pause before
|
||||
# speaking the voice command.
|
||||
for chunk_ts in result.queued_audio:
|
||||
audio_chunks_for_stt.append(chunk_ts[0])
|
||||
audio_chunks_for_stt.append(
|
||||
ProcessedAudioChunk(
|
||||
audio=chunk_ts[0], timestamp_ms=chunk_ts[1], is_speech=False
|
||||
)
|
||||
)
|
||||
|
||||
wake_word_output = asdict(result)
|
||||
|
||||
|
@ -604,8 +683,8 @@ class PipelineRun:
|
|||
|
||||
async def _wake_word_audio_stream(
|
||||
self,
|
||||
audio_stream: AsyncIterable[bytes],
|
||||
stt_audio_buffer: RingBuffer | None,
|
||||
audio_stream: AsyncIterable[ProcessedAudioChunk],
|
||||
stt_audio_buffer: deque[ProcessedAudioChunk] | None,
|
||||
wake_word_vad: VoiceActivityTimeout | None,
|
||||
sample_rate: int = 16000,
|
||||
sample_width: int = 2,
|
||||
|
@ -615,25 +694,24 @@ class PipelineRun:
|
|||
Adds audio to a ring buffer that will be forwarded to speech-to-text after
|
||||
detection. Times out if VAD detects enough silence.
|
||||
"""
|
||||
ms_per_sample = sample_rate // 1000
|
||||
timestamp_ms = 0
|
||||
chunk_seconds = AUDIO_PROCESSOR_SAMPLES / sample_rate
|
||||
async for chunk in audio_stream:
|
||||
if self.debug_recording_queue is not None:
|
||||
self.debug_recording_queue.put_nowait(chunk)
|
||||
self.debug_recording_queue.put_nowait(chunk.audio)
|
||||
|
||||
yield chunk, timestamp_ms
|
||||
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
|
||||
yield chunk.audio, chunk.timestamp_ms
|
||||
|
||||
# Wake-word-detection occurs *after* the wake word was actually
|
||||
# spoken. Keeping audio right before detection allows the voice
|
||||
# command to be spoken immediately after the wake word.
|
||||
if stt_audio_buffer is not None:
|
||||
stt_audio_buffer.put(chunk)
|
||||
stt_audio_buffer.append(chunk)
|
||||
|
||||
if (wake_word_vad is not None) and (not wake_word_vad.process(chunk)):
|
||||
raise WakeWordTimeoutError(
|
||||
code="wake-word-timeout", message="Wake word was not detected"
|
||||
)
|
||||
if wake_word_vad is not None:
|
||||
if not wake_word_vad.process(chunk_seconds, chunk.is_speech):
|
||||
raise WakeWordTimeoutError(
|
||||
code="wake-word-timeout", message="Wake word was not detected"
|
||||
)
|
||||
|
||||
async def prepare_speech_to_text(self, metadata: stt.SpeechMetadata) -> None:
|
||||
"""Prepare speech-to-text."""
|
||||
|
@ -666,7 +744,7 @@ class PipelineRun:
|
|||
async def speech_to_text(
|
||||
self,
|
||||
metadata: stt.SpeechMetadata,
|
||||
stream: AsyncIterable[bytes],
|
||||
stream: AsyncIterable[ProcessedAudioChunk],
|
||||
) -> str:
|
||||
"""Run speech-to-text portion of pipeline. Returns the spoken text."""
|
||||
if isinstance(self.stt_provider, stt.Provider):
|
||||
|
@ -690,11 +768,13 @@ class PipelineRun:
|
|||
|
||||
try:
|
||||
# Transcribe audio stream
|
||||
stt_vad: VoiceCommandSegmenter | None = None
|
||||
if self.audio_settings.is_vad_enabled:
|
||||
stt_vad = VoiceCommandSegmenter()
|
||||
|
||||
result = await self.stt_provider.async_process_audio_stream(
|
||||
metadata,
|
||||
self._speech_to_text_stream(
|
||||
audio_stream=stream, stt_vad=VoiceCommandSegmenter()
|
||||
),
|
||||
self._speech_to_text_stream(audio_stream=stream, stt_vad=stt_vad),
|
||||
)
|
||||
except Exception as src_error:
|
||||
_LOGGER.exception("Unexpected error during speech-to-text")
|
||||
|
@ -731,26 +811,25 @@ class PipelineRun:
|
|||
|
||||
async def _speech_to_text_stream(
|
||||
self,
|
||||
audio_stream: AsyncIterable[bytes],
|
||||
audio_stream: AsyncIterable[ProcessedAudioChunk],
|
||||
stt_vad: VoiceCommandSegmenter | None,
|
||||
sample_rate: int = 16000,
|
||||
sample_width: int = 2,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""Yield audio chunks until VAD detects silence or speech-to-text completes."""
|
||||
ms_per_sample = sample_rate // 1000
|
||||
chunk_seconds = AUDIO_PROCESSOR_SAMPLES / sample_rate
|
||||
sent_vad_start = False
|
||||
timestamp_ms = 0
|
||||
async for chunk in audio_stream:
|
||||
if self.debug_recording_queue is not None:
|
||||
self.debug_recording_queue.put_nowait(chunk)
|
||||
self.debug_recording_queue.put_nowait(chunk.audio)
|
||||
|
||||
if stt_vad is not None:
|
||||
if not stt_vad.process(chunk):
|
||||
if not stt_vad.process(chunk_seconds, chunk.is_speech):
|
||||
# Silence detected at the end of voice command
|
||||
self.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.STT_VAD_END,
|
||||
{"timestamp": timestamp_ms},
|
||||
{"timestamp": chunk.timestamp_ms},
|
||||
)
|
||||
)
|
||||
break
|
||||
|
@ -760,13 +839,12 @@ class PipelineRun:
|
|||
self.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.STT_VAD_START,
|
||||
{"timestamp": timestamp_ms},
|
||||
{"timestamp": chunk.timestamp_ms},
|
||||
)
|
||||
)
|
||||
sent_vad_start = True
|
||||
|
||||
yield chunk
|
||||
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
|
||||
yield chunk.audio
|
||||
|
||||
async def prepare_recognize_intent(self) -> None:
|
||||
"""Prepare recognizing an intent."""
|
||||
|
@ -977,6 +1055,94 @@ class PipelineRun:
|
|||
self.debug_recording_queue = None
|
||||
self.debug_recording_thread = None
|
||||
|
||||
async def process_volume_only(
|
||||
self,
|
||||
audio_stream: AsyncIterable[bytes],
|
||||
sample_rate: int = 16000,
|
||||
sample_width: int = 2,
|
||||
) -> AsyncGenerator[ProcessedAudioChunk, None]:
|
||||
"""Apply volume transformation only (no VAD/audio enhancements) with optional chunking."""
|
||||
ms_per_sample = sample_rate // 1000
|
||||
ms_per_chunk = (AUDIO_PROCESSOR_SAMPLES // sample_width) // ms_per_sample
|
||||
timestamp_ms = 0
|
||||
|
||||
async for chunk in audio_stream:
|
||||
if self.audio_settings.volume_multiplier != 1.0:
|
||||
chunk = _multiply_volume(chunk, self.audio_settings.volume_multiplier)
|
||||
|
||||
if self.audio_settings.is_chunking_enabled:
|
||||
# 10 ms chunking
|
||||
for chunk_10ms in chunk_samples(
|
||||
chunk, AUDIO_PROCESSOR_BYTES, self.audio_processor_buffer
|
||||
):
|
||||
yield ProcessedAudioChunk(
|
||||
audio=chunk_10ms,
|
||||
timestamp_ms=timestamp_ms,
|
||||
is_speech=None, # no VAD
|
||||
)
|
||||
timestamp_ms += ms_per_chunk
|
||||
else:
|
||||
# No chunking
|
||||
yield ProcessedAudioChunk(
|
||||
audio=chunk,
|
||||
timestamp_ms=timestamp_ms,
|
||||
is_speech=None, # no VAD
|
||||
)
|
||||
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
|
||||
|
||||
async def process_enhance_audio(
|
||||
self,
|
||||
audio_stream: AsyncIterable[bytes],
|
||||
sample_rate: int = 16000,
|
||||
sample_width: int = 2,
|
||||
) -> AsyncGenerator[ProcessedAudioChunk, None]:
|
||||
"""Split audio into 10 ms chunks and apply VAD/noise suppression/auto gain/volume transformation."""
|
||||
assert self.audio_processor is not None
|
||||
|
||||
ms_per_sample = sample_rate // 1000
|
||||
ms_per_chunk = (AUDIO_PROCESSOR_SAMPLES // sample_width) // ms_per_sample
|
||||
timestamp_ms = 0
|
||||
|
||||
async for dirty_samples in audio_stream:
|
||||
if self.audio_settings.volume_multiplier != 1.0:
|
||||
# Static gain
|
||||
dirty_samples = _multiply_volume(
|
||||
dirty_samples, self.audio_settings.volume_multiplier
|
||||
)
|
||||
|
||||
# Split into 10ms chunks for audio enhancements/VAD
|
||||
for dirty_10ms_chunk in chunk_samples(
|
||||
dirty_samples, AUDIO_PROCESSOR_BYTES, self.audio_processor_buffer
|
||||
):
|
||||
ap_result = self.audio_processor.Process10ms(dirty_10ms_chunk)
|
||||
yield ProcessedAudioChunk(
|
||||
audio=ap_result.audio,
|
||||
timestamp_ms=timestamp_ms,
|
||||
is_speech=ap_result.is_speech,
|
||||
)
|
||||
|
||||
timestamp_ms += ms_per_chunk
|
||||
|
||||
|
||||
def _multiply_volume(chunk: bytes, volume_multiplier: float) -> bytes:
|
||||
"""Multiplies 16-bit PCM samples by a constant."""
|
||||
return array.array(
|
||||
"h",
|
||||
[
|
||||
int(
|
||||
# Clamp to signed 16-bit range
|
||||
max(
|
||||
-32767,
|
||||
min(
|
||||
32767,
|
||||
value * volume_multiplier,
|
||||
),
|
||||
)
|
||||
)
|
||||
for value in array.array("h", chunk)
|
||||
],
|
||||
).tobytes()
|
||||
|
||||
|
||||
def _pipeline_debug_recording_thread_proc(
|
||||
run_recording_dir: Path,
|
||||
|
@ -1042,14 +1208,23 @@ class PipelineInput:
|
|||
"""Run pipeline."""
|
||||
self.run.start(device_id=self.device_id)
|
||||
current_stage: PipelineStage | None = self.run.start_stage
|
||||
stt_audio_buffer: list[bytes] = []
|
||||
stt_audio_buffer: list[ProcessedAudioChunk] = []
|
||||
stt_processed_stream: AsyncIterable[ProcessedAudioChunk] | None = None
|
||||
|
||||
if self.stt_stream is not None:
|
||||
if self.run.audio_settings.needs_processor:
|
||||
# VAD/noise suppression/auto gain/volume
|
||||
stt_processed_stream = self.run.process_enhance_audio(self.stt_stream)
|
||||
else:
|
||||
# Volume multiplier only
|
||||
stt_processed_stream = self.run.process_volume_only(self.stt_stream)
|
||||
|
||||
try:
|
||||
if current_stage == PipelineStage.WAKE_WORD:
|
||||
# wake-word-detection
|
||||
assert self.stt_stream is not None
|
||||
assert stt_processed_stream is not None
|
||||
detect_result = await self.run.wake_word_detection(
|
||||
self.stt_stream, stt_audio_buffer
|
||||
stt_processed_stream, stt_audio_buffer
|
||||
)
|
||||
if detect_result is None:
|
||||
# No wake word. Abort the rest of the pipeline.
|
||||
|
@ -1062,28 +1237,30 @@ class PipelineInput:
|
|||
intent_input = self.intent_input
|
||||
if current_stage == PipelineStage.STT:
|
||||
assert self.stt_metadata is not None
|
||||
assert self.stt_stream is not None
|
||||
assert stt_processed_stream is not None
|
||||
|
||||
stt_stream = self.stt_stream
|
||||
stt_input_stream = stt_processed_stream
|
||||
|
||||
if stt_audio_buffer:
|
||||
# Send audio in the buffer first to speech-to-text, then move on to stt_stream.
|
||||
# This is basically an async itertools.chain.
|
||||
async def buffer_then_audio_stream() -> AsyncGenerator[bytes, None]:
|
||||
async def buffer_then_audio_stream() -> AsyncGenerator[
|
||||
ProcessedAudioChunk, None
|
||||
]:
|
||||
# Buffered audio
|
||||
for chunk in stt_audio_buffer:
|
||||
yield chunk
|
||||
|
||||
# Streamed audio
|
||||
assert self.stt_stream is not None
|
||||
async for chunk in self.stt_stream:
|
||||
assert stt_processed_stream is not None
|
||||
async for chunk in stt_processed_stream:
|
||||
yield chunk
|
||||
|
||||
stt_stream = buffer_then_audio_stream()
|
||||
stt_input_stream = buffer_then_audio_stream()
|
||||
|
||||
intent_input = await self.run.speech_to_text(
|
||||
self.stt_metadata,
|
||||
stt_stream,
|
||||
stt_input_stream,
|
||||
)
|
||||
current_stage = PipelineStage.INTENT
|
||||
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
"""Voice activity detection."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Final
|
||||
from typing import Final, cast
|
||||
|
||||
import webrtcvad
|
||||
from webrtc_noise_gain import AudioProcessor
|
||||
|
||||
_SAMPLE_RATE: Final = 16000 # Hz
|
||||
_SAMPLE_WIDTH: Final = 2 # bytes
|
||||
|
@ -32,6 +33,38 @@ class VadSensitivity(StrEnum):
|
|||
return 1.0
|
||||
|
||||
|
||||
class VoiceActivityDetector(ABC):
|
||||
"""Base class for voice activity detectors (VAD)."""
|
||||
|
||||
@abstractmethod
|
||||
def is_speech(self, chunk: bytes) -> bool:
|
||||
"""Return True if audio chunk contains speech."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def samples_per_chunk(self) -> int | None:
|
||||
"""Return number of samples per chunk or None if chunking is not required."""
|
||||
|
||||
|
||||
class WebRtcVad(VoiceActivityDetector):
|
||||
"""Voice activity detector based on webrtc."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize webrtcvad."""
|
||||
# Just VAD: no noise suppression or auto gain
|
||||
self._audio_processor = AudioProcessor(0, 0)
|
||||
|
||||
def is_speech(self, chunk: bytes) -> bool:
|
||||
"""Return True if audio chunk contains speech."""
|
||||
result = self._audio_processor.Process10ms(chunk)
|
||||
return cast(bool, result.is_speech)
|
||||
|
||||
@property
|
||||
def samples_per_chunk(self) -> int | None:
|
||||
"""Return 10 ms."""
|
||||
return int(0.01 * _SAMPLE_RATE) # 10 ms
|
||||
|
||||
|
||||
class AudioBuffer:
|
||||
"""Fixed-sized audio buffer with variable internal length."""
|
||||
|
||||
|
@ -73,13 +106,7 @@ class AudioBuffer:
|
|||
|
||||
@dataclass
|
||||
class VoiceCommandSegmenter:
|
||||
"""Segments an audio stream into voice commands using webrtcvad."""
|
||||
|
||||
vad_mode: int = 3
|
||||
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
|
||||
|
||||
vad_samples_per_chunk: int = 480 # 30 ms
|
||||
"""Must be 10, 20, or 30 ms at 16Khz."""
|
||||
"""Segments an audio stream into voice commands."""
|
||||
|
||||
speech_seconds: float = 0.3
|
||||
"""Seconds of speech before voice command has started."""
|
||||
|
@ -108,85 +135,85 @@ class VoiceCommandSegmenter:
|
|||
_reset_seconds_left: float = 0.0
|
||||
"""Seconds left before resetting start/stop time counters."""
|
||||
|
||||
_vad: webrtcvad.Vad = None
|
||||
_leftover_chunk_buffer: AudioBuffer = field(init=False)
|
||||
_bytes_per_chunk: int = field(init=False)
|
||||
_seconds_per_chunk: float = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Initialize VAD."""
|
||||
self._vad = webrtcvad.Vad(self.vad_mode)
|
||||
self._bytes_per_chunk = self.vad_samples_per_chunk * _SAMPLE_WIDTH
|
||||
self._seconds_per_chunk = self.vad_samples_per_chunk / _SAMPLE_RATE
|
||||
self._leftover_chunk_buffer = AudioBuffer(
|
||||
self.vad_samples_per_chunk * _SAMPLE_WIDTH
|
||||
)
|
||||
"""Reset after initialization."""
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all counters and state."""
|
||||
self._leftover_chunk_buffer.clear()
|
||||
self._speech_seconds_left = self.speech_seconds
|
||||
self._silence_seconds_left = self.silence_seconds
|
||||
self._timeout_seconds_left = self.timeout_seconds
|
||||
self._reset_seconds_left = self.reset_seconds
|
||||
self.in_command = False
|
||||
|
||||
def process(self, samples: bytes) -> bool:
|
||||
"""Process 16-bit 16Khz mono audio samples.
|
||||
def process(self, chunk_seconds: float, is_speech: bool | None) -> bool:
|
||||
"""Process samples using external VAD.
|
||||
|
||||
Returns False when command is done.
|
||||
"""
|
||||
for chunk in chunk_samples(
|
||||
samples, self._bytes_per_chunk, self._leftover_chunk_buffer
|
||||
):
|
||||
if not self._process_chunk(chunk):
|
||||
self.reset()
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@property
|
||||
def audio_buffer(self) -> bytes:
|
||||
"""Get partial chunk in the audio buffer."""
|
||||
return self._leftover_chunk_buffer.bytes()
|
||||
|
||||
def _process_chunk(self, chunk: bytes) -> bool:
|
||||
"""Process a single chunk of 16-bit 16Khz mono audio.
|
||||
|
||||
Returns False when command is done.
|
||||
"""
|
||||
is_speech = self._vad.is_speech(chunk, _SAMPLE_RATE)
|
||||
|
||||
self._timeout_seconds_left -= self._seconds_per_chunk
|
||||
self._timeout_seconds_left -= chunk_seconds
|
||||
if self._timeout_seconds_left <= 0:
|
||||
self.reset()
|
||||
return False
|
||||
|
||||
if not self.in_command:
|
||||
if is_speech:
|
||||
self._reset_seconds_left = self.reset_seconds
|
||||
self._speech_seconds_left -= self._seconds_per_chunk
|
||||
self._speech_seconds_left -= chunk_seconds
|
||||
if self._speech_seconds_left <= 0:
|
||||
# Inside voice command
|
||||
self.in_command = True
|
||||
else:
|
||||
# Reset if enough silence
|
||||
self._reset_seconds_left -= self._seconds_per_chunk
|
||||
self._reset_seconds_left -= chunk_seconds
|
||||
if self._reset_seconds_left <= 0:
|
||||
self._speech_seconds_left = self.speech_seconds
|
||||
elif not is_speech:
|
||||
self._reset_seconds_left = self.reset_seconds
|
||||
self._silence_seconds_left -= self._seconds_per_chunk
|
||||
self._silence_seconds_left -= chunk_seconds
|
||||
if self._silence_seconds_left <= 0:
|
||||
self.reset()
|
||||
return False
|
||||
else:
|
||||
# Reset if enough speech
|
||||
self._reset_seconds_left -= self._seconds_per_chunk
|
||||
self._reset_seconds_left -= chunk_seconds
|
||||
if self._reset_seconds_left <= 0:
|
||||
self._silence_seconds_left = self.silence_seconds
|
||||
|
||||
return True
|
||||
|
||||
def process_with_vad(
|
||||
self,
|
||||
chunk: bytes,
|
||||
vad: VoiceActivityDetector,
|
||||
leftover_chunk_buffer: AudioBuffer | None,
|
||||
) -> bool:
|
||||
"""Process an audio chunk using an external VAD.
|
||||
|
||||
A buffer is required if the VAD requires fixed-sized audio chunks (usually the case).
|
||||
|
||||
Returns False when voice command is finished.
|
||||
"""
|
||||
if vad.samples_per_chunk is None:
|
||||
# No chunking
|
||||
chunk_seconds = (len(chunk) // _SAMPLE_WIDTH) / _SAMPLE_RATE
|
||||
is_speech = vad.is_speech(chunk)
|
||||
return self.process(chunk_seconds, is_speech)
|
||||
|
||||
if leftover_chunk_buffer is None:
|
||||
raise ValueError("leftover_chunk_buffer is required when vad uses chunking")
|
||||
|
||||
# With chunking
|
||||
seconds_per_chunk = vad.samples_per_chunk / _SAMPLE_RATE
|
||||
bytes_per_chunk = vad.samples_per_chunk * _SAMPLE_WIDTH
|
||||
for vad_chunk in chunk_samples(chunk, bytes_per_chunk, leftover_chunk_buffer):
|
||||
is_speech = vad.is_speech(vad_chunk)
|
||||
if not self.process(seconds_per_chunk, is_speech):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoiceActivityTimeout:
|
||||
|
@ -198,73 +225,43 @@ class VoiceActivityTimeout:
|
|||
reset_seconds: float = 0.5
|
||||
"""Seconds of speech before resetting timeout."""
|
||||
|
||||
vad_mode: int = 3
|
||||
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
|
||||
|
||||
vad_samples_per_chunk: int = 480 # 30 ms
|
||||
"""Must be 10, 20, or 30 ms at 16Khz."""
|
||||
|
||||
_silence_seconds_left: float = 0.0
|
||||
"""Seconds left before considering voice command as stopped."""
|
||||
|
||||
_reset_seconds_left: float = 0.0
|
||||
"""Seconds left before resetting start/stop time counters."""
|
||||
|
||||
_vad: webrtcvad.Vad = None
|
||||
_leftover_chunk_buffer: AudioBuffer = field(init=False)
|
||||
_bytes_per_chunk: int = field(init=False)
|
||||
_seconds_per_chunk: float = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Initialize VAD."""
|
||||
self._vad = webrtcvad.Vad(self.vad_mode)
|
||||
self._bytes_per_chunk = self.vad_samples_per_chunk * _SAMPLE_WIDTH
|
||||
self._seconds_per_chunk = self.vad_samples_per_chunk / _SAMPLE_RATE
|
||||
self._leftover_chunk_buffer = AudioBuffer(
|
||||
self.vad_samples_per_chunk * _SAMPLE_WIDTH
|
||||
)
|
||||
"""Reset after initialization."""
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all counters and state."""
|
||||
self._leftover_chunk_buffer.clear()
|
||||
self._silence_seconds_left = self.silence_seconds
|
||||
self._reset_seconds_left = self.reset_seconds
|
||||
|
||||
def process(self, samples: bytes) -> bool:
|
||||
"""Process 16-bit 16Khz mono audio samples.
|
||||
def process(self, chunk_seconds: float, is_speech: bool | None) -> bool:
|
||||
"""Process samples using external VAD.
|
||||
|
||||
Returns False when timeout is reached.
|
||||
"""
|
||||
for chunk in chunk_samples(
|
||||
samples, self._bytes_per_chunk, self._leftover_chunk_buffer
|
||||
):
|
||||
if not self._process_chunk(chunk):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _process_chunk(self, chunk: bytes) -> bool:
|
||||
"""Process a single chunk of 16-bit 16Khz mono audio.
|
||||
|
||||
Returns False when timeout is reached.
|
||||
"""
|
||||
if self._vad.is_speech(chunk, _SAMPLE_RATE):
|
||||
if is_speech:
|
||||
# Speech
|
||||
self._reset_seconds_left -= self._seconds_per_chunk
|
||||
self._reset_seconds_left -= chunk_seconds
|
||||
if self._reset_seconds_left <= 0:
|
||||
# Reset timeout
|
||||
self._silence_seconds_left = self.silence_seconds
|
||||
else:
|
||||
# Silence
|
||||
self._silence_seconds_left -= self._seconds_per_chunk
|
||||
self._silence_seconds_left -= chunk_seconds
|
||||
if self._silence_seconds_left <= 0:
|
||||
# Timeout reached
|
||||
self.reset()
|
||||
return False
|
||||
|
||||
# Slowly build reset counter back up
|
||||
self._reset_seconds_left = min(
|
||||
self.reset_seconds, self._reset_seconds_left + self._seconds_per_chunk
|
||||
self.reset_seconds, self._reset_seconds_left + chunk_seconds
|
||||
)
|
||||
|
||||
return True
|
||||
|
|
|
@ -18,6 +18,7 @@ from homeassistant.util import language as language_util
|
|||
from .const import DOMAIN
|
||||
from .error import PipelineNotFound
|
||||
from .pipeline import (
|
||||
AudioSettings,
|
||||
PipelineData,
|
||||
PipelineError,
|
||||
PipelineEvent,
|
||||
|
@ -71,6 +72,13 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
|
|||
vol.Optional("audio_seconds_to_buffer"): vol.Any(
|
||||
float, int
|
||||
),
|
||||
# Audio enhancement
|
||||
vol.Optional("noise_suppression_level"): int,
|
||||
vol.Optional("auto_gain_dbfs"): int,
|
||||
vol.Optional("volume_multiplier"): float,
|
||||
# Advanced use cases/testing
|
||||
vol.Optional("no_vad"): bool,
|
||||
vol.Optional("no_chunking"): bool,
|
||||
}
|
||||
},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
|
@ -115,6 +123,7 @@ async def websocket_run(
|
|||
handler_id: int | None = None
|
||||
unregister_handler: Callable[[], None] | None = None
|
||||
wake_word_settings: WakeWordSettings | None = None
|
||||
audio_settings: AudioSettings | None = None
|
||||
|
||||
# Arguments to PipelineInput
|
||||
input_args: dict[str, Any] = {
|
||||
|
@ -124,13 +133,14 @@ async def websocket_run(
|
|||
|
||||
if start_stage in (PipelineStage.WAKE_WORD, PipelineStage.STT):
|
||||
# Audio pipeline that will receive audio as binary websocket messages
|
||||
msg_input = msg["input"]
|
||||
audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
incoming_sample_rate = msg["input"]["sample_rate"]
|
||||
incoming_sample_rate = msg_input["sample_rate"]
|
||||
|
||||
if start_stage == PipelineStage.WAKE_WORD:
|
||||
wake_word_settings = WakeWordSettings(
|
||||
timeout=msg["input"].get("timeout", DEFAULT_WAKE_WORD_TIMEOUT),
|
||||
audio_seconds_to_buffer=msg["input"].get("audio_seconds_to_buffer", 0),
|
||||
audio_seconds_to_buffer=msg_input.get("audio_seconds_to_buffer", 0),
|
||||
)
|
||||
|
||||
async def stt_stream() -> AsyncGenerator[bytes, None]:
|
||||
|
@ -166,6 +176,15 @@ async def websocket_run(
|
|||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
)
|
||||
input_args["stt_stream"] = stt_stream()
|
||||
|
||||
# Audio settings
|
||||
audio_settings = AudioSettings(
|
||||
noise_suppression_level=msg_input.get("noise_suppression_level", 0),
|
||||
auto_gain_dbfs=msg_input.get("auto_gain_dbfs", 0),
|
||||
volume_multiplier=msg_input.get("volume_multiplier", 1.0),
|
||||
is_vad_enabled=not msg_input.get("no_vad", False),
|
||||
is_chunking_enabled=not msg_input.get("no_chunking", False),
|
||||
)
|
||||
elif start_stage == PipelineStage.INTENT:
|
||||
# Input to conversation agent
|
||||
input_args["intent_input"] = msg["input"]["text"]
|
||||
|
@ -185,6 +204,7 @@ async def websocket_run(
|
|||
"timeout": timeout,
|
||||
},
|
||||
wake_word_settings=wake_word_settings,
|
||||
audio_settings=audio_settings or AudioSettings(),
|
||||
)
|
||||
|
||||
pipeline_input = PipelineInput(**input_args)
|
||||
|
|
|
@ -29,8 +29,11 @@ from homeassistant.components.assist_pipeline import (
|
|||
select as pipeline_select,
|
||||
)
|
||||
from homeassistant.components.assist_pipeline.vad import (
|
||||
AudioBuffer,
|
||||
VadSensitivity,
|
||||
VoiceActivityDetector,
|
||||
VoiceCommandSegmenter,
|
||||
WebRtcVad,
|
||||
)
|
||||
from homeassistant.const import __version__
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
|
@ -225,11 +228,13 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
|||
try:
|
||||
# Wait for speech before starting pipeline
|
||||
segmenter = VoiceCommandSegmenter(silence_seconds=self.silence_seconds)
|
||||
vad = WebRtcVad()
|
||||
chunk_buffer: deque[bytes] = deque(
|
||||
maxlen=self.buffered_chunks_before_speech,
|
||||
)
|
||||
speech_detected = await self._wait_for_speech(
|
||||
segmenter,
|
||||
vad,
|
||||
chunk_buffer,
|
||||
)
|
||||
if not speech_detected:
|
||||
|
@ -243,6 +248,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
|||
try:
|
||||
async for chunk in self._segment_audio(
|
||||
segmenter,
|
||||
vad,
|
||||
chunk_buffer,
|
||||
):
|
||||
yield chunk
|
||||
|
@ -306,6 +312,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
|||
async def _wait_for_speech(
|
||||
self,
|
||||
segmenter: VoiceCommandSegmenter,
|
||||
vad: VoiceActivityDetector,
|
||||
chunk_buffer: MutableSequence[bytes],
|
||||
):
|
||||
"""Buffer audio chunks until speech is detected.
|
||||
|
@ -317,12 +324,18 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
|||
async with asyncio.timeout(self.audio_timeout):
|
||||
chunk = await self._audio_queue.get()
|
||||
|
||||
assert vad.samples_per_chunk is not None
|
||||
vad_buffer = AudioBuffer(vad.samples_per_chunk * WIDTH)
|
||||
|
||||
while chunk:
|
||||
chunk_buffer.append(chunk)
|
||||
|
||||
segmenter.process(chunk)
|
||||
segmenter.process_with_vad(chunk, vad, vad_buffer)
|
||||
if segmenter.in_command:
|
||||
# Buffer until command starts
|
||||
if len(vad_buffer) > 0:
|
||||
chunk_buffer.append(vad_buffer.bytes())
|
||||
|
||||
return True
|
||||
|
||||
async with asyncio.timeout(self.audio_timeout):
|
||||
|
@ -333,6 +346,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
|||
async def _segment_audio(
|
||||
self,
|
||||
segmenter: VoiceCommandSegmenter,
|
||||
vad: VoiceActivityDetector,
|
||||
chunk_buffer: Sequence[bytes],
|
||||
) -> AsyncIterable[bytes]:
|
||||
"""Yield audio chunks until voice command has finished."""
|
||||
|
@ -345,8 +359,11 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
|||
async with asyncio.timeout(self.audio_timeout):
|
||||
chunk = await self._audio_queue.get()
|
||||
|
||||
assert vad.samples_per_chunk is not None
|
||||
vad_buffer = AudioBuffer(vad.samples_per_chunk * WIDTH)
|
||||
|
||||
while chunk:
|
||||
if not segmenter.process(chunk):
|
||||
if not segmenter.process_with_vad(chunk, vad, vad_buffer):
|
||||
# Voice command is finished
|
||||
break
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ typing-extensions>=4.8.0,<5.0
|
|||
ulid-transform==0.8.1
|
||||
voluptuous-serialize==2.6.0
|
||||
voluptuous==0.13.1
|
||||
webrtcvad==2.0.10
|
||||
webrtc-noise-gain==1.1.0
|
||||
yarl==1.9.2
|
||||
zeroconf==0.114.0
|
||||
|
||||
|
|
|
@ -2691,7 +2691,7 @@ waterfurnace==1.1.0
|
|||
webexteamssdk==1.1.1
|
||||
|
||||
# homeassistant.components.assist_pipeline
|
||||
webrtcvad==2.0.10
|
||||
webrtc-noise-gain==1.1.0
|
||||
|
||||
# homeassistant.components.whirlpool
|
||||
whirlpool-sixth-sense==0.18.4
|
||||
|
|
|
@ -1994,7 +1994,7 @@ wallbox==0.4.12
|
|||
watchdog==2.3.1
|
||||
|
||||
# homeassistant.components.assist_pipeline
|
||||
webrtcvad==2.0.10
|
||||
webrtc-noise-gain==1.1.0
|
||||
|
||||
# homeassistant.components.whirlpool
|
||||
whirlpool-sixth-sense==0.18.4
|
||||
|
|
|
@ -311,18 +311,6 @@
|
|||
}),
|
||||
'type': <PipelineEventType.STT_START: 'stt-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'timestamp': 0,
|
||||
}),
|
||||
'type': <PipelineEventType.STT_VAD_START: 'stt-vad-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'timestamp': 1500,
|
||||
}),
|
||||
'type': <PipelineEventType.STT_VAD_END: 'stt-vad-end'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'stt_output': dict({
|
||||
|
|
|
@ -173,6 +173,87 @@
|
|||
'message': 'No wake-word-detection provider for: wake_word.bad-entity-id',
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_enhancements
|
||||
dict({
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': 1,
|
||||
'timeout': 30,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_enhancements.1
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
'codec': 'pcm',
|
||||
'format': 'wav',
|
||||
'language': 'en-US',
|
||||
'sample_rate': 16000,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_enhancements.2
|
||||
dict({
|
||||
'stt_output': dict({
|
||||
'text': 'test transcript',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_enhancements.3
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'device_id': None,
|
||||
'engine': 'homeassistant',
|
||||
'intent_input': 'test transcript',
|
||||
'language': 'en',
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_enhancements.4
|
||||
dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
'data': dict({
|
||||
'code': 'no_intent_match',
|
||||
}),
|
||||
'language': 'en',
|
||||
'response_type': 'error',
|
||||
'speech': dict({
|
||||
'plain': dict({
|
||||
'extra_data': None,
|
||||
'speech': "Sorry, I couldn't understand that",
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_enhancements.5
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'language': 'en-US',
|
||||
'tts_input': "Sorry, I couldn't understand that",
|
||||
'voice': 'james_earl_jones',
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_enhancements.6
|
||||
dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_enhancements.7
|
||||
None
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word
|
||||
dict({
|
||||
'language': 'en',
|
||||
|
|
|
@ -64,6 +64,9 @@ async def test_pipeline_from_audio_stream_auto(
|
|||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_data(),
|
||||
audio_settings=assist_pipeline.AudioSettings(
|
||||
is_vad_enabled=False, is_chunking_enabled=False
|
||||
),
|
||||
)
|
||||
|
||||
assert process_events(events) == snapshot
|
||||
|
@ -126,6 +129,9 @@ async def test_pipeline_from_audio_stream_legacy(
|
|||
),
|
||||
stt_stream=audio_data(),
|
||||
pipeline_id=pipeline_id,
|
||||
audio_settings=assist_pipeline.AudioSettings(
|
||||
is_vad_enabled=False, is_chunking_enabled=False
|
||||
),
|
||||
)
|
||||
|
||||
assert process_events(events) == snapshot
|
||||
|
@ -188,6 +194,9 @@ async def test_pipeline_from_audio_stream_entity(
|
|||
),
|
||||
stt_stream=audio_data(),
|
||||
pipeline_id=pipeline_id,
|
||||
audio_settings=assist_pipeline.AudioSettings(
|
||||
is_vad_enabled=False, is_chunking_enabled=False
|
||||
),
|
||||
)
|
||||
|
||||
assert process_events(events) == snapshot
|
||||
|
@ -251,6 +260,9 @@ async def test_pipeline_from_audio_stream_no_stt(
|
|||
),
|
||||
stt_stream=audio_data(),
|
||||
pipeline_id=pipeline_id,
|
||||
audio_settings=assist_pipeline.AudioSettings(
|
||||
is_vad_enabled=False, is_chunking_enabled=False
|
||||
),
|
||||
)
|
||||
|
||||
assert not events
|
||||
|
@ -312,44 +324,47 @@ async def test_pipeline_from_audio_stream_wake_word(
|
|||
# [0, 2, ...]
|
||||
wake_chunk_2 = bytes(it.islice(it.cycle(range(0, 256, 2)), BYTES_ONE_SECOND))
|
||||
|
||||
bytes_per_chunk = int(0.01 * BYTES_ONE_SECOND)
|
||||
|
||||
async def audio_data():
|
||||
yield wake_chunk_1 # 1 second
|
||||
yield wake_chunk_2 # 1 second
|
||||
# 1 second in 10 ms chunks
|
||||
i = 0
|
||||
while i < len(wake_chunk_1):
|
||||
yield wake_chunk_1[i : i + bytes_per_chunk]
|
||||
i += bytes_per_chunk
|
||||
|
||||
# 1 second in 30 ms chunks
|
||||
i = 0
|
||||
while i < len(wake_chunk_2):
|
||||
yield wake_chunk_2[i : i + bytes_per_chunk]
|
||||
i += bytes_per_chunk
|
||||
|
||||
yield b"wake word!"
|
||||
yield b"part1"
|
||||
yield b"part2"
|
||||
yield b"end"
|
||||
yield b""
|
||||
|
||||
def continue_stt(self, chunk):
|
||||
# Ensure stt_vad_start event is triggered
|
||||
self.in_command = True
|
||||
|
||||
# Stop on fake end chunk to trigger stt_vad_end
|
||||
return chunk != b"end"
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.VoiceCommandSegmenter.process",
|
||||
continue_stt,
|
||||
):
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
context=Context(),
|
||||
event_callback=events.append,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_data(),
|
||||
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
|
||||
wake_word_settings=assist_pipeline.WakeWordSettings(
|
||||
audio_seconds_to_buffer=1.5
|
||||
),
|
||||
)
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
context=Context(),
|
||||
event_callback=events.append,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_data(),
|
||||
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
|
||||
wake_word_settings=assist_pipeline.WakeWordSettings(
|
||||
audio_seconds_to_buffer=1.5
|
||||
),
|
||||
audio_settings=assist_pipeline.AudioSettings(
|
||||
is_vad_enabled=False, is_chunking_enabled=False
|
||||
),
|
||||
)
|
||||
|
||||
assert process_events(events) == snapshot
|
||||
|
||||
|
@ -357,12 +372,14 @@ async def test_pipeline_from_audio_stream_wake_word(
|
|||
# 2. queued audio (from mock wake word entity)
|
||||
# 3. part1
|
||||
# 4. part2
|
||||
assert len(mock_stt_provider.received) == 4
|
||||
assert len(mock_stt_provider.received) > 3
|
||||
|
||||
first_chunk = mock_stt_provider.received[0]
|
||||
first_chunk = bytes(
|
||||
[c_byte for c in mock_stt_provider.received[:-3] for c_byte in c]
|
||||
)
|
||||
assert first_chunk == wake_chunk_1[len(wake_chunk_1) // 2 :] + wake_chunk_2
|
||||
|
||||
assert mock_stt_provider.received[1:] == [b"queued audio", b"part1", b"part2"]
|
||||
assert mock_stt_provider.received[-3:] == [b"queued audio", b"part1", b"part2"]
|
||||
|
||||
|
||||
async def test_pipeline_save_audio(
|
||||
|
@ -410,6 +427,9 @@ async def test_pipeline_save_audio(
|
|||
pipeline_id=pipeline.id,
|
||||
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
|
||||
end_stage=assist_pipeline.PipelineStage.STT,
|
||||
audio_settings=assist_pipeline.AudioSettings(
|
||||
is_vad_enabled=False, is_chunking_enabled=False
|
||||
),
|
||||
)
|
||||
|
||||
pipeline_dirs = list(temp_dir.iterdir())
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
"""Tests for webrtcvad voice command segmenter."""
|
||||
"""Tests for voice command segmenter."""
|
||||
import itertools as it
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant.components.assist_pipeline.vad import (
|
||||
AudioBuffer,
|
||||
VoiceActivityDetector,
|
||||
VoiceCommandSegmenter,
|
||||
chunk_samples,
|
||||
)
|
||||
|
||||
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit
|
||||
_ONE_SECOND = 1.0
|
||||
|
||||
|
||||
def test_silence() -> None:
|
||||
|
@ -16,87 +17,85 @@ def test_silence() -> None:
|
|||
segmenter = VoiceCommandSegmenter()
|
||||
|
||||
# True return value indicates voice command has not finished
|
||||
assert segmenter.process(bytes(_ONE_SECOND * 3))
|
||||
assert segmenter.process(_ONE_SECOND * 3, False)
|
||||
|
||||
|
||||
def test_speech() -> None:
|
||||
"""Test that silence + speech + silence triggers a voice command."""
|
||||
|
||||
def is_speech(self, chunk, sample_rate):
|
||||
def is_speech(chunk):
|
||||
"""Anything non-zero is speech."""
|
||||
return sum(chunk) > 0
|
||||
|
||||
with patch(
|
||||
"webrtcvad.Vad.is_speech",
|
||||
new=is_speech,
|
||||
):
|
||||
segmenter = VoiceCommandSegmenter()
|
||||
segmenter = VoiceCommandSegmenter()
|
||||
|
||||
# silence
|
||||
assert segmenter.process(bytes(_ONE_SECOND))
|
||||
# silence
|
||||
assert segmenter.process(_ONE_SECOND, False)
|
||||
|
||||
# "speech"
|
||||
assert segmenter.process(bytes([255] * _ONE_SECOND))
|
||||
# "speech"
|
||||
assert segmenter.process(_ONE_SECOND, True)
|
||||
|
||||
# silence
|
||||
# False return value indicates voice command is finished
|
||||
assert not segmenter.process(bytes(_ONE_SECOND))
|
||||
# silence
|
||||
# False return value indicates voice command is finished
|
||||
assert not segmenter.process(_ONE_SECOND, False)
|
||||
|
||||
|
||||
def test_audio_buffer() -> None:
|
||||
"""Test audio buffer wrapping."""
|
||||
|
||||
def is_speech(self, chunk, sample_rate):
|
||||
"""Disable VAD."""
|
||||
return False
|
||||
class DisabledVad(VoiceActivityDetector):
|
||||
def is_speech(self, chunk):
|
||||
return False
|
||||
|
||||
with patch(
|
||||
"webrtcvad.Vad.is_speech",
|
||||
new=is_speech,
|
||||
):
|
||||
segmenter = VoiceCommandSegmenter()
|
||||
bytes_per_chunk = segmenter.vad_samples_per_chunk * 2
|
||||
@property
|
||||
def samples_per_chunk(self):
|
||||
return 160 # 10 ms
|
||||
|
||||
with patch.object(
|
||||
segmenter, "_process_chunk", return_value=True
|
||||
) as mock_process:
|
||||
# Partially fill audio buffer
|
||||
half_chunk = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk // 2))
|
||||
segmenter.process(half_chunk)
|
||||
vad = DisabledVad()
|
||||
bytes_per_chunk = vad.samples_per_chunk * 2
|
||||
vad_buffer = AudioBuffer(bytes_per_chunk)
|
||||
segmenter = VoiceCommandSegmenter()
|
||||
|
||||
assert not mock_process.called
|
||||
assert segmenter.audio_buffer == half_chunk
|
||||
with patch.object(vad, "is_speech", return_value=False) as mock_process:
|
||||
# Partially fill audio buffer
|
||||
half_chunk = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk // 2))
|
||||
segmenter.process_with_vad(half_chunk, vad, vad_buffer)
|
||||
|
||||
# Fill and wrap with 1/4 chunk left over
|
||||
three_quarters_chunk = bytes(
|
||||
it.islice(it.cycle(range(256)), int(0.75 * bytes_per_chunk))
|
||||
)
|
||||
segmenter.process(three_quarters_chunk)
|
||||
assert not mock_process.called
|
||||
assert vad_buffer is not None
|
||||
assert vad_buffer.bytes() == half_chunk
|
||||
|
||||
assert mock_process.call_count == 1
|
||||
assert (
|
||||
segmenter.audio_buffer
|
||||
== three_quarters_chunk[
|
||||
len(three_quarters_chunk) - (bytes_per_chunk // 4) :
|
||||
]
|
||||
)
|
||||
assert (
|
||||
mock_process.call_args[0][0]
|
||||
== half_chunk + three_quarters_chunk[: bytes_per_chunk // 2]
|
||||
)
|
||||
# Fill and wrap with 1/4 chunk left over
|
||||
three_quarters_chunk = bytes(
|
||||
it.islice(it.cycle(range(256)), int(0.75 * bytes_per_chunk))
|
||||
)
|
||||
segmenter.process_with_vad(three_quarters_chunk, vad, vad_buffer)
|
||||
|
||||
# Run 2 chunks through
|
||||
segmenter.reset()
|
||||
assert len(segmenter.audio_buffer) == 0
|
||||
assert mock_process.call_count == 1
|
||||
assert (
|
||||
vad_buffer.bytes()
|
||||
== three_quarters_chunk[
|
||||
len(three_quarters_chunk) - (bytes_per_chunk // 4) :
|
||||
]
|
||||
)
|
||||
assert (
|
||||
mock_process.call_args[0][0]
|
||||
== half_chunk + three_quarters_chunk[: bytes_per_chunk // 2]
|
||||
)
|
||||
|
||||
mock_process.reset_mock()
|
||||
two_chunks = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk * 2))
|
||||
segmenter.process(two_chunks)
|
||||
# Run 2 chunks through
|
||||
segmenter.reset()
|
||||
vad_buffer.clear()
|
||||
assert len(vad_buffer) == 0
|
||||
|
||||
assert mock_process.call_count == 2
|
||||
assert len(segmenter.audio_buffer) == 0
|
||||
assert mock_process.call_args_list[0][0][0] == two_chunks[:bytes_per_chunk]
|
||||
assert mock_process.call_args_list[1][0][0] == two_chunks[bytes_per_chunk:]
|
||||
mock_process.reset_mock()
|
||||
two_chunks = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk * 2))
|
||||
segmenter.process_with_vad(two_chunks, vad, vad_buffer)
|
||||
|
||||
assert mock_process.call_count == 2
|
||||
assert len(vad_buffer) == 0
|
||||
assert mock_process.call_args_list[0][0][0] == two_chunks[:bytes_per_chunk]
|
||||
assert mock_process.call_args_list[1][0][0] == two_chunks[bytes_per_chunk:]
|
||||
|
||||
|
||||
def test_partial_chunk() -> None:
|
||||
|
@ -125,3 +124,43 @@ def test_chunk_samples_leftover() -> None:
|
|||
|
||||
assert len(chunks) == 1
|
||||
assert leftover_chunk_buffer.bytes() == bytes([5, 6])
|
||||
|
||||
|
||||
def test_vad_no_chunking() -> None:
|
||||
"""Test VAD that doesn't require chunking."""
|
||||
|
||||
class VadNoChunk(VoiceActivityDetector):
|
||||
def is_speech(self, chunk: bytes) -> bool:
|
||||
return sum(chunk) > 0
|
||||
|
||||
@property
|
||||
def samples_per_chunk(self) -> int | None:
|
||||
return None
|
||||
|
||||
vad = VadNoChunk()
|
||||
segmenter = VoiceCommandSegmenter(
|
||||
speech_seconds=1.0, silence_seconds=1.0, reset_seconds=0.5
|
||||
)
|
||||
silence = bytes([0] * 16000)
|
||||
speech = bytes([255] * (16000 // 2))
|
||||
|
||||
# Test with differently-sized chunks
|
||||
assert vad.is_speech(speech)
|
||||
assert not vad.is_speech(silence)
|
||||
|
||||
# Simulate voice command
|
||||
assert segmenter.process_with_vad(silence, vad, None)
|
||||
# begin
|
||||
assert segmenter.process_with_vad(speech, vad, None)
|
||||
assert segmenter.process_with_vad(speech, vad, None)
|
||||
assert segmenter.process_with_vad(speech, vad, None)
|
||||
# reset with silence
|
||||
assert segmenter.process_with_vad(silence, vad, None)
|
||||
# resume
|
||||
assert segmenter.process_with_vad(speech, vad, None)
|
||||
assert segmenter.process_with_vad(speech, vad, None)
|
||||
assert segmenter.process_with_vad(speech, vad, None)
|
||||
assert segmenter.process_with_vad(speech, vad, None)
|
||||
# end
|
||||
assert segmenter.process_with_vad(silence, vad, None)
|
||||
assert not segmenter.process_with_vad(silence, vad, None)
|
||||
|
|
|
@ -107,6 +107,7 @@ async def test_audio_pipeline(
|
|||
assert msg["event"]["type"] == "run-start"
|
||||
msg["event"]["data"]["pipeline"] = ANY
|
||||
assert msg["event"]["data"] == snapshot
|
||||
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
|
||||
events.append(msg["event"])
|
||||
|
||||
# stt
|
||||
|
@ -116,7 +117,7 @@ async def test_audio_pipeline(
|
|||
events.append(msg["event"])
|
||||
|
||||
# End of audio stream (handler id + empty payload)
|
||||
await client.send_bytes(bytes([1]))
|
||||
await client.send_bytes(bytes([handler_id]))
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "stt-end"
|
||||
|
@ -240,6 +241,8 @@ async def test_audio_pipeline_with_wake_word_no_timeout(
|
|||
"input": {
|
||||
"sample_rate": 16000,
|
||||
"timeout": 0,
|
||||
"no_vad": True,
|
||||
"no_chunking": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
@ -253,6 +256,7 @@ async def test_audio_pipeline_with_wake_word_no_timeout(
|
|||
assert msg["event"]["type"] == "run-start"
|
||||
msg["event"]["data"]["pipeline"] = ANY
|
||||
assert msg["event"]["data"] == snapshot
|
||||
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
|
||||
events.append(msg["event"])
|
||||
|
||||
# wake_word
|
||||
|
@ -276,7 +280,7 @@ async def test_audio_pipeline_with_wake_word_no_timeout(
|
|||
events.append(msg["event"])
|
||||
|
||||
# End of audio stream (handler id + empty payload)
|
||||
await client.send_bytes(bytes([1]))
|
||||
await client.send_bytes(bytes([handler_id]))
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "stt-end"
|
||||
|
@ -731,6 +735,7 @@ async def test_stt_stream_failed(
|
|||
assert msg["event"]["type"] == "run-start"
|
||||
msg["event"]["data"]["pipeline"] = ANY
|
||||
assert msg["event"]["data"] == snapshot
|
||||
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
|
||||
events.append(msg["event"])
|
||||
|
||||
# stt
|
||||
|
@ -740,7 +745,7 @@ async def test_stt_stream_failed(
|
|||
events.append(msg["event"])
|
||||
|
||||
# End of audio stream (handler id + empty payload)
|
||||
await client.send_bytes(b"1")
|
||||
await client.send_bytes(bytes([handler_id]))
|
||||
|
||||
# stt error
|
||||
msg = await client.receive_json()
|
||||
|
@ -1489,6 +1494,7 @@ async def test_audio_pipeline_debug(
|
|||
assert msg["event"]["type"] == "run-start"
|
||||
msg["event"]["data"]["pipeline"] = ANY
|
||||
assert msg["event"]["data"] == snapshot
|
||||
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
|
||||
events.append(msg["event"])
|
||||
|
||||
# stt
|
||||
|
@ -1498,7 +1504,7 @@ async def test_audio_pipeline_debug(
|
|||
events.append(msg["event"])
|
||||
|
||||
# End of audio stream (handler id + empty payload)
|
||||
await client.send_bytes(bytes([1]))
|
||||
await client.send_bytes(bytes([handler_id]))
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "stt-end"
|
||||
|
@ -1699,3 +1705,103 @@ async def test_list_pipeline_languages_with_aliases(
|
|||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"languages": ["he", "nb"]}
|
||||
|
||||
|
||||
async def test_audio_pipeline_with_enhancements(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test events from a pipeline run with audio input/output."""
|
||||
events = []
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "stt",
|
||||
"end_stage": "tts",
|
||||
"input": {
|
||||
"sample_rate": 16000,
|
||||
# Enhancements
|
||||
"noise_suppression_level": 2,
|
||||
"auto_gain_dbfs": 15,
|
||||
"volume_multiplier": 2.0,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
# run start
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
msg["event"]["data"]["pipeline"] = ANY
|
||||
assert msg["event"]["data"] == snapshot
|
||||
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
|
||||
events.append(msg["event"])
|
||||
|
||||
# stt
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "stt-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# One second of silence.
|
||||
# This will pass through the audio enhancement pipeline, but we don't test
|
||||
# the actual output.
|
||||
await client.send_bytes(bytes([handler_id]) + bytes(16000 * 2))
|
||||
|
||||
# End of audio stream (handler id + empty payload)
|
||||
await client.send_bytes(bytes([handler_id]))
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "stt-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# intent
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "intent-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "intent-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# text-to-speech
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "tts-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "tts-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# run end
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": pipeline_id,
|
||||
"pipeline_run_id": pipeline_run_id,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"events": events}
|
||||
|
|
|
@ -21,7 +21,7 @@ async def test_pipeline(
|
|||
"""Test that pipeline function is called from RTP protocol."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
def is_speech(self, chunk, sample_rate):
|
||||
def is_speech(self, chunk):
|
||||
"""Anything non-zero is speech."""
|
||||
return sum(chunk) > 0
|
||||
|
||||
|
@ -76,7 +76,7 @@ async def test_pipeline(
|
|||
return ("mp3", b"")
|
||||
|
||||
with patch(
|
||||
"webrtcvad.Vad.is_speech",
|
||||
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
|
||||
new=is_speech,
|
||||
), patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
|
@ -210,7 +210,7 @@ async def test_tts_timeout(
|
|||
"""Test that TTS will time out based on its length."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
def is_speech(self, chunk, sample_rate):
|
||||
def is_speech(self, chunk):
|
||||
"""Anything non-zero is speech."""
|
||||
return sum(chunk) > 0
|
||||
|
||||
|
@ -269,7 +269,7 @@ async def test_tts_timeout(
|
|||
return ("raw", bytes(0))
|
||||
|
||||
with patch(
|
||||
"webrtcvad.Vad.is_speech",
|
||||
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
|
||||
new=is_speech,
|
||||
), patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
|
|
Loading…
Add table
Reference in a new issue