Use first media player announcement format for TTS (#125237)

* Use ANNOUNCEMENT format from first media player for tts

* Fix formatting

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen 2024-09-06 10:57:09 -05:00 committed by GitHub
parent 20639b0f02
commit ee59303d3c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 123 additions and 4 deletions

View file

@ -6,12 +6,14 @@ import asyncio
from collections.abc import AsyncIterable
from functools import partial
import io
from itertools import chain
import logging
import socket
from typing import Any, cast
import wave
from aioesphomeapi import (
MediaPlayerFormatPurpose,
VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag,
VoiceAssistantEventType,
@ -288,6 +290,18 @@ class EsphomeAssistSatellite(
end_stage = PipelineStage.TTS
if feature_flags & VoiceAssistantFeature.SPEAKER:
# Stream WAV audio
self._attr_tts_options = {
tts.ATTR_PREFERRED_FORMAT: "wav",
tts.ATTR_PREFERRED_SAMPLE_RATE: 16000,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
}
else:
# ANNOUNCEMENT format from media player
self._update_tts_format()
# Run the pipeline
_LOGGER.debug("Running pipeline from %s to %s", start_stage, end_stage)
self.entry_data.async_set_assist_pipeline_state(True)
@ -340,6 +354,19 @@ class EsphomeAssistSatellite(
timer_info.is_active,
)
def _update_tts_format(self) -> None:
"""Update the TTS format from the first media player."""
for supported_format in chain(*self.entry_data.media_player_formats.values()):
# Find first announcement format
if supported_format.purpose == MediaPlayerFormatPurpose.ANNOUNCEMENT:
self._attr_tts_options = {
tts.ATTR_PREFERRED_FORMAT: supported_format.format,
tts.ATTR_PREFERRED_SAMPLE_RATE: supported_format.sample_rate,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: supported_format.num_channels,
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
}
break
async def _stream_tts_audio(
self,
media_id: str,

View file

@ -31,6 +31,7 @@ from aioesphomeapi import (
LightInfo,
LockInfo,
MediaPlayerInfo,
MediaPlayerSupportedFormat,
NumberInfo,
SelectInfo,
SensorInfo,
@ -148,6 +149,9 @@ class RuntimeEntryData:
tuple[type[EntityInfo], int], list[Callable[[EntityInfo], None]]
] = field(default_factory=dict)
original_options: dict[str, Any] = field(default_factory=dict)
media_player_formats: dict[str, list[MediaPlayerSupportedFormat]] = field(
default_factory=lambda: defaultdict(list)
)
@property
def name(self) -> str:

View file

@ -3,7 +3,7 @@
from __future__ import annotations
from functools import partial
from typing import Any
from typing import Any, cast
from aioesphomeapi import (
EntityInfo,
@ -66,6 +66,9 @@ class EsphomeMediaPlayer(
if self._static_info.supports_pause:
flags |= MediaPlayerEntityFeature.PAUSE | MediaPlayerEntityFeature.PLAY
self._attr_supported_features = flags
self._entry_data.media_player_formats[self.entity_id] = cast(
MediaPlayerInfo, static_info
).supported_formats
@property
@esphome_state_property
@ -103,6 +106,11 @@ class EsphomeMediaPlayer(
self._key, media_url=media_id, announcement=announcement
)
async def async_will_remove_from_hass(self) -> None:
"""Handle entity being removed."""
await super().async_will_remove_from_hass()
self._entry_data.media_player_formats.pop(self.entity_id, None)
async def async_browse_media(
self,
media_content_type: MediaType | str | None = None,