Add trigger condition (#51710)

* Add trigger condition

* Tweaks, add tests
This commit is contained in:
Erik Montnemery 2021-06-11 15:05:57 +02:00 committed by GitHub
parent fa3ae9b83c
commit b01b33c304
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 149 additions and 7 deletions

View file

@ -31,6 +31,7 @@ from homeassistant.const import (
CONF_DEVICE_ID, CONF_DEVICE_ID,
CONF_DOMAIN, CONF_DOMAIN,
CONF_ENTITY_ID, CONF_ENTITY_ID,
CONF_ID,
CONF_STATE, CONF_STATE,
CONF_VALUE_TEMPLATE, CONF_VALUE_TEMPLATE,
CONF_WEEKDAY, CONF_WEEKDAY,
@ -930,6 +931,26 @@ async def async_device_from_config(
) )
async def async_trigger_from_config(
hass: HomeAssistant, config: ConfigType, config_validation: bool = True
) -> ConditionCheckerType:
"""Test a trigger condition."""
if config_validation:
config = cv.TRIGGER_CONDITION_SCHEMA(config)
trigger_id = config[CONF_ID]
@trace_condition_function
def trigger_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
"""Validate trigger based if-condition."""
return (
variables is not None
and "trigger" in variables
and variables["trigger"].get("id") in trigger_id
)
return trigger_if
async def async_validate_condition_config( async def async_validate_condition_config(
hass: HomeAssistant, config: ConfigType | Template hass: HomeAssistant, config: ConfigType | Template
) -> ConfigType | Template: ) -> ConfigType | Template:

View file

@ -45,6 +45,7 @@ from homeassistant.const import (
CONF_EVENT_DATA, CONF_EVENT_DATA,
CONF_EVENT_DATA_TEMPLATE, CONF_EVENT_DATA_TEMPLATE,
CONF_FOR, CONF_FOR,
CONF_ID,
CONF_PLATFORM, CONF_PLATFORM,
CONF_REPEAT, CONF_REPEAT,
CONF_SCAN_INTERVAL, CONF_SCAN_INTERVAL,
@ -1026,6 +1027,14 @@ TIME_CONDITION_SCHEMA = vol.All(
has_at_least_one_key("before", "after", "weekday"), has_at_least_one_key("before", "after", "weekday"),
) )
TRIGGER_CONDITION_SCHEMA = vol.Schema(
{
**CONDITION_BASE_SCHEMA,
vol.Required(CONF_CONDITION): "trigger",
vol.Required(CONF_ID): vol.All(ensure_list, [string]),
}
)
ZONE_CONDITION_SCHEMA = vol.Schema( ZONE_CONDITION_SCHEMA = vol.Schema(
{ {
**CONDITION_BASE_SCHEMA, **CONDITION_BASE_SCHEMA,
@ -1090,23 +1099,26 @@ CONDITION_SCHEMA: vol.Schema = vol.Schema(
key_value_schemas( key_value_schemas(
CONF_CONDITION, CONF_CONDITION,
{ {
"and": AND_CONDITION_SCHEMA,
"device": DEVICE_CONDITION_SCHEMA,
"not": NOT_CONDITION_SCHEMA,
"numeric_state": NUMERIC_STATE_CONDITION_SCHEMA, "numeric_state": NUMERIC_STATE_CONDITION_SCHEMA,
"or": OR_CONDITION_SCHEMA,
"state": STATE_CONDITION_SCHEMA, "state": STATE_CONDITION_SCHEMA,
"sun": SUN_CONDITION_SCHEMA, "sun": SUN_CONDITION_SCHEMA,
"template": TEMPLATE_CONDITION_SCHEMA, "template": TEMPLATE_CONDITION_SCHEMA,
"time": TIME_CONDITION_SCHEMA, "time": TIME_CONDITION_SCHEMA,
"trigger": TRIGGER_CONDITION_SCHEMA,
"zone": ZONE_CONDITION_SCHEMA, "zone": ZONE_CONDITION_SCHEMA,
"and": AND_CONDITION_SCHEMA,
"or": OR_CONDITION_SCHEMA,
"not": NOT_CONDITION_SCHEMA,
"device": DEVICE_CONDITION_SCHEMA,
}, },
), ),
dynamic_template, dynamic_template,
) )
) )
TRIGGER_BASE_SCHEMA = vol.Schema({vol.Required(CONF_PLATFORM): str}) TRIGGER_BASE_SCHEMA = vol.Schema(
{vol.Required(CONF_PLATFORM): str, vol.Optional(CONF_ID): str}
)
TRIGGER_SCHEMA = vol.All( TRIGGER_SCHEMA = vol.All(
ensure_list, [TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)] ensure_list, [TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)]

View file

@ -8,7 +8,7 @@ from typing import Any, Callable
import voluptuous as vol import voluptuous as vol
from homeassistant.const import CONF_PLATFORM from homeassistant.const import CONF_ID, CONF_PLATFORM
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
@ -74,7 +74,8 @@ async def async_initialize_triggers(
triggers = [] triggers = []
for idx, conf in enumerate(trigger_config): for idx, conf in enumerate(trigger_config):
platform = await _async_get_trigger_platform(hass, conf) platform = await _async_get_trigger_platform(hass, conf)
info = {**info, "trigger_id": f"{idx}"} trigger_id = conf.get(CONF_ID, f"{idx}")
info = {**info, "trigger_id": trigger_id}
triggers.append(platform.async_attach_trigger(hass, conf, action, info)) triggers.append(platform.async_attach_trigger(hass, conf, action, info))
attach_results = await asyncio.gather(*triggers, return_exceptions=True) attach_results = await asyncio.gather(*triggers, return_exceptions=True)

View file

@ -1405,3 +1405,97 @@ async def test_trigger_service(hass, calls):
assert len(calls) == 1 assert len(calls) == 1
assert calls[0].data.get("trigger") == {"platform": None} assert calls[0].data.get("trigger") == {"platform": None}
assert calls[0].context.parent_id is context.id assert calls[0].context.parent_id is context.id
async def test_trigger_condition_implicit_id(hass, calls):
"""Test triggers."""
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"trigger": [
{"platform": "event", "event_type": "test_event1"},
{"platform": "event", "event_type": "test_event2"},
{"platform": "event", "event_type": "test_event3"},
],
"action": {
"choose": [
{
"conditions": {"condition": "trigger", "id": [0, "2"]},
"sequence": {
"service": "test.automation",
"data": {"param": "one"},
},
},
{
"conditions": {"condition": "trigger", "id": "1"},
"sequence": {
"service": "test.automation",
"data": {"param": "two"},
},
},
]
},
}
},
)
hass.bus.async_fire("test_event1")
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[-1].data.get("param") == "one"
hass.bus.async_fire("test_event2")
await hass.async_block_till_done()
assert len(calls) == 2
assert calls[-1].data.get("param") == "two"
hass.bus.async_fire("test_event3")
await hass.async_block_till_done()
assert len(calls) == 3
assert calls[-1].data.get("param") == "one"
async def test_trigger_condition_explicit_id(hass, calls):
"""Test triggers."""
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"trigger": [
{"platform": "event", "event_type": "test_event1", "id": "one"},
{"platform": "event", "event_type": "test_event2", "id": "two"},
],
"action": {
"choose": [
{
"conditions": {"condition": "trigger", "id": "one"},
"sequence": {
"service": "test.automation",
"data": {"param": "one"},
},
},
{
"conditions": {"condition": "trigger", "id": "two"},
"sequence": {
"service": "test.automation",
"data": {"param": "two"},
},
},
]
},
}
},
)
hass.bus.async_fire("test_event1")
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[-1].data.get("param") == "one"
hass.bus.async_fire("test_event2")
await hass.async_block_till_done()
assert len(calls) == 2
assert calls[-1].data.get("param") == "two"

View file

@ -2829,3 +2829,17 @@ async def test_if_action_after_sunset_no_offset_kotzebue(hass, hass_ws_client, c
"sun", "sun",
{"result": True, "wanted_time_after": "2015-07-23T11:22:18.467277+00:00"}, {"result": True, "wanted_time_after": "2015-07-23T11:22:18.467277+00:00"},
) )
async def test_trigger(hass):
"""Test trigger condition."""
test = await condition.async_from_config(
hass,
{"alias": "Trigger Cond", "condition": "trigger", "id": "123456"},
)
assert not test(hass)
assert not test(hass, {})
assert not test(hass, {"other_var": "123456"})
assert not test(hass, {"trigger": {"trigger_id": "123456"}})
assert test(hass, {"trigger": {"id": "123456"}})