render_with_collect method for template (#23283)

* Make entity_filter be a modifiable builder

* Add render_with_collect method

* Use sync render_with_collect and non-class based test case

* Refactor: Template renders to RenderInfo

* Freeze with exception too

* Finish merging test changes

* Removed unused sync interface

* Final bits of the diff
This commit is contained in:
Penny Wood 2019-05-01 10:54:25 +08:00 committed by GitHub
parent 581b16e9fa
commit 5b9d01139d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 474 additions and 90 deletions

View file

@ -23,7 +23,6 @@ from homeassistant.const import (
TEMP_CELSIUS, TEMP_FAHRENHEIT, WEEKDAYS, __version__) TEMP_CELSIUS, TEMP_FAHRENHEIT, WEEKDAYS, __version__)
from homeassistant.core import valid_entity_id, split_entity_id from homeassistant.core import valid_entity_id, split_entity_id
from homeassistant.exceptions import TemplateError from homeassistant.exceptions import TemplateError
from homeassistant.helpers import template as template_helper
from homeassistant.helpers.logging import KeywordStyleAdapter from homeassistant.helpers.logging import KeywordStyleAdapter
from homeassistant.util import slugify as util_slugify from homeassistant.util import slugify as util_slugify
@ -445,6 +444,8 @@ unit_system = vol.All(vol.Lower, vol.Any(CONF_UNIT_SYSTEM_METRIC,
def template(value): def template(value):
"""Validate a jinja2 template.""" """Validate a jinja2 template."""
from homeassistant.helpers import template as template_helper
if value is None: if value is None:
raise vol.Invalid('template value is None') raise vol.Invalid('template value is None')
if isinstance(value, (list, dict, template_helper.Template)): if isinstance(value, (list, dict, template_helper.Template)):

View file

@ -1,21 +1,21 @@
"""Template helper methods for rendering strings with Home Assistant data.""" """Template helper methods for rendering strings with Home Assistant data."""
from datetime import datetime import base64
import json import json
import logging import logging
import math import math
import random import random
import base64
import re import re
from datetime import datetime
import jinja2 import jinja2
from jinja2 import contextfilter from jinja2 import contextfilter
from jinja2.sandbox import ImmutableSandboxedEnvironment from jinja2.sandbox import ImmutableSandboxedEnvironment
from jinja2.utils import Namespace from jinja2.utils import Namespace
from homeassistant.const import ( from homeassistant.const import (ATTR_LATITUDE, ATTR_LONGITUDE, MATCH_ALL,
ATTR_LATITUDE, ATTR_LONGITUDE, ATTR_UNIT_OF_MEASUREMENT, MATCH_ALL, ATTR_UNIT_OF_MEASUREMENT, STATE_UNKNOWN)
STATE_UNKNOWN) from homeassistant.core import (
from homeassistant.core import State, valid_entity_id State, callback, valid_entity_id, split_entity_id)
from homeassistant.exceptions import TemplateError from homeassistant.exceptions import TemplateError
from homeassistant.helpers import location as loc_helper from homeassistant.helpers import location as loc_helper
from homeassistant.helpers.typing import TemplateVarsType from homeassistant.helpers.typing import TemplateVarsType
@ -29,6 +29,8 @@ _LOGGER = logging.getLogger(__name__)
_SENTINEL = object() _SENTINEL = object()
DATE_STR_FORMAT = "%Y-%m-%d %H:%M:%S" DATE_STR_FORMAT = "%Y-%m-%d %H:%M:%S"
_RENDER_INFO = 'template.render_info'
_RE_NONE_ENTITIES = re.compile(r"distance\(|closest\(", re.I | re.M) _RE_NONE_ENTITIES = re.compile(r"distance\(|closest\(", re.I | re.M)
_RE_GET_ENTITIES = re.compile( _RE_GET_ENTITIES = re.compile(
r"(?:(?:states\.|(?:is_state|is_state_attr|state_attr|states)" r"(?:(?:states\.|(?:is_state|is_state_attr|state_attr|states)"
@ -89,6 +91,54 @@ def extract_entities(template, variables=None):
return MATCH_ALL return MATCH_ALL
def _true(arg) -> bool:
return True
class RenderInfo:
"""Holds information about a template render."""
def __init__(self, template):
"""Initialise."""
self.template = template
# Will be set sensibly once frozen.
self.filter_lifecycle = _true
self._result = None
self._exception = None
self._all_states = False
self._domains = []
self._entities = []
def filter(self, entity_id: str) -> bool:
"""Template should re-render if the state changes."""
return entity_id in self._entities
def _filter_lifecycle(self, entity_id: str) -> bool:
"""Template should re-render if the state changes."""
return (
split_entity_id(entity_id)[0] in self._domains
or entity_id in self._entities)
@property
def result(self) -> str:
"""Results of the template computation."""
if self._exception is not None:
raise self._exception # pylint: disable=raising-bad-type
return self._result
def _freeze(self) -> None:
self._entities = frozenset(self._entities)
if self._all_states:
# Leave lifecycle_filter as True
del self._domains
elif not self._domains:
del self._domains
self.filter_lifecycle = self.filter
else:
self._domains = frozenset(self._domains)
self.filter_lifecycle = self._filter_lifecycle
class Template: class Template:
"""Class to hold a template and manage caching and rendering.""" """Class to hold a template and manage caching and rendering."""
@ -124,6 +174,7 @@ class Template:
return run_callback_threadsafe( return run_callback_threadsafe(
self.hass.loop, self.async_render, kwargs).result() self.hass.loop, self.async_render, kwargs).result()
@callback
def async_render(self, variables: TemplateVarsType = None, def async_render(self, variables: TemplateVarsType = None,
**kwargs) -> str: **kwargs) -> str:
"""Render given template. """Render given template.
@ -141,6 +192,23 @@ class Template:
except jinja2.TemplateError as err: except jinja2.TemplateError as err:
raise TemplateError(err) raise TemplateError(err)
@callback
def async_render_to_info(
self, variables: TemplateVarsType = None,
**kwargs) -> RenderInfo:
"""Render the template and collect an entity filter."""
assert self.hass and _RENDER_INFO not in self.hass.data
render_info = self.hass.data[_RENDER_INFO] = RenderInfo(self)
# pylint: disable=protected-access
try:
render_info._result = self.async_render(variables, **kwargs)
except TemplateError as ex:
render_info._exception = ex
finally:
del self.hass.data[_RENDER_INFO]
render_info._freeze()
return render_info
def render_with_possible_json_value(self, value, error_value=_SENTINEL): def render_with_possible_json_value(self, value, error_value=_SENTINEL):
"""Render template with value exposed. """Render template with value exposed.
@ -150,6 +218,7 @@ class Template:
self.hass.loop, self.async_render_with_possible_json_value, value, self.hass.loop, self.async_render_with_possible_json_value, value,
error_value).result() error_value).result()
@callback
def async_render_with_possible_json_value(self, value, def async_render_with_possible_json_value(self, value,
error_value=_SENTINEL, error_value=_SENTINEL,
variables=None): variables=None):
@ -190,7 +259,7 @@ class Template:
global_vars = ENV.make_globals({ global_vars = ENV.make_globals({
'closest': template_methods.closest, 'closest': template_methods.closest,
'distance': template_methods.distance, 'distance': template_methods.distance,
'is_state': self.hass.states.is_state, 'is_state': template_methods.is_state,
'is_state_attr': template_methods.is_state_attr, 'is_state_attr': template_methods.is_state_attr,
'state_attr': template_methods.state_attr, 'state_attr': template_methods.state_attr,
'states': AllStates(self.hass), 'states': AllStates(self.hass),
@ -207,6 +276,14 @@ class Template:
self.template == other.template and self.template == other.template and
self.hass == other.hass) self.hass == other.hass)
def __hash__(self):
"""Hash code for template."""
return hash(self.template)
def __repr__(self):
"""Representation of Template."""
return 'Template(\"' + self.template + '\")'
class AllStates: class AllStates:
"""Class to expose all HA states as attributes.""" """Class to expose all HA states as attributes."""
@ -217,24 +294,42 @@ class AllStates:
def __getattr__(self, name): def __getattr__(self, name):
"""Return the domain state.""" """Return the domain state."""
if '.' in name:
if not valid_entity_id(name):
raise TemplateError("Invalid entity ID '{}'".format(name))
return _get_state(self._hass, name)
if not valid_entity_id(name + '.entity'):
raise TemplateError("Invalid domain name '{}'".format(name))
return DomainStates(self._hass, name) return DomainStates(self._hass, name)
def _collect_all(self):
render_info = self._hass.data.get(_RENDER_INFO)
if render_info is not None:
# pylint: disable=protected-access
render_info._all_states = True
def __iter__(self): def __iter__(self):
"""Return all states.""" """Return all states."""
self._collect_all()
return iter( return iter(
_wrap_state(state) for state in _wrap_state(self._hass, state) for state in
sorted(self._hass.states.async_all(), sorted(self._hass.states.async_all(),
key=lambda state: state.entity_id)) key=lambda state: state.entity_id))
def __len__(self): def __len__(self):
"""Return number of states.""" """Return number of states."""
self._collect_all()
return len(self._hass.states.async_entity_ids()) return len(self._hass.states.async_entity_ids())
def __call__(self, entity_id): def __call__(self, entity_id):
"""Return the states.""" """Return the states."""
state = self._hass.states.get(entity_id) state = _get_state(self._hass, entity_id)
return STATE_UNKNOWN if state is None else state.state return STATE_UNKNOWN if state is None else state.state
def __repr__(self):
"""Representation of All States."""
return '<template AllStates>'
class DomainStates: class DomainStates:
"""Class to expose a specific HA domain as attributes.""" """Class to expose a specific HA domain as attributes."""
@ -246,34 +341,56 @@ class DomainStates:
def __getattr__(self, name): def __getattr__(self, name):
"""Return the states.""" """Return the states."""
return _wrap_state( entity_id = '{}.{}'.format(self._domain, name)
self._hass.states.get('{}.{}'.format(self._domain, name))) if not valid_entity_id(entity_id):
raise TemplateError("Invalid entity ID '{}'".format(entity_id))
return _get_state(self._hass, entity_id)
def _collect_domain(self):
entity_collect = self._hass.data.get(_RENDER_INFO)
if entity_collect is not None:
# pylint: disable=protected-access
entity_collect._domains.append(self._domain)
def __iter__(self): def __iter__(self):
"""Return the iteration over all the states.""" """Return the iteration over all the states."""
self._collect_domain()
return iter(sorted( return iter(sorted(
(_wrap_state(state) for state in self._hass.states.async_all() (_wrap_state(self._hass, state)
for state in self._hass.states.async_all()
if state.domain == self._domain), if state.domain == self._domain),
key=lambda state: state.entity_id)) key=lambda state: state.entity_id))
def __len__(self): def __len__(self):
"""Return number of states.""" """Return number of states."""
self._collect_domain()
return len(self._hass.states.async_entity_ids(self._domain)) return len(self._hass.states.async_entity_ids(self._domain))
def __repr__(self):
"""Representation of Domain States."""
return '<template DomainStates(\'{}\')>'.format(self._domain)
class TemplateState(State): class TemplateState(State):
"""Class to represent a state object in a template.""" """Class to represent a state object in a template."""
# Inheritance is done so functions that check against State keep working # Inheritance is done so functions that check against State keep working
# pylint: disable=super-init-not-called # pylint: disable=super-init-not-called
def __init__(self, state): def __init__(self, hass, state):
"""Initialize template state.""" """Initialize template state."""
self._hass = hass
self._state = state self._state = state
def _access_state(self):
state = object.__getattribute__(self, '_state')
hass = object.__getattribute__(self, '_hass')
_collect_state(hass, state.entity_id)
return state
@property @property
def state_with_unit(self): def state_with_unit(self):
"""Return the state concatenated with the unit if available.""" """Return the state concatenated with the unit if available."""
state = object.__getattribute__(self, '_state') state = object.__getattribute__(self, '_access_state')()
unit = state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) unit = state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)
if unit is None: if unit is None:
return state.state return state.state
@ -281,19 +398,44 @@ class TemplateState(State):
def __getattribute__(self, name): def __getattribute__(self, name):
"""Return an attribute of the state.""" """Return an attribute of the state."""
# This one doesn't count as an access of the state
# since we either found it by looking direct for the ID
# or got it off an iterator.
if name == 'entity_id' or name in object.__dict__:
state = object.__getattribute__(self, '_state')
return getattr(state, name)
if name in TemplateState.__dict__: if name in TemplateState.__dict__:
return object.__getattribute__(self, name) return object.__getattribute__(self, name)
return getattr(object.__getattribute__(self, '_state'), name) state = object.__getattribute__(self, '_access_state')()
return getattr(state, name)
def __repr__(self): def __repr__(self):
"""Representation of Template State.""" """Representation of Template State."""
rep = object.__getattribute__(self, '_state').__repr__() state = object.__getattribute__(self, '_access_state')()
rep = state.__repr__()
return '<template ' + rep[1:] return '<template ' + rep[1:]
def _wrap_state(state): def _collect_state(hass, entity_id):
entity_collect = hass.data.get(_RENDER_INFO)
if entity_collect is not None:
# pylint: disable=protected-access
entity_collect._entities.append(entity_id)
def _wrap_state(hass, state):
"""Wrap a state.""" """Wrap a state."""
return None if state is None else TemplateState(state) return None if state is None else TemplateState(hass, state)
def _get_state(hass, entity_id):
state = hass.states.get(entity_id)
if state is None:
# Only need to collect if none, if not none collect first actuall
# access to the state properties in the state wrapper.
_collect_state(hass, entity_id)
return None
return _wrap_state(hass, state)
class TemplateMethods: class TemplateMethods:
@ -359,12 +501,14 @@ class TemplateMethods:
else: else:
gr_entity_id = str(entities) gr_entity_id = str(entities)
group = self._hass.components.group _collect_state(self._hass, gr_entity_id)
states = [self._hass.states.get(entity_id) for entity_id group = self._hass.components.group
states = [_get_state(self._hass, entity_id) for entity_id
in group.expand_entity_ids([gr_entity_id])] in group.expand_entity_ids([gr_entity_id])]
return _wrap_state(loc_helper.closest(latitude, longitude, states)) # state will already be wrapped here
return loc_helper.closest(latitude, longitude, states)
def distance(self, *args): def distance(self, *args):
"""Calculate distance. """Calculate distance.
@ -407,12 +551,6 @@ class TemplateMethods:
latitude = point_state.attributes.get(ATTR_LATITUDE) latitude = point_state.attributes.get(ATTR_LATITUDE)
longitude = point_state.attributes.get(ATTR_LONGITUDE) longitude = point_state.attributes.get(ATTR_LONGITUDE)
if latitude is None or longitude is None:
_LOGGER.warning(
"Distance:State does not contains a location: %s",
value)
return None
locations.append((latitude, longitude)) locations.append((latitude, longitude))
if len(locations) == 1: if len(locations) == 1:
@ -421,14 +559,19 @@ class TemplateMethods:
return self._hass.config.units.length( return self._hass.config.units.length(
loc_util.distance(*locations[0] + locations[1]), 'm') loc_util.distance(*locations[0] + locations[1]), 'm')
def is_state(self, entity_id: str, state: State) -> bool:
"""Test if a state is a specific value."""
state_obj = _get_state(self._hass, entity_id)
return state_obj is not None and state_obj.state == state
def is_state_attr(self, entity_id, name, value): def is_state_attr(self, entity_id, name, value):
"""Test if a state is a specific attribute.""" """Test if a state's attribute is a specific value."""
state_attr = self.state_attr(entity_id, name) state_attr = self.state_attr(entity_id, name)
return state_attr is not None and state_attr == value return state_attr is not None and state_attr == value
def state_attr(self, entity_id, name): def state_attr(self, entity_id, name):
"""Get a specific attribute from a state.""" """Get a specific attribute from a state."""
state_obj = self._hass.states.get(entity_id) state_obj = _get_state(self._hass, entity_id)
if state_obj is not None: if state_obj is not None:
return state_obj.attributes.get(name) return state_obj.attributes.get(name)
return None return None
@ -438,7 +581,7 @@ class TemplateMethods:
if isinstance(entity_id_or_state, State): if isinstance(entity_id_or_state, State):
return entity_id_or_state return entity_id_or_state
if isinstance(entity_id_or_state, str): if isinstance(entity_id_or_state, str):
return self._hass.states.get(entity_id_or_state) return _get_state(self._hass, entity_id_or_state)
return None return None

View file

@ -1,25 +1,19 @@
"""Test Home Assistant template helper methods.""" """Test Home Assistant template helper methods."""
from datetime import datetime
import random
import math import math
import random
from datetime import datetime
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
import pytz import pytz
import homeassistant.util.dt as dt_util
from homeassistant.components import group from homeassistant.components import group
from homeassistant.const import (LENGTH_METERS, MASS_GRAMS, MATCH_ALL,
PRESSURE_PA, TEMP_CELSIUS, VOLUME_LITERS)
from homeassistant.exceptions import TemplateError from homeassistant.exceptions import TemplateError
from homeassistant.helpers import template from homeassistant.helpers import template
from homeassistant.util.unit_system import UnitSystem from homeassistant.util.unit_system import UnitSystem
from homeassistant.const import (
LENGTH_METERS,
TEMP_CELSIUS,
MASS_GRAMS,
PRESSURE_PA,
VOLUME_LITERS,
MATCH_ALL,
)
import homeassistant.util.dt as dt_util
def _set_up_units(hass): def _set_up_units(hass):
@ -29,32 +23,148 @@ def _set_up_units(hass):
MASS_GRAMS, PRESSURE_PA) MASS_GRAMS, PRESSURE_PA)
def render_to_info(hass, template_str, variables=None):
"""Create render info from template."""
tmp = template.Template(template_str, hass)
return tmp.async_render_to_info(variables)
def extract_entities(hass, template_str, variables=None):
"""Extract entities from a template."""
info = render_to_info(hass, template_str, variables)
# pylint: disable=protected-access
assert not hasattr(info, '_domains')
return info._entities
def assert_result_info(
info, result, entities=None, domains=None, all_states=False):
"""Check result info."""
assert info.result == result
# pylint: disable=protected-access
assert info._all_states == all_states
assert info.filter_lifecycle('invalid_entity_name.somewhere') == all_states
if entities is not None:
assert info._entities == frozenset(entities)
assert all([info.filter(entity) for entity in entities])
assert not info.filter('invalid_entity_name.somewhere')
else:
assert not info._entities
if domains is not None:
assert info._domains == frozenset(domains)
assert all([info.filter_lifecycle(domain + ".entity")
for domain in domains])
else:
assert not hasattr(info, '_domains')
def test_template_equality():
"""Test template comparison and hashing."""
template_one = template.Template("{{ template_one }}")
template_one_1 = template.Template("{{ template_" + "one }}")
template_two = template.Template("{{ template_two }}")
assert template_one == template_one_1
assert template_one != template_two
assert hash(template_one) == hash(template_one_1)
assert hash(template_one) != hash(template_two)
assert str(template_one_1) == 'Template("{{ template_one }}")'
with pytest.raises(TypeError):
template.Template(["{{ template_one }}"])
def test_invalid_template(hass):
"""Invalid template raises error."""
tmpl = template.Template("{{", hass)
with pytest.raises(TemplateError):
tmpl.ensure_valid()
with pytest.raises(TemplateError):
tmpl.async_render()
info = tmpl.async_render_to_info()
with pytest.raises(TemplateError):
assert info.result == "impossible"
tmpl = template.Template("{{states(keyword)}}", hass)
tmpl.ensure_valid()
with pytest.raises(TemplateError):
tmpl.async_render()
def test_referring_states_by_entity_id(hass): def test_referring_states_by_entity_id(hass):
"""Test referring states by entity id.""" """Test referring states by entity id."""
hass.states.async_set('test.object', 'happy') hass.states.async_set('test.object', 'happy')
assert template.Template( assert template.Template(
'{{ states.test.object.state }}', hass).async_render() == 'happy' '{{ states.test.object.state }}', hass).async_render() == 'happy'
assert template.Template(
'{{ states["test.object"].state }}',
hass).async_render() == 'happy'
assert template.Template(
'{{ states("test.object") }}', hass).async_render() == 'happy'
def test_invalid_entity_id(hass):
"""Test referring states by entity id."""
with pytest.raises(TemplateError):
template.Template(
'{{ states["big.fat..."] }}', hass).async_render()
with pytest.raises(TemplateError):
template.Template(
'{{ states.test["big.fat..."] }}', hass).async_render()
with pytest.raises(TemplateError):
template.Template(
'{{ states["invalid/domain"] }}', hass).async_render()
def test_raise_exception_on_error(hass):
"""Test raising an exception on error."""
with pytest.raises(TemplateError):
template.Template('{{ invalid_syntax').ensure_valid()
def test_iterating_all_states(hass): def test_iterating_all_states(hass):
"""Test iterating all states.""" """Test iterating all states."""
tmpl_str = '{% for state in states %}{{ state.state }}{% endfor %}'
info = render_to_info(hass, tmpl_str)
assert_result_info(info, '', all_states=True)
hass.states.async_set('test.object', 'happy') hass.states.async_set('test.object', 'happy')
hass.states.async_set('sensor.temperature', 10) hass.states.async_set('sensor.temperature', 10)
assert template.Template( info = render_to_info(hass, tmpl_str)
'{% for state in states %}{{ state.state }}{% endfor %}', assert_result_info(
hass).async_render() == '10happy' info, '10happy',
entities=['test.object', 'sensor.temperature'],
all_states=True)
def test_iterating_domain_states(hass): def test_iterating_domain_states(hass):
"""Test iterating domain states.""" """Test iterating domain states."""
tmpl_str = \
"{% for state in states.sensor %}" \
"{{ state.state }}{% endfor %}"
info = render_to_info(hass, tmpl_str)
assert_result_info(info, '', domains=['sensor'])
hass.states.async_set('test.object', 'happy') hass.states.async_set('test.object', 'happy')
hass.states.async_set('sensor.back_door', 'open') hass.states.async_set('sensor.back_door', 'open')
hass.states.async_set('sensor.temperature', 10) hass.states.async_set('sensor.temperature', 10)
assert template.Template(""" info = render_to_info(hass, tmpl_str)
{% for state in states.sensor %}{{ state.state }}{% endfor %} assert_result_info(
""", hass).async_render() == 'open10' info, 'open10',
entities=['sensor.back_door', 'sensor.temperature'],
domains=['sensor'])
def test_float(hass): def test_float(hass):
@ -69,6 +179,10 @@ def test_float(hass):
'{{ float(states.sensor.temperature.state) > 11 }}', '{{ float(states.sensor.temperature.state) > 11 }}',
hass).async_render() == 'True' hass).async_render() == 'True'
assert template.Template(
'{{ float(\'forgiving\') }}',
hass).async_render() == 'forgiving'
def test_rounding_value(hass): def test_rounding_value(hass):
"""Test rounding value.""" """Test rounding value."""
@ -140,7 +254,8 @@ def test_sine(hass):
(math.pi / 2, '1.0'), (math.pi / 2, '1.0'),
(math.pi, '0.0'), (math.pi, '0.0'),
(math.pi * 1.5, '-1.0'), (math.pi * 1.5, '-1.0'),
(math.pi / 10, '0.309') (math.pi / 10, '0.309'),
('"duck"', 'duck'),
] ]
for value, expected in tests: for value, expected in tests:
@ -156,7 +271,8 @@ def test_cos(hass):
(math.pi / 2, '0.0'), (math.pi / 2, '0.0'),
(math.pi, '-1.0'), (math.pi, '-1.0'),
(math.pi * 1.5, '-0.0'), (math.pi * 1.5, '-0.0'),
(math.pi / 10, '0.951') (math.pi / 10, '0.951'),
("'error'", 'error'),
] ]
for value, expected in tests: for value, expected in tests:
@ -172,7 +288,8 @@ def test_tan(hass):
(math.pi, '-0.0'), (math.pi, '-0.0'),
(math.pi / 180 * 45, '1.0'), (math.pi / 180 * 45, '1.0'),
(math.pi / 180 * 90, '1.633123935319537e+16'), (math.pi / 180 * 90, '1.633123935319537e+16'),
(math.pi / 180 * 135, '-1.0') (math.pi / 180 * 135, '-1.0'),
("'error'", 'error'),
] ]
for value, expected in tests: for value, expected in tests:
@ -189,6 +306,7 @@ def test_sqrt(hass):
(2, '1.414'), (2, '1.414'),
(10, '3.162'), (10, '3.162'),
(100, '10.0'), (100, '10.0'),
("'error'", 'error'),
] ]
for value, expected in tests: for value, expected in tests:
@ -290,6 +408,9 @@ def test_ordinal(hass):
(3, '3rd'), (3, '3rd'),
(4, '4th'), (4, '4th'),
(5, '5th'), (5, '5th'),
(12, '12th'),
(100, '100th'),
(101, '101st'),
] ]
for value, expected in tests: for value, expected in tests:
@ -433,12 +554,6 @@ def test_render_with_possible_json_value_non_string_value(hass):
assert tpl.async_render_with_possible_json_value(value) == expected assert tpl.async_render_with_possible_json_value(value) == expected
def test_raise_exception_on_error(hass):
"""Test raising an exception on error."""
with pytest.raises(TemplateError):
template.Template('{{ invalid_syntax').ensure_valid()
def test_if_state_exists(hass): def test_if_state_exists(hass):
"""Test if state exists works.""" """Test if state exists works."""
hass.states.async_set('test.object', 'available') hass.states.async_set('test.object', 'available')
@ -539,6 +654,11 @@ def test_regex_match(hass):
""", hass) """, hass)
assert tpl.async_render() == 'False' assert tpl.async_render() == 'False'
tpl = template.Template("""
{{ ['home assistant test'] | regex_match('.*assist') }}
""", hass)
assert tpl.async_render() == 'True'
def test_regex_search(hass): def test_regex_search(hass):
"""Test regex_search method.""" """Test regex_search method."""
@ -557,6 +677,11 @@ def test_regex_search(hass):
""", hass) """, hass)
assert tpl.async_render() == 'True' assert tpl.async_render() == 'True'
tpl = template.Template("""
{{ ['home assistant test'] | regex_search('assist') }}
""", hass)
assert tpl.async_render() == 'True'
def test_regex_replace(hass): def test_regex_replace(hass):
"""Test regex_replace method.""" """Test regex_replace method."""
@ -565,6 +690,11 @@ def test_regex_replace(hass):
""", hass) """, hass)
assert tpl.async_render() == 'World' assert tpl.async_render() == 'World'
tpl = template.Template("""
{{ ['home hinderant test'] | regex_replace('hinder', 'assist') }}
""", hass)
assert tpl.async_render() == "['home assistant test']"
def test_regex_findall_index(hass): def test_regex_findall_index(hass):
"""Test regex_findall_index method.""" """Test regex_findall_index method."""
@ -578,6 +708,11 @@ def test_regex_findall_index(hass):
""", hass) """, hass)
assert tpl.async_render() == 'LHR' assert tpl.async_render() == 'LHR'
tpl = template.Template("""
{{ ['JFK', 'LHR'] | regex_findall_index('([A-Z]{3})', 1) }}
""", hass)
assert tpl.async_render() == 'LHR'
def test_bitwise_and(hass): def test_bitwise_and(hass):
"""Test bitwise_and method.""" """Test bitwise_and method."""
@ -779,9 +914,10 @@ async def test_closest_function_home_vs_group_entity_id(hass):
await group.Group.async_create_group( await group.Group.async_create_group(
hass, 'location group', ['test_domain.object']) hass, 'location group', ['test_domain.object'])
assert template.Template( info = render_to_info(
'{{ closest("group.location_group").entity_id }}', hass, '{{ closest("group.location_group").entity_id }}')
hass).async_render() == 'test_domain.object' assert_result_info(info, 'test_domain.object', [
'test_domain.object', 'group.location_group'])
async def test_closest_function_home_vs_group_state(hass): async def test_closest_function_home_vs_group_state(hass):
@ -799,9 +935,17 @@ async def test_closest_function_home_vs_group_state(hass):
await group.Group.async_create_group( await group.Group.async_create_group(
hass, 'location group', ['test_domain.object']) hass, 'location group', ['test_domain.object'])
assert template.Template( info = render_to_info(
'{{ closest(states.group.location_group).entity_id }}', hass, '{{ closest("group.location_group").entity_id }}')
hass).async_render() == 'test_domain.object' assert_result_info(
info, 'test_domain.object',
['test_domain.object', 'group.location_group'])
info = render_to_info(
hass, '{{ closest(states.group.location_group).entity_id }}')
assert_result_info(
info, 'test_domain.object',
['test_domain.object', 'group.location_group'])
def test_closest_function_to_coord(hass): def test_closest_function_to_coord(hass):
@ -846,10 +990,18 @@ def test_closest_function_to_entity_id(hass):
'longitude': hass.config.longitude + 0.3, 'longitude': hass.config.longitude + 0.3,
}) })
assert template.Template( info = render_to_info(
'{{ closest("zone.far_away", ' hass,
'states.test_domain).entity_id }}', hass).async_render() == \ '{{ closest(zone, states.test_domain).entity_id }}',
'test_domain.closest_zone' {
'zone': 'zone.far_away'
})
assert_result_info(
info, 'test_domain.closest_zone',
['test_domain.closest_home', 'test_domain.closest_zone',
'zone.far_away'],
["test_domain"])
def test_closest_function_to_state(hass): def test_closest_function_to_state(hass):
@ -935,11 +1087,83 @@ def test_extract_entities_no_match_entities(hass):
assert template.extract_entities( assert template.extract_entities(
"{{ value_json.tst | timestamp_custom('%Y' True) }}") == MATCH_ALL "{{ value_json.tst | timestamp_custom('%Y' True) }}") == MATCH_ALL
assert template.extract_entities(""" info = render_to_info(hass, """
{% for state in states.sensor %} {% for state in states.sensor %}
{{ state.entity_id }}={{ state.state }},d {{ state.entity_id }}={{ state.state }},d
{% endfor %} {% endfor %}
""") == MATCH_ALL """)
assert_result_info(info, '', domains=['sensor'])
def test_generate_filter_iterators(hass):
"""Test extract entities function with none entities stuff."""
info = render_to_info(hass, """
{% for state in states %}
{{ state.entity_id }}
{% endfor %}
""")
assert_result_info(info, '', all_states=True)
info = render_to_info(hass, """
{% for state in states.sensor %}
{{ state.entity_id }}
{% endfor %}
""")
assert_result_info(info, '', domains=['sensor'])
hass.states.async_set('sensor.test_sensor', 'off', {
'attr': 'value'})
# Don't need the entity because the state is not accessed
info = render_to_info(hass, """
{% for state in states.sensor %}
{{ state.entity_id }}
{% endfor %}
""")
assert_result_info(info, 'sensor.test_sensor', domains=['sensor'])
# But we do here because the state gets accessed
info = render_to_info(hass, """
{% for state in states.sensor %}
{{ state.entity_id }}={{ state.state }},
{% endfor %}
""")
assert_result_info(
info, 'sensor.test_sensor=off,',
['sensor.test_sensor'],
['sensor'])
info = render_to_info(hass, """
{% for state in states.sensor %}
{{ state.entity_id }}={{ state.attributes.attr }},
{% endfor %}
""")
assert_result_info(
info, 'sensor.test_sensor=value,',
['sensor.test_sensor'],
['sensor'])
def test_generate_select(hass):
"""Test extract entities function with none entities stuff."""
template_str = """
{{ states.sensor|selectattr("state","equalto","off")
|join(",", attribute="entity_id") }}
"""
tmp = template.Template(template_str, hass)
info = tmp.async_render_to_info()
assert_result_info(info, '', [], ['sensor'])
hass.states.async_set('sensor.test_sensor', 'off', {
'attr': 'value'})
hass.states.async_set('sensor.test_sensor_on', 'on')
info = tmp.async_render_to_info()
assert_result_info(
info, 'sensor.test_sensor',
['sensor.test_sensor', 'sensor.test_sensor_on'],
['sensor'])
def test_extract_entities_match_entities(hass): def test_extract_entities_match_entities(hass):
@ -960,6 +1184,10 @@ Hercules is at {{ states('device_tracker.phone_1') }}.
{{ states("binary_sensor.garage_door") }} {{ states("binary_sensor.garage_door") }}
""") == ['binary_sensor.garage_door'] """) == ['binary_sensor.garage_door']
hass.states.async_set('device_tracker.phone_2', 'not_home', {
'battery': 20
})
assert template.extract_entities(""" assert template.extract_entities("""
{{ is_state_attr('device_tracker.phone_2', 'battery', 40) }} {{ is_state_attr('device_tracker.phone_2', 'battery', 40) }}
""") == ['device_tracker.phone_2'] """) == ['device_tracker.phone_2']
@ -1000,30 +1228,42 @@ states.sensor.pick_humidity.state ~ „ %“
def test_extract_entities_with_variables(hass): def test_extract_entities_with_variables(hass):
"""Test extract entities function with variables and entities stuff.""" """Test extract entities function with variables and entities stuff."""
assert template.extract_entities( hass.states.async_set('input_boolean.switch', 'on')
"{{ is_state('input_boolean.switch', 'off') }}", {}) == \ assert {'input_boolean.switch'} == \
['input_boolean.switch'] extract_entities(
hass, "{{ is_state('input_boolean.switch', 'off') }}", {})
assert template.extract_entities( assert {'input_boolean.switch'} == extract_entities(
"{{ is_state(trigger.entity_id, 'off') }}", {}) == \ hass, "{{ is_state(trigger.entity_id, 'off') }}", {
['trigger.entity_id'] 'trigger': {
'entity_id': 'input_boolean.switch'
}
})
assert template.extract_entities( assert {'no_state'} == extract_entities(
"{{ is_state(data, 'off') }}", {}) == MATCH_ALL hass,
"{{ is_state(data, 'off') }}", {
'data': 'no_state'
})
assert template.extract_entities( assert {'input_boolean.switch'} == \
"{{ is_state(data, 'off') }}", extract_entities(
{'data': 'input_boolean.switch'}) == \ hass,
['input_boolean.switch'] "{{ is_state(data, 'off') }}",
{'data': 'input_boolean.switch'})
assert template.extract_entities( assert {'input_boolean.switch'} == \
"{{ is_state(trigger.entity_id, 'off') }}", extract_entities(
{'trigger': {'entity_id': 'input_boolean.switch'}}) == \ hass,
['input_boolean.switch'] "{{ is_state(trigger.entity_id, 'off') }}",
{'trigger': {'entity_id': 'input_boolean.switch'}})
assert template.extract_entities( hass.states.async_set('media_player.livingroom', 'off')
"{{ is_state('media_player.' ~ where , 'playing') }}", assert {'media_player.livingroom'} == \
{'where': 'livingroom'}) == MATCH_ALL extract_entities(
hass,
"{{ is_state('media_player.' ~ where , 'playing') }}",
{'where': 'livingroom'})
def test_jinja_namespace(hass): def test_jinja_namespace(hass):
@ -1044,7 +1284,7 @@ def test_jinja_namespace(hass):
assert test_template.async_render() == 'another value' assert test_template.async_render() == 'another value'
async def test_state_with_unit(hass): def test_state_with_unit(hass):
"""Test the state_with_unit property helper.""" """Test the state_with_unit property helper."""
hass.states.async_set('sensor.test', '23', { hass.states.async_set('sensor.test', '23', {
'unit_of_measurement': 'beers', 'unit_of_measurement': 'beers',
@ -1073,7 +1313,7 @@ async def test_state_with_unit(hass):
assert tpl.async_render() == '' assert tpl.async_render() == ''
async def test_length_of_states(hass): def test_length_of_states(hass):
"""Test fetching the length of states.""" """Test fetching the length of states."""
hass.states.async_set('sensor.test', '23') hass.states.async_set('sensor.test', '23')
hass.states.async_set('sensor.test2', 'wow') hass.states.async_set('sensor.test2', 'wow')