diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 22b0191ee0b..72259d96826 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio from collections.abc import AsyncIterable, Callable from dataclasses import asdict, dataclass, field +import functools as ft import logging from typing import Any @@ -47,7 +48,8 @@ STORAGE_FIELDS = { vol.Optional("tts_engine", default=None): vol.Any(str, None), } -STORED_PIPELINE_RUNS = 10 +STORED_PIPELINE_SESSIONS = 10 +STORED_PIPELINE_RUNS_PER_SESSION = 100 SAVE_DELAY = 10 @@ -191,21 +193,29 @@ class PipelineRun: raise InvalidPipelineStagesError(self.start_stage, self.end_stage) pipeline_data: PipelineData = self.hass.data[DOMAIN] - if self.pipeline.id not in pipeline_data.pipeline_runs: - pipeline_data.pipeline_runs[self.pipeline.id] = LimitedSizeDict( - size_limit=STORED_PIPELINE_RUNS + pipeline_runs = pipeline_data.pipeline_runs + if self.pipeline.id not in pipeline_runs: + pipeline_runs[self.pipeline.id] = LimitedSizeDict( + size_limit=STORED_PIPELINE_SESSIONS ) - pipeline_data.pipeline_runs[self.pipeline.id][self.id] = PipelineRunDebug() + if self.context.id not in pipeline_runs[self.pipeline.id]: + pipeline_runs[self.pipeline.id][self.context.id] = PipelineSessionDebug() + + pipeline_runs[self.pipeline.id][self.context.id].runs[self.id] = [] @callback def process_event(self, event: PipelineEvent) -> None: """Log an event and call listener.""" self.event_callback(event) pipeline_data: PipelineData = self.hass.data[DOMAIN] - if self.id not in pipeline_data.pipeline_runs[self.pipeline.id]: + pipeline_runs = pipeline_data.pipeline_runs + if self.context.id not in pipeline_runs[self.pipeline.id]: + # This session has been evicted from the logged pipeline sessions already + return + if self.id not in pipeline_runs[self.pipeline.id][self.context.id].runs: # This run has been evicted from the logged pipeline runs already return - pipeline_data.pipeline_runs[self.pipeline.id][self.id].events.append(event) + pipeline_runs[self.pipeline.id][self.context.id].runs[self.id].append(event) def start(self) -> None: """Emit run start event.""" @@ -735,15 +745,20 @@ class PipelineStorageCollectionWebsocket( class PipelineData: """Store and debug data stored in hass.data.""" - pipeline_runs: dict[str, LimitedSizeDict[str, PipelineRunDebug]] + pipeline_runs: dict[str, LimitedSizeDict[str, PipelineSessionDebug]] pipeline_store: PipelineStorageCollection @dataclass -class PipelineRunDebug: - """Debug data for a pipelinerun.""" +class PipelineSessionDebug: + """Debug data for a pipeline session.""" - events: list[PipelineEvent] = field(default_factory=list, init=False) + runs: LimitedSizeDict[str, list[PipelineEvent]] = field( + default_factory=ft.partial( # type: ignore[arg-type] + LimitedSizeDict, size_limit=STORED_PIPELINE_RUNS_PER_SESSION + ), + init=False, + ) timestamp: str = field( default_factory=lambda: dt_util.utcnow().isoformat(), init=False, diff --git a/homeassistant/components/assist_pipeline/websocket_api.py b/homeassistant/components/assist_pipeline/websocket_api.py index b672e0c6b25..01680d73fb4 100644 --- a/homeassistant/components/assist_pipeline/websocket_api.py +++ b/homeassistant/components/assist_pipeline/websocket_api.py @@ -70,8 +70,8 @@ def async_register_websocket_api(hass: HomeAssistant) -> None: ), ), ) - websocket_api.async_register_command(hass, websocket_list_runs) - websocket_api.async_register_command(hass, websocket_get_run) + websocket_api.async_register_command(hass, websocket_list_sessions) + websocket_api.async_register_command(hass, websocket_get_session) @websocket_api.async_response @@ -206,12 +206,12 @@ async def websocket_run( vol.Required("pipeline_id"): str, } ) -def websocket_list_runs( +def websocket_list_sessions( hass: HomeAssistant, connection: websocket_api.connection.ActiveConnection, msg: dict[str, Any], ) -> None: - """List pipeline runs for which debug data is available.""" + """List pipeline sessions for which debug data is available.""" pipeline_data: PipelineData = hass.data[DOMAIN] pipeline_id = msg["pipeline_id"] @@ -219,14 +219,14 @@ def websocket_list_runs( connection.send_result(msg["id"], {"pipeline_runs": []}) return - pipeline_runs = pipeline_data.pipeline_runs[pipeline_id] + pipeline_sessions = pipeline_data.pipeline_runs[pipeline_id] connection.send_result( msg["id"], { - "pipeline_runs": [ - {"pipeline_run_id": id, "timestamp": pipeline_run.timestamp} - for id, pipeline_run in pipeline_runs.items() + "pipeline_sessions": [ + {"pipeline_session_id": id, "timestamp": pipeline_run.timestamp} + for id, pipeline_run in pipeline_sessions.items() ] }, ) @@ -238,18 +238,18 @@ def websocket_list_runs( { vol.Required("type"): "assist_pipeline/pipeline_debug/get", vol.Required("pipeline_id"): str, - vol.Required("pipeline_run_id"): str, + vol.Required("pipeline_session_id"): str, } ) -def websocket_get_run( +def websocket_get_session( hass: HomeAssistant, connection: websocket_api.connection.ActiveConnection, msg: dict[str, Any], ) -> None: - """Get debug data for a pipeline run.""" + """Get debug data for a pipeline session.""" pipeline_data: PipelineData = hass.data[DOMAIN] pipeline_id = msg["pipeline_id"] - pipeline_run_id = msg["pipeline_run_id"] + pipeline_session_id = msg["pipeline_session_id"] if pipeline_id not in pipeline_data.pipeline_runs: connection.send_error( @@ -259,17 +259,22 @@ def websocket_get_run( ) return - pipeline_runs = pipeline_data.pipeline_runs[pipeline_id] + pipeline_sessions = pipeline_data.pipeline_runs[pipeline_id] - if pipeline_run_id not in pipeline_runs: + if pipeline_session_id not in pipeline_sessions: connection.send_error( msg["id"], websocket_api.const.ERR_NOT_FOUND, - f"pipeline_run_id {pipeline_run_id} not found", + f"pipeline_session_id {pipeline_session_id} not found", ) return connection.send_result( msg["id"], - {"events": pipeline_runs[pipeline_run_id].events}, + { + "runs": [ + {"pipeline_run_id": id, "events": events} + for id, events in pipeline_sessions[pipeline_session_id].runs.items() + ] + }, ) diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index 0560b585eff..fd16ca6a743 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -58,19 +58,23 @@ async def test_text_only_pipeline( events.append(msg["event"]) pipeline_data: PipelineData = hass.data[DOMAIN] - pipeline_id = list(pipeline_data.pipeline_runs)[0] - pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0] + pipeline_runs = pipeline_data.pipeline_runs + pipeline_id = list(pipeline_runs)[0] + pipeline_session_id = list(pipeline_runs[pipeline_id])[0] + pipeline_run_id = list(pipeline_runs[pipeline_id][pipeline_session_id].runs)[0] await client.send_json_auto_id( { "type": "assist_pipeline/pipeline_debug/get", "pipeline_id": pipeline_id, - "pipeline_run_id": pipeline_run_id, + "pipeline_session_id": pipeline_session_id, } ) msg = await client.receive_json() assert msg["success"] - assert msg["result"] == {"events": events} + assert msg["result"] == { + "runs": [{"pipeline_run_id": pipeline_run_id, "events": events}] + } async def test_audio_pipeline( @@ -147,19 +151,23 @@ async def test_audio_pipeline( events.append(msg["event"]) pipeline_data: PipelineData = hass.data[DOMAIN] - pipeline_id = list(pipeline_data.pipeline_runs)[0] - pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0] + pipeline_runs = pipeline_data.pipeline_runs + pipeline_id = list(pipeline_runs)[0] + pipeline_session_id = list(pipeline_runs[pipeline_id])[0] + pipeline_run_id = list(pipeline_runs[pipeline_id][pipeline_session_id].runs)[0] await client.send_json_auto_id( { "type": "assist_pipeline/pipeline_debug/get", "pipeline_id": pipeline_id, - "pipeline_run_id": pipeline_run_id, + "pipeline_session_id": pipeline_session_id, } ) msg = await client.receive_json() assert msg["success"] - assert msg["result"] == {"events": events} + assert msg["result"] == { + "runs": [{"pipeline_run_id": pipeline_run_id, "events": events}] + } async def test_intent_timeout( @@ -212,19 +220,23 @@ async def test_intent_timeout( events.append(msg["event"]) pipeline_data: PipelineData = hass.data[DOMAIN] - pipeline_id = list(pipeline_data.pipeline_runs)[0] - pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0] + pipeline_runs = pipeline_data.pipeline_runs + pipeline_id = list(pipeline_runs)[0] + pipeline_session_id = list(pipeline_runs[pipeline_id])[0] + pipeline_run_id = list(pipeline_runs[pipeline_id][pipeline_session_id].runs)[0] await client.send_json_auto_id( { "type": "assist_pipeline/pipeline_debug/get", "pipeline_id": pipeline_id, - "pipeline_run_id": pipeline_run_id, + "pipeline_session_id": pipeline_session_id, } ) msg = await client.receive_json() assert msg["success"] - assert msg["result"] == {"events": events} + assert msg["result"] == { + "runs": [{"pipeline_run_id": pipeline_run_id, "events": events}] + } async def test_text_pipeline_timeout( @@ -265,19 +277,23 @@ async def test_text_pipeline_timeout( events.append(msg["event"]) pipeline_data: PipelineData = hass.data[DOMAIN] - pipeline_id = list(pipeline_data.pipeline_runs)[0] - pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0] + pipeline_runs = pipeline_data.pipeline_runs + pipeline_id = list(pipeline_runs)[0] + pipeline_session_id = list(pipeline_runs[pipeline_id])[0] + pipeline_run_id = list(pipeline_runs[pipeline_id][pipeline_session_id].runs)[0] await client.send_json_auto_id( { "type": "assist_pipeline/pipeline_debug/get", "pipeline_id": pipeline_id, - "pipeline_run_id": pipeline_run_id, + "pipeline_session_id": pipeline_session_id, } ) msg = await client.receive_json() assert msg["success"] - assert msg["result"] == {"events": events} + assert msg["result"] == { + "runs": [{"pipeline_run_id": pipeline_run_id, "events": events}] + } async def test_intent_failed( @@ -326,19 +342,23 @@ async def test_intent_failed( events.append(msg["event"]) pipeline_data: PipelineData = hass.data[DOMAIN] - pipeline_id = list(pipeline_data.pipeline_runs)[0] - pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0] + pipeline_runs = pipeline_data.pipeline_runs + pipeline_id = list(pipeline_runs)[0] + pipeline_session_id = list(pipeline_runs[pipeline_id])[0] + pipeline_run_id = list(pipeline_runs[pipeline_id][pipeline_session_id].runs)[0] await client.send_json_auto_id( { "type": "assist_pipeline/pipeline_debug/get", "pipeline_id": pipeline_id, - "pipeline_run_id": pipeline_run_id, + "pipeline_session_id": pipeline_session_id, } ) msg = await client.receive_json() assert msg["success"] - assert msg["result"] == {"events": events} + assert msg["result"] == { + "runs": [{"pipeline_run_id": pipeline_run_id, "events": events}] + } async def test_audio_pipeline_timeout( @@ -381,19 +401,23 @@ async def test_audio_pipeline_timeout( events.append(msg["event"]) pipeline_data: PipelineData = hass.data[DOMAIN] - pipeline_id = list(pipeline_data.pipeline_runs)[0] - pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0] + pipeline_runs = pipeline_data.pipeline_runs + pipeline_id = list(pipeline_runs)[0] + pipeline_session_id = list(pipeline_runs[pipeline_id])[0] + pipeline_run_id = list(pipeline_runs[pipeline_id][pipeline_session_id].runs)[0] await client.send_json_auto_id( { "type": "assist_pipeline/pipeline_debug/get", "pipeline_id": pipeline_id, - "pipeline_run_id": pipeline_run_id, + "pipeline_session_id": pipeline_session_id, } ) msg = await client.receive_json() assert msg["success"] - assert msg["result"] == {"events": events} + assert msg["result"] == { + "runs": [{"pipeline_run_id": pipeline_run_id, "events": events}] + } async def test_stt_provider_missing( @@ -477,19 +501,23 @@ async def test_stt_stream_failed( events.append(msg["event"]) pipeline_data: PipelineData = hass.data[DOMAIN] - pipeline_id = list(pipeline_data.pipeline_runs)[0] - pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0] + pipeline_runs = pipeline_data.pipeline_runs + pipeline_id = list(pipeline_runs)[0] + pipeline_session_id = list(pipeline_runs[pipeline_id])[0] + pipeline_run_id = list(pipeline_runs[pipeline_id][pipeline_session_id].runs)[0] await client.send_json_auto_id( { "type": "assist_pipeline/pipeline_debug/get", "pipeline_id": pipeline_id, - "pipeline_run_id": pipeline_run_id, + "pipeline_session_id": pipeline_session_id, } ) msg = await client.receive_json() assert msg["success"] - assert msg["result"] == {"events": events} + assert msg["result"] == { + "runs": [{"pipeline_run_id": pipeline_run_id, "events": events}] + } async def test_tts_failed( @@ -538,19 +566,23 @@ async def test_tts_failed( events.append(msg["event"]) pipeline_data: PipelineData = hass.data[DOMAIN] - pipeline_id = list(pipeline_data.pipeline_runs)[0] - pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0] + pipeline_runs = pipeline_data.pipeline_runs + pipeline_id = list(pipeline_runs)[0] + pipeline_session_id = list(pipeline_runs[pipeline_id])[0] + pipeline_run_id = list(pipeline_runs[pipeline_id][pipeline_session_id].runs)[0] await client.send_json_auto_id( { "type": "assist_pipeline/pipeline_debug/get", "pipeline_id": pipeline_id, - "pipeline_run_id": pipeline_run_id, + "pipeline_session_id": pipeline_session_id, } ) msg = await client.receive_json() assert msg["success"] - assert msg["result"] == {"events": events} + assert msg["result"] == { + "runs": [{"pipeline_run_id": pipeline_run_id, "events": events}] + } async def test_invalid_stage_order( @@ -1020,20 +1052,20 @@ async def test_audio_pipeline_debug( ) msg = await client.receive_json() assert msg["success"] - assert msg["result"] == {"pipeline_runs": [ANY]} + assert msg["result"] == {"pipeline_sessions": [ANY]} - pipeline_run_id = msg["result"]["pipeline_runs"][0]["pipeline_run_id"] + pipeline_session_id = msg["result"]["pipeline_sessions"][0]["pipeline_session_id"] await client.send_json_auto_id( { "type": "assist_pipeline/pipeline_debug/get", "pipeline_id": pipeline_id, - "pipeline_run_id": pipeline_run_id, + "pipeline_session_id": pipeline_session_id, } ) msg = await client.receive_json() assert msg["success"] - assert msg["result"] == {"events": events} + assert msg["result"] == {"runs": [{"pipeline_run_id": ANY, "events": events}]} async def test_pipeline_debug_list_runs_wrong_pipeline( @@ -1064,7 +1096,7 @@ async def test_pipeline_debug_get_run_wrong_pipeline( { "type": "assist_pipeline/pipeline_debug/get", "pipeline_id": "blah", - "pipeline_run_id": "blah", + "pipeline_session_id": "blah", } ) msg = await client.receive_json() @@ -1121,12 +1153,12 @@ async def test_pipeline_debug_get_run_wrong_pipeline_run( { "type": "assist_pipeline/pipeline_debug/get", "pipeline_id": pipeline_id, - "pipeline_run_id": "blah", + "pipeline_session_id": "blah", } ) msg = await client.receive_json() assert not msg["success"] assert msg["error"] == { "code": "not_found", - "message": "pipeline_run_id blah not found", + "message": "pipeline_session_id blah not found", }