Improve device_automation
typing (#66621)
This commit is contained in:
parent
8bf19655f1
commit
1a247f7d1b
3 changed files with 148 additions and 19 deletions
|
@ -7,14 +7,14 @@ from enum import Enum
|
|||
from functools import wraps
|
||||
import logging
|
||||
from types import ModuleType
|
||||
from typing import Any, NamedTuple
|
||||
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Protocol, Union, overload
|
||||
|
||||
import voluptuous as vol
|
||||
import voluptuous_serialize
|
||||
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant
|
||||
from homeassistant.helpers import (
|
||||
config_validation as cv,
|
||||
device_registry as dr,
|
||||
|
@ -27,6 +27,13 @@ from homeassistant.requirements import async_get_integration_with_requirements
|
|||
|
||||
from .exceptions import DeviceNotFound, InvalidDeviceAutomationConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from homeassistant.components.automation import (
|
||||
AutomationActionType,
|
||||
AutomationTriggerInfo,
|
||||
)
|
||||
from homeassistant.helpers import condition
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
||||
|
||||
DOMAIN = "device_automation"
|
||||
|
@ -76,6 +83,77 @@ TYPES = {
|
|||
}
|
||||
|
||||
|
||||
class DeviceAutomationTriggerProtocol(Protocol):
|
||||
"""Define the format of device_trigger modules.
|
||||
|
||||
Each module must define either TRIGGER_SCHEMA or async_validate_trigger_config.
|
||||
"""
|
||||
|
||||
TRIGGER_SCHEMA: vol.Schema
|
||||
|
||||
async def async_validate_trigger_config(
|
||||
self, hass: HomeAssistant, config: ConfigType
|
||||
) -> ConfigType:
|
||||
"""Validate config."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_attach_trigger(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
action: AutomationActionType,
|
||||
automation_info: AutomationTriggerInfo,
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Attach a trigger."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DeviceAutomationConditionProtocol(Protocol):
|
||||
"""Define the format of device_condition modules.
|
||||
|
||||
Each module must define either CONDITION_SCHEMA or async_validate_condition_config.
|
||||
"""
|
||||
|
||||
CONDITION_SCHEMA: vol.Schema
|
||||
|
||||
async def async_validate_condition_config(
|
||||
self, hass: HomeAssistant, config: ConfigType
|
||||
) -> ConfigType:
|
||||
"""Validate config."""
|
||||
raise NotImplementedError
|
||||
|
||||
def async_condition_from_config(
|
||||
self, hass: HomeAssistant, config: ConfigType
|
||||
) -> condition.ConditionCheckerType:
|
||||
"""Evaluate state based on configuration."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DeviceAutomationActionProtocol(Protocol):
|
||||
"""Define the format of device_action modules.
|
||||
|
||||
Each module must define either ACTION_SCHEMA or async_validate_action_config.
|
||||
"""
|
||||
|
||||
ACTION_SCHEMA: vol.Schema
|
||||
|
||||
async def async_validate_action_config(
|
||||
self, hass: HomeAssistant, config: ConfigType
|
||||
) -> ConfigType:
|
||||
"""Validate config."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_call_action_from_config(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
variables: dict[str, Any],
|
||||
context: Context | None,
|
||||
) -> None:
|
||||
"""Execute a device action."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_get_device_automations(
|
||||
hass: HomeAssistant,
|
||||
|
@ -115,9 +193,51 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
DeviceAutomationPlatformType = Union[
|
||||
ModuleType,
|
||||
DeviceAutomationTriggerProtocol,
|
||||
DeviceAutomationConditionProtocol,
|
||||
DeviceAutomationActionProtocol,
|
||||
]
|
||||
|
||||
|
||||
@overload
|
||||
async def async_get_device_automation_platform( # noqa: D103
|
||||
hass: HomeAssistant,
|
||||
domain: str,
|
||||
automation_type: Literal[DeviceAutomationType.TRIGGER],
|
||||
) -> DeviceAutomationTriggerProtocol:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
async def async_get_device_automation_platform( # noqa: D103
|
||||
hass: HomeAssistant,
|
||||
domain: str,
|
||||
automation_type: Literal[DeviceAutomationType.CONDITION],
|
||||
) -> DeviceAutomationConditionProtocol:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
async def async_get_device_automation_platform( # noqa: D103
|
||||
hass: HomeAssistant,
|
||||
domain: str,
|
||||
automation_type: Literal[DeviceAutomationType.ACTION],
|
||||
) -> DeviceAutomationActionProtocol:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
async def async_get_device_automation_platform( # noqa: D103
|
||||
hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType | str
|
||||
) -> DeviceAutomationPlatformType:
|
||||
...
|
||||
|
||||
|
||||
async def async_get_device_automation_platform(
|
||||
hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType | str
|
||||
) -> ModuleType:
|
||||
) -> DeviceAutomationPlatformType:
|
||||
"""Load device automation platform for integration.
|
||||
|
||||
Throws InvalidDeviceAutomationConfig if the integration is not found or does not support device automation.
|
||||
|
|
|
@ -1,7 +1,15 @@
|
|||
"""Offer device oriented automation."""
|
||||
from typing import cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.automation import (
|
||||
AutomationActionType,
|
||||
AutomationTriggerInfo,
|
||||
)
|
||||
from homeassistant.const import CONF_DOMAIN
|
||||
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from . import (
|
||||
DEVICE_TRIGGER_BASE_SCHEMA,
|
||||
|
@ -10,26 +18,31 @@ from . import (
|
|||
)
|
||||
from .exceptions import InvalidDeviceAutomationConfig
|
||||
|
||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
||||
|
||||
TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
|
||||
|
||||
|
||||
async def async_validate_trigger_config(hass, config):
|
||||
async def async_validate_trigger_config(
|
||||
hass: HomeAssistant, config: ConfigType
|
||||
) -> ConfigType:
|
||||
"""Validate config."""
|
||||
platform = await async_get_device_automation_platform(
|
||||
hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER
|
||||
)
|
||||
if not hasattr(platform, "async_validate_trigger_config"):
|
||||
return platform.TRIGGER_SCHEMA(config)
|
||||
return cast(ConfigType, platform.TRIGGER_SCHEMA(config))
|
||||
|
||||
try:
|
||||
return await getattr(platform, "async_validate_trigger_config")(hass, config)
|
||||
return await platform.async_validate_trigger_config(hass, config)
|
||||
except InvalidDeviceAutomationConfig as err:
|
||||
raise vol.Invalid(str(err) or "Invalid trigger configuration") from err
|
||||
|
||||
|
||||
async def async_attach_trigger(hass, config, action, automation_info):
|
||||
async def async_attach_trigger(
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
action: AutomationActionType,
|
||||
automation_info: AutomationTriggerInfo,
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Listen for trigger."""
|
||||
platform = await async_get_device_automation_platform(
|
||||
hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER
|
||||
|
|
|
@ -875,12 +875,7 @@ async def async_device_from_config(
|
|||
platform = await async_get_device_automation_platform(
|
||||
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
|
||||
)
|
||||
return trace_condition_function(
|
||||
cast(
|
||||
ConditionCheckerType,
|
||||
platform.async_condition_from_config(hass, config),
|
||||
)
|
||||
)
|
||||
return trace_condition_function(platform.async_condition_from_config(hass, config))
|
||||
|
||||
|
||||
async def async_trigger_from_config(
|
||||
|
@ -943,14 +938,15 @@ async def async_validate_condition_config(
|
|||
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
|
||||
)
|
||||
if hasattr(platform, "async_validate_condition_config"):
|
||||
return await platform.async_validate_condition_config(hass, config) # type: ignore
|
||||
return await platform.async_validate_condition_config(hass, config)
|
||||
return cast(ConfigType, platform.CONDITION_SCHEMA(config))
|
||||
|
||||
if condition in ("numeric_state", "state"):
|
||||
validator = getattr(
|
||||
sys.modules[__name__], VALIDATE_CONFIG_FORMAT.format(condition)
|
||||
validator = cast(
|
||||
Callable[[HomeAssistant, ConfigType], ConfigType],
|
||||
getattr(sys.modules[__name__], VALIDATE_CONFIG_FORMAT.format(condition)),
|
||||
)
|
||||
return validator(hass, config) # type: ignore
|
||||
return validator(hass, config)
|
||||
|
||||
return config
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue