Add a callback for data flow handler removal (#77394)

* Add a callback for when data flows are removed

* Call `async_remove` at the very end

* Handle and log exceptions caught during flow removal

* Log the error as an exception, with a traceback

* Adjust test's expected logging output to match updated format specifier
This commit is contained in:
puddly 2022-08-29 16:25:34 -04:00 committed by GitHub
parent a6c61cf339
commit 2224d0f43a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 1 deletions

View file

@ -5,6 +5,7 @@ import abc
import asyncio import asyncio
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from dataclasses import dataclass from dataclasses import dataclass
import logging
from types import MappingProxyType from types import MappingProxyType
from typing import Any, TypedDict from typing import Any, TypedDict
@ -16,6 +17,8 @@ from .exceptions import HomeAssistantError
from .helpers.frame import report from .helpers.frame import report
from .util import uuid as uuid_util from .util import uuid as uuid_util
_LOGGER = logging.getLogger(__name__)
class FlowResultType(StrEnum): class FlowResultType(StrEnum):
"""Result type for a data entry flow.""" """Result type for a data entry flow."""
@ -337,6 +340,11 @@ class FlowManager(abc.ABC):
if not self._handler_progress_index[handler]: if not self._handler_progress_index[handler]:
del self._handler_progress_index[handler] del self._handler_progress_index[handler]
try:
flow.async_remove()
except Exception as err: # pylint: disable=broad-except
_LOGGER.exception("Error removing %s config flow: %s", flow.handler, err)
async def _async_handle_step( async def _async_handle_step(
self, self,
flow: Any, flow: Any,
@ -568,6 +576,10 @@ class FlowHandler:
description_placeholders=description_placeholders, description_placeholders=description_placeholders,
) )
@callback
def async_remove(self) -> None:
"""Notification that the config flow has been removed."""
@callback @callback
def _create_abort_data( def _create_abort_data(

View file

@ -1,6 +1,7 @@
"""Test the flow classes.""" """Test the flow classes."""
import asyncio import asyncio
from unittest.mock import patch import logging
from unittest.mock import Mock, patch
import pytest import pytest
import voluptuous as vol import voluptuous as vol
@ -149,6 +150,45 @@ async def test_abort_removes_instance(manager):
assert len(manager.mock_created_entries) == 0 assert len(manager.mock_created_entries) == 0
async def test_abort_calls_async_remove(manager):
"""Test abort calling the async_remove FlowHandler method."""
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
async def async_step_init(self, user_input=None):
return self.async_abort(reason="reason")
async_remove = Mock()
await manager.async_init("test")
TestFlow.async_remove.assert_called_once()
assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 0
async def test_abort_calls_async_remove_with_exception(manager, caplog):
"""Test abort calling the async_remove FlowHandler method, with an exception."""
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
async def async_step_init(self, user_input=None):
return self.async_abort(reason="reason")
async_remove = Mock(side_effect=[RuntimeError("error")])
with caplog.at_level(logging.ERROR):
await manager.async_init("test")
assert "Error removing test config flow: error" in caplog.text
TestFlow.async_remove.assert_called_once()
assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 0
async def test_create_saves_data(manager): async def test_create_saves_data(manager):
"""Test creating a config entry.""" """Test creating a config entry."""