Type hint improvements (#28260)

* Add and improve core and config_entries type hints

* Complete and improve config_entries type hints

* More entity registry type hints

* Complete helpers.event type hints
This commit is contained in:
Ville Skyttä 2019-10-28 22:36:26 +02:00 committed by Paulus Schoutsen
parent f7a64019b6
commit f88ead597a
9 changed files with 135 additions and 89 deletions

View file

@ -16,7 +16,7 @@ from homeassistant.const import (
STATE_ON, STATE_ON,
STATE_UNAVAILABLE, STATE_UNAVAILABLE,
) )
from homeassistant.core import State, callback from homeassistant.core import CALLBACK_TYPE, State, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.event import async_track_state_change from homeassistant.helpers.event import async_track_state_change
from homeassistant.helpers.typing import ConfigType, HomeAssistantType from homeassistant.helpers.typing import ConfigType, HomeAssistantType
@ -96,7 +96,7 @@ class LightGroup(light.Light):
self._effect_list: Optional[List[str]] = None self._effect_list: Optional[List[str]] = None
self._effect: Optional[str] = None self._effect: Optional[str] = None
self._supported_features: int = 0 self._supported_features: int = 0
self._async_unsub_state_changed = None self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Register callbacks.""" """Register callbacks."""
@ -108,6 +108,7 @@ class LightGroup(light.Light):
"""Handle child updates.""" """Handle child updates."""
self.async_schedule_update_ha_state(True) self.async_schedule_update_ha_state(True)
assert self.hass is not None
self._async_unsub_state_changed = async_track_state_change( self._async_unsub_state_changed = async_track_state_change(
self.hass, self._entity_ids, async_state_changed_listener self.hass, self._entity_ids, async_state_changed_listener
) )

View file

@ -12,7 +12,7 @@ from homeassistant.const import (
STATE_ON, STATE_ON,
STATE_UNAVAILABLE, STATE_UNAVAILABLE,
) )
from homeassistant.core import State, callback from homeassistant.core import CALLBACK_TYPE, State, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.helpers.event import async_track_state_change from homeassistant.helpers.event import async_track_state_change
@ -56,7 +56,7 @@ class LightSwitch(Light):
self._switch_entity_id = switch_entity_id self._switch_entity_id = switch_entity_id
self._is_on = False self._is_on = False
self._available = False self._available = False
self._async_unsub_state_changed = None self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None
@property @property
def name(self) -> str: def name(self) -> str:
@ -113,6 +113,7 @@ class LightSwitch(Light):
"""Handle child updates.""" """Handle child updates."""
self.async_schedule_update_ha_state(True) self.async_schedule_update_ha_state(True)
assert self.hass is not None
self._async_unsub_state_changed = async_track_state_change( self._async_unsub_state_changed = async_track_state_change(
self.hass, self._switch_entity_id, async_state_changed_listener self.hass, self._switch_entity_id, async_state_changed_listener
) )

View file

@ -3,7 +3,7 @@ import asyncio
import logging import logging
import functools import functools
import uuid import uuid
from typing import Any, Callable, List, Optional, Set from typing import Any, Callable, Dict, List, Optional, Set, cast
import weakref import weakref
import attr import attr
@ -14,11 +14,11 @@ from homeassistant.exceptions import HomeAssistantError, ConfigEntryNotReady
from homeassistant.setup import async_setup_component, async_process_deps_reqs from homeassistant.setup import async_setup_component, async_process_deps_reqs
from homeassistant.util.decorator import Registry from homeassistant.util.decorator import Registry
from homeassistant.helpers import entity_registry from homeassistant.helpers import entity_registry
from homeassistant.helpers.event import Event
# mypy: allow-untyped-defs, no-check-untyped-defs
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_UNDEF = object() _UNDEF: dict = {}
SOURCE_USER = "user" SOURCE_USER = "user"
SOURCE_DISCOVERY = "discovery" SOURCE_DISCOVERY = "discovery"
@ -205,7 +205,7 @@ class ConfigEntry:
wait_time, wait_time,
) )
async def setup_again(now): async def setup_again(now: Any) -> None:
"""Run setup again.""" """Run setup again."""
self._async_cancel_retry_setup = None self._async_cancel_retry_setup = None
await self.async_setup(hass, integration=integration, tries=tries) await self.async_setup(hass, integration=integration, tries=tries)
@ -357,7 +357,7 @@ class ConfigEntry:
return lambda: self.update_listeners.remove(weak_listener) return lambda: self.update_listeners.remove(weak_listener)
def as_dict(self): def as_dict(self) -> Dict[str, Any]:
"""Return dictionary version of this entry.""" """Return dictionary version of this entry."""
return { return {
"entry_id": self.entry_id, "entry_id": self.entry_id,
@ -418,7 +418,7 @@ class ConfigEntries:
return list(self._entries) return list(self._entries)
return [entry for entry in self._entries if entry.domain == domain] return [entry for entry in self._entries if entry.domain == domain]
async def async_remove(self, entry_id): async def async_remove(self, entry_id: str) -> Dict[str, Any]:
"""Remove an entry.""" """Remove an entry."""
entry = self.async_get_entry(entry_id) entry = self.async_get_entry(entry_id)
@ -529,8 +529,13 @@ class ConfigEntries:
@callback @callback
def async_update_entry( def async_update_entry(
self, entry, *, data=_UNDEF, options=_UNDEF, system_options=_UNDEF self,
): entry: ConfigEntry,
*,
data: dict = _UNDEF,
options: dict = _UNDEF,
system_options: dict = _UNDEF,
) -> None:
"""Update a config entry.""" """Update a config entry."""
if data is not _UNDEF: if data is not _UNDEF:
entry.data = data entry.data = data
@ -547,7 +552,7 @@ class ConfigEntries:
self._async_schedule_save() self._async_schedule_save()
async def async_forward_entry_setup(self, entry, domain): async def async_forward_entry_setup(self, entry: ConfigEntry, domain: str) -> bool:
"""Forward the setup of an entry to a different component. """Forward the setup of an entry to a different component.
By default an entry is setup with the component it belongs to. If that By default an entry is setup with the component it belongs to. If that
@ -567,8 +572,9 @@ class ConfigEntries:
integration = await loader.async_get_integration(self.hass, domain) integration = await loader.async_get_integration(self.hass, domain)
await entry.async_setup(self.hass, integration=integration) await entry.async_setup(self.hass, integration=integration)
return True
async def async_forward_entry_unload(self, entry, domain): async def async_forward_entry_unload(self, entry: ConfigEntry, domain: str) -> bool:
"""Forward the unloading of an entry to a different component.""" """Forward the unloading of an entry to a different component."""
# It was never loaded. # It was never loaded.
if domain not in self.hass.config.components: if domain not in self.hass.config.components:
@ -578,7 +584,9 @@ class ConfigEntries:
return await entry.async_unload(self.hass, integration=integration) return await entry.async_unload(self.hass, integration=integration)
async def _async_finish_flow(self, flow, result): async def _async_finish_flow(
self, flow: "ConfigFlow", result: Dict[str, Any]
) -> Dict[str, Any]:
"""Finish a config flow and add an entry.""" """Finish a config flow and add an entry."""
# Remove notification if no other discovery config entries in progress # Remove notification if no other discovery config entries in progress
if not any( if not any(
@ -611,7 +619,9 @@ class ConfigEntries:
result["result"] = entry result["result"] = entry
return result return result
async def _async_create_flow(self, handler_key, *, context, data): async def _async_create_flow(
self, handler_key: str, *, context: Dict[str, Any], data: Dict[str, Any]
) -> "ConfigFlow":
"""Create a flow for specified handler. """Create a flow for specified handler.
Handler key is the domain of the component that we want to set up. Handler key is the domain of the component that we want to set up.
@ -654,7 +664,7 @@ class ConfigEntries:
notification_id=DISCOVERY_NOTIFICATION_ID, notification_id=DISCOVERY_NOTIFICATION_ID,
) )
flow = handler() flow = cast(ConfigFlow, handler())
flow.init_step = source flow.init_step = source
return flow return flow
@ -663,12 +673,12 @@ class ConfigEntries:
self._store.async_delay_save(self._data_to_save, SAVE_DELAY) self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
@callback @callback
def _data_to_save(self): def _data_to_save(self) -> Dict[str, List[Dict[str, Any]]]:
"""Return data to save.""" """Return data to save."""
return {"entries": [entry.as_dict() for entry in self._entries]} return {"entries": [entry.as_dict() for entry in self._entries]}
async def _old_conf_migrator(old_config): async def _old_conf_migrator(old_config: Dict[str, Any]) -> Dict[str, Any]:
"""Migrate the pre-0.73 config format to the latest version.""" """Migrate the pre-0.73 config format to the latest version."""
return {"entries": old_config} return {"entries": old_config}
@ -686,18 +696,20 @@ class ConfigFlow(data_entry_flow.FlowHandler):
@staticmethod @staticmethod
@callback @callback
def async_get_options_flow(config_entry): def async_get_options_flow(config_entry: ConfigEntry) -> "OptionsFlow":
"""Get the options flow for this handler.""" """Get the options flow for this handler."""
raise data_entry_flow.UnknownHandler raise data_entry_flow.UnknownHandler
@callback @callback
def _async_current_entries(self): def _async_current_entries(self) -> List[ConfigEntry]:
"""Return current entries.""" """Return current entries."""
assert self.hass is not None
return self.hass.config_entries.async_entries(self.handler) return self.hass.config_entries.async_entries(self.handler)
@callback @callback
def _async_in_progress(self): def _async_in_progress(self) -> List[Dict]:
"""Return other in progress flows for current domain.""" """Return other in progress flows for current domain."""
assert self.hass is not None
return [ return [
flw flw
for flw in self.hass.config_entries.flow.async_progress() for flw in self.hass.config_entries.flow.async_progress()
@ -715,29 +727,33 @@ class OptionsFlowManager:
hass, self._async_create_flow, self._async_finish_flow hass, self._async_create_flow, self._async_finish_flow
) )
async def _async_create_flow(self, entry_id, *, context, data): async def _async_create_flow(
self, entry_id: str, *, context: Dict[str, Any], data: Dict[str, Any]
) -> Optional["OptionsFlow"]:
"""Create an options flow for a config entry. """Create an options flow for a config entry.
Entry_id and flow.handler is the same thing to map entry with flow. Entry_id and flow.handler is the same thing to map entry with flow.
""" """
entry = self.hass.config_entries.async_get_entry(entry_id) entry = self.hass.config_entries.async_get_entry(entry_id)
if entry is None: if entry is None:
return return None
if entry.domain not in HANDLERS: if entry.domain not in HANDLERS:
raise data_entry_flow.UnknownHandler raise data_entry_flow.UnknownHandler
flow = HANDLERS[entry.domain].async_get_options_flow(entry) flow = cast(OptionsFlow, HANDLERS[entry.domain].async_get_options_flow(entry))
return flow return flow
async def _async_finish_flow(self, flow, result): async def _async_finish_flow(
self, flow: "OptionsFlow", result: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""Finish an options flow and update options for configuration entry. """Finish an options flow and update options for configuration entry.
Flow.handler and entry_id is the same thing to map flow with entry. Flow.handler and entry_id is the same thing to map flow with entry.
""" """
entry = self.hass.config_entries.async_get_entry(flow.handler) entry = self.hass.config_entries.async_get_entry(flow.handler)
if entry is None: if entry is None:
return return None
self.hass.config_entries.async_update_entry(entry, options=result["data"]) self.hass.config_entries.async_update_entry(entry, options=result["data"])
result["result"] = True result["result"] = True
@ -747,7 +763,7 @@ class OptionsFlowManager:
class OptionsFlow(data_entry_flow.FlowHandler): class OptionsFlow(data_entry_flow.FlowHandler):
"""Base class for config option flows.""" """Base class for config option flows."""
pass handler: str
@attr.s(slots=True) @attr.s(slots=True)
@ -756,11 +772,11 @@ class SystemOptions:
disable_new_entities = attr.ib(type=bool, default=False) disable_new_entities = attr.ib(type=bool, default=False)
def update(self, *, disable_new_entities): def update(self, *, disable_new_entities: bool) -> None:
"""Update properties.""" """Update properties."""
self.disable_new_entities = disable_new_entities self.disable_new_entities = disable_new_entities
def as_dict(self): def as_dict(self) -> Dict[str, Any]:
"""Return dictionary version of this config entrys system options.""" """Return dictionary version of this config entrys system options."""
return {"disable_new_entities": self.disable_new_entities} return {"disable_new_entities": self.disable_new_entities}
@ -784,7 +800,7 @@ class EntityRegistryDisabledHandler:
entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, self._handle_entry_updated entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, self._handle_entry_updated
) )
async def _handle_entry_updated(self, event): async def _handle_entry_updated(self, event: Event) -> None:
"""Handle entity registry entry update.""" """Handle entity registry entry update."""
if ( if (
event.data["action"] != "update" event.data["action"] != "update"
@ -811,6 +827,7 @@ class EntityRegistryDisabledHandler:
config_entry = self.hass.config_entries.async_get_entry( config_entry = self.hass.config_entries.async_get_entry(
entity_entry.config_entry_id entity_entry.config_entry_id
) )
assert config_entry is not None
if config_entry.entry_id not in self.changed and await support_entry_unload( if config_entry.entry_id not in self.changed and await support_entry_unload(
self.hass, config_entry.domain self.hass, config_entry.domain
@ -830,7 +847,7 @@ class EntityRegistryDisabledHandler:
self.RELOAD_AFTER_UPDATE_DELAY, self._handle_reload self.RELOAD_AFTER_UPDATE_DELAY, self._handle_reload
) )
async def _handle_reload(self, _now): async def _handle_reload(self, _now: Any) -> None:
"""Handle a reload.""" """Handle a reload."""
self._remove_call_later = None self._remove_call_later = None
to_reload = self.changed to_reload = self.changed

View file

@ -1283,7 +1283,7 @@ class Config:
self.skip_pip: bool = False self.skip_pip: bool = False
# List of loaded components # List of loaded components
self.components: set = set() self.components: Set[str] = set()
# API (HTTP) server configuration, see components.http.ApiConfig # API (HTTP) server configuration, see components.http.ApiConfig
self.api: Optional[Any] = None self.api: Optional[Any] = None

View file

@ -1,6 +1,6 @@
"""Classes to help gather user submissions.""" """Classes to help gather user submissions."""
import logging import logging
from typing import Dict, Any, Callable, Hashable, List, Optional from typing import Dict, Any, Callable, List, Optional
import uuid import uuid
import voluptuous as vol import voluptuous as vol
from .core import callback, HomeAssistant from .core import callback, HomeAssistant
@ -58,7 +58,7 @@ class FlowManager:
] ]
async def async_init( async def async_init(
self, handler: Hashable, *, context: Optional[Dict] = None, data: Any = None self, handler: str, *, context: Optional[Dict] = None, data: Any = None
) -> Any: ) -> Any:
"""Start a configuration flow.""" """Start a configuration flow."""
if context is None: if context is None:
@ -170,7 +170,7 @@ class FlowHandler:
# Set by flow manager # Set by flow manager
flow_id: str = None # type: ignore flow_id: str = None # type: ignore
hass: Optional[HomeAssistant] = None hass: Optional[HomeAssistant] = None
handler: Optional[Hashable] = None handler: Optional[str] = None
cur_step: Optional[Dict[str, str]] = None cur_step: Optional[Dict[str, str]] = None
context: Dict context: Dict

View file

@ -399,7 +399,7 @@ class OAuth2Session:
new_token = await self.implementation.async_refresh_token(token) new_token = await self.implementation.async_refresh_token(token)
self.hass.config_entries.async_update_entry( # type: ignore self.hass.config_entries.async_update_entry(
self.config_entry, data={**self.config_entry.data, "token": new_token} self.config_entry, data={**self.config_entry.data, "token": new_token}
) )

View file

@ -7,15 +7,15 @@ The Entity Registry will persist itself 10 seconds after a new entity is
registered. Registering a new entity while a timer is in progress resets the registered. Registering a new entity while a timer is in progress resets the
timer. timer.
""" """
from asyncio import Event import asyncio
from collections import OrderedDict from collections import OrderedDict
from itertools import chain from itertools import chain
import logging import logging
from typing import List, Optional, cast from typing import Any, Dict, Iterable, List, Optional, cast
import attr import attr
from homeassistant.core import callback, split_entity_id, valid_entity_id from homeassistant.core import Event, callback, split_entity_id, valid_entity_id
from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util import ensure_unique_string, slugify from homeassistant.util import ensure_unique_string, slugify
@ -24,8 +24,7 @@ from homeassistant.util.yaml import load_yaml
from .typing import HomeAssistantType from .typing import HomeAssistantType
# mypy: allow-untyped-calls, allow-untyped-defs # mypy: allow-untyped-defs, no-check-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any
PATH_REGISTRY = "entity_registry.yaml" PATH_REGISTRY = "entity_registry.yaml"
DATA_REGISTRY = "entity_registry" DATA_REGISTRY = "entity_registry"
@ -51,7 +50,7 @@ class RegistryEntry:
platform = attr.ib(type=str) platform = attr.ib(type=str)
name = attr.ib(type=str, default=None) name = attr.ib(type=str, default=None)
device_id = attr.ib(type=str, default=None) device_id = attr.ib(type=str, default=None)
config_entry_id = attr.ib(type=str, default=None) config_entry_id: Optional[str] = attr.ib(default=None)
disabled_by = attr.ib( disabled_by = attr.ib(
type=Optional[str], type=Optional[str],
default=None, default=None,
@ -68,12 +67,12 @@ class RegistryEntry:
domain = attr.ib(type=str, init=False, repr=False) domain = attr.ib(type=str, init=False, repr=False)
@domain.default @domain.default
def _domain_default(self): def _domain_default(self) -> str:
"""Compute domain value.""" """Compute domain value."""
return split_entity_id(self.entity_id)[0] return split_entity_id(self.entity_id)[0]
@property @property
def disabled(self): def disabled(self) -> bool:
"""Return if entry is disabled.""" """Return if entry is disabled."""
return self.disabled_by is not None return self.disabled_by is not None
@ -81,17 +80,17 @@ class RegistryEntry:
class EntityRegistry: class EntityRegistry:
"""Class to hold a registry of entities.""" """Class to hold a registry of entities."""
def __init__(self, hass): def __init__(self, hass: HomeAssistantType):
"""Initialize the registry.""" """Initialize the registry."""
self.hass = hass self.hass = hass
self.entities = None self.entities: Dict[str, RegistryEntry]
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
self.hass.bus.async_listen( self.hass.bus.async_listen(
EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_removed EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_removed
) )
@callback @callback
def async_is_registered(self, entity_id): def async_is_registered(self, entity_id: str) -> bool:
"""Check if an entity_id is currently registered.""" """Check if an entity_id is currently registered."""
return entity_id in self.entities return entity_id in self.entities
@ -116,8 +115,11 @@ class EntityRegistry:
@callback @callback
def async_generate_entity_id( def async_generate_entity_id(
self, domain, suggested_object_id, known_object_ids=None self,
): domain: str,
suggested_object_id: str,
known_object_ids: Optional[Iterable[str]] = None,
) -> str:
"""Generate an entity ID that does not conflict. """Generate an entity ID that does not conflict.
Conflicts checked against registered and currently existing entities. Conflicts checked against registered and currently existing entities.
@ -195,7 +197,7 @@ class EntityRegistry:
return entity return entity
@callback @callback
def async_remove(self, entity_id): def async_remove(self, entity_id: str) -> None:
"""Remove an entity from registry.""" """Remove an entity from registry."""
self.entities.pop(entity_id) self.entities.pop(entity_id)
self.hass.bus.async_fire( self.hass.bus.async_fire(
@ -204,7 +206,7 @@ class EntityRegistry:
self.async_schedule_save() self.async_schedule_save()
@callback @callback
def async_device_removed(self, event): def async_device_removed(self, event: Event) -> None:
"""Handle the removal of a device. """Handle the removal of a device.
Remove entities from the registry that are associated to a device when Remove entities from the registry that are associated to a device when
@ -309,7 +311,7 @@ class EntityRegistry:
return new return new
async def async_load(self): async def async_load(self) -> None:
"""Load the entity registry.""" """Load the entity registry."""
data = await self.hass.helpers.storage.async_migrator( data = await self.hass.helpers.storage.async_migrator(
self.hass.config.path(PATH_REGISTRY), self.hass.config.path(PATH_REGISTRY),
@ -317,7 +319,7 @@ class EntityRegistry:
old_conf_load_func=load_yaml, old_conf_load_func=load_yaml,
old_conf_migrate_func=_async_migrate, old_conf_migrate_func=_async_migrate,
) )
entities = OrderedDict() entities: Dict[str, RegistryEntry] = OrderedDict()
if data is not None: if data is not None:
for entity in data["entities"]: for entity in data["entities"]:
@ -334,12 +336,12 @@ class EntityRegistry:
self.entities = entities self.entities = entities
@callback @callback
def async_schedule_save(self): def async_schedule_save(self) -> None:
"""Schedule saving the entity registry.""" """Schedule saving the entity registry."""
self._store.async_delay_save(self._data_to_save, SAVE_DELAY) self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
@callback @callback
def _data_to_save(self): def _data_to_save(self) -> Dict[str, Any]:
"""Return data of entity registry to store in a file.""" """Return data of entity registry to store in a file."""
data = {} data = {}
@ -359,7 +361,7 @@ class EntityRegistry:
return data return data
@callback @callback
def async_clear_config_entry(self, config_entry): def async_clear_config_entry(self, config_entry: str) -> None:
"""Clear config entry from registry entries.""" """Clear config entry from registry entries."""
for entity_id in [ for entity_id in [
entity_id entity_id
@ -375,7 +377,7 @@ async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry:
reg_or_evt = hass.data.get(DATA_REGISTRY) reg_or_evt = hass.data.get(DATA_REGISTRY)
if not reg_or_evt: if not reg_or_evt:
evt = hass.data[DATA_REGISTRY] = Event() evt = hass.data[DATA_REGISTRY] = asyncio.Event()
reg = EntityRegistry(hass) reg = EntityRegistry(hass)
await reg.async_load() await reg.async_load()
@ -384,7 +386,7 @@ async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry:
evt.set() evt.set()
return reg return reg
if isinstance(reg_or_evt, Event): if isinstance(reg_or_evt, asyncio.Event):
evt = reg_or_evt evt = reg_or_evt
await evt.wait() await evt.wait()
return cast(EntityRegistry, hass.data.get(DATA_REGISTRY)) return cast(EntityRegistry, hass.data.get(DATA_REGISTRY))
@ -402,7 +404,7 @@ def async_entries_for_device(
] ]
async def _async_migrate(entities): async def _async_migrate(entities: Dict[str, Any]) -> Dict[str, List[Dict[str, Any]]]:
"""Migrate the YAML config file to storage helper format.""" """Migrate the YAML config file to storage helper format."""
return { return {
"entities": [ "entities": [

View file

@ -1,13 +1,14 @@
"""Helpers for listening to events.""" """Helpers for listening to events."""
from datetime import datetime, timedelta from datetime import datetime, timedelta
import functools as ft import functools as ft
from typing import Any, Callable, Iterable, Optional, Union from typing import Any, Callable, Dict, Iterable, Optional, Union, cast
import attr import attr
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.helpers.sun import get_astral_event_next from homeassistant.helpers.sun import get_astral_event_next
from homeassistant.core import HomeAssistant, callback, CALLBACK_TYPE, Event from homeassistant.helpers.template import Template
from homeassistant.core import HomeAssistant, callback, CALLBACK_TYPE, Event, State
from homeassistant.const import ( from homeassistant.const import (
ATTR_NOW, ATTR_NOW,
EVENT_STATE_CHANGED, EVENT_STATE_CHANGED,
@ -21,16 +22,15 @@ from homeassistant.util import dt as dt_util
from homeassistant.util.async_ import run_callback_threadsafe from homeassistant.util.async_ import run_callback_threadsafe
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
# PyLint does not like the use of threaded_listener_factory # PyLint does not like the use of threaded_listener_factory
# pylint: disable=invalid-name # pylint: disable=invalid-name
def threaded_listener_factory(async_factory): def threaded_listener_factory(async_factory: Callable[..., Any]) -> CALLBACK_TYPE:
"""Convert an async event helper to a threaded one.""" """Convert an async event helper to a threaded one."""
@ft.wraps(async_factory) @ft.wraps(async_factory)
def factory(*args, **kwargs): def factory(*args: Any, **kwargs: Any) -> CALLBACK_TYPE:
"""Call async event helper safely.""" """Call async event helper safely."""
hass = args[0] hass = args[0]
@ -41,7 +41,7 @@ def threaded_listener_factory(async_factory):
hass.loop, ft.partial(async_factory, *args, **kwargs) hass.loop, ft.partial(async_factory, *args, **kwargs)
).result() ).result()
def remove(): def remove() -> None:
"""Threadsafe removal.""" """Threadsafe removal."""
run_callback_threadsafe(hass.loop, async_remove).result() run_callback_threadsafe(hass.loop, async_remove).result()
@ -52,7 +52,13 @@ def threaded_listener_factory(async_factory):
@callback @callback
@bind_hass @bind_hass
def async_track_state_change(hass, entity_ids, action, from_state=None, to_state=None): def async_track_state_change(
hass: HomeAssistant,
entity_ids: Union[str, Iterable[str]],
action: Callable[[str, State, State], None],
from_state: Union[None, str, Iterable[str]] = None,
to_state: Union[None, str, Iterable[str]] = None,
) -> CALLBACK_TYPE:
"""Track specific state changes. """Track specific state changes.
entity_ids, from_state and to_state can be string or list. entity_ids, from_state and to_state can be string or list.
@ -74,9 +80,12 @@ def async_track_state_change(hass, entity_ids, action, from_state=None, to_state
entity_ids = tuple(entity_id.lower() for entity_id in entity_ids) entity_ids = tuple(entity_id.lower() for entity_id in entity_ids)
@callback @callback
def state_change_listener(event): def state_change_listener(event: Event) -> None:
"""Handle specific state changes.""" """Handle specific state changes."""
if entity_ids != MATCH_ALL and event.data.get("entity_id") not in entity_ids: if (
entity_ids != MATCH_ALL
and cast(str, event.data.get("entity_id")) not in entity_ids
):
return return
old_state = event.data.get("old_state") old_state = event.data.get("old_state")
@ -103,7 +112,12 @@ track_state_change = threaded_listener_factory(async_track_state_change)
@callback @callback
@bind_hass @bind_hass
def async_track_template(hass, template, action, variables=None): def async_track_template(
hass: HomeAssistant,
template: Template,
action: Callable[[str, State, State], None],
variables: Optional[Dict[str, Any]] = None,
) -> CALLBACK_TYPE:
"""Add a listener that track state changes with template condition.""" """Add a listener that track state changes with template condition."""
from . import condition from . import condition
@ -111,7 +125,7 @@ def async_track_template(hass, template, action, variables=None):
already_triggered = False already_triggered = False
@callback @callback
def template_condition_listener(entity_id, from_s, to_s): def template_condition_listener(entity_id: str, from_s: State, to_s: State) -> None:
"""Check if condition is correct and run action.""" """Check if condition is correct and run action."""
nonlocal already_triggered nonlocal already_triggered
template_result = condition.async_template(hass, template, variables) template_result = condition.async_template(hass, template, variables)
@ -134,18 +148,22 @@ track_template = threaded_listener_factory(async_track_template)
@callback @callback
@bind_hass @bind_hass
def async_track_same_state( def async_track_same_state(
hass, period, action, async_check_same_func, entity_ids=MATCH_ALL hass: HomeAssistant,
): period: timedelta,
action: Callable[..., None],
async_check_same_func: Callable[[str, State, State], bool],
entity_ids: Union[str, Iterable[str]] = MATCH_ALL,
) -> CALLBACK_TYPE:
"""Track the state of entities for a period and run an action. """Track the state of entities for a period and run an action.
If async_check_func is None it use the state of orig_value. If async_check_func is None it use the state of orig_value.
Without entity_ids we track all state changes. Without entity_ids we track all state changes.
""" """
async_remove_state_for_cancel = None async_remove_state_for_cancel: Optional[CALLBACK_TYPE] = None
async_remove_state_for_listener = None async_remove_state_for_listener: Optional[CALLBACK_TYPE] = None
@callback @callback
def clear_listener(): def clear_listener() -> None:
"""Clear all unsub listener.""" """Clear all unsub listener."""
nonlocal async_remove_state_for_cancel, async_remove_state_for_listener nonlocal async_remove_state_for_cancel, async_remove_state_for_listener
@ -157,7 +175,7 @@ def async_track_same_state(
async_remove_state_for_cancel = None async_remove_state_for_cancel = None
@callback @callback
def state_for_listener(now): def state_for_listener(now: Any) -> None:
"""Fire on state changes after a delay and calls action.""" """Fire on state changes after a delay and calls action."""
nonlocal async_remove_state_for_listener nonlocal async_remove_state_for_listener
async_remove_state_for_listener = None async_remove_state_for_listener = None
@ -165,7 +183,9 @@ def async_track_same_state(
hass.async_run_job(action) hass.async_run_job(action)
@callback @callback
def state_for_cancel_listener(entity, from_state, to_state): def state_for_cancel_listener(
entity: str, from_state: State, to_state: State
) -> None:
"""Fire on changes and cancel for listener if changed.""" """Fire on changes and cancel for listener if changed."""
if not async_check_same_func(entity, from_state, to_state): if not async_check_same_func(entity, from_state, to_state):
clear_listener() clear_listener()
@ -193,7 +213,7 @@ def async_track_point_in_time(
utc_point_in_time = dt_util.as_utc(point_in_time) utc_point_in_time = dt_util.as_utc(point_in_time)
@callback @callback
def utc_converter(utc_now): def utc_converter(utc_now: datetime) -> None:
"""Convert passed in UTC now to local now.""" """Convert passed in UTC now to local now."""
hass.async_run_job(action, dt_util.as_local(utc_now)) hass.async_run_job(action, dt_util.as_local(utc_now))
@ -213,7 +233,7 @@ def async_track_point_in_utc_time(
point_in_time = dt_util.as_utc(point_in_time) point_in_time = dt_util.as_utc(point_in_time)
@callback @callback
def point_in_time_listener(event): def point_in_time_listener(event: Event) -> None:
"""Listen for matching time_changed events.""" """Listen for matching time_changed events."""
now = event.data[ATTR_NOW] now = event.data[ATTR_NOW]
@ -225,7 +245,7 @@ def async_track_point_in_utc_time(
# available to execute this listener it might occur that the # available to execute this listener it might occur that the
# listener gets lined up twice to be executed. This will make # listener gets lined up twice to be executed. This will make
# sure the second time it does nothing. # sure the second time it does nothing.
point_in_time_listener.run = True setattr(point_in_time_listener, "run", True)
async_unsub() async_unsub()
hass.async_run_job(action, now) hass.async_run_job(action, now)
@ -260,12 +280,12 @@ def async_track_time_interval(
"""Add a listener that fires repetitively at every timedelta interval.""" """Add a listener that fires repetitively at every timedelta interval."""
remove = None remove = None
def next_interval(): def next_interval() -> datetime:
"""Return the next interval.""" """Return the next interval."""
return dt_util.utcnow() + interval return dt_util.utcnow() + interval
@callback @callback
def interval_listener(now): def interval_listener(now: datetime) -> None:
"""Handle elapsed intervals.""" """Handle elapsed intervals."""
nonlocal remove nonlocal remove
remove = async_track_point_in_utc_time(hass, interval_listener, next_interval()) remove = async_track_point_in_utc_time(hass, interval_listener, next_interval())
@ -273,7 +293,7 @@ def async_track_time_interval(
remove = async_track_point_in_utc_time(hass, interval_listener, next_interval()) remove = async_track_point_in_utc_time(hass, interval_listener, next_interval())
def remove_listener(): def remove_listener() -> None:
"""Remove interval listener.""" """Remove interval listener."""
remove() remove()
@ -387,7 +407,7 @@ def async_track_utc_time_change(
if all(val is None for val in (hour, minute, second)): if all(val is None for val in (hour, minute, second)):
@callback @callback
def time_change_listener(event): def time_change_listener(event: Event) -> None:
"""Fire every time event that comes in.""" """Fire every time event that comes in."""
hass.async_run_job(action, event.data[ATTR_NOW]) hass.async_run_job(action, event.data[ATTR_NOW])

View file

@ -7,7 +7,7 @@ import random
import re import re
from datetime import datetime from datetime import datetime
from functools import wraps from functools import wraps
from typing import Any, Iterable from typing import Any, Dict, Iterable, List, Optional, Union
import jinja2 import jinja2
from jinja2 import contextfilter, contextfunction from jinja2 import contextfilter, contextfunction
@ -72,7 +72,9 @@ def render_complex(value, variables=None):
return value.async_render(variables) return value.async_render(variables)
def extract_entities(template, variables=None): def extract_entities(
template: Optional[str], variables: Optional[Dict[str, Any]] = None
) -> Union[str, List[str]]:
"""Extract all entities for state_changed listener from template string.""" """Extract all entities for state_changed listener from template string."""
if template is None or _RE_JINJA_DELIMITERS.search(template) is None: if template is None or _RE_JINJA_DELIMITERS.search(template) is None:
return [] return []
@ -86,6 +88,7 @@ def extract_entities(template, variables=None):
for result in extraction: for result in extraction:
if ( if (
result[0] == "trigger.entity_id" result[0] == "trigger.entity_id"
and variables
and "trigger" in variables and "trigger" in variables
and "entity_id" in variables["trigger"] and "entity_id" in variables["trigger"]
): ):
@ -163,7 +166,7 @@ class Template:
if not isinstance(template, str): if not isinstance(template, str):
raise TypeError("Expected template to be a string") raise TypeError("Expected template to be a string")
self.template = template self.template: str = template
self._compiled_code = None self._compiled_code = None
self._compiled = None self._compiled = None
self.hass = hass self.hass = hass
@ -187,7 +190,9 @@ class Template:
except jinja2.exceptions.TemplateSyntaxError as err: except jinja2.exceptions.TemplateSyntaxError as err:
raise TemplateError(err) raise TemplateError(err)
def extract_entities(self, variables=None): def extract_entities(
self, variables: Dict[str, Any] = None
) -> Union[str, List[str]]:
"""Extract all entities for state_changed listener.""" """Extract all entities for state_changed listener."""
return extract_entities(self.template, variables) return extract_entities(self.template, variables)