Use ConfigFlow.has_matching_flow to deduplicate tplink flows (#127164)

This commit is contained in:
Erik Montnemery 2024-10-02 19:59:24 +02:00 committed by GitHub
parent fed953023d
commit 1dc1fd421b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 32 additions and 12 deletions

View file

@ -4,7 +4,7 @@ from __future__ import annotations
from collections.abc import Mapping from collections.abc import Mapping
import logging import logging
from typing import Any from typing import TYPE_CHECKING, Any, Self
from kasa import ( from kasa import (
AuthenticationError, AuthenticationError,
@ -67,6 +67,8 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
VERSION = 1 VERSION = 1
MINOR_VERSION = CONF_CONFIG_ENTRY_MINOR_VERSION MINOR_VERSION = CONF_CONFIG_ENTRY_MINOR_VERSION
host: str | None = None
reauth_entry: ConfigEntry | None = None reauth_entry: ConfigEntry | None = None
def __init__(self) -> None: def __init__(self) -> None:
@ -156,10 +158,9 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
return result return result
self._abort_if_unique_id_configured(updates={CONF_HOST: host}) self._abort_if_unique_id_configured(updates={CONF_HOST: host})
self._async_abort_entries_match({CONF_HOST: host}) self._async_abort_entries_match({CONF_HOST: host})
self.context[CONF_HOST] = host self.host = host
for progress in self._async_in_progress(): if self.hass.config_entries.flow.async_has_matching_flow(self):
if progress.get("context", {}).get(CONF_HOST) == host: return self.async_abort(reason="already_in_progress")
return self.async_abort(reason="already_in_progress")
credentials = await get_credentials(self.hass) credentials = await get_credentials(self.hass)
try: try:
if device: if device:
@ -176,6 +177,10 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
return await self.async_step_discovery_confirm() return await self.async_step_discovery_confirm()
def is_matching(self, other_flow: Self) -> bool:
"""Return True if other_flow is matching this flow."""
return other_flow.host == self.host
async def async_step_discovery_auth_confirm( async def async_step_discovery_auth_confirm(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult: ) -> ConfigFlowResult:
@ -263,7 +268,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
if not (host := user_input[CONF_HOST]): if not (host := user_input[CONF_HOST]):
return await self.async_step_pick_device() return await self.async_step_pick_device()
self._async_abort_entries_match({CONF_HOST: host}) self._async_abort_entries_match({CONF_HOST: host})
self.context[CONF_HOST] = host self.host = host
credentials = await get_credentials(self.hass) credentials = await get_credentials(self.hass)
try: try:
device = await self._async_try_discover_and_update( device = await self._async_try_discover_and_update(
@ -289,8 +294,10 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Dialog that informs the user that auth is required.""" """Dialog that informs the user that auth is required."""
errors: dict[str, str] = {} errors: dict[str, str] = {}
host = self.context[CONF_HOST] if TYPE_CHECKING:
placeholders: dict[str, str] = {CONF_HOST: host} # self.host is set by async_step_user and async_step_pick_device
assert self.host is not None
placeholders: dict[str, str] = {CONF_HOST: self.host}
assert self._discovered_device is not None assert self._discovered_device is not None
if user_input: if user_input:
@ -329,9 +336,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
mac = user_input[CONF_DEVICE] mac = user_input[CONF_DEVICE]
await self.async_set_unique_id(mac, raise_on_progress=False) await self.async_set_unique_id(mac, raise_on_progress=False)
self._discovered_device = self._discovered_devices[mac] self._discovered_device = self._discovered_devices[mac]
host = self._discovered_device.host self.host = self._discovered_device.host
self.context[CONF_HOST] = host
credentials = await get_credentials(self.hass) credentials = await get_credentials(self.hass)
try: try:

View file

@ -17,6 +17,7 @@ from homeassistant.components.tplink import (
DeviceConfig, DeviceConfig,
KasaException, KasaException,
) )
from homeassistant.components.tplink.config_flow import TPLinkConfigFlow
from homeassistant.components.tplink.const import ( from homeassistant.components.tplink.const import (
CONF_CONNECTION_PARAMETERS, CONF_CONNECTION_PARAMETERS,
CONF_CREDENTIALS_HASH, CONF_CREDENTIALS_HASH,
@ -682,7 +683,19 @@ async def test_discovered_by_discovery_and_dhcp(hass: HomeAssistant) -> None:
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
assert result["errors"] is None assert result["errors"] is None
with _patch_discovery(), _patch_single_discovery(), _patch_connect(): real_is_matching = TPLinkConfigFlow.is_matching
return_values = []
def is_matching(self, other_flow) -> bool:
return_values.append(real_is_matching(self, other_flow))
return return_values[-1]
with (
_patch_discovery(),
_patch_single_discovery(),
_patch_connect(),
patch.object(TPLinkConfigFlow, "is_matching", wraps=is_matching, autospec=True),
):
result2 = await hass.config_entries.flow.async_init( result2 = await hass.config_entries.flow.async_init(
DOMAIN, DOMAIN,
context={"source": config_entries.SOURCE_DHCP}, context={"source": config_entries.SOURCE_DHCP},
@ -693,6 +706,8 @@ async def test_discovered_by_discovery_and_dhcp(hass: HomeAssistant) -> None:
await hass.async_block_till_done() await hass.async_block_till_done()
assert result2["type"] is FlowResultType.ABORT assert result2["type"] is FlowResultType.ABORT
assert result2["reason"] == "already_in_progress" assert result2["reason"] == "already_in_progress"
# Ensure the is_matching method returned True
assert return_values == [True]
with _patch_discovery(), _patch_single_discovery(), _patch_connect(): with _patch_discovery(), _patch_single_discovery(), _patch_connect():
result3 = await hass.config_entries.flow.async_init( result3 = await hass.config_entries.flow.async_init(