Address review comments from trace refactoring PRs (#48288)
This commit is contained in:
parent
ee81869c05
commit
14ef0531f0
7 changed files with 130 additions and 140 deletions
|
@ -2,25 +2,79 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Deque
|
||||
|
||||
from homeassistant.components.trace import AutomationTrace, async_store_trace
|
||||
from homeassistant.components.trace import ActionTrace, async_store_trace
|
||||
from homeassistant.core import Context
|
||||
from homeassistant.helpers.trace import TraceElement
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
||||
# mypy: no-check-untyped-defs, no-warn-return-any
|
||||
|
||||
|
||||
class AutomationTrace(ActionTrace):
|
||||
"""Container for automation trace."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
item_id: str,
|
||||
config: dict[str, Any],
|
||||
context: Context,
|
||||
):
|
||||
"""Container for automation trace."""
|
||||
key = ("automation", item_id)
|
||||
super().__init__(key, config, context)
|
||||
self._condition_trace: dict[str, Deque[TraceElement]] | None = None
|
||||
|
||||
def set_condition_trace(self, trace: dict[str, Deque[TraceElement]]) -> None:
|
||||
"""Set condition trace."""
|
||||
self._condition_trace = trace
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""Return dictionary version of this AutomationTrace."""
|
||||
|
||||
result = super().as_dict()
|
||||
|
||||
condition_traces = {}
|
||||
|
||||
if self._condition_trace:
|
||||
for key, trace_list in self._condition_trace.items():
|
||||
condition_traces[key] = [item.as_dict() for item in trace_list]
|
||||
result["condition_trace"] = condition_traces
|
||||
|
||||
return result
|
||||
|
||||
def as_short_dict(self) -> dict[str, Any]:
|
||||
"""Return a brief dictionary version of this AutomationTrace."""
|
||||
|
||||
result = super().as_short_dict()
|
||||
|
||||
last_condition = None
|
||||
trigger = None
|
||||
|
||||
if self._condition_trace:
|
||||
last_condition = list(self._condition_trace)[-1]
|
||||
if self._variables:
|
||||
trigger = self._variables.get("trigger", {}).get("description")
|
||||
|
||||
result["trigger"] = trigger
|
||||
result["last_condition"] = last_condition
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@contextmanager
|
||||
def trace_automation(hass, item_id, config, context):
|
||||
"""Trace action execution of automation with item_id."""
|
||||
trace = AutomationTrace(item_id, config, context)
|
||||
def trace_automation(hass, automation_id, config, context):
|
||||
"""Trace action execution of automation with automation_id."""
|
||||
trace = AutomationTrace(automation_id, config, context)
|
||||
async_store_trace(hass, trace)
|
||||
|
||||
try:
|
||||
yield trace
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
if item_id:
|
||||
if automation_id:
|
||||
trace.set_error(ex)
|
||||
raise ex
|
||||
finally:
|
||||
if item_id:
|
||||
if automation_id:
|
||||
trace.finished()
|
||||
|
|
|
@ -2,8 +2,24 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.components.trace import ScriptTrace, async_store_trace
|
||||
from homeassistant.components.trace import ActionTrace, async_store_trace
|
||||
from homeassistant.core import Context
|
||||
|
||||
|
||||
class ScriptTrace(ActionTrace):
|
||||
"""Container for automation trace."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
item_id: str,
|
||||
config: dict[str, Any],
|
||||
context: Context,
|
||||
):
|
||||
"""Container for automation trace."""
|
||||
key = ("script", item_id)
|
||||
super().__init__(key, config, context)
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
|
|
@ -128,68 +128,3 @@ class ActionTrace:
|
|||
result["last_action"] = last_action
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class AutomationTrace(ActionTrace):
|
||||
"""Container for automation trace."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
item_id: str,
|
||||
config: dict[str, Any],
|
||||
context: Context,
|
||||
):
|
||||
"""Container for automation trace."""
|
||||
key = ("automation", item_id)
|
||||
super().__init__(key, config, context)
|
||||
self._condition_trace: dict[str, Deque[TraceElement]] | None = None
|
||||
|
||||
def set_condition_trace(self, trace: dict[str, Deque[TraceElement]]) -> None:
|
||||
"""Set condition trace."""
|
||||
self._condition_trace = trace
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""Return dictionary version of this AutomationTrace."""
|
||||
|
||||
result = super().as_dict()
|
||||
|
||||
condition_traces = {}
|
||||
|
||||
if self._condition_trace:
|
||||
for key, trace_list in self._condition_trace.items():
|
||||
condition_traces[key] = [item.as_dict() for item in trace_list]
|
||||
result["condition_trace"] = condition_traces
|
||||
|
||||
return result
|
||||
|
||||
def as_short_dict(self) -> dict[str, Any]:
|
||||
"""Return a brief dictionary version of this AutomationTrace."""
|
||||
|
||||
result = super().as_short_dict()
|
||||
|
||||
last_condition = None
|
||||
trigger = None
|
||||
|
||||
if self._condition_trace:
|
||||
last_condition = list(self._condition_trace)[-1]
|
||||
if self._variables:
|
||||
trigger = self._variables.get("trigger", {}).get("description")
|
||||
|
||||
result["trigger"] = trigger
|
||||
result["last_condition"] = last_condition
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ScriptTrace(ActionTrace):
|
||||
"""Container for automation trace."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
item_id: str,
|
||||
config: dict[str, Any],
|
||||
context: Context,
|
||||
):
|
||||
"""Container for automation trace."""
|
||||
key = ("script", item_id)
|
||||
super().__init__(key, config, context)
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
"""Support for automation and script tracing and debugging."""
|
||||
from homeassistant.core import callback
|
||||
|
||||
from .const import DATA_TRACE
|
||||
|
||||
|
||||
@callback
|
||||
def get_debug_trace(hass, key, run_id):
|
||||
"""Return a serializable debug trace."""
|
||||
return hass.data[DATA_TRACE][key][run_id]
|
||||
|
||||
|
||||
@callback
|
||||
def get_debug_traces(hass, key, summary=False):
|
||||
"""Return a serializable list of debug traces for an automation or script."""
|
||||
traces = []
|
||||
|
||||
for trace in hass.data[DATA_TRACE].get(key, {}).values():
|
||||
if summary:
|
||||
traces.append(trace.as_short_dict())
|
||||
else:
|
||||
traces.append(trace.as_dict())
|
||||
|
||||
return traces
|
||||
|
||||
|
||||
@callback
|
||||
def get_all_debug_traces(hass, summary=False):
|
||||
"""Return a serializable list of debug traces for all automations and scripts."""
|
||||
traces = []
|
||||
|
||||
for key in hass.data[DATA_TRACE]:
|
||||
traces.extend(get_debug_traces(hass, key, summary))
|
||||
|
||||
return traces
|
|
@ -1,4 +1,4 @@
|
|||
"""Helpers for automation and script tracing and debugging."""
|
||||
"""Helpers for script and automation tracing and debugging."""
|
||||
from collections import OrderedDict
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
|
|
@ -23,12 +23,12 @@ from homeassistant.helpers.script import (
|
|||
debug_stop,
|
||||
)
|
||||
|
||||
from .trace import DATA_TRACE, get_all_debug_traces, get_debug_trace, get_debug_traces
|
||||
from .const import DATA_TRACE
|
||||
from .utils import TraceJSONEncoder
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
||||
|
||||
TRACE_DOMAINS = ["automation", "script"]
|
||||
TRACE_DOMAINS = ("automation", "script")
|
||||
|
||||
|
||||
@callback
|
||||
|
@ -57,33 +57,47 @@ def async_setup(hass: HomeAssistant) -> None:
|
|||
}
|
||||
)
|
||||
def websocket_trace_get(hass, connection, msg):
|
||||
"""Get an automation or script trace."""
|
||||
"""Get an script or automation trace."""
|
||||
key = (msg["domain"], msg["item_id"])
|
||||
run_id = msg["run_id"]
|
||||
|
||||
trace = get_debug_trace(hass, key, run_id)
|
||||
trace = hass.data[DATA_TRACE][key][run_id]
|
||||
message = websocket_api.messages.result_message(msg["id"], trace)
|
||||
|
||||
connection.send_message(json.dumps(message, cls=TraceJSONEncoder, allow_nan=False))
|
||||
|
||||
|
||||
def get_debug_traces(hass, key):
|
||||
"""Return a serializable list of debug traces for an script or automation."""
|
||||
traces = []
|
||||
|
||||
for trace in hass.data[DATA_TRACE].get(key, {}).values():
|
||||
traces.append(trace.as_short_dict())
|
||||
|
||||
return traces
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "trace/list",
|
||||
vol.Inclusive("domain", "id"): vol.In(TRACE_DOMAINS),
|
||||
vol.Inclusive("item_id", "id"): str,
|
||||
vol.Required("domain", "id"): vol.In(TRACE_DOMAINS),
|
||||
vol.Optional("item_id", "id"): str,
|
||||
}
|
||||
)
|
||||
def websocket_trace_list(hass, connection, msg):
|
||||
"""Summarize automation and script traces."""
|
||||
key = (msg["domain"], msg["item_id"]) if "item_id" in msg else None
|
||||
"""Summarize script and automation traces."""
|
||||
domain = msg["domain"]
|
||||
key = (domain, msg["item_id"]) if "item_id" in msg else None
|
||||
|
||||
if not key:
|
||||
traces = get_all_debug_traces(hass, summary=True)
|
||||
traces = []
|
||||
for key in hass.data[DATA_TRACE]:
|
||||
if key[0] == domain:
|
||||
traces.extend(get_debug_traces(hass, key))
|
||||
else:
|
||||
traces = get_debug_traces(hass, key, summary=True)
|
||||
traces = get_debug_traces(hass, key)
|
||||
|
||||
connection.send_result(msg["id"], traces)
|
||||
|
||||
|
@ -230,7 +244,7 @@ def websocket_subscribe_breakpoint_events(hass, connection, msg):
|
|||
}
|
||||
)
|
||||
def websocket_debug_continue(hass, connection, msg):
|
||||
"""Resume execution of halted automation or script."""
|
||||
"""Resume execution of halted script or automation."""
|
||||
key = (msg["domain"], msg["item_id"])
|
||||
run_id = msg["run_id"]
|
||||
|
||||
|
@ -250,7 +264,7 @@ def websocket_debug_continue(hass, connection, msg):
|
|||
}
|
||||
)
|
||||
def websocket_debug_step(hass, connection, msg):
|
||||
"""Single step a halted automation or script."""
|
||||
"""Single step a halted script or automation."""
|
||||
key = (msg["domain"], msg["item_id"])
|
||||
run_id = msg["run_id"]
|
||||
|
||||
|
@ -270,7 +284,7 @@ def websocket_debug_step(hass, connection, msg):
|
|||
}
|
||||
)
|
||||
def websocket_debug_stop(hass, connection, msg):
|
||||
"""Stop a halted automation or script."""
|
||||
"""Stop a halted script or automation."""
|
||||
key = (msg["domain"], msg["item_id"])
|
||||
run_id = msg["run_id"]
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ from tests.common import assert_lists_same
|
|||
|
||||
|
||||
def _find_run_id(traces, trace_type, item_id):
|
||||
"""Find newest run_id for an automation or script."""
|
||||
"""Find newest run_id for an script or automation."""
|
||||
for trace in reversed(traces):
|
||||
if trace["domain"] == trace_type and trace["item_id"] == item_id:
|
||||
return trace["run_id"]
|
||||
|
@ -18,7 +18,7 @@ def _find_run_id(traces, trace_type, item_id):
|
|||
|
||||
|
||||
def _find_traces(traces, trace_type, item_id):
|
||||
"""Find traces for an automation or script."""
|
||||
"""Find traces for an script or automation."""
|
||||
return [
|
||||
trace
|
||||
for trace in traces
|
||||
|
@ -30,7 +30,7 @@ def _find_traces(traces, trace_type, item_id):
|
|||
"domain, prefix", [("automation", "action"), ("script", "sequence")]
|
||||
)
|
||||
async def test_get_trace(hass, hass_ws_client, domain, prefix):
|
||||
"""Test tracing an automation or script."""
|
||||
"""Test tracing an script or automation."""
|
||||
id = 1
|
||||
|
||||
def next_id():
|
||||
|
@ -92,7 +92,7 @@ async def test_get_trace(hass, hass_ws_client, domain, prefix):
|
|||
await hass.async_block_till_done()
|
||||
|
||||
# List traces
|
||||
await client.send_json({"id": next_id(), "type": "trace/list"})
|
||||
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
run_id = _find_run_id(response["result"], domain, "sun")
|
||||
|
@ -140,7 +140,7 @@ async def test_get_trace(hass, hass_ws_client, domain, prefix):
|
|||
await hass.async_block_till_done()
|
||||
|
||||
# List traces
|
||||
await client.send_json({"id": next_id(), "type": "trace/list"})
|
||||
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
run_id = _find_run_id(response["result"], domain, "moon")
|
||||
|
@ -193,7 +193,7 @@ async def test_get_trace(hass, hass_ws_client, domain, prefix):
|
|||
await hass.async_block_till_done()
|
||||
|
||||
# List traces
|
||||
await client.send_json({"id": next_id(), "type": "trace/list"})
|
||||
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
run_id = _find_run_id(response["result"], "automation", "moon")
|
||||
|
@ -233,7 +233,7 @@ async def test_get_trace(hass, hass_ws_client, domain, prefix):
|
|||
await hass.async_block_till_done()
|
||||
|
||||
# List traces
|
||||
await client.send_json({"id": next_id(), "type": "trace/list"})
|
||||
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
run_id = _find_run_id(response["result"], "automation", "moon")
|
||||
|
@ -280,7 +280,7 @@ async def test_get_trace(hass, hass_ws_client, domain, prefix):
|
|||
|
||||
@pytest.mark.parametrize("domain", ["automation", "script"])
|
||||
async def test_trace_overflow(hass, hass_ws_client, domain):
|
||||
"""Test the number of stored traces per automation or script is limited."""
|
||||
"""Test the number of stored traces per script or automation is limited."""
|
||||
id = 1
|
||||
|
||||
def next_id():
|
||||
|
@ -313,7 +313,7 @@ async def test_trace_overflow(hass, hass_ws_client, domain):
|
|||
|
||||
client = await hass_ws_client()
|
||||
|
||||
await client.send_json({"id": next_id(), "type": "trace/list"})
|
||||
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
assert response["result"] == []
|
||||
|
@ -328,7 +328,7 @@ async def test_trace_overflow(hass, hass_ws_client, domain):
|
|||
await hass.async_block_till_done()
|
||||
|
||||
# List traces
|
||||
await client.send_json({"id": next_id(), "type": "trace/list"})
|
||||
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
assert len(_find_traces(response["result"], domain, "moon")) == 1
|
||||
|
@ -343,7 +343,7 @@ async def test_trace_overflow(hass, hass_ws_client, domain):
|
|||
await hass.services.async_call("script", "moon")
|
||||
await hass.async_block_till_done()
|
||||
|
||||
await client.send_json({"id": next_id(), "type": "trace/list"})
|
||||
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
moon_traces = _find_traces(response["result"], domain, "moon")
|
||||
|
@ -358,7 +358,7 @@ async def test_trace_overflow(hass, hass_ws_client, domain):
|
|||
"domain, prefix", [("automation", "action"), ("script", "sequence")]
|
||||
)
|
||||
async def test_list_traces(hass, hass_ws_client, domain, prefix):
|
||||
"""Test listing automation and script traces."""
|
||||
"""Test listing script and automation traces."""
|
||||
id = 1
|
||||
|
||||
def next_id():
|
||||
|
@ -398,7 +398,7 @@ async def test_list_traces(hass, hass_ws_client, domain, prefix):
|
|||
|
||||
client = await hass_ws_client()
|
||||
|
||||
await client.send_json({"id": next_id(), "type": "trace/list"})
|
||||
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
assert response["result"] == []
|
||||
|
@ -418,7 +418,7 @@ async def test_list_traces(hass, hass_ws_client, domain, prefix):
|
|||
await hass.async_block_till_done()
|
||||
|
||||
# Get trace
|
||||
await client.send_json({"id": next_id(), "type": "trace/list"})
|
||||
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
assert len(response["result"]) == 1
|
||||
|
@ -461,7 +461,7 @@ async def test_list_traces(hass, hass_ws_client, domain, prefix):
|
|||
await hass.async_block_till_done()
|
||||
|
||||
# Get trace
|
||||
await client.send_json({"id": next_id(), "type": "trace/list"})
|
||||
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
assert len(_find_traces(response["result"], domain, "moon")) == 3
|
||||
|
@ -585,7 +585,7 @@ async def test_nested_traces(hass, hass_ws_client, domain, prefix):
|
|||
"domain, prefix", [("automation", "action"), ("script", "sequence")]
|
||||
)
|
||||
async def test_breakpoints(hass, hass_ws_client, domain, prefix):
|
||||
"""Test automation and script breakpoints."""
|
||||
"""Test script and automation breakpoints."""
|
||||
id = 1
|
||||
|
||||
def next_id():
|
||||
|
@ -594,7 +594,9 @@ async def test_breakpoints(hass, hass_ws_client, domain, prefix):
|
|||
return id
|
||||
|
||||
async def assert_last_action(item_id, expected_action, expected_state):
|
||||
await client.send_json({"id": next_id(), "type": "trace/list"})
|
||||
await client.send_json(
|
||||
{"id": next_id(), "type": "trace/list", "domain": domain}
|
||||
)
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
trace = _find_traces(response["result"], domain, item_id)[-1]
|
||||
|
@ -770,7 +772,9 @@ async def test_breakpoints_2(hass, hass_ws_client, domain, prefix):
|
|||
return id
|
||||
|
||||
async def assert_last_action(item_id, expected_action, expected_state):
|
||||
await client.send_json({"id": next_id(), "type": "trace/list"})
|
||||
await client.send_json(
|
||||
{"id": next_id(), "type": "trace/list", "domain": domain}
|
||||
)
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
trace = _find_traces(response["result"], domain, item_id)[-1]
|
||||
|
@ -883,7 +887,9 @@ async def test_breakpoints_3(hass, hass_ws_client, domain, prefix):
|
|||
return id
|
||||
|
||||
async def assert_last_action(item_id, expected_action, expected_state):
|
||||
await client.send_json({"id": next_id(), "type": "trace/list"})
|
||||
await client.send_json(
|
||||
{"id": next_id(), "type": "trace/list", "domain": domain}
|
||||
)
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
trace = _find_traces(response["result"], domain, item_id)[-1]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue