From 154da1b18b4dc31dab8d5cd5b9b7d8f9c1e0d584 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Tue, 9 Jul 2024 17:56:53 +0200 Subject: [PATCH] Allow targeting conversation agent as pipeline (#119556) * Allow targetting conversation agent as pipeline * Test that we can use a conversation entity as an assist pipeline * Add test for WS get --------- Co-authored-by: Michael Hansen --- .../components/assist_pipeline/pipeline.py | 25 +++++++++ .../snapshots/test_websocket.ambr | 55 +++++++++++++++++-- .../assist_pipeline/test_websocket.py | 35 ++++++++++++ 3 files changed, 111 insertions(+), 4 deletions(-) diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index d8fd15900b8..ecf361cb67c 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -259,6 +259,22 @@ async def async_create_default_pipeline( return await pipeline_store.async_create_item(pipeline_settings) +@callback +def _async_get_pipeline_from_conversation_entity( + hass: HomeAssistant, entity_id: str +) -> Pipeline: + """Get a pipeline by conversation entity ID.""" + entity = hass.states.get(entity_id) + settings = _async_resolve_default_pipeline_settings( + hass, + pipeline_name=entity.name if entity else entity_id, + conversation_engine_id=entity_id, + ) + settings["id"] = entity_id + + return Pipeline.from_json(settings) + + @callback def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> Pipeline: """Get a pipeline by id or the preferred pipeline.""" @@ -268,6 +284,9 @@ def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> P # A pipeline was not specified, use the preferred one pipeline_id = pipeline_data.pipeline_store.async_get_preferred_item() + if pipeline_id.startswith("conversation."): + return _async_get_pipeline_from_conversation_entity(hass, pipeline_id) + pipeline = pipeline_data.pipeline_store.data.get(pipeline_id) # If invalid pipeline ID was specified @@ -1670,6 +1689,12 @@ class PipelineStorageCollectionWebsocket( if item_id is None: item_id = self.storage_collection.async_get_preferred_item() + if item_id.startswith("conversation.") and hass.states.get(item_id): + connection.send_result( + msg["id"], _async_get_pipeline_from_conversation_entity(hass, item_id) + ) + return + if item_id not in self.storage_collection.data: connection.send_error( msg["id"], diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index 2c506215c68..0b04b67bb22 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -663,7 +663,10 @@ # name: test_stt_stream_failed.2 None # --- -# name: test_text_only_pipeline +# name: test_text_only_pipeline.3 + None +# --- +# name: test_text_only_pipeline[extra_msg0] dict({ 'language': 'en', 'pipeline': , @@ -673,7 +676,7 @@ }), }) # --- -# name: test_text_only_pipeline.1 +# name: test_text_only_pipeline[extra_msg0].1 dict({ 'conversation_id': 'mock-conversation-id', 'device_id': 'mock-device-id', @@ -682,7 +685,7 @@ 'language': 'en', }) # --- -# name: test_text_only_pipeline.2 +# name: test_text_only_pipeline[extra_msg0].2 dict({ 'intent_output': dict({ 'conversation_id': None, @@ -704,7 +707,51 @@ }), }) # --- -# name: test_text_only_pipeline.3 +# name: test_text_only_pipeline[extra_msg0].3 + None +# --- +# name: test_text_only_pipeline[extra_msg1] + dict({ + 'language': 'en', + 'pipeline': , + 'runner_data': dict({ + 'stt_binary_handler_id': None, + 'timeout': 300, + }), + }) +# --- +# name: test_text_only_pipeline[extra_msg1].1 + dict({ + 'conversation_id': 'mock-conversation-id', + 'device_id': 'mock-device-id', + 'engine': 'conversation.home_assistant', + 'intent_input': 'Are the lights on?', + 'language': 'en', + }) +# --- +# name: test_text_only_pipeline[extra_msg1].2 + dict({ + 'intent_output': dict({ + 'conversation_id': None, + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'code': 'no_valid_targets', + }), + 'language': 'en', + 'response_type': 'error', + 'speech': dict({ + 'plain': dict({ + 'extra_data': None, + 'speech': 'Sorry, I am not aware of any area called are', + }), + }), + }), + }), + }) +# --- +# name: test_text_only_pipeline[extra_msg1].3 None # --- # name: test_text_pipeline_timeout diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index e08dd9685ea..de8ddc7ccc7 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -5,6 +5,7 @@ import base64 from typing import Any from unittest.mock import ANY, patch +import pytest from syrupy.assertion import SnapshotAssertion from homeassistant.components.assist_pipeline.const import DOMAIN @@ -23,11 +24,19 @@ from tests.common import MockConfigEntry from tests.typing import WebSocketGenerator +@pytest.mark.parametrize( + "extra_msg", + [ + {}, + {"pipeline": "conversation.home_assistant"}, + ], +) async def test_text_only_pipeline( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components, snapshot: SnapshotAssertion, + extra_msg: dict[str, Any], ) -> None: """Test events from a pipeline run with text input (no STT/TTS).""" events = [] @@ -42,6 +51,7 @@ async def test_text_only_pipeline( "conversation_id": "mock-conversation-id", "device_id": "mock-device-id", } + | extra_msg ) # result @@ -1180,6 +1190,31 @@ async def test_get_pipeline( "wake_word_id": None, } + # Get conversation agent as pipeline + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline/get", + "pipeline_id": "conversation.home_assistant", + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "conversation_engine": "conversation.home_assistant", + "conversation_language": "en", + "id": ANY, + "language": "en", + "name": "Home Assistant", + # It found these defaults + "stt_engine": "test", + "stt_language": "en-US", + "tts_engine": "test", + "tts_language": "en-US", + "tts_voice": "james_earl_jones", + "wake_word_entity": None, + "wake_word_id": None, + } + await client.send_json_auto_id( { "type": "assist_pipeline/pipeline/get",