Automatically convert TTS audio to MP3 on demand (#102814)

* Add ATTR_PREFERRED_FORMAT to TTS for auto-converting audio

* Move conversion into SpeechManager

* Handle None case for expected_extension

* Only use ATTR_AUDIO_OUTPUT

* Prefer MP3 in pipelines

* Automatically convert to mp3 on demand

* Add preferred audio format

* Break out preferred format

* Add ATTR_BLOCKING to allow async fetching

* Make a copy of supported options

* Fix MaryTTS tests

* Update ESPHome to use "wav" instead of "raw"

* Clean up tests, remove blocking

* Clean up rest of TTS tests

* Fix ESPHome tests

* More test coverage
This commit is contained in:
Michael Hansen 2023-11-06 14:26:00 -06:00 committed by GitHub
parent 054089291f
commit ae516ffbb5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 723 additions and 241 deletions

View file

@ -3,9 +3,11 @@ from __future__ import annotations
import asyncio
from collections.abc import AsyncIterable, Callable
import io
import logging
import socket
from typing import cast
import wave
from aioesphomeapi import (
VoiceAssistantAudioSettings,
@ -88,6 +90,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
self.handle_event = handle_event
self.handle_finished = handle_finished
self._tts_done = asyncio.Event()
self._tts_task: asyncio.Task | None = None
async def start_server(self) -> int:
"""Start accepting connections."""
@ -189,7 +192,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
if self.device_info.voice_assistant_version >= 2:
media_id = event.data["tts_output"]["media_id"]
self.hass.async_create_background_task(
self._tts_task = self.hass.async_create_background_task(
self._send_tts(media_id), "esphome_voice_assistant_tts"
)
else:
@ -228,7 +231,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
audio_settings = VoiceAssistantAudioSettings()
tts_audio_output = (
"raw" if self.device_info.voice_assistant_version >= 2 else "mp3"
"wav" if self.device_info.voice_assistant_version >= 2 else "mp3"
)
_LOGGER.debug("Starting pipeline")
@ -302,11 +305,32 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {}
)
_extension, audio_bytes = await tts.async_get_media_source_audio(
extension, data = await tts.async_get_media_source_audio(
self.hass,
media_id,
)
if extension != "wav":
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
with io.BytesIO(data) as wav_io:
with wave.open(wav_io, "rb") as wav_file:
sample_rate = wav_file.getframerate()
sample_width = wav_file.getsampwidth()
sample_channels = wav_file.getnchannels()
if (
(sample_rate != 16000)
or (sample_width != 2)
or (sample_channels != 1)
):
raise ValueError(
"Expected rate/width/channels as 16000/2/1,"
" got {sample_rate}/{sample_width}/{sample_channels}}"
)
audio_bytes = wav_file.readframes(wav_file.getnframes())
_LOGGER.debug("Sending %d bytes of audio", len(audio_bytes))
bytes_per_sample = stt.AudioBitRates.BITRATE_16 // 8
@ -330,4 +354,5 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {}
)
self._tts_task = None
self._tts_done.set()