Add FlowManager.async_has_matching_flow (#126804)

* Add FlowManager.async_flow_has_matching_flow

* Revert changes from the future

* Apply suggested changes to apple_tv config flow

* Rename methods after discussion

* Update homeassistant/data_entry_flow.py

Co-authored-by: J. Nick Koston <nick@koston.org>

* Move deduplication functions to config_entries, add tests

* Adjust tests

---------

Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Erik Montnemery 2024-09-27 10:51:36 +02:00 committed by GitHub
parent 26b5dab12b
commit 3c0be47d3c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 262 additions and 121 deletions

View file

@ -8,7 +8,7 @@ from collections.abc import Awaitable, Callable, Mapping
from ipaddress import ip_address from ipaddress import ip_address
import logging import logging
from random import randrange from random import randrange
from typing import Any from typing import Any, Self
from pyatv import exceptions, pair, scan from pyatv import exceptions, pair, scan
from pyatv.const import DeviceModel, PairingRequirement, Protocol from pyatv.const import DeviceModel, PairingRequirement, Protocol
@ -98,8 +98,11 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
VERSION = 1 VERSION = 1
scan_filter: str | None = None scan_filter: str | None = None
all_identifiers: set[str]
atv: BaseConfig | None = None atv: BaseConfig | None = None
atv_identifiers: list[str] | None = None atv_identifiers: list[str] | None = None
_host: str # host in zeroconf discovery info, should not be accessed by other flows
host: str | None = None # set by _async_aggregate_discoveries, for other flows
protocol: Protocol | None = None protocol: Protocol | None = None
pairing: PairingHandler | None = None pairing: PairingHandler | None = None
protocols_to_pair: deque[Protocol] | None = None protocols_to_pair: deque[Protocol] | None = None
@ -157,7 +160,6 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
"type": "Apple TV", "type": "Apple TV",
} }
self.scan_filter = self.unique_id self.scan_filter = self.unique_id
self.context["identifier"] = self.unique_id
return await self.async_step_restore_device() return await self.async_step_restore_device()
async def async_step_restore_device( async def async_step_restore_device(
@ -192,7 +194,7 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
self.device_identifier, raise_on_progress=False self.device_identifier, raise_on_progress=False
) )
assert self.atv assert self.atv
self.context["all_identifiers"] = self.atv.all_identifiers self.all_identifiers = set(self.atv.all_identifiers)
return await self.async_step_confirm() return await self.async_step_confirm()
return self.async_show_form( return self.async_show_form(
@ -207,7 +209,7 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle device found via zeroconf.""" """Handle device found via zeroconf."""
if discovery_info.ip_address.version == 6: if discovery_info.ip_address.version == 6:
return self.async_abort(reason="ipv6_not_supported") return self.async_abort(reason="ipv6_not_supported")
host = discovery_info.host self._host = host = discovery_info.host
service_type = discovery_info.type[:-1] # Remove leading . service_type = discovery_info.type[:-1] # Remove leading .
name = discovery_info.name.replace(f".{service_type}.", "") name = discovery_info.name.replace(f".{service_type}.", "")
properties = discovery_info.properties properties = discovery_info.properties
@ -255,7 +257,7 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
# as two separate flows. # as two separate flows.
# #
# To solve this, all identifiers are stored as # To solve this, all identifiers are stored as
# "all_identifiers" in the flow context. When a new service is discovered, the # "all_identifiers" in the flow. When a new service is discovered, the
# code below will check these identifiers for all active flows and abort if a # code below will check these identifiers for all active flows and abort if a
# match is found. Before aborting, the original flow is updated with any # match is found. Before aborting, the original flow is updated with any
# potentially new identifiers. In the example above, when service C is # potentially new identifiers. In the example above, when service C is
@ -277,32 +279,32 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
self._async_check_and_update_in_progress(host, unique_id) self._async_check_and_update_in_progress(host, unique_id)
# Host must only be set AFTER checking and updating in progress # Host must only be set AFTER checking and updating in progress
# flows or we will have a race condition where no flows move forward. # flows or we will have a race condition where no flows move forward.
self.context[CONF_ADDRESS] = host self.host = host
@callback @callback
def _async_check_and_update_in_progress(self, host: str, unique_id: str) -> None: def _async_check_and_update_in_progress(self, host: str, unique_id: str) -> None:
"""Check for in-progress flows and update them with identifiers if needed.""" """Check for in-progress flows and update them with identifiers if needed."""
for flow in self._async_in_progress(include_uninitialized=True): if self.hass.config_entries.flow.async_has_matching_flow(self):
context = flow["context"]
if (
context.get("source") != SOURCE_ZEROCONF
or context.get(CONF_ADDRESS) != host
):
continue
if (
"all_identifiers" in context
and unique_id not in context["all_identifiers"]
):
# Add potentially new identifiers from this device to the existing flow
context["all_identifiers"].append(unique_id)
raise AbortFlow("already_in_progress") raise AbortFlow("already_in_progress")
def is_matching(self, other_flow: Self) -> bool:
"""Return True if other_flow is matching this flow."""
if (
other_flow.context.get("source") != SOURCE_ZEROCONF
or other_flow.host != self._host
):
return False
if self.unique_id is not None:
# Add potentially new identifiers from this device to the existing flow
other_flow.all_identifiers.add(self.unique_id)
return True
async def async_found_zeroconf_device( async def async_found_zeroconf_device(
self, user_input: dict[str, str] | None = None self, user_input: dict[str, str] | None = None
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Handle device found after Zeroconf discovery.""" """Handle device found after Zeroconf discovery."""
assert self.atv assert self.atv
self.context["all_identifiers"] = self.atv.all_identifiers self.all_identifiers = set(self.atv.all_identifiers)
# Also abort if an integration with this identifier already exists # Also abort if an integration with this identifier already exists
await self.async_set_unique_id(self.device_identifier) await self.async_set_unique_id(self.device_identifier)
# but be sure to update the address if its changed so the scanner # but be sure to update the address if its changed so the scanner
@ -310,7 +312,6 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
self._abort_if_unique_id_configured( self._abort_if_unique_id_configured(
updates={CONF_ADDRESS: str(self.atv.address)} updates={CONF_ADDRESS: str(self.atv.address)}
) )
self.context["identifier"] = self.unique_id
return await self.async_step_confirm() return await self.async_step_confirm()
async def async_find_device_wrapper( async def async_find_device_wrapper(
@ -390,7 +391,7 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle user-confirmation of discovered node.""" """Handle user-confirmation of discovered node."""
assert self.atv assert self.atv
if user_input is not None: if user_input is not None:
expected_identifier_count = len(self.context["all_identifiers"]) expected_identifier_count = len(self.all_identifiers)
# If number of services found during device scan mismatch number of # If number of services found during device scan mismatch number of
# identifiers collected during Zeroconf discovery, then trigger a new scan # identifiers collected during Zeroconf discovery, then trigger a new scan
# with hopes of finding all services. # with hopes of finding all services.

View file

@ -1544,6 +1544,35 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
notification_id=DISCOVERY_NOTIFICATION_ID, notification_id=DISCOVERY_NOTIFICATION_ID,
) )
@callback
def async_has_matching_discovery_flow(
self, handler: str, match_context: dict[str, Any], data: Any
) -> bool:
"""Check if an existing matching discovery flow is in progress.
A flow with the same handler, context, and data.
If match_context is passed, only return flows with a context that is a
superset of match_context.
"""
if not (flows := self._handler_progress_index.get(handler)):
return False
match_items = match_context.items()
for progress in flows:
if match_items <= progress.context.items() and progress.init_data == data:
return True
return False
@callback
def async_has_matching_flow(self, flow: ConfigFlow) -> bool:
"""Check if an existing matching flow is in progress."""
if not (flows := self._handler_progress_index.get(flow.handler)):
return False
for other_flow in flows:
if other_flow is not flow and flow.is_matching(other_flow): # type: ignore[arg-type]
return True
return False
class ConfigEntryItems(UserDict[str, ConfigEntry]): class ConfigEntryItems(UserDict[str, ConfigEntry]):
"""Container for config items, maps config_entry_id -> entry. """Container for config items, maps config_entry_id -> entry.
@ -2693,6 +2722,10 @@ class ConfigFlow(ConfigEntryBaseFlow):
self.hass.config_entries.async_schedule_reload(entry.entry_id) self.hass.config_entries.async_schedule_reload(entry.entry_id)
return self.async_abort(reason=reason) return self.async_abort(reason=reason)
def is_matching(self, other_flow: Self) -> bool:
"""Return True if other_flow is matching this flow."""
raise NotImplementedError
class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]): class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
"""Flow to set options for a configuration entry.""" """Flow to set options for a configuration entry."""

View file

@ -237,25 +237,6 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
) -> None: ) -> None:
"""Entry has finished executing its first step asynchronously.""" """Entry has finished executing its first step asynchronously."""
@callback
def async_has_matching_flow(
self, handler: _HandlerT, match_context: dict[str, Any], data: Any
) -> bool:
"""Check if an existing matching flow is in progress.
A flow with the same handler, context, and data.
If match_context is passed, only return flows with a context that is a
superset of match_context.
"""
if not (flows := self._handler_progress_index.get(handler)):
return False
match_items = match_context.items()
for progress in flows:
if match_items <= progress.context.items() and progress.init_data == data:
return True
return False
@callback @callback
def async_get(self, flow_id: str) -> _FlowResultT: def async_get(self, flow_id: str) -> _FlowResultT:
"""Return a flow in progress as a partial FlowResult.""" """Return a flow in progress as a partial FlowResult."""

View file

@ -78,7 +78,9 @@ def _async_init_flow(
# which can overload devices since zeroconf/ssdp updates can happen # which can overload devices since zeroconf/ssdp updates can happen
# multiple times in the same minute # multiple times in the same minute
if ( if (
hass.config_entries.flow.async_has_matching_flow(domain, context, data) hass.config_entries.flow.async_has_matching_discovery_flow(
domain, context, data
)
or hass.is_stopping or hass.is_stopping
): ):
return None return None

View file

@ -91,7 +91,7 @@ async def test_async_create_flow_checks_existing_flows_after_startup(
"""Test existing flows prevent an identical ones from being after startup.""" """Test existing flows prevent an identical ones from being after startup."""
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
with patch( with patch(
"homeassistant.data_entry_flow.FlowManager.async_has_matching_flow", "homeassistant.config_entries.ConfigEntriesFlowManager.async_has_matching_discovery_flow",
return_value=True, return_value=True,
): ):
discovery_flow.async_create_flow( discovery_flow.async_create_flow(

View file

@ -7,7 +7,7 @@ from collections.abc import Generator
from datetime import timedelta from datetime import timedelta
from functools import cached_property from functools import cached_property
import logging import logging
from typing import Any from typing import Any, Self
from unittest.mock import ANY, AsyncMock, Mock, patch from unittest.mock import ANY, AsyncMock, Mock, patch
from freezegun import freeze_time from freezegun import freeze_time
@ -6180,3 +6180,204 @@ async def test_async_loaded_entries(
assert await hass.config_entries.async_unload(entry1.entry_id) assert await hass.config_entries.async_unload(entry1.entry_id)
assert hass.config_entries.async_loaded_entries("comp") == [] assert hass.config_entries.async_loaded_entries("comp") == []
async def test_async_has_matching_discovery_flow(
hass: HomeAssistant, manager: config_entries.ConfigEntries
) -> None:
"""Test we can check for matching discovery flows."""
assert (
manager.flow.async_has_matching_discovery_flow(
"test",
{"source": config_entries.SOURCE_HOMEKIT},
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
is False
)
mock_integration(hass, MockModule("test"))
mock_platform(hass, "test.config_flow", None)
class TestFlow(config_entries.ConfigFlow):
VERSION = 5
async def async_step_init(self, user_input=None):
return self.async_show_progress(
step_id="init",
progress_action="task_one",
)
async def async_step_homekit(self, discovery_info=None):
return await self.async_step_init(discovery_info)
with mock_config_flow("test", TestFlow):
result = await manager.flow.async_init(
"test",
context={"source": config_entries.SOURCE_HOMEKIT},
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
assert result["progress_action"] == "task_one"
assert len(manager.flow.async_progress()) == 1
assert len(manager.flow.async_progress_by_handler("test")) == 1
assert (
len(
manager.flow.async_progress_by_handler(
"test", match_context={"source": config_entries.SOURCE_HOMEKIT}
)
)
== 1
)
assert (
len(
manager.flow.async_progress_by_handler(
"test", match_context={"source": config_entries.SOURCE_BLUETOOTH}
)
)
== 0
)
assert manager.flow.async_get(result["flow_id"])["handler"] == "test"
assert (
manager.flow.async_has_matching_discovery_flow(
"test",
{"source": config_entries.SOURCE_HOMEKIT},
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
is True
)
assert (
manager.flow.async_has_matching_discovery_flow(
"test",
{"source": config_entries.SOURCE_SSDP},
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
is False
)
assert (
manager.flow.async_has_matching_discovery_flow(
"other",
{"source": config_entries.SOURCE_HOMEKIT},
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
is False
)
async def test_async_has_matching_flow(
hass: HomeAssistant, manager: config_entries.ConfigEntries
) -> None:
"""Test check for matching flows when there is no active flow."""
mock_integration(hass, MockModule("test"))
mock_platform(hass, "test.config_flow", None)
class TestFlow(config_entries.ConfigFlow):
VERSION = 5
async def async_step_init(self, user_input=None):
return self.async_show_progress(
step_id="init",
progress_action="task_one",
)
async def async_step_homekit(self, discovery_info=None):
return await self.async_step_init(discovery_info)
def is_matching(self, other_flow: Self) -> bool:
"""Return True if other_flow is matching this flow."""
return True
# Initiate a flow
with mock_config_flow("test", TestFlow):
await manager.flow.async_init(
"test",
context={"source": config_entries.SOURCE_HOMEKIT},
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
flow = list(manager.flow._handler_progress_index.get("test"))[0]
assert manager.flow.async_has_matching_flow(flow) is False
# Initiate another flow
with mock_config_flow("test", TestFlow):
await manager.flow.async_init(
"test",
context={"source": config_entries.SOURCE_HOMEKIT},
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
assert manager.flow.async_has_matching_flow(flow) is True
async def test_async_has_matching_flow_no_flows(
hass: HomeAssistant, manager: config_entries.ConfigEntries
) -> None:
"""Test check for matching flows when there is no active flow."""
mock_integration(hass, MockModule("test"))
mock_platform(hass, "test.config_flow", None)
class TestFlow(config_entries.ConfigFlow):
VERSION = 5
async def async_step_init(self, user_input=None):
return self.async_show_progress(
step_id="init",
progress_action="task_one",
)
async def async_step_homekit(self, discovery_info=None):
return await self.async_step_init(discovery_info)
with mock_config_flow("test", TestFlow):
result = await manager.flow.async_init(
"test",
context={"source": config_entries.SOURCE_HOMEKIT},
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
flow = list(manager.flow._handler_progress_index.get("test"))[0]
# Abort the flow before checking for matching flows
manager.flow.async_abort(result["flow_id"])
assert manager.flow.async_has_matching_flow(flow) is False
async def test_async_has_matching_flow_not_implemented(
hass: HomeAssistant, manager: config_entries.ConfigEntries
) -> None:
"""Test check for matching flows when there is no active flow."""
mock_integration(hass, MockModule("test"))
mock_platform(hass, "test.config_flow", None)
class TestFlow(config_entries.ConfigFlow):
VERSION = 5
async def async_step_init(self, user_input=None):
return self.async_show_progress(
step_id="init",
progress_action="task_one",
)
async def async_step_homekit(self, discovery_info=None):
return await self.async_step_init(discovery_info)
# Initiate a flow
with mock_config_flow("test", TestFlow):
await manager.flow.async_init(
"test",
context={"source": config_entries.SOURCE_HOMEKIT},
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
flow = list(manager.flow._handler_progress_index.get("test"))[0]
# Initiate another flow
with mock_config_flow("test", TestFlow):
await manager.flow.async_init(
"test",
context={"source": config_entries.SOURCE_HOMEKIT},
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
# The flow does not implement is_matching
with pytest.raises(NotImplementedError):
manager.flow.async_has_matching_flow(flow)

View file

@ -781,83 +781,6 @@ async def test_async_get_unknown_flow(manager: MockFlowManager) -> None:
await manager.async_get("does_not_exist") await manager.async_get("does_not_exist")
async def test_async_has_matching_flow(
hass: HomeAssistant, manager: MockFlowManager
) -> None:
"""Test we can check for matching flows."""
manager.hass = hass
assert (
manager.async_has_matching_flow(
"test",
{"source": config_entries.SOURCE_HOMEKIT},
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
is False
)
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 5
async def async_step_init(self, user_input=None):
return self.async_show_progress(
step_id="init",
progress_action="task_one",
)
result = await manager.async_init(
"test",
context={"source": config_entries.SOURCE_HOMEKIT},
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
assert result["progress_action"] == "task_one"
assert len(manager.async_progress()) == 1
assert len(manager.async_progress_by_handler("test")) == 1
assert (
len(
manager.async_progress_by_handler(
"test", match_context={"source": config_entries.SOURCE_HOMEKIT}
)
)
== 1
)
assert (
len(
manager.async_progress_by_handler(
"test", match_context={"source": config_entries.SOURCE_BLUETOOTH}
)
)
== 0
)
assert manager.async_get(result["flow_id"])["handler"] == "test"
assert (
manager.async_has_matching_flow(
"test",
{"source": config_entries.SOURCE_HOMEKIT},
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
is True
)
assert (
manager.async_has_matching_flow(
"test",
{"source": config_entries.SOURCE_SSDP},
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
is False
)
assert (
manager.async_has_matching_flow(
"other",
{"source": config_entries.SOURCE_HOMEKIT},
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
is False
)
async def test_move_to_unknown_step_raises_and_removes_from_in_progress( async def test_move_to_unknown_step_raises_and_removes_from_in_progress(
manager: MockFlowManager, manager: MockFlowManager,
) -> None: ) -> None: