Refactor UniFi DPI switch entities (#80761)

* Refactor UniFi DPI switch entities

* Remove dpi presence from items_added
This commit is contained in:
Robert Svensson 2022-10-23 22:42:24 +02:00 committed by GitHub
parent d75834cd1e
commit 03bf37e12c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 105 additions and 108 deletions

View file

@ -9,12 +9,7 @@ from typing import Any
from aiohttp import CookieJar from aiohttp import CookieJar
import aiounifi import aiounifi
from aiounifi.interfaces.messages import ( from aiounifi.interfaces.messages import DATA_CLIENT_REMOVED, DATA_EVENT
DATA_CLIENT_REMOVED,
DATA_DPI_GROUP,
DATA_DPI_GROUP_REMOVED,
DATA_EVENT,
)
from aiounifi.models.event import EventKey from aiounifi.models.event import EventKey
from aiounifi.websocket import WebsocketSignal, WebsocketState from aiounifi.websocket import WebsocketSignal, WebsocketState
import async_timeout import async_timeout
@ -247,14 +242,6 @@ class UniFiController:
self.hass, self.signal_remove, data[DATA_CLIENT_REMOVED] self.hass, self.signal_remove, data[DATA_CLIENT_REMOVED]
) )
elif DATA_DPI_GROUP in data:
async_dispatcher_send(self.hass, self.signal_update)
elif DATA_DPI_GROUP_REMOVED in data:
async_dispatcher_send(
self.hass, self.signal_remove, data[DATA_DPI_GROUP_REMOVED]
)
@property @property
def signal_reachable(self) -> str: def signal_reachable(self) -> str:
"""Integration specific event to signal a change in connection status.""" """Integration specific event to signal a change in connection status."""

View file

@ -33,7 +33,6 @@ from homeassistant.helpers.restore_state import RestoreEntity
from .const import ATTR_MANUFACTURER, DOMAIN as UNIFI_DOMAIN from .const import ATTR_MANUFACTURER, DOMAIN as UNIFI_DOMAIN
from .unifi_client import UniFiClient from .unifi_client import UniFiClient
from .unifi_entity_base import UniFiBase
BLOCK_SWITCH = "block" BLOCK_SWITCH = "block"
DPI_SWITCH = "dpi" DPI_SWITCH = "dpi"
@ -88,7 +87,7 @@ async def async_setup_entry(
@callback @callback
def items_added( def items_added(
clients: set = controller.api.clients, clients: set = controller.api.clients,
dpi_groups: set = controller.api.dpi_groups, devices: set = controller.api.devices,
) -> None: ) -> None:
"""Update the values of the controller.""" """Update the values of the controller."""
if controller.option_block_clients: if controller.option_block_clients:
@ -97,9 +96,6 @@ async def async_setup_entry(
if controller.option_poe_clients: if controller.option_poe_clients:
add_poe_entities(controller, async_add_entities, clients, known_poe_clients) add_poe_entities(controller, async_add_entities, clients, known_poe_clients)
if controller.option_dpi_restrictions:
add_dpi_entities(controller, async_add_entities, dpi_groups)
for signal in (controller.signal_update, controller.signal_options_update): for signal in (controller.signal_update, controller.signal_options_update):
config_entry.async_on_unload( config_entry.async_on_unload(
async_dispatcher_connect(hass, signal, items_added) async_dispatcher_connect(hass, signal, items_added)
@ -120,6 +116,20 @@ async def async_setup_entry(
for index in controller.api.outlets: for index in controller.api.outlets:
async_add_outlet_switch(ItemEvent.ADDED, index) async_add_outlet_switch(ItemEvent.ADDED, index)
def async_add_dpi_switch(_: ItemEvent, obj_id: str) -> None:
"""Add DPI switch from UniFi controller."""
if (
not controller.option_dpi_restrictions
or not controller.api.dpi_groups[obj_id].dpiapp_ids
):
return
async_add_entities([UnifiDPIRestrictionSwitch(obj_id, controller)])
controller.api.ports.subscribe(async_add_dpi_switch, ItemEvent.ADDED)
for dpi_group_id in controller.api.dpi_groups:
async_add_dpi_switch(ItemEvent.ADDED, dpi_group_id)
@callback @callback
def async_add_poe_switch(_: ItemEvent, obj_id: str) -> None: def async_add_poe_switch(_: ItemEvent, obj_id: str) -> None:
"""Add port PoE switch from UniFi controller.""" """Add port PoE switch from UniFi controller."""
@ -198,23 +208,6 @@ def add_poe_entities(controller, async_add_entities, clients, known_poe_clients)
async_add_entities(switches) async_add_entities(switches)
@callback
def add_dpi_entities(controller, async_add_entities, dpi_groups):
"""Add new switch entities from the controller."""
switches = []
for group in dpi_groups:
if (
group in controller.entities[DOMAIN][DPI_SWITCH]
or not dpi_groups[group].dpiapp_ids
):
continue
switches.append(UniFiDPIRestrictionSwitch(dpi_groups[group], controller))
async_add_entities(switches)
class UniFiPOEClientSwitch(UniFiClient, SwitchEntity, RestoreEntity): class UniFiPOEClientSwitch(UniFiClient, SwitchEntity, RestoreEntity):
"""Representation of a client that uses POE.""" """Representation of a client that uses POE."""
@ -367,132 +360,139 @@ class UniFiBlockClientSwitch(UniFiClient, SwitchEntity):
await self.remove_item({self.client.mac}) await self.remove_item({self.client.mac})
class UniFiDPIRestrictionSwitch(UniFiBase, SwitchEntity): class UnifiDPIRestrictionSwitch(SwitchEntity):
"""Representation of a DPI restriction group.""" """Representation of a DPI restriction group."""
DOMAIN = DOMAIN
TYPE = DPI_SWITCH
_attr_entity_category = EntityCategory.CONFIG _attr_entity_category = EntityCategory.CONFIG
def __init__(self, dpi_group, controller): def __init__(self, obj_id: str, controller):
"""Set up dpi switch.""" """Set up dpi switch."""
super().__init__(dpi_group, controller) controller.entities[DOMAIN][DPI_SWITCH].add(obj_id)
self._obj_id = obj_id
self.controller = controller
self._is_enabled = self.calculate_enabled() dpi_group = controller.api.dpi_groups[obj_id]
self._known_app_ids = dpi_group.dpiapp_ids self._known_app_ids = dpi_group.dpiapp_ids
@property self._attr_available = controller.available
def key(self) -> Any: self._attr_is_on = self.calculate_enabled()
"""Return item key.""" self._attr_name = dpi_group.name
return self._item.id self._attr_unique_id = dpi_group.id
self._attr_device_info = DeviceInfo(
entry_type=DeviceEntryType.SERVICE,
identifiers={(DOMAIN, f"unifi_controller_{obj_id}")},
manufacturer=ATTR_MANUFACTURER,
model="UniFi Network",
name="UniFi Network",
)
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Register callback to known apps.""" """Register callback to known apps."""
await super().async_added_to_hass() self.async_on_remove(
self.controller.api.dpi_groups.subscribe(self.async_signalling_callback)
apps = self.controller.api.dpi_apps )
for app_id in self._item.dpiapp_ids: self.async_on_remove(
apps[app_id].register_callback(self.async_update_callback) self.controller.api.dpi_apps.subscribe(
self.async_signalling_callback, ItemEvent.CHANGED
),
)
self.async_on_remove(
async_dispatcher_connect(
self.hass, self.controller.signal_remove, self.remove_item
)
)
self.async_on_remove(
async_dispatcher_connect(
self.hass, self.controller.signal_options_update, self.options_updated
)
)
self.async_on_remove(
async_dispatcher_connect(
self.hass,
self.controller.signal_reachable,
self.async_signal_reachable_callback,
)
)
async def async_will_remove_from_hass(self) -> None: async def async_will_remove_from_hass(self) -> None:
"""Remove registered callbacks.""" """Disconnect object when removed."""
apps = self.controller.api.dpi_apps self.controller.entities[DOMAIN][DPI_SWITCH].remove(self._obj_id)
for app_id in self._item.dpiapp_ids:
apps[app_id].remove_callback(self.async_update_callback)
await super().async_will_remove_from_hass()
@callback @callback
def async_update_callback(self) -> None: def async_signalling_callback(self, event: ItemEvent, obj_id: str) -> None:
"""Update the DPI switch state. """Object has new event."""
if event == ItemEvent.DELETED:
Remove entity when no apps are paired with group. self.hass.async_create_task(self.remove_item({self._obj_id}))
Register callbacks to new apps.
Calculate and update entity state if it has changed.
"""
if not self._item.dpiapp_ids:
self.hass.async_create_task(self.remove_item({self.key}))
return return
if self._known_app_ids != self._item.dpiapp_ids: dpi_group = self.controller.api.dpi_groups[self._obj_id]
self._known_app_ids = self._item.dpiapp_ids if not dpi_group.dpiapp_ids:
self.hass.async_create_task(self.remove_item({self._obj_id}))
return
apps = self.controller.api.dpi_apps self._attr_available = self.controller.available
for app_id in self._item.dpiapp_ids: self._attr_is_on = self.calculate_enabled()
apps[app_id].register_callback(self.async_update_callback) self.async_write_ha_state()
if (enabled := self.calculate_enabled()) != self._is_enabled: @callback
self._is_enabled = enabled def async_signal_reachable_callback(self) -> None:
super().async_update_callback() """Call when controller connection state change."""
self.async_signalling_callback(ItemEvent.ADDED, self._obj_id)
@property
def unique_id(self):
"""Return a unique identifier for this switch."""
return self._item.id
@property
def name(self) -> str:
"""Return the name of the DPI group."""
return self._item.name
@property @property
def icon(self): def icon(self):
"""Return the icon to use in the frontend.""" """Return the icon to use in the frontend."""
if self._is_enabled: if self._attr_is_on:
return "mdi:network" return "mdi:network"
return "mdi:network-off" return "mdi:network-off"
def calculate_enabled(self) -> bool: def calculate_enabled(self) -> bool:
"""Calculate if all apps are enabled.""" """Calculate if all apps are enabled."""
dpi_group = self.controller.api.dpi_groups[self._obj_id]
return all( return all(
self.controller.api.dpi_apps[app_id].enabled self.controller.api.dpi_apps[app_id].enabled
for app_id in self._item.dpiapp_ids for app_id in dpi_group.dpiapp_ids
if app_id in self.controller.api.dpi_apps if app_id in self.controller.api.dpi_apps
) )
@property
def is_on(self):
"""Return true if DPI group app restriction is enabled."""
return self._is_enabled
async def async_turn_on(self, **kwargs: Any) -> None: async def async_turn_on(self, **kwargs: Any) -> None:
"""Restrict access of apps related to DPI group.""" """Restrict access of apps related to DPI group."""
dpi_group = self.controller.api.dpi_groups[self._obj_id]
return await asyncio.gather( return await asyncio.gather(
*[ *[
self.controller.api.request( self.controller.api.request(
DPIRestrictionAppEnableRequest.create(app_id, True) DPIRestrictionAppEnableRequest.create(app_id, True)
) )
for app_id in self._item.dpiapp_ids for app_id in dpi_group.dpiapp_ids
] ]
) )
async def async_turn_off(self, **kwargs: Any) -> None: async def async_turn_off(self, **kwargs: Any) -> None:
"""Remove restriction of apps related to DPI group.""" """Remove restriction of apps related to DPI group."""
dpi_group = self.controller.api.dpi_groups[self._obj_id]
return await asyncio.gather( return await asyncio.gather(
*[ *[
self.controller.api.request( self.controller.api.request(
DPIRestrictionAppEnableRequest.create(app_id, False) DPIRestrictionAppEnableRequest.create(app_id, False)
) )
for app_id in self._item.dpiapp_ids for app_id in dpi_group.dpiapp_ids
] ]
) )
async def options_updated(self) -> None: async def options_updated(self) -> None:
"""Config entry options are updated, remove entity if option is disabled.""" """Config entry options are updated, remove entity if option is disabled."""
if not self.controller.option_dpi_restrictions: if not self.controller.option_dpi_restrictions:
await self.remove_item({self.key}) await self.remove_item({self._attr_unique_id})
@property async def remove_item(self, keys: set) -> None:
def device_info(self) -> DeviceInfo: """Remove entity if key is part of set."""
"""Return a service description for device registry.""" if self._attr_unique_id not in keys:
return DeviceInfo( return
entry_type=DeviceEntryType.SERVICE,
identifiers={(DOMAIN, f"unifi_controller_{self._item.site_id}")}, if self.registry_entry:
manufacturer=ATTR_MANUFACTURER, er.async_get(self.hass).async_remove(self.entity_id)
model="UniFi Network", else:
name="UniFi Network", await self.async_remove(force_remove=True)
)
class UnifiOutletSwitch(SwitchEntity): class UnifiOutletSwitch(SwitchEntity):

View file

@ -761,7 +761,6 @@ async def test_remove_switches(hass, aioclient_mock, mock_unifi_websocket):
mock_unifi_websocket(data=DPI_GROUP_REMOVED_EVENT) mock_unifi_websocket(data=DPI_GROUP_REMOVED_EVENT)
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done()
assert hass.states.get("switch.block_media_streaming") is None assert hass.states.get("switch.block_media_streaming") is None
assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0 assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0
@ -852,10 +851,21 @@ async def test_dpi_switches(hass, aioclient_mock, mock_unifi_websocket):
assert hass.states.get("switch.block_media_streaming").state == STATE_OFF assert hass.states.get("switch.block_media_streaming").state == STATE_OFF
# Availability signalling
# Controller disconnects
mock_unifi_websocket(state=WebsocketState.DISCONNECTED)
await hass.async_block_till_done()
assert hass.states.get("switch.block_media_streaming").state == STATE_UNAVAILABLE
# Controller reconnects
mock_unifi_websocket(state=WebsocketState.RUNNING)
await hass.async_block_till_done()
assert hass.states.get("switch.block_media_streaming").state == STATE_OFF
# Remove app
mock_unifi_websocket(data=DPI_GROUP_REMOVE_APP) mock_unifi_websocket(data=DPI_GROUP_REMOVE_APP)
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done()
await hass.async_block_till_done()
assert hass.states.get("switch.block_media_streaming") is None assert hass.states.get("switch.block_media_streaming") is None
assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0 assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0