diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 34afc77e528..9df6dff8316 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -792,6 +792,7 @@ class ConfigEntries: await asyncio.gather( *[entry.async_shutdown() for entry in self._entries.values()] ) + await self.flow.async_shutdown() async def async_initialize(self) -> None: """Initialize config entry config.""" diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index 46ec967bd94..a9a78337b17 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -61,6 +61,7 @@ class FlowManager(abc.ABC): """Initialize the flow manager.""" self.hass = hass self._initializing: dict[str, list[asyncio.Future]] = {} + self._initialize_tasks: dict[str, list[asyncio.Task]] = {} self._progress: dict[str, Any] = {} async def async_wait_init_flow_finish(self, handler: str) -> None: @@ -118,21 +119,13 @@ class FlowManager(abc.ABC): init_done: asyncio.Future = asyncio.Future() self._initializing.setdefault(handler, []).append(init_done) - flow = await self.async_create_flow(handler, context=context, data=data) - if not flow: - self._initializing[handler].remove(init_done) - raise UnknownFlow("Flow was not created") - flow.hass = self.hass - flow.handler = handler - flow.flow_id = uuid.uuid4().hex - flow.context = context - self._progress[flow.flow_id] = flow + task = asyncio.create_task(self._async_init(init_done, handler, context, data)) + self._initialize_tasks.setdefault(handler, []).append(task) try: - result = await self._async_handle_step( - flow, flow.init_step, data, init_done - ) + flow, result = await task finally: + self._initialize_tasks[handler].remove(task) self._initializing[handler].remove(init_done) if result["type"] != RESULT_TYPE_ABORT: @@ -140,6 +133,31 @@ class FlowManager(abc.ABC): return result + async def _async_init( + self, + init_done: asyncio.Future, + handler: str, + context: dict, + data: Any, + ) -> tuple[FlowHandler, Any]: + """Run the init in a task to allow it to be canceled at shutdown.""" + flow = await self.async_create_flow(handler, context=context, data=data) + if not flow: + raise UnknownFlow("Flow was not created") + flow.hass = self.hass + flow.handler = handler + flow.flow_id = uuid.uuid4().hex + flow.context = context + self._progress[flow.flow_id] = flow + result = await self._async_handle_step(flow, flow.init_step, data, init_done) + return flow, result + + async def async_shutdown(self) -> None: + """Cancel any initializing flows.""" + for task_list in self._initialize_tasks.values(): + for task in task_list: + task.cancel() + async def async_configure( self, flow_id: str, user_input: dict | None = None ) -> Any: diff --git a/tests/test_data_entry_flow.py b/tests/test_data_entry_flow.py index b2fd9c8e34b..47b21793656 100644 --- a/tests/test_data_entry_flow.py +++ b/tests/test_data_entry_flow.py @@ -1,4 +1,7 @@ """Test the flow classes.""" +import asyncio +from unittest.mock import patch + import pytest import voluptuous as vol @@ -367,3 +370,28 @@ async def test_abort_flow_exception(manager): assert form["type"] == "abort" assert form["reason"] == "mock-reason" assert form["description_placeholders"] == {"placeholder": "yo"} + + +async def test_initializing_flows_canceled_on_shutdown(hass, manager): + """Test that initializing flows are canceled on shutdown.""" + + @manager.mock_reg_handler("test") + class TestFlow(data_entry_flow.FlowHandler): + async def async_step_init(self, user_input=None): + await asyncio.sleep(1) + + task = asyncio.create_task(manager.async_init("test")) + await hass.async_block_till_done() + await manager.async_shutdown() + + with pytest.raises(asyncio.exceptions.CancelledError): + await task + + +async def test_init_unknown_flow(manager): + """Test that UnknownFlow is raised when async_create_flow returns None.""" + + with pytest.raises(data_entry_flow.UnknownFlow), patch.object( + manager, "async_create_flow", return_value=None + ): + await manager.async_init("test")