Allow templating keys in data_template (#39008)

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Bas Nijholt 2020-08-21 22:42:05 +02:00 committed by GitHub
parent b0f214bd9c
commit a9ffc149f8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 12 deletions

View file

@ -535,12 +535,13 @@ def template_complex(value: Any) -> Any:
return_list[idx] = template_complex(element) return_list[idx] = template_complex(element)
return return_list return return_list
if isinstance(value, dict): if isinstance(value, dict):
return_dict = value.copy() return {
for key, element in return_dict.items(): template_complex(key): template_complex(element)
return_dict[key] = template_complex(element) for key, element in value.items()
return return_dict }
if isinstance(value, str): if isinstance(value, str) and template_helper.is_template_string(value):
return template(value) return template(value)
return value return value
@ -858,7 +859,7 @@ EVENT_SCHEMA = vol.Schema(
vol.Optional(CONF_ALIAS): string, vol.Optional(CONF_ALIAS): string,
vol.Required(CONF_EVENT): string, vol.Required(CONF_EVENT): string,
vol.Optional(CONF_EVENT_DATA): dict, vol.Optional(CONF_EVENT_DATA): dict,
vol.Optional(CONF_EVENT_DATA_TEMPLATE): {match_all: template_complex}, vol.Optional(CONF_EVENT_DATA_TEMPLATE): template_complex,
} }
) )
@ -869,7 +870,7 @@ SERVICE_SCHEMA = vol.All(
vol.Exclusive(CONF_SERVICE, "service name"): service, vol.Exclusive(CONF_SERVICE, "service name"): service,
vol.Exclusive(CONF_SERVICE_TEMPLATE, "service name"): template, vol.Exclusive(CONF_SERVICE_TEMPLATE, "service name"): template,
vol.Optional("data"): dict, vol.Optional("data"): dict,
vol.Optional("data_template"): {match_all: template_complex}, vol.Optional("data_template"): template_complex,
vol.Optional(CONF_ENTITY_ID): comp_entity_ids, vol.Optional(CONF_ENTITY_ID): comp_entity_ids,
} }
), ),

View file

@ -65,8 +65,9 @@ def attach(hass: HomeAssistantType, obj: Any) -> None:
for child in obj: for child in obj:
attach(hass, child) attach(hass, child)
elif isinstance(obj, dict): elif isinstance(obj, dict):
for child in obj.values(): for child_key, child_value in obj.items():
attach(hass, child) attach(hass, child_key)
attach(hass, child_value)
elif isinstance(obj, Template): elif isinstance(obj, Template):
obj.hass = hass obj.hass = hass
@ -76,19 +77,28 @@ def render_complex(value: Any, variables: TemplateVarsType = None) -> Any:
if isinstance(value, list): if isinstance(value, list):
return [render_complex(item, variables) for item in value] return [render_complex(item, variables) for item in value]
if isinstance(value, dict): if isinstance(value, dict):
return {key: render_complex(item, variables) for key, item in value.items()} return {
render_complex(key, variables): render_complex(item, variables)
for key, item in value.items()
}
if isinstance(value, Template): if isinstance(value, Template):
return value.async_render(variables) return value.async_render(variables)
return value return value
def is_template_string(maybe_template: str) -> bool:
"""Check if the input is a Jinja2 template."""
return _RE_JINJA_DELIMITERS.search(maybe_template) is not None
def extract_entities( def extract_entities(
hass: HomeAssistantType, hass: HomeAssistantType,
template: Optional[str], template: Optional[str],
variables: TemplateVarsType = None, variables: TemplateVarsType = None,
) -> Union[str, List[str]]: ) -> Union[str, List[str]]:
"""Extract all entities for state_changed listener from template string.""" """Extract all entities for state_changed listener from template string."""
if template is None or _RE_JINJA_DELIMITERS.search(template) is None: if template is None or not is_template_string(template):
return [] return []
if _RE_NONE_ENTITIES.search(template): if _RE_NONE_ENTITIES.search(template):
@ -262,7 +272,7 @@ class Template:
render_info.exception = ex render_info.exception = ex
finally: finally:
del self.hass.data[_RENDER_INFO] del self.hass.data[_RENDER_INFO]
if _RE_JINJA_DELIMITERS.search(self.template) is None: if not is_template_string(self.template):
render_info._freeze_static() render_info._freeze_static()
else: else:
render_info._freeze() render_info._freeze()

View file

@ -144,6 +144,26 @@ async def test_calling_service_template(hass):
assert calls[0].data.get("hello") == "world" assert calls[0].data.get("hello") == "world"
async def test_data_template_with_templated_key(hass):
"""Test the calling of a service with a data_template with a templated key."""
context = Context()
calls = async_mock_service(hass, "test", "script")
sequence = cv.SCRIPT_SCHEMA(
{"service": "test.script", "data_template": {"{{ hello_var }}": "world"}}
)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
await script_obj.async_run(
MappingProxyType({"hello_var": "hello"}), context=context
)
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0].context is context
assert "hello" in calls[0].data
async def test_multiple_runs_no_wait(hass): async def test_multiple_runs_no_wait(hass):
"""Test multiple runs with no wait in script.""" """Test multiple runs with no wait in script."""
logger = logging.getLogger("TEST") logger = logging.getLogger("TEST")