Refactor script helper actions into their own methods (#18962)

* Refactor script helper actions into their own methods

* Lint

* Lint
This commit is contained in:
Paulus Schoutsen 2018-12-03 15:46:25 +01:00 committed by GitHub
parent d0751ffd91
commit d028236bf2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 227 additions and 92 deletions

View file

@ -9,7 +9,7 @@ import voluptuous as vol
from homeassistant.core import HomeAssistant, Context, callback from homeassistant.core import HomeAssistant, Context, callback
from homeassistant.const import CONF_CONDITION, CONF_TIMEOUT from homeassistant.const import CONF_CONDITION, CONF_TIMEOUT
from homeassistant.exceptions import TemplateError from homeassistant import exceptions
from homeassistant.helpers import ( from homeassistant.helpers import (
service, condition, template as template, service, condition, template as template,
config_validation as cv) config_validation as cv)
@ -34,6 +34,30 @@ CONF_WAIT_TEMPLATE = 'wait_template'
CONF_CONTINUE = 'continue_on_timeout' CONF_CONTINUE = 'continue_on_timeout'
ACTION_DELAY = 'delay'
ACTION_WAIT_TEMPLATE = 'wait_template'
ACTION_CHECK_CONDITION = 'condition'
ACTION_FIRE_EVENT = 'event'
ACTION_CALL_SERVICE = 'call_service'
def _determine_action(action):
"""Determine action type."""
if CONF_DELAY in action:
return ACTION_DELAY
if CONF_WAIT_TEMPLATE in action:
return ACTION_WAIT_TEMPLATE
if CONF_CONDITION in action:
return ACTION_CHECK_CONDITION
if CONF_EVENT in action:
return ACTION_FIRE_EVENT
return ACTION_CALL_SERVICE
def call_from_config(hass: HomeAssistant, config: ConfigType, def call_from_config(hass: HomeAssistant, config: ConfigType,
variables: Optional[Sequence] = None, variables: Optional[Sequence] = None,
context: Optional[Context] = None) -> None: context: Optional[Context] = None) -> None:
@ -41,6 +65,14 @@ def call_from_config(hass: HomeAssistant, config: ConfigType,
Script(hass, cv.SCRIPT_SCHEMA(config)).run(variables, context) Script(hass, cv.SCRIPT_SCHEMA(config)).run(variables, context)
class _StopScript(Exception):
"""Throw if script needs to stop."""
class _SuspendScript(Exception):
"""Throw if script needs to suspend."""
class Script(): class Script():
"""Representation of a script.""" """Representation of a script."""
@ -60,6 +92,13 @@ class Script():
self._async_listener = [] self._async_listener = []
self._template_cache = {} self._template_cache = {}
self._config_cache = {} self._config_cache = {}
self._actions = {
ACTION_DELAY: self._async_delay,
ACTION_WAIT_TEMPLATE: self._async_wait_template,
ACTION_CHECK_CONDITION: self._async_check_condition,
ACTION_FIRE_EVENT: self._async_fire_event,
ACTION_CALL_SERVICE: self._async_call_service,
}
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
@ -87,98 +126,27 @@ class Script():
self._async_remove_listener() self._async_remove_listener()
for cur, action in islice(enumerate(self.sequence), self._cur, None): for cur, action in islice(enumerate(self.sequence), self._cur, None):
try:
if CONF_DELAY in action: await self._handle_action(action, variables, context)
# Call ourselves in the future to continue work except _SuspendScript:
unsub = None # Store next step to take and notify change listeners
@callback
def async_script_delay(now):
"""Handle delay."""
# pylint: disable=cell-var-from-loop
with suppress(ValueError):
self._async_listener.remove(unsub)
self.hass.async_create_task(
self.async_run(variables, context))
delay = action[CONF_DELAY]
try:
if isinstance(delay, template.Template):
delay = vol.All(
cv.time_period,
cv.positive_timedelta)(
delay.async_render(variables))
elif isinstance(delay, dict):
delay_data = {}
delay_data.update(
template.render_complex(delay, variables))
delay = cv.time_period(delay_data)
except (TemplateError, vol.Invalid) as ex:
_LOGGER.error("Error rendering '%s' delay template: %s",
self.name, ex)
break
self.last_action = action.get(
CONF_ALIAS, 'delay {}'.format(delay))
self._log("Executing step %s" % self.last_action)
unsub = async_track_point_in_utc_time(
self.hass, async_script_delay,
date_util.utcnow() + delay
)
self._async_listener.append(unsub)
self._cur = cur + 1 self._cur = cur + 1
if self._change_listener: if self._change_listener:
self.hass.async_add_job(self._change_listener) self.hass.async_add_job(self._change_listener)
return return
except _StopScript:
break
except Exception as err:
# Store the step that had an exception
# pylint: disable=protected-access
err._script_step = cur
# Set script to not running
self._cur = -1
self.last_action = None
# Pass exception on.
raise
if CONF_WAIT_TEMPLATE in action: # Set script to not-running.
# Call ourselves in the future to continue work
wait_template = action[CONF_WAIT_TEMPLATE]
wait_template.hass = self.hass
self.last_action = action.get(CONF_ALIAS, 'wait template')
self._log("Executing step %s" % self.last_action)
# check if condition already okay
if condition.async_template(
self.hass, wait_template, variables):
continue
@callback
def async_script_wait(entity_id, from_s, to_s):
"""Handle script after template condition is true."""
self._async_remove_listener()
self.hass.async_create_task(
self.async_run(variables, context))
self._async_listener.append(async_track_template(
self.hass, wait_template, async_script_wait, variables))
self._cur = cur + 1
if self._change_listener:
self.hass.async_add_job(self._change_listener)
if CONF_TIMEOUT in action:
self._async_set_timeout(
action, variables, context,
action.get(CONF_CONTINUE, True))
return
if CONF_CONDITION in action:
if not self._async_check_condition(action, variables):
break
elif CONF_EVENT in action:
self._async_fire_event(action, variables, context)
else:
await self._async_call_service(action, variables, context)
self._cur = -1 self._cur = -1
self.last_action = None self.last_action = None
if self._change_listener: if self._change_listener:
@ -198,6 +166,86 @@ class Script():
if self._change_listener: if self._change_listener:
self.hass.async_add_job(self._change_listener) self.hass.async_add_job(self._change_listener)
async def _handle_action(self, action, variables, context):
"""Handle an action."""
await self._actions[_determine_action(action)](
action, variables, context)
async def _async_delay(self, action, variables, context):
"""Handle delay."""
# Call ourselves in the future to continue work
unsub = None
@callback
def async_script_delay(now):
"""Handle delay."""
# pylint: disable=cell-var-from-loop
with suppress(ValueError):
self._async_listener.remove(unsub)
self.hass.async_create_task(
self.async_run(variables, context))
delay = action[CONF_DELAY]
try:
if isinstance(delay, template.Template):
delay = vol.All(
cv.time_period,
cv.positive_timedelta)(
delay.async_render(variables))
elif isinstance(delay, dict):
delay_data = {}
delay_data.update(
template.render_complex(delay, variables))
delay = cv.time_period(delay_data)
except (exceptions.TemplateError, vol.Invalid) as ex:
_LOGGER.error("Error rendering '%s' delay template: %s",
self.name, ex)
raise _StopScript
self.last_action = action.get(
CONF_ALIAS, 'delay {}'.format(delay))
self._log("Executing step %s" % self.last_action)
unsub = async_track_point_in_utc_time(
self.hass, async_script_delay,
date_util.utcnow() + delay
)
self._async_listener.append(unsub)
raise _SuspendScript
async def _async_wait_template(self, action, variables, context):
"""Handle a wait template."""
# Call ourselves in the future to continue work
wait_template = action[CONF_WAIT_TEMPLATE]
wait_template.hass = self.hass
self.last_action = action.get(CONF_ALIAS, 'wait template')
self._log("Executing step %s" % self.last_action)
# check if condition already okay
if condition.async_template(
self.hass, wait_template, variables):
return
@callback
def async_script_wait(entity_id, from_s, to_s):
"""Handle script after template condition is true."""
self._async_remove_listener()
self.hass.async_create_task(
self.async_run(variables, context))
self._async_listener.append(async_track_template(
self.hass, wait_template, async_script_wait, variables))
if CONF_TIMEOUT in action:
self._async_set_timeout(
action, variables, context,
action.get(CONF_CONTINUE, True))
raise _SuspendScript
async def _async_call_service(self, action, variables, context): async def _async_call_service(self, action, variables, context):
"""Call the service specified in the action. """Call the service specified in the action.
@ -213,7 +261,7 @@ class Script():
context=context context=context
) )
def _async_fire_event(self, action, variables, context): async def _async_fire_event(self, action, variables, context):
"""Fire an event.""" """Fire an event."""
self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT]) self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT])
self._log("Executing step %s" % self.last_action) self._log("Executing step %s" % self.last_action)
@ -222,13 +270,13 @@ class Script():
try: try:
event_data.update(template.render_complex( event_data.update(template.render_complex(
action[CONF_EVENT_DATA_TEMPLATE], variables)) action[CONF_EVENT_DATA_TEMPLATE], variables))
except TemplateError as ex: except exceptions.TemplateError as ex:
_LOGGER.error('Error rendering event data template: %s', ex) _LOGGER.error('Error rendering event data template: %s', ex)
self.hass.bus.async_fire(action[CONF_EVENT], self.hass.bus.async_fire(action[CONF_EVENT],
event_data, context=context) event_data, context=context)
def _async_check_condition(self, action, variables): async def _async_check_condition(self, action, variables, context):
"""Test if condition is matching.""" """Test if condition is matching."""
config_cache_key = frozenset((k, str(v)) for k, v in action.items()) config_cache_key = frozenset((k, str(v)) for k, v in action.items())
config = self._config_cache.get(config_cache_key) config = self._config_cache.get(config_cache_key)
@ -239,7 +287,9 @@ class Script():
self.last_action = action.get(CONF_ALIAS, action[CONF_CONDITION]) self.last_action = action.get(CONF_ALIAS, action[CONF_CONDITION])
check = config(self.hass, variables) check = config(self.hass, variables)
self._log("Test condition {}: {}".format(self.last_action, check)) self._log("Test condition {}: {}".format(self.last_action, check))
return check
if not check:
raise _StopScript
def _async_set_timeout(self, action, variables, context, def _async_set_timeout(self, action, variables, context,
continue_on_timeout): continue_on_timeout):

View file

@ -4,6 +4,10 @@ from datetime import timedelta
from unittest import mock from unittest import mock
import unittest import unittest
import voluptuous as vol
import pytest
from homeassistant import exceptions
from homeassistant.core import Context, callback from homeassistant.core import Context, callback
# Otherwise can't test just this file (import order issue) # Otherwise can't test just this file (import order issue)
import homeassistant.components # noqa import homeassistant.components # noqa
@ -774,3 +778,84 @@ class TestScriptHelper(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
assert script_obj.last_triggered == time assert script_obj.last_triggered == time
async def test_propagate_error_service_not_found(hass):
"""Test that a script aborts when a service is not found."""
events = []
@callback
def record_event(event):
events.append(event)
hass.bus.async_listen('test_event', record_event)
script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
{'service': 'test.script'},
{'event': 'test_event'}]))
with pytest.raises(exceptions.ServiceNotFound):
await script_obj.async_run()
assert len(events) == 0
async def test_propagate_error_invalid_service_data(hass):
"""Test that a script aborts when we send invalid service data."""
events = []
@callback
def record_event(event):
events.append(event)
hass.bus.async_listen('test_event', record_event)
calls = []
@callback
def record_call(service):
"""Add recorded event to set."""
calls.append(service)
hass.services.async_register('test', 'script', record_call,
schema=vol.Schema({'text': str}))
script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
{'service': 'test.script', 'data': {'text': 1}},
{'event': 'test_event'}]))
with pytest.raises(vol.Invalid):
await script_obj.async_run()
assert len(events) == 0
assert len(calls) == 0
async def test_propagate_error_service_exception(hass):
"""Test that a script aborts when a service throws an exception."""
events = []
@callback
def record_event(event):
events.append(event)
hass.bus.async_listen('test_event', record_event)
calls = []
@callback
def record_call(service):
"""Add recorded event to set."""
raise ValueError("BROKEN")
hass.services.async_register('test', 'script', record_call)
script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
{'service': 'test.script'},
{'event': 'test_event'}]))
with pytest.raises(ValueError):
await script_obj.async_run()
assert len(events) == 0
assert len(calls) == 0