Wyoming tts (#91712)
* Add tts entity * Add tts entity and tests * Re-add name to TextToSpeechEntity * Fix linting * Fix ruff linting * Support voice attr (unused) * Remove async_get_text_to_speech_entity * Move name property to Wyoming TTS entity * Fix id --------- Co-authored-by: Martin Hjelmare <marhje52@gmail.com> Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
f4df0ca50a
commit
b6f2b29a99
12 changed files with 529 additions and 58 deletions
|
@ -52,10 +52,14 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
|
||||
# ASR = automated speech recognition (STT)
|
||||
asr_installed = [asr for asr in service.info.asr if asr.installed]
|
||||
if not asr_installed:
|
||||
return self.async_abort(reason="no_services")
|
||||
tts_installed = [tts for tts in service.info.tts if tts.installed]
|
||||
|
||||
if asr_installed:
|
||||
name = asr_installed[0].name
|
||||
elif tts_installed:
|
||||
name = tts_installed[0].name
|
||||
else:
|
||||
return self.async_abort(reason="no_services")
|
||||
|
||||
return self.async_create_entry(title=name, data=user_input)
|
||||
|
||||
|
|
|
@ -25,8 +25,10 @@ class WyomingService:
|
|||
self.port = port
|
||||
self.info = info
|
||||
platforms = []
|
||||
if info.asr:
|
||||
if any(asr.installed for asr in info.asr):
|
||||
platforms.append(Platform.STT)
|
||||
if any(tts.installed for tts in info.tts):
|
||||
platforms.append(Platform.TTS)
|
||||
self.platforms = platforms
|
||||
|
||||
@classmethod
|
||||
|
@ -39,14 +41,20 @@ class WyomingService:
|
|||
return cls(host, port, info)
|
||||
|
||||
|
||||
async def load_wyoming_info(host: str, port: int) -> Info | None:
|
||||
async def load_wyoming_info(
|
||||
host: str,
|
||||
port: int,
|
||||
retries: int = _INFO_RETRIES,
|
||||
retry_wait: float = _INFO_RETRY_WAIT,
|
||||
timeout: float = _INFO_TIMEOUT,
|
||||
) -> Info | None:
|
||||
"""Load info from Wyoming server."""
|
||||
wyoming_info: Info | None = None
|
||||
|
||||
for _ in range(_INFO_RETRIES):
|
||||
for _ in range(retries + 1):
|
||||
try:
|
||||
async with AsyncTcpClient(host, port) as client:
|
||||
with async_timeout.timeout(_INFO_TIMEOUT):
|
||||
with async_timeout.timeout(timeout):
|
||||
# Describe -> Info
|
||||
await client.write_event(Describe().event())
|
||||
while True:
|
||||
|
@ -58,9 +66,12 @@ async def load_wyoming_info(host: str, port: int) -> Info | None:
|
|||
|
||||
if Info.is_type(event.type):
|
||||
wyoming_info = Info.from_event(event)
|
||||
break
|
||||
break # while
|
||||
|
||||
if wyoming_info is not None:
|
||||
break # for
|
||||
except (asyncio.TimeoutError, OSError, WyomingError):
|
||||
# Sleep and try again
|
||||
await asyncio.sleep(_INFO_RETRY_WAIT)
|
||||
await asyncio.sleep(retry_wait)
|
||||
|
||||
return wyoming_info
|
||||
|
|
161
homeassistant/components/wyoming/tts.py
Normal file
161
homeassistant/components/wyoming/tts.py
Normal file
|
@ -0,0 +1,161 @@
|
|||
"""Support for Wyoming text to speech services."""
|
||||
from collections import defaultdict
|
||||
import io
|
||||
import logging
|
||||
import wave
|
||||
|
||||
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStop
|
||||
from wyoming.client import AsyncTcpClient
|
||||
from wyoming.tts import Synthesize
|
||||
|
||||
from homeassistant.components import tts
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from .const import DOMAIN
|
||||
from .data import WyomingService
|
||||
from .error import WyomingError
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up Wyoming speech to text."""
|
||||
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id]
|
||||
async_add_entities(
|
||||
[
|
||||
WyomingTtsProvider(config_entry, service),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class WyomingTtsProvider(tts.TextToSpeechEntity):
|
||||
"""Wyoming text to speech provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config_entry: ConfigEntry,
|
||||
service: WyomingService,
|
||||
) -> None:
|
||||
"""Set up provider."""
|
||||
self.service = service
|
||||
self._tts_service = next(tts for tts in service.info.tts if tts.installed)
|
||||
|
||||
voice_languages: set[str] = set()
|
||||
self._voices: dict[str, list[tts.Voice]] = defaultdict(list)
|
||||
for voice in self._tts_service.voices:
|
||||
if not voice.installed:
|
||||
continue
|
||||
|
||||
voice_languages.update(voice.languages)
|
||||
for language in voice.languages:
|
||||
self._voices[language].append(
|
||||
tts.Voice(
|
||||
voice_id=voice.name,
|
||||
name=voice.name,
|
||||
)
|
||||
)
|
||||
|
||||
self._supported_languages: list[str] = list(voice_languages)
|
||||
|
||||
self._attr_name = self._tts_service.name
|
||||
self._attr_unique_id = f"{config_entry.entry_id}-tts"
|
||||
|
||||
@property
|
||||
def name(self) -> str | None:
|
||||
"""Return the name of the provider entity."""
|
||||
# Only one entity is allowed per platform for now.
|
||||
return self._tts_service.name
|
||||
|
||||
@property
|
||||
def default_language(self):
|
||||
"""Return default language."""
|
||||
if not self._supported_languages:
|
||||
return None
|
||||
|
||||
return self._supported_languages[0]
|
||||
|
||||
@property
|
||||
def supported_languages(self):
|
||||
"""Return list of supported languages."""
|
||||
return self._supported_languages
|
||||
|
||||
@property
|
||||
def supported_options(self):
|
||||
"""Return list of supported options like voice, emotion."""
|
||||
return [tts.ATTR_AUDIO_OUTPUT, tts.ATTR_VOICE]
|
||||
|
||||
@property
|
||||
def default_options(self):
|
||||
"""Return a dict include default options."""
|
||||
return {tts.ATTR_AUDIO_OUTPUT: "wav"}
|
||||
|
||||
@callback
|
||||
def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None:
|
||||
"""Return a list of supported voices for a language."""
|
||||
return self._voices.get(language)
|
||||
|
||||
async def async_get_tts_audio(self, message, language, options=None):
|
||||
"""Load TTS from UNIX socket."""
|
||||
try:
|
||||
async with AsyncTcpClient(self.service.host, self.service.port) as client:
|
||||
await client.write_event(Synthesize(message).event())
|
||||
|
||||
with io.BytesIO() as wav_io:
|
||||
wav_writer: wave.Wave_write | None = None
|
||||
while True:
|
||||
event = await client.read_event()
|
||||
if event is None:
|
||||
_LOGGER.debug("Connection lost")
|
||||
return (None, None)
|
||||
|
||||
if AudioStop.is_type(event.type):
|
||||
break
|
||||
|
||||
if AudioChunk.is_type(event.type):
|
||||
chunk = AudioChunk.from_event(event)
|
||||
if wav_writer is None:
|
||||
wav_writer = wave.open(wav_io, "wb")
|
||||
wav_writer.setframerate(chunk.rate)
|
||||
wav_writer.setsampwidth(chunk.width)
|
||||
wav_writer.setnchannels(chunk.channels)
|
||||
|
||||
wav_writer.writeframes(chunk.audio)
|
||||
|
||||
if wav_writer is not None:
|
||||
wav_writer.close()
|
||||
|
||||
data = wav_io.getvalue()
|
||||
|
||||
except (OSError, WyomingError):
|
||||
return (None, None)
|
||||
|
||||
if (options is None) or (options[tts.ATTR_AUDIO_OUTPUT] == "wav"):
|
||||
return ("wav", data)
|
||||
|
||||
# Raw output (convert to 16Khz, 16-bit mono)
|
||||
with io.BytesIO(data) as wav_io:
|
||||
wav_reader: wave.Wave_read = wave.open(wav_io, "rb")
|
||||
raw_data = (
|
||||
AudioChunkConverter(
|
||||
rate=16000,
|
||||
width=2,
|
||||
channels=1,
|
||||
)
|
||||
.convert(
|
||||
AudioChunk(
|
||||
audio=wav_reader.readframes(wav_reader.getnframes()),
|
||||
rate=wav_reader.getframerate(),
|
||||
width=wav_reader.getsampwidth(),
|
||||
channels=wav_reader.getnchannels(),
|
||||
)
|
||||
)
|
||||
.audio
|
||||
)
|
||||
|
||||
return ("raw", raw_data)
|
|
@ -1,5 +1,5 @@
|
|||
"""Tests for the Wyoming integration."""
|
||||
from wyoming.info import AsrModel, AsrProgram, Attribution, Info
|
||||
from wyoming.info import AsrModel, AsrProgram, Attribution, Info, TtsProgram, TtsVoice
|
||||
|
||||
TEST_ATTR = Attribution(name="Test", url="http://www.test.com")
|
||||
STT_INFO = Info(
|
||||
|
@ -19,4 +19,53 @@ STT_INFO = Info(
|
|||
)
|
||||
]
|
||||
)
|
||||
TTS_INFO = Info(
|
||||
tts=[
|
||||
TtsProgram(
|
||||
name="Test TTS",
|
||||
installed=True,
|
||||
attribution=TEST_ATTR,
|
||||
voices=[
|
||||
TtsVoice(
|
||||
name="Test Voice",
|
||||
installed=True,
|
||||
attribution=TEST_ATTR,
|
||||
languages=["en-US"],
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
EMPTY_INFO = Info()
|
||||
|
||||
|
||||
class MockAsyncTcpClient:
|
||||
"""Mock AsyncTcpClient."""
|
||||
|
||||
def __init__(self, responses) -> None:
|
||||
"""Initialize."""
|
||||
self.host = None
|
||||
self.port = None
|
||||
self.written = []
|
||||
self.responses = responses
|
||||
|
||||
async def write_event(self, event):
|
||||
"""Send."""
|
||||
self.written.append(event)
|
||||
|
||||
async def read_event(self):
|
||||
"""Receive."""
|
||||
return self.responses.pop(0)
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Enter."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
"""Exit."""
|
||||
|
||||
def __call__(self, host, port):
|
||||
"""Call."""
|
||||
self.host = host
|
||||
self.port = port
|
||||
return self
|
||||
|
|
|
@ -7,7 +7,7 @@ import pytest
|
|||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from . import STT_INFO
|
||||
from . import STT_INFO, TTS_INFO
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
@ -22,7 +22,7 @@ def mock_setup_entry() -> Generator[AsyncMock, None, None]:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def config_entry(hass: HomeAssistant) -> ConfigEntry:
|
||||
def stt_config_entry(hass: HomeAssistant) -> ConfigEntry:
|
||||
"""Create a config entry."""
|
||||
entry = MockConfigEntry(
|
||||
domain="wyoming",
|
||||
|
@ -37,10 +37,35 @@ def config_entry(hass: HomeAssistant) -> ConfigEntry:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_wyoming_stt(hass: HomeAssistant, config_entry: ConfigEntry):
|
||||
"""Initialize Wyoming."""
|
||||
def tts_config_entry(hass: HomeAssistant) -> ConfigEntry:
|
||||
"""Create a config entry."""
|
||||
entry = MockConfigEntry(
|
||||
domain="wyoming",
|
||||
data={
|
||||
"host": "1.2.3.4",
|
||||
"port": 1234,
|
||||
},
|
||||
title="Test TTS",
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
return entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_wyoming_stt(hass: HomeAssistant, stt_config_entry: ConfigEntry):
|
||||
"""Initialize Wyoming STT."""
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=STT_INFO,
|
||||
):
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
await hass.config_entries.async_setup(stt_config_entry.entry_id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_wyoming_tts(hass: HomeAssistant, tts_config_entry: ConfigEntry):
|
||||
"""Initialize Wyoming TTS."""
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=TTS_INFO,
|
||||
):
|
||||
await hass.config_entries.async_setup(tts_config_entry.entry_id)
|
||||
|
|
11
tests/components/wyoming/snapshots/test_data.ambr
Normal file
11
tests/components/wyoming/snapshots/test_data.ambr
Normal file
|
@ -0,0 +1,11 @@
|
|||
# serializer version: 1
|
||||
# name: test_load_info
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'describe',
|
||||
}),
|
||||
])
|
||||
# ---
|
23
tests/components/wyoming/snapshots/test_tts.ambr
Normal file
23
tests/components/wyoming/snapshots/test_tts.ambr
Normal file
|
@ -0,0 +1,23 @@
|
|||
# serializer version: 1
|
||||
# name: test_get_tts_audio
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_get_tts_audio_raw
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize',
|
||||
}),
|
||||
])
|
||||
# ---
|
|
@ -10,7 +10,7 @@ from homeassistant.components.wyoming.const import DOMAIN
|
|||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
|
||||
from . import EMPTY_INFO, STT_INFO
|
||||
from . import EMPTY_INFO, STT_INFO, TTS_INFO
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
@ -26,7 +26,7 @@ ADDON_DISCOVERY = HassioServiceInfo(
|
|||
pytestmark = pytest.mark.usefixtures("mock_setup_entry")
|
||||
|
||||
|
||||
async def test_form(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None:
|
||||
async def test_form_stt(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None:
|
||||
"""Test we get the form."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
|
@ -56,6 +56,36 @@ async def test_form(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None:
|
|||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_form_tts(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None:
|
||||
"""Test we get the form."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["errors"] is None
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=TTS_INFO,
|
||||
):
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
"host": "1.1.1.1",
|
||||
"port": 1234,
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result2["type"] == FlowResultType.CREATE_ENTRY
|
||||
assert result2["title"] == "Test TTS"
|
||||
assert result2["data"] == {
|
||||
"host": "1.1.1.1",
|
||||
"port": 1234,
|
||||
}
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_form_cannot_connect(hass: HomeAssistant) -> None:
|
||||
"""Test we handle cannot connect error."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
|
|
40
tests/components/wyoming/test_data.py
Normal file
40
tests/components/wyoming/test_data.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
"""Test tts."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant.components.wyoming.data import load_wyoming_info
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from . import STT_INFO, MockAsyncTcpClient
|
||||
|
||||
|
||||
async def test_load_info(hass: HomeAssistant, snapshot) -> None:
|
||||
"""Test loading info."""
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.AsyncTcpClient",
|
||||
MockAsyncTcpClient([STT_INFO.event()]),
|
||||
) as mock_client:
|
||||
info = await load_wyoming_info("localhost", 1234)
|
||||
|
||||
assert info == STT_INFO
|
||||
assert mock_client.written == snapshot
|
||||
|
||||
|
||||
async def test_load_info_oserror(hass: HomeAssistant) -> None:
|
||||
"""Test loading info and error raising."""
|
||||
mock_client = MockAsyncTcpClient([STT_INFO.event()])
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.AsyncTcpClient",
|
||||
mock_client,
|
||||
), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")):
|
||||
info = await load_wyoming_info(
|
||||
"localhost",
|
||||
1234,
|
||||
retries=0,
|
||||
retry_wait=0,
|
||||
timeout=0.001,
|
||||
)
|
||||
|
||||
assert info is None
|
|
@ -5,17 +5,19 @@ from homeassistant.config_entries import ConfigEntry
|
|||
from homeassistant.core import HomeAssistant
|
||||
|
||||
|
||||
async def test_cannot_connect(hass: HomeAssistant, config_entry: ConfigEntry) -> None:
|
||||
async def test_cannot_connect(
|
||||
hass: HomeAssistant, stt_config_entry: ConfigEntry
|
||||
) -> None:
|
||||
"""Test we handle cannot connect error."""
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=None,
|
||||
):
|
||||
assert not await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
assert not await hass.config_entries.async_setup(stt_config_entry.entry_id)
|
||||
|
||||
|
||||
async def test_unload(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry, init_wyoming_stt
|
||||
hass: HomeAssistant, stt_config_entry: ConfigEntry, init_wyoming_stt
|
||||
) -> None:
|
||||
"""Test unload."""
|
||||
assert await hass.config_entries.async_unload(config_entry.entry_id)
|
||||
assert await hass.config_entries.async_unload(stt_config_entry.entry_id)
|
||||
|
|
|
@ -3,50 +3,21 @@ from __future__ import annotations
|
|||
|
||||
from unittest.mock import patch
|
||||
|
||||
from wyoming.event import Event
|
||||
from wyoming.asr import Transcript
|
||||
|
||||
from homeassistant.components import stt
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
|
||||
class MockAsyncTcpClient:
|
||||
"""Mock AsyncTcpClient."""
|
||||
|
||||
def __init__(self, responses) -> None:
|
||||
"""Initialize."""
|
||||
self.host = None
|
||||
self.port = None
|
||||
self.written = []
|
||||
self.responses = responses
|
||||
|
||||
async def write_event(self, event):
|
||||
"""Send."""
|
||||
self.written.append(event)
|
||||
|
||||
async def read_event(self):
|
||||
"""Receive."""
|
||||
return self.responses.pop(0)
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Enter."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
"""Exit."""
|
||||
|
||||
def __call__(self, host, port):
|
||||
"""Call."""
|
||||
self.host = host
|
||||
self.port = port
|
||||
return self
|
||||
from . import MockAsyncTcpClient
|
||||
|
||||
|
||||
async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None:
|
||||
"""Test streaming audio."""
|
||||
"""Test supported properties."""
|
||||
state = hass.states.get("stt.wyoming")
|
||||
assert state is not None
|
||||
|
||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
|
||||
assert entity is not None
|
||||
|
||||
assert entity.supported_languages == ["en-US"]
|
||||
assert entity.supported_formats == [stt.AudioFormats.WAV]
|
||||
|
@ -59,6 +30,7 @@ async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None:
|
|||
async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot) -> None:
|
||||
"""Test streaming audio."""
|
||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
|
||||
assert entity is not None
|
||||
|
||||
async def audio_stream():
|
||||
yield "chunk1"
|
||||
|
@ -66,7 +38,7 @@ async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot)
|
|||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.stt.AsyncTcpClient",
|
||||
MockAsyncTcpClient([Event(type="transcript", data={"text": "Hello world"})]),
|
||||
MockAsyncTcpClient([Transcript(text="Hello world").event()]),
|
||||
) as mock_client:
|
||||
result = await entity.async_process_audio_stream(None, audio_stream())
|
||||
|
||||
|
@ -80,6 +52,7 @@ async def test_streaming_audio_connection_lost(
|
|||
) -> None:
|
||||
"""Test streaming audio and losing connection."""
|
||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
|
||||
assert entity is not None
|
||||
|
||||
async def audio_stream():
|
||||
yield "chunk1"
|
||||
|
@ -97,13 +70,12 @@ async def test_streaming_audio_connection_lost(
|
|||
async def test_streaming_audio_oserror(hass: HomeAssistant, init_wyoming_stt) -> None:
|
||||
"""Test streaming audio and error raising."""
|
||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
|
||||
assert entity is not None
|
||||
|
||||
async def audio_stream():
|
||||
yield "chunk1"
|
||||
|
||||
mock_client = MockAsyncTcpClient(
|
||||
[Event(type="transcript", data={"text": "Hello world"})]
|
||||
)
|
||||
mock_client = MockAsyncTcpClient([Transcript(text="Hello world").event()])
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.stt.AsyncTcpClient",
|
||||
|
|
143
tests/components/wyoming/test_tts.py
Normal file
143
tests/components/wyoming/test_tts.py
Normal file
|
@ -0,0 +1,143 @@
|
|||
"""Test tts."""
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from unittest.mock import patch
|
||||
import wave
|
||||
|
||||
import pytest
|
||||
from wyoming.audio import AudioChunk, AudioStop
|
||||
|
||||
from homeassistant.components import tts
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.entity_component import DATA_INSTANCES
|
||||
|
||||
from . import MockAsyncTcpClient
|
||||
|
||||
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
|
||||
init_cache_dir_side_effect,
|
||||
mock_get_cache_files,
|
||||
mock_init_cache_dir,
|
||||
)
|
||||
|
||||
|
||||
async def test_support(hass: HomeAssistant, init_wyoming_tts) -> None:
|
||||
"""Test supported properties."""
|
||||
state = hass.states.get("tts.test_tts")
|
||||
assert state is not None
|
||||
|
||||
entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_tts")
|
||||
assert entity is not None
|
||||
|
||||
assert entity.supported_languages == ["en-US"]
|
||||
assert entity.supported_options == [tts.ATTR_AUDIO_OUTPUT, tts.ATTR_VOICE]
|
||||
voices = entity.async_get_supported_voices("en-US")
|
||||
assert len(voices) == 1
|
||||
assert voices[0].name == "Test Voice"
|
||||
assert voices[0].voice_id == "Test Voice"
|
||||
assert not entity.async_get_supported_voices("de-DE")
|
||||
|
||||
|
||||
async def test_get_tts_audio(hass: HomeAssistant, init_wyoming_tts, snapshot) -> None:
|
||||
"""Test get audio."""
|
||||
audio = bytes(100)
|
||||
audio_events = [
|
||||
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
||||
AudioStop().event(),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
||||
MockAsyncTcpClient(audio_events),
|
||||
) as mock_client:
|
||||
extension, data = await tts.async_get_media_source_audio(
|
||||
hass,
|
||||
tts.generate_media_source_id(
|
||||
hass, "Hello world", "tts.test_tts", hass.config.language
|
||||
),
|
||||
)
|
||||
|
||||
assert extension == "wav"
|
||||
assert data is not None
|
||||
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||
assert wav_file.getframerate() == 16000
|
||||
assert wav_file.getsampwidth() == 2
|
||||
assert wav_file.getnchannels() == 1
|
||||
assert wav_file.readframes(wav_file.getnframes()) == audio
|
||||
|
||||
assert mock_client.written == snapshot
|
||||
|
||||
|
||||
async def test_get_tts_audio_raw(
|
||||
hass: HomeAssistant, init_wyoming_tts, snapshot
|
||||
) -> None:
|
||||
"""Test get raw audio."""
|
||||
audio = bytes(100)
|
||||
audio_events = [
|
||||
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
||||
AudioStop().event(),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
||||
MockAsyncTcpClient(audio_events),
|
||||
) as mock_client:
|
||||
extension, data = await tts.async_get_media_source_audio(
|
||||
hass,
|
||||
tts.generate_media_source_id(
|
||||
hass,
|
||||
"Hello world",
|
||||
"tts.test_tts",
|
||||
hass.config.language,
|
||||
options={tts.ATTR_AUDIO_OUTPUT: "raw"},
|
||||
),
|
||||
)
|
||||
|
||||
assert extension == "raw"
|
||||
assert data == audio
|
||||
assert mock_client.written == snapshot
|
||||
|
||||
|
||||
async def test_get_tts_audio_connection_lost(
|
||||
hass: HomeAssistant, init_wyoming_tts
|
||||
) -> None:
|
||||
"""Test streaming audio and losing connection."""
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
||||
MockAsyncTcpClient([None]),
|
||||
), pytest.raises(HomeAssistantError):
|
||||
await tts.async_get_media_source_audio(
|
||||
hass,
|
||||
tts.generate_media_source_id(
|
||||
hass, "Hello world", "tts.test_tts", hass.config.language
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def test_get_tts_audio_audio_oserror(
|
||||
hass: HomeAssistant, init_wyoming_tts
|
||||
) -> None:
|
||||
"""Test get audio and error raising."""
|
||||
audio = bytes(100)
|
||||
audio_events = [
|
||||
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
||||
AudioStop().event(),
|
||||
]
|
||||
|
||||
mock_client = MockAsyncTcpClient(audio_events)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
||||
mock_client,
|
||||
), patch.object(
|
||||
mock_client, "read_event", side_effect=OSError("Boom!")
|
||||
), pytest.raises(
|
||||
HomeAssistantError
|
||||
):
|
||||
await tts.async_get_media_source_audio(
|
||||
hass,
|
||||
tts.generate_media_source_id(
|
||||
hass, "Hello world", "tts.test_tts", hass.config.language
|
||||
),
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue