Clean up voice assistant integration (#90239)
* Clean up voice assistant * Reinstate auto-removed imports * Resample STT audio from 44.1Khz to 16Khz * Energy based VAD for prototyping --------- Co-authored-by: Michael Hansen <mike@rhasspy.org>
This commit is contained in:
parent
7098debe09
commit
c3717f8182
5 changed files with 407 additions and 237 deletions
|
@ -2,6 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterable
|
||||
import logging
|
||||
|
||||
from hass_nabucasa import Cloud
|
||||
from hass_nabucasa.voice import VoiceError
|
||||
|
@ -20,6 +21,8 @@ from homeassistant.components.stt import (
|
|||
|
||||
from .const import DOMAIN
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
SUPPORT_LANGUAGES = [
|
||||
"da-DK",
|
||||
"de-DE",
|
||||
|
@ -102,7 +105,8 @@ class CloudProvider(Provider):
|
|||
result = await self.cloud.voice.process_stt(
|
||||
stream, content, metadata.language
|
||||
)
|
||||
except VoiceError:
|
||||
except VoiceError as err:
|
||||
_LOGGER.debug("Voice error: %s", err)
|
||||
return SpeechResult(None, SpeechResultState.ERROR)
|
||||
|
||||
# Return Speech as Text
|
||||
|
|
|
@ -150,6 +150,7 @@ class PipelineRun:
|
|||
end_stage: PipelineStage
|
||||
event_callback: Callable[[PipelineEvent], None]
|
||||
language: str = None # type: ignore[assignment]
|
||||
runner_data: Any | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Set language for pipeline."""
|
||||
|
@ -163,15 +164,14 @@ class PipelineRun:
|
|||
|
||||
def start(self):
|
||||
"""Emit run start event."""
|
||||
self.event_callback(
|
||||
PipelineEvent(
|
||||
PipelineEventType.RUN_START,
|
||||
{
|
||||
"pipeline": self.pipeline.name,
|
||||
"language": self.language,
|
||||
},
|
||||
)
|
||||
)
|
||||
data = {
|
||||
"pipeline": self.pipeline.name,
|
||||
"language": self.language,
|
||||
}
|
||||
if self.runner_data is not None:
|
||||
data["runner_data"] = self.runner_data
|
||||
|
||||
self.event_callback(PipelineEvent(PipelineEventType.RUN_START, data))
|
||||
|
||||
def end(self):
|
||||
"""Emit run end event."""
|
||||
|
@ -200,41 +200,45 @@ class PipelineRun:
|
|||
|
||||
try:
|
||||
# Load provider
|
||||
stt_provider = stt.async_get_provider(self.hass, self.pipeline.stt_engine)
|
||||
stt_provider: stt.Provider = stt.async_get_provider(
|
||||
self.hass, self.pipeline.stt_engine
|
||||
)
|
||||
assert stt_provider is not None
|
||||
except Exception as src_error:
|
||||
stt_error = SpeechToTextError(
|
||||
_LOGGER.exception("No speech to text provider for %s", engine)
|
||||
raise SpeechToTextError(
|
||||
code="stt-provider-missing",
|
||||
message=f"No speech to text provider for: {engine}",
|
||||
) from src_error
|
||||
|
||||
if not stt_provider.check_metadata(metadata):
|
||||
raise SpeechToTextError(
|
||||
code="stt-provider-unsupported-metadata",
|
||||
message=f"Provider {engine} does not support input speech to text metadata",
|
||||
)
|
||||
_LOGGER.exception(stt_error.message)
|
||||
self.event_callback(
|
||||
PipelineEvent(
|
||||
PipelineEventType.ERROR,
|
||||
{"code": stt_error.code, "message": stt_error.message},
|
||||
)
|
||||
)
|
||||
raise stt_error from src_error
|
||||
|
||||
try:
|
||||
# Transcribe audio stream
|
||||
result = await stt_provider.async_process_audio_stream(metadata, stream)
|
||||
assert (result.text is not None) and (
|
||||
result.result == stt.SpeechResultState.SUCCESS
|
||||
)
|
||||
except Exception as src_error:
|
||||
stt_error = SpeechToTextError(
|
||||
_LOGGER.exception("Unexpected error during speech to text")
|
||||
raise SpeechToTextError(
|
||||
code="stt-stream-failed",
|
||||
message="Unexpected error during speech to text",
|
||||
) from src_error
|
||||
|
||||
_LOGGER.debug("speech-to-text result %s", result)
|
||||
|
||||
if result.result != stt.SpeechResultState.SUCCESS:
|
||||
raise SpeechToTextError(
|
||||
code="stt-stream-failed",
|
||||
message="Speech to text failed",
|
||||
)
|
||||
_LOGGER.exception(stt_error.message)
|
||||
self.event_callback(
|
||||
PipelineEvent(
|
||||
PipelineEventType.ERROR,
|
||||
{"code": stt_error.code, "message": stt_error.message},
|
||||
)
|
||||
|
||||
if not result.text:
|
||||
raise SpeechToTextError(
|
||||
code="stt-no-text-recognized", message="No text recognized"
|
||||
)
|
||||
raise stt_error from src_error
|
||||
|
||||
self.event_callback(
|
||||
PipelineEvent(
|
||||
|
@ -273,18 +277,13 @@ class PipelineRun:
|
|||
agent_id=self.pipeline.conversation_engine,
|
||||
)
|
||||
except Exception as src_error:
|
||||
intent_error = IntentRecognitionError(
|
||||
_LOGGER.exception("Unexpected error during intent recognition")
|
||||
raise IntentRecognitionError(
|
||||
code="intent-failed",
|
||||
message="Unexpected error during intent recognition",
|
||||
)
|
||||
_LOGGER.exception(intent_error.message)
|
||||
self.event_callback(
|
||||
PipelineEvent(
|
||||
PipelineEventType.ERROR,
|
||||
{"code": intent_error.code, "message": intent_error.message},
|
||||
)
|
||||
)
|
||||
raise intent_error from src_error
|
||||
) from src_error
|
||||
|
||||
_LOGGER.debug("conversation result %s", conversation_result)
|
||||
|
||||
self.event_callback(
|
||||
PipelineEvent(
|
||||
|
@ -320,18 +319,13 @@ class PipelineRun:
|
|||
),
|
||||
)
|
||||
except Exception as src_error:
|
||||
tts_error = TextToSpeechError(
|
||||
_LOGGER.exception("Unexpected error during text to speech")
|
||||
raise TextToSpeechError(
|
||||
code="tts-failed",
|
||||
message="Unexpected error during text to speech",
|
||||
)
|
||||
_LOGGER.exception(tts_error.message)
|
||||
self.event_callback(
|
||||
PipelineEvent(
|
||||
PipelineEventType.ERROR,
|
||||
{"code": tts_error.code, "message": tts_error.message},
|
||||
)
|
||||
)
|
||||
raise tts_error from src_error
|
||||
) from src_error
|
||||
|
||||
_LOGGER.debug("TTS result %s", tts_media)
|
||||
|
||||
self.event_callback(
|
||||
PipelineEvent(
|
||||
|
@ -377,31 +371,41 @@ class PipelineInput:
|
|||
run.start()
|
||||
current_stage = run.start_stage
|
||||
|
||||
# Speech to text
|
||||
intent_input = self.intent_input
|
||||
if current_stage == PipelineStage.STT:
|
||||
assert self.stt_metadata is not None
|
||||
assert self.stt_stream is not None
|
||||
intent_input = await run.speech_to_text(
|
||||
self.stt_metadata,
|
||||
self.stt_stream,
|
||||
)
|
||||
current_stage = PipelineStage.INTENT
|
||||
|
||||
if run.end_stage != PipelineStage.STT:
|
||||
tts_input = self.tts_input
|
||||
|
||||
if current_stage == PipelineStage.INTENT:
|
||||
assert intent_input is not None
|
||||
tts_input = await run.recognize_intent(
|
||||
intent_input, self.conversation_id
|
||||
try:
|
||||
# Speech to text
|
||||
intent_input = self.intent_input
|
||||
if current_stage == PipelineStage.STT:
|
||||
assert self.stt_metadata is not None
|
||||
assert self.stt_stream is not None
|
||||
intent_input = await run.speech_to_text(
|
||||
self.stt_metadata,
|
||||
self.stt_stream,
|
||||
)
|
||||
current_stage = PipelineStage.TTS
|
||||
current_stage = PipelineStage.INTENT
|
||||
|
||||
if run.end_stage != PipelineStage.INTENT:
|
||||
if current_stage == PipelineStage.TTS:
|
||||
assert tts_input is not None
|
||||
await run.text_to_speech(tts_input)
|
||||
if run.end_stage != PipelineStage.STT:
|
||||
tts_input = self.tts_input
|
||||
|
||||
if current_stage == PipelineStage.INTENT:
|
||||
assert intent_input is not None
|
||||
tts_input = await run.recognize_intent(
|
||||
intent_input, self.conversation_id
|
||||
)
|
||||
current_stage = PipelineStage.TTS
|
||||
|
||||
if run.end_stage != PipelineStage.INTENT:
|
||||
if current_stage == PipelineStage.TTS:
|
||||
assert tts_input is not None
|
||||
await run.text_to_speech(tts_input)
|
||||
|
||||
except PipelineError as err:
|
||||
run.event_callback(
|
||||
PipelineEvent(
|
||||
PipelineEventType.ERROR,
|
||||
{"code": err.code, "message": err.message},
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
run.end()
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""Voice Assistant Websocket API."""
|
||||
import asyncio
|
||||
import audioop
|
||||
from collections.abc import Callable
|
||||
import logging
|
||||
from typing import Any
|
||||
|
@ -12,6 +13,8 @@ from homeassistant.core import HomeAssistant, callback
|
|||
from .pipeline import (
|
||||
DEFAULT_TIMEOUT,
|
||||
PipelineError,
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineInput,
|
||||
PipelineRun,
|
||||
PipelineStage,
|
||||
|
@ -20,6 +23,10 @@ from .pipeline import (
|
|||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_VAD_ENERGY_THRESHOLD = 1000
|
||||
_VAD_SPEECH_FRAMES = 25
|
||||
_VAD_SILENCE_FRAMES = 25
|
||||
|
||||
|
||||
@callback
|
||||
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||
|
@ -27,6 +34,17 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
|
|||
websocket_api.async_register_command(hass, websocket_run)
|
||||
|
||||
|
||||
def _get_debiased_energy(audio_data: bytes, width: int = 2) -> float:
|
||||
"""Compute RMS of debiased audio."""
|
||||
energy = -audioop.rms(audio_data, width)
|
||||
energy_bytes = bytes([energy & 0xFF, (energy >> 8) & 0xFF])
|
||||
debiased_energy = audioop.rms(
|
||||
audioop.add(audio_data, energy_bytes * (len(audio_data) // width), width), width
|
||||
)
|
||||
|
||||
return debiased_energy
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "voice_assistant/run",
|
||||
|
@ -49,6 +67,11 @@ async def websocket_run(
|
|||
) -> None:
|
||||
"""Run a pipeline."""
|
||||
language = msg.get("language", hass.config.language)
|
||||
|
||||
# Temporary workaround for language codes
|
||||
if language == "en":
|
||||
language = "en-US"
|
||||
|
||||
pipeline_id = msg.get("pipeline")
|
||||
pipeline = async_get_pipeline(
|
||||
hass,
|
||||
|
@ -79,8 +102,32 @@ async def websocket_run(
|
|||
audio_queue: "asyncio.Queue[bytes]" = asyncio.Queue()
|
||||
|
||||
async def stt_stream():
|
||||
state = None
|
||||
speech_count = 0
|
||||
in_voice_command = False
|
||||
|
||||
# Yield until we receive an empty chunk
|
||||
while chunk := await audio_queue.get():
|
||||
chunk, state = audioop.ratecv(chunk, 2, 1, 44100, 16000, state)
|
||||
is_speech = _get_debiased_energy(chunk) > _VAD_ENERGY_THRESHOLD
|
||||
|
||||
if in_voice_command:
|
||||
if is_speech:
|
||||
speech_count += 1
|
||||
else:
|
||||
speech_count -= 1
|
||||
|
||||
if speech_count <= -_VAD_SILENCE_FRAMES:
|
||||
_LOGGER.info("Voice command stopped")
|
||||
break
|
||||
else:
|
||||
if is_speech:
|
||||
speech_count += 1
|
||||
|
||||
if speech_count >= _VAD_SPEECH_FRAMES:
|
||||
in_voice_command = True
|
||||
_LOGGER.info("Voice command started")
|
||||
|
||||
yield chunk
|
||||
|
||||
def handle_binary(_hass, _connection, data: bytes):
|
||||
|
@ -119,6 +166,9 @@ async def websocket_run(
|
|||
event_callback=lambda event: connection.send_event(
|
||||
msg["id"], event.as_dict()
|
||||
),
|
||||
runner_data={
|
||||
"stt_binary_handler_id": handler_id,
|
||||
},
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
@ -130,16 +180,20 @@ async def websocket_run(
|
|||
# Confirm subscription
|
||||
connection.send_result(msg["id"])
|
||||
|
||||
if handler_id is not None:
|
||||
# Send handler id to client
|
||||
connection.send_event(msg["id"], {"handler_id": handler_id})
|
||||
|
||||
try:
|
||||
# Task contains a timeout
|
||||
await run_task
|
||||
except PipelineError as error:
|
||||
# Report more specific error when possible
|
||||
connection.send_error(msg["id"], error.code, error.message)
|
||||
except asyncio.TimeoutError:
|
||||
connection.send_event(
|
||||
msg["id"],
|
||||
PipelineEvent(
|
||||
PipelineEventType.ERROR,
|
||||
{"code": "timeout", "message": "Timeout running pipeline"},
|
||||
),
|
||||
)
|
||||
finally:
|
||||
if unregister_handler is not None:
|
||||
# Unregister binary handler
|
||||
|
|
210
tests/components/voice_assistant/snapshots/test_websocket.ambr
Normal file
210
tests/components/voice_assistant/snapshots/test_websocket.ambr
Normal file
|
@ -0,0 +1,210 @@
|
|||
# serializer version: 1
|
||||
# name: test_audio_pipeline
|
||||
dict({
|
||||
'language': 'en-US',
|
||||
'pipeline': 'en-US',
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': 1,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline.1
|
||||
dict({
|
||||
'engine': 'default',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
'codec': 'pcm',
|
||||
'format': 'wav',
|
||||
'language': 'en-US',
|
||||
'sample_rate': 16000,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline.2
|
||||
dict({
|
||||
'stt_output': dict({
|
||||
'text': 'test transcript',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline.3
|
||||
dict({
|
||||
'engine': 'default',
|
||||
'intent_input': 'test transcript',
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline.4
|
||||
dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
'data': dict({
|
||||
'code': 'no_intent_match',
|
||||
}),
|
||||
'language': 'en-US',
|
||||
'response_type': 'error',
|
||||
'speech': dict({
|
||||
'plain': dict({
|
||||
'extra_data': None,
|
||||
'speech': "Sorry, I couldn't understand that",
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline.5
|
||||
dict({
|
||||
'engine': 'default',
|
||||
'tts_input': "Sorry, I couldn't understand that",
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline.6
|
||||
dict({
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en_-_demo.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_intent_failed
|
||||
dict({
|
||||
'language': 'en-US',
|
||||
'pipeline': 'en-US',
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': None,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_intent_failed.1
|
||||
dict({
|
||||
'engine': 'default',
|
||||
'intent_input': 'Are the lights on?',
|
||||
})
|
||||
# ---
|
||||
# name: test_intent_timeout
|
||||
dict({
|
||||
'language': 'en-US',
|
||||
'pipeline': 'en-US',
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': None,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_intent_timeout.1
|
||||
dict({
|
||||
'engine': 'default',
|
||||
'intent_input': 'Are the lights on?',
|
||||
})
|
||||
# ---
|
||||
# name: test_intent_timeout.2
|
||||
dict({
|
||||
'code': 'timeout',
|
||||
'message': 'Timeout running pipeline',
|
||||
})
|
||||
# ---
|
||||
# name: test_stt_provider_missing
|
||||
dict({
|
||||
'language': 'en-US',
|
||||
'pipeline': 'en-US',
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': 1,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_stt_provider_missing.1
|
||||
dict({
|
||||
'engine': 'default',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
'codec': 'pcm',
|
||||
'format': 'wav',
|
||||
'language': 'en-US',
|
||||
'sample_rate': 16000,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_stt_stream_failed
|
||||
dict({
|
||||
'language': 'en-US',
|
||||
'pipeline': 'en-US',
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': 1,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_stt_stream_failed.1
|
||||
dict({
|
||||
'engine': 'default',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
'codec': 'pcm',
|
||||
'format': 'wav',
|
||||
'language': 'en-US',
|
||||
'sample_rate': 16000,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_text_only_pipeline
|
||||
dict({
|
||||
'language': 'en-US',
|
||||
'pipeline': 'en-US',
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': None,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_text_only_pipeline.1
|
||||
dict({
|
||||
'engine': 'default',
|
||||
'intent_input': 'Are the lights on?',
|
||||
})
|
||||
# ---
|
||||
# name: test_text_only_pipeline.2
|
||||
dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
'data': dict({
|
||||
'code': 'no_intent_match',
|
||||
}),
|
||||
'language': 'en-US',
|
||||
'response_type': 'error',
|
||||
'speech': dict({
|
||||
'plain': dict({
|
||||
'extra_data': None,
|
||||
'speech': "Sorry, I couldn't understand that",
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_text_pipeline_timeout
|
||||
dict({
|
||||
'code': 'timeout',
|
||||
'message': 'Timeout running pipeline',
|
||||
})
|
||||
# ---
|
||||
# name: test_tts_failed
|
||||
dict({
|
||||
'language': 'en-US',
|
||||
'pipeline': 'en-US',
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': None,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_tts_failed.1
|
||||
dict({
|
||||
'engine': 'default',
|
||||
'tts_input': 'Lights are on.',
|
||||
})
|
||||
# ---
|
|
@ -4,6 +4,7 @@ from collections.abc import AsyncIterable
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components import stt
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
@ -29,7 +30,7 @@ class MockSttProvider(stt.Provider):
|
|||
@property
|
||||
def supported_languages(self) -> list[str]:
|
||||
"""Return a list of supported languages."""
|
||||
return [self.hass.config.language]
|
||||
return ["en-US"]
|
||||
|
||||
@property
|
||||
def supported_formats(self) -> list[stt.AudioFormats]:
|
||||
|
@ -64,7 +65,11 @@ class MockSttProvider(stt.Provider):
|
|||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def init_components(hass):
|
||||
async def init_components(
|
||||
hass: HomeAssistant,
|
||||
mock_get_cache_files, # noqa: F811
|
||||
mock_init_cache_dir, # noqa: F811
|
||||
):
|
||||
"""Initialize relevant components with empty configs."""
|
||||
assert await async_setup_component(hass, "media_source", {})
|
||||
assert await async_setup_component(
|
||||
|
@ -93,6 +98,7 @@ async def init_components(hass):
|
|||
async def test_text_only_pipeline(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test events from a pipeline run with text input (no STT/TTS)."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
@ -114,38 +120,16 @@ async def test_text_only_pipeline(
|
|||
# run start
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
assert msg["event"]["data"] == {
|
||||
"pipeline": hass.config.language,
|
||||
"language": hass.config.language,
|
||||
}
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# intent
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "intent-start"
|
||||
assert msg["event"]["data"] == {
|
||||
"engine": "default",
|
||||
"intent_input": "Are the lights on?",
|
||||
}
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "intent-end"
|
||||
assert msg["event"]["data"] == {
|
||||
"intent_output": {
|
||||
"response": {
|
||||
"speech": {
|
||||
"plain": {
|
||||
"speech": "Sorry, I couldn't understand that",
|
||||
"extra_data": None,
|
||||
}
|
||||
},
|
||||
"card": {},
|
||||
"language": "en",
|
||||
"response_type": "error",
|
||||
"data": {"code": "no_intent_match"},
|
||||
},
|
||||
"conversation_id": None,
|
||||
}
|
||||
}
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# run end
|
||||
msg = await client.receive_json()
|
||||
|
@ -154,8 +138,7 @@ async def test_text_only_pipeline(
|
|||
|
||||
|
||||
async def test_audio_pipeline(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
"""Test events from a pipeline run with audio input/output."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
@ -173,86 +156,40 @@ async def test_audio_pipeline(
|
|||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
# handler id
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["handler_id"] == 1
|
||||
|
||||
# run start
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
assert msg["event"]["data"] == {
|
||||
"pipeline": hass.config.language,
|
||||
"language": hass.config.language,
|
||||
}
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# stt
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "stt-start"
|
||||
assert msg["event"]["data"] == {
|
||||
"engine": "default",
|
||||
"metadata": {
|
||||
"bit_rate": 16,
|
||||
"channel": 1,
|
||||
"codec": "pcm",
|
||||
"format": "wav",
|
||||
"language": "en",
|
||||
"sample_rate": 16000,
|
||||
},
|
||||
}
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# End of audio stream (handler id + empty payload)
|
||||
await client.send_bytes(b"1")
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "stt-end"
|
||||
assert msg["event"]["data"] == {
|
||||
"stt_output": {"text": _TRANSCRIPT},
|
||||
}
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# intent
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "intent-start"
|
||||
assert msg["event"]["data"] == {
|
||||
"engine": "default",
|
||||
"intent_input": _TRANSCRIPT,
|
||||
}
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "intent-end"
|
||||
assert msg["event"]["data"] == {
|
||||
"intent_output": {
|
||||
"response": {
|
||||
"speech": {
|
||||
"plain": {
|
||||
"speech": "Sorry, I couldn't understand that",
|
||||
"extra_data": None,
|
||||
}
|
||||
},
|
||||
"card": {},
|
||||
"language": "en",
|
||||
"response_type": "error",
|
||||
"data": {"code": "no_intent_match"},
|
||||
},
|
||||
"conversation_id": None,
|
||||
}
|
||||
}
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# text to speech
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "tts-start"
|
||||
assert msg["event"]["data"] == {
|
||||
"engine": "default",
|
||||
"tts_input": "Sorry, I couldn't understand that",
|
||||
}
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "tts-end"
|
||||
assert msg["event"]["data"] == {
|
||||
"tts_output": {
|
||||
"url": f"/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_{hass.config.language}_-_demo.mp3",
|
||||
"mime_type": "audio/mpeg",
|
||||
},
|
||||
}
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# run end
|
||||
msg = await client.receive_json()
|
||||
|
@ -261,7 +198,10 @@ async def test_audio_pipeline(
|
|||
|
||||
|
||||
async def test_intent_timeout(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test partial pipeline run with conversation agent timeout."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
@ -291,27 +231,24 @@ async def test_intent_timeout(
|
|||
# run start
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
assert msg["event"]["data"] == {
|
||||
"pipeline": hass.config.language,
|
||||
"language": hass.config.language,
|
||||
}
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# intent
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "intent-start"
|
||||
assert msg["event"]["data"] == {
|
||||
"engine": "default",
|
||||
"intent_input": "Are the lights on?",
|
||||
}
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# timeout error
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
assert msg["error"]["code"] == "timeout"
|
||||
assert msg["event"]["type"] == "error"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
|
||||
async def test_text_pipeline_timeout(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test text-only pipeline run with immediate timeout."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
@ -340,12 +277,15 @@ async def test_text_pipeline_timeout(
|
|||
|
||||
# timeout error
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
assert msg["error"]["code"] == "timeout"
|
||||
assert msg["event"]["type"] == "error"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
|
||||
async def test_intent_failed(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test text-only pipeline run with conversation agent error."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
@ -371,18 +311,12 @@ async def test_intent_failed(
|
|||
# run start
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
assert msg["event"]["data"] == {
|
||||
"pipeline": hass.config.language,
|
||||
"language": hass.config.language,
|
||||
}
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# intent start
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "intent-start"
|
||||
assert msg["event"]["data"] == {
|
||||
"engine": "default",
|
||||
"intent_input": "Are the lights on?",
|
||||
}
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# intent error
|
||||
msg = await client.receive_json()
|
||||
|
@ -391,7 +325,10 @@ async def test_intent_failed(
|
|||
|
||||
|
||||
async def test_audio_pipeline_timeout(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test audio pipeline run with immediate timeout."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
@ -417,19 +354,16 @@ async def test_audio_pipeline_timeout(
|
|||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
# handler id
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["handler_id"] == 1
|
||||
|
||||
# timeout error
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
assert msg["error"]["code"] == "timeout"
|
||||
assert msg["event"]["type"] == "error"
|
||||
assert msg["event"]["data"]["code"] == "timeout"
|
||||
|
||||
|
||||
async def test_stt_provider_missing(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test events from a pipeline run with a non-existent STT provider."""
|
||||
with patch(
|
||||
|
@ -451,32 +385,15 @@ async def test_stt_provider_missing(
|
|||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
# handler id
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["handler_id"] == 1
|
||||
|
||||
# run start
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
assert msg["event"]["data"] == {
|
||||
"pipeline": hass.config.language,
|
||||
"language": hass.config.language,
|
||||
}
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# stt
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "stt-start"
|
||||
assert msg["event"]["data"] == {
|
||||
"engine": "default",
|
||||
"metadata": {
|
||||
"bit_rate": 16,
|
||||
"channel": 1,
|
||||
"codec": "pcm",
|
||||
"format": "wav",
|
||||
"language": "en",
|
||||
"sample_rate": 16000,
|
||||
},
|
||||
}
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# End of audio stream (handler id + empty payload)
|
||||
await client.send_bytes(b"1")
|
||||
|
@ -490,6 +407,7 @@ async def test_stt_provider_missing(
|
|||
async def test_stt_stream_failed(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test events from a pipeline run with a non-existent STT provider."""
|
||||
with patch(
|
||||
|
@ -511,32 +429,15 @@ async def test_stt_stream_failed(
|
|||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
# handler id
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["handler_id"] == 1
|
||||
|
||||
# run start
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
assert msg["event"]["data"] == {
|
||||
"pipeline": hass.config.language,
|
||||
"language": hass.config.language,
|
||||
}
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# stt
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "stt-start"
|
||||
assert msg["event"]["data"] == {
|
||||
"engine": "default",
|
||||
"metadata": {
|
||||
"bit_rate": 16,
|
||||
"channel": 1,
|
||||
"codec": "pcm",
|
||||
"format": "wav",
|
||||
"language": "en",
|
||||
"sample_rate": 16000,
|
||||
},
|
||||
}
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# End of audio stream (handler id + empty payload)
|
||||
await client.send_bytes(b"1")
|
||||
|
@ -548,7 +449,10 @@ async def test_stt_stream_failed(
|
|||
|
||||
|
||||
async def test_tts_failed(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test pipeline run with text to speech error."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
@ -574,18 +478,12 @@ async def test_tts_failed(
|
|||
# run start
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
assert msg["event"]["data"] == {
|
||||
"pipeline": hass.config.language,
|
||||
"language": hass.config.language,
|
||||
}
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# tts start
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "tts-start"
|
||||
assert msg["event"]["data"] == {
|
||||
"engine": "default",
|
||||
"tts_input": "Lights are on.",
|
||||
}
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# tts error
|
||||
msg = await client.receive_json()
|
||||
|
|
Loading…
Add table
Reference in a new issue