Improve ergonomics of FlowManager.async_show_progress (#107668)
* Improve ergonomics of FlowManager.async_show_progress * Don't include progress coroutine in web response * Unconditionally reset progress task when show_progress finished * Fix race * Tweak, add tests * Address review comments * Improve error handling * Allow progress jobs to return anything * Add comment * Remove unneeded check * Change API according to discussion * Adjust typing
This commit is contained in:
parent
00b40c964a
commit
24cd6a8a52
4 changed files with 221 additions and 28 deletions
|
@ -2,7 +2,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import suppress
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from aiogithubapi import (
|
||||
|
@ -18,7 +17,7 @@ import voluptuous as vol
|
|||
from homeassistant import config_entries
|
||||
from homeassistant.const import CONF_ACCESS_TOKEN
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.data_entry_flow import FlowResult, UnknownFlow
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
from homeassistant.helpers.aiohttp_client import (
|
||||
SERVER_SOFTWARE,
|
||||
async_get_clientsession,
|
||||
|
@ -124,22 +123,10 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
assert self._device is not None
|
||||
assert self._login_device is not None
|
||||
|
||||
try:
|
||||
response = await self._device.activation(
|
||||
device_code=self._login_device.device_code
|
||||
)
|
||||
self._login = response.data
|
||||
|
||||
finally:
|
||||
|
||||
async def _progress():
|
||||
# If the user closes the dialog the flow will no longer exist and it will raise UnknownFlow
|
||||
with suppress(UnknownFlow):
|
||||
await self.hass.config_entries.flow.async_configure(
|
||||
flow_id=self.flow_id
|
||||
)
|
||||
|
||||
self.hass.async_create_task(_progress())
|
||||
response = await self._device.activation(
|
||||
device_code=self._login_device.device_code
|
||||
)
|
||||
self._login = response.data
|
||||
|
||||
if not self._device:
|
||||
self._device = GitHubDeviceAPI(
|
||||
|
@ -174,6 +161,7 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
"url": OAUTH_USER_LOGIN,
|
||||
"code": self._login_device.user_code,
|
||||
},
|
||||
progress_task=self.login_task,
|
||||
)
|
||||
|
||||
async def async_step_repositories(
|
||||
|
@ -220,13 +208,6 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
"""Get the options flow for this handler."""
|
||||
return OptionsFlowHandler(config_entry)
|
||||
|
||||
@callback
|
||||
def async_remove(self) -> None:
|
||||
"""Handle remove handler callback."""
|
||||
if self.login_task and not self.login_task.done():
|
||||
# Clean up login task if it's still running
|
||||
self.login_task.cancel()
|
||||
|
||||
|
||||
class OptionsFlowHandler(config_entries.OptionsFlow):
|
||||
"""Handle a option flow for GitHub."""
|
||||
|
|
|
@ -2,7 +2,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from contextlib import suppress
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
|
@ -124,6 +126,7 @@ class FlowResult(TypedDict, total=False):
|
|||
options: Mapping[str, Any]
|
||||
preview: str | None
|
||||
progress_action: str
|
||||
progress_task: asyncio.Task[Any] | None
|
||||
reason: str
|
||||
required: bool
|
||||
result: Any
|
||||
|
@ -402,6 +405,7 @@ class FlowManager(abc.ABC):
|
|||
if (flow := self._progress.pop(flow_id, None)) is None:
|
||||
raise UnknownFlow
|
||||
self._async_remove_flow_from_index(flow)
|
||||
flow.async_cancel_progress_task()
|
||||
try:
|
||||
flow.async_remove()
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
|
@ -435,6 +439,25 @@ class FlowManager(abc.ABC):
|
|||
error_if_core=False,
|
||||
)
|
||||
|
||||
if (
|
||||
result["type"] == FlowResultType.SHOW_PROGRESS
|
||||
and (progress_task := result.pop("progress_task", None))
|
||||
and progress_task != flow.async_get_progress_task()
|
||||
):
|
||||
# The flow's progress task was changed, register a callback on it
|
||||
async def call_configure() -> None:
|
||||
with suppress(UnknownFlow):
|
||||
await self.async_configure(flow.flow_id)
|
||||
|
||||
def schedule_configure(_: asyncio.Task) -> None:
|
||||
self.hass.async_create_task(call_configure())
|
||||
|
||||
progress_task.add_done_callback(schedule_configure)
|
||||
flow.async_set_progress_task(progress_task)
|
||||
|
||||
elif result["type"] != FlowResultType.SHOW_PROGRESS:
|
||||
flow.async_cancel_progress_task()
|
||||
|
||||
if result["type"] in FLOW_NOT_COMPLETE_STEPS:
|
||||
self._raise_if_step_does_not_exist(flow, result["step_id"])
|
||||
flow.cur_step = result
|
||||
|
@ -494,6 +517,8 @@ class FlowHandler:
|
|||
VERSION = 1
|
||||
MINOR_VERSION = 1
|
||||
|
||||
__progress_task: asyncio.Task[Any] | None = None
|
||||
|
||||
@property
|
||||
def source(self) -> str | None:
|
||||
"""Source that initialized the flow."""
|
||||
|
@ -632,6 +657,7 @@ class FlowHandler:
|
|||
step_id: str,
|
||||
progress_action: str,
|
||||
description_placeholders: Mapping[str, str] | None = None,
|
||||
progress_task: asyncio.Task[Any] | None = None,
|
||||
) -> FlowResult:
|
||||
"""Show a progress message to the user, without user input allowed."""
|
||||
return FlowResult(
|
||||
|
@ -641,6 +667,7 @@ class FlowHandler:
|
|||
step_id=step_id,
|
||||
progress_action=progress_action,
|
||||
description_placeholders=description_placeholders,
|
||||
progress_task=progress_task,
|
||||
)
|
||||
|
||||
@callback
|
||||
|
@ -683,6 +710,26 @@ class FlowHandler:
|
|||
async def async_setup_preview(hass: HomeAssistant) -> None:
|
||||
"""Set up preview."""
|
||||
|
||||
@callback
|
||||
def async_cancel_progress_task(self) -> None:
|
||||
"""Cancel in progress task."""
|
||||
if self.__progress_task and not self.__progress_task.done():
|
||||
self.__progress_task.cancel()
|
||||
self.__progress_task = None
|
||||
|
||||
@callback
|
||||
def async_get_progress_task(self) -> asyncio.Task[Any] | None:
|
||||
"""Get in progress task."""
|
||||
return self.__progress_task
|
||||
|
||||
@callback
|
||||
def async_set_progress_task(
|
||||
self,
|
||||
progress_task: asyncio.Task[Any],
|
||||
) -> None:
|
||||
"""Set in progress task."""
|
||||
self.__progress_task = progress_task
|
||||
|
||||
|
||||
@callback
|
||||
def _create_abort_data(
|
||||
|
|
|
@ -121,10 +121,11 @@ async def test_flow_with_activation_failure(
|
|||
)
|
||||
assert result["step_id"] == "device"
|
||||
assert result["type"] == FlowResultType.SHOW_PROGRESS
|
||||
await hass.async_block_till_done()
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
assert result["type"] == FlowResultType.SHOW_PROGRESS_DONE
|
||||
assert result["step_id"] == "could_not_register"
|
||||
assert result["type"] == FlowResultType.ABORT
|
||||
assert result["reason"] == "could_not_register"
|
||||
|
||||
|
||||
async def test_flow_with_remove_while_activating(
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Test the flow classes."""
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import logging
|
||||
from unittest.mock import Mock, patch
|
||||
|
@ -7,7 +8,7 @@ import pytest
|
|||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries, data_entry_flow
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
from homeassistant.util.decorator import Registry
|
||||
|
||||
from .common import (
|
||||
|
@ -342,6 +343,169 @@ async def test_external_step(hass: HomeAssistant, manager) -> None:
|
|||
async def test_show_progress(hass: HomeAssistant, manager) -> None:
|
||||
"""Test show progress logic."""
|
||||
manager.hass = hass
|
||||
events = []
|
||||
task_one_evt = asyncio.Event()
|
||||
task_two_evt = asyncio.Event()
|
||||
event_received_evt = asyncio.Event()
|
||||
|
||||
@callback
|
||||
def capture_events(event: Event) -> None:
|
||||
events.append(event)
|
||||
event_received_evt.set()
|
||||
|
||||
@manager.mock_reg_handler("test")
|
||||
class TestFlow(data_entry_flow.FlowHandler):
|
||||
VERSION = 5
|
||||
data = None
|
||||
start_task_two = False
|
||||
progress_task: asyncio.Task[None] | None = None
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
async def long_running_task_one() -> None:
|
||||
await task_one_evt.wait()
|
||||
self.start_task_two = True
|
||||
|
||||
async def long_running_task_two() -> None:
|
||||
await task_two_evt.wait()
|
||||
self.data = {"title": "Hello"}
|
||||
|
||||
if not task_one_evt.is_set():
|
||||
progress_action = "task_one"
|
||||
if not self.progress_task:
|
||||
self.progress_task = hass.async_create_task(long_running_task_one())
|
||||
elif not task_two_evt.is_set():
|
||||
progress_action = "task_two"
|
||||
if self.start_task_two:
|
||||
self.progress_task = hass.async_create_task(long_running_task_two())
|
||||
self.start_task_two = False
|
||||
if not task_one_evt.is_set() or not task_two_evt.is_set():
|
||||
return self.async_show_progress(
|
||||
step_id="init",
|
||||
progress_action=progress_action,
|
||||
progress_task=self.progress_task,
|
||||
)
|
||||
|
||||
return self.async_show_progress_done(next_step_id="finish")
|
||||
|
||||
async def async_step_finish(self, user_input=None):
|
||||
return self.async_create_entry(title=self.data["title"], data=self.data)
|
||||
|
||||
hass.bus.async_listen(
|
||||
data_entry_flow.EVENT_DATA_ENTRY_FLOW_PROGRESSED,
|
||||
capture_events,
|
||||
run_immediately=True,
|
||||
)
|
||||
|
||||
result = await manager.async_init("test")
|
||||
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
|
||||
assert result["progress_action"] == "task_one"
|
||||
assert len(manager.async_progress()) == 1
|
||||
assert len(manager.async_progress_by_handler("test")) == 1
|
||||
assert manager.async_get(result["flow_id"])["handler"] == "test"
|
||||
|
||||
# Set task one done and wait for event
|
||||
task_one_evt.set()
|
||||
await event_received_evt.wait()
|
||||
event_received_evt.clear()
|
||||
assert len(events) == 1
|
||||
assert events[0].data == {
|
||||
"handler": "test",
|
||||
"flow_id": result["flow_id"],
|
||||
"refresh": True,
|
||||
}
|
||||
|
||||
# Frontend refreshes the flow
|
||||
result = await manager.async_configure(result["flow_id"])
|
||||
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
|
||||
assert result["progress_action"] == "task_two"
|
||||
|
||||
# Set task two done and wait for event
|
||||
task_two_evt.set()
|
||||
await event_received_evt.wait()
|
||||
event_received_evt.clear()
|
||||
assert len(events) == 2 # 1 for task one and 1 for task two
|
||||
assert events[1].data == {
|
||||
"handler": "test",
|
||||
"flow_id": result["flow_id"],
|
||||
"refresh": True,
|
||||
}
|
||||
|
||||
# Frontend refreshes the flow
|
||||
result = await manager.async_configure(result["flow_id"])
|
||||
assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY
|
||||
assert result["title"] == "Hello"
|
||||
|
||||
|
||||
async def test_show_progress_error(hass: HomeAssistant, manager) -> None:
|
||||
"""Test show progress logic."""
|
||||
manager.hass = hass
|
||||
events = []
|
||||
event_received_evt = asyncio.Event()
|
||||
|
||||
@callback
|
||||
def capture_events(event: Event) -> None:
|
||||
events.append(event)
|
||||
event_received_evt.set()
|
||||
|
||||
@manager.mock_reg_handler("test")
|
||||
class TestFlow(data_entry_flow.FlowHandler):
|
||||
VERSION = 5
|
||||
data = None
|
||||
progress_task: asyncio.Task[None] | None = None
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
async def long_running_task() -> None:
|
||||
raise TypeError
|
||||
|
||||
if not self.progress_task:
|
||||
self.progress_task = hass.async_create_task(long_running_task())
|
||||
if self.progress_task and self.progress_task.done():
|
||||
if self.progress_task.exception():
|
||||
return self.async_show_progress_done(next_step_id="error")
|
||||
return self.async_show_progress_done(next_step_id="no_error")
|
||||
return self.async_show_progress(
|
||||
step_id="init", progress_action="task", progress_task=self.progress_task
|
||||
)
|
||||
|
||||
async def async_step_error(self, user_input=None):
|
||||
return self.async_abort(reason="error")
|
||||
|
||||
hass.bus.async_listen(
|
||||
data_entry_flow.EVENT_DATA_ENTRY_FLOW_PROGRESSED,
|
||||
capture_events,
|
||||
run_immediately=True,
|
||||
)
|
||||
|
||||
result = await manager.async_init("test")
|
||||
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
|
||||
assert result["progress_action"] == "task"
|
||||
assert len(manager.async_progress()) == 1
|
||||
assert len(manager.async_progress_by_handler("test")) == 1
|
||||
assert manager.async_get(result["flow_id"])["handler"] == "test"
|
||||
|
||||
# Set task one done and wait for event
|
||||
await event_received_evt.wait()
|
||||
event_received_evt.clear()
|
||||
assert len(events) == 1
|
||||
assert events[0].data == {
|
||||
"handler": "test",
|
||||
"flow_id": result["flow_id"],
|
||||
"refresh": True,
|
||||
}
|
||||
|
||||
# Frontend refreshes the flow
|
||||
result = await manager.async_configure(result["flow_id"])
|
||||
assert result["type"] == data_entry_flow.FlowResultType.ABORT
|
||||
assert result["reason"] == "error"
|
||||
|
||||
|
||||
async def test_show_progress_legacy(hass: HomeAssistant, manager) -> None:
|
||||
"""Test show progress logic.
|
||||
|
||||
This tests the deprecated version where the config flow is responsible for
|
||||
resuming the flow.
|
||||
"""
|
||||
manager.hass = hass
|
||||
|
||||
@manager.mock_reg_handler("test")
|
||||
class TestFlow(data_entry_flow.FlowHandler):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue