Allow conditions to be implemented in platforms (#88509)
* Allow conditions to be implemented in platforms * Update tests * Tweak typing * Rebase fixes
This commit is contained in:
parent
2f826a6f86
commit
d90ee85118
3 changed files with 77 additions and 52 deletions
|
@ -8,6 +8,7 @@ import voluptuous as vol
|
|||
from homeassistant.const import CONF_DOMAIN
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.condition import ConditionProtocol, trace_condition_function
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from . import DeviceAutomationType, async_get_device_automation_platform
|
||||
|
@ -17,24 +18,13 @@ if TYPE_CHECKING:
|
|||
from homeassistant.helpers import condition
|
||||
|
||||
|
||||
class DeviceAutomationConditionProtocol(Protocol):
|
||||
class DeviceAutomationConditionProtocol(ConditionProtocol, Protocol):
|
||||
"""Define the format of device_condition modules.
|
||||
|
||||
Each module must define either CONDITION_SCHEMA or async_validate_condition_config.
|
||||
Each module must define either CONDITION_SCHEMA or async_validate_condition_config
|
||||
from ConditionProtocol.
|
||||
"""
|
||||
|
||||
CONDITION_SCHEMA: vol.Schema
|
||||
|
||||
async def async_validate_condition_config(
|
||||
self, hass: HomeAssistant, config: ConfigType
|
||||
) -> ConfigType:
|
||||
"""Validate config."""
|
||||
|
||||
def async_condition_from_config(
|
||||
self, hass: HomeAssistant, config: ConfigType
|
||||
) -> condition.ConditionCheckerType:
|
||||
"""Evaluate state based on configuration."""
|
||||
|
||||
async def async_get_condition_capabilities(
|
||||
self, hass: HomeAssistant, config: ConfigType
|
||||
) -> dict[str, vol.Schema]:
|
||||
|
@ -62,4 +52,4 @@ async def async_condition_from_config(
|
|||
platform = await async_get_device_automation_platform(
|
||||
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
|
||||
)
|
||||
return platform.async_condition_from_config(hass, config)
|
||||
return trace_condition_function(platform.async_condition_from_config(hass, config))
|
||||
|
|
|
@ -7,15 +7,13 @@ from collections.abc import Callable, Container, Generator
|
|||
from contextlib import contextmanager
|
||||
from datetime import datetime, time as dt_time, timedelta
|
||||
import functools as ft
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from typing import Any, cast
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import zone as zone_cmp
|
||||
from homeassistant.components.device_automation import condition as device_condition
|
||||
from homeassistant.components.sensor import SensorDeviceClass
|
||||
from homeassistant.const import (
|
||||
ATTR_DEVICE_CLASS,
|
||||
|
@ -55,6 +53,7 @@ from homeassistant.exceptions import (
|
|||
HomeAssistantError,
|
||||
TemplateError,
|
||||
)
|
||||
from homeassistant.loader import IntegrationNotFound, async_get_integration
|
||||
from homeassistant.util.async_ import run_callback_threadsafe
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
|
@ -77,12 +76,44 @@ ASYNC_FROM_CONFIG_FORMAT = "async_{}_from_config"
|
|||
FROM_CONFIG_FORMAT = "{}_from_config"
|
||||
VALIDATE_CONFIG_FORMAT = "{}_validate_config"
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
_PLATFORM_ALIASES = {
|
||||
"and": None,
|
||||
"device": "device_automation",
|
||||
"not": None,
|
||||
"numeric_state": None,
|
||||
"or": None,
|
||||
"state": None,
|
||||
"sun": None,
|
||||
"template": None,
|
||||
"time": None,
|
||||
"trigger": None,
|
||||
"zone": None,
|
||||
}
|
||||
|
||||
INPUT_ENTITY_ID = re.compile(
|
||||
r"^input_(?:select|text|number|boolean|datetime)\.(?!.+__)(?!_)[\da-z_]+(?<!_)$"
|
||||
)
|
||||
|
||||
|
||||
class ConditionProtocol(Protocol):
|
||||
"""Define the format of device_condition modules.
|
||||
|
||||
Each module must define either CONDITION_SCHEMA or async_validate_condition_config.
|
||||
"""
|
||||
|
||||
CONDITION_SCHEMA: vol.Schema
|
||||
|
||||
async def async_validate_condition_config(
|
||||
self, hass: HomeAssistant, config: ConfigType
|
||||
) -> ConfigType:
|
||||
"""Validate config."""
|
||||
|
||||
def async_condition_from_config(
|
||||
self, hass: HomeAssistant, config: ConfigType
|
||||
) -> ConditionCheckerType:
|
||||
"""Evaluate state based on configuration."""
|
||||
|
||||
|
||||
ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool | None]
|
||||
|
||||
|
||||
|
@ -152,6 +183,27 @@ def trace_condition_function(condition: ConditionCheckerType) -> ConditionChecke
|
|||
return wrapper
|
||||
|
||||
|
||||
async def _async_get_condition_platform(
|
||||
hass: HomeAssistant, config: ConfigType
|
||||
) -> ConditionProtocol | None:
|
||||
platform = config[CONF_CONDITION]
|
||||
platform = _PLATFORM_ALIASES.get(platform, platform)
|
||||
if platform is None:
|
||||
return None
|
||||
try:
|
||||
integration = await async_get_integration(hass, platform)
|
||||
except IntegrationNotFound:
|
||||
raise HomeAssistantError(
|
||||
f'Invalid condition "{platform}" specified {config}'
|
||||
) from None
|
||||
try:
|
||||
return integration.get_platform("condition")
|
||||
except ImportError:
|
||||
raise HomeAssistantError(
|
||||
f"Integration '{platform}' does not provide condition support"
|
||||
) from None
|
||||
|
||||
|
||||
async def async_from_config(
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
|
@ -160,15 +212,18 @@ async def async_from_config(
|
|||
|
||||
Should be run on the event loop.
|
||||
"""
|
||||
condition = config.get(CONF_CONDITION)
|
||||
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
|
||||
factory = getattr(sys.modules[__name__], fmt.format(condition), None)
|
||||
factory: Any = None
|
||||
platform = await _async_get_condition_platform(hass, config)
|
||||
|
||||
if factory:
|
||||
break
|
||||
if platform is None:
|
||||
condition = config.get(CONF_CONDITION)
|
||||
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
|
||||
factory = getattr(sys.modules[__name__], fmt.format(condition), None)
|
||||
|
||||
if factory is None:
|
||||
raise HomeAssistantError(f'Invalid condition "{condition}" specified {config}')
|
||||
if factory:
|
||||
break
|
||||
else:
|
||||
factory = platform.async_condition_from_config
|
||||
|
||||
# Check if condition is not enabled
|
||||
if not config.get(CONF_ENABLED, True):
|
||||
|
@ -928,14 +983,6 @@ def zone_from_config(config: ConfigType) -> ConditionCheckerType:
|
|||
return if_in_zone
|
||||
|
||||
|
||||
async def async_device_from_config(
|
||||
hass: HomeAssistant, config: ConfigType
|
||||
) -> ConditionCheckerType:
|
||||
"""Test a device condition."""
|
||||
checker = await device_condition.async_condition_from_config(hass, config)
|
||||
return trace_condition_function(checker)
|
||||
|
||||
|
||||
async def async_trigger_from_config(
|
||||
hass: HomeAssistant, config: ConfigType
|
||||
) -> ConditionCheckerType:
|
||||
|
@ -991,10 +1038,10 @@ async def async_validate_condition_config(
|
|||
config["conditions"] = conditions
|
||||
return config
|
||||
|
||||
if condition == "device":
|
||||
return await device_condition.async_validate_condition_config(hass, config)
|
||||
|
||||
if condition in ("numeric_state", "state"):
|
||||
platform = await _async_get_condition_platform(hass, config)
|
||||
if platform is not None and hasattr(platform, "async_validate_condition_config"):
|
||||
return await platform.async_validate_condition_config(hass, config)
|
||||
if platform is None and condition in ("numeric_state", "state"):
|
||||
validator = cast(
|
||||
Callable[[HomeAssistant, ConfigType], ConfigType],
|
||||
getattr(sys.modules[__name__], VALIDATE_CONFIG_FORMAT.format(condition)),
|
||||
|
|
|
@ -2406,13 +2406,7 @@ async def test_repeat_var_in_condition(hass: HomeAssistant, condition) -> None:
|
|||
script_obj = script.Script(
|
||||
hass, cv.SCRIPT_SCHEMA(sequence), "Test Name", "test_domain"
|
||||
)
|
||||
|
||||
with mock.patch(
|
||||
"homeassistant.helpers.condition._LOGGER.error",
|
||||
side_effect=AssertionError("Template Error"),
|
||||
):
|
||||
await script_obj.async_run(context=Context())
|
||||
|
||||
await script_obj.async_run(context=Context())
|
||||
assert len(events) == 2
|
||||
|
||||
if condition == "while":
|
||||
|
@ -2545,13 +2539,7 @@ async def test_repeat_nested(
|
|||
]
|
||||
)
|
||||
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
||||
|
||||
with mock.patch(
|
||||
"homeassistant.helpers.condition._LOGGER.error",
|
||||
side_effect=AssertionError("Template Error"),
|
||||
):
|
||||
await script_obj.async_run(variables, Context())
|
||||
|
||||
await script_obj.async_run(variables, Context())
|
||||
assert len(events) == 10
|
||||
assert events[0].data == first_last
|
||||
assert events[-1].data == first_last
|
||||
|
|
Loading…
Add table
Reference in a new issue