From 8b2759d810cb7f949caeb305ac6748e6fb8fa0d5 Mon Sep 17 00:00:00 2001 From: Robert Svensson Date: Sat, 9 Mar 2024 10:52:59 +0100 Subject: [PATCH] Move restoring inactive clients method into UniFi entity loader (#112805) * Move restoring inactive clients method into UniFi entity loader * Use an initialize method in entity_loader --- .../components/unifi/hub/entity_loader.py | 28 +++++++++++++++++++ homeassistant/components/unifi/hub/hub.py | 23 ++------------- 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/homeassistant/components/unifi/hub/entity_loader.py b/homeassistant/components/unifi/hub/entity_loader.py index 17e31836790..940d4dbdcad 100644 --- a/homeassistant/components/unifi/hub/entity_loader.py +++ b/homeassistant/components/unifi/hub/entity_loader.py @@ -13,9 +13,12 @@ from typing import TYPE_CHECKING from aiounifi.interfaces.api_handlers import ItemEvent +from homeassistant.const import Platform from homeassistant.core import callback +from homeassistant.helpers import entity_registry as er from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.helpers.entity_registry import async_entries_for_config_entry from ..const import LOGGER from ..entity import UnifiEntity, UnifiEntityDescription @@ -56,6 +59,11 @@ class UnifiEntityLoader: self.known_objects: set[tuple[str, str]] = set() """Tuples of entity description key and object ID of loaded entities.""" + async def initialize(self) -> None: + """Initialize API data and extra client support.""" + await self.refresh_api_data() + self.restore_inactive_clients() + async def refresh_api_data(self) -> None: """Refresh API data from network application.""" results = await asyncio.gather( @@ -66,6 +74,26 @@ class UnifiEntityLoader: if result is not None: LOGGER.warning("Exception on update %s", result) + @callback + def restore_inactive_clients(self) -> None: + """Restore inactive clients. + + Provide inactive clients to device tracker and switch platform. + """ + config = self.hub.config + macs: list[str] = [] + entity_registry = er.async_get(self.hub.hass) + for entry in async_entries_for_config_entry( + entity_registry, config.entry.entry_id + ): + if entry.domain == Platform.DEVICE_TRACKER and "-" in entry.unique_id: + macs.append(entry.unique_id.split("-", 1)[1]) + + api = self.hub.api + for mac in config.option_supported_clients + config.option_block_clients + macs: + if mac not in api.clients and mac in api.clients_all: + api.clients.process_raw([dict(api.clients_all[mac].raw)]) + @callback def register_platform( self, diff --git a/homeassistant/components/unifi/hub/hub.py b/homeassistant/components/unifi/hub/hub.py index f152c928659..b17e0d154a7 100644 --- a/homeassistant/components/unifi/hub/hub.py +++ b/homeassistant/components/unifi/hub/hub.py @@ -8,16 +8,14 @@ import aiounifi from aiounifi.models.device import DeviceSetPoePortModeRequest from homeassistant.config_entries import ConfigEntry -from homeassistant.const import Platform from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback -from homeassistant.helpers import device_registry as dr, entity_registry as er +from homeassistant.helpers import device_registry as dr from homeassistant.helpers.device_registry import ( DeviceEntry, DeviceEntryType, DeviceInfo, ) from homeassistant.helpers.dispatcher import async_dispatcher_send -from homeassistant.helpers.entity_registry import async_entries_for_config_entry from homeassistant.helpers.event import async_call_later, async_track_time_interval import homeassistant.util.dt as dt_util @@ -88,28 +86,11 @@ class UnifiHub: async def initialize(self) -> None: """Set up a UniFi Network instance.""" - await self.entity_loader.refresh_api_data() + await self.entity_loader.initialize() assert self.config.entry.unique_id is not None self.is_admin = self.api.sites[self.config.entry.unique_id].role == "admin" - # Restore device tracker clients that are not a part of active clients list. - macs: list[str] = [] - entity_registry = er.async_get(self.hass) - for entry in async_entries_for_config_entry( - entity_registry, self.config.entry.entry_id - ): - if entry.domain == Platform.DEVICE_TRACKER and "-" in entry.unique_id: - macs.append(entry.unique_id.split("-", 1)[1]) - - for mac in ( - self.config.option_supported_clients - + self.config.option_block_clients - + macs - ): - if mac not in self.api.clients and mac in self.api.clients_all: - self.api.clients.process_raw([dict(self.api.clients_all[mac].raw)]) - self.wireless_clients.update_clients(set(self.api.clients.values())) self.config.entry.add_update_listener(self.async_config_entry_updated)