diff --git a/homeassistant/components/unifi/device_tracker.py b/homeassistant/components/unifi/device_tracker.py index c845b6d5d38..a5b153d7f36 100644 --- a/homeassistant/components/unifi/device_tracker.py +++ b/homeassistant/components/unifi/device_tracker.py @@ -2,20 +2,21 @@ from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Mapping from dataclasses import dataclass from datetime import timedelta import logging -from typing import Generic, TypeVar +from typing import Any, Generic import aiounifi from aiounifi.interfaces.api_handlers import ItemEvent +from aiounifi.interfaces.clients import Clients from aiounifi.interfaces.devices import Devices -from aiounifi.models.api import SOURCE_DATA, SOURCE_EVENT +from aiounifi.models.client import Client from aiounifi.models.device import Device -from aiounifi.models.event import EventKey +from aiounifi.models.event import Event, EventKey -from homeassistant.components.device_tracker import DOMAIN, ScannerEntity, SourceType +from homeassistant.components.device_tracker import ScannerEntity, SourceType from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.dispatcher import async_dispatcher_connect @@ -24,8 +25,13 @@ import homeassistant.util.dt as dt_util from .const import DOMAIN as UNIFI_DOMAIN from .controller import UniFiController -from .entity import UnifiEntity, UnifiEntityDescription -from .unifi_client import UniFiClientBase +from .entity import ( + DataT, + HandlerT, + UnifiEntity, + UnifiEntityDescription, + async_device_available_fn, +) LOGGER = logging.getLogger(__name__) @@ -58,6 +64,7 @@ CLIENT_STATIC_ATTRIBUTES = [ CLIENT_CONNECTED_ALL_ATTRIBUTES = CLIENT_CONNECTED_ATTRIBUTES + CLIENT_STATIC_ATTRIBUTES WIRED_CONNECTION = (EventKey.WIRED_CLIENT_CONNECTED,) +WIRED_DISCONNECTION = (EventKey.WIRED_CLIENT_DISCONNECTED,) WIRELESS_CONNECTION = ( EventKey.WIRELESS_CLIENT_CONNECTED, EventKey.WIRELESS_CLIENT_ROAM, @@ -66,17 +73,57 @@ WIRELESS_CONNECTION = ( EventKey.WIRELESS_GUEST_ROAM, EventKey.WIRELESS_GUEST_ROAM_RADIO, ) - - -_DataT = TypeVar("_DataT", bound=Device) -_HandlerT = TypeVar("_HandlerT", bound=Devices) +WIRELESS_DISCONNECTION = ( + EventKey.WIRELESS_CLIENT_DISCONNECTED, + EventKey.WIRELESS_GUEST_DISCONNECTED, +) @callback -def async_device_available_fn(controller: UniFiController, obj_id: str) -> bool: +def async_client_allowed_fn(controller: UniFiController, obj_id: str) -> bool: + """Check if client is allowed.""" + if not controller.option_track_clients: + return False + + client = controller.api.clients[obj_id] + if client.mac not in controller.wireless_clients: + if not controller.option_track_wired_clients: + return False + + elif ( + client.essid + and controller.option_ssid_filter + and client.essid not in controller.option_ssid_filter + ): + return False + + return True + + +@callback +def async_client_is_connected_fn(controller: UniFiController, obj_id: str) -> bool: """Check if device object is disabled.""" - device = controller.api.devices[obj_id] - return controller.available and not device.disabled + client = controller.api.clients[obj_id] + + if client.is_wired != (obj_id not in controller.wireless_clients): + if not controller.option_ignore_wired_bug: + return False # Wired bug in action + + if ( + not client.is_wired + and client.essid + and controller.option_ssid_filter + and client.essid not in controller.option_ssid_filter + ): + return False + + if ( + dt_util.utcnow() - dt_util.utc_from_timestamp(client.last_seen or 0) + > controller.option_detection_time + ): + return False + + return True @callback @@ -89,7 +136,7 @@ def async_device_heartbeat_timedelta_fn( @dataclass -class UnifiEntityTrackerDescriptionMixin(Generic[_HandlerT, _DataT]): +class UnifiEntityTrackerDescriptionMixin(Generic[HandlerT, DataT]): """Device tracker local functions.""" heartbeat_timedelta_fn: Callable[[UniFiController, str], timedelta] @@ -100,13 +147,36 @@ class UnifiEntityTrackerDescriptionMixin(Generic[_HandlerT, _DataT]): @dataclass class UnifiTrackerEntityDescription( - UnifiEntityDescription[_HandlerT, _DataT], - UnifiEntityTrackerDescriptionMixin[_HandlerT, _DataT], + UnifiEntityDescription[HandlerT, DataT], + UnifiEntityTrackerDescriptionMixin[HandlerT, DataT], ): """Class describing UniFi device tracker entity.""" ENTITY_DESCRIPTIONS: tuple[UnifiTrackerEntityDescription, ...] = ( + UnifiTrackerEntityDescription[Clients, Client]( + key="Client device scanner", + has_entity_name=True, + allowed_fn=async_client_allowed_fn, + api_handler_fn=lambda api: api.clients, + available_fn=lambda controller, obj_id: controller.available, + device_info_fn=lambda api, obj_id: None, + event_is_on=(WIRED_CONNECTION + WIRELESS_CONNECTION), + event_to_subscribe=( + WIRED_CONNECTION + + WIRED_DISCONNECTION + + WIRELESS_CONNECTION + + WIRELESS_DISCONNECTION + ), + heartbeat_timedelta_fn=lambda controller, _: controller.option_detection_time, + is_connected_fn=async_client_is_connected_fn, + name_fn=lambda client: client.name or client.hostname, + object_fn=lambda api, obj_id: api.clients[obj_id], + supported_fn=lambda controller, obj_id: True, + unique_id_fn=lambda controller, obj_id: f"{obj_id}-{controller.site}", + ip_address_fn=lambda api, obj_id: api.clients[obj_id].ip, + hostname_fn=lambda api, obj_id: None, + ), UnifiTrackerEntityDescription[Devices, Device]( key="Device scanner", has_entity_name=True, @@ -140,239 +210,13 @@ async def async_setup_entry( UnifiScannerEntity, ENTITY_DESCRIPTIONS, async_add_entities ) - controller.entities[DOMAIN] = {CLIENT_TRACKER: set(), DEVICE_TRACKER: set()} - @callback - def items_added( - clients: set = controller.api.clients, devices: set = controller.api.devices - ) -> None: - """Update the values of the controller.""" - if controller.option_track_clients: - add_client_entities(controller, async_add_entities, clients) - - for signal in (controller.signal_update, controller.signal_options_update): - config_entry.async_on_unload( - async_dispatcher_connect(hass, signal, items_added) - ) - - items_added() - - -@callback -def add_client_entities(controller, async_add_entities, clients): - """Add new client tracker entities from the controller.""" - trackers = [] - - for mac in clients: - if mac in controller.entities[DOMAIN][UniFiClientTracker.TYPE] or not ( - client := controller.api.clients.get(mac) - ): - continue - - if mac not in controller.wireless_clients: - if not controller.option_track_wired_clients: - continue - elif ( - client.essid - and controller.option_ssid_filter - and client.essid not in controller.option_ssid_filter - ): - continue - - trackers.append(UniFiClientTracker(client, controller)) - - async_add_entities(trackers) - - -class UniFiClientTracker(UniFiClientBase, ScannerEntity): - """Representation of a network client.""" - - DOMAIN = DOMAIN - TYPE = CLIENT_TRACKER - - def __init__(self, client, controller): - """Set up tracked client.""" - super().__init__(client, controller) - - self._controller_connection_state_changed = False - - self._only_listen_to_data_source = False - - last_seen = client.last_seen or 0 - self.schedule_update = self._is_connected = ( - self.is_wired == client.is_wired - and dt_util.utcnow() - dt_util.utc_from_timestamp(float(last_seen)) - < controller.option_detection_time - ) - - @callback - def _async_log_debug_data(self, method: str) -> None: - """Print debug data about entity.""" - if not LOGGER.isEnabledFor(logging.DEBUG): - return - last_seen = self.client.last_seen or 0 - LOGGER.debug( - "%s [%s, %s] [%s %s] [%s] %s (%s)", - method, - self.entity_id, - self.client.mac, - self.schedule_update, - self._is_connected, - dt_util.utc_from_timestamp(float(last_seen)), - dt_util.utcnow() - dt_util.utc_from_timestamp(float(last_seen)), - last_seen, - ) - - async def async_added_to_hass(self) -> None: - """Watch object when added.""" - self.async_on_remove( - async_dispatcher_connect( - self.hass, - f"{self.controller.signal_heartbeat_missed}_{self.unique_id}", - self._make_disconnected, - ) - ) - await super().async_added_to_hass() - self._async_log_debug_data("added_to_hass") - - async def async_will_remove_from_hass(self) -> None: - """Disconnect object when removed.""" - self.controller.async_heartbeat(self.unique_id) - await super().async_will_remove_from_hass() - - @callback - def async_signal_reachable_callback(self) -> None: - """Call when controller connection state change.""" - self._controller_connection_state_changed = True - super().async_signal_reachable_callback() - - @callback - def async_update_callback(self) -> None: - """Update the clients state.""" - - if self._controller_connection_state_changed: - self._controller_connection_state_changed = False - - if self.controller.available: - self.schedule_update = True - - else: - self.controller.async_heartbeat(self.unique_id) - super().async_update_callback() - - elif ( - self.client.last_updated == SOURCE_DATA - and self.is_wired == self.client.is_wired - ): - self._is_connected = True - self.schedule_update = True - self._only_listen_to_data_source = True - - elif ( - self.client.last_updated == SOURCE_EVENT - and not self._only_listen_to_data_source - ): - if (self.is_wired and self.client.event.key in WIRED_CONNECTION) or ( - not self.is_wired and self.client.event.key in WIRELESS_CONNECTION - ): - self._is_connected = True - self.schedule_update = False - self.controller.async_heartbeat(self.unique_id) - super().async_update_callback() - - else: - self.schedule_update = True - - self._async_log_debug_data("update_callback") - - if self.schedule_update: - self.schedule_update = False - self.controller.async_heartbeat( - self.unique_id, dt_util.utcnow() + self.controller.option_detection_time - ) - - super().async_update_callback() - - @callback - def _make_disconnected(self, *_): - """No heart beat by device.""" - self._is_connected = False - self.async_write_ha_state() - self._async_log_debug_data("make_disconnected") - - @property - def is_connected(self): - """Return true if the client is connected to the network.""" - if ( - not self.is_wired - and self.client.essid - and self.controller.option_ssid_filter - and self.client.essid not in self.controller.option_ssid_filter - ): - return False - - return self._is_connected - - @property - def source_type(self) -> SourceType: - """Return the source type of the client.""" - return SourceType.ROUTER - - @property - def unique_id(self) -> str: - """Return a unique identifier for this client.""" - return f"{self.client.mac}-{self.controller.site}" - - @property - def extra_state_attributes(self): - """Return the client state attributes.""" - raw = self.client.raw - - attributes_to_check = CLIENT_STATIC_ATTRIBUTES - if self.is_connected: - attributes_to_check = CLIENT_CONNECTED_ALL_ATTRIBUTES - - attributes = {k: raw[k] for k in attributes_to_check if k in raw} - attributes["is_wired"] = self.is_wired - - return attributes - - @property - def ip_address(self) -> str: - """Return the primary ip address of the device.""" - return self.client.raw.get("ip") - - @property - def mac_address(self) -> str: - """Return the mac address of the device.""" - return self.client.raw.get("mac") - - @property - def hostname(self) -> str: - """Return hostname of the device.""" - return self.client.raw.get("hostname") - - async def options_updated(self) -> None: - """Config entry options are updated, remove entity if option is disabled.""" - if not self.controller.option_track_clients: - await self.remove_item({self.client.mac}) - - elif self.is_wired: - if not self.controller.option_track_wired_clients: - await self.remove_item({self.client.mac}) - - elif ( - self.controller.option_ssid_filter - and self.client.essid not in self.controller.option_ssid_filter - ): - await self.remove_item({self.client.mac}) - - -class UnifiScannerEntity(UnifiEntity[_HandlerT, _DataT], ScannerEntity): +class UnifiScannerEntity(UnifiEntity[HandlerT, DataT], ScannerEntity): """Representation of a UniFi scanner.""" entity_description: UnifiTrackerEntityDescription + _event_is_on: tuple[EventKey, ...] _ignore_events: bool _is_connected: bool @@ -383,8 +227,15 @@ class UnifiScannerEntity(UnifiEntity[_HandlerT, _DataT], ScannerEntity): Initiate is_connected. """ description = self.entity_description + self._event_is_on = description.event_is_on or () self._ignore_events = False self._is_connected = description.is_connected_fn(self.controller, self._obj_id) + if self.is_connected: + self.controller.async_heartbeat( + self.unique_id, + dt_util.utcnow() + + description.heartbeat_timedelta_fn(self.controller, self._obj_id), + ) @property def is_connected(self) -> bool: @@ -452,13 +303,33 @@ class UnifiScannerEntity(UnifiEntity[_HandlerT, _DataT], ScannerEntity): + description.heartbeat_timedelta_fn(self.controller, self._obj_id), ) + @callback + def async_event_callback(self, event: Event) -> None: + """Event subscription callback.""" + if event.mac != self._obj_id or self._ignore_events: + return + + if event.key in self._event_is_on: + self.controller.async_heartbeat(self.unique_id) + self._is_connected = True + self.async_write_ha_state() + return + + self.controller.async_heartbeat( + self.unique_id, + dt_util.utcnow() + + self.entity_description.heartbeat_timedelta_fn( + self.controller, self._obj_id + ), + ) + async def async_added_to_hass(self) -> None: """Register callbacks.""" await super().async_added_to_hass() self.async_on_remove( async_dispatcher_connect( self.hass, - f"{self.controller.signal_heartbeat_missed}_{self._obj_id}", + f"{self.controller.signal_heartbeat_missed}_{self.unique_id}", self._make_disconnected, ) ) @@ -467,3 +338,20 @@ class UnifiScannerEntity(UnifiEntity[_HandlerT, _DataT], ScannerEntity): """Disconnect object when removed.""" await super().async_will_remove_from_hass() self.controller.async_heartbeat(self.unique_id) + + @property + def extra_state_attributes(self) -> Mapping[str, Any] | None: + """Return the client state attributes.""" + if self.entity_description.key != "Client device scanner": + return None + + client = self.entity_description.object_fn(self.controller.api, self._obj_id) + raw = client.raw + + attributes_to_check = CLIENT_STATIC_ATTRIBUTES + if self.is_connected: + attributes_to_check = CLIENT_CONNECTED_ALL_ATTRIBUTES + + attributes = {k: raw[k] for k in attributes_to_check if k in raw} + + return attributes diff --git a/tests/components/unifi/test_device_tracker.py b/tests/components/unifi/test_device_tracker.py index 5dcf1fc6932..1e68b497111 100644 --- a/tests/components/unifi/test_device_tracker.py +++ b/tests/components/unifi/test_device_tracker.py @@ -156,7 +156,7 @@ async def test_tracked_clients( # State change signalling works - client_1["last_seen"] += 1 + client_1["last_seen"] = dt_util.as_timestamp(dt_util.utcnow()) mock_unifi_websocket(message=MessageKey.CLIENT, data=client_1) await hass.async_block_till_done() @@ -244,6 +244,7 @@ async def test_tracked_wireless_clients_event_source( # New data + client["last_seen"] = dt_util.as_timestamp(dt_util.utcnow()) mock_unifi_websocket(message=MessageKey.CLIENT, data=client) await hass.async_block_till_done() assert hass.states.get("device_tracker.client").state == STATE_HOME @@ -703,6 +704,11 @@ async def test_option_ssid_filter( mock_unifi_websocket(message=MessageKey.CLIENT, data=client_on_ssid2) await hass.async_block_till_done() + new_time = dt_util.utcnow() + controller.option_detection_time + with patch("homeassistant.util.dt.utcnow", return_value=new_time): + async_fire_time_changed(hass, new_time) + await hass.async_block_till_done() + # SSID filter marks client as away assert hass.states.get("device_tracker.client").state == STATE_NOT_HOME @@ -726,7 +732,7 @@ async def test_option_ssid_filter( # Time pass to mark client as away - new_time = dt_util.utcnow() + controller.option_detection_time + new_time += controller.option_detection_time with patch("homeassistant.util.dt.utcnow", return_value=new_time): async_fire_time_changed(hass, new_time) await hass.async_block_till_done() @@ -745,9 +751,7 @@ async def test_option_ssid_filter( mock_unifi_websocket(message=MessageKey.CLIENT, data=client_on_ssid2) await hass.async_block_till_done() - new_time = ( - dt_util.utcnow() + controller.option_detection_time + timedelta(seconds=1) - ) + new_time += controller.option_detection_time with patch("homeassistant.util.dt.utcnow", return_value=new_time): async_fire_time_changed(hass, new_time) await hass.async_block_till_done() @@ -784,10 +788,9 @@ async def test_wireless_client_go_wired_issue( # Client is wireless client_state = hass.states.get("device_tracker.client") assert client_state.state == STATE_HOME - assert client_state.attributes["is_wired"] is False # Trigger wired bug - client["last_seen"] += 1 + client["last_seen"] = dt_util.as_timestamp(dt_util.utcnow()) client["is_wired"] = True mock_unifi_websocket(message=MessageKey.CLIENT, data=client) await hass.async_block_till_done() @@ -795,7 +798,6 @@ async def test_wireless_client_go_wired_issue( # Wired bug fix keeps client marked as wireless client_state = hass.states.get("device_tracker.client") assert client_state.state == STATE_HOME - assert client_state.attributes["is_wired"] is False # Pass time new_time = dt_util.utcnow() + controller.option_detection_time @@ -806,7 +808,6 @@ async def test_wireless_client_go_wired_issue( # Marked as home according to the timer client_state = hass.states.get("device_tracker.client") assert client_state.state == STATE_NOT_HOME - assert client_state.attributes["is_wired"] is False # Try to mark client as connected client["last_seen"] += 1 @@ -816,7 +817,6 @@ async def test_wireless_client_go_wired_issue( # Make sure it don't go online again until wired bug disappears client_state = hass.states.get("device_tracker.client") assert client_state.state == STATE_NOT_HOME - assert client_state.attributes["is_wired"] is False # Make client wireless client["last_seen"] += 1 @@ -827,7 +827,6 @@ async def test_wireless_client_go_wired_issue( # Client is no longer affected by wired bug and can be marked online client_state = hass.states.get("device_tracker.client") assert client_state.state == STATE_HOME - assert client_state.attributes["is_wired"] is False async def test_option_ignore_wired_bug( @@ -859,7 +858,6 @@ async def test_option_ignore_wired_bug( # Client is wireless client_state = hass.states.get("device_tracker.client") assert client_state.state == STATE_HOME - assert client_state.attributes["is_wired"] is False # Trigger wired bug client["is_wired"] = True @@ -869,7 +867,6 @@ async def test_option_ignore_wired_bug( # Wired bug in effect client_state = hass.states.get("device_tracker.client") assert client_state.state == STATE_HOME - assert client_state.attributes["is_wired"] is True # pass time new_time = dt_util.utcnow() + controller.option_detection_time @@ -880,7 +877,6 @@ async def test_option_ignore_wired_bug( # Timer marks client as away client_state = hass.states.get("device_tracker.client") assert client_state.state == STATE_NOT_HOME - assert client_state.attributes["is_wired"] is True # Mark client as connected again client["last_seen"] += 1 @@ -890,7 +886,6 @@ async def test_option_ignore_wired_bug( # Ignoring wired bug allows client to go home again even while affected client_state = hass.states.get("device_tracker.client") assert client_state.state == STATE_HOME - assert client_state.attributes["is_wired"] is True # Make client wireless client["last_seen"] += 1 @@ -901,7 +896,6 @@ async def test_option_ignore_wired_bug( # Client is wireless and still connected client_state = hass.states.get("device_tracker.client") assert client_state.state == STATE_HOME - assert client_state.attributes["is_wired"] is False async def test_restoring_client(