Add new repeat loop for scripts and automations (#37589)
This commit is contained in:
parent
b187b17a4f
commit
91271f388c
8 changed files with 337 additions and 62 deletions
|
@ -3,7 +3,8 @@ from abc import ABC, abstractmethod
|
|||
import asyncio
|
||||
from contextlib import suppress
|
||||
from datetime import datetime
|
||||
from itertools import islice
|
||||
from functools import partial
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, cast
|
||||
|
||||
|
@ -18,6 +19,7 @@ from homeassistant.const import (
|
|||
CONF_ALIAS,
|
||||
CONF_CONDITION,
|
||||
CONF_CONTINUE_ON_TIMEOUT,
|
||||
CONF_COUNT,
|
||||
CONF_DELAY,
|
||||
CONF_DEVICE_ID,
|
||||
CONF_DOMAIN,
|
||||
|
@ -26,9 +28,13 @@ from homeassistant.const import (
|
|||
CONF_EVENT_DATA_TEMPLATE,
|
||||
CONF_MODE,
|
||||
CONF_QUEUE_SIZE,
|
||||
CONF_REPEAT,
|
||||
CONF_SCENE,
|
||||
CONF_SEQUENCE,
|
||||
CONF_TIMEOUT,
|
||||
CONF_UNTIL,
|
||||
CONF_WAIT_TEMPLATE,
|
||||
CONF_WHILE,
|
||||
SERVICE_TURN_OFF,
|
||||
SERVICE_TURN_ON,
|
||||
)
|
||||
|
@ -86,6 +92,10 @@ SCRIPT_BASE_SCHEMA = vol.Schema(
|
|||
}
|
||||
)
|
||||
|
||||
_UNSUPPORTED_IN_LEGACY = {
|
||||
cv.SCRIPT_ACTION_REPEAT: CONF_REPEAT,
|
||||
}
|
||||
|
||||
|
||||
def warn_deprecated_legacy(logger, msg):
|
||||
"""Warn about deprecated legacy mode."""
|
||||
|
@ -99,6 +109,17 @@ def warn_deprecated_legacy(logger, msg):
|
|||
)
|
||||
|
||||
|
||||
def validate_legacy_mode_actions(sequence):
|
||||
"""Check for actions not supported in legacy mode."""
|
||||
for action in sequence:
|
||||
script_action = cv.determine_script_action(action)
|
||||
if script_action in _UNSUPPORTED_IN_LEGACY:
|
||||
raise vol.Invalid(
|
||||
f"{_UNSUPPORTED_IN_LEGACY[script_action]} action not supported in "
|
||||
f"{SCRIPT_MODE_LEGACY} mode"
|
||||
)
|
||||
|
||||
|
||||
def validate_queue_size(config):
|
||||
"""Validate queue_size option."""
|
||||
mode = config.get(CONF_MODE, DEFAULT_SCRIPT_MODE)
|
||||
|
@ -330,22 +351,29 @@ class _ScriptRunBase(ABC):
|
|||
self._action[CONF_EVENT], event_data, context=self._context
|
||||
)
|
||||
|
||||
async def _async_get_condition(self, config):
|
||||
config_cache_key = frozenset((k, str(v)) for k, v in config.items())
|
||||
cond = self._config_cache.get(config_cache_key)
|
||||
if not cond:
|
||||
cond = await condition.async_from_config(self._hass, config, False)
|
||||
self._config_cache[config_cache_key] = cond
|
||||
return cond
|
||||
|
||||
async def _async_condition_step(self):
|
||||
"""Test if condition is matching."""
|
||||
config_cache_key = frozenset((k, str(v)) for k, v in self._action.items())
|
||||
config = self._config_cache.get(config_cache_key)
|
||||
if not config:
|
||||
config = await condition.async_from_config(self._hass, self._action, False)
|
||||
self._config_cache[config_cache_key] = config
|
||||
|
||||
self._script.last_action = self._action.get(
|
||||
CONF_ALIAS, self._action[CONF_CONDITION]
|
||||
)
|
||||
check = config(self._hass, self._variables)
|
||||
cond = await self._async_get_condition(self._action)
|
||||
check = cond(self._hass, self._variables)
|
||||
self._log("Test condition %s: %s", self._script.last_action, check)
|
||||
if not check:
|
||||
raise _StopScript
|
||||
|
||||
@abstractmethod
|
||||
async def _async_repeat_step(self):
|
||||
"""Repeat a sequence."""
|
||||
|
||||
def _log(self, msg, *args, level=logging.INFO):
|
||||
self._script._log(msg, *args, level=level) # pylint: disable=protected-access
|
||||
|
||||
|
@ -441,6 +469,41 @@ class _ScriptRun(_ScriptRunBase):
|
|||
task.cancel()
|
||||
unsub()
|
||||
|
||||
async def _async_run_long_action(self, long_task):
|
||||
"""Run a long task while monitoring for stop request."""
|
||||
|
||||
async def async_cancel_long_task():
|
||||
# Stop long task and wait for it to finish.
|
||||
long_task.cancel()
|
||||
try:
|
||||
await long_task
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
|
||||
# Wait for long task while monitoring for a stop request.
|
||||
stop_task = self._hass.async_create_task(self._stop.wait())
|
||||
try:
|
||||
await asyncio.wait(
|
||||
{long_task, stop_task}, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
# If our task is cancelled, then cancel long task, too. Note that if long task
|
||||
# is cancelled otherwise the CancelledError exception will not be raised to
|
||||
# here due to the call to asyncio.wait(). Rather we'll check for that below.
|
||||
except asyncio.CancelledError:
|
||||
await async_cancel_long_task()
|
||||
raise
|
||||
finally:
|
||||
stop_task.cancel()
|
||||
|
||||
if long_task.cancelled():
|
||||
raise asyncio.CancelledError
|
||||
if long_task.done():
|
||||
# Propagate any exceptions that occurred.
|
||||
long_task.result()
|
||||
else:
|
||||
# Stopped before long task completed, so cancel it.
|
||||
await async_cancel_long_task()
|
||||
|
||||
async def _async_call_service_step(self):
|
||||
"""Call the service specified in the action."""
|
||||
domain, service, service_data = self._prep_call_service_step()
|
||||
|
@ -474,37 +537,71 @@ class _ScriptRun(_ScriptRunBase):
|
|||
await service_task
|
||||
return
|
||||
|
||||
async def async_cancel_service_task():
|
||||
# Stop service task and wait for it to finish.
|
||||
service_task.cancel()
|
||||
try:
|
||||
await service_task
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
await self._async_run_long_action(service_task)
|
||||
|
||||
# No call limit so watch for a stop request.
|
||||
stop_task = self._hass.async_create_task(self._stop.wait())
|
||||
try:
|
||||
await asyncio.wait(
|
||||
{service_task, stop_task}, return_when=asyncio.FIRST_COMPLETED
|
||||
async def _async_repeat_step(self):
|
||||
"""Repeat a sequence."""
|
||||
|
||||
description = self._action.get(CONF_ALIAS, "sequence")
|
||||
repeat = self._action[CONF_REPEAT]
|
||||
|
||||
async def async_run_sequence(iteration, extra_msg="", extra_vars=None):
|
||||
self._log("Repeating %s: Iteration %i%s", description, iteration, extra_msg)
|
||||
repeat_vars = {"repeat": {"first": iteration == 1, "index": iteration}}
|
||||
if extra_vars:
|
||||
repeat_vars["repeat"].update(extra_vars)
|
||||
task = self._hass.async_create_task(
|
||||
# pylint: disable=protected-access
|
||||
self._script._repeat_script[self._step].async_run(
|
||||
# Add repeat to variables. Override if it already exists in case of
|
||||
# nested calls.
|
||||
{**(self._variables or {}), **repeat_vars},
|
||||
self._context,
|
||||
)
|
||||
)
|
||||
# If our task is cancelled, then cancel service task, too. Note that if service
|
||||
# task is cancelled otherwise the CancelledError exception will not be raised to
|
||||
# here due to the call to asyncio.wait(). Rather we'll check for that below.
|
||||
except asyncio.CancelledError:
|
||||
await async_cancel_service_task()
|
||||
raise
|
||||
finally:
|
||||
stop_task.cancel()
|
||||
await self._async_run_long_action(task)
|
||||
|
||||
if service_task.cancelled():
|
||||
raise asyncio.CancelledError
|
||||
if service_task.done():
|
||||
# Propagate any exceptions that occurred.
|
||||
service_task.result()
|
||||
elif running_script:
|
||||
# Stopped before service completed, so cancel service.
|
||||
await async_cancel_service_task()
|
||||
if CONF_COUNT in repeat:
|
||||
count = repeat[CONF_COUNT]
|
||||
if isinstance(count, template.Template):
|
||||
try:
|
||||
count = int(count.async_render(self._variables))
|
||||
except (exceptions.TemplateError, ValueError) as ex:
|
||||
self._log(
|
||||
"Error rendering %s repeat count template: %s",
|
||||
self._script.name,
|
||||
ex,
|
||||
level=logging.ERROR,
|
||||
)
|
||||
raise _StopScript
|
||||
for iteration in range(1, count + 1):
|
||||
await async_run_sequence(
|
||||
iteration, f" of {count}", {"last": iteration == count}
|
||||
)
|
||||
if self._stop.is_set():
|
||||
break
|
||||
|
||||
elif CONF_WHILE in repeat:
|
||||
conditions = [
|
||||
await self._async_get_condition(config) for config in repeat[CONF_WHILE]
|
||||
]
|
||||
for iteration in itertools.count(1):
|
||||
if self._stop.is_set() or not all(
|
||||
cond(self._hass, self._variables) for cond in conditions
|
||||
):
|
||||
break
|
||||
await async_run_sequence(iteration)
|
||||
|
||||
elif CONF_UNTIL in repeat:
|
||||
conditions = [
|
||||
await self._async_get_condition(config) for config in repeat[CONF_UNTIL]
|
||||
]
|
||||
for iteration in itertools.count(1):
|
||||
await async_run_sequence(iteration)
|
||||
if self._stop.is_set() or all(
|
||||
cond(self._hass, self._variables) for cond in conditions
|
||||
):
|
||||
break
|
||||
|
||||
|
||||
class _QueuedScriptRun(_ScriptRun):
|
||||
|
@ -599,7 +696,7 @@ class _LegacyScriptRun(_ScriptRunBase):
|
|||
|
||||
suspended = False
|
||||
try:
|
||||
for self._step, self._action in islice(
|
||||
for self._step, self._action in itertools.islice(
|
||||
enumerate(self._script.sequence), self._cur, None
|
||||
):
|
||||
await self._async_step(log_exceptions=not propagate_exceptions)
|
||||
|
@ -689,6 +786,10 @@ class _LegacyScriptRun(_ScriptRunBase):
|
|||
*self._prep_call_service_step(), blocking=True, context=self._context
|
||||
)
|
||||
|
||||
async def _async_repeat_step(self):
|
||||
"""Repeat a sequence."""
|
||||
# Not supported in legacy mode.
|
||||
|
||||
def _async_remove_listener(self):
|
||||
"""Remove listeners, if any."""
|
||||
for unsub in self._async_listener:
|
||||
|
@ -733,6 +834,22 @@ class Script:
|
|||
for action in self.sequence
|
||||
)
|
||||
|
||||
self._repeat_script = {}
|
||||
for step, action in enumerate(sequence):
|
||||
if cv.determine_script_action(action) == cv.SCRIPT_ACTION_REPEAT:
|
||||
step_name = action.get(CONF_ALIAS, f"Repeat at step {step}")
|
||||
sub_script = Script(
|
||||
hass,
|
||||
action[CONF_REPEAT][CONF_SEQUENCE],
|
||||
f"{name}: {step_name}",
|
||||
script_mode=SCRIPT_MODE_PARALLEL,
|
||||
logger=self._logger,
|
||||
)
|
||||
sub_script.change_listener = partial(
|
||||
self._chain_change_listener, sub_script
|
||||
)
|
||||
self._repeat_script[step] = sub_script
|
||||
|
||||
self._runs: List[_ScriptRunBase] = []
|
||||
if script_mode == SCRIPT_MODE_QUEUE:
|
||||
self._queue_size = queue_size
|
||||
|
@ -746,6 +863,11 @@ class Script:
|
|||
if self.change_listener:
|
||||
self._hass.async_run_job(self.change_listener)
|
||||
|
||||
def _chain_change_listener(self, sub_script):
|
||||
if sub_script.is_running:
|
||||
self.last_action = sub_script.last_action
|
||||
self._changed()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Return true if script is on."""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue