Compare commits

...
Sign in to create a new pull request.

5 commits

Author SHA1 Message Date
Paulus Schoutsen
fbec2c63db Add discovery support to engine component 2023-03-15 15:13:10 -04:00
Paulus Schoutsen
ec12436d10 Update comments 2023-03-14 23:19:20 -04:00
Paulus Schoutsen
7dfb3f9af1 Fix comment 2023-03-14 21:52:52 -04:00
Paulus Schoutsen
98c292eb56 Simplify engine component 2023-03-14 21:32:34 -04:00
Paulus Schoutsen
aeccd525f5 Introduce engine component 2023-03-14 00:04:16 -04:00
5 changed files with 248 additions and 72 deletions

View file

@ -16,6 +16,8 @@ from homeassistant.components.stt import (
SpeechResult,
SpeechResultState,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from .const import DOMAIN
@ -43,7 +45,10 @@ SUPPORT_LANGUAGES = [
]
async def async_get_engine(hass, config, discovery_info=None):
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
) -> Provider:
"""Set up Cloud speech component."""
cloud: Cloud = hass.data[DOMAIN]

View file

@ -14,16 +14,15 @@ from homeassistant.components.stt import (
SpeechResult,
SpeechResultState,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
SUPPORT_LANGUAGES = ["en", "de"]
async def async_get_engine(
async def async_setup_entry(
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
config_entry: ConfigEntry,
) -> Provider:
"""Set up Demo speech component."""
return DemoProvider()

View file

@ -2,7 +2,6 @@
from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from dataclasses import asdict, dataclass
import logging
from typing import Any
@ -16,10 +15,10 @@ from aiohttp.web_exceptions import (
)
from homeassistant.components.http import HomeAssistantView
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_per_platform, discovery
from homeassistant.helpers import engine_component
from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import async_prepare_setup_platform
from .const import (
DOMAIN,
@ -35,60 +34,45 @@ _LOGGER = logging.getLogger(__name__)
@callback
def async_get_provider(hass: HomeAssistant, domain: str | None = None) -> Provider:
def async_get_provider(
hass: HomeAssistant, provider: str | None = None
) -> Provider | None:
"""Return provider."""
if domain is None:
domain = next(iter(hass.data[DOMAIN]))
component: engine_component.EngineComponent[Provider] | None = hass.data.get(DOMAIN)
return hass.data[DOMAIN][domain]
if component is None:
return None
if provider is None:
providers = component.async_get_engines()
return providers[0] if providers else None
return component.async_get_engine(provider)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up STT."""
providers = hass.data[DOMAIN] = {}
async def async_setup_platform(p_type, p_config=None, discovery_info=None):
"""Set up a TTS platform."""
if p_config is None:
p_config = {}
platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type)
if platform is None:
return
try:
provider = await platform.async_get_engine(hass, p_config, discovery_info)
if provider is None:
_LOGGER.error("Error setting up platform %s", p_type)
return
provider.name = p_type
provider.hass = hass
providers[provider.name] = provider
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error setting up platform: %s", p_type)
return
setup_tasks = [
asyncio.create_task(async_setup_platform(p_type, p_config))
for p_type, p_config in config_per_platform(config, DOMAIN)
]
if setup_tasks:
await asyncio.wait(setup_tasks)
# Add discovery support
async def async_platform_discovered(platform, info):
"""Handle for discovered platform."""
await async_setup_platform(platform, discovery_info=info)
discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered)
hass.http.register_view(SpeechToTextView(providers))
engines: engine_component.EngineComponent[
Provider
] = engine_component.EngineComponent(_LOGGER, DOMAIN, hass, config)
engines.async_setup_discovery()
hass.data[DOMAIN] = engines
hass.http.register_view(SpeechToTextView(engines))
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
component: engine_component.EngineComponent[Provider] = hass.data[DOMAIN]
return await component.async_setup_entry(entry)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
component: engine_component.EngineComponent[Provider] = hass.data[DOMAIN]
return await component.async_unload_entry(entry)
@dataclass
class SpeechMetadata:
"""Metadata of audio stream."""
@ -115,7 +99,7 @@ class SpeechResult:
result: SpeechResultState
class Provider(ABC):
class Provider(engine_component.Engine, ABC):
"""Represent a single STT provider."""
hass: HomeAssistant | None = None
@ -182,15 +166,15 @@ class SpeechToTextView(HomeAssistantView):
url = "/api/stt/{provider}"
name = "api:stt:provider"
def __init__(self, providers: dict[str, Provider]) -> None:
def __init__(self, engines: engine_component.EngineComponent[Provider]) -> None:
"""Initialize a tts view."""
self.providers = providers
self.engines = engines
async def post(self, request: web.Request, provider: str) -> web.Response:
"""Convert Speech (audio) to text."""
if provider not in self.providers:
stt_provider = self.engines.async_get_engine(provider)
if stt_provider is None:
raise HTTPNotFound()
stt_provider: Provider = self.providers[provider]
# Get metadata
try:
@ -212,9 +196,9 @@ class SpeechToTextView(HomeAssistantView):
async def get(self, request: web.Request, provider: str) -> web.Response:
"""Return provider specific audio information."""
if provider not in self.providers:
stt_provider = self.engines.async_get_engine(provider)
if stt_provider is None:
raise HTTPNotFound()
stt_provider: Provider = self.providers[provider]
return self.json(
{

View file

@ -0,0 +1,168 @@
"""Engine component helper."""
from __future__ import annotations
from collections.abc import Awaitable, Callable
import logging
from types import ModuleType
from typing import Generic, Protocol, TypeVar
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback
from homeassistant.setup import async_prepare_setup_platform
from . import discovery
from .typing import ConfigType, DiscoveryInfoType
class Engine:
"""Base class for Home Assistant engines."""
async def async_internal_added_to_hass(self) -> None:
"""Run when engine about to be added to Home Assistant.
Not to be extended by integrations.
"""
async def async_added_to_hass(self) -> None:
"""Run when engine about to be added to Home Assistant."""
async def async_internal_will_remove_from_hass(self) -> None:
"""Prepare to remove the engine from Home Assistant.
Not to be extended by integrations.
"""
async def async_will_remove_from_hass(self) -> None:
"""Prepare to remove the engine from Home Assistant."""
_EngineT_co = TypeVar("_EngineT_co", bound=Engine, covariant=True)
class EnginePlatformModule(Protocol[_EngineT_co]):
"""Protocol type for engine platform modules."""
async def async_setup_entry(
self,
hass: HomeAssistant,
entry: ConfigEntry,
) -> _EngineT_co:
"""Set up an integration platform from a config entry."""
async def async_setup_platform(
self,
hass: HomeAssistant,
) -> _EngineT_co:
"""Set up an integration platform async."""
class EngineComponent(Generic[_EngineT_co]):
"""Track engines for a component."""
def __init__(
self,
logger: logging.Logger,
domain: str,
hass: HomeAssistant,
config: ConfigType,
) -> None:
"""Initialize the engine component."""
self.logger = logger
self.domain = domain
self.hass = hass
self.config = config
self._engines: dict[str, _EngineT_co] = {}
@callback
def async_get_engine(self, config_entry_id: str) -> _EngineT_co | None:
"""Return a wrapped engine."""
return self._engines.get(config_entry_id)
@callback
def async_get_engines(self) -> list[_EngineT_co]:
"""Return a wrapped engine."""
return list(self._engines.values())
@callback
def async_setup_discovery(self) -> None:
"""Initialize the engine component discovery."""
async def async_platform_discovered(
platform: str, info: DiscoveryInfoType | None
) -> None:
"""Handle for discovered platform."""
await self.async_setup_domain(platform)
discovery.async_listen_platform(
self.hass, self.domain, async_platform_discovered
)
async def async_setup_domain(self, domain: str) -> bool:
"""Set up an integration."""
async def setup(platform: EnginePlatformModule[_EngineT_co]) -> _EngineT_co:
return await platform.async_setup_platform(self.hass)
return await self._async_do_setup(domain, domain, setup)
async def async_setup_entry(self, config_entry: ConfigEntry) -> bool:
"""Set up a config entry."""
async def setup(platform: EnginePlatformModule[_EngineT_co]) -> _EngineT_co:
return await platform.async_setup_entry(self.hass, config_entry)
return await self._async_do_setup(
config_entry.entry_id, config_entry.domain, setup
)
async def _async_do_setup(
self,
key: str,
platform_domain: str,
get_setup_coro: Callable[[ModuleType], Awaitable[_EngineT_co]],
) -> bool:
"""Set up an entry."""
platform = await async_prepare_setup_platform(
self.hass, self.config, self.domain, platform_domain
)
if platform is None:
return False
if key in self._engines:
raise ValueError("Config entry has already been setup!")
try:
engine = await get_setup_coro(platform)
await engine.async_internal_added_to_hass()
await engine.async_added_to_hass()
except Exception: # pylint: disable=broad-except
self.logger.exception(
"Error getting engine for %s (%s)", key, platform_domain
)
return False
self._engines[key] = engine
return True
async def async_unload_domain(self, domain: str) -> bool:
"""Unload a domain."""
return await self._async_do_unload(domain)
async def async_unload_entry(self, config_entry: ConfigEntry) -> bool:
"""Unload a config entry."""
return await self._async_do_unload(config_entry.entry_id)
async def _async_do_unload(self, key: str) -> bool:
"""Unload an engine."""
if (engine := self._engines.pop(key, None)) is None:
raise ValueError("Config entry was never loaded!")
try:
await engine.async_internal_will_remove_from_hass()
await engine.async_will_remove_from_hass()
except Exception: # pylint: disable=broad-except
self.logger.exception("Error unloading entry %s", key)
return False
return True

View file

@ -18,9 +18,8 @@ from homeassistant.components.stt import (
async_get_provider,
)
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from tests.common import mock_platform
from tests.common import MockConfigEntry, mock_platform
from tests.typing import ClientSessionGenerator
@ -80,21 +79,35 @@ def mock_provider() -> MockProvider:
return MockProvider()
@pytest.fixture
def mock_config_entry(hass) -> MockConfigEntry:
"""Mock config entry."""
config_entry = MockConfigEntry()
config_entry.add_to_hass(hass)
return config_entry
@pytest.fixture(autouse=True)
async def mock_setup(hass: HomeAssistant, mock_provider: MockProvider) -> None:
async def mock_setup(
hass: HomeAssistant, mock_provider: MockProvider, mock_config_entry: MockConfigEntry
) -> None:
"""Set up a test provider."""
mock_platform(
hass, "test.stt", Mock(async_get_engine=AsyncMock(return_value=mock_provider))
hass,
"test.stt",
Mock(async_setup_entry=AsyncMock(return_value=mock_provider)),
)
assert await async_setup_component(hass, "stt", {"stt": {"platform": "test"}})
assert await hass.config_entries.async_forward_entry_setup(mock_config_entry, "stt")
async def test_get_provider_info(
hass: HomeAssistant, hass_client: ClientSessionGenerator
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test engine that doesn't exist."""
client = await hass_client()
response = await client.get("/api/stt/test")
response = await client.get(f"/api/stt/{mock_config_entry.entry_id}")
assert response.status == HTTPStatus.OK
assert await response.json() == {
"languages": ["en"],
@ -116,12 +129,14 @@ async def test_get_non_existing_provider_info(
async def test_stream_audio(
hass: HomeAssistant, hass_client: ClientSessionGenerator
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test streaming audio and getting response."""
client = await hass_client()
response = await client.post(
"/api/stt/test",
f"/api/stt/{mock_config_entry.entry_id}",
headers={
"X-Speech-Content": (
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;"
@ -155,6 +170,7 @@ async def test_stream_audio(
async def test_metadata_errors(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_config_entry: MockConfigEntry,
header: str | None,
status: int,
error: str,
@ -165,11 +181,15 @@ async def test_metadata_errors(
if header:
headers["X-Speech-Content"] = header
response = await client.post("/api/stt/test", headers=headers)
response = await client.post(
f"/api/stt/{mock_config_entry.entry_id}", headers=headers
)
assert response.status == status
assert await response.text() == error
async def test_get_provider(hass: HomeAssistant, mock_provider: MockProvider) -> None:
async def test_get_provider(
hass: HomeAssistant, mock_provider: MockProvider, mock_config_entry: MockConfigEntry
) -> None:
"""Test we can get STT providers."""
assert mock_provider == async_get_provider(hass, "test")
assert mock_provider == async_get_provider(hass, mock_config_entry.entry_id)