Add FlowManager.async_has_matching_flow (#126804)
* Add FlowManager.async_flow_has_matching_flow * Revert changes from the future * Apply suggested changes to apple_tv config flow * Rename methods after discussion * Update homeassistant/data_entry_flow.py Co-authored-by: J. Nick Koston <nick@koston.org> * Move deduplication functions to config_entries, add tests * Adjust tests --------- Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
parent
26b5dab12b
commit
3c0be47d3c
7 changed files with 262 additions and 121 deletions
|
@ -8,7 +8,7 @@ from collections.abc import Awaitable, Callable, Mapping
|
|||
from ipaddress import ip_address
|
||||
import logging
|
||||
from random import randrange
|
||||
from typing import Any
|
||||
from typing import Any, Self
|
||||
|
||||
from pyatv import exceptions, pair, scan
|
||||
from pyatv.const import DeviceModel, PairingRequirement, Protocol
|
||||
|
@ -98,8 +98,11 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
VERSION = 1
|
||||
|
||||
scan_filter: str | None = None
|
||||
all_identifiers: set[str]
|
||||
atv: BaseConfig | None = None
|
||||
atv_identifiers: list[str] | None = None
|
||||
_host: str # host in zeroconf discovery info, should not be accessed by other flows
|
||||
host: str | None = None # set by _async_aggregate_discoveries, for other flows
|
||||
protocol: Protocol | None = None
|
||||
pairing: PairingHandler | None = None
|
||||
protocols_to_pair: deque[Protocol] | None = None
|
||||
|
@ -157,7 +160,6 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
"type": "Apple TV",
|
||||
}
|
||||
self.scan_filter = self.unique_id
|
||||
self.context["identifier"] = self.unique_id
|
||||
return await self.async_step_restore_device()
|
||||
|
||||
async def async_step_restore_device(
|
||||
|
@ -192,7 +194,7 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
self.device_identifier, raise_on_progress=False
|
||||
)
|
||||
assert self.atv
|
||||
self.context["all_identifiers"] = self.atv.all_identifiers
|
||||
self.all_identifiers = set(self.atv.all_identifiers)
|
||||
return await self.async_step_confirm()
|
||||
|
||||
return self.async_show_form(
|
||||
|
@ -207,7 +209,7 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
"""Handle device found via zeroconf."""
|
||||
if discovery_info.ip_address.version == 6:
|
||||
return self.async_abort(reason="ipv6_not_supported")
|
||||
host = discovery_info.host
|
||||
self._host = host = discovery_info.host
|
||||
service_type = discovery_info.type[:-1] # Remove leading .
|
||||
name = discovery_info.name.replace(f".{service_type}.", "")
|
||||
properties = discovery_info.properties
|
||||
|
@ -255,7 +257,7 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
# as two separate flows.
|
||||
#
|
||||
# To solve this, all identifiers are stored as
|
||||
# "all_identifiers" in the flow context. When a new service is discovered, the
|
||||
# "all_identifiers" in the flow. When a new service is discovered, the
|
||||
# code below will check these identifiers for all active flows and abort if a
|
||||
# match is found. Before aborting, the original flow is updated with any
|
||||
# potentially new identifiers. In the example above, when service C is
|
||||
|
@ -277,32 +279,32 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
self._async_check_and_update_in_progress(host, unique_id)
|
||||
# Host must only be set AFTER checking and updating in progress
|
||||
# flows or we will have a race condition where no flows move forward.
|
||||
self.context[CONF_ADDRESS] = host
|
||||
self.host = host
|
||||
|
||||
@callback
|
||||
def _async_check_and_update_in_progress(self, host: str, unique_id: str) -> None:
|
||||
"""Check for in-progress flows and update them with identifiers if needed."""
|
||||
for flow in self._async_in_progress(include_uninitialized=True):
|
||||
context = flow["context"]
|
||||
if (
|
||||
context.get("source") != SOURCE_ZEROCONF
|
||||
or context.get(CONF_ADDRESS) != host
|
||||
):
|
||||
continue
|
||||
if (
|
||||
"all_identifiers" in context
|
||||
and unique_id not in context["all_identifiers"]
|
||||
):
|
||||
# Add potentially new identifiers from this device to the existing flow
|
||||
context["all_identifiers"].append(unique_id)
|
||||
if self.hass.config_entries.flow.async_has_matching_flow(self):
|
||||
raise AbortFlow("already_in_progress")
|
||||
|
||||
def is_matching(self, other_flow: Self) -> bool:
|
||||
"""Return True if other_flow is matching this flow."""
|
||||
if (
|
||||
other_flow.context.get("source") != SOURCE_ZEROCONF
|
||||
or other_flow.host != self._host
|
||||
):
|
||||
return False
|
||||
if self.unique_id is not None:
|
||||
# Add potentially new identifiers from this device to the existing flow
|
||||
other_flow.all_identifiers.add(self.unique_id)
|
||||
return True
|
||||
|
||||
async def async_found_zeroconf_device(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle device found after Zeroconf discovery."""
|
||||
assert self.atv
|
||||
self.context["all_identifiers"] = self.atv.all_identifiers
|
||||
self.all_identifiers = set(self.atv.all_identifiers)
|
||||
# Also abort if an integration with this identifier already exists
|
||||
await self.async_set_unique_id(self.device_identifier)
|
||||
# but be sure to update the address if its changed so the scanner
|
||||
|
@ -310,7 +312,6 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
self._abort_if_unique_id_configured(
|
||||
updates={CONF_ADDRESS: str(self.atv.address)}
|
||||
)
|
||||
self.context["identifier"] = self.unique_id
|
||||
return await self.async_step_confirm()
|
||||
|
||||
async def async_find_device_wrapper(
|
||||
|
@ -390,7 +391,7 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||
"""Handle user-confirmation of discovered node."""
|
||||
assert self.atv
|
||||
if user_input is not None:
|
||||
expected_identifier_count = len(self.context["all_identifiers"])
|
||||
expected_identifier_count = len(self.all_identifiers)
|
||||
# If number of services found during device scan mismatch number of
|
||||
# identifiers collected during Zeroconf discovery, then trigger a new scan
|
||||
# with hopes of finding all services.
|
||||
|
|
|
@ -1544,6 +1544,35 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
|||
notification_id=DISCOVERY_NOTIFICATION_ID,
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_has_matching_discovery_flow(
|
||||
self, handler: str, match_context: dict[str, Any], data: Any
|
||||
) -> bool:
|
||||
"""Check if an existing matching discovery flow is in progress.
|
||||
|
||||
A flow with the same handler, context, and data.
|
||||
|
||||
If match_context is passed, only return flows with a context that is a
|
||||
superset of match_context.
|
||||
"""
|
||||
if not (flows := self._handler_progress_index.get(handler)):
|
||||
return False
|
||||
match_items = match_context.items()
|
||||
for progress in flows:
|
||||
if match_items <= progress.context.items() and progress.init_data == data:
|
||||
return True
|
||||
return False
|
||||
|
||||
@callback
|
||||
def async_has_matching_flow(self, flow: ConfigFlow) -> bool:
|
||||
"""Check if an existing matching flow is in progress."""
|
||||
if not (flows := self._handler_progress_index.get(flow.handler)):
|
||||
return False
|
||||
for other_flow in flows:
|
||||
if other_flow is not flow and flow.is_matching(other_flow): # type: ignore[arg-type]
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class ConfigEntryItems(UserDict[str, ConfigEntry]):
|
||||
"""Container for config items, maps config_entry_id -> entry.
|
||||
|
@ -2693,6 +2722,10 @@ class ConfigFlow(ConfigEntryBaseFlow):
|
|||
self.hass.config_entries.async_schedule_reload(entry.entry_id)
|
||||
return self.async_abort(reason=reason)
|
||||
|
||||
def is_matching(self, other_flow: Self) -> bool:
|
||||
"""Return True if other_flow is matching this flow."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
||||
"""Flow to set options for a configuration entry."""
|
||||
|
|
|
@ -237,25 +237,6 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
|||
) -> None:
|
||||
"""Entry has finished executing its first step asynchronously."""
|
||||
|
||||
@callback
|
||||
def async_has_matching_flow(
|
||||
self, handler: _HandlerT, match_context: dict[str, Any], data: Any
|
||||
) -> bool:
|
||||
"""Check if an existing matching flow is in progress.
|
||||
|
||||
A flow with the same handler, context, and data.
|
||||
|
||||
If match_context is passed, only return flows with a context that is a
|
||||
superset of match_context.
|
||||
"""
|
||||
if not (flows := self._handler_progress_index.get(handler)):
|
||||
return False
|
||||
match_items = match_context.items()
|
||||
for progress in flows:
|
||||
if match_items <= progress.context.items() and progress.init_data == data:
|
||||
return True
|
||||
return False
|
||||
|
||||
@callback
|
||||
def async_get(self, flow_id: str) -> _FlowResultT:
|
||||
"""Return a flow in progress as a partial FlowResult."""
|
||||
|
|
|
@ -78,7 +78,9 @@ def _async_init_flow(
|
|||
# which can overload devices since zeroconf/ssdp updates can happen
|
||||
# multiple times in the same minute
|
||||
if (
|
||||
hass.config_entries.flow.async_has_matching_flow(domain, context, data)
|
||||
hass.config_entries.flow.async_has_matching_discovery_flow(
|
||||
domain, context, data
|
||||
)
|
||||
or hass.is_stopping
|
||||
):
|
||||
return None
|
||||
|
|
|
@ -91,7 +91,7 @@ async def test_async_create_flow_checks_existing_flows_after_startup(
|
|||
"""Test existing flows prevent an identical ones from being after startup."""
|
||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
|
||||
with patch(
|
||||
"homeassistant.data_entry_flow.FlowManager.async_has_matching_flow",
|
||||
"homeassistant.config_entries.ConfigEntriesFlowManager.async_has_matching_discovery_flow",
|
||||
return_value=True,
|
||||
):
|
||||
discovery_flow.async_create_flow(
|
||||
|
|
|
@ -7,7 +7,7 @@ from collections.abc import Generator
|
|||
from datetime import timedelta
|
||||
from functools import cached_property
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, Self
|
||||
from unittest.mock import ANY, AsyncMock, Mock, patch
|
||||
|
||||
from freezegun import freeze_time
|
||||
|
@ -6180,3 +6180,204 @@ async def test_async_loaded_entries(
|
|||
assert await hass.config_entries.async_unload(entry1.entry_id)
|
||||
|
||||
assert hass.config_entries.async_loaded_entries("comp") == []
|
||||
|
||||
|
||||
async def test_async_has_matching_discovery_flow(
|
||||
hass: HomeAssistant, manager: config_entries.ConfigEntries
|
||||
) -> None:
|
||||
"""Test we can check for matching discovery flows."""
|
||||
assert (
|
||||
manager.flow.async_has_matching_discovery_flow(
|
||||
"test",
|
||||
{"source": config_entries.SOURCE_HOMEKIT},
|
||||
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
mock_integration(hass, MockModule("test"))
|
||||
mock_platform(hass, "test.config_flow", None)
|
||||
|
||||
class TestFlow(config_entries.ConfigFlow):
|
||||
VERSION = 5
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
return self.async_show_progress(
|
||||
step_id="init",
|
||||
progress_action="task_one",
|
||||
)
|
||||
|
||||
async def async_step_homekit(self, discovery_info=None):
|
||||
return await self.async_step_init(discovery_info)
|
||||
|
||||
with mock_config_flow("test", TestFlow):
|
||||
result = await manager.flow.async_init(
|
||||
"test",
|
||||
context={"source": config_entries.SOURCE_HOMEKIT},
|
||||
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||
)
|
||||
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
|
||||
assert result["progress_action"] == "task_one"
|
||||
assert len(manager.flow.async_progress()) == 1
|
||||
assert len(manager.flow.async_progress_by_handler("test")) == 1
|
||||
assert (
|
||||
len(
|
||||
manager.flow.async_progress_by_handler(
|
||||
"test", match_context={"source": config_entries.SOURCE_HOMEKIT}
|
||||
)
|
||||
)
|
||||
== 1
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
manager.flow.async_progress_by_handler(
|
||||
"test", match_context={"source": config_entries.SOURCE_BLUETOOTH}
|
||||
)
|
||||
)
|
||||
== 0
|
||||
)
|
||||
assert manager.flow.async_get(result["flow_id"])["handler"] == "test"
|
||||
|
||||
assert (
|
||||
manager.flow.async_has_matching_discovery_flow(
|
||||
"test",
|
||||
{"source": config_entries.SOURCE_HOMEKIT},
|
||||
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||
)
|
||||
is True
|
||||
)
|
||||
assert (
|
||||
manager.flow.async_has_matching_discovery_flow(
|
||||
"test",
|
||||
{"source": config_entries.SOURCE_SSDP},
|
||||
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||
)
|
||||
is False
|
||||
)
|
||||
assert (
|
||||
manager.flow.async_has_matching_discovery_flow(
|
||||
"other",
|
||||
{"source": config_entries.SOURCE_HOMEKIT},
|
||||
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
async def test_async_has_matching_flow(
|
||||
hass: HomeAssistant, manager: config_entries.ConfigEntries
|
||||
) -> None:
|
||||
"""Test check for matching flows when there is no active flow."""
|
||||
mock_integration(hass, MockModule("test"))
|
||||
mock_platform(hass, "test.config_flow", None)
|
||||
|
||||
class TestFlow(config_entries.ConfigFlow):
|
||||
VERSION = 5
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
return self.async_show_progress(
|
||||
step_id="init",
|
||||
progress_action="task_one",
|
||||
)
|
||||
|
||||
async def async_step_homekit(self, discovery_info=None):
|
||||
return await self.async_step_init(discovery_info)
|
||||
|
||||
def is_matching(self, other_flow: Self) -> bool:
|
||||
"""Return True if other_flow is matching this flow."""
|
||||
return True
|
||||
|
||||
# Initiate a flow
|
||||
with mock_config_flow("test", TestFlow):
|
||||
await manager.flow.async_init(
|
||||
"test",
|
||||
context={"source": config_entries.SOURCE_HOMEKIT},
|
||||
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||
)
|
||||
flow = list(manager.flow._handler_progress_index.get("test"))[0]
|
||||
|
||||
assert manager.flow.async_has_matching_flow(flow) is False
|
||||
|
||||
# Initiate another flow
|
||||
with mock_config_flow("test", TestFlow):
|
||||
await manager.flow.async_init(
|
||||
"test",
|
||||
context={"source": config_entries.SOURCE_HOMEKIT},
|
||||
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||
)
|
||||
|
||||
assert manager.flow.async_has_matching_flow(flow) is True
|
||||
|
||||
|
||||
async def test_async_has_matching_flow_no_flows(
|
||||
hass: HomeAssistant, manager: config_entries.ConfigEntries
|
||||
) -> None:
|
||||
"""Test check for matching flows when there is no active flow."""
|
||||
mock_integration(hass, MockModule("test"))
|
||||
mock_platform(hass, "test.config_flow", None)
|
||||
|
||||
class TestFlow(config_entries.ConfigFlow):
|
||||
VERSION = 5
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
return self.async_show_progress(
|
||||
step_id="init",
|
||||
progress_action="task_one",
|
||||
)
|
||||
|
||||
async def async_step_homekit(self, discovery_info=None):
|
||||
return await self.async_step_init(discovery_info)
|
||||
|
||||
with mock_config_flow("test", TestFlow):
|
||||
result = await manager.flow.async_init(
|
||||
"test",
|
||||
context={"source": config_entries.SOURCE_HOMEKIT},
|
||||
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||
)
|
||||
flow = list(manager.flow._handler_progress_index.get("test"))[0]
|
||||
|
||||
# Abort the flow before checking for matching flows
|
||||
manager.flow.async_abort(result["flow_id"])
|
||||
|
||||
assert manager.flow.async_has_matching_flow(flow) is False
|
||||
|
||||
|
||||
async def test_async_has_matching_flow_not_implemented(
|
||||
hass: HomeAssistant, manager: config_entries.ConfigEntries
|
||||
) -> None:
|
||||
"""Test check for matching flows when there is no active flow."""
|
||||
mock_integration(hass, MockModule("test"))
|
||||
mock_platform(hass, "test.config_flow", None)
|
||||
|
||||
class TestFlow(config_entries.ConfigFlow):
|
||||
VERSION = 5
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
return self.async_show_progress(
|
||||
step_id="init",
|
||||
progress_action="task_one",
|
||||
)
|
||||
|
||||
async def async_step_homekit(self, discovery_info=None):
|
||||
return await self.async_step_init(discovery_info)
|
||||
|
||||
# Initiate a flow
|
||||
with mock_config_flow("test", TestFlow):
|
||||
await manager.flow.async_init(
|
||||
"test",
|
||||
context={"source": config_entries.SOURCE_HOMEKIT},
|
||||
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||
)
|
||||
flow = list(manager.flow._handler_progress_index.get("test"))[0]
|
||||
|
||||
# Initiate another flow
|
||||
with mock_config_flow("test", TestFlow):
|
||||
await manager.flow.async_init(
|
||||
"test",
|
||||
context={"source": config_entries.SOURCE_HOMEKIT},
|
||||
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||
)
|
||||
|
||||
# The flow does not implement is_matching
|
||||
with pytest.raises(NotImplementedError):
|
||||
manager.flow.async_has_matching_flow(flow)
|
||||
|
|
|
@ -781,83 +781,6 @@ async def test_async_get_unknown_flow(manager: MockFlowManager) -> None:
|
|||
await manager.async_get("does_not_exist")
|
||||
|
||||
|
||||
async def test_async_has_matching_flow(
|
||||
hass: HomeAssistant, manager: MockFlowManager
|
||||
) -> None:
|
||||
"""Test we can check for matching flows."""
|
||||
manager.hass = hass
|
||||
assert (
|
||||
manager.async_has_matching_flow(
|
||||
"test",
|
||||
{"source": config_entries.SOURCE_HOMEKIT},
|
||||
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
@manager.mock_reg_handler("test")
|
||||
class TestFlow(data_entry_flow.FlowHandler):
|
||||
VERSION = 5
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
return self.async_show_progress(
|
||||
step_id="init",
|
||||
progress_action="task_one",
|
||||
)
|
||||
|
||||
result = await manager.async_init(
|
||||
"test",
|
||||
context={"source": config_entries.SOURCE_HOMEKIT},
|
||||
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||
)
|
||||
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
|
||||
assert result["progress_action"] == "task_one"
|
||||
assert len(manager.async_progress()) == 1
|
||||
assert len(manager.async_progress_by_handler("test")) == 1
|
||||
assert (
|
||||
len(
|
||||
manager.async_progress_by_handler(
|
||||
"test", match_context={"source": config_entries.SOURCE_HOMEKIT}
|
||||
)
|
||||
)
|
||||
== 1
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
manager.async_progress_by_handler(
|
||||
"test", match_context={"source": config_entries.SOURCE_BLUETOOTH}
|
||||
)
|
||||
)
|
||||
== 0
|
||||
)
|
||||
assert manager.async_get(result["flow_id"])["handler"] == "test"
|
||||
|
||||
assert (
|
||||
manager.async_has_matching_flow(
|
||||
"test",
|
||||
{"source": config_entries.SOURCE_HOMEKIT},
|
||||
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||
)
|
||||
is True
|
||||
)
|
||||
assert (
|
||||
manager.async_has_matching_flow(
|
||||
"test",
|
||||
{"source": config_entries.SOURCE_SSDP},
|
||||
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||
)
|
||||
is False
|
||||
)
|
||||
assert (
|
||||
manager.async_has_matching_flow(
|
||||
"other",
|
||||
{"source": config_entries.SOURCE_HOMEKIT},
|
||||
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
async def test_move_to_unknown_step_raises_and_removes_from_in_progress(
|
||||
manager: MockFlowManager,
|
||||
) -> None:
|
||||
|
|
Loading…
Add table
Reference in a new issue