diff --git a/homeassistant/components/cloud/stt.py b/homeassistant/components/cloud/stt.py index 70618ab38ef..bdce055c3c4 100644 --- a/homeassistant/components/cloud/stt.py +++ b/homeassistant/components/cloud/stt.py @@ -1,7 +1,8 @@ """Support for the cloud for speech to text service.""" from __future__ import annotations -from aiohttp import StreamReader +from collections.abc import AsyncIterable + from hass_nabucasa import Cloud from hass_nabucasa.voice import VoiceError @@ -88,7 +89,7 @@ class CloudProvider(Provider): return [AudioChannels.CHANNEL_MONO] async def async_process_audio_stream( - self, metadata: SpeechMetadata, stream: StreamReader + self, metadata: SpeechMetadata, stream: AsyncIterable[bytes] ) -> SpeechResult: """Process an audio stream to STT service.""" content = ( diff --git a/homeassistant/components/demo/stt.py b/homeassistant/components/demo/stt.py index 9c3cf89d80e..923092fad20 100644 --- a/homeassistant/components/demo/stt.py +++ b/homeassistant/components/demo/stt.py @@ -1,7 +1,7 @@ """Support for the demo for speech to text service.""" from __future__ import annotations -from aiohttp import StreamReader +from collections.abc import AsyncIterable from homeassistant.components.stt import ( AudioBitRates, @@ -63,12 +63,12 @@ class DemoProvider(Provider): return [AudioChannels.CHANNEL_STEREO] async def async_process_audio_stream( - self, metadata: SpeechMetadata, stream: StreamReader + self, metadata: SpeechMetadata, stream: AsyncIterable[bytes] ) -> SpeechResult: """Process an audio stream to STT service.""" # Read available data - async for _ in stream.iter_chunked(4096): + async for _ in stream: pass return SpeechResult("Turn the Kitchen Lights on", SpeechResultState.SUCCESS) diff --git a/homeassistant/components/stt/__init__.py b/homeassistant/components/stt/__init__.py index 94e08d25363..63199402194 100644 --- a/homeassistant/components/stt/__init__.py +++ b/homeassistant/components/stt/__init__.py @@ -3,11 +3,12 @@ from __future__ import annotations from abc import ABC, abstractmethod import asyncio +from collections.abc import AsyncIterable from dataclasses import asdict, dataclass import logging from typing import Any -from aiohttp import StreamReader, web +from aiohttp import web from aiohttp.hdrs import istr from aiohttp.web_exceptions import ( HTTPBadRequest, @@ -153,7 +154,7 @@ class Provider(ABC): @abstractmethod async def async_process_audio_stream( - self, metadata: SpeechMetadata, stream: StreamReader + self, metadata: SpeechMetadata, stream: AsyncIterable[bytes] ) -> SpeechResult: """Process an audio stream to STT service. diff --git a/homeassistant/components/voice_assistant/manifest.json b/homeassistant/components/voice_assistant/manifest.json index 6d353660b31..644c49e9459 100644 --- a/homeassistant/components/voice_assistant/manifest.json +++ b/homeassistant/components/voice_assistant/manifest.json @@ -2,7 +2,7 @@ "domain": "voice_assistant", "name": "Voice Assistant", "codeowners": ["@balloob", "@synesthesiam"], - "dependencies": ["conversation"], + "dependencies": ["conversation", "stt", "tts"], "documentation": "https://www.home-assistant.io/integrations/voice_assistant", "iot_class": "local_push", "quality_scale": "internal" diff --git a/homeassistant/components/voice_assistant/pipeline.py b/homeassistant/components/voice_assistant/pipeline.py index 0b55d724554..0070154bd40 100644 --- a/homeassistant/components/voice_assistant/pipeline.py +++ b/homeassistant/components/voice_assistant/pipeline.py @@ -1,33 +1,80 @@ """Classes for voice assistant pipelines.""" from __future__ import annotations -from abc import ABC, abstractmethod import asyncio -from collections.abc import Callable -from dataclasses import dataclass, field +from collections.abc import AsyncIterable, Callable +from dataclasses import asdict, dataclass, field +import logging from typing import Any from homeassistant.backports.enum import StrEnum -from homeassistant.components import conversation -from homeassistant.components.media_source import async_resolve_media +from homeassistant.components import conversation, media_source, stt from homeassistant.components.tts.media_source import ( generate_media_source_id as tts_generate_media_source_id, ) -from homeassistant.core import Context, HomeAssistant +from homeassistant.core import Context, HomeAssistant, callback from homeassistant.util.dt import utcnow +from .const import DOMAIN + DEFAULT_TIMEOUT = 30 # seconds +_LOGGER = logging.getLogger(__name__) + + +@callback +def async_get_pipeline( + hass: HomeAssistant, pipeline_id: str | None = None, language: str | None = None +) -> Pipeline | None: + """Get a pipeline by id or create one for a language.""" + if pipeline_id is not None: + return hass.data[DOMAIN].get(pipeline_id) + + # Construct a pipeline for the required/configured language + language = language or hass.config.language + return Pipeline( + name=language, + language=language, + stt_engine=None, # first engine + conversation_engine=None, # first agent + tts_engine=None, # first engine + ) + + +class PipelineError(Exception): + """Base class for pipeline errors.""" + + def __init__(self, code: str, message: str) -> None: + """Set error message.""" + self.code = code + self.message = message + + super().__init__(f"Pipeline error code={code}, message={message}") + + +class SpeechToTextError(PipelineError): + """Error in speech to text portion of pipeline.""" + + +class IntentRecognitionError(PipelineError): + """Error in intent recognition portion of pipeline.""" + + +class TextToSpeechError(PipelineError): + """Error in text to speech portion of pipeline.""" + class PipelineEventType(StrEnum): """Event types emitted during a pipeline run.""" RUN_START = "run-start" - RUN_FINISH = "run-finish" + RUN_END = "run-end" + STT_START = "stt-start" + STT_END = "stt-end" INTENT_START = "intent-start" - INTENT_FINISH = "intent-finish" + INTENT_END = "intent-end" TTS_START = "tts-start" - TTS_FINISH = "tts-finish" + TTS_END = "tts-end" ERROR = "error" @@ -54,10 +101,44 @@ class Pipeline: name: str language: str | None + stt_engine: str | None conversation_engine: str | None tts_engine: str | None +class PipelineStage(StrEnum): + """Stages of a pipeline.""" + + STT = "stt" + INTENT = "intent" + TTS = "tts" + + +PIPELINE_STAGE_ORDER = [ + PipelineStage.STT, + PipelineStage.INTENT, + PipelineStage.TTS, +] + + +class PipelineRunValidationError(Exception): + """Error when a pipeline run is not valid.""" + + +class InvalidPipelineStagesError(PipelineRunValidationError): + """Error when given an invalid combination of start/end stages.""" + + def __init__( + self, + start_stage: PipelineStage, + end_stage: PipelineStage, + ) -> None: + """Set error message.""" + super().__init__( + f"Invalid stage combination: start={start_stage}, end={end_stage}" + ) + + @dataclass class PipelineRun: """Running context for a pipeline.""" @@ -65,6 +146,8 @@ class PipelineRun: hass: HomeAssistant context: Context pipeline: Pipeline + start_stage: PipelineStage + end_stage: PipelineStage event_callback: Callable[[PipelineEvent], None] language: str = None # type: ignore[assignment] @@ -72,6 +155,12 @@ class PipelineRun: """Set language for pipeline.""" self.language = self.pipeline.language or self.hass.config.language + # stt -> intent -> tts + if PIPELINE_STAGE_ORDER.index(self.end_stage) < PIPELINE_STAGE_ORDER.index( + self.start_stage + ): + raise InvalidPipelineStagesError(self.start_stage, self.end_stage) + def start(self): """Emit run start event.""" self.event_callback( @@ -84,18 +173,86 @@ class PipelineRun: ) ) - def finish(self): - """Emit run finish event.""" + def end(self): + """Emit run end event.""" self.event_callback( PipelineEvent( - PipelineEventType.RUN_FINISH, + PipelineEventType.RUN_END, ) ) + async def speech_to_text( + self, + metadata: stt.SpeechMetadata, + stream: AsyncIterable[bytes], + ) -> str: + """Run speech to text portion of pipeline. Returns the spoken text.""" + engine = self.pipeline.stt_engine or "default" + self.event_callback( + PipelineEvent( + PipelineEventType.STT_START, + { + "engine": engine, + "metadata": asdict(metadata), + }, + ) + ) + + try: + # Load provider + stt_provider = stt.async_get_provider(self.hass, self.pipeline.stt_engine) + assert stt_provider is not None + except Exception as src_error: + stt_error = SpeechToTextError( + code="stt-provider-missing", + message=f"No speech to text provider for: {engine}", + ) + _LOGGER.exception(stt_error.message) + self.event_callback( + PipelineEvent( + PipelineEventType.ERROR, + {"code": stt_error.code, "message": stt_error.message}, + ) + ) + raise stt_error from src_error + + try: + # Transcribe audio stream + result = await stt_provider.async_process_audio_stream(metadata, stream) + assert (result.text is not None) and ( + result.result == stt.SpeechResultState.SUCCESS + ) + except Exception as src_error: + stt_error = SpeechToTextError( + code="stt-stream-failed", + message="Unexpected error during speech to text", + ) + _LOGGER.exception(stt_error.message) + self.event_callback( + PipelineEvent( + PipelineEventType.ERROR, + {"code": stt_error.code, "message": stt_error.message}, + ) + ) + raise stt_error from src_error + + self.event_callback( + PipelineEvent( + PipelineEventType.STT_END, + { + "stt_output": { + "text": result.text, + } + }, + ) + ) + + return result.text + async def recognize_intent( self, intent_input: str, conversation_id: str | None - ) -> conversation.ConversationResult: - """Run intent recognition portion of pipeline.""" + ) -> str: + """Run intent recognition portion of pipeline. Returns text to speak.""" self.event_callback( PipelineEvent( PipelineEventType.INTENT_START, @@ -106,23 +263,39 @@ class PipelineRun: ) ) - conversation_result = await conversation.async_converse( - hass=self.hass, - text=intent_input, - conversation_id=conversation_id, - context=self.context, - language=self.language, - agent_id=self.pipeline.conversation_engine, - ) + try: + conversation_result = await conversation.async_converse( + hass=self.hass, + text=intent_input, + conversation_id=conversation_id, + context=self.context, + language=self.language, + agent_id=self.pipeline.conversation_engine, + ) + except Exception as src_error: + intent_error = IntentRecognitionError( + code="intent-failed", + message="Unexpected error during intent recognition", + ) + _LOGGER.exception(intent_error.message) + self.event_callback( + PipelineEvent( + PipelineEventType.ERROR, + {"code": intent_error.code, "message": intent_error.message}, + ) + ) + raise intent_error from src_error self.event_callback( PipelineEvent( - PipelineEventType.INTENT_FINISH, + PipelineEventType.INTENT_END, {"intent_output": conversation_result.as_dict()}, ) ) - return conversation_result + speech = conversation_result.response.speech.get("plain", {}).get("speech", "") + + return speech async def text_to_speech(self, tts_input: str) -> str: """Run text to speech portion of pipeline. Returns URL of TTS audio.""" @@ -136,29 +309,57 @@ class PipelineRun: ) ) - tts_media = await async_resolve_media( - self.hass, - tts_generate_media_source_id( + try: + # Synthesize audio and get URL + tts_media = await media_source.async_resolve_media( self.hass, - tts_input, - engine=self.pipeline.tts_engine, - ), - ) - tts_url = tts_media.url + tts_generate_media_source_id( + self.hass, + tts_input, + engine=self.pipeline.tts_engine, + ), + ) + except Exception as src_error: + tts_error = TextToSpeechError( + code="tts-failed", + message="Unexpected error during text to speech", + ) + _LOGGER.exception(tts_error.message) + self.event_callback( + PipelineEvent( + PipelineEventType.ERROR, + {"code": tts_error.code, "message": tts_error.message}, + ) + ) + raise tts_error from src_error self.event_callback( PipelineEvent( - PipelineEventType.TTS_FINISH, - {"tts_output": tts_url}, + PipelineEventType.TTS_END, + {"tts_output": asdict(tts_media)}, ) ) - return tts_url + return tts_media.url @dataclass -class PipelineRequest(ABC): - """Request to for a pipeline run.""" +class PipelineInput: + """Input to a pipeline run.""" + + stt_metadata: stt.SpeechMetadata | None = None + """Metadata of stt input audio. Required when start_stage = stt.""" + + stt_stream: AsyncIterable[bytes] | None = None + """Input audio for stt. Required when start_stage = stt.""" + + intent_input: str | None = None + """Input for conversation agent. Required when start_stage = intent.""" + + tts_input: str | None = None + """Input for text to speech. Required when start_stage = tts.""" + + conversation_id: str | None = None async def execute( self, run: PipelineRun, timeout: int | float | None = DEFAULT_TIMEOUT @@ -169,47 +370,60 @@ class PipelineRequest(ABC): timeout=timeout, ) - @abstractmethod async def _execute(self, run: PipelineRun): - """Run pipeline with request info and context.""" + self._validate(run.start_stage) - -@dataclass -class TextPipelineRequest(PipelineRequest): - """Request to run the text portion only of a pipeline.""" - - intent_input: str - conversation_id: str | None = None - - async def _execute( - self, - run: PipelineRun, - ): + # stt -> intent -> tts run.start() - await run.recognize_intent(self.intent_input, self.conversation_id) - run.finish() + current_stage = run.start_stage + # Speech to text + intent_input = self.intent_input + if current_stage == PipelineStage.STT: + assert self.stt_metadata is not None + assert self.stt_stream is not None + intent_input = await run.speech_to_text( + self.stt_metadata, + self.stt_stream, + ) + current_stage = PipelineStage.INTENT -@dataclass -class AudioPipelineRequest(PipelineRequest): - """Request to full pipeline from audio input (stt) to audio output (tts).""" + if run.end_stage != PipelineStage.STT: + tts_input = self.tts_input - intent_input: str # this will be changed to stt audio - conversation_id: str | None = None + if current_stage == PipelineStage.INTENT: + assert intent_input is not None + tts_input = await run.recognize_intent( + intent_input, self.conversation_id + ) + current_stage = PipelineStage.TTS - async def _execute(self, run: PipelineRun): - run.start() + if run.end_stage != PipelineStage.INTENT: + if current_stage == PipelineStage.TTS: + assert tts_input is not None + await run.text_to_speech(tts_input) - # stt will go here + run.end() - conversation_result = await run.recognize_intent( - self.intent_input, self.conversation_id - ) + def _validate(self, stage: PipelineStage): + """Validate pipeline input against start stage.""" + if stage == PipelineStage.STT: + if self.stt_metadata is None: + raise PipelineRunValidationError( + "stt_metadata is required for speech to text" + ) - tts_input = conversation_result.response.speech.get("plain", {}).get( - "speech", "" - ) - - await run.text_to_speech(tts_input) - - run.finish() + if self.stt_stream is None: + raise PipelineRunValidationError( + "stt_stream is required for speech to text" + ) + elif stage == PipelineStage.INTENT: + if self.intent_input is None: + raise PipelineRunValidationError( + "intent_input is required for intent recognition" + ) + elif stage == PipelineStage.TTS: + if self.tts_input is None: + raise PipelineRunValidationError( + "tts_input is required for text to speech" + ) diff --git a/homeassistant/components/voice_assistant/websocket_api.py b/homeassistant/components/voice_assistant/websocket_api.py index 54e87e292a1..cc4799f13e7 100644 --- a/homeassistant/components/voice_assistant/websocket_api.py +++ b/homeassistant/components/voice_assistant/websocket_api.py @@ -1,13 +1,24 @@ """Voice Assistant Websocket API.""" +import asyncio +from collections.abc import Callable +import logging from typing import Any import voluptuous as vol -from homeassistant.components import websocket_api +from homeassistant.components import stt, websocket_api from homeassistant.core import HomeAssistant, callback -from .const import DOMAIN -from .pipeline import DEFAULT_TIMEOUT, Pipeline, PipelineRun, TextPipelineRequest +from .pipeline import ( + DEFAULT_TIMEOUT, + PipelineError, + PipelineInput, + PipelineRun, + PipelineStage, + async_get_pipeline, +) + +_LOGGER = logging.getLogger(__name__) @callback @@ -19,9 +30,13 @@ def async_register_websocket_api(hass: HomeAssistant) -> None: @websocket_api.websocket_command( { vol.Required("type"): "voice_assistant/run", + # pylint: disable-next=unnecessary-lambda + vol.Required("start_stage"): lambda val: PipelineStage(val), + # pylint: disable-next=unnecessary-lambda + vol.Required("end_stage"): lambda val: PipelineStage(val), + vol.Optional("input"): {"text": str}, vol.Optional("language"): str, vol.Optional("pipeline"): str, - vol.Required("intent_input"): str, vol.Optional("conversation_id"): vol.Any(str, None), vol.Optional("timeout"): vol.Any(float, int), } @@ -33,39 +48,74 @@ async def websocket_run( msg: dict[str, Any], ) -> None: """Run a pipeline.""" + language = msg.get("language", hass.config.language) pipeline_id = msg.get("pipeline") - if pipeline_id is not None: - pipeline = hass.data[DOMAIN].get(pipeline_id) - if pipeline is None: - connection.send_error( - msg["id"], - "pipeline_not_found", - f"Pipeline not found: {pipeline_id}", - ) - return + pipeline = async_get_pipeline( + hass, + pipeline_id=pipeline_id, + language=language, + ) + if pipeline is None: + connection.send_error( + msg["id"], + "pipeline-not-found", + f"Pipeline not found: id={pipeline_id}, language={language}", + ) + return - else: - # Construct a pipeline for the required/configured language - language = msg.get("language", hass.config.language) - pipeline = Pipeline( - name=language, - language=language, - conversation_engine=None, - tts_engine=None, + timeout = msg.get("timeout", DEFAULT_TIMEOUT) + start_stage = PipelineStage(msg["start_stage"]) + end_stage = PipelineStage(msg["end_stage"]) + handler_id: int | None = None + unregister_handler: Callable[[], None] | None = None + + # Arguments to PipelineInput + input_args: dict[str, Any] = { + "conversation_id": msg.get("conversation_id"), + } + + if start_stage == PipelineStage.STT: + # Audio pipeline that will receive audio as binary websocket messages + audio_queue: "asyncio.Queue[bytes]" = asyncio.Queue() + + async def stt_stream(): + # Yield until we receive an empty chunk + while chunk := await audio_queue.get(): + yield chunk + + def handle_binary(_hass, _connection, data: bytes): + # Forward to STT audio stream + audio_queue.put_nowait(data) + + handler_id, unregister_handler = connection.async_register_binary_handler( + handle_binary ) - # Run pipeline with a timeout. - # Events are sent over the websocket connection. - timeout = msg.get("timeout", DEFAULT_TIMEOUT) + # Audio input must be raw PCM at 16Khz with 16-bit mono samples + input_args["stt_metadata"] = stt.SpeechMetadata( + language=language, + format=stt.AudioFormats.WAV, + codec=stt.AudioCodecs.PCM, + bit_rate=stt.AudioBitRates.BITRATE_16, + sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, + channel=stt.AudioChannels.CHANNEL_MONO, + ) + input_args["stt_stream"] = stt_stream() + elif start_stage == PipelineStage.INTENT: + # Input to conversation agent + input_args["intent_input"] = msg["input"]["text"] + elif start_stage == PipelineStage.TTS: + # Input to text to speech system + input_args["tts_input"] = msg["input"]["text"] + run_task = hass.async_create_task( - TextPipelineRequest( - intent_input=msg["intent_input"], - conversation_id=msg.get("conversation_id"), - ).execute( + PipelineInput(**input_args).execute( PipelineRun( hass, - connection.context(msg), - pipeline, + context=connection.context(msg), + pipeline=pipeline, + start_stage=start_stage, + end_stage=end_stage, event_callback=lambda event: connection.send_event( msg["id"], event.as_dict() ), @@ -77,7 +127,20 @@ async def websocket_run( # Cancel pipeline if user unsubscribes connection.subscriptions[msg["id"]] = run_task.cancel + # Confirm subscription connection.send_result(msg["id"]) - # Task contains a timeout - await run_task + if handler_id is not None: + # Send handler id to client + connection.send_event(msg["id"], {"handler_id": handler_id}) + + try: + # Task contains a timeout + await run_task + except PipelineError as error: + # Report more specific error when possible + connection.send_error(msg["id"], error.code, error.message) + finally: + if unregister_handler is not None: + # Unregister binary handler + unregister_handler() diff --git a/tests/components/stt/test_init.py b/tests/components/stt/test_init.py index e36b8af3f6c..3d20dbc5403 100644 --- a/tests/components/stt/test_init.py +++ b/tests/components/stt/test_init.py @@ -1,5 +1,5 @@ """Test STT component setup.""" -from asyncio import StreamReader +from collections.abc import AsyncIterable from http import HTTPStatus from unittest.mock import AsyncMock, Mock @@ -64,7 +64,7 @@ class MockProvider(Provider): return [AudioChannels.CHANNEL_MONO] async def async_process_audio_stream( - self, metadata: SpeechMetadata, stream: StreamReader + self, metadata: SpeechMetadata, stream: AsyncIterable[bytes] ) -> SpeechResult: """Process an audio stream.""" self.calls.append((metadata, stream)) diff --git a/tests/components/voice_assistant/test_pipeline.py b/tests/components/voice_assistant/test_pipeline.py deleted file mode 100644 index 343719a49fd..00000000000 --- a/tests/components/voice_assistant/test_pipeline.py +++ /dev/null @@ -1,110 +0,0 @@ -"""Pipeline tests for Voice Assistant integration.""" -from unittest.mock import MagicMock, patch - -import pytest - -from homeassistant.components.voice_assistant.pipeline import ( - AudioPipelineRequest, - Pipeline, - PipelineEventType, - PipelineRun, -) -from homeassistant.core import Context -from homeassistant.setup import async_setup_component - -from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import - mock_get_cache_files, - mock_init_cache_dir, -) - - -@pytest.fixture(autouse=True) -async def init_components(hass): - """Initialize relevant components with empty configs.""" - assert await async_setup_component(hass, "voice_assistant", {}) - - -@pytest.fixture -async def mock_get_tts_audio(hass): - """Set up media source.""" - assert await async_setup_component(hass, "media_source", {}) - assert await async_setup_component( - hass, - "tts", - { - "tts": { - "platform": "demo", - } - }, - ) - - with patch( - "homeassistant.components.demo.tts.DemoProvider.get_tts_audio", - return_value=("mp3", b""), - ) as mock_get_tts: - yield mock_get_tts - - -async def test_audio_pipeline(hass, mock_get_tts_audio): - """Run audio pipeline with mock TTS.""" - pipeline = Pipeline( - name="test", - language=hass.config.language, - conversation_engine=None, - tts_engine=None, - ) - - event_callback = MagicMock() - await AudioPipelineRequest(intent_input="Are the lights on?").execute( - PipelineRun( - hass, - context=Context(), - pipeline=pipeline, - event_callback=event_callback, - language=hass.config.language, - ) - ) - - calls = event_callback.mock_calls - assert calls[0].args[0].type == PipelineEventType.RUN_START - assert calls[0].args[0].data == { - "pipeline": "test", - "language": hass.config.language, - } - - assert calls[1].args[0].type == PipelineEventType.INTENT_START - assert calls[1].args[0].data == { - "engine": "default", - "intent_input": "Are the lights on?", - } - assert calls[2].args[0].type == PipelineEventType.INTENT_FINISH - assert calls[2].args[0].data == { - "intent_output": { - "conversation_id": None, - "response": { - "card": {}, - "data": {"code": "no_intent_match"}, - "language": hass.config.language, - "response_type": "error", - "speech": { - "plain": { - "extra_data": None, - "speech": "Sorry, I couldn't understand that", - } - }, - }, - } - } - - assert calls[3].args[0].type == PipelineEventType.TTS_START - assert calls[3].args[0].data == { - "engine": "default", - "tts_input": "Sorry, I couldn't understand that", - } - assert calls[4].args[0].type == PipelineEventType.TTS_FINISH - assert ( - calls[4].args[0].data["tts_output"] - == f"/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_{hass.config.language}_-_demo.mp3" - ) - - assert calls[5].args[0].type == PipelineEventType.RUN_FINISH diff --git a/tests/components/voice_assistant/test_websocket.py b/tests/components/voice_assistant/test_websocket.py index 2fec6cdfb03..a1ba8b5f7cb 100644 --- a/tests/components/voice_assistant/test_websocket.py +++ b/tests/components/voice_assistant/test_websocket.py @@ -1,20 +1,94 @@ """Websocket tests for Voice Assistant integration.""" import asyncio -from unittest.mock import patch +from collections.abc import AsyncIterable +from unittest.mock import MagicMock, patch import pytest +from homeassistant.components import stt from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component +from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import + mock_get_cache_files, + mock_init_cache_dir, +) from tests.typing import WebSocketGenerator +_TRANSCRIPT = "test transcript" + + +class MockSttProvider(stt.Provider): + """Mock STT provider.""" + + def __init__(self, hass: HomeAssistant, text: str) -> None: + """Init test provider.""" + self.hass = hass + self.text = text + + @property + def supported_languages(self) -> list[str]: + """Return a list of supported languages.""" + return [self.hass.config.language] + + @property + def supported_formats(self) -> list[stt.AudioFormats]: + """Return a list of supported formats.""" + return [stt.AudioFormats.WAV] + + @property + def supported_codecs(self) -> list[stt.AudioCodecs]: + """Return a list of supported codecs.""" + return [stt.AudioCodecs.PCM] + + @property + def supported_bit_rates(self) -> list[stt.AudioBitRates]: + """Return a list of supported bitrates.""" + return [stt.AudioBitRates.BITRATE_16] + + @property + def supported_sample_rates(self) -> list[stt.AudioSampleRates]: + """Return a list of supported samplerates.""" + return [stt.AudioSampleRates.SAMPLERATE_16000] + + @property + def supported_channels(self) -> list[stt.AudioChannels]: + """Return a list of supported channels.""" + return [stt.AudioChannels.CHANNEL_MONO] + + async def async_process_audio_stream( + self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes] + ) -> stt.SpeechResult: + """Process an audio stream.""" + return stt.SpeechResult(self.text, stt.SpeechResultState.SUCCESS) + @pytest.fixture(autouse=True) async def init_components(hass): """Initialize relevant components with empty configs.""" + assert await async_setup_component(hass, "media_source", {}) + assert await async_setup_component( + hass, + "tts", + { + "tts": { + "platform": "demo", + } + }, + ) + assert await async_setup_component(hass, "stt", {}) + + # mock_platform fails because it can't import + hass.data[stt.DOMAIN] = {"test": MockSttProvider(hass, _TRANSCRIPT)} + assert await async_setup_component(hass, "voice_assistant", {}) + with patch( + "homeassistant.components.demo.tts.DemoProvider.get_tts_audio", + return_value=("mp3", b""), + ) as mock_get_tts: + yield mock_get_tts + async def test_text_only_pipeline( hass: HomeAssistant, @@ -27,7 +101,9 @@ async def test_text_only_pipeline( { "id": 5, "type": "voice_assistant/run", - "intent_input": "Are the lights on?", + "start_stage": "intent", + "end_stage": "intent", + "input": {"text": "Are the lights on?"}, } ) @@ -52,7 +128,7 @@ async def test_text_only_pipeline( } msg = await client.receive_json() - assert msg["event"]["type"] == "intent-finish" + assert msg["event"]["type"] == "intent-end" assert msg["event"]["data"] == { "intent_output": { "response": { @@ -71,13 +147,120 @@ async def test_text_only_pipeline( } } - # run finish + # run end msg = await client.receive_json() - assert msg["event"]["type"] == "run-finish" + assert msg["event"]["type"] == "run-end" assert msg["event"]["data"] == {} -async def test_conversation_timeout( +async def test_audio_pipeline( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test events from a pipeline run with audio input/output.""" + client = await hass_ws_client(hass) + + await client.send_json( + { + "id": 5, + "type": "voice_assistant/run", + "start_stage": "stt", + "end_stage": "tts", + } + ) + + # result + msg = await client.receive_json() + assert msg["success"] + + # handler id + msg = await client.receive_json() + assert msg["event"]["handler_id"] == 1 + + # run start + msg = await client.receive_json() + assert msg["event"]["type"] == "run-start" + assert msg["event"]["data"] == { + "pipeline": hass.config.language, + "language": hass.config.language, + } + + # stt + msg = await client.receive_json() + assert msg["event"]["type"] == "stt-start" + assert msg["event"]["data"] == { + "engine": "default", + "metadata": { + "bit_rate": 16, + "channel": 1, + "codec": "pcm", + "format": "wav", + "language": "en", + "sample_rate": 16000, + }, + } + + # End of audio stream (handler id + empty payload) + await client.send_bytes(b"1") + + msg = await client.receive_json() + assert msg["event"]["type"] == "stt-end" + assert msg["event"]["data"] == { + "stt_output": {"text": _TRANSCRIPT}, + } + + # intent + msg = await client.receive_json() + assert msg["event"]["type"] == "intent-start" + assert msg["event"]["data"] == { + "engine": "default", + "intent_input": _TRANSCRIPT, + } + + msg = await client.receive_json() + assert msg["event"]["type"] == "intent-end" + assert msg["event"]["data"] == { + "intent_output": { + "response": { + "speech": { + "plain": { + "speech": "Sorry, I couldn't understand that", + "extra_data": None, + } + }, + "card": {}, + "language": "en", + "response_type": "error", + "data": {"code": "no_intent_match"}, + }, + "conversation_id": None, + } + } + + # text to speech + msg = await client.receive_json() + assert msg["event"]["type"] == "tts-start" + assert msg["event"]["data"] == { + "engine": "default", + "tts_input": "Sorry, I couldn't understand that", + } + + msg = await client.receive_json() + assert msg["event"]["type"] == "tts-end" + assert msg["event"]["data"] == { + "tts_output": { + "url": f"/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_{hass.config.language}_-_demo.mp3", + "mime_type": "audio/mpeg", + }, + } + + # run end + msg = await client.receive_json() + assert msg["event"]["type"] == "run-end" + assert msg["event"]["data"] == {} + + +async def test_intent_timeout( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components ) -> None: """Test partial pipeline run with conversation agent timeout.""" @@ -94,7 +277,9 @@ async def test_conversation_timeout( { "id": 5, "type": "voice_assistant/run", - "intent_input": "Are the lights on?", + "start_stage": "intent", + "end_stage": "intent", + "input": {"text": "Are the lights on?"}, "timeout": 0.00001, } ) @@ -125,24 +310,26 @@ async def test_conversation_timeout( assert msg["error"]["code"] == "timeout" -async def test_pipeline_timeout( +async def test_text_pipeline_timeout( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components ) -> None: - """Test pipeline run with immediate timeout.""" + """Test text-only pipeline run with immediate timeout.""" client = await hass_ws_client(hass) async def sleepy_run(*args, **kwargs): await asyncio.sleep(3600) with patch( - "homeassistant.components.voice_assistant.pipeline.TextPipelineRequest._execute", + "homeassistant.components.voice_assistant.pipeline.PipelineInput._execute", new=sleepy_run, ): await client.send_json( { "id": 5, "type": "voice_assistant/run", - "intent_input": "Are the lights on?", + "start_stage": "intent", + "end_stage": "intent", + "input": {"text": "Are the lights on?"}, "timeout": 0.0001, } ) @@ -155,3 +342,273 @@ async def test_pipeline_timeout( msg = await client.receive_json() assert not msg["success"] assert msg["error"]["code"] == "timeout" + + +async def test_intent_failed( + hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components +) -> None: + """Test text-only pipeline run with conversation agent error.""" + client = await hass_ws_client(hass) + + with patch( + "homeassistant.components.conversation.async_converse", + new=MagicMock(return_value=RuntimeError), + ): + await client.send_json( + { + "id": 5, + "type": "voice_assistant/run", + "start_stage": "intent", + "end_stage": "intent", + "input": {"text": "Are the lights on?"}, + } + ) + + # result + msg = await client.receive_json() + assert msg["success"] + + # run start + msg = await client.receive_json() + assert msg["event"]["type"] == "run-start" + assert msg["event"]["data"] == { + "pipeline": hass.config.language, + "language": hass.config.language, + } + + # intent start + msg = await client.receive_json() + assert msg["event"]["type"] == "intent-start" + assert msg["event"]["data"] == { + "engine": "default", + "intent_input": "Are the lights on?", + } + + # intent error + msg = await client.receive_json() + assert msg["event"]["type"] == "error" + assert msg["event"]["data"]["code"] == "intent-failed" + + +async def test_audio_pipeline_timeout( + hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components +) -> None: + """Test audio pipeline run with immediate timeout.""" + client = await hass_ws_client(hass) + + async def sleepy_run(*args, **kwargs): + await asyncio.sleep(3600) + + with patch( + "homeassistant.components.voice_assistant.pipeline.PipelineInput._execute", + new=sleepy_run, + ): + await client.send_json( + { + "id": 5, + "type": "voice_assistant/run", + "start_stage": "stt", + "end_stage": "tts", + "timeout": 0.0001, + } + ) + + # result + msg = await client.receive_json() + assert msg["success"] + + # handler id + msg = await client.receive_json() + assert msg["event"]["handler_id"] == 1 + + # timeout error + msg = await client.receive_json() + assert not msg["success"] + assert msg["error"]["code"] == "timeout" + + +async def test_stt_provider_missing( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test events from a pipeline run with a non-existent STT provider.""" + with patch( + "homeassistant.components.stt.async_get_provider", + new=MagicMock(return_value=None), + ): + client = await hass_ws_client(hass) + + await client.send_json( + { + "id": 5, + "type": "voice_assistant/run", + "start_stage": "stt", + "end_stage": "tts", + } + ) + + # result + msg = await client.receive_json() + assert msg["success"] + + # handler id + msg = await client.receive_json() + assert msg["event"]["handler_id"] == 1 + + # run start + msg = await client.receive_json() + assert msg["event"]["type"] == "run-start" + assert msg["event"]["data"] == { + "pipeline": hass.config.language, + "language": hass.config.language, + } + + # stt + msg = await client.receive_json() + assert msg["event"]["type"] == "stt-start" + assert msg["event"]["data"] == { + "engine": "default", + "metadata": { + "bit_rate": 16, + "channel": 1, + "codec": "pcm", + "format": "wav", + "language": "en", + "sample_rate": 16000, + }, + } + + # End of audio stream (handler id + empty payload) + await client.send_bytes(b"1") + + # stt error + msg = await client.receive_json() + assert msg["event"]["type"] == "error" + assert msg["event"]["data"]["code"] == "stt-provider-missing" + + +async def test_stt_stream_failed( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test events from a pipeline run with a non-existent STT provider.""" + with patch( + "tests.components.voice_assistant.test_websocket.MockSttProvider.async_process_audio_stream", + new=MagicMock(side_effect=RuntimeError), + ): + client = await hass_ws_client(hass) + + await client.send_json( + { + "id": 5, + "type": "voice_assistant/run", + "start_stage": "stt", + "end_stage": "tts", + } + ) + + # result + msg = await client.receive_json() + assert msg["success"] + + # handler id + msg = await client.receive_json() + assert msg["event"]["handler_id"] == 1 + + # run start + msg = await client.receive_json() + assert msg["event"]["type"] == "run-start" + assert msg["event"]["data"] == { + "pipeline": hass.config.language, + "language": hass.config.language, + } + + # stt + msg = await client.receive_json() + assert msg["event"]["type"] == "stt-start" + assert msg["event"]["data"] == { + "engine": "default", + "metadata": { + "bit_rate": 16, + "channel": 1, + "codec": "pcm", + "format": "wav", + "language": "en", + "sample_rate": 16000, + }, + } + + # End of audio stream (handler id + empty payload) + await client.send_bytes(b"1") + + # stt error + msg = await client.receive_json() + assert msg["event"]["type"] == "error" + assert msg["event"]["data"]["code"] == "stt-stream-failed" + + +async def test_tts_failed( + hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components +) -> None: + """Test pipeline run with text to speech error.""" + client = await hass_ws_client(hass) + + with patch( + "homeassistant.components.media_source.async_resolve_media", + new=MagicMock(return_value=RuntimeError), + ): + await client.send_json( + { + "id": 5, + "type": "voice_assistant/run", + "start_stage": "tts", + "end_stage": "tts", + "input": {"text": "Lights are on."}, + } + ) + + # result + msg = await client.receive_json() + assert msg["success"] + + # run start + msg = await client.receive_json() + assert msg["event"]["type"] == "run-start" + assert msg["event"]["data"] == { + "pipeline": hass.config.language, + "language": hass.config.language, + } + + # tts start + msg = await client.receive_json() + assert msg["event"]["type"] == "tts-start" + assert msg["event"]["data"] == { + "engine": "default", + "tts_input": "Lights are on.", + } + + # tts error + msg = await client.receive_json() + assert msg["event"]["type"] == "error" + assert msg["event"]["data"]["code"] == "tts-failed" + + +async def test_invalid_stage_order( + hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components +) -> None: + """Test pipeline run with invalid stage order.""" + client = await hass_ws_client(hass) + + await client.send_json( + { + "id": 5, + "type": "voice_assistant/run", + "start_stage": "tts", + "end_stage": "stt", + "input": {"text": "Lights are on."}, + } + ) + + # result + msg = await client.receive_json() + assert not msg["success"]