Improve zha typing [api] (5) (#68684)

This commit is contained in:
Marc Mueller 2022-03-30 15:54:31 +02:00 committed by GitHub
parent cde989cd38
commit 006fa9b700
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 93 additions and 56 deletions

View file

@ -2,10 +2,8 @@
from __future__ import annotations
import asyncio
import collections
from collections.abc import Mapping
import logging
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, NamedTuple
import voluptuous as vol
from zigpy.config.validators import cv_boolean
@ -14,7 +12,7 @@ from zigpy.zcl.clusters.security import IasAce
import zigpy.zdo.types as zdo_types
from homeassistant.components import websocket_api
from homeassistant.const import ATTR_COMMAND, ATTR_NAME
from homeassistant.const import ATTR_COMMAND, ATTR_ID, ATTR_NAME
from homeassistant.core import HomeAssistant, ServiceCall, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import async_dispatcher_connect
@ -31,6 +29,7 @@ from .core.const import (
ATTR_LEVEL,
ATTR_MANUFACTURER,
ATTR_MEMBERS,
ATTR_TYPE,
ATTR_VALUE,
ATTR_WARNING_DEVICE_DURATION,
ATTR_WARNING_DEVICE_MODE,
@ -201,7 +200,56 @@ SERVICE_SCHEMAS = {
),
}
ClusterBinding = collections.namedtuple("ClusterBinding", "id endpoint_id type name")
class ClusterBinding(NamedTuple):
"""Describes a cluster binding."""
name: str
type: str
id: int
endpoint_id: int
def _cv_group_member(value: dict[str, Any]) -> GroupMember:
"""Transform a group member."""
return GroupMember(
ieee=value[ATTR_IEEE],
endpoint_id=value[ATTR_ENDPOINT_ID],
)
def _cv_cluster_binding(value: dict[str, Any]) -> ClusterBinding:
"""Transform a cluster binding."""
return ClusterBinding(
name=value[ATTR_NAME],
type=value[ATTR_TYPE],
id=value[ATTR_ID],
endpoint_id=value[ATTR_ENDPOINT_ID],
)
GROUP_MEMBER_SCHEMA = vol.All(
vol.Schema(
{
vol.Required(ATTR_IEEE): IEEE_SCHEMA,
vol.Required(ATTR_ENDPOINT_ID): int,
}
),
_cv_group_member,
)
CLUSTER_BINDING_SCHEMA = vol.All(
vol.Schema(
{
vol.Required(ATTR_NAME): cv.string,
vol.Required(ATTR_TYPE): cv.string,
vol.Required(ATTR_ID): int,
vol.Required(ATTR_ENDPOINT_ID): int,
}
),
_cv_cluster_binding,
)
@websocket_api.require_admin
@ -374,27 +422,13 @@ async def websocket_get_group(
connection.send_result(msg[ID], group_info)
def cv_group_member(value: Any) -> GroupMember:
"""Validate and transform a group member."""
if not isinstance(value, Mapping):
raise vol.Invalid("Not a group member")
try:
group_member = GroupMember(
ieee=EUI64.convert(value["ieee"]), endpoint_id=value["endpoint_id"]
)
except KeyError as err:
raise vol.Invalid("Not a group member") from err
return group_member
@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required(TYPE): "zha/group/add",
vol.Required(GROUP_NAME): cv.string,
vol.Optional(GROUP_ID): cv.positive_int,
vol.Optional(ATTR_MEMBERS): vol.All(cv.ensure_list, [cv_group_member]),
vol.Optional(ATTR_MEMBERS): vol.All(cv.ensure_list, [GROUP_MEMBER_SCHEMA]),
}
)
@websocket_api.async_response
@ -441,7 +475,7 @@ async def websocket_remove_groups(
{
vol.Required(TYPE): "zha/group/members/add",
vol.Required(GROUP_ID): cv.positive_int,
vol.Required(ATTR_MEMBERS): vol.All(cv.ensure_list, [cv_group_member]),
vol.Required(ATTR_MEMBERS): vol.All(cv.ensure_list, [GROUP_MEMBER_SCHEMA]),
}
)
@websocket_api.async_response
@ -471,7 +505,7 @@ async def websocket_add_group_members(
{
vol.Required(TYPE): "zha/group/members/remove",
vol.Required(GROUP_ID): cv.positive_int,
vol.Required(ATTR_MEMBERS): vol.All(cv.ensure_list, [cv_group_member]),
vol.Required(ATTR_MEMBERS): vol.All(cv.ensure_list, [GROUP_MEMBER_SCHEMA]),
}
)
@websocket_api.async_response
@ -837,30 +871,13 @@ async def websocket_unbind_devices(
)
def is_cluster_binding(value: Any) -> ClusterBinding:
"""Validate and transform a cluster binding."""
if not isinstance(value, Mapping):
raise vol.Invalid("Not a cluster binding")
try:
cluster_binding = ClusterBinding(
name=value["name"],
type=value["type"],
id=value["id"],
endpoint_id=value["endpoint_id"],
)
except KeyError as err:
raise vol.Invalid("Not a cluster binding") from err
return cluster_binding
@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required(TYPE): "zha/groups/bind",
vol.Required(ATTR_SOURCE_IEEE): IEEE_SCHEMA,
vol.Required(GROUP_ID): cv.positive_int,
vol.Required(BINDINGS): vol.All(cv.ensure_list, [is_cluster_binding]),
vol.Required(BINDINGS): vol.All(cv.ensure_list, [CLUSTER_BINDING_SCHEMA]),
}
)
@websocket_api.async_response
@ -882,7 +899,7 @@ async def websocket_bind_group(
vol.Required(TYPE): "zha/groups/unbind",
vol.Required(ATTR_SOURCE_IEEE): IEEE_SCHEMA,
vol.Required(GROUP_ID): cv.positive_int,
vol.Required(BINDINGS): vol.All(cv.ensure_list, [is_cluster_binding]),
vol.Required(BINDINGS): vol.All(cv.ensure_list, [CLUSTER_BINDING_SCHEMA]),
}
)
@websocket_api.async_response

View file

@ -8,7 +8,7 @@ from enum import Enum
import logging
import random
import time
from typing import Any
from typing import TYPE_CHECKING, Any
from zigpy import types
import zigpy.exceptions
@ -75,6 +75,9 @@ from .const import (
)
from .helpers import LogMixin, async_get_zha_config_value
if TYPE_CHECKING:
from ..api import ClusterBinding
_LOGGER = logging.getLogger(__name__)
_UPDATE_ALIVE_INTERVAL = (60, 90)
_CHECKIN_GRACE_PERIODS = 2
@ -655,7 +658,7 @@ class ZHADevice(LogMixin):
)
return response
async def async_add_to_group(self, group_id):
async def async_add_to_group(self, group_id: int) -> None:
"""Add this device to the provided zigbee group."""
try:
await self._zigpy_device.add_to_group(group_id)
@ -667,7 +670,7 @@ class ZHADevice(LogMixin):
str(ex),
)
async def async_remove_from_group(self, group_id):
async def async_remove_from_group(self, group_id: int) -> None:
"""Remove this device from the provided zigbee group."""
try:
await self._zigpy_device.remove_from_group(group_id)
@ -679,10 +682,12 @@ class ZHADevice(LogMixin):
str(ex),
)
async def async_add_endpoint_to_group(self, endpoint_id, group_id):
async def async_add_endpoint_to_group(
self, endpoint_id: int, group_id: int
) -> None:
"""Add the device endpoint to the provided zigbee group."""
try:
await self._zigpy_device.endpoints[int(endpoint_id)].add_to_group(group_id)
await self._zigpy_device.endpoints[endpoint_id].add_to_group(group_id)
except (zigpy.exceptions.ZigbeeException, asyncio.TimeoutError) as ex:
self.debug(
"Failed to add endpoint: %s for device: '%s' to group: 0x%04x ex: %s",
@ -692,12 +697,12 @@ class ZHADevice(LogMixin):
str(ex),
)
async def async_remove_endpoint_from_group(self, endpoint_id, group_id):
async def async_remove_endpoint_from_group(
self, endpoint_id: int, group_id: int
) -> None:
"""Remove the device endpoint from the provided zigbee group."""
try:
await self._zigpy_device.endpoints[int(endpoint_id)].remove_from_group(
group_id
)
await self._zigpy_device.endpoints[endpoint_id].remove_from_group(group_id)
except (zigpy.exceptions.ZigbeeException, asyncio.TimeoutError) as ex:
self.debug(
"Failed to remove endpoint: %s for device '%s' from group: 0x%04x ex: %s",
@ -707,21 +712,28 @@ class ZHADevice(LogMixin):
str(ex),
)
async def async_bind_to_group(self, group_id, cluster_bindings):
async def async_bind_to_group(
self, group_id: int, cluster_bindings: list[ClusterBinding]
) -> None:
"""Directly bind this device to a group for the given clusters."""
await self._async_group_binding_operation(
group_id, zdo_types.ZDOCmd.Bind_req, cluster_bindings
)
async def async_unbind_from_group(self, group_id, cluster_bindings):
async def async_unbind_from_group(
self, group_id: int, cluster_bindings: list[ClusterBinding]
) -> None:
"""Unbind this device from a group for the given clusters."""
await self._async_group_binding_operation(
group_id, zdo_types.ZDOCmd.Unbind_req, cluster_bindings
)
async def _async_group_binding_operation(
self, group_id, operation, cluster_bindings
):
self,
group_id: int,
operation: zdo_types.ZDOCmd,
cluster_bindings: list[ClusterBinding],
) -> None:
"""Create or remove a direct zigbee binding between a device and a group."""
zdo = self._zigpy_device.zdo

View file

@ -4,11 +4,12 @@ from __future__ import annotations
import asyncio
import collections
import logging
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, NamedTuple
import zigpy.endpoint
import zigpy.exceptions
import zigpy.group
from zigpy.types.named import EUI64
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_registry import async_entries_for_device
@ -21,7 +22,14 @@ if TYPE_CHECKING:
_LOGGER = logging.getLogger(__name__)
GroupMember = collections.namedtuple("GroupMember", "ieee endpoint_id")
class GroupMember(NamedTuple):
"""Describes a group member."""
ieee: EUI64
endpoint_id: int
GroupEntityReference = collections.namedtuple(
"GroupEntityReference", "name original_name entity_id"
)