Add choose script action (#37818)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Phil Bruckner 2020-07-14 12:22:54 -05:00 committed by GitHub
parent 515ad6164d
commit 7e280e2b27
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 176 additions and 39 deletions

View file

@ -47,6 +47,7 @@ CONF_BINARY_SENSORS = "binary_sensors"
CONF_BRIGHTNESS = "brightness"
CONF_BROADCAST_ADDRESS = "broadcast_address"
CONF_BROADCAST_PORT = "broadcast_port"
CONF_CHOOSE = "choose"
CONF_CLIENT_ID = "client_id"
CONF_CLIENT_SECRET = "client_secret"
CONF_CODE = "code"
@ -59,6 +60,7 @@ CONF_COMMAND_OPEN = "command_open"
CONF_COMMAND_STATE = "command_state"
CONF_COMMAND_STOP = "command_stop"
CONF_CONDITION = "condition"
CONF_CONDITIONS = "conditions"
CONF_CONTINUE_ON_TIMEOUT = "continue_on_timeout"
CONF_COUNT = "count"
CONF_COVERS = "covers"
@ -66,6 +68,7 @@ CONF_CURRENCY = "currency"
CONF_CUSTOMIZE = "customize"
CONF_CUSTOMIZE_DOMAIN = "customize_domain"
CONF_CUSTOMIZE_GLOB = "customize_glob"
CONF_DEFAULT = "default"
CONF_DELAY = "delay"
CONF_DELAY_TIME = "delay_time"
CONF_DEVICE = "device"

View file

@ -38,9 +38,12 @@ from homeassistant.const import (
CONF_ABOVE,
CONF_ALIAS,
CONF_BELOW,
CONF_CHOOSE,
CONF_CONDITION,
CONF_CONDITIONS,
CONF_CONTINUE_ON_TIMEOUT,
CONF_COUNT,
CONF_DEFAULT,
CONF_DELAY,
CONF_DEVICE_ID,
CONF_DOMAIN,
@ -930,7 +933,7 @@ ZONE_CONDITION_SCHEMA = vol.Schema(
AND_CONDITION_SCHEMA = vol.Schema(
{
vol.Required(CONF_CONDITION): "and",
vol.Required("conditions"): vol.All(
vol.Required(CONF_CONDITIONS): vol.All(
ensure_list,
# pylint: disable=unnecessary-lambda
[lambda value: CONDITION_SCHEMA(value)],
@ -941,7 +944,7 @@ AND_CONDITION_SCHEMA = vol.Schema(
OR_CONDITION_SCHEMA = vol.Schema(
{
vol.Required(CONF_CONDITION): "or",
vol.Required("conditions"): vol.All(
vol.Required(CONF_CONDITIONS): vol.All(
ensure_list,
# pylint: disable=unnecessary-lambda
[lambda value: CONDITION_SCHEMA(value)],
@ -952,7 +955,7 @@ OR_CONDITION_SCHEMA = vol.Schema(
NOT_CONDITION_SCHEMA = vol.Schema(
{
vol.Required(CONF_CONDITION): "not",
vol.Required("conditions"): vol.All(
vol.Required(CONF_CONDITIONS): vol.All(
ensure_list,
# pylint: disable=unnecessary-lambda
[lambda value: CONDITION_SCHEMA(value)],
@ -1031,6 +1034,24 @@ _SCRIPT_REPEAT_SCHEMA = vol.Schema(
}
)
_SCRIPT_CHOOSE_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ALIAS): string,
vol.Required(CONF_CHOOSE): vol.All(
ensure_list,
[
{
vol.Required(CONF_CONDITIONS): vol.All(
ensure_list, [CONDITION_SCHEMA]
),
vol.Required(CONF_SEQUENCE): SCRIPT_SCHEMA,
}
],
),
vol.Optional(CONF_DEFAULT): SCRIPT_SCHEMA,
}
)
SCRIPT_ACTION_DELAY = "delay"
SCRIPT_ACTION_WAIT_TEMPLATE = "wait_template"
SCRIPT_ACTION_CHECK_CONDITION = "condition"
@ -1039,6 +1060,7 @@ SCRIPT_ACTION_CALL_SERVICE = "call_service"
SCRIPT_ACTION_DEVICE_AUTOMATION = "device"
SCRIPT_ACTION_ACTIVATE_SCENE = "scene"
SCRIPT_ACTION_REPEAT = "repeat"
SCRIPT_ACTION_CHOOSE = "choose"
def determine_script_action(action: dict) -> str:
@ -1064,6 +1086,9 @@ def determine_script_action(action: dict) -> str:
if CONF_REPEAT in action:
return SCRIPT_ACTION_REPEAT
if CONF_CHOOSE in action:
return SCRIPT_ACTION_CHOOSE
return SCRIPT_ACTION_CALL_SERVICE
@ -1076,4 +1101,5 @@ ACTION_TYPE_SCHEMAS: Dict[str, Callable[[Any], dict]] = {
SCRIPT_ACTION_DEVICE_AUTOMATION: DEVICE_ACTION_SCHEMA,
SCRIPT_ACTION_ACTIVATE_SCENE: _SCRIPT_SCENE_SCHEMA,
SCRIPT_ACTION_REPEAT: _SCRIPT_REPEAT_SCHEMA,
SCRIPT_ACTION_CHOOSE: _SCRIPT_CHOOSE_SCHEMA,
}

View file

@ -15,9 +15,12 @@ import homeassistant.components.scene as scene
from homeassistant.const import (
ATTR_ENTITY_ID,
CONF_ALIAS,
CONF_CHOOSE,
CONF_CONDITION,
CONF_CONDITIONS,
CONF_CONTINUE_ON_TIMEOUT,
CONF_COUNT,
CONF_DEFAULT,
CONF_DELAY,
CONF_DEVICE_ID,
CONF_DOMAIN,
@ -138,9 +141,9 @@ class _ScriptRun:
if not self._stop.is_set():
self._script._changed() # pylint: disable=protected-access
@property
def _config_cache(self):
return self._script._config_cache # pylint: disable=protected-access
async def _async_get_condition(self, config):
# pylint: disable=protected-access
return await self._script._async_get_condition(config)
def _log(self, msg, *args, level=logging.INFO):
self._script._log(msg, *args, level=level) # pylint: disable=protected-access
@ -404,14 +407,6 @@ class _ScriptRun:
self._action[CONF_EVENT], event_data, context=self._context
)
async def _async_get_condition(self, config):
config_cache_key = frozenset((k, str(v)) for k, v in config.items())
cond = self._config_cache.get(config_cache_key)
if not cond:
cond = await condition.async_from_config(self._hass, config, False)
self._config_cache[config_cache_key] = cond
return cond
async def _async_condition_step(self):
"""Test if condition is matching."""
self._script.last_action = self._action.get(
@ -434,16 +429,13 @@ class _ScriptRun:
repeat_vars = {"repeat": {"first": iteration == 1, "index": iteration}}
if extra_vars:
repeat_vars["repeat"].update(extra_vars)
task = self._hass.async_create_task(
# pylint: disable=protected-access
self._script._repeat_script[self._step].async_run(
await self._async_run_script(
self._script._get_repeat_script(self._step),
# Add repeat to variables. Override if it already exists in case of
# nested calls.
{**(self._variables or {}), **repeat_vars},
self._context,
)
)
await self._async_run_long_action(task)
if CONF_COUNT in repeat:
count = repeat[CONF_COUNT]
@ -487,6 +479,27 @@ class _ScriptRun:
):
break
async def _async_choose_step(self):
"""Choose a sequence."""
# pylint: disable=protected-access
choose_data = await self._script._async_get_choose_data(self._step)
for conditions, script in choose_data["choices"]:
if all(condition(self._hass, self._variables) for condition in conditions):
await self._async_run_script(script)
return
if choose_data["default"]:
await self._async_run_script(choose_data["default"])
async def _async_run_script(self, script, variables=None):
"""Execute a script."""
await self._async_run_long_action(
self._hass.async_create_task(
script.async_run(variables or self._variables, self._context)
)
)
class _QueuedScriptRun(_ScriptRun):
"""Manage queued Script sequence run."""
@ -562,27 +575,15 @@ class Script:
self.last_triggered: Optional[datetime] = None
self.can_cancel = True
self._repeat_script = {}
for step, action in enumerate(sequence):
if cv.determine_script_action(action) == cv.SCRIPT_ACTION_REPEAT:
step_name = action.get(CONF_ALIAS, f"Repeat at step {step}")
sub_script = Script(
hass,
action[CONF_REPEAT][CONF_SEQUENCE],
f"{name}: {step_name}",
script_mode=SCRIPT_MODE_PARALLEL,
logger=self._logger,
)
sub_script.change_listener = partial(
self._chain_change_listener, sub_script
)
self._repeat_script[step] = sub_script
self._runs: List[_ScriptRun] = []
self._max_runs = max_runs
if script_mode == SCRIPT_MODE_QUEUED:
self._queue_lck = asyncio.Lock()
self._config_cache: Dict[Set[Tuple], Callable[..., bool]] = {}
self._repeat_script: Dict[int, Script] = {}
self._choose_data: Dict[
int, List[Tuple[List[Callable[[HomeAssistant, Dict], bool]], Script]]
] = {}
self._referenced_entities: Optional[Set[str]] = None
self._referenced_devices: Optional[Set[str]] = None
@ -701,6 +702,78 @@ class Script:
if self.is_running:
await asyncio.shield(self._async_stop(update_state))
async def _async_get_condition(self, config):
config_cache_key = frozenset((k, str(v)) for k, v in config.items())
cond = self._config_cache.get(config_cache_key)
if not cond:
cond = await condition.async_from_config(self._hass, config, False)
self._config_cache[config_cache_key] = cond
return cond
def _prep_repeat_script(self, step):
action = self.sequence[step]
step_name = action.get(CONF_ALIAS, f"Repeat at step {step+1}")
sub_script = Script(
self._hass,
action[CONF_REPEAT][CONF_SEQUENCE],
f"{self.name}: {step_name}",
script_mode=SCRIPT_MODE_PARALLEL,
logger=self._logger,
)
sub_script.change_listener = partial(self._chain_change_listener, sub_script)
return sub_script
def _get_repeat_script(self, step):
sub_script = self._repeat_script.get(step)
if not sub_script:
sub_script = self._prep_repeat_script(step)
self._repeat_script[step] = sub_script
return sub_script
async def _async_prep_choose_data(self, step):
action = self.sequence[step]
step_name = action.get(CONF_ALIAS, f"Choose at step {step+1}")
choices = []
for idx, choice in enumerate(action[CONF_CHOOSE], start=1):
conditions = [
await self._async_get_condition(config)
for config in choice.get(CONF_CONDITIONS, [])
]
sub_script = Script(
self._hass,
choice[CONF_SEQUENCE],
f"{self.name}: {step_name}: choice {idx}",
script_mode=SCRIPT_MODE_PARALLEL,
logger=self._logger,
)
sub_script.change_listener = partial(
self._chain_change_listener, sub_script
)
choices.append((conditions, sub_script))
if CONF_DEFAULT in action:
default_script = Script(
self._hass,
action[CONF_DEFAULT],
f"{self.name}: {step_name}: default",
script_mode=SCRIPT_MODE_PARALLEL,
logger=self._logger,
)
default_script.change_listener = partial(
self._chain_change_listener, default_script
)
else:
default_script = None
return {"choices": choices, "default": default_script}
async def _async_get_choose_data(self, step):
choose_data = self._choose_data.get(step)
if not choose_data:
choose_data = await self._async_prep_choose_data(step)
self._choose_data[step] = choose_data
return choose_data
def _log(self, msg, *args, level=logging.INFO):
if self.name:
msg = f"%s: {msg}"

View file

@ -877,6 +877,41 @@ async def test_repeat_conditional(hass, condition):
assert event.data.get("index") == str(index + 1)
@pytest.mark.parametrize("var,result", [(1, "first"), (2, "second"), (3, "default")])
async def test_choose(hass, var, result):
"""Test choose action."""
event = "test_event"
events = async_capture_events(hass, event)
sequence = cv.SCRIPT_SCHEMA(
{
"choose": [
{
"conditions": {
"condition": "template",
"value_template": "{{ var == 1 }}",
},
"sequence": {"event": event, "event_data": {"choice": "first"}},
},
{
"conditions": {
"condition": "template",
"value_template": "{{ var == 2 }}",
},
"sequence": {"event": event, "event_data": {"choice": "second"}},
},
],
"default": {"event": event, "event_data": {"choice": "default"}},
}
)
script_obj = script.Script(hass, sequence)
await script_obj.async_run({"var": var})
await hass.async_block_till_done()
assert len(events) == 1
assert events[0].data["choice"] == result
async def test_last_triggered(hass):
"""Test the last_triggered."""
event = "test_event"