Remove config entry specifics from FlowManager (#85565)

This commit is contained in:
Erik Montnemery 2023-01-17 15:26:17 +01:00 committed by GitHub
parent 0f3221eac7
commit 3cd6bd87a7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 119 additions and 85 deletions

View file

@ -761,6 +761,15 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
super().__init__(hass) super().__init__(hass)
self.config_entries = config_entries self.config_entries = config_entries
self._hass_config = hass_config self._hass_config = hass_config
self._initializing: dict[str, dict[str, asyncio.Future]] = {}
self._initialize_tasks: dict[str, list[asyncio.Task]] = {}
async def async_wait_init_flow_finish(self, handler: str) -> None:
"""Wait till all flows in progress are initialized."""
if not (current := self._initializing.get(handler)):
return
await asyncio.wait(current.values())
@callback @callback
def _async_has_other_discovery_flows(self, flow_id: str) -> bool: def _async_has_other_discovery_flows(self, flow_id: str) -> bool:
@ -770,12 +779,76 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
for flow in self._progress.values() for flow in self._progress.values()
) )
async def async_init(
self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None
) -> FlowResult:
"""Start a configuration flow."""
if context is None:
context = {}
flow_id = uuid_util.random_uuid_hex()
init_done: asyncio.Future = asyncio.Future()
self._initializing.setdefault(handler, {})[flow_id] = init_done
task = asyncio.create_task(self._async_init(flow_id, handler, context, data))
self._initialize_tasks.setdefault(handler, []).append(task)
try:
flow, result = await task
finally:
self._initialize_tasks[handler].remove(task)
self._initializing[handler].pop(flow_id)
if result["type"] != data_entry_flow.FlowResultType.ABORT:
await self.async_post_init(flow, result)
return result
async def _async_init(
self,
flow_id: str,
handler: str,
context: dict,
data: Any,
) -> tuple[data_entry_flow.FlowHandler, FlowResult]:
"""Run the init in a task to allow it to be canceled at shutdown."""
flow = await self.async_create_flow(handler, context=context, data=data)
if not flow:
raise data_entry_flow.UnknownFlow("Flow was not created")
flow.hass = self.hass
flow.handler = handler
flow.flow_id = flow_id
flow.context = context
flow.init_data = data
self._async_add_flow_progress(flow)
try:
result = await self._async_handle_step(flow, flow.init_step, data)
finally:
init_done = self._initializing[handler][flow_id]
if not init_done.done():
init_done.set_result(None)
return flow, result
async def async_shutdown(self) -> None:
"""Cancel any initializing flows."""
for task_list in self._initialize_tasks.values():
for task in task_list:
task.cancel()
async def async_finish_flow( async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult
) -> data_entry_flow.FlowResult: ) -> data_entry_flow.FlowResult:
"""Finish a config flow and add an entry.""" """Finish a config flow and add an entry."""
flow = cast(ConfigFlow, flow) flow = cast(ConfigFlow, flow)
# Mark the step as done.
# We do this to avoid a circular dependency where async_finish_flow sets up a
# new entry, which needs the integration to be set up, which is waiting for
# init to be done.
init_done = self._initializing[flow.handler].get(flow.flow_id)
if init_done and not init_done.done():
init_done.set_result(None)
# Remove notification if no other discovery config entries in progress # Remove notification if no other discovery config entries in progress
if not self._async_has_other_discovery_flows(flow.flow_id): if not self._async_has_other_discovery_flows(flow.flow_id):
persistent_notification.async_dismiss(self.hass, DISCOVERY_NOTIFICATION_ID) persistent_notification.async_dismiss(self.hass, DISCOVERY_NOTIFICATION_ID)

View file

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
import asyncio
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
import copy import copy
from dataclasses import dataclass from dataclasses import dataclass
@ -55,7 +54,7 @@ class BaseServiceInfo:
class FlowError(HomeAssistantError): class FlowError(HomeAssistantError):
"""Error while configuring an account.""" """Base class for data entry errors."""
class UnknownHandler(FlowError): class UnknownHandler(FlowError):
@ -137,18 +136,9 @@ class FlowManager(abc.ABC):
) -> None: ) -> None:
"""Initialize the flow manager.""" """Initialize the flow manager."""
self.hass = hass self.hass = hass
self._initializing: dict[str, list[asyncio.Future]] = {}
self._initialize_tasks: dict[str, list[asyncio.Task]] = {}
self._progress: dict[str, FlowHandler] = {} self._progress: dict[str, FlowHandler] = {}
self._handler_progress_index: dict[str, set[str]] = {} self._handler_progress_index: dict[str, set[str]] = {}
async def async_wait_init_flow_finish(self, handler: str) -> None:
"""Wait till all flows in progress are initialized."""
if not (current := self._initializing.get(handler)):
return
await asyncio.wait(current)
@abc.abstractmethod @abc.abstractmethod
async def async_create_flow( async def async_create_flow(
self, self,
@ -166,7 +156,7 @@ class FlowManager(abc.ABC):
async def async_finish_flow( async def async_finish_flow(
self, flow: FlowHandler, result: FlowResult self, flow: FlowHandler, result: FlowResult
) -> FlowResult: ) -> FlowResult:
"""Finish a config flow and add an entry.""" """Finish a data entry flow."""
async def async_post_init(self, flow: FlowHandler, result: FlowResult) -> None: async def async_post_init(self, flow: FlowHandler, result: FlowResult) -> None:
"""Entry has finished executing its first step asynchronously.""" """Entry has finished executing its first step asynchronously."""
@ -219,35 +209,9 @@ class FlowManager(abc.ABC):
async def async_init( async def async_init(
self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None
) -> FlowResult: ) -> FlowResult:
"""Start a configuration flow.""" """Start a data entry flow."""
if context is None: if context is None:
context = {} context = {}
init_done: asyncio.Future = asyncio.Future()
self._initializing.setdefault(handler, []).append(init_done)
task = asyncio.create_task(self._async_init(init_done, handler, context, data))
self._initialize_tasks.setdefault(handler, []).append(task)
try:
flow, result = await task
finally:
self._initialize_tasks[handler].remove(task)
self._initializing[handler].remove(init_done)
if result["type"] != FlowResultType.ABORT:
await self.async_post_init(flow, result)
return result
async def _async_init(
self,
init_done: asyncio.Future,
handler: str,
context: dict,
data: Any,
) -> tuple[FlowHandler, FlowResult]:
"""Run the init in a task to allow it to be canceled at shutdown."""
flow = await self.async_create_flow(handler, context=context, data=data) flow = await self.async_create_flow(handler, context=context, data=data)
if not flow: if not flow:
raise UnknownFlow("Flow was not created") raise UnknownFlow("Flow was not created")
@ -257,19 +221,18 @@ class FlowManager(abc.ABC):
flow.context = context flow.context = context
flow.init_data = data flow.init_data = data
self._async_add_flow_progress(flow) self._async_add_flow_progress(flow)
result = await self._async_handle_step(flow, flow.init_step, data, init_done)
return flow, result
async def async_shutdown(self) -> None: result = await self._async_handle_step(flow, flow.init_step, data)
"""Cancel any initializing flows."""
for task_list in self._initialize_tasks.values(): if result["type"] != FlowResultType.ABORT:
for task in task_list: await self.async_post_init(flow, result)
task.cancel()
return result
async def async_configure( async def async_configure(
self, flow_id: str, user_input: dict | None = None self, flow_id: str, user_input: dict | None = None
) -> FlowResult: ) -> FlowResult:
"""Continue a configuration flow.""" """Continue a data entry flow."""
if (flow := self._progress.get(flow_id)) is None: if (flow := self._progress.get(flow_id)) is None:
raise UnknownFlow raise UnknownFlow
@ -354,22 +317,16 @@ class FlowManager(abc.ABC):
try: try:
flow.async_remove() flow.async_remove()
except Exception as err: # pylint: disable=broad-except except Exception as err: # pylint: disable=broad-except
_LOGGER.exception("Error removing %s config flow: %s", flow.handler, err) _LOGGER.exception("Error removing %s flow: %s", flow.handler, err)
async def _async_handle_step( async def _async_handle_step(
self, self, flow: FlowHandler, step_id: str, user_input: dict | BaseServiceInfo | None
flow: FlowHandler,
step_id: str,
user_input: dict | BaseServiceInfo | None,
step_done: asyncio.Future | None = None,
) -> FlowResult: ) -> FlowResult:
"""Handle a step of a flow.""" """Handle a step of a flow."""
method = f"async_step_{step_id}" method = f"async_step_{step_id}"
if not hasattr(flow, method): if not hasattr(flow, method):
self._async_remove_flow_progress(flow.flow_id) self._async_remove_flow_progress(flow.flow_id)
if step_done:
step_done.set_result(None)
raise UnknownStep( raise UnknownStep(
f"Handler {flow.__class__.__name__} doesn't support step {step_id}" f"Handler {flow.__class__.__name__} doesn't support step {step_id}"
) )
@ -381,13 +338,6 @@ class FlowManager(abc.ABC):
flow.flow_id, flow.handler, err.reason, err.description_placeholders flow.flow_id, flow.handler, err.reason, err.description_placeholders
) )
# Mark the step as done.
# We do this before calling async_finish_flow because config entries will hit a
# circular dependency where async_finish_flow sets up new entry, which needs the
# integration to be set up, which is waiting for init to be done.
if step_done:
step_done.set_result(None)
if not isinstance(result["type"], FlowResultType): if not isinstance(result["type"], FlowResultType):
result["type"] = FlowResultType(result["type"]) # type: ignore[unreachable] result["type"] = FlowResultType(result["type"]) # type: ignore[unreachable]
report( report(
@ -424,7 +374,7 @@ class FlowManager(abc.ABC):
class FlowHandler: class FlowHandler:
"""Handle the configuration flow of a component.""" """Handle a data entry flow."""
# Set by flow manager # Set by flow manager
cur_step: FlowResult | None = None cur_step: FlowResult | None = None
@ -519,7 +469,7 @@ class FlowHandler:
description: str | None = None, description: str | None = None,
description_placeholders: Mapping[str, str] | None = None, description_placeholders: Mapping[str, str] | None = None,
) -> FlowResult: ) -> FlowResult:
"""Finish config flow and create a config entry.""" """Finish flow."""
flow_result = FlowResult( flow_result = FlowResult(
version=self.VERSION, version=self.VERSION,
type=FlowResultType.CREATE_ENTRY, type=FlowResultType.CREATE_ENTRY,
@ -541,7 +491,7 @@ class FlowHandler:
reason: str, reason: str,
description_placeholders: Mapping[str, str] | None = None, description_placeholders: Mapping[str, str] | None = None,
) -> FlowResult: ) -> FlowResult:
"""Abort the config flow.""" """Abort the flow."""
return _create_abort_data( return _create_abort_data(
self.flow_id, self.handler, reason, description_placeholders self.flow_id, self.handler, reason, description_placeholders
) )
@ -626,7 +576,7 @@ class FlowHandler:
@callback @callback
def async_remove(self) -> None: def async_remove(self) -> None:
"""Notification that the config flow has been removed.""" """Notification that the flow has been removed."""
@callback @callback

View file

@ -92,7 +92,9 @@ async def test_discover_config_flow(hass):
with patch.dict( with patch.dict(
discovery.CONFIG_ENTRY_HANDLERS, {"mock-service": "mock-component"} discovery.CONFIG_ENTRY_HANDLERS, {"mock-service": "mock-component"}
), patch("homeassistant.data_entry_flow.FlowManager.async_init") as m_init: ), patch(
"homeassistant.config_entries.ConfigEntriesFlowManager.async_init"
) as m_init:
await mock_discovery(hass, discover) await mock_discovery(hass, discover)
assert len(m_init.mock_calls) == 1 assert len(m_init.mock_calls) == 1

View file

@ -3537,3 +3537,29 @@ async def test_options_flow_options_not_mutated() -> None:
"sub_list": ["one", "two"], "sub_list": ["one", "two"],
} }
assert entry.options == {"sub_dict": {"1": "one"}, "sub_list": ["one"]} assert entry.options == {"sub_dict": {"1": "one"}, "sub_list": ["one"]}
async def test_initializing_flows_canceled_on_shutdown(hass: HomeAssistant, manager):
"""Test that initializing flows are canceled on shutdown."""
class MockFlowHandler(config_entries.ConfigFlow):
"""Define a mock flow handler."""
VERSION = 1
async def async_step_reauth(self, data):
"""Mock Reauth."""
await asyncio.sleep(1)
with patch.dict(
config_entries.HANDLERS, {"comp": MockFlowHandler, "test": MockFlowHandler}
):
task = asyncio.create_task(
manager.flow.async_init("test", context={"source": "reauth"})
)
await hass.async_block_till_done()
await manager.flow.async_shutdown()
with pytest.raises(asyncio.exceptions.CancelledError):
await task

View file

@ -1,5 +1,4 @@
"""Test the flow classes.""" """Test the flow classes."""
import asyncio
import logging import logging
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
@ -181,7 +180,7 @@ async def test_abort_calls_async_remove_with_exception(manager, caplog):
with caplog.at_level(logging.ERROR): with caplog.at_level(logging.ERROR):
await manager.async_init("test") await manager.async_init("test")
assert "Error removing test config flow: error" in caplog.text assert "Error removing test flow: error" in caplog.text
TestFlow.async_remove.assert_called_once() TestFlow.async_remove.assert_called_once()
@ -419,22 +418,6 @@ async def test_abort_flow_exception(manager):
assert form["description_placeholders"] == {"placeholder": "yo"} assert form["description_placeholders"] == {"placeholder": "yo"}
async def test_initializing_flows_canceled_on_shutdown(hass, manager):
"""Test that initializing flows are canceled on shutdown."""
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
async def async_step_init(self, user_input=None):
await asyncio.sleep(1)
task = asyncio.create_task(manager.async_init("test"))
await hass.async_block_till_done()
await manager.async_shutdown()
with pytest.raises(asyncio.exceptions.CancelledError):
await task
async def test_init_unknown_flow(manager): async def test_init_unknown_flow(manager):
"""Test that UnknownFlow is raised when async_create_flow returns None.""" """Test that UnknownFlow is raised when async_create_flow returns None."""