Implemented event_data_template (new) (#11057)

* Implemented event_data_template

* The hound does not like my indentation

* Added passed variables to tests for event and svc template calls

* Moved recursive function to template.py

* Update template.py

* Update template.py

* Cleaned up service.py and fixed unit tests

* Blank lines

* Removed stray logger statement

* Blank lines again
This commit is contained in:
tschmidty69 2018-01-19 01:13:14 -05:00 committed by Paulus Schoutsen
parent 0e1cc05189
commit 48619c9d7c
6 changed files with 101 additions and 19 deletions

View file

@ -475,6 +475,7 @@ EVENT_SCHEMA = vol.Schema({
vol.Optional(CONF_ALIAS): string, vol.Optional(CONF_ALIAS): string,
vol.Required('event'): string, vol.Required('event'): string,
vol.Optional('event_data'): dict, vol.Optional('event_data'): dict,
vol.Optional('event_data_template'): {match_all: template_complex}
}) })
SERVICE_SCHEMA = vol.All(vol.Schema({ SERVICE_SCHEMA = vol.All(vol.Schema({

View file

@ -8,8 +8,10 @@ import voluptuous as vol
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.const import CONF_CONDITION, CONF_TIMEOUT from homeassistant.const import CONF_CONDITION, CONF_TIMEOUT
from homeassistant.exceptions import TemplateError
from homeassistant.helpers import ( from homeassistant.helpers import (
service, condition, template, config_validation as cv) service, condition, template as template,
config_validation as cv)
from homeassistant.helpers.event import ( from homeassistant.helpers.event import (
async_track_point_in_utc_time, async_track_template) async_track_point_in_utc_time, async_track_template)
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
@ -25,6 +27,7 @@ CONF_SERVICE_DATA = 'data'
CONF_SEQUENCE = 'sequence' CONF_SEQUENCE = 'sequence'
CONF_EVENT = 'event' CONF_EVENT = 'event'
CONF_EVENT_DATA = 'event_data' CONF_EVENT_DATA = 'event_data'
CONF_EVENT_DATA_TEMPLATE = 'event_data_template'
CONF_DELAY = 'delay' CONF_DELAY = 'delay'
CONF_WAIT_TEMPLATE = 'wait_template' CONF_WAIT_TEMPLATE = 'wait_template'
@ -145,7 +148,7 @@ class Script():
break break
elif CONF_EVENT in action: elif CONF_EVENT in action:
self._async_fire_event(action) self._async_fire_event(action, variables)
else: else:
yield from self._async_call_service(action, variables) yield from self._async_call_service(action, variables)
@ -180,12 +183,20 @@ class Script():
yield from service.async_call_from_config( yield from service.async_call_from_config(
self.hass, action, True, variables, validate_config=False) self.hass, action, True, variables, validate_config=False)
def _async_fire_event(self, action): def _async_fire_event(self, action, variables):
"""Fire an event.""" """Fire an event."""
self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT]) self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT])
self._log("Executing step %s" % self.last_action) self._log("Executing step %s" % self.last_action)
event_data = dict(action.get(CONF_EVENT_DATA, {}))
if CONF_EVENT_DATA_TEMPLATE in action:
try:
event_data.update(template.render_complex(
action[CONF_EVENT_DATA_TEMPLATE], variables))
except TemplateError as ex:
_LOGGER.error('Error rendering event data template: %s', ex)
self.hass.bus.async_fire(action[CONF_EVENT], self.hass.bus.async_fire(action[CONF_EVENT],
action.get(CONF_EVENT_DATA)) event_data)
def _async_check_condition(self, action, variables): def _async_check_condition(self, action, variables):
"""Test if condition is matching.""" """Test if condition is matching."""

View file

@ -10,6 +10,7 @@ import voluptuous as vol
from homeassistant.const import ATTR_ENTITY_ID from homeassistant.const import ATTR_ENTITY_ID
import homeassistant.core as ha import homeassistant.core as ha
from homeassistant.exceptions import TemplateError from homeassistant.exceptions import TemplateError
from homeassistant.helpers import template
from homeassistant.loader import get_component, bind_hass from homeassistant.loader import get_component, bind_hass
from homeassistant.util.yaml import load_yaml from homeassistant.util.yaml import load_yaml
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
@ -67,17 +68,12 @@ def async_call_from_config(hass, config, blocking=False, variables=None,
service_data = dict(config.get(CONF_SERVICE_DATA, {})) service_data = dict(config.get(CONF_SERVICE_DATA, {}))
if CONF_SERVICE_DATA_TEMPLATE in config: if CONF_SERVICE_DATA_TEMPLATE in config:
def _data_template_creator(value): try:
"""Recursive template creator helper function.""" template.attach(hass, config[CONF_SERVICE_DATA_TEMPLATE])
if isinstance(value, list): service_data.update(template.render_complex(
return [_data_template_creator(item) for item in value] config[CONF_SERVICE_DATA_TEMPLATE], variables))
elif isinstance(value, dict): except TemplateError as ex:
return {key: _data_template_creator(item) _LOGGER.error('Error rendering data template: %s', ex)
for key, item in value.items()}
value.hass = hass
return value.async_render(variables)
service_data.update(_data_template_creator(
config[CONF_SERVICE_DATA_TEMPLATE]))
if CONF_SERVICE_ENTITY_ID in config: if CONF_SERVICE_ENTITY_ID in config:
service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID] service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID]

View file

@ -44,6 +44,17 @@ def attach(hass, obj):
obj.hass = hass obj.hass = hass
def render_complex(value, variables=None):
"""Recursive template creator helper function."""
if isinstance(value, list):
return [render_complex(item, variables)
for item in value]
elif isinstance(value, dict):
return {key: render_complex(item, variables)
for key, item in value.items()}
return value.async_render(variables)
def extract_entities(template, variables=None): def extract_entities(template, variables=None):
"""Extract all entities for state_changed listener from template string.""" """Extract all entities for state_changed listener from template string."""
if template is None or _RE_NONE_ENTITIES.search(template): if template is None or _RE_NONE_ENTITIES.search(template):

View file

@ -56,6 +56,39 @@ class TestScriptHelper(unittest.TestCase):
assert calls[0].data.get('hello') == 'world' assert calls[0].data.get('hello') == 'world'
assert not script_obj.can_cancel assert not script_obj.can_cancel
def test_firing_event_template(self):
"""Test the firing of events."""
event = 'test_event'
calls = []
@callback
def record_event(event):
"""Add recorded event to set."""
calls.append(event)
self.hass.bus.listen(event, record_event)
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA({
'event': event,
'event_data_template': {
'hello': """
{% if is_world == 'yes' %}
world
{% else %}
not world
{% endif %}
"""
}
}))
script_obj.run({'is_world': 'yes'})
self.hass.block_till_done()
assert len(calls) == 1
assert calls[0].data.get('hello') == 'world'
assert not script_obj.can_cancel
def test_calling_service(self): def test_calling_service(self):
"""Test the calling of a service.""" """Test the calling of a service."""
calls = [] calls = []
@ -99,14 +132,14 @@ class TestScriptHelper(unittest.TestCase):
{% endif %}""", {% endif %}""",
'data_template': { 'data_template': {
'hello': """ 'hello': """
{% if True %} {% if is_world == 'yes' %}
world world
{% else %} {% else %}
Not world not world
{% endif %} {% endif %}
""" """
} }
}) }, {'is_world': 'yes'})
self.hass.block_till_done() self.hass.block_till_done()
@ -147,7 +180,7 @@ class TestScriptHelper(unittest.TestCase):
def test_delay_template(self): def test_delay_template(self):
"""Test the delay as a template.""" """Test the delay as a template."""
event = 'test_evnt' event = 'test_event'
events = [] events = []
@callback @callback

View file

@ -279,6 +279,36 @@ class TestHelpersTemplate(unittest.TestCase):
'127', '127',
template.Template('{{ hello }}', self.hass).render({'hello': 127})) template.Template('{{ hello }}', self.hass).render({'hello': 127}))
def test_passing_vars_as_list(self):
"""Test passing variables as list."""
self.assertEqual(
"['foo', 'bar']",
template.render_complex(template.Template('{{ hello }}',
self.hass), {'hello': ['foo', 'bar']}))
def test_passing_vars_as_list_element(self):
"""Test passing variables as list."""
self.assertEqual(
'bar',
template.render_complex(template.Template('{{ hello[1] }}',
self.hass),
{'hello': ['foo', 'bar']}))
def test_passing_vars_as_dict_element(self):
"""Test passing variables as list."""
self.assertEqual(
'bar',
template.render_complex(template.Template('{{ hello.foo }}',
self.hass),
{'hello': {'foo': 'bar'}}))
def test_passing_vars_as_dict(self):
"""Test passing variables as list."""
self.assertEqual(
"{'foo': 'bar'}",
template.render_complex(template.Template('{{ hello }}',
self.hass), {'hello': {'foo': 'bar'}}))
def test_render_with_possible_json_value_with_valid_json(self): def test_render_with_possible_json_value_with_valid_json(self):
"""Render with possible JSON value with valid JSON.""" """Render with possible JSON value with valid JSON."""
tpl = template.Template('{{ value_json.hello }}', self.hass) tpl = template.Template('{{ value_json.hello }}', self.hass)