Add new repeat loop for scripts and automations (#37589)

This commit is contained in:
Phil Bruckner 2020-07-10 13:37:19 -05:00 committed by GitHub
parent b187b17a4f
commit 91271f388c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 337 additions and 62 deletions

View file

@ -15,6 +15,7 @@ from homeassistant.helpers import condition, config_per_platform
from homeassistant.helpers.script import (
SCRIPT_MODE_LEGACY,
async_validate_action_config,
validate_legacy_mode_actions,
warn_deprecated_legacy,
)
from homeassistant.loader import IntegrationNotFound
@ -55,6 +56,9 @@ async def async_validate_config_item(hass, config, full_config=None):
*[async_validate_action_config(hass, action) for action in config[CONF_ACTION]]
)
if config.get(CONF_MODE, SCRIPT_MODE_LEGACY) == SCRIPT_MODE_LEGACY:
validate_legacy_mode_actions(config[CONF_ACTION])
return config

View file

@ -12,6 +12,7 @@ from homeassistant.const import (
CONF_ICON,
CONF_MODE,
CONF_QUEUE_SIZE,
CONF_SEQUENCE,
SERVICE_RELOAD,
SERVICE_TOGGLE,
SERVICE_TURN_OFF,
@ -27,6 +28,7 @@ from homeassistant.helpers.script import (
SCRIPT_BASE_SCHEMA,
SCRIPT_MODE_LEGACY,
Script,
validate_legacy_mode_actions,
validate_queue_size,
warn_deprecated_legacy,
)
@ -44,7 +46,6 @@ ATTR_VARIABLES = "variables"
CONF_DESCRIPTION = "description"
CONF_EXAMPLE = "example"
CONF_FIELDS = "fields"
CONF_SEQUENCE = "sequence"
ENTITY_ID_FORMAT = DOMAIN + ".{}"
@ -63,6 +64,13 @@ def _deprecated_legacy_mode(config):
return config
def _not_supported_in_legacy_mode(config):
if config.get(CONF_MODE, SCRIPT_MODE_LEGACY) == SCRIPT_MODE_LEGACY:
validate_legacy_mode_actions(config[CONF_SEQUENCE])
return config
SCRIPT_ENTRY_SCHEMA = vol.All(
SCRIPT_BASE_SCHEMA.extend(
{
@ -79,6 +87,7 @@ SCRIPT_ENTRY_SCHEMA = vol.All(
}
),
validate_queue_size,
_not_supported_in_legacy_mode,
)
CONFIG_SCHEMA = vol.Schema(

View file

@ -61,6 +61,7 @@ CONF_COMMAND_STATE = "command_state"
CONF_COMMAND_STOP = "command_stop"
CONF_CONDITION = "condition"
CONF_CONTINUE_ON_TIMEOUT = "continue_on_timeout"
CONF_COUNT = "count"
CONF_COVERS = "covers"
CONF_CURRENCY = "currency"
CONF_CUSTOMIZE = "customize"
@ -139,6 +140,7 @@ CONF_QUOTE = "quote"
CONF_RADIUS = "radius"
CONF_RECIPIENT = "recipient"
CONF_REGION = "region"
CONF_REPEAT = "repeat"
CONF_RESOURCE = "resource"
CONF_RESOURCES = "resources"
CONF_RESOURCE_TEMPLATE = "resource_template"
@ -149,6 +151,7 @@ CONF_SCENE = "scene"
CONF_SENDER = "sender"
CONF_SENSORS = "sensors"
CONF_SENSOR_TYPE = "sensor_type"
CONF_SEQUENCE = "sequence"
CONF_SERVICE = "service"
CONF_SERVICE_DATA = "data"
CONF_SERVICE_TEMPLATE = "service_template"
@ -169,6 +172,7 @@ CONF_TTL = "ttl"
CONF_TYPE = "type"
CONF_UNIT_OF_MEASUREMENT = "unit_of_measurement"
CONF_UNIT_SYSTEM = "unit_system"
CONF_UNTIL = "until"
CONF_URL = "url"
CONF_USERNAME = "username"
CONF_VALUE_TEMPLATE = "value_template"
@ -176,6 +180,7 @@ CONF_VERIFY_SSL = "verify_ssl"
CONF_WAIT_TEMPLATE = "wait_template"
CONF_WEBHOOK_ID = "webhook_id"
CONF_WEEKDAY = "weekday"
CONF_WHILE = "while"
CONF_WHITELIST = "whitelist"
CONF_WHITELIST_EXTERNAL_DIRS = "whitelist_external_dirs"
CONF_WHITE_VALUE = "white_value"

View file

@ -40,6 +40,7 @@ from homeassistant.const import (
CONF_BELOW,
CONF_CONDITION,
CONF_CONTINUE_ON_TIMEOUT,
CONF_COUNT,
CONF_DELAY,
CONF_DEVICE_ID,
CONF_DOMAIN,
@ -50,16 +51,20 @@ from homeassistant.const import (
CONF_EVENT_DATA_TEMPLATE,
CONF_FOR,
CONF_PLATFORM,
CONF_REPEAT,
CONF_SCAN_INTERVAL,
CONF_SCENE,
CONF_SEQUENCE,
CONF_SERVICE,
CONF_SERVICE_TEMPLATE,
CONF_STATE,
CONF_TIMEOUT,
CONF_UNIT_SYSTEM_IMPERIAL,
CONF_UNIT_SYSTEM_METRIC,
CONF_UNTIL,
CONF_VALUE_TEMPLATE,
CONF_WAIT_TEMPLATE,
CONF_WHILE,
ENTITY_MATCH_ALL,
ENTITY_MATCH_NONE,
SUN_EVENT_SUNRISE,
@ -80,7 +85,6 @@ import homeassistant.util.dt as dt_util
TIME_PERIOD_ERROR = "offset {} should be format 'HH:MM' or 'HH:MM:SS'"
# Home Assistant types
byte = vol.All(vol.Coerce(int), vol.Range(min=0, max=255))
small_float = vol.All(vol.Coerce(float), vol.Range(min=0, max=1))
@ -817,6 +821,16 @@ def make_entity_service_schema(
)
def script_action(value: Any) -> dict:
"""Validate a script action."""
if not isinstance(value, dict):
raise vol.Invalid("expected dictionary")
return ACTION_TYPE_SCHEMAS[determine_script_action(value)](value)
SCRIPT_SCHEMA = vol.All(ensure_list, [script_action])
EVENT_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ALIAS): string,
@ -998,6 +1012,25 @@ DEVICE_ACTION_SCHEMA = DEVICE_ACTION_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTR
_SCRIPT_SCENE_SCHEMA = vol.Schema({vol.Required(CONF_SCENE): entity_domain("scene")})
_SCRIPT_REPEAT_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ALIAS): string,
vol.Required(CONF_REPEAT): vol.All(
{
vol.Exclusive(CONF_COUNT, "repeat"): vol.Any(vol.Coerce(int), template),
vol.Exclusive(CONF_WHILE, "repeat"): vol.All(
ensure_list, [CONDITION_SCHEMA]
),
vol.Exclusive(CONF_UNTIL, "repeat"): vol.All(
ensure_list, [CONDITION_SCHEMA]
),
vol.Required(CONF_SEQUENCE): SCRIPT_SCHEMA,
},
has_at_least_one_key(CONF_COUNT, CONF_WHILE, CONF_UNTIL),
),
}
)
SCRIPT_ACTION_DELAY = "delay"
SCRIPT_ACTION_WAIT_TEMPLATE = "wait_template"
SCRIPT_ACTION_CHECK_CONDITION = "condition"
@ -1005,6 +1038,7 @@ SCRIPT_ACTION_FIRE_EVENT = "event"
SCRIPT_ACTION_CALL_SERVICE = "call_service"
SCRIPT_ACTION_DEVICE_AUTOMATION = "device"
SCRIPT_ACTION_ACTIVATE_SCENE = "scene"
SCRIPT_ACTION_REPEAT = "repeat"
def determine_script_action(action: dict) -> str:
@ -1027,6 +1061,9 @@ def determine_script_action(action: dict) -> str:
if CONF_SCENE in action:
return SCRIPT_ACTION_ACTIVATE_SCENE
if CONF_REPEAT in action:
return SCRIPT_ACTION_REPEAT
return SCRIPT_ACTION_CALL_SERVICE
@ -1038,15 +1075,5 @@ ACTION_TYPE_SCHEMAS: Dict[str, Callable[[Any], dict]] = {
SCRIPT_ACTION_CHECK_CONDITION: CONDITION_SCHEMA,
SCRIPT_ACTION_DEVICE_AUTOMATION: DEVICE_ACTION_SCHEMA,
SCRIPT_ACTION_ACTIVATE_SCENE: _SCRIPT_SCENE_SCHEMA,
SCRIPT_ACTION_REPEAT: _SCRIPT_REPEAT_SCHEMA,
}
def script_action(value: Any) -> dict:
"""Validate a script action."""
if not isinstance(value, dict):
raise vol.Invalid("expected dictionary")
return ACTION_TYPE_SCHEMAS[determine_script_action(value)](value)
SCRIPT_SCHEMA = vol.All(ensure_list, [script_action])

View file

@ -3,7 +3,8 @@ from abc import ABC, abstractmethod
import asyncio
from contextlib import suppress
from datetime import datetime
from itertools import islice
from functools import partial
import itertools
import logging
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, cast
@ -18,6 +19,7 @@ from homeassistant.const import (
CONF_ALIAS,
CONF_CONDITION,
CONF_CONTINUE_ON_TIMEOUT,
CONF_COUNT,
CONF_DELAY,
CONF_DEVICE_ID,
CONF_DOMAIN,
@ -26,9 +28,13 @@ from homeassistant.const import (
CONF_EVENT_DATA_TEMPLATE,
CONF_MODE,
CONF_QUEUE_SIZE,
CONF_REPEAT,
CONF_SCENE,
CONF_SEQUENCE,
CONF_TIMEOUT,
CONF_UNTIL,
CONF_WAIT_TEMPLATE,
CONF_WHILE,
SERVICE_TURN_OFF,
SERVICE_TURN_ON,
)
@ -86,6 +92,10 @@ SCRIPT_BASE_SCHEMA = vol.Schema(
}
)
_UNSUPPORTED_IN_LEGACY = {
cv.SCRIPT_ACTION_REPEAT: CONF_REPEAT,
}
def warn_deprecated_legacy(logger, msg):
"""Warn about deprecated legacy mode."""
@ -99,6 +109,17 @@ def warn_deprecated_legacy(logger, msg):
)
def validate_legacy_mode_actions(sequence):
"""Check for actions not supported in legacy mode."""
for action in sequence:
script_action = cv.determine_script_action(action)
if script_action in _UNSUPPORTED_IN_LEGACY:
raise vol.Invalid(
f"{_UNSUPPORTED_IN_LEGACY[script_action]} action not supported in "
f"{SCRIPT_MODE_LEGACY} mode"
)
def validate_queue_size(config):
"""Validate queue_size option."""
mode = config.get(CONF_MODE, DEFAULT_SCRIPT_MODE)
@ -330,22 +351,29 @@ class _ScriptRunBase(ABC):
self._action[CONF_EVENT], event_data, context=self._context
)
async def _async_get_condition(self, config):
config_cache_key = frozenset((k, str(v)) for k, v in config.items())
cond = self._config_cache.get(config_cache_key)
if not cond:
cond = await condition.async_from_config(self._hass, config, False)
self._config_cache[config_cache_key] = cond
return cond
async def _async_condition_step(self):
"""Test if condition is matching."""
config_cache_key = frozenset((k, str(v)) for k, v in self._action.items())
config = self._config_cache.get(config_cache_key)
if not config:
config = await condition.async_from_config(self._hass, self._action, False)
self._config_cache[config_cache_key] = config
self._script.last_action = self._action.get(
CONF_ALIAS, self._action[CONF_CONDITION]
)
check = config(self._hass, self._variables)
cond = await self._async_get_condition(self._action)
check = cond(self._hass, self._variables)
self._log("Test condition %s: %s", self._script.last_action, check)
if not check:
raise _StopScript
@abstractmethod
async def _async_repeat_step(self):
"""Repeat a sequence."""
def _log(self, msg, *args, level=logging.INFO):
self._script._log(msg, *args, level=level) # pylint: disable=protected-access
@ -441,6 +469,41 @@ class _ScriptRun(_ScriptRunBase):
task.cancel()
unsub()
async def _async_run_long_action(self, long_task):
"""Run a long task while monitoring for stop request."""
async def async_cancel_long_task():
# Stop long task and wait for it to finish.
long_task.cancel()
try:
await long_task
except Exception: # pylint: disable=broad-except
pass
# Wait for long task while monitoring for a stop request.
stop_task = self._hass.async_create_task(self._stop.wait())
try:
await asyncio.wait(
{long_task, stop_task}, return_when=asyncio.FIRST_COMPLETED
)
# If our task is cancelled, then cancel long task, too. Note that if long task
# is cancelled otherwise the CancelledError exception will not be raised to
# here due to the call to asyncio.wait(). Rather we'll check for that below.
except asyncio.CancelledError:
await async_cancel_long_task()
raise
finally:
stop_task.cancel()
if long_task.cancelled():
raise asyncio.CancelledError
if long_task.done():
# Propagate any exceptions that occurred.
long_task.result()
else:
# Stopped before long task completed, so cancel it.
await async_cancel_long_task()
async def _async_call_service_step(self):
"""Call the service specified in the action."""
domain, service, service_data = self._prep_call_service_step()
@ -474,37 +537,71 @@ class _ScriptRun(_ScriptRunBase):
await service_task
return
async def async_cancel_service_task():
# Stop service task and wait for it to finish.
service_task.cancel()
try:
await service_task
except Exception: # pylint: disable=broad-except
pass
await self._async_run_long_action(service_task)
# No call limit so watch for a stop request.
stop_task = self._hass.async_create_task(self._stop.wait())
try:
await asyncio.wait(
{service_task, stop_task}, return_when=asyncio.FIRST_COMPLETED
async def _async_repeat_step(self):
"""Repeat a sequence."""
description = self._action.get(CONF_ALIAS, "sequence")
repeat = self._action[CONF_REPEAT]
async def async_run_sequence(iteration, extra_msg="", extra_vars=None):
self._log("Repeating %s: Iteration %i%s", description, iteration, extra_msg)
repeat_vars = {"repeat": {"first": iteration == 1, "index": iteration}}
if extra_vars:
repeat_vars["repeat"].update(extra_vars)
task = self._hass.async_create_task(
# pylint: disable=protected-access
self._script._repeat_script[self._step].async_run(
# Add repeat to variables. Override if it already exists in case of
# nested calls.
{**(self._variables or {}), **repeat_vars},
self._context,
)
)
# If our task is cancelled, then cancel service task, too. Note that if service
# task is cancelled otherwise the CancelledError exception will not be raised to
# here due to the call to asyncio.wait(). Rather we'll check for that below.
except asyncio.CancelledError:
await async_cancel_service_task()
raise
finally:
stop_task.cancel()
await self._async_run_long_action(task)
if service_task.cancelled():
raise asyncio.CancelledError
if service_task.done():
# Propagate any exceptions that occurred.
service_task.result()
elif running_script:
# Stopped before service completed, so cancel service.
await async_cancel_service_task()
if CONF_COUNT in repeat:
count = repeat[CONF_COUNT]
if isinstance(count, template.Template):
try:
count = int(count.async_render(self._variables))
except (exceptions.TemplateError, ValueError) as ex:
self._log(
"Error rendering %s repeat count template: %s",
self._script.name,
ex,
level=logging.ERROR,
)
raise _StopScript
for iteration in range(1, count + 1):
await async_run_sequence(
iteration, f" of {count}", {"last": iteration == count}
)
if self._stop.is_set():
break
elif CONF_WHILE in repeat:
conditions = [
await self._async_get_condition(config) for config in repeat[CONF_WHILE]
]
for iteration in itertools.count(1):
if self._stop.is_set() or not all(
cond(self._hass, self._variables) for cond in conditions
):
break
await async_run_sequence(iteration)
elif CONF_UNTIL in repeat:
conditions = [
await self._async_get_condition(config) for config in repeat[CONF_UNTIL]
]
for iteration in itertools.count(1):
await async_run_sequence(iteration)
if self._stop.is_set() or all(
cond(self._hass, self._variables) for cond in conditions
):
break
class _QueuedScriptRun(_ScriptRun):
@ -599,7 +696,7 @@ class _LegacyScriptRun(_ScriptRunBase):
suspended = False
try:
for self._step, self._action in islice(
for self._step, self._action in itertools.islice(
enumerate(self._script.sequence), self._cur, None
):
await self._async_step(log_exceptions=not propagate_exceptions)
@ -689,6 +786,10 @@ class _LegacyScriptRun(_ScriptRunBase):
*self._prep_call_service_step(), blocking=True, context=self._context
)
async def _async_repeat_step(self):
"""Repeat a sequence."""
# Not supported in legacy mode.
def _async_remove_listener(self):
"""Remove listeners, if any."""
for unsub in self._async_listener:
@ -733,6 +834,22 @@ class Script:
for action in self.sequence
)
self._repeat_script = {}
for step, action in enumerate(sequence):
if cv.determine_script_action(action) == cv.SCRIPT_ACTION_REPEAT:
step_name = action.get(CONF_ALIAS, f"Repeat at step {step}")
sub_script = Script(
hass,
action[CONF_REPEAT][CONF_SEQUENCE],
f"{name}: {step_name}",
script_mode=SCRIPT_MODE_PARALLEL,
logger=self._logger,
)
sub_script.change_listener = partial(
self._chain_change_listener, sub_script
)
self._repeat_script[step] = sub_script
self._runs: List[_ScriptRunBase] = []
if script_mode == SCRIPT_MODE_QUEUE:
self._queue_size = queue_size
@ -746,6 +863,11 @@ class Script:
if self.change_listener:
self._hass.async_run_job(self.change_listener)
def _chain_change_listener(self, sub_script):
if sub_script.is_running:
self.last_action = sub_script.last_action
self._changed()
@property
def is_running(self) -> bool:
"""Return true if script is on."""

View file

@ -1072,20 +1072,27 @@ async def test_logbook_humanify_automation_triggered_event(hass):
assert event2["entity_id"] == "automation.bye"
async def test_invalid_config(hass):
"""Test invalid config."""
invalid_configs = [
{
"mode": "parallel",
"queue_size": 5,
"trigger": {"platform": "event", "event_type": "test_event"},
"action": [],
},
{
"mode": "legacy",
"trigger": {"platform": "event", "event_type": "test_event"},
"action": [{"repeat": {"count": 5, "sequence": []}}],
},
]
@pytest.mark.parametrize("value", invalid_configs)
async def test_invalid_configs(hass, value):
"""Test invalid configurations."""
with assert_setup_component(0, automation.DOMAIN):
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"mode": "parallel",
"queue_size": 5,
"trigger": {"platform": "event", "event_type": "test_event"},
"action": [],
}
},
hass, automation.DOMAIN, {automation.DOMAIN: value}
)

View file

@ -200,6 +200,12 @@ invalid_configs = [
{"test hello world": {"sequence": [{"event": "bla"}]}},
{"test": {"sequence": {"event": "test_event", "service": "homeassistant.turn_on"}}},
{"test": {"sequence": [], "mode": "parallel", "queue_size": 5}},
{
"test": {
"mode": "legacy",
"sequence": [{"repeat": {"count": 5, "sequence": []}}],
}
},
]

View file

@ -822,6 +822,101 @@ async def test_condition_all_cached(hass, script_mode):
assert len(script_obj._config_cache) == 2
async def test_repeat_count(hass):
"""Test repeat action w/ count option."""
event = "test_event"
events = async_capture_events(hass, event)
count = 3
sequence = cv.SCRIPT_SCHEMA(
{
"repeat": {
"count": count,
"sequence": {
"event": event,
"event_data_template": {
"first": "{{ repeat.first }}",
"index": "{{ repeat.index }}",
"last": "{{ repeat.last }}",
},
},
}
}
)
script_obj = script.Script(hass, sequence, script_mode="ignore")
await script_obj.async_run()
await hass.async_block_till_done()
assert len(events) == count
for index, event in enumerate(events):
assert event.data.get("first") == str(index == 0)
assert event.data.get("index") == str(index + 1)
assert event.data.get("last") == str(index == count - 1)
@pytest.mark.parametrize("condition", ["while", "until"])
async def test_repeat_conditional(hass, condition):
"""Test repeat action w/ while option."""
event = "test_event"
events = async_capture_events(hass, event)
count = 3
sequence = {
"repeat": {
"sequence": [
{
"event": event,
"event_data_template": {
"first": "{{ repeat.first }}",
"index": "{{ repeat.index }}",
},
},
{"wait_template": "{{ is_state('sensor.test', 'next') }}"},
{"wait_template": "{{ not is_state('sensor.test', 'next') }}"},
],
}
}
if condition == "while":
sequence["repeat"]["while"] = {
"condition": "template",
"value_template": "{{ not is_state('sensor.test', 'done') }}",
}
else:
sequence["repeat"]["until"] = {
"condition": "template",
"value_template": "{{ is_state('sensor.test', 'done') }}",
}
script_obj = script.Script(hass, cv.SCRIPT_SCHEMA(sequence), script_mode="ignore")
wait_started = async_watch_for_action(script_obj, "wait")
hass.states.async_set("sensor.test", "1")
hass.async_create_task(script_obj.async_run())
try:
for index in range(2, count + 1):
await asyncio.wait_for(wait_started.wait(), 1)
wait_started.clear()
hass.states.async_set("sensor.test", "next")
await asyncio.wait_for(wait_started.wait(), 1)
wait_started.clear()
hass.states.async_set("sensor.test", index)
await asyncio.wait_for(wait_started.wait(), 1)
hass.states.async_set("sensor.test", "next")
await asyncio.wait_for(wait_started.wait(), 1)
wait_started.clear()
hass.states.async_set("sensor.test", "done")
await asyncio.wait_for(hass.async_block_till_done(), 1)
except asyncio.TimeoutError:
await script_obj.async_stop()
raise
assert len(events) == count
for index, event in enumerate(events):
assert event.data.get("first") == str(index == 0)
assert event.data.get("index") == str(index + 1)
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_last_triggered(hass, script_mode):
"""Test the last_triggered."""