diff --git a/homeassistant/components/voip/voip.py b/homeassistant/components/voip/voip.py index 120f2d9559b..74bc94e7dc5 100644 --- a/homeassistant/components/voip/voip.py +++ b/homeassistant/components/voip/voip.py @@ -5,10 +5,12 @@ import asyncio from collections import deque from collections.abc import AsyncIterable, MutableSequence, Sequence from functools import partial +import io import logging from pathlib import Path import time from typing import TYPE_CHECKING +import wave from voip_utils import ( CallInfo, @@ -285,7 +287,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol): ), conversation_id=self._conversation_id, device_id=self.voip_device.device_id, - tts_audio_output="raw", + tts_audio_output="wav", ) if self._pipeline_error: @@ -402,11 +404,32 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol): if self.transport is None: return - _extension, audio_bytes = await tts.async_get_media_source_audio( + extension, data = await tts.async_get_media_source_audio( self.hass, media_id, ) + if extension != "wav": + raise ValueError(f"Only WAV audio can be streamed, got {extension}") + + with io.BytesIO(data) as wav_io: + with wave.open(wav_io, "rb") as wav_file: + sample_rate = wav_file.getframerate() + sample_width = wav_file.getsampwidth() + sample_channels = wav_file.getnchannels() + + if ( + (sample_rate != 16000) + or (sample_width != 2) + or (sample_channels != 1) + ): + raise ValueError( + "Expected rate/width/channels as 16000/2/1," + " got {sample_rate}/{sample_width}/{sample_channels}}" + ) + + audio_bytes = wav_file.readframes(wav_file.getnframes()) + _LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes)) # Time out 1 second after TTS audio should be finished @@ -414,7 +437,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol): tts_seconds = tts_samples / RATE async with asyncio.timeout(tts_seconds + self.tts_extra_timeout): - # Assume TTS audio is 16Khz 16-bit mono + # TTS audio is 16Khz 16-bit mono await self._async_send_audio(audio_bytes) except asyncio.TimeoutError as err: _LOGGER.warning("TTS timeout") diff --git a/tests/components/voip/test_voip.py b/tests/components/voip/test_voip.py index f82a00087c6..692896c6dfa 100644 --- a/tests/components/voip/test_voip.py +++ b/tests/components/voip/test_voip.py @@ -1,7 +1,9 @@ """Test VoIP protocol.""" import asyncio +import io import time from unittest.mock import AsyncMock, Mock, patch +import wave import pytest @@ -14,6 +16,24 @@ _ONE_SECOND = 16000 * 2 # 16Khz 16-bit _MEDIA_ID = "12345" +@pytest.fixture(autouse=True) +def mock_tts_cache_dir_autouse(mock_tts_cache_dir): + """Mock the TTS cache dir with empty dir.""" + return mock_tts_cache_dir + + +def _empty_wav() -> bytes: + """Return bytes of an empty WAV file.""" + with io.BytesIO() as wav_io: + wav_file: wave.Wave_write = wave.open(wav_io, "wb") + with wav_file: + wav_file.setframerate(16000) + wav_file.setsampwidth(2) + wav_file.setnchannels(1) + + return wav_io.getvalue() + + async def test_pipeline( hass: HomeAssistant, voip_device: VoIPDevice, @@ -72,8 +92,7 @@ async def test_pipeline( media_source_id: str, ) -> tuple[str, bytes]: assert media_source_id == _MEDIA_ID - - return ("mp3", b"") + return ("wav", _empty_wav()) with patch( "homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech", @@ -266,7 +285,7 @@ async def test_tts_timeout( media_source_id: str, ) -> tuple[str, bytes]: # Should time out immediately - return ("raw", bytes(0)) + return ("wav", _empty_wav()) with patch( "homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech", @@ -305,8 +324,197 @@ async def test_tts_timeout( done.set() - rtp_protocol._async_send_audio = AsyncMock(side_effect=async_send_audio) - rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) + rtp_protocol._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign] + rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign] + + # silence + rtp_protocol.on_chunk(bytes(_ONE_SECOND)) + + # "speech" + rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2)) + + # silence (assumes relaxed VAD sensitivity) + rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4)) + + # Wait for mock pipeline to exhaust the audio stream + async with asyncio.timeout(1): + await done.wait() + + +async def test_tts_wrong_extension( + hass: HomeAssistant, + voip_device: VoIPDevice, +) -> None: + """Test that TTS will only stream WAV audio.""" + assert await async_setup_component(hass, "voip", {}) + + def is_speech(self, chunk): + """Anything non-zero is speech.""" + return sum(chunk) > 0 + + done = asyncio.Event() + + async def async_pipeline_from_audio_stream(*args, **kwargs): + stt_stream = kwargs["stt_stream"] + event_callback = kwargs["event_callback"] + async for _chunk in stt_stream: + # Stream will end when VAD detects end of "speech" + pass + + # Fake intent result + event_callback( + assist_pipeline.PipelineEvent( + type=assist_pipeline.PipelineEventType.INTENT_END, + data={ + "intent_output": { + "conversation_id": "fake-conversation", + } + }, + ) + ) + + # Proceed with media output + event_callback( + assist_pipeline.PipelineEvent( + type=assist_pipeline.PipelineEventType.TTS_END, + data={"tts_output": {"media_id": _MEDIA_ID}}, + ) + ) + + async def async_get_media_source_audio( + hass: HomeAssistant, + media_source_id: str, + ) -> tuple[str, bytes]: + # Should fail because it's not "wav" + return ("mp3", b"") + + with patch( + "homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech", + new=is_speech, + ), patch( + "homeassistant.components.voip.voip.async_pipeline_from_audio_stream", + new=async_pipeline_from_audio_stream, + ), patch( + "homeassistant.components.voip.voip.tts.async_get_media_source_audio", + new=async_get_media_source_audio, + ): + rtp_protocol = voip.voip.PipelineRtpDatagramProtocol( + hass, + hass.config.language, + voip_device, + Context(), + opus_payload_type=123, + ) + rtp_protocol.transport = Mock() + + original_send_tts = rtp_protocol._send_tts + + async def send_tts(*args, **kwargs): + # Call original then end test successfully + with pytest.raises(ValueError): + await original_send_tts(*args, **kwargs) + + done.set() + + rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign] + + # silence + rtp_protocol.on_chunk(bytes(_ONE_SECOND)) + + # "speech" + rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2)) + + # silence (assumes relaxed VAD sensitivity) + rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4)) + + # Wait for mock pipeline to exhaust the audio stream + async with asyncio.timeout(1): + await done.wait() + + +async def test_tts_wrong_wav_format( + hass: HomeAssistant, + voip_device: VoIPDevice, +) -> None: + """Test that TTS will only stream WAV audio with a specific format.""" + assert await async_setup_component(hass, "voip", {}) + + def is_speech(self, chunk): + """Anything non-zero is speech.""" + return sum(chunk) > 0 + + done = asyncio.Event() + + async def async_pipeline_from_audio_stream(*args, **kwargs): + stt_stream = kwargs["stt_stream"] + event_callback = kwargs["event_callback"] + async for _chunk in stt_stream: + # Stream will end when VAD detects end of "speech" + pass + + # Fake intent result + event_callback( + assist_pipeline.PipelineEvent( + type=assist_pipeline.PipelineEventType.INTENT_END, + data={ + "intent_output": { + "conversation_id": "fake-conversation", + } + }, + ) + ) + + # Proceed with media output + event_callback( + assist_pipeline.PipelineEvent( + type=assist_pipeline.PipelineEventType.TTS_END, + data={"tts_output": {"media_id": _MEDIA_ID}}, + ) + ) + + async def async_get_media_source_audio( + hass: HomeAssistant, + media_source_id: str, + ) -> tuple[str, bytes]: + # Should fail because it's not 16Khz, 16-bit mono + with io.BytesIO() as wav_io: + wav_file: wave.Wave_write = wave.open(wav_io, "wb") + with wav_file: + wav_file.setframerate(22050) + wav_file.setsampwidth(2) + wav_file.setnchannels(2) + + return ("wav", wav_io.getvalue()) + + with patch( + "homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech", + new=is_speech, + ), patch( + "homeassistant.components.voip.voip.async_pipeline_from_audio_stream", + new=async_pipeline_from_audio_stream, + ), patch( + "homeassistant.components.voip.voip.tts.async_get_media_source_audio", + new=async_get_media_source_audio, + ): + rtp_protocol = voip.voip.PipelineRtpDatagramProtocol( + hass, + hass.config.language, + voip_device, + Context(), + opus_payload_type=123, + ) + rtp_protocol.transport = Mock() + + original_send_tts = rtp_protocol._send_tts + + async def send_tts(*args, **kwargs): + # Call original then end test successfully + with pytest.raises(ValueError): + await original_send_tts(*args, **kwargs) + + done.set() + + rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign] # silence rtp_protocol.on_chunk(bytes(_ONE_SECOND))