Add 'wait_template' to script commands / Refactory track_template (#5827)
* Add 'wait' to script commands. * Add track_template + unittest / rename wait_template * fix lint & test * Fix handling / change automation-template / add tests * address paulus comments
This commit is contained in:
parent
5f0b2a7d15
commit
9aac2113b6
6 changed files with 326 additions and 50 deletions
|
@ -10,8 +10,7 @@ import voluptuous as vol
|
|||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.const import CONF_VALUE_TEMPLATE, CONF_PLATFORM
|
||||
from homeassistant.helpers import condition
|
||||
from homeassistant.helpers.event import async_track_state_change
|
||||
from homeassistant.helpers.event import async_track_template
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
|
||||
|
||||
|
@ -28,28 +27,16 @@ def async_trigger(hass, config, action):
|
|||
value_template = config.get(CONF_VALUE_TEMPLATE)
|
||||
value_template.hass = hass
|
||||
|
||||
# Local variable to keep track of if the action has already been triggered
|
||||
already_triggered = False
|
||||
|
||||
@callback
|
||||
def state_changed_listener(entity_id, from_s, to_s):
|
||||
def template_listener(entity_id, from_s, to_s):
|
||||
"""Listen for state changes and calls action."""
|
||||
nonlocal already_triggered
|
||||
template_result = condition.async_template(hass, value_template)
|
||||
hass.async_run_job(action, {
|
||||
'trigger': {
|
||||
'platform': 'template',
|
||||
'entity_id': entity_id,
|
||||
'from_state': from_s,
|
||||
'to_state': to_s,
|
||||
},
|
||||
})
|
||||
|
||||
# Check to see if template returns true
|
||||
if template_result and not already_triggered:
|
||||
already_triggered = True
|
||||
hass.async_run_job(action, {
|
||||
'trigger': {
|
||||
'platform': 'template',
|
||||
'entity_id': entity_id,
|
||||
'from_state': from_s,
|
||||
'to_state': to_s,
|
||||
},
|
||||
})
|
||||
elif not template_result:
|
||||
already_triggered = False
|
||||
|
||||
return async_track_state_change(hass, value_template.extract_entities(),
|
||||
state_changed_listener)
|
||||
return async_track_template(hass, value_template, template_listener)
|
||||
|
|
|
@ -14,7 +14,7 @@ from homeassistant.loader import get_platform
|
|||
from homeassistant.const import (
|
||||
CONF_PLATFORM, CONF_SCAN_INTERVAL, TEMP_CELSIUS, TEMP_FAHRENHEIT,
|
||||
CONF_ALIAS, CONF_ENTITY_ID, CONF_VALUE_TEMPLATE, WEEKDAYS,
|
||||
CONF_CONDITION, CONF_BELOW, CONF_ABOVE, SUN_EVENT_SUNSET,
|
||||
CONF_CONDITION, CONF_BELOW, CONF_ABOVE, CONF_TIMEOUT, SUN_EVENT_SUNSET,
|
||||
SUN_EVENT_SUNRISE, CONF_UNIT_SYSTEM_IMPERIAL, CONF_UNIT_SYSTEM_METRIC)
|
||||
from homeassistant.core import valid_entity_id
|
||||
from homeassistant.exceptions import TemplateError
|
||||
|
@ -524,8 +524,14 @@ _SCRIPT_DELAY_SCHEMA = vol.Schema({
|
|||
template)
|
||||
})
|
||||
|
||||
_SCRIPT_WAIT_TEMPLATE_SCHEMA = vol.Schema({
|
||||
vol.Optional(CONF_ALIAS): string,
|
||||
vol.Required("wait_template"): template,
|
||||
vol.Optional(CONF_TIMEOUT): vol.All(time_period, positive_timedelta),
|
||||
})
|
||||
|
||||
SCRIPT_SCHEMA = vol.All(
|
||||
ensure_list,
|
||||
[vol.Any(SERVICE_SCHEMA, _SCRIPT_DELAY_SCHEMA, EVENT_SCHEMA,
|
||||
CONDITION_SCHEMA)],
|
||||
[vol.Any(SERVICE_SCHEMA, _SCRIPT_DELAY_SCHEMA,
|
||||
_SCRIPT_WAIT_TEMPLATE_SCHEMA, EVENT_SCHEMA, CONDITION_SCHEMA)],
|
||||
)
|
||||
|
|
|
@ -84,6 +84,33 @@ def async_track_state_change(hass, entity_ids, action, from_state=None,
|
|||
track_state_change = threaded_listener_factory(async_track_state_change)
|
||||
|
||||
|
||||
def async_track_template(hass, template, action, variables=None):
|
||||
"""Add a listener that track state changes with template condition."""
|
||||
from . import condition
|
||||
|
||||
# Local variable to keep track of if the action has already been triggered
|
||||
already_triggered = False
|
||||
|
||||
@callback
|
||||
def template_condition_listener(entity_id, from_s, to_s):
|
||||
"""Check if condition is correct and run action."""
|
||||
nonlocal already_triggered
|
||||
template_result = condition.async_template(hass, template, variables)
|
||||
|
||||
# Check to see if template returns true
|
||||
if template_result and not already_triggered:
|
||||
already_triggered = True
|
||||
hass.async_run_job(action, entity_id, from_s, to_s)
|
||||
elif not template_result:
|
||||
already_triggered = False
|
||||
|
||||
return async_track_state_change(
|
||||
hass, template.extract_entities(), template_condition_listener)
|
||||
|
||||
|
||||
track_template = threaded_listener_factory(async_track_template)
|
||||
|
||||
|
||||
def async_track_point_in_time(hass, action, point_in_time):
|
||||
"""Add a listener that fires once after a specific point in time."""
|
||||
utc_point_in_time = dt_util.as_utc(point_in_time)
|
||||
|
|
|
@ -6,11 +6,12 @@ from typing import Optional, Sequence
|
|||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.const import CONF_CONDITION
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.const import CONF_CONDITION, CONF_TIMEOUT
|
||||
from homeassistant.helpers import (
|
||||
service, condition, template, config_validation as cv)
|
||||
from homeassistant.helpers.event import async_track_point_in_utc_time
|
||||
from homeassistant.helpers.event import (
|
||||
async_track_point_in_utc_time, async_track_template)
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
import homeassistant.util.dt as date_util
|
||||
from homeassistant.util.async import (
|
||||
|
@ -25,6 +26,7 @@ CONF_SEQUENCE = "sequence"
|
|||
CONF_EVENT = "event"
|
||||
CONF_EVENT_DATA = "event_data"
|
||||
CONF_DELAY = "delay"
|
||||
CONF_WAIT_TEMPLATE = "wait_template"
|
||||
|
||||
|
||||
def call_from_config(hass: HomeAssistant, config: ConfigType,
|
||||
|
@ -47,9 +49,9 @@ class Script():
|
|||
self._cur = -1
|
||||
self.last_action = None
|
||||
self.last_triggered = None
|
||||
self.can_cancel = any(CONF_DELAY in action for action
|
||||
in self.sequence)
|
||||
self._async_unsub_delay_listener = None
|
||||
self.can_cancel = any(CONF_DELAY in action or CONF_WAIT_TEMPLATE
|
||||
in action for action in self.sequence)
|
||||
self._async_listener = []
|
||||
self._template_cache = {}
|
||||
self._config_cache = {}
|
||||
|
||||
|
@ -74,19 +76,21 @@ class Script():
|
|||
self._log('Running script')
|
||||
self._cur = 0
|
||||
|
||||
# Unregister callback if we were in a delay but turn on is called
|
||||
# again. In that case we just continue execution.
|
||||
# Unregister callback if we were in a delay or wait but turn on is
|
||||
# called again. In that case we just continue execution.
|
||||
self._async_remove_listener()
|
||||
|
||||
for cur, action in islice(enumerate(self.sequence), self._cur,
|
||||
None):
|
||||
for cur, action in islice(enumerate(self.sequence), self._cur, None):
|
||||
|
||||
if CONF_DELAY in action:
|
||||
# Call ourselves in the future to continue work
|
||||
@asyncio.coroutine
|
||||
def script_delay(now):
|
||||
unsub = None
|
||||
|
||||
@callback
|
||||
def async_script_delay(now):
|
||||
"""Called after delay is done."""
|
||||
self._async_unsub_delay_listener = None
|
||||
# pylint: disable=cell-var-from-loop
|
||||
self._async_listener.remove(unsub)
|
||||
self.hass.async_add_job(self.async_run(variables))
|
||||
|
||||
delay = action[CONF_DELAY]
|
||||
|
@ -97,15 +101,45 @@ class Script():
|
|||
cv.positive_timedelta)(
|
||||
delay.async_render(variables))
|
||||
|
||||
self._async_unsub_delay_listener = \
|
||||
async_track_point_in_utc_time(
|
||||
self.hass, script_delay,
|
||||
date_util.utcnow() + delay)
|
||||
unsub = async_track_point_in_utc_time(
|
||||
self.hass, async_script_delay,
|
||||
date_util.utcnow() + delay
|
||||
)
|
||||
self._async_listener.append(unsub)
|
||||
|
||||
self._cur = cur + 1
|
||||
if self._change_listener:
|
||||
self.hass.async_add_job(self._change_listener)
|
||||
return
|
||||
|
||||
elif CONF_WAIT_TEMPLATE in action:
|
||||
# Call ourselves in the future to continue work
|
||||
wait_template = action[CONF_WAIT_TEMPLATE]
|
||||
wait_template.hass = self.hass
|
||||
|
||||
# check if condition allready okay
|
||||
if condition.async_template(
|
||||
self.hass, wait_template, variables):
|
||||
continue
|
||||
|
||||
@callback
|
||||
def async_script_wait(entity_id, from_s, to_s):
|
||||
"""Called after template condition is true."""
|
||||
self._async_remove_listener()
|
||||
self.hass.async_add_job(self.async_run(variables))
|
||||
|
||||
self._async_listener.append(async_track_template(
|
||||
self.hass, wait_template, async_script_wait))
|
||||
|
||||
self._cur = cur + 1
|
||||
if self._change_listener:
|
||||
self.hass.async_add_job(self._change_listener)
|
||||
|
||||
if CONF_TIMEOUT in action:
|
||||
self._async_set_timeout(action, variables)
|
||||
|
||||
return
|
||||
|
||||
elif CONF_CONDITION in action:
|
||||
if not self._async_check_condition(action, variables):
|
||||
break
|
||||
|
@ -166,11 +200,29 @@ class Script():
|
|||
self._log("Test condition {}: {}".format(self.last_action, check))
|
||||
return check
|
||||
|
||||
def _async_set_timeout(self, action, variables):
|
||||
"""Schedule a timeout to abort script."""
|
||||
timeout = action[CONF_TIMEOUT]
|
||||
unsub = None
|
||||
|
||||
@callback
|
||||
def async_script_timeout(now):
|
||||
"""Call after timeout is retrieve stop script."""
|
||||
self._async_listener.remove(unsub)
|
||||
self._log("Timout reach, abort script.")
|
||||
self.async_stop()
|
||||
|
||||
unsub = async_track_point_in_utc_time(
|
||||
self.hass, async_script_timeout,
|
||||
date_util.utcnow() + timeout
|
||||
)
|
||||
self._async_listener.append(unsub)
|
||||
|
||||
def _async_remove_listener(self):
|
||||
"""Remove point in time listener, if any."""
|
||||
if self._async_unsub_delay_listener:
|
||||
self._async_unsub_delay_listener()
|
||||
self._async_unsub_delay_listener = None
|
||||
for unsub in self._async_listener:
|
||||
unsub()
|
||||
self._async_listener.clear()
|
||||
|
||||
def _log(self, msg):
|
||||
"""Logger helper."""
|
||||
|
|
|
@ -16,9 +16,11 @@ from homeassistant.helpers.event import (
|
|||
track_time_change,
|
||||
track_state_change,
|
||||
track_time_interval,
|
||||
track_template,
|
||||
track_sunrise,
|
||||
track_sunset,
|
||||
)
|
||||
from homeassistant.helpers.template import Template
|
||||
from homeassistant.components import sun
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
|
@ -188,6 +190,77 @@ class TestEventHelpers(unittest.TestCase):
|
|||
self.assertEqual(5, len(wildcard_runs))
|
||||
self.assertEqual(6, len(wildercard_runs))
|
||||
|
||||
def test_track_template(self):
|
||||
"""Test tracking template."""
|
||||
specific_runs = []
|
||||
wildcard_runs = []
|
||||
wildercard_runs = []
|
||||
|
||||
template_condition = Template(
|
||||
"{{states.switch.test.state == 'on'}}",
|
||||
self.hass
|
||||
)
|
||||
template_condition_var = Template(
|
||||
"{{states.switch.test.state == 'on' and test == 5}}",
|
||||
self.hass
|
||||
)
|
||||
|
||||
self.hass.states.set('switch.test', 'off')
|
||||
|
||||
def specific_run_callback(entity_id, old_state, new_state):
|
||||
specific_runs.append(1)
|
||||
|
||||
track_template(self.hass, template_condition, specific_run_callback)
|
||||
|
||||
@ha.callback
|
||||
def wildcard_run_callback(entity_id, old_state, new_state):
|
||||
wildcard_runs.append((old_state, new_state))
|
||||
|
||||
track_template(self.hass, template_condition, wildcard_run_callback)
|
||||
|
||||
@asyncio.coroutine
|
||||
def wildercard_run_callback(entity_id, old_state, new_state):
|
||||
wildercard_runs.append((old_state, new_state))
|
||||
|
||||
track_template(
|
||||
self.hass, template_condition_var, wildercard_run_callback,
|
||||
{'test': 5})
|
||||
|
||||
self.hass.states.set('switch.test', 'on')
|
||||
self.hass.block_till_done()
|
||||
|
||||
self.assertEqual(1, len(specific_runs))
|
||||
self.assertEqual(1, len(wildcard_runs))
|
||||
self.assertEqual(1, len(wildercard_runs))
|
||||
|
||||
self.hass.states.set('switch.test', 'on')
|
||||
self.hass.block_till_done()
|
||||
|
||||
self.assertEqual(1, len(specific_runs))
|
||||
self.assertEqual(1, len(wildcard_runs))
|
||||
self.assertEqual(1, len(wildercard_runs))
|
||||
|
||||
self.hass.states.set('switch.test', 'off')
|
||||
self.hass.block_till_done()
|
||||
|
||||
self.assertEqual(1, len(specific_runs))
|
||||
self.assertEqual(1, len(wildcard_runs))
|
||||
self.assertEqual(1, len(wildercard_runs))
|
||||
|
||||
self.hass.states.set('switch.test', 'off')
|
||||
self.hass.block_till_done()
|
||||
|
||||
self.assertEqual(1, len(specific_runs))
|
||||
self.assertEqual(1, len(wildcard_runs))
|
||||
self.assertEqual(1, len(wildercard_runs))
|
||||
|
||||
self.hass.states.set('switch.test', 'on')
|
||||
self.hass.block_till_done()
|
||||
|
||||
self.assertEqual(2, len(specific_runs))
|
||||
self.assertEqual(2, len(wildcard_runs))
|
||||
self.assertEqual(2, len(wildercard_runs))
|
||||
|
||||
def test_track_time_interval(self):
|
||||
"""Test tracking time interval."""
|
||||
specific_runs = []
|
||||
|
|
|
@ -131,7 +131,6 @@ class TestScriptHelper(unittest.TestCase):
|
|||
{'event': event}]))
|
||||
|
||||
script_obj.run()
|
||||
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert script_obj.is_running
|
||||
|
@ -164,7 +163,6 @@ class TestScriptHelper(unittest.TestCase):
|
|||
{'event': event}]))
|
||||
|
||||
script_obj.run()
|
||||
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert script_obj.is_running
|
||||
|
@ -196,7 +194,6 @@ class TestScriptHelper(unittest.TestCase):
|
|||
{'event': event}]))
|
||||
|
||||
script_obj.run()
|
||||
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert script_obj.is_running
|
||||
|
@ -214,6 +211,140 @@ class TestScriptHelper(unittest.TestCase):
|
|||
assert not script_obj.is_running
|
||||
assert len(events) == 0
|
||||
|
||||
def test_wait_template(self):
|
||||
"""Test the wait template."""
|
||||
event = 'test_event'
|
||||
events = []
|
||||
|
||||
@callback
|
||||
def record_event(event):
|
||||
"""Add recorded event to set."""
|
||||
events.append(event)
|
||||
|
||||
self.hass.bus.listen(event, record_event)
|
||||
|
||||
self.hass.states.set('switch.test', 'on')
|
||||
|
||||
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA([
|
||||
{'event': event},
|
||||
{'wait_template': "{{states.switch.test.state == 'off'}}"},
|
||||
{'event': event}]))
|
||||
|
||||
script_obj.run()
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert script_obj.is_running
|
||||
assert script_obj.can_cancel
|
||||
assert script_obj.last_action == event
|
||||
assert len(events) == 1
|
||||
|
||||
self.hass.states.set('switch.test', 'off')
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert not script_obj.is_running
|
||||
assert len(events) == 2
|
||||
|
||||
def test_wait_template_cancel(self):
|
||||
"""Test the wait template cancel action."""
|
||||
event = 'test_event'
|
||||
events = []
|
||||
|
||||
@callback
|
||||
def record_event(event):
|
||||
"""Add recorded event to set."""
|
||||
events.append(event)
|
||||
|
||||
self.hass.bus.listen(event, record_event)
|
||||
|
||||
self.hass.states.set('switch.test', 'on')
|
||||
|
||||
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA([
|
||||
{'event': event},
|
||||
{'wait_template': "{{states.switch.test.state == 'off'}}"},
|
||||
{'event': event}]))
|
||||
|
||||
script_obj.run()
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert script_obj.is_running
|
||||
assert script_obj.can_cancel
|
||||
assert script_obj.last_action == event
|
||||
assert len(events) == 1
|
||||
|
||||
script_obj.stop()
|
||||
|
||||
assert not script_obj.is_running
|
||||
assert len(events) == 1
|
||||
|
||||
self.hass.states.set('switch.test', 'off')
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert not script_obj.is_running
|
||||
assert len(events) == 1
|
||||
|
||||
def test_wait_template_not_schedule(self):
|
||||
"""Test the wait template with correct condition."""
|
||||
event = 'test_event'
|
||||
events = []
|
||||
|
||||
@callback
|
||||
def record_event(event):
|
||||
"""Add recorded event to set."""
|
||||
events.append(event)
|
||||
|
||||
self.hass.bus.listen(event, record_event)
|
||||
|
||||
self.hass.states.set('switch.test', 'on')
|
||||
|
||||
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA([
|
||||
{'event': event},
|
||||
{'wait_template': "{{states.switch.test.state == 'on'}}"},
|
||||
{'event': event}]))
|
||||
|
||||
script_obj.run()
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert not script_obj.is_running
|
||||
assert script_obj.can_cancel
|
||||
assert len(events) == 2
|
||||
|
||||
def test_wait_template_timeout(self):
|
||||
"""Test the wait template."""
|
||||
event = 'test_event'
|
||||
events = []
|
||||
|
||||
@callback
|
||||
def record_event(event):
|
||||
"""Add recorded event to set."""
|
||||
events.append(event)
|
||||
|
||||
self.hass.bus.listen(event, record_event)
|
||||
|
||||
self.hass.states.set('switch.test', 'on')
|
||||
|
||||
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA([
|
||||
{'event': event},
|
||||
{
|
||||
'wait_template': "{{states.switch.test.state == 'off'}}",
|
||||
'timeout': 5
|
||||
},
|
||||
{'event': event}]))
|
||||
|
||||
script_obj.run()
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert script_obj.is_running
|
||||
assert script_obj.can_cancel
|
||||
assert script_obj.last_action == event
|
||||
assert len(events) == 1
|
||||
|
||||
future = dt_util.utcnow() + timedelta(seconds=5)
|
||||
fire_time_changed(self.hass, future)
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert not script_obj.is_running
|
||||
assert len(events) == 1
|
||||
|
||||
def test_passing_variables_to_script(self):
|
||||
"""Test if we can pass variables to script."""
|
||||
calls = []
|
||||
|
|
Loading…
Add table
Reference in a new issue