From 61dabae6abc96b44bd6fc26ff93e1745ab86c156 Mon Sep 17 00:00:00 2001 From: Phil Bruckner Date: Fri, 7 Jun 2019 23:45:37 -0500 Subject: [PATCH] Add for option for template triggers (#24330) --- .../components/automation/template.py | 52 ++++++++--- tests/components/automation/test_template.py | 93 ++++++++++++++++++- 2 files changed, 133 insertions(+), 12 deletions(-) diff --git a/homeassistant/components/automation/template.py b/homeassistant/components/automation/template.py index 6371be28021..96075e9bd1c 100644 --- a/homeassistant/components/automation/template.py +++ b/homeassistant/components/automation/template.py @@ -4,8 +4,10 @@ import logging import voluptuous as vol from homeassistant.core import callback -from homeassistant.const import CONF_VALUE_TEMPLATE, CONF_PLATFORM -from homeassistant.helpers.event import async_track_template +from homeassistant.const import CONF_VALUE_TEMPLATE, CONF_PLATFORM, CONF_FOR +from homeassistant.helpers import condition +from homeassistant.helpers.event import ( + async_track_same_state, async_track_template) import homeassistant.helpers.config_validation as cv _LOGGER = logging.getLogger(__name__) @@ -13,6 +15,7 @@ _LOGGER = logging.getLogger(__name__) TRIGGER_SCHEMA = IF_ACTION_SCHEMA = vol.Schema({ vol.Required(CONF_PLATFORM): 'template', vol.Required(CONF_VALUE_TEMPLATE): cv.template, + vol.Optional(CONF_FOR): vol.All(cv.time_period, cv.positive_timedelta), }) @@ -20,17 +23,44 @@ async def async_trigger(hass, config, action, automation_info): """Listen for state changes based on configuration.""" value_template = config.get(CONF_VALUE_TEMPLATE) value_template.hass = hass + time_delta = config.get(CONF_FOR) + unsub_track_same = None @callback def template_listener(entity_id, from_s, to_s): """Listen for state changes and calls action.""" - hass.async_run_job(action({ - 'trigger': { - 'platform': 'template', - 'entity_id': entity_id, - 'from_state': from_s, - 'to_state': to_s, - }, - }, context=(to_s.context if to_s else None))) + nonlocal unsub_track_same - return async_track_template(hass, value_template, template_listener) + @callback + def call_action(): + """Call action with right context.""" + hass.async_run_job(action({ + 'trigger': { + 'platform': 'template', + 'entity_id': entity_id, + 'from_state': from_s, + 'to_state': to_s, + }, + }, context=(to_s.context if to_s else None))) + + if not time_delta: + call_action() + return + + unsub_track_same = async_track_same_state( + hass, time_delta, call_action, + lambda _, _2, _3: condition.async_template(hass, value_template), + value_template.extract_entities()) + + unsub = async_track_template( + hass, value_template, template_listener) + + @callback + def async_remove(): + """Remove state listeners async.""" + unsub() + if unsub_track_same: + # pylint: disable=not-callable + unsub_track_same() + + return async_remove diff --git a/tests/components/automation/test_template.py b/tests/components/automation/test_template.py index 25f32ac1939..815c5e440b4 100644 --- a/tests/components/automation/test_template.py +++ b/tests/components/automation/test_template.py @@ -1,11 +1,15 @@ """The tests for the Template automation.""" +from datetime import timedelta + import pytest from homeassistant.core import Context from homeassistant.setup import async_setup_component +import homeassistant.util.dt as dt_util import homeassistant.components.automation as automation -from tests.common import (assert_setup_component, mock_component) +from tests.common import ( + async_fire_time_changed, assert_setup_component, mock_component) from tests.components.automation import common from tests.common import async_mock_service @@ -434,3 +438,90 @@ async def test_wait_template_with_trigger(hass, calls): assert 1 == len(calls) assert 'template - test.entity - hello - world' == \ calls[0].data['some'] + + +async def test_if_fires_on_change_with_for(hass, calls): + """Test for firing on change with for.""" + assert await async_setup_component(hass, automation.DOMAIN, { + automation.DOMAIN: { + 'trigger': { + 'platform': 'template', + 'value_template': "{{ is_state('test.entity', 'world') }}", + 'for': { + 'seconds': 5 + }, + }, + 'action': { + 'service': 'test.automation' + } + } + }) + + hass.states.async_set('test.entity', 'world') + await hass.async_block_till_done() + assert 0 == len(calls) + async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=10)) + await hass.async_block_till_done() + assert 1 == len(calls) + + +async def test_if_not_fires_on_change_with_for(hass, calls): + """Test for firing on change with for.""" + assert await async_setup_component(hass, automation.DOMAIN, { + automation.DOMAIN: { + 'trigger': { + 'platform': 'template', + 'value_template': "{{ is_state('test.entity', 'world') }}", + 'for': { + 'seconds': 5 + }, + }, + 'action': { + 'service': 'test.automation' + } + } + }) + + hass.states.async_set('test.entity', 'world') + await hass.async_block_till_done() + assert 0 == len(calls) + async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=4)) + await hass.async_block_till_done() + assert 0 == len(calls) + hass.states.async_set('test.entity', 'hello') + await hass.async_block_till_done() + assert 0 == len(calls) + async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=6)) + await hass.async_block_till_done() + assert 0 == len(calls) + + +async def test_if_not_fires_when_turned_off_with_for(hass, calls): + """Test for firing on change with for.""" + assert await async_setup_component(hass, automation.DOMAIN, { + automation.DOMAIN: { + 'trigger': { + 'platform': 'template', + 'value_template': "{{ is_state('test.entity', 'world') }}", + 'for': { + 'seconds': 5 + }, + }, + 'action': { + 'service': 'test.automation' + } + } + }) + + hass.states.async_set('test.entity', 'world') + await hass.async_block_till_done() + assert 0 == len(calls) + async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=4)) + await hass.async_block_till_done() + assert 0 == len(calls) + await common.async_turn_off(hass) + await hass.async_block_till_done() + assert 0 == len(calls) + async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=6)) + await hass.async_block_till_done() + assert 0 == len(calls)