Compare commits
5 commits
dev
...
engine-com
Author | SHA1 | Date | |
---|---|---|---|
|
fbec2c63db | ||
|
ec12436d10 | ||
|
7dfb3f9af1 | ||
|
98c292eb56 | ||
|
aeccd525f5 |
5 changed files with 248 additions and 72 deletions
|
@ -16,6 +16,8 @@ from homeassistant.components.stt import (
|
|||
SpeechResult,
|
||||
SpeechResultState,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .const import DOMAIN
|
||||
|
||||
|
@ -43,7 +45,10 @@ SUPPORT_LANGUAGES = [
|
|||
]
|
||||
|
||||
|
||||
async def async_get_engine(hass, config, discovery_info=None):
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
) -> Provider:
|
||||
"""Set up Cloud speech component."""
|
||||
cloud: Cloud = hass.data[DOMAIN]
|
||||
|
||||
|
|
|
@ -14,16 +14,15 @@ from homeassistant.components.stt import (
|
|||
SpeechResult,
|
||||
SpeechResultState,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
|
||||
SUPPORT_LANGUAGES = ["en", "de"]
|
||||
|
||||
|
||||
async def async_get_engine(
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
discovery_info: DiscoveryInfoType | None = None,
|
||||
config_entry: ConfigEntry,
|
||||
) -> Provider:
|
||||
"""Set up Demo speech component."""
|
||||
return DemoProvider()
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
from dataclasses import asdict, dataclass
|
||||
import logging
|
||||
from typing import Any
|
||||
|
@ -16,10 +15,10 @@ from aiohttp.web_exceptions import (
|
|||
)
|
||||
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import config_per_platform, discovery
|
||||
from homeassistant.helpers import engine_component
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.setup import async_prepare_setup_platform
|
||||
|
||||
from .const import (
|
||||
DOMAIN,
|
||||
|
@ -35,60 +34,45 @@ _LOGGER = logging.getLogger(__name__)
|
|||
|
||||
|
||||
@callback
|
||||
def async_get_provider(hass: HomeAssistant, domain: str | None = None) -> Provider:
|
||||
def async_get_provider(
|
||||
hass: HomeAssistant, provider: str | None = None
|
||||
) -> Provider | None:
|
||||
"""Return provider."""
|
||||
if domain is None:
|
||||
domain = next(iter(hass.data[DOMAIN]))
|
||||
component: engine_component.EngineComponent[Provider] | None = hass.data.get(DOMAIN)
|
||||
|
||||
return hass.data[DOMAIN][domain]
|
||||
if component is None:
|
||||
return None
|
||||
|
||||
if provider is None:
|
||||
providers = component.async_get_engines()
|
||||
return providers[0] if providers else None
|
||||
|
||||
return component.async_get_engine(provider)
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up STT."""
|
||||
providers = hass.data[DOMAIN] = {}
|
||||
|
||||
async def async_setup_platform(p_type, p_config=None, discovery_info=None):
|
||||
"""Set up a TTS platform."""
|
||||
if p_config is None:
|
||||
p_config = {}
|
||||
|
||||
platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type)
|
||||
if platform is None:
|
||||
return
|
||||
|
||||
try:
|
||||
provider = await platform.async_get_engine(hass, p_config, discovery_info)
|
||||
if provider is None:
|
||||
_LOGGER.error("Error setting up platform %s", p_type)
|
||||
return
|
||||
|
||||
provider.name = p_type
|
||||
provider.hass = hass
|
||||
|
||||
providers[provider.name] = provider
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Error setting up platform: %s", p_type)
|
||||
return
|
||||
|
||||
setup_tasks = [
|
||||
asyncio.create_task(async_setup_platform(p_type, p_config))
|
||||
for p_type, p_config in config_per_platform(config, DOMAIN)
|
||||
]
|
||||
|
||||
if setup_tasks:
|
||||
await asyncio.wait(setup_tasks)
|
||||
|
||||
# Add discovery support
|
||||
async def async_platform_discovered(platform, info):
|
||||
"""Handle for discovered platform."""
|
||||
await async_setup_platform(platform, discovery_info=info)
|
||||
|
||||
discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered)
|
||||
|
||||
hass.http.register_view(SpeechToTextView(providers))
|
||||
engines: engine_component.EngineComponent[
|
||||
Provider
|
||||
] = engine_component.EngineComponent(_LOGGER, DOMAIN, hass, config)
|
||||
engines.async_setup_discovery()
|
||||
hass.data[DOMAIN] = engines
|
||||
hass.http.register_view(SpeechToTextView(engines))
|
||||
return True
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up a config entry."""
|
||||
component: engine_component.EngineComponent[Provider] = hass.data[DOMAIN]
|
||||
return await component.async_setup_entry(entry)
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
component: engine_component.EngineComponent[Provider] = hass.data[DOMAIN]
|
||||
return await component.async_unload_entry(entry)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeechMetadata:
|
||||
"""Metadata of audio stream."""
|
||||
|
@ -115,7 +99,7 @@ class SpeechResult:
|
|||
result: SpeechResultState
|
||||
|
||||
|
||||
class Provider(ABC):
|
||||
class Provider(engine_component.Engine, ABC):
|
||||
"""Represent a single STT provider."""
|
||||
|
||||
hass: HomeAssistant | None = None
|
||||
|
@ -182,15 +166,15 @@ class SpeechToTextView(HomeAssistantView):
|
|||
url = "/api/stt/{provider}"
|
||||
name = "api:stt:provider"
|
||||
|
||||
def __init__(self, providers: dict[str, Provider]) -> None:
|
||||
def __init__(self, engines: engine_component.EngineComponent[Provider]) -> None:
|
||||
"""Initialize a tts view."""
|
||||
self.providers = providers
|
||||
self.engines = engines
|
||||
|
||||
async def post(self, request: web.Request, provider: str) -> web.Response:
|
||||
"""Convert Speech (audio) to text."""
|
||||
if provider not in self.providers:
|
||||
stt_provider = self.engines.async_get_engine(provider)
|
||||
if stt_provider is None:
|
||||
raise HTTPNotFound()
|
||||
stt_provider: Provider = self.providers[provider]
|
||||
|
||||
# Get metadata
|
||||
try:
|
||||
|
@ -212,9 +196,9 @@ class SpeechToTextView(HomeAssistantView):
|
|||
|
||||
async def get(self, request: web.Request, provider: str) -> web.Response:
|
||||
"""Return provider specific audio information."""
|
||||
if provider not in self.providers:
|
||||
stt_provider = self.engines.async_get_engine(provider)
|
||||
if stt_provider is None:
|
||||
raise HTTPNotFound()
|
||||
stt_provider: Provider = self.providers[provider]
|
||||
|
||||
return self.json(
|
||||
{
|
||||
|
|
168
homeassistant/helpers/engine_component.py
Normal file
168
homeassistant/helpers/engine_component.py
Normal file
|
@ -0,0 +1,168 @@
|
|||
"""Engine component helper."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
import logging
|
||||
from types import ModuleType
|
||||
from typing import Generic, Protocol, TypeVar
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.setup import async_prepare_setup_platform
|
||||
|
||||
from . import discovery
|
||||
from .typing import ConfigType, DiscoveryInfoType
|
||||
|
||||
|
||||
class Engine:
|
||||
"""Base class for Home Assistant engines."""
|
||||
|
||||
async def async_internal_added_to_hass(self) -> None:
|
||||
"""Run when engine about to be added to Home Assistant.
|
||||
|
||||
Not to be extended by integrations.
|
||||
"""
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Run when engine about to be added to Home Assistant."""
|
||||
|
||||
async def async_internal_will_remove_from_hass(self) -> None:
|
||||
"""Prepare to remove the engine from Home Assistant.
|
||||
|
||||
Not to be extended by integrations.
|
||||
"""
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Prepare to remove the engine from Home Assistant."""
|
||||
|
||||
|
||||
_EngineT_co = TypeVar("_EngineT_co", bound=Engine, covariant=True)
|
||||
|
||||
|
||||
class EnginePlatformModule(Protocol[_EngineT_co]):
|
||||
"""Protocol type for engine platform modules."""
|
||||
|
||||
async def async_setup_entry(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
entry: ConfigEntry,
|
||||
) -> _EngineT_co:
|
||||
"""Set up an integration platform from a config entry."""
|
||||
|
||||
async def async_setup_platform(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
) -> _EngineT_co:
|
||||
"""Set up an integration platform async."""
|
||||
|
||||
|
||||
class EngineComponent(Generic[_EngineT_co]):
|
||||
"""Track engines for a component."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logger: logging.Logger,
|
||||
domain: str,
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
) -> None:
|
||||
"""Initialize the engine component."""
|
||||
self.logger = logger
|
||||
self.domain = domain
|
||||
self.hass = hass
|
||||
self.config = config
|
||||
self._engines: dict[str, _EngineT_co] = {}
|
||||
|
||||
@callback
|
||||
def async_get_engine(self, config_entry_id: str) -> _EngineT_co | None:
|
||||
"""Return a wrapped engine."""
|
||||
return self._engines.get(config_entry_id)
|
||||
|
||||
@callback
|
||||
def async_get_engines(self) -> list[_EngineT_co]:
|
||||
"""Return a wrapped engine."""
|
||||
return list(self._engines.values())
|
||||
|
||||
@callback
|
||||
def async_setup_discovery(self) -> None:
|
||||
"""Initialize the engine component discovery."""
|
||||
|
||||
async def async_platform_discovered(
|
||||
platform: str, info: DiscoveryInfoType | None
|
||||
) -> None:
|
||||
"""Handle for discovered platform."""
|
||||
await self.async_setup_domain(platform)
|
||||
|
||||
discovery.async_listen_platform(
|
||||
self.hass, self.domain, async_platform_discovered
|
||||
)
|
||||
|
||||
async def async_setup_domain(self, domain: str) -> bool:
|
||||
"""Set up an integration."""
|
||||
|
||||
async def setup(platform: EnginePlatformModule[_EngineT_co]) -> _EngineT_co:
|
||||
return await platform.async_setup_platform(self.hass)
|
||||
|
||||
return await self._async_do_setup(domain, domain, setup)
|
||||
|
||||
async def async_setup_entry(self, config_entry: ConfigEntry) -> bool:
|
||||
"""Set up a config entry."""
|
||||
|
||||
async def setup(platform: EnginePlatformModule[_EngineT_co]) -> _EngineT_co:
|
||||
return await platform.async_setup_entry(self.hass, config_entry)
|
||||
|
||||
return await self._async_do_setup(
|
||||
config_entry.entry_id, config_entry.domain, setup
|
||||
)
|
||||
|
||||
async def _async_do_setup(
|
||||
self,
|
||||
key: str,
|
||||
platform_domain: str,
|
||||
get_setup_coro: Callable[[ModuleType], Awaitable[_EngineT_co]],
|
||||
) -> bool:
|
||||
"""Set up an entry."""
|
||||
platform = await async_prepare_setup_platform(
|
||||
self.hass, self.config, self.domain, platform_domain
|
||||
)
|
||||
|
||||
if platform is None:
|
||||
return False
|
||||
|
||||
if key in self._engines:
|
||||
raise ValueError("Config entry has already been setup!")
|
||||
|
||||
try:
|
||||
engine = await get_setup_coro(platform)
|
||||
await engine.async_internal_added_to_hass()
|
||||
await engine.async_added_to_hass()
|
||||
except Exception: # pylint: disable=broad-except
|
||||
self.logger.exception(
|
||||
"Error getting engine for %s (%s)", key, platform_domain
|
||||
)
|
||||
return False
|
||||
|
||||
self._engines[key] = engine
|
||||
return True
|
||||
|
||||
async def async_unload_domain(self, domain: str) -> bool:
|
||||
"""Unload a domain."""
|
||||
return await self._async_do_unload(domain)
|
||||
|
||||
async def async_unload_entry(self, config_entry: ConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
return await self._async_do_unload(config_entry.entry_id)
|
||||
|
||||
async def _async_do_unload(self, key: str) -> bool:
|
||||
"""Unload an engine."""
|
||||
if (engine := self._engines.pop(key, None)) is None:
|
||||
raise ValueError("Config entry was never loaded!")
|
||||
|
||||
try:
|
||||
await engine.async_internal_will_remove_from_hass()
|
||||
await engine.async_will_remove_from_hass()
|
||||
except Exception: # pylint: disable=broad-except
|
||||
self.logger.exception("Error unloading entry %s", key)
|
||||
return False
|
||||
|
||||
return True
|
|
@ -18,9 +18,8 @@ from homeassistant.components.stt import (
|
|||
async_get_provider,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import mock_platform
|
||||
from tests.common import MockConfigEntry, mock_platform
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
|
||||
|
@ -80,21 +79,35 @@ def mock_provider() -> MockProvider:
|
|||
return MockProvider()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_entry(hass) -> MockConfigEntry:
|
||||
"""Mock config entry."""
|
||||
config_entry = MockConfigEntry()
|
||||
config_entry.add_to_hass(hass)
|
||||
return config_entry
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def mock_setup(hass: HomeAssistant, mock_provider: MockProvider) -> None:
|
||||
async def mock_setup(
|
||||
hass: HomeAssistant, mock_provider: MockProvider, mock_config_entry: MockConfigEntry
|
||||
) -> None:
|
||||
"""Set up a test provider."""
|
||||
mock_platform(
|
||||
hass, "test.stt", Mock(async_get_engine=AsyncMock(return_value=mock_provider))
|
||||
hass,
|
||||
"test.stt",
|
||||
Mock(async_setup_entry=AsyncMock(return_value=mock_provider)),
|
||||
)
|
||||
assert await async_setup_component(hass, "stt", {"stt": {"platform": "test"}})
|
||||
assert await hass.config_entries.async_forward_entry_setup(mock_config_entry, "stt")
|
||||
|
||||
|
||||
async def test_get_provider_info(
|
||||
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test engine that doesn't exist."""
|
||||
client = await hass_client()
|
||||
response = await client.get("/api/stt/test")
|
||||
response = await client.get(f"/api/stt/{mock_config_entry.entry_id}")
|
||||
assert response.status == HTTPStatus.OK
|
||||
assert await response.json() == {
|
||||
"languages": ["en"],
|
||||
|
@ -116,12 +129,14 @@ async def test_get_non_existing_provider_info(
|
|||
|
||||
|
||||
async def test_stream_audio(
|
||||
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test streaming audio and getting response."""
|
||||
client = await hass_client()
|
||||
response = await client.post(
|
||||
"/api/stt/test",
|
||||
f"/api/stt/{mock_config_entry.entry_id}",
|
||||
headers={
|
||||
"X-Speech-Content": (
|
||||
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;"
|
||||
|
@ -155,6 +170,7 @@ async def test_stream_audio(
|
|||
async def test_metadata_errors(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
header: str | None,
|
||||
status: int,
|
||||
error: str,
|
||||
|
@ -165,11 +181,15 @@ async def test_metadata_errors(
|
|||
if header:
|
||||
headers["X-Speech-Content"] = header
|
||||
|
||||
response = await client.post("/api/stt/test", headers=headers)
|
||||
response = await client.post(
|
||||
f"/api/stt/{mock_config_entry.entry_id}", headers=headers
|
||||
)
|
||||
assert response.status == status
|
||||
assert await response.text() == error
|
||||
|
||||
|
||||
async def test_get_provider(hass: HomeAssistant, mock_provider: MockProvider) -> None:
|
||||
async def test_get_provider(
|
||||
hass: HomeAssistant, mock_provider: MockProvider, mock_config_entry: MockConfigEntry
|
||||
) -> None:
|
||||
"""Test we can get STT providers."""
|
||||
assert mock_provider == async_get_provider(hass, "test")
|
||||
assert mock_provider == async_get_provider(hass, mock_config_entry.entry_id)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue