Make client tracker use common UniFi entity class (#84942)

* Make client tracker use common UniFi entity class

* Fix tests

* Fix mypy

* Remove legacy data

* Fix comment: skip else use return

* Minor change

* Remove missed stuff from previous rebase

* Import async_device_available_fn from entities.py rather than specifying it in device_tracker

* Avoid using asserts

* Keep explicit parenthesis for readability

* Allow loading entities on option changes
This commit is contained in:
Robert Svensson 2023-03-11 06:23:49 +01:00 committed by GitHub
parent d6a223f0e1
commit 288a4203ab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 144 additions and 262 deletions

View file

@ -2,20 +2,21 @@
from __future__ import annotations
from collections.abc import Callable
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from datetime import timedelta
import logging
from typing import Generic, TypeVar
from typing import Any, Generic
import aiounifi
from aiounifi.interfaces.api_handlers import ItemEvent
from aiounifi.interfaces.clients import Clients
from aiounifi.interfaces.devices import Devices
from aiounifi.models.api import SOURCE_DATA, SOURCE_EVENT
from aiounifi.models.client import Client
from aiounifi.models.device import Device
from aiounifi.models.event import EventKey
from aiounifi.models.event import Event, EventKey
from homeassistant.components.device_tracker import DOMAIN, ScannerEntity, SourceType
from homeassistant.components.device_tracker import ScannerEntity, SourceType
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect
@ -24,8 +25,13 @@ import homeassistant.util.dt as dt_util
from .const import DOMAIN as UNIFI_DOMAIN
from .controller import UniFiController
from .entity import UnifiEntity, UnifiEntityDescription
from .unifi_client import UniFiClientBase
from .entity import (
DataT,
HandlerT,
UnifiEntity,
UnifiEntityDescription,
async_device_available_fn,
)
LOGGER = logging.getLogger(__name__)
@ -58,6 +64,7 @@ CLIENT_STATIC_ATTRIBUTES = [
CLIENT_CONNECTED_ALL_ATTRIBUTES = CLIENT_CONNECTED_ATTRIBUTES + CLIENT_STATIC_ATTRIBUTES
WIRED_CONNECTION = (EventKey.WIRED_CLIENT_CONNECTED,)
WIRED_DISCONNECTION = (EventKey.WIRED_CLIENT_DISCONNECTED,)
WIRELESS_CONNECTION = (
EventKey.WIRELESS_CLIENT_CONNECTED,
EventKey.WIRELESS_CLIENT_ROAM,
@ -66,17 +73,57 @@ WIRELESS_CONNECTION = (
EventKey.WIRELESS_GUEST_ROAM,
EventKey.WIRELESS_GUEST_ROAM_RADIO,
)
_DataT = TypeVar("_DataT", bound=Device)
_HandlerT = TypeVar("_HandlerT", bound=Devices)
WIRELESS_DISCONNECTION = (
EventKey.WIRELESS_CLIENT_DISCONNECTED,
EventKey.WIRELESS_GUEST_DISCONNECTED,
)
@callback
def async_device_available_fn(controller: UniFiController, obj_id: str) -> bool:
def async_client_allowed_fn(controller: UniFiController, obj_id: str) -> bool:
"""Check if client is allowed."""
if not controller.option_track_clients:
return False
client = controller.api.clients[obj_id]
if client.mac not in controller.wireless_clients:
if not controller.option_track_wired_clients:
return False
elif (
client.essid
and controller.option_ssid_filter
and client.essid not in controller.option_ssid_filter
):
return False
return True
@callback
def async_client_is_connected_fn(controller: UniFiController, obj_id: str) -> bool:
"""Check if device object is disabled."""
device = controller.api.devices[obj_id]
return controller.available and not device.disabled
client = controller.api.clients[obj_id]
if client.is_wired != (obj_id not in controller.wireless_clients):
if not controller.option_ignore_wired_bug:
return False # Wired bug in action
if (
not client.is_wired
and client.essid
and controller.option_ssid_filter
and client.essid not in controller.option_ssid_filter
):
return False
if (
dt_util.utcnow() - dt_util.utc_from_timestamp(client.last_seen or 0)
> controller.option_detection_time
):
return False
return True
@callback
@ -89,7 +136,7 @@ def async_device_heartbeat_timedelta_fn(
@dataclass
class UnifiEntityTrackerDescriptionMixin(Generic[_HandlerT, _DataT]):
class UnifiEntityTrackerDescriptionMixin(Generic[HandlerT, DataT]):
"""Device tracker local functions."""
heartbeat_timedelta_fn: Callable[[UniFiController, str], timedelta]
@ -100,13 +147,36 @@ class UnifiEntityTrackerDescriptionMixin(Generic[_HandlerT, _DataT]):
@dataclass
class UnifiTrackerEntityDescription(
UnifiEntityDescription[_HandlerT, _DataT],
UnifiEntityTrackerDescriptionMixin[_HandlerT, _DataT],
UnifiEntityDescription[HandlerT, DataT],
UnifiEntityTrackerDescriptionMixin[HandlerT, DataT],
):
"""Class describing UniFi device tracker entity."""
ENTITY_DESCRIPTIONS: tuple[UnifiTrackerEntityDescription, ...] = (
UnifiTrackerEntityDescription[Clients, Client](
key="Client device scanner",
has_entity_name=True,
allowed_fn=async_client_allowed_fn,
api_handler_fn=lambda api: api.clients,
available_fn=lambda controller, obj_id: controller.available,
device_info_fn=lambda api, obj_id: None,
event_is_on=(WIRED_CONNECTION + WIRELESS_CONNECTION),
event_to_subscribe=(
WIRED_CONNECTION
+ WIRED_DISCONNECTION
+ WIRELESS_CONNECTION
+ WIRELESS_DISCONNECTION
),
heartbeat_timedelta_fn=lambda controller, _: controller.option_detection_time,
is_connected_fn=async_client_is_connected_fn,
name_fn=lambda client: client.name or client.hostname,
object_fn=lambda api, obj_id: api.clients[obj_id],
supported_fn=lambda controller, obj_id: True,
unique_id_fn=lambda controller, obj_id: f"{obj_id}-{controller.site}",
ip_address_fn=lambda api, obj_id: api.clients[obj_id].ip,
hostname_fn=lambda api, obj_id: None,
),
UnifiTrackerEntityDescription[Devices, Device](
key="Device scanner",
has_entity_name=True,
@ -140,239 +210,13 @@ async def async_setup_entry(
UnifiScannerEntity, ENTITY_DESCRIPTIONS, async_add_entities
)
controller.entities[DOMAIN] = {CLIENT_TRACKER: set(), DEVICE_TRACKER: set()}
@callback
def items_added(
clients: set = controller.api.clients, devices: set = controller.api.devices
) -> None:
"""Update the values of the controller."""
if controller.option_track_clients:
add_client_entities(controller, async_add_entities, clients)
for signal in (controller.signal_update, controller.signal_options_update):
config_entry.async_on_unload(
async_dispatcher_connect(hass, signal, items_added)
)
items_added()
@callback
def add_client_entities(controller, async_add_entities, clients):
"""Add new client tracker entities from the controller."""
trackers = []
for mac in clients:
if mac in controller.entities[DOMAIN][UniFiClientTracker.TYPE] or not (
client := controller.api.clients.get(mac)
):
continue
if mac not in controller.wireless_clients:
if not controller.option_track_wired_clients:
continue
elif (
client.essid
and controller.option_ssid_filter
and client.essid not in controller.option_ssid_filter
):
continue
trackers.append(UniFiClientTracker(client, controller))
async_add_entities(trackers)
class UniFiClientTracker(UniFiClientBase, ScannerEntity):
"""Representation of a network client."""
DOMAIN = DOMAIN
TYPE = CLIENT_TRACKER
def __init__(self, client, controller):
"""Set up tracked client."""
super().__init__(client, controller)
self._controller_connection_state_changed = False
self._only_listen_to_data_source = False
last_seen = client.last_seen or 0
self.schedule_update = self._is_connected = (
self.is_wired == client.is_wired
and dt_util.utcnow() - dt_util.utc_from_timestamp(float(last_seen))
< controller.option_detection_time
)
@callback
def _async_log_debug_data(self, method: str) -> None:
"""Print debug data about entity."""
if not LOGGER.isEnabledFor(logging.DEBUG):
return
last_seen = self.client.last_seen or 0
LOGGER.debug(
"%s [%s, %s] [%s %s] [%s] %s (%s)",
method,
self.entity_id,
self.client.mac,
self.schedule_update,
self._is_connected,
dt_util.utc_from_timestamp(float(last_seen)),
dt_util.utcnow() - dt_util.utc_from_timestamp(float(last_seen)),
last_seen,
)
async def async_added_to_hass(self) -> None:
"""Watch object when added."""
self.async_on_remove(
async_dispatcher_connect(
self.hass,
f"{self.controller.signal_heartbeat_missed}_{self.unique_id}",
self._make_disconnected,
)
)
await super().async_added_to_hass()
self._async_log_debug_data("added_to_hass")
async def async_will_remove_from_hass(self) -> None:
"""Disconnect object when removed."""
self.controller.async_heartbeat(self.unique_id)
await super().async_will_remove_from_hass()
@callback
def async_signal_reachable_callback(self) -> None:
"""Call when controller connection state change."""
self._controller_connection_state_changed = True
super().async_signal_reachable_callback()
@callback
def async_update_callback(self) -> None:
"""Update the clients state."""
if self._controller_connection_state_changed:
self._controller_connection_state_changed = False
if self.controller.available:
self.schedule_update = True
else:
self.controller.async_heartbeat(self.unique_id)
super().async_update_callback()
elif (
self.client.last_updated == SOURCE_DATA
and self.is_wired == self.client.is_wired
):
self._is_connected = True
self.schedule_update = True
self._only_listen_to_data_source = True
elif (
self.client.last_updated == SOURCE_EVENT
and not self._only_listen_to_data_source
):
if (self.is_wired and self.client.event.key in WIRED_CONNECTION) or (
not self.is_wired and self.client.event.key in WIRELESS_CONNECTION
):
self._is_connected = True
self.schedule_update = False
self.controller.async_heartbeat(self.unique_id)
super().async_update_callback()
else:
self.schedule_update = True
self._async_log_debug_data("update_callback")
if self.schedule_update:
self.schedule_update = False
self.controller.async_heartbeat(
self.unique_id, dt_util.utcnow() + self.controller.option_detection_time
)
super().async_update_callback()
@callback
def _make_disconnected(self, *_):
"""No heart beat by device."""
self._is_connected = False
self.async_write_ha_state()
self._async_log_debug_data("make_disconnected")
@property
def is_connected(self):
"""Return true if the client is connected to the network."""
if (
not self.is_wired
and self.client.essid
and self.controller.option_ssid_filter
and self.client.essid not in self.controller.option_ssid_filter
):
return False
return self._is_connected
@property
def source_type(self) -> SourceType:
"""Return the source type of the client."""
return SourceType.ROUTER
@property
def unique_id(self) -> str:
"""Return a unique identifier for this client."""
return f"{self.client.mac}-{self.controller.site}"
@property
def extra_state_attributes(self):
"""Return the client state attributes."""
raw = self.client.raw
attributes_to_check = CLIENT_STATIC_ATTRIBUTES
if self.is_connected:
attributes_to_check = CLIENT_CONNECTED_ALL_ATTRIBUTES
attributes = {k: raw[k] for k in attributes_to_check if k in raw}
attributes["is_wired"] = self.is_wired
return attributes
@property
def ip_address(self) -> str:
"""Return the primary ip address of the device."""
return self.client.raw.get("ip")
@property
def mac_address(self) -> str:
"""Return the mac address of the device."""
return self.client.raw.get("mac")
@property
def hostname(self) -> str:
"""Return hostname of the device."""
return self.client.raw.get("hostname")
async def options_updated(self) -> None:
"""Config entry options are updated, remove entity if option is disabled."""
if not self.controller.option_track_clients:
await self.remove_item({self.client.mac})
elif self.is_wired:
if not self.controller.option_track_wired_clients:
await self.remove_item({self.client.mac})
elif (
self.controller.option_ssid_filter
and self.client.essid not in self.controller.option_ssid_filter
):
await self.remove_item({self.client.mac})
class UnifiScannerEntity(UnifiEntity[_HandlerT, _DataT], ScannerEntity):
class UnifiScannerEntity(UnifiEntity[HandlerT, DataT], ScannerEntity):
"""Representation of a UniFi scanner."""
entity_description: UnifiTrackerEntityDescription
_event_is_on: tuple[EventKey, ...]
_ignore_events: bool
_is_connected: bool
@ -383,8 +227,15 @@ class UnifiScannerEntity(UnifiEntity[_HandlerT, _DataT], ScannerEntity):
Initiate is_connected.
"""
description = self.entity_description
self._event_is_on = description.event_is_on or ()
self._ignore_events = False
self._is_connected = description.is_connected_fn(self.controller, self._obj_id)
if self.is_connected:
self.controller.async_heartbeat(
self.unique_id,
dt_util.utcnow()
+ description.heartbeat_timedelta_fn(self.controller, self._obj_id),
)
@property
def is_connected(self) -> bool:
@ -452,13 +303,33 @@ class UnifiScannerEntity(UnifiEntity[_HandlerT, _DataT], ScannerEntity):
+ description.heartbeat_timedelta_fn(self.controller, self._obj_id),
)
@callback
def async_event_callback(self, event: Event) -> None:
"""Event subscription callback."""
if event.mac != self._obj_id or self._ignore_events:
return
if event.key in self._event_is_on:
self.controller.async_heartbeat(self.unique_id)
self._is_connected = True
self.async_write_ha_state()
return
self.controller.async_heartbeat(
self.unique_id,
dt_util.utcnow()
+ self.entity_description.heartbeat_timedelta_fn(
self.controller, self._obj_id
),
)
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
await super().async_added_to_hass()
self.async_on_remove(
async_dispatcher_connect(
self.hass,
f"{self.controller.signal_heartbeat_missed}_{self._obj_id}",
f"{self.controller.signal_heartbeat_missed}_{self.unique_id}",
self._make_disconnected,
)
)
@ -467,3 +338,20 @@ class UnifiScannerEntity(UnifiEntity[_HandlerT, _DataT], ScannerEntity):
"""Disconnect object when removed."""
await super().async_will_remove_from_hass()
self.controller.async_heartbeat(self.unique_id)
@property
def extra_state_attributes(self) -> Mapping[str, Any] | None:
"""Return the client state attributes."""
if self.entity_description.key != "Client device scanner":
return None
client = self.entity_description.object_fn(self.controller.api, self._obj_id)
raw = client.raw
attributes_to_check = CLIENT_STATIC_ATTRIBUTES
if self.is_connected:
attributes_to_check = CLIENT_CONNECTED_ALL_ATTRIBUTES
attributes = {k: raw[k] for k in attributes_to_check if k in raw}
return attributes

View file

@ -156,7 +156,7 @@ async def test_tracked_clients(
# State change signalling works
client_1["last_seen"] += 1
client_1["last_seen"] = dt_util.as_timestamp(dt_util.utcnow())
mock_unifi_websocket(message=MessageKey.CLIENT, data=client_1)
await hass.async_block_till_done()
@ -244,6 +244,7 @@ async def test_tracked_wireless_clients_event_source(
# New data
client["last_seen"] = dt_util.as_timestamp(dt_util.utcnow())
mock_unifi_websocket(message=MessageKey.CLIENT, data=client)
await hass.async_block_till_done()
assert hass.states.get("device_tracker.client").state == STATE_HOME
@ -703,6 +704,11 @@ async def test_option_ssid_filter(
mock_unifi_websocket(message=MessageKey.CLIENT, data=client_on_ssid2)
await hass.async_block_till_done()
new_time = dt_util.utcnow() + controller.option_detection_time
with patch("homeassistant.util.dt.utcnow", return_value=new_time):
async_fire_time_changed(hass, new_time)
await hass.async_block_till_done()
# SSID filter marks client as away
assert hass.states.get("device_tracker.client").state == STATE_NOT_HOME
@ -726,7 +732,7 @@ async def test_option_ssid_filter(
# Time pass to mark client as away
new_time = dt_util.utcnow() + controller.option_detection_time
new_time += controller.option_detection_time
with patch("homeassistant.util.dt.utcnow", return_value=new_time):
async_fire_time_changed(hass, new_time)
await hass.async_block_till_done()
@ -745,9 +751,7 @@ async def test_option_ssid_filter(
mock_unifi_websocket(message=MessageKey.CLIENT, data=client_on_ssid2)
await hass.async_block_till_done()
new_time = (
dt_util.utcnow() + controller.option_detection_time + timedelta(seconds=1)
)
new_time += controller.option_detection_time
with patch("homeassistant.util.dt.utcnow", return_value=new_time):
async_fire_time_changed(hass, new_time)
await hass.async_block_till_done()
@ -784,10 +788,9 @@ async def test_wireless_client_go_wired_issue(
# Client is wireless
client_state = hass.states.get("device_tracker.client")
assert client_state.state == STATE_HOME
assert client_state.attributes["is_wired"] is False
# Trigger wired bug
client["last_seen"] += 1
client["last_seen"] = dt_util.as_timestamp(dt_util.utcnow())
client["is_wired"] = True
mock_unifi_websocket(message=MessageKey.CLIENT, data=client)
await hass.async_block_till_done()
@ -795,7 +798,6 @@ async def test_wireless_client_go_wired_issue(
# Wired bug fix keeps client marked as wireless
client_state = hass.states.get("device_tracker.client")
assert client_state.state == STATE_HOME
assert client_state.attributes["is_wired"] is False
# Pass time
new_time = dt_util.utcnow() + controller.option_detection_time
@ -806,7 +808,6 @@ async def test_wireless_client_go_wired_issue(
# Marked as home according to the timer
client_state = hass.states.get("device_tracker.client")
assert client_state.state == STATE_NOT_HOME
assert client_state.attributes["is_wired"] is False
# Try to mark client as connected
client["last_seen"] += 1
@ -816,7 +817,6 @@ async def test_wireless_client_go_wired_issue(
# Make sure it don't go online again until wired bug disappears
client_state = hass.states.get("device_tracker.client")
assert client_state.state == STATE_NOT_HOME
assert client_state.attributes["is_wired"] is False
# Make client wireless
client["last_seen"] += 1
@ -827,7 +827,6 @@ async def test_wireless_client_go_wired_issue(
# Client is no longer affected by wired bug and can be marked online
client_state = hass.states.get("device_tracker.client")
assert client_state.state == STATE_HOME
assert client_state.attributes["is_wired"] is False
async def test_option_ignore_wired_bug(
@ -859,7 +858,6 @@ async def test_option_ignore_wired_bug(
# Client is wireless
client_state = hass.states.get("device_tracker.client")
assert client_state.state == STATE_HOME
assert client_state.attributes["is_wired"] is False
# Trigger wired bug
client["is_wired"] = True
@ -869,7 +867,6 @@ async def test_option_ignore_wired_bug(
# Wired bug in effect
client_state = hass.states.get("device_tracker.client")
assert client_state.state == STATE_HOME
assert client_state.attributes["is_wired"] is True
# pass time
new_time = dt_util.utcnow() + controller.option_detection_time
@ -880,7 +877,6 @@ async def test_option_ignore_wired_bug(
# Timer marks client as away
client_state = hass.states.get("device_tracker.client")
assert client_state.state == STATE_NOT_HOME
assert client_state.attributes["is_wired"] is True
# Mark client as connected again
client["last_seen"] += 1
@ -890,7 +886,6 @@ async def test_option_ignore_wired_bug(
# Ignoring wired bug allows client to go home again even while affected
client_state = hass.states.get("device_tracker.client")
assert client_state.state == STATE_HOME
assert client_state.attributes["is_wired"] is True
# Make client wireless
client["last_seen"] += 1
@ -901,7 +896,6 @@ async def test_option_ignore_wired_bug(
# Client is wireless and still connected
client_state = hass.states.get("device_tracker.client")
assert client_state.state == STATE_HOME
assert client_state.attributes["is_wired"] is False
async def test_restoring_client(