Add types to event tracker data (#118010)

* Add types to event tracker data

* fixes

* do not test event internals in other tests

* fixes

* Update homeassistant/helpers/event.py

* cleanup

* cleanup
This commit is contained in:
J. Nick Koston 2024-05-24 04:09:39 -10:00 committed by GitHub
parent 7183260d95
commit a8fba691ee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 41 additions and 62 deletions

View file

@ -54,30 +54,21 @@ from .sun import get_astral_event_next
from .template import RenderInfo, Template, result_as_boolean
from .typing import TemplateVarsType
TRACK_STATE_CHANGE_CALLBACKS = "track_state_change_callbacks"
TRACK_STATE_CHANGE_LISTENER: HassKey[Callable[[], None]] = HassKey(
"track_state_change_listener"
_TRACK_STATE_CHANGE_DATA: HassKey[_KeyedEventData[EventStateChangedData]] = HassKey(
"track_state_change_data"
)
TRACK_STATE_ADDED_DOMAIN_CALLBACKS = "track_state_added_domain_callbacks"
TRACK_STATE_ADDED_DOMAIN_LISTENER: HassKey[Callable[[], None]] = HassKey(
"track_state_added_domain_listener"
_TRACK_STATE_ADDED_DOMAIN_DATA: HassKey[_KeyedEventData[EventStateChangedData]] = (
HassKey("track_state_added_domain_data")
)
TRACK_STATE_REMOVED_DOMAIN_CALLBACKS = "track_state_removed_domain_callbacks"
TRACK_STATE_REMOVED_DOMAIN_LISTENER: HassKey[Callable[[], None]] = HassKey(
"track_state_removed_domain_listener"
)
TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS = "track_entity_registry_updated_callbacks"
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER: HassKey[Callable[[], None]] = HassKey(
"track_entity_registry_updated_listener"
)
TRACK_DEVICE_REGISTRY_UPDATED_CALLBACKS = "track_device_registry_updated_callbacks"
TRACK_DEVICE_REGISTRY_UPDATED_LISTENER: HassKey[Callable[[], None]] = HassKey(
"track_device_registry_updated_listener"
_TRACK_STATE_REMOVED_DOMAIN_DATA: HassKey[_KeyedEventData[EventStateChangedData]] = (
HassKey("track_state_removed_domain_data")
)
_TRACK_ENTITY_REGISTRY_UPDATED_DATA: HassKey[
_KeyedEventData[EventEntityRegistryUpdatedData]
] = HassKey("track_entity_registry_updated_data")
_TRACK_DEVICE_REGISTRY_UPDATED_DATA: HassKey[
_KeyedEventData[EventDeviceRegistryUpdatedData]
] = HassKey("track_device_registry_updated_data")
_ALL_LISTENER = "all"
_DOMAINS_LISTENER = "domains"
@ -99,8 +90,7 @@ _TypedDictT = TypeVar("_TypedDictT", bound=Mapping[str, Any])
class _KeyedEventTracker(Generic[_TypedDictT]):
"""Class to track events by key."""
listeners_key: HassKey[Callable[[], None]]
callbacks_key: str
key: HassKey[_KeyedEventData[_TypedDictT]]
event_type: EventType[_TypedDictT] | str
dispatcher_callable: Callable[
[
@ -120,6 +110,14 @@ class _KeyedEventTracker(Generic[_TypedDictT]):
]
@dataclass(slots=True, frozen=True)
class _KeyedEventData(Generic[_TypedDictT]):
"""Class to track data for events by key."""
listener: CALLBACK_TYPE
callbacks: defaultdict[str, list[HassJob[[Event[_TypedDictT]], Any]]]
@dataclass(slots=True)
class TrackStates:
"""Class for keeping track of states being tracked.
@ -354,8 +352,7 @@ def _async_state_change_filter(
_KEYED_TRACK_STATE_CHANGE = _KeyedEventTracker(
listeners_key=TRACK_STATE_CHANGE_LISTENER,
callbacks_key=TRACK_STATE_CHANGE_CALLBACKS,
key=_TRACK_STATE_CHANGE_DATA,
event_type=EVENT_STATE_CHANGED,
dispatcher_callable=_async_dispatch_entity_id_event,
filter_callable=_async_state_change_filter,
@ -380,10 +377,10 @@ def _remove_empty_listener() -> None:
"""Remove a listener that does nothing."""
@callback # type: ignore[arg-type] # mypy bug?
@callback
def _remove_listener(
hass: HomeAssistant,
listeners_key: HassKey[Callable[[], None]],
tracker: _KeyedEventTracker[_TypedDictT],
keys: Iterable[str],
job: HassJob[[Event[_TypedDictT]], Any],
callbacks: dict[str, list[HassJob[[Event[_TypedDictT]], Any]]],
@ -391,12 +388,11 @@ def _remove_listener(
"""Remove listener."""
for key in keys:
callbacks[key].remove(job)
if len(callbacks[key]) == 0:
if not callbacks[key]:
del callbacks[key]
if not callbacks:
hass.data[listeners_key]()
del hass.data[listeners_key]
hass.data.pop(tracker.key).listener()
# tracker, not hass is intentionally the first argument here since its
@ -411,26 +407,24 @@ def _async_track_event(
"""Track an event by a specific key.
This function is intended for internal use only.
The dispatcher_callable, filter_callable, event_type, and run_immediately
must always be the same for the listener_key as the first call to this
function will set the listener_key in hass.data.
"""
if not keys:
return _remove_empty_listener
hass_data = hass.data
callbacks: defaultdict[str, list[HassJob[[Event[_TypedDictT]], Any]]] | None
if not (callbacks := hass_data.get(tracker.callbacks_key)):
callbacks = hass_data[tracker.callbacks_key] = defaultdict(list)
listeners_key = tracker.listeners_key
if tracker.listeners_key not in hass_data:
hass_data[tracker.listeners_key] = hass.bus.async_listen(
tracker_key = tracker.key
if tracker_key in hass_data:
event_data = hass_data[tracker_key]
callbacks = event_data.callbacks
else:
callbacks = defaultdict(list)
listener = hass.bus.async_listen(
tracker.event_type,
partial(tracker.dispatcher_callable, hass, callbacks),
event_filter=partial(tracker.filter_callable, hass, callbacks),
)
event_data = _KeyedEventData(listener, callbacks)
hass_data[tracker_key] = event_data
job = HassJob(action, f"track {tracker.event_type} event {keys}", job_type=job_type)
@ -441,12 +435,12 @@ def _async_track_event(
# during startup, and we want to avoid the overhead of
# creating empty lists and throwing them away.
callbacks[keys].append(job)
keys = [keys]
keys = (keys,)
else:
for key in keys:
callbacks[key].append(job)
return partial(_remove_listener, hass, listeners_key, keys, job, callbacks)
return partial(_remove_listener, hass, tracker, keys, job, callbacks)
@callback
@ -484,8 +478,7 @@ def _async_entity_registry_updated_filter(
_KEYED_TRACK_ENTITY_REGISTRY_UPDATED = _KeyedEventTracker(
listeners_key=TRACK_ENTITY_REGISTRY_UPDATED_LISTENER,
callbacks_key=TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS,
key=_TRACK_ENTITY_REGISTRY_UPDATED_DATA,
event_type=EVENT_ENTITY_REGISTRY_UPDATED,
dispatcher_callable=_async_dispatch_old_entity_id_or_entity_id_event,
filter_callable=_async_entity_registry_updated_filter,
@ -542,8 +535,7 @@ def _async_dispatch_device_id_event(
_KEYED_TRACK_DEVICE_REGISTRY_UPDATED = _KeyedEventTracker(
listeners_key=TRACK_DEVICE_REGISTRY_UPDATED_LISTENER,
callbacks_key=TRACK_DEVICE_REGISTRY_UPDATED_CALLBACKS,
key=_TRACK_DEVICE_REGISTRY_UPDATED_DATA,
event_type=EVENT_DEVICE_REGISTRY_UPDATED,
dispatcher_callable=_async_dispatch_device_id_event,
filter_callable=_async_device_registry_updated_filter,
@ -613,8 +605,7 @@ def async_track_state_added_domain(
_KEYED_TRACK_STATE_ADDED_DOMAIN = _KeyedEventTracker(
listeners_key=TRACK_STATE_ADDED_DOMAIN_LISTENER,
callbacks_key=TRACK_STATE_ADDED_DOMAIN_CALLBACKS,
key=_TRACK_STATE_ADDED_DOMAIN_DATA,
event_type=EVENT_STATE_CHANGED,
dispatcher_callable=_async_dispatch_domain_event,
filter_callable=_async_domain_added_filter,
@ -651,8 +642,7 @@ def _async_domain_removed_filter(
_KEYED_TRACK_STATE_REMOVED_DOMAIN = _KeyedEventTracker(
listeners_key=TRACK_STATE_REMOVED_DOMAIN_LISTENER,
callbacks_key=TRACK_STATE_REMOVED_DOMAIN_CALLBACKS,
key=_TRACK_STATE_REMOVED_DOMAIN_DATA,
event_type=EVENT_STATE_CHANGED,
dispatcher_callable=_async_dispatch_domain_event,
filter_callable=_async_domain_removed_filter,

View file

@ -33,7 +33,6 @@ from homeassistant.const import (
)
from homeassistant.core import CoreState, HomeAssistant
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.event import TRACK_STATE_CHANGE_CALLBACKS
from homeassistant.setup import async_setup_component
from . import common
@ -901,10 +900,6 @@ async def test_reloading_groups(hass: HomeAssistant) -> None:
"group.test_group",
]
assert hass.bus.async_listeners()["state_changed"] == 1
assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["hello.world"]) == 1
assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["light.bowl"]) == 1
assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["test.one"]) == 1
assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["test.two"]) == 1
with patch(
"homeassistant.config.load_yaml_config_file",
@ -920,9 +915,6 @@ async def test_reloading_groups(hass: HomeAssistant) -> None:
"group.hello",
]
assert hass.bus.async_listeners()["state_changed"] == 1
assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["light.bowl"]) == 1
assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["test.one"]) == 1
assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["test.two"]) == 1
async def test_modify_group(hass: HomeAssistant) -> None:

View file

@ -48,7 +48,6 @@ from homeassistant.const import (
__version__ as hass_version,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers.event import TRACK_STATE_CHANGE_CALLBACKS
from tests.common import async_mock_service
@ -66,9 +65,7 @@ async def test_accessory_cancels_track_state_change_on_stop(
"homeassistant.components.homekit.accessories.HomeAccessory.async_update_state"
):
acc.run()
assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS][entity_id]) == 1
await acc.stop()
assert entity_id not in hass.data[TRACK_STATE_CHANGE_CALLBACKS]
async def test_home_accessory(hass: HomeAssistant, hk_driver) -> None: