Add a callback for data flow handler removal (#77394)
* Add a callback for when data flows are removed * Call `async_remove` at the very end * Handle and log exceptions caught during flow removal * Log the error as an exception, with a traceback * Adjust test's expected logging output to match updated format specifier
This commit is contained in:
parent
a6c61cf339
commit
2224d0f43a
2 changed files with 53 additions and 1 deletions
|
@ -5,6 +5,7 @@ import abc
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Iterable, Mapping
|
from collections.abc import Iterable, Mapping
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
import logging
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import Any, TypedDict
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
|
@ -16,6 +17,8 @@ from .exceptions import HomeAssistantError
|
||||||
from .helpers.frame import report
|
from .helpers.frame import report
|
||||||
from .util import uuid as uuid_util
|
from .util import uuid as uuid_util
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class FlowResultType(StrEnum):
|
class FlowResultType(StrEnum):
|
||||||
"""Result type for a data entry flow."""
|
"""Result type for a data entry flow."""
|
||||||
|
@ -337,6 +340,11 @@ class FlowManager(abc.ABC):
|
||||||
if not self._handler_progress_index[handler]:
|
if not self._handler_progress_index[handler]:
|
||||||
del self._handler_progress_index[handler]
|
del self._handler_progress_index[handler]
|
||||||
|
|
||||||
|
try:
|
||||||
|
flow.async_remove()
|
||||||
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
_LOGGER.exception("Error removing %s config flow: %s", flow.handler, err)
|
||||||
|
|
||||||
async def _async_handle_step(
|
async def _async_handle_step(
|
||||||
self,
|
self,
|
||||||
flow: Any,
|
flow: Any,
|
||||||
|
@ -568,6 +576,10 @@ class FlowHandler:
|
||||||
description_placeholders=description_placeholders,
|
description_placeholders=description_placeholders,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_remove(self) -> None:
|
||||||
|
"""Notification that the config flow has been removed."""
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _create_abort_data(
|
def _create_abort_data(
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Test the flow classes."""
|
"""Test the flow classes."""
|
||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import patch
|
import logging
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
@ -149,6 +150,45 @@ async def test_abort_removes_instance(manager):
|
||||||
assert len(manager.mock_created_entries) == 0
|
assert len(manager.mock_created_entries) == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_abort_calls_async_remove(manager):
|
||||||
|
"""Test abort calling the async_remove FlowHandler method."""
|
||||||
|
|
||||||
|
@manager.mock_reg_handler("test")
|
||||||
|
class TestFlow(data_entry_flow.FlowHandler):
|
||||||
|
async def async_step_init(self, user_input=None):
|
||||||
|
return self.async_abort(reason="reason")
|
||||||
|
|
||||||
|
async_remove = Mock()
|
||||||
|
|
||||||
|
await manager.async_init("test")
|
||||||
|
|
||||||
|
TestFlow.async_remove.assert_called_once()
|
||||||
|
|
||||||
|
assert len(manager.async_progress()) == 0
|
||||||
|
assert len(manager.mock_created_entries) == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_abort_calls_async_remove_with_exception(manager, caplog):
|
||||||
|
"""Test abort calling the async_remove FlowHandler method, with an exception."""
|
||||||
|
|
||||||
|
@manager.mock_reg_handler("test")
|
||||||
|
class TestFlow(data_entry_flow.FlowHandler):
|
||||||
|
async def async_step_init(self, user_input=None):
|
||||||
|
return self.async_abort(reason="reason")
|
||||||
|
|
||||||
|
async_remove = Mock(side_effect=[RuntimeError("error")])
|
||||||
|
|
||||||
|
with caplog.at_level(logging.ERROR):
|
||||||
|
await manager.async_init("test")
|
||||||
|
|
||||||
|
assert "Error removing test config flow: error" in caplog.text
|
||||||
|
|
||||||
|
TestFlow.async_remove.assert_called_once()
|
||||||
|
|
||||||
|
assert len(manager.async_progress()) == 0
|
||||||
|
assert len(manager.mock_created_entries) == 0
|
||||||
|
|
||||||
|
|
||||||
async def test_create_saves_data(manager):
|
async def test_create_saves_data(manager):
|
||||||
"""Test creating a config entry."""
|
"""Test creating a config entry."""
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue