Fix TTS streaming for VoIP (#104620)
* Use wav instead of raw tts audio in voip * More tests * Use mock TTS dir
This commit is contained in:
parent
47426a3ddc
commit
a894146cee
2 changed files with 239 additions and 8 deletions
|
@ -5,10 +5,12 @@ import asyncio
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from collections.abc import AsyncIterable, MutableSequence, Sequence
|
from collections.abc import AsyncIterable, MutableSequence, Sequence
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
import io
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
import wave
|
||||||
|
|
||||||
from voip_utils import (
|
from voip_utils import (
|
||||||
CallInfo,
|
CallInfo,
|
||||||
|
@ -285,7 +287,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||||
),
|
),
|
||||||
conversation_id=self._conversation_id,
|
conversation_id=self._conversation_id,
|
||||||
device_id=self.voip_device.device_id,
|
device_id=self.voip_device.device_id,
|
||||||
tts_audio_output="raw",
|
tts_audio_output="wav",
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._pipeline_error:
|
if self._pipeline_error:
|
||||||
|
@ -402,11 +404,32 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||||
if self.transport is None:
|
if self.transport is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
_extension, audio_bytes = await tts.async_get_media_source_audio(
|
extension, data = await tts.async_get_media_source_audio(
|
||||||
self.hass,
|
self.hass,
|
||||||
media_id,
|
media_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if extension != "wav":
|
||||||
|
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
|
||||||
|
|
||||||
|
with io.BytesIO(data) as wav_io:
|
||||||
|
with wave.open(wav_io, "rb") as wav_file:
|
||||||
|
sample_rate = wav_file.getframerate()
|
||||||
|
sample_width = wav_file.getsampwidth()
|
||||||
|
sample_channels = wav_file.getnchannels()
|
||||||
|
|
||||||
|
if (
|
||||||
|
(sample_rate != 16000)
|
||||||
|
or (sample_width != 2)
|
||||||
|
or (sample_channels != 1)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Expected rate/width/channels as 16000/2/1,"
|
||||||
|
" got {sample_rate}/{sample_width}/{sample_channels}}"
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_bytes = wav_file.readframes(wav_file.getnframes())
|
||||||
|
|
||||||
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
|
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
|
||||||
|
|
||||||
# Time out 1 second after TTS audio should be finished
|
# Time out 1 second after TTS audio should be finished
|
||||||
|
@ -414,7 +437,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||||
tts_seconds = tts_samples / RATE
|
tts_seconds = tts_samples / RATE
|
||||||
|
|
||||||
async with asyncio.timeout(tts_seconds + self.tts_extra_timeout):
|
async with asyncio.timeout(tts_seconds + self.tts_extra_timeout):
|
||||||
# Assume TTS audio is 16Khz 16-bit mono
|
# TTS audio is 16Khz 16-bit mono
|
||||||
await self._async_send_audio(audio_bytes)
|
await self._async_send_audio(audio_bytes)
|
||||||
except asyncio.TimeoutError as err:
|
except asyncio.TimeoutError as err:
|
||||||
_LOGGER.warning("TTS timeout")
|
_LOGGER.warning("TTS timeout")
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
"""Test VoIP protocol."""
|
"""Test VoIP protocol."""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import io
|
||||||
import time
|
import time
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
import wave
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -14,6 +16,24 @@ _ONE_SECOND = 16000 * 2 # 16Khz 16-bit
|
||||||
_MEDIA_ID = "12345"
|
_MEDIA_ID = "12345"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_tts_cache_dir_autouse(mock_tts_cache_dir):
|
||||||
|
"""Mock the TTS cache dir with empty dir."""
|
||||||
|
return mock_tts_cache_dir
|
||||||
|
|
||||||
|
|
||||||
|
def _empty_wav() -> bytes:
|
||||||
|
"""Return bytes of an empty WAV file."""
|
||||||
|
with io.BytesIO() as wav_io:
|
||||||
|
wav_file: wave.Wave_write = wave.open(wav_io, "wb")
|
||||||
|
with wav_file:
|
||||||
|
wav_file.setframerate(16000)
|
||||||
|
wav_file.setsampwidth(2)
|
||||||
|
wav_file.setnchannels(1)
|
||||||
|
|
||||||
|
return wav_io.getvalue()
|
||||||
|
|
||||||
|
|
||||||
async def test_pipeline(
|
async def test_pipeline(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
voip_device: VoIPDevice,
|
voip_device: VoIPDevice,
|
||||||
|
@ -72,8 +92,7 @@ async def test_pipeline(
|
||||||
media_source_id: str,
|
media_source_id: str,
|
||||||
) -> tuple[str, bytes]:
|
) -> tuple[str, bytes]:
|
||||||
assert media_source_id == _MEDIA_ID
|
assert media_source_id == _MEDIA_ID
|
||||||
|
return ("wav", _empty_wav())
|
||||||
return ("mp3", b"")
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
|
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
|
||||||
|
@ -266,7 +285,7 @@ async def test_tts_timeout(
|
||||||
media_source_id: str,
|
media_source_id: str,
|
||||||
) -> tuple[str, bytes]:
|
) -> tuple[str, bytes]:
|
||||||
# Should time out immediately
|
# Should time out immediately
|
||||||
return ("raw", bytes(0))
|
return ("wav", _empty_wav())
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
|
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
|
||||||
|
@ -305,8 +324,197 @@ async def test_tts_timeout(
|
||||||
|
|
||||||
done.set()
|
done.set()
|
||||||
|
|
||||||
rtp_protocol._async_send_audio = AsyncMock(side_effect=async_send_audio)
|
rtp_protocol._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
|
||||||
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts)
|
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
# silence
|
||||||
|
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
|
||||||
|
# "speech"
|
||||||
|
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||||
|
|
||||||
|
# silence (assumes relaxed VAD sensitivity)
|
||||||
|
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
||||||
|
|
||||||
|
# Wait for mock pipeline to exhaust the audio stream
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await done.wait()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_tts_wrong_extension(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
voip_device: VoIPDevice,
|
||||||
|
) -> None:
|
||||||
|
"""Test that TTS will only stream WAV audio."""
|
||||||
|
assert await async_setup_component(hass, "voip", {})
|
||||||
|
|
||||||
|
def is_speech(self, chunk):
|
||||||
|
"""Anything non-zero is speech."""
|
||||||
|
return sum(chunk) > 0
|
||||||
|
|
||||||
|
done = asyncio.Event()
|
||||||
|
|
||||||
|
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||||
|
stt_stream = kwargs["stt_stream"]
|
||||||
|
event_callback = kwargs["event_callback"]
|
||||||
|
async for _chunk in stt_stream:
|
||||||
|
# Stream will end when VAD detects end of "speech"
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fake intent result
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
type=assist_pipeline.PipelineEventType.INTENT_END,
|
||||||
|
data={
|
||||||
|
"intent_output": {
|
||||||
|
"conversation_id": "fake-conversation",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Proceed with media output
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
type=assist_pipeline.PipelineEventType.TTS_END,
|
||||||
|
data={"tts_output": {"media_id": _MEDIA_ID}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_get_media_source_audio(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
media_source_id: str,
|
||||||
|
) -> tuple[str, bytes]:
|
||||||
|
# Should fail because it's not "wav"
|
||||||
|
return ("mp3", b"")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
|
||||||
|
new=is_speech,
|
||||||
|
), patch(
|
||||||
|
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||||
|
new=async_pipeline_from_audio_stream,
|
||||||
|
), patch(
|
||||||
|
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
|
||||||
|
new=async_get_media_source_audio,
|
||||||
|
):
|
||||||
|
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||||
|
hass,
|
||||||
|
hass.config.language,
|
||||||
|
voip_device,
|
||||||
|
Context(),
|
||||||
|
opus_payload_type=123,
|
||||||
|
)
|
||||||
|
rtp_protocol.transport = Mock()
|
||||||
|
|
||||||
|
original_send_tts = rtp_protocol._send_tts
|
||||||
|
|
||||||
|
async def send_tts(*args, **kwargs):
|
||||||
|
# Call original then end test successfully
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await original_send_tts(*args, **kwargs)
|
||||||
|
|
||||||
|
done.set()
|
||||||
|
|
||||||
|
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
# silence
|
||||||
|
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
|
||||||
|
# "speech"
|
||||||
|
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||||
|
|
||||||
|
# silence (assumes relaxed VAD sensitivity)
|
||||||
|
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
||||||
|
|
||||||
|
# Wait for mock pipeline to exhaust the audio stream
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await done.wait()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_tts_wrong_wav_format(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
voip_device: VoIPDevice,
|
||||||
|
) -> None:
|
||||||
|
"""Test that TTS will only stream WAV audio with a specific format."""
|
||||||
|
assert await async_setup_component(hass, "voip", {})
|
||||||
|
|
||||||
|
def is_speech(self, chunk):
|
||||||
|
"""Anything non-zero is speech."""
|
||||||
|
return sum(chunk) > 0
|
||||||
|
|
||||||
|
done = asyncio.Event()
|
||||||
|
|
||||||
|
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||||
|
stt_stream = kwargs["stt_stream"]
|
||||||
|
event_callback = kwargs["event_callback"]
|
||||||
|
async for _chunk in stt_stream:
|
||||||
|
# Stream will end when VAD detects end of "speech"
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fake intent result
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
type=assist_pipeline.PipelineEventType.INTENT_END,
|
||||||
|
data={
|
||||||
|
"intent_output": {
|
||||||
|
"conversation_id": "fake-conversation",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Proceed with media output
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
type=assist_pipeline.PipelineEventType.TTS_END,
|
||||||
|
data={"tts_output": {"media_id": _MEDIA_ID}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_get_media_source_audio(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
media_source_id: str,
|
||||||
|
) -> tuple[str, bytes]:
|
||||||
|
# Should fail because it's not 16Khz, 16-bit mono
|
||||||
|
with io.BytesIO() as wav_io:
|
||||||
|
wav_file: wave.Wave_write = wave.open(wav_io, "wb")
|
||||||
|
with wav_file:
|
||||||
|
wav_file.setframerate(22050)
|
||||||
|
wav_file.setsampwidth(2)
|
||||||
|
wav_file.setnchannels(2)
|
||||||
|
|
||||||
|
return ("wav", wav_io.getvalue())
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
|
||||||
|
new=is_speech,
|
||||||
|
), patch(
|
||||||
|
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||||
|
new=async_pipeline_from_audio_stream,
|
||||||
|
), patch(
|
||||||
|
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
|
||||||
|
new=async_get_media_source_audio,
|
||||||
|
):
|
||||||
|
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||||
|
hass,
|
||||||
|
hass.config.language,
|
||||||
|
voip_device,
|
||||||
|
Context(),
|
||||||
|
opus_payload_type=123,
|
||||||
|
)
|
||||||
|
rtp_protocol.transport = Mock()
|
||||||
|
|
||||||
|
original_send_tts = rtp_protocol._send_tts
|
||||||
|
|
||||||
|
async def send_tts(*args, **kwargs):
|
||||||
|
# Call original then end test successfully
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await original_send_tts(*args, **kwargs)
|
||||||
|
|
||||||
|
done.set()
|
||||||
|
|
||||||
|
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||||
|
|
||||||
# silence
|
# silence
|
||||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue