Improve device_automation typing (#66621)

This commit is contained in:
Marc Mueller 2022-02-17 22:08:43 +01:00 committed by GitHub
parent 8bf19655f1
commit 1a247f7d1b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 148 additions and 19 deletions

View file

@ -7,14 +7,14 @@ from enum import Enum
from functools import wraps from functools import wraps
import logging import logging
from types import ModuleType from types import ModuleType
from typing import Any, NamedTuple from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Protocol, Union, overload
import voluptuous as vol import voluptuous as vol
import voluptuous_serialize import voluptuous_serialize
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM
from homeassistant.core import HomeAssistant from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant
from homeassistant.helpers import ( from homeassistant.helpers import (
config_validation as cv, config_validation as cv,
device_registry as dr, device_registry as dr,
@ -27,6 +27,13 @@ from homeassistant.requirements import async_get_integration_with_requirements
from .exceptions import DeviceNotFound, InvalidDeviceAutomationConfig from .exceptions import DeviceNotFound, InvalidDeviceAutomationConfig
if TYPE_CHECKING:
from homeassistant.components.automation import (
AutomationActionType,
AutomationTriggerInfo,
)
from homeassistant.helpers import condition
# mypy: allow-untyped-calls, allow-untyped-defs # mypy: allow-untyped-calls, allow-untyped-defs
DOMAIN = "device_automation" DOMAIN = "device_automation"
@ -76,6 +83,77 @@ TYPES = {
} }
class DeviceAutomationTriggerProtocol(Protocol):
"""Define the format of device_trigger modules.
Each module must define either TRIGGER_SCHEMA or async_validate_trigger_config.
"""
TRIGGER_SCHEMA: vol.Schema
async def async_validate_trigger_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
raise NotImplementedError
async def async_attach_trigger(
self,
hass: HomeAssistant,
config: ConfigType,
action: AutomationActionType,
automation_info: AutomationTriggerInfo,
) -> CALLBACK_TYPE:
"""Attach a trigger."""
raise NotImplementedError
class DeviceAutomationConditionProtocol(Protocol):
"""Define the format of device_condition modules.
Each module must define either CONDITION_SCHEMA or async_validate_condition_config.
"""
CONDITION_SCHEMA: vol.Schema
async def async_validate_condition_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
raise NotImplementedError
def async_condition_from_config(
self, hass: HomeAssistant, config: ConfigType
) -> condition.ConditionCheckerType:
"""Evaluate state based on configuration."""
raise NotImplementedError
class DeviceAutomationActionProtocol(Protocol):
"""Define the format of device_action modules.
Each module must define either ACTION_SCHEMA or async_validate_action_config.
"""
ACTION_SCHEMA: vol.Schema
async def async_validate_action_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
raise NotImplementedError
async def async_call_action_from_config(
self,
hass: HomeAssistant,
config: ConfigType,
variables: dict[str, Any],
context: Context | None,
) -> None:
"""Execute a device action."""
raise NotImplementedError
@bind_hass @bind_hass
async def async_get_device_automations( async def async_get_device_automations(
hass: HomeAssistant, hass: HomeAssistant,
@ -115,9 +193,51 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True return True
DeviceAutomationPlatformType = Union[
ModuleType,
DeviceAutomationTriggerProtocol,
DeviceAutomationConditionProtocol,
DeviceAutomationActionProtocol,
]
@overload
async def async_get_device_automation_platform( # noqa: D103
hass: HomeAssistant,
domain: str,
automation_type: Literal[DeviceAutomationType.TRIGGER],
) -> DeviceAutomationTriggerProtocol:
...
@overload
async def async_get_device_automation_platform( # noqa: D103
hass: HomeAssistant,
domain: str,
automation_type: Literal[DeviceAutomationType.CONDITION],
) -> DeviceAutomationConditionProtocol:
...
@overload
async def async_get_device_automation_platform( # noqa: D103
hass: HomeAssistant,
domain: str,
automation_type: Literal[DeviceAutomationType.ACTION],
) -> DeviceAutomationActionProtocol:
...
@overload
async def async_get_device_automation_platform( # noqa: D103
hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType | str
) -> DeviceAutomationPlatformType:
...
async def async_get_device_automation_platform( async def async_get_device_automation_platform(
hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType | str hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType | str
) -> ModuleType: ) -> DeviceAutomationPlatformType:
"""Load device automation platform for integration. """Load device automation platform for integration.
Throws InvalidDeviceAutomationConfig if the integration is not found or does not support device automation. Throws InvalidDeviceAutomationConfig if the integration is not found or does not support device automation.

View file

@ -1,7 +1,15 @@
"""Offer device oriented automation.""" """Offer device oriented automation."""
from typing import cast
import voluptuous as vol import voluptuous as vol
from homeassistant.components.automation import (
AutomationActionType,
AutomationTriggerInfo,
)
from homeassistant.const import CONF_DOMAIN from homeassistant.const import CONF_DOMAIN
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
from homeassistant.helpers.typing import ConfigType
from . import ( from . import (
DEVICE_TRIGGER_BASE_SCHEMA, DEVICE_TRIGGER_BASE_SCHEMA,
@ -10,26 +18,31 @@ from . import (
) )
from .exceptions import InvalidDeviceAutomationConfig from .exceptions import InvalidDeviceAutomationConfig
# mypy: allow-untyped-defs, no-check-untyped-defs
TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA) TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
async def async_validate_trigger_config(hass, config): async def async_validate_trigger_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config.""" """Validate config."""
platform = await async_get_device_automation_platform( platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER
) )
if not hasattr(platform, "async_validate_trigger_config"): if not hasattr(platform, "async_validate_trigger_config"):
return platform.TRIGGER_SCHEMA(config) return cast(ConfigType, platform.TRIGGER_SCHEMA(config))
try: try:
return await getattr(platform, "async_validate_trigger_config")(hass, config) return await platform.async_validate_trigger_config(hass, config)
except InvalidDeviceAutomationConfig as err: except InvalidDeviceAutomationConfig as err:
raise vol.Invalid(str(err) or "Invalid trigger configuration") from err raise vol.Invalid(str(err) or "Invalid trigger configuration") from err
async def async_attach_trigger(hass, config, action, automation_info): async def async_attach_trigger(
hass: HomeAssistant,
config: ConfigType,
action: AutomationActionType,
automation_info: AutomationTriggerInfo,
) -> CALLBACK_TYPE:
"""Listen for trigger.""" """Listen for trigger."""
platform = await async_get_device_automation_platform( platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER

View file

@ -875,12 +875,7 @@ async def async_device_from_config(
platform = await async_get_device_automation_platform( platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
) )
return trace_condition_function( return trace_condition_function(platform.async_condition_from_config(hass, config))
cast(
ConditionCheckerType,
platform.async_condition_from_config(hass, config),
)
)
async def async_trigger_from_config( async def async_trigger_from_config(
@ -943,14 +938,15 @@ async def async_validate_condition_config(
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
) )
if hasattr(platform, "async_validate_condition_config"): if hasattr(platform, "async_validate_condition_config"):
return await platform.async_validate_condition_config(hass, config) # type: ignore return await platform.async_validate_condition_config(hass, config)
return cast(ConfigType, platform.CONDITION_SCHEMA(config)) return cast(ConfigType, platform.CONDITION_SCHEMA(config))
if condition in ("numeric_state", "state"): if condition in ("numeric_state", "state"):
validator = getattr( validator = cast(
sys.modules[__name__], VALIDATE_CONFIG_FORMAT.format(condition) Callable[[HomeAssistant, ConfigType], ConfigType],
getattr(sys.modules[__name__], VALIDATE_CONFIG_FORMAT.format(condition)),
) )
return validator(hass, config) # type: ignore return validator(hass, config)
return config return config