UniFi - Add block network access control to config option (#32004)

* Add block network access control to config option

* Clean up
This commit is contained in:
Paulus Schoutsen 2020-03-04 21:55:56 -08:00 committed by GitHub
parent 1615a5ee81
commit d216c1f2ac
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 314 additions and 104 deletions

View file

@ -6,7 +6,8 @@
}, },
"error": { "error": {
"faulty_credentials": "Bad user credentials", "faulty_credentials": "Bad user credentials",
"service_unavailable": "No service available" "service_unavailable": "No service available",
"unknown_client_mac": "No client available on that MAC address"
}, },
"step": { "step": {
"user": { "user": {
@ -34,15 +35,26 @@
"track_wired_clients": "Include wired network clients" "track_wired_clients": "Include wired network clients"
}, },
"description": "Configure device tracking", "description": "Configure device tracking",
"title": "UniFi options" "title": "UniFi options 1/3"
},
"client_control": {
"data": {
"block_client": "Network access controlled clients",
"new_client": "Add new client (MAC) for network access control"
},
"description": "Configure client controls\n\nCreate switches for serial numbers you want to control network access for.",
"title": "UniFi options 2/3"
}, },
"statistics_sensors": { "statistics_sensors": {
"data": { "data": {
"allow_bandwidth_sensors": "Bandwidth usage sensors for network clients" "allow_bandwidth_sensors": "Bandwidth usage sensors for network clients"
}, },
"description": "Configure statistics sensors", "description": "Configure statistics sensors",
"title": "UniFi options" "title": "UniFi options 3/3"
} }
},
"error": {
"unknown_client_mac": "No client available in UniFi on that MAC address"
} }
} }
} }

View file

@ -16,6 +16,7 @@ import homeassistant.helpers.config_validation as cv
from .const import ( from .const import (
CONF_ALLOW_BANDWIDTH_SENSORS, CONF_ALLOW_BANDWIDTH_SENSORS,
CONF_BLOCK_CLIENT,
CONF_CONTROLLER, CONF_CONTROLLER,
CONF_DETECTION_TIME, CONF_DETECTION_TIME,
CONF_SITE_ID, CONF_SITE_ID,
@ -30,6 +31,7 @@ from .const import (
from .controller import get_controller from .controller import get_controller
from .errors import AlreadyConfigured, AuthenticationRequired, CannotConnect from .errors import AlreadyConfigured, AuthenticationRequired, CannotConnect
CONF_NEW_CLIENT = "new_client"
DEFAULT_PORT = 8443 DEFAULT_PORT = 8443
DEFAULT_SITE_ID = "default" DEFAULT_SITE_ID = "default"
DEFAULT_VERIFY_SSL = False DEFAULT_VERIFY_SSL = False
@ -171,61 +173,117 @@ class UnifiOptionsFlowHandler(config_entries.OptionsFlow):
"""Initialize UniFi options flow.""" """Initialize UniFi options flow."""
self.config_entry = config_entry self.config_entry = config_entry
self.options = dict(config_entry.options) self.options = dict(config_entry.options)
self.controller = None
async def async_step_init(self, user_input=None): async def async_step_init(self, user_input=None):
"""Manage the UniFi options.""" """Manage the UniFi options."""
self.controller = get_controller_from_config_entry(self.hass, self.config_entry)
self.options[CONF_BLOCK_CLIENT] = self.controller.option_block_clients
return await self.async_step_device_tracker() return await self.async_step_device_tracker()
async def async_step_device_tracker(self, user_input=None): async def async_step_device_tracker(self, user_input=None):
"""Manage the device tracker options.""" """Manage the device tracker options."""
if user_input is not None: if user_input is not None:
self.options.update(user_input) self.options.update(user_input)
return await self.async_step_statistics_sensors() return await self.async_step_client_control()
controller = get_controller_from_config_entry(self.hass, self.config_entry) ssid_filter = {wlan: wlan for wlan in self.controller.api.wlans}
ssid_filter = {wlan: wlan for wlan in controller.api.wlans}
return self.async_show_form( return self.async_show_form(
step_id="device_tracker", step_id="device_tracker",
data_schema=vol.Schema( data_schema=vol.Schema(
{ {
vol.Optional( vol.Optional(
CONF_TRACK_CLIENTS, default=controller.option_track_clients, CONF_TRACK_CLIENTS,
default=self.controller.option_track_clients,
): bool, ): bool,
vol.Optional( vol.Optional(
CONF_TRACK_WIRED_CLIENTS, CONF_TRACK_WIRED_CLIENTS,
default=controller.option_track_wired_clients, default=self.controller.option_track_wired_clients,
): bool, ): bool,
vol.Optional( vol.Optional(
CONF_TRACK_DEVICES, default=controller.option_track_devices, CONF_TRACK_DEVICES,
default=self.controller.option_track_devices,
): bool, ): bool,
vol.Optional( vol.Optional(
CONF_SSID_FILTER, default=controller.option_ssid_filter CONF_SSID_FILTER, default=self.controller.option_ssid_filter
): cv.multi_select(ssid_filter), ): cv.multi_select(ssid_filter),
vol.Optional( vol.Optional(
CONF_DETECTION_TIME, CONF_DETECTION_TIME,
default=int(controller.option_detection_time.total_seconds()), default=int(
self.controller.option_detection_time.total_seconds()
),
): int, ): int,
} }
), ),
) )
async def async_step_client_control(self, user_input=None):
"""Manage configuration of network access controlled clients."""
errors = {}
if user_input is not None:
new_client = user_input.pop(CONF_NEW_CLIENT, None)
self.options.update(user_input)
if new_client:
if (
new_client in self.controller.api.clients
or new_client in self.controller.api.clients_all
):
self.options[CONF_BLOCK_CLIENT].append(new_client)
else:
errors["base"] = "unknown_client_mac"
else:
return await self.async_step_statistics_sensors()
clients_to_block = {}
for mac in self.options[CONF_BLOCK_CLIENT]:
name = None
for clients in [
self.controller.api.clients,
self.controller.api.clients_all,
]:
if mac in clients:
name = f"{clients[mac].name or clients[mac].hostname} ({mac})"
break
if not name:
name = mac
clients_to_block[mac] = name
return self.async_show_form(
step_id="client_control",
data_schema=vol.Schema(
{
vol.Optional(
CONF_BLOCK_CLIENT, default=self.options[CONF_BLOCK_CLIENT]
): cv.multi_select(clients_to_block),
vol.Optional(CONF_NEW_CLIENT): str,
}
),
errors=errors,
)
async def async_step_statistics_sensors(self, user_input=None): async def async_step_statistics_sensors(self, user_input=None):
"""Manage the statistics sensors options.""" """Manage the statistics sensors options."""
if user_input is not None: if user_input is not None:
self.options.update(user_input) self.options.update(user_input)
return await self._update_options() return await self._update_options()
controller = get_controller_from_config_entry(self.hass, self.config_entry)
return self.async_show_form( return self.async_show_form(
step_id="statistics_sensors", step_id="statistics_sensors",
data_schema=vol.Schema( data_schema=vol.Schema(
{ {
vol.Optional( vol.Optional(
CONF_ALLOW_BANDWIDTH_SENSORS, CONF_ALLOW_BANDWIDTH_SENSORS,
default=controller.option_allow_bandwidth_sensors, default=self.controller.option_allow_bandwidth_sensors,
): bool ): bool
} }
), ),

View file

@ -25,11 +25,9 @@ CONF_DONT_TRACK_DEVICES = "dont_track_devices"
CONF_DONT_TRACK_WIRED_CLIENTS = "dont_track_wired_clients" CONF_DONT_TRACK_WIRED_CLIENTS = "dont_track_wired_clients"
DEFAULT_ALLOW_BANDWIDTH_SENSORS = False DEFAULT_ALLOW_BANDWIDTH_SENSORS = False
DEFAULT_BLOCK_CLIENTS = []
DEFAULT_TRACK_CLIENTS = True DEFAULT_TRACK_CLIENTS = True
DEFAULT_TRACK_DEVICES = True DEFAULT_TRACK_DEVICES = True
DEFAULT_TRACK_WIRED_CLIENTS = True DEFAULT_TRACK_WIRED_CLIENTS = True
DEFAULT_DETECTION_TIME = 300 DEFAULT_DETECTION_TIME = 300
DEFAULT_SSID_FILTER = []
ATTR_MANUFACTURER = "Ubiquiti Networks" ATTR_MANUFACTURER = "Ubiquiti Networks"

View file

@ -31,9 +31,7 @@ from .const import (
CONF_TRACK_WIRED_CLIENTS, CONF_TRACK_WIRED_CLIENTS,
CONTROLLER_ID, CONTROLLER_ID,
DEFAULT_ALLOW_BANDWIDTH_SENSORS, DEFAULT_ALLOW_BANDWIDTH_SENSORS,
DEFAULT_BLOCK_CLIENTS,
DEFAULT_DETECTION_TIME, DEFAULT_DETECTION_TIME,
DEFAULT_SSID_FILTER,
DEFAULT_TRACK_CLIENTS, DEFAULT_TRACK_CLIENTS,
DEFAULT_TRACK_DEVICES, DEFAULT_TRACK_DEVICES,
DEFAULT_TRACK_WIRED_CLIENTS, DEFAULT_TRACK_WIRED_CLIENTS,
@ -99,7 +97,7 @@ class UniFiController:
@property @property
def option_block_clients(self): def option_block_clients(self):
"""Config entry option with list of clients to control network access.""" """Config entry option with list of clients to control network access."""
return self.config_entry.options.get(CONF_BLOCK_CLIENT, DEFAULT_BLOCK_CLIENTS) return self.config_entry.options.get(CONF_BLOCK_CLIENT, [])
@property @property
def option_track_clients(self): def option_track_clients(self):
@ -130,7 +128,7 @@ class UniFiController:
@property @property
def option_ssid_filter(self): def option_ssid_filter(self):
"""Config entry option listing what SSIDs are being used to track clients.""" """Config entry option listing what SSIDs are being used to track clients."""
return self.config_entry.options.get(CONF_SSID_FILTER, DEFAULT_SSID_FILTER) return self.config_entry.options.get(CONF_SSID_FILTER, [])
@property @property
def mac(self): def mac(self):

View file

@ -16,7 +16,8 @@
}, },
"error": { "error": {
"faulty_credentials": "Bad user credentials", "faulty_credentials": "Bad user credentials",
"service_unavailable": "No service available" "service_unavailable": "No service available",
"unknown_client_mac": "No client available on that MAC address"
}, },
"abort": { "abort": {
"already_configured": "Controller site is already configured", "already_configured": "Controller site is already configured",
@ -37,15 +38,26 @@
"track_wired_clients": "Include wired network clients" "track_wired_clients": "Include wired network clients"
}, },
"description": "Configure device tracking", "description": "Configure device tracking",
"title": "UniFi options" "title": "UniFi options 1/3"
},
"client_control": {
"data": {
"block_client": "Network access controlled clients",
"new_client": "Add new client for network access control"
},
"description": "Configure client controls\n\nCreate switches for serial numbers you want to control network access for.",
"title": "UniFi options 2/3"
}, },
"statistics_sensors": { "statistics_sensors": {
"data": { "data": {
"allow_bandwidth_sensors": "Bandwidth usage sensors for network clients" "allow_bandwidth_sensors": "Bandwidth usage sensors for network clients"
}, },
"description": "Configure statistics sensors", "description": "Configure statistics sensors",
"title": "UniFi options" "title": "UniFi options 3/3"
} }
} }
},
"error": {
"unknown_client_mac": "No client available in UniFi on that MAC address"
} }
} }

View file

@ -4,7 +4,6 @@ import logging
from homeassistant.components.switch import SwitchDevice from homeassistant.components.switch import SwitchDevice
from homeassistant.components.unifi.config_flow import get_controller_from_config_entry from homeassistant.components.unifi.config_flow import get_controller_from_config_entry
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers import entity_registry
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.restore_state import RestoreEntity
@ -30,10 +29,12 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
switches = {} switches = {}
switches_off = [] switches_off = []
registry = await entity_registry.async_get_registry(hass) option_block_clients = controller.option_block_clients
entity_registry = await hass.helpers.entity_registry.async_get_registry()
# Restore clients that is not a part of active clients list. # Restore clients that is not a part of active clients list.
for entity in registry.entities.values(): for entity in entity_registry.entities.values():
if ( if (
entity.config_entry_id == config_entry.entry_id entity.config_entry_id == config_entry.entry_id
@ -61,6 +62,43 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async_dispatcher_connect(hass, controller.signal_update, update_controller) async_dispatcher_connect(hass, controller.signal_update, update_controller)
) )
@callback
def options_updated():
"""Manage entities affected by config entry options."""
nonlocal option_block_clients
update = set()
remove = set()
if option_block_clients != controller.option_block_clients:
option_block_clients = controller.option_block_clients
for block_client_id, entity in switches.items():
if not isinstance(entity, UniFiBlockClientSwitch):
continue
if entity.client.mac in option_block_clients:
update.add(block_client_id)
else:
remove.add(block_client_id)
for block_client_id in remove:
entity = switches.pop(block_client_id)
if entity_registry.async_is_registered(entity.entity_id):
entity_registry.async_remove(entity.entity_id)
hass.async_create_task(entity.async_remove())
if len(update) != len(option_block_clients):
update_controller()
controller.listeners.append(
async_dispatcher_connect(
hass, controller.signal_options_update, options_updated
)
)
update_controller() update_controller()
switches_off.clear() switches_off.clear()
@ -74,15 +112,21 @@ def add_entities(controller, async_add_entities, switches, switches_off):
# block client # block client
for client_id in controller.option_block_clients: for client_id in controller.option_block_clients:
client = None
block_client_id = f"block-{client_id}" block_client_id = f"block-{client_id}"
if block_client_id in switches: if block_client_id in switches:
continue continue
if client_id not in controller.api.clients_all: if client_id in controller.api.clients:
client = controller.api.clients[client_id]
elif client_id in controller.api.clients_all:
client = controller.api.clients_all[client_id]
if not client:
continue continue
client = controller.api.clients_all[client_id]
switches[block_client_id] = UniFiBlockClientSwitch(client, controller) switches[block_client_id] = UniFiBlockClientSwitch(client, controller)
new_switches.append(switches[block_client_id]) new_switches.append(switches[block_client_id])

View file

@ -5,7 +5,18 @@ from asynctest import patch
from homeassistant import data_entry_flow from homeassistant import data_entry_flow
from homeassistant.components import unifi from homeassistant.components import unifi
from homeassistant.components.unifi import config_flow from homeassistant.components.unifi import config_flow
from homeassistant.components.unifi.const import CONF_CONTROLLER, CONF_SITE_ID from homeassistant.components.unifi.config_flow import CONF_NEW_CLIENT
from homeassistant.components.unifi.const import (
CONF_ALLOW_BANDWIDTH_SENSORS,
CONF_BLOCK_CLIENT,
CONF_CONTROLLER,
CONF_DETECTION_TIME,
CONF_SITE_ID,
CONF_SSID_FILTER,
CONF_TRACK_CLIENTS,
CONF_TRACK_DEVICES,
CONF_TRACK_WIRED_CLIENTS,
)
from homeassistant.const import ( from homeassistant.const import (
CONF_HOST, CONF_HOST,
CONF_PASSWORD, CONF_PASSWORD,
@ -18,6 +29,8 @@ from .test_controller import setup_unifi_integration
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
CLIENTS = [{"mac": "00:00:00:00:00:01"}]
WLANS = [{"name": "SSID 1"}, {"name": "SSID 2"}] WLANS = [{"name": "SSID 1"}, {"name": "SSID 2"}]
@ -28,7 +41,7 @@ async def test_flow_works(hass, aioclient_mock, mock_discovery):
config_flow.DOMAIN, context={"source": "user"} config_flow.DOMAIN, context={"source": "user"}
) )
assert result["type"] == "form" assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "user" assert result["step_id"] == "user"
assert result["data_schema"]({CONF_USERNAME: "", CONF_PASSWORD: ""}) == { assert result["data_schema"]({CONF_USERNAME: "", CONF_PASSWORD: ""}) == {
CONF_HOST: "unifi", CONF_HOST: "unifi",
@ -64,7 +77,7 @@ async def test_flow_works(hass, aioclient_mock, mock_discovery):
}, },
) )
assert result["type"] == "create_entry" assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result["title"] == "Site name" assert result["title"] == "Site name"
assert result["data"] == { assert result["data"] == {
CONF_CONTROLLER: { CONF_CONTROLLER: {
@ -84,7 +97,7 @@ async def test_flow_works_multiple_sites(hass, aioclient_mock):
config_flow.DOMAIN, context={"source": "user"} config_flow.DOMAIN, context={"source": "user"}
) )
assert result["type"] == "form" assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "user" assert result["step_id"] == "user"
aioclient_mock.post( aioclient_mock.post(
@ -116,7 +129,7 @@ async def test_flow_works_multiple_sites(hass, aioclient_mock):
}, },
) )
assert result["type"] == "form" assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "site" assert result["step_id"] == "site"
assert result["data_schema"]({"site": "site name"}) assert result["data_schema"]({"site": "site name"})
assert result["data_schema"]({"site": "site2 name"}) assert result["data_schema"]({"site": "site2 name"})
@ -133,7 +146,7 @@ async def test_flow_fails_site_already_configured(hass, aioclient_mock):
config_flow.DOMAIN, context={"source": "user"} config_flow.DOMAIN, context={"source": "user"}
) )
assert result["type"] == "form" assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "user" assert result["step_id"] == "user"
aioclient_mock.post( aioclient_mock.post(
@ -162,7 +175,7 @@ async def test_flow_fails_site_already_configured(hass, aioclient_mock):
}, },
) )
assert result["type"] == "abort" assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
async def test_flow_fails_user_credentials_faulty(hass, aioclient_mock): async def test_flow_fails_user_credentials_faulty(hass, aioclient_mock):
@ -171,7 +184,7 @@ async def test_flow_fails_user_credentials_faulty(hass, aioclient_mock):
config_flow.DOMAIN, context={"source": "user"} config_flow.DOMAIN, context={"source": "user"}
) )
assert result["type"] == "form" assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "user" assert result["step_id"] == "user"
with patch("aiounifi.Controller.login", side_effect=aiounifi.errors.Unauthorized): with patch("aiounifi.Controller.login", side_effect=aiounifi.errors.Unauthorized):
@ -186,7 +199,7 @@ async def test_flow_fails_user_credentials_faulty(hass, aioclient_mock):
}, },
) )
assert result["type"] == "form" assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["errors"] == {"base": "faulty_credentials"} assert result["errors"] == {"base": "faulty_credentials"}
@ -196,7 +209,7 @@ async def test_flow_fails_controller_unavailable(hass, aioclient_mock):
config_flow.DOMAIN, context={"source": "user"} config_flow.DOMAIN, context={"source": "user"}
) )
assert result["type"] == "form" assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "user" assert result["step_id"] == "user"
with patch("aiounifi.Controller.login", side_effect=aiounifi.errors.RequestError): with patch("aiounifi.Controller.login", side_effect=aiounifi.errors.RequestError):
@ -211,7 +224,7 @@ async def test_flow_fails_controller_unavailable(hass, aioclient_mock):
}, },
) )
assert result["type"] == "form" assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["errors"] == {"base": "service_unavailable"} assert result["errors"] == {"base": "service_unavailable"}
@ -221,7 +234,7 @@ async def test_flow_fails_unknown_problem(hass, aioclient_mock):
config_flow.DOMAIN, context={"source": "user"} config_flow.DOMAIN, context={"source": "user"}
) )
assert result["type"] == "form" assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "user" assert result["step_id"] == "user"
with patch("aiounifi.Controller.login", side_effect=Exception): with patch("aiounifi.Controller.login", side_effect=Exception):
@ -236,12 +249,14 @@ async def test_flow_fails_unknown_problem(hass, aioclient_mock):
}, },
) )
assert result["type"] == "abort" assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
async def test_option_flow(hass): async def test_option_flow(hass):
"""Test config flow options.""" """Test config flow options."""
controller = await setup_unifi_integration(hass, wlans_response=WLANS) controller = await setup_unifi_integration(
hass, clients_response=CLIENTS, wlans_response=WLANS
)
result = await hass.config_entries.options.async_init( result = await hass.config_entries.options.async_init(
controller.config_entry.entry_id controller.config_entry.entry_id
@ -253,27 +268,64 @@ async def test_option_flow(hass):
result = await hass.config_entries.options.async_configure( result = await hass.config_entries.options.async_configure(
result["flow_id"], result["flow_id"],
user_input={ user_input={
config_flow.CONF_TRACK_CLIENTS: False, CONF_TRACK_CLIENTS: False,
config_flow.CONF_TRACK_WIRED_CLIENTS: False, CONF_TRACK_WIRED_CLIENTS: False,
config_flow.CONF_TRACK_DEVICES: False, CONF_TRACK_DEVICES: False,
config_flow.CONF_SSID_FILTER: ["SSID 1"], CONF_SSID_FILTER: ["SSID 1"],
config_flow.CONF_DETECTION_TIME: 100, CONF_DETECTION_TIME: 100,
}, },
) )
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "client_control"
clients_to_block = hass.config_entries.options._progress[result["flow_id"]].options[
CONF_BLOCK_CLIENT
]
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={
CONF_BLOCK_CLIENT: clients_to_block,
CONF_NEW_CLIENT: "00:00:00:00:00:01",
},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "client_control"
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={
CONF_BLOCK_CLIENT: clients_to_block,
CONF_NEW_CLIENT: "00:00:00:00:00:02",
},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "client_control"
assert result["errors"] == {"base": "unknown_client_mac"}
clients_to_block = hass.config_entries.options._progress[result["flow_id"]].options[
CONF_BLOCK_CLIENT
]
result = await hass.config_entries.options.async_configure(
result["flow_id"], user_input={CONF_BLOCK_CLIENT: clients_to_block},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "statistics_sensors" assert result["step_id"] == "statistics_sensors"
result = await hass.config_entries.options.async_configure( result = await hass.config_entries.options.async_configure(
result["flow_id"], user_input={config_flow.CONF_ALLOW_BANDWIDTH_SENSORS: True} result["flow_id"], user_input={CONF_ALLOW_BANDWIDTH_SENSORS: True}
) )
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result["data"] == { assert result["data"] == {
config_flow.CONF_TRACK_CLIENTS: False, CONF_TRACK_CLIENTS: False,
config_flow.CONF_TRACK_WIRED_CLIENTS: False, CONF_TRACK_WIRED_CLIENTS: False,
config_flow.CONF_TRACK_DEVICES: False, CONF_TRACK_DEVICES: False,
config_flow.CONF_DETECTION_TIME: 100, CONF_DETECTION_TIME: 100,
config_flow.CONF_SSID_FILTER: ["SSID 1"], CONF_SSID_FILTER: ["SSID 1"],
config_flow.CONF_ALLOW_BANDWIDTH_SENSORS: True, CONF_BLOCK_CLIENT: ["00:00:00:00:00:01"],
CONF_ALLOW_BANDWIDTH_SENSORS: True,
} }

View file

@ -166,7 +166,7 @@ async def test_controller_setup(hass):
controller.option_allow_bandwidth_sensors controller.option_allow_bandwidth_sensors
== unifi.const.DEFAULT_ALLOW_BANDWIDTH_SENSORS == unifi.const.DEFAULT_ALLOW_BANDWIDTH_SENSORS
) )
assert controller.option_block_clients == unifi.const.DEFAULT_BLOCK_CLIENTS assert isinstance(controller.option_block_clients, list)
assert controller.option_track_clients == unifi.const.DEFAULT_TRACK_CLIENTS assert controller.option_track_clients == unifi.const.DEFAULT_TRACK_CLIENTS
assert controller.option_track_devices == unifi.const.DEFAULT_TRACK_DEVICES assert controller.option_track_devices == unifi.const.DEFAULT_TRACK_DEVICES
assert ( assert (
@ -175,7 +175,7 @@ async def test_controller_setup(hass):
assert controller.option_detection_time == timedelta( assert controller.option_detection_time == timedelta(
seconds=unifi.const.DEFAULT_DETECTION_TIME seconds=unifi.const.DEFAULT_DETECTION_TIME
) )
assert controller.option_ssid_filter == unifi.const.DEFAULT_SSID_FILTER assert isinstance(controller.option_ssid_filter, list)
assert controller.mac is None assert controller.mac is None
@ -235,7 +235,7 @@ async def test_reset_after_successful_setup(hass):
"""Calling reset when the entry has been setup.""" """Calling reset when the entry has been setup."""
controller = await setup_unifi_integration(hass) controller = await setup_unifi_integration(hass)
assert len(controller.listeners) == 5 assert len(controller.listeners) == 6
result = await controller.async_reset() result = await controller.async_reset()
await hass.async_block_till_done() await hass.async_block_till_done()

View file

@ -4,6 +4,11 @@ from copy import deepcopy
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components import unifi from homeassistant.components import unifi
import homeassistant.components.switch as switch import homeassistant.components.switch as switch
from homeassistant.components.unifi.const import (
CONF_BLOCK_CLIENT,
CONF_TRACK_CLIENTS,
CONF_TRACK_DEVICES,
)
from homeassistant.helpers import entity_registry from homeassistant.helpers import entity_registry
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@ -200,11 +205,7 @@ async def test_platform_manually_configured(hass):
async def test_no_clients(hass): async def test_no_clients(hass):
"""Test the update_clients function when no clients are found.""" """Test the update_clients function when no clients are found."""
controller = await setup_unifi_integration( controller = await setup_unifi_integration(
hass, hass, options={CONF_TRACK_CLIENTS: False, CONF_TRACK_DEVICES: False},
options={
unifi.const.CONF_TRACK_CLIENTS: False,
unifi.const.CONF_TRACK_DEVICES: False,
},
) )
assert len(controller.mock_requests) == 4 assert len(controller.mock_requests) == 4
@ -215,10 +216,7 @@ async def test_controller_not_client(hass):
"""Test that the controller doesn't become a switch.""" """Test that the controller doesn't become a switch."""
controller = await setup_unifi_integration( controller = await setup_unifi_integration(
hass, hass,
options={ options={CONF_TRACK_CLIENTS: False, CONF_TRACK_DEVICES: False},
unifi.const.CONF_TRACK_CLIENTS: False,
unifi.const.CONF_TRACK_DEVICES: False,
},
clients_response=[CONTROLLER_HOST], clients_response=[CONTROLLER_HOST],
devices_response=[DEVICE_1], devices_response=[DEVICE_1],
) )
@ -235,10 +233,7 @@ async def test_not_admin(hass):
sites["Site name"]["role"] = "not admin" sites["Site name"]["role"] = "not admin"
controller = await setup_unifi_integration( controller = await setup_unifi_integration(
hass, hass,
options={ options={CONF_TRACK_CLIENTS: False, CONF_TRACK_DEVICES: False},
unifi.const.CONF_TRACK_CLIENTS: False,
unifi.const.CONF_TRACK_DEVICES: False,
},
sites=sites, sites=sites,
clients_response=[CLIENT_1], clients_response=[CLIENT_1],
devices_response=[DEVICE_1], devices_response=[DEVICE_1],
@ -253,9 +248,9 @@ async def test_switches(hass):
controller = await setup_unifi_integration( controller = await setup_unifi_integration(
hass, hass,
options={ options={
unifi.CONF_BLOCK_CLIENT: [BLOCKED["mac"], UNBLOCKED["mac"]], CONF_BLOCK_CLIENT: [BLOCKED["mac"], UNBLOCKED["mac"]],
unifi.const.CONF_TRACK_CLIENTS: False, CONF_TRACK_CLIENTS: False,
unifi.const.CONF_TRACK_DEVICES: False, CONF_TRACK_DEVICES: False,
}, },
clients_response=[CLIENT_1, CLIENT_4], clients_response=[CLIENT_1, CLIENT_4],
devices_response=[DEVICE_1], devices_response=[DEVICE_1],
@ -284,34 +279,10 @@ async def test_switches(hass):
assert unblocked is not None assert unblocked is not None
assert unblocked.state == "on" assert unblocked.state == "on"
async def test_new_client_discovered_on_block_control(hass):
"""Test if 2nd update has a new client."""
controller = await setup_unifi_integration(
hass,
options={
unifi.CONF_BLOCK_CLIENT: [BLOCKED["mac"]],
unifi.const.CONF_TRACK_CLIENTS: False,
unifi.const.CONF_TRACK_DEVICES: False,
},
clients_all_response=[BLOCKED],
)
assert len(controller.mock_requests) == 4
assert len(hass.states.async_all()) == 2
controller.api.websocket._data = {
"meta": {"message": "sta:sync"},
"data": [BLOCKED],
}
controller.api.session_handler("data")
# Calling a service will trigger the updates to run
await hass.services.async_call( await hass.services.async_call(
"switch", "turn_off", {"entity_id": "switch.block_client_1"}, blocking=True "switch", "turn_off", {"entity_id": "switch.block_client_1"}, blocking=True
) )
assert len(controller.mock_requests) == 5 assert len(controller.mock_requests) == 5
assert len(hass.states.async_all()) == 2
assert controller.mock_requests[4] == { assert controller.mock_requests[4] == {
"json": {"mac": "00:00:00:00:01:01", "cmd": "block-sta"}, "json": {"mac": "00:00:00:00:01:01", "cmd": "block-sta"},
"method": "post", "method": "post",
@ -329,14 +300,79 @@ async def test_new_client_discovered_on_block_control(hass):
} }
async def test_new_client_discovered_on_poe_control(hass): async def test_new_client_discovered_on_block_control(hass):
"""Test if 2nd update has a new client.""" """Test if 2nd update has a new client."""
controller = await setup_unifi_integration( controller = await setup_unifi_integration(
hass, hass,
options={ options={
unifi.const.CONF_TRACK_CLIENTS: False, CONF_BLOCK_CLIENT: [BLOCKED["mac"]],
unifi.const.CONF_TRACK_DEVICES: False, CONF_TRACK_CLIENTS: False,
CONF_TRACK_DEVICES: False,
}, },
)
assert len(controller.mock_requests) == 4
assert len(hass.states.async_all()) == 1
blocked = hass.states.get("switch.block_client_1")
assert blocked is None
controller.api.websocket._data = {
"meta": {"message": "sta:sync"},
"data": [BLOCKED],
}
controller.api.session_handler("data")
await hass.async_block_till_done()
assert len(hass.states.async_all()) == 2
blocked = hass.states.get("switch.block_client_1")
assert blocked is not None
async def test_option_block_clients(hass):
"""Test the changes to option reflects accordingly."""
controller = await setup_unifi_integration(
hass,
options={CONF_BLOCK_CLIENT: [BLOCKED["mac"]]},
clients_all_response=[BLOCKED, UNBLOCKED],
)
assert len(hass.states.async_all()) == 2
# Add a second switch
hass.config_entries.async_update_entry(
controller.config_entry,
options={CONF_BLOCK_CLIENT: [BLOCKED["mac"], UNBLOCKED["mac"]]},
)
await hass.async_block_till_done()
assert len(hass.states.async_all()) == 3
# Remove the second switch again
hass.config_entries.async_update_entry(
controller.config_entry, options={CONF_BLOCK_CLIENT: [BLOCKED["mac"]]},
)
await hass.async_block_till_done()
assert len(hass.states.async_all()) == 2
# Enable one and remove another one
hass.config_entries.async_update_entry(
controller.config_entry, options={CONF_BLOCK_CLIENT: [UNBLOCKED["mac"]]},
)
await hass.async_block_till_done()
assert len(hass.states.async_all()) == 2
# Remove one
hass.config_entries.async_update_entry(
controller.config_entry, options={CONF_BLOCK_CLIENT: []},
)
await hass.async_block_till_done()
assert len(hass.states.async_all()) == 1
async def test_new_client_discovered_on_poe_control(hass):
"""Test if 2nd update has a new client."""
controller = await setup_unifi_integration(
hass,
options={CONF_TRACK_CLIENTS: False, CONF_TRACK_DEVICES: False},
clients_response=[CLIENT_1], clients_response=[CLIENT_1],
devices_response=[DEVICE_1], devices_response=[DEVICE_1],
) )
@ -435,9 +471,9 @@ async def test_restoring_client(hass):
controller = await setup_unifi_integration( controller = await setup_unifi_integration(
hass, hass,
options={ options={
unifi.CONF_BLOCK_CLIENT: ["random mac"], CONF_BLOCK_CLIENT: ["random mac"],
unifi.const.CONF_TRACK_CLIENTS: False, CONF_TRACK_CLIENTS: False,
unifi.const.CONF_TRACK_DEVICES: False, CONF_TRACK_DEVICES: False,
}, },
clients_response=[CLIENT_2], clients_response=[CLIENT_2],
devices_response=[DEVICE_1], devices_response=[DEVICE_1],