From a98be9dc84e68e6e469496f25433378622e107e2 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Thu, 20 Apr 2023 15:01:31 +0200 Subject: [PATCH] Test specifying pipeline in calls to async_pipeline_from_audio_stream (#91739) --- tests/components/assist_pipeline/conftest.py | 98 +++++++++- .../assist_pipeline/snapshots/test_init.ambr | 170 +++++++++++++++++- tests/components/assist_pipeline/test_init.py | 140 ++++++++++++++- 3 files changed, 398 insertions(+), 10 deletions(-) diff --git a/tests/components/assist_pipeline/conftest.py b/tests/components/assist_pipeline/conftest.py index b010236af09..57fb19dcf59 100644 --- a/tests/components/assist_pipeline/conftest.py +++ b/tests/components/assist_pipeline/conftest.py @@ -1,18 +1,27 @@ """Test fixtures for voice assistant.""" -from collections.abc import AsyncIterable +from collections.abc import AsyncIterable, Generator from typing import Any -from unittest.mock import AsyncMock, Mock +from unittest.mock import AsyncMock import pytest from homeassistant.components import stt, tts from homeassistant.components.assist_pipeline import DOMAIN from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection +from homeassistant.config_entries import ConfigEntry, ConfigFlow from homeassistant.core import HomeAssistant +from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.setup import async_setup_component -from tests.common import MockModule, mock_integration, mock_platform +from tests.common import ( + MockConfigEntry, + MockModule, + MockPlatform, + mock_config_flow, + mock_integration, + mock_platform, +) from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import mock_get_cache_files, mock_init_cache_dir, @@ -21,7 +30,7 @@ from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unuse _TRANSCRIPT = "test transcript" -class MockSttProvider(stt.Provider): +class BaseProvider: """Mock STT provider.""" def __init__(self, hass: HomeAssistant, text: str) -> None: @@ -71,6 +80,14 @@ class MockSttProvider(stt.Provider): return stt.SpeechResult(self.text, stt.SpeechResultState.SUCCESS) +class MockSttProvider(BaseProvider, stt.Provider): + """Mock provider.""" + + +class MockSttProviderEntity(BaseProvider, stt.SpeechToTextEntity): + """Mock provider entity.""" + + class MockTTSProvider(tts.Provider): """Mock TTS provider.""" @@ -119,27 +136,96 @@ async def mock_stt_provider(hass) -> MockSttProvider: return MockSttProvider(hass, _TRANSCRIPT) +@pytest.fixture +def mock_stt_provider_entity(hass) -> MockSttProviderEntity: + """Test provider entity fixture.""" + return MockSttProviderEntity(hass, _TRANSCRIPT) + + +class MockSttPlatform(MockPlatform): + """Provide a fake STT platform.""" + + def __init__(self, *, async_get_engine, **kwargs): + """Initialize the stt platform.""" + super().__init__(**kwargs) + self.async_get_engine = async_get_engine + + +class MockFlow(ConfigFlow): + """Test flow.""" + + +@pytest.fixture +def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]: + """Mock config flow.""" + mock_platform(hass, "test.config_flow") + + with mock_config_flow("test", MockFlow): + yield + + @pytest.fixture async def init_components( hass: HomeAssistant, mock_stt_provider: MockSttProvider, + mock_stt_provider_entity: MockSttProviderEntity, + config_flow_fixture, mock_get_cache_files, # noqa: F811 mock_init_cache_dir, # noqa: F811, ): """Initialize relevant components with empty configs.""" - mock_integration(hass, MockModule(domain="test")) + + async def async_setup_entry_init( + hass: HomeAssistant, config_entry: ConfigEntry + ) -> bool: + """Set up test config entry.""" + await hass.config_entries.async_forward_entry_setup(config_entry, stt.DOMAIN) + return True + + async def async_unload_entry_init( + hass: HomeAssistant, config_entry: ConfigEntry + ) -> bool: + """Unload up test config entry.""" + await hass.config_entries.async_forward_entry_unload(config_entry, stt.DOMAIN) + return True + + async def async_setup_entry_stt_platform( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, + ) -> None: + """Set up test stt platform via config entry.""" + async_add_entities([mock_stt_provider_entity]) + + mock_integration( + hass, + MockModule( + "test", + async_setup_entry=async_setup_entry_init, + async_unload_entry=async_unload_entry_init, + ), + ) mock_platform(hass, "test.tts", MockTTS()) mock_platform( hass, "test.stt", - Mock(async_get_engine=AsyncMock(return_value=mock_stt_provider)), + MockSttPlatform( + async_get_engine=AsyncMock(return_value=mock_stt_provider), + async_setup_entry=async_setup_entry_stt_platform, + ), ) + mock_platform(hass, "test.config_flow") assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}}) assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}}) assert await async_setup_component(hass, "media_source", {}) assert await async_setup_component(hass, "assist_pipeline", {}) + config_entry = MockConfigEntry(domain="test") + config_entry.add_to_hass(hass) + assert await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() + @pytest.fixture def pipeline_storage(hass: HomeAssistant, init_components) -> PipelineStorageCollection: diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index 2e8b7cfb4dd..938f15e2024 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -1,5 +1,5 @@ # serializer version: 1 -# name: test_pipeline_from_audio_stream +# name: test_pipeline_from_audio_stream_auto list([ dict({ 'data': dict({ @@ -83,3 +83,171 @@ }), ]) # --- +# name: test_pipeline_from_audio_stream_entity + list([ + dict({ + 'data': dict({ + 'language': 'en-US', + 'pipeline': 'test_name', + }), + 'type': , + }), + dict({ + 'data': dict({ + 'engine': 'test', + 'metadata': dict({ + 'bit_rate': , + 'channel': , + 'codec': , + 'format': , + 'language': 'en-US', + 'sample_rate': , + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'stt_output': dict({ + 'text': 'test transcript', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'engine': 'homeassistant', + 'intent_input': 'test transcript', + }), + 'type': , + }), + dict({ + 'data': dict({ + 'intent_output': dict({ + 'conversation_id': None, + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'code': 'no_intent_match', + }), + 'language': 'en-US', + 'response_type': 'error', + 'speech': dict({ + 'plain': dict({ + 'extra_data': None, + 'speech': "Sorry, I couldn't understand that", + }), + }), + }), + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'engine': 'test', + 'tts_input': "Sorry, I couldn't understand that", + }), + 'type': , + }), + dict({ + 'data': dict({ + 'tts_output': dict({ + 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US", + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3', + }), + }), + 'type': , + }), + dict({ + 'data': None, + 'type': , + }), + ]) +# --- +# name: test_pipeline_from_audio_stream_legacy + list([ + dict({ + 'data': dict({ + 'language': 'en-US', + 'pipeline': 'test_name', + }), + 'type': , + }), + dict({ + 'data': dict({ + 'engine': 'test', + 'metadata': dict({ + 'bit_rate': , + 'channel': , + 'codec': , + 'format': , + 'language': 'en-US', + 'sample_rate': , + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'stt_output': dict({ + 'text': 'test transcript', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'engine': 'homeassistant', + 'intent_input': 'test transcript', + }), + 'type': , + }), + dict({ + 'data': dict({ + 'intent_output': dict({ + 'conversation_id': None, + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'code': 'no_intent_match', + }), + 'language': 'en-US', + 'response_type': 'error', + 'speech': dict({ + 'plain': dict({ + 'extra_data': None, + 'speech': "Sorry, I couldn't understand that", + }), + }), + }), + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'engine': 'test', + 'tts_input': "Sorry, I couldn't understand that", + }), + 'type': , + }), + dict({ + 'data': dict({ + 'tts_output': dict({ + 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US", + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3', + }), + }), + 'type': , + }), + dict({ + 'data': None, + 'type': , + }), + ]) +# --- diff --git a/tests/components/assist_pipeline/test_init.py b/tests/components/assist_pipeline/test_init.py index 7bdfc2e7b9e..0c0f051b999 100644 --- a/tests/components/assist_pipeline/test_init.py +++ b/tests/components/assist_pipeline/test_init.py @@ -6,11 +6,21 @@ from syrupy.assertion import SnapshotAssertion from homeassistant.components import assist_pipeline, stt from homeassistant.core import Context, HomeAssistant +from .conftest import MockSttProvider, MockSttProviderEntity -async def test_pipeline_from_audio_stream( - hass: HomeAssistant, mock_stt_provider, init_components, snapshot: SnapshotAssertion +from tests.typing import WebSocketGenerator + + +async def test_pipeline_from_audio_stream_auto( + hass: HomeAssistant, + mock_stt_provider: MockSttProvider, + init_components, + snapshot: SnapshotAssertion, ) -> None: - """Test creating a pipeline from an audio stream.""" + """Test creating a pipeline from an audio stream. + + In this test, no pipeline is specified. + """ events = [] @@ -42,3 +52,127 @@ async def test_pipeline_from_audio_stream( assert processed == snapshot assert mock_stt_provider.received == [b"part1", b"part2"] + + +async def test_pipeline_from_audio_stream_legacy( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + mock_stt_provider: MockSttProvider, + init_components, + snapshot: SnapshotAssertion, +) -> None: + """Test creating a pipeline from an audio stream. + + In this test, a pipeline using a legacy stt engine is used. + """ + client = await hass_ws_client(hass) + + events = [] + + async def audio_data(): + yield b"part1" + yield b"part2" + yield b"" + + # Create a pipeline using an stt entity + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline/create", + "conversation_engine": "homeassistant", + "language": "en-US", + "name": "test_name", + "stt_engine": "test", + "tts_engine": "test", + } + ) + msg = await client.receive_json() + assert msg["success"] + pipeline_id = msg["result"]["id"] + + # Use the created pipeline + await assist_pipeline.async_pipeline_from_audio_stream( + hass, + Context(), + events.append, + stt.SpeechMetadata( + language="", + format=stt.AudioFormats.WAV, + codec=stt.AudioCodecs.PCM, + bit_rate=stt.AudioBitRates.BITRATE_16, + sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, + channel=stt.AudioChannels.CHANNEL_MONO, + ), + audio_data(), + pipeline_id=pipeline_id, + ) + + processed = [] + for event in events: + as_dict = asdict(event) + as_dict.pop("timestamp") + processed.append(as_dict) + + assert processed == snapshot + assert mock_stt_provider.received == [b"part1", b"part2"] + + +async def test_pipeline_from_audio_stream_entity( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + mock_stt_provider_entity: MockSttProviderEntity, + init_components, + snapshot: SnapshotAssertion, +) -> None: + """Test creating a pipeline from an audio stream. + + In this test, a pipeline using am stt entity is used. + """ + client = await hass_ws_client(hass) + + events = [] + + async def audio_data(): + yield b"part1" + yield b"part2" + yield b"" + + # Create a pipeline using an stt entity + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline/create", + "conversation_engine": "homeassistant", + "language": "en-US", + "name": "test_name", + "stt_engine": mock_stt_provider_entity.entity_id, + "tts_engine": "test", + } + ) + msg = await client.receive_json() + assert msg["success"] + pipeline_id = msg["result"]["id"] + + # Use the created pipeline + await assist_pipeline.async_pipeline_from_audio_stream( + hass, + Context(), + events.append, + stt.SpeechMetadata( + language="", + format=stt.AudioFormats.WAV, + codec=stt.AudioCodecs.PCM, + bit_rate=stt.AudioBitRates.BITRATE_16, + sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, + channel=stt.AudioChannels.CHANNEL_MONO, + ), + audio_data(), + pipeline_id=pipeline_id, + ) + + processed = [] + for event in events: + as_dict = asdict(event) + as_dict.pop("timestamp") + processed.append(as_dict) + + assert processed == snapshot + assert mock_stt_provider_entity.received == [b"part1", b"part2"]