Add type hints to helpers.condition (#20266)
This commit is contained in:
parent
5b8cb10ad7
commit
58bb6f2e99
6 changed files with 101 additions and 55 deletions
|
@ -56,7 +56,7 @@ def async_active_zone(hass, latitude, longitude, radius=0):
|
|||
return closest
|
||||
|
||||
|
||||
def in_zone(zone, latitude, longitude, radius=0):
|
||||
def in_zone(zone, latitude, longitude, radius=0) -> bool:
|
||||
"""Test if given latitude, longitude is in given zone.
|
||||
|
||||
Async friendly.
|
||||
|
|
|
@ -678,7 +678,7 @@ class State:
|
|||
"State max length is 255 characters.").format(entity_id))
|
||||
|
||||
self.entity_id = entity_id.lower()
|
||||
self.state = state
|
||||
self.state = state # type: str
|
||||
self.attributes = MappingProxyType(attributes or {})
|
||||
self.last_updated = last_updated or dt_util.utcnow()
|
||||
self.last_changed = last_changed or self.last_updated
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
"""Offer reusable conditions."""
|
||||
from datetime import timedelta
|
||||
from datetime import datetime, timedelta
|
||||
import functools as ft
|
||||
import logging
|
||||
import sys
|
||||
from typing import Callable, Container, Optional, Union, cast
|
||||
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.helpers.template import Template
|
||||
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.core import HomeAssistant, State
|
||||
from homeassistant.components import zone as zone_cmp
|
||||
from homeassistant.const import (
|
||||
ATTR_GPS_ACCURACY, ATTR_LATITUDE, ATTR_LONGITUDE,
|
||||
|
@ -29,25 +31,30 @@ _LOGGER = logging.getLogger(__name__)
|
|||
# pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _threaded_factory(async_factory):
|
||||
def _threaded_factory(async_factory:
|
||||
Callable[[ConfigType, bool], Callable[..., bool]]) \
|
||||
-> Callable[[ConfigType, bool], Callable[..., bool]]:
|
||||
"""Create threaded versions of async factories."""
|
||||
@ft.wraps(async_factory)
|
||||
def factory(config, config_validation=True):
|
||||
def factory(config: ConfigType,
|
||||
config_validation: bool = True) -> Callable[..., bool]:
|
||||
"""Threaded factory."""
|
||||
async_check = async_factory(config, config_validation)
|
||||
|
||||
def condition_if(hass, variables=None):
|
||||
def condition_if(hass: HomeAssistant,
|
||||
variables: TemplateVarsType = None) -> bool:
|
||||
"""Validate condition."""
|
||||
return run_callback_threadsafe(
|
||||
return cast(bool, run_callback_threadsafe(
|
||||
hass.loop, async_check, hass, variables,
|
||||
).result()
|
||||
).result())
|
||||
|
||||
return condition_if
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
def async_from_config(config: ConfigType, config_validation: bool = True):
|
||||
def async_from_config(config: ConfigType,
|
||||
config_validation: bool = True) -> Callable[..., bool]:
|
||||
"""Turn a condition configuration into a method.
|
||||
|
||||
Should be run on the event loop.
|
||||
|
@ -64,20 +71,22 @@ def async_from_config(config: ConfigType, config_validation: bool = True):
|
|||
raise HomeAssistantError('Invalid condition "{}" specified {}'.format(
|
||||
config.get(CONF_CONDITION), config))
|
||||
|
||||
return factory(config, config_validation)
|
||||
return cast(Callable[..., bool], factory(config, config_validation))
|
||||
|
||||
|
||||
from_config = _threaded_factory(async_from_config)
|
||||
|
||||
|
||||
def async_and_from_config(config: ConfigType, config_validation: bool = True):
|
||||
def async_and_from_config(config: ConfigType,
|
||||
config_validation: bool = True) \
|
||||
-> Callable[..., bool]:
|
||||
"""Create multi condition matcher using 'AND'."""
|
||||
if config_validation:
|
||||
config = cv.AND_CONDITION_SCHEMA(config)
|
||||
checks = None
|
||||
|
||||
def if_and_condition(hass: HomeAssistant,
|
||||
variables=None) -> bool:
|
||||
variables: TemplateVarsType = None) -> bool:
|
||||
"""Test and condition."""
|
||||
nonlocal checks
|
||||
|
||||
|
@ -101,14 +110,16 @@ def async_and_from_config(config: ConfigType, config_validation: bool = True):
|
|||
and_from_config = _threaded_factory(async_and_from_config)
|
||||
|
||||
|
||||
def async_or_from_config(config: ConfigType, config_validation: bool = True):
|
||||
def async_or_from_config(config: ConfigType,
|
||||
config_validation: bool = True) \
|
||||
-> Callable[..., bool]:
|
||||
"""Create multi condition matcher using 'OR'."""
|
||||
if config_validation:
|
||||
config = cv.OR_CONDITION_SCHEMA(config)
|
||||
checks = None
|
||||
|
||||
def if_or_condition(hass: HomeAssistant,
|
||||
variables=None) -> bool:
|
||||
variables: TemplateVarsType = None) -> bool:
|
||||
"""Test and condition."""
|
||||
nonlocal checks
|
||||
|
||||
|
@ -131,17 +142,22 @@ def async_or_from_config(config: ConfigType, config_validation: bool = True):
|
|||
or_from_config = _threaded_factory(async_or_from_config)
|
||||
|
||||
|
||||
def numeric_state(hass: HomeAssistant, entity, below=None, above=None,
|
||||
value_template=None, variables=None):
|
||||
def numeric_state(hass: HomeAssistant, entity: Union[None, str, State],
|
||||
below: Optional[float] = None, above: Optional[float] = None,
|
||||
value_template: Optional[Template] = None,
|
||||
variables: TemplateVarsType = None) -> bool:
|
||||
"""Test a numeric state condition."""
|
||||
return run_callback_threadsafe(
|
||||
return cast(bool, run_callback_threadsafe(
|
||||
hass.loop, async_numeric_state, hass, entity, below, above,
|
||||
value_template, variables,
|
||||
).result()
|
||||
).result())
|
||||
|
||||
|
||||
def async_numeric_state(hass: HomeAssistant, entity, below=None, above=None,
|
||||
value_template=None, variables=None):
|
||||
def async_numeric_state(hass: HomeAssistant, entity: Union[None, str, State],
|
||||
below: Optional[float] = None,
|
||||
above: Optional[float] = None,
|
||||
value_template: Optional[Template] = None,
|
||||
variables: TemplateVarsType = None) -> bool:
|
||||
"""Test a numeric state condition."""
|
||||
if isinstance(entity, str):
|
||||
entity = hass.states.get(entity)
|
||||
|
@ -164,22 +180,24 @@ def async_numeric_state(hass: HomeAssistant, entity, below=None, above=None,
|
|||
return False
|
||||
|
||||
try:
|
||||
value = float(value)
|
||||
fvalue = float(value)
|
||||
except ValueError:
|
||||
_LOGGER.warning("Value cannot be processed as a number: %s "
|
||||
"(Offending entity: %s)", entity, value)
|
||||
return False
|
||||
|
||||
if below is not None and value >= below:
|
||||
if below is not None and fvalue >= below:
|
||||
return False
|
||||
|
||||
if above is not None and value <= above:
|
||||
if above is not None and fvalue <= above:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def async_numeric_state_from_config(config, config_validation=True):
|
||||
def async_numeric_state_from_config(config: ConfigType,
|
||||
config_validation: bool = True) \
|
||||
-> Callable[..., bool]:
|
||||
"""Wrap action method with state based condition."""
|
||||
if config_validation:
|
||||
config = cv.NUMERIC_STATE_CONDITION_SCHEMA(config)
|
||||
|
@ -188,7 +206,8 @@ def async_numeric_state_from_config(config, config_validation=True):
|
|||
above = config.get(CONF_ABOVE)
|
||||
value_template = config.get(CONF_VALUE_TEMPLATE)
|
||||
|
||||
def if_numeric_state(hass, variables=None):
|
||||
def if_numeric_state(hass: HomeAssistant,
|
||||
variables: TemplateVarsType = None) -> bool:
|
||||
"""Test numeric state condition."""
|
||||
if value_template is not None:
|
||||
value_template.hass = hass
|
||||
|
@ -202,7 +221,8 @@ def async_numeric_state_from_config(config, config_validation=True):
|
|||
numeric_state_from_config = _threaded_factory(async_numeric_state_from_config)
|
||||
|
||||
|
||||
def state(hass, entity, req_state, for_period=None):
|
||||
def state(hass: HomeAssistant, entity: Union[None, str, State], req_state: str,
|
||||
for_period: Optional[timedelta] = None) -> bool:
|
||||
"""Test if state matches requirements.
|
||||
|
||||
Async friendly.
|
||||
|
@ -212,6 +232,7 @@ def state(hass, entity, req_state, for_period=None):
|
|||
|
||||
if entity is None:
|
||||
return False
|
||||
assert isinstance(entity, State)
|
||||
|
||||
is_state = entity.state == req_state
|
||||
|
||||
|
@ -221,22 +242,26 @@ def state(hass, entity, req_state, for_period=None):
|
|||
return dt_util.utcnow() - for_period > entity.last_changed
|
||||
|
||||
|
||||
def state_from_config(config, config_validation=True):
|
||||
def state_from_config(config: ConfigType,
|
||||
config_validation: bool = True) -> Callable[..., bool]:
|
||||
"""Wrap action method with state based condition."""
|
||||
if config_validation:
|
||||
config = cv.STATE_CONDITION_SCHEMA(config)
|
||||
entity_id = config.get(CONF_ENTITY_ID)
|
||||
req_state = config.get(CONF_STATE)
|
||||
req_state = cast(str, config.get(CONF_STATE))
|
||||
for_period = config.get('for')
|
||||
|
||||
def if_state(hass, variables=None):
|
||||
def if_state(hass: HomeAssistant,
|
||||
variables: TemplateVarsType = None) -> bool:
|
||||
"""Test if condition."""
|
||||
return state(hass, entity_id, req_state, for_period)
|
||||
|
||||
return if_state
|
||||
|
||||
|
||||
def sun(hass, before=None, after=None, before_offset=None, after_offset=None):
|
||||
def sun(hass: HomeAssistant, before: Optional[str] = None,
|
||||
after: Optional[str] = None, before_offset: Optional[timedelta] = None,
|
||||
after_offset: Optional[timedelta] = None) -> bool:
|
||||
"""Test if current time matches sun requirements."""
|
||||
utcnow = dt_util.utcnow()
|
||||
today = dt_util.as_local(utcnow).date()
|
||||
|
@ -254,22 +279,27 @@ def sun(hass, before=None, after=None, before_offset=None, after_offset=None):
|
|||
# There is no sunset today
|
||||
return False
|
||||
|
||||
if before == SUN_EVENT_SUNRISE and utcnow > sunrise + before_offset:
|
||||
if before == SUN_EVENT_SUNRISE and \
|
||||
utcnow > cast(datetime, sunrise) + before_offset:
|
||||
return False
|
||||
|
||||
if before == SUN_EVENT_SUNSET and utcnow > sunset + before_offset:
|
||||
if before == SUN_EVENT_SUNSET and \
|
||||
utcnow > cast(datetime, sunset) + before_offset:
|
||||
return False
|
||||
|
||||
if after == SUN_EVENT_SUNRISE and utcnow < sunrise + after_offset:
|
||||
if after == SUN_EVENT_SUNRISE and \
|
||||
utcnow < cast(datetime, sunrise) + after_offset:
|
||||
return False
|
||||
|
||||
if after == SUN_EVENT_SUNSET and utcnow < sunset + after_offset:
|
||||
if after == SUN_EVENT_SUNSET and \
|
||||
utcnow < cast(datetime, sunset) + after_offset:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def sun_from_config(config, config_validation=True):
|
||||
def sun_from_config(config: ConfigType,
|
||||
config_validation: bool = True) -> Callable[..., bool]:
|
||||
"""Wrap action method with sun based condition."""
|
||||
if config_validation:
|
||||
config = cv.SUN_CONDITION_SCHEMA(config)
|
||||
|
@ -278,21 +308,24 @@ def sun_from_config(config, config_validation=True):
|
|||
before_offset = config.get('before_offset')
|
||||
after_offset = config.get('after_offset')
|
||||
|
||||
def time_if(hass, variables=None):
|
||||
def time_if(hass: HomeAssistant,
|
||||
variables: TemplateVarsType = None) -> bool:
|
||||
"""Validate time based if-condition."""
|
||||
return sun(hass, before, after, before_offset, after_offset)
|
||||
|
||||
return time_if
|
||||
|
||||
|
||||
def template(hass, value_template, variables=None):
|
||||
def template(hass: HomeAssistant, value_template: Template,
|
||||
variables: TemplateVarsType = None) -> bool:
|
||||
"""Test if template condition matches."""
|
||||
return run_callback_threadsafe(
|
||||
return cast(bool, run_callback_threadsafe(
|
||||
hass.loop, async_template, hass, value_template, variables,
|
||||
).result()
|
||||
).result())
|
||||
|
||||
|
||||
def async_template(hass, value_template, variables=None):
|
||||
def async_template(hass: HomeAssistant, value_template: Template,
|
||||
variables: TemplateVarsType = None) -> bool:
|
||||
"""Test if template condition matches."""
|
||||
try:
|
||||
value = value_template.async_render(variables)
|
||||
|
@ -303,13 +336,16 @@ def async_template(hass, value_template, variables=None):
|
|||
return value.lower() == 'true'
|
||||
|
||||
|
||||
def async_template_from_config(config, config_validation=True):
|
||||
def async_template_from_config(config: ConfigType,
|
||||
config_validation: bool = True) \
|
||||
-> Callable[..., bool]:
|
||||
"""Wrap action method with state based condition."""
|
||||
if config_validation:
|
||||
config = cv.TEMPLATE_CONDITION_SCHEMA(config)
|
||||
value_template = config.get(CONF_VALUE_TEMPLATE)
|
||||
value_template = cast(Template, config.get(CONF_VALUE_TEMPLATE))
|
||||
|
||||
def template_if(hass, variables=None):
|
||||
def template_if(hass: HomeAssistant,
|
||||
variables: TemplateVarsType = None) -> bool:
|
||||
"""Validate template based if-condition."""
|
||||
value_template.hass = hass
|
||||
|
||||
|
@ -321,7 +357,9 @@ def async_template_from_config(config, config_validation=True):
|
|||
template_from_config = _threaded_factory(async_template_from_config)
|
||||
|
||||
|
||||
def time(before=None, after=None, weekday=None):
|
||||
def time(before: Optional[dt_util.dt.time] = None,
|
||||
after: Optional[dt_util.dt.time] = None,
|
||||
weekday: Union[None, str, Container[str]] = None) -> bool:
|
||||
"""Test if local time condition matches.
|
||||
|
||||
Handle the fact that time is continuous and we may be testing for
|
||||
|
@ -354,7 +392,8 @@ def time(before=None, after=None, weekday=None):
|
|||
return True
|
||||
|
||||
|
||||
def time_from_config(config, config_validation=True):
|
||||
def time_from_config(config: ConfigType,
|
||||
config_validation: bool = True) -> Callable[..., bool]:
|
||||
"""Wrap action method with time based condition."""
|
||||
if config_validation:
|
||||
config = cv.TIME_CONDITION_SCHEMA(config)
|
||||
|
@ -362,14 +401,16 @@ def time_from_config(config, config_validation=True):
|
|||
after = config.get(CONF_AFTER)
|
||||
weekday = config.get(CONF_WEEKDAY)
|
||||
|
||||
def time_if(hass, variables=None):
|
||||
def time_if(hass: HomeAssistant,
|
||||
variables: TemplateVarsType = None) -> bool:
|
||||
"""Validate time based if-condition."""
|
||||
return time(before, after, weekday)
|
||||
|
||||
return time_if
|
||||
|
||||
|
||||
def zone(hass, zone_ent, entity):
|
||||
def zone(hass: HomeAssistant, zone_ent: Union[None, str, State],
|
||||
entity: Union[None, str, State]) -> bool:
|
||||
"""Test if zone-condition matches.
|
||||
|
||||
Async friendly.
|
||||
|
@ -396,14 +437,16 @@ def zone(hass, zone_ent, entity):
|
|||
entity.attributes.get(ATTR_GPS_ACCURACY, 0))
|
||||
|
||||
|
||||
def zone_from_config(config, config_validation=True):
|
||||
def zone_from_config(config: ConfigType,
|
||||
config_validation: bool = True) -> Callable[..., bool]:
|
||||
"""Wrap action method with zone based condition."""
|
||||
if config_validation:
|
||||
config = cv.ZONE_CONDITION_SCHEMA(config)
|
||||
entity_id = config.get(CONF_ENTITY_ID)
|
||||
zone_entity_id = config.get(CONF_ZONE)
|
||||
|
||||
def if_in_zone(hass, variables=None):
|
||||
def if_in_zone(hass: HomeAssistant,
|
||||
variables: TemplateVarsType = None) -> bool:
|
||||
"""Test if condition."""
|
||||
return zone(hass, zone_entity_id, entity_id)
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ from homeassistant.const import (
|
|||
from homeassistant.core import State, valid_entity_id
|
||||
from homeassistant.exceptions import TemplateError
|
||||
from homeassistant.helpers import location as loc_helper
|
||||
from homeassistant.helpers.typing import TemplateVarsType
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util import convert
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
@ -115,7 +116,7 @@ class Template:
|
|||
"""Extract all entities for state_changed listener."""
|
||||
return extract_entities(self.template, variables)
|
||||
|
||||
def render(self, variables=None, **kwargs):
|
||||
def render(self, variables: TemplateVarsType = None, **kwargs):
|
||||
"""Render given template."""
|
||||
if variables is not None:
|
||||
kwargs.update(variables)
|
||||
|
@ -123,7 +124,8 @@ class Template:
|
|||
return run_callback_threadsafe(
|
||||
self.hass.loop, self.async_render, kwargs).result()
|
||||
|
||||
def async_render(self, variables=None, **kwargs):
|
||||
def async_render(self, variables: TemplateVarsType = None,
|
||||
**kwargs) -> str:
|
||||
"""Render given template.
|
||||
|
||||
This method must be run in the event loop.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""Typing Helpers for Home Assistant."""
|
||||
from typing import Dict, Any, Tuple
|
||||
from typing import Dict, Any, Tuple, Optional
|
||||
|
||||
import homeassistant.core
|
||||
|
||||
|
@ -9,6 +9,7 @@ GPSType = Tuple[float, float]
|
|||
ConfigType = Dict[str, Any]
|
||||
HomeAssistantType = homeassistant.core.HomeAssistant
|
||||
ServiceDataType = Dict[str, Any]
|
||||
TemplateVarsType = Optional[Dict[str, Any]]
|
||||
|
||||
# Custom type for recorder Queries
|
||||
QueryType = Any
|
||||
|
|
2
tox.ini
2
tox.ini
|
@ -60,4 +60,4 @@ whitelist_externals=/bin/bash
|
|||
deps =
|
||||
-r{toxinidir}/requirements_test.txt
|
||||
commands =
|
||||
/bin/bash -c 'mypy homeassistant/*.py homeassistant/{auth,util}/ homeassistant/helpers/{__init__,deprecation,dispatcher,entity_values,entityfilter,icon,intent,json,location,signal,state,sun,temperature,translation,typing}.py'
|
||||
/bin/bash -c 'mypy homeassistant/*.py homeassistant/{auth,util}/ homeassistant/helpers/{__init__,condition,deprecation,dispatcher,entity_values,entityfilter,icon,intent,json,location,signal,state,sun,temperature,translation,typing}.py'
|
||||
|
|
Loading…
Add table
Reference in a new issue