Wake word cleanup (#98652)
* Make arguments for async_pipeline_from_audio_stream keyword-only after hass * Use a bytearray ring buffer * Move generator outside * Move stt stream generator outside * Clean up execute * Refactor VAD to use bytearray * More tests * Refactor chunk_samples to be more correct and robust * Change AudioBuffer to use append instead of setitem * Cleanup --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
49897341ba
commit
8768c39021
9 changed files with 458 additions and 163 deletions
|
@ -1,7 +1,7 @@
|
|||
"""Test Voice Assistant init."""
|
||||
from dataclasses import asdict
|
||||
import itertools as it
|
||||
from unittest.mock import ANY
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
@ -49,9 +49,9 @@ async def test_pipeline_from_audio_stream_auto(
|
|||
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
Context(),
|
||||
events.append,
|
||||
stt.SpeechMetadata(
|
||||
context=Context(),
|
||||
event_callback=events.append,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
|
@ -59,7 +59,7 @@ async def test_pipeline_from_audio_stream_auto(
|
|||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
audio_data(),
|
||||
stt_stream=audio_data(),
|
||||
)
|
||||
|
||||
assert process_events(events) == snapshot
|
||||
|
@ -108,9 +108,9 @@ async def test_pipeline_from_audio_stream_legacy(
|
|||
# Use the created pipeline
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
Context(),
|
||||
events.append,
|
||||
stt.SpeechMetadata(
|
||||
context=Context(),
|
||||
event_callback=events.append,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="en-UK",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
|
@ -118,7 +118,7 @@ async def test_pipeline_from_audio_stream_legacy(
|
|||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
audio_data(),
|
||||
stt_stream=audio_data(),
|
||||
pipeline_id=pipeline_id,
|
||||
)
|
||||
|
||||
|
@ -168,9 +168,9 @@ async def test_pipeline_from_audio_stream_entity(
|
|||
# Use the created pipeline
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
Context(),
|
||||
events.append,
|
||||
stt.SpeechMetadata(
|
||||
context=Context(),
|
||||
event_callback=events.append,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="en-UK",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
|
@ -178,7 +178,7 @@ async def test_pipeline_from_audio_stream_entity(
|
|||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
audio_data(),
|
||||
stt_stream=audio_data(),
|
||||
pipeline_id=pipeline_id,
|
||||
)
|
||||
|
||||
|
@ -229,9 +229,9 @@ async def test_pipeline_from_audio_stream_no_stt(
|
|||
with pytest.raises(assist_pipeline.pipeline.PipelineRunValidationError):
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
Context(),
|
||||
events.append,
|
||||
stt.SpeechMetadata(
|
||||
context=Context(),
|
||||
event_callback=events.append,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="en-UK",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
|
@ -239,7 +239,7 @@ async def test_pipeline_from_audio_stream_no_stt(
|
|||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
audio_data(),
|
||||
stt_stream=audio_data(),
|
||||
pipeline_id=pipeline_id,
|
||||
)
|
||||
|
||||
|
@ -268,9 +268,9 @@ async def test_pipeline_from_audio_stream_unknown_pipeline(
|
|||
with pytest.raises(assist_pipeline.PipelineNotFound):
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
Context(),
|
||||
events.append,
|
||||
stt.SpeechMetadata(
|
||||
context=Context(),
|
||||
event_callback=events.append,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="en-UK",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
|
@ -278,7 +278,7 @@ async def test_pipeline_from_audio_stream_unknown_pipeline(
|
|||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
audio_data(),
|
||||
stt_stream=audio_data(),
|
||||
pipeline_id="blah",
|
||||
)
|
||||
|
||||
|
@ -308,26 +308,38 @@ async def test_pipeline_from_audio_stream_wake_word(
|
|||
yield b"wake word"
|
||||
yield b"part1"
|
||||
yield b"part2"
|
||||
yield b"end"
|
||||
yield b""
|
||||
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
Context(),
|
||||
events.append,
|
||||
stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
audio_data(),
|
||||
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
|
||||
wake_word_settings=assist_pipeline.WakeWordSettings(
|
||||
audio_seconds_to_buffer=1.5
|
||||
),
|
||||
)
|
||||
def continue_stt(self, chunk):
|
||||
# Ensure stt_vad_start event is triggered
|
||||
self.in_command = True
|
||||
|
||||
# Stop on fake end chunk to trigger stt_vad_end
|
||||
return chunk != b"end"
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.VoiceCommandSegmenter.process",
|
||||
continue_stt,
|
||||
):
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
context=Context(),
|
||||
event_callback=events.append,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_data(),
|
||||
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
|
||||
wake_word_settings=assist_pipeline.WakeWordSettings(
|
||||
audio_seconds_to_buffer=1.5
|
||||
),
|
||||
)
|
||||
|
||||
assert process_events(events) == snapshot
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue