Add WS API for debugging previous assist_pipeline runs (#91541)

* Add WS API for debugging previous assist_pipeline runs

* Improve typing
This commit is contained in:
Erik Montnemery 2023-04-17 17:48:02 +02:00 committed by GitHub
parent b597415b01
commit 0ecd23baee
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 564 additions and 32 deletions

View file

@ -24,6 +24,7 @@ from homeassistant.helpers.collection import (
)
from homeassistant.helpers.storage import Store
from homeassistant.util import dt as dt_util, ulid as ulid_util
from homeassistant.util.limited_size_dict import LimitedSizeDict
from .const import DOMAIN
from .error import (
@ -46,6 +47,8 @@ STORAGE_FIELDS = {
vol.Required("tts_engine"): str,
}
STORED_PIPELINE_RUNS = 10
SAVE_DELAY = 10
@ -53,14 +56,14 @@ async def async_get_pipeline(
hass: HomeAssistant, pipeline_id: str | None = None, language: str | None = None
) -> Pipeline | None:
"""Get a pipeline by id or create one for a language."""
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
pipeline_data: PipelineData = hass.data[DOMAIN]
if pipeline_id is not None:
return pipeline_store.data.get(pipeline_id)
return pipeline_data.pipeline_store.data.get(pipeline_id)
# Construct a pipeline for the required/configured language
language = language or hass.config.language
return await pipeline_store.async_create_item(
return await pipeline_data.pipeline_store.async_create_item(
{
"name": language,
"language": language,
@ -171,6 +174,8 @@ class PipelineRun:
tts_engine: str | None = None
tts_options: dict | None = None
id: str = field(default_factory=ulid_util.ulid)
def __post_init__(self) -> None:
"""Set language for pipeline."""
self.language = self.pipeline.language or self.hass.config.language
@ -181,6 +186,23 @@ class PipelineRun:
):
raise InvalidPipelineStagesError(self.start_stage, self.end_stage)
pipeline_data: PipelineData = self.hass.data[DOMAIN]
if self.pipeline.id not in pipeline_data.pipeline_runs:
pipeline_data.pipeline_runs[self.pipeline.id] = LimitedSizeDict(
size_limit=STORED_PIPELINE_RUNS
)
pipeline_data.pipeline_runs[self.pipeline.id][self.id] = []
@callback
def process_event(self, event: PipelineEvent) -> None:
"""Log an event and call listener."""
self.event_callback(event)
pipeline_data: PipelineData = self.hass.data[DOMAIN]
if self.id not in pipeline_data.pipeline_runs[self.pipeline.id]:
# This run has been evicted from the logged pipeline runs already
return
pipeline_data.pipeline_runs[self.pipeline.id][self.id].append(event)
def start(self) -> None:
"""Emit run start event."""
data = {
@ -190,11 +212,11 @@ class PipelineRun:
if self.runner_data is not None:
data["runner_data"] = self.runner_data
self.event_callback(PipelineEvent(PipelineEventType.RUN_START, data))
self.process_event(PipelineEvent(PipelineEventType.RUN_START, data))
def end(self) -> None:
"""Emit run end event."""
self.event_callback(
self.process_event(
PipelineEvent(
PipelineEventType.RUN_END,
)
@ -233,7 +255,7 @@ class PipelineRun:
engine = self.stt_provider.name
self.event_callback(
self.process_event(
PipelineEvent(
PipelineEventType.STT_START,
{
@ -268,7 +290,7 @@ class PipelineRun:
code="stt-no-text-recognized", message="No text recognized"
)
self.event_callback(
self.process_event(
PipelineEvent(
PipelineEventType.STT_END,
{
@ -306,7 +328,7 @@ class PipelineRun:
if self.intent_agent is None:
raise RuntimeError("Recognize intent was not prepared")
self.event_callback(
self.process_event(
PipelineEvent(
PipelineEventType.INTENT_START,
{
@ -334,7 +356,7 @@ class PipelineRun:
_LOGGER.debug("conversation result %s", conversation_result)
self.event_callback(
self.process_event(
PipelineEvent(
PipelineEventType.INTENT_END,
{"intent_output": conversation_result.as_dict()},
@ -379,7 +401,7 @@ class PipelineRun:
if self.tts_engine is None:
raise RuntimeError("Text to speech was not prepared")
self.event_callback(
self.process_event(
PipelineEvent(
PipelineEventType.TTS_START,
{
@ -412,7 +434,7 @@ class PipelineRun:
_LOGGER.debug("TTS result %s", tts_media)
self.event_callback(
self.process_event(
PipelineEvent(
PipelineEventType.TTS_END,
{
@ -480,7 +502,7 @@ class PipelineInput:
await self.run.text_to_speech(tts_input)
except PipelineError as err:
self.run.event_callback(
self.run.process_event(
PipelineEvent(
PipelineEventType.ERROR,
{"code": err.code, "message": err.message},
@ -691,6 +713,14 @@ class PipelineStorageCollectionWebsocket(
connection.send_result(msg["id"])
@dataclass
class PipelineData:
"""Store and debug data stored in hass.data."""
pipeline_runs: dict[str, LimitedSizeDict[str, list[PipelineEvent]]]
pipeline_store: PipelineStorageCollection
async def async_setup_pipeline_store(hass: HomeAssistant) -> None:
"""Set up the pipeline storage collection."""
pipeline_store = PipelineStorageCollection(
@ -700,4 +730,4 @@ async def async_setup_pipeline_store(hass: HomeAssistant) -> None:
PipelineStorageCollectionWebsocket(
pipeline_store, f"{DOMAIN}/pipeline", "pipeline", STORAGE_FIELDS, STORAGE_FIELDS
).async_setup(hass)
hass.data[DOMAIN] = pipeline_store
hass.data[DOMAIN] = PipelineData({}, pipeline_store)

View file

@ -12,7 +12,9 @@ from homeassistant.components import stt, websocket_api
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv
from .const import DOMAIN
from .pipeline import (
PipelineData,
PipelineError,
PipelineEvent,
PipelineEventType,
@ -69,6 +71,8 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
),
),
)
websocket_api.async_register_command(hass, websocket_list_runs)
websocket_api.async_register_command(hass, websocket_get_run)
@websocket_api.async_response
@ -193,14 +197,82 @@ async def websocket_run(
async with async_timeout.timeout(timeout):
await run_task
except asyncio.TimeoutError:
connection.send_event(
msg["id"],
pipeline_input.run.process_event(
PipelineEvent(
PipelineEventType.ERROR,
{"code": "timeout", "message": "Timeout running pipeline"},
),
)
)
finally:
if unregister_handler is not None:
# Unregister binary handler
unregister_handler()
@callback
@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_pipeline/pipeline_debug/list",
vol.Required("pipeline_id"): str,
}
)
def websocket_list_runs(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""List pipeline runs for which debug data is available."""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = msg["pipeline_id"]
if pipeline_id not in pipeline_data.pipeline_runs:
connection.send_result(msg["id"], {"pipeline_runs": []})
return
pipeline_runs = pipeline_data.pipeline_runs[pipeline_id]
connection.send_result(msg["id"], {"pipeline_runs": list(pipeline_runs)})
@callback
@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_pipeline/pipeline_debug/get",
vol.Required("pipeline_id"): str,
vol.Required("pipeline_run_id"): str,
}
)
def websocket_get_run(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Get debug data for a pipeline run."""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = msg["pipeline_id"]
pipeline_run_id = msg["pipeline_run_id"]
if pipeline_id not in pipeline_data.pipeline_runs:
connection.send_error(
msg["id"],
websocket_api.const.ERR_NOT_FOUND,
f"pipeline_id {pipeline_id} not found",
)
return
pipeline_runs = pipeline_data.pipeline_runs[pipeline_id]
if pipeline_run_id not in pipeline_runs:
connection.send_error(
msg["id"],
websocket_api.const.ERR_NOT_FOUND,
f"pipeline_run_id {pipeline_run_id} not found",
)
return
connection.send_result(
msg["id"],
{"events": pipeline_runs[pipeline_run_id]},
)

View file

@ -14,6 +14,7 @@ import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.json import ExtendedJSONEncoder
from homeassistant.helpers.storage import Store
from homeassistant.helpers.typing import ConfigType
from homeassistant.util.limited_size_dict import LimitedSizeDict
from . import websocket_api
from .const import (
@ -24,7 +25,6 @@ from .const import (
DEFAULT_STORED_TRACES,
)
from .models import ActionTrace, BaseTrace, RestoredTrace
from .utils import LimitedSizeDict
_LOGGER = logging.getLogger(__name__)

View file

@ -72,6 +72,79 @@
}),
})
# ---
# name: test_audio_pipeline_debug
dict({
'language': 'en-US',
'pipeline': 'en-US',
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 30,
}),
})
# ---
# name: test_audio_pipeline_debug.1
dict({
'engine': 'test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'language': 'en-US',
'sample_rate': 16000,
}),
})
# ---
# name: test_audio_pipeline_debug.2
dict({
'stt_output': dict({
'text': 'test transcript',
}),
})
# ---
# name: test_audio_pipeline_debug.3
dict({
'engine': 'homeassistant',
'intent_input': 'test transcript',
})
# ---
# name: test_audio_pipeline_debug.4
dict({
'intent_output': dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'code': 'no_intent_match',
}),
'language': 'en-US',
'response_type': 'error',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': "Sorry, I couldn't understand that",
}),
}),
}),
}),
})
# ---
# name: test_audio_pipeline_debug.5
dict({
'engine': 'test',
'tts_input': "Sorry, I couldn't understand that",
})
# ---
# name: test_audio_pipeline_debug.6
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
}),
})
# ---
# name: test_intent_failed
dict({
'language': 'en-US',

View file

@ -5,6 +5,7 @@ from homeassistant.components.assist_pipeline.const import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import (
STORAGE_KEY,
STORAGE_VERSION,
PipelineData,
PipelineStorageCollection,
)
from homeassistant.core import HomeAssistant
@ -42,7 +43,8 @@ async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
]
pipeline_ids = []
store1: PipelineStorageCollection = hass.data[DOMAIN]
pipeline_data: PipelineData = hass.data[DOMAIN]
store1 = pipeline_data.pipeline_store
for pipeline in pipelines:
pipeline_ids.append((await store1.async_create_item(pipeline)).id)
assert len(store1.data) == 3
@ -103,6 +105,7 @@ async def test_loading_datasets_from_storage(
assert await async_setup_component(hass, "assist_pipeline", {})
store: PipelineStorageCollection = hass.data[DOMAIN]
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 3
assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY"

View file

@ -5,10 +5,7 @@ from unittest.mock import ANY, MagicMock, patch
from syrupy.assertion import SnapshotAssertion
from homeassistant.components.assist_pipeline.const import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import (
Pipeline,
PipelineStorageCollection,
)
from homeassistant.components.assist_pipeline.pipeline import Pipeline, PipelineData
from homeassistant.core import HomeAssistant
from tests.typing import WebSocketGenerator
@ -21,6 +18,7 @@ async def test_text_only_pipeline(
snapshot: SnapshotAssertion,
) -> None:
"""Test events from a pipeline run with text input (no STT/TTS)."""
events = []
client = await hass_ws_client(hass)
await client.send_json_auto_id(
@ -40,20 +38,39 @@ async def test_text_only_pipeline(
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# intent
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# run end
msg = await client.receive_json()
assert msg["event"]["type"] == "run-end"
assert msg["event"]["data"] is None
events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = list(pipeline_data.pipeline_runs)[0]
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}
async def test_audio_pipeline(
@ -63,6 +80,7 @@ async def test_audio_pipeline(
snapshot: SnapshotAssertion,
) -> None:
"""Test events from a pipeline run with audio input/output."""
events = []
client = await hass_ws_client(hass)
await client.send_json_auto_id(
@ -84,11 +102,13 @@ async def test_audio_pipeline(
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# stt
msg = await client.receive_json()
assert msg["event"]["type"] == "stt-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# End of audio stream (handler id + empty payload)
await client.send_bytes(bytes([1]))
@ -96,29 +116,50 @@ async def test_audio_pipeline(
msg = await client.receive_json()
assert msg["event"]["type"] == "stt-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# intent
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# text to speech
msg = await client.receive_json()
assert msg["event"]["type"] == "tts-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
msg = await client.receive_json()
assert msg["event"]["type"] == "tts-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# run end
msg = await client.receive_json()
assert msg["event"]["type"] == "run-end"
assert msg["event"]["data"] is None
events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = list(pipeline_data.pipeline_runs)[0]
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}
async def test_intent_timeout(
@ -128,6 +169,7 @@ async def test_intent_timeout(
snapshot: SnapshotAssertion,
) -> None:
"""Test partial pipeline run with conversation agent timeout."""
events = []
client = await hass_ws_client(hass)
async def sleepy_converse(*args, **kwargs):
@ -155,16 +197,34 @@ async def test_intent_timeout(
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# intent
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# timeout error
msg = await client.receive_json()
assert msg["event"]["type"] == "error"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = list(pipeline_data.pipeline_runs)[0]
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}
async def test_text_pipeline_timeout(
@ -174,6 +234,7 @@ async def test_text_pipeline_timeout(
snapshot: SnapshotAssertion,
) -> None:
"""Test text-only pipeline run with immediate timeout."""
events = []
client = await hass_ws_client(hass)
async def sleepy_run(*args, **kwargs):
@ -201,6 +262,22 @@ async def test_text_pipeline_timeout(
msg = await client.receive_json()
assert msg["event"]["type"] == "error"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = list(pipeline_data.pipeline_runs)[0]
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}
async def test_intent_failed(
@ -210,6 +287,7 @@ async def test_intent_failed(
snapshot: SnapshotAssertion,
) -> None:
"""Test text-only pipeline run with conversation agent error."""
events = []
client = await hass_ws_client(hass)
with patch(
@ -233,16 +311,34 @@ async def test_intent_failed(
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# intent start
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# intent error
msg = await client.receive_json()
assert msg["event"]["type"] == "error"
assert msg["event"]["data"]["code"] == "intent-failed"
events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = list(pipeline_data.pipeline_runs)[0]
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}
async def test_audio_pipeline_timeout(
@ -252,6 +348,7 @@ async def test_audio_pipeline_timeout(
snapshot: SnapshotAssertion,
) -> None:
"""Test audio pipeline run with immediate timeout."""
events = []
client = await hass_ws_client(hass)
async def sleepy_run(*args, **kwargs):
@ -281,6 +378,22 @@ async def test_audio_pipeline_timeout(
msg = await client.receive_json()
assert msg["event"]["type"] == "error"
assert msg["event"]["data"]["code"] == "timeout"
events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = list(pipeline_data.pipeline_runs)[0]
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}
async def test_stt_provider_missing(
@ -320,12 +433,13 @@ async def test_stt_stream_failed(
snapshot: SnapshotAssertion,
) -> None:
"""Test events from a pipeline run with a non-existent STT provider."""
events = []
client = await hass_ws_client(hass)
with patch(
"tests.components.assist_pipeline.conftest.MockSttProvider.async_process_audio_stream",
new=MagicMock(side_effect=RuntimeError),
):
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
@ -345,11 +459,13 @@ async def test_stt_stream_failed(
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# stt
msg = await client.receive_json()
assert msg["event"]["type"] == "stt-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# End of audio stream (handler id + empty payload)
await client.send_bytes(b"1")
@ -358,6 +474,22 @@ async def test_stt_stream_failed(
msg = await client.receive_json()
assert msg["event"]["type"] == "error"
assert msg["event"]["data"]["code"] == "stt-stream-failed"
events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = list(pipeline_data.pipeline_runs)[0]
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}
async def test_tts_failed(
@ -367,15 +499,15 @@ async def test_tts_failed(
snapshot: SnapshotAssertion,
) -> None:
"""Test pipeline run with text to speech error."""
events = []
client = await hass_ws_client(hass)
with patch(
"homeassistant.components.media_source.async_resolve_media",
new=MagicMock(return_value=RuntimeError),
):
await client.send_json(
await client.send_json_auto_id(
{
"id": 5,
"type": "assist_pipeline/run",
"start_stage": "tts",
"end_stage": "tts",
@ -391,16 +523,34 @@ async def test_tts_failed(
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# tts start
msg = await client.receive_json()
assert msg["event"]["type"] == "tts-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# tts error
msg = await client.receive_json()
assert msg["event"]["type"] == "error"
assert msg["event"]["data"]["code"] == "tts-failed"
events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = list(pipeline_data.pipeline_runs)[0]
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}
async def test_invalid_stage_order(
@ -428,7 +578,8 @@ async def test_add_pipeline(
) -> None:
"""Test we can add a pipeline."""
client = await hass_ws_client(hass)
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_store = pipeline_data.pipeline_store
await client.send_json_auto_id(
{
@ -468,7 +619,8 @@ async def test_delete_pipeline(
) -> None:
"""Test we can delete a pipeline."""
client = await hass_ws_client(hass)
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_store = pipeline_data.pipeline_store
await client.send_json_auto_id(
{
@ -542,7 +694,8 @@ async def test_list_pipelines(
) -> None:
"""Test we can list pipelines."""
client = await hass_ws_client(hass)
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_store = pipeline_data.pipeline_store
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
msg = await client.receive_json()
@ -586,7 +739,8 @@ async def test_update_pipeline(
) -> None:
"""Test we can list pipelines."""
client = await hass_ws_client(hass)
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_store = pipeline_data.pipeline_store
await client.send_json_auto_id(
{
@ -660,7 +814,8 @@ async def test_set_preferred_pipeline(
) -> None:
"""Test updating the preferred pipeline."""
client = await hass_ws_client(hass)
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_store = pipeline_data.pipeline_store
await client.send_json_auto_id(
{
@ -715,3 +870,202 @@ async def test_set_preferred_pipeline_wrong_id(
)
msg = await client.receive_json()
assert msg["error"]["code"] == "not_found"
async def test_audio_pipeline_debug(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test debug listing events from a pipeline run with audio input/output."""
events = []
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "stt",
"end_stage": "tts",
"input": {
"sample_rate": 44100,
},
}
)
# result
msg = await client.receive_json()
assert msg["success"]
# run start
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# stt
msg = await client.receive_json()
assert msg["event"]["type"] == "stt-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# End of audio stream (handler id + empty payload)
await client.send_bytes(bytes([1]))
msg = await client.receive_json()
assert msg["event"]["type"] == "stt-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# intent
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# text to speech
msg = await client.receive_json()
assert msg["event"]["type"] == "tts-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
msg = await client.receive_json()
assert msg["event"]["type"] == "tts-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# run end
msg = await client.receive_json()
assert msg["event"]["type"] == "run-end"
assert msg["event"]["data"] is None
events.append(msg["event"])
# Get the id of the pipeline
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
msg = await client.receive_json()
assert msg["success"]
assert len(msg["result"]["pipelines"]) == 1
pipeline_id = msg["result"]["pipelines"][0]["id"]
# Get the id for the run
await client.send_json_auto_id(
{"type": "assist_pipeline/pipeline_debug/list", "pipeline_id": pipeline_id}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"pipeline_runs": [ANY]}
pipeline_run_id = msg["result"]["pipeline_runs"][0]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}
async def test_pipeline_debug_list_runs_wrong_pipeline(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
) -> None:
"""Test debug listing events from a pipeline."""
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{"type": "assist_pipeline/pipeline_debug/list", "pipeline_id": "blah"}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"pipeline_runs": []}
async def test_pipeline_debug_get_run_wrong_pipeline(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
) -> None:
"""Test debug listing events from a pipeline."""
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": "blah",
"pipeline_run_id": "blah",
}
)
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"] == {
"code": "not_found",
"message": "pipeline_id blah not found",
}
async def test_pipeline_debug_get_run_wrong_pipeline_run(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
) -> None:
"""Test debug listing events from a pipeline."""
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "intent",
"end_stage": "intent",
"input": {"text": "Are the lights on?"},
}
)
# result
msg = await client.receive_json()
assert msg["success"]
# consume events
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-start"
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-end"
msg = await client.receive_json()
assert msg["event"]["type"] == "run-end"
# Get the id of the pipeline
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
msg = await client.receive_json()
assert msg["success"]
assert len(msg["result"]["pipelines"]) == 1
pipeline_id = msg["result"]["pipelines"][0]["id"]
# get debug data for the wrong run
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": "blah",
}
)
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"] == {
"code": "not_found",
"message": "pipeline_run_id blah not found",
}