UniFi config entry options (#26113)

Introduce config entry options for Unifi integration
Allow configuration.yaml options to be imported to new options
This commit is contained in:
Robert Svensson 2019-08-21 22:22:42 +02:00 committed by GitHub
parent 7ab36e0381
commit 588eac82c7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 250 additions and 75 deletions

View file

@ -1,26 +1,41 @@
{ {
"config": { "config": {
"abort": { "title": "UniFi Controller",
"already_configured": "Controller site is already configured", "step": {
"user_privilege": "User needs to be administrator" "user": {
"title": "Set up UniFi Controller",
"data": {
"host": "Host",
"username": "User name",
"password": "Password",
"port": "Port",
"site": "Site ID",
"verify_ssl": "Controller using proper certificate"
}
}
}, },
"error": { "error": {
"faulty_credentials": "Bad user credentials", "faulty_credentials": "Bad user credentials",
"service_unavailable": "No service available" "service_unavailable": "No service available"
}, },
"abort": {
"already_configured": "Controller site is already configured",
"user_privilege": "User needs to be administrator"
}
},
"options": {
"step": { "step": {
"user": { "init": {
"data": {}
},
"device_tracker": {
"data": { "data": {
"host": "Host", "detection_time": "Time in seconds from last seen until considered away",
"password": "Password", "track_clients": "Track network clients",
"port": "Port", "track_devices": "Track network devices (Ubiquiti devices)",
"site": "Site ID", "track_wired_clients": "Include wired network clients"
"username": "User name", }
"verify_ssl": "Controller using proper certificate"
},
"title": "Set up UniFi Controller"
} }
}, }
"title": "UniFi Controller"
} }
} }

View file

@ -11,9 +11,6 @@ from .const import (
CONF_BLOCK_CLIENT, CONF_BLOCK_CLIENT,
CONF_CONTROLLER, CONF_CONTROLLER,
CONF_DETECTION_TIME, CONF_DETECTION_TIME,
CONF_DONT_TRACK_CLIENTS,
CONF_DONT_TRACK_DEVICES,
CONF_DONT_TRACK_WIRED_CLIENTS,
CONF_SITE_ID, CONF_SITE_ID,
CONF_SSID_FILTER, CONF_SSID_FILTER,
CONTROLLER_ID, CONTROLLER_ID,
@ -23,6 +20,9 @@ from .const import (
from .controller import UniFiController from .controller import UniFiController
CONF_CONTROLLERS = "controllers" CONF_CONTROLLERS = "controllers"
CONF_DONT_TRACK_CLIENTS = "dont_track_clients"
CONF_DONT_TRACK_DEVICES = "dont_track_devices"
CONF_DONT_TRACK_WIRED_CLIENTS = "dont_track_wired_clients"
CONTROLLER_SCHEMA = vol.Schema( CONTROLLER_SCHEMA = vol.Schema(
{ {
@ -34,9 +34,7 @@ CONTROLLER_SCHEMA = vol.Schema(
vol.Optional(CONF_DONT_TRACK_CLIENTS): cv.boolean, vol.Optional(CONF_DONT_TRACK_CLIENTS): cv.boolean,
vol.Optional(CONF_DONT_TRACK_DEVICES): cv.boolean, vol.Optional(CONF_DONT_TRACK_DEVICES): cv.boolean,
vol.Optional(CONF_DONT_TRACK_WIRED_CLIENTS): cv.boolean, vol.Optional(CONF_DONT_TRACK_WIRED_CLIENTS): cv.boolean,
vol.Optional(CONF_DETECTION_TIME): vol.All( vol.Optional(CONF_DETECTION_TIME): cv.positive_int,
cv.time_period, cv.positive_timedelta
),
vol.Optional(CONF_SSID_FILTER): vol.All(cv.ensure_list, [cv.string]), vol.Optional(CONF_SSID_FILTER): vol.All(cv.ensure_list, [cv.string]),
} }
) )

View file

@ -2,6 +2,7 @@
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.core import callback
from homeassistant.const import ( from homeassistant.const import (
CONF_HOST, CONF_HOST,
CONF_PASSWORD, CONF_PASSWORD,
@ -10,7 +11,20 @@ from homeassistant.const import (
CONF_VERIFY_SSL, CONF_VERIFY_SSL,
) )
from .const import CONF_CONTROLLER, CONF_SITE_ID, DOMAIN, LOGGER from .const import (
CONF_CONTROLLER,
CONF_TRACK_CLIENTS,
CONF_TRACK_DEVICES,
CONF_TRACK_WIRED_CLIENTS,
CONF_DETECTION_TIME,
CONF_SITE_ID,
DEFAULT_TRACK_CLIENTS,
DEFAULT_TRACK_DEVICES,
DEFAULT_TRACK_WIRED_CLIENTS,
DEFAULT_DETECTION_TIME,
DOMAIN,
LOGGER,
)
from .controller import get_controller from .controller import get_controller
from .errors import AlreadyConfigured, AuthenticationRequired, CannotConnect from .errors import AlreadyConfigured, AuthenticationRequired, CannotConnect
@ -26,6 +40,12 @@ class UnifiFlowHandler(config_entries.ConfigFlow):
VERSION = 1 VERSION = 1
CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_POLL CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_POLL
@staticmethod
@callback
def async_get_options_flow(config_entry):
"""Get the options flow for this handler."""
return UnifiOptionsFlowHandler(config_entry)
def __init__(self): def __init__(self):
"""Initialize the UniFi flow.""" """Initialize the UniFi flow."""
self.config = None self.config = None
@ -142,3 +162,52 @@ class UnifiFlowHandler(config_entries.ConfigFlow):
self.desc = import_config[CONF_SITE_ID] self.desc = import_config[CONF_SITE_ID]
return await self.async_step_user(user_input=config) return await self.async_step_user(user_input=config)
class UnifiOptionsFlowHandler(config_entries.OptionsFlow):
"""Handle Unifi options."""
def __init__(self, config_entry):
"""Initialize UniFi options flow."""
self.config_entry = config_entry
async def async_step_init(self, user_input=None):
"""Manage the UniFi options."""
return await self.async_step_device_tracker()
async def async_step_device_tracker(self, user_input=None):
"""Manage the device tracker options."""
if user_input is not None:
return self.async_create_entry(title="", data=user_input)
return self.async_show_form(
step_id="device_tracker",
data_schema=vol.Schema(
{
vol.Optional(
CONF_TRACK_CLIENTS,
default=self.config_entry.options.get(
CONF_TRACK_CLIENTS, DEFAULT_TRACK_CLIENTS
),
): bool,
vol.Optional(
CONF_TRACK_WIRED_CLIENTS,
default=self.config_entry.options.get(
CONF_TRACK_WIRED_CLIENTS, DEFAULT_TRACK_WIRED_CLIENTS
),
): bool,
vol.Optional(
CONF_TRACK_DEVICES,
default=self.config_entry.options.get(
CONF_TRACK_DEVICES, DEFAULT_TRACK_DEVICES
),
): bool,
vol.Optional(
CONF_DETECTION_TIME,
default=self.config_entry.options.get(
CONF_DETECTION_TIME, DEFAULT_DETECTION_TIME
),
): int,
}
),
)

View file

@ -13,9 +13,16 @@ UNIFI_CONFIG = "unifi_config"
CONF_BLOCK_CLIENT = "block_client" CONF_BLOCK_CLIENT = "block_client"
CONF_DETECTION_TIME = "detection_time" CONF_DETECTION_TIME = "detection_time"
CONF_DONT_TRACK_CLIENTS = "dont_track_clients" CONF_TRACK_CLIENTS = "track_clients"
CONF_DONT_TRACK_DEVICES = "dont_track_devices" CONF_TRACK_DEVICES = "track_devices"
CONF_DONT_TRACK_WIRED_CLIENTS = "dont_track_wired_clients" CONF_TRACK_WIRED_CLIENTS = "track_wired_clients"
CONF_SSID_FILTER = "ssid_filter" CONF_SSID_FILTER = "ssid_filter"
DEFAULT_BLOCK_CLIENTS = []
DEFAULT_TRACK_CLIENTS = True
DEFAULT_TRACK_DEVICES = True
DEFAULT_TRACK_WIRED_CLIENTS = True
DEFAULT_DETECTION_TIME = 300
DEFAULT_SSID_FILTER = []
ATTR_MANUFACTURER = "Ubiquiti Networks" ATTR_MANUFACTURER = "Ubiquiti Networks"

View file

@ -1,4 +1,6 @@
"""UniFi Controller abstraction.""" """UniFi Controller abstraction."""
from datetime import timedelta
import asyncio import asyncio
import ssl import ssl
import async_timeout import async_timeout
@ -15,8 +17,19 @@ from homeassistant.helpers.dispatcher import async_dispatcher_send
from .const import ( from .const import (
CONF_BLOCK_CLIENT, CONF_BLOCK_CLIENT,
CONF_CONTROLLER, CONF_CONTROLLER,
CONF_DETECTION_TIME,
CONF_TRACK_CLIENTS,
CONF_TRACK_DEVICES,
CONF_TRACK_WIRED_CLIENTS,
CONF_SITE_ID, CONF_SITE_ID,
CONF_SSID_FILTER,
CONTROLLER_ID, CONTROLLER_ID,
DEFAULT_BLOCK_CLIENTS,
DEFAULT_TRACK_CLIENTS,
DEFAULT_TRACK_DEVICES,
DEFAULT_TRACK_WIRED_CLIENTS,
DEFAULT_DETECTION_TIME,
DEFAULT_SSID_FILTER,
LOGGER, LOGGER,
UNIFI_CONFIG, UNIFI_CONFIG,
) )
@ -59,9 +72,40 @@ class UniFiController:
return self._site_role return self._site_role
@property @property
def block_clients(self): def option_block_clients(self):
"""Return list of clients to block.""" """Config entry option with list of clients to control network access."""
return self.unifi_config.get(CONF_BLOCK_CLIENT, []) return self.config_entry.options.get(CONF_BLOCK_CLIENT, DEFAULT_BLOCK_CLIENTS)
@property
def option_track_clients(self):
"""Config entry option to not track clients."""
return self.config_entry.options.get(CONF_TRACK_CLIENTS, DEFAULT_TRACK_CLIENTS)
@property
def option_track_devices(self):
"""Config entry option to not track devices."""
return self.config_entry.options.get(CONF_TRACK_DEVICES, DEFAULT_TRACK_DEVICES)
@property
def option_track_wired_clients(self):
"""Config entry option to not track wired clients."""
return self.config_entry.options.get(
CONF_TRACK_WIRED_CLIENTS, DEFAULT_TRACK_WIRED_CLIENTS
)
@property
def option_detection_time(self):
"""Config entry option defining number of seconds from last seen to away."""
return timedelta(
seconds=self.config_entry.options.get(
CONF_DETECTION_TIME, DEFAULT_DETECTION_TIME
)
)
@property
def option_ssid_filter(self):
"""Config entry option listing what SSIDs are being used to track clients."""
return self.config_entry.options.get(CONF_SSID_FILTER, DEFAULT_SSID_FILTER)
@property @property
def mac(self): def mac(self):
@ -96,7 +140,7 @@ class UniFiController:
with async_timeout.timeout(10): with async_timeout.timeout(10):
await self.api.clients.update() await self.api.clients.update()
await self.api.devices.update() await self.api.devices.update()
if self.block_clients: if self.option_block_clients:
await self.api.clients_all.update() await self.api.clients_all.update()
except aiounifi.LoginRequired: except aiounifi.LoginRequired:
@ -155,6 +199,30 @@ class UniFiController:
self.unifi_config = unifi_config self.unifi_config = unifi_config
break break
options = dict(self.config_entry.options)
if CONF_BLOCK_CLIENT in self.unifi_config:
options[CONF_BLOCK_CLIENT] = self.unifi_config[CONF_BLOCK_CLIENT]
if CONF_TRACK_CLIENTS in self.unifi_config:
options[CONF_TRACK_CLIENTS] = self.unifi_config[CONF_TRACK_CLIENTS]
if CONF_TRACK_DEVICES in self.unifi_config:
options[CONF_TRACK_DEVICES] = self.unifi_config[CONF_TRACK_DEVICES]
if CONF_TRACK_WIRED_CLIENTS in self.unifi_config:
options[CONF_TRACK_WIRED_CLIENTS] = self.unifi_config[
CONF_TRACK_WIRED_CLIENTS
]
if CONF_DETECTION_TIME in self.unifi_config:
options[CONF_DETECTION_TIME] = self.unifi_config[CONF_DETECTION_TIME]
if CONF_SSID_FILTER in self.unifi_config:
options[CONF_SSID_FILTER] = self.unifi_config[CONF_SSID_FILTER]
hass.config_entries.async_update_entry(self.config_entry, options=options)
for platform in ["device_tracker", "switch"]: for platform in ["device_tracker", "switch"]:
hass.async_create_task( hass.async_create_task(
hass.config_entries.async_forward_entry_setup( hass.config_entries.async_forward_entry_setup(

View file

@ -27,12 +27,7 @@ import homeassistant.util.dt as dt_util
from .const import ( from .const import (
ATTR_MANUFACTURER, ATTR_MANUFACTURER,
CONF_CONTROLLER, CONF_CONTROLLER,
CONF_DETECTION_TIME,
CONF_DONT_TRACK_CLIENTS,
CONF_DONT_TRACK_DEVICES,
CONF_DONT_TRACK_WIRED_CLIENTS,
CONF_SITE_ID, CONF_SITE_ID,
CONF_SSID_FILTER,
CONTROLLER_ID, CONTROLLER_ID,
DOMAIN as UNIFI_DOMAIN, DOMAIN as UNIFI_DOMAIN,
) )
@ -151,11 +146,11 @@ def update_items(controller, async_add_entities, tracked):
"""Update tracked device state from the controller.""" """Update tracked device state from the controller."""
new_tracked = [] new_tracked = []
if not controller.unifi_config.get(CONF_DONT_TRACK_CLIENTS, False): if controller.option_track_clients:
for client_id in controller.api.clients: for client_id in controller.api.clients:
if client_id in tracked: if client_id in tracked and tracked[client_id].entity_id:
LOGGER.debug( LOGGER.debug(
"Updating UniFi tracked client %s (%s)", "Updating UniFi tracked client %s (%s)",
tracked[client_id].entity_id, tracked[client_id].entity_id,
@ -168,15 +163,12 @@ def update_items(controller, async_add_entities, tracked):
if ( if (
not client.is_wired not client.is_wired
and CONF_SSID_FILTER in controller.unifi_config and controller.option_ssid_filter
and client.essid not in controller.unifi_config[CONF_SSID_FILTER] and client.essid not in controller.option_ssid_filter
): ):
continue continue
if ( if not controller.option_track_wired_clients and client.is_wired:
controller.unifi_config.get(CONF_DONT_TRACK_WIRED_CLIENTS, False)
and client.is_wired
):
continue continue
tracked[client_id] = UniFiClientTracker(client, controller) tracked[client_id] = UniFiClientTracker(client, controller)
@ -187,11 +179,11 @@ def update_items(controller, async_add_entities, tracked):
client.mac, client.mac,
) )
if not controller.unifi_config.get(CONF_DONT_TRACK_DEVICES, False): if controller.option_track_devices:
for device_id in controller.api.devices: for device_id in controller.api.devices:
if device_id in tracked: if device_id in tracked and tracked[device_id].entity_id:
LOGGER.debug( LOGGER.debug(
"Updating UniFi tracked device %s (%s)", "Updating UniFi tracked device %s (%s)",
tracked[device_id].entity_id, tracked[device_id].entity_id,
@ -229,14 +221,11 @@ class UniFiClientTracker(ScannerEntity):
@property @property
def is_connected(self): def is_connected(self):
"""Return true if the client is connected to the network.""" """Return true if the client is connected to the network."""
detection_time = self.controller.unifi_config.get(
CONF_DETECTION_TIME, DEFAULT_DETECTION_TIME
)
if ( if (
dt_util.utcnow() - dt_util.utc_from_timestamp(float(self.client.last_seen)) dt_util.utcnow() - dt_util.utc_from_timestamp(float(self.client.last_seen))
) < detection_time: ) < self.controller.option_detection_time:
return True return True
return False return False
@property @property
@ -291,15 +280,12 @@ class UniFiDeviceTracker(ScannerEntity):
@property @property
def is_connected(self): def is_connected(self):
"""Return true if the device is connected to the network.""" """Return true if the device is connected to the network."""
detection_time = self.controller.unifi_config.get(
CONF_DETECTION_TIME, DEFAULT_DETECTION_TIME
)
if self.device.state == 1 and ( if self.device.state == 1 and (
dt_util.utcnow() - dt_util.utc_from_timestamp(float(self.device.last_seen)) dt_util.utcnow() - dt_util.utc_from_timestamp(float(self.device.last_seen))
< detection_time < self.controller.option_detection_time
): ):
return True return True
return False return False
@property @property

View file

@ -22,5 +22,20 @@
"already_configured": "Controller site is already configured", "already_configured": "Controller site is already configured",
"user_privilege": "User needs to be administrator" "user_privilege": "User needs to be administrator"
} }
},
"options": {
"step": {
"init": {
"data": {}
},
"device_tracker": {
"data": {
"detection_time": "Time in seconds from last seen until considered away",
"track_clients": "Track network clients",
"track_devices": "Track network devices (Ubiquiti devices)",
"track_wired_clients": "Include wired network clients"
}
}
}
} }
} }

View file

@ -74,7 +74,7 @@ def update_items(controller, async_add_entities, switches, switches_off):
devices = controller.api.devices devices = controller.api.devices
# block client # block client
for client_id in controller.block_clients: for client_id in controller.option_block_clients:
block_client_id = "block-{}".format(client_id) block_client_id = "block-{}".format(client_id)

View file

@ -37,9 +37,23 @@ ENTRY_CONFIG = {CONF_CONTROLLER: CONTROLLER_DATA}
async def test_controller_setup(): async def test_controller_setup():
"""Successful setup.""" """Successful setup."""
hass = Mock() hass = Mock()
hass.data = {UNIFI_CONFIG: {}} hass.data = {
UNIFI_CONFIG: [
{
CONF_HOST: CONTROLLER_DATA[CONF_HOST],
CONF_SITE_ID: "nice name",
controller.CONF_BLOCK_CLIENT: [],
controller.CONF_TRACK_CLIENTS: True,
controller.CONF_TRACK_DEVICES: True,
controller.CONF_TRACK_WIRED_CLIENTS: True,
controller.CONF_DETECTION_TIME: 300,
controller.CONF_SSID_FILTER: [],
}
]
}
entry = Mock() entry = Mock()
entry.data = ENTRY_CONFIG entry.data = ENTRY_CONFIG
entry.options = {}
api = Mock() api = Mock()
api.initialize.return_value = mock_coro(True) api.initialize.return_value = mock_coro(True)
api.sites.return_value = mock_coro(CONTROLLER_SITES) api.sites.return_value = mock_coro(CONTROLLER_SITES)
@ -89,6 +103,7 @@ async def test_controller_mac():
hass.data = {UNIFI_CONFIG: {}} hass.data = {UNIFI_CONFIG: {}}
entry = Mock() entry = Mock()
entry.data = ENTRY_CONFIG entry.data = ENTRY_CONFIG
entry.options = {}
client = Mock() client = Mock()
client.ip = "1.2.3.4" client.ip = "1.2.3.4"
client.mac = "00:11:22:33:44:55" client.mac = "00:11:22:33:44:55"
@ -111,6 +126,7 @@ async def test_controller_no_mac():
hass.data = {UNIFI_CONFIG: {}} hass.data = {UNIFI_CONFIG: {}}
entry = Mock() entry = Mock()
entry.data = ENTRY_CONFIG entry.data = ENTRY_CONFIG
entry.options = {}
client = Mock() client = Mock()
client.ip = "5.6.7.8" client.ip = "5.6.7.8"
api = Mock() api = Mock()
@ -182,6 +198,7 @@ async def test_reset_unloads_entry_if_setup():
hass.data = {UNIFI_CONFIG: {}} hass.data = {UNIFI_CONFIG: {}}
entry = Mock() entry = Mock()
entry.data = ENTRY_CONFIG entry.data = ENTRY_CONFIG
entry.options = {}
api = Mock() api = Mock()
api.initialize.return_value = mock_coro(True) api.initialize.return_value = mock_coro(True)
api.sites.return_value = mock_coro(CONTROLLER_SITES) api.sites.return_value = mock_coro(CONTROLLER_SITES)

View file

@ -14,6 +14,7 @@ from homeassistant.components import unifi
from homeassistant.components.unifi.const import ( from homeassistant.components.unifi.const import (
CONF_CONTROLLER, CONF_CONTROLLER,
CONF_SITE_ID, CONF_SITE_ID,
CONF_SSID_FILTER,
UNIFI_CONFIG, UNIFI_CONFIG,
) )
from homeassistant.const import ( from homeassistant.const import (
@ -133,7 +134,7 @@ def mock_controller(hass):
return controller return controller
async def setup_controller(hass, mock_controller): async def setup_controller(hass, mock_controller, options={}):
"""Load the UniFi switch platform with the provided controller.""" """Load the UniFi switch platform with the provided controller."""
hass.config.components.add(unifi.DOMAIN) hass.config.components.add(unifi.DOMAIN)
hass.data[unifi.DOMAIN] = {CONTROLLER_ID: mock_controller} hass.data[unifi.DOMAIN] = {CONTROLLER_ID: mock_controller}
@ -146,6 +147,7 @@ async def setup_controller(hass, mock_controller):
config_entries.CONN_CLASS_LOCAL_POLL, config_entries.CONN_CLASS_LOCAL_POLL,
entry_id=1, entry_id=1,
system_options={}, system_options={},
options=options,
) )
mock_controller.config_entry = config_entry mock_controller.config_entry = config_entry
@ -182,9 +184,9 @@ async def test_tracked_devices(hass, mock_controller):
"""Test the update_items function with some clients.""" """Test the update_items function with some clients."""
mock_controller.mock_client_responses.append([CLIENT_1, CLIENT_2, CLIENT_3]) mock_controller.mock_client_responses.append([CLIENT_1, CLIENT_2, CLIENT_3])
mock_controller.mock_device_responses.append([DEVICE_1, DEVICE_2]) mock_controller.mock_device_responses.append([DEVICE_1, DEVICE_2])
mock_controller.unifi_config = {unifi_dt.CONF_SSID_FILTER: ["ssid"]} options = {CONF_SSID_FILTER: ["ssid"]}
await setup_controller(hass, mock_controller) await setup_controller(hass, mock_controller, options)
assert len(mock_controller.mock_requests) == 2 assert len(mock_controller.mock_requests) == 2
assert len(hass.states.async_all()) == 5 assert len(hass.states.async_all()) == 5
@ -234,7 +236,7 @@ async def test_restoring_client(hass, mock_controller):
mock_controller.mock_client_responses.append([CLIENT_2]) mock_controller.mock_client_responses.append([CLIENT_2])
mock_controller.mock_device_responses.append({}) mock_controller.mock_device_responses.append({})
mock_controller.mock_client_all_responses.append([CLIENT_1]) mock_controller.mock_client_all_responses.append([CLIENT_1])
mock_controller.unifi_config = {unifi.CONF_BLOCK_CLIENT: True} options = {unifi.CONF_BLOCK_CLIENT: True}
config_entry = config_entries.ConfigEntry( config_entry = config_entries.ConfigEntry(
1, 1,
@ -263,7 +265,7 @@ async def test_restoring_client(hass, mock_controller):
config_entry=config_entry, config_entry=config_entry,
) )
await setup_controller(hass, mock_controller) await setup_controller(hass, mock_controller, options)
assert len(mock_controller.mock_requests) == 3 assert len(mock_controller.mock_requests) == 3
assert len(hass.states.async_all()) == 4 assert len(hass.states.async_all()) == 4
@ -275,9 +277,9 @@ async def test_dont_track_clients(hass, mock_controller):
"""Test dont track clients config works.""" """Test dont track clients config works."""
mock_controller.mock_client_responses.append([CLIENT_1]) mock_controller.mock_client_responses.append([CLIENT_1])
mock_controller.mock_device_responses.append([DEVICE_1]) mock_controller.mock_device_responses.append([DEVICE_1])
mock_controller.unifi_config = {unifi.CONF_DONT_TRACK_CLIENTS: True} options = {unifi.controller.CONF_TRACK_CLIENTS: False}
await setup_controller(hass, mock_controller) await setup_controller(hass, mock_controller, options)
assert len(mock_controller.mock_requests) == 2 assert len(mock_controller.mock_requests) == 2
assert len(hass.states.async_all()) == 3 assert len(hass.states.async_all()) == 3
@ -293,9 +295,9 @@ async def test_dont_track_devices(hass, mock_controller):
"""Test dont track devices config works.""" """Test dont track devices config works."""
mock_controller.mock_client_responses.append([CLIENT_1]) mock_controller.mock_client_responses.append([CLIENT_1])
mock_controller.mock_device_responses.append([DEVICE_1]) mock_controller.mock_device_responses.append([DEVICE_1])
mock_controller.unifi_config = {unifi.CONF_DONT_TRACK_DEVICES: True} options = {unifi.controller.CONF_TRACK_DEVICES: False}
await setup_controller(hass, mock_controller) await setup_controller(hass, mock_controller, options)
assert len(mock_controller.mock_requests) == 2 assert len(mock_controller.mock_requests) == 2
assert len(hass.states.async_all()) == 3 assert len(hass.states.async_all()) == 3
@ -311,9 +313,9 @@ async def test_dont_track_wired_clients(hass, mock_controller):
"""Test dont track wired clients config works.""" """Test dont track wired clients config works."""
mock_controller.mock_client_responses.append([CLIENT_1, CLIENT_2]) mock_controller.mock_client_responses.append([CLIENT_1, CLIENT_2])
mock_controller.mock_device_responses.append({}) mock_controller.mock_device_responses.append({})
mock_controller.unifi_config = {unifi.CONF_DONT_TRACK_WIRED_CLIENTS: True} options = {unifi.controller.CONF_TRACK_WIRED_CLIENTS: False}
await setup_controller(hass, mock_controller) await setup_controller(hass, mock_controller, options)
assert len(mock_controller.mock_requests) == 2 assert len(mock_controller.mock_requests) == 2
assert len(hass.states.async_all()) == 3 assert len(hass.states.async_all()) == 3

View file

@ -1,5 +1,4 @@
"""Test UniFi setup process.""" """Test UniFi setup process."""
from datetime import timedelta
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from homeassistant.components import unifi from homeassistant.components import unifi
@ -44,7 +43,7 @@ async def test_setup_with_config(hass):
unifi.CONF_HOST: "1.2.3.4", unifi.CONF_HOST: "1.2.3.4",
unifi.CONF_SITE_ID: "My site", unifi.CONF_SITE_ID: "My site",
unifi.CONF_BLOCK_CLIENT: ["12:34:56:78:90:AB"], unifi.CONF_BLOCK_CLIENT: ["12:34:56:78:90:AB"],
unifi.CONF_DETECTION_TIME: timedelta(seconds=3), unifi.CONF_DETECTION_TIME: 3,
unifi.CONF_SSID_FILTER: ["ssid"], unifi.CONF_SSID_FILTER: ["ssid"],
} }
] ]

View file

@ -250,7 +250,7 @@ def mock_controller(hass):
return controller return controller
async def setup_controller(hass, mock_controller): async def setup_controller(hass, mock_controller, options={}):
"""Load the UniFi switch platform with the provided controller.""" """Load the UniFi switch platform with the provided controller."""
hass.config.components.add(unifi.DOMAIN) hass.config.components.add(unifi.DOMAIN)
hass.data[unifi.DOMAIN] = {CONTROLLER_ID: mock_controller} hass.data[unifi.DOMAIN] = {CONTROLLER_ID: mock_controller}
@ -263,6 +263,7 @@ async def setup_controller(hass, mock_controller):
config_entries.CONN_CLASS_LOCAL_POLL, config_entries.CONN_CLASS_LOCAL_POLL,
entry_id=1, entry_id=1,
system_options={}, system_options={},
options=options,
) )
mock_controller.config_entry = config_entry mock_controller.config_entry = config_entry
@ -320,11 +321,9 @@ async def test_switches(hass, mock_controller):
mock_controller.mock_client_responses.append([CLIENT_1, CLIENT_4]) mock_controller.mock_client_responses.append([CLIENT_1, CLIENT_4])
mock_controller.mock_device_responses.append([DEVICE_1]) mock_controller.mock_device_responses.append([DEVICE_1])
mock_controller.mock_client_all_responses.append([BLOCKED, UNBLOCKED, CLIENT_1]) mock_controller.mock_client_all_responses.append([BLOCKED, UNBLOCKED, CLIENT_1])
mock_controller.unifi_config = { options = {unifi.CONF_BLOCK_CLIENT: [BLOCKED["mac"], UNBLOCKED["mac"]]}
unifi.CONF_BLOCK_CLIENT: [BLOCKED["mac"], UNBLOCKED["mac"]]
}
await setup_controller(hass, mock_controller) await setup_controller(hass, mock_controller, options)
assert len(mock_controller.mock_requests) == 3 assert len(mock_controller.mock_requests) == 3
assert len(hass.states.async_all()) == 5 assert len(hass.states.async_all()) == 5
@ -467,7 +466,7 @@ async def test_restoring_client(hass, mock_controller):
mock_controller.mock_client_responses.append([CLIENT_2]) mock_controller.mock_client_responses.append([CLIENT_2])
mock_controller.mock_device_responses.append([DEVICE_1]) mock_controller.mock_device_responses.append([DEVICE_1])
mock_controller.mock_client_all_responses.append([CLIENT_1]) mock_controller.mock_client_all_responses.append([CLIENT_1])
mock_controller.unifi_config = {unifi.CONF_BLOCK_CLIENT: ["random mac"]} options = {unifi.CONF_BLOCK_CLIENT: ["random mac"]}
config_entry = config_entries.ConfigEntry( config_entry = config_entries.ConfigEntry(
1, 1,
@ -496,7 +495,7 @@ async def test_restoring_client(hass, mock_controller):
config_entry=config_entry, config_entry=config_entry,
) )
await setup_controller(hass, mock_controller) await setup_controller(hass, mock_controller, options)
assert len(mock_controller.mock_requests) == 3 assert len(mock_controller.mock_requests) == 3
assert len(hass.states.async_all()) == 3 assert len(hass.states.async_all()) == 3