Add speech to text over binary websocket to pipeline (#90082)

* Allow passing binary to the WS connection

* Expand test coverage

* Test non-existing handler

* Add text to speech and stages to pipeline

* Default to "cloud" TTS when engine is None

* Refactor pipeline request to split text/audio

* Refactor with PipelineRun

* Generate pipeline from language

* Clean up

* Restore TTS code

* Add audio pipeline test

* Clean TTS cache in test

* Clean up tests and pipeline base class

* Stop pylint and pytest magics from fighting

* Include mock_get_cache_files

* Working on STT

* Preparing to test

* First successful test

* Send handler_id

* Allow signaling end of stream using empty payloads

* Store handlers in a list

* Handle binary handlers raising exceptions

* Add stt/tts dependencies to voice_assistant

* Include STT audio in pipeline test

* Working on tests

* Refactoring with stages

* Fix tests

* Add more tests

* Add method docs

* Change stt demo/cloud to AsyncIterable

* Add pipeline error events

* Move handler id to separate message before pipeline

* Add test for invalid stage order

* Change "finish" to "end"

* Use enum

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen 2023-03-23 13:44:19 -05:00 committed by GitHub
parent 185d6d74d7
commit 3e3ece4e56
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 860 additions and 234 deletions

View file

@ -1,7 +1,8 @@
"""Support for the cloud for speech to text service."""
from __future__ import annotations
from aiohttp import StreamReader
from collections.abc import AsyncIterable
from hass_nabucasa import Cloud
from hass_nabucasa.voice import VoiceError
@ -88,7 +89,7 @@ class CloudProvider(Provider):
return [AudioChannels.CHANNEL_MONO]
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: StreamReader
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
) -> SpeechResult:
"""Process an audio stream to STT service."""
content = (

View file

@ -1,7 +1,7 @@
"""Support for the demo for speech to text service."""
from __future__ import annotations
from aiohttp import StreamReader
from collections.abc import AsyncIterable
from homeassistant.components.stt import (
AudioBitRates,
@ -63,12 +63,12 @@ class DemoProvider(Provider):
return [AudioChannels.CHANNEL_STEREO]
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: StreamReader
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
) -> SpeechResult:
"""Process an audio stream to STT service."""
# Read available data
async for _ in stream.iter_chunked(4096):
async for _ in stream:
pass
return SpeechResult("Turn the Kitchen Lights on", SpeechResultState.SUCCESS)

View file

@ -3,11 +3,12 @@ from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from collections.abc import AsyncIterable
from dataclasses import asdict, dataclass
import logging
from typing import Any
from aiohttp import StreamReader, web
from aiohttp import web
from aiohttp.hdrs import istr
from aiohttp.web_exceptions import (
HTTPBadRequest,
@ -153,7 +154,7 @@ class Provider(ABC):
@abstractmethod
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: StreamReader
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
) -> SpeechResult:
"""Process an audio stream to STT service.

View file

@ -2,7 +2,7 @@
"domain": "voice_assistant",
"name": "Voice Assistant",
"codeowners": ["@balloob", "@synesthesiam"],
"dependencies": ["conversation"],
"dependencies": ["conversation", "stt", "tts"],
"documentation": "https://www.home-assistant.io/integrations/voice_assistant",
"iot_class": "local_push",
"quality_scale": "internal"

View file

@ -1,33 +1,80 @@
"""Classes for voice assistant pipelines."""
from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from collections.abc import Callable
from dataclasses import dataclass, field
from collections.abc import AsyncIterable, Callable
from dataclasses import asdict, dataclass, field
import logging
from typing import Any
from homeassistant.backports.enum import StrEnum
from homeassistant.components import conversation
from homeassistant.components.media_source import async_resolve_media
from homeassistant.components import conversation, media_source, stt
from homeassistant.components.tts.media_source import (
generate_media_source_id as tts_generate_media_source_id,
)
from homeassistant.core import Context, HomeAssistant
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.util.dt import utcnow
from .const import DOMAIN
DEFAULT_TIMEOUT = 30 # seconds
_LOGGER = logging.getLogger(__name__)
@callback
def async_get_pipeline(
hass: HomeAssistant, pipeline_id: str | None = None, language: str | None = None
) -> Pipeline | None:
"""Get a pipeline by id or create one for a language."""
if pipeline_id is not None:
return hass.data[DOMAIN].get(pipeline_id)
# Construct a pipeline for the required/configured language
language = language or hass.config.language
return Pipeline(
name=language,
language=language,
stt_engine=None, # first engine
conversation_engine=None, # first agent
tts_engine=None, # first engine
)
class PipelineError(Exception):
"""Base class for pipeline errors."""
def __init__(self, code: str, message: str) -> None:
"""Set error message."""
self.code = code
self.message = message
super().__init__(f"Pipeline error code={code}, message={message}")
class SpeechToTextError(PipelineError):
"""Error in speech to text portion of pipeline."""
class IntentRecognitionError(PipelineError):
"""Error in intent recognition portion of pipeline."""
class TextToSpeechError(PipelineError):
"""Error in text to speech portion of pipeline."""
class PipelineEventType(StrEnum):
"""Event types emitted during a pipeline run."""
RUN_START = "run-start"
RUN_FINISH = "run-finish"
RUN_END = "run-end"
STT_START = "stt-start"
STT_END = "stt-end"
INTENT_START = "intent-start"
INTENT_FINISH = "intent-finish"
INTENT_END = "intent-end"
TTS_START = "tts-start"
TTS_FINISH = "tts-finish"
TTS_END = "tts-end"
ERROR = "error"
@ -54,10 +101,44 @@ class Pipeline:
name: str
language: str | None
stt_engine: str | None
conversation_engine: str | None
tts_engine: str | None
class PipelineStage(StrEnum):
"""Stages of a pipeline."""
STT = "stt"
INTENT = "intent"
TTS = "tts"
PIPELINE_STAGE_ORDER = [
PipelineStage.STT,
PipelineStage.INTENT,
PipelineStage.TTS,
]
class PipelineRunValidationError(Exception):
"""Error when a pipeline run is not valid."""
class InvalidPipelineStagesError(PipelineRunValidationError):
"""Error when given an invalid combination of start/end stages."""
def __init__(
self,
start_stage: PipelineStage,
end_stage: PipelineStage,
) -> None:
"""Set error message."""
super().__init__(
f"Invalid stage combination: start={start_stage}, end={end_stage}"
)
@dataclass
class PipelineRun:
"""Running context for a pipeline."""
@ -65,6 +146,8 @@ class PipelineRun:
hass: HomeAssistant
context: Context
pipeline: Pipeline
start_stage: PipelineStage
end_stage: PipelineStage
event_callback: Callable[[PipelineEvent], None]
language: str = None # type: ignore[assignment]
@ -72,6 +155,12 @@ class PipelineRun:
"""Set language for pipeline."""
self.language = self.pipeline.language or self.hass.config.language
# stt -> intent -> tts
if PIPELINE_STAGE_ORDER.index(self.end_stage) < PIPELINE_STAGE_ORDER.index(
self.start_stage
):
raise InvalidPipelineStagesError(self.start_stage, self.end_stage)
def start(self):
"""Emit run start event."""
self.event_callback(
@ -84,18 +173,86 @@ class PipelineRun:
)
)
def finish(self):
"""Emit run finish event."""
def end(self):
"""Emit run end event."""
self.event_callback(
PipelineEvent(
PipelineEventType.RUN_FINISH,
PipelineEventType.RUN_END,
)
)
async def speech_to_text(
self,
metadata: stt.SpeechMetadata,
stream: AsyncIterable[bytes],
) -> str:
"""Run speech to text portion of pipeline. Returns the spoken text."""
engine = self.pipeline.stt_engine or "default"
self.event_callback(
PipelineEvent(
PipelineEventType.STT_START,
{
"engine": engine,
"metadata": asdict(metadata),
},
)
)
try:
# Load provider
stt_provider = stt.async_get_provider(self.hass, self.pipeline.stt_engine)
assert stt_provider is not None
except Exception as src_error:
stt_error = SpeechToTextError(
code="stt-provider-missing",
message=f"No speech to text provider for: {engine}",
)
_LOGGER.exception(stt_error.message)
self.event_callback(
PipelineEvent(
PipelineEventType.ERROR,
{"code": stt_error.code, "message": stt_error.message},
)
)
raise stt_error from src_error
try:
# Transcribe audio stream
result = await stt_provider.async_process_audio_stream(metadata, stream)
assert (result.text is not None) and (
result.result == stt.SpeechResultState.SUCCESS
)
except Exception as src_error:
stt_error = SpeechToTextError(
code="stt-stream-failed",
message="Unexpected error during speech to text",
)
_LOGGER.exception(stt_error.message)
self.event_callback(
PipelineEvent(
PipelineEventType.ERROR,
{"code": stt_error.code, "message": stt_error.message},
)
)
raise stt_error from src_error
self.event_callback(
PipelineEvent(
PipelineEventType.STT_END,
{
"stt_output": {
"text": result.text,
}
},
)
)
return result.text
async def recognize_intent(
self, intent_input: str, conversation_id: str | None
) -> conversation.ConversationResult:
"""Run intent recognition portion of pipeline."""
) -> str:
"""Run intent recognition portion of pipeline. Returns text to speak."""
self.event_callback(
PipelineEvent(
PipelineEventType.INTENT_START,
@ -106,23 +263,39 @@ class PipelineRun:
)
)
conversation_result = await conversation.async_converse(
hass=self.hass,
text=intent_input,
conversation_id=conversation_id,
context=self.context,
language=self.language,
agent_id=self.pipeline.conversation_engine,
)
try:
conversation_result = await conversation.async_converse(
hass=self.hass,
text=intent_input,
conversation_id=conversation_id,
context=self.context,
language=self.language,
agent_id=self.pipeline.conversation_engine,
)
except Exception as src_error:
intent_error = IntentRecognitionError(
code="intent-failed",
message="Unexpected error during intent recognition",
)
_LOGGER.exception(intent_error.message)
self.event_callback(
PipelineEvent(
PipelineEventType.ERROR,
{"code": intent_error.code, "message": intent_error.message},
)
)
raise intent_error from src_error
self.event_callback(
PipelineEvent(
PipelineEventType.INTENT_FINISH,
PipelineEventType.INTENT_END,
{"intent_output": conversation_result.as_dict()},
)
)
return conversation_result
speech = conversation_result.response.speech.get("plain", {}).get("speech", "")
return speech
async def text_to_speech(self, tts_input: str) -> str:
"""Run text to speech portion of pipeline. Returns URL of TTS audio."""
@ -136,29 +309,57 @@ class PipelineRun:
)
)
tts_media = await async_resolve_media(
self.hass,
tts_generate_media_source_id(
try:
# Synthesize audio and get URL
tts_media = await media_source.async_resolve_media(
self.hass,
tts_input,
engine=self.pipeline.tts_engine,
),
)
tts_url = tts_media.url
tts_generate_media_source_id(
self.hass,
tts_input,
engine=self.pipeline.tts_engine,
),
)
except Exception as src_error:
tts_error = TextToSpeechError(
code="tts-failed",
message="Unexpected error during text to speech",
)
_LOGGER.exception(tts_error.message)
self.event_callback(
PipelineEvent(
PipelineEventType.ERROR,
{"code": tts_error.code, "message": tts_error.message},
)
)
raise tts_error from src_error
self.event_callback(
PipelineEvent(
PipelineEventType.TTS_FINISH,
{"tts_output": tts_url},
PipelineEventType.TTS_END,
{"tts_output": asdict(tts_media)},
)
)
return tts_url
return tts_media.url
@dataclass
class PipelineRequest(ABC):
"""Request to for a pipeline run."""
class PipelineInput:
"""Input to a pipeline run."""
stt_metadata: stt.SpeechMetadata | None = None
"""Metadata of stt input audio. Required when start_stage = stt."""
stt_stream: AsyncIterable[bytes] | None = None
"""Input audio for stt. Required when start_stage = stt."""
intent_input: str | None = None
"""Input for conversation agent. Required when start_stage = intent."""
tts_input: str | None = None
"""Input for text to speech. Required when start_stage = tts."""
conversation_id: str | None = None
async def execute(
self, run: PipelineRun, timeout: int | float | None = DEFAULT_TIMEOUT
@ -169,47 +370,60 @@ class PipelineRequest(ABC):
timeout=timeout,
)
@abstractmethod
async def _execute(self, run: PipelineRun):
"""Run pipeline with request info and context."""
self._validate(run.start_stage)
@dataclass
class TextPipelineRequest(PipelineRequest):
"""Request to run the text portion only of a pipeline."""
intent_input: str
conversation_id: str | None = None
async def _execute(
self,
run: PipelineRun,
):
# stt -> intent -> tts
run.start()
await run.recognize_intent(self.intent_input, self.conversation_id)
run.finish()
current_stage = run.start_stage
# Speech to text
intent_input = self.intent_input
if current_stage == PipelineStage.STT:
assert self.stt_metadata is not None
assert self.stt_stream is not None
intent_input = await run.speech_to_text(
self.stt_metadata,
self.stt_stream,
)
current_stage = PipelineStage.INTENT
@dataclass
class AudioPipelineRequest(PipelineRequest):
"""Request to full pipeline from audio input (stt) to audio output (tts)."""
if run.end_stage != PipelineStage.STT:
tts_input = self.tts_input
intent_input: str # this will be changed to stt audio
conversation_id: str | None = None
if current_stage == PipelineStage.INTENT:
assert intent_input is not None
tts_input = await run.recognize_intent(
intent_input, self.conversation_id
)
current_stage = PipelineStage.TTS
async def _execute(self, run: PipelineRun):
run.start()
if run.end_stage != PipelineStage.INTENT:
if current_stage == PipelineStage.TTS:
assert tts_input is not None
await run.text_to_speech(tts_input)
# stt will go here
run.end()
conversation_result = await run.recognize_intent(
self.intent_input, self.conversation_id
)
def _validate(self, stage: PipelineStage):
"""Validate pipeline input against start stage."""
if stage == PipelineStage.STT:
if self.stt_metadata is None:
raise PipelineRunValidationError(
"stt_metadata is required for speech to text"
)
tts_input = conversation_result.response.speech.get("plain", {}).get(
"speech", ""
)
await run.text_to_speech(tts_input)
run.finish()
if self.stt_stream is None:
raise PipelineRunValidationError(
"stt_stream is required for speech to text"
)
elif stage == PipelineStage.INTENT:
if self.intent_input is None:
raise PipelineRunValidationError(
"intent_input is required for intent recognition"
)
elif stage == PipelineStage.TTS:
if self.tts_input is None:
raise PipelineRunValidationError(
"tts_input is required for text to speech"
)

View file

@ -1,13 +1,24 @@
"""Voice Assistant Websocket API."""
import asyncio
from collections.abc import Callable
import logging
from typing import Any
import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.components import stt, websocket_api
from homeassistant.core import HomeAssistant, callback
from .const import DOMAIN
from .pipeline import DEFAULT_TIMEOUT, Pipeline, PipelineRun, TextPipelineRequest
from .pipeline import (
DEFAULT_TIMEOUT,
PipelineError,
PipelineInput,
PipelineRun,
PipelineStage,
async_get_pipeline,
)
_LOGGER = logging.getLogger(__name__)
@callback
@ -19,9 +30,13 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
@websocket_api.websocket_command(
{
vol.Required("type"): "voice_assistant/run",
# pylint: disable-next=unnecessary-lambda
vol.Required("start_stage"): lambda val: PipelineStage(val),
# pylint: disable-next=unnecessary-lambda
vol.Required("end_stage"): lambda val: PipelineStage(val),
vol.Optional("input"): {"text": str},
vol.Optional("language"): str,
vol.Optional("pipeline"): str,
vol.Required("intent_input"): str,
vol.Optional("conversation_id"): vol.Any(str, None),
vol.Optional("timeout"): vol.Any(float, int),
}
@ -33,39 +48,74 @@ async def websocket_run(
msg: dict[str, Any],
) -> None:
"""Run a pipeline."""
language = msg.get("language", hass.config.language)
pipeline_id = msg.get("pipeline")
if pipeline_id is not None:
pipeline = hass.data[DOMAIN].get(pipeline_id)
if pipeline is None:
connection.send_error(
msg["id"],
"pipeline_not_found",
f"Pipeline not found: {pipeline_id}",
)
return
pipeline = async_get_pipeline(
hass,
pipeline_id=pipeline_id,
language=language,
)
if pipeline is None:
connection.send_error(
msg["id"],
"pipeline-not-found",
f"Pipeline not found: id={pipeline_id}, language={language}",
)
return
else:
# Construct a pipeline for the required/configured language
language = msg.get("language", hass.config.language)
pipeline = Pipeline(
name=language,
language=language,
conversation_engine=None,
tts_engine=None,
timeout = msg.get("timeout", DEFAULT_TIMEOUT)
start_stage = PipelineStage(msg["start_stage"])
end_stage = PipelineStage(msg["end_stage"])
handler_id: int | None = None
unregister_handler: Callable[[], None] | None = None
# Arguments to PipelineInput
input_args: dict[str, Any] = {
"conversation_id": msg.get("conversation_id"),
}
if start_stage == PipelineStage.STT:
# Audio pipeline that will receive audio as binary websocket messages
audio_queue: "asyncio.Queue[bytes]" = asyncio.Queue()
async def stt_stream():
# Yield until we receive an empty chunk
while chunk := await audio_queue.get():
yield chunk
def handle_binary(_hass, _connection, data: bytes):
# Forward to STT audio stream
audio_queue.put_nowait(data)
handler_id, unregister_handler = connection.async_register_binary_handler(
handle_binary
)
# Run pipeline with a timeout.
# Events are sent over the websocket connection.
timeout = msg.get("timeout", DEFAULT_TIMEOUT)
# Audio input must be raw PCM at 16Khz with 16-bit mono samples
input_args["stt_metadata"] = stt.SpeechMetadata(
language=language,
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
)
input_args["stt_stream"] = stt_stream()
elif start_stage == PipelineStage.INTENT:
# Input to conversation agent
input_args["intent_input"] = msg["input"]["text"]
elif start_stage == PipelineStage.TTS:
# Input to text to speech system
input_args["tts_input"] = msg["input"]["text"]
run_task = hass.async_create_task(
TextPipelineRequest(
intent_input=msg["intent_input"],
conversation_id=msg.get("conversation_id"),
).execute(
PipelineInput(**input_args).execute(
PipelineRun(
hass,
connection.context(msg),
pipeline,
context=connection.context(msg),
pipeline=pipeline,
start_stage=start_stage,
end_stage=end_stage,
event_callback=lambda event: connection.send_event(
msg["id"], event.as_dict()
),
@ -77,7 +127,20 @@ async def websocket_run(
# Cancel pipeline if user unsubscribes
connection.subscriptions[msg["id"]] = run_task.cancel
# Confirm subscription
connection.send_result(msg["id"])
# Task contains a timeout
await run_task
if handler_id is not None:
# Send handler id to client
connection.send_event(msg["id"], {"handler_id": handler_id})
try:
# Task contains a timeout
await run_task
except PipelineError as error:
# Report more specific error when possible
connection.send_error(msg["id"], error.code, error.message)
finally:
if unregister_handler is not None:
# Unregister binary handler
unregister_handler()

View file

@ -1,5 +1,5 @@
"""Test STT component setup."""
from asyncio import StreamReader
from collections.abc import AsyncIterable
from http import HTTPStatus
from unittest.mock import AsyncMock, Mock
@ -64,7 +64,7 @@ class MockProvider(Provider):
return [AudioChannels.CHANNEL_MONO]
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: StreamReader
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
) -> SpeechResult:
"""Process an audio stream."""
self.calls.append((metadata, stream))

View file

@ -1,110 +0,0 @@
"""Pipeline tests for Voice Assistant integration."""
from unittest.mock import MagicMock, patch
import pytest
from homeassistant.components.voice_assistant.pipeline import (
AudioPipelineRequest,
Pipeline,
PipelineEventType,
PipelineRun,
)
from homeassistant.core import Context
from homeassistant.setup import async_setup_component
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
mock_get_cache_files,
mock_init_cache_dir,
)
@pytest.fixture(autouse=True)
async def init_components(hass):
"""Initialize relevant components with empty configs."""
assert await async_setup_component(hass, "voice_assistant", {})
@pytest.fixture
async def mock_get_tts_audio(hass):
"""Set up media source."""
assert await async_setup_component(hass, "media_source", {})
assert await async_setup_component(
hass,
"tts",
{
"tts": {
"platform": "demo",
}
},
)
with patch(
"homeassistant.components.demo.tts.DemoProvider.get_tts_audio",
return_value=("mp3", b""),
) as mock_get_tts:
yield mock_get_tts
async def test_audio_pipeline(hass, mock_get_tts_audio):
"""Run audio pipeline with mock TTS."""
pipeline = Pipeline(
name="test",
language=hass.config.language,
conversation_engine=None,
tts_engine=None,
)
event_callback = MagicMock()
await AudioPipelineRequest(intent_input="Are the lights on?").execute(
PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
event_callback=event_callback,
language=hass.config.language,
)
)
calls = event_callback.mock_calls
assert calls[0].args[0].type == PipelineEventType.RUN_START
assert calls[0].args[0].data == {
"pipeline": "test",
"language": hass.config.language,
}
assert calls[1].args[0].type == PipelineEventType.INTENT_START
assert calls[1].args[0].data == {
"engine": "default",
"intent_input": "Are the lights on?",
}
assert calls[2].args[0].type == PipelineEventType.INTENT_FINISH
assert calls[2].args[0].data == {
"intent_output": {
"conversation_id": None,
"response": {
"card": {},
"data": {"code": "no_intent_match"},
"language": hass.config.language,
"response_type": "error",
"speech": {
"plain": {
"extra_data": None,
"speech": "Sorry, I couldn't understand that",
}
},
},
}
}
assert calls[3].args[0].type == PipelineEventType.TTS_START
assert calls[3].args[0].data == {
"engine": "default",
"tts_input": "Sorry, I couldn't understand that",
}
assert calls[4].args[0].type == PipelineEventType.TTS_FINISH
assert (
calls[4].args[0].data["tts_output"]
== f"/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_{hass.config.language}_-_demo.mp3"
)
assert calls[5].args[0].type == PipelineEventType.RUN_FINISH

View file

@ -1,20 +1,94 @@
"""Websocket tests for Voice Assistant integration."""
import asyncio
from unittest.mock import patch
from collections.abc import AsyncIterable
from unittest.mock import MagicMock, patch
import pytest
from homeassistant.components import stt
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
mock_get_cache_files,
mock_init_cache_dir,
)
from tests.typing import WebSocketGenerator
_TRANSCRIPT = "test transcript"
class MockSttProvider(stt.Provider):
"""Mock STT provider."""
def __init__(self, hass: HomeAssistant, text: str) -> None:
"""Init test provider."""
self.hass = hass
self.text = text
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return [self.hass.config.language]
@property
def supported_formats(self) -> list[stt.AudioFormats]:
"""Return a list of supported formats."""
return [stt.AudioFormats.WAV]
@property
def supported_codecs(self) -> list[stt.AudioCodecs]:
"""Return a list of supported codecs."""
return [stt.AudioCodecs.PCM]
@property
def supported_bit_rates(self) -> list[stt.AudioBitRates]:
"""Return a list of supported bitrates."""
return [stt.AudioBitRates.BITRATE_16]
@property
def supported_sample_rates(self) -> list[stt.AudioSampleRates]:
"""Return a list of supported samplerates."""
return [stt.AudioSampleRates.SAMPLERATE_16000]
@property
def supported_channels(self) -> list[stt.AudioChannels]:
"""Return a list of supported channels."""
return [stt.AudioChannels.CHANNEL_MONO]
async def async_process_audio_stream(
self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes]
) -> stt.SpeechResult:
"""Process an audio stream."""
return stt.SpeechResult(self.text, stt.SpeechResultState.SUCCESS)
@pytest.fixture(autouse=True)
async def init_components(hass):
"""Initialize relevant components with empty configs."""
assert await async_setup_component(hass, "media_source", {})
assert await async_setup_component(
hass,
"tts",
{
"tts": {
"platform": "demo",
}
},
)
assert await async_setup_component(hass, "stt", {})
# mock_platform fails because it can't import
hass.data[stt.DOMAIN] = {"test": MockSttProvider(hass, _TRANSCRIPT)}
assert await async_setup_component(hass, "voice_assistant", {})
with patch(
"homeassistant.components.demo.tts.DemoProvider.get_tts_audio",
return_value=("mp3", b""),
) as mock_get_tts:
yield mock_get_tts
async def test_text_only_pipeline(
hass: HomeAssistant,
@ -27,7 +101,9 @@ async def test_text_only_pipeline(
{
"id": 5,
"type": "voice_assistant/run",
"intent_input": "Are the lights on?",
"start_stage": "intent",
"end_stage": "intent",
"input": {"text": "Are the lights on?"},
}
)
@ -52,7 +128,7 @@ async def test_text_only_pipeline(
}
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-finish"
assert msg["event"]["type"] == "intent-end"
assert msg["event"]["data"] == {
"intent_output": {
"response": {
@ -71,13 +147,120 @@ async def test_text_only_pipeline(
}
}
# run finish
# run end
msg = await client.receive_json()
assert msg["event"]["type"] == "run-finish"
assert msg["event"]["type"] == "run-end"
assert msg["event"]["data"] == {}
async def test_conversation_timeout(
async def test_audio_pipeline(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test events from a pipeline run with audio input/output."""
client = await hass_ws_client(hass)
await client.send_json(
{
"id": 5,
"type": "voice_assistant/run",
"start_stage": "stt",
"end_stage": "tts",
}
)
# result
msg = await client.receive_json()
assert msg["success"]
# handler id
msg = await client.receive_json()
assert msg["event"]["handler_id"] == 1
# run start
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == {
"pipeline": hass.config.language,
"language": hass.config.language,
}
# stt
msg = await client.receive_json()
assert msg["event"]["type"] == "stt-start"
assert msg["event"]["data"] == {
"engine": "default",
"metadata": {
"bit_rate": 16,
"channel": 1,
"codec": "pcm",
"format": "wav",
"language": "en",
"sample_rate": 16000,
},
}
# End of audio stream (handler id + empty payload)
await client.send_bytes(b"1")
msg = await client.receive_json()
assert msg["event"]["type"] == "stt-end"
assert msg["event"]["data"] == {
"stt_output": {"text": _TRANSCRIPT},
}
# intent
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-start"
assert msg["event"]["data"] == {
"engine": "default",
"intent_input": _TRANSCRIPT,
}
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-end"
assert msg["event"]["data"] == {
"intent_output": {
"response": {
"speech": {
"plain": {
"speech": "Sorry, I couldn't understand that",
"extra_data": None,
}
},
"card": {},
"language": "en",
"response_type": "error",
"data": {"code": "no_intent_match"},
},
"conversation_id": None,
}
}
# text to speech
msg = await client.receive_json()
assert msg["event"]["type"] == "tts-start"
assert msg["event"]["data"] == {
"engine": "default",
"tts_input": "Sorry, I couldn't understand that",
}
msg = await client.receive_json()
assert msg["event"]["type"] == "tts-end"
assert msg["event"]["data"] == {
"tts_output": {
"url": f"/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_{hass.config.language}_-_demo.mp3",
"mime_type": "audio/mpeg",
},
}
# run end
msg = await client.receive_json()
assert msg["event"]["type"] == "run-end"
assert msg["event"]["data"] == {}
async def test_intent_timeout(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
) -> None:
"""Test partial pipeline run with conversation agent timeout."""
@ -94,7 +277,9 @@ async def test_conversation_timeout(
{
"id": 5,
"type": "voice_assistant/run",
"intent_input": "Are the lights on?",
"start_stage": "intent",
"end_stage": "intent",
"input": {"text": "Are the lights on?"},
"timeout": 0.00001,
}
)
@ -125,24 +310,26 @@ async def test_conversation_timeout(
assert msg["error"]["code"] == "timeout"
async def test_pipeline_timeout(
async def test_text_pipeline_timeout(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
) -> None:
"""Test pipeline run with immediate timeout."""
"""Test text-only pipeline run with immediate timeout."""
client = await hass_ws_client(hass)
async def sleepy_run(*args, **kwargs):
await asyncio.sleep(3600)
with patch(
"homeassistant.components.voice_assistant.pipeline.TextPipelineRequest._execute",
"homeassistant.components.voice_assistant.pipeline.PipelineInput._execute",
new=sleepy_run,
):
await client.send_json(
{
"id": 5,
"type": "voice_assistant/run",
"intent_input": "Are the lights on?",
"start_stage": "intent",
"end_stage": "intent",
"input": {"text": "Are the lights on?"},
"timeout": 0.0001,
}
)
@ -155,3 +342,273 @@ async def test_pipeline_timeout(
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"]["code"] == "timeout"
async def test_intent_failed(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
) -> None:
"""Test text-only pipeline run with conversation agent error."""
client = await hass_ws_client(hass)
with patch(
"homeassistant.components.conversation.async_converse",
new=MagicMock(return_value=RuntimeError),
):
await client.send_json(
{
"id": 5,
"type": "voice_assistant/run",
"start_stage": "intent",
"end_stage": "intent",
"input": {"text": "Are the lights on?"},
}
)
# result
msg = await client.receive_json()
assert msg["success"]
# run start
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == {
"pipeline": hass.config.language,
"language": hass.config.language,
}
# intent start
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-start"
assert msg["event"]["data"] == {
"engine": "default",
"intent_input": "Are the lights on?",
}
# intent error
msg = await client.receive_json()
assert msg["event"]["type"] == "error"
assert msg["event"]["data"]["code"] == "intent-failed"
async def test_audio_pipeline_timeout(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
) -> None:
"""Test audio pipeline run with immediate timeout."""
client = await hass_ws_client(hass)
async def sleepy_run(*args, **kwargs):
await asyncio.sleep(3600)
with patch(
"homeassistant.components.voice_assistant.pipeline.PipelineInput._execute",
new=sleepy_run,
):
await client.send_json(
{
"id": 5,
"type": "voice_assistant/run",
"start_stage": "stt",
"end_stage": "tts",
"timeout": 0.0001,
}
)
# result
msg = await client.receive_json()
assert msg["success"]
# handler id
msg = await client.receive_json()
assert msg["event"]["handler_id"] == 1
# timeout error
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"]["code"] == "timeout"
async def test_stt_provider_missing(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test events from a pipeline run with a non-existent STT provider."""
with patch(
"homeassistant.components.stt.async_get_provider",
new=MagicMock(return_value=None),
):
client = await hass_ws_client(hass)
await client.send_json(
{
"id": 5,
"type": "voice_assistant/run",
"start_stage": "stt",
"end_stage": "tts",
}
)
# result
msg = await client.receive_json()
assert msg["success"]
# handler id
msg = await client.receive_json()
assert msg["event"]["handler_id"] == 1
# run start
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == {
"pipeline": hass.config.language,
"language": hass.config.language,
}
# stt
msg = await client.receive_json()
assert msg["event"]["type"] == "stt-start"
assert msg["event"]["data"] == {
"engine": "default",
"metadata": {
"bit_rate": 16,
"channel": 1,
"codec": "pcm",
"format": "wav",
"language": "en",
"sample_rate": 16000,
},
}
# End of audio stream (handler id + empty payload)
await client.send_bytes(b"1")
# stt error
msg = await client.receive_json()
assert msg["event"]["type"] == "error"
assert msg["event"]["data"]["code"] == "stt-provider-missing"
async def test_stt_stream_failed(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test events from a pipeline run with a non-existent STT provider."""
with patch(
"tests.components.voice_assistant.test_websocket.MockSttProvider.async_process_audio_stream",
new=MagicMock(side_effect=RuntimeError),
):
client = await hass_ws_client(hass)
await client.send_json(
{
"id": 5,
"type": "voice_assistant/run",
"start_stage": "stt",
"end_stage": "tts",
}
)
# result
msg = await client.receive_json()
assert msg["success"]
# handler id
msg = await client.receive_json()
assert msg["event"]["handler_id"] == 1
# run start
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == {
"pipeline": hass.config.language,
"language": hass.config.language,
}
# stt
msg = await client.receive_json()
assert msg["event"]["type"] == "stt-start"
assert msg["event"]["data"] == {
"engine": "default",
"metadata": {
"bit_rate": 16,
"channel": 1,
"codec": "pcm",
"format": "wav",
"language": "en",
"sample_rate": 16000,
},
}
# End of audio stream (handler id + empty payload)
await client.send_bytes(b"1")
# stt error
msg = await client.receive_json()
assert msg["event"]["type"] == "error"
assert msg["event"]["data"]["code"] == "stt-stream-failed"
async def test_tts_failed(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
) -> None:
"""Test pipeline run with text to speech error."""
client = await hass_ws_client(hass)
with patch(
"homeassistant.components.media_source.async_resolve_media",
new=MagicMock(return_value=RuntimeError),
):
await client.send_json(
{
"id": 5,
"type": "voice_assistant/run",
"start_stage": "tts",
"end_stage": "tts",
"input": {"text": "Lights are on."},
}
)
# result
msg = await client.receive_json()
assert msg["success"]
# run start
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == {
"pipeline": hass.config.language,
"language": hass.config.language,
}
# tts start
msg = await client.receive_json()
assert msg["event"]["type"] == "tts-start"
assert msg["event"]["data"] == {
"engine": "default",
"tts_input": "Lights are on.",
}
# tts error
msg = await client.receive_json()
assert msg["event"]["type"] == "error"
assert msg["event"]["data"]["code"] == "tts-failed"
async def test_invalid_stage_order(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
) -> None:
"""Test pipeline run with invalid stage order."""
client = await hass_ws_client(hass)
await client.send_json(
{
"id": 5,
"type": "voice_assistant/run",
"start_stage": "tts",
"end_stage": "stt",
"input": {"text": "Lights are on."},
}
)
# result
msg = await client.receive_json()
assert not msg["success"]