Allow conditions to be implemented in platforms (#88509)

* Allow conditions to be implemented in platforms

* Update tests

* Tweak typing

* Rebase fixes
This commit is contained in:
Erik Montnemery 2023-02-24 04:30:51 +01:00 committed by GitHub
parent 2f826a6f86
commit d90ee85118
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 77 additions and 52 deletions

View file

@ -8,6 +8,7 @@ import voluptuous as vol
from homeassistant.const import CONF_DOMAIN
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.condition import ConditionProtocol, trace_condition_function
from homeassistant.helpers.typing import ConfigType
from . import DeviceAutomationType, async_get_device_automation_platform
@ -17,24 +18,13 @@ if TYPE_CHECKING:
from homeassistant.helpers import condition
class DeviceAutomationConditionProtocol(Protocol):
class DeviceAutomationConditionProtocol(ConditionProtocol, Protocol):
"""Define the format of device_condition modules.
Each module must define either CONDITION_SCHEMA or async_validate_condition_config.
Each module must define either CONDITION_SCHEMA or async_validate_condition_config
from ConditionProtocol.
"""
CONDITION_SCHEMA: vol.Schema
async def async_validate_condition_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
def async_condition_from_config(
self, hass: HomeAssistant, config: ConfigType
) -> condition.ConditionCheckerType:
"""Evaluate state based on configuration."""
async def async_get_condition_capabilities(
self, hass: HomeAssistant, config: ConfigType
) -> dict[str, vol.Schema]:
@ -62,4 +52,4 @@ async def async_condition_from_config(
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
)
return platform.async_condition_from_config(hass, config)
return trace_condition_function(platform.async_condition_from_config(hass, config))

View file

@ -7,15 +7,13 @@ from collections.abc import Callable, Container, Generator
from contextlib import contextmanager
from datetime import datetime, time as dt_time, timedelta
import functools as ft
import logging
import re
import sys
from typing import Any, cast
from typing import Any, Protocol, cast
import voluptuous as vol
from homeassistant.components import zone as zone_cmp
from homeassistant.components.device_automation import condition as device_condition
from homeassistant.components.sensor import SensorDeviceClass
from homeassistant.const import (
ATTR_DEVICE_CLASS,
@ -55,6 +53,7 @@ from homeassistant.exceptions import (
HomeAssistantError,
TemplateError,
)
from homeassistant.loader import IntegrationNotFound, async_get_integration
from homeassistant.util.async_ import run_callback_threadsafe
import homeassistant.util.dt as dt_util
@ -77,12 +76,44 @@ ASYNC_FROM_CONFIG_FORMAT = "async_{}_from_config"
FROM_CONFIG_FORMAT = "{}_from_config"
VALIDATE_CONFIG_FORMAT = "{}_validate_config"
_LOGGER = logging.getLogger(__name__)
_PLATFORM_ALIASES = {
"and": None,
"device": "device_automation",
"not": None,
"numeric_state": None,
"or": None,
"state": None,
"sun": None,
"template": None,
"time": None,
"trigger": None,
"zone": None,
}
INPUT_ENTITY_ID = re.compile(
r"^input_(?:select|text|number|boolean|datetime)\.(?!.+__)(?!_)[\da-z_]+(?<!_)$"
)
class ConditionProtocol(Protocol):
"""Define the format of device_condition modules.
Each module must define either CONDITION_SCHEMA or async_validate_condition_config.
"""
CONDITION_SCHEMA: vol.Schema
async def async_validate_condition_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
def async_condition_from_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConditionCheckerType:
"""Evaluate state based on configuration."""
ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool | None]
@ -152,6 +183,27 @@ def trace_condition_function(condition: ConditionCheckerType) -> ConditionChecke
return wrapper
async def _async_get_condition_platform(
hass: HomeAssistant, config: ConfigType
) -> ConditionProtocol | None:
platform = config[CONF_CONDITION]
platform = _PLATFORM_ALIASES.get(platform, platform)
if platform is None:
return None
try:
integration = await async_get_integration(hass, platform)
except IntegrationNotFound:
raise HomeAssistantError(
f'Invalid condition "{platform}" specified {config}'
) from None
try:
return integration.get_platform("condition")
except ImportError:
raise HomeAssistantError(
f"Integration '{platform}' does not provide condition support"
) from None
async def async_from_config(
hass: HomeAssistant,
config: ConfigType,
@ -160,15 +212,18 @@ async def async_from_config(
Should be run on the event loop.
"""
condition = config.get(CONF_CONDITION)
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
factory = getattr(sys.modules[__name__], fmt.format(condition), None)
factory: Any = None
platform = await _async_get_condition_platform(hass, config)
if factory:
break
if platform is None:
condition = config.get(CONF_CONDITION)
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
factory = getattr(sys.modules[__name__], fmt.format(condition), None)
if factory is None:
raise HomeAssistantError(f'Invalid condition "{condition}" specified {config}')
if factory:
break
else:
factory = platform.async_condition_from_config
# Check if condition is not enabled
if not config.get(CONF_ENABLED, True):
@ -928,14 +983,6 @@ def zone_from_config(config: ConfigType) -> ConditionCheckerType:
return if_in_zone
async def async_device_from_config(
hass: HomeAssistant, config: ConfigType
) -> ConditionCheckerType:
"""Test a device condition."""
checker = await device_condition.async_condition_from_config(hass, config)
return trace_condition_function(checker)
async def async_trigger_from_config(
hass: HomeAssistant, config: ConfigType
) -> ConditionCheckerType:
@ -991,10 +1038,10 @@ async def async_validate_condition_config(
config["conditions"] = conditions
return config
if condition == "device":
return await device_condition.async_validate_condition_config(hass, config)
if condition in ("numeric_state", "state"):
platform = await _async_get_condition_platform(hass, config)
if platform is not None and hasattr(platform, "async_validate_condition_config"):
return await platform.async_validate_condition_config(hass, config)
if platform is None and condition in ("numeric_state", "state"):
validator = cast(
Callable[[HomeAssistant, ConfigType], ConfigType],
getattr(sys.modules[__name__], VALIDATE_CONFIG_FORMAT.format(condition)),

View file

@ -2406,13 +2406,7 @@ async def test_repeat_var_in_condition(hass: HomeAssistant, condition) -> None:
script_obj = script.Script(
hass, cv.SCRIPT_SCHEMA(sequence), "Test Name", "test_domain"
)
with mock.patch(
"homeassistant.helpers.condition._LOGGER.error",
side_effect=AssertionError("Template Error"),
):
await script_obj.async_run(context=Context())
await script_obj.async_run(context=Context())
assert len(events) == 2
if condition == "while":
@ -2545,13 +2539,7 @@ async def test_repeat_nested(
]
)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
with mock.patch(
"homeassistant.helpers.condition._LOGGER.error",
side_effect=AssertionError("Template Error"),
):
await script_obj.async_run(variables, Context())
await script_obj.async_run(variables, Context())
assert len(events) == 10
assert events[0].data == first_last
assert events[-1].data == first_last