Remove config entry specifics from FlowManager (#85565)
This commit is contained in:
parent
0f3221eac7
commit
3cd6bd87a7
5 changed files with 119 additions and 85 deletions
|
@ -761,6 +761,15 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
|
|||
super().__init__(hass)
|
||||
self.config_entries = config_entries
|
||||
self._hass_config = hass_config
|
||||
self._initializing: dict[str, dict[str, asyncio.Future]] = {}
|
||||
self._initialize_tasks: dict[str, list[asyncio.Task]] = {}
|
||||
|
||||
async def async_wait_init_flow_finish(self, handler: str) -> None:
|
||||
"""Wait till all flows in progress are initialized."""
|
||||
if not (current := self._initializing.get(handler)):
|
||||
return
|
||||
|
||||
await asyncio.wait(current.values())
|
||||
|
||||
@callback
|
||||
def _async_has_other_discovery_flows(self, flow_id: str) -> bool:
|
||||
|
@ -770,12 +779,76 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
|
|||
for flow in self._progress.values()
|
||||
)
|
||||
|
||||
async def async_init(
|
||||
self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None
|
||||
) -> FlowResult:
|
||||
"""Start a configuration flow."""
|
||||
if context is None:
|
||||
context = {}
|
||||
|
||||
flow_id = uuid_util.random_uuid_hex()
|
||||
init_done: asyncio.Future = asyncio.Future()
|
||||
self._initializing.setdefault(handler, {})[flow_id] = init_done
|
||||
|
||||
task = asyncio.create_task(self._async_init(flow_id, handler, context, data))
|
||||
self._initialize_tasks.setdefault(handler, []).append(task)
|
||||
|
||||
try:
|
||||
flow, result = await task
|
||||
finally:
|
||||
self._initialize_tasks[handler].remove(task)
|
||||
self._initializing[handler].pop(flow_id)
|
||||
|
||||
if result["type"] != data_entry_flow.FlowResultType.ABORT:
|
||||
await self.async_post_init(flow, result)
|
||||
|
||||
return result
|
||||
|
||||
async def _async_init(
|
||||
self,
|
||||
flow_id: str,
|
||||
handler: str,
|
||||
context: dict,
|
||||
data: Any,
|
||||
) -> tuple[data_entry_flow.FlowHandler, FlowResult]:
|
||||
"""Run the init in a task to allow it to be canceled at shutdown."""
|
||||
flow = await self.async_create_flow(handler, context=context, data=data)
|
||||
if not flow:
|
||||
raise data_entry_flow.UnknownFlow("Flow was not created")
|
||||
flow.hass = self.hass
|
||||
flow.handler = handler
|
||||
flow.flow_id = flow_id
|
||||
flow.context = context
|
||||
flow.init_data = data
|
||||
self._async_add_flow_progress(flow)
|
||||
try:
|
||||
result = await self._async_handle_step(flow, flow.init_step, data)
|
||||
finally:
|
||||
init_done = self._initializing[handler][flow_id]
|
||||
if not init_done.done():
|
||||
init_done.set_result(None)
|
||||
return flow, result
|
||||
|
||||
async def async_shutdown(self) -> None:
|
||||
"""Cancel any initializing flows."""
|
||||
for task_list in self._initialize_tasks.values():
|
||||
for task in task_list:
|
||||
task.cancel()
|
||||
|
||||
async def async_finish_flow(
|
||||
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult
|
||||
) -> data_entry_flow.FlowResult:
|
||||
"""Finish a config flow and add an entry."""
|
||||
flow = cast(ConfigFlow, flow)
|
||||
|
||||
# Mark the step as done.
|
||||
# We do this to avoid a circular dependency where async_finish_flow sets up a
|
||||
# new entry, which needs the integration to be set up, which is waiting for
|
||||
# init to be done.
|
||||
init_done = self._initializing[flow.handler].get(flow.flow_id)
|
||||
if init_done and not init_done.done():
|
||||
init_done.set_result(None)
|
||||
|
||||
# Remove notification if no other discovery config entries in progress
|
||||
if not self._async_has_other_discovery_flows(flow.flow_id):
|
||||
persistent_notification.async_dismiss(self.hass, DISCOVERY_NOTIFICATION_ID)
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
from collections.abc import Iterable, Mapping
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
|
@ -55,7 +54,7 @@ class BaseServiceInfo:
|
|||
|
||||
|
||||
class FlowError(HomeAssistantError):
|
||||
"""Error while configuring an account."""
|
||||
"""Base class for data entry errors."""
|
||||
|
||||
|
||||
class UnknownHandler(FlowError):
|
||||
|
@ -137,18 +136,9 @@ class FlowManager(abc.ABC):
|
|||
) -> None:
|
||||
"""Initialize the flow manager."""
|
||||
self.hass = hass
|
||||
self._initializing: dict[str, list[asyncio.Future]] = {}
|
||||
self._initialize_tasks: dict[str, list[asyncio.Task]] = {}
|
||||
self._progress: dict[str, FlowHandler] = {}
|
||||
self._handler_progress_index: dict[str, set[str]] = {}
|
||||
|
||||
async def async_wait_init_flow_finish(self, handler: str) -> None:
|
||||
"""Wait till all flows in progress are initialized."""
|
||||
if not (current := self._initializing.get(handler)):
|
||||
return
|
||||
|
||||
await asyncio.wait(current)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def async_create_flow(
|
||||
self,
|
||||
|
@ -166,7 +156,7 @@ class FlowManager(abc.ABC):
|
|||
async def async_finish_flow(
|
||||
self, flow: FlowHandler, result: FlowResult
|
||||
) -> FlowResult:
|
||||
"""Finish a config flow and add an entry."""
|
||||
"""Finish a data entry flow."""
|
||||
|
||||
async def async_post_init(self, flow: FlowHandler, result: FlowResult) -> None:
|
||||
"""Entry has finished executing its first step asynchronously."""
|
||||
|
@ -219,35 +209,9 @@ class FlowManager(abc.ABC):
|
|||
async def async_init(
|
||||
self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None
|
||||
) -> FlowResult:
|
||||
"""Start a configuration flow."""
|
||||
"""Start a data entry flow."""
|
||||
if context is None:
|
||||
context = {}
|
||||
|
||||
init_done: asyncio.Future = asyncio.Future()
|
||||
self._initializing.setdefault(handler, []).append(init_done)
|
||||
|
||||
task = asyncio.create_task(self._async_init(init_done, handler, context, data))
|
||||
self._initialize_tasks.setdefault(handler, []).append(task)
|
||||
|
||||
try:
|
||||
flow, result = await task
|
||||
finally:
|
||||
self._initialize_tasks[handler].remove(task)
|
||||
self._initializing[handler].remove(init_done)
|
||||
|
||||
if result["type"] != FlowResultType.ABORT:
|
||||
await self.async_post_init(flow, result)
|
||||
|
||||
return result
|
||||
|
||||
async def _async_init(
|
||||
self,
|
||||
init_done: asyncio.Future,
|
||||
handler: str,
|
||||
context: dict,
|
||||
data: Any,
|
||||
) -> tuple[FlowHandler, FlowResult]:
|
||||
"""Run the init in a task to allow it to be canceled at shutdown."""
|
||||
flow = await self.async_create_flow(handler, context=context, data=data)
|
||||
if not flow:
|
||||
raise UnknownFlow("Flow was not created")
|
||||
|
@ -257,19 +221,18 @@ class FlowManager(abc.ABC):
|
|||
flow.context = context
|
||||
flow.init_data = data
|
||||
self._async_add_flow_progress(flow)
|
||||
result = await self._async_handle_step(flow, flow.init_step, data, init_done)
|
||||
return flow, result
|
||||
|
||||
async def async_shutdown(self) -> None:
|
||||
"""Cancel any initializing flows."""
|
||||
for task_list in self._initialize_tasks.values():
|
||||
for task in task_list:
|
||||
task.cancel()
|
||||
result = await self._async_handle_step(flow, flow.init_step, data)
|
||||
|
||||
if result["type"] != FlowResultType.ABORT:
|
||||
await self.async_post_init(flow, result)
|
||||
|
||||
return result
|
||||
|
||||
async def async_configure(
|
||||
self, flow_id: str, user_input: dict | None = None
|
||||
) -> FlowResult:
|
||||
"""Continue a configuration flow."""
|
||||
"""Continue a data entry flow."""
|
||||
if (flow := self._progress.get(flow_id)) is None:
|
||||
raise UnknownFlow
|
||||
|
||||
|
@ -354,22 +317,16 @@ class FlowManager(abc.ABC):
|
|||
try:
|
||||
flow.async_remove()
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Error removing %s config flow: %s", flow.handler, err)
|
||||
_LOGGER.exception("Error removing %s flow: %s", flow.handler, err)
|
||||
|
||||
async def _async_handle_step(
|
||||
self,
|
||||
flow: FlowHandler,
|
||||
step_id: str,
|
||||
user_input: dict | BaseServiceInfo | None,
|
||||
step_done: asyncio.Future | None = None,
|
||||
self, flow: FlowHandler, step_id: str, user_input: dict | BaseServiceInfo | None
|
||||
) -> FlowResult:
|
||||
"""Handle a step of a flow."""
|
||||
method = f"async_step_{step_id}"
|
||||
|
||||
if not hasattr(flow, method):
|
||||
self._async_remove_flow_progress(flow.flow_id)
|
||||
if step_done:
|
||||
step_done.set_result(None)
|
||||
raise UnknownStep(
|
||||
f"Handler {flow.__class__.__name__} doesn't support step {step_id}"
|
||||
)
|
||||
|
@ -381,13 +338,6 @@ class FlowManager(abc.ABC):
|
|||
flow.flow_id, flow.handler, err.reason, err.description_placeholders
|
||||
)
|
||||
|
||||
# Mark the step as done.
|
||||
# We do this before calling async_finish_flow because config entries will hit a
|
||||
# circular dependency where async_finish_flow sets up new entry, which needs the
|
||||
# integration to be set up, which is waiting for init to be done.
|
||||
if step_done:
|
||||
step_done.set_result(None)
|
||||
|
||||
if not isinstance(result["type"], FlowResultType):
|
||||
result["type"] = FlowResultType(result["type"]) # type: ignore[unreachable]
|
||||
report(
|
||||
|
@ -424,7 +374,7 @@ class FlowManager(abc.ABC):
|
|||
|
||||
|
||||
class FlowHandler:
|
||||
"""Handle the configuration flow of a component."""
|
||||
"""Handle a data entry flow."""
|
||||
|
||||
# Set by flow manager
|
||||
cur_step: FlowResult | None = None
|
||||
|
@ -519,7 +469,7 @@ class FlowHandler:
|
|||
description: str | None = None,
|
||||
description_placeholders: Mapping[str, str] | None = None,
|
||||
) -> FlowResult:
|
||||
"""Finish config flow and create a config entry."""
|
||||
"""Finish flow."""
|
||||
flow_result = FlowResult(
|
||||
version=self.VERSION,
|
||||
type=FlowResultType.CREATE_ENTRY,
|
||||
|
@ -541,7 +491,7 @@ class FlowHandler:
|
|||
reason: str,
|
||||
description_placeholders: Mapping[str, str] | None = None,
|
||||
) -> FlowResult:
|
||||
"""Abort the config flow."""
|
||||
"""Abort the flow."""
|
||||
return _create_abort_data(
|
||||
self.flow_id, self.handler, reason, description_placeholders
|
||||
)
|
||||
|
@ -626,7 +576,7 @@ class FlowHandler:
|
|||
|
||||
@callback
|
||||
def async_remove(self) -> None:
|
||||
"""Notification that the config flow has been removed."""
|
||||
"""Notification that the flow has been removed."""
|
||||
|
||||
|
||||
@callback
|
||||
|
|
|
@ -92,7 +92,9 @@ async def test_discover_config_flow(hass):
|
|||
|
||||
with patch.dict(
|
||||
discovery.CONFIG_ENTRY_HANDLERS, {"mock-service": "mock-component"}
|
||||
), patch("homeassistant.data_entry_flow.FlowManager.async_init") as m_init:
|
||||
), patch(
|
||||
"homeassistant.config_entries.ConfigEntriesFlowManager.async_init"
|
||||
) as m_init:
|
||||
await mock_discovery(hass, discover)
|
||||
|
||||
assert len(m_init.mock_calls) == 1
|
||||
|
|
|
@ -3537,3 +3537,29 @@ async def test_options_flow_options_not_mutated() -> None:
|
|||
"sub_list": ["one", "two"],
|
||||
}
|
||||
assert entry.options == {"sub_dict": {"1": "one"}, "sub_list": ["one"]}
|
||||
|
||||
|
||||
async def test_initializing_flows_canceled_on_shutdown(hass: HomeAssistant, manager):
|
||||
"""Test that initializing flows are canceled on shutdown."""
|
||||
|
||||
class MockFlowHandler(config_entries.ConfigFlow):
|
||||
"""Define a mock flow handler."""
|
||||
|
||||
VERSION = 1
|
||||
|
||||
async def async_step_reauth(self, data):
|
||||
"""Mock Reauth."""
|
||||
await asyncio.sleep(1)
|
||||
|
||||
with patch.dict(
|
||||
config_entries.HANDLERS, {"comp": MockFlowHandler, "test": MockFlowHandler}
|
||||
):
|
||||
|
||||
task = asyncio.create_task(
|
||||
manager.flow.async_init("test", context={"source": "reauth"})
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
await manager.flow.async_shutdown()
|
||||
|
||||
with pytest.raises(asyncio.exceptions.CancelledError):
|
||||
await task
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
"""Test the flow classes."""
|
||||
import asyncio
|
||||
import logging
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
|
@ -181,7 +180,7 @@ async def test_abort_calls_async_remove_with_exception(manager, caplog):
|
|||
with caplog.at_level(logging.ERROR):
|
||||
await manager.async_init("test")
|
||||
|
||||
assert "Error removing test config flow: error" in caplog.text
|
||||
assert "Error removing test flow: error" in caplog.text
|
||||
|
||||
TestFlow.async_remove.assert_called_once()
|
||||
|
||||
|
@ -419,22 +418,6 @@ async def test_abort_flow_exception(manager):
|
|||
assert form["description_placeholders"] == {"placeholder": "yo"}
|
||||
|
||||
|
||||
async def test_initializing_flows_canceled_on_shutdown(hass, manager):
|
||||
"""Test that initializing flows are canceled on shutdown."""
|
||||
|
||||
@manager.mock_reg_handler("test")
|
||||
class TestFlow(data_entry_flow.FlowHandler):
|
||||
async def async_step_init(self, user_input=None):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
task = asyncio.create_task(manager.async_init("test"))
|
||||
await hass.async_block_till_done()
|
||||
await manager.async_shutdown()
|
||||
|
||||
with pytest.raises(asyncio.exceptions.CancelledError):
|
||||
await task
|
||||
|
||||
|
||||
async def test_init_unknown_flow(manager):
|
||||
"""Test that UnknownFlow is raised when async_create_flow returns None."""
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue