Compare commits

...
Sign in to create a new pull request.

6 commits

Author SHA1 Message Date
Michael Hansen
28a9ceee3e Fix parameterize 2023-05-10 12:57:22 -05:00
Michael Hansen
c64ac285e0 Increase test coverage 2023-05-10 12:57:22 -05:00
Michael Hansen
d57d8e21f6 Appeasing the codebot 2023-05-10 12:57:22 -05:00
Michael Hansen
86a6b941d2 Test empty options too 2023-05-10 12:57:22 -05:00
Michael Hansen
fa7a5c7bc1 Fix test 2023-05-10 12:57:22 -05:00
Michael Hansen
3a8a8e8813 Use ffmpeg to convert Piper audio to mp3 2023-05-10 12:57:22 -05:00
4 changed files with 144 additions and 22 deletions

View file

@ -3,6 +3,7 @@
"name": "Wyoming Protocol", "name": "Wyoming Protocol",
"codeowners": ["@balloob", "@synesthesiam"], "codeowners": ["@balloob", "@synesthesiam"],
"config_flow": true, "config_flow": true,
"dependencies": ["ffmpeg"],
"documentation": "https://www.home-assistant.io/integrations/wyoming", "documentation": "https://www.home-assistant.io/integrations/wyoming",
"iot_class": "local_push", "iot_class": "local_push",
"requirements": ["wyoming==0.0.1"] "requirements": ["wyoming==0.0.1"]

View file

@ -1,4 +1,5 @@
"""Support for Wyoming text to speech services.""" """Support for Wyoming text to speech services."""
import asyncio
from collections import defaultdict from collections import defaultdict
import io import io
import logging import logging
@ -8,7 +9,7 @@ from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStop
from wyoming.client import AsyncTcpClient from wyoming.client import AsyncTcpClient
from wyoming.tts import Synthesize from wyoming.tts import Synthesize
from homeassistant.components import tts from homeassistant.components import ffmpeg, tts
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
@ -18,6 +19,7 @@ from .data import WyomingService
from .error import WyomingError from .error import WyomingError
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_DEFAULT_FORMAT = "mp3"
async def async_setup_entry( async def async_setup_entry(
@ -27,9 +29,10 @@ async def async_setup_entry(
) -> None: ) -> None:
"""Set up Wyoming speech to text.""" """Set up Wyoming speech to text."""
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id] service: WyomingService = hass.data[DOMAIN][config_entry.entry_id]
ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass)
async_add_entities( async_add_entities(
[ [
WyomingTtsProvider(config_entry, service), WyomingTtsProvider(config_entry, service, ffmpeg_manager),
] ]
) )
@ -41,9 +44,11 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
self, self,
config_entry: ConfigEntry, config_entry: ConfigEntry,
service: WyomingService, service: WyomingService,
ffmpeg_manager: ffmpeg.FFmpegManager,
) -> None: ) -> None:
"""Set up provider.""" """Set up provider."""
self.service = service self.service = service
self._ffmpeg_manager = ffmpeg_manager
self._tts_service = next(tts for tts in service.info.tts if tts.installed) self._tts_service = next(tts for tts in service.info.tts if tts.installed)
voice_languages: set[str] = set() voice_languages: set[str] = set()
@ -87,7 +92,7 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
@property @property
def default_options(self): def default_options(self):
"""Return a dict include default options.""" """Return a dict include default options."""
return {tts.ATTR_AUDIO_OUTPUT: "wav"} return {tts.ATTR_AUDIO_OUTPUT: _DEFAULT_FORMAT}
@callback @callback
def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None: def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None:
@ -129,9 +134,20 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
except (OSError, WyomingError): except (OSError, WyomingError):
return (None, None) return (None, None)
if (options is None) or (options[tts.ATTR_AUDIO_OUTPUT] == "wav"): if options is None:
output_format = _DEFAULT_FORMAT
else:
output_format = options.get(tts.ATTR_AUDIO_OUTPUT, _DEFAULT_FORMAT)
if output_format == "wav":
# Already WAV data
return ("wav", data) return ("wav", data)
if output_format != "raw":
# Convert with ffmpeg
converted_data = await self._convert_audio(data, output_format)
return (output_format, converted_data)
# Raw output (convert to 16Khz, 16-bit mono) # Raw output (convert to 16Khz, 16-bit mono)
with io.BytesIO(data) as wav_io: with io.BytesIO(data) as wav_io:
wav_reader: wave.Wave_read = wave.open(wav_io, "rb") wav_reader: wave.Wave_read = wave.open(wav_io, "rb")
@ -153,3 +169,33 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
) )
return ("raw", raw_data) return ("raw", raw_data)
async def _convert_audio(self, wav_data: bytes, output_format: str) -> bytes:
"""Convert from WAV to a different format using ffmpeg asynchronously."""
ffmpeg_input = [
"-f",
"wav",
"-i",
"pipe:", # input from stdin
]
ffmpeg_output = [
"-f",
output_format,
]
if output_format == "mp3":
ffmpeg_output.extend(["-q:a", "0"]) # max quality
ffmpeg_output.append("pipe:") # output to stdout
ffmpeg_proc = await asyncio.create_subprocess_exec(
self._ffmpeg_manager.binary,
*ffmpeg_input,
*ffmpeg_output,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.DEVNULL,
)
stdout, _stderr = await ffmpeg_proc.communicate(input=wav_data)
return stdout

View file

@ -10,6 +10,39 @@
}), }),
]) ])
# --- # ---
# name: test_get_tts_audio.1
list([
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize',
}),
])
# ---
# name: test_get_tts_audio_format[raw]
list([
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize',
}),
])
# ---
# name: test_get_tts_audio_format[wav]
list([
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize',
}),
])
# ---
# name: test_get_tts_audio_raw # name: test_get_tts_audio_raw
list([ list([
dict({ dict({
@ -21,3 +54,14 @@
}), }),
]) ])
# --- # ---
# name: test_no_options
list([
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize',
}),
])
# ---

View file

@ -32,6 +32,7 @@ async def test_support(hass: HomeAssistant, init_wyoming_tts) -> None:
assert entity.supported_languages == ["en-US"] assert entity.supported_languages == ["en-US"]
assert entity.supported_options == [tts.ATTR_AUDIO_OUTPUT, tts.ATTR_VOICE] assert entity.supported_options == [tts.ATTR_AUDIO_OUTPUT, tts.ATTR_VOICE]
assert entity.default_language == "en-US"
voices = entity.async_get_supported_voices("en-US") voices = entity.async_get_supported_voices("en-US")
assert len(voices) == 1 assert len(voices) == 1
assert voices[0].name == "Test Voice" assert voices[0].name == "Test Voice"
@ -39,6 +40,33 @@ async def test_support(hass: HomeAssistant, init_wyoming_tts) -> None:
assert not entity.async_get_supported_voices("de-DE") assert not entity.async_get_supported_voices("de-DE")
async def test_no_options(hass: HomeAssistant, init_wyoming_tts, snapshot) -> None:
"""Test options=None."""
audio = bytes(100)
audio_events = [
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(),
]
state = hass.states.get("tts.test_tts")
assert state is not None
entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_tts")
assert entity is not None
with patch(
"homeassistant.components.wyoming.tts.AsyncTcpClient",
MockAsyncTcpClient(audio_events),
) as mock_client:
extension, data = await entity.async_get_tts_audio(
"Hello world", "en-US", options=None
)
assert extension == "mp3"
assert data is not None
assert mock_client.written == snapshot
async def test_get_tts_audio(hass: HomeAssistant, init_wyoming_tts, snapshot) -> None: async def test_get_tts_audio(hass: HomeAssistant, init_wyoming_tts, snapshot) -> None:
"""Test get audio.""" """Test get audio."""
audio = bytes(100) audio = bytes(100)
@ -56,21 +84,16 @@ async def test_get_tts_audio(hass: HomeAssistant, init_wyoming_tts, snapshot) ->
tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"), tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"),
) )
assert extension == "wav" assert extension == "mp3"
assert data is not None assert data is not None
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
assert wav_file.getframerate() == 16000
assert wav_file.getsampwidth() == 2
assert wav_file.getnchannels() == 1
assert wav_file.readframes(wav_file.getnframes()) == audio
assert mock_client.written == snapshot assert mock_client.written == snapshot
async def test_get_tts_audio_raw( @pytest.mark.parametrize("audio_format", ("wav", "raw"))
hass: HomeAssistant, init_wyoming_tts, snapshot async def test_get_tts_audio_format(
hass: HomeAssistant, init_wyoming_tts, snapshot, audio_format: str
) -> None: ) -> None:
"""Test get raw audio.""" """Test get audio in a specific format."""
audio = bytes(100) audio = bytes(100)
audio_events = [ audio_events = [
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(), AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
@ -88,12 +111,22 @@ async def test_get_tts_audio_raw(
"Hello world", "Hello world",
"tts.test_tts", "tts.test_tts",
"en-US", "en-US",
options={tts.ATTR_AUDIO_OUTPUT: "raw"}, options={tts.ATTR_AUDIO_OUTPUT: audio_format},
), ),
) )
assert extension == "raw" assert extension == audio_format
if audio_format == "raw":
assert data == audio assert data == audio
else:
# Verify WAV audio
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
assert wav_file.getframerate() == 16000
assert wav_file.getsampwidth() == 2
assert wav_file.getnchannels() == 1
assert wav_file.readframes(wav_file.getnframes()) == audio
assert mock_client.written == snapshot assert mock_client.written == snapshot
@ -133,7 +166,5 @@ async def test_get_tts_audio_audio_oserror(
): ):
await tts.async_get_media_source_audio( await tts.async_get_media_source_audio(
hass, hass,
tts.generate_media_source_id( tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"),
hass, "Hello world", "tts.test_tts", hass.config.language
),
) )