From c9df42b69a0e9e07dc9bac70ffb63386619f01a4 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 14 Feb 2021 09:42:55 -1000 Subject: [PATCH] Add support for pre-filtering events to the event bus (#46371) --- homeassistant/components/recorder/__init__.py | 24 +++-- homeassistant/config_entries.py | 18 ++-- homeassistant/core.py | 59 ++++++++---- homeassistant/helpers/device_registry.py | 19 ++-- homeassistant/helpers/entity_registry.py | 14 ++- homeassistant/helpers/event.py | 61 +++++++++++-- homeassistant/scripts/benchmark/__init__.py | 89 ++++++++++++++++--- .../mqtt/test_device_tracker_discovery.py | 1 + tests/components/mqtt/test_discovery.py | 1 + tests/components/unifi/test_device_tracker.py | 1 + tests/test_core.py | 29 ++++++ 11 files changed, 256 insertions(+), 60 deletions(-) diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index 0f8a5ae7f8f..16232bcaa16 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -252,7 +252,22 @@ class Recorder(threading.Thread): @callback def async_initialize(self): """Initialize the recorder.""" - self.hass.bus.async_listen(MATCH_ALL, self.event_listener) + self.hass.bus.async_listen( + MATCH_ALL, self.event_listener, event_filter=self._async_event_filter + ) + + @callback + def _async_event_filter(self, event): + """Filter events.""" + if event.event_type in self.exclude_t: + return False + + entity_id = event.data.get(ATTR_ENTITY_ID) + if entity_id is not None: + if not self.entity_filter(entity_id): + return False + + return True def do_adhoc_purge(self, **kwargs): """Trigger an adhoc purge retaining keep_days worth of data.""" @@ -378,13 +393,6 @@ class Recorder(threading.Thread): self._timechanges_seen = 0 self._commit_event_session_or_retry() continue - if event.event_type in self.exclude_t: - continue - - entity_id = event.data.get(ATTR_ENTITY_ID) - if entity_id is not None: - if not self.entity_filter(entity_id): - continue try: if event.event_type == EVENT_STATE_CHANGED: diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index bbc1479524a..7225b7c375d 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -1139,17 +1139,13 @@ class EntityRegistryDisabledHandler: def async_setup(self) -> None: """Set up the disable handler.""" self.hass.bus.async_listen( - entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, self._handle_entry_updated + entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, + self._handle_entry_updated, + event_filter=_handle_entry_updated_filter, ) async def _handle_entry_updated(self, event: Event) -> None: """Handle entity registry entry update.""" - if ( - event.data["action"] != "update" - or "disabled_by" not in event.data["changes"] - ): - return - if self.registry is None: self.registry = await entity_registry.async_get_registry(self.hass) @@ -1203,6 +1199,14 @@ class EntityRegistryDisabledHandler: ) +@callback +def _handle_entry_updated_filter(event: Event) -> bool: + """Handle entity registry entry update filter.""" + if event.data["action"] != "update" or "disabled_by" not in event.data["changes"]: + return False + return True + + async def support_entry_unload(hass: HomeAssistant, domain: str) -> bool: """Test if a domain supports entry unloading.""" integration = await loader.async_get_integration(hass, domain) diff --git a/homeassistant/core.py b/homeassistant/core.py index fff16cdd31f..b62dd1ee7d5 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -28,6 +28,7 @@ from typing import ( Mapping, Optional, Set, + Tuple, TypeVar, Union, cast, @@ -661,7 +662,7 @@ class EventBus: def __init__(self, hass: HomeAssistant) -> None: """Initialize a new event bus.""" - self._listeners: Dict[str, List[HassJob]] = {} + self._listeners: Dict[str, List[Tuple[HassJob, Optional[Callable]]]] = {} self._hass = hass @callback @@ -717,7 +718,14 @@ class EventBus: if not listeners: return - for job in listeners: + for job, event_filter in listeners: + if event_filter is not None: + try: + if not event_filter(event): + continue + except Exception: # pylint: disable=broad-except + _LOGGER.exception("Error in event filter") + continue self._hass.async_add_hass_job(job, event) def listen(self, event_type: str, listener: Callable) -> CALLBACK_TYPE: @@ -737,23 +745,38 @@ class EventBus: return remove_listener @callback - def async_listen(self, event_type: str, listener: Callable) -> CALLBACK_TYPE: + def async_listen( + self, + event_type: str, + listener: Callable, + event_filter: Optional[Callable] = None, + ) -> CALLBACK_TYPE: """Listen for all events or events of a specific type. To listen to all events specify the constant ``MATCH_ALL`` as event_type. + An optional event_filter, which must be a callable decorated with + @callback that returns a boolean value, determines if the + listener callable should run. + This method must be run in the event loop. """ - return self._async_listen_job(event_type, HassJob(listener)) + if event_filter is not None and not is_callback(event_filter): + raise HomeAssistantError(f"Event filter {event_filter} is not a callback") + return self._async_listen_filterable_job( + event_type, (HassJob(listener), event_filter) + ) @callback - def _async_listen_job(self, event_type: str, hassjob: HassJob) -> CALLBACK_TYPE: - self._listeners.setdefault(event_type, []).append(hassjob) + def _async_listen_filterable_job( + self, event_type: str, filterable_job: Tuple[HassJob, Optional[Callable]] + ) -> CALLBACK_TYPE: + self._listeners.setdefault(event_type, []).append(filterable_job) def remove_listener() -> None: """Remove the listener.""" - self._async_remove_listener(event_type, hassjob) + self._async_remove_listener(event_type, filterable_job) return remove_listener @@ -786,12 +809,12 @@ class EventBus: This method must be run in the event loop. """ - job: Optional[HassJob] = None + filterable_job: Optional[Tuple[HassJob, Optional[Callable]]] = None @callback def _onetime_listener(event: Event) -> None: """Remove listener from event bus and then fire listener.""" - nonlocal job + nonlocal filterable_job if hasattr(_onetime_listener, "run"): return # Set variable so that we will never run twice. @@ -800,22 +823,24 @@ class EventBus: # multiple times as well. # This will make sure the second time it does nothing. setattr(_onetime_listener, "run", True) - assert job is not None - self._async_remove_listener(event_type, job) + assert filterable_job is not None + self._async_remove_listener(event_type, filterable_job) self._hass.async_run_job(listener, event) - job = HassJob(_onetime_listener) + filterable_job = (HassJob(_onetime_listener), None) - return self._async_listen_job(event_type, job) + return self._async_listen_filterable_job(event_type, filterable_job) @callback - def _async_remove_listener(self, event_type: str, hassjob: HassJob) -> None: + def _async_remove_listener( + self, event_type: str, filterable_job: Tuple[HassJob, Optional[Callable]] + ) -> None: """Remove a listener of a specific event_type. This method must be run in the event loop. """ try: - self._listeners[event_type].remove(hassjob) + self._listeners[event_type].remove(filterable_job) # delete event_type list if empty if not self._listeners[event_type]: @@ -823,7 +848,9 @@ class EventBus: except (KeyError, ValueError): # KeyError is key event_type listener did not exist # ValueError if listener did not exist within event_type - _LOGGER.exception("Unable to remove unknown job listener %s", hassjob) + _LOGGER.exception( + "Unable to remove unknown job listener %s", filterable_job + ) class State: diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index 0d62b2cab47..77dc2cdf609 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -686,25 +686,34 @@ def async_setup_cleanup(hass: HomeAssistantType, dev_reg: DeviceRegistry) -> Non ) async def entity_registry_changed(event: Event) -> None: - """Handle entity updated or removed.""" + """Handle entity updated or removed dispatch.""" + await debounced_cleanup.async_call() + + @callback + def entity_registry_changed_filter(event: Event) -> bool: + """Handle entity updated or removed filter.""" if ( event.data["action"] == "update" and "device_id" not in event.data["changes"] ) or event.data["action"] == "create": - return + return False - await debounced_cleanup.async_call() + return True if hass.is_running: hass.bus.async_listen( - entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, entity_registry_changed + entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, + entity_registry_changed, + event_filter=entity_registry_changed_filter, ) return async def startup_clean(event: Event) -> None: """Clean up on startup.""" hass.bus.async_listen( - entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, entity_registry_changed + entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, + entity_registry_changed, + event_filter=entity_registry_changed_filter, ) await debounced_cleanup.async_call() diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 418c3f90304..0938ea9165f 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -641,12 +641,14 @@ def async_setup_entity_restore( ) -> None: """Set up the entity restore mechanism.""" + @callback + def cleanup_restored_states_filter(event: Event) -> bool: + """Clean up restored states filter.""" + return bool(event.data["action"] == "remove") + @callback def cleanup_restored_states(event: Event) -> None: """Clean up restored states.""" - if event.data["action"] != "remove": - return - state = hass.states.get(event.data["entity_id"]) if state is None or not state.attributes.get(ATTR_RESTORED): @@ -654,7 +656,11 @@ def async_setup_entity_restore( hass.states.async_remove(event.data["entity_id"], context=event.context) - hass.bus.async_listen(EVENT_ENTITY_REGISTRY_UPDATED, cleanup_restored_states) + hass.bus.async_listen( + EVENT_ENTITY_REGISTRY_UPDATED, + cleanup_restored_states, + event_filter=cleanup_restored_states_filter, + ) if hass.is_running: return diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index 102a84863bd..f496c7088a4 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -180,7 +180,7 @@ def async_track_state_change( job = HassJob(action) @callback - def state_change_listener(event: Event) -> None: + def state_change_filter(event: Event) -> bool: """Handle specific state changes.""" if from_state is not None: old_state = event.data.get("old_state") @@ -188,15 +188,21 @@ def async_track_state_change( old_state = old_state.state if not match_from_state(old_state): - return + return False + if to_state is not None: new_state = event.data.get("new_state") if new_state is not None: new_state = new_state.state if not match_to_state(new_state): - return + return False + return True + + @callback + def state_change_dispatcher(event: Event) -> None: + """Handle specific state changes.""" hass.async_run_hass_job( job, event.data.get("entity_id"), @@ -204,6 +210,14 @@ def async_track_state_change( event.data.get("new_state"), ) + @callback + def state_change_listener(event: Event) -> None: + """Handle specific state changes.""" + if not state_change_filter(event): + return + + state_change_dispatcher(event) + if entity_ids != MATCH_ALL: # If we have a list of entity ids we use # async_track_state_change_event to route @@ -215,7 +229,9 @@ def async_track_state_change( # entity_id. return async_track_state_change_event(hass, entity_ids, state_change_listener) - return hass.bus.async_listen(EVENT_STATE_CHANGED, state_change_listener) + return hass.bus.async_listen( + EVENT_STATE_CHANGED, state_change_dispatcher, event_filter=state_change_filter + ) track_state_change = threaded_listener_factory(async_track_state_change) @@ -246,6 +262,11 @@ def async_track_state_change_event( if TRACK_STATE_CHANGE_LISTENER not in hass.data: + @callback + def _async_state_change_filter(event: Event) -> bool: + """Filter state changes by entity_id.""" + return event.data.get("entity_id") in entity_callbacks + @callback def _async_state_change_dispatcher(event: Event) -> None: """Dispatch state changes by entity_id.""" @@ -263,7 +284,9 @@ def async_track_state_change_event( ) hass.data[TRACK_STATE_CHANGE_LISTENER] = hass.bus.async_listen( - EVENT_STATE_CHANGED, _async_state_change_dispatcher + EVENT_STATE_CHANGED, + _async_state_change_dispatcher, + event_filter=_async_state_change_filter, ) job = HassJob(action) @@ -329,6 +352,12 @@ def async_track_entity_registry_updated_event( if TRACK_ENTITY_REGISTRY_UPDATED_LISTENER not in hass.data: + @callback + def _async_entity_registry_updated_filter(event: Event) -> bool: + """Filter entity registry updates by entity_id.""" + entity_id = event.data.get("old_entity_id", event.data["entity_id"]) + return entity_id in entity_callbacks + @callback def _async_entity_registry_updated_dispatcher(event: Event) -> None: """Dispatch entity registry updates by entity_id.""" @@ -347,7 +376,9 @@ def async_track_entity_registry_updated_event( ) hass.data[TRACK_ENTITY_REGISTRY_UPDATED_LISTENER] = hass.bus.async_listen( - EVENT_ENTITY_REGISTRY_UPDATED, _async_entity_registry_updated_dispatcher + EVENT_ENTITY_REGISTRY_UPDATED, + _async_entity_registry_updated_dispatcher, + event_filter=_async_entity_registry_updated_filter, ) job = HassJob(action) @@ -404,6 +435,11 @@ def async_track_state_added_domain( if TRACK_STATE_ADDED_DOMAIN_LISTENER not in hass.data: + @callback + def _async_state_change_filter(event: Event) -> bool: + """Filter state changes by entity_id.""" + return event.data.get("old_state") is None + @callback def _async_state_change_dispatcher(event: Event) -> None: """Dispatch state changes by entity_id.""" @@ -413,7 +449,9 @@ def async_track_state_added_domain( _async_dispatch_domain_event(hass, event, domain_callbacks) hass.data[TRACK_STATE_ADDED_DOMAIN_LISTENER] = hass.bus.async_listen( - EVENT_STATE_CHANGED, _async_state_change_dispatcher + EVENT_STATE_CHANGED, + _async_state_change_dispatcher, + event_filter=_async_state_change_filter, ) job = HassJob(action) @@ -450,6 +488,11 @@ def async_track_state_removed_domain( if TRACK_STATE_REMOVED_DOMAIN_LISTENER not in hass.data: + @callback + def _async_state_change_filter(event: Event) -> bool: + """Filter state changes by entity_id.""" + return event.data.get("new_state") is None + @callback def _async_state_change_dispatcher(event: Event) -> None: """Dispatch state changes by entity_id.""" @@ -459,7 +502,9 @@ def async_track_state_removed_domain( _async_dispatch_domain_event(hass, event, domain_callbacks) hass.data[TRACK_STATE_REMOVED_DOMAIN_LISTENER] = hass.bus.async_listen( - EVENT_STATE_CHANGED, _async_state_change_dispatcher + EVENT_STATE_CHANGED, + _async_state_change_dispatcher, + event_filter=_async_state_change_filter, ) job = HassJob(action) diff --git a/homeassistant/scripts/benchmark/__init__.py b/homeassistant/scripts/benchmark/__init__.py index 48e6d7d5302..3f590362504 100644 --- a/homeassistant/scripts/benchmark/__init__.py +++ b/homeassistant/scripts/benchmark/__init__.py @@ -62,7 +62,7 @@ async def fire_events(hass): """Fire a million events.""" count = 0 event_name = "benchmark_event" - event = asyncio.Event() + events_to_fire = 10 ** 6 @core.callback def listener(_): @@ -70,17 +70,48 @@ async def fire_events(hass): nonlocal count count += 1 - if count == 10 ** 6: - event.set() - hass.bus.async_listen(event_name, listener) - for _ in range(10 ** 6): + for _ in range(events_to_fire): hass.bus.async_fire(event_name) start = timer() - await event.wait() + await hass.async_block_till_done() + + assert count == events_to_fire + + return timer() - start + + +@benchmark +async def fire_events_with_filter(hass): + """Fire a million events with a filter that rejects them.""" + count = 0 + event_name = "benchmark_event" + events_to_fire = 10 ** 6 + + @core.callback + def event_filter(event): + """Filter event.""" + return False + + @core.callback + def listener(_): + """Handle event.""" + nonlocal count + count += 1 + + hass.bus.async_listen(event_name, listener, event_filter=event_filter) + + for _ in range(events_to_fire): + hass.bus.async_fire(event_name) + + start = timer() + + await hass.async_block_till_done() + + assert count == 0 return timer() - start @@ -154,7 +185,7 @@ async def state_changed_event_helper(hass): """Run a million events through state changed event helper with 1000 entities.""" count = 0 entity_id = "light.kitchen" - event = asyncio.Event() + events_to_fire = 10 ** 6 @core.callback def listener(*args): @@ -162,9 +193,6 @@ async def state_changed_event_helper(hass): nonlocal count count += 1 - if count == 10 ** 6: - event.set() - hass.helpers.event.async_track_state_change_event( [f"{entity_id}{idx}" for idx in range(1000)], listener ) @@ -175,12 +203,49 @@ async def state_changed_event_helper(hass): "new_state": core.State(entity_id, "on"), } - for _ in range(10 ** 6): + for _ in range(events_to_fire): hass.bus.async_fire(EVENT_STATE_CHANGED, event_data) start = timer() - await event.wait() + await hass.async_block_till_done() + + assert count == events_to_fire + + return timer() - start + + +@benchmark +async def state_changed_event_filter_helper(hass): + """Run a million events through state changed event helper with 1000 entities that all get filtered.""" + count = 0 + entity_id = "light.kitchen" + events_to_fire = 10 ** 6 + + @core.callback + def listener(*args): + """Handle event.""" + nonlocal count + count += 1 + + hass.helpers.event.async_track_state_change_event( + [f"{entity_id}{idx}" for idx in range(1000)], listener + ) + + event_data = { + "entity_id": "switch.no_listeners", + "old_state": core.State(entity_id, "off"), + "new_state": core.State(entity_id, "on"), + } + + for _ in range(events_to_fire): + hass.bus.async_fire(EVENT_STATE_CHANGED, event_data) + + start = timer() + + await hass.async_block_till_done() + + assert count == 0 return timer() - start diff --git a/tests/components/mqtt/test_device_tracker_discovery.py b/tests/components/mqtt/test_device_tracker_discovery.py index 2c445ee0fa5..f158a878fcd 100644 --- a/tests/components/mqtt/test_device_tracker_discovery.py +++ b/tests/components/mqtt/test_device_tracker_discovery.py @@ -194,6 +194,7 @@ async def test_cleanup_device_tracker(hass, device_reg, entity_reg, mqtt_mock): device_reg.async_remove_device(device_entry.id) await hass.async_block_till_done() + await hass.async_block_till_done() # Verify device and registry entries are cleared device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}) diff --git a/tests/components/mqtt/test_discovery.py b/tests/components/mqtt/test_discovery.py index c9b0879d490..fed0dfa54d6 100644 --- a/tests/components/mqtt/test_discovery.py +++ b/tests/components/mqtt/test_discovery.py @@ -411,6 +411,7 @@ async def test_cleanup_device(hass, device_reg, entity_reg, mqtt_mock): device_reg.async_remove_device(device_entry.id) await hass.async_block_till_done() + await hass.async_block_till_done() # Verify device and registry entries are cleared device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}) diff --git a/tests/components/unifi/test_device_tracker.py b/tests/components/unifi/test_device_tracker.py index 39465a34aef..e8081a831c2 100644 --- a/tests/components/unifi/test_device_tracker.py +++ b/tests/components/unifi/test_device_tracker.py @@ -353,6 +353,7 @@ async def test_remove_clients(hass, aioclient_mock): } controller.api.session_handler(SIGNAL_DATA) await hass.async_block_till_done() + await hass.async_block_till_done() assert len(hass.states.async_entity_ids(TRACKER_DOMAIN)) == 1 diff --git a/tests/test_core.py b/tests/test_core.py index dfd5b925e1c..88b4e1d58f6 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -379,6 +379,35 @@ async def test_eventbus_add_remove_listener(hass): unsub() +async def test_eventbus_filtered_listener(hass): + """Test we can prefilter events.""" + calls = [] + + @ha.callback + def listener(event): + """Mock listener.""" + calls.append(event) + + @ha.callback + def filter(event): + """Mock filter.""" + return not event.data["filtered"] + + unsub = hass.bus.async_listen("test", listener, event_filter=filter) + + hass.bus.async_fire("test", {"filtered": True}) + await hass.async_block_till_done() + + assert len(calls) == 0 + + hass.bus.async_fire("test", {"filtered": False}) + await hass.async_block_till_done() + + assert len(calls) == 1 + + unsub() + + async def test_eventbus_unsubscribe_listener(hass): """Test unsubscribe listener from returned function.""" calls = []