diff --git a/homeassistant/components/automation/numeric_state.py b/homeassistant/components/automation/numeric_state.py index 571888038a6..d5cdc9ffd83 100644 --- a/homeassistant/components/automation/numeric_state.py +++ b/homeassistant/components/automation/numeric_state.py @@ -38,13 +38,14 @@ def async_trigger(hass, config, action): time_delta = config.get(CONF_FOR) value_template = config.get(CONF_VALUE_TEMPLATE) async_remove_track_same = None + already_triggered = False if value_template is not None: value_template.hass = hass @callback def check_numeric_state(entity, from_s, to_s): - """Return True if they should trigger.""" + """Return True if criteria are now met.""" if to_s is None: return False @@ -56,51 +57,39 @@ def async_trigger(hass, config, action): 'above': above, } } - - # If new one doesn't match, nothing to do - if not condition.async_numeric_state( - hass, to_s, below, above, value_template, variables): - return False - - return True + return condition.async_numeric_state( + hass, to_s, below, above, value_template, variables) @callback def state_automation_listener(entity, from_s, to_s): """Listen for state changes and calls action.""" - nonlocal async_remove_track_same - - if not check_numeric_state(entity, from_s, to_s): - return - - variables = { - 'trigger': { - 'platform': 'numeric_state', - 'entity_id': entity, - 'below': below, - 'above': above, - 'from_state': from_s, - 'to_state': to_s, - } - } - - # Only match if old didn't exist or existed but didn't match - # Written as: skip if old one did exist and matched - if from_s is not None and condition.async_numeric_state( - hass, from_s, below, above, value_template, variables): - return + nonlocal already_triggered, async_remove_track_same @callback def call_action(): """Call action with right context.""" - hass.async_run_job(action, variables) + hass.async_run_job(action, { + 'trigger': { + 'platform': 'numeric_state', + 'entity_id': entity, + 'below': below, + 'above': above, + 'from_state': from_s, + 'to_state': to_s, + } + }) - if not time_delta: - call_action() - return + matching = check_numeric_state(entity, from_s, to_s) - async_remove_track_same = async_track_same_state( - hass, time_delta, call_action, entity_ids=entity_id, - async_check_same_func=check_numeric_state) + if matching and not already_triggered: + if time_delta: + async_remove_track_same = async_track_same_state( + hass, time_delta, call_action, entity_ids=entity_id, + async_check_same_func=check_numeric_state) + else: + call_action() + + already_triggered = matching unsub = async_track_state_change( hass, entity_id, state_automation_listener) diff --git a/tests/components/automation/test_numeric_state.py b/tests/components/automation/test_numeric_state.py index cb36a91dddb..35841baa930 100644 --- a/tests/components/automation/test_numeric_state.py +++ b/tests/components/automation/test_numeric_state.py @@ -86,7 +86,7 @@ class TestAutomationNumericState(unittest.TestCase): def test_if_not_fires_on_entity_change_below_to_below(self): """"Test the firing with changed entity.""" - self.hass.states.set('test.entity', 9) + self.hass.states.set('test.entity', 11) self.hass.block_till_done() assert setup_component(self.hass, automation.DOMAIN, { @@ -102,10 +102,15 @@ class TestAutomationNumericState(unittest.TestCase): } }) - # 9 is below 10 so this should not fire again - self.hass.states.set('test.entity', 8) + # 9 is below 10 so this should fire + self.hass.states.set('test.entity', 9) self.hass.block_till_done() - self.assertEqual(0, len(self.calls)) + self.assertEqual(1, len(self.calls)) + + # already below so should not fire again + self.hass.states.set('test.entity', 5) + self.hass.block_till_done() + self.assertEqual(1, len(self.calls)) def test_if_not_below_fires_on_entity_change_to_equal(self): """"Test the firing with changed entity.""" @@ -130,6 +135,52 @@ class TestAutomationNumericState(unittest.TestCase): self.hass.block_till_done() self.assertEqual(0, len(self.calls)) + def test_if_fires_on_initial_entity_below(self): + """"Test the firing when starting with a match.""" + self.hass.states.set('test.entity', 9) + self.hass.block_till_done() + + assert setup_component(self.hass, automation.DOMAIN, { + automation.DOMAIN: { + 'trigger': { + 'platform': 'numeric_state', + 'entity_id': 'test.entity', + 'below': 10, + }, + 'action': { + 'service': 'test.automation' + } + } + }) + + # Fire on first update even if initial state was already below + self.hass.states.set('test.entity', 8) + self.hass.block_till_done() + self.assertEqual(1, len(self.calls)) + + def test_if_fires_on_initial_entity_above(self): + """"Test the firing when starting with a match.""" + self.hass.states.set('test.entity', 11) + self.hass.block_till_done() + + assert setup_component(self.hass, automation.DOMAIN, { + automation.DOMAIN: { + 'trigger': { + 'platform': 'numeric_state', + 'entity_id': 'test.entity', + 'above': 10, + }, + 'action': { + 'service': 'test.automation' + } + } + }) + + # Fire on first update even if initial state was already above + self.hass.states.set('test.entity', 12) + self.hass.block_till_done() + self.assertEqual(1, len(self.calls)) + def test_if_fires_on_entity_change_above(self): """"Test the firing with changed entity.""" assert setup_component(self.hass, automation.DOMAIN, { @@ -176,7 +227,7 @@ class TestAutomationNumericState(unittest.TestCase): def test_if_not_fires_on_entity_change_above_to_above(self): """"Test the firing with changed entity.""" # set initial state - self.hass.states.set('test.entity', 11) + self.hass.states.set('test.entity', 9) self.hass.block_till_done() assert setup_component(self.hass, automation.DOMAIN, { @@ -192,10 +243,15 @@ class TestAutomationNumericState(unittest.TestCase): } }) - # 11 is above 10 so this should fire again + # 12 is above 10 so this should fire self.hass.states.set('test.entity', 12) self.hass.block_till_done() - self.assertEqual(0, len(self.calls)) + self.assertEqual(1, len(self.calls)) + + # already above, should not fire again + self.hass.states.set('test.entity', 15) + self.hass.block_till_done() + self.assertEqual(1, len(self.calls)) def test_if_not_above_fires_on_entity_change_to_equal(self): """"Test the firing with changed entity."""