Adjust device_automation type hints in zha (#72142)

This commit is contained in:
epenet 2022-05-23 17:35:35 +02:00 committed by GitHub
parent 5b4fdb081e
commit 1b5a46a5ba
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 37 additions and 15 deletions

View file

@ -15,7 +15,7 @@ import itertools
import logging import logging
from random import uniform from random import uniform
import re import re
from typing import Any, TypeVar from typing import TYPE_CHECKING, Any, TypeVar
import voluptuous as vol import voluptuous as vol
import zigpy.exceptions import zigpy.exceptions
@ -24,7 +24,7 @@ import zigpy.util
import zigpy.zdo.types as zdo_types import zigpy.zdo.types as zdo_types
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import State, callback from homeassistant.core import HomeAssistant, State, callback
from homeassistant.helpers import device_registry as dr from homeassistant.helpers import device_registry as dr
from .const import ( from .const import (
@ -37,6 +37,10 @@ from .const import (
from .registries import BINDABLE_CLUSTERS from .registries import BINDABLE_CLUSTERS
from .typing import ZhaDeviceType, ZigpyClusterType from .typing import ZhaDeviceType, ZigpyClusterType
if TYPE_CHECKING:
from .device import ZHADevice
from .gateway import ZHAGateway
_T = TypeVar("_T") _T = TypeVar("_T")
@ -161,11 +165,11 @@ def async_cluster_exists(hass, cluster_id):
@callback @callback
def async_get_zha_device(hass, device_id): def async_get_zha_device(hass: HomeAssistant, device_id: str) -> ZHADevice:
"""Get a ZHA device for the given device registry id.""" """Get a ZHA device for the given device registry id."""
device_registry = dr.async_get(hass) device_registry = dr.async_get(hass)
registry_device = device_registry.async_get(device_id) registry_device = device_registry.async_get(device_id)
zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]
ieee_address = list(list(registry_device.identifiers)[0])[1] ieee_address = list(list(registry_device.identifiers)[0])[1]
ieee = zigpy.types.EUI64.convert(ieee_address) ieee = zigpy.types.EUI64.convert(ieee_address)
return zha_gateway.devices[ieee] return zha_gateway.devices[ieee]

View file

@ -48,7 +48,7 @@ async def async_call_action_from_config(
hass: HomeAssistant, hass: HomeAssistant,
config: ConfigType, config: ConfigType,
variables: TemplateVarsType, variables: TemplateVarsType,
context: Context, context: Context | None,
) -> None: ) -> None:
"""Perform an action based on configuration.""" """Perform an action based on configuration."""
await ZHA_ACTION_TYPES[DEVICE_ACTION_TYPES[config[CONF_TYPE]]]( await ZHA_ACTION_TYPES[DEVICE_ACTION_TYPES[config[CONF_TYPE]]](

View file

@ -1,12 +1,19 @@
"""Provides device automations for ZHA devices that emit events.""" """Provides device automations for ZHA devices that emit events."""
import voluptuous as vol import voluptuous as vol
from homeassistant.components.automation import (
AutomationActionType,
AutomationTriggerInfo,
)
from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA
from homeassistant.components.device_automation.exceptions import ( from homeassistant.components.device_automation.exceptions import (
InvalidDeviceAutomationConfig, InvalidDeviceAutomationConfig,
) )
from homeassistant.components.homeassistant.triggers import event as event_trigger from homeassistant.components.homeassistant.triggers import event as event_trigger
from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM, CONF_TYPE from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM, CONF_TYPE
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.typing import ConfigType
from . import DOMAIN from . import DOMAIN
from .core.helpers import async_get_zha_device from .core.helpers import async_get_zha_device
@ -21,7 +28,9 @@ TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend(
) )
async def async_validate_trigger_config(hass, config): async def async_validate_trigger_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config.""" """Validate config."""
config = TRIGGER_SCHEMA(config) config = TRIGGER_SCHEMA(config)
@ -40,18 +49,25 @@ async def async_validate_trigger_config(hass, config):
return config return config
async def async_attach_trigger(hass, config, action, automation_info): async def async_attach_trigger(
hass: HomeAssistant,
config: ConfigType,
action: AutomationActionType,
automation_info: AutomationTriggerInfo,
) -> CALLBACK_TYPE:
"""Listen for state changes based on configuration.""" """Listen for state changes based on configuration."""
trigger = (config[CONF_TYPE], config[CONF_SUBTYPE]) trigger_key: tuple[str, str] = (config[CONF_TYPE], config[CONF_SUBTYPE])
try: try:
zha_device = async_get_zha_device(hass, config[CONF_DEVICE_ID]) zha_device = async_get_zha_device(hass, config[CONF_DEVICE_ID])
except (KeyError, AttributeError): except (KeyError, AttributeError) as err:
return None raise HomeAssistantError(
f"Unable to get zha device {config[CONF_DEVICE_ID]}"
) from err
if trigger not in zha_device.device_automation_triggers: if trigger_key not in zha_device.device_automation_triggers:
return None raise HomeAssistantError(f"Unable to find trigger {trigger_key}")
trigger = zha_device.device_automation_triggers[trigger] trigger = zha_device.device_automation_triggers[trigger_key]
event_config = { event_config = {
event_trigger.CONF_PLATFORM: "event", event_trigger.CONF_PLATFORM: "event",
@ -65,7 +81,9 @@ async def async_attach_trigger(hass, config, action, automation_info):
) )
async def async_get_triggers(hass, device_id): async def async_get_triggers(
hass: HomeAssistant, device_id: str
) -> list[dict[str, str]]:
"""List device triggers. """List device triggers.
Make sure the device supports device automations and Make sure the device supports device automations and
@ -74,7 +92,7 @@ async def async_get_triggers(hass, device_id):
zha_device = async_get_zha_device(hass, device_id) zha_device = async_get_zha_device(hass, device_id)
if not zha_device.device_automation_triggers: if not zha_device.device_automation_triggers:
return return []
triggers = [] triggers = []
for trigger, subtype in zha_device.device_automation_triggers.keys(): for trigger, subtype in zha_device.device_automation_triggers.keys():