Add shorthand notation for Template conditions (#39705)

This commit is contained in:
Franck Nijhof 2020-09-06 16:55:06 +02:00 committed by GitHub
parent da9b077c11
commit a3c45a6f89
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 87 additions and 43 deletions

View file

@ -52,26 +52,30 @@ ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool]
async def async_from_config(
hass: HomeAssistant, config: ConfigType, config_validation: bool = True
hass: HomeAssistant,
config: Union[ConfigType, Template],
config_validation: bool = True,
) -> ConditionCheckerType:
"""Turn a condition configuration into a method.
Should be run on the event loop.
"""
if isinstance(config, Template):
# We got a condition template, wrap it in a configuration to pass along.
config = {
CONF_CONDITION: "template",
CONF_VALUE_TEMPLATE: config,
}
condition = config.get(CONF_CONDITION)
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
factory = getattr(
sys.modules[__name__], fmt.format(config.get(CONF_CONDITION)), None
)
factory = getattr(sys.modules[__name__], fmt.format(condition), None)
if factory:
break
if factory is None:
raise HomeAssistantError(
'Invalid condition "{}" specified {}'.format(
config.get(CONF_CONDITION), config
)
)
raise HomeAssistantError(f'Invalid condition "{condition}" specified {config}')
# Check for partials to properly determine if coroutine function
check_factory = factory
@ -584,9 +588,12 @@ async def async_device_from_config(
async def async_validate_condition_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
hass: HomeAssistant, config: Union[ConfigType, Template]
) -> Union[ConfigType, Template]:
"""Validate config."""
if isinstance(config, Template):
return config
condition = config[CONF_CONDITION]
if condition in ("and", "not", "or"):
conditions = []
@ -597,6 +604,7 @@ async def async_validate_condition_config(
if condition == "device":
config = cv.DEVICE_CONDITION_SCHEMA(config)
assert not isinstance(config, Template)
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], "condition"
)

View file

@ -1018,7 +1018,9 @@ DEVICE_CONDITION_BASE_SCHEMA = vol.Schema(
DEVICE_CONDITION_SCHEMA = DEVICE_CONDITION_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
CONDITION_SCHEMA: vol.Schema = key_value_schemas(
CONDITION_SCHEMA: vol.Schema = vol.Schema(
vol.Any(
key_value_schemas(
CONF_CONDITION,
{
"numeric_state": NUMERIC_STATE_CONDITION_SCHEMA,
@ -1032,6 +1034,9 @@ CONDITION_SCHEMA: vol.Schema = key_value_schemas(
"not": NOT_CONDITION_SCHEMA,
"device": DEVICE_CONDITION_SCHEMA,
},
),
dynamic_template,
)
)
TRIGGER_SCHEMA = vol.All(

View file

@ -935,6 +935,9 @@ class Script:
await asyncio.shield(self._async_stop(update_state))
async def _async_get_condition(self, config):
if isinstance(config, template.Template):
config_cache_key = config.template
else:
config_cache_key = frozenset((k, str(v)) for k, v in config.items())
cond = self._config_cache.get(config_cache_key)
if not cond:

View file

@ -232,6 +232,31 @@ async def test_two_conditions_with_and(hass, calls):
assert len(calls) == 1
async def test_shorthand_conditions_template(hass, calls):
"""Test shorthand nation form in conditions."""
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"trigger": [{"platform": "event", "event_type": "test_event"}],
"condition": "{{ is_state('test.entity', 'hello') }}",
"action": {"service": "test.automation"},
}
},
)
hass.states.async_set("test.entity", "hello")
hass.bus.async_fire("test_event")
await hass.async_block_till_done()
assert len(calls) == 1
hass.states.async_set("test.entity", "goodbye")
hass.bus.async_fire("test_event")
await hass.async_block_till_done()
assert len(calls) == 1
async def test_automation_list_setting(hass, calls):
"""Event is not a valid condition."""
assert await async_setup_component(

View file

@ -128,10 +128,7 @@ async def test_or_condition_with_template(hass):
{
"condition": "or",
"conditions": [
{
"condition": "template",
"value_template": '{{ states.sensor.temperature.state == "100" }}',
},
{'{{ states.sensor.temperature.state == "100" }}'},
{
"condition": "numeric_state",
"entity_id": "sensor.temperature",

View file

@ -982,7 +982,8 @@ async def test_repeat_count(hass):
@pytest.mark.parametrize("condition", ["while", "until"])
async def test_repeat_conditional(hass, condition):
@pytest.mark.parametrize("direct_template", [False, True])
async def test_repeat_conditional(hass, condition, direct_template):
"""Test repeat action w/ while option."""
event = "test_event"
events = async_capture_events(hass, event)
@ -1004,14 +1005,22 @@ async def test_repeat_conditional(hass, condition):
}
}
if condition == "while":
template = "{{ not is_state('sensor.test', 'done') }}"
if direct_template:
sequence["repeat"]["while"] = template
else:
sequence["repeat"]["while"] = {
"condition": "template",
"value_template": "{{ not is_state('sensor.test', 'done') }}",
"value_template": template,
}
else:
template = "{{ is_state('sensor.test', 'done') }}"
if direct_template:
sequence["repeat"]["until"] = template
else:
sequence["repeat"]["until"] = {
"condition": "template",
"value_template": "{{ is_state('sensor.test', 'done') }}",
"value_template": template,
}
script_obj = script.Script(
hass, cv.SCRIPT_SCHEMA(sequence), "Test Name", "test_domain"
@ -1193,10 +1202,7 @@ async def test_choose(hass, var, result):
"sequence": {"event": event, "event_data": {"choice": "first"}},
},
{
"conditions": {
"condition": "template",
"value_template": "{{ var == 2 }}",
},
"conditions": "{{ var == 2 }}",
"sequence": {"event": event, "event_data": {"choice": "second"}},
},
],