From 1f31dfe5d341d6cda59344110bcf61b2ffc0d684 Mon Sep 17 00:00:00 2001 From: Johann Kellerman Date: Mon, 9 Jan 2017 22:53:30 +0200 Subject: [PATCH] [recorder] Include & Exclude domain fix & unit tests (#5213) * Tests & domain fix * incl/excl combined --- homeassistant/components/recorder/__init__.py | 34 ++++--- tests/components/recorder/test_init.py | 89 +++++++++++++++++++ 2 files changed, 110 insertions(+), 13 deletions(-) diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index 41a7991c32f..4f02fe2873d 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -16,9 +16,9 @@ from typing import Any, Union, Optional, List, Dict import voluptuous as vol -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import HomeAssistant, callback, split_entity_id from homeassistant.const import ( - ATTR_ENTITY_ID, ATTR_DOMAIN, CONF_ENTITIES, CONF_EXCLUDE, CONF_DOMAINS, + ATTR_ENTITY_ID, CONF_ENTITIES, CONF_EXCLUDE, CONF_DOMAINS, CONF_INCLUDE, EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL) import homeassistant.helpers.config_validation as cv @@ -181,8 +181,8 @@ class Recorder(threading.Thread): self.engine = None # type: Any self._run = None # type: Any - self.include = include.get(CONF_ENTITIES, []) + \ - include.get(CONF_DOMAINS, []) + self.include_e = include.get(CONF_ENTITIES, []) + self.include_d = include.get(CONF_DOMAINS, []) self.exclude = exclude.get(CONF_ENTITIES, []) + \ exclude.get(CONF_DOMAINS, []) @@ -230,17 +230,25 @@ class Recorder(threading.Thread): self.queue.task_done() continue - entity_id = event.data.get(ATTR_ENTITY_ID) - domain = event.data.get(ATTR_DOMAIN) + if ATTR_ENTITY_ID in event.data: + entity_id = event.data[ATTR_ENTITY_ID] + domain = split_entity_id(entity_id)[0] - if entity_id in self.exclude or domain in self.exclude: - self.queue.task_done() - continue + # Exclude entities OR + # Exclude domains, but include specific entities + if (entity_id in self.exclude) or \ + (domain in self.exclude and + entity_id not in self.include_e): + self.queue.task_done() + continue - if (self.include and entity_id not in self.include and - domain not in self.include): - self.queue.task_done() - continue + # Included domains only (excluded entities above) OR + # Include entities only, but only if no excludes + if (self.include_d and domain not in self.include_d) or \ + (self.include_e and entity_id not in self.include_e + and not self.exclude): + self.queue.task_done() + continue dbevent = Events.from_event(event) self._commit(dbevent) diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index 03e782841a2..e8a73e347ff 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -4,6 +4,7 @@ import json from datetime import datetime, timedelta import unittest +import pytest from homeassistant.core import callback from homeassistant.const import MATCH_ALL from homeassistant.components import recorder @@ -188,3 +189,91 @@ class TestRecorder(unittest.TestCase): # we should have all of our states still self.assertEqual(states.count(), 5) self.assertEqual(events.count(), 5) + + +@pytest.fixture +def hass_recorder(): + """HASS fixture with in-memory recorder.""" + hass = get_test_home_assistant() + + def setup_recorder(config): + """Setup with params.""" + db_uri = 'sqlite://' # In memory DB + conf = {recorder.CONF_DB_URL: db_uri} + conf.update(config) + assert setup_component(hass, recorder.DOMAIN, {recorder.DOMAIN: conf}) + hass.start() + hass.block_till_done() + recorder._verify_instance() + recorder._INSTANCE.block_till_done() + return hass + + yield setup_recorder + hass.stop() + + +def _add_entities(hass, entity_ids): + """Add entities.""" + attributes = {'test_attr': 5, 'test_attr_10': 'nice'} + for idx, entity_id in enumerate(entity_ids): + hass.states.set(entity_id, 'state{}'.format(idx), attributes) + hass.block_till_done() + recorder._INSTANCE.block_till_done() + db_states = recorder.query('States') + states = recorder.execute(db_states) + assert db_states[0].event_id is not None + return states + + +# pylint: disable=redefined-outer-name,invalid-name +def test_saving_state_include_domains(hass_recorder): + """Test saving and restoring a state.""" + hass = hass_recorder({'include': {'domains': 'test2'}}) + states = _add_entities(hass, ['test.recorder', 'test2.recorder']) + assert len(states) == 1 + assert hass.states.get('test2.recorder') == states[0] + + +def test_saving_state_incl_entities(hass_recorder): + """Test saving and restoring a state.""" + hass = hass_recorder({'include': {'entities': 'test2.recorder'}}) + states = _add_entities(hass, ['test.recorder', 'test2.recorder']) + assert len(states) == 1 + assert hass.states.get('test2.recorder') == states[0] + + +def test_saving_state_exclude_domains(hass_recorder): + """Test saving and restoring a state.""" + hass = hass_recorder({'exclude': {'domains': 'test'}}) + states = _add_entities(hass, ['test.recorder', 'test2.recorder']) + assert len(states) == 1 + assert hass.states.get('test2.recorder') == states[0] + + +def test_saving_state_exclude_entities(hass_recorder): + """Test saving and restoring a state.""" + hass = hass_recorder({'exclude': {'entities': 'test.recorder'}}) + states = _add_entities(hass, ['test.recorder', 'test2.recorder']) + assert len(states) == 1 + assert hass.states.get('test2.recorder') == states[0] + + +def test_saving_state_exclude_domain_include_entity(hass_recorder): + """Test saving and restoring a state.""" + hass = hass_recorder({ + 'include': {'entities': 'test.recorder'}, + 'exclude': {'domains': 'test'}}) + states = _add_entities(hass, ['test.recorder', 'test2.recorder']) + assert len(states) == 2 + + +def test_saving_state_include_domain_exclude_entity(hass_recorder): + """Test saving and restoring a state.""" + hass = hass_recorder({ + 'exclude': {'entities': 'test.recorder'}, + 'include': {'domains': 'test'}}) + states = _add_entities(hass, ['test.recorder', 'test2.recorder', + 'test.ok']) + assert len(states) == 1 + assert hass.states.get('test.ok') == states[0] + assert hass.states.get('test.ok').state == 'state2'