diff --git a/CODEOWNERS b/CODEOWNERS index eb5a560f063..672d83c6956 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1369,6 +1369,8 @@ build.json @home-assistant/supervisor /tests/components/worldclock/ @fabaff /homeassistant/components/ws66i/ @ssaenger /tests/components/ws66i/ @ssaenger +/homeassistant/components/wyoming/ @balloob @synesthesiam +/tests/components/wyoming/ @balloob @synesthesiam /homeassistant/components/xbox/ @hunterjm /tests/components/xbox/ @hunterjm /homeassistant/components/xiaomi_aqara/ @danielhiversen @syssi diff --git a/homeassistant/components/wyoming/__init__.py b/homeassistant/components/wyoming/__init__.py new file mode 100644 index 00000000000..8676365212a --- /dev/null +++ b/homeassistant/components/wyoming/__init__.py @@ -0,0 +1,44 @@ +"""The Wyoming integration.""" +from __future__ import annotations + +import logging + +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import ConfigEntryNotReady + +from .const import DOMAIN +from .data import WyomingService + +_LOGGER = logging.getLogger(__name__) + + +async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Load Wyoming.""" + service = await WyomingService.create(entry.data["host"], entry.data["port"]) + + if service is None: + raise ConfigEntryNotReady("Unable to connect") + + hass.data.setdefault(DOMAIN, {})[entry.entry_id] = service + + await hass.config_entries.async_forward_entry_setups( + entry, + service.platforms, + ) + + return True + + +async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Unload Wyoming.""" + service: WyomingService = hass.data[DOMAIN][entry.entry_id] + + unload_ok = await hass.config_entries.async_unload_platforms( + entry, + service.platforms, + ) + if unload_ok: + del hass.data[DOMAIN][entry.entry_id] + + return unload_ok diff --git a/homeassistant/components/wyoming/config_flow.py b/homeassistant/components/wyoming/config_flow.py new file mode 100644 index 00000000000..2788f30aeef --- /dev/null +++ b/homeassistant/components/wyoming/config_flow.py @@ -0,0 +1,56 @@ +"""Config flow for Wyoming integration.""" +from __future__ import annotations + +from typing import Any + +import voluptuous as vol + +from homeassistant import config_entries +from homeassistant.const import CONF_HOST, CONF_PORT +from homeassistant.data_entry_flow import FlowResult + +from .const import DOMAIN +from .data import WyomingService + +STEP_USER_DATA_SCHEMA = vol.Schema( + { + vol.Required(CONF_HOST): str, + vol.Required(CONF_PORT): int, + } +) + + +class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): + """Handle a config flow for Wyoming integration.""" + + VERSION = 1 + + async def async_step_user( + self, user_input: dict[str, Any] | None = None + ) -> FlowResult: + """Handle the initial step.""" + if user_input is None: + return self.async_show_form( + step_id="user", data_schema=STEP_USER_DATA_SCHEMA + ) + + service = await WyomingService.create( + user_input[CONF_HOST], + user_input[CONF_PORT], + ) + + if service is None: + return self.async_show_form( + step_id="user", + data_schema=STEP_USER_DATA_SCHEMA, + errors={"base": "cannot_connect"}, + ) + + # ASR = automated speech recognition (STT) + asr_installed = [asr for asr in service.info.asr if asr.installed] + if not asr_installed: + return self.async_abort(reason="no_services") + + name = asr_installed[0].name + + return self.async_create_entry(title=name, data=user_input) diff --git a/homeassistant/components/wyoming/const.py b/homeassistant/components/wyoming/const.py new file mode 100644 index 00000000000..26443cc11eb --- /dev/null +++ b/homeassistant/components/wyoming/const.py @@ -0,0 +1,7 @@ +"""Constants for the Wyoming integration.""" + +DOMAIN = "wyoming" + +SAMPLE_RATE = 16000 +SAMPLE_WIDTH = 2 +SAMPLE_CHANNELS = 1 diff --git a/homeassistant/components/wyoming/data.py b/homeassistant/components/wyoming/data.py new file mode 100644 index 00000000000..f5f869d8e68 --- /dev/null +++ b/homeassistant/components/wyoming/data.py @@ -0,0 +1,66 @@ +"""Base class for Wyoming providers.""" +from __future__ import annotations + +import asyncio + +import async_timeout +from wyoming.client import AsyncTcpClient +from wyoming.info import Describe, Info + +from homeassistant.const import Platform + +from .error import WyomingError + +_INFO_TIMEOUT = 1 +_INFO_RETRY_WAIT = 2 +_INFO_RETRIES = 3 + + +class WyomingService: + """Hold info for Wyoming service.""" + + def __init__(self, host: str, port: int, info: Info) -> None: + """Initialize Wyoming service.""" + self.host = host + self.port = port + self.info = info + platforms = [] + if info.asr: + platforms.append(Platform.STT) + self.platforms = platforms + + @classmethod + async def create(cls, host: str, port: int) -> WyomingService | None: + """Create a Wyoming service.""" + info = await load_wyoming_info(host, port) + if info is None: + return None + + return cls(host, port, info) + + +async def load_wyoming_info(host: str, port: int) -> Info | None: + """Load info from Wyoming server.""" + wyoming_info: Info | None = None + + for _ in range(_INFO_RETRIES): + try: + async with AsyncTcpClient(host, port) as client: + with async_timeout.timeout(_INFO_TIMEOUT): + # Describe -> Info + await client.write_event(Describe().event()) + while True: + event = await client.read_event() + if event is None: + raise WyomingError( + "Connection closed unexpectedly", + ) + + if Info.is_type(event.type): + wyoming_info = Info.from_event(event) + break + except (asyncio.TimeoutError, OSError, WyomingError): + # Sleep and try again + await asyncio.sleep(_INFO_RETRY_WAIT) + + return wyoming_info diff --git a/homeassistant/components/wyoming/error.py b/homeassistant/components/wyoming/error.py new file mode 100644 index 00000000000..40b2e70ce69 --- /dev/null +++ b/homeassistant/components/wyoming/error.py @@ -0,0 +1,6 @@ +"""Errors for the Wyoming integration.""" +from homeassistant.exceptions import HomeAssistantError + + +class WyomingError(HomeAssistantError): + """Base class for Wyoming errors.""" diff --git a/homeassistant/components/wyoming/manifest.json b/homeassistant/components/wyoming/manifest.json new file mode 100644 index 00000000000..9ad8092bb8c --- /dev/null +++ b/homeassistant/components/wyoming/manifest.json @@ -0,0 +1,9 @@ +{ + "domain": "wyoming", + "name": "Wyoming Protocol", + "codeowners": ["@balloob", "@synesthesiam"], + "config_flow": true, + "documentation": "https://www.home-assistant.io/integrations/wyoming", + "iot_class": "local_push", + "requirements": ["wyoming==0.0.1"] +} diff --git a/homeassistant/components/wyoming/strings.json b/homeassistant/components/wyoming/strings.json new file mode 100644 index 00000000000..76f6b837b80 --- /dev/null +++ b/homeassistant/components/wyoming/strings.json @@ -0,0 +1,18 @@ +{ + "config": { + "step": { + "user": { + "data": { + "host": "[%key:common::config_flow::data::host%]", + "port": "[%key:common::config_flow::data::port%]" + } + } + }, + "error": { + "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]" + }, + "abort": { + "no_services": "No services found at endpoint" + } + } +} diff --git a/homeassistant/components/wyoming/stt.py b/homeassistant/components/wyoming/stt.py new file mode 100644 index 00000000000..8d3f6534502 --- /dev/null +++ b/homeassistant/components/wyoming/stt.py @@ -0,0 +1,129 @@ +"""Support for Wyoming speech to text services.""" +from collections.abc import AsyncIterable +import logging + +from wyoming.asr import Transcript +from wyoming.audio import AudioChunk, AudioStart, AudioStop +from wyoming.client import AsyncTcpClient + +from homeassistant.components import stt +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant +from homeassistant.helpers.entity_platform import AddEntitiesCallback + +from .const import DOMAIN, SAMPLE_CHANNELS, SAMPLE_RATE, SAMPLE_WIDTH +from .data import WyomingService +from .error import WyomingError + +_LOGGER = logging.getLogger(__name__) + + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Set up Wyoming speech to text.""" + service: WyomingService = hass.data[DOMAIN][config_entry.entry_id] + async_add_entities( + [ + WyomingSttProvider(config_entry, service), + ] + ) + + +class WyomingSttProvider(stt.SpeechToTextEntity): + """Wyoming speech to text provider.""" + + def __init__( + self, + config_entry: ConfigEntry, + service: WyomingService, + ) -> None: + """Set up provider.""" + self.service = service + asr_service = service.info.asr[0] + + model_languages: set[str] = set() + for asr_model in asr_service.models: + if asr_model.installed: + model_languages.update(asr_model.languages) + + self._supported_languages = list(model_languages) + self._attr_name = asr_service.name + self._attr_unique_id = f"{config_entry.entry_id}-stt" + + @property + def supported_languages(self) -> list[str]: + """Return a list of supported languages.""" + return self._supported_languages + + @property + def supported_formats(self) -> list[stt.AudioFormats]: + """Return a list of supported formats.""" + return [stt.AudioFormats.WAV] + + @property + def supported_codecs(self) -> list[stt.AudioCodecs]: + """Return a list of supported codecs.""" + return [stt.AudioCodecs.PCM] + + @property + def supported_bit_rates(self) -> list[stt.AudioBitRates]: + """Return a list of supported bitrates.""" + return [stt.AudioBitRates.BITRATE_16] + + @property + def supported_sample_rates(self) -> list[stt.AudioSampleRates]: + """Return a list of supported samplerates.""" + return [stt.AudioSampleRates.SAMPLERATE_16000] + + @property + def supported_channels(self) -> list[stt.AudioChannels]: + """Return a list of supported channels.""" + return [stt.AudioChannels.CHANNEL_MONO] + + async def async_process_audio_stream( + self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes] + ) -> stt.SpeechResult: + """Process an audio stream to STT service.""" + try: + async with AsyncTcpClient(self.service.host, self.service.port) as client: + await client.write_event( + AudioStart( + rate=SAMPLE_RATE, + width=SAMPLE_WIDTH, + channels=SAMPLE_CHANNELS, + ).event(), + ) + + async for audio_bytes in stream: + chunk = AudioChunk( + rate=SAMPLE_RATE, + width=SAMPLE_WIDTH, + channels=SAMPLE_CHANNELS, + audio=audio_bytes, + ) + await client.write_event(chunk.event()) + + await client.write_event(AudioStop().event()) + + while True: + event = await client.read_event() + if event is None: + _LOGGER.debug("Connection lost") + return stt.SpeechResult(None, stt.SpeechResultState.ERROR) + + if Transcript.is_type(event.type): + transcript = Transcript.from_event(event) + text = transcript.text + break + + except (OSError, WyomingError) as err: + _LOGGER.exception("Error processing audio stream: %s", err) + return stt.SpeechResult(None, stt.SpeechResultState.ERROR) + + return stt.SpeechResult( + text, + stt.SpeechResultState.SUCCESS, + ) diff --git a/homeassistant/generated/config_flows.py b/homeassistant/generated/config_flows.py index 2d07ac554cb..074fc1ee858 100644 --- a/homeassistant/generated/config_flows.py +++ b/homeassistant/generated/config_flows.py @@ -502,6 +502,7 @@ FLOWS = { "wolflink", "workday", "ws66i", + "wyoming", "xbox", "xiaomi_aqara", "xiaomi_ble", diff --git a/homeassistant/generated/integrations.json b/homeassistant/generated/integrations.json index cb2d1076181..a8216f21db1 100644 --- a/homeassistant/generated/integrations.json +++ b/homeassistant/generated/integrations.json @@ -6257,6 +6257,12 @@ "config_flow": false, "iot_class": "cloud_polling" }, + "wyoming": { + "name": "Wyoming Protocol", + "integration_type": "hub", + "config_flow": true, + "iot_class": "local_push" + }, "x10": { "name": "Heyu X10", "integration_type": "hub", diff --git a/requirements_all.txt b/requirements_all.txt index 773d9e9ec53..fcdfc4d9675 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -2651,6 +2651,9 @@ wled==0.16.0 # homeassistant.components.wolflink wolf_smartset==0.1.11 +# homeassistant.components.wyoming +wyoming==0.0.1 + # homeassistant.components.xbox xbox-webapi==2.0.11 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 56589414ee2..0f09d164f77 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -1909,6 +1909,9 @@ wled==0.16.0 # homeassistant.components.wolflink wolf_smartset==0.1.11 +# homeassistant.components.wyoming +wyoming==0.0.1 + # homeassistant.components.xbox xbox-webapi==2.0.11 diff --git a/tests/components/wyoming/__init__.py b/tests/components/wyoming/__init__.py new file mode 100644 index 00000000000..5df845bb63a --- /dev/null +++ b/tests/components/wyoming/__init__.py @@ -0,0 +1,22 @@ +"""Tests for the Wyoming integration.""" +from wyoming.info import AsrModel, AsrProgram, Attribution, Info + +TEST_ATTR = Attribution(name="Test", url="http://www.test.com") +STT_INFO = Info( + asr=[ + AsrProgram( + name="Test ASR", + installed=True, + attribution=TEST_ATTR, + models=[ + AsrModel( + name="Test Model", + installed=True, + attribution=TEST_ATTR, + languages=["en-US"], + ) + ], + ) + ] +) +EMPTY_INFO = Info() diff --git a/tests/components/wyoming/conftest.py b/tests/components/wyoming/conftest.py new file mode 100644 index 00000000000..a3c83901453 --- /dev/null +++ b/tests/components/wyoming/conftest.py @@ -0,0 +1,46 @@ +"""Common fixtures for the Wyoming tests.""" +from collections.abc import Generator +from unittest.mock import AsyncMock, patch + +import pytest + +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant + +from . import STT_INFO + +from tests.common import MockConfigEntry + + +@pytest.fixture +def mock_setup_entry() -> Generator[AsyncMock, None, None]: + """Override async_setup_entry.""" + with patch( + "homeassistant.components.wyoming.async_setup_entry", return_value=True + ) as mock_setup_entry: + yield mock_setup_entry + + +@pytest.fixture +def config_entry(hass: HomeAssistant) -> ConfigEntry: + """Create a config entry.""" + entry = MockConfigEntry( + domain="wyoming", + data={ + "host": "1.2.3.4", + "port": 1234, + }, + title="Test ASR", + ) + entry.add_to_hass(hass) + return entry + + +@pytest.fixture +async def init_wyoming_stt(hass: HomeAssistant, config_entry: ConfigEntry): + """Initialize Wyoming.""" + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=STT_INFO, + ): + await hass.config_entries.async_setup(config_entry.entry_id) diff --git a/tests/components/wyoming/snapshots/test_stt.ambr b/tests/components/wyoming/snapshots/test_stt.ambr new file mode 100644 index 00000000000..08fe6a1ef8e --- /dev/null +++ b/tests/components/wyoming/snapshots/test_stt.ambr @@ -0,0 +1,42 @@ +# serializer version: 1 +# name: test_streaming_audio + list([ + dict({ + 'data': dict({ + 'channels': 1, + 'rate': 16000, + 'timestamp': None, + 'width': 2, + }), + 'payload': None, + 'type': 'audio-start', + }), + dict({ + 'data': dict({ + 'channels': 1, + 'rate': 16000, + 'timestamp': None, + 'width': 2, + }), + 'payload': 'chunk1', + 'type': 'audio-chunk', + }), + dict({ + 'data': dict({ + 'channels': 1, + 'rate': 16000, + 'timestamp': None, + 'width': 2, + }), + 'payload': 'chunk2', + 'type': 'audio-chunk', + }), + dict({ + 'data': dict({ + 'timestamp': None, + }), + 'payload': None, + 'type': 'audio-stop', + }), + ]) +# --- diff --git a/tests/components/wyoming/test_config_flow.py b/tests/components/wyoming/test_config_flow.py new file mode 100644 index 00000000000..8a0bf4955e7 --- /dev/null +++ b/tests/components/wyoming/test_config_flow.py @@ -0,0 +1,87 @@ +"""Test the Wyoming config flow.""" +from unittest.mock import AsyncMock, patch + +import pytest + +from homeassistant import config_entries +from homeassistant.components.wyoming.const import DOMAIN +from homeassistant.core import HomeAssistant +from homeassistant.data_entry_flow import FlowResultType + +from . import EMPTY_INFO, STT_INFO + +pytestmark = pytest.mark.usefixtures("mock_setup_entry") + + +async def test_form(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None: + """Test we get the form.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] == FlowResultType.FORM + assert result["errors"] is None + + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=STT_INFO, + ): + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + "host": "1.1.1.1", + "port": 1234, + }, + ) + await hass.async_block_till_done() + + assert result2["type"] == FlowResultType.CREATE_ENTRY + assert result2["title"] == "Test ASR" + assert result2["data"] == { + "host": "1.1.1.1", + "port": 1234, + } + assert len(mock_setup_entry.mock_calls) == 1 + + +async def test_form_cannot_connect(hass: HomeAssistant) -> None: + """Test we handle cannot connect error.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=None, + ): + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + "host": "1.1.1.1", + "port": 1234, + }, + ) + + assert result2["type"] == FlowResultType.FORM + assert result2["errors"] == {"base": "cannot_connect"} + + +async def test_no_supported_services(hass: HomeAssistant) -> None: + """Test we handle no supported services error.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=EMPTY_INFO, + ): + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + "host": "1.1.1.1", + "port": 1234, + }, + ) + + assert result2["type"] == FlowResultType.ABORT + assert result2["reason"] == "no_services" diff --git a/tests/components/wyoming/test_init.py b/tests/components/wyoming/test_init.py new file mode 100644 index 00000000000..1a8b89d9b5e --- /dev/null +++ b/tests/components/wyoming/test_init.py @@ -0,0 +1,21 @@ +"""Test init.""" +from unittest.mock import patch + +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant + + +async def test_cannot_connect(hass: HomeAssistant, config_entry: ConfigEntry) -> None: + """Test we handle cannot connect error.""" + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=None, + ): + assert not await hass.config_entries.async_setup(config_entry.entry_id) + + +async def test_unload( + hass: HomeAssistant, config_entry: ConfigEntry, init_wyoming_stt +) -> None: + """Test unload.""" + assert await hass.config_entries.async_unload(config_entry.entry_id) diff --git a/tests/components/wyoming/test_stt.py b/tests/components/wyoming/test_stt.py new file mode 100644 index 00000000000..1f73426e9f9 --- /dev/null +++ b/tests/components/wyoming/test_stt.py @@ -0,0 +1,115 @@ +"""Test stt.""" +from __future__ import annotations + +from unittest.mock import patch + +from wyoming.event import Event + +from homeassistant.components import stt +from homeassistant.core import HomeAssistant + + +class MockAsyncTcpClient: + """Mock AsyncTcpClient.""" + + def __init__(self, responses) -> None: + """Initialize.""" + self.host = None + self.port = None + self.written = [] + self.responses = responses + + async def write_event(self, event): + """Send.""" + self.written.append(event) + + async def read_event(self): + """Receive.""" + return self.responses.pop(0) + + async def __aenter__(self): + """Enter.""" + return self + + async def __aexit__(self, exc_type, exc, tb): + """Exit.""" + + def __call__(self, host, port): + """Call.""" + self.host = host + self.port = port + return self + + +async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None: + """Test streaming audio.""" + state = hass.states.get("stt.wyoming") + assert state is not None + + entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming") + + assert entity.supported_languages == ["en-US"] + assert entity.supported_formats == [stt.AudioFormats.WAV] + assert entity.supported_codecs == [stt.AudioCodecs.PCM] + assert entity.supported_bit_rates == [stt.AudioBitRates.BITRATE_16] + assert entity.supported_sample_rates == [stt.AudioSampleRates.SAMPLERATE_16000] + assert entity.supported_channels == [stt.AudioChannels.CHANNEL_MONO] + + +async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot) -> None: + """Test streaming audio.""" + entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming") + + async def audio_stream(): + yield "chunk1" + yield "chunk2" + + with patch( + "homeassistant.components.wyoming.stt.AsyncTcpClient", + MockAsyncTcpClient([Event(type="transcript", data={"text": "Hello world"})]), + ) as mock_client: + result = await entity.async_process_audio_stream(None, audio_stream()) + + assert result.result == stt.SpeechResultState.SUCCESS + assert result.text == "Hello world" + assert mock_client.written == snapshot + + +async def test_streaming_audio_connection_lost( + hass: HomeAssistant, init_wyoming_stt +) -> None: + """Test streaming audio and losing connection.""" + entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming") + + async def audio_stream(): + yield "chunk1" + + with patch( + "homeassistant.components.wyoming.stt.AsyncTcpClient", + MockAsyncTcpClient([None]), + ): + result = await entity.async_process_audio_stream(None, audio_stream()) + + assert result.result == stt.SpeechResultState.ERROR + assert result.text is None + + +async def test_streaming_audio_oserror(hass: HomeAssistant, init_wyoming_stt) -> None: + """Test streaming audio and error raising.""" + entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming") + + async def audio_stream(): + yield "chunk1" + + mock_client = MockAsyncTcpClient( + [Event(type="transcript", data={"text": "Hello world"})] + ) + + with patch( + "homeassistant.components.wyoming.stt.AsyncTcpClient", + mock_client, + ), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")): + result = await entity.async_process_audio_stream(None, audio_stream()) + + assert result.result == stt.SpeechResultState.ERROR + assert result.text is None