Add new repeat loop for scripts and automations (#37589)
This commit is contained in:
parent
b187b17a4f
commit
91271f388c
8 changed files with 337 additions and 62 deletions
|
@ -15,6 +15,7 @@ from homeassistant.helpers import condition, config_per_platform
|
|||
from homeassistant.helpers.script import (
|
||||
SCRIPT_MODE_LEGACY,
|
||||
async_validate_action_config,
|
||||
validate_legacy_mode_actions,
|
||||
warn_deprecated_legacy,
|
||||
)
|
||||
from homeassistant.loader import IntegrationNotFound
|
||||
|
@ -55,6 +56,9 @@ async def async_validate_config_item(hass, config, full_config=None):
|
|||
*[async_validate_action_config(hass, action) for action in config[CONF_ACTION]]
|
||||
)
|
||||
|
||||
if config.get(CONF_MODE, SCRIPT_MODE_LEGACY) == SCRIPT_MODE_LEGACY:
|
||||
validate_legacy_mode_actions(config[CONF_ACTION])
|
||||
|
||||
return config
|
||||
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ from homeassistant.const import (
|
|||
CONF_ICON,
|
||||
CONF_MODE,
|
||||
CONF_QUEUE_SIZE,
|
||||
CONF_SEQUENCE,
|
||||
SERVICE_RELOAD,
|
||||
SERVICE_TOGGLE,
|
||||
SERVICE_TURN_OFF,
|
||||
|
@ -27,6 +28,7 @@ from homeassistant.helpers.script import (
|
|||
SCRIPT_BASE_SCHEMA,
|
||||
SCRIPT_MODE_LEGACY,
|
||||
Script,
|
||||
validate_legacy_mode_actions,
|
||||
validate_queue_size,
|
||||
warn_deprecated_legacy,
|
||||
)
|
||||
|
@ -44,7 +46,6 @@ ATTR_VARIABLES = "variables"
|
|||
CONF_DESCRIPTION = "description"
|
||||
CONF_EXAMPLE = "example"
|
||||
CONF_FIELDS = "fields"
|
||||
CONF_SEQUENCE = "sequence"
|
||||
|
||||
ENTITY_ID_FORMAT = DOMAIN + ".{}"
|
||||
|
||||
|
@ -63,6 +64,13 @@ def _deprecated_legacy_mode(config):
|
|||
return config
|
||||
|
||||
|
||||
def _not_supported_in_legacy_mode(config):
|
||||
if config.get(CONF_MODE, SCRIPT_MODE_LEGACY) == SCRIPT_MODE_LEGACY:
|
||||
validate_legacy_mode_actions(config[CONF_SEQUENCE])
|
||||
|
||||
return config
|
||||
|
||||
|
||||
SCRIPT_ENTRY_SCHEMA = vol.All(
|
||||
SCRIPT_BASE_SCHEMA.extend(
|
||||
{
|
||||
|
@ -79,6 +87,7 @@ SCRIPT_ENTRY_SCHEMA = vol.All(
|
|||
}
|
||||
),
|
||||
validate_queue_size,
|
||||
_not_supported_in_legacy_mode,
|
||||
)
|
||||
|
||||
CONFIG_SCHEMA = vol.Schema(
|
||||
|
|
|
@ -61,6 +61,7 @@ CONF_COMMAND_STATE = "command_state"
|
|||
CONF_COMMAND_STOP = "command_stop"
|
||||
CONF_CONDITION = "condition"
|
||||
CONF_CONTINUE_ON_TIMEOUT = "continue_on_timeout"
|
||||
CONF_COUNT = "count"
|
||||
CONF_COVERS = "covers"
|
||||
CONF_CURRENCY = "currency"
|
||||
CONF_CUSTOMIZE = "customize"
|
||||
|
@ -139,6 +140,7 @@ CONF_QUOTE = "quote"
|
|||
CONF_RADIUS = "radius"
|
||||
CONF_RECIPIENT = "recipient"
|
||||
CONF_REGION = "region"
|
||||
CONF_REPEAT = "repeat"
|
||||
CONF_RESOURCE = "resource"
|
||||
CONF_RESOURCES = "resources"
|
||||
CONF_RESOURCE_TEMPLATE = "resource_template"
|
||||
|
@ -149,6 +151,7 @@ CONF_SCENE = "scene"
|
|||
CONF_SENDER = "sender"
|
||||
CONF_SENSORS = "sensors"
|
||||
CONF_SENSOR_TYPE = "sensor_type"
|
||||
CONF_SEQUENCE = "sequence"
|
||||
CONF_SERVICE = "service"
|
||||
CONF_SERVICE_DATA = "data"
|
||||
CONF_SERVICE_TEMPLATE = "service_template"
|
||||
|
@ -169,6 +172,7 @@ CONF_TTL = "ttl"
|
|||
CONF_TYPE = "type"
|
||||
CONF_UNIT_OF_MEASUREMENT = "unit_of_measurement"
|
||||
CONF_UNIT_SYSTEM = "unit_system"
|
||||
CONF_UNTIL = "until"
|
||||
CONF_URL = "url"
|
||||
CONF_USERNAME = "username"
|
||||
CONF_VALUE_TEMPLATE = "value_template"
|
||||
|
@ -176,6 +180,7 @@ CONF_VERIFY_SSL = "verify_ssl"
|
|||
CONF_WAIT_TEMPLATE = "wait_template"
|
||||
CONF_WEBHOOK_ID = "webhook_id"
|
||||
CONF_WEEKDAY = "weekday"
|
||||
CONF_WHILE = "while"
|
||||
CONF_WHITELIST = "whitelist"
|
||||
CONF_WHITELIST_EXTERNAL_DIRS = "whitelist_external_dirs"
|
||||
CONF_WHITE_VALUE = "white_value"
|
||||
|
|
|
@ -40,6 +40,7 @@ from homeassistant.const import (
|
|||
CONF_BELOW,
|
||||
CONF_CONDITION,
|
||||
CONF_CONTINUE_ON_TIMEOUT,
|
||||
CONF_COUNT,
|
||||
CONF_DELAY,
|
||||
CONF_DEVICE_ID,
|
||||
CONF_DOMAIN,
|
||||
|
@ -50,16 +51,20 @@ from homeassistant.const import (
|
|||
CONF_EVENT_DATA_TEMPLATE,
|
||||
CONF_FOR,
|
||||
CONF_PLATFORM,
|
||||
CONF_REPEAT,
|
||||
CONF_SCAN_INTERVAL,
|
||||
CONF_SCENE,
|
||||
CONF_SEQUENCE,
|
||||
CONF_SERVICE,
|
||||
CONF_SERVICE_TEMPLATE,
|
||||
CONF_STATE,
|
||||
CONF_TIMEOUT,
|
||||
CONF_UNIT_SYSTEM_IMPERIAL,
|
||||
CONF_UNIT_SYSTEM_METRIC,
|
||||
CONF_UNTIL,
|
||||
CONF_VALUE_TEMPLATE,
|
||||
CONF_WAIT_TEMPLATE,
|
||||
CONF_WHILE,
|
||||
ENTITY_MATCH_ALL,
|
||||
ENTITY_MATCH_NONE,
|
||||
SUN_EVENT_SUNRISE,
|
||||
|
@ -80,7 +85,6 @@ import homeassistant.util.dt as dt_util
|
|||
|
||||
TIME_PERIOD_ERROR = "offset {} should be format 'HH:MM' or 'HH:MM:SS'"
|
||||
|
||||
|
||||
# Home Assistant types
|
||||
byte = vol.All(vol.Coerce(int), vol.Range(min=0, max=255))
|
||||
small_float = vol.All(vol.Coerce(float), vol.Range(min=0, max=1))
|
||||
|
@ -817,6 +821,16 @@ def make_entity_service_schema(
|
|||
)
|
||||
|
||||
|
||||
def script_action(value: Any) -> dict:
|
||||
"""Validate a script action."""
|
||||
if not isinstance(value, dict):
|
||||
raise vol.Invalid("expected dictionary")
|
||||
|
||||
return ACTION_TYPE_SCHEMAS[determine_script_action(value)](value)
|
||||
|
||||
|
||||
SCRIPT_SCHEMA = vol.All(ensure_list, [script_action])
|
||||
|
||||
EVENT_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Optional(CONF_ALIAS): string,
|
||||
|
@ -998,6 +1012,25 @@ DEVICE_ACTION_SCHEMA = DEVICE_ACTION_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTR
|
|||
|
||||
_SCRIPT_SCENE_SCHEMA = vol.Schema({vol.Required(CONF_SCENE): entity_domain("scene")})
|
||||
|
||||
_SCRIPT_REPEAT_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Optional(CONF_ALIAS): string,
|
||||
vol.Required(CONF_REPEAT): vol.All(
|
||||
{
|
||||
vol.Exclusive(CONF_COUNT, "repeat"): vol.Any(vol.Coerce(int), template),
|
||||
vol.Exclusive(CONF_WHILE, "repeat"): vol.All(
|
||||
ensure_list, [CONDITION_SCHEMA]
|
||||
),
|
||||
vol.Exclusive(CONF_UNTIL, "repeat"): vol.All(
|
||||
ensure_list, [CONDITION_SCHEMA]
|
||||
),
|
||||
vol.Required(CONF_SEQUENCE): SCRIPT_SCHEMA,
|
||||
},
|
||||
has_at_least_one_key(CONF_COUNT, CONF_WHILE, CONF_UNTIL),
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
SCRIPT_ACTION_DELAY = "delay"
|
||||
SCRIPT_ACTION_WAIT_TEMPLATE = "wait_template"
|
||||
SCRIPT_ACTION_CHECK_CONDITION = "condition"
|
||||
|
@ -1005,6 +1038,7 @@ SCRIPT_ACTION_FIRE_EVENT = "event"
|
|||
SCRIPT_ACTION_CALL_SERVICE = "call_service"
|
||||
SCRIPT_ACTION_DEVICE_AUTOMATION = "device"
|
||||
SCRIPT_ACTION_ACTIVATE_SCENE = "scene"
|
||||
SCRIPT_ACTION_REPEAT = "repeat"
|
||||
|
||||
|
||||
def determine_script_action(action: dict) -> str:
|
||||
|
@ -1027,6 +1061,9 @@ def determine_script_action(action: dict) -> str:
|
|||
if CONF_SCENE in action:
|
||||
return SCRIPT_ACTION_ACTIVATE_SCENE
|
||||
|
||||
if CONF_REPEAT in action:
|
||||
return SCRIPT_ACTION_REPEAT
|
||||
|
||||
return SCRIPT_ACTION_CALL_SERVICE
|
||||
|
||||
|
||||
|
@ -1038,15 +1075,5 @@ ACTION_TYPE_SCHEMAS: Dict[str, Callable[[Any], dict]] = {
|
|||
SCRIPT_ACTION_CHECK_CONDITION: CONDITION_SCHEMA,
|
||||
SCRIPT_ACTION_DEVICE_AUTOMATION: DEVICE_ACTION_SCHEMA,
|
||||
SCRIPT_ACTION_ACTIVATE_SCENE: _SCRIPT_SCENE_SCHEMA,
|
||||
SCRIPT_ACTION_REPEAT: _SCRIPT_REPEAT_SCHEMA,
|
||||
}
|
||||
|
||||
|
||||
def script_action(value: Any) -> dict:
|
||||
"""Validate a script action."""
|
||||
if not isinstance(value, dict):
|
||||
raise vol.Invalid("expected dictionary")
|
||||
|
||||
return ACTION_TYPE_SCHEMAS[determine_script_action(value)](value)
|
||||
|
||||
|
||||
SCRIPT_SCHEMA = vol.All(ensure_list, [script_action])
|
||||
|
|
|
@ -3,7 +3,8 @@ from abc import ABC, abstractmethod
|
|||
import asyncio
|
||||
from contextlib import suppress
|
||||
from datetime import datetime
|
||||
from itertools import islice
|
||||
from functools import partial
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, cast
|
||||
|
||||
|
@ -18,6 +19,7 @@ from homeassistant.const import (
|
|||
CONF_ALIAS,
|
||||
CONF_CONDITION,
|
||||
CONF_CONTINUE_ON_TIMEOUT,
|
||||
CONF_COUNT,
|
||||
CONF_DELAY,
|
||||
CONF_DEVICE_ID,
|
||||
CONF_DOMAIN,
|
||||
|
@ -26,9 +28,13 @@ from homeassistant.const import (
|
|||
CONF_EVENT_DATA_TEMPLATE,
|
||||
CONF_MODE,
|
||||
CONF_QUEUE_SIZE,
|
||||
CONF_REPEAT,
|
||||
CONF_SCENE,
|
||||
CONF_SEQUENCE,
|
||||
CONF_TIMEOUT,
|
||||
CONF_UNTIL,
|
||||
CONF_WAIT_TEMPLATE,
|
||||
CONF_WHILE,
|
||||
SERVICE_TURN_OFF,
|
||||
SERVICE_TURN_ON,
|
||||
)
|
||||
|
@ -86,6 +92,10 @@ SCRIPT_BASE_SCHEMA = vol.Schema(
|
|||
}
|
||||
)
|
||||
|
||||
_UNSUPPORTED_IN_LEGACY = {
|
||||
cv.SCRIPT_ACTION_REPEAT: CONF_REPEAT,
|
||||
}
|
||||
|
||||
|
||||
def warn_deprecated_legacy(logger, msg):
|
||||
"""Warn about deprecated legacy mode."""
|
||||
|
@ -99,6 +109,17 @@ def warn_deprecated_legacy(logger, msg):
|
|||
)
|
||||
|
||||
|
||||
def validate_legacy_mode_actions(sequence):
|
||||
"""Check for actions not supported in legacy mode."""
|
||||
for action in sequence:
|
||||
script_action = cv.determine_script_action(action)
|
||||
if script_action in _UNSUPPORTED_IN_LEGACY:
|
||||
raise vol.Invalid(
|
||||
f"{_UNSUPPORTED_IN_LEGACY[script_action]} action not supported in "
|
||||
f"{SCRIPT_MODE_LEGACY} mode"
|
||||
)
|
||||
|
||||
|
||||
def validate_queue_size(config):
|
||||
"""Validate queue_size option."""
|
||||
mode = config.get(CONF_MODE, DEFAULT_SCRIPT_MODE)
|
||||
|
@ -330,22 +351,29 @@ class _ScriptRunBase(ABC):
|
|||
self._action[CONF_EVENT], event_data, context=self._context
|
||||
)
|
||||
|
||||
async def _async_get_condition(self, config):
|
||||
config_cache_key = frozenset((k, str(v)) for k, v in config.items())
|
||||
cond = self._config_cache.get(config_cache_key)
|
||||
if not cond:
|
||||
cond = await condition.async_from_config(self._hass, config, False)
|
||||
self._config_cache[config_cache_key] = cond
|
||||
return cond
|
||||
|
||||
async def _async_condition_step(self):
|
||||
"""Test if condition is matching."""
|
||||
config_cache_key = frozenset((k, str(v)) for k, v in self._action.items())
|
||||
config = self._config_cache.get(config_cache_key)
|
||||
if not config:
|
||||
config = await condition.async_from_config(self._hass, self._action, False)
|
||||
self._config_cache[config_cache_key] = config
|
||||
|
||||
self._script.last_action = self._action.get(
|
||||
CONF_ALIAS, self._action[CONF_CONDITION]
|
||||
)
|
||||
check = config(self._hass, self._variables)
|
||||
cond = await self._async_get_condition(self._action)
|
||||
check = cond(self._hass, self._variables)
|
||||
self._log("Test condition %s: %s", self._script.last_action, check)
|
||||
if not check:
|
||||
raise _StopScript
|
||||
|
||||
@abstractmethod
|
||||
async def _async_repeat_step(self):
|
||||
"""Repeat a sequence."""
|
||||
|
||||
def _log(self, msg, *args, level=logging.INFO):
|
||||
self._script._log(msg, *args, level=level) # pylint: disable=protected-access
|
||||
|
||||
|
@ -441,6 +469,41 @@ class _ScriptRun(_ScriptRunBase):
|
|||
task.cancel()
|
||||
unsub()
|
||||
|
||||
async def _async_run_long_action(self, long_task):
|
||||
"""Run a long task while monitoring for stop request."""
|
||||
|
||||
async def async_cancel_long_task():
|
||||
# Stop long task and wait for it to finish.
|
||||
long_task.cancel()
|
||||
try:
|
||||
await long_task
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
|
||||
# Wait for long task while monitoring for a stop request.
|
||||
stop_task = self._hass.async_create_task(self._stop.wait())
|
||||
try:
|
||||
await asyncio.wait(
|
||||
{long_task, stop_task}, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
# If our task is cancelled, then cancel long task, too. Note that if long task
|
||||
# is cancelled otherwise the CancelledError exception will not be raised to
|
||||
# here due to the call to asyncio.wait(). Rather we'll check for that below.
|
||||
except asyncio.CancelledError:
|
||||
await async_cancel_long_task()
|
||||
raise
|
||||
finally:
|
||||
stop_task.cancel()
|
||||
|
||||
if long_task.cancelled():
|
||||
raise asyncio.CancelledError
|
||||
if long_task.done():
|
||||
# Propagate any exceptions that occurred.
|
||||
long_task.result()
|
||||
else:
|
||||
# Stopped before long task completed, so cancel it.
|
||||
await async_cancel_long_task()
|
||||
|
||||
async def _async_call_service_step(self):
|
||||
"""Call the service specified in the action."""
|
||||
domain, service, service_data = self._prep_call_service_step()
|
||||
|
@ -474,37 +537,71 @@ class _ScriptRun(_ScriptRunBase):
|
|||
await service_task
|
||||
return
|
||||
|
||||
async def async_cancel_service_task():
|
||||
# Stop service task and wait for it to finish.
|
||||
service_task.cancel()
|
||||
try:
|
||||
await service_task
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
await self._async_run_long_action(service_task)
|
||||
|
||||
# No call limit so watch for a stop request.
|
||||
stop_task = self._hass.async_create_task(self._stop.wait())
|
||||
try:
|
||||
await asyncio.wait(
|
||||
{service_task, stop_task}, return_when=asyncio.FIRST_COMPLETED
|
||||
async def _async_repeat_step(self):
|
||||
"""Repeat a sequence."""
|
||||
|
||||
description = self._action.get(CONF_ALIAS, "sequence")
|
||||
repeat = self._action[CONF_REPEAT]
|
||||
|
||||
async def async_run_sequence(iteration, extra_msg="", extra_vars=None):
|
||||
self._log("Repeating %s: Iteration %i%s", description, iteration, extra_msg)
|
||||
repeat_vars = {"repeat": {"first": iteration == 1, "index": iteration}}
|
||||
if extra_vars:
|
||||
repeat_vars["repeat"].update(extra_vars)
|
||||
task = self._hass.async_create_task(
|
||||
# pylint: disable=protected-access
|
||||
self._script._repeat_script[self._step].async_run(
|
||||
# Add repeat to variables. Override if it already exists in case of
|
||||
# nested calls.
|
||||
{**(self._variables or {}), **repeat_vars},
|
||||
self._context,
|
||||
)
|
||||
)
|
||||
# If our task is cancelled, then cancel service task, too. Note that if service
|
||||
# task is cancelled otherwise the CancelledError exception will not be raised to
|
||||
# here due to the call to asyncio.wait(). Rather we'll check for that below.
|
||||
except asyncio.CancelledError:
|
||||
await async_cancel_service_task()
|
||||
raise
|
||||
finally:
|
||||
stop_task.cancel()
|
||||
await self._async_run_long_action(task)
|
||||
|
||||
if service_task.cancelled():
|
||||
raise asyncio.CancelledError
|
||||
if service_task.done():
|
||||
# Propagate any exceptions that occurred.
|
||||
service_task.result()
|
||||
elif running_script:
|
||||
# Stopped before service completed, so cancel service.
|
||||
await async_cancel_service_task()
|
||||
if CONF_COUNT in repeat:
|
||||
count = repeat[CONF_COUNT]
|
||||
if isinstance(count, template.Template):
|
||||
try:
|
||||
count = int(count.async_render(self._variables))
|
||||
except (exceptions.TemplateError, ValueError) as ex:
|
||||
self._log(
|
||||
"Error rendering %s repeat count template: %s",
|
||||
self._script.name,
|
||||
ex,
|
||||
level=logging.ERROR,
|
||||
)
|
||||
raise _StopScript
|
||||
for iteration in range(1, count + 1):
|
||||
await async_run_sequence(
|
||||
iteration, f" of {count}", {"last": iteration == count}
|
||||
)
|
||||
if self._stop.is_set():
|
||||
break
|
||||
|
||||
elif CONF_WHILE in repeat:
|
||||
conditions = [
|
||||
await self._async_get_condition(config) for config in repeat[CONF_WHILE]
|
||||
]
|
||||
for iteration in itertools.count(1):
|
||||
if self._stop.is_set() or not all(
|
||||
cond(self._hass, self._variables) for cond in conditions
|
||||
):
|
||||
break
|
||||
await async_run_sequence(iteration)
|
||||
|
||||
elif CONF_UNTIL in repeat:
|
||||
conditions = [
|
||||
await self._async_get_condition(config) for config in repeat[CONF_UNTIL]
|
||||
]
|
||||
for iteration in itertools.count(1):
|
||||
await async_run_sequence(iteration)
|
||||
if self._stop.is_set() or all(
|
||||
cond(self._hass, self._variables) for cond in conditions
|
||||
):
|
||||
break
|
||||
|
||||
|
||||
class _QueuedScriptRun(_ScriptRun):
|
||||
|
@ -599,7 +696,7 @@ class _LegacyScriptRun(_ScriptRunBase):
|
|||
|
||||
suspended = False
|
||||
try:
|
||||
for self._step, self._action in islice(
|
||||
for self._step, self._action in itertools.islice(
|
||||
enumerate(self._script.sequence), self._cur, None
|
||||
):
|
||||
await self._async_step(log_exceptions=not propagate_exceptions)
|
||||
|
@ -689,6 +786,10 @@ class _LegacyScriptRun(_ScriptRunBase):
|
|||
*self._prep_call_service_step(), blocking=True, context=self._context
|
||||
)
|
||||
|
||||
async def _async_repeat_step(self):
|
||||
"""Repeat a sequence."""
|
||||
# Not supported in legacy mode.
|
||||
|
||||
def _async_remove_listener(self):
|
||||
"""Remove listeners, if any."""
|
||||
for unsub in self._async_listener:
|
||||
|
@ -733,6 +834,22 @@ class Script:
|
|||
for action in self.sequence
|
||||
)
|
||||
|
||||
self._repeat_script = {}
|
||||
for step, action in enumerate(sequence):
|
||||
if cv.determine_script_action(action) == cv.SCRIPT_ACTION_REPEAT:
|
||||
step_name = action.get(CONF_ALIAS, f"Repeat at step {step}")
|
||||
sub_script = Script(
|
||||
hass,
|
||||
action[CONF_REPEAT][CONF_SEQUENCE],
|
||||
f"{name}: {step_name}",
|
||||
script_mode=SCRIPT_MODE_PARALLEL,
|
||||
logger=self._logger,
|
||||
)
|
||||
sub_script.change_listener = partial(
|
||||
self._chain_change_listener, sub_script
|
||||
)
|
||||
self._repeat_script[step] = sub_script
|
||||
|
||||
self._runs: List[_ScriptRunBase] = []
|
||||
if script_mode == SCRIPT_MODE_QUEUE:
|
||||
self._queue_size = queue_size
|
||||
|
@ -746,6 +863,11 @@ class Script:
|
|||
if self.change_listener:
|
||||
self._hass.async_run_job(self.change_listener)
|
||||
|
||||
def _chain_change_listener(self, sub_script):
|
||||
if sub_script.is_running:
|
||||
self.last_action = sub_script.last_action
|
||||
self._changed()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Return true if script is on."""
|
||||
|
|
|
@ -1072,20 +1072,27 @@ async def test_logbook_humanify_automation_triggered_event(hass):
|
|||
assert event2["entity_id"] == "automation.bye"
|
||||
|
||||
|
||||
async def test_invalid_config(hass):
|
||||
"""Test invalid config."""
|
||||
invalid_configs = [
|
||||
{
|
||||
"mode": "parallel",
|
||||
"queue_size": 5,
|
||||
"trigger": {"platform": "event", "event_type": "test_event"},
|
||||
"action": [],
|
||||
},
|
||||
{
|
||||
"mode": "legacy",
|
||||
"trigger": {"platform": "event", "event_type": "test_event"},
|
||||
"action": [{"repeat": {"count": 5, "sequence": []}}],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value", invalid_configs)
|
||||
async def test_invalid_configs(hass, value):
|
||||
"""Test invalid configurations."""
|
||||
with assert_setup_component(0, automation.DOMAIN):
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
automation.DOMAIN,
|
||||
{
|
||||
automation.DOMAIN: {
|
||||
"mode": "parallel",
|
||||
"queue_size": 5,
|
||||
"trigger": {"platform": "event", "event_type": "test_event"},
|
||||
"action": [],
|
||||
}
|
||||
},
|
||||
hass, automation.DOMAIN, {automation.DOMAIN: value}
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -200,6 +200,12 @@ invalid_configs = [
|
|||
{"test hello world": {"sequence": [{"event": "bla"}]}},
|
||||
{"test": {"sequence": {"event": "test_event", "service": "homeassistant.turn_on"}}},
|
||||
{"test": {"sequence": [], "mode": "parallel", "queue_size": 5}},
|
||||
{
|
||||
"test": {
|
||||
"mode": "legacy",
|
||||
"sequence": [{"repeat": {"count": 5, "sequence": []}}],
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -822,6 +822,101 @@ async def test_condition_all_cached(hass, script_mode):
|
|||
assert len(script_obj._config_cache) == 2
|
||||
|
||||
|
||||
async def test_repeat_count(hass):
|
||||
"""Test repeat action w/ count option."""
|
||||
event = "test_event"
|
||||
events = async_capture_events(hass, event)
|
||||
count = 3
|
||||
|
||||
sequence = cv.SCRIPT_SCHEMA(
|
||||
{
|
||||
"repeat": {
|
||||
"count": count,
|
||||
"sequence": {
|
||||
"event": event,
|
||||
"event_data_template": {
|
||||
"first": "{{ repeat.first }}",
|
||||
"index": "{{ repeat.index }}",
|
||||
"last": "{{ repeat.last }}",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
)
|
||||
script_obj = script.Script(hass, sequence, script_mode="ignore")
|
||||
|
||||
await script_obj.async_run()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(events) == count
|
||||
for index, event in enumerate(events):
|
||||
assert event.data.get("first") == str(index == 0)
|
||||
assert event.data.get("index") == str(index + 1)
|
||||
assert event.data.get("last") == str(index == count - 1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("condition", ["while", "until"])
|
||||
async def test_repeat_conditional(hass, condition):
|
||||
"""Test repeat action w/ while option."""
|
||||
event = "test_event"
|
||||
events = async_capture_events(hass, event)
|
||||
count = 3
|
||||
|
||||
sequence = {
|
||||
"repeat": {
|
||||
"sequence": [
|
||||
{
|
||||
"event": event,
|
||||
"event_data_template": {
|
||||
"first": "{{ repeat.first }}",
|
||||
"index": "{{ repeat.index }}",
|
||||
},
|
||||
},
|
||||
{"wait_template": "{{ is_state('sensor.test', 'next') }}"},
|
||||
{"wait_template": "{{ not is_state('sensor.test', 'next') }}"},
|
||||
],
|
||||
}
|
||||
}
|
||||
if condition == "while":
|
||||
sequence["repeat"]["while"] = {
|
||||
"condition": "template",
|
||||
"value_template": "{{ not is_state('sensor.test', 'done') }}",
|
||||
}
|
||||
else:
|
||||
sequence["repeat"]["until"] = {
|
||||
"condition": "template",
|
||||
"value_template": "{{ is_state('sensor.test', 'done') }}",
|
||||
}
|
||||
script_obj = script.Script(hass, cv.SCRIPT_SCHEMA(sequence), script_mode="ignore")
|
||||
|
||||
wait_started = async_watch_for_action(script_obj, "wait")
|
||||
hass.states.async_set("sensor.test", "1")
|
||||
|
||||
hass.async_create_task(script_obj.async_run())
|
||||
try:
|
||||
for index in range(2, count + 1):
|
||||
await asyncio.wait_for(wait_started.wait(), 1)
|
||||
wait_started.clear()
|
||||
hass.states.async_set("sensor.test", "next")
|
||||
await asyncio.wait_for(wait_started.wait(), 1)
|
||||
wait_started.clear()
|
||||
hass.states.async_set("sensor.test", index)
|
||||
await asyncio.wait_for(wait_started.wait(), 1)
|
||||
hass.states.async_set("sensor.test", "next")
|
||||
await asyncio.wait_for(wait_started.wait(), 1)
|
||||
wait_started.clear()
|
||||
hass.states.async_set("sensor.test", "done")
|
||||
await asyncio.wait_for(hass.async_block_till_done(), 1)
|
||||
except asyncio.TimeoutError:
|
||||
await script_obj.async_stop()
|
||||
raise
|
||||
|
||||
assert len(events) == count
|
||||
for index, event in enumerate(events):
|
||||
assert event.data.get("first") == str(index == 0)
|
||||
assert event.data.get("index") == str(index + 1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
|
||||
async def test_last_triggered(hass, script_mode):
|
||||
"""Test the last_triggered."""
|
||||
|
|
Loading…
Add table
Reference in a new issue