From 81d90b1bc7ea52ba4bef45a7cf12d3adb7a7f141 Mon Sep 17 00:00:00 2001 From: Franck Nijhof Date: Mon, 11 Apr 2022 23:22:22 +0200 Subject: [PATCH] Add stop/error script/automation action (#67340) --- homeassistant/const.py | 2 + homeassistant/helpers/config_validation.py | 26 +++++++ homeassistant/helpers/script.py | 40 +++++++++-- tests/helpers/test_script.py | 79 ++++++++++++++++++++-- 4 files changed, 136 insertions(+), 11 deletions(-) diff --git a/homeassistant/const.py b/homeassistant/const.py index 0ab11a2971f..7a58ce111f8 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -146,6 +146,7 @@ CONF_ENTITY_CATEGORY: Final = "entity_category" CONF_ENTITY_ID: Final = "entity_id" CONF_ENTITY_NAMESPACE: Final = "entity_namespace" CONF_ENTITY_PICTURE_TEMPLATE: Final = "entity_picture_template" +CONF_ERROR: Final = "error" CONF_EVENT: Final = "event" CONF_EVENT_DATA: Final = "event_data" CONF_EVENT_DATA_TEMPLATE: Final = "event_data_template" @@ -226,6 +227,7 @@ CONF_SOURCE: Final = "source" CONF_SSL: Final = "ssl" CONF_STATE: Final = "state" CONF_STATE_TEMPLATE: Final = "state_template" +CONF_STOP: Final = "stop" CONF_STRUCTURE: Final = "structure" CONF_SWITCHES: Final = "switches" CONF_TARGET: Final = "target" diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index c7c7e9aae4a..899d038cbe3 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -44,6 +44,7 @@ from homeassistant.const import ( CONF_DOMAIN, CONF_ENTITY_ID, CONF_ENTITY_NAMESPACE, + CONF_ERROR, CONF_EVENT, CONF_EVENT_DATA, CONF_EVENT_DATA_TEMPLATE, @@ -58,6 +59,7 @@ from homeassistant.const import ( CONF_SERVICE, CONF_SERVICE_TEMPLATE, CONF_STATE, + CONF_STOP, CONF_TARGET, CONF_TIMEOUT, CONF_UNIT_SYSTEM_IMPERIAL, @@ -1425,6 +1427,20 @@ _SCRIPT_SET_SCHEMA = vol.Schema( } ) +_SCRIPT_STOP_SCHEMA = vol.Schema( + { + **SCRIPT_ACTION_BASE_SCHEMA, + vol.Required(CONF_STOP): vol.Any(None, string), + } +) + +_SCRIPT_ERROR_SCHEMA = vol.Schema( + { + **SCRIPT_ACTION_BASE_SCHEMA, + vol.Optional(CONF_ERROR): vol.Any(None, string), + } +) + SCRIPT_ACTION_DELAY = "delay" SCRIPT_ACTION_WAIT_TEMPLATE = "wait_template" SCRIPT_ACTION_CHECK_CONDITION = "condition" @@ -1436,6 +1452,8 @@ SCRIPT_ACTION_REPEAT = "repeat" SCRIPT_ACTION_CHOOSE = "choose" SCRIPT_ACTION_WAIT_FOR_TRIGGER = "wait_for_trigger" SCRIPT_ACTION_VARIABLES = "variables" +SCRIPT_ACTION_STOP = "stop" +SCRIPT_ACTION_ERROR = "error" def determine_script_action(action: dict[str, Any]) -> str: @@ -1473,6 +1491,12 @@ def determine_script_action(action: dict[str, Any]) -> str: if CONF_SERVICE in action or CONF_SERVICE_TEMPLATE in action: return SCRIPT_ACTION_CALL_SERVICE + if CONF_STOP in action: + return SCRIPT_ACTION_STOP + + if CONF_ERROR in action: + return SCRIPT_ACTION_ERROR + raise ValueError("Unable to determine action") @@ -1488,6 +1512,8 @@ ACTION_TYPE_SCHEMAS: dict[str, Callable[[Any], dict]] = { SCRIPT_ACTION_CHOOSE: _SCRIPT_CHOOSE_SCHEMA, SCRIPT_ACTION_WAIT_FOR_TRIGGER: _SCRIPT_WAIT_FOR_TRIGGER_SCHEMA, SCRIPT_ACTION_VARIABLES: _SCRIPT_SET_SCHEMA, + SCRIPT_ACTION_STOP: _SCRIPT_STOP_SCHEMA, + SCRIPT_ACTION_ERROR: _SCRIPT_ERROR_SCHEMA, } diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 1ede1d10d89..f6109e9d6f8 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -33,6 +33,7 @@ from homeassistant.const import ( CONF_DELAY, CONF_DEVICE_ID, CONF_DOMAIN, + CONF_ERROR, CONF_EVENT, CONF_EVENT_DATA, CONF_EVENT_DATA_TEMPLATE, @@ -41,6 +42,7 @@ from homeassistant.const import ( CONF_SCENE, CONF_SEQUENCE, CONF_SERVICE, + CONF_STOP, CONF_TARGET, CONF_TIMEOUT, CONF_UNTIL, @@ -191,9 +193,11 @@ async def trace_action(hass, script_run, stop, variables): try: yield trace_element - except _StopScript as ex: + except _AbortScript as ex: trace_element.set_error(ex.__cause__ or ex) raise ex + except _StopScript as ex: + raise ex except Exception as ex: trace_element.set_error(ex) raise ex @@ -227,6 +231,8 @@ STATIC_VALIDATION_ACTION_TYPES = ( cv.SCRIPT_ACTION_FIRE_EVENT, cv.SCRIPT_ACTION_ACTIVATE_SCENE, cv.SCRIPT_ACTION_VARIABLES, + cv.SCRIPT_ACTION_ERROR, + cv.SCRIPT_ACTION_STOP, ) @@ -295,6 +301,10 @@ async def async_validate_action_config( return config +class _AbortScript(Exception): + """Throw if script needs to abort because of an unexpected error.""" + + class _StopScript(Exception): """Throw if script needs to stop.""" @@ -360,6 +370,8 @@ class _ScriptRun: else: script_execution_set("finished") except _StopScript: + script_execution_set("finished") + except _AbortScript: script_execution_set("aborted") except Exception: script_execution_set("error") @@ -378,7 +390,7 @@ class _ScriptRun: handler = f"_async_{cv.determine_script_action(self._action)}_step" await getattr(self, handler)() except Exception as ex: - if not isinstance(ex, _StopScript) and ( + if not isinstance(ex, (_AbortScript, _StopScript)) and ( self._log_exceptions or log_exceptions ): self._log_exception(ex) @@ -443,7 +455,7 @@ class _ScriptRun: ex, level=logging.ERROR, ) - raise _StopScript from ex + raise _AbortScript from ex async def _async_delay_step(self): """Handle delay.""" @@ -509,7 +521,7 @@ class _ScriptRun: if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True): self._log(_TIMEOUT_MSG) trace_set_result(wait=self._variables["wait"], timeout=True) - raise _StopScript from ex + raise _AbortScript from ex finally: for task in tasks: task.cancel() @@ -643,7 +655,7 @@ class _ScriptRun: self._log("Test condition %s: %s", self._script.last_action, check) trace_update_result(result=check) if not check: - raise _StopScript + raise _AbortScript def _test_conditions(self, conditions, name, condition_path=None): if condition_path is None: @@ -700,7 +712,7 @@ class _ScriptRun: ex, level=logging.ERROR, ) - raise _StopScript from ex + raise _AbortScript from ex extra_msg = f" of {count}" for iteration in range(1, count + 1): set_repeat_var(iteration, count) @@ -820,7 +832,7 @@ class _ScriptRun: if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True): self._log(_TIMEOUT_MSG) trace_set_result(wait=self._variables["wait"], timeout=True) - raise _StopScript from ex + raise _AbortScript from ex finally: for task in tasks: task.cancel() @@ -833,6 +845,20 @@ class _ScriptRun: self._hass, self._variables, render_as_defaults=False ) + async def _async_stop_step(self): + """Stop script execution.""" + stop = self._action[CONF_STOP] + self._log("Stop script sequence: %s", stop) + trace_set_result(stop=stop) + raise _StopScript(stop) + + async def _async_error_step(self): + """Abort and error script execution.""" + error = self._action[CONF_ERROR] + self._log("Error script sequence: %s", error) + trace_set_result(error=error) + raise _AbortScript(error) + async def _async_run_script(self, script: Script) -> None: """Execute a script.""" await self._async_run_long_action( diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index dc1a498e465..d5b7f8048f4 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -1438,7 +1438,7 @@ async def test_condition_warning(hass, caplog): assert_action_trace( { "0": [{"result": {"event": "test_event", "event_data": {}}}], - "1": [{"error_type": script._StopScript, "result": {"result": False}}], + "1": [{"error_type": script._AbortScript, "result": {"result": False}}], "1/entity_id/0": [{"error_type": ConditionError}], }, expected_script_execution="aborted", @@ -1492,7 +1492,7 @@ async def test_condition_basic(hass, caplog): "0": [{"result": {"event": "test_event", "event_data": {}}}], "1": [ { - "error_type": script._StopScript, + "error_type": script._AbortScript, "result": {"entities": ["test.entity"], "result": False}, } ], @@ -1547,7 +1547,7 @@ async def test_shorthand_template_condition(hass, caplog): "0": [{"result": {"event": "test_event", "event_data": {}}}], "1": [ { - "error_type": script._StopScript, + "error_type": script._AbortScript, "result": {"entities": ["test.entity"], "result": False}, } ], @@ -1613,7 +1613,7 @@ async def test_condition_validation(hass, caplog): "0": [{"result": {"event": "test_event", "event_data": {}}}], "1": [ { - "error_type": script._StopScript, + "error_type": script._AbortScript, "result": {"result": False}, } ], @@ -3508,6 +3508,8 @@ async def test_validate_action_config(hass): ] }, cv.SCRIPT_ACTION_VARIABLES: {"variables": {"hello": "world"}}, + cv.SCRIPT_ACTION_STOP: {"stop": "Stop it right there buddy..."}, + cv.SCRIPT_ACTION_ERROR: {"error": "Stand up, and try again!"}, } expected_templates = { cv.SCRIPT_ACTION_CHECK_CONDITION: None, @@ -3778,3 +3780,72 @@ async def test_platform_async_validate_action_config(hass): platform.async_validate_action_config.return_value = config await script.async_validate_action_config(hass, config) platform.async_validate_action_config.assert_awaited() + + +async def test_stop_action(hass, caplog): + """Test if automation stops on calling the stop action.""" + event = "test_event" + events = async_capture_events(hass, event) + + alias = "stop step" + sequence = cv.SCRIPT_SCHEMA( + [ + {"event": event}, + { + "alias": alias, + "stop": "In the name of love", + }, + {"event": event}, + ] + ) + script_obj = script.Script(hass, sequence, "Test Name", "test_domain") + + await script_obj.async_run(context=Context()) + await hass.async_block_till_done() + + assert "Stop script sequence: In the name of love" in caplog.text + caplog.clear() + assert len(events) == 1 + + assert_action_trace( + { + "0": [{"result": {"event": "test_event", "event_data": {}}}], + "1": [{"result": {"stop": "In the name of love"}}], + } + ) + + +async def test_error_action(hass, caplog): + """Test if automation fails on calling the error action.""" + event = "test_event" + events = async_capture_events(hass, event) + + alias = "stop step" + sequence = cv.SCRIPT_SCHEMA( + [ + {"event": event}, + { + "alias": alias, + "error": "Epic one...", + }, + {"event": event}, + ] + ) + script_obj = script.Script(hass, sequence, "Test Name", "test_domain") + + await script_obj.async_run(context=Context()) + await hass.async_block_till_done() + + assert "Test Name: Error script sequence: Epic one..." in caplog.text + caplog.clear() + assert len(events) == 1 + + assert_action_trace( + { + "0": [{"result": {"event": "test_event", "event_data": {}}}], + "1": [ + {"error_type": script._AbortScript, "result": {"error": "Epic one..."}} + ], + }, + expected_script_execution="aborted", + )