From 876776e2915c8c1d18b6bdfd82a0f25dc3d6b2ea Mon Sep 17 00:00:00 2001 From: MarkGodwin Date: Mon, 6 Mar 2023 04:47:45 +0000 Subject: [PATCH] Fix host IP and scheme entry issues in TP-Link Omada (#89130) Fixing host IP and scheme entry issues --- .../components/tplink_omada/config_flow.py | 27 +++++++- .../tplink_omada/test_config_flow.py | 68 +++++++++++++++++-- 2 files changed, 87 insertions(+), 8 deletions(-) diff --git a/homeassistant/components/tplink_omada/config_flow.py b/homeassistant/components/tplink_omada/config_flow.py index 6b958b7d258..f6a75abe6d8 100644 --- a/homeassistant/components/tplink_omada/config_flow.py +++ b/homeassistant/components/tplink_omada/config_flow.py @@ -3,9 +3,12 @@ from __future__ import annotations from collections.abc import Mapping import logging +import re from types import MappingProxyType from typing import Any, NamedTuple +from urllib.parse import urlsplit +from aiohttp import CookieJar from tplink_omada_client.exceptions import ( ConnectionFailed, LoginFailed, @@ -20,7 +23,10 @@ from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME, CONF_VE from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResult from homeassistant.helpers import selector -from homeassistant.helpers.aiohttp_client import async_get_clientsession +from homeassistant.helpers.aiohttp_client import ( + async_create_clientsession, + async_get_clientsession, +) from .const import DOMAIN @@ -42,11 +48,26 @@ async def create_omada_client( hass: HomeAssistant, data: MappingProxyType[str, Any] ) -> OmadaClient: """Create a TP-Link Omada client API for the given config entry.""" - host = data[CONF_HOST] + + host: str = data[CONF_HOST] verify_ssl = bool(data[CONF_VERIFY_SSL]) + + if not host.lower().startswith(("http://", "https://")): + host = "https://" + host + host_parts = urlsplit(host) + if ( + host_parts.hostname + and re.fullmatch(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", host_parts.hostname) + is not None + ): + # TP-Link API uses cookies for login session, so an unsafe cookie jar is required for IP addresses + websession = async_create_clientsession(hass, cookie_jar=CookieJar(unsafe=True)) + else: + websession = async_get_clientsession(hass, verify_ssl=verify_ssl) + username = data[CONF_USERNAME] password = data[CONF_PASSWORD] - websession = async_get_clientsession(hass, verify_ssl=verify_ssl) + return OmadaClient(host, username, password, websession=websession) diff --git a/tests/components/tplink_omada/test_config_flow.py b/tests/components/tplink_omada/test_config_flow.py index fd32b357b7c..cf3fddf5943 100644 --- a/tests/components/tplink_omada/test_config_flow.py +++ b/tests/components/tplink_omada/test_config_flow.py @@ -22,14 +22,14 @@ from homeassistant.data_entry_flow import FlowResultType from tests.common import MockConfigEntry MOCK_USER_DATA = { - "host": "1.1.1.1", + "host": "https://fake.omada.host", "verify_ssl": True, "username": "test-username", "password": "test-password", } MOCK_ENTRY_DATA = { - "host": "1.1.1.1", + "host": "https://fake.omada.host", "verify_ssl": True, "site": "SiteId", "username": "test-username", @@ -111,7 +111,7 @@ async def test_form_multiple_sites(hass: HomeAssistant) -> None: assert result3["type"] == FlowResultType.CREATE_ENTRY assert result3["title"] == "OC200 (Site 2)" assert result3["data"] == { - "host": "1.1.1.1", + "host": "https://fake.omada.host", "verify_ssl": True, "site": "second", "username": "test-username", @@ -272,7 +272,7 @@ async def test_async_step_reauth_success(hass: HomeAssistant) -> None: mocked_validate.assert_called_once_with( hass, { - "host": "1.1.1.1", + "host": "https://fake.omada.host", "verify_ssl": True, "site": "SiteId", "username": "new_uname", @@ -353,6 +353,64 @@ async def test_create_omada_client_parses_args(hass: HomeAssistant) -> None: assert result is not None mock_client.assert_called_once_with( - "1.1.1.1", "test-username", "test-password", "ws" + "https://fake.omada.host", "test-username", "test-password", "ws" ) mock_clientsession.assert_called_once_with(hass, verify_ssl=True) + + +async def test_create_omada_client_adds_missing_scheme(hass: HomeAssistant) -> None: + """Test config arguments are passed to Omada client.""" + + with patch( + "homeassistant.components.tplink_omada.config_flow.OmadaClient", autospec=True + ) as mock_client, patch( + "homeassistant.components.tplink_omada.config_flow.async_get_clientsession", + return_value="ws", + ) as mock_clientsession: + result = await create_omada_client( + hass, + { + "host": "fake.omada.host", + "verify_ssl": True, + "username": "test-username", + "password": "test-password", + }, + ) + + assert result is not None + mock_client.assert_called_once_with( + "https://fake.omada.host", "test-username", "test-password", "ws" + ) + mock_clientsession.assert_called_once_with(hass, verify_ssl=True) + + +async def test_create_omada_client_with_ip_creates_clientsession( + hass: HomeAssistant, +) -> None: + """Test config arguments are passed to Omada client.""" + + with patch( + "homeassistant.components.tplink_omada.config_flow.OmadaClient", autospec=True + ) as mock_client, patch( + "homeassistant.components.tplink_omada.config_flow.CookieJar", autospec=True + ) as mock_jar, patch( + "homeassistant.components.tplink_omada.config_flow.async_create_clientsession", + return_value="ws", + ) as mock_create_clientsession: + result = await create_omada_client( + hass, + { + "host": "10.10.10.10", + "verify_ssl": True, # Verify is meaningless for IP + "username": "test-username", + "password": "test-password", + }, + ) + + assert result is not None + mock_client.assert_called_once_with( + "https://10.10.10.10", "test-username", "test-password", "ws" + ) + mock_create_clientsession.assert_called_once_with( + hass, cookie_jar=mock_jar.return_value + )