Allow passing in device_id to pipeline run WS API (#95139)

This commit is contained in:
Paulus Schoutsen 2023-06-23 22:29:56 -04:00 committed by GitHub
parent 3f10233833
commit c42d0feec1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 22 additions and 1 deletions

View file

@ -512,6 +512,8 @@ class PipelineRun:
"engine": self.intent_agent, "engine": self.intent_agent,
"language": self.pipeline.conversation_language, "language": self.pipeline.conversation_language,
"intent_input": intent_input, "intent_input": intent_input,
"conversation_id": conversation_id,
"device_id": device_id,
}, },
) )
) )

View file

@ -56,6 +56,7 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
vol.Optional("input"): dict, vol.Optional("input"): dict,
vol.Optional("pipeline"): str, vol.Optional("pipeline"): str,
vol.Optional("conversation_id"): vol.Any(str, None), vol.Optional("conversation_id"): vol.Any(str, None),
vol.Optional("device_id"): vol.Any(str, None),
vol.Optional("timeout"): vol.Any(float, int), vol.Optional("timeout"): vol.Any(float, int),
}, },
), ),
@ -105,6 +106,7 @@ async def websocket_run(
# Arguments to PipelineInput # Arguments to PipelineInput
input_args: dict[str, Any] = { input_args: dict[str, Any] = {
"conversation_id": msg.get("conversation_id"), "conversation_id": msg.get("conversation_id"),
"device_id": msg.get("device_id"),
} }
if start_stage == PipelineStage.STT: if start_stage == PipelineStage.STT:

View file

@ -32,6 +32,8 @@
}), }),
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant', 'engine': 'homeassistant',
'intent_input': 'test transcript', 'intent_input': 'test transcript',
'language': 'en', 'language': 'en',
@ -119,6 +121,8 @@
}), }),
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant', 'engine': 'homeassistant',
'intent_input': 'test transcript', 'intent_input': 'test transcript',
'language': 'en-US', 'language': 'en-US',
@ -206,6 +210,8 @@
}), }),
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant', 'engine': 'homeassistant',
'intent_input': 'test transcript', 'intent_input': 'test transcript',
'language': 'en-US', 'language': 'en-US',

View file

@ -31,6 +31,8 @@
# --- # ---
# name: test_audio_pipeline.3 # name: test_audio_pipeline.3
dict({ dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant', 'engine': 'homeassistant',
'intent_input': 'test transcript', 'intent_input': 'test transcript',
'language': 'en', 'language': 'en',
@ -107,6 +109,8 @@
# --- # ---
# name: test_audio_pipeline_debug.3 # name: test_audio_pipeline_debug.3
dict({ dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant', 'engine': 'homeassistant',
'intent_input': 'test transcript', 'intent_input': 'test transcript',
'language': 'en', 'language': 'en',
@ -163,6 +167,8 @@
# --- # ---
# name: test_intent_failed.1 # name: test_intent_failed.1
dict({ dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant', 'engine': 'homeassistant',
'intent_input': 'Are the lights on?', 'intent_input': 'Are the lights on?',
'language': 'en', 'language': 'en',
@ -180,6 +186,8 @@
# --- # ---
# name: test_intent_timeout.1 # name: test_intent_timeout.1
dict({ dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant', 'engine': 'homeassistant',
'intent_input': 'Are the lights on?', 'intent_input': 'Are the lights on?',
'language': 'en', 'language': 'en',
@ -249,6 +257,8 @@
# --- # ---
# name: test_text_only_pipeline.1 # name: test_text_only_pipeline.1
dict({ dict({
'conversation_id': 'mock-conversation-id',
'device_id': 'mock-device-id',
'engine': 'homeassistant', 'engine': 'homeassistant',
'intent_input': 'Are the lights on?', 'intent_input': 'Are the lights on?',
'language': 'en', 'language': 'en',

View file

@ -28,6 +28,8 @@ async def test_text_only_pipeline(
"start_stage": "intent", "start_stage": "intent",
"end_stage": "intent", "end_stage": "intent",
"input": {"text": "Are the lights on?"}, "input": {"text": "Are the lights on?"},
"conversation_id": "mock-conversation-id",
"device_id": "mock-device-id",
} }
) )
@ -954,7 +956,6 @@ async def test_list_pipelines(
) -> None: ) -> None:
"""Test we can list pipelines.""" """Test we can list pipelines."""
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
hass.data[DOMAIN]
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"}) await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
msg = await client.receive_json() msg = await client.receive_json()