Move legacy stt (#90776)

* Move legacy stt to separate module

* Remove case for None as provider

* Add error log for unknown platform

* Add some tests
This commit is contained in:
Martin Hjelmare 2023-04-04 14:52:36 +02:00 committed by GitHub
parent 584066b809
commit 535fb34207
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 322 additions and 158 deletions

View file

@ -1,11 +1,8 @@
"""Provide functionality to STT."""
from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from collections.abc import AsyncIterable
from dataclasses import asdict, dataclass
import logging
from dataclasses import asdict
from typing import Any
from aiohttp import web
@ -17,10 +14,8 @@ from aiohttp.web_exceptions import (
)
from homeassistant.components.http import HomeAssistantView
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_per_platform, discovery
from homeassistant.core import HomeAssistant
from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import async_prepare_setup_platform
from .const import (
DOMAIN,
@ -31,159 +26,40 @@ from .const import (
AudioSampleRates,
SpeechResultState,
)
from .legacy import (
Provider,
SpeechMetadata,
SpeechResult,
async_get_provider,
async_setup_legacy,
)
_LOGGER = logging.getLogger(__name__)
@callback
def async_get_provider(
hass: HomeAssistant, domain: str | None = None
) -> Provider | None:
"""Return provider."""
if domain:
return hass.data[DOMAIN].get(domain)
if not hass.data[DOMAIN]:
return None
if "cloud" in hass.data[DOMAIN]:
return hass.data[DOMAIN]["cloud"]
return next(iter(hass.data[DOMAIN].values()))
__all__ = [
"async_get_provider",
"AudioBitRates",
"AudioChannels",
"AudioCodecs",
"AudioFormats",
"AudioSampleRates",
"DOMAIN",
"Provider",
"SpeechMetadata",
"SpeechResult",
"SpeechResultState",
]
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up STT."""
providers = hass.data[DOMAIN] = {}
platform_setups = async_setup_legacy(hass, config)
async def async_setup_platform(p_type, p_config=None, discovery_info=None):
"""Set up a TTS platform."""
if p_config is None:
p_config = {}
if platform_setups:
await asyncio.wait([asyncio.create_task(setup) for setup in platform_setups])
platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type)
if platform is None:
return
try:
provider = await platform.async_get_engine(hass, p_config, discovery_info)
if provider is None:
_LOGGER.error("Error setting up platform %s", p_type)
return
provider.name = p_type
provider.hass = hass
providers[provider.name] = provider
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error setting up platform: %s", p_type)
return
setup_tasks = [
asyncio.create_task(async_setup_platform(p_type, p_config))
for p_type, p_config in config_per_platform(config, DOMAIN)
]
if setup_tasks:
await asyncio.wait(setup_tasks)
# Add discovery support
async def async_platform_discovered(platform, info):
"""Handle for discovered platform."""
await async_setup_platform(platform, discovery_info=info)
discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered)
hass.http.register_view(SpeechToTextView(providers))
hass.http.register_view(SpeechToTextView(hass.data[DOMAIN]))
return True
@dataclass
class SpeechMetadata:
"""Metadata of audio stream."""
language: str
format: AudioFormats
codec: AudioCodecs
bit_rate: AudioBitRates
sample_rate: AudioSampleRates
channel: AudioChannels
def __post_init__(self) -> None:
"""Finish initializing the metadata."""
self.bit_rate = AudioBitRates(int(self.bit_rate))
self.sample_rate = AudioSampleRates(int(self.sample_rate))
self.channel = AudioChannels(int(self.channel))
@dataclass
class SpeechResult:
"""Result of audio Speech."""
text: str | None
result: SpeechResultState
class Provider(ABC):
"""Represent a single STT provider."""
hass: HomeAssistant | None = None
name: str | None = None
@property
@abstractmethod
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
@property
@abstractmethod
def supported_formats(self) -> list[AudioFormats]:
"""Return a list of supported formats."""
@property
@abstractmethod
def supported_codecs(self) -> list[AudioCodecs]:
"""Return a list of supported codecs."""
@property
@abstractmethod
def supported_bit_rates(self) -> list[AudioBitRates]:
"""Return a list of supported bit rates."""
@property
@abstractmethod
def supported_sample_rates(self) -> list[AudioSampleRates]:
"""Return a list of supported sample rates."""
@property
@abstractmethod
def supported_channels(self) -> list[AudioChannels]:
"""Return a list of supported channels."""
@abstractmethod
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
) -> SpeechResult:
"""Process an audio stream to STT service.
Only streaming of content are allow!
"""
@callback
def check_metadata(self, metadata: SpeechMetadata) -> bool:
"""Check if given metadata supported by this provider."""
if (
metadata.language not in self.supported_languages
or metadata.format not in self.supported_formats
or metadata.codec not in self.supported_codecs
or metadata.bit_rate not in self.supported_bit_rates
or metadata.sample_rate not in self.supported_sample_rates
or metadata.channel not in self.supported_channels
):
return False
return True
class SpeechToTextView(HomeAssistantView):
"""STT view to generate a text from audio stream."""
@ -203,7 +79,7 @@ class SpeechToTextView(HomeAssistantView):
# Get metadata
try:
metadata = metadata_from_header(request)
metadata = _metadata_from_header(request)
except ValueError as err:
raise HTTPBadRequest(text=str(err)) from err
@ -237,7 +113,7 @@ class SpeechToTextView(HomeAssistantView):
)
def metadata_from_header(request: web.Request) -> SpeechMetadata:
def _metadata_from_header(request: web.Request) -> SpeechMetadata:
"""Extract STT metadata from header.
X-Speech-Content:

View file

@ -0,0 +1,169 @@
"""Handle legacy speech to text platforms."""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import AsyncIterable, Coroutine
from dataclasses import dataclass
import logging
from typing import Any
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_per_platform, discovery
from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import async_prepare_setup_platform
from .const import (
DOMAIN,
AudioBitRates,
AudioChannels,
AudioCodecs,
AudioFormats,
AudioSampleRates,
SpeechResultState,
)
_LOGGER = logging.getLogger(__name__)
@callback
def async_get_provider(
hass: HomeAssistant, domain: str | None = None
) -> Provider | None:
"""Return provider."""
if domain:
return hass.data[DOMAIN].get(domain)
if not hass.data[DOMAIN]:
return None
if "cloud" in hass.data[DOMAIN]:
return hass.data[DOMAIN]["cloud"]
return next(iter(hass.data[DOMAIN].values()))
@callback
def async_setup_legacy(
hass: HomeAssistant, config: ConfigType
) -> list[Coroutine[Any, Any, None]]:
"""Set up legacy speech to text providers."""
providers = hass.data[DOMAIN] = {}
async def async_setup_platform(p_type, p_config=None, discovery_info=None):
"""Set up a TTS platform."""
if p_config is None:
p_config = {}
platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type)
if platform is None:
_LOGGER.error("Unknown speech to text platform specified")
return
try:
provider = await platform.async_get_engine(hass, p_config, discovery_info)
provider.name = p_type
provider.hass = hass
providers[provider.name] = provider
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error setting up platform: %s", p_type)
return
# Add discovery support
async def async_platform_discovered(platform, info):
"""Handle for discovered platform."""
await async_setup_platform(platform, discovery_info=info)
discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered)
return [
async_setup_platform(p_type, p_config)
for p_type, p_config in config_per_platform(config, DOMAIN)
]
@dataclass
class SpeechMetadata:
"""Metadata of audio stream."""
language: str
format: AudioFormats
codec: AudioCodecs
bit_rate: AudioBitRates
sample_rate: AudioSampleRates
channel: AudioChannels
def __post_init__(self) -> None:
"""Finish initializing the metadata."""
self.bit_rate = AudioBitRates(int(self.bit_rate))
self.sample_rate = AudioSampleRates(int(self.sample_rate))
self.channel = AudioChannels(int(self.channel))
@dataclass
class SpeechResult:
"""Result of audio Speech."""
text: str | None
result: SpeechResultState
class Provider(ABC):
"""Represent a single STT provider."""
hass: HomeAssistant | None = None
name: str | None = None
@property
@abstractmethod
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
@property
@abstractmethod
def supported_formats(self) -> list[AudioFormats]:
"""Return a list of supported formats."""
@property
@abstractmethod
def supported_codecs(self) -> list[AudioCodecs]:
"""Return a list of supported codecs."""
@property
@abstractmethod
def supported_bit_rates(self) -> list[AudioBitRates]:
"""Return a list of supported bit rates."""
@property
@abstractmethod
def supported_sample_rates(self) -> list[AudioSampleRates]:
"""Return a list of supported sample rates."""
@property
@abstractmethod
def supported_channels(self) -> list[AudioChannels]:
"""Return a list of supported channels."""
@abstractmethod
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
) -> SpeechResult:
"""Process an audio stream to STT service.
Only streaming of content are allow!
"""
@callback
def check_metadata(self, metadata: SpeechMetadata) -> bool:
"""Check if given metadata supported by this provider."""
if (
metadata.language not in self.supported_languages
or metadata.format not in self.supported_formats
or metadata.codec not in self.supported_codecs
or metadata.bit_rate not in self.supported_bit_rates
or metadata.sample_rate not in self.supported_sample_rates
or metadata.channel not in self.supported_channels
):
return False
return True

View file

@ -0,0 +1,56 @@
"""Provide common test tools for STT."""
from __future__ import annotations
from collections.abc import Callable, Coroutine
from pathlib import Path
from typing import Any
from homeassistant.components.stt import Provider
from homeassistant.core import HomeAssistant
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from tests.common import MockPlatform, mock_platform
class MockSTTPlatform(MockPlatform):
"""Help to set up test stt service."""
def __init__(
self,
async_get_engine: Callable[
[HomeAssistant, ConfigType, DiscoveryInfoType | None],
Coroutine[Any, Any, Provider | None],
]
| None = None,
get_engine: Callable[
[HomeAssistant, ConfigType, DiscoveryInfoType | None], Provider | None
]
| None = None,
) -> None:
"""Return the stt service."""
super().__init__()
if get_engine:
self.get_engine = get_engine
if async_get_engine:
self.async_get_engine = async_get_engine
def mock_stt_platform(
hass: HomeAssistant,
tmp_path: Path,
integration: str = "stt",
async_get_engine: Callable[
[HomeAssistant, ConfigType, DiscoveryInfoType | None],
Coroutine[Any, Any, Provider | None],
]
| None = None,
get_engine: Callable[
[HomeAssistant, ConfigType, DiscoveryInfoType | None], Provider | None
]
| None = None,
):
"""Specialize the mock platform for stt."""
loaded_platform = MockSTTPlatform(async_get_engine, get_engine)
mock_platform(hass, f"{integration}.stt", loaded_platform)
return loaded_platform

View file

@ -1,7 +1,8 @@
"""Test STT component setup."""
from collections.abc import AsyncIterable
from http import HTTPStatus
from unittest.mock import AsyncMock, Mock
from pathlib import Path
from unittest.mock import AsyncMock
import pytest
@ -20,7 +21,8 @@ from homeassistant.components.stt import (
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from tests.common import mock_platform
from .common import mock_stt_platform
from tests.typing import ClientSessionGenerator
@ -31,7 +33,7 @@ class MockProvider(Provider):
def __init__(self) -> None:
"""Init test provider."""
self.calls = []
self.calls: list[tuple[SpeechMetadata, AsyncIterable[bytes]]] = []
@property
def supported_languages(self) -> list[str]:
@ -81,10 +83,15 @@ def mock_provider() -> MockProvider:
@pytest.fixture(autouse=True)
async def mock_setup(hass: HomeAssistant, mock_provider: MockProvider) -> None:
async def mock_setup(
hass: HomeAssistant, tmp_path: Path, mock_provider: MockProvider
) -> None:
"""Set up a test provider."""
mock_platform(
hass, "test.stt", Mock(async_get_engine=AsyncMock(return_value=mock_provider))
mock_stt_platform(
hass,
tmp_path,
"test",
async_get_engine=AsyncMock(return_value=mock_provider),
)
assert await async_setup_component(hass, "stt", {"stt": {"platform": "test"}})

View file

@ -0,0 +1,56 @@
"""Test the legacy stt setup."""
from __future__ import annotations
from pathlib import Path
import pytest
from homeassistant.components.stt import Provider
from homeassistant.core import HomeAssistant
from homeassistant.helpers.discovery import async_load_platform
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from .common import mock_stt_platform
async def test_invalid_platform(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture, tmp_path: Path
) -> None:
"""Test platform setup with an invalid platform."""
await async_load_platform(
hass,
"stt",
"bad_stt",
{"stt": [{"platform": "bad_stt"}]},
hass_config={"stt": [{"platform": "bad_stt"}]},
)
await hass.async_block_till_done()
assert "Unknown speech to text platform specified" in caplog.text
async def test_platform_setup_with_error(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture, tmp_path: Path
) -> None:
"""Test platform setup with an error during setup."""
async def async_get_engine(
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> Provider:
"""Raise exception during platform setup."""
raise Exception("Setup error") # pylint: disable=broad-exception-raised
mock_stt_platform(hass, tmp_path, "bad_stt", async_get_engine=async_get_engine)
await async_load_platform(
hass,
"stt",
"bad_stt",
{},
hass_config={"stt": [{"platform": "bad_stt"}]},
)
await hass.async_block_till_done()
assert "Error setting up platform: bad_stt" in caplog.text