Refactor dispatcher to reduce run time and memory overhead (#99676)

* Fix memory leak in dispatcher removal

When we removed the last job/callable from the dict for the
signal we did not remove the dict for the signal which meant
it leaked

* comment

* cleanup a bit more
This commit is contained in:
J. Nick Koston 2023-09-05 20:18:27 -05:00 committed by GitHub
parent b69cc29a78
commit a2dae60170
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 67 additions and 24 deletions

View file

@ -2,6 +2,7 @@
from __future__ import annotations
from collections.abc import Callable, Coroutine
from functools import partial
import logging
from typing import Any
@ -13,6 +14,14 @@ from homeassistant.util.logging import catch_log_exception
_LOGGER = logging.getLogger(__name__)
DATA_DISPATCHER = "dispatcher"
_DispatcherDataType = dict[
str,
dict[
Callable[..., Any],
HassJob[..., None | Coroutine[Any, Any, None]] | None,
],
]
@bind_hass
def dispatcher_connect(
@ -30,6 +39,26 @@ def dispatcher_connect(
return remove_dispatcher
@callback
def _async_remove_dispatcher(
dispatchers: _DispatcherDataType,
signal: str,
target: Callable[..., Any],
) -> None:
"""Remove signal listener."""
try:
signal_dispatchers = dispatchers[signal]
del signal_dispatchers[target]
# Cleanup the signal dict if it is now empty
# to prevent memory leaks
if not signal_dispatchers:
del dispatchers[signal]
except (KeyError, ValueError):
# KeyError is key target listener did not exist
# ValueError if listener did not exist within signal
_LOGGER.warning("Unable to remove unknown dispatcher %s", target)
@callback
@bind_hass
def async_dispatcher_connect(
@ -41,19 +70,18 @@ def async_dispatcher_connect(
"""
if DATA_DISPATCHER not in hass.data:
hass.data[DATA_DISPATCHER] = {}
hass.data[DATA_DISPATCHER].setdefault(signal, {})[target] = None
@callback
def async_remove_dispatcher() -> None:
"""Remove signal listener."""
try:
del hass.data[DATA_DISPATCHER][signal][target]
except (KeyError, ValueError):
# KeyError is key target listener did not exist
# ValueError if listener did not exist within signal
_LOGGER.warning("Unable to remove unknown dispatcher %s", target)
dispatchers: _DispatcherDataType = hass.data[DATA_DISPATCHER]
return async_remove_dispatcher
if signal not in dispatchers:
dispatchers[signal] = {}
dispatchers[signal][target] = None
# Use a partial for the remove since it uses
# less memory than a full closure since a partial copies
# the body of the function and we don't have to store
# many different copies of the same function
return partial(_async_remove_dispatcher, dispatchers, signal, target)
@bind_hass
@ -87,21 +115,14 @@ def async_dispatcher_send(hass: HomeAssistant, signal: str, *args: Any) -> None:
This method must be run in the event loop.
"""
target_list: dict[
Callable[..., Any], HassJob[..., None | Coroutine[Any, Any, None]] | None
] = hass.data.get(DATA_DISPATCHER, {}).get(signal, {})
if (maybe_dispatchers := hass.data.get(DATA_DISPATCHER)) is None:
return
dispatchers: _DispatcherDataType = maybe_dispatchers
if (target_list := dispatchers.get(signal)) is None:
return
run: list[HassJob[..., None | Coroutine[Any, Any, None]]] = []
for target, job in target_list.items():
for target, job in list(target_list.items()):
if job is None:
job = _generate_job(signal, target)
target_list[target] = job
# Run the jobs all at the end
# to ensure no jobs add more dispatchers
# which can result in the target_list
# changing size during iteration
run.append(job)
for job in run:
hass.async_run_hass_job(job, *args)

View file

@ -151,3 +151,25 @@ async def test_callback_exception_gets_logged(
f"Exception in functools.partial({bad_handler}) when dispatching 'test': ('bad',)"
in caplog.text
)
async def test_dispatcher_add_dispatcher(hass: HomeAssistant) -> None:
"""Test adding a dispatcher from a dispatcher."""
calls = []
@callback
def _new_dispatcher(data):
calls.append(data)
@callback
def _add_new_dispatcher(data):
calls.append(data)
async_dispatcher_connect(hass, "test", _new_dispatcher)
async_dispatcher_connect(hass, "test", _add_new_dispatcher)
async_dispatcher_send(hass, "test", 3)
async_dispatcher_send(hass, "test", 4)
async_dispatcher_send(hass, "test", 5)
assert calls == [3, 4, 4, 5, 5]