Store automation and script traces (#56894)

* Store automation and script traces

* Pylint

* Deduplicate code

* Fix issues when no stored traces are available

* Store serialized data for restored traces

* Update WS API

* Update test

* Restore context

* Improve tests

* Add new test files

* Rename restore_traces to async_restore_traces

* Refactor trace.websocket_api

* Defer loading stored traces

* Lint

* Revert refactoring which is no longer needed

* Correct order when restoring traces

* Apply suggestion from code review

* Improve test coverage

* Apply suggestions from code review
This commit is contained in:
Erik Montnemery 2021-10-19 10:23:23 +02:00 committed by GitHub
parent 29c062fcc4
commit 961ee717ef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 1256 additions and 191 deletions

View file

@ -3,7 +3,7 @@ import json
import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.components import trace, websocket_api
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.dispatcher import (
@ -24,8 +24,6 @@ from homeassistant.helpers.script import (
debug_stop,
)
from .const import DATA_TRACE
# mypy: allow-untyped-calls, allow-untyped-defs
TRACE_DOMAINS = ("automation", "script")
@ -46,7 +44,6 @@ def async_setup(hass: HomeAssistant) -> None:
websocket_api.async_register_command(hass, websocket_subscribe_breakpoint_events)
@callback
@websocket_api.require_admin
@websocket_api.websocket_command(
{
@ -56,37 +53,27 @@ def async_setup(hass: HomeAssistant) -> None:
vol.Required("run_id"): str,
}
)
def websocket_trace_get(hass, connection, msg):
@websocket_api.async_response
async def websocket_trace_get(hass, connection, msg):
"""Get a script or automation trace."""
key = (msg["domain"], msg["item_id"])
key = f"{msg['domain']}.{msg['item_id']}"
run_id = msg["run_id"]
try:
trace = hass.data[DATA_TRACE][key][run_id]
requested_trace = await trace.async_get_trace(hass, key, run_id)
except KeyError:
connection.send_error(
msg["id"], websocket_api.ERR_NOT_FOUND, "The trace could not be found"
)
return
message = websocket_api.messages.result_message(msg["id"], trace)
message = websocket_api.messages.result_message(msg["id"], requested_trace)
connection.send_message(
json.dumps(message, cls=ExtendedJSONEncoder, allow_nan=False)
)
def get_debug_traces(hass, key):
"""Return a serializable list of debug traces for a script or automation."""
traces = []
for trace in hass.data[DATA_TRACE].get(key, {}).values():
traces.append(trace.as_short_dict())
return traces
@callback
@websocket_api.require_admin
@websocket_api.websocket_command(
{
@ -95,23 +82,17 @@ def get_debug_traces(hass, key):
vol.Optional("item_id", "id"): str,
}
)
def websocket_trace_list(hass, connection, msg):
@websocket_api.async_response
async def websocket_trace_list(hass, connection, msg):
"""Summarize script and automation traces."""
domain = msg["domain"]
key = (domain, msg["item_id"]) if "item_id" in msg else None
wanted_domain = msg["domain"]
key = f"{msg['domain']}.{msg['item_id']}" if "item_id" in msg else None
if not key:
traces = []
for key in hass.data[DATA_TRACE]:
if key[0] == domain:
traces.extend(get_debug_traces(hass, key))
else:
traces = get_debug_traces(hass, key)
traces = await trace.async_list_traces(hass, wanted_domain, key)
connection.send_result(msg["id"], traces)
@callback
@websocket_api.require_admin
@websocket_api.websocket_command(
{
@ -120,20 +101,12 @@ def websocket_trace_list(hass, connection, msg):
vol.Inclusive("item_id", "id"): str,
}
)
def websocket_trace_contexts(hass, connection, msg):
@websocket_api.async_response
async def websocket_trace_contexts(hass, connection, msg):
"""Retrieve contexts we have traces for."""
key = (msg["domain"], msg["item_id"]) if "item_id" in msg else None
key = f"{msg['domain']}.{msg['item_id']}" if "item_id" in msg else None
if key is not None:
values = {key: hass.data[DATA_TRACE].get(key, {})}
else:
values = hass.data[DATA_TRACE]
contexts = {
trace.context.id: {"run_id": trace.run_id, "domain": key[0], "item_id": key[1]}
for key, traces in values.items()
for trace in traces.values()
}
contexts = await trace.async_list_contexts(hass, key)
connection.send_result(msg["id"], contexts)
@ -151,7 +124,7 @@ def websocket_trace_contexts(hass, connection, msg):
)
def websocket_breakpoint_set(hass, connection, msg):
"""Set breakpoint."""
key = (msg["domain"], msg["item_id"])
key = f"{msg['domain']}.{msg['item_id']}"
node = msg["node"]
run_id = msg.get("run_id")
@ -178,7 +151,7 @@ def websocket_breakpoint_set(hass, connection, msg):
)
def websocket_breakpoint_clear(hass, connection, msg):
"""Clear breakpoint."""
key = (msg["domain"], msg["item_id"])
key = f"{msg['domain']}.{msg['item_id']}"
node = msg["node"]
run_id = msg.get("run_id")
@ -194,7 +167,8 @@ def websocket_breakpoint_list(hass, connection, msg):
"""List breakpoints."""
breakpoints = breakpoint_list(hass)
for _breakpoint in breakpoints:
_breakpoint["domain"], _breakpoint["item_id"] = _breakpoint.pop("key")
key = _breakpoint.pop("key")
_breakpoint["domain"], _breakpoint["item_id"] = key.split(".", 1)
connection.send_result(msg["id"], breakpoints)
@ -210,12 +184,13 @@ def websocket_subscribe_breakpoint_events(hass, connection, msg):
@callback
def breakpoint_hit(key, run_id, node):
"""Forward events to websocket."""
domain, item_id = key.split(".", 1)
connection.send_message(
websocket_api.event_message(
msg["id"],
{
"domain": key[0],
"item_id": key[1],
"domain": domain,
"item_id": item_id,
"run_id": run_id,
"node": node,
},
@ -254,7 +229,7 @@ def websocket_subscribe_breakpoint_events(hass, connection, msg):
)
def websocket_debug_continue(hass, connection, msg):
"""Resume execution of halted script or automation."""
key = (msg["domain"], msg["item_id"])
key = f"{msg['domain']}.{msg['item_id']}"
run_id = msg["run_id"]
result = debug_continue(hass, key, run_id)
@ -274,7 +249,7 @@ def websocket_debug_continue(hass, connection, msg):
)
def websocket_debug_step(hass, connection, msg):
"""Single step a halted script or automation."""
key = (msg["domain"], msg["item_id"])
key = f"{msg['domain']}.{msg['item_id']}"
run_id = msg["run_id"]
result = debug_step(hass, key, run_id)
@ -294,7 +269,7 @@ def websocket_debug_step(hass, connection, msg):
)
def websocket_debug_stop(hass, connection, msg):
"""Stop a halted script or automation."""
key = (msg["domain"], msg["item_id"])
key = f"{msg['domain']}.{msg['item_id']}"
run_id = msg["run_id"]
result = debug_stop(hass, key, run_id)