From 535fb342077310482edfc5056141ac2de3e0855f Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Tue, 4 Apr 2023 14:52:36 +0200 Subject: [PATCH] Move legacy stt (#90776) * Move legacy stt to separate module * Remove case for None as provider * Add error log for unknown platform * Add some tests --- homeassistant/components/stt/__init__.py | 180 ++++------------------- homeassistant/components/stt/legacy.py | 169 +++++++++++++++++++++ tests/components/stt/common.py | 56 +++++++ tests/components/stt/test_init.py | 19 ++- tests/components/stt/test_legacy.py | 56 +++++++ 5 files changed, 322 insertions(+), 158 deletions(-) create mode 100644 homeassistant/components/stt/legacy.py create mode 100644 tests/components/stt/common.py create mode 100644 tests/components/stt/test_legacy.py diff --git a/homeassistant/components/stt/__init__.py b/homeassistant/components/stt/__init__.py index b858cc743a2..50eac1dfeb7 100644 --- a/homeassistant/components/stt/__init__.py +++ b/homeassistant/components/stt/__init__.py @@ -1,11 +1,8 @@ """Provide functionality to STT.""" from __future__ import annotations -from abc import ABC, abstractmethod import asyncio -from collections.abc import AsyncIterable -from dataclasses import asdict, dataclass -import logging +from dataclasses import asdict from typing import Any from aiohttp import web @@ -17,10 +14,8 @@ from aiohttp.web_exceptions import ( ) from homeassistant.components.http import HomeAssistantView -from homeassistant.core import HomeAssistant, callback -from homeassistant.helpers import config_per_platform, discovery +from homeassistant.core import HomeAssistant from homeassistant.helpers.typing import ConfigType -from homeassistant.setup import async_prepare_setup_platform from .const import ( DOMAIN, @@ -31,159 +26,40 @@ from .const import ( AudioSampleRates, SpeechResultState, ) +from .legacy import ( + Provider, + SpeechMetadata, + SpeechResult, + async_get_provider, + async_setup_legacy, +) -_LOGGER = logging.getLogger(__name__) - - -@callback -def async_get_provider( - hass: HomeAssistant, domain: str | None = None -) -> Provider | None: - """Return provider.""" - if domain: - return hass.data[DOMAIN].get(domain) - - if not hass.data[DOMAIN]: - return None - - if "cloud" in hass.data[DOMAIN]: - return hass.data[DOMAIN]["cloud"] - - return next(iter(hass.data[DOMAIN].values())) +__all__ = [ + "async_get_provider", + "AudioBitRates", + "AudioChannels", + "AudioCodecs", + "AudioFormats", + "AudioSampleRates", + "DOMAIN", + "Provider", + "SpeechMetadata", + "SpeechResult", + "SpeechResultState", +] async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up STT.""" - providers = hass.data[DOMAIN] = {} + platform_setups = async_setup_legacy(hass, config) - async def async_setup_platform(p_type, p_config=None, discovery_info=None): - """Set up a TTS platform.""" - if p_config is None: - p_config = {} + if platform_setups: + await asyncio.wait([asyncio.create_task(setup) for setup in platform_setups]) - platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type) - if platform is None: - return - - try: - provider = await platform.async_get_engine(hass, p_config, discovery_info) - if provider is None: - _LOGGER.error("Error setting up platform %s", p_type) - return - - provider.name = p_type - provider.hass = hass - - providers[provider.name] = provider - except Exception: # pylint: disable=broad-except - _LOGGER.exception("Error setting up platform: %s", p_type) - return - - setup_tasks = [ - asyncio.create_task(async_setup_platform(p_type, p_config)) - for p_type, p_config in config_per_platform(config, DOMAIN) - ] - - if setup_tasks: - await asyncio.wait(setup_tasks) - - # Add discovery support - async def async_platform_discovered(platform, info): - """Handle for discovered platform.""" - await async_setup_platform(platform, discovery_info=info) - - discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered) - - hass.http.register_view(SpeechToTextView(providers)) + hass.http.register_view(SpeechToTextView(hass.data[DOMAIN])) return True -@dataclass -class SpeechMetadata: - """Metadata of audio stream.""" - - language: str - format: AudioFormats - codec: AudioCodecs - bit_rate: AudioBitRates - sample_rate: AudioSampleRates - channel: AudioChannels - - def __post_init__(self) -> None: - """Finish initializing the metadata.""" - self.bit_rate = AudioBitRates(int(self.bit_rate)) - self.sample_rate = AudioSampleRates(int(self.sample_rate)) - self.channel = AudioChannels(int(self.channel)) - - -@dataclass -class SpeechResult: - """Result of audio Speech.""" - - text: str | None - result: SpeechResultState - - -class Provider(ABC): - """Represent a single STT provider.""" - - hass: HomeAssistant | None = None - name: str | None = None - - @property - @abstractmethod - def supported_languages(self) -> list[str]: - """Return a list of supported languages.""" - - @property - @abstractmethod - def supported_formats(self) -> list[AudioFormats]: - """Return a list of supported formats.""" - - @property - @abstractmethod - def supported_codecs(self) -> list[AudioCodecs]: - """Return a list of supported codecs.""" - - @property - @abstractmethod - def supported_bit_rates(self) -> list[AudioBitRates]: - """Return a list of supported bit rates.""" - - @property - @abstractmethod - def supported_sample_rates(self) -> list[AudioSampleRates]: - """Return a list of supported sample rates.""" - - @property - @abstractmethod - def supported_channels(self) -> list[AudioChannels]: - """Return a list of supported channels.""" - - @abstractmethod - async def async_process_audio_stream( - self, metadata: SpeechMetadata, stream: AsyncIterable[bytes] - ) -> SpeechResult: - """Process an audio stream to STT service. - - Only streaming of content are allow! - """ - - @callback - def check_metadata(self, metadata: SpeechMetadata) -> bool: - """Check if given metadata supported by this provider.""" - if ( - metadata.language not in self.supported_languages - or metadata.format not in self.supported_formats - or metadata.codec not in self.supported_codecs - or metadata.bit_rate not in self.supported_bit_rates - or metadata.sample_rate not in self.supported_sample_rates - or metadata.channel not in self.supported_channels - ): - return False - return True - - class SpeechToTextView(HomeAssistantView): """STT view to generate a text from audio stream.""" @@ -203,7 +79,7 @@ class SpeechToTextView(HomeAssistantView): # Get metadata try: - metadata = metadata_from_header(request) + metadata = _metadata_from_header(request) except ValueError as err: raise HTTPBadRequest(text=str(err)) from err @@ -237,7 +113,7 @@ class SpeechToTextView(HomeAssistantView): ) -def metadata_from_header(request: web.Request) -> SpeechMetadata: +def _metadata_from_header(request: web.Request) -> SpeechMetadata: """Extract STT metadata from header. X-Speech-Content: diff --git a/homeassistant/components/stt/legacy.py b/homeassistant/components/stt/legacy.py new file mode 100644 index 00000000000..2f826e0be9e --- /dev/null +++ b/homeassistant/components/stt/legacy.py @@ -0,0 +1,169 @@ +"""Handle legacy speech to text platforms.""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import AsyncIterable, Coroutine +from dataclasses import dataclass +import logging +from typing import Any + +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers import config_per_platform, discovery +from homeassistant.helpers.typing import ConfigType +from homeassistant.setup import async_prepare_setup_platform + +from .const import ( + DOMAIN, + AudioBitRates, + AudioChannels, + AudioCodecs, + AudioFormats, + AudioSampleRates, + SpeechResultState, +) + +_LOGGER = logging.getLogger(__name__) + + +@callback +def async_get_provider( + hass: HomeAssistant, domain: str | None = None +) -> Provider | None: + """Return provider.""" + if domain: + return hass.data[DOMAIN].get(domain) + + if not hass.data[DOMAIN]: + return None + + if "cloud" in hass.data[DOMAIN]: + return hass.data[DOMAIN]["cloud"] + + return next(iter(hass.data[DOMAIN].values())) + + +@callback +def async_setup_legacy( + hass: HomeAssistant, config: ConfigType +) -> list[Coroutine[Any, Any, None]]: + """Set up legacy speech to text providers.""" + providers = hass.data[DOMAIN] = {} + + async def async_setup_platform(p_type, p_config=None, discovery_info=None): + """Set up a TTS platform.""" + if p_config is None: + p_config = {} + + platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type) + if platform is None: + _LOGGER.error("Unknown speech to text platform specified") + return + + try: + provider = await platform.async_get_engine(hass, p_config, discovery_info) + + provider.name = p_type + provider.hass = hass + + providers[provider.name] = provider + except Exception: # pylint: disable=broad-except + _LOGGER.exception("Error setting up platform: %s", p_type) + return + + # Add discovery support + async def async_platform_discovered(platform, info): + """Handle for discovered platform.""" + await async_setup_platform(platform, discovery_info=info) + + discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered) + + return [ + async_setup_platform(p_type, p_config) + for p_type, p_config in config_per_platform(config, DOMAIN) + ] + + +@dataclass +class SpeechMetadata: + """Metadata of audio stream.""" + + language: str + format: AudioFormats + codec: AudioCodecs + bit_rate: AudioBitRates + sample_rate: AudioSampleRates + channel: AudioChannels + + def __post_init__(self) -> None: + """Finish initializing the metadata.""" + self.bit_rate = AudioBitRates(int(self.bit_rate)) + self.sample_rate = AudioSampleRates(int(self.sample_rate)) + self.channel = AudioChannels(int(self.channel)) + + +@dataclass +class SpeechResult: + """Result of audio Speech.""" + + text: str | None + result: SpeechResultState + + +class Provider(ABC): + """Represent a single STT provider.""" + + hass: HomeAssistant | None = None + name: str | None = None + + @property + @abstractmethod + def supported_languages(self) -> list[str]: + """Return a list of supported languages.""" + + @property + @abstractmethod + def supported_formats(self) -> list[AudioFormats]: + """Return a list of supported formats.""" + + @property + @abstractmethod + def supported_codecs(self) -> list[AudioCodecs]: + """Return a list of supported codecs.""" + + @property + @abstractmethod + def supported_bit_rates(self) -> list[AudioBitRates]: + """Return a list of supported bit rates.""" + + @property + @abstractmethod + def supported_sample_rates(self) -> list[AudioSampleRates]: + """Return a list of supported sample rates.""" + + @property + @abstractmethod + def supported_channels(self) -> list[AudioChannels]: + """Return a list of supported channels.""" + + @abstractmethod + async def async_process_audio_stream( + self, metadata: SpeechMetadata, stream: AsyncIterable[bytes] + ) -> SpeechResult: + """Process an audio stream to STT service. + + Only streaming of content are allow! + """ + + @callback + def check_metadata(self, metadata: SpeechMetadata) -> bool: + """Check if given metadata supported by this provider.""" + if ( + metadata.language not in self.supported_languages + or metadata.format not in self.supported_formats + or metadata.codec not in self.supported_codecs + or metadata.bit_rate not in self.supported_bit_rates + or metadata.sample_rate not in self.supported_sample_rates + or metadata.channel not in self.supported_channels + ): + return False + return True diff --git a/tests/components/stt/common.py b/tests/components/stt/common.py new file mode 100644 index 00000000000..0fe4d5b80d1 --- /dev/null +++ b/tests/components/stt/common.py @@ -0,0 +1,56 @@ +"""Provide common test tools for STT.""" +from __future__ import annotations + +from collections.abc import Callable, Coroutine +from pathlib import Path +from typing import Any + +from homeassistant.components.stt import Provider +from homeassistant.core import HomeAssistant +from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType + +from tests.common import MockPlatform, mock_platform + + +class MockSTTPlatform(MockPlatform): + """Help to set up test stt service.""" + + def __init__( + self, + async_get_engine: Callable[ + [HomeAssistant, ConfigType, DiscoveryInfoType | None], + Coroutine[Any, Any, Provider | None], + ] + | None = None, + get_engine: Callable[ + [HomeAssistant, ConfigType, DiscoveryInfoType | None], Provider | None + ] + | None = None, + ) -> None: + """Return the stt service.""" + super().__init__() + if get_engine: + self.get_engine = get_engine + if async_get_engine: + self.async_get_engine = async_get_engine + + +def mock_stt_platform( + hass: HomeAssistant, + tmp_path: Path, + integration: str = "stt", + async_get_engine: Callable[ + [HomeAssistant, ConfigType, DiscoveryInfoType | None], + Coroutine[Any, Any, Provider | None], + ] + | None = None, + get_engine: Callable[ + [HomeAssistant, ConfigType, DiscoveryInfoType | None], Provider | None + ] + | None = None, +): + """Specialize the mock platform for stt.""" + loaded_platform = MockSTTPlatform(async_get_engine, get_engine) + mock_platform(hass, f"{integration}.stt", loaded_platform) + + return loaded_platform diff --git a/tests/components/stt/test_init.py b/tests/components/stt/test_init.py index 3d20dbc5403..e021d73c290 100644 --- a/tests/components/stt/test_init.py +++ b/tests/components/stt/test_init.py @@ -1,7 +1,8 @@ """Test STT component setup.""" from collections.abc import AsyncIterable from http import HTTPStatus -from unittest.mock import AsyncMock, Mock +from pathlib import Path +from unittest.mock import AsyncMock import pytest @@ -20,7 +21,8 @@ from homeassistant.components.stt import ( from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component -from tests.common import mock_platform +from .common import mock_stt_platform + from tests.typing import ClientSessionGenerator @@ -31,7 +33,7 @@ class MockProvider(Provider): def __init__(self) -> None: """Init test provider.""" - self.calls = [] + self.calls: list[tuple[SpeechMetadata, AsyncIterable[bytes]]] = [] @property def supported_languages(self) -> list[str]: @@ -81,10 +83,15 @@ def mock_provider() -> MockProvider: @pytest.fixture(autouse=True) -async def mock_setup(hass: HomeAssistant, mock_provider: MockProvider) -> None: +async def mock_setup( + hass: HomeAssistant, tmp_path: Path, mock_provider: MockProvider +) -> None: """Set up a test provider.""" - mock_platform( - hass, "test.stt", Mock(async_get_engine=AsyncMock(return_value=mock_provider)) + mock_stt_platform( + hass, + tmp_path, + "test", + async_get_engine=AsyncMock(return_value=mock_provider), ) assert await async_setup_component(hass, "stt", {"stt": {"platform": "test"}}) diff --git a/tests/components/stt/test_legacy.py b/tests/components/stt/test_legacy.py new file mode 100644 index 00000000000..a95a1f0f6f4 --- /dev/null +++ b/tests/components/stt/test_legacy.py @@ -0,0 +1,56 @@ +"""Test the legacy stt setup.""" +from __future__ import annotations + +from pathlib import Path + +import pytest + +from homeassistant.components.stt import Provider +from homeassistant.core import HomeAssistant +from homeassistant.helpers.discovery import async_load_platform +from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType + +from .common import mock_stt_platform + + +async def test_invalid_platform( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture, tmp_path: Path +) -> None: + """Test platform setup with an invalid platform.""" + await async_load_platform( + hass, + "stt", + "bad_stt", + {"stt": [{"platform": "bad_stt"}]}, + hass_config={"stt": [{"platform": "bad_stt"}]}, + ) + await hass.async_block_till_done() + + assert "Unknown speech to text platform specified" in caplog.text + + +async def test_platform_setup_with_error( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture, tmp_path: Path +) -> None: + """Test platform setup with an error during setup.""" + + async def async_get_engine( + hass: HomeAssistant, + config: ConfigType, + discovery_info: DiscoveryInfoType | None = None, + ) -> Provider: + """Raise exception during platform setup.""" + raise Exception("Setup error") # pylint: disable=broad-exception-raised + + mock_stt_platform(hass, tmp_path, "bad_stt", async_get_engine=async_get_engine) + + await async_load_platform( + hass, + "stt", + "bad_stt", + {}, + hass_config={"stt": [{"platform": "bad_stt"}]}, + ) + await hass.async_block_till_done() + + assert "Error setting up platform: bad_stt" in caplog.text