Improve dispatcher typing (#106872)
This commit is contained in:
parent
ea4143154b
commit
fde7a6e9ef
5 changed files with 161 additions and 15 deletions
|
@ -1,4 +1,15 @@
|
|||
"""Consts for Cast integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pychromecast.controllers.homeassistant import HomeAssistantController
|
||||
|
||||
from homeassistant.helpers.dispatcher import SignalType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .helpers import ChromecastInfo
|
||||
|
||||
|
||||
DOMAIN = "cast"
|
||||
|
||||
|
@ -14,14 +25,16 @@ CAST_BROWSER_KEY = "cast_browser"
|
|||
|
||||
# Dispatcher signal fired with a ChromecastInfo every time we discover a new
|
||||
# Chromecast or receive it through configuration
|
||||
SIGNAL_CAST_DISCOVERED = "cast_discovered"
|
||||
SIGNAL_CAST_DISCOVERED: SignalType[ChromecastInfo] = SignalType("cast_discovered")
|
||||
|
||||
# Dispatcher signal fired with a ChromecastInfo every time a Chromecast is
|
||||
# removed
|
||||
SIGNAL_CAST_REMOVED = "cast_removed"
|
||||
SIGNAL_CAST_REMOVED: SignalType[ChromecastInfo] = SignalType("cast_removed")
|
||||
|
||||
# Dispatcher signal fired when a Chromecast should show a Home Assistant Cast view.
|
||||
SIGNAL_HASS_CAST_SHOW_VIEW = "cast_show_view"
|
||||
SIGNAL_HASS_CAST_SHOW_VIEW: SignalType[
|
||||
HomeAssistantController, str, str, str | None
|
||||
] = SignalType("cast_show_view")
|
||||
|
||||
CONF_IGNORE_CEC = "ignore_cec"
|
||||
CONF_KNOWN_HOSTS = "known_hosts"
|
||||
|
|
|
@ -26,6 +26,7 @@ from homeassistant.helpers import config_validation as cv, entityfilter
|
|||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.discovery import async_load_platform
|
||||
from homeassistant.helpers.dispatcher import (
|
||||
SignalType,
|
||||
async_dispatcher_connect,
|
||||
async_dispatcher_send,
|
||||
)
|
||||
|
@ -69,7 +70,9 @@ PLATFORMS = [Platform.BINARY_SENSOR, Platform.STT]
|
|||
SERVICE_REMOTE_CONNECT = "remote_connect"
|
||||
SERVICE_REMOTE_DISCONNECT = "remote_disconnect"
|
||||
|
||||
SIGNAL_CLOUD_CONNECTION_STATE = "CLOUD_CONNECTION_STATE"
|
||||
SIGNAL_CLOUD_CONNECTION_STATE: SignalType[CloudConnectionState] = SignalType(
|
||||
"CLOUD_CONNECTION_STATE"
|
||||
)
|
||||
|
||||
STARTUP_REPAIR_DELAY = 1 # 1 hour
|
||||
|
||||
|
|
|
@ -1,4 +1,10 @@
|
|||
"""Constants for the cloud component."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.helpers.dispatcher import SignalType
|
||||
|
||||
DOMAIN = "cloud"
|
||||
DATA_PLATFORMS_SETUP = "cloud_platforms_setup"
|
||||
REQUEST_TIMEOUT = 10
|
||||
|
@ -64,6 +70,6 @@ CONF_SERVICEHANDLERS_SERVER = "servicehandlers_server"
|
|||
MODE_DEV = "development"
|
||||
MODE_PROD = "production"
|
||||
|
||||
DISPATCHER_REMOTE_UPDATE = "cloud_remote_update"
|
||||
DISPATCHER_REMOTE_UPDATE: SignalType[Any] = SignalType("cloud_remote_update")
|
||||
|
||||
STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"
|
||||
|
|
|
@ -2,30 +2,73 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Coroutine
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, Generic, TypeVarTuple, overload
|
||||
|
||||
from homeassistant.core import HassJob, HomeAssistant, callback
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util.async_ import run_callback_threadsafe
|
||||
from homeassistant.util.logging import catch_log_exception
|
||||
|
||||
_Ts = TypeVarTuple("_Ts")
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
DATA_DISPATCHER = "dispatcher"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SignalType(Generic[*_Ts]):
|
||||
"""Generic string class for signal to improve typing."""
|
||||
|
||||
name: str
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return hash of name."""
|
||||
|
||||
return hash(self.name)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Check equality for dict keys to be compatible with str."""
|
||||
|
||||
if isinstance(other, str):
|
||||
return self.name == other
|
||||
if isinstance(other, SignalType):
|
||||
return self.name == other.name
|
||||
return False
|
||||
|
||||
|
||||
_DispatcherDataType = dict[
|
||||
str,
|
||||
SignalType[*_Ts] | str,
|
||||
dict[
|
||||
Callable[..., Any],
|
||||
Callable[[*_Ts], Any] | Callable[..., Any],
|
||||
HassJob[..., None | Coroutine[Any, Any, None]] | None,
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
@overload
|
||||
@bind_hass
|
||||
def dispatcher_connect(
|
||||
hass: HomeAssistant, signal: SignalType[*_Ts], target: Callable[[*_Ts], None]
|
||||
) -> Callable[[], None]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
@bind_hass
|
||||
def dispatcher_connect(
|
||||
hass: HomeAssistant, signal: str, target: Callable[..., None]
|
||||
) -> Callable[[], None]:
|
||||
...
|
||||
|
||||
|
||||
@bind_hass # type: ignore[misc] # workaround; exclude typing of 2 overload in func def
|
||||
def dispatcher_connect(
|
||||
hass: HomeAssistant,
|
||||
signal: SignalType[*_Ts],
|
||||
target: Callable[[*_Ts], None],
|
||||
) -> Callable[[], None]:
|
||||
"""Connect a callable function to a signal."""
|
||||
async_unsub = run_callback_threadsafe(
|
||||
|
@ -41,9 +84,9 @@ def dispatcher_connect(
|
|||
|
||||
@callback
|
||||
def _async_remove_dispatcher(
|
||||
dispatchers: _DispatcherDataType,
|
||||
signal: str,
|
||||
target: Callable[..., Any],
|
||||
dispatchers: _DispatcherDataType[*_Ts],
|
||||
signal: SignalType[*_Ts] | str,
|
||||
target: Callable[[*_Ts], Any] | Callable[..., Any],
|
||||
) -> None:
|
||||
"""Remove signal listener."""
|
||||
try:
|
||||
|
@ -59,10 +102,30 @@ def _async_remove_dispatcher(
|
|||
_LOGGER.warning("Unable to remove unknown dispatcher %s", target)
|
||||
|
||||
|
||||
@overload
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_dispatcher_connect(
|
||||
hass: HomeAssistant, signal: SignalType[*_Ts], target: Callable[[*_Ts], Any]
|
||||
) -> Callable[[], None]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_dispatcher_connect(
|
||||
hass: HomeAssistant, signal: str, target: Callable[..., Any]
|
||||
) -> Callable[[], None]:
|
||||
...
|
||||
|
||||
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_dispatcher_connect(
|
||||
hass: HomeAssistant,
|
||||
signal: SignalType[*_Ts] | str,
|
||||
target: Callable[[*_Ts], Any] | Callable[..., Any],
|
||||
) -> Callable[[], None]:
|
||||
"""Connect a callable function to a signal.
|
||||
|
||||
|
@ -71,7 +134,7 @@ def async_dispatcher_connect(
|
|||
if DATA_DISPATCHER not in hass.data:
|
||||
hass.data[DATA_DISPATCHER] = {}
|
||||
|
||||
dispatchers: _DispatcherDataType = hass.data[DATA_DISPATCHER]
|
||||
dispatchers: _DispatcherDataType[*_Ts] = hass.data[DATA_DISPATCHER]
|
||||
|
||||
if signal not in dispatchers:
|
||||
dispatchers[signal] = {}
|
||||
|
@ -84,13 +147,29 @@ def async_dispatcher_connect(
|
|||
return partial(_async_remove_dispatcher, dispatchers, signal, target)
|
||||
|
||||
|
||||
@overload
|
||||
@bind_hass
|
||||
def dispatcher_send(hass: HomeAssistant, signal: SignalType[*_Ts], *args: *_Ts) -> None:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
@bind_hass
|
||||
def dispatcher_send(hass: HomeAssistant, signal: str, *args: Any) -> None:
|
||||
...
|
||||
|
||||
|
||||
@bind_hass # type: ignore[misc] # workaround; exclude typing of 2 overload in func def
|
||||
def dispatcher_send(hass: HomeAssistant, signal: SignalType[*_Ts], *args: *_Ts) -> None:
|
||||
"""Send signal and data."""
|
||||
hass.loop.call_soon_threadsafe(async_dispatcher_send, hass, signal, *args)
|
||||
|
||||
|
||||
def _format_err(signal: str, target: Callable[..., Any], *args: Any) -> str:
|
||||
def _format_err(
|
||||
signal: SignalType[*_Ts] | str,
|
||||
target: Callable[[*_Ts], Any] | Callable[..., Any],
|
||||
*args: Any,
|
||||
) -> str:
|
||||
"""Format error message."""
|
||||
return "Exception in {} when dispatching '{}': {}".format(
|
||||
# Functions wrapped in partial do not have a __name__
|
||||
|
@ -101,7 +180,7 @@ def _format_err(signal: str, target: Callable[..., Any], *args: Any) -> str:
|
|||
|
||||
|
||||
def _generate_job(
|
||||
signal: str, target: Callable[..., Any]
|
||||
signal: SignalType[*_Ts] | str, target: Callable[[*_Ts], Any] | Callable[..., Any]
|
||||
) -> HassJob[..., None | Coroutine[Any, Any, None]]:
|
||||
"""Generate a HassJob for a signal and target."""
|
||||
return HassJob(
|
||||
|
@ -110,16 +189,34 @@ def _generate_job(
|
|||
)
|
||||
|
||||
|
||||
@overload
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_dispatcher_send(
|
||||
hass: HomeAssistant, signal: SignalType[*_Ts], *args: *_Ts
|
||||
) -> None:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_dispatcher_send(hass: HomeAssistant, signal: str, *args: Any) -> None:
|
||||
...
|
||||
|
||||
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_dispatcher_send(
|
||||
hass: HomeAssistant, signal: SignalType[*_Ts] | str, *args: *_Ts
|
||||
) -> None:
|
||||
"""Send signal and data.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
if (maybe_dispatchers := hass.data.get(DATA_DISPATCHER)) is None:
|
||||
return
|
||||
dispatchers: _DispatcherDataType = maybe_dispatchers
|
||||
dispatchers: _DispatcherDataType[*_Ts] = maybe_dispatchers
|
||||
if (target_list := dispatchers.get(signal)) is None:
|
||||
return
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import pytest
|
|||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.dispatcher import (
|
||||
SignalType,
|
||||
async_dispatcher_connect,
|
||||
async_dispatcher_send,
|
||||
)
|
||||
|
@ -30,6 +31,32 @@ async def test_simple_function(hass: HomeAssistant) -> None:
|
|||
assert calls == [3, "bla"]
|
||||
|
||||
|
||||
async def test_signal_type(hass: HomeAssistant) -> None:
|
||||
"""Test dispatcher with SignalType."""
|
||||
signal: SignalType[str, int] = SignalType("test")
|
||||
calls: list[tuple[str, int]] = []
|
||||
|
||||
def test_funct(data1: str, data2: int) -> None:
|
||||
calls.append((data1, data2))
|
||||
|
||||
async_dispatcher_connect(hass, signal, test_funct)
|
||||
async_dispatcher_send(hass, signal, "Hello", 2)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert calls == [("Hello", 2)]
|
||||
|
||||
async_dispatcher_send(hass, signal, "World", 3)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert calls == [("Hello", 2), ("World", 3)]
|
||||
|
||||
# Test compatibility with string keys
|
||||
async_dispatcher_send(hass, "test", "x", 4)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert calls == [("Hello", 2), ("World", 3), ("x", 4)]
|
||||
|
||||
|
||||
async def test_simple_function_unsub(hass: HomeAssistant) -> None:
|
||||
"""Test simple function (executor) and unsub."""
|
||||
calls1 = []
|
||||
|
|
Loading…
Add table
Reference in a new issue