Don't return TTS URL in Assist pipeline (#105164)
* Don't return TTS URL * Add test for empty queue
This commit is contained in:
parent
6666b796f2
commit
4c4ad9404f
2 changed files with 69 additions and 5 deletions
|
@ -9,7 +9,7 @@ from dataclasses import asdict, dataclass, field
|
|||
from enum import StrEnum
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from queue import Empty, Queue
|
||||
from threading import Thread
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Final, cast
|
||||
|
@ -1010,8 +1010,8 @@ class PipelineRun:
|
|||
self.tts_engine = engine
|
||||
self.tts_options = tts_options
|
||||
|
||||
async def text_to_speech(self, tts_input: str) -> str:
|
||||
"""Run text-to-speech portion of pipeline. Returns URL of TTS audio."""
|
||||
async def text_to_speech(self, tts_input: str) -> None:
|
||||
"""Run text-to-speech portion of pipeline."""
|
||||
self.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.TTS_START,
|
||||
|
@ -1058,8 +1058,6 @@ class PipelineRun:
|
|||
PipelineEvent(PipelineEventType.TTS_END, {"tts_output": tts_output})
|
||||
)
|
||||
|
||||
return tts_media.url
|
||||
|
||||
def _capture_chunk(self, audio_bytes: bytes | None) -> None:
|
||||
"""Forward audio chunk to various capturing mechanisms."""
|
||||
if self.debug_recording_queue is not None:
|
||||
|
@ -1246,6 +1244,8 @@ def _pipeline_debug_recording_thread_proc(
|
|||
# Chunk of 16-bit mono audio at 16Khz
|
||||
if wav_writer is not None:
|
||||
wav_writer.writeframes(message)
|
||||
except Empty:
|
||||
pass # occurs when pipeline has unexpected error
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
_LOGGER.exception("Unexpected error in debug recording thread")
|
||||
finally:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Test Voice Assistant init."""
|
||||
import asyncio
|
||||
from dataclasses import asdict
|
||||
import itertools as it
|
||||
from pathlib import Path
|
||||
|
@ -569,6 +570,69 @@ async def test_pipeline_saved_audio_write_error(
|
|||
)
|
||||
|
||||
|
||||
async def test_pipeline_saved_audio_empty_queue(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider: MockSttProvider,
|
||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||
init_supporting_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that saved audio thread closes WAV file even if there's an empty queue."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir_str:
|
||||
# Enable audio recording to temporary directory
|
||||
temp_dir = Path(temp_dir_str)
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
DOMAIN,
|
||||
{DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
|
||||
)
|
||||
|
||||
def event_callback(event: assist_pipeline.PipelineEvent):
|
||||
if event.type == "run-end":
|
||||
# Verify WAV file exists, but contains no data
|
||||
pipeline_dirs = list(temp_dir.iterdir())
|
||||
run_dirs = list(pipeline_dirs[0].iterdir())
|
||||
wav_path = next(run_dirs[0].iterdir())
|
||||
with wave.open(str(wav_path), "rb") as wav_file:
|
||||
assert wav_file.getnframes() == 0
|
||||
|
||||
async def audio_data():
|
||||
# Force timeout in _pipeline_debug_recording_thread_proc
|
||||
await asyncio.sleep(1)
|
||||
yield b"not used"
|
||||
|
||||
# Wrap original function to time out immediately
|
||||
_pipeline_debug_recording_thread_proc = (
|
||||
assist_pipeline.pipeline._pipeline_debug_recording_thread_proc
|
||||
)
|
||||
|
||||
def proc_wrapper(run_recording_dir, queue):
|
||||
_pipeline_debug_recording_thread_proc(
|
||||
run_recording_dir, queue, message_timeout=0
|
||||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline._pipeline_debug_recording_thread_proc",
|
||||
proc_wrapper,
|
||||
):
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
context=Context(),
|
||||
event_callback=event_callback,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_data(),
|
||||
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
|
||||
end_stage=assist_pipeline.PipelineStage.STT,
|
||||
)
|
||||
|
||||
|
||||
async def test_wake_word_detection_aborted(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider: MockSttProvider,
|
||||
|
|
Loading…
Add table
Reference in a new issue