diff --git a/homeassistant/components/auth/login_flow.py b/homeassistant/components/auth/login_flow.py index e660832487a..ed5c544499e 100644 --- a/homeassistant/components/auth/login_flow.py +++ b/homeassistant/components/auth/login_flow.py @@ -231,14 +231,9 @@ class LoginFlowResourceView(HomeAssistantView): try: # do not allow change ip during login flow - for flow in self._flow_mgr.async_progress(): - if flow["flow_id"] == flow_id and flow["context"][ - "ip_address" - ] != ip_address(request.remote): - return self.json_message( - "IP address changed", HTTPStatus.BAD_REQUEST - ) - + flow = self._flow_mgr.async_get(flow_id) + if flow["context"]["ip_address"] != ip_address(request.remote): + return self.json_message("IP address changed", HTTPStatus.BAD_REQUEST) result = await self._flow_mgr.async_configure(flow_id, data) except data_entry_flow.UnknownFlow: return self.json_message("Invalid flow specified", HTTPStatus.NOT_FOUND) diff --git a/homeassistant/components/point/config_flow.py b/homeassistant/components/point/config_flow.py index 3b9ba84fab5..fbcbcd02a2b 100644 --- a/homeassistant/components/point/config_flow.py +++ b/homeassistant/components/point/config_flow.py @@ -131,7 +131,7 @@ class PointFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): _LOGGER.debug( "Should close all flows below %s", - self.hass.config_entries.flow.async_progress(), + self._async_in_progress(), ) # Remove notification if no other discovery config entries in progress diff --git a/homeassistant/components/smartthings/__init__.py b/homeassistant/components/smartthings/__init__.py index 44a360a2e55..5612cd65732 100644 --- a/homeassistant/components/smartthings/__init__.py +++ b/homeassistant/components/smartthings/__init__.py @@ -73,8 +73,7 @@ async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: # Remove the entry which will invoke the callback to delete the app. hass.async_create_task(hass.config_entries.async_remove(entry.entry_id)) # only create new flow if there isn't a pending one for SmartThings. - flows = hass.config_entries.flow.async_progress() - if not [flow for flow in flows if flow["handler"] == DOMAIN]: + if not hass.config_entries.flow.async_progress_by_handler(DOMAIN): hass.async_create_task( hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_IMPORT} @@ -181,8 +180,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: if remove_entry: hass.async_create_task(hass.config_entries.async_remove(entry.entry_id)) # only create new flow if there isn't a pending one for SmartThings. - flows = hass.config_entries.flow.async_progress() - if not [flow for flow in flows if flow["handler"] == DOMAIN]: + if not hass.config_entries.flow.async_progress_by_handler(DOMAIN): hass.async_create_task( hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_IMPORT} diff --git a/homeassistant/components/smartthings/smartapp.py b/homeassistant/components/smartthings/smartapp.py index 36f2610b981..2086d564753 100644 --- a/homeassistant/components/smartthings/smartapp.py +++ b/homeassistant/components/smartthings/smartapp.py @@ -406,8 +406,8 @@ async def _continue_flow( flow = next( ( flow - for flow in hass.config_entries.flow.async_progress() - if flow["handler"] == DOMAIN and flow["context"]["unique_id"] == unique_id + for flow in hass.config_entries.flow.async_progress_by_handler(DOMAIN) + if flow["context"]["unique_id"] == unique_id ), None, ) diff --git a/homeassistant/components/withings/common.py b/homeassistant/components/withings/common.py index 9e4beff8c38..83c5653afdf 100644 --- a/homeassistant/components/withings/common.py +++ b/homeassistant/components/withings/common.py @@ -745,7 +745,9 @@ class DataManager: flow = next( iter( flow - for flow in self._hass.config_entries.flow.async_progress() + for flow in self._hass.config_entries.flow.async_progress_by_handler( + const.DOMAIN + ) if flow.context == context ), None, diff --git a/homeassistant/components/zha/config_flow.py b/homeassistant/components/zha/config_flow.py index 2b867366453..a3eaffdf1ba 100644 --- a/homeassistant/components/zha/config_flow.py +++ b/homeassistant/components/zha/config_flow.py @@ -120,9 +120,8 @@ class ZhaFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): # If they already have a discovery for deconz # we ignore the usb discovery as they probably # want to use it there instead - for flow in self.hass.config_entries.flow.async_progress(): - if flow["handler"] == DECONZ_DOMAIN: - return self.async_abort(reason="not_zha_device") + if self.hass.config_entries.flow.async_progress_by_handler(DECONZ_DOMAIN): + return self.async_abort(reason="not_zha_device") for entry in self.hass.config_entries.async_entries(DECONZ_DOMAIN): if entry.source != config_entries.SOURCE_IGNORE: return self.async_abort(reason="not_zha_device") diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 03d7df740ba..17f8b1396ed 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -586,7 +586,7 @@ class ConfigEntry: "unique_id": self.unique_id, } - for flow in hass.config_entries.flow.async_progress(): + for flow in hass.config_entries.flow.async_progress_by_handler(self.domain): if flow["context"] == flow_context: return @@ -618,6 +618,14 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager): self.config_entries = config_entries self._hass_config = hass_config + @callback + def _async_has_other_discovery_flows(self, flow_id: str) -> bool: + """Check if there are any other discovery flows in progress.""" + return any( + flow.context["source"] in DISCOVERY_SOURCES and flow.flow_id != flow_id + for flow in self._progress.values() + ) + async def async_finish_flow( self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult ) -> data_entry_flow.FlowResult: @@ -625,11 +633,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager): flow = cast(ConfigFlow, flow) # Remove notification if no other discovery config entries in progress - if not any( - ent["context"]["source"] in DISCOVERY_SOURCES - for ent in self.hass.config_entries.flow.async_progress() - if ent["flow_id"] != flow.flow_id - ): + if not self._async_has_other_discovery_flows(flow.flow_id): self.hass.components.persistent_notification.async_dismiss( DISCOVERY_NOTIFICATION_ID ) @@ -642,15 +646,11 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager): # Abort all flows in progress with same unique ID # or the default discovery ID - for progress_flow in self.async_progress(): + for progress_flow in self.async_progress_by_handler(flow.handler): progress_unique_id = progress_flow["context"].get("unique_id") - if ( - progress_flow["handler"] == flow.handler - and progress_flow["flow_id"] != flow.flow_id - and ( - (flow.unique_id and progress_unique_id == flow.unique_id) - or progress_unique_id == DEFAULT_DISCOVERY_UNIQUE_ID - ) + if progress_flow["flow_id"] != flow.flow_id and ( + (flow.unique_id and progress_unique_id == flow.unique_id) + or progress_unique_id == DEFAULT_DISCOVERY_UNIQUE_ID ): self.async_abort(progress_flow["flow_id"]) @@ -837,7 +837,9 @@ class ConfigEntries: # If the configuration entry is removed during reauth, it should # abort any reauth flow that is active for the removed entry. - for progress_flow in self.hass.config_entries.flow.async_progress(): + for progress_flow in self.hass.config_entries.flow.async_progress_by_handler( + entry.domain + ): context = progress_flow.get("context") if ( context @@ -1265,10 +1267,10 @@ class ConfigFlow(data_entry_flow.FlowHandler): """Return other in progress flows for current domain.""" return [ flw - for flw in self.hass.config_entries.flow.async_progress( - include_uninitialized=include_uninitialized + for flw in self.hass.config_entries.flow.async_progress_by_handler( + self.handler, include_uninitialized=include_uninitialized ) - if flw["handler"] == self.handler and flw["flow_id"] != self.flow_id + if flw["flow_id"] != self.flow_id ] async def async_step_ignore( @@ -1329,7 +1331,9 @@ class ConfigFlow(data_entry_flow.FlowHandler): # Remove reauth notification if no reauth flows are in progress if self.source == SOURCE_REAUTH and not any( ent["context"]["source"] == SOURCE_REAUTH - for ent in self.hass.config_entries.flow.async_progress() + for ent in self.hass.config_entries.flow.async_progress_by_handler( + self.handler + ) if ent["flow_id"] != self.flow_id ): self.hass.components.persistent_notification.async_dismiss( diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index c82ec3acfd7..c1f798fcc32 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -3,7 +3,7 @@ from __future__ import annotations import abc import asyncio -from collections.abc import Mapping +from collections.abc import Iterable, Mapping from types import MappingProxyType from typing import Any, TypedDict import uuid @@ -78,6 +78,23 @@ class FlowResult(TypedDict, total=False): options: Mapping[str, Any] +@callback +def _async_flow_handler_to_flow_result( + flows: Iterable[FlowHandler], include_uninitialized: bool +) -> list[FlowResult]: + """Convert a list of FlowHandler to a partial FlowResult that can be serialized.""" + return [ + { + "flow_id": flow.flow_id, + "handler": flow.handler, + "context": flow.context, + "step_id": flow.cur_step["step_id"] if flow.cur_step else None, + } + for flow in flows + if include_uninitialized or flow.cur_step is not None + ] + + class FlowManager(abc.ABC): """Manage all the flows that are in progress.""" @@ -89,7 +106,8 @@ class FlowManager(abc.ABC): self.hass = hass self._initializing: dict[str, list[asyncio.Future]] = {} self._initialize_tasks: dict[str, list[asyncio.Task]] = {} - self._progress: dict[str, Any] = {} + self._progress: dict[str, FlowHandler] = {} + self._handler_progress_index: dict[str, set[str]] = {} async def async_wait_init_flow_finish(self, handler: str) -> None: """Wait till all flows in progress are initialized.""" @@ -127,24 +145,39 @@ class FlowManager(abc.ABC): """Check if an existing matching flow is in progress with the same handler, context, and data.""" return any( flow - for flow in self._progress.values() - if flow.handler == handler - and flow.context["source"] == context["source"] - and flow.init_data == data + for flow in self._async_progress_by_handler(handler) + if flow.context["source"] == context["source"] and flow.init_data == data ) + @callback + def async_get(self, flow_id: str) -> FlowResult | None: + """Return a flow in progress as a partial FlowResult.""" + if (flow := self._progress.get(flow_id)) is None: + raise UnknownFlow + return _async_flow_handler_to_flow_result([flow], False)[0] + @callback def async_progress(self, include_uninitialized: bool = False) -> list[FlowResult]: - """Return the flows in progress.""" + """Return the flows in progress as a partial FlowResult.""" + return _async_flow_handler_to_flow_result( + self._progress.values(), include_uninitialized + ) + + @callback + def async_progress_by_handler( + self, handler: str, include_uninitialized: bool = False + ) -> list[FlowResult]: + """Return the flows in progress by handler as a partial FlowResult.""" + return _async_flow_handler_to_flow_result( + self._async_progress_by_handler(handler), include_uninitialized + ) + + @callback + def _async_progress_by_handler(self, handler: str) -> list[FlowHandler]: + """Return the flows in progress by handler.""" return [ - { - "flow_id": flow.flow_id, - "handler": flow.handler, - "context": flow.context, - "step_id": flow.cur_step["step_id"] if flow.cur_step else None, - } - for flow in self._progress.values() - if include_uninitialized or flow.cur_step is not None + self._progress[flow_id] + for flow_id in self._handler_progress_index.get(handler, {}) ] async def async_init( @@ -187,7 +220,7 @@ class FlowManager(abc.ABC): flow.flow_id = uuid.uuid4().hex flow.context = context flow.init_data = data - self._progress[flow.flow_id] = flow + self._async_add_flow_progress(flow) result = await self._async_handle_step(flow, flow.init_step, data, init_done) return flow, result @@ -205,6 +238,7 @@ class FlowManager(abc.ABC): raise UnknownFlow cur_step = flow.cur_step + assert cur_step is not None if cur_step.get("data_schema") is not None and user_input is not None: user_input = cur_step["data_schema"](user_input) @@ -245,8 +279,24 @@ class FlowManager(abc.ABC): @callback def async_abort(self, flow_id: str) -> None: """Abort a flow.""" - if self._progress.pop(flow_id, None) is None: + self._async_remove_flow_progress(flow_id) + + @callback + def _async_add_flow_progress(self, flow: FlowHandler) -> None: + """Add a flow to in progress.""" + self._progress[flow.flow_id] = flow + self._handler_progress_index.setdefault(flow.handler, set()).add(flow.flow_id) + + @callback + def _async_remove_flow_progress(self, flow_id: str) -> None: + """Remove a flow from in progress.""" + flow = self._progress.pop(flow_id, None) + if flow is None: raise UnknownFlow + handler = flow.handler + self._handler_progress_index[handler].remove(flow.flow_id) + if not self._handler_progress_index[handler]: + del self._handler_progress_index[handler] async def _async_handle_step( self, @@ -259,7 +309,7 @@ class FlowManager(abc.ABC): method = f"async_step_{step_id}" if not hasattr(flow, method): - self._progress.pop(flow.flow_id) + self._async_remove_flow_progress(flow.flow_id) if step_done: step_done.set_result(None) raise UnknownStep( @@ -310,7 +360,7 @@ class FlowManager(abc.ABC): return result # Abort and Success results both finish the flow - self._progress.pop(flow.flow_id) + self._async_remove_flow_progress(flow.flow_id) return result @@ -319,7 +369,7 @@ class FlowHandler: """Handle the configuration flow of a component.""" # Set by flow manager - cur_step: dict[str, str] | None = None + cur_step: dict[str, Any] | None = None # While not purely typed, it makes typehinting more useful for us # and removes the need for constant None checks or asserts. diff --git a/tests/components/auth/test_login_flow.py b/tests/components/auth/test_login_flow.py index ce3d37598d7..72881023fe5 100644 --- a/tests/components/auth/test_login_flow.py +++ b/tests/components/auth/test_login_flow.py @@ -114,3 +114,43 @@ async def test_login_exist_user(hass, aiohttp_client): step = await resp.json() assert step["type"] == "create_entry" assert len(step["result"]) > 1 + + +async def test_login_exist_user_ip_changes(hass, aiohttp_client): + """Test logging in and the ip address changes results in an rejection.""" + client = await async_setup_auth(hass, aiohttp_client, setup_api=True) + cred = await hass.auth.auth_providers[0].async_get_or_create_credentials( + {"username": "test-user"} + ) + await hass.auth.async_get_or_create_user(cred) + + resp = await client.post( + "/auth/login_flow", + json={ + "client_id": CLIENT_ID, + "handler": ["insecure_example", None], + "redirect_uri": CLIENT_REDIRECT_URI, + }, + ) + assert resp.status == 200 + step = await resp.json() + + # + # Here we modify the ip_address in the context to make sure + # when ip address changes in the middle of the login flow we prevent logins. + # + # This method was chosen because it seemed less likely to break + # vs patching aiohttp internals to fake the ip address + # + for flow_id, flow in hass.auth.login_flow._progress.items(): + assert flow_id == step["flow_id"] + flow.context["ip_address"] = "10.2.3.1" + + resp = await client.post( + f"/auth/login_flow/{step['flow_id']}", + json={"client_id": CLIENT_ID, "username": "test-user", "password": "test-pass"}, + ) + + assert resp.status == 400 + response = await resp.json() + assert response == {"message": "IP address changed"} diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 0b146c2f612..85d64de70a2 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -349,7 +349,7 @@ async def test_remove_entry_cancels_reauth(hass, manager): await entry.async_setup(hass) await hass.async_block_till_done() - flows = hass.config_entries.flow.async_progress() + flows = hass.config_entries.flow.async_progress_by_handler("test") assert len(flows) == 1 assert flows[0]["context"]["entry_id"] == entry.entry_id assert flows[0]["context"]["source"] == config_entries.SOURCE_REAUTH @@ -357,7 +357,7 @@ async def test_remove_entry_cancels_reauth(hass, manager): await manager.async_remove(entry.entry_id) - flows = hass.config_entries.flow.async_progress() + flows = hass.config_entries.flow.async_progress_by_handler("test") assert len(flows) == 0 @@ -2100,11 +2100,11 @@ async def test_unignore_step_form(hass, manager): # Right after removal there shouldn't be an entry or active flows assert len(hass.config_entries.async_entries("comp")) == 0 - assert len(hass.config_entries.flow.async_progress()) == 0 + assert len(hass.config_entries.flow.async_progress_by_handler("comp")) == 0 # But after a 'tick' the unignore step has run and we can see an active flow again. await hass.async_block_till_done() - assert len(hass.config_entries.flow.async_progress()) == 1 + assert len(hass.config_entries.flow.async_progress_by_handler("comp")) == 1 # and still not config entries assert len(hass.config_entries.async_entries("comp")) == 0 @@ -2144,7 +2144,7 @@ async def test_unignore_create_entry(hass, manager): await manager.async_remove(entry.entry_id) # Right after removal there shouldn't be an entry or flow - assert len(hass.config_entries.flow.async_progress()) == 0 + assert len(hass.config_entries.flow.async_progress_by_handler("comp")) == 0 assert len(hass.config_entries.async_entries("comp")) == 0 # But after a 'tick' the unignore step has run and we can see a config entry. @@ -2155,7 +2155,7 @@ async def test_unignore_create_entry(hass, manager): assert entry.title == "yo" # And still no active flow - assert len(hass.config_entries.flow.async_progress()) == 0 + assert len(hass.config_entries.flow.async_progress_by_handler("comp")) == 0 async def test_unignore_default_impl(hass, manager): diff --git a/tests/test_data_entry_flow.py b/tests/test_data_entry_flow.py index 0aa3c01d50f..b4b40b6b6c6 100644 --- a/tests/test_data_entry_flow.py +++ b/tests/test_data_entry_flow.py @@ -271,6 +271,8 @@ async def test_external_step(hass, manager): result = await manager.async_init("test") assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP assert len(manager.async_progress()) == 1 + assert len(manager.async_progress_by_handler("test")) == 1 + assert manager.async_get(result["flow_id"])["handler"] == "test" # Mimic external step # Called by integrations: `hass.config_entries.flow.async_configure(…)` @@ -327,6 +329,8 @@ async def test_show_progress(hass, manager): assert result["type"] == data_entry_flow.RESULT_TYPE_SHOW_PROGRESS assert result["progress_action"] == "task_one" assert len(manager.async_progress()) == 1 + assert len(manager.async_progress_by_handler("test")) == 1 + assert manager.async_get(result["flow_id"])["handler"] == "test" # Mimic task one done and moving to task two # Called by integrations: `hass.config_entries.flow.async_configure(…)` @@ -400,6 +404,13 @@ async def test_init_unknown_flow(manager): await manager.async_init("test") +async def test_async_get_unknown_flow(manager): + """Test that UnknownFlow is raised when async_get is called with a flow_id that does not exist.""" + + with pytest.raises(data_entry_flow.UnknownFlow): + await manager.async_get("does_not_exist") + + async def test_async_has_matching_flow( hass: HomeAssistant, manager: data_entry_flow.FlowManager ): @@ -424,6 +435,8 @@ async def test_async_has_matching_flow( assert result["type"] == data_entry_flow.RESULT_TYPE_SHOW_PROGRESS assert result["progress_action"] == "task_one" assert len(manager.async_progress()) == 1 + assert len(manager.async_progress_by_handler("test")) == 1 + assert manager.async_get(result["flow_id"])["handler"] == "test" assert ( manager.async_has_matching_flow( @@ -449,3 +462,28 @@ async def test_async_has_matching_flow( ) is False ) + + +async def test_move_to_unknown_step_raises_and_removes_from_in_progress(manager): + """Test that moving to an unknown step raises and removes the flow from in progress.""" + + @manager.mock_reg_handler("test") + class TestFlow(data_entry_flow.FlowHandler): + VERSION = 1 + + with pytest.raises(data_entry_flow.UnknownStep): + await manager.async_init("test", context={"init_step": "does_not_exist"}) + + assert manager.async_progress() == [] + + +async def test_configure_raises_unknown_flow_if_not_in_progress(manager): + """Test configure raises UnknownFlow if the flow is not in progress.""" + with pytest.raises(data_entry_flow.UnknownFlow): + await manager.async_configure("wrong_flow_id") + + +async def test_abort_raises_unknown_flow_if_not_in_progress(manager): + """Test abort raises UnknownFlow if the flow is not in progress.""" + with pytest.raises(data_entry_flow.UnknownFlow): + await manager.async_abort("wrong_flow_id")