Make number of stored traces configurable (#49728)

This commit is contained in:
Erik Montnemery 2021-04-27 19:27:12 +02:00 committed by GitHub
parent b10534359b
commit ce64690817
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 104 additions and 23 deletions

View file

@ -74,6 +74,7 @@ from .config import PLATFORM_SCHEMA # noqa: F401
from .const import (
CONF_ACTION,
CONF_INITIAL_STATE,
CONF_TRACE,
CONF_TRIGGER,
CONF_TRIGGER_VARIABLES,
DEFAULT_INITIAL_STATE,
@ -274,6 +275,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
trigger_variables,
raw_config,
blueprint_inputs,
trace_config,
):
"""Initialize an automation entity."""
self._id = automation_id
@ -292,6 +294,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
self._trigger_variables: ScriptVariables = trigger_variables
self._raw_config = raw_config
self._blueprint_inputs = blueprint_inputs
self._trace_config = trace_config
@property
def name(self):
@ -444,6 +447,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
self._raw_config,
self._blueprint_inputs,
trigger_context,
self._trace_config,
) as automation_trace:
if self._variables:
try:
@ -682,6 +686,7 @@ async def _async_process_config(
config_block.get(CONF_TRIGGER_VARIABLES),
raw_config,
raw_blueprint_inputs,
config_block[CONF_TRACE],
)
entities.append(entity)

View file

@ -8,6 +8,7 @@ from homeassistant.components import blueprint
from homeassistant.components.device_automation.exceptions import (
InvalidDeviceAutomationConfig,
)
from homeassistant.components.trace import TRACE_CONFIG_SCHEMA
from homeassistant.config import async_log_exception, config_without_domain
from homeassistant.const import (
CONF_ALIAS,
@ -26,6 +27,7 @@ from .const import (
CONF_ACTION,
CONF_HIDE_ENTITY,
CONF_INITIAL_STATE,
CONF_TRACE,
CONF_TRIGGER,
CONF_TRIGGER_VARIABLES,
DOMAIN,
@ -45,6 +47,7 @@ PLATFORM_SCHEMA = vol.All(
CONF_ID: str,
CONF_ALIAS: cv.string,
vol.Optional(CONF_DESCRIPTION): cv.string,
vol.Optional(CONF_TRACE, default={}): TRACE_CONFIG_SCHEMA,
vol.Optional(CONF_INITIAL_STATE): cv.boolean,
vol.Optional(CONF_HIDE_ENTITY): cv.boolean,
vol.Required(CONF_TRIGGER): cv.TRIGGER_SCHEMA,

View file

@ -12,6 +12,7 @@ CONF_CONDITION_TYPE = "condition_type"
CONF_INITIAL_STATE = "initial_state"
CONF_BLUEPRINT = "blueprint"
CONF_INPUT = "input"
CONF_TRACE = "trace"
DEFAULT_INITIAL_STATE = True

View file

@ -5,6 +5,7 @@ from contextlib import contextmanager
from typing import Any
from homeassistant.components.trace import ActionTrace, async_store_trace
from homeassistant.components.trace.const import CONF_STORED_TRACES
from homeassistant.core import Context
# mypy: allow-untyped-calls, allow-untyped-defs
@ -38,10 +39,12 @@ class AutomationTrace(ActionTrace):
@contextmanager
def trace_automation(hass, automation_id, config, blueprint_inputs, context):
def trace_automation(
hass, automation_id, config, blueprint_inputs, context, trace_config
):
"""Trace action execution of automation with automation_id."""
trace = AutomationTrace(automation_id, config, blueprint_inputs, context)
async_store_trace(hass, trace)
async_store_trace(hass, trace, trace_config[CONF_STORED_TRACES])
try:
yield trace

View file

@ -6,6 +6,7 @@ import logging
import voluptuous as vol
from homeassistant.components.trace import TRACE_CONFIG_SCHEMA
from homeassistant.const import (
ATTR_ENTITY_ID,
ATTR_MODE,
@ -58,6 +59,7 @@ CONF_ADVANCED = "advanced"
CONF_EXAMPLE = "example"
CONF_FIELDS = "fields"
CONF_REQUIRED = "required"
CONF_TRACE = "trace"
ENTITY_ID_FORMAT = DOMAIN + ".{}"
@ -67,6 +69,7 @@ EVENT_SCRIPT_STARTED = "script_started"
SCRIPT_ENTRY_SCHEMA = make_script_schema(
{
vol.Optional(CONF_ALIAS): cv.string,
vol.Optional(CONF_TRACE, default={}): TRACE_CONFIG_SCHEMA,
vol.Optional(CONF_ICON): cv.icon,
vol.Required(CONF_SEQUENCE): cv.SCRIPT_SCHEMA,
vol.Optional(CONF_DESCRIPTION, default=""): cv.string,
@ -319,6 +322,7 @@ class ScriptEntity(ToggleEntity):
)
self._changed = asyncio.Event()
self._raw_config = raw_config
self._trace_config = cfg[CONF_TRACE]
@property
def should_poll(self):
@ -384,7 +388,7 @@ class ScriptEntity(ToggleEntity):
async def _async_run(self, variables, context):
with trace_script(
self.hass, self.object_id, self._raw_config, context
self.hass, self.object_id, self._raw_config, context, self._trace_config
) as script_trace:
# Prepare tracing the execution of the script's sequence
script_trace.set_trace(trace_get())

View file

@ -5,6 +5,7 @@ from contextlib import contextmanager
from typing import Any
from homeassistant.components.trace import ActionTrace, async_store_trace
from homeassistant.components.trace.const import CONF_STORED_TRACES
from homeassistant.core import Context
@ -23,10 +24,10 @@ class ScriptTrace(ActionTrace):
@contextmanager
def trace_script(hass, item_id, config, context):
def trace_script(hass, item_id, config, context, trace_config):
"""Trace execution of a script."""
trace = ScriptTrace(item_id, config, context)
async_store_trace(hass, trace)
async_store_trace(hass, trace, trace_config[CONF_STORED_TRACES])
try:
yield trace

View file

@ -6,7 +6,10 @@ import datetime as dt
from itertools import count
from typing import Any
import voluptuous as vol
from homeassistant.core import Context
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.trace import (
TraceElement,
script_execution_get,
@ -17,11 +20,15 @@ from homeassistant.helpers.trace import (
import homeassistant.util.dt as dt_util
from . import websocket_api
from .const import DATA_TRACE, STORED_TRACES
from .const import CONF_STORED_TRACES, DATA_TRACE, DEFAULT_STORED_TRACES
from .utils import LimitedSizeDict
DOMAIN = "trace"
TRACE_CONFIG_SCHEMA = {
vol.Optional(CONF_STORED_TRACES, default=DEFAULT_STORED_TRACES): cv.positive_int
}
async def async_setup(hass, config):
"""Initialize the trace integration."""
@ -30,18 +37,20 @@ async def async_setup(hass, config):
return True
def async_store_trace(hass, trace):
def async_store_trace(hass, trace, stored_traces):
"""Store a trace if its item_id is valid."""
key = trace.key
if key[1]:
traces = hass.data[DATA_TRACE]
if key not in traces:
traces[key] = LimitedSizeDict(size_limit=STORED_TRACES)
traces[key] = LimitedSizeDict(size_limit=stored_traces)
else:
traces[key].size_limit = stored_traces
traces[key][trace.run_id] = trace
class ActionTrace:
"""Base container for an script or automation trace."""
"""Base container for a script or automation trace."""
_run_ids = count(0)

View file

@ -1,4 +1,5 @@
"""Shared constants for script and automation tracing and debugging."""
CONF_STORED_TRACES = "stored_traces"
DATA_TRACE = "trace"
STORED_TRACES = 5 # Stored traces per script or automation
DEFAULT_STORED_TRACES = 5 # Stored traces per script or automation

View file

@ -57,7 +57,7 @@ def async_setup(hass: HomeAssistant) -> None:
}
)
def websocket_trace_get(hass, connection, msg):
"""Get an script or automation trace."""
"""Get a script or automation trace."""
key = (msg["domain"], msg["item_id"])
run_id = msg["run_id"]
@ -77,7 +77,7 @@ def websocket_trace_get(hass, connection, msg):
def get_debug_traces(hass, key):
"""Return a serializable list of debug traces for an script or automation."""
"""Return a serializable list of debug traces for a script or automation."""
traces = []
for trace in hass.data[DATA_TRACE].get(key, {}).values():

View file

@ -4,7 +4,7 @@ import asyncio
import pytest
from homeassistant.bootstrap import async_setup_component
from homeassistant.components.trace.const import STORED_TRACES
from homeassistant.components.trace.const import DEFAULT_STORED_TRACES
from homeassistant.core import Context, callback
from homeassistant.helpers.typing import UNDEFINED
@ -12,7 +12,7 @@ from tests.common import assert_lists_same
def _find_run_id(traces, trace_type, item_id):
"""Find newest run_id for an script or automation."""
"""Find newest run_id for a script or automation."""
for trace in reversed(traces):
if trace["domain"] == trace_type and trace["item_id"] == item_id:
return trace["run_id"]
@ -21,7 +21,7 @@ def _find_run_id(traces, trace_type, item_id):
def _find_traces(traces, trace_type, item_id):
"""Find traces for an script or automation."""
"""Find traces for a script or automation."""
return [
trace
for trace in traces
@ -29,7 +29,9 @@ def _find_traces(traces, trace_type, item_id):
]
async def _setup_automation_or_script(hass, domain, configs, script_config=None):
async def _setup_automation_or_script(
hass, domain, configs, script_config=None, stored_traces=None
):
"""Set up automations or scripts from automation config."""
if domain == "script":
configs = {config["id"]: {"sequence": config["action"]} for config in configs}
@ -42,6 +44,16 @@ async def _setup_automation_or_script(hass, domain, configs, script_config=None)
else:
configs = {**configs, **script_config}
if stored_traces is not None:
if domain == "script":
for config in configs.values():
config["trace"] = {}
config["trace"]["stored_traces"] = stored_traces
else:
for config in configs:
config["trace"] = {}
config["trace"]["stored_traces"] = stored_traces
assert await async_setup_component(hass, domain, {domain: configs})
@ -97,7 +109,7 @@ async def test_get_trace(
context_key,
condition_results,
):
"""Test tracing an script or automation."""
"""Test tracing a script or automation."""
id = 1
def next_id():
@ -347,8 +359,11 @@ async def test_get_invalid_trace(hass, hass_ws_client, domain):
assert response["error"]["code"] == "not_found"
@pytest.mark.parametrize("domain", ["automation", "script"])
async def test_trace_overflow(hass, hass_ws_client, domain):
@pytest.mark.parametrize(
"domain,stored_traces",
[("automation", None), ("automation", 10), ("script", None), ("script", 10)],
)
async def test_trace_overflow(hass, hass_ws_client, domain, stored_traces):
"""Test the number of stored traces per script or automation is limited."""
id = 1
@ -367,7 +382,9 @@ async def test_trace_overflow(hass, hass_ws_client, domain):
"trigger": {"platform": "event", "event_type": "test_event2"},
"action": {"event": "another_event"},
}
await _setup_automation_or_script(hass, domain, [sun_config, moon_config])
await _setup_automation_or_script(
hass, domain, [sun_config, moon_config], stored_traces=stored_traces
)
client = await hass_ws_client()
@ -390,7 +407,7 @@ async def test_trace_overflow(hass, hass_ws_client, domain):
assert len(_find_traces(response["result"], domain, "sun")) == 1
# Trigger "moon" enough times to overflow the max number of stored traces
for _ in range(STORED_TRACES):
for _ in range(stored_traces or DEFAULT_STORED_TRACES):
await _run_automation_or_script(hass, domain, moon_config, "test_event2")
await hass.async_block_till_done()
@ -398,13 +415,50 @@ async def test_trace_overflow(hass, hass_ws_client, domain):
response = await client.receive_json()
assert response["success"]
moon_traces = _find_traces(response["result"], domain, "moon")
assert len(moon_traces) == STORED_TRACES
assert len(moon_traces) == stored_traces or DEFAULT_STORED_TRACES
assert moon_traces[0]
assert int(moon_traces[0]["run_id"]) == int(moon_run_id) + 1
assert int(moon_traces[-1]["run_id"]) == int(moon_run_id) + STORED_TRACES
assert int(moon_traces[-1]["run_id"]) == int(moon_run_id) + (
stored_traces or DEFAULT_STORED_TRACES
)
assert len(_find_traces(response["result"], domain, "sun")) == 1
@pytest.mark.parametrize("domain", ["automation", "script"])
async def test_trace_no_traces(hass, hass_ws_client, domain):
"""Test the storing traces for a script or automation can be disabled."""
id = 1
def next_id():
nonlocal id
id += 1
return id
sun_config = {
"id": "sun",
"trigger": {"platform": "event", "event_type": "test_event"},
"action": {"event": "some_event"},
}
await _setup_automation_or_script(hass, domain, [sun_config], stored_traces=0)
client = await hass_ws_client()
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json()
assert response["success"]
assert response["result"] == []
# Trigger "sun" automation / script once
await _run_automation_or_script(hass, domain, sun_config, "test_event")
await hass.async_block_till_done()
# List traces
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json()
assert response["success"]
assert len(_find_traces(response["result"], domain, "sun")) == 0
@pytest.mark.parametrize(
"domain, prefix, trigger, last_step, script_execution",
[