Improve ergonomics of FlowManager.async_show_progress (#107668)

* Improve ergonomics of FlowManager.async_show_progress

* Don't include progress coroutine in web response

* Unconditionally reset progress task when show_progress finished

* Fix race

* Tweak, add tests

* Address review comments

* Improve error handling

* Allow progress jobs to return anything

* Add comment

* Remove unneeded check

* Change API according to discussion

* Adjust typing
This commit is contained in:
Erik Montnemery 2024-01-11 12:00:12 +01:00 committed by GitHub
parent 00b40c964a
commit 24cd6a8a52
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 221 additions and 28 deletions

View file

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from contextlib import suppress
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from aiogithubapi import ( from aiogithubapi import (
@ -18,7 +17,7 @@ import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.const import CONF_ACCESS_TOKEN from homeassistant.const import CONF_ACCESS_TOKEN
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult, UnknownFlow from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.aiohttp_client import ( from homeassistant.helpers.aiohttp_client import (
SERVER_SOFTWARE, SERVER_SOFTWARE,
async_get_clientsession, async_get_clientsession,
@ -124,22 +123,10 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
assert self._device is not None assert self._device is not None
assert self._login_device is not None assert self._login_device is not None
try: response = await self._device.activation(
response = await self._device.activation( device_code=self._login_device.device_code
device_code=self._login_device.device_code )
) self._login = response.data
self._login = response.data
finally:
async def _progress():
# If the user closes the dialog the flow will no longer exist and it will raise UnknownFlow
with suppress(UnknownFlow):
await self.hass.config_entries.flow.async_configure(
flow_id=self.flow_id
)
self.hass.async_create_task(_progress())
if not self._device: if not self._device:
self._device = GitHubDeviceAPI( self._device = GitHubDeviceAPI(
@ -174,6 +161,7 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"url": OAUTH_USER_LOGIN, "url": OAUTH_USER_LOGIN,
"code": self._login_device.user_code, "code": self._login_device.user_code,
}, },
progress_task=self.login_task,
) )
async def async_step_repositories( async def async_step_repositories(
@ -220,13 +208,6 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Get the options flow for this handler.""" """Get the options flow for this handler."""
return OptionsFlowHandler(config_entry) return OptionsFlowHandler(config_entry)
@callback
def async_remove(self) -> None:
"""Handle remove handler callback."""
if self.login_task and not self.login_task.done():
# Clean up login task if it's still running
self.login_task.cancel()
class OptionsFlowHandler(config_entries.OptionsFlow): class OptionsFlowHandler(config_entries.OptionsFlow):
"""Handle a option flow for GitHub.""" """Handle a option flow for GitHub."""

View file

@ -2,7 +2,9 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
import asyncio
from collections.abc import Callable, Iterable, Mapping from collections.abc import Callable, Iterable, Mapping
from contextlib import suppress
import copy import copy
from dataclasses import dataclass from dataclasses import dataclass
from enum import StrEnum from enum import StrEnum
@ -124,6 +126,7 @@ class FlowResult(TypedDict, total=False):
options: Mapping[str, Any] options: Mapping[str, Any]
preview: str | None preview: str | None
progress_action: str progress_action: str
progress_task: asyncio.Task[Any] | None
reason: str reason: str
required: bool required: bool
result: Any result: Any
@ -402,6 +405,7 @@ class FlowManager(abc.ABC):
if (flow := self._progress.pop(flow_id, None)) is None: if (flow := self._progress.pop(flow_id, None)) is None:
raise UnknownFlow raise UnknownFlow
self._async_remove_flow_from_index(flow) self._async_remove_flow_from_index(flow)
flow.async_cancel_progress_task()
try: try:
flow.async_remove() flow.async_remove()
except Exception as err: # pylint: disable=broad-except except Exception as err: # pylint: disable=broad-except
@ -435,6 +439,25 @@ class FlowManager(abc.ABC):
error_if_core=False, error_if_core=False,
) )
if (
result["type"] == FlowResultType.SHOW_PROGRESS
and (progress_task := result.pop("progress_task", None))
and progress_task != flow.async_get_progress_task()
):
# The flow's progress task was changed, register a callback on it
async def call_configure() -> None:
with suppress(UnknownFlow):
await self.async_configure(flow.flow_id)
def schedule_configure(_: asyncio.Task) -> None:
self.hass.async_create_task(call_configure())
progress_task.add_done_callback(schedule_configure)
flow.async_set_progress_task(progress_task)
elif result["type"] != FlowResultType.SHOW_PROGRESS:
flow.async_cancel_progress_task()
if result["type"] in FLOW_NOT_COMPLETE_STEPS: if result["type"] in FLOW_NOT_COMPLETE_STEPS:
self._raise_if_step_does_not_exist(flow, result["step_id"]) self._raise_if_step_does_not_exist(flow, result["step_id"])
flow.cur_step = result flow.cur_step = result
@ -494,6 +517,8 @@ class FlowHandler:
VERSION = 1 VERSION = 1
MINOR_VERSION = 1 MINOR_VERSION = 1
__progress_task: asyncio.Task[Any] | None = None
@property @property
def source(self) -> str | None: def source(self) -> str | None:
"""Source that initialized the flow.""" """Source that initialized the flow."""
@ -632,6 +657,7 @@ class FlowHandler:
step_id: str, step_id: str,
progress_action: str, progress_action: str,
description_placeholders: Mapping[str, str] | None = None, description_placeholders: Mapping[str, str] | None = None,
progress_task: asyncio.Task[Any] | None = None,
) -> FlowResult: ) -> FlowResult:
"""Show a progress message to the user, without user input allowed.""" """Show a progress message to the user, without user input allowed."""
return FlowResult( return FlowResult(
@ -641,6 +667,7 @@ class FlowHandler:
step_id=step_id, step_id=step_id,
progress_action=progress_action, progress_action=progress_action,
description_placeholders=description_placeholders, description_placeholders=description_placeholders,
progress_task=progress_task,
) )
@callback @callback
@ -683,6 +710,26 @@ class FlowHandler:
async def async_setup_preview(hass: HomeAssistant) -> None: async def async_setup_preview(hass: HomeAssistant) -> None:
"""Set up preview.""" """Set up preview."""
@callback
def async_cancel_progress_task(self) -> None:
"""Cancel in progress task."""
if self.__progress_task and not self.__progress_task.done():
self.__progress_task.cancel()
self.__progress_task = None
@callback
def async_get_progress_task(self) -> asyncio.Task[Any] | None:
"""Get in progress task."""
return self.__progress_task
@callback
def async_set_progress_task(
self,
progress_task: asyncio.Task[Any],
) -> None:
"""Set in progress task."""
self.__progress_task = progress_task
@callback @callback
def _create_abort_data( def _create_abort_data(

View file

@ -121,10 +121,11 @@ async def test_flow_with_activation_failure(
) )
assert result["step_id"] == "device" assert result["step_id"] == "device"
assert result["type"] == FlowResultType.SHOW_PROGRESS assert result["type"] == FlowResultType.SHOW_PROGRESS
await hass.async_block_till_done()
result = await hass.config_entries.flow.async_configure(result["flow_id"]) result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] == FlowResultType.SHOW_PROGRESS_DONE assert result["type"] == FlowResultType.ABORT
assert result["step_id"] == "could_not_register" assert result["reason"] == "could_not_register"
async def test_flow_with_remove_while_activating( async def test_flow_with_remove_while_activating(

View file

@ -1,4 +1,5 @@
"""Test the flow classes.""" """Test the flow classes."""
import asyncio
import dataclasses import dataclasses
import logging import logging
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
@ -7,7 +8,7 @@ import pytest
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries, data_entry_flow from homeassistant import config_entries, data_entry_flow
from homeassistant.core import HomeAssistant from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.util.decorator import Registry from homeassistant.util.decorator import Registry
from .common import ( from .common import (
@ -342,6 +343,169 @@ async def test_external_step(hass: HomeAssistant, manager) -> None:
async def test_show_progress(hass: HomeAssistant, manager) -> None: async def test_show_progress(hass: HomeAssistant, manager) -> None:
"""Test show progress logic.""" """Test show progress logic."""
manager.hass = hass manager.hass = hass
events = []
task_one_evt = asyncio.Event()
task_two_evt = asyncio.Event()
event_received_evt = asyncio.Event()
@callback
def capture_events(event: Event) -> None:
events.append(event)
event_received_evt.set()
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 5
data = None
start_task_two = False
progress_task: asyncio.Task[None] | None = None
async def async_step_init(self, user_input=None):
async def long_running_task_one() -> None:
await task_one_evt.wait()
self.start_task_two = True
async def long_running_task_two() -> None:
await task_two_evt.wait()
self.data = {"title": "Hello"}
if not task_one_evt.is_set():
progress_action = "task_one"
if not self.progress_task:
self.progress_task = hass.async_create_task(long_running_task_one())
elif not task_two_evt.is_set():
progress_action = "task_two"
if self.start_task_two:
self.progress_task = hass.async_create_task(long_running_task_two())
self.start_task_two = False
if not task_one_evt.is_set() or not task_two_evt.is_set():
return self.async_show_progress(
step_id="init",
progress_action=progress_action,
progress_task=self.progress_task,
)
return self.async_show_progress_done(next_step_id="finish")
async def async_step_finish(self, user_input=None):
return self.async_create_entry(title=self.data["title"], data=self.data)
hass.bus.async_listen(
data_entry_flow.EVENT_DATA_ENTRY_FLOW_PROGRESSED,
capture_events,
run_immediately=True,
)
result = await manager.async_init("test")
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
assert result["progress_action"] == "task_one"
assert len(manager.async_progress()) == 1
assert len(manager.async_progress_by_handler("test")) == 1
assert manager.async_get(result["flow_id"])["handler"] == "test"
# Set task one done and wait for event
task_one_evt.set()
await event_received_evt.wait()
event_received_evt.clear()
assert len(events) == 1
assert events[0].data == {
"handler": "test",
"flow_id": result["flow_id"],
"refresh": True,
}
# Frontend refreshes the flow
result = await manager.async_configure(result["flow_id"])
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
assert result["progress_action"] == "task_two"
# Set task two done and wait for event
task_two_evt.set()
await event_received_evt.wait()
event_received_evt.clear()
assert len(events) == 2 # 1 for task one and 1 for task two
assert events[1].data == {
"handler": "test",
"flow_id": result["flow_id"],
"refresh": True,
}
# Frontend refreshes the flow
result = await manager.async_configure(result["flow_id"])
assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY
assert result["title"] == "Hello"
async def test_show_progress_error(hass: HomeAssistant, manager) -> None:
"""Test show progress logic."""
manager.hass = hass
events = []
event_received_evt = asyncio.Event()
@callback
def capture_events(event: Event) -> None:
events.append(event)
event_received_evt.set()
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 5
data = None
progress_task: asyncio.Task[None] | None = None
async def async_step_init(self, user_input=None):
async def long_running_task() -> None:
raise TypeError
if not self.progress_task:
self.progress_task = hass.async_create_task(long_running_task())
if self.progress_task and self.progress_task.done():
if self.progress_task.exception():
return self.async_show_progress_done(next_step_id="error")
return self.async_show_progress_done(next_step_id="no_error")
return self.async_show_progress(
step_id="init", progress_action="task", progress_task=self.progress_task
)
async def async_step_error(self, user_input=None):
return self.async_abort(reason="error")
hass.bus.async_listen(
data_entry_flow.EVENT_DATA_ENTRY_FLOW_PROGRESSED,
capture_events,
run_immediately=True,
)
result = await manager.async_init("test")
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
assert result["progress_action"] == "task"
assert len(manager.async_progress()) == 1
assert len(manager.async_progress_by_handler("test")) == 1
assert manager.async_get(result["flow_id"])["handler"] == "test"
# Set task one done and wait for event
await event_received_evt.wait()
event_received_evt.clear()
assert len(events) == 1
assert events[0].data == {
"handler": "test",
"flow_id": result["flow_id"],
"refresh": True,
}
# Frontend refreshes the flow
result = await manager.async_configure(result["flow_id"])
assert result["type"] == data_entry_flow.FlowResultType.ABORT
assert result["reason"] == "error"
async def test_show_progress_legacy(hass: HomeAssistant, manager) -> None:
"""Test show progress logic.
This tests the deprecated version where the config flow is responsible for
resuming the flow.
"""
manager.hass = hass
@manager.mock_reg_handler("test") @manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler): class TestFlow(data_entry_flow.FlowHandler):