Wyoming tts (#91712)

* Add tts entity

* Add tts entity and tests

* Re-add name to TextToSpeechEntity

* Fix linting

* Fix ruff linting

* Support voice attr (unused)

* Remove async_get_text_to_speech_entity

* Move name property to Wyoming TTS entity

* Fix id

---------

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen 2023-04-23 13:06:56 -05:00 committed by GitHub
parent f4df0ca50a
commit b6f2b29a99
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 529 additions and 58 deletions

View file

@ -52,10 +52,14 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
# ASR = automated speech recognition (STT)
asr_installed = [asr for asr in service.info.asr if asr.installed]
if not asr_installed:
return self.async_abort(reason="no_services")
tts_installed = [tts for tts in service.info.tts if tts.installed]
if asr_installed:
name = asr_installed[0].name
elif tts_installed:
name = tts_installed[0].name
else:
return self.async_abort(reason="no_services")
return self.async_create_entry(title=name, data=user_input)

View file

@ -25,8 +25,10 @@ class WyomingService:
self.port = port
self.info = info
platforms = []
if info.asr:
if any(asr.installed for asr in info.asr):
platforms.append(Platform.STT)
if any(tts.installed for tts in info.tts):
platforms.append(Platform.TTS)
self.platforms = platforms
@classmethod
@ -39,14 +41,20 @@ class WyomingService:
return cls(host, port, info)
async def load_wyoming_info(host: str, port: int) -> Info | None:
async def load_wyoming_info(
host: str,
port: int,
retries: int = _INFO_RETRIES,
retry_wait: float = _INFO_RETRY_WAIT,
timeout: float = _INFO_TIMEOUT,
) -> Info | None:
"""Load info from Wyoming server."""
wyoming_info: Info | None = None
for _ in range(_INFO_RETRIES):
for _ in range(retries + 1):
try:
async with AsyncTcpClient(host, port) as client:
with async_timeout.timeout(_INFO_TIMEOUT):
with async_timeout.timeout(timeout):
# Describe -> Info
await client.write_event(Describe().event())
while True:
@ -58,9 +66,12 @@ async def load_wyoming_info(host: str, port: int) -> Info | None:
if Info.is_type(event.type):
wyoming_info = Info.from_event(event)
break
break # while
if wyoming_info is not None:
break # for
except (asyncio.TimeoutError, OSError, WyomingError):
# Sleep and try again
await asyncio.sleep(_INFO_RETRY_WAIT)
await asyncio.sleep(retry_wait)
return wyoming_info

View file

@ -0,0 +1,161 @@
"""Support for Wyoming text to speech services."""
from collections import defaultdict
import io
import logging
import wave
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStop
from wyoming.client import AsyncTcpClient
from wyoming.tts import Synthesize
from homeassistant.components import tts
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN
from .data import WyomingService
from .error import WyomingError
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up Wyoming speech to text."""
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id]
async_add_entities(
[
WyomingTtsProvider(config_entry, service),
]
)
class WyomingTtsProvider(tts.TextToSpeechEntity):
"""Wyoming text to speech provider."""
def __init__(
self,
config_entry: ConfigEntry,
service: WyomingService,
) -> None:
"""Set up provider."""
self.service = service
self._tts_service = next(tts for tts in service.info.tts if tts.installed)
voice_languages: set[str] = set()
self._voices: dict[str, list[tts.Voice]] = defaultdict(list)
for voice in self._tts_service.voices:
if not voice.installed:
continue
voice_languages.update(voice.languages)
for language in voice.languages:
self._voices[language].append(
tts.Voice(
voice_id=voice.name,
name=voice.name,
)
)
self._supported_languages: list[str] = list(voice_languages)
self._attr_name = self._tts_service.name
self._attr_unique_id = f"{config_entry.entry_id}-tts"
@property
def name(self) -> str | None:
"""Return the name of the provider entity."""
# Only one entity is allowed per platform for now.
return self._tts_service.name
@property
def default_language(self):
"""Return default language."""
if not self._supported_languages:
return None
return self._supported_languages[0]
@property
def supported_languages(self):
"""Return list of supported languages."""
return self._supported_languages
@property
def supported_options(self):
"""Return list of supported options like voice, emotion."""
return [tts.ATTR_AUDIO_OUTPUT, tts.ATTR_VOICE]
@property
def default_options(self):
"""Return a dict include default options."""
return {tts.ATTR_AUDIO_OUTPUT: "wav"}
@callback
def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None:
"""Return a list of supported voices for a language."""
return self._voices.get(language)
async def async_get_tts_audio(self, message, language, options=None):
"""Load TTS from UNIX socket."""
try:
async with AsyncTcpClient(self.service.host, self.service.port) as client:
await client.write_event(Synthesize(message).event())
with io.BytesIO() as wav_io:
wav_writer: wave.Wave_write | None = None
while True:
event = await client.read_event()
if event is None:
_LOGGER.debug("Connection lost")
return (None, None)
if AudioStop.is_type(event.type):
break
if AudioChunk.is_type(event.type):
chunk = AudioChunk.from_event(event)
if wav_writer is None:
wav_writer = wave.open(wav_io, "wb")
wav_writer.setframerate(chunk.rate)
wav_writer.setsampwidth(chunk.width)
wav_writer.setnchannels(chunk.channels)
wav_writer.writeframes(chunk.audio)
if wav_writer is not None:
wav_writer.close()
data = wav_io.getvalue()
except (OSError, WyomingError):
return (None, None)
if (options is None) or (options[tts.ATTR_AUDIO_OUTPUT] == "wav"):
return ("wav", data)
# Raw output (convert to 16Khz, 16-bit mono)
with io.BytesIO(data) as wav_io:
wav_reader: wave.Wave_read = wave.open(wav_io, "rb")
raw_data = (
AudioChunkConverter(
rate=16000,
width=2,
channels=1,
)
.convert(
AudioChunk(
audio=wav_reader.readframes(wav_reader.getnframes()),
rate=wav_reader.getframerate(),
width=wav_reader.getsampwidth(),
channels=wav_reader.getnchannels(),
)
)
.audio
)
return ("raw", raw_data)

View file

@ -1,5 +1,5 @@
"""Tests for the Wyoming integration."""
from wyoming.info import AsrModel, AsrProgram, Attribution, Info
from wyoming.info import AsrModel, AsrProgram, Attribution, Info, TtsProgram, TtsVoice
TEST_ATTR = Attribution(name="Test", url="http://www.test.com")
STT_INFO = Info(
@ -19,4 +19,53 @@ STT_INFO = Info(
)
]
)
TTS_INFO = Info(
tts=[
TtsProgram(
name="Test TTS",
installed=True,
attribution=TEST_ATTR,
voices=[
TtsVoice(
name="Test Voice",
installed=True,
attribution=TEST_ATTR,
languages=["en-US"],
)
],
)
]
)
EMPTY_INFO = Info()
class MockAsyncTcpClient:
"""Mock AsyncTcpClient."""
def __init__(self, responses) -> None:
"""Initialize."""
self.host = None
self.port = None
self.written = []
self.responses = responses
async def write_event(self, event):
"""Send."""
self.written.append(event)
async def read_event(self):
"""Receive."""
return self.responses.pop(0)
async def __aenter__(self):
"""Enter."""
return self
async def __aexit__(self, exc_type, exc, tb):
"""Exit."""
def __call__(self, host, port):
"""Call."""
self.host = host
self.port = port
return self

View file

@ -7,7 +7,7 @@ import pytest
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from . import STT_INFO
from . import STT_INFO, TTS_INFO
from tests.common import MockConfigEntry
@ -22,7 +22,7 @@ def mock_setup_entry() -> Generator[AsyncMock, None, None]:
@pytest.fixture
def config_entry(hass: HomeAssistant) -> ConfigEntry:
def stt_config_entry(hass: HomeAssistant) -> ConfigEntry:
"""Create a config entry."""
entry = MockConfigEntry(
domain="wyoming",
@ -37,10 +37,35 @@ def config_entry(hass: HomeAssistant) -> ConfigEntry:
@pytest.fixture
async def init_wyoming_stt(hass: HomeAssistant, config_entry: ConfigEntry):
"""Initialize Wyoming."""
def tts_config_entry(hass: HomeAssistant) -> ConfigEntry:
"""Create a config entry."""
entry = MockConfigEntry(
domain="wyoming",
data={
"host": "1.2.3.4",
"port": 1234,
},
title="Test TTS",
)
entry.add_to_hass(hass)
return entry
@pytest.fixture
async def init_wyoming_stt(hass: HomeAssistant, stt_config_entry: ConfigEntry):
"""Initialize Wyoming STT."""
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=STT_INFO,
):
await hass.config_entries.async_setup(config_entry.entry_id)
await hass.config_entries.async_setup(stt_config_entry.entry_id)
@pytest.fixture
async def init_wyoming_tts(hass: HomeAssistant, tts_config_entry: ConfigEntry):
"""Initialize Wyoming TTS."""
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=TTS_INFO,
):
await hass.config_entries.async_setup(tts_config_entry.entry_id)

View file

@ -0,0 +1,11 @@
# serializer version: 1
# name: test_load_info
list([
dict({
'data': dict({
}),
'payload': None,
'type': 'describe',
}),
])
# ---

View file

@ -0,0 +1,23 @@
# serializer version: 1
# name: test_get_tts_audio
list([
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize',
}),
])
# ---
# name: test_get_tts_audio_raw
list([
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize',
}),
])
# ---

View file

@ -10,7 +10,7 @@ from homeassistant.components.wyoming.const import DOMAIN
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
from . import EMPTY_INFO, STT_INFO
from . import EMPTY_INFO, STT_INFO, TTS_INFO
from tests.common import MockConfigEntry
@ -26,7 +26,7 @@ ADDON_DISCOVERY = HassioServiceInfo(
pytestmark = pytest.mark.usefixtures("mock_setup_entry")
async def test_form(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None:
async def test_form_stt(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None:
"""Test we get the form."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
@ -56,6 +56,36 @@ async def test_form(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None:
assert len(mock_setup_entry.mock_calls) == 1
async def test_form_tts(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None:
"""Test we get the form."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == FlowResultType.FORM
assert result["errors"] is None
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=TTS_INFO,
):
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
"host": "1.1.1.1",
"port": 1234,
},
)
await hass.async_block_till_done()
assert result2["type"] == FlowResultType.CREATE_ENTRY
assert result2["title"] == "Test TTS"
assert result2["data"] == {
"host": "1.1.1.1",
"port": 1234,
}
assert len(mock_setup_entry.mock_calls) == 1
async def test_form_cannot_connect(hass: HomeAssistant) -> None:
"""Test we handle cannot connect error."""
result = await hass.config_entries.flow.async_init(

View file

@ -0,0 +1,40 @@
"""Test tts."""
from __future__ import annotations
from unittest.mock import patch
from homeassistant.components.wyoming.data import load_wyoming_info
from homeassistant.core import HomeAssistant
from . import STT_INFO, MockAsyncTcpClient
async def test_load_info(hass: HomeAssistant, snapshot) -> None:
"""Test loading info."""
with patch(
"homeassistant.components.wyoming.data.AsyncTcpClient",
MockAsyncTcpClient([STT_INFO.event()]),
) as mock_client:
info = await load_wyoming_info("localhost", 1234)
assert info == STT_INFO
assert mock_client.written == snapshot
async def test_load_info_oserror(hass: HomeAssistant) -> None:
"""Test loading info and error raising."""
mock_client = MockAsyncTcpClient([STT_INFO.event()])
with patch(
"homeassistant.components.wyoming.data.AsyncTcpClient",
mock_client,
), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")):
info = await load_wyoming_info(
"localhost",
1234,
retries=0,
retry_wait=0,
timeout=0.001,
)
assert info is None

View file

@ -5,17 +5,19 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
async def test_cannot_connect(hass: HomeAssistant, config_entry: ConfigEntry) -> None:
async def test_cannot_connect(
hass: HomeAssistant, stt_config_entry: ConfigEntry
) -> None:
"""Test we handle cannot connect error."""
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=None,
):
assert not await hass.config_entries.async_setup(config_entry.entry_id)
assert not await hass.config_entries.async_setup(stt_config_entry.entry_id)
async def test_unload(
hass: HomeAssistant, config_entry: ConfigEntry, init_wyoming_stt
hass: HomeAssistant, stt_config_entry: ConfigEntry, init_wyoming_stt
) -> None:
"""Test unload."""
assert await hass.config_entries.async_unload(config_entry.entry_id)
assert await hass.config_entries.async_unload(stt_config_entry.entry_id)

View file

@ -3,50 +3,21 @@ from __future__ import annotations
from unittest.mock import patch
from wyoming.event import Event
from wyoming.asr import Transcript
from homeassistant.components import stt
from homeassistant.core import HomeAssistant
class MockAsyncTcpClient:
"""Mock AsyncTcpClient."""
def __init__(self, responses) -> None:
"""Initialize."""
self.host = None
self.port = None
self.written = []
self.responses = responses
async def write_event(self, event):
"""Send."""
self.written.append(event)
async def read_event(self):
"""Receive."""
return self.responses.pop(0)
async def __aenter__(self):
"""Enter."""
return self
async def __aexit__(self, exc_type, exc, tb):
"""Exit."""
def __call__(self, host, port):
"""Call."""
self.host = host
self.port = port
return self
from . import MockAsyncTcpClient
async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None:
"""Test streaming audio."""
"""Test supported properties."""
state = hass.states.get("stt.wyoming")
assert state is not None
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
assert entity is not None
assert entity.supported_languages == ["en-US"]
assert entity.supported_formats == [stt.AudioFormats.WAV]
@ -59,6 +30,7 @@ async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None:
async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot) -> None:
"""Test streaming audio."""
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
assert entity is not None
async def audio_stream():
yield "chunk1"
@ -66,7 +38,7 @@ async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot)
with patch(
"homeassistant.components.wyoming.stt.AsyncTcpClient",
MockAsyncTcpClient([Event(type="transcript", data={"text": "Hello world"})]),
MockAsyncTcpClient([Transcript(text="Hello world").event()]),
) as mock_client:
result = await entity.async_process_audio_stream(None, audio_stream())
@ -80,6 +52,7 @@ async def test_streaming_audio_connection_lost(
) -> None:
"""Test streaming audio and losing connection."""
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
assert entity is not None
async def audio_stream():
yield "chunk1"
@ -97,13 +70,12 @@ async def test_streaming_audio_connection_lost(
async def test_streaming_audio_oserror(hass: HomeAssistant, init_wyoming_stt) -> None:
"""Test streaming audio and error raising."""
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
assert entity is not None
async def audio_stream():
yield "chunk1"
mock_client = MockAsyncTcpClient(
[Event(type="transcript", data={"text": "Hello world"})]
)
mock_client = MockAsyncTcpClient([Transcript(text="Hello world").event()])
with patch(
"homeassistant.components.wyoming.stt.AsyncTcpClient",

View file

@ -0,0 +1,143 @@
"""Test tts."""
from __future__ import annotations
import io
from unittest.mock import patch
import wave
import pytest
from wyoming.audio import AudioChunk, AudioStop
from homeassistant.components import tts
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_component import DATA_INSTANCES
from . import MockAsyncTcpClient
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
init_cache_dir_side_effect,
mock_get_cache_files,
mock_init_cache_dir,
)
async def test_support(hass: HomeAssistant, init_wyoming_tts) -> None:
"""Test supported properties."""
state = hass.states.get("tts.test_tts")
assert state is not None
entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_tts")
assert entity is not None
assert entity.supported_languages == ["en-US"]
assert entity.supported_options == [tts.ATTR_AUDIO_OUTPUT, tts.ATTR_VOICE]
voices = entity.async_get_supported_voices("en-US")
assert len(voices) == 1
assert voices[0].name == "Test Voice"
assert voices[0].voice_id == "Test Voice"
assert not entity.async_get_supported_voices("de-DE")
async def test_get_tts_audio(hass: HomeAssistant, init_wyoming_tts, snapshot) -> None:
"""Test get audio."""
audio = bytes(100)
audio_events = [
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(),
]
with patch(
"homeassistant.components.wyoming.tts.AsyncTcpClient",
MockAsyncTcpClient(audio_events),
) as mock_client:
extension, data = await tts.async_get_media_source_audio(
hass,
tts.generate_media_source_id(
hass, "Hello world", "tts.test_tts", hass.config.language
),
)
assert extension == "wav"
assert data is not None
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
assert wav_file.getframerate() == 16000
assert wav_file.getsampwidth() == 2
assert wav_file.getnchannels() == 1
assert wav_file.readframes(wav_file.getnframes()) == audio
assert mock_client.written == snapshot
async def test_get_tts_audio_raw(
hass: HomeAssistant, init_wyoming_tts, snapshot
) -> None:
"""Test get raw audio."""
audio = bytes(100)
audio_events = [
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(),
]
with patch(
"homeassistant.components.wyoming.tts.AsyncTcpClient",
MockAsyncTcpClient(audio_events),
) as mock_client:
extension, data = await tts.async_get_media_source_audio(
hass,
tts.generate_media_source_id(
hass,
"Hello world",
"tts.test_tts",
hass.config.language,
options={tts.ATTR_AUDIO_OUTPUT: "raw"},
),
)
assert extension == "raw"
assert data == audio
assert mock_client.written == snapshot
async def test_get_tts_audio_connection_lost(
hass: HomeAssistant, init_wyoming_tts
) -> None:
"""Test streaming audio and losing connection."""
with patch(
"homeassistant.components.wyoming.tts.AsyncTcpClient",
MockAsyncTcpClient([None]),
), pytest.raises(HomeAssistantError):
await tts.async_get_media_source_audio(
hass,
tts.generate_media_source_id(
hass, "Hello world", "tts.test_tts", hass.config.language
),
)
async def test_get_tts_audio_audio_oserror(
hass: HomeAssistant, init_wyoming_tts
) -> None:
"""Test get audio and error raising."""
audio = bytes(100)
audio_events = [
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(),
]
mock_client = MockAsyncTcpClient(audio_events)
with patch(
"homeassistant.components.wyoming.tts.AsyncTcpClient",
mock_client,
), patch.object(
mock_client, "read_event", side_effect=OSError("Boom!")
), pytest.raises(
HomeAssistantError
):
await tts.async_get_media_source_audio(
hass,
tts.generate_media_source_id(
hass, "Hello world", "tts.test_tts", hass.config.language
),
)