Fix script repeat variable lifetime (#38124)
This commit is contained in:
parent
a7459b3126
commit
2f87da8aa9
2 changed files with 142 additions and 18 deletions
|
@ -140,7 +140,7 @@ class _ScriptRun:
|
||||||
) -> None:
|
) -> None:
|
||||||
self._hass = hass
|
self._hass = hass
|
||||||
self._script = script
|
self._script = script
|
||||||
self._variables = variables
|
self._variables = variables or {}
|
||||||
self._context = context
|
self._context = context
|
||||||
self._log_exceptions = log_exceptions
|
self._log_exceptions = log_exceptions
|
||||||
self._step = -1
|
self._step = -1
|
||||||
|
@ -431,22 +431,23 @@ class _ScriptRun:
|
||||||
|
|
||||||
async def _async_repeat_step(self):
|
async def _async_repeat_step(self):
|
||||||
"""Repeat a sequence."""
|
"""Repeat a sequence."""
|
||||||
|
|
||||||
description = self._action.get(CONF_ALIAS, "sequence")
|
description = self._action.get(CONF_ALIAS, "sequence")
|
||||||
repeat = self._action[CONF_REPEAT]
|
repeat = self._action[CONF_REPEAT]
|
||||||
|
|
||||||
async def async_run_sequence(iteration, extra_msg="", extra_vars=None):
|
saved_repeat_vars = self._variables.get("repeat")
|
||||||
|
|
||||||
|
def set_repeat_var(iteration, count=None):
|
||||||
|
repeat_vars = {"first": iteration == 1, "index": iteration}
|
||||||
|
if count:
|
||||||
|
repeat_vars["last"] = iteration == count
|
||||||
|
self._variables["repeat"] = repeat_vars
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
script = self._script._get_repeat_script(self._step)
|
||||||
|
|
||||||
|
async def async_run_sequence(iteration, extra_msg=""):
|
||||||
self._log("Repeating %s: Iteration %i%s", description, iteration, extra_msg)
|
self._log("Repeating %s: Iteration %i%s", description, iteration, extra_msg)
|
||||||
repeat_vars = {"repeat": {"first": iteration == 1, "index": iteration}}
|
await self._async_run_script(script)
|
||||||
if extra_vars:
|
|
||||||
repeat_vars["repeat"].update(extra_vars)
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
await self._async_run_script(
|
|
||||||
self._script._get_repeat_script(self._step),
|
|
||||||
# Add repeat to variables. Override if it already exists in case of
|
|
||||||
# nested calls.
|
|
||||||
{**(self._variables or {}), **repeat_vars},
|
|
||||||
)
|
|
||||||
|
|
||||||
if CONF_COUNT in repeat:
|
if CONF_COUNT in repeat:
|
||||||
count = repeat[CONF_COUNT]
|
count = repeat[CONF_COUNT]
|
||||||
|
@ -461,10 +462,10 @@ class _ScriptRun:
|
||||||
level=logging.ERROR,
|
level=logging.ERROR,
|
||||||
)
|
)
|
||||||
raise _StopScript
|
raise _StopScript
|
||||||
|
extra_msg = f" of {count}"
|
||||||
for iteration in range(1, count + 1):
|
for iteration in range(1, count + 1):
|
||||||
await async_run_sequence(
|
set_repeat_var(iteration, count)
|
||||||
iteration, f" of {count}", {"last": iteration == count}
|
await async_run_sequence(iteration, extra_msg)
|
||||||
)
|
|
||||||
if self._stop.is_set():
|
if self._stop.is_set():
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -473,6 +474,7 @@ class _ScriptRun:
|
||||||
await self._async_get_condition(config) for config in repeat[CONF_WHILE]
|
await self._async_get_condition(config) for config in repeat[CONF_WHILE]
|
||||||
]
|
]
|
||||||
for iteration in itertools.count(1):
|
for iteration in itertools.count(1):
|
||||||
|
set_repeat_var(iteration)
|
||||||
if self._stop.is_set() or not all(
|
if self._stop.is_set() or not all(
|
||||||
cond(self._hass, self._variables) for cond in conditions
|
cond(self._hass, self._variables) for cond in conditions
|
||||||
):
|
):
|
||||||
|
@ -484,12 +486,18 @@ class _ScriptRun:
|
||||||
await self._async_get_condition(config) for config in repeat[CONF_UNTIL]
|
await self._async_get_condition(config) for config in repeat[CONF_UNTIL]
|
||||||
]
|
]
|
||||||
for iteration in itertools.count(1):
|
for iteration in itertools.count(1):
|
||||||
|
set_repeat_var(iteration)
|
||||||
await async_run_sequence(iteration)
|
await async_run_sequence(iteration)
|
||||||
if self._stop.is_set() or all(
|
if self._stop.is_set() or all(
|
||||||
cond(self._hass, self._variables) for cond in conditions
|
cond(self._hass, self._variables) for cond in conditions
|
||||||
):
|
):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if saved_repeat_vars:
|
||||||
|
self._variables["repeat"] = saved_repeat_vars
|
||||||
|
else:
|
||||||
|
del self._variables["repeat"]
|
||||||
|
|
||||||
async def _async_choose_step(self):
|
async def _async_choose_step(self):
|
||||||
"""Choose a sequence."""
|
"""Choose a sequence."""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
|
@ -503,11 +511,11 @@ class _ScriptRun:
|
||||||
if choose_data["default"]:
|
if choose_data["default"]:
|
||||||
await self._async_run_script(choose_data["default"])
|
await self._async_run_script(choose_data["default"])
|
||||||
|
|
||||||
async def _async_run_script(self, script, variables=None):
|
async def _async_run_script(self, script):
|
||||||
"""Execute a script."""
|
"""Execute a script."""
|
||||||
await self._async_run_long_action(
|
await self._async_run_long_action(
|
||||||
self._hass.async_create_task(
|
self._hass.async_create_task(
|
||||||
script.async_run(variables or self._variables, self._context)
|
script.async_run(self._variables, self._context)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -854,6 +854,122 @@ async def test_repeat_conditional(hass, condition):
|
||||||
assert event.data.get("index") == str(index + 1)
|
assert event.data.get("index") == str(index + 1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("condition", ["while", "until"])
|
||||||
|
async def test_repeat_var_in_condition(hass, condition):
|
||||||
|
"""Test repeat action w/ while option."""
|
||||||
|
event = "test_event"
|
||||||
|
events = async_capture_events(hass, event)
|
||||||
|
|
||||||
|
sequence = {"repeat": {"sequence": {"event": event}}}
|
||||||
|
if condition == "while":
|
||||||
|
sequence["repeat"]["while"] = {
|
||||||
|
"condition": "template",
|
||||||
|
"value_template": "{{ repeat.index <= 2 }}",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
sequence["repeat"]["until"] = {
|
||||||
|
"condition": "template",
|
||||||
|
"value_template": "{{ repeat.index == 2 }}",
|
||||||
|
}
|
||||||
|
script_obj = script.Script(hass, cv.SCRIPT_SCHEMA(sequence))
|
||||||
|
|
||||||
|
with mock.patch(
|
||||||
|
"homeassistant.helpers.condition._LOGGER.error",
|
||||||
|
side_effect=AssertionError("Template Error"),
|
||||||
|
):
|
||||||
|
await script_obj.async_run()
|
||||||
|
|
||||||
|
assert len(events) == 2
|
||||||
|
|
||||||
|
|
||||||
|
async def test_repeat_nested(hass):
|
||||||
|
"""Test nested repeats."""
|
||||||
|
event = "test_event"
|
||||||
|
events = async_capture_events(hass, event)
|
||||||
|
|
||||||
|
sequence = cv.SCRIPT_SCHEMA(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"event": event,
|
||||||
|
"event_data_template": {
|
||||||
|
"repeat": "{{ None if repeat is not defined else repeat }}"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"repeat": {
|
||||||
|
"count": 2,
|
||||||
|
"sequence": [
|
||||||
|
{
|
||||||
|
"event": event,
|
||||||
|
"event_data_template": {
|
||||||
|
"first": "{{ repeat.first }}",
|
||||||
|
"index": "{{ repeat.index }}",
|
||||||
|
"last": "{{ repeat.last }}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"repeat": {
|
||||||
|
"count": 2,
|
||||||
|
"sequence": {
|
||||||
|
"event": event,
|
||||||
|
"event_data_template": {
|
||||||
|
"first": "{{ repeat.first }}",
|
||||||
|
"index": "{{ repeat.index }}",
|
||||||
|
"last": "{{ repeat.last }}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"event": event,
|
||||||
|
"event_data_template": {
|
||||||
|
"first": "{{ repeat.first }}",
|
||||||
|
"index": "{{ repeat.index }}",
|
||||||
|
"last": "{{ repeat.last }}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"event": event,
|
||||||
|
"event_data_template": {
|
||||||
|
"repeat": "{{ None if repeat is not defined else repeat }}"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
script_obj = script.Script(hass, sequence, "test script")
|
||||||
|
|
||||||
|
with mock.patch(
|
||||||
|
"homeassistant.helpers.condition._LOGGER.error",
|
||||||
|
side_effect=AssertionError("Template Error"),
|
||||||
|
):
|
||||||
|
await script_obj.async_run()
|
||||||
|
|
||||||
|
assert len(events) == 10
|
||||||
|
assert events[0].data == {"repeat": "None"}
|
||||||
|
assert events[-1].data == {"repeat": "None"}
|
||||||
|
for index, result in enumerate(
|
||||||
|
(
|
||||||
|
("True", "1", "False"),
|
||||||
|
("True", "1", "False"),
|
||||||
|
("False", "2", "True"),
|
||||||
|
("True", "1", "False"),
|
||||||
|
("False", "2", "True"),
|
||||||
|
("True", "1", "False"),
|
||||||
|
("False", "2", "True"),
|
||||||
|
("False", "2", "True"),
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
):
|
||||||
|
assert events[index].data == {
|
||||||
|
"first": result[0],
|
||||||
|
"index": result[1],
|
||||||
|
"last": result[2],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("var,result", [(1, "first"), (2, "second"), (3, "default")])
|
@pytest.mark.parametrize("var,result", [(1, "first"), (2, "second"), (3, "default")])
|
||||||
async def test_choose(hass, var, result):
|
async def test_choose(hass, var, result):
|
||||||
"""Test choose action."""
|
"""Test choose action."""
|
||||||
|
|
Loading…
Add table
Reference in a new issue