Add new repeat loop for scripts and automations (#37589)

This commit is contained in:
Phil Bruckner 2020-07-10 13:37:19 -05:00 committed by GitHub
parent b187b17a4f
commit 91271f388c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 337 additions and 62 deletions

View file

@ -15,6 +15,7 @@ from homeassistant.helpers import condition, config_per_platform
from homeassistant.helpers.script import ( from homeassistant.helpers.script import (
SCRIPT_MODE_LEGACY, SCRIPT_MODE_LEGACY,
async_validate_action_config, async_validate_action_config,
validate_legacy_mode_actions,
warn_deprecated_legacy, warn_deprecated_legacy,
) )
from homeassistant.loader import IntegrationNotFound from homeassistant.loader import IntegrationNotFound
@ -55,6 +56,9 @@ async def async_validate_config_item(hass, config, full_config=None):
*[async_validate_action_config(hass, action) for action in config[CONF_ACTION]] *[async_validate_action_config(hass, action) for action in config[CONF_ACTION]]
) )
if config.get(CONF_MODE, SCRIPT_MODE_LEGACY) == SCRIPT_MODE_LEGACY:
validate_legacy_mode_actions(config[CONF_ACTION])
return config return config

View file

@ -12,6 +12,7 @@ from homeassistant.const import (
CONF_ICON, CONF_ICON,
CONF_MODE, CONF_MODE,
CONF_QUEUE_SIZE, CONF_QUEUE_SIZE,
CONF_SEQUENCE,
SERVICE_RELOAD, SERVICE_RELOAD,
SERVICE_TOGGLE, SERVICE_TOGGLE,
SERVICE_TURN_OFF, SERVICE_TURN_OFF,
@ -27,6 +28,7 @@ from homeassistant.helpers.script import (
SCRIPT_BASE_SCHEMA, SCRIPT_BASE_SCHEMA,
SCRIPT_MODE_LEGACY, SCRIPT_MODE_LEGACY,
Script, Script,
validate_legacy_mode_actions,
validate_queue_size, validate_queue_size,
warn_deprecated_legacy, warn_deprecated_legacy,
) )
@ -44,7 +46,6 @@ ATTR_VARIABLES = "variables"
CONF_DESCRIPTION = "description" CONF_DESCRIPTION = "description"
CONF_EXAMPLE = "example" CONF_EXAMPLE = "example"
CONF_FIELDS = "fields" CONF_FIELDS = "fields"
CONF_SEQUENCE = "sequence"
ENTITY_ID_FORMAT = DOMAIN + ".{}" ENTITY_ID_FORMAT = DOMAIN + ".{}"
@ -63,6 +64,13 @@ def _deprecated_legacy_mode(config):
return config return config
def _not_supported_in_legacy_mode(config):
if config.get(CONF_MODE, SCRIPT_MODE_LEGACY) == SCRIPT_MODE_LEGACY:
validate_legacy_mode_actions(config[CONF_SEQUENCE])
return config
SCRIPT_ENTRY_SCHEMA = vol.All( SCRIPT_ENTRY_SCHEMA = vol.All(
SCRIPT_BASE_SCHEMA.extend( SCRIPT_BASE_SCHEMA.extend(
{ {
@ -79,6 +87,7 @@ SCRIPT_ENTRY_SCHEMA = vol.All(
} }
), ),
validate_queue_size, validate_queue_size,
_not_supported_in_legacy_mode,
) )
CONFIG_SCHEMA = vol.Schema( CONFIG_SCHEMA = vol.Schema(

View file

@ -61,6 +61,7 @@ CONF_COMMAND_STATE = "command_state"
CONF_COMMAND_STOP = "command_stop" CONF_COMMAND_STOP = "command_stop"
CONF_CONDITION = "condition" CONF_CONDITION = "condition"
CONF_CONTINUE_ON_TIMEOUT = "continue_on_timeout" CONF_CONTINUE_ON_TIMEOUT = "continue_on_timeout"
CONF_COUNT = "count"
CONF_COVERS = "covers" CONF_COVERS = "covers"
CONF_CURRENCY = "currency" CONF_CURRENCY = "currency"
CONF_CUSTOMIZE = "customize" CONF_CUSTOMIZE = "customize"
@ -139,6 +140,7 @@ CONF_QUOTE = "quote"
CONF_RADIUS = "radius" CONF_RADIUS = "radius"
CONF_RECIPIENT = "recipient" CONF_RECIPIENT = "recipient"
CONF_REGION = "region" CONF_REGION = "region"
CONF_REPEAT = "repeat"
CONF_RESOURCE = "resource" CONF_RESOURCE = "resource"
CONF_RESOURCES = "resources" CONF_RESOURCES = "resources"
CONF_RESOURCE_TEMPLATE = "resource_template" CONF_RESOURCE_TEMPLATE = "resource_template"
@ -149,6 +151,7 @@ CONF_SCENE = "scene"
CONF_SENDER = "sender" CONF_SENDER = "sender"
CONF_SENSORS = "sensors" CONF_SENSORS = "sensors"
CONF_SENSOR_TYPE = "sensor_type" CONF_SENSOR_TYPE = "sensor_type"
CONF_SEQUENCE = "sequence"
CONF_SERVICE = "service" CONF_SERVICE = "service"
CONF_SERVICE_DATA = "data" CONF_SERVICE_DATA = "data"
CONF_SERVICE_TEMPLATE = "service_template" CONF_SERVICE_TEMPLATE = "service_template"
@ -169,6 +172,7 @@ CONF_TTL = "ttl"
CONF_TYPE = "type" CONF_TYPE = "type"
CONF_UNIT_OF_MEASUREMENT = "unit_of_measurement" CONF_UNIT_OF_MEASUREMENT = "unit_of_measurement"
CONF_UNIT_SYSTEM = "unit_system" CONF_UNIT_SYSTEM = "unit_system"
CONF_UNTIL = "until"
CONF_URL = "url" CONF_URL = "url"
CONF_USERNAME = "username" CONF_USERNAME = "username"
CONF_VALUE_TEMPLATE = "value_template" CONF_VALUE_TEMPLATE = "value_template"
@ -176,6 +180,7 @@ CONF_VERIFY_SSL = "verify_ssl"
CONF_WAIT_TEMPLATE = "wait_template" CONF_WAIT_TEMPLATE = "wait_template"
CONF_WEBHOOK_ID = "webhook_id" CONF_WEBHOOK_ID = "webhook_id"
CONF_WEEKDAY = "weekday" CONF_WEEKDAY = "weekday"
CONF_WHILE = "while"
CONF_WHITELIST = "whitelist" CONF_WHITELIST = "whitelist"
CONF_WHITELIST_EXTERNAL_DIRS = "whitelist_external_dirs" CONF_WHITELIST_EXTERNAL_DIRS = "whitelist_external_dirs"
CONF_WHITE_VALUE = "white_value" CONF_WHITE_VALUE = "white_value"

View file

@ -40,6 +40,7 @@ from homeassistant.const import (
CONF_BELOW, CONF_BELOW,
CONF_CONDITION, CONF_CONDITION,
CONF_CONTINUE_ON_TIMEOUT, CONF_CONTINUE_ON_TIMEOUT,
CONF_COUNT,
CONF_DELAY, CONF_DELAY,
CONF_DEVICE_ID, CONF_DEVICE_ID,
CONF_DOMAIN, CONF_DOMAIN,
@ -50,16 +51,20 @@ from homeassistant.const import (
CONF_EVENT_DATA_TEMPLATE, CONF_EVENT_DATA_TEMPLATE,
CONF_FOR, CONF_FOR,
CONF_PLATFORM, CONF_PLATFORM,
CONF_REPEAT,
CONF_SCAN_INTERVAL, CONF_SCAN_INTERVAL,
CONF_SCENE, CONF_SCENE,
CONF_SEQUENCE,
CONF_SERVICE, CONF_SERVICE,
CONF_SERVICE_TEMPLATE, CONF_SERVICE_TEMPLATE,
CONF_STATE, CONF_STATE,
CONF_TIMEOUT, CONF_TIMEOUT,
CONF_UNIT_SYSTEM_IMPERIAL, CONF_UNIT_SYSTEM_IMPERIAL,
CONF_UNIT_SYSTEM_METRIC, CONF_UNIT_SYSTEM_METRIC,
CONF_UNTIL,
CONF_VALUE_TEMPLATE, CONF_VALUE_TEMPLATE,
CONF_WAIT_TEMPLATE, CONF_WAIT_TEMPLATE,
CONF_WHILE,
ENTITY_MATCH_ALL, ENTITY_MATCH_ALL,
ENTITY_MATCH_NONE, ENTITY_MATCH_NONE,
SUN_EVENT_SUNRISE, SUN_EVENT_SUNRISE,
@ -80,7 +85,6 @@ import homeassistant.util.dt as dt_util
TIME_PERIOD_ERROR = "offset {} should be format 'HH:MM' or 'HH:MM:SS'" TIME_PERIOD_ERROR = "offset {} should be format 'HH:MM' or 'HH:MM:SS'"
# Home Assistant types # Home Assistant types
byte = vol.All(vol.Coerce(int), vol.Range(min=0, max=255)) byte = vol.All(vol.Coerce(int), vol.Range(min=0, max=255))
small_float = vol.All(vol.Coerce(float), vol.Range(min=0, max=1)) small_float = vol.All(vol.Coerce(float), vol.Range(min=0, max=1))
@ -817,6 +821,16 @@ def make_entity_service_schema(
) )
def script_action(value: Any) -> dict:
"""Validate a script action."""
if not isinstance(value, dict):
raise vol.Invalid("expected dictionary")
return ACTION_TYPE_SCHEMAS[determine_script_action(value)](value)
SCRIPT_SCHEMA = vol.All(ensure_list, [script_action])
EVENT_SCHEMA = vol.Schema( EVENT_SCHEMA = vol.Schema(
{ {
vol.Optional(CONF_ALIAS): string, vol.Optional(CONF_ALIAS): string,
@ -998,6 +1012,25 @@ DEVICE_ACTION_SCHEMA = DEVICE_ACTION_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTR
_SCRIPT_SCENE_SCHEMA = vol.Schema({vol.Required(CONF_SCENE): entity_domain("scene")}) _SCRIPT_SCENE_SCHEMA = vol.Schema({vol.Required(CONF_SCENE): entity_domain("scene")})
_SCRIPT_REPEAT_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ALIAS): string,
vol.Required(CONF_REPEAT): vol.All(
{
vol.Exclusive(CONF_COUNT, "repeat"): vol.Any(vol.Coerce(int), template),
vol.Exclusive(CONF_WHILE, "repeat"): vol.All(
ensure_list, [CONDITION_SCHEMA]
),
vol.Exclusive(CONF_UNTIL, "repeat"): vol.All(
ensure_list, [CONDITION_SCHEMA]
),
vol.Required(CONF_SEQUENCE): SCRIPT_SCHEMA,
},
has_at_least_one_key(CONF_COUNT, CONF_WHILE, CONF_UNTIL),
),
}
)
SCRIPT_ACTION_DELAY = "delay" SCRIPT_ACTION_DELAY = "delay"
SCRIPT_ACTION_WAIT_TEMPLATE = "wait_template" SCRIPT_ACTION_WAIT_TEMPLATE = "wait_template"
SCRIPT_ACTION_CHECK_CONDITION = "condition" SCRIPT_ACTION_CHECK_CONDITION = "condition"
@ -1005,6 +1038,7 @@ SCRIPT_ACTION_FIRE_EVENT = "event"
SCRIPT_ACTION_CALL_SERVICE = "call_service" SCRIPT_ACTION_CALL_SERVICE = "call_service"
SCRIPT_ACTION_DEVICE_AUTOMATION = "device" SCRIPT_ACTION_DEVICE_AUTOMATION = "device"
SCRIPT_ACTION_ACTIVATE_SCENE = "scene" SCRIPT_ACTION_ACTIVATE_SCENE = "scene"
SCRIPT_ACTION_REPEAT = "repeat"
def determine_script_action(action: dict) -> str: def determine_script_action(action: dict) -> str:
@ -1027,6 +1061,9 @@ def determine_script_action(action: dict) -> str:
if CONF_SCENE in action: if CONF_SCENE in action:
return SCRIPT_ACTION_ACTIVATE_SCENE return SCRIPT_ACTION_ACTIVATE_SCENE
if CONF_REPEAT in action:
return SCRIPT_ACTION_REPEAT
return SCRIPT_ACTION_CALL_SERVICE return SCRIPT_ACTION_CALL_SERVICE
@ -1038,15 +1075,5 @@ ACTION_TYPE_SCHEMAS: Dict[str, Callable[[Any], dict]] = {
SCRIPT_ACTION_CHECK_CONDITION: CONDITION_SCHEMA, SCRIPT_ACTION_CHECK_CONDITION: CONDITION_SCHEMA,
SCRIPT_ACTION_DEVICE_AUTOMATION: DEVICE_ACTION_SCHEMA, SCRIPT_ACTION_DEVICE_AUTOMATION: DEVICE_ACTION_SCHEMA,
SCRIPT_ACTION_ACTIVATE_SCENE: _SCRIPT_SCENE_SCHEMA, SCRIPT_ACTION_ACTIVATE_SCENE: _SCRIPT_SCENE_SCHEMA,
SCRIPT_ACTION_REPEAT: _SCRIPT_REPEAT_SCHEMA,
} }
def script_action(value: Any) -> dict:
"""Validate a script action."""
if not isinstance(value, dict):
raise vol.Invalid("expected dictionary")
return ACTION_TYPE_SCHEMAS[determine_script_action(value)](value)
SCRIPT_SCHEMA = vol.All(ensure_list, [script_action])

View file

@ -3,7 +3,8 @@ from abc import ABC, abstractmethod
import asyncio import asyncio
from contextlib import suppress from contextlib import suppress
from datetime import datetime from datetime import datetime
from itertools import islice from functools import partial
import itertools
import logging import logging
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, cast from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, cast
@ -18,6 +19,7 @@ from homeassistant.const import (
CONF_ALIAS, CONF_ALIAS,
CONF_CONDITION, CONF_CONDITION,
CONF_CONTINUE_ON_TIMEOUT, CONF_CONTINUE_ON_TIMEOUT,
CONF_COUNT,
CONF_DELAY, CONF_DELAY,
CONF_DEVICE_ID, CONF_DEVICE_ID,
CONF_DOMAIN, CONF_DOMAIN,
@ -26,9 +28,13 @@ from homeassistant.const import (
CONF_EVENT_DATA_TEMPLATE, CONF_EVENT_DATA_TEMPLATE,
CONF_MODE, CONF_MODE,
CONF_QUEUE_SIZE, CONF_QUEUE_SIZE,
CONF_REPEAT,
CONF_SCENE, CONF_SCENE,
CONF_SEQUENCE,
CONF_TIMEOUT, CONF_TIMEOUT,
CONF_UNTIL,
CONF_WAIT_TEMPLATE, CONF_WAIT_TEMPLATE,
CONF_WHILE,
SERVICE_TURN_OFF, SERVICE_TURN_OFF,
SERVICE_TURN_ON, SERVICE_TURN_ON,
) )
@ -86,6 +92,10 @@ SCRIPT_BASE_SCHEMA = vol.Schema(
} }
) )
_UNSUPPORTED_IN_LEGACY = {
cv.SCRIPT_ACTION_REPEAT: CONF_REPEAT,
}
def warn_deprecated_legacy(logger, msg): def warn_deprecated_legacy(logger, msg):
"""Warn about deprecated legacy mode.""" """Warn about deprecated legacy mode."""
@ -99,6 +109,17 @@ def warn_deprecated_legacy(logger, msg):
) )
def validate_legacy_mode_actions(sequence):
"""Check for actions not supported in legacy mode."""
for action in sequence:
script_action = cv.determine_script_action(action)
if script_action in _UNSUPPORTED_IN_LEGACY:
raise vol.Invalid(
f"{_UNSUPPORTED_IN_LEGACY[script_action]} action not supported in "
f"{SCRIPT_MODE_LEGACY} mode"
)
def validate_queue_size(config): def validate_queue_size(config):
"""Validate queue_size option.""" """Validate queue_size option."""
mode = config.get(CONF_MODE, DEFAULT_SCRIPT_MODE) mode = config.get(CONF_MODE, DEFAULT_SCRIPT_MODE)
@ -330,22 +351,29 @@ class _ScriptRunBase(ABC):
self._action[CONF_EVENT], event_data, context=self._context self._action[CONF_EVENT], event_data, context=self._context
) )
async def _async_get_condition(self, config):
config_cache_key = frozenset((k, str(v)) for k, v in config.items())
cond = self._config_cache.get(config_cache_key)
if not cond:
cond = await condition.async_from_config(self._hass, config, False)
self._config_cache[config_cache_key] = cond
return cond
async def _async_condition_step(self): async def _async_condition_step(self):
"""Test if condition is matching.""" """Test if condition is matching."""
config_cache_key = frozenset((k, str(v)) for k, v in self._action.items())
config = self._config_cache.get(config_cache_key)
if not config:
config = await condition.async_from_config(self._hass, self._action, False)
self._config_cache[config_cache_key] = config
self._script.last_action = self._action.get( self._script.last_action = self._action.get(
CONF_ALIAS, self._action[CONF_CONDITION] CONF_ALIAS, self._action[CONF_CONDITION]
) )
check = config(self._hass, self._variables) cond = await self._async_get_condition(self._action)
check = cond(self._hass, self._variables)
self._log("Test condition %s: %s", self._script.last_action, check) self._log("Test condition %s: %s", self._script.last_action, check)
if not check: if not check:
raise _StopScript raise _StopScript
@abstractmethod
async def _async_repeat_step(self):
"""Repeat a sequence."""
def _log(self, msg, *args, level=logging.INFO): def _log(self, msg, *args, level=logging.INFO):
self._script._log(msg, *args, level=level) # pylint: disable=protected-access self._script._log(msg, *args, level=level) # pylint: disable=protected-access
@ -441,6 +469,41 @@ class _ScriptRun(_ScriptRunBase):
task.cancel() task.cancel()
unsub() unsub()
async def _async_run_long_action(self, long_task):
"""Run a long task while monitoring for stop request."""
async def async_cancel_long_task():
# Stop long task and wait for it to finish.
long_task.cancel()
try:
await long_task
except Exception: # pylint: disable=broad-except
pass
# Wait for long task while monitoring for a stop request.
stop_task = self._hass.async_create_task(self._stop.wait())
try:
await asyncio.wait(
{long_task, stop_task}, return_when=asyncio.FIRST_COMPLETED
)
# If our task is cancelled, then cancel long task, too. Note that if long task
# is cancelled otherwise the CancelledError exception will not be raised to
# here due to the call to asyncio.wait(). Rather we'll check for that below.
except asyncio.CancelledError:
await async_cancel_long_task()
raise
finally:
stop_task.cancel()
if long_task.cancelled():
raise asyncio.CancelledError
if long_task.done():
# Propagate any exceptions that occurred.
long_task.result()
else:
# Stopped before long task completed, so cancel it.
await async_cancel_long_task()
async def _async_call_service_step(self): async def _async_call_service_step(self):
"""Call the service specified in the action.""" """Call the service specified in the action."""
domain, service, service_data = self._prep_call_service_step() domain, service, service_data = self._prep_call_service_step()
@ -474,37 +537,71 @@ class _ScriptRun(_ScriptRunBase):
await service_task await service_task
return return
async def async_cancel_service_task(): await self._async_run_long_action(service_task)
# Stop service task and wait for it to finish.
service_task.cancel()
try:
await service_task
except Exception: # pylint: disable=broad-except
pass
# No call limit so watch for a stop request. async def _async_repeat_step(self):
stop_task = self._hass.async_create_task(self._stop.wait()) """Repeat a sequence."""
try:
await asyncio.wait( description = self._action.get(CONF_ALIAS, "sequence")
{service_task, stop_task}, return_when=asyncio.FIRST_COMPLETED repeat = self._action[CONF_REPEAT]
async def async_run_sequence(iteration, extra_msg="", extra_vars=None):
self._log("Repeating %s: Iteration %i%s", description, iteration, extra_msg)
repeat_vars = {"repeat": {"first": iteration == 1, "index": iteration}}
if extra_vars:
repeat_vars["repeat"].update(extra_vars)
task = self._hass.async_create_task(
# pylint: disable=protected-access
self._script._repeat_script[self._step].async_run(
# Add repeat to variables. Override if it already exists in case of
# nested calls.
{**(self._variables or {}), **repeat_vars},
self._context,
)
) )
# If our task is cancelled, then cancel service task, too. Note that if service await self._async_run_long_action(task)
# task is cancelled otherwise the CancelledError exception will not be raised to
# here due to the call to asyncio.wait(). Rather we'll check for that below.
except asyncio.CancelledError:
await async_cancel_service_task()
raise
finally:
stop_task.cancel()
if service_task.cancelled(): if CONF_COUNT in repeat:
raise asyncio.CancelledError count = repeat[CONF_COUNT]
if service_task.done(): if isinstance(count, template.Template):
# Propagate any exceptions that occurred. try:
service_task.result() count = int(count.async_render(self._variables))
elif running_script: except (exceptions.TemplateError, ValueError) as ex:
# Stopped before service completed, so cancel service. self._log(
await async_cancel_service_task() "Error rendering %s repeat count template: %s",
self._script.name,
ex,
level=logging.ERROR,
)
raise _StopScript
for iteration in range(1, count + 1):
await async_run_sequence(
iteration, f" of {count}", {"last": iteration == count}
)
if self._stop.is_set():
break
elif CONF_WHILE in repeat:
conditions = [
await self._async_get_condition(config) for config in repeat[CONF_WHILE]
]
for iteration in itertools.count(1):
if self._stop.is_set() or not all(
cond(self._hass, self._variables) for cond in conditions
):
break
await async_run_sequence(iteration)
elif CONF_UNTIL in repeat:
conditions = [
await self._async_get_condition(config) for config in repeat[CONF_UNTIL]
]
for iteration in itertools.count(1):
await async_run_sequence(iteration)
if self._stop.is_set() or all(
cond(self._hass, self._variables) for cond in conditions
):
break
class _QueuedScriptRun(_ScriptRun): class _QueuedScriptRun(_ScriptRun):
@ -599,7 +696,7 @@ class _LegacyScriptRun(_ScriptRunBase):
suspended = False suspended = False
try: try:
for self._step, self._action in islice( for self._step, self._action in itertools.islice(
enumerate(self._script.sequence), self._cur, None enumerate(self._script.sequence), self._cur, None
): ):
await self._async_step(log_exceptions=not propagate_exceptions) await self._async_step(log_exceptions=not propagate_exceptions)
@ -689,6 +786,10 @@ class _LegacyScriptRun(_ScriptRunBase):
*self._prep_call_service_step(), blocking=True, context=self._context *self._prep_call_service_step(), blocking=True, context=self._context
) )
async def _async_repeat_step(self):
"""Repeat a sequence."""
# Not supported in legacy mode.
def _async_remove_listener(self): def _async_remove_listener(self):
"""Remove listeners, if any.""" """Remove listeners, if any."""
for unsub in self._async_listener: for unsub in self._async_listener:
@ -733,6 +834,22 @@ class Script:
for action in self.sequence for action in self.sequence
) )
self._repeat_script = {}
for step, action in enumerate(sequence):
if cv.determine_script_action(action) == cv.SCRIPT_ACTION_REPEAT:
step_name = action.get(CONF_ALIAS, f"Repeat at step {step}")
sub_script = Script(
hass,
action[CONF_REPEAT][CONF_SEQUENCE],
f"{name}: {step_name}",
script_mode=SCRIPT_MODE_PARALLEL,
logger=self._logger,
)
sub_script.change_listener = partial(
self._chain_change_listener, sub_script
)
self._repeat_script[step] = sub_script
self._runs: List[_ScriptRunBase] = [] self._runs: List[_ScriptRunBase] = []
if script_mode == SCRIPT_MODE_QUEUE: if script_mode == SCRIPT_MODE_QUEUE:
self._queue_size = queue_size self._queue_size = queue_size
@ -746,6 +863,11 @@ class Script:
if self.change_listener: if self.change_listener:
self._hass.async_run_job(self.change_listener) self._hass.async_run_job(self.change_listener)
def _chain_change_listener(self, sub_script):
if sub_script.is_running:
self.last_action = sub_script.last_action
self._changed()
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
"""Return true if script is on.""" """Return true if script is on."""

View file

@ -1072,20 +1072,27 @@ async def test_logbook_humanify_automation_triggered_event(hass):
assert event2["entity_id"] == "automation.bye" assert event2["entity_id"] == "automation.bye"
async def test_invalid_config(hass): invalid_configs = [
"""Test invalid config.""" {
"mode": "parallel",
"queue_size": 5,
"trigger": {"platform": "event", "event_type": "test_event"},
"action": [],
},
{
"mode": "legacy",
"trigger": {"platform": "event", "event_type": "test_event"},
"action": [{"repeat": {"count": 5, "sequence": []}}],
},
]
@pytest.mark.parametrize("value", invalid_configs)
async def test_invalid_configs(hass, value):
"""Test invalid configurations."""
with assert_setup_component(0, automation.DOMAIN): with assert_setup_component(0, automation.DOMAIN):
assert await async_setup_component( assert await async_setup_component(
hass, hass, automation.DOMAIN, {automation.DOMAIN: value}
automation.DOMAIN,
{
automation.DOMAIN: {
"mode": "parallel",
"queue_size": 5,
"trigger": {"platform": "event", "event_type": "test_event"},
"action": [],
}
},
) )

View file

@ -200,6 +200,12 @@ invalid_configs = [
{"test hello world": {"sequence": [{"event": "bla"}]}}, {"test hello world": {"sequence": [{"event": "bla"}]}},
{"test": {"sequence": {"event": "test_event", "service": "homeassistant.turn_on"}}}, {"test": {"sequence": {"event": "test_event", "service": "homeassistant.turn_on"}}},
{"test": {"sequence": [], "mode": "parallel", "queue_size": 5}}, {"test": {"sequence": [], "mode": "parallel", "queue_size": 5}},
{
"test": {
"mode": "legacy",
"sequence": [{"repeat": {"count": 5, "sequence": []}}],
}
},
] ]

View file

@ -822,6 +822,101 @@ async def test_condition_all_cached(hass, script_mode):
assert len(script_obj._config_cache) == 2 assert len(script_obj._config_cache) == 2
async def test_repeat_count(hass):
"""Test repeat action w/ count option."""
event = "test_event"
events = async_capture_events(hass, event)
count = 3
sequence = cv.SCRIPT_SCHEMA(
{
"repeat": {
"count": count,
"sequence": {
"event": event,
"event_data_template": {
"first": "{{ repeat.first }}",
"index": "{{ repeat.index }}",
"last": "{{ repeat.last }}",
},
},
}
}
)
script_obj = script.Script(hass, sequence, script_mode="ignore")
await script_obj.async_run()
await hass.async_block_till_done()
assert len(events) == count
for index, event in enumerate(events):
assert event.data.get("first") == str(index == 0)
assert event.data.get("index") == str(index + 1)
assert event.data.get("last") == str(index == count - 1)
@pytest.mark.parametrize("condition", ["while", "until"])
async def test_repeat_conditional(hass, condition):
"""Test repeat action w/ while option."""
event = "test_event"
events = async_capture_events(hass, event)
count = 3
sequence = {
"repeat": {
"sequence": [
{
"event": event,
"event_data_template": {
"first": "{{ repeat.first }}",
"index": "{{ repeat.index }}",
},
},
{"wait_template": "{{ is_state('sensor.test', 'next') }}"},
{"wait_template": "{{ not is_state('sensor.test', 'next') }}"},
],
}
}
if condition == "while":
sequence["repeat"]["while"] = {
"condition": "template",
"value_template": "{{ not is_state('sensor.test', 'done') }}",
}
else:
sequence["repeat"]["until"] = {
"condition": "template",
"value_template": "{{ is_state('sensor.test', 'done') }}",
}
script_obj = script.Script(hass, cv.SCRIPT_SCHEMA(sequence), script_mode="ignore")
wait_started = async_watch_for_action(script_obj, "wait")
hass.states.async_set("sensor.test", "1")
hass.async_create_task(script_obj.async_run())
try:
for index in range(2, count + 1):
await asyncio.wait_for(wait_started.wait(), 1)
wait_started.clear()
hass.states.async_set("sensor.test", "next")
await asyncio.wait_for(wait_started.wait(), 1)
wait_started.clear()
hass.states.async_set("sensor.test", index)
await asyncio.wait_for(wait_started.wait(), 1)
hass.states.async_set("sensor.test", "next")
await asyncio.wait_for(wait_started.wait(), 1)
wait_started.clear()
hass.states.async_set("sensor.test", "done")
await asyncio.wait_for(hass.async_block_till_done(), 1)
except asyncio.TimeoutError:
await script_obj.async_stop()
raise
assert len(events) == count
for index, event in enumerate(events):
assert event.data.get("first") == str(index == 0)
assert event.data.get("index") == str(index + 1)
@pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES) @pytest.mark.parametrize("script_mode", _BASIC_SCRIPT_MODES)
async def test_last_triggered(hass, script_mode): async def test_last_triggered(hass, script_mode):
"""Test the last_triggered.""" """Test the last_triggered."""