Address review comments from trace refactoring PRs (#48288)

This commit is contained in:
Erik Montnemery 2021-03-29 08:09:14 +02:00 committed by GitHub
parent ee81869c05
commit 14ef0531f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 130 additions and 140 deletions

View file

@ -2,25 +2,79 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import Any, Deque
from homeassistant.components.trace import AutomationTrace, async_store_trace
from homeassistant.components.trace import ActionTrace, async_store_trace
from homeassistant.core import Context
from homeassistant.helpers.trace import TraceElement
# mypy: allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any
class AutomationTrace(ActionTrace):
"""Container for automation trace."""
def __init__(
self,
item_id: str,
config: dict[str, Any],
context: Context,
):
"""Container for automation trace."""
key = ("automation", item_id)
super().__init__(key, config, context)
self._condition_trace: dict[str, Deque[TraceElement]] | None = None
def set_condition_trace(self, trace: dict[str, Deque[TraceElement]]) -> None:
"""Set condition trace."""
self._condition_trace = trace
def as_dict(self) -> dict[str, Any]:
"""Return dictionary version of this AutomationTrace."""
result = super().as_dict()
condition_traces = {}
if self._condition_trace:
for key, trace_list in self._condition_trace.items():
condition_traces[key] = [item.as_dict() for item in trace_list]
result["condition_trace"] = condition_traces
return result
def as_short_dict(self) -> dict[str, Any]:
"""Return a brief dictionary version of this AutomationTrace."""
result = super().as_short_dict()
last_condition = None
trigger = None
if self._condition_trace:
last_condition = list(self._condition_trace)[-1]
if self._variables:
trigger = self._variables.get("trigger", {}).get("description")
result["trigger"] = trigger
result["last_condition"] = last_condition
return result
@contextmanager
def trace_automation(hass, item_id, config, context):
"""Trace action execution of automation with item_id."""
trace = AutomationTrace(item_id, config, context)
def trace_automation(hass, automation_id, config, context):
"""Trace action execution of automation with automation_id."""
trace = AutomationTrace(automation_id, config, context)
async_store_trace(hass, trace)
try:
yield trace
except Exception as ex: # pylint: disable=broad-except
if item_id:
if automation_id:
trace.set_error(ex)
raise ex
finally:
if item_id:
if automation_id:
trace.finished()

View file

@ -2,8 +2,24 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import Any
from homeassistant.components.trace import ScriptTrace, async_store_trace
from homeassistant.components.trace import ActionTrace, async_store_trace
from homeassistant.core import Context
class ScriptTrace(ActionTrace):
"""Container for automation trace."""
def __init__(
self,
item_id: str,
config: dict[str, Any],
context: Context,
):
"""Container for automation trace."""
key = ("script", item_id)
super().__init__(key, config, context)
@contextmanager

View file

@ -128,68 +128,3 @@ class ActionTrace:
result["last_action"] = last_action
return result
class AutomationTrace(ActionTrace):
"""Container for automation trace."""
def __init__(
self,
item_id: str,
config: dict[str, Any],
context: Context,
):
"""Container for automation trace."""
key = ("automation", item_id)
super().__init__(key, config, context)
self._condition_trace: dict[str, Deque[TraceElement]] | None = None
def set_condition_trace(self, trace: dict[str, Deque[TraceElement]]) -> None:
"""Set condition trace."""
self._condition_trace = trace
def as_dict(self) -> dict[str, Any]:
"""Return dictionary version of this AutomationTrace."""
result = super().as_dict()
condition_traces = {}
if self._condition_trace:
for key, trace_list in self._condition_trace.items():
condition_traces[key] = [item.as_dict() for item in trace_list]
result["condition_trace"] = condition_traces
return result
def as_short_dict(self) -> dict[str, Any]:
"""Return a brief dictionary version of this AutomationTrace."""
result = super().as_short_dict()
last_condition = None
trigger = None
if self._condition_trace:
last_condition = list(self._condition_trace)[-1]
if self._variables:
trigger = self._variables.get("trigger", {}).get("description")
result["trigger"] = trigger
result["last_condition"] = last_condition
return result
class ScriptTrace(ActionTrace):
"""Container for automation trace."""
def __init__(
self,
item_id: str,
config: dict[str, Any],
context: Context,
):
"""Container for automation trace."""
key = ("script", item_id)
super().__init__(key, config, context)

View file

@ -1,35 +0,0 @@
"""Support for automation and script tracing and debugging."""
from homeassistant.core import callback
from .const import DATA_TRACE
@callback
def get_debug_trace(hass, key, run_id):
"""Return a serializable debug trace."""
return hass.data[DATA_TRACE][key][run_id]
@callback
def get_debug_traces(hass, key, summary=False):
"""Return a serializable list of debug traces for an automation or script."""
traces = []
for trace in hass.data[DATA_TRACE].get(key, {}).values():
if summary:
traces.append(trace.as_short_dict())
else:
traces.append(trace.as_dict())
return traces
@callback
def get_all_debug_traces(hass, summary=False):
"""Return a serializable list of debug traces for all automations and scripts."""
traces = []
for key in hass.data[DATA_TRACE]:
traces.extend(get_debug_traces(hass, key, summary))
return traces

View file

@ -1,4 +1,4 @@
"""Helpers for automation and script tracing and debugging."""
"""Helpers for script and automation tracing and debugging."""
from collections import OrderedDict
from datetime import timedelta
from typing import Any

View file

@ -23,12 +23,12 @@ from homeassistant.helpers.script import (
debug_stop,
)
from .trace import DATA_TRACE, get_all_debug_traces, get_debug_trace, get_debug_traces
from .const import DATA_TRACE
from .utils import TraceJSONEncoder
# mypy: allow-untyped-calls, allow-untyped-defs
TRACE_DOMAINS = ["automation", "script"]
TRACE_DOMAINS = ("automation", "script")
@callback
@ -57,33 +57,47 @@ def async_setup(hass: HomeAssistant) -> None:
}
)
def websocket_trace_get(hass, connection, msg):
"""Get an automation or script trace."""
"""Get an script or automation trace."""
key = (msg["domain"], msg["item_id"])
run_id = msg["run_id"]
trace = get_debug_trace(hass, key, run_id)
trace = hass.data[DATA_TRACE][key][run_id]
message = websocket_api.messages.result_message(msg["id"], trace)
connection.send_message(json.dumps(message, cls=TraceJSONEncoder, allow_nan=False))
def get_debug_traces(hass, key):
"""Return a serializable list of debug traces for an script or automation."""
traces = []
for trace in hass.data[DATA_TRACE].get(key, {}).values():
traces.append(trace.as_short_dict())
return traces
@callback
@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required("type"): "trace/list",
vol.Inclusive("domain", "id"): vol.In(TRACE_DOMAINS),
vol.Inclusive("item_id", "id"): str,
vol.Required("domain", "id"): vol.In(TRACE_DOMAINS),
vol.Optional("item_id", "id"): str,
}
)
def websocket_trace_list(hass, connection, msg):
"""Summarize automation and script traces."""
key = (msg["domain"], msg["item_id"]) if "item_id" in msg else None
"""Summarize script and automation traces."""
domain = msg["domain"]
key = (domain, msg["item_id"]) if "item_id" in msg else None
if not key:
traces = get_all_debug_traces(hass, summary=True)
traces = []
for key in hass.data[DATA_TRACE]:
if key[0] == domain:
traces.extend(get_debug_traces(hass, key))
else:
traces = get_debug_traces(hass, key, summary=True)
traces = get_debug_traces(hass, key)
connection.send_result(msg["id"], traces)
@ -230,7 +244,7 @@ def websocket_subscribe_breakpoint_events(hass, connection, msg):
}
)
def websocket_debug_continue(hass, connection, msg):
"""Resume execution of halted automation or script."""
"""Resume execution of halted script or automation."""
key = (msg["domain"], msg["item_id"])
run_id = msg["run_id"]
@ -250,7 +264,7 @@ def websocket_debug_continue(hass, connection, msg):
}
)
def websocket_debug_step(hass, connection, msg):
"""Single step a halted automation or script."""
"""Single step a halted script or automation."""
key = (msg["domain"], msg["item_id"])
run_id = msg["run_id"]
@ -270,7 +284,7 @@ def websocket_debug_step(hass, connection, msg):
}
)
def websocket_debug_stop(hass, connection, msg):
"""Stop a halted automation or script."""
"""Stop a halted script or automation."""
key = (msg["domain"], msg["item_id"])
run_id = msg["run_id"]

View file

@ -9,7 +9,7 @@ from tests.common import assert_lists_same
def _find_run_id(traces, trace_type, item_id):
"""Find newest run_id for an automation or script."""
"""Find newest run_id for an script or automation."""
for trace in reversed(traces):
if trace["domain"] == trace_type and trace["item_id"] == item_id:
return trace["run_id"]
@ -18,7 +18,7 @@ def _find_run_id(traces, trace_type, item_id):
def _find_traces(traces, trace_type, item_id):
"""Find traces for an automation or script."""
"""Find traces for an script or automation."""
return [
trace
for trace in traces
@ -30,7 +30,7 @@ def _find_traces(traces, trace_type, item_id):
"domain, prefix", [("automation", "action"), ("script", "sequence")]
)
async def test_get_trace(hass, hass_ws_client, domain, prefix):
"""Test tracing an automation or script."""
"""Test tracing an script or automation."""
id = 1
def next_id():
@ -92,7 +92,7 @@ async def test_get_trace(hass, hass_ws_client, domain, prefix):
await hass.async_block_till_done()
# List traces
await client.send_json({"id": next_id(), "type": "trace/list"})
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json()
assert response["success"]
run_id = _find_run_id(response["result"], domain, "sun")
@ -140,7 +140,7 @@ async def test_get_trace(hass, hass_ws_client, domain, prefix):
await hass.async_block_till_done()
# List traces
await client.send_json({"id": next_id(), "type": "trace/list"})
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json()
assert response["success"]
run_id = _find_run_id(response["result"], domain, "moon")
@ -193,7 +193,7 @@ async def test_get_trace(hass, hass_ws_client, domain, prefix):
await hass.async_block_till_done()
# List traces
await client.send_json({"id": next_id(), "type": "trace/list"})
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json()
assert response["success"]
run_id = _find_run_id(response["result"], "automation", "moon")
@ -233,7 +233,7 @@ async def test_get_trace(hass, hass_ws_client, domain, prefix):
await hass.async_block_till_done()
# List traces
await client.send_json({"id": next_id(), "type": "trace/list"})
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json()
assert response["success"]
run_id = _find_run_id(response["result"], "automation", "moon")
@ -280,7 +280,7 @@ async def test_get_trace(hass, hass_ws_client, domain, prefix):
@pytest.mark.parametrize("domain", ["automation", "script"])
async def test_trace_overflow(hass, hass_ws_client, domain):
"""Test the number of stored traces per automation or script is limited."""
"""Test the number of stored traces per script or automation is limited."""
id = 1
def next_id():
@ -313,7 +313,7 @@ async def test_trace_overflow(hass, hass_ws_client, domain):
client = await hass_ws_client()
await client.send_json({"id": next_id(), "type": "trace/list"})
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json()
assert response["success"]
assert response["result"] == []
@ -328,7 +328,7 @@ async def test_trace_overflow(hass, hass_ws_client, domain):
await hass.async_block_till_done()
# List traces
await client.send_json({"id": next_id(), "type": "trace/list"})
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json()
assert response["success"]
assert len(_find_traces(response["result"], domain, "moon")) == 1
@ -343,7 +343,7 @@ async def test_trace_overflow(hass, hass_ws_client, domain):
await hass.services.async_call("script", "moon")
await hass.async_block_till_done()
await client.send_json({"id": next_id(), "type": "trace/list"})
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json()
assert response["success"]
moon_traces = _find_traces(response["result"], domain, "moon")
@ -358,7 +358,7 @@ async def test_trace_overflow(hass, hass_ws_client, domain):
"domain, prefix", [("automation", "action"), ("script", "sequence")]
)
async def test_list_traces(hass, hass_ws_client, domain, prefix):
"""Test listing automation and script traces."""
"""Test listing script and automation traces."""
id = 1
def next_id():
@ -398,7 +398,7 @@ async def test_list_traces(hass, hass_ws_client, domain, prefix):
client = await hass_ws_client()
await client.send_json({"id": next_id(), "type": "trace/list"})
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json()
assert response["success"]
assert response["result"] == []
@ -418,7 +418,7 @@ async def test_list_traces(hass, hass_ws_client, domain, prefix):
await hass.async_block_till_done()
# Get trace
await client.send_json({"id": next_id(), "type": "trace/list"})
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json()
assert response["success"]
assert len(response["result"]) == 1
@ -461,7 +461,7 @@ async def test_list_traces(hass, hass_ws_client, domain, prefix):
await hass.async_block_till_done()
# Get trace
await client.send_json({"id": next_id(), "type": "trace/list"})
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json()
assert response["success"]
assert len(_find_traces(response["result"], domain, "moon")) == 3
@ -585,7 +585,7 @@ async def test_nested_traces(hass, hass_ws_client, domain, prefix):
"domain, prefix", [("automation", "action"), ("script", "sequence")]
)
async def test_breakpoints(hass, hass_ws_client, domain, prefix):
"""Test automation and script breakpoints."""
"""Test script and automation breakpoints."""
id = 1
def next_id():
@ -594,7 +594,9 @@ async def test_breakpoints(hass, hass_ws_client, domain, prefix):
return id
async def assert_last_action(item_id, expected_action, expected_state):
await client.send_json({"id": next_id(), "type": "trace/list"})
await client.send_json(
{"id": next_id(), "type": "trace/list", "domain": domain}
)
response = await client.receive_json()
assert response["success"]
trace = _find_traces(response["result"], domain, item_id)[-1]
@ -770,7 +772,9 @@ async def test_breakpoints_2(hass, hass_ws_client, domain, prefix):
return id
async def assert_last_action(item_id, expected_action, expected_state):
await client.send_json({"id": next_id(), "type": "trace/list"})
await client.send_json(
{"id": next_id(), "type": "trace/list", "domain": domain}
)
response = await client.receive_json()
assert response["success"]
trace = _find_traces(response["result"], domain, item_id)[-1]
@ -883,7 +887,9 @@ async def test_breakpoints_3(hass, hass_ws_client, domain, prefix):
return id
async def assert_last_action(item_id, expected_action, expected_state):
await client.send_json({"id": next_id(), "type": "trace/list"})
await client.send_json(
{"id": next_id(), "type": "trace/list", "domain": domain}
)
response = await client.receive_json()
assert response["success"]
trace = _find_traces(response["result"], domain, item_id)[-1]