Improve dispatcher typing (#106872)
This commit is contained in:
parent
ea4143154b
commit
fde7a6e9ef
5 changed files with 161 additions and 15 deletions
|
@ -1,4 +1,15 @@
|
||||||
"""Consts for Cast integration."""
|
"""Consts for Cast integration."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from pychromecast.controllers.homeassistant import HomeAssistantController
|
||||||
|
|
||||||
|
from homeassistant.helpers.dispatcher import SignalType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .helpers import ChromecastInfo
|
||||||
|
|
||||||
|
|
||||||
DOMAIN = "cast"
|
DOMAIN = "cast"
|
||||||
|
|
||||||
|
@ -14,14 +25,16 @@ CAST_BROWSER_KEY = "cast_browser"
|
||||||
|
|
||||||
# Dispatcher signal fired with a ChromecastInfo every time we discover a new
|
# Dispatcher signal fired with a ChromecastInfo every time we discover a new
|
||||||
# Chromecast or receive it through configuration
|
# Chromecast or receive it through configuration
|
||||||
SIGNAL_CAST_DISCOVERED = "cast_discovered"
|
SIGNAL_CAST_DISCOVERED: SignalType[ChromecastInfo] = SignalType("cast_discovered")
|
||||||
|
|
||||||
# Dispatcher signal fired with a ChromecastInfo every time a Chromecast is
|
# Dispatcher signal fired with a ChromecastInfo every time a Chromecast is
|
||||||
# removed
|
# removed
|
||||||
SIGNAL_CAST_REMOVED = "cast_removed"
|
SIGNAL_CAST_REMOVED: SignalType[ChromecastInfo] = SignalType("cast_removed")
|
||||||
|
|
||||||
# Dispatcher signal fired when a Chromecast should show a Home Assistant Cast view.
|
# Dispatcher signal fired when a Chromecast should show a Home Assistant Cast view.
|
||||||
SIGNAL_HASS_CAST_SHOW_VIEW = "cast_show_view"
|
SIGNAL_HASS_CAST_SHOW_VIEW: SignalType[
|
||||||
|
HomeAssistantController, str, str, str | None
|
||||||
|
] = SignalType("cast_show_view")
|
||||||
|
|
||||||
CONF_IGNORE_CEC = "ignore_cec"
|
CONF_IGNORE_CEC = "ignore_cec"
|
||||||
CONF_KNOWN_HOSTS = "known_hosts"
|
CONF_KNOWN_HOSTS = "known_hosts"
|
||||||
|
|
|
@ -26,6 +26,7 @@ from homeassistant.helpers import config_validation as cv, entityfilter
|
||||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||||
from homeassistant.helpers.discovery import async_load_platform
|
from homeassistant.helpers.discovery import async_load_platform
|
||||||
from homeassistant.helpers.dispatcher import (
|
from homeassistant.helpers.dispatcher import (
|
||||||
|
SignalType,
|
||||||
async_dispatcher_connect,
|
async_dispatcher_connect,
|
||||||
async_dispatcher_send,
|
async_dispatcher_send,
|
||||||
)
|
)
|
||||||
|
@ -69,7 +70,9 @@ PLATFORMS = [Platform.BINARY_SENSOR, Platform.STT]
|
||||||
SERVICE_REMOTE_CONNECT = "remote_connect"
|
SERVICE_REMOTE_CONNECT = "remote_connect"
|
||||||
SERVICE_REMOTE_DISCONNECT = "remote_disconnect"
|
SERVICE_REMOTE_DISCONNECT = "remote_disconnect"
|
||||||
|
|
||||||
SIGNAL_CLOUD_CONNECTION_STATE = "CLOUD_CONNECTION_STATE"
|
SIGNAL_CLOUD_CONNECTION_STATE: SignalType[CloudConnectionState] = SignalType(
|
||||||
|
"CLOUD_CONNECTION_STATE"
|
||||||
|
)
|
||||||
|
|
||||||
STARTUP_REPAIR_DELAY = 1 # 1 hour
|
STARTUP_REPAIR_DELAY = 1 # 1 hour
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,10 @@
|
||||||
"""Constants for the cloud component."""
|
"""Constants for the cloud component."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from homeassistant.helpers.dispatcher import SignalType
|
||||||
|
|
||||||
DOMAIN = "cloud"
|
DOMAIN = "cloud"
|
||||||
DATA_PLATFORMS_SETUP = "cloud_platforms_setup"
|
DATA_PLATFORMS_SETUP = "cloud_platforms_setup"
|
||||||
REQUEST_TIMEOUT = 10
|
REQUEST_TIMEOUT = 10
|
||||||
|
@ -64,6 +70,6 @@ CONF_SERVICEHANDLERS_SERVER = "servicehandlers_server"
|
||||||
MODE_DEV = "development"
|
MODE_DEV = "development"
|
||||||
MODE_PROD = "production"
|
MODE_PROD = "production"
|
||||||
|
|
||||||
DISPATCHER_REMOTE_UPDATE = "cloud_remote_update"
|
DISPATCHER_REMOTE_UPDATE: SignalType[Any] = SignalType("cloud_remote_update")
|
||||||
|
|
||||||
STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"
|
STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"
|
||||||
|
|
|
@ -2,30 +2,73 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable, Coroutine
|
from collections.abc import Callable, Coroutine
|
||||||
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any, Generic, TypeVarTuple, overload
|
||||||
|
|
||||||
from homeassistant.core import HassJob, HomeAssistant, callback
|
from homeassistant.core import HassJob, HomeAssistant, callback
|
||||||
from homeassistant.loader import bind_hass
|
from homeassistant.loader import bind_hass
|
||||||
from homeassistant.util.async_ import run_callback_threadsafe
|
from homeassistant.util.async_ import run_callback_threadsafe
|
||||||
from homeassistant.util.logging import catch_log_exception
|
from homeassistant.util.logging import catch_log_exception
|
||||||
|
|
||||||
|
_Ts = TypeVarTuple("_Ts")
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
DATA_DISPATCHER = "dispatcher"
|
DATA_DISPATCHER = "dispatcher"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SignalType(Generic[*_Ts]):
|
||||||
|
"""Generic string class for signal to improve typing."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
"""Return hash of name."""
|
||||||
|
|
||||||
|
return hash(self.name)
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
"""Check equality for dict keys to be compatible with str."""
|
||||||
|
|
||||||
|
if isinstance(other, str):
|
||||||
|
return self.name == other
|
||||||
|
if isinstance(other, SignalType):
|
||||||
|
return self.name == other.name
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
_DispatcherDataType = dict[
|
_DispatcherDataType = dict[
|
||||||
str,
|
SignalType[*_Ts] | str,
|
||||||
dict[
|
dict[
|
||||||
Callable[..., Any],
|
Callable[[*_Ts], Any] | Callable[..., Any],
|
||||||
HassJob[..., None | Coroutine[Any, Any, None]] | None,
|
HassJob[..., None | Coroutine[Any, Any, None]] | None,
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@bind_hass
|
||||||
|
def dispatcher_connect(
|
||||||
|
hass: HomeAssistant, signal: SignalType[*_Ts], target: Callable[[*_Ts], None]
|
||||||
|
) -> Callable[[], None]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
@bind_hass
|
@bind_hass
|
||||||
def dispatcher_connect(
|
def dispatcher_connect(
|
||||||
hass: HomeAssistant, signal: str, target: Callable[..., None]
|
hass: HomeAssistant, signal: str, target: Callable[..., None]
|
||||||
|
) -> Callable[[], None]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@bind_hass # type: ignore[misc] # workaround; exclude typing of 2 overload in func def
|
||||||
|
def dispatcher_connect(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
signal: SignalType[*_Ts],
|
||||||
|
target: Callable[[*_Ts], None],
|
||||||
) -> Callable[[], None]:
|
) -> Callable[[], None]:
|
||||||
"""Connect a callable function to a signal."""
|
"""Connect a callable function to a signal."""
|
||||||
async_unsub = run_callback_threadsafe(
|
async_unsub = run_callback_threadsafe(
|
||||||
|
@ -41,9 +84,9 @@ def dispatcher_connect(
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_remove_dispatcher(
|
def _async_remove_dispatcher(
|
||||||
dispatchers: _DispatcherDataType,
|
dispatchers: _DispatcherDataType[*_Ts],
|
||||||
signal: str,
|
signal: SignalType[*_Ts] | str,
|
||||||
target: Callable[..., Any],
|
target: Callable[[*_Ts], Any] | Callable[..., Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Remove signal listener."""
|
"""Remove signal listener."""
|
||||||
try:
|
try:
|
||||||
|
@ -59,10 +102,30 @@ def _async_remove_dispatcher(
|
||||||
_LOGGER.warning("Unable to remove unknown dispatcher %s", target)
|
_LOGGER.warning("Unable to remove unknown dispatcher %s", target)
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@callback
|
||||||
|
@bind_hass
|
||||||
|
def async_dispatcher_connect(
|
||||||
|
hass: HomeAssistant, signal: SignalType[*_Ts], target: Callable[[*_Ts], Any]
|
||||||
|
) -> Callable[[], None]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
@callback
|
@callback
|
||||||
@bind_hass
|
@bind_hass
|
||||||
def async_dispatcher_connect(
|
def async_dispatcher_connect(
|
||||||
hass: HomeAssistant, signal: str, target: Callable[..., Any]
|
hass: HomeAssistant, signal: str, target: Callable[..., Any]
|
||||||
|
) -> Callable[[], None]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
@bind_hass
|
||||||
|
def async_dispatcher_connect(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
signal: SignalType[*_Ts] | str,
|
||||||
|
target: Callable[[*_Ts], Any] | Callable[..., Any],
|
||||||
) -> Callable[[], None]:
|
) -> Callable[[], None]:
|
||||||
"""Connect a callable function to a signal.
|
"""Connect a callable function to a signal.
|
||||||
|
|
||||||
|
@ -71,7 +134,7 @@ def async_dispatcher_connect(
|
||||||
if DATA_DISPATCHER not in hass.data:
|
if DATA_DISPATCHER not in hass.data:
|
||||||
hass.data[DATA_DISPATCHER] = {}
|
hass.data[DATA_DISPATCHER] = {}
|
||||||
|
|
||||||
dispatchers: _DispatcherDataType = hass.data[DATA_DISPATCHER]
|
dispatchers: _DispatcherDataType[*_Ts] = hass.data[DATA_DISPATCHER]
|
||||||
|
|
||||||
if signal not in dispatchers:
|
if signal not in dispatchers:
|
||||||
dispatchers[signal] = {}
|
dispatchers[signal] = {}
|
||||||
|
@ -84,13 +147,29 @@ def async_dispatcher_connect(
|
||||||
return partial(_async_remove_dispatcher, dispatchers, signal, target)
|
return partial(_async_remove_dispatcher, dispatchers, signal, target)
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@bind_hass
|
||||||
|
def dispatcher_send(hass: HomeAssistant, signal: SignalType[*_Ts], *args: *_Ts) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
@bind_hass
|
@bind_hass
|
||||||
def dispatcher_send(hass: HomeAssistant, signal: str, *args: Any) -> None:
|
def dispatcher_send(hass: HomeAssistant, signal: str, *args: Any) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@bind_hass # type: ignore[misc] # workaround; exclude typing of 2 overload in func def
|
||||||
|
def dispatcher_send(hass: HomeAssistant, signal: SignalType[*_Ts], *args: *_Ts) -> None:
|
||||||
"""Send signal and data."""
|
"""Send signal and data."""
|
||||||
hass.loop.call_soon_threadsafe(async_dispatcher_send, hass, signal, *args)
|
hass.loop.call_soon_threadsafe(async_dispatcher_send, hass, signal, *args)
|
||||||
|
|
||||||
|
|
||||||
def _format_err(signal: str, target: Callable[..., Any], *args: Any) -> str:
|
def _format_err(
|
||||||
|
signal: SignalType[*_Ts] | str,
|
||||||
|
target: Callable[[*_Ts], Any] | Callable[..., Any],
|
||||||
|
*args: Any,
|
||||||
|
) -> str:
|
||||||
"""Format error message."""
|
"""Format error message."""
|
||||||
return "Exception in {} when dispatching '{}': {}".format(
|
return "Exception in {} when dispatching '{}': {}".format(
|
||||||
# Functions wrapped in partial do not have a __name__
|
# Functions wrapped in partial do not have a __name__
|
||||||
|
@ -101,7 +180,7 @@ def _format_err(signal: str, target: Callable[..., Any], *args: Any) -> str:
|
||||||
|
|
||||||
|
|
||||||
def _generate_job(
|
def _generate_job(
|
||||||
signal: str, target: Callable[..., Any]
|
signal: SignalType[*_Ts] | str, target: Callable[[*_Ts], Any] | Callable[..., Any]
|
||||||
) -> HassJob[..., None | Coroutine[Any, Any, None]]:
|
) -> HassJob[..., None | Coroutine[Any, Any, None]]:
|
||||||
"""Generate a HassJob for a signal and target."""
|
"""Generate a HassJob for a signal and target."""
|
||||||
return HassJob(
|
return HassJob(
|
||||||
|
@ -110,16 +189,34 @@ def _generate_job(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@callback
|
||||||
|
@bind_hass
|
||||||
|
def async_dispatcher_send(
|
||||||
|
hass: HomeAssistant, signal: SignalType[*_Ts], *args: *_Ts
|
||||||
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
@callback
|
@callback
|
||||||
@bind_hass
|
@bind_hass
|
||||||
def async_dispatcher_send(hass: HomeAssistant, signal: str, *args: Any) -> None:
|
def async_dispatcher_send(hass: HomeAssistant, signal: str, *args: Any) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
@bind_hass
|
||||||
|
def async_dispatcher_send(
|
||||||
|
hass: HomeAssistant, signal: SignalType[*_Ts] | str, *args: *_Ts
|
||||||
|
) -> None:
|
||||||
"""Send signal and data.
|
"""Send signal and data.
|
||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
"""
|
"""
|
||||||
if (maybe_dispatchers := hass.data.get(DATA_DISPATCHER)) is None:
|
if (maybe_dispatchers := hass.data.get(DATA_DISPATCHER)) is None:
|
||||||
return
|
return
|
||||||
dispatchers: _DispatcherDataType = maybe_dispatchers
|
dispatchers: _DispatcherDataType[*_Ts] = maybe_dispatchers
|
||||||
if (target_list := dispatchers.get(signal)) is None:
|
if (target_list := dispatchers.get(signal)) is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import pytest
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers.dispatcher import (
|
from homeassistant.helpers.dispatcher import (
|
||||||
|
SignalType,
|
||||||
async_dispatcher_connect,
|
async_dispatcher_connect,
|
||||||
async_dispatcher_send,
|
async_dispatcher_send,
|
||||||
)
|
)
|
||||||
|
@ -30,6 +31,32 @@ async def test_simple_function(hass: HomeAssistant) -> None:
|
||||||
assert calls == [3, "bla"]
|
assert calls == [3, "bla"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_signal_type(hass: HomeAssistant) -> None:
|
||||||
|
"""Test dispatcher with SignalType."""
|
||||||
|
signal: SignalType[str, int] = SignalType("test")
|
||||||
|
calls: list[tuple[str, int]] = []
|
||||||
|
|
||||||
|
def test_funct(data1: str, data2: int) -> None:
|
||||||
|
calls.append((data1, data2))
|
||||||
|
|
||||||
|
async_dispatcher_connect(hass, signal, test_funct)
|
||||||
|
async_dispatcher_send(hass, signal, "Hello", 2)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert calls == [("Hello", 2)]
|
||||||
|
|
||||||
|
async_dispatcher_send(hass, signal, "World", 3)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert calls == [("Hello", 2), ("World", 3)]
|
||||||
|
|
||||||
|
# Test compatibility with string keys
|
||||||
|
async_dispatcher_send(hass, "test", "x", 4)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert calls == [("Hello", 2), ("World", 3), ("x", 4)]
|
||||||
|
|
||||||
|
|
||||||
async def test_simple_function_unsub(hass: HomeAssistant) -> None:
|
async def test_simple_function_unsub(hass: HomeAssistant) -> None:
|
||||||
"""Test simple function (executor) and unsub."""
|
"""Test simple function (executor) and unsub."""
|
||||||
calls1 = []
|
calls1 = []
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue