Improve zha typing [api] (5) (#68684)
This commit is contained in:
parent
cde989cd38
commit
006fa9b700
3 changed files with 93 additions and 56 deletions
|
@ -2,10 +2,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
from collections.abc import Mapping
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple
|
||||
|
||||
import voluptuous as vol
|
||||
from zigpy.config.validators import cv_boolean
|
||||
|
@ -14,7 +12,7 @@ from zigpy.zcl.clusters.security import IasAce
|
|||
import zigpy.zdo.types as zdo_types
|
||||
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.const import ATTR_COMMAND, ATTR_NAME
|
||||
from homeassistant.const import ATTR_COMMAND, ATTR_ID, ATTR_NAME
|
||||
from homeassistant.core import HomeAssistant, ServiceCall, callback
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
|
@ -31,6 +29,7 @@ from .core.const import (
|
|||
ATTR_LEVEL,
|
||||
ATTR_MANUFACTURER,
|
||||
ATTR_MEMBERS,
|
||||
ATTR_TYPE,
|
||||
ATTR_VALUE,
|
||||
ATTR_WARNING_DEVICE_DURATION,
|
||||
ATTR_WARNING_DEVICE_MODE,
|
||||
|
@ -201,7 +200,56 @@ SERVICE_SCHEMAS = {
|
|||
),
|
||||
}
|
||||
|
||||
ClusterBinding = collections.namedtuple("ClusterBinding", "id endpoint_id type name")
|
||||
|
||||
class ClusterBinding(NamedTuple):
|
||||
"""Describes a cluster binding."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
id: int
|
||||
endpoint_id: int
|
||||
|
||||
|
||||
def _cv_group_member(value: dict[str, Any]) -> GroupMember:
|
||||
"""Transform a group member."""
|
||||
return GroupMember(
|
||||
ieee=value[ATTR_IEEE],
|
||||
endpoint_id=value[ATTR_ENDPOINT_ID],
|
||||
)
|
||||
|
||||
|
||||
def _cv_cluster_binding(value: dict[str, Any]) -> ClusterBinding:
|
||||
"""Transform a cluster binding."""
|
||||
return ClusterBinding(
|
||||
name=value[ATTR_NAME],
|
||||
type=value[ATTR_TYPE],
|
||||
id=value[ATTR_ID],
|
||||
endpoint_id=value[ATTR_ENDPOINT_ID],
|
||||
)
|
||||
|
||||
|
||||
GROUP_MEMBER_SCHEMA = vol.All(
|
||||
vol.Schema(
|
||||
{
|
||||
vol.Required(ATTR_IEEE): IEEE_SCHEMA,
|
||||
vol.Required(ATTR_ENDPOINT_ID): int,
|
||||
}
|
||||
),
|
||||
_cv_group_member,
|
||||
)
|
||||
|
||||
|
||||
CLUSTER_BINDING_SCHEMA = vol.All(
|
||||
vol.Schema(
|
||||
{
|
||||
vol.Required(ATTR_NAME): cv.string,
|
||||
vol.Required(ATTR_TYPE): cv.string,
|
||||
vol.Required(ATTR_ID): int,
|
||||
vol.Required(ATTR_ENDPOINT_ID): int,
|
||||
}
|
||||
),
|
||||
_cv_cluster_binding,
|
||||
)
|
||||
|
||||
|
||||
@websocket_api.require_admin
|
||||
|
@ -374,27 +422,13 @@ async def websocket_get_group(
|
|||
connection.send_result(msg[ID], group_info)
|
||||
|
||||
|
||||
def cv_group_member(value: Any) -> GroupMember:
|
||||
"""Validate and transform a group member."""
|
||||
if not isinstance(value, Mapping):
|
||||
raise vol.Invalid("Not a group member")
|
||||
try:
|
||||
group_member = GroupMember(
|
||||
ieee=EUI64.convert(value["ieee"]), endpoint_id=value["endpoint_id"]
|
||||
)
|
||||
except KeyError as err:
|
||||
raise vol.Invalid("Not a group member") from err
|
||||
|
||||
return group_member
|
||||
|
||||
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required(TYPE): "zha/group/add",
|
||||
vol.Required(GROUP_NAME): cv.string,
|
||||
vol.Optional(GROUP_ID): cv.positive_int,
|
||||
vol.Optional(ATTR_MEMBERS): vol.All(cv.ensure_list, [cv_group_member]),
|
||||
vol.Optional(ATTR_MEMBERS): vol.All(cv.ensure_list, [GROUP_MEMBER_SCHEMA]),
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
|
@ -441,7 +475,7 @@ async def websocket_remove_groups(
|
|||
{
|
||||
vol.Required(TYPE): "zha/group/members/add",
|
||||
vol.Required(GROUP_ID): cv.positive_int,
|
||||
vol.Required(ATTR_MEMBERS): vol.All(cv.ensure_list, [cv_group_member]),
|
||||
vol.Required(ATTR_MEMBERS): vol.All(cv.ensure_list, [GROUP_MEMBER_SCHEMA]),
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
|
@ -471,7 +505,7 @@ async def websocket_add_group_members(
|
|||
{
|
||||
vol.Required(TYPE): "zha/group/members/remove",
|
||||
vol.Required(GROUP_ID): cv.positive_int,
|
||||
vol.Required(ATTR_MEMBERS): vol.All(cv.ensure_list, [cv_group_member]),
|
||||
vol.Required(ATTR_MEMBERS): vol.All(cv.ensure_list, [GROUP_MEMBER_SCHEMA]),
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
|
@ -837,30 +871,13 @@ async def websocket_unbind_devices(
|
|||
)
|
||||
|
||||
|
||||
def is_cluster_binding(value: Any) -> ClusterBinding:
|
||||
"""Validate and transform a cluster binding."""
|
||||
if not isinstance(value, Mapping):
|
||||
raise vol.Invalid("Not a cluster binding")
|
||||
try:
|
||||
cluster_binding = ClusterBinding(
|
||||
name=value["name"],
|
||||
type=value["type"],
|
||||
id=value["id"],
|
||||
endpoint_id=value["endpoint_id"],
|
||||
)
|
||||
except KeyError as err:
|
||||
raise vol.Invalid("Not a cluster binding") from err
|
||||
|
||||
return cluster_binding
|
||||
|
||||
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required(TYPE): "zha/groups/bind",
|
||||
vol.Required(ATTR_SOURCE_IEEE): IEEE_SCHEMA,
|
||||
vol.Required(GROUP_ID): cv.positive_int,
|
||||
vol.Required(BINDINGS): vol.All(cv.ensure_list, [is_cluster_binding]),
|
||||
vol.Required(BINDINGS): vol.All(cv.ensure_list, [CLUSTER_BINDING_SCHEMA]),
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
|
@ -882,7 +899,7 @@ async def websocket_bind_group(
|
|||
vol.Required(TYPE): "zha/groups/unbind",
|
||||
vol.Required(ATTR_SOURCE_IEEE): IEEE_SCHEMA,
|
||||
vol.Required(GROUP_ID): cv.positive_int,
|
||||
vol.Required(BINDINGS): vol.All(cv.ensure_list, [is_cluster_binding]),
|
||||
vol.Required(BINDINGS): vol.All(cv.ensure_list, [CLUSTER_BINDING_SCHEMA]),
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
|
|
|
@ -8,7 +8,7 @@ from enum import Enum
|
|||
import logging
|
||||
import random
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from zigpy import types
|
||||
import zigpy.exceptions
|
||||
|
@ -75,6 +75,9 @@ from .const import (
|
|||
)
|
||||
from .helpers import LogMixin, async_get_zha_config_value
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..api import ClusterBinding
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
_UPDATE_ALIVE_INTERVAL = (60, 90)
|
||||
_CHECKIN_GRACE_PERIODS = 2
|
||||
|
@ -655,7 +658,7 @@ class ZHADevice(LogMixin):
|
|||
)
|
||||
return response
|
||||
|
||||
async def async_add_to_group(self, group_id):
|
||||
async def async_add_to_group(self, group_id: int) -> None:
|
||||
"""Add this device to the provided zigbee group."""
|
||||
try:
|
||||
await self._zigpy_device.add_to_group(group_id)
|
||||
|
@ -667,7 +670,7 @@ class ZHADevice(LogMixin):
|
|||
str(ex),
|
||||
)
|
||||
|
||||
async def async_remove_from_group(self, group_id):
|
||||
async def async_remove_from_group(self, group_id: int) -> None:
|
||||
"""Remove this device from the provided zigbee group."""
|
||||
try:
|
||||
await self._zigpy_device.remove_from_group(group_id)
|
||||
|
@ -679,10 +682,12 @@ class ZHADevice(LogMixin):
|
|||
str(ex),
|
||||
)
|
||||
|
||||
async def async_add_endpoint_to_group(self, endpoint_id, group_id):
|
||||
async def async_add_endpoint_to_group(
|
||||
self, endpoint_id: int, group_id: int
|
||||
) -> None:
|
||||
"""Add the device endpoint to the provided zigbee group."""
|
||||
try:
|
||||
await self._zigpy_device.endpoints[int(endpoint_id)].add_to_group(group_id)
|
||||
await self._zigpy_device.endpoints[endpoint_id].add_to_group(group_id)
|
||||
except (zigpy.exceptions.ZigbeeException, asyncio.TimeoutError) as ex:
|
||||
self.debug(
|
||||
"Failed to add endpoint: %s for device: '%s' to group: 0x%04x ex: %s",
|
||||
|
@ -692,12 +697,12 @@ class ZHADevice(LogMixin):
|
|||
str(ex),
|
||||
)
|
||||
|
||||
async def async_remove_endpoint_from_group(self, endpoint_id, group_id):
|
||||
async def async_remove_endpoint_from_group(
|
||||
self, endpoint_id: int, group_id: int
|
||||
) -> None:
|
||||
"""Remove the device endpoint from the provided zigbee group."""
|
||||
try:
|
||||
await self._zigpy_device.endpoints[int(endpoint_id)].remove_from_group(
|
||||
group_id
|
||||
)
|
||||
await self._zigpy_device.endpoints[endpoint_id].remove_from_group(group_id)
|
||||
except (zigpy.exceptions.ZigbeeException, asyncio.TimeoutError) as ex:
|
||||
self.debug(
|
||||
"Failed to remove endpoint: %s for device '%s' from group: 0x%04x ex: %s",
|
||||
|
@ -707,21 +712,28 @@ class ZHADevice(LogMixin):
|
|||
str(ex),
|
||||
)
|
||||
|
||||
async def async_bind_to_group(self, group_id, cluster_bindings):
|
||||
async def async_bind_to_group(
|
||||
self, group_id: int, cluster_bindings: list[ClusterBinding]
|
||||
) -> None:
|
||||
"""Directly bind this device to a group for the given clusters."""
|
||||
await self._async_group_binding_operation(
|
||||
group_id, zdo_types.ZDOCmd.Bind_req, cluster_bindings
|
||||
)
|
||||
|
||||
async def async_unbind_from_group(self, group_id, cluster_bindings):
|
||||
async def async_unbind_from_group(
|
||||
self, group_id: int, cluster_bindings: list[ClusterBinding]
|
||||
) -> None:
|
||||
"""Unbind this device from a group for the given clusters."""
|
||||
await self._async_group_binding_operation(
|
||||
group_id, zdo_types.ZDOCmd.Unbind_req, cluster_bindings
|
||||
)
|
||||
|
||||
async def _async_group_binding_operation(
|
||||
self, group_id, operation, cluster_bindings
|
||||
):
|
||||
self,
|
||||
group_id: int,
|
||||
operation: zdo_types.ZDOCmd,
|
||||
cluster_bindings: list[ClusterBinding],
|
||||
) -> None:
|
||||
"""Create or remove a direct zigbee binding between a device and a group."""
|
||||
|
||||
zdo = self._zigpy_device.zdo
|
||||
|
|
|
@ -4,11 +4,12 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import collections
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple
|
||||
|
||||
import zigpy.endpoint
|
||||
import zigpy.exceptions
|
||||
import zigpy.group
|
||||
from zigpy.types.named import EUI64
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_registry import async_entries_for_device
|
||||
|
@ -21,7 +22,14 @@ if TYPE_CHECKING:
|
|||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
GroupMember = collections.namedtuple("GroupMember", "ieee endpoint_id")
|
||||
|
||||
class GroupMember(NamedTuple):
|
||||
"""Describes a group member."""
|
||||
|
||||
ieee: EUI64
|
||||
endpoint_id: int
|
||||
|
||||
|
||||
GroupEntityReference = collections.namedtuple(
|
||||
"GroupEntityReference", "name original_name entity_id"
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue