Improve device_automation
typing (#66621)
This commit is contained in:
parent
8bf19655f1
commit
1a247f7d1b
3 changed files with 148 additions and 19 deletions
|
@ -7,14 +7,14 @@ from enum import Enum
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
import logging
|
import logging
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import Any, NamedTuple
|
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Protocol, Union, overload
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
import voluptuous_serialize
|
import voluptuous_serialize
|
||||||
|
|
||||||
from homeassistant.components import websocket_api
|
from homeassistant.components import websocket_api
|
||||||
from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM
|
from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant
|
||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import (
|
||||||
config_validation as cv,
|
config_validation as cv,
|
||||||
device_registry as dr,
|
device_registry as dr,
|
||||||
|
@ -27,6 +27,13 @@ from homeassistant.requirements import async_get_integration_with_requirements
|
||||||
|
|
||||||
from .exceptions import DeviceNotFound, InvalidDeviceAutomationConfig
|
from .exceptions import DeviceNotFound, InvalidDeviceAutomationConfig
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from homeassistant.components.automation import (
|
||||||
|
AutomationActionType,
|
||||||
|
AutomationTriggerInfo,
|
||||||
|
)
|
||||||
|
from homeassistant.helpers import condition
|
||||||
|
|
||||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
# mypy: allow-untyped-calls, allow-untyped-defs
|
||||||
|
|
||||||
DOMAIN = "device_automation"
|
DOMAIN = "device_automation"
|
||||||
|
@ -76,6 +83,77 @@ TYPES = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceAutomationTriggerProtocol(Protocol):
|
||||||
|
"""Define the format of device_trigger modules.
|
||||||
|
|
||||||
|
Each module must define either TRIGGER_SCHEMA or async_validate_trigger_config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
TRIGGER_SCHEMA: vol.Schema
|
||||||
|
|
||||||
|
async def async_validate_trigger_config(
|
||||||
|
self, hass: HomeAssistant, config: ConfigType
|
||||||
|
) -> ConfigType:
|
||||||
|
"""Validate config."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def async_attach_trigger(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config: ConfigType,
|
||||||
|
action: AutomationActionType,
|
||||||
|
automation_info: AutomationTriggerInfo,
|
||||||
|
) -> CALLBACK_TYPE:
|
||||||
|
"""Attach a trigger."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceAutomationConditionProtocol(Protocol):
|
||||||
|
"""Define the format of device_condition modules.
|
||||||
|
|
||||||
|
Each module must define either CONDITION_SCHEMA or async_validate_condition_config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
CONDITION_SCHEMA: vol.Schema
|
||||||
|
|
||||||
|
async def async_validate_condition_config(
|
||||||
|
self, hass: HomeAssistant, config: ConfigType
|
||||||
|
) -> ConfigType:
|
||||||
|
"""Validate config."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def async_condition_from_config(
|
||||||
|
self, hass: HomeAssistant, config: ConfigType
|
||||||
|
) -> condition.ConditionCheckerType:
|
||||||
|
"""Evaluate state based on configuration."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceAutomationActionProtocol(Protocol):
|
||||||
|
"""Define the format of device_action modules.
|
||||||
|
|
||||||
|
Each module must define either ACTION_SCHEMA or async_validate_action_config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ACTION_SCHEMA: vol.Schema
|
||||||
|
|
||||||
|
async def async_validate_action_config(
|
||||||
|
self, hass: HomeAssistant, config: ConfigType
|
||||||
|
) -> ConfigType:
|
||||||
|
"""Validate config."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def async_call_action_from_config(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config: ConfigType,
|
||||||
|
variables: dict[str, Any],
|
||||||
|
context: Context | None,
|
||||||
|
) -> None:
|
||||||
|
"""Execute a device action."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
async def async_get_device_automations(
|
async def async_get_device_automations(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
@ -115,9 +193,51 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
DeviceAutomationPlatformType = Union[
|
||||||
|
ModuleType,
|
||||||
|
DeviceAutomationTriggerProtocol,
|
||||||
|
DeviceAutomationConditionProtocol,
|
||||||
|
DeviceAutomationActionProtocol,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def async_get_device_automation_platform( # noqa: D103
|
||||||
|
hass: HomeAssistant,
|
||||||
|
domain: str,
|
||||||
|
automation_type: Literal[DeviceAutomationType.TRIGGER],
|
||||||
|
) -> DeviceAutomationTriggerProtocol:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def async_get_device_automation_platform( # noqa: D103
|
||||||
|
hass: HomeAssistant,
|
||||||
|
domain: str,
|
||||||
|
automation_type: Literal[DeviceAutomationType.CONDITION],
|
||||||
|
) -> DeviceAutomationConditionProtocol:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def async_get_device_automation_platform( # noqa: D103
|
||||||
|
hass: HomeAssistant,
|
||||||
|
domain: str,
|
||||||
|
automation_type: Literal[DeviceAutomationType.ACTION],
|
||||||
|
) -> DeviceAutomationActionProtocol:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def async_get_device_automation_platform( # noqa: D103
|
||||||
|
hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType | str
|
||||||
|
) -> DeviceAutomationPlatformType:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
async def async_get_device_automation_platform(
|
async def async_get_device_automation_platform(
|
||||||
hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType | str
|
hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType | str
|
||||||
) -> ModuleType:
|
) -> DeviceAutomationPlatformType:
|
||||||
"""Load device automation platform for integration.
|
"""Load device automation platform for integration.
|
||||||
|
|
||||||
Throws InvalidDeviceAutomationConfig if the integration is not found or does not support device automation.
|
Throws InvalidDeviceAutomationConfig if the integration is not found or does not support device automation.
|
||||||
|
|
|
@ -1,7 +1,15 @@
|
||||||
"""Offer device oriented automation."""
|
"""Offer device oriented automation."""
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.components.automation import (
|
||||||
|
AutomationActionType,
|
||||||
|
AutomationTriggerInfo,
|
||||||
|
)
|
||||||
from homeassistant.const import CONF_DOMAIN
|
from homeassistant.const import CONF_DOMAIN
|
||||||
|
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
|
||||||
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
from . import (
|
from . import (
|
||||||
DEVICE_TRIGGER_BASE_SCHEMA,
|
DEVICE_TRIGGER_BASE_SCHEMA,
|
||||||
|
@ -10,26 +18,31 @@ from . import (
|
||||||
)
|
)
|
||||||
from .exceptions import InvalidDeviceAutomationConfig
|
from .exceptions import InvalidDeviceAutomationConfig
|
||||||
|
|
||||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
|
||||||
|
|
||||||
TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
|
TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
|
||||||
|
|
||||||
|
|
||||||
async def async_validate_trigger_config(hass, config):
|
async def async_validate_trigger_config(
|
||||||
|
hass: HomeAssistant, config: ConfigType
|
||||||
|
) -> ConfigType:
|
||||||
"""Validate config."""
|
"""Validate config."""
|
||||||
platform = await async_get_device_automation_platform(
|
platform = await async_get_device_automation_platform(
|
||||||
hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER
|
hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER
|
||||||
)
|
)
|
||||||
if not hasattr(platform, "async_validate_trigger_config"):
|
if not hasattr(platform, "async_validate_trigger_config"):
|
||||||
return platform.TRIGGER_SCHEMA(config)
|
return cast(ConfigType, platform.TRIGGER_SCHEMA(config))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await getattr(platform, "async_validate_trigger_config")(hass, config)
|
return await platform.async_validate_trigger_config(hass, config)
|
||||||
except InvalidDeviceAutomationConfig as err:
|
except InvalidDeviceAutomationConfig as err:
|
||||||
raise vol.Invalid(str(err) or "Invalid trigger configuration") from err
|
raise vol.Invalid(str(err) or "Invalid trigger configuration") from err
|
||||||
|
|
||||||
|
|
||||||
async def async_attach_trigger(hass, config, action, automation_info):
|
async def async_attach_trigger(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config: ConfigType,
|
||||||
|
action: AutomationActionType,
|
||||||
|
automation_info: AutomationTriggerInfo,
|
||||||
|
) -> CALLBACK_TYPE:
|
||||||
"""Listen for trigger."""
|
"""Listen for trigger."""
|
||||||
platform = await async_get_device_automation_platform(
|
platform = await async_get_device_automation_platform(
|
||||||
hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER
|
hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER
|
||||||
|
|
|
@ -875,12 +875,7 @@ async def async_device_from_config(
|
||||||
platform = await async_get_device_automation_platform(
|
platform = await async_get_device_automation_platform(
|
||||||
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
|
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
|
||||||
)
|
)
|
||||||
return trace_condition_function(
|
return trace_condition_function(platform.async_condition_from_config(hass, config))
|
||||||
cast(
|
|
||||||
ConditionCheckerType,
|
|
||||||
platform.async_condition_from_config(hass, config),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def async_trigger_from_config(
|
async def async_trigger_from_config(
|
||||||
|
@ -943,14 +938,15 @@ async def async_validate_condition_config(
|
||||||
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
|
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
|
||||||
)
|
)
|
||||||
if hasattr(platform, "async_validate_condition_config"):
|
if hasattr(platform, "async_validate_condition_config"):
|
||||||
return await platform.async_validate_condition_config(hass, config) # type: ignore
|
return await platform.async_validate_condition_config(hass, config)
|
||||||
return cast(ConfigType, platform.CONDITION_SCHEMA(config))
|
return cast(ConfigType, platform.CONDITION_SCHEMA(config))
|
||||||
|
|
||||||
if condition in ("numeric_state", "state"):
|
if condition in ("numeric_state", "state"):
|
||||||
validator = getattr(
|
validator = cast(
|
||||||
sys.modules[__name__], VALIDATE_CONFIG_FORMAT.format(condition)
|
Callable[[HomeAssistant, ConfigType], ConfigType],
|
||||||
|
getattr(sys.modules[__name__], VALIDATE_CONFIG_FORMAT.format(condition)),
|
||||||
)
|
)
|
||||||
return validator(hass, config) # type: ignore
|
return validator(hass, config)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue