Add parallel automation/script actions (#69903)

This commit is contained in:
Franck Nijhof 2022-04-13 22:07:44 +02:00 committed by GitHub
parent 3df6d26712
commit d704d4f853
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 282 additions and 0 deletions

View file

@ -189,6 +189,7 @@ CONF_NAME: Final = "name"
CONF_OFFSET: Final = "offset" CONF_OFFSET: Final = "offset"
CONF_OPTIMISTIC: Final = "optimistic" CONF_OPTIMISTIC: Final = "optimistic"
CONF_PACKAGES: Final = "packages" CONF_PACKAGES: Final = "packages"
CONF_PARALLEL: Final = "parallel"
CONF_PARAMS: Final = "params" CONF_PARAMS: Final = "params"
CONF_PASSWORD: Final = "password" CONF_PASSWORD: Final = "password"
CONF_PATH: Final = "path" CONF_PATH: Final = "path"

View file

@ -53,6 +53,7 @@ from homeassistant.const import (
CONF_ID, CONF_ID,
CONF_IF, CONF_IF,
CONF_MATCH, CONF_MATCH,
CONF_PARALLEL,
CONF_PLATFORM, CONF_PLATFORM,
CONF_REPEAT, CONF_REPEAT,
CONF_SCAN_INTERVAL, CONF_SCAN_INTERVAL,
@ -1455,6 +1456,32 @@ _SCRIPT_ERROR_SCHEMA = vol.Schema(
} }
) )
_SCRIPT_PARALLEL_SEQUENCE = vol.Schema(
{
**SCRIPT_ACTION_BASE_SCHEMA,
vol.Required(CONF_SEQUENCE): SCRIPT_SCHEMA,
}
)
_parallel_sequence_action = vol.All(
# Wrap a shorthand sequences in a parallel action
SCRIPT_SCHEMA,
lambda config: {
CONF_SEQUENCE: config,
},
)
_SCRIPT_PARALLEL_SCHEMA = vol.Schema(
{
**SCRIPT_ACTION_BASE_SCHEMA,
vol.Required(CONF_PARALLEL): vol.All(
ensure_list, [vol.Any(_SCRIPT_PARALLEL_SEQUENCE, _parallel_sequence_action)]
),
}
)
SCRIPT_ACTION_DELAY = "delay" SCRIPT_ACTION_DELAY = "delay"
SCRIPT_ACTION_WAIT_TEMPLATE = "wait_template" SCRIPT_ACTION_WAIT_TEMPLATE = "wait_template"
SCRIPT_ACTION_CHECK_CONDITION = "condition" SCRIPT_ACTION_CHECK_CONDITION = "condition"
@ -1469,6 +1496,7 @@ SCRIPT_ACTION_VARIABLES = "variables"
SCRIPT_ACTION_STOP = "stop" SCRIPT_ACTION_STOP = "stop"
SCRIPT_ACTION_ERROR = "error" SCRIPT_ACTION_ERROR = "error"
SCRIPT_ACTION_IF = "if" SCRIPT_ACTION_IF = "if"
SCRIPT_ACTION_PARALLEL = "parallel"
def determine_script_action(action: dict[str, Any]) -> str: def determine_script_action(action: dict[str, Any]) -> str:
@ -1515,6 +1543,9 @@ def determine_script_action(action: dict[str, Any]) -> str:
if CONF_ERROR in action: if CONF_ERROR in action:
return SCRIPT_ACTION_ERROR return SCRIPT_ACTION_ERROR
if CONF_PARALLEL in action:
return SCRIPT_ACTION_PARALLEL
raise ValueError("Unable to determine action") raise ValueError("Unable to determine action")
@ -1533,6 +1564,7 @@ ACTION_TYPE_SCHEMAS: dict[str, Callable[[Any], dict]] = {
SCRIPT_ACTION_STOP: _SCRIPT_STOP_SCHEMA, SCRIPT_ACTION_STOP: _SCRIPT_STOP_SCHEMA,
SCRIPT_ACTION_ERROR: _SCRIPT_ERROR_SCHEMA, SCRIPT_ACTION_ERROR: _SCRIPT_ERROR_SCHEMA,
SCRIPT_ACTION_IF: _SCRIPT_IF_SCHEMA, SCRIPT_ACTION_IF: _SCRIPT_IF_SCHEMA,
SCRIPT_ACTION_PARALLEL: _SCRIPT_PARALLEL_SCHEMA,
} }

View file

@ -5,6 +5,7 @@ import asyncio
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from contextlib import asynccontextmanager, suppress from contextlib import asynccontextmanager, suppress
from contextvars import ContextVar from contextvars import ContextVar
from copy import copy
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import partial from functools import partial
import itertools import itertools
@ -40,6 +41,7 @@ from homeassistant.const import (
CONF_EVENT_DATA_TEMPLATE, CONF_EVENT_DATA_TEMPLATE,
CONF_IF, CONF_IF,
CONF_MODE, CONF_MODE,
CONF_PARALLEL,
CONF_REPEAT, CONF_REPEAT,
CONF_SCENE, CONF_SCENE,
CONF_SEQUENCE, CONF_SEQUENCE,
@ -79,6 +81,7 @@ from .trace import (
trace_id_get, trace_id_get,
trace_path, trace_path,
trace_path_get, trace_path_get,
trace_path_stack_cv,
trace_set_result, trace_set_result,
trace_stack_cv, trace_stack_cv,
trace_stack_pop, trace_stack_pop,
@ -307,6 +310,13 @@ async def async_validate_action_config(
config[CONF_ELSE] = await async_validate_actions_config( config[CONF_ELSE] = await async_validate_actions_config(
hass, config[CONF_ELSE] hass, config[CONF_ELSE]
) )
elif action_type == cv.SCRIPT_ACTION_PARALLEL:
for parallel_conf in config[CONF_PARALLEL]:
parallel_conf[CONF_SEQUENCE] = await async_validate_actions_config(
hass, parallel_conf[CONF_SEQUENCE]
)
else: else:
raise ValueError(f"No validation for {action_type}") raise ValueError(f"No validation for {action_type}")
@ -896,6 +906,26 @@ class _ScriptRun:
trace_set_result(error=error) trace_set_result(error=error)
raise _AbortScript(error) raise _AbortScript(error)
@async_trace_path("parallel")
async def _async_parallel_step(self) -> None:
"""Run a sequence in parallel."""
# pylint: disable=protected-access
scripts = await self._script._async_get_parallel_scripts(self._step)
async def async_run_with_trace(idx: int, script: Script) -> None:
"""Run a script with a trace path."""
trace_path_stack_cv.set(copy(trace_path_stack_cv.get()))
with trace_path([str(idx), "sequence"]):
await self._async_run_script(script)
results = await asyncio.gather(
*(async_run_with_trace(idx, script) for idx, script in enumerate(scripts)),
return_exceptions=True,
)
for result in results:
if isinstance(result, Exception):
raise result
async def _async_run_script(self, script: Script) -> None: async def _async_run_script(self, script: Script) -> None:
"""Execute a script.""" """Execute a script."""
await self._async_run_long_action( await self._async_run_long_action(
@ -1075,6 +1105,7 @@ class Script:
self._repeat_script: dict[int, Script] = {} self._repeat_script: dict[int, Script] = {}
self._choose_data: dict[int, _ChooseData] = {} self._choose_data: dict[int, _ChooseData] = {}
self._if_data: dict[int, _IfData] = {} self._if_data: dict[int, _IfData] = {}
self._parallel_scripts: dict[int, list[Script]] = {}
self._referenced_entities: set[str] | None = None self._referenced_entities: set[str] | None = None
self._referenced_devices: set[str] | None = None self._referenced_devices: set[str] | None = None
self._referenced_areas: set[str] | None = None self._referenced_areas: set[str] | None = None
@ -1109,6 +1140,9 @@ class Script:
self._set_logger(logger) self._set_logger(logger)
for script in self._repeat_script.values(): for script in self._repeat_script.values():
script.update_logger(self._logger) script.update_logger(self._logger)
for parallel_scripts in self._parallel_scripts.values():
for parallel_script in parallel_scripts:
parallel_script.update_logger(self._logger)
for choose_data in self._choose_data.values(): for choose_data in self._choose_data.values():
for _, script in choose_data["choices"]: for _, script in choose_data["choices"]:
script.update_logger(self._logger) script.update_logger(self._logger)
@ -1178,6 +1212,10 @@ class Script:
if CONF_ELSE in step: if CONF_ELSE in step:
Script._find_referenced_areas(referenced, step[CONF_ELSE]) Script._find_referenced_areas(referenced, step[CONF_ELSE])
elif action == cv.SCRIPT_ACTION_PARALLEL:
for script in step[CONF_PARALLEL]:
Script._find_referenced_areas(referenced, script[CONF_SEQUENCE])
@property @property
def referenced_devices(self): def referenced_devices(self):
"""Return a set of referenced devices.""" """Return a set of referenced devices."""
@ -1222,6 +1260,10 @@ class Script:
if CONF_ELSE in step: if CONF_ELSE in step:
Script._find_referenced_devices(referenced, step[CONF_ELSE]) Script._find_referenced_devices(referenced, step[CONF_ELSE])
elif action == cv.SCRIPT_ACTION_PARALLEL:
for script in step[CONF_PARALLEL]:
Script._find_referenced_devices(referenced, script[CONF_SEQUENCE])
@property @property
def referenced_entities(self): def referenced_entities(self):
"""Return a set of referenced entities.""" """Return a set of referenced entities."""
@ -1267,6 +1309,10 @@ class Script:
if CONF_ELSE in step: if CONF_ELSE in step:
Script._find_referenced_entities(referenced, step[CONF_ELSE]) Script._find_referenced_entities(referenced, step[CONF_ELSE])
elif action == cv.SCRIPT_ACTION_PARALLEL:
for script in step[CONF_PARALLEL]:
Script._find_referenced_entities(referenced, script[CONF_SEQUENCE])
def run( def run(
self, variables: _VarsType | None = None, context: Context | None = None self, variables: _VarsType | None = None, context: Context | None = None
) -> None: ) -> None:
@ -1530,6 +1576,36 @@ class Script:
self._if_data[step] = if_data self._if_data[step] = if_data
return if_data return if_data
async def _async_prep_parallel_scripts(self, step: int) -> list[Script]:
action = self.sequence[step]
step_name = action.get(CONF_ALIAS, f"Parallel action at step {step+1}")
parallel_scripts: list[Script] = []
for idx, parallel_script in enumerate(action[CONF_PARALLEL], start=1):
parallel_name = parallel_script.get(CONF_ALIAS, f"parallel {idx}")
parallel_script = Script(
self._hass,
parallel_script[CONF_SEQUENCE],
f"{self.name}: {step_name}: {parallel_name}",
self.domain,
running_description=self.running_description,
script_mode=SCRIPT_MODE_PARALLEL,
max_runs=self.max_runs,
logger=self._logger,
top_level=False,
)
parallel_script.change_listener = partial(
self._chain_change_listener, parallel_script
)
parallel_scripts.append(parallel_script)
return parallel_scripts
async def _async_get_parallel_scripts(self, step: int) -> list[Script]:
if not (parallel_scripts := self._parallel_scripts.get(step)):
parallel_scripts = await self._async_prep_parallel_scripts(step)
self._parallel_scripts[step] = parallel_scripts
return parallel_scripts
def _log( def _log(
self, msg: str, *args: Any, level: int = logging.INFO, **kwargs: Any self, msg: str, *args: Any, level: int = logging.INFO, **kwargs: Any
) -> None: ) -> None:

View file

@ -2682,6 +2682,148 @@ async def test_if_condition_validation(
) )
async def test_parallel(hass: HomeAssistant, caplog: pytest.LogCaptureFixture) -> None:
"""Test parallel action."""
events = async_capture_events(hass, "test_event")
hass.states.async_set("switch.trigger", "off")
sequence = cv.SCRIPT_SCHEMA(
{
"parallel": [
{
"alias": "Sequential group",
"sequence": [
{
"alias": "Waiting for trigger",
"wait_for_trigger": {
"platform": "state",
"entity_id": "switch.trigger",
"to": "on",
},
},
{
"event": "test_event",
"event_data": {
"hello": "from action 1",
"what": "{{ what }}",
},
},
],
},
{
"alias": "Don't wait at all",
"event": "test_event",
"event_data": {"hello": "from action 2", "what": "{{ what }}"},
},
]
}
)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, "Waiting for trigger")
hass.async_create_task(
script_obj.async_run(MappingProxyType({"what": "world"}), Context())
)
await asyncio.wait_for(wait_started_flag.wait(), 1)
assert script_obj.is_running
hass.states.async_set("switch.trigger", "on")
await hass.async_block_till_done()
assert len(events) == 2
assert events[0].data["hello"] == "from action 2"
assert events[0].data["what"] == "world"
assert events[1].data["hello"] == "from action 1"
assert events[1].data["what"] == "world"
assert (
"Test Name: Parallel action at step 1: Sequential group: Executing step Waiting for trigger"
in caplog.text
)
assert (
"Parallel action at step 1: parallel 2: Executing step Don't wait at all"
in caplog.text
)
expected_trace = {
"0": [{"result": {}}],
"0/parallel/0/sequence/0": [
{
"result": {
"wait": {
"remaining": None,
"trigger": {
"entity_id": "switch.trigger",
"description": "state of switch.trigger",
},
}
}
}
],
"0/parallel/1/sequence/0": [
{
"variables": {"wait": {"remaining": None}},
"result": {
"event": "test_event",
"event_data": {"hello": "from action 2", "what": "world"},
},
}
],
"0/parallel/0/sequence/1": [
{
"variables": {"wait": {"remaining": None}},
"result": {
"event": "test_event",
"event_data": {"hello": "from action 1", "what": "world"},
},
}
],
}
assert_action_trace(expected_trace)
async def test_parallel_error(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test parallel action failure handling."""
events = async_capture_events(hass, "test_event")
sequence = cv.SCRIPT_SCHEMA(
{
"parallel": [
{"service": "epic.failure"},
]
}
)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
with pytest.raises(exceptions.ServiceNotFound):
await script_obj.async_run(context=Context())
assert len(events) == 0
expected_trace = {
"0": [{"error_type": ServiceNotFound, "result": {}}],
"0/parallel/0/sequence/0": [
{
"error_type": ServiceNotFound,
"result": {
"limit": 10,
"params": {
"domain": "epic",
"service": "failure",
"service_data": {},
"target": {},
},
"running_script": False,
},
}
],
}
assert_action_trace(expected_trace, expected_script_execution="error")
async def test_last_triggered(hass): async def test_last_triggered(hass):
"""Test the last_triggered.""" """Test the last_triggered."""
event = "test_event" event = "test_event"
@ -2881,6 +3023,14 @@ async def test_referenced_areas(hass):
} }
], ],
}, },
{
"parallel": [
{
"service": "test.script",
"data": {"area_id": "area_parallel"},
}
],
},
] ]
), ),
"Test Name", "Test Name",
@ -2896,6 +3046,7 @@ async def test_referenced_areas(hass):
"area_service_not_list", "area_service_not_list",
"area_if_then", "area_if_then",
"area_if_else", "area_if_else",
"area_parallel",
# 'area_service_template', # no area extraction from template # 'area_service_template', # no area extraction from template
} }
# Test we cache results. # Test we cache results.
@ -2988,6 +3139,14 @@ async def test_referenced_entities(hass):
} }
], ],
}, },
{
"parallel": [
{
"service": "test.script",
"data": {"entity_id": "light.parallel"},
}
],
},
] ]
), ),
"Test Name", "Test Name",
@ -3006,6 +3165,7 @@ async def test_referenced_entities(hass):
"light.service_not_list", "light.service_not_list",
"light.if_then", "light.if_then",
"light.if_else", "light.if_else",
"light.parallel",
# "light.service_template", # no entity extraction from template # "light.service_template", # no entity extraction from template
"scene.hello", "scene.hello",
"sensor.condition", "sensor.condition",
@ -3093,6 +3253,14 @@ async def test_referenced_devices(hass):
} }
], ],
}, },
{
"parallel": [
{
"service": "test.script",
"target": {"device_id": "parallel-device"},
}
],
},
] ]
), ),
"Test Name", "Test Name",
@ -3113,6 +3281,7 @@ async def test_referenced_devices(hass):
"target-string-id", "target-string-id",
"if-then", "if-then",
"if-else", "if-else",
"parallel-device",
} }
# Test we cache results. # Test we cache results.
assert script_obj.referenced_devices is script_obj.referenced_devices assert script_obj.referenced_devices is script_obj.referenced_devices
@ -3744,6 +3913,9 @@ async def test_validate_action_config(hass):
"then": [templated_device_action("if_then_event")], "then": [templated_device_action("if_then_event")],
"else": [templated_device_action("if_else_event")], "else": [templated_device_action("if_else_event")],
}, },
cv.SCRIPT_ACTION_PARALLEL: {
"parallel": [templated_device_action("parallel_event")],
},
} }
expected_templates = { expected_templates = {
cv.SCRIPT_ACTION_CHECK_CONDITION: None, cv.SCRIPT_ACTION_CHECK_CONDITION: None,
@ -3752,6 +3924,7 @@ async def test_validate_action_config(hass):
cv.SCRIPT_ACTION_CHOOSE: [["choose", 0, "sequence", 0], ["default", 0]], cv.SCRIPT_ACTION_CHOOSE: [["choose", 0, "sequence", 0], ["default", 0]],
cv.SCRIPT_ACTION_WAIT_FOR_TRIGGER: None, cv.SCRIPT_ACTION_WAIT_FOR_TRIGGER: None,
cv.SCRIPT_ACTION_IF: None, cv.SCRIPT_ACTION_IF: None,
cv.SCRIPT_ACTION_PARALLEL: None,
} }
for key in cv.ACTION_TYPE_SCHEMAS: for key in cv.ACTION_TYPE_SCHEMAS: