Type hint improvements (#28260)

* Add and improve core and config_entries type hints

* Complete and improve config_entries type hints

* More entity registry type hints

* Complete helpers.event type hints
This commit is contained in:
Ville Skyttä 2019-10-28 22:36:26 +02:00 committed by Paulus Schoutsen
parent f7a64019b6
commit f88ead597a
9 changed files with 135 additions and 89 deletions

View file

@ -1,13 +1,14 @@
"""Helpers for listening to events."""
from datetime import datetime, timedelta
import functools as ft
from typing import Any, Callable, Iterable, Optional, Union
from typing import Any, Callable, Dict, Iterable, Optional, Union, cast
import attr
from homeassistant.loader import bind_hass
from homeassistant.helpers.sun import get_astral_event_next
from homeassistant.core import HomeAssistant, callback, CALLBACK_TYPE, Event
from homeassistant.helpers.template import Template
from homeassistant.core import HomeAssistant, callback, CALLBACK_TYPE, Event, State
from homeassistant.const import (
ATTR_NOW,
EVENT_STATE_CHANGED,
@ -21,16 +22,15 @@ from homeassistant.util import dt as dt_util
from homeassistant.util.async_ import run_callback_threadsafe
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
# PyLint does not like the use of threaded_listener_factory
# pylint: disable=invalid-name
def threaded_listener_factory(async_factory):
def threaded_listener_factory(async_factory: Callable[..., Any]) -> CALLBACK_TYPE:
"""Convert an async event helper to a threaded one."""
@ft.wraps(async_factory)
def factory(*args, **kwargs):
def factory(*args: Any, **kwargs: Any) -> CALLBACK_TYPE:
"""Call async event helper safely."""
hass = args[0]
@ -41,7 +41,7 @@ def threaded_listener_factory(async_factory):
hass.loop, ft.partial(async_factory, *args, **kwargs)
).result()
def remove():
def remove() -> None:
"""Threadsafe removal."""
run_callback_threadsafe(hass.loop, async_remove).result()
@ -52,7 +52,13 @@ def threaded_listener_factory(async_factory):
@callback
@bind_hass
def async_track_state_change(hass, entity_ids, action, from_state=None, to_state=None):
def async_track_state_change(
hass: HomeAssistant,
entity_ids: Union[str, Iterable[str]],
action: Callable[[str, State, State], None],
from_state: Union[None, str, Iterable[str]] = None,
to_state: Union[None, str, Iterable[str]] = None,
) -> CALLBACK_TYPE:
"""Track specific state changes.
entity_ids, from_state and to_state can be string or list.
@ -74,9 +80,12 @@ def async_track_state_change(hass, entity_ids, action, from_state=None, to_state
entity_ids = tuple(entity_id.lower() for entity_id in entity_ids)
@callback
def state_change_listener(event):
def state_change_listener(event: Event) -> None:
"""Handle specific state changes."""
if entity_ids != MATCH_ALL and event.data.get("entity_id") not in entity_ids:
if (
entity_ids != MATCH_ALL
and cast(str, event.data.get("entity_id")) not in entity_ids
):
return
old_state = event.data.get("old_state")
@ -103,7 +112,12 @@ track_state_change = threaded_listener_factory(async_track_state_change)
@callback
@bind_hass
def async_track_template(hass, template, action, variables=None):
def async_track_template(
hass: HomeAssistant,
template: Template,
action: Callable[[str, State, State], None],
variables: Optional[Dict[str, Any]] = None,
) -> CALLBACK_TYPE:
"""Add a listener that track state changes with template condition."""
from . import condition
@ -111,7 +125,7 @@ def async_track_template(hass, template, action, variables=None):
already_triggered = False
@callback
def template_condition_listener(entity_id, from_s, to_s):
def template_condition_listener(entity_id: str, from_s: State, to_s: State) -> None:
"""Check if condition is correct and run action."""
nonlocal already_triggered
template_result = condition.async_template(hass, template, variables)
@ -134,18 +148,22 @@ track_template = threaded_listener_factory(async_track_template)
@callback
@bind_hass
def async_track_same_state(
hass, period, action, async_check_same_func, entity_ids=MATCH_ALL
):
hass: HomeAssistant,
period: timedelta,
action: Callable[..., None],
async_check_same_func: Callable[[str, State, State], bool],
entity_ids: Union[str, Iterable[str]] = MATCH_ALL,
) -> CALLBACK_TYPE:
"""Track the state of entities for a period and run an action.
If async_check_func is None it use the state of orig_value.
Without entity_ids we track all state changes.
"""
async_remove_state_for_cancel = None
async_remove_state_for_listener = None
async_remove_state_for_cancel: Optional[CALLBACK_TYPE] = None
async_remove_state_for_listener: Optional[CALLBACK_TYPE] = None
@callback
def clear_listener():
def clear_listener() -> None:
"""Clear all unsub listener."""
nonlocal async_remove_state_for_cancel, async_remove_state_for_listener
@ -157,7 +175,7 @@ def async_track_same_state(
async_remove_state_for_cancel = None
@callback
def state_for_listener(now):
def state_for_listener(now: Any) -> None:
"""Fire on state changes after a delay and calls action."""
nonlocal async_remove_state_for_listener
async_remove_state_for_listener = None
@ -165,7 +183,9 @@ def async_track_same_state(
hass.async_run_job(action)
@callback
def state_for_cancel_listener(entity, from_state, to_state):
def state_for_cancel_listener(
entity: str, from_state: State, to_state: State
) -> None:
"""Fire on changes and cancel for listener if changed."""
if not async_check_same_func(entity, from_state, to_state):
clear_listener()
@ -193,7 +213,7 @@ def async_track_point_in_time(
utc_point_in_time = dt_util.as_utc(point_in_time)
@callback
def utc_converter(utc_now):
def utc_converter(utc_now: datetime) -> None:
"""Convert passed in UTC now to local now."""
hass.async_run_job(action, dt_util.as_local(utc_now))
@ -213,7 +233,7 @@ def async_track_point_in_utc_time(
point_in_time = dt_util.as_utc(point_in_time)
@callback
def point_in_time_listener(event):
def point_in_time_listener(event: Event) -> None:
"""Listen for matching time_changed events."""
now = event.data[ATTR_NOW]
@ -225,7 +245,7 @@ def async_track_point_in_utc_time(
# available to execute this listener it might occur that the
# listener gets lined up twice to be executed. This will make
# sure the second time it does nothing.
point_in_time_listener.run = True
setattr(point_in_time_listener, "run", True)
async_unsub()
hass.async_run_job(action, now)
@ -260,12 +280,12 @@ def async_track_time_interval(
"""Add a listener that fires repetitively at every timedelta interval."""
remove = None
def next_interval():
def next_interval() -> datetime:
"""Return the next interval."""
return dt_util.utcnow() + interval
@callback
def interval_listener(now):
def interval_listener(now: datetime) -> None:
"""Handle elapsed intervals."""
nonlocal remove
remove = async_track_point_in_utc_time(hass, interval_listener, next_interval())
@ -273,7 +293,7 @@ def async_track_time_interval(
remove = async_track_point_in_utc_time(hass, interval_listener, next_interval())
def remove_listener():
def remove_listener() -> None:
"""Remove interval listener."""
remove()
@ -387,7 +407,7 @@ def async_track_utc_time_change(
if all(val is None for val in (hour, minute, second)):
@callback
def time_change_listener(event):
def time_change_listener(event: Event) -> None:
"""Fire every time event that comes in."""
hass.async_run_job(action, event.data[ATTR_NOW])