Type hint improvements (#28260)
* Add and improve core and config_entries type hints * Complete and improve config_entries type hints * More entity registry type hints * Complete helpers.event type hints
This commit is contained in:
parent
f7a64019b6
commit
f88ead597a
9 changed files with 135 additions and 89 deletions
|
@ -1,13 +1,14 @@
|
|||
"""Helpers for listening to events."""
|
||||
from datetime import datetime, timedelta
|
||||
import functools as ft
|
||||
from typing import Any, Callable, Iterable, Optional, Union
|
||||
from typing import Any, Callable, Dict, Iterable, Optional, Union, cast
|
||||
|
||||
import attr
|
||||
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.helpers.sun import get_astral_event_next
|
||||
from homeassistant.core import HomeAssistant, callback, CALLBACK_TYPE, Event
|
||||
from homeassistant.helpers.template import Template
|
||||
from homeassistant.core import HomeAssistant, callback, CALLBACK_TYPE, Event, State
|
||||
from homeassistant.const import (
|
||||
ATTR_NOW,
|
||||
EVENT_STATE_CHANGED,
|
||||
|
@ -21,16 +22,15 @@ from homeassistant.util import dt as dt_util
|
|||
from homeassistant.util.async_ import run_callback_threadsafe
|
||||
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
|
||||
# PyLint does not like the use of threaded_listener_factory
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
|
||||
def threaded_listener_factory(async_factory):
|
||||
def threaded_listener_factory(async_factory: Callable[..., Any]) -> CALLBACK_TYPE:
|
||||
"""Convert an async event helper to a threaded one."""
|
||||
|
||||
@ft.wraps(async_factory)
|
||||
def factory(*args, **kwargs):
|
||||
def factory(*args: Any, **kwargs: Any) -> CALLBACK_TYPE:
|
||||
"""Call async event helper safely."""
|
||||
hass = args[0]
|
||||
|
||||
|
@ -41,7 +41,7 @@ def threaded_listener_factory(async_factory):
|
|||
hass.loop, ft.partial(async_factory, *args, **kwargs)
|
||||
).result()
|
||||
|
||||
def remove():
|
||||
def remove() -> None:
|
||||
"""Threadsafe removal."""
|
||||
run_callback_threadsafe(hass.loop, async_remove).result()
|
||||
|
||||
|
@ -52,7 +52,13 @@ def threaded_listener_factory(async_factory):
|
|||
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_track_state_change(hass, entity_ids, action, from_state=None, to_state=None):
|
||||
def async_track_state_change(
|
||||
hass: HomeAssistant,
|
||||
entity_ids: Union[str, Iterable[str]],
|
||||
action: Callable[[str, State, State], None],
|
||||
from_state: Union[None, str, Iterable[str]] = None,
|
||||
to_state: Union[None, str, Iterable[str]] = None,
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Track specific state changes.
|
||||
|
||||
entity_ids, from_state and to_state can be string or list.
|
||||
|
@ -74,9 +80,12 @@ def async_track_state_change(hass, entity_ids, action, from_state=None, to_state
|
|||
entity_ids = tuple(entity_id.lower() for entity_id in entity_ids)
|
||||
|
||||
@callback
|
||||
def state_change_listener(event):
|
||||
def state_change_listener(event: Event) -> None:
|
||||
"""Handle specific state changes."""
|
||||
if entity_ids != MATCH_ALL and event.data.get("entity_id") not in entity_ids:
|
||||
if (
|
||||
entity_ids != MATCH_ALL
|
||||
and cast(str, event.data.get("entity_id")) not in entity_ids
|
||||
):
|
||||
return
|
||||
|
||||
old_state = event.data.get("old_state")
|
||||
|
@ -103,7 +112,12 @@ track_state_change = threaded_listener_factory(async_track_state_change)
|
|||
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_track_template(hass, template, action, variables=None):
|
||||
def async_track_template(
|
||||
hass: HomeAssistant,
|
||||
template: Template,
|
||||
action: Callable[[str, State, State], None],
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Add a listener that track state changes with template condition."""
|
||||
from . import condition
|
||||
|
||||
|
@ -111,7 +125,7 @@ def async_track_template(hass, template, action, variables=None):
|
|||
already_triggered = False
|
||||
|
||||
@callback
|
||||
def template_condition_listener(entity_id, from_s, to_s):
|
||||
def template_condition_listener(entity_id: str, from_s: State, to_s: State) -> None:
|
||||
"""Check if condition is correct and run action."""
|
||||
nonlocal already_triggered
|
||||
template_result = condition.async_template(hass, template, variables)
|
||||
|
@ -134,18 +148,22 @@ track_template = threaded_listener_factory(async_track_template)
|
|||
@callback
|
||||
@bind_hass
|
||||
def async_track_same_state(
|
||||
hass, period, action, async_check_same_func, entity_ids=MATCH_ALL
|
||||
):
|
||||
hass: HomeAssistant,
|
||||
period: timedelta,
|
||||
action: Callable[..., None],
|
||||
async_check_same_func: Callable[[str, State, State], bool],
|
||||
entity_ids: Union[str, Iterable[str]] = MATCH_ALL,
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Track the state of entities for a period and run an action.
|
||||
|
||||
If async_check_func is None it use the state of orig_value.
|
||||
Without entity_ids we track all state changes.
|
||||
"""
|
||||
async_remove_state_for_cancel = None
|
||||
async_remove_state_for_listener = None
|
||||
async_remove_state_for_cancel: Optional[CALLBACK_TYPE] = None
|
||||
async_remove_state_for_listener: Optional[CALLBACK_TYPE] = None
|
||||
|
||||
@callback
|
||||
def clear_listener():
|
||||
def clear_listener() -> None:
|
||||
"""Clear all unsub listener."""
|
||||
nonlocal async_remove_state_for_cancel, async_remove_state_for_listener
|
||||
|
||||
|
@ -157,7 +175,7 @@ def async_track_same_state(
|
|||
async_remove_state_for_cancel = None
|
||||
|
||||
@callback
|
||||
def state_for_listener(now):
|
||||
def state_for_listener(now: Any) -> None:
|
||||
"""Fire on state changes after a delay and calls action."""
|
||||
nonlocal async_remove_state_for_listener
|
||||
async_remove_state_for_listener = None
|
||||
|
@ -165,7 +183,9 @@ def async_track_same_state(
|
|||
hass.async_run_job(action)
|
||||
|
||||
@callback
|
||||
def state_for_cancel_listener(entity, from_state, to_state):
|
||||
def state_for_cancel_listener(
|
||||
entity: str, from_state: State, to_state: State
|
||||
) -> None:
|
||||
"""Fire on changes and cancel for listener if changed."""
|
||||
if not async_check_same_func(entity, from_state, to_state):
|
||||
clear_listener()
|
||||
|
@ -193,7 +213,7 @@ def async_track_point_in_time(
|
|||
utc_point_in_time = dt_util.as_utc(point_in_time)
|
||||
|
||||
@callback
|
||||
def utc_converter(utc_now):
|
||||
def utc_converter(utc_now: datetime) -> None:
|
||||
"""Convert passed in UTC now to local now."""
|
||||
hass.async_run_job(action, dt_util.as_local(utc_now))
|
||||
|
||||
|
@ -213,7 +233,7 @@ def async_track_point_in_utc_time(
|
|||
point_in_time = dt_util.as_utc(point_in_time)
|
||||
|
||||
@callback
|
||||
def point_in_time_listener(event):
|
||||
def point_in_time_listener(event: Event) -> None:
|
||||
"""Listen for matching time_changed events."""
|
||||
now = event.data[ATTR_NOW]
|
||||
|
||||
|
@ -225,7 +245,7 @@ def async_track_point_in_utc_time(
|
|||
# available to execute this listener it might occur that the
|
||||
# listener gets lined up twice to be executed. This will make
|
||||
# sure the second time it does nothing.
|
||||
point_in_time_listener.run = True
|
||||
setattr(point_in_time_listener, "run", True)
|
||||
async_unsub()
|
||||
|
||||
hass.async_run_job(action, now)
|
||||
|
@ -260,12 +280,12 @@ def async_track_time_interval(
|
|||
"""Add a listener that fires repetitively at every timedelta interval."""
|
||||
remove = None
|
||||
|
||||
def next_interval():
|
||||
def next_interval() -> datetime:
|
||||
"""Return the next interval."""
|
||||
return dt_util.utcnow() + interval
|
||||
|
||||
@callback
|
||||
def interval_listener(now):
|
||||
def interval_listener(now: datetime) -> None:
|
||||
"""Handle elapsed intervals."""
|
||||
nonlocal remove
|
||||
remove = async_track_point_in_utc_time(hass, interval_listener, next_interval())
|
||||
|
@ -273,7 +293,7 @@ def async_track_time_interval(
|
|||
|
||||
remove = async_track_point_in_utc_time(hass, interval_listener, next_interval())
|
||||
|
||||
def remove_listener():
|
||||
def remove_listener() -> None:
|
||||
"""Remove interval listener."""
|
||||
remove()
|
||||
|
||||
|
@ -387,7 +407,7 @@ def async_track_utc_time_change(
|
|||
if all(val is None for val in (hour, minute, second)):
|
||||
|
||||
@callback
|
||||
def time_change_listener(event):
|
||||
def time_change_listener(event: Event) -> None:
|
||||
"""Fire every time event that comes in."""
|
||||
hass.async_run_job(action, event.data[ATTR_NOW])
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue