From aeccd525f54d17ada32b3c8e22743f2c2e0f9340 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 13 Mar 2023 23:56:24 -0400 Subject: [PATCH 1/5] Introduce engine component --- homeassistant/components/cloud/stt.py | 7 +- homeassistant/components/demo/stt.py | 7 +- homeassistant/components/stt/__init__.py | 90 +++++++++-------------- homeassistant/helpers/engine.py | 26 +++++++ homeassistant/helpers/engine_component.py | 86 ++++++++++++++++++++++ homeassistant/helpers/engine_platform.py | 65 ++++++++++++++++ tests/components/stt/test_init.py | 44 ++++++++--- 7 files changed, 253 insertions(+), 72 deletions(-) create mode 100644 homeassistant/helpers/engine.py create mode 100644 homeassistant/helpers/engine_component.py create mode 100644 homeassistant/helpers/engine_platform.py diff --git a/homeassistant/components/cloud/stt.py b/homeassistant/components/cloud/stt.py index 70618ab38ef..5abc5524a88 100644 --- a/homeassistant/components/cloud/stt.py +++ b/homeassistant/components/cloud/stt.py @@ -16,6 +16,8 @@ from homeassistant.components.stt import ( SpeechResult, SpeechResultState, ) +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant from .const import DOMAIN @@ -43,7 +45,10 @@ SUPPORT_LANGUAGES = [ ] -async def async_get_engine(hass, config, discovery_info=None): +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, +) -> Provider: """Set up Cloud speech component.""" cloud: Cloud = hass.data[DOMAIN] diff --git a/homeassistant/components/demo/stt.py b/homeassistant/components/demo/stt.py index 9c3cf89d80e..1d2c61d0a53 100644 --- a/homeassistant/components/demo/stt.py +++ b/homeassistant/components/demo/stt.py @@ -14,16 +14,15 @@ from homeassistant.components.stt import ( SpeechResult, SpeechResultState, ) +from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant -from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType SUPPORT_LANGUAGES = ["en", "de"] -async def async_get_engine( +async def async_setup_entry( hass: HomeAssistant, - config: ConfigType, - discovery_info: DiscoveryInfoType | None = None, + config_entry: ConfigEntry, ) -> Provider: """Set up Demo speech component.""" return DemoProvider() diff --git a/homeassistant/components/stt/__init__.py b/homeassistant/components/stt/__init__.py index 94e08d25363..caa094f88e3 100644 --- a/homeassistant/components/stt/__init__.py +++ b/homeassistant/components/stt/__init__.py @@ -2,7 +2,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -import asyncio from dataclasses import asdict, dataclass import logging from typing import Any @@ -16,10 +15,10 @@ from aiohttp.web_exceptions import ( ) from homeassistant.components.http import HomeAssistantView +from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant, callback -from homeassistant.helpers import config_per_platform, discovery +from homeassistant.helpers import engine, engine_component from homeassistant.helpers.typing import ConfigType -from homeassistant.setup import async_prepare_setup_platform from .const import ( DOMAIN, @@ -35,60 +34,41 @@ _LOGGER = logging.getLogger(__name__) @callback -def async_get_provider(hass: HomeAssistant, domain: str | None = None) -> Provider: +def async_get_provider( + hass: HomeAssistant, provider: str | None = None +) -> Provider | None: """Return provider.""" - if domain is None: - domain = next(iter(hass.data[DOMAIN])) + component: engine_component.EngineComponent[Provider] | None = hass.data.get(DOMAIN) - return hass.data[DOMAIN][domain] + if component is None: + return None + + if provider is None: + providers = component.async_get_engines() + return providers[0] if providers else None + + return component.async_get_engine(provider) async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up STT.""" - providers = hass.data[DOMAIN] = {} - - async def async_setup_platform(p_type, p_config=None, discovery_info=None): - """Set up a TTS platform.""" - if p_config is None: - p_config = {} - - platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type) - if platform is None: - return - - try: - provider = await platform.async_get_engine(hass, p_config, discovery_info) - if provider is None: - _LOGGER.error("Error setting up platform %s", p_type) - return - - provider.name = p_type - provider.hass = hass - - providers[provider.name] = provider - except Exception: # pylint: disable=broad-except - _LOGGER.exception("Error setting up platform: %s", p_type) - return - - setup_tasks = [ - asyncio.create_task(async_setup_platform(p_type, p_config)) - for p_type, p_config in config_per_platform(config, DOMAIN) - ] - - if setup_tasks: - await asyncio.wait(setup_tasks) - - # Add discovery support - async def async_platform_discovered(platform, info): - """Handle for discovered platform.""" - await async_setup_platform(platform, discovery_info=info) - - discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered) - - hass.http.register_view(SpeechToTextView(providers)) + hass.data[DOMAIN] = engine_component.EngineComponent(_LOGGER, DOMAIN, hass, config) + hass.http.register_view(SpeechToTextView(hass.data[DOMAIN])) return True +async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Set up a config entry.""" + component: engine_component.EngineComponent[Provider] = hass.data[DOMAIN] + return await component.async_setup_entry(entry) + + +async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Unload a config entry.""" + component: engine_component.EngineComponent[Provider] = hass.data[DOMAIN] + return await component.async_unload_entry(entry) + + @dataclass class SpeechMetadata: """Metadata of audio stream.""" @@ -115,7 +95,7 @@ class SpeechResult: result: SpeechResultState -class Provider(ABC): +class Provider(engine.Engine, ABC): """Represent a single STT provider.""" hass: HomeAssistant | None = None @@ -182,15 +162,15 @@ class SpeechToTextView(HomeAssistantView): url = "/api/stt/{provider}" name = "api:stt:provider" - def __init__(self, providers: dict[str, Provider]) -> None: + def __init__(self, engines: engine_component.EngineComponent[Provider]) -> None: """Initialize a tts view.""" - self.providers = providers + self.engines = engines async def post(self, request: web.Request, provider: str) -> web.Response: """Convert Speech (audio) to text.""" - if provider not in self.providers: + stt_provider = self.engines.async_get_engine(provider) + if stt_provider is None: raise HTTPNotFound() - stt_provider: Provider = self.providers[provider] # Get metadata try: @@ -212,9 +192,9 @@ class SpeechToTextView(HomeAssistantView): async def get(self, request: web.Request, provider: str) -> web.Response: """Return provider specific audio information.""" - if provider not in self.providers: + stt_provider = self.engines.async_get_engine(provider) + if stt_provider is None: raise HTTPNotFound() - stt_provider: Provider = self.providers[provider] return self.json( { diff --git a/homeassistant/helpers/engine.py b/homeassistant/helpers/engine.py new file mode 100644 index 00000000000..99a7e9661a7 --- /dev/null +++ b/homeassistant/helpers/engine.py @@ -0,0 +1,26 @@ +"""Base class for Home Assistant engines.""" + + +class Engine: + """Base class for Home Assistant engines.""" + + async def async_internal_added_to_hass(self) -> None: + """Run when service about to be added to hass. + + Not to be extended by integrations. + """ + + async def async_added_to_hass(self) -> None: + """Run when service about to be added to hass. + + Not to be extended by integrations. + """ + + async def async_internal_will_remove_from_hass(self) -> None: + """Prepare to remove the service from Home Assistant. + + Not to be extended by integrations. + """ + + async def async_will_remove_from_hass(self) -> None: + """Prepare to remove the service from Home Assistant.""" diff --git a/homeassistant/helpers/engine_component.py b/homeassistant/helpers/engine_component.py new file mode 100644 index 00000000000..75873a9237a --- /dev/null +++ b/homeassistant/helpers/engine_component.py @@ -0,0 +1,86 @@ +"""Service platform helper.""" +from __future__ import annotations + +import logging +from typing import Generic, TypeVar + +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant, callback +from homeassistant.setup import async_prepare_setup_platform + +from .engine import Engine +from .engine_platform import EnginePlatform +from .typing import ConfigType + +_EngineT = TypeVar("_EngineT", bound=Engine) + + +class EngineComponent(Generic[_EngineT]): + """Track engines for a component.""" + + def __init__( + self, + logger: logging.Logger, + domain: str, + hass: HomeAssistant, + config: ConfigType, + ) -> None: + """Initialize the engine component.""" + self.logger = logger + self.domain = domain + self.hass = hass + self.config = config + self._platforms: dict[str, EnginePlatform[_EngineT]] = {} + + @callback + def async_get_engine(self, config_entry_id: str) -> _EngineT | None: + """Return a wrapped engine.""" + platform = self._platforms.get(config_entry_id) + return None if platform is None else platform.engine + + @callback + def async_get_engines(self) -> list[_EngineT]: + """Return a wrapped engine.""" + return [ + platform.engine + for platform in self._platforms.values() + if platform.engine is not None + ] + + async def async_setup_entry(self, config_entry: ConfigEntry) -> bool: + """Set up a config entry.""" + platform_type = config_entry.domain + platform = await async_prepare_setup_platform( + self.hass, + # In future PR we should make hass_config part of the constructor + # params. + self.config or {}, + self.domain, + platform_type, + ) + + if platform is None: + return False + + key = config_entry.entry_id + + if key in self._platforms: + raise ValueError("Config entry has already been setup!") + + self._platforms[key] = EnginePlatform( + self.logger, + self.hass, + config_entry, + platform, + ) + + return await self._platforms[key].async_setup_entry() + + async def async_unload_entry(self, config_entry: ConfigEntry) -> bool: + """Unload a config entry.""" + key = config_entry.entry_id + + if (platform := self._platforms.pop(key, None)) is None: + raise ValueError("Config entry was never loaded!") + + return await platform.async_unload_entry() diff --git a/homeassistant/helpers/engine_platform.py b/homeassistant/helpers/engine_platform.py new file mode 100644 index 00000000000..95397213dd5 --- /dev/null +++ b/homeassistant/helpers/engine_platform.py @@ -0,0 +1,65 @@ +"""Service platform helper.""" +import logging +from typing import Generic, Protocol, TypeVar + +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant + +from .engine import Engine + +_EngineT_co = TypeVar("_EngineT_co", bound=Engine, covariant=True) + + +class EnginePlatformModule(Protocol[_EngineT_co]): + """Protocol type for engine platform modules.""" + + async def async_setup_entry( + self, + hass: HomeAssistant, + entry: ConfigEntry, + ) -> _EngineT_co: + """Set up an integration platform from a config entry.""" + + +class EnginePlatform(Generic[_EngineT_co]): + """Track engines for a platform.""" + + def __init__( + self, + logger: logging.Logger, + hass: HomeAssistant, + config_entry: ConfigEntry, + platform: EnginePlatformModule, + ) -> None: + """Initialize the engine platform.""" + self.logger = logger + self.hass = hass + self.config_entry = config_entry + self.platform = platform + self.engine: _EngineT_co | None = None + + async def async_setup_entry(self) -> bool: + """Set up a config entry.""" + try: + engine = await self.platform.async_setup_entry(self.hass, self.config_entry) + except Exception: # pylint: disable=broad-except + self.logger.exception( + "Error setting up entry %s", self.config_entry.entry_id + ) + return False + + await engine.async_internal_added_to_hass() + await engine.async_added_to_hass() + + self.engine = engine + return True + + async def async_unload_entry(self) -> bool: + """Unload a config entry.""" + if self.engine is None: + return True + + await self.engine.async_internal_will_remove_from_hass() + await self.engine.async_will_remove_from_hass() + self.engine = None + return True diff --git a/tests/components/stt/test_init.py b/tests/components/stt/test_init.py index e36b8af3f6c..bb34b68bba2 100644 --- a/tests/components/stt/test_init.py +++ b/tests/components/stt/test_init.py @@ -18,9 +18,8 @@ from homeassistant.components.stt import ( async_get_provider, ) from homeassistant.core import HomeAssistant -from homeassistant.setup import async_setup_component -from tests.common import mock_platform +from tests.common import MockConfigEntry, mock_platform from tests.typing import ClientSessionGenerator @@ -80,21 +79,35 @@ def mock_provider() -> MockProvider: return MockProvider() +@pytest.fixture +def mock_config_entry(hass) -> MockConfigEntry: + """Mock config entry.""" + config_entry = MockConfigEntry() + config_entry.add_to_hass(hass) + return config_entry + + @pytest.fixture(autouse=True) -async def mock_setup(hass: HomeAssistant, mock_provider: MockProvider) -> None: +async def mock_setup( + hass: HomeAssistant, mock_provider: MockProvider, mock_config_entry: MockConfigEntry +) -> None: """Set up a test provider.""" mock_platform( - hass, "test.stt", Mock(async_get_engine=AsyncMock(return_value=mock_provider)) + hass, + "test.stt", + Mock(async_setup_entry=AsyncMock(return_value=mock_provider)), ) - assert await async_setup_component(hass, "stt", {"stt": {"platform": "test"}}) + assert await hass.config_entries.async_forward_entry_setup(mock_config_entry, "stt") async def test_get_provider_info( - hass: HomeAssistant, hass_client: ClientSessionGenerator + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + mock_config_entry: MockConfigEntry, ) -> None: """Test engine that doesn't exist.""" client = await hass_client() - response = await client.get("/api/stt/test") + response = await client.get(f"/api/stt/{mock_config_entry.entry_id}") assert response.status == HTTPStatus.OK assert await response.json() == { "languages": ["en"], @@ -116,12 +129,14 @@ async def test_get_non_existing_provider_info( async def test_stream_audio( - hass: HomeAssistant, hass_client: ClientSessionGenerator + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + mock_config_entry: MockConfigEntry, ) -> None: """Test streaming audio and getting response.""" client = await hass_client() response = await client.post( - "/api/stt/test", + f"/api/stt/{mock_config_entry.entry_id}", headers={ "X-Speech-Content": ( "format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;" @@ -155,6 +170,7 @@ async def test_stream_audio( async def test_metadata_errors( hass: HomeAssistant, hass_client: ClientSessionGenerator, + mock_config_entry: MockConfigEntry, header: str | None, status: int, error: str, @@ -165,11 +181,15 @@ async def test_metadata_errors( if header: headers["X-Speech-Content"] = header - response = await client.post("/api/stt/test", headers=headers) + response = await client.post( + f"/api/stt/{mock_config_entry.entry_id}", headers=headers + ) assert response.status == status assert await response.text() == error -async def test_get_provider(hass: HomeAssistant, mock_provider: MockProvider) -> None: +async def test_get_provider( + hass: HomeAssistant, mock_provider: MockProvider, mock_config_entry: MockConfigEntry +) -> None: """Test we can get STT providers.""" - assert mock_provider == async_get_provider(hass, "test") + assert mock_provider == async_get_provider(hass, mock_config_entry.entry_id) From 98c292eb561ffa350e1fd1596da1773bfa16da37 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Tue, 14 Mar 2023 21:32:34 -0400 Subject: [PATCH 2/5] Simplify engine component --- homeassistant/components/stt/__init__.py | 4 +- homeassistant/helpers/engine.py | 26 ------- homeassistant/helpers/engine_component.py | 91 ++++++++++++++++------- homeassistant/helpers/engine_platform.py | 65 ---------------- 4 files changed, 68 insertions(+), 118 deletions(-) delete mode 100644 homeassistant/helpers/engine.py delete mode 100644 homeassistant/helpers/engine_platform.py diff --git a/homeassistant/components/stt/__init__.py b/homeassistant/components/stt/__init__.py index caa094f88e3..cb1c646ef79 100644 --- a/homeassistant/components/stt/__init__.py +++ b/homeassistant/components/stt/__init__.py @@ -17,7 +17,7 @@ from aiohttp.web_exceptions import ( from homeassistant.components.http import HomeAssistantView from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant, callback -from homeassistant.helpers import engine, engine_component +from homeassistant.helpers import engine_component from homeassistant.helpers.typing import ConfigType from .const import ( @@ -95,7 +95,7 @@ class SpeechResult: result: SpeechResultState -class Provider(engine.Engine, ABC): +class Provider(engine_component.Engine, ABC): """Represent a single STT provider.""" hass: HomeAssistant | None = None diff --git a/homeassistant/helpers/engine.py b/homeassistant/helpers/engine.py deleted file mode 100644 index 99a7e9661a7..00000000000 --- a/homeassistant/helpers/engine.py +++ /dev/null @@ -1,26 +0,0 @@ -"""Base class for Home Assistant engines.""" - - -class Engine: - """Base class for Home Assistant engines.""" - - async def async_internal_added_to_hass(self) -> None: - """Run when service about to be added to hass. - - Not to be extended by integrations. - """ - - async def async_added_to_hass(self) -> None: - """Run when service about to be added to hass. - - Not to be extended by integrations. - """ - - async def async_internal_will_remove_from_hass(self) -> None: - """Prepare to remove the service from Home Assistant. - - Not to be extended by integrations. - """ - - async def async_will_remove_from_hass(self) -> None: - """Prepare to remove the service from Home Assistant.""" diff --git a/homeassistant/helpers/engine_component.py b/homeassistant/helpers/engine_component.py index 75873a9237a..b925f5da028 100644 --- a/homeassistant/helpers/engine_component.py +++ b/homeassistant/helpers/engine_component.py @@ -2,20 +2,55 @@ from __future__ import annotations import logging -from typing import Generic, TypeVar +from typing import Generic, Protocol, TypeVar, cast from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant, callback from homeassistant.setup import async_prepare_setup_platform -from .engine import Engine -from .engine_platform import EnginePlatform from .typing import ConfigType -_EngineT = TypeVar("_EngineT", bound=Engine) + +class Engine: + """Base class for Home Assistant engines.""" + + async def async_internal_added_to_hass(self) -> None: + """Run when service about to be added to hass. + + Not to be extended by integrations. + """ + + async def async_added_to_hass(self) -> None: + """Run when service about to be added to hass. + + Not to be extended by integrations. + """ + + async def async_internal_will_remove_from_hass(self) -> None: + """Prepare to remove the service from Home Assistant. + + Not to be extended by integrations. + """ + + async def async_will_remove_from_hass(self) -> None: + """Prepare to remove the service from Home Assistant.""" -class EngineComponent(Generic[_EngineT]): +_EngineT_co = TypeVar("_EngineT_co", bound=Engine, covariant=True) + + +class EnginePlatformModule(Protocol[_EngineT_co]): + """Protocol type for engine platform modules.""" + + async def async_setup_entry( + self, + hass: HomeAssistant, + entry: ConfigEntry, + ) -> _EngineT_co: + """Set up an integration platform from a config entry.""" + + +class EngineComponent(Generic[_EngineT_co]): """Track engines for a component.""" def __init__( @@ -30,22 +65,17 @@ class EngineComponent(Generic[_EngineT]): self.domain = domain self.hass = hass self.config = config - self._platforms: dict[str, EnginePlatform[_EngineT]] = {} + self._engines: dict[str, _EngineT_co] = {} @callback - def async_get_engine(self, config_entry_id: str) -> _EngineT | None: + def async_get_engine(self, config_entry_id: str) -> _EngineT_co | None: """Return a wrapped engine.""" - platform = self._platforms.get(config_entry_id) - return None if platform is None else platform.engine + return self._engines.get(config_entry_id) @callback - def async_get_engines(self) -> list[_EngineT]: + def async_get_engines(self) -> list[_EngineT_co]: """Return a wrapped engine.""" - return [ - platform.engine - for platform in self._platforms.values() - if platform.engine is not None - ] + return list(self._engines.values()) async def async_setup_entry(self, config_entry: ConfigEntry) -> bool: """Set up a config entry.""" @@ -64,23 +94,34 @@ class EngineComponent(Generic[_EngineT]): key = config_entry.entry_id - if key in self._platforms: + if key in self._engines: raise ValueError("Config entry has already been setup!") - self._platforms[key] = EnginePlatform( - self.logger, - self.hass, - config_entry, - platform, - ) + try: + engine = await cast( + EnginePlatformModule[_EngineT_co], platform + ).async_setup_entry(self.hass, config_entry) + await engine.async_internal_added_to_hass() + await engine.async_added_to_hass() + except Exception: # pylint: disable=broad-except + self.logger.exception("Error setting up entry %s", config_entry.entry_id) + return False - return await self._platforms[key].async_setup_entry() + self._engines[key] = engine + return True async def async_unload_entry(self, config_entry: ConfigEntry) -> bool: """Unload a config entry.""" key = config_entry.entry_id - if (platform := self._platforms.pop(key, None)) is None: + if (engine := self._engines.pop(key, None)) is None: raise ValueError("Config entry was never loaded!") - return await platform.async_unload_entry() + try: + await engine.async_internal_will_remove_from_hass() + await engine.async_will_remove_from_hass() + except Exception: # pylint: disable=broad-except + self.logger.exception("Error unloading entry %s", config_entry.entry_id) + return False + + return True diff --git a/homeassistant/helpers/engine_platform.py b/homeassistant/helpers/engine_platform.py deleted file mode 100644 index 95397213dd5..00000000000 --- a/homeassistant/helpers/engine_platform.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Service platform helper.""" -import logging -from typing import Generic, Protocol, TypeVar - -from homeassistant.config_entries import ConfigEntry -from homeassistant.core import HomeAssistant - -from .engine import Engine - -_EngineT_co = TypeVar("_EngineT_co", bound=Engine, covariant=True) - - -class EnginePlatformModule(Protocol[_EngineT_co]): - """Protocol type for engine platform modules.""" - - async def async_setup_entry( - self, - hass: HomeAssistant, - entry: ConfigEntry, - ) -> _EngineT_co: - """Set up an integration platform from a config entry.""" - - -class EnginePlatform(Generic[_EngineT_co]): - """Track engines for a platform.""" - - def __init__( - self, - logger: logging.Logger, - hass: HomeAssistant, - config_entry: ConfigEntry, - platform: EnginePlatformModule, - ) -> None: - """Initialize the engine platform.""" - self.logger = logger - self.hass = hass - self.config_entry = config_entry - self.platform = platform - self.engine: _EngineT_co | None = None - - async def async_setup_entry(self) -> bool: - """Set up a config entry.""" - try: - engine = await self.platform.async_setup_entry(self.hass, self.config_entry) - except Exception: # pylint: disable=broad-except - self.logger.exception( - "Error setting up entry %s", self.config_entry.entry_id - ) - return False - - await engine.async_internal_added_to_hass() - await engine.async_added_to_hass() - - self.engine = engine - return True - - async def async_unload_entry(self) -> bool: - """Unload a config entry.""" - if self.engine is None: - return True - - await self.engine.async_internal_will_remove_from_hass() - await self.engine.async_will_remove_from_hass() - self.engine = None - return True From 7dfb3f9af1da936f7688fd9aed8179222e3e9402 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Tue, 14 Mar 2023 21:52:52 -0400 Subject: [PATCH 3/5] Fix comment --- homeassistant/helpers/engine_component.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/homeassistant/helpers/engine_component.py b/homeassistant/helpers/engine_component.py index b925f5da028..2258eb91b4d 100644 --- a/homeassistant/helpers/engine_component.py +++ b/homeassistant/helpers/engine_component.py @@ -21,10 +21,7 @@ class Engine: """ async def async_added_to_hass(self) -> None: - """Run when service about to be added to hass. - - Not to be extended by integrations. - """ + """Run when service about to be added to hass.""" async def async_internal_will_remove_from_hass(self) -> None: """Prepare to remove the service from Home Assistant. From ec12436d1058549b50ba90c449e662d3b9218dff Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Tue, 14 Mar 2023 23:19:20 -0400 Subject: [PATCH 4/5] Update comments --- homeassistant/helpers/engine_component.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/homeassistant/helpers/engine_component.py b/homeassistant/helpers/engine_component.py index 2258eb91b4d..d699804359d 100644 --- a/homeassistant/helpers/engine_component.py +++ b/homeassistant/helpers/engine_component.py @@ -1,4 +1,4 @@ -"""Service platform helper.""" +"""Engine component helper.""" from __future__ import annotations import logging @@ -15,22 +15,22 @@ class Engine: """Base class for Home Assistant engines.""" async def async_internal_added_to_hass(self) -> None: - """Run when service about to be added to hass. + """Run when engine about to be added to Home Assistant. Not to be extended by integrations. """ async def async_added_to_hass(self) -> None: - """Run when service about to be added to hass.""" + """Run when engine about to be added to Home Assistant.""" async def async_internal_will_remove_from_hass(self) -> None: - """Prepare to remove the service from Home Assistant. + """Prepare to remove the engine from Home Assistant. Not to be extended by integrations. """ async def async_will_remove_from_hass(self) -> None: - """Prepare to remove the service from Home Assistant.""" + """Prepare to remove the engine from Home Assistant.""" _EngineT_co = TypeVar("_EngineT_co", bound=Engine, covariant=True) @@ -76,14 +76,8 @@ class EngineComponent(Generic[_EngineT_co]): async def async_setup_entry(self, config_entry: ConfigEntry) -> bool: """Set up a config entry.""" - platform_type = config_entry.domain platform = await async_prepare_setup_platform( - self.hass, - # In future PR we should make hass_config part of the constructor - # params. - self.config or {}, - self.domain, - platform_type, + self.hass, self.config, self.domain, config_entry.domain ) if platform is None: From fbec2c63db8f95e868e7f1ee53ea3a70d7bf87db Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Wed, 15 Mar 2023 15:13:10 -0400 Subject: [PATCH 5/5] Add discovery support to engine component --- homeassistant/components/stt/__init__.py | 8 ++- homeassistant/helpers/engine_component.py | 72 +++++++++++++++++++---- 2 files changed, 67 insertions(+), 13 deletions(-) diff --git a/homeassistant/components/stt/__init__.py b/homeassistant/components/stt/__init__.py index cb1c646ef79..084c3627e26 100644 --- a/homeassistant/components/stt/__init__.py +++ b/homeassistant/components/stt/__init__.py @@ -52,8 +52,12 @@ def async_get_provider( async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up STT.""" - hass.data[DOMAIN] = engine_component.EngineComponent(_LOGGER, DOMAIN, hass, config) - hass.http.register_view(SpeechToTextView(hass.data[DOMAIN])) + engines: engine_component.EngineComponent[ + Provider + ] = engine_component.EngineComponent(_LOGGER, DOMAIN, hass, config) + engines.async_setup_discovery() + hass.data[DOMAIN] = engines + hass.http.register_view(SpeechToTextView(engines)) return True diff --git a/homeassistant/helpers/engine_component.py b/homeassistant/helpers/engine_component.py index d699804359d..e058b7dfc41 100644 --- a/homeassistant/helpers/engine_component.py +++ b/homeassistant/helpers/engine_component.py @@ -1,14 +1,17 @@ """Engine component helper.""" from __future__ import annotations +from collections.abc import Awaitable, Callable import logging -from typing import Generic, Protocol, TypeVar, cast +from types import ModuleType +from typing import Generic, Protocol, TypeVar from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant, callback from homeassistant.setup import async_prepare_setup_platform -from .typing import ConfigType +from . import discovery +from .typing import ConfigType, DiscoveryInfoType class Engine: @@ -46,6 +49,12 @@ class EnginePlatformModule(Protocol[_EngineT_co]): ) -> _EngineT_co: """Set up an integration platform from a config entry.""" + async def async_setup_platform( + self, + hass: HomeAssistant, + ) -> _EngineT_co: + """Set up an integration platform async.""" + class EngineComponent(Generic[_EngineT_co]): """Track engines for a component.""" @@ -74,37 +83,78 @@ class EngineComponent(Generic[_EngineT_co]): """Return a wrapped engine.""" return list(self._engines.values()) + @callback + def async_setup_discovery(self) -> None: + """Initialize the engine component discovery.""" + + async def async_platform_discovered( + platform: str, info: DiscoveryInfoType | None + ) -> None: + """Handle for discovered platform.""" + await self.async_setup_domain(platform) + + discovery.async_listen_platform( + self.hass, self.domain, async_platform_discovered + ) + + async def async_setup_domain(self, domain: str) -> bool: + """Set up an integration.""" + + async def setup(platform: EnginePlatformModule[_EngineT_co]) -> _EngineT_co: + return await platform.async_setup_platform(self.hass) + + return await self._async_do_setup(domain, domain, setup) + async def async_setup_entry(self, config_entry: ConfigEntry) -> bool: """Set up a config entry.""" + + async def setup(platform: EnginePlatformModule[_EngineT_co]) -> _EngineT_co: + return await platform.async_setup_entry(self.hass, config_entry) + + return await self._async_do_setup( + config_entry.entry_id, config_entry.domain, setup + ) + + async def _async_do_setup( + self, + key: str, + platform_domain: str, + get_setup_coro: Callable[[ModuleType], Awaitable[_EngineT_co]], + ) -> bool: + """Set up an entry.""" platform = await async_prepare_setup_platform( - self.hass, self.config, self.domain, config_entry.domain + self.hass, self.config, self.domain, platform_domain ) if platform is None: return False - key = config_entry.entry_id - if key in self._engines: raise ValueError("Config entry has already been setup!") try: - engine = await cast( - EnginePlatformModule[_EngineT_co], platform - ).async_setup_entry(self.hass, config_entry) + engine = await get_setup_coro(platform) await engine.async_internal_added_to_hass() await engine.async_added_to_hass() except Exception: # pylint: disable=broad-except - self.logger.exception("Error setting up entry %s", config_entry.entry_id) + self.logger.exception( + "Error getting engine for %s (%s)", key, platform_domain + ) return False self._engines[key] = engine return True + async def async_unload_domain(self, domain: str) -> bool: + """Unload a domain.""" + return await self._async_do_unload(domain) + async def async_unload_entry(self, config_entry: ConfigEntry) -> bool: """Unload a config entry.""" - key = config_entry.entry_id + return await self._async_do_unload(config_entry.entry_id) + async def _async_do_unload(self, key: str) -> bool: + """Unload an engine.""" if (engine := self._engines.pop(key, None)) is None: raise ValueError("Config entry was never loaded!") @@ -112,7 +162,7 @@ class EngineComponent(Generic[_EngineT_co]): await engine.async_internal_will_remove_from_hass() await engine.async_will_remove_from_hass() except Exception: # pylint: disable=broad-except - self.logger.exception("Error unloading entry %s", config_entry.entry_id) + self.logger.exception("Error unloading entry %s", key) return False return True