Automatically convert TTS audio to MP3 on demand (#102814)

* Add ATTR_PREFERRED_FORMAT to TTS for auto-converting audio

* Move conversion into SpeechManager

* Handle None case for expected_extension

* Only use ATTR_AUDIO_OUTPUT

* Prefer MP3 in pipelines

* Automatically convert to mp3 on demand

* Add preferred audio format

* Break out preferred format

* Add ATTR_BLOCKING to allow async fetching

* Make a copy of supported options

* Fix MaryTTS tests

* Update ESPHome to use "wav" instead of "raw"

* Clean up tests, remove blocking

* Clean up rest of TTS tests

* Fix ESPHome tests

* More test coverage
This commit is contained in:
Michael Hansen 2023-11-06 14:26:00 -06:00 committed by GitHub
parent 054089291f
commit ae516ffbb5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 723 additions and 241 deletions

View file

@ -971,12 +971,16 @@ class PipelineRun:
# pipeline.tts_engine can't be None or this function is not called # pipeline.tts_engine can't be None or this function is not called
engine = cast(str, self.pipeline.tts_engine) engine = cast(str, self.pipeline.tts_engine)
tts_options = {} tts_options: dict[str, Any] = {}
if self.pipeline.tts_voice is not None: if self.pipeline.tts_voice is not None:
tts_options[tts.ATTR_VOICE] = self.pipeline.tts_voice tts_options[tts.ATTR_VOICE] = self.pipeline.tts_voice
if self.tts_audio_output is not None: if self.tts_audio_output is not None:
tts_options[tts.ATTR_AUDIO_OUTPUT] = self.tts_audio_output tts_options[tts.ATTR_PREFERRED_FORMAT] = self.tts_audio_output
if self.tts_audio_output == "wav":
# 16 Khz, 16-bit mono
tts_options[tts.ATTR_PREFERRED_SAMPLE_RATE] = 16000
tts_options[tts.ATTR_PREFERRED_SAMPLE_CHANNELS] = 1
try: try:
options_supported = await tts.async_support_options( options_supported = await tts.async_support_options(

View file

@ -150,4 +150,4 @@ class CloudProvider(Provider):
_LOGGER.error("Voice error: %s", err) _LOGGER.error("Voice error: %s", err)
return (None, None) return (None, None)
return (str(options[ATTR_AUDIO_OUTPUT]), data) return (str(options[ATTR_AUDIO_OUTPUT].value), data)

View file

@ -3,9 +3,11 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import AsyncIterable, Callable from collections.abc import AsyncIterable, Callable
import io
import logging import logging
import socket import socket
from typing import cast from typing import cast
import wave
from aioesphomeapi import ( from aioesphomeapi import (
VoiceAssistantAudioSettings, VoiceAssistantAudioSettings,
@ -88,6 +90,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
self.handle_event = handle_event self.handle_event = handle_event
self.handle_finished = handle_finished self.handle_finished = handle_finished
self._tts_done = asyncio.Event() self._tts_done = asyncio.Event()
self._tts_task: asyncio.Task | None = None
async def start_server(self) -> int: async def start_server(self) -> int:
"""Start accepting connections.""" """Start accepting connections."""
@ -189,7 +192,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
if self.device_info.voice_assistant_version >= 2: if self.device_info.voice_assistant_version >= 2:
media_id = event.data["tts_output"]["media_id"] media_id = event.data["tts_output"]["media_id"]
self.hass.async_create_background_task( self._tts_task = self.hass.async_create_background_task(
self._send_tts(media_id), "esphome_voice_assistant_tts" self._send_tts(media_id), "esphome_voice_assistant_tts"
) )
else: else:
@ -228,7 +231,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
audio_settings = VoiceAssistantAudioSettings() audio_settings = VoiceAssistantAudioSettings()
tts_audio_output = ( tts_audio_output = (
"raw" if self.device_info.voice_assistant_version >= 2 else "mp3" "wav" if self.device_info.voice_assistant_version >= 2 else "mp3"
) )
_LOGGER.debug("Starting pipeline") _LOGGER.debug("Starting pipeline")
@ -302,11 +305,32 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {} VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {}
) )
_extension, audio_bytes = await tts.async_get_media_source_audio( extension, data = await tts.async_get_media_source_audio(
self.hass, self.hass,
media_id, media_id,
) )
if extension != "wav":
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
with io.BytesIO(data) as wav_io:
with wave.open(wav_io, "rb") as wav_file:
sample_rate = wav_file.getframerate()
sample_width = wav_file.getsampwidth()
sample_channels = wav_file.getnchannels()
if (
(sample_rate != 16000)
or (sample_width != 2)
or (sample_channels != 1)
):
raise ValueError(
"Expected rate/width/channels as 16000/2/1,"
" got {sample_rate}/{sample_width}/{sample_channels}}"
)
audio_bytes = wav_file.readframes(wav_file.getnframes())
_LOGGER.debug("Sending %d bytes of audio", len(audio_bytes)) _LOGGER.debug("Sending %d bytes of audio", len(audio_bytes))
bytes_per_sample = stt.AudioBitRates.BITRATE_16 // 8 bytes_per_sample = stt.AudioBitRates.BITRATE_16 // 8
@ -330,4 +354,5 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
self.handle_event( self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {} VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {}
) )
self._tts_task = None
self._tts_done.set() self._tts_done.set()

View file

@ -13,6 +13,8 @@ import logging
import mimetypes import mimetypes
import os import os
import re import re
import subprocess
import tempfile
from typing import Any, TypedDict, final from typing import Any, TypedDict, final
from aiohttp import web from aiohttp import web
@ -20,7 +22,7 @@ import mutagen
from mutagen.id3 import ID3, TextFrame as ID3Text from mutagen.id3 import ID3, TextFrame as ID3Text
import voluptuous as vol import voluptuous as vol
from homeassistant.components import websocket_api from homeassistant.components import ffmpeg, websocket_api
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from homeassistant.components.media_player import ( from homeassistant.components.media_player import (
ATTR_MEDIA_ANNOUNCE, ATTR_MEDIA_ANNOUNCE,
@ -72,11 +74,15 @@ __all__ = [
"async_get_media_source_audio", "async_get_media_source_audio",
"async_support_options", "async_support_options",
"ATTR_AUDIO_OUTPUT", "ATTR_AUDIO_OUTPUT",
"ATTR_PREFERRED_FORMAT",
"ATTR_PREFERRED_SAMPLE_RATE",
"ATTR_PREFERRED_SAMPLE_CHANNELS",
"CONF_LANG", "CONF_LANG",
"DEFAULT_CACHE_DIR", "DEFAULT_CACHE_DIR",
"generate_media_source_id", "generate_media_source_id",
"PLATFORM_SCHEMA_BASE", "PLATFORM_SCHEMA_BASE",
"PLATFORM_SCHEMA", "PLATFORM_SCHEMA",
"SampleFormat",
"Provider", "Provider",
"TtsAudioType", "TtsAudioType",
"Voice", "Voice",
@ -86,6 +92,9 @@ _LOGGER = logging.getLogger(__name__)
ATTR_PLATFORM = "platform" ATTR_PLATFORM = "platform"
ATTR_AUDIO_OUTPUT = "audio_output" ATTR_AUDIO_OUTPUT = "audio_output"
ATTR_PREFERRED_FORMAT = "preferred_format"
ATTR_PREFERRED_SAMPLE_RATE = "preferred_sample_rate"
ATTR_PREFERRED_SAMPLE_CHANNELS = "preferred_sample_channels"
ATTR_MEDIA_PLAYER_ENTITY_ID = "media_player_entity_id" ATTR_MEDIA_PLAYER_ENTITY_ID = "media_player_entity_id"
ATTR_VOICE = "voice" ATTR_VOICE = "voice"
@ -199,6 +208,83 @@ def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]:
return languages return languages
async def async_convert_audio(
hass: HomeAssistant,
from_extension: str,
audio_bytes: bytes,
to_extension: str,
to_sample_rate: int | None = None,
to_sample_channels: int | None = None,
) -> bytes:
"""Convert audio to a preferred format using ffmpeg."""
ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass)
return await hass.async_add_executor_job(
lambda: _convert_audio(
ffmpeg_manager.binary,
from_extension,
audio_bytes,
to_extension,
to_sample_rate=to_sample_rate,
to_sample_channels=to_sample_channels,
)
)
def _convert_audio(
ffmpeg_binary: str,
from_extension: str,
audio_bytes: bytes,
to_extension: str,
to_sample_rate: int | None = None,
to_sample_channels: int | None = None,
) -> bytes:
"""Convert audio to a preferred format using ffmpeg."""
# We have to use a temporary file here because some formats like WAV store
# the length of the file in the header, and therefore cannot be written in a
# streaming fashion.
with tempfile.NamedTemporaryFile(
mode="wb+", suffix=f".{to_extension}"
) as output_file:
# input
command = [
ffmpeg_binary,
"-y", # overwrite temp file
"-f",
from_extension,
"-i",
"pipe:", # input from stdin
]
# output
command.extend(["-f", to_extension])
if to_sample_rate is not None:
command.extend(["-ar", str(to_sample_rate)])
if to_sample_channels is not None:
command.extend(["-ac", str(to_sample_channels)])
if to_extension == "mp3":
# Max quality for MP3
command.extend(["-q:a", "0"])
command.append(output_file.name)
with subprocess.Popen(
command, stdin=subprocess.PIPE, stderr=subprocess.PIPE
) as proc:
_stdout, stderr = proc.communicate(input=audio_bytes)
if proc.returncode != 0:
_LOGGER.error(stderr.decode())
raise RuntimeError(
f"Unexpected error while running ffmpeg with arguments: {command}. See log for details."
)
output_file.seek(0)
return output_file.read()
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up TTS.""" """Set up TTS."""
websocket_api.async_register_command(hass, websocket_list_engines) websocket_api.async_register_command(hass, websocket_list_engines)
@ -482,7 +568,18 @@ class SpeechManager:
merged_options = dict(engine_instance.default_options or {}) merged_options = dict(engine_instance.default_options or {})
merged_options.update(options or {}) merged_options.update(options or {})
supported_options = engine_instance.supported_options or [] supported_options = list(engine_instance.supported_options or [])
# ATTR_PREFERRED_* options are always "supported" since they're used to
# convert audio after the TTS has run (if necessary).
supported_options.extend(
(
ATTR_PREFERRED_FORMAT,
ATTR_PREFERRED_SAMPLE_RATE,
ATTR_PREFERRED_SAMPLE_CHANNELS,
)
)
invalid_opts = [ invalid_opts = [
opt_name for opt_name in merged_options if opt_name not in supported_options opt_name for opt_name in merged_options if opt_name not in supported_options
] ]
@ -520,12 +617,7 @@ class SpeechManager:
# Load speech from engine into memory # Load speech from engine into memory
else: else:
filename = await self._async_get_tts_audio( filename = await self._async_get_tts_audio(
engine_instance, engine_instance, cache_key, message, use_cache, language, options
cache_key,
message,
use_cache,
language,
options,
) )
return f"/api/tts_proxy/{filename}" return f"/api/tts_proxy/{filename}"
@ -590,10 +682,10 @@ class SpeechManager:
This method is a coroutine. This method is a coroutine.
""" """
if options is not None and ATTR_AUDIO_OUTPUT in options: options = options or {}
expected_extension = options[ATTR_AUDIO_OUTPUT]
else: # Default to MP3 unless a different format is preferred
expected_extension = None final_extension = options.get(ATTR_PREFERRED_FORMAT, "mp3")
async def get_tts_data() -> str: async def get_tts_data() -> str:
"""Handle data available.""" """Handle data available."""
@ -614,8 +706,27 @@ class SpeechManager:
f"No TTS from {engine_instance.name} for '{message}'" f"No TTS from {engine_instance.name} for '{message}'"
) )
# Only convert if we have a preferred format different than the
# expected format from the TTS system, or if a specific sample
# rate/format/channel count is requested.
needs_conversion = (
(final_extension != extension)
or (ATTR_PREFERRED_SAMPLE_RATE in options)
or (ATTR_PREFERRED_SAMPLE_CHANNELS in options)
)
if needs_conversion:
data = await async_convert_audio(
self.hass,
extension,
data,
to_extension=final_extension,
to_sample_rate=options.get(ATTR_PREFERRED_SAMPLE_RATE),
to_sample_channels=options.get(ATTR_PREFERRED_SAMPLE_CHANNELS),
)
# Create file infos # Create file infos
filename = f"{cache_key}.{extension}".lower() filename = f"{cache_key}.{final_extension}".lower()
# Validate filename # Validate filename
if not _RE_VOICE_FILE.match(filename) and not _RE_LEGACY_VOICE_FILE.match( if not _RE_VOICE_FILE.match(filename) and not _RE_LEGACY_VOICE_FILE.match(
@ -626,10 +737,11 @@ class SpeechManager:
) )
# Save to memory # Save to memory
if extension == "mp3": if final_extension == "mp3":
data = self.write_tags( data = self.write_tags(
filename, data, engine_instance.name, message, language, options filename, data, engine_instance.name, message, language, options
) )
self._async_store_to_memcache(cache_key, filename, data) self._async_store_to_memcache(cache_key, filename, data)
if cache: if cache:
@ -641,9 +753,6 @@ class SpeechManager:
audio_task = self.hass.async_create_task(get_tts_data()) audio_task = self.hass.async_create_task(get_tts_data())
if expected_extension is None:
return await audio_task
def handle_error(_future: asyncio.Future) -> None: def handle_error(_future: asyncio.Future) -> None:
"""Handle error.""" """Handle error."""
if audio_task.exception(): if audio_task.exception():
@ -651,7 +760,7 @@ class SpeechManager:
audio_task.add_done_callback(handle_error) audio_task.add_done_callback(handle_error)
filename = f"{cache_key}.{expected_extension}".lower() filename = f"{cache_key}.{final_extension}".lower()
self.mem_cache[cache_key] = { self.mem_cache[cache_key] = {
"filename": filename, "filename": filename,
"voice": b"", "voice": b"",
@ -747,11 +856,12 @@ class SpeechManager:
raise HomeAssistantError(f"{cache_key} not in cache!") raise HomeAssistantError(f"{cache_key} not in cache!")
await self._async_file_to_mem(cache_key) await self._async_file_to_mem(cache_key)
content, _ = mimetypes.guess_type(filename)
cached = self.mem_cache[cache_key] cached = self.mem_cache[cache_key]
if pending := cached.get("pending"): if pending := cached.get("pending"):
await pending await pending
cached = self.mem_cache[cache_key] cached = self.mem_cache[cache_key]
content, _ = mimetypes.guess_type(filename)
return content, cached["voice"] return content, cached["voice"]
@staticmethod @staticmethod

View file

@ -3,7 +3,7 @@
"name": "Text-to-speech (TTS)", "name": "Text-to-speech (TTS)",
"after_dependencies": ["media_player"], "after_dependencies": ["media_player"],
"codeowners": ["@home-assistant/core", "@pvizeli"], "codeowners": ["@home-assistant/core", "@pvizeli"],
"dependencies": ["http"], "dependencies": ["http", "ffmpeg"],
"documentation": "https://www.home-assistant.io/integrations/tts", "documentation": "https://www.home-assistant.io/integrations/tts",
"integration_type": "entity", "integration_type": "entity",
"loggers": ["mutagen"], "loggers": ["mutagen"],

View file

@ -4,7 +4,7 @@ import io
import logging import logging
import wave import wave
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStop from wyoming.audio import AudioChunk, AudioStop
from wyoming.client import AsyncTcpClient from wyoming.client import AsyncTcpClient
from wyoming.tts import Synthesize, SynthesizeVoice from wyoming.tts import Synthesize, SynthesizeVoice
@ -88,12 +88,16 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
@property @property
def supported_options(self): def supported_options(self):
"""Return list of supported options like voice, emotion.""" """Return list of supported options like voice, emotion."""
return [tts.ATTR_AUDIO_OUTPUT, tts.ATTR_VOICE, ATTR_SPEAKER] return [
tts.ATTR_AUDIO_OUTPUT,
tts.ATTR_VOICE,
ATTR_SPEAKER,
]
@property @property
def default_options(self): def default_options(self):
"""Return a dict include default options.""" """Return a dict include default options."""
return {tts.ATTR_AUDIO_OUTPUT: "wav"} return {}
@callback @callback
def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None: def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None:
@ -143,27 +147,4 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
except (OSError, WyomingError): except (OSError, WyomingError):
return (None, None) return (None, None)
if options[tts.ATTR_AUDIO_OUTPUT] == "wav":
return ("wav", data) return ("wav", data)
# Raw output (convert to 16Khz, 16-bit mono)
with io.BytesIO(data) as wav_io:
wav_reader: wave.Wave_read = wave.open(wav_io, "rb")
raw_data = (
AudioChunkConverter(
rate=16000,
width=2,
channels=1,
)
.convert(
AudioChunk(
audio=wav_reader.readframes(wav_reader.getnframes()),
rate=wav_reader.getframerate(),
width=wav_reader.getsampwidth(),
channels=wav_reader.getnchannels(),
)
)
.audio
)
return ("raw", raw_data)

View file

@ -20,6 +20,7 @@ cryptography==41.0.4
dbus-fast==2.12.0 dbus-fast==2.12.0
fnv-hash-fast==0.5.0 fnv-hash-fast==0.5.0
ha-av==10.1.1 ha-av==10.1.1
ha-ffmpeg==3.1.0
hass-nabucasa==0.74.0 hass-nabucasa==0.74.0
hassil==1.2.5 hassil==1.2.5
home-assistant-bluetooth==1.10.4 home-assistant-bluetooth==1.10.4

View file

@ -9,7 +9,7 @@ import wave
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from homeassistant.components import assist_pipeline, stt from homeassistant.components import assist_pipeline, stt, tts
from homeassistant.components.assist_pipeline.const import ( from homeassistant.components.assist_pipeline.const import (
CONF_DEBUG_RECORDING_DIR, CONF_DEBUG_RECORDING_DIR,
DOMAIN, DOMAIN,
@ -660,3 +660,42 @@ def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None:
assert run_1 == run_1 assert run_1 == run_1
assert run_1 != run_2 assert run_1 != run_2
assert run_1 != 1234 assert run_1 != 1234
async def test_tts_audio_output(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
snapshot: SnapshotAssertion,
) -> None:
"""Test using tts_audio_output with wav sets options correctly."""
def event_callback(event):
pass
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
conversation_id=None,
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.TTS,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=event_callback,
tts_audio_output="wav",
),
)
await pipeline_input.validate()
# Verify TTS audio settings
assert pipeline_input.run.tts_options is not None
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_RATE) == 16000
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS) == 1

View file

@ -1,8 +1,10 @@
"""Test ESPHome voice assistant server.""" """Test ESPHome voice assistant server."""
import asyncio import asyncio
import io
import socket import socket
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import wave
from aioesphomeapi import VoiceAssistantEventType from aioesphomeapi import VoiceAssistantEventType
import pytest import pytest
@ -340,9 +342,18 @@ async def test_send_tts(
voice_assistant_udp_server_v2: VoiceAssistantUDPServer, voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
) -> None: ) -> None:
"""Test the UDP server calls sendto to transmit audio data to device.""" """Test the UDP server calls sendto to transmit audio data to device."""
with io.BytesIO() as wav_io:
with wave.open(wav_io, "wb") as wav_file:
wav_file.setframerate(16000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(bytes(_ONE_SECOND))
wav_bytes = wav_io.getvalue()
with patch( with patch(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("raw", bytes(1024)), return_value=("wav", wav_bytes),
): ):
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport) voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
@ -360,6 +371,63 @@ async def test_send_tts(
voice_assistant_udp_server_v2.transport.sendto.assert_called() voice_assistant_udp_server_v2.transport.sendto.assert_called()
async def test_send_tts_wrong_sample_rate(
hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
) -> None:
"""Test the UDP server calls sendto to transmit audio data to device."""
with io.BytesIO() as wav_io:
with wave.open(wav_io, "wb") as wav_file:
wav_file.setframerate(22050) # should be 16000
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(bytes(_ONE_SECOND))
wav_bytes = wav_io.getvalue()
with patch(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("wav", wav_bytes),
), pytest.raises(ValueError):
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
voice_assistant_udp_server_v2._event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
},
)
)
assert voice_assistant_udp_server_v2._tts_task is not None
await voice_assistant_udp_server_v2._tts_task # raises ValueError
async def test_send_tts_wrong_format(
hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
) -> None:
"""Test that only WAV audio will be streamed."""
with patch(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("raw", bytes(1024)),
), pytest.raises(ValueError):
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
voice_assistant_udp_server_v2._event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
},
)
)
assert voice_assistant_udp_server_v2._tts_task is not None
await voice_assistant_udp_server_v2._tts_task # raises ValueError
async def test_wake_word( async def test_wake_word(
hass: HomeAssistant, hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer, voice_assistant_udp_server_v2: VoiceAssistantUDPServer,

View file

@ -2,13 +2,14 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Generator from collections.abc import Generator
from http import HTTPStatus
from typing import Any from typing import Any
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from gtts import gTTSError from gtts import gTTSError
import pytest import pytest
from homeassistant.components import media_source, tts from homeassistant.components import tts
from homeassistant.components.google_translate.const import CONF_TLD, DOMAIN from homeassistant.components.google_translate.const import CONF_TLD, DOMAIN
from homeassistant.components.media_player import ( from homeassistant.components.media_player import (
ATTR_MEDIA_CONTENT_ID, ATTR_MEDIA_CONTENT_ID,
@ -18,10 +19,11 @@ from homeassistant.components.media_player import (
from homeassistant.config import async_process_ha_core_config from homeassistant.config import async_process_ha_core_config
from homeassistant.const import ATTR_ENTITY_ID, CONF_PLATFORM from homeassistant.const import ATTR_ENTITY_ID, CONF_PLATFORM
from homeassistant.core import HomeAssistant, ServiceCall from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry, async_mock_service from tests.common import MockConfigEntry, async_mock_service
from tests.components.tts.common import retrieve_media
from tests.typing import ClientSessionGenerator
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@ -35,15 +37,6 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir):
return mock_tts_cache_dir return mock_tts_cache_dir
async def get_media_source_url(hass: HomeAssistant, media_content_id: str) -> str:
"""Get the media source url."""
if media_source.DOMAIN not in hass.config.components:
assert await async_setup_component(hass, media_source.DOMAIN, {})
resolved = await media_source.async_resolve_media(hass, media_content_id, None)
return resolved.url
@pytest.fixture @pytest.fixture
async def calls(hass: HomeAssistant) -> list[ServiceCall]: async def calls(hass: HomeAssistant) -> list[ServiceCall]:
"""Mock media player calls.""" """Mock media player calls."""
@ -128,6 +121,7 @@ async def mock_config_entry_setup(hass: HomeAssistant, config: dict[str, Any]) -
async def test_tts_service( async def test_tts_service(
hass: HomeAssistant, hass: HomeAssistant,
mock_gtts: MagicMock, mock_gtts: MagicMock,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall], calls: list[ServiceCall],
setup: str, setup: str,
tts_service: str, tts_service: str,
@ -142,9 +136,11 @@ async def test_tts_service(
) )
assert len(calls) == 1 assert len(calls) == 1
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(mock_gtts.mock_calls) == 2 assert len(mock_gtts.mock_calls) == 2
assert url.endswith(".mp3")
assert mock_gtts.mock_calls[0][2] == { assert mock_gtts.mock_calls[0][2] == {
"text": "There is a person at the front door.", "text": "There is a person at the front door.",
@ -180,6 +176,7 @@ async def test_tts_service(
async def test_service_say_german_config( async def test_service_say_german_config(
hass: HomeAssistant, hass: HomeAssistant,
mock_gtts: MagicMock, mock_gtts: MagicMock,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall], calls: list[ServiceCall],
setup: str, setup: str,
tts_service: str, tts_service: str,
@ -194,7 +191,10 @@ async def test_service_say_german_config(
) )
assert len(calls) == 1 assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(mock_gtts.mock_calls) == 2 assert len(mock_gtts.mock_calls) == 2
assert mock_gtts.mock_calls[0][2] == { assert mock_gtts.mock_calls[0][2] == {
"text": "There is a person at the front door.", "text": "There is a person at the front door.",
@ -231,6 +231,7 @@ async def test_service_say_german_config(
async def test_service_say_german_service( async def test_service_say_german_service(
hass: HomeAssistant, hass: HomeAssistant,
mock_gtts: MagicMock, mock_gtts: MagicMock,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall], calls: list[ServiceCall],
setup: str, setup: str,
tts_service: str, tts_service: str,
@ -245,7 +246,10 @@ async def test_service_say_german_service(
) )
assert len(calls) == 1 assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(mock_gtts.mock_calls) == 2 assert len(mock_gtts.mock_calls) == 2
assert mock_gtts.mock_calls[0][2] == { assert mock_gtts.mock_calls[0][2] == {
"text": "There is a person at the front door.", "text": "There is a person at the front door.",
@ -281,6 +285,7 @@ async def test_service_say_german_service(
async def test_service_say_en_uk_config( async def test_service_say_en_uk_config(
hass: HomeAssistant, hass: HomeAssistant,
mock_gtts: MagicMock, mock_gtts: MagicMock,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall], calls: list[ServiceCall],
setup: str, setup: str,
tts_service: str, tts_service: str,
@ -295,7 +300,10 @@ async def test_service_say_en_uk_config(
) )
assert len(calls) == 1 assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(mock_gtts.mock_calls) == 2 assert len(mock_gtts.mock_calls) == 2
assert mock_gtts.mock_calls[0][2] == { assert mock_gtts.mock_calls[0][2] == {
"text": "There is a person at the front door.", "text": "There is a person at the front door.",
@ -332,6 +340,7 @@ async def test_service_say_en_uk_config(
async def test_service_say_en_uk_service( async def test_service_say_en_uk_service(
hass: HomeAssistant, hass: HomeAssistant,
mock_gtts: MagicMock, mock_gtts: MagicMock,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall], calls: list[ServiceCall],
setup: str, setup: str,
tts_service: str, tts_service: str,
@ -346,7 +355,10 @@ async def test_service_say_en_uk_service(
) )
assert len(calls) == 1 assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(mock_gtts.mock_calls) == 2 assert len(mock_gtts.mock_calls) == 2
assert mock_gtts.mock_calls[0][2] == { assert mock_gtts.mock_calls[0][2] == {
"text": "There is a person at the front door.", "text": "There is a person at the front door.",
@ -383,6 +395,7 @@ async def test_service_say_en_uk_service(
async def test_service_say_en_couk( async def test_service_say_en_couk(
hass: HomeAssistant, hass: HomeAssistant,
mock_gtts: MagicMock, mock_gtts: MagicMock,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall], calls: list[ServiceCall],
setup: str, setup: str,
tts_service: str, tts_service: str,
@ -397,9 +410,11 @@ async def test_service_say_en_couk(
) )
assert len(calls) == 1 assert len(calls) == 1
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(mock_gtts.mock_calls) == 2 assert len(mock_gtts.mock_calls) == 2
assert url.endswith(".mp3")
assert mock_gtts.mock_calls[0][2] == { assert mock_gtts.mock_calls[0][2] == {
"text": "There is a person at the front door.", "text": "There is a person at the front door.",
@ -434,6 +449,7 @@ async def test_service_say_en_couk(
async def test_service_say_error( async def test_service_say_error(
hass: HomeAssistant, hass: HomeAssistant,
mock_gtts: MagicMock, mock_gtts: MagicMock,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall], calls: list[ServiceCall],
setup: str, setup: str,
tts_service: str, tts_service: str,
@ -450,6 +466,8 @@ async def test_service_say_error(
) )
assert len(calls) == 1 assert len(calls) == 1
with pytest.raises(HomeAssistantError): assert (
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.NOT_FOUND
)
assert len(mock_gtts.mock_calls) == 2 assert len(mock_gtts.mock_calls) == 2

View file

@ -1,9 +1,12 @@
"""The tests for the MaryTTS speech platform.""" """The tests for the MaryTTS speech platform."""
from http import HTTPStatus
import io
from unittest.mock import patch from unittest.mock import patch
import wave
import pytest import pytest
from homeassistant.components import media_source, tts from homeassistant.components import tts
from homeassistant.components.media_player import ( from homeassistant.components.media_player import (
ATTR_MEDIA_CONTENT_ID, ATTR_MEDIA_CONTENT_ID,
DOMAIN as DOMAIN_MP, DOMAIN as DOMAIN_MP,
@ -13,15 +16,19 @@ from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import assert_setup_component, async_mock_service from tests.common import assert_setup_component, async_mock_service
from tests.components.tts.common import retrieve_media
from tests.typing import ClientSessionGenerator
async def get_media_source_url(hass, media_content_id): def get_empty_wav() -> bytes:
"""Get the media source url.""" """Get bytes for empty WAV file."""
if media_source.DOMAIN not in hass.config.components: with io.BytesIO() as wav_io:
assert await async_setup_component(hass, media_source.DOMAIN, {}) with wave.open(wav_io, "wb") as wav_file:
wav_file.setframerate(22050)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
resolved = await media_source.async_resolve_media(hass, media_content_id, None) return wav_io.getvalue()
return resolved.url
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@ -39,7 +46,9 @@ async def test_setup_component(hass: HomeAssistant) -> None:
await hass.async_block_till_done() await hass.async_block_till_done()
async def test_service_say(hass: HomeAssistant) -> None: async def test_service_say(
hass: HomeAssistant, hass_client: ClientSessionGenerator
) -> None:
"""Test service call say.""" """Test service call say."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -51,7 +60,7 @@ async def test_service_say(hass: HomeAssistant) -> None:
with patch( with patch(
"homeassistant.components.marytts.tts.MaryTTS.speak", "homeassistant.components.marytts.tts.MaryTTS.speak",
return_value=b"audio", return_value=get_empty_wav(),
) as mock_speak: ) as mock_speak:
await hass.services.async_call( await hass.services.async_call(
tts.DOMAIN, tts.DOMAIN,
@ -63,16 +72,22 @@ async def test_service_say(hass: HomeAssistant) -> None:
blocking=True, blocking=True,
) )
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
await retrieve_media(
hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]
)
== HTTPStatus.OK
)
mock_speak.assert_called_once() mock_speak.assert_called_once()
mock_speak.assert_called_with("HomeAssistant", {}) mock_speak.assert_called_with("HomeAssistant", {})
assert len(calls) == 1 assert len(calls) == 1
assert url.endswith(".wav")
async def test_service_say_with_effect(hass: HomeAssistant) -> None: async def test_service_say_with_effect(
hass: HomeAssistant, hass_client: ClientSessionGenerator
) -> None:
"""Test service call say with effects.""" """Test service call say with effects."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -84,7 +99,7 @@ async def test_service_say_with_effect(hass: HomeAssistant) -> None:
with patch( with patch(
"homeassistant.components.marytts.tts.MaryTTS.speak", "homeassistant.components.marytts.tts.MaryTTS.speak",
return_value=b"audio", return_value=get_empty_wav(),
) as mock_speak: ) as mock_speak:
await hass.services.async_call( await hass.services.async_call(
tts.DOMAIN, tts.DOMAIN,
@ -96,16 +111,22 @@ async def test_service_say_with_effect(hass: HomeAssistant) -> None:
blocking=True, blocking=True,
) )
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
await retrieve_media(
hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]
)
== HTTPStatus.OK
)
mock_speak.assert_called_once() mock_speak.assert_called_once()
mock_speak.assert_called_with("HomeAssistant", {"Volume": "amount:2.0;"}) mock_speak.assert_called_with("HomeAssistant", {"Volume": "amount:2.0;"})
assert len(calls) == 1 assert len(calls) == 1
assert url.endswith(".wav")
async def test_service_say_http_error(hass: HomeAssistant) -> None: async def test_service_say_http_error(
hass: HomeAssistant, hass_client: ClientSessionGenerator
) -> None:
"""Test service call say.""" """Test service call say."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -129,7 +150,11 @@ async def test_service_say_http_error(hass: HomeAssistant) -> None:
) )
await hass.async_block_till_done() await hass.async_block_till_done()
with pytest.raises(Exception): assert (
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) await retrieve_media(
hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]
)
== HTTPStatus.NOT_FOUND
)
mock_speak.assert_called_once() mock_speak.assert_called_once()

View file

@ -1,10 +1,11 @@
"""Tests for Microsoft text-to-speech.""" """Tests for Microsoft text-to-speech."""
from http import HTTPStatus
from unittest.mock import patch from unittest.mock import patch
from pycsspeechtts import pycsspeechtts from pycsspeechtts import pycsspeechtts
import pytest import pytest
from homeassistant.components import media_source, tts from homeassistant.components import tts
from homeassistant.components.media_player import ( from homeassistant.components.media_player import (
ATTR_MEDIA_CONTENT_ID, ATTR_MEDIA_CONTENT_ID,
DOMAIN as DOMAIN_MP, DOMAIN as DOMAIN_MP,
@ -13,19 +14,12 @@ from homeassistant.components.media_player import (
from homeassistant.components.microsoft.tts import SUPPORTED_LANGUAGES from homeassistant.components.microsoft.tts import SUPPORTED_LANGUAGES
from homeassistant.config import async_process_ha_core_config from homeassistant.config import async_process_ha_core_config
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError, ServiceNotFound from homeassistant.exceptions import ServiceNotFound
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import async_mock_service from tests.common import async_mock_service
from tests.components.tts.common import retrieve_media
from tests.typing import ClientSessionGenerator
async def get_media_source_url(hass: HomeAssistant, media_content_id):
"""Get the media source url."""
if media_source.DOMAIN not in hass.config.components:
assert await async_setup_component(hass, media_source.DOMAIN, {})
resolved = await media_source.async_resolve_media(hass, media_content_id, None)
return resolved.url
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@ -58,7 +52,9 @@ def mock_tts():
yield mock_tts yield mock_tts
async def test_service_say(hass: HomeAssistant, mock_tts, calls) -> None: async def test_service_say(
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts, calls
) -> None:
"""Test service call say.""" """Test service call say."""
await async_setup_component( await async_setup_component(
@ -76,9 +72,12 @@ async def test_service_say(hass: HomeAssistant, mock_tts, calls) -> None:
) )
assert len(calls) == 1 assert len(calls) == 1
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(mock_tts.mock_calls) == 2 assert len(mock_tts.mock_calls) == 2
assert url.endswith(".mp3")
assert mock_tts.mock_calls[1][2] == { assert mock_tts.mock_calls[1][2] == {
"language": "en-us", "language": "en-us",
@ -93,7 +92,9 @@ async def test_service_say(hass: HomeAssistant, mock_tts, calls) -> None:
} }
async def test_service_say_en_gb_config(hass: HomeAssistant, mock_tts, calls) -> None: async def test_service_say_en_gb_config(
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts, calls
) -> None:
"""Test service call say with en-gb code in the config.""" """Test service call say with en-gb code in the config."""
await async_setup_component( await async_setup_component(
@ -120,7 +121,11 @@ async def test_service_say_en_gb_config(hass: HomeAssistant, mock_tts, calls) ->
) )
assert len(calls) == 1 assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(mock_tts.mock_calls) == 2 assert len(mock_tts.mock_calls) == 2
assert mock_tts.mock_calls[1][2] == { assert mock_tts.mock_calls[1][2] == {
"language": "en-gb", "language": "en-gb",
@ -135,7 +140,9 @@ async def test_service_say_en_gb_config(hass: HomeAssistant, mock_tts, calls) ->
} }
async def test_service_say_en_gb_service(hass: HomeAssistant, mock_tts, calls) -> None: async def test_service_say_en_gb_service(
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts, calls
) -> None:
"""Test service call say with en-gb code in the service.""" """Test service call say with en-gb code in the service."""
await async_setup_component( await async_setup_component(
@ -157,7 +164,11 @@ async def test_service_say_en_gb_service(hass: HomeAssistant, mock_tts, calls) -
) )
assert len(calls) == 1 assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(mock_tts.mock_calls) == 2 assert len(mock_tts.mock_calls) == 2
assert mock_tts.mock_calls[1][2] == { assert mock_tts.mock_calls[1][2] == {
"language": "en-gb", "language": "en-gb",
@ -172,7 +183,9 @@ async def test_service_say_en_gb_service(hass: HomeAssistant, mock_tts, calls) -
} }
async def test_service_say_fa_ir_config(hass: HomeAssistant, mock_tts, calls) -> None: async def test_service_say_fa_ir_config(
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts, calls
) -> None:
"""Test service call say with fa-ir code in the config.""" """Test service call say with fa-ir code in the config."""
await async_setup_component( await async_setup_component(
@ -199,7 +212,11 @@ async def test_service_say_fa_ir_config(hass: HomeAssistant, mock_tts, calls) ->
) )
assert len(calls) == 1 assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(mock_tts.mock_calls) == 2 assert len(mock_tts.mock_calls) == 2
assert mock_tts.mock_calls[1][2] == { assert mock_tts.mock_calls[1][2] == {
"language": "fa-ir", "language": "fa-ir",
@ -214,7 +231,9 @@ async def test_service_say_fa_ir_config(hass: HomeAssistant, mock_tts, calls) ->
} }
async def test_service_say_fa_ir_service(hass: HomeAssistant, mock_tts, calls) -> None: async def test_service_say_fa_ir_service(
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts, calls
) -> None:
"""Test service call say with fa-ir code in the service.""" """Test service call say with fa-ir code in the service."""
config = { config = {
@ -240,7 +259,11 @@ async def test_service_say_fa_ir_service(hass: HomeAssistant, mock_tts, calls) -
) )
assert len(calls) == 1 assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(mock_tts.mock_calls) == 2 assert len(mock_tts.mock_calls) == 2
assert mock_tts.mock_calls[1][2] == { assert mock_tts.mock_calls[1][2] == {
"language": "fa-ir", "language": "fa-ir",
@ -295,7 +318,9 @@ async def test_invalid_language(hass: HomeAssistant, mock_tts, calls) -> None:
assert len(mock_tts.mock_calls) == 0 assert len(mock_tts.mock_calls) == 0
async def test_service_say_error(hass: HomeAssistant, mock_tts, calls) -> None: async def test_service_say_error(
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts, calls
) -> None:
"""Test service call say with http error.""" """Test service call say with http error."""
mock_tts.return_value.speak.side_effect = pycsspeechtts.requests.HTTPError mock_tts.return_value.speak.side_effect = pycsspeechtts.requests.HTTPError
await async_setup_component( await async_setup_component(
@ -313,6 +338,9 @@ async def test_service_say_error(hass: HomeAssistant, mock_tts, calls) -> None:
) )
assert len(calls) == 1 assert len(calls) == 1
with pytest.raises(HomeAssistantError): assert (
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.NOT_FOUND
)
assert len(mock_tts.mock_calls) == 2 assert len(mock_tts.mock_calls) == 2

View file

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Generator from collections.abc import Generator
from http import HTTPStatus
from typing import Any from typing import Any
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@ -32,6 +33,7 @@ from tests.common import (
mock_integration, mock_integration,
mock_platform, mock_platform,
) )
from tests.typing import ClientSessionGenerator
DEFAULT_LANG = "en_US" DEFAULT_LANG = "en_US"
SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"] SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"]
@ -103,6 +105,20 @@ async def get_media_source_url(hass: HomeAssistant, media_content_id: str) -> st
return resolved.url return resolved.url
async def retrieve_media(
hass: HomeAssistant, hass_client: ClientSessionGenerator, media_content_id: str
) -> HTTPStatus:
"""Get the media source url."""
url = await get_media_source_url(hass, media_content_id)
# Ensure media has been generated by requesting it
await hass.async_block_till_done()
client = await hass_client()
req = await client.get(url)
return req.status
class BaseProvider: class BaseProvider:
"""Test speech API provider.""" """Test speech API provider."""

View file

@ -6,7 +6,7 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from homeassistant.components import tts from homeassistant.components import ffmpeg, tts
from homeassistant.components.media_player import ( from homeassistant.components.media_player import (
ATTR_MEDIA_ANNOUNCE, ATTR_MEDIA_ANNOUNCE,
ATTR_MEDIA_CONTENT_ID, ATTR_MEDIA_CONTENT_ID,
@ -15,7 +15,6 @@ from homeassistant.components.media_player import (
SERVICE_PLAY_MEDIA, SERVICE_PLAY_MEDIA,
MediaType, MediaType,
) )
from homeassistant.components.media_source import Unresolvable
from homeassistant.config_entries import ConfigEntryState from homeassistant.config_entries import ConfigEntryState
from homeassistant.const import ATTR_ENTITY_ID, STATE_UNKNOWN from homeassistant.const import ATTR_ENTITY_ID, STATE_UNKNOWN
from homeassistant.core import HomeAssistant, State from homeassistant.core import HomeAssistant, State
@ -33,6 +32,7 @@ from .common import (
get_media_source_url, get_media_source_url,
mock_config_entry_setup, mock_config_entry_setup,
mock_setup, mock_setup,
retrieve_media,
) )
from tests.common import async_mock_service, mock_restore_cache from tests.common import async_mock_service, mock_restore_cache
@ -75,7 +75,9 @@ async def test_default_entity_attributes() -> None:
async def test_config_entry_unload( async def test_config_entry_unload(
hass: HomeAssistant, mock_tts_entity: MockTTSEntity hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_entity: MockTTSEntity,
) -> None: ) -> None:
"""Test we can unload config entry.""" """Test we can unload config entry."""
entity_id = f"{tts.DOMAIN}.{TEST_DOMAIN}" entity_id = f"{tts.DOMAIN}.{TEST_DOMAIN}"
@ -104,7 +106,12 @@ async def test_config_entry_unload(
) )
assert len(calls) == 1 assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
await retrieve_media(
hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]
)
== HTTPStatus.OK
)
await hass.async_block_till_done() await hass.async_block_till_done()
state = hass.states.get(entity_id) state = hass.states.get(entity_id)
@ -1159,6 +1166,7 @@ class MockEntityEmpty(MockTTSEntity):
) )
async def test_service_get_tts_error( async def test_service_get_tts_error(
hass: HomeAssistant, hass: HomeAssistant,
hass_client: ClientSessionGenerator,
setup: str, setup: str,
tts_service: str, tts_service: str,
service_data: dict[str, Any], service_data: dict[str, Any],
@ -1173,8 +1181,10 @@ async def test_service_get_tts_error(
blocking=True, blocking=True,
) )
assert len(calls) == 1 assert len(calls) == 1
with pytest.raises(Unresolvable): assert (
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.NOT_FOUND
)
async def test_load_cache_legacy_retrieve_without_mem_cache( async def test_load_cache_legacy_retrieve_without_mem_cache(
@ -1454,7 +1464,11 @@ async def test_legacy_fetching_in_async(
# Test async_get_media_source_audio # Test async_get_media_source_audio
media_source_id = tts.generate_media_source_id( media_source_id = tts.generate_media_source_id(
hass, "test message", "test", "en_US", None, None hass,
"test message",
"test",
"en_US",
cache=None,
) )
task = hass.async_create_task( task = hass.async_create_task(
@ -1508,16 +1522,6 @@ async def test_fetching_in_async(
class EntityWithAsyncFetching(MockTTSEntity): class EntityWithAsyncFetching(MockTTSEntity):
"""Entity that supports audio output option.""" """Entity that supports audio output option."""
@property
def supported_options(self) -> list[str]:
"""Return list of supported options like voice, emotions."""
return [tts.ATTR_AUDIO_OUTPUT]
@property
def default_options(self) -> dict[str, str]:
"""Return a dict including the default options."""
return {tts.ATTR_AUDIO_OUTPUT: "mp3"}
async def async_get_tts_audio( async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any] self, message: str, language: str, options: dict[str, Any]
) -> tts.TtsAudioType: ) -> tts.TtsAudioType:
@ -1527,7 +1531,11 @@ async def test_fetching_in_async(
# Test async_get_media_source_audio # Test async_get_media_source_audio
media_source_id = tts.generate_media_source_id( media_source_id = tts.generate_media_source_id(
hass, "test message", "tts.test", "en_US", None, None hass,
"test message",
"tts.test",
"en_US",
cache=None,
) )
task = hass.async_create_task( task = hass.async_create_task(
@ -1751,3 +1759,12 @@ async def test_ws_list_voices(
{"voice_id": "fran_drescher", "name": "Fran Drescher"}, {"voice_id": "fran_drescher", "name": "Fran Drescher"},
] ]
} }
async def test_async_convert_audio_error(hass: HomeAssistant) -> None:
"""Test that ffmpeg failing during audio conversion will raise an error."""
assert await async_setup_component(hass, ffmpeg.DOMAIN, {})
with pytest.raises(RuntimeError):
# Simulate a bad WAV file
await tts.async_convert_audio(hass, "wav", bytes(0), "mp3")

View file

@ -1,4 +1,5 @@
"""Tests for TTS media source.""" """Tests for TTS media source."""
from http import HTTPStatus
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
@ -14,8 +15,11 @@ from .common import (
MockTTSEntity, MockTTSEntity,
mock_config_entry_setup, mock_config_entry_setup,
mock_setup, mock_setup,
retrieve_media,
) )
from tests.typing import ClientSessionGenerator
class MSEntity(MockTTSEntity): class MSEntity(MockTTSEntity):
"""Test speech API entity.""" """Test speech API entity."""
@ -88,16 +92,18 @@ async def test_browsing(hass: HomeAssistant, setup: str) -> None:
@pytest.mark.parametrize("mock_provider", [MSProvider(DEFAULT_LANG)]) @pytest.mark.parametrize("mock_provider", [MSProvider(DEFAULT_LANG)])
async def test_legacy_resolving(hass: HomeAssistant, mock_provider: MSProvider) -> None: async def test_legacy_resolving(
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_provider: MSProvider
) -> None:
"""Test resolving legacy provider.""" """Test resolving legacy provider."""
await mock_setup(hass, mock_provider) await mock_setup(hass, mock_provider)
mock_get_tts_audio = mock_provider.get_tts_audio mock_get_tts_audio = mock_provider.get_tts_audio
media = await media_source.async_resolve_media( media_id = "media-source://tts/test?message=Hello%20World"
hass, "media-source://tts/test?message=Hello%20World", None media = await media_source.async_resolve_media(hass, media_id, None)
)
assert media.url.startswith("/api/tts_proxy/") assert media.url.startswith("/api/tts_proxy/")
assert media.mime_type == "audio/mpeg" assert media.mime_type == "audio/mpeg"
assert await retrieve_media(hass, hass_client, media_id) == HTTPStatus.OK
assert len(mock_get_tts_audio.mock_calls) == 1 assert len(mock_get_tts_audio.mock_calls) == 1
message, language = mock_get_tts_audio.mock_calls[0][1] message, language = mock_get_tts_audio.mock_calls[0][1]
@ -107,13 +113,11 @@ async def test_legacy_resolving(hass: HomeAssistant, mock_provider: MSProvider)
# Pass language and options # Pass language and options
mock_get_tts_audio.reset_mock() mock_get_tts_audio.reset_mock()
media = await media_source.async_resolve_media( media_id = "media-source://tts/test?message=Bye%20World&language=de_DE&voice=Paulus"
hass, media = await media_source.async_resolve_media(hass, media_id, None)
"media-source://tts/test?message=Bye%20World&language=de_DE&voice=Paulus",
None,
)
assert media.url.startswith("/api/tts_proxy/") assert media.url.startswith("/api/tts_proxy/")
assert media.mime_type == "audio/mpeg" assert media.mime_type == "audio/mpeg"
assert await retrieve_media(hass, hass_client, media_id) == HTTPStatus.OK
assert len(mock_get_tts_audio.mock_calls) == 1 assert len(mock_get_tts_audio.mock_calls) == 1
message, language = mock_get_tts_audio.mock_calls[0][1] message, language = mock_get_tts_audio.mock_calls[0][1]
@ -123,16 +127,18 @@ async def test_legacy_resolving(hass: HomeAssistant, mock_provider: MSProvider)
@pytest.mark.parametrize("mock_tts_entity", [MSEntity(DEFAULT_LANG)]) @pytest.mark.parametrize("mock_tts_entity", [MSEntity(DEFAULT_LANG)])
async def test_resolving(hass: HomeAssistant, mock_tts_entity: MSEntity) -> None: async def test_resolving(
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts_entity: MSEntity
) -> None:
"""Test resolving entity.""" """Test resolving entity."""
await mock_config_entry_setup(hass, mock_tts_entity) await mock_config_entry_setup(hass, mock_tts_entity)
mock_get_tts_audio = mock_tts_entity.get_tts_audio mock_get_tts_audio = mock_tts_entity.get_tts_audio
media = await media_source.async_resolve_media( media_id = "media-source://tts/tts.test?message=Hello%20World"
hass, "media-source://tts/tts.test?message=Hello%20World", None media = await media_source.async_resolve_media(hass, media_id, None)
)
assert media.url.startswith("/api/tts_proxy/") assert media.url.startswith("/api/tts_proxy/")
assert media.mime_type == "audio/mpeg" assert media.mime_type == "audio/mpeg"
assert await retrieve_media(hass, hass_client, media_id) == HTTPStatus.OK
assert len(mock_get_tts_audio.mock_calls) == 1 assert len(mock_get_tts_audio.mock_calls) == 1
message, language = mock_get_tts_audio.mock_calls[0][1] message, language = mock_get_tts_audio.mock_calls[0][1]
@ -142,13 +148,13 @@ async def test_resolving(hass: HomeAssistant, mock_tts_entity: MSEntity) -> None
# Pass language and options # Pass language and options
mock_get_tts_audio.reset_mock() mock_get_tts_audio.reset_mock()
media = await media_source.async_resolve_media( media_id = (
hass, "media-source://tts/tts.test?message=Bye%20World&language=de_DE&voice=Paulus"
"media-source://tts/tts.test?message=Bye%20World&language=de_DE&voice=Paulus",
None,
) )
media = await media_source.async_resolve_media(hass, media_id, None)
assert media.url.startswith("/api/tts_proxy/") assert media.url.startswith("/api/tts_proxy/")
assert media.mime_type == "audio/mpeg" assert media.mime_type == "audio/mpeg"
assert await retrieve_media(hass, hass_client, media_id) == HTTPStatus.OK
assert len(mock_get_tts_audio.mock_calls) == 1 assert len(mock_get_tts_audio.mock_calls) == 1
message, language = mock_get_tts_audio.mock_calls[0][1] message, language = mock_get_tts_audio.mock_calls[0][1]

View file

@ -4,18 +4,19 @@ from http import HTTPStatus
import pytest import pytest
from homeassistant.components import media_source, tts from homeassistant.components import tts
from homeassistant.components.media_player import ( from homeassistant.components.media_player import (
ATTR_MEDIA_CONTENT_ID, ATTR_MEDIA_CONTENT_ID,
DOMAIN as DOMAIN_MP, DOMAIN as DOMAIN_MP,
SERVICE_PLAY_MEDIA, SERVICE_PLAY_MEDIA,
) )
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import assert_setup_component, async_mock_service from tests.common import assert_setup_component, async_mock_service
from tests.components.tts.common import retrieve_media
from tests.test_util.aiohttp import AiohttpClientMocker from tests.test_util.aiohttp import AiohttpClientMocker
from tests.typing import ClientSessionGenerator
URL = "https://api.voicerss.org/" URL = "https://api.voicerss.org/"
FORM_DATA = { FORM_DATA = {
@ -38,15 +39,6 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir):
return mock_tts_cache_dir return mock_tts_cache_dir
async def get_media_source_url(hass, media_content_id):
"""Get the media source url."""
if media_source.DOMAIN not in hass.config.components:
assert await async_setup_component(hass, media_source.DOMAIN, {})
resolved = await media_source.async_resolve_media(hass, media_content_id, None)
return resolved.url
async def test_setup_component(hass: HomeAssistant) -> None: async def test_setup_component(hass: HomeAssistant) -> None:
"""Test setup component.""" """Test setup component."""
config = {tts.DOMAIN: {"platform": "voicerss", "api_key": "1234567xx"}} config = {tts.DOMAIN: {"platform": "voicerss", "api_key": "1234567xx"}}
@ -66,7 +58,9 @@ async def test_setup_component_without_api_key(hass: HomeAssistant) -> None:
async def test_service_say( async def test_service_say(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker hass: HomeAssistant,
hass_client: ClientSessionGenerator,
aioclient_mock: AiohttpClientMocker,
) -> None: ) -> None:
"""Test service call say.""" """Test service call say."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -90,14 +84,18 @@ async def test_service_say(
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(calls) == 1 assert len(calls) == 1
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
assert url.endswith(".mp3") await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
assert aioclient_mock.mock_calls[0][2] == FORM_DATA assert aioclient_mock.mock_calls[0][2] == FORM_DATA
async def test_service_say_german_config( async def test_service_say_german_config(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker hass: HomeAssistant,
hass_client: ClientSessionGenerator,
aioclient_mock: AiohttpClientMocker,
) -> None: ) -> None:
"""Test service call say with german code in the config.""" """Test service call say with german code in the config."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -128,13 +126,18 @@ async def test_service_say_german_config(
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(calls) == 1 assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
assert aioclient_mock.mock_calls[0][2] == form_data assert aioclient_mock.mock_calls[0][2] == form_data
async def test_service_say_german_service( async def test_service_say_german_service(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker hass: HomeAssistant,
hass_client: ClientSessionGenerator,
aioclient_mock: AiohttpClientMocker,
) -> None: ) -> None:
"""Test service call say with german code in the service.""" """Test service call say with german code in the service."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -160,13 +163,18 @@ async def test_service_say_german_service(
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(calls) == 1 assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
assert aioclient_mock.mock_calls[0][2] == form_data assert aioclient_mock.mock_calls[0][2] == form_data
async def test_service_say_error( async def test_service_say_error(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker hass: HomeAssistant,
hass_client: ClientSessionGenerator,
aioclient_mock: AiohttpClientMocker,
) -> None: ) -> None:
"""Test service call say with http response 400.""" """Test service call say with http response 400."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -189,14 +197,18 @@ async def test_service_say_error(
) )
await hass.async_block_till_done() await hass.async_block_till_done()
with pytest.raises(HomeAssistantError): assert (
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.NOT_FOUND
)
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
assert aioclient_mock.mock_calls[0][2] == FORM_DATA assert aioclient_mock.mock_calls[0][2] == FORM_DATA
async def test_service_say_timeout( async def test_service_say_timeout(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker hass: HomeAssistant,
hass_client: ClientSessionGenerator,
aioclient_mock: AiohttpClientMocker,
) -> None: ) -> None:
"""Test service call say with http timeout.""" """Test service call say with http timeout."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -219,14 +231,18 @@ async def test_service_say_timeout(
) )
await hass.async_block_till_done() await hass.async_block_till_done()
with pytest.raises(HomeAssistantError): assert (
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.NOT_FOUND
)
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
assert aioclient_mock.mock_calls[0][2] == FORM_DATA assert aioclient_mock.mock_calls[0][2] == FORM_DATA
async def test_service_say_error_msg( async def test_service_say_error_msg(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker hass: HomeAssistant,
hass_client: ClientSessionGenerator,
aioclient_mock: AiohttpClientMocker,
) -> None: ) -> None:
"""Test service call say with http error api message.""" """Test service call say with http error api message."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -254,7 +270,9 @@ async def test_service_say_error_msg(
) )
await hass.async_block_till_done() await hass.async_block_till_done()
with pytest.raises(media_source.Unresolvable): assert (
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.NOT_FOUND
)
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
assert aioclient_mock.mock_calls[0][2] == FORM_DATA assert aioclient_mock.mock_calls[0][2] == FORM_DATA

View file

@ -10,6 +10,39 @@
}), }),
]) ])
# --- # ---
# name: test_get_tts_audio_different_formats
list([
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize',
}),
])
# ---
# name: test_get_tts_audio_different_formats.1
list([
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize',
}),
])
# ---
# name: test_get_tts_audio_mp3
list([
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize',
}),
])
# ---
# name: test_get_tts_audio_raw # name: test_get_tts_audio_raw
list([ list([
dict({ dict({

View file

@ -51,31 +51,7 @@ async def test_get_tts_audio(hass: HomeAssistant, init_wyoming_tts, snapshot) ->
AudioStop().event(), AudioStop().event(),
] ]
with patch( # Verify audio
"homeassistant.components.wyoming.tts.AsyncTcpClient",
MockAsyncTcpClient(audio_events),
) as mock_client:
extension, data = await tts.async_get_media_source_audio(
hass,
tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"),
)
assert extension == "wav"
assert data is not None
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
assert wav_file.getframerate() == 16000
assert wav_file.getsampwidth() == 2
assert wav_file.getnchannels() == 1
assert wav_file.readframes(wav_file.getnframes()) == audio
assert mock_client.written == snapshot
async def test_get_tts_audio_raw(
hass: HomeAssistant, init_wyoming_tts, snapshot
) -> None:
"""Test get raw audio."""
audio = bytes(100)
audio_events = [ audio_events = [
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(), AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(), AudioStop().event(),
@ -92,12 +68,83 @@ async def test_get_tts_audio_raw(
"Hello world", "Hello world",
"tts.test_tts", "tts.test_tts",
"en-US", "en-US",
options={tts.ATTR_AUDIO_OUTPUT: "raw"}, options={tts.ATTR_PREFERRED_FORMAT: "wav"},
), ),
) )
assert extension == "raw" assert extension == "wav"
assert data == audio assert data is not None
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
assert wav_file.getframerate() == 16000
assert wav_file.getsampwidth() == 2
assert wav_file.getnchannels() == 1
assert wav_file.readframes(wav_file.getnframes()) == audio
assert mock_client.written == snapshot
async def test_get_tts_audio_different_formats(
hass: HomeAssistant, init_wyoming_tts, snapshot
) -> None:
"""Test changing preferred audio format."""
audio = bytes(16000 * 2 * 1) # one second
audio_events = [
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(),
]
# Request a different sample rate, etc.
with patch(
"homeassistant.components.wyoming.tts.AsyncTcpClient",
MockAsyncTcpClient(audio_events),
) as mock_client:
extension, data = await tts.async_get_media_source_audio(
hass,
tts.generate_media_source_id(
hass,
"Hello world",
"tts.test_tts",
"en-US",
options={
tts.ATTR_PREFERRED_FORMAT: "wav",
tts.ATTR_PREFERRED_SAMPLE_RATE: 48000,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 2,
},
),
)
assert extension == "wav"
assert data is not None
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
assert wav_file.getframerate() == 48000
assert wav_file.getsampwidth() == 2
assert wav_file.getnchannels() == 2
assert wav_file.getnframes() == wav_file.getframerate() # one second
assert mock_client.written == snapshot
# MP3 is the default
audio_events = [
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(),
]
with patch(
"homeassistant.components.wyoming.tts.AsyncTcpClient",
MockAsyncTcpClient(audio_events),
) as mock_client:
extension, data = await tts.async_get_media_source_audio(
hass,
tts.generate_media_source_id(
hass,
"Hello world",
"tts.test_tts",
"en-US",
),
)
assert extension == "mp3"
assert b"ID3" in data
assert mock_client.written == snapshot assert mock_client.written == snapshot

View file

@ -4,7 +4,7 @@ from http import HTTPStatus
import pytest import pytest
from homeassistant.components import media_source, tts from homeassistant.components import tts
from homeassistant.components.media_player import ( from homeassistant.components.media_player import (
ATTR_MEDIA_CONTENT_ID, ATTR_MEDIA_CONTENT_ID,
DOMAIN as DOMAIN_MP, DOMAIN as DOMAIN_MP,
@ -14,7 +14,9 @@ from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import assert_setup_component, async_mock_service from tests.common import assert_setup_component, async_mock_service
from tests.components.tts.common import retrieve_media
from tests.test_util.aiohttp import AiohttpClientMocker from tests.test_util.aiohttp import AiohttpClientMocker
from tests.typing import ClientSessionGenerator
URL = "https://tts.voicetech.yandex.net/generate?" URL = "https://tts.voicetech.yandex.net/generate?"
@ -30,15 +32,6 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir):
return mock_tts_cache_dir return mock_tts_cache_dir
async def get_media_source_url(hass, media_content_id):
"""Get the media source url."""
if media_source.DOMAIN not in hass.config.components:
assert await async_setup_component(hass, media_source.DOMAIN, {})
resolved = await media_source.async_resolve_media(hass, media_content_id, None)
return resolved.url
async def test_setup_component(hass: HomeAssistant) -> None: async def test_setup_component(hass: HomeAssistant) -> None:
"""Test setup component.""" """Test setup component."""
config = {tts.DOMAIN: {"platform": "yandextts", "api_key": "1234567xx"}} config = {tts.DOMAIN: {"platform": "yandextts", "api_key": "1234567xx"}}
@ -58,7 +51,9 @@ async def test_setup_component_without_api_key(hass: HomeAssistant) -> None:
async def test_service_say( async def test_service_say(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_client: ClientSessionGenerator,
) -> None: ) -> None:
"""Test service call say.""" """Test service call say."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -87,12 +82,18 @@ async def test_service_say(
blocking=True, blocking=True,
) )
assert len(calls) == 1 assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
async def test_service_say_russian_config( async def test_service_say_russian_config(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_client: ClientSessionGenerator,
) -> None: ) -> None:
"""Test service call say.""" """Test service call say."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -128,12 +129,18 @@ async def test_service_say_russian_config(
) )
assert len(calls) == 1 assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
async def test_service_say_russian_service( async def test_service_say_russian_service(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_client: ClientSessionGenerator,
) -> None: ) -> None:
"""Test service call say.""" """Test service call say."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -166,12 +173,18 @@ async def test_service_say_russian_service(
blocking=True, blocking=True,
) )
assert len(calls) == 1 assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
async def test_service_say_timeout( async def test_service_say_timeout(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_client: ClientSessionGenerator,
) -> None: ) -> None:
"""Test service call say.""" """Test service call say."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -207,13 +220,18 @@ async def test_service_say_timeout(
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(calls) == 1 assert len(calls) == 1
with pytest.raises(media_source.Unresolvable): assert (
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.NOT_FOUND
)
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
async def test_service_say_http_error( async def test_service_say_http_error(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_client: ClientSessionGenerator,
) -> None: ) -> None:
"""Test service call say.""" """Test service call say."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -248,12 +266,16 @@ async def test_service_say_http_error(
) )
assert len(calls) == 1 assert len(calls) == 1
with pytest.raises(media_source.Unresolvable): assert (
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.NOT_FOUND
)
async def test_service_say_specified_speaker( async def test_service_say_specified_speaker(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_client: ClientSessionGenerator,
) -> None: ) -> None:
"""Test service call say.""" """Test service call say."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -288,12 +310,18 @@ async def test_service_say_specified_speaker(
blocking=True, blocking=True,
) )
assert len(calls) == 1 assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
async def test_service_say_specified_emotion( async def test_service_say_specified_emotion(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_client: ClientSessionGenerator,
) -> None: ) -> None:
"""Test service call say.""" """Test service call say."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -328,13 +356,18 @@ async def test_service_say_specified_emotion(
blocking=True, blocking=True,
) )
assert len(calls) == 1 assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
async def test_service_say_specified_low_speed( async def test_service_say_specified_low_speed(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_client: ClientSessionGenerator,
) -> None: ) -> None:
"""Test service call say.""" """Test service call say."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -365,13 +398,18 @@ async def test_service_say_specified_low_speed(
blocking=True, blocking=True,
) )
assert len(calls) == 1 assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
async def test_service_say_specified_speed( async def test_service_say_specified_speed(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_client: ClientSessionGenerator,
) -> None: ) -> None:
"""Test service call say.""" """Test service call say."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -400,13 +438,18 @@ async def test_service_say_specified_speed(
blocking=True, blocking=True,
) )
assert len(calls) == 1 assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
async def test_service_say_specified_options( async def test_service_say_specified_options(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_client: ClientSessionGenerator,
) -> None: ) -> None:
"""Test service call say with options.""" """Test service call say with options."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -438,6 +481,9 @@ async def test_service_say_specified_options(
blocking=True, blocking=True,
) )
assert len(calls) == 1 assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1