Add type hints to helpers.condition (#20266)

This commit is contained in:
Ville Skyttä 2019-01-21 01:03:12 +02:00 committed by Fabian Affolter
parent 5b8cb10ad7
commit 58bb6f2e99
6 changed files with 101 additions and 55 deletions

View file

@ -56,7 +56,7 @@ def async_active_zone(hass, latitude, longitude, radius=0):
return closest
def in_zone(zone, latitude, longitude, radius=0):
def in_zone(zone, latitude, longitude, radius=0) -> bool:
"""Test if given latitude, longitude is in given zone.
Async friendly.

View file

@ -678,7 +678,7 @@ class State:
"State max length is 255 characters.").format(entity_id))
self.entity_id = entity_id.lower()
self.state = state
self.state = state # type: str
self.attributes = MappingProxyType(attributes or {})
self.last_updated = last_updated or dt_util.utcnow()
self.last_changed = last_changed or self.last_updated

View file

@ -1,12 +1,14 @@
"""Offer reusable conditions."""
from datetime import timedelta
from datetime import datetime, timedelta
import functools as ft
import logging
import sys
from typing import Callable, Container, Optional, Union, cast
from homeassistant.helpers.typing import ConfigType
from homeassistant.helpers.template import Template
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, State
from homeassistant.components import zone as zone_cmp
from homeassistant.const import (
ATTR_GPS_ACCURACY, ATTR_LATITUDE, ATTR_LONGITUDE,
@ -29,25 +31,30 @@ _LOGGER = logging.getLogger(__name__)
# pylint: disable=invalid-name
def _threaded_factory(async_factory):
def _threaded_factory(async_factory:
Callable[[ConfigType, bool], Callable[..., bool]]) \
-> Callable[[ConfigType, bool], Callable[..., bool]]:
"""Create threaded versions of async factories."""
@ft.wraps(async_factory)
def factory(config, config_validation=True):
def factory(config: ConfigType,
config_validation: bool = True) -> Callable[..., bool]:
"""Threaded factory."""
async_check = async_factory(config, config_validation)
def condition_if(hass, variables=None):
def condition_if(hass: HomeAssistant,
variables: TemplateVarsType = None) -> bool:
"""Validate condition."""
return run_callback_threadsafe(
return cast(bool, run_callback_threadsafe(
hass.loop, async_check, hass, variables,
).result()
).result())
return condition_if
return factory
def async_from_config(config: ConfigType, config_validation: bool = True):
def async_from_config(config: ConfigType,
config_validation: bool = True) -> Callable[..., bool]:
"""Turn a condition configuration into a method.
Should be run on the event loop.
@ -64,20 +71,22 @@ def async_from_config(config: ConfigType, config_validation: bool = True):
raise HomeAssistantError('Invalid condition "{}" specified {}'.format(
config.get(CONF_CONDITION), config))
return factory(config, config_validation)
return cast(Callable[..., bool], factory(config, config_validation))
from_config = _threaded_factory(async_from_config)
def async_and_from_config(config: ConfigType, config_validation: bool = True):
def async_and_from_config(config: ConfigType,
config_validation: bool = True) \
-> Callable[..., bool]:
"""Create multi condition matcher using 'AND'."""
if config_validation:
config = cv.AND_CONDITION_SCHEMA(config)
checks = None
def if_and_condition(hass: HomeAssistant,
variables=None) -> bool:
variables: TemplateVarsType = None) -> bool:
"""Test and condition."""
nonlocal checks
@ -101,14 +110,16 @@ def async_and_from_config(config: ConfigType, config_validation: bool = True):
and_from_config = _threaded_factory(async_and_from_config)
def async_or_from_config(config: ConfigType, config_validation: bool = True):
def async_or_from_config(config: ConfigType,
config_validation: bool = True) \
-> Callable[..., bool]:
"""Create multi condition matcher using 'OR'."""
if config_validation:
config = cv.OR_CONDITION_SCHEMA(config)
checks = None
def if_or_condition(hass: HomeAssistant,
variables=None) -> bool:
variables: TemplateVarsType = None) -> bool:
"""Test and condition."""
nonlocal checks
@ -131,17 +142,22 @@ def async_or_from_config(config: ConfigType, config_validation: bool = True):
or_from_config = _threaded_factory(async_or_from_config)
def numeric_state(hass: HomeAssistant, entity, below=None, above=None,
value_template=None, variables=None):
def numeric_state(hass: HomeAssistant, entity: Union[None, str, State],
below: Optional[float] = None, above: Optional[float] = None,
value_template: Optional[Template] = None,
variables: TemplateVarsType = None) -> bool:
"""Test a numeric state condition."""
return run_callback_threadsafe(
return cast(bool, run_callback_threadsafe(
hass.loop, async_numeric_state, hass, entity, below, above,
value_template, variables,
).result()
).result())
def async_numeric_state(hass: HomeAssistant, entity, below=None, above=None,
value_template=None, variables=None):
def async_numeric_state(hass: HomeAssistant, entity: Union[None, str, State],
below: Optional[float] = None,
above: Optional[float] = None,
value_template: Optional[Template] = None,
variables: TemplateVarsType = None) -> bool:
"""Test a numeric state condition."""
if isinstance(entity, str):
entity = hass.states.get(entity)
@ -164,22 +180,24 @@ def async_numeric_state(hass: HomeAssistant, entity, below=None, above=None,
return False
try:
value = float(value)
fvalue = float(value)
except ValueError:
_LOGGER.warning("Value cannot be processed as a number: %s "
"(Offending entity: %s)", entity, value)
return False
if below is not None and value >= below:
if below is not None and fvalue >= below:
return False
if above is not None and value <= above:
if above is not None and fvalue <= above:
return False
return True
def async_numeric_state_from_config(config, config_validation=True):
def async_numeric_state_from_config(config: ConfigType,
config_validation: bool = True) \
-> Callable[..., bool]:
"""Wrap action method with state based condition."""
if config_validation:
config = cv.NUMERIC_STATE_CONDITION_SCHEMA(config)
@ -188,7 +206,8 @@ def async_numeric_state_from_config(config, config_validation=True):
above = config.get(CONF_ABOVE)
value_template = config.get(CONF_VALUE_TEMPLATE)
def if_numeric_state(hass, variables=None):
def if_numeric_state(hass: HomeAssistant,
variables: TemplateVarsType = None) -> bool:
"""Test numeric state condition."""
if value_template is not None:
value_template.hass = hass
@ -202,7 +221,8 @@ def async_numeric_state_from_config(config, config_validation=True):
numeric_state_from_config = _threaded_factory(async_numeric_state_from_config)
def state(hass, entity, req_state, for_period=None):
def state(hass: HomeAssistant, entity: Union[None, str, State], req_state: str,
for_period: Optional[timedelta] = None) -> bool:
"""Test if state matches requirements.
Async friendly.
@ -212,6 +232,7 @@ def state(hass, entity, req_state, for_period=None):
if entity is None:
return False
assert isinstance(entity, State)
is_state = entity.state == req_state
@ -221,22 +242,26 @@ def state(hass, entity, req_state, for_period=None):
return dt_util.utcnow() - for_period > entity.last_changed
def state_from_config(config, config_validation=True):
def state_from_config(config: ConfigType,
config_validation: bool = True) -> Callable[..., bool]:
"""Wrap action method with state based condition."""
if config_validation:
config = cv.STATE_CONDITION_SCHEMA(config)
entity_id = config.get(CONF_ENTITY_ID)
req_state = config.get(CONF_STATE)
req_state = cast(str, config.get(CONF_STATE))
for_period = config.get('for')
def if_state(hass, variables=None):
def if_state(hass: HomeAssistant,
variables: TemplateVarsType = None) -> bool:
"""Test if condition."""
return state(hass, entity_id, req_state, for_period)
return if_state
def sun(hass, before=None, after=None, before_offset=None, after_offset=None):
def sun(hass: HomeAssistant, before: Optional[str] = None,
after: Optional[str] = None, before_offset: Optional[timedelta] = None,
after_offset: Optional[timedelta] = None) -> bool:
"""Test if current time matches sun requirements."""
utcnow = dt_util.utcnow()
today = dt_util.as_local(utcnow).date()
@ -254,22 +279,27 @@ def sun(hass, before=None, after=None, before_offset=None, after_offset=None):
# There is no sunset today
return False
if before == SUN_EVENT_SUNRISE and utcnow > sunrise + before_offset:
if before == SUN_EVENT_SUNRISE and \
utcnow > cast(datetime, sunrise) + before_offset:
return False
if before == SUN_EVENT_SUNSET and utcnow > sunset + before_offset:
if before == SUN_EVENT_SUNSET and \
utcnow > cast(datetime, sunset) + before_offset:
return False
if after == SUN_EVENT_SUNRISE and utcnow < sunrise + after_offset:
if after == SUN_EVENT_SUNRISE and \
utcnow < cast(datetime, sunrise) + after_offset:
return False
if after == SUN_EVENT_SUNSET and utcnow < sunset + after_offset:
if after == SUN_EVENT_SUNSET and \
utcnow < cast(datetime, sunset) + after_offset:
return False
return True
def sun_from_config(config, config_validation=True):
def sun_from_config(config: ConfigType,
config_validation: bool = True) -> Callable[..., bool]:
"""Wrap action method with sun based condition."""
if config_validation:
config = cv.SUN_CONDITION_SCHEMA(config)
@ -278,21 +308,24 @@ def sun_from_config(config, config_validation=True):
before_offset = config.get('before_offset')
after_offset = config.get('after_offset')
def time_if(hass, variables=None):
def time_if(hass: HomeAssistant,
variables: TemplateVarsType = None) -> bool:
"""Validate time based if-condition."""
return sun(hass, before, after, before_offset, after_offset)
return time_if
def template(hass, value_template, variables=None):
def template(hass: HomeAssistant, value_template: Template,
variables: TemplateVarsType = None) -> bool:
"""Test if template condition matches."""
return run_callback_threadsafe(
return cast(bool, run_callback_threadsafe(
hass.loop, async_template, hass, value_template, variables,
).result()
).result())
def async_template(hass, value_template, variables=None):
def async_template(hass: HomeAssistant, value_template: Template,
variables: TemplateVarsType = None) -> bool:
"""Test if template condition matches."""
try:
value = value_template.async_render(variables)
@ -303,13 +336,16 @@ def async_template(hass, value_template, variables=None):
return value.lower() == 'true'
def async_template_from_config(config, config_validation=True):
def async_template_from_config(config: ConfigType,
config_validation: bool = True) \
-> Callable[..., bool]:
"""Wrap action method with state based condition."""
if config_validation:
config = cv.TEMPLATE_CONDITION_SCHEMA(config)
value_template = config.get(CONF_VALUE_TEMPLATE)
value_template = cast(Template, config.get(CONF_VALUE_TEMPLATE))
def template_if(hass, variables=None):
def template_if(hass: HomeAssistant,
variables: TemplateVarsType = None) -> bool:
"""Validate template based if-condition."""
value_template.hass = hass
@ -321,7 +357,9 @@ def async_template_from_config(config, config_validation=True):
template_from_config = _threaded_factory(async_template_from_config)
def time(before=None, after=None, weekday=None):
def time(before: Optional[dt_util.dt.time] = None,
after: Optional[dt_util.dt.time] = None,
weekday: Union[None, str, Container[str]] = None) -> bool:
"""Test if local time condition matches.
Handle the fact that time is continuous and we may be testing for
@ -354,7 +392,8 @@ def time(before=None, after=None, weekday=None):
return True
def time_from_config(config, config_validation=True):
def time_from_config(config: ConfigType,
config_validation: bool = True) -> Callable[..., bool]:
"""Wrap action method with time based condition."""
if config_validation:
config = cv.TIME_CONDITION_SCHEMA(config)
@ -362,14 +401,16 @@ def time_from_config(config, config_validation=True):
after = config.get(CONF_AFTER)
weekday = config.get(CONF_WEEKDAY)
def time_if(hass, variables=None):
def time_if(hass: HomeAssistant,
variables: TemplateVarsType = None) -> bool:
"""Validate time based if-condition."""
return time(before, after, weekday)
return time_if
def zone(hass, zone_ent, entity):
def zone(hass: HomeAssistant, zone_ent: Union[None, str, State],
entity: Union[None, str, State]) -> bool:
"""Test if zone-condition matches.
Async friendly.
@ -396,14 +437,16 @@ def zone(hass, zone_ent, entity):
entity.attributes.get(ATTR_GPS_ACCURACY, 0))
def zone_from_config(config, config_validation=True):
def zone_from_config(config: ConfigType,
config_validation: bool = True) -> Callable[..., bool]:
"""Wrap action method with zone based condition."""
if config_validation:
config = cv.ZONE_CONDITION_SCHEMA(config)
entity_id = config.get(CONF_ENTITY_ID)
zone_entity_id = config.get(CONF_ZONE)
def if_in_zone(hass, variables=None):
def if_in_zone(hass: HomeAssistant,
variables: TemplateVarsType = None) -> bool:
"""Test if condition."""
return zone(hass, zone_entity_id, entity_id)

View file

@ -18,6 +18,7 @@ from homeassistant.const import (
from homeassistant.core import State, valid_entity_id
from homeassistant.exceptions import TemplateError
from homeassistant.helpers import location as loc_helper
from homeassistant.helpers.typing import TemplateVarsType
from homeassistant.loader import bind_hass
from homeassistant.util import convert
from homeassistant.util import dt as dt_util
@ -115,7 +116,7 @@ class Template:
"""Extract all entities for state_changed listener."""
return extract_entities(self.template, variables)
def render(self, variables=None, **kwargs):
def render(self, variables: TemplateVarsType = None, **kwargs):
"""Render given template."""
if variables is not None:
kwargs.update(variables)
@ -123,7 +124,8 @@ class Template:
return run_callback_threadsafe(
self.hass.loop, self.async_render, kwargs).result()
def async_render(self, variables=None, **kwargs):
def async_render(self, variables: TemplateVarsType = None,
**kwargs) -> str:
"""Render given template.
This method must be run in the event loop.

View file

@ -1,5 +1,5 @@
"""Typing Helpers for Home Assistant."""
from typing import Dict, Any, Tuple
from typing import Dict, Any, Tuple, Optional
import homeassistant.core
@ -9,6 +9,7 @@ GPSType = Tuple[float, float]
ConfigType = Dict[str, Any]
HomeAssistantType = homeassistant.core.HomeAssistant
ServiceDataType = Dict[str, Any]
TemplateVarsType = Optional[Dict[str, Any]]
# Custom type for recorder Queries
QueryType = Any

View file

@ -60,4 +60,4 @@ whitelist_externals=/bin/bash
deps =
-r{toxinidir}/requirements_test.txt
commands =
/bin/bash -c 'mypy homeassistant/*.py homeassistant/{auth,util}/ homeassistant/helpers/{__init__,deprecation,dispatcher,entity_values,entityfilter,icon,intent,json,location,signal,state,sun,temperature,translation,typing}.py'
/bin/bash -c 'mypy homeassistant/*.py homeassistant/{auth,util}/ homeassistant/helpers/{__init__,condition,deprecation,dispatcher,entity_values,entityfilter,icon,intent,json,location,signal,state,sun,temperature,translation,typing}.py'