Compare commits
6 commits
dev
...
wyoming-mp
Author | SHA1 | Date | |
---|---|---|---|
|
28a9ceee3e | ||
|
c64ac285e0 | ||
|
d57d8e21f6 | ||
|
86a6b941d2 | ||
|
fa7a5c7bc1 | ||
|
3a8a8e8813 |
4 changed files with 144 additions and 22 deletions
|
@ -3,6 +3,7 @@
|
||||||
"name": "Wyoming Protocol",
|
"name": "Wyoming Protocol",
|
||||||
"codeowners": ["@balloob", "@synesthesiam"],
|
"codeowners": ["@balloob", "@synesthesiam"],
|
||||||
"config_flow": true,
|
"config_flow": true,
|
||||||
|
"dependencies": ["ffmpeg"],
|
||||||
"documentation": "https://www.home-assistant.io/integrations/wyoming",
|
"documentation": "https://www.home-assistant.io/integrations/wyoming",
|
||||||
"iot_class": "local_push",
|
"iot_class": "local_push",
|
||||||
"requirements": ["wyoming==0.0.1"]
|
"requirements": ["wyoming==0.0.1"]
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""Support for Wyoming text to speech services."""
|
"""Support for Wyoming text to speech services."""
|
||||||
|
import asyncio
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
|
@ -8,7 +9,7 @@ from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStop
|
||||||
from wyoming.client import AsyncTcpClient
|
from wyoming.client import AsyncTcpClient
|
||||||
from wyoming.tts import Synthesize
|
from wyoming.tts import Synthesize
|
||||||
|
|
||||||
from homeassistant.components import tts
|
from homeassistant.components import ffmpeg, tts
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
@ -18,6 +19,7 @@ from .data import WyomingService
|
||||||
from .error import WyomingError
|
from .error import WyomingError
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
_DEFAULT_FORMAT = "mp3"
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_entry(
|
async def async_setup_entry(
|
||||||
|
@ -27,9 +29,10 @@ async def async_setup_entry(
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up Wyoming speech to text."""
|
"""Set up Wyoming speech to text."""
|
||||||
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id]
|
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id]
|
||||||
|
ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass)
|
||||||
async_add_entities(
|
async_add_entities(
|
||||||
[
|
[
|
||||||
WyomingTtsProvider(config_entry, service),
|
WyomingTtsProvider(config_entry, service, ffmpeg_manager),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -41,9 +44,11 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
|
||||||
self,
|
self,
|
||||||
config_entry: ConfigEntry,
|
config_entry: ConfigEntry,
|
||||||
service: WyomingService,
|
service: WyomingService,
|
||||||
|
ffmpeg_manager: ffmpeg.FFmpegManager,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up provider."""
|
"""Set up provider."""
|
||||||
self.service = service
|
self.service = service
|
||||||
|
self._ffmpeg_manager = ffmpeg_manager
|
||||||
self._tts_service = next(tts for tts in service.info.tts if tts.installed)
|
self._tts_service = next(tts for tts in service.info.tts if tts.installed)
|
||||||
|
|
||||||
voice_languages: set[str] = set()
|
voice_languages: set[str] = set()
|
||||||
|
@ -87,7 +92,7 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
|
||||||
@property
|
@property
|
||||||
def default_options(self):
|
def default_options(self):
|
||||||
"""Return a dict include default options."""
|
"""Return a dict include default options."""
|
||||||
return {tts.ATTR_AUDIO_OUTPUT: "wav"}
|
return {tts.ATTR_AUDIO_OUTPUT: _DEFAULT_FORMAT}
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None:
|
def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None:
|
||||||
|
@ -129,9 +134,20 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
|
||||||
except (OSError, WyomingError):
|
except (OSError, WyomingError):
|
||||||
return (None, None)
|
return (None, None)
|
||||||
|
|
||||||
if (options is None) or (options[tts.ATTR_AUDIO_OUTPUT] == "wav"):
|
if options is None:
|
||||||
|
output_format = _DEFAULT_FORMAT
|
||||||
|
else:
|
||||||
|
output_format = options.get(tts.ATTR_AUDIO_OUTPUT, _DEFAULT_FORMAT)
|
||||||
|
|
||||||
|
if output_format == "wav":
|
||||||
|
# Already WAV data
|
||||||
return ("wav", data)
|
return ("wav", data)
|
||||||
|
|
||||||
|
if output_format != "raw":
|
||||||
|
# Convert with ffmpeg
|
||||||
|
converted_data = await self._convert_audio(data, output_format)
|
||||||
|
return (output_format, converted_data)
|
||||||
|
|
||||||
# Raw output (convert to 16Khz, 16-bit mono)
|
# Raw output (convert to 16Khz, 16-bit mono)
|
||||||
with io.BytesIO(data) as wav_io:
|
with io.BytesIO(data) as wav_io:
|
||||||
wav_reader: wave.Wave_read = wave.open(wav_io, "rb")
|
wav_reader: wave.Wave_read = wave.open(wav_io, "rb")
|
||||||
|
@ -153,3 +169,33 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
|
||||||
)
|
)
|
||||||
|
|
||||||
return ("raw", raw_data)
|
return ("raw", raw_data)
|
||||||
|
|
||||||
|
async def _convert_audio(self, wav_data: bytes, output_format: str) -> bytes:
|
||||||
|
"""Convert from WAV to a different format using ffmpeg asynchronously."""
|
||||||
|
ffmpeg_input = [
|
||||||
|
"-f",
|
||||||
|
"wav",
|
||||||
|
"-i",
|
||||||
|
"pipe:", # input from stdin
|
||||||
|
]
|
||||||
|
ffmpeg_output = [
|
||||||
|
"-f",
|
||||||
|
output_format,
|
||||||
|
]
|
||||||
|
|
||||||
|
if output_format == "mp3":
|
||||||
|
ffmpeg_output.extend(["-q:a", "0"]) # max quality
|
||||||
|
|
||||||
|
ffmpeg_output.append("pipe:") # output to stdout
|
||||||
|
|
||||||
|
ffmpeg_proc = await asyncio.create_subprocess_exec(
|
||||||
|
self._ffmpeg_manager.binary,
|
||||||
|
*ffmpeg_input,
|
||||||
|
*ffmpeg_output,
|
||||||
|
stdin=asyncio.subprocess.PIPE,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.DEVNULL,
|
||||||
|
)
|
||||||
|
|
||||||
|
stdout, _stderr = await ffmpeg_proc.communicate(input=wav_data)
|
||||||
|
return stdout
|
||||||
|
|
|
@ -10,6 +10,39 @@
|
||||||
}),
|
}),
|
||||||
])
|
])
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_get_tts_audio.1
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'text': 'Hello world',
|
||||||
|
}),
|
||||||
|
'payload': None,
|
||||||
|
'type': 'synthesize',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
# name: test_get_tts_audio_format[raw]
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'text': 'Hello world',
|
||||||
|
}),
|
||||||
|
'payload': None,
|
||||||
|
'type': 'synthesize',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
# name: test_get_tts_audio_format[wav]
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'text': 'Hello world',
|
||||||
|
}),
|
||||||
|
'payload': None,
|
||||||
|
'type': 'synthesize',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
# name: test_get_tts_audio_raw
|
# name: test_get_tts_audio_raw
|
||||||
list([
|
list([
|
||||||
dict({
|
dict({
|
||||||
|
@ -21,3 +54,14 @@
|
||||||
}),
|
}),
|
||||||
])
|
])
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_no_options
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'text': 'Hello world',
|
||||||
|
}),
|
||||||
|
'payload': None,
|
||||||
|
'type': 'synthesize',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
|
|
@ -32,6 +32,7 @@ async def test_support(hass: HomeAssistant, init_wyoming_tts) -> None:
|
||||||
|
|
||||||
assert entity.supported_languages == ["en-US"]
|
assert entity.supported_languages == ["en-US"]
|
||||||
assert entity.supported_options == [tts.ATTR_AUDIO_OUTPUT, tts.ATTR_VOICE]
|
assert entity.supported_options == [tts.ATTR_AUDIO_OUTPUT, tts.ATTR_VOICE]
|
||||||
|
assert entity.default_language == "en-US"
|
||||||
voices = entity.async_get_supported_voices("en-US")
|
voices = entity.async_get_supported_voices("en-US")
|
||||||
assert len(voices) == 1
|
assert len(voices) == 1
|
||||||
assert voices[0].name == "Test Voice"
|
assert voices[0].name == "Test Voice"
|
||||||
|
@ -39,6 +40,33 @@ async def test_support(hass: HomeAssistant, init_wyoming_tts) -> None:
|
||||||
assert not entity.async_get_supported_voices("de-DE")
|
assert not entity.async_get_supported_voices("de-DE")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_no_options(hass: HomeAssistant, init_wyoming_tts, snapshot) -> None:
|
||||||
|
"""Test options=None."""
|
||||||
|
audio = bytes(100)
|
||||||
|
audio_events = [
|
||||||
|
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
||||||
|
AudioStop().event(),
|
||||||
|
]
|
||||||
|
|
||||||
|
state = hass.states.get("tts.test_tts")
|
||||||
|
assert state is not None
|
||||||
|
|
||||||
|
entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_tts")
|
||||||
|
assert entity is not None
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
||||||
|
MockAsyncTcpClient(audio_events),
|
||||||
|
) as mock_client:
|
||||||
|
extension, data = await entity.async_get_tts_audio(
|
||||||
|
"Hello world", "en-US", options=None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert extension == "mp3"
|
||||||
|
assert data is not None
|
||||||
|
assert mock_client.written == snapshot
|
||||||
|
|
||||||
|
|
||||||
async def test_get_tts_audio(hass: HomeAssistant, init_wyoming_tts, snapshot) -> None:
|
async def test_get_tts_audio(hass: HomeAssistant, init_wyoming_tts, snapshot) -> None:
|
||||||
"""Test get audio."""
|
"""Test get audio."""
|
||||||
audio = bytes(100)
|
audio = bytes(100)
|
||||||
|
@ -56,21 +84,16 @@ async def test_get_tts_audio(hass: HomeAssistant, init_wyoming_tts, snapshot) ->
|
||||||
tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"),
|
tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert extension == "wav"
|
assert extension == "mp3"
|
||||||
assert data is not None
|
assert data is not None
|
||||||
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
|
||||||
assert wav_file.getframerate() == 16000
|
|
||||||
assert wav_file.getsampwidth() == 2
|
|
||||||
assert wav_file.getnchannels() == 1
|
|
||||||
assert wav_file.readframes(wav_file.getnframes()) == audio
|
|
||||||
|
|
||||||
assert mock_client.written == snapshot
|
assert mock_client.written == snapshot
|
||||||
|
|
||||||
|
|
||||||
async def test_get_tts_audio_raw(
|
@pytest.mark.parametrize("audio_format", ("wav", "raw"))
|
||||||
hass: HomeAssistant, init_wyoming_tts, snapshot
|
async def test_get_tts_audio_format(
|
||||||
|
hass: HomeAssistant, init_wyoming_tts, snapshot, audio_format: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test get raw audio."""
|
"""Test get audio in a specific format."""
|
||||||
audio = bytes(100)
|
audio = bytes(100)
|
||||||
audio_events = [
|
audio_events = [
|
||||||
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
||||||
|
@ -88,12 +111,22 @@ async def test_get_tts_audio_raw(
|
||||||
"Hello world",
|
"Hello world",
|
||||||
"tts.test_tts",
|
"tts.test_tts",
|
||||||
"en-US",
|
"en-US",
|
||||||
options={tts.ATTR_AUDIO_OUTPUT: "raw"},
|
options={tts.ATTR_AUDIO_OUTPUT: audio_format},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert extension == "raw"
|
assert extension == audio_format
|
||||||
|
|
||||||
|
if audio_format == "raw":
|
||||||
assert data == audio
|
assert data == audio
|
||||||
|
else:
|
||||||
|
# Verify WAV audio
|
||||||
|
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||||
|
assert wav_file.getframerate() == 16000
|
||||||
|
assert wav_file.getsampwidth() == 2
|
||||||
|
assert wav_file.getnchannels() == 1
|
||||||
|
assert wav_file.readframes(wav_file.getnframes()) == audio
|
||||||
|
|
||||||
assert mock_client.written == snapshot
|
assert mock_client.written == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@ -133,7 +166,5 @@ async def test_get_tts_audio_audio_oserror(
|
||||||
):
|
):
|
||||||
await tts.async_get_media_source_audio(
|
await tts.async_get_media_source_audio(
|
||||||
hass,
|
hass,
|
||||||
tts.generate_media_source_id(
|
tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"),
|
||||||
hass, "Hello world", "tts.test_tts", hass.config.language
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue