Wake word cleanup (#98652)

* Make arguments for async_pipeline_from_audio_stream keyword-only after hass

* Use a bytearray ring buffer

* Move generator outside

* Move stt stream generator outside

* Clean up execute

* Refactor VAD to use bytearray

* More tests

* Refactor chunk_samples to be more correct and robust

* Change AudioBuffer to use append instead of setitem

* Cleanup

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen 2023-08-25 12:28:48 -05:00 committed by GitHub
parent 49897341ba
commit 8768c39021
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 458 additions and 163 deletions

View file

@ -49,6 +49,7 @@ from .error import (
WakeWordDetectionError,
WakeWordTimeoutError,
)
from .ring_buffer import RingBuffer
from .vad import VoiceActivityTimeout, VoiceCommandSegmenter
_LOGGER = logging.getLogger(__name__)
@ -425,7 +426,6 @@ class PipelineRun:
async def prepare_wake_word_detection(self) -> None:
"""Prepare wake-word-detection."""
# Need to add to pipeline store
engine = wake_word.async_default_engine(self.hass)
if engine is None:
raise WakeWordDetectionError(
@ -448,7 +448,7 @@ class PipelineRun:
async def wake_word_detection(
self,
stream: AsyncIterable[bytes],
audio_buffer: list[bytes],
audio_chunks_for_stt: list[bytes],
) -> wake_word.DetectionResult | None:
"""Run wake-word-detection portion of pipeline. Returns detection result."""
metadata_dict = asdict(
@ -484,46 +484,29 @@ class PipelineRun:
# Use VAD to determine timeout
wake_word_vad = VoiceActivityTimeout(wake_word_settings.timeout)
# Audio chunk buffer.
audio_bytes_to_buffer = int(
wake_word_settings.audio_seconds_to_buffer * 16000 * 2
# Audio chunk buffer. This audio will be forwarded to speech-to-text
# after wake-word-detection.
num_audio_bytes_to_buffer = int(
wake_word_settings.audio_seconds_to_buffer * 16000 * 2 # 16-bit @ 16Khz
)
audio_ring_buffer = b""
async def timestamped_stream() -> AsyncIterable[tuple[bytes, int]]:
"""Yield audio with timestamps (milliseconds since start of stream)."""
nonlocal audio_ring_buffer
timestamp_ms = 0
async for chunk in stream:
yield chunk, timestamp_ms
timestamp_ms += (len(chunk) // 2) // 16 # milliseconds @ 16Khz
# Keeping audio right before wake word detection allows the
# voice command to be spoken immediately after the wake word.
if audio_bytes_to_buffer > 0:
audio_ring_buffer += chunk
if len(audio_ring_buffer) > audio_bytes_to_buffer:
# A proper ring buffer would be far more efficient
audio_ring_buffer = audio_ring_buffer[
len(audio_ring_buffer) - audio_bytes_to_buffer :
]
if (wake_word_vad is not None) and (not wake_word_vad.process(chunk)):
raise WakeWordTimeoutError(
code="wake-word-timeout", message="Wake word was not detected"
)
stt_audio_buffer: RingBuffer | None = None
if num_audio_bytes_to_buffer > 0:
stt_audio_buffer = RingBuffer(num_audio_bytes_to_buffer)
try:
# Detect wake word(s)
result = await self.wake_word_provider.async_process_audio_stream(
timestamped_stream()
_wake_word_audio_stream(
audio_stream=stream,
stt_audio_buffer=stt_audio_buffer,
wake_word_vad=wake_word_vad,
)
)
if audio_ring_buffer:
if stt_audio_buffer is not None:
# All audio kept from right before the wake word was detected as
# a single chunk.
audio_buffer.append(audio_ring_buffer)
audio_chunks_for_stt.append(stt_audio_buffer.getvalue())
except WakeWordTimeoutError:
_LOGGER.debug("Timeout during wake word detection")
raise
@ -540,9 +523,14 @@ class PipelineRun:
wake_word_output: dict[str, Any] = {}
else:
if result.queued_audio:
# Add audio that was pending at detection
# Add audio that was pending at detection.
#
# Because detection occurs *after* the wake word was actually
# spoken, we need to make sure pending audio is forwarded to
# speech-to-text so the user does not have to pause before
# speaking the voice command.
for chunk_ts in result.queued_audio:
audio_buffer.append(chunk_ts[0])
audio_chunks_for_stt.append(chunk_ts[0])
wake_word_output = asdict(result)
@ -608,41 +596,12 @@ class PipelineRun:
)
try:
segmenter = VoiceCommandSegmenter()
async def segment_stream(
stream: AsyncIterable[bytes],
) -> AsyncGenerator[bytes, None]:
"""Stop stream when voice command is finished."""
sent_vad_start = False
timestamp_ms = 0
async for chunk in stream:
if not segmenter.process(chunk):
# Silence detected at the end of voice command
self.process_event(
PipelineEvent(
PipelineEventType.STT_VAD_END,
{"timestamp": timestamp_ms},
)
)
break
if segmenter.in_command and (not sent_vad_start):
# Speech detected at start of voice command
self.process_event(
PipelineEvent(
PipelineEventType.STT_VAD_START,
{"timestamp": timestamp_ms},
)
)
sent_vad_start = True
yield chunk
timestamp_ms += (len(chunk) // 2) // 16 # milliseconds @ 16Khz
# Transcribe audio stream
result = await self.stt_provider.async_process_audio_stream(
metadata, segment_stream(stream)
metadata,
self._speech_to_text_stream(
audio_stream=stream, stt_vad=VoiceCommandSegmenter()
),
)
except Exception as src_error:
_LOGGER.exception("Unexpected error during speech-to-text")
@ -677,6 +636,42 @@ class PipelineRun:
return result.text
async def _speech_to_text_stream(
self,
audio_stream: AsyncIterable[bytes],
stt_vad: VoiceCommandSegmenter | None,
sample_rate: int = 16000,
sample_width: int = 2,
) -> AsyncGenerator[bytes, None]:
"""Yield audio chunks until VAD detects silence or speech-to-text completes."""
ms_per_sample = sample_rate // 1000
sent_vad_start = False
timestamp_ms = 0
async for chunk in audio_stream:
if stt_vad is not None:
if not stt_vad.process(chunk):
# Silence detected at the end of voice command
self.process_event(
PipelineEvent(
PipelineEventType.STT_VAD_END,
{"timestamp": timestamp_ms},
)
)
break
if stt_vad.in_command and (not sent_vad_start):
# Speech detected at start of voice command
self.process_event(
PipelineEvent(
PipelineEventType.STT_VAD_START,
{"timestamp": timestamp_ms},
)
)
sent_vad_start = True
yield chunk
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
async def prepare_recognize_intent(self) -> None:
"""Prepare recognizing an intent."""
agent_info = conversation.async_get_agent_info(
@ -861,13 +856,14 @@ class PipelineInput:
"""Run pipeline."""
self.run.start()
current_stage: PipelineStage | None = self.run.start_stage
audio_buffer: list[bytes] = []
stt_audio_buffer: list[bytes] = []
try:
if current_stage == PipelineStage.WAKE_WORD:
# wake-word-detection
assert self.stt_stream is not None
detect_result = await self.run.wake_word_detection(
self.stt_stream, audio_buffer
self.stt_stream, stt_audio_buffer
)
if detect_result is None:
# No wake word. Abort the rest of the pipeline.
@ -882,19 +878,22 @@ class PipelineInput:
assert self.stt_metadata is not None
assert self.stt_stream is not None
if audio_buffer:
stt_stream = self.stt_stream
async def buffered_stream() -> AsyncGenerator[bytes, None]:
for chunk in audio_buffer:
if stt_audio_buffer:
# Send audio in the buffer first to speech-to-text, then move on to stt_stream.
# This is basically an async itertools.chain.
async def buffer_then_audio_stream() -> AsyncGenerator[bytes, None]:
# Buffered audio
for chunk in stt_audio_buffer:
yield chunk
# Streamed audio
assert self.stt_stream is not None
async for chunk in self.stt_stream:
yield chunk
stt_stream = cast(AsyncIterable[bytes], buffered_stream())
else:
stt_stream = self.stt_stream
stt_stream = buffer_then_audio_stream()
intent_input = await self.run.speech_to_text(
self.stt_metadata,
@ -906,6 +905,7 @@ class PipelineInput:
tts_input = self.tts_input
if current_stage == PipelineStage.INTENT:
# intent-recognition
assert intent_input is not None
tts_input = await self.run.recognize_intent(
intent_input,
@ -915,6 +915,7 @@ class PipelineInput:
current_stage = PipelineStage.TTS
if self.run.end_stage != PipelineStage.INTENT:
# text-to-speech
if current_stage == PipelineStage.TTS:
assert tts_input is not None
await self.run.text_to_speech(tts_input)
@ -999,6 +1000,36 @@ class PipelineInput:
await asyncio.gather(*prepare_tasks)
async def _wake_word_audio_stream(
audio_stream: AsyncIterable[bytes],
stt_audio_buffer: RingBuffer | None,
wake_word_vad: VoiceActivityTimeout | None,
sample_rate: int = 16000,
sample_width: int = 2,
) -> AsyncIterable[tuple[bytes, int]]:
"""Yield audio chunks with timestamps (milliseconds since start of stream).
Adds audio to a ring buffer that will be forwarded to speech-to-text after
detection. Times out if VAD detects enough silence.
"""
ms_per_sample = sample_rate // 1000
timestamp_ms = 0
async for chunk in audio_stream:
yield chunk, timestamp_ms
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
# Wake-word-detection occurs *after* the wake word was actually
# spoken. Keeping audio right before detection allows the voice
# command to be spoken immediately after the wake word.
if stt_audio_buffer is not None:
stt_audio_buffer.put(chunk)
if (wake_word_vad is not None) and (not wake_word_vad.process(chunk)):
raise WakeWordTimeoutError(
code="wake-word-timeout", message="Wake word was not detected"
)
class PipelinePreferred(CollectionError):
"""Raised when attempting to delete the preferred pipelen."""