Run pipeline from audio stream function (#90748)
* Run pipeline from audio stream function * Fix tests --------- Co-authored-by: Michael Hansen <mike@rhasspy.org>
This commit is contained in:
parent
4f1574b859
commit
6e4c78686e
8 changed files with 383 additions and 158 deletions
|
@ -1,12 +1,33 @@
|
||||||
"""The Voice Assistant integration."""
|
"""The Voice Assistant integration."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant
|
from collections.abc import AsyncIterable
|
||||||
|
|
||||||
|
from homeassistant.components import stt
|
||||||
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
|
from .error import PipelineNotFound
|
||||||
|
from .pipeline import (
|
||||||
|
PipelineEvent,
|
||||||
|
PipelineEventCallback,
|
||||||
|
PipelineEventType,
|
||||||
|
PipelineInput,
|
||||||
|
PipelineRun,
|
||||||
|
PipelineStage,
|
||||||
|
async_get_pipeline,
|
||||||
|
)
|
||||||
from .websocket_api import async_register_websocket_api
|
from .websocket_api import async_register_websocket_api
|
||||||
|
|
||||||
|
__all__ = (
|
||||||
|
"DOMAIN",
|
||||||
|
"async_setup",
|
||||||
|
"async_pipeline_from_audio_stream",
|
||||||
|
"PipelineEvent",
|
||||||
|
"PipelineEventType",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
"""Set up Voice Assistant integration."""
|
"""Set up Voice Assistant integration."""
|
||||||
|
@ -14,3 +35,55 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
async_register_websocket_api(hass)
|
async_register_websocket_api(hass)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def async_pipeline_from_audio_stream(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
event_callback: PipelineEventCallback,
|
||||||
|
stt_metadata: stt.SpeechMetadata,
|
||||||
|
stt_stream: AsyncIterable[bytes],
|
||||||
|
language: str | None = None,
|
||||||
|
pipeline_id: str | None = None,
|
||||||
|
conversation_id: str | None = None,
|
||||||
|
context: Context | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Create an audio pipeline from an audio stream."""
|
||||||
|
if language is None:
|
||||||
|
language = hass.config.language
|
||||||
|
|
||||||
|
# Temporary workaround for language codes
|
||||||
|
if language == "en":
|
||||||
|
language = "en-US"
|
||||||
|
|
||||||
|
if stt_metadata.language == "":
|
||||||
|
stt_metadata.language = language
|
||||||
|
|
||||||
|
if context is None:
|
||||||
|
context = Context()
|
||||||
|
|
||||||
|
pipeline = async_get_pipeline(
|
||||||
|
hass,
|
||||||
|
pipeline_id=pipeline_id,
|
||||||
|
language=language,
|
||||||
|
)
|
||||||
|
if pipeline is None:
|
||||||
|
raise PipelineNotFound(
|
||||||
|
"pipeline_not_found", f"Pipeline {pipeline_id} not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline_input = PipelineInput(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
stt_metadata=stt_metadata,
|
||||||
|
stt_stream=stt_stream,
|
||||||
|
run=PipelineRun(
|
||||||
|
hass,
|
||||||
|
context=context,
|
||||||
|
pipeline=pipeline,
|
||||||
|
start_stage=PipelineStage.STT,
|
||||||
|
end_stage=PipelineStage.TTS,
|
||||||
|
event_callback=event_callback,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
await pipeline_input.validate()
|
||||||
|
await pipeline_input.execute()
|
||||||
|
|
30
homeassistant/components/voice_assistant/error.py
Normal file
30
homeassistant/components/voice_assistant/error.py
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
"""Voice Assistant errors."""
|
||||||
|
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineError(HomeAssistantError):
|
||||||
|
"""Base class for pipeline errors."""
|
||||||
|
|
||||||
|
def __init__(self, code: str, message: str) -> None:
|
||||||
|
"""Set error message."""
|
||||||
|
self.code = code
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
super().__init__(f"Pipeline error code={code}, message={message}")
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineNotFound(PipelineError):
|
||||||
|
"""Unspecified pipeline picked."""
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechToTextError(PipelineError):
|
||||||
|
"""Error in speech to text portion of pipeline."""
|
||||||
|
|
||||||
|
|
||||||
|
class IntentRecognitionError(PipelineError):
|
||||||
|
"""Error in intent recognition portion of pipeline."""
|
||||||
|
|
||||||
|
|
||||||
|
class TextToSpeechError(PipelineError):
|
||||||
|
"""Error in text to speech portion of pipeline."""
|
|
@ -16,6 +16,12 @@ from homeassistant.core import Context, HomeAssistant, callback
|
||||||
from homeassistant.util.dt import utcnow
|
from homeassistant.util.dt import utcnow
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
|
from .error import (
|
||||||
|
IntentRecognitionError,
|
||||||
|
PipelineError,
|
||||||
|
SpeechToTextError,
|
||||||
|
TextToSpeechError,
|
||||||
|
)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -39,29 +45,6 @@ def async_get_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PipelineError(Exception):
|
|
||||||
"""Base class for pipeline errors."""
|
|
||||||
|
|
||||||
def __init__(self, code: str, message: str) -> None:
|
|
||||||
"""Set error message."""
|
|
||||||
self.code = code
|
|
||||||
self.message = message
|
|
||||||
|
|
||||||
super().__init__(f"Pipeline error code={code}, message={message}")
|
|
||||||
|
|
||||||
|
|
||||||
class SpeechToTextError(PipelineError):
|
|
||||||
"""Error in speech to text portion of pipeline."""
|
|
||||||
|
|
||||||
|
|
||||||
class IntentRecognitionError(PipelineError):
|
|
||||||
"""Error in intent recognition portion of pipeline."""
|
|
||||||
|
|
||||||
|
|
||||||
class TextToSpeechError(PipelineError):
|
|
||||||
"""Error in text to speech portion of pipeline."""
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineEventType(StrEnum):
|
class PipelineEventType(StrEnum):
|
||||||
"""Event types emitted during a pipeline run."""
|
"""Event types emitted during a pipeline run."""
|
||||||
|
|
||||||
|
@ -93,6 +76,9 @@ class PipelineEvent:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
PipelineEventCallback = Callable[[PipelineEvent], None]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
"""A voice assistant pipeline."""
|
"""A voice assistant pipeline."""
|
||||||
|
@ -146,7 +132,7 @@ class PipelineRun:
|
||||||
pipeline: Pipeline
|
pipeline: Pipeline
|
||||||
start_stage: PipelineStage
|
start_stage: PipelineStage
|
||||||
end_stage: PipelineStage
|
end_stage: PipelineStage
|
||||||
event_callback: Callable[[PipelineEvent], None]
|
event_callback: PipelineEventCallback
|
||||||
language: str = None # type: ignore[assignment]
|
language: str = None # type: ignore[assignment]
|
||||||
runner_data: Any | None = None
|
runner_data: Any | None = None
|
||||||
stt_provider: stt.Provider | None = None
|
stt_provider: stt.Provider | None = None
|
||||||
|
|
|
@ -1268,7 +1268,7 @@ def mock_integration(
|
||||||
|
|
||||||
def mock_import_platform(platform_name: str) -> NoReturn:
|
def mock_import_platform(platform_name: str) -> NoReturn:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"Mocked unable to import platform '{platform_name}'",
|
f"Mocked unable to import platform '{integration.pkg_path}.{platform_name}'",
|
||||||
name=f"{integration.pkg_path}.{platform_name}",
|
name=f"{integration.pkg_path}.{platform_name}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
139
tests/components/voice_assistant/conftest.py
Normal file
139
tests/components/voice_assistant/conftest.py
Normal file
|
@ -0,0 +1,139 @@
|
||||||
|
"""Test fixtures for voice assistant."""
|
||||||
|
from collections.abc import AsyncIterable
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components import stt, tts
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from tests.common import MockModule, mock_integration, mock_platform
|
||||||
|
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
|
||||||
|
mock_get_cache_files,
|
||||||
|
mock_init_cache_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
_TRANSCRIPT = "test transcript"
|
||||||
|
|
||||||
|
|
||||||
|
class MockSttProvider(stt.Provider):
|
||||||
|
"""Mock STT provider."""
|
||||||
|
|
||||||
|
def __init__(self, hass: HomeAssistant, text: str) -> None:
|
||||||
|
"""Init test provider."""
|
||||||
|
self.hass = hass
|
||||||
|
self.text = text
|
||||||
|
self.received = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_languages(self) -> list[str]:
|
||||||
|
"""Return a list of supported languages."""
|
||||||
|
return ["en-US"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_formats(self) -> list[stt.AudioFormats]:
|
||||||
|
"""Return a list of supported formats."""
|
||||||
|
return [stt.AudioFormats.WAV]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_codecs(self) -> list[stt.AudioCodecs]:
|
||||||
|
"""Return a list of supported codecs."""
|
||||||
|
return [stt.AudioCodecs.PCM]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_bit_rates(self) -> list[stt.AudioBitRates]:
|
||||||
|
"""Return a list of supported bitrates."""
|
||||||
|
return [stt.AudioBitRates.BITRATE_16]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_sample_rates(self) -> list[stt.AudioSampleRates]:
|
||||||
|
"""Return a list of supported samplerates."""
|
||||||
|
return [stt.AudioSampleRates.SAMPLERATE_16000]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_channels(self) -> list[stt.AudioChannels]:
|
||||||
|
"""Return a list of supported channels."""
|
||||||
|
return [stt.AudioChannels.CHANNEL_MONO]
|
||||||
|
|
||||||
|
async def async_process_audio_stream(
|
||||||
|
self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes]
|
||||||
|
) -> stt.SpeechResult:
|
||||||
|
"""Process an audio stream."""
|
||||||
|
async for data in stream:
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
self.received.append(data)
|
||||||
|
return stt.SpeechResult(self.text, stt.SpeechResultState.SUCCESS)
|
||||||
|
|
||||||
|
|
||||||
|
class MockTTSProvider(tts.Provider):
|
||||||
|
"""Mock TTS provider."""
|
||||||
|
|
||||||
|
name = "Test"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_language(self) -> str:
|
||||||
|
"""Return the default language."""
|
||||||
|
return "en"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_languages(self) -> list[str]:
|
||||||
|
"""Return list of supported languages."""
|
||||||
|
return ["en-US"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_options(self) -> list[str]:
|
||||||
|
"""Return list of supported options like voice, emotions."""
|
||||||
|
return ["voice", "age"]
|
||||||
|
|
||||||
|
def get_tts_audio(
|
||||||
|
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||||
|
) -> tts.TtsAudioType:
|
||||||
|
"""Load TTS data."""
|
||||||
|
return ("mp3", b"")
|
||||||
|
|
||||||
|
|
||||||
|
class MockTTS:
|
||||||
|
"""A mock TTS platform."""
|
||||||
|
|
||||||
|
PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA
|
||||||
|
|
||||||
|
async def async_get_engine(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config: ConfigType,
|
||||||
|
discovery_info: DiscoveryInfoType | None = None,
|
||||||
|
) -> tts.Provider:
|
||||||
|
"""Set up a mock speech component."""
|
||||||
|
return MockTTSProvider()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_stt_provider(hass) -> MockSttProvider:
|
||||||
|
"""Mock STT provider."""
|
||||||
|
return MockSttProvider(hass, _TRANSCRIPT)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
async def init_components(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_stt_provider: MockSttProvider,
|
||||||
|
mock_get_cache_files, # noqa: F811
|
||||||
|
mock_init_cache_dir, # noqa: F811,
|
||||||
|
):
|
||||||
|
"""Initialize relevant components with empty configs."""
|
||||||
|
mock_integration(hass, MockModule(domain="test"))
|
||||||
|
mock_platform(hass, "test.tts", MockTTS())
|
||||||
|
mock_platform(
|
||||||
|
hass,
|
||||||
|
"test.stt",
|
||||||
|
Mock(async_get_engine=AsyncMock(return_value=mock_stt_provider)),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
|
||||||
|
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
|
||||||
|
assert await async_setup_component(hass, "media_source", {})
|
||||||
|
assert await async_setup_component(hass, "voice_assistant", {})
|
85
tests/components/voice_assistant/snapshots/test_init.ambr
Normal file
85
tests/components/voice_assistant/snapshots/test_init.ambr
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
# serializer version: 1
|
||||||
|
# name: test_pipeline_from_audio_stream
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'language': 'en-US',
|
||||||
|
'pipeline': 'en-US',
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'engine': 'test',
|
||||||
|
'metadata': dict({
|
||||||
|
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
|
||||||
|
'channel': <AudioChannels.CHANNEL_MONO: 1>,
|
||||||
|
'codec': <AudioCodecs.PCM: 'pcm'>,
|
||||||
|
'format': <AudioFormats.WAV: 'wav'>,
|
||||||
|
'language': 'en-US',
|
||||||
|
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.STT_START: 'stt-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'stt_output': dict({
|
||||||
|
'text': 'test transcript',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.STT_END: 'stt-end'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'engine': 'homeassistant',
|
||||||
|
'intent_input': 'test transcript',
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'intent_output': dict({
|
||||||
|
'conversation_id': None,
|
||||||
|
'response': dict({
|
||||||
|
'card': dict({
|
||||||
|
}),
|
||||||
|
'data': dict({
|
||||||
|
'code': 'no_intent_match',
|
||||||
|
}),
|
||||||
|
'language': 'en-US',
|
||||||
|
'response_type': 'error',
|
||||||
|
'speech': dict({
|
||||||
|
'plain': dict({
|
||||||
|
'extra_data': None,
|
||||||
|
'speech': "Sorry, I couldn't understand that",
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'engine': 'test',
|
||||||
|
'tts_input': "Sorry, I couldn't understand that",
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.TTS_START: 'tts-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'tts_output': dict({
|
||||||
|
'mime_type': 'audio/mpeg',
|
||||||
|
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.TTS_END: 'tts-end'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
42
tests/components/voice_assistant/test_init.py
Normal file
42
tests/components/voice_assistant/test_init.py
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
"""Test Voice Assistant init."""
|
||||||
|
|
||||||
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
|
from homeassistant.components import stt, voice_assistant
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pipeline_from_audio_stream(
|
||||||
|
hass: HomeAssistant, mock_stt_provider, snapshot: SnapshotAssertion
|
||||||
|
) -> None:
|
||||||
|
"""Test creating a pipeline from an audio stream."""
|
||||||
|
|
||||||
|
events = []
|
||||||
|
|
||||||
|
async def audio_data():
|
||||||
|
yield b"part1"
|
||||||
|
yield b"part2"
|
||||||
|
yield b""
|
||||||
|
|
||||||
|
await voice_assistant.async_pipeline_from_audio_stream(
|
||||||
|
hass,
|
||||||
|
events.append,
|
||||||
|
stt.SpeechMetadata(
|
||||||
|
language="",
|
||||||
|
format=stt.AudioFormats.WAV,
|
||||||
|
codec=stt.AudioCodecs.PCM,
|
||||||
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||||
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||||
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
|
),
|
||||||
|
audio_data(),
|
||||||
|
)
|
||||||
|
|
||||||
|
processed = []
|
||||||
|
for event in events:
|
||||||
|
as_dict = event.as_dict()
|
||||||
|
as_dict.pop("timestamp")
|
||||||
|
processed.append(as_dict)
|
||||||
|
|
||||||
|
assert processed == snapshot
|
||||||
|
assert mock_stt_provider.received == [b"part1", b"part2"]
|
|
@ -1,143 +1,13 @@
|
||||||
"""Websocket tests for Voice Assistant integration."""
|
"""Websocket tests for Voice Assistant integration."""
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import AsyncIterable
|
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
from homeassistant.components import stt, tts
|
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
|
||||||
from homeassistant.setup import async_setup_component
|
|
||||||
|
|
||||||
from tests.common import MockModule, mock_integration, mock_platform
|
|
||||||
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
|
|
||||||
mock_get_cache_files,
|
|
||||||
mock_init_cache_dir,
|
|
||||||
)
|
|
||||||
from tests.typing import WebSocketGenerator
|
from tests.typing import WebSocketGenerator
|
||||||
|
|
||||||
_TRANSCRIPT = "test transcript"
|
|
||||||
|
|
||||||
|
|
||||||
class MockSttProvider(stt.Provider):
|
|
||||||
"""Mock STT provider."""
|
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant, text: str) -> None:
|
|
||||||
"""Init test provider."""
|
|
||||||
self.hass = hass
|
|
||||||
self.text = text
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_languages(self) -> list[str]:
|
|
||||||
"""Return a list of supported languages."""
|
|
||||||
return ["en-US"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_formats(self) -> list[stt.AudioFormats]:
|
|
||||||
"""Return a list of supported formats."""
|
|
||||||
return [stt.AudioFormats.WAV]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_codecs(self) -> list[stt.AudioCodecs]:
|
|
||||||
"""Return a list of supported codecs."""
|
|
||||||
return [stt.AudioCodecs.PCM]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_bit_rates(self) -> list[stt.AudioBitRates]:
|
|
||||||
"""Return a list of supported bitrates."""
|
|
||||||
return [stt.AudioBitRates.BITRATE_16]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_sample_rates(self) -> list[stt.AudioSampleRates]:
|
|
||||||
"""Return a list of supported samplerates."""
|
|
||||||
return [stt.AudioSampleRates.SAMPLERATE_16000]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_channels(self) -> list[stt.AudioChannels]:
|
|
||||||
"""Return a list of supported channels."""
|
|
||||||
return [stt.AudioChannels.CHANNEL_MONO]
|
|
||||||
|
|
||||||
async def async_process_audio_stream(
|
|
||||||
self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes]
|
|
||||||
) -> stt.SpeechResult:
|
|
||||||
"""Process an audio stream."""
|
|
||||||
return stt.SpeechResult(self.text, stt.SpeechResultState.SUCCESS)
|
|
||||||
|
|
||||||
|
|
||||||
class MockSTT:
|
|
||||||
"""A mock STT platform."""
|
|
||||||
|
|
||||||
async def async_get_engine(
|
|
||||||
self,
|
|
||||||
hass: HomeAssistant,
|
|
||||||
config: ConfigType,
|
|
||||||
discovery_info: DiscoveryInfoType | None = None,
|
|
||||||
) -> stt.Provider:
|
|
||||||
"""Set up a mock speech component."""
|
|
||||||
return MockSttProvider(hass, _TRANSCRIPT)
|
|
||||||
|
|
||||||
|
|
||||||
class MockTTSProvider(tts.Provider):
|
|
||||||
"""Mock TTS provider."""
|
|
||||||
|
|
||||||
name = "Test"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def default_language(self) -> str:
|
|
||||||
"""Return the default language."""
|
|
||||||
return "en"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_languages(self) -> list[str]:
|
|
||||||
"""Return list of supported languages."""
|
|
||||||
return ["en-US"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_options(self) -> list[str]:
|
|
||||||
"""Return list of supported options like voice, emotions."""
|
|
||||||
return ["voice", "age"]
|
|
||||||
|
|
||||||
def get_tts_audio(
|
|
||||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
|
||||||
) -> tts.TtsAudioType:
|
|
||||||
"""Load TTS dat."""
|
|
||||||
return ("mp3", b"")
|
|
||||||
|
|
||||||
|
|
||||||
class MockTTS:
|
|
||||||
"""A mock TTS platform."""
|
|
||||||
|
|
||||||
PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA
|
|
||||||
|
|
||||||
async def async_get_engine(
|
|
||||||
self,
|
|
||||||
hass: HomeAssistant,
|
|
||||||
config: ConfigType,
|
|
||||||
discovery_info: DiscoveryInfoType | None = None,
|
|
||||||
) -> tts.Provider:
|
|
||||||
"""Set up a mock speech component."""
|
|
||||||
return MockTTSProvider()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
async def init_components(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
mock_get_cache_files, # noqa: F811
|
|
||||||
mock_init_cache_dir, # noqa: F811
|
|
||||||
):
|
|
||||||
"""Initialize relevant components with empty configs."""
|
|
||||||
mock_integration(hass, MockModule(domain="test"))
|
|
||||||
mock_platform(hass, "test.tts", MockTTS())
|
|
||||||
mock_platform(hass, "test.stt", MockSTT())
|
|
||||||
|
|
||||||
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
|
|
||||||
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
|
|
||||||
assert await async_setup_component(hass, "media_source", {})
|
|
||||||
assert await async_setup_component(hass, "voice_assistant", {})
|
|
||||||
|
|
||||||
|
|
||||||
async def test_text_only_pipeline(
|
async def test_text_only_pipeline(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
@ -211,7 +81,7 @@ async def test_audio_pipeline(
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
# End of audio stream (handler id + empty payload)
|
# End of audio stream (handler id + empty payload)
|
||||||
await client.send_bytes(b"1")
|
await client.send_bytes(bytes([1]))
|
||||||
|
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "stt-end"
|
assert msg["event"]["type"] == "stt-end"
|
||||||
|
@ -438,7 +308,7 @@ async def test_stt_stream_failed(
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test events from a pipeline run with a non-existent STT provider."""
|
"""Test events from a pipeline run with a non-existent STT provider."""
|
||||||
with patch(
|
with patch(
|
||||||
"tests.components.voice_assistant.test_websocket.MockSttProvider.async_process_audio_stream",
|
"tests.components.voice_assistant.conftest.MockSttProvider.async_process_audio_stream",
|
||||||
new=MagicMock(side_effect=RuntimeError),
|
new=MagicMock(side_effect=RuntimeError),
|
||||||
):
|
):
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue