Refactor script helper actions into their own methods (#18962)
* Refactor script helper actions into their own methods * Lint * Lint
This commit is contained in:
parent
d0751ffd91
commit
d028236bf2
2 changed files with 227 additions and 92 deletions
|
@ -9,7 +9,7 @@ import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant, Context, callback
|
from homeassistant.core import HomeAssistant, Context, callback
|
||||||
from homeassistant.const import CONF_CONDITION, CONF_TIMEOUT
|
from homeassistant.const import CONF_CONDITION, CONF_TIMEOUT
|
||||||
from homeassistant.exceptions import TemplateError
|
from homeassistant import exceptions
|
||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import (
|
||||||
service, condition, template as template,
|
service, condition, template as template,
|
||||||
config_validation as cv)
|
config_validation as cv)
|
||||||
|
@ -34,6 +34,30 @@ CONF_WAIT_TEMPLATE = 'wait_template'
|
||||||
CONF_CONTINUE = 'continue_on_timeout'
|
CONF_CONTINUE = 'continue_on_timeout'
|
||||||
|
|
||||||
|
|
||||||
|
ACTION_DELAY = 'delay'
|
||||||
|
ACTION_WAIT_TEMPLATE = 'wait_template'
|
||||||
|
ACTION_CHECK_CONDITION = 'condition'
|
||||||
|
ACTION_FIRE_EVENT = 'event'
|
||||||
|
ACTION_CALL_SERVICE = 'call_service'
|
||||||
|
|
||||||
|
|
||||||
|
def _determine_action(action):
|
||||||
|
"""Determine action type."""
|
||||||
|
if CONF_DELAY in action:
|
||||||
|
return ACTION_DELAY
|
||||||
|
|
||||||
|
if CONF_WAIT_TEMPLATE in action:
|
||||||
|
return ACTION_WAIT_TEMPLATE
|
||||||
|
|
||||||
|
if CONF_CONDITION in action:
|
||||||
|
return ACTION_CHECK_CONDITION
|
||||||
|
|
||||||
|
if CONF_EVENT in action:
|
||||||
|
return ACTION_FIRE_EVENT
|
||||||
|
|
||||||
|
return ACTION_CALL_SERVICE
|
||||||
|
|
||||||
|
|
||||||
def call_from_config(hass: HomeAssistant, config: ConfigType,
|
def call_from_config(hass: HomeAssistant, config: ConfigType,
|
||||||
variables: Optional[Sequence] = None,
|
variables: Optional[Sequence] = None,
|
||||||
context: Optional[Context] = None) -> None:
|
context: Optional[Context] = None) -> None:
|
||||||
|
@ -41,6 +65,14 @@ def call_from_config(hass: HomeAssistant, config: ConfigType,
|
||||||
Script(hass, cv.SCRIPT_SCHEMA(config)).run(variables, context)
|
Script(hass, cv.SCRIPT_SCHEMA(config)).run(variables, context)
|
||||||
|
|
||||||
|
|
||||||
|
class _StopScript(Exception):
|
||||||
|
"""Throw if script needs to stop."""
|
||||||
|
|
||||||
|
|
||||||
|
class _SuspendScript(Exception):
|
||||||
|
"""Throw if script needs to suspend."""
|
||||||
|
|
||||||
|
|
||||||
class Script():
|
class Script():
|
||||||
"""Representation of a script."""
|
"""Representation of a script."""
|
||||||
|
|
||||||
|
@ -60,6 +92,13 @@ class Script():
|
||||||
self._async_listener = []
|
self._async_listener = []
|
||||||
self._template_cache = {}
|
self._template_cache = {}
|
||||||
self._config_cache = {}
|
self._config_cache = {}
|
||||||
|
self._actions = {
|
||||||
|
ACTION_DELAY: self._async_delay,
|
||||||
|
ACTION_WAIT_TEMPLATE: self._async_wait_template,
|
||||||
|
ACTION_CHECK_CONDITION: self._async_check_condition,
|
||||||
|
ACTION_FIRE_EVENT: self._async_fire_event,
|
||||||
|
ACTION_CALL_SERVICE: self._async_call_service,
|
||||||
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
|
@ -87,98 +126,27 @@ class Script():
|
||||||
self._async_remove_listener()
|
self._async_remove_listener()
|
||||||
|
|
||||||
for cur, action in islice(enumerate(self.sequence), self._cur, None):
|
for cur, action in islice(enumerate(self.sequence), self._cur, None):
|
||||||
|
try:
|
||||||
if CONF_DELAY in action:
|
await self._handle_action(action, variables, context)
|
||||||
# Call ourselves in the future to continue work
|
except _SuspendScript:
|
||||||
unsub = None
|
# Store next step to take and notify change listeners
|
||||||
|
|
||||||
@callback
|
|
||||||
def async_script_delay(now):
|
|
||||||
"""Handle delay."""
|
|
||||||
# pylint: disable=cell-var-from-loop
|
|
||||||
with suppress(ValueError):
|
|
||||||
self._async_listener.remove(unsub)
|
|
||||||
|
|
||||||
self.hass.async_create_task(
|
|
||||||
self.async_run(variables, context))
|
|
||||||
|
|
||||||
delay = action[CONF_DELAY]
|
|
||||||
|
|
||||||
try:
|
|
||||||
if isinstance(delay, template.Template):
|
|
||||||
delay = vol.All(
|
|
||||||
cv.time_period,
|
|
||||||
cv.positive_timedelta)(
|
|
||||||
delay.async_render(variables))
|
|
||||||
elif isinstance(delay, dict):
|
|
||||||
delay_data = {}
|
|
||||||
delay_data.update(
|
|
||||||
template.render_complex(delay, variables))
|
|
||||||
delay = cv.time_period(delay_data)
|
|
||||||
except (TemplateError, vol.Invalid) as ex:
|
|
||||||
_LOGGER.error("Error rendering '%s' delay template: %s",
|
|
||||||
self.name, ex)
|
|
||||||
break
|
|
||||||
|
|
||||||
self.last_action = action.get(
|
|
||||||
CONF_ALIAS, 'delay {}'.format(delay))
|
|
||||||
self._log("Executing step %s" % self.last_action)
|
|
||||||
|
|
||||||
unsub = async_track_point_in_utc_time(
|
|
||||||
self.hass, async_script_delay,
|
|
||||||
date_util.utcnow() + delay
|
|
||||||
)
|
|
||||||
self._async_listener.append(unsub)
|
|
||||||
|
|
||||||
self._cur = cur + 1
|
self._cur = cur + 1
|
||||||
if self._change_listener:
|
if self._change_listener:
|
||||||
self.hass.async_add_job(self._change_listener)
|
self.hass.async_add_job(self._change_listener)
|
||||||
return
|
return
|
||||||
|
except _StopScript:
|
||||||
|
break
|
||||||
|
except Exception as err:
|
||||||
|
# Store the step that had an exception
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
err._script_step = cur
|
||||||
|
# Set script to not running
|
||||||
|
self._cur = -1
|
||||||
|
self.last_action = None
|
||||||
|
# Pass exception on.
|
||||||
|
raise
|
||||||
|
|
||||||
if CONF_WAIT_TEMPLATE in action:
|
# Set script to not-running.
|
||||||
# Call ourselves in the future to continue work
|
|
||||||
wait_template = action[CONF_WAIT_TEMPLATE]
|
|
||||||
wait_template.hass = self.hass
|
|
||||||
|
|
||||||
self.last_action = action.get(CONF_ALIAS, 'wait template')
|
|
||||||
self._log("Executing step %s" % self.last_action)
|
|
||||||
|
|
||||||
# check if condition already okay
|
|
||||||
if condition.async_template(
|
|
||||||
self.hass, wait_template, variables):
|
|
||||||
continue
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def async_script_wait(entity_id, from_s, to_s):
|
|
||||||
"""Handle script after template condition is true."""
|
|
||||||
self._async_remove_listener()
|
|
||||||
self.hass.async_create_task(
|
|
||||||
self.async_run(variables, context))
|
|
||||||
|
|
||||||
self._async_listener.append(async_track_template(
|
|
||||||
self.hass, wait_template, async_script_wait, variables))
|
|
||||||
|
|
||||||
self._cur = cur + 1
|
|
||||||
if self._change_listener:
|
|
||||||
self.hass.async_add_job(self._change_listener)
|
|
||||||
|
|
||||||
if CONF_TIMEOUT in action:
|
|
||||||
self._async_set_timeout(
|
|
||||||
action, variables, context,
|
|
||||||
action.get(CONF_CONTINUE, True))
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
if CONF_CONDITION in action:
|
|
||||||
if not self._async_check_condition(action, variables):
|
|
||||||
break
|
|
||||||
|
|
||||||
elif CONF_EVENT in action:
|
|
||||||
self._async_fire_event(action, variables, context)
|
|
||||||
|
|
||||||
else:
|
|
||||||
await self._async_call_service(action, variables, context)
|
|
||||||
|
|
||||||
self._cur = -1
|
self._cur = -1
|
||||||
self.last_action = None
|
self.last_action = None
|
||||||
if self._change_listener:
|
if self._change_listener:
|
||||||
|
@ -198,6 +166,86 @@ class Script():
|
||||||
if self._change_listener:
|
if self._change_listener:
|
||||||
self.hass.async_add_job(self._change_listener)
|
self.hass.async_add_job(self._change_listener)
|
||||||
|
|
||||||
|
async def _handle_action(self, action, variables, context):
|
||||||
|
"""Handle an action."""
|
||||||
|
await self._actions[_determine_action(action)](
|
||||||
|
action, variables, context)
|
||||||
|
|
||||||
|
async def _async_delay(self, action, variables, context):
|
||||||
|
"""Handle delay."""
|
||||||
|
# Call ourselves in the future to continue work
|
||||||
|
unsub = None
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_script_delay(now):
|
||||||
|
"""Handle delay."""
|
||||||
|
# pylint: disable=cell-var-from-loop
|
||||||
|
with suppress(ValueError):
|
||||||
|
self._async_listener.remove(unsub)
|
||||||
|
|
||||||
|
self.hass.async_create_task(
|
||||||
|
self.async_run(variables, context))
|
||||||
|
|
||||||
|
delay = action[CONF_DELAY]
|
||||||
|
|
||||||
|
try:
|
||||||
|
if isinstance(delay, template.Template):
|
||||||
|
delay = vol.All(
|
||||||
|
cv.time_period,
|
||||||
|
cv.positive_timedelta)(
|
||||||
|
delay.async_render(variables))
|
||||||
|
elif isinstance(delay, dict):
|
||||||
|
delay_data = {}
|
||||||
|
delay_data.update(
|
||||||
|
template.render_complex(delay, variables))
|
||||||
|
delay = cv.time_period(delay_data)
|
||||||
|
except (exceptions.TemplateError, vol.Invalid) as ex:
|
||||||
|
_LOGGER.error("Error rendering '%s' delay template: %s",
|
||||||
|
self.name, ex)
|
||||||
|
raise _StopScript
|
||||||
|
|
||||||
|
self.last_action = action.get(
|
||||||
|
CONF_ALIAS, 'delay {}'.format(delay))
|
||||||
|
self._log("Executing step %s" % self.last_action)
|
||||||
|
|
||||||
|
unsub = async_track_point_in_utc_time(
|
||||||
|
self.hass, async_script_delay,
|
||||||
|
date_util.utcnow() + delay
|
||||||
|
)
|
||||||
|
self._async_listener.append(unsub)
|
||||||
|
raise _SuspendScript
|
||||||
|
|
||||||
|
async def _async_wait_template(self, action, variables, context):
|
||||||
|
"""Handle a wait template."""
|
||||||
|
# Call ourselves in the future to continue work
|
||||||
|
wait_template = action[CONF_WAIT_TEMPLATE]
|
||||||
|
wait_template.hass = self.hass
|
||||||
|
|
||||||
|
self.last_action = action.get(CONF_ALIAS, 'wait template')
|
||||||
|
self._log("Executing step %s" % self.last_action)
|
||||||
|
|
||||||
|
# check if condition already okay
|
||||||
|
if condition.async_template(
|
||||||
|
self.hass, wait_template, variables):
|
||||||
|
return
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_script_wait(entity_id, from_s, to_s):
|
||||||
|
"""Handle script after template condition is true."""
|
||||||
|
self._async_remove_listener()
|
||||||
|
self.hass.async_create_task(
|
||||||
|
self.async_run(variables, context))
|
||||||
|
|
||||||
|
self._async_listener.append(async_track_template(
|
||||||
|
self.hass, wait_template, async_script_wait, variables))
|
||||||
|
|
||||||
|
if CONF_TIMEOUT in action:
|
||||||
|
self._async_set_timeout(
|
||||||
|
action, variables, context,
|
||||||
|
action.get(CONF_CONTINUE, True))
|
||||||
|
|
||||||
|
raise _SuspendScript
|
||||||
|
|
||||||
async def _async_call_service(self, action, variables, context):
|
async def _async_call_service(self, action, variables, context):
|
||||||
"""Call the service specified in the action.
|
"""Call the service specified in the action.
|
||||||
|
|
||||||
|
@ -213,7 +261,7 @@ class Script():
|
||||||
context=context
|
context=context
|
||||||
)
|
)
|
||||||
|
|
||||||
def _async_fire_event(self, action, variables, context):
|
async def _async_fire_event(self, action, variables, context):
|
||||||
"""Fire an event."""
|
"""Fire an event."""
|
||||||
self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT])
|
self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT])
|
||||||
self._log("Executing step %s" % self.last_action)
|
self._log("Executing step %s" % self.last_action)
|
||||||
|
@ -222,13 +270,13 @@ class Script():
|
||||||
try:
|
try:
|
||||||
event_data.update(template.render_complex(
|
event_data.update(template.render_complex(
|
||||||
action[CONF_EVENT_DATA_TEMPLATE], variables))
|
action[CONF_EVENT_DATA_TEMPLATE], variables))
|
||||||
except TemplateError as ex:
|
except exceptions.TemplateError as ex:
|
||||||
_LOGGER.error('Error rendering event data template: %s', ex)
|
_LOGGER.error('Error rendering event data template: %s', ex)
|
||||||
|
|
||||||
self.hass.bus.async_fire(action[CONF_EVENT],
|
self.hass.bus.async_fire(action[CONF_EVENT],
|
||||||
event_data, context=context)
|
event_data, context=context)
|
||||||
|
|
||||||
def _async_check_condition(self, action, variables):
|
async def _async_check_condition(self, action, variables, context):
|
||||||
"""Test if condition is matching."""
|
"""Test if condition is matching."""
|
||||||
config_cache_key = frozenset((k, str(v)) for k, v in action.items())
|
config_cache_key = frozenset((k, str(v)) for k, v in action.items())
|
||||||
config = self._config_cache.get(config_cache_key)
|
config = self._config_cache.get(config_cache_key)
|
||||||
|
@ -239,7 +287,9 @@ class Script():
|
||||||
self.last_action = action.get(CONF_ALIAS, action[CONF_CONDITION])
|
self.last_action = action.get(CONF_ALIAS, action[CONF_CONDITION])
|
||||||
check = config(self.hass, variables)
|
check = config(self.hass, variables)
|
||||||
self._log("Test condition {}: {}".format(self.last_action, check))
|
self._log("Test condition {}: {}".format(self.last_action, check))
|
||||||
return check
|
|
||||||
|
if not check:
|
||||||
|
raise _StopScript
|
||||||
|
|
||||||
def _async_set_timeout(self, action, variables, context,
|
def _async_set_timeout(self, action, variables, context,
|
||||||
continue_on_timeout):
|
continue_on_timeout):
|
||||||
|
|
|
@ -4,6 +4,10 @@ from datetime import timedelta
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import voluptuous as vol
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant import exceptions
|
||||||
from homeassistant.core import Context, callback
|
from homeassistant.core import Context, callback
|
||||||
# Otherwise can't test just this file (import order issue)
|
# Otherwise can't test just this file (import order issue)
|
||||||
import homeassistant.components # noqa
|
import homeassistant.components # noqa
|
||||||
|
@ -774,3 +778,84 @@ class TestScriptHelper(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
assert script_obj.last_triggered == time
|
assert script_obj.last_triggered == time
|
||||||
|
|
||||||
|
|
||||||
|
async def test_propagate_error_service_not_found(hass):
|
||||||
|
"""Test that a script aborts when a service is not found."""
|
||||||
|
events = []
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def record_event(event):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
hass.bus.async_listen('test_event', record_event)
|
||||||
|
|
||||||
|
script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
|
||||||
|
{'service': 'test.script'},
|
||||||
|
{'event': 'test_event'}]))
|
||||||
|
|
||||||
|
with pytest.raises(exceptions.ServiceNotFound):
|
||||||
|
await script_obj.async_run()
|
||||||
|
|
||||||
|
assert len(events) == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_propagate_error_invalid_service_data(hass):
|
||||||
|
"""Test that a script aborts when we send invalid service data."""
|
||||||
|
events = []
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def record_event(event):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
hass.bus.async_listen('test_event', record_event)
|
||||||
|
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def record_call(service):
|
||||||
|
"""Add recorded event to set."""
|
||||||
|
calls.append(service)
|
||||||
|
|
||||||
|
hass.services.async_register('test', 'script', record_call,
|
||||||
|
schema=vol.Schema({'text': str}))
|
||||||
|
|
||||||
|
script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
|
||||||
|
{'service': 'test.script', 'data': {'text': 1}},
|
||||||
|
{'event': 'test_event'}]))
|
||||||
|
|
||||||
|
with pytest.raises(vol.Invalid):
|
||||||
|
await script_obj.async_run()
|
||||||
|
|
||||||
|
assert len(events) == 0
|
||||||
|
assert len(calls) == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_propagate_error_service_exception(hass):
|
||||||
|
"""Test that a script aborts when a service throws an exception."""
|
||||||
|
events = []
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def record_event(event):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
hass.bus.async_listen('test_event', record_event)
|
||||||
|
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def record_call(service):
|
||||||
|
"""Add recorded event to set."""
|
||||||
|
raise ValueError("BROKEN")
|
||||||
|
|
||||||
|
hass.services.async_register('test', 'script', record_call)
|
||||||
|
|
||||||
|
script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
|
||||||
|
{'service': 'test.script'},
|
||||||
|
{'event': 'test_event'}]))
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await script_obj.async_run()
|
||||||
|
|
||||||
|
assert len(events) == 0
|
||||||
|
assert len(calls) == 0
|
||||||
|
|
Loading…
Add table
Reference in a new issue