From 3cd6bd87a7e63e6c4479a3d475449478beb83c20 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 17 Jan 2023 15:26:17 +0100 Subject: [PATCH] Remove config entry specifics from FlowManager (#85565) --- homeassistant/config_entries.py | 73 ++++++++++++++++++++++ homeassistant/data_entry_flow.py | 82 +++++-------------------- tests/components/discovery/test_init.py | 4 +- tests/test_config_entries.py | 26 ++++++++ tests/test_data_entry_flow.py | 19 +----- 5 files changed, 119 insertions(+), 85 deletions(-) diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 252f01eceef..2d4774024be 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -761,6 +761,15 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager): super().__init__(hass) self.config_entries = config_entries self._hass_config = hass_config + self._initializing: dict[str, dict[str, asyncio.Future]] = {} + self._initialize_tasks: dict[str, list[asyncio.Task]] = {} + + async def async_wait_init_flow_finish(self, handler: str) -> None: + """Wait till all flows in progress are initialized.""" + if not (current := self._initializing.get(handler)): + return + + await asyncio.wait(current.values()) @callback def _async_has_other_discovery_flows(self, flow_id: str) -> bool: @@ -770,12 +779,76 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager): for flow in self._progress.values() ) + async def async_init( + self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None + ) -> FlowResult: + """Start a configuration flow.""" + if context is None: + context = {} + + flow_id = uuid_util.random_uuid_hex() + init_done: asyncio.Future = asyncio.Future() + self._initializing.setdefault(handler, {})[flow_id] = init_done + + task = asyncio.create_task(self._async_init(flow_id, handler, context, data)) + self._initialize_tasks.setdefault(handler, []).append(task) + + try: + flow, result = await task + finally: + self._initialize_tasks[handler].remove(task) + self._initializing[handler].pop(flow_id) + + if result["type"] != data_entry_flow.FlowResultType.ABORT: + await self.async_post_init(flow, result) + + return result + + async def _async_init( + self, + flow_id: str, + handler: str, + context: dict, + data: Any, + ) -> tuple[data_entry_flow.FlowHandler, FlowResult]: + """Run the init in a task to allow it to be canceled at shutdown.""" + flow = await self.async_create_flow(handler, context=context, data=data) + if not flow: + raise data_entry_flow.UnknownFlow("Flow was not created") + flow.hass = self.hass + flow.handler = handler + flow.flow_id = flow_id + flow.context = context + flow.init_data = data + self._async_add_flow_progress(flow) + try: + result = await self._async_handle_step(flow, flow.init_step, data) + finally: + init_done = self._initializing[handler][flow_id] + if not init_done.done(): + init_done.set_result(None) + return flow, result + + async def async_shutdown(self) -> None: + """Cancel any initializing flows.""" + for task_list in self._initialize_tasks.values(): + for task in task_list: + task.cancel() + async def async_finish_flow( self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult ) -> data_entry_flow.FlowResult: """Finish a config flow and add an entry.""" flow = cast(ConfigFlow, flow) + # Mark the step as done. + # We do this to avoid a circular dependency where async_finish_flow sets up a + # new entry, which needs the integration to be set up, which is waiting for + # init to be done. + init_done = self._initializing[flow.handler].get(flow.flow_id) + if init_done and not init_done.done(): + init_done.set_result(None) + # Remove notification if no other discovery config entries in progress if not self._async_has_other_discovery_flows(flow.flow_id): persistent_notification.async_dismiss(self.hass, DISCOVERY_NOTIFICATION_ID) diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index 59f76d90da8..ebe67e47103 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -2,7 +2,6 @@ from __future__ import annotations import abc -import asyncio from collections.abc import Iterable, Mapping import copy from dataclasses import dataclass @@ -55,7 +54,7 @@ class BaseServiceInfo: class FlowError(HomeAssistantError): - """Error while configuring an account.""" + """Base class for data entry errors.""" class UnknownHandler(FlowError): @@ -137,18 +136,9 @@ class FlowManager(abc.ABC): ) -> None: """Initialize the flow manager.""" self.hass = hass - self._initializing: dict[str, list[asyncio.Future]] = {} - self._initialize_tasks: dict[str, list[asyncio.Task]] = {} self._progress: dict[str, FlowHandler] = {} self._handler_progress_index: dict[str, set[str]] = {} - async def async_wait_init_flow_finish(self, handler: str) -> None: - """Wait till all flows in progress are initialized.""" - if not (current := self._initializing.get(handler)): - return - - await asyncio.wait(current) - @abc.abstractmethod async def async_create_flow( self, @@ -166,7 +156,7 @@ class FlowManager(abc.ABC): async def async_finish_flow( self, flow: FlowHandler, result: FlowResult ) -> FlowResult: - """Finish a config flow and add an entry.""" + """Finish a data entry flow.""" async def async_post_init(self, flow: FlowHandler, result: FlowResult) -> None: """Entry has finished executing its first step asynchronously.""" @@ -219,35 +209,9 @@ class FlowManager(abc.ABC): async def async_init( self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None ) -> FlowResult: - """Start a configuration flow.""" + """Start a data entry flow.""" if context is None: context = {} - - init_done: asyncio.Future = asyncio.Future() - self._initializing.setdefault(handler, []).append(init_done) - - task = asyncio.create_task(self._async_init(init_done, handler, context, data)) - self._initialize_tasks.setdefault(handler, []).append(task) - - try: - flow, result = await task - finally: - self._initialize_tasks[handler].remove(task) - self._initializing[handler].remove(init_done) - - if result["type"] != FlowResultType.ABORT: - await self.async_post_init(flow, result) - - return result - - async def _async_init( - self, - init_done: asyncio.Future, - handler: str, - context: dict, - data: Any, - ) -> tuple[FlowHandler, FlowResult]: - """Run the init in a task to allow it to be canceled at shutdown.""" flow = await self.async_create_flow(handler, context=context, data=data) if not flow: raise UnknownFlow("Flow was not created") @@ -257,19 +221,18 @@ class FlowManager(abc.ABC): flow.context = context flow.init_data = data self._async_add_flow_progress(flow) - result = await self._async_handle_step(flow, flow.init_step, data, init_done) - return flow, result - async def async_shutdown(self) -> None: - """Cancel any initializing flows.""" - for task_list in self._initialize_tasks.values(): - for task in task_list: - task.cancel() + result = await self._async_handle_step(flow, flow.init_step, data) + + if result["type"] != FlowResultType.ABORT: + await self.async_post_init(flow, result) + + return result async def async_configure( self, flow_id: str, user_input: dict | None = None ) -> FlowResult: - """Continue a configuration flow.""" + """Continue a data entry flow.""" if (flow := self._progress.get(flow_id)) is None: raise UnknownFlow @@ -354,22 +317,16 @@ class FlowManager(abc.ABC): try: flow.async_remove() except Exception as err: # pylint: disable=broad-except - _LOGGER.exception("Error removing %s config flow: %s", flow.handler, err) + _LOGGER.exception("Error removing %s flow: %s", flow.handler, err) async def _async_handle_step( - self, - flow: FlowHandler, - step_id: str, - user_input: dict | BaseServiceInfo | None, - step_done: asyncio.Future | None = None, + self, flow: FlowHandler, step_id: str, user_input: dict | BaseServiceInfo | None ) -> FlowResult: """Handle a step of a flow.""" method = f"async_step_{step_id}" if not hasattr(flow, method): self._async_remove_flow_progress(flow.flow_id) - if step_done: - step_done.set_result(None) raise UnknownStep( f"Handler {flow.__class__.__name__} doesn't support step {step_id}" ) @@ -381,13 +338,6 @@ class FlowManager(abc.ABC): flow.flow_id, flow.handler, err.reason, err.description_placeholders ) - # Mark the step as done. - # We do this before calling async_finish_flow because config entries will hit a - # circular dependency where async_finish_flow sets up new entry, which needs the - # integration to be set up, which is waiting for init to be done. - if step_done: - step_done.set_result(None) - if not isinstance(result["type"], FlowResultType): result["type"] = FlowResultType(result["type"]) # type: ignore[unreachable] report( @@ -424,7 +374,7 @@ class FlowManager(abc.ABC): class FlowHandler: - """Handle the configuration flow of a component.""" + """Handle a data entry flow.""" # Set by flow manager cur_step: FlowResult | None = None @@ -519,7 +469,7 @@ class FlowHandler: description: str | None = None, description_placeholders: Mapping[str, str] | None = None, ) -> FlowResult: - """Finish config flow and create a config entry.""" + """Finish flow.""" flow_result = FlowResult( version=self.VERSION, type=FlowResultType.CREATE_ENTRY, @@ -541,7 +491,7 @@ class FlowHandler: reason: str, description_placeholders: Mapping[str, str] | None = None, ) -> FlowResult: - """Abort the config flow.""" + """Abort the flow.""" return _create_abort_data( self.flow_id, self.handler, reason, description_placeholders ) @@ -626,7 +576,7 @@ class FlowHandler: @callback def async_remove(self) -> None: - """Notification that the config flow has been removed.""" + """Notification that the flow has been removed.""" @callback diff --git a/tests/components/discovery/test_init.py b/tests/components/discovery/test_init.py index 9bc1e9a6812..df1a67245db 100644 --- a/tests/components/discovery/test_init.py +++ b/tests/components/discovery/test_init.py @@ -92,7 +92,9 @@ async def test_discover_config_flow(hass): with patch.dict( discovery.CONFIG_ENTRY_HANDLERS, {"mock-service": "mock-component"} - ), patch("homeassistant.data_entry_flow.FlowManager.async_init") as m_init: + ), patch( + "homeassistant.config_entries.ConfigEntriesFlowManager.async_init" + ) as m_init: await mock_discovery(hass, discover) assert len(m_init.mock_calls) == 1 diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 994c220adc4..198e79ec189 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -3537,3 +3537,29 @@ async def test_options_flow_options_not_mutated() -> None: "sub_list": ["one", "two"], } assert entry.options == {"sub_dict": {"1": "one"}, "sub_list": ["one"]} + + +async def test_initializing_flows_canceled_on_shutdown(hass: HomeAssistant, manager): + """Test that initializing flows are canceled on shutdown.""" + + class MockFlowHandler(config_entries.ConfigFlow): + """Define a mock flow handler.""" + + VERSION = 1 + + async def async_step_reauth(self, data): + """Mock Reauth.""" + await asyncio.sleep(1) + + with patch.dict( + config_entries.HANDLERS, {"comp": MockFlowHandler, "test": MockFlowHandler} + ): + + task = asyncio.create_task( + manager.flow.async_init("test", context={"source": "reauth"}) + ) + await hass.async_block_till_done() + await manager.flow.async_shutdown() + + with pytest.raises(asyncio.exceptions.CancelledError): + await task diff --git a/tests/test_data_entry_flow.py b/tests/test_data_entry_flow.py index f0bcd2b5fd6..b39635e0ca5 100644 --- a/tests/test_data_entry_flow.py +++ b/tests/test_data_entry_flow.py @@ -1,5 +1,4 @@ """Test the flow classes.""" -import asyncio import logging from unittest.mock import Mock, patch @@ -181,7 +180,7 @@ async def test_abort_calls_async_remove_with_exception(manager, caplog): with caplog.at_level(logging.ERROR): await manager.async_init("test") - assert "Error removing test config flow: error" in caplog.text + assert "Error removing test flow: error" in caplog.text TestFlow.async_remove.assert_called_once() @@ -419,22 +418,6 @@ async def test_abort_flow_exception(manager): assert form["description_placeholders"] == {"placeholder": "yo"} -async def test_initializing_flows_canceled_on_shutdown(hass, manager): - """Test that initializing flows are canceled on shutdown.""" - - @manager.mock_reg_handler("test") - class TestFlow(data_entry_flow.FlowHandler): - async def async_step_init(self, user_input=None): - await asyncio.sleep(1) - - task = asyncio.create_task(manager.async_init("test")) - await hass.async_block_till_done() - await manager.async_shutdown() - - with pytest.raises(asyncio.exceptions.CancelledError): - await task - - async def test_init_unknown_flow(manager): """Test that UnknownFlow is raised when async_create_flow returns None."""