Move legacy stt (#90776)
* Move legacy stt to separate module * Remove case for None as provider * Add error log for unknown platform * Add some tests
This commit is contained in:
parent
584066b809
commit
535fb34207
5 changed files with 322 additions and 158 deletions
|
@ -1,11 +1,8 @@
|
|||
"""Provide functionality to STT."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
from dataclasses import asdict, dataclass
|
||||
import logging
|
||||
from dataclasses import asdict
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import web
|
||||
|
@ -17,10 +14,8 @@ from aiohttp.web_exceptions import (
|
|||
)
|
||||
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import config_per_platform, discovery
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.setup import async_prepare_setup_platform
|
||||
|
||||
from .const import (
|
||||
DOMAIN,
|
||||
|
@ -31,159 +26,40 @@ from .const import (
|
|||
AudioSampleRates,
|
||||
SpeechResultState,
|
||||
)
|
||||
from .legacy import (
|
||||
Provider,
|
||||
SpeechMetadata,
|
||||
SpeechResult,
|
||||
async_get_provider,
|
||||
async_setup_legacy,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_provider(
|
||||
hass: HomeAssistant, domain: str | None = None
|
||||
) -> Provider | None:
|
||||
"""Return provider."""
|
||||
if domain:
|
||||
return hass.data[DOMAIN].get(domain)
|
||||
|
||||
if not hass.data[DOMAIN]:
|
||||
return None
|
||||
|
||||
if "cloud" in hass.data[DOMAIN]:
|
||||
return hass.data[DOMAIN]["cloud"]
|
||||
|
||||
return next(iter(hass.data[DOMAIN].values()))
|
||||
__all__ = [
|
||||
"async_get_provider",
|
||||
"AudioBitRates",
|
||||
"AudioChannels",
|
||||
"AudioCodecs",
|
||||
"AudioFormats",
|
||||
"AudioSampleRates",
|
||||
"DOMAIN",
|
||||
"Provider",
|
||||
"SpeechMetadata",
|
||||
"SpeechResult",
|
||||
"SpeechResultState",
|
||||
]
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up STT."""
|
||||
providers = hass.data[DOMAIN] = {}
|
||||
platform_setups = async_setup_legacy(hass, config)
|
||||
|
||||
async def async_setup_platform(p_type, p_config=None, discovery_info=None):
|
||||
"""Set up a TTS platform."""
|
||||
if p_config is None:
|
||||
p_config = {}
|
||||
if platform_setups:
|
||||
await asyncio.wait([asyncio.create_task(setup) for setup in platform_setups])
|
||||
|
||||
platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type)
|
||||
if platform is None:
|
||||
return
|
||||
|
||||
try:
|
||||
provider = await platform.async_get_engine(hass, p_config, discovery_info)
|
||||
if provider is None:
|
||||
_LOGGER.error("Error setting up platform %s", p_type)
|
||||
return
|
||||
|
||||
provider.name = p_type
|
||||
provider.hass = hass
|
||||
|
||||
providers[provider.name] = provider
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Error setting up platform: %s", p_type)
|
||||
return
|
||||
|
||||
setup_tasks = [
|
||||
asyncio.create_task(async_setup_platform(p_type, p_config))
|
||||
for p_type, p_config in config_per_platform(config, DOMAIN)
|
||||
]
|
||||
|
||||
if setup_tasks:
|
||||
await asyncio.wait(setup_tasks)
|
||||
|
||||
# Add discovery support
|
||||
async def async_platform_discovered(platform, info):
|
||||
"""Handle for discovered platform."""
|
||||
await async_setup_platform(platform, discovery_info=info)
|
||||
|
||||
discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered)
|
||||
|
||||
hass.http.register_view(SpeechToTextView(providers))
|
||||
hass.http.register_view(SpeechToTextView(hass.data[DOMAIN]))
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeechMetadata:
|
||||
"""Metadata of audio stream."""
|
||||
|
||||
language: str
|
||||
format: AudioFormats
|
||||
codec: AudioCodecs
|
||||
bit_rate: AudioBitRates
|
||||
sample_rate: AudioSampleRates
|
||||
channel: AudioChannels
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Finish initializing the metadata."""
|
||||
self.bit_rate = AudioBitRates(int(self.bit_rate))
|
||||
self.sample_rate = AudioSampleRates(int(self.sample_rate))
|
||||
self.channel = AudioChannels(int(self.channel))
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeechResult:
|
||||
"""Result of audio Speech."""
|
||||
|
||||
text: str | None
|
||||
result: SpeechResultState
|
||||
|
||||
|
||||
class Provider(ABC):
|
||||
"""Represent a single STT provider."""
|
||||
|
||||
hass: HomeAssistant | None = None
|
||||
name: str | None = None
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_languages(self) -> list[str]:
|
||||
"""Return a list of supported languages."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_formats(self) -> list[AudioFormats]:
|
||||
"""Return a list of supported formats."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_codecs(self) -> list[AudioCodecs]:
|
||||
"""Return a list of supported codecs."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_bit_rates(self) -> list[AudioBitRates]:
|
||||
"""Return a list of supported bit rates."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_sample_rates(self) -> list[AudioSampleRates]:
|
||||
"""Return a list of supported sample rates."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_channels(self) -> list[AudioChannels]:
|
||||
"""Return a list of supported channels."""
|
||||
|
||||
@abstractmethod
|
||||
async def async_process_audio_stream(
|
||||
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
||||
) -> SpeechResult:
|
||||
"""Process an audio stream to STT service.
|
||||
|
||||
Only streaming of content are allow!
|
||||
"""
|
||||
|
||||
@callback
|
||||
def check_metadata(self, metadata: SpeechMetadata) -> bool:
|
||||
"""Check if given metadata supported by this provider."""
|
||||
if (
|
||||
metadata.language not in self.supported_languages
|
||||
or metadata.format not in self.supported_formats
|
||||
or metadata.codec not in self.supported_codecs
|
||||
or metadata.bit_rate not in self.supported_bit_rates
|
||||
or metadata.sample_rate not in self.supported_sample_rates
|
||||
or metadata.channel not in self.supported_channels
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class SpeechToTextView(HomeAssistantView):
|
||||
"""STT view to generate a text from audio stream."""
|
||||
|
||||
|
@ -203,7 +79,7 @@ class SpeechToTextView(HomeAssistantView):
|
|||
|
||||
# Get metadata
|
||||
try:
|
||||
metadata = metadata_from_header(request)
|
||||
metadata = _metadata_from_header(request)
|
||||
except ValueError as err:
|
||||
raise HTTPBadRequest(text=str(err)) from err
|
||||
|
||||
|
@ -237,7 +113,7 @@ class SpeechToTextView(HomeAssistantView):
|
|||
)
|
||||
|
||||
|
||||
def metadata_from_header(request: web.Request) -> SpeechMetadata:
|
||||
def _metadata_from_header(request: web.Request) -> SpeechMetadata:
|
||||
"""Extract STT metadata from header.
|
||||
|
||||
X-Speech-Content:
|
||||
|
|
169
homeassistant/components/stt/legacy.py
Normal file
169
homeassistant/components/stt/legacy.py
Normal file
|
@ -0,0 +1,169 @@
|
|||
"""Handle legacy speech to text platforms."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterable, Coroutine
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import config_per_platform, discovery
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.setup import async_prepare_setup_platform
|
||||
|
||||
from .const import (
|
||||
DOMAIN,
|
||||
AudioBitRates,
|
||||
AudioChannels,
|
||||
AudioCodecs,
|
||||
AudioFormats,
|
||||
AudioSampleRates,
|
||||
SpeechResultState,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_provider(
|
||||
hass: HomeAssistant, domain: str | None = None
|
||||
) -> Provider | None:
|
||||
"""Return provider."""
|
||||
if domain:
|
||||
return hass.data[DOMAIN].get(domain)
|
||||
|
||||
if not hass.data[DOMAIN]:
|
||||
return None
|
||||
|
||||
if "cloud" in hass.data[DOMAIN]:
|
||||
return hass.data[DOMAIN]["cloud"]
|
||||
|
||||
return next(iter(hass.data[DOMAIN].values()))
|
||||
|
||||
|
||||
@callback
|
||||
def async_setup_legacy(
|
||||
hass: HomeAssistant, config: ConfigType
|
||||
) -> list[Coroutine[Any, Any, None]]:
|
||||
"""Set up legacy speech to text providers."""
|
||||
providers = hass.data[DOMAIN] = {}
|
||||
|
||||
async def async_setup_platform(p_type, p_config=None, discovery_info=None):
|
||||
"""Set up a TTS platform."""
|
||||
if p_config is None:
|
||||
p_config = {}
|
||||
|
||||
platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type)
|
||||
if platform is None:
|
||||
_LOGGER.error("Unknown speech to text platform specified")
|
||||
return
|
||||
|
||||
try:
|
||||
provider = await platform.async_get_engine(hass, p_config, discovery_info)
|
||||
|
||||
provider.name = p_type
|
||||
provider.hass = hass
|
||||
|
||||
providers[provider.name] = provider
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Error setting up platform: %s", p_type)
|
||||
return
|
||||
|
||||
# Add discovery support
|
||||
async def async_platform_discovered(platform, info):
|
||||
"""Handle for discovered platform."""
|
||||
await async_setup_platform(platform, discovery_info=info)
|
||||
|
||||
discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered)
|
||||
|
||||
return [
|
||||
async_setup_platform(p_type, p_config)
|
||||
for p_type, p_config in config_per_platform(config, DOMAIN)
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeechMetadata:
|
||||
"""Metadata of audio stream."""
|
||||
|
||||
language: str
|
||||
format: AudioFormats
|
||||
codec: AudioCodecs
|
||||
bit_rate: AudioBitRates
|
||||
sample_rate: AudioSampleRates
|
||||
channel: AudioChannels
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Finish initializing the metadata."""
|
||||
self.bit_rate = AudioBitRates(int(self.bit_rate))
|
||||
self.sample_rate = AudioSampleRates(int(self.sample_rate))
|
||||
self.channel = AudioChannels(int(self.channel))
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeechResult:
|
||||
"""Result of audio Speech."""
|
||||
|
||||
text: str | None
|
||||
result: SpeechResultState
|
||||
|
||||
|
||||
class Provider(ABC):
|
||||
"""Represent a single STT provider."""
|
||||
|
||||
hass: HomeAssistant | None = None
|
||||
name: str | None = None
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_languages(self) -> list[str]:
|
||||
"""Return a list of supported languages."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_formats(self) -> list[AudioFormats]:
|
||||
"""Return a list of supported formats."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_codecs(self) -> list[AudioCodecs]:
|
||||
"""Return a list of supported codecs."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_bit_rates(self) -> list[AudioBitRates]:
|
||||
"""Return a list of supported bit rates."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_sample_rates(self) -> list[AudioSampleRates]:
|
||||
"""Return a list of supported sample rates."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_channels(self) -> list[AudioChannels]:
|
||||
"""Return a list of supported channels."""
|
||||
|
||||
@abstractmethod
|
||||
async def async_process_audio_stream(
|
||||
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
||||
) -> SpeechResult:
|
||||
"""Process an audio stream to STT service.
|
||||
|
||||
Only streaming of content are allow!
|
||||
"""
|
||||
|
||||
@callback
|
||||
def check_metadata(self, metadata: SpeechMetadata) -> bool:
|
||||
"""Check if given metadata supported by this provider."""
|
||||
if (
|
||||
metadata.language not in self.supported_languages
|
||||
or metadata.format not in self.supported_formats
|
||||
or metadata.codec not in self.supported_codecs
|
||||
or metadata.bit_rate not in self.supported_bit_rates
|
||||
or metadata.sample_rate not in self.supported_sample_rates
|
||||
or metadata.channel not in self.supported_channels
|
||||
):
|
||||
return False
|
||||
return True
|
56
tests/components/stt/common.py
Normal file
56
tests/components/stt/common.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
"""Provide common test tools for STT."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Coroutine
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.components.stt import Provider
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
|
||||
from tests.common import MockPlatform, mock_platform
|
||||
|
||||
|
||||
class MockSTTPlatform(MockPlatform):
|
||||
"""Help to set up test stt service."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
async_get_engine: Callable[
|
||||
[HomeAssistant, ConfigType, DiscoveryInfoType | None],
|
||||
Coroutine[Any, Any, Provider | None],
|
||||
]
|
||||
| None = None,
|
||||
get_engine: Callable[
|
||||
[HomeAssistant, ConfigType, DiscoveryInfoType | None], Provider | None
|
||||
]
|
||||
| None = None,
|
||||
) -> None:
|
||||
"""Return the stt service."""
|
||||
super().__init__()
|
||||
if get_engine:
|
||||
self.get_engine = get_engine
|
||||
if async_get_engine:
|
||||
self.async_get_engine = async_get_engine
|
||||
|
||||
|
||||
def mock_stt_platform(
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
integration: str = "stt",
|
||||
async_get_engine: Callable[
|
||||
[HomeAssistant, ConfigType, DiscoveryInfoType | None],
|
||||
Coroutine[Any, Any, Provider | None],
|
||||
]
|
||||
| None = None,
|
||||
get_engine: Callable[
|
||||
[HomeAssistant, ConfigType, DiscoveryInfoType | None], Provider | None
|
||||
]
|
||||
| None = None,
|
||||
):
|
||||
"""Specialize the mock platform for stt."""
|
||||
loaded_platform = MockSTTPlatform(async_get_engine, get_engine)
|
||||
mock_platform(hass, f"{integration}.stt", loaded_platform)
|
||||
|
||||
return loaded_platform
|
|
@ -1,7 +1,8 @@
|
|||
"""Test STT component setup."""
|
||||
from collections.abc import AsyncIterable
|
||||
from http import HTTPStatus
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -20,7 +21,8 @@ from homeassistant.components.stt import (
|
|||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import mock_platform
|
||||
from .common import mock_stt_platform
|
||||
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
|
||||
|
@ -31,7 +33,7 @@ class MockProvider(Provider):
|
|||
|
||||
def __init__(self) -> None:
|
||||
"""Init test provider."""
|
||||
self.calls = []
|
||||
self.calls: list[tuple[SpeechMetadata, AsyncIterable[bytes]]] = []
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str]:
|
||||
|
@ -81,10 +83,15 @@ def mock_provider() -> MockProvider:
|
|||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def mock_setup(hass: HomeAssistant, mock_provider: MockProvider) -> None:
|
||||
async def mock_setup(
|
||||
hass: HomeAssistant, tmp_path: Path, mock_provider: MockProvider
|
||||
) -> None:
|
||||
"""Set up a test provider."""
|
||||
mock_platform(
|
||||
hass, "test.stt", Mock(async_get_engine=AsyncMock(return_value=mock_provider))
|
||||
mock_stt_platform(
|
||||
hass,
|
||||
tmp_path,
|
||||
"test",
|
||||
async_get_engine=AsyncMock(return_value=mock_provider),
|
||||
)
|
||||
assert await async_setup_component(hass, "stt", {"stt": {"platform": "test"}})
|
||||
|
||||
|
|
56
tests/components/stt/test_legacy.py
Normal file
56
tests/components/stt/test_legacy.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
"""Test the legacy stt setup."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.stt import Provider
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.discovery import async_load_platform
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
|
||||
from .common import mock_stt_platform
|
||||
|
||||
|
||||
async def test_invalid_platform(
|
||||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test platform setup with an invalid platform."""
|
||||
await async_load_platform(
|
||||
hass,
|
||||
"stt",
|
||||
"bad_stt",
|
||||
{"stt": [{"platform": "bad_stt"}]},
|
||||
hass_config={"stt": [{"platform": "bad_stt"}]},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert "Unknown speech to text platform specified" in caplog.text
|
||||
|
||||
|
||||
async def test_platform_setup_with_error(
|
||||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test platform setup with an error during setup."""
|
||||
|
||||
async def async_get_engine(
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
discovery_info: DiscoveryInfoType | None = None,
|
||||
) -> Provider:
|
||||
"""Raise exception during platform setup."""
|
||||
raise Exception("Setup error") # pylint: disable=broad-exception-raised
|
||||
|
||||
mock_stt_platform(hass, tmp_path, "bad_stt", async_get_engine=async_get_engine)
|
||||
|
||||
await async_load_platform(
|
||||
hass,
|
||||
"stt",
|
||||
"bad_stt",
|
||||
{},
|
||||
hass_config={"stt": [{"platform": "bad_stt"}]},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert "Error setting up platform: bad_stt" in caplog.text
|
Loading…
Add table
Reference in a new issue