Add new repeat loop for scripts and automations (#37589)
This commit is contained in:
parent
b187b17a4f
commit
91271f388c
8 changed files with 337 additions and 62 deletions
|
@ -15,6 +15,7 @@ from homeassistant.helpers import condition, config_per_platform
|
||||||
from homeassistant.helpers.script import (
|
from homeassistant.helpers.script import (
|
||||||
SCRIPT_MODE_LEGACY,
|
SCRIPT_MODE_LEGACY,
|
||||||
async_validate_action_config,
|
async_validate_action_config,
|
||||||
|
validate_legacy_mode_actions,
|
||||||
warn_deprecated_legacy,
|
warn_deprecated_legacy,
|
||||||
)
|
)
|
||||||
from homeassistant.loader import IntegrationNotFound
|
from homeassistant.loader import IntegrationNotFound
|
||||||
|
@ -55,6 +56,9 @@ async def async_validate_config_item(hass, config, full_config=None):
|
||||||
*[async_validate_action_config(hass, action) for action in config[CONF_ACTION]]
|
*[async_validate_action_config(hass, action) for action in config[CONF_ACTION]]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if config.get(CONF_MODE, SCRIPT_MODE_LEGACY) == SCRIPT_MODE_LEGACY:
|
||||||
|
validate_legacy_mode_actions(config[CONF_ACTION])
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ from homeassistant.const import (
|
||||||
CONF_ICON,
|
CONF_ICON,
|
||||||
CONF_MODE,
|
CONF_MODE,
|
||||||
CONF_QUEUE_SIZE,
|
CONF_QUEUE_SIZE,
|
||||||
|
CONF_SEQUENCE,
|
||||||
SERVICE_RELOAD,
|
SERVICE_RELOAD,
|
||||||
SERVICE_TOGGLE,
|
SERVICE_TOGGLE,
|
||||||
SERVICE_TURN_OFF,
|
SERVICE_TURN_OFF,
|
||||||
|
@ -27,6 +28,7 @@ from homeassistant.helpers.script import (
|
||||||
SCRIPT_BASE_SCHEMA,
|
SCRIPT_BASE_SCHEMA,
|
||||||
SCRIPT_MODE_LEGACY,
|
SCRIPT_MODE_LEGACY,
|
||||||
Script,
|
Script,
|
||||||
|
validate_legacy_mode_actions,
|
||||||
validate_queue_size,
|
validate_queue_size,
|
||||||
warn_deprecated_legacy,
|
warn_deprecated_legacy,
|
||||||
)
|
)
|
||||||
|
@ -44,7 +46,6 @@ ATTR_VARIABLES = "variables"
|
||||||
CONF_DESCRIPTION = "description"
|
CONF_DESCRIPTION = "description"
|
||||||
CONF_EXAMPLE = "example"
|
CONF_EXAMPLE = "example"
|
||||||
CONF_FIELDS = "fields"
|
CONF_FIELDS = "fields"
|
||||||
CONF_SEQUENCE = "sequence"
|
|
||||||
|
|
||||||
ENTITY_ID_FORMAT = DOMAIN + ".{}"
|
ENTITY_ID_FORMAT = DOMAIN + ".{}"
|
||||||
|
|
||||||
|
@ -63,6 +64,13 @@ def _deprecated_legacy_mode(config):
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def _not_supported_in_legacy_mode(config):
|
||||||
|
if config.get(CONF_MODE, SCRIPT_MODE_LEGACY) == SCRIPT_MODE_LEGACY:
|
||||||
|
validate_legacy_mode_actions(config[CONF_SEQUENCE])
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
SCRIPT_ENTRY_SCHEMA = vol.All(
|
SCRIPT_ENTRY_SCHEMA = vol.All(
|
||||||
SCRIPT_BASE_SCHEMA.extend(
|
SCRIPT_BASE_SCHEMA.extend(
|
||||||
{
|
{
|
||||||
|
@ -79,6 +87,7 @@ SCRIPT_ENTRY_SCHEMA = vol.All(
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
validate_queue_size,
|
validate_queue_size,
|
||||||
|
_not_supported_in_legacy_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
CONFIG_SCHEMA = vol.Schema(
|
CONFIG_SCHEMA = vol.Schema(
|
||||||
|
|
|
@ -61,6 +61,7 @@ CONF_COMMAND_STATE = "command_state"
|
||||||
CONF_COMMAND_STOP = "command_stop"
|
CONF_COMMAND_STOP = "command_stop"
|
||||||
CONF_CONDITION = "condition"
|
CONF_CONDITION = "condition"
|
||||||
CONF_CONTINUE_ON_TIMEOUT = "continue_on_timeout"
|
CONF_CONTINUE_ON_TIMEOUT = "continue_on_timeout"
|
||||||
|
CONF_COUNT = "count"
|
||||||
CONF_COVERS = "covers"
|
CONF_COVERS = "covers"
|
||||||
CONF_CURRENCY = "currency"
|
CONF_CURRENCY = "currency"
|
||||||
CONF_CUSTOMIZE = "customize"
|
CONF_CUSTOMIZE = "customize"
|
||||||
|
@ -139,6 +140,7 @@ CONF_QUOTE = "quote"
|
||||||
CONF_RADIUS = "radius"
|
CONF_RADIUS = "radius"
|
||||||
CONF_RECIPIENT = "recipient"
|
CONF_RECIPIENT = "recipient"
|
||||||
CONF_REGION = "region"
|
CONF_REGION = "region"
|
||||||
|
CONF_REPEAT = "repeat"
|
||||||
CONF_RESOURCE = "resource"
|
CONF_RESOURCE = "resource"
|
||||||
CONF_RESOURCES = "resources"
|
CONF_RESOURCES = "resources"
|
||||||
CONF_RESOURCE_TEMPLATE = "resource_template"
|
CONF_RESOURCE_TEMPLATE = "resource_template"
|
||||||
|
@ -149,6 +151,7 @@ CONF_SCENE = "scene"
|
||||||
CONF_SENDER = "sender"
|
CONF_SENDER = "sender"
|
||||||
CONF_SENSORS = "sensors"
|
CONF_SENSORS = "sensors"
|
||||||
CONF_SENSOR_TYPE = "sensor_type"
|
CONF_SENSOR_TYPE = "sensor_type"
|
||||||
|
CONF_SEQUENCE = "sequence"
|
||||||
CONF_SERVICE = "service"
|
CONF_SERVICE = "service"
|
||||||
CONF_SERVICE_DATA = "data"
|
CONF_SERVICE_DATA = "data"
|
||||||
CONF_SERVICE_TEMPLATE = "service_template"
|
CONF_SERVICE_TEMPLATE = "service_template"
|
||||||
|
@ -169,6 +172,7 @@ CONF_TTL = "ttl"
|
||||||
CONF_TYPE = "type"
|
CONF_TYPE = "type"
|
||||||
CONF_UNIT_OF_MEASUREMENT = "unit_of_measurement"
|
CONF_UNIT_OF_MEASUREMENT = "unit_of_measurement"
|
||||||
CONF_UNIT_SYSTEM = "unit_system"
|
CONF_UNIT_SYSTEM = "unit_system"
|
||||||
|
CONF_UNTIL = "until"
|
||||||
CONF_URL = "url"
|
CONF_URL = "url"
|
||||||
CONF_USERNAME = "username"
|
CONF_USERNAME = "username"
|
||||||
CONF_VALUE_TEMPLATE = "value_template"
|
CONF_VALUE_TEMPLATE = "value_template"
|
||||||
|
@ -176,6 +180,7 @@ CONF_VERIFY_SSL = "verify_ssl"
|
||||||
CONF_WAIT_TEMPLATE = "wait_template"
|
CONF_WAIT_TEMPLATE = "wait_template"
|
||||||
CONF_WEBHOOK_ID = "webhook_id"
|
CONF_WEBHOOK_ID = "webhook_id"
|
||||||
CONF_WEEKDAY = "weekday"
|
CONF_WEEKDAY = "weekday"
|
||||||
|
CONF_WHILE = "while"
|
||||||
CONF_WHITELIST = "whitelist"
|
CONF_WHITELIST = "whitelist"
|
||||||
CONF_WHITELIST_EXTERNAL_DIRS = "whitelist_external_dirs"
|
CONF_WHITELIST_EXTERNAL_DIRS = "whitelist_external_dirs"
|
||||||
CONF_WHITE_VALUE = "white_value"
|
CONF_WHITE_VALUE = "white_value"
|
||||||
|
|
|
@ -40,6 +40,7 @@ from homeassistant.const import (
|
||||||
CONF_BELOW,
|
CONF_BELOW,
|
||||||
CONF_CONDITION,
|
CONF_CONDITION,
|
||||||
CONF_CONTINUE_ON_TIMEOUT,
|
CONF_CONTINUE_ON_TIMEOUT,
|
||||||
|
CONF_COUNT,
|
||||||
CONF_DELAY,
|
CONF_DELAY,
|
||||||
CONF_DEVICE_ID,
|
CONF_DEVICE_ID,
|
||||||
CONF_DOMAIN,
|
CONF_DOMAIN,
|
||||||
|
@ -50,16 +51,20 @@ from homeassistant.const import (
|
||||||
CONF_EVENT_DATA_TEMPLATE,
|
CONF_EVENT_DATA_TEMPLATE,
|
||||||
CONF_FOR,
|
CONF_FOR,
|
||||||
CONF_PLATFORM,
|
CONF_PLATFORM,
|
||||||
|
CONF_REPEAT,
|
||||||
CONF_SCAN_INTERVAL,
|
CONF_SCAN_INTERVAL,
|
||||||
CONF_SCENE,
|
CONF_SCENE,
|
||||||
|
CONF_SEQUENCE,
|
||||||
CONF_SERVICE,
|
CONF_SERVICE,
|
||||||
CONF_SERVICE_TEMPLATE,
|
CONF_SERVICE_TEMPLATE,
|
||||||
CONF_STATE,
|
CONF_STATE,
|
||||||
CONF_TIMEOUT,
|
CONF_TIMEOUT,
|
||||||
CONF_UNIT_SYSTEM_IMPERIAL,
|
CONF_UNIT_SYSTEM_IMPERIAL,
|
||||||
CONF_UNIT_SYSTEM_METRIC,
|
CONF_UNIT_SYSTEM_METRIC,
|
||||||
|
CONF_UNTIL,
|
||||||
CONF_VALUE_TEMPLATE,
|
CONF_VALUE_TEMPLATE,
|
||||||
CONF_WAIT_TEMPLATE,
|
CONF_WAIT_TEMPLATE,
|
||||||
|
CONF_WHILE,
|
||||||
ENTITY_MATCH_ALL,
|
ENTITY_MATCH_ALL,
|
||||||
ENTITY_MATCH_NONE,
|
ENTITY_MATCH_NONE,
|
||||||
SUN_EVENT_SUNRISE,
|
SUN_EVENT_SUNRISE,
|
||||||
|
@ -80,7 +85,6 @@ import homeassistant.util.dt as dt_util
|
||||||
|
|
||||||
TIME_PERIOD_ERROR = "offset {} should be format 'HH:MM' or 'HH:MM:SS'"
|
TIME_PERIOD_ERROR = "offset {} should be format 'HH:MM' or 'HH:MM:SS'"
|
||||||
|
|
||||||
|
|
||||||
# Home Assistant types
|
# Home Assistant types
|
||||||
byte = vol.All(vol.Coerce(int), vol.Range(min=0, max=255))
|
byte = vol.All(vol.Coerce(int), vol.Range(min=0, max=255))
|
||||||
small_float = vol.All(vol.Coerce(float), vol.Range(min=0, max=1))
|
small_float = vol.All(vol.Coerce(float), vol.Range(min=0, max=1))
|
||||||
|
@ -817,6 +821,16 @@ def make_entity_service_schema(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def script_action(value: Any) -> dict:
|
||||||
|
"""Validate a script action."""
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
raise vol.Invalid("expected dictionary")
|
||||||
|
|
||||||
|
return ACTION_TYPE_SCHEMAS[determine_script_action(value)](value)
|
||||||
|
|
||||||
|
|
||||||
|
SCRIPT_SCHEMA = vol.All(ensure_list, [script_action])
|
||||||
|
|
||||||
EVENT_SCHEMA = vol.Schema(
|
EVENT_SCHEMA = vol.Schema(
|
||||||
{
|
{
|
||||||
vol.Optional(CONF_ALIAS): string,
|
vol.Optional(CONF_ALIAS): string,
|
||||||
|
@ -998,6 +1012,25 @@ DEVICE_ACTION_SCHEMA = DEVICE_ACTION_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTR
|
||||||
|
|
||||||
_SCRIPT_SCENE_SCHEMA = vol.Schema({vol.Required(CONF_SCENE): entity_domain("scene")})
|
_SCRIPT_SCENE_SCHEMA = vol.Schema({vol.Required(CONF_SCENE): entity_domain("scene")})
|
||||||
|
|
||||||
|
_SCRIPT_REPEAT_SCHEMA = vol.Schema(
|
||||||
|
{
|
||||||
|
vol.Optional(CONF_ALIAS): string,
|
||||||
|
vol.Required(CONF_REPEAT): vol.All(
|
||||||
|
{
|
||||||
|
vol.Exclusive(CONF_COUNT, "repeat"): vol.Any(vol.Coerce(int), template),
|
||||||
|
vol.Exclusive(CONF_WHILE, "repeat"): vol.All(
|
||||||
|
ensure_list, [CONDITION_SCHEMA]
|
||||||
|
),
|
||||||
|
vol.Exclusive(CONF_UNTIL, "repeat"): vol.All(
|
||||||
|
ensure_list, [CONDITION_SCHEMA]
|
||||||
|
),
|
||||||
|
vol.Required(CONF_SEQUENCE): SCRIPT_SCHEMA,
|
||||||
|
},
|
||||||
|
has_at_least_one_key(CONF_COUNT, CONF_WHILE, CONF_UNTIL),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
SCRIPT_ACTION_DELAY = "delay"
|
SCRIPT_ACTION_DELAY = "delay"
|
||||||
SCRIPT_ACTION_WAIT_TEMPLATE = "wait_template"
|
SCRIPT_ACTION_WAIT_TEMPLATE = "wait_template"
|
||||||
SCRIPT_ACTION_CHECK_CONDITION = "condition"
|
SCRIPT_ACTION_CHECK_CONDITION = "condition"
|
||||||
|
@ -1005,6 +1038,7 @@ SCRIPT_ACTION_FIRE_EVENT = "event"
|
||||||
SCRIPT_ACTION_CALL_SERVICE = "call_service"
|
SCRIPT_ACTION_CALL_SERVICE = "call_service"
|
||||||
SCRIPT_ACTION_DEVICE_AUTOMATION = "device"
|
SCRIPT_ACTION_DEVICE_AUTOMATION = "device"
|
||||||
SCRIPT_ACTION_ACTIVATE_SCENE = "scene"
|
SCRIPT_ACTION_ACTIVATE_SCENE = "scene"
|
||||||
|
SCRIPT_ACTION_REPEAT = "repeat"
|
||||||
|
|
||||||
|
|
||||||
def determine_script_action(action: dict) -> str:
|
def determine_script_action(action: dict) -> str:
|
||||||
|
@ -1027,6 +1061,9 @@ def determine_script_action(action: dict) -> str:
|
||||||
if CONF_SCENE in action:
|
if CONF_SCENE in action:
|
||||||
return SCRIPT_ACTION_ACTIVATE_SCENE
|
return SCRIPT_ACTION_ACTIVATE_SCENE
|
||||||
|
|
||||||
|
if CONF_REPEAT in action:
|
||||||
|
return SCRIPT_ACTION_REPEAT
|
||||||
|
|
||||||
return SCRIPT_ACTION_CALL_SERVICE
|
return SCRIPT_ACTION_CALL_SERVICE
|
||||||
|
|
||||||
|
|
||||||
|
@ -1038,15 +1075,5 @@ ACTION_TYPE_SCHEMAS: Dict[str, Callable[[Any], dict]] = {
|
||||||
SCRIPT_ACTION_CHECK_CONDITION: CONDITION_SCHEMA,
|
SCRIPT_ACTION_CHECK_CONDITION: CONDITION_SCHEMA,
|
||||||
SCRIPT_ACTION_DEVICE_AUTOMATION: DEVICE_ACTION_SCHEMA,
|
SCRIPT_ACTION_DEVICE_AUTOMATION: DEVICE_ACTION_SCHEMA,
|
||||||
SCRIPT_ACTION_ACTIVATE_SCENE: _SCRIPT_SCENE_SCHEMA,
|
SCRIPT_ACTION_ACTIVATE_SCENE: _SCRIPT_SCENE_SCHEMA,
|
||||||
|
SCRIPT_ACTION_REPEAT: _SCRIPT_REPEAT_SCHEMA,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def script_action(value: Any) -> dict:
|
|
||||||
"""Validate a script action."""
|
|
||||||
if not isinstance(value, dict):
|
|
||||||
raise vol.Invalid("expected dictionary")
|
|
||||||
|
|
||||||
return ACTION_TYPE_SCHEMAS[determine_script_action(value)](value)
|
|
||||||
|
|
||||||
|
|
||||||
SCRIPT_SCHEMA = vol.All(ensure_list, [script_action])
|
|
||||||
|
|
|
@ -3,7 +3,8 @@ from abc import ABC, abstractmethod
|
||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from itertools import islice
|
from functools import partial
|
||||||
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, cast
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, cast
|
||||||
|
|
||||||
|
@ -18,6 +19,7 @@ from homeassistant.const import (
|
||||||
CONF_ALIAS,
|
CONF_ALIAS,
|
||||||
CONF_CONDITION,
|
CONF_CONDITION,
|
||||||
CONF_CONTINUE_ON_TIMEOUT,
|
CONF_CONTINUE_ON_TIMEOUT,
|
||||||
|
CONF_COUNT,
|
||||||
CONF_DELAY,
|
CONF_DELAY,
|
||||||
CONF_DEVICE_ID,
|
CONF_DEVICE_ID,
|
||||||
CONF_DOMAIN,
|
CONF_DOMAIN,
|
||||||
|
@ -26,9 +28,13 @@ from homeassistant.const import (
|
||||||
CONF_EVENT_DATA_TEMPLATE,
|
CONF_EVENT_DATA_TEMPLATE,
|
||||||
CONF_MODE,
|
CONF_MODE,
|
||||||
CONF_QUEUE_SIZE,
|
CONF_QUEUE_SIZE,
|
||||||
|
CONF_REPEAT,
|
||||||
CONF_SCENE,
|
CONF_SCENE,
|
||||||
|
CONF_SEQUENCE,
|
||||||
CONF_TIMEOUT,
|
CONF_TIMEOUT,
|
||||||
|
CONF_UNTIL,
|
||||||
CONF_WAIT_TEMPLATE,
|
CONF_WAIT_TEMPLATE,
|
||||||
|
CONF_WHILE,
|
||||||
SERVICE_TURN_OFF,
|
SERVICE_TURN_OFF,
|
||||||
SERVICE_TURN_ON,
|
SERVICE_TURN_ON,
|
||||||
)
|
)
|
||||||
|
@ -86,6 +92,10 @@ SCRIPT_BASE_SCHEMA = vol.Schema(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_UNSUPPORTED_IN_LEGACY = {
|
||||||
|
cv.SCRIPT_ACTION_REPEAT: CONF_REPEAT,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def warn_deprecated_legacy(logger, msg):
|
def warn_deprecated_legacy(logger, msg):
|
||||||
"""Warn about deprecated legacy mode."""
|
"""Warn about deprecated legacy mode."""
|
||||||
|
@ -99,6 +109,17 @@ def warn_deprecated_legacy(logger, msg):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_legacy_mode_actions(sequence):
|
||||||
|
"""Check for actions not supported in legacy mode."""
|
||||||
|
for action in sequence:
|
||||||
|
script_action = cv.determine_script_action(action)
|
||||||
|
if script_action in _UNSUPPORTED_IN_LEGACY:
|
||||||
|
raise vol.Invalid(
|
||||||
|
f"{_UNSUPPORTED_IN_LEGACY[script_action]} action not supported in "
|
||||||
|
f"{SCRIPT_MODE_LEGACY} mode"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def validate_queue_size(config):
|
def validate_queue_size(config):
|
||||||
"""Validate queue_size option."""
|
"""Validate queue_size option."""
|
||||||
mode = config.get(CONF_MODE, DEFAULT_SCRIPT_MODE)
|
mode = config.get(CONF_MODE, DEFAULT_SCRIPT_MODE)
|
||||||
|
@ -330,22 +351,29 @@ class _ScriptRunBase(ABC):
|
||||||
self._action[CONF_EVENT], event_data, context=self._context
|
self._action[CONF_EVENT], event_data, context=self._context
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _async_get_condition(self, config):
|
||||||
|
config_cache_key = frozenset((k, str(v)) for k, v in config.items())
|
||||||
|
cond = self._config_cache.get(config_cache_key)
|
||||||
|
if not cond:
|
||||||
|
cond = await condition.async_from_config(self._hass, config, False)
|
||||||
|
self._config_cache[config_cache_key] = cond
|
||||||
|
return cond
|
||||||
|
|
||||||
async def _async_condition_step(self):
|
async def _async_condition_step(self):
|
||||||
"""Test if condition is matching."""
|
"""Test if condition is matching."""
|
||||||
config_cache_key = frozenset((k, str(v)) for k, v in self._action.items())
|
|
||||||
config = self._config_cache.get(config_cache_key)
|
|
||||||
if not config:
|
|
||||||
config = await condition.async_from_config(self._hass, self._action, False)
|
|
||||||
self._config_cache[config_cache_key] = config
|
|
||||||
|
|
||||||
self._script.last_action = self._action.get(
|
self._script.last_action = self._action.get(
|
||||||
CONF_ALIAS, self._action[CONF_CONDITION]
|
CONF_ALIAS, self._action[CONF_CONDITION]
|
||||||
)
|
)
|
||||||
check = config(self._hass, self._variables)
|
cond = await self._async_get_condition(self._action)
|
||||||
|
check = cond(self._hass, self._variables)
|
||||||
self._log("Test condition %s: %s", self._script.last_action, check)
|
self._log("Test condition %s: %s", self._script.last_action, check)
|
||||||
if not check:
|
if not check:
|
||||||
raise _StopScript
|
raise _StopScript
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _async_repeat_step(self):
|
||||||
|
"""Repeat a sequence."""
|
||||||
|
|
||||||
def _log(self, msg, *args, level=logging.INFO):
|
def _log(self, msg, *args, level=logging.INFO):
|
||||||
self._script._log(msg, *args, level=level) # pylint: disable=protected-access
|
self._script._log(msg, *args, level=level) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
@ -441,6 +469,41 @@ class _ScriptRun(_ScriptRunBase):
|
||||||
task.cancel()
|
task.cancel()
|
||||||
unsub()
|
unsub()
|
||||||
|
|
||||||
|
async def _async_run_long_action(self, long_task):
|
||||||
|
"""Run a long task while monitoring for stop request."""
|
||||||
|
|
||||||
|
async def async_cancel_long_task():
|
||||||
|
# Stop long task and wait for it to finish.
|
||||||
|
long_task.cancel()
|
||||||
|
try:
|
||||||
|
await long_task
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Wait for long task while monitoring for a stop request.
|
||||||
|
stop_task = self._hass.async_create_task(self._stop.wait())
|
||||||
|
try:
|
||||||
|
await asyncio.wait(
|
||||||
|
{long_task, stop_task}, return_when=asyncio.FIRST_COMPLETED
|
||||||
|
)
|
||||||
|
# If our task is cancelled, then cancel long task, too. Note that if long task
|
||||||
|
# is cancelled otherwise the CancelledError exception will not be raised to
|
||||||
|
# here due to the call to asyncio.wait(). Rather we'll check for that below.
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
await async_cancel_long_task()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
stop_task.cancel()
|
||||||
|
|
||||||
|
if long_task.cancelled():
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
if long_task.done():
|
||||||
|
# Propagate any exceptions that occurred.
|
||||||
|
long_task.result()
|
||||||
|
else:
|
||||||
|
# Stopped before long task completed, so cancel it.
|
||||||
|
await async_cancel_long_task()
|
||||||
|
|
||||||
async def _async_call_service_step(self):
|
async def _async_call_service_step(self):
|
||||||
"""Call the service specified in the action."""
|
"""Call the service specified in the action."""
|
||||||
domain, service, service_data = self._prep_call_service_step()
|
domain, service, service_data = self._prep_call_service_step()
|
||||||
|
@ -474,37 +537,71 @@ class _ScriptRun(_ScriptRunBase):
|
||||||
await service_task
|
await service_task
|
||||||
return
|
return
|
||||||
|
|
||||||
async def async_cancel_service_task():
|
await self._async_run_long_action(service_task)
|
||||||
# Stop service task and wait for it to finish.
|
|
||||||
service_task.cancel()
|
|
||||||
try:
|
|
||||||
await service_task
|
|
||||||
except Exception: # pylint: disable=broad-except
|
|
||||||
pass
|
|
||||||
|
|
||||||
# No call limit so watch for a stop request.
|
async def _async_repeat_step(self):
|
||||||
stop_task = self._hass.async_create_task(self._stop.wait())
|
"""Repeat a sequence."""
|
||||||
try:
|
|
||||||
await asyncio.wait(
|
description = self._action.get(CONF_ALIAS, "sequence")
|
||||||
{service_task, stop_task}, return_when=asyncio.FIRST_COMPLETED
|
repeat = self._action[CONF_REPEAT]
|
||||||
|
|
||||||
|
async def async_run_sequence(iteration, extra_msg="", extra_vars=None):
|
||||||
|
self._log("Repeating %s: Iteration %i%s", description, iteration, extra_msg)
|
||||||
|
repeat_vars = {"repeat": {"first": iteration == 1, "index": iteration}}
|
||||||
|
if extra_vars:
|
||||||
|
repeat_vars["repeat"].update(extra_vars)
|
||||||
|
task = self._hass.async_create_task(
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
self._script._repeat_script[self._step].async_run(
|
||||||
|
# Add repeat to variables. Override if it already exists in case of
|
||||||
|
# nested calls.
|
||||||
|
{**(self._variables or {}), **repeat_vars},
|
||||||
|
self._context,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# If our task is cancelled, then cancel service task, too. Note that if service
|
await self._async_run_long_action(task)
|
||||||
# task is cancelled otherwise the CancelledError exception will not be raised to
|
|
||||||
# here due to the call to asyncio.wait(). Rather we'll check for that below.
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
await async_cancel_service_task()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
stop_task.cancel()
|
|
||||||
|
|
||||||
if service_task.cancelled():
|
if CONF_COUNT in repeat:
|
||||||
raise asyncio.CancelledError
|
count = repeat[CONF_COUNT]
|
||||||
if service_task.done():
|
if isinstance(count, template.Template):
|
||||||
# Propagate any exceptions that occurred.
|
try:
|
||||||
service_task.result()
|
count = int(count.async_render(self._variables))
|
||||||
elif running_script:
|
except (exceptions.TemplateError, ValueError) as ex:
|
||||||
# Stopped before service completed, so cancel service.
|
self._log(
|
||||||
await async_cancel_service_task()
|
"Error rendering %s repeat count template: %s",
|
||||||
|
self._script.name,
|
||||||
|
ex,
|
||||||
|
level=logging.ERROR,
|
||||||
|
)
|
||||||
|
raise _StopScript
|
||||||
|
for iteration in range(1, count + 1):
|
||||||
|
await async_run_sequence(
|
||||||
|
iteration, f" of {count}", {"last": iteration == count}
|
||||||
|
)
|
||||||
|
if self._stop.is_set():
|
||||||
|
break
|
||||||
|
|
||||||
|
elif CONF_WHILE in repeat:
|
||||||
|
conditions = [
|
||||||
|
await self._async_get_condition(config) for config in repeat[CONF_WHILE]
|
||||||
|
]
|
||||||
|
for iteration in itertools.count(1):
|
||||||
|
if self._stop.is_set() or not all(
|
||||||
|
cond(self._hass, self._variables) for cond in conditions
|
||||||
|
):
|
||||||
|
break
|
||||||
|
await async_run_sequence(iteration)
|
||||||
|
|
||||||
|
elif CONF_UNTIL in repeat:
|
||||||
|
conditions = [
|
||||||
|
await self._async_get_condition(config) for config in repeat[CONF_UNTIL]
|
||||||
|
]
|
||||||
|
for iteration in itertools.count(1):
|
||||||
|
await async_run_sequence(iteration)
|
||||||
|
if self._stop.is_set() or all(
|
||||||
|
cond(self._hass, self._variables) for cond in conditions
|
||||||
|
):
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
class _QueuedScriptRun(_ScriptRun):
|
class _QueuedScriptRun(_ScriptRun):
|
||||||
|
@ -599,7 +696,7 @@ class _LegacyScriptRun(_ScriptRunBase):
|
||||||
|
|
||||||
suspended = False
|
suspended = False
|
||||||
try:
|
try:
|
||||||
for self._step, self._action in islice(
|
for self._step, self._action in itertools.islice(
|
||||||
enumerate(self._script.sequence), self._cur, None
|
enumerate(self._script.sequence), self._cur, None
|
||||||
):
|
):
|
||||||
await self._async_step(log_exceptions=not propagate_exceptions)
|
await self._async_step(log_exceptions=not propagate_exceptions)
|
||||||
|
@ -689,6 +786,10 @@ class _LegacyScriptRun(_ScriptRunBase):
|
||||||
*self._prep_call_service_step(), blocking=True, context=self._context
|
*self._prep_call_service_step(), blocking=True, context=self._context
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _async_repeat_step(self):
|
||||||
|
"""Repeat a sequence."""
|
||||||
|
# Not supported in legacy mode.
|
||||||
|
|
||||||
def _async_remove_listener(self):
|
def _async_remove_listener(self):
|
||||||
"""Remove listeners, if any."""
|
"""Remove listeners, if any."""
|
||||||
for unsub in self._async_listener:
|
for unsub in self._async_listener:
|
||||||
|
@ -733,6 +834,22 @@ class Script:
|
||||||
for action in self.sequence
|
for action in self.sequence
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._repeat_script = {}
|
||||||
|
for step, action in enumerate(sequence):
|
||||||
|
if cv.determine_script_action(action) == cv.SCRIPT_ACTION_REPEAT:
|
||||||
|
step_name = action.get(CONF_ALIAS, f"Repeat at step {step}")
|
||||||
|
sub_script = Script(
|
||||||
|
hass,
|
||||||
|
action[CONF_REPEAT][CONF_SEQUENCE],
|
||||||
|
f"{name}: {step_name}",
|
||||||
|
script_mode=SCRIPT_MODE_PARALLEL,
|
||||||
|
logger=self._logger,
|
||||||
|
)
|
||||||
|
sub_script.change_listener = partial(
|
||||||
|
self._chain_change_listener, sub_script
|
||||||
|
)
|
||||||
|
self._repeat_script[step] = sub_script
|
||||||
|
|
||||||
self._runs: List[_ScriptRunBase] = []
|
self._runs: List[_ScriptRunBase] = []
|
||||||
if script_mode == SCRIPT_MODE_QUEUE:
|
if script_mode == SCRIPT_MODE_QUEUE:
|
||||||
self._queue_size = queue_size
|
self._queue_size = queue_size
|
||||||
|
@ -746,6 +863,11 @@ class Script:
|
||||||
if self.change_listener:
|
if self.change_listener:
|
||||||
self._hass.async_run_job(self.change_listener)
|
self._hass.async_run_job(self.change_listener)
|
||||||
|
|
||||||
|
def _chain_change_listener(self, sub_script):
|
||||||
|
if sub_script.is_running:
|
||||||
|
self.last_action = sub_script.last_action
|
||||||
|
self._changed()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
"""Return true if script is on."""
|
"""Return true if script is on."""
|
||||||
|
|
|
@ -1072,20 +1072,27 @@ async def test_logbook_humanify_automation_triggered_event(hass):
|
||||||
assert event2["entity_id"] == "automation.bye"
|
assert event2["entity_id"] == "automation.bye"
|
||||||
|
|
||||||
|
|
||||||
async def test_invalid_config(hass):
|
invalid_configs = [
|
||||||
"""Test invalid config."""
|
{
|
||||||
|
"mode": "parallel",
|
||||||
|
"queue_size": 5,
|
||||||
|
"trigger": {"platform": "event", "event_type": "test_event"},
|
||||||
|
"action": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"mode": "legacy",
|
||||||
|
"trigger": {"platform": "event", "event_type": "test_event"},
|
||||||
|
"action": [{"repeat": {"count": 5, "sequence": []}}],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("value", invalid_configs)
|
||||||
|
async def test_invalid_configs(hass, value):
|
||||||
|
"""Test invalid configurations."""
|
||||||
with assert_setup_component(0, automation.DOMAIN):
|
with assert_setup_component(0, automation.DOMAIN):
|
||||||
assert await async_setup_component(
|
assert await async_setup_component(
|
||||||
hass,
|
hass, automation.DOMAIN, {automation.DOMAIN: value}
|
||||||
automation.DOMAIN,
|
|
||||||
{
|
|
||||||
automation.DOMAIN: {
|
|
||||||
"mode": "parallel",
|
|
||||||
"queue_size": 5,
|
|
||||||
"trigger": {"platform": "event", "event_type": "test_event"},
|
|
||||||
"action": [],
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -200,6 +200,12 @@ invalid_configs = [
|
||||||
{"test hello world": {"sequence": [{"event": "bla"}]}},
|
{"test hello world": {"sequence": [{"event": "bla"}]}},
|
||||||
{"test": {"sequence": {"event": "test_event", "service": "homeassistant.turn_on"}}},
|
{"test": {"sequence": {"event": "test_event", "service": "homeassistant.turn_on"}}},
|
||||||
{"test": {"sequence": [], "mode": "parallel", "queue_size": 5}},
|
{"test": {"sequence": [], "mode": "parallel", "queue_size": 5}},
|
||||||
|
{
|
||||||
|
"test": {
|
||||||
|
"mode": "legacy",
|
||||||
|
"sequence": [{"repeat": {"count": 5, "sequence": []}}],
|
||||||
|
}
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -822,6 +822,101 @@ async def test_condition_all_cached(hass, script_mode):
|
||||||
assert len(script_obj._config_cache) == 2
|
assert len(script_obj._config_cache) == 2
|
||||||
|
|
||||||
|
|
||||||
|
async def test_repeat_count(hass):
|
||||||
|
"""Test repeat action w/ count option."""
|
||||||
|
event = "test_event"
|
||||||
|
events = async_capture_events(hass, event)
|
||||||
|
count = 3
|
||||||
|
|
||||||
|
sequence = cv.SCRIPT_SCHEMA(
|
||||||
|
{
|
||||||
|
"repeat": {
|
||||||
|
"count": count,
|
||||||
|
"sequence": {
|
||||||
|
"event": event,
|
||||||
|
"event_data_template": {
|
||||||
|
"first": "{{ repeat.first }}",
|
||||||
|
"index": "{{ repeat.index }}",
|
||||||
|
"last": "{{ repeat.last }}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
script_obj = script.Script(hass, sequence, script_mode="ignore")
|
||||||
|
|
||||||
|
await script_obj.async_run()
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert len(events) == count
|
||||||
|
for index, event in enumerate(events):
|
||||||
|
assert event.data.get("first") == str(index == 0)
|
||||||
|
assert event.data.get("index") == str(index + 1)
|
||||||
|
assert event.data.get("last") == str(index == count - 1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("condition", ["while", "until"])
|
||||||
|
async def test_repeat_conditional(hass, condition):
|
||||||
|
"""Test repeat action w/ while option."""
|
||||||
|
event = "test_event"
|
||||||
|
events = async_capture_events(hass, event)
|
||||||
|
count = 3
|
||||||
|
|
||||||
|
sequence = {
|
||||||
|
"repeat": {
|
||||||
|
"sequence": [
|
||||||
|
{
|
||||||
|
"event": event,
|
||||||
|
"event_data_template": {
|
||||||
|
"first": "{{ repeat.first }}",
|
||||||
|
"index": "{{ repeat.index }}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"wait_template": "{{ is_state('sensor.test', 'next') }}"},
|
||||||
|
{"wait_template": "{{ not is_state('sensor.test', 'next') }}"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if condition == "while":
|
||||||
|
sequence["repeat"]["while"] = {
|
||||||
|
"condition": "template",
|
||||||
|
"value_template": "{{ not is_state('sensor.test', 'done') }}",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
sequence["repeat"]["until"] = {
|
||||||
|
"condition": "template",
|
||||||
|
"value_template": "{{ is_state('sensor.test', 'done') }}",
|
||||||
|
}
|
||||||
|
script_obj = script.Script(hass, cv.SCRIPT_SCHEMA(sequence), script_mode="ignore")
|
||||||
|
|
||||||
|
wait_started = async_watch_for_action(script_obj, "wait")
|
||||||
|
hass.states.async_set("sensor.test", "1")
|
||||||
|
|
||||||
|
hass.async_create_task(script_obj.async_run())
|
||||||
|
try:
|
||||||
|
for index in range(2, count + 1):
|
||||||
|
await asyncio.wait_for(wait_started.wait(), 1)
|
||||||
|
wait_started.clear()
|
||||||
|
hass.states.async_set("sensor.test", "next")
|
||||||
|
await asyncio.wait_for(wait_started.wait(), 1)
|
||||||
|
wait_started.clear()
|
||||||
|
hass.states.async_set("sensor.test", index)
|
||||||
|
await asyncio.wait_for(wait_started.wait(), 1)
|
||||||
|
hass.states.async_set("sensor.test", "next")
|
||||||
|
await asyncio.wait_for(wait_started.wait(), 1)
|
||||||
|
wait_started.clear()
|
||||||
|
hass.states.async_set("sensor.test", "done")
|
||||||
|
await asyncio.wait_for(hass.async_block_till_done(), 1)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
await script_obj.async_stop()
|
||||||
|
raise
|
||||||
|
|
||||||
|
assert len(events) == count
|
||||||
|
for index, event in enumerate(events):
|
||||||
|
assert event.data.get("first") == str(index == 0)
|
||||||
|
assert event.data.get("index") == str(index + 1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||||
async def test_last_triggered(hass, script_mode):
|
async def test_last_triggered(hass, script_mode):
|
||||||
"""Test the last_triggered."""
|
"""Test the last_triggered."""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue