From 22d9bee41aa71e53a93527faa6fe748f32a9b70b Mon Sep 17 00:00:00 2001 From: Penny Wood Date: Sat, 22 Jun 2019 15:32:32 +0800 Subject: [PATCH] Template: Expand method to expand groups, and closest as filter (#23691) * Implement expand method * Allow expand and closest to be used as filters * Correct patch * Addresses review comments --- homeassistant/helpers/template.py | 423 ++++++++++-------- tests/components/history_stats/test_sensor.py | 7 +- tests/helpers/test_template.py | 101 ++++- 3 files changed, 346 insertions(+), 185 deletions(-) diff --git a/homeassistant/helpers/template.py b/homeassistant/helpers/template.py index 203e460aaa5..55db75642f4 100644 --- a/homeassistant/helpers/template.py +++ b/homeassistant/helpers/template.py @@ -6,23 +6,24 @@ import math import random import re from datetime import datetime +from functools import wraps +from typing import Iterable import jinja2 -from jinja2 import contextfilter +from jinja2 import contextfilter, contextfunction from jinja2.sandbox import ImmutableSandboxedEnvironment from jinja2.utils import Namespace -from homeassistant.const import (ATTR_LATITUDE, ATTR_LONGITUDE, MATCH_ALL, - ATTR_UNIT_OF_MEASUREMENT, STATE_UNKNOWN) +from homeassistant.const import ( + ATTR_ENTITY_ID, ATTR_LATITUDE, ATTR_LONGITUDE, ATTR_UNIT_OF_MEASUREMENT, + MATCH_ALL, STATE_UNKNOWN) from homeassistant.core import ( - State, callback, valid_entity_id, split_entity_id) + State, callback, split_entity_id, valid_entity_id) from homeassistant.exceptions import TemplateError from homeassistant.helpers import location as loc_helper from homeassistant.helpers.typing import TemplateVarsType from homeassistant.loader import bind_hass -from homeassistant.util import convert -from homeassistant.util import dt as dt_util -from homeassistant.util import location as loc_util +from homeassistant.util import convert, dt as dt_util, location as loc_util from homeassistant.util.async_ import run_callback_threadsafe _LOGGER = logging.getLogger(__name__) @@ -30,6 +31,7 @@ _SENTINEL = object() DATE_STR_FORMAT = "%Y-%m-%d %H:%M:%S" _RENDER_INFO = 'template.render_info' +_ENVIRONMENT = 'template.environment' _RE_NONE_ENTITIES = re.compile(r"distance\(|closest\(", re.I | re.M) _RE_GET_ENTITIES = re.compile( @@ -152,13 +154,22 @@ class Template: self._compiled = None self.hass = hass + @property + def _env(self): + if self.hass is None: + return _NO_HASS_ENV + ret = self.hass.data.get(_ENVIRONMENT) + if ret is None: + ret = self.hass.data[_ENVIRONMENT] = TemplateEnvironment(self.hass) + return ret + def ensure_valid(self): """Return if template is valid.""" if self._compiled_code is not None: return try: - self._compiled_code = ENV.compile(self.template) + self._compiled_code = self._env.compile(self.template) except jinja2.exceptions.TemplateSyntaxError as err: raise TemplateError(err) @@ -254,19 +265,10 @@ class Template: assert self.hass is not None, 'hass variable not set on template' - template_methods = TemplateMethods(self.hass) - - global_vars = ENV.make_globals({ - 'closest': template_methods.closest, - 'distance': template_methods.distance, - 'is_state': template_methods.is_state, - 'is_state_attr': template_methods.is_state_attr, - 'state_attr': template_methods.state_attr, - 'states': AllStates(self.hass), - }) + env = self._env self._compiled = jinja2.Template.from_code( - ENV, self._compiled_code, global_vars, None) + env, self._compiled_code, env.globals, None) return self._compiled @@ -384,6 +386,7 @@ class TemplateState(State): def _access_state(self): state = object.__getattribute__(self, '_state') hass = object.__getattribute__(self, '_hass') + _collect_state(hass, state.entity_id) return state @@ -438,151 +441,184 @@ def _get_state(hass, entity_id): return _wrap_state(hass, state) -class TemplateMethods: - """Class to expose helpers to templates.""" +def _resolve_state(hass, entity_id_or_state): + """Return state or entity_id if given.""" + if isinstance(entity_id_or_state, State): + return entity_id_or_state + if isinstance(entity_id_or_state, str): + return _get_state(hass, entity_id_or_state) + return None - def __init__(self, hass): - """Initialize the helpers.""" - self._hass = hass - def closest(self, *args): - """Find closest entity. +def expand(hass, *args) -> Iterable[State]: + """Expand out any groups into entity states.""" + search = list(args) + found = {} + while search: + entity = search.pop() + if isinstance(entity, str): + entity_id = entity + entity = _get_state(hass, entity) + if entity is None: + continue + elif isinstance(entity, State): + entity_id = entity.entity_id + elif isinstance(entity, Iterable): + search += entity + continue + else: + # ignore other types + continue - Closest to home: - closest(states) - closest(states.device_tracker) - closest('group.children') - closest(states.group.children) + from homeassistant.components import group + if split_entity_id(entity_id)[0] == group.DOMAIN: + # Collect state will be called in here since it's wrapped + group_entities = entity.attributes.get(ATTR_ENTITY_ID) + if group_entities: + search += group_entities + else: + found[entity_id] = entity + return sorted(found.values(), key=lambda a: a.entity_id) - Closest to a point: - closest(23.456, 23.456, 'group.children') - closest('zone.school', 'group.children') - closest(states.zone.school, 'group.children') - """ - if len(args) == 1: - latitude = self._hass.config.latitude - longitude = self._hass.config.longitude - entities = args[0] - elif len(args) == 2: - point_state = self._resolve_state(args[0]) +def closest(hass, *args): + """Find closest entity. - if point_state is None: - _LOGGER.warning("Closest:Unable to find state %s", args[0]) + Closest to home: + closest(states) + closest(states.device_tracker) + closest('group.children') + closest(states.group.children) + + Closest to a point: + closest(23.456, 23.456, 'group.children') + closest('zone.school', 'group.children') + closest(states.zone.school, 'group.children') + + As a filter: + states | closest + states.device_tracker | closest + ['group.children', states.device_tracker] | closest + 'group.children' | closest(23.456, 23.456) + states.device_tracker | closest('zone.school') + 'group.children' | closest(states.zone.school) + + """ + if len(args) == 1: + latitude = hass.config.latitude + longitude = hass.config.longitude + entities = args[0] + + elif len(args) == 2: + point_state = _resolve_state(hass, args[0]) + + if point_state is None: + _LOGGER.warning("Closest:Unable to find state %s", args[0]) + return None + if not loc_helper.has_location(point_state): + _LOGGER.warning( + "Closest:State does not contain valid location: %s", + point_state) + return None + + latitude = point_state.attributes.get(ATTR_LATITUDE) + longitude = point_state.attributes.get(ATTR_LONGITUDE) + + entities = args[1] + + else: + latitude = convert(args[0], float) + longitude = convert(args[1], float) + + if latitude is None or longitude is None: + _LOGGER.warning( + "Closest:Received invalid coordinates: %s, %s", + args[0], args[1]) + return None + + entities = args[2] + + states = expand(hass, entities) + + # state will already be wrapped here + return loc_helper.closest(latitude, longitude, states) + + +def closest_filter(hass, *args): + """Call closest as a filter. Need to reorder arguments.""" + new_args = list(args[1:]) + new_args.append(args[0]) + return closest(hass, *new_args) + + +def distance(hass, *args): + """Calculate distance. + + Will calculate distance from home to a point or between points. + Points can be passed in using state objects or lat/lng coordinates. + """ + locations = [] + + to_process = list(args) + + while to_process: + value = to_process.pop(0) + point_state = _resolve_state(hass, value) + + if point_state is None: + # We expect this and next value to be lat&lng + if not to_process: + _LOGGER.warning( + "Distance:Expected latitude and longitude, got %s", + value) return None + + value_2 = to_process.pop(0) + latitude = convert(value, float) + longitude = convert(value_2, float) + + if latitude is None or longitude is None: + _LOGGER.warning("Distance:Unable to process latitude and " + "longitude: %s, %s", value, value_2) + return None + + else: if not loc_helper.has_location(point_state): _LOGGER.warning( - "Closest:State does not contain valid location: %s", + "distance:State does not contain valid location: %s", point_state) return None latitude = point_state.attributes.get(ATTR_LATITUDE) longitude = point_state.attributes.get(ATTR_LONGITUDE) - entities = args[1] + locations.append((latitude, longitude)) - else: - latitude = convert(args[0], float) - longitude = convert(args[1], float) + if len(locations) == 1: + return hass.config.distance(*locations[0]) - if latitude is None or longitude is None: - _LOGGER.warning( - "Closest:Received invalid coordinates: %s, %s", - args[0], args[1]) - return None + return hass.config.units.length( + loc_util.distance(*locations[0] + locations[1]), 'm') - entities = args[2] - if isinstance(entities, (AllStates, DomainStates)): - states = list(entities) - else: - if isinstance(entities, State): - gr_entity_id = entities.entity_id - else: - gr_entity_id = str(entities) +def is_state(hass, entity_id: str, state: State) -> bool: + """Test if a state is a specific value.""" + state_obj = _get_state(hass, entity_id) + return state_obj is not None and state_obj.state == state - _collect_state(self._hass, gr_entity_id) - group = self._hass.components.group - states = [_get_state(self._hass, entity_id) for entity_id - in group.expand_entity_ids([gr_entity_id])] +def is_state_attr(hass, entity_id, name, value): + """Test if a state's attribute is a specific value.""" + attr = state_attr(hass, entity_id, name) + return attr is not None and attr == value - # state will already be wrapped here - return loc_helper.closest(latitude, longitude, states) - def distance(self, *args): - """Calculate distance. - - Will calculate distance from home to a point or between points. - Points can be passed in using state objects or lat/lng coordinates. - """ - locations = [] - - to_process = list(args) - - while to_process: - value = to_process.pop(0) - point_state = self._resolve_state(value) - - if point_state is None: - # We expect this and next value to be lat&lng - if not to_process: - _LOGGER.warning( - "Distance:Expected latitude and longitude, got %s", - value) - return None - - value_2 = to_process.pop(0) - latitude = convert(value, float) - longitude = convert(value_2, float) - - if latitude is None or longitude is None: - _LOGGER.warning("Distance:Unable to process latitude and " - "longitude: %s, %s", value, value_2) - return None - - else: - if not loc_helper.has_location(point_state): - _LOGGER.warning( - "distance:State does not contain valid location: %s", - point_state) - return None - - latitude = point_state.attributes.get(ATTR_LATITUDE) - longitude = point_state.attributes.get(ATTR_LONGITUDE) - - locations.append((latitude, longitude)) - - if len(locations) == 1: - return self._hass.config.distance(*locations[0]) - - return self._hass.config.units.length( - loc_util.distance(*locations[0] + locations[1]), 'm') - - def is_state(self, entity_id: str, state: State) -> bool: - """Test if a state is a specific value.""" - state_obj = _get_state(self._hass, entity_id) - return state_obj is not None and state_obj.state == state - - def is_state_attr(self, entity_id, name, value): - """Test if a state's attribute is a specific value.""" - state_attr = self.state_attr(entity_id, name) - return state_attr is not None and state_attr == value - - def state_attr(self, entity_id, name): - """Get a specific attribute from a state.""" - state_obj = _get_state(self._hass, entity_id) - if state_obj is not None: - return state_obj.attributes.get(name) - return None - - def _resolve_state(self, entity_id_or_state): - """Return state or entity_id if given.""" - if isinstance(entity_id_or_state, State): - return entity_id_or_state - if isinstance(entity_id_or_state, str): - return _get_state(self._hass, entity_id_or_state) - return None +def state_attr(hass, entity_id, name): + """Get a specific attribute from a state.""" + state_obj = _get_state(hass, entity_id) + if state_obj is not None: + return state_obj.attributes.get(name) + return None def forgiving_round(value, precision=0, method="common"): @@ -790,6 +826,71 @@ def random_every_time(context, values): class TemplateEnvironment(ImmutableSandboxedEnvironment): """The Home Assistant template environment.""" + def __init__(self, hass): + """Initialise template environment.""" + super().__init__() + self.hass = hass + self.filters['round'] = forgiving_round + self.filters['multiply'] = multiply + self.filters['log'] = logarithm + self.filters['sin'] = sine + self.filters['cos'] = cosine + self.filters['tan'] = tangent + self.filters['sqrt'] = square_root + self.filters['as_timestamp'] = forgiving_as_timestamp + self.filters['timestamp_custom'] = timestamp_custom + self.filters['timestamp_local'] = timestamp_local + self.filters['timestamp_utc'] = timestamp_utc + self.filters['is_defined'] = fail_when_undefined + self.filters['max'] = max + self.filters['min'] = min + self.filters['random'] = random_every_time + self.filters['base64_encode'] = base64_encode + self.filters['base64_decode'] = base64_decode + self.filters['ordinal'] = ordinal + self.filters['regex_match'] = regex_match + self.filters['regex_replace'] = regex_replace + self.filters['regex_search'] = regex_search + self.filters['regex_findall_index'] = regex_findall_index + self.filters['bitwise_and'] = bitwise_and + self.filters['bitwise_or'] = bitwise_or + self.globals['log'] = logarithm + self.globals['sin'] = sine + self.globals['cos'] = cosine + self.globals['tan'] = tangent + self.globals['sqrt'] = square_root + self.globals['pi'] = math.pi + self.globals['tau'] = math.pi * 2 + self.globals['e'] = math.e + self.globals['float'] = forgiving_float + self.globals['now'] = dt_util.now + self.globals['utcnow'] = dt_util.utcnow + self.globals['as_timestamp'] = forgiving_as_timestamp + self.globals['relative_time'] = dt_util.get_age + self.globals['strptime'] = strptime + if hass is None: + return + + # We mark these as a context functions to ensure they get + # evaluated fresh with every execution, rather than executed + # at compile time and the value stored. The context itself + # can be discarded, we only need to get at the hass object. + def hassfunction(func): + """Wrap function that depend on hass.""" + @wraps(func) + def wrapper(*args, **kwargs): + return func(hass, *args[1:], **kwargs) + return contextfunction(wrapper) + self.globals['expand'] = hassfunction(expand) + self.filters['expand'] = contextfilter(self.globals['expand']) + self.globals['closest'] = hassfunction(closest) + self.filters['closest'] = contextfilter(hassfunction(closest_filter)) + self.globals['distance'] = hassfunction(distance) + self.globals['is_state'] = hassfunction(is_state) + self.globals['is_state_attr'] = hassfunction(is_state_attr) + self.globals['state_attr'] = hassfunction(state_attr) + self.globals['states'] = AllStates(hass) + def is_safe_callable(self, obj): """Test if callback is safe.""" return isinstance(obj, AllStates) or super().is_safe_callable(obj) @@ -800,42 +901,4 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment): super().is_safe_attribute(obj, attr, value) -ENV = TemplateEnvironment() -ENV.filters['round'] = forgiving_round -ENV.filters['multiply'] = multiply -ENV.filters['log'] = logarithm -ENV.filters['sin'] = sine -ENV.filters['cos'] = cosine -ENV.filters['tan'] = tangent -ENV.filters['sqrt'] = square_root -ENV.filters['as_timestamp'] = forgiving_as_timestamp -ENV.filters['timestamp_custom'] = timestamp_custom -ENV.filters['timestamp_local'] = timestamp_local -ENV.filters['timestamp_utc'] = timestamp_utc -ENV.filters['is_defined'] = fail_when_undefined -ENV.filters['max'] = max -ENV.filters['min'] = min -ENV.filters['random'] = random_every_time -ENV.filters['base64_encode'] = base64_encode -ENV.filters['base64_decode'] = base64_decode -ENV.filters['ordinal'] = ordinal -ENV.filters['regex_match'] = regex_match -ENV.filters['regex_replace'] = regex_replace -ENV.filters['regex_search'] = regex_search -ENV.filters['regex_findall_index'] = regex_findall_index -ENV.filters['bitwise_and'] = bitwise_and -ENV.filters['bitwise_or'] = bitwise_or -ENV.globals['log'] = logarithm -ENV.globals['sin'] = sine -ENV.globals['cos'] = cosine -ENV.globals['tan'] = tangent -ENV.globals['sqrt'] = square_root -ENV.globals['pi'] = math.pi -ENV.globals['tau'] = math.pi * 2 -ENV.globals['e'] = math.e -ENV.globals['float'] = forgiving_float -ENV.globals['now'] = dt_util.now -ENV.globals['utcnow'] = dt_util.utcnow -ENV.globals['as_timestamp'] = forgiving_as_timestamp -ENV.globals['relative_time'] = dt_util.get_age -ENV.globals['strptime'] = strptime +_NO_HASS_ENV = TemplateEnvironment(None) diff --git a/tests/components/history_stats/test_sensor.py b/tests/components/history_stats/test_sensor.py index 05a2d585d16..beceb32154e 100644 --- a/tests/components/history_stats/test_sensor.py +++ b/tests/components/history_stats/test_sensor.py @@ -5,7 +5,6 @@ import unittest from unittest.mock import patch import pytest import pytz -from homeassistant.helpers import template from homeassistant.const import STATE_UNKNOWN from homeassistant.setup import setup_component @@ -50,10 +49,12 @@ class TestHistoryStatsSensor(unittest.TestCase): state = self.hass.states.get('sensor.test') assert state.state == STATE_UNKNOWN - def test_period_parsing(self): + @patch('homeassistant.helpers.template.TemplateEnvironment.' + 'is_safe_callable', return_value=True) + def test_period_parsing(self, mock): """Test the conversion from templates to period.""" now = datetime(2019, 1, 1, 23, 30, 0, tzinfo=pytz.utc) - with patch.dict(template.ENV.globals, {'now': lambda: now}): + with patch('homeassistant.util.dt.now', return_value=now): today = Template('{{ now().replace(hour=0).replace(minute=0)' '.replace(second=0) }}', self.hass) duration = timedelta(hours=2, minutes=1) diff --git a/tests/helpers/test_template.py b/tests/helpers/test_template.py index 032f613d258..f7e4e7dd2ec 100644 --- a/tests/helpers/test_template.py +++ b/tests/helpers/test_template.py @@ -620,7 +620,7 @@ def test_states_function(hass): def test_now(mock_is_safe, hass): """Test now method.""" now = dt_util.now() - with patch.dict(template.ENV.globals, {'now': lambda: now}): + with patch('homeassistant.util.dt.now', return_value=now): assert now.isoformat() == \ template.Template('{{ now().isoformat() }}', hass).async_render() @@ -631,7 +631,7 @@ def test_now(mock_is_safe, hass): def test_utcnow(mock_is_safe, hass): """Test utcnow method.""" now = dt_util.utcnow() - with patch.dict(template.ENV.globals, {'utcnow': lambda: now}): + with patch('homeassistant.util.dt.utcnow', return_value=now): assert now.isoformat() == \ template.Template('{{ utcnow().isoformat() }}', hass).async_render() @@ -882,6 +882,9 @@ def test_closest_function_home_vs_domain(hass): assert template.Template('{{ closest(states.test_domain).entity_id }}', hass).async_render() == 'test_domain.object' + assert template.Template('{{ (states.test_domain | closest).entity_id }}', + hass).async_render() == 'test_domain.object' + def test_closest_function_home_vs_all_states(hass): """Test closest function home vs all states.""" @@ -898,6 +901,9 @@ def test_closest_function_home_vs_all_states(hass): assert template.Template('{{ closest(states).entity_id }}', hass).async_render() == 'test_domain_2.and_closer' + assert template.Template('{{ (states | closest).entity_id }}', + hass).async_render() == 'test_domain_2.and_closer' + async def test_closest_function_home_vs_group_entity_id(hass): """Test closest function home vs group entity id.""" @@ -948,6 +954,74 @@ async def test_closest_function_home_vs_group_state(hass): ['test_domain.object', 'group.location_group']) +async def test_expand(hass): + """Test expand function.""" + info = render_to_info( + hass, "{{ expand('test.object') }}") + assert_result_info( + info, '[]', + ['test.object']) + + info = render_to_info( + hass, "{{ expand(56) }}") + assert_result_info( + info, '[]') + + hass.states.async_set('test.object', 'happy') + + info = render_to_info( + hass, "{{ expand('test.object') | map(attribute='entity_id')" + " | join(', ') }}") + assert_result_info( + info, 'test.object', + []) + + info = render_to_info( + hass, "{{ expand('group.new_group') | map(attribute='entity_id')" + " | join(', ') }}") + assert_result_info( + info, '', + ['group.new_group']) + + info = render_to_info( + hass, "{{ expand(states.group) | map(attribute='entity_id')" + " | join(', ') }}") + assert_result_info( + info, '', + [], ['group']) + + await group.Group.async_create_group( + hass, 'new group', ['test.object']) + + info = render_to_info( + hass, "{{ expand('group.new_group') | map(attribute='entity_id')" + " | join(', ') }}") + assert_result_info( + info, 'test.object', + ['group.new_group']) + + info = render_to_info( + hass, "{{ expand(states.group) | map(attribute='entity_id')" + " | join(', ') }}") + assert_result_info( + info, 'test.object', + ['group.new_group'], ['group']) + + info = render_to_info( + hass, "{{ expand('group.new_group', 'test.object')" + " | map(attribute='entity_id') | join(', ') }}") + assert_result_info( + info, 'test.object', + ['group.new_group']) + + info = render_to_info( + hass, "{{ ['group.new_group', 'test.object'] | expand" + " | map(attribute='entity_id') | join(', ') }}") + assert_result_info( + info, 'test.object', + ['group.new_group']) + + def test_closest_function_to_coord(hass): """Test closest function to coord.""" hass.states.async_set('test_domain.closest_home', 'happy', { @@ -972,6 +1046,13 @@ def test_closest_function_to_coord(hass): assert tpl.async_render() == 'test_domain.closest_zone' + tpl = template.Template( + '{{ (states.test_domain | closest("%s", %s)).entity_id }}' + % (hass.config.latitude + 0.3, + hass.config.longitude + 0.3), hass) + + assert tpl.async_render() == 'test_domain.closest_zone' + def test_closest_function_to_entity_id(hass): """Test closest function to entity id.""" @@ -1003,6 +1084,20 @@ def test_closest_function_to_entity_id(hass): 'zone.far_away'], ["test_domain"]) + info = render_to_info( + hass, + "{{ ([states.test_domain, 'test_domain.closest_zone'] " + "| closest(zone)).entity_id }}", + { + 'zone': 'zone.far_away' + }) + + assert_result_info( + info, 'test_domain.closest_zone', + ['test_domain.closest_home', 'test_domain.closest_zone', + 'zone.far_away'], + ["test_domain"]) + def test_closest_function_to_state(hass): """Test closest function to state.""" @@ -1060,6 +1155,8 @@ def test_closest_function_invalid_coordinates(hass): assert template.Template('{{ closest("invalid", "coord", states) }}', hass).async_render() == 'None' + assert template.Template('{{ states | closest("invalid", "coord") }}', + hass).async_render() == 'None' def test_closest_function_no_location_states(hass):