Use webrcvad to detect silence in pipelines (#90610)
* Add webrtcvad requirement * Use webrcvad for voice command segmenting * Add vad test
This commit is contained in:
parent
44b35fea47
commit
90d81e9844
7 changed files with 180 additions and 37 deletions
|
@ -5,5 +5,6 @@
|
|||
"dependencies": ["conversation", "stt", "tts"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/voice_assistant",
|
||||
"iot_class": "local_push",
|
||||
"quality_scale": "internal"
|
||||
"quality_scale": "internal",
|
||||
"requirements": ["webrtcvad==2.0.10"]
|
||||
}
|
||||
|
|
128
homeassistant/components/voice_assistant/vad.py
Normal file
128
homeassistant/components/voice_assistant/vad.py
Normal file
|
@ -0,0 +1,128 @@
|
|||
"""Voice activity detection."""
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import webrtcvad
|
||||
|
||||
_SAMPLE_RATE = 16000
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoiceCommandSegmenter:
|
||||
"""Segments an audio stream into voice commands using webrtcvad."""
|
||||
|
||||
vad_mode: int = 3
|
||||
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
|
||||
|
||||
vad_frames: int = 480 # 30 ms
|
||||
"""Must be 10, 20, or 30 ms at 16Khz."""
|
||||
|
||||
speech_seconds: float = 0.3
|
||||
"""Seconds of speech before voice command has started."""
|
||||
|
||||
silence_seconds: float = 0.5
|
||||
"""Seconds of silence after voice command has ended."""
|
||||
|
||||
timeout_seconds: float = 15.0
|
||||
"""Maximum number of seconds before stopping with timeout=True."""
|
||||
|
||||
reset_seconds: float = 1.0
|
||||
"""Seconds before reset start/stop time counters."""
|
||||
|
||||
_in_command: bool = False
|
||||
"""True if inside voice command."""
|
||||
|
||||
_speech_seconds_left: float = 0.0
|
||||
"""Seconds left before considering voice command as started."""
|
||||
|
||||
_silence_seconds_left: float = 0.0
|
||||
"""Seconds left before considering voice command as stopped."""
|
||||
|
||||
_timeout_seconds_left: float = 0.0
|
||||
"""Seconds left before considering voice command timed out."""
|
||||
|
||||
_reset_seconds_left: float = 0.0
|
||||
"""Seconds left before resetting start/stop time counters."""
|
||||
|
||||
_vad: webrtcvad.Vad = None
|
||||
_audio_buffer: bytes = field(default_factory=bytes)
|
||||
_bytes_per_chunk: int = 480 * 2 # 16-bit samples
|
||||
_seconds_per_chunk: float = 0.03 # 30 ms
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize VAD."""
|
||||
self._vad = webrtcvad.Vad(self.vad_mode)
|
||||
self._bytes_per_chunk = self.vad_frames * 2
|
||||
self._seconds_per_chunk = self.vad_frames / _SAMPLE_RATE
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""Reset all counters and state."""
|
||||
self._audio_buffer = b""
|
||||
self._speech_seconds_left = self.speech_seconds
|
||||
self._silence_seconds_left = self.silence_seconds
|
||||
self._timeout_seconds_left = self.timeout_seconds
|
||||
self._reset_seconds_left = self.reset_seconds
|
||||
self._in_command = False
|
||||
|
||||
def process(self, samples: bytes) -> bool:
|
||||
"""Process a 16-bit 16Khz mono audio samples.
|
||||
|
||||
Returns False when command is done.
|
||||
"""
|
||||
self._audio_buffer += samples
|
||||
|
||||
# Process in 10, 20, or 30 ms chunks.
|
||||
num_chunks = len(self._audio_buffer) // self._bytes_per_chunk
|
||||
for chunk_idx in range(num_chunks):
|
||||
chunk_offset = chunk_idx * self._bytes_per_chunk
|
||||
chunk = self._audio_buffer[
|
||||
chunk_offset : chunk_offset + self._bytes_per_chunk
|
||||
]
|
||||
if not self._process_chunk(chunk):
|
||||
self.reset()
|
||||
return False
|
||||
|
||||
if num_chunks > 0:
|
||||
# Remove from buffer
|
||||
self._audio_buffer = self._audio_buffer[
|
||||
num_chunks * self._bytes_per_chunk :
|
||||
]
|
||||
|
||||
return True
|
||||
|
||||
def _process_chunk(self, chunk: bytes) -> bool:
|
||||
"""Process a single chunk of 16-bit 16Khz mono audio.
|
||||
|
||||
Returns False when command is done.
|
||||
"""
|
||||
is_speech = self._vad.is_speech(chunk, _SAMPLE_RATE)
|
||||
|
||||
self._timeout_seconds_left -= self._seconds_per_chunk
|
||||
if self._timeout_seconds_left <= 0:
|
||||
return False
|
||||
|
||||
if not self._in_command:
|
||||
if is_speech:
|
||||
self._reset_seconds_left = self.reset_seconds
|
||||
self._speech_seconds_left -= self._seconds_per_chunk
|
||||
if self._speech_seconds_left <= 0:
|
||||
# Inside voice command
|
||||
self._in_command = True
|
||||
else:
|
||||
# Reset if enough silence
|
||||
self._reset_seconds_left -= self._seconds_per_chunk
|
||||
if self._reset_seconds_left <= 0:
|
||||
self._speech_seconds_left = self.speech_seconds
|
||||
else:
|
||||
if not is_speech:
|
||||
self._reset_seconds_left = self.reset_seconds
|
||||
self._silence_seconds_left -= self._seconds_per_chunk
|
||||
if self._silence_seconds_left <= 0:
|
||||
return False
|
||||
else:
|
||||
# Reset if enough speech
|
||||
self._reset_seconds_left -= self._seconds_per_chunk
|
||||
if self._reset_seconds_left <= 0:
|
||||
self._silence_seconds_left = self.silence_seconds
|
||||
|
||||
return True
|
|
@ -20,15 +20,12 @@ from .pipeline import (
|
|||
PipelineStage,
|
||||
async_get_pipeline,
|
||||
)
|
||||
from .vad import VoiceCommandSegmenter
|
||||
|
||||
DEFAULT_TIMEOUT = 30
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_VAD_ENERGY_THRESHOLD = 1000
|
||||
_VAD_SPEECH_FRAMES = 25
|
||||
_VAD_SILENCE_FRAMES = 25
|
||||
|
||||
|
||||
@callback
|
||||
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||
|
@ -36,17 +33,6 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
|
|||
websocket_api.async_register_command(hass, websocket_run)
|
||||
|
||||
|
||||
def _get_debiased_energy(audio_data: bytes, width: int = 2) -> float:
|
||||
"""Compute RMS of debiased audio."""
|
||||
energy = -audioop.rms(audio_data, width)
|
||||
energy_bytes = bytes([energy & 0xFF, (energy >> 8) & 0xFF])
|
||||
debiased_energy = audioop.rms(
|
||||
audioop.add(audio_data, energy_bytes * (len(audio_data) // width), width), width
|
||||
)
|
||||
|
||||
return debiased_energy
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "voice_assistant/run",
|
||||
|
@ -105,30 +91,14 @@ async def websocket_run(
|
|||
|
||||
async def stt_stream():
|
||||
state = None
|
||||
speech_count = 0
|
||||
in_voice_command = False
|
||||
segmenter = VoiceCommandSegmenter()
|
||||
|
||||
# Yield until we receive an empty chunk
|
||||
while chunk := await audio_queue.get():
|
||||
chunk, state = audioop.ratecv(chunk, 2, 1, 44100, 16000, state)
|
||||
is_speech = _get_debiased_energy(chunk) > _VAD_ENERGY_THRESHOLD
|
||||
|
||||
if in_voice_command:
|
||||
if is_speech:
|
||||
speech_count += 1
|
||||
else:
|
||||
speech_count -= 1
|
||||
|
||||
if speech_count <= -_VAD_SILENCE_FRAMES:
|
||||
_LOGGER.info("Voice command stopped")
|
||||
break
|
||||
else:
|
||||
if is_speech:
|
||||
speech_count += 1
|
||||
|
||||
if speech_count >= _VAD_SPEECH_FRAMES:
|
||||
in_voice_command = True
|
||||
_LOGGER.info("Voice command started")
|
||||
if not segmenter.process(chunk):
|
||||
# Voice command is finished
|
||||
break
|
||||
|
||||
yield chunk
|
||||
|
||||
|
|
|
@ -2619,6 +2619,9 @@ waterfurnace==1.1.0
|
|||
# homeassistant.components.cisco_webex_teams
|
||||
webexteamssdk==1.1.1
|
||||
|
||||
# homeassistant.components.voice_assistant
|
||||
webrtcvad==2.0.10
|
||||
|
||||
# homeassistant.components.whirlpool
|
||||
whirlpool-sixth-sense==0.18.2
|
||||
|
||||
|
|
|
@ -1877,6 +1877,9 @@ wallbox==0.4.12
|
|||
# homeassistant.components.folder_watcher
|
||||
watchdog==2.3.1
|
||||
|
||||
# homeassistant.components.voice_assistant
|
||||
webrtcvad==2.0.10
|
||||
|
||||
# homeassistant.components.whirlpool
|
||||
whirlpool-sixth-sense==0.18.2
|
||||
|
||||
|
|
38
tests/components/voice_assistant/test_vad.py
Normal file
38
tests/components/voice_assistant/test_vad.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
"""Tests for webrtcvad voice command segmenter."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant.components.voice_assistant.vad import VoiceCommandSegmenter
|
||||
|
||||
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit
|
||||
|
||||
|
||||
def test_silence() -> None:
|
||||
"""Test that 3 seconds of silence does not trigger a voice command."""
|
||||
segmenter = VoiceCommandSegmenter()
|
||||
|
||||
# True return value indicates voice command has not finished
|
||||
assert segmenter.process(bytes(_ONE_SECOND * 3))
|
||||
|
||||
|
||||
def test_speech() -> None:
|
||||
"""Test that silence + speech + silence triggers a voice command."""
|
||||
|
||||
def is_speech(self, chunk, sample_rate):
|
||||
"""Anything non-zero is speech."""
|
||||
return sum(chunk) > 0
|
||||
|
||||
with patch(
|
||||
"webrtcvad.Vad.is_speech",
|
||||
new=is_speech,
|
||||
):
|
||||
segmenter = VoiceCommandSegmenter()
|
||||
|
||||
# silence
|
||||
assert segmenter.process(bytes(_ONE_SECOND))
|
||||
|
||||
# "speech"
|
||||
assert segmenter.process(bytes([255] * _ONE_SECOND))
|
||||
|
||||
# silence
|
||||
# False return value indicates voice command is finished
|
||||
assert not segmenter.process(bytes(_ONE_SECOND))
|
|
@ -75,7 +75,7 @@ class MockSTT:
|
|||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
discovery_info: DiscoveryInfoType | None = None,
|
||||
) -> tts.Provider:
|
||||
) -> stt.Provider:
|
||||
"""Set up a mock speech component."""
|
||||
return MockSttProvider(hass, _TRANSCRIPT)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue