Compare commits

...
Sign in to create a new pull request.

6 commits

Author SHA1 Message Date
Michael Hansen
28a9ceee3e Fix parameterize 2023-05-10 12:57:22 -05:00
Michael Hansen
c64ac285e0 Increase test coverage 2023-05-10 12:57:22 -05:00
Michael Hansen
d57d8e21f6 Appeasing the codebot 2023-05-10 12:57:22 -05:00
Michael Hansen
86a6b941d2 Test empty options too 2023-05-10 12:57:22 -05:00
Michael Hansen
fa7a5c7bc1 Fix test 2023-05-10 12:57:22 -05:00
Michael Hansen
3a8a8e8813 Use ffmpeg to convert Piper audio to mp3 2023-05-10 12:57:22 -05:00
4 changed files with 144 additions and 22 deletions

View file

@ -3,6 +3,7 @@
"name": "Wyoming Protocol",
"codeowners": ["@balloob", "@synesthesiam"],
"config_flow": true,
"dependencies": ["ffmpeg"],
"documentation": "https://www.home-assistant.io/integrations/wyoming",
"iot_class": "local_push",
"requirements": ["wyoming==0.0.1"]

View file

@ -1,4 +1,5 @@
"""Support for Wyoming text to speech services."""
import asyncio
from collections import defaultdict
import io
import logging
@ -8,7 +9,7 @@ from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStop
from wyoming.client import AsyncTcpClient
from wyoming.tts import Synthesize
from homeassistant.components import tts
from homeassistant.components import ffmpeg, tts
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
@ -18,6 +19,7 @@ from .data import WyomingService
from .error import WyomingError
_LOGGER = logging.getLogger(__name__)
_DEFAULT_FORMAT = "mp3"
async def async_setup_entry(
@ -27,9 +29,10 @@ async def async_setup_entry(
) -> None:
"""Set up Wyoming speech to text."""
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id]
ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass)
async_add_entities(
[
WyomingTtsProvider(config_entry, service),
WyomingTtsProvider(config_entry, service, ffmpeg_manager),
]
)
@ -41,9 +44,11 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
self,
config_entry: ConfigEntry,
service: WyomingService,
ffmpeg_manager: ffmpeg.FFmpegManager,
) -> None:
"""Set up provider."""
self.service = service
self._ffmpeg_manager = ffmpeg_manager
self._tts_service = next(tts for tts in service.info.tts if tts.installed)
voice_languages: set[str] = set()
@ -87,7 +92,7 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
@property
def default_options(self):
"""Return a dict include default options."""
return {tts.ATTR_AUDIO_OUTPUT: "wav"}
return {tts.ATTR_AUDIO_OUTPUT: _DEFAULT_FORMAT}
@callback
def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None:
@ -129,9 +134,20 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
except (OSError, WyomingError):
return (None, None)
if (options is None) or (options[tts.ATTR_AUDIO_OUTPUT] == "wav"):
if options is None:
output_format = _DEFAULT_FORMAT
else:
output_format = options.get(tts.ATTR_AUDIO_OUTPUT, _DEFAULT_FORMAT)
if output_format == "wav":
# Already WAV data
return ("wav", data)
if output_format != "raw":
# Convert with ffmpeg
converted_data = await self._convert_audio(data, output_format)
return (output_format, converted_data)
# Raw output (convert to 16Khz, 16-bit mono)
with io.BytesIO(data) as wav_io:
wav_reader: wave.Wave_read = wave.open(wav_io, "rb")
@ -153,3 +169,33 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
)
return ("raw", raw_data)
async def _convert_audio(self, wav_data: bytes, output_format: str) -> bytes:
"""Convert from WAV to a different format using ffmpeg asynchronously."""
ffmpeg_input = [
"-f",
"wav",
"-i",
"pipe:", # input from stdin
]
ffmpeg_output = [
"-f",
output_format,
]
if output_format == "mp3":
ffmpeg_output.extend(["-q:a", "0"]) # max quality
ffmpeg_output.append("pipe:") # output to stdout
ffmpeg_proc = await asyncio.create_subprocess_exec(
self._ffmpeg_manager.binary,
*ffmpeg_input,
*ffmpeg_output,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.DEVNULL,
)
stdout, _stderr = await ffmpeg_proc.communicate(input=wav_data)
return stdout

View file

@ -10,6 +10,39 @@
}),
])
# ---
# name: test_get_tts_audio.1
list([
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize',
}),
])
# ---
# name: test_get_tts_audio_format[raw]
list([
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize',
}),
])
# ---
# name: test_get_tts_audio_format[wav]
list([
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize',
}),
])
# ---
# name: test_get_tts_audio_raw
list([
dict({
@ -21,3 +54,14 @@
}),
])
# ---
# name: test_no_options
list([
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize',
}),
])
# ---

View file

@ -32,6 +32,7 @@ async def test_support(hass: HomeAssistant, init_wyoming_tts) -> None:
assert entity.supported_languages == ["en-US"]
assert entity.supported_options == [tts.ATTR_AUDIO_OUTPUT, tts.ATTR_VOICE]
assert entity.default_language == "en-US"
voices = entity.async_get_supported_voices("en-US")
assert len(voices) == 1
assert voices[0].name == "Test Voice"
@ -39,6 +40,33 @@ async def test_support(hass: HomeAssistant, init_wyoming_tts) -> None:
assert not entity.async_get_supported_voices("de-DE")
async def test_no_options(hass: HomeAssistant, init_wyoming_tts, snapshot) -> None:
"""Test options=None."""
audio = bytes(100)
audio_events = [
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(),
]
state = hass.states.get("tts.test_tts")
assert state is not None
entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_tts")
assert entity is not None
with patch(
"homeassistant.components.wyoming.tts.AsyncTcpClient",
MockAsyncTcpClient(audio_events),
) as mock_client:
extension, data = await entity.async_get_tts_audio(
"Hello world", "en-US", options=None
)
assert extension == "mp3"
assert data is not None
assert mock_client.written == snapshot
async def test_get_tts_audio(hass: HomeAssistant, init_wyoming_tts, snapshot) -> None:
"""Test get audio."""
audio = bytes(100)
@ -56,21 +84,16 @@ async def test_get_tts_audio(hass: HomeAssistant, init_wyoming_tts, snapshot) ->
tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"),
)
assert extension == "wav"
assert data is not None
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
assert wav_file.getframerate() == 16000
assert wav_file.getsampwidth() == 2
assert wav_file.getnchannels() == 1
assert wav_file.readframes(wav_file.getnframes()) == audio
assert mock_client.written == snapshot
assert extension == "mp3"
assert data is not None
assert mock_client.written == snapshot
async def test_get_tts_audio_raw(
hass: HomeAssistant, init_wyoming_tts, snapshot
@pytest.mark.parametrize("audio_format", ("wav", "raw"))
async def test_get_tts_audio_format(
hass: HomeAssistant, init_wyoming_tts, snapshot, audio_format: str
) -> None:
"""Test get raw audio."""
"""Test get audio in a specific format."""
audio = bytes(100)
audio_events = [
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
@ -88,12 +111,22 @@ async def test_get_tts_audio_raw(
"Hello world",
"tts.test_tts",
"en-US",
options={tts.ATTR_AUDIO_OUTPUT: "raw"},
options={tts.ATTR_AUDIO_OUTPUT: audio_format},
),
)
assert extension == "raw"
assert data == audio
assert extension == audio_format
if audio_format == "raw":
assert data == audio
else:
# Verify WAV audio
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
assert wav_file.getframerate() == 16000
assert wav_file.getsampwidth() == 2
assert wav_file.getnchannels() == 1
assert wav_file.readframes(wav_file.getnframes()) == audio
assert mock_client.written == snapshot
@ -133,7 +166,5 @@ async def test_get_tts_audio_audio_oserror(
):
await tts.async_get_media_source_audio(
hass,
tts.generate_media_source_id(
hass, "Hello world", "tts.test_tts", hass.config.language
),
tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"),
)