Add wyoming integration with stt (#91579)
* Add wyoming integration with stt/tts * Forward config entry setup * Use SpeechToTextEntity * Add strings to config flow * Move connection into config flow * Add tests * On load/unload used platforms * Tweaks * Add unload test * Fix stt * Add missing file * Add test for no services * Improve coverage * Finish test coverage --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
f74103c57e
commit
85d57a046c
19 changed files with 683 additions and 0 deletions
|
@ -1369,6 +1369,8 @@ build.json @home-assistant/supervisor
|
|||
/tests/components/worldclock/ @fabaff
|
||||
/homeassistant/components/ws66i/ @ssaenger
|
||||
/tests/components/ws66i/ @ssaenger
|
||||
/homeassistant/components/wyoming/ @balloob @synesthesiam
|
||||
/tests/components/wyoming/ @balloob @synesthesiam
|
||||
/homeassistant/components/xbox/ @hunterjm
|
||||
/tests/components/xbox/ @hunterjm
|
||||
/homeassistant/components/xiaomi_aqara/ @danielhiversen @syssi
|
||||
|
|
44
homeassistant/components/wyoming/__init__.py
Normal file
44
homeassistant/components/wyoming/__init__.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
"""The Wyoming integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryNotReady
|
||||
|
||||
from .const import DOMAIN
|
||||
from .data import WyomingService
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Load Wyoming."""
|
||||
service = await WyomingService.create(entry.data["host"], entry.data["port"])
|
||||
|
||||
if service is None:
|
||||
raise ConfigEntryNotReady("Unable to connect")
|
||||
|
||||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = service
|
||||
|
||||
await hass.config_entries.async_forward_entry_setups(
|
||||
entry,
|
||||
service.platforms,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload Wyoming."""
|
||||
service: WyomingService = hass.data[DOMAIN][entry.entry_id]
|
||||
|
||||
unload_ok = await hass.config_entries.async_unload_platforms(
|
||||
entry,
|
||||
service.platforms,
|
||||
)
|
||||
if unload_ok:
|
||||
del hass.data[DOMAIN][entry.entry_id]
|
||||
|
||||
return unload_ok
|
56
homeassistant/components/wyoming/config_flow.py
Normal file
56
homeassistant/components/wyoming/config_flow.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
"""Config flow for Wyoming integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.const import CONF_HOST, CONF_PORT
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
|
||||
from .const import DOMAIN
|
||||
from .data import WyomingService
|
||||
|
||||
STEP_USER_DATA_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_HOST): str,
|
||||
vol.Required(CONF_PORT): int,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for Wyoming integration."""
|
||||
|
||||
VERSION = 1
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Handle the initial step."""
|
||||
if user_input is None:
|
||||
return self.async_show_form(
|
||||
step_id="user", data_schema=STEP_USER_DATA_SCHEMA
|
||||
)
|
||||
|
||||
service = await WyomingService.create(
|
||||
user_input[CONF_HOST],
|
||||
user_input[CONF_PORT],
|
||||
)
|
||||
|
||||
if service is None:
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
data_schema=STEP_USER_DATA_SCHEMA,
|
||||
errors={"base": "cannot_connect"},
|
||||
)
|
||||
|
||||
# ASR = automated speech recognition (STT)
|
||||
asr_installed = [asr for asr in service.info.asr if asr.installed]
|
||||
if not asr_installed:
|
||||
return self.async_abort(reason="no_services")
|
||||
|
||||
name = asr_installed[0].name
|
||||
|
||||
return self.async_create_entry(title=name, data=user_input)
|
7
homeassistant/components/wyoming/const.py
Normal file
7
homeassistant/components/wyoming/const.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
"""Constants for the Wyoming integration."""
|
||||
|
||||
DOMAIN = "wyoming"
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
SAMPLE_WIDTH = 2
|
||||
SAMPLE_CHANNELS = 1
|
66
homeassistant/components/wyoming/data.py
Normal file
66
homeassistant/components/wyoming/data.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
"""Base class for Wyoming providers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import async_timeout
|
||||
from wyoming.client import AsyncTcpClient
|
||||
from wyoming.info import Describe, Info
|
||||
|
||||
from homeassistant.const import Platform
|
||||
|
||||
from .error import WyomingError
|
||||
|
||||
_INFO_TIMEOUT = 1
|
||||
_INFO_RETRY_WAIT = 2
|
||||
_INFO_RETRIES = 3
|
||||
|
||||
|
||||
class WyomingService:
|
||||
"""Hold info for Wyoming service."""
|
||||
|
||||
def __init__(self, host: str, port: int, info: Info) -> None:
|
||||
"""Initialize Wyoming service."""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.info = info
|
||||
platforms = []
|
||||
if info.asr:
|
||||
platforms.append(Platform.STT)
|
||||
self.platforms = platforms
|
||||
|
||||
@classmethod
|
||||
async def create(cls, host: str, port: int) -> WyomingService | None:
|
||||
"""Create a Wyoming service."""
|
||||
info = await load_wyoming_info(host, port)
|
||||
if info is None:
|
||||
return None
|
||||
|
||||
return cls(host, port, info)
|
||||
|
||||
|
||||
async def load_wyoming_info(host: str, port: int) -> Info | None:
|
||||
"""Load info from Wyoming server."""
|
||||
wyoming_info: Info | None = None
|
||||
|
||||
for _ in range(_INFO_RETRIES):
|
||||
try:
|
||||
async with AsyncTcpClient(host, port) as client:
|
||||
with async_timeout.timeout(_INFO_TIMEOUT):
|
||||
# Describe -> Info
|
||||
await client.write_event(Describe().event())
|
||||
while True:
|
||||
event = await client.read_event()
|
||||
if event is None:
|
||||
raise WyomingError(
|
||||
"Connection closed unexpectedly",
|
||||
)
|
||||
|
||||
if Info.is_type(event.type):
|
||||
wyoming_info = Info.from_event(event)
|
||||
break
|
||||
except (asyncio.TimeoutError, OSError, WyomingError):
|
||||
# Sleep and try again
|
||||
await asyncio.sleep(_INFO_RETRY_WAIT)
|
||||
|
||||
return wyoming_info
|
6
homeassistant/components/wyoming/error.py
Normal file
6
homeassistant/components/wyoming/error.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
"""Errors for the Wyoming integration."""
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
|
||||
class WyomingError(HomeAssistantError):
|
||||
"""Base class for Wyoming errors."""
|
9
homeassistant/components/wyoming/manifest.json
Normal file
9
homeassistant/components/wyoming/manifest.json
Normal file
|
@ -0,0 +1,9 @@
|
|||
{
|
||||
"domain": "wyoming",
|
||||
"name": "Wyoming Protocol",
|
||||
"codeowners": ["@balloob", "@synesthesiam"],
|
||||
"config_flow": true,
|
||||
"documentation": "https://www.home-assistant.io/integrations/wyoming",
|
||||
"iot_class": "local_push",
|
||||
"requirements": ["wyoming==0.0.1"]
|
||||
}
|
18
homeassistant/components/wyoming/strings.json
Normal file
18
homeassistant/components/wyoming/strings.json
Normal file
|
@ -0,0 +1,18 @@
|
|||
{
|
||||
"config": {
|
||||
"step": {
|
||||
"user": {
|
||||
"data": {
|
||||
"host": "[%key:common::config_flow::data::host%]",
|
||||
"port": "[%key:common::config_flow::data::port%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"error": {
|
||||
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]"
|
||||
},
|
||||
"abort": {
|
||||
"no_services": "No services found at endpoint"
|
||||
}
|
||||
}
|
||||
}
|
129
homeassistant/components/wyoming/stt.py
Normal file
129
homeassistant/components/wyoming/stt.py
Normal file
|
@ -0,0 +1,129 @@
|
|||
"""Support for Wyoming speech to text services."""
|
||||
from collections.abc import AsyncIterable
|
||||
import logging
|
||||
|
||||
from wyoming.asr import Transcript
|
||||
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
||||
from wyoming.client import AsyncTcpClient
|
||||
|
||||
from homeassistant.components import stt
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from .const import DOMAIN, SAMPLE_CHANNELS, SAMPLE_RATE, SAMPLE_WIDTH
|
||||
from .data import WyomingService
|
||||
from .error import WyomingError
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up Wyoming speech to text."""
|
||||
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id]
|
||||
async_add_entities(
|
||||
[
|
||||
WyomingSttProvider(config_entry, service),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class WyomingSttProvider(stt.SpeechToTextEntity):
|
||||
"""Wyoming speech to text provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config_entry: ConfigEntry,
|
||||
service: WyomingService,
|
||||
) -> None:
|
||||
"""Set up provider."""
|
||||
self.service = service
|
||||
asr_service = service.info.asr[0]
|
||||
|
||||
model_languages: set[str] = set()
|
||||
for asr_model in asr_service.models:
|
||||
if asr_model.installed:
|
||||
model_languages.update(asr_model.languages)
|
||||
|
||||
self._supported_languages = list(model_languages)
|
||||
self._attr_name = asr_service.name
|
||||
self._attr_unique_id = f"{config_entry.entry_id}-stt"
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str]:
|
||||
"""Return a list of supported languages."""
|
||||
return self._supported_languages
|
||||
|
||||
@property
|
||||
def supported_formats(self) -> list[stt.AudioFormats]:
|
||||
"""Return a list of supported formats."""
|
||||
return [stt.AudioFormats.WAV]
|
||||
|
||||
@property
|
||||
def supported_codecs(self) -> list[stt.AudioCodecs]:
|
||||
"""Return a list of supported codecs."""
|
||||
return [stt.AudioCodecs.PCM]
|
||||
|
||||
@property
|
||||
def supported_bit_rates(self) -> list[stt.AudioBitRates]:
|
||||
"""Return a list of supported bitrates."""
|
||||
return [stt.AudioBitRates.BITRATE_16]
|
||||
|
||||
@property
|
||||
def supported_sample_rates(self) -> list[stt.AudioSampleRates]:
|
||||
"""Return a list of supported samplerates."""
|
||||
return [stt.AudioSampleRates.SAMPLERATE_16000]
|
||||
|
||||
@property
|
||||
def supported_channels(self) -> list[stt.AudioChannels]:
|
||||
"""Return a list of supported channels."""
|
||||
return [stt.AudioChannels.CHANNEL_MONO]
|
||||
|
||||
async def async_process_audio_stream(
|
||||
self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes]
|
||||
) -> stt.SpeechResult:
|
||||
"""Process an audio stream to STT service."""
|
||||
try:
|
||||
async with AsyncTcpClient(self.service.host, self.service.port) as client:
|
||||
await client.write_event(
|
||||
AudioStart(
|
||||
rate=SAMPLE_RATE,
|
||||
width=SAMPLE_WIDTH,
|
||||
channels=SAMPLE_CHANNELS,
|
||||
).event(),
|
||||
)
|
||||
|
||||
async for audio_bytes in stream:
|
||||
chunk = AudioChunk(
|
||||
rate=SAMPLE_RATE,
|
||||
width=SAMPLE_WIDTH,
|
||||
channels=SAMPLE_CHANNELS,
|
||||
audio=audio_bytes,
|
||||
)
|
||||
await client.write_event(chunk.event())
|
||||
|
||||
await client.write_event(AudioStop().event())
|
||||
|
||||
while True:
|
||||
event = await client.read_event()
|
||||
if event is None:
|
||||
_LOGGER.debug("Connection lost")
|
||||
return stt.SpeechResult(None, stt.SpeechResultState.ERROR)
|
||||
|
||||
if Transcript.is_type(event.type):
|
||||
transcript = Transcript.from_event(event)
|
||||
text = transcript.text
|
||||
break
|
||||
|
||||
except (OSError, WyomingError) as err:
|
||||
_LOGGER.exception("Error processing audio stream: %s", err)
|
||||
return stt.SpeechResult(None, stt.SpeechResultState.ERROR)
|
||||
|
||||
return stt.SpeechResult(
|
||||
text,
|
||||
stt.SpeechResultState.SUCCESS,
|
||||
)
|
|
@ -502,6 +502,7 @@ FLOWS = {
|
|||
"wolflink",
|
||||
"workday",
|
||||
"ws66i",
|
||||
"wyoming",
|
||||
"xbox",
|
||||
"xiaomi_aqara",
|
||||
"xiaomi_ble",
|
||||
|
|
|
@ -6257,6 +6257,12 @@
|
|||
"config_flow": false,
|
||||
"iot_class": "cloud_polling"
|
||||
},
|
||||
"wyoming": {
|
||||
"name": "Wyoming Protocol",
|
||||
"integration_type": "hub",
|
||||
"config_flow": true,
|
||||
"iot_class": "local_push"
|
||||
},
|
||||
"x10": {
|
||||
"name": "Heyu X10",
|
||||
"integration_type": "hub",
|
||||
|
|
|
@ -2651,6 +2651,9 @@ wled==0.16.0
|
|||
# homeassistant.components.wolflink
|
||||
wolf_smartset==0.1.11
|
||||
|
||||
# homeassistant.components.wyoming
|
||||
wyoming==0.0.1
|
||||
|
||||
# homeassistant.components.xbox
|
||||
xbox-webapi==2.0.11
|
||||
|
||||
|
|
|
@ -1909,6 +1909,9 @@ wled==0.16.0
|
|||
# homeassistant.components.wolflink
|
||||
wolf_smartset==0.1.11
|
||||
|
||||
# homeassistant.components.wyoming
|
||||
wyoming==0.0.1
|
||||
|
||||
# homeassistant.components.xbox
|
||||
xbox-webapi==2.0.11
|
||||
|
||||
|
|
22
tests/components/wyoming/__init__.py
Normal file
22
tests/components/wyoming/__init__.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
"""Tests for the Wyoming integration."""
|
||||
from wyoming.info import AsrModel, AsrProgram, Attribution, Info
|
||||
|
||||
TEST_ATTR = Attribution(name="Test", url="http://www.test.com")
|
||||
STT_INFO = Info(
|
||||
asr=[
|
||||
AsrProgram(
|
||||
name="Test ASR",
|
||||
installed=True,
|
||||
attribution=TEST_ATTR,
|
||||
models=[
|
||||
AsrModel(
|
||||
name="Test Model",
|
||||
installed=True,
|
||||
attribution=TEST_ATTR,
|
||||
languages=["en-US"],
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
EMPTY_INFO = Info()
|
46
tests/components/wyoming/conftest.py
Normal file
46
tests/components/wyoming/conftest.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
"""Common fixtures for the Wyoming tests."""
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from . import STT_INFO
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_setup_entry() -> Generator[AsyncMock, None, None]:
|
||||
"""Override async_setup_entry."""
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.async_setup_entry", return_value=True
|
||||
) as mock_setup_entry:
|
||||
yield mock_setup_entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_entry(hass: HomeAssistant) -> ConfigEntry:
|
||||
"""Create a config entry."""
|
||||
entry = MockConfigEntry(
|
||||
domain="wyoming",
|
||||
data={
|
||||
"host": "1.2.3.4",
|
||||
"port": 1234,
|
||||
},
|
||||
title="Test ASR",
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
return entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_wyoming_stt(hass: HomeAssistant, config_entry: ConfigEntry):
|
||||
"""Initialize Wyoming."""
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=STT_INFO,
|
||||
):
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
42
tests/components/wyoming/snapshots/test_stt.ambr
Normal file
42
tests/components/wyoming/snapshots/test_stt.ambr
Normal file
|
@ -0,0 +1,42 @@
|
|||
# serializer version: 1
|
||||
# name: test_streaming_audio
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'channels': 1,
|
||||
'rate': 16000,
|
||||
'timestamp': None,
|
||||
'width': 2,
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'audio-start',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'channels': 1,
|
||||
'rate': 16000,
|
||||
'timestamp': None,
|
||||
'width': 2,
|
||||
}),
|
||||
'payload': 'chunk1',
|
||||
'type': 'audio-chunk',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'channels': 1,
|
||||
'rate': 16000,
|
||||
'timestamp': None,
|
||||
'width': 2,
|
||||
}),
|
||||
'payload': 'chunk2',
|
||||
'type': 'audio-chunk',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'timestamp': None,
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'audio-stop',
|
||||
}),
|
||||
])
|
||||
# ---
|
87
tests/components/wyoming/test_config_flow.py
Normal file
87
tests/components/wyoming/test_config_flow.py
Normal file
|
@ -0,0 +1,87 @@
|
|||
"""Test the Wyoming config flow."""
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components.wyoming.const import DOMAIN
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
|
||||
from . import EMPTY_INFO, STT_INFO
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("mock_setup_entry")
|
||||
|
||||
|
||||
async def test_form(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None:
|
||||
"""Test we get the form."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["errors"] is None
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=STT_INFO,
|
||||
):
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
"host": "1.1.1.1",
|
||||
"port": 1234,
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result2["type"] == FlowResultType.CREATE_ENTRY
|
||||
assert result2["title"] == "Test ASR"
|
||||
assert result2["data"] == {
|
||||
"host": "1.1.1.1",
|
||||
"port": 1234,
|
||||
}
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_form_cannot_connect(hass: HomeAssistant) -> None:
|
||||
"""Test we handle cannot connect error."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=None,
|
||||
):
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
"host": "1.1.1.1",
|
||||
"port": 1234,
|
||||
},
|
||||
)
|
||||
|
||||
assert result2["type"] == FlowResultType.FORM
|
||||
assert result2["errors"] == {"base": "cannot_connect"}
|
||||
|
||||
|
||||
async def test_no_supported_services(hass: HomeAssistant) -> None:
|
||||
"""Test we handle no supported services error."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=EMPTY_INFO,
|
||||
):
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
"host": "1.1.1.1",
|
||||
"port": 1234,
|
||||
},
|
||||
)
|
||||
|
||||
assert result2["type"] == FlowResultType.ABORT
|
||||
assert result2["reason"] == "no_services"
|
21
tests/components/wyoming/test_init.py
Normal file
21
tests/components/wyoming/test_init.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
"""Test init."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
|
||||
async def test_cannot_connect(hass: HomeAssistant, config_entry: ConfigEntry) -> None:
|
||||
"""Test we handle cannot connect error."""
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=None,
|
||||
):
|
||||
assert not await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
|
||||
|
||||
async def test_unload(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry, init_wyoming_stt
|
||||
) -> None:
|
||||
"""Test unload."""
|
||||
assert await hass.config_entries.async_unload(config_entry.entry_id)
|
115
tests/components/wyoming/test_stt.py
Normal file
115
tests/components/wyoming/test_stt.py
Normal file
|
@ -0,0 +1,115 @@
|
|||
"""Test stt."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from wyoming.event import Event
|
||||
|
||||
from homeassistant.components import stt
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
|
||||
class MockAsyncTcpClient:
|
||||
"""Mock AsyncTcpClient."""
|
||||
|
||||
def __init__(self, responses) -> None:
|
||||
"""Initialize."""
|
||||
self.host = None
|
||||
self.port = None
|
||||
self.written = []
|
||||
self.responses = responses
|
||||
|
||||
async def write_event(self, event):
|
||||
"""Send."""
|
||||
self.written.append(event)
|
||||
|
||||
async def read_event(self):
|
||||
"""Receive."""
|
||||
return self.responses.pop(0)
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Enter."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
"""Exit."""
|
||||
|
||||
def __call__(self, host, port):
|
||||
"""Call."""
|
||||
self.host = host
|
||||
self.port = port
|
||||
return self
|
||||
|
||||
|
||||
async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None:
|
||||
"""Test streaming audio."""
|
||||
state = hass.states.get("stt.wyoming")
|
||||
assert state is not None
|
||||
|
||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
|
||||
|
||||
assert entity.supported_languages == ["en-US"]
|
||||
assert entity.supported_formats == [stt.AudioFormats.WAV]
|
||||
assert entity.supported_codecs == [stt.AudioCodecs.PCM]
|
||||
assert entity.supported_bit_rates == [stt.AudioBitRates.BITRATE_16]
|
||||
assert entity.supported_sample_rates == [stt.AudioSampleRates.SAMPLERATE_16000]
|
||||
assert entity.supported_channels == [stt.AudioChannels.CHANNEL_MONO]
|
||||
|
||||
|
||||
async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot) -> None:
|
||||
"""Test streaming audio."""
|
||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
|
||||
|
||||
async def audio_stream():
|
||||
yield "chunk1"
|
||||
yield "chunk2"
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.stt.AsyncTcpClient",
|
||||
MockAsyncTcpClient([Event(type="transcript", data={"text": "Hello world"})]),
|
||||
) as mock_client:
|
||||
result = await entity.async_process_audio_stream(None, audio_stream())
|
||||
|
||||
assert result.result == stt.SpeechResultState.SUCCESS
|
||||
assert result.text == "Hello world"
|
||||
assert mock_client.written == snapshot
|
||||
|
||||
|
||||
async def test_streaming_audio_connection_lost(
|
||||
hass: HomeAssistant, init_wyoming_stt
|
||||
) -> None:
|
||||
"""Test streaming audio and losing connection."""
|
||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
|
||||
|
||||
async def audio_stream():
|
||||
yield "chunk1"
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.stt.AsyncTcpClient",
|
||||
MockAsyncTcpClient([None]),
|
||||
):
|
||||
result = await entity.async_process_audio_stream(None, audio_stream())
|
||||
|
||||
assert result.result == stt.SpeechResultState.ERROR
|
||||
assert result.text is None
|
||||
|
||||
|
||||
async def test_streaming_audio_oserror(hass: HomeAssistant, init_wyoming_stt) -> None:
|
||||
"""Test streaming audio and error raising."""
|
||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
|
||||
|
||||
async def audio_stream():
|
||||
yield "chunk1"
|
||||
|
||||
mock_client = MockAsyncTcpClient(
|
||||
[Event(type="transcript", data={"text": "Hello world"})]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.stt.AsyncTcpClient",
|
||||
mock_client,
|
||||
), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")):
|
||||
result = await entity.async_process_audio_stream(None, audio_stream())
|
||||
|
||||
assert result.result == stt.SpeechResultState.ERROR
|
||||
assert result.text is None
|
Loading…
Add table
Add a link
Reference in a new issue