Speech to Text component (#28434)

* Initial commit for STT

* Fix code review
This commit is contained in:
Pascal Vizeli 2019-11-04 13:10:42 +01:00 committed by GitHub
parent 33c8cba30d
commit 99c0559a0c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 436 additions and 1 deletions

View file

@ -280,6 +280,7 @@ homeassistant/components/sql/* @dgomes
homeassistant/components/statistics/* @fabaff
homeassistant/components/stiebel_eltron/* @fucm
homeassistant/components/stream/* @hunterjm
homeassistant/components/stt/* @pvizeli
homeassistant/components/suez_water/* @ooii
homeassistant/components/sun/* @Swamp-Ig
homeassistant/components/supla/* @mwegrzynek

View file

@ -25,6 +25,7 @@ COMPONENTS_WITH_DEMO_PLATFORM = [
"media_player",
"notify",
"sensor",
"stt",
"switch",
"tts",
"mailbox",

View file

@ -0,0 +1,60 @@
"""Support for the demo for speech to text service."""
from typing import List
from aiohttp import StreamReader
from homeassistant.components.stt import Provider, SpeechMetadata, SpeechResult
from homeassistant.components.stt.const import (
AudioBitrates,
AudioFormats,
AudioSamplerates,
AudioCodecs,
SpeechResultState,
)
SUPPORT_LANGUAGES = ["en", "de"]
async def async_get_engine(hass, config):
"""Set up Demo speech component."""
return DemoProvider()
class DemoProvider(Provider):
"""Demo speech API provider."""
@property
def supported_languages(self) -> List[str]:
"""Return a list of supported languages."""
return SUPPORT_LANGUAGES
@property
def supported_formats(self) -> List[AudioFormats]:
"""Return a list of supported formats."""
return [AudioFormats.WAV]
@property
def supported_codecs(self) -> List[AudioCodecs]:
"""Return a list of supported codecs."""
return [AudioCodecs.PCM]
@property
def supported_bitrates(self) -> List[AudioBitrates]:
"""Return a list of supported bitrates."""
return [AudioBitrates.BITRATE_16]
@property
def supported_samplerates(self) -> List[AudioSamplerates]:
"""Return a list of supported samplerates."""
return [AudioSamplerates.SAMPLERATE_16000, AudioSamplerates.SAMPLERATE_44100]
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: StreamReader
) -> SpeechResult:
"""Process an audio stream to STT service."""
# Read available data
async for _ in stream.iter_chunked(4096):
pass
return SpeechResult("Turn the Kitchen Lights on", SpeechResultState.SUCCESS)

View file

@ -1,4 +1,4 @@
"""Support for the demo speech service."""
"""Support for the demo for text to speech service."""
import os
import voluptuous as vol

View file

@ -0,0 +1,217 @@
"""Provide functionality to STT."""
from abc import ABC, abstractmethod
import asyncio
import logging
from typing import Dict, List, Optional
from aiohttp import StreamReader, web
from aiohttp.hdrs import istr
from aiohttp.web_exceptions import (
HTTPNotFound,
HTTPUnsupportedMediaType,
HTTPBadRequest,
)
import attr
from homeassistant.components.http import HomeAssistantView
from homeassistant.core import callback
from homeassistant.helpers import config_per_platform
from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.setup import async_prepare_setup_platform
from .const import (
DOMAIN,
AudioBitrates,
AudioCodecs,
AudioFormats,
AudioSamplerates,
SpeechResultState,
)
# mypy: allow-untyped-defs, no-check-untyped-defs
_LOGGER = logging.getLogger(__name__)
async def async_setup(hass: HomeAssistantType, config):
"""Set up STT."""
providers = {}
async def async_setup_platform(p_type, p_config, disc_info=None):
"""Set up a TTS platform."""
platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type)
if platform is None:
return
try:
provider = await platform.async_get_engine(hass, p_config)
if provider is None:
_LOGGER.error("Error setting up platform %s", p_type)
return
provider.name = p_type
provider.hass = hass
providers[provider.name] = provider
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error setting up platform: %s", p_type)
return
setup_tasks = [
async_setup_platform(p_type, p_config)
for p_type, p_config in config_per_platform(config, DOMAIN)
]
if setup_tasks:
await asyncio.wait(setup_tasks)
hass.http.register_view(SpeechToTextView(providers))
return True
@attr.s
class SpeechMetadata:
"""Metadata of audio stream."""
language: str = attr.ib()
format: AudioFormats = attr.ib()
codec: AudioCodecs = attr.ib()
bitrate: AudioBitrates = attr.ib(converter=int)
samplerate: AudioSamplerates = attr.ib(converter=int)
@attr.s
class SpeechResult:
"""Result of audio Speech."""
text: str = attr.ib()
result: SpeechResultState = attr.ib()
class Provider(ABC):
"""Represent a single STT provider."""
hass: Optional[HomeAssistantType] = None
name: Optional[str] = None
@property
@abstractmethod
def supported_languages(self) -> List[str]:
"""Return a list of supported languages."""
@property
@abstractmethod
def supported_formats(self) -> List[AudioFormats]:
"""Return a list of supported formats."""
@property
@abstractmethod
def supported_codecs(self) -> List[AudioCodecs]:
"""Return a list of supported codecs."""
@property
@abstractmethod
def supported_bitrates(self) -> List[AudioBitrates]:
"""Return a list of supported bitrates."""
@property
@abstractmethod
def supported_samplerates(self) -> List[AudioSamplerates]:
"""Return a list of supported samplerates."""
@abstractmethod
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: StreamReader
) -> SpeechResult:
"""Process an audio stream to STT service.
Only streaming of content are allow!
"""
@callback
def check_metadata(self, metadata: SpeechMetadata) -> bool:
"""Check if given metadata supported by this provider."""
if (
metadata.language not in self.supported_languages
or metadata.format not in self.supported_formats
or metadata.codec not in self.supported_codecs
or metadata.bitrate not in self.supported_bitrates
or metadata.samplerate not in self.supported_samplerates
):
return False
return True
class SpeechToTextView(HomeAssistantView):
"""STT view to generate a text from audio stream."""
requires_auth = True
url = "/api/stt/{provider}"
name = "api:stt:provider"
def __init__(self, providers: Dict[str, Provider]) -> None:
"""Initialize a tts view."""
self.providers = providers
@staticmethod
def _metadata_from_header(request: web.Request) -> Optional[SpeechMetadata]:
"""Extract metadata from header.
X-Speech-Content: format=wav; codec=pcm; samplerate=16000; bitrate=16; language=de_de
"""
try:
data = request.headers[istr("X-Speech-Content")].split(";")
except KeyError:
_LOGGER.warning("Missing X-Speech-Content")
return None
# Convert Header data
args = dict()
for value in data:
value = value.strip()
args[value.partition("=")[0]] = value.partition("=")[2]
try:
return SpeechMetadata(**args)
except TypeError as err:
_LOGGER.warning("Wrong format of X-Speech-Content: %s", err)
return None
async def post(self, request: web.Request, provider: str) -> web.Response:
"""Convert Speech (audio) to text."""
if provider not in self.providers:
raise HTTPNotFound()
stt_provider: Provider = self.providers[provider]
# Get metadata
metadata = self._metadata_from_header(request)
if not metadata:
raise HTTPBadRequest()
# Check format
if not stt_provider.check_metadata(metadata):
raise HTTPUnsupportedMediaType()
# Process audio stream
result = await stt_provider.async_process_audio_stream(
metadata, request.content
)
# Return result
return self.json(attr.asdict(result))
async def get(self, request: web.Request, provider: str) -> web.Response:
"""Return provider specific audio information."""
if provider not in self.providers:
raise HTTPNotFound()
stt_provider: Provider = self.providers[provider]
return self.json(
{
"languages": stt_provider.supported_languages,
"formats": stt_provider.supported_formats,
"codecs": stt_provider.supported_codecs,
"samplerates": stt_provider.supported_samplerates,
"bitrates": stt_provider.supported_bitrates,
}
)

View file

@ -0,0 +1,48 @@
"""STT constante."""
from enum import Enum
DOMAIN = "stt"
class AudioCodecs(str, Enum):
"""Supported Audio codecs."""
PCM = "pcm"
OPUS = "opus"
class AudioFormats(str, Enum):
"""Supported Audio formats."""
WAV = "wav"
OGG = "ogg"
class AudioBitrates(int, Enum):
"""Supported Audio bitrates."""
BITRATE_8 = 8
BITRATE_16 = 16
BITRATE_24 = 24
BITRATE_32 = 32
class AudioSamplerates(int, Enum):
"""Supported Audio samplerates."""
SAMPLERATE_8000 = 8000
SAMPLERATE_11000 = 11000
SAMPLERATE_16000 = 16000
SAMPLERATE_18900 = 18900
SAMPLERATE_22000 = 22000
SAMPLERATE_32000 = 32000
SAMPLERATE_37800 = 37800
SAMPLERATE_44100 = 44100
SAMPLERATE_48000 = 48000
class SpeechResultState(str, Enum):
"""Result state of speech."""
SUCCESS = "success"
ERROR = "error"

View file

@ -0,0 +1,8 @@
{
"domain": "stt",
"name": "Stt",
"documentation": "https://www.home-assistant.io/integrations/stt",
"requirements": [],
"dependencies": ["http"],
"codeowners": ["@pvizeli"]
}

View file

@ -0,0 +1 @@
# Describes the format for available STT services

View file

@ -0,0 +1,69 @@
"""The tests for the demo stt component."""
import pytest
from homeassistant.setup import async_setup_component
from homeassistant.components import stt
@pytest.fixture(autouse=True)
def setup_comp(hass):
"""Set up demo component."""
hass.loop.run_until_complete(
async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "demo"}})
)
async def test_demo_settings(hass_client):
"""Test retrieve settings from demo provider."""
client = await hass_client()
response = await client.get("/api/stt/demo")
response_data = await response.json()
assert response.status == 200
assert response_data == {
"languages": ["en", "de"],
"bitrates": [16],
"samplerates": [16000, 44100],
"formats": ["wav"],
"codecs": ["pcm"],
}
async def test_demo_speech_no_metadata(hass_client):
"""Test retrieve settings from demo provider."""
client = await hass_client()
response = await client.post("/api/stt/demo", data=b"Test")
assert response.status == 400
async def test_demo_speech_wrong_metadata(hass_client):
"""Test retrieve settings from demo provider."""
client = await hass_client()
response = await client.post(
"/api/stt/demo",
headers={
"X-Speech-Content": "format=wav; codec=pcm; samplerate=8000; bitrate=16; language=de"
},
data=b"Test",
)
assert response.status == 415
async def test_demo_speech(hass_client):
"""Test retrieve settings from demo provider."""
client = await hass_client()
response = await client.post(
"/api/stt/demo",
headers={
"X-Speech-Content": "format=wav; codec=pcm; samplerate=16000; bitrate=16; language=de"
},
data=b"Test",
)
response_data = await response.json()
assert response.status == 200
assert response_data == {"text": "Turn the Kitchen Lights on", "result": "success"}

View file

@ -0,0 +1 @@
"""Speech to text tests."""

View file

@ -0,0 +1,29 @@
"""Test STT component setup."""
from homeassistant.setup import async_setup_component
from homeassistant.components import stt
async def test_setup_comp(hass):
"""Set up demo component."""
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {}})
async def test_demo_settings_not_exists(hass, hass_client):
"""Test retrieve settings from demo provider."""
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {}})
client = await hass_client()
response = await client.get("/api/stt/beer")
assert response.status == 404
async def test_demo_speech_not_exists(hass, hass_client):
"""Test retrieve settings from demo provider."""
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {}})
client = await hass_client()
response = await client.post("/api/stt/beer", data=b"test")
assert response.status == 404