hass-core/homeassistant/components/stt/__init__.py
2023-03-15 15:13:10 -04:00

257 lines
7.6 KiB
Python

"""Provide functionality to STT."""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
import logging
from typing import Any
from aiohttp import StreamReader, web
from aiohttp.hdrs import istr
from aiohttp.web_exceptions import (
HTTPBadRequest,
HTTPNotFound,
HTTPUnsupportedMediaType,
)
from homeassistant.components.http import HomeAssistantView
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import engine_component
from homeassistant.helpers.typing import ConfigType
from .const import (
DOMAIN,
AudioBitRates,
AudioChannels,
AudioCodecs,
AudioFormats,
AudioSampleRates,
SpeechResultState,
)
_LOGGER = logging.getLogger(__name__)
@callback
def async_get_provider(
hass: HomeAssistant, provider: str | None = None
) -> Provider | None:
"""Return provider."""
component: engine_component.EngineComponent[Provider] | None = hass.data.get(DOMAIN)
if component is None:
return None
if provider is None:
providers = component.async_get_engines()
return providers[0] if providers else None
return component.async_get_engine(provider)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up STT."""
engines: engine_component.EngineComponent[
Provider
] = engine_component.EngineComponent(_LOGGER, DOMAIN, hass, config)
engines.async_setup_discovery()
hass.data[DOMAIN] = engines
hass.http.register_view(SpeechToTextView(engines))
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
component: engine_component.EngineComponent[Provider] = hass.data[DOMAIN]
return await component.async_setup_entry(entry)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
component: engine_component.EngineComponent[Provider] = hass.data[DOMAIN]
return await component.async_unload_entry(entry)
@dataclass
class SpeechMetadata:
"""Metadata of audio stream."""
language: str
format: AudioFormats
codec: AudioCodecs
bit_rate: AudioBitRates
sample_rate: AudioSampleRates
channel: AudioChannels
def __post_init__(self) -> None:
"""Finish initializing the metadata."""
self.bit_rate = AudioBitRates(int(self.bit_rate))
self.sample_rate = AudioSampleRates(int(self.sample_rate))
self.channel = AudioChannels(int(self.channel))
@dataclass
class SpeechResult:
"""Result of audio Speech."""
text: str | None
result: SpeechResultState
class Provider(engine_component.Engine, ABC):
"""Represent a single STT provider."""
hass: HomeAssistant | None = None
name: str | None = None
@property
@abstractmethod
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
@property
@abstractmethod
def supported_formats(self) -> list[AudioFormats]:
"""Return a list of supported formats."""
@property
@abstractmethod
def supported_codecs(self) -> list[AudioCodecs]:
"""Return a list of supported codecs."""
@property
@abstractmethod
def supported_bit_rates(self) -> list[AudioBitRates]:
"""Return a list of supported bit rates."""
@property
@abstractmethod
def supported_sample_rates(self) -> list[AudioSampleRates]:
"""Return a list of supported sample rates."""
@property
@abstractmethod
def supported_channels(self) -> list[AudioChannels]:
"""Return a list of supported channels."""
@abstractmethod
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: StreamReader
) -> SpeechResult:
"""Process an audio stream to STT service.
Only streaming of content are allow!
"""
@callback
def check_metadata(self, metadata: SpeechMetadata) -> bool:
"""Check if given metadata supported by this provider."""
if (
metadata.language not in self.supported_languages
or metadata.format not in self.supported_formats
or metadata.codec not in self.supported_codecs
or metadata.bit_rate not in self.supported_bit_rates
or metadata.sample_rate not in self.supported_sample_rates
or metadata.channel not in self.supported_channels
):
return False
return True
class SpeechToTextView(HomeAssistantView):
"""STT view to generate a text from audio stream."""
requires_auth = True
url = "/api/stt/{provider}"
name = "api:stt:provider"
def __init__(self, engines: engine_component.EngineComponent[Provider]) -> None:
"""Initialize a tts view."""
self.engines = engines
async def post(self, request: web.Request, provider: str) -> web.Response:
"""Convert Speech (audio) to text."""
stt_provider = self.engines.async_get_engine(provider)
if stt_provider is None:
raise HTTPNotFound()
# Get metadata
try:
metadata = metadata_from_header(request)
except ValueError as err:
raise HTTPBadRequest(text=str(err)) from err
# Check format
if not stt_provider.check_metadata(metadata):
raise HTTPUnsupportedMediaType()
# Process audio stream
result = await stt_provider.async_process_audio_stream(
metadata, request.content
)
# Return result
return self.json(asdict(result))
async def get(self, request: web.Request, provider: str) -> web.Response:
"""Return provider specific audio information."""
stt_provider = self.engines.async_get_engine(provider)
if stt_provider is None:
raise HTTPNotFound()
return self.json(
{
"languages": stt_provider.supported_languages,
"formats": stt_provider.supported_formats,
"codecs": stt_provider.supported_codecs,
"sample_rates": stt_provider.supported_sample_rates,
"bit_rates": stt_provider.supported_bit_rates,
"channels": stt_provider.supported_channels,
}
)
def metadata_from_header(request: web.Request) -> SpeechMetadata:
"""Extract STT metadata from header.
X-Speech-Content:
format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1; language=de_de
"""
try:
data = request.headers[istr("X-Speech-Content")].split(";")
except KeyError as err:
raise ValueError("Missing X-Speech-Content header") from err
fields = (
"language",
"format",
"codec",
"bit_rate",
"sample_rate",
"channel",
)
# Convert Header data
args: dict[str, Any] = {}
for entry in data:
key, _, value = entry.strip().partition("=")
if key not in fields:
raise ValueError(f"Invalid field {key}")
args[key] = value
for field in fields:
if field not in args:
raise ValueError(f"Missing {field} in X-Speech-Content header")
try:
return SpeechMetadata(
language=args["language"],
format=args["format"],
codec=args["codec"],
bit_rate=args["bit_rate"],
sample_rate=args["sample_rate"],
channel=args["channel"],
)
except TypeError as err:
raise ValueError(f"Wrong format of X-Speech-Content: {err}") from err