From a9ffc149f8abcd739b95a1f798e182f439a0cb2a Mon Sep 17 00:00:00 2001
From: Bas Nijholt <basnijholt@gmail.com>
Date: Fri, 21 Aug 2020 22:42:05 +0200
Subject: [PATCH] Allow templating keys in data_template (#39008)

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
---
 homeassistant/helpers/config_validation.py | 15 ++++++++-------
 homeassistant/helpers/template.py          | 20 +++++++++++++++-----
 tests/helpers/test_script.py               | 20 ++++++++++++++++++++
 3 files changed, 43 insertions(+), 12 deletions(-)

diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py
index f70812d5a4f..7ab81385c63 100644
--- a/homeassistant/helpers/config_validation.py
+++ b/homeassistant/helpers/config_validation.py
@@ -535,12 +535,13 @@ def template_complex(value: Any) -> Any:
             return_list[idx] = template_complex(element)
         return return_list
     if isinstance(value, dict):
-        return_dict = value.copy()
-        for key, element in return_dict.items():
-            return_dict[key] = template_complex(element)
-        return return_dict
-    if isinstance(value, str):
+        return {
+            template_complex(key): template_complex(element)
+            for key, element in value.items()
+        }
+    if isinstance(value, str) and template_helper.is_template_string(value):
         return template(value)
+
     return value
 
 
@@ -858,7 +859,7 @@ EVENT_SCHEMA = vol.Schema(
         vol.Optional(CONF_ALIAS): string,
         vol.Required(CONF_EVENT): string,
         vol.Optional(CONF_EVENT_DATA): dict,
-        vol.Optional(CONF_EVENT_DATA_TEMPLATE): {match_all: template_complex},
+        vol.Optional(CONF_EVENT_DATA_TEMPLATE): template_complex,
     }
 )
 
@@ -869,7 +870,7 @@ SERVICE_SCHEMA = vol.All(
             vol.Exclusive(CONF_SERVICE, "service name"): service,
             vol.Exclusive(CONF_SERVICE_TEMPLATE, "service name"): template,
             vol.Optional("data"): dict,
-            vol.Optional("data_template"): {match_all: template_complex},
+            vol.Optional("data_template"): template_complex,
             vol.Optional(CONF_ENTITY_ID): comp_entity_ids,
         }
     ),
diff --git a/homeassistant/helpers/template.py b/homeassistant/helpers/template.py
index 00f2c4c7296..39317d873f5 100644
--- a/homeassistant/helpers/template.py
+++ b/homeassistant/helpers/template.py
@@ -65,8 +65,9 @@ def attach(hass: HomeAssistantType, obj: Any) -> None:
         for child in obj:
             attach(hass, child)
     elif isinstance(obj, dict):
-        for child in obj.values():
-            attach(hass, child)
+        for child_key, child_value in obj.items():
+            attach(hass, child_key)
+            attach(hass, child_value)
     elif isinstance(obj, Template):
         obj.hass = hass
 
@@ -76,19 +77,28 @@ def render_complex(value: Any, variables: TemplateVarsType = None) -> Any:
     if isinstance(value, list):
         return [render_complex(item, variables) for item in value]
     if isinstance(value, dict):
-        return {key: render_complex(item, variables) for key, item in value.items()}
+        return {
+            render_complex(key, variables): render_complex(item, variables)
+            for key, item in value.items()
+        }
     if isinstance(value, Template):
         return value.async_render(variables)
+
     return value
 
 
+def is_template_string(maybe_template: str) -> bool:
+    """Check if the input is a Jinja2 template."""
+    return _RE_JINJA_DELIMITERS.search(maybe_template) is not None
+
+
 def extract_entities(
     hass: HomeAssistantType,
     template: Optional[str],
     variables: TemplateVarsType = None,
 ) -> Union[str, List[str]]:
     """Extract all entities for state_changed listener from template string."""
-    if template is None or _RE_JINJA_DELIMITERS.search(template) is None:
+    if template is None or not is_template_string(template):
         return []
 
     if _RE_NONE_ENTITIES.search(template):
@@ -262,7 +272,7 @@ class Template:
             render_info.exception = ex
         finally:
             del self.hass.data[_RENDER_INFO]
-            if _RE_JINJA_DELIMITERS.search(self.template) is None:
+            if not is_template_string(self.template):
                 render_info._freeze_static()
             else:
                 render_info._freeze()
diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py
index 1b2d190fc52..d5aa15ffe38 100644
--- a/tests/helpers/test_script.py
+++ b/tests/helpers/test_script.py
@@ -144,6 +144,26 @@ async def test_calling_service_template(hass):
     assert calls[0].data.get("hello") == "world"
 
 
+async def test_data_template_with_templated_key(hass):
+    """Test the calling of a service with a data_template with a templated key."""
+    context = Context()
+    calls = async_mock_service(hass, "test", "script")
+
+    sequence = cv.SCRIPT_SCHEMA(
+        {"service": "test.script", "data_template": {"{{ hello_var }}": "world"}}
+    )
+    script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
+
+    await script_obj.async_run(
+        MappingProxyType({"hello_var": "hello"}), context=context
+    )
+    await hass.async_block_till_done()
+
+    assert len(calls) == 1
+    assert calls[0].context is context
+    assert "hello" in calls[0].data
+
+
 async def test_multiple_runs_no_wait(hass):
     """Test multiple runs with no wait in script."""
     logger = logging.getLogger("TEST")