Make number of stored traces configurable (#49728)

This commit is contained in:
Erik Montnemery 2021-04-27 19:27:12 +02:00 committed by GitHub
parent b10534359b
commit ce64690817
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 104 additions and 23 deletions

View file

@ -74,6 +74,7 @@ from .config import PLATFORM_SCHEMA # noqa: F401
from .const import ( from .const import (
CONF_ACTION, CONF_ACTION,
CONF_INITIAL_STATE, CONF_INITIAL_STATE,
CONF_TRACE,
CONF_TRIGGER, CONF_TRIGGER,
CONF_TRIGGER_VARIABLES, CONF_TRIGGER_VARIABLES,
DEFAULT_INITIAL_STATE, DEFAULT_INITIAL_STATE,
@ -274,6 +275,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
trigger_variables, trigger_variables,
raw_config, raw_config,
blueprint_inputs, blueprint_inputs,
trace_config,
): ):
"""Initialize an automation entity.""" """Initialize an automation entity."""
self._id = automation_id self._id = automation_id
@ -292,6 +294,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
self._trigger_variables: ScriptVariables = trigger_variables self._trigger_variables: ScriptVariables = trigger_variables
self._raw_config = raw_config self._raw_config = raw_config
self._blueprint_inputs = blueprint_inputs self._blueprint_inputs = blueprint_inputs
self._trace_config = trace_config
@property @property
def name(self): def name(self):
@ -444,6 +447,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
self._raw_config, self._raw_config,
self._blueprint_inputs, self._blueprint_inputs,
trigger_context, trigger_context,
self._trace_config,
) as automation_trace: ) as automation_trace:
if self._variables: if self._variables:
try: try:
@ -682,6 +686,7 @@ async def _async_process_config(
config_block.get(CONF_TRIGGER_VARIABLES), config_block.get(CONF_TRIGGER_VARIABLES),
raw_config, raw_config,
raw_blueprint_inputs, raw_blueprint_inputs,
config_block[CONF_TRACE],
) )
entities.append(entity) entities.append(entity)

View file

@ -8,6 +8,7 @@ from homeassistant.components import blueprint
from homeassistant.components.device_automation.exceptions import ( from homeassistant.components.device_automation.exceptions import (
InvalidDeviceAutomationConfig, InvalidDeviceAutomationConfig,
) )
from homeassistant.components.trace import TRACE_CONFIG_SCHEMA
from homeassistant.config import async_log_exception, config_without_domain from homeassistant.config import async_log_exception, config_without_domain
from homeassistant.const import ( from homeassistant.const import (
CONF_ALIAS, CONF_ALIAS,
@ -26,6 +27,7 @@ from .const import (
CONF_ACTION, CONF_ACTION,
CONF_HIDE_ENTITY, CONF_HIDE_ENTITY,
CONF_INITIAL_STATE, CONF_INITIAL_STATE,
CONF_TRACE,
CONF_TRIGGER, CONF_TRIGGER,
CONF_TRIGGER_VARIABLES, CONF_TRIGGER_VARIABLES,
DOMAIN, DOMAIN,
@ -45,6 +47,7 @@ PLATFORM_SCHEMA = vol.All(
CONF_ID: str, CONF_ID: str,
CONF_ALIAS: cv.string, CONF_ALIAS: cv.string,
vol.Optional(CONF_DESCRIPTION): cv.string, vol.Optional(CONF_DESCRIPTION): cv.string,
vol.Optional(CONF_TRACE, default={}): TRACE_CONFIG_SCHEMA,
vol.Optional(CONF_INITIAL_STATE): cv.boolean, vol.Optional(CONF_INITIAL_STATE): cv.boolean,
vol.Optional(CONF_HIDE_ENTITY): cv.boolean, vol.Optional(CONF_HIDE_ENTITY): cv.boolean,
vol.Required(CONF_TRIGGER): cv.TRIGGER_SCHEMA, vol.Required(CONF_TRIGGER): cv.TRIGGER_SCHEMA,

View file

@ -12,6 +12,7 @@ CONF_CONDITION_TYPE = "condition_type"
CONF_INITIAL_STATE = "initial_state" CONF_INITIAL_STATE = "initial_state"
CONF_BLUEPRINT = "blueprint" CONF_BLUEPRINT = "blueprint"
CONF_INPUT = "input" CONF_INPUT = "input"
CONF_TRACE = "trace"
DEFAULT_INITIAL_STATE = True DEFAULT_INITIAL_STATE = True

View file

@ -5,6 +5,7 @@ from contextlib import contextmanager
from typing import Any from typing import Any
from homeassistant.components.trace import ActionTrace, async_store_trace from homeassistant.components.trace import ActionTrace, async_store_trace
from homeassistant.components.trace.const import CONF_STORED_TRACES
from homeassistant.core import Context from homeassistant.core import Context
# mypy: allow-untyped-calls, allow-untyped-defs # mypy: allow-untyped-calls, allow-untyped-defs
@ -38,10 +39,12 @@ class AutomationTrace(ActionTrace):
@contextmanager @contextmanager
def trace_automation(hass, automation_id, config, blueprint_inputs, context): def trace_automation(
hass, automation_id, config, blueprint_inputs, context, trace_config
):
"""Trace action execution of automation with automation_id.""" """Trace action execution of automation with automation_id."""
trace = AutomationTrace(automation_id, config, blueprint_inputs, context) trace = AutomationTrace(automation_id, config, blueprint_inputs, context)
async_store_trace(hass, trace) async_store_trace(hass, trace, trace_config[CONF_STORED_TRACES])
try: try:
yield trace yield trace

View file

@ -6,6 +6,7 @@ import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.components.trace import TRACE_CONFIG_SCHEMA
from homeassistant.const import ( from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_ENTITY_ID,
ATTR_MODE, ATTR_MODE,
@ -58,6 +59,7 @@ CONF_ADVANCED = "advanced"
CONF_EXAMPLE = "example" CONF_EXAMPLE = "example"
CONF_FIELDS = "fields" CONF_FIELDS = "fields"
CONF_REQUIRED = "required" CONF_REQUIRED = "required"
CONF_TRACE = "trace"
ENTITY_ID_FORMAT = DOMAIN + ".{}" ENTITY_ID_FORMAT = DOMAIN + ".{}"
@ -67,6 +69,7 @@ EVENT_SCRIPT_STARTED = "script_started"
SCRIPT_ENTRY_SCHEMA = make_script_schema( SCRIPT_ENTRY_SCHEMA = make_script_schema(
{ {
vol.Optional(CONF_ALIAS): cv.string, vol.Optional(CONF_ALIAS): cv.string,
vol.Optional(CONF_TRACE, default={}): TRACE_CONFIG_SCHEMA,
vol.Optional(CONF_ICON): cv.icon, vol.Optional(CONF_ICON): cv.icon,
vol.Required(CONF_SEQUENCE): cv.SCRIPT_SCHEMA, vol.Required(CONF_SEQUENCE): cv.SCRIPT_SCHEMA,
vol.Optional(CONF_DESCRIPTION, default=""): cv.string, vol.Optional(CONF_DESCRIPTION, default=""): cv.string,
@ -319,6 +322,7 @@ class ScriptEntity(ToggleEntity):
) )
self._changed = asyncio.Event() self._changed = asyncio.Event()
self._raw_config = raw_config self._raw_config = raw_config
self._trace_config = cfg[CONF_TRACE]
@property @property
def should_poll(self): def should_poll(self):
@ -384,7 +388,7 @@ class ScriptEntity(ToggleEntity):
async def _async_run(self, variables, context): async def _async_run(self, variables, context):
with trace_script( with trace_script(
self.hass, self.object_id, self._raw_config, context self.hass, self.object_id, self._raw_config, context, self._trace_config
) as script_trace: ) as script_trace:
# Prepare tracing the execution of the script's sequence # Prepare tracing the execution of the script's sequence
script_trace.set_trace(trace_get()) script_trace.set_trace(trace_get())

View file

@ -5,6 +5,7 @@ from contextlib import contextmanager
from typing import Any from typing import Any
from homeassistant.components.trace import ActionTrace, async_store_trace from homeassistant.components.trace import ActionTrace, async_store_trace
from homeassistant.components.trace.const import CONF_STORED_TRACES
from homeassistant.core import Context from homeassistant.core import Context
@ -23,10 +24,10 @@ class ScriptTrace(ActionTrace):
@contextmanager @contextmanager
def trace_script(hass, item_id, config, context): def trace_script(hass, item_id, config, context, trace_config):
"""Trace execution of a script.""" """Trace execution of a script."""
trace = ScriptTrace(item_id, config, context) trace = ScriptTrace(item_id, config, context)
async_store_trace(hass, trace) async_store_trace(hass, trace, trace_config[CONF_STORED_TRACES])
try: try:
yield trace yield trace

View file

@ -6,7 +6,10 @@ import datetime as dt
from itertools import count from itertools import count
from typing import Any from typing import Any
import voluptuous as vol
from homeassistant.core import Context from homeassistant.core import Context
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.trace import ( from homeassistant.helpers.trace import (
TraceElement, TraceElement,
script_execution_get, script_execution_get,
@ -17,11 +20,15 @@ from homeassistant.helpers.trace import (
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from . import websocket_api from . import websocket_api
from .const import DATA_TRACE, STORED_TRACES from .const import CONF_STORED_TRACES, DATA_TRACE, DEFAULT_STORED_TRACES
from .utils import LimitedSizeDict from .utils import LimitedSizeDict
DOMAIN = "trace" DOMAIN = "trace"
TRACE_CONFIG_SCHEMA = {
vol.Optional(CONF_STORED_TRACES, default=DEFAULT_STORED_TRACES): cv.positive_int
}
async def async_setup(hass, config): async def async_setup(hass, config):
"""Initialize the trace integration.""" """Initialize the trace integration."""
@ -30,18 +37,20 @@ async def async_setup(hass, config):
return True return True
def async_store_trace(hass, trace): def async_store_trace(hass, trace, stored_traces):
"""Store a trace if its item_id is valid.""" """Store a trace if its item_id is valid."""
key = trace.key key = trace.key
if key[1]: if key[1]:
traces = hass.data[DATA_TRACE] traces = hass.data[DATA_TRACE]
if key not in traces: if key not in traces:
traces[key] = LimitedSizeDict(size_limit=STORED_TRACES) traces[key] = LimitedSizeDict(size_limit=stored_traces)
else:
traces[key].size_limit = stored_traces
traces[key][trace.run_id] = trace traces[key][trace.run_id] = trace
class ActionTrace: class ActionTrace:
"""Base container for an script or automation trace.""" """Base container for a script or automation trace."""
_run_ids = count(0) _run_ids = count(0)

View file

@ -1,4 +1,5 @@
"""Shared constants for script and automation tracing and debugging.""" """Shared constants for script and automation tracing and debugging."""
CONF_STORED_TRACES = "stored_traces"
DATA_TRACE = "trace" DATA_TRACE = "trace"
STORED_TRACES = 5 # Stored traces per script or automation DEFAULT_STORED_TRACES = 5 # Stored traces per script or automation

View file

@ -57,7 +57,7 @@ def async_setup(hass: HomeAssistant) -> None:
} }
) )
def websocket_trace_get(hass, connection, msg): def websocket_trace_get(hass, connection, msg):
"""Get an script or automation trace.""" """Get a script or automation trace."""
key = (msg["domain"], msg["item_id"]) key = (msg["domain"], msg["item_id"])
run_id = msg["run_id"] run_id = msg["run_id"]
@ -77,7 +77,7 @@ def websocket_trace_get(hass, connection, msg):
def get_debug_traces(hass, key): def get_debug_traces(hass, key):
"""Return a serializable list of debug traces for an script or automation.""" """Return a serializable list of debug traces for a script or automation."""
traces = [] traces = []
for trace in hass.data[DATA_TRACE].get(key, {}).values(): for trace in hass.data[DATA_TRACE].get(key, {}).values():

View file

@ -4,7 +4,7 @@ import asyncio
import pytest import pytest
from homeassistant.bootstrap import async_setup_component from homeassistant.bootstrap import async_setup_component
from homeassistant.components.trace.const import STORED_TRACES from homeassistant.components.trace.const import DEFAULT_STORED_TRACES
from homeassistant.core import Context, callback from homeassistant.core import Context, callback
from homeassistant.helpers.typing import UNDEFINED from homeassistant.helpers.typing import UNDEFINED
@ -12,7 +12,7 @@ from tests.common import assert_lists_same
def _find_run_id(traces, trace_type, item_id): def _find_run_id(traces, trace_type, item_id):
"""Find newest run_id for an script or automation.""" """Find newest run_id for a script or automation."""
for trace in reversed(traces): for trace in reversed(traces):
if trace["domain"] == trace_type and trace["item_id"] == item_id: if trace["domain"] == trace_type and trace["item_id"] == item_id:
return trace["run_id"] return trace["run_id"]
@ -21,7 +21,7 @@ def _find_run_id(traces, trace_type, item_id):
def _find_traces(traces, trace_type, item_id): def _find_traces(traces, trace_type, item_id):
"""Find traces for an script or automation.""" """Find traces for a script or automation."""
return [ return [
trace trace
for trace in traces for trace in traces
@ -29,7 +29,9 @@ def _find_traces(traces, trace_type, item_id):
] ]
async def _setup_automation_or_script(hass, domain, configs, script_config=None): async def _setup_automation_or_script(
hass, domain, configs, script_config=None, stored_traces=None
):
"""Set up automations or scripts from automation config.""" """Set up automations or scripts from automation config."""
if domain == "script": if domain == "script":
configs = {config["id"]: {"sequence": config["action"]} for config in configs} configs = {config["id"]: {"sequence": config["action"]} for config in configs}
@ -42,6 +44,16 @@ async def _setup_automation_or_script(hass, domain, configs, script_config=None)
else: else:
configs = {**configs, **script_config} configs = {**configs, **script_config}
if stored_traces is not None:
if domain == "script":
for config in configs.values():
config["trace"] = {}
config["trace"]["stored_traces"] = stored_traces
else:
for config in configs:
config["trace"] = {}
config["trace"]["stored_traces"] = stored_traces
assert await async_setup_component(hass, domain, {domain: configs}) assert await async_setup_component(hass, domain, {domain: configs})
@ -97,7 +109,7 @@ async def test_get_trace(
context_key, context_key,
condition_results, condition_results,
): ):
"""Test tracing an script or automation.""" """Test tracing a script or automation."""
id = 1 id = 1
def next_id(): def next_id():
@ -347,8 +359,11 @@ async def test_get_invalid_trace(hass, hass_ws_client, domain):
assert response["error"]["code"] == "not_found" assert response["error"]["code"] == "not_found"
@pytest.mark.parametrize("domain", ["automation", "script"]) @pytest.mark.parametrize(
async def test_trace_overflow(hass, hass_ws_client, domain): "domain,stored_traces",
[("automation", None), ("automation", 10), ("script", None), ("script", 10)],
)
async def test_trace_overflow(hass, hass_ws_client, domain, stored_traces):
"""Test the number of stored traces per script or automation is limited.""" """Test the number of stored traces per script or automation is limited."""
id = 1 id = 1
@ -367,7 +382,9 @@ async def test_trace_overflow(hass, hass_ws_client, domain):
"trigger": {"platform": "event", "event_type": "test_event2"}, "trigger": {"platform": "event", "event_type": "test_event2"},
"action": {"event": "another_event"}, "action": {"event": "another_event"},
} }
await _setup_automation_or_script(hass, domain, [sun_config, moon_config]) await _setup_automation_or_script(
hass, domain, [sun_config, moon_config], stored_traces=stored_traces
)
client = await hass_ws_client() client = await hass_ws_client()
@ -390,7 +407,7 @@ async def test_trace_overflow(hass, hass_ws_client, domain):
assert len(_find_traces(response["result"], domain, "sun")) == 1 assert len(_find_traces(response["result"], domain, "sun")) == 1
# Trigger "moon" enough times to overflow the max number of stored traces # Trigger "moon" enough times to overflow the max number of stored traces
for _ in range(STORED_TRACES): for _ in range(stored_traces or DEFAULT_STORED_TRACES):
await _run_automation_or_script(hass, domain, moon_config, "test_event2") await _run_automation_or_script(hass, domain, moon_config, "test_event2")
await hass.async_block_till_done() await hass.async_block_till_done()
@ -398,13 +415,50 @@ async def test_trace_overflow(hass, hass_ws_client, domain):
response = await client.receive_json() response = await client.receive_json()
assert response["success"] assert response["success"]
moon_traces = _find_traces(response["result"], domain, "moon") moon_traces = _find_traces(response["result"], domain, "moon")
assert len(moon_traces) == STORED_TRACES assert len(moon_traces) == stored_traces or DEFAULT_STORED_TRACES
assert moon_traces[0] assert moon_traces[0]
assert int(moon_traces[0]["run_id"]) == int(moon_run_id) + 1 assert int(moon_traces[0]["run_id"]) == int(moon_run_id) + 1
assert int(moon_traces[-1]["run_id"]) == int(moon_run_id) + STORED_TRACES assert int(moon_traces[-1]["run_id"]) == int(moon_run_id) + (
stored_traces or DEFAULT_STORED_TRACES
)
assert len(_find_traces(response["result"], domain, "sun")) == 1 assert len(_find_traces(response["result"], domain, "sun")) == 1
@pytest.mark.parametrize("domain", ["automation", "script"])
async def test_trace_no_traces(hass, hass_ws_client, domain):
"""Test the storing traces for a script or automation can be disabled."""
id = 1
def next_id():
nonlocal id
id += 1
return id
sun_config = {
"id": "sun",
"trigger": {"platform": "event", "event_type": "test_event"},
"action": {"event": "some_event"},
}
await _setup_automation_or_script(hass, domain, [sun_config], stored_traces=0)
client = await hass_ws_client()
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json()
assert response["success"]
assert response["result"] == []
# Trigger "sun" automation / script once
await _run_automation_or_script(hass, domain, sun_config, "test_event")
await hass.async_block_till_done()
# List traces
await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json()
assert response["success"]
assert len(_find_traces(response["result"], domain, "sun")) == 0
@pytest.mark.parametrize( @pytest.mark.parametrize(
"domain, prefix, trigger, last_step, script_execution", "domain, prefix, trigger, last_step, script_execution",
[ [