diff --git a/homeassistant/components/voice_assistant/__init__.py b/homeassistant/components/voice_assistant/__init__.py index 2ae169a28eb..8a2c04d8301 100644 --- a/homeassistant/components/voice_assistant/__init__.py +++ b/homeassistant/components/voice_assistant/__init__.py @@ -1,12 +1,33 @@ """The Voice Assistant integration.""" from __future__ import annotations -from homeassistant.core import HomeAssistant +from collections.abc import AsyncIterable + +from homeassistant.components import stt +from homeassistant.core import Context, HomeAssistant from homeassistant.helpers.typing import ConfigType from .const import DOMAIN +from .error import PipelineNotFound +from .pipeline import ( + PipelineEvent, + PipelineEventCallback, + PipelineEventType, + PipelineInput, + PipelineRun, + PipelineStage, + async_get_pipeline, +) from .websocket_api import async_register_websocket_api +__all__ = ( + "DOMAIN", + "async_setup", + "async_pipeline_from_audio_stream", + "PipelineEvent", + "PipelineEventType", +) + async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up Voice Assistant integration.""" @@ -14,3 +35,55 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async_register_websocket_api(hass) return True + + +async def async_pipeline_from_audio_stream( + hass: HomeAssistant, + event_callback: PipelineEventCallback, + stt_metadata: stt.SpeechMetadata, + stt_stream: AsyncIterable[bytes], + language: str | None = None, + pipeline_id: str | None = None, + conversation_id: str | None = None, + context: Context | None = None, +) -> None: + """Create an audio pipeline from an audio stream.""" + if language is None: + language = hass.config.language + + # Temporary workaround for language codes + if language == "en": + language = "en-US" + + if stt_metadata.language == "": + stt_metadata.language = language + + if context is None: + context = Context() + + pipeline = async_get_pipeline( + hass, + pipeline_id=pipeline_id, + language=language, + ) + if pipeline is None: + raise PipelineNotFound( + "pipeline_not_found", f"Pipeline {pipeline_id} not found" + ) + + pipeline_input = PipelineInput( + conversation_id=conversation_id, + stt_metadata=stt_metadata, + stt_stream=stt_stream, + run=PipelineRun( + hass, + context=context, + pipeline=pipeline, + start_stage=PipelineStage.STT, + end_stage=PipelineStage.TTS, + event_callback=event_callback, + ), + ) + + await pipeline_input.validate() + await pipeline_input.execute() diff --git a/homeassistant/components/voice_assistant/error.py b/homeassistant/components/voice_assistant/error.py new file mode 100644 index 00000000000..2a52bf82c8e --- /dev/null +++ b/homeassistant/components/voice_assistant/error.py @@ -0,0 +1,30 @@ +"""Voice Assistant errors.""" + +from homeassistant.exceptions import HomeAssistantError + + +class PipelineError(HomeAssistantError): + """Base class for pipeline errors.""" + + def __init__(self, code: str, message: str) -> None: + """Set error message.""" + self.code = code + self.message = message + + super().__init__(f"Pipeline error code={code}, message={message}") + + +class PipelineNotFound(PipelineError): + """Unspecified pipeline picked.""" + + +class SpeechToTextError(PipelineError): + """Error in speech to text portion of pipeline.""" + + +class IntentRecognitionError(PipelineError): + """Error in intent recognition portion of pipeline.""" + + +class TextToSpeechError(PipelineError): + """Error in text to speech portion of pipeline.""" diff --git a/homeassistant/components/voice_assistant/pipeline.py b/homeassistant/components/voice_assistant/pipeline.py index b41ab8ef9f7..7c909c32819 100644 --- a/homeassistant/components/voice_assistant/pipeline.py +++ b/homeassistant/components/voice_assistant/pipeline.py @@ -16,6 +16,12 @@ from homeassistant.core import Context, HomeAssistant, callback from homeassistant.util.dt import utcnow from .const import DOMAIN +from .error import ( + IntentRecognitionError, + PipelineError, + SpeechToTextError, + TextToSpeechError, +) _LOGGER = logging.getLogger(__name__) @@ -39,29 +45,6 @@ def async_get_pipeline( ) -class PipelineError(Exception): - """Base class for pipeline errors.""" - - def __init__(self, code: str, message: str) -> None: - """Set error message.""" - self.code = code - self.message = message - - super().__init__(f"Pipeline error code={code}, message={message}") - - -class SpeechToTextError(PipelineError): - """Error in speech to text portion of pipeline.""" - - -class IntentRecognitionError(PipelineError): - """Error in intent recognition portion of pipeline.""" - - -class TextToSpeechError(PipelineError): - """Error in text to speech portion of pipeline.""" - - class PipelineEventType(StrEnum): """Event types emitted during a pipeline run.""" @@ -93,6 +76,9 @@ class PipelineEvent: } +PipelineEventCallback = Callable[[PipelineEvent], None] + + @dataclass class Pipeline: """A voice assistant pipeline.""" @@ -146,7 +132,7 @@ class PipelineRun: pipeline: Pipeline start_stage: PipelineStage end_stage: PipelineStage - event_callback: Callable[[PipelineEvent], None] + event_callback: PipelineEventCallback language: str = None # type: ignore[assignment] runner_data: Any | None = None stt_provider: stt.Provider | None = None diff --git a/tests/common.py b/tests/common.py index f7a2c04a5f5..632294a50fb 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1268,7 +1268,7 @@ def mock_integration( def mock_import_platform(platform_name: str) -> NoReturn: raise ImportError( - f"Mocked unable to import platform '{platform_name}'", + f"Mocked unable to import platform '{integration.pkg_path}.{platform_name}'", name=f"{integration.pkg_path}.{platform_name}", ) diff --git a/tests/components/voice_assistant/conftest.py b/tests/components/voice_assistant/conftest.py new file mode 100644 index 00000000000..86da6334e09 --- /dev/null +++ b/tests/components/voice_assistant/conftest.py @@ -0,0 +1,139 @@ +"""Test fixtures for voice assistant.""" +from collections.abc import AsyncIterable +from typing import Any +from unittest.mock import AsyncMock, Mock + +import pytest + +from homeassistant.components import stt, tts +from homeassistant.core import HomeAssistant +from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType +from homeassistant.setup import async_setup_component + +from tests.common import MockModule, mock_integration, mock_platform +from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import + mock_get_cache_files, + mock_init_cache_dir, +) + +_TRANSCRIPT = "test transcript" + + +class MockSttProvider(stt.Provider): + """Mock STT provider.""" + + def __init__(self, hass: HomeAssistant, text: str) -> None: + """Init test provider.""" + self.hass = hass + self.text = text + self.received = [] + + @property + def supported_languages(self) -> list[str]: + """Return a list of supported languages.""" + return ["en-US"] + + @property + def supported_formats(self) -> list[stt.AudioFormats]: + """Return a list of supported formats.""" + return [stt.AudioFormats.WAV] + + @property + def supported_codecs(self) -> list[stt.AudioCodecs]: + """Return a list of supported codecs.""" + return [stt.AudioCodecs.PCM] + + @property + def supported_bit_rates(self) -> list[stt.AudioBitRates]: + """Return a list of supported bitrates.""" + return [stt.AudioBitRates.BITRATE_16] + + @property + def supported_sample_rates(self) -> list[stt.AudioSampleRates]: + """Return a list of supported samplerates.""" + return [stt.AudioSampleRates.SAMPLERATE_16000] + + @property + def supported_channels(self) -> list[stt.AudioChannels]: + """Return a list of supported channels.""" + return [stt.AudioChannels.CHANNEL_MONO] + + async def async_process_audio_stream( + self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes] + ) -> stt.SpeechResult: + """Process an audio stream.""" + async for data in stream: + if not data: + break + self.received.append(data) + return stt.SpeechResult(self.text, stt.SpeechResultState.SUCCESS) + + +class MockTTSProvider(tts.Provider): + """Mock TTS provider.""" + + name = "Test" + + @property + def default_language(self) -> str: + """Return the default language.""" + return "en" + + @property + def supported_languages(self) -> list[str]: + """Return list of supported languages.""" + return ["en-US"] + + @property + def supported_options(self) -> list[str]: + """Return list of supported options like voice, emotions.""" + return ["voice", "age"] + + def get_tts_audio( + self, message: str, language: str, options: dict[str, Any] | None = None + ) -> tts.TtsAudioType: + """Load TTS data.""" + return ("mp3", b"") + + +class MockTTS: + """A mock TTS platform.""" + + PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA + + async def async_get_engine( + self, + hass: HomeAssistant, + config: ConfigType, + discovery_info: DiscoveryInfoType | None = None, + ) -> tts.Provider: + """Set up a mock speech component.""" + return MockTTSProvider() + + +@pytest.fixture +async def mock_stt_provider(hass) -> MockSttProvider: + """Mock STT provider.""" + return MockSttProvider(hass, _TRANSCRIPT) + + +@pytest.fixture(autouse=True) +async def init_components( + hass: HomeAssistant, + mock_stt_provider: MockSttProvider, + mock_get_cache_files, # noqa: F811 + mock_init_cache_dir, # noqa: F811, +): + """Initialize relevant components with empty configs.""" + mock_integration(hass, MockModule(domain="test")) + mock_platform(hass, "test.tts", MockTTS()) + mock_platform( + hass, + "test.stt", + Mock(async_get_engine=AsyncMock(return_value=mock_stt_provider)), + ) + + assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}}) + assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}}) + assert await async_setup_component(hass, "media_source", {}) + assert await async_setup_component(hass, "voice_assistant", {}) diff --git a/tests/components/voice_assistant/snapshots/test_init.ambr b/tests/components/voice_assistant/snapshots/test_init.ambr new file mode 100644 index 00000000000..459bfca01dd --- /dev/null +++ b/tests/components/voice_assistant/snapshots/test_init.ambr @@ -0,0 +1,85 @@ +# serializer version: 1 +# name: test_pipeline_from_audio_stream + list([ + dict({ + 'data': dict({ + 'language': 'en-US', + 'pipeline': 'en-US', + }), + 'type': , + }), + dict({ + 'data': dict({ + 'engine': 'test', + 'metadata': dict({ + 'bit_rate': , + 'channel': , + 'codec': , + 'format': , + 'language': 'en-US', + 'sample_rate': , + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'stt_output': dict({ + 'text': 'test transcript', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'engine': 'homeassistant', + 'intent_input': 'test transcript', + }), + 'type': , + }), + dict({ + 'data': dict({ + 'intent_output': dict({ + 'conversation_id': None, + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'code': 'no_intent_match', + }), + 'language': 'en-US', + 'response_type': 'error', + 'speech': dict({ + 'plain': dict({ + 'extra_data': None, + 'speech': "Sorry, I couldn't understand that", + }), + }), + }), + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'engine': 'test', + 'tts_input': "Sorry, I couldn't understand that", + }), + 'type': , + }), + dict({ + 'data': dict({ + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + }), + 'type': , + }), + ]) +# --- diff --git a/tests/components/voice_assistant/test_init.py b/tests/components/voice_assistant/test_init.py new file mode 100644 index 00000000000..1178f94c60c --- /dev/null +++ b/tests/components/voice_assistant/test_init.py @@ -0,0 +1,42 @@ +"""Test Voice Assistant init.""" + +from syrupy.assertion import SnapshotAssertion + +from homeassistant.components import stt, voice_assistant +from homeassistant.core import HomeAssistant + + +async def test_pipeline_from_audio_stream( + hass: HomeAssistant, mock_stt_provider, snapshot: SnapshotAssertion +) -> None: + """Test creating a pipeline from an audio stream.""" + + events = [] + + async def audio_data(): + yield b"part1" + yield b"part2" + yield b"" + + await voice_assistant.async_pipeline_from_audio_stream( + hass, + events.append, + stt.SpeechMetadata( + language="", + format=stt.AudioFormats.WAV, + codec=stt.AudioCodecs.PCM, + bit_rate=stt.AudioBitRates.BITRATE_16, + sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, + channel=stt.AudioChannels.CHANNEL_MONO, + ), + audio_data(), + ) + + processed = [] + for event in events: + as_dict = event.as_dict() + as_dict.pop("timestamp") + processed.append(as_dict) + + assert processed == snapshot + assert mock_stt_provider.received == [b"part1", b"part2"] diff --git a/tests/components/voice_assistant/test_websocket.py b/tests/components/voice_assistant/test_websocket.py index 54fe51a7a22..2acb954c87d 100644 --- a/tests/components/voice_assistant/test_websocket.py +++ b/tests/components/voice_assistant/test_websocket.py @@ -1,143 +1,13 @@ """Websocket tests for Voice Assistant integration.""" import asyncio -from collections.abc import AsyncIterable -from typing import Any from unittest.mock import MagicMock, patch -import pytest from syrupy.assertion import SnapshotAssertion -from homeassistant.components import stt, tts from homeassistant.core import HomeAssistant -from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType -from homeassistant.setup import async_setup_component -from tests.common import MockModule, mock_integration, mock_platform -from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import - mock_get_cache_files, - mock_init_cache_dir, -) from tests.typing import WebSocketGenerator -_TRANSCRIPT = "test transcript" - - -class MockSttProvider(stt.Provider): - """Mock STT provider.""" - - def __init__(self, hass: HomeAssistant, text: str) -> None: - """Init test provider.""" - self.hass = hass - self.text = text - - @property - def supported_languages(self) -> list[str]: - """Return a list of supported languages.""" - return ["en-US"] - - @property - def supported_formats(self) -> list[stt.AudioFormats]: - """Return a list of supported formats.""" - return [stt.AudioFormats.WAV] - - @property - def supported_codecs(self) -> list[stt.AudioCodecs]: - """Return a list of supported codecs.""" - return [stt.AudioCodecs.PCM] - - @property - def supported_bit_rates(self) -> list[stt.AudioBitRates]: - """Return a list of supported bitrates.""" - return [stt.AudioBitRates.BITRATE_16] - - @property - def supported_sample_rates(self) -> list[stt.AudioSampleRates]: - """Return a list of supported samplerates.""" - return [stt.AudioSampleRates.SAMPLERATE_16000] - - @property - def supported_channels(self) -> list[stt.AudioChannels]: - """Return a list of supported channels.""" - return [stt.AudioChannels.CHANNEL_MONO] - - async def async_process_audio_stream( - self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes] - ) -> stt.SpeechResult: - """Process an audio stream.""" - return stt.SpeechResult(self.text, stt.SpeechResultState.SUCCESS) - - -class MockSTT: - """A mock STT platform.""" - - async def async_get_engine( - self, - hass: HomeAssistant, - config: ConfigType, - discovery_info: DiscoveryInfoType | None = None, - ) -> stt.Provider: - """Set up a mock speech component.""" - return MockSttProvider(hass, _TRANSCRIPT) - - -class MockTTSProvider(tts.Provider): - """Mock TTS provider.""" - - name = "Test" - - @property - def default_language(self) -> str: - """Return the default language.""" - return "en" - - @property - def supported_languages(self) -> list[str]: - """Return list of supported languages.""" - return ["en-US"] - - @property - def supported_options(self) -> list[str]: - """Return list of supported options like voice, emotions.""" - return ["voice", "age"] - - def get_tts_audio( - self, message: str, language: str, options: dict[str, Any] | None = None - ) -> tts.TtsAudioType: - """Load TTS dat.""" - return ("mp3", b"") - - -class MockTTS: - """A mock TTS platform.""" - - PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA - - async def async_get_engine( - self, - hass: HomeAssistant, - config: ConfigType, - discovery_info: DiscoveryInfoType | None = None, - ) -> tts.Provider: - """Set up a mock speech component.""" - return MockTTSProvider() - - -@pytest.fixture(autouse=True) -async def init_components( - hass: HomeAssistant, - mock_get_cache_files, # noqa: F811 - mock_init_cache_dir, # noqa: F811 -): - """Initialize relevant components with empty configs.""" - mock_integration(hass, MockModule(domain="test")) - mock_platform(hass, "test.tts", MockTTS()) - mock_platform(hass, "test.stt", MockSTT()) - - assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}}) - assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}}) - assert await async_setup_component(hass, "media_source", {}) - assert await async_setup_component(hass, "voice_assistant", {}) - async def test_text_only_pipeline( hass: HomeAssistant, @@ -211,7 +81,7 @@ async def test_audio_pipeline( assert msg["event"]["data"] == snapshot # End of audio stream (handler id + empty payload) - await client.send_bytes(b"1") + await client.send_bytes(bytes([1])) msg = await client.receive_json() assert msg["event"]["type"] == "stt-end" @@ -438,7 +308,7 @@ async def test_stt_stream_failed( ) -> None: """Test events from a pipeline run with a non-existent STT provider.""" with patch( - "tests.components.voice_assistant.test_websocket.MockSttProvider.async_process_audio_stream", + "tests.components.voice_assistant.conftest.MockSttProvider.async_process_audio_stream", new=MagicMock(side_effect=RuntimeError), ): client = await hass_ws_client(hass)