Fix changed_variables in automation traces (#106665)

This commit is contained in:
Erik Montnemery 2023-12-30 08:34:21 +01:00 committed by Franck Nijhof
parent 494dd2ef07
commit 362e5ca09a
No known key found for this signature in database
GPG key ID: D62583BA8AB11CA3
3 changed files with 50 additions and 44 deletions

View file

@ -2,7 +2,7 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable, Mapping, Sequence
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
from contextlib import asynccontextmanager, suppress
from contextvars import ContextVar
from copy import copy
@ -157,7 +157,12 @@ def action_trace_append(variables, path):
@asynccontextmanager
async def trace_action(hass, script_run, stop, variables):
async def trace_action(
hass: HomeAssistant,
script_run: _ScriptRun,
stop: asyncio.Event,
variables: dict[str, Any],
) -> AsyncGenerator[TraceElement, None]:
"""Trace action execution."""
path = trace_path_get()
trace_element = action_trace_append(variables, path)
@ -362,6 +367,8 @@ class _StopScript(_HaltScript):
class _ScriptRun:
"""Manage Script sequence run."""
_action: dict[str, Any]
def __init__(
self,
hass: HomeAssistant,
@ -376,7 +383,6 @@ class _ScriptRun:
self._context = context
self._log_exceptions = log_exceptions
self._step = -1
self._action: dict[str, Any] | None = None
self._stop = asyncio.Event()
self._stopped = asyncio.Event()
@ -446,11 +452,13 @@ class _ScriptRun:
return ScriptRunResult(response, self._variables)
async def _async_step(self, log_exceptions):
async def _async_step(self, log_exceptions: bool) -> None:
continue_on_error = self._action.get(CONF_CONTINUE_ON_ERROR, False)
with trace_path(str(self._step)):
async with trace_action(self._hass, self, self._stop, self._variables):
async with trace_action(
self._hass, self, self._stop, self._variables
) as trace_element:
if self._stop.is_set():
return
@ -466,6 +474,7 @@ class _ScriptRun:
try:
handler = f"_async_{action}_step"
await getattr(self, handler)()
trace_element.update_variables(self._variables)
except Exception as ex: # pylint: disable=broad-except
self._handle_exception(
ex, continue_on_error, self._log_exceptions or log_exceptions

View file

@ -21,6 +21,7 @@ class TraceElement:
"_child_key",
"_child_run_id",
"_error",
"_last_variables",
"path",
"_result",
"reuse_by_child",
@ -38,16 +39,8 @@ class TraceElement:
self.reuse_by_child = False
self._timestamp = dt_util.utcnow()
if variables is None:
variables = {}
last_variables = variables_cv.get() or {}
variables_cv.set(dict(variables))
changed_variables = {
key: value
for key, value in variables.items()
if key not in last_variables or last_variables[key] != value
}
self._variables = changed_variables
self._last_variables = variables_cv.get() or {}
self.update_variables(variables)
def __repr__(self) -> str:
"""Container for trace data."""
@ -71,6 +64,19 @@ class TraceElement:
old_result = self._result or {}
self._result = {**old_result, **kwargs}
def update_variables(self, variables: TemplateVarsType) -> None:
"""Update variables."""
if variables is None:
variables = {}
last_variables = self._last_variables
variables_cv.set(dict(variables))
changed_variables = {
key: value
for key, value in variables.items()
if key not in last_variables or last_variables[key] != value
}
self._variables = changed_variables
def as_dict(self) -> dict[str, Any]:
"""Return dictionary version of this TraceElement."""
result: dict[str, Any] = {"path": self.path, "timestamp": self._timestamp}

View file

@ -386,7 +386,10 @@ async def test_calling_service_response_data(
"target": {},
},
"running_script": False,
}
},
"variables": {
"my_response": {"data": "value-12345"},
},
}
],
"1": [
@ -399,10 +402,7 @@ async def test_calling_service_response_data(
"target": {},
},
"running_script": False,
},
"variables": {
"my_response": {"data": "value-12345"},
},
}
}
],
}
@ -1163,13 +1163,13 @@ async def test_wait_template_not_schedule(hass: HomeAssistant) -> None:
assert_action_trace(
{
"0": [{"result": {"event": "test_event", "event_data": {}}}],
"1": [{"result": {"wait": {"completed": True, "remaining": None}}}],
"2": [
"1": [
{
"result": {"event": "test_event", "event_data": {}},
"result": {"wait": {"completed": True, "remaining": None}},
"variables": {"wait": {"completed": True, "remaining": None}},
}
],
"2": [{"result": {"event": "test_event", "event_data": {}}}],
}
)
@ -1230,13 +1230,13 @@ async def test_wait_timeout(
else:
variable_wait = {"wait": {"trigger": None, "remaining": 0.0}}
expected_trace = {
"0": [{"result": variable_wait}],
"1": [
"0": [
{
"result": {"event": "test_event", "event_data": {}},
"result": variable_wait,
"variables": variable_wait,
}
],
"1": [{"result": {"event": "test_event", "event_data": {}}}],
}
assert_action_trace(expected_trace)
@ -1291,19 +1291,14 @@ async def test_wait_continue_on_timeout(
else:
variable_wait = {"wait": {"trigger": None, "remaining": 0.0}}
expected_trace = {
"0": [{"result": variable_wait}],
"0": [{"result": variable_wait, "variables": variable_wait}],
}
if continue_on_timeout is False:
expected_trace["0"][0]["result"]["timeout"] = True
expected_trace["0"][0]["error_type"] = asyncio.TimeoutError
expected_script_execution = "aborted"
else:
expected_trace["1"] = [
{
"result": {"event": "test_event", "event_data": {}},
"variables": variable_wait,
}
]
expected_trace["1"] = [{"result": {"event": "test_event", "event_data": {}}}]
expected_script_execution = "finished"
assert_action_trace(expected_trace, expected_script_execution)
@ -3269,12 +3264,12 @@ async def test_parallel(hass: HomeAssistant, caplog: pytest.LogCaptureFixture) -
"description": "state of switch.trigger",
},
}
}
},
"variables": {"wait": {"remaining": None}},
}
],
"0/parallel/1/sequence/0": [
{
"variables": {},
"result": {
"event": "test_event",
"event_data": {"hello": "from action 2", "what": "world"},
@ -3283,7 +3278,6 @@ async def test_parallel(hass: HomeAssistant, caplog: pytest.LogCaptureFixture) -
],
"0/parallel/0/sequence/1": [
{
"variables": {"wait": {"remaining": None}},
"result": {
"event": "test_event",
"event_data": {"hello": "from action 1", "what": "world"},
@ -4462,7 +4456,7 @@ async def test_set_variable(
assert f"Executing step {alias}" in caplog.text
expected_trace = {
"0": [{}],
"0": [{"variables": {"variable": "value"}}],
"1": [
{
"result": {
@ -4474,7 +4468,6 @@ async def test_set_variable(
},
"running_script": False,
},
"variables": {"variable": "value"},
}
],
}
@ -4504,7 +4497,7 @@ async def test_set_redefines_variable(
assert mock_calls[1].data["value"] == 2
expected_trace = {
"0": [{}],
"0": [{"variables": {"variable": "1"}}],
"1": [
{
"result": {
@ -4515,11 +4508,10 @@ async def test_set_redefines_variable(
"target": {},
},
"running_script": False,
},
"variables": {"variable": "1"},
}
}
],
"2": [{}],
"2": [{"variables": {"variable": 2}}],
"3": [
{
"result": {
@ -4530,8 +4522,7 @@ async def test_set_redefines_variable(
"target": {},
},
"running_script": False,
},
"variables": {"variable": 2},
}
}
],
}