From d028236bf210e668f4151ba2dc9514833e52cda8 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 3 Dec 2018 15:46:25 +0100 Subject: [PATCH] Refactor script helper actions into their own methods (#18962) * Refactor script helper actions into their own methods * Lint * Lint --- homeassistant/helpers/script.py | 234 +++++++++++++++++++------------- tests/helpers/test_script.py | 85 ++++++++++++ 2 files changed, 227 insertions(+), 92 deletions(-) diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 80d66f4fac8..088882df608 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -9,7 +9,7 @@ import voluptuous as vol from homeassistant.core import HomeAssistant, Context, callback from homeassistant.const import CONF_CONDITION, CONF_TIMEOUT -from homeassistant.exceptions import TemplateError +from homeassistant import exceptions from homeassistant.helpers import ( service, condition, template as template, config_validation as cv) @@ -34,6 +34,30 @@ CONF_WAIT_TEMPLATE = 'wait_template' CONF_CONTINUE = 'continue_on_timeout' +ACTION_DELAY = 'delay' +ACTION_WAIT_TEMPLATE = 'wait_template' +ACTION_CHECK_CONDITION = 'condition' +ACTION_FIRE_EVENT = 'event' +ACTION_CALL_SERVICE = 'call_service' + + +def _determine_action(action): + """Determine action type.""" + if CONF_DELAY in action: + return ACTION_DELAY + + if CONF_WAIT_TEMPLATE in action: + return ACTION_WAIT_TEMPLATE + + if CONF_CONDITION in action: + return ACTION_CHECK_CONDITION + + if CONF_EVENT in action: + return ACTION_FIRE_EVENT + + return ACTION_CALL_SERVICE + + def call_from_config(hass: HomeAssistant, config: ConfigType, variables: Optional[Sequence] = None, context: Optional[Context] = None) -> None: @@ -41,6 +65,14 @@ def call_from_config(hass: HomeAssistant, config: ConfigType, Script(hass, cv.SCRIPT_SCHEMA(config)).run(variables, context) +class _StopScript(Exception): + """Throw if script needs to stop.""" + + +class _SuspendScript(Exception): + """Throw if script needs to suspend.""" + + class Script(): """Representation of a script.""" @@ -60,6 +92,13 @@ class Script(): self._async_listener = [] self._template_cache = {} self._config_cache = {} + self._actions = { + ACTION_DELAY: self._async_delay, + ACTION_WAIT_TEMPLATE: self._async_wait_template, + ACTION_CHECK_CONDITION: self._async_check_condition, + ACTION_FIRE_EVENT: self._async_fire_event, + ACTION_CALL_SERVICE: self._async_call_service, + } @property def is_running(self) -> bool: @@ -87,98 +126,27 @@ class Script(): self._async_remove_listener() for cur, action in islice(enumerate(self.sequence), self._cur, None): - - if CONF_DELAY in action: - # Call ourselves in the future to continue work - unsub = None - - @callback - def async_script_delay(now): - """Handle delay.""" - # pylint: disable=cell-var-from-loop - with suppress(ValueError): - self._async_listener.remove(unsub) - - self.hass.async_create_task( - self.async_run(variables, context)) - - delay = action[CONF_DELAY] - - try: - if isinstance(delay, template.Template): - delay = vol.All( - cv.time_period, - cv.positive_timedelta)( - delay.async_render(variables)) - elif isinstance(delay, dict): - delay_data = {} - delay_data.update( - template.render_complex(delay, variables)) - delay = cv.time_period(delay_data) - except (TemplateError, vol.Invalid) as ex: - _LOGGER.error("Error rendering '%s' delay template: %s", - self.name, ex) - break - - self.last_action = action.get( - CONF_ALIAS, 'delay {}'.format(delay)) - self._log("Executing step %s" % self.last_action) - - unsub = async_track_point_in_utc_time( - self.hass, async_script_delay, - date_util.utcnow() + delay - ) - self._async_listener.append(unsub) - + try: + await self._handle_action(action, variables, context) + except _SuspendScript: + # Store next step to take and notify change listeners self._cur = cur + 1 if self._change_listener: self.hass.async_add_job(self._change_listener) return + except _StopScript: + break + except Exception as err: + # Store the step that had an exception + # pylint: disable=protected-access + err._script_step = cur + # Set script to not running + self._cur = -1 + self.last_action = None + # Pass exception on. + raise - if CONF_WAIT_TEMPLATE in action: - # Call ourselves in the future to continue work - wait_template = action[CONF_WAIT_TEMPLATE] - wait_template.hass = self.hass - - self.last_action = action.get(CONF_ALIAS, 'wait template') - self._log("Executing step %s" % self.last_action) - - # check if condition already okay - if condition.async_template( - self.hass, wait_template, variables): - continue - - @callback - def async_script_wait(entity_id, from_s, to_s): - """Handle script after template condition is true.""" - self._async_remove_listener() - self.hass.async_create_task( - self.async_run(variables, context)) - - self._async_listener.append(async_track_template( - self.hass, wait_template, async_script_wait, variables)) - - self._cur = cur + 1 - if self._change_listener: - self.hass.async_add_job(self._change_listener) - - if CONF_TIMEOUT in action: - self._async_set_timeout( - action, variables, context, - action.get(CONF_CONTINUE, True)) - - return - - if CONF_CONDITION in action: - if not self._async_check_condition(action, variables): - break - - elif CONF_EVENT in action: - self._async_fire_event(action, variables, context) - - else: - await self._async_call_service(action, variables, context) - + # Set script to not-running. self._cur = -1 self.last_action = None if self._change_listener: @@ -198,6 +166,86 @@ class Script(): if self._change_listener: self.hass.async_add_job(self._change_listener) + async def _handle_action(self, action, variables, context): + """Handle an action.""" + await self._actions[_determine_action(action)]( + action, variables, context) + + async def _async_delay(self, action, variables, context): + """Handle delay.""" + # Call ourselves in the future to continue work + unsub = None + + @callback + def async_script_delay(now): + """Handle delay.""" + # pylint: disable=cell-var-from-loop + with suppress(ValueError): + self._async_listener.remove(unsub) + + self.hass.async_create_task( + self.async_run(variables, context)) + + delay = action[CONF_DELAY] + + try: + if isinstance(delay, template.Template): + delay = vol.All( + cv.time_period, + cv.positive_timedelta)( + delay.async_render(variables)) + elif isinstance(delay, dict): + delay_data = {} + delay_data.update( + template.render_complex(delay, variables)) + delay = cv.time_period(delay_data) + except (exceptions.TemplateError, vol.Invalid) as ex: + _LOGGER.error("Error rendering '%s' delay template: %s", + self.name, ex) + raise _StopScript + + self.last_action = action.get( + CONF_ALIAS, 'delay {}'.format(delay)) + self._log("Executing step %s" % self.last_action) + + unsub = async_track_point_in_utc_time( + self.hass, async_script_delay, + date_util.utcnow() + delay + ) + self._async_listener.append(unsub) + raise _SuspendScript + + async def _async_wait_template(self, action, variables, context): + """Handle a wait template.""" + # Call ourselves in the future to continue work + wait_template = action[CONF_WAIT_TEMPLATE] + wait_template.hass = self.hass + + self.last_action = action.get(CONF_ALIAS, 'wait template') + self._log("Executing step %s" % self.last_action) + + # check if condition already okay + if condition.async_template( + self.hass, wait_template, variables): + return + + @callback + def async_script_wait(entity_id, from_s, to_s): + """Handle script after template condition is true.""" + self._async_remove_listener() + self.hass.async_create_task( + self.async_run(variables, context)) + + self._async_listener.append(async_track_template( + self.hass, wait_template, async_script_wait, variables)) + + if CONF_TIMEOUT in action: + self._async_set_timeout( + action, variables, context, + action.get(CONF_CONTINUE, True)) + + raise _SuspendScript + async def _async_call_service(self, action, variables, context): """Call the service specified in the action. @@ -213,7 +261,7 @@ class Script(): context=context ) - def _async_fire_event(self, action, variables, context): + async def _async_fire_event(self, action, variables, context): """Fire an event.""" self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT]) self._log("Executing step %s" % self.last_action) @@ -222,13 +270,13 @@ class Script(): try: event_data.update(template.render_complex( action[CONF_EVENT_DATA_TEMPLATE], variables)) - except TemplateError as ex: + except exceptions.TemplateError as ex: _LOGGER.error('Error rendering event data template: %s', ex) self.hass.bus.async_fire(action[CONF_EVENT], event_data, context=context) - def _async_check_condition(self, action, variables): + async def _async_check_condition(self, action, variables, context): """Test if condition is matching.""" config_cache_key = frozenset((k, str(v)) for k, v in action.items()) config = self._config_cache.get(config_cache_key) @@ -239,7 +287,9 @@ class Script(): self.last_action = action.get(CONF_ALIAS, action[CONF_CONDITION]) check = config(self.hass, variables) self._log("Test condition {}: {}".format(self.last_action, check)) - return check + + if not check: + raise _StopScript def _async_set_timeout(self, action, variables, context, continue_on_timeout): diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index e5e62d2aed3..887a147c417 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -4,6 +4,10 @@ from datetime import timedelta from unittest import mock import unittest +import voluptuous as vol +import pytest + +from homeassistant import exceptions from homeassistant.core import Context, callback # Otherwise can't test just this file (import order issue) import homeassistant.components # noqa @@ -774,3 +778,84 @@ class TestScriptHelper(unittest.TestCase): self.hass.block_till_done() assert script_obj.last_triggered == time + + +async def test_propagate_error_service_not_found(hass): + """Test that a script aborts when a service is not found.""" + events = [] + + @callback + def record_event(event): + events.append(event) + + hass.bus.async_listen('test_event', record_event) + + script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([ + {'service': 'test.script'}, + {'event': 'test_event'}])) + + with pytest.raises(exceptions.ServiceNotFound): + await script_obj.async_run() + + assert len(events) == 0 + + +async def test_propagate_error_invalid_service_data(hass): + """Test that a script aborts when we send invalid service data.""" + events = [] + + @callback + def record_event(event): + events.append(event) + + hass.bus.async_listen('test_event', record_event) + + calls = [] + + @callback + def record_call(service): + """Add recorded event to set.""" + calls.append(service) + + hass.services.async_register('test', 'script', record_call, + schema=vol.Schema({'text': str})) + + script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([ + {'service': 'test.script', 'data': {'text': 1}}, + {'event': 'test_event'}])) + + with pytest.raises(vol.Invalid): + await script_obj.async_run() + + assert len(events) == 0 + assert len(calls) == 0 + + +async def test_propagate_error_service_exception(hass): + """Test that a script aborts when a service throws an exception.""" + events = [] + + @callback + def record_event(event): + events.append(event) + + hass.bus.async_listen('test_event', record_event) + + calls = [] + + @callback + def record_call(service): + """Add recorded event to set.""" + raise ValueError("BROKEN") + + hass.services.async_register('test', 'script', record_call) + + script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([ + {'service': 'test.script'}, + {'event': 'test_event'}])) + + with pytest.raises(ValueError): + await script_obj.async_run() + + assert len(events) == 0 + assert len(calls) == 0