From f8b2570cb3a8df1349f60ffa6393f65aa97d9f82 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sun, 3 Jan 2016 02:32:09 -0800 Subject: [PATCH] Group entities when reproducing a state --- homeassistant/helpers/state.py | 35 +++++---- tests/helpers/test_state.py | 129 +++++++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+), 13 deletions(-) create mode 100644 tests/helpers/test_state.py diff --git a/homeassistant/helpers/state.py b/homeassistant/helpers/state.py index 24a37c5b5ea..019e7ce6ce9 100644 --- a/homeassistant/helpers/state.py +++ b/homeassistant/helpers/state.py @@ -1,9 +1,5 @@ -""" -homeassistant.helpers.state -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Helpers that help with state related things. -""" +"""Helpers that help with state related things.""" +from collections import defaultdict import logging from homeassistant.core import State @@ -25,32 +21,36 @@ class TrackStates(object): that have changed since the start time to the return list when with-block is exited. """ + def __init__(self, hass): + """Initialize a TrackStates block.""" self.hass = hass self.states = [] def __enter__(self): + """Record time from which to track changes.""" self.now = dt_util.utcnow() return self.states def __exit__(self, exc_type, exc_value, traceback): + """Add changes states to changes list.""" self.states.extend(get_changed_since(self.hass.states.all(), self.now)) def get_changed_since(states, utc_point_in_time): - """ - Returns all states that have been changed since utc_point_in_time. - """ + """List of states that have been changed since utc_point_in_time.""" point_in_time = dt_util.strip_microseconds(utc_point_in_time) return [state for state in states if state.last_updated >= point_in_time] def reproduce_state(hass, states, blocking=False): - """ Takes in a state and will try to have the entity reproduce it. """ + """Reproduce given state.""" if isinstance(states, State): states = [states] + to_call = defaultdict(list) + for state in states: current_state = hass.states.get(state.entity_id) @@ -76,7 +76,16 @@ def reproduce_state(hass, states, blocking=False): state) continue - service_data = dict(state.attributes) - service_data[ATTR_ENTITY_ID] = state.entity_id + if state.domain == 'group': + service_domain = 'homeassistant' + else: + service_domain = state.domain - hass.services.call(state.domain, service, service_data, blocking) + # We group service calls for entities by service call + key = (service_domain, service, tuple(state.attributes.items())) + to_call[key].append(state.entity_id) + + for (service_domain, service, service_data), entity_ids in to_call.items(): + data = dict(service_data) + data[ATTR_ENTITY_ID] = entity_ids + hass.services.call(service_domain, service, data, blocking) diff --git a/tests/helpers/test_state.py b/tests/helpers/test_state.py new file mode 100644 index 00000000000..32924b1d6d5 --- /dev/null +++ b/tests/helpers/test_state.py @@ -0,0 +1,129 @@ +""" +tests.helpers.test_state +~~~~~~~~~~~~~~~~~~~~~~~~ + +Test state helpers. +""" +from datetime import timedelta +import unittest +from unittest.mock import patch + +import homeassistant.core as ha +import homeassistant.components as core_components +from homeassistant.const import SERVICE_TURN_ON +from homeassistant.util import dt as dt_util +from homeassistant.helpers import state + +from tests.common import get_test_home_assistant, mock_service + + +class TestStateHelpers(unittest.TestCase): + """ + Tests the Home Assistant event helpers. + """ + + def setUp(self): # pylint: disable=invalid-name + """ things to be run when tests are started. """ + self.hass = get_test_home_assistant() + core_components.setup(self.hass, {}) + + def tearDown(self): # pylint: disable=invalid-name + """ Stop down stuff we started. """ + self.hass.stop() + + def test_get_changed_since(self): + point1 = dt_util.utcnow() + point2 = point1 + timedelta(seconds=5) + point3 = point2 + timedelta(seconds=5) + + with patch('homeassistant.core.dt_util.utcnow', return_value=point1): + self.hass.states.set('light.test', 'on') + state1 = self.hass.states.get('light.test') + + with patch('homeassistant.core.dt_util.utcnow', return_value=point2): + self.hass.states.set('light.test2', 'on') + state2 = self.hass.states.get('light.test2') + + with patch('homeassistant.core.dt_util.utcnow', return_value=point3): + self.hass.states.set('light.test3', 'on') + state3 = self.hass.states.get('light.test3') + + self.assertEqual( + [state2, state3], + state.get_changed_since([state1, state2, state3], point2)) + + def test_track_states(self): + point1 = dt_util.utcnow() + point2 = point1 + timedelta(seconds=5) + point3 = point2 + timedelta(seconds=5) + + with patch('homeassistant.core.dt_util.utcnow') as mock_utcnow: + mock_utcnow.return_value = point2 + + with state.TrackStates(self.hass) as states: + mock_utcnow.return_value = point1 + self.hass.states.set('light.test', 'on') + + mock_utcnow.return_value = point2 + self.hass.states.set('light.test2', 'on') + state2 = self.hass.states.get('light.test2') + + mock_utcnow.return_value = point3 + self.hass.states.set('light.test3', 'on') + state3 = self.hass.states.get('light.test3') + + self.assertEqual( + sorted([state2, state3], key=lambda state: state.entity_id), + sorted(states, key=lambda state: state.entity_id)) + + def test_reproduce_state_with_turn_on(self): + calls = mock_service(self.hass, 'light', SERVICE_TURN_ON) + + self.hass.states.set('light.test', 'off') + + state.reproduce_state(self.hass, ha.State('light.test', 'on')) + + self.hass.pool.block_till_done() + + self.assertTrue(len(calls) > 0) + last_call = calls[-1] + self.assertEqual('light', last_call.domain) + self.assertEqual(SERVICE_TURN_ON, last_call.service) + self.assertEqual(['light.test'], last_call.data.get('entity_id')) + + def test_reproduce_state_with_group(self): + light_calls = mock_service(self.hass, 'light', SERVICE_TURN_ON) + + self.hass.states.set('group.test', 'off', { + 'entity_id': ['light.test1', 'light.test2']}) + + state.reproduce_state(self.hass, ha.State('group.test', 'on')) + + self.hass.pool.block_till_done() + + self.assertEqual(1, len(light_calls)) + last_call = light_calls[-1] + self.assertEqual('light', last_call.domain) + self.assertEqual(SERVICE_TURN_ON, last_call.service) + self.assertEqual(['light.test1', 'light.test2'], + last_call.data.get('entity_id')) + + def test_reproduce_state_group_states_with_same_domain_and_data(self): + light_calls = mock_service(self.hass, 'light', SERVICE_TURN_ON) + + self.hass.states.set('light.test1', 'off') + self.hass.states.set('light.test2', 'off') + + state.reproduce_state(self.hass, [ + ha.State('light.test1', 'on', {'brightness': 95}), + ha.State('light.test2', 'on', {'brightness': 95})]) + + self.hass.pool.block_till_done() + + self.assertEqual(1, len(light_calls)) + last_call = light_calls[-1] + self.assertEqual('light', last_call.domain) + self.assertEqual(SERVICE_TURN_ON, last_call.service) + self.assertEqual(['light.test1', 'light.test2'], + last_call.data.get('entity_id')) + self.assertEqual(95, last_call.data.get('brightness'))