hass-core/homeassistant/components/cloud/stt.py
Michael Hansen 3e3ece4e56
Add speech to text over binary websocket to pipeline (#90082)
* Allow passing binary to the WS connection

* Expand test coverage

* Test non-existing handler

* Add text to speech and stages to pipeline

* Default to "cloud" TTS when engine is None

* Refactor pipeline request to split text/audio

* Refactor with PipelineRun

* Generate pipeline from language

* Clean up

* Restore TTS code

* Add audio pipeline test

* Clean TTS cache in test

* Clean up tests and pipeline base class

* Stop pylint and pytest magics from fighting

* Include mock_get_cache_files

* Working on STT

* Preparing to test

* First successful test

* Send handler_id

* Allow signaling end of stream using empty payloads

* Store handlers in a list

* Handle binary handlers raising exceptions

* Add stt/tts dependencies to voice_assistant

* Include STT audio in pipeline test

* Working on tests

* Refactoring with stages

* Fix tests

* Add more tests

* Add method docs

* Change stt demo/cloud to AsyncIterable

* Add pipeline error events

* Move handler id to separate message before pipeline

* Add test for invalid stage order

* Change "finish" to "end"

* Use enum

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
2023-03-23 14:44:19 -04:00

112 lines
2.8 KiB
Python

"""Support for the cloud for speech to text service."""
from __future__ import annotations
from collections.abc import AsyncIterable
from hass_nabucasa import Cloud
from hass_nabucasa.voice import VoiceError
from homeassistant.components.stt import (
AudioBitRates,
AudioChannels,
AudioCodecs,
AudioFormats,
AudioSampleRates,
Provider,
SpeechMetadata,
SpeechResult,
SpeechResultState,
)
from .const import DOMAIN
SUPPORT_LANGUAGES = [
"da-DK",
"de-DE",
"en-AU",
"en-CA",
"en-GB",
"en-US",
"es-ES",
"fi-FI",
"fr-CA",
"fr-FR",
"it-IT",
"ja-JP",
"nl-NL",
"pl-PL",
"pt-PT",
"ru-RU",
"sv-SE",
"th-TH",
"zh-CN",
"zh-HK",
]
async def async_get_engine(hass, config, discovery_info=None):
"""Set up Cloud speech component."""
cloud: Cloud = hass.data[DOMAIN]
return CloudProvider(cloud)
class CloudProvider(Provider):
"""NabuCasa speech API provider."""
def __init__(self, cloud: Cloud) -> None:
"""Home Assistant NabuCasa Speech to text."""
self.cloud = cloud
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return SUPPORT_LANGUAGES
@property
def supported_formats(self) -> list[AudioFormats]:
"""Return a list of supported formats."""
return [AudioFormats.WAV, AudioFormats.OGG]
@property
def supported_codecs(self) -> list[AudioCodecs]:
"""Return a list of supported codecs."""
return [AudioCodecs.PCM, AudioCodecs.OPUS]
@property
def supported_bit_rates(self) -> list[AudioBitRates]:
"""Return a list of supported bitrates."""
return [AudioBitRates.BITRATE_16]
@property
def supported_sample_rates(self) -> list[AudioSampleRates]:
"""Return a list of supported samplerates."""
return [AudioSampleRates.SAMPLERATE_16000]
@property
def supported_channels(self) -> list[AudioChannels]:
"""Return a list of supported channels."""
return [AudioChannels.CHANNEL_MONO]
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
) -> SpeechResult:
"""Process an audio stream to STT service."""
content = (
f"audio/{metadata.format!s}; codecs=audio/{metadata.codec!s};"
" samplerate=16000"
)
# Process STT
try:
result = await self.cloud.voice.process_stt(
stream, content, metadata.language
)
except VoiceError:
return SpeechResult(None, SpeechResultState.ERROR)
# Return Speech as Text
return SpeechResult(
result.text,
SpeechResultState.SUCCESS if result.success else SpeechResultState.ERROR,
)