From a9ffc149f8abcd739b95a1f798e182f439a0cb2a Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Fri, 21 Aug 2020 22:42:05 +0200 Subject: [PATCH] Allow templating keys in data_template (#39008) Co-authored-by: Paulus Schoutsen Co-authored-by: Paulus Schoutsen --- homeassistant/helpers/config_validation.py | 15 ++++++++------- homeassistant/helpers/template.py | 20 +++++++++++++++----- tests/helpers/test_script.py | 20 ++++++++++++++++++++ 3 files changed, 43 insertions(+), 12 deletions(-) diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index f70812d5a4f..7ab81385c63 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -535,12 +535,13 @@ def template_complex(value: Any) -> Any: return_list[idx] = template_complex(element) return return_list if isinstance(value, dict): - return_dict = value.copy() - for key, element in return_dict.items(): - return_dict[key] = template_complex(element) - return return_dict - if isinstance(value, str): + return { + template_complex(key): template_complex(element) + for key, element in value.items() + } + if isinstance(value, str) and template_helper.is_template_string(value): return template(value) + return value @@ -858,7 +859,7 @@ EVENT_SCHEMA = vol.Schema( vol.Optional(CONF_ALIAS): string, vol.Required(CONF_EVENT): string, vol.Optional(CONF_EVENT_DATA): dict, - vol.Optional(CONF_EVENT_DATA_TEMPLATE): {match_all: template_complex}, + vol.Optional(CONF_EVENT_DATA_TEMPLATE): template_complex, } ) @@ -869,7 +870,7 @@ SERVICE_SCHEMA = vol.All( vol.Exclusive(CONF_SERVICE, "service name"): service, vol.Exclusive(CONF_SERVICE_TEMPLATE, "service name"): template, vol.Optional("data"): dict, - vol.Optional("data_template"): {match_all: template_complex}, + vol.Optional("data_template"): template_complex, vol.Optional(CONF_ENTITY_ID): comp_entity_ids, } ), diff --git a/homeassistant/helpers/template.py b/homeassistant/helpers/template.py index 00f2c4c7296..39317d873f5 100644 --- a/homeassistant/helpers/template.py +++ b/homeassistant/helpers/template.py @@ -65,8 +65,9 @@ def attach(hass: HomeAssistantType, obj: Any) -> None: for child in obj: attach(hass, child) elif isinstance(obj, dict): - for child in obj.values(): - attach(hass, child) + for child_key, child_value in obj.items(): + attach(hass, child_key) + attach(hass, child_value) elif isinstance(obj, Template): obj.hass = hass @@ -76,19 +77,28 @@ def render_complex(value: Any, variables: TemplateVarsType = None) -> Any: if isinstance(value, list): return [render_complex(item, variables) for item in value] if isinstance(value, dict): - return {key: render_complex(item, variables) for key, item in value.items()} + return { + render_complex(key, variables): render_complex(item, variables) + for key, item in value.items() + } if isinstance(value, Template): return value.async_render(variables) + return value +def is_template_string(maybe_template: str) -> bool: + """Check if the input is a Jinja2 template.""" + return _RE_JINJA_DELIMITERS.search(maybe_template) is not None + + def extract_entities( hass: HomeAssistantType, template: Optional[str], variables: TemplateVarsType = None, ) -> Union[str, List[str]]: """Extract all entities for state_changed listener from template string.""" - if template is None or _RE_JINJA_DELIMITERS.search(template) is None: + if template is None or not is_template_string(template): return [] if _RE_NONE_ENTITIES.search(template): @@ -262,7 +272,7 @@ class Template: render_info.exception = ex finally: del self.hass.data[_RENDER_INFO] - if _RE_JINJA_DELIMITERS.search(self.template) is None: + if not is_template_string(self.template): render_info._freeze_static() else: render_info._freeze() diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index 1b2d190fc52..d5aa15ffe38 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -144,6 +144,26 @@ async def test_calling_service_template(hass): assert calls[0].data.get("hello") == "world" +async def test_data_template_with_templated_key(hass): + """Test the calling of a service with a data_template with a templated key.""" + context = Context() + calls = async_mock_service(hass, "test", "script") + + sequence = cv.SCRIPT_SCHEMA( + {"service": "test.script", "data_template": {"{{ hello_var }}": "world"}} + ) + script_obj = script.Script(hass, sequence, "Test Name", "test_domain") + + await script_obj.async_run( + MappingProxyType({"hello_var": "hello"}), context=context + ) + await hass.async_block_till_done() + + assert len(calls) == 1 + assert calls[0].context is context + assert "hello" in calls[0].data + + async def test_multiple_runs_no_wait(hass): """Test multiple runs with no wait in script.""" logger = logging.getLogger("TEST")