Compare commits

...
Sign in to create a new pull request.

1 commit

Author SHA1 Message Date
Erik
3115dbed1f Group assist pipeline debug data by context ID 2023-04-19 13:34:01 +02:00
3 changed files with 118 additions and 66 deletions

View file

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
from collections.abc import AsyncIterable, Callable
from dataclasses import asdict, dataclass, field
import functools as ft
import logging
from typing import Any
@ -47,7 +48,8 @@ STORAGE_FIELDS = {
vol.Optional("tts_engine", default=None): vol.Any(str, None),
}
STORED_PIPELINE_RUNS = 10
STORED_PIPELINE_SESSIONS = 10
STORED_PIPELINE_RUNS_PER_SESSION = 100
SAVE_DELAY = 10
@ -191,21 +193,29 @@ class PipelineRun:
raise InvalidPipelineStagesError(self.start_stage, self.end_stage)
pipeline_data: PipelineData = self.hass.data[DOMAIN]
if self.pipeline.id not in pipeline_data.pipeline_runs:
pipeline_data.pipeline_runs[self.pipeline.id] = LimitedSizeDict(
size_limit=STORED_PIPELINE_RUNS
pipeline_runs = pipeline_data.pipeline_runs
if self.pipeline.id not in pipeline_runs:
pipeline_runs[self.pipeline.id] = LimitedSizeDict(
size_limit=STORED_PIPELINE_SESSIONS
)
pipeline_data.pipeline_runs[self.pipeline.id][self.id] = PipelineRunDebug()
if self.context.id not in pipeline_runs[self.pipeline.id]:
pipeline_runs[self.pipeline.id][self.context.id] = PipelineSessionDebug()
pipeline_runs[self.pipeline.id][self.context.id].runs[self.id] = []
@callback
def process_event(self, event: PipelineEvent) -> None:
"""Log an event and call listener."""
self.event_callback(event)
pipeline_data: PipelineData = self.hass.data[DOMAIN]
if self.id not in pipeline_data.pipeline_runs[self.pipeline.id]:
pipeline_runs = pipeline_data.pipeline_runs
if self.context.id not in pipeline_runs[self.pipeline.id]:
# This session has been evicted from the logged pipeline sessions already
return
if self.id not in pipeline_runs[self.pipeline.id][self.context.id].runs:
# This run has been evicted from the logged pipeline runs already
return
pipeline_data.pipeline_runs[self.pipeline.id][self.id].events.append(event)
pipeline_runs[self.pipeline.id][self.context.id].runs[self.id].append(event)
def start(self) -> None:
"""Emit run start event."""
@ -735,15 +745,20 @@ class PipelineStorageCollectionWebsocket(
class PipelineData:
"""Store and debug data stored in hass.data."""
pipeline_runs: dict[str, LimitedSizeDict[str, PipelineRunDebug]]
pipeline_runs: dict[str, LimitedSizeDict[str, PipelineSessionDebug]]
pipeline_store: PipelineStorageCollection
@dataclass
class PipelineRunDebug:
"""Debug data for a pipelinerun."""
class PipelineSessionDebug:
"""Debug data for a pipeline session."""
events: list[PipelineEvent] = field(default_factory=list, init=False)
runs: LimitedSizeDict[str, list[PipelineEvent]] = field(
default_factory=ft.partial( # type: ignore[arg-type]
LimitedSizeDict, size_limit=STORED_PIPELINE_RUNS_PER_SESSION
),
init=False,
)
timestamp: str = field(
default_factory=lambda: dt_util.utcnow().isoformat(),
init=False,

View file

@ -70,8 +70,8 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
),
),
)
websocket_api.async_register_command(hass, websocket_list_runs)
websocket_api.async_register_command(hass, websocket_get_run)
websocket_api.async_register_command(hass, websocket_list_sessions)
websocket_api.async_register_command(hass, websocket_get_session)
@websocket_api.async_response
@ -206,12 +206,12 @@ async def websocket_run(
vol.Required("pipeline_id"): str,
}
)
def websocket_list_runs(
def websocket_list_sessions(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""List pipeline runs for which debug data is available."""
"""List pipeline sessions for which debug data is available."""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = msg["pipeline_id"]
@ -219,14 +219,14 @@ def websocket_list_runs(
connection.send_result(msg["id"], {"pipeline_runs": []})
return
pipeline_runs = pipeline_data.pipeline_runs[pipeline_id]
pipeline_sessions = pipeline_data.pipeline_runs[pipeline_id]
connection.send_result(
msg["id"],
{
"pipeline_runs": [
{"pipeline_run_id": id, "timestamp": pipeline_run.timestamp}
for id, pipeline_run in pipeline_runs.items()
"pipeline_sessions": [
{"pipeline_session_id": id, "timestamp": pipeline_run.timestamp}
for id, pipeline_run in pipeline_sessions.items()
]
},
)
@ -238,18 +238,18 @@ def websocket_list_runs(
{
vol.Required("type"): "assist_pipeline/pipeline_debug/get",
vol.Required("pipeline_id"): str,
vol.Required("pipeline_run_id"): str,
vol.Required("pipeline_session_id"): str,
}
)
def websocket_get_run(
def websocket_get_session(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Get debug data for a pipeline run."""
"""Get debug data for a pipeline session."""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = msg["pipeline_id"]
pipeline_run_id = msg["pipeline_run_id"]
pipeline_session_id = msg["pipeline_session_id"]
if pipeline_id not in pipeline_data.pipeline_runs:
connection.send_error(
@ -259,17 +259,22 @@ def websocket_get_run(
)
return
pipeline_runs = pipeline_data.pipeline_runs[pipeline_id]
pipeline_sessions = pipeline_data.pipeline_runs[pipeline_id]
if pipeline_run_id not in pipeline_runs:
if pipeline_session_id not in pipeline_sessions:
connection.send_error(
msg["id"],
websocket_api.const.ERR_NOT_FOUND,
f"pipeline_run_id {pipeline_run_id} not found",
f"pipeline_session_id {pipeline_session_id} not found",
)
return
connection.send_result(
msg["id"],
{"events": pipeline_runs[pipeline_run_id].events},
{
"runs": [
{"pipeline_run_id": id, "events": events}
for id, events in pipeline_sessions[pipeline_session_id].runs.items()
]
},
)

View file

@ -58,19 +58,23 @@ async def test_text_only_pipeline(
events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = list(pipeline_data.pipeline_runs)[0]
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
pipeline_runs = pipeline_data.pipeline_runs
pipeline_id = list(pipeline_runs)[0]
pipeline_session_id = list(pipeline_runs[pipeline_id])[0]
pipeline_run_id = list(pipeline_runs[pipeline_id][pipeline_session_id].runs)[0]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id,
"pipeline_session_id": pipeline_session_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}
assert msg["result"] == {
"runs": [{"pipeline_run_id": pipeline_run_id, "events": events}]
}
async def test_audio_pipeline(
@ -147,19 +151,23 @@ async def test_audio_pipeline(
events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = list(pipeline_data.pipeline_runs)[0]
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
pipeline_runs = pipeline_data.pipeline_runs
pipeline_id = list(pipeline_runs)[0]
pipeline_session_id = list(pipeline_runs[pipeline_id])[0]
pipeline_run_id = list(pipeline_runs[pipeline_id][pipeline_session_id].runs)[0]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id,
"pipeline_session_id": pipeline_session_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}
assert msg["result"] == {
"runs": [{"pipeline_run_id": pipeline_run_id, "events": events}]
}
async def test_intent_timeout(
@ -212,19 +220,23 @@ async def test_intent_timeout(
events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = list(pipeline_data.pipeline_runs)[0]
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
pipeline_runs = pipeline_data.pipeline_runs
pipeline_id = list(pipeline_runs)[0]
pipeline_session_id = list(pipeline_runs[pipeline_id])[0]
pipeline_run_id = list(pipeline_runs[pipeline_id][pipeline_session_id].runs)[0]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id,
"pipeline_session_id": pipeline_session_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}
assert msg["result"] == {
"runs": [{"pipeline_run_id": pipeline_run_id, "events": events}]
}
async def test_text_pipeline_timeout(
@ -265,19 +277,23 @@ async def test_text_pipeline_timeout(
events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = list(pipeline_data.pipeline_runs)[0]
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
pipeline_runs = pipeline_data.pipeline_runs
pipeline_id = list(pipeline_runs)[0]
pipeline_session_id = list(pipeline_runs[pipeline_id])[0]
pipeline_run_id = list(pipeline_runs[pipeline_id][pipeline_session_id].runs)[0]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id,
"pipeline_session_id": pipeline_session_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}
assert msg["result"] == {
"runs": [{"pipeline_run_id": pipeline_run_id, "events": events}]
}
async def test_intent_failed(
@ -326,19 +342,23 @@ async def test_intent_failed(
events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = list(pipeline_data.pipeline_runs)[0]
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
pipeline_runs = pipeline_data.pipeline_runs
pipeline_id = list(pipeline_runs)[0]
pipeline_session_id = list(pipeline_runs[pipeline_id])[0]
pipeline_run_id = list(pipeline_runs[pipeline_id][pipeline_session_id].runs)[0]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id,
"pipeline_session_id": pipeline_session_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}
assert msg["result"] == {
"runs": [{"pipeline_run_id": pipeline_run_id, "events": events}]
}
async def test_audio_pipeline_timeout(
@ -381,19 +401,23 @@ async def test_audio_pipeline_timeout(
events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = list(pipeline_data.pipeline_runs)[0]
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
pipeline_runs = pipeline_data.pipeline_runs
pipeline_id = list(pipeline_runs)[0]
pipeline_session_id = list(pipeline_runs[pipeline_id])[0]
pipeline_run_id = list(pipeline_runs[pipeline_id][pipeline_session_id].runs)[0]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id,
"pipeline_session_id": pipeline_session_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}
assert msg["result"] == {
"runs": [{"pipeline_run_id": pipeline_run_id, "events": events}]
}
async def test_stt_provider_missing(
@ -477,19 +501,23 @@ async def test_stt_stream_failed(
events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = list(pipeline_data.pipeline_runs)[0]
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
pipeline_runs = pipeline_data.pipeline_runs
pipeline_id = list(pipeline_runs)[0]
pipeline_session_id = list(pipeline_runs[pipeline_id])[0]
pipeline_run_id = list(pipeline_runs[pipeline_id][pipeline_session_id].runs)[0]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id,
"pipeline_session_id": pipeline_session_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}
assert msg["result"] == {
"runs": [{"pipeline_run_id": pipeline_run_id, "events": events}]
}
async def test_tts_failed(
@ -538,19 +566,23 @@ async def test_tts_failed(
events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = list(pipeline_data.pipeline_runs)[0]
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
pipeline_runs = pipeline_data.pipeline_runs
pipeline_id = list(pipeline_runs)[0]
pipeline_session_id = list(pipeline_runs[pipeline_id])[0]
pipeline_run_id = list(pipeline_runs[pipeline_id][pipeline_session_id].runs)[0]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id,
"pipeline_session_id": pipeline_session_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}
assert msg["result"] == {
"runs": [{"pipeline_run_id": pipeline_run_id, "events": events}]
}
async def test_invalid_stage_order(
@ -1020,20 +1052,20 @@ async def test_audio_pipeline_debug(
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"pipeline_runs": [ANY]}
assert msg["result"] == {"pipeline_sessions": [ANY]}
pipeline_run_id = msg["result"]["pipeline_runs"][0]["pipeline_run_id"]
pipeline_session_id = msg["result"]["pipeline_sessions"][0]["pipeline_session_id"]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id,
"pipeline_session_id": pipeline_session_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}
assert msg["result"] == {"runs": [{"pipeline_run_id": ANY, "events": events}]}
async def test_pipeline_debug_list_runs_wrong_pipeline(
@ -1064,7 +1096,7 @@ async def test_pipeline_debug_get_run_wrong_pipeline(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": "blah",
"pipeline_run_id": "blah",
"pipeline_session_id": "blah",
}
)
msg = await client.receive_json()
@ -1121,12 +1153,12 @@ async def test_pipeline_debug_get_run_wrong_pipeline_run(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": "blah",
"pipeline_session_id": "blah",
}
)
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"] == {
"code": "not_found",
"message": "pipeline_run_id blah not found",
"message": "pipeline_session_id blah not found",
}