Speed up ZHA initialization and improve startup responsiveness (#108103)
* Limit concurrency of startup traffic to allow for interactive usage * Drop `retryable_req`, we already have request retrying * Oops, `min` -> `max` * Add a comment describing why `async_initialize` is not concurrent * Fix existing unit tests * Break out fetching mains state into its own function to unit test
This commit is contained in:
parent
99f9f0205a
commit
304b950f1a
7 changed files with 149 additions and 73 deletions
|
@ -42,7 +42,7 @@ from ..const import (
|
|||
ZHA_CLUSTER_HANDLER_MSG_DATA,
|
||||
ZHA_CLUSTER_HANDLER_READS_PER_REQ,
|
||||
)
|
||||
from ..helpers import LogMixin, retryable_req, safe_read
|
||||
from ..helpers import LogMixin, safe_read
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..endpoint import Endpoint
|
||||
|
@ -362,7 +362,6 @@ class ClusterHandler(LogMixin):
|
|||
self.debug("skipping cluster handler configuration")
|
||||
self._status = ClusterHandlerStatus.CONFIGURED
|
||||
|
||||
@retryable_req(delays=(1, 1, 3))
|
||||
async def async_initialize(self, from_cache: bool) -> None:
|
||||
"""Initialize cluster handler."""
|
||||
if not from_cache and self._endpoint.device.skip_configuration:
|
||||
|
|
|
@ -592,12 +592,17 @@ class ZHADevice(LogMixin):
|
|||
self.debug("started initialization")
|
||||
await self._zdo_handler.async_initialize(from_cache)
|
||||
self._zdo_handler.debug("'async_initialize' stage succeeded")
|
||||
await asyncio.gather(
|
||||
*(
|
||||
endpoint.async_initialize(from_cache)
|
||||
for endpoint in self._endpoints.values()
|
||||
)
|
||||
)
|
||||
|
||||
# We intentionally do not use `gather` here! This is so that if, for example,
|
||||
# three `device.async_initialize()`s are spawned, only three concurrent requests
|
||||
# will ever be in flight at once. Startup concurrency is managed at the device
|
||||
# level.
|
||||
for endpoint in self._endpoints.values():
|
||||
try:
|
||||
await endpoint.async_initialize(from_cache)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
self.debug("Failed to initialize endpoint", exc_info=True)
|
||||
|
||||
self.debug("power source: %s", self.power_source)
|
||||
self.status = DeviceStatus.INITIALIZED
|
||||
self.debug("completed initialization")
|
||||
|
|
|
@ -2,7 +2,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Awaitable, Callable
|
||||
import functools
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Final, TypeVar
|
||||
|
||||
|
@ -11,6 +12,7 @@ from zigpy.typing import EndpointType as ZigpyEndpointType
|
|||
from homeassistant.const import Platform
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||
from homeassistant.util.async_ import gather_with_limited_concurrency
|
||||
|
||||
from . import const, discovery, registries
|
||||
from .cluster_handlers import ClusterHandler
|
||||
|
@ -169,20 +171,32 @@ class Endpoint:
|
|||
|
||||
async def async_initialize(self, from_cache: bool = False) -> None:
|
||||
"""Initialize claimed cluster handlers."""
|
||||
await self._execute_handler_tasks("async_initialize", from_cache)
|
||||
await self._execute_handler_tasks(
|
||||
"async_initialize", from_cache, max_concurrency=1
|
||||
)
|
||||
|
||||
async def async_configure(self) -> None:
|
||||
"""Configure claimed cluster handlers."""
|
||||
await self._execute_handler_tasks("async_configure")
|
||||
|
||||
async def _execute_handler_tasks(self, func_name: str, *args: Any) -> None:
|
||||
async def _execute_handler_tasks(
|
||||
self, func_name: str, *args: Any, max_concurrency: int | None = None
|
||||
) -> None:
|
||||
"""Add a throttled cluster handler task and swallow exceptions."""
|
||||
cluster_handlers = [
|
||||
*self.claimed_cluster_handlers.values(),
|
||||
*self.client_cluster_handlers.values(),
|
||||
]
|
||||
tasks = [getattr(ch, func_name)(*args) for ch in cluster_handlers]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
gather: Callable[..., Awaitable]
|
||||
|
||||
if max_concurrency is None:
|
||||
gather = asyncio.gather
|
||||
else:
|
||||
gather = functools.partial(gather_with_limited_concurrency, max_concurrency)
|
||||
|
||||
results = await gather(*tasks, return_exceptions=True)
|
||||
for cluster_handler, outcome in zip(cluster_handlers, results):
|
||||
if isinstance(outcome, Exception):
|
||||
cluster_handler.warning(
|
||||
|
|
|
@ -11,7 +11,7 @@ import itertools
|
|||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Self
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Self, cast
|
||||
|
||||
from zigpy.application import ControllerApplication
|
||||
from zigpy.config import (
|
||||
|
@ -36,6 +36,7 @@ from homeassistant.helpers import device_registry as dr, entity_registry as er
|
|||
from homeassistant.helpers.device_registry import DeviceInfo
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.util.async_ import gather_with_limited_concurrency
|
||||
|
||||
from . import discovery
|
||||
from .const import (
|
||||
|
@ -292,6 +293,39 @@ class ZHAGateway:
|
|||
# entity registry tied to the devices
|
||||
discovery.GROUP_PROBE.discover_group_entities(zha_group)
|
||||
|
||||
@property
|
||||
def radio_concurrency(self) -> int:
|
||||
"""Maximum configured radio concurrency."""
|
||||
return self.application_controller._concurrent_requests_semaphore.max_value # pylint: disable=protected-access
|
||||
|
||||
async def async_fetch_updated_state_mains(self) -> None:
|
||||
"""Fetch updated state for mains powered devices."""
|
||||
_LOGGER.debug("Fetching current state for mains powered devices")
|
||||
|
||||
now = time.time()
|
||||
|
||||
# Only delay startup to poll mains-powered devices that are online
|
||||
online_devices = [
|
||||
dev
|
||||
for dev in self.devices.values()
|
||||
if dev.is_mains_powered
|
||||
and dev.last_seen is not None
|
||||
and (now - dev.last_seen) < dev.consider_unavailable_time
|
||||
]
|
||||
|
||||
# Prioritize devices that have recently been contacted
|
||||
online_devices.sort(key=lambda dev: cast(float, dev.last_seen), reverse=True)
|
||||
|
||||
# Make sure that we always leave slots for non-startup requests
|
||||
max_poll_concurrency = max(1, self.radio_concurrency - 4)
|
||||
|
||||
await gather_with_limited_concurrency(
|
||||
max_poll_concurrency,
|
||||
*(dev.async_initialize(from_cache=False) for dev in online_devices),
|
||||
)
|
||||
|
||||
_LOGGER.debug("completed fetching current state for mains powered devices")
|
||||
|
||||
async def async_initialize_devices_and_entities(self) -> None:
|
||||
"""Initialize devices and load entities."""
|
||||
|
||||
|
@ -302,17 +336,8 @@ class ZHAGateway:
|
|||
|
||||
async def fetch_updated_state() -> None:
|
||||
"""Fetch updated state for mains powered devices."""
|
||||
_LOGGER.debug("Fetching current state for mains powered devices")
|
||||
await asyncio.gather(
|
||||
*(
|
||||
dev.async_initialize(from_cache=False)
|
||||
for dev in self.devices.values()
|
||||
if dev.is_mains_powered
|
||||
)
|
||||
)
|
||||
_LOGGER.debug(
|
||||
"completed fetching current state for mains powered devices - allowing polled requests"
|
||||
)
|
||||
await self.async_fetch_updated_state_mains()
|
||||
_LOGGER.debug("Allowing polled requests")
|
||||
self.hass.data[DATA_ZHA].allow_polling = True
|
||||
|
||||
# background the fetching of state for mains powered devices
|
||||
|
|
|
@ -5,17 +5,13 @@ https://home-assistant.io/integrations/zha/
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import binascii
|
||||
import collections
|
||||
from collections.abc import Callable, Iterator
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
import enum
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
from random import uniform
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
|
@ -318,49 +314,6 @@ class LogMixin:
|
|||
return self.log(logging.ERROR, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def retryable_req(
|
||||
delays=(1, 5, 10, 15, 30, 60, 120, 180, 360, 600, 900, 1800), raise_=False
|
||||
):
|
||||
"""Make a method with ZCL requests retryable.
|
||||
|
||||
This adds delays keyword argument to function.
|
||||
len(delays) is number of tries.
|
||||
raise_ if the final attempt should raise the exception.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(cluster_handler, *args, **kwargs):
|
||||
exceptions = (zigpy.exceptions.ZigbeeException, asyncio.TimeoutError)
|
||||
try_count, errors = 1, []
|
||||
for delay in itertools.chain(delays, [None]):
|
||||
try:
|
||||
return await func(cluster_handler, *args, **kwargs)
|
||||
except exceptions as ex:
|
||||
errors.append(ex)
|
||||
if delay:
|
||||
delay = uniform(delay * 0.75, delay * 1.25)
|
||||
cluster_handler.debug(
|
||||
"%s: retryable request #%d failed: %s. Retrying in %ss",
|
||||
func.__name__,
|
||||
try_count,
|
||||
ex,
|
||||
round(delay, 1),
|
||||
)
|
||||
try_count += 1
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
cluster_handler.warning(
|
||||
"%s: all attempts have failed: %s", func.__name__, errors
|
||||
)
|
||||
if raise_:
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def convert_install_code(value: str) -> bytes:
|
||||
"""Convert string to install code bytes and validate length."""
|
||||
|
||||
|
|
|
@ -135,7 +135,7 @@ def _wrap_mock_instance(obj: Any) -> MagicMock:
|
|||
real_attr = getattr(obj, attr_name)
|
||||
mock_attr = getattr(mock, attr_name)
|
||||
|
||||
if callable(real_attr):
|
||||
if callable(real_attr) and not hasattr(real_attr, "__aenter__"):
|
||||
mock_attr.side_effect = real_attr
|
||||
else:
|
||||
setattr(mock, attr_name, real_attr)
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
"""Test ZHA Gateway."""
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from zigpy.application import ControllerApplication
|
||||
import zigpy.profiles.zha as zha
|
||||
import zigpy.types
|
||||
import zigpy.zcl.clusters.general as general
|
||||
import zigpy.zcl.clusters.lighting as lighting
|
||||
import zigpy.zdo.types
|
||||
|
||||
from homeassistant.components.zha.core.gateway import ZHAGateway
|
||||
from homeassistant.components.zha.core.group import GroupMember
|
||||
|
@ -321,3 +323,81 @@ async def test_single_reload_on_multiple_connection_loss(
|
|||
assert len(mock_reload.mock_calls) == 1
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("radio_concurrency", [1, 2, 8])
|
||||
async def test_startup_concurrency_limit(
|
||||
radio_concurrency: int,
|
||||
hass: HomeAssistant,
|
||||
zigpy_app_controller: ControllerApplication,
|
||||
config_entry: MockConfigEntry,
|
||||
zigpy_device_mock,
|
||||
):
|
||||
"""Test ZHA gateway limits concurrency on startup."""
|
||||
config_entry.add_to_hass(hass)
|
||||
zha_gateway = ZHAGateway(hass, {}, config_entry)
|
||||
|
||||
with patch(
|
||||
"bellows.zigbee.application.ControllerApplication.new",
|
||||
return_value=zigpy_app_controller,
|
||||
):
|
||||
await zha_gateway.async_initialize()
|
||||
|
||||
for i in range(50):
|
||||
zigpy_dev = zigpy_device_mock(
|
||||
{
|
||||
1: {
|
||||
SIG_EP_INPUT: [
|
||||
general.OnOff.cluster_id,
|
||||
general.LevelControl.cluster_id,
|
||||
lighting.Color.cluster_id,
|
||||
general.Groups.cluster_id,
|
||||
],
|
||||
SIG_EP_OUTPUT: [],
|
||||
SIG_EP_TYPE: zha.DeviceType.COLOR_DIMMABLE_LIGHT,
|
||||
SIG_EP_PROFILE: zha.PROFILE_ID,
|
||||
}
|
||||
},
|
||||
ieee=f"11:22:33:44:{i:08x}",
|
||||
nwk=0x1234 + i,
|
||||
)
|
||||
zigpy_dev.node_desc.mac_capability_flags |= (
|
||||
zigpy.zdo.types.NodeDescriptor.MACCapabilityFlags.MainsPowered
|
||||
)
|
||||
|
||||
zha_gateway._async_get_or_create_device(zigpy_dev, restored=True)
|
||||
|
||||
# Keep track of request concurrency during initialization
|
||||
current_concurrency = 0
|
||||
concurrencies = []
|
||||
|
||||
async def mock_send_packet(*args, **kwargs):
|
||||
nonlocal current_concurrency
|
||||
|
||||
current_concurrency += 1
|
||||
concurrencies.append(current_concurrency)
|
||||
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
current_concurrency -= 1
|
||||
concurrencies.append(current_concurrency)
|
||||
|
||||
type(zha_gateway).radio_concurrency = PropertyMock(return_value=radio_concurrency)
|
||||
assert zha_gateway.radio_concurrency == radio_concurrency
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.zha.core.device.ZHADevice.async_initialize",
|
||||
side_effect=mock_send_packet,
|
||||
):
|
||||
await zha_gateway.async_fetch_updated_state_mains()
|
||||
|
||||
await zha_gateway.shutdown()
|
||||
|
||||
# Make sure concurrency was always limited
|
||||
assert current_concurrency == 0
|
||||
assert min(concurrencies) == 0
|
||||
|
||||
if radio_concurrency > 1:
|
||||
assert 1 <= max(concurrencies) < zha_gateway.radio_concurrency
|
||||
else:
|
||||
assert 1 == max(concurrencies) == zha_gateway.radio_concurrency
|
||||
|
|
Loading…
Add table
Reference in a new issue