Add if/else automation/script action (#69811)

Co-authored-by: Erik <erik@montnemery.com>
This commit is contained in:
Franck Nijhof 2022-04-12 15:02:17 +02:00 committed by GitHub
parent 5bb3d6487b
commit 67b200a532
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 375 additions and 1 deletions

View file

@ -33,10 +33,12 @@ from homeassistant.const import (
CONF_DELAY,
CONF_DEVICE_ID,
CONF_DOMAIN,
CONF_ELSE,
CONF_ERROR,
CONF_EVENT,
CONF_EVENT_DATA,
CONF_EVENT_DATA_TEMPLATE,
CONF_IF,
CONF_MODE,
CONF_REPEAT,
CONF_SCENE,
@ -44,6 +46,7 @@ from homeassistant.const import (
CONF_SERVICE,
CONF_STOP,
CONF_TARGET,
CONF_THEN,
CONF_TIMEOUT,
CONF_UNTIL,
CONF_VARIABLES,
@ -295,6 +298,15 @@ async def async_validate_action_config(
hass, choose_conf[CONF_SEQUENCE]
)
elif action_type == cv.SCRIPT_ACTION_IF:
config[CONF_IF] = await condition.async_validate_conditions_config(
hass, config[CONF_IF]
)
config[CONF_THEN] = await async_validate_actions_config(hass, config[CONF_THEN])
if CONF_ELSE in config:
config[CONF_ELSE] = await async_validate_actions_config(
hass, config[CONF_ELSE]
)
else:
raise ValueError(f"No validation for {action_type}")
@ -780,6 +792,31 @@ class _ScriptRun:
with trace_path(["default"]):
await self._async_run_script(choose_data["default"])
async def _async_if_step(self) -> None:
"""If sequence."""
# pylint: disable=protected-access
if_data = await self._script._async_get_if_data(self._step)
test_conditions = False
try:
with trace_path("if"):
test_conditions = self._test_conditions(
if_data["if_conditions"], "if", "condition"
)
except exceptions.ConditionError as ex:
_LOGGER.warning("Error in 'if' evaluation:\n%s", ex)
if test_conditions:
trace_set_result(choice="then")
with trace_path("then"):
await self._async_run_script(if_data["if_then"])
return
if if_data["if_else"] is not None:
trace_set_result(choice="else")
with trace_path("else"):
await self._async_run_script(if_data["if_else"])
async def _async_wait_for_trigger_step(self):
"""Wait for a trigger event."""
if CONF_TIMEOUT in self._action:
@ -970,6 +1007,12 @@ class _ChooseData(TypedDict):
default: Script | None
class _IfData(TypedDict):
if_conditions: list[ConditionCheckerType]
if_then: Script
if_else: Script | None
class Script:
"""Representation of a script."""
@ -1031,6 +1074,7 @@ class Script:
self._config_cache: dict[set[tuple], Callable[..., bool]] = {}
self._repeat_script: dict[int, Script] = {}
self._choose_data: dict[int, _ChooseData] = {}
self._if_data: dict[int, _IfData] = {}
self._referenced_entities: set[str] | None = None
self._referenced_devices: set[str] | None = None
self._referenced_areas: set[str] | None = None
@ -1070,6 +1114,10 @@ class Script:
script.update_logger(self._logger)
if choose_data["default"] is not None:
choose_data["default"].update_logger(self._logger)
for if_data in self._if_data.values():
if_data["if_then"].update_logger(self._logger)
if if_data["if_else"] is not None:
if_data["if_else"].update_logger(self._logger)
def _changed(self) -> None:
if self._change_listener_job:
@ -1125,6 +1173,11 @@ class Script:
if CONF_DEFAULT in step:
Script._find_referenced_areas(referenced, step[CONF_DEFAULT])
elif action == cv.SCRIPT_ACTION_IF:
Script._find_referenced_areas(referenced, step[CONF_THEN])
if CONF_ELSE in step:
Script._find_referenced_areas(referenced, step[CONF_ELSE])
@property
def referenced_devices(self):
"""Return a set of referenced devices."""
@ -1162,6 +1215,13 @@ class Script:
if CONF_DEFAULT in step:
Script._find_referenced_devices(referenced, step[CONF_DEFAULT])
elif action == cv.SCRIPT_ACTION_IF:
for cond in step[CONF_IF]:
referenced |= condition.async_extract_devices(cond)
Script._find_referenced_devices(referenced, step[CONF_THEN])
if CONF_ELSE in step:
Script._find_referenced_devices(referenced, step[CONF_ELSE])
@property
def referenced_entities(self):
"""Return a set of referenced entities."""
@ -1200,6 +1260,13 @@ class Script:
if CONF_DEFAULT in step:
Script._find_referenced_entities(referenced, step[CONF_DEFAULT])
elif action == cv.SCRIPT_ACTION_IF:
for cond in step[CONF_IF]:
referenced |= condition.async_extract_entities(cond)
Script._find_referenced_entities(referenced, step[CONF_THEN])
if CONF_ELSE in step:
Script._find_referenced_entities(referenced, step[CONF_ELSE])
def run(
self, variables: _VarsType | None = None, context: Context | None = None
) -> None:
@ -1411,6 +1478,58 @@ class Script:
self._choose_data[step] = choose_data
return choose_data
async def _async_prep_if_data(self, step: int) -> _IfData:
"""Prepare data for an if statement."""
action = self.sequence[step]
step_name = action.get(CONF_ALIAS, f"If at step {step+1}")
conditions = [
await self._async_get_condition(config) for config in action[CONF_IF]
]
then_script = Script(
self._hass,
action[CONF_THEN],
f"{self.name}: {step_name}",
self.domain,
running_description=self.running_description,
script_mode=SCRIPT_MODE_PARALLEL,
max_runs=self.max_runs,
logger=self._logger,
top_level=False,
)
then_script.change_listener = partial(self._chain_change_listener, then_script)
if CONF_ELSE in action:
else_script = Script(
self._hass,
action[CONF_ELSE],
f"{self.name}: {step_name}",
self.domain,
running_description=self.running_description,
script_mode=SCRIPT_MODE_PARALLEL,
max_runs=self.max_runs,
logger=self._logger,
top_level=False,
)
else_script.change_listener = partial(
self._chain_change_listener, else_script
)
else:
else_script = None
return _IfData(
if_conditions=conditions,
if_then=then_script,
if_else=else_script,
)
async def _async_get_if_data(self, step: int) -> _IfData:
if not (if_data := self._if_data.get(step)):
if_data = await self._async_prep_if_data(step)
self._if_data[step] = if_data
return if_data
def _log(
self, msg: str, *args: Any, level: int = logging.INFO, **kwargs: Any
) -> None: