diff --git a/homeassistant/components/apple_tv/config_flow.py b/homeassistant/components/apple_tv/config_flow.py index 71c26244203..b0741cc9c61 100644 --- a/homeassistant/components/apple_tv/config_flow.py +++ b/homeassistant/components/apple_tv/config_flow.py @@ -8,7 +8,7 @@ from collections.abc import Awaitable, Callable, Mapping from ipaddress import ip_address import logging from random import randrange -from typing import Any +from typing import Any, Self from pyatv import exceptions, pair, scan from pyatv.const import DeviceModel, PairingRequirement, Protocol @@ -98,8 +98,11 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN): VERSION = 1 scan_filter: str | None = None + all_identifiers: set[str] atv: BaseConfig | None = None atv_identifiers: list[str] | None = None + _host: str # host in zeroconf discovery info, should not be accessed by other flows + host: str | None = None # set by _async_aggregate_discoveries, for other flows protocol: Protocol | None = None pairing: PairingHandler | None = None protocols_to_pair: deque[Protocol] | None = None @@ -157,7 +160,6 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN): "type": "Apple TV", } self.scan_filter = self.unique_id - self.context["identifier"] = self.unique_id return await self.async_step_restore_device() async def async_step_restore_device( @@ -192,7 +194,7 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN): self.device_identifier, raise_on_progress=False ) assert self.atv - self.context["all_identifiers"] = self.atv.all_identifiers + self.all_identifiers = set(self.atv.all_identifiers) return await self.async_step_confirm() return self.async_show_form( @@ -207,7 +209,7 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN): """Handle device found via zeroconf.""" if discovery_info.ip_address.version == 6: return self.async_abort(reason="ipv6_not_supported") - host = discovery_info.host + self._host = host = discovery_info.host service_type = discovery_info.type[:-1] # Remove leading . name = discovery_info.name.replace(f".{service_type}.", "") properties = discovery_info.properties @@ -255,7 +257,7 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN): # as two separate flows. # # To solve this, all identifiers are stored as - # "all_identifiers" in the flow context. When a new service is discovered, the + # "all_identifiers" in the flow. When a new service is discovered, the # code below will check these identifiers for all active flows and abort if a # match is found. Before aborting, the original flow is updated with any # potentially new identifiers. In the example above, when service C is @@ -277,32 +279,32 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN): self._async_check_and_update_in_progress(host, unique_id) # Host must only be set AFTER checking and updating in progress # flows or we will have a race condition where no flows move forward. - self.context[CONF_ADDRESS] = host + self.host = host @callback def _async_check_and_update_in_progress(self, host: str, unique_id: str) -> None: """Check for in-progress flows and update them with identifiers if needed.""" - for flow in self._async_in_progress(include_uninitialized=True): - context = flow["context"] - if ( - context.get("source") != SOURCE_ZEROCONF - or context.get(CONF_ADDRESS) != host - ): - continue - if ( - "all_identifiers" in context - and unique_id not in context["all_identifiers"] - ): - # Add potentially new identifiers from this device to the existing flow - context["all_identifiers"].append(unique_id) + if self.hass.config_entries.flow.async_has_matching_flow(self): raise AbortFlow("already_in_progress") + def is_matching(self, other_flow: Self) -> bool: + """Return True if other_flow is matching this flow.""" + if ( + other_flow.context.get("source") != SOURCE_ZEROCONF + or other_flow.host != self._host + ): + return False + if self.unique_id is not None: + # Add potentially new identifiers from this device to the existing flow + other_flow.all_identifiers.add(self.unique_id) + return True + async def async_found_zeroconf_device( self, user_input: dict[str, str] | None = None ) -> ConfigFlowResult: """Handle device found after Zeroconf discovery.""" assert self.atv - self.context["all_identifiers"] = self.atv.all_identifiers + self.all_identifiers = set(self.atv.all_identifiers) # Also abort if an integration with this identifier already exists await self.async_set_unique_id(self.device_identifier) # but be sure to update the address if its changed so the scanner @@ -310,7 +312,6 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN): self._abort_if_unique_id_configured( updates={CONF_ADDRESS: str(self.atv.address)} ) - self.context["identifier"] = self.unique_id return await self.async_step_confirm() async def async_find_device_wrapper( @@ -390,7 +391,7 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN): """Handle user-confirmation of discovered node.""" assert self.atv if user_input is not None: - expected_identifier_count = len(self.context["all_identifiers"]) + expected_identifier_count = len(self.all_identifiers) # If number of services found during device scan mismatch number of # identifiers collected during Zeroconf discovery, then trigger a new scan # with hopes of finding all services. diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 404ae1c91dd..ac96b83f61d 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -1544,6 +1544,35 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]): notification_id=DISCOVERY_NOTIFICATION_ID, ) + @callback + def async_has_matching_discovery_flow( + self, handler: str, match_context: dict[str, Any], data: Any + ) -> bool: + """Check if an existing matching discovery flow is in progress. + + A flow with the same handler, context, and data. + + If match_context is passed, only return flows with a context that is a + superset of match_context. + """ + if not (flows := self._handler_progress_index.get(handler)): + return False + match_items = match_context.items() + for progress in flows: + if match_items <= progress.context.items() and progress.init_data == data: + return True + return False + + @callback + def async_has_matching_flow(self, flow: ConfigFlow) -> bool: + """Check if an existing matching flow is in progress.""" + if not (flows := self._handler_progress_index.get(flow.handler)): + return False + for other_flow in flows: + if other_flow is not flow and flow.is_matching(other_flow): # type: ignore[arg-type] + return True + return False + class ConfigEntryItems(UserDict[str, ConfigEntry]): """Container for config items, maps config_entry_id -> entry. @@ -2693,6 +2722,10 @@ class ConfigFlow(ConfigEntryBaseFlow): self.hass.config_entries.async_schedule_reload(entry.entry_id) return self.async_abort(reason=reason) + def is_matching(self, other_flow: Self) -> bool: + """Return True if other_flow is matching this flow.""" + raise NotImplementedError + class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]): """Flow to set options for a configuration entry.""" diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index dff7ebee03c..de08a178a70 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -237,25 +237,6 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]): ) -> None: """Entry has finished executing its first step asynchronously.""" - @callback - def async_has_matching_flow( - self, handler: _HandlerT, match_context: dict[str, Any], data: Any - ) -> bool: - """Check if an existing matching flow is in progress. - - A flow with the same handler, context, and data. - - If match_context is passed, only return flows with a context that is a - superset of match_context. - """ - if not (flows := self._handler_progress_index.get(handler)): - return False - match_items = match_context.items() - for progress in flows: - if match_items <= progress.context.items() and progress.init_data == data: - return True - return False - @callback def async_get(self, flow_id: str) -> _FlowResultT: """Return a flow in progress as a partial FlowResult.""" diff --git a/homeassistant/helpers/discovery_flow.py b/homeassistant/helpers/discovery_flow.py index 8112be3dde4..e6596a496e0 100644 --- a/homeassistant/helpers/discovery_flow.py +++ b/homeassistant/helpers/discovery_flow.py @@ -78,7 +78,9 @@ def _async_init_flow( # which can overload devices since zeroconf/ssdp updates can happen # multiple times in the same minute if ( - hass.config_entries.flow.async_has_matching_flow(domain, context, data) + hass.config_entries.flow.async_has_matching_discovery_flow( + domain, context, data + ) or hass.is_stopping ): return None diff --git a/tests/helpers/test_discovery_flow.py b/tests/helpers/test_discovery_flow.py index 2bb58f86c9a..dde0f209706 100644 --- a/tests/helpers/test_discovery_flow.py +++ b/tests/helpers/test_discovery_flow.py @@ -91,7 +91,7 @@ async def test_async_create_flow_checks_existing_flows_after_startup( """Test existing flows prevent an identical ones from being after startup.""" hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) with patch( - "homeassistant.data_entry_flow.FlowManager.async_has_matching_flow", + "homeassistant.config_entries.ConfigEntriesFlowManager.async_has_matching_discovery_flow", return_value=True, ): discovery_flow.async_create_flow( diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 9cba19ef3b1..e16e0a0ace5 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -7,7 +7,7 @@ from collections.abc import Generator from datetime import timedelta from functools import cached_property import logging -from typing import Any +from typing import Any, Self from unittest.mock import ANY, AsyncMock, Mock, patch from freezegun import freeze_time @@ -6180,3 +6180,204 @@ async def test_async_loaded_entries( assert await hass.config_entries.async_unload(entry1.entry_id) assert hass.config_entries.async_loaded_entries("comp") == [] + + +async def test_async_has_matching_discovery_flow( + hass: HomeAssistant, manager: config_entries.ConfigEntries +) -> None: + """Test we can check for matching discovery flows.""" + assert ( + manager.flow.async_has_matching_discovery_flow( + "test", + {"source": config_entries.SOURCE_HOMEKIT}, + {"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, + ) + is False + ) + + mock_integration(hass, MockModule("test")) + mock_platform(hass, "test.config_flow", None) + + class TestFlow(config_entries.ConfigFlow): + VERSION = 5 + + async def async_step_init(self, user_input=None): + return self.async_show_progress( + step_id="init", + progress_action="task_one", + ) + + async def async_step_homekit(self, discovery_info=None): + return await self.async_step_init(discovery_info) + + with mock_config_flow("test", TestFlow): + result = await manager.flow.async_init( + "test", + context={"source": config_entries.SOURCE_HOMEKIT}, + data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, + ) + assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS + assert result["progress_action"] == "task_one" + assert len(manager.flow.async_progress()) == 1 + assert len(manager.flow.async_progress_by_handler("test")) == 1 + assert ( + len( + manager.flow.async_progress_by_handler( + "test", match_context={"source": config_entries.SOURCE_HOMEKIT} + ) + ) + == 1 + ) + assert ( + len( + manager.flow.async_progress_by_handler( + "test", match_context={"source": config_entries.SOURCE_BLUETOOTH} + ) + ) + == 0 + ) + assert manager.flow.async_get(result["flow_id"])["handler"] == "test" + + assert ( + manager.flow.async_has_matching_discovery_flow( + "test", + {"source": config_entries.SOURCE_HOMEKIT}, + {"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, + ) + is True + ) + assert ( + manager.flow.async_has_matching_discovery_flow( + "test", + {"source": config_entries.SOURCE_SSDP}, + {"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, + ) + is False + ) + assert ( + manager.flow.async_has_matching_discovery_flow( + "other", + {"source": config_entries.SOURCE_HOMEKIT}, + {"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, + ) + is False + ) + + +async def test_async_has_matching_flow( + hass: HomeAssistant, manager: config_entries.ConfigEntries +) -> None: + """Test check for matching flows when there is no active flow.""" + mock_integration(hass, MockModule("test")) + mock_platform(hass, "test.config_flow", None) + + class TestFlow(config_entries.ConfigFlow): + VERSION = 5 + + async def async_step_init(self, user_input=None): + return self.async_show_progress( + step_id="init", + progress_action="task_one", + ) + + async def async_step_homekit(self, discovery_info=None): + return await self.async_step_init(discovery_info) + + def is_matching(self, other_flow: Self) -> bool: + """Return True if other_flow is matching this flow.""" + return True + + # Initiate a flow + with mock_config_flow("test", TestFlow): + await manager.flow.async_init( + "test", + context={"source": config_entries.SOURCE_HOMEKIT}, + data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, + ) + flow = list(manager.flow._handler_progress_index.get("test"))[0] + + assert manager.flow.async_has_matching_flow(flow) is False + + # Initiate another flow + with mock_config_flow("test", TestFlow): + await manager.flow.async_init( + "test", + context={"source": config_entries.SOURCE_HOMEKIT}, + data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, + ) + + assert manager.flow.async_has_matching_flow(flow) is True + + +async def test_async_has_matching_flow_no_flows( + hass: HomeAssistant, manager: config_entries.ConfigEntries +) -> None: + """Test check for matching flows when there is no active flow.""" + mock_integration(hass, MockModule("test")) + mock_platform(hass, "test.config_flow", None) + + class TestFlow(config_entries.ConfigFlow): + VERSION = 5 + + async def async_step_init(self, user_input=None): + return self.async_show_progress( + step_id="init", + progress_action="task_one", + ) + + async def async_step_homekit(self, discovery_info=None): + return await self.async_step_init(discovery_info) + + with mock_config_flow("test", TestFlow): + result = await manager.flow.async_init( + "test", + context={"source": config_entries.SOURCE_HOMEKIT}, + data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, + ) + flow = list(manager.flow._handler_progress_index.get("test"))[0] + + # Abort the flow before checking for matching flows + manager.flow.async_abort(result["flow_id"]) + + assert manager.flow.async_has_matching_flow(flow) is False + + +async def test_async_has_matching_flow_not_implemented( + hass: HomeAssistant, manager: config_entries.ConfigEntries +) -> None: + """Test check for matching flows when there is no active flow.""" + mock_integration(hass, MockModule("test")) + mock_platform(hass, "test.config_flow", None) + + class TestFlow(config_entries.ConfigFlow): + VERSION = 5 + + async def async_step_init(self, user_input=None): + return self.async_show_progress( + step_id="init", + progress_action="task_one", + ) + + async def async_step_homekit(self, discovery_info=None): + return await self.async_step_init(discovery_info) + + # Initiate a flow + with mock_config_flow("test", TestFlow): + await manager.flow.async_init( + "test", + context={"source": config_entries.SOURCE_HOMEKIT}, + data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, + ) + flow = list(manager.flow._handler_progress_index.get("test"))[0] + + # Initiate another flow + with mock_config_flow("test", TestFlow): + await manager.flow.async_init( + "test", + context={"source": config_entries.SOURCE_HOMEKIT}, + data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, + ) + + # The flow does not implement is_matching + with pytest.raises(NotImplementedError): + manager.flow.async_has_matching_flow(flow) diff --git a/tests/test_data_entry_flow.py b/tests/test_data_entry_flow.py index 01b6a530105..32020ac0d76 100644 --- a/tests/test_data_entry_flow.py +++ b/tests/test_data_entry_flow.py @@ -781,83 +781,6 @@ async def test_async_get_unknown_flow(manager: MockFlowManager) -> None: await manager.async_get("does_not_exist") -async def test_async_has_matching_flow( - hass: HomeAssistant, manager: MockFlowManager -) -> None: - """Test we can check for matching flows.""" - manager.hass = hass - assert ( - manager.async_has_matching_flow( - "test", - {"source": config_entries.SOURCE_HOMEKIT}, - {"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, - ) - is False - ) - - @manager.mock_reg_handler("test") - class TestFlow(data_entry_flow.FlowHandler): - VERSION = 5 - - async def async_step_init(self, user_input=None): - return self.async_show_progress( - step_id="init", - progress_action="task_one", - ) - - result = await manager.async_init( - "test", - context={"source": config_entries.SOURCE_HOMEKIT}, - data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, - ) - assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS - assert result["progress_action"] == "task_one" - assert len(manager.async_progress()) == 1 - assert len(manager.async_progress_by_handler("test")) == 1 - assert ( - len( - manager.async_progress_by_handler( - "test", match_context={"source": config_entries.SOURCE_HOMEKIT} - ) - ) - == 1 - ) - assert ( - len( - manager.async_progress_by_handler( - "test", match_context={"source": config_entries.SOURCE_BLUETOOTH} - ) - ) - == 0 - ) - assert manager.async_get(result["flow_id"])["handler"] == "test" - - assert ( - manager.async_has_matching_flow( - "test", - {"source": config_entries.SOURCE_HOMEKIT}, - {"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, - ) - is True - ) - assert ( - manager.async_has_matching_flow( - "test", - {"source": config_entries.SOURCE_SSDP}, - {"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, - ) - is False - ) - assert ( - manager.async_has_matching_flow( - "other", - {"source": config_entries.SOURCE_HOMEKIT}, - {"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, - ) - is False - ) - - async def test_move_to_unknown_step_raises_and_removes_from_in_progress( manager: MockFlowManager, ) -> None: