From 250af90acbc2a68dc2dc05b1b23ee2f8686ddc4c Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Fri, 7 Jan 2022 08:01:27 +0100 Subject: [PATCH] Improve callable typing [helpers.event] (#63543) --- homeassistant/components/flux_led/__init__.py | 2 +- homeassistant/components/netatmo/__init__.py | 9 +++- .../components/netatmo/data_handler.py | 4 +- homeassistant/components/wemo/__init__.py | 7 ++- homeassistant/helpers/event.py | 45 ++++++++++--------- 5 files changed, 39 insertions(+), 28 deletions(-) diff --git a/homeassistant/components/flux_led/__init__.py b/homeassistant/components/flux_led/__init__.py index 6c126ec633a..81fabfe61f9 100644 --- a/homeassistant/components/flux_led/__init__.py +++ b/homeassistant/components/flux_led/__init__.py @@ -135,7 +135,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: await device.async_set_time() await _async_sync_time() # set at startup - entry.async_on_unload(async_track_time_change(hass, _async_sync_time, 2, 40, 30)) # type: ignore[arg-type] + entry.async_on_unload(async_track_time_change(hass, _async_sync_time, 2, 40, 30)) return True diff --git a/homeassistant/components/netatmo/__init__.py b/homeassistant/components/netatmo/__init__.py index 530e4608621..2ada6577df4 100644 --- a/homeassistant/components/netatmo/__init__.py +++ b/homeassistant/components/netatmo/__init__.py @@ -1,6 +1,7 @@ """The Netatmo integration.""" from __future__ import annotations +from datetime import datetime from http import HTTPStatus import logging import secrets @@ -150,7 +151,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: _webhook_retries = 0 - async def unregister_webhook(call_or_event: ServiceCall | Event | None) -> None: + async def unregister_webhook( + call_or_event_or_dt: ServiceCall | Event | datetime | None, + ) -> None: if CONF_WEBHOOK_ID not in entry.data: return _LOGGER.debug("Unregister Netatmo webhook (%s)", entry.data[CONF_WEBHOOK_ID]) @@ -172,7 +175,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: _webhook_retries += 1 async_call_later(hass, 30, register_webhook) - async def register_webhook(call_or_event: ServiceCall | Event | None) -> None: + async def register_webhook( + call_or_event_or_dt: ServiceCall | Event | datetime | None, + ) -> None: if CONF_WEBHOOK_ID not in entry.data: data = {**entry.data, CONF_WEBHOOK_ID: secrets.token_hex()} hass.config_entries.async_update_entry(entry, data=data) diff --git a/homeassistant/components/netatmo/data_handler.py b/homeassistant/components/netatmo/data_handler.py index 7a97ec3748f..ace5934adbd 100644 --- a/homeassistant/components/netatmo/data_handler.py +++ b/homeassistant/components/netatmo/data_handler.py @@ -4,7 +4,7 @@ from __future__ import annotations import asyncio from collections import deque from dataclasses import dataclass -from datetime import timedelta +from datetime import datetime, timedelta from itertools import islice import logging from time import time @@ -105,7 +105,7 @@ class NetatmoDataHandler: ) ) - async def async_update(self, event_time: timedelta) -> None: + async def async_update(self, event_time: datetime) -> None: """ Update device. diff --git a/homeassistant/components/wemo/__init__.py b/homeassistant/components/wemo/__init__.py index 8d75b9bddae..31a34befda8 100644 --- a/homeassistant/components/wemo/__init__.py +++ b/homeassistant/components/wemo/__init__.py @@ -2,8 +2,9 @@ from __future__ import annotations from collections.abc import Sequence +from datetime import datetime import logging -from typing import Any, Optional, Tuple +from typing import Optional, Tuple import pywemo import voluptuous as vol @@ -197,7 +198,9 @@ class WemoDiscovery: self._scan_delay = 0 self._static_config = static_config - async def async_discover_and_schedule(self, *_: tuple[Any]) -> None: + async def async_discover_and_schedule( + self, event_time: datetime | None = None + ) -> None: """Periodically scan the network looking for WeMo devices.""" _LOGGER.debug("Scanning network for WeMo devices") try: diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index 24b0b66bc97..7872d3ead31 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -12,6 +12,7 @@ import time from typing import Any, Callable, List, Union, cast import attr +from typing_extensions import Concatenate, ParamSpec from homeassistant.const import ( ATTR_ENTITY_ID, @@ -61,6 +62,8 @@ _ENTITIES_LISTENER = "entities" _LOGGER = logging.getLogger(__name__) +_P = ParamSpec("_P") + @dataclass class TrackStates: @@ -110,20 +113,20 @@ class TrackTemplateResult: def threaded_listener_factory( - async_factory: Callable[..., Any] -) -> Callable[..., CALLBACK_TYPE]: + async_factory: Callable[Concatenate[HomeAssistant, _P], Any] # type: ignore[misc] +) -> Callable[Concatenate[HomeAssistant, _P], CALLBACK_TYPE]: # type: ignore[misc] """Convert an async event helper to a threaded one.""" @ft.wraps(async_factory) - def factory(*args: Any, **kwargs: Any) -> CALLBACK_TYPE: + def factory( + hass: HomeAssistant, *args: _P.args, **kwargs: _P.kwargs + ) -> CALLBACK_TYPE: """Call async event helper safely.""" - hass = args[0] - if not isinstance(hass, HomeAssistant): raise TypeError("First parameter needs to be a hass instance") async_remove = run_callback_threadsafe( - hass.loop, ft.partial(async_factory, *args, **kwargs) + hass.loop, ft.partial(async_factory, hass, *args, **kwargs) ).result() def remove() -> None: @@ -233,7 +236,7 @@ def async_track_state_change_event( hass: HomeAssistant, entity_ids: str | Iterable[str], action: Callable[[Event], Any], -) -> Callable[[], None]: +) -> CALLBACK_TYPE: """Track specific state change events indexed by entity_id. Unlike async_track_state_change, async_track_state_change_event @@ -329,7 +332,7 @@ def async_track_entity_registry_updated_event( hass: HomeAssistant, entity_ids: str | Iterable[str], action: Callable[[Event], Any], -) -> Callable[[], None]: +) -> CALLBACK_TYPE: """Track specific entity registry updated events indexed by entity_id. Similar to async_track_state_change_event. @@ -414,7 +417,7 @@ def async_track_state_added_domain( hass: HomeAssistant, domains: str | Iterable[str], action: Callable[[Event], Any], -) -> Callable[[], None]: +) -> CALLBACK_TYPE: """Track state change events when an entity is added to domains.""" if not (domains := _async_string_to_lower_list(domains)): return _remove_empty_listener @@ -466,7 +469,7 @@ def async_track_state_removed_domain( hass: HomeAssistant, domains: str | Iterable[str], action: Callable[[Event], Any], -) -> Callable[[], None]: +) -> CALLBACK_TYPE: """Track state change events when an entity is removed from domains.""" if not (domains := _async_string_to_lower_list(domains)): return _remove_empty_listener @@ -680,7 +683,7 @@ def async_track_template( template: Template, action: Callable[[str, State | None, State | None], Awaitable[None] | None], variables: TemplateVarsType | None = None, -) -> Callable[[], None]: +) -> CALLBACK_TYPE: """Add a listener that fires when a a template evaluates to 'true'. Listen for the result of the template becoming true, or a true-like @@ -1152,7 +1155,7 @@ def async_track_template_result( def async_track_same_state( hass: HomeAssistant, period: timedelta, - action: Callable[..., Awaitable[None] | None], + action: Callable[[], Awaitable[None] | None], async_check_same_func: Callable[[str, State | None, State | None], bool], entity_ids: str | Iterable[str] = MATCH_ALL, ) -> CALLBACK_TYPE: @@ -1221,7 +1224,7 @@ track_same_state = threaded_listener_factory(async_track_same_state) @bind_hass def async_track_point_in_time( hass: HomeAssistant, - action: HassJob | Callable[..., Awaitable[None] | None], + action: HassJob | Callable[[datetime], Awaitable[None] | None], point_in_time: datetime, ) -> CALLBACK_TYPE: """Add a listener that fires once after a specific point in time.""" @@ -1242,7 +1245,7 @@ track_point_in_time = threaded_listener_factory(async_track_point_in_time) @bind_hass def async_track_point_in_utc_time( hass: HomeAssistant, - action: HassJob | Callable[..., Awaitable[None] | None], + action: HassJob | Callable[[datetime], Awaitable[None] | None], point_in_time: datetime, ) -> CALLBACK_TYPE: """Add a listener that fires once after a specific point in UTC time.""" @@ -1294,7 +1297,7 @@ track_point_in_utc_time = threaded_listener_factory(async_track_point_in_utc_tim def async_call_later( hass: HomeAssistant, delay: float | timedelta, - action: HassJob | Callable[..., Awaitable[None] | None], + action: HassJob | Callable[[datetime], Awaitable[None] | None], ) -> CALLBACK_TYPE: """Add a listener that is called in .""" if not isinstance(delay, timedelta): @@ -1309,7 +1312,7 @@ call_later = threaded_listener_factory(async_call_later) @bind_hass def async_track_time_interval( hass: HomeAssistant, - action: Callable[..., Awaitable[None] | None], + action: Callable[[datetime], Awaitable[None] | None], interval: timedelta, ) -> CALLBACK_TYPE: """Add a listener that fires repetitively at every timedelta interval.""" @@ -1409,7 +1412,7 @@ class SunListener: @callback @bind_hass def async_track_sunrise( - hass: HomeAssistant, action: Callable[..., None], offset: timedelta | None = None + hass: HomeAssistant, action: Callable[[], None], offset: timedelta | None = None ) -> CALLBACK_TYPE: """Add a listener that will fire a specified offset from sunrise daily.""" listener = SunListener(hass, HassJob(action), SUN_EVENT_SUNRISE, offset) @@ -1423,7 +1426,7 @@ track_sunrise = threaded_listener_factory(async_track_sunrise) @callback @bind_hass def async_track_sunset( - hass: HomeAssistant, action: Callable[..., None], offset: timedelta | None = None + hass: HomeAssistant, action: Callable[[], None], offset: timedelta | None = None ) -> CALLBACK_TYPE: """Add a listener that will fire a specified offset from sunset daily.""" listener = SunListener(hass, HassJob(action), SUN_EVENT_SUNSET, offset) @@ -1441,7 +1444,7 @@ time_tracker_utcnow = dt_util.utcnow @bind_hass def async_track_utc_time_change( hass: HomeAssistant, - action: Callable[..., Awaitable[None] | None], + action: Callable[[datetime], Awaitable[None] | None], hour: Any | None = None, minute: Any | None = None, second: Any | None = None, @@ -1456,7 +1459,7 @@ def async_track_utc_time_change( @callback def time_change_listener(event: Event) -> None: """Fire every time event that comes in.""" - hass.async_run_hass_job(job, event.data[ATTR_NOW]) + hass.async_run_hass_job(job, cast(datetime, event.data[ATTR_NOW])) return hass.bus.async_listen(EVENT_TIME_CHANGED, time_change_listener) @@ -1507,7 +1510,7 @@ track_utc_time_change = threaded_listener_factory(async_track_utc_time_change) @bind_hass def async_track_time_change( hass: HomeAssistant, - action: Callable[..., None], + action: Callable[[datetime], Awaitable[None] | None], hour: Any | None = None, minute: Any | None = None, second: Any | None = None,