Wake word cleanup (#98652)

* Make arguments for async_pipeline_from_audio_stream keyword-only after hass

* Use a bytearray ring buffer

* Move generator outside

* Move stt stream generator outside

* Clean up execute

* Refactor VAD to use bytearray

* More tests

* Refactor chunk_samples to be more correct and robust

* Change AudioBuffer to use append instead of setitem

* Cleanup

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen 2023-08-25 12:28:48 -05:00 committed by GitHub
parent 49897341ba
commit 8768c39021
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 458 additions and 163 deletions

View file

@ -1,7 +1,7 @@
"""Test Voice Assistant init."""
from dataclasses import asdict
import itertools as it
from unittest.mock import ANY
from unittest.mock import ANY, patch
import pytest
from syrupy.assertion import SnapshotAssertion
@ -49,9 +49,9 @@ async def test_pipeline_from_audio_stream_auto(
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
Context(),
events.append,
stt.SpeechMetadata(
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
@ -59,7 +59,7 @@ async def test_pipeline_from_audio_stream_auto(
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
audio_data(),
stt_stream=audio_data(),
)
assert process_events(events) == snapshot
@ -108,9 +108,9 @@ async def test_pipeline_from_audio_stream_legacy(
# Use the created pipeline
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
Context(),
events.append,
stt.SpeechMetadata(
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="en-UK",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
@ -118,7 +118,7 @@ async def test_pipeline_from_audio_stream_legacy(
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
audio_data(),
stt_stream=audio_data(),
pipeline_id=pipeline_id,
)
@ -168,9 +168,9 @@ async def test_pipeline_from_audio_stream_entity(
# Use the created pipeline
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
Context(),
events.append,
stt.SpeechMetadata(
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="en-UK",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
@ -178,7 +178,7 @@ async def test_pipeline_from_audio_stream_entity(
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
audio_data(),
stt_stream=audio_data(),
pipeline_id=pipeline_id,
)
@ -229,9 +229,9 @@ async def test_pipeline_from_audio_stream_no_stt(
with pytest.raises(assist_pipeline.pipeline.PipelineRunValidationError):
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
Context(),
events.append,
stt.SpeechMetadata(
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="en-UK",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
@ -239,7 +239,7 @@ async def test_pipeline_from_audio_stream_no_stt(
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
audio_data(),
stt_stream=audio_data(),
pipeline_id=pipeline_id,
)
@ -268,9 +268,9 @@ async def test_pipeline_from_audio_stream_unknown_pipeline(
with pytest.raises(assist_pipeline.PipelineNotFound):
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
Context(),
events.append,
stt.SpeechMetadata(
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="en-UK",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
@ -278,7 +278,7 @@ async def test_pipeline_from_audio_stream_unknown_pipeline(
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
audio_data(),
stt_stream=audio_data(),
pipeline_id="blah",
)
@ -308,26 +308,38 @@ async def test_pipeline_from_audio_stream_wake_word(
yield b"wake word"
yield b"part1"
yield b"part2"
yield b"end"
yield b""
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
Context(),
events.append,
stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
audio_data(),
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
wake_word_settings=assist_pipeline.WakeWordSettings(
audio_seconds_to_buffer=1.5
),
)
def continue_stt(self, chunk):
# Ensure stt_vad_start event is triggered
self.in_command = True
# Stop on fake end chunk to trigger stt_vad_end
return chunk != b"end"
with patch(
"homeassistant.components.assist_pipeline.pipeline.VoiceCommandSegmenter.process",
continue_stt,
):
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
wake_word_settings=assist_pipeline.WakeWordSettings(
audio_seconds_to_buffer=1.5
),
)
assert process_events(events) == snapshot