diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index db97c3a321a..392ca710000 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -13,6 +13,7 @@ from homeassistant.const import ( CONF_ID, CONF_MODE, CONF_PLATFORM, + CONF_VARIABLES, CONF_ZONE, EVENT_HOMEASSISTANT_STARTED, SERVICE_RELOAD, @@ -29,7 +30,7 @@ from homeassistant.core import ( split_entity_id, ) from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import condition, extract_domain_configs +from homeassistant.helpers import condition, extract_domain_configs, template import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity import ToggleEntity from homeassistant.helpers.entity_component import EntityComponent @@ -104,6 +105,7 @@ PLATFORM_SCHEMA = vol.All( vol.Optional(CONF_HIDE_ENTITY): cv.boolean, vol.Required(CONF_TRIGGER): cv.TRIGGER_SCHEMA, vol.Optional(CONF_CONDITION): _CONDITION_SCHEMA, + vol.Optional(CONF_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA, vol.Required(CONF_ACTION): cv.SCRIPT_SCHEMA, }, SCRIPT_MODE_SINGLE, @@ -239,6 +241,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity): cond_func, action_script, initial_state, + variables, ): """Initialize an automation entity.""" self._id = automation_id @@ -253,6 +256,8 @@ class AutomationEntity(ToggleEntity, RestoreEntity): self._referenced_entities: Optional[Set[str]] = None self._referenced_devices: Optional[Set[str]] = None self._logger = _LOGGER + self._variables = variables + self._variables_dynamic = template.is_complex(variables) @property def name(self): @@ -329,6 +334,9 @@ class AutomationEntity(ToggleEntity, RestoreEntity): """Startup with initial state or previous state.""" await super().async_added_to_hass() + if self._variables_dynamic: + template.attach(cast(HomeAssistant, self.hass), self._variables) + self._logger = logging.getLogger( f"{__name__}.{split_entity_id(self.entity_id)[1]}" ) @@ -378,11 +386,22 @@ class AutomationEntity(ToggleEntity, RestoreEntity): else: await self.async_disable() - async def async_trigger(self, variables, context=None, skip_condition=False): + async def async_trigger(self, run_variables, context=None, skip_condition=False): """Trigger automation. This method is a coroutine. """ + if self._variables: + if self._variables_dynamic: + variables = template.render_complex(self._variables, run_variables) + else: + variables = dict(self._variables) + else: + variables = {} + + if run_variables: + variables.update(run_variables) + if ( not skip_condition and self._cond_func is not None @@ -518,6 +537,9 @@ async def _async_process_config(hass, config, component): max_runs=config_block[CONF_MAX], max_exceeded=config_block[CONF_MAX_EXCEEDED], logger=_LOGGER, + # We don't pass variables here + # Automation will already render them to use them in the condition + # and so will pass them on to the script. ) if CONF_CONDITION in config_block: @@ -535,6 +557,7 @@ async def _async_process_config(hass, config, component): cond_func, action_script, initial_state, + config_block.get(CONF_VARIABLES), ) entities.append(entity) diff --git a/homeassistant/components/script/__init__.py b/homeassistant/components/script/__init__.py index 20f12361621..1e0fad9be5d 100644 --- a/homeassistant/components/script/__init__.py +++ b/homeassistant/components/script/__init__.py @@ -12,6 +12,7 @@ from homeassistant.const import ( CONF_ICON, CONF_MODE, CONF_SEQUENCE, + CONF_VARIABLES, SERVICE_RELOAD, SERVICE_TOGGLE, SERVICE_TURN_OFF, @@ -59,6 +60,7 @@ SCRIPT_ENTRY_SCHEMA = make_script_schema( vol.Optional(CONF_ICON): cv.icon, vol.Required(CONF_SEQUENCE): cv.SCRIPT_SCHEMA, vol.Optional(CONF_DESCRIPTION, default=""): cv.string, + vol.Optional(CONF_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA, vol.Optional(CONF_FIELDS, default={}): { cv.string: { vol.Optional(CONF_DESCRIPTION): cv.string, @@ -75,7 +77,7 @@ CONFIG_SCHEMA = vol.Schema( SCRIPT_SERVICE_SCHEMA = vol.Schema(dict) SCRIPT_TURN_ONOFF_SCHEMA = make_entity_service_schema( - {vol.Optional(ATTR_VARIABLES): dict} + {vol.Optional(ATTR_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA} ) RELOAD_SERVICE_SCHEMA = vol.Schema({}) @@ -263,6 +265,7 @@ class ScriptEntity(ToggleEntity): max_runs=cfg[CONF_MAX], max_exceeded=cfg[CONF_MAX_EXCEEDED], logger=logging.getLogger(f"{__name__}.{object_id}"), + variables=cfg.get(CONF_VARIABLES), ) self._changed = asyncio.Event() diff --git a/homeassistant/const.py b/homeassistant/const.py index 4411de047d5..81f2243bca3 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -179,6 +179,7 @@ CONF_UNTIL = "until" CONF_URL = "url" CONF_USERNAME = "username" CONF_VALUE_TEMPLATE = "value_template" +CONF_VARIABLES = "variables" CONF_VERIFY_SSL = "verify_ssl" CONF_WAIT_FOR_TRIGGER = "wait_for_trigger" CONF_WAIT_TEMPLATE = "wait_template" diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index c3842c538d8..a54f97ec7e5 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -863,6 +863,9 @@ def make_entity_service_schema( ) +SCRIPT_VARIABLES_SCHEMA = vol.Schema({str: template_complex}) + + def script_action(value: Any) -> dict: """Validate a script action.""" if not isinstance(value, dict): diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 74660f8b391..cd664974431 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -53,11 +53,7 @@ from homeassistant.const import ( SERVICE_TURN_ON, ) from homeassistant.core import SERVICE_CALL_LIMIT, Context, HomeAssistant, callback -from homeassistant.helpers import ( - condition, - config_validation as cv, - template as template, -) +from homeassistant.helpers import condition, config_validation as cv, template from homeassistant.helpers.event import async_call_later, async_track_template from homeassistant.helpers.service import ( CONF_SERVICE_DATA, @@ -721,6 +717,7 @@ class Script: logger: Optional[logging.Logger] = None, log_exceptions: bool = True, top_level: bool = True, + variables: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the script.""" all_scripts = hass.data.get(DATA_SCRIPTS) @@ -759,6 +756,10 @@ class Script: self._choose_data: Dict[int, Dict[str, Any]] = {} self._referenced_entities: Optional[Set[str]] = None self._referenced_devices: Optional[Set[str]] = None + self.variables = variables + self._variables_dynamic = template.is_complex(variables) + if self._variables_dynamic: + template.attach(hass, variables) def _set_logger(self, logger: Optional[logging.Logger] = None) -> None: if logger: @@ -867,7 +868,7 @@ class Script: async def async_run( self, - variables: Optional[_VarsType] = None, + run_variables: Optional[_VarsType] = None, context: Optional[Context] = None, started_action: Optional[Callable[..., Any]] = None, ) -> None: @@ -898,8 +899,19 @@ class Script: # are read-only, but more importantly, so as not to leak any variables created # during the run back to the caller. if self._top_level: - variables = dict(variables) if variables is not None else {} + if self.variables: + if self._variables_dynamic: + variables = template.render_complex(self.variables, run_variables) + else: + variables = dict(self.variables) + else: + variables = {} + + if run_variables: + variables.update(run_variables) variables["context"] = context + else: + variables = cast(dict, run_variables) if self.script_mode != SCRIPT_MODE_QUEUED: cls = _ScriptRun diff --git a/homeassistant/helpers/template.py b/homeassistant/helpers/template.py index c771992caa4..917581fac07 100644 --- a/homeassistant/helpers/template.py +++ b/homeassistant/helpers/template.py @@ -65,7 +65,7 @@ def attach(hass: HomeAssistantType, obj: Any) -> None: if isinstance(obj, list): for child in obj: attach(hass, child) - elif isinstance(obj, dict): + elif isinstance(obj, collections.abc.Mapping): for child_key, child_value in obj.items(): attach(hass, child_key) attach(hass, child_value) @@ -77,7 +77,7 @@ def render_complex(value: Any, variables: TemplateVarsType = None) -> Any: """Recursive template creator helper function.""" if isinstance(value, list): return [render_complex(item, variables) for item in value] - if isinstance(value, dict): + if isinstance(value, collections.abc.Mapping): return { render_complex(key, variables): render_complex(item, variables) for key, item in value.items() @@ -88,6 +88,19 @@ def render_complex(value: Any, variables: TemplateVarsType = None) -> Any: return value +def is_complex(value: Any) -> bool: + """Test if data structure is a complex template.""" + if isinstance(value, Template): + return True + if isinstance(value, list): + return any(is_complex(val) for val in value) + if isinstance(value, collections.abc.Mapping): + return any(is_complex(val) for val in value.keys()) or any( + is_complex(val) for val in value.values() + ) + return False + + def is_template_string(maybe_template: str) -> bool: """Check if the input is a Jinja2 template.""" return _RE_JINJA_DELIMITERS.search(maybe_template) is not None diff --git a/homeassistant/helpers/typing.py b/homeassistant/helpers/typing.py index 6bcc98c10a8..bed0d2b8d17 100644 --- a/homeassistant/helpers/typing.py +++ b/homeassistant/helpers/typing.py @@ -1,5 +1,5 @@ """Typing Helpers for Home Assistant.""" -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Mapping, Optional, Tuple, Union import homeassistant.core @@ -12,7 +12,7 @@ HomeAssistantType = homeassistant.core.HomeAssistant ServiceCallType = homeassistant.core.ServiceCall ServiceDataType = Dict[str, Any] StateType = Union[None, str, int, float] -TemplateVarsType = Optional[Dict[str, Any]] +TemplateVarsType = Optional[Mapping[str, Any]] # Custom type for recorder Queries QueryType = Any diff --git a/tests/components/automation/test_init.py b/tests/components/automation/test_init.py index 3952e781952..5ee0ff62af2 100644 --- a/tests/components/automation/test_init.py +++ b/tests/components/automation/test_init.py @@ -1134,3 +1134,57 @@ async def test_logbook_humanify_automation_triggered_event(hass): assert event2["domain"] == "automation" assert event2["message"] == "has been triggered by source of trigger" assert event2["entity_id"] == "automation.bye" + + +async def test_automation_variables(hass): + """Test automation variables.""" + calls = async_mock_service(hass, "test", "automation") + + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: [ + { + "variables": { + "test_var": "defined_in_config", + "event_type": "{{ trigger.event.event_type }}", + }, + "trigger": {"platform": "event", "event_type": "test_event"}, + "action": { + "service": "test.automation", + "data": { + "value": "{{ test_var }}", + "event_type": "{{ event_type }}", + }, + }, + }, + { + "variables": { + "test_var": "defined_in_config", + }, + "trigger": {"platform": "event", "event_type": "test_event_2"}, + "condition": { + "condition": "template", + "value_template": "{{ trigger.event.data.pass_condition }}", + }, + "action": { + "service": "test.automation", + }, + }, + ] + }, + ) + hass.bus.async_fire("test_event") + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0].data["value"] == "defined_in_config" + assert calls[0].data["event_type"] == "test_event" + + hass.bus.async_fire("test_event_2") + await hass.async_block_till_done() + assert len(calls) == 1 + + hass.bus.async_fire("test_event_2", {"pass_condition": True}) + await hass.async_block_till_done() + assert len(calls) == 2 diff --git a/tests/components/script/test_init.py b/tests/components/script/test_init.py index 22625d46530..5fb832d0f36 100644 --- a/tests/components/script/test_init.py +++ b/tests/components/script/test_init.py @@ -23,7 +23,7 @@ from homeassistant.loader import bind_hass from homeassistant.setup import async_setup_component, setup_component from tests.async_mock import Mock, patch -from tests.common import get_test_home_assistant +from tests.common import async_mock_service, get_test_home_assistant from tests.components.logbook.test_init import MockLazyEventPartialState ENTITY_ID = "script.test" @@ -615,3 +615,69 @@ async def test_concurrent_script(hass, concurrently): assert not script.is_on(hass, "script.script1") assert not script.is_on(hass, "script.script2") + + +async def test_script_variables(hass): + """Test defining scripts.""" + assert await async_setup_component( + hass, + "script", + { + "script": { + "script1": { + "variables": { + "test_var": "from_config", + "templated_config_var": "{{ var_from_service | default('config-default') }}", + }, + "sequence": [ + { + "service": "test.script", + "data": { + "value": "{{ test_var }}", + "templated_config_var": "{{ templated_config_var }}", + }, + }, + ], + }, + "script2": { + "variables": { + "test_var": "from_config", + }, + "sequence": [ + { + "service": "test.script", + "data": { + "value": "{{ test_var }}", + }, + }, + ], + }, + } + }, + ) + + mock_calls = async_mock_service(hass, "test", "script") + + await hass.services.async_call( + "script", "script1", {"var_from_service": "hello"}, blocking=True + ) + + assert len(mock_calls) == 1 + assert mock_calls[0].data["value"] == "from_config" + assert mock_calls[0].data["templated_config_var"] == "hello" + + await hass.services.async_call( + "script", "script1", {"test_var": "from_service"}, blocking=True + ) + + assert len(mock_calls) == 2 + assert mock_calls[1].data["value"] == "from_service" + assert mock_calls[1].data["templated_config_var"] == "config-default" + + # Call script with vars but no templates in it + await hass.services.async_call( + "script", "script2", {"test_var": "from_service"}, blocking=True + ) + + assert len(mock_calls) == 3 + assert mock_calls[2].data["value"] == "from_service"