Create variable with result of wait_template and accept template for timeout option (#38634)

This commit is contained in:
Phil Bruckner 2020-08-12 13:42:06 -05:00 committed by GitHub
parent 45526f4e8a
commit 580e229cf2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 164 additions and 128 deletions

View file

@ -47,11 +47,7 @@ TRIGGER_SCHEMA = vol.All(
vol.Optional(CONF_BELOW): vol.Coerce(float),
vol.Optional(CONF_ABOVE): vol.Coerce(float),
vol.Optional(CONF_VALUE_TEMPLATE): cv.template,
vol.Optional(CONF_FOR): vol.Any(
vol.All(cv.time_period, cv.positive_timedelta),
cv.template,
cv.template_complex,
),
vol.Optional(CONF_FOR): cv.positive_time_period_template,
}
),
cv.has_at_least_one_key(CONF_BELOW, CONF_ABOVE),
@ -141,20 +137,9 @@ async def async_attach_trigger(
}
try:
if isinstance(time_delta, template.Template):
period[entity] = vol.All(cv.time_period, cv.positive_timedelta)(
time_delta.async_render(variables)
)
elif isinstance(time_delta, dict):
time_delta_data = {}
time_delta_data.update(
template.render_complex(time_delta, variables)
)
period[entity] = vol.All(cv.time_period, cv.positive_timedelta)(
time_delta_data
)
else:
period[entity] = time_delta
period[entity] = cv.positive_time_period(
template.render_complex(time_delta, variables)
)
except (exceptions.TemplateError, vol.Invalid) as ex:
_LOGGER.error(
"Error rendering '%s' for template: %s",

View file

@ -33,11 +33,7 @@ TRIGGER_SCHEMA = vol.All(
# These are str on purpose. Want to catch YAML conversions
vol.Optional(CONF_FROM): vol.Any(str, [str]),
vol.Optional(CONF_TO): vol.Any(str, [str]),
vol.Optional(CONF_FOR): vol.Any(
vol.All(cv.time_period, cv.positive_timedelta),
cv.template,
cv.template_complex,
),
vol.Optional(CONF_FOR): cv.positive_time_period_template,
}
),
cv.key_dependency(CONF_FOR, CONF_TO),
@ -115,18 +111,9 @@ async def async_attach_trigger(
}
try:
if isinstance(time_delta, template.Template):
period[entity] = vol.All(cv.time_period, cv.positive_timedelta)(
time_delta.async_render(variables)
)
elif isinstance(time_delta, dict):
time_delta_data = {}
time_delta_data.update(template.render_complex(time_delta, variables))
period[entity] = vol.All(cv.time_period, cv.positive_timedelta)(
time_delta_data
)
else:
period[entity] = time_delta
period[entity] = cv.positive_time_period(
template.render_complex(time_delta, variables)
)
except (exceptions.TemplateError, vol.Invalid) as ex:
_LOGGER.error(
"Error rendering '%s' for template: %s", automation_info["name"], ex

View file

@ -17,11 +17,7 @@ TRIGGER_SCHEMA = IF_ACTION_SCHEMA = vol.Schema(
{
vol.Required(CONF_PLATFORM): "template",
vol.Required(CONF_VALUE_TEMPLATE): cv.template,
vol.Optional(CONF_FOR): vol.Any(
vol.All(cv.time_period, cv.positive_timedelta),
cv.template,
cv.template_complex,
),
vol.Optional(CONF_FOR): cv.positive_time_period_template,
}
)
@ -73,16 +69,9 @@ async def async_attach_trigger(
}
try:
if isinstance(time_delta, template.Template):
period = vol.All(cv.time_period, cv.positive_timedelta)(
time_delta.async_render(variables)
)
elif isinstance(time_delta, dict):
time_delta_data = {}
time_delta_data.update(template.render_complex(time_delta, variables))
period = vol.All(cv.time_period, cv.positive_timedelta)(time_delta_data)
else:
period = time_delta
period = cv.positive_time_period(
template.render_complex(time_delta, variables)
)
except (exceptions.TemplateError, vol.Invalid) as ex:
_LOGGER.error(
"Error rendering '%s' for template: %s", automation_info["name"], ex

View file

@ -68,13 +68,13 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
vol.Required(CONF_SENSOR): cv.entity_id,
vol.Optional(CONF_AC_MODE): cv.boolean,
vol.Optional(CONF_MAX_TEMP): vol.Coerce(float),
vol.Optional(CONF_MIN_DUR): vol.All(cv.time_period, cv.positive_timedelta),
vol.Optional(CONF_MIN_DUR): cv.positive_time_period,
vol.Optional(CONF_MIN_TEMP): vol.Coerce(float),
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
vol.Optional(CONF_COLD_TOLERANCE, default=DEFAULT_TOLERANCE): vol.Coerce(float),
vol.Optional(CONF_HOT_TOLERANCE, default=DEFAULT_TOLERANCE): vol.Coerce(float),
vol.Optional(CONF_TARGET_TEMP): vol.Coerce(float),
vol.Optional(CONF_KEEP_ALIVE): vol.All(cv.time_period, cv.positive_timedelta),
vol.Optional(CONF_KEEP_ALIVE): cv.positive_time_period,
vol.Optional(CONF_INITIAL_HVAC_MODE): vol.In(
[HVAC_MODE_COOL, HVAC_MODE_HEAT, HVAC_MODE_OFF]
),

View file

@ -56,7 +56,7 @@ CONFIG_SCHEMA = vol.Schema(
vol.Required(CONF_DEVICE_PORT): cv.port,
vol.Optional(
CONF_SCAN_INTERVAL, default=DEFAULT_SCAN_INTERVAL
): vol.All(cv.time_period, cv.positive_timedelta),
): cv.positive_time_period,
vol.Optional(CONF_ZONES, default=DEFAULT_ZONES): vol.All(
cv.ensure_list, [ZONE_SCHEMA]
),

View file

@ -35,7 +35,7 @@ CONFIG_SCHEMA = vol.Schema(
vol.Optional(CONF_SERVER_ID): cv.positive_int,
vol.Optional(
CONF_SCAN_INTERVAL, default=timedelta(minutes=DEFAULT_SCAN_INTERVAL)
): vol.All(cv.time_period, cv.positive_timedelta),
): cv.positive_time_period,
vol.Optional(CONF_MANUAL, default=False): cv.boolean,
vol.Optional(
CONF_MONITORED_CONDITIONS, default=list(SENSOR_TYPES)

View file

@ -49,8 +49,8 @@ SENSOR_SCHEMA = vol.Schema(
vol.Optional(ATTR_FRIENDLY_NAME): cv.string,
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
vol.Optional(CONF_DEVICE_CLASS): DEVICE_CLASSES_SCHEMA,
vol.Optional(CONF_DELAY_ON): vol.All(cv.time_period, cv.positive_timedelta),
vol.Optional(CONF_DELAY_OFF): vol.All(cv.time_period, cv.positive_timedelta),
vol.Optional(CONF_DELAY_ON): cv.positive_time_period,
vol.Optional(CONF_DELAY_OFF): cv.positive_time_period,
vol.Optional(CONF_UNIQUE_ID): cv.string,
}
)

View file

@ -48,7 +48,7 @@ CONFIG_SCHEMA = vol.Schema(
vol.Required(CONF_CLIENT_SECRET): cv.string,
vol.Optional(
CONF_SCAN_INTERVAL, default=DEFAULT_SCAN_INTERVAL
): vol.All(cv.time_period, cv.positive_timedelta),
): cv.positive_time_period,
}
),
)

View file

@ -102,7 +102,7 @@ SERVICE_SCHEMA_SET_SCENE = XIAOMI_MIIO_SERVICE_SCHEMA.extend(
)
SERVICE_SCHEMA_SET_DELAYED_TURN_OFF = XIAOMI_MIIO_SERVICE_SCHEMA.extend(
{vol.Required(ATTR_TIME_PERIOD): vol.All(cv.time_period, cv.positive_timedelta)}
{vol.Required(ATTR_TIME_PERIOD): cv.positive_time_period}
)
SERVICE_TO_METHOD = {

View file

@ -402,6 +402,7 @@ def positive_timedelta(value: timedelta) -> timedelta:
positive_time_period_dict = vol.All(time_period_dict, positive_timedelta)
positive_time_period = vol.All(time_period, positive_timedelta)
def remove_falsy(value: List[T]) -> List[T]:
@ -530,6 +531,11 @@ def template_complex(value: Any) -> Any:
return value
positive_time_period_template = vol.Any(
positive_time_period, template, template_complex
)
def datetime(value: Any) -> datetime_sys:
"""Validate datetime."""
if isinstance(value, datetime_sys):
@ -876,7 +882,7 @@ STATE_CONDITION_SCHEMA = vol.All(
vol.Required(CONF_CONDITION): "state",
vol.Required(CONF_ENTITY_ID): entity_ids,
vol.Required(CONF_STATE): vol.Any(str, [str]),
vol.Optional(CONF_FOR): vol.All(time_period, positive_timedelta),
vol.Optional(CONF_FOR): positive_time_period,
# To support use_trigger_value in automation
# Deprecated 2016/04/25
vol.Optional("from"): str,
@ -992,9 +998,7 @@ CONDITION_SCHEMA: vol.Schema = key_value_schemas(
_SCRIPT_DELAY_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ALIAS): string,
vol.Required(CONF_DELAY): vol.Any(
vol.All(time_period, positive_timedelta), template, template_complex
),
vol.Required(CONF_DELAY): positive_time_period_template,
}
)
@ -1002,7 +1006,7 @@ _SCRIPT_WAIT_TEMPLATE_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ALIAS): string,
vol.Required(CONF_WAIT_TEMPLATE): template,
vol.Optional(CONF_TIMEOUT): vol.All(time_period, positive_timedelta),
vol.Optional(CONF_TIMEOUT): positive_time_period_template,
vol.Optional(CONF_CONTINUE_ON_TIMEOUT): boolean,
}
)

View file

@ -1,6 +1,6 @@
"""Helpers to execute scripts."""
import asyncio
from datetime import datetime
from datetime import datetime, timedelta
from functools import partial
import itertools
import logging
@ -241,21 +241,25 @@ class _ScriptRun:
level=level,
)
async def _async_delay_step(self):
"""Handle delay."""
def _get_pos_time_period_template(self, key):
try:
delay = vol.All(cv.time_period, cv.positive_timedelta)(
template.render_complex(self._action[CONF_DELAY], self._variables)
return cv.positive_time_period(
template.render_complex(self._action[key], self._variables)
)
except (exceptions.TemplateError, vol.Invalid) as ex:
self._log(
"Error rendering %s delay template: %s",
"Error rendering %s %s template: %s",
self._script.name,
key,
ex,
level=logging.ERROR,
)
raise _StopScript
async def _async_delay_step(self):
"""Handle delay."""
delay = self._get_pos_time_period_template(CONF_DELAY)
self._script.last_action = self._action.get(CONF_ALIAS, f"delay {delay}")
self._log("Executing step %s", self._script.last_action)
@ -269,41 +273,55 @@ class _ScriptRun:
async def _async_wait_template_step(self):
"""Handle a wait template."""
if CONF_TIMEOUT in self._action:
delay = self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds()
else:
delay = None
self._script.last_action = self._action.get(CONF_ALIAS, "wait template")
self._log("Executing step %s", self._script.last_action)
self._log(
"Executing step %s%s",
self._script.last_action,
"" if delay is None else f" (timeout: {timedelta(seconds=delay)})",
)
self._variables["wait"] = {"remaining": delay, "completed": False}
wait_template = self._action[CONF_WAIT_TEMPLATE]
wait_template.hass = self._hass
# check if condition already okay
if condition.async_template(self._hass, wait_template, self._variables):
self._variables["wait"]["completed"] = True
return
@callback
def async_script_wait(entity_id, from_s, to_s):
"""Handle script after template condition is true."""
self._variables["wait"] = {
"remaining": to_context.remaining if to_context else delay,
"completed": True,
}
done.set()
to_context = None
unsub = async_track_template(
self._hass, wait_template, async_script_wait, self._variables
)
self._changed()
try:
delay = self._action[CONF_TIMEOUT].total_seconds()
except KeyError:
delay = None
done = asyncio.Event()
tasks = [
self._hass.async_create_task(flag.wait()) for flag in (self._stop, done)
]
try:
async with timeout(delay):
async with timeout(delay) as to_context:
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
except asyncio.TimeoutError:
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
self._log(_TIMEOUT_MSG)
raise _StopScript
self._variables["wait"]["remaining"] = 0.0
finally:
for task in tasks:
task.cancel()

View file

@ -16,7 +16,6 @@ import homeassistant.components.scene as scene
from homeassistant.const import ATTR_ENTITY_ID, SERVICE_TURN_ON
from homeassistant.core import Context, CoreState, callback
from homeassistant.helpers import config_validation as cv, script
from homeassistant.helpers.event import async_call_later
import homeassistant.util.dt as dt_util
from tests.async_mock import patch
@ -29,49 +28,6 @@ from tests.common import (
ENTITY_ID = "script.test"
@pytest.fixture
def mock_timeout(hass, monkeypatch):
"""Mock async_timeout.timeout."""
class MockTimeout:
def __init__(self, timeout):
self._timeout = timeout
self._loop = asyncio.get_event_loop()
self._task = None
self._cancelled = False
self._unsub = None
async def __aenter__(self):
if self._timeout is None:
return self
self._task = asyncio.Task.current_task()
if self._timeout <= 0:
self._loop.call_soon(self._cancel_task)
return self
# Wait for a time_changed event instead of real time passing.
self._unsub = async_call_later(hass, self._timeout, self._cancel_task)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type is asyncio.CancelledError and self._cancelled:
self._unsub = None
self._task = None
raise asyncio.TimeoutError
if self._timeout is not None and self._unsub:
self._unsub()
self._unsub = None
self._task = None
return None
@callback
def _cancel_task(self, now=None):
if self._task is not None:
self._task.cancel()
self._cancelled = True
monkeypatch.setattr(script, "timeout", MockTimeout)
def async_watch_for_action(script_obj, message):
"""Watch for message in last_action."""
flag = asyncio.Event()
@ -326,7 +282,7 @@ async def test_stop_no_wait(hass, count):
assert len(events) == 0
async def test_delay_basic(hass, mock_timeout):
async def test_delay_basic(hass):
"""Test the delay."""
delay_alias = "delay step"
sequence = cv.SCRIPT_SCHEMA({"delay": {"seconds": 5}, "alias": delay_alias})
@ -350,7 +306,7 @@ async def test_delay_basic(hass, mock_timeout):
assert script_obj.last_action is None
async def test_multiple_runs_delay(hass, mock_timeout):
async def test_multiple_runs_delay(hass):
"""Test multiple runs with delay in script."""
event = "test_event"
events = async_capture_events(hass, event)
@ -393,7 +349,7 @@ async def test_multiple_runs_delay(hass, mock_timeout):
assert events[-1].data["value"] == 2
async def test_delay_template_ok(hass, mock_timeout):
async def test_delay_template_ok(hass):
"""Test the delay as a template."""
sequence = cv.SCRIPT_SCHEMA({"delay": "00:00:{{ 5 }}"})
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
@ -441,7 +397,7 @@ async def test_delay_template_invalid(hass, caplog):
assert len(events) == 1
async def test_delay_template_complex_ok(hass, mock_timeout):
async def test_delay_template_complex_ok(hass):
"""Test the delay with a working complex template."""
sequence = cv.SCRIPT_SCHEMA({"delay": {"seconds": "{{ 5 }}"}})
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
@ -647,11 +603,56 @@ async def test_wait_template_not_schedule(hass):
assert len(events) == 2
@pytest.mark.parametrize(
"timeout_param", [5, "{{ 5 }}", {"seconds": 5}, {"seconds": "{{ 5 }}"}]
)
async def test_wait_template_timeout(hass, caplog, timeout_param):
"""Test the wait timeout option."""
event = "test_event"
events = async_capture_events(hass, event)
sequence = cv.SCRIPT_SCHEMA(
[
{
"wait_template": "{{ states.switch.test.state == 'off' }}",
"timeout": timeout_param,
"continue_on_timeout": True,
},
{"event": event},
]
)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, "wait")
try:
hass.states.async_set("switch.test", "on")
hass.async_create_task(script_obj.async_run())
await asyncio.wait_for(wait_started_flag.wait(), 1)
assert script_obj.is_running
assert len(events) == 0
except (AssertionError, asyncio.TimeoutError):
await script_obj.async_stop()
raise
else:
cur_time = dt_util.utcnow()
async_fire_time_changed(hass, cur_time + timedelta(seconds=4))
await asyncio.sleep(0)
assert len(events) == 0
async_fire_time_changed(hass, cur_time + timedelta(seconds=5))
await hass.async_block_till_done()
assert not script_obj.is_running
assert len(events) == 1
assert "(timeout: 0:00:05)" in caplog.text
@pytest.mark.parametrize(
"continue_on_timeout,n_events", [(False, 0), (True, 1), (None, 1)]
)
async def test_wait_template_timeout(hass, mock_timeout, continue_on_timeout, n_events):
"""Test the wait template, halt on timeout."""
async def test_wait_template_continue_on_timeout(hass, continue_on_timeout, n_events):
"""Test the wait template continue_on_timeout option."""
event = "test_event"
events = async_capture_events(hass, event)
sequence = [
@ -682,8 +683,8 @@ async def test_wait_template_timeout(hass, mock_timeout, continue_on_timeout, n_
assert len(events) == n_events
async def test_wait_template_variables(hass):
"""Test the wait template with variables."""
async def test_wait_template_variables_in(hass):
"""Test the wait template with input variables."""
sequence = cv.SCRIPT_SCHEMA({"wait_template": "{{ is_state(data, 'off') }}"})
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, "wait")
@ -706,6 +707,58 @@ async def test_wait_template_variables(hass):
assert not script_obj.is_running
@pytest.mark.parametrize("mode", ["no_timeout", "timeout_finish", "timeout_not_finish"])
async def test_wait_template_variables_out(hass, mode):
"""Test the wait template output variable."""
event = "test_event"
events = async_capture_events(hass, event)
action = {"wait_template": "{{ states.switch.test.state == 'off' }}"}
if mode != "no_timeout":
action["timeout"] = 5
action["continue_on_timeout"] = True
sequence = [
action,
{
"event": event,
"event_data_template": {
"completed": "{{ wait.completed }}",
"remaining": "{{ wait.remaining }}",
},
},
]
sequence = cv.SCRIPT_SCHEMA(sequence)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, "wait")
try:
hass.states.async_set("switch.test", "on")
hass.async_create_task(script_obj.async_run())
await asyncio.wait_for(wait_started_flag.wait(), 1)
assert script_obj.is_running
assert len(events) == 0
except (AssertionError, asyncio.TimeoutError):
await script_obj.async_stop()
raise
else:
if mode == "timeout_not_finish":
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=5))
else:
hass.states.async_set("switch.test", "off")
await hass.async_block_till_done()
assert not script_obj.is_running
assert len(events) == 1
assert events[0].data["completed"] == str(mode != "timeout_not_finish")
remaining = events[0].data["remaining"]
if mode == "no_timeout":
assert remaining == "None"
elif mode == "timeout_finish":
assert 0.0 < float(remaining) < 5
else:
assert float(remaining) == 0.0
async def test_condition_basic(hass):
"""Test if we can use conditions in a script."""
event = "test_event"