From 2224d0f43a048052cfc4572df95c7afcccdf3a57 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Mon, 29 Aug 2022 16:25:34 -0400 Subject: [PATCH] Add a callback for data flow handler removal (#77394) * Add a callback for when data flows are removed * Call `async_remove` at the very end * Handle and log exceptions caught during flow removal * Log the error as an exception, with a traceback * Adjust test's expected logging output to match updated format specifier --- homeassistant/data_entry_flow.py | 12 +++++++++ tests/test_data_entry_flow.py | 42 +++++++++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index 64750b2ff50..cdc4023f32c 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -5,6 +5,7 @@ import abc import asyncio from collections.abc import Iterable, Mapping from dataclasses import dataclass +import logging from types import MappingProxyType from typing import Any, TypedDict @@ -16,6 +17,8 @@ from .exceptions import HomeAssistantError from .helpers.frame import report from .util import uuid as uuid_util +_LOGGER = logging.getLogger(__name__) + class FlowResultType(StrEnum): """Result type for a data entry flow.""" @@ -337,6 +340,11 @@ class FlowManager(abc.ABC): if not self._handler_progress_index[handler]: del self._handler_progress_index[handler] + try: + flow.async_remove() + except Exception as err: # pylint: disable=broad-except + _LOGGER.exception("Error removing %s config flow: %s", flow.handler, err) + async def _async_handle_step( self, flow: Any, @@ -568,6 +576,10 @@ class FlowHandler: description_placeholders=description_placeholders, ) + @callback + def async_remove(self) -> None: + """Notification that the config flow has been removed.""" + @callback def _create_abort_data( diff --git a/tests/test_data_entry_flow.py b/tests/test_data_entry_flow.py index 136c97808d3..1d60e20a3f0 100644 --- a/tests/test_data_entry_flow.py +++ b/tests/test_data_entry_flow.py @@ -1,6 +1,7 @@ """Test the flow classes.""" import asyncio -from unittest.mock import patch +import logging +from unittest.mock import Mock, patch import pytest import voluptuous as vol @@ -149,6 +150,45 @@ async def test_abort_removes_instance(manager): assert len(manager.mock_created_entries) == 0 +async def test_abort_calls_async_remove(manager): + """Test abort calling the async_remove FlowHandler method.""" + + @manager.mock_reg_handler("test") + class TestFlow(data_entry_flow.FlowHandler): + async def async_step_init(self, user_input=None): + return self.async_abort(reason="reason") + + async_remove = Mock() + + await manager.async_init("test") + + TestFlow.async_remove.assert_called_once() + + assert len(manager.async_progress()) == 0 + assert len(manager.mock_created_entries) == 0 + + +async def test_abort_calls_async_remove_with_exception(manager, caplog): + """Test abort calling the async_remove FlowHandler method, with an exception.""" + + @manager.mock_reg_handler("test") + class TestFlow(data_entry_flow.FlowHandler): + async def async_step_init(self, user_input=None): + return self.async_abort(reason="reason") + + async_remove = Mock(side_effect=[RuntimeError("error")]) + + with caplog.at_level(logging.ERROR): + await manager.async_init("test") + + assert "Error removing test config flow: error" in caplog.text + + TestFlow.async_remove.assert_called_once() + + assert len(manager.async_progress()) == 0 + assert len(manager.mock_created_entries) == 0 + + async def test_create_saves_data(manager): """Test creating a config entry."""