Cancel discovery flows that are initializing at shutdown (#49241)
This commit is contained in:
parent
a529a12745
commit
dafc7a072c
3 changed files with 59 additions and 12 deletions
|
@ -792,6 +792,7 @@ class ConfigEntries:
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[entry.async_shutdown() for entry in self._entries.values()]
|
*[entry.async_shutdown() for entry in self._entries.values()]
|
||||||
)
|
)
|
||||||
|
await self.flow.async_shutdown()
|
||||||
|
|
||||||
async def async_initialize(self) -> None:
|
async def async_initialize(self) -> None:
|
||||||
"""Initialize config entry config."""
|
"""Initialize config entry config."""
|
||||||
|
|
|
@ -61,6 +61,7 @@ class FlowManager(abc.ABC):
|
||||||
"""Initialize the flow manager."""
|
"""Initialize the flow manager."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self._initializing: dict[str, list[asyncio.Future]] = {}
|
self._initializing: dict[str, list[asyncio.Future]] = {}
|
||||||
|
self._initialize_tasks: dict[str, list[asyncio.Task]] = {}
|
||||||
self._progress: dict[str, Any] = {}
|
self._progress: dict[str, Any] = {}
|
||||||
|
|
||||||
async def async_wait_init_flow_finish(self, handler: str) -> None:
|
async def async_wait_init_flow_finish(self, handler: str) -> None:
|
||||||
|
@ -118,21 +119,13 @@ class FlowManager(abc.ABC):
|
||||||
init_done: asyncio.Future = asyncio.Future()
|
init_done: asyncio.Future = asyncio.Future()
|
||||||
self._initializing.setdefault(handler, []).append(init_done)
|
self._initializing.setdefault(handler, []).append(init_done)
|
||||||
|
|
||||||
flow = await self.async_create_flow(handler, context=context, data=data)
|
task = asyncio.create_task(self._async_init(init_done, handler, context, data))
|
||||||
if not flow:
|
self._initialize_tasks.setdefault(handler, []).append(task)
|
||||||
self._initializing[handler].remove(init_done)
|
|
||||||
raise UnknownFlow("Flow was not created")
|
|
||||||
flow.hass = self.hass
|
|
||||||
flow.handler = handler
|
|
||||||
flow.flow_id = uuid.uuid4().hex
|
|
||||||
flow.context = context
|
|
||||||
self._progress[flow.flow_id] = flow
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await self._async_handle_step(
|
flow, result = await task
|
||||||
flow, flow.init_step, data, init_done
|
|
||||||
)
|
|
||||||
finally:
|
finally:
|
||||||
|
self._initialize_tasks[handler].remove(task)
|
||||||
self._initializing[handler].remove(init_done)
|
self._initializing[handler].remove(init_done)
|
||||||
|
|
||||||
if result["type"] != RESULT_TYPE_ABORT:
|
if result["type"] != RESULT_TYPE_ABORT:
|
||||||
|
@ -140,6 +133,31 @@ class FlowManager(abc.ABC):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def _async_init(
|
||||||
|
self,
|
||||||
|
init_done: asyncio.Future,
|
||||||
|
handler: str,
|
||||||
|
context: dict,
|
||||||
|
data: Any,
|
||||||
|
) -> tuple[FlowHandler, Any]:
|
||||||
|
"""Run the init in a task to allow it to be canceled at shutdown."""
|
||||||
|
flow = await self.async_create_flow(handler, context=context, data=data)
|
||||||
|
if not flow:
|
||||||
|
raise UnknownFlow("Flow was not created")
|
||||||
|
flow.hass = self.hass
|
||||||
|
flow.handler = handler
|
||||||
|
flow.flow_id = uuid.uuid4().hex
|
||||||
|
flow.context = context
|
||||||
|
self._progress[flow.flow_id] = flow
|
||||||
|
result = await self._async_handle_step(flow, flow.init_step, data, init_done)
|
||||||
|
return flow, result
|
||||||
|
|
||||||
|
async def async_shutdown(self) -> None:
|
||||||
|
"""Cancel any initializing flows."""
|
||||||
|
for task_list in self._initialize_tasks.values():
|
||||||
|
for task in task_list:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
async def async_configure(
|
async def async_configure(
|
||||||
self, flow_id: str, user_input: dict | None = None
|
self, flow_id: str, user_input: dict | None = None
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
"""Test the flow classes."""
|
"""Test the flow classes."""
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
@ -367,3 +370,28 @@ async def test_abort_flow_exception(manager):
|
||||||
assert form["type"] == "abort"
|
assert form["type"] == "abort"
|
||||||
assert form["reason"] == "mock-reason"
|
assert form["reason"] == "mock-reason"
|
||||||
assert form["description_placeholders"] == {"placeholder": "yo"}
|
assert form["description_placeholders"] == {"placeholder": "yo"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_initializing_flows_canceled_on_shutdown(hass, manager):
|
||||||
|
"""Test that initializing flows are canceled on shutdown."""
|
||||||
|
|
||||||
|
@manager.mock_reg_handler("test")
|
||||||
|
class TestFlow(data_entry_flow.FlowHandler):
|
||||||
|
async def async_step_init(self, user_input=None):
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
task = asyncio.create_task(manager.async_init("test"))
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
await manager.async_shutdown()
|
||||||
|
|
||||||
|
with pytest.raises(asyncio.exceptions.CancelledError):
|
||||||
|
await task
|
||||||
|
|
||||||
|
|
||||||
|
async def test_init_unknown_flow(manager):
|
||||||
|
"""Test that UnknownFlow is raised when async_create_flow returns None."""
|
||||||
|
|
||||||
|
with pytest.raises(data_entry_flow.UnknownFlow), patch.object(
|
||||||
|
manager, "async_create_flow", return_value=None
|
||||||
|
):
|
||||||
|
await manager.async_init("test")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue