Don't return TTS URL in Assist pipeline (#105164)

* Don't return TTS URL

* Add test for empty queue
This commit is contained in:
Michael Hansen 2023-12-07 14:28:04 -06:00 committed by GitHub
parent 6666b796f2
commit 4c4ad9404f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 69 additions and 5 deletions

View file

@ -9,7 +9,7 @@ from dataclasses import asdict, dataclass, field
from enum import StrEnum
import logging
from pathlib import Path
from queue import Queue
from queue import Empty, Queue
from threading import Thread
import time
from typing import TYPE_CHECKING, Any, Final, cast
@ -1010,8 +1010,8 @@ class PipelineRun:
self.tts_engine = engine
self.tts_options = tts_options
async def text_to_speech(self, tts_input: str) -> str:
"""Run text-to-speech portion of pipeline. Returns URL of TTS audio."""
async def text_to_speech(self, tts_input: str) -> None:
"""Run text-to-speech portion of pipeline."""
self.process_event(
PipelineEvent(
PipelineEventType.TTS_START,
@ -1058,8 +1058,6 @@ class PipelineRun:
PipelineEvent(PipelineEventType.TTS_END, {"tts_output": tts_output})
)
return tts_media.url
def _capture_chunk(self, audio_bytes: bytes | None) -> None:
"""Forward audio chunk to various capturing mechanisms."""
if self.debug_recording_queue is not None:
@ -1246,6 +1244,8 @@ def _pipeline_debug_recording_thread_proc(
# Chunk of 16-bit mono audio at 16Khz
if wav_writer is not None:
wav_writer.writeframes(message)
except Empty:
pass # occurs when pipeline has unexpected error
except Exception: # pylint: disable=broad-exception-caught
_LOGGER.exception("Unexpected error in debug recording thread")
finally:

View file

@ -1,4 +1,5 @@
"""Test Voice Assistant init."""
import asyncio
from dataclasses import asdict
import itertools as it
from pathlib import Path
@ -569,6 +570,69 @@ async def test_pipeline_saved_audio_write_error(
)
async def test_pipeline_saved_audio_empty_queue(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_supporting_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test that saved audio thread closes WAV file even if there's an empty queue."""
with tempfile.TemporaryDirectory() as temp_dir_str:
# Enable audio recording to temporary directory
temp_dir = Path(temp_dir_str)
assert await async_setup_component(
hass,
DOMAIN,
{DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
)
def event_callback(event: assist_pipeline.PipelineEvent):
if event.type == "run-end":
# Verify WAV file exists, but contains no data
pipeline_dirs = list(temp_dir.iterdir())
run_dirs = list(pipeline_dirs[0].iterdir())
wav_path = next(run_dirs[0].iterdir())
with wave.open(str(wav_path), "rb") as wav_file:
assert wav_file.getnframes() == 0
async def audio_data():
# Force timeout in _pipeline_debug_recording_thread_proc
await asyncio.sleep(1)
yield b"not used"
# Wrap original function to time out immediately
_pipeline_debug_recording_thread_proc = (
assist_pipeline.pipeline._pipeline_debug_recording_thread_proc
)
def proc_wrapper(run_recording_dir, queue):
_pipeline_debug_recording_thread_proc(
run_recording_dir, queue, message_timeout=0
)
with patch(
"homeassistant.components.assist_pipeline.pipeline._pipeline_debug_recording_thread_proc",
proc_wrapper,
):
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
context=Context(),
event_callback=event_callback,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
end_stage=assist_pipeline.PipelineStage.STT,
)
async def test_wake_word_detection_aborted(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,