From fde7a6e9ef288cec56457369eb44e549f8e91a1a Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Mon, 8 Jan 2024 09:45:37 +0100 Subject: [PATCH] Improve dispatcher typing (#106872) --- homeassistant/components/cast/const.py | 19 +++- homeassistant/components/cloud/__init__.py | 5 +- homeassistant/components/cloud/const.py | 8 +- homeassistant/helpers/dispatcher.py | 117 +++++++++++++++++++-- tests/helpers/test_dispatcher.py | 27 +++++ 5 files changed, 161 insertions(+), 15 deletions(-) diff --git a/homeassistant/components/cast/const.py b/homeassistant/components/cast/const.py index e8e38a6e72b..730757de8b4 100644 --- a/homeassistant/components/cast/const.py +++ b/homeassistant/components/cast/const.py @@ -1,4 +1,15 @@ """Consts for Cast integration.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pychromecast.controllers.homeassistant import HomeAssistantController + +from homeassistant.helpers.dispatcher import SignalType + +if TYPE_CHECKING: + from .helpers import ChromecastInfo + DOMAIN = "cast" @@ -14,14 +25,16 @@ CAST_BROWSER_KEY = "cast_browser" # Dispatcher signal fired with a ChromecastInfo every time we discover a new # Chromecast or receive it through configuration -SIGNAL_CAST_DISCOVERED = "cast_discovered" +SIGNAL_CAST_DISCOVERED: SignalType[ChromecastInfo] = SignalType("cast_discovered") # Dispatcher signal fired with a ChromecastInfo every time a Chromecast is # removed -SIGNAL_CAST_REMOVED = "cast_removed" +SIGNAL_CAST_REMOVED: SignalType[ChromecastInfo] = SignalType("cast_removed") # Dispatcher signal fired when a Chromecast should show a Home Assistant Cast view. -SIGNAL_HASS_CAST_SHOW_VIEW = "cast_show_view" +SIGNAL_HASS_CAST_SHOW_VIEW: SignalType[ + HomeAssistantController, str, str, str | None +] = SignalType("cast_show_view") CONF_IGNORE_CEC = "ignore_cec" CONF_KNOWN_HOSTS = "known_hosts" diff --git a/homeassistant/components/cloud/__init__.py b/homeassistant/components/cloud/__init__.py index 6e5cddd0f28..76369c07e8e 100644 --- a/homeassistant/components/cloud/__init__.py +++ b/homeassistant/components/cloud/__init__.py @@ -26,6 +26,7 @@ from homeassistant.helpers import config_validation as cv, entityfilter from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.discovery import async_load_platform from homeassistant.helpers.dispatcher import ( + SignalType, async_dispatcher_connect, async_dispatcher_send, ) @@ -69,7 +70,9 @@ PLATFORMS = [Platform.BINARY_SENSOR, Platform.STT] SERVICE_REMOTE_CONNECT = "remote_connect" SERVICE_REMOTE_DISCONNECT = "remote_disconnect" -SIGNAL_CLOUD_CONNECTION_STATE = "CLOUD_CONNECTION_STATE" +SIGNAL_CLOUD_CONNECTION_STATE: SignalType[CloudConnectionState] = SignalType( + "CLOUD_CONNECTION_STATE" +) STARTUP_REPAIR_DELAY = 1 # 1 hour diff --git a/homeassistant/components/cloud/const.py b/homeassistant/components/cloud/const.py index db964607923..da012c20bab 100644 --- a/homeassistant/components/cloud/const.py +++ b/homeassistant/components/cloud/const.py @@ -1,4 +1,10 @@ """Constants for the cloud component.""" +from __future__ import annotations + +from typing import Any + +from homeassistant.helpers.dispatcher import SignalType + DOMAIN = "cloud" DATA_PLATFORMS_SETUP = "cloud_platforms_setup" REQUEST_TIMEOUT = 10 @@ -64,6 +70,6 @@ CONF_SERVICEHANDLERS_SERVER = "servicehandlers_server" MODE_DEV = "development" MODE_PROD = "production" -DISPATCHER_REMOTE_UPDATE = "cloud_remote_update" +DISPATCHER_REMOTE_UPDATE: SignalType[Any] = SignalType("cloud_remote_update") STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text" diff --git a/homeassistant/helpers/dispatcher.py b/homeassistant/helpers/dispatcher.py index 07112226ecf..59d680a60ee 100644 --- a/homeassistant/helpers/dispatcher.py +++ b/homeassistant/helpers/dispatcher.py @@ -2,30 +2,73 @@ from __future__ import annotations from collections.abc import Callable, Coroutine +from dataclasses import dataclass from functools import partial import logging -from typing import Any +from typing import Any, Generic, TypeVarTuple, overload from homeassistant.core import HassJob, HomeAssistant, callback from homeassistant.loader import bind_hass from homeassistant.util.async_ import run_callback_threadsafe from homeassistant.util.logging import catch_log_exception +_Ts = TypeVarTuple("_Ts") + _LOGGER = logging.getLogger(__name__) DATA_DISPATCHER = "dispatcher" + +@dataclass(frozen=True) +class SignalType(Generic[*_Ts]): + """Generic string class for signal to improve typing.""" + + name: str + + def __hash__(self) -> int: + """Return hash of name.""" + + return hash(self.name) + + def __eq__(self, other: Any) -> bool: + """Check equality for dict keys to be compatible with str.""" + + if isinstance(other, str): + return self.name == other + if isinstance(other, SignalType): + return self.name == other.name + return False + + _DispatcherDataType = dict[ - str, + SignalType[*_Ts] | str, dict[ - Callable[..., Any], + Callable[[*_Ts], Any] | Callable[..., Any], HassJob[..., None | Coroutine[Any, Any, None]] | None, ], ] +@overload +@bind_hass +def dispatcher_connect( + hass: HomeAssistant, signal: SignalType[*_Ts], target: Callable[[*_Ts], None] +) -> Callable[[], None]: + ... + + +@overload @bind_hass def dispatcher_connect( hass: HomeAssistant, signal: str, target: Callable[..., None] +) -> Callable[[], None]: + ... + + +@bind_hass # type: ignore[misc] # workaround; exclude typing of 2 overload in func def +def dispatcher_connect( + hass: HomeAssistant, + signal: SignalType[*_Ts], + target: Callable[[*_Ts], None], ) -> Callable[[], None]: """Connect a callable function to a signal.""" async_unsub = run_callback_threadsafe( @@ -41,9 +84,9 @@ def dispatcher_connect( @callback def _async_remove_dispatcher( - dispatchers: _DispatcherDataType, - signal: str, - target: Callable[..., Any], + dispatchers: _DispatcherDataType[*_Ts], + signal: SignalType[*_Ts] | str, + target: Callable[[*_Ts], Any] | Callable[..., Any], ) -> None: """Remove signal listener.""" try: @@ -59,10 +102,30 @@ def _async_remove_dispatcher( _LOGGER.warning("Unable to remove unknown dispatcher %s", target) +@overload +@callback +@bind_hass +def async_dispatcher_connect( + hass: HomeAssistant, signal: SignalType[*_Ts], target: Callable[[*_Ts], Any] +) -> Callable[[], None]: + ... + + +@overload @callback @bind_hass def async_dispatcher_connect( hass: HomeAssistant, signal: str, target: Callable[..., Any] +) -> Callable[[], None]: + ... + + +@callback +@bind_hass +def async_dispatcher_connect( + hass: HomeAssistant, + signal: SignalType[*_Ts] | str, + target: Callable[[*_Ts], Any] | Callable[..., Any], ) -> Callable[[], None]: """Connect a callable function to a signal. @@ -71,7 +134,7 @@ def async_dispatcher_connect( if DATA_DISPATCHER not in hass.data: hass.data[DATA_DISPATCHER] = {} - dispatchers: _DispatcherDataType = hass.data[DATA_DISPATCHER] + dispatchers: _DispatcherDataType[*_Ts] = hass.data[DATA_DISPATCHER] if signal not in dispatchers: dispatchers[signal] = {} @@ -84,13 +147,29 @@ def async_dispatcher_connect( return partial(_async_remove_dispatcher, dispatchers, signal, target) +@overload +@bind_hass +def dispatcher_send(hass: HomeAssistant, signal: SignalType[*_Ts], *args: *_Ts) -> None: + ... + + +@overload @bind_hass def dispatcher_send(hass: HomeAssistant, signal: str, *args: Any) -> None: + ... + + +@bind_hass # type: ignore[misc] # workaround; exclude typing of 2 overload in func def +def dispatcher_send(hass: HomeAssistant, signal: SignalType[*_Ts], *args: *_Ts) -> None: """Send signal and data.""" hass.loop.call_soon_threadsafe(async_dispatcher_send, hass, signal, *args) -def _format_err(signal: str, target: Callable[..., Any], *args: Any) -> str: +def _format_err( + signal: SignalType[*_Ts] | str, + target: Callable[[*_Ts], Any] | Callable[..., Any], + *args: Any, +) -> str: """Format error message.""" return "Exception in {} when dispatching '{}': {}".format( # Functions wrapped in partial do not have a __name__ @@ -101,7 +180,7 @@ def _format_err(signal: str, target: Callable[..., Any], *args: Any) -> str: def _generate_job( - signal: str, target: Callable[..., Any] + signal: SignalType[*_Ts] | str, target: Callable[[*_Ts], Any] | Callable[..., Any] ) -> HassJob[..., None | Coroutine[Any, Any, None]]: """Generate a HassJob for a signal and target.""" return HassJob( @@ -110,16 +189,34 @@ def _generate_job( ) +@overload +@callback +@bind_hass +def async_dispatcher_send( + hass: HomeAssistant, signal: SignalType[*_Ts], *args: *_Ts +) -> None: + ... + + +@overload @callback @bind_hass def async_dispatcher_send(hass: HomeAssistant, signal: str, *args: Any) -> None: + ... + + +@callback +@bind_hass +def async_dispatcher_send( + hass: HomeAssistant, signal: SignalType[*_Ts] | str, *args: *_Ts +) -> None: """Send signal and data. This method must be run in the event loop. """ if (maybe_dispatchers := hass.data.get(DATA_DISPATCHER)) is None: return - dispatchers: _DispatcherDataType = maybe_dispatchers + dispatchers: _DispatcherDataType[*_Ts] = maybe_dispatchers if (target_list := dispatchers.get(signal)) is None: return diff --git a/tests/helpers/test_dispatcher.py b/tests/helpers/test_dispatcher.py index 89d23fb4533..add80c941a1 100644 --- a/tests/helpers/test_dispatcher.py +++ b/tests/helpers/test_dispatcher.py @@ -5,6 +5,7 @@ import pytest from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.dispatcher import ( + SignalType, async_dispatcher_connect, async_dispatcher_send, ) @@ -30,6 +31,32 @@ async def test_simple_function(hass: HomeAssistant) -> None: assert calls == [3, "bla"] +async def test_signal_type(hass: HomeAssistant) -> None: + """Test dispatcher with SignalType.""" + signal: SignalType[str, int] = SignalType("test") + calls: list[tuple[str, int]] = [] + + def test_funct(data1: str, data2: int) -> None: + calls.append((data1, data2)) + + async_dispatcher_connect(hass, signal, test_funct) + async_dispatcher_send(hass, signal, "Hello", 2) + await hass.async_block_till_done() + + assert calls == [("Hello", 2)] + + async_dispatcher_send(hass, signal, "World", 3) + await hass.async_block_till_done() + + assert calls == [("Hello", 2), ("World", 3)] + + # Test compatibility with string keys + async_dispatcher_send(hass, "test", "x", 4) + await hass.async_block_till_done() + + assert calls == [("Hello", 2), ("World", 3), ("x", 4)] + + async def test_simple_function_unsub(hass: HomeAssistant) -> None: """Test simple function (executor) and unsub.""" calls1 = []