From 3374169c741f8890a24bd9c4996ebd143c4e3b34 Mon Sep 17 00:00:00 2001 From: Jan Losinski Date: Wed, 26 Apr 2017 21:14:52 +0200 Subject: [PATCH] Allow InfluxDB to blacklist domains (#7264) * Allow InfluxDB to blacklist domains This adds an option to InfluxDB to blacklist whole domains. This is useful for domains like automation or script, where no statistic data is needed. Signed-off-by: Jan Losinski * Add unittest for InfluxDB domain blacklist Signed-off-by: Jan Losinski * Use common include/exclude config for InfluxDB. Its now the same syntax as it is for recorder. Signed-off-by: Jan Losinski * Add unittests for InfluxDB include whitelist. There where no tests for that feature before. Signed-off-by: Jan Losinski --- homeassistant/components/influxdb.py | 33 +++++-- tests/components/test_influxdb.py | 132 ++++++++++++++++++++++++++- 2 files changed, 153 insertions(+), 12 deletions(-) diff --git a/homeassistant/components/influxdb.py b/homeassistant/components/influxdb.py index 430c5cbe4c6..58479b6c14e 100644 --- a/homeassistant/components/influxdb.py +++ b/homeassistant/components/influxdb.py @@ -10,8 +10,8 @@ import voluptuous as vol from homeassistant.const import ( EVENT_STATE_CHANGED, STATE_UNAVAILABLE, STATE_UNKNOWN, CONF_HOST, - CONF_PORT, CONF_SSL, CONF_VERIFY_SSL, CONF_USERNAME, CONF_BLACKLIST, - CONF_PASSWORD, CONF_WHITELIST) + CONF_PORT, CONF_SSL, CONF_VERIFY_SSL, CONF_USERNAME, CONF_PASSWORD, + CONF_EXCLUDE, CONF_INCLUDE, CONF_DOMAINS, CONF_ENTITIES) from homeassistant.helpers import state as state_helper import homeassistant.helpers.config_validation as cv @@ -23,6 +23,7 @@ CONF_DB_NAME = 'database' CONF_TAGS = 'tags' CONF_DEFAULT_MEASUREMENT = 'default_measurement' CONF_OVERRIDE_MEASUREMENT = 'override_measurement' +CONF_BLACKLIST_DOMAINS = "blacklist_domains" DEFAULT_DATABASE = 'home_assistant' DEFAULT_VERIFY_SSL = True @@ -34,8 +35,16 @@ CONFIG_SCHEMA = vol.Schema({ vol.Optional(CONF_HOST): cv.string, vol.Inclusive(CONF_USERNAME, 'authentication'): cv.string, vol.Inclusive(CONF_PASSWORD, 'authentication'): cv.string, - vol.Optional(CONF_BLACKLIST, default=[]): - vol.All(cv.ensure_list, [cv.entity_id]), + vol.Optional(CONF_EXCLUDE, default={}): vol.Schema({ + vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids, + vol.Optional(CONF_DOMAINS, default=[]): + vol.All(cv.ensure_list, [cv.string]) + }), + vol.Optional(CONF_INCLUDE, default={}): vol.Schema({ + vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids, + vol.Optional(CONF_DOMAINS, default=[]): + vol.All(cv.ensure_list, [cv.string]) + }), vol.Optional(CONF_DB_NAME, default=DEFAULT_DATABASE): cv.string, vol.Optional(CONF_PORT): cv.port, vol.Optional(CONF_SSL): cv.boolean, @@ -43,8 +52,6 @@ CONFIG_SCHEMA = vol.Schema({ vol.Optional(CONF_OVERRIDE_MEASUREMENT): cv.string, vol.Optional(CONF_TAGS, default={}): vol.Schema({cv.string: cv.string}), - vol.Optional(CONF_WHITELIST, default=[]): - vol.All(cv.ensure_list, [cv.entity_id]), vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): cv.boolean, }), }, extra=vol.ALLOW_EXTRA) @@ -77,8 +84,12 @@ def setup(hass, config): if CONF_SSL in conf: kwargs['ssl'] = conf[CONF_SSL] - blacklist = conf.get(CONF_BLACKLIST) - whitelist = conf.get(CONF_WHITELIST) + include = conf.get(CONF_INCLUDE, {}) + exclude = conf.get(CONF_EXCLUDE, {}) + whitelist_e = set(include.get(CONF_ENTITIES, [])) + whitelist_d = set(include.get(CONF_DOMAINS, [])) + blacklist_e = set(exclude.get(CONF_ENTITIES, [])) + blacklist_d = set(exclude.get(CONF_DOMAINS, [])) tags = conf.get(CONF_TAGS) default_measurement = conf.get(CONF_DEFAULT_MEASUREMENT) override_measurement = conf.get(CONF_OVERRIDE_MEASUREMENT) @@ -97,11 +108,13 @@ def setup(hass, config): state = event.data.get('new_state') if state is None or state.state in ( STATE_UNKNOWN, '', STATE_UNAVAILABLE) or \ - state.entity_id in blacklist: + state.entity_id in blacklist_e or \ + state.domain in blacklist_d: return try: - if whitelist and state.entity_id not in whitelist: + if (whitelist_e and state.entity_id not in whitelist_e) or \ + (whitelist_d and state.domain not in whitelist_d): return _state = float(state_helper.state_as_number(state)) diff --git a/tests/components/test_influxdb.py b/tests/components/test_influxdb.py index c1ad2672365..ab1f8916c37 100644 --- a/tests/components/test_influxdb.py +++ b/tests/components/test_influxdb.py @@ -96,7 +96,10 @@ class TestInfluxDB(unittest.TestCase): 'host': 'host', 'username': 'user', 'password': 'pass', - 'blacklist': ['fake.blacklisted'] + 'exclude': { + 'entities': ['fake.blacklisted'], + 'domains': ['another_fake'] + } } } assert setup_component(self.hass, influxdb.DOMAIN, config) @@ -273,6 +276,129 @@ class TestInfluxDB(unittest.TestCase): self.assertFalse(mock_client.return_value.write_points.called) mock_client.return_value.write_points.reset_mock() + def test_event_listener_blacklist_domain(self, mock_client): + """Test the event listener against a blacklist.""" + self._setup() + + for domain in ('ok', 'another_fake'): + state = mock.MagicMock( + state=1, domain=domain, + entity_id='{}.something'.format(domain), + object_id='something', attributes={}) + event = mock.MagicMock(data={'new_state': state}, time_fired=12345) + body = [{ + 'measurement': '{}.something'.format(domain), + 'tags': { + 'domain': domain, + 'entity_id': 'something', + }, + 'time': 12345, + 'fields': { + 'value': 1, + }, + }] + self.handler_method(event) + if domain == 'ok': + self.assertEqual( + mock_client.return_value.write_points.call_count, 1 + ) + self.assertEqual( + mock_client.return_value.write_points.call_args, + mock.call(body) + ) + else: + self.assertFalse(mock_client.return_value.write_points.called) + mock_client.return_value.write_points.reset_mock() + + def test_event_listener_whitelist(self, mock_client): + """Test the event listener against a whitelist.""" + config = { + 'influxdb': { + 'host': 'host', + 'username': 'user', + 'password': 'pass', + 'include': { + 'entities': ['fake.included'], + } + } + } + assert setup_component(self.hass, influxdb.DOMAIN, config) + self.handler_method = self.hass.bus.listen.call_args_list[0][0][1] + + for entity_id in ('included', 'default'): + state = mock.MagicMock( + state=1, domain='fake', entity_id='fake.{}'.format(entity_id), + object_id=entity_id, attributes={}) + event = mock.MagicMock(data={'new_state': state}, time_fired=12345) + body = [{ + 'measurement': 'fake.{}'.format(entity_id), + 'tags': { + 'domain': 'fake', + 'entity_id': entity_id, + }, + 'time': 12345, + 'fields': { + 'value': 1, + }, + }] + self.handler_method(event) + if entity_id == 'included': + self.assertEqual( + mock_client.return_value.write_points.call_count, 1 + ) + self.assertEqual( + mock_client.return_value.write_points.call_args, + mock.call(body) + ) + else: + self.assertFalse(mock_client.return_value.write_points.called) + mock_client.return_value.write_points.reset_mock() + + def test_event_listener_whitelist_domain(self, mock_client): + """Test the event listener against a whitelist.""" + config = { + 'influxdb': { + 'host': 'host', + 'username': 'user', + 'password': 'pass', + 'include': { + 'domains': ['fake'], + } + } + } + assert setup_component(self.hass, influxdb.DOMAIN, config) + self.handler_method = self.hass.bus.listen.call_args_list[0][0][1] + + for domain in ('fake', 'another_fake'): + state = mock.MagicMock( + state=1, domain=domain, + entity_id='{}.something'.format(domain), + object_id='something', attributes={}) + event = mock.MagicMock(data={'new_state': state}, time_fired=12345) + body = [{ + 'measurement': '{}.something'.format(domain), + 'tags': { + 'domain': domain, + 'entity_id': 'something', + }, + 'time': 12345, + 'fields': { + 'value': 1, + }, + }] + self.handler_method(event) + if domain == 'fake': + self.assertEqual( + mock_client.return_value.write_points.call_count, 1 + ) + self.assertEqual( + mock_client.return_value.write_points.call_args, + mock.call(body) + ) + else: + self.assertFalse(mock_client.return_value.write_points.called) + mock_client.return_value.write_points.reset_mock() + def test_event_listener_invalid_type(self, mock_client): """Test the event listener when an attirbute has an invalid type.""" self._setup() @@ -343,7 +469,9 @@ class TestInfluxDB(unittest.TestCase): 'username': 'user', 'password': 'pass', 'default_measurement': 'state', - 'blacklist': ['fake.blacklisted'] + 'exclude': { + 'entities': ['fake.blacklisted'] + } } } assert setup_component(self.hass, influxdb.DOMAIN, config)