diff --git a/homeassistant/components/cloud/stt.py b/homeassistant/components/cloud/stt.py index 70618ab38ef..5abc5524a88 100644 --- a/homeassistant/components/cloud/stt.py +++ b/homeassistant/components/cloud/stt.py @@ -16,6 +16,8 @@ from homeassistant.components.stt import ( SpeechResult, SpeechResultState, ) +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant from .const import DOMAIN @@ -43,7 +45,10 @@ SUPPORT_LANGUAGES = [ ] -async def async_get_engine(hass, config, discovery_info=None): +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, +) -> Provider: """Set up Cloud speech component.""" cloud: Cloud = hass.data[DOMAIN] diff --git a/homeassistant/components/demo/stt.py b/homeassistant/components/demo/stt.py index 9c3cf89d80e..1d2c61d0a53 100644 --- a/homeassistant/components/demo/stt.py +++ b/homeassistant/components/demo/stt.py @@ -14,16 +14,15 @@ from homeassistant.components.stt import ( SpeechResult, SpeechResultState, ) +from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant -from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType SUPPORT_LANGUAGES = ["en", "de"] -async def async_get_engine( +async def async_setup_entry( hass: HomeAssistant, - config: ConfigType, - discovery_info: DiscoveryInfoType | None = None, + config_entry: ConfigEntry, ) -> Provider: """Set up Demo speech component.""" return DemoProvider() diff --git a/homeassistant/components/stt/__init__.py b/homeassistant/components/stt/__init__.py index 94e08d25363..084c3627e26 100644 --- a/homeassistant/components/stt/__init__.py +++ b/homeassistant/components/stt/__init__.py @@ -2,7 +2,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -import asyncio from dataclasses import asdict, dataclass import logging from typing import Any @@ -16,10 +15,10 @@ from aiohttp.web_exceptions import ( ) from homeassistant.components.http import HomeAssistantView +from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant, callback -from homeassistant.helpers import config_per_platform, discovery +from homeassistant.helpers import engine_component from homeassistant.helpers.typing import ConfigType -from homeassistant.setup import async_prepare_setup_platform from .const import ( DOMAIN, @@ -35,60 +34,45 @@ _LOGGER = logging.getLogger(__name__) @callback -def async_get_provider(hass: HomeAssistant, domain: str | None = None) -> Provider: +def async_get_provider( + hass: HomeAssistant, provider: str | None = None +) -> Provider | None: """Return provider.""" - if domain is None: - domain = next(iter(hass.data[DOMAIN])) + component: engine_component.EngineComponent[Provider] | None = hass.data.get(DOMAIN) - return hass.data[DOMAIN][domain] + if component is None: + return None + + if provider is None: + providers = component.async_get_engines() + return providers[0] if providers else None + + return component.async_get_engine(provider) async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up STT.""" - providers = hass.data[DOMAIN] = {} - - async def async_setup_platform(p_type, p_config=None, discovery_info=None): - """Set up a TTS platform.""" - if p_config is None: - p_config = {} - - platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type) - if platform is None: - return - - try: - provider = await platform.async_get_engine(hass, p_config, discovery_info) - if provider is None: - _LOGGER.error("Error setting up platform %s", p_type) - return - - provider.name = p_type - provider.hass = hass - - providers[provider.name] = provider - except Exception: # pylint: disable=broad-except - _LOGGER.exception("Error setting up platform: %s", p_type) - return - - setup_tasks = [ - asyncio.create_task(async_setup_platform(p_type, p_config)) - for p_type, p_config in config_per_platform(config, DOMAIN) - ] - - if setup_tasks: - await asyncio.wait(setup_tasks) - - # Add discovery support - async def async_platform_discovered(platform, info): - """Handle for discovered platform.""" - await async_setup_platform(platform, discovery_info=info) - - discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered) - - hass.http.register_view(SpeechToTextView(providers)) + engines: engine_component.EngineComponent[ + Provider + ] = engine_component.EngineComponent(_LOGGER, DOMAIN, hass, config) + engines.async_setup_discovery() + hass.data[DOMAIN] = engines + hass.http.register_view(SpeechToTextView(engines)) return True +async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Set up a config entry.""" + component: engine_component.EngineComponent[Provider] = hass.data[DOMAIN] + return await component.async_setup_entry(entry) + + +async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Unload a config entry.""" + component: engine_component.EngineComponent[Provider] = hass.data[DOMAIN] + return await component.async_unload_entry(entry) + + @dataclass class SpeechMetadata: """Metadata of audio stream.""" @@ -115,7 +99,7 @@ class SpeechResult: result: SpeechResultState -class Provider(ABC): +class Provider(engine_component.Engine, ABC): """Represent a single STT provider.""" hass: HomeAssistant | None = None @@ -182,15 +166,15 @@ class SpeechToTextView(HomeAssistantView): url = "/api/stt/{provider}" name = "api:stt:provider" - def __init__(self, providers: dict[str, Provider]) -> None: + def __init__(self, engines: engine_component.EngineComponent[Provider]) -> None: """Initialize a tts view.""" - self.providers = providers + self.engines = engines async def post(self, request: web.Request, provider: str) -> web.Response: """Convert Speech (audio) to text.""" - if provider not in self.providers: + stt_provider = self.engines.async_get_engine(provider) + if stt_provider is None: raise HTTPNotFound() - stt_provider: Provider = self.providers[provider] # Get metadata try: @@ -212,9 +196,9 @@ class SpeechToTextView(HomeAssistantView): async def get(self, request: web.Request, provider: str) -> web.Response: """Return provider specific audio information.""" - if provider not in self.providers: + stt_provider = self.engines.async_get_engine(provider) + if stt_provider is None: raise HTTPNotFound() - stt_provider: Provider = self.providers[provider] return self.json( { diff --git a/homeassistant/helpers/engine_component.py b/homeassistant/helpers/engine_component.py new file mode 100644 index 00000000000..e058b7dfc41 --- /dev/null +++ b/homeassistant/helpers/engine_component.py @@ -0,0 +1,168 @@ +"""Engine component helper.""" +from __future__ import annotations + +from collections.abc import Awaitable, Callable +import logging +from types import ModuleType +from typing import Generic, Protocol, TypeVar + +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant, callback +from homeassistant.setup import async_prepare_setup_platform + +from . import discovery +from .typing import ConfigType, DiscoveryInfoType + + +class Engine: + """Base class for Home Assistant engines.""" + + async def async_internal_added_to_hass(self) -> None: + """Run when engine about to be added to Home Assistant. + + Not to be extended by integrations. + """ + + async def async_added_to_hass(self) -> None: + """Run when engine about to be added to Home Assistant.""" + + async def async_internal_will_remove_from_hass(self) -> None: + """Prepare to remove the engine from Home Assistant. + + Not to be extended by integrations. + """ + + async def async_will_remove_from_hass(self) -> None: + """Prepare to remove the engine from Home Assistant.""" + + +_EngineT_co = TypeVar("_EngineT_co", bound=Engine, covariant=True) + + +class EnginePlatformModule(Protocol[_EngineT_co]): + """Protocol type for engine platform modules.""" + + async def async_setup_entry( + self, + hass: HomeAssistant, + entry: ConfigEntry, + ) -> _EngineT_co: + """Set up an integration platform from a config entry.""" + + async def async_setup_platform( + self, + hass: HomeAssistant, + ) -> _EngineT_co: + """Set up an integration platform async.""" + + +class EngineComponent(Generic[_EngineT_co]): + """Track engines for a component.""" + + def __init__( + self, + logger: logging.Logger, + domain: str, + hass: HomeAssistant, + config: ConfigType, + ) -> None: + """Initialize the engine component.""" + self.logger = logger + self.domain = domain + self.hass = hass + self.config = config + self._engines: dict[str, _EngineT_co] = {} + + @callback + def async_get_engine(self, config_entry_id: str) -> _EngineT_co | None: + """Return a wrapped engine.""" + return self._engines.get(config_entry_id) + + @callback + def async_get_engines(self) -> list[_EngineT_co]: + """Return a wrapped engine.""" + return list(self._engines.values()) + + @callback + def async_setup_discovery(self) -> None: + """Initialize the engine component discovery.""" + + async def async_platform_discovered( + platform: str, info: DiscoveryInfoType | None + ) -> None: + """Handle for discovered platform.""" + await self.async_setup_domain(platform) + + discovery.async_listen_platform( + self.hass, self.domain, async_platform_discovered + ) + + async def async_setup_domain(self, domain: str) -> bool: + """Set up an integration.""" + + async def setup(platform: EnginePlatformModule[_EngineT_co]) -> _EngineT_co: + return await platform.async_setup_platform(self.hass) + + return await self._async_do_setup(domain, domain, setup) + + async def async_setup_entry(self, config_entry: ConfigEntry) -> bool: + """Set up a config entry.""" + + async def setup(platform: EnginePlatformModule[_EngineT_co]) -> _EngineT_co: + return await platform.async_setup_entry(self.hass, config_entry) + + return await self._async_do_setup( + config_entry.entry_id, config_entry.domain, setup + ) + + async def _async_do_setup( + self, + key: str, + platform_domain: str, + get_setup_coro: Callable[[ModuleType], Awaitable[_EngineT_co]], + ) -> bool: + """Set up an entry.""" + platform = await async_prepare_setup_platform( + self.hass, self.config, self.domain, platform_domain + ) + + if platform is None: + return False + + if key in self._engines: + raise ValueError("Config entry has already been setup!") + + try: + engine = await get_setup_coro(platform) + await engine.async_internal_added_to_hass() + await engine.async_added_to_hass() + except Exception: # pylint: disable=broad-except + self.logger.exception( + "Error getting engine for %s (%s)", key, platform_domain + ) + return False + + self._engines[key] = engine + return True + + async def async_unload_domain(self, domain: str) -> bool: + """Unload a domain.""" + return await self._async_do_unload(domain) + + async def async_unload_entry(self, config_entry: ConfigEntry) -> bool: + """Unload a config entry.""" + return await self._async_do_unload(config_entry.entry_id) + + async def _async_do_unload(self, key: str) -> bool: + """Unload an engine.""" + if (engine := self._engines.pop(key, None)) is None: + raise ValueError("Config entry was never loaded!") + + try: + await engine.async_internal_will_remove_from_hass() + await engine.async_will_remove_from_hass() + except Exception: # pylint: disable=broad-except + self.logger.exception("Error unloading entry %s", key) + return False + + return True diff --git a/tests/components/stt/test_init.py b/tests/components/stt/test_init.py index e36b8af3f6c..bb34b68bba2 100644 --- a/tests/components/stt/test_init.py +++ b/tests/components/stt/test_init.py @@ -18,9 +18,8 @@ from homeassistant.components.stt import ( async_get_provider, ) from homeassistant.core import HomeAssistant -from homeassistant.setup import async_setup_component -from tests.common import mock_platform +from tests.common import MockConfigEntry, mock_platform from tests.typing import ClientSessionGenerator @@ -80,21 +79,35 @@ def mock_provider() -> MockProvider: return MockProvider() +@pytest.fixture +def mock_config_entry(hass) -> MockConfigEntry: + """Mock config entry.""" + config_entry = MockConfigEntry() + config_entry.add_to_hass(hass) + return config_entry + + @pytest.fixture(autouse=True) -async def mock_setup(hass: HomeAssistant, mock_provider: MockProvider) -> None: +async def mock_setup( + hass: HomeAssistant, mock_provider: MockProvider, mock_config_entry: MockConfigEntry +) -> None: """Set up a test provider.""" mock_platform( - hass, "test.stt", Mock(async_get_engine=AsyncMock(return_value=mock_provider)) + hass, + "test.stt", + Mock(async_setup_entry=AsyncMock(return_value=mock_provider)), ) - assert await async_setup_component(hass, "stt", {"stt": {"platform": "test"}}) + assert await hass.config_entries.async_forward_entry_setup(mock_config_entry, "stt") async def test_get_provider_info( - hass: HomeAssistant, hass_client: ClientSessionGenerator + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + mock_config_entry: MockConfigEntry, ) -> None: """Test engine that doesn't exist.""" client = await hass_client() - response = await client.get("/api/stt/test") + response = await client.get(f"/api/stt/{mock_config_entry.entry_id}") assert response.status == HTTPStatus.OK assert await response.json() == { "languages": ["en"], @@ -116,12 +129,14 @@ async def test_get_non_existing_provider_info( async def test_stream_audio( - hass: HomeAssistant, hass_client: ClientSessionGenerator + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + mock_config_entry: MockConfigEntry, ) -> None: """Test streaming audio and getting response.""" client = await hass_client() response = await client.post( - "/api/stt/test", + f"/api/stt/{mock_config_entry.entry_id}", headers={ "X-Speech-Content": ( "format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;" @@ -155,6 +170,7 @@ async def test_stream_audio( async def test_metadata_errors( hass: HomeAssistant, hass_client: ClientSessionGenerator, + mock_config_entry: MockConfigEntry, header: str | None, status: int, error: str, @@ -165,11 +181,15 @@ async def test_metadata_errors( if header: headers["X-Speech-Content"] = header - response = await client.post("/api/stt/test", headers=headers) + response = await client.post( + f"/api/stt/{mock_config_entry.entry_id}", headers=headers + ) assert response.status == status assert await response.text() == error -async def test_get_provider(hass: HomeAssistant, mock_provider: MockProvider) -> None: +async def test_get_provider( + hass: HomeAssistant, mock_provider: MockProvider, mock_config_entry: MockConfigEntry +) -> None: """Test we can get STT providers.""" - assert mock_provider == async_get_provider(hass, "test") + assert mock_provider == async_get_provider(hass, mock_config_entry.entry_id)