Allow targeting conversation agent as pipeline (#119556)
* Allow targetting conversation agent as pipeline * Test that we can use a conversation entity as an assist pipeline * Add test for WS get --------- Co-authored-by: Michael Hansen <mike@rhasspy.org>
This commit is contained in:
parent
69ed730101
commit
154da1b18b
3 changed files with 111 additions and 4 deletions
|
@ -259,6 +259,22 @@ async def async_create_default_pipeline(
|
|||
return await pipeline_store.async_create_item(pipeline_settings)
|
||||
|
||||
|
||||
@callback
|
||||
def _async_get_pipeline_from_conversation_entity(
|
||||
hass: HomeAssistant, entity_id: str
|
||||
) -> Pipeline:
|
||||
"""Get a pipeline by conversation entity ID."""
|
||||
entity = hass.states.get(entity_id)
|
||||
settings = _async_resolve_default_pipeline_settings(
|
||||
hass,
|
||||
pipeline_name=entity.name if entity else entity_id,
|
||||
conversation_engine_id=entity_id,
|
||||
)
|
||||
settings["id"] = entity_id
|
||||
|
||||
return Pipeline.from_json(settings)
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> Pipeline:
|
||||
"""Get a pipeline by id or the preferred pipeline."""
|
||||
|
@ -268,6 +284,9 @@ def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> P
|
|||
# A pipeline was not specified, use the preferred one
|
||||
pipeline_id = pipeline_data.pipeline_store.async_get_preferred_item()
|
||||
|
||||
if pipeline_id.startswith("conversation."):
|
||||
return _async_get_pipeline_from_conversation_entity(hass, pipeline_id)
|
||||
|
||||
pipeline = pipeline_data.pipeline_store.data.get(pipeline_id)
|
||||
|
||||
# If invalid pipeline ID was specified
|
||||
|
@ -1670,6 +1689,12 @@ class PipelineStorageCollectionWebsocket(
|
|||
if item_id is None:
|
||||
item_id = self.storage_collection.async_get_preferred_item()
|
||||
|
||||
if item_id.startswith("conversation.") and hass.states.get(item_id):
|
||||
connection.send_result(
|
||||
msg["id"], _async_get_pipeline_from_conversation_entity(hass, item_id)
|
||||
)
|
||||
return
|
||||
|
||||
if item_id not in self.storage_collection.data:
|
||||
connection.send_error(
|
||||
msg["id"],
|
||||
|
|
|
@ -663,7 +663,10 @@
|
|||
# name: test_stt_stream_failed.2
|
||||
None
|
||||
# ---
|
||||
# name: test_text_only_pipeline
|
||||
# name: test_text_only_pipeline.3
|
||||
None
|
||||
# ---
|
||||
# name: test_text_only_pipeline[extra_msg0]
|
||||
dict({
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
|
@ -673,7 +676,7 @@
|
|||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_text_only_pipeline.1
|
||||
# name: test_text_only_pipeline[extra_msg0].1
|
||||
dict({
|
||||
'conversation_id': 'mock-conversation-id',
|
||||
'device_id': 'mock-device-id',
|
||||
|
@ -682,7 +685,7 @@
|
|||
'language': 'en',
|
||||
})
|
||||
# ---
|
||||
# name: test_text_only_pipeline.2
|
||||
# name: test_text_only_pipeline[extra_msg0].2
|
||||
dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
|
@ -704,7 +707,51 @@
|
|||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_text_only_pipeline.3
|
||||
# name: test_text_only_pipeline[extra_msg0].3
|
||||
None
|
||||
# ---
|
||||
# name: test_text_only_pipeline[extra_msg1]
|
||||
dict({
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': None,
|
||||
'timeout': 300,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_text_only_pipeline[extra_msg1].1
|
||||
dict({
|
||||
'conversation_id': 'mock-conversation-id',
|
||||
'device_id': 'mock-device-id',
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'Are the lights on?',
|
||||
'language': 'en',
|
||||
})
|
||||
# ---
|
||||
# name: test_text_only_pipeline[extra_msg1].2
|
||||
dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
'data': dict({
|
||||
'code': 'no_valid_targets',
|
||||
}),
|
||||
'language': 'en',
|
||||
'response_type': 'error',
|
||||
'speech': dict({
|
||||
'plain': dict({
|
||||
'extra_data': None,
|
||||
'speech': 'Sorry, I am not aware of any area called are',
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_text_only_pipeline[extra_msg1].3
|
||||
None
|
||||
# ---
|
||||
# name: test_text_pipeline_timeout
|
||||
|
|
|
@ -5,6 +5,7 @@ import base64
|
|||
from typing import Any
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components.assist_pipeline.const import DOMAIN
|
||||
|
@ -23,11 +24,19 @@ from tests.common import MockConfigEntry
|
|||
from tests.typing import WebSocketGenerator
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"extra_msg",
|
||||
[
|
||||
{},
|
||||
{"pipeline": "conversation.home_assistant"},
|
||||
],
|
||||
)
|
||||
async def test_text_only_pipeline(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
extra_msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test events from a pipeline run with text input (no STT/TTS)."""
|
||||
events = []
|
||||
|
@ -42,6 +51,7 @@ async def test_text_only_pipeline(
|
|||
"conversation_id": "mock-conversation-id",
|
||||
"device_id": "mock-device-id",
|
||||
}
|
||||
| extra_msg
|
||||
)
|
||||
|
||||
# result
|
||||
|
@ -1180,6 +1190,31 @@ async def test_get_pipeline(
|
|||
"wake_word_id": None,
|
||||
}
|
||||
|
||||
# Get conversation agent as pipeline
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/get",
|
||||
"pipeline_id": "conversation.home_assistant",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"conversation_engine": "conversation.home_assistant",
|
||||
"conversation_language": "en",
|
||||
"id": ANY,
|
||||
"language": "en",
|
||||
"name": "Home Assistant",
|
||||
# It found these defaults
|
||||
"stt_engine": "test",
|
||||
"stt_language": "en-US",
|
||||
"tts_engine": "test",
|
||||
"tts_language": "en-US",
|
||||
"tts_voice": "james_earl_jones",
|
||||
"wake_word_entity": None,
|
||||
"wake_word_id": None,
|
||||
}
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/get",
|
||||
|
|
Loading…
Add table
Reference in a new issue