"""Config flow for UniFi.""" import socket import voluptuous as vol from homeassistant import config_entries from homeassistant.const import ( CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME, CONF_VERIFY_SSL, ) from homeassistant.core import callback import homeassistant.helpers.config_validation as cv from .const import ( CONF_ALLOW_BANDWIDTH_SENSORS, CONF_ALLOW_UPTIME_SENSORS, CONF_BLOCK_CLIENT, CONF_CONTROLLER, CONF_DETECTION_TIME, CONF_DPI_RESTRICTIONS, CONF_IGNORE_WIRED_BUG, CONF_POE_CLIENTS, CONF_SITE_ID, CONF_SSID_FILTER, CONF_TRACK_CLIENTS, CONF_TRACK_DEVICES, CONF_TRACK_WIRED_CLIENTS, CONTROLLER_ID, DEFAULT_DPI_RESTRICTIONS, DEFAULT_POE_CLIENTS, DOMAIN as UNIFI_DOMAIN, LOGGER, ) from .controller import get_controller from .errors import AlreadyConfigured, AuthenticationRequired, CannotConnect DEFAULT_PORT = 8443 DEFAULT_SITE_ID = "default" DEFAULT_VERIFY_SSL = False @callback def get_controller_id_from_config_entry(config_entry): """Return controller with a matching bridge id.""" return CONTROLLER_ID.format( host=config_entry.data[CONF_CONTROLLER][CONF_HOST], site=config_entry.data[CONF_CONTROLLER][CONF_SITE_ID], ) class UnifiFlowHandler(config_entries.ConfigFlow, domain=UNIFI_DOMAIN): """Handle a UniFi config flow.""" VERSION = 1 CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_POLL @staticmethod @callback def async_get_options_flow(config_entry): """Get the options flow for this handler.""" return UnifiOptionsFlowHandler(config_entry) def __init__(self): """Initialize the UniFi flow.""" self.config = None self.sites = None self.reauth_config_entry = {} self.reauth_config = {} self.reauth_schema = {} async def async_step_user(self, user_input=None): """Handle a flow initialized by the user.""" errors = {} if user_input is not None: try: self.config = { CONF_HOST: user_input[CONF_HOST], CONF_USERNAME: user_input[CONF_USERNAME], CONF_PASSWORD: user_input[CONF_PASSWORD], CONF_PORT: user_input.get(CONF_PORT), CONF_VERIFY_SSL: user_input.get(CONF_VERIFY_SSL), CONF_SITE_ID: DEFAULT_SITE_ID, } controller = await get_controller(self.hass, **self.config) sites = await controller.sites() self.sites = {site["name"]: site["desc"] for site in sites.values()} if self.reauth_config.get(CONF_SITE_ID) in self.sites: return await self.async_step_site( {CONF_SITE_ID: self.reauth_config[CONF_SITE_ID]} ) return await self.async_step_site() except AuthenticationRequired: errors["base"] = "faulty_credentials" except CannotConnect: errors["base"] = "service_unavailable" except Exception: # pylint: disable=broad-except LOGGER.error( "Unknown error connecting with UniFi Controller at %s", user_input[CONF_HOST], ) return self.async_abort(reason="unknown") host = "" if await async_discover_unifi(self.hass): host = "unifi" data = self.reauth_schema or { vol.Required(CONF_HOST, default=host): str, vol.Required(CONF_USERNAME): str, vol.Required(CONF_PASSWORD): str, vol.Optional(CONF_PORT, default=DEFAULT_PORT): int, vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): bool, } return self.async_show_form( step_id="user", data_schema=vol.Schema(data), errors=errors, ) async def async_step_site(self, user_input=None): """Select site to control.""" errors = {} if user_input is not None: try: self.config[CONF_SITE_ID] = user_input[CONF_SITE_ID] data = {CONF_CONTROLLER: self.config} if self.reauth_config_entry: self.hass.config_entries.async_update_entry( self.reauth_config_entry, data=data ) await self.hass.config_entries.async_reload( self.reauth_config_entry.entry_id ) return self.async_abort(reason="reauth_successful") for entry in self._async_current_entries(): controller = entry.data[CONF_CONTROLLER] if ( controller[CONF_HOST] == self.config[CONF_HOST] and controller[CONF_SITE_ID] == self.config[CONF_SITE_ID] ): raise AlreadyConfigured site_nice_name = self.sites[self.config[CONF_SITE_ID]] return self.async_create_entry(title=site_nice_name, data=data) except AlreadyConfigured: return self.async_abort(reason="already_configured") if len(self.sites) == 1: return await self.async_step_site({CONF_SITE_ID: next(iter(self.sites))}) return self.async_show_form( step_id="site", data_schema=vol.Schema({vol.Required(CONF_SITE_ID): vol.In(self.sites)}), errors=errors, ) async def async_step_reauth(self, config_entry: dict): """Trigger a reauthentication flow.""" self.reauth_config_entry = config_entry self.reauth_config = config_entry.data[CONF_CONTROLLER] # pylint: disable=no-member # https://github.com/PyCQA/pylint/issues/3167 self.context["title_placeholders"] = { CONF_HOST: self.reauth_config[CONF_HOST], CONF_SITE_ID: config_entry.title, } self.reauth_schema = { vol.Required(CONF_HOST, default=self.reauth_config[CONF_HOST]): str, vol.Required(CONF_USERNAME, default=self.reauth_config[CONF_USERNAME]): str, vol.Required(CONF_PASSWORD): str, vol.Required(CONF_PORT, default=self.reauth_config[CONF_PORT]): int, vol.Required( CONF_VERIFY_SSL, default=self.reauth_config[CONF_VERIFY_SSL] ): bool, } return await self.async_step_user() class UnifiOptionsFlowHandler(config_entries.OptionsFlow): """Handle Unifi options.""" def __init__(self, config_entry): """Initialize UniFi options flow.""" self.config_entry = config_entry self.options = dict(config_entry.options) self.controller = None async def async_step_init(self, user_input=None): """Manage the UniFi options.""" self.controller = self.hass.data[UNIFI_DOMAIN][self.config_entry.entry_id] self.options[CONF_BLOCK_CLIENT] = self.controller.option_block_clients if self.show_advanced_options: return await self.async_step_device_tracker() return await self.async_step_simple_options() async def async_step_simple_options(self, user_input=None): """For simple Jack.""" if user_input is not None: self.options.update(user_input) return await self._update_options() clients_to_block = {} for client in self.controller.api.clients.values(): clients_to_block[ client.mac ] = f"{client.name or client.hostname} ({client.mac})" return self.async_show_form( step_id="simple_options", data_schema=vol.Schema( { vol.Optional( CONF_TRACK_CLIENTS, default=self.controller.option_track_clients, ): bool, vol.Optional( CONF_TRACK_DEVICES, default=self.controller.option_track_devices, ): bool, vol.Optional( CONF_BLOCK_CLIENT, default=self.options[CONF_BLOCK_CLIENT] ): cv.multi_select(clients_to_block), } ), ) async def async_step_device_tracker(self, user_input=None): """Manage the device tracker options.""" if user_input is not None: self.options.update(user_input) return await self.async_step_client_control() ssids = ( set(self.controller.api.wlans) | { f"{wlan.name}{wlan.name_combine_suffix}" for wlan in self.controller.api.wlans.values() if not wlan.name_combine_enabled } | { wlan["name"] for ap in self.controller.api.devices.values() for wlan in ap.wlan_overrides if "name" in wlan } ) ssid_filter = {ssid: ssid for ssid in sorted(list(ssids))} return self.async_show_form( step_id="device_tracker", data_schema=vol.Schema( { vol.Optional( CONF_TRACK_CLIENTS, default=self.controller.option_track_clients, ): bool, vol.Optional( CONF_TRACK_WIRED_CLIENTS, default=self.controller.option_track_wired_clients, ): bool, vol.Optional( CONF_TRACK_DEVICES, default=self.controller.option_track_devices, ): bool, vol.Optional( CONF_SSID_FILTER, default=self.controller.option_ssid_filter ): cv.multi_select(ssid_filter), vol.Optional( CONF_DETECTION_TIME, default=int( self.controller.option_detection_time.total_seconds() ), ): int, vol.Optional( CONF_IGNORE_WIRED_BUG, default=self.controller.option_ignore_wired_bug, ): bool, } ), ) async def async_step_client_control(self, user_input=None): """Manage configuration of network access controlled clients.""" errors = {} if user_input is not None: self.options.update(user_input) return await self.async_step_statistics_sensors() clients_to_block = {} for client in self.controller.api.clients.values(): clients_to_block[ client.mac ] = f"{client.name or client.hostname} ({client.mac})" return self.async_show_form( step_id="client_control", data_schema=vol.Schema( { vol.Optional( CONF_BLOCK_CLIENT, default=self.options[CONF_BLOCK_CLIENT] ): cv.multi_select(clients_to_block), vol.Optional( CONF_POE_CLIENTS, default=self.options.get(CONF_POE_CLIENTS, DEFAULT_POE_CLIENTS), ): bool, vol.Optional( CONF_DPI_RESTRICTIONS, default=self.options.get( CONF_DPI_RESTRICTIONS, DEFAULT_DPI_RESTRICTIONS ), ): bool, } ), errors=errors, ) async def async_step_statistics_sensors(self, user_input=None): """Manage the statistics sensors options.""" if user_input is not None: self.options.update(user_input) return await self._update_options() return self.async_show_form( step_id="statistics_sensors", data_schema=vol.Schema( { vol.Optional( CONF_ALLOW_BANDWIDTH_SENSORS, default=self.controller.option_allow_bandwidth_sensors, ): bool, vol.Optional( CONF_ALLOW_UPTIME_SENSORS, default=self.controller.option_allow_uptime_sensors, ): bool, } ), ) async def _update_options(self): """Update config entry options.""" return self.async_create_entry(title="", data=self.options) async def async_discover_unifi(hass): """Discover UniFi address.""" try: return await hass.async_add_executor_job(socket.gethostbyname, "unifi") except socket.gaierror: return None