Compare commits
1 commit
dev
...
assist_pip
Author | SHA1 | Date | |
---|---|---|---|
|
3115dbed1f |
3 changed files with 118 additions and 66 deletions
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
from collections.abc import AsyncIterable, Callable
|
||||
from dataclasses import asdict, dataclass, field
|
||||
import functools as ft
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
|
@ -47,7 +48,8 @@ STORAGE_FIELDS = {
|
|||
vol.Optional("tts_engine", default=None): vol.Any(str, None),
|
||||
}
|
||||
|
||||
STORED_PIPELINE_RUNS = 10
|
||||
STORED_PIPELINE_SESSIONS = 10
|
||||
STORED_PIPELINE_RUNS_PER_SESSION = 100
|
||||
|
||||
SAVE_DELAY = 10
|
||||
|
||||
|
@ -191,21 +193,29 @@ class PipelineRun:
|
|||
raise InvalidPipelineStagesError(self.start_stage, self.end_stage)
|
||||
|
||||
pipeline_data: PipelineData = self.hass.data[DOMAIN]
|
||||
if self.pipeline.id not in pipeline_data.pipeline_runs:
|
||||
pipeline_data.pipeline_runs[self.pipeline.id] = LimitedSizeDict(
|
||||
size_limit=STORED_PIPELINE_RUNS
|
||||
pipeline_runs = pipeline_data.pipeline_runs
|
||||
if self.pipeline.id not in pipeline_runs:
|
||||
pipeline_runs[self.pipeline.id] = LimitedSizeDict(
|
||||
size_limit=STORED_PIPELINE_SESSIONS
|
||||
)
|
||||
pipeline_data.pipeline_runs[self.pipeline.id][self.id] = PipelineRunDebug()
|
||||
if self.context.id not in pipeline_runs[self.pipeline.id]:
|
||||
pipeline_runs[self.pipeline.id][self.context.id] = PipelineSessionDebug()
|
||||
|
||||
pipeline_runs[self.pipeline.id][self.context.id].runs[self.id] = []
|
||||
|
||||
@callback
|
||||
def process_event(self, event: PipelineEvent) -> None:
|
||||
"""Log an event and call listener."""
|
||||
self.event_callback(event)
|
||||
pipeline_data: PipelineData = self.hass.data[DOMAIN]
|
||||
if self.id not in pipeline_data.pipeline_runs[self.pipeline.id]:
|
||||
pipeline_runs = pipeline_data.pipeline_runs
|
||||
if self.context.id not in pipeline_runs[self.pipeline.id]:
|
||||
# This session has been evicted from the logged pipeline sessions already
|
||||
return
|
||||
if self.id not in pipeline_runs[self.pipeline.id][self.context.id].runs:
|
||||
# This run has been evicted from the logged pipeline runs already
|
||||
return
|
||||
pipeline_data.pipeline_runs[self.pipeline.id][self.id].events.append(event)
|
||||
pipeline_runs[self.pipeline.id][self.context.id].runs[self.id].append(event)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Emit run start event."""
|
||||
|
@ -735,15 +745,20 @@ class PipelineStorageCollectionWebsocket(
|
|||
class PipelineData:
|
||||
"""Store and debug data stored in hass.data."""
|
||||
|
||||
pipeline_runs: dict[str, LimitedSizeDict[str, PipelineRunDebug]]
|
||||
pipeline_runs: dict[str, LimitedSizeDict[str, PipelineSessionDebug]]
|
||||
pipeline_store: PipelineStorageCollection
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineRunDebug:
|
||||
"""Debug data for a pipelinerun."""
|
||||
class PipelineSessionDebug:
|
||||
"""Debug data for a pipeline session."""
|
||||
|
||||
events: list[PipelineEvent] = field(default_factory=list, init=False)
|
||||
runs: LimitedSizeDict[str, list[PipelineEvent]] = field(
|
||||
default_factory=ft.partial( # type: ignore[arg-type]
|
||||
LimitedSizeDict, size_limit=STORED_PIPELINE_RUNS_PER_SESSION
|
||||
),
|
||||
init=False,
|
||||
)
|
||||
timestamp: str = field(
|
||||
default_factory=lambda: dt_util.utcnow().isoformat(),
|
||||
init=False,
|
||||
|
|
|
@ -70,8 +70,8 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
|
|||
),
|
||||
),
|
||||
)
|
||||
websocket_api.async_register_command(hass, websocket_list_runs)
|
||||
websocket_api.async_register_command(hass, websocket_get_run)
|
||||
websocket_api.async_register_command(hass, websocket_list_sessions)
|
||||
websocket_api.async_register_command(hass, websocket_get_session)
|
||||
|
||||
|
||||
@websocket_api.async_response
|
||||
|
@ -206,12 +206,12 @@ async def websocket_run(
|
|||
vol.Required("pipeline_id"): str,
|
||||
}
|
||||
)
|
||||
def websocket_list_runs(
|
||||
def websocket_list_sessions(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.connection.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""List pipeline runs for which debug data is available."""
|
||||
"""List pipeline sessions for which debug data is available."""
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = msg["pipeline_id"]
|
||||
|
||||
|
@ -219,14 +219,14 @@ def websocket_list_runs(
|
|||
connection.send_result(msg["id"], {"pipeline_runs": []})
|
||||
return
|
||||
|
||||
pipeline_runs = pipeline_data.pipeline_runs[pipeline_id]
|
||||
pipeline_sessions = pipeline_data.pipeline_runs[pipeline_id]
|
||||
|
||||
connection.send_result(
|
||||
msg["id"],
|
||||
{
|
||||
"pipeline_runs": [
|
||||
{"pipeline_run_id": id, "timestamp": pipeline_run.timestamp}
|
||||
for id, pipeline_run in pipeline_runs.items()
|
||||
"pipeline_sessions": [
|
||||
{"pipeline_session_id": id, "timestamp": pipeline_run.timestamp}
|
||||
for id, pipeline_run in pipeline_sessions.items()
|
||||
]
|
||||
},
|
||||
)
|
||||
|
@ -238,18 +238,18 @@ def websocket_list_runs(
|
|||
{
|
||||
vol.Required("type"): "assist_pipeline/pipeline_debug/get",
|
||||
vol.Required("pipeline_id"): str,
|
||||
vol.Required("pipeline_run_id"): str,
|
||||
vol.Required("pipeline_session_id"): str,
|
||||
}
|
||||
)
|
||||
def websocket_get_run(
|
||||
def websocket_get_session(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.connection.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Get debug data for a pipeline run."""
|
||||
"""Get debug data for a pipeline session."""
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = msg["pipeline_id"]
|
||||
pipeline_run_id = msg["pipeline_run_id"]
|
||||
pipeline_session_id = msg["pipeline_session_id"]
|
||||
|
||||
if pipeline_id not in pipeline_data.pipeline_runs:
|
||||
connection.send_error(
|
||||
|
@ -259,17 +259,22 @@ def websocket_get_run(
|
|||
)
|
||||
return
|
||||
|
||||
pipeline_runs = pipeline_data.pipeline_runs[pipeline_id]
|
||||
pipeline_sessions = pipeline_data.pipeline_runs[pipeline_id]
|
||||
|
||||
if pipeline_run_id not in pipeline_runs:
|
||||
if pipeline_session_id not in pipeline_sessions:
|
||||
connection.send_error(
|
||||
msg["id"],
|
||||
websocket_api.const.ERR_NOT_FOUND,
|
||||
f"pipeline_run_id {pipeline_run_id} not found",
|
||||
f"pipeline_session_id {pipeline_session_id} not found",
|
||||
)
|
||||
return
|
||||
|
||||
connection.send_result(
|
||||
msg["id"],
|
||||
{"events": pipeline_runs[pipeline_run_id].events},
|
||||
{
|
||||
"runs": [
|
||||
{"pipeline_run_id": id, "events": events}
|
||||
for id, events in pipeline_sessions[pipeline_session_id].runs.items()
|
||||
]
|
||||
},
|
||||
)
|
||||
|
|
|
@ -58,19 +58,23 @@ async def test_text_only_pipeline(
|
|||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||
pipeline_runs = pipeline_data.pipeline_runs
|
||||
pipeline_id = list(pipeline_runs)[0]
|
||||
pipeline_session_id = list(pipeline_runs[pipeline_id])[0]
|
||||
pipeline_run_id = list(pipeline_runs[pipeline_id][pipeline_session_id].runs)[0]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": pipeline_id,
|
||||
"pipeline_run_id": pipeline_run_id,
|
||||
"pipeline_session_id": pipeline_session_id,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"events": events}
|
||||
assert msg["result"] == {
|
||||
"runs": [{"pipeline_run_id": pipeline_run_id, "events": events}]
|
||||
}
|
||||
|
||||
|
||||
async def test_audio_pipeline(
|
||||
|
@ -147,19 +151,23 @@ async def test_audio_pipeline(
|
|||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||
pipeline_runs = pipeline_data.pipeline_runs
|
||||
pipeline_id = list(pipeline_runs)[0]
|
||||
pipeline_session_id = list(pipeline_runs[pipeline_id])[0]
|
||||
pipeline_run_id = list(pipeline_runs[pipeline_id][pipeline_session_id].runs)[0]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": pipeline_id,
|
||||
"pipeline_run_id": pipeline_run_id,
|
||||
"pipeline_session_id": pipeline_session_id,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"events": events}
|
||||
assert msg["result"] == {
|
||||
"runs": [{"pipeline_run_id": pipeline_run_id, "events": events}]
|
||||
}
|
||||
|
||||
|
||||
async def test_intent_timeout(
|
||||
|
@ -212,19 +220,23 @@ async def test_intent_timeout(
|
|||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||
pipeline_runs = pipeline_data.pipeline_runs
|
||||
pipeline_id = list(pipeline_runs)[0]
|
||||
pipeline_session_id = list(pipeline_runs[pipeline_id])[0]
|
||||
pipeline_run_id = list(pipeline_runs[pipeline_id][pipeline_session_id].runs)[0]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": pipeline_id,
|
||||
"pipeline_run_id": pipeline_run_id,
|
||||
"pipeline_session_id": pipeline_session_id,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"events": events}
|
||||
assert msg["result"] == {
|
||||
"runs": [{"pipeline_run_id": pipeline_run_id, "events": events}]
|
||||
}
|
||||
|
||||
|
||||
async def test_text_pipeline_timeout(
|
||||
|
@ -265,19 +277,23 @@ async def test_text_pipeline_timeout(
|
|||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||
pipeline_runs = pipeline_data.pipeline_runs
|
||||
pipeline_id = list(pipeline_runs)[0]
|
||||
pipeline_session_id = list(pipeline_runs[pipeline_id])[0]
|
||||
pipeline_run_id = list(pipeline_runs[pipeline_id][pipeline_session_id].runs)[0]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": pipeline_id,
|
||||
"pipeline_run_id": pipeline_run_id,
|
||||
"pipeline_session_id": pipeline_session_id,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"events": events}
|
||||
assert msg["result"] == {
|
||||
"runs": [{"pipeline_run_id": pipeline_run_id, "events": events}]
|
||||
}
|
||||
|
||||
|
||||
async def test_intent_failed(
|
||||
|
@ -326,19 +342,23 @@ async def test_intent_failed(
|
|||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||
pipeline_runs = pipeline_data.pipeline_runs
|
||||
pipeline_id = list(pipeline_runs)[0]
|
||||
pipeline_session_id = list(pipeline_runs[pipeline_id])[0]
|
||||
pipeline_run_id = list(pipeline_runs[pipeline_id][pipeline_session_id].runs)[0]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": pipeline_id,
|
||||
"pipeline_run_id": pipeline_run_id,
|
||||
"pipeline_session_id": pipeline_session_id,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"events": events}
|
||||
assert msg["result"] == {
|
||||
"runs": [{"pipeline_run_id": pipeline_run_id, "events": events}]
|
||||
}
|
||||
|
||||
|
||||
async def test_audio_pipeline_timeout(
|
||||
|
@ -381,19 +401,23 @@ async def test_audio_pipeline_timeout(
|
|||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||
pipeline_runs = pipeline_data.pipeline_runs
|
||||
pipeline_id = list(pipeline_runs)[0]
|
||||
pipeline_session_id = list(pipeline_runs[pipeline_id])[0]
|
||||
pipeline_run_id = list(pipeline_runs[pipeline_id][pipeline_session_id].runs)[0]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": pipeline_id,
|
||||
"pipeline_run_id": pipeline_run_id,
|
||||
"pipeline_session_id": pipeline_session_id,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"events": events}
|
||||
assert msg["result"] == {
|
||||
"runs": [{"pipeline_run_id": pipeline_run_id, "events": events}]
|
||||
}
|
||||
|
||||
|
||||
async def test_stt_provider_missing(
|
||||
|
@ -477,19 +501,23 @@ async def test_stt_stream_failed(
|
|||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||
pipeline_runs = pipeline_data.pipeline_runs
|
||||
pipeline_id = list(pipeline_runs)[0]
|
||||
pipeline_session_id = list(pipeline_runs[pipeline_id])[0]
|
||||
pipeline_run_id = list(pipeline_runs[pipeline_id][pipeline_session_id].runs)[0]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": pipeline_id,
|
||||
"pipeline_run_id": pipeline_run_id,
|
||||
"pipeline_session_id": pipeline_session_id,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"events": events}
|
||||
assert msg["result"] == {
|
||||
"runs": [{"pipeline_run_id": pipeline_run_id, "events": events}]
|
||||
}
|
||||
|
||||
|
||||
async def test_tts_failed(
|
||||
|
@ -538,19 +566,23 @@ async def test_tts_failed(
|
|||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = list(pipeline_data.pipeline_runs)[0]
|
||||
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
|
||||
pipeline_runs = pipeline_data.pipeline_runs
|
||||
pipeline_id = list(pipeline_runs)[0]
|
||||
pipeline_session_id = list(pipeline_runs[pipeline_id])[0]
|
||||
pipeline_run_id = list(pipeline_runs[pipeline_id][pipeline_session_id].runs)[0]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": pipeline_id,
|
||||
"pipeline_run_id": pipeline_run_id,
|
||||
"pipeline_session_id": pipeline_session_id,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"events": events}
|
||||
assert msg["result"] == {
|
||||
"runs": [{"pipeline_run_id": pipeline_run_id, "events": events}]
|
||||
}
|
||||
|
||||
|
||||
async def test_invalid_stage_order(
|
||||
|
@ -1020,20 +1052,20 @@ async def test_audio_pipeline_debug(
|
|||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"pipeline_runs": [ANY]}
|
||||
assert msg["result"] == {"pipeline_sessions": [ANY]}
|
||||
|
||||
pipeline_run_id = msg["result"]["pipeline_runs"][0]["pipeline_run_id"]
|
||||
pipeline_session_id = msg["result"]["pipeline_sessions"][0]["pipeline_session_id"]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": pipeline_id,
|
||||
"pipeline_run_id": pipeline_run_id,
|
||||
"pipeline_session_id": pipeline_session_id,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"events": events}
|
||||
assert msg["result"] == {"runs": [{"pipeline_run_id": ANY, "events": events}]}
|
||||
|
||||
|
||||
async def test_pipeline_debug_list_runs_wrong_pipeline(
|
||||
|
@ -1064,7 +1096,7 @@ async def test_pipeline_debug_get_run_wrong_pipeline(
|
|||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": "blah",
|
||||
"pipeline_run_id": "blah",
|
||||
"pipeline_session_id": "blah",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
|
@ -1121,12 +1153,12 @@ async def test_pipeline_debug_get_run_wrong_pipeline_run(
|
|||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": pipeline_id,
|
||||
"pipeline_run_id": "blah",
|
||||
"pipeline_session_id": "blah",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
assert msg["error"] == {
|
||||
"code": "not_found",
|
||||
"message": "pipeline_run_id blah not found",
|
||||
"message": "pipeline_session_id blah not found",
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue