Improve device_automation typing (#66621)

This commit is contained in:
Marc Mueller 2022-02-17 22:08:43 +01:00 committed by GitHub
parent 8bf19655f1
commit 1a247f7d1b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 148 additions and 19 deletions

View file

@ -7,14 +7,14 @@ from enum import Enum
from functools import wraps
import logging
from types import ModuleType
from typing import Any, NamedTuple
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Protocol, Union, overload
import voluptuous as vol
import voluptuous_serialize
from homeassistant.components import websocket_api
from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM
from homeassistant.core import HomeAssistant
from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant
from homeassistant.helpers import (
config_validation as cv,
device_registry as dr,
@ -27,6 +27,13 @@ from homeassistant.requirements import async_get_integration_with_requirements
from .exceptions import DeviceNotFound, InvalidDeviceAutomationConfig
if TYPE_CHECKING:
from homeassistant.components.automation import (
AutomationActionType,
AutomationTriggerInfo,
)
from homeassistant.helpers import condition
# mypy: allow-untyped-calls, allow-untyped-defs
DOMAIN = "device_automation"
@ -76,6 +83,77 @@ TYPES = {
}
class DeviceAutomationTriggerProtocol(Protocol):
"""Define the format of device_trigger modules.
Each module must define either TRIGGER_SCHEMA or async_validate_trigger_config.
"""
TRIGGER_SCHEMA: vol.Schema
async def async_validate_trigger_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
raise NotImplementedError
async def async_attach_trigger(
self,
hass: HomeAssistant,
config: ConfigType,
action: AutomationActionType,
automation_info: AutomationTriggerInfo,
) -> CALLBACK_TYPE:
"""Attach a trigger."""
raise NotImplementedError
class DeviceAutomationConditionProtocol(Protocol):
"""Define the format of device_condition modules.
Each module must define either CONDITION_SCHEMA or async_validate_condition_config.
"""
CONDITION_SCHEMA: vol.Schema
async def async_validate_condition_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
raise NotImplementedError
def async_condition_from_config(
self, hass: HomeAssistant, config: ConfigType
) -> condition.ConditionCheckerType:
"""Evaluate state based on configuration."""
raise NotImplementedError
class DeviceAutomationActionProtocol(Protocol):
"""Define the format of device_action modules.
Each module must define either ACTION_SCHEMA or async_validate_action_config.
"""
ACTION_SCHEMA: vol.Schema
async def async_validate_action_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
raise NotImplementedError
async def async_call_action_from_config(
self,
hass: HomeAssistant,
config: ConfigType,
variables: dict[str, Any],
context: Context | None,
) -> None:
"""Execute a device action."""
raise NotImplementedError
@bind_hass
async def async_get_device_automations(
hass: HomeAssistant,
@ -115,9 +193,51 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
DeviceAutomationPlatformType = Union[
ModuleType,
DeviceAutomationTriggerProtocol,
DeviceAutomationConditionProtocol,
DeviceAutomationActionProtocol,
]
@overload
async def async_get_device_automation_platform( # noqa: D103
hass: HomeAssistant,
domain: str,
automation_type: Literal[DeviceAutomationType.TRIGGER],
) -> DeviceAutomationTriggerProtocol:
...
@overload
async def async_get_device_automation_platform( # noqa: D103
hass: HomeAssistant,
domain: str,
automation_type: Literal[DeviceAutomationType.CONDITION],
) -> DeviceAutomationConditionProtocol:
...
@overload
async def async_get_device_automation_platform( # noqa: D103
hass: HomeAssistant,
domain: str,
automation_type: Literal[DeviceAutomationType.ACTION],
) -> DeviceAutomationActionProtocol:
...
@overload
async def async_get_device_automation_platform( # noqa: D103
hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType | str
) -> DeviceAutomationPlatformType:
...
async def async_get_device_automation_platform(
hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType | str
) -> ModuleType:
) -> DeviceAutomationPlatformType:
"""Load device automation platform for integration.
Throws InvalidDeviceAutomationConfig if the integration is not found or does not support device automation.

View file

@ -1,7 +1,15 @@
"""Offer device oriented automation."""
from typing import cast
import voluptuous as vol
from homeassistant.components.automation import (
AutomationActionType,
AutomationTriggerInfo,
)
from homeassistant.const import CONF_DOMAIN
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
from homeassistant.helpers.typing import ConfigType
from . import (
DEVICE_TRIGGER_BASE_SCHEMA,
@ -10,26 +18,31 @@ from . import (
)
from .exceptions import InvalidDeviceAutomationConfig
# mypy: allow-untyped-defs, no-check-untyped-defs
TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
async def async_validate_trigger_config(hass, config):
async def async_validate_trigger_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER
)
if not hasattr(platform, "async_validate_trigger_config"):
return platform.TRIGGER_SCHEMA(config)
return cast(ConfigType, platform.TRIGGER_SCHEMA(config))
try:
return await getattr(platform, "async_validate_trigger_config")(hass, config)
return await platform.async_validate_trigger_config(hass, config)
except InvalidDeviceAutomationConfig as err:
raise vol.Invalid(str(err) or "Invalid trigger configuration") from err
async def async_attach_trigger(hass, config, action, automation_info):
async def async_attach_trigger(
hass: HomeAssistant,
config: ConfigType,
action: AutomationActionType,
automation_info: AutomationTriggerInfo,
) -> CALLBACK_TYPE:
"""Listen for trigger."""
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER

View file

@ -875,12 +875,7 @@ async def async_device_from_config(
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
)
return trace_condition_function(
cast(
ConditionCheckerType,
platform.async_condition_from_config(hass, config),
)
)
return trace_condition_function(platform.async_condition_from_config(hass, config))
async def async_trigger_from_config(
@ -943,14 +938,15 @@ async def async_validate_condition_config(
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
)
if hasattr(platform, "async_validate_condition_config"):
return await platform.async_validate_condition_config(hass, config) # type: ignore
return await platform.async_validate_condition_config(hass, config)
return cast(ConfigType, platform.CONDITION_SCHEMA(config))
if condition in ("numeric_state", "state"):
validator = getattr(
sys.modules[__name__], VALIDATE_CONFIG_FORMAT.format(condition)
validator = cast(
Callable[[HomeAssistant, ConfigType], ConfigType],
getattr(sys.modules[__name__], VALIDATE_CONFIG_FORMAT.format(condition)),
)
return validator(hass, config) # type: ignore
return validator(hass, config)
return config