Use ConfigFlow.has_matching_flow to deduplicate tplink flows (#127164)

This commit is contained in:
Erik Montnemery 2024-10-02 19:59:24 +02:00 committed by GitHub
parent fed953023d
commit 1dc1fd421b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 32 additions and 12 deletions

View file

@ -4,7 +4,7 @@ from __future__ import annotations
from collections.abc import Mapping
import logging
from typing import Any
from typing import TYPE_CHECKING, Any, Self
from kasa import (
AuthenticationError,
@ -67,6 +67,8 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
VERSION = 1
MINOR_VERSION = CONF_CONFIG_ENTRY_MINOR_VERSION
host: str | None = None
reauth_entry: ConfigEntry | None = None
def __init__(self) -> None:
@ -156,10 +158,9 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
return result
self._abort_if_unique_id_configured(updates={CONF_HOST: host})
self._async_abort_entries_match({CONF_HOST: host})
self.context[CONF_HOST] = host
for progress in self._async_in_progress():
if progress.get("context", {}).get(CONF_HOST) == host:
return self.async_abort(reason="already_in_progress")
self.host = host
if self.hass.config_entries.flow.async_has_matching_flow(self):
return self.async_abort(reason="already_in_progress")
credentials = await get_credentials(self.hass)
try:
if device:
@ -176,6 +177,10 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
return await self.async_step_discovery_confirm()
def is_matching(self, other_flow: Self) -> bool:
"""Return True if other_flow is matching this flow."""
return other_flow.host == self.host
async def async_step_discovery_auth_confirm(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
@ -263,7 +268,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
if not (host := user_input[CONF_HOST]):
return await self.async_step_pick_device()
self._async_abort_entries_match({CONF_HOST: host})
self.context[CONF_HOST] = host
self.host = host
credentials = await get_credentials(self.hass)
try:
device = await self._async_try_discover_and_update(
@ -289,8 +294,10 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
) -> ConfigFlowResult:
"""Dialog that informs the user that auth is required."""
errors: dict[str, str] = {}
host = self.context[CONF_HOST]
placeholders: dict[str, str] = {CONF_HOST: host}
if TYPE_CHECKING:
# self.host is set by async_step_user and async_step_pick_device
assert self.host is not None
placeholders: dict[str, str] = {CONF_HOST: self.host}
assert self._discovered_device is not None
if user_input:
@ -329,9 +336,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
mac = user_input[CONF_DEVICE]
await self.async_set_unique_id(mac, raise_on_progress=False)
self._discovered_device = self._discovered_devices[mac]
host = self._discovered_device.host
self.context[CONF_HOST] = host
self.host = self._discovered_device.host
credentials = await get_credentials(self.hass)
try:

View file

@ -17,6 +17,7 @@ from homeassistant.components.tplink import (
DeviceConfig,
KasaException,
)
from homeassistant.components.tplink.config_flow import TPLinkConfigFlow
from homeassistant.components.tplink.const import (
CONF_CONNECTION_PARAMETERS,
CONF_CREDENTIALS_HASH,
@ -682,7 +683,19 @@ async def test_discovered_by_discovery_and_dhcp(hass: HomeAssistant) -> None:
assert result["type"] is FlowResultType.FORM
assert result["errors"] is None
with _patch_discovery(), _patch_single_discovery(), _patch_connect():
real_is_matching = TPLinkConfigFlow.is_matching
return_values = []
def is_matching(self, other_flow) -> bool:
return_values.append(real_is_matching(self, other_flow))
return return_values[-1]
with (
_patch_discovery(),
_patch_single_discovery(),
_patch_connect(),
patch.object(TPLinkConfigFlow, "is_matching", wraps=is_matching, autospec=True),
):
result2 = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_DHCP},
@ -693,6 +706,8 @@ async def test_discovered_by_discovery_and_dhcp(hass: HomeAssistant) -> None:
await hass.async_block_till_done()
assert result2["type"] is FlowResultType.ABORT
assert result2["reason"] == "already_in_progress"
# Ensure the is_matching method returned True
assert return_values == [True]
with _patch_discovery(), _patch_single_discovery(), _patch_connect():
result3 = await hass.config_entries.flow.async_init(