Allow templates for enabling actions (#117049)

* Allow templates for enabling automation actions

* Use `cv.template` instead of `cv.template_complex`

* Rename test function
This commit is contained in:
Matthias Alphart 2024-05-15 21:03:52 +02:00 committed by GitHub
parent 076f57ee07
commit ec4c8ae228
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 59 additions and 10 deletions

View file

@ -1311,7 +1311,7 @@ SCRIPT_SCHEMA = vol.All(ensure_list, [script_action])
SCRIPT_ACTION_BASE_SCHEMA = { SCRIPT_ACTION_BASE_SCHEMA = {
vol.Optional(CONF_ALIAS): string, vol.Optional(CONF_ALIAS): string,
vol.Optional(CONF_CONTINUE_ON_ERROR): boolean, vol.Optional(CONF_CONTINUE_ON_ERROR): boolean,
vol.Optional(CONF_ENABLED): boolean, vol.Optional(CONF_ENABLED): vol.Any(boolean, template),
} }
EVENT_SCHEMA = vol.Schema( EVENT_SCHEMA = vol.Schema(

View file

@ -89,6 +89,7 @@ from .condition import ConditionCheckerType, trace_condition_function
from .dispatcher import async_dispatcher_connect, async_dispatcher_send_internal from .dispatcher import async_dispatcher_connect, async_dispatcher_send_internal
from .event import async_call_later, async_track_template from .event import async_call_later, async_track_template
from .script_variables import ScriptVariables from .script_variables import ScriptVariables
from .template import Template
from .trace import ( from .trace import (
TraceElement, TraceElement,
async_trace_path, async_trace_path,
@ -500,12 +501,24 @@ class _ScriptRun:
action = cv.determine_script_action(self._action) action = cv.determine_script_action(self._action)
if not self._action.get(CONF_ENABLED, True): if CONF_ENABLED in self._action:
self._log( enabled = self._action[CONF_ENABLED]
"Skipped disabled step %s", self._action.get(CONF_ALIAS, action) if isinstance(enabled, Template):
) try:
trace_set_result(enabled=False) enabled = enabled.async_render(limited=True)
return except exceptions.TemplateError as ex:
self._handle_exception(
ex,
continue_on_error,
self._log_exceptions or log_exceptions,
)
if not enabled:
self._log(
"Skipped disabled step %s",
self._action.get(CONF_ALIAS, action),
)
trace_set_result(enabled=False)
return
handler = f"_async_{action}_step" handler = f"_async_{action}_step"
try: try:

View file

@ -5764,8 +5764,9 @@ async def test_continue_on_error_unknown_error(hass: HomeAssistant) -> None:
) )
@pytest.mark.parametrize("enabled_value", [False, "{{ 1 == 9 }}"])
async def test_disabled_actions( async def test_disabled_actions(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture hass: HomeAssistant, caplog: pytest.LogCaptureFixture, enabled_value: bool | str
) -> None: ) -> None:
"""Test disabled action steps.""" """Test disabled action steps."""
events = async_capture_events(hass, "test_event") events = async_capture_events(hass, "test_event")
@ -5782,10 +5783,14 @@ async def test_disabled_actions(
{"event": "test_event"}, {"event": "test_event"},
{ {
"alias": "Hello", "alias": "Hello",
"enabled": False, "enabled": enabled_value,
"service": "broken.service", "service": "broken.service",
}, },
{"alias": "World", "enabled": False, "event": "test_event"}, {
"alias": "World",
"enabled": enabled_value,
"event": "test_event",
},
{"event": "test_event"}, {"event": "test_event"},
] ]
) )
@ -5807,6 +5812,37 @@ async def test_disabled_actions(
) )
async def test_enabled_error_non_limited_template(hass: HomeAssistant) -> None:
"""Test that a script aborts when an action enabled uses non-limited template."""
await async_setup_component(hass, "homeassistant", {})
event = "test_event"
events = async_capture_events(hass, event)
sequence = cv.SCRIPT_SCHEMA(
[
{
"event": event,
"enabled": "{{ states('sensor.limited') }}",
}
]
)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
with pytest.raises(exceptions.TemplateError):
await script_obj.async_run(context=Context())
assert len(events) == 0
assert not script_obj.is_running
expected_trace = {
"0": [
{
"error": "TemplateError: Use of 'states' is not supported in limited templates"
}
],
}
assert_action_trace(expected_trace, expected_script_execution="error")
async def test_condition_and_shorthand( async def test_condition_and_shorthand(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None: ) -> None: