Add speech to text over binary websocket to pipeline (#90082)
* Allow passing binary to the WS connection * Expand test coverage * Test non-existing handler * Add text to speech and stages to pipeline * Default to "cloud" TTS when engine is None * Refactor pipeline request to split text/audio * Refactor with PipelineRun * Generate pipeline from language * Clean up * Restore TTS code * Add audio pipeline test * Clean TTS cache in test * Clean up tests and pipeline base class * Stop pylint and pytest magics from fighting * Include mock_get_cache_files * Working on STT * Preparing to test * First successful test * Send handler_id * Allow signaling end of stream using empty payloads * Store handlers in a list * Handle binary handlers raising exceptions * Add stt/tts dependencies to voice_assistant * Include STT audio in pipeline test * Working on tests * Refactoring with stages * Fix tests * Add more tests * Add method docs * Change stt demo/cloud to AsyncIterable * Add pipeline error events * Move handler id to separate message before pipeline * Add test for invalid stage order * Change "finish" to "end" * Use enum --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
185d6d74d7
commit
3e3ece4e56
9 changed files with 860 additions and 234 deletions
|
@ -1,7 +1,8 @@
|
||||||
"""Support for the cloud for speech to text service."""
|
"""Support for the cloud for speech to text service."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from aiohttp import StreamReader
|
from collections.abc import AsyncIterable
|
||||||
|
|
||||||
from hass_nabucasa import Cloud
|
from hass_nabucasa import Cloud
|
||||||
from hass_nabucasa.voice import VoiceError
|
from hass_nabucasa.voice import VoiceError
|
||||||
|
|
||||||
|
@ -88,7 +89,7 @@ class CloudProvider(Provider):
|
||||||
return [AudioChannels.CHANNEL_MONO]
|
return [AudioChannels.CHANNEL_MONO]
|
||||||
|
|
||||||
async def async_process_audio_stream(
|
async def async_process_audio_stream(
|
||||||
self, metadata: SpeechMetadata, stream: StreamReader
|
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
||||||
) -> SpeechResult:
|
) -> SpeechResult:
|
||||||
"""Process an audio stream to STT service."""
|
"""Process an audio stream to STT service."""
|
||||||
content = (
|
content = (
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
"""Support for the demo for speech to text service."""
|
"""Support for the demo for speech to text service."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from aiohttp import StreamReader
|
from collections.abc import AsyncIterable
|
||||||
|
|
||||||
from homeassistant.components.stt import (
|
from homeassistant.components.stt import (
|
||||||
AudioBitRates,
|
AudioBitRates,
|
||||||
|
@ -63,12 +63,12 @@ class DemoProvider(Provider):
|
||||||
return [AudioChannels.CHANNEL_STEREO]
|
return [AudioChannels.CHANNEL_STEREO]
|
||||||
|
|
||||||
async def async_process_audio_stream(
|
async def async_process_audio_stream(
|
||||||
self, metadata: SpeechMetadata, stream: StreamReader
|
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
||||||
) -> SpeechResult:
|
) -> SpeechResult:
|
||||||
"""Process an audio stream to STT service."""
|
"""Process an audio stream to STT service."""
|
||||||
|
|
||||||
# Read available data
|
# Read available data
|
||||||
async for _ in stream.iter_chunked(4096):
|
async for _ in stream:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return SpeechResult("Turn the Kitchen Lights on", SpeechResultState.SUCCESS)
|
return SpeechResult("Turn the Kitchen Lights on", SpeechResultState.SUCCESS)
|
||||||
|
|
|
@ -3,11 +3,12 @@ from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections.abc import AsyncIterable
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from aiohttp import StreamReader, web
|
from aiohttp import web
|
||||||
from aiohttp.hdrs import istr
|
from aiohttp.hdrs import istr
|
||||||
from aiohttp.web_exceptions import (
|
from aiohttp.web_exceptions import (
|
||||||
HTTPBadRequest,
|
HTTPBadRequest,
|
||||||
|
@ -153,7 +154,7 @@ class Provider(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def async_process_audio_stream(
|
async def async_process_audio_stream(
|
||||||
self, metadata: SpeechMetadata, stream: StreamReader
|
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
||||||
) -> SpeechResult:
|
) -> SpeechResult:
|
||||||
"""Process an audio stream to STT service.
|
"""Process an audio stream to STT service.
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
"domain": "voice_assistant",
|
"domain": "voice_assistant",
|
||||||
"name": "Voice Assistant",
|
"name": "Voice Assistant",
|
||||||
"codeowners": ["@balloob", "@synesthesiam"],
|
"codeowners": ["@balloob", "@synesthesiam"],
|
||||||
"dependencies": ["conversation"],
|
"dependencies": ["conversation", "stt", "tts"],
|
||||||
"documentation": "https://www.home-assistant.io/integrations/voice_assistant",
|
"documentation": "https://www.home-assistant.io/integrations/voice_assistant",
|
||||||
"iot_class": "local_push",
|
"iot_class": "local_push",
|
||||||
"quality_scale": "internal"
|
"quality_scale": "internal"
|
||||||
|
|
|
@ -1,33 +1,80 @@
|
||||||
"""Classes for voice assistant pipelines."""
|
"""Classes for voice assistant pipelines."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Callable
|
from collections.abc import AsyncIterable, Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from homeassistant.backports.enum import StrEnum
|
from homeassistant.backports.enum import StrEnum
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation, media_source, stt
|
||||||
from homeassistant.components.media_source import async_resolve_media
|
|
||||||
from homeassistant.components.tts.media_source import (
|
from homeassistant.components.tts.media_source import (
|
||||||
generate_media_source_id as tts_generate_media_source_id,
|
generate_media_source_id as tts_generate_media_source_id,
|
||||||
)
|
)
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant, callback
|
||||||
from homeassistant.util.dt import utcnow
|
from homeassistant.util.dt import utcnow
|
||||||
|
|
||||||
|
from .const import DOMAIN
|
||||||
|
|
||||||
DEFAULT_TIMEOUT = 30 # seconds
|
DEFAULT_TIMEOUT = 30 # seconds
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_get_pipeline(
|
||||||
|
hass: HomeAssistant, pipeline_id: str | None = None, language: str | None = None
|
||||||
|
) -> Pipeline | None:
|
||||||
|
"""Get a pipeline by id or create one for a language."""
|
||||||
|
if pipeline_id is not None:
|
||||||
|
return hass.data[DOMAIN].get(pipeline_id)
|
||||||
|
|
||||||
|
# Construct a pipeline for the required/configured language
|
||||||
|
language = language or hass.config.language
|
||||||
|
return Pipeline(
|
||||||
|
name=language,
|
||||||
|
language=language,
|
||||||
|
stt_engine=None, # first engine
|
||||||
|
conversation_engine=None, # first agent
|
||||||
|
tts_engine=None, # first engine
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineError(Exception):
|
||||||
|
"""Base class for pipeline errors."""
|
||||||
|
|
||||||
|
def __init__(self, code: str, message: str) -> None:
|
||||||
|
"""Set error message."""
|
||||||
|
self.code = code
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
super().__init__(f"Pipeline error code={code}, message={message}")
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechToTextError(PipelineError):
|
||||||
|
"""Error in speech to text portion of pipeline."""
|
||||||
|
|
||||||
|
|
||||||
|
class IntentRecognitionError(PipelineError):
|
||||||
|
"""Error in intent recognition portion of pipeline."""
|
||||||
|
|
||||||
|
|
||||||
|
class TextToSpeechError(PipelineError):
|
||||||
|
"""Error in text to speech portion of pipeline."""
|
||||||
|
|
||||||
|
|
||||||
class PipelineEventType(StrEnum):
|
class PipelineEventType(StrEnum):
|
||||||
"""Event types emitted during a pipeline run."""
|
"""Event types emitted during a pipeline run."""
|
||||||
|
|
||||||
RUN_START = "run-start"
|
RUN_START = "run-start"
|
||||||
RUN_FINISH = "run-finish"
|
RUN_END = "run-end"
|
||||||
|
STT_START = "stt-start"
|
||||||
|
STT_END = "stt-end"
|
||||||
INTENT_START = "intent-start"
|
INTENT_START = "intent-start"
|
||||||
INTENT_FINISH = "intent-finish"
|
INTENT_END = "intent-end"
|
||||||
TTS_START = "tts-start"
|
TTS_START = "tts-start"
|
||||||
TTS_FINISH = "tts-finish"
|
TTS_END = "tts-end"
|
||||||
ERROR = "error"
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,10 +101,44 @@ class Pipeline:
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
language: str | None
|
language: str | None
|
||||||
|
stt_engine: str | None
|
||||||
conversation_engine: str | None
|
conversation_engine: str | None
|
||||||
tts_engine: str | None
|
tts_engine: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineStage(StrEnum):
|
||||||
|
"""Stages of a pipeline."""
|
||||||
|
|
||||||
|
STT = "stt"
|
||||||
|
INTENT = "intent"
|
||||||
|
TTS = "tts"
|
||||||
|
|
||||||
|
|
||||||
|
PIPELINE_STAGE_ORDER = [
|
||||||
|
PipelineStage.STT,
|
||||||
|
PipelineStage.INTENT,
|
||||||
|
PipelineStage.TTS,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineRunValidationError(Exception):
|
||||||
|
"""Error when a pipeline run is not valid."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidPipelineStagesError(PipelineRunValidationError):
|
||||||
|
"""Error when given an invalid combination of start/end stages."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
start_stage: PipelineStage,
|
||||||
|
end_stage: PipelineStage,
|
||||||
|
) -> None:
|
||||||
|
"""Set error message."""
|
||||||
|
super().__init__(
|
||||||
|
f"Invalid stage combination: start={start_stage}, end={end_stage}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PipelineRun:
|
class PipelineRun:
|
||||||
"""Running context for a pipeline."""
|
"""Running context for a pipeline."""
|
||||||
|
@ -65,6 +146,8 @@ class PipelineRun:
|
||||||
hass: HomeAssistant
|
hass: HomeAssistant
|
||||||
context: Context
|
context: Context
|
||||||
pipeline: Pipeline
|
pipeline: Pipeline
|
||||||
|
start_stage: PipelineStage
|
||||||
|
end_stage: PipelineStage
|
||||||
event_callback: Callable[[PipelineEvent], None]
|
event_callback: Callable[[PipelineEvent], None]
|
||||||
language: str = None # type: ignore[assignment]
|
language: str = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
@ -72,6 +155,12 @@ class PipelineRun:
|
||||||
"""Set language for pipeline."""
|
"""Set language for pipeline."""
|
||||||
self.language = self.pipeline.language or self.hass.config.language
|
self.language = self.pipeline.language or self.hass.config.language
|
||||||
|
|
||||||
|
# stt -> intent -> tts
|
||||||
|
if PIPELINE_STAGE_ORDER.index(self.end_stage) < PIPELINE_STAGE_ORDER.index(
|
||||||
|
self.start_stage
|
||||||
|
):
|
||||||
|
raise InvalidPipelineStagesError(self.start_stage, self.end_stage)
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Emit run start event."""
|
"""Emit run start event."""
|
||||||
self.event_callback(
|
self.event_callback(
|
||||||
|
@ -84,18 +173,86 @@ class PipelineRun:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def finish(self):
|
def end(self):
|
||||||
"""Emit run finish event."""
|
"""Emit run end event."""
|
||||||
self.event_callback(
|
self.event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.RUN_FINISH,
|
PipelineEventType.RUN_END,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def speech_to_text(
|
||||||
|
self,
|
||||||
|
metadata: stt.SpeechMetadata,
|
||||||
|
stream: AsyncIterable[bytes],
|
||||||
|
) -> str:
|
||||||
|
"""Run speech to text portion of pipeline. Returns the spoken text."""
|
||||||
|
engine = self.pipeline.stt_engine or "default"
|
||||||
|
self.event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
PipelineEventType.STT_START,
|
||||||
|
{
|
||||||
|
"engine": engine,
|
||||||
|
"metadata": asdict(metadata),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load provider
|
||||||
|
stt_provider = stt.async_get_provider(self.hass, self.pipeline.stt_engine)
|
||||||
|
assert stt_provider is not None
|
||||||
|
except Exception as src_error:
|
||||||
|
stt_error = SpeechToTextError(
|
||||||
|
code="stt-provider-missing",
|
||||||
|
message=f"No speech to text provider for: {engine}",
|
||||||
|
)
|
||||||
|
_LOGGER.exception(stt_error.message)
|
||||||
|
self.event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
PipelineEventType.ERROR,
|
||||||
|
{"code": stt_error.code, "message": stt_error.message},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise stt_error from src_error
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Transcribe audio stream
|
||||||
|
result = await stt_provider.async_process_audio_stream(metadata, stream)
|
||||||
|
assert (result.text is not None) and (
|
||||||
|
result.result == stt.SpeechResultState.SUCCESS
|
||||||
|
)
|
||||||
|
except Exception as src_error:
|
||||||
|
stt_error = SpeechToTextError(
|
||||||
|
code="stt-stream-failed",
|
||||||
|
message="Unexpected error during speech to text",
|
||||||
|
)
|
||||||
|
_LOGGER.exception(stt_error.message)
|
||||||
|
self.event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
PipelineEventType.ERROR,
|
||||||
|
{"code": stt_error.code, "message": stt_error.message},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise stt_error from src_error
|
||||||
|
|
||||||
|
self.event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
PipelineEventType.STT_END,
|
||||||
|
{
|
||||||
|
"stt_output": {
|
||||||
|
"text": result.text,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return result.text
|
||||||
|
|
||||||
async def recognize_intent(
|
async def recognize_intent(
|
||||||
self, intent_input: str, conversation_id: str | None
|
self, intent_input: str, conversation_id: str | None
|
||||||
) -> conversation.ConversationResult:
|
) -> str:
|
||||||
"""Run intent recognition portion of pipeline."""
|
"""Run intent recognition portion of pipeline. Returns text to speak."""
|
||||||
self.event_callback(
|
self.event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.INTENT_START,
|
PipelineEventType.INTENT_START,
|
||||||
|
@ -106,6 +263,7 @@ class PipelineRun:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
conversation_result = await conversation.async_converse(
|
conversation_result = await conversation.async_converse(
|
||||||
hass=self.hass,
|
hass=self.hass,
|
||||||
text=intent_input,
|
text=intent_input,
|
||||||
|
@ -114,15 +272,30 @@ class PipelineRun:
|
||||||
language=self.language,
|
language=self.language,
|
||||||
agent_id=self.pipeline.conversation_engine,
|
agent_id=self.pipeline.conversation_engine,
|
||||||
)
|
)
|
||||||
|
except Exception as src_error:
|
||||||
|
intent_error = IntentRecognitionError(
|
||||||
|
code="intent-failed",
|
||||||
|
message="Unexpected error during intent recognition",
|
||||||
|
)
|
||||||
|
_LOGGER.exception(intent_error.message)
|
||||||
|
self.event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
PipelineEventType.ERROR,
|
||||||
|
{"code": intent_error.code, "message": intent_error.message},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise intent_error from src_error
|
||||||
|
|
||||||
self.event_callback(
|
self.event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.INTENT_FINISH,
|
PipelineEventType.INTENT_END,
|
||||||
{"intent_output": conversation_result.as_dict()},
|
{"intent_output": conversation_result.as_dict()},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return conversation_result
|
speech = conversation_result.response.speech.get("plain", {}).get("speech", "")
|
||||||
|
|
||||||
|
return speech
|
||||||
|
|
||||||
async def text_to_speech(self, tts_input: str) -> str:
|
async def text_to_speech(self, tts_input: str) -> str:
|
||||||
"""Run text to speech portion of pipeline. Returns URL of TTS audio."""
|
"""Run text to speech portion of pipeline. Returns URL of TTS audio."""
|
||||||
|
@ -136,7 +309,9 @@ class PipelineRun:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
tts_media = await async_resolve_media(
|
try:
|
||||||
|
# Synthesize audio and get URL
|
||||||
|
tts_media = await media_source.async_resolve_media(
|
||||||
self.hass,
|
self.hass,
|
||||||
tts_generate_media_source_id(
|
tts_generate_media_source_id(
|
||||||
self.hass,
|
self.hass,
|
||||||
|
@ -144,21 +319,47 @@ class PipelineRun:
|
||||||
engine=self.pipeline.tts_engine,
|
engine=self.pipeline.tts_engine,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
tts_url = tts_media.url
|
except Exception as src_error:
|
||||||
|
tts_error = TextToSpeechError(
|
||||||
|
code="tts-failed",
|
||||||
|
message="Unexpected error during text to speech",
|
||||||
|
)
|
||||||
|
_LOGGER.exception(tts_error.message)
|
||||||
|
self.event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
PipelineEventType.ERROR,
|
||||||
|
{"code": tts_error.code, "message": tts_error.message},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise tts_error from src_error
|
||||||
|
|
||||||
self.event_callback(
|
self.event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.TTS_FINISH,
|
PipelineEventType.TTS_END,
|
||||||
{"tts_output": tts_url},
|
{"tts_output": asdict(tts_media)},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return tts_url
|
return tts_media.url
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PipelineRequest(ABC):
|
class PipelineInput:
|
||||||
"""Request to for a pipeline run."""
|
"""Input to a pipeline run."""
|
||||||
|
|
||||||
|
stt_metadata: stt.SpeechMetadata | None = None
|
||||||
|
"""Metadata of stt input audio. Required when start_stage = stt."""
|
||||||
|
|
||||||
|
stt_stream: AsyncIterable[bytes] | None = None
|
||||||
|
"""Input audio for stt. Required when start_stage = stt."""
|
||||||
|
|
||||||
|
intent_input: str | None = None
|
||||||
|
"""Input for conversation agent. Required when start_stage = intent."""
|
||||||
|
|
||||||
|
tts_input: str | None = None
|
||||||
|
"""Input for text to speech. Required when start_stage = tts."""
|
||||||
|
|
||||||
|
conversation_id: str | None = None
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self, run: PipelineRun, timeout: int | float | None = DEFAULT_TIMEOUT
|
self, run: PipelineRun, timeout: int | float | None = DEFAULT_TIMEOUT
|
||||||
|
@ -169,47 +370,60 @@ class PipelineRequest(ABC):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def _execute(self, run: PipelineRun):
|
async def _execute(self, run: PipelineRun):
|
||||||
"""Run pipeline with request info and context."""
|
self._validate(run.start_stage)
|
||||||
|
|
||||||
|
# stt -> intent -> tts
|
||||||
@dataclass
|
|
||||||
class TextPipelineRequest(PipelineRequest):
|
|
||||||
"""Request to run the text portion only of a pipeline."""
|
|
||||||
|
|
||||||
intent_input: str
|
|
||||||
conversation_id: str | None = None
|
|
||||||
|
|
||||||
async def _execute(
|
|
||||||
self,
|
|
||||||
run: PipelineRun,
|
|
||||||
):
|
|
||||||
run.start()
|
run.start()
|
||||||
await run.recognize_intent(self.intent_input, self.conversation_id)
|
current_stage = run.start_stage
|
||||||
run.finish()
|
|
||||||
|
|
||||||
|
# Speech to text
|
||||||
@dataclass
|
intent_input = self.intent_input
|
||||||
class AudioPipelineRequest(PipelineRequest):
|
if current_stage == PipelineStage.STT:
|
||||||
"""Request to full pipeline from audio input (stt) to audio output (tts)."""
|
assert self.stt_metadata is not None
|
||||||
|
assert self.stt_stream is not None
|
||||||
intent_input: str # this will be changed to stt audio
|
intent_input = await run.speech_to_text(
|
||||||
conversation_id: str | None = None
|
self.stt_metadata,
|
||||||
|
self.stt_stream,
|
||||||
async def _execute(self, run: PipelineRun):
|
|
||||||
run.start()
|
|
||||||
|
|
||||||
# stt will go here
|
|
||||||
|
|
||||||
conversation_result = await run.recognize_intent(
|
|
||||||
self.intent_input, self.conversation_id
|
|
||||||
)
|
)
|
||||||
|
current_stage = PipelineStage.INTENT
|
||||||
|
|
||||||
tts_input = conversation_result.response.speech.get("plain", {}).get(
|
if run.end_stage != PipelineStage.STT:
|
||||||
"speech", ""
|
tts_input = self.tts_input
|
||||||
|
|
||||||
|
if current_stage == PipelineStage.INTENT:
|
||||||
|
assert intent_input is not None
|
||||||
|
tts_input = await run.recognize_intent(
|
||||||
|
intent_input, self.conversation_id
|
||||||
)
|
)
|
||||||
|
current_stage = PipelineStage.TTS
|
||||||
|
|
||||||
|
if run.end_stage != PipelineStage.INTENT:
|
||||||
|
if current_stage == PipelineStage.TTS:
|
||||||
|
assert tts_input is not None
|
||||||
await run.text_to_speech(tts_input)
|
await run.text_to_speech(tts_input)
|
||||||
|
|
||||||
run.finish()
|
run.end()
|
||||||
|
|
||||||
|
def _validate(self, stage: PipelineStage):
|
||||||
|
"""Validate pipeline input against start stage."""
|
||||||
|
if stage == PipelineStage.STT:
|
||||||
|
if self.stt_metadata is None:
|
||||||
|
raise PipelineRunValidationError(
|
||||||
|
"stt_metadata is required for speech to text"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.stt_stream is None:
|
||||||
|
raise PipelineRunValidationError(
|
||||||
|
"stt_stream is required for speech to text"
|
||||||
|
)
|
||||||
|
elif stage == PipelineStage.INTENT:
|
||||||
|
if self.intent_input is None:
|
||||||
|
raise PipelineRunValidationError(
|
||||||
|
"intent_input is required for intent recognition"
|
||||||
|
)
|
||||||
|
elif stage == PipelineStage.TTS:
|
||||||
|
if self.tts_input is None:
|
||||||
|
raise PipelineRunValidationError(
|
||||||
|
"tts_input is required for text to speech"
|
||||||
|
)
|
||||||
|
|
|
@ -1,13 +1,24 @@
|
||||||
"""Voice Assistant Websocket API."""
|
"""Voice Assistant Websocket API."""
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import Callable
|
||||||
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import websocket_api
|
from homeassistant.components import stt, websocket_api
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .pipeline import (
|
||||||
from .pipeline import DEFAULT_TIMEOUT, Pipeline, PipelineRun, TextPipelineRequest
|
DEFAULT_TIMEOUT,
|
||||||
|
PipelineError,
|
||||||
|
PipelineInput,
|
||||||
|
PipelineRun,
|
||||||
|
PipelineStage,
|
||||||
|
async_get_pipeline,
|
||||||
|
)
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -19,9 +30,13 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||||
@websocket_api.websocket_command(
|
@websocket_api.websocket_command(
|
||||||
{
|
{
|
||||||
vol.Required("type"): "voice_assistant/run",
|
vol.Required("type"): "voice_assistant/run",
|
||||||
|
# pylint: disable-next=unnecessary-lambda
|
||||||
|
vol.Required("start_stage"): lambda val: PipelineStage(val),
|
||||||
|
# pylint: disable-next=unnecessary-lambda
|
||||||
|
vol.Required("end_stage"): lambda val: PipelineStage(val),
|
||||||
|
vol.Optional("input"): {"text": str},
|
||||||
vol.Optional("language"): str,
|
vol.Optional("language"): str,
|
||||||
vol.Optional("pipeline"): str,
|
vol.Optional("pipeline"): str,
|
||||||
vol.Required("intent_input"): str,
|
|
||||||
vol.Optional("conversation_id"): vol.Any(str, None),
|
vol.Optional("conversation_id"): vol.Any(str, None),
|
||||||
vol.Optional("timeout"): vol.Any(float, int),
|
vol.Optional("timeout"): vol.Any(float, int),
|
||||||
}
|
}
|
||||||
|
@ -33,39 +48,74 @@ async def websocket_run(
|
||||||
msg: dict[str, Any],
|
msg: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run a pipeline."""
|
"""Run a pipeline."""
|
||||||
|
language = msg.get("language", hass.config.language)
|
||||||
pipeline_id = msg.get("pipeline")
|
pipeline_id = msg.get("pipeline")
|
||||||
if pipeline_id is not None:
|
pipeline = async_get_pipeline(
|
||||||
pipeline = hass.data[DOMAIN].get(pipeline_id)
|
hass,
|
||||||
|
pipeline_id=pipeline_id,
|
||||||
|
language=language,
|
||||||
|
)
|
||||||
if pipeline is None:
|
if pipeline is None:
|
||||||
connection.send_error(
|
connection.send_error(
|
||||||
msg["id"],
|
msg["id"],
|
||||||
"pipeline_not_found",
|
"pipeline-not-found",
|
||||||
f"Pipeline not found: {pipeline_id}",
|
f"Pipeline not found: id={pipeline_id}, language={language}",
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
else:
|
timeout = msg.get("timeout", DEFAULT_TIMEOUT)
|
||||||
# Construct a pipeline for the required/configured language
|
start_stage = PipelineStage(msg["start_stage"])
|
||||||
language = msg.get("language", hass.config.language)
|
end_stage = PipelineStage(msg["end_stage"])
|
||||||
pipeline = Pipeline(
|
handler_id: int | None = None
|
||||||
name=language,
|
unregister_handler: Callable[[], None] | None = None
|
||||||
language=language,
|
|
||||||
conversation_engine=None,
|
# Arguments to PipelineInput
|
||||||
tts_engine=None,
|
input_args: dict[str, Any] = {
|
||||||
|
"conversation_id": msg.get("conversation_id"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if start_stage == PipelineStage.STT:
|
||||||
|
# Audio pipeline that will receive audio as binary websocket messages
|
||||||
|
audio_queue: "asyncio.Queue[bytes]" = asyncio.Queue()
|
||||||
|
|
||||||
|
async def stt_stream():
|
||||||
|
# Yield until we receive an empty chunk
|
||||||
|
while chunk := await audio_queue.get():
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
def handle_binary(_hass, _connection, data: bytes):
|
||||||
|
# Forward to STT audio stream
|
||||||
|
audio_queue.put_nowait(data)
|
||||||
|
|
||||||
|
handler_id, unregister_handler = connection.async_register_binary_handler(
|
||||||
|
handle_binary
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run pipeline with a timeout.
|
# Audio input must be raw PCM at 16Khz with 16-bit mono samples
|
||||||
# Events are sent over the websocket connection.
|
input_args["stt_metadata"] = stt.SpeechMetadata(
|
||||||
timeout = msg.get("timeout", DEFAULT_TIMEOUT)
|
language=language,
|
||||||
|
format=stt.AudioFormats.WAV,
|
||||||
|
codec=stt.AudioCodecs.PCM,
|
||||||
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||||
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||||
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
|
)
|
||||||
|
input_args["stt_stream"] = stt_stream()
|
||||||
|
elif start_stage == PipelineStage.INTENT:
|
||||||
|
# Input to conversation agent
|
||||||
|
input_args["intent_input"] = msg["input"]["text"]
|
||||||
|
elif start_stage == PipelineStage.TTS:
|
||||||
|
# Input to text to speech system
|
||||||
|
input_args["tts_input"] = msg["input"]["text"]
|
||||||
|
|
||||||
run_task = hass.async_create_task(
|
run_task = hass.async_create_task(
|
||||||
TextPipelineRequest(
|
PipelineInput(**input_args).execute(
|
||||||
intent_input=msg["intent_input"],
|
|
||||||
conversation_id=msg.get("conversation_id"),
|
|
||||||
).execute(
|
|
||||||
PipelineRun(
|
PipelineRun(
|
||||||
hass,
|
hass,
|
||||||
connection.context(msg),
|
context=connection.context(msg),
|
||||||
pipeline,
|
pipeline=pipeline,
|
||||||
|
start_stage=start_stage,
|
||||||
|
end_stage=end_stage,
|
||||||
event_callback=lambda event: connection.send_event(
|
event_callback=lambda event: connection.send_event(
|
||||||
msg["id"], event.as_dict()
|
msg["id"], event.as_dict()
|
||||||
),
|
),
|
||||||
|
@ -77,7 +127,20 @@ async def websocket_run(
|
||||||
# Cancel pipeline if user unsubscribes
|
# Cancel pipeline if user unsubscribes
|
||||||
connection.subscriptions[msg["id"]] = run_task.cancel
|
connection.subscriptions[msg["id"]] = run_task.cancel
|
||||||
|
|
||||||
|
# Confirm subscription
|
||||||
connection.send_result(msg["id"])
|
connection.send_result(msg["id"])
|
||||||
|
|
||||||
|
if handler_id is not None:
|
||||||
|
# Send handler id to client
|
||||||
|
connection.send_event(msg["id"], {"handler_id": handler_id})
|
||||||
|
|
||||||
|
try:
|
||||||
# Task contains a timeout
|
# Task contains a timeout
|
||||||
await run_task
|
await run_task
|
||||||
|
except PipelineError as error:
|
||||||
|
# Report more specific error when possible
|
||||||
|
connection.send_error(msg["id"], error.code, error.message)
|
||||||
|
finally:
|
||||||
|
if unregister_handler is not None:
|
||||||
|
# Unregister binary handler
|
||||||
|
unregister_handler()
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""Test STT component setup."""
|
"""Test STT component setup."""
|
||||||
from asyncio import StreamReader
|
from collections.abc import AsyncIterable
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from unittest.mock import AsyncMock, Mock
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ class MockProvider(Provider):
|
||||||
return [AudioChannels.CHANNEL_MONO]
|
return [AudioChannels.CHANNEL_MONO]
|
||||||
|
|
||||||
async def async_process_audio_stream(
|
async def async_process_audio_stream(
|
||||||
self, metadata: SpeechMetadata, stream: StreamReader
|
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
||||||
) -> SpeechResult:
|
) -> SpeechResult:
|
||||||
"""Process an audio stream."""
|
"""Process an audio stream."""
|
||||||
self.calls.append((metadata, stream))
|
self.calls.append((metadata, stream))
|
||||||
|
|
|
@ -1,110 +0,0 @@
|
||||||
"""Pipeline tests for Voice Assistant integration."""
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from homeassistant.components.voice_assistant.pipeline import (
|
|
||||||
AudioPipelineRequest,
|
|
||||||
Pipeline,
|
|
||||||
PipelineEventType,
|
|
||||||
PipelineRun,
|
|
||||||
)
|
|
||||||
from homeassistant.core import Context
|
|
||||||
from homeassistant.setup import async_setup_component
|
|
||||||
|
|
||||||
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
|
|
||||||
mock_get_cache_files,
|
|
||||||
mock_init_cache_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
async def init_components(hass):
|
|
||||||
"""Initialize relevant components with empty configs."""
|
|
||||||
assert await async_setup_component(hass, "voice_assistant", {})
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def mock_get_tts_audio(hass):
|
|
||||||
"""Set up media source."""
|
|
||||||
assert await async_setup_component(hass, "media_source", {})
|
|
||||||
assert await async_setup_component(
|
|
||||||
hass,
|
|
||||||
"tts",
|
|
||||||
{
|
|
||||||
"tts": {
|
|
||||||
"platform": "demo",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.demo.tts.DemoProvider.get_tts_audio",
|
|
||||||
return_value=("mp3", b""),
|
|
||||||
) as mock_get_tts:
|
|
||||||
yield mock_get_tts
|
|
||||||
|
|
||||||
|
|
||||||
async def test_audio_pipeline(hass, mock_get_tts_audio):
|
|
||||||
"""Run audio pipeline with mock TTS."""
|
|
||||||
pipeline = Pipeline(
|
|
||||||
name="test",
|
|
||||||
language=hass.config.language,
|
|
||||||
conversation_engine=None,
|
|
||||||
tts_engine=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
event_callback = MagicMock()
|
|
||||||
await AudioPipelineRequest(intent_input="Are the lights on?").execute(
|
|
||||||
PipelineRun(
|
|
||||||
hass,
|
|
||||||
context=Context(),
|
|
||||||
pipeline=pipeline,
|
|
||||||
event_callback=event_callback,
|
|
||||||
language=hass.config.language,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
calls = event_callback.mock_calls
|
|
||||||
assert calls[0].args[0].type == PipelineEventType.RUN_START
|
|
||||||
assert calls[0].args[0].data == {
|
|
||||||
"pipeline": "test",
|
|
||||||
"language": hass.config.language,
|
|
||||||
}
|
|
||||||
|
|
||||||
assert calls[1].args[0].type == PipelineEventType.INTENT_START
|
|
||||||
assert calls[1].args[0].data == {
|
|
||||||
"engine": "default",
|
|
||||||
"intent_input": "Are the lights on?",
|
|
||||||
}
|
|
||||||
assert calls[2].args[0].type == PipelineEventType.INTENT_FINISH
|
|
||||||
assert calls[2].args[0].data == {
|
|
||||||
"intent_output": {
|
|
||||||
"conversation_id": None,
|
|
||||||
"response": {
|
|
||||||
"card": {},
|
|
||||||
"data": {"code": "no_intent_match"},
|
|
||||||
"language": hass.config.language,
|
|
||||||
"response_type": "error",
|
|
||||||
"speech": {
|
|
||||||
"plain": {
|
|
||||||
"extra_data": None,
|
|
||||||
"speech": "Sorry, I couldn't understand that",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
assert calls[3].args[0].type == PipelineEventType.TTS_START
|
|
||||||
assert calls[3].args[0].data == {
|
|
||||||
"engine": "default",
|
|
||||||
"tts_input": "Sorry, I couldn't understand that",
|
|
||||||
}
|
|
||||||
assert calls[4].args[0].type == PipelineEventType.TTS_FINISH
|
|
||||||
assert (
|
|
||||||
calls[4].args[0].data["tts_output"]
|
|
||||||
== f"/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_{hass.config.language}_-_demo.mp3"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert calls[5].args[0].type == PipelineEventType.RUN_FINISH
|
|
|
@ -1,20 +1,94 @@
|
||||||
"""Websocket tests for Voice Assistant integration."""
|
"""Websocket tests for Voice Assistant integration."""
|
||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import patch
|
from collections.abc import AsyncIterable
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components import stt
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
|
||||||
|
mock_get_cache_files,
|
||||||
|
mock_init_cache_dir,
|
||||||
|
)
|
||||||
from tests.typing import WebSocketGenerator
|
from tests.typing import WebSocketGenerator
|
||||||
|
|
||||||
|
_TRANSCRIPT = "test transcript"
|
||||||
|
|
||||||
|
|
||||||
|
class MockSttProvider(stt.Provider):
|
||||||
|
"""Mock STT provider."""
|
||||||
|
|
||||||
|
def __init__(self, hass: HomeAssistant, text: str) -> None:
|
||||||
|
"""Init test provider."""
|
||||||
|
self.hass = hass
|
||||||
|
self.text = text
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_languages(self) -> list[str]:
|
||||||
|
"""Return a list of supported languages."""
|
||||||
|
return [self.hass.config.language]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_formats(self) -> list[stt.AudioFormats]:
|
||||||
|
"""Return a list of supported formats."""
|
||||||
|
return [stt.AudioFormats.WAV]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_codecs(self) -> list[stt.AudioCodecs]:
|
||||||
|
"""Return a list of supported codecs."""
|
||||||
|
return [stt.AudioCodecs.PCM]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_bit_rates(self) -> list[stt.AudioBitRates]:
|
||||||
|
"""Return a list of supported bitrates."""
|
||||||
|
return [stt.AudioBitRates.BITRATE_16]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_sample_rates(self) -> list[stt.AudioSampleRates]:
|
||||||
|
"""Return a list of supported samplerates."""
|
||||||
|
return [stt.AudioSampleRates.SAMPLERATE_16000]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_channels(self) -> list[stt.AudioChannels]:
|
||||||
|
"""Return a list of supported channels."""
|
||||||
|
return [stt.AudioChannels.CHANNEL_MONO]
|
||||||
|
|
||||||
|
async def async_process_audio_stream(
|
||||||
|
self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes]
|
||||||
|
) -> stt.SpeechResult:
|
||||||
|
"""Process an audio stream."""
|
||||||
|
return stt.SpeechResult(self.text, stt.SpeechResultState.SUCCESS)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
async def init_components(hass):
|
async def init_components(hass):
|
||||||
"""Initialize relevant components with empty configs."""
|
"""Initialize relevant components with empty configs."""
|
||||||
|
assert await async_setup_component(hass, "media_source", {})
|
||||||
|
assert await async_setup_component(
|
||||||
|
hass,
|
||||||
|
"tts",
|
||||||
|
{
|
||||||
|
"tts": {
|
||||||
|
"platform": "demo",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert await async_setup_component(hass, "stt", {})
|
||||||
|
|
||||||
|
# mock_platform fails because it can't import
|
||||||
|
hass.data[stt.DOMAIN] = {"test": MockSttProvider(hass, _TRANSCRIPT)}
|
||||||
|
|
||||||
assert await async_setup_component(hass, "voice_assistant", {})
|
assert await async_setup_component(hass, "voice_assistant", {})
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.demo.tts.DemoProvider.get_tts_audio",
|
||||||
|
return_value=("mp3", b""),
|
||||||
|
) as mock_get_tts:
|
||||||
|
yield mock_get_tts
|
||||||
|
|
||||||
|
|
||||||
async def test_text_only_pipeline(
|
async def test_text_only_pipeline(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
@ -27,7 +101,9 @@ async def test_text_only_pipeline(
|
||||||
{
|
{
|
||||||
"id": 5,
|
"id": 5,
|
||||||
"type": "voice_assistant/run",
|
"type": "voice_assistant/run",
|
||||||
"intent_input": "Are the lights on?",
|
"start_stage": "intent",
|
||||||
|
"end_stage": "intent",
|
||||||
|
"input": {"text": "Are the lights on?"},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -52,7 +128,7 @@ async def test_text_only_pipeline(
|
||||||
}
|
}
|
||||||
|
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "intent-finish"
|
assert msg["event"]["type"] == "intent-end"
|
||||||
assert msg["event"]["data"] == {
|
assert msg["event"]["data"] == {
|
||||||
"intent_output": {
|
"intent_output": {
|
||||||
"response": {
|
"response": {
|
||||||
|
@ -71,13 +147,120 @@ async def test_text_only_pipeline(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# run finish
|
# run end
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "run-finish"
|
assert msg["event"]["type"] == "run-end"
|
||||||
assert msg["event"]["data"] == {}
|
assert msg["event"]["data"] == {}
|
||||||
|
|
||||||
|
|
||||||
async def test_conversation_timeout(
|
async def test_audio_pipeline(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test events from a pipeline run with audio input/output."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await client.send_json(
|
||||||
|
{
|
||||||
|
"id": 5,
|
||||||
|
"type": "voice_assistant/run",
|
||||||
|
"start_stage": "stt",
|
||||||
|
"end_stage": "tts",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
# handler id
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["handler_id"] == 1
|
||||||
|
|
||||||
|
# run start
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-start"
|
||||||
|
assert msg["event"]["data"] == {
|
||||||
|
"pipeline": hass.config.language,
|
||||||
|
"language": hass.config.language,
|
||||||
|
}
|
||||||
|
|
||||||
|
# stt
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "stt-start"
|
||||||
|
assert msg["event"]["data"] == {
|
||||||
|
"engine": "default",
|
||||||
|
"metadata": {
|
||||||
|
"bit_rate": 16,
|
||||||
|
"channel": 1,
|
||||||
|
"codec": "pcm",
|
||||||
|
"format": "wav",
|
||||||
|
"language": "en",
|
||||||
|
"sample_rate": 16000,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# End of audio stream (handler id + empty payload)
|
||||||
|
await client.send_bytes(b"1")
|
||||||
|
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "stt-end"
|
||||||
|
assert msg["event"]["data"] == {
|
||||||
|
"stt_output": {"text": _TRANSCRIPT},
|
||||||
|
}
|
||||||
|
|
||||||
|
# intent
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "intent-start"
|
||||||
|
assert msg["event"]["data"] == {
|
||||||
|
"engine": "default",
|
||||||
|
"intent_input": _TRANSCRIPT,
|
||||||
|
}
|
||||||
|
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "intent-end"
|
||||||
|
assert msg["event"]["data"] == {
|
||||||
|
"intent_output": {
|
||||||
|
"response": {
|
||||||
|
"speech": {
|
||||||
|
"plain": {
|
||||||
|
"speech": "Sorry, I couldn't understand that",
|
||||||
|
"extra_data": None,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"card": {},
|
||||||
|
"language": "en",
|
||||||
|
"response_type": "error",
|
||||||
|
"data": {"code": "no_intent_match"},
|
||||||
|
},
|
||||||
|
"conversation_id": None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# text to speech
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "tts-start"
|
||||||
|
assert msg["event"]["data"] == {
|
||||||
|
"engine": "default",
|
||||||
|
"tts_input": "Sorry, I couldn't understand that",
|
||||||
|
}
|
||||||
|
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "tts-end"
|
||||||
|
assert msg["event"]["data"] == {
|
||||||
|
"tts_output": {
|
||||||
|
"url": f"/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_{hass.config.language}_-_demo.mp3",
|
||||||
|
"mime_type": "audio/mpeg",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# run end
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-end"
|
||||||
|
assert msg["event"]["data"] == {}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_intent_timeout(
|
||||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test partial pipeline run with conversation agent timeout."""
|
"""Test partial pipeline run with conversation agent timeout."""
|
||||||
|
@ -94,7 +277,9 @@ async def test_conversation_timeout(
|
||||||
{
|
{
|
||||||
"id": 5,
|
"id": 5,
|
||||||
"type": "voice_assistant/run",
|
"type": "voice_assistant/run",
|
||||||
"intent_input": "Are the lights on?",
|
"start_stage": "intent",
|
||||||
|
"end_stage": "intent",
|
||||||
|
"input": {"text": "Are the lights on?"},
|
||||||
"timeout": 0.00001,
|
"timeout": 0.00001,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -125,24 +310,26 @@ async def test_conversation_timeout(
|
||||||
assert msg["error"]["code"] == "timeout"
|
assert msg["error"]["code"] == "timeout"
|
||||||
|
|
||||||
|
|
||||||
async def test_pipeline_timeout(
|
async def test_text_pipeline_timeout(
|
||||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test pipeline run with immediate timeout."""
|
"""Test text-only pipeline run with immediate timeout."""
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
async def sleepy_run(*args, **kwargs):
|
async def sleepy_run(*args, **kwargs):
|
||||||
await asyncio.sleep(3600)
|
await asyncio.sleep(3600)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.voice_assistant.pipeline.TextPipelineRequest._execute",
|
"homeassistant.components.voice_assistant.pipeline.PipelineInput._execute",
|
||||||
new=sleepy_run,
|
new=sleepy_run,
|
||||||
):
|
):
|
||||||
await client.send_json(
|
await client.send_json(
|
||||||
{
|
{
|
||||||
"id": 5,
|
"id": 5,
|
||||||
"type": "voice_assistant/run",
|
"type": "voice_assistant/run",
|
||||||
"intent_input": "Are the lights on?",
|
"start_stage": "intent",
|
||||||
|
"end_stage": "intent",
|
||||||
|
"input": {"text": "Are the lights on?"},
|
||||||
"timeout": 0.0001,
|
"timeout": 0.0001,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -155,3 +342,273 @@ async def test_pipeline_timeout(
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert not msg["success"]
|
assert not msg["success"]
|
||||||
assert msg["error"]["code"] == "timeout"
|
assert msg["error"]["code"] == "timeout"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_intent_failed(
|
||||||
|
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||||
|
) -> None:
|
||||||
|
"""Test text-only pipeline run with conversation agent error."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.conversation.async_converse",
|
||||||
|
new=MagicMock(return_value=RuntimeError),
|
||||||
|
):
|
||||||
|
await client.send_json(
|
||||||
|
{
|
||||||
|
"id": 5,
|
||||||
|
"type": "voice_assistant/run",
|
||||||
|
"start_stage": "intent",
|
||||||
|
"end_stage": "intent",
|
||||||
|
"input": {"text": "Are the lights on?"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
# run start
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-start"
|
||||||
|
assert msg["event"]["data"] == {
|
||||||
|
"pipeline": hass.config.language,
|
||||||
|
"language": hass.config.language,
|
||||||
|
}
|
||||||
|
|
||||||
|
# intent start
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "intent-start"
|
||||||
|
assert msg["event"]["data"] == {
|
||||||
|
"engine": "default",
|
||||||
|
"intent_input": "Are the lights on?",
|
||||||
|
}
|
||||||
|
|
||||||
|
# intent error
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "error"
|
||||||
|
assert msg["event"]["data"]["code"] == "intent-failed"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_audio_pipeline_timeout(
|
||||||
|
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||||
|
) -> None:
|
||||||
|
"""Test audio pipeline run with immediate timeout."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
async def sleepy_run(*args, **kwargs):
|
||||||
|
await asyncio.sleep(3600)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.voice_assistant.pipeline.PipelineInput._execute",
|
||||||
|
new=sleepy_run,
|
||||||
|
):
|
||||||
|
await client.send_json(
|
||||||
|
{
|
||||||
|
"id": 5,
|
||||||
|
"type": "voice_assistant/run",
|
||||||
|
"start_stage": "stt",
|
||||||
|
"end_stage": "tts",
|
||||||
|
"timeout": 0.0001,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
# handler id
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["handler_id"] == 1
|
||||||
|
|
||||||
|
# timeout error
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert not msg["success"]
|
||||||
|
assert msg["error"]["code"] == "timeout"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_stt_provider_missing(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test events from a pipeline run with a non-existent STT provider."""
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.stt.async_get_provider",
|
||||||
|
new=MagicMock(return_value=None),
|
||||||
|
):
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await client.send_json(
|
||||||
|
{
|
||||||
|
"id": 5,
|
||||||
|
"type": "voice_assistant/run",
|
||||||
|
"start_stage": "stt",
|
||||||
|
"end_stage": "tts",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
# handler id
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["handler_id"] == 1
|
||||||
|
|
||||||
|
# run start
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-start"
|
||||||
|
assert msg["event"]["data"] == {
|
||||||
|
"pipeline": hass.config.language,
|
||||||
|
"language": hass.config.language,
|
||||||
|
}
|
||||||
|
|
||||||
|
# stt
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "stt-start"
|
||||||
|
assert msg["event"]["data"] == {
|
||||||
|
"engine": "default",
|
||||||
|
"metadata": {
|
||||||
|
"bit_rate": 16,
|
||||||
|
"channel": 1,
|
||||||
|
"codec": "pcm",
|
||||||
|
"format": "wav",
|
||||||
|
"language": "en",
|
||||||
|
"sample_rate": 16000,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# End of audio stream (handler id + empty payload)
|
||||||
|
await client.send_bytes(b"1")
|
||||||
|
|
||||||
|
# stt error
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "error"
|
||||||
|
assert msg["event"]["data"]["code"] == "stt-provider-missing"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_stt_stream_failed(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test events from a pipeline run with a non-existent STT provider."""
|
||||||
|
with patch(
|
||||||
|
"tests.components.voice_assistant.test_websocket.MockSttProvider.async_process_audio_stream",
|
||||||
|
new=MagicMock(side_effect=RuntimeError),
|
||||||
|
):
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await client.send_json(
|
||||||
|
{
|
||||||
|
"id": 5,
|
||||||
|
"type": "voice_assistant/run",
|
||||||
|
"start_stage": "stt",
|
||||||
|
"end_stage": "tts",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
# handler id
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["handler_id"] == 1
|
||||||
|
|
||||||
|
# run start
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-start"
|
||||||
|
assert msg["event"]["data"] == {
|
||||||
|
"pipeline": hass.config.language,
|
||||||
|
"language": hass.config.language,
|
||||||
|
}
|
||||||
|
|
||||||
|
# stt
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "stt-start"
|
||||||
|
assert msg["event"]["data"] == {
|
||||||
|
"engine": "default",
|
||||||
|
"metadata": {
|
||||||
|
"bit_rate": 16,
|
||||||
|
"channel": 1,
|
||||||
|
"codec": "pcm",
|
||||||
|
"format": "wav",
|
||||||
|
"language": "en",
|
||||||
|
"sample_rate": 16000,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# End of audio stream (handler id + empty payload)
|
||||||
|
await client.send_bytes(b"1")
|
||||||
|
|
||||||
|
# stt error
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "error"
|
||||||
|
assert msg["event"]["data"]["code"] == "stt-stream-failed"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_tts_failed(
|
||||||
|
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||||
|
) -> None:
|
||||||
|
"""Test pipeline run with text to speech error."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.media_source.async_resolve_media",
|
||||||
|
new=MagicMock(return_value=RuntimeError),
|
||||||
|
):
|
||||||
|
await client.send_json(
|
||||||
|
{
|
||||||
|
"id": 5,
|
||||||
|
"type": "voice_assistant/run",
|
||||||
|
"start_stage": "tts",
|
||||||
|
"end_stage": "tts",
|
||||||
|
"input": {"text": "Lights are on."},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
# run start
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-start"
|
||||||
|
assert msg["event"]["data"] == {
|
||||||
|
"pipeline": hass.config.language,
|
||||||
|
"language": hass.config.language,
|
||||||
|
}
|
||||||
|
|
||||||
|
# tts start
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "tts-start"
|
||||||
|
assert msg["event"]["data"] == {
|
||||||
|
"engine": "default",
|
||||||
|
"tts_input": "Lights are on.",
|
||||||
|
}
|
||||||
|
|
||||||
|
# tts error
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "error"
|
||||||
|
assert msg["event"]["data"]["code"] == "tts-failed"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_invalid_stage_order(
|
||||||
|
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||||
|
) -> None:
|
||||||
|
"""Test pipeline run with invalid stage order."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await client.send_json(
|
||||||
|
{
|
||||||
|
"id": 5,
|
||||||
|
"type": "voice_assistant/run",
|
||||||
|
"start_stage": "tts",
|
||||||
|
"end_stage": "stt",
|
||||||
|
"input": {"text": "Lights are on."},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert not msg["success"]
|
||||||
|
|
Loading…
Add table
Reference in a new issue