Improve ergonomics of FlowManager.async_show_progress (#107668)
* Improve ergonomics of FlowManager.async_show_progress * Don't include progress coroutine in web response * Unconditionally reset progress task when show_progress finished * Fix race * Tweak, add tests * Address review comments * Improve error handling * Allow progress jobs to return anything * Add comment * Remove unneeded check * Change API according to discussion * Adjust typing
This commit is contained in:
parent
00b40c964a
commit
24cd6a8a52
4 changed files with 221 additions and 28 deletions
|
@ -2,7 +2,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import suppress
|
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from aiogithubapi import (
|
from aiogithubapi import (
|
||||||
|
@ -18,7 +17,7 @@ import voluptuous as vol
|
||||||
from homeassistant import config_entries
|
from homeassistant import config_entries
|
||||||
from homeassistant.const import CONF_ACCESS_TOKEN
|
from homeassistant.const import CONF_ACCESS_TOKEN
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.data_entry_flow import FlowResult, UnknownFlow
|
from homeassistant.data_entry_flow import FlowResult
|
||||||
from homeassistant.helpers.aiohttp_client import (
|
from homeassistant.helpers.aiohttp_client import (
|
||||||
SERVER_SOFTWARE,
|
SERVER_SOFTWARE,
|
||||||
async_get_clientsession,
|
async_get_clientsession,
|
||||||
|
@ -124,22 +123,10 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||||
assert self._device is not None
|
assert self._device is not None
|
||||||
assert self._login_device is not None
|
assert self._login_device is not None
|
||||||
|
|
||||||
try:
|
response = await self._device.activation(
|
||||||
response = await self._device.activation(
|
device_code=self._login_device.device_code
|
||||||
device_code=self._login_device.device_code
|
)
|
||||||
)
|
self._login = response.data
|
||||||
self._login = response.data
|
|
||||||
|
|
||||||
finally:
|
|
||||||
|
|
||||||
async def _progress():
|
|
||||||
# If the user closes the dialog the flow will no longer exist and it will raise UnknownFlow
|
|
||||||
with suppress(UnknownFlow):
|
|
||||||
await self.hass.config_entries.flow.async_configure(
|
|
||||||
flow_id=self.flow_id
|
|
||||||
)
|
|
||||||
|
|
||||||
self.hass.async_create_task(_progress())
|
|
||||||
|
|
||||||
if not self._device:
|
if not self._device:
|
||||||
self._device = GitHubDeviceAPI(
|
self._device = GitHubDeviceAPI(
|
||||||
|
@ -174,6 +161,7 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||||
"url": OAUTH_USER_LOGIN,
|
"url": OAUTH_USER_LOGIN,
|
||||||
"code": self._login_device.user_code,
|
"code": self._login_device.user_code,
|
||||||
},
|
},
|
||||||
|
progress_task=self.login_task,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_step_repositories(
|
async def async_step_repositories(
|
||||||
|
@ -220,13 +208,6 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||||
"""Get the options flow for this handler."""
|
"""Get the options flow for this handler."""
|
||||||
return OptionsFlowHandler(config_entry)
|
return OptionsFlowHandler(config_entry)
|
||||||
|
|
||||||
@callback
|
|
||||||
def async_remove(self) -> None:
|
|
||||||
"""Handle remove handler callback."""
|
|
||||||
if self.login_task and not self.login_task.done():
|
|
||||||
# Clean up login task if it's still running
|
|
||||||
self.login_task.cancel()
|
|
||||||
|
|
||||||
|
|
||||||
class OptionsFlowHandler(config_entries.OptionsFlow):
|
class OptionsFlowHandler(config_entries.OptionsFlow):
|
||||||
"""Handle a option flow for GitHub."""
|
"""Handle a option flow for GitHub."""
|
||||||
|
|
|
@ -2,7 +2,9 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
import asyncio
|
||||||
from collections.abc import Callable, Iterable, Mapping
|
from collections.abc import Callable, Iterable, Mapping
|
||||||
|
from contextlib import suppress
|
||||||
import copy
|
import copy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
@ -124,6 +126,7 @@ class FlowResult(TypedDict, total=False):
|
||||||
options: Mapping[str, Any]
|
options: Mapping[str, Any]
|
||||||
preview: str | None
|
preview: str | None
|
||||||
progress_action: str
|
progress_action: str
|
||||||
|
progress_task: asyncio.Task[Any] | None
|
||||||
reason: str
|
reason: str
|
||||||
required: bool
|
required: bool
|
||||||
result: Any
|
result: Any
|
||||||
|
@ -402,6 +405,7 @@ class FlowManager(abc.ABC):
|
||||||
if (flow := self._progress.pop(flow_id, None)) is None:
|
if (flow := self._progress.pop(flow_id, None)) is None:
|
||||||
raise UnknownFlow
|
raise UnknownFlow
|
||||||
self._async_remove_flow_from_index(flow)
|
self._async_remove_flow_from_index(flow)
|
||||||
|
flow.async_cancel_progress_task()
|
||||||
try:
|
try:
|
||||||
flow.async_remove()
|
flow.async_remove()
|
||||||
except Exception as err: # pylint: disable=broad-except
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
@ -435,6 +439,25 @@ class FlowManager(abc.ABC):
|
||||||
error_if_core=False,
|
error_if_core=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
result["type"] == FlowResultType.SHOW_PROGRESS
|
||||||
|
and (progress_task := result.pop("progress_task", None))
|
||||||
|
and progress_task != flow.async_get_progress_task()
|
||||||
|
):
|
||||||
|
# The flow's progress task was changed, register a callback on it
|
||||||
|
async def call_configure() -> None:
|
||||||
|
with suppress(UnknownFlow):
|
||||||
|
await self.async_configure(flow.flow_id)
|
||||||
|
|
||||||
|
def schedule_configure(_: asyncio.Task) -> None:
|
||||||
|
self.hass.async_create_task(call_configure())
|
||||||
|
|
||||||
|
progress_task.add_done_callback(schedule_configure)
|
||||||
|
flow.async_set_progress_task(progress_task)
|
||||||
|
|
||||||
|
elif result["type"] != FlowResultType.SHOW_PROGRESS:
|
||||||
|
flow.async_cancel_progress_task()
|
||||||
|
|
||||||
if result["type"] in FLOW_NOT_COMPLETE_STEPS:
|
if result["type"] in FLOW_NOT_COMPLETE_STEPS:
|
||||||
self._raise_if_step_does_not_exist(flow, result["step_id"])
|
self._raise_if_step_does_not_exist(flow, result["step_id"])
|
||||||
flow.cur_step = result
|
flow.cur_step = result
|
||||||
|
@ -494,6 +517,8 @@ class FlowHandler:
|
||||||
VERSION = 1
|
VERSION = 1
|
||||||
MINOR_VERSION = 1
|
MINOR_VERSION = 1
|
||||||
|
|
||||||
|
__progress_task: asyncio.Task[Any] | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def source(self) -> str | None:
|
def source(self) -> str | None:
|
||||||
"""Source that initialized the flow."""
|
"""Source that initialized the flow."""
|
||||||
|
@ -632,6 +657,7 @@ class FlowHandler:
|
||||||
step_id: str,
|
step_id: str,
|
||||||
progress_action: str,
|
progress_action: str,
|
||||||
description_placeholders: Mapping[str, str] | None = None,
|
description_placeholders: Mapping[str, str] | None = None,
|
||||||
|
progress_task: asyncio.Task[Any] | None = None,
|
||||||
) -> FlowResult:
|
) -> FlowResult:
|
||||||
"""Show a progress message to the user, without user input allowed."""
|
"""Show a progress message to the user, without user input allowed."""
|
||||||
return FlowResult(
|
return FlowResult(
|
||||||
|
@ -641,6 +667,7 @@ class FlowHandler:
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
progress_action=progress_action,
|
progress_action=progress_action,
|
||||||
description_placeholders=description_placeholders,
|
description_placeholders=description_placeholders,
|
||||||
|
progress_task=progress_task,
|
||||||
)
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -683,6 +710,26 @@ class FlowHandler:
|
||||||
async def async_setup_preview(hass: HomeAssistant) -> None:
|
async def async_setup_preview(hass: HomeAssistant) -> None:
|
||||||
"""Set up preview."""
|
"""Set up preview."""
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_cancel_progress_task(self) -> None:
|
||||||
|
"""Cancel in progress task."""
|
||||||
|
if self.__progress_task and not self.__progress_task.done():
|
||||||
|
self.__progress_task.cancel()
|
||||||
|
self.__progress_task = None
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_get_progress_task(self) -> asyncio.Task[Any] | None:
|
||||||
|
"""Get in progress task."""
|
||||||
|
return self.__progress_task
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_set_progress_task(
|
||||||
|
self,
|
||||||
|
progress_task: asyncio.Task[Any],
|
||||||
|
) -> None:
|
||||||
|
"""Set in progress task."""
|
||||||
|
self.__progress_task = progress_task
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _create_abort_data(
|
def _create_abort_data(
|
||||||
|
|
|
@ -121,10 +121,11 @@ async def test_flow_with_activation_failure(
|
||||||
)
|
)
|
||||||
assert result["step_id"] == "device"
|
assert result["step_id"] == "device"
|
||||||
assert result["type"] == FlowResultType.SHOW_PROGRESS
|
assert result["type"] == FlowResultType.SHOW_PROGRESS
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||||
assert result["type"] == FlowResultType.SHOW_PROGRESS_DONE
|
assert result["type"] == FlowResultType.ABORT
|
||||||
assert result["step_id"] == "could_not_register"
|
assert result["reason"] == "could_not_register"
|
||||||
|
|
||||||
|
|
||||||
async def test_flow_with_remove_while_activating(
|
async def test_flow_with_remove_while_activating(
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""Test the flow classes."""
|
"""Test the flow classes."""
|
||||||
|
import asyncio
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
@ -7,7 +8,7 @@ import pytest
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant import config_entries, data_entry_flow
|
from homeassistant import config_entries, data_entry_flow
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import Event, HomeAssistant, callback
|
||||||
from homeassistant.util.decorator import Registry
|
from homeassistant.util.decorator import Registry
|
||||||
|
|
||||||
from .common import (
|
from .common import (
|
||||||
|
@ -342,6 +343,169 @@ async def test_external_step(hass: HomeAssistant, manager) -> None:
|
||||||
async def test_show_progress(hass: HomeAssistant, manager) -> None:
|
async def test_show_progress(hass: HomeAssistant, manager) -> None:
|
||||||
"""Test show progress logic."""
|
"""Test show progress logic."""
|
||||||
manager.hass = hass
|
manager.hass = hass
|
||||||
|
events = []
|
||||||
|
task_one_evt = asyncio.Event()
|
||||||
|
task_two_evt = asyncio.Event()
|
||||||
|
event_received_evt = asyncio.Event()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def capture_events(event: Event) -> None:
|
||||||
|
events.append(event)
|
||||||
|
event_received_evt.set()
|
||||||
|
|
||||||
|
@manager.mock_reg_handler("test")
|
||||||
|
class TestFlow(data_entry_flow.FlowHandler):
|
||||||
|
VERSION = 5
|
||||||
|
data = None
|
||||||
|
start_task_two = False
|
||||||
|
progress_task: asyncio.Task[None] | None = None
|
||||||
|
|
||||||
|
async def async_step_init(self, user_input=None):
|
||||||
|
async def long_running_task_one() -> None:
|
||||||
|
await task_one_evt.wait()
|
||||||
|
self.start_task_two = True
|
||||||
|
|
||||||
|
async def long_running_task_two() -> None:
|
||||||
|
await task_two_evt.wait()
|
||||||
|
self.data = {"title": "Hello"}
|
||||||
|
|
||||||
|
if not task_one_evt.is_set():
|
||||||
|
progress_action = "task_one"
|
||||||
|
if not self.progress_task:
|
||||||
|
self.progress_task = hass.async_create_task(long_running_task_one())
|
||||||
|
elif not task_two_evt.is_set():
|
||||||
|
progress_action = "task_two"
|
||||||
|
if self.start_task_two:
|
||||||
|
self.progress_task = hass.async_create_task(long_running_task_two())
|
||||||
|
self.start_task_two = False
|
||||||
|
if not task_one_evt.is_set() or not task_two_evt.is_set():
|
||||||
|
return self.async_show_progress(
|
||||||
|
step_id="init",
|
||||||
|
progress_action=progress_action,
|
||||||
|
progress_task=self.progress_task,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.async_show_progress_done(next_step_id="finish")
|
||||||
|
|
||||||
|
async def async_step_finish(self, user_input=None):
|
||||||
|
return self.async_create_entry(title=self.data["title"], data=self.data)
|
||||||
|
|
||||||
|
hass.bus.async_listen(
|
||||||
|
data_entry_flow.EVENT_DATA_ENTRY_FLOW_PROGRESSED,
|
||||||
|
capture_events,
|
||||||
|
run_immediately=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await manager.async_init("test")
|
||||||
|
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
|
||||||
|
assert result["progress_action"] == "task_one"
|
||||||
|
assert len(manager.async_progress()) == 1
|
||||||
|
assert len(manager.async_progress_by_handler("test")) == 1
|
||||||
|
assert manager.async_get(result["flow_id"])["handler"] == "test"
|
||||||
|
|
||||||
|
# Set task one done and wait for event
|
||||||
|
task_one_evt.set()
|
||||||
|
await event_received_evt.wait()
|
||||||
|
event_received_evt.clear()
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0].data == {
|
||||||
|
"handler": "test",
|
||||||
|
"flow_id": result["flow_id"],
|
||||||
|
"refresh": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Frontend refreshes the flow
|
||||||
|
result = await manager.async_configure(result["flow_id"])
|
||||||
|
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
|
||||||
|
assert result["progress_action"] == "task_two"
|
||||||
|
|
||||||
|
# Set task two done and wait for event
|
||||||
|
task_two_evt.set()
|
||||||
|
await event_received_evt.wait()
|
||||||
|
event_received_evt.clear()
|
||||||
|
assert len(events) == 2 # 1 for task one and 1 for task two
|
||||||
|
assert events[1].data == {
|
||||||
|
"handler": "test",
|
||||||
|
"flow_id": result["flow_id"],
|
||||||
|
"refresh": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Frontend refreshes the flow
|
||||||
|
result = await manager.async_configure(result["flow_id"])
|
||||||
|
assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY
|
||||||
|
assert result["title"] == "Hello"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_show_progress_error(hass: HomeAssistant, manager) -> None:
|
||||||
|
"""Test show progress logic."""
|
||||||
|
manager.hass = hass
|
||||||
|
events = []
|
||||||
|
event_received_evt = asyncio.Event()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def capture_events(event: Event) -> None:
|
||||||
|
events.append(event)
|
||||||
|
event_received_evt.set()
|
||||||
|
|
||||||
|
@manager.mock_reg_handler("test")
|
||||||
|
class TestFlow(data_entry_flow.FlowHandler):
|
||||||
|
VERSION = 5
|
||||||
|
data = None
|
||||||
|
progress_task: asyncio.Task[None] | None = None
|
||||||
|
|
||||||
|
async def async_step_init(self, user_input=None):
|
||||||
|
async def long_running_task() -> None:
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
if not self.progress_task:
|
||||||
|
self.progress_task = hass.async_create_task(long_running_task())
|
||||||
|
if self.progress_task and self.progress_task.done():
|
||||||
|
if self.progress_task.exception():
|
||||||
|
return self.async_show_progress_done(next_step_id="error")
|
||||||
|
return self.async_show_progress_done(next_step_id="no_error")
|
||||||
|
return self.async_show_progress(
|
||||||
|
step_id="init", progress_action="task", progress_task=self.progress_task
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_step_error(self, user_input=None):
|
||||||
|
return self.async_abort(reason="error")
|
||||||
|
|
||||||
|
hass.bus.async_listen(
|
||||||
|
data_entry_flow.EVENT_DATA_ENTRY_FLOW_PROGRESSED,
|
||||||
|
capture_events,
|
||||||
|
run_immediately=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await manager.async_init("test")
|
||||||
|
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
|
||||||
|
assert result["progress_action"] == "task"
|
||||||
|
assert len(manager.async_progress()) == 1
|
||||||
|
assert len(manager.async_progress_by_handler("test")) == 1
|
||||||
|
assert manager.async_get(result["flow_id"])["handler"] == "test"
|
||||||
|
|
||||||
|
# Set task one done and wait for event
|
||||||
|
await event_received_evt.wait()
|
||||||
|
event_received_evt.clear()
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0].data == {
|
||||||
|
"handler": "test",
|
||||||
|
"flow_id": result["flow_id"],
|
||||||
|
"refresh": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Frontend refreshes the flow
|
||||||
|
result = await manager.async_configure(result["flow_id"])
|
||||||
|
assert result["type"] == data_entry_flow.FlowResultType.ABORT
|
||||||
|
assert result["reason"] == "error"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_show_progress_legacy(hass: HomeAssistant, manager) -> None:
|
||||||
|
"""Test show progress logic.
|
||||||
|
|
||||||
|
This tests the deprecated version where the config flow is responsible for
|
||||||
|
resuming the flow.
|
||||||
|
"""
|
||||||
|
manager.hass = hass
|
||||||
|
|
||||||
@manager.mock_reg_handler("test")
|
@manager.mock_reg_handler("test")
|
||||||
class TestFlow(data_entry_flow.FlowHandler):
|
class TestFlow(data_entry_flow.FlowHandler):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue