Add types to event tracker data (#118010)
* Add types to event tracker data * fixes * do not test event internals in other tests * fixes * Update homeassistant/helpers/event.py * cleanup * cleanup
This commit is contained in:
parent
7183260d95
commit
a8fba691ee
3 changed files with 41 additions and 62 deletions
|
@ -54,30 +54,21 @@ from .sun import get_astral_event_next
|
|||
from .template import RenderInfo, Template, result_as_boolean
|
||||
from .typing import TemplateVarsType
|
||||
|
||||
TRACK_STATE_CHANGE_CALLBACKS = "track_state_change_callbacks"
|
||||
TRACK_STATE_CHANGE_LISTENER: HassKey[Callable[[], None]] = HassKey(
|
||||
"track_state_change_listener"
|
||||
_TRACK_STATE_CHANGE_DATA: HassKey[_KeyedEventData[EventStateChangedData]] = HassKey(
|
||||
"track_state_change_data"
|
||||
)
|
||||
|
||||
TRACK_STATE_ADDED_DOMAIN_CALLBACKS = "track_state_added_domain_callbacks"
|
||||
TRACK_STATE_ADDED_DOMAIN_LISTENER: HassKey[Callable[[], None]] = HassKey(
|
||||
"track_state_added_domain_listener"
|
||||
_TRACK_STATE_ADDED_DOMAIN_DATA: HassKey[_KeyedEventData[EventStateChangedData]] = (
|
||||
HassKey("track_state_added_domain_data")
|
||||
)
|
||||
|
||||
TRACK_STATE_REMOVED_DOMAIN_CALLBACKS = "track_state_removed_domain_callbacks"
|
||||
TRACK_STATE_REMOVED_DOMAIN_LISTENER: HassKey[Callable[[], None]] = HassKey(
|
||||
"track_state_removed_domain_listener"
|
||||
)
|
||||
|
||||
TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS = "track_entity_registry_updated_callbacks"
|
||||
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER: HassKey[Callable[[], None]] = HassKey(
|
||||
"track_entity_registry_updated_listener"
|
||||
)
|
||||
|
||||
TRACK_DEVICE_REGISTRY_UPDATED_CALLBACKS = "track_device_registry_updated_callbacks"
|
||||
TRACK_DEVICE_REGISTRY_UPDATED_LISTENER: HassKey[Callable[[], None]] = HassKey(
|
||||
"track_device_registry_updated_listener"
|
||||
_TRACK_STATE_REMOVED_DOMAIN_DATA: HassKey[_KeyedEventData[EventStateChangedData]] = (
|
||||
HassKey("track_state_removed_domain_data")
|
||||
)
|
||||
_TRACK_ENTITY_REGISTRY_UPDATED_DATA: HassKey[
|
||||
_KeyedEventData[EventEntityRegistryUpdatedData]
|
||||
] = HassKey("track_entity_registry_updated_data")
|
||||
_TRACK_DEVICE_REGISTRY_UPDATED_DATA: HassKey[
|
||||
_KeyedEventData[EventDeviceRegistryUpdatedData]
|
||||
] = HassKey("track_device_registry_updated_data")
|
||||
|
||||
_ALL_LISTENER = "all"
|
||||
_DOMAINS_LISTENER = "domains"
|
||||
|
@ -99,8 +90,7 @@ _TypedDictT = TypeVar("_TypedDictT", bound=Mapping[str, Any])
|
|||
class _KeyedEventTracker(Generic[_TypedDictT]):
|
||||
"""Class to track events by key."""
|
||||
|
||||
listeners_key: HassKey[Callable[[], None]]
|
||||
callbacks_key: str
|
||||
key: HassKey[_KeyedEventData[_TypedDictT]]
|
||||
event_type: EventType[_TypedDictT] | str
|
||||
dispatcher_callable: Callable[
|
||||
[
|
||||
|
@ -120,6 +110,14 @@ class _KeyedEventTracker(Generic[_TypedDictT]):
|
|||
]
|
||||
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
class _KeyedEventData(Generic[_TypedDictT]):
|
||||
"""Class to track data for events by key."""
|
||||
|
||||
listener: CALLBACK_TYPE
|
||||
callbacks: defaultdict[str, list[HassJob[[Event[_TypedDictT]], Any]]]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class TrackStates:
|
||||
"""Class for keeping track of states being tracked.
|
||||
|
@ -354,8 +352,7 @@ def _async_state_change_filter(
|
|||
|
||||
|
||||
_KEYED_TRACK_STATE_CHANGE = _KeyedEventTracker(
|
||||
listeners_key=TRACK_STATE_CHANGE_LISTENER,
|
||||
callbacks_key=TRACK_STATE_CHANGE_CALLBACKS,
|
||||
key=_TRACK_STATE_CHANGE_DATA,
|
||||
event_type=EVENT_STATE_CHANGED,
|
||||
dispatcher_callable=_async_dispatch_entity_id_event,
|
||||
filter_callable=_async_state_change_filter,
|
||||
|
@ -380,10 +377,10 @@ def _remove_empty_listener() -> None:
|
|||
"""Remove a listener that does nothing."""
|
||||
|
||||
|
||||
@callback # type: ignore[arg-type] # mypy bug?
|
||||
@callback
|
||||
def _remove_listener(
|
||||
hass: HomeAssistant,
|
||||
listeners_key: HassKey[Callable[[], None]],
|
||||
tracker: _KeyedEventTracker[_TypedDictT],
|
||||
keys: Iterable[str],
|
||||
job: HassJob[[Event[_TypedDictT]], Any],
|
||||
callbacks: dict[str, list[HassJob[[Event[_TypedDictT]], Any]]],
|
||||
|
@ -391,12 +388,11 @@ def _remove_listener(
|
|||
"""Remove listener."""
|
||||
for key in keys:
|
||||
callbacks[key].remove(job)
|
||||
if len(callbacks[key]) == 0:
|
||||
if not callbacks[key]:
|
||||
del callbacks[key]
|
||||
|
||||
if not callbacks:
|
||||
hass.data[listeners_key]()
|
||||
del hass.data[listeners_key]
|
||||
hass.data.pop(tracker.key).listener()
|
||||
|
||||
|
||||
# tracker, not hass is intentionally the first argument here since its
|
||||
|
@ -411,26 +407,24 @@ def _async_track_event(
|
|||
"""Track an event by a specific key.
|
||||
|
||||
This function is intended for internal use only.
|
||||
|
||||
The dispatcher_callable, filter_callable, event_type, and run_immediately
|
||||
must always be the same for the listener_key as the first call to this
|
||||
function will set the listener_key in hass.data.
|
||||
"""
|
||||
if not keys:
|
||||
return _remove_empty_listener
|
||||
|
||||
hass_data = hass.data
|
||||
callbacks: defaultdict[str, list[HassJob[[Event[_TypedDictT]], Any]]] | None
|
||||
if not (callbacks := hass_data.get(tracker.callbacks_key)):
|
||||
callbacks = hass_data[tracker.callbacks_key] = defaultdict(list)
|
||||
|
||||
listeners_key = tracker.listeners_key
|
||||
if tracker.listeners_key not in hass_data:
|
||||
hass_data[tracker.listeners_key] = hass.bus.async_listen(
|
||||
tracker_key = tracker.key
|
||||
if tracker_key in hass_data:
|
||||
event_data = hass_data[tracker_key]
|
||||
callbacks = event_data.callbacks
|
||||
else:
|
||||
callbacks = defaultdict(list)
|
||||
listener = hass.bus.async_listen(
|
||||
tracker.event_type,
|
||||
partial(tracker.dispatcher_callable, hass, callbacks),
|
||||
event_filter=partial(tracker.filter_callable, hass, callbacks),
|
||||
)
|
||||
event_data = _KeyedEventData(listener, callbacks)
|
||||
hass_data[tracker_key] = event_data
|
||||
|
||||
job = HassJob(action, f"track {tracker.event_type} event {keys}", job_type=job_type)
|
||||
|
||||
|
@ -441,12 +435,12 @@ def _async_track_event(
|
|||
# during startup, and we want to avoid the overhead of
|
||||
# creating empty lists and throwing them away.
|
||||
callbacks[keys].append(job)
|
||||
keys = [keys]
|
||||
keys = (keys,)
|
||||
else:
|
||||
for key in keys:
|
||||
callbacks[key].append(job)
|
||||
|
||||
return partial(_remove_listener, hass, listeners_key, keys, job, callbacks)
|
||||
return partial(_remove_listener, hass, tracker, keys, job, callbacks)
|
||||
|
||||
|
||||
@callback
|
||||
|
@ -484,8 +478,7 @@ def _async_entity_registry_updated_filter(
|
|||
|
||||
|
||||
_KEYED_TRACK_ENTITY_REGISTRY_UPDATED = _KeyedEventTracker(
|
||||
listeners_key=TRACK_ENTITY_REGISTRY_UPDATED_LISTENER,
|
||||
callbacks_key=TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS,
|
||||
key=_TRACK_ENTITY_REGISTRY_UPDATED_DATA,
|
||||
event_type=EVENT_ENTITY_REGISTRY_UPDATED,
|
||||
dispatcher_callable=_async_dispatch_old_entity_id_or_entity_id_event,
|
||||
filter_callable=_async_entity_registry_updated_filter,
|
||||
|
@ -542,8 +535,7 @@ def _async_dispatch_device_id_event(
|
|||
|
||||
|
||||
_KEYED_TRACK_DEVICE_REGISTRY_UPDATED = _KeyedEventTracker(
|
||||
listeners_key=TRACK_DEVICE_REGISTRY_UPDATED_LISTENER,
|
||||
callbacks_key=TRACK_DEVICE_REGISTRY_UPDATED_CALLBACKS,
|
||||
key=_TRACK_DEVICE_REGISTRY_UPDATED_DATA,
|
||||
event_type=EVENT_DEVICE_REGISTRY_UPDATED,
|
||||
dispatcher_callable=_async_dispatch_device_id_event,
|
||||
filter_callable=_async_device_registry_updated_filter,
|
||||
|
@ -613,8 +605,7 @@ def async_track_state_added_domain(
|
|||
|
||||
|
||||
_KEYED_TRACK_STATE_ADDED_DOMAIN = _KeyedEventTracker(
|
||||
listeners_key=TRACK_STATE_ADDED_DOMAIN_LISTENER,
|
||||
callbacks_key=TRACK_STATE_ADDED_DOMAIN_CALLBACKS,
|
||||
key=_TRACK_STATE_ADDED_DOMAIN_DATA,
|
||||
event_type=EVENT_STATE_CHANGED,
|
||||
dispatcher_callable=_async_dispatch_domain_event,
|
||||
filter_callable=_async_domain_added_filter,
|
||||
|
@ -651,8 +642,7 @@ def _async_domain_removed_filter(
|
|||
|
||||
|
||||
_KEYED_TRACK_STATE_REMOVED_DOMAIN = _KeyedEventTracker(
|
||||
listeners_key=TRACK_STATE_REMOVED_DOMAIN_LISTENER,
|
||||
callbacks_key=TRACK_STATE_REMOVED_DOMAIN_CALLBACKS,
|
||||
key=_TRACK_STATE_REMOVED_DOMAIN_DATA,
|
||||
event_type=EVENT_STATE_CHANGED,
|
||||
dispatcher_callable=_async_dispatch_domain_event,
|
||||
filter_callable=_async_domain_removed_filter,
|
||||
|
|
|
@ -33,7 +33,6 @@ from homeassistant.const import (
|
|||
)
|
||||
from homeassistant.core import CoreState, HomeAssistant
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.helpers.event import TRACK_STATE_CHANGE_CALLBACKS
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import common
|
||||
|
@ -901,10 +900,6 @@ async def test_reloading_groups(hass: HomeAssistant) -> None:
|
|||
"group.test_group",
|
||||
]
|
||||
assert hass.bus.async_listeners()["state_changed"] == 1
|
||||
assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["hello.world"]) == 1
|
||||
assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["light.bowl"]) == 1
|
||||
assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["test.one"]) == 1
|
||||
assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["test.two"]) == 1
|
||||
|
||||
with patch(
|
||||
"homeassistant.config.load_yaml_config_file",
|
||||
|
@ -920,9 +915,6 @@ async def test_reloading_groups(hass: HomeAssistant) -> None:
|
|||
"group.hello",
|
||||
]
|
||||
assert hass.bus.async_listeners()["state_changed"] == 1
|
||||
assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["light.bowl"]) == 1
|
||||
assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["test.one"]) == 1
|
||||
assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["test.two"]) == 1
|
||||
|
||||
|
||||
async def test_modify_group(hass: HomeAssistant) -> None:
|
||||
|
|
|
@ -48,7 +48,6 @@ from homeassistant.const import (
|
|||
__version__ as hass_version,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.event import TRACK_STATE_CHANGE_CALLBACKS
|
||||
|
||||
from tests.common import async_mock_service
|
||||
|
||||
|
@ -66,9 +65,7 @@ async def test_accessory_cancels_track_state_change_on_stop(
|
|||
"homeassistant.components.homekit.accessories.HomeAccessory.async_update_state"
|
||||
):
|
||||
acc.run()
|
||||
assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS][entity_id]) == 1
|
||||
await acc.stop()
|
||||
assert entity_id not in hass.data[TRACK_STATE_CHANGE_CALLBACKS]
|
||||
|
||||
|
||||
async def test_home_accessory(hass: HomeAssistant, hk_driver) -> None:
|
||||
|
|
Loading…
Add table
Reference in a new issue