Drop custom type (CALLABLE_T) from zha (#73736)

* Drop CALLABLE_T from zha

* Adjust .coveragerc

* Apply suggestions from code review

Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>

* Add TypeVar

* Apply suggestions from code review

Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>

* One more

Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>

* Flake8

Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
This commit is contained in:
epenet 2022-06-20 14:50:27 +02:00 committed by GitHub
parent c075760ca0
commit b6d3e34ebc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 31 additions and 32 deletions

View file

@ -1516,7 +1516,6 @@ omit =
homeassistant/components/zha/core/gateway.py
homeassistant/components/zha/core/helpers.py
homeassistant/components/zha/core/registries.py
homeassistant/components/zha/core/typing.py
homeassistant/components/zha/entity.py
homeassistant/components/zha/light.py
homeassistant/components/zha/sensor.py

View file

@ -7,7 +7,8 @@ https://home-assistant.io/integrations/zha/
from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from zigpy.exceptions import ZigbeeException
import zigpy.zcl
@ -25,7 +26,6 @@ from ..const import (
WARNING_DEVICE_STROBE_HIGH,
WARNING_DEVICE_STROBE_YES,
)
from ..typing import CALLABLE_T
from .base import ChannelStatus, ZigbeeChannel
if TYPE_CHECKING:
@ -55,7 +55,7 @@ class IasAce(ZigbeeChannel):
def __init__(self, cluster: zigpy.zcl.Cluster, ch_pool: ChannelPool) -> None:
"""Initialize IAS Ancillary Control Equipment channel."""
super().__init__(cluster, ch_pool)
self.command_map: dict[int, CALLABLE_T] = {
self.command_map: dict[int, Callable[..., Any]] = {
IAS_ACE_ARM: self.arm,
IAS_ACE_BYPASS: self._bypass,
IAS_ACE_EMERGENCY: self._emergency,
@ -67,7 +67,7 @@ class IasAce(ZigbeeChannel):
IAS_ACE_GET_BYPASSED_ZONE_LIST: self._get_bypassed_zone_list,
IAS_ACE_GET_ZONE_STATUS: self._get_zone_status,
}
self.arm_map: dict[AceCluster.ArmMode, CALLABLE_T] = {
self.arm_map: dict[AceCluster.ArmMode, Callable[..., Any]] = {
AceCluster.ArmMode.Disarm: self._disarm,
AceCluster.ArmMode.Arm_All_Zones: self._arm_away,
AceCluster.ArmMode.Arm_Day_Home_Only: self._arm_day,

View file

@ -4,7 +4,7 @@ from __future__ import annotations
import collections
from collections.abc import Callable
import dataclasses
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypeVar
import attr
from zigpy import zcl
@ -17,11 +17,15 @@ from homeassistant.const import Platform
# importing channels updates registries
from . import channels as zha_channels # noqa: F401 pylint: disable=unused-import
from .decorators import DictRegistry, SetRegistry
from .typing import CALLABLE_T
if TYPE_CHECKING:
from ..entity import ZhaEntity, ZhaGroupEntity
from .channels.base import ClientChannel, ZigbeeChannel
_ZhaEntityT = TypeVar("_ZhaEntityT", bound=type["ZhaEntity"])
_ZhaGroupEntityT = TypeVar("_ZhaGroupEntityT", bound=type["ZhaGroupEntity"])
GROUP_ENTITY_DOMAINS = [Platform.LIGHT, Platform.SWITCH, Platform.FAN]
PHILLIPS_REMOTE_CLUSTER = 0xFC00
@ -215,7 +219,7 @@ class MatchRule:
class EntityClassAndChannels:
"""Container for entity class and corresponding channels."""
entity_class: CALLABLE_T
entity_class: type[ZhaEntity]
claimed_channel: list[ZigbeeChannel]
@ -225,19 +229,19 @@ class ZHAEntityRegistry:
def __init__(self):
"""Initialize Registry instance."""
self._strict_registry: dict[
str, dict[MatchRule, CALLABLE_T]
str, dict[MatchRule, type[ZhaEntity]]
] = collections.defaultdict(dict)
self._multi_entity_registry: dict[
str, dict[int | str | None, dict[MatchRule, list[CALLABLE_T]]]
str, dict[int | str | None, dict[MatchRule, list[type[ZhaEntity]]]]
] = collections.defaultdict(
lambda: collections.defaultdict(lambda: collections.defaultdict(list))
)
self._config_diagnostic_entity_registry: dict[
str, dict[int | str | None, dict[MatchRule, list[CALLABLE_T]]]
str, dict[int | str | None, dict[MatchRule, list[type[ZhaEntity]]]]
] = collections.defaultdict(
lambda: collections.defaultdict(lambda: collections.defaultdict(list))
)
self._group_registry: dict[str, CALLABLE_T] = {}
self._group_registry: dict[str, type[ZhaGroupEntity]] = {}
self.single_device_matches: dict[
Platform, dict[EUI64, list[str]]
] = collections.defaultdict(lambda: collections.defaultdict(list))
@ -248,8 +252,8 @@ class ZHAEntityRegistry:
manufacturer: str,
model: str,
channels: list[ZigbeeChannel],
default: CALLABLE_T = None,
) -> tuple[CALLABLE_T, list[ZigbeeChannel]]:
default: type[ZhaEntity] | None = None,
) -> tuple[type[ZhaEntity] | None, list[ZigbeeChannel]]:
"""Match a ZHA Channels to a ZHA Entity class."""
matches = self._strict_registry[component]
for match in sorted(matches, key=lambda x: x.weight, reverse=True):
@ -310,7 +314,7 @@ class ZHAEntityRegistry:
return result, list(all_claimed)
def get_group_entity(self, component: str) -> CALLABLE_T:
def get_group_entity(self, component: str) -> type[ZhaGroupEntity] | None:
"""Match a ZHA group to a ZHA Entity class."""
return self._group_registry.get(component)
@ -322,14 +326,14 @@ class ZHAEntityRegistry:
manufacturers: Callable | set[str] | str = None,
models: Callable | set[str] | str = None,
aux_channels: Callable | set[str] | str = None,
) -> Callable[[CALLABLE_T], CALLABLE_T]:
) -> Callable[[_ZhaEntityT], _ZhaEntityT]:
"""Decorate a strict match rule."""
rule = MatchRule(
channel_names, generic_ids, manufacturers, models, aux_channels
)
def decorator(zha_ent: CALLABLE_T) -> CALLABLE_T:
def decorator(zha_ent: _ZhaEntityT) -> _ZhaEntityT:
"""Register a strict match rule.
All non empty fields of a match rule must match.
@ -348,7 +352,7 @@ class ZHAEntityRegistry:
models: Callable | set[str] | str = None,
aux_channels: Callable | set[str] | str = None,
stop_on_match_group: int | str | None = None,
) -> Callable[[CALLABLE_T], CALLABLE_T]:
) -> Callable[[_ZhaEntityT], _ZhaEntityT]:
"""Decorate a loose match rule."""
rule = MatchRule(
@ -359,7 +363,7 @@ class ZHAEntityRegistry:
aux_channels,
)
def decorator(zha_entity: CALLABLE_T) -> CALLABLE_T:
def decorator(zha_entity: _ZhaEntityT) -> _ZhaEntityT:
"""Register a loose match rule.
All non empty fields of a match rule must match.
@ -381,7 +385,7 @@ class ZHAEntityRegistry:
models: Callable | set[str] | str = None,
aux_channels: Callable | set[str] | str = None,
stop_on_match_group: int | str | None = None,
) -> Callable[[CALLABLE_T], CALLABLE_T]:
) -> Callable[[_ZhaEntityT], _ZhaEntityT]:
"""Decorate a loose match rule."""
rule = MatchRule(
@ -392,7 +396,7 @@ class ZHAEntityRegistry:
aux_channels,
)
def decorator(zha_entity: CALLABLE_T) -> CALLABLE_T:
def decorator(zha_entity: _ZhaEntityT) -> _ZhaEntityT:
"""Register a loose match rule.
All non empty fields of a match rule must match.
@ -405,10 +409,12 @@ class ZHAEntityRegistry:
return decorator
def group_match(self, component: str) -> Callable[[CALLABLE_T], CALLABLE_T]:
def group_match(
self, component: str
) -> Callable[[_ZhaGroupEntityT], _ZhaGroupEntityT]:
"""Decorate a group match rule."""
def decorator(zha_ent: CALLABLE_T) -> CALLABLE_T:
def decorator(zha_ent: _ZhaGroupEntityT) -> _ZhaGroupEntityT:
"""Register a group match rule."""
self._group_registry[component] = zha_ent
return zha_ent

View file

@ -1,6 +0,0 @@
"""Typing helpers for ZHA component."""
from collections.abc import Callable
from typing import TypeVar
# pylint: disable=invalid-name
CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable)

View file

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
import functools
import logging
from typing import TYPE_CHECKING, Any
@ -29,7 +30,6 @@ from .core.const import (
SIGNAL_REMOVE,
)
from .core.helpers import LogMixin
from .core.typing import CALLABLE_T
if TYPE_CHECKING:
from .core.channels.base import ZigbeeChannel
@ -57,7 +57,7 @@ class BaseZhaEntity(LogMixin, entity.Entity):
self._state: Any = None
self._extra_state_attributes: dict[str, Any] = {}
self._zha_device = zha_device
self._unsubs: list[CALLABLE_T] = []
self._unsubs: list[Callable[[], None]] = []
self.remove_future: asyncio.Future[Any] = asyncio.Future()
@property
@ -130,7 +130,7 @@ class BaseZhaEntity(LogMixin, entity.Entity):
self,
channel: ZigbeeChannel,
signal: str,
func: CALLABLE_T,
func: Callable[[], Any],
signal_override=False,
):
"""Accept a signal from a channel."""