Allow targeting conversation agent as pipeline (#119556)

* Allow targetting conversation agent as pipeline

* Test that we can use a conversation entity as an assist pipeline

* Add test for WS get

---------

Co-authored-by: Michael Hansen <mike@rhasspy.org>
This commit is contained in:
Paulus Schoutsen 2024-07-09 17:56:53 +02:00 committed by GitHub
parent 69ed730101
commit 154da1b18b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 111 additions and 4 deletions

View file

@ -259,6 +259,22 @@ async def async_create_default_pipeline(
return await pipeline_store.async_create_item(pipeline_settings)
@callback
def _async_get_pipeline_from_conversation_entity(
hass: HomeAssistant, entity_id: str
) -> Pipeline:
"""Get a pipeline by conversation entity ID."""
entity = hass.states.get(entity_id)
settings = _async_resolve_default_pipeline_settings(
hass,
pipeline_name=entity.name if entity else entity_id,
conversation_engine_id=entity_id,
)
settings["id"] = entity_id
return Pipeline.from_json(settings)
@callback
def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> Pipeline:
"""Get a pipeline by id or the preferred pipeline."""
@ -268,6 +284,9 @@ def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> P
# A pipeline was not specified, use the preferred one
pipeline_id = pipeline_data.pipeline_store.async_get_preferred_item()
if pipeline_id.startswith("conversation."):
return _async_get_pipeline_from_conversation_entity(hass, pipeline_id)
pipeline = pipeline_data.pipeline_store.data.get(pipeline_id)
# If invalid pipeline ID was specified
@ -1670,6 +1689,12 @@ class PipelineStorageCollectionWebsocket(
if item_id is None:
item_id = self.storage_collection.async_get_preferred_item()
if item_id.startswith("conversation.") and hass.states.get(item_id):
connection.send_result(
msg["id"], _async_get_pipeline_from_conversation_entity(hass, item_id)
)
return
if item_id not in self.storage_collection.data:
connection.send_error(
msg["id"],

View file

@ -663,7 +663,10 @@
# name: test_stt_stream_failed.2
None
# ---
# name: test_text_only_pipeline
# name: test_text_only_pipeline.3
None
# ---
# name: test_text_only_pipeline[extra_msg0]
dict({
'language': 'en',
'pipeline': <ANY>,
@ -673,7 +676,7 @@
}),
})
# ---
# name: test_text_only_pipeline.1
# name: test_text_only_pipeline[extra_msg0].1
dict({
'conversation_id': 'mock-conversation-id',
'device_id': 'mock-device-id',
@ -682,7 +685,7 @@
'language': 'en',
})
# ---
# name: test_text_only_pipeline.2
# name: test_text_only_pipeline[extra_msg0].2
dict({
'intent_output': dict({
'conversation_id': None,
@ -704,7 +707,51 @@
}),
})
# ---
# name: test_text_only_pipeline.3
# name: test_text_only_pipeline[extra_msg0].3
None
# ---
# name: test_text_only_pipeline[extra_msg1]
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': None,
'timeout': 300,
}),
})
# ---
# name: test_text_only_pipeline[extra_msg1].1
dict({
'conversation_id': 'mock-conversation-id',
'device_id': 'mock-device-id',
'engine': 'conversation.home_assistant',
'intent_input': 'Are the lights on?',
'language': 'en',
})
# ---
# name: test_text_only_pipeline[extra_msg1].2
dict({
'intent_output': dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'code': 'no_valid_targets',
}),
'language': 'en',
'response_type': 'error',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': 'Sorry, I am not aware of any area called are',
}),
}),
}),
}),
})
# ---
# name: test_text_only_pipeline[extra_msg1].3
None
# ---
# name: test_text_pipeline_timeout

View file

@ -5,6 +5,7 @@ import base64
from typing import Any
from unittest.mock import ANY, patch
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components.assist_pipeline.const import DOMAIN
@ -23,11 +24,19 @@ from tests.common import MockConfigEntry
from tests.typing import WebSocketGenerator
@pytest.mark.parametrize(
"extra_msg",
[
{},
{"pipeline": "conversation.home_assistant"},
],
)
async def test_text_only_pipeline(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
extra_msg: dict[str, Any],
) -> None:
"""Test events from a pipeline run with text input (no STT/TTS)."""
events = []
@ -42,6 +51,7 @@ async def test_text_only_pipeline(
"conversation_id": "mock-conversation-id",
"device_id": "mock-device-id",
}
| extra_msg
)
# result
@ -1180,6 +1190,31 @@ async def test_get_pipeline(
"wake_word_id": None,
}
# Get conversation agent as pipeline
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/get",
"pipeline_id": "conversation.home_assistant",
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"conversation_engine": "conversation.home_assistant",
"conversation_language": "en",
"id": ANY,
"language": "en",
"name": "Home Assistant",
# It found these defaults
"stt_engine": "test",
"stt_language": "en-US",
"tts_engine": "test",
"tts_language": "en-US",
"tts_voice": "james_earl_jones",
"wake_word_entity": None,
"wake_word_id": None,
}
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/get",