Add DPI Restriction switch to UniFi integration (#42499)

* initial implementation for controlling DPI restrictions

* address PR review comments and add DataUpdateCoordinator

* fix existing tests against new lib version

* add tests for DPI switches

* bump aiounifi

* listen to events instead of polling

* fix tests

* remove useless test

* bump aiounifi

* rename device to UniFi Controller per PR feedback
This commit is contained in:
Jason Hunter 2020-11-03 02:36:37 -05:00 committed by GitHub
parent aab0ff2ea5
commit 5a4c1dbcc4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 309 additions and 42 deletions

View file

@ -20,6 +20,7 @@ from .const import (
CONF_BLOCK_CLIENT, CONF_BLOCK_CLIENT,
CONF_CONTROLLER, CONF_CONTROLLER,
CONF_DETECTION_TIME, CONF_DETECTION_TIME,
CONF_DPI_RESTRICTIONS,
CONF_IGNORE_WIRED_BUG, CONF_IGNORE_WIRED_BUG,
CONF_POE_CLIENTS, CONF_POE_CLIENTS,
CONF_SITE_ID, CONF_SITE_ID,
@ -28,6 +29,7 @@ from .const import (
CONF_TRACK_DEVICES, CONF_TRACK_DEVICES,
CONF_TRACK_WIRED_CLIENTS, CONF_TRACK_WIRED_CLIENTS,
CONTROLLER_ID, CONTROLLER_ID,
DEFAULT_DPI_RESTRICTIONS,
DEFAULT_POE_CLIENTS, DEFAULT_POE_CLIENTS,
DOMAIN as UNIFI_DOMAIN, DOMAIN as UNIFI_DOMAIN,
LOGGER, LOGGER,
@ -295,6 +297,12 @@ class UnifiOptionsFlowHandler(config_entries.OptionsFlow):
CONF_POE_CLIENTS, CONF_POE_CLIENTS,
default=self.options.get(CONF_POE_CLIENTS, DEFAULT_POE_CLIENTS), default=self.options.get(CONF_POE_CLIENTS, DEFAULT_POE_CLIENTS),
): bool, ): bool,
vol.Optional(
CONF_DPI_RESTRICTIONS,
default=self.options.get(
CONF_DPI_RESTRICTIONS, DEFAULT_DPI_RESTRICTIONS
),
): bool,
} }
), ),
errors=errors, errors=errors,

View file

@ -15,6 +15,7 @@ CONF_ALLOW_BANDWIDTH_SENSORS = "allow_bandwidth_sensors"
CONF_ALLOW_UPTIME_SENSORS = "allow_uptime_sensors" CONF_ALLOW_UPTIME_SENSORS = "allow_uptime_sensors"
CONF_BLOCK_CLIENT = "block_client" CONF_BLOCK_CLIENT = "block_client"
CONF_DETECTION_TIME = "detection_time" CONF_DETECTION_TIME = "detection_time"
CONF_DPI_RESTRICTIONS = "dpi_restrictions"
CONF_IGNORE_WIRED_BUG = "ignore_wired_bug" CONF_IGNORE_WIRED_BUG = "ignore_wired_bug"
CONF_POE_CLIENTS = "poe_clients" CONF_POE_CLIENTS = "poe_clients"
CONF_TRACK_CLIENTS = "track_clients" CONF_TRACK_CLIENTS = "track_clients"
@ -24,6 +25,7 @@ CONF_SSID_FILTER = "ssid_filter"
DEFAULT_ALLOW_BANDWIDTH_SENSORS = False DEFAULT_ALLOW_BANDWIDTH_SENSORS = False
DEFAULT_ALLOW_UPTIME_SENSORS = False DEFAULT_ALLOW_UPTIME_SENSORS = False
DEFAULT_DPI_RESTRICTIONS = True
DEFAULT_IGNORE_WIRED_BUG = False DEFAULT_IGNORE_WIRED_BUG = False
DEFAULT_POE_CLIENTS = True DEFAULT_POE_CLIENTS = True
DEFAULT_TRACK_CLIENTS = True DEFAULT_TRACK_CLIENTS = True

View file

@ -7,6 +7,8 @@ from aiohttp import CookieJar
import aiounifi import aiounifi
from aiounifi.controller import ( from aiounifi.controller import (
DATA_CLIENT_REMOVED, DATA_CLIENT_REMOVED,
DATA_DPI_GROUP,
DATA_DPI_GROUP_REMOVED,
DATA_EVENT, DATA_EVENT,
SIGNAL_CONNECTION_STATE, SIGNAL_CONNECTION_STATE,
SIGNAL_DATA, SIGNAL_DATA,
@ -37,6 +39,7 @@ from .const import (
CONF_BLOCK_CLIENT, CONF_BLOCK_CLIENT,
CONF_CONTROLLER, CONF_CONTROLLER,
CONF_DETECTION_TIME, CONF_DETECTION_TIME,
CONF_DPI_RESTRICTIONS,
CONF_IGNORE_WIRED_BUG, CONF_IGNORE_WIRED_BUG,
CONF_POE_CLIENTS, CONF_POE_CLIENTS,
CONF_SITE_ID, CONF_SITE_ID,
@ -48,6 +51,7 @@ from .const import (
DEFAULT_ALLOW_BANDWIDTH_SENSORS, DEFAULT_ALLOW_BANDWIDTH_SENSORS,
DEFAULT_ALLOW_UPTIME_SENSORS, DEFAULT_ALLOW_UPTIME_SENSORS,
DEFAULT_DETECTION_TIME, DEFAULT_DETECTION_TIME,
DEFAULT_DPI_RESTRICTIONS,
DEFAULT_IGNORE_WIRED_BUG, DEFAULT_IGNORE_WIRED_BUG,
DEFAULT_POE_CLIENTS, DEFAULT_POE_CLIENTS,
DEFAULT_TRACK_CLIENTS, DEFAULT_TRACK_CLIENTS,
@ -177,6 +181,13 @@ class UniFiController:
"""Config entry option with list of clients to control network access.""" """Config entry option with list of clients to control network access."""
return self.config_entry.options.get(CONF_BLOCK_CLIENT, []) return self.config_entry.options.get(CONF_BLOCK_CLIENT, [])
@property
def option_dpi_restrictions(self):
"""Config entry option to control DPI restriction groups."""
return self.config_entry.options.get(
CONF_DPI_RESTRICTIONS, DEFAULT_DPI_RESTRICTIONS
)
# Statistics sensor options # Statistics sensor options
@property @property
@ -248,6 +259,18 @@ class UniFiController:
self.hass, self.signal_remove, data[DATA_CLIENT_REMOVED] self.hass, self.signal_remove, data[DATA_CLIENT_REMOVED]
) )
elif DATA_DPI_GROUP in data:
for key in data[DATA_DPI_GROUP]:
if self.api.dpi_groups[key].dpiapp_ids:
async_dispatcher_send(self.hass, self.signal_update)
else:
async_dispatcher_send(self.hass, self.signal_remove, {key})
elif DATA_DPI_GROUP_REMOVED in data:
async_dispatcher_send(
self.hass, self.signal_remove, data[DATA_DPI_GROUP_REMOVED]
)
@property @property
def signal_reachable(self) -> str: def signal_reachable(self) -> str:
"""Integration specific event to signal a change in connection status.""" """Integration specific event to signal a change in connection status."""

View file

@ -3,7 +3,7 @@
"name": "Ubiquiti UniFi", "name": "Ubiquiti UniFi",
"config_flow": true, "config_flow": true,
"documentation": "https://www.home-assistant.io/integrations/unifi", "documentation": "https://www.home-assistant.io/integrations/unifi",
"requirements": ["aiounifi==23"], "requirements": ["aiounifi==25"],
"codeowners": ["@Kane610"], "codeowners": ["@Kane610"],
"quality_scale": "platinum" "quality_scale": "platinum"
} }

View file

@ -39,7 +39,8 @@
"client_control": { "client_control": {
"data": { "data": {
"block_client": "Network access controlled clients", "block_client": "Network access controlled clients",
"poe_clients": "Allow POE control of clients" "poe_clients": "Allow POE control of clients",
"dpi_restrictions": "Allow control of DPI restriction groups"
}, },
"description": "Configure client controls\n\nCreate switches for serial numbers you want to control network access for.", "description": "Configure client controls\n\nCreate switches for serial numbers you want to control network access for.",
"title": "UniFi options 2/3" "title": "UniFi options 2/3"

View file

@ -1,5 +1,6 @@
"""Support for devices connected to UniFi POE.""" """Support for devices connected to UniFi POE."""
import logging import logging
from typing import Any
from aiounifi.api import SOURCE_EVENT from aiounifi.api import SOURCE_EVENT
from aiounifi.events import ( from aiounifi.events import (
@ -14,12 +15,14 @@ from homeassistant.core import callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.restore_state import RestoreEntity
from .const import DOMAIN as UNIFI_DOMAIN from .const import ATTR_MANUFACTURER, DOMAIN as UNIFI_DOMAIN
from .unifi_client import UniFiClient from .unifi_client import UniFiClient
from .unifi_entity_base import UniFiBase
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
BLOCK_SWITCH = "block" BLOCK_SWITCH = "block"
DPI_SWITCH = "dpi"
POE_SWITCH = "poe" POE_SWITCH = "poe"
CLIENT_BLOCKED = (WIRED_CLIENT_BLOCKED, WIRELESS_CLIENT_BLOCKED) CLIENT_BLOCKED = (WIRED_CLIENT_BLOCKED, WIRELESS_CLIENT_BLOCKED)
@ -32,7 +35,11 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
Switches are controlling network access and switch ports with POE. Switches are controlling network access and switch ports with POE.
""" """
controller = hass.data[UNIFI_DOMAIN][config_entry.entry_id] controller = hass.data[UNIFI_DOMAIN][config_entry.entry_id]
controller.entities[DOMAIN] = {BLOCK_SWITCH: set(), POE_SWITCH: set()} controller.entities[DOMAIN] = {
BLOCK_SWITCH: set(),
POE_SWITCH: set(),
DPI_SWITCH: set(),
}
if controller.site_role != "admin": if controller.site_role != "admin":
return return
@ -59,7 +66,9 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
@callback @callback
def items_added( def items_added(
clients: set = controller.api.clients, devices: set = controller.api.devices clients: set = controller.api.clients,
devices: set = controller.api.devices,
dpi_groups: set = controller.api.dpi_groups,
) -> None: ) -> None:
"""Update the values of the controller.""" """Update the values of the controller."""
if controller.option_block_clients: if controller.option_block_clients:
@ -70,6 +79,9 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
controller, async_add_entities, clients, previously_known_poe_clients controller, async_add_entities, clients, previously_known_poe_clients
) )
if controller.option_dpi_restrictions:
add_dpi_entities(controller, async_add_entities, dpi_groups)
for signal in (controller.signal_update, controller.signal_options_update): for signal in (controller.signal_update, controller.signal_options_update):
controller.listeners.append(async_dispatcher_connect(hass, signal, items_added)) controller.listeners.append(async_dispatcher_connect(hass, signal, items_added))
@ -143,6 +155,24 @@ def add_poe_entities(
async_add_entities(switches) async_add_entities(switches)
@callback
def add_dpi_entities(controller, async_add_entities, dpi_groups):
"""Add new switch entities from the controller."""
switches = []
for group in dpi_groups:
if (
group in controller.entities[DOMAIN][DPI_SWITCH]
or not dpi_groups[group].dpiapp_ids
):
continue
switches.append(UniFiDPIRestrictionSwitch(dpi_groups[group], controller))
if switches:
async_add_entities(switches)
class UniFiPOEClientSwitch(UniFiClient, SwitchEntity, RestoreEntity): class UniFiPOEClientSwitch(UniFiClient, SwitchEntity, RestoreEntity):
"""Representation of a client that uses POE.""" """Representation of a client that uses POE."""
@ -284,3 +314,61 @@ class UniFiBlockClientSwitch(UniFiClient, SwitchEntity):
"""Config entry options are updated, remove entity if option is disabled.""" """Config entry options are updated, remove entity if option is disabled."""
if self.client.mac not in self.controller.option_block_clients: if self.client.mac not in self.controller.option_block_clients:
await self.remove_item({self.client.mac}) await self.remove_item({self.client.mac})
class UniFiDPIRestrictionSwitch(UniFiBase, SwitchEntity):
"""Representation of a DPI restriction group."""
DOMAIN = DOMAIN
TYPE = DPI_SWITCH
@property
def key(self) -> Any:
"""Return item key."""
return self._item.id
@property
def unique_id(self):
"""Return a unique identifier for this switch."""
return self._item.id
@property
def name(self) -> str:
"""Return the name of the client."""
return self._item.name
@property
def icon(self):
"""Return the icon to use in the frontend."""
if self._item.enabled:
return "mdi:network"
return "mdi:network-off"
@property
def is_on(self):
"""Return true if client is allowed to connect."""
return self._item.enabled
async def async_turn_on(self, **kwargs):
"""Turn on connectivity for client."""
await self.controller.api.dpi_groups.async_enable(self._item)
async def async_turn_off(self, **kwargs):
"""Turn off connectivity for client."""
await self.controller.api.dpi_groups.async_disable(self._item)
async def options_updated(self) -> None:
"""Config entry options are updated, remove entity if option is disabled."""
if not self.controller.option_dpi_restrictions:
await self.remove_item({self.key})
@property
def device_info(self) -> dict:
"""Return a service description for device registry."""
return {
"identifiers": {(DOMAIN, f"unifi_controller_{self._item.site_id}")},
"name": "UniFi Controller",
"manufacturer": ATTR_MANUFACTURER,
"model": "UniFi Controller",
"entry_type": "service",
}

View file

@ -27,6 +27,7 @@
"client_control": { "client_control": {
"data": { "data": {
"block_client": "Network access controlled clients", "block_client": "Network access controlled clients",
"dpi_restrictions": "Allow control of DPI restriction groups",
"poe_clients": "Allow POE control of clients" "poe_clients": "Allow POE control of clients"
}, },
"description": "Configure client controls\n\nCreate switches for serial numbers you want to control network access for.", "description": "Configure client controls\n\nCreate switches for serial numbers you want to control network access for.",

View file

@ -1,5 +1,6 @@
"""Base class for UniFi entities.""" """Base class for UniFi entities."""
import logging import logging
from typing import Any
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
@ -22,12 +23,20 @@ class UniFiBase(Entity):
""" """
self._item = item self._item = item
self.controller = controller self.controller = controller
self.controller.entities[self.DOMAIN][self.TYPE].add(item.mac) self.controller.entities[self.DOMAIN][self.TYPE].add(self.key)
@property
def key(self) -> Any:
"""Return item key."""
return self._item.mac
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Entity created.""" """Entity created."""
_LOGGER.debug( _LOGGER.debug(
"New %s entity %s (%s)", self.TYPE, self.entity_id, self._item.mac "New %s entity %s (%s)",
self.TYPE,
self.entity_id,
self.key,
) )
for signal, method in ( for signal, method in (
(self.controller.signal_reachable, self.async_update_callback), (self.controller.signal_reachable, self.async_update_callback),
@ -40,16 +49,22 @@ class UniFiBase(Entity):
async def async_will_remove_from_hass(self) -> None: async def async_will_remove_from_hass(self) -> None:
"""Disconnect object when removed.""" """Disconnect object when removed."""
_LOGGER.debug( _LOGGER.debug(
"Removing %s entity %s (%s)", self.TYPE, self.entity_id, self._item.mac "Removing %s entity %s (%s)",
self.TYPE,
self.entity_id,
self.key,
) )
self._item.remove_callback(self.async_update_callback) self._item.remove_callback(self.async_update_callback)
self.controller.entities[self.DOMAIN][self.TYPE].remove(self._item.mac) self.controller.entities[self.DOMAIN][self.TYPE].remove(self.key)
@callback @callback
def async_update_callback(self) -> None: def async_update_callback(self) -> None:
"""Update the entity's state.""" """Update the entity's state."""
_LOGGER.debug( _LOGGER.debug(
"Updating %s entity %s (%s)", self.TYPE, self.entity_id, self._item.mac "Updating %s entity %s (%s)",
self.TYPE,
self.entity_id,
self.key,
) )
self.async_write_ha_state() self.async_write_ha_state()
@ -57,15 +72,15 @@ class UniFiBase(Entity):
"""Config entry options are updated, remove entity if option is disabled.""" """Config entry options are updated, remove entity if option is disabled."""
raise NotImplementedError raise NotImplementedError
async def remove_item(self, mac_addresses: set) -> None: async def remove_item(self, keys: set) -> None:
"""Remove entity if MAC is part of set. """Remove entity if key is part of set.
Remove entity if no entry in entity registry exist. Remove entity if no entry in entity registry exist.
Remove entity registry entry if no entry in device registry exist. Remove entity registry entry if no entry in device registry exist.
Remove device registry entry if there is only one linked entity (this entity). Remove device registry entry if there is only one linked entity (this entity).
Remove entity registry entry if there are more than one entity linked to the device registry entry. Remove entity registry entry if there are more than one entity linked to the device registry entry.
""" """
if self._item.mac not in mac_addresses: if self.key not in keys:
return return
entity_registry = await self.hass.helpers.entity_registry.async_get_registry() entity_registry = await self.hass.helpers.entity_registry.async_get_registry()

View file

@ -227,7 +227,7 @@ aioshelly==0.5.0
aioswitcher==1.2.1 aioswitcher==1.2.1
# homeassistant.components.unifi # homeassistant.components.unifi
aiounifi==23 aiounifi==25
# homeassistant.components.yandex_transport # homeassistant.components.yandex_transport
aioymaps==1.1.0 aioymaps==1.1.0

View file

@ -143,7 +143,7 @@ aioshelly==0.5.0
aioswitcher==1.2.1 aioswitcher==1.2.1
# homeassistant.components.unifi # homeassistant.components.unifi
aiounifi==23 aiounifi==25
# homeassistant.components.yandex_transport # homeassistant.components.yandex_transport
aioymaps==1.1.0 aioymaps==1.1.0

View file

@ -8,6 +8,7 @@ from homeassistant.components.unifi.const import (
CONF_BLOCK_CLIENT, CONF_BLOCK_CLIENT,
CONF_CONTROLLER, CONF_CONTROLLER,
CONF_DETECTION_TIME, CONF_DETECTION_TIME,
CONF_DPI_RESTRICTIONS,
CONF_IGNORE_WIRED_BUG, CONF_IGNORE_WIRED_BUG,
CONF_POE_CLIENTS, CONF_POE_CLIENTS,
CONF_SITE_ID, CONF_SITE_ID,
@ -72,6 +73,14 @@ WLANS = [
{"name": "SSID 2", "name_combine_enabled": False, "name_combine_suffix": "_IOT"}, {"name": "SSID 2", "name_combine_enabled": False, "name_combine_suffix": "_IOT"},
] ]
DPI_GROUPS = [
{
"_id": "5ba29dd8e3c58f026e9d7c4a",
"name": "Default",
"site_id": "5ba29dd4e3c58f026e9d7c38",
},
]
async def test_flow_works(hass, aioclient_mock, mock_discovery): async def test_flow_works(hass, aioclient_mock, mock_discovery):
"""Test config flow.""" """Test config flow."""
@ -307,7 +316,12 @@ async def test_flow_fails_unknown_problem(hass, aioclient_mock):
async def test_advanced_option_flow(hass): async def test_advanced_option_flow(hass):
"""Test advanced config flow options.""" """Test advanced config flow options."""
controller = await setup_unifi_integration( controller = await setup_unifi_integration(
hass, clients_response=CLIENTS, devices_response=DEVICES, wlans_response=WLANS hass,
clients_response=CLIENTS,
devices_response=DEVICES,
wlans_response=WLANS,
dpigroup_response=DPI_GROUPS,
dpiapp_response=[],
) )
result = await hass.config_entries.options.async_init( result = await hass.config_entries.options.async_init(
@ -336,7 +350,11 @@ async def test_advanced_option_flow(hass):
result = await hass.config_entries.options.async_configure( result = await hass.config_entries.options.async_configure(
result["flow_id"], result["flow_id"],
user_input={CONF_BLOCK_CLIENT: [CLIENTS[0]["mac"]], CONF_POE_CLIENTS: False}, user_input={
CONF_BLOCK_CLIENT: [CLIENTS[0]["mac"]],
CONF_POE_CLIENTS: False,
CONF_DPI_RESTRICTIONS: False,
},
) )
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
@ -359,6 +377,7 @@ async def test_advanced_option_flow(hass):
CONF_DETECTION_TIME: 100, CONF_DETECTION_TIME: 100,
CONF_IGNORE_WIRED_BUG: False, CONF_IGNORE_WIRED_BUG: False,
CONF_POE_CLIENTS: False, CONF_POE_CLIENTS: False,
CONF_DPI_RESTRICTIONS: False,
CONF_BLOCK_CLIENT: [CLIENTS[0]["mac"]], CONF_BLOCK_CLIENT: [CLIENTS[0]["mac"]],
CONF_ALLOW_BANDWIDTH_SENSORS: True, CONF_ALLOW_BANDWIDTH_SENSORS: True,
CONF_ALLOW_UPTIME_SENSORS: True, CONF_ALLOW_UPTIME_SENSORS: True,
@ -368,7 +387,11 @@ async def test_advanced_option_flow(hass):
async def test_simple_option_flow(hass): async def test_simple_option_flow(hass):
"""Test simple config flow options.""" """Test simple config flow options."""
controller = await setup_unifi_integration( controller = await setup_unifi_integration(
hass, clients_response=CLIENTS, wlans_response=WLANS hass,
clients_response=CLIENTS,
wlans_response=WLANS,
dpigroup_response=DPI_GROUPS,
dpiapp_response=[],
) )
result = await hass.config_entries.options.async_init( result = await hass.config_entries.options.async_init(

View file

@ -81,6 +81,8 @@ async def setup_unifi_integration(
devices_response=None, devices_response=None,
clients_all_response=None, clients_all_response=None,
wlans_response=None, wlans_response=None,
dpigroup_response=None,
dpiapp_response=None,
known_wireless_clients=None, known_wireless_clients=None,
controllers=None, controllers=None,
): ):
@ -116,6 +118,14 @@ async def setup_unifi_integration(
if wlans_response: if wlans_response:
mock_wlans_responses.append(wlans_response) mock_wlans_responses.append(wlans_response)
mock_dpigroup_responses = deque()
if dpigroup_response:
mock_dpigroup_responses.append(dpigroup_response)
mock_dpiapp_responses = deque()
if dpiapp_response:
mock_dpiapp_responses.append(dpiapp_response)
mock_requests = [] mock_requests = []
async def mock_request(self, method, path, json=None): async def mock_request(self, method, path, json=None):
@ -129,6 +139,10 @@ async def setup_unifi_integration(
return mock_client_all_responses.popleft() return mock_client_all_responses.popleft()
if path == "/rest/wlanconf" and mock_wlans_responses: if path == "/rest/wlanconf" and mock_wlans_responses:
return mock_wlans_responses.popleft() return mock_wlans_responses.popleft()
if path == "/rest/dpigroup" and mock_dpigroup_responses:
return mock_dpigroup_responses.popleft()
if path == "/rest/dpiapp" and mock_dpiapp_responses:
return mock_dpiapp_responses.popleft()
return {} return {}
with patch("aiounifi.Controller.check_unifi_os", return_value=True), patch( with patch("aiounifi.Controller.check_unifi_os", return_value=True), patch(

View file

@ -71,7 +71,7 @@ async def test_no_clients(hass):
}, },
) )
assert len(controller.mock_requests) == 4 assert len(controller.mock_requests) == 6
assert len(hass.states.async_entity_ids(SENSOR_DOMAIN)) == 0 assert len(hass.states.async_entity_ids(SENSOR_DOMAIN)) == 0
@ -88,7 +88,7 @@ async def test_sensors(hass):
clients_response=CLIENTS, clients_response=CLIENTS,
) )
assert len(controller.mock_requests) == 4 assert len(controller.mock_requests) == 6
assert len(hass.states.async_entity_ids(SENSOR_DOMAIN)) == 6 assert len(hass.states.async_entity_ids(SENSOR_DOMAIN)) == 6
wired_client_rx = hass.states.get("sensor.wired_client_name_rx") wired_client_rx = hass.states.get("sensor.wired_client_name_rx")

View file

@ -9,6 +9,8 @@ from homeassistant.components.device_tracker import DOMAIN as TRACKER_DOMAIN
from homeassistant.components.switch import DOMAIN as SWITCH_DOMAIN from homeassistant.components.switch import DOMAIN as SWITCH_DOMAIN
from homeassistant.components.unifi.const import ( from homeassistant.components.unifi.const import (
CONF_BLOCK_CLIENT, CONF_BLOCK_CLIENT,
CONF_DPI_RESTRICTIONS,
CONF_POE_CLIENTS,
CONF_TRACK_CLIENTS, CONF_TRACK_CLIENTS,
CONF_TRACK_DEVICES, CONF_TRACK_DEVICES,
DOMAIN as UNIFI_DOMAIN, DOMAIN as UNIFI_DOMAIN,
@ -251,6 +253,35 @@ EVENT_CLIENT_2_CONNECTED = {
} }
DPI_GROUPS = [
{
"_id": "5ba29dd8e3c58f026e9d7c4a",
"attr_no_delete": True,
"attr_hidden_id": "Default",
"name": "Default",
"site_id": "name",
},
{
"_id": "5f976f4ae3c58f018ec7dff6",
"name": "Block Media Streaming",
"site_id": "name",
"dpiapp_ids": ["5f976f62e3c58f018ec7e17d"],
},
]
DPI_APPS = [
{
"_id": "5f976f62e3c58f018ec7e17d",
"apps": [],
"blocked": True,
"cats": ["4"],
"enabled": True,
"log": True,
"site_id": "name",
}
]
async def test_platform_manually_configured(hass): async def test_platform_manually_configured(hass):
"""Test that we do not discover anything or try to set up a controller.""" """Test that we do not discover anything or try to set up a controller."""
assert ( assert (
@ -266,10 +297,14 @@ async def test_no_clients(hass):
"""Test the update_clients function when no clients are found.""" """Test the update_clients function when no clients are found."""
controller = await setup_unifi_integration( controller = await setup_unifi_integration(
hass, hass,
options={CONF_TRACK_CLIENTS: False, CONF_TRACK_DEVICES: False}, options={
CONF_TRACK_CLIENTS: False,
CONF_TRACK_DEVICES: False,
CONF_DPI_RESTRICTIONS: False,
},
) )
assert len(controller.mock_requests) == 4 assert len(controller.mock_requests) == 6
assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0 assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0
@ -282,7 +317,7 @@ async def test_controller_not_client(hass):
devices_response=[DEVICE_1], devices_response=[DEVICE_1],
) )
assert len(controller.mock_requests) == 4 assert len(controller.mock_requests) == 6
assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0 assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0
cloudkey = hass.states.get("switch.cloud_key") cloudkey = hass.states.get("switch.cloud_key")
assert cloudkey is None assert cloudkey is None
@ -300,7 +335,7 @@ async def test_not_admin(hass):
devices_response=[DEVICE_1], devices_response=[DEVICE_1],
) )
assert len(controller.mock_requests) == 4 assert len(controller.mock_requests) == 6
assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0 assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0
@ -316,10 +351,12 @@ async def test_switches(hass):
clients_response=[CLIENT_1, CLIENT_4], clients_response=[CLIENT_1, CLIENT_4],
devices_response=[DEVICE_1], devices_response=[DEVICE_1],
clients_all_response=[BLOCKED, UNBLOCKED, CLIENT_1], clients_all_response=[BLOCKED, UNBLOCKED, CLIENT_1],
dpigroup_response=DPI_GROUPS,
dpiapp_response=DPI_APPS,
) )
assert len(controller.mock_requests) == 4 assert len(controller.mock_requests) == 6
assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 3 assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 4
switch_1 = hass.states.get("switch.poe_client_1") switch_1 = hass.states.get("switch.poe_client_1")
assert switch_1 is not None assert switch_1 is not None
@ -340,11 +377,15 @@ async def test_switches(hass):
assert unblocked is not None assert unblocked is not None
assert unblocked.state == "on" assert unblocked.state == "on"
dpi_switch = hass.states.get("switch.block_media_streaming")
assert dpi_switch is not None
assert dpi_switch.state == "on"
await hass.services.async_call( await hass.services.async_call(
SWITCH_DOMAIN, "turn_off", {"entity_id": "switch.block_client_1"}, blocking=True SWITCH_DOMAIN, "turn_off", {"entity_id": "switch.block_client_1"}, blocking=True
) )
assert len(controller.mock_requests) == 5 assert len(controller.mock_requests) == 7
assert controller.mock_requests[4] == { assert controller.mock_requests[6] == {
"json": {"mac": "00:00:00:00:01:01", "cmd": "block-sta"}, "json": {"mac": "00:00:00:00:01:01", "cmd": "block-sta"},
"method": "post", "method": "post",
"path": "/cmd/stamgr", "path": "/cmd/stamgr",
@ -353,13 +394,39 @@ async def test_switches(hass):
await hass.services.async_call( await hass.services.async_call(
SWITCH_DOMAIN, "turn_on", {"entity_id": "switch.block_client_1"}, blocking=True SWITCH_DOMAIN, "turn_on", {"entity_id": "switch.block_client_1"}, blocking=True
) )
assert len(controller.mock_requests) == 6 assert len(controller.mock_requests) == 8
assert controller.mock_requests[5] == { assert controller.mock_requests[7] == {
"json": {"mac": "00:00:00:00:01:01", "cmd": "unblock-sta"}, "json": {"mac": "00:00:00:00:01:01", "cmd": "unblock-sta"},
"method": "post", "method": "post",
"path": "/cmd/stamgr", "path": "/cmd/stamgr",
} }
await hass.services.async_call(
SWITCH_DOMAIN,
"turn_off",
{"entity_id": "switch.block_media_streaming"},
blocking=True,
)
assert len(controller.mock_requests) == 9
assert controller.mock_requests[8] == {
"json": {"enabled": False},
"method": "put",
"path": "/rest/dpiapp/5f976f62e3c58f018ec7e17d",
}
await hass.services.async_call(
SWITCH_DOMAIN,
"turn_on",
{"entity_id": "switch.block_media_streaming"},
blocking=True,
)
assert len(controller.mock_requests) == 10
assert controller.mock_requests[9] == {
"json": {"enabled": True},
"method": "put",
"path": "/rest/dpiapp/5f976f62e3c58f018ec7e17d",
}
async def test_remove_switches(hass): async def test_remove_switches(hass):
"""Test the update_items function with some clients.""" """Test the update_items function with some clients."""
@ -443,8 +510,8 @@ async def test_block_switches(hass):
await hass.services.async_call( await hass.services.async_call(
SWITCH_DOMAIN, "turn_off", {"entity_id": "switch.block_client_1"}, blocking=True SWITCH_DOMAIN, "turn_off", {"entity_id": "switch.block_client_1"}, blocking=True
) )
assert len(controller.mock_requests) == 5 assert len(controller.mock_requests) == 7
assert controller.mock_requests[4] == { assert controller.mock_requests[6] == {
"json": {"mac": "00:00:00:00:01:01", "cmd": "block-sta"}, "json": {"mac": "00:00:00:00:01:01", "cmd": "block-sta"},
"method": "post", "method": "post",
"path": "/cmd/stamgr", "path": "/cmd/stamgr",
@ -453,8 +520,8 @@ async def test_block_switches(hass):
await hass.services.async_call( await hass.services.async_call(
SWITCH_DOMAIN, "turn_on", {"entity_id": "switch.block_client_1"}, blocking=True SWITCH_DOMAIN, "turn_on", {"entity_id": "switch.block_client_1"}, blocking=True
) )
assert len(controller.mock_requests) == 6 assert len(controller.mock_requests) == 8
assert controller.mock_requests[5] == { assert controller.mock_requests[7] == {
"json": {"mac": "00:00:00:00:01:01", "cmd": "unblock-sta"}, "json": {"mac": "00:00:00:00:01:01", "cmd": "unblock-sta"},
"method": "post", "method": "post",
"path": "/cmd/stamgr", "path": "/cmd/stamgr",
@ -469,10 +536,11 @@ async def test_new_client_discovered_on_block_control(hass):
CONF_BLOCK_CLIENT: [BLOCKED["mac"]], CONF_BLOCK_CLIENT: [BLOCKED["mac"]],
CONF_TRACK_CLIENTS: False, CONF_TRACK_CLIENTS: False,
CONF_TRACK_DEVICES: False, CONF_TRACK_DEVICES: False,
CONF_DPI_RESTRICTIONS: False,
}, },
) )
assert len(controller.mock_requests) == 4 assert len(controller.mock_requests) == 6
assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0 assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0
blocked = hass.states.get("switch.block_client_1") blocked = hass.states.get("switch.block_client_1")
@ -541,6 +609,30 @@ async def test_option_block_clients(hass):
assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0 assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0
async def test_option_remove_switches(hass):
"""Test removal of DPI switch when options updated."""
controller = await setup_unifi_integration(
hass,
options={
CONF_TRACK_CLIENTS: False,
CONF_TRACK_DEVICES: False,
},
clients_response=[CLIENT_1],
devices_response=[DEVICE_1],
dpigroup_response=DPI_GROUPS,
dpiapp_response=DPI_APPS,
)
assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 2
# Disable DPI Switches
hass.config_entries.async_update_entry(
controller.config_entry,
options={CONF_DPI_RESTRICTIONS: False, CONF_POE_CLIENTS: False},
)
await hass.async_block_till_done()
assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0
async def test_new_client_discovered_on_poe_control(hass): async def test_new_client_discovered_on_poe_control(hass):
"""Test if 2nd update has a new client.""" """Test if 2nd update has a new client."""
controller = await setup_unifi_integration( controller = await setup_unifi_integration(
@ -550,7 +642,7 @@ async def test_new_client_discovered_on_poe_control(hass):
devices_response=[DEVICE_1], devices_response=[DEVICE_1],
) )
assert len(controller.mock_requests) == 4 assert len(controller.mock_requests) == 6
assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 1 assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 1
controller.api.websocket._data = { controller.api.websocket._data = {
@ -576,9 +668,9 @@ async def test_new_client_discovered_on_poe_control(hass):
await hass.services.async_call( await hass.services.async_call(
SWITCH_DOMAIN, "turn_off", {"entity_id": "switch.poe_client_1"}, blocking=True SWITCH_DOMAIN, "turn_off", {"entity_id": "switch.poe_client_1"}, blocking=True
) )
assert len(controller.mock_requests) == 5 assert len(controller.mock_requests) == 7
assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 2 assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 2
assert controller.mock_requests[4] == { assert controller.mock_requests[6] == {
"json": { "json": {
"port_overrides": [{"port_idx": 1, "portconf_id": "1a1", "poe_mode": "off"}] "port_overrides": [{"port_idx": 1, "portconf_id": "1a1", "poe_mode": "off"}]
}, },
@ -589,8 +681,8 @@ async def test_new_client_discovered_on_poe_control(hass):
await hass.services.async_call( await hass.services.async_call(
SWITCH_DOMAIN, "turn_on", {"entity_id": "switch.poe_client_1"}, blocking=True SWITCH_DOMAIN, "turn_on", {"entity_id": "switch.poe_client_1"}, blocking=True
) )
assert len(controller.mock_requests) == 6 assert len(controller.mock_requests) == 8
assert controller.mock_requests[4] == { assert controller.mock_requests[7] == {
"json": { "json": {
"port_overrides": [ "port_overrides": [
{"port_idx": 1, "portconf_id": "1a1", "poe_mode": "auto"} {"port_idx": 1, "portconf_id": "1a1", "poe_mode": "auto"}
@ -613,7 +705,7 @@ async def test_ignore_multiple_poe_clients_on_same_port(hass):
devices_response=[DEVICE_1], devices_response=[DEVICE_1],
) )
assert len(controller.mock_requests) == 4 assert len(controller.mock_requests) == 6
assert len(hass.states.async_entity_ids(TRACKER_DOMAIN)) == 3 assert len(hass.states.async_entity_ids(TRACKER_DOMAIN)) == 3
switch_1 = hass.states.get("switch.poe_client_1") switch_1 = hass.states.get("switch.poe_client_1")
@ -664,7 +756,7 @@ async def test_restoring_client(hass):
clients_all_response=[CLIENT_1], clients_all_response=[CLIENT_1],
) )
assert len(controller.mock_requests) == 4 assert len(controller.mock_requests) == 6
assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 2 assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 2
device_1 = hass.states.get("switch.client_1") device_1 = hass.states.get("switch.client_1")