diff --git a/homeassistant/components/github/config_flow.py b/homeassistant/components/github/config_flow.py index c90caf0fc89..aa7ec7b6f86 100644 --- a/homeassistant/components/github/config_flow.py +++ b/homeassistant/components/github/config_flow.py @@ -2,7 +2,6 @@ from __future__ import annotations import asyncio -from contextlib import suppress from typing import TYPE_CHECKING, Any from aiogithubapi import ( @@ -18,7 +17,7 @@ import voluptuous as vol from homeassistant import config_entries from homeassistant.const import CONF_ACCESS_TOKEN from homeassistant.core import HomeAssistant, callback -from homeassistant.data_entry_flow import FlowResult, UnknownFlow +from homeassistant.data_entry_flow import FlowResult from homeassistant.helpers.aiohttp_client import ( SERVER_SOFTWARE, async_get_clientsession, @@ -124,22 +123,10 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): assert self._device is not None assert self._login_device is not None - try: - response = await self._device.activation( - device_code=self._login_device.device_code - ) - self._login = response.data - - finally: - - async def _progress(): - # If the user closes the dialog the flow will no longer exist and it will raise UnknownFlow - with suppress(UnknownFlow): - await self.hass.config_entries.flow.async_configure( - flow_id=self.flow_id - ) - - self.hass.async_create_task(_progress()) + response = await self._device.activation( + device_code=self._login_device.device_code + ) + self._login = response.data if not self._device: self._device = GitHubDeviceAPI( @@ -174,6 +161,7 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): "url": OAUTH_USER_LOGIN, "code": self._login_device.user_code, }, + progress_task=self.login_task, ) async def async_step_repositories( @@ -220,13 +208,6 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): """Get the options flow for this handler.""" return OptionsFlowHandler(config_entry) - @callback - def async_remove(self) -> None: - """Handle remove handler callback.""" - if self.login_task and not self.login_task.done(): - # Clean up login task if it's still running - self.login_task.cancel() - class OptionsFlowHandler(config_entries.OptionsFlow): """Handle a option flow for GitHub.""" diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index c017744689c..aa9df89de5c 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -2,7 +2,9 @@ from __future__ import annotations import abc +import asyncio from collections.abc import Callable, Iterable, Mapping +from contextlib import suppress import copy from dataclasses import dataclass from enum import StrEnum @@ -124,6 +126,7 @@ class FlowResult(TypedDict, total=False): options: Mapping[str, Any] preview: str | None progress_action: str + progress_task: asyncio.Task[Any] | None reason: str required: bool result: Any @@ -402,6 +405,7 @@ class FlowManager(abc.ABC): if (flow := self._progress.pop(flow_id, None)) is None: raise UnknownFlow self._async_remove_flow_from_index(flow) + flow.async_cancel_progress_task() try: flow.async_remove() except Exception as err: # pylint: disable=broad-except @@ -435,6 +439,25 @@ class FlowManager(abc.ABC): error_if_core=False, ) + if ( + result["type"] == FlowResultType.SHOW_PROGRESS + and (progress_task := result.pop("progress_task", None)) + and progress_task != flow.async_get_progress_task() + ): + # The flow's progress task was changed, register a callback on it + async def call_configure() -> None: + with suppress(UnknownFlow): + await self.async_configure(flow.flow_id) + + def schedule_configure(_: asyncio.Task) -> None: + self.hass.async_create_task(call_configure()) + + progress_task.add_done_callback(schedule_configure) + flow.async_set_progress_task(progress_task) + + elif result["type"] != FlowResultType.SHOW_PROGRESS: + flow.async_cancel_progress_task() + if result["type"] in FLOW_NOT_COMPLETE_STEPS: self._raise_if_step_does_not_exist(flow, result["step_id"]) flow.cur_step = result @@ -494,6 +517,8 @@ class FlowHandler: VERSION = 1 MINOR_VERSION = 1 + __progress_task: asyncio.Task[Any] | None = None + @property def source(self) -> str | None: """Source that initialized the flow.""" @@ -632,6 +657,7 @@ class FlowHandler: step_id: str, progress_action: str, description_placeholders: Mapping[str, str] | None = None, + progress_task: asyncio.Task[Any] | None = None, ) -> FlowResult: """Show a progress message to the user, without user input allowed.""" return FlowResult( @@ -641,6 +667,7 @@ class FlowHandler: step_id=step_id, progress_action=progress_action, description_placeholders=description_placeholders, + progress_task=progress_task, ) @callback @@ -683,6 +710,26 @@ class FlowHandler: async def async_setup_preview(hass: HomeAssistant) -> None: """Set up preview.""" + @callback + def async_cancel_progress_task(self) -> None: + """Cancel in progress task.""" + if self.__progress_task and not self.__progress_task.done(): + self.__progress_task.cancel() + self.__progress_task = None + + @callback + def async_get_progress_task(self) -> asyncio.Task[Any] | None: + """Get in progress task.""" + return self.__progress_task + + @callback + def async_set_progress_task( + self, + progress_task: asyncio.Task[Any], + ) -> None: + """Set in progress task.""" + self.__progress_task = progress_task + @callback def _create_abort_data( diff --git a/tests/components/github/test_config_flow.py b/tests/components/github/test_config_flow.py index 8d61eca1ab1..32388fb65d1 100644 --- a/tests/components/github/test_config_flow.py +++ b/tests/components/github/test_config_flow.py @@ -121,10 +121,11 @@ async def test_flow_with_activation_failure( ) assert result["step_id"] == "device" assert result["type"] == FlowResultType.SHOW_PROGRESS + await hass.async_block_till_done() result = await hass.config_entries.flow.async_configure(result["flow_id"]) - assert result["type"] == FlowResultType.SHOW_PROGRESS_DONE - assert result["step_id"] == "could_not_register" + assert result["type"] == FlowResultType.ABORT + assert result["reason"] == "could_not_register" async def test_flow_with_remove_while_activating( diff --git a/tests/test_data_entry_flow.py b/tests/test_data_entry_flow.py index 155d78e2c64..aedf3e40c15 100644 --- a/tests/test_data_entry_flow.py +++ b/tests/test_data_entry_flow.py @@ -1,4 +1,5 @@ """Test the flow classes.""" +import asyncio import dataclasses import logging from unittest.mock import Mock, patch @@ -7,7 +8,7 @@ import pytest import voluptuous as vol from homeassistant import config_entries, data_entry_flow -from homeassistant.core import HomeAssistant +from homeassistant.core import Event, HomeAssistant, callback from homeassistant.util.decorator import Registry from .common import ( @@ -342,6 +343,169 @@ async def test_external_step(hass: HomeAssistant, manager) -> None: async def test_show_progress(hass: HomeAssistant, manager) -> None: """Test show progress logic.""" manager.hass = hass + events = [] + task_one_evt = asyncio.Event() + task_two_evt = asyncio.Event() + event_received_evt = asyncio.Event() + + @callback + def capture_events(event: Event) -> None: + events.append(event) + event_received_evt.set() + + @manager.mock_reg_handler("test") + class TestFlow(data_entry_flow.FlowHandler): + VERSION = 5 + data = None + start_task_two = False + progress_task: asyncio.Task[None] | None = None + + async def async_step_init(self, user_input=None): + async def long_running_task_one() -> None: + await task_one_evt.wait() + self.start_task_two = True + + async def long_running_task_two() -> None: + await task_two_evt.wait() + self.data = {"title": "Hello"} + + if not task_one_evt.is_set(): + progress_action = "task_one" + if not self.progress_task: + self.progress_task = hass.async_create_task(long_running_task_one()) + elif not task_two_evt.is_set(): + progress_action = "task_two" + if self.start_task_two: + self.progress_task = hass.async_create_task(long_running_task_two()) + self.start_task_two = False + if not task_one_evt.is_set() or not task_two_evt.is_set(): + return self.async_show_progress( + step_id="init", + progress_action=progress_action, + progress_task=self.progress_task, + ) + + return self.async_show_progress_done(next_step_id="finish") + + async def async_step_finish(self, user_input=None): + return self.async_create_entry(title=self.data["title"], data=self.data) + + hass.bus.async_listen( + data_entry_flow.EVENT_DATA_ENTRY_FLOW_PROGRESSED, + capture_events, + run_immediately=True, + ) + + result = await manager.async_init("test") + assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS + assert result["progress_action"] == "task_one" + assert len(manager.async_progress()) == 1 + assert len(manager.async_progress_by_handler("test")) == 1 + assert manager.async_get(result["flow_id"])["handler"] == "test" + + # Set task one done and wait for event + task_one_evt.set() + await event_received_evt.wait() + event_received_evt.clear() + assert len(events) == 1 + assert events[0].data == { + "handler": "test", + "flow_id": result["flow_id"], + "refresh": True, + } + + # Frontend refreshes the flow + result = await manager.async_configure(result["flow_id"]) + assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS + assert result["progress_action"] == "task_two" + + # Set task two done and wait for event + task_two_evt.set() + await event_received_evt.wait() + event_received_evt.clear() + assert len(events) == 2 # 1 for task one and 1 for task two + assert events[1].data == { + "handler": "test", + "flow_id": result["flow_id"], + "refresh": True, + } + + # Frontend refreshes the flow + result = await manager.async_configure(result["flow_id"]) + assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY + assert result["title"] == "Hello" + + +async def test_show_progress_error(hass: HomeAssistant, manager) -> None: + """Test show progress logic.""" + manager.hass = hass + events = [] + event_received_evt = asyncio.Event() + + @callback + def capture_events(event: Event) -> None: + events.append(event) + event_received_evt.set() + + @manager.mock_reg_handler("test") + class TestFlow(data_entry_flow.FlowHandler): + VERSION = 5 + data = None + progress_task: asyncio.Task[None] | None = None + + async def async_step_init(self, user_input=None): + async def long_running_task() -> None: + raise TypeError + + if not self.progress_task: + self.progress_task = hass.async_create_task(long_running_task()) + if self.progress_task and self.progress_task.done(): + if self.progress_task.exception(): + return self.async_show_progress_done(next_step_id="error") + return self.async_show_progress_done(next_step_id="no_error") + return self.async_show_progress( + step_id="init", progress_action="task", progress_task=self.progress_task + ) + + async def async_step_error(self, user_input=None): + return self.async_abort(reason="error") + + hass.bus.async_listen( + data_entry_flow.EVENT_DATA_ENTRY_FLOW_PROGRESSED, + capture_events, + run_immediately=True, + ) + + result = await manager.async_init("test") + assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS + assert result["progress_action"] == "task" + assert len(manager.async_progress()) == 1 + assert len(manager.async_progress_by_handler("test")) == 1 + assert manager.async_get(result["flow_id"])["handler"] == "test" + + # Set task one done and wait for event + await event_received_evt.wait() + event_received_evt.clear() + assert len(events) == 1 + assert events[0].data == { + "handler": "test", + "flow_id": result["flow_id"], + "refresh": True, + } + + # Frontend refreshes the flow + result = await manager.async_configure(result["flow_id"]) + assert result["type"] == data_entry_flow.FlowResultType.ABORT + assert result["reason"] == "error" + + +async def test_show_progress_legacy(hass: HomeAssistant, manager) -> None: + """Test show progress logic. + + This tests the deprecated version where the config flow is responsible for + resuming the flow. + """ + manager.hass = hass @manager.mock_reg_handler("test") class TestFlow(data_entry_flow.FlowHandler):