VoIP audio queue (#91577)

* Clear audio queue after every conversation turn

* Stream STT audio when voice command starts
This commit is contained in:
Michael Hansen 2023-04-17 21:51:14 -05:00 committed by GitHub
parent aeb19831d2
commit 95d16c9829
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 45 additions and 22 deletions

View file

@ -28,7 +28,7 @@ class VoiceCommandSegmenter:
reset_seconds: float = 1.0
"""Seconds before reset start/stop time counters."""
_in_command: bool = False
in_command: bool = False
"""True if inside voice command."""
_speech_seconds_left: float = 0.0
@ -62,7 +62,7 @@ class VoiceCommandSegmenter:
self._silence_seconds_left = self.silence_seconds
self._timeout_seconds_left = self.timeout_seconds
self._reset_seconds_left = self.reset_seconds
self._in_command = False
self.in_command = False
def process(self, samples: bytes) -> bool:
"""Process a 16-bit 16Khz mono audio samples.
@ -101,13 +101,13 @@ class VoiceCommandSegmenter:
if self._timeout_seconds_left <= 0:
return False
if not self._in_command:
if not self.in_command:
if is_speech:
self._reset_seconds_left = self.reset_seconds
self._speech_seconds_left -= self._seconds_per_chunk
if self._speech_seconds_left <= 0:
# Inside voice command
self._in_command = True
self.in_command = True
else:
# Reset if enough silence
self._reset_seconds_left -= self._seconds_per_chunk

View file

@ -2,6 +2,8 @@
from __future__ import annotations
import asyncio
from collections import deque
from collections.abc import AsyncIterable
import logging
import time
from typing import TYPE_CHECKING
@ -26,6 +28,7 @@ from .const import DOMAIN
if TYPE_CHECKING:
from .devices import VoIPDevice, VoIPDevices
_BUFFERED_CHUNKS_BEFORE_SPEECH = 100 # ~2 seconds
_LOGGER = logging.getLogger(__name__)
@ -95,9 +98,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
def on_chunk(self, audio_bytes: bytes) -> None:
"""Handle raw audio chunk."""
if self._pipeline_task is None:
# Clear audio queue
while not self._audio_queue.empty():
self._audio_queue.get_nowait()
self._clear_audio_queue()
# Run pipeline until voice command finishes, then start over
self._pipeline_task = self.hass.async_create_background_task(
@ -114,23 +115,9 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
_LOGGER.debug("Starting pipeline")
async def stt_stream():
segmenter = VoiceCommandSegmenter()
try:
# Timeout if no audio comes in for a while.
# This means the caller hung up.
async with async_timeout.timeout(self.audio_timeout):
chunk = await self._audio_queue.get()
while chunk:
if not segmenter.process(chunk):
# Voice command is finished
break
async for chunk in self._segment_audio():
yield chunk
async with async_timeout.timeout(self.audio_timeout):
chunk = await self._audio_queue.get()
except asyncio.TimeoutError:
# Expected after caller hangs up
_LOGGER.debug("Audio timeout")
@ -138,6 +125,8 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
if self.transport is not None:
self.transport.close()
self.transport = None
finally:
self._clear_audio_queue()
try:
# Run pipeline with a timeout
@ -172,6 +161,40 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
# Allow pipeline to run again
self._pipeline_task = None
async def _segment_audio(self) -> AsyncIterable[bytes]:
segmenter = VoiceCommandSegmenter()
chunk_buffer: deque[bytes] = deque(maxlen=_BUFFERED_CHUNKS_BEFORE_SPEECH)
# Timeout if no audio comes in for a while.
# This means the caller hung up.
async with async_timeout.timeout(self.audio_timeout):
chunk = await self._audio_queue.get()
while chunk:
if not segmenter.process(chunk):
# Voice command is finished
break
if segmenter.in_command:
if chunk_buffer:
# Release audio in buffer first
for buffered_chunk in chunk_buffer:
yield buffered_chunk
chunk_buffer.clear()
yield chunk
else:
# Buffer until command starts
chunk_buffer.append(chunk)
async with async_timeout.timeout(self.audio_timeout):
chunk = await self._audio_queue.get()
def _clear_audio_queue(self) -> None:
while not self._audio_queue.empty():
self._audio_queue.get_nowait()
def _event_callback(self, event: PipelineEvent):
if not event.data:
return