Add support for glob matching to entity filters (#36913)

* Added GLOB capability to entityfilter and every place that uses it. All existing tests are passing

* added tests for components affected by glob change

* fixed flake8 error

* mocking the correct listener

* mocking correct bus method in azure test

* tests passing in 3.7 and 3.8

* fixed formatting issue from rebase/conflict

* Checking against glob patterns in more performant way

* perf improvments and reverted unnecessarily adjusted tests

* added new benchmark test around filters

* no longer using get with default in entityfilter

* changed filter name and removed logbook from filter benchmark

* simplified benchmark tests from feedback

* fixed apache tests and returned include exclude schemas to normal

* fixed azure event hub tests to properly go through component logic

* fixed azure test and clean up for other tests

* renaming test files to match standard

* merged mqtt statestream test changes with base

* removed dependency on recorder filter schema from history

* fixed recorder tests after merge and a bunch of lint errors
This commit is contained in:
mdegat01 2020-06-23 21:02:29 -04:00 committed by GitHub
parent a1ac1fb091
commit 6c7355785a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 1832 additions and 278 deletions

View file

@ -46,7 +46,6 @@ omit =
homeassistant/components/android_ip_webcam/* homeassistant/components/android_ip_webcam/*
homeassistant/components/anel_pwrctrl/switch.py homeassistant/components/anel_pwrctrl/switch.py
homeassistant/components/anthemav/media_player.py homeassistant/components/anthemav/media_player.py
homeassistant/components/apache_kafka/*
homeassistant/components/apcupsd/* homeassistant/components/apcupsd/*
homeassistant/components/apple_tv/* homeassistant/components/apple_tv/*
homeassistant/components/aqualogic/* homeassistant/components/aqualogic/*
@ -70,7 +69,6 @@ omit =
homeassistant/components/avion/light.py homeassistant/components/avion/light.py
homeassistant/components/avri/const.py homeassistant/components/avri/const.py
homeassistant/components/avri/sensor.py homeassistant/components/avri/sensor.py
homeassistant/components/azure_event_hub/*
homeassistant/components/azure_service_bus/* homeassistant/components/azure_service_bus/*
homeassistant/components/baidu/tts.py homeassistant/components/baidu/tts.py
homeassistant/components/beewi_smartclim/sensor.py homeassistant/components/beewi_smartclim/sensor.py

View file

@ -40,14 +40,26 @@ CONF_ORDER = "use_include_order"
STATE_KEY = "state" STATE_KEY = "state"
LAST_CHANGED_KEY = "last_changed" LAST_CHANGED_KEY = "last_changed"
CONFIG_SCHEMA = vol.Schema( # Not reusing from entityfilter because history does not support glob filtering
_FILTER_SCHEMA_INNER = vol.Schema(
{ {
DOMAIN: recorder.FILTER_SCHEMA.extend( vol.Optional(CONF_DOMAINS, default=[]): vol.All(cv.ensure_list, [cv.string]),
{vol.Optional(CONF_ORDER, default=False): cv.boolean} vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
) }
},
extra=vol.ALLOW_EXTRA,
) )
_FILTER_SCHEMA = vol.Schema(
{
vol.Optional(
CONF_INCLUDE, default=_FILTER_SCHEMA_INNER({})
): _FILTER_SCHEMA_INNER,
vol.Optional(
CONF_EXCLUDE, default=_FILTER_SCHEMA_INNER({})
): _FILTER_SCHEMA_INNER,
vol.Optional(CONF_ORDER, default=False): cv.boolean,
}
)
CONFIG_SCHEMA = vol.Schema({DOMAIN: _FILTER_SCHEMA}, extra=vol.ALLOW_EXTRA)
SIGNIFICANT_DOMAINS = ( SIGNIFICANT_DOMAINS = (
"climate", "climate",
@ -143,7 +155,6 @@ def _get_significant_states(
def state_changes_during_period(hass, start_time, end_time=None, entity_id=None): def state_changes_during_period(hass, start_time, end_time=None, entity_id=None):
"""Return states changes during UTC period start_time - end_time.""" """Return states changes during UTC period start_time - end_time."""
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
query = session.query(*QUERY_STATES).filter( query = session.query(*QUERY_STATES).filter(
(States.last_changed == States.last_updated) (States.last_changed == States.last_updated)
@ -165,7 +176,6 @@ def state_changes_during_period(hass, start_time, end_time=None, entity_id=None)
def get_last_state_changes(hass, number_of_states, entity_id): def get_last_state_changes(hass, number_of_states, entity_id):
"""Return the last number_of_states.""" """Return the last number_of_states."""
start_time = dt_util.utcnow() start_time = dt_util.utcnow()
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
@ -196,7 +206,6 @@ def get_last_state_changes(hass, number_of_states, entity_id):
def get_states(hass, utc_point_in_time, entity_ids=None, run=None, filters=None): def get_states(hass, utc_point_in_time, entity_ids=None, run=None, filters=None):
"""Return the states at a specific point in time.""" """Return the states at a specific point in time."""
if run is None: if run is None:
run = recorder.run_information_from_instance(hass, utc_point_in_time) run = recorder.run_information_from_instance(hass, utc_point_in_time)
@ -542,7 +551,6 @@ class Filters:
* if include and exclude is defined - select the entities specified in * if include and exclude is defined - select the entities specified in
the include and filter out the ones from the exclude list. the include and filter out the ones from the exclude list.
""" """
# specific entities requested - do not in/exclude anything # specific entities requested - do not in/exclude anything
if entity_ids is not None: if entity_ids is not None:
return query.filter(States.entity_id.in_(entity_ids)) return query.filter(States.entity_id.in_(entity_ids))

View file

@ -33,14 +33,7 @@ from homeassistant.core import CoreState, HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryNotReady, Unauthorized from homeassistant.exceptions import ConfigEntryNotReady, Unauthorized
from homeassistant.helpers import device_registry, entity_registry from homeassistant.helpers import device_registry, entity_registry
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entityfilter import ( from homeassistant.helpers.entityfilter import BASE_FILTER_SCHEMA, FILTER_SCHEMA
BASE_FILTER_SCHEMA,
CONF_EXCLUDE_DOMAINS,
CONF_EXCLUDE_ENTITIES,
CONF_INCLUDE_DOMAINS,
CONF_INCLUDE_ENTITIES,
convert_filter,
)
from homeassistant.loader import async_get_integration from homeassistant.loader import async_get_integration
from homeassistant.util import get_local_ip from homeassistant.util import get_local_ip
@ -144,7 +137,6 @@ RESET_ACCESSORY_SERVICE_SCHEMA = vol.Schema(
async def async_setup(hass: HomeAssistant, config: dict): async def async_setup(hass: HomeAssistant, config: dict):
"""Set up the HomeKit from yaml.""" """Set up the HomeKit from yaml."""
hass.data.setdefault(DOMAIN, {}) hass.data.setdefault(DOMAIN, {})
_async_register_events_and_services(hass) _async_register_events_and_services(hass)
@ -221,17 +213,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry):
entity_config = options.get(CONF_ENTITY_CONFIG, {}).copy() entity_config = options.get(CONF_ENTITY_CONFIG, {}).copy()
auto_start = options.get(CONF_AUTO_START, DEFAULT_AUTO_START) auto_start = options.get(CONF_AUTO_START, DEFAULT_AUTO_START)
safe_mode = options.get(CONF_SAFE_MODE, DEFAULT_SAFE_MODE) safe_mode = options.get(CONF_SAFE_MODE, DEFAULT_SAFE_MODE)
entity_filter = convert_filter( entity_filter = FILTER_SCHEMA(options.get(CONF_FILTER, {}))
options.get(
CONF_FILTER,
{
CONF_INCLUDE_DOMAINS: [],
CONF_EXCLUDE_DOMAINS: [],
CONF_INCLUDE_ENTITIES: [],
CONF_EXCLUDE_ENTITIES: [],
},
)
)
homekit = HomeKit( homekit = HomeKit(
hass, hass,
@ -272,7 +254,6 @@ async def _async_update_listener(hass: HomeAssistant, entry: ConfigEntry):
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry): async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry):
"""Unload a config entry.""" """Unload a config entry."""
dismiss_setup_message(hass, entry.entry_id) dismiss_setup_message(hass, entry.entry_id)
hass.data[DOMAIN][entry.entry_id][UNDO_UPDATE_LISTENER]() hass.data[DOMAIN][entry.entry_id][UNDO_UPDATE_LISTENER]()
@ -319,7 +300,6 @@ def _async_import_options_from_data_if_missing(hass: HomeAssistant, entry: Confi
@callback @callback
def _async_register_events_and_services(hass: HomeAssistant): def _async_register_events_and_services(hass: HomeAssistant):
"""Register events and services for HomeKit.""" """Register events and services for HomeKit."""
hass.http.register_view(HomeKitPairingQRView) hass.http.register_view(HomeKitPairingQRView)
def handle_homekit_reset_accessory(service): def handle_homekit_reset_accessory(service):
@ -504,7 +484,6 @@ class HomeKit:
async def async_start(self, *args): async def async_start(self, *args):
"""Start the accessory driver.""" """Start the accessory driver."""
if self.status != STATUS_READY: if self.status != STATUS_READY:
return return
self.status = STATUS_WAIT self.status = STATUS_WAIT

View file

@ -42,7 +42,11 @@ from homeassistant.const import (
) )
from homeassistant.core import DOMAIN as HA_DOMAIN, callback, split_entity_id from homeassistant.core import DOMAIN as HA_DOMAIN, callback, split_entity_id
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entityfilter import generate_filter from homeassistant.helpers.entityfilter import (
INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA,
convert_include_exclude_filter,
generate_filter,
)
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@ -59,31 +63,8 @@ DOMAIN = "logbook"
GROUP_BY_MINUTES = 15 GROUP_BY_MINUTES = 15
EMPTY_JSON_OBJECT = "{}" EMPTY_JSON_OBJECT = "{}"
CONFIG_SCHEMA = vol.Schema( CONFIG_SCHEMA = vol.Schema(
{ {DOMAIN: INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA}, extra=vol.ALLOW_EXTRA
DOMAIN: vol.Schema(
{
CONF_EXCLUDE: vol.Schema(
{
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
vol.Optional(CONF_DOMAINS, default=[]): vol.All(
cv.ensure_list, [cv.string]
),
}
),
CONF_INCLUDE: vol.Schema(
{
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
vol.Optional(CONF_DOMAINS, default=[]): vol.All(
cv.ensure_list, [cv.string]
),
}
),
}
)
},
extra=vol.ALLOW_EXTRA,
) )
HOMEASSISTANT_EVENTS = [ HOMEASSISTANT_EVENTS = [
@ -129,7 +110,6 @@ def async_describe_event(hass, domain, event_name, describe_callback):
async def async_setup(hass, config): async def async_setup(hass, config):
"""Logbook setup.""" """Logbook setup."""
hass.data.setdefault(DOMAIN, {}) hass.data.setdefault(DOMAIN, {})
@callback @callback
@ -360,26 +340,6 @@ def _get_related_entity_ids(session, entity_filter):
time.sleep(QUERY_RETRY_WAIT) time.sleep(QUERY_RETRY_WAIT)
def _generate_filter_from_config(config):
excluded_entities = []
excluded_domains = []
included_entities = []
included_domains = []
exclude = config.get(CONF_EXCLUDE)
if exclude:
excluded_entities = exclude.get(CONF_ENTITIES, [])
excluded_domains = exclude.get(CONF_DOMAINS, [])
include = config.get(CONF_INCLUDE)
if include:
included_entities = include.get(CONF_ENTITIES, [])
included_domains = include.get(CONF_DOMAINS, [])
return generate_filter(
included_domains, included_entities, excluded_domains, excluded_entities
)
def _all_entities_filter(_): def _all_entities_filter(_):
"""Filter that accepts all entities.""" """Filter that accepts all entities."""
return True return True
@ -387,7 +347,6 @@ def _all_entities_filter(_):
def _get_events(hass, config, start_day, end_day, entity_id=None): def _get_events(hass, config, start_day, end_day, entity_id=None):
"""Get events for a period of time.""" """Get events for a period of time."""
entity_attr_cache = EntityAttributeCache(hass) entity_attr_cache = EntityAttributeCache(hass)
def yield_events(query): def yield_events(query):
@ -402,7 +361,7 @@ def _get_events(hass, config, start_day, end_day, entity_id=None):
entity_ids = [entity_id.lower()] entity_ids = [entity_id.lower()]
entities_filter = generate_filter([], entity_ids, [], []) entities_filter = generate_filter([], entity_ids, [], [])
elif config.get(CONF_EXCLUDE) or config.get(CONF_INCLUDE): elif config.get(CONF_EXCLUDE) or config.get(CONF_INCLUDE):
entities_filter = _generate_filter_from_config(config) entities_filter = convert_include_exclude_filter(config)
entity_ids = _get_related_entity_ids(session, entities_filter) entity_ids = _get_related_entity_ids(session, entities_filter)
else: else:
entities_filter = _all_entities_filter entities_filter = _all_entities_filter
@ -642,7 +601,6 @@ class LazyEventPartialState:
@property @property
def data(self): def data(self):
"""Event data.""" """Event data."""
if not self._event_data: if not self._event_data:
if self._row.event_data == EMPTY_JSON_OBJECT: if self._row.event_data == EMPTY_JSON_OBJECT:
self._event_data = {} self._event_data = {}
@ -679,7 +637,6 @@ class LazyEventPartialState:
@property @property
def has_old_and_new_state(self): def has_old_and_new_state(self):
"""Check the json data to see if new_state and old_state is present without decoding.""" """Check the json data to see if new_state and old_state is present without decoding."""
# Delete this check once all states are saved in the v8 schema # Delete this check once all states are saved in the v8 schema
# format or later (they have the old_state_id column). # format or later (they have the old_state_id column).

View file

@ -4,16 +4,13 @@ import json
import voluptuous as vol import voluptuous as vol
from homeassistant.components.mqtt import valid_publish_topic from homeassistant.components.mqtt import valid_publish_topic
from homeassistant.const import ( from homeassistant.const import MATCH_ALL
CONF_DOMAINS,
CONF_ENTITIES,
CONF_EXCLUDE,
CONF_INCLUDE,
MATCH_ALL,
)
from homeassistant.core import callback from homeassistant.core import callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entityfilter import generate_filter from homeassistant.helpers.entityfilter import (
INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA,
convert_include_exclude_filter,
)
from homeassistant.helpers.event import async_track_state_change from homeassistant.helpers.event import async_track_state_change
from homeassistant.helpers.json import JSONEncoder from homeassistant.helpers.json import JSONEncoder
@ -25,29 +22,13 @@ DOMAIN = "mqtt_statestream"
CONFIG_SCHEMA = vol.Schema( CONFIG_SCHEMA = vol.Schema(
{ {
DOMAIN: vol.Schema( DOMAIN: INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA.extend(
{ {
vol.Optional(CONF_EXCLUDE, default={}): vol.Schema(
{
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
vol.Optional(CONF_DOMAINS, default=[]): vol.All(
cv.ensure_list, [cv.string]
),
}
),
vol.Optional(CONF_INCLUDE, default={}): vol.Schema(
{
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
vol.Optional(CONF_DOMAINS, default=[]): vol.All(
cv.ensure_list, [cv.string]
),
}
),
vol.Required(CONF_BASE_TOPIC): valid_publish_topic, vol.Required(CONF_BASE_TOPIC): valid_publish_topic,
vol.Optional(CONF_PUBLISH_ATTRIBUTES, default=False): cv.boolean, vol.Optional(CONF_PUBLISH_ATTRIBUTES, default=False): cv.boolean,
vol.Optional(CONF_PUBLISH_TIMESTAMPS, default=False): cv.boolean, vol.Optional(CONF_PUBLISH_TIMESTAMPS, default=False): cv.boolean,
} }
) ),
}, },
extra=vol.ALLOW_EXTRA, extra=vol.ALLOW_EXTRA,
) )
@ -55,18 +36,11 @@ CONFIG_SCHEMA = vol.Schema(
async def async_setup(hass, config): async def async_setup(hass, config):
"""Set up the MQTT state feed.""" """Set up the MQTT state feed."""
conf = config.get(DOMAIN, {}) conf = config.get(DOMAIN)
publish_filter = convert_include_exclude_filter(conf)
base_topic = conf.get(CONF_BASE_TOPIC) base_topic = conf.get(CONF_BASE_TOPIC)
pub_include = conf.get(CONF_INCLUDE, {})
pub_exclude = conf.get(CONF_EXCLUDE, {})
publish_attributes = conf.get(CONF_PUBLISH_ATTRIBUTES) publish_attributes = conf.get(CONF_PUBLISH_ATTRIBUTES)
publish_timestamps = conf.get(CONF_PUBLISH_TIMESTAMPS) publish_timestamps = conf.get(CONF_PUBLISH_TIMESTAMPS)
publish_filter = generate_filter(
pub_include.get(CONF_DOMAINS, []),
pub_include.get(CONF_ENTITIES, []),
pub_exclude.get(CONF_DOMAINS, []),
pub_exclude.get(CONF_ENTITIES, []),
)
if not base_topic.endswith("/"): if not base_topic.endswith("/"):
base_topic = f"{base_topic}/" base_topic = f"{base_topic}/"

View file

@ -7,7 +7,7 @@ import logging
import queue import queue
import threading import threading
import time import time
from typing import Any, Dict, Optional from typing import Any, Callable, List, Optional
from sqlalchemy import create_engine, event as sqlalchemy_event, exc, select from sqlalchemy import create_engine, event as sqlalchemy_event, exc, select
from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm import scoped_session, sessionmaker
@ -17,10 +17,7 @@ import voluptuous as vol
from homeassistant.components import persistent_notification from homeassistant.components import persistent_notification
from homeassistant.const import ( from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_ENTITY_ID,
CONF_DOMAINS,
CONF_ENTITIES,
CONF_EXCLUDE, CONF_EXCLUDE,
CONF_INCLUDE,
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_START,
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_STOP,
EVENT_STATE_CHANGED, EVENT_STATE_CHANGED,
@ -29,7 +26,11 @@ from homeassistant.const import (
) )
from homeassistant.core import CoreState, HomeAssistant, callback from homeassistant.core import CoreState, HomeAssistant, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entityfilter import generate_filter from homeassistant.helpers.entityfilter import (
INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA,
INCLUDE_EXCLUDE_FILTER_SCHEMA_INNER,
convert_include_exclude_filter,
)
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@ -69,22 +70,12 @@ CONF_PURGE_INTERVAL = "purge_interval"
CONF_EVENT_TYPES = "event_types" CONF_EVENT_TYPES = "event_types"
CONF_COMMIT_INTERVAL = "commit_interval" CONF_COMMIT_INTERVAL = "commit_interval"
FILTER_SCHEMA = vol.Schema( EXCLUDE_SCHEMA = INCLUDE_EXCLUDE_FILTER_SCHEMA_INNER.extend(
{ {vol.Optional(CONF_EVENT_TYPES): vol.All(cv.ensure_list, [cv.string])}
vol.Optional(CONF_EXCLUDE, default={}): vol.Schema( )
{
vol.Optional(CONF_DOMAINS): vol.All(cv.ensure_list, [cv.string]), FILTER_SCHEMA = INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA.extend(
vol.Optional(CONF_ENTITIES): cv.entity_ids, {vol.Optional(CONF_EXCLUDE, default=EXCLUDE_SCHEMA({})): EXCLUDE_SCHEMA}
vol.Optional(CONF_EVENT_TYPES): vol.All(cv.ensure_list, [cv.string]),
}
),
vol.Optional(CONF_INCLUDE, default={}): vol.Schema(
{
vol.Optional(CONF_DOMAINS): vol.All(cv.ensure_list, [cv.string]),
vol.Optional(CONF_ENTITIES): cv.entity_ids,
}
),
}
) )
CONFIG_SCHEMA = vol.Schema( CONFIG_SCHEMA = vol.Schema(
@ -161,6 +152,7 @@ def run_information_with_session(session, point_in_time: Optional[datetime] = No
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the recorder.""" """Set up the recorder."""
conf = config[DOMAIN] conf = config[DOMAIN]
entity_filter = convert_include_exclude_filter(conf)
auto_purge = conf[CONF_AUTO_PURGE] auto_purge = conf[CONF_AUTO_PURGE]
keep_days = conf[CONF_PURGE_KEEP_DAYS] keep_days = conf[CONF_PURGE_KEEP_DAYS]
commit_interval = conf[CONF_COMMIT_INTERVAL] commit_interval = conf[CONF_COMMIT_INTERVAL]
@ -170,9 +162,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
db_url = conf.get(CONF_DB_URL) db_url = conf.get(CONF_DB_URL)
if not db_url: if not db_url:
db_url = DEFAULT_URL.format(hass_config_path=hass.config.path(DEFAULT_DB_FILE)) db_url = DEFAULT_URL.format(hass_config_path=hass.config.path(DEFAULT_DB_FILE))
exclude = conf[CONF_EXCLUDE]
include = conf.get(CONF_INCLUDE, {}) exclude_t = exclude.get(CONF_EVENT_TYPES, [])
exclude = conf.get(CONF_EXCLUDE, {})
instance = hass.data[DATA_INSTANCE] = Recorder( instance = hass.data[DATA_INSTANCE] = Recorder(
hass=hass, hass=hass,
auto_purge=auto_purge, auto_purge=auto_purge,
@ -181,8 +172,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
uri=db_url, uri=db_url,
db_max_retries=db_max_retries, db_max_retries=db_max_retries,
db_retry_wait=db_retry_wait, db_retry_wait=db_retry_wait,
include=include, entity_filter=entity_filter,
exclude=exclude, exclude_t=exclude_t,
) )
instance.async_initialize() instance.async_initialize()
instance.start() instance.start()
@ -213,8 +204,8 @@ class Recorder(threading.Thread):
uri: str, uri: str,
db_max_retries: int, db_max_retries: int,
db_retry_wait: int, db_retry_wait: int,
include: Dict, entity_filter: Callable[[str], bool],
exclude: Dict, exclude_t: List[str],
) -> None: ) -> None:
"""Initialize the recorder.""" """Initialize the recorder."""
threading.Thread.__init__(self, name="Recorder") threading.Thread.__init__(self, name="Recorder")
@ -232,13 +223,8 @@ class Recorder(threading.Thread):
self.engine: Any = None self.engine: Any = None
self.run_info: Any = None self.run_info: Any = None
self.entity_filter = generate_filter( self.entity_filter = entity_filter
include.get(CONF_DOMAINS, []), self.exclude_t = exclude_t
include.get(CONF_ENTITIES, []),
exclude.get(CONF_DOMAINS, []),
exclude.get(CONF_ENTITIES, []),
)
self.exclude_t = exclude.get(CONF_EVENT_TYPES, [])
self._timechanges_seen = 0 self._timechanges_seen = 0
self._keepalive_count = 0 self._keepalive_count = 0
@ -513,7 +499,6 @@ class Recorder(threading.Thread):
def setup_recorder_connection(dbapi_connection, connection_record): def setup_recorder_connection(dbapi_connection, connection_record):
"""Dbapi specific connection settings.""" """Dbapi specific connection settings."""
if self._completed_database_setup: if self._completed_database_setup:
return return

View file

@ -1,16 +1,23 @@
"""Helper class to implement include/exclude of entities and domains.""" """Helper class to implement include/exclude of entities and domains."""
from typing import Callable, Dict, List import fnmatch
import re
from typing import Callable, Dict, List, Pattern
import voluptuous as vol import voluptuous as vol
from homeassistant.const import CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE
from homeassistant.core import split_entity_id from homeassistant.core import split_entity_id
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
CONF_INCLUDE_DOMAINS = "include_domains" CONF_INCLUDE_DOMAINS = "include_domains"
CONF_INCLUDE_ENTITY_GLOBS = "include_entity_globs"
CONF_INCLUDE_ENTITIES = "include_entities" CONF_INCLUDE_ENTITIES = "include_entities"
CONF_EXCLUDE_DOMAINS = "exclude_domains" CONF_EXCLUDE_DOMAINS = "exclude_domains"
CONF_EXCLUDE_ENTITY_GLOBS = "exclude_entity_globs"
CONF_EXCLUDE_ENTITIES = "exclude_entities" CONF_EXCLUDE_ENTITIES = "exclude_entities"
CONF_ENTITY_GLOBS = "entity_globs"
def convert_filter(config: Dict[str, List[str]]) -> Callable[[str], bool]: def convert_filter(config: Dict[str, List[str]]) -> Callable[[str], bool]:
"""Convert the filter schema into a filter.""" """Convert the filter schema into a filter."""
@ -19,6 +26,8 @@ def convert_filter(config: Dict[str, List[str]]) -> Callable[[str], bool]:
config[CONF_INCLUDE_ENTITIES], config[CONF_INCLUDE_ENTITIES],
config[CONF_EXCLUDE_DOMAINS], config[CONF_EXCLUDE_DOMAINS],
config[CONF_EXCLUDE_ENTITIES], config[CONF_EXCLUDE_ENTITIES],
config[CONF_INCLUDE_ENTITY_GLOBS],
config[CONF_EXCLUDE_ENTITY_GLOBS],
) )
setattr(filt, "config", config) setattr(filt, "config", config)
setattr(filt, "empty_filter", sum(len(val) for val in config.values()) == 0) setattr(filt, "empty_filter", sum(len(val) for val in config.values()) == 0)
@ -30,10 +39,16 @@ BASE_FILTER_SCHEMA = vol.Schema(
vol.Optional(CONF_EXCLUDE_DOMAINS, default=[]): vol.All( vol.Optional(CONF_EXCLUDE_DOMAINS, default=[]): vol.All(
cv.ensure_list, [cv.string] cv.ensure_list, [cv.string]
), ),
vol.Optional(CONF_EXCLUDE_ENTITY_GLOBS, default=[]): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(CONF_EXCLUDE_ENTITIES, default=[]): cv.entity_ids, vol.Optional(CONF_EXCLUDE_ENTITIES, default=[]): cv.entity_ids,
vol.Optional(CONF_INCLUDE_DOMAINS, default=[]): vol.All( vol.Optional(CONF_INCLUDE_DOMAINS, default=[]): vol.All(
cv.ensure_list, [cv.string] cv.ensure_list, [cv.string]
), ),
vol.Optional(CONF_INCLUDE_ENTITY_GLOBS, default=[]): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(CONF_INCLUDE_ENTITIES, default=[]): cv.entity_ids, vol.Optional(CONF_INCLUDE_ENTITIES, default=[]): cv.entity_ids,
} }
) )
@ -41,20 +56,104 @@ BASE_FILTER_SCHEMA = vol.Schema(
FILTER_SCHEMA = vol.All(BASE_FILTER_SCHEMA, convert_filter) FILTER_SCHEMA = vol.All(BASE_FILTER_SCHEMA, convert_filter)
def convert_include_exclude_filter(
config: Dict[str, Dict[str, List[str]]]
) -> Callable[[str], bool]:
"""Convert the include exclude filter schema into a filter."""
include = config[CONF_INCLUDE]
exclude = config[CONF_EXCLUDE]
filt = convert_filter(
{
CONF_INCLUDE_DOMAINS: include[CONF_DOMAINS],
CONF_INCLUDE_ENTITY_GLOBS: include[CONF_ENTITY_GLOBS],
CONF_INCLUDE_ENTITIES: include[CONF_ENTITIES],
CONF_EXCLUDE_DOMAINS: exclude[CONF_DOMAINS],
CONF_EXCLUDE_ENTITY_GLOBS: exclude[CONF_ENTITY_GLOBS],
CONF_EXCLUDE_ENTITIES: exclude[CONF_ENTITIES],
}
)
setattr(filt, "config", config)
return filt
INCLUDE_EXCLUDE_FILTER_SCHEMA_INNER = vol.Schema(
{
vol.Optional(CONF_DOMAINS, default=[]): vol.All(cv.ensure_list, [cv.string]),
vol.Optional(CONF_ENTITY_GLOBS, default=[]): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
}
)
INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA = vol.Schema(
{
vol.Optional(
CONF_INCLUDE, default=INCLUDE_EXCLUDE_FILTER_SCHEMA_INNER({})
): INCLUDE_EXCLUDE_FILTER_SCHEMA_INNER,
vol.Optional(
CONF_EXCLUDE, default=INCLUDE_EXCLUDE_FILTER_SCHEMA_INNER({})
): INCLUDE_EXCLUDE_FILTER_SCHEMA_INNER,
}
)
INCLUDE_EXCLUDE_FILTER_SCHEMA = vol.All(
INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA, convert_include_exclude_filter
)
def _glob_to_re(glob: str) -> Pattern:
"""Translate and compile glob string into pattern."""
return re.compile(fnmatch.translate(glob))
def _test_against_patterns(patterns: List[Pattern], entity_id: str) -> bool:
"""Test entity against list of patterns, true if any match."""
for pattern in patterns:
if pattern.match(entity_id):
return True
return False
# It's safe since we don't modify it. And None causes typing warnings
# pylint: disable=dangerous-default-value
def generate_filter( def generate_filter(
include_domains: List[str], include_domains: List[str],
include_entities: List[str], include_entities: List[str],
exclude_domains: List[str], exclude_domains: List[str],
exclude_entities: List[str], exclude_entities: List[str],
include_entity_globs: List[str] = [],
exclude_entity_globs: List[str] = [],
) -> Callable[[str], bool]: ) -> Callable[[str], bool]:
"""Return a function that will filter entities based on the args.""" """Return a function that will filter entities based on the args."""
include_d = set(include_domains) include_d = set(include_domains)
include_e = set(include_entities) include_e = set(include_entities)
exclude_d = set(exclude_domains) exclude_d = set(exclude_domains)
exclude_e = set(exclude_entities) exclude_e = set(exclude_entities)
include_eg_set = set(include_entity_globs)
exclude_eg_set = set(exclude_entity_globs)
include_eg = list(map(_glob_to_re, include_eg_set))
exclude_eg = list(map(_glob_to_re, exclude_eg_set))
have_exclude = bool(exclude_e or exclude_d) have_exclude = bool(exclude_e or exclude_d or exclude_eg)
have_include = bool(include_e or include_d) have_include = bool(include_e or include_d or include_eg)
def entity_included(domain: str, entity_id: str) -> bool:
"""Return true if entity matches inclusion filters."""
return (
entity_id in include_e
or domain in include_d
or bool(include_eg and _test_against_patterns(include_eg, entity_id))
)
def entity_excluded(domain: str, entity_id: str) -> bool:
"""Return true if entity matches exclusion filters."""
return (
entity_id in exclude_e
or domain in exclude_d
or bool(exclude_eg and _test_against_patterns(exclude_eg, entity_id))
)
# Case 1 - no includes or excludes - pass all entities # Case 1 - no includes or excludes - pass all entities
if not have_include and not have_exclude: if not have_include and not have_exclude:
@ -66,7 +165,7 @@ def generate_filter(
def entity_filter_2(entity_id: str) -> bool: def entity_filter_2(entity_id: str) -> bool:
"""Return filter function for case 2.""" """Return filter function for case 2."""
domain = split_entity_id(entity_id)[0] domain = split_entity_id(entity_id)[0]
return entity_id in include_e or domain in include_d return entity_included(domain, entity_id)
return entity_filter_2 return entity_filter_2
@ -76,36 +175,50 @@ def generate_filter(
def entity_filter_3(entity_id: str) -> bool: def entity_filter_3(entity_id: str) -> bool:
"""Return filter function for case 3.""" """Return filter function for case 3."""
domain = split_entity_id(entity_id)[0] domain = split_entity_id(entity_id)[0]
return entity_id not in exclude_e and domain not in exclude_d return not entity_excluded(domain, entity_id)
return entity_filter_3 return entity_filter_3
# Case 4 - both includes and excludes specified # Case 4 - both includes and excludes specified
# Case 4a - include domain specified # Case 4a - include domain or glob specified
# - if domain is included, pass if entity not excluded # - if domain is included, pass if entity not excluded
# - if domain is not included, pass if entity is included # - if glob is included, pass if entity and domain not excluded
# note: if both include and exclude domains specified, # - if domain and glob are not included, pass if entity is included
# the exclude domains are ignored # note: if both include domain matches then exclude domains ignored.
if include_d: # If glob matches then exclude domains and glob checked
if include_d or include_eg:
def entity_filter_4a(entity_id: str) -> bool: def entity_filter_4a(entity_id: str) -> bool:
"""Return filter function for case 4a.""" """Return filter function for case 4a."""
domain = split_entity_id(entity_id)[0] domain = split_entity_id(entity_id)[0]
if domain in include_d: if domain in include_d:
return entity_id not in exclude_e return not (
entity_id in exclude_e
or bool(
exclude_eg and _test_against_patterns(exclude_eg, entity_id)
)
)
if _test_against_patterns(include_eg, entity_id):
return not entity_excluded(domain, entity_id)
return entity_id in include_e return entity_id in include_e
return entity_filter_4a return entity_filter_4a
# Case 4b - exclude domain specified # Case 4b - exclude domain or glob specified, include has no domain or glob
# - if domain is excluded, pass if entity is included # In this one case the traditional include logic is inverted. Even though an
# - if domain is not excluded, pass if entity not excluded # include is specified since its only a list of entity IDs its used only to
if exclude_d: # expose specific entities excluded by domain or glob. Any entities not
# excluded are then presumed included. Logic is as follows
# - if domain or glob is excluded, pass if entity is included
# - if domain is not excluded, pass if entity not excluded by ID
if exclude_d or exclude_eg:
def entity_filter_4b(entity_id: str) -> bool: def entity_filter_4b(entity_id: str) -> bool:
"""Return filter function for case 4b.""" """Return filter function for case 4b."""
domain = split_entity_id(entity_id)[0] domain = split_entity_id(entity_id)[0]
if domain in exclude_d: if domain in exclude_d or (
exclude_eg and _test_against_patterns(exclude_eg, entity_id)
):
return entity_id in include_e return entity_id in include_e
return entity_id not in exclude_e return entity_id not in exclude_e
@ -113,8 +226,4 @@ def generate_filter(
# Case 4c - neither include or exclude domain specified # Case 4c - neither include or exclude domain specified
# - Only pass if entity is included. Ignore entity excludes. # - Only pass if entity is included. Ignore entity excludes.
def entity_filter_4c(entity_id: str) -> bool: return lambda entity_id: entity_id in include_e
"""Return filter function for case 4c."""
return entity_id in include_e
return entity_filter_4c

View file

@ -12,6 +12,7 @@ from typing import Callable, Dict, TypeVar
from homeassistant import core from homeassistant import core
from homeassistant.components.websocket_api.const import JSON_DUMP from homeassistant.components.websocket_api.const import JSON_DUMP
from homeassistant.const import ATTR_NOW, EVENT_STATE_CHANGED, EVENT_TIME_CHANGED from homeassistant.const import ATTR_NOW, EVENT_STATE_CHANGED, EVENT_TIME_CHANGED
from homeassistant.helpers.entityfilter import convert_include_exclude_filter
from homeassistant.helpers.json import JSONEncoder from homeassistant.helpers.json import JSONEncoder
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
@ -178,10 +179,13 @@ async def _logbook_filtering(hass, last_changed, last_updated):
entity_attr_cache = logbook.EntityAttributeCache(hass) entity_attr_cache = logbook.EntityAttributeCache(hass)
entities_filter = convert_include_exclude_filter(
logbook.INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA({})
)
def yield_events(event): def yield_events(event):
# pylint: disable=protected-access
entities_filter = logbook._generate_filter_from_config({})
for _ in range(10 ** 5): for _ in range(10 ** 5):
# pylint: disable=protected-access
if logbook._keep_event(hass, event, entities_filter, entity_attr_cache): if logbook._keep_event(hass, event, entities_filter, entity_attr_cache):
yield event yield event
@ -192,6 +196,71 @@ async def _logbook_filtering(hass, last_changed, last_updated):
return timer() - start return timer() - start
@benchmark
async def filtering_entity_id(hass):
"""Run a 100k state changes through entity filter."""
config = {
"include": {
"domains": [
"automation",
"script",
"group",
"media_player",
"custom_component",
],
"entity_globs": [
"binary_sensor.*_contact",
"binary_sensor.*_occupancy",
"binary_sensor.*_detected",
"binary_sensor.*_active",
"input_*",
"device_tracker.*_phone",
"switch.*_light",
"binary_sensor.*_charging",
"binary_sensor.*_lock",
"binary_sensor.*_connected",
],
"entities": [
"test.entity_1",
"test.entity_2",
"binary_sensor.garage_door_open",
"test.entity_3",
"test.entity_4",
],
},
"exclude": {
"domains": ["input_number"],
"entity_globs": ["media_player.google_*", "group.all_*"],
"entities": [],
},
}
entity_ids = [
"automation.home_arrival",
"script.shut_off_house",
"binary_sensor.garage_door_open",
"binary_sensor.front_door_lock",
"binary_sensor.kitchen_motion_sensor_occupancy",
"switch.desk_lamp",
"light.dining_room",
"input_boolean.guest_staying_over",
"person.eleanor_fant",
"alert.issue_at_home",
"calendar.eleanor_fant_s_calendar",
"sun.sun",
]
entities_filter = convert_include_exclude_filter(config)
size = len(entity_ids)
start = timer()
for i in range(10 ** 5):
entities_filter(entity_ids[i % size])
return timer() - start
@benchmark @benchmark
async def valid_entity_id(hass): async def valid_entity_id(hass):
"""Run valid entity ID a million times.""" """Run valid entity ID a million times."""

View file

@ -100,6 +100,9 @@ aiohttp_cors==0.7.0
# homeassistant.components.hue # homeassistant.components.hue
aiohue==2.1.0 aiohue==2.1.0
# homeassistant.components.apache_kafka
aiokafka==0.5.1
# homeassistant.components.notion # homeassistant.components.notion
aionotion==1.1.0 aionotion==1.1.0
@ -155,6 +158,9 @@ avri-api==0.1.7
# homeassistant.components.axis # homeassistant.components.axis
axis==33 axis==33
# homeassistant.components.azure_event_hub
azure-eventhub==5.1.0
# homeassistant.components.homekit # homeassistant.components.homekit
base36==0.1.1 base36==0.1.1

View file

@ -0,0 +1 @@
"""Tests for apache_kafka component."""

View file

@ -0,0 +1,181 @@
"""The tests for the Apache Kafka component."""
from collections import namedtuple
import pytest
import homeassistant.components.apache_kafka as apache_kafka
from homeassistant.const import STATE_ON
from homeassistant.setup import async_setup_component
from tests.async_mock import patch
APACHE_KAFKA_PATH = "homeassistant.components.apache_kafka"
PRODUCER_PATH = f"{APACHE_KAFKA_PATH}.AIOKafkaProducer"
MIN_CONFIG = {
"ip_address": "localhost",
"port": 8080,
"topic": "topic",
}
FilterTest = namedtuple("FilterTest", "id should_pass")
MockKafkaClient = namedtuple("MockKafkaClient", "init start send_and_wait")
@pytest.fixture(name="mock_client")
def mock_client_fixture():
"""Mock the apache kafka client."""
with patch(f"{PRODUCER_PATH}.start") as start, patch(
f"{PRODUCER_PATH}.send_and_wait"
) as send_and_wait, patch(f"{PRODUCER_PATH}.__init__", return_value=None) as init:
yield MockKafkaClient(init, start, send_and_wait)
@pytest.fixture(autouse=True, scope="module")
def mock_client_stop():
"""Mock client stop at module scope for teardown."""
with patch(f"{PRODUCER_PATH}.stop") as stop:
yield stop
async def test_minimal_config(hass, mock_client):
"""Test the minimal config and defaults of component."""
config = {apache_kafka.DOMAIN: MIN_CONFIG}
assert await async_setup_component(hass, apache_kafka.DOMAIN, config)
await hass.async_block_till_done()
assert mock_client.start.called_once
async def test_full_config(hass, mock_client):
"""Test the full config of component."""
config = {
apache_kafka.DOMAIN: {
"filter": {
"include_domains": ["light"],
"include_entity_globs": ["sensor.included_*"],
"include_entities": ["binary_sensor.included"],
"exclude_domains": ["light"],
"exclude_entity_globs": ["sensor.excluded_*"],
"exclude_entities": ["binary_sensor.excluded"],
},
}
}
config[apache_kafka.DOMAIN].update(MIN_CONFIG)
assert await async_setup_component(hass, apache_kafka.DOMAIN, config)
await hass.async_block_till_done()
assert mock_client.start.called_once
async def _setup(hass, filter_config):
"""Shared set up for filtering tests."""
config = {apache_kafka.DOMAIN: {"filter": filter_config}}
config[apache_kafka.DOMAIN].update(MIN_CONFIG)
assert await async_setup_component(hass, apache_kafka.DOMAIN, config)
await hass.async_block_till_done()
async def _run_filter_tests(hass, tests, mock_client):
"""Run a series of filter tests on apache kafka."""
for test in tests:
hass.states.async_set(test.id, STATE_ON)
await hass.async_block_till_done()
if test.should_pass:
mock_client.send_and_wait.assert_called_once()
mock_client.send_and_wait.reset_mock()
else:
mock_client.send_and_wait.assert_not_called()
async def test_allowlist(hass, mock_client):
"""Test an allowlist only config."""
await _setup(
hass,
{
"include_domains": ["light"],
"include_entity_globs": ["sensor.included_*"],
"include_entities": ["binary_sensor.included"],
},
)
tests = [
FilterTest("climate.excluded", False),
FilterTest("light.included", True),
FilterTest("sensor.excluded_test", False),
FilterTest("sensor.included_test", True),
FilterTest("binary_sensor.included", True),
FilterTest("binary_sensor.excluded", False),
]
await _run_filter_tests(hass, tests, mock_client)
async def test_denylist(hass, mock_client):
"""Test a denylist only config."""
await _setup(
hass,
{
"exclude_domains": ["climate"],
"exclude_entity_globs": ["sensor.excluded_*"],
"exclude_entities": ["binary_sensor.excluded"],
},
)
tests = [
FilterTest("climate.excluded", False),
FilterTest("light.included", True),
FilterTest("sensor.excluded_test", False),
FilterTest("sensor.included_test", True),
FilterTest("binary_sensor.included", True),
FilterTest("binary_sensor.excluded", False),
]
await _run_filter_tests(hass, tests, mock_client)
async def test_filtered_allowlist(hass, mock_client):
"""Test an allowlist config with a filtering denylist."""
await _setup(
hass,
{
"include_domains": ["light"],
"include_entity_globs": ["*.included_*"],
"exclude_domains": ["climate"],
"exclude_entity_globs": ["*.excluded_*"],
"exclude_entities": ["light.excluded"],
},
)
tests = [
FilterTest("light.included", True),
FilterTest("light.excluded_test", False),
FilterTest("light.excluded", False),
FilterTest("sensor.included_test", True),
FilterTest("climate.included_test", False),
]
await _run_filter_tests(hass, tests, mock_client)
async def test_filtered_denylist(hass, mock_client):
"""Test a denylist config with a filtering allowlist."""
await _setup(
hass,
{
"include_entities": ["climate.included", "sensor.excluded_test"],
"exclude_domains": ["climate"],
"exclude_entity_globs": ["*.excluded_*"],
"exclude_entities": ["light.excluded"],
},
)
tests = [
FilterTest("climate.excluded", False),
FilterTest("climate.included", True),
FilterTest("switch.excluded_test", False),
FilterTest("sensor.excluded_test", True),
FilterTest("light.excluded", False),
FilterTest("light.included", True),
]
await _run_filter_tests(hass, tests, mock_client)

View file

@ -0,0 +1 @@
"""Tests for azure_event_hub component."""

View file

@ -0,0 +1,211 @@
"""The tests for the Azure Event Hub component."""
from collections import namedtuple
import pytest
import homeassistant.components.azure_event_hub as azure_event_hub
from homeassistant.const import STATE_ON
from homeassistant.setup import async_setup_component
from tests.async_mock import MagicMock, patch
AZURE_EVENT_HUB_PATH = "homeassistant.components.azure_event_hub"
PRODUCER_PATH = f"{AZURE_EVENT_HUB_PATH}.EventHubProducerClient"
MIN_CONFIG = {
"event_hub_namespace": "namespace",
"event_hub_instance_name": "name",
"event_hub_sas_policy": "policy",
"event_hub_sas_key": "key",
}
FilterTest = namedtuple("FilterTest", "id should_pass")
@pytest.fixture(autouse=True, name="mock_client", scope="module")
def mock_client_fixture():
"""Mock the azure event hub producer client."""
with patch(f"{PRODUCER_PATH}.send_batch") as mock_send_batch, patch(
f"{PRODUCER_PATH}.close"
) as mock_close, patch(f"{PRODUCER_PATH}.__init__", return_value=None) as mock_init:
yield (
mock_init,
mock_send_batch,
mock_close,
)
@pytest.fixture(autouse=True, name="mock_batch")
def mock_batch_fixture():
"""Mock batch creator and return mocked batch object."""
mock_batch = MagicMock()
with patch(f"{PRODUCER_PATH}.create_batch", return_value=mock_batch):
yield mock_batch
@pytest.fixture(autouse=True, name="mock_policy")
def mock_policy_fixture():
"""Mock azure shared key credential."""
with patch(f"{AZURE_EVENT_HUB_PATH}.EventHubSharedKeyCredential") as policy:
yield policy
@pytest.fixture(autouse=True, name="mock_event_data")
def mock_event_data_fixture():
"""Mock the azure event data component."""
with patch(f"{AZURE_EVENT_HUB_PATH}.EventData") as event_data:
yield event_data
@pytest.fixture(autouse=True, name="mock_call_later")
def mock_call_later_fixture():
"""Mock async_call_later to allow queue processing on demand."""
with patch(f"{AZURE_EVENT_HUB_PATH}.async_call_later") as mock_call_later:
yield mock_call_later
async def test_minimal_config(hass):
"""Test the minimal config and defaults of component."""
config = {azure_event_hub.DOMAIN: MIN_CONFIG}
assert await async_setup_component(hass, azure_event_hub.DOMAIN, config)
async def test_full_config(hass):
"""Test the full config of component."""
config = {
azure_event_hub.DOMAIN: {
"send_interval": 10,
"max_delay": 10,
"filter": {
"include_domains": ["light"],
"include_entity_globs": ["sensor.included_*"],
"include_entities": ["binary_sensor.included"],
"exclude_domains": ["light"],
"exclude_entity_globs": ["sensor.excluded_*"],
"exclude_entities": ["binary_sensor.excluded"],
},
}
}
config[azure_event_hub.DOMAIN].update(MIN_CONFIG)
assert await async_setup_component(hass, azure_event_hub.DOMAIN, config)
async def _setup(hass, mock_call_later, filter_config):
"""Shared set up for filtering tests."""
config = {azure_event_hub.DOMAIN: {"filter": filter_config}}
config[azure_event_hub.DOMAIN].update(MIN_CONFIG)
assert await async_setup_component(hass, azure_event_hub.DOMAIN, config)
await hass.async_block_till_done()
mock_call_later.assert_called_once()
return mock_call_later.call_args[0][2]
async def _run_filter_tests(hass, tests, process_queue, mock_batch):
"""Run a series of filter tests on azure event hub."""
for test in tests:
hass.states.async_set(test.id, STATE_ON)
await hass.async_block_till_done()
await process_queue(None)
if test.should_pass:
mock_batch.add.assert_called_once()
mock_batch.add.reset_mock()
else:
mock_batch.add.assert_not_called()
async def test_allowlist(hass, mock_batch, mock_call_later):
"""Test an allowlist only config."""
process_queue = await _setup(
hass,
mock_call_later,
{
"include_domains": ["light"],
"include_entity_globs": ["sensor.included_*"],
"include_entities": ["binary_sensor.included"],
},
)
tests = [
FilterTest("climate.excluded", False),
FilterTest("light.included", True),
FilterTest("sensor.excluded_test", False),
FilterTest("sensor.included_test", True),
FilterTest("binary_sensor.included", True),
FilterTest("binary_sensor.excluded", False),
]
await _run_filter_tests(hass, tests, process_queue, mock_batch)
async def test_denylist(hass, mock_batch, mock_call_later):
"""Test a denylist only config."""
process_queue = await _setup(
hass,
mock_call_later,
{
"exclude_domains": ["climate"],
"exclude_entity_globs": ["sensor.excluded_*"],
"exclude_entities": ["binary_sensor.excluded"],
},
)
tests = [
FilterTest("climate.excluded", False),
FilterTest("light.included", True),
FilterTest("sensor.excluded_test", False),
FilterTest("sensor.included_test", True),
FilterTest("binary_sensor.included", True),
FilterTest("binary_sensor.excluded", False),
]
await _run_filter_tests(hass, tests, process_queue, mock_batch)
async def test_filtered_allowlist(hass, mock_batch, mock_call_later):
"""Test an allowlist config with a filtering denylist."""
process_queue = await _setup(
hass,
mock_call_later,
{
"include_domains": ["light"],
"include_entity_globs": ["*.included_*"],
"exclude_domains": ["climate"],
"exclude_entity_globs": ["*.excluded_*"],
"exclude_entities": ["light.excluded"],
},
)
tests = [
FilterTest("light.included", True),
FilterTest("light.excluded_test", False),
FilterTest("light.excluded", False),
FilterTest("sensor.included_test", True),
FilterTest("climate.included_test", False),
]
await _run_filter_tests(hass, tests, process_queue, mock_batch)
async def test_filtered_denylist(hass, mock_batch, mock_call_later):
"""Test a denylist config with a filtering allowlist."""
process_queue = await _setup(
hass,
mock_call_later,
{
"include_entities": ["climate.included", "sensor.excluded_test"],
"exclude_domains": ["climate"],
"exclude_entity_globs": ["*.excluded_*"],
"exclude_entities": ["light.excluded"],
},
)
tests = [
FilterTest("climate.excluded", False),
FilterTest("climate.included", True),
FilterTest("switch.excluded_test", False),
FilterTest("sensor.excluded_test", True),
FilterTest("light.excluded", False),
FilterTest("light.included", True),
]
await _run_filter_tests(hass, tests, process_queue, mock_batch)

View file

@ -25,15 +25,15 @@ from tests.components.google_assistant import MockConfig
SUBSCRIPTION_INFO_URL = "https://api-test.hass.io/subscription_info" SUBSCRIPTION_INFO_URL = "https://api-test.hass.io/subscription_info"
@pytest.fixture() @pytest.fixture(name="mock_auth")
def mock_auth(): def mock_auth_fixture():
"""Mock check token.""" """Mock check token."""
with patch("hass_nabucasa.auth.CognitoAuth.async_check_token"): with patch("hass_nabucasa.auth.CognitoAuth.async_check_token"):
yield yield
@pytest.fixture() @pytest.fixture(name="mock_cloud_login")
def mock_cloud_login(hass, setup_api): def mock_cloud_login_fixture(hass, setup_api):
"""Mock cloud is logged in.""" """Mock cloud is logged in."""
hass.data[DOMAIN].id_token = jwt.encode( hass.data[DOMAIN].id_token = jwt.encode(
{ {
@ -45,8 +45,8 @@ def mock_cloud_login(hass, setup_api):
) )
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True, name="setup_api")
def setup_api(hass, aioclient_mock): def setup_api_fixture(hass, aioclient_mock):
"""Initialize HTTP API.""" """Initialize HTTP API."""
hass.loop.run_until_complete( hass.loop.run_until_complete(
mock_cloud( mock_cloud(
@ -68,15 +68,15 @@ def setup_api(hass, aioclient_mock):
return mock_cloud_prefs(hass) return mock_cloud_prefs(hass)
@pytest.fixture @pytest.fixture(name="cloud_client")
def cloud_client(hass, hass_client): def cloud_client_fixture(hass, hass_client):
"""Fixture that can fetch from the cloud client.""" """Fixture that can fetch from the cloud client."""
with patch("hass_nabucasa.Cloud.write_user_info"): with patch("hass_nabucasa.Cloud.write_user_info"):
yield hass.loop.run_until_complete(hass_client()) yield hass.loop.run_until_complete(hass_client())
@pytest.fixture @pytest.fixture(name="mock_cognito")
def mock_cognito(): def mock_cognito_fixture():
"""Mock warrant.""" """Mock warrant."""
with patch("hass_nabucasa.auth.CognitoAuth._cognito") as mock_cog: with patch("hass_nabucasa.auth.CognitoAuth._cognito") as mock_cog:
yield mock_cog() yield mock_cog()
@ -362,14 +362,18 @@ async def test_websocket_status(
}, },
"alexa_entities": { "alexa_entities": {
"include_domains": [], "include_domains": [],
"include_entity_globs": [],
"include_entities": ["light.kitchen", "switch.ac"], "include_entities": ["light.kitchen", "switch.ac"],
"exclude_domains": [], "exclude_domains": [],
"exclude_entity_globs": [],
"exclude_entities": [], "exclude_entities": [],
}, },
"google_entities": { "google_entities": {
"include_domains": ["light"], "include_domains": ["light"],
"include_entity_globs": [],
"include_entities": [], "include_entities": [],
"exclude_domains": [], "exclude_domains": [],
"exclude_entity_globs": [],
"exclude_entities": [], "exclude_entities": [],
}, },
"remote_domain": None, "remote_domain": None,
@ -594,6 +598,7 @@ async def test_enabling_remote_trusted_networks_local4(
hass, hass_ws_client, setup_api, mock_cloud_login hass, hass_ws_client, setup_api, mock_cloud_login
): ):
"""Test we cannot enable remote UI when trusted networks active.""" """Test we cannot enable remote UI when trusted networks active."""
# pylint: disable=protected-access
hass.auth._providers[ hass.auth._providers[
("trusted_networks", None) ("trusted_networks", None)
] = tn_auth.TrustedNetworksAuthProvider( ] = tn_auth.TrustedNetworksAuthProvider(
@ -626,6 +631,7 @@ async def test_enabling_remote_trusted_networks_local6(
hass, hass_ws_client, setup_api, mock_cloud_login hass, hass_ws_client, setup_api, mock_cloud_login
): ):
"""Test we cannot enable remote UI when trusted networks active.""" """Test we cannot enable remote UI when trusted networks active."""
# pylint: disable=protected-access
hass.auth._providers[ hass.auth._providers[
("trusted_networks", None) ("trusted_networks", None)
] = tn_auth.TrustedNetworksAuthProvider( ] = tn_auth.TrustedNetworksAuthProvider(
@ -658,6 +664,7 @@ async def test_enabling_remote_trusted_networks_other(
hass, hass_ws_client, setup_api, mock_cloud_login hass, hass_ws_client, setup_api, mock_cloud_login
): ):
"""Test we can enable remote UI when trusted networks active.""" """Test we can enable remote UI when trusted networks active."""
# pylint: disable=protected-access
hass.auth._providers[ hass.auth._providers[
("trusted_networks", None) ("trusted_networks", None)
] = tn_auth.TrustedNetworksAuthProvider( ] = tn_auth.TrustedNetworksAuthProvider(

View file

@ -0,0 +1,262 @@
"""The tests for the Google Pub/Sub component."""
from collections import namedtuple
from datetime import datetime
import pytest
import homeassistant.components.google_pubsub as google_pubsub
from homeassistant.components.google_pubsub import DateTimeJSONEncoder as victim
from homeassistant.const import EVENT_STATE_CHANGED
from homeassistant.core import split_entity_id
from homeassistant.setup import async_setup_component
import tests.async_mock as mock
GOOGLE_PUBSUB_PATH = "homeassistant.components.google_pubsub"
async def test_datetime():
"""Test datetime encoding."""
time = datetime(2019, 1, 13, 12, 30, 5)
assert victim().encode(time) == '"2019-01-13T12:30:05"'
async def test_no_datetime():
"""Test integer encoding."""
assert victim().encode(42) == "42"
async def test_nested():
"""Test dictionary encoding."""
assert victim().encode({"foo": "bar"}) == '{"foo": "bar"}'
@pytest.fixture(autouse=True, name="mock_client")
def mock_client_fixture():
"""Mock the pubsub client."""
with mock.patch(f"{GOOGLE_PUBSUB_PATH}.pubsub_v1") as client:
client.PublisherClient = mock.MagicMock()
setattr(
client.PublisherClient,
"from_service_account_json",
mock.MagicMock(return_value=mock.MagicMock()),
)
yield client
@pytest.fixture(autouse=True, name="mock_os")
def mock_os_fixture():
"""Mock the OS cli."""
with mock.patch(f"{GOOGLE_PUBSUB_PATH}.os") as os_cli:
os_cli.path = mock.MagicMock()
setattr(os_cli.path, "join", mock.MagicMock(return_value="path"))
yield os_cli
@pytest.fixture(autouse=True)
def mock_bus_and_json(hass, monkeypatch):
"""Mock the event bus listener and os component."""
hass.bus.listen = mock.MagicMock()
monkeypatch.setattr(
f"{GOOGLE_PUBSUB_PATH}.json.dumps", mock.Mock(return_value=mock.MagicMock())
)
async def test_minimal_config(hass, mock_client):
"""Test the minimal config and defaults of component."""
config = {
google_pubsub.DOMAIN: {
"project_id": "proj",
"topic_name": "topic",
"credentials_json": "creds",
"filter": {},
}
}
assert await async_setup_component(hass, google_pubsub.DOMAIN, config)
await hass.async_block_till_done()
assert hass.bus.listen.called
assert EVENT_STATE_CHANGED == hass.bus.listen.call_args_list[0][0][0]
assert mock_client.PublisherClient.from_service_account_json.call_count == 1
assert (
mock_client.PublisherClient.from_service_account_json.call_args[0][0] == "path"
)
async def test_full_config(hass, mock_client):
"""Test the full config of the component."""
config = {
google_pubsub.DOMAIN: {
"project_id": "proj",
"topic_name": "topic",
"credentials_json": "creds",
"filter": {
"include_domains": ["light"],
"include_entity_globs": ["sensor.included_*"],
"include_entities": ["binary_sensor.included"],
"exclude_domains": ["light"],
"exclude_entity_globs": ["sensor.excluded_*"],
"exclude_entities": ["binary_sensor.excluded"],
},
}
}
assert await async_setup_component(hass, google_pubsub.DOMAIN, config)
await hass.async_block_till_done()
assert hass.bus.listen.called
assert EVENT_STATE_CHANGED == hass.bus.listen.call_args_list[0][0][0]
assert mock_client.PublisherClient.from_service_account_json.call_count == 1
assert (
mock_client.PublisherClient.from_service_account_json.call_args[0][0] == "path"
)
FilterTest = namedtuple("FilterTest", "id should_pass")
def make_event(entity_id):
"""Make a mock event for test."""
domain = split_entity_id(entity_id)[0]
state = mock.MagicMock(
state="not blank",
domain=domain,
entity_id=entity_id,
object_id="entity",
attributes={},
)
return mock.MagicMock(data={"new_state": state}, time_fired=12345)
async def _setup(hass, filter_config):
"""Shared set up for filtering tests."""
config = {
google_pubsub.DOMAIN: {
"project_id": "proj",
"topic_name": "topic",
"credentials_json": "creds",
"filter": filter_config,
}
}
assert await async_setup_component(hass, google_pubsub.DOMAIN, config)
await hass.async_block_till_done()
return hass.bus.listen.call_args_list[0][0][1]
async def test_allowlist(hass, mock_client):
"""Test an allowlist only config."""
handler_method = await _setup(
hass,
{
"include_domains": ["light"],
"include_entity_globs": ["sensor.included_*"],
"include_entities": ["binary_sensor.included"],
},
)
publish_client = mock_client.PublisherClient.from_service_account_json("path")
tests = [
FilterTest("climate.excluded", False),
FilterTest("light.included", True),
FilterTest("sensor.excluded_test", False),
FilterTest("sensor.included_test", True),
FilterTest("binary_sensor.included", True),
FilterTest("binary_sensor.excluded", False),
]
for test in tests:
event = make_event(test.id)
handler_method(event)
was_called = publish_client.publish.call_count == 1
assert test.should_pass == was_called
publish_client.publish.reset_mock()
async def test_denylist(hass, mock_client):
"""Test a denylist only config."""
handler_method = await _setup(
hass,
{
"exclude_domains": ["climate"],
"exclude_entity_globs": ["sensor.excluded_*"],
"exclude_entities": ["binary_sensor.excluded"],
},
)
publish_client = mock_client.PublisherClient.from_service_account_json("path")
tests = [
FilterTest("climate.excluded", False),
FilterTest("light.included", True),
FilterTest("sensor.excluded_test", False),
FilterTest("sensor.included_test", True),
FilterTest("binary_sensor.included", True),
FilterTest("binary_sensor.excluded", False),
]
for test in tests:
event = make_event(test.id)
handler_method(event)
was_called = publish_client.publish.call_count == 1
assert test.should_pass == was_called
publish_client.publish.reset_mock()
async def test_filtered_allowlist(hass, mock_client):
"""Test an allowlist config with a filtering denylist."""
handler_method = await _setup(
hass,
{
"include_domains": ["light"],
"include_entity_globs": ["*.included_*"],
"exclude_domains": ["climate"],
"exclude_entity_globs": ["*.excluded_*"],
"exclude_entities": ["light.excluded"],
},
)
publish_client = mock_client.PublisherClient.from_service_account_json("path")
tests = [
FilterTest("light.included", True),
FilterTest("light.excluded_test", False),
FilterTest("light.excluded", False),
FilterTest("sensor.included_test", True),
FilterTest("climate.included_test", False),
]
for test in tests:
event = make_event(test.id)
handler_method(event)
was_called = publish_client.publish.call_count == 1
assert test.should_pass == was_called
publish_client.publish.reset_mock()
async def test_filtered_denylist(hass, mock_client):
"""Test a denylist config with a filtering allowlist."""
handler_method = await _setup(
hass,
{
"include_entities": ["climate.included", "sensor.excluded_test"],
"exclude_domains": ["climate"],
"exclude_entity_globs": ["*.excluded_*"],
"exclude_entities": ["light.excluded"],
},
)
publish_client = mock_client.PublisherClient.from_service_account_json("path")
tests = [
FilterTest("climate.excluded", False),
FilterTest("climate.included", True),
FilterTest("switch.excluded_test", False),
FilterTest("sensor.excluded_test", True),
FilterTest("light.excluded", False),
FilterTest("light.included", True),
]
for test in tests:
event = make_event(test.id)
handler_method(event)
was_called = publish_client.publish.call_count == 1
assert test.should_pass == was_called
publish_client.publish.reset_mock()

View file

@ -1,21 +0,0 @@
"""The tests for the Google Pub/Sub component."""
from datetime import datetime
from homeassistant.components.google_pubsub import DateTimeJSONEncoder as victim
class TestDateTimeJSONEncoder:
"""Bundle for DateTimeJSONEncoder tests."""
def test_datetime(self):
"""Test datetime encoding."""
time = datetime(2019, 1, 13, 12, 30, 5)
assert victim().encode(time) == '"2019-01-13T12:30:05"'
def test_no_datetime(self):
"""Test integer encoding."""
assert victim().encode(42) == "42"
def test_nested(self):
"""Test dictionary encoding."""
assert victim().encode({"foo": "bar"}) == '{"foo": "bar"}'

View file

@ -66,20 +66,20 @@ from tests.components.homekit.common import patch_debounce
IP_ADDRESS = "127.0.0.1" IP_ADDRESS = "127.0.0.1"
@pytest.fixture @pytest.fixture(name="device_reg")
def device_reg(hass): def device_reg_fixture(hass):
"""Return an empty, loaded, registry.""" """Return an empty, loaded, registry."""
return mock_device_registry(hass) return mock_device_registry(hass)
@pytest.fixture @pytest.fixture(name="entity_reg")
def entity_reg(hass): def entity_reg_fixture(hass):
"""Return an empty, loaded, registry.""" """Return an empty, loaded, registry."""
return mock_registry(hass) return mock_registry(hass)
@pytest.fixture(scope="module") @pytest.fixture(name="debounce_patcher", scope="module")
def debounce_patcher(): def debounce_patcher_fixture():
"""Patch debounce method.""" """Patch debounce method."""
patcher = patch_debounce() patcher = patch_debounce()
yield patcher.start() yield patcher.start()
@ -88,7 +88,6 @@ def debounce_patcher():
async def test_setup_min(hass): async def test_setup_min(hass):
"""Test async_setup with min config options.""" """Test async_setup with min config options."""
entry = MockConfigEntry( entry = MockConfigEntry(
domain=DOMAIN, domain=DOMAIN,
data={CONF_NAME: BRIDGE_NAME, CONF_PORT: DEFAULT_PORT}, data={CONF_NAME: BRIDGE_NAME, CONF_PORT: DEFAULT_PORT},
@ -413,6 +412,47 @@ async def test_homekit_entity_filter(hass):
assert mock_get_acc.called is False assert mock_get_acc.called is False
async def test_homekit_entity_glob_filter(hass):
"""Test the entity filter."""
entry = await async_init_integration(hass)
entity_filter = generate_filter(
["cover"], ["demo.test"], [], [], ["*.included_*"], ["*.excluded_*"]
)
homekit = HomeKit(
hass,
None,
None,
None,
entity_filter,
{},
DEFAULT_SAFE_MODE,
advertise_ip=None,
entry_id=entry.entry_id,
)
homekit.bridge = Mock()
homekit.bridge.accessories = {}
with patch(f"{PATH_HOMEKIT}.get_accessory") as mock_get_acc:
mock_get_acc.return_value = None
homekit.add_bridge_accessory(State("cover.test", "open"))
assert mock_get_acc.called is True
mock_get_acc.reset_mock()
homekit.add_bridge_accessory(State("demo.test", "on"))
assert mock_get_acc.called is True
mock_get_acc.reset_mock()
homekit.add_bridge_accessory(State("cover.excluded_test", "open"))
assert mock_get_acc.called is False
mock_get_acc.reset_mock()
homekit.add_bridge_accessory(State("light.included_test", "light"))
assert mock_get_acc.called is True
mock_get_acc.reset_mock()
async def test_homekit_start(hass, hk_driver, device_reg, debounce_patcher): async def test_homekit_start(hass, hk_driver, device_reg, debounce_patcher):
"""Test HomeKit start method.""" """Test HomeKit start method."""
entry = await async_init_integration(hass) entry = await async_init_integration(hass)
@ -432,6 +472,7 @@ async def test_homekit_start(hass, hk_driver, device_reg, debounce_patcher):
homekit.bridge = Mock() homekit.bridge = Mock()
homekit.bridge.accessories = [] homekit.bridge.accessories = []
homekit.driver = hk_driver homekit.driver = hk_driver
# pylint: disable=protected-access
homekit._filter = Mock(return_value=True) homekit._filter = Mock(return_value=True)
connection = (device_registry.CONNECTION_NETWORK_MAC, "AA:BB:CC:DD:EE:FF") connection = (device_registry.CONNECTION_NETWORK_MAC, "AA:BB:CC:DD:EE:FF")
@ -587,7 +628,6 @@ async def test_homekit_stop(hass):
async def test_homekit_reset_accessories(hass): async def test_homekit_reset_accessories(hass):
"""Test adding too many accessories to HomeKit.""" """Test adding too many accessories to HomeKit."""
entry = MockConfigEntry( entry = MockConfigEntry(
domain=DOMAIN, data={CONF_NAME: "mock_name", CONF_PORT: 12345} domain=DOMAIN, data={CONF_NAME: "mock_name", CONF_PORT: 12345}
) )
@ -629,7 +669,7 @@ async def test_homekit_reset_accessories(hass):
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert 2 == hk_driver_config_changed.call_count assert hk_driver_config_changed.call_count == 2
assert mock_add_accessory.called assert mock_add_accessory.called
homekit.status = STATUS_READY homekit.status = STATUS_READY
@ -686,6 +726,7 @@ async def test_homekit_finds_linked_batteries(
entry_id=entry.entry_id, entry_id=entry.entry_id,
) )
homekit.driver = hk_driver homekit.driver = hk_driver
# pylint: disable=protected-access
homekit._filter = Mock(return_value=True) homekit._filter = Mock(return_value=True)
homekit.bridge = HomeBridge(hass, hk_driver, "mock_bridge") homekit.bridge = HomeBridge(hass, hk_driver, "mock_bridge")
@ -818,7 +859,6 @@ async def test_setup_imported(hass):
async def test_yaml_updates_update_config_entry_for_name(hass): async def test_yaml_updates_update_config_entry_for_name(hass):
"""Test async_setup with imported config.""" """Test async_setup with imported config."""
entry = MockConfigEntry( entry = MockConfigEntry(
domain=DOMAIN, domain=DOMAIN,
source=SOURCE_IMPORT, source=SOURCE_IMPORT,
@ -858,7 +898,6 @@ async def test_yaml_updates_update_config_entry_for_name(hass):
async def test_raise_config_entry_not_ready(hass): async def test_raise_config_entry_not_ready(hass):
"""Test async_setup when the port is not available.""" """Test async_setup when the port is not available."""
entry = MockConfigEntry( entry = MockConfigEntry(
domain=DOMAIN, domain=DOMAIN,
data={CONF_NAME: BRIDGE_NAME, CONF_PORT: DEFAULT_PORT}, data={CONF_NAME: BRIDGE_NAME, CONF_PORT: DEFAULT_PORT},
@ -918,6 +957,7 @@ async def test_homekit_ignored_missing_devices(
entry_id=entry.entry_id, entry_id=entry.entry_id,
) )
homekit.driver = hk_driver homekit.driver = hk_driver
# pylint: disable=protected-access
homekit._filter = Mock(return_value=True) homekit._filter = Mock(return_value=True)
homekit.bridge = HomeBridge(hass, hk_driver, "mock_bridge") homekit.bridge = HomeBridge(hass, hk_driver, "mock_bridge")
@ -997,6 +1037,7 @@ async def test_homekit_finds_linked_motion_sensors(
entry_id=entry.entry_id, entry_id=entry.entry_id,
) )
homekit.driver = hk_driver homekit.driver = hk_driver
# pylint: disable=protected-access
homekit._filter = Mock(return_value=True) homekit._filter = Mock(return_value=True)
homekit.bridge = HomeBridge(hass, hk_driver, "mock_bridge") homekit.bridge = HomeBridge(hass, hk_driver, "mock_bridge")

View file

@ -18,6 +18,8 @@ from homeassistant.components.script import EVENT_SCRIPT_STARTED
from homeassistant.const import ( from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_ENTITY_ID,
ATTR_NAME, ATTR_NAME,
CONF_DOMAINS,
CONF_ENTITIES,
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_START,
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_STOP,
EVENT_STATE_CHANGED, EVENT_STATE_CHANGED,
@ -26,6 +28,10 @@ from homeassistant.const import (
STATE_ON, STATE_ON,
) )
import homeassistant.core as ha import homeassistant.core as ha
from homeassistant.helpers.entityfilter import (
CONF_ENTITY_GLOBS,
convert_include_exclude_filter,
)
from homeassistant.helpers.json import JSONEncoder from homeassistant.helpers.json import JSONEncoder
from homeassistant.setup import async_setup_component, setup_component from homeassistant.setup import async_setup_component, setup_component
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@ -151,7 +157,9 @@ class TestComponentLogbook(unittest.TestCase):
attributes = {"unit_of_measurement": "foo"} attributes = {"unit_of_measurement": "foo"}
eventA = self.create_state_changed_event(pointA, entity_id, 10, attributes) eventA = self.create_state_changed_event(pointA, entity_id, 10, attributes)
entities_filter = logbook._generate_filter_from_config({}) entities_filter = convert_include_exclude_filter(
logbook.CONFIG_SCHEMA({logbook.DOMAIN: {}})[logbook.DOMAIN]
)
assert ( assert (
logbook._keep_event(self.hass, eventA, entities_filter, entity_attr_cache) logbook._keep_event(self.hass, eventA, entities_filter, entity_attr_cache)
is False is False
@ -174,7 +182,9 @@ class TestComponentLogbook(unittest.TestCase):
) )
eventB = self.create_state_changed_event(pointB, entity_id2, 20) eventB = self.create_state_changed_event(pointB, entity_id2, 20)
entities_filter = logbook._generate_filter_from_config({}) entities_filter = convert_include_exclude_filter(
logbook.CONFIG_SCHEMA({logbook.DOMAIN: {}})[logbook.DOMAIN]
)
events = [ events = [
e e
for e in ( for e in (
@ -210,7 +220,9 @@ class TestComponentLogbook(unittest.TestCase):
) )
eventB = self.create_state_changed_event(pointB, entity_id2, 20) eventB = self.create_state_changed_event(pointB, entity_id2, 20)
entities_filter = logbook._generate_filter_from_config({}) entities_filter = convert_include_exclude_filter(
logbook.CONFIG_SCHEMA({logbook.DOMAIN: {}})[logbook.DOMAIN]
)
events = [ events = [
e e
for e in ( for e in (
@ -244,12 +256,10 @@ class TestComponentLogbook(unittest.TestCase):
config = logbook.CONFIG_SCHEMA( config = logbook.CONFIG_SCHEMA(
{ {
ha.DOMAIN: {}, ha.DOMAIN: {},
logbook.DOMAIN: { logbook.DOMAIN: {logbook.CONF_EXCLUDE: {CONF_ENTITIES: [entity_id]}},
logbook.CONF_EXCLUDE: {logbook.CONF_ENTITIES: [entity_id]}
},
} }
) )
entities_filter = logbook._generate_filter_from_config(config[logbook.DOMAIN]) entities_filter = convert_include_exclude_filter(config[logbook.DOMAIN])
events = [ events = [
e e
for e in ( for e in (
@ -284,11 +294,11 @@ class TestComponentLogbook(unittest.TestCase):
{ {
ha.DOMAIN: {}, ha.DOMAIN: {},
logbook.DOMAIN: { logbook.DOMAIN: {
logbook.CONF_EXCLUDE: {logbook.CONF_DOMAINS: ["switch", "alexa"]} logbook.CONF_EXCLUDE: {CONF_DOMAINS: ["switch", "alexa"]}
}, },
} }
) )
entities_filter = logbook._generate_filter_from_config(config[logbook.DOMAIN]) entities_filter = convert_include_exclude_filter(config[logbook.DOMAIN])
events = [ events = [
e e
for e in ( for e in (
@ -309,6 +319,53 @@ class TestComponentLogbook(unittest.TestCase):
entries[1], pointB, "blu", domain="sensor", entity_id=entity_id2 entries[1], pointB, "blu", domain="sensor", entity_id=entity_id2
) )
def test_exclude_events_domain_glob(self):
"""Test if events are filtered if domain or glob is excluded in config."""
entity_id = "switch.bla"
entity_id2 = "sensor.blu"
entity_id3 = "sensor.excluded"
pointA = dt_util.utcnow()
pointB = pointA + timedelta(minutes=logbook.GROUP_BY_MINUTES)
pointC = pointB + timedelta(minutes=logbook.GROUP_BY_MINUTES)
entity_attr_cache = logbook.EntityAttributeCache(self.hass)
eventA = self.create_state_changed_event(pointA, entity_id, 10)
eventB = self.create_state_changed_event(pointB, entity_id2, 20)
eventC = self.create_state_changed_event(pointC, entity_id3, 30)
config = logbook.CONFIG_SCHEMA(
{
ha.DOMAIN: {},
logbook.DOMAIN: {
logbook.CONF_EXCLUDE: {
CONF_DOMAINS: ["switch", "alexa"],
CONF_ENTITY_GLOBS: "*.excluded",
}
},
}
)
entities_filter = convert_include_exclude_filter(config[logbook.DOMAIN])
events = [
e
for e in (
MockLazyEventPartialState(EVENT_HOMEASSISTANT_START),
MockLazyEventPartialState(EVENT_ALEXA_SMART_HOME),
eventA,
eventB,
eventC,
)
if logbook._keep_event(self.hass, e, entities_filter, entity_attr_cache)
]
entries = list(logbook.humanify(self.hass, events, entity_attr_cache))
assert len(entries) == 2
self.assert_entry(
entries[0], name="Home Assistant", message="started", domain=ha.DOMAIN
)
self.assert_entry(
entries[1], pointB, "blu", domain="sensor", entity_id=entity_id2
)
def test_include_events_entity(self): def test_include_events_entity(self):
"""Test if events are filtered if entity is included in config.""" """Test if events are filtered if entity is included in config."""
entity_id = "sensor.bla" entity_id = "sensor.bla"
@ -325,13 +382,13 @@ class TestComponentLogbook(unittest.TestCase):
ha.DOMAIN: {}, ha.DOMAIN: {},
logbook.DOMAIN: { logbook.DOMAIN: {
logbook.CONF_INCLUDE: { logbook.CONF_INCLUDE: {
logbook.CONF_DOMAINS: ["homeassistant"], CONF_DOMAINS: ["homeassistant"],
logbook.CONF_ENTITIES: [entity_id2], CONF_ENTITIES: [entity_id2],
} }
}, },
} }
) )
entities_filter = logbook._generate_filter_from_config(config[logbook.DOMAIN]) entities_filter = convert_include_exclude_filter(config[logbook.DOMAIN])
events = [ events = [
e e
for e in ( for e in (
@ -373,12 +430,12 @@ class TestComponentLogbook(unittest.TestCase):
ha.DOMAIN: {}, ha.DOMAIN: {},
logbook.DOMAIN: { logbook.DOMAIN: {
logbook.CONF_INCLUDE: { logbook.CONF_INCLUDE: {
logbook.CONF_DOMAINS: ["homeassistant", "sensor", "alexa"] CONF_DOMAINS: ["homeassistant", "sensor", "alexa"]
} }
}, },
} }
) )
entities_filter = logbook._generate_filter_from_config(config[logbook.DOMAIN]) entities_filter = convert_include_exclude_filter(config[logbook.DOMAIN])
events = [ events = [
e e
for e in ( for e in (
@ -400,6 +457,63 @@ class TestComponentLogbook(unittest.TestCase):
entries[2], pointB, "blu", domain="sensor", entity_id=entity_id2 entries[2], pointB, "blu", domain="sensor", entity_id=entity_id2
) )
def test_include_events_domain_glob(self):
"""Test if events are filtered if domain or glob is included in config."""
assert setup_component(self.hass, "alexa", {})
entity_id = "switch.bla"
entity_id2 = "sensor.blu"
entity_id3 = "switch.included"
pointA = dt_util.utcnow()
pointB = pointA + timedelta(minutes=logbook.GROUP_BY_MINUTES)
pointC = pointB + timedelta(minutes=logbook.GROUP_BY_MINUTES)
entity_attr_cache = logbook.EntityAttributeCache(self.hass)
event_alexa = MockLazyEventPartialState(
EVENT_ALEXA_SMART_HOME,
{"request": {"namespace": "Alexa.Discovery", "name": "Discover"}},
)
eventA = self.create_state_changed_event(pointA, entity_id, 10)
eventB = self.create_state_changed_event(pointB, entity_id2, 20)
eventC = self.create_state_changed_event(pointC, entity_id3, 30)
config = logbook.CONFIG_SCHEMA(
{
ha.DOMAIN: {},
logbook.DOMAIN: {
logbook.CONF_INCLUDE: {
CONF_DOMAINS: ["homeassistant", "sensor", "alexa"],
CONF_ENTITY_GLOBS: ["*.included"],
}
},
}
)
entities_filter = convert_include_exclude_filter(config[logbook.DOMAIN])
events = [
e
for e in (
MockLazyEventPartialState(EVENT_HOMEASSISTANT_START),
event_alexa,
eventA,
eventB,
eventC,
)
if logbook._keep_event(self.hass, e, entities_filter, entity_attr_cache)
]
entries = list(logbook.humanify(self.hass, events, entity_attr_cache))
assert len(entries) == 4
self.assert_entry(
entries[0], name="Home Assistant", message="started", domain=ha.DOMAIN
)
self.assert_entry(entries[1], name="Amazon Alexa", domain="alexa")
self.assert_entry(
entries[2], pointB, "blu", domain="sensor", entity_id=entity_id2
)
self.assert_entry(
entries[3], pointC, "included", domain="switch", entity_id=entity_id3
)
def test_include_exclude_events(self): def test_include_exclude_events(self):
"""Test if events are filtered if include and exclude is configured.""" """Test if events are filtered if include and exclude is configured."""
entity_id = "switch.bla" entity_id = "switch.bla"
@ -420,17 +534,17 @@ class TestComponentLogbook(unittest.TestCase):
ha.DOMAIN: {}, ha.DOMAIN: {},
logbook.DOMAIN: { logbook.DOMAIN: {
logbook.CONF_INCLUDE: { logbook.CONF_INCLUDE: {
logbook.CONF_DOMAINS: ["sensor", "homeassistant"], CONF_DOMAINS: ["sensor", "homeassistant"],
logbook.CONF_ENTITIES: ["switch.bla"], CONF_ENTITIES: ["switch.bla"],
}, },
logbook.CONF_EXCLUDE: { logbook.CONF_EXCLUDE: {
logbook.CONF_DOMAINS: ["switch"], CONF_DOMAINS: ["switch"],
logbook.CONF_ENTITIES: ["sensor.bli"], CONF_ENTITIES: ["sensor.bli"],
}, },
}, },
} }
) )
entities_filter = logbook._generate_filter_from_config(config[logbook.DOMAIN]) entities_filter = convert_include_exclude_filter(config[logbook.DOMAIN])
events = [ events = [
e e
for e in ( for e in (
@ -462,6 +576,83 @@ class TestComponentLogbook(unittest.TestCase):
entries[4], pointB, "blu", domain="sensor", entity_id=entity_id2 entries[4], pointB, "blu", domain="sensor", entity_id=entity_id2
) )
def test_include_exclude_events_with_glob_filters(self):
"""Test if events are filtered if include and exclude is configured."""
entity_id = "switch.bla"
entity_id2 = "sensor.blu"
entity_id3 = "sensor.bli"
entity_id4 = "light.included"
entity_id5 = "switch.included"
entity_id6 = "sensor.excluded"
pointA = dt_util.utcnow()
pointB = pointA + timedelta(minutes=logbook.GROUP_BY_MINUTES)
pointC = pointB + timedelta(minutes=logbook.GROUP_BY_MINUTES)
entity_attr_cache = logbook.EntityAttributeCache(self.hass)
eventA1 = self.create_state_changed_event(pointA, entity_id, 10)
eventA2 = self.create_state_changed_event(pointA, entity_id2, 10)
eventA3 = self.create_state_changed_event(pointA, entity_id3, 10)
eventB1 = self.create_state_changed_event(pointB, entity_id, 20)
eventB2 = self.create_state_changed_event(pointB, entity_id2, 20)
eventC1 = self.create_state_changed_event(pointC, entity_id4, 30)
eventC2 = self.create_state_changed_event(pointC, entity_id5, 30)
eventC3 = self.create_state_changed_event(pointC, entity_id6, 30)
config = logbook.CONFIG_SCHEMA(
{
ha.DOMAIN: {},
logbook.DOMAIN: {
logbook.CONF_INCLUDE: {
CONF_DOMAINS: ["sensor", "homeassistant"],
CONF_ENTITIES: ["switch.bla"],
CONF_ENTITY_GLOBS: ["*.included"],
},
logbook.CONF_EXCLUDE: {
CONF_DOMAINS: ["switch"],
CONF_ENTITY_GLOBS: ["*.excluded"],
CONF_ENTITIES: ["sensor.bli"],
},
},
}
)
entities_filter = convert_include_exclude_filter(config[logbook.DOMAIN])
events = [
e
for e in (
MockLazyEventPartialState(EVENT_HOMEASSISTANT_START),
eventA1,
eventA2,
eventA3,
eventB1,
eventB2,
eventC1,
eventC2,
eventC3,
)
if logbook._keep_event(self.hass, e, entities_filter, entity_attr_cache)
]
entries = list(logbook.humanify(self.hass, events, entity_attr_cache))
assert len(entries) == 6
self.assert_entry(
entries[0], name="Home Assistant", message="started", domain=ha.DOMAIN
)
self.assert_entry(
entries[1], pointA, "bla", domain="switch", entity_id=entity_id
)
self.assert_entry(
entries[2], pointA, "blu", domain="sensor", entity_id=entity_id2
)
self.assert_entry(
entries[3], pointB, "bla", domain="switch", entity_id=entity_id
)
self.assert_entry(
entries[4], pointB, "blu", domain="sensor", entity_id=entity_id2
)
self.assert_entry(
entries[5], pointC, "included", domain="light", entity_id=entity_id4
)
def test_exclude_attribute_changes(self): def test_exclude_attribute_changes(self):
"""Test if events of attribute changes are filtered.""" """Test if events of attribute changes are filtered."""
pointA = dt_util.utcnow() pointA = dt_util.utcnow()
@ -484,7 +675,9 @@ class TestComponentLogbook(unittest.TestCase):
"light.kitchen", pointC, state_100, state_200 "light.kitchen", pointC, state_100, state_200
) )
entities_filter = logbook._generate_filter_from_config({}) entities_filter = convert_include_exclude_filter(
logbook.CONFIG_SCHEMA({logbook.DOMAIN: {}})[logbook.DOMAIN]
)
events = [ events = [
e e
for e in (eventA, eventB) for e in (eventA, eventB)
@ -1192,6 +1385,7 @@ class TestComponentLogbook(unittest.TestCase):
entries[0], name=name, message=message, domain="sun", entity_id=entity_id entries[0], name=name, message=message, domain="sun", entity_id=entity_id
) )
# pylint: disable=no-self-use
def assert_entry( def assert_entry(
self, entry, when=None, name=None, message=None, domain=None, entity_id=None self, entry, when=None, name=None, message=None, domain=None, entity_id=None
): ):
@ -1232,6 +1426,7 @@ class TestComponentLogbook(unittest.TestCase):
entity_id, event_time_fired, old_state, new_state entity_id, event_time_fired, old_state, new_state
) )
# pylint: disable=no-self-use
def create_state_changed_event_from_old_new( def create_state_changed_event_from_old_new(
self, entity_id, event_time_fired, old_state, new_state self, entity_id, event_time_fired, old_state, new_state
): ):
@ -1306,36 +1501,36 @@ async def test_logbook_view_period_entity(hass, hass_client):
# Test today entries without filters # Test today entries without filters
response = await client.get(f"/api/logbook/{start_date.isoformat()}") response = await client.get(f"/api/logbook/{start_date.isoformat()}")
assert response.status == 200 assert response.status == 200
json = await response.json() response_json = await response.json()
assert len(json) == 2 assert len(response_json) == 2
assert json[0]["entity_id"] == entity_id_test assert response_json[0]["entity_id"] == entity_id_test
assert json[1]["entity_id"] == entity_id_second assert response_json[1]["entity_id"] == entity_id_second
# Test today entries with filter by period # Test today entries with filter by period
response = await client.get(f"/api/logbook/{start_date.isoformat()}?period=1") response = await client.get(f"/api/logbook/{start_date.isoformat()}?period=1")
assert response.status == 200 assert response.status == 200
json = await response.json() response_json = await response.json()
assert len(json) == 2 assert len(response_json) == 2
assert json[0]["entity_id"] == entity_id_test assert response_json[0]["entity_id"] == entity_id_test
assert json[1]["entity_id"] == entity_id_second assert response_json[1]["entity_id"] == entity_id_second
# Test today entries with filter by entity_id # Test today entries with filter by entity_id
response = await client.get( response = await client.get(
f"/api/logbook/{start_date.isoformat()}?entity=switch.test" f"/api/logbook/{start_date.isoformat()}?entity=switch.test"
) )
assert response.status == 200 assert response.status == 200
json = await response.json() response_json = await response.json()
assert len(json) == 1 assert len(response_json) == 1
assert json[0]["entity_id"] == entity_id_test assert response_json[0]["entity_id"] == entity_id_test
# Test entries for 3 days with filter by entity_id # Test entries for 3 days with filter by entity_id
response = await client.get( response = await client.get(
f"/api/logbook/{start_date.isoformat()}?period=3&entity=switch.test" f"/api/logbook/{start_date.isoformat()}?period=3&entity=switch.test"
) )
assert response.status == 200 assert response.status == 200
json = await response.json() response_json = await response.json()
assert len(json) == 1 assert len(response_json) == 1
assert json[0]["entity_id"] == entity_id_test assert response_json[0]["entity_id"] == entity_id_test
# Tomorrow time 00:00:00 # Tomorrow time 00:00:00
start = (dt_util.utcnow() + timedelta(days=1)).date() start = (dt_util.utcnow() + timedelta(days=1)).date()
@ -1344,25 +1539,25 @@ async def test_logbook_view_period_entity(hass, hass_client):
# Test tomorrow entries without filters # Test tomorrow entries without filters
response = await client.get(f"/api/logbook/{start_date.isoformat()}") response = await client.get(f"/api/logbook/{start_date.isoformat()}")
assert response.status == 200 assert response.status == 200
json = await response.json() response_json = await response.json()
assert len(json) == 0 assert len(response_json) == 0
# Test tomorrow entries with filter by entity_id # Test tomorrow entries with filter by entity_id
response = await client.get( response = await client.get(
f"/api/logbook/{start_date.isoformat()}?entity=switch.test" f"/api/logbook/{start_date.isoformat()}?entity=switch.test"
) )
assert response.status == 200 assert response.status == 200
json = await response.json() response_json = await response.json()
assert len(json) == 0 assert len(response_json) == 0
# Test entries from tomorrow to 3 days ago with filter by entity_id # Test entries from tomorrow to 3 days ago with filter by entity_id
response = await client.get( response = await client.get(
f"/api/logbook/{start_date.isoformat()}?period=3&entity=switch.test" f"/api/logbook/{start_date.isoformat()}?period=3&entity=switch.test"
) )
assert response.status == 200 assert response.status == 200
json = await response.json() response_json = await response.json()
assert len(json) == 1 assert len(response_json) == 1
assert json[0]["entity_id"] == entity_id_test assert response_json[0]["entity_id"] == entity_id_test
async def test_logbook_describe_event(hass, hass_client): async def test_logbook_describe_event(hass, hass_client):
@ -1409,8 +1604,8 @@ async def test_exclude_described_event(hass, hass_client):
{ {
logbook.DOMAIN: { logbook.DOMAIN: {
logbook.CONF_EXCLUDE: { logbook.CONF_EXCLUDE: {
logbook.CONF_DOMAINS: ["sensor"], CONF_DOMAINS: ["sensor"],
logbook.CONF_ENTITIES: [entity_id], CONF_ENTITIES: [entity_id],
} }
} }
}, },
@ -1488,10 +1683,10 @@ async def test_logbook_view_end_time_entity(hass, hass_client):
f"/api/logbook/{start_date.isoformat()}?end_time={end_time}" f"/api/logbook/{start_date.isoformat()}?end_time={end_time}"
) )
assert response.status == 200 assert response.status == 200
json = await response.json() response_json = await response.json()
assert len(json) == 2 assert len(response_json) == 2
assert json[0]["entity_id"] == entity_id_test assert response_json[0]["entity_id"] == entity_id_test
assert json[1]["entity_id"] == entity_id_second assert response_json[1]["entity_id"] == entity_id_second
# Test entries for 3 days with filter by entity_id # Test entries for 3 days with filter by entity_id
end_time = start + timedelta(hours=72) end_time = start + timedelta(hours=72)
@ -1499,9 +1694,9 @@ async def test_logbook_view_end_time_entity(hass, hass_client):
f"/api/logbook/{start_date.isoformat()}?end_time={end_time}&entity=switch.test" f"/api/logbook/{start_date.isoformat()}?end_time={end_time}&entity=switch.test"
) )
assert response.status == 200 assert response.status == 200
json = await response.json() response_json = await response.json()
assert len(json) == 1 assert len(response_json) == 1
assert json[0]["entity_id"] == entity_id_test assert response_json[0]["entity_id"] == entity_id_test
# Tomorrow time 00:00:00 # Tomorrow time 00:00:00
start = dt_util.utcnow() start = dt_util.utcnow()
@ -1513,9 +1708,9 @@ async def test_logbook_view_end_time_entity(hass, hass_client):
f"/api/logbook/{start_date.isoformat()}?end_time={end_time}&entity=switch.test" f"/api/logbook/{start_date.isoformat()}?end_time={end_time}&entity=switch.test"
) )
assert response.status == 200 assert response.status == 200
json = await response.json() response_json = await response.json()
assert len(json) == 1 assert len(response_json) == 1
assert json[0]["entity_id"] == entity_id_test assert response_json[0]["entity_id"] == entity_id_test
async def test_logbook_entity_filter_with_automations(hass, hass_client): async def test_logbook_entity_filter_with_automations(hass, hass_client):

View file

@ -354,3 +354,189 @@ async def test_state_changed_event_include_domain_exclude_entity(hass, mqtt_mock
await hass.async_block_till_done() await hass.async_block_till_done()
assert not mqtt_mock.async_publish.called assert not mqtt_mock.async_publish.called
async def test_state_changed_event_include_globs(hass, mqtt_mock):
"""Test that filtering on included glob works as expected."""
base_topic = "pub"
incl = {"entity_globs": ["*.included_*"]}
excl = {}
# Add the statestream component for publishing state updates
# Set the filter to allow *.included_* items
assert await add_statestream(
hass, base_topic=base_topic, publish_include=incl, publish_exclude=excl
)
await hass.async_block_till_done()
# Reset the mock because it will have already gotten calls for the
# mqtt_statestream state change on initialization, etc.
mqtt_mock.async_publish.reset_mock()
# Set a state of an entity with included glob
mock_state_change_event(hass, State("fake2.included_entity", "on"))
await hass.async_block_till_done()
await hass.async_block_till_done()
# Make sure 'on' was published to pub/fake2/included_entity/state
mqtt_mock.async_publish.assert_called_with(
"pub/fake2/included_entity/state", "on", 1, True
)
assert mqtt_mock.async_publish.called
mqtt_mock.async_publish.reset_mock()
# Set a state of an entity that shouldn't be included
mock_state_change_event(hass, State("fake2.entity", "on"))
await hass.async_block_till_done()
await hass.async_block_till_done()
assert not mqtt_mock.async_publish.called
async def test_state_changed_event_exclude_globs(hass, mqtt_mock):
"""Test that filtering on excluded globs works as expected."""
base_topic = "pub"
incl = {}
excl = {"entity_globs": ["*.excluded_*"]}
# Add the statestream component for publishing state updates
# Set the filter to allow *.excluded_* items
assert await add_statestream(
hass, base_topic=base_topic, publish_include=incl, publish_exclude=excl
)
await hass.async_block_till_done()
# Reset the mock because it will have already gotten calls for the
# mqtt_statestream state change on initialization, etc.
mqtt_mock.async_publish.reset_mock()
# Set a state of an entity
mock_state_change_event(hass, State("fake.entity", "on"))
await hass.async_block_till_done()
await hass.async_block_till_done()
# Make sure 'on' was published to pub/fake/entity/state
mqtt_mock.async_publish.assert_called_with("pub/fake/entity/state", "on", 1, True)
assert mqtt_mock.async_publish.called
mqtt_mock.async_publish.reset_mock()
# Set a state of an entity that shouldn't be included by glob
mock_state_change_event(hass, State("fake.excluded_entity", "on"))
await hass.async_block_till_done()
await hass.async_block_till_done()
assert not mqtt_mock.async_publish.called
async def test_state_changed_event_exclude_domain_globs_include_entity(hass, mqtt_mock):
"""Test filtering with excluded domain and glob and included entity."""
base_topic = "pub"
incl = {"entities": ["fake.entity"]}
excl = {"domains": ["fake"], "entity_globs": ["*.excluded_*"]}
# Add the statestream component for publishing state updates
# Set the filter to exclude with include filter
assert await add_statestream(
hass, base_topic=base_topic, publish_include=incl, publish_exclude=excl
)
await hass.async_block_till_done()
# Reset the mock because it will have already gotten calls for the
# mqtt_statestream state change on initialization, etc.
mqtt_mock.async_publish.reset_mock()
# Set a state of an entity
mock_state_change_event(hass, State("fake.entity", "on"))
await hass.async_block_till_done()
await hass.async_block_till_done()
# Make sure 'on' was published to pub/fake/entity/state
mqtt_mock.async_publish.assert_called_with("pub/fake/entity/state", "on", 1, True)
assert mqtt_mock.async_publish.called
mqtt_mock.async_publish.reset_mock()
# Set a state of an entity that doesn't match any filters
mock_state_change_event(hass, State("fake2.included_entity", "on"))
await hass.async_block_till_done()
await hass.async_block_till_done()
# Make sure 'on' was published to pub/fake/entity/state
mqtt_mock.async_publish.assert_called_with(
"pub/fake2/included_entity/state", "on", 1, True
)
assert mqtt_mock.async_publish.called
mqtt_mock.async_publish.reset_mock()
# Set a state of an entity that shouldn't be included by domain
mock_state_change_event(hass, State("fake.entity2", "on"))
await hass.async_block_till_done()
await hass.async_block_till_done()
assert not mqtt_mock.async_publish.called
mqtt_mock.async_publish.reset_mock()
# Set a state of an entity that shouldn't be included by glob
mock_state_change_event(hass, State("fake.excluded_entity", "on"))
await hass.async_block_till_done()
await hass.async_block_till_done()
assert not mqtt_mock.async_publish.called
async def test_state_changed_event_include_domain_globs_exclude_entity(hass, mqtt_mock):
"""Test filtering with included domain and glob and excluded entity."""
base_topic = "pub"
incl = {"domains": ["fake"], "entity_globs": ["*.included_*"]}
excl = {"entities": ["fake.entity2"]}
# Add the statestream component for publishing state updates
# Set the filter to include with exclude filter
assert await add_statestream(
hass, base_topic=base_topic, publish_include=incl, publish_exclude=excl
)
await hass.async_block_till_done()
# Reset the mock because it will have already gotten calls for the
# mqtt_statestream state change on initialization, etc.
mqtt_mock.async_publish.reset_mock()
# Set a state of an entity included by domain
mock_state_change_event(hass, State("fake.entity", "on"))
await hass.async_block_till_done()
await hass.async_block_till_done()
# Make sure 'on' was published to pub/fake/entity/state
mqtt_mock.async_publish.assert_called_with("pub/fake/entity/state", "on", 1, True)
assert mqtt_mock.async_publish.called
mqtt_mock.async_publish.reset_mock()
# Set a state of an entity included by glob
mock_state_change_event(hass, State("fake.included_entity", "on"))
await hass.async_block_till_done()
await hass.async_block_till_done()
# Make sure 'on' was published to pub/fake/entity/state
mqtt_mock.async_publish.assert_called_with(
"pub/fake/included_entity/state", "on", 1, True
)
assert mqtt_mock.async_publish.called
mqtt_mock.async_publish.reset_mock()
# Set a state of an entity that shouldn't be included
mock_state_change_event(hass, State("fake.entity2", "on"))
await hass.async_block_till_done()
await hass.async_block_till_done()
assert not mqtt_mock.async_publish.called
mqtt_mock.async_publish.reset_mock()
# Set a state of an entity that doesn't match any filters
mock_state_change_event(hass, State("fake2.entity", "on"))
await hass.async_block_till_done()
await hass.async_block_till_done()
assert not mqtt_mock.async_publish.called

View file

@ -1,4 +1,6 @@
"""The tests for the Prometheus exporter.""" """The tests for the Prometheus exporter."""
from collections import namedtuple
import pytest import pytest
from homeassistant import setup from homeassistant import setup
@ -10,9 +12,15 @@ from homeassistant.const import (
DEGREE, DEGREE,
DEVICE_CLASS_POWER, DEVICE_CLASS_POWER,
ENERGY_KILO_WATT_HOUR, ENERGY_KILO_WATT_HOUR,
EVENT_STATE_CHANGED,
) )
from homeassistant.core import split_entity_id
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
import tests.async_mock as mock
PROMETHEUS_PATH = "homeassistant.components.prometheus"
@pytest.fixture @pytest.fixture
async def prometheus_client(loop, hass, hass_client): async def prometheus_client(loop, hass, hass_client):
@ -139,3 +147,171 @@ async def test_view(prometheus_client): # pylint: disable=redefined-outer-name
'entity="sensor.sps30_pm_1um_weight_concentration",' 'entity="sensor.sps30_pm_1um_weight_concentration",'
'friendly_name="SPS30 PM <1µm Weight concentration"} 3.7069' in body 'friendly_name="SPS30 PM <1µm Weight concentration"} 3.7069' in body
) )
@pytest.fixture(name="mock_client")
def mock_client_fixture():
"""Mock the prometheus client."""
with mock.patch(f"{PROMETHEUS_PATH}.prometheus_client") as client:
counter_client = mock.MagicMock()
client.Counter = mock.MagicMock(return_value=counter_client)
setattr(counter_client, "labels", mock.MagicMock(return_value=mock.MagicMock()))
yield counter_client
@pytest.fixture
def mock_bus(hass):
"""Mock the event bus listener."""
hass.bus.listen = mock.MagicMock()
@pytest.mark.usefixtures("mock_bus")
async def test_minimal_config(hass, mock_client):
"""Test the minimal config and defaults of component."""
config = {prometheus.DOMAIN: {}}
assert await async_setup_component(hass, prometheus.DOMAIN, config)
await hass.async_block_till_done()
assert hass.bus.listen.called
assert EVENT_STATE_CHANGED == hass.bus.listen.call_args_list[0][0][0]
@pytest.mark.usefixtures("mock_bus")
async def test_full_config(hass, mock_client):
"""Test the full config of component."""
config = {
prometheus.DOMAIN: {
"namespace": "ns",
"default_metric": "m",
"override_metric": "m",
"component_config": {"fake.test": {"override_metric": "km"}},
"component_config_glob": {"fake.time_*": {"override_metric": "h"}},
"component_config_domain": {"climate": {"override_metric": "°C"}},
"filter": {
"include_domains": ["climate"],
"include_entity_globs": ["fake.time_*"],
"include_entities": ["fake.test"],
"exclude_domains": ["script"],
"exclude_entity_globs": ["climate.excluded_*"],
"exclude_entities": ["fake.time_excluded"],
},
}
}
assert await async_setup_component(hass, prometheus.DOMAIN, config)
await hass.async_block_till_done()
assert hass.bus.listen.called
assert EVENT_STATE_CHANGED == hass.bus.listen.call_args_list[0][0][0]
FilterTest = namedtuple("FilterTest", "id should_pass")
def make_event(entity_id):
"""Make a mock event for test."""
domain = split_entity_id(entity_id)[0]
state = mock.MagicMock(
state="not blank",
domain=domain,
entity_id=entity_id,
object_id="entity",
attributes={},
)
return mock.MagicMock(data={"new_state": state}, time_fired=12345)
async def _setup(hass, filter_config):
"""Shared set up for filtering tests."""
config = {prometheus.DOMAIN: {"filter": filter_config}}
assert await async_setup_component(hass, prometheus.DOMAIN, config)
await hass.async_block_till_done()
return hass.bus.listen.call_args_list[0][0][1]
@pytest.mark.usefixtures("mock_bus")
async def test_allowlist(hass, mock_client):
"""Test an allowlist only config."""
handler_method = await _setup(
hass,
{
"include_domains": ["fake"],
"include_entity_globs": ["test.included_*"],
"include_entities": ["not_real.included"],
},
)
tests = [
FilterTest("climate.excluded", False),
FilterTest("fake.included", True),
FilterTest("test.excluded_test", False),
FilterTest("test.included_test", True),
FilterTest("not_real.included", True),
FilterTest("not_real.excluded", False),
]
for test in tests:
event = make_event(test.id)
handler_method(event)
was_called = mock_client.labels.call_count == 1
assert test.should_pass == was_called
mock_client.labels.reset_mock()
@pytest.mark.usefixtures("mock_bus")
async def test_denylist(hass, mock_client):
"""Test a denylist only config."""
handler_method = await _setup(
hass,
{
"exclude_domains": ["fake"],
"exclude_entity_globs": ["test.excluded_*"],
"exclude_entities": ["not_real.excluded"],
},
)
tests = [
FilterTest("fake.excluded", False),
FilterTest("light.included", True),
FilterTest("test.excluded_test", False),
FilterTest("test.included_test", True),
FilterTest("not_real.included", True),
FilterTest("not_real.excluded", False),
]
for test in tests:
event = make_event(test.id)
handler_method(event)
was_called = mock_client.labels.call_count == 1
assert test.should_pass == was_called
mock_client.labels.reset_mock()
@pytest.mark.usefixtures("mock_bus")
async def test_filtered_denylist(hass, mock_client):
"""Test a denylist config with a filtering allowlist."""
handler_method = await _setup(
hass,
{
"include_entities": ["fake.included", "test.excluded_test"],
"exclude_domains": ["fake"],
"exclude_entity_globs": ["*.excluded_*"],
"exclude_entities": ["not_real.excluded"],
},
)
tests = [
FilterTest("fake.excluded", False),
FilterTest("fake.included", True),
FilterTest("alt_fake.excluded_test", False),
FilterTest("test.excluded_test", True),
FilterTest("not_real.excluded", False),
FilterTest("not_real.included", True),
]
for test in tests:
event = make_event(test.id)
handler_method(event)
was_called = mock_client.labels.call_count == 1
assert test.should_pass == was_called
mock_client.labels.reset_mock()

View file

@ -6,6 +6,8 @@ import unittest
import pytest import pytest
from homeassistant.components.recorder import ( from homeassistant.components.recorder import (
CONFIG_SCHEMA,
DOMAIN,
Recorder, Recorder,
run_information, run_information,
run_information_from_instance, run_information_from_instance,
@ -152,6 +154,19 @@ def test_saving_state_include_domains(hass_recorder):
assert _state_empty_context(hass, "test2.recorder") == states[0] assert _state_empty_context(hass, "test2.recorder") == states[0]
def test_saving_state_include_domains_globs(hass_recorder):
"""Test saving and restoring a state."""
hass = hass_recorder(
{"include": {"domains": "test2", "entity_globs": "*.included_*"}}
)
states = _add_entities(
hass, ["test.recorder", "test2.recorder", "test3.included_entity"]
)
assert len(states) == 2
assert _state_empty_context(hass, "test2.recorder") == states[0]
assert _state_empty_context(hass, "test3.included_entity") == states[1]
def test_saving_state_incl_entities(hass_recorder): def test_saving_state_incl_entities(hass_recorder):
"""Test saving and restoring a state.""" """Test saving and restoring a state."""
hass = hass_recorder({"include": {"entities": "test2.recorder"}}) hass = hass_recorder({"include": {"entities": "test2.recorder"}})
@ -176,6 +191,18 @@ def test_saving_state_exclude_domains(hass_recorder):
assert _state_empty_context(hass, "test2.recorder") == states[0] assert _state_empty_context(hass, "test2.recorder") == states[0]
def test_saving_state_exclude_domains_globs(hass_recorder):
"""Test saving and restoring a state."""
hass = hass_recorder(
{"exclude": {"domains": "test", "entity_globs": "*.excluded_*"}}
)
states = _add_entities(
hass, ["test.recorder", "test2.recorder", "test2.excluded_entity"]
)
assert len(states) == 1
assert _state_empty_context(hass, "test2.recorder") == states[0]
def test_saving_state_exclude_entities(hass_recorder): def test_saving_state_exclude_entities(hass_recorder):
"""Test saving and restoring a state.""" """Test saving and restoring a state."""
hass = hass_recorder({"exclude": {"entities": "test.recorder"}}) hass = hass_recorder({"exclude": {"entities": "test.recorder"}})
@ -193,6 +220,20 @@ def test_saving_state_exclude_domain_include_entity(hass_recorder):
assert len(states) == 2 assert len(states) == 2
def test_saving_state_exclude_domain_glob_include_entity(hass_recorder):
"""Test saving and restoring a state."""
hass = hass_recorder(
{
"include": {"entities": ["test.recorder", "test.excluded_entity"]},
"exclude": {"domains": "test", "entity_globs": "*._excluded_*"},
}
)
states = _add_entities(
hass, ["test.recorder", "test2.recorder", "test.excluded_entity"]
)
assert len(states) == 3
def test_saving_state_include_domain_exclude_entity(hass_recorder): def test_saving_state_include_domain_exclude_entity(hass_recorder):
"""Test saving and restoring a state.""" """Test saving and restoring a state."""
hass = hass_recorder( hass = hass_recorder(
@ -204,6 +245,22 @@ def test_saving_state_include_domain_exclude_entity(hass_recorder):
assert _state_empty_context(hass, "test.ok").state == "state2" assert _state_empty_context(hass, "test.ok").state == "state2"
def test_saving_state_include_domain_glob_exclude_entity(hass_recorder):
"""Test saving and restoring a state."""
hass = hass_recorder(
{
"exclude": {"entities": ["test.recorder", "test2.included_entity"]},
"include": {"domains": "test", "entity_globs": "*._included_*"},
}
)
states = _add_entities(
hass, ["test.recorder", "test2.recorder", "test.ok", "test2.included_entity"]
)
assert len(states) == 1
assert _state_empty_context(hass, "test.ok") == states[0]
assert _state_empty_context(hass, "test.ok").state == "state2"
def test_recorder_setup_failure(): def test_recorder_setup_failure():
"""Test some exceptions.""" """Test some exceptions."""
hass = get_test_home_assistant() hass = get_test_home_assistant()
@ -220,8 +277,8 @@ def test_recorder_setup_failure():
uri="sqlite://", uri="sqlite://",
db_max_retries=10, db_max_retries=10,
db_retry_wait=3, db_retry_wait=3,
include={}, entity_filter=CONFIG_SCHEMA({DOMAIN: {}}),
exclude={}, exclude_t=[],
) )
rec.start() rec.start()
rec.join() rec.join()
@ -243,6 +300,7 @@ async def test_defaults_set(hass):
assert await async_setup_component(hass, "history", {}) assert await async_setup_component(hass, "history", {})
assert recorder_config is not None assert recorder_config is not None
# pylint: disable=unsubscriptable-object
assert recorder_config["auto_purge"] assert recorder_config["auto_purge"]
assert recorder_config["purge_keep_days"] == 10 assert recorder_config["purge_keep_days"] == 10

View file

@ -58,6 +58,7 @@ class TestSplunk(unittest.TestCase):
def _setup(self, mock_requests): def _setup(self, mock_requests):
"""Test the setup.""" """Test the setup."""
# pylint: disable=attribute-defined-outside-init
self.mock_post = mock_requests.post self.mock_post = mock_requests.post
self.mock_request_exception = Exception self.mock_request_exception = Exception
mock_requests.exceptions.RequestException = self.mock_request_exception mock_requests.exceptions.RequestException = self.mock_request_exception
@ -115,7 +116,7 @@ class TestSplunk(unittest.TestCase):
) )
self.mock_post.reset_mock() self.mock_post.reset_mock()
def _setup_with_filter(self): def _setup_with_filter(self, addl_filters=None):
"""Test the setup.""" """Test the setup."""
config = { config = {
"splunk": { "splunk": {
@ -128,12 +129,15 @@ class TestSplunk(unittest.TestCase):
}, },
} }
} }
if addl_filters:
config["splunk"]["filter"].update(addl_filters)
setup_component(self.hass, splunk.DOMAIN, config) setup_component(self.hass, splunk.DOMAIN, config)
@mock.patch.object(splunk, "post_request") @mock.patch.object(splunk, "post_request")
def test_splunk_entityfilter(self, mock_requests): def test_splunk_entityfilter(self, mock_requests):
"""Test event listener.""" """Test event listener."""
# pylint: disable=no-member
self._setup_with_filter() self._setup_with_filter()
testdata = [ testdata = [
@ -152,3 +156,27 @@ class TestSplunk(unittest.TestCase):
assert splunk.post_request.called assert splunk.post_request.called
splunk.post_request.reset_mock() splunk.post_request.reset_mock()
@mock.patch.object(splunk, "post_request")
def test_splunk_entityfilter_with_glob_filter(self, mock_requests):
"""Test event listener."""
# pylint: disable=no-member
self._setup_with_filter({"exclude_entity_globs": ["*.skip_*"]})
testdata = [
{"entity_id": "other_domain.other_entity", "filter_expected": False},
{"entity_id": "other_domain.excluded_entity", "filter_expected": True},
{"entity_id": "excluded_domain.other_entity", "filter_expected": True},
{"entity_id": "test.skip_me", "filter_expected": True},
]
for test in testdata:
mock_state_change_event(self.hass, State(test["entity_id"], "on"))
self.hass.block_till_done()
if test["filter_expected"]:
assert not splunk.post_request.called
else:
assert splunk.post_request.called
splunk.post_request.reset_mock()

View file

@ -1,5 +1,9 @@
"""The tests for the EntityFilter component.""" """The tests for the EntityFilter component."""
from homeassistant.helpers.entityfilter import FILTER_SCHEMA, generate_filter from homeassistant.helpers.entityfilter import (
FILTER_SCHEMA,
INCLUDE_EXCLUDE_FILTER_SCHEMA,
generate_filter,
)
def test_no_filters_case_1(): def test_no_filters_case_1():
@ -29,6 +33,27 @@ def test_includes_only_case_2():
assert testfilter("sun.sun") is False assert testfilter("sun.sun") is False
def test_includes_only_with_glob_case_2():
"""If include specified, only pass if specified (Case 2)."""
incl_dom = {"light", "sensor"}
incl_glob = {"cover.*_window"}
incl_ent = {"binary_sensor.working"}
excl_dom = {}
excl_glob = {}
excl_ent = {}
testfilter = generate_filter(
incl_dom, incl_ent, excl_dom, excl_ent, incl_glob, excl_glob
)
assert testfilter("sensor.test")
assert testfilter("light.test")
assert testfilter("cover.bedroom_window")
assert testfilter("binary_sensor.working")
assert testfilter("binary_sensor.notworking") is False
assert testfilter("sun.sun") is False
assert testfilter("cover.garage_door") is False
def test_excludes_only_case_3(): def test_excludes_only_case_3():
"""If exclude specified, pass all but specified (Case 3).""" """If exclude specified, pass all but specified (Case 3)."""
incl_dom = {} incl_dom = {}
@ -44,6 +69,27 @@ def test_excludes_only_case_3():
assert testfilter("sun.sun") is True assert testfilter("sun.sun") is True
def test_excludes_only_with_glob_case_3():
"""If exclude specified, pass all but specified (Case 3)."""
incl_dom = {}
incl_glob = {}
incl_ent = {}
excl_dom = {"light", "sensor"}
excl_glob = {"cover.*_window"}
excl_ent = {"binary_sensor.working"}
testfilter = generate_filter(
incl_dom, incl_ent, excl_dom, excl_ent, incl_glob, excl_glob
)
assert testfilter("sensor.test") is False
assert testfilter("light.test") is False
assert testfilter("cover.bedroom_window") is False
assert testfilter("binary_sensor.working") is False
assert testfilter("binary_sensor.another")
assert testfilter("sun.sun") is True
assert testfilter("cover.garage_door")
def test_with_include_domain_case4a(): def test_with_include_domain_case4a():
"""Test case 4a - include and exclude specified, with included domain.""" """Test case 4a - include and exclude specified, with included domain."""
incl_dom = {"light", "sensor"} incl_dom = {"light", "sensor"}
@ -61,6 +107,49 @@ def test_with_include_domain_case4a():
assert testfilter("sun.sun") is False assert testfilter("sun.sun") is False
def test_with_include_glob_case4a():
"""Test case 4a - include and exclude specified, with included glob."""
incl_dom = {}
incl_glob = {"light.*", "sensor.*"}
incl_ent = {"binary_sensor.working"}
excl_dom = {}
excl_glob = {}
excl_ent = {"light.ignoreme", "sensor.notworking"}
testfilter = generate_filter(
incl_dom, incl_ent, excl_dom, excl_ent, incl_glob, excl_glob
)
assert testfilter("sensor.test")
assert testfilter("sensor.notworking") is False
assert testfilter("light.test")
assert testfilter("light.ignoreme") is False
assert testfilter("binary_sensor.working")
assert testfilter("binary_sensor.another") is False
assert testfilter("sun.sun") is False
def test_with_include_domain_glob_filtering_case4a():
"""Test case 4a - include and exclude specified, both have domains and globs."""
incl_dom = {"light"}
incl_glob = {"*working"}
incl_ent = {}
excl_dom = {"binary_sensor"}
excl_glob = {"*notworking"}
excl_ent = {"light.ignoreme"}
testfilter = generate_filter(
incl_dom, incl_ent, excl_dom, excl_ent, incl_glob, excl_glob
)
assert testfilter("sensor.working")
assert testfilter("sensor.notworking") is False
assert testfilter("light.test")
assert testfilter("light.notworking") is False
assert testfilter("light.ignoreme") is False
assert testfilter("binary_sensor.not_working") is False
assert testfilter("binary_sensor.another") is False
assert testfilter("sun.sun") is False
def test_exclude_domain_case4b(): def test_exclude_domain_case4b():
"""Test case 4b - include and exclude specified, with excluded domain.""" """Test case 4b - include and exclude specified, with excluded domain."""
incl_dom = {} incl_dom = {}
@ -78,6 +167,27 @@ def test_exclude_domain_case4b():
assert testfilter("sun.sun") is True assert testfilter("sun.sun") is True
def test_exclude_glob_case4b():
"""Test case 4b - include and exclude specified, with excluded glob."""
incl_dom = {}
incl_glob = {}
incl_ent = {"binary_sensor.working"}
excl_dom = {}
excl_glob = {"binary_sensor.*"}
excl_ent = {"light.ignoreme", "sensor.notworking"}
testfilter = generate_filter(
incl_dom, incl_ent, excl_dom, excl_ent, incl_glob, excl_glob
)
assert testfilter("sensor.test")
assert testfilter("sensor.notworking") is False
assert testfilter("light.test")
assert testfilter("light.ignoreme") is False
assert testfilter("binary_sensor.working")
assert testfilter("binary_sensor.another") is False
assert testfilter("sun.sun") is True
def test_no_domain_case4c(): def test_no_domain_case4c():
"""Test case 4c - include and exclude specified, with no domains.""" """Test case 4c - include and exclude specified, with no domains."""
incl_dom = {} incl_dom = {}
@ -104,4 +214,37 @@ def test_filter_schema():
"exclude_entities": ["light.kitchen"], "exclude_entities": ["light.kitchen"],
} }
filt = FILTER_SCHEMA(conf) filt = FILTER_SCHEMA(conf)
conf.update({"include_entity_globs": [], "exclude_entity_globs": []})
assert filt.config == conf
def test_filter_schema_with_globs():
"""Test filter schema with glob options."""
conf = {
"include_domains": ["light"],
"include_entity_globs": ["sensor.kitchen_*"],
"include_entities": ["switch.kitchen"],
"exclude_domains": ["cover"],
"exclude_entity_globs": ["sensor.weather_*"],
"exclude_entities": ["light.kitchen"],
}
filt = FILTER_SCHEMA(conf)
assert filt.config == conf
def test_filter_schema_include_exclude():
"""Test the include exclude filter schema."""
conf = {
"include": {
"domains": ["light"],
"entity_globs": ["sensor.kitchen_*"],
"entities": ["switch.kitchen"],
},
"exclude": {
"domains": ["cover"],
"entity_globs": ["sensor.weather_*"],
"entities": ["light.kitchen"],
},
}
filt = INCLUDE_EXCLUDE_FILTER_SCHEMA(conf)
assert filt.config == conf assert filt.config == conf