hass-core/tests/components/voice_assistant/test_websocket.py
Michael Hansen 90d81e9844
Use webrcvad to detect silence in pipelines (#90610)
* Add webrtcvad requirement

* Use webrcvad for voice command segmenting

* Add vad test
2023-03-31 23:55:07 -04:00

539 lines
15 KiB
Python

"""Websocket tests for Voice Assistant integration."""
import asyncio
from collections.abc import AsyncIterable
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import stt, tts
from homeassistant.core import HomeAssistant
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_setup_component
from tests.common import MockModule, mock_integration, mock_platform
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
mock_get_cache_files,
mock_init_cache_dir,
)
from tests.typing import WebSocketGenerator
_TRANSCRIPT = "test transcript"
class MockSttProvider(stt.Provider):
"""Mock STT provider."""
def __init__(self, hass: HomeAssistant, text: str) -> None:
"""Init test provider."""
self.hass = hass
self.text = text
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return ["en-US"]
@property
def supported_formats(self) -> list[stt.AudioFormats]:
"""Return a list of supported formats."""
return [stt.AudioFormats.WAV]
@property
def supported_codecs(self) -> list[stt.AudioCodecs]:
"""Return a list of supported codecs."""
return [stt.AudioCodecs.PCM]
@property
def supported_bit_rates(self) -> list[stt.AudioBitRates]:
"""Return a list of supported bitrates."""
return [stt.AudioBitRates.BITRATE_16]
@property
def supported_sample_rates(self) -> list[stt.AudioSampleRates]:
"""Return a list of supported samplerates."""
return [stt.AudioSampleRates.SAMPLERATE_16000]
@property
def supported_channels(self) -> list[stt.AudioChannels]:
"""Return a list of supported channels."""
return [stt.AudioChannels.CHANNEL_MONO]
async def async_process_audio_stream(
self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes]
) -> stt.SpeechResult:
"""Process an audio stream."""
return stt.SpeechResult(self.text, stt.SpeechResultState.SUCCESS)
class MockSTT:
"""A mock STT platform."""
async def async_get_engine(
self,
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> stt.Provider:
"""Set up a mock speech component."""
return MockSttProvider(hass, _TRANSCRIPT)
class MockTTSProvider(tts.Provider):
"""Mock TTS provider."""
name = "Test"
@property
def default_language(self) -> str:
"""Return the default language."""
return "en"
@property
def supported_languages(self) -> list[str]:
"""Return list of supported languages."""
return ["en-US"]
@property
def supported_options(self) -> list[str]:
"""Return list of supported options like voice, emotions."""
return ["voice", "age"]
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
) -> tts.TtsAudioType:
"""Load TTS dat."""
return ("mp3", b"")
class MockTTS:
"""A mock TTS platform."""
PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA
async def async_get_engine(
self,
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> tts.Provider:
"""Set up a mock speech component."""
return MockTTSProvider()
@pytest.fixture(autouse=True)
async def init_components(
hass: HomeAssistant,
mock_get_cache_files, # noqa: F811
mock_init_cache_dir, # noqa: F811
):
"""Initialize relevant components with empty configs."""
mock_integration(hass, MockModule(domain="test"))
mock_platform(hass, "test.tts", MockTTS())
mock_platform(hass, "test.stt", MockSTT())
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
assert await async_setup_component(hass, "media_source", {})
assert await async_setup_component(hass, "voice_assistant", {})
async def test_text_only_pipeline(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
) -> None:
"""Test events from a pipeline run with text input (no STT/TTS)."""
client = await hass_ws_client(hass)
await client.send_json(
{
"id": 5,
"type": "voice_assistant/run",
"start_stage": "intent",
"end_stage": "intent",
"input": {"text": "Are the lights on?"},
}
)
# result
msg = await client.receive_json()
assert msg["success"]
# run start
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == snapshot
# intent
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-start"
assert msg["event"]["data"] == snapshot
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-end"
assert msg["event"]["data"] == snapshot
# run end
msg = await client.receive_json()
assert msg["event"]["type"] == "run-end"
assert msg["event"]["data"] == {}
async def test_audio_pipeline(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, snapshot: SnapshotAssertion
) -> None:
"""Test events from a pipeline run with audio input/output."""
client = await hass_ws_client(hass)
await client.send_json(
{
"id": 5,
"type": "voice_assistant/run",
"start_stage": "stt",
"end_stage": "tts",
}
)
# result
msg = await client.receive_json()
assert msg["success"]
# run start
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == snapshot
# stt
msg = await client.receive_json()
assert msg["event"]["type"] == "stt-start"
assert msg["event"]["data"] == snapshot
# End of audio stream (handler id + empty payload)
await client.send_bytes(b"1")
msg = await client.receive_json()
assert msg["event"]["type"] == "stt-end"
assert msg["event"]["data"] == snapshot
# intent
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-start"
assert msg["event"]["data"] == snapshot
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-end"
assert msg["event"]["data"] == snapshot
# text to speech
msg = await client.receive_json()
assert msg["event"]["type"] == "tts-start"
assert msg["event"]["data"] == snapshot
msg = await client.receive_json()
assert msg["event"]["type"] == "tts-end"
assert msg["event"]["data"] == snapshot
# run end
msg = await client.receive_json()
assert msg["event"]["type"] == "run-end"
assert msg["event"]["data"] == {}
async def test_intent_timeout(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test partial pipeline run with conversation agent timeout."""
client = await hass_ws_client(hass)
async def sleepy_converse(*args, **kwargs):
await asyncio.sleep(3600)
with patch(
"homeassistant.components.conversation.async_converse",
new=sleepy_converse,
):
await client.send_json(
{
"id": 5,
"type": "voice_assistant/run",
"start_stage": "intent",
"end_stage": "intent",
"input": {"text": "Are the lights on?"},
"timeout": 0.1,
}
)
# result
msg = await client.receive_json()
assert msg["success"]
# run start
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == snapshot
# intent
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-start"
assert msg["event"]["data"] == snapshot
# timeout error
msg = await client.receive_json()
assert msg["event"]["type"] == "error"
assert msg["event"]["data"] == snapshot
async def test_text_pipeline_timeout(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test text-only pipeline run with immediate timeout."""
client = await hass_ws_client(hass)
async def sleepy_run(*args, **kwargs):
await asyncio.sleep(3600)
with patch(
"homeassistant.components.voice_assistant.pipeline.PipelineInput.execute",
new=sleepy_run,
):
await client.send_json(
{
"id": 5,
"type": "voice_assistant/run",
"start_stage": "intent",
"end_stage": "intent",
"input": {"text": "Are the lights on?"},
"timeout": 0.0001,
}
)
# result
msg = await client.receive_json()
assert msg["success"]
# timeout error
msg = await client.receive_json()
assert msg["event"]["type"] == "error"
assert msg["event"]["data"] == snapshot
async def test_intent_failed(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test text-only pipeline run with conversation agent error."""
client = await hass_ws_client(hass)
with patch(
"homeassistant.components.conversation.async_converse",
new=MagicMock(return_value=RuntimeError),
):
await client.send_json(
{
"id": 5,
"type": "voice_assistant/run",
"start_stage": "intent",
"end_stage": "intent",
"input": {"text": "Are the lights on?"},
}
)
# result
msg = await client.receive_json()
assert msg["success"]
# run start
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == snapshot
# intent start
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-start"
assert msg["event"]["data"] == snapshot
# intent error
msg = await client.receive_json()
assert msg["event"]["type"] == "error"
assert msg["event"]["data"]["code"] == "intent-failed"
async def test_audio_pipeline_timeout(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test audio pipeline run with immediate timeout."""
client = await hass_ws_client(hass)
async def sleepy_run(*args, **kwargs):
await asyncio.sleep(3600)
with patch(
"homeassistant.components.voice_assistant.pipeline.PipelineInput.execute",
new=sleepy_run,
):
await client.send_json(
{
"id": 5,
"type": "voice_assistant/run",
"start_stage": "stt",
"end_stage": "tts",
"timeout": 0.0001,
}
)
# result
msg = await client.receive_json()
assert msg["success"]
# timeout error
msg = await client.receive_json()
assert msg["event"]["type"] == "error"
assert msg["event"]["data"]["code"] == "timeout"
async def test_stt_provider_missing(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
) -> None:
"""Test events from a pipeline run with a non-existent STT provider."""
with patch(
"homeassistant.components.stt.async_get_provider",
new=MagicMock(return_value=None),
):
client = await hass_ws_client(hass)
await client.send_json(
{
"id": 5,
"type": "voice_assistant/run",
"start_stage": "stt",
"end_stage": "tts",
}
)
# result
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"]["code"] == "stt-provider-missing"
async def test_stt_stream_failed(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
) -> None:
"""Test events from a pipeline run with a non-existent STT provider."""
with patch(
"tests.components.voice_assistant.test_websocket.MockSttProvider.async_process_audio_stream",
new=MagicMock(side_effect=RuntimeError),
):
client = await hass_ws_client(hass)
await client.send_json(
{
"id": 5,
"type": "voice_assistant/run",
"start_stage": "stt",
"end_stage": "tts",
}
)
# result
msg = await client.receive_json()
assert msg["success"]
# run start
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == snapshot
# stt
msg = await client.receive_json()
assert msg["event"]["type"] == "stt-start"
assert msg["event"]["data"] == snapshot
# End of audio stream (handler id + empty payload)
await client.send_bytes(b"1")
# stt error
msg = await client.receive_json()
assert msg["event"]["type"] == "error"
assert msg["event"]["data"]["code"] == "stt-stream-failed"
async def test_tts_failed(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test pipeline run with text to speech error."""
client = await hass_ws_client(hass)
with patch(
"homeassistant.components.media_source.async_resolve_media",
new=MagicMock(return_value=RuntimeError),
):
await client.send_json(
{
"id": 5,
"type": "voice_assistant/run",
"start_stage": "tts",
"end_stage": "tts",
"input": {"text": "Lights are on."},
}
)
# result
msg = await client.receive_json()
assert msg["success"]
# run start
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == snapshot
# tts start
msg = await client.receive_json()
assert msg["event"]["type"] == "tts-start"
assert msg["event"]["data"] == snapshot
# tts error
msg = await client.receive_json()
assert msg["event"]["type"] == "error"
assert msg["event"]["data"]["code"] == "tts-failed"
async def test_invalid_stage_order(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
) -> None:
"""Test pipeline run with invalid stage order."""
client = await hass_ws_client(hass)
await client.send_json(
{
"id": 5,
"type": "voice_assistant/run",
"start_stage": "tts",
"end_stage": "stt",
"input": {"text": "Lights are on."},
}
)
# result
msg = await client.receive_json()
assert not msg["success"]