Fix changed_variables in automation traces (#106665)

This commit is contained in:
Erik Montnemery 2023-12-30 08:34:21 +01:00 committed by Franck Nijhof
parent 494dd2ef07
commit 362e5ca09a
No known key found for this signature in database
GPG key ID: D62583BA8AB11CA3
3 changed files with 50 additions and 44 deletions

View file

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, Mapping, Sequence from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
from contextlib import asynccontextmanager, suppress from contextlib import asynccontextmanager, suppress
from contextvars import ContextVar from contextvars import ContextVar
from copy import copy from copy import copy
@ -157,7 +157,12 @@ def action_trace_append(variables, path):
@asynccontextmanager @asynccontextmanager
async def trace_action(hass, script_run, stop, variables): async def trace_action(
hass: HomeAssistant,
script_run: _ScriptRun,
stop: asyncio.Event,
variables: dict[str, Any],
) -> AsyncGenerator[TraceElement, None]:
"""Trace action execution.""" """Trace action execution."""
path = trace_path_get() path = trace_path_get()
trace_element = action_trace_append(variables, path) trace_element = action_trace_append(variables, path)
@ -362,6 +367,8 @@ class _StopScript(_HaltScript):
class _ScriptRun: class _ScriptRun:
"""Manage Script sequence run.""" """Manage Script sequence run."""
_action: dict[str, Any]
def __init__( def __init__(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
@ -376,7 +383,6 @@ class _ScriptRun:
self._context = context self._context = context
self._log_exceptions = log_exceptions self._log_exceptions = log_exceptions
self._step = -1 self._step = -1
self._action: dict[str, Any] | None = None
self._stop = asyncio.Event() self._stop = asyncio.Event()
self._stopped = asyncio.Event() self._stopped = asyncio.Event()
@ -446,11 +452,13 @@ class _ScriptRun:
return ScriptRunResult(response, self._variables) return ScriptRunResult(response, self._variables)
async def _async_step(self, log_exceptions): async def _async_step(self, log_exceptions: bool) -> None:
continue_on_error = self._action.get(CONF_CONTINUE_ON_ERROR, False) continue_on_error = self._action.get(CONF_CONTINUE_ON_ERROR, False)
with trace_path(str(self._step)): with trace_path(str(self._step)):
async with trace_action(self._hass, self, self._stop, self._variables): async with trace_action(
self._hass, self, self._stop, self._variables
) as trace_element:
if self._stop.is_set(): if self._stop.is_set():
return return
@ -466,6 +474,7 @@ class _ScriptRun:
try: try:
handler = f"_async_{action}_step" handler = f"_async_{action}_step"
await getattr(self, handler)() await getattr(self, handler)()
trace_element.update_variables(self._variables)
except Exception as ex: # pylint: disable=broad-except except Exception as ex: # pylint: disable=broad-except
self._handle_exception( self._handle_exception(
ex, continue_on_error, self._log_exceptions or log_exceptions ex, continue_on_error, self._log_exceptions or log_exceptions

View file

@ -21,6 +21,7 @@ class TraceElement:
"_child_key", "_child_key",
"_child_run_id", "_child_run_id",
"_error", "_error",
"_last_variables",
"path", "path",
"_result", "_result",
"reuse_by_child", "reuse_by_child",
@ -38,16 +39,8 @@ class TraceElement:
self.reuse_by_child = False self.reuse_by_child = False
self._timestamp = dt_util.utcnow() self._timestamp = dt_util.utcnow()
if variables is None: self._last_variables = variables_cv.get() or {}
variables = {} self.update_variables(variables)
last_variables = variables_cv.get() or {}
variables_cv.set(dict(variables))
changed_variables = {
key: value
for key, value in variables.items()
if key not in last_variables or last_variables[key] != value
}
self._variables = changed_variables
def __repr__(self) -> str: def __repr__(self) -> str:
"""Container for trace data.""" """Container for trace data."""
@ -71,6 +64,19 @@ class TraceElement:
old_result = self._result or {} old_result = self._result or {}
self._result = {**old_result, **kwargs} self._result = {**old_result, **kwargs}
def update_variables(self, variables: TemplateVarsType) -> None:
"""Update variables."""
if variables is None:
variables = {}
last_variables = self._last_variables
variables_cv.set(dict(variables))
changed_variables = {
key: value
for key, value in variables.items()
if key not in last_variables or last_variables[key] != value
}
self._variables = changed_variables
def as_dict(self) -> dict[str, Any]: def as_dict(self) -> dict[str, Any]:
"""Return dictionary version of this TraceElement.""" """Return dictionary version of this TraceElement."""
result: dict[str, Any] = {"path": self.path, "timestamp": self._timestamp} result: dict[str, Any] = {"path": self.path, "timestamp": self._timestamp}

View file

@ -386,7 +386,10 @@ async def test_calling_service_response_data(
"target": {}, "target": {},
}, },
"running_script": False, "running_script": False,
} },
"variables": {
"my_response": {"data": "value-12345"},
},
} }
], ],
"1": [ "1": [
@ -399,10 +402,7 @@ async def test_calling_service_response_data(
"target": {}, "target": {},
}, },
"running_script": False, "running_script": False,
}, }
"variables": {
"my_response": {"data": "value-12345"},
},
} }
], ],
} }
@ -1163,13 +1163,13 @@ async def test_wait_template_not_schedule(hass: HomeAssistant) -> None:
assert_action_trace( assert_action_trace(
{ {
"0": [{"result": {"event": "test_event", "event_data": {}}}], "0": [{"result": {"event": "test_event", "event_data": {}}}],
"1": [{"result": {"wait": {"completed": True, "remaining": None}}}], "1": [
"2": [
{ {
"result": {"event": "test_event", "event_data": {}}, "result": {"wait": {"completed": True, "remaining": None}},
"variables": {"wait": {"completed": True, "remaining": None}}, "variables": {"wait": {"completed": True, "remaining": None}},
} }
], ],
"2": [{"result": {"event": "test_event", "event_data": {}}}],
} }
) )
@ -1230,13 +1230,13 @@ async def test_wait_timeout(
else: else:
variable_wait = {"wait": {"trigger": None, "remaining": 0.0}} variable_wait = {"wait": {"trigger": None, "remaining": 0.0}}
expected_trace = { expected_trace = {
"0": [{"result": variable_wait}], "0": [
"1": [
{ {
"result": {"event": "test_event", "event_data": {}}, "result": variable_wait,
"variables": variable_wait, "variables": variable_wait,
} }
], ],
"1": [{"result": {"event": "test_event", "event_data": {}}}],
} }
assert_action_trace(expected_trace) assert_action_trace(expected_trace)
@ -1291,19 +1291,14 @@ async def test_wait_continue_on_timeout(
else: else:
variable_wait = {"wait": {"trigger": None, "remaining": 0.0}} variable_wait = {"wait": {"trigger": None, "remaining": 0.0}}
expected_trace = { expected_trace = {
"0": [{"result": variable_wait}], "0": [{"result": variable_wait, "variables": variable_wait}],
} }
if continue_on_timeout is False: if continue_on_timeout is False:
expected_trace["0"][0]["result"]["timeout"] = True expected_trace["0"][0]["result"]["timeout"] = True
expected_trace["0"][0]["error_type"] = asyncio.TimeoutError expected_trace["0"][0]["error_type"] = asyncio.TimeoutError
expected_script_execution = "aborted" expected_script_execution = "aborted"
else: else:
expected_trace["1"] = [ expected_trace["1"] = [{"result": {"event": "test_event", "event_data": {}}}]
{
"result": {"event": "test_event", "event_data": {}},
"variables": variable_wait,
}
]
expected_script_execution = "finished" expected_script_execution = "finished"
assert_action_trace(expected_trace, expected_script_execution) assert_action_trace(expected_trace, expected_script_execution)
@ -3269,12 +3264,12 @@ async def test_parallel(hass: HomeAssistant, caplog: pytest.LogCaptureFixture) -
"description": "state of switch.trigger", "description": "state of switch.trigger",
}, },
} }
} },
"variables": {"wait": {"remaining": None}},
} }
], ],
"0/parallel/1/sequence/0": [ "0/parallel/1/sequence/0": [
{ {
"variables": {},
"result": { "result": {
"event": "test_event", "event": "test_event",
"event_data": {"hello": "from action 2", "what": "world"}, "event_data": {"hello": "from action 2", "what": "world"},
@ -3283,7 +3278,6 @@ async def test_parallel(hass: HomeAssistant, caplog: pytest.LogCaptureFixture) -
], ],
"0/parallel/0/sequence/1": [ "0/parallel/0/sequence/1": [
{ {
"variables": {"wait": {"remaining": None}},
"result": { "result": {
"event": "test_event", "event": "test_event",
"event_data": {"hello": "from action 1", "what": "world"}, "event_data": {"hello": "from action 1", "what": "world"},
@ -4462,7 +4456,7 @@ async def test_set_variable(
assert f"Executing step {alias}" in caplog.text assert f"Executing step {alias}" in caplog.text
expected_trace = { expected_trace = {
"0": [{}], "0": [{"variables": {"variable": "value"}}],
"1": [ "1": [
{ {
"result": { "result": {
@ -4474,7 +4468,6 @@ async def test_set_variable(
}, },
"running_script": False, "running_script": False,
}, },
"variables": {"variable": "value"},
} }
], ],
} }
@ -4504,7 +4497,7 @@ async def test_set_redefines_variable(
assert mock_calls[1].data["value"] == 2 assert mock_calls[1].data["value"] == 2
expected_trace = { expected_trace = {
"0": [{}], "0": [{"variables": {"variable": "1"}}],
"1": [ "1": [
{ {
"result": { "result": {
@ -4515,11 +4508,10 @@ async def test_set_redefines_variable(
"target": {}, "target": {},
}, },
"running_script": False, "running_script": False,
}, }
"variables": {"variable": "1"},
} }
], ],
"2": [{}], "2": [{"variables": {"variable": 2}}],
"3": [ "3": [
{ {
"result": { "result": {
@ -4530,8 +4522,7 @@ async def test_set_redefines_variable(
"target": {}, "target": {},
}, },
"running_script": False, "running_script": False,
}, }
"variables": {"variable": 2},
} }
], ],
} }