Automatically remove unregistered TP-Link Omada devices at start up (#124153)

* Adding coordinator for omada device list

* Remove dead omada devices at startup

* Tidy up tests

* Address PR feedback

* Returned to use of read-only properties for coordinators. Tidied up parameters some more

* Update homeassistant/components/tplink_omada/controller.py

* Update homeassistant/components/tplink_omada/controller.py

* Update homeassistant/components/tplink_omada/controller.py

---------

Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>
This commit is contained in:
MarkGodwin 2024-09-22 16:05:29 +01:00 committed by GitHub
parent 8158ca7c69
commit 2a36ec3e21
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 164 additions and 51 deletions

View file

@ -3,6 +3,7 @@
from __future__ import annotations
from tplink_omada_client import OmadaSite
from tplink_omada_client.devices import OmadaListDevice
from tplink_omada_client.exceptions import (
ConnectionFailed,
LoginFailed,
@ -14,6 +15,7 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers import device_registry as dr
from .config_flow import CONF_SITE, create_omada_client
from .const import DOMAIN
@ -52,13 +54,12 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
site_client = await client.get_site_client(OmadaSite("", entry.data[CONF_SITE]))
controller = OmadaSiteController(hass, site_client)
gateway_coordinator = await controller.get_gateway_coordinator()
if gateway_coordinator:
await gateway_coordinator.async_config_entry_first_refresh()
await controller.get_clients_coordinator().async_config_entry_first_refresh()
await controller.initialize_first_refresh()
hass.data[DOMAIN][entry.entry_id] = controller
_remove_old_devices(hass, entry, controller.devices_coordinator.data)
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
return True
@ -70,3 +71,20 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
hass.data[DOMAIN].pop(entry.entry_id)
return unload_ok
def _remove_old_devices(
hass: HomeAssistant, entry: ConfigEntry, omada_devices: dict[str, OmadaListDevice]
) -> None:
device_registry = dr.async_get(hass)
for registered_device in device_registry.devices.get_devices_for_config_entry_id(
entry.entry_id
):
mac = next(
(i[1] for i in registered_device.identifiers if i[0] == DOMAIN), None
)
if mac and mac not in omada_devices:
device_registry.async_update_device(
registered_device.id, remove_config_entry_id=entry.entry_id
)

View file

@ -34,7 +34,7 @@ async def async_setup_entry(
"""Set up binary sensors."""
controller: OmadaSiteController = hass.data[DOMAIN][config_entry.entry_id]
gateway_coordinator = await controller.get_gateway_coordinator()
gateway_coordinator = controller.gateway_coordinator
if not gateway_coordinator:
return

View file

@ -7,6 +7,7 @@ from homeassistant.core import HomeAssistant
from .coordinator import (
OmadaClientsCoordinator,
OmadaDevicesCoordinator,
OmadaGatewayCoordinator,
OmadaSwitchPortCoordinator,
)
@ -16,15 +17,33 @@ class OmadaSiteController:
"""Controller for the Omada SDN site."""
_gateway_coordinator: OmadaGatewayCoordinator | None = None
_initialized_gateway_coordinator = False
_clients_coordinator: OmadaClientsCoordinator | None = None
def __init__(self, hass: HomeAssistant, omada_client: OmadaSiteClient) -> None:
def __init__(
self,
hass: HomeAssistant,
omada_client: OmadaSiteClient,
) -> None:
"""Create the controller."""
self._hass = hass
self._omada_client = omada_client
self._switch_port_coordinators: dict[str, OmadaSwitchPortCoordinator] = {}
self._devices_coordinator = OmadaDevicesCoordinator(hass, omada_client)
self._clients_coordinator = OmadaClientsCoordinator(hass, omada_client)
async def initialize_first_refresh(self) -> None:
"""Initialize the all coordinators, and perform first refresh."""
await self._devices_coordinator.async_config_entry_first_refresh()
devices = self._devices_coordinator.data.values()
gateway = next((d for d in devices if d.type == "gateway"), None)
if gateway:
self._gateway_coordinator = OmadaGatewayCoordinator(
self._hass, self._omada_client, gateway.mac
)
await self._gateway_coordinator.async_config_entry_first_refresh()
await self.clients_coordinator.async_config_entry_first_refresh()
@property
def omada_client(self) -> OmadaSiteClient:
@ -42,26 +61,17 @@ class OmadaSiteController:
return self._switch_port_coordinators[switch.mac]
async def get_gateway_coordinator(self) -> OmadaGatewayCoordinator | None:
"""Get coordinator for site's gateway, or None if there is no gateway."""
if not self._initialized_gateway_coordinator:
self._initialized_gateway_coordinator = True
devices = await self._omada_client.get_devices()
gateway = next((d for d in devices if d.type == "gateway"), None)
if not gateway:
return None
self._gateway_coordinator = OmadaGatewayCoordinator(
self._hass, self._omada_client, gateway.mac
)
@property
def gateway_coordinator(self) -> OmadaGatewayCoordinator | None:
"""Gets the coordinator for site's gateway, or None if there is no gateway."""
return self._gateway_coordinator
def get_clients_coordinator(self) -> OmadaClientsCoordinator:
"""Get coordinator for site's clients."""
if not self._clients_coordinator:
self._clients_coordinator = OmadaClientsCoordinator(
self._hass, self._omada_client
)
@property
def devices_coordinator(self) -> OmadaDevicesCoordinator:
"""Gets the coordinator for site's devices."""
return self._devices_coordinator
@property
def clients_coordinator(self) -> OmadaClientsCoordinator:
"""Gets the coordinator for site's clients."""
return self._clients_coordinator

View file

@ -6,7 +6,7 @@ import logging
from tplink_omada_client import OmadaSiteClient, OmadaSwitchPortDetails
from tplink_omada_client.clients import OmadaWirelessClient
from tplink_omada_client.devices import OmadaGateway, OmadaSwitch
from tplink_omada_client.devices import OmadaGateway, OmadaListDevice, OmadaSwitch
from tplink_omada_client.exceptions import OmadaClientException
from homeassistant.core import HomeAssistant
@ -17,6 +17,7 @@ _LOGGER = logging.getLogger(__name__)
POLL_SWITCH_PORT = 300
POLL_GATEWAY = 300
POLL_CLIENTS = 300
POLL_DEVICES = 900
class OmadaCoordinator[_T](DataUpdateCoordinator[dict[str, _T]]):
@ -27,14 +28,14 @@ class OmadaCoordinator[_T](DataUpdateCoordinator[dict[str, _T]]):
hass: HomeAssistant,
omada_client: OmadaSiteClient,
name: str,
poll_delay: int = 300,
poll_delay: int | None = 300,
) -> None:
"""Initialize my coordinator."""
super().__init__(
hass,
_LOGGER,
name=f"Omada API Data - {name}",
update_interval=timedelta(seconds=poll_delay),
update_interval=timedelta(seconds=poll_delay) if poll_delay else None,
)
self.omada_client = omada_client
@ -91,6 +92,22 @@ class OmadaGatewayCoordinator(OmadaCoordinator[OmadaGateway]):
return {self.mac: gateway}
class OmadaDevicesCoordinator(OmadaCoordinator[OmadaListDevice]):
"""Coordinator for generic device lists from the controller."""
def __init__(
self,
hass: HomeAssistant,
omada_client: OmadaSiteClient,
) -> None:
"""Initialize my coordinator."""
super().__init__(hass, omada_client, "DeviceList", POLL_CLIENTS)
async def poll_update(self) -> dict[str, OmadaListDevice]:
"""Poll the site's current registered Omada devices."""
return {d.mac: d for d in await self.omada_client.get_devices()}
class OmadaClientsCoordinator(OmadaCoordinator[OmadaWirelessClient]):
"""Coordinator for getting details about the site's connected clients."""

View file

@ -26,7 +26,6 @@ async def async_setup_entry(
controller: OmadaSiteController = hass.data[DOMAIN][config_entry.entry_id]
clients_coordinator = controller.get_clients_coordinator()
site_id = config_entry.data[CONF_SITE]
# Add all known WiFi devices as potentially tracked devices. They will only be
@ -34,7 +33,7 @@ async def async_setup_entry(
async_add_entities(
[
OmadaClientScannerEntity(
site_id, client.mac, client.name, clients_coordinator
site_id, client.mac, client.name, controller.clients_coordinator
)
async for client in controller.omada_client.get_known_clients()
if isinstance(client, OmadaWirelessClient)

View file

@ -5,7 +5,6 @@ from typing import Any
from tplink_omada_client.devices import OmadaDevice
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.device_registry import DeviceInfo
from homeassistant.helpers.update_coordinator import CoordinatorEntity
from .const import DOMAIN
@ -19,7 +18,7 @@ class OmadaDeviceEntity[_T: OmadaCoordinator[Any]](CoordinatorEntity[_T]):
"""Initialize the device."""
super().__init__(coordinator)
self.device = device
self._attr_device_info = DeviceInfo(
self._attr_device_info = dr.DeviceInfo(
connections={(dr.CONNECTION_NETWORK_MAC, device.mac)},
identifiers={(DOMAIN, device.mac)},
manufacturer="TP-Link",

View file

@ -74,7 +74,7 @@ async def async_setup_entry(
if desc.exists_func(switch, port)
)
gateway_coordinator = await controller.get_gateway_coordinator()
gateway_coordinator = controller.gateway_coordinator
if gateway_coordinator:
for gateway in gateway_coordinator.data.values():
entities.extend(

View file

@ -21,10 +21,9 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN
from .controller import OmadaSiteController
from .coordinator import OmadaCoordinator
from .coordinator import POLL_DEVICES, OmadaCoordinator, OmadaDevicesCoordinator
from .entity import OmadaDeviceEntity
POLL_DELAY_IDLE = 6 * 60 * 60
POLL_DELAY_UPGRADE = 60
@ -35,15 +34,28 @@ class FirmwareUpdateStatus(NamedTuple):
firmware: OmadaFirmwareUpdate | None
class OmadaFirmwareUpdateCoodinator(OmadaCoordinator[FirmwareUpdateStatus]): # pylint: disable=hass-enforce-class-module
"""Coordinator for getting details about ports on a switch."""
class OmadaFirmwareUpdateCoordinator(OmadaCoordinator[FirmwareUpdateStatus]): # pylint: disable=hass-enforce-class-module
"""Coordinator for getting details about available firmware updates for Omada devices."""
def __init__(self, hass: HomeAssistant, omada_client: OmadaSiteClient) -> None:
def __init__(
self,
hass: HomeAssistant,
config_entry: ConfigEntry,
omada_client: OmadaSiteClient,
devices_coordinator: OmadaDevicesCoordinator,
) -> None:
"""Initialize my coordinator."""
super().__init__(hass, omada_client, "Firmware Updates", POLL_DELAY_IDLE)
super().__init__(hass, omada_client, "Firmware Updates", poll_delay=None)
self._devices_coordinator = devices_coordinator
self._config_entry = config_entry
config_entry.async_on_unload(
devices_coordinator.async_add_listener(self._handle_devices_update)
)
async def _get_firmware_updates(self) -> list[FirmwareUpdateStatus]:
devices = await self.omada_client.get_devices()
devices = self._devices_coordinator.data.values()
updates = [
FirmwareUpdateStatus(
@ -55,12 +67,12 @@ class OmadaFirmwareUpdateCoodinator(OmadaCoordinator[FirmwareUpdateStatus]): #
for d in devices
]
# During a firmware upgrade, poll more frequently
self.update_interval = timedelta(
# During a firmware upgrade, poll device list more frequently
self._devices_coordinator.update_interval = timedelta(
seconds=(
POLL_DELAY_UPGRADE
if any(u.device.fw_download for u in updates)
else POLL_DELAY_IDLE
else POLL_DEVICES
)
)
return updates
@ -69,6 +81,14 @@ class OmadaFirmwareUpdateCoodinator(OmadaCoordinator[FirmwareUpdateStatus]): #
"""Poll the state of Omada Devices firmware update availability."""
return {d.device.mac: d for d in await self._get_firmware_updates()}
@callback
def _handle_devices_update(self) -> None:
"""Handle updated data from the devices coordinator."""
# Trigger a refresh of our data, based on the updated device list
self._config_entry.async_create_background_task(
self.hass, self.async_request_refresh(), "Omada Firmware Update Refresh"
)
async def async_setup_entry(
hass: HomeAssistant,
@ -77,18 +97,21 @@ async def async_setup_entry(
) -> None:
"""Set up switches."""
controller: OmadaSiteController = hass.data[DOMAIN][config_entry.entry_id]
omada_client = controller.omada_client
devices = await omada_client.get_devices()
devices = controller.devices_coordinator.data
coordinator = OmadaFirmwareUpdateCoodinator(hass, omada_client)
coordinator = OmadaFirmwareUpdateCoordinator(
hass, config_entry, controller.omada_client, controller.devices_coordinator
)
async_add_entities(OmadaDeviceUpdate(coordinator, device) for device in devices)
async_add_entities(
OmadaDeviceUpdate(coordinator, device) for device in devices.values()
)
await coordinator.async_request_refresh()
class OmadaDeviceUpdate(
OmadaDeviceEntity[OmadaFirmwareUpdateCoodinator],
OmadaDeviceEntity[OmadaFirmwareUpdateCoordinator],
UpdateEntity,
):
"""Firmware update status for Omada SDN devices."""
@ -103,7 +126,7 @@ class OmadaDeviceUpdate(
def __init__(
self,
coordinator: OmadaFirmwareUpdateCoodinator,
coordinator: OmadaFirmwareUpdateCoordinator,
device: OmadaListDevice,
) -> None:
"""Initialize the update entity."""

View file

@ -0,0 +1,47 @@
"""Tests for TP-Link Omada integration init."""
from unittest.mock import MagicMock
from homeassistant.components.tplink_omada.const import DOMAIN
from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr
from tests.common import MockConfigEntry
MOCK_ENTRY_DATA = {
"host": "https://fake.omada.host",
"verify_ssl": True,
"site": "SiteId",
"username": "test-username",
"password": "test-password",
}
async def test_missing_devices_removed_at_startup(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
mock_omada_client: MagicMock,
) -> None:
"""Test missing devices are removed at startup."""
mock_config_entry = MockConfigEntry(
title="Test Omada Controller",
domain=DOMAIN,
data=dict(MOCK_ENTRY_DATA),
unique_id="12345",
)
mock_config_entry.add_to_hass(hass)
device_entry = device_registry.async_get_or_create(
config_entry_id=mock_config_entry.entry_id,
identifiers={(DOMAIN, "AA:BB:CC:DD:EE:FF")},
manufacturer="TPLink",
name="Old Device",
model="Some old model",
)
assert device_registry.async_get(device_entry.id) == device_entry
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
assert device_registry.async_get(device_entry.id) is None