Fix script repeat variable lifetime (#38124)

This commit is contained in:
Phil Bruckner 2020-07-24 01:11:21 -05:00 committed by GitHub
parent a7459b3126
commit 2f87da8aa9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 142 additions and 18 deletions

View file

@ -140,7 +140,7 @@ class _ScriptRun:
) -> None: ) -> None:
self._hass = hass self._hass = hass
self._script = script self._script = script
self._variables = variables self._variables = variables or {}
self._context = context self._context = context
self._log_exceptions = log_exceptions self._log_exceptions = log_exceptions
self._step = -1 self._step = -1
@ -431,22 +431,23 @@ class _ScriptRun:
async def _async_repeat_step(self): async def _async_repeat_step(self):
"""Repeat a sequence.""" """Repeat a sequence."""
description = self._action.get(CONF_ALIAS, "sequence") description = self._action.get(CONF_ALIAS, "sequence")
repeat = self._action[CONF_REPEAT] repeat = self._action[CONF_REPEAT]
async def async_run_sequence(iteration, extra_msg="", extra_vars=None): saved_repeat_vars = self._variables.get("repeat")
def set_repeat_var(iteration, count=None):
repeat_vars = {"first": iteration == 1, "index": iteration}
if count:
repeat_vars["last"] = iteration == count
self._variables["repeat"] = repeat_vars
# pylint: disable=protected-access
script = self._script._get_repeat_script(self._step)
async def async_run_sequence(iteration, extra_msg=""):
self._log("Repeating %s: Iteration %i%s", description, iteration, extra_msg) self._log("Repeating %s: Iteration %i%s", description, iteration, extra_msg)
repeat_vars = {"repeat": {"first": iteration == 1, "index": iteration}} await self._async_run_script(script)
if extra_vars:
repeat_vars["repeat"].update(extra_vars)
# pylint: disable=protected-access
await self._async_run_script(
self._script._get_repeat_script(self._step),
# Add repeat to variables. Override if it already exists in case of
# nested calls.
{**(self._variables or {}), **repeat_vars},
)
if CONF_COUNT in repeat: if CONF_COUNT in repeat:
count = repeat[CONF_COUNT] count = repeat[CONF_COUNT]
@ -461,10 +462,10 @@ class _ScriptRun:
level=logging.ERROR, level=logging.ERROR,
) )
raise _StopScript raise _StopScript
extra_msg = f" of {count}"
for iteration in range(1, count + 1): for iteration in range(1, count + 1):
await async_run_sequence( set_repeat_var(iteration, count)
iteration, f" of {count}", {"last": iteration == count} await async_run_sequence(iteration, extra_msg)
)
if self._stop.is_set(): if self._stop.is_set():
break break
@ -473,6 +474,7 @@ class _ScriptRun:
await self._async_get_condition(config) for config in repeat[CONF_WHILE] await self._async_get_condition(config) for config in repeat[CONF_WHILE]
] ]
for iteration in itertools.count(1): for iteration in itertools.count(1):
set_repeat_var(iteration)
if self._stop.is_set() or not all( if self._stop.is_set() or not all(
cond(self._hass, self._variables) for cond in conditions cond(self._hass, self._variables) for cond in conditions
): ):
@ -484,12 +486,18 @@ class _ScriptRun:
await self._async_get_condition(config) for config in repeat[CONF_UNTIL] await self._async_get_condition(config) for config in repeat[CONF_UNTIL]
] ]
for iteration in itertools.count(1): for iteration in itertools.count(1):
set_repeat_var(iteration)
await async_run_sequence(iteration) await async_run_sequence(iteration)
if self._stop.is_set() or all( if self._stop.is_set() or all(
cond(self._hass, self._variables) for cond in conditions cond(self._hass, self._variables) for cond in conditions
): ):
break break
if saved_repeat_vars:
self._variables["repeat"] = saved_repeat_vars
else:
del self._variables["repeat"]
async def _async_choose_step(self): async def _async_choose_step(self):
"""Choose a sequence.""" """Choose a sequence."""
# pylint: disable=protected-access # pylint: disable=protected-access
@ -503,11 +511,11 @@ class _ScriptRun:
if choose_data["default"]: if choose_data["default"]:
await self._async_run_script(choose_data["default"]) await self._async_run_script(choose_data["default"])
async def _async_run_script(self, script, variables=None): async def _async_run_script(self, script):
"""Execute a script.""" """Execute a script."""
await self._async_run_long_action( await self._async_run_long_action(
self._hass.async_create_task( self._hass.async_create_task(
script.async_run(variables or self._variables, self._context) script.async_run(self._variables, self._context)
) )
) )

View file

@ -854,6 +854,122 @@ async def test_repeat_conditional(hass, condition):
assert event.data.get("index") == str(index + 1) assert event.data.get("index") == str(index + 1)
@pytest.mark.parametrize("condition", ["while", "until"])
async def test_repeat_var_in_condition(hass, condition):
"""Test repeat action w/ while option."""
event = "test_event"
events = async_capture_events(hass, event)
sequence = {"repeat": {"sequence": {"event": event}}}
if condition == "while":
sequence["repeat"]["while"] = {
"condition": "template",
"value_template": "{{ repeat.index <= 2 }}",
}
else:
sequence["repeat"]["until"] = {
"condition": "template",
"value_template": "{{ repeat.index == 2 }}",
}
script_obj = script.Script(hass, cv.SCRIPT_SCHEMA(sequence))
with mock.patch(
"homeassistant.helpers.condition._LOGGER.error",
side_effect=AssertionError("Template Error"),
):
await script_obj.async_run()
assert len(events) == 2
async def test_repeat_nested(hass):
"""Test nested repeats."""
event = "test_event"
events = async_capture_events(hass, event)
sequence = cv.SCRIPT_SCHEMA(
[
{
"event": event,
"event_data_template": {
"repeat": "{{ None if repeat is not defined else repeat }}"
},
},
{
"repeat": {
"count": 2,
"sequence": [
{
"event": event,
"event_data_template": {
"first": "{{ repeat.first }}",
"index": "{{ repeat.index }}",
"last": "{{ repeat.last }}",
},
},
{
"repeat": {
"count": 2,
"sequence": {
"event": event,
"event_data_template": {
"first": "{{ repeat.first }}",
"index": "{{ repeat.index }}",
"last": "{{ repeat.last }}",
},
},
}
},
{
"event": event,
"event_data_template": {
"first": "{{ repeat.first }}",
"index": "{{ repeat.index }}",
"last": "{{ repeat.last }}",
},
},
],
}
},
{
"event": event,
"event_data_template": {
"repeat": "{{ None if repeat is not defined else repeat }}"
},
},
]
)
script_obj = script.Script(hass, sequence, "test script")
with mock.patch(
"homeassistant.helpers.condition._LOGGER.error",
side_effect=AssertionError("Template Error"),
):
await script_obj.async_run()
assert len(events) == 10
assert events[0].data == {"repeat": "None"}
assert events[-1].data == {"repeat": "None"}
for index, result in enumerate(
(
("True", "1", "False"),
("True", "1", "False"),
("False", "2", "True"),
("True", "1", "False"),
("False", "2", "True"),
("True", "1", "False"),
("False", "2", "True"),
("False", "2", "True"),
),
1,
):
assert events[index].data == {
"first": result[0],
"index": result[1],
"last": result[2],
}
@pytest.mark.parametrize("var,result", [(1, "first"), (2, "second"), (3, "default")]) @pytest.mark.parametrize("var,result", [(1, "first"), (2, "second"), (3, "default")])
async def test_choose(hass, var, result): async def test_choose(hass, var, result):
"""Test choose action.""" """Test choose action."""