Use webrcvad to detect silence in pipelines (#90610)
* Add webrtcvad requirement * Use webrcvad for voice command segmenting * Add vad test
This commit is contained in:
parent
44b35fea47
commit
90d81e9844
7 changed files with 180 additions and 37 deletions
|
@ -5,5 +5,6 @@
|
||||||
"dependencies": ["conversation", "stt", "tts"],
|
"dependencies": ["conversation", "stt", "tts"],
|
||||||
"documentation": "https://www.home-assistant.io/integrations/voice_assistant",
|
"documentation": "https://www.home-assistant.io/integrations/voice_assistant",
|
||||||
"iot_class": "local_push",
|
"iot_class": "local_push",
|
||||||
"quality_scale": "internal"
|
"quality_scale": "internal",
|
||||||
|
"requirements": ["webrtcvad==2.0.10"]
|
||||||
}
|
}
|
||||||
|
|
128
homeassistant/components/voice_assistant/vad.py
Normal file
128
homeassistant/components/voice_assistant/vad.py
Normal file
|
@ -0,0 +1,128 @@
|
||||||
|
"""Voice activity detection."""
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
import webrtcvad
|
||||||
|
|
||||||
|
_SAMPLE_RATE = 16000
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VoiceCommandSegmenter:
|
||||||
|
"""Segments an audio stream into voice commands using webrtcvad."""
|
||||||
|
|
||||||
|
vad_mode: int = 3
|
||||||
|
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
|
||||||
|
|
||||||
|
vad_frames: int = 480 # 30 ms
|
||||||
|
"""Must be 10, 20, or 30 ms at 16Khz."""
|
||||||
|
|
||||||
|
speech_seconds: float = 0.3
|
||||||
|
"""Seconds of speech before voice command has started."""
|
||||||
|
|
||||||
|
silence_seconds: float = 0.5
|
||||||
|
"""Seconds of silence after voice command has ended."""
|
||||||
|
|
||||||
|
timeout_seconds: float = 15.0
|
||||||
|
"""Maximum number of seconds before stopping with timeout=True."""
|
||||||
|
|
||||||
|
reset_seconds: float = 1.0
|
||||||
|
"""Seconds before reset start/stop time counters."""
|
||||||
|
|
||||||
|
_in_command: bool = False
|
||||||
|
"""True if inside voice command."""
|
||||||
|
|
||||||
|
_speech_seconds_left: float = 0.0
|
||||||
|
"""Seconds left before considering voice command as started."""
|
||||||
|
|
||||||
|
_silence_seconds_left: float = 0.0
|
||||||
|
"""Seconds left before considering voice command as stopped."""
|
||||||
|
|
||||||
|
_timeout_seconds_left: float = 0.0
|
||||||
|
"""Seconds left before considering voice command timed out."""
|
||||||
|
|
||||||
|
_reset_seconds_left: float = 0.0
|
||||||
|
"""Seconds left before resetting start/stop time counters."""
|
||||||
|
|
||||||
|
_vad: webrtcvad.Vad = None
|
||||||
|
_audio_buffer: bytes = field(default_factory=bytes)
|
||||||
|
_bytes_per_chunk: int = 480 * 2 # 16-bit samples
|
||||||
|
_seconds_per_chunk: float = 0.03 # 30 ms
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Initialize VAD."""
|
||||||
|
self._vad = webrtcvad.Vad(self.vad_mode)
|
||||||
|
self._bytes_per_chunk = self.vad_frames * 2
|
||||||
|
self._seconds_per_chunk = self.vad_frames / _SAMPLE_RATE
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset all counters and state."""
|
||||||
|
self._audio_buffer = b""
|
||||||
|
self._speech_seconds_left = self.speech_seconds
|
||||||
|
self._silence_seconds_left = self.silence_seconds
|
||||||
|
self._timeout_seconds_left = self.timeout_seconds
|
||||||
|
self._reset_seconds_left = self.reset_seconds
|
||||||
|
self._in_command = False
|
||||||
|
|
||||||
|
def process(self, samples: bytes) -> bool:
|
||||||
|
"""Process a 16-bit 16Khz mono audio samples.
|
||||||
|
|
||||||
|
Returns False when command is done.
|
||||||
|
"""
|
||||||
|
self._audio_buffer += samples
|
||||||
|
|
||||||
|
# Process in 10, 20, or 30 ms chunks.
|
||||||
|
num_chunks = len(self._audio_buffer) // self._bytes_per_chunk
|
||||||
|
for chunk_idx in range(num_chunks):
|
||||||
|
chunk_offset = chunk_idx * self._bytes_per_chunk
|
||||||
|
chunk = self._audio_buffer[
|
||||||
|
chunk_offset : chunk_offset + self._bytes_per_chunk
|
||||||
|
]
|
||||||
|
if not self._process_chunk(chunk):
|
||||||
|
self.reset()
|
||||||
|
return False
|
||||||
|
|
||||||
|
if num_chunks > 0:
|
||||||
|
# Remove from buffer
|
||||||
|
self._audio_buffer = self._audio_buffer[
|
||||||
|
num_chunks * self._bytes_per_chunk :
|
||||||
|
]
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _process_chunk(self, chunk: bytes) -> bool:
|
||||||
|
"""Process a single chunk of 16-bit 16Khz mono audio.
|
||||||
|
|
||||||
|
Returns False when command is done.
|
||||||
|
"""
|
||||||
|
is_speech = self._vad.is_speech(chunk, _SAMPLE_RATE)
|
||||||
|
|
||||||
|
self._timeout_seconds_left -= self._seconds_per_chunk
|
||||||
|
if self._timeout_seconds_left <= 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not self._in_command:
|
||||||
|
if is_speech:
|
||||||
|
self._reset_seconds_left = self.reset_seconds
|
||||||
|
self._speech_seconds_left -= self._seconds_per_chunk
|
||||||
|
if self._speech_seconds_left <= 0:
|
||||||
|
# Inside voice command
|
||||||
|
self._in_command = True
|
||||||
|
else:
|
||||||
|
# Reset if enough silence
|
||||||
|
self._reset_seconds_left -= self._seconds_per_chunk
|
||||||
|
if self._reset_seconds_left <= 0:
|
||||||
|
self._speech_seconds_left = self.speech_seconds
|
||||||
|
else:
|
||||||
|
if not is_speech:
|
||||||
|
self._reset_seconds_left = self.reset_seconds
|
||||||
|
self._silence_seconds_left -= self._seconds_per_chunk
|
||||||
|
if self._silence_seconds_left <= 0:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# Reset if enough speech
|
||||||
|
self._reset_seconds_left -= self._seconds_per_chunk
|
||||||
|
if self._reset_seconds_left <= 0:
|
||||||
|
self._silence_seconds_left = self.silence_seconds
|
||||||
|
|
||||||
|
return True
|
|
@ -20,15 +20,12 @@ from .pipeline import (
|
||||||
PipelineStage,
|
PipelineStage,
|
||||||
async_get_pipeline,
|
async_get_pipeline,
|
||||||
)
|
)
|
||||||
|
from .vad import VoiceCommandSegmenter
|
||||||
|
|
||||||
DEFAULT_TIMEOUT = 30
|
DEFAULT_TIMEOUT = 30
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
_VAD_ENERGY_THRESHOLD = 1000
|
|
||||||
_VAD_SPEECH_FRAMES = 25
|
|
||||||
_VAD_SILENCE_FRAMES = 25
|
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||||
|
@ -36,17 +33,6 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||||
websocket_api.async_register_command(hass, websocket_run)
|
websocket_api.async_register_command(hass, websocket_run)
|
||||||
|
|
||||||
|
|
||||||
def _get_debiased_energy(audio_data: bytes, width: int = 2) -> float:
|
|
||||||
"""Compute RMS of debiased audio."""
|
|
||||||
energy = -audioop.rms(audio_data, width)
|
|
||||||
energy_bytes = bytes([energy & 0xFF, (energy >> 8) & 0xFF])
|
|
||||||
debiased_energy = audioop.rms(
|
|
||||||
audioop.add(audio_data, energy_bytes * (len(audio_data) // width), width), width
|
|
||||||
)
|
|
||||||
|
|
||||||
return debiased_energy
|
|
||||||
|
|
||||||
|
|
||||||
@websocket_api.websocket_command(
|
@websocket_api.websocket_command(
|
||||||
{
|
{
|
||||||
vol.Required("type"): "voice_assistant/run",
|
vol.Required("type"): "voice_assistant/run",
|
||||||
|
@ -105,30 +91,14 @@ async def websocket_run(
|
||||||
|
|
||||||
async def stt_stream():
|
async def stt_stream():
|
||||||
state = None
|
state = None
|
||||||
speech_count = 0
|
segmenter = VoiceCommandSegmenter()
|
||||||
in_voice_command = False
|
|
||||||
|
|
||||||
# Yield until we receive an empty chunk
|
# Yield until we receive an empty chunk
|
||||||
while chunk := await audio_queue.get():
|
while chunk := await audio_queue.get():
|
||||||
chunk, state = audioop.ratecv(chunk, 2, 1, 44100, 16000, state)
|
chunk, state = audioop.ratecv(chunk, 2, 1, 44100, 16000, state)
|
||||||
is_speech = _get_debiased_energy(chunk) > _VAD_ENERGY_THRESHOLD
|
if not segmenter.process(chunk):
|
||||||
|
# Voice command is finished
|
||||||
if in_voice_command:
|
|
||||||
if is_speech:
|
|
||||||
speech_count += 1
|
|
||||||
else:
|
|
||||||
speech_count -= 1
|
|
||||||
|
|
||||||
if speech_count <= -_VAD_SILENCE_FRAMES:
|
|
||||||
_LOGGER.info("Voice command stopped")
|
|
||||||
break
|
break
|
||||||
else:
|
|
||||||
if is_speech:
|
|
||||||
speech_count += 1
|
|
||||||
|
|
||||||
if speech_count >= _VAD_SPEECH_FRAMES:
|
|
||||||
in_voice_command = True
|
|
||||||
_LOGGER.info("Voice command started")
|
|
||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
|
@ -2619,6 +2619,9 @@ waterfurnace==1.1.0
|
||||||
# homeassistant.components.cisco_webex_teams
|
# homeassistant.components.cisco_webex_teams
|
||||||
webexteamssdk==1.1.1
|
webexteamssdk==1.1.1
|
||||||
|
|
||||||
|
# homeassistant.components.voice_assistant
|
||||||
|
webrtcvad==2.0.10
|
||||||
|
|
||||||
# homeassistant.components.whirlpool
|
# homeassistant.components.whirlpool
|
||||||
whirlpool-sixth-sense==0.18.2
|
whirlpool-sixth-sense==0.18.2
|
||||||
|
|
||||||
|
|
|
@ -1877,6 +1877,9 @@ wallbox==0.4.12
|
||||||
# homeassistant.components.folder_watcher
|
# homeassistant.components.folder_watcher
|
||||||
watchdog==2.3.1
|
watchdog==2.3.1
|
||||||
|
|
||||||
|
# homeassistant.components.voice_assistant
|
||||||
|
webrtcvad==2.0.10
|
||||||
|
|
||||||
# homeassistant.components.whirlpool
|
# homeassistant.components.whirlpool
|
||||||
whirlpool-sixth-sense==0.18.2
|
whirlpool-sixth-sense==0.18.2
|
||||||
|
|
||||||
|
|
38
tests/components/voice_assistant/test_vad.py
Normal file
38
tests/components/voice_assistant/test_vad.py
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
"""Tests for webrtcvad voice command segmenter."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from homeassistant.components.voice_assistant.vad import VoiceCommandSegmenter
|
||||||
|
|
||||||
|
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit
|
||||||
|
|
||||||
|
|
||||||
|
def test_silence() -> None:
|
||||||
|
"""Test that 3 seconds of silence does not trigger a voice command."""
|
||||||
|
segmenter = VoiceCommandSegmenter()
|
||||||
|
|
||||||
|
# True return value indicates voice command has not finished
|
||||||
|
assert segmenter.process(bytes(_ONE_SECOND * 3))
|
||||||
|
|
||||||
|
|
||||||
|
def test_speech() -> None:
|
||||||
|
"""Test that silence + speech + silence triggers a voice command."""
|
||||||
|
|
||||||
|
def is_speech(self, chunk, sample_rate):
|
||||||
|
"""Anything non-zero is speech."""
|
||||||
|
return sum(chunk) > 0
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"webrtcvad.Vad.is_speech",
|
||||||
|
new=is_speech,
|
||||||
|
):
|
||||||
|
segmenter = VoiceCommandSegmenter()
|
||||||
|
|
||||||
|
# silence
|
||||||
|
assert segmenter.process(bytes(_ONE_SECOND))
|
||||||
|
|
||||||
|
# "speech"
|
||||||
|
assert segmenter.process(bytes([255] * _ONE_SECOND))
|
||||||
|
|
||||||
|
# silence
|
||||||
|
# False return value indicates voice command is finished
|
||||||
|
assert not segmenter.process(bytes(_ONE_SECOND))
|
|
@ -75,7 +75,7 @@ class MockSTT:
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
config: ConfigType,
|
config: ConfigType,
|
||||||
discovery_info: DiscoveryInfoType | None = None,
|
discovery_info: DiscoveryInfoType | None = None,
|
||||||
) -> tts.Provider:
|
) -> stt.Provider:
|
||||||
"""Set up a mock speech component."""
|
"""Set up a mock speech component."""
|
||||||
return MockSttProvider(hass, _TRANSCRIPT)
|
return MockSttProvider(hass, _TRANSCRIPT)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue