Add option to save Assist pipeline audio (#98928)
* Add pipeline option to save wake/stt audio to media * Add debug_recording_dir to assist_pipeline YAML config * Clean up and additional tests * Remove I/O in event loop * Organize saved audio by pipeline name and device id * Record wake/stt debug audio in separate thread * Fix after rebase * Use timestamp instead of pipeline id for directory name * Add WAV write error test * Join thread in executor
This commit is contained in:
parent
de30712d76
commit
054a63c3a2
7 changed files with 417 additions and 49 deletions
|
@ -3,12 +3,13 @@ from __future__ import annotations
|
|||
|
||||
from collections.abc import AsyncIterable
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import stt
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import DOMAIN
|
||||
from .const import DATA_CONFIG, DOMAIN
|
||||
from .error import PipelineNotFound
|
||||
from .pipeline import (
|
||||
Pipeline,
|
||||
|
@ -39,11 +40,15 @@ __all__ = (
|
|||
"WakeWordSettings",
|
||||
)
|
||||
|
||||
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
|
||||
CONFIG_SCHEMA = vol.Schema(
|
||||
{vol.Optional(DOMAIN): {vol.Optional("debug_recording_dir"): str}}
|
||||
)
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up the Assist pipeline integration."""
|
||||
hass.data[DATA_CONFIG] = config.get(DOMAIN, {})
|
||||
|
||||
await async_setup_pipeline_store(hass)
|
||||
async_register_websocket_api(hass)
|
||||
|
||||
|
|
|
@ -1,2 +1,4 @@
|
|||
"""Constants for the Assist pipeline integration."""
|
||||
DOMAIN = "assist_pipeline"
|
||||
|
||||
DATA_CONFIG = f"{DOMAIN}.config"
|
||||
|
|
|
@ -6,7 +6,12 @@ from collections.abc import AsyncGenerator, AsyncIterable, Callable, Iterable
|
|||
from dataclasses import asdict, dataclass, field
|
||||
from enum import StrEnum
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
import time
|
||||
from typing import Any, cast
|
||||
import wave
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -39,7 +44,7 @@ from homeassistant.util import (
|
|||
)
|
||||
from homeassistant.util.limited_size_dict import LimitedSizeDict
|
||||
|
||||
from .const import DOMAIN
|
||||
from .const import DATA_CONFIG, DOMAIN
|
||||
from .error import (
|
||||
IntentRecognitionError,
|
||||
PipelineError,
|
||||
|
@ -378,6 +383,12 @@ class PipelineRun:
|
|||
wake_word_engine: str = field(init=False)
|
||||
wake_word_provider: wake_word.WakeWordDetectionEntity = field(init=False)
|
||||
|
||||
debug_recording_thread: Thread | None = None
|
||||
"""Thread that records audio to debug_recording_dir"""
|
||||
|
||||
debug_recording_queue: Queue[str | bytes | None] | None = None
|
||||
"""Queue to communicate with debug recording thread"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Set language for pipeline."""
|
||||
self.language = self.pipeline.language or self.hass.config.language
|
||||
|
@ -405,8 +416,10 @@ class PipelineRun:
|
|||
return
|
||||
pipeline_data.pipeline_runs[self.pipeline.id][self.id].events.append(event)
|
||||
|
||||
def start(self) -> None:
|
||||
def start(self, device_id: str | None) -> None:
|
||||
"""Emit run start event."""
|
||||
self._start_debug_recording_thread(device_id)
|
||||
|
||||
data = {
|
||||
"pipeline": self.pipeline.id,
|
||||
"language": self.language,
|
||||
|
@ -416,8 +429,12 @@ class PipelineRun:
|
|||
|
||||
self.process_event(PipelineEvent(PipelineEventType.RUN_START, data))
|
||||
|
||||
def end(self) -> None:
|
||||
async def end(self) -> None:
|
||||
"""Emit run end event."""
|
||||
# Stop the recording thread before emitting run-end.
|
||||
# This ensures that files are properly closed if the event handler reads them.
|
||||
await self._stop_debug_recording_thread()
|
||||
|
||||
self.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.RUN_END,
|
||||
|
@ -475,6 +492,9 @@ class PipelineRun:
|
|||
)
|
||||
)
|
||||
|
||||
if self.debug_recording_queue is not None:
|
||||
self.debug_recording_queue.put_nowait(f"00_wake-{self.wake_word_engine}")
|
||||
|
||||
wake_word_settings = self.wake_word_settings or WakeWordSettings()
|
||||
|
||||
wake_word_vad: VoiceActivityTimeout | None = None
|
||||
|
@ -496,7 +516,7 @@ class PipelineRun:
|
|||
try:
|
||||
# Detect wake word(s)
|
||||
result = await self.wake_word_provider.async_process_audio_stream(
|
||||
_wake_word_audio_stream(
|
||||
self._wake_word_audio_stream(
|
||||
audio_stream=stream,
|
||||
stt_audio_buffer=stt_audio_buffer,
|
||||
wake_word_vad=wake_word_vad,
|
||||
|
@ -546,6 +566,39 @@ class PipelineRun:
|
|||
|
||||
return result
|
||||
|
||||
async def _wake_word_audio_stream(
|
||||
self,
|
||||
audio_stream: AsyncIterable[bytes],
|
||||
stt_audio_buffer: RingBuffer | None,
|
||||
wake_word_vad: VoiceActivityTimeout | None,
|
||||
sample_rate: int = 16000,
|
||||
sample_width: int = 2,
|
||||
) -> AsyncIterable[tuple[bytes, int]]:
|
||||
"""Yield audio chunks with timestamps (milliseconds since start of stream).
|
||||
|
||||
Adds audio to a ring buffer that will be forwarded to speech-to-text after
|
||||
detection. Times out if VAD detects enough silence.
|
||||
"""
|
||||
ms_per_sample = sample_rate // 1000
|
||||
timestamp_ms = 0
|
||||
async for chunk in audio_stream:
|
||||
if self.debug_recording_queue is not None:
|
||||
self.debug_recording_queue.put_nowait(chunk)
|
||||
|
||||
yield chunk, timestamp_ms
|
||||
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
|
||||
|
||||
# Wake-word-detection occurs *after* the wake word was actually
|
||||
# spoken. Keeping audio right before detection allows the voice
|
||||
# command to be spoken immediately after the wake word.
|
||||
if stt_audio_buffer is not None:
|
||||
stt_audio_buffer.put(chunk)
|
||||
|
||||
if (wake_word_vad is not None) and (not wake_word_vad.process(chunk)):
|
||||
raise WakeWordTimeoutError(
|
||||
code="wake-word-timeout", message="Wake word was not detected"
|
||||
)
|
||||
|
||||
async def prepare_speech_to_text(self, metadata: stt.SpeechMetadata) -> None:
|
||||
"""Prepare speech-to-text."""
|
||||
# pipeline.stt_engine can't be None or this function is not called
|
||||
|
@ -595,6 +648,10 @@ class PipelineRun:
|
|||
)
|
||||
)
|
||||
|
||||
if self.debug_recording_queue is not None:
|
||||
# New recording
|
||||
self.debug_recording_queue.put_nowait(f"01_stt-{engine}")
|
||||
|
||||
try:
|
||||
# Transcribe audio stream
|
||||
result = await self.stt_provider.async_process_audio_stream(
|
||||
|
@ -648,6 +705,9 @@ class PipelineRun:
|
|||
sent_vad_start = False
|
||||
timestamp_ms = 0
|
||||
async for chunk in audio_stream:
|
||||
if self.debug_recording_queue is not None:
|
||||
self.debug_recording_queue.put_nowait(chunk)
|
||||
|
||||
if stt_vad is not None:
|
||||
if not stt_vad.process(chunk):
|
||||
# Silence detected at the end of voice command
|
||||
|
@ -829,6 +889,96 @@ class PipelineRun:
|
|||
|
||||
return tts_media.url
|
||||
|
||||
def _start_debug_recording_thread(self, device_id: str | None) -> None:
|
||||
"""Start thread to record wake/stt audio if debug_recording_dir is set."""
|
||||
if self.debug_recording_thread is not None:
|
||||
# Already started
|
||||
return
|
||||
|
||||
# Directory to save audio for each pipeline run.
|
||||
# Configured in YAML for assist_pipeline.
|
||||
if debug_recording_dir := self.hass.data[DATA_CONFIG].get(
|
||||
"debug_recording_dir"
|
||||
):
|
||||
if device_id is None:
|
||||
# <debug_recording_dir>/<pipeline.name>/<run.id>
|
||||
run_recording_dir = (
|
||||
Path(debug_recording_dir)
|
||||
/ self.pipeline.name
|
||||
/ str(time.monotonic_ns())
|
||||
)
|
||||
else:
|
||||
# <debug_recording_dir>/<device_id>/<pipeline.name>/<run.id>
|
||||
run_recording_dir = (
|
||||
Path(debug_recording_dir)
|
||||
/ device_id
|
||||
/ self.pipeline.name
|
||||
/ str(time.monotonic_ns())
|
||||
)
|
||||
|
||||
self.debug_recording_queue = Queue()
|
||||
self.debug_recording_thread = Thread(
|
||||
target=_pipeline_debug_recording_thread_proc,
|
||||
args=(run_recording_dir, self.debug_recording_queue),
|
||||
daemon=True,
|
||||
)
|
||||
self.debug_recording_thread.start()
|
||||
|
||||
async def _stop_debug_recording_thread(self) -> None:
|
||||
"""Stop recording thread."""
|
||||
if (self.debug_recording_thread is None) or (
|
||||
self.debug_recording_queue is None
|
||||
):
|
||||
# Not running
|
||||
return
|
||||
|
||||
# Signal thread to stop gracefully
|
||||
self.debug_recording_queue.put(None)
|
||||
|
||||
# Wait until the thread has finished to ensure that files are fully written
|
||||
await self.hass.async_add_executor_job(self.debug_recording_thread.join)
|
||||
|
||||
self.debug_recording_queue = None
|
||||
self.debug_recording_thread = None
|
||||
|
||||
|
||||
def _pipeline_debug_recording_thread_proc(
|
||||
run_recording_dir: Path,
|
||||
queue: Queue[str | bytes | None],
|
||||
message_timeout: float = 5,
|
||||
) -> None:
|
||||
wav_writer: wave.Wave_write | None = None
|
||||
|
||||
try:
|
||||
_LOGGER.debug("Saving wake/stt audio to %s", run_recording_dir)
|
||||
run_recording_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
while True:
|
||||
message = queue.get(timeout=message_timeout)
|
||||
if message is None:
|
||||
# Stop signal
|
||||
break
|
||||
|
||||
if isinstance(message, str):
|
||||
# New WAV file name
|
||||
if wav_writer is not None:
|
||||
wav_writer.close()
|
||||
|
||||
wav_path = run_recording_dir / f"{message}.wav"
|
||||
wav_writer = wave.open(str(wav_path), "wb")
|
||||
wav_writer.setframerate(16000)
|
||||
wav_writer.setsampwidth(2)
|
||||
wav_writer.setnchannels(1)
|
||||
elif isinstance(message, bytes):
|
||||
# Chunk of 16-bit mono audio at 16Khz
|
||||
if wav_writer is not None:
|
||||
wav_writer.writeframes(message)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
_LOGGER.exception("Unexpected error in debug recording thread")
|
||||
finally:
|
||||
if wav_writer is not None:
|
||||
wav_writer.close()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineInput:
|
||||
|
@ -854,7 +1004,7 @@ class PipelineInput:
|
|||
|
||||
async def execute(self) -> None:
|
||||
"""Run pipeline."""
|
||||
self.run.start()
|
||||
self.run.start(device_id=self.device_id)
|
||||
current_stage: PipelineStage | None = self.run.start_stage
|
||||
stt_audio_buffer: list[bytes] = []
|
||||
|
||||
|
@ -867,7 +1017,7 @@ class PipelineInput:
|
|||
)
|
||||
if detect_result is None:
|
||||
# No wake word. Abort the rest of the pipeline.
|
||||
self.run.end()
|
||||
await self.run.end()
|
||||
return
|
||||
|
||||
current_stage = PipelineStage.STT
|
||||
|
@ -927,9 +1077,10 @@ class PipelineInput:
|
|||
{"code": err.code, "message": err.message},
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
self.run.end()
|
||||
finally:
|
||||
# Always end the run since it needs to shut down the debug recording
|
||||
# thread, etc.
|
||||
await self.run.end()
|
||||
|
||||
async def validate(self) -> None:
|
||||
"""Validate pipeline input against start stage."""
|
||||
|
@ -1000,36 +1151,6 @@ class PipelineInput:
|
|||
await asyncio.gather(*prepare_tasks)
|
||||
|
||||
|
||||
async def _wake_word_audio_stream(
|
||||
audio_stream: AsyncIterable[bytes],
|
||||
stt_audio_buffer: RingBuffer | None,
|
||||
wake_word_vad: VoiceActivityTimeout | None,
|
||||
sample_rate: int = 16000,
|
||||
sample_width: int = 2,
|
||||
) -> AsyncIterable[tuple[bytes, int]]:
|
||||
"""Yield audio chunks with timestamps (milliseconds since start of stream).
|
||||
|
||||
Adds audio to a ring buffer that will be forwarded to speech-to-text after
|
||||
detection. Times out if VAD detects enough silence.
|
||||
"""
|
||||
ms_per_sample = sample_rate // 1000
|
||||
timestamp_ms = 0
|
||||
async for chunk in audio_stream:
|
||||
yield chunk, timestamp_ms
|
||||
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
|
||||
|
||||
# Wake-word-detection occurs *after* the wake word was actually
|
||||
# spoken. Keeping audio right before detection allows the voice
|
||||
# command to be spoken immediately after the wake word.
|
||||
if stt_audio_buffer is not None:
|
||||
stt_audio_buffer.put(chunk)
|
||||
|
||||
if (wake_word_vad is not None) and (not wake_word_vad.process(chunk)):
|
||||
raise WakeWordTimeoutError(
|
||||
code="wake-word-timeout", message="Wake word was not detected"
|
||||
)
|
||||
|
||||
|
||||
class PipelinePreferred(CollectionError):
|
||||
"""Raised when attempting to delete the preferred pipelen."""
|
||||
|
||||
|
|
|
@ -191,7 +191,7 @@ class MockWakeWordEntity(wake_word.WakeWordDetectionEntity):
|
|||
) -> wake_word.DetectionResult | None:
|
||||
"""Try to detect wake word(s) in an audio stream with timestamps."""
|
||||
async for chunk, timestamp in stream:
|
||||
if chunk == b"wake word":
|
||||
if chunk.startswith(b"wake word"):
|
||||
return wake_word.DetectionResult(
|
||||
ww_id=self.supported_wake_words[0].ww_id,
|
||||
timestamp=timestamp,
|
||||
|
@ -301,7 +301,6 @@ async def init_supporting_components(
|
|||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
|
||||
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
|
||||
# assert await async_setup_component(hass, wake_word.DOMAIN, {"wake_word": {}})
|
||||
assert await async_setup_component(hass, "media_source", {})
|
||||
|
||||
config_entry = MockConfigEntry(domain="test")
|
||||
|
|
|
@ -77,6 +77,9 @@
|
|||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline.7
|
||||
None
|
||||
# ---
|
||||
# name: test_audio_pipeline_debug
|
||||
dict({
|
||||
'language': 'en',
|
||||
|
@ -155,6 +158,9 @@
|
|||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_debug.7
|
||||
None
|
||||
# ---
|
||||
# name: test_audio_pipeline_no_wake_word_engine
|
||||
dict({
|
||||
'code': 'wake-engine-missing',
|
||||
|
@ -364,6 +370,9 @@
|
|||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_no_timeout.9
|
||||
None
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_timeout
|
||||
dict({
|
||||
'language': 'en',
|
||||
|
@ -392,6 +401,9 @@
|
|||
'message': 'Wake word was not detected',
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_timeout.3
|
||||
None
|
||||
# ---
|
||||
# name: test_intent_failed
|
||||
dict({
|
||||
'language': 'en',
|
||||
|
@ -411,6 +423,9 @@
|
|||
'language': 'en',
|
||||
})
|
||||
# ---
|
||||
# name: test_intent_failed.2
|
||||
None
|
||||
# ---
|
||||
# name: test_intent_timeout
|
||||
dict({
|
||||
'language': 'en',
|
||||
|
@ -431,6 +446,9 @@
|
|||
})
|
||||
# ---
|
||||
# name: test_intent_timeout.2
|
||||
None
|
||||
# ---
|
||||
# name: test_intent_timeout.3
|
||||
dict({
|
||||
'code': 'timeout',
|
||||
'message': 'Timeout running pipeline',
|
||||
|
@ -482,6 +500,9 @@
|
|||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_stt_stream_failed.2
|
||||
None
|
||||
# ---
|
||||
# name: test_text_only_pipeline
|
||||
dict({
|
||||
'language': 'en',
|
||||
|
@ -523,6 +544,9 @@
|
|||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_text_only_pipeline.3
|
||||
None
|
||||
# ---
|
||||
# name: test_text_pipeline_timeout
|
||||
dict({
|
||||
'code': 'timeout',
|
||||
|
@ -547,3 +571,6 @@
|
|||
'voice': 'james_earl_jones',
|
||||
})
|
||||
# ---
|
||||
# name: test_tts_failed.2
|
||||
None
|
||||
# ---
|
||||
|
|
|
@ -1,13 +1,17 @@
|
|||
"""Test Voice Assistant init."""
|
||||
from dataclasses import asdict
|
||||
import itertools as it
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from unittest.mock import ANY, patch
|
||||
import wave
|
||||
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components import assist_pipeline, stt
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from .conftest import MockSttProvider, MockSttProviderEntity, MockWakeWordEntity
|
||||
|
||||
|
@ -305,7 +309,7 @@ async def test_pipeline_from_audio_stream_wake_word(
|
|||
async def audio_data():
|
||||
yield wake_chunk_1 # 1 second
|
||||
yield wake_chunk_2 # 1 second
|
||||
yield b"wake word"
|
||||
yield b"wake word!"
|
||||
yield b"part1"
|
||||
yield b"part2"
|
||||
yield b"end"
|
||||
|
@ -353,3 +357,183 @@ async def test_pipeline_from_audio_stream_wake_word(
|
|||
assert first_chunk == wake_chunk_1[len(wake_chunk_1) // 2 :] + wake_chunk_2
|
||||
|
||||
assert mock_stt_provider.received[1:] == [b"queued audio", b"part1", b"part2"]
|
||||
|
||||
|
||||
async def test_pipeline_save_audio(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider: MockSttProvider,
|
||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||
init_supporting_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test saving audio during a pipeline run."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir_str:
|
||||
# Enable audio recording to temporary directory
|
||||
temp_dir = Path(temp_dir_str)
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"assist_pipeline",
|
||||
{"assist_pipeline": {"debug_recording_dir": temp_dir_str}},
|
||||
)
|
||||
|
||||
pipeline = assist_pipeline.async_get_pipeline(hass)
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
# Pad out to an even number of bytes since these "samples" will be saved
|
||||
# as 16-bit values.
|
||||
async def audio_data():
|
||||
yield b"wake word_"
|
||||
# queued audio
|
||||
yield b"part1_"
|
||||
yield b"part2_"
|
||||
yield b""
|
||||
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
context=Context(),
|
||||
event_callback=events.append,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_data(),
|
||||
pipeline_id=pipeline.id,
|
||||
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
|
||||
end_stage=assist_pipeline.PipelineStage.STT,
|
||||
)
|
||||
|
||||
pipeline_dirs = list(temp_dir.iterdir())
|
||||
|
||||
# Only one pipeline run
|
||||
# <debug_recording_dir>/<pipeline.name>/<run.id>
|
||||
assert len(pipeline_dirs) == 1
|
||||
assert pipeline_dirs[0].is_dir()
|
||||
assert pipeline_dirs[0].name == pipeline.name
|
||||
|
||||
# Wake and stt files
|
||||
run_dirs = list(pipeline_dirs[0].iterdir())
|
||||
assert run_dirs[0].is_dir()
|
||||
run_files = list(run_dirs[0].iterdir())
|
||||
|
||||
assert len(run_files) == 2
|
||||
wake_file = run_files[0] if "wake" in run_files[0].name else run_files[1]
|
||||
stt_file = run_files[0] if "stt" in run_files[0].name else run_files[1]
|
||||
assert wake_file != stt_file
|
||||
|
||||
# Verify wake file
|
||||
with wave.open(str(wake_file), "rb") as wake_wav:
|
||||
wake_data = wake_wav.readframes(wake_wav.getnframes())
|
||||
assert wake_data == b"wake word_"
|
||||
|
||||
# Verify stt file
|
||||
with wave.open(str(stt_file), "rb") as stt_wav:
|
||||
stt_data = stt_wav.readframes(stt_wav.getnframes())
|
||||
assert stt_data == b"queued audiopart1_part2_"
|
||||
|
||||
|
||||
async def test_pipeline_saved_audio_with_device_id(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider: MockSttProvider,
|
||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||
init_supporting_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that saved audio directory uses device id."""
|
||||
device_id = "test-device-id"
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir_str:
|
||||
# Enable audio recording to temporary directory
|
||||
temp_dir = Path(temp_dir_str)
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"assist_pipeline",
|
||||
{"assist_pipeline": {"debug_recording_dir": temp_dir_str}},
|
||||
)
|
||||
|
||||
def event_callback(event: assist_pipeline.PipelineEvent):
|
||||
if event.type == "run-end":
|
||||
# Verify that saved audio directory is named after device id
|
||||
device_dirs = list(temp_dir.iterdir())
|
||||
assert device_dirs[0].name == device_id
|
||||
|
||||
async def audio_data():
|
||||
yield b"not used"
|
||||
|
||||
# Force a timeout during wake word detection
|
||||
with patch.object(
|
||||
mock_wake_word_provider_entity,
|
||||
"async_process_audio_stream",
|
||||
side_effect=assist_pipeline.error.WakeWordTimeoutError(
|
||||
code="timeout", message="timeout"
|
||||
),
|
||||
):
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
context=Context(),
|
||||
event_callback=event_callback,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_data(),
|
||||
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
|
||||
end_stage=assist_pipeline.PipelineStage.STT,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
|
||||
async def test_pipeline_saved_audio_write_error(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider: MockSttProvider,
|
||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||
init_supporting_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that saved audio thread closes WAV file even if there's a write error."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir_str:
|
||||
# Enable audio recording to temporary directory
|
||||
temp_dir = Path(temp_dir_str)
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"assist_pipeline",
|
||||
{"assist_pipeline": {"debug_recording_dir": temp_dir_str}},
|
||||
)
|
||||
|
||||
def event_callback(event: assist_pipeline.PipelineEvent):
|
||||
if event.type == "run-end":
|
||||
# Verify WAV file exists, but contains no data
|
||||
pipeline_dirs = list(temp_dir.iterdir())
|
||||
run_dirs = list(pipeline_dirs[0].iterdir())
|
||||
wav_path = next(run_dirs[0].iterdir())
|
||||
with wave.open(str(wav_path), "rb") as wav_file:
|
||||
assert wav_file.getnframes() == 0
|
||||
|
||||
async def audio_data():
|
||||
yield b"not used"
|
||||
|
||||
# Force a timeout during wake word detection
|
||||
with patch("wave.Wave_write.writeframes", raises=RuntimeError()):
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
context=Context(),
|
||||
event_callback=event_callback,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_data(),
|
||||
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
|
||||
end_stage=assist_pipeline.PipelineStage.STT,
|
||||
)
|
||||
|
|
|
@ -58,7 +58,7 @@ async def test_text_only_pipeline(
|
|||
# run end
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-end"
|
||||
assert msg["event"]["data"] is None
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
|
@ -148,7 +148,7 @@ async def test_audio_pipeline(
|
|||
# run end
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-end"
|
||||
assert msg["event"]["data"] is None
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
|
@ -215,6 +215,12 @@ async def test_audio_pipeline_with_wake_word_timeout(
|
|||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# run end
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
|
||||
async def test_audio_pipeline_with_wake_word_no_timeout(
|
||||
hass: HomeAssistant,
|
||||
|
@ -302,7 +308,7 @@ async def test_audio_pipeline_with_wake_word_no_timeout(
|
|||
# run end
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-end"
|
||||
assert msg["event"]["data"] is None
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
|
@ -429,6 +435,12 @@ async def test_intent_timeout(
|
|||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# run-end
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# timeout error
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "error"
|
||||
|
@ -550,6 +562,12 @@ async def test_intent_failed(
|
|||
assert msg["event"]["data"]["code"] == "intent-failed"
|
||||
events.append(msg["event"])
|
||||
|
||||
# run end
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||
|
@ -730,6 +748,12 @@ async def test_stt_stream_failed(
|
|||
assert msg["event"]["data"]["code"] == "stt-stream-failed"
|
||||
events.append(msg["event"])
|
||||
|
||||
# run end
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||
|
@ -792,6 +816,12 @@ async def test_tts_failed(
|
|||
assert msg["event"]["data"]["code"] == "tts-failed"
|
||||
events.append(msg["event"])
|
||||
|
||||
# run end
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||
|
@ -1460,7 +1490,7 @@ async def test_audio_pipeline_debug(
|
|||
# run end
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-end"
|
||||
assert msg["event"]["data"] is None
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# Get the id of the pipeline
|
||||
|
|
Loading…
Add table
Reference in a new issue