From b417ae72e571b9b33acb591ca6da2a192fbc80f6 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Mon, 30 May 2022 09:22:37 +0200 Subject: [PATCH] Add generic parameters to HassJob (#70973) --- homeassistant/core.py | 20 +++++++++++--------- homeassistant/helpers/event.py | 34 +++++++++++++++++++++------------- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/homeassistant/core.py b/homeassistant/core.py index d7cae4e411e..b8f509abef3 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -37,6 +37,7 @@ from typing import ( ) from urllib.parse import urlparse +from typing_extensions import ParamSpec import voluptuous as vol import yarl @@ -98,6 +99,7 @@ block_async_io.enable() _T = TypeVar("_T") _R = TypeVar("_R") _R_co = TypeVar("_R_co", covariant=True) +_P = ParamSpec("_P") # Internal; not helpers.typing.UNDEFINED due to circular dependency _UNDEF: dict[Any, Any] = {} _CallableT = TypeVar("_CallableT", bound=Callable[..., Any]) @@ -182,7 +184,7 @@ class HassJobType(enum.Enum): Executor = 3 -class HassJob(Generic[_R_co]): +class HassJob(Generic[_P, _R_co]): """Represent a job to be run later. We check the callable type in advance @@ -192,7 +194,7 @@ class HassJob(Generic[_R_co]): __slots__ = ("job_type", "target") - def __init__(self, target: Callable[..., _R_co]) -> None: + def __init__(self, target: Callable[_P, _R_co]) -> None: """Create a job object.""" self.target = target self.job_type = _get_hassjob_callable_job_type(target) @@ -416,20 +418,20 @@ class HomeAssistant: @overload @callback def async_add_hass_job( - self, hassjob: HassJob[Coroutine[Any, Any, _R]], *args: Any + self, hassjob: HassJob[..., Coroutine[Any, Any, _R]], *args: Any ) -> asyncio.Future[_R] | None: ... @overload @callback def async_add_hass_job( - self, hassjob: HassJob[Coroutine[Any, Any, _R] | _R], *args: Any + self, hassjob: HassJob[..., Coroutine[Any, Any, _R] | _R], *args: Any ) -> asyncio.Future[_R] | None: ... @callback def async_add_hass_job( - self, hassjob: HassJob[Coroutine[Any, Any, _R] | _R], *args: Any + self, hassjob: HassJob[..., Coroutine[Any, Any, _R] | _R], *args: Any ) -> asyncio.Future[_R] | None: """Add a HassJob from within the event loop. @@ -512,20 +514,20 @@ class HomeAssistant: @overload @callback def async_run_hass_job( - self, hassjob: HassJob[Coroutine[Any, Any, _R]], *args: Any + self, hassjob: HassJob[..., Coroutine[Any, Any, _R]], *args: Any ) -> asyncio.Future[_R] | None: ... @overload @callback def async_run_hass_job( - self, hassjob: HassJob[Coroutine[Any, Any, _R] | _R], *args: Any + self, hassjob: HassJob[..., Coroutine[Any, Any, _R] | _R], *args: Any ) -> asyncio.Future[_R] | None: ... @callback def async_run_hass_job( - self, hassjob: HassJob[Coroutine[Any, Any, _R] | _R], *args: Any + self, hassjob: HassJob[..., Coroutine[Any, Any, _R] | _R], *args: Any ) -> asyncio.Future[_R] | None: """Run a HassJob from within the event loop. @@ -814,7 +816,7 @@ class Event: class _FilterableJob(NamedTuple): """Event listener job to be executed with optional filter.""" - job: HassJob[None | Awaitable[None]] + job: HassJob[[Event], None | Awaitable[None]] event_filter: Callable[[Event], bool] | None run_immediately: bool diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index c1229dc3e7c..c9b569c6601 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -258,7 +258,9 @@ def _async_track_state_change_event( action: Callable[[Event], Any], ) -> CALLBACK_TYPE: """async_track_state_change_event without lowercasing.""" - entity_callbacks = hass.data.setdefault(TRACK_STATE_CHANGE_CALLBACKS, {}) + entity_callbacks: dict[str, list[HassJob[[Event], Any]]] = hass.data.setdefault( + TRACK_STATE_CHANGE_CALLBACKS, {} + ) if TRACK_STATE_CHANGE_LISTENER not in hass.data: @@ -319,10 +321,10 @@ def _async_remove_indexed_listeners( data_key: str, listener_key: str, storage_keys: Iterable[str], - job: HassJob[Any], + job: HassJob[[Event], Any], ) -> None: """Remove a listener.""" - callbacks = hass.data[data_key] + callbacks: dict[str, list[HassJob[[Event], Any]]] = hass.data[data_key] for storage_key in storage_keys: callbacks[storage_key].remove(job) @@ -347,7 +349,9 @@ def async_track_entity_registry_updated_event( if not (entity_ids := _async_string_to_lower_list(entity_ids)): return _remove_empty_listener - entity_callbacks = hass.data.setdefault(TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS, {}) + entity_callbacks: dict[str, list[HassJob[[Event], Any]]] = hass.data.setdefault( + TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS, {} + ) if TRACK_ENTITY_REGISTRY_UPDATED_LISTENER not in hass.data: @@ -401,7 +405,7 @@ def async_track_entity_registry_updated_event( @callback def _async_dispatch_domain_event( - hass: HomeAssistant, event: Event, callbacks: dict[str, list[HassJob[Any]]] + hass: HomeAssistant, event: Event, callbacks: dict[str, list[HassJob[[Event], Any]]] ) -> None: domain = split_entity_id(event.data["entity_id"])[0] @@ -438,7 +442,9 @@ def _async_track_state_added_domain( action: Callable[[Event], Any], ) -> CALLBACK_TYPE: """async_track_state_added_domain without lowercasing.""" - domain_callbacks = hass.data.setdefault(TRACK_STATE_ADDED_DOMAIN_CALLBACKS, {}) + domain_callbacks: dict[str, list[HassJob[[Event], Any]]] = hass.data.setdefault( + TRACK_STATE_ADDED_DOMAIN_CALLBACKS, {} + ) if TRACK_STATE_ADDED_DOMAIN_LISTENER not in hass.data: @@ -490,7 +496,9 @@ def async_track_state_removed_domain( if not (domains := _async_string_to_lower_list(domains)): return _remove_empty_listener - domain_callbacks = hass.data.setdefault(TRACK_STATE_REMOVED_DOMAIN_CALLBACKS, {}) + domain_callbacks: dict[str, list[HassJob[[Event], Any]]] = hass.data.setdefault( + TRACK_STATE_REMOVED_DOMAIN_CALLBACKS, {} + ) if TRACK_STATE_REMOVED_DOMAIN_LISTENER not in hass.data: @@ -1249,7 +1257,7 @@ track_same_state = threaded_listener_factory(async_track_same_state) @bind_hass def async_track_point_in_time( hass: HomeAssistant, - action: HassJob[Awaitable[None] | None] + action: HassJob[[datetime], Awaitable[None] | None] | Callable[[datetime], Awaitable[None] | None], point_in_time: datetime, ) -> CALLBACK_TYPE: @@ -1271,7 +1279,7 @@ track_point_in_time = threaded_listener_factory(async_track_point_in_time) @bind_hass def async_track_point_in_utc_time( hass: HomeAssistant, - action: HassJob[Awaitable[None] | None] + action: HassJob[[datetime], Awaitable[None] | None] | Callable[[datetime], Awaitable[None] | None], point_in_time: datetime, ) -> CALLBACK_TYPE: @@ -1284,7 +1292,7 @@ def async_track_point_in_utc_time( cancel_callback: asyncio.TimerHandle | None = None @callback - def run_action(job: HassJob[Awaitable[None] | None]) -> None: + def run_action(job: HassJob[[datetime], Awaitable[None] | None]) -> None: """Call the action.""" nonlocal cancel_callback @@ -1324,7 +1332,7 @@ track_point_in_utc_time = threaded_listener_factory(async_track_point_in_utc_tim def async_call_later( hass: HomeAssistant, delay: float | timedelta, - action: HassJob[Awaitable[None] | None] + action: HassJob[[datetime], Awaitable[None] | None] | Callable[[datetime], Awaitable[None] | None], ) -> CALLBACK_TYPE: """Add a listener that is called in .""" @@ -1345,7 +1353,7 @@ def async_track_time_interval( ) -> CALLBACK_TYPE: """Add a listener that fires repetitively at every timedelta interval.""" remove: CALLBACK_TYPE - interval_listener_job: HassJob[None] + interval_listener_job: HassJob[[datetime], None] job = HassJob(action) @@ -1382,7 +1390,7 @@ class SunListener: """Helper class to help listen to sun events.""" hass: HomeAssistant = attr.ib() - job: HassJob[Awaitable[None] | None] = attr.ib() + job: HassJob[[], Awaitable[None] | None] = attr.ib() event: str = attr.ib() offset: timedelta | None = attr.ib() _unsub_sun: CALLBACK_TYPE | None = attr.ib(default=None)