hass-core/homeassistant/components/zha/core/endpoint.py
puddly 304b950f1a
Speed up ZHA initialization and improve startup responsiveness ()
* Limit concurrency of startup traffic to allow for interactive usage

* Drop `retryable_req`, we already have request retrying

* Oops, `min` -> `max`

* Add a comment describing why `async_initialize` is not concurrent

* Fix existing unit tests

* Break out fetching mains state into its own function to unit test
2024-01-19 15:41:45 +01:00

252 lines
9.4 KiB
Python

"""Representation of a Zigbee endpoint for zha."""
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
import functools
import logging
from typing import TYPE_CHECKING, Any, Final, TypeVar
from zigpy.typing import EndpointType as ZigpyEndpointType
from homeassistant.const import Platform
from homeassistant.core import callback
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.util.async_ import gather_with_limited_concurrency
from . import const, discovery, registries
from .cluster_handlers import ClusterHandler
from .helpers import get_zha_data
if TYPE_CHECKING:
from .cluster_handlers import ClientClusterHandler
from .device import ZHADevice
ATTR_DEVICE_TYPE: Final[str] = "device_type"
ATTR_PROFILE_ID: Final[str] = "profile_id"
ATTR_IN_CLUSTERS: Final[str] = "input_clusters"
ATTR_OUT_CLUSTERS: Final[str] = "output_clusters"
_LOGGER = logging.getLogger(__name__)
CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable)
class Endpoint:
"""Endpoint for a zha device."""
def __init__(self, zigpy_endpoint: ZigpyEndpointType, device: ZHADevice) -> None:
"""Initialize instance."""
assert zigpy_endpoint is not None
assert device is not None
self._zigpy_endpoint: ZigpyEndpointType = zigpy_endpoint
self._device: ZHADevice = device
self._all_cluster_handlers: dict[str, ClusterHandler] = {}
self._claimed_cluster_handlers: dict[str, ClusterHandler] = {}
self._client_cluster_handlers: dict[str, ClientClusterHandler] = {}
self._unique_id: str = f"{str(device.ieee)}-{zigpy_endpoint.endpoint_id}"
@property
def device(self) -> ZHADevice:
"""Return the device this endpoint belongs to."""
return self._device
@property
def all_cluster_handlers(self) -> dict[str, ClusterHandler]:
"""All server cluster handlers of an endpoint."""
return self._all_cluster_handlers
@property
def claimed_cluster_handlers(self) -> dict[str, ClusterHandler]:
"""Cluster handlers in use."""
return self._claimed_cluster_handlers
@property
def client_cluster_handlers(self) -> dict[str, ClientClusterHandler]:
"""Return a dict of client cluster handlers."""
return self._client_cluster_handlers
@property
def zigpy_endpoint(self) -> ZigpyEndpointType:
"""Return endpoint of zigpy device."""
return self._zigpy_endpoint
@property
def id(self) -> int:
"""Return endpoint id."""
return self._zigpy_endpoint.endpoint_id
@property
def unique_id(self) -> str:
"""Return the unique id for this endpoint."""
return self._unique_id
@property
def zigbee_signature(self) -> tuple[int, dict[str, Any]]:
"""Get the zigbee signature for the endpoint this pool represents."""
return (
self.id,
{
ATTR_PROFILE_ID: f"0x{self._zigpy_endpoint.profile_id:04x}"
if self._zigpy_endpoint.profile_id is not None
else "",
ATTR_DEVICE_TYPE: f"0x{self._zigpy_endpoint.device_type:04x}"
if self._zigpy_endpoint.device_type is not None
else "",
ATTR_IN_CLUSTERS: [
f"0x{cluster_id:04x}"
for cluster_id in sorted(self._zigpy_endpoint.in_clusters)
],
ATTR_OUT_CLUSTERS: [
f"0x{cluster_id:04x}"
for cluster_id in sorted(self._zigpy_endpoint.out_clusters)
],
},
)
@classmethod
def new(cls, zigpy_endpoint: ZigpyEndpointType, device: ZHADevice) -> Endpoint:
"""Create new endpoint and populate cluster handlers."""
endpoint = cls(zigpy_endpoint, device)
endpoint.add_all_cluster_handlers()
endpoint.add_client_cluster_handlers()
if not device.is_coordinator:
discovery.PROBE.discover_entities(endpoint)
return endpoint
def add_all_cluster_handlers(self) -> None:
"""Create and add cluster handlers for all input clusters."""
for cluster_id, cluster in self.zigpy_endpoint.in_clusters.items():
cluster_handler_classes = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id, {None: ClusterHandler}
)
quirk_id = (
self.device.quirk_id
if self.device.quirk_id in cluster_handler_classes
else None
)
cluster_handler_class = cluster_handler_classes.get(
quirk_id, ClusterHandler
)
# Allow cluster handler to filter out bad matches
if not cluster_handler_class.matches(cluster, self):
cluster_handler_class = ClusterHandler
_LOGGER.info(
"Creating cluster handler for cluster id: %s class: %s",
cluster_id,
cluster_handler_class,
)
try:
cluster_handler = cluster_handler_class(cluster, self)
except KeyError as err:
_LOGGER.warning(
"Cluster handler %s for cluster %s on endpoint %s is invalid: %s",
cluster_handler_class,
cluster,
self,
err,
)
continue
if cluster_handler.name == const.CLUSTER_HANDLER_POWER_CONFIGURATION:
self._device.power_configuration_ch = cluster_handler
elif cluster_handler.name == const.CLUSTER_HANDLER_IDENTIFY:
self._device.identify_ch = cluster_handler
elif cluster_handler.name == const.CLUSTER_HANDLER_BASIC:
self._device.basic_ch = cluster_handler
self._all_cluster_handlers[cluster_handler.id] = cluster_handler
def add_client_cluster_handlers(self) -> None:
"""Create client cluster handlers for all output clusters if in the registry."""
for (
cluster_id,
cluster_handler_class,
) in registries.CLIENT_CLUSTER_HANDLER_REGISTRY.items():
cluster = self.zigpy_endpoint.out_clusters.get(cluster_id)
if cluster is not None:
cluster_handler = cluster_handler_class(cluster, self)
self.client_cluster_handlers[cluster_handler.id] = cluster_handler
async def async_initialize(self, from_cache: bool = False) -> None:
"""Initialize claimed cluster handlers."""
await self._execute_handler_tasks(
"async_initialize", from_cache, max_concurrency=1
)
async def async_configure(self) -> None:
"""Configure claimed cluster handlers."""
await self._execute_handler_tasks("async_configure")
async def _execute_handler_tasks(
self, func_name: str, *args: Any, max_concurrency: int | None = None
) -> None:
"""Add a throttled cluster handler task and swallow exceptions."""
cluster_handlers = [
*self.claimed_cluster_handlers.values(),
*self.client_cluster_handlers.values(),
]
tasks = [getattr(ch, func_name)(*args) for ch in cluster_handlers]
gather: Callable[..., Awaitable]
if max_concurrency is None:
gather = asyncio.gather
else:
gather = functools.partial(gather_with_limited_concurrency, max_concurrency)
results = await gather(*tasks, return_exceptions=True)
for cluster_handler, outcome in zip(cluster_handlers, results):
if isinstance(outcome, Exception):
cluster_handler.warning(
"'%s' stage failed: %s", func_name, str(outcome), exc_info=outcome
)
continue
cluster_handler.debug("'%s' stage succeeded", func_name)
def async_new_entity(
self,
platform: Platform,
entity_class: CALLABLE_T,
unique_id: str,
cluster_handlers: list[ClusterHandler],
) -> None:
"""Create a new entity."""
from .device import DeviceStatus # pylint: disable=import-outside-toplevel
if self.device.status == DeviceStatus.INITIALIZED:
return
zha_data = get_zha_data(self.device.hass)
zha_data.platforms[platform].append(
(entity_class, (unique_id, self.device, cluster_handlers))
)
@callback
def async_send_signal(self, signal: str, *args: Any) -> None:
"""Send a signal through hass dispatcher."""
async_dispatcher_send(self.device.hass, signal, *args)
def send_event(self, signal: dict[str, Any]) -> None:
"""Broadcast an event from this endpoint."""
self.device.zha_send_event(
{
const.ATTR_UNIQUE_ID: self.unique_id,
const.ATTR_ENDPOINT_ID: self.id,
**signal,
}
)
def claim_cluster_handlers(self, cluster_handlers: list[ClusterHandler]) -> None:
"""Claim cluster handlers."""
self.claimed_cluster_handlers.update({ch.id: ch for ch in cluster_handlers})
def unclaimed_cluster_handlers(self) -> list[ClusterHandler]:
"""Return a list of available (unclaimed) cluster handlers."""
claimed = set(self.claimed_cluster_handlers)
available = set(self.all_cluster_handlers)
return [
self.all_cluster_handlers[cluster_id]
for cluster_id in (available - claimed)
]