Add new repeat loop for scripts and automations (#37589)

This commit is contained in:
Phil Bruckner 2020-07-10 13:37:19 -05:00 committed by GitHub
parent b187b17a4f
commit 91271f388c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 337 additions and 62 deletions

View file

@ -3,7 +3,8 @@ from abc import ABC, abstractmethod
import asyncio
from contextlib import suppress
from datetime import datetime
from itertools import islice
from functools import partial
import itertools
import logging
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, cast
@ -18,6 +19,7 @@ from homeassistant.const import (
CONF_ALIAS,
CONF_CONDITION,
CONF_CONTINUE_ON_TIMEOUT,
CONF_COUNT,
CONF_DELAY,
CONF_DEVICE_ID,
CONF_DOMAIN,
@ -26,9 +28,13 @@ from homeassistant.const import (
CONF_EVENT_DATA_TEMPLATE,
CONF_MODE,
CONF_QUEUE_SIZE,
CONF_REPEAT,
CONF_SCENE,
CONF_SEQUENCE,
CONF_TIMEOUT,
CONF_UNTIL,
CONF_WAIT_TEMPLATE,
CONF_WHILE,
SERVICE_TURN_OFF,
SERVICE_TURN_ON,
)
@ -86,6 +92,10 @@ SCRIPT_BASE_SCHEMA = vol.Schema(
}
)
_UNSUPPORTED_IN_LEGACY = {
cv.SCRIPT_ACTION_REPEAT: CONF_REPEAT,
}
def warn_deprecated_legacy(logger, msg):
"""Warn about deprecated legacy mode."""
@ -99,6 +109,17 @@ def warn_deprecated_legacy(logger, msg):
)
def validate_legacy_mode_actions(sequence):
"""Check for actions not supported in legacy mode."""
for action in sequence:
script_action = cv.determine_script_action(action)
if script_action in _UNSUPPORTED_IN_LEGACY:
raise vol.Invalid(
f"{_UNSUPPORTED_IN_LEGACY[script_action]} action not supported in "
f"{SCRIPT_MODE_LEGACY} mode"
)
def validate_queue_size(config):
"""Validate queue_size option."""
mode = config.get(CONF_MODE, DEFAULT_SCRIPT_MODE)
@ -330,22 +351,29 @@ class _ScriptRunBase(ABC):
self._action[CONF_EVENT], event_data, context=self._context
)
async def _async_get_condition(self, config):
config_cache_key = frozenset((k, str(v)) for k, v in config.items())
cond = self._config_cache.get(config_cache_key)
if not cond:
cond = await condition.async_from_config(self._hass, config, False)
self._config_cache[config_cache_key] = cond
return cond
async def _async_condition_step(self):
"""Test if condition is matching."""
config_cache_key = frozenset((k, str(v)) for k, v in self._action.items())
config = self._config_cache.get(config_cache_key)
if not config:
config = await condition.async_from_config(self._hass, self._action, False)
self._config_cache[config_cache_key] = config
self._script.last_action = self._action.get(
CONF_ALIAS, self._action[CONF_CONDITION]
)
check = config(self._hass, self._variables)
cond = await self._async_get_condition(self._action)
check = cond(self._hass, self._variables)
self._log("Test condition %s: %s", self._script.last_action, check)
if not check:
raise _StopScript
@abstractmethod
async def _async_repeat_step(self):
"""Repeat a sequence."""
def _log(self, msg, *args, level=logging.INFO):
self._script._log(msg, *args, level=level) # pylint: disable=protected-access
@ -441,6 +469,41 @@ class _ScriptRun(_ScriptRunBase):
task.cancel()
unsub()
async def _async_run_long_action(self, long_task):
"""Run a long task while monitoring for stop request."""
async def async_cancel_long_task():
# Stop long task and wait for it to finish.
long_task.cancel()
try:
await long_task
except Exception: # pylint: disable=broad-except
pass
# Wait for long task while monitoring for a stop request.
stop_task = self._hass.async_create_task(self._stop.wait())
try:
await asyncio.wait(
{long_task, stop_task}, return_when=asyncio.FIRST_COMPLETED
)
# If our task is cancelled, then cancel long task, too. Note that if long task
# is cancelled otherwise the CancelledError exception will not be raised to
# here due to the call to asyncio.wait(). Rather we'll check for that below.
except asyncio.CancelledError:
await async_cancel_long_task()
raise
finally:
stop_task.cancel()
if long_task.cancelled():
raise asyncio.CancelledError
if long_task.done():
# Propagate any exceptions that occurred.
long_task.result()
else:
# Stopped before long task completed, so cancel it.
await async_cancel_long_task()
async def _async_call_service_step(self):
"""Call the service specified in the action."""
domain, service, service_data = self._prep_call_service_step()
@ -474,37 +537,71 @@ class _ScriptRun(_ScriptRunBase):
await service_task
return
async def async_cancel_service_task():
# Stop service task and wait for it to finish.
service_task.cancel()
try:
await service_task
except Exception: # pylint: disable=broad-except
pass
await self._async_run_long_action(service_task)
# No call limit so watch for a stop request.
stop_task = self._hass.async_create_task(self._stop.wait())
try:
await asyncio.wait(
{service_task, stop_task}, return_when=asyncio.FIRST_COMPLETED
async def _async_repeat_step(self):
"""Repeat a sequence."""
description = self._action.get(CONF_ALIAS, "sequence")
repeat = self._action[CONF_REPEAT]
async def async_run_sequence(iteration, extra_msg="", extra_vars=None):
self._log("Repeating %s: Iteration %i%s", description, iteration, extra_msg)
repeat_vars = {"repeat": {"first": iteration == 1, "index": iteration}}
if extra_vars:
repeat_vars["repeat"].update(extra_vars)
task = self._hass.async_create_task(
# pylint: disable=protected-access
self._script._repeat_script[self._step].async_run(
# Add repeat to variables. Override if it already exists in case of
# nested calls.
{**(self._variables or {}), **repeat_vars},
self._context,
)
)
# If our task is cancelled, then cancel service task, too. Note that if service
# task is cancelled otherwise the CancelledError exception will not be raised to
# here due to the call to asyncio.wait(). Rather we'll check for that below.
except asyncio.CancelledError:
await async_cancel_service_task()
raise
finally:
stop_task.cancel()
await self._async_run_long_action(task)
if service_task.cancelled():
raise asyncio.CancelledError
if service_task.done():
# Propagate any exceptions that occurred.
service_task.result()
elif running_script:
# Stopped before service completed, so cancel service.
await async_cancel_service_task()
if CONF_COUNT in repeat:
count = repeat[CONF_COUNT]
if isinstance(count, template.Template):
try:
count = int(count.async_render(self._variables))
except (exceptions.TemplateError, ValueError) as ex:
self._log(
"Error rendering %s repeat count template: %s",
self._script.name,
ex,
level=logging.ERROR,
)
raise _StopScript
for iteration in range(1, count + 1):
await async_run_sequence(
iteration, f" of {count}", {"last": iteration == count}
)
if self._stop.is_set():
break
elif CONF_WHILE in repeat:
conditions = [
await self._async_get_condition(config) for config in repeat[CONF_WHILE]
]
for iteration in itertools.count(1):
if self._stop.is_set() or not all(
cond(self._hass, self._variables) for cond in conditions
):
break
await async_run_sequence(iteration)
elif CONF_UNTIL in repeat:
conditions = [
await self._async_get_condition(config) for config in repeat[CONF_UNTIL]
]
for iteration in itertools.count(1):
await async_run_sequence(iteration)
if self._stop.is_set() or all(
cond(self._hass, self._variables) for cond in conditions
):
break
class _QueuedScriptRun(_ScriptRun):
@ -599,7 +696,7 @@ class _LegacyScriptRun(_ScriptRunBase):
suspended = False
try:
for self._step, self._action in islice(
for self._step, self._action in itertools.islice(
enumerate(self._script.sequence), self._cur, None
):
await self._async_step(log_exceptions=not propagate_exceptions)
@ -689,6 +786,10 @@ class _LegacyScriptRun(_ScriptRunBase):
*self._prep_call_service_step(), blocking=True, context=self._context
)
async def _async_repeat_step(self):
"""Repeat a sequence."""
# Not supported in legacy mode.
def _async_remove_listener(self):
"""Remove listeners, if any."""
for unsub in self._async_listener:
@ -733,6 +834,22 @@ class Script:
for action in self.sequence
)
self._repeat_script = {}
for step, action in enumerate(sequence):
if cv.determine_script_action(action) == cv.SCRIPT_ACTION_REPEAT:
step_name = action.get(CONF_ALIAS, f"Repeat at step {step}")
sub_script = Script(
hass,
action[CONF_REPEAT][CONF_SEQUENCE],
f"{name}: {step_name}",
script_mode=SCRIPT_MODE_PARALLEL,
logger=self._logger,
)
sub_script.change_listener = partial(
self._chain_change_listener, sub_script
)
self._repeat_script[step] = sub_script
self._runs: List[_ScriptRunBase] = []
if script_mode == SCRIPT_MODE_QUEUE:
self._queue_size = queue_size
@ -746,6 +863,11 @@ class Script:
if self.change_listener:
self._hass.async_run_job(self.change_listener)
def _chain_change_listener(self, sub_script):
if sub_script.is_running:
self.last_action = sub_script.last_action
self._changed()
@property
def is_running(self) -> bool:
"""Return true if script is on."""