Add default variables to script helper (#39895)

This commit is contained in:
Paulus Schoutsen 2020-09-10 20:41:42 +02:00 committed by GitHub
parent b5005430be
commit aa9dff572e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 190 additions and 15 deletions

View file

@ -13,6 +13,7 @@ from homeassistant.const import (
CONF_ID, CONF_ID,
CONF_MODE, CONF_MODE,
CONF_PLATFORM, CONF_PLATFORM,
CONF_VARIABLES,
CONF_ZONE, CONF_ZONE,
EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STARTED,
SERVICE_RELOAD, SERVICE_RELOAD,
@ -29,7 +30,7 @@ from homeassistant.core import (
split_entity_id, split_entity_id,
) )
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import condition, extract_domain_configs from homeassistant.helpers import condition, extract_domain_configs, template
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import ToggleEntity from homeassistant.helpers.entity import ToggleEntity
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
@ -104,6 +105,7 @@ PLATFORM_SCHEMA = vol.All(
vol.Optional(CONF_HIDE_ENTITY): cv.boolean, vol.Optional(CONF_HIDE_ENTITY): cv.boolean,
vol.Required(CONF_TRIGGER): cv.TRIGGER_SCHEMA, vol.Required(CONF_TRIGGER): cv.TRIGGER_SCHEMA,
vol.Optional(CONF_CONDITION): _CONDITION_SCHEMA, vol.Optional(CONF_CONDITION): _CONDITION_SCHEMA,
vol.Optional(CONF_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA,
vol.Required(CONF_ACTION): cv.SCRIPT_SCHEMA, vol.Required(CONF_ACTION): cv.SCRIPT_SCHEMA,
}, },
SCRIPT_MODE_SINGLE, SCRIPT_MODE_SINGLE,
@ -239,6 +241,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
cond_func, cond_func,
action_script, action_script,
initial_state, initial_state,
variables,
): ):
"""Initialize an automation entity.""" """Initialize an automation entity."""
self._id = automation_id self._id = automation_id
@ -253,6 +256,8 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
self._referenced_entities: Optional[Set[str]] = None self._referenced_entities: Optional[Set[str]] = None
self._referenced_devices: Optional[Set[str]] = None self._referenced_devices: Optional[Set[str]] = None
self._logger = _LOGGER self._logger = _LOGGER
self._variables = variables
self._variables_dynamic = template.is_complex(variables)
@property @property
def name(self): def name(self):
@ -329,6 +334,9 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
"""Startup with initial state or previous state.""" """Startup with initial state or previous state."""
await super().async_added_to_hass() await super().async_added_to_hass()
if self._variables_dynamic:
template.attach(cast(HomeAssistant, self.hass), self._variables)
self._logger = logging.getLogger( self._logger = logging.getLogger(
f"{__name__}.{split_entity_id(self.entity_id)[1]}" f"{__name__}.{split_entity_id(self.entity_id)[1]}"
) )
@ -378,11 +386,22 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
else: else:
await self.async_disable() await self.async_disable()
async def async_trigger(self, variables, context=None, skip_condition=False): async def async_trigger(self, run_variables, context=None, skip_condition=False):
"""Trigger automation. """Trigger automation.
This method is a coroutine. This method is a coroutine.
""" """
if self._variables:
if self._variables_dynamic:
variables = template.render_complex(self._variables, run_variables)
else:
variables = dict(self._variables)
else:
variables = {}
if run_variables:
variables.update(run_variables)
if ( if (
not skip_condition not skip_condition
and self._cond_func is not None and self._cond_func is not None
@ -518,6 +537,9 @@ async def _async_process_config(hass, config, component):
max_runs=config_block[CONF_MAX], max_runs=config_block[CONF_MAX],
max_exceeded=config_block[CONF_MAX_EXCEEDED], max_exceeded=config_block[CONF_MAX_EXCEEDED],
logger=_LOGGER, logger=_LOGGER,
# We don't pass variables here
# Automation will already render them to use them in the condition
# and so will pass them on to the script.
) )
if CONF_CONDITION in config_block: if CONF_CONDITION in config_block:
@ -535,6 +557,7 @@ async def _async_process_config(hass, config, component):
cond_func, cond_func,
action_script, action_script,
initial_state, initial_state,
config_block.get(CONF_VARIABLES),
) )
entities.append(entity) entities.append(entity)

View file

@ -12,6 +12,7 @@ from homeassistant.const import (
CONF_ICON, CONF_ICON,
CONF_MODE, CONF_MODE,
CONF_SEQUENCE, CONF_SEQUENCE,
CONF_VARIABLES,
SERVICE_RELOAD, SERVICE_RELOAD,
SERVICE_TOGGLE, SERVICE_TOGGLE,
SERVICE_TURN_OFF, SERVICE_TURN_OFF,
@ -59,6 +60,7 @@ SCRIPT_ENTRY_SCHEMA = make_script_schema(
vol.Optional(CONF_ICON): cv.icon, vol.Optional(CONF_ICON): cv.icon,
vol.Required(CONF_SEQUENCE): cv.SCRIPT_SCHEMA, vol.Required(CONF_SEQUENCE): cv.SCRIPT_SCHEMA,
vol.Optional(CONF_DESCRIPTION, default=""): cv.string, vol.Optional(CONF_DESCRIPTION, default=""): cv.string,
vol.Optional(CONF_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA,
vol.Optional(CONF_FIELDS, default={}): { vol.Optional(CONF_FIELDS, default={}): {
cv.string: { cv.string: {
vol.Optional(CONF_DESCRIPTION): cv.string, vol.Optional(CONF_DESCRIPTION): cv.string,
@ -75,7 +77,7 @@ CONFIG_SCHEMA = vol.Schema(
SCRIPT_SERVICE_SCHEMA = vol.Schema(dict) SCRIPT_SERVICE_SCHEMA = vol.Schema(dict)
SCRIPT_TURN_ONOFF_SCHEMA = make_entity_service_schema( SCRIPT_TURN_ONOFF_SCHEMA = make_entity_service_schema(
{vol.Optional(ATTR_VARIABLES): dict} {vol.Optional(ATTR_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA}
) )
RELOAD_SERVICE_SCHEMA = vol.Schema({}) RELOAD_SERVICE_SCHEMA = vol.Schema({})
@ -263,6 +265,7 @@ class ScriptEntity(ToggleEntity):
max_runs=cfg[CONF_MAX], max_runs=cfg[CONF_MAX],
max_exceeded=cfg[CONF_MAX_EXCEEDED], max_exceeded=cfg[CONF_MAX_EXCEEDED],
logger=logging.getLogger(f"{__name__}.{object_id}"), logger=logging.getLogger(f"{__name__}.{object_id}"),
variables=cfg.get(CONF_VARIABLES),
) )
self._changed = asyncio.Event() self._changed = asyncio.Event()

View file

@ -179,6 +179,7 @@ CONF_UNTIL = "until"
CONF_URL = "url" CONF_URL = "url"
CONF_USERNAME = "username" CONF_USERNAME = "username"
CONF_VALUE_TEMPLATE = "value_template" CONF_VALUE_TEMPLATE = "value_template"
CONF_VARIABLES = "variables"
CONF_VERIFY_SSL = "verify_ssl" CONF_VERIFY_SSL = "verify_ssl"
CONF_WAIT_FOR_TRIGGER = "wait_for_trigger" CONF_WAIT_FOR_TRIGGER = "wait_for_trigger"
CONF_WAIT_TEMPLATE = "wait_template" CONF_WAIT_TEMPLATE = "wait_template"

View file

@ -863,6 +863,9 @@ def make_entity_service_schema(
) )
SCRIPT_VARIABLES_SCHEMA = vol.Schema({str: template_complex})
def script_action(value: Any) -> dict: def script_action(value: Any) -> dict:
"""Validate a script action.""" """Validate a script action."""
if not isinstance(value, dict): if not isinstance(value, dict):

View file

@ -53,11 +53,7 @@ from homeassistant.const import (
SERVICE_TURN_ON, SERVICE_TURN_ON,
) )
from homeassistant.core import SERVICE_CALL_LIMIT, Context, HomeAssistant, callback from homeassistant.core import SERVICE_CALL_LIMIT, Context, HomeAssistant, callback
from homeassistant.helpers import ( from homeassistant.helpers import condition, config_validation as cv, template
condition,
config_validation as cv,
template as template,
)
from homeassistant.helpers.event import async_call_later, async_track_template from homeassistant.helpers.event import async_call_later, async_track_template
from homeassistant.helpers.service import ( from homeassistant.helpers.service import (
CONF_SERVICE_DATA, CONF_SERVICE_DATA,
@ -721,6 +717,7 @@ class Script:
logger: Optional[logging.Logger] = None, logger: Optional[logging.Logger] = None,
log_exceptions: bool = True, log_exceptions: bool = True,
top_level: bool = True, top_level: bool = True,
variables: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
"""Initialize the script.""" """Initialize the script."""
all_scripts = hass.data.get(DATA_SCRIPTS) all_scripts = hass.data.get(DATA_SCRIPTS)
@ -759,6 +756,10 @@ class Script:
self._choose_data: Dict[int, Dict[str, Any]] = {} self._choose_data: Dict[int, Dict[str, Any]] = {}
self._referenced_entities: Optional[Set[str]] = None self._referenced_entities: Optional[Set[str]] = None
self._referenced_devices: Optional[Set[str]] = None self._referenced_devices: Optional[Set[str]] = None
self.variables = variables
self._variables_dynamic = template.is_complex(variables)
if self._variables_dynamic:
template.attach(hass, variables)
def _set_logger(self, logger: Optional[logging.Logger] = None) -> None: def _set_logger(self, logger: Optional[logging.Logger] = None) -> None:
if logger: if logger:
@ -867,7 +868,7 @@ class Script:
async def async_run( async def async_run(
self, self,
variables: Optional[_VarsType] = None, run_variables: Optional[_VarsType] = None,
context: Optional[Context] = None, context: Optional[Context] = None,
started_action: Optional[Callable[..., Any]] = None, started_action: Optional[Callable[..., Any]] = None,
) -> None: ) -> None:
@ -898,8 +899,19 @@ class Script:
# are read-only, but more importantly, so as not to leak any variables created # are read-only, but more importantly, so as not to leak any variables created
# during the run back to the caller. # during the run back to the caller.
if self._top_level: if self._top_level:
variables = dict(variables) if variables is not None else {} if self.variables:
if self._variables_dynamic:
variables = template.render_complex(self.variables, run_variables)
else:
variables = dict(self.variables)
else:
variables = {}
if run_variables:
variables.update(run_variables)
variables["context"] = context variables["context"] = context
else:
variables = cast(dict, run_variables)
if self.script_mode != SCRIPT_MODE_QUEUED: if self.script_mode != SCRIPT_MODE_QUEUED:
cls = _ScriptRun cls = _ScriptRun

View file

@ -65,7 +65,7 @@ def attach(hass: HomeAssistantType, obj: Any) -> None:
if isinstance(obj, list): if isinstance(obj, list):
for child in obj: for child in obj:
attach(hass, child) attach(hass, child)
elif isinstance(obj, dict): elif isinstance(obj, collections.abc.Mapping):
for child_key, child_value in obj.items(): for child_key, child_value in obj.items():
attach(hass, child_key) attach(hass, child_key)
attach(hass, child_value) attach(hass, child_value)
@ -77,7 +77,7 @@ def render_complex(value: Any, variables: TemplateVarsType = None) -> Any:
"""Recursive template creator helper function.""" """Recursive template creator helper function."""
if isinstance(value, list): if isinstance(value, list):
return [render_complex(item, variables) for item in value] return [render_complex(item, variables) for item in value]
if isinstance(value, dict): if isinstance(value, collections.abc.Mapping):
return { return {
render_complex(key, variables): render_complex(item, variables) render_complex(key, variables): render_complex(item, variables)
for key, item in value.items() for key, item in value.items()
@ -88,6 +88,19 @@ def render_complex(value: Any, variables: TemplateVarsType = None) -> Any:
return value return value
def is_complex(value: Any) -> bool:
"""Test if data structure is a complex template."""
if isinstance(value, Template):
return True
if isinstance(value, list):
return any(is_complex(val) for val in value)
if isinstance(value, collections.abc.Mapping):
return any(is_complex(val) for val in value.keys()) or any(
is_complex(val) for val in value.values()
)
return False
def is_template_string(maybe_template: str) -> bool: def is_template_string(maybe_template: str) -> bool:
"""Check if the input is a Jinja2 template.""" """Check if the input is a Jinja2 template."""
return _RE_JINJA_DELIMITERS.search(maybe_template) is not None return _RE_JINJA_DELIMITERS.search(maybe_template) is not None

View file

@ -1,5 +1,5 @@
"""Typing Helpers for Home Assistant.""" """Typing Helpers for Home Assistant."""
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Mapping, Optional, Tuple, Union
import homeassistant.core import homeassistant.core
@ -12,7 +12,7 @@ HomeAssistantType = homeassistant.core.HomeAssistant
ServiceCallType = homeassistant.core.ServiceCall ServiceCallType = homeassistant.core.ServiceCall
ServiceDataType = Dict[str, Any] ServiceDataType = Dict[str, Any]
StateType = Union[None, str, int, float] StateType = Union[None, str, int, float]
TemplateVarsType = Optional[Dict[str, Any]] TemplateVarsType = Optional[Mapping[str, Any]]
# Custom type for recorder Queries # Custom type for recorder Queries
QueryType = Any QueryType = Any

View file

@ -1134,3 +1134,57 @@ async def test_logbook_humanify_automation_triggered_event(hass):
assert event2["domain"] == "automation" assert event2["domain"] == "automation"
assert event2["message"] == "has been triggered by source of trigger" assert event2["message"] == "has been triggered by source of trigger"
assert event2["entity_id"] == "automation.bye" assert event2["entity_id"] == "automation.bye"
async def test_automation_variables(hass):
"""Test automation variables."""
calls = async_mock_service(hass, "test", "automation")
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: [
{
"variables": {
"test_var": "defined_in_config",
"event_type": "{{ trigger.event.event_type }}",
},
"trigger": {"platform": "event", "event_type": "test_event"},
"action": {
"service": "test.automation",
"data": {
"value": "{{ test_var }}",
"event_type": "{{ event_type }}",
},
},
},
{
"variables": {
"test_var": "defined_in_config",
},
"trigger": {"platform": "event", "event_type": "test_event_2"},
"condition": {
"condition": "template",
"value_template": "{{ trigger.event.data.pass_condition }}",
},
"action": {
"service": "test.automation",
},
},
]
},
)
hass.bus.async_fire("test_event")
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0].data["value"] == "defined_in_config"
assert calls[0].data["event_type"] == "test_event"
hass.bus.async_fire("test_event_2")
await hass.async_block_till_done()
assert len(calls) == 1
hass.bus.async_fire("test_event_2", {"pass_condition": True})
await hass.async_block_till_done()
assert len(calls) == 2

View file

@ -23,7 +23,7 @@ from homeassistant.loader import bind_hass
from homeassistant.setup import async_setup_component, setup_component from homeassistant.setup import async_setup_component, setup_component
from tests.async_mock import Mock, patch from tests.async_mock import Mock, patch
from tests.common import get_test_home_assistant from tests.common import async_mock_service, get_test_home_assistant
from tests.components.logbook.test_init import MockLazyEventPartialState from tests.components.logbook.test_init import MockLazyEventPartialState
ENTITY_ID = "script.test" ENTITY_ID = "script.test"
@ -615,3 +615,69 @@ async def test_concurrent_script(hass, concurrently):
assert not script.is_on(hass, "script.script1") assert not script.is_on(hass, "script.script1")
assert not script.is_on(hass, "script.script2") assert not script.is_on(hass, "script.script2")
async def test_script_variables(hass):
"""Test defining scripts."""
assert await async_setup_component(
hass,
"script",
{
"script": {
"script1": {
"variables": {
"test_var": "from_config",
"templated_config_var": "{{ var_from_service | default('config-default') }}",
},
"sequence": [
{
"service": "test.script",
"data": {
"value": "{{ test_var }}",
"templated_config_var": "{{ templated_config_var }}",
},
},
],
},
"script2": {
"variables": {
"test_var": "from_config",
},
"sequence": [
{
"service": "test.script",
"data": {
"value": "{{ test_var }}",
},
},
],
},
}
},
)
mock_calls = async_mock_service(hass, "test", "script")
await hass.services.async_call(
"script", "script1", {"var_from_service": "hello"}, blocking=True
)
assert len(mock_calls) == 1
assert mock_calls[0].data["value"] == "from_config"
assert mock_calls[0].data["templated_config_var"] == "hello"
await hass.services.async_call(
"script", "script1", {"test_var": "from_service"}, blocking=True
)
assert len(mock_calls) == 2
assert mock_calls[1].data["value"] == "from_service"
assert mock_calls[1].data["templated_config_var"] == "config-default"
# Call script with vars but no templates in it
await hass.services.async_call(
"script", "script2", {"test_var": "from_service"}, blocking=True
)
assert len(mock_calls) == 3
assert mock_calls[2].data["value"] == "from_service"