From 1dc1fd421b47d6047154bcad8e0acabda23b1cf5 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Wed, 2 Oct 2024 19:59:24 +0200 Subject: [PATCH] Use ConfigFlow.has_matching_flow to deduplicate tplink flows (#127164) --- .../components/tplink/config_flow.py | 27 +++++++++++-------- tests/components/tplink/test_config_flow.py | 17 +++++++++++- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/homeassistant/components/tplink/config_flow.py b/homeassistant/components/tplink/config_flow.py index 03234d545b5..ae7543218c7 100644 --- a/homeassistant/components/tplink/config_flow.py +++ b/homeassistant/components/tplink/config_flow.py @@ -4,7 +4,7 @@ from __future__ import annotations from collections.abc import Mapping import logging -from typing import Any +from typing import TYPE_CHECKING, Any, Self from kasa import ( AuthenticationError, @@ -67,6 +67,8 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): VERSION = 1 MINOR_VERSION = CONF_CONFIG_ENTRY_MINOR_VERSION + + host: str | None = None reauth_entry: ConfigEntry | None = None def __init__(self) -> None: @@ -156,10 +158,9 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): return result self._abort_if_unique_id_configured(updates={CONF_HOST: host}) self._async_abort_entries_match({CONF_HOST: host}) - self.context[CONF_HOST] = host - for progress in self._async_in_progress(): - if progress.get("context", {}).get(CONF_HOST) == host: - return self.async_abort(reason="already_in_progress") + self.host = host + if self.hass.config_entries.flow.async_has_matching_flow(self): + return self.async_abort(reason="already_in_progress") credentials = await get_credentials(self.hass) try: if device: @@ -176,6 +177,10 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): return await self.async_step_discovery_confirm() + def is_matching(self, other_flow: Self) -> bool: + """Return True if other_flow is matching this flow.""" + return other_flow.host == self.host + async def async_step_discovery_auth_confirm( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: @@ -263,7 +268,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): if not (host := user_input[CONF_HOST]): return await self.async_step_pick_device() self._async_abort_entries_match({CONF_HOST: host}) - self.context[CONF_HOST] = host + self.host = host credentials = await get_credentials(self.hass) try: device = await self._async_try_discover_and_update( @@ -289,8 +294,10 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): ) -> ConfigFlowResult: """Dialog that informs the user that auth is required.""" errors: dict[str, str] = {} - host = self.context[CONF_HOST] - placeholders: dict[str, str] = {CONF_HOST: host} + if TYPE_CHECKING: + # self.host is set by async_step_user and async_step_pick_device + assert self.host is not None + placeholders: dict[str, str] = {CONF_HOST: self.host} assert self._discovered_device is not None if user_input: @@ -329,9 +336,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): mac = user_input[CONF_DEVICE] await self.async_set_unique_id(mac, raise_on_progress=False) self._discovered_device = self._discovered_devices[mac] - host = self._discovered_device.host - - self.context[CONF_HOST] = host + self.host = self._discovered_device.host credentials = await get_credentials(self.hass) try: diff --git a/tests/components/tplink/test_config_flow.py b/tests/components/tplink/test_config_flow.py index 7b24769c858..40bd4383513 100644 --- a/tests/components/tplink/test_config_flow.py +++ b/tests/components/tplink/test_config_flow.py @@ -17,6 +17,7 @@ from homeassistant.components.tplink import ( DeviceConfig, KasaException, ) +from homeassistant.components.tplink.config_flow import TPLinkConfigFlow from homeassistant.components.tplink.const import ( CONF_CONNECTION_PARAMETERS, CONF_CREDENTIALS_HASH, @@ -682,7 +683,19 @@ async def test_discovered_by_discovery_and_dhcp(hass: HomeAssistant) -> None: assert result["type"] is FlowResultType.FORM assert result["errors"] is None - with _patch_discovery(), _patch_single_discovery(), _patch_connect(): + real_is_matching = TPLinkConfigFlow.is_matching + return_values = [] + + def is_matching(self, other_flow) -> bool: + return_values.append(real_is_matching(self, other_flow)) + return return_values[-1] + + with ( + _patch_discovery(), + _patch_single_discovery(), + _patch_connect(), + patch.object(TPLinkConfigFlow, "is_matching", wraps=is_matching, autospec=True), + ): result2 = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_DHCP}, @@ -693,6 +706,8 @@ async def test_discovered_by_discovery_and_dhcp(hass: HomeAssistant) -> None: await hass.async_block_till_done() assert result2["type"] is FlowResultType.ABORT assert result2["reason"] == "already_in_progress" + # Ensure the is_matching method returned True + assert return_values == [True] with _patch_discovery(), _patch_single_discovery(), _patch_connect(): result3 = await hass.config_entries.flow.async_init(