Enable overriding connection port for tplink devices (#129619)
Enable setting a port override during manual config entry setup. The feature will be undocumented as it's quite a specialized use case generally used for testing purposes.
This commit is contained in:
parent
f49547d598
commit
03c3d09583
4 changed files with 163 additions and 16 deletions
|
@ -31,6 +31,7 @@ from homeassistant.const import (
|
|||
CONF_MAC,
|
||||
CONF_MODEL,
|
||||
CONF_PASSWORD,
|
||||
CONF_PORT,
|
||||
CONF_USERNAME,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
|
@ -141,6 +142,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: TPLinkConfigEntry) -> bo
|
|||
entry_credentials_hash = entry.data.get(CONF_CREDENTIALS_HASH)
|
||||
entry_use_http = entry.data.get(CONF_USES_HTTP, False)
|
||||
entry_aes_keys = entry.data.get(CONF_AES_KEYS)
|
||||
port_override = entry.data.get(CONF_PORT)
|
||||
|
||||
conn_params: Device.ConnectionParameters | None = None
|
||||
if conn_params_dict := entry.data.get(CONF_CONNECTION_PARAMETERS):
|
||||
|
@ -157,6 +159,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: TPLinkConfigEntry) -> bo
|
|||
timeout=CONNECT_TIMEOUT,
|
||||
http_client=client,
|
||||
aes_keys=entry_aes_keys,
|
||||
port_override=port_override,
|
||||
)
|
||||
if conn_params:
|
||||
config.connection_type = conn_params
|
||||
|
|
|
@ -32,6 +32,7 @@ from homeassistant.const import (
|
|||
CONF_MAC,
|
||||
CONF_MODEL,
|
||||
CONF_PASSWORD,
|
||||
CONF_PORT,
|
||||
CONF_USERNAME,
|
||||
)
|
||||
from homeassistant.core import callback
|
||||
|
@ -69,6 +70,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
MINOR_VERSION = CONF_CONFIG_ENTRY_MINOR_VERSION
|
||||
|
||||
host: str | None = None
|
||||
port: int | None = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the config flow."""
|
||||
|
@ -260,6 +262,26 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
step_id="discovery_confirm", description_placeholders=placeholders
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _async_get_host_port(host_str: str) -> tuple[str, int | None]:
|
||||
"""Parse the host string for host and port."""
|
||||
if "[" in host_str:
|
||||
_, _, bracketed = host_str.partition("[")
|
||||
host, _, port_str = bracketed.partition("]")
|
||||
_, _, port_str = port_str.partition(":")
|
||||
else:
|
||||
host, _, port_str = host_str.partition(":")
|
||||
|
||||
if not port_str:
|
||||
return host, None
|
||||
|
||||
try:
|
||||
port = int(port_str)
|
||||
except ValueError:
|
||||
return host, None
|
||||
|
||||
return host, port
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
|
@ -270,14 +292,29 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
if user_input is not None:
|
||||
if not (host := user_input[CONF_HOST]):
|
||||
return await self.async_step_pick_device()
|
||||
self._async_abort_entries_match({CONF_HOST: host})
|
||||
|
||||
host, port = self._async_get_host_port(host)
|
||||
|
||||
match_dict = {CONF_HOST: host}
|
||||
if port:
|
||||
self.port = port
|
||||
match_dict[CONF_PORT] = port
|
||||
self._async_abort_entries_match(match_dict)
|
||||
|
||||
self.host = host
|
||||
credentials = await get_credentials(self.hass)
|
||||
try:
|
||||
device = await self._async_try_discover_and_update(
|
||||
host, credentials, raise_on_progress=False, raise_on_timeout=False
|
||||
host,
|
||||
credentials,
|
||||
raise_on_progress=False,
|
||||
raise_on_timeout=False,
|
||||
port=port,
|
||||
) or await self._async_try_connect_all(
|
||||
host, credentials=credentials, raise_on_progress=False
|
||||
host,
|
||||
credentials=credentials,
|
||||
raise_on_progress=False,
|
||||
port=port,
|
||||
)
|
||||
except AuthenticationError:
|
||||
return await self.async_step_user_auth_confirm()
|
||||
|
@ -318,7 +355,10 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
)
|
||||
else:
|
||||
device = await self._async_try_connect_all(
|
||||
self.host, credentials=credentials, raise_on_progress=False
|
||||
self.host,
|
||||
credentials=credentials,
|
||||
raise_on_progress=False,
|
||||
port=self.port,
|
||||
)
|
||||
except AuthenticationError as ex:
|
||||
errors[CONF_PASSWORD] = "invalid_auth"
|
||||
|
@ -420,6 +460,8 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
data[CONF_AES_KEYS] = device.config.aes_keys
|
||||
if device.credentials_hash:
|
||||
data[CONF_CREDENTIALS_HASH] = device.credentials_hash
|
||||
if port := device.config.port_override:
|
||||
data[CONF_PORT] = port
|
||||
return self.async_create_entry(
|
||||
title=f"{device.alias} {device.model}",
|
||||
data=data,
|
||||
|
@ -430,6 +472,8 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
host: str,
|
||||
credentials: Credentials | None,
|
||||
raise_on_progress: bool,
|
||||
*,
|
||||
port: int | None = None,
|
||||
) -> Device | None:
|
||||
"""Try to connect to the device speculatively.
|
||||
|
||||
|
@ -441,12 +485,15 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
host,
|
||||
credentials=credentials,
|
||||
http_client=create_async_tplink_clientsession(self.hass),
|
||||
port=port,
|
||||
)
|
||||
else:
|
||||
# This will just try the legacy protocol that doesn't require auth
|
||||
# and doesn't use http
|
||||
try:
|
||||
device = await Device.connect(config=DeviceConfig(host))
|
||||
device = await Device.connect(
|
||||
config=DeviceConfig(host, port_override=port)
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
return None
|
||||
if device:
|
||||
|
@ -462,6 +509,8 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
credentials: Credentials | None,
|
||||
raise_on_progress: bool,
|
||||
raise_on_timeout: bool,
|
||||
*,
|
||||
port: int | None = None,
|
||||
) -> Device | None:
|
||||
"""Try to discover the device and call update.
|
||||
|
||||
|
@ -470,7 +519,9 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
self._discovered_device = None
|
||||
try:
|
||||
self._discovered_device = await Discover.discover_single(
|
||||
host, credentials=credentials
|
||||
host,
|
||||
credentials=credentials,
|
||||
port=port,
|
||||
)
|
||||
except TimeoutError as ex:
|
||||
if raise_on_timeout:
|
||||
|
@ -526,6 +577,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
reauth_entry = self._get_reauth_entry()
|
||||
entry_data = reauth_entry.data
|
||||
host = entry_data[CONF_HOST]
|
||||
port = entry_data.get(CONF_PORT)
|
||||
|
||||
if user_input:
|
||||
username = user_input[CONF_USERNAME]
|
||||
|
@ -537,8 +589,12 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
credentials=credentials,
|
||||
raise_on_progress=False,
|
||||
raise_on_timeout=False,
|
||||
port=port,
|
||||
) or await self._async_try_connect_all(
|
||||
host, credentials=credentials, raise_on_progress=False
|
||||
host,
|
||||
credentials=credentials,
|
||||
raise_on_progress=False,
|
||||
port=port,
|
||||
)
|
||||
except AuthenticationError as ex:
|
||||
errors[CONF_PASSWORD] = "invalid_auth"
|
||||
|
|
|
@ -37,7 +37,7 @@ def mock_discovery():
|
|||
device = _mocked_device(
|
||||
device_config=DeviceConfig.from_dict(DEVICE_CONFIG_KLAP.to_dict()),
|
||||
credentials_hash=CREDENTIALS_HASH_KLAP,
|
||||
alias=None,
|
||||
alias="My Bulb",
|
||||
)
|
||||
devices = {
|
||||
"127.0.0.1": _mocked_device(
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from contextlib import contextmanager
|
||||
import logging
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import ANY, AsyncMock, patch
|
||||
|
||||
from kasa import TimeoutError
|
||||
import pytest
|
||||
|
@ -30,6 +30,7 @@ from homeassistant.const import (
|
|||
CONF_HOST,
|
||||
CONF_MAC,
|
||||
CONF_PASSWORD,
|
||||
CONF_PORT,
|
||||
CONF_USERNAME,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
@ -665,6 +666,93 @@ async def test_manual_auth_errors(
|
|||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("host_str", "host", "port"),
|
||||
[
|
||||
(f"{IP_ADDRESS}:1234", IP_ADDRESS, 1234),
|
||||
("[2001:db8:0::1]:4321", "2001:db8:0::1", 4321),
|
||||
],
|
||||
)
|
||||
async def test_manual_port_override(
|
||||
hass: HomeAssistant,
|
||||
mock_connect: AsyncMock,
|
||||
mock_discovery: AsyncMock,
|
||||
host_str,
|
||||
host,
|
||||
port,
|
||||
) -> None:
|
||||
"""Test manually setup."""
|
||||
mock_discovery["mock_device"].config.port_override = port
|
||||
mock_discovery["mock_device"].host = host
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "user"
|
||||
assert not result["errors"]
|
||||
|
||||
# side_effects to cause auth confirm as the port override usually only
|
||||
# works with direct connections.
|
||||
mock_discovery["discover_single"].side_effect = TimeoutError
|
||||
mock_connect["connect"].side_effect = AuthenticationError
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {CONF_HOST: host_str}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result2["type"] is FlowResultType.FORM
|
||||
assert result2["step_id"] == "user_auth_confirm"
|
||||
assert not result2["errors"]
|
||||
|
||||
creds = Credentials("fake_username", "fake_password")
|
||||
result3 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"],
|
||||
user_input={
|
||||
CONF_USERNAME: "fake_username",
|
||||
CONF_PASSWORD: "fake_password",
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
mock_discovery["try_connect_all"].assert_called_once_with(
|
||||
host, credentials=creds, port=port, http_client=ANY
|
||||
)
|
||||
assert result3["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result3["title"] == DEFAULT_ENTRY_TITLE
|
||||
assert result3["data"] == {
|
||||
**CREATE_ENTRY_DATA_KLAP,
|
||||
CONF_PORT: port,
|
||||
CONF_HOST: host,
|
||||
}
|
||||
assert result3["context"]["unique_id"] == MAC_ADDRESS
|
||||
|
||||
|
||||
async def test_manual_port_override_invalid(
|
||||
hass: HomeAssistant, mock_connect: AsyncMock, mock_discovery: AsyncMock
|
||||
) -> None:
|
||||
"""Test manually setup."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "user"
|
||||
assert not result["errors"]
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {CONF_HOST: f"{IP_ADDRESS}:foo"}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
mock_discovery["discover_single"].assert_called_once_with(
|
||||
"127.0.0.1", credentials=None, port=None
|
||||
)
|
||||
|
||||
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result2["title"] == DEFAULT_ENTRY_TITLE
|
||||
assert result2["data"] == CREATE_ENTRY_DATA_KLAP
|
||||
assert result2["context"]["unique_id"] == MAC_ADDRESS
|
||||
|
||||
|
||||
async def test_discovered_by_discovery_and_dhcp(hass: HomeAssistant) -> None:
|
||||
"""Test we get the form with discovery and abort for dhcp source when we get both."""
|
||||
|
||||
|
@ -1072,7 +1160,7 @@ async def test_reauth(
|
|||
)
|
||||
credentials = Credentials("fake_username", "fake_password")
|
||||
mock_discovery["discover_single"].assert_called_once_with(
|
||||
"127.0.0.1", credentials=credentials
|
||||
"127.0.0.1", credentials=credentials, port=None
|
||||
)
|
||||
mock_discovery["mock_device"].update.assert_called_once_with()
|
||||
assert result2["type"] is FlowResultType.ABORT
|
||||
|
@ -1107,7 +1195,7 @@ async def test_reauth_try_connect_all(
|
|||
)
|
||||
credentials = Credentials("fake_username", "fake_password")
|
||||
mock_discovery["discover_single"].assert_called_once_with(
|
||||
"127.0.0.1", credentials=credentials
|
||||
"127.0.0.1", credentials=credentials, port=None
|
||||
)
|
||||
mock_discovery["try_connect_all"].assert_called_once()
|
||||
assert result2["type"] is FlowResultType.ABORT
|
||||
|
@ -1145,7 +1233,7 @@ async def test_reauth_try_connect_all_fail(
|
|||
)
|
||||
credentials = Credentials("fake_username", "fake_password")
|
||||
mock_discovery["discover_single"].assert_called_once_with(
|
||||
"127.0.0.1", credentials=credentials
|
||||
"127.0.0.1", credentials=credentials, port=None
|
||||
)
|
||||
mock_discovery["try_connect_all"].assert_called_once()
|
||||
assert result2["errors"] == {"base": "cannot_connect"}
|
||||
|
@ -1214,7 +1302,7 @@ async def test_reauth_update_with_encryption_change(
|
|||
assert "Connection type changed for 127.0.0.2" in caplog.text
|
||||
credentials = Credentials("fake_username", "fake_password")
|
||||
mock_discovery["discover_single"].assert_called_once_with(
|
||||
"127.0.0.2", credentials=credentials
|
||||
"127.0.0.2", credentials=credentials, port=None
|
||||
)
|
||||
mock_discovery["mock_device"].update.assert_called_once_with()
|
||||
assert result2["type"] is FlowResultType.ABORT
|
||||
|
@ -1416,7 +1504,7 @@ async def test_reauth_errors(
|
|||
credentials = Credentials("fake_username", "fake_password")
|
||||
|
||||
mock_discovery["discover_single"].assert_called_once_with(
|
||||
"127.0.0.1", credentials=credentials
|
||||
"127.0.0.1", credentials=credentials, port=None
|
||||
)
|
||||
mock_discovery["mock_device"].update.assert_called_once_with()
|
||||
assert result2["type"] is FlowResultType.FORM
|
||||
|
@ -1434,7 +1522,7 @@ async def test_reauth_errors(
|
|||
)
|
||||
|
||||
mock_discovery["discover_single"].assert_called_once_with(
|
||||
"127.0.0.1", credentials=credentials
|
||||
"127.0.0.1", credentials=credentials, port=None
|
||||
)
|
||||
mock_discovery["mock_device"].update.assert_called_once_with()
|
||||
|
||||
|
@ -1643,7 +1731,7 @@ async def test_reauth_update_other_flows(
|
|||
)
|
||||
credentials = Credentials("fake_username", "fake_password")
|
||||
mock_discovery["discover_single"].assert_called_once_with(
|
||||
"127.0.0.1", credentials=credentials
|
||||
"127.0.0.1", credentials=credentials, port=None
|
||||
)
|
||||
mock_discovery["mock_device"].update.assert_called_once_with()
|
||||
assert result2["type"] is FlowResultType.ABORT
|
||||
|
|
Loading…
Add table
Reference in a new issue