Add support for checking if an entity is explicitly included in an entity filter (#64463)

This commit is contained in:
J. Nick Koston 2022-01-19 20:38:48 -10:00 committed by GitHub
parent a3281f9bda
commit 2083f0b3c0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 106 additions and 28 deletions

View file

@ -22,19 +22,54 @@ CONF_EXCLUDE_ENTITIES = "exclude_entities"
CONF_ENTITY_GLOBS = "entity_globs"
def convert_filter(config: dict[str, list[str]]) -> Callable[[str], bool]:
class EntityFilter:
"""A entity filter."""
def __init__(self, config: dict[str, list[str]]) -> None:
"""Init the filter."""
self.empty_filter: bool = sum(len(val) for val in config.values()) == 0
self.config = config
self._include_e = set(config[CONF_INCLUDE_ENTITIES])
self._exclude_e = set(config[CONF_EXCLUDE_ENTITIES])
self._include_d = set(config[CONF_INCLUDE_DOMAINS])
self._exclude_d = set(config[CONF_EXCLUDE_DOMAINS])
self._include_eg = _convert_globs_to_pattern_list(
config[CONF_INCLUDE_ENTITY_GLOBS]
)
self._exclude_eg = _convert_globs_to_pattern_list(
config[CONF_EXCLUDE_ENTITY_GLOBS]
)
self._filter: Callable[[str], bool] | None = None
def explicitly_included(self, entity_id: str) -> bool:
"""Check if an entity is explicitly included."""
return entity_id in self._include_e or _test_against_patterns(
self._include_eg, entity_id
)
def explicitly_excluded(self, entity_id: str) -> bool:
"""Check if an entity is explicitly excluded."""
return entity_id in self._exclude_e or _test_against_patterns(
self._exclude_eg, entity_id
)
def __call__(self, entity_id: str) -> bool:
"""Run the filter."""
if self._filter is None:
self._filter = _generate_filter_from_sets_and_pattern_lists(
self._include_d,
self._include_e,
self._exclude_d,
self._exclude_e,
self._include_eg,
self._exclude_eg,
)
return self._filter(entity_id)
def convert_filter(config: dict[str, list[str]]) -> EntityFilter:
"""Convert the filter schema into a filter."""
filt = generate_filter(
config[CONF_INCLUDE_DOMAINS],
config[CONF_INCLUDE_ENTITIES],
config[CONF_EXCLUDE_DOMAINS],
config[CONF_EXCLUDE_ENTITIES],
config[CONF_INCLUDE_ENTITY_GLOBS],
config[CONF_EXCLUDE_ENTITY_GLOBS],
)
setattr(filt, "config", config)
setattr(filt, "empty_filter", sum(len(val) for val in config.values()) == 0)
return filt
return EntityFilter(config)
BASE_FILTER_SCHEMA = vol.Schema(
@ -61,11 +96,11 @@ FILTER_SCHEMA = vol.All(BASE_FILTER_SCHEMA, convert_filter)
def convert_include_exclude_filter(
config: dict[str, dict[str, list[str]]]
) -> Callable[[str], bool]:
) -> EntityFilter:
"""Convert the include exclude filter schema into a filter."""
include = config[CONF_INCLUDE]
exclude = config[CONF_EXCLUDE]
filt = convert_filter(
return convert_filter(
{
CONF_INCLUDE_DOMAINS: include[CONF_DOMAINS],
CONF_INCLUDE_ENTITY_GLOBS: include[CONF_ENTITY_GLOBS],
@ -75,8 +110,6 @@ def convert_include_exclude_filter(
CONF_EXCLUDE_ENTITIES: exclude[CONF_ENTITIES],
}
)
setattr(filt, "config", config)
return filt
INCLUDE_EXCLUDE_FILTER_SCHEMA_INNER = vol.Schema(
@ -119,6 +152,11 @@ def _test_against_patterns(patterns: list[re.Pattern[str]], entity_id: str) -> b
return False
def _convert_globs_to_pattern_list(globs: list[str] | None) -> list[re.Pattern[str]]:
"""Convert a list of globs to a re pattern list."""
return list(map(_glob_to_re, set(globs or [])))
def generate_filter(
include_domains: list[str],
include_entities: list[str],
@ -128,19 +166,25 @@ def generate_filter(
exclude_entity_globs: list[str] | None = None,
) -> Callable[[str], bool]:
"""Return a function that will filter entities based on the args."""
include_d = set(include_domains)
include_e = set(include_entities)
exclude_d = set(exclude_domains)
exclude_e = set(exclude_entities)
include_eg_set = (
set(include_entity_globs) if include_entity_globs is not None else set()
return _generate_filter_from_sets_and_pattern_lists(
set(include_domains),
set(include_entities),
set(exclude_domains),
set(exclude_entities),
_convert_globs_to_pattern_list(include_entity_globs),
_convert_globs_to_pattern_list(exclude_entity_globs),
)
exclude_eg_set = (
set(exclude_entity_globs) if exclude_entity_globs is not None else set()
)
include_eg = list(map(_glob_to_re, include_eg_set))
exclude_eg = list(map(_glob_to_re, exclude_eg_set))
def _generate_filter_from_sets_and_pattern_lists(
include_d: set[str],
include_e: set[str],
exclude_d: set[str],
exclude_e: set[str],
include_eg: list[re.Pattern[str]],
exclude_eg: list[re.Pattern[str]],
) -> Callable[[str], bool]:
"""Generate a filter from pre-comuted sets and pattern lists."""
have_exclude = bool(exclude_e or exclude_d or exclude_eg)
have_include = bool(include_e or include_d or include_eg)

View file

@ -2,6 +2,7 @@
from homeassistant.helpers.entityfilter import (
FILTER_SCHEMA,
INCLUDE_EXCLUDE_FILTER_SCHEMA,
EntityFilter,
generate_filter,
)
@ -267,5 +268,38 @@ def test_filter_schema_include_exclude():
},
}
filt = INCLUDE_EXCLUDE_FILTER_SCHEMA(conf)
assert filt.config == conf
assert filt.config == {
"include_domains": ["light"],
"include_entity_globs": ["sensor.kitchen_*"],
"include_entities": ["switch.kitchen"],
"exclude_domains": ["cover"],
"exclude_entity_globs": ["sensor.weather_*"],
"exclude_entities": ["light.kitchen"],
}
assert not filt.empty_filter
def test_exlictly_included():
"""Test if an entity is explicitly included."""
conf = {
"include": {
"domains": ["light"],
"entity_globs": ["sensor.kitchen_*"],
"entities": ["switch.kitchen"],
},
"exclude": {
"domains": ["cover"],
"entity_globs": ["sensor.weather_*"],
"entities": ["light.kitchen"],
},
}
filt: EntityFilter = INCLUDE_EXCLUDE_FILTER_SCHEMA(conf)
assert not filt.explicitly_included("light.any")
assert not filt.explicitly_included("switch.other")
assert filt.explicitly_included("sensor.kitchen_4")
assert filt.explicitly_included("switch.kitchen")
assert not filt.explicitly_excluded("light.any")
assert not filt.explicitly_excluded("switch.other")
assert filt.explicitly_excluded("sensor.weather_5")
assert filt.explicitly_excluded("light.kitchen")