Add FlowManager.async_has_matching_flow (#126804)

* Add FlowManager.async_flow_has_matching_flow

* Revert changes from the future

* Apply suggested changes to apple_tv config flow

* Rename methods after discussion

* Update homeassistant/data_entry_flow.py

Co-authored-by: J. Nick Koston <nick@koston.org>

* Move deduplication functions to config_entries, add tests

* Adjust tests

---------

Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Erik Montnemery 2024-09-27 10:51:36 +02:00 committed by GitHub
parent 26b5dab12b
commit 3c0be47d3c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 262 additions and 121 deletions

View file

@ -8,7 +8,7 @@ from collections.abc import Awaitable, Callable, Mapping
from ipaddress import ip_address
import logging
from random import randrange
from typing import Any
from typing import Any, Self
from pyatv import exceptions, pair, scan
from pyatv.const import DeviceModel, PairingRequirement, Protocol
@ -98,8 +98,11 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
VERSION = 1
scan_filter: str | None = None
all_identifiers: set[str]
atv: BaseConfig | None = None
atv_identifiers: list[str] | None = None
_host: str # host in zeroconf discovery info, should not be accessed by other flows
host: str | None = None # set by _async_aggregate_discoveries, for other flows
protocol: Protocol | None = None
pairing: PairingHandler | None = None
protocols_to_pair: deque[Protocol] | None = None
@ -157,7 +160,6 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
"type": "Apple TV",
}
self.scan_filter = self.unique_id
self.context["identifier"] = self.unique_id
return await self.async_step_restore_device()
async def async_step_restore_device(
@ -192,7 +194,7 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
self.device_identifier, raise_on_progress=False
)
assert self.atv
self.context["all_identifiers"] = self.atv.all_identifiers
self.all_identifiers = set(self.atv.all_identifiers)
return await self.async_step_confirm()
return self.async_show_form(
@ -207,7 +209,7 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle device found via zeroconf."""
if discovery_info.ip_address.version == 6:
return self.async_abort(reason="ipv6_not_supported")
host = discovery_info.host
self._host = host = discovery_info.host
service_type = discovery_info.type[:-1] # Remove leading .
name = discovery_info.name.replace(f".{service_type}.", "")
properties = discovery_info.properties
@ -255,7 +257,7 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
# as two separate flows.
#
# To solve this, all identifiers are stored as
# "all_identifiers" in the flow context. When a new service is discovered, the
# "all_identifiers" in the flow. When a new service is discovered, the
# code below will check these identifiers for all active flows and abort if a
# match is found. Before aborting, the original flow is updated with any
# potentially new identifiers. In the example above, when service C is
@ -277,32 +279,32 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
self._async_check_and_update_in_progress(host, unique_id)
# Host must only be set AFTER checking and updating in progress
# flows or we will have a race condition where no flows move forward.
self.context[CONF_ADDRESS] = host
self.host = host
@callback
def _async_check_and_update_in_progress(self, host: str, unique_id: str) -> None:
"""Check for in-progress flows and update them with identifiers if needed."""
for flow in self._async_in_progress(include_uninitialized=True):
context = flow["context"]
if (
context.get("source") != SOURCE_ZEROCONF
or context.get(CONF_ADDRESS) != host
):
continue
if (
"all_identifiers" in context
and unique_id not in context["all_identifiers"]
):
# Add potentially new identifiers from this device to the existing flow
context["all_identifiers"].append(unique_id)
if self.hass.config_entries.flow.async_has_matching_flow(self):
raise AbortFlow("already_in_progress")
def is_matching(self, other_flow: Self) -> bool:
"""Return True if other_flow is matching this flow."""
if (
other_flow.context.get("source") != SOURCE_ZEROCONF
or other_flow.host != self._host
):
return False
if self.unique_id is not None:
# Add potentially new identifiers from this device to the existing flow
other_flow.all_identifiers.add(self.unique_id)
return True
async def async_found_zeroconf_device(
self, user_input: dict[str, str] | None = None
) -> ConfigFlowResult:
"""Handle device found after Zeroconf discovery."""
assert self.atv
self.context["all_identifiers"] = self.atv.all_identifiers
self.all_identifiers = set(self.atv.all_identifiers)
# Also abort if an integration with this identifier already exists
await self.async_set_unique_id(self.device_identifier)
# but be sure to update the address if its changed so the scanner
@ -310,7 +312,6 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
self._abort_if_unique_id_configured(
updates={CONF_ADDRESS: str(self.atv.address)}
)
self.context["identifier"] = self.unique_id
return await self.async_step_confirm()
async def async_find_device_wrapper(
@ -390,7 +391,7 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle user-confirmation of discovered node."""
assert self.atv
if user_input is not None:
expected_identifier_count = len(self.context["all_identifiers"])
expected_identifier_count = len(self.all_identifiers)
# If number of services found during device scan mismatch number of
# identifiers collected during Zeroconf discovery, then trigger a new scan
# with hopes of finding all services.

View file

@ -1544,6 +1544,35 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
notification_id=DISCOVERY_NOTIFICATION_ID,
)
@callback
def async_has_matching_discovery_flow(
self, handler: str, match_context: dict[str, Any], data: Any
) -> bool:
"""Check if an existing matching discovery flow is in progress.
A flow with the same handler, context, and data.
If match_context is passed, only return flows with a context that is a
superset of match_context.
"""
if not (flows := self._handler_progress_index.get(handler)):
return False
match_items = match_context.items()
for progress in flows:
if match_items <= progress.context.items() and progress.init_data == data:
return True
return False
@callback
def async_has_matching_flow(self, flow: ConfigFlow) -> bool:
"""Check if an existing matching flow is in progress."""
if not (flows := self._handler_progress_index.get(flow.handler)):
return False
for other_flow in flows:
if other_flow is not flow and flow.is_matching(other_flow): # type: ignore[arg-type]
return True
return False
class ConfigEntryItems(UserDict[str, ConfigEntry]):
"""Container for config items, maps config_entry_id -> entry.
@ -2693,6 +2722,10 @@ class ConfigFlow(ConfigEntryBaseFlow):
self.hass.config_entries.async_schedule_reload(entry.entry_id)
return self.async_abort(reason=reason)
def is_matching(self, other_flow: Self) -> bool:
"""Return True if other_flow is matching this flow."""
raise NotImplementedError
class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
"""Flow to set options for a configuration entry."""

View file

@ -237,25 +237,6 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
) -> None:
"""Entry has finished executing its first step asynchronously."""
@callback
def async_has_matching_flow(
self, handler: _HandlerT, match_context: dict[str, Any], data: Any
) -> bool:
"""Check if an existing matching flow is in progress.
A flow with the same handler, context, and data.
If match_context is passed, only return flows with a context that is a
superset of match_context.
"""
if not (flows := self._handler_progress_index.get(handler)):
return False
match_items = match_context.items()
for progress in flows:
if match_items <= progress.context.items() and progress.init_data == data:
return True
return False
@callback
def async_get(self, flow_id: str) -> _FlowResultT:
"""Return a flow in progress as a partial FlowResult."""

View file

@ -78,7 +78,9 @@ def _async_init_flow(
# which can overload devices since zeroconf/ssdp updates can happen
# multiple times in the same minute
if (
hass.config_entries.flow.async_has_matching_flow(domain, context, data)
hass.config_entries.flow.async_has_matching_discovery_flow(
domain, context, data
)
or hass.is_stopping
):
return None

View file

@ -91,7 +91,7 @@ async def test_async_create_flow_checks_existing_flows_after_startup(
"""Test existing flows prevent an identical ones from being after startup."""
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
with patch(
"homeassistant.data_entry_flow.FlowManager.async_has_matching_flow",
"homeassistant.config_entries.ConfigEntriesFlowManager.async_has_matching_discovery_flow",
return_value=True,
):
discovery_flow.async_create_flow(

View file

@ -7,7 +7,7 @@ from collections.abc import Generator
from datetime import timedelta
from functools import cached_property
import logging
from typing import Any
from typing import Any, Self
from unittest.mock import ANY, AsyncMock, Mock, patch
from freezegun import freeze_time
@ -6180,3 +6180,204 @@ async def test_async_loaded_entries(
assert await hass.config_entries.async_unload(entry1.entry_id)
assert hass.config_entries.async_loaded_entries("comp") == []
async def test_async_has_matching_discovery_flow(
hass: HomeAssistant, manager: config_entries.ConfigEntries
) -> None:
"""Test we can check for matching discovery flows."""
assert (
manager.flow.async_has_matching_discovery_flow(
"test",
{"source": config_entries.SOURCE_HOMEKIT},
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
is False
)
mock_integration(hass, MockModule("test"))
mock_platform(hass, "test.config_flow", None)
class TestFlow(config_entries.ConfigFlow):
VERSION = 5
async def async_step_init(self, user_input=None):
return self.async_show_progress(
step_id="init",
progress_action="task_one",
)
async def async_step_homekit(self, discovery_info=None):
return await self.async_step_init(discovery_info)
with mock_config_flow("test", TestFlow):
result = await manager.flow.async_init(
"test",
context={"source": config_entries.SOURCE_HOMEKIT},
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
assert result["progress_action"] == "task_one"
assert len(manager.flow.async_progress()) == 1
assert len(manager.flow.async_progress_by_handler("test")) == 1
assert (
len(
manager.flow.async_progress_by_handler(
"test", match_context={"source": config_entries.SOURCE_HOMEKIT}
)
)
== 1
)
assert (
len(
manager.flow.async_progress_by_handler(
"test", match_context={"source": config_entries.SOURCE_BLUETOOTH}
)
)
== 0
)
assert manager.flow.async_get(result["flow_id"])["handler"] == "test"
assert (
manager.flow.async_has_matching_discovery_flow(
"test",
{"source": config_entries.SOURCE_HOMEKIT},
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
is True
)
assert (
manager.flow.async_has_matching_discovery_flow(
"test",
{"source": config_entries.SOURCE_SSDP},
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
is False
)
assert (
manager.flow.async_has_matching_discovery_flow(
"other",
{"source": config_entries.SOURCE_HOMEKIT},
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
is False
)
async def test_async_has_matching_flow(
hass: HomeAssistant, manager: config_entries.ConfigEntries
) -> None:
"""Test check for matching flows when there is no active flow."""
mock_integration(hass, MockModule("test"))
mock_platform(hass, "test.config_flow", None)
class TestFlow(config_entries.ConfigFlow):
VERSION = 5
async def async_step_init(self, user_input=None):
return self.async_show_progress(
step_id="init",
progress_action="task_one",
)
async def async_step_homekit(self, discovery_info=None):
return await self.async_step_init(discovery_info)
def is_matching(self, other_flow: Self) -> bool:
"""Return True if other_flow is matching this flow."""
return True
# Initiate a flow
with mock_config_flow("test", TestFlow):
await manager.flow.async_init(
"test",
context={"source": config_entries.SOURCE_HOMEKIT},
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
flow = list(manager.flow._handler_progress_index.get("test"))[0]
assert manager.flow.async_has_matching_flow(flow) is False
# Initiate another flow
with mock_config_flow("test", TestFlow):
await manager.flow.async_init(
"test",
context={"source": config_entries.SOURCE_HOMEKIT},
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
assert manager.flow.async_has_matching_flow(flow) is True
async def test_async_has_matching_flow_no_flows(
hass: HomeAssistant, manager: config_entries.ConfigEntries
) -> None:
"""Test check for matching flows when there is no active flow."""
mock_integration(hass, MockModule("test"))
mock_platform(hass, "test.config_flow", None)
class TestFlow(config_entries.ConfigFlow):
VERSION = 5
async def async_step_init(self, user_input=None):
return self.async_show_progress(
step_id="init",
progress_action="task_one",
)
async def async_step_homekit(self, discovery_info=None):
return await self.async_step_init(discovery_info)
with mock_config_flow("test", TestFlow):
result = await manager.flow.async_init(
"test",
context={"source": config_entries.SOURCE_HOMEKIT},
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
flow = list(manager.flow._handler_progress_index.get("test"))[0]
# Abort the flow before checking for matching flows
manager.flow.async_abort(result["flow_id"])
assert manager.flow.async_has_matching_flow(flow) is False
async def test_async_has_matching_flow_not_implemented(
hass: HomeAssistant, manager: config_entries.ConfigEntries
) -> None:
"""Test check for matching flows when there is no active flow."""
mock_integration(hass, MockModule("test"))
mock_platform(hass, "test.config_flow", None)
class TestFlow(config_entries.ConfigFlow):
VERSION = 5
async def async_step_init(self, user_input=None):
return self.async_show_progress(
step_id="init",
progress_action="task_one",
)
async def async_step_homekit(self, discovery_info=None):
return await self.async_step_init(discovery_info)
# Initiate a flow
with mock_config_flow("test", TestFlow):
await manager.flow.async_init(
"test",
context={"source": config_entries.SOURCE_HOMEKIT},
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
flow = list(manager.flow._handler_progress_index.get("test"))[0]
# Initiate another flow
with mock_config_flow("test", TestFlow):
await manager.flow.async_init(
"test",
context={"source": config_entries.SOURCE_HOMEKIT},
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
# The flow does not implement is_matching
with pytest.raises(NotImplementedError):
manager.flow.async_has_matching_flow(flow)

View file

@ -781,83 +781,6 @@ async def test_async_get_unknown_flow(manager: MockFlowManager) -> None:
await manager.async_get("does_not_exist")
async def test_async_has_matching_flow(
hass: HomeAssistant, manager: MockFlowManager
) -> None:
"""Test we can check for matching flows."""
manager.hass = hass
assert (
manager.async_has_matching_flow(
"test",
{"source": config_entries.SOURCE_HOMEKIT},
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
is False
)
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 5
async def async_step_init(self, user_input=None):
return self.async_show_progress(
step_id="init",
progress_action="task_one",
)
result = await manager.async_init(
"test",
context={"source": config_entries.SOURCE_HOMEKIT},
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
assert result["progress_action"] == "task_one"
assert len(manager.async_progress()) == 1
assert len(manager.async_progress_by_handler("test")) == 1
assert (
len(
manager.async_progress_by_handler(
"test", match_context={"source": config_entries.SOURCE_HOMEKIT}
)
)
== 1
)
assert (
len(
manager.async_progress_by_handler(
"test", match_context={"source": config_entries.SOURCE_BLUETOOTH}
)
)
== 0
)
assert manager.async_get(result["flow_id"])["handler"] == "test"
assert (
manager.async_has_matching_flow(
"test",
{"source": config_entries.SOURCE_HOMEKIT},
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
is True
)
assert (
manager.async_has_matching_flow(
"test",
{"source": config_entries.SOURCE_SSDP},
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
is False
)
assert (
manager.async_has_matching_flow(
"other",
{"source": config_entries.SOURCE_HOMEKIT},
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
is False
)
async def test_move_to_unknown_step_raises_and_removes_from_in_progress(
manager: MockFlowManager,
) -> None: