Enable overriding connection port for tplink devices (#129619)

Enable setting a port override during manual config entry setup.

The feature will be undocumented as it's quite a specialized use case generally used for testing purposes.
This commit is contained in:
Steven B. 2024-11-08 13:41:00 +00:00 committed by GitHub
parent f49547d598
commit 03c3d09583
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 163 additions and 16 deletions

View file

@ -31,6 +31,7 @@ from homeassistant.const import (
CONF_MAC,
CONF_MODEL,
CONF_PASSWORD,
CONF_PORT,
CONF_USERNAME,
)
from homeassistant.core import HomeAssistant, callback
@ -141,6 +142,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: TPLinkConfigEntry) -> bo
entry_credentials_hash = entry.data.get(CONF_CREDENTIALS_HASH)
entry_use_http = entry.data.get(CONF_USES_HTTP, False)
entry_aes_keys = entry.data.get(CONF_AES_KEYS)
port_override = entry.data.get(CONF_PORT)
conn_params: Device.ConnectionParameters | None = None
if conn_params_dict := entry.data.get(CONF_CONNECTION_PARAMETERS):
@ -157,6 +159,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: TPLinkConfigEntry) -> bo
timeout=CONNECT_TIMEOUT,
http_client=client,
aes_keys=entry_aes_keys,
port_override=port_override,
)
if conn_params:
config.connection_type = conn_params

View file

@ -32,6 +32,7 @@ from homeassistant.const import (
CONF_MAC,
CONF_MODEL,
CONF_PASSWORD,
CONF_PORT,
CONF_USERNAME,
)
from homeassistant.core import callback
@ -69,6 +70,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
MINOR_VERSION = CONF_CONFIG_ENTRY_MINOR_VERSION
host: str | None = None
port: int | None = None
def __init__(self) -> None:
"""Initialize the config flow."""
@ -260,6 +262,26 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
step_id="discovery_confirm", description_placeholders=placeholders
)
@staticmethod
def _async_get_host_port(host_str: str) -> tuple[str, int | None]:
"""Parse the host string for host and port."""
if "[" in host_str:
_, _, bracketed = host_str.partition("[")
host, _, port_str = bracketed.partition("]")
_, _, port_str = port_str.partition(":")
else:
host, _, port_str = host_str.partition(":")
if not port_str:
return host, None
try:
port = int(port_str)
except ValueError:
return host, None
return host, port
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
@ -270,14 +292,29 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
if user_input is not None:
if not (host := user_input[CONF_HOST]):
return await self.async_step_pick_device()
self._async_abort_entries_match({CONF_HOST: host})
host, port = self._async_get_host_port(host)
match_dict = {CONF_HOST: host}
if port:
self.port = port
match_dict[CONF_PORT] = port
self._async_abort_entries_match(match_dict)
self.host = host
credentials = await get_credentials(self.hass)
try:
device = await self._async_try_discover_and_update(
host, credentials, raise_on_progress=False, raise_on_timeout=False
host,
credentials,
raise_on_progress=False,
raise_on_timeout=False,
port=port,
) or await self._async_try_connect_all(
host, credentials=credentials, raise_on_progress=False
host,
credentials=credentials,
raise_on_progress=False,
port=port,
)
except AuthenticationError:
return await self.async_step_user_auth_confirm()
@ -318,7 +355,10 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
)
else:
device = await self._async_try_connect_all(
self.host, credentials=credentials, raise_on_progress=False
self.host,
credentials=credentials,
raise_on_progress=False,
port=self.port,
)
except AuthenticationError as ex:
errors[CONF_PASSWORD] = "invalid_auth"
@ -420,6 +460,8 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
data[CONF_AES_KEYS] = device.config.aes_keys
if device.credentials_hash:
data[CONF_CREDENTIALS_HASH] = device.credentials_hash
if port := device.config.port_override:
data[CONF_PORT] = port
return self.async_create_entry(
title=f"{device.alias} {device.model}",
data=data,
@ -430,6 +472,8 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
host: str,
credentials: Credentials | None,
raise_on_progress: bool,
*,
port: int | None = None,
) -> Device | None:
"""Try to connect to the device speculatively.
@ -441,12 +485,15 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
host,
credentials=credentials,
http_client=create_async_tplink_clientsession(self.hass),
port=port,
)
else:
# This will just try the legacy protocol that doesn't require auth
# and doesn't use http
try:
device = await Device.connect(config=DeviceConfig(host))
device = await Device.connect(
config=DeviceConfig(host, port_override=port)
)
except Exception: # noqa: BLE001
return None
if device:
@ -462,6 +509,8 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
credentials: Credentials | None,
raise_on_progress: bool,
raise_on_timeout: bool,
*,
port: int | None = None,
) -> Device | None:
"""Try to discover the device and call update.
@ -470,7 +519,9 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
self._discovered_device = None
try:
self._discovered_device = await Discover.discover_single(
host, credentials=credentials
host,
credentials=credentials,
port=port,
)
except TimeoutError as ex:
if raise_on_timeout:
@ -526,6 +577,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
reauth_entry = self._get_reauth_entry()
entry_data = reauth_entry.data
host = entry_data[CONF_HOST]
port = entry_data.get(CONF_PORT)
if user_input:
username = user_input[CONF_USERNAME]
@ -537,8 +589,12 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
credentials=credentials,
raise_on_progress=False,
raise_on_timeout=False,
port=port,
) or await self._async_try_connect_all(
host, credentials=credentials, raise_on_progress=False
host,
credentials=credentials,
raise_on_progress=False,
port=port,
)
except AuthenticationError as ex:
errors[CONF_PASSWORD] = "invalid_auth"

View file

@ -37,7 +37,7 @@ def mock_discovery():
device = _mocked_device(
device_config=DeviceConfig.from_dict(DEVICE_CONFIG_KLAP.to_dict()),
credentials_hash=CREDENTIALS_HASH_KLAP,
alias=None,
alias="My Bulb",
)
devices = {
"127.0.0.1": _mocked_device(

View file

@ -2,7 +2,7 @@
from contextlib import contextmanager
import logging
from unittest.mock import AsyncMock, patch
from unittest.mock import ANY, AsyncMock, patch
from kasa import TimeoutError
import pytest
@ -30,6 +30,7 @@ from homeassistant.const import (
CONF_HOST,
CONF_MAC,
CONF_PASSWORD,
CONF_PORT,
CONF_USERNAME,
)
from homeassistant.core import HomeAssistant
@ -665,6 +666,93 @@ async def test_manual_auth_errors(
await hass.async_block_till_done()
@pytest.mark.parametrize(
("host_str", "host", "port"),
[
(f"{IP_ADDRESS}:1234", IP_ADDRESS, 1234),
("[2001:db8:0::1]:4321", "2001:db8:0::1", 4321),
],
)
async def test_manual_port_override(
hass: HomeAssistant,
mock_connect: AsyncMock,
mock_discovery: AsyncMock,
host_str,
host,
port,
) -> None:
"""Test manually setup."""
mock_discovery["mock_device"].config.port_override = port
mock_discovery["mock_device"].host = host
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "user"
assert not result["errors"]
# side_effects to cause auth confirm as the port override usually only
# works with direct connections.
mock_discovery["discover_single"].side_effect = TimeoutError
mock_connect["connect"].side_effect = AuthenticationError
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], {CONF_HOST: host_str}
)
await hass.async_block_till_done()
assert result2["type"] is FlowResultType.FORM
assert result2["step_id"] == "user_auth_confirm"
assert not result2["errors"]
creds = Credentials("fake_username", "fake_password")
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"],
user_input={
CONF_USERNAME: "fake_username",
CONF_PASSWORD: "fake_password",
},
)
await hass.async_block_till_done()
mock_discovery["try_connect_all"].assert_called_once_with(
host, credentials=creds, port=port, http_client=ANY
)
assert result3["type"] is FlowResultType.CREATE_ENTRY
assert result3["title"] == DEFAULT_ENTRY_TITLE
assert result3["data"] == {
**CREATE_ENTRY_DATA_KLAP,
CONF_PORT: port,
CONF_HOST: host,
}
assert result3["context"]["unique_id"] == MAC_ADDRESS
async def test_manual_port_override_invalid(
hass: HomeAssistant, mock_connect: AsyncMock, mock_discovery: AsyncMock
) -> None:
"""Test manually setup."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "user"
assert not result["errors"]
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], {CONF_HOST: f"{IP_ADDRESS}:foo"}
)
await hass.async_block_till_done()
mock_discovery["discover_single"].assert_called_once_with(
"127.0.0.1", credentials=None, port=None
)
assert result2["type"] is FlowResultType.CREATE_ENTRY
assert result2["title"] == DEFAULT_ENTRY_TITLE
assert result2["data"] == CREATE_ENTRY_DATA_KLAP
assert result2["context"]["unique_id"] == MAC_ADDRESS
async def test_discovered_by_discovery_and_dhcp(hass: HomeAssistant) -> None:
"""Test we get the form with discovery and abort for dhcp source when we get both."""
@ -1072,7 +1160,7 @@ async def test_reauth(
)
credentials = Credentials("fake_username", "fake_password")
mock_discovery["discover_single"].assert_called_once_with(
"127.0.0.1", credentials=credentials
"127.0.0.1", credentials=credentials, port=None
)
mock_discovery["mock_device"].update.assert_called_once_with()
assert result2["type"] is FlowResultType.ABORT
@ -1107,7 +1195,7 @@ async def test_reauth_try_connect_all(
)
credentials = Credentials("fake_username", "fake_password")
mock_discovery["discover_single"].assert_called_once_with(
"127.0.0.1", credentials=credentials
"127.0.0.1", credentials=credentials, port=None
)
mock_discovery["try_connect_all"].assert_called_once()
assert result2["type"] is FlowResultType.ABORT
@ -1145,7 +1233,7 @@ async def test_reauth_try_connect_all_fail(
)
credentials = Credentials("fake_username", "fake_password")
mock_discovery["discover_single"].assert_called_once_with(
"127.0.0.1", credentials=credentials
"127.0.0.1", credentials=credentials, port=None
)
mock_discovery["try_connect_all"].assert_called_once()
assert result2["errors"] == {"base": "cannot_connect"}
@ -1214,7 +1302,7 @@ async def test_reauth_update_with_encryption_change(
assert "Connection type changed for 127.0.0.2" in caplog.text
credentials = Credentials("fake_username", "fake_password")
mock_discovery["discover_single"].assert_called_once_with(
"127.0.0.2", credentials=credentials
"127.0.0.2", credentials=credentials, port=None
)
mock_discovery["mock_device"].update.assert_called_once_with()
assert result2["type"] is FlowResultType.ABORT
@ -1416,7 +1504,7 @@ async def test_reauth_errors(
credentials = Credentials("fake_username", "fake_password")
mock_discovery["discover_single"].assert_called_once_with(
"127.0.0.1", credentials=credentials
"127.0.0.1", credentials=credentials, port=None
)
mock_discovery["mock_device"].update.assert_called_once_with()
assert result2["type"] is FlowResultType.FORM
@ -1434,7 +1522,7 @@ async def test_reauth_errors(
)
mock_discovery["discover_single"].assert_called_once_with(
"127.0.0.1", credentials=credentials
"127.0.0.1", credentials=credentials, port=None
)
mock_discovery["mock_device"].update.assert_called_once_with()
@ -1643,7 +1731,7 @@ async def test_reauth_update_other_flows(
)
credentials = Credentials("fake_username", "fake_password")
mock_discovery["discover_single"].assert_called_once_with(
"127.0.0.1", credentials=credentials
"127.0.0.1", credentials=credentials, port=None
)
mock_discovery["mock_device"].update.assert_called_once_with()
assert result2["type"] is FlowResultType.ABORT