Add parallel automation/script actions (#69903)

This commit is contained in:
Franck Nijhof 2022-04-13 22:07:44 +02:00 committed by GitHub
parent 3df6d26712
commit d704d4f853
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 282 additions and 0 deletions

View file

@ -189,6 +189,7 @@ CONF_NAME: Final = "name"
CONF_OFFSET: Final = "offset"
CONF_OPTIMISTIC: Final = "optimistic"
CONF_PACKAGES: Final = "packages"
CONF_PARALLEL: Final = "parallel"
CONF_PARAMS: Final = "params"
CONF_PASSWORD: Final = "password"
CONF_PATH: Final = "path"

View file

@ -53,6 +53,7 @@ from homeassistant.const import (
CONF_ID,
CONF_IF,
CONF_MATCH,
CONF_PARALLEL,
CONF_PLATFORM,
CONF_REPEAT,
CONF_SCAN_INTERVAL,
@ -1455,6 +1456,32 @@ _SCRIPT_ERROR_SCHEMA = vol.Schema(
}
)
_SCRIPT_PARALLEL_SEQUENCE = vol.Schema(
{
**SCRIPT_ACTION_BASE_SCHEMA,
vol.Required(CONF_SEQUENCE): SCRIPT_SCHEMA,
}
)
_parallel_sequence_action = vol.All(
# Wrap a shorthand sequences in a parallel action
SCRIPT_SCHEMA,
lambda config: {
CONF_SEQUENCE: config,
},
)
_SCRIPT_PARALLEL_SCHEMA = vol.Schema(
{
**SCRIPT_ACTION_BASE_SCHEMA,
vol.Required(CONF_PARALLEL): vol.All(
ensure_list, [vol.Any(_SCRIPT_PARALLEL_SEQUENCE, _parallel_sequence_action)]
),
}
)
SCRIPT_ACTION_DELAY = "delay"
SCRIPT_ACTION_WAIT_TEMPLATE = "wait_template"
SCRIPT_ACTION_CHECK_CONDITION = "condition"
@ -1469,6 +1496,7 @@ SCRIPT_ACTION_VARIABLES = "variables"
SCRIPT_ACTION_STOP = "stop"
SCRIPT_ACTION_ERROR = "error"
SCRIPT_ACTION_IF = "if"
SCRIPT_ACTION_PARALLEL = "parallel"
def determine_script_action(action: dict[str, Any]) -> str:
@ -1515,6 +1543,9 @@ def determine_script_action(action: dict[str, Any]) -> str:
if CONF_ERROR in action:
return SCRIPT_ACTION_ERROR
if CONF_PARALLEL in action:
return SCRIPT_ACTION_PARALLEL
raise ValueError("Unable to determine action")
@ -1533,6 +1564,7 @@ ACTION_TYPE_SCHEMAS: dict[str, Callable[[Any], dict]] = {
SCRIPT_ACTION_STOP: _SCRIPT_STOP_SCHEMA,
SCRIPT_ACTION_ERROR: _SCRIPT_ERROR_SCHEMA,
SCRIPT_ACTION_IF: _SCRIPT_IF_SCHEMA,
SCRIPT_ACTION_PARALLEL: _SCRIPT_PARALLEL_SCHEMA,
}

View file

@ -5,6 +5,7 @@ import asyncio
from collections.abc import Callable, Sequence
from contextlib import asynccontextmanager, suppress
from contextvars import ContextVar
from copy import copy
from datetime import datetime, timedelta
from functools import partial
import itertools
@ -40,6 +41,7 @@ from homeassistant.const import (
CONF_EVENT_DATA_TEMPLATE,
CONF_IF,
CONF_MODE,
CONF_PARALLEL,
CONF_REPEAT,
CONF_SCENE,
CONF_SEQUENCE,
@ -79,6 +81,7 @@ from .trace import (
trace_id_get,
trace_path,
trace_path_get,
trace_path_stack_cv,
trace_set_result,
trace_stack_cv,
trace_stack_pop,
@ -307,6 +310,13 @@ async def async_validate_action_config(
config[CONF_ELSE] = await async_validate_actions_config(
hass, config[CONF_ELSE]
)
elif action_type == cv.SCRIPT_ACTION_PARALLEL:
for parallel_conf in config[CONF_PARALLEL]:
parallel_conf[CONF_SEQUENCE] = await async_validate_actions_config(
hass, parallel_conf[CONF_SEQUENCE]
)
else:
raise ValueError(f"No validation for {action_type}")
@ -896,6 +906,26 @@ class _ScriptRun:
trace_set_result(error=error)
raise _AbortScript(error)
@async_trace_path("parallel")
async def _async_parallel_step(self) -> None:
"""Run a sequence in parallel."""
# pylint: disable=protected-access
scripts = await self._script._async_get_parallel_scripts(self._step)
async def async_run_with_trace(idx: int, script: Script) -> None:
"""Run a script with a trace path."""
trace_path_stack_cv.set(copy(trace_path_stack_cv.get()))
with trace_path([str(idx), "sequence"]):
await self._async_run_script(script)
results = await asyncio.gather(
*(async_run_with_trace(idx, script) for idx, script in enumerate(scripts)),
return_exceptions=True,
)
for result in results:
if isinstance(result, Exception):
raise result
async def _async_run_script(self, script: Script) -> None:
"""Execute a script."""
await self._async_run_long_action(
@ -1075,6 +1105,7 @@ class Script:
self._repeat_script: dict[int, Script] = {}
self._choose_data: dict[int, _ChooseData] = {}
self._if_data: dict[int, _IfData] = {}
self._parallel_scripts: dict[int, list[Script]] = {}
self._referenced_entities: set[str] | None = None
self._referenced_devices: set[str] | None = None
self._referenced_areas: set[str] | None = None
@ -1109,6 +1140,9 @@ class Script:
self._set_logger(logger)
for script in self._repeat_script.values():
script.update_logger(self._logger)
for parallel_scripts in self._parallel_scripts.values():
for parallel_script in parallel_scripts:
parallel_script.update_logger(self._logger)
for choose_data in self._choose_data.values():
for _, script in choose_data["choices"]:
script.update_logger(self._logger)
@ -1178,6 +1212,10 @@ class Script:
if CONF_ELSE in step:
Script._find_referenced_areas(referenced, step[CONF_ELSE])
elif action == cv.SCRIPT_ACTION_PARALLEL:
for script in step[CONF_PARALLEL]:
Script._find_referenced_areas(referenced, script[CONF_SEQUENCE])
@property
def referenced_devices(self):
"""Return a set of referenced devices."""
@ -1222,6 +1260,10 @@ class Script:
if CONF_ELSE in step:
Script._find_referenced_devices(referenced, step[CONF_ELSE])
elif action == cv.SCRIPT_ACTION_PARALLEL:
for script in step[CONF_PARALLEL]:
Script._find_referenced_devices(referenced, script[CONF_SEQUENCE])
@property
def referenced_entities(self):
"""Return a set of referenced entities."""
@ -1267,6 +1309,10 @@ class Script:
if CONF_ELSE in step:
Script._find_referenced_entities(referenced, step[CONF_ELSE])
elif action == cv.SCRIPT_ACTION_PARALLEL:
for script in step[CONF_PARALLEL]:
Script._find_referenced_entities(referenced, script[CONF_SEQUENCE])
def run(
self, variables: _VarsType | None = None, context: Context | None = None
) -> None:
@ -1530,6 +1576,36 @@ class Script:
self._if_data[step] = if_data
return if_data
async def _async_prep_parallel_scripts(self, step: int) -> list[Script]:
action = self.sequence[step]
step_name = action.get(CONF_ALIAS, f"Parallel action at step {step+1}")
parallel_scripts: list[Script] = []
for idx, parallel_script in enumerate(action[CONF_PARALLEL], start=1):
parallel_name = parallel_script.get(CONF_ALIAS, f"parallel {idx}")
parallel_script = Script(
self._hass,
parallel_script[CONF_SEQUENCE],
f"{self.name}: {step_name}: {parallel_name}",
self.domain,
running_description=self.running_description,
script_mode=SCRIPT_MODE_PARALLEL,
max_runs=self.max_runs,
logger=self._logger,
top_level=False,
)
parallel_script.change_listener = partial(
self._chain_change_listener, parallel_script
)
parallel_scripts.append(parallel_script)
return parallel_scripts
async def _async_get_parallel_scripts(self, step: int) -> list[Script]:
if not (parallel_scripts := self._parallel_scripts.get(step)):
parallel_scripts = await self._async_prep_parallel_scripts(step)
self._parallel_scripts[step] = parallel_scripts
return parallel_scripts
def _log(
self, msg: str, *args: Any, level: int = logging.INFO, **kwargs: Any
) -> None:

View file

@ -2682,6 +2682,148 @@ async def test_if_condition_validation(
)
async def test_parallel(hass: HomeAssistant, caplog: pytest.LogCaptureFixture) -> None:
"""Test parallel action."""
events = async_capture_events(hass, "test_event")
hass.states.async_set("switch.trigger", "off")
sequence = cv.SCRIPT_SCHEMA(
{
"parallel": [
{
"alias": "Sequential group",
"sequence": [
{
"alias": "Waiting for trigger",
"wait_for_trigger": {
"platform": "state",
"entity_id": "switch.trigger",
"to": "on",
},
},
{
"event": "test_event",
"event_data": {
"hello": "from action 1",
"what": "{{ what }}",
},
},
],
},
{
"alias": "Don't wait at all",
"event": "test_event",
"event_data": {"hello": "from action 2", "what": "{{ what }}"},
},
]
}
)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, "Waiting for trigger")
hass.async_create_task(
script_obj.async_run(MappingProxyType({"what": "world"}), Context())
)
await asyncio.wait_for(wait_started_flag.wait(), 1)
assert script_obj.is_running
hass.states.async_set("switch.trigger", "on")
await hass.async_block_till_done()
assert len(events) == 2
assert events[0].data["hello"] == "from action 2"
assert events[0].data["what"] == "world"
assert events[1].data["hello"] == "from action 1"
assert events[1].data["what"] == "world"
assert (
"Test Name: Parallel action at step 1: Sequential group: Executing step Waiting for trigger"
in caplog.text
)
assert (
"Parallel action at step 1: parallel 2: Executing step Don't wait at all"
in caplog.text
)
expected_trace = {
"0": [{"result": {}}],
"0/parallel/0/sequence/0": [
{
"result": {
"wait": {
"remaining": None,
"trigger": {
"entity_id": "switch.trigger",
"description": "state of switch.trigger",
},
}
}
}
],
"0/parallel/1/sequence/0": [
{
"variables": {"wait": {"remaining": None}},
"result": {
"event": "test_event",
"event_data": {"hello": "from action 2", "what": "world"},
},
}
],
"0/parallel/0/sequence/1": [
{
"variables": {"wait": {"remaining": None}},
"result": {
"event": "test_event",
"event_data": {"hello": "from action 1", "what": "world"},
},
}
],
}
assert_action_trace(expected_trace)
async def test_parallel_error(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test parallel action failure handling."""
events = async_capture_events(hass, "test_event")
sequence = cv.SCRIPT_SCHEMA(
{
"parallel": [
{"service": "epic.failure"},
]
}
)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
with pytest.raises(exceptions.ServiceNotFound):
await script_obj.async_run(context=Context())
assert len(events) == 0
expected_trace = {
"0": [{"error_type": ServiceNotFound, "result": {}}],
"0/parallel/0/sequence/0": [
{
"error_type": ServiceNotFound,
"result": {
"limit": 10,
"params": {
"domain": "epic",
"service": "failure",
"service_data": {},
"target": {},
},
"running_script": False,
},
}
],
}
assert_action_trace(expected_trace, expected_script_execution="error")
async def test_last_triggered(hass):
"""Test the last_triggered."""
event = "test_event"
@ -2881,6 +3023,14 @@ async def test_referenced_areas(hass):
}
],
},
{
"parallel": [
{
"service": "test.script",
"data": {"area_id": "area_parallel"},
}
],
},
]
),
"Test Name",
@ -2896,6 +3046,7 @@ async def test_referenced_areas(hass):
"area_service_not_list",
"area_if_then",
"area_if_else",
"area_parallel",
# 'area_service_template', # no area extraction from template
}
# Test we cache results.
@ -2988,6 +3139,14 @@ async def test_referenced_entities(hass):
}
],
},
{
"parallel": [
{
"service": "test.script",
"data": {"entity_id": "light.parallel"},
}
],
},
]
),
"Test Name",
@ -3006,6 +3165,7 @@ async def test_referenced_entities(hass):
"light.service_not_list",
"light.if_then",
"light.if_else",
"light.parallel",
# "light.service_template", # no entity extraction from template
"scene.hello",
"sensor.condition",
@ -3093,6 +3253,14 @@ async def test_referenced_devices(hass):
}
],
},
{
"parallel": [
{
"service": "test.script",
"target": {"device_id": "parallel-device"},
}
],
},
]
),
"Test Name",
@ -3113,6 +3281,7 @@ async def test_referenced_devices(hass):
"target-string-id",
"if-then",
"if-else",
"parallel-device",
}
# Test we cache results.
assert script_obj.referenced_devices is script_obj.referenced_devices
@ -3744,6 +3913,9 @@ async def test_validate_action_config(hass):
"then": [templated_device_action("if_then_event")],
"else": [templated_device_action("if_else_event")],
},
cv.SCRIPT_ACTION_PARALLEL: {
"parallel": [templated_device_action("parallel_event")],
},
}
expected_templates = {
cv.SCRIPT_ACTION_CHECK_CONDITION: None,
@ -3752,6 +3924,7 @@ async def test_validate_action_config(hass):
cv.SCRIPT_ACTION_CHOOSE: [["choose", 0, "sequence", 0], ["default", 0]],
cv.SCRIPT_ACTION_WAIT_FOR_TRIGGER: None,
cv.SCRIPT_ACTION_IF: None,
cv.SCRIPT_ACTION_PARALLEL: None,
}
for key in cv.ACTION_TYPE_SCHEMAS: