From b2ad8db86bd67886f9bc4a049b49a5af5961f0cb Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 14 Sep 2015 22:51:28 -0700 Subject: [PATCH] Add condition type to automation component --- .../components/automation/__init__.py | 33 ++++++- .../components/automation/numeric_state.py | 16 ++-- homeassistant/components/automation/state.py | 15 ++-- homeassistant/components/automation/time.py | 11 +-- tests/components/automation/test_init.py | 88 ++++++++++++++++++- 5 files changed, 137 insertions(+), 26 deletions(-) diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index 23126b2a3b3..73411ddd1db 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -23,6 +23,11 @@ CONF_SERVICE_DATA = "service_data" CONF_CONDITION = "condition" CONF_ACTION = 'action' CONF_TRIGGER = "trigger" +CONF_CONDITION_TYPE = "condition_type" + +CONDITION_TYPE_AND = "and" +CONDITION_TYPE_OR = "or" +DEFAULT_CONDITION_TYPE = CONDITION_TYPE_AND _LOGGER = logging.getLogger(__name__) @@ -44,8 +49,13 @@ def setup(hass, config): continue if CONF_CONDITION in p_config: + cond_type = p_config.get(CONF_CONDITION_TYPE, + DEFAULT_CONDITION_TYPE).lower() action = _process_if(hass, config, p_config[CONF_CONDITION], - action) + action, cond_type) + + if action is None: + continue _process_trigger(hass, config, p_config.get(CONF_TRIGGER, []), name, action) @@ -116,21 +126,36 @@ def _migrate_old_config(config): return new_conf -def _process_if(hass, config, if_configs, action): +def _process_if(hass, config, if_configs, action, cond_type): """ Processes if checks. """ if isinstance(if_configs, dict): if_configs = [if_configs] + checks = [] for if_config in if_configs: platform = _resolve_platform('condition', hass, config, if_config.get(CONF_PLATFORM)) if platform is None: continue - action = platform.if_action(hass, if_config, action) + check = platform.if_action(hass, if_config) - return action + if check is None: + return None + + checks.append(check) + + if cond_type == CONDITION_TYPE_AND: + def if_action(): + if all(check() for check in checks): + action() + else: + def if_action(): + if any(check() for check in checks): + action() + + return if_action def _process_trigger(hass, config, trigger_configs, name, action): diff --git a/homeassistant/components/automation/numeric_state.py b/homeassistant/components/automation/numeric_state.py index 95691d0ebcc..7e014213d62 100644 --- a/homeassistant/components/automation/numeric_state.py +++ b/homeassistant/components/automation/numeric_state.py @@ -48,14 +48,14 @@ def trigger(hass, config, action): return True -def if_action(hass, config, action): +def if_action(hass, config): """ Wraps action method with state based condition. """ entity_id = config.get(CONF_ENTITY_ID) if entity_id is None: _LOGGER.error("Missing configuration key %s", CONF_ENTITY_ID) - return action + return None below = config.get(CONF_BELOW) above = config.get(CONF_ABOVE) @@ -64,16 +64,14 @@ def if_action(hass, config, action): _LOGGER.error("Missing configuration key." " One of %s or %s is required", CONF_BELOW, CONF_ABOVE) - return action - - def state_if(): - """ Execute action if state matches. """ + return None + def if_numeric_state(): + """ Test numeric state condition. """ state = hass.states.get(entity_id) - if state is not None and _in_range(state.state, above, below): - action() + return state is not None and _in_range(state.state, above, below) - return state_if + return if_numeric_state def _in_range(value, range_start, range_end): diff --git a/homeassistant/components/automation/state.py b/homeassistant/components/automation/state.py index 7bd0542855c..bb936d36a1b 100644 --- a/homeassistant/components/automation/state.py +++ b/homeassistant/components/automation/state.py @@ -38,7 +38,7 @@ def trigger(hass, config, action): return True -def if_action(hass, config, action): +def if_action(hass, config): """ Wraps action method with state based condition. """ entity_id = config.get(CONF_ENTITY_ID) state = config.get(CONF_STATE) @@ -47,11 +47,12 @@ def if_action(hass, config, action): logging.getLogger(__name__).error( "Missing if-condition configuration key %s or %s", CONF_ENTITY_ID, CONF_STATE) - return action + return None - def state_if(): - """ Execute action if state matches. """ - if hass.states.is_state(entity_id, state): - action() + state = str(state) - return state_if + def if_state(): + """ Test if condition. """ + return hass.states.is_state(entity_id, state) + + return if_state diff --git a/homeassistant/components/automation/time.py b/homeassistant/components/automation/time.py index b5bfcd274ee..a7afa183ba0 100644 --- a/homeassistant/components/automation/time.py +++ b/homeassistant/components/automation/time.py @@ -36,7 +36,7 @@ def trigger(hass, config, action): return True -def if_action(hass, config, action): +def if_action(hass, config): """ Wraps action method with time based condition. """ before = config.get(CONF_BEFORE) after = config.get(CONF_AFTER) @@ -46,6 +46,7 @@ def if_action(hass, config, action): logging.getLogger(__name__).error( "Missing if-condition configuration key %s, %s or %s", CONF_BEFORE, CONF_AFTER, CONF_WEEKDAY) + return None def time_if(): """ Validate time based if-condition """ @@ -59,7 +60,7 @@ def if_action(hass, config, action): minute=int(before_m)) if now > before_point: - return + return False if after is not None: # Strip seconds if given @@ -68,15 +69,15 @@ def if_action(hass, config, action): after_point = now.replace(hour=int(after_h), minute=int(after_m)) if now < after_point: - return + return False if weekday is not None: now_weekday = WEEKDAYS[now.weekday()] if isinstance(weekday, str) and weekday != now_weekday or \ now_weekday not in weekday: - return + return False - action() + return True return time_if diff --git a/tests/components/automation/test_init.py b/tests/components/automation/test_init.py index e2477972ead..8553a4472be 100644 --- a/tests/components/automation/test_init.py +++ b/tests/components/automation/test_init.py @@ -162,7 +162,6 @@ class TestAutomationEvent(unittest.TestCase): ], 'action': { 'execute_service': 'test.automation', - 'service_entity_id': ['hello.world', 'hello.world2'] } } }) @@ -173,3 +172,90 @@ class TestAutomationEvent(unittest.TestCase): self.hass.states.set('test.entity', 'hello') self.hass.pool.block_till_done() self.assertEqual(2, len(self.calls)) + + def test_two_conditions_with_and(self): + entity_id = 'test.entity' + automation.setup(self.hass, { + automation.DOMAIN: { + 'trigger': [ + { + 'platform': 'event', + 'event_type': 'test_event', + }, + ], + 'condition': [ + { + 'platform': 'state', + 'entity_id': entity_id, + 'state': 100 + }, + { + 'platform': 'numeric_state', + 'entity_id': entity_id, + 'below': 150 + } + ], + 'action': { + 'execute_service': 'test.automation', + } + } + }) + + self.hass.states.set(entity_id, 100) + self.hass.bus.fire('test_event') + self.hass.pool.block_till_done() + self.assertEqual(1, len(self.calls)) + + self.hass.states.set(entity_id, 101) + self.hass.bus.fire('test_event') + self.hass.pool.block_till_done() + self.assertEqual(1, len(self.calls)) + + self.hass.states.set(entity_id, 151) + self.hass.bus.fire('test_event') + self.hass.pool.block_till_done() + self.assertEqual(1, len(self.calls)) + + def test_two_conditions_with_or(self): + entity_id = 'test.entity' + automation.setup(self.hass, { + automation.DOMAIN: { + 'trigger': [ + { + 'platform': 'event', + 'event_type': 'test_event', + }, + ], + 'condition_type': 'OR', + 'condition': [ + { + 'platform': 'state', + 'entity_id': entity_id, + 'state': 200 + }, + { + 'platform': 'numeric_state', + 'entity_id': entity_id, + 'below': 150 + } + ], + 'action': { + 'execute_service': 'test.automation', + } + } + }) + + self.hass.states.set(entity_id, 200) + self.hass.bus.fire('test_event') + self.hass.pool.block_till_done() + self.assertEqual(1, len(self.calls)) + + self.hass.states.set(entity_id, 100) + self.hass.bus.fire('test_event') + self.hass.pool.block_till_done() + self.assertEqual(2, len(self.calls)) + + self.hass.states.set(entity_id, 250) + self.hass.bus.fire('test_event') + self.hass.pool.block_till_done() + self.assertEqual(2, len(self.calls))