Improve ergonomics of FlowManager.async_show_progress (#107668)

* Improve ergonomics of FlowManager.async_show_progress

* Don't include progress coroutine in web response

* Unconditionally reset progress task when show_progress finished

* Fix race

* Tweak, add tests

* Address review comments

* Improve error handling

* Allow progress jobs to return anything

* Add comment

* Remove unneeded check

* Change API according to discussion

* Adjust typing
This commit is contained in:
Erik Montnemery 2024-01-11 12:00:12 +01:00 committed by GitHub
parent 00b40c964a
commit 24cd6a8a52
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 221 additions and 28 deletions

View file

@ -2,7 +2,6 @@
from __future__ import annotations
import asyncio
from contextlib import suppress
from typing import TYPE_CHECKING, Any
from aiogithubapi import (
@ -18,7 +17,7 @@ import voluptuous as vol
from homeassistant import config_entries
from homeassistant.const import CONF_ACCESS_TOKEN
from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult, UnknownFlow
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.aiohttp_client import (
SERVER_SOFTWARE,
async_get_clientsession,
@ -124,22 +123,10 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
assert self._device is not None
assert self._login_device is not None
try:
response = await self._device.activation(
device_code=self._login_device.device_code
)
self._login = response.data
finally:
async def _progress():
# If the user closes the dialog the flow will no longer exist and it will raise UnknownFlow
with suppress(UnknownFlow):
await self.hass.config_entries.flow.async_configure(
flow_id=self.flow_id
)
self.hass.async_create_task(_progress())
response = await self._device.activation(
device_code=self._login_device.device_code
)
self._login = response.data
if not self._device:
self._device = GitHubDeviceAPI(
@ -174,6 +161,7 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"url": OAUTH_USER_LOGIN,
"code": self._login_device.user_code,
},
progress_task=self.login_task,
)
async def async_step_repositories(
@ -220,13 +208,6 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Get the options flow for this handler."""
return OptionsFlowHandler(config_entry)
@callback
def async_remove(self) -> None:
"""Handle remove handler callback."""
if self.login_task and not self.login_task.done():
# Clean up login task if it's still running
self.login_task.cancel()
class OptionsFlowHandler(config_entries.OptionsFlow):
"""Handle a option flow for GitHub."""

View file

@ -2,7 +2,9 @@
from __future__ import annotations
import abc
import asyncio
from collections.abc import Callable, Iterable, Mapping
from contextlib import suppress
import copy
from dataclasses import dataclass
from enum import StrEnum
@ -124,6 +126,7 @@ class FlowResult(TypedDict, total=False):
options: Mapping[str, Any]
preview: str | None
progress_action: str
progress_task: asyncio.Task[Any] | None
reason: str
required: bool
result: Any
@ -402,6 +405,7 @@ class FlowManager(abc.ABC):
if (flow := self._progress.pop(flow_id, None)) is None:
raise UnknownFlow
self._async_remove_flow_from_index(flow)
flow.async_cancel_progress_task()
try:
flow.async_remove()
except Exception as err: # pylint: disable=broad-except
@ -435,6 +439,25 @@ class FlowManager(abc.ABC):
error_if_core=False,
)
if (
result["type"] == FlowResultType.SHOW_PROGRESS
and (progress_task := result.pop("progress_task", None))
and progress_task != flow.async_get_progress_task()
):
# The flow's progress task was changed, register a callback on it
async def call_configure() -> None:
with suppress(UnknownFlow):
await self.async_configure(flow.flow_id)
def schedule_configure(_: asyncio.Task) -> None:
self.hass.async_create_task(call_configure())
progress_task.add_done_callback(schedule_configure)
flow.async_set_progress_task(progress_task)
elif result["type"] != FlowResultType.SHOW_PROGRESS:
flow.async_cancel_progress_task()
if result["type"] in FLOW_NOT_COMPLETE_STEPS:
self._raise_if_step_does_not_exist(flow, result["step_id"])
flow.cur_step = result
@ -494,6 +517,8 @@ class FlowHandler:
VERSION = 1
MINOR_VERSION = 1
__progress_task: asyncio.Task[Any] | None = None
@property
def source(self) -> str | None:
"""Source that initialized the flow."""
@ -632,6 +657,7 @@ class FlowHandler:
step_id: str,
progress_action: str,
description_placeholders: Mapping[str, str] | None = None,
progress_task: asyncio.Task[Any] | None = None,
) -> FlowResult:
"""Show a progress message to the user, without user input allowed."""
return FlowResult(
@ -641,6 +667,7 @@ class FlowHandler:
step_id=step_id,
progress_action=progress_action,
description_placeholders=description_placeholders,
progress_task=progress_task,
)
@callback
@ -683,6 +710,26 @@ class FlowHandler:
async def async_setup_preview(hass: HomeAssistant) -> None:
"""Set up preview."""
@callback
def async_cancel_progress_task(self) -> None:
"""Cancel in progress task."""
if self.__progress_task and not self.__progress_task.done():
self.__progress_task.cancel()
self.__progress_task = None
@callback
def async_get_progress_task(self) -> asyncio.Task[Any] | None:
"""Get in progress task."""
return self.__progress_task
@callback
def async_set_progress_task(
self,
progress_task: asyncio.Task[Any],
) -> None:
"""Set in progress task."""
self.__progress_task = progress_task
@callback
def _create_abort_data(

View file

@ -121,10 +121,11 @@ async def test_flow_with_activation_failure(
)
assert result["step_id"] == "device"
assert result["type"] == FlowResultType.SHOW_PROGRESS
await hass.async_block_till_done()
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] == FlowResultType.SHOW_PROGRESS_DONE
assert result["step_id"] == "could_not_register"
assert result["type"] == FlowResultType.ABORT
assert result["reason"] == "could_not_register"
async def test_flow_with_remove_while_activating(

View file

@ -1,4 +1,5 @@
"""Test the flow classes."""
import asyncio
import dataclasses
import logging
from unittest.mock import Mock, patch
@ -7,7 +8,7 @@ import pytest
import voluptuous as vol
from homeassistant import config_entries, data_entry_flow
from homeassistant.core import HomeAssistant
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.util.decorator import Registry
from .common import (
@ -342,6 +343,169 @@ async def test_external_step(hass: HomeAssistant, manager) -> None:
async def test_show_progress(hass: HomeAssistant, manager) -> None:
"""Test show progress logic."""
manager.hass = hass
events = []
task_one_evt = asyncio.Event()
task_two_evt = asyncio.Event()
event_received_evt = asyncio.Event()
@callback
def capture_events(event: Event) -> None:
events.append(event)
event_received_evt.set()
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 5
data = None
start_task_two = False
progress_task: asyncio.Task[None] | None = None
async def async_step_init(self, user_input=None):
async def long_running_task_one() -> None:
await task_one_evt.wait()
self.start_task_two = True
async def long_running_task_two() -> None:
await task_two_evt.wait()
self.data = {"title": "Hello"}
if not task_one_evt.is_set():
progress_action = "task_one"
if not self.progress_task:
self.progress_task = hass.async_create_task(long_running_task_one())
elif not task_two_evt.is_set():
progress_action = "task_two"
if self.start_task_two:
self.progress_task = hass.async_create_task(long_running_task_two())
self.start_task_two = False
if not task_one_evt.is_set() or not task_two_evt.is_set():
return self.async_show_progress(
step_id="init",
progress_action=progress_action,
progress_task=self.progress_task,
)
return self.async_show_progress_done(next_step_id="finish")
async def async_step_finish(self, user_input=None):
return self.async_create_entry(title=self.data["title"], data=self.data)
hass.bus.async_listen(
data_entry_flow.EVENT_DATA_ENTRY_FLOW_PROGRESSED,
capture_events,
run_immediately=True,
)
result = await manager.async_init("test")
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
assert result["progress_action"] == "task_one"
assert len(manager.async_progress()) == 1
assert len(manager.async_progress_by_handler("test")) == 1
assert manager.async_get(result["flow_id"])["handler"] == "test"
# Set task one done and wait for event
task_one_evt.set()
await event_received_evt.wait()
event_received_evt.clear()
assert len(events) == 1
assert events[0].data == {
"handler": "test",
"flow_id": result["flow_id"],
"refresh": True,
}
# Frontend refreshes the flow
result = await manager.async_configure(result["flow_id"])
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
assert result["progress_action"] == "task_two"
# Set task two done and wait for event
task_two_evt.set()
await event_received_evt.wait()
event_received_evt.clear()
assert len(events) == 2 # 1 for task one and 1 for task two
assert events[1].data == {
"handler": "test",
"flow_id": result["flow_id"],
"refresh": True,
}
# Frontend refreshes the flow
result = await manager.async_configure(result["flow_id"])
assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY
assert result["title"] == "Hello"
async def test_show_progress_error(hass: HomeAssistant, manager) -> None:
"""Test show progress logic."""
manager.hass = hass
events = []
event_received_evt = asyncio.Event()
@callback
def capture_events(event: Event) -> None:
events.append(event)
event_received_evt.set()
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 5
data = None
progress_task: asyncio.Task[None] | None = None
async def async_step_init(self, user_input=None):
async def long_running_task() -> None:
raise TypeError
if not self.progress_task:
self.progress_task = hass.async_create_task(long_running_task())
if self.progress_task and self.progress_task.done():
if self.progress_task.exception():
return self.async_show_progress_done(next_step_id="error")
return self.async_show_progress_done(next_step_id="no_error")
return self.async_show_progress(
step_id="init", progress_action="task", progress_task=self.progress_task
)
async def async_step_error(self, user_input=None):
return self.async_abort(reason="error")
hass.bus.async_listen(
data_entry_flow.EVENT_DATA_ENTRY_FLOW_PROGRESSED,
capture_events,
run_immediately=True,
)
result = await manager.async_init("test")
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
assert result["progress_action"] == "task"
assert len(manager.async_progress()) == 1
assert len(manager.async_progress_by_handler("test")) == 1
assert manager.async_get(result["flow_id"])["handler"] == "test"
# Set task one done and wait for event
await event_received_evt.wait()
event_received_evt.clear()
assert len(events) == 1
assert events[0].data == {
"handler": "test",
"flow_id": result["flow_id"],
"refresh": True,
}
# Frontend refreshes the flow
result = await manager.async_configure(result["flow_id"])
assert result["type"] == data_entry_flow.FlowResultType.ABORT
assert result["reason"] == "error"
async def test_show_progress_legacy(hass: HomeAssistant, manager) -> None:
"""Test show progress logic.
This tests the deprecated version where the config flow is responsible for
resuming the flow.
"""
manager.hass = hass
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):