Add choose script action (#37818)
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
515ad6164d
commit
7e280e2b27
4 changed files with 176 additions and 39 deletions
|
@ -47,6 +47,7 @@ CONF_BINARY_SENSORS = "binary_sensors"
|
|||
CONF_BRIGHTNESS = "brightness"
|
||||
CONF_BROADCAST_ADDRESS = "broadcast_address"
|
||||
CONF_BROADCAST_PORT = "broadcast_port"
|
||||
CONF_CHOOSE = "choose"
|
||||
CONF_CLIENT_ID = "client_id"
|
||||
CONF_CLIENT_SECRET = "client_secret"
|
||||
CONF_CODE = "code"
|
||||
|
@ -59,6 +60,7 @@ CONF_COMMAND_OPEN = "command_open"
|
|||
CONF_COMMAND_STATE = "command_state"
|
||||
CONF_COMMAND_STOP = "command_stop"
|
||||
CONF_CONDITION = "condition"
|
||||
CONF_CONDITIONS = "conditions"
|
||||
CONF_CONTINUE_ON_TIMEOUT = "continue_on_timeout"
|
||||
CONF_COUNT = "count"
|
||||
CONF_COVERS = "covers"
|
||||
|
@ -66,6 +68,7 @@ CONF_CURRENCY = "currency"
|
|||
CONF_CUSTOMIZE = "customize"
|
||||
CONF_CUSTOMIZE_DOMAIN = "customize_domain"
|
||||
CONF_CUSTOMIZE_GLOB = "customize_glob"
|
||||
CONF_DEFAULT = "default"
|
||||
CONF_DELAY = "delay"
|
||||
CONF_DELAY_TIME = "delay_time"
|
||||
CONF_DEVICE = "device"
|
||||
|
|
|
@ -38,9 +38,12 @@ from homeassistant.const import (
|
|||
CONF_ABOVE,
|
||||
CONF_ALIAS,
|
||||
CONF_BELOW,
|
||||
CONF_CHOOSE,
|
||||
CONF_CONDITION,
|
||||
CONF_CONDITIONS,
|
||||
CONF_CONTINUE_ON_TIMEOUT,
|
||||
CONF_COUNT,
|
||||
CONF_DEFAULT,
|
||||
CONF_DELAY,
|
||||
CONF_DEVICE_ID,
|
||||
CONF_DOMAIN,
|
||||
|
@ -930,7 +933,7 @@ ZONE_CONDITION_SCHEMA = vol.Schema(
|
|||
AND_CONDITION_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_CONDITION): "and",
|
||||
vol.Required("conditions"): vol.All(
|
||||
vol.Required(CONF_CONDITIONS): vol.All(
|
||||
ensure_list,
|
||||
# pylint: disable=unnecessary-lambda
|
||||
[lambda value: CONDITION_SCHEMA(value)],
|
||||
|
@ -941,7 +944,7 @@ AND_CONDITION_SCHEMA = vol.Schema(
|
|||
OR_CONDITION_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_CONDITION): "or",
|
||||
vol.Required("conditions"): vol.All(
|
||||
vol.Required(CONF_CONDITIONS): vol.All(
|
||||
ensure_list,
|
||||
# pylint: disable=unnecessary-lambda
|
||||
[lambda value: CONDITION_SCHEMA(value)],
|
||||
|
@ -952,7 +955,7 @@ OR_CONDITION_SCHEMA = vol.Schema(
|
|||
NOT_CONDITION_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_CONDITION): "not",
|
||||
vol.Required("conditions"): vol.All(
|
||||
vol.Required(CONF_CONDITIONS): vol.All(
|
||||
ensure_list,
|
||||
# pylint: disable=unnecessary-lambda
|
||||
[lambda value: CONDITION_SCHEMA(value)],
|
||||
|
@ -1031,6 +1034,24 @@ _SCRIPT_REPEAT_SCHEMA = vol.Schema(
|
|||
}
|
||||
)
|
||||
|
||||
_SCRIPT_CHOOSE_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Optional(CONF_ALIAS): string,
|
||||
vol.Required(CONF_CHOOSE): vol.All(
|
||||
ensure_list,
|
||||
[
|
||||
{
|
||||
vol.Required(CONF_CONDITIONS): vol.All(
|
||||
ensure_list, [CONDITION_SCHEMA]
|
||||
),
|
||||
vol.Required(CONF_SEQUENCE): SCRIPT_SCHEMA,
|
||||
}
|
||||
],
|
||||
),
|
||||
vol.Optional(CONF_DEFAULT): SCRIPT_SCHEMA,
|
||||
}
|
||||
)
|
||||
|
||||
SCRIPT_ACTION_DELAY = "delay"
|
||||
SCRIPT_ACTION_WAIT_TEMPLATE = "wait_template"
|
||||
SCRIPT_ACTION_CHECK_CONDITION = "condition"
|
||||
|
@ -1039,6 +1060,7 @@ SCRIPT_ACTION_CALL_SERVICE = "call_service"
|
|||
SCRIPT_ACTION_DEVICE_AUTOMATION = "device"
|
||||
SCRIPT_ACTION_ACTIVATE_SCENE = "scene"
|
||||
SCRIPT_ACTION_REPEAT = "repeat"
|
||||
SCRIPT_ACTION_CHOOSE = "choose"
|
||||
|
||||
|
||||
def determine_script_action(action: dict) -> str:
|
||||
|
@ -1064,6 +1086,9 @@ def determine_script_action(action: dict) -> str:
|
|||
if CONF_REPEAT in action:
|
||||
return SCRIPT_ACTION_REPEAT
|
||||
|
||||
if CONF_CHOOSE in action:
|
||||
return SCRIPT_ACTION_CHOOSE
|
||||
|
||||
return SCRIPT_ACTION_CALL_SERVICE
|
||||
|
||||
|
||||
|
@ -1076,4 +1101,5 @@ ACTION_TYPE_SCHEMAS: Dict[str, Callable[[Any], dict]] = {
|
|||
SCRIPT_ACTION_DEVICE_AUTOMATION: DEVICE_ACTION_SCHEMA,
|
||||
SCRIPT_ACTION_ACTIVATE_SCENE: _SCRIPT_SCENE_SCHEMA,
|
||||
SCRIPT_ACTION_REPEAT: _SCRIPT_REPEAT_SCHEMA,
|
||||
SCRIPT_ACTION_CHOOSE: _SCRIPT_CHOOSE_SCHEMA,
|
||||
}
|
||||
|
|
|
@ -15,9 +15,12 @@ import homeassistant.components.scene as scene
|
|||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID,
|
||||
CONF_ALIAS,
|
||||
CONF_CHOOSE,
|
||||
CONF_CONDITION,
|
||||
CONF_CONDITIONS,
|
||||
CONF_CONTINUE_ON_TIMEOUT,
|
||||
CONF_COUNT,
|
||||
CONF_DEFAULT,
|
||||
CONF_DELAY,
|
||||
CONF_DEVICE_ID,
|
||||
CONF_DOMAIN,
|
||||
|
@ -138,9 +141,9 @@ class _ScriptRun:
|
|||
if not self._stop.is_set():
|
||||
self._script._changed() # pylint: disable=protected-access
|
||||
|
||||
@property
|
||||
def _config_cache(self):
|
||||
return self._script._config_cache # pylint: disable=protected-access
|
||||
async def _async_get_condition(self, config):
|
||||
# pylint: disable=protected-access
|
||||
return await self._script._async_get_condition(config)
|
||||
|
||||
def _log(self, msg, *args, level=logging.INFO):
|
||||
self._script._log(msg, *args, level=level) # pylint: disable=protected-access
|
||||
|
@ -404,14 +407,6 @@ class _ScriptRun:
|
|||
self._action[CONF_EVENT], event_data, context=self._context
|
||||
)
|
||||
|
||||
async def _async_get_condition(self, config):
|
||||
config_cache_key = frozenset((k, str(v)) for k, v in config.items())
|
||||
cond = self._config_cache.get(config_cache_key)
|
||||
if not cond:
|
||||
cond = await condition.async_from_config(self._hass, config, False)
|
||||
self._config_cache[config_cache_key] = cond
|
||||
return cond
|
||||
|
||||
async def _async_condition_step(self):
|
||||
"""Test if condition is matching."""
|
||||
self._script.last_action = self._action.get(
|
||||
|
@ -434,16 +429,13 @@ class _ScriptRun:
|
|||
repeat_vars = {"repeat": {"first": iteration == 1, "index": iteration}}
|
||||
if extra_vars:
|
||||
repeat_vars["repeat"].update(extra_vars)
|
||||
task = self._hass.async_create_task(
|
||||
# pylint: disable=protected-access
|
||||
self._script._repeat_script[self._step].async_run(
|
||||
await self._async_run_script(
|
||||
self._script._get_repeat_script(self._step),
|
||||
# Add repeat to variables. Override if it already exists in case of
|
||||
# nested calls.
|
||||
{**(self._variables or {}), **repeat_vars},
|
||||
self._context,
|
||||
)
|
||||
)
|
||||
await self._async_run_long_action(task)
|
||||
|
||||
if CONF_COUNT in repeat:
|
||||
count = repeat[CONF_COUNT]
|
||||
|
@ -487,6 +479,27 @@ class _ScriptRun:
|
|||
):
|
||||
break
|
||||
|
||||
async def _async_choose_step(self):
|
||||
"""Choose a sequence."""
|
||||
# pylint: disable=protected-access
|
||||
choose_data = await self._script._async_get_choose_data(self._step)
|
||||
|
||||
for conditions, script in choose_data["choices"]:
|
||||
if all(condition(self._hass, self._variables) for condition in conditions):
|
||||
await self._async_run_script(script)
|
||||
return
|
||||
|
||||
if choose_data["default"]:
|
||||
await self._async_run_script(choose_data["default"])
|
||||
|
||||
async def _async_run_script(self, script, variables=None):
|
||||
"""Execute a script."""
|
||||
await self._async_run_long_action(
|
||||
self._hass.async_create_task(
|
||||
script.async_run(variables or self._variables, self._context)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class _QueuedScriptRun(_ScriptRun):
|
||||
"""Manage queued Script sequence run."""
|
||||
|
@ -562,27 +575,15 @@ class Script:
|
|||
self.last_triggered: Optional[datetime] = None
|
||||
self.can_cancel = True
|
||||
|
||||
self._repeat_script = {}
|
||||
for step, action in enumerate(sequence):
|
||||
if cv.determine_script_action(action) == cv.SCRIPT_ACTION_REPEAT:
|
||||
step_name = action.get(CONF_ALIAS, f"Repeat at step {step}")
|
||||
sub_script = Script(
|
||||
hass,
|
||||
action[CONF_REPEAT][CONF_SEQUENCE],
|
||||
f"{name}: {step_name}",
|
||||
script_mode=SCRIPT_MODE_PARALLEL,
|
||||
logger=self._logger,
|
||||
)
|
||||
sub_script.change_listener = partial(
|
||||
self._chain_change_listener, sub_script
|
||||
)
|
||||
self._repeat_script[step] = sub_script
|
||||
|
||||
self._runs: List[_ScriptRun] = []
|
||||
self._max_runs = max_runs
|
||||
if script_mode == SCRIPT_MODE_QUEUED:
|
||||
self._queue_lck = asyncio.Lock()
|
||||
self._config_cache: Dict[Set[Tuple], Callable[..., bool]] = {}
|
||||
self._repeat_script: Dict[int, Script] = {}
|
||||
self._choose_data: Dict[
|
||||
int, List[Tuple[List[Callable[[HomeAssistant, Dict], bool]], Script]]
|
||||
] = {}
|
||||
self._referenced_entities: Optional[Set[str]] = None
|
||||
self._referenced_devices: Optional[Set[str]] = None
|
||||
|
||||
|
@ -701,6 +702,78 @@ class Script:
|
|||
if self.is_running:
|
||||
await asyncio.shield(self._async_stop(update_state))
|
||||
|
||||
async def _async_get_condition(self, config):
|
||||
config_cache_key = frozenset((k, str(v)) for k, v in config.items())
|
||||
cond = self._config_cache.get(config_cache_key)
|
||||
if not cond:
|
||||
cond = await condition.async_from_config(self._hass, config, False)
|
||||
self._config_cache[config_cache_key] = cond
|
||||
return cond
|
||||
|
||||
def _prep_repeat_script(self, step):
|
||||
action = self.sequence[step]
|
||||
step_name = action.get(CONF_ALIAS, f"Repeat at step {step+1}")
|
||||
sub_script = Script(
|
||||
self._hass,
|
||||
action[CONF_REPEAT][CONF_SEQUENCE],
|
||||
f"{self.name}: {step_name}",
|
||||
script_mode=SCRIPT_MODE_PARALLEL,
|
||||
logger=self._logger,
|
||||
)
|
||||
sub_script.change_listener = partial(self._chain_change_listener, sub_script)
|
||||
return sub_script
|
||||
|
||||
def _get_repeat_script(self, step):
|
||||
sub_script = self._repeat_script.get(step)
|
||||
if not sub_script:
|
||||
sub_script = self._prep_repeat_script(step)
|
||||
self._repeat_script[step] = sub_script
|
||||
return sub_script
|
||||
|
||||
async def _async_prep_choose_data(self, step):
|
||||
action = self.sequence[step]
|
||||
step_name = action.get(CONF_ALIAS, f"Choose at step {step+1}")
|
||||
choices = []
|
||||
for idx, choice in enumerate(action[CONF_CHOOSE], start=1):
|
||||
conditions = [
|
||||
await self._async_get_condition(config)
|
||||
for config in choice.get(CONF_CONDITIONS, [])
|
||||
]
|
||||
sub_script = Script(
|
||||
self._hass,
|
||||
choice[CONF_SEQUENCE],
|
||||
f"{self.name}: {step_name}: choice {idx}",
|
||||
script_mode=SCRIPT_MODE_PARALLEL,
|
||||
logger=self._logger,
|
||||
)
|
||||
sub_script.change_listener = partial(
|
||||
self._chain_change_listener, sub_script
|
||||
)
|
||||
choices.append((conditions, sub_script))
|
||||
|
||||
if CONF_DEFAULT in action:
|
||||
default_script = Script(
|
||||
self._hass,
|
||||
action[CONF_DEFAULT],
|
||||
f"{self.name}: {step_name}: default",
|
||||
script_mode=SCRIPT_MODE_PARALLEL,
|
||||
logger=self._logger,
|
||||
)
|
||||
default_script.change_listener = partial(
|
||||
self._chain_change_listener, default_script
|
||||
)
|
||||
else:
|
||||
default_script = None
|
||||
|
||||
return {"choices": choices, "default": default_script}
|
||||
|
||||
async def _async_get_choose_data(self, step):
|
||||
choose_data = self._choose_data.get(step)
|
||||
if not choose_data:
|
||||
choose_data = await self._async_prep_choose_data(step)
|
||||
self._choose_data[step] = choose_data
|
||||
return choose_data
|
||||
|
||||
def _log(self, msg, *args, level=logging.INFO):
|
||||
if self.name:
|
||||
msg = f"%s: {msg}"
|
||||
|
|
|
@ -877,6 +877,41 @@ async def test_repeat_conditional(hass, condition):
|
|||
assert event.data.get("index") == str(index + 1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("var,result", [(1, "first"), (2, "second"), (3, "default")])
|
||||
async def test_choose(hass, var, result):
|
||||
"""Test choose action."""
|
||||
event = "test_event"
|
||||
events = async_capture_events(hass, event)
|
||||
sequence = cv.SCRIPT_SCHEMA(
|
||||
{
|
||||
"choose": [
|
||||
{
|
||||
"conditions": {
|
||||
"condition": "template",
|
||||
"value_template": "{{ var == 1 }}",
|
||||
},
|
||||
"sequence": {"event": event, "event_data": {"choice": "first"}},
|
||||
},
|
||||
{
|
||||
"conditions": {
|
||||
"condition": "template",
|
||||
"value_template": "{{ var == 2 }}",
|
||||
},
|
||||
"sequence": {"event": event, "event_data": {"choice": "second"}},
|
||||
},
|
||||
],
|
||||
"default": {"event": event, "event_data": {"choice": "default"}},
|
||||
}
|
||||
)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
|
||||
await script_obj.async_run({"var": var})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].data["choice"] == result
|
||||
|
||||
|
||||
async def test_last_triggered(hass):
|
||||
"""Test the last_triggered."""
|
||||
event = "test_event"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue