diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index c88e0e83334..8ea4617bbf3 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -16,7 +16,7 @@ import os import re import subprocess import tempfile -from typing import Any, TypedDict, final +from typing import Any, Final, TypedDict, final from aiohttp import web import mutagen @@ -99,6 +99,13 @@ ATTR_PREFERRED_SAMPLE_CHANNELS = "preferred_sample_channels" ATTR_MEDIA_PLAYER_ENTITY_ID = "media_player_entity_id" ATTR_VOICE = "voice" +_DEFAULT_FORMAT = "mp3" +_PREFFERED_FORMAT_OPTIONS: Final[set[str]] = { + ATTR_PREFERRED_FORMAT, + ATTR_PREFERRED_SAMPLE_RATE, + ATTR_PREFERRED_SAMPLE_CHANNELS, +} + CONF_LANG = "language" SERVICE_CLEAR_CACHE = "clear_cache" @@ -569,25 +576,23 @@ class SpeechManager: ): raise HomeAssistantError(f"Language '{language}' not supported") + options = options or {} + supported_options = engine_instance.supported_options or [] + # Update default options with provided options + invalid_opts: list[str] = [] merged_options = dict(engine_instance.default_options or {}) - merged_options.update(options or {}) + for option_name, option_value in options.items(): + # Only count an option as invalid if it's not a "preferred format" + # option. These are used as hints to the TTS system if supported, + # and otherwise as parameters to ffmpeg conversion. + if (option_name in supported_options) or ( + option_name in _PREFFERED_FORMAT_OPTIONS + ): + merged_options[option_name] = option_value + else: + invalid_opts.append(option_name) - supported_options = list(engine_instance.supported_options or []) - - # ATTR_PREFERRED_* options are always "supported" since they're used to - # convert audio after the TTS has run (if necessary). - supported_options.extend( - ( - ATTR_PREFERRED_FORMAT, - ATTR_PREFERRED_SAMPLE_RATE, - ATTR_PREFERRED_SAMPLE_CHANNELS, - ) - ) - - invalid_opts = [ - opt_name for opt_name in merged_options if opt_name not in supported_options - ] if invalid_opts: raise HomeAssistantError(f"Invalid options found: {invalid_opts}") @@ -687,10 +692,31 @@ class SpeechManager: This method is a coroutine. """ - options = options or {} + options = dict(options or {}) + supported_options = engine_instance.supported_options or [] - # Default to MP3 unless a different format is preferred - final_extension = options.get(ATTR_PREFERRED_FORMAT, "mp3") + # Extract preferred format options. + # + # These options are used by Assist pipelines, etc. to get a format that + # the voice satellite will support. + # + # The TTS system ideally supports options directly so we won't have + # to convert with ffmpeg later. If not, we pop the options here and + # perform the conversation after receiving the audio. + if ATTR_PREFERRED_FORMAT in supported_options: + final_extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT) + else: + final_extension = options.pop(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT) + + if ATTR_PREFERRED_SAMPLE_RATE in supported_options: + sample_rate = options.get(ATTR_PREFERRED_SAMPLE_RATE) + else: + sample_rate = options.pop(ATTR_PREFERRED_SAMPLE_RATE, None) + + if ATTR_PREFERRED_SAMPLE_CHANNELS in supported_options: + sample_channels = options.get(ATTR_PREFERRED_SAMPLE_CHANNELS) + else: + sample_channels = options.pop(ATTR_PREFERRED_SAMPLE_CHANNELS, None) async def get_tts_data() -> str: """Handle data available.""" @@ -716,8 +742,8 @@ class SpeechManager: # rate/format/channel count is requested. needs_conversion = ( (final_extension != extension) - or (ATTR_PREFERRED_SAMPLE_RATE in options) - or (ATTR_PREFERRED_SAMPLE_CHANNELS in options) + or (sample_rate is not None) + or (sample_channels is not None) ) if needs_conversion: @@ -726,8 +752,8 @@ class SpeechManager: extension, data, to_extension=final_extension, - to_sample_rate=options.get(ATTR_PREFERRED_SAMPLE_RATE), - to_sample_channels=options.get(ATTR_PREFERRED_SAMPLE_CHANNELS), + to_sample_rate=sample_rate, + to_sample_channels=sample_channels, ) # Create file infos diff --git a/tests/components/assist_pipeline/conftest.py b/tests/components/assist_pipeline/conftest.py index 8c5cfe9d599..9f098150288 100644 --- a/tests/components/assist_pipeline/conftest.py +++ b/tests/components/assist_pipeline/conftest.py @@ -111,6 +111,7 @@ class MockTTSProvider(tts.Provider): tts.Voice("fran_drescher", "Fran Drescher"), ] } + _supported_options = ["voice", "age", tts.ATTR_AUDIO_OUTPUT] @property def default_language(self) -> str: @@ -130,7 +131,7 @@ class MockTTSProvider(tts.Provider): @property def supported_options(self) -> list[str]: """Return list of supported options like voice, emotions.""" - return ["voice", "age", tts.ATTR_AUDIO_OUTPUT] + return self._supported_options def get_tts_audio( self, message: str, language: str, options: dict[str, Any] diff --git a/tests/components/assist_pipeline/test_init.py b/tests/components/assist_pipeline/test_init.py index 81347e96235..c6f45044cb3 100644 --- a/tests/components/assist_pipeline/test_init.py +++ b/tests/components/assist_pipeline/test_init.py @@ -11,7 +11,7 @@ import wave import pytest from syrupy.assertion import SnapshotAssertion -from homeassistant.components import assist_pipeline, stt, tts +from homeassistant.components import assist_pipeline, media_source, stt, tts from homeassistant.components.assist_pipeline.const import ( CONF_DEBUG_RECORDING_DIR, DOMAIN, @@ -19,9 +19,14 @@ from homeassistant.components.assist_pipeline.const import ( from homeassistant.core import Context, HomeAssistant from homeassistant.setup import async_setup_component -from .conftest import MockSttProvider, MockSttProviderEntity, MockWakeWordEntity +from .conftest import ( + MockSttProvider, + MockSttProviderEntity, + MockTTSProvider, + MockWakeWordEntity, +) -from tests.typing import WebSocketGenerator +from tests.typing import ClientSessionGenerator, WebSocketGenerator BYTES_ONE_SECOND = 16000 * 2 @@ -729,15 +734,17 @@ def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None: async def test_tts_audio_output( hass: HomeAssistant, - mock_stt_provider: MockSttProvider, + hass_client: ClientSessionGenerator, + mock_tts_provider: MockTTSProvider, init_components, pipeline_data: assist_pipeline.pipeline.PipelineData, snapshot: SnapshotAssertion, ) -> None: """Test using tts_audio_output with wav sets options correctly.""" + client = await hass_client() + assert await async_setup_component(hass, media_source.DOMAIN, {}) - def event_callback(event): - pass + events: list[assist_pipeline.PipelineEvent] = [] pipeline_store = pipeline_data.pipeline_store pipeline_id = pipeline_store.async_get_preferred_item() @@ -753,7 +760,7 @@ async def test_tts_audio_output( pipeline=pipeline, start_stage=assist_pipeline.PipelineStage.TTS, end_stage=assist_pipeline.PipelineStage.TTS, - event_callback=event_callback, + event_callback=events.append, tts_audio_output="wav", ), ) @@ -764,3 +771,87 @@ async def test_tts_audio_output( assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_FORMAT) == "wav" assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_RATE) == 16000 assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS) == 1 + + with patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio: + await pipeline_input.execute() + + for event in events: + if event.type == assist_pipeline.PipelineEventType.TTS_END: + # We must fetch the media URL to trigger the TTS + assert event.data + media_id = event.data["tts_output"]["media_id"] + resolved = await media_source.async_resolve_media(hass, media_id, None) + await client.get(resolved.url) + + # Ensure that no unsupported options were passed in + assert mock_get_tts_audio.called + options = mock_get_tts_audio.call_args_list[0].kwargs["options"] + extra_options = set(options).difference(mock_tts_provider.supported_options) + assert len(extra_options) == 0, extra_options + + +async def test_tts_supports_preferred_format( + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + mock_tts_provider: MockTTSProvider, + init_components, + pipeline_data: assist_pipeline.pipeline.PipelineData, + snapshot: SnapshotAssertion, +) -> None: + """Test that preferred format options are given to the TTS system if supported.""" + client = await hass_client() + assert await async_setup_component(hass, media_source.DOMAIN, {}) + + events: list[assist_pipeline.PipelineEvent] = [] + + pipeline_store = pipeline_data.pipeline_store + pipeline_id = pipeline_store.async_get_preferred_item() + pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) + + pipeline_input = assist_pipeline.pipeline.PipelineInput( + tts_input="This is a test.", + conversation_id=None, + device_id=None, + run=assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.TTS, + end_stage=assist_pipeline.PipelineStage.TTS, + event_callback=events.append, + tts_audio_output="wav", + ), + ) + await pipeline_input.validate() + + # Make the TTS provider support preferred format options + supported_options = list(mock_tts_provider.supported_options or []) + supported_options.extend( + [ + tts.ATTR_PREFERRED_FORMAT, + tts.ATTR_PREFERRED_SAMPLE_RATE, + tts.ATTR_PREFERRED_SAMPLE_CHANNELS, + ] + ) + + with ( + patch.object(mock_tts_provider, "_supported_options", supported_options), + patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio, + ): + await pipeline_input.execute() + + for event in events: + if event.type == assist_pipeline.PipelineEventType.TTS_END: + # We must fetch the media URL to trigger the TTS + assert event.data + media_id = event.data["tts_output"]["media_id"] + resolved = await media_source.async_resolve_media(hass, media_id, None) + await client.get(resolved.url) + + assert mock_get_tts_audio.called + options = mock_get_tts_audio.call_args_list[0].kwargs["options"] + + # We should have received preferred format options in get_tts_audio + assert tts.ATTR_PREFERRED_FORMAT in options + assert tts.ATTR_PREFERRED_SAMPLE_RATE in options + assert tts.ATTR_PREFERRED_SAMPLE_CHANNELS in options