Automatically convert TTS audio to MP3 on demand (#102814)

* Add ATTR_PREFERRED_FORMAT to TTS for auto-converting audio

* Move conversion into SpeechManager

* Handle None case for expected_extension

* Only use ATTR_AUDIO_OUTPUT

* Prefer MP3 in pipelines

* Automatically convert to mp3 on demand

* Add preferred audio format

* Break out preferred format

* Add ATTR_BLOCKING to allow async fetching

* Make a copy of supported options

* Fix MaryTTS tests

* Update ESPHome to use "wav" instead of "raw"

* Clean up tests, remove blocking

* Clean up rest of TTS tests

* Fix ESPHome tests

* More test coverage
This commit is contained in:
Michael Hansen 2023-11-06 14:26:00 -06:00 committed by GitHub
parent 054089291f
commit ae516ffbb5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 723 additions and 241 deletions

View file

@ -13,6 +13,8 @@ import logging
import mimetypes
import os
import re
import subprocess
import tempfile
from typing import Any, TypedDict, final
from aiohttp import web
@ -20,7 +22,7 @@ import mutagen
from mutagen.id3 import ID3, TextFrame as ID3Text
import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.components import ffmpeg, websocket_api
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.media_player import (
ATTR_MEDIA_ANNOUNCE,
@ -72,11 +74,15 @@ __all__ = [
"async_get_media_source_audio",
"async_support_options",
"ATTR_AUDIO_OUTPUT",
"ATTR_PREFERRED_FORMAT",
"ATTR_PREFERRED_SAMPLE_RATE",
"ATTR_PREFERRED_SAMPLE_CHANNELS",
"CONF_LANG",
"DEFAULT_CACHE_DIR",
"generate_media_source_id",
"PLATFORM_SCHEMA_BASE",
"PLATFORM_SCHEMA",
"SampleFormat",
"Provider",
"TtsAudioType",
"Voice",
@ -86,6 +92,9 @@ _LOGGER = logging.getLogger(__name__)
ATTR_PLATFORM = "platform"
ATTR_AUDIO_OUTPUT = "audio_output"
ATTR_PREFERRED_FORMAT = "preferred_format"
ATTR_PREFERRED_SAMPLE_RATE = "preferred_sample_rate"
ATTR_PREFERRED_SAMPLE_CHANNELS = "preferred_sample_channels"
ATTR_MEDIA_PLAYER_ENTITY_ID = "media_player_entity_id"
ATTR_VOICE = "voice"
@ -199,6 +208,83 @@ def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]:
return languages
async def async_convert_audio(
hass: HomeAssistant,
from_extension: str,
audio_bytes: bytes,
to_extension: str,
to_sample_rate: int | None = None,
to_sample_channels: int | None = None,
) -> bytes:
"""Convert audio to a preferred format using ffmpeg."""
ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass)
return await hass.async_add_executor_job(
lambda: _convert_audio(
ffmpeg_manager.binary,
from_extension,
audio_bytes,
to_extension,
to_sample_rate=to_sample_rate,
to_sample_channels=to_sample_channels,
)
)
def _convert_audio(
ffmpeg_binary: str,
from_extension: str,
audio_bytes: bytes,
to_extension: str,
to_sample_rate: int | None = None,
to_sample_channels: int | None = None,
) -> bytes:
"""Convert audio to a preferred format using ffmpeg."""
# We have to use a temporary file here because some formats like WAV store
# the length of the file in the header, and therefore cannot be written in a
# streaming fashion.
with tempfile.NamedTemporaryFile(
mode="wb+", suffix=f".{to_extension}"
) as output_file:
# input
command = [
ffmpeg_binary,
"-y", # overwrite temp file
"-f",
from_extension,
"-i",
"pipe:", # input from stdin
]
# output
command.extend(["-f", to_extension])
if to_sample_rate is not None:
command.extend(["-ar", str(to_sample_rate)])
if to_sample_channels is not None:
command.extend(["-ac", str(to_sample_channels)])
if to_extension == "mp3":
# Max quality for MP3
command.extend(["-q:a", "0"])
command.append(output_file.name)
with subprocess.Popen(
command, stdin=subprocess.PIPE, stderr=subprocess.PIPE
) as proc:
_stdout, stderr = proc.communicate(input=audio_bytes)
if proc.returncode != 0:
_LOGGER.error(stderr.decode())
raise RuntimeError(
f"Unexpected error while running ffmpeg with arguments: {command}. See log for details."
)
output_file.seek(0)
return output_file.read()
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up TTS."""
websocket_api.async_register_command(hass, websocket_list_engines)
@ -482,7 +568,18 @@ class SpeechManager:
merged_options = dict(engine_instance.default_options or {})
merged_options.update(options or {})
supported_options = engine_instance.supported_options or []
supported_options = list(engine_instance.supported_options or [])
# ATTR_PREFERRED_* options are always "supported" since they're used to
# convert audio after the TTS has run (if necessary).
supported_options.extend(
(
ATTR_PREFERRED_FORMAT,
ATTR_PREFERRED_SAMPLE_RATE,
ATTR_PREFERRED_SAMPLE_CHANNELS,
)
)
invalid_opts = [
opt_name for opt_name in merged_options if opt_name not in supported_options
]
@ -520,12 +617,7 @@ class SpeechManager:
# Load speech from engine into memory
else:
filename = await self._async_get_tts_audio(
engine_instance,
cache_key,
message,
use_cache,
language,
options,
engine_instance, cache_key, message, use_cache, language, options
)
return f"/api/tts_proxy/{filename}"
@ -590,10 +682,10 @@ class SpeechManager:
This method is a coroutine.
"""
if options is not None and ATTR_AUDIO_OUTPUT in options:
expected_extension = options[ATTR_AUDIO_OUTPUT]
else:
expected_extension = None
options = options or {}
# Default to MP3 unless a different format is preferred
final_extension = options.get(ATTR_PREFERRED_FORMAT, "mp3")
async def get_tts_data() -> str:
"""Handle data available."""
@ -614,8 +706,27 @@ class SpeechManager:
f"No TTS from {engine_instance.name} for '{message}'"
)
# Only convert if we have a preferred format different than the
# expected format from the TTS system, or if a specific sample
# rate/format/channel count is requested.
needs_conversion = (
(final_extension != extension)
or (ATTR_PREFERRED_SAMPLE_RATE in options)
or (ATTR_PREFERRED_SAMPLE_CHANNELS in options)
)
if needs_conversion:
data = await async_convert_audio(
self.hass,
extension,
data,
to_extension=final_extension,
to_sample_rate=options.get(ATTR_PREFERRED_SAMPLE_RATE),
to_sample_channels=options.get(ATTR_PREFERRED_SAMPLE_CHANNELS),
)
# Create file infos
filename = f"{cache_key}.{extension}".lower()
filename = f"{cache_key}.{final_extension}".lower()
# Validate filename
if not _RE_VOICE_FILE.match(filename) and not _RE_LEGACY_VOICE_FILE.match(
@ -626,10 +737,11 @@ class SpeechManager:
)
# Save to memory
if extension == "mp3":
if final_extension == "mp3":
data = self.write_tags(
filename, data, engine_instance.name, message, language, options
)
self._async_store_to_memcache(cache_key, filename, data)
if cache:
@ -641,9 +753,6 @@ class SpeechManager:
audio_task = self.hass.async_create_task(get_tts_data())
if expected_extension is None:
return await audio_task
def handle_error(_future: asyncio.Future) -> None:
"""Handle error."""
if audio_task.exception():
@ -651,7 +760,7 @@ class SpeechManager:
audio_task.add_done_callback(handle_error)
filename = f"{cache_key}.{expected_extension}".lower()
filename = f"{cache_key}.{final_extension}".lower()
self.mem_cache[cache_key] = {
"filename": filename,
"voice": b"",
@ -747,11 +856,12 @@ class SpeechManager:
raise HomeAssistantError(f"{cache_key} not in cache!")
await self._async_file_to_mem(cache_key)
content, _ = mimetypes.guess_type(filename)
cached = self.mem_cache[cache_key]
if pending := cached.get("pending"):
await pending
cached = self.mem_cache[cache_key]
content, _ = mimetypes.guess_type(filename)
return content, cached["voice"]
@staticmethod

View file

@ -3,7 +3,7 @@
"name": "Text-to-speech (TTS)",
"after_dependencies": ["media_player"],
"codeowners": ["@home-assistant/core", "@pvizeli"],
"dependencies": ["http"],
"dependencies": ["http", "ffmpeg"],
"documentation": "https://www.home-assistant.io/integrations/tts",
"integration_type": "entity",
"loggers": ["mutagen"],