Template: Expand method to expand groups, and closest as filter (#23691)

* Implement expand method

* Allow expand and closest to be used as filters

* Correct patch

* Addresses review comments
This commit is contained in:
Penny Wood 2019-06-22 15:32:32 +08:00 committed by Paulus Schoutsen
parent a6eef22fbc
commit 22d9bee41a
3 changed files with 346 additions and 185 deletions

View file

@ -6,23 +6,24 @@ import math
import random
import re
from datetime import datetime
from functools import wraps
from typing import Iterable
import jinja2
from jinja2 import contextfilter
from jinja2 import contextfilter, contextfunction
from jinja2.sandbox import ImmutableSandboxedEnvironment
from jinja2.utils import Namespace
from homeassistant.const import (ATTR_LATITUDE, ATTR_LONGITUDE, MATCH_ALL,
ATTR_UNIT_OF_MEASUREMENT, STATE_UNKNOWN)
from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_LATITUDE, ATTR_LONGITUDE, ATTR_UNIT_OF_MEASUREMENT,
MATCH_ALL, STATE_UNKNOWN)
from homeassistant.core import (
State, callback, valid_entity_id, split_entity_id)
State, callback, split_entity_id, valid_entity_id)
from homeassistant.exceptions import TemplateError
from homeassistant.helpers import location as loc_helper
from homeassistant.helpers.typing import TemplateVarsType
from homeassistant.loader import bind_hass
from homeassistant.util import convert
from homeassistant.util import dt as dt_util
from homeassistant.util import location as loc_util
from homeassistant.util import convert, dt as dt_util, location as loc_util
from homeassistant.util.async_ import run_callback_threadsafe
_LOGGER = logging.getLogger(__name__)
@ -30,6 +31,7 @@ _SENTINEL = object()
DATE_STR_FORMAT = "%Y-%m-%d %H:%M:%S"
_RENDER_INFO = 'template.render_info'
_ENVIRONMENT = 'template.environment'
_RE_NONE_ENTITIES = re.compile(r"distance\(|closest\(", re.I | re.M)
_RE_GET_ENTITIES = re.compile(
@ -152,13 +154,22 @@ class Template:
self._compiled = None
self.hass = hass
@property
def _env(self):
if self.hass is None:
return _NO_HASS_ENV
ret = self.hass.data.get(_ENVIRONMENT)
if ret is None:
ret = self.hass.data[_ENVIRONMENT] = TemplateEnvironment(self.hass)
return ret
def ensure_valid(self):
"""Return if template is valid."""
if self._compiled_code is not None:
return
try:
self._compiled_code = ENV.compile(self.template)
self._compiled_code = self._env.compile(self.template)
except jinja2.exceptions.TemplateSyntaxError as err:
raise TemplateError(err)
@ -254,19 +265,10 @@ class Template:
assert self.hass is not None, 'hass variable not set on template'
template_methods = TemplateMethods(self.hass)
global_vars = ENV.make_globals({
'closest': template_methods.closest,
'distance': template_methods.distance,
'is_state': template_methods.is_state,
'is_state_attr': template_methods.is_state_attr,
'state_attr': template_methods.state_attr,
'states': AllStates(self.hass),
})
env = self._env
self._compiled = jinja2.Template.from_code(
ENV, self._compiled_code, global_vars, None)
env, self._compiled_code, env.globals, None)
return self._compiled
@ -384,6 +386,7 @@ class TemplateState(State):
def _access_state(self):
state = object.__getattribute__(self, '_state')
hass = object.__getattribute__(self, '_hass')
_collect_state(hass, state.entity_id)
return state
@ -438,151 +441,184 @@ def _get_state(hass, entity_id):
return _wrap_state(hass, state)
class TemplateMethods:
"""Class to expose helpers to templates."""
def _resolve_state(hass, entity_id_or_state):
"""Return state or entity_id if given."""
if isinstance(entity_id_or_state, State):
return entity_id_or_state
if isinstance(entity_id_or_state, str):
return _get_state(hass, entity_id_or_state)
return None
def __init__(self, hass):
"""Initialize the helpers."""
self._hass = hass
def closest(self, *args):
"""Find closest entity.
def expand(hass, *args) -> Iterable[State]:
"""Expand out any groups into entity states."""
search = list(args)
found = {}
while search:
entity = search.pop()
if isinstance(entity, str):
entity_id = entity
entity = _get_state(hass, entity)
if entity is None:
continue
elif isinstance(entity, State):
entity_id = entity.entity_id
elif isinstance(entity, Iterable):
search += entity
continue
else:
# ignore other types
continue
Closest to home:
closest(states)
closest(states.device_tracker)
closest('group.children')
closest(states.group.children)
from homeassistant.components import group
if split_entity_id(entity_id)[0] == group.DOMAIN:
# Collect state will be called in here since it's wrapped
group_entities = entity.attributes.get(ATTR_ENTITY_ID)
if group_entities:
search += group_entities
else:
found[entity_id] = entity
return sorted(found.values(), key=lambda a: a.entity_id)
Closest to a point:
closest(23.456, 23.456, 'group.children')
closest('zone.school', 'group.children')
closest(states.zone.school, 'group.children')
"""
if len(args) == 1:
latitude = self._hass.config.latitude
longitude = self._hass.config.longitude
entities = args[0]
elif len(args) == 2:
point_state = self._resolve_state(args[0])
def closest(hass, *args):
"""Find closest entity.
if point_state is None:
_LOGGER.warning("Closest:Unable to find state %s", args[0])
Closest to home:
closest(states)
closest(states.device_tracker)
closest('group.children')
closest(states.group.children)
Closest to a point:
closest(23.456, 23.456, 'group.children')
closest('zone.school', 'group.children')
closest(states.zone.school, 'group.children')
As a filter:
states | closest
states.device_tracker | closest
['group.children', states.device_tracker] | closest
'group.children' | closest(23.456, 23.456)
states.device_tracker | closest('zone.school')
'group.children' | closest(states.zone.school)
"""
if len(args) == 1:
latitude = hass.config.latitude
longitude = hass.config.longitude
entities = args[0]
elif len(args) == 2:
point_state = _resolve_state(hass, args[0])
if point_state is None:
_LOGGER.warning("Closest:Unable to find state %s", args[0])
return None
if not loc_helper.has_location(point_state):
_LOGGER.warning(
"Closest:State does not contain valid location: %s",
point_state)
return None
latitude = point_state.attributes.get(ATTR_LATITUDE)
longitude = point_state.attributes.get(ATTR_LONGITUDE)
entities = args[1]
else:
latitude = convert(args[0], float)
longitude = convert(args[1], float)
if latitude is None or longitude is None:
_LOGGER.warning(
"Closest:Received invalid coordinates: %s, %s",
args[0], args[1])
return None
entities = args[2]
states = expand(hass, entities)
# state will already be wrapped here
return loc_helper.closest(latitude, longitude, states)
def closest_filter(hass, *args):
"""Call closest as a filter. Need to reorder arguments."""
new_args = list(args[1:])
new_args.append(args[0])
return closest(hass, *new_args)
def distance(hass, *args):
"""Calculate distance.
Will calculate distance from home to a point or between points.
Points can be passed in using state objects or lat/lng coordinates.
"""
locations = []
to_process = list(args)
while to_process:
value = to_process.pop(0)
point_state = _resolve_state(hass, value)
if point_state is None:
# We expect this and next value to be lat&lng
if not to_process:
_LOGGER.warning(
"Distance:Expected latitude and longitude, got %s",
value)
return None
value_2 = to_process.pop(0)
latitude = convert(value, float)
longitude = convert(value_2, float)
if latitude is None or longitude is None:
_LOGGER.warning("Distance:Unable to process latitude and "
"longitude: %s, %s", value, value_2)
return None
else:
if not loc_helper.has_location(point_state):
_LOGGER.warning(
"Closest:State does not contain valid location: %s",
"distance:State does not contain valid location: %s",
point_state)
return None
latitude = point_state.attributes.get(ATTR_LATITUDE)
longitude = point_state.attributes.get(ATTR_LONGITUDE)
entities = args[1]
locations.append((latitude, longitude))
else:
latitude = convert(args[0], float)
longitude = convert(args[1], float)
if len(locations) == 1:
return hass.config.distance(*locations[0])
if latitude is None or longitude is None:
_LOGGER.warning(
"Closest:Received invalid coordinates: %s, %s",
args[0], args[1])
return None
return hass.config.units.length(
loc_util.distance(*locations[0] + locations[1]), 'm')
entities = args[2]
if isinstance(entities, (AllStates, DomainStates)):
states = list(entities)
else:
if isinstance(entities, State):
gr_entity_id = entities.entity_id
else:
gr_entity_id = str(entities)
def is_state(hass, entity_id: str, state: State) -> bool:
"""Test if a state is a specific value."""
state_obj = _get_state(hass, entity_id)
return state_obj is not None and state_obj.state == state
_collect_state(self._hass, gr_entity_id)
group = self._hass.components.group
states = [_get_state(self._hass, entity_id) for entity_id
in group.expand_entity_ids([gr_entity_id])]
def is_state_attr(hass, entity_id, name, value):
"""Test if a state's attribute is a specific value."""
attr = state_attr(hass, entity_id, name)
return attr is not None and attr == value
# state will already be wrapped here
return loc_helper.closest(latitude, longitude, states)
def distance(self, *args):
"""Calculate distance.
Will calculate distance from home to a point or between points.
Points can be passed in using state objects or lat/lng coordinates.
"""
locations = []
to_process = list(args)
while to_process:
value = to_process.pop(0)
point_state = self._resolve_state(value)
if point_state is None:
# We expect this and next value to be lat&lng
if not to_process:
_LOGGER.warning(
"Distance:Expected latitude and longitude, got %s",
value)
return None
value_2 = to_process.pop(0)
latitude = convert(value, float)
longitude = convert(value_2, float)
if latitude is None or longitude is None:
_LOGGER.warning("Distance:Unable to process latitude and "
"longitude: %s, %s", value, value_2)
return None
else:
if not loc_helper.has_location(point_state):
_LOGGER.warning(
"distance:State does not contain valid location: %s",
point_state)
return None
latitude = point_state.attributes.get(ATTR_LATITUDE)
longitude = point_state.attributes.get(ATTR_LONGITUDE)
locations.append((latitude, longitude))
if len(locations) == 1:
return self._hass.config.distance(*locations[0])
return self._hass.config.units.length(
loc_util.distance(*locations[0] + locations[1]), 'm')
def is_state(self, entity_id: str, state: State) -> bool:
"""Test if a state is a specific value."""
state_obj = _get_state(self._hass, entity_id)
return state_obj is not None and state_obj.state == state
def is_state_attr(self, entity_id, name, value):
"""Test if a state's attribute is a specific value."""
state_attr = self.state_attr(entity_id, name)
return state_attr is not None and state_attr == value
def state_attr(self, entity_id, name):
"""Get a specific attribute from a state."""
state_obj = _get_state(self._hass, entity_id)
if state_obj is not None:
return state_obj.attributes.get(name)
return None
def _resolve_state(self, entity_id_or_state):
"""Return state or entity_id if given."""
if isinstance(entity_id_or_state, State):
return entity_id_or_state
if isinstance(entity_id_or_state, str):
return _get_state(self._hass, entity_id_or_state)
return None
def state_attr(hass, entity_id, name):
"""Get a specific attribute from a state."""
state_obj = _get_state(hass, entity_id)
if state_obj is not None:
return state_obj.attributes.get(name)
return None
def forgiving_round(value, precision=0, method="common"):
@ -790,6 +826,71 @@ def random_every_time(context, values):
class TemplateEnvironment(ImmutableSandboxedEnvironment):
"""The Home Assistant template environment."""
def __init__(self, hass):
"""Initialise template environment."""
super().__init__()
self.hass = hass
self.filters['round'] = forgiving_round
self.filters['multiply'] = multiply
self.filters['log'] = logarithm
self.filters['sin'] = sine
self.filters['cos'] = cosine
self.filters['tan'] = tangent
self.filters['sqrt'] = square_root
self.filters['as_timestamp'] = forgiving_as_timestamp
self.filters['timestamp_custom'] = timestamp_custom
self.filters['timestamp_local'] = timestamp_local
self.filters['timestamp_utc'] = timestamp_utc
self.filters['is_defined'] = fail_when_undefined
self.filters['max'] = max
self.filters['min'] = min
self.filters['random'] = random_every_time
self.filters['base64_encode'] = base64_encode
self.filters['base64_decode'] = base64_decode
self.filters['ordinal'] = ordinal
self.filters['regex_match'] = regex_match
self.filters['regex_replace'] = regex_replace
self.filters['regex_search'] = regex_search
self.filters['regex_findall_index'] = regex_findall_index
self.filters['bitwise_and'] = bitwise_and
self.filters['bitwise_or'] = bitwise_or
self.globals['log'] = logarithm
self.globals['sin'] = sine
self.globals['cos'] = cosine
self.globals['tan'] = tangent
self.globals['sqrt'] = square_root
self.globals['pi'] = math.pi
self.globals['tau'] = math.pi * 2
self.globals['e'] = math.e
self.globals['float'] = forgiving_float
self.globals['now'] = dt_util.now
self.globals['utcnow'] = dt_util.utcnow
self.globals['as_timestamp'] = forgiving_as_timestamp
self.globals['relative_time'] = dt_util.get_age
self.globals['strptime'] = strptime
if hass is None:
return
# We mark these as a context functions to ensure they get
# evaluated fresh with every execution, rather than executed
# at compile time and the value stored. The context itself
# can be discarded, we only need to get at the hass object.
def hassfunction(func):
"""Wrap function that depend on hass."""
@wraps(func)
def wrapper(*args, **kwargs):
return func(hass, *args[1:], **kwargs)
return contextfunction(wrapper)
self.globals['expand'] = hassfunction(expand)
self.filters['expand'] = contextfilter(self.globals['expand'])
self.globals['closest'] = hassfunction(closest)
self.filters['closest'] = contextfilter(hassfunction(closest_filter))
self.globals['distance'] = hassfunction(distance)
self.globals['is_state'] = hassfunction(is_state)
self.globals['is_state_attr'] = hassfunction(is_state_attr)
self.globals['state_attr'] = hassfunction(state_attr)
self.globals['states'] = AllStates(hass)
def is_safe_callable(self, obj):
"""Test if callback is safe."""
return isinstance(obj, AllStates) or super().is_safe_callable(obj)
@ -800,42 +901,4 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
super().is_safe_attribute(obj, attr, value)
ENV = TemplateEnvironment()
ENV.filters['round'] = forgiving_round
ENV.filters['multiply'] = multiply
ENV.filters['log'] = logarithm
ENV.filters['sin'] = sine
ENV.filters['cos'] = cosine
ENV.filters['tan'] = tangent
ENV.filters['sqrt'] = square_root
ENV.filters['as_timestamp'] = forgiving_as_timestamp
ENV.filters['timestamp_custom'] = timestamp_custom
ENV.filters['timestamp_local'] = timestamp_local
ENV.filters['timestamp_utc'] = timestamp_utc
ENV.filters['is_defined'] = fail_when_undefined
ENV.filters['max'] = max
ENV.filters['min'] = min
ENV.filters['random'] = random_every_time
ENV.filters['base64_encode'] = base64_encode
ENV.filters['base64_decode'] = base64_decode
ENV.filters['ordinal'] = ordinal
ENV.filters['regex_match'] = regex_match
ENV.filters['regex_replace'] = regex_replace
ENV.filters['regex_search'] = regex_search
ENV.filters['regex_findall_index'] = regex_findall_index
ENV.filters['bitwise_and'] = bitwise_and
ENV.filters['bitwise_or'] = bitwise_or
ENV.globals['log'] = logarithm
ENV.globals['sin'] = sine
ENV.globals['cos'] = cosine
ENV.globals['tan'] = tangent
ENV.globals['sqrt'] = square_root
ENV.globals['pi'] = math.pi
ENV.globals['tau'] = math.pi * 2
ENV.globals['e'] = math.e
ENV.globals['float'] = forgiving_float
ENV.globals['now'] = dt_util.now
ENV.globals['utcnow'] = dt_util.utcnow
ENV.globals['as_timestamp'] = forgiving_as_timestamp
ENV.globals['relative_time'] = dt_util.get_age
ENV.globals['strptime'] = strptime
_NO_HASS_ENV = TemplateEnvironment(None)

View file

@ -5,7 +5,6 @@ import unittest
from unittest.mock import patch
import pytest
import pytz
from homeassistant.helpers import template
from homeassistant.const import STATE_UNKNOWN
from homeassistant.setup import setup_component
@ -50,10 +49,12 @@ class TestHistoryStatsSensor(unittest.TestCase):
state = self.hass.states.get('sensor.test')
assert state.state == STATE_UNKNOWN
def test_period_parsing(self):
@patch('homeassistant.helpers.template.TemplateEnvironment.'
'is_safe_callable', return_value=True)
def test_period_parsing(self, mock):
"""Test the conversion from templates to period."""
now = datetime(2019, 1, 1, 23, 30, 0, tzinfo=pytz.utc)
with patch.dict(template.ENV.globals, {'now': lambda: now}):
with patch('homeassistant.util.dt.now', return_value=now):
today = Template('{{ now().replace(hour=0).replace(minute=0)'
'.replace(second=0) }}', self.hass)
duration = timedelta(hours=2, minutes=1)

View file

@ -620,7 +620,7 @@ def test_states_function(hass):
def test_now(mock_is_safe, hass):
"""Test now method."""
now = dt_util.now()
with patch.dict(template.ENV.globals, {'now': lambda: now}):
with patch('homeassistant.util.dt.now', return_value=now):
assert now.isoformat() == \
template.Template('{{ now().isoformat() }}',
hass).async_render()
@ -631,7 +631,7 @@ def test_now(mock_is_safe, hass):
def test_utcnow(mock_is_safe, hass):
"""Test utcnow method."""
now = dt_util.utcnow()
with patch.dict(template.ENV.globals, {'utcnow': lambda: now}):
with patch('homeassistant.util.dt.utcnow', return_value=now):
assert now.isoformat() == \
template.Template('{{ utcnow().isoformat() }}',
hass).async_render()
@ -882,6 +882,9 @@ def test_closest_function_home_vs_domain(hass):
assert template.Template('{{ closest(states.test_domain).entity_id }}',
hass).async_render() == 'test_domain.object'
assert template.Template('{{ (states.test_domain | closest).entity_id }}',
hass).async_render() == 'test_domain.object'
def test_closest_function_home_vs_all_states(hass):
"""Test closest function home vs all states."""
@ -898,6 +901,9 @@ def test_closest_function_home_vs_all_states(hass):
assert template.Template('{{ closest(states).entity_id }}',
hass).async_render() == 'test_domain_2.and_closer'
assert template.Template('{{ (states | closest).entity_id }}',
hass).async_render() == 'test_domain_2.and_closer'
async def test_closest_function_home_vs_group_entity_id(hass):
"""Test closest function home vs group entity id."""
@ -948,6 +954,74 @@ async def test_closest_function_home_vs_group_state(hass):
['test_domain.object', 'group.location_group'])
async def test_expand(hass):
"""Test expand function."""
info = render_to_info(
hass, "{{ expand('test.object') }}")
assert_result_info(
info, '[]',
['test.object'])
info = render_to_info(
hass, "{{ expand(56) }}")
assert_result_info(
info, '[]')
hass.states.async_set('test.object', 'happy')
info = render_to_info(
hass, "{{ expand('test.object') | map(attribute='entity_id')"
" | join(', ') }}")
assert_result_info(
info, 'test.object',
[])
info = render_to_info(
hass, "{{ expand('group.new_group') | map(attribute='entity_id')"
" | join(', ') }}")
assert_result_info(
info, '',
['group.new_group'])
info = render_to_info(
hass, "{{ expand(states.group) | map(attribute='entity_id')"
" | join(', ') }}")
assert_result_info(
info, '',
[], ['group'])
await group.Group.async_create_group(
hass, 'new group', ['test.object'])
info = render_to_info(
hass, "{{ expand('group.new_group') | map(attribute='entity_id')"
" | join(', ') }}")
assert_result_info(
info, 'test.object',
['group.new_group'])
info = render_to_info(
hass, "{{ expand(states.group) | map(attribute='entity_id')"
" | join(', ') }}")
assert_result_info(
info, 'test.object',
['group.new_group'], ['group'])
info = render_to_info(
hass, "{{ expand('group.new_group', 'test.object')"
" | map(attribute='entity_id') | join(', ') }}")
assert_result_info(
info, 'test.object',
['group.new_group'])
info = render_to_info(
hass, "{{ ['group.new_group', 'test.object'] | expand"
" | map(attribute='entity_id') | join(', ') }}")
assert_result_info(
info, 'test.object',
['group.new_group'])
def test_closest_function_to_coord(hass):
"""Test closest function to coord."""
hass.states.async_set('test_domain.closest_home', 'happy', {
@ -972,6 +1046,13 @@ def test_closest_function_to_coord(hass):
assert tpl.async_render() == 'test_domain.closest_zone'
tpl = template.Template(
'{{ (states.test_domain | closest("%s", %s)).entity_id }}'
% (hass.config.latitude + 0.3,
hass.config.longitude + 0.3), hass)
assert tpl.async_render() == 'test_domain.closest_zone'
def test_closest_function_to_entity_id(hass):
"""Test closest function to entity id."""
@ -1003,6 +1084,20 @@ def test_closest_function_to_entity_id(hass):
'zone.far_away'],
["test_domain"])
info = render_to_info(
hass,
"{{ ([states.test_domain, 'test_domain.closest_zone'] "
"| closest(zone)).entity_id }}",
{
'zone': 'zone.far_away'
})
assert_result_info(
info, 'test_domain.closest_zone',
['test_domain.closest_home', 'test_domain.closest_zone',
'zone.far_away'],
["test_domain"])
def test_closest_function_to_state(hass):
"""Test closest function to state."""
@ -1060,6 +1155,8 @@ def test_closest_function_invalid_coordinates(hass):
assert template.Template('{{ closest("invalid", "coord", states) }}',
hass).async_render() == 'None'
assert template.Template('{{ states | closest("invalid", "coord") }}',
hass).async_render() == 'None'
def test_closest_function_no_location_states(hass):