Fix pipeline restart in VoIP (#126668)
This commit is contained in:
parent
739165585a
commit
86f8901c96
3 changed files with 34 additions and 116 deletions
|
@ -14,11 +14,7 @@ import wave
|
|||
from voip_utils import RtpDatagramProtocol
|
||||
|
||||
from homeassistant.components import tts
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineNotFound,
|
||||
)
|
||||
from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType
|
||||
from homeassistant.components.assist_satellite import (
|
||||
AssistSatelliteConfiguration,
|
||||
AssistSatelliteEntity,
|
||||
|
@ -31,7 +27,6 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
|||
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
|
||||
from .devices import VoIPDevice
|
||||
from .entity import VoIPEntity
|
||||
from .util import queue_to_iterable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import DomainData
|
||||
|
@ -101,9 +96,9 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||
|
||||
self.config_entry = config_entry
|
||||
|
||||
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._audio_chunk_timeout: float = 2.0
|
||||
self._pipeline_task: asyncio.Task | None = None
|
||||
self._run_pipeline_task: asyncio.Task | None = None
|
||||
self._pipeline_had_error: bool = False
|
||||
self._tts_done = asyncio.Event()
|
||||
self._tts_extra_timeout: float = 1.0
|
||||
|
@ -161,11 +156,11 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||
|
||||
def on_chunk(self, audio_bytes: bytes) -> None:
|
||||
"""Handle raw audio chunk."""
|
||||
if self._pipeline_task is None:
|
||||
self._clear_audio_queue()
|
||||
|
||||
if self._run_pipeline_task is None:
|
||||
# Run pipeline until voice command finishes, then start over
|
||||
self._pipeline_task = self.config_entry.async_create_background_task(
|
||||
self._clear_audio_queue()
|
||||
self._tts_done.clear()
|
||||
self._run_pipeline_task = self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._run_pipeline(),
|
||||
"voip_pipeline_run",
|
||||
|
@ -173,26 +168,27 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||
|
||||
self._audio_queue.put_nowait(audio_bytes)
|
||||
|
||||
async def _run_pipeline(
|
||||
self,
|
||||
) -> None:
|
||||
"""Forward audio to pipeline STT and handle TTS."""
|
||||
async def _run_pipeline(self) -> None:
|
||||
_LOGGER.debug("Starting pipeline")
|
||||
|
||||
self.async_set_context(Context(user_id=self.config_entry.data["user"]))
|
||||
self.voip_device.set_is_active(True)
|
||||
|
||||
async def stt_stream():
|
||||
while True:
|
||||
async with asyncio.timeout(self._audio_chunk_timeout):
|
||||
chunk = await self._audio_queue.get()
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
yield chunk
|
||||
|
||||
# Play listening tone at the start of each cycle
|
||||
await self._play_tone(Tones.LISTENING, silence_before=0.2)
|
||||
|
||||
try:
|
||||
self._tts_done.clear()
|
||||
|
||||
# Run pipeline with a timeout
|
||||
_LOGGER.debug("Starting pipeline")
|
||||
async with asyncio.timeout(_PIPELINE_TIMEOUT_SEC):
|
||||
await self.async_accept_pipeline_from_satellite(
|
||||
audio_stream=queue_to_iterable(
|
||||
self._audio_queue, timeout=self._audio_chunk_timeout
|
||||
),
|
||||
audio_stream=stt_stream(),
|
||||
)
|
||||
|
||||
if self._pipeline_had_error:
|
||||
|
@ -204,20 +200,15 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||
# This is set in _send_tts and has a timeout that's based on the
|
||||
# length of the TTS audio.
|
||||
await self._tts_done.wait()
|
||||
|
||||
_LOGGER.debug("Pipeline finished")
|
||||
except PipelineNotFound:
|
||||
_LOGGER.warning("Pipeline not found")
|
||||
except (asyncio.CancelledError, TimeoutError):
|
||||
# Expected after caller hangs up
|
||||
_LOGGER.debug("Pipeline cancelled or timed out")
|
||||
self.disconnect()
|
||||
self._clear_audio_queue()
|
||||
except TimeoutError:
|
||||
self.disconnect() # caller hung up
|
||||
finally:
|
||||
self.voip_device.set_is_active(False)
|
||||
# Stop audio stream
|
||||
await self._audio_queue.put(None)
|
||||
|
||||
# Allow pipeline to run again
|
||||
self._pipeline_task = None
|
||||
self.voip_device.set_is_active(False)
|
||||
self._run_pipeline_task = None
|
||||
_LOGGER.debug("Pipeline finished")
|
||||
|
||||
def _clear_audio_queue(self) -> None:
|
||||
"""Ensure audio queue is empty."""
|
||||
|
@ -247,6 +238,7 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||
elif event.type == PipelineEventType.ERROR:
|
||||
# Play error tone instead of wait for TTS when pipeline is finished.
|
||||
self._pipeline_had_error = True
|
||||
_LOGGER.warning(event)
|
||||
|
||||
async def _send_tts(self, media_id: str) -> None:
|
||||
"""Send TTS audio to caller via RTP."""
|
||||
|
@ -264,6 +256,7 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||
|
||||
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
|
||||
# Don't overlap TTS and processing beep
|
||||
_LOGGER.debug("Waiting for processing tone")
|
||||
await self._processing_tone_done.wait()
|
||||
|
||||
with io.BytesIO(data) as wav_io:
|
||||
|
@ -297,12 +290,12 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||
_LOGGER.warning("TTS timeout")
|
||||
raise
|
||||
finally:
|
||||
# Signal pipeline to restart
|
||||
self._tts_done.set()
|
||||
|
||||
# Update satellite state
|
||||
self.tts_response_finished()
|
||||
|
||||
# Signal pipeline to restart
|
||||
self._tts_done.set()
|
||||
|
||||
async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
|
||||
"""Send audio in executor."""
|
||||
await self.hass.async_add_executor_job(
|
||||
|
|
|
@ -1,28 +0,0 @@
|
|||
"""Voip util functions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from asyncio import Queue, timeout as async_timeout
|
||||
from collections.abc import AsyncIterable
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
_DataT = TypeVar("_DataT", default=Any)
|
||||
|
||||
|
||||
async def queue_to_iterable(
|
||||
queue: Queue[_DataT], timeout: float | None = None
|
||||
) -> AsyncIterable[_DataT]:
|
||||
"""Stream items from a queue until None with an optional timeout per item."""
|
||||
if timeout is None:
|
||||
while (item := await queue.get()) is not None:
|
||||
yield item
|
||||
else:
|
||||
async with async_timeout(timeout):
|
||||
item = await queue.get()
|
||||
|
||||
while item is not None:
|
||||
yield item
|
||||
async with async_timeout(timeout):
|
||||
item = await queue.get()
|
|
@ -1,47 +0,0 @@
|
|||
"""Test VoIP utils."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.voip.util import queue_to_iterable
|
||||
|
||||
|
||||
async def test_queue_to_iterable() -> None:
|
||||
"""Test queue_to_iterable."""
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
expected_items = list(range(10))
|
||||
|
||||
for i in expected_items:
|
||||
await queue.put(i)
|
||||
|
||||
# Will terminate the stream
|
||||
await queue.put(None)
|
||||
|
||||
actual_items = [item async for item in queue_to_iterable(queue)]
|
||||
|
||||
assert expected_items == actual_items
|
||||
|
||||
# Check timeout
|
||||
assert queue.empty()
|
||||
|
||||
# Time out on first item
|
||||
async with asyncio.timeout(1):
|
||||
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
|
||||
# Should time out very quickly
|
||||
async for _item in queue_to_iterable(queue, timeout=0.01):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Check timeout on second item
|
||||
assert queue.empty()
|
||||
await queue.put(12345)
|
||||
|
||||
# Time out on second item
|
||||
async with asyncio.timeout(1):
|
||||
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
|
||||
# Should time out very quickly
|
||||
async for item in queue_to_iterable(queue, timeout=0.01):
|
||||
if item != 12345:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
assert queue.empty()
|
Loading…
Add table
Reference in a new issue