From a8fba691ee6a9eb8fc172bf4f3bc42e4ed4d8bff Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 24 May 2024 04:09:39 -1000 Subject: [PATCH] Add types to event tracker data (#118010) * Add types to event tracker data * fixes * do not test event internals in other tests * fixes * Update homeassistant/helpers/event.py * cleanup * cleanup --- homeassistant/helpers/event.py | 92 +++++++++----------- tests/components/group/test_init.py | 8 -- tests/components/homekit/test_accessories.py | 3 - 3 files changed, 41 insertions(+), 62 deletions(-) diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index fd97afbcaaf..4150d871b6b 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -54,30 +54,21 @@ from .sun import get_astral_event_next from .template import RenderInfo, Template, result_as_boolean from .typing import TemplateVarsType -TRACK_STATE_CHANGE_CALLBACKS = "track_state_change_callbacks" -TRACK_STATE_CHANGE_LISTENER: HassKey[Callable[[], None]] = HassKey( - "track_state_change_listener" +_TRACK_STATE_CHANGE_DATA: HassKey[_KeyedEventData[EventStateChangedData]] = HassKey( + "track_state_change_data" ) - -TRACK_STATE_ADDED_DOMAIN_CALLBACKS = "track_state_added_domain_callbacks" -TRACK_STATE_ADDED_DOMAIN_LISTENER: HassKey[Callable[[], None]] = HassKey( - "track_state_added_domain_listener" +_TRACK_STATE_ADDED_DOMAIN_DATA: HassKey[_KeyedEventData[EventStateChangedData]] = ( + HassKey("track_state_added_domain_data") ) - -TRACK_STATE_REMOVED_DOMAIN_CALLBACKS = "track_state_removed_domain_callbacks" -TRACK_STATE_REMOVED_DOMAIN_LISTENER: HassKey[Callable[[], None]] = HassKey( - "track_state_removed_domain_listener" -) - -TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS = "track_entity_registry_updated_callbacks" -TRACK_ENTITY_REGISTRY_UPDATED_LISTENER: HassKey[Callable[[], None]] = HassKey( - "track_entity_registry_updated_listener" -) - -TRACK_DEVICE_REGISTRY_UPDATED_CALLBACKS = "track_device_registry_updated_callbacks" -TRACK_DEVICE_REGISTRY_UPDATED_LISTENER: HassKey[Callable[[], None]] = HassKey( - "track_device_registry_updated_listener" +_TRACK_STATE_REMOVED_DOMAIN_DATA: HassKey[_KeyedEventData[EventStateChangedData]] = ( + HassKey("track_state_removed_domain_data") ) +_TRACK_ENTITY_REGISTRY_UPDATED_DATA: HassKey[ + _KeyedEventData[EventEntityRegistryUpdatedData] +] = HassKey("track_entity_registry_updated_data") +_TRACK_DEVICE_REGISTRY_UPDATED_DATA: HassKey[ + _KeyedEventData[EventDeviceRegistryUpdatedData] +] = HassKey("track_device_registry_updated_data") _ALL_LISTENER = "all" _DOMAINS_LISTENER = "domains" @@ -99,8 +90,7 @@ _TypedDictT = TypeVar("_TypedDictT", bound=Mapping[str, Any]) class _KeyedEventTracker(Generic[_TypedDictT]): """Class to track events by key.""" - listeners_key: HassKey[Callable[[], None]] - callbacks_key: str + key: HassKey[_KeyedEventData[_TypedDictT]] event_type: EventType[_TypedDictT] | str dispatcher_callable: Callable[ [ @@ -120,6 +110,14 @@ class _KeyedEventTracker(Generic[_TypedDictT]): ] +@dataclass(slots=True, frozen=True) +class _KeyedEventData(Generic[_TypedDictT]): + """Class to track data for events by key.""" + + listener: CALLBACK_TYPE + callbacks: defaultdict[str, list[HassJob[[Event[_TypedDictT]], Any]]] + + @dataclass(slots=True) class TrackStates: """Class for keeping track of states being tracked. @@ -354,8 +352,7 @@ def _async_state_change_filter( _KEYED_TRACK_STATE_CHANGE = _KeyedEventTracker( - listeners_key=TRACK_STATE_CHANGE_LISTENER, - callbacks_key=TRACK_STATE_CHANGE_CALLBACKS, + key=_TRACK_STATE_CHANGE_DATA, event_type=EVENT_STATE_CHANGED, dispatcher_callable=_async_dispatch_entity_id_event, filter_callable=_async_state_change_filter, @@ -380,10 +377,10 @@ def _remove_empty_listener() -> None: """Remove a listener that does nothing.""" -@callback # type: ignore[arg-type] # mypy bug? +@callback def _remove_listener( hass: HomeAssistant, - listeners_key: HassKey[Callable[[], None]], + tracker: _KeyedEventTracker[_TypedDictT], keys: Iterable[str], job: HassJob[[Event[_TypedDictT]], Any], callbacks: dict[str, list[HassJob[[Event[_TypedDictT]], Any]]], @@ -391,12 +388,11 @@ def _remove_listener( """Remove listener.""" for key in keys: callbacks[key].remove(job) - if len(callbacks[key]) == 0: + if not callbacks[key]: del callbacks[key] if not callbacks: - hass.data[listeners_key]() - del hass.data[listeners_key] + hass.data.pop(tracker.key).listener() # tracker, not hass is intentionally the first argument here since its @@ -411,26 +407,24 @@ def _async_track_event( """Track an event by a specific key. This function is intended for internal use only. - - The dispatcher_callable, filter_callable, event_type, and run_immediately - must always be the same for the listener_key as the first call to this - function will set the listener_key in hass.data. """ if not keys: return _remove_empty_listener hass_data = hass.data - callbacks: defaultdict[str, list[HassJob[[Event[_TypedDictT]], Any]]] | None - if not (callbacks := hass_data.get(tracker.callbacks_key)): - callbacks = hass_data[tracker.callbacks_key] = defaultdict(list) - - listeners_key = tracker.listeners_key - if tracker.listeners_key not in hass_data: - hass_data[tracker.listeners_key] = hass.bus.async_listen( + tracker_key = tracker.key + if tracker_key in hass_data: + event_data = hass_data[tracker_key] + callbacks = event_data.callbacks + else: + callbacks = defaultdict(list) + listener = hass.bus.async_listen( tracker.event_type, partial(tracker.dispatcher_callable, hass, callbacks), event_filter=partial(tracker.filter_callable, hass, callbacks), ) + event_data = _KeyedEventData(listener, callbacks) + hass_data[tracker_key] = event_data job = HassJob(action, f"track {tracker.event_type} event {keys}", job_type=job_type) @@ -441,12 +435,12 @@ def _async_track_event( # during startup, and we want to avoid the overhead of # creating empty lists and throwing them away. callbacks[keys].append(job) - keys = [keys] + keys = (keys,) else: for key in keys: callbacks[key].append(job) - return partial(_remove_listener, hass, listeners_key, keys, job, callbacks) + return partial(_remove_listener, hass, tracker, keys, job, callbacks) @callback @@ -484,8 +478,7 @@ def _async_entity_registry_updated_filter( _KEYED_TRACK_ENTITY_REGISTRY_UPDATED = _KeyedEventTracker( - listeners_key=TRACK_ENTITY_REGISTRY_UPDATED_LISTENER, - callbacks_key=TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS, + key=_TRACK_ENTITY_REGISTRY_UPDATED_DATA, event_type=EVENT_ENTITY_REGISTRY_UPDATED, dispatcher_callable=_async_dispatch_old_entity_id_or_entity_id_event, filter_callable=_async_entity_registry_updated_filter, @@ -542,8 +535,7 @@ def _async_dispatch_device_id_event( _KEYED_TRACK_DEVICE_REGISTRY_UPDATED = _KeyedEventTracker( - listeners_key=TRACK_DEVICE_REGISTRY_UPDATED_LISTENER, - callbacks_key=TRACK_DEVICE_REGISTRY_UPDATED_CALLBACKS, + key=_TRACK_DEVICE_REGISTRY_UPDATED_DATA, event_type=EVENT_DEVICE_REGISTRY_UPDATED, dispatcher_callable=_async_dispatch_device_id_event, filter_callable=_async_device_registry_updated_filter, @@ -613,8 +605,7 @@ def async_track_state_added_domain( _KEYED_TRACK_STATE_ADDED_DOMAIN = _KeyedEventTracker( - listeners_key=TRACK_STATE_ADDED_DOMAIN_LISTENER, - callbacks_key=TRACK_STATE_ADDED_DOMAIN_CALLBACKS, + key=_TRACK_STATE_ADDED_DOMAIN_DATA, event_type=EVENT_STATE_CHANGED, dispatcher_callable=_async_dispatch_domain_event, filter_callable=_async_domain_added_filter, @@ -651,8 +642,7 @@ def _async_domain_removed_filter( _KEYED_TRACK_STATE_REMOVED_DOMAIN = _KeyedEventTracker( - listeners_key=TRACK_STATE_REMOVED_DOMAIN_LISTENER, - callbacks_key=TRACK_STATE_REMOVED_DOMAIN_CALLBACKS, + key=_TRACK_STATE_REMOVED_DOMAIN_DATA, event_type=EVENT_STATE_CHANGED, dispatcher_callable=_async_dispatch_domain_event, filter_callable=_async_domain_removed_filter, diff --git a/tests/components/group/test_init.py b/tests/components/group/test_init.py index d83f8be6993..4f928e0a8c2 100644 --- a/tests/components/group/test_init.py +++ b/tests/components/group/test_init.py @@ -33,7 +33,6 @@ from homeassistant.const import ( ) from homeassistant.core import CoreState, HomeAssistant from homeassistant.helpers import entity_registry as er -from homeassistant.helpers.event import TRACK_STATE_CHANGE_CALLBACKS from homeassistant.setup import async_setup_component from . import common @@ -901,10 +900,6 @@ async def test_reloading_groups(hass: HomeAssistant) -> None: "group.test_group", ] assert hass.bus.async_listeners()["state_changed"] == 1 - assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["hello.world"]) == 1 - assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["light.bowl"]) == 1 - assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["test.one"]) == 1 - assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["test.two"]) == 1 with patch( "homeassistant.config.load_yaml_config_file", @@ -920,9 +915,6 @@ async def test_reloading_groups(hass: HomeAssistant) -> None: "group.hello", ] assert hass.bus.async_listeners()["state_changed"] == 1 - assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["light.bowl"]) == 1 - assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["test.one"]) == 1 - assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["test.two"]) == 1 async def test_modify_group(hass: HomeAssistant) -> None: diff --git a/tests/components/homekit/test_accessories.py b/tests/components/homekit/test_accessories.py index 11a2675382a..32cd6622492 100644 --- a/tests/components/homekit/test_accessories.py +++ b/tests/components/homekit/test_accessories.py @@ -48,7 +48,6 @@ from homeassistant.const import ( __version__ as hass_version, ) from homeassistant.core import HomeAssistant -from homeassistant.helpers.event import TRACK_STATE_CHANGE_CALLBACKS from tests.common import async_mock_service @@ -66,9 +65,7 @@ async def test_accessory_cancels_track_state_change_on_stop( "homeassistant.components.homekit.accessories.HomeAccessory.async_update_state" ): acc.run() - assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS][entity_id]) == 1 await acc.stop() - assert entity_id not in hass.data[TRACK_STATE_CHANGE_CALLBACKS] async def test_home_accessory(hass: HomeAssistant, hk_driver) -> None: