Compare commits
6 commits
dev
...
wyoming-mp
Author | SHA1 | Date | |
---|---|---|---|
|
28a9ceee3e | ||
|
c64ac285e0 | ||
|
d57d8e21f6 | ||
|
86a6b941d2 | ||
|
fa7a5c7bc1 | ||
|
3a8a8e8813 |
4 changed files with 144 additions and 22 deletions
|
@ -3,6 +3,7 @@
|
|||
"name": "Wyoming Protocol",
|
||||
"codeowners": ["@balloob", "@synesthesiam"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["ffmpeg"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/wyoming",
|
||||
"iot_class": "local_push",
|
||||
"requirements": ["wyoming==0.0.1"]
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Support for Wyoming text to speech services."""
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
import io
|
||||
import logging
|
||||
|
@ -8,7 +9,7 @@ from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStop
|
|||
from wyoming.client import AsyncTcpClient
|
||||
from wyoming.tts import Synthesize
|
||||
|
||||
from homeassistant.components import tts
|
||||
from homeassistant.components import ffmpeg, tts
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
@ -18,6 +19,7 @@ from .data import WyomingService
|
|||
from .error import WyomingError
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
_DEFAULT_FORMAT = "mp3"
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
|
@ -27,9 +29,10 @@ async def async_setup_entry(
|
|||
) -> None:
|
||||
"""Set up Wyoming speech to text."""
|
||||
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id]
|
||||
ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass)
|
||||
async_add_entities(
|
||||
[
|
||||
WyomingTtsProvider(config_entry, service),
|
||||
WyomingTtsProvider(config_entry, service, ffmpeg_manager),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -41,9 +44,11 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
|
|||
self,
|
||||
config_entry: ConfigEntry,
|
||||
service: WyomingService,
|
||||
ffmpeg_manager: ffmpeg.FFmpegManager,
|
||||
) -> None:
|
||||
"""Set up provider."""
|
||||
self.service = service
|
||||
self._ffmpeg_manager = ffmpeg_manager
|
||||
self._tts_service = next(tts for tts in service.info.tts if tts.installed)
|
||||
|
||||
voice_languages: set[str] = set()
|
||||
|
@ -87,7 +92,7 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
|
|||
@property
|
||||
def default_options(self):
|
||||
"""Return a dict include default options."""
|
||||
return {tts.ATTR_AUDIO_OUTPUT: "wav"}
|
||||
return {tts.ATTR_AUDIO_OUTPUT: _DEFAULT_FORMAT}
|
||||
|
||||
@callback
|
||||
def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None:
|
||||
|
@ -129,9 +134,20 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
|
|||
except (OSError, WyomingError):
|
||||
return (None, None)
|
||||
|
||||
if (options is None) or (options[tts.ATTR_AUDIO_OUTPUT] == "wav"):
|
||||
if options is None:
|
||||
output_format = _DEFAULT_FORMAT
|
||||
else:
|
||||
output_format = options.get(tts.ATTR_AUDIO_OUTPUT, _DEFAULT_FORMAT)
|
||||
|
||||
if output_format == "wav":
|
||||
# Already WAV data
|
||||
return ("wav", data)
|
||||
|
||||
if output_format != "raw":
|
||||
# Convert with ffmpeg
|
||||
converted_data = await self._convert_audio(data, output_format)
|
||||
return (output_format, converted_data)
|
||||
|
||||
# Raw output (convert to 16Khz, 16-bit mono)
|
||||
with io.BytesIO(data) as wav_io:
|
||||
wav_reader: wave.Wave_read = wave.open(wav_io, "rb")
|
||||
|
@ -153,3 +169,33 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
|
|||
)
|
||||
|
||||
return ("raw", raw_data)
|
||||
|
||||
async def _convert_audio(self, wav_data: bytes, output_format: str) -> bytes:
|
||||
"""Convert from WAV to a different format using ffmpeg asynchronously."""
|
||||
ffmpeg_input = [
|
||||
"-f",
|
||||
"wav",
|
||||
"-i",
|
||||
"pipe:", # input from stdin
|
||||
]
|
||||
ffmpeg_output = [
|
||||
"-f",
|
||||
output_format,
|
||||
]
|
||||
|
||||
if output_format == "mp3":
|
||||
ffmpeg_output.extend(["-q:a", "0"]) # max quality
|
||||
|
||||
ffmpeg_output.append("pipe:") # output to stdout
|
||||
|
||||
ffmpeg_proc = await asyncio.create_subprocess_exec(
|
||||
self._ffmpeg_manager.binary,
|
||||
*ffmpeg_input,
|
||||
*ffmpeg_output,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.DEVNULL,
|
||||
)
|
||||
|
||||
stdout, _stderr = await ffmpeg_proc.communicate(input=wav_data)
|
||||
return stdout
|
||||
|
|
|
@ -10,6 +10,39 @@
|
|||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_get_tts_audio.1
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_get_tts_audio_format[raw]
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_get_tts_audio_format[wav]
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_get_tts_audio_raw
|
||||
list([
|
||||
dict({
|
||||
|
@ -21,3 +54,14 @@
|
|||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_no_options
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
|
|
|
@ -32,6 +32,7 @@ async def test_support(hass: HomeAssistant, init_wyoming_tts) -> None:
|
|||
|
||||
assert entity.supported_languages == ["en-US"]
|
||||
assert entity.supported_options == [tts.ATTR_AUDIO_OUTPUT, tts.ATTR_VOICE]
|
||||
assert entity.default_language == "en-US"
|
||||
voices = entity.async_get_supported_voices("en-US")
|
||||
assert len(voices) == 1
|
||||
assert voices[0].name == "Test Voice"
|
||||
|
@ -39,6 +40,33 @@ async def test_support(hass: HomeAssistant, init_wyoming_tts) -> None:
|
|||
assert not entity.async_get_supported_voices("de-DE")
|
||||
|
||||
|
||||
async def test_no_options(hass: HomeAssistant, init_wyoming_tts, snapshot) -> None:
|
||||
"""Test options=None."""
|
||||
audio = bytes(100)
|
||||
audio_events = [
|
||||
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
||||
AudioStop().event(),
|
||||
]
|
||||
|
||||
state = hass.states.get("tts.test_tts")
|
||||
assert state is not None
|
||||
|
||||
entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_tts")
|
||||
assert entity is not None
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
||||
MockAsyncTcpClient(audio_events),
|
||||
) as mock_client:
|
||||
extension, data = await entity.async_get_tts_audio(
|
||||
"Hello world", "en-US", options=None
|
||||
)
|
||||
|
||||
assert extension == "mp3"
|
||||
assert data is not None
|
||||
assert mock_client.written == snapshot
|
||||
|
||||
|
||||
async def test_get_tts_audio(hass: HomeAssistant, init_wyoming_tts, snapshot) -> None:
|
||||
"""Test get audio."""
|
||||
audio = bytes(100)
|
||||
|
@ -56,21 +84,16 @@ async def test_get_tts_audio(hass: HomeAssistant, init_wyoming_tts, snapshot) ->
|
|||
tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"),
|
||||
)
|
||||
|
||||
assert extension == "wav"
|
||||
assert data is not None
|
||||
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||
assert wav_file.getframerate() == 16000
|
||||
assert wav_file.getsampwidth() == 2
|
||||
assert wav_file.getnchannels() == 1
|
||||
assert wav_file.readframes(wav_file.getnframes()) == audio
|
||||
|
||||
assert mock_client.written == snapshot
|
||||
assert extension == "mp3"
|
||||
assert data is not None
|
||||
assert mock_client.written == snapshot
|
||||
|
||||
|
||||
async def test_get_tts_audio_raw(
|
||||
hass: HomeAssistant, init_wyoming_tts, snapshot
|
||||
@pytest.mark.parametrize("audio_format", ("wav", "raw"))
|
||||
async def test_get_tts_audio_format(
|
||||
hass: HomeAssistant, init_wyoming_tts, snapshot, audio_format: str
|
||||
) -> None:
|
||||
"""Test get raw audio."""
|
||||
"""Test get audio in a specific format."""
|
||||
audio = bytes(100)
|
||||
audio_events = [
|
||||
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
||||
|
@ -88,12 +111,22 @@ async def test_get_tts_audio_raw(
|
|||
"Hello world",
|
||||
"tts.test_tts",
|
||||
"en-US",
|
||||
options={tts.ATTR_AUDIO_OUTPUT: "raw"},
|
||||
options={tts.ATTR_AUDIO_OUTPUT: audio_format},
|
||||
),
|
||||
)
|
||||
|
||||
assert extension == "raw"
|
||||
assert data == audio
|
||||
assert extension == audio_format
|
||||
|
||||
if audio_format == "raw":
|
||||
assert data == audio
|
||||
else:
|
||||
# Verify WAV audio
|
||||
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||
assert wav_file.getframerate() == 16000
|
||||
assert wav_file.getsampwidth() == 2
|
||||
assert wav_file.getnchannels() == 1
|
||||
assert wav_file.readframes(wav_file.getnframes()) == audio
|
||||
|
||||
assert mock_client.written == snapshot
|
||||
|
||||
|
||||
|
@ -133,7 +166,5 @@ async def test_get_tts_audio_audio_oserror(
|
|||
):
|
||||
await tts.async_get_media_source_audio(
|
||||
hass,
|
||||
tts.generate_media_source_id(
|
||||
hass, "Hello world", "tts.test_tts", hass.config.language
|
||||
),
|
||||
tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"),
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue