Fix TTS streaming for VoIP (#104620)

* Use wav instead of raw tts audio in voip

* More tests

* Use mock TTS dir
This commit is contained in:
Michael Hansen 2023-11-29 11:07:22 -06:00 committed by GitHub
parent 47426a3ddc
commit a894146cee
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 239 additions and 8 deletions

View file

@ -5,10 +5,12 @@ import asyncio
from collections import deque from collections import deque
from collections.abc import AsyncIterable, MutableSequence, Sequence from collections.abc import AsyncIterable, MutableSequence, Sequence
from functools import partial from functools import partial
import io
import logging import logging
from pathlib import Path from pathlib import Path
import time import time
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import wave
from voip_utils import ( from voip_utils import (
CallInfo, CallInfo,
@ -285,7 +287,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
), ),
conversation_id=self._conversation_id, conversation_id=self._conversation_id,
device_id=self.voip_device.device_id, device_id=self.voip_device.device_id,
tts_audio_output="raw", tts_audio_output="wav",
) )
if self._pipeline_error: if self._pipeline_error:
@ -402,11 +404,32 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
if self.transport is None: if self.transport is None:
return return
_extension, audio_bytes = await tts.async_get_media_source_audio( extension, data = await tts.async_get_media_source_audio(
self.hass, self.hass,
media_id, media_id,
) )
if extension != "wav":
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
with io.BytesIO(data) as wav_io:
with wave.open(wav_io, "rb") as wav_file:
sample_rate = wav_file.getframerate()
sample_width = wav_file.getsampwidth()
sample_channels = wav_file.getnchannels()
if (
(sample_rate != 16000)
or (sample_width != 2)
or (sample_channels != 1)
):
raise ValueError(
"Expected rate/width/channels as 16000/2/1,"
" got {sample_rate}/{sample_width}/{sample_channels}}"
)
audio_bytes = wav_file.readframes(wav_file.getnframes())
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes)) _LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
# Time out 1 second after TTS audio should be finished # Time out 1 second after TTS audio should be finished
@ -414,7 +437,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
tts_seconds = tts_samples / RATE tts_seconds = tts_samples / RATE
async with asyncio.timeout(tts_seconds + self.tts_extra_timeout): async with asyncio.timeout(tts_seconds + self.tts_extra_timeout):
# Assume TTS audio is 16Khz 16-bit mono # TTS audio is 16Khz 16-bit mono
await self._async_send_audio(audio_bytes) await self._async_send_audio(audio_bytes)
except asyncio.TimeoutError as err: except asyncio.TimeoutError as err:
_LOGGER.warning("TTS timeout") _LOGGER.warning("TTS timeout")

View file

@ -1,7 +1,9 @@
"""Test VoIP protocol.""" """Test VoIP protocol."""
import asyncio import asyncio
import io
import time import time
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
import wave
import pytest import pytest
@ -14,6 +16,24 @@ _ONE_SECOND = 16000 * 2 # 16Khz 16-bit
_MEDIA_ID = "12345" _MEDIA_ID = "12345"
@pytest.fixture(autouse=True)
def mock_tts_cache_dir_autouse(mock_tts_cache_dir):
"""Mock the TTS cache dir with empty dir."""
return mock_tts_cache_dir
def _empty_wav() -> bytes:
"""Return bytes of an empty WAV file."""
with io.BytesIO() as wav_io:
wav_file: wave.Wave_write = wave.open(wav_io, "wb")
with wav_file:
wav_file.setframerate(16000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
return wav_io.getvalue()
async def test_pipeline( async def test_pipeline(
hass: HomeAssistant, hass: HomeAssistant,
voip_device: VoIPDevice, voip_device: VoIPDevice,
@ -72,8 +92,7 @@ async def test_pipeline(
media_source_id: str, media_source_id: str,
) -> tuple[str, bytes]: ) -> tuple[str, bytes]:
assert media_source_id == _MEDIA_ID assert media_source_id == _MEDIA_ID
return ("wav", _empty_wav())
return ("mp3", b"")
with patch( with patch(
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech", "homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
@ -266,7 +285,7 @@ async def test_tts_timeout(
media_source_id: str, media_source_id: str,
) -> tuple[str, bytes]: ) -> tuple[str, bytes]:
# Should time out immediately # Should time out immediately
return ("raw", bytes(0)) return ("wav", _empty_wav())
with patch( with patch(
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech", "homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
@ -305,8 +324,197 @@ async def test_tts_timeout(
done.set() done.set()
rtp_protocol._async_send_audio = AsyncMock(side_effect=async_send_audio) rtp_protocol._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes relaxed VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
# Wait for mock pipeline to exhaust the audio stream
async with asyncio.timeout(1):
await done.wait()
async def test_tts_wrong_extension(
hass: HomeAssistant,
voip_device: VoIPDevice,
) -> None:
"""Test that TTS will only stream WAV audio."""
assert await async_setup_component(hass, "voip", {})
def is_speech(self, chunk):
"""Anything non-zero is speech."""
return sum(chunk) > 0
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
pass
# Fake intent result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.INTENT_END,
data={
"intent_output": {
"conversation_id": "fake-conversation",
}
},
)
)
# Proceed with media output
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.TTS_END,
data={"tts_output": {"media_id": _MEDIA_ID}},
)
)
async def async_get_media_source_audio(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
# Should fail because it's not "wav"
return ("mp3", b"")
with patch(
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
new=is_speech,
), patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
), patch(
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
)
rtp_protocol.transport = Mock()
original_send_tts = rtp_protocol._send_tts
async def send_tts(*args, **kwargs):
# Call original then end test successfully
with pytest.raises(ValueError):
await original_send_tts(*args, **kwargs)
done.set()
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes relaxed VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
# Wait for mock pipeline to exhaust the audio stream
async with asyncio.timeout(1):
await done.wait()
async def test_tts_wrong_wav_format(
hass: HomeAssistant,
voip_device: VoIPDevice,
) -> None:
"""Test that TTS will only stream WAV audio with a specific format."""
assert await async_setup_component(hass, "voip", {})
def is_speech(self, chunk):
"""Anything non-zero is speech."""
return sum(chunk) > 0
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
pass
# Fake intent result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.INTENT_END,
data={
"intent_output": {
"conversation_id": "fake-conversation",
}
},
)
)
# Proceed with media output
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.TTS_END,
data={"tts_output": {"media_id": _MEDIA_ID}},
)
)
async def async_get_media_source_audio(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
# Should fail because it's not 16Khz, 16-bit mono
with io.BytesIO() as wav_io:
wav_file: wave.Wave_write = wave.open(wav_io, "wb")
with wav_file:
wav_file.setframerate(22050)
wav_file.setsampwidth(2)
wav_file.setnchannels(2)
return ("wav", wav_io.getvalue())
with patch(
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
new=is_speech,
), patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
), patch(
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
)
rtp_protocol.transport = Mock()
original_send_tts = rtp_protocol._send_tts
async def send_tts(*args, **kwargs):
# Call original then end test successfully
with pytest.raises(ValueError):
await original_send_tts(*args, **kwargs)
done.set()
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
# silence # silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND)) rtp_protocol.on_chunk(bytes(_ONE_SECOND))