Fix pipeline restart in VoIP (#126668)

This commit is contained in:
Michael Hansen 2024-09-24 14:24:42 -05:00 committed by GitHub
parent 739165585a
commit 86f8901c96
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 34 additions and 116 deletions

View file

@ -14,11 +14,7 @@ import wave
from voip_utils import RtpDatagramProtocol
from homeassistant.components import tts
from homeassistant.components.assist_pipeline import (
PipelineEvent,
PipelineEventType,
PipelineNotFound,
)
from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType
from homeassistant.components.assist_satellite import (
AssistSatelliteConfiguration,
AssistSatelliteEntity,
@ -31,7 +27,6 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
from .devices import VoIPDevice
from .entity import VoIPEntity
from .util import queue_to_iterable
if TYPE_CHECKING:
from . import DomainData
@ -101,9 +96,9 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
self.config_entry = config_entry
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
self._audio_chunk_timeout: float = 2.0
self._pipeline_task: asyncio.Task | None = None
self._run_pipeline_task: asyncio.Task | None = None
self._pipeline_had_error: bool = False
self._tts_done = asyncio.Event()
self._tts_extra_timeout: float = 1.0
@ -161,11 +156,11 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
def on_chunk(self, audio_bytes: bytes) -> None:
"""Handle raw audio chunk."""
if self._pipeline_task is None:
self._clear_audio_queue()
if self._run_pipeline_task is None:
# Run pipeline until voice command finishes, then start over
self._pipeline_task = self.config_entry.async_create_background_task(
self._clear_audio_queue()
self._tts_done.clear()
self._run_pipeline_task = self.config_entry.async_create_background_task(
self.hass,
self._run_pipeline(),
"voip_pipeline_run",
@ -173,27 +168,28 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
self._audio_queue.put_nowait(audio_bytes)
async def _run_pipeline(
self,
) -> None:
"""Forward audio to pipeline STT and handle TTS."""
async def _run_pipeline(self) -> None:
_LOGGER.debug("Starting pipeline")
self.async_set_context(Context(user_id=self.config_entry.data["user"]))
self.voip_device.set_is_active(True)
async def stt_stream():
while True:
async with asyncio.timeout(self._audio_chunk_timeout):
chunk = await self._audio_queue.get()
if not chunk:
break
yield chunk
# Play listening tone at the start of each cycle
await self._play_tone(Tones.LISTENING, silence_before=0.2)
try:
self._tts_done.clear()
# Run pipeline with a timeout
_LOGGER.debug("Starting pipeline")
async with asyncio.timeout(_PIPELINE_TIMEOUT_SEC):
await self.async_accept_pipeline_from_satellite(
audio_stream=queue_to_iterable(
self._audio_queue, timeout=self._audio_chunk_timeout
),
)
await self.async_accept_pipeline_from_satellite(
audio_stream=stt_stream(),
)
if self._pipeline_had_error:
self._pipeline_had_error = False
@ -204,20 +200,15 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
# This is set in _send_tts and has a timeout that's based on the
# length of the TTS audio.
await self._tts_done.wait()
_LOGGER.debug("Pipeline finished")
except PipelineNotFound:
_LOGGER.warning("Pipeline not found")
except (asyncio.CancelledError, TimeoutError):
# Expected after caller hangs up
_LOGGER.debug("Pipeline cancelled or timed out")
self.disconnect()
self._clear_audio_queue()
except TimeoutError:
self.disconnect() # caller hung up
finally:
self.voip_device.set_is_active(False)
# Stop audio stream
await self._audio_queue.put(None)
# Allow pipeline to run again
self._pipeline_task = None
self.voip_device.set_is_active(False)
self._run_pipeline_task = None
_LOGGER.debug("Pipeline finished")
def _clear_audio_queue(self) -> None:
"""Ensure audio queue is empty."""
@ -247,6 +238,7 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
elif event.type == PipelineEventType.ERROR:
# Play error tone instead of wait for TTS when pipeline is finished.
self._pipeline_had_error = True
_LOGGER.warning(event)
async def _send_tts(self, media_id: str) -> None:
"""Send TTS audio to caller via RTP."""
@ -264,6 +256,7 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
# Don't overlap TTS and processing beep
_LOGGER.debug("Waiting for processing tone")
await self._processing_tone_done.wait()
with io.BytesIO(data) as wav_io:
@ -297,12 +290,12 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
_LOGGER.warning("TTS timeout")
raise
finally:
# Signal pipeline to restart
self._tts_done.set()
# Update satellite state
self.tts_response_finished()
# Signal pipeline to restart
self._tts_done.set()
async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
"""Send audio in executor."""
await self.hass.async_add_executor_job(

View file

@ -1,28 +0,0 @@
"""Voip util functions."""
from __future__ import annotations
from asyncio import Queue, timeout as async_timeout
from collections.abc import AsyncIterable
from typing import Any
from typing_extensions import TypeVar
_DataT = TypeVar("_DataT", default=Any)
async def queue_to_iterable(
queue: Queue[_DataT], timeout: float | None = None
) -> AsyncIterable[_DataT]:
"""Stream items from a queue until None with an optional timeout per item."""
if timeout is None:
while (item := await queue.get()) is not None:
yield item
else:
async with async_timeout(timeout):
item = await queue.get()
while item is not None:
yield item
async with async_timeout(timeout):
item = await queue.get()

View file

@ -1,47 +0,0 @@
"""Test VoIP utils."""
import asyncio
import pytest
from homeassistant.components.voip.util import queue_to_iterable
async def test_queue_to_iterable() -> None:
"""Test queue_to_iterable."""
queue: asyncio.Queue[int | None] = asyncio.Queue()
expected_items = list(range(10))
for i in expected_items:
await queue.put(i)
# Will terminate the stream
await queue.put(None)
actual_items = [item async for item in queue_to_iterable(queue)]
assert expected_items == actual_items
# Check timeout
assert queue.empty()
# Time out on first item
async with asyncio.timeout(1):
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
# Should time out very quickly
async for _item in queue_to_iterable(queue, timeout=0.01):
await asyncio.sleep(1)
# Check timeout on second item
assert queue.empty()
await queue.put(12345)
# Time out on second item
async with asyncio.timeout(1):
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
# Should time out very quickly
async for item in queue_to_iterable(queue, timeout=0.01):
if item != 12345:
await asyncio.sleep(1)
assert queue.empty()