Improve callable typing [helpers.event] (#63543)

This commit is contained in:
Marc Mueller 2022-01-07 08:01:27 +01:00 committed by GitHub
parent ad68d0795e
commit 250af90acb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 39 additions and 28 deletions

View file

@ -135,7 +135,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
await device.async_set_time() await device.async_set_time()
await _async_sync_time() # set at startup await _async_sync_time() # set at startup
entry.async_on_unload(async_track_time_change(hass, _async_sync_time, 2, 40, 30)) # type: ignore[arg-type] entry.async_on_unload(async_track_time_change(hass, _async_sync_time, 2, 40, 30))
return True return True

View file

@ -1,6 +1,7 @@
"""The Netatmo integration.""" """The Netatmo integration."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime
from http import HTTPStatus from http import HTTPStatus
import logging import logging
import secrets import secrets
@ -150,7 +151,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
_webhook_retries = 0 _webhook_retries = 0
async def unregister_webhook(call_or_event: ServiceCall | Event | None) -> None: async def unregister_webhook(
call_or_event_or_dt: ServiceCall | Event | datetime | None,
) -> None:
if CONF_WEBHOOK_ID not in entry.data: if CONF_WEBHOOK_ID not in entry.data:
return return
_LOGGER.debug("Unregister Netatmo webhook (%s)", entry.data[CONF_WEBHOOK_ID]) _LOGGER.debug("Unregister Netatmo webhook (%s)", entry.data[CONF_WEBHOOK_ID])
@ -172,7 +175,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
_webhook_retries += 1 _webhook_retries += 1
async_call_later(hass, 30, register_webhook) async_call_later(hass, 30, register_webhook)
async def register_webhook(call_or_event: ServiceCall | Event | None) -> None: async def register_webhook(
call_or_event_or_dt: ServiceCall | Event | datetime | None,
) -> None:
if CONF_WEBHOOK_ID not in entry.data: if CONF_WEBHOOK_ID not in entry.data:
data = {**entry.data, CONF_WEBHOOK_ID: secrets.token_hex()} data = {**entry.data, CONF_WEBHOOK_ID: secrets.token_hex()}
hass.config_entries.async_update_entry(entry, data=data) hass.config_entries.async_update_entry(entry, data=data)

View file

@ -4,7 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections import deque from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from datetime import timedelta from datetime import datetime, timedelta
from itertools import islice from itertools import islice
import logging import logging
from time import time from time import time
@ -105,7 +105,7 @@ class NetatmoDataHandler:
) )
) )
async def async_update(self, event_time: timedelta) -> None: async def async_update(self, event_time: datetime) -> None:
""" """
Update device. Update device.

View file

@ -2,8 +2,9 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Sequence
from datetime import datetime
import logging import logging
from typing import Any, Optional, Tuple from typing import Optional, Tuple
import pywemo import pywemo
import voluptuous as vol import voluptuous as vol
@ -197,7 +198,9 @@ class WemoDiscovery:
self._scan_delay = 0 self._scan_delay = 0
self._static_config = static_config self._static_config = static_config
async def async_discover_and_schedule(self, *_: tuple[Any]) -> None: async def async_discover_and_schedule(
self, event_time: datetime | None = None
) -> None:
"""Periodically scan the network looking for WeMo devices.""" """Periodically scan the network looking for WeMo devices."""
_LOGGER.debug("Scanning network for WeMo devices") _LOGGER.debug("Scanning network for WeMo devices")
try: try:

View file

@ -12,6 +12,7 @@ import time
from typing import Any, Callable, List, Union, cast from typing import Any, Callable, List, Union, cast
import attr import attr
from typing_extensions import Concatenate, ParamSpec
from homeassistant.const import ( from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_ENTITY_ID,
@ -61,6 +62,8 @@ _ENTITIES_LISTENER = "entities"
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_P = ParamSpec("_P")
@dataclass @dataclass
class TrackStates: class TrackStates:
@ -110,20 +113,20 @@ class TrackTemplateResult:
def threaded_listener_factory( def threaded_listener_factory(
async_factory: Callable[..., Any] async_factory: Callable[Concatenate[HomeAssistant, _P], Any] # type: ignore[misc]
) -> Callable[..., CALLBACK_TYPE]: ) -> Callable[Concatenate[HomeAssistant, _P], CALLBACK_TYPE]: # type: ignore[misc]
"""Convert an async event helper to a threaded one.""" """Convert an async event helper to a threaded one."""
@ft.wraps(async_factory) @ft.wraps(async_factory)
def factory(*args: Any, **kwargs: Any) -> CALLBACK_TYPE: def factory(
hass: HomeAssistant, *args: _P.args, **kwargs: _P.kwargs
) -> CALLBACK_TYPE:
"""Call async event helper safely.""" """Call async event helper safely."""
hass = args[0]
if not isinstance(hass, HomeAssistant): if not isinstance(hass, HomeAssistant):
raise TypeError("First parameter needs to be a hass instance") raise TypeError("First parameter needs to be a hass instance")
async_remove = run_callback_threadsafe( async_remove = run_callback_threadsafe(
hass.loop, ft.partial(async_factory, *args, **kwargs) hass.loop, ft.partial(async_factory, hass, *args, **kwargs)
).result() ).result()
def remove() -> None: def remove() -> None:
@ -233,7 +236,7 @@ def async_track_state_change_event(
hass: HomeAssistant, hass: HomeAssistant,
entity_ids: str | Iterable[str], entity_ids: str | Iterable[str],
action: Callable[[Event], Any], action: Callable[[Event], Any],
) -> Callable[[], None]: ) -> CALLBACK_TYPE:
"""Track specific state change events indexed by entity_id. """Track specific state change events indexed by entity_id.
Unlike async_track_state_change, async_track_state_change_event Unlike async_track_state_change, async_track_state_change_event
@ -329,7 +332,7 @@ def async_track_entity_registry_updated_event(
hass: HomeAssistant, hass: HomeAssistant,
entity_ids: str | Iterable[str], entity_ids: str | Iterable[str],
action: Callable[[Event], Any], action: Callable[[Event], Any],
) -> Callable[[], None]: ) -> CALLBACK_TYPE:
"""Track specific entity registry updated events indexed by entity_id. """Track specific entity registry updated events indexed by entity_id.
Similar to async_track_state_change_event. Similar to async_track_state_change_event.
@ -414,7 +417,7 @@ def async_track_state_added_domain(
hass: HomeAssistant, hass: HomeAssistant,
domains: str | Iterable[str], domains: str | Iterable[str],
action: Callable[[Event], Any], action: Callable[[Event], Any],
) -> Callable[[], None]: ) -> CALLBACK_TYPE:
"""Track state change events when an entity is added to domains.""" """Track state change events when an entity is added to domains."""
if not (domains := _async_string_to_lower_list(domains)): if not (domains := _async_string_to_lower_list(domains)):
return _remove_empty_listener return _remove_empty_listener
@ -466,7 +469,7 @@ def async_track_state_removed_domain(
hass: HomeAssistant, hass: HomeAssistant,
domains: str | Iterable[str], domains: str | Iterable[str],
action: Callable[[Event], Any], action: Callable[[Event], Any],
) -> Callable[[], None]: ) -> CALLBACK_TYPE:
"""Track state change events when an entity is removed from domains.""" """Track state change events when an entity is removed from domains."""
if not (domains := _async_string_to_lower_list(domains)): if not (domains := _async_string_to_lower_list(domains)):
return _remove_empty_listener return _remove_empty_listener
@ -680,7 +683,7 @@ def async_track_template(
template: Template, template: Template,
action: Callable[[str, State | None, State | None], Awaitable[None] | None], action: Callable[[str, State | None, State | None], Awaitable[None] | None],
variables: TemplateVarsType | None = None, variables: TemplateVarsType | None = None,
) -> Callable[[], None]: ) -> CALLBACK_TYPE:
"""Add a listener that fires when a a template evaluates to 'true'. """Add a listener that fires when a a template evaluates to 'true'.
Listen for the result of the template becoming true, or a true-like Listen for the result of the template becoming true, or a true-like
@ -1152,7 +1155,7 @@ def async_track_template_result(
def async_track_same_state( def async_track_same_state(
hass: HomeAssistant, hass: HomeAssistant,
period: timedelta, period: timedelta,
action: Callable[..., Awaitable[None] | None], action: Callable[[], Awaitable[None] | None],
async_check_same_func: Callable[[str, State | None, State | None], bool], async_check_same_func: Callable[[str, State | None, State | None], bool],
entity_ids: str | Iterable[str] = MATCH_ALL, entity_ids: str | Iterable[str] = MATCH_ALL,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
@ -1221,7 +1224,7 @@ track_same_state = threaded_listener_factory(async_track_same_state)
@bind_hass @bind_hass
def async_track_point_in_time( def async_track_point_in_time(
hass: HomeAssistant, hass: HomeAssistant,
action: HassJob | Callable[..., Awaitable[None] | None], action: HassJob | Callable[[datetime], Awaitable[None] | None],
point_in_time: datetime, point_in_time: datetime,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Add a listener that fires once after a specific point in time.""" """Add a listener that fires once after a specific point in time."""
@ -1242,7 +1245,7 @@ track_point_in_time = threaded_listener_factory(async_track_point_in_time)
@bind_hass @bind_hass
def async_track_point_in_utc_time( def async_track_point_in_utc_time(
hass: HomeAssistant, hass: HomeAssistant,
action: HassJob | Callable[..., Awaitable[None] | None], action: HassJob | Callable[[datetime], Awaitable[None] | None],
point_in_time: datetime, point_in_time: datetime,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Add a listener that fires once after a specific point in UTC time.""" """Add a listener that fires once after a specific point in UTC time."""
@ -1294,7 +1297,7 @@ track_point_in_utc_time = threaded_listener_factory(async_track_point_in_utc_tim
def async_call_later( def async_call_later(
hass: HomeAssistant, hass: HomeAssistant,
delay: float | timedelta, delay: float | timedelta,
action: HassJob | Callable[..., Awaitable[None] | None], action: HassJob | Callable[[datetime], Awaitable[None] | None],
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Add a listener that is called in <delay>.""" """Add a listener that is called in <delay>."""
if not isinstance(delay, timedelta): if not isinstance(delay, timedelta):
@ -1309,7 +1312,7 @@ call_later = threaded_listener_factory(async_call_later)
@bind_hass @bind_hass
def async_track_time_interval( def async_track_time_interval(
hass: HomeAssistant, hass: HomeAssistant,
action: Callable[..., Awaitable[None] | None], action: Callable[[datetime], Awaitable[None] | None],
interval: timedelta, interval: timedelta,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Add a listener that fires repetitively at every timedelta interval.""" """Add a listener that fires repetitively at every timedelta interval."""
@ -1409,7 +1412,7 @@ class SunListener:
@callback @callback
@bind_hass @bind_hass
def async_track_sunrise( def async_track_sunrise(
hass: HomeAssistant, action: Callable[..., None], offset: timedelta | None = None hass: HomeAssistant, action: Callable[[], None], offset: timedelta | None = None
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Add a listener that will fire a specified offset from sunrise daily.""" """Add a listener that will fire a specified offset from sunrise daily."""
listener = SunListener(hass, HassJob(action), SUN_EVENT_SUNRISE, offset) listener = SunListener(hass, HassJob(action), SUN_EVENT_SUNRISE, offset)
@ -1423,7 +1426,7 @@ track_sunrise = threaded_listener_factory(async_track_sunrise)
@callback @callback
@bind_hass @bind_hass
def async_track_sunset( def async_track_sunset(
hass: HomeAssistant, action: Callable[..., None], offset: timedelta | None = None hass: HomeAssistant, action: Callable[[], None], offset: timedelta | None = None
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Add a listener that will fire a specified offset from sunset daily.""" """Add a listener that will fire a specified offset from sunset daily."""
listener = SunListener(hass, HassJob(action), SUN_EVENT_SUNSET, offset) listener = SunListener(hass, HassJob(action), SUN_EVENT_SUNSET, offset)
@ -1441,7 +1444,7 @@ time_tracker_utcnow = dt_util.utcnow
@bind_hass @bind_hass
def async_track_utc_time_change( def async_track_utc_time_change(
hass: HomeAssistant, hass: HomeAssistant,
action: Callable[..., Awaitable[None] | None], action: Callable[[datetime], Awaitable[None] | None],
hour: Any | None = None, hour: Any | None = None,
minute: Any | None = None, minute: Any | None = None,
second: Any | None = None, second: Any | None = None,
@ -1456,7 +1459,7 @@ def async_track_utc_time_change(
@callback @callback
def time_change_listener(event: Event) -> None: def time_change_listener(event: Event) -> None:
"""Fire every time event that comes in.""" """Fire every time event that comes in."""
hass.async_run_hass_job(job, event.data[ATTR_NOW]) hass.async_run_hass_job(job, cast(datetime, event.data[ATTR_NOW]))
return hass.bus.async_listen(EVENT_TIME_CHANGED, time_change_listener) return hass.bus.async_listen(EVENT_TIME_CHANGED, time_change_listener)
@ -1507,7 +1510,7 @@ track_utc_time_change = threaded_listener_factory(async_track_utc_time_change)
@bind_hass @bind_hass
def async_track_time_change( def async_track_time_change(
hass: HomeAssistant, hass: HomeAssistant,
action: Callable[..., None], action: Callable[[datetime], Awaitable[None] | None],
hour: Any | None = None, hour: Any | None = None,
minute: Any | None = None, minute: Any | None = None,
second: Any | None = None, second: Any | None = None,