Add generic parameters to HassJob (#70973)

This commit is contained in:
Marc Mueller 2022-05-30 09:22:37 +02:00 committed by GitHub
parent 6bc09741c7
commit b417ae72e5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 22 deletions

View file

@ -37,6 +37,7 @@ from typing import (
) )
from urllib.parse import urlparse from urllib.parse import urlparse
from typing_extensions import ParamSpec
import voluptuous as vol import voluptuous as vol
import yarl import yarl
@ -98,6 +99,7 @@ block_async_io.enable()
_T = TypeVar("_T") _T = TypeVar("_T")
_R = TypeVar("_R") _R = TypeVar("_R")
_R_co = TypeVar("_R_co", covariant=True) _R_co = TypeVar("_R_co", covariant=True)
_P = ParamSpec("_P")
# Internal; not helpers.typing.UNDEFINED due to circular dependency # Internal; not helpers.typing.UNDEFINED due to circular dependency
_UNDEF: dict[Any, Any] = {} _UNDEF: dict[Any, Any] = {}
_CallableT = TypeVar("_CallableT", bound=Callable[..., Any]) _CallableT = TypeVar("_CallableT", bound=Callable[..., Any])
@ -182,7 +184,7 @@ class HassJobType(enum.Enum):
Executor = 3 Executor = 3
class HassJob(Generic[_R_co]): class HassJob(Generic[_P, _R_co]):
"""Represent a job to be run later. """Represent a job to be run later.
We check the callable type in advance We check the callable type in advance
@ -192,7 +194,7 @@ class HassJob(Generic[_R_co]):
__slots__ = ("job_type", "target") __slots__ = ("job_type", "target")
def __init__(self, target: Callable[..., _R_co]) -> None: def __init__(self, target: Callable[_P, _R_co]) -> None:
"""Create a job object.""" """Create a job object."""
self.target = target self.target = target
self.job_type = _get_hassjob_callable_job_type(target) self.job_type = _get_hassjob_callable_job_type(target)
@ -416,20 +418,20 @@ class HomeAssistant:
@overload @overload
@callback @callback
def async_add_hass_job( def async_add_hass_job(
self, hassjob: HassJob[Coroutine[Any, Any, _R]], *args: Any self, hassjob: HassJob[..., Coroutine[Any, Any, _R]], *args: Any
) -> asyncio.Future[_R] | None: ) -> asyncio.Future[_R] | None:
... ...
@overload @overload
@callback @callback
def async_add_hass_job( def async_add_hass_job(
self, hassjob: HassJob[Coroutine[Any, Any, _R] | _R], *args: Any self, hassjob: HassJob[..., Coroutine[Any, Any, _R] | _R], *args: Any
) -> asyncio.Future[_R] | None: ) -> asyncio.Future[_R] | None:
... ...
@callback @callback
def async_add_hass_job( def async_add_hass_job(
self, hassjob: HassJob[Coroutine[Any, Any, _R] | _R], *args: Any self, hassjob: HassJob[..., Coroutine[Any, Any, _R] | _R], *args: Any
) -> asyncio.Future[_R] | None: ) -> asyncio.Future[_R] | None:
"""Add a HassJob from within the event loop. """Add a HassJob from within the event loop.
@ -512,20 +514,20 @@ class HomeAssistant:
@overload @overload
@callback @callback
def async_run_hass_job( def async_run_hass_job(
self, hassjob: HassJob[Coroutine[Any, Any, _R]], *args: Any self, hassjob: HassJob[..., Coroutine[Any, Any, _R]], *args: Any
) -> asyncio.Future[_R] | None: ) -> asyncio.Future[_R] | None:
... ...
@overload @overload
@callback @callback
def async_run_hass_job( def async_run_hass_job(
self, hassjob: HassJob[Coroutine[Any, Any, _R] | _R], *args: Any self, hassjob: HassJob[..., Coroutine[Any, Any, _R] | _R], *args: Any
) -> asyncio.Future[_R] | None: ) -> asyncio.Future[_R] | None:
... ...
@callback @callback
def async_run_hass_job( def async_run_hass_job(
self, hassjob: HassJob[Coroutine[Any, Any, _R] | _R], *args: Any self, hassjob: HassJob[..., Coroutine[Any, Any, _R] | _R], *args: Any
) -> asyncio.Future[_R] | None: ) -> asyncio.Future[_R] | None:
"""Run a HassJob from within the event loop. """Run a HassJob from within the event loop.
@ -814,7 +816,7 @@ class Event:
class _FilterableJob(NamedTuple): class _FilterableJob(NamedTuple):
"""Event listener job to be executed with optional filter.""" """Event listener job to be executed with optional filter."""
job: HassJob[None | Awaitable[None]] job: HassJob[[Event], None | Awaitable[None]]
event_filter: Callable[[Event], bool] | None event_filter: Callable[[Event], bool] | None
run_immediately: bool run_immediately: bool

View file

@ -258,7 +258,9 @@ def _async_track_state_change_event(
action: Callable[[Event], Any], action: Callable[[Event], Any],
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""async_track_state_change_event without lowercasing.""" """async_track_state_change_event without lowercasing."""
entity_callbacks = hass.data.setdefault(TRACK_STATE_CHANGE_CALLBACKS, {}) entity_callbacks: dict[str, list[HassJob[[Event], Any]]] = hass.data.setdefault(
TRACK_STATE_CHANGE_CALLBACKS, {}
)
if TRACK_STATE_CHANGE_LISTENER not in hass.data: if TRACK_STATE_CHANGE_LISTENER not in hass.data:
@ -319,10 +321,10 @@ def _async_remove_indexed_listeners(
data_key: str, data_key: str,
listener_key: str, listener_key: str,
storage_keys: Iterable[str], storage_keys: Iterable[str],
job: HassJob[Any], job: HassJob[[Event], Any],
) -> None: ) -> None:
"""Remove a listener.""" """Remove a listener."""
callbacks = hass.data[data_key] callbacks: dict[str, list[HassJob[[Event], Any]]] = hass.data[data_key]
for storage_key in storage_keys: for storage_key in storage_keys:
callbacks[storage_key].remove(job) callbacks[storage_key].remove(job)
@ -347,7 +349,9 @@ def async_track_entity_registry_updated_event(
if not (entity_ids := _async_string_to_lower_list(entity_ids)): if not (entity_ids := _async_string_to_lower_list(entity_ids)):
return _remove_empty_listener return _remove_empty_listener
entity_callbacks = hass.data.setdefault(TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS, {}) entity_callbacks: dict[str, list[HassJob[[Event], Any]]] = hass.data.setdefault(
TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS, {}
)
if TRACK_ENTITY_REGISTRY_UPDATED_LISTENER not in hass.data: if TRACK_ENTITY_REGISTRY_UPDATED_LISTENER not in hass.data:
@ -401,7 +405,7 @@ def async_track_entity_registry_updated_event(
@callback @callback
def _async_dispatch_domain_event( def _async_dispatch_domain_event(
hass: HomeAssistant, event: Event, callbacks: dict[str, list[HassJob[Any]]] hass: HomeAssistant, event: Event, callbacks: dict[str, list[HassJob[[Event], Any]]]
) -> None: ) -> None:
domain = split_entity_id(event.data["entity_id"])[0] domain = split_entity_id(event.data["entity_id"])[0]
@ -438,7 +442,9 @@ def _async_track_state_added_domain(
action: Callable[[Event], Any], action: Callable[[Event], Any],
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""async_track_state_added_domain without lowercasing.""" """async_track_state_added_domain without lowercasing."""
domain_callbacks = hass.data.setdefault(TRACK_STATE_ADDED_DOMAIN_CALLBACKS, {}) domain_callbacks: dict[str, list[HassJob[[Event], Any]]] = hass.data.setdefault(
TRACK_STATE_ADDED_DOMAIN_CALLBACKS, {}
)
if TRACK_STATE_ADDED_DOMAIN_LISTENER not in hass.data: if TRACK_STATE_ADDED_DOMAIN_LISTENER not in hass.data:
@ -490,7 +496,9 @@ def async_track_state_removed_domain(
if not (domains := _async_string_to_lower_list(domains)): if not (domains := _async_string_to_lower_list(domains)):
return _remove_empty_listener return _remove_empty_listener
domain_callbacks = hass.data.setdefault(TRACK_STATE_REMOVED_DOMAIN_CALLBACKS, {}) domain_callbacks: dict[str, list[HassJob[[Event], Any]]] = hass.data.setdefault(
TRACK_STATE_REMOVED_DOMAIN_CALLBACKS, {}
)
if TRACK_STATE_REMOVED_DOMAIN_LISTENER not in hass.data: if TRACK_STATE_REMOVED_DOMAIN_LISTENER not in hass.data:
@ -1249,7 +1257,7 @@ track_same_state = threaded_listener_factory(async_track_same_state)
@bind_hass @bind_hass
def async_track_point_in_time( def async_track_point_in_time(
hass: HomeAssistant, hass: HomeAssistant,
action: HassJob[Awaitable[None] | None] action: HassJob[[datetime], Awaitable[None] | None]
| Callable[[datetime], Awaitable[None] | None], | Callable[[datetime], Awaitable[None] | None],
point_in_time: datetime, point_in_time: datetime,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
@ -1271,7 +1279,7 @@ track_point_in_time = threaded_listener_factory(async_track_point_in_time)
@bind_hass @bind_hass
def async_track_point_in_utc_time( def async_track_point_in_utc_time(
hass: HomeAssistant, hass: HomeAssistant,
action: HassJob[Awaitable[None] | None] action: HassJob[[datetime], Awaitable[None] | None]
| Callable[[datetime], Awaitable[None] | None], | Callable[[datetime], Awaitable[None] | None],
point_in_time: datetime, point_in_time: datetime,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
@ -1284,7 +1292,7 @@ def async_track_point_in_utc_time(
cancel_callback: asyncio.TimerHandle | None = None cancel_callback: asyncio.TimerHandle | None = None
@callback @callback
def run_action(job: HassJob[Awaitable[None] | None]) -> None: def run_action(job: HassJob[[datetime], Awaitable[None] | None]) -> None:
"""Call the action.""" """Call the action."""
nonlocal cancel_callback nonlocal cancel_callback
@ -1324,7 +1332,7 @@ track_point_in_utc_time = threaded_listener_factory(async_track_point_in_utc_tim
def async_call_later( def async_call_later(
hass: HomeAssistant, hass: HomeAssistant,
delay: float | timedelta, delay: float | timedelta,
action: HassJob[Awaitable[None] | None] action: HassJob[[datetime], Awaitable[None] | None]
| Callable[[datetime], Awaitable[None] | None], | Callable[[datetime], Awaitable[None] | None],
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Add a listener that is called in <delay>.""" """Add a listener that is called in <delay>."""
@ -1345,7 +1353,7 @@ def async_track_time_interval(
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Add a listener that fires repetitively at every timedelta interval.""" """Add a listener that fires repetitively at every timedelta interval."""
remove: CALLBACK_TYPE remove: CALLBACK_TYPE
interval_listener_job: HassJob[None] interval_listener_job: HassJob[[datetime], None]
job = HassJob(action) job = HassJob(action)
@ -1382,7 +1390,7 @@ class SunListener:
"""Helper class to help listen to sun events.""" """Helper class to help listen to sun events."""
hass: HomeAssistant = attr.ib() hass: HomeAssistant = attr.ib()
job: HassJob[Awaitable[None] | None] = attr.ib() job: HassJob[[], Awaitable[None] | None] = attr.ib()
event: str = attr.ib() event: str = attr.ib()
offset: timedelta | None = attr.ib() offset: timedelta | None = attr.ib()
_unsub_sun: CALLBACK_TYPE | None = attr.ib(default=None) _unsub_sun: CALLBACK_TYPE | None = attr.ib(default=None)