diff --git a/homeassistant/components/voip/assist_satellite.py b/homeassistant/components/voip/assist_satellite.py index 2f37a8a63e1..6eb1aee209f 100644 --- a/homeassistant/components/voip/assist_satellite.py +++ b/homeassistant/components/voip/assist_satellite.py @@ -14,11 +14,7 @@ import wave from voip_utils import RtpDatagramProtocol from homeassistant.components import tts -from homeassistant.components.assist_pipeline import ( - PipelineEvent, - PipelineEventType, - PipelineNotFound, -) +from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType from homeassistant.components.assist_satellite import ( AssistSatelliteConfiguration, AssistSatelliteEntity, @@ -31,7 +27,6 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH from .devices import VoIPDevice from .entity import VoIPEntity -from .util import queue_to_iterable if TYPE_CHECKING: from . import DomainData @@ -101,9 +96,9 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol self.config_entry = config_entry - self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue() + self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue() self._audio_chunk_timeout: float = 2.0 - self._pipeline_task: asyncio.Task | None = None + self._run_pipeline_task: asyncio.Task | None = None self._pipeline_had_error: bool = False self._tts_done = asyncio.Event() self._tts_extra_timeout: float = 1.0 @@ -161,11 +156,11 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol def on_chunk(self, audio_bytes: bytes) -> None: """Handle raw audio chunk.""" - if self._pipeline_task is None: - self._clear_audio_queue() - + if self._run_pipeline_task is None: # Run pipeline until voice command finishes, then start over - self._pipeline_task = self.config_entry.async_create_background_task( + self._clear_audio_queue() + self._tts_done.clear() + self._run_pipeline_task = self.config_entry.async_create_background_task( self.hass, self._run_pipeline(), "voip_pipeline_run", @@ -173,27 +168,28 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol self._audio_queue.put_nowait(audio_bytes) - async def _run_pipeline( - self, - ) -> None: - """Forward audio to pipeline STT and handle TTS.""" + async def _run_pipeline(self) -> None: + _LOGGER.debug("Starting pipeline") + self.async_set_context(Context(user_id=self.config_entry.data["user"])) self.voip_device.set_is_active(True) + async def stt_stream(): + while True: + async with asyncio.timeout(self._audio_chunk_timeout): + chunk = await self._audio_queue.get() + if not chunk: + break + + yield chunk + # Play listening tone at the start of each cycle await self._play_tone(Tones.LISTENING, silence_before=0.2) try: - self._tts_done.clear() - - # Run pipeline with a timeout - _LOGGER.debug("Starting pipeline") - async with asyncio.timeout(_PIPELINE_TIMEOUT_SEC): - await self.async_accept_pipeline_from_satellite( - audio_stream=queue_to_iterable( - self._audio_queue, timeout=self._audio_chunk_timeout - ), - ) + await self.async_accept_pipeline_from_satellite( + audio_stream=stt_stream(), + ) if self._pipeline_had_error: self._pipeline_had_error = False @@ -204,20 +200,15 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol # This is set in _send_tts and has a timeout that's based on the # length of the TTS audio. await self._tts_done.wait() - - _LOGGER.debug("Pipeline finished") - except PipelineNotFound: - _LOGGER.warning("Pipeline not found") - except (asyncio.CancelledError, TimeoutError): - # Expected after caller hangs up - _LOGGER.debug("Pipeline cancelled or timed out") - self.disconnect() - self._clear_audio_queue() + except TimeoutError: + self.disconnect() # caller hung up finally: - self.voip_device.set_is_active(False) + # Stop audio stream + await self._audio_queue.put(None) - # Allow pipeline to run again - self._pipeline_task = None + self.voip_device.set_is_active(False) + self._run_pipeline_task = None + _LOGGER.debug("Pipeline finished") def _clear_audio_queue(self) -> None: """Ensure audio queue is empty.""" @@ -247,6 +238,7 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol elif event.type == PipelineEventType.ERROR: # Play error tone instead of wait for TTS when pipeline is finished. self._pipeline_had_error = True + _LOGGER.warning(event) async def _send_tts(self, media_id: str) -> None: """Send TTS audio to caller via RTP.""" @@ -264,6 +256,7 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol if (self._tones & Tones.PROCESSING) == Tones.PROCESSING: # Don't overlap TTS and processing beep + _LOGGER.debug("Waiting for processing tone") await self._processing_tone_done.wait() with io.BytesIO(data) as wav_io: @@ -297,12 +290,12 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol _LOGGER.warning("TTS timeout") raise finally: - # Signal pipeline to restart - self._tts_done.set() - # Update satellite state self.tts_response_finished() + # Signal pipeline to restart + self._tts_done.set() + async def _async_send_audio(self, audio_bytes: bytes, **kwargs): """Send audio in executor.""" await self.hass.async_add_executor_job( diff --git a/homeassistant/components/voip/util.py b/homeassistant/components/voip/util.py deleted file mode 100644 index bfda96ba810..00000000000 --- a/homeassistant/components/voip/util.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Voip util functions.""" - -from __future__ import annotations - -from asyncio import Queue, timeout as async_timeout -from collections.abc import AsyncIterable -from typing import Any - -from typing_extensions import TypeVar - -_DataT = TypeVar("_DataT", default=Any) - - -async def queue_to_iterable( - queue: Queue[_DataT], timeout: float | None = None -) -> AsyncIterable[_DataT]: - """Stream items from a queue until None with an optional timeout per item.""" - if timeout is None: - while (item := await queue.get()) is not None: - yield item - else: - async with async_timeout(timeout): - item = await queue.get() - - while item is not None: - yield item - async with async_timeout(timeout): - item = await queue.get() diff --git a/tests/components/voip/test_util.py b/tests/components/voip/test_util.py deleted file mode 100644 index 85dfdbac2be..00000000000 --- a/tests/components/voip/test_util.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Test VoIP utils.""" - -import asyncio - -import pytest - -from homeassistant.components.voip.util import queue_to_iterable - - -async def test_queue_to_iterable() -> None: - """Test queue_to_iterable.""" - queue: asyncio.Queue[int | None] = asyncio.Queue() - expected_items = list(range(10)) - - for i in expected_items: - await queue.put(i) - - # Will terminate the stream - await queue.put(None) - - actual_items = [item async for item in queue_to_iterable(queue)] - - assert expected_items == actual_items - - # Check timeout - assert queue.empty() - - # Time out on first item - async with asyncio.timeout(1): - with pytest.raises(asyncio.TimeoutError): # noqa: PT012 - # Should time out very quickly - async for _item in queue_to_iterable(queue, timeout=0.01): - await asyncio.sleep(1) - - # Check timeout on second item - assert queue.empty() - await queue.put(12345) - - # Time out on second item - async with asyncio.timeout(1): - with pytest.raises(asyncio.TimeoutError): # noqa: PT012 - # Should time out very quickly - async for item in queue_to_iterable(queue, timeout=0.01): - if item != 12345: - await asyncio.sleep(1) - - assert queue.empty()