Improve dispatcher typing (#106872)

This commit is contained in:
Marc Mueller 2024-01-08 09:45:37 +01:00 committed by GitHub
parent ea4143154b
commit fde7a6e9ef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 161 additions and 15 deletions

View file

@ -1,4 +1,15 @@
"""Consts for Cast integration."""
from __future__ import annotations
from typing import TYPE_CHECKING
from pychromecast.controllers.homeassistant import HomeAssistantController
from homeassistant.helpers.dispatcher import SignalType
if TYPE_CHECKING:
from .helpers import ChromecastInfo
DOMAIN = "cast"
@ -14,14 +25,16 @@ CAST_BROWSER_KEY = "cast_browser"
# Dispatcher signal fired with a ChromecastInfo every time we discover a new
# Chromecast or receive it through configuration
SIGNAL_CAST_DISCOVERED = "cast_discovered"
SIGNAL_CAST_DISCOVERED: SignalType[ChromecastInfo] = SignalType("cast_discovered")
# Dispatcher signal fired with a ChromecastInfo every time a Chromecast is
# removed
SIGNAL_CAST_REMOVED = "cast_removed"
SIGNAL_CAST_REMOVED: SignalType[ChromecastInfo] = SignalType("cast_removed")
# Dispatcher signal fired when a Chromecast should show a Home Assistant Cast view.
SIGNAL_HASS_CAST_SHOW_VIEW = "cast_show_view"
SIGNAL_HASS_CAST_SHOW_VIEW: SignalType[
HomeAssistantController, str, str, str | None
] = SignalType("cast_show_view")
CONF_IGNORE_CEC = "ignore_cec"
CONF_KNOWN_HOSTS = "known_hosts"

View file

@ -26,6 +26,7 @@ from homeassistant.helpers import config_validation as cv, entityfilter
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.discovery import async_load_platform
from homeassistant.helpers.dispatcher import (
SignalType,
async_dispatcher_connect,
async_dispatcher_send,
)
@ -69,7 +70,9 @@ PLATFORMS = [Platform.BINARY_SENSOR, Platform.STT]
SERVICE_REMOTE_CONNECT = "remote_connect"
SERVICE_REMOTE_DISCONNECT = "remote_disconnect"
SIGNAL_CLOUD_CONNECTION_STATE = "CLOUD_CONNECTION_STATE"
SIGNAL_CLOUD_CONNECTION_STATE: SignalType[CloudConnectionState] = SignalType(
"CLOUD_CONNECTION_STATE"
)
STARTUP_REPAIR_DELAY = 1 # 1 hour

View file

@ -1,4 +1,10 @@
"""Constants for the cloud component."""
from __future__ import annotations
from typing import Any
from homeassistant.helpers.dispatcher import SignalType
DOMAIN = "cloud"
DATA_PLATFORMS_SETUP = "cloud_platforms_setup"
REQUEST_TIMEOUT = 10
@ -64,6 +70,6 @@ CONF_SERVICEHANDLERS_SERVER = "servicehandlers_server"
MODE_DEV = "development"
MODE_PROD = "production"
DISPATCHER_REMOTE_UPDATE = "cloud_remote_update"
DISPATCHER_REMOTE_UPDATE: SignalType[Any] = SignalType("cloud_remote_update")
STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"

View file

@ -2,30 +2,73 @@
from __future__ import annotations
from collections.abc import Callable, Coroutine
from dataclasses import dataclass
from functools import partial
import logging
from typing import Any
from typing import Any, Generic, TypeVarTuple, overload
from homeassistant.core import HassJob, HomeAssistant, callback
from homeassistant.loader import bind_hass
from homeassistant.util.async_ import run_callback_threadsafe
from homeassistant.util.logging import catch_log_exception
_Ts = TypeVarTuple("_Ts")
_LOGGER = logging.getLogger(__name__)
DATA_DISPATCHER = "dispatcher"
@dataclass(frozen=True)
class SignalType(Generic[*_Ts]):
"""Generic string class for signal to improve typing."""
name: str
def __hash__(self) -> int:
"""Return hash of name."""
return hash(self.name)
def __eq__(self, other: Any) -> bool:
"""Check equality for dict keys to be compatible with str."""
if isinstance(other, str):
return self.name == other
if isinstance(other, SignalType):
return self.name == other.name
return False
_DispatcherDataType = dict[
str,
SignalType[*_Ts] | str,
dict[
Callable[..., Any],
Callable[[*_Ts], Any] | Callable[..., Any],
HassJob[..., None | Coroutine[Any, Any, None]] | None,
],
]
@overload
@bind_hass
def dispatcher_connect(
hass: HomeAssistant, signal: SignalType[*_Ts], target: Callable[[*_Ts], None]
) -> Callable[[], None]:
...
@overload
@bind_hass
def dispatcher_connect(
hass: HomeAssistant, signal: str, target: Callable[..., None]
) -> Callable[[], None]:
...
@bind_hass # type: ignore[misc] # workaround; exclude typing of 2 overload in func def
def dispatcher_connect(
hass: HomeAssistant,
signal: SignalType[*_Ts],
target: Callable[[*_Ts], None],
) -> Callable[[], None]:
"""Connect a callable function to a signal."""
async_unsub = run_callback_threadsafe(
@ -41,9 +84,9 @@ def dispatcher_connect(
@callback
def _async_remove_dispatcher(
dispatchers: _DispatcherDataType,
signal: str,
target: Callable[..., Any],
dispatchers: _DispatcherDataType[*_Ts],
signal: SignalType[*_Ts] | str,
target: Callable[[*_Ts], Any] | Callable[..., Any],
) -> None:
"""Remove signal listener."""
try:
@ -59,10 +102,30 @@ def _async_remove_dispatcher(
_LOGGER.warning("Unable to remove unknown dispatcher %s", target)
@overload
@callback
@bind_hass
def async_dispatcher_connect(
hass: HomeAssistant, signal: SignalType[*_Ts], target: Callable[[*_Ts], Any]
) -> Callable[[], None]:
...
@overload
@callback
@bind_hass
def async_dispatcher_connect(
hass: HomeAssistant, signal: str, target: Callable[..., Any]
) -> Callable[[], None]:
...
@callback
@bind_hass
def async_dispatcher_connect(
hass: HomeAssistant,
signal: SignalType[*_Ts] | str,
target: Callable[[*_Ts], Any] | Callable[..., Any],
) -> Callable[[], None]:
"""Connect a callable function to a signal.
@ -71,7 +134,7 @@ def async_dispatcher_connect(
if DATA_DISPATCHER not in hass.data:
hass.data[DATA_DISPATCHER] = {}
dispatchers: _DispatcherDataType = hass.data[DATA_DISPATCHER]
dispatchers: _DispatcherDataType[*_Ts] = hass.data[DATA_DISPATCHER]
if signal not in dispatchers:
dispatchers[signal] = {}
@ -84,13 +147,29 @@ def async_dispatcher_connect(
return partial(_async_remove_dispatcher, dispatchers, signal, target)
@overload
@bind_hass
def dispatcher_send(hass: HomeAssistant, signal: SignalType[*_Ts], *args: *_Ts) -> None:
...
@overload
@bind_hass
def dispatcher_send(hass: HomeAssistant, signal: str, *args: Any) -> None:
...
@bind_hass # type: ignore[misc] # workaround; exclude typing of 2 overload in func def
def dispatcher_send(hass: HomeAssistant, signal: SignalType[*_Ts], *args: *_Ts) -> None:
"""Send signal and data."""
hass.loop.call_soon_threadsafe(async_dispatcher_send, hass, signal, *args)
def _format_err(signal: str, target: Callable[..., Any], *args: Any) -> str:
def _format_err(
signal: SignalType[*_Ts] | str,
target: Callable[[*_Ts], Any] | Callable[..., Any],
*args: Any,
) -> str:
"""Format error message."""
return "Exception in {} when dispatching '{}': {}".format(
# Functions wrapped in partial do not have a __name__
@ -101,7 +180,7 @@ def _format_err(signal: str, target: Callable[..., Any], *args: Any) -> str:
def _generate_job(
signal: str, target: Callable[..., Any]
signal: SignalType[*_Ts] | str, target: Callable[[*_Ts], Any] | Callable[..., Any]
) -> HassJob[..., None | Coroutine[Any, Any, None]]:
"""Generate a HassJob for a signal and target."""
return HassJob(
@ -110,16 +189,34 @@ def _generate_job(
)
@overload
@callback
@bind_hass
def async_dispatcher_send(
hass: HomeAssistant, signal: SignalType[*_Ts], *args: *_Ts
) -> None:
...
@overload
@callback
@bind_hass
def async_dispatcher_send(hass: HomeAssistant, signal: str, *args: Any) -> None:
...
@callback
@bind_hass
def async_dispatcher_send(
hass: HomeAssistant, signal: SignalType[*_Ts] | str, *args: *_Ts
) -> None:
"""Send signal and data.
This method must be run in the event loop.
"""
if (maybe_dispatchers := hass.data.get(DATA_DISPATCHER)) is None:
return
dispatchers: _DispatcherDataType = maybe_dispatchers
dispatchers: _DispatcherDataType[*_Ts] = maybe_dispatchers
if (target_list := dispatchers.get(signal)) is None:
return

View file

@ -5,6 +5,7 @@ import pytest
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.dispatcher import (
SignalType,
async_dispatcher_connect,
async_dispatcher_send,
)
@ -30,6 +31,32 @@ async def test_simple_function(hass: HomeAssistant) -> None:
assert calls == [3, "bla"]
async def test_signal_type(hass: HomeAssistant) -> None:
"""Test dispatcher with SignalType."""
signal: SignalType[str, int] = SignalType("test")
calls: list[tuple[str, int]] = []
def test_funct(data1: str, data2: int) -> None:
calls.append((data1, data2))
async_dispatcher_connect(hass, signal, test_funct)
async_dispatcher_send(hass, signal, "Hello", 2)
await hass.async_block_till_done()
assert calls == [("Hello", 2)]
async_dispatcher_send(hass, signal, "World", 3)
await hass.async_block_till_done()
assert calls == [("Hello", 2), ("World", 3)]
# Test compatibility with string keys
async_dispatcher_send(hass, "test", "x", 4)
await hass.async_block_till_done()
assert calls == [("Hello", 2), ("World", 3), ("x", 4)]
async def test_simple_function_unsub(hass: HomeAssistant) -> None:
"""Test simple function (executor) and unsub."""
calls1 = []