From 91271f388cac8773895593faf4fff078094e5ed7 Mon Sep 17 00:00:00 2001 From: Phil Bruckner Date: Fri, 10 Jul 2020 13:37:19 -0500 Subject: [PATCH] Add new repeat loop for scripts and automations (#37589) --- homeassistant/components/automation/config.py | 4 + homeassistant/components/script/__init__.py | 11 +- homeassistant/const.py | 5 + homeassistant/helpers/config_validation.py | 51 +++-- homeassistant/helpers/script.py | 196 ++++++++++++++---- tests/components/automation/test_init.py | 31 +-- tests/components/script/test_init.py | 6 + tests/helpers/test_script.py | 95 +++++++++ 8 files changed, 337 insertions(+), 62 deletions(-) diff --git a/homeassistant/components/automation/config.py b/homeassistant/components/automation/config.py index 4fa913fc3ce..fe3066eb7b2 100644 --- a/homeassistant/components/automation/config.py +++ b/homeassistant/components/automation/config.py @@ -15,6 +15,7 @@ from homeassistant.helpers import condition, config_per_platform from homeassistant.helpers.script import ( SCRIPT_MODE_LEGACY, async_validate_action_config, + validate_legacy_mode_actions, warn_deprecated_legacy, ) from homeassistant.loader import IntegrationNotFound @@ -55,6 +56,9 @@ async def async_validate_config_item(hass, config, full_config=None): *[async_validate_action_config(hass, action) for action in config[CONF_ACTION]] ) + if config.get(CONF_MODE, SCRIPT_MODE_LEGACY) == SCRIPT_MODE_LEGACY: + validate_legacy_mode_actions(config[CONF_ACTION]) + return config diff --git a/homeassistant/components/script/__init__.py b/homeassistant/components/script/__init__.py index 03dd6a54e8f..f444cc45d76 100644 --- a/homeassistant/components/script/__init__.py +++ b/homeassistant/components/script/__init__.py @@ -12,6 +12,7 @@ from homeassistant.const import ( CONF_ICON, CONF_MODE, CONF_QUEUE_SIZE, + CONF_SEQUENCE, SERVICE_RELOAD, SERVICE_TOGGLE, SERVICE_TURN_OFF, @@ -27,6 +28,7 @@ from homeassistant.helpers.script import ( SCRIPT_BASE_SCHEMA, SCRIPT_MODE_LEGACY, Script, + validate_legacy_mode_actions, validate_queue_size, warn_deprecated_legacy, ) @@ -44,7 +46,6 @@ ATTR_VARIABLES = "variables" CONF_DESCRIPTION = "description" CONF_EXAMPLE = "example" CONF_FIELDS = "fields" -CONF_SEQUENCE = "sequence" ENTITY_ID_FORMAT = DOMAIN + ".{}" @@ -63,6 +64,13 @@ def _deprecated_legacy_mode(config): return config +def _not_supported_in_legacy_mode(config): + if config.get(CONF_MODE, SCRIPT_MODE_LEGACY) == SCRIPT_MODE_LEGACY: + validate_legacy_mode_actions(config[CONF_SEQUENCE]) + + return config + + SCRIPT_ENTRY_SCHEMA = vol.All( SCRIPT_BASE_SCHEMA.extend( { @@ -79,6 +87,7 @@ SCRIPT_ENTRY_SCHEMA = vol.All( } ), validate_queue_size, + _not_supported_in_legacy_mode, ) CONFIG_SCHEMA = vol.Schema( diff --git a/homeassistant/const.py b/homeassistant/const.py index 658b3081d02..935557d9407 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -61,6 +61,7 @@ CONF_COMMAND_STATE = "command_state" CONF_COMMAND_STOP = "command_stop" CONF_CONDITION = "condition" CONF_CONTINUE_ON_TIMEOUT = "continue_on_timeout" +CONF_COUNT = "count" CONF_COVERS = "covers" CONF_CURRENCY = "currency" CONF_CUSTOMIZE = "customize" @@ -139,6 +140,7 @@ CONF_QUOTE = "quote" CONF_RADIUS = "radius" CONF_RECIPIENT = "recipient" CONF_REGION = "region" +CONF_REPEAT = "repeat" CONF_RESOURCE = "resource" CONF_RESOURCES = "resources" CONF_RESOURCE_TEMPLATE = "resource_template" @@ -149,6 +151,7 @@ CONF_SCENE = "scene" CONF_SENDER = "sender" CONF_SENSORS = "sensors" CONF_SENSOR_TYPE = "sensor_type" +CONF_SEQUENCE = "sequence" CONF_SERVICE = "service" CONF_SERVICE_DATA = "data" CONF_SERVICE_TEMPLATE = "service_template" @@ -169,6 +172,7 @@ CONF_TTL = "ttl" CONF_TYPE = "type" CONF_UNIT_OF_MEASUREMENT = "unit_of_measurement" CONF_UNIT_SYSTEM = "unit_system" +CONF_UNTIL = "until" CONF_URL = "url" CONF_USERNAME = "username" CONF_VALUE_TEMPLATE = "value_template" @@ -176,6 +180,7 @@ CONF_VERIFY_SSL = "verify_ssl" CONF_WAIT_TEMPLATE = "wait_template" CONF_WEBHOOK_ID = "webhook_id" CONF_WEEKDAY = "weekday" +CONF_WHILE = "while" CONF_WHITELIST = "whitelist" CONF_WHITELIST_EXTERNAL_DIRS = "whitelist_external_dirs" CONF_WHITE_VALUE = "white_value" diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index 24ba0d3c0f0..0e20dea718b 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -40,6 +40,7 @@ from homeassistant.const import ( CONF_BELOW, CONF_CONDITION, CONF_CONTINUE_ON_TIMEOUT, + CONF_COUNT, CONF_DELAY, CONF_DEVICE_ID, CONF_DOMAIN, @@ -50,16 +51,20 @@ from homeassistant.const import ( CONF_EVENT_DATA_TEMPLATE, CONF_FOR, CONF_PLATFORM, + CONF_REPEAT, CONF_SCAN_INTERVAL, CONF_SCENE, + CONF_SEQUENCE, CONF_SERVICE, CONF_SERVICE_TEMPLATE, CONF_STATE, CONF_TIMEOUT, CONF_UNIT_SYSTEM_IMPERIAL, CONF_UNIT_SYSTEM_METRIC, + CONF_UNTIL, CONF_VALUE_TEMPLATE, CONF_WAIT_TEMPLATE, + CONF_WHILE, ENTITY_MATCH_ALL, ENTITY_MATCH_NONE, SUN_EVENT_SUNRISE, @@ -80,7 +85,6 @@ import homeassistant.util.dt as dt_util TIME_PERIOD_ERROR = "offset {} should be format 'HH:MM' or 'HH:MM:SS'" - # Home Assistant types byte = vol.All(vol.Coerce(int), vol.Range(min=0, max=255)) small_float = vol.All(vol.Coerce(float), vol.Range(min=0, max=1)) @@ -817,6 +821,16 @@ def make_entity_service_schema( ) +def script_action(value: Any) -> dict: + """Validate a script action.""" + if not isinstance(value, dict): + raise vol.Invalid("expected dictionary") + + return ACTION_TYPE_SCHEMAS[determine_script_action(value)](value) + + +SCRIPT_SCHEMA = vol.All(ensure_list, [script_action]) + EVENT_SCHEMA = vol.Schema( { vol.Optional(CONF_ALIAS): string, @@ -998,6 +1012,25 @@ DEVICE_ACTION_SCHEMA = DEVICE_ACTION_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTR _SCRIPT_SCENE_SCHEMA = vol.Schema({vol.Required(CONF_SCENE): entity_domain("scene")}) +_SCRIPT_REPEAT_SCHEMA = vol.Schema( + { + vol.Optional(CONF_ALIAS): string, + vol.Required(CONF_REPEAT): vol.All( + { + vol.Exclusive(CONF_COUNT, "repeat"): vol.Any(vol.Coerce(int), template), + vol.Exclusive(CONF_WHILE, "repeat"): vol.All( + ensure_list, [CONDITION_SCHEMA] + ), + vol.Exclusive(CONF_UNTIL, "repeat"): vol.All( + ensure_list, [CONDITION_SCHEMA] + ), + vol.Required(CONF_SEQUENCE): SCRIPT_SCHEMA, + }, + has_at_least_one_key(CONF_COUNT, CONF_WHILE, CONF_UNTIL), + ), + } +) + SCRIPT_ACTION_DELAY = "delay" SCRIPT_ACTION_WAIT_TEMPLATE = "wait_template" SCRIPT_ACTION_CHECK_CONDITION = "condition" @@ -1005,6 +1038,7 @@ SCRIPT_ACTION_FIRE_EVENT = "event" SCRIPT_ACTION_CALL_SERVICE = "call_service" SCRIPT_ACTION_DEVICE_AUTOMATION = "device" SCRIPT_ACTION_ACTIVATE_SCENE = "scene" +SCRIPT_ACTION_REPEAT = "repeat" def determine_script_action(action: dict) -> str: @@ -1027,6 +1061,9 @@ def determine_script_action(action: dict) -> str: if CONF_SCENE in action: return SCRIPT_ACTION_ACTIVATE_SCENE + if CONF_REPEAT in action: + return SCRIPT_ACTION_REPEAT + return SCRIPT_ACTION_CALL_SERVICE @@ -1038,15 +1075,5 @@ ACTION_TYPE_SCHEMAS: Dict[str, Callable[[Any], dict]] = { SCRIPT_ACTION_CHECK_CONDITION: CONDITION_SCHEMA, SCRIPT_ACTION_DEVICE_AUTOMATION: DEVICE_ACTION_SCHEMA, SCRIPT_ACTION_ACTIVATE_SCENE: _SCRIPT_SCENE_SCHEMA, + SCRIPT_ACTION_REPEAT: _SCRIPT_REPEAT_SCHEMA, } - - -def script_action(value: Any) -> dict: - """Validate a script action.""" - if not isinstance(value, dict): - raise vol.Invalid("expected dictionary") - - return ACTION_TYPE_SCHEMAS[determine_script_action(value)](value) - - -SCRIPT_SCHEMA = vol.All(ensure_list, [script_action]) diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index c235a65dc81..88d1779bf1b 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -3,7 +3,8 @@ from abc import ABC, abstractmethod import asyncio from contextlib import suppress from datetime import datetime -from itertools import islice +from functools import partial +import itertools import logging from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, cast @@ -18,6 +19,7 @@ from homeassistant.const import ( CONF_ALIAS, CONF_CONDITION, CONF_CONTINUE_ON_TIMEOUT, + CONF_COUNT, CONF_DELAY, CONF_DEVICE_ID, CONF_DOMAIN, @@ -26,9 +28,13 @@ from homeassistant.const import ( CONF_EVENT_DATA_TEMPLATE, CONF_MODE, CONF_QUEUE_SIZE, + CONF_REPEAT, CONF_SCENE, + CONF_SEQUENCE, CONF_TIMEOUT, + CONF_UNTIL, CONF_WAIT_TEMPLATE, + CONF_WHILE, SERVICE_TURN_OFF, SERVICE_TURN_ON, ) @@ -86,6 +92,10 @@ SCRIPT_BASE_SCHEMA = vol.Schema( } ) +_UNSUPPORTED_IN_LEGACY = { + cv.SCRIPT_ACTION_REPEAT: CONF_REPEAT, +} + def warn_deprecated_legacy(logger, msg): """Warn about deprecated legacy mode.""" @@ -99,6 +109,17 @@ def warn_deprecated_legacy(logger, msg): ) +def validate_legacy_mode_actions(sequence): + """Check for actions not supported in legacy mode.""" + for action in sequence: + script_action = cv.determine_script_action(action) + if script_action in _UNSUPPORTED_IN_LEGACY: + raise vol.Invalid( + f"{_UNSUPPORTED_IN_LEGACY[script_action]} action not supported in " + f"{SCRIPT_MODE_LEGACY} mode" + ) + + def validate_queue_size(config): """Validate queue_size option.""" mode = config.get(CONF_MODE, DEFAULT_SCRIPT_MODE) @@ -330,22 +351,29 @@ class _ScriptRunBase(ABC): self._action[CONF_EVENT], event_data, context=self._context ) + async def _async_get_condition(self, config): + config_cache_key = frozenset((k, str(v)) for k, v in config.items()) + cond = self._config_cache.get(config_cache_key) + if not cond: + cond = await condition.async_from_config(self._hass, config, False) + self._config_cache[config_cache_key] = cond + return cond + async def _async_condition_step(self): """Test if condition is matching.""" - config_cache_key = frozenset((k, str(v)) for k, v in self._action.items()) - config = self._config_cache.get(config_cache_key) - if not config: - config = await condition.async_from_config(self._hass, self._action, False) - self._config_cache[config_cache_key] = config - self._script.last_action = self._action.get( CONF_ALIAS, self._action[CONF_CONDITION] ) - check = config(self._hass, self._variables) + cond = await self._async_get_condition(self._action) + check = cond(self._hass, self._variables) self._log("Test condition %s: %s", self._script.last_action, check) if not check: raise _StopScript + @abstractmethod + async def _async_repeat_step(self): + """Repeat a sequence.""" + def _log(self, msg, *args, level=logging.INFO): self._script._log(msg, *args, level=level) # pylint: disable=protected-access @@ -441,6 +469,41 @@ class _ScriptRun(_ScriptRunBase): task.cancel() unsub() + async def _async_run_long_action(self, long_task): + """Run a long task while monitoring for stop request.""" + + async def async_cancel_long_task(): + # Stop long task and wait for it to finish. + long_task.cancel() + try: + await long_task + except Exception: # pylint: disable=broad-except + pass + + # Wait for long task while monitoring for a stop request. + stop_task = self._hass.async_create_task(self._stop.wait()) + try: + await asyncio.wait( + {long_task, stop_task}, return_when=asyncio.FIRST_COMPLETED + ) + # If our task is cancelled, then cancel long task, too. Note that if long task + # is cancelled otherwise the CancelledError exception will not be raised to + # here due to the call to asyncio.wait(). Rather we'll check for that below. + except asyncio.CancelledError: + await async_cancel_long_task() + raise + finally: + stop_task.cancel() + + if long_task.cancelled(): + raise asyncio.CancelledError + if long_task.done(): + # Propagate any exceptions that occurred. + long_task.result() + else: + # Stopped before long task completed, so cancel it. + await async_cancel_long_task() + async def _async_call_service_step(self): """Call the service specified in the action.""" domain, service, service_data = self._prep_call_service_step() @@ -474,37 +537,71 @@ class _ScriptRun(_ScriptRunBase): await service_task return - async def async_cancel_service_task(): - # Stop service task and wait for it to finish. - service_task.cancel() - try: - await service_task - except Exception: # pylint: disable=broad-except - pass + await self._async_run_long_action(service_task) - # No call limit so watch for a stop request. - stop_task = self._hass.async_create_task(self._stop.wait()) - try: - await asyncio.wait( - {service_task, stop_task}, return_when=asyncio.FIRST_COMPLETED + async def _async_repeat_step(self): + """Repeat a sequence.""" + + description = self._action.get(CONF_ALIAS, "sequence") + repeat = self._action[CONF_REPEAT] + + async def async_run_sequence(iteration, extra_msg="", extra_vars=None): + self._log("Repeating %s: Iteration %i%s", description, iteration, extra_msg) + repeat_vars = {"repeat": {"first": iteration == 1, "index": iteration}} + if extra_vars: + repeat_vars["repeat"].update(extra_vars) + task = self._hass.async_create_task( + # pylint: disable=protected-access + self._script._repeat_script[self._step].async_run( + # Add repeat to variables. Override if it already exists in case of + # nested calls. + {**(self._variables or {}), **repeat_vars}, + self._context, + ) ) - # If our task is cancelled, then cancel service task, too. Note that if service - # task is cancelled otherwise the CancelledError exception will not be raised to - # here due to the call to asyncio.wait(). Rather we'll check for that below. - except asyncio.CancelledError: - await async_cancel_service_task() - raise - finally: - stop_task.cancel() + await self._async_run_long_action(task) - if service_task.cancelled(): - raise asyncio.CancelledError - if service_task.done(): - # Propagate any exceptions that occurred. - service_task.result() - elif running_script: - # Stopped before service completed, so cancel service. - await async_cancel_service_task() + if CONF_COUNT in repeat: + count = repeat[CONF_COUNT] + if isinstance(count, template.Template): + try: + count = int(count.async_render(self._variables)) + except (exceptions.TemplateError, ValueError) as ex: + self._log( + "Error rendering %s repeat count template: %s", + self._script.name, + ex, + level=logging.ERROR, + ) + raise _StopScript + for iteration in range(1, count + 1): + await async_run_sequence( + iteration, f" of {count}", {"last": iteration == count} + ) + if self._stop.is_set(): + break + + elif CONF_WHILE in repeat: + conditions = [ + await self._async_get_condition(config) for config in repeat[CONF_WHILE] + ] + for iteration in itertools.count(1): + if self._stop.is_set() or not all( + cond(self._hass, self._variables) for cond in conditions + ): + break + await async_run_sequence(iteration) + + elif CONF_UNTIL in repeat: + conditions = [ + await self._async_get_condition(config) for config in repeat[CONF_UNTIL] + ] + for iteration in itertools.count(1): + await async_run_sequence(iteration) + if self._stop.is_set() or all( + cond(self._hass, self._variables) for cond in conditions + ): + break class _QueuedScriptRun(_ScriptRun): @@ -599,7 +696,7 @@ class _LegacyScriptRun(_ScriptRunBase): suspended = False try: - for self._step, self._action in islice( + for self._step, self._action in itertools.islice( enumerate(self._script.sequence), self._cur, None ): await self._async_step(log_exceptions=not propagate_exceptions) @@ -689,6 +786,10 @@ class _LegacyScriptRun(_ScriptRunBase): *self._prep_call_service_step(), blocking=True, context=self._context ) + async def _async_repeat_step(self): + """Repeat a sequence.""" + # Not supported in legacy mode. + def _async_remove_listener(self): """Remove listeners, if any.""" for unsub in self._async_listener: @@ -733,6 +834,22 @@ class Script: for action in self.sequence ) + self._repeat_script = {} + for step, action in enumerate(sequence): + if cv.determine_script_action(action) == cv.SCRIPT_ACTION_REPEAT: + step_name = action.get(CONF_ALIAS, f"Repeat at step {step}") + sub_script = Script( + hass, + action[CONF_REPEAT][CONF_SEQUENCE], + f"{name}: {step_name}", + script_mode=SCRIPT_MODE_PARALLEL, + logger=self._logger, + ) + sub_script.change_listener = partial( + self._chain_change_listener, sub_script + ) + self._repeat_script[step] = sub_script + self._runs: List[_ScriptRunBase] = [] if script_mode == SCRIPT_MODE_QUEUE: self._queue_size = queue_size @@ -746,6 +863,11 @@ class Script: if self.change_listener: self._hass.async_run_job(self.change_listener) + def _chain_change_listener(self, sub_script): + if sub_script.is_running: + self.last_action = sub_script.last_action + self._changed() + @property def is_running(self) -> bool: """Return true if script is on.""" diff --git a/tests/components/automation/test_init.py b/tests/components/automation/test_init.py index 70feb2a7796..a2edd6ed07c 100644 --- a/tests/components/automation/test_init.py +++ b/tests/components/automation/test_init.py @@ -1072,20 +1072,27 @@ async def test_logbook_humanify_automation_triggered_event(hass): assert event2["entity_id"] == "automation.bye" -async def test_invalid_config(hass): - """Test invalid config.""" +invalid_configs = [ + { + "mode": "parallel", + "queue_size": 5, + "trigger": {"platform": "event", "event_type": "test_event"}, + "action": [], + }, + { + "mode": "legacy", + "trigger": {"platform": "event", "event_type": "test_event"}, + "action": [{"repeat": {"count": 5, "sequence": []}}], + }, +] + + +@pytest.mark.parametrize("value", invalid_configs) +async def test_invalid_configs(hass, value): + """Test invalid configurations.""" with assert_setup_component(0, automation.DOMAIN): assert await async_setup_component( - hass, - automation.DOMAIN, - { - automation.DOMAIN: { - "mode": "parallel", - "queue_size": 5, - "trigger": {"platform": "event", "event_type": "test_event"}, - "action": [], - } - }, + hass, automation.DOMAIN, {automation.DOMAIN: value} ) diff --git a/tests/components/script/test_init.py b/tests/components/script/test_init.py index 8faa2936352..10f2efc5e5c 100644 --- a/tests/components/script/test_init.py +++ b/tests/components/script/test_init.py @@ -200,6 +200,12 @@ invalid_configs = [ {"test hello world": {"sequence": [{"event": "bla"}]}}, {"test": {"sequence": {"event": "test_event", "service": "homeassistant.turn_on"}}}, {"test": {"sequence": [], "mode": "parallel", "queue_size": 5}}, + { + "test": { + "mode": "legacy", + "sequence": [{"repeat": {"count": 5, "sequence": []}}], + } + }, ] diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index c00dadc27e8..f6c9ec4ac5b 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -822,6 +822,101 @@ async def test_condition_all_cached(hass, script_mode): assert len(script_obj._config_cache) == 2 +async def test_repeat_count(hass): + """Test repeat action w/ count option.""" + event = "test_event" + events = async_capture_events(hass, event) + count = 3 + + sequence = cv.SCRIPT_SCHEMA( + { + "repeat": { + "count": count, + "sequence": { + "event": event, + "event_data_template": { + "first": "{{ repeat.first }}", + "index": "{{ repeat.index }}", + "last": "{{ repeat.last }}", + }, + }, + } + } + ) + script_obj = script.Script(hass, sequence, script_mode="ignore") + + await script_obj.async_run() + await hass.async_block_till_done() + + assert len(events) == count + for index, event in enumerate(events): + assert event.data.get("first") == str(index == 0) + assert event.data.get("index") == str(index + 1) + assert event.data.get("last") == str(index == count - 1) + + +@pytest.mark.parametrize("condition", ["while", "until"]) +async def test_repeat_conditional(hass, condition): + """Test repeat action w/ while option.""" + event = "test_event" + events = async_capture_events(hass, event) + count = 3 + + sequence = { + "repeat": { + "sequence": [ + { + "event": event, + "event_data_template": { + "first": "{{ repeat.first }}", + "index": "{{ repeat.index }}", + }, + }, + {"wait_template": "{{ is_state('sensor.test', 'next') }}"}, + {"wait_template": "{{ not is_state('sensor.test', 'next') }}"}, + ], + } + } + if condition == "while": + sequence["repeat"]["while"] = { + "condition": "template", + "value_template": "{{ not is_state('sensor.test', 'done') }}", + } + else: + sequence["repeat"]["until"] = { + "condition": "template", + "value_template": "{{ is_state('sensor.test', 'done') }}", + } + script_obj = script.Script(hass, cv.SCRIPT_SCHEMA(sequence), script_mode="ignore") + + wait_started = async_watch_for_action(script_obj, "wait") + hass.states.async_set("sensor.test", "1") + + hass.async_create_task(script_obj.async_run()) + try: + for index in range(2, count + 1): + await asyncio.wait_for(wait_started.wait(), 1) + wait_started.clear() + hass.states.async_set("sensor.test", "next") + await asyncio.wait_for(wait_started.wait(), 1) + wait_started.clear() + hass.states.async_set("sensor.test", index) + await asyncio.wait_for(wait_started.wait(), 1) + hass.states.async_set("sensor.test", "next") + await asyncio.wait_for(wait_started.wait(), 1) + wait_started.clear() + hass.states.async_set("sensor.test", "done") + await asyncio.wait_for(hass.async_block_till_done(), 1) + except asyncio.TimeoutError: + await script_obj.async_stop() + raise + + assert len(events) == count + for index, event in enumerate(events): + assert event.data.get("first") == str(index == 0) + assert event.data.get("index") == str(index + 1) + + @pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) async def test_last_triggered(hass, script_mode): """Test the last_triggered."""