* Limit concurrency of startup traffic to allow for interactive usage * Drop `retryable_req`, we already have request retrying * Oops, `min` -> `max` * Add a comment describing why `async_initialize` is not concurrent * Fix existing unit tests * Break out fetching mains state into its own function to unit test
252 lines
9.4 KiB
Python
252 lines
9.4 KiB
Python
"""Representation of a Zigbee endpoint for zha."""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections.abc import Awaitable, Callable
|
|
import functools
|
|
import logging
|
|
from typing import TYPE_CHECKING, Any, Final, TypeVar
|
|
|
|
from zigpy.typing import EndpointType as ZigpyEndpointType
|
|
|
|
from homeassistant.const import Platform
|
|
from homeassistant.core import callback
|
|
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
|
from homeassistant.util.async_ import gather_with_limited_concurrency
|
|
|
|
from . import const, discovery, registries
|
|
from .cluster_handlers import ClusterHandler
|
|
from .helpers import get_zha_data
|
|
|
|
if TYPE_CHECKING:
|
|
from .cluster_handlers import ClientClusterHandler
|
|
from .device import ZHADevice
|
|
|
|
ATTR_DEVICE_TYPE: Final[str] = "device_type"
|
|
ATTR_PROFILE_ID: Final[str] = "profile_id"
|
|
ATTR_IN_CLUSTERS: Final[str] = "input_clusters"
|
|
ATTR_OUT_CLUSTERS: Final[str] = "output_clusters"
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable)
|
|
|
|
|
|
class Endpoint:
|
|
"""Endpoint for a zha device."""
|
|
|
|
def __init__(self, zigpy_endpoint: ZigpyEndpointType, device: ZHADevice) -> None:
|
|
"""Initialize instance."""
|
|
assert zigpy_endpoint is not None
|
|
assert device is not None
|
|
self._zigpy_endpoint: ZigpyEndpointType = zigpy_endpoint
|
|
self._device: ZHADevice = device
|
|
self._all_cluster_handlers: dict[str, ClusterHandler] = {}
|
|
self._claimed_cluster_handlers: dict[str, ClusterHandler] = {}
|
|
self._client_cluster_handlers: dict[str, ClientClusterHandler] = {}
|
|
self._unique_id: str = f"{str(device.ieee)}-{zigpy_endpoint.endpoint_id}"
|
|
|
|
@property
|
|
def device(self) -> ZHADevice:
|
|
"""Return the device this endpoint belongs to."""
|
|
return self._device
|
|
|
|
@property
|
|
def all_cluster_handlers(self) -> dict[str, ClusterHandler]:
|
|
"""All server cluster handlers of an endpoint."""
|
|
return self._all_cluster_handlers
|
|
|
|
@property
|
|
def claimed_cluster_handlers(self) -> dict[str, ClusterHandler]:
|
|
"""Cluster handlers in use."""
|
|
return self._claimed_cluster_handlers
|
|
|
|
@property
|
|
def client_cluster_handlers(self) -> dict[str, ClientClusterHandler]:
|
|
"""Return a dict of client cluster handlers."""
|
|
return self._client_cluster_handlers
|
|
|
|
@property
|
|
def zigpy_endpoint(self) -> ZigpyEndpointType:
|
|
"""Return endpoint of zigpy device."""
|
|
return self._zigpy_endpoint
|
|
|
|
@property
|
|
def id(self) -> int:
|
|
"""Return endpoint id."""
|
|
return self._zigpy_endpoint.endpoint_id
|
|
|
|
@property
|
|
def unique_id(self) -> str:
|
|
"""Return the unique id for this endpoint."""
|
|
return self._unique_id
|
|
|
|
@property
|
|
def zigbee_signature(self) -> tuple[int, dict[str, Any]]:
|
|
"""Get the zigbee signature for the endpoint this pool represents."""
|
|
return (
|
|
self.id,
|
|
{
|
|
ATTR_PROFILE_ID: f"0x{self._zigpy_endpoint.profile_id:04x}"
|
|
if self._zigpy_endpoint.profile_id is not None
|
|
else "",
|
|
ATTR_DEVICE_TYPE: f"0x{self._zigpy_endpoint.device_type:04x}"
|
|
if self._zigpy_endpoint.device_type is not None
|
|
else "",
|
|
ATTR_IN_CLUSTERS: [
|
|
f"0x{cluster_id:04x}"
|
|
for cluster_id in sorted(self._zigpy_endpoint.in_clusters)
|
|
],
|
|
ATTR_OUT_CLUSTERS: [
|
|
f"0x{cluster_id:04x}"
|
|
for cluster_id in sorted(self._zigpy_endpoint.out_clusters)
|
|
],
|
|
},
|
|
)
|
|
|
|
@classmethod
|
|
def new(cls, zigpy_endpoint: ZigpyEndpointType, device: ZHADevice) -> Endpoint:
|
|
"""Create new endpoint and populate cluster handlers."""
|
|
endpoint = cls(zigpy_endpoint, device)
|
|
endpoint.add_all_cluster_handlers()
|
|
endpoint.add_client_cluster_handlers()
|
|
if not device.is_coordinator:
|
|
discovery.PROBE.discover_entities(endpoint)
|
|
return endpoint
|
|
|
|
def add_all_cluster_handlers(self) -> None:
|
|
"""Create and add cluster handlers for all input clusters."""
|
|
for cluster_id, cluster in self.zigpy_endpoint.in_clusters.items():
|
|
cluster_handler_classes = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
|
|
cluster_id, {None: ClusterHandler}
|
|
)
|
|
quirk_id = (
|
|
self.device.quirk_id
|
|
if self.device.quirk_id in cluster_handler_classes
|
|
else None
|
|
)
|
|
cluster_handler_class = cluster_handler_classes.get(
|
|
quirk_id, ClusterHandler
|
|
)
|
|
|
|
# Allow cluster handler to filter out bad matches
|
|
if not cluster_handler_class.matches(cluster, self):
|
|
cluster_handler_class = ClusterHandler
|
|
|
|
_LOGGER.info(
|
|
"Creating cluster handler for cluster id: %s class: %s",
|
|
cluster_id,
|
|
cluster_handler_class,
|
|
)
|
|
|
|
try:
|
|
cluster_handler = cluster_handler_class(cluster, self)
|
|
except KeyError as err:
|
|
_LOGGER.warning(
|
|
"Cluster handler %s for cluster %s on endpoint %s is invalid: %s",
|
|
cluster_handler_class,
|
|
cluster,
|
|
self,
|
|
err,
|
|
)
|
|
continue
|
|
|
|
if cluster_handler.name == const.CLUSTER_HANDLER_POWER_CONFIGURATION:
|
|
self._device.power_configuration_ch = cluster_handler
|
|
elif cluster_handler.name == const.CLUSTER_HANDLER_IDENTIFY:
|
|
self._device.identify_ch = cluster_handler
|
|
elif cluster_handler.name == const.CLUSTER_HANDLER_BASIC:
|
|
self._device.basic_ch = cluster_handler
|
|
self._all_cluster_handlers[cluster_handler.id] = cluster_handler
|
|
|
|
def add_client_cluster_handlers(self) -> None:
|
|
"""Create client cluster handlers for all output clusters if in the registry."""
|
|
for (
|
|
cluster_id,
|
|
cluster_handler_class,
|
|
) in registries.CLIENT_CLUSTER_HANDLER_REGISTRY.items():
|
|
cluster = self.zigpy_endpoint.out_clusters.get(cluster_id)
|
|
if cluster is not None:
|
|
cluster_handler = cluster_handler_class(cluster, self)
|
|
self.client_cluster_handlers[cluster_handler.id] = cluster_handler
|
|
|
|
async def async_initialize(self, from_cache: bool = False) -> None:
|
|
"""Initialize claimed cluster handlers."""
|
|
await self._execute_handler_tasks(
|
|
"async_initialize", from_cache, max_concurrency=1
|
|
)
|
|
|
|
async def async_configure(self) -> None:
|
|
"""Configure claimed cluster handlers."""
|
|
await self._execute_handler_tasks("async_configure")
|
|
|
|
async def _execute_handler_tasks(
|
|
self, func_name: str, *args: Any, max_concurrency: int | None = None
|
|
) -> None:
|
|
"""Add a throttled cluster handler task and swallow exceptions."""
|
|
cluster_handlers = [
|
|
*self.claimed_cluster_handlers.values(),
|
|
*self.client_cluster_handlers.values(),
|
|
]
|
|
tasks = [getattr(ch, func_name)(*args) for ch in cluster_handlers]
|
|
|
|
gather: Callable[..., Awaitable]
|
|
|
|
if max_concurrency is None:
|
|
gather = asyncio.gather
|
|
else:
|
|
gather = functools.partial(gather_with_limited_concurrency, max_concurrency)
|
|
|
|
results = await gather(*tasks, return_exceptions=True)
|
|
for cluster_handler, outcome in zip(cluster_handlers, results):
|
|
if isinstance(outcome, Exception):
|
|
cluster_handler.warning(
|
|
"'%s' stage failed: %s", func_name, str(outcome), exc_info=outcome
|
|
)
|
|
continue
|
|
cluster_handler.debug("'%s' stage succeeded", func_name)
|
|
|
|
def async_new_entity(
|
|
self,
|
|
platform: Platform,
|
|
entity_class: CALLABLE_T,
|
|
unique_id: str,
|
|
cluster_handlers: list[ClusterHandler],
|
|
) -> None:
|
|
"""Create a new entity."""
|
|
from .device import DeviceStatus # pylint: disable=import-outside-toplevel
|
|
|
|
if self.device.status == DeviceStatus.INITIALIZED:
|
|
return
|
|
|
|
zha_data = get_zha_data(self.device.hass)
|
|
zha_data.platforms[platform].append(
|
|
(entity_class, (unique_id, self.device, cluster_handlers))
|
|
)
|
|
|
|
@callback
|
|
def async_send_signal(self, signal: str, *args: Any) -> None:
|
|
"""Send a signal through hass dispatcher."""
|
|
async_dispatcher_send(self.device.hass, signal, *args)
|
|
|
|
def send_event(self, signal: dict[str, Any]) -> None:
|
|
"""Broadcast an event from this endpoint."""
|
|
self.device.zha_send_event(
|
|
{
|
|
const.ATTR_UNIQUE_ID: self.unique_id,
|
|
const.ATTR_ENDPOINT_ID: self.id,
|
|
**signal,
|
|
}
|
|
)
|
|
|
|
def claim_cluster_handlers(self, cluster_handlers: list[ClusterHandler]) -> None:
|
|
"""Claim cluster handlers."""
|
|
self.claimed_cluster_handlers.update({ch.id: ch for ch in cluster_handlers})
|
|
|
|
def unclaimed_cluster_handlers(self) -> list[ClusterHandler]:
|
|
"""Return a list of available (unclaimed) cluster handlers."""
|
|
claimed = set(self.claimed_cluster_handlers)
|
|
available = set(self.all_cluster_handlers)
|
|
return [
|
|
self.all_cluster_handlers[cluster_id]
|
|
for cluster_id in (available - claimed)
|
|
]
|