diff --git a/homeassistant/helpers/dispatcher.py b/homeassistant/helpers/dispatcher.py index 60aab156144..e416d939914 100644 --- a/homeassistant/helpers/dispatcher.py +++ b/homeassistant/helpers/dispatcher.py @@ -2,6 +2,7 @@ from __future__ import annotations from collections.abc import Callable, Coroutine +from functools import partial import logging from typing import Any @@ -13,6 +14,14 @@ from homeassistant.util.logging import catch_log_exception _LOGGER = logging.getLogger(__name__) DATA_DISPATCHER = "dispatcher" +_DispatcherDataType = dict[ + str, + dict[ + Callable[..., Any], + HassJob[..., None | Coroutine[Any, Any, None]] | None, + ], +] + @bind_hass def dispatcher_connect( @@ -30,6 +39,26 @@ def dispatcher_connect( return remove_dispatcher +@callback +def _async_remove_dispatcher( + dispatchers: _DispatcherDataType, + signal: str, + target: Callable[..., Any], +) -> None: + """Remove signal listener.""" + try: + signal_dispatchers = dispatchers[signal] + del signal_dispatchers[target] + # Cleanup the signal dict if it is now empty + # to prevent memory leaks + if not signal_dispatchers: + del dispatchers[signal] + except (KeyError, ValueError): + # KeyError is key target listener did not exist + # ValueError if listener did not exist within signal + _LOGGER.warning("Unable to remove unknown dispatcher %s", target) + + @callback @bind_hass def async_dispatcher_connect( @@ -41,19 +70,18 @@ def async_dispatcher_connect( """ if DATA_DISPATCHER not in hass.data: hass.data[DATA_DISPATCHER] = {} - hass.data[DATA_DISPATCHER].setdefault(signal, {})[target] = None - @callback - def async_remove_dispatcher() -> None: - """Remove signal listener.""" - try: - del hass.data[DATA_DISPATCHER][signal][target] - except (KeyError, ValueError): - # KeyError is key target listener did not exist - # ValueError if listener did not exist within signal - _LOGGER.warning("Unable to remove unknown dispatcher %s", target) + dispatchers: _DispatcherDataType = hass.data[DATA_DISPATCHER] - return async_remove_dispatcher + if signal not in dispatchers: + dispatchers[signal] = {} + + dispatchers[signal][target] = None + # Use a partial for the remove since it uses + # less memory than a full closure since a partial copies + # the body of the function and we don't have to store + # many different copies of the same function + return partial(_async_remove_dispatcher, dispatchers, signal, target) @bind_hass @@ -87,21 +115,14 @@ def async_dispatcher_send(hass: HomeAssistant, signal: str, *args: Any) -> None: This method must be run in the event loop. """ - target_list: dict[ - Callable[..., Any], HassJob[..., None | Coroutine[Any, Any, None]] | None - ] = hass.data.get(DATA_DISPATCHER, {}).get(signal, {}) + if (maybe_dispatchers := hass.data.get(DATA_DISPATCHER)) is None: + return + dispatchers: _DispatcherDataType = maybe_dispatchers + if (target_list := dispatchers.get(signal)) is None: + return - run: list[HassJob[..., None | Coroutine[Any, Any, None]]] = [] - for target, job in target_list.items(): + for target, job in list(target_list.items()): if job is None: job = _generate_job(signal, target) target_list[target] = job - - # Run the jobs all at the end - # to ensure no jobs add more dispatchers - # which can result in the target_list - # changing size during iteration - run.append(job) - - for job in run: hass.async_run_hass_job(job, *args) diff --git a/tests/helpers/test_dispatcher.py b/tests/helpers/test_dispatcher.py index e30aaa6e0d9..a251b20b0f4 100644 --- a/tests/helpers/test_dispatcher.py +++ b/tests/helpers/test_dispatcher.py @@ -151,3 +151,25 @@ async def test_callback_exception_gets_logged( f"Exception in functools.partial({bad_handler}) when dispatching 'test': ('bad',)" in caplog.text ) + + +async def test_dispatcher_add_dispatcher(hass: HomeAssistant) -> None: + """Test adding a dispatcher from a dispatcher.""" + calls = [] + + @callback + def _new_dispatcher(data): + calls.append(data) + + @callback + def _add_new_dispatcher(data): + calls.append(data) + async_dispatcher_connect(hass, "test", _new_dispatcher) + + async_dispatcher_connect(hass, "test", _add_new_dispatcher) + + async_dispatcher_send(hass, "test", 3) + async_dispatcher_send(hass, "test", 4) + async_dispatcher_send(hass, "test", 5) + + assert calls == [3, 4, 4, 5, 5]