Improve zha typing [api] (5) (#68684)

This commit is contained in:
Marc Mueller 2022-03-30 15:54:31 +02:00 committed by GitHub
parent cde989cd38
commit 006fa9b700
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 93 additions and 56 deletions

View file

@ -2,10 +2,8 @@
from __future__ import annotations
import asyncio
import collections
from collections.abc import Mapping
import logging
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, NamedTuple
import voluptuous as vol
from zigpy.config.validators import cv_boolean
@ -14,7 +12,7 @@ from zigpy.zcl.clusters.security import IasAce
import zigpy.zdo.types as zdo_types
from homeassistant.components import websocket_api
from homeassistant.const import ATTR_COMMAND, ATTR_NAME
from homeassistant.const import ATTR_COMMAND, ATTR_ID, ATTR_NAME
from homeassistant.core import HomeAssistant, ServiceCall, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import async_dispatcher_connect
@ -31,6 +29,7 @@ from .core.const import (
ATTR_LEVEL,
ATTR_MANUFACTURER,
ATTR_MEMBERS,
ATTR_TYPE,
ATTR_VALUE,
ATTR_WARNING_DEVICE_DURATION,
ATTR_WARNING_DEVICE_MODE,
@ -201,7 +200,56 @@ SERVICE_SCHEMAS = {
),
}
ClusterBinding = collections.namedtuple("ClusterBinding", "id endpoint_id type name")
class ClusterBinding(NamedTuple):
"""Describes a cluster binding."""
name: str
type: str
id: int
endpoint_id: int
def _cv_group_member(value: dict[str, Any]) -> GroupMember:
"""Transform a group member."""
return GroupMember(
ieee=value[ATTR_IEEE],
endpoint_id=value[ATTR_ENDPOINT_ID],
)
def _cv_cluster_binding(value: dict[str, Any]) -> ClusterBinding:
"""Transform a cluster binding."""
return ClusterBinding(
name=value[ATTR_NAME],
type=value[ATTR_TYPE],
id=value[ATTR_ID],
endpoint_id=value[ATTR_ENDPOINT_ID],
)
GROUP_MEMBER_SCHEMA = vol.All(
vol.Schema(
{
vol.Required(ATTR_IEEE): IEEE_SCHEMA,
vol.Required(ATTR_ENDPOINT_ID): int,
}
),
_cv_group_member,
)
CLUSTER_BINDING_SCHEMA = vol.All(
vol.Schema(
{
vol.Required(ATTR_NAME): cv.string,
vol.Required(ATTR_TYPE): cv.string,
vol.Required(ATTR_ID): int,
vol.Required(ATTR_ENDPOINT_ID): int,
}
),
_cv_cluster_binding,
)
@websocket_api.require_admin
@ -374,27 +422,13 @@ async def websocket_get_group(
connection.send_result(msg[ID], group_info)
def cv_group_member(value: Any) -> GroupMember:
"""Validate and transform a group member."""
if not isinstance(value, Mapping):
raise vol.Invalid("Not a group member")
try:
group_member = GroupMember(
ieee=EUI64.convert(value["ieee"]), endpoint_id=value["endpoint_id"]
)
except KeyError as err:
raise vol.Invalid("Not a group member") from err
return group_member
@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required(TYPE): "zha/group/add",
vol.Required(GROUP_NAME): cv.string,
vol.Optional(GROUP_ID): cv.positive_int,
vol.Optional(ATTR_MEMBERS): vol.All(cv.ensure_list, [cv_group_member]),
vol.Optional(ATTR_MEMBERS): vol.All(cv.ensure_list, [GROUP_MEMBER_SCHEMA]),
}
)
@websocket_api.async_response
@ -441,7 +475,7 @@ async def websocket_remove_groups(
{
vol.Required(TYPE): "zha/group/members/add",
vol.Required(GROUP_ID): cv.positive_int,
vol.Required(ATTR_MEMBERS): vol.All(cv.ensure_list, [cv_group_member]),
vol.Required(ATTR_MEMBERS): vol.All(cv.ensure_list, [GROUP_MEMBER_SCHEMA]),
}
)
@websocket_api.async_response
@ -471,7 +505,7 @@ async def websocket_add_group_members(
{
vol.Required(TYPE): "zha/group/members/remove",
vol.Required(GROUP_ID): cv.positive_int,
vol.Required(ATTR_MEMBERS): vol.All(cv.ensure_list, [cv_group_member]),
vol.Required(ATTR_MEMBERS): vol.All(cv.ensure_list, [GROUP_MEMBER_SCHEMA]),
}
)
@websocket_api.async_response
@ -837,30 +871,13 @@ async def websocket_unbind_devices(
)
def is_cluster_binding(value: Any) -> ClusterBinding:
"""Validate and transform a cluster binding."""
if not isinstance(value, Mapping):
raise vol.Invalid("Not a cluster binding")
try:
cluster_binding = ClusterBinding(
name=value["name"],
type=value["type"],
id=value["id"],
endpoint_id=value["endpoint_id"],
)
except KeyError as err:
raise vol.Invalid("Not a cluster binding") from err
return cluster_binding
@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required(TYPE): "zha/groups/bind",
vol.Required(ATTR_SOURCE_IEEE): IEEE_SCHEMA,
vol.Required(GROUP_ID): cv.positive_int,
vol.Required(BINDINGS): vol.All(cv.ensure_list, [is_cluster_binding]),
vol.Required(BINDINGS): vol.All(cv.ensure_list, [CLUSTER_BINDING_SCHEMA]),
}
)
@websocket_api.async_response
@ -882,7 +899,7 @@ async def websocket_bind_group(
vol.Required(TYPE): "zha/groups/unbind",
vol.Required(ATTR_SOURCE_IEEE): IEEE_SCHEMA,
vol.Required(GROUP_ID): cv.positive_int,
vol.Required(BINDINGS): vol.All(cv.ensure_list, [is_cluster_binding]),
vol.Required(BINDINGS): vol.All(cv.ensure_list, [CLUSTER_BINDING_SCHEMA]),
}
)
@websocket_api.async_response