Add wyoming integration with stt (#91579)

* Add wyoming integration with stt/tts

* Forward config entry setup

* Use SpeechToTextEntity

* Add strings to config flow

* Move connection into config flow

* Add tests

* On load/unload used platforms

* Tweaks

* Add unload test

* Fix stt

* Add missing file

* Add test for no services

* Improve coverage

* Finish test coverage

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen 2023-04-19 05:10:59 -05:00 committed by GitHub
parent f74103c57e
commit 85d57a046c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 683 additions and 0 deletions

View file

@ -1369,6 +1369,8 @@ build.json @home-assistant/supervisor
/tests/components/worldclock/ @fabaff
/homeassistant/components/ws66i/ @ssaenger
/tests/components/ws66i/ @ssaenger
/homeassistant/components/wyoming/ @balloob @synesthesiam
/tests/components/wyoming/ @balloob @synesthesiam
/homeassistant/components/xbox/ @hunterjm
/tests/components/xbox/ @hunterjm
/homeassistant/components/xiaomi_aqara/ @danielhiversen @syssi

View file

@ -0,0 +1,44 @@
"""The Wyoming integration."""
from __future__ import annotations
import logging
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady
from .const import DOMAIN
from .data import WyomingService
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Load Wyoming."""
service = await WyomingService.create(entry.data["host"], entry.data["port"])
if service is None:
raise ConfigEntryNotReady("Unable to connect")
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = service
await hass.config_entries.async_forward_entry_setups(
entry,
service.platforms,
)
return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload Wyoming."""
service: WyomingService = hass.data[DOMAIN][entry.entry_id]
unload_ok = await hass.config_entries.async_unload_platforms(
entry,
service.platforms,
)
if unload_ok:
del hass.data[DOMAIN][entry.entry_id]
return unload_ok

View file

@ -0,0 +1,56 @@
"""Config flow for Wyoming integration."""
from __future__ import annotations
from typing import Any
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.const import CONF_HOST, CONF_PORT
from homeassistant.data_entry_flow import FlowResult
from .const import DOMAIN
from .data import WyomingService
STEP_USER_DATA_SCHEMA = vol.Schema(
{
vol.Required(CONF_HOST): str,
vol.Required(CONF_PORT): int,
}
)
class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Wyoming integration."""
VERSION = 1
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle the initial step."""
if user_input is None:
return self.async_show_form(
step_id="user", data_schema=STEP_USER_DATA_SCHEMA
)
service = await WyomingService.create(
user_input[CONF_HOST],
user_input[CONF_PORT],
)
if service is None:
return self.async_show_form(
step_id="user",
data_schema=STEP_USER_DATA_SCHEMA,
errors={"base": "cannot_connect"},
)
# ASR = automated speech recognition (STT)
asr_installed = [asr for asr in service.info.asr if asr.installed]
if not asr_installed:
return self.async_abort(reason="no_services")
name = asr_installed[0].name
return self.async_create_entry(title=name, data=user_input)

View file

@ -0,0 +1,7 @@
"""Constants for the Wyoming integration."""
DOMAIN = "wyoming"
SAMPLE_RATE = 16000
SAMPLE_WIDTH = 2
SAMPLE_CHANNELS = 1

View file

@ -0,0 +1,66 @@
"""Base class for Wyoming providers."""
from __future__ import annotations
import asyncio
import async_timeout
from wyoming.client import AsyncTcpClient
from wyoming.info import Describe, Info
from homeassistant.const import Platform
from .error import WyomingError
_INFO_TIMEOUT = 1
_INFO_RETRY_WAIT = 2
_INFO_RETRIES = 3
class WyomingService:
"""Hold info for Wyoming service."""
def __init__(self, host: str, port: int, info: Info) -> None:
"""Initialize Wyoming service."""
self.host = host
self.port = port
self.info = info
platforms = []
if info.asr:
platforms.append(Platform.STT)
self.platforms = platforms
@classmethod
async def create(cls, host: str, port: int) -> WyomingService | None:
"""Create a Wyoming service."""
info = await load_wyoming_info(host, port)
if info is None:
return None
return cls(host, port, info)
async def load_wyoming_info(host: str, port: int) -> Info | None:
"""Load info from Wyoming server."""
wyoming_info: Info | None = None
for _ in range(_INFO_RETRIES):
try:
async with AsyncTcpClient(host, port) as client:
with async_timeout.timeout(_INFO_TIMEOUT):
# Describe -> Info
await client.write_event(Describe().event())
while True:
event = await client.read_event()
if event is None:
raise WyomingError(
"Connection closed unexpectedly",
)
if Info.is_type(event.type):
wyoming_info = Info.from_event(event)
break
except (asyncio.TimeoutError, OSError, WyomingError):
# Sleep and try again
await asyncio.sleep(_INFO_RETRY_WAIT)
return wyoming_info

View file

@ -0,0 +1,6 @@
"""Errors for the Wyoming integration."""
from homeassistant.exceptions import HomeAssistantError
class WyomingError(HomeAssistantError):
"""Base class for Wyoming errors."""

View file

@ -0,0 +1,9 @@
{
"domain": "wyoming",
"name": "Wyoming Protocol",
"codeowners": ["@balloob", "@synesthesiam"],
"config_flow": true,
"documentation": "https://www.home-assistant.io/integrations/wyoming",
"iot_class": "local_push",
"requirements": ["wyoming==0.0.1"]
}

View file

@ -0,0 +1,18 @@
{
"config": {
"step": {
"user": {
"data": {
"host": "[%key:common::config_flow::data::host%]",
"port": "[%key:common::config_flow::data::port%]"
}
}
},
"error": {
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]"
},
"abort": {
"no_services": "No services found at endpoint"
}
}
}

View file

@ -0,0 +1,129 @@
"""Support for Wyoming speech to text services."""
from collections.abc import AsyncIterable
import logging
from wyoming.asr import Transcript
from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.client import AsyncTcpClient
from homeassistant.components import stt
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN, SAMPLE_CHANNELS, SAMPLE_RATE, SAMPLE_WIDTH
from .data import WyomingService
from .error import WyomingError
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up Wyoming speech to text."""
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id]
async_add_entities(
[
WyomingSttProvider(config_entry, service),
]
)
class WyomingSttProvider(stt.SpeechToTextEntity):
"""Wyoming speech to text provider."""
def __init__(
self,
config_entry: ConfigEntry,
service: WyomingService,
) -> None:
"""Set up provider."""
self.service = service
asr_service = service.info.asr[0]
model_languages: set[str] = set()
for asr_model in asr_service.models:
if asr_model.installed:
model_languages.update(asr_model.languages)
self._supported_languages = list(model_languages)
self._attr_name = asr_service.name
self._attr_unique_id = f"{config_entry.entry_id}-stt"
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return self._supported_languages
@property
def supported_formats(self) -> list[stt.AudioFormats]:
"""Return a list of supported formats."""
return [stt.AudioFormats.WAV]
@property
def supported_codecs(self) -> list[stt.AudioCodecs]:
"""Return a list of supported codecs."""
return [stt.AudioCodecs.PCM]
@property
def supported_bit_rates(self) -> list[stt.AudioBitRates]:
"""Return a list of supported bitrates."""
return [stt.AudioBitRates.BITRATE_16]
@property
def supported_sample_rates(self) -> list[stt.AudioSampleRates]:
"""Return a list of supported samplerates."""
return [stt.AudioSampleRates.SAMPLERATE_16000]
@property
def supported_channels(self) -> list[stt.AudioChannels]:
"""Return a list of supported channels."""
return [stt.AudioChannels.CHANNEL_MONO]
async def async_process_audio_stream(
self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes]
) -> stt.SpeechResult:
"""Process an audio stream to STT service."""
try:
async with AsyncTcpClient(self.service.host, self.service.port) as client:
await client.write_event(
AudioStart(
rate=SAMPLE_RATE,
width=SAMPLE_WIDTH,
channels=SAMPLE_CHANNELS,
).event(),
)
async for audio_bytes in stream:
chunk = AudioChunk(
rate=SAMPLE_RATE,
width=SAMPLE_WIDTH,
channels=SAMPLE_CHANNELS,
audio=audio_bytes,
)
await client.write_event(chunk.event())
await client.write_event(AudioStop().event())
while True:
event = await client.read_event()
if event is None:
_LOGGER.debug("Connection lost")
return stt.SpeechResult(None, stt.SpeechResultState.ERROR)
if Transcript.is_type(event.type):
transcript = Transcript.from_event(event)
text = transcript.text
break
except (OSError, WyomingError) as err:
_LOGGER.exception("Error processing audio stream: %s", err)
return stt.SpeechResult(None, stt.SpeechResultState.ERROR)
return stt.SpeechResult(
text,
stt.SpeechResultState.SUCCESS,
)

View file

@ -502,6 +502,7 @@ FLOWS = {
"wolflink",
"workday",
"ws66i",
"wyoming",
"xbox",
"xiaomi_aqara",
"xiaomi_ble",

View file

@ -6257,6 +6257,12 @@
"config_flow": false,
"iot_class": "cloud_polling"
},
"wyoming": {
"name": "Wyoming Protocol",
"integration_type": "hub",
"config_flow": true,
"iot_class": "local_push"
},
"x10": {
"name": "Heyu X10",
"integration_type": "hub",

View file

@ -2651,6 +2651,9 @@ wled==0.16.0
# homeassistant.components.wolflink
wolf_smartset==0.1.11
# homeassistant.components.wyoming
wyoming==0.0.1
# homeassistant.components.xbox
xbox-webapi==2.0.11

View file

@ -1909,6 +1909,9 @@ wled==0.16.0
# homeassistant.components.wolflink
wolf_smartset==0.1.11
# homeassistant.components.wyoming
wyoming==0.0.1
# homeassistant.components.xbox
xbox-webapi==2.0.11

View file

@ -0,0 +1,22 @@
"""Tests for the Wyoming integration."""
from wyoming.info import AsrModel, AsrProgram, Attribution, Info
TEST_ATTR = Attribution(name="Test", url="http://www.test.com")
STT_INFO = Info(
asr=[
AsrProgram(
name="Test ASR",
installed=True,
attribution=TEST_ATTR,
models=[
AsrModel(
name="Test Model",
installed=True,
attribution=TEST_ATTR,
languages=["en-US"],
)
],
)
]
)
EMPTY_INFO = Info()

View file

@ -0,0 +1,46 @@
"""Common fixtures for the Wyoming tests."""
from collections.abc import Generator
from unittest.mock import AsyncMock, patch
import pytest
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from . import STT_INFO
from tests.common import MockConfigEntry
@pytest.fixture
def mock_setup_entry() -> Generator[AsyncMock, None, None]:
"""Override async_setup_entry."""
with patch(
"homeassistant.components.wyoming.async_setup_entry", return_value=True
) as mock_setup_entry:
yield mock_setup_entry
@pytest.fixture
def config_entry(hass: HomeAssistant) -> ConfigEntry:
"""Create a config entry."""
entry = MockConfigEntry(
domain="wyoming",
data={
"host": "1.2.3.4",
"port": 1234,
},
title="Test ASR",
)
entry.add_to_hass(hass)
return entry
@pytest.fixture
async def init_wyoming_stt(hass: HomeAssistant, config_entry: ConfigEntry):
"""Initialize Wyoming."""
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=STT_INFO,
):
await hass.config_entries.async_setup(config_entry.entry_id)

View file

@ -0,0 +1,42 @@
# serializer version: 1
# name: test_streaming_audio
list([
dict({
'data': dict({
'channels': 1,
'rate': 16000,
'timestamp': None,
'width': 2,
}),
'payload': None,
'type': 'audio-start',
}),
dict({
'data': dict({
'channels': 1,
'rate': 16000,
'timestamp': None,
'width': 2,
}),
'payload': 'chunk1',
'type': 'audio-chunk',
}),
dict({
'data': dict({
'channels': 1,
'rate': 16000,
'timestamp': None,
'width': 2,
}),
'payload': 'chunk2',
'type': 'audio-chunk',
}),
dict({
'data': dict({
'timestamp': None,
}),
'payload': None,
'type': 'audio-stop',
}),
])
# ---

View file

@ -0,0 +1,87 @@
"""Test the Wyoming config flow."""
from unittest.mock import AsyncMock, patch
import pytest
from homeassistant import config_entries
from homeassistant.components.wyoming.const import DOMAIN
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
from . import EMPTY_INFO, STT_INFO
pytestmark = pytest.mark.usefixtures("mock_setup_entry")
async def test_form(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None:
"""Test we get the form."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == FlowResultType.FORM
assert result["errors"] is None
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=STT_INFO,
):
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
"host": "1.1.1.1",
"port": 1234,
},
)
await hass.async_block_till_done()
assert result2["type"] == FlowResultType.CREATE_ENTRY
assert result2["title"] == "Test ASR"
assert result2["data"] == {
"host": "1.1.1.1",
"port": 1234,
}
assert len(mock_setup_entry.mock_calls) == 1
async def test_form_cannot_connect(hass: HomeAssistant) -> None:
"""Test we handle cannot connect error."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=None,
):
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
"host": "1.1.1.1",
"port": 1234,
},
)
assert result2["type"] == FlowResultType.FORM
assert result2["errors"] == {"base": "cannot_connect"}
async def test_no_supported_services(hass: HomeAssistant) -> None:
"""Test we handle no supported services error."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=EMPTY_INFO,
):
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
"host": "1.1.1.1",
"port": 1234,
},
)
assert result2["type"] == FlowResultType.ABORT
assert result2["reason"] == "no_services"

View file

@ -0,0 +1,21 @@
"""Test init."""
from unittest.mock import patch
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
async def test_cannot_connect(hass: HomeAssistant, config_entry: ConfigEntry) -> None:
"""Test we handle cannot connect error."""
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=None,
):
assert not await hass.config_entries.async_setup(config_entry.entry_id)
async def test_unload(
hass: HomeAssistant, config_entry: ConfigEntry, init_wyoming_stt
) -> None:
"""Test unload."""
assert await hass.config_entries.async_unload(config_entry.entry_id)

View file

@ -0,0 +1,115 @@
"""Test stt."""
from __future__ import annotations
from unittest.mock import patch
from wyoming.event import Event
from homeassistant.components import stt
from homeassistant.core import HomeAssistant
class MockAsyncTcpClient:
"""Mock AsyncTcpClient."""
def __init__(self, responses) -> None:
"""Initialize."""
self.host = None
self.port = None
self.written = []
self.responses = responses
async def write_event(self, event):
"""Send."""
self.written.append(event)
async def read_event(self):
"""Receive."""
return self.responses.pop(0)
async def __aenter__(self):
"""Enter."""
return self
async def __aexit__(self, exc_type, exc, tb):
"""Exit."""
def __call__(self, host, port):
"""Call."""
self.host = host
self.port = port
return self
async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None:
"""Test streaming audio."""
state = hass.states.get("stt.wyoming")
assert state is not None
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
assert entity.supported_languages == ["en-US"]
assert entity.supported_formats == [stt.AudioFormats.WAV]
assert entity.supported_codecs == [stt.AudioCodecs.PCM]
assert entity.supported_bit_rates == [stt.AudioBitRates.BITRATE_16]
assert entity.supported_sample_rates == [stt.AudioSampleRates.SAMPLERATE_16000]
assert entity.supported_channels == [stt.AudioChannels.CHANNEL_MONO]
async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot) -> None:
"""Test streaming audio."""
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
async def audio_stream():
yield "chunk1"
yield "chunk2"
with patch(
"homeassistant.components.wyoming.stt.AsyncTcpClient",
MockAsyncTcpClient([Event(type="transcript", data={"text": "Hello world"})]),
) as mock_client:
result = await entity.async_process_audio_stream(None, audio_stream())
assert result.result == stt.SpeechResultState.SUCCESS
assert result.text == "Hello world"
assert mock_client.written == snapshot
async def test_streaming_audio_connection_lost(
hass: HomeAssistant, init_wyoming_stt
) -> None:
"""Test streaming audio and losing connection."""
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
async def audio_stream():
yield "chunk1"
with patch(
"homeassistant.components.wyoming.stt.AsyncTcpClient",
MockAsyncTcpClient([None]),
):
result = await entity.async_process_audio_stream(None, audio_stream())
assert result.result == stt.SpeechResultState.ERROR
assert result.text is None
async def test_streaming_audio_oserror(hass: HomeAssistant, init_wyoming_stt) -> None:
"""Test streaming audio and error raising."""
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
async def audio_stream():
yield "chunk1"
mock_client = MockAsyncTcpClient(
[Event(type="transcript", data={"text": "Hello world"})]
)
with patch(
"homeassistant.components.wyoming.stt.AsyncTcpClient",
mock_client,
), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")):
result = await entity.async_process_audio_stream(None, audio_stream())
assert result.result == stt.SpeechResultState.ERROR
assert result.text is None