Use webrtc-noise-gain for audio enhancement in Assist pipelines (#100698)
* Use webrtc-noise-gain instead of webrtcvad package * Switching to ProcessedAudioChunk * Refactor VAD and fix tests * Add vad no chunking test * Add test that runs audio enhancements
This commit is contained in:
parent
a4f7f3ba7e
commit
785618909a
15 changed files with 707 additions and 258 deletions
|
@ -12,6 +12,7 @@ from homeassistant.helpers.typing import ConfigType
|
||||||
from .const import DATA_CONFIG, DOMAIN
|
from .const import DATA_CONFIG, DOMAIN
|
||||||
from .error import PipelineNotFound
|
from .error import PipelineNotFound
|
||||||
from .pipeline import (
|
from .pipeline import (
|
||||||
|
AudioSettings,
|
||||||
Pipeline,
|
Pipeline,
|
||||||
PipelineEvent,
|
PipelineEvent,
|
||||||
PipelineEventCallback,
|
PipelineEventCallback,
|
||||||
|
@ -33,6 +34,7 @@ __all__ = (
|
||||||
"async_get_pipelines",
|
"async_get_pipelines",
|
||||||
"async_setup",
|
"async_setup",
|
||||||
"async_pipeline_from_audio_stream",
|
"async_pipeline_from_audio_stream",
|
||||||
|
"AudioSettings",
|
||||||
"Pipeline",
|
"Pipeline",
|
||||||
"PipelineEvent",
|
"PipelineEvent",
|
||||||
"PipelineEventType",
|
"PipelineEventType",
|
||||||
|
@ -71,6 +73,7 @@ async def async_pipeline_from_audio_stream(
|
||||||
conversation_id: str | None = None,
|
conversation_id: str | None = None,
|
||||||
tts_audio_output: str | None = None,
|
tts_audio_output: str | None = None,
|
||||||
wake_word_settings: WakeWordSettings | None = None,
|
wake_word_settings: WakeWordSettings | None = None,
|
||||||
|
audio_settings: AudioSettings | None = None,
|
||||||
device_id: str | None = None,
|
device_id: str | None = None,
|
||||||
start_stage: PipelineStage = PipelineStage.STT,
|
start_stage: PipelineStage = PipelineStage.STT,
|
||||||
end_stage: PipelineStage = PipelineStage.TTS,
|
end_stage: PipelineStage = PipelineStage.TTS,
|
||||||
|
@ -93,6 +96,7 @@ async def async_pipeline_from_audio_stream(
|
||||||
event_callback=event_callback,
|
event_callback=event_callback,
|
||||||
tts_audio_output=tts_audio_output,
|
tts_audio_output=tts_audio_output,
|
||||||
wake_word_settings=wake_word_settings,
|
wake_word_settings=wake_word_settings,
|
||||||
|
audio_settings=audio_settings or AudioSettings(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
await pipeline_input.validate()
|
await pipeline_input.validate()
|
||||||
|
|
|
@ -6,5 +6,5 @@
|
||||||
"documentation": "https://www.home-assistant.io/integrations/assist_pipeline",
|
"documentation": "https://www.home-assistant.io/integrations/assist_pipeline",
|
||||||
"iot_class": "local_push",
|
"iot_class": "local_push",
|
||||||
"quality_scale": "internal",
|
"quality_scale": "internal",
|
||||||
"requirements": ["webrtcvad==2.0.10"]
|
"requirements": ["webrtc-noise-gain==1.1.0"]
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
"""Classes for voice assistant pipelines."""
|
"""Classes for voice assistant pipelines."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import array
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections import deque
|
||||||
from collections.abc import AsyncGenerator, AsyncIterable, Callable, Iterable
|
from collections.abc import AsyncGenerator, AsyncIterable, Callable, Iterable
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
@ -10,10 +12,11 @@ from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
import time
|
import time
|
||||||
from typing import Any, cast
|
from typing import Any, Final, cast
|
||||||
import wave
|
import wave
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
from webrtc_noise_gain import AudioProcessor
|
||||||
|
|
||||||
from homeassistant.components import (
|
from homeassistant.components import (
|
||||||
conversation,
|
conversation,
|
||||||
|
@ -54,8 +57,7 @@ from .error import (
|
||||||
WakeWordDetectionError,
|
WakeWordDetectionError,
|
||||||
WakeWordTimeoutError,
|
WakeWordTimeoutError,
|
||||||
)
|
)
|
||||||
from .ring_buffer import RingBuffer
|
from .vad import AudioBuffer, VoiceActivityTimeout, VoiceCommandSegmenter, chunk_samples
|
||||||
from .vad import VoiceActivityTimeout, VoiceCommandSegmenter
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -95,6 +97,9 @@ STORED_PIPELINE_RUNS = 10
|
||||||
|
|
||||||
SAVE_DELAY = 10
|
SAVE_DELAY = 10
|
||||||
|
|
||||||
|
AUDIO_PROCESSOR_SAMPLES: Final = 160 # 10 ms @ 16 Khz
|
||||||
|
AUDIO_PROCESSOR_BYTES: Final = AUDIO_PROCESSOR_SAMPLES * 2 # 16-bit samples
|
||||||
|
|
||||||
|
|
||||||
async def _async_resolve_default_pipeline_settings(
|
async def _async_resolve_default_pipeline_settings(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
@ -393,6 +398,60 @@ class WakeWordSettings:
|
||||||
"""Seconds of audio to buffer before detection and forward to STT."""
|
"""Seconds of audio to buffer before detection and forward to STT."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AudioSettings:
|
||||||
|
"""Settings for pipeline audio processing."""
|
||||||
|
|
||||||
|
noise_suppression_level: int = 0
|
||||||
|
"""Level of noise suppression (0 = disabled, 4 = max)"""
|
||||||
|
|
||||||
|
auto_gain_dbfs: int = 0
|
||||||
|
"""Amount of automatic gain in dbFS (0 = disabled, 31 = max)"""
|
||||||
|
|
||||||
|
volume_multiplier: float = 1.0
|
||||||
|
"""Multiplier used directly on PCM samples (1.0 = no change, 2.0 = twice as loud)"""
|
||||||
|
|
||||||
|
is_vad_enabled: bool = True
|
||||||
|
"""True if VAD is used to determine the end of the voice command."""
|
||||||
|
|
||||||
|
is_chunking_enabled: bool = True
|
||||||
|
"""True if audio is automatically split into 10 ms chunks (required for VAD, etc.)"""
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Verify settings post-initialization."""
|
||||||
|
if (self.noise_suppression_level < 0) or (self.noise_suppression_level > 4):
|
||||||
|
raise ValueError("noise_suppression_level must be in [0, 4]")
|
||||||
|
|
||||||
|
if (self.auto_gain_dbfs < 0) or (self.auto_gain_dbfs > 31):
|
||||||
|
raise ValueError("auto_gain_dbfs must be in [0, 31]")
|
||||||
|
|
||||||
|
if self.needs_processor and (not self.is_chunking_enabled):
|
||||||
|
raise ValueError("Chunking must be enabled for audio processing")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def needs_processor(self) -> bool:
|
||||||
|
"""True if an audio processor is needed."""
|
||||||
|
return (
|
||||||
|
self.is_vad_enabled
|
||||||
|
or (self.noise_suppression_level > 0)
|
||||||
|
or (self.auto_gain_dbfs > 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class ProcessedAudioChunk:
|
||||||
|
"""Processed audio chunk and metadata."""
|
||||||
|
|
||||||
|
audio: bytes
|
||||||
|
"""Raw PCM audio @ 16Khz with 16-bit mono samples"""
|
||||||
|
|
||||||
|
timestamp_ms: int
|
||||||
|
"""Timestamp relative to start of audio stream (milliseconds)"""
|
||||||
|
|
||||||
|
is_speech: bool | None
|
||||||
|
"""True if audio chunk likely contains speech, False if not, None if unknown"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PipelineRun:
|
class PipelineRun:
|
||||||
"""Running context for a pipeline."""
|
"""Running context for a pipeline."""
|
||||||
|
@ -408,6 +467,7 @@ class PipelineRun:
|
||||||
intent_agent: str | None = None
|
intent_agent: str | None = None
|
||||||
tts_audio_output: str | None = None
|
tts_audio_output: str | None = None
|
||||||
wake_word_settings: WakeWordSettings | None = None
|
wake_word_settings: WakeWordSettings | None = None
|
||||||
|
audio_settings: AudioSettings = field(default_factory=AudioSettings)
|
||||||
|
|
||||||
id: str = field(default_factory=ulid_util.ulid)
|
id: str = field(default_factory=ulid_util.ulid)
|
||||||
stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False)
|
stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False)
|
||||||
|
@ -422,6 +482,12 @@ class PipelineRun:
|
||||||
debug_recording_queue: Queue[str | bytes | None] | None = None
|
debug_recording_queue: Queue[str | bytes | None] | None = None
|
||||||
"""Queue to communicate with debug recording thread"""
|
"""Queue to communicate with debug recording thread"""
|
||||||
|
|
||||||
|
audio_processor: AudioProcessor | None = None
|
||||||
|
"""VAD/noise suppression/auto gain"""
|
||||||
|
|
||||||
|
audio_processor_buffer: AudioBuffer = field(init=False)
|
||||||
|
"""Buffer used when splitting audio into chunks for audio processing"""
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
"""Set language for pipeline."""
|
"""Set language for pipeline."""
|
||||||
self.language = self.pipeline.language or self.hass.config.language
|
self.language = self.pipeline.language or self.hass.config.language
|
||||||
|
@ -439,6 +505,14 @@ class PipelineRun:
|
||||||
)
|
)
|
||||||
pipeline_data.pipeline_runs[self.pipeline.id][self.id] = PipelineRunDebug()
|
pipeline_data.pipeline_runs[self.pipeline.id][self.id] = PipelineRunDebug()
|
||||||
|
|
||||||
|
# Initialize with audio settings
|
||||||
|
self.audio_processor_buffer = AudioBuffer(AUDIO_PROCESSOR_BYTES)
|
||||||
|
if self.audio_settings.needs_processor:
|
||||||
|
self.audio_processor = AudioProcessor(
|
||||||
|
self.audio_settings.auto_gain_dbfs,
|
||||||
|
self.audio_settings.noise_suppression_level,
|
||||||
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def process_event(self, event: PipelineEvent) -> None:
|
def process_event(self, event: PipelineEvent) -> None:
|
||||||
"""Log an event and call listener."""
|
"""Log an event and call listener."""
|
||||||
|
@ -499,8 +573,8 @@ class PipelineRun:
|
||||||
|
|
||||||
async def wake_word_detection(
|
async def wake_word_detection(
|
||||||
self,
|
self,
|
||||||
stream: AsyncIterable[bytes],
|
stream: AsyncIterable[ProcessedAudioChunk],
|
||||||
audio_chunks_for_stt: list[bytes],
|
audio_chunks_for_stt: list[ProcessedAudioChunk],
|
||||||
) -> wake_word.DetectionResult | None:
|
) -> wake_word.DetectionResult | None:
|
||||||
"""Run wake-word-detection portion of pipeline. Returns detection result."""
|
"""Run wake-word-detection portion of pipeline. Returns detection result."""
|
||||||
metadata_dict = asdict(
|
metadata_dict = asdict(
|
||||||
|
@ -541,12 +615,13 @@ class PipelineRun:
|
||||||
|
|
||||||
# Audio chunk buffer. This audio will be forwarded to speech-to-text
|
# Audio chunk buffer. This audio will be forwarded to speech-to-text
|
||||||
# after wake-word-detection.
|
# after wake-word-detection.
|
||||||
num_audio_bytes_to_buffer = int(
|
num_audio_chunks_to_buffer = int(
|
||||||
wake_word_settings.audio_seconds_to_buffer * 16000 * 2 # 16-bit @ 16Khz
|
(wake_word_settings.audio_seconds_to_buffer * 16000)
|
||||||
|
/ AUDIO_PROCESSOR_SAMPLES
|
||||||
)
|
)
|
||||||
stt_audio_buffer: RingBuffer | None = None
|
stt_audio_buffer: deque[ProcessedAudioChunk] | None = None
|
||||||
if num_audio_bytes_to_buffer > 0:
|
if num_audio_chunks_to_buffer > 0:
|
||||||
stt_audio_buffer = RingBuffer(num_audio_bytes_to_buffer)
|
stt_audio_buffer = deque(maxlen=num_audio_chunks_to_buffer)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Detect wake word(s)
|
# Detect wake word(s)
|
||||||
|
@ -562,7 +637,7 @@ class PipelineRun:
|
||||||
if stt_audio_buffer is not None:
|
if stt_audio_buffer is not None:
|
||||||
# All audio kept from right before the wake word was detected as
|
# All audio kept from right before the wake word was detected as
|
||||||
# a single chunk.
|
# a single chunk.
|
||||||
audio_chunks_for_stt.append(stt_audio_buffer.getvalue())
|
audio_chunks_for_stt.extend(stt_audio_buffer)
|
||||||
except WakeWordTimeoutError:
|
except WakeWordTimeoutError:
|
||||||
_LOGGER.debug("Timeout during wake word detection")
|
_LOGGER.debug("Timeout during wake word detection")
|
||||||
raise
|
raise
|
||||||
|
@ -586,7 +661,11 @@ class PipelineRun:
|
||||||
# speech-to-text so the user does not have to pause before
|
# speech-to-text so the user does not have to pause before
|
||||||
# speaking the voice command.
|
# speaking the voice command.
|
||||||
for chunk_ts in result.queued_audio:
|
for chunk_ts in result.queued_audio:
|
||||||
audio_chunks_for_stt.append(chunk_ts[0])
|
audio_chunks_for_stt.append(
|
||||||
|
ProcessedAudioChunk(
|
||||||
|
audio=chunk_ts[0], timestamp_ms=chunk_ts[1], is_speech=False
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
wake_word_output = asdict(result)
|
wake_word_output = asdict(result)
|
||||||
|
|
||||||
|
@ -604,8 +683,8 @@ class PipelineRun:
|
||||||
|
|
||||||
async def _wake_word_audio_stream(
|
async def _wake_word_audio_stream(
|
||||||
self,
|
self,
|
||||||
audio_stream: AsyncIterable[bytes],
|
audio_stream: AsyncIterable[ProcessedAudioChunk],
|
||||||
stt_audio_buffer: RingBuffer | None,
|
stt_audio_buffer: deque[ProcessedAudioChunk] | None,
|
||||||
wake_word_vad: VoiceActivityTimeout | None,
|
wake_word_vad: VoiceActivityTimeout | None,
|
||||||
sample_rate: int = 16000,
|
sample_rate: int = 16000,
|
||||||
sample_width: int = 2,
|
sample_width: int = 2,
|
||||||
|
@ -615,22 +694,21 @@ class PipelineRun:
|
||||||
Adds audio to a ring buffer that will be forwarded to speech-to-text after
|
Adds audio to a ring buffer that will be forwarded to speech-to-text after
|
||||||
detection. Times out if VAD detects enough silence.
|
detection. Times out if VAD detects enough silence.
|
||||||
"""
|
"""
|
||||||
ms_per_sample = sample_rate // 1000
|
chunk_seconds = AUDIO_PROCESSOR_SAMPLES / sample_rate
|
||||||
timestamp_ms = 0
|
|
||||||
async for chunk in audio_stream:
|
async for chunk in audio_stream:
|
||||||
if self.debug_recording_queue is not None:
|
if self.debug_recording_queue is not None:
|
||||||
self.debug_recording_queue.put_nowait(chunk)
|
self.debug_recording_queue.put_nowait(chunk.audio)
|
||||||
|
|
||||||
yield chunk, timestamp_ms
|
yield chunk.audio, chunk.timestamp_ms
|
||||||
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
|
|
||||||
|
|
||||||
# Wake-word-detection occurs *after* the wake word was actually
|
# Wake-word-detection occurs *after* the wake word was actually
|
||||||
# spoken. Keeping audio right before detection allows the voice
|
# spoken. Keeping audio right before detection allows the voice
|
||||||
# command to be spoken immediately after the wake word.
|
# command to be spoken immediately after the wake word.
|
||||||
if stt_audio_buffer is not None:
|
if stt_audio_buffer is not None:
|
||||||
stt_audio_buffer.put(chunk)
|
stt_audio_buffer.append(chunk)
|
||||||
|
|
||||||
if (wake_word_vad is not None) and (not wake_word_vad.process(chunk)):
|
if wake_word_vad is not None:
|
||||||
|
if not wake_word_vad.process(chunk_seconds, chunk.is_speech):
|
||||||
raise WakeWordTimeoutError(
|
raise WakeWordTimeoutError(
|
||||||
code="wake-word-timeout", message="Wake word was not detected"
|
code="wake-word-timeout", message="Wake word was not detected"
|
||||||
)
|
)
|
||||||
|
@ -666,7 +744,7 @@ class PipelineRun:
|
||||||
async def speech_to_text(
|
async def speech_to_text(
|
||||||
self,
|
self,
|
||||||
metadata: stt.SpeechMetadata,
|
metadata: stt.SpeechMetadata,
|
||||||
stream: AsyncIterable[bytes],
|
stream: AsyncIterable[ProcessedAudioChunk],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Run speech-to-text portion of pipeline. Returns the spoken text."""
|
"""Run speech-to-text portion of pipeline. Returns the spoken text."""
|
||||||
if isinstance(self.stt_provider, stt.Provider):
|
if isinstance(self.stt_provider, stt.Provider):
|
||||||
|
@ -690,11 +768,13 @@ class PipelineRun:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Transcribe audio stream
|
# Transcribe audio stream
|
||||||
|
stt_vad: VoiceCommandSegmenter | None = None
|
||||||
|
if self.audio_settings.is_vad_enabled:
|
||||||
|
stt_vad = VoiceCommandSegmenter()
|
||||||
|
|
||||||
result = await self.stt_provider.async_process_audio_stream(
|
result = await self.stt_provider.async_process_audio_stream(
|
||||||
metadata,
|
metadata,
|
||||||
self._speech_to_text_stream(
|
self._speech_to_text_stream(audio_stream=stream, stt_vad=stt_vad),
|
||||||
audio_stream=stream, stt_vad=VoiceCommandSegmenter()
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
except Exception as src_error:
|
except Exception as src_error:
|
||||||
_LOGGER.exception("Unexpected error during speech-to-text")
|
_LOGGER.exception("Unexpected error during speech-to-text")
|
||||||
|
@ -731,26 +811,25 @@ class PipelineRun:
|
||||||
|
|
||||||
async def _speech_to_text_stream(
|
async def _speech_to_text_stream(
|
||||||
self,
|
self,
|
||||||
audio_stream: AsyncIterable[bytes],
|
audio_stream: AsyncIterable[ProcessedAudioChunk],
|
||||||
stt_vad: VoiceCommandSegmenter | None,
|
stt_vad: VoiceCommandSegmenter | None,
|
||||||
sample_rate: int = 16000,
|
sample_rate: int = 16000,
|
||||||
sample_width: int = 2,
|
sample_width: int = 2,
|
||||||
) -> AsyncGenerator[bytes, None]:
|
) -> AsyncGenerator[bytes, None]:
|
||||||
"""Yield audio chunks until VAD detects silence or speech-to-text completes."""
|
"""Yield audio chunks until VAD detects silence or speech-to-text completes."""
|
||||||
ms_per_sample = sample_rate // 1000
|
chunk_seconds = AUDIO_PROCESSOR_SAMPLES / sample_rate
|
||||||
sent_vad_start = False
|
sent_vad_start = False
|
||||||
timestamp_ms = 0
|
|
||||||
async for chunk in audio_stream:
|
async for chunk in audio_stream:
|
||||||
if self.debug_recording_queue is not None:
|
if self.debug_recording_queue is not None:
|
||||||
self.debug_recording_queue.put_nowait(chunk)
|
self.debug_recording_queue.put_nowait(chunk.audio)
|
||||||
|
|
||||||
if stt_vad is not None:
|
if stt_vad is not None:
|
||||||
if not stt_vad.process(chunk):
|
if not stt_vad.process(chunk_seconds, chunk.is_speech):
|
||||||
# Silence detected at the end of voice command
|
# Silence detected at the end of voice command
|
||||||
self.process_event(
|
self.process_event(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.STT_VAD_END,
|
PipelineEventType.STT_VAD_END,
|
||||||
{"timestamp": timestamp_ms},
|
{"timestamp": chunk.timestamp_ms},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
@ -760,13 +839,12 @@ class PipelineRun:
|
||||||
self.process_event(
|
self.process_event(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.STT_VAD_START,
|
PipelineEventType.STT_VAD_START,
|
||||||
{"timestamp": timestamp_ms},
|
{"timestamp": chunk.timestamp_ms},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
sent_vad_start = True
|
sent_vad_start = True
|
||||||
|
|
||||||
yield chunk
|
yield chunk.audio
|
||||||
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
|
|
||||||
|
|
||||||
async def prepare_recognize_intent(self) -> None:
|
async def prepare_recognize_intent(self) -> None:
|
||||||
"""Prepare recognizing an intent."""
|
"""Prepare recognizing an intent."""
|
||||||
|
@ -977,6 +1055,94 @@ class PipelineRun:
|
||||||
self.debug_recording_queue = None
|
self.debug_recording_queue = None
|
||||||
self.debug_recording_thread = None
|
self.debug_recording_thread = None
|
||||||
|
|
||||||
|
async def process_volume_only(
|
||||||
|
self,
|
||||||
|
audio_stream: AsyncIterable[bytes],
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
sample_width: int = 2,
|
||||||
|
) -> AsyncGenerator[ProcessedAudioChunk, None]:
|
||||||
|
"""Apply volume transformation only (no VAD/audio enhancements) with optional chunking."""
|
||||||
|
ms_per_sample = sample_rate // 1000
|
||||||
|
ms_per_chunk = (AUDIO_PROCESSOR_SAMPLES // sample_width) // ms_per_sample
|
||||||
|
timestamp_ms = 0
|
||||||
|
|
||||||
|
async for chunk in audio_stream:
|
||||||
|
if self.audio_settings.volume_multiplier != 1.0:
|
||||||
|
chunk = _multiply_volume(chunk, self.audio_settings.volume_multiplier)
|
||||||
|
|
||||||
|
if self.audio_settings.is_chunking_enabled:
|
||||||
|
# 10 ms chunking
|
||||||
|
for chunk_10ms in chunk_samples(
|
||||||
|
chunk, AUDIO_PROCESSOR_BYTES, self.audio_processor_buffer
|
||||||
|
):
|
||||||
|
yield ProcessedAudioChunk(
|
||||||
|
audio=chunk_10ms,
|
||||||
|
timestamp_ms=timestamp_ms,
|
||||||
|
is_speech=None, # no VAD
|
||||||
|
)
|
||||||
|
timestamp_ms += ms_per_chunk
|
||||||
|
else:
|
||||||
|
# No chunking
|
||||||
|
yield ProcessedAudioChunk(
|
||||||
|
audio=chunk,
|
||||||
|
timestamp_ms=timestamp_ms,
|
||||||
|
is_speech=None, # no VAD
|
||||||
|
)
|
||||||
|
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
|
||||||
|
|
||||||
|
async def process_enhance_audio(
|
||||||
|
self,
|
||||||
|
audio_stream: AsyncIterable[bytes],
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
sample_width: int = 2,
|
||||||
|
) -> AsyncGenerator[ProcessedAudioChunk, None]:
|
||||||
|
"""Split audio into 10 ms chunks and apply VAD/noise suppression/auto gain/volume transformation."""
|
||||||
|
assert self.audio_processor is not None
|
||||||
|
|
||||||
|
ms_per_sample = sample_rate // 1000
|
||||||
|
ms_per_chunk = (AUDIO_PROCESSOR_SAMPLES // sample_width) // ms_per_sample
|
||||||
|
timestamp_ms = 0
|
||||||
|
|
||||||
|
async for dirty_samples in audio_stream:
|
||||||
|
if self.audio_settings.volume_multiplier != 1.0:
|
||||||
|
# Static gain
|
||||||
|
dirty_samples = _multiply_volume(
|
||||||
|
dirty_samples, self.audio_settings.volume_multiplier
|
||||||
|
)
|
||||||
|
|
||||||
|
# Split into 10ms chunks for audio enhancements/VAD
|
||||||
|
for dirty_10ms_chunk in chunk_samples(
|
||||||
|
dirty_samples, AUDIO_PROCESSOR_BYTES, self.audio_processor_buffer
|
||||||
|
):
|
||||||
|
ap_result = self.audio_processor.Process10ms(dirty_10ms_chunk)
|
||||||
|
yield ProcessedAudioChunk(
|
||||||
|
audio=ap_result.audio,
|
||||||
|
timestamp_ms=timestamp_ms,
|
||||||
|
is_speech=ap_result.is_speech,
|
||||||
|
)
|
||||||
|
|
||||||
|
timestamp_ms += ms_per_chunk
|
||||||
|
|
||||||
|
|
||||||
|
def _multiply_volume(chunk: bytes, volume_multiplier: float) -> bytes:
|
||||||
|
"""Multiplies 16-bit PCM samples by a constant."""
|
||||||
|
return array.array(
|
||||||
|
"h",
|
||||||
|
[
|
||||||
|
int(
|
||||||
|
# Clamp to signed 16-bit range
|
||||||
|
max(
|
||||||
|
-32767,
|
||||||
|
min(
|
||||||
|
32767,
|
||||||
|
value * volume_multiplier,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for value in array.array("h", chunk)
|
||||||
|
],
|
||||||
|
).tobytes()
|
||||||
|
|
||||||
|
|
||||||
def _pipeline_debug_recording_thread_proc(
|
def _pipeline_debug_recording_thread_proc(
|
||||||
run_recording_dir: Path,
|
run_recording_dir: Path,
|
||||||
|
@ -1042,14 +1208,23 @@ class PipelineInput:
|
||||||
"""Run pipeline."""
|
"""Run pipeline."""
|
||||||
self.run.start(device_id=self.device_id)
|
self.run.start(device_id=self.device_id)
|
||||||
current_stage: PipelineStage | None = self.run.start_stage
|
current_stage: PipelineStage | None = self.run.start_stage
|
||||||
stt_audio_buffer: list[bytes] = []
|
stt_audio_buffer: list[ProcessedAudioChunk] = []
|
||||||
|
stt_processed_stream: AsyncIterable[ProcessedAudioChunk] | None = None
|
||||||
|
|
||||||
|
if self.stt_stream is not None:
|
||||||
|
if self.run.audio_settings.needs_processor:
|
||||||
|
# VAD/noise suppression/auto gain/volume
|
||||||
|
stt_processed_stream = self.run.process_enhance_audio(self.stt_stream)
|
||||||
|
else:
|
||||||
|
# Volume multiplier only
|
||||||
|
stt_processed_stream = self.run.process_volume_only(self.stt_stream)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if current_stage == PipelineStage.WAKE_WORD:
|
if current_stage == PipelineStage.WAKE_WORD:
|
||||||
# wake-word-detection
|
# wake-word-detection
|
||||||
assert self.stt_stream is not None
|
assert stt_processed_stream is not None
|
||||||
detect_result = await self.run.wake_word_detection(
|
detect_result = await self.run.wake_word_detection(
|
||||||
self.stt_stream, stt_audio_buffer
|
stt_processed_stream, stt_audio_buffer
|
||||||
)
|
)
|
||||||
if detect_result is None:
|
if detect_result is None:
|
||||||
# No wake word. Abort the rest of the pipeline.
|
# No wake word. Abort the rest of the pipeline.
|
||||||
|
@ -1062,28 +1237,30 @@ class PipelineInput:
|
||||||
intent_input = self.intent_input
|
intent_input = self.intent_input
|
||||||
if current_stage == PipelineStage.STT:
|
if current_stage == PipelineStage.STT:
|
||||||
assert self.stt_metadata is not None
|
assert self.stt_metadata is not None
|
||||||
assert self.stt_stream is not None
|
assert stt_processed_stream is not None
|
||||||
|
|
||||||
stt_stream = self.stt_stream
|
stt_input_stream = stt_processed_stream
|
||||||
|
|
||||||
if stt_audio_buffer:
|
if stt_audio_buffer:
|
||||||
# Send audio in the buffer first to speech-to-text, then move on to stt_stream.
|
# Send audio in the buffer first to speech-to-text, then move on to stt_stream.
|
||||||
# This is basically an async itertools.chain.
|
# This is basically an async itertools.chain.
|
||||||
async def buffer_then_audio_stream() -> AsyncGenerator[bytes, None]:
|
async def buffer_then_audio_stream() -> AsyncGenerator[
|
||||||
|
ProcessedAudioChunk, None
|
||||||
|
]:
|
||||||
# Buffered audio
|
# Buffered audio
|
||||||
for chunk in stt_audio_buffer:
|
for chunk in stt_audio_buffer:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
# Streamed audio
|
# Streamed audio
|
||||||
assert self.stt_stream is not None
|
assert stt_processed_stream is not None
|
||||||
async for chunk in self.stt_stream:
|
async for chunk in stt_processed_stream:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
stt_stream = buffer_then_audio_stream()
|
stt_input_stream = buffer_then_audio_stream()
|
||||||
|
|
||||||
intent_input = await self.run.speech_to_text(
|
intent_input = await self.run.speech_to_text(
|
||||||
self.stt_metadata,
|
self.stt_metadata,
|
||||||
stt_stream,
|
stt_input_stream,
|
||||||
)
|
)
|
||||||
current_stage = PipelineStage.INTENT
|
current_stage = PipelineStage.INTENT
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
"""Voice activity detection."""
|
"""Voice activity detection."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Final
|
from typing import Final, cast
|
||||||
|
|
||||||
import webrtcvad
|
from webrtc_noise_gain import AudioProcessor
|
||||||
|
|
||||||
_SAMPLE_RATE: Final = 16000 # Hz
|
_SAMPLE_RATE: Final = 16000 # Hz
|
||||||
_SAMPLE_WIDTH: Final = 2 # bytes
|
_SAMPLE_WIDTH: Final = 2 # bytes
|
||||||
|
@ -32,6 +33,38 @@ class VadSensitivity(StrEnum):
|
||||||
return 1.0
|
return 1.0
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceActivityDetector(ABC):
|
||||||
|
"""Base class for voice activity detectors (VAD)."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def is_speech(self, chunk: bytes) -> bool:
|
||||||
|
"""Return True if audio chunk contains speech."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def samples_per_chunk(self) -> int | None:
|
||||||
|
"""Return number of samples per chunk or None if chunking is not required."""
|
||||||
|
|
||||||
|
|
||||||
|
class WebRtcVad(VoiceActivityDetector):
|
||||||
|
"""Voice activity detector based on webrtc."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize webrtcvad."""
|
||||||
|
# Just VAD: no noise suppression or auto gain
|
||||||
|
self._audio_processor = AudioProcessor(0, 0)
|
||||||
|
|
||||||
|
def is_speech(self, chunk: bytes) -> bool:
|
||||||
|
"""Return True if audio chunk contains speech."""
|
||||||
|
result = self._audio_processor.Process10ms(chunk)
|
||||||
|
return cast(bool, result.is_speech)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def samples_per_chunk(self) -> int | None:
|
||||||
|
"""Return 10 ms."""
|
||||||
|
return int(0.01 * _SAMPLE_RATE) # 10 ms
|
||||||
|
|
||||||
|
|
||||||
class AudioBuffer:
|
class AudioBuffer:
|
||||||
"""Fixed-sized audio buffer with variable internal length."""
|
"""Fixed-sized audio buffer with variable internal length."""
|
||||||
|
|
||||||
|
@ -73,13 +106,7 @@ class AudioBuffer:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class VoiceCommandSegmenter:
|
class VoiceCommandSegmenter:
|
||||||
"""Segments an audio stream into voice commands using webrtcvad."""
|
"""Segments an audio stream into voice commands."""
|
||||||
|
|
||||||
vad_mode: int = 3
|
|
||||||
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
|
|
||||||
|
|
||||||
vad_samples_per_chunk: int = 480 # 30 ms
|
|
||||||
"""Must be 10, 20, or 30 ms at 16Khz."""
|
|
||||||
|
|
||||||
speech_seconds: float = 0.3
|
speech_seconds: float = 0.3
|
||||||
"""Seconds of speech before voice command has started."""
|
"""Seconds of speech before voice command has started."""
|
||||||
|
@ -108,85 +135,85 @@ class VoiceCommandSegmenter:
|
||||||
_reset_seconds_left: float = 0.0
|
_reset_seconds_left: float = 0.0
|
||||||
"""Seconds left before resetting start/stop time counters."""
|
"""Seconds left before resetting start/stop time counters."""
|
||||||
|
|
||||||
_vad: webrtcvad.Vad = None
|
|
||||||
_leftover_chunk_buffer: AudioBuffer = field(init=False)
|
|
||||||
_bytes_per_chunk: int = field(init=False)
|
|
||||||
_seconds_per_chunk: float = field(init=False)
|
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
"""Initialize VAD."""
|
"""Reset after initialization."""
|
||||||
self._vad = webrtcvad.Vad(self.vad_mode)
|
|
||||||
self._bytes_per_chunk = self.vad_samples_per_chunk * _SAMPLE_WIDTH
|
|
||||||
self._seconds_per_chunk = self.vad_samples_per_chunk / _SAMPLE_RATE
|
|
||||||
self._leftover_chunk_buffer = AudioBuffer(
|
|
||||||
self.vad_samples_per_chunk * _SAMPLE_WIDTH
|
|
||||||
)
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""Reset all counters and state."""
|
"""Reset all counters and state."""
|
||||||
self._leftover_chunk_buffer.clear()
|
|
||||||
self._speech_seconds_left = self.speech_seconds
|
self._speech_seconds_left = self.speech_seconds
|
||||||
self._silence_seconds_left = self.silence_seconds
|
self._silence_seconds_left = self.silence_seconds
|
||||||
self._timeout_seconds_left = self.timeout_seconds
|
self._timeout_seconds_left = self.timeout_seconds
|
||||||
self._reset_seconds_left = self.reset_seconds
|
self._reset_seconds_left = self.reset_seconds
|
||||||
self.in_command = False
|
self.in_command = False
|
||||||
|
|
||||||
def process(self, samples: bytes) -> bool:
|
def process(self, chunk_seconds: float, is_speech: bool | None) -> bool:
|
||||||
"""Process 16-bit 16Khz mono audio samples.
|
"""Process samples using external VAD.
|
||||||
|
|
||||||
Returns False when command is done.
|
Returns False when command is done.
|
||||||
"""
|
"""
|
||||||
for chunk in chunk_samples(
|
self._timeout_seconds_left -= chunk_seconds
|
||||||
samples, self._bytes_per_chunk, self._leftover_chunk_buffer
|
|
||||||
):
|
|
||||||
if not self._process_chunk(chunk):
|
|
||||||
self.reset()
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def audio_buffer(self) -> bytes:
|
|
||||||
"""Get partial chunk in the audio buffer."""
|
|
||||||
return self._leftover_chunk_buffer.bytes()
|
|
||||||
|
|
||||||
def _process_chunk(self, chunk: bytes) -> bool:
|
|
||||||
"""Process a single chunk of 16-bit 16Khz mono audio.
|
|
||||||
|
|
||||||
Returns False when command is done.
|
|
||||||
"""
|
|
||||||
is_speech = self._vad.is_speech(chunk, _SAMPLE_RATE)
|
|
||||||
|
|
||||||
self._timeout_seconds_left -= self._seconds_per_chunk
|
|
||||||
if self._timeout_seconds_left <= 0:
|
if self._timeout_seconds_left <= 0:
|
||||||
|
self.reset()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not self.in_command:
|
if not self.in_command:
|
||||||
if is_speech:
|
if is_speech:
|
||||||
self._reset_seconds_left = self.reset_seconds
|
self._reset_seconds_left = self.reset_seconds
|
||||||
self._speech_seconds_left -= self._seconds_per_chunk
|
self._speech_seconds_left -= chunk_seconds
|
||||||
if self._speech_seconds_left <= 0:
|
if self._speech_seconds_left <= 0:
|
||||||
# Inside voice command
|
# Inside voice command
|
||||||
self.in_command = True
|
self.in_command = True
|
||||||
else:
|
else:
|
||||||
# Reset if enough silence
|
# Reset if enough silence
|
||||||
self._reset_seconds_left -= self._seconds_per_chunk
|
self._reset_seconds_left -= chunk_seconds
|
||||||
if self._reset_seconds_left <= 0:
|
if self._reset_seconds_left <= 0:
|
||||||
self._speech_seconds_left = self.speech_seconds
|
self._speech_seconds_left = self.speech_seconds
|
||||||
elif not is_speech:
|
elif not is_speech:
|
||||||
self._reset_seconds_left = self.reset_seconds
|
self._reset_seconds_left = self.reset_seconds
|
||||||
self._silence_seconds_left -= self._seconds_per_chunk
|
self._silence_seconds_left -= chunk_seconds
|
||||||
if self._silence_seconds_left <= 0:
|
if self._silence_seconds_left <= 0:
|
||||||
|
self.reset()
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
# Reset if enough speech
|
# Reset if enough speech
|
||||||
self._reset_seconds_left -= self._seconds_per_chunk
|
self._reset_seconds_left -= chunk_seconds
|
||||||
if self._reset_seconds_left <= 0:
|
if self._reset_seconds_left <= 0:
|
||||||
self._silence_seconds_left = self.silence_seconds
|
self._silence_seconds_left = self.silence_seconds
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def process_with_vad(
|
||||||
|
self,
|
||||||
|
chunk: bytes,
|
||||||
|
vad: VoiceActivityDetector,
|
||||||
|
leftover_chunk_buffer: AudioBuffer | None,
|
||||||
|
) -> bool:
|
||||||
|
"""Process an audio chunk using an external VAD.
|
||||||
|
|
||||||
|
A buffer is required if the VAD requires fixed-sized audio chunks (usually the case).
|
||||||
|
|
||||||
|
Returns False when voice command is finished.
|
||||||
|
"""
|
||||||
|
if vad.samples_per_chunk is None:
|
||||||
|
# No chunking
|
||||||
|
chunk_seconds = (len(chunk) // _SAMPLE_WIDTH) / _SAMPLE_RATE
|
||||||
|
is_speech = vad.is_speech(chunk)
|
||||||
|
return self.process(chunk_seconds, is_speech)
|
||||||
|
|
||||||
|
if leftover_chunk_buffer is None:
|
||||||
|
raise ValueError("leftover_chunk_buffer is required when vad uses chunking")
|
||||||
|
|
||||||
|
# With chunking
|
||||||
|
seconds_per_chunk = vad.samples_per_chunk / _SAMPLE_RATE
|
||||||
|
bytes_per_chunk = vad.samples_per_chunk * _SAMPLE_WIDTH
|
||||||
|
for vad_chunk in chunk_samples(chunk, bytes_per_chunk, leftover_chunk_buffer):
|
||||||
|
is_speech = vad.is_speech(vad_chunk)
|
||||||
|
if not self.process(seconds_per_chunk, is_speech):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class VoiceActivityTimeout:
|
class VoiceActivityTimeout:
|
||||||
|
@ -198,73 +225,43 @@ class VoiceActivityTimeout:
|
||||||
reset_seconds: float = 0.5
|
reset_seconds: float = 0.5
|
||||||
"""Seconds of speech before resetting timeout."""
|
"""Seconds of speech before resetting timeout."""
|
||||||
|
|
||||||
vad_mode: int = 3
|
|
||||||
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
|
|
||||||
|
|
||||||
vad_samples_per_chunk: int = 480 # 30 ms
|
|
||||||
"""Must be 10, 20, or 30 ms at 16Khz."""
|
|
||||||
|
|
||||||
_silence_seconds_left: float = 0.0
|
_silence_seconds_left: float = 0.0
|
||||||
"""Seconds left before considering voice command as stopped."""
|
"""Seconds left before considering voice command as stopped."""
|
||||||
|
|
||||||
_reset_seconds_left: float = 0.0
|
_reset_seconds_left: float = 0.0
|
||||||
"""Seconds left before resetting start/stop time counters."""
|
"""Seconds left before resetting start/stop time counters."""
|
||||||
|
|
||||||
_vad: webrtcvad.Vad = None
|
|
||||||
_leftover_chunk_buffer: AudioBuffer = field(init=False)
|
|
||||||
_bytes_per_chunk: int = field(init=False)
|
|
||||||
_seconds_per_chunk: float = field(init=False)
|
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
"""Initialize VAD."""
|
"""Reset after initialization."""
|
||||||
self._vad = webrtcvad.Vad(self.vad_mode)
|
|
||||||
self._bytes_per_chunk = self.vad_samples_per_chunk * _SAMPLE_WIDTH
|
|
||||||
self._seconds_per_chunk = self.vad_samples_per_chunk / _SAMPLE_RATE
|
|
||||||
self._leftover_chunk_buffer = AudioBuffer(
|
|
||||||
self.vad_samples_per_chunk * _SAMPLE_WIDTH
|
|
||||||
)
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""Reset all counters and state."""
|
"""Reset all counters and state."""
|
||||||
self._leftover_chunk_buffer.clear()
|
|
||||||
self._silence_seconds_left = self.silence_seconds
|
self._silence_seconds_left = self.silence_seconds
|
||||||
self._reset_seconds_left = self.reset_seconds
|
self._reset_seconds_left = self.reset_seconds
|
||||||
|
|
||||||
def process(self, samples: bytes) -> bool:
|
def process(self, chunk_seconds: float, is_speech: bool | None) -> bool:
|
||||||
"""Process 16-bit 16Khz mono audio samples.
|
"""Process samples using external VAD.
|
||||||
|
|
||||||
Returns False when timeout is reached.
|
Returns False when timeout is reached.
|
||||||
"""
|
"""
|
||||||
for chunk in chunk_samples(
|
if is_speech:
|
||||||
samples, self._bytes_per_chunk, self._leftover_chunk_buffer
|
|
||||||
):
|
|
||||||
if not self._process_chunk(chunk):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _process_chunk(self, chunk: bytes) -> bool:
|
|
||||||
"""Process a single chunk of 16-bit 16Khz mono audio.
|
|
||||||
|
|
||||||
Returns False when timeout is reached.
|
|
||||||
"""
|
|
||||||
if self._vad.is_speech(chunk, _SAMPLE_RATE):
|
|
||||||
# Speech
|
# Speech
|
||||||
self._reset_seconds_left -= self._seconds_per_chunk
|
self._reset_seconds_left -= chunk_seconds
|
||||||
if self._reset_seconds_left <= 0:
|
if self._reset_seconds_left <= 0:
|
||||||
# Reset timeout
|
# Reset timeout
|
||||||
self._silence_seconds_left = self.silence_seconds
|
self._silence_seconds_left = self.silence_seconds
|
||||||
else:
|
else:
|
||||||
# Silence
|
# Silence
|
||||||
self._silence_seconds_left -= self._seconds_per_chunk
|
self._silence_seconds_left -= chunk_seconds
|
||||||
if self._silence_seconds_left <= 0:
|
if self._silence_seconds_left <= 0:
|
||||||
# Timeout reached
|
# Timeout reached
|
||||||
|
self.reset()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Slowly build reset counter back up
|
# Slowly build reset counter back up
|
||||||
self._reset_seconds_left = min(
|
self._reset_seconds_left = min(
|
||||||
self.reset_seconds, self._reset_seconds_left + self._seconds_per_chunk
|
self.reset_seconds, self._reset_seconds_left + chunk_seconds
|
||||||
)
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -18,6 +18,7 @@ from homeassistant.util import language as language_util
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
from .error import PipelineNotFound
|
from .error import PipelineNotFound
|
||||||
from .pipeline import (
|
from .pipeline import (
|
||||||
|
AudioSettings,
|
||||||
PipelineData,
|
PipelineData,
|
||||||
PipelineError,
|
PipelineError,
|
||||||
PipelineEvent,
|
PipelineEvent,
|
||||||
|
@ -71,6 +72,13 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||||
vol.Optional("audio_seconds_to_buffer"): vol.Any(
|
vol.Optional("audio_seconds_to_buffer"): vol.Any(
|
||||||
float, int
|
float, int
|
||||||
),
|
),
|
||||||
|
# Audio enhancement
|
||||||
|
vol.Optional("noise_suppression_level"): int,
|
||||||
|
vol.Optional("auto_gain_dbfs"): int,
|
||||||
|
vol.Optional("volume_multiplier"): float,
|
||||||
|
# Advanced use cases/testing
|
||||||
|
vol.Optional("no_vad"): bool,
|
||||||
|
vol.Optional("no_chunking"): bool,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
extra=vol.ALLOW_EXTRA,
|
extra=vol.ALLOW_EXTRA,
|
||||||
|
@ -115,6 +123,7 @@ async def websocket_run(
|
||||||
handler_id: int | None = None
|
handler_id: int | None = None
|
||||||
unregister_handler: Callable[[], None] | None = None
|
unregister_handler: Callable[[], None] | None = None
|
||||||
wake_word_settings: WakeWordSettings | None = None
|
wake_word_settings: WakeWordSettings | None = None
|
||||||
|
audio_settings: AudioSettings | None = None
|
||||||
|
|
||||||
# Arguments to PipelineInput
|
# Arguments to PipelineInput
|
||||||
input_args: dict[str, Any] = {
|
input_args: dict[str, Any] = {
|
||||||
|
@ -124,13 +133,14 @@ async def websocket_run(
|
||||||
|
|
||||||
if start_stage in (PipelineStage.WAKE_WORD, PipelineStage.STT):
|
if start_stage in (PipelineStage.WAKE_WORD, PipelineStage.STT):
|
||||||
# Audio pipeline that will receive audio as binary websocket messages
|
# Audio pipeline that will receive audio as binary websocket messages
|
||||||
|
msg_input = msg["input"]
|
||||||
audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||||
incoming_sample_rate = msg["input"]["sample_rate"]
|
incoming_sample_rate = msg_input["sample_rate"]
|
||||||
|
|
||||||
if start_stage == PipelineStage.WAKE_WORD:
|
if start_stage == PipelineStage.WAKE_WORD:
|
||||||
wake_word_settings = WakeWordSettings(
|
wake_word_settings = WakeWordSettings(
|
||||||
timeout=msg["input"].get("timeout", DEFAULT_WAKE_WORD_TIMEOUT),
|
timeout=msg["input"].get("timeout", DEFAULT_WAKE_WORD_TIMEOUT),
|
||||||
audio_seconds_to_buffer=msg["input"].get("audio_seconds_to_buffer", 0),
|
audio_seconds_to_buffer=msg_input.get("audio_seconds_to_buffer", 0),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def stt_stream() -> AsyncGenerator[bytes, None]:
|
async def stt_stream() -> AsyncGenerator[bytes, None]:
|
||||||
|
@ -166,6 +176,15 @@ async def websocket_run(
|
||||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
)
|
)
|
||||||
input_args["stt_stream"] = stt_stream()
|
input_args["stt_stream"] = stt_stream()
|
||||||
|
|
||||||
|
# Audio settings
|
||||||
|
audio_settings = AudioSettings(
|
||||||
|
noise_suppression_level=msg_input.get("noise_suppression_level", 0),
|
||||||
|
auto_gain_dbfs=msg_input.get("auto_gain_dbfs", 0),
|
||||||
|
volume_multiplier=msg_input.get("volume_multiplier", 1.0),
|
||||||
|
is_vad_enabled=not msg_input.get("no_vad", False),
|
||||||
|
is_chunking_enabled=not msg_input.get("no_chunking", False),
|
||||||
|
)
|
||||||
elif start_stage == PipelineStage.INTENT:
|
elif start_stage == PipelineStage.INTENT:
|
||||||
# Input to conversation agent
|
# Input to conversation agent
|
||||||
input_args["intent_input"] = msg["input"]["text"]
|
input_args["intent_input"] = msg["input"]["text"]
|
||||||
|
@ -185,6 +204,7 @@ async def websocket_run(
|
||||||
"timeout": timeout,
|
"timeout": timeout,
|
||||||
},
|
},
|
||||||
wake_word_settings=wake_word_settings,
|
wake_word_settings=wake_word_settings,
|
||||||
|
audio_settings=audio_settings or AudioSettings(),
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline_input = PipelineInput(**input_args)
|
pipeline_input = PipelineInput(**input_args)
|
||||||
|
|
|
@ -29,8 +29,11 @@ from homeassistant.components.assist_pipeline import (
|
||||||
select as pipeline_select,
|
select as pipeline_select,
|
||||||
)
|
)
|
||||||
from homeassistant.components.assist_pipeline.vad import (
|
from homeassistant.components.assist_pipeline.vad import (
|
||||||
|
AudioBuffer,
|
||||||
VadSensitivity,
|
VadSensitivity,
|
||||||
|
VoiceActivityDetector,
|
||||||
VoiceCommandSegmenter,
|
VoiceCommandSegmenter,
|
||||||
|
WebRtcVad,
|
||||||
)
|
)
|
||||||
from homeassistant.const import __version__
|
from homeassistant.const import __version__
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
|
@ -225,11 +228,13 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||||
try:
|
try:
|
||||||
# Wait for speech before starting pipeline
|
# Wait for speech before starting pipeline
|
||||||
segmenter = VoiceCommandSegmenter(silence_seconds=self.silence_seconds)
|
segmenter = VoiceCommandSegmenter(silence_seconds=self.silence_seconds)
|
||||||
|
vad = WebRtcVad()
|
||||||
chunk_buffer: deque[bytes] = deque(
|
chunk_buffer: deque[bytes] = deque(
|
||||||
maxlen=self.buffered_chunks_before_speech,
|
maxlen=self.buffered_chunks_before_speech,
|
||||||
)
|
)
|
||||||
speech_detected = await self._wait_for_speech(
|
speech_detected = await self._wait_for_speech(
|
||||||
segmenter,
|
segmenter,
|
||||||
|
vad,
|
||||||
chunk_buffer,
|
chunk_buffer,
|
||||||
)
|
)
|
||||||
if not speech_detected:
|
if not speech_detected:
|
||||||
|
@ -243,6 +248,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||||
try:
|
try:
|
||||||
async for chunk in self._segment_audio(
|
async for chunk in self._segment_audio(
|
||||||
segmenter,
|
segmenter,
|
||||||
|
vad,
|
||||||
chunk_buffer,
|
chunk_buffer,
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
@ -306,6 +312,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||||
async def _wait_for_speech(
|
async def _wait_for_speech(
|
||||||
self,
|
self,
|
||||||
segmenter: VoiceCommandSegmenter,
|
segmenter: VoiceCommandSegmenter,
|
||||||
|
vad: VoiceActivityDetector,
|
||||||
chunk_buffer: MutableSequence[bytes],
|
chunk_buffer: MutableSequence[bytes],
|
||||||
):
|
):
|
||||||
"""Buffer audio chunks until speech is detected.
|
"""Buffer audio chunks until speech is detected.
|
||||||
|
@ -317,12 +324,18 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||||
async with asyncio.timeout(self.audio_timeout):
|
async with asyncio.timeout(self.audio_timeout):
|
||||||
chunk = await self._audio_queue.get()
|
chunk = await self._audio_queue.get()
|
||||||
|
|
||||||
|
assert vad.samples_per_chunk is not None
|
||||||
|
vad_buffer = AudioBuffer(vad.samples_per_chunk * WIDTH)
|
||||||
|
|
||||||
while chunk:
|
while chunk:
|
||||||
chunk_buffer.append(chunk)
|
chunk_buffer.append(chunk)
|
||||||
|
|
||||||
segmenter.process(chunk)
|
segmenter.process_with_vad(chunk, vad, vad_buffer)
|
||||||
if segmenter.in_command:
|
if segmenter.in_command:
|
||||||
# Buffer until command starts
|
# Buffer until command starts
|
||||||
|
if len(vad_buffer) > 0:
|
||||||
|
chunk_buffer.append(vad_buffer.bytes())
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async with asyncio.timeout(self.audio_timeout):
|
async with asyncio.timeout(self.audio_timeout):
|
||||||
|
@ -333,6 +346,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||||
async def _segment_audio(
|
async def _segment_audio(
|
||||||
self,
|
self,
|
||||||
segmenter: VoiceCommandSegmenter,
|
segmenter: VoiceCommandSegmenter,
|
||||||
|
vad: VoiceActivityDetector,
|
||||||
chunk_buffer: Sequence[bytes],
|
chunk_buffer: Sequence[bytes],
|
||||||
) -> AsyncIterable[bytes]:
|
) -> AsyncIterable[bytes]:
|
||||||
"""Yield audio chunks until voice command has finished."""
|
"""Yield audio chunks until voice command has finished."""
|
||||||
|
@ -345,8 +359,11 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||||
async with asyncio.timeout(self.audio_timeout):
|
async with asyncio.timeout(self.audio_timeout):
|
||||||
chunk = await self._audio_queue.get()
|
chunk = await self._audio_queue.get()
|
||||||
|
|
||||||
|
assert vad.samples_per_chunk is not None
|
||||||
|
vad_buffer = AudioBuffer(vad.samples_per_chunk * WIDTH)
|
||||||
|
|
||||||
while chunk:
|
while chunk:
|
||||||
if not segmenter.process(chunk):
|
if not segmenter.process_with_vad(chunk, vad, vad_buffer):
|
||||||
# Voice command is finished
|
# Voice command is finished
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
|
@ -51,7 +51,7 @@ typing-extensions>=4.8.0,<5.0
|
||||||
ulid-transform==0.8.1
|
ulid-transform==0.8.1
|
||||||
voluptuous-serialize==2.6.0
|
voluptuous-serialize==2.6.0
|
||||||
voluptuous==0.13.1
|
voluptuous==0.13.1
|
||||||
webrtcvad==2.0.10
|
webrtc-noise-gain==1.1.0
|
||||||
yarl==1.9.2
|
yarl==1.9.2
|
||||||
zeroconf==0.114.0
|
zeroconf==0.114.0
|
||||||
|
|
||||||
|
|
|
@ -2691,7 +2691,7 @@ waterfurnace==1.1.0
|
||||||
webexteamssdk==1.1.1
|
webexteamssdk==1.1.1
|
||||||
|
|
||||||
# homeassistant.components.assist_pipeline
|
# homeassistant.components.assist_pipeline
|
||||||
webrtcvad==2.0.10
|
webrtc-noise-gain==1.1.0
|
||||||
|
|
||||||
# homeassistant.components.whirlpool
|
# homeassistant.components.whirlpool
|
||||||
whirlpool-sixth-sense==0.18.4
|
whirlpool-sixth-sense==0.18.4
|
||||||
|
|
|
@ -1994,7 +1994,7 @@ wallbox==0.4.12
|
||||||
watchdog==2.3.1
|
watchdog==2.3.1
|
||||||
|
|
||||||
# homeassistant.components.assist_pipeline
|
# homeassistant.components.assist_pipeline
|
||||||
webrtcvad==2.0.10
|
webrtc-noise-gain==1.1.0
|
||||||
|
|
||||||
# homeassistant.components.whirlpool
|
# homeassistant.components.whirlpool
|
||||||
whirlpool-sixth-sense==0.18.4
|
whirlpool-sixth-sense==0.18.4
|
||||||
|
|
|
@ -311,18 +311,6 @@
|
||||||
}),
|
}),
|
||||||
'type': <PipelineEventType.STT_START: 'stt-start'>,
|
'type': <PipelineEventType.STT_START: 'stt-start'>,
|
||||||
}),
|
}),
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
'timestamp': 0,
|
|
||||||
}),
|
|
||||||
'type': <PipelineEventType.STT_VAD_START: 'stt-vad-start'>,
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
'timestamp': 1500,
|
|
||||||
}),
|
|
||||||
'type': <PipelineEventType.STT_VAD_END: 'stt-vad-end'>,
|
|
||||||
}),
|
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'stt_output': dict({
|
'stt_output': dict({
|
||||||
|
|
|
@ -173,6 +173,87 @@
|
||||||
'message': 'No wake-word-detection provider for: wake_word.bad-entity-id',
|
'message': 'No wake-word-detection provider for: wake_word.bad-entity-id',
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_audio_pipeline_with_enhancements
|
||||||
|
dict({
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
'runner_data': dict({
|
||||||
|
'stt_binary_handler_id': 1,
|
||||||
|
'timeout': 30,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_audio_pipeline_with_enhancements.1
|
||||||
|
dict({
|
||||||
|
'engine': 'test',
|
||||||
|
'metadata': dict({
|
||||||
|
'bit_rate': 16,
|
||||||
|
'channel': 1,
|
||||||
|
'codec': 'pcm',
|
||||||
|
'format': 'wav',
|
||||||
|
'language': 'en-US',
|
||||||
|
'sample_rate': 16000,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_audio_pipeline_with_enhancements.2
|
||||||
|
dict({
|
||||||
|
'stt_output': dict({
|
||||||
|
'text': 'test transcript',
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_audio_pipeline_with_enhancements.3
|
||||||
|
dict({
|
||||||
|
'conversation_id': None,
|
||||||
|
'device_id': None,
|
||||||
|
'engine': 'homeassistant',
|
||||||
|
'intent_input': 'test transcript',
|
||||||
|
'language': 'en',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_audio_pipeline_with_enhancements.4
|
||||||
|
dict({
|
||||||
|
'intent_output': dict({
|
||||||
|
'conversation_id': None,
|
||||||
|
'response': dict({
|
||||||
|
'card': dict({
|
||||||
|
}),
|
||||||
|
'data': dict({
|
||||||
|
'code': 'no_intent_match',
|
||||||
|
}),
|
||||||
|
'language': 'en',
|
||||||
|
'response_type': 'error',
|
||||||
|
'speech': dict({
|
||||||
|
'plain': dict({
|
||||||
|
'extra_data': None,
|
||||||
|
'speech': "Sorry, I couldn't understand that",
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_audio_pipeline_with_enhancements.5
|
||||||
|
dict({
|
||||||
|
'engine': 'test',
|
||||||
|
'language': 'en-US',
|
||||||
|
'tts_input': "Sorry, I couldn't understand that",
|
||||||
|
'voice': 'james_earl_jones',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_audio_pipeline_with_enhancements.6
|
||||||
|
dict({
|
||||||
|
'tts_output': dict({
|
||||||
|
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
|
||||||
|
'mime_type': 'audio/mpeg',
|
||||||
|
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_audio_pipeline_with_enhancements.7
|
||||||
|
None
|
||||||
|
# ---
|
||||||
# name: test_audio_pipeline_with_wake_word
|
# name: test_audio_pipeline_with_wake_word
|
||||||
dict({
|
dict({
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
|
|
|
@ -64,6 +64,9 @@ async def test_pipeline_from_audio_stream_auto(
|
||||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
),
|
),
|
||||||
stt_stream=audio_data(),
|
stt_stream=audio_data(),
|
||||||
|
audio_settings=assist_pipeline.AudioSettings(
|
||||||
|
is_vad_enabled=False, is_chunking_enabled=False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert process_events(events) == snapshot
|
assert process_events(events) == snapshot
|
||||||
|
@ -126,6 +129,9 @@ async def test_pipeline_from_audio_stream_legacy(
|
||||||
),
|
),
|
||||||
stt_stream=audio_data(),
|
stt_stream=audio_data(),
|
||||||
pipeline_id=pipeline_id,
|
pipeline_id=pipeline_id,
|
||||||
|
audio_settings=assist_pipeline.AudioSettings(
|
||||||
|
is_vad_enabled=False, is_chunking_enabled=False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert process_events(events) == snapshot
|
assert process_events(events) == snapshot
|
||||||
|
@ -188,6 +194,9 @@ async def test_pipeline_from_audio_stream_entity(
|
||||||
),
|
),
|
||||||
stt_stream=audio_data(),
|
stt_stream=audio_data(),
|
||||||
pipeline_id=pipeline_id,
|
pipeline_id=pipeline_id,
|
||||||
|
audio_settings=assist_pipeline.AudioSettings(
|
||||||
|
is_vad_enabled=False, is_chunking_enabled=False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert process_events(events) == snapshot
|
assert process_events(events) == snapshot
|
||||||
|
@ -251,6 +260,9 @@ async def test_pipeline_from_audio_stream_no_stt(
|
||||||
),
|
),
|
||||||
stt_stream=audio_data(),
|
stt_stream=audio_data(),
|
||||||
pipeline_id=pipeline_id,
|
pipeline_id=pipeline_id,
|
||||||
|
audio_settings=assist_pipeline.AudioSettings(
|
||||||
|
is_vad_enabled=False, is_chunking_enabled=False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert not events
|
assert not events
|
||||||
|
@ -312,26 +324,26 @@ async def test_pipeline_from_audio_stream_wake_word(
|
||||||
# [0, 2, ...]
|
# [0, 2, ...]
|
||||||
wake_chunk_2 = bytes(it.islice(it.cycle(range(0, 256, 2)), BYTES_ONE_SECOND))
|
wake_chunk_2 = bytes(it.islice(it.cycle(range(0, 256, 2)), BYTES_ONE_SECOND))
|
||||||
|
|
||||||
|
bytes_per_chunk = int(0.01 * BYTES_ONE_SECOND)
|
||||||
|
|
||||||
async def audio_data():
|
async def audio_data():
|
||||||
yield wake_chunk_1 # 1 second
|
# 1 second in 10 ms chunks
|
||||||
yield wake_chunk_2 # 1 second
|
i = 0
|
||||||
|
while i < len(wake_chunk_1):
|
||||||
|
yield wake_chunk_1[i : i + bytes_per_chunk]
|
||||||
|
i += bytes_per_chunk
|
||||||
|
|
||||||
|
# 1 second in 30 ms chunks
|
||||||
|
i = 0
|
||||||
|
while i < len(wake_chunk_2):
|
||||||
|
yield wake_chunk_2[i : i + bytes_per_chunk]
|
||||||
|
i += bytes_per_chunk
|
||||||
|
|
||||||
yield b"wake word!"
|
yield b"wake word!"
|
||||||
yield b"part1"
|
yield b"part1"
|
||||||
yield b"part2"
|
yield b"part2"
|
||||||
yield b"end"
|
|
||||||
yield b""
|
yield b""
|
||||||
|
|
||||||
def continue_stt(self, chunk):
|
|
||||||
# Ensure stt_vad_start event is triggered
|
|
||||||
self.in_command = True
|
|
||||||
|
|
||||||
# Stop on fake end chunk to trigger stt_vad_end
|
|
||||||
return chunk != b"end"
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.assist_pipeline.pipeline.VoiceCommandSegmenter.process",
|
|
||||||
continue_stt,
|
|
||||||
):
|
|
||||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||||
hass,
|
hass,
|
||||||
context=Context(),
|
context=Context(),
|
||||||
|
@ -349,6 +361,9 @@ async def test_pipeline_from_audio_stream_wake_word(
|
||||||
wake_word_settings=assist_pipeline.WakeWordSettings(
|
wake_word_settings=assist_pipeline.WakeWordSettings(
|
||||||
audio_seconds_to_buffer=1.5
|
audio_seconds_to_buffer=1.5
|
||||||
),
|
),
|
||||||
|
audio_settings=assist_pipeline.AudioSettings(
|
||||||
|
is_vad_enabled=False, is_chunking_enabled=False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert process_events(events) == snapshot
|
assert process_events(events) == snapshot
|
||||||
|
@ -357,12 +372,14 @@ async def test_pipeline_from_audio_stream_wake_word(
|
||||||
# 2. queued audio (from mock wake word entity)
|
# 2. queued audio (from mock wake word entity)
|
||||||
# 3. part1
|
# 3. part1
|
||||||
# 4. part2
|
# 4. part2
|
||||||
assert len(mock_stt_provider.received) == 4
|
assert len(mock_stt_provider.received) > 3
|
||||||
|
|
||||||
first_chunk = mock_stt_provider.received[0]
|
first_chunk = bytes(
|
||||||
|
[c_byte for c in mock_stt_provider.received[:-3] for c_byte in c]
|
||||||
|
)
|
||||||
assert first_chunk == wake_chunk_1[len(wake_chunk_1) // 2 :] + wake_chunk_2
|
assert first_chunk == wake_chunk_1[len(wake_chunk_1) // 2 :] + wake_chunk_2
|
||||||
|
|
||||||
assert mock_stt_provider.received[1:] == [b"queued audio", b"part1", b"part2"]
|
assert mock_stt_provider.received[-3:] == [b"queued audio", b"part1", b"part2"]
|
||||||
|
|
||||||
|
|
||||||
async def test_pipeline_save_audio(
|
async def test_pipeline_save_audio(
|
||||||
|
@ -410,6 +427,9 @@ async def test_pipeline_save_audio(
|
||||||
pipeline_id=pipeline.id,
|
pipeline_id=pipeline.id,
|
||||||
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
|
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
|
||||||
end_stage=assist_pipeline.PipelineStage.STT,
|
end_stage=assist_pipeline.PipelineStage.STT,
|
||||||
|
audio_settings=assist_pipeline.AudioSettings(
|
||||||
|
is_vad_enabled=False, is_chunking_enabled=False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline_dirs = list(temp_dir.iterdir())
|
pipeline_dirs = list(temp_dir.iterdir())
|
||||||
|
|
|
@ -1,14 +1,15 @@
|
||||||
"""Tests for webrtcvad voice command segmenter."""
|
"""Tests for voice command segmenter."""
|
||||||
import itertools as it
|
import itertools as it
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from homeassistant.components.assist_pipeline.vad import (
|
from homeassistant.components.assist_pipeline.vad import (
|
||||||
AudioBuffer,
|
AudioBuffer,
|
||||||
|
VoiceActivityDetector,
|
||||||
VoiceCommandSegmenter,
|
VoiceCommandSegmenter,
|
||||||
chunk_samples,
|
chunk_samples,
|
||||||
)
|
)
|
||||||
|
|
||||||
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit
|
_ONE_SECOND = 1.0
|
||||||
|
|
||||||
|
|
||||||
def test_silence() -> None:
|
def test_silence() -> None:
|
||||||
|
@ -16,66 +17,63 @@ def test_silence() -> None:
|
||||||
segmenter = VoiceCommandSegmenter()
|
segmenter = VoiceCommandSegmenter()
|
||||||
|
|
||||||
# True return value indicates voice command has not finished
|
# True return value indicates voice command has not finished
|
||||||
assert segmenter.process(bytes(_ONE_SECOND * 3))
|
assert segmenter.process(_ONE_SECOND * 3, False)
|
||||||
|
|
||||||
|
|
||||||
def test_speech() -> None:
|
def test_speech() -> None:
|
||||||
"""Test that silence + speech + silence triggers a voice command."""
|
"""Test that silence + speech + silence triggers a voice command."""
|
||||||
|
|
||||||
def is_speech(self, chunk, sample_rate):
|
def is_speech(chunk):
|
||||||
"""Anything non-zero is speech."""
|
"""Anything non-zero is speech."""
|
||||||
return sum(chunk) > 0
|
return sum(chunk) > 0
|
||||||
|
|
||||||
with patch(
|
|
||||||
"webrtcvad.Vad.is_speech",
|
|
||||||
new=is_speech,
|
|
||||||
):
|
|
||||||
segmenter = VoiceCommandSegmenter()
|
segmenter = VoiceCommandSegmenter()
|
||||||
|
|
||||||
# silence
|
# silence
|
||||||
assert segmenter.process(bytes(_ONE_SECOND))
|
assert segmenter.process(_ONE_SECOND, False)
|
||||||
|
|
||||||
# "speech"
|
# "speech"
|
||||||
assert segmenter.process(bytes([255] * _ONE_SECOND))
|
assert segmenter.process(_ONE_SECOND, True)
|
||||||
|
|
||||||
# silence
|
# silence
|
||||||
# False return value indicates voice command is finished
|
# False return value indicates voice command is finished
|
||||||
assert not segmenter.process(bytes(_ONE_SECOND))
|
assert not segmenter.process(_ONE_SECOND, False)
|
||||||
|
|
||||||
|
|
||||||
def test_audio_buffer() -> None:
|
def test_audio_buffer() -> None:
|
||||||
"""Test audio buffer wrapping."""
|
"""Test audio buffer wrapping."""
|
||||||
|
|
||||||
def is_speech(self, chunk, sample_rate):
|
class DisabledVad(VoiceActivityDetector):
|
||||||
"""Disable VAD."""
|
def is_speech(self, chunk):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
with patch(
|
@property
|
||||||
"webrtcvad.Vad.is_speech",
|
def samples_per_chunk(self):
|
||||||
new=is_speech,
|
return 160 # 10 ms
|
||||||
):
|
|
||||||
segmenter = VoiceCommandSegmenter()
|
|
||||||
bytes_per_chunk = segmenter.vad_samples_per_chunk * 2
|
|
||||||
|
|
||||||
with patch.object(
|
vad = DisabledVad()
|
||||||
segmenter, "_process_chunk", return_value=True
|
bytes_per_chunk = vad.samples_per_chunk * 2
|
||||||
) as mock_process:
|
vad_buffer = AudioBuffer(bytes_per_chunk)
|
||||||
|
segmenter = VoiceCommandSegmenter()
|
||||||
|
|
||||||
|
with patch.object(vad, "is_speech", return_value=False) as mock_process:
|
||||||
# Partially fill audio buffer
|
# Partially fill audio buffer
|
||||||
half_chunk = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk // 2))
|
half_chunk = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk // 2))
|
||||||
segmenter.process(half_chunk)
|
segmenter.process_with_vad(half_chunk, vad, vad_buffer)
|
||||||
|
|
||||||
assert not mock_process.called
|
assert not mock_process.called
|
||||||
assert segmenter.audio_buffer == half_chunk
|
assert vad_buffer is not None
|
||||||
|
assert vad_buffer.bytes() == half_chunk
|
||||||
|
|
||||||
# Fill and wrap with 1/4 chunk left over
|
# Fill and wrap with 1/4 chunk left over
|
||||||
three_quarters_chunk = bytes(
|
three_quarters_chunk = bytes(
|
||||||
it.islice(it.cycle(range(256)), int(0.75 * bytes_per_chunk))
|
it.islice(it.cycle(range(256)), int(0.75 * bytes_per_chunk))
|
||||||
)
|
)
|
||||||
segmenter.process(three_quarters_chunk)
|
segmenter.process_with_vad(three_quarters_chunk, vad, vad_buffer)
|
||||||
|
|
||||||
assert mock_process.call_count == 1
|
assert mock_process.call_count == 1
|
||||||
assert (
|
assert (
|
||||||
segmenter.audio_buffer
|
vad_buffer.bytes()
|
||||||
== three_quarters_chunk[
|
== three_quarters_chunk[
|
||||||
len(three_quarters_chunk) - (bytes_per_chunk // 4) :
|
len(three_quarters_chunk) - (bytes_per_chunk // 4) :
|
||||||
]
|
]
|
||||||
|
@ -87,14 +85,15 @@ def test_audio_buffer() -> None:
|
||||||
|
|
||||||
# Run 2 chunks through
|
# Run 2 chunks through
|
||||||
segmenter.reset()
|
segmenter.reset()
|
||||||
assert len(segmenter.audio_buffer) == 0
|
vad_buffer.clear()
|
||||||
|
assert len(vad_buffer) == 0
|
||||||
|
|
||||||
mock_process.reset_mock()
|
mock_process.reset_mock()
|
||||||
two_chunks = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk * 2))
|
two_chunks = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk * 2))
|
||||||
segmenter.process(two_chunks)
|
segmenter.process_with_vad(two_chunks, vad, vad_buffer)
|
||||||
|
|
||||||
assert mock_process.call_count == 2
|
assert mock_process.call_count == 2
|
||||||
assert len(segmenter.audio_buffer) == 0
|
assert len(vad_buffer) == 0
|
||||||
assert mock_process.call_args_list[0][0][0] == two_chunks[:bytes_per_chunk]
|
assert mock_process.call_args_list[0][0][0] == two_chunks[:bytes_per_chunk]
|
||||||
assert mock_process.call_args_list[1][0][0] == two_chunks[bytes_per_chunk:]
|
assert mock_process.call_args_list[1][0][0] == two_chunks[bytes_per_chunk:]
|
||||||
|
|
||||||
|
@ -125,3 +124,43 @@ def test_chunk_samples_leftover() -> None:
|
||||||
|
|
||||||
assert len(chunks) == 1
|
assert len(chunks) == 1
|
||||||
assert leftover_chunk_buffer.bytes() == bytes([5, 6])
|
assert leftover_chunk_buffer.bytes() == bytes([5, 6])
|
||||||
|
|
||||||
|
|
||||||
|
def test_vad_no_chunking() -> None:
|
||||||
|
"""Test VAD that doesn't require chunking."""
|
||||||
|
|
||||||
|
class VadNoChunk(VoiceActivityDetector):
|
||||||
|
def is_speech(self, chunk: bytes) -> bool:
|
||||||
|
return sum(chunk) > 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def samples_per_chunk(self) -> int | None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
vad = VadNoChunk()
|
||||||
|
segmenter = VoiceCommandSegmenter(
|
||||||
|
speech_seconds=1.0, silence_seconds=1.0, reset_seconds=0.5
|
||||||
|
)
|
||||||
|
silence = bytes([0] * 16000)
|
||||||
|
speech = bytes([255] * (16000 // 2))
|
||||||
|
|
||||||
|
# Test with differently-sized chunks
|
||||||
|
assert vad.is_speech(speech)
|
||||||
|
assert not vad.is_speech(silence)
|
||||||
|
|
||||||
|
# Simulate voice command
|
||||||
|
assert segmenter.process_with_vad(silence, vad, None)
|
||||||
|
# begin
|
||||||
|
assert segmenter.process_with_vad(speech, vad, None)
|
||||||
|
assert segmenter.process_with_vad(speech, vad, None)
|
||||||
|
assert segmenter.process_with_vad(speech, vad, None)
|
||||||
|
# reset with silence
|
||||||
|
assert segmenter.process_with_vad(silence, vad, None)
|
||||||
|
# resume
|
||||||
|
assert segmenter.process_with_vad(speech, vad, None)
|
||||||
|
assert segmenter.process_with_vad(speech, vad, None)
|
||||||
|
assert segmenter.process_with_vad(speech, vad, None)
|
||||||
|
assert segmenter.process_with_vad(speech, vad, None)
|
||||||
|
# end
|
||||||
|
assert segmenter.process_with_vad(silence, vad, None)
|
||||||
|
assert not segmenter.process_with_vad(silence, vad, None)
|
||||||
|
|
|
@ -107,6 +107,7 @@ async def test_audio_pipeline(
|
||||||
assert msg["event"]["type"] == "run-start"
|
assert msg["event"]["type"] == "run-start"
|
||||||
msg["event"]["data"]["pipeline"] = ANY
|
msg["event"]["data"]["pipeline"] = ANY
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
|
||||||
events.append(msg["event"])
|
events.append(msg["event"])
|
||||||
|
|
||||||
# stt
|
# stt
|
||||||
|
@ -116,7 +117,7 @@ async def test_audio_pipeline(
|
||||||
events.append(msg["event"])
|
events.append(msg["event"])
|
||||||
|
|
||||||
# End of audio stream (handler id + empty payload)
|
# End of audio stream (handler id + empty payload)
|
||||||
await client.send_bytes(bytes([1]))
|
await client.send_bytes(bytes([handler_id]))
|
||||||
|
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "stt-end"
|
assert msg["event"]["type"] == "stt-end"
|
||||||
|
@ -240,6 +241,8 @@ async def test_audio_pipeline_with_wake_word_no_timeout(
|
||||||
"input": {
|
"input": {
|
||||||
"sample_rate": 16000,
|
"sample_rate": 16000,
|
||||||
"timeout": 0,
|
"timeout": 0,
|
||||||
|
"no_vad": True,
|
||||||
|
"no_chunking": True,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -253,6 +256,7 @@ async def test_audio_pipeline_with_wake_word_no_timeout(
|
||||||
assert msg["event"]["type"] == "run-start"
|
assert msg["event"]["type"] == "run-start"
|
||||||
msg["event"]["data"]["pipeline"] = ANY
|
msg["event"]["data"]["pipeline"] = ANY
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
|
||||||
events.append(msg["event"])
|
events.append(msg["event"])
|
||||||
|
|
||||||
# wake_word
|
# wake_word
|
||||||
|
@ -276,7 +280,7 @@ async def test_audio_pipeline_with_wake_word_no_timeout(
|
||||||
events.append(msg["event"])
|
events.append(msg["event"])
|
||||||
|
|
||||||
# End of audio stream (handler id + empty payload)
|
# End of audio stream (handler id + empty payload)
|
||||||
await client.send_bytes(bytes([1]))
|
await client.send_bytes(bytes([handler_id]))
|
||||||
|
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "stt-end"
|
assert msg["event"]["type"] == "stt-end"
|
||||||
|
@ -731,6 +735,7 @@ async def test_stt_stream_failed(
|
||||||
assert msg["event"]["type"] == "run-start"
|
assert msg["event"]["type"] == "run-start"
|
||||||
msg["event"]["data"]["pipeline"] = ANY
|
msg["event"]["data"]["pipeline"] = ANY
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
|
||||||
events.append(msg["event"])
|
events.append(msg["event"])
|
||||||
|
|
||||||
# stt
|
# stt
|
||||||
|
@ -740,7 +745,7 @@ async def test_stt_stream_failed(
|
||||||
events.append(msg["event"])
|
events.append(msg["event"])
|
||||||
|
|
||||||
# End of audio stream (handler id + empty payload)
|
# End of audio stream (handler id + empty payload)
|
||||||
await client.send_bytes(b"1")
|
await client.send_bytes(bytes([handler_id]))
|
||||||
|
|
||||||
# stt error
|
# stt error
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
|
@ -1489,6 +1494,7 @@ async def test_audio_pipeline_debug(
|
||||||
assert msg["event"]["type"] == "run-start"
|
assert msg["event"]["type"] == "run-start"
|
||||||
msg["event"]["data"]["pipeline"] = ANY
|
msg["event"]["data"]["pipeline"] = ANY
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
|
||||||
events.append(msg["event"])
|
events.append(msg["event"])
|
||||||
|
|
||||||
# stt
|
# stt
|
||||||
|
@ -1498,7 +1504,7 @@ async def test_audio_pipeline_debug(
|
||||||
events.append(msg["event"])
|
events.append(msg["event"])
|
||||||
|
|
||||||
# End of audio stream (handler id + empty payload)
|
# End of audio stream (handler id + empty payload)
|
||||||
await client.send_bytes(bytes([1]))
|
await client.send_bytes(bytes([handler_id]))
|
||||||
|
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "stt-end"
|
assert msg["event"]["type"] == "stt-end"
|
||||||
|
@ -1699,3 +1705,103 @@ async def test_list_pipeline_languages_with_aliases(
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
assert msg["result"] == {"languages": ["he", "nb"]}
|
assert msg["result"] == {"languages": ["he", "nb"]}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_audio_pipeline_with_enhancements(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
init_components,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test events from a pipeline run with audio input/output."""
|
||||||
|
events = []
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/run",
|
||||||
|
"start_stage": "stt",
|
||||||
|
"end_stage": "tts",
|
||||||
|
"input": {
|
||||||
|
"sample_rate": 16000,
|
||||||
|
# Enhancements
|
||||||
|
"noise_suppression_level": 2,
|
||||||
|
"auto_gain_dbfs": 15,
|
||||||
|
"volume_multiplier": 2.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
# run start
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-start"
|
||||||
|
msg["event"]["data"]["pipeline"] = ANY
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
# stt
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "stt-start"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
# One second of silence.
|
||||||
|
# This will pass through the audio enhancement pipeline, but we don't test
|
||||||
|
# the actual output.
|
||||||
|
await client.send_bytes(bytes([handler_id]) + bytes(16000 * 2))
|
||||||
|
|
||||||
|
# End of audio stream (handler id + empty payload)
|
||||||
|
await client.send_bytes(bytes([handler_id]))
|
||||||
|
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "stt-end"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
# intent
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "intent-start"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "intent-end"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
# text-to-speech
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "tts-start"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "tts-end"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
# run end
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-end"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||||
|
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline_debug/get",
|
||||||
|
"pipeline_id": pipeline_id,
|
||||||
|
"pipeline_run_id": pipeline_run_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {"events": events}
|
||||||
|
|
|
@ -21,7 +21,7 @@ async def test_pipeline(
|
||||||
"""Test that pipeline function is called from RTP protocol."""
|
"""Test that pipeline function is called from RTP protocol."""
|
||||||
assert await async_setup_component(hass, "voip", {})
|
assert await async_setup_component(hass, "voip", {})
|
||||||
|
|
||||||
def is_speech(self, chunk, sample_rate):
|
def is_speech(self, chunk):
|
||||||
"""Anything non-zero is speech."""
|
"""Anything non-zero is speech."""
|
||||||
return sum(chunk) > 0
|
return sum(chunk) > 0
|
||||||
|
|
||||||
|
@ -76,7 +76,7 @@ async def test_pipeline(
|
||||||
return ("mp3", b"")
|
return ("mp3", b"")
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"webrtcvad.Vad.is_speech",
|
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
|
||||||
new=is_speech,
|
new=is_speech,
|
||||||
), patch(
|
), patch(
|
||||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||||
|
@ -210,7 +210,7 @@ async def test_tts_timeout(
|
||||||
"""Test that TTS will time out based on its length."""
|
"""Test that TTS will time out based on its length."""
|
||||||
assert await async_setup_component(hass, "voip", {})
|
assert await async_setup_component(hass, "voip", {})
|
||||||
|
|
||||||
def is_speech(self, chunk, sample_rate):
|
def is_speech(self, chunk):
|
||||||
"""Anything non-zero is speech."""
|
"""Anything non-zero is speech."""
|
||||||
return sum(chunk) > 0
|
return sum(chunk) > 0
|
||||||
|
|
||||||
|
@ -269,7 +269,7 @@ async def test_tts_timeout(
|
||||||
return ("raw", bytes(0))
|
return ("raw", bytes(0))
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"webrtcvad.Vad.is_speech",
|
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
|
||||||
new=is_speech,
|
new=is_speech,
|
||||||
), patch(
|
), patch(
|
||||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||||
|
|
Loading…
Add table
Reference in a new issue