Support templating for in state conditions (#88411)

This commit is contained in:
Erik Montnemery 2023-02-20 18:57:00 +01:00 committed by GitHub
parent 0b81c836ef
commit cc4a179ca8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 173 additions and 49 deletions

View file

@ -12,6 +12,8 @@ import re
import sys import sys
from typing import Any, cast from typing import Any, cast
import voluptuous as vol
from homeassistant.components import zone as zone_cmp from homeassistant.components import zone as zone_cmp
from homeassistant.components.device_automation import condition as device_condition from homeassistant.components.device_automation import condition as device_condition
from homeassistant.components.sensor import SensorDeviceClass from homeassistant.components.sensor import SensorDeviceClass
@ -29,6 +31,7 @@ from homeassistant.const import (
CONF_DEVICE_ID, CONF_DEVICE_ID,
CONF_ENABLED, CONF_ENABLED,
CONF_ENTITY_ID, CONF_ENTITY_ID,
CONF_FOR,
CONF_ID, CONF_ID,
CONF_MATCH, CONF_MATCH,
CONF_STATE, CONF_STATE,
@ -57,7 +60,7 @@ import homeassistant.util.dt as dt_util
from . import config_validation as cv, entity_registry as er from . import config_validation as cv, entity_registry as er
from .sun import get_astral_event_date from .sun import get_astral_event_date
from .template import Template from .template import Template, attach as template_attach, render_complex
from .trace import ( from .trace import (
TraceElement, TraceElement,
trace_append_element, trace_append_element,
@ -481,6 +484,7 @@ def state(
req_state: Any, req_state: Any,
for_period: timedelta | None = None, for_period: timedelta | None = None,
attribute: str | None = None, attribute: str | None = None,
variables: TemplateVarsType = None,
) -> bool: ) -> bool:
"""Test if state matches requirements. """Test if state matches requirements.
@ -534,7 +538,14 @@ def state(
condition_trace_set_result(is_state, state=value, wanted_state=state_value) condition_trace_set_result(is_state, state=value, wanted_state=state_value)
return is_state return is_state
duration = dt_util.utcnow() - for_period try:
for_period = cv.positive_time_period(render_complex(for_period, variables))
except TemplateError as ex:
raise ConditionErrorMessage("state", f"template error: {ex}") from ex
except vol.Invalid as ex:
raise ConditionErrorMessage("state", f"schema error: {ex}") from ex
duration = dt_util.utcnow() - cast(timedelta, for_period)
duration_ok = duration > entity.last_changed duration_ok = duration > entity.last_changed
condition_trace_set_result(duration_ok, state=value, duration=duration) condition_trace_set_result(duration_ok, state=value, duration=duration)
return duration_ok return duration_ok
@ -544,7 +555,7 @@ def state_from_config(config: ConfigType) -> ConditionCheckerType:
"""Wrap action method with state based condition.""" """Wrap action method with state based condition."""
entity_ids = config.get(CONF_ENTITY_ID, []) entity_ids = config.get(CONF_ENTITY_ID, [])
req_states: str | list[str] = config.get(CONF_STATE, []) req_states: str | list[str] = config.get(CONF_STATE, [])
for_period = config.get("for") for_period = config.get(CONF_FOR)
attribute = config.get(CONF_ATTRIBUTE) attribute = config.get(CONF_ATTRIBUTE)
match = config.get(CONF_MATCH, ENTITY_MATCH_ALL) match = config.get(CONF_MATCH, ENTITY_MATCH_ALL)
@ -554,12 +565,15 @@ def state_from_config(config: ConfigType) -> ConditionCheckerType:
@trace_condition_function @trace_condition_function
def if_state(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool: def if_state(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
"""Test if condition.""" """Test if condition."""
template_attach(hass, for_period)
errors = [] errors = []
result: bool = match != ENTITY_MATCH_ANY result: bool = match != ENTITY_MATCH_ANY
for index, entity_id in enumerate(entity_ids): for index, entity_id in enumerate(entity_ids):
try: try:
with trace_path(["entity_id", str(index)]), trace_condition(variables): with trace_path(["entity_id", str(index)]), trace_condition(variables):
if state(hass, entity_id, req_states, for_period, attribute): if state(
hass, entity_id, req_states, for_period, attribute, variables
):
result = True result = True
elif match == ENTITY_MATCH_ALL: elif match == ENTITY_MATCH_ALL:
return False return False

View file

@ -386,6 +386,8 @@ def icon(value: Any) -> str:
raise vol.Invalid('Icons should be specified in the form "prefix:name"') raise vol.Invalid('Icons should be specified in the form "prefix:name"')
_TIME_PERIOD_DICT_KEYS = ("days", "hours", "minutes", "seconds", "milliseconds")
time_period_dict = vol.All( time_period_dict = vol.All(
dict, dict,
vol.Schema( vol.Schema(
@ -397,7 +399,7 @@ time_period_dict = vol.All(
"milliseconds": vol.Coerce(float), "milliseconds": vol.Coerce(float),
} }
), ),
has_at_least_one_key("days", "hours", "minutes", "seconds", "milliseconds"), has_at_least_one_key(*_TIME_PERIOD_DICT_KEYS),
lambda value: timedelta(**value), lambda value: timedelta(**value),
) )
@ -639,8 +641,24 @@ def template_complex(value: Any) -> Any:
return value return value
def _positive_time_period_template_complex(value: Any) -> Any:
"""Do basic validation of a positive time period expressed as a templated dict."""
if not isinstance(value, dict) or not value:
raise vol.Invalid("template should be a dict")
for key, element in value.items():
if not isinstance(key, str):
raise vol.Invalid("key should be a string")
if not template_helper.is_template_string(key):
vol.In(_TIME_PERIOD_DICT_KEYS)(key)
if not isinstance(element, str) or (
isinstance(element, str) and not template_helper.is_template_string(element)
):
vol.All(vol.Coerce(float), vol.Range(min=0))(element)
return template_complex(value)
positive_time_period_template = vol.Any( positive_time_period_template = vol.Any(
positive_time_period, template, template_complex positive_time_period, dynamic_template, _positive_time_period_template_complex
) )
@ -1166,7 +1184,7 @@ STATE_CONDITION_BASE_SCHEMA = {
vol.Lower, vol.Any(ENTITY_MATCH_ALL, ENTITY_MATCH_ANY) vol.Lower, vol.Any(ENTITY_MATCH_ALL, ENTITY_MATCH_ANY)
), ),
vol.Optional(CONF_ATTRIBUTE): str, vol.Optional(CONF_ATTRIBUTE): str,
vol.Optional(CONF_FOR): positive_time_period, vol.Optional(CONF_FOR): positive_time_period_template,
# To support use_trigger_value in automation # To support use_trigger_value in automation
# Deprecated 2016/04/25 # Deprecated 2016/04/25
vol.Optional("from"): str, vol.Optional("from"): str,

View file

@ -1044,27 +1044,23 @@ async def test_if_fails_setup_bad_for(hass, calls, above, below):
hass.states.async_set("test.entity", 5) hass.states.async_set("test.entity", 5)
await hass.async_block_till_done() await hass.async_block_till_done()
assert await async_setup_component( with assert_setup_component(0, automation.DOMAIN):
hass, assert await async_setup_component(
automation.DOMAIN, hass,
{ automation.DOMAIN,
automation.DOMAIN: { {
"trigger": { automation.DOMAIN: {
"platform": "numeric_state", "trigger": {
"entity_id": "test.entity", "platform": "numeric_state",
"above": above, "entity_id": "test.entity",
"below": below, "above": above,
"for": {"invalid": 5}, "below": below,
}, "for": {"invalid": 5},
"action": {"service": "homeassistant.turn_on"}, },
} "action": {"service": "homeassistant.turn_on"},
}, }
) },
)
with patch.object(numeric_state_trigger, "_LOGGER") as mock_logger:
hass.states.async_set("test.entity", 9)
await hass.async_block_till_done()
assert mock_logger.error.called
async def test_if_fails_setup_for_without_above_below(hass, calls): async def test_if_fails_setup_for_without_above_below(hass, calls):

View file

@ -578,26 +578,22 @@ async def test_if_fails_setup_if_from_boolean_value(hass, calls):
async def test_if_fails_setup_bad_for(hass, calls): async def test_if_fails_setup_bad_for(hass, calls):
"""Test for setup failure for bad for.""" """Test for setup failure for bad for."""
assert await async_setup_component( with assert_setup_component(0, automation.DOMAIN):
hass, assert await async_setup_component(
automation.DOMAIN, hass,
{ automation.DOMAIN,
automation.DOMAIN: { {
"trigger": { automation.DOMAIN: {
"platform": "state", "trigger": {
"entity_id": "test.entity", "platform": "state",
"to": "world", "entity_id": "test.entity",
"for": {"invalid": 5}, "to": "world",
}, "for": {"invalid": 5},
"action": {"service": "homeassistant.turn_on"}, },
} "action": {"service": "homeassistant.turn_on"},
}, }
) },
)
with patch.object(state_trigger, "_LOGGER") as mock_logger:
hass.states.async_set("test.entity", "world")
await hass.async_block_till_done()
assert mock_logger.error.called
async def test_if_not_fires_on_entity_change_with_for(hass, calls): async def test_if_not_fires_on_entity_change_with_for(hass, calls):

View file

@ -1,5 +1,5 @@
"""Test the condition helper.""" """Test the condition helper."""
from datetime import datetime from datetime import datetime, timedelta
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
import pytest import pytest
@ -1125,6 +1125,81 @@ async def test_state_raises(hass: HomeAssistant) -> None:
test(hass) test(hass)
async def test_state_for(hass: HomeAssistant) -> None:
"""Test state with duration."""
config = {
"condition": "and",
"conditions": [
{
"condition": "state",
"entity_id": ["sensor.temperature"],
"state": "100",
"for": {"seconds": 5},
},
],
}
config = cv.CONDITION_SCHEMA(config)
config = await condition.async_validate_condition_config(hass, config)
test = await condition.async_from_config(hass, config)
hass.states.async_set("sensor.temperature", 100)
assert not test(hass)
now = dt_util.utcnow() + timedelta(seconds=5)
with patch("homeassistant.util.dt.utcnow", return_value=now):
assert test(hass)
async def test_state_for_template(hass: HomeAssistant) -> None:
"""Test state with templated duration."""
config = {
"condition": "and",
"conditions": [
{
"condition": "state",
"entity_id": ["sensor.temperature"],
"state": "100",
"for": {"seconds": "{{ states('input_number.test')|int }}"},
},
],
}
config = cv.CONDITION_SCHEMA(config)
config = await condition.async_validate_condition_config(hass, config)
test = await condition.async_from_config(hass, config)
hass.states.async_set("sensor.temperature", 100)
hass.states.async_set("input_number.test", 5)
assert not test(hass)
now = dt_util.utcnow() + timedelta(seconds=5)
with patch("homeassistant.util.dt.utcnow", return_value=now):
assert test(hass)
@pytest.mark.parametrize("for_template", [{"{{invalid}}": 5}, {"hours": "{{ 1/0 }}"}])
async def test_state_for_invalid_template(hass: HomeAssistant, for_template) -> None:
"""Test state with invalid templated duration."""
config = {
"condition": "and",
"conditions": [
{
"condition": "state",
"entity_id": ["sensor.temperature"],
"state": "100",
"for": for_template,
},
],
}
config = cv.CONDITION_SCHEMA(config)
config = await condition.async_validate_condition_config(hass, config)
test = await condition.async_from_config(hass, config)
hass.states.async_set("sensor.temperature", 100)
hass.states.async_set("input_number.test", 5)
with pytest.raises(ConditionError):
assert not test(hass)
async def test_state_unknown_attribute(hass: HomeAssistant) -> None: async def test_state_unknown_attribute(hass: HomeAssistant) -> None:
"""Test that state returns False on unknown attribute.""" """Test that state returns False on unknown attribute."""
# Unknown attribute # Unknown attribute

View file

@ -1376,3 +1376,28 @@ def test_language() -> None:
for value in ("en", "sv"): for value in ("en", "sv"):
assert schema(value) assert schema(value)
def test_positive_time_period_template() -> None:
"""Test positive time period template validation."""
schema = vol.Schema(cv.positive_time_period_template)
with pytest.raises(vol.MultipleInvalid):
schema({})
with pytest.raises(vol.MultipleInvalid):
schema({5: 5})
with pytest.raises(vol.MultipleInvalid):
schema({"invalid": 5})
with pytest.raises(vol.MultipleInvalid):
schema("invalid")
# Time periods pass
schema("00:01")
schema("00:00:01")
schema("00:00:00.500")
schema({"minutes": 5})
# Templates are not evaluated and will pass
schema("{{ 'invalid' }}")
schema({"{{ 'invalid' }}": 5})
schema({"minutes": "{{ 'invalid' }}"})