Add condition to trigger template entities (#119689)

* Add conditions to trigger template entities

* Add tests

* Fix ruff error

* Ruff

* Apply suggestions from code review

* Deduplicate

* Tweak name used in debug message

* Add and improve type annotations of modified code

* Adjust typing

* Adjust typing

* Add typing and remove unused parameter

* Adjust typing

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Adjust return type

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

---------

Co-authored-by: Erik Montnemery <erik@montnemery.com>
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
chammp 2024-09-11 09:36:49 +02:00 committed by GitHub
parent 74834b2d88
commit b3377fe5fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 265 additions and 49 deletions

View file

@ -47,14 +47,7 @@ from homeassistant.core import (
split_entity_id, split_entity_id,
valid_entity_id, valid_entity_id,
) )
from homeassistant.exceptions import ( from homeassistant.exceptions import HomeAssistantError, ServiceNotFound, TemplateError
ConditionError,
ConditionErrorContainer,
ConditionErrorIndex,
HomeAssistantError,
ServiceNotFound,
TemplateError,
)
from homeassistant.helpers import condition from homeassistant.helpers import condition
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.deprecation import ( from homeassistant.helpers.deprecation import (
@ -1146,38 +1139,13 @@ async def _async_process_if(
"""Process if checks.""" """Process if checks."""
if_configs = config[CONF_CONDITION] if_configs = config[CONF_CONDITION]
checks: list[condition.ConditionCheckerType] = [] try:
for if_config in if_configs: if_action = await condition.async_conditions_from_config(
try: hass, if_configs, LOGGER, name
checks.append(await condition.async_from_config(hass, if_config)) )
except HomeAssistantError as ex: except HomeAssistantError as ex:
LOGGER.warning("Invalid condition: %s", ex) LOGGER.warning("Invalid condition: %s", ex)
return None return None
def if_action(variables: Mapping[str, Any] | None = None) -> bool:
"""AND all conditions."""
errors: list[ConditionErrorIndex] = []
for index, check in enumerate(checks):
try:
with trace_path(["condition", str(index)]):
if check(hass, variables) is False:
return False
except ConditionError as ex:
errors.append(
ConditionErrorIndex(
"condition", index=index, total=len(checks), error=ex
)
)
if errors:
LOGGER.warning(
"Error evaluating condition in '%s':\n%s",
name,
ConditionErrorContainer("condition", errors=errors),
)
return False
return True
result: IfAction = if_action # type: ignore[assignment] result: IfAction = if_action # type: ignore[assignment]
result.config = if_configs result.config = if_configs

View file

@ -15,6 +15,7 @@ from homeassistant.config import async_log_schema_error, config_without_domain
from homeassistant.const import CONF_BINARY_SENSORS, CONF_SENSORS, CONF_UNIQUE_ID from homeassistant.const import CONF_BINARY_SENSORS, CONF_SENSORS, CONF_UNIQUE_ID
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.condition import async_validate_conditions_config
from homeassistant.helpers.trigger import async_validate_trigger_config from homeassistant.helpers.trigger import async_validate_trigger_config
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import async_notify_setup_error from homeassistant.setup import async_notify_setup_error
@ -28,7 +29,7 @@ from . import (
sensor as sensor_platform, sensor as sensor_platform,
weather as weather_platform, weather as weather_platform,
) )
from .const import CONF_ACTION, CONF_TRIGGER, DOMAIN from .const import CONF_ACTION, CONF_CONDITION, CONF_TRIGGER, DOMAIN
PACKAGE_MERGE_HINT = "list" PACKAGE_MERGE_HINT = "list"
@ -36,6 +37,7 @@ CONFIG_SECTION_SCHEMA = vol.Schema(
{ {
vol.Optional(CONF_UNIQUE_ID): cv.string, vol.Optional(CONF_UNIQUE_ID): cv.string,
vol.Optional(CONF_TRIGGER): cv.TRIGGER_SCHEMA, vol.Optional(CONF_TRIGGER): cv.TRIGGER_SCHEMA,
vol.Optional(CONF_CONDITION): cv.CONDITIONS_SCHEMA,
vol.Optional(CONF_ACTION): cv.SCRIPT_SCHEMA, vol.Optional(CONF_ACTION): cv.SCRIPT_SCHEMA,
vol.Optional(NUMBER_DOMAIN): vol.All( vol.Optional(NUMBER_DOMAIN): vol.All(
cv.ensure_list, [number_platform.NUMBER_SCHEMA] cv.ensure_list, [number_platform.NUMBER_SCHEMA]
@ -83,6 +85,11 @@ async def async_validate_config(hass: HomeAssistant, config: ConfigType) -> Conf
cfg[CONF_TRIGGER] = await async_validate_trigger_config( cfg[CONF_TRIGGER] = await async_validate_trigger_config(
hass, cfg[CONF_TRIGGER] hass, cfg[CONF_TRIGGER]
) )
if CONF_CONDITION in cfg:
cfg[CONF_CONDITION] = await async_validate_conditions_config(
hass, cfg[CONF_CONDITION]
)
except vol.Invalid as err: except vol.Invalid as err:
async_log_schema_error(err, DOMAIN, cfg, hass) async_log_schema_error(err, DOMAIN, cfg, hass)
async_notify_setup_error(hass, DOMAIN) async_notify_setup_error(hass, DOMAIN)

View file

@ -7,6 +7,7 @@ CONF_ATTRIBUTE_TEMPLATES = "attribute_templates"
CONF_ATTRIBUTES = "attributes" CONF_ATTRIBUTES = "attributes"
CONF_AVAILABILITY = "availability" CONF_AVAILABILITY = "availability"
CONF_AVAILABILITY_TEMPLATE = "availability_template" CONF_AVAILABILITY_TEMPLATE = "availability_template"
CONF_CONDITION = "condition"
CONF_MAX = "max" CONF_MAX = "max"
CONF_MIN = "min" CONF_MIN = "min"
CONF_OBJECT_ID = "object_id" CONF_OBJECT_ID = "object_id"

View file

@ -1,16 +1,18 @@
"""Data update coordinator for trigger based template entities.""" """Data update coordinator for trigger based template entities."""
from collections.abc import Callable from collections.abc import Callable, Mapping
import logging import logging
from typing import TYPE_CHECKING, Any
from homeassistant.const import EVENT_HOMEASSISTANT_START from homeassistant.const import EVENT_HOMEASSISTANT_START
from homeassistant.core import Context, CoreState, callback from homeassistant.core import Context, CoreState, callback
from homeassistant.helpers import discovery, trigger as trigger_helper from homeassistant.helpers import condition, discovery, trigger as trigger_helper
from homeassistant.helpers.script import Script from homeassistant.helpers.script import Script
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.trace import trace_get
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
from .const import CONF_ACTION, CONF_TRIGGER, DOMAIN, PLATFORMS from .const import CONF_ACTION, CONF_CONDITION, CONF_TRIGGER, DOMAIN, PLATFORMS
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -24,6 +26,7 @@ class TriggerUpdateCoordinator(DataUpdateCoordinator):
"""Instantiate trigger data.""" """Instantiate trigger data."""
super().__init__(hass, _LOGGER, name="Trigger Update Coordinator") super().__init__(hass, _LOGGER, name="Trigger Update Coordinator")
self.config = config self.config = config
self._cond_func: Callable[[Mapping[str, Any] | None], bool] | None = None
self._unsub_start: Callable[[], None] | None = None self._unsub_start: Callable[[], None] | None = None
self._unsub_trigger: Callable[[], None] | None = None self._unsub_trigger: Callable[[], None] | None = None
self._script: Script | None = None self._script: Script | None = None
@ -73,6 +76,11 @@ class TriggerUpdateCoordinator(DataUpdateCoordinator):
DOMAIN, DOMAIN,
) )
if CONF_CONDITION in self.config:
self._cond_func = await condition.async_conditions_from_config(
self.hass, self.config[CONF_CONDITION], _LOGGER, "template entity"
)
if start_event is not None: if start_event is not None:
self._unsub_start = None self._unsub_start = None
@ -91,16 +99,43 @@ class TriggerUpdateCoordinator(DataUpdateCoordinator):
start_event is not None, start_event is not None,
) )
async def _handle_triggered_with_script(self, run_variables, context=None): async def _handle_triggered_with_script(
self, run_variables: TemplateVarsType, context: Context | None = None
) -> None:
if not self._check_condition(run_variables):
return
# Create a context referring to the trigger context. # Create a context referring to the trigger context.
trigger_context_id = None if context is None else context.id trigger_context_id = None if context is None else context.id
script_context = Context(parent_id=trigger_context_id) script_context = Context(parent_id=trigger_context_id)
if TYPE_CHECKING:
# This method is only called if there's a script
assert self._script is not None
if script_result := await self._script.async_run(run_variables, script_context): if script_result := await self._script.async_run(run_variables, script_context):
run_variables = script_result.variables run_variables = script_result.variables
self._handle_triggered(run_variables, context) self._execute_update(run_variables, context)
async def _handle_triggered(
self, run_variables: TemplateVarsType, context: Context | None = None
) -> None:
if not self._check_condition(run_variables):
return
self._execute_update(run_variables, context)
def _check_condition(self, run_variables: TemplateVarsType) -> bool:
if not self._cond_func:
return True
condition_result = self._cond_func(run_variables)
if condition_result is False:
_LOGGER.debug(
"Conditions not met, aborting template trigger update. Condition summary: %s",
trace_get(clear=False),
)
return condition_result
@callback @callback
def _handle_triggered(self, run_variables, context=None): def _execute_update(
self, run_variables: TemplateVarsType, context: Context | None = None
) -> None:
self.async_set_updated_data( self.async_set_updated_data(
{"run_variables": run_variables, "context": context} {"run_variables": run_variables, "context": context}
) )

View file

@ -8,6 +8,7 @@ from collections.abc import Callable, Container, Generator
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime, time as dt_time, timedelta from datetime import datetime, time as dt_time, timedelta
import functools as ft import functools as ft
import logging
import re import re
import sys import sys
from typing import Any, Protocol, cast from typing import Any, Protocol, cast
@ -1064,6 +1065,46 @@ async def async_validate_conditions_config(
return [await async_validate_condition_config(hass, cond) for cond in conditions] return [await async_validate_condition_config(hass, cond) for cond in conditions]
async def async_conditions_from_config(
hass: HomeAssistant,
condition_configs: list[ConfigType],
logger: logging.Logger,
name: str,
) -> Callable[[TemplateVarsType], bool]:
"""AND all conditions."""
checks: list[ConditionCheckerType] = [
await async_from_config(hass, condition_config)
for condition_config in condition_configs
]
def check_conditions(variables: TemplateVarsType = None) -> bool:
"""AND all conditions."""
errors: list[ConditionErrorIndex] = []
for index, check in enumerate(checks):
try:
with trace_path(["condition", str(index)]):
if check(hass, variables) is False:
return False
except ConditionError as ex:
errors.append(
ConditionErrorIndex(
"condition", index=index, total=len(checks), error=ex
)
)
if errors:
logger.warning(
"Error evaluating condition in '%s':\n%s",
name,
ConditionErrorContainer("condition", errors=errors),
)
return False
return True
return check_conditions
@callback @callback
def async_extract_entities(config: ConfigType | Template) -> set[str]: def async_extract_entities(config: ConfigType | Template) -> set[str]:
"""Extract entities from a condition.""" """Extract entities from a condition."""

View file

@ -1349,7 +1349,7 @@ async def _async_stop_scripts_at_shutdown(hass: HomeAssistant, event: Event) ->
) )
type _VarsType = dict[str, Any] | MappingProxyType[str, Any] type _VarsType = dict[str, Any] | Mapping[str, Any] | MappingProxyType[str, Any]
def _referenced_extract_ids(data: Any, key: str, found: set[str]) -> None: def _referenced_extract_ids(data: Any, key: str, found: set[str]) -> None:

View file

@ -1207,6 +1207,124 @@ async def test_trigger_entity(
assert state.context is context assert state.context is context
@pytest.mark.parametrize(("count", "domain"), [(1, template.DOMAIN)])
@pytest.mark.parametrize(
"config",
[
{
"template": [
{
"unique_id": "listening-test-event",
"trigger": {"platform": "event", "event_type": "test_event"},
"condition": [
{
"condition": "template",
"value_template": "{{ trigger.event.data.beer >= 42 }}",
}
],
"sensor": [
{
"name": "Enough Name",
"unique_id": "enough-id",
"state": "You had enough Beer.",
}
],
},
],
},
],
)
async def test_trigger_conditional_entity(hass: HomeAssistant, start_ha) -> None:
"""Test conditional trigger entity works."""
state = hass.states.get("sensor.enough_name")
assert state is not None
assert state.state == STATE_UNKNOWN
hass.bus.async_fire("test_event", {"beer": 2})
await hass.async_block_till_done()
state = hass.states.get("sensor.enough_name")
assert state.state == STATE_UNKNOWN
hass.bus.async_fire("test_event", {"beer": 42})
await hass.async_block_till_done()
state = hass.states.get("sensor.enough_name")
assert state.state == "You had enough Beer."
@pytest.mark.parametrize(("count", "domain"), [(1, template.DOMAIN)])
@pytest.mark.parametrize(
"config",
[
{
"template": [
{
"unique_id": "listening-test-event",
"trigger": {"platform": "event", "event_type": "test_event"},
"condition": [
{
"condition": "template",
"value_template": "{{ trigger.event.data.beer / 0 == 'narf' }}",
}
],
"sensor": [
{
"name": "Enough Name",
"unique_id": "enough-id",
"state": "You had enough Beer.",
}
],
},
],
},
],
)
async def test_trigger_conditional_entity_evaluation_error(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture, start_ha
) -> None:
"""Test trigger entity is not updated when condition evaluation fails."""
hass.bus.async_fire("test_event", {"beer": 1})
await hass.async_block_till_done()
state = hass.states.get("sensor.enough_name")
assert state is not None
assert state.state == STATE_UNKNOWN
assert "Error evaluating condition in 'template entity'" in caplog.text
@pytest.mark.parametrize(("count", "domain"), [(0, template.DOMAIN)])
@pytest.mark.parametrize(
"config",
[
{
"template": [
{
"unique_id": "listening-test-event",
"trigger": {"platform": "event", "event_type": "test_event"},
"condition": [
{"condition": "template", "value_template": "{{ invalid"}
],
"sensor": [
{
"name": "Will Not Exist Name",
"state": "Unimportant",
}
],
},
],
},
],
)
async def test_trigger_conditional_entity_invalid_condition(
hass: HomeAssistant, start_ha
) -> None:
"""Test trigger entity is not created when condition is invalid."""
state = hass.states.get("sensor.will_not_exist_name")
assert state is None
@pytest.mark.parametrize(("count", "domain"), [(1, "template")]) @pytest.mark.parametrize(("count", "domain"), [(1, "template")])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"config", "config",
@ -1903,6 +2021,52 @@ async def test_trigger_action(
assert events[0].context.parent_id == context.id assert events[0].context.parent_id == context.id
@pytest.mark.parametrize(("count", "domain"), [(1, template.DOMAIN)])
@pytest.mark.parametrize(
"config",
[
{
"template": [
{
"unique_id": "listening-test-event",
"trigger": {"platform": "event", "event_type": "test_event"},
"condition": [
{
"condition": "template",
"value_template": "{{ trigger.event.data.beer >= 42 }}",
}
],
"action": [
{"event": "test_event_by_action"},
],
"sensor": [
{
"name": "Not That Important",
"state": "Really not.",
}
],
},
],
},
],
)
async def test_trigger_conditional_action(hass: HomeAssistant, start_ha) -> None:
"""Test conditional trigger entity with an action works."""
event = "test_event_by_action"
events = async_capture_events(hass, event)
hass.bus.async_fire("test_event", {"beer": 1})
await hass.async_block_till_done()
assert len(events) == 0
hass.bus.async_fire("test_event", {"beer": 42})
await hass.async_block_till_done()
assert len(events) == 1
async def test_device_id( async def test_device_id(
hass: HomeAssistant, hass: HomeAssistant,
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,