Use unifi direct connect w/ssl verify for unifiprotect when possible (#64395)

This commit is contained in:
J. Nick Koston 2022-01-18 14:40:55 -10:00 committed by GitHub
parent 04a2227f4b
commit 3c7005d4dc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 190 additions and 33 deletions

View file

@ -40,6 +40,11 @@ from .utils import _async_short_mac, _async_unifi_mac_from_hass
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def _host_is_direct_connect(host: str) -> bool:
"""Check if a host is a unifi direct connect domain."""
return host.endswith(".ui.direct")
class ProtectFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): class ProtectFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a UniFi Protect config flow.""" """Handle a UniFi Protect config flow."""
@ -74,10 +79,33 @@ class ProtectFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
) -> FlowResult: ) -> FlowResult:
"""Handle discovery.""" """Handle discovery."""
self._discovered_device = discovery_info self._discovered_device = discovery_info
mac = _async_unifi_mac_from_hass(discovery_info["mac"]) mac = _async_unifi_mac_from_hass(discovery_info["hw_addr"])
await self.async_set_unique_id(mac) await self.async_set_unique_id(mac)
for entry in self._async_current_entries(include_ignore=False):
if entry.unique_id != mac:
continue
new_host = None
if (
_host_is_direct_connect(entry.data[CONF_HOST])
and discovery_info["direct_connect_domain"]
and entry.data[CONF_HOST] != discovery_info["direct_connect_domain"]
):
new_host = discovery_info["direct_connect_domain"]
elif (
not _host_is_direct_connect(entry.data[CONF_HOST])
and entry.data[CONF_HOST] != discovery_info["source_ip"]
):
new_host = discovery_info["source_ip"]
if new_host:
self.hass.config_entries.async_update_entry(
entry, data={**entry.data, CONF_HOST: new_host}
)
self.hass.async_create_task(
self.hass.config_entries.async_reload(entry.entry_id)
)
return self.async_abort(reason="already_configured")
self._abort_if_unique_id_configured( self._abort_if_unique_id_configured(
updates={CONF_HOST: discovery_info["ip_address"]} updates={CONF_HOST: discovery_info["source_ip"]}
) )
return await self.async_step_discovery_confirm() return await self.async_step_discovery_confirm()
@ -88,8 +116,15 @@ class ProtectFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
errors: dict[str, str] = {} errors: dict[str, str] = {}
discovery_info = self._discovered_device discovery_info = self._discovered_device
if user_input is not None: if user_input is not None:
user_input[CONF_HOST] = discovery_info["ip_address"]
user_input[CONF_PORT] = DEFAULT_PORT user_input[CONF_PORT] = DEFAULT_PORT
nvr_data = None
if discovery_info["direct_connect_domain"]:
user_input[CONF_HOST] = discovery_info["direct_connect_domain"]
user_input[CONF_VERIFY_SSL] = True
nvr_data, errors = await self._async_get_nvr_data(user_input)
if not nvr_data or errors:
user_input[CONF_HOST] = discovery_info["source_ip"]
user_input[CONF_VERIFY_SSL] = False
nvr_data, errors = await self._async_get_nvr_data(user_input) nvr_data, errors = await self._async_get_nvr_data(user_input)
if nvr_data and not errors: if nvr_data and not errors:
return self._async_create_entry(nvr_data.name, user_input) return self._async_create_entry(nvr_data.name, user_input)
@ -97,8 +132,8 @@ class ProtectFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
placeholders = { placeholders = {
"name": discovery_info["hostname"] "name": discovery_info["hostname"]
or discovery_info["platform"] or discovery_info["platform"]
or f"NVR {_async_short_mac(discovery_info['mac'])}", or f"NVR {_async_short_mac(discovery_info['hw_addr'])}",
"ip_address": discovery_info["ip_address"], "ip_address": discovery_info["source_ip"],
} }
self.context["title_placeholders"] = placeholders self.context["title_placeholders"] = placeholders
user_input = user_input or {} user_input = user_input or {}
@ -107,10 +142,6 @@ class ProtectFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
description_placeholders=placeholders, description_placeholders=placeholders,
data_schema=vol.Schema( data_schema=vol.Schema(
{ {
vol.Required(
CONF_VERIFY_SSL,
default=user_input.get(CONF_VERIFY_SSL, DEFAULT_VERIFY_SSL),
): bool,
vol.Required( vol.Required(
CONF_USERNAME, default=user_input.get(CONF_USERNAME) CONF_USERNAME, default=user_input.get(CONF_USERNAME)
): str, ): str,

View file

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from dataclasses import asdict
from datetime import timedelta from datetime import timedelta
import logging import logging
from typing import Any from typing import Any
@ -56,11 +57,6 @@ def async_trigger_discovery(
hass.config_entries.flow.async_init( hass.config_entries.flow.async_init(
DOMAIN, DOMAIN,
context={"source": config_entries.SOURCE_DISCOVERY}, context={"source": config_entries.SOURCE_DISCOVERY},
data={ data=asdict(device),
"ip_address": device.source_ip,
"mac": device.hw_addr,
"hostname": device.hostname, # can be None
"platform": device.platform, # can be None
},
) )
) )

View file

@ -24,7 +24,6 @@
"discovery_confirm": { "discovery_confirm": {
"description": "Do you want to setup {name} ({ip_address})?", "description": "Do you want to setup {name} ({ip_address})?",
"data": { "data": {
"verify_ssl": "[%key:common::config_flow::data::verify_ssl%]",
"username": "[%key:common::config_flow::data::username%]", "username": "[%key:common::config_flow::data::username%]",
"password": "[%key:common::config_flow::data::password%]" "password": "[%key:common::config_flow::data::password%]"
} }

View file

@ -15,8 +15,7 @@
"discovery_confirm": { "discovery_confirm": {
"data": { "data": {
"password": "Password", "password": "Password",
"username": "Username", "username": "Username"
"verify_ssl": "Verify SSL certificate"
}, },
"description": "Do you want to setup {name} ({ip_address})?" "description": "Do you want to setup {name} ({ip_address})?"
}, },

View file

@ -16,6 +16,14 @@ UNIFI_DISCOVERY = UnifiDevice(
platform=DEVICE_HOSTNAME, platform=DEVICE_HOSTNAME,
hostname=DEVICE_HOSTNAME, hostname=DEVICE_HOSTNAME,
services={UnifiService.Protect: True}, services={UnifiService.Protect: True},
direct_connect_domain="x.ui.direct",
)
UNIFI_DISCOVERY_PARTIAL = UnifiDevice(
source_ip=DEVICE_IP_ADDRESS,
hw_addr=DEVICE_MAC_ADDRESS,
services={UnifiService.Protect: True},
) )

View file

@ -1,6 +1,7 @@
"""Test the UniFi Protect config flow.""" """Test the UniFi Protect config flow."""
from __future__ import annotations from __future__ import annotations
from dataclasses import asdict
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -15,6 +16,7 @@ from homeassistant.components.unifiprotect.const import (
CONF_OVERRIDE_CHOST, CONF_OVERRIDE_CHOST,
DOMAIN, DOMAIN,
) )
from homeassistant.const import CONF_HOST
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import ( from homeassistant.data_entry_flow import (
RESULT_TYPE_ABORT, RESULT_TYPE_ABORT,
@ -23,7 +25,14 @@ from homeassistant.data_entry_flow import (
) )
from homeassistant.helpers import device_registry as dr from homeassistant.helpers import device_registry as dr
from . import DEVICE_HOSTNAME, DEVICE_IP_ADDRESS, DEVICE_MAC_ADDRESS, _patch_discovery from . import (
DEVICE_HOSTNAME,
DEVICE_IP_ADDRESS,
DEVICE_MAC_ADDRESS,
UNIFI_DISCOVERY,
UNIFI_DISCOVERY_PARTIAL,
_patch_discovery,
)
from .conftest import MAC_ADDR from .conftest import MAC_ADDR
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -45,18 +54,9 @@ SSDP_DISCOVERY = (
}, },
), ),
) )
UNIFI_DISCOVERY_DICT = {
"ip_address": DEVICE_IP_ADDRESS, UNIFI_DISCOVERY_DICT = asdict(UNIFI_DISCOVERY)
"mac": DEVICE_MAC_ADDRESS, UNIFI_DISCOVERY_DICT_PARTIAL = asdict(UNIFI_DISCOVERY_PARTIAL)
"hostname": DEVICE_HOSTNAME,
"platform": DEVICE_HOSTNAME,
}
UNIFI_DISCOVERY_DICT_PARTIAL = {
"ip_address": DEVICE_IP_ADDRESS,
"mac": DEVICE_MAC_ADDRESS,
"hostname": None,
"platform": None,
}
async def test_form(hass: HomeAssistant, mock_nvr: NVR) -> None: async def test_form(hass: HomeAssistant, mock_nvr: NVR) -> None:
@ -292,7 +292,7 @@ async def test_discovered_by_ssdp_or_dhcp(
assert result["reason"] == "discovery_started" assert result["reason"] == "discovery_started"
async def test_discovered_by_unifi_discovery( async def test_discovered_by_unifi_discovery_direct_connect(
hass: HomeAssistant, mock_nvr: NVR hass: HomeAssistant, mock_nvr: NVR
) -> None: ) -> None:
"""Test a discovery from unifi-discovery.""" """Test a discovery from unifi-discovery."""
@ -331,6 +331,130 @@ async def test_discovered_by_unifi_discovery(
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert result2["type"] == RESULT_TYPE_CREATE_ENTRY
assert result2["title"] == "UnifiProtect"
assert result2["data"] == {
"host": "x.ui.direct",
"username": "test-username",
"password": "test-password",
"id": "UnifiProtect",
"port": 443,
"verify_ssl": True,
}
assert len(mock_setup_entry.mock_calls) == 1
async def test_discovered_by_unifi_discovery_direct_connect_updated(
hass: HomeAssistant, mock_nvr: NVR
) -> None:
"""Test a discovery from unifi-discovery updates the direct connect host."""
mock_config = MockConfigEntry(
domain=DOMAIN,
data={
"host": "y.ui.direct",
"username": "test-username",
"password": "test-password",
"id": "UnifiProtect",
"port": 443,
"verify_ssl": True,
},
version=2,
unique_id=DEVICE_MAC_ADDRESS.replace(":", "").upper(),
)
mock_config.add_to_hass(hass)
with _patch_discovery(), patch(
"homeassistant.components.unifiprotect.async_setup_entry",
return_value=True,
) as mock_setup_entry:
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_DISCOVERY},
data=UNIFI_DISCOVERY_DICT,
)
await hass.async_block_till_done()
assert result["type"] == RESULT_TYPE_ABORT
assert result["reason"] == "already_configured"
assert len(mock_setup_entry.mock_calls) == 1
assert mock_config.data[CONF_HOST] == "x.ui.direct"
async def test_discovered_by_unifi_discovery_direct_connect_updated_but_not_using_direct_connect(
hass: HomeAssistant, mock_nvr: NVR
) -> None:
"""Test a discovery from unifi-discovery updates the host but not direct connect if its not in use."""
mock_config = MockConfigEntry(
domain=DOMAIN,
data={
"host": "1.2.2.2",
"username": "test-username",
"password": "test-password",
"id": "UnifiProtect",
"port": 443,
"verify_ssl": False,
},
version=2,
unique_id=DEVICE_MAC_ADDRESS.replace(":", "").upper(),
)
mock_config.add_to_hass(hass)
with _patch_discovery(), patch(
"homeassistant.components.unifiprotect.async_setup_entry",
return_value=True,
) as mock_setup_entry:
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_DISCOVERY},
data=UNIFI_DISCOVERY_DICT,
)
await hass.async_block_till_done()
assert result["type"] == RESULT_TYPE_ABORT
assert result["reason"] == "already_configured"
assert len(mock_setup_entry.mock_calls) == 1
assert mock_config.data[CONF_HOST] == "127.0.0.1"
async def test_discovered_by_unifi_discovery(
hass: HomeAssistant, mock_nvr: NVR
) -> None:
"""Test a discovery from unifi-discovery."""
with _patch_discovery():
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_DISCOVERY},
data=UNIFI_DISCOVERY_DICT,
)
await hass.async_block_till_done()
assert result["type"] == RESULT_TYPE_FORM
assert result["step_id"] == "discovery_confirm"
flows = hass.config_entries.flow.async_progress_by_handler(DOMAIN)
assert flows[0]["context"]["title_placeholders"] == {
"ip_address": DEVICE_IP_ADDRESS,
"name": DEVICE_HOSTNAME,
}
assert not result["errors"]
with patch(
"homeassistant.components.unifiprotect.config_flow.ProtectApiClient.get_nvr",
side_effect=[NotAuthorized, mock_nvr],
), patch(
"homeassistant.components.unifiprotect.async_setup_entry",
return_value=True,
) as mock_setup_entry:
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
"username": "test-username",
"password": "test-password",
},
)
await hass.async_block_till_done()
assert result2["type"] == RESULT_TYPE_CREATE_ENTRY assert result2["type"] == RESULT_TYPE_CREATE_ENTRY
assert result2["title"] == "UnifiProtect" assert result2["title"] == "UnifiProtect"
assert result2["data"] == { assert result2["data"] == {