diff --git a/tests/test_data_entry_flow.py b/tests/test_data_entry_flow.py index c02d909733a..782f349f9f2 100644 --- a/tests/test_data_entry_flow.py +++ b/tests/test_data_entry_flow.py @@ -19,42 +19,42 @@ from .common import ( ) +class MockFlowManager(data_entry_flow.FlowManager): + """Test flow manager.""" + + def __init__(self) -> None: + """Initialize the flow manager.""" + super().__init__(None) + self._handlers = Registry() + self.mock_reg_handler = self._handlers.register + self.mock_created_entries = [] + + async def async_create_flow(self, handler_key, *, context, data): + """Test create flow.""" + handler = self._handlers.get(handler_key) + + if handler is None: + raise data_entry_flow.UnknownHandler + + flow = handler() + flow.init_step = context.get("init_step", "init") + return flow + + async def async_finish_flow(self, flow, result): + """Test finish flow.""" + if result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY: + result["source"] = flow.context.get("source") + self.mock_created_entries.append(result) + return result + + @pytest.fixture -def manager(): +def manager() -> MockFlowManager: """Return a flow manager.""" - handlers = Registry() - entries = [] - - class FlowManager(data_entry_flow.FlowManager): - """Test flow manager.""" - - async def async_create_flow(self, handler_key, *, context, data): - """Test create flow.""" - handler = handlers.get(handler_key) - - if handler is None: - raise data_entry_flow.UnknownHandler - - flow = handler() - flow.init_step = context.get("init_step", "init") - return flow - - async def async_finish_flow(self, flow, result): - """Test finish flow.""" - if result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY: - result["source"] = flow.context.get("source") - entries.append(result) - return result - - mgr = FlowManager(None) - # pylint: disable-next=attribute-defined-outside-init - mgr.mock_created_entries = entries - # pylint: disable-next=attribute-defined-outside-init - mgr.mock_reg_handler = handlers.register - return mgr + return MockFlowManager() -async def test_configure_reuses_handler_instance(manager) -> None: +async def test_configure_reuses_handler_instance(manager: MockFlowManager) -> None: """Test that we reuse instances.""" @manager.mock_reg_handler("test") @@ -82,7 +82,7 @@ async def test_configure_reuses_handler_instance(manager) -> None: assert len(manager.mock_created_entries) == 0 -async def test_configure_two_steps(manager: data_entry_flow.FlowManager) -> None: +async def test_configure_two_steps(manager: MockFlowManager) -> None: """Test that we reuse instances.""" @manager.mock_reg_handler("test") @@ -117,7 +117,7 @@ async def test_configure_two_steps(manager: data_entry_flow.FlowManager) -> None assert result["data"] == ["INIT-DATA", "SECOND-DATA"] -async def test_show_form(manager) -> None: +async def test_show_form(manager: MockFlowManager) -> None: """Test that we can show a form.""" schema = vol.Schema({vol.Required("username"): str, vol.Required("password"): str}) @@ -136,7 +136,7 @@ async def test_show_form(manager) -> None: assert form["errors"] == {"username": "Should be unique."} -async def test_abort_removes_instance(manager) -> None: +async def test_abort_removes_instance(manager: MockFlowManager) -> None: """Test that abort removes the flow from progress.""" @manager.mock_reg_handler("test") @@ -158,7 +158,7 @@ async def test_abort_removes_instance(manager) -> None: assert len(manager.mock_created_entries) == 0 -async def test_abort_calls_async_remove(manager) -> None: +async def test_abort_calls_async_remove(manager: MockFlowManager) -> None: """Test abort calling the async_remove FlowHandler method.""" @manager.mock_reg_handler("test") @@ -177,7 +177,7 @@ async def test_abort_calls_async_remove(manager) -> None: async def test_abort_calls_async_remove_with_exception( - manager, caplog: pytest.LogCaptureFixture + manager: MockFlowManager, caplog: pytest.LogCaptureFixture ) -> None: """Test abort calling the async_remove FlowHandler method, with an exception.""" @@ -199,7 +199,7 @@ async def test_abort_calls_async_remove_with_exception( assert len(manager.mock_created_entries) == 0 -async def test_create_saves_data(manager) -> None: +async def test_create_saves_data(manager: MockFlowManager) -> None: """Test creating a config entry.""" @manager.mock_reg_handler("test") @@ -220,7 +220,7 @@ async def test_create_saves_data(manager) -> None: assert entry["source"] is None -async def test_discovery_init_flow(manager) -> None: +async def test_discovery_init_flow(manager: MockFlowManager) -> None: """Test a flow initialized by discovery.""" @manager.mock_reg_handler("test") @@ -290,7 +290,7 @@ async def test_finish_callback_change_result_type(hass: HomeAssistant) -> None: assert result["result"] == 2 -async def test_external_step(hass: HomeAssistant, manager) -> None: +async def test_external_step(hass: HomeAssistant, manager: MockFlowManager) -> None: """Test external step logic.""" manager.hass = hass @@ -340,7 +340,7 @@ async def test_external_step(hass: HomeAssistant, manager) -> None: assert result["title"] == "Hello" -async def test_show_progress(hass: HomeAssistant, manager) -> None: +async def test_show_progress(hass: HomeAssistant, manager: MockFlowManager) -> None: """Test show progress logic.""" manager.hass = hass events = [] @@ -443,7 +443,9 @@ async def test_show_progress(hass: HomeAssistant, manager) -> None: assert result["title"] == "Hello" -async def test_show_progress_error(hass: HomeAssistant, manager) -> None: +async def test_show_progress_error( + hass: HomeAssistant, manager: MockFlowManager +) -> None: """Test show progress logic.""" manager.hass = hass events = [] @@ -506,7 +508,9 @@ async def test_show_progress_error(hass: HomeAssistant, manager) -> None: assert result["reason"] == "error" -async def test_show_progress_hidden_from_frontend(hass: HomeAssistant, manager) -> None: +async def test_show_progress_hidden_from_frontend( + hass: HomeAssistant, manager: MockFlowManager +) -> None: """Test show progress done is not sent to frontend.""" manager.hass = hass async_show_progress_done_called = False @@ -557,7 +561,7 @@ async def test_show_progress_hidden_from_frontend(hass: HomeAssistant, manager) async def test_show_progress_legacy( - hass: HomeAssistant, manager, caplog: pytest.LogCaptureFixture + hass: HomeAssistant, manager: MockFlowManager, caplog: pytest.LogCaptureFixture ) -> None: """Test show progress logic. @@ -659,7 +663,7 @@ async def test_show_progress_legacy( async def test_show_progress_fires_only_when_changed( - hass: HomeAssistant, manager + hass: HomeAssistant, manager: MockFlowManager ) -> None: """Test show progress change logic.""" manager.hass = hass @@ -745,7 +749,7 @@ async def test_show_progress_fires_only_when_changed( ) # change (description placeholder) -async def test_abort_flow_exception(manager) -> None: +async def test_abort_flow_exception(manager: MockFlowManager) -> None: """Test that the AbortFlow exception works.""" @manager.mock_reg_handler("test") @@ -759,7 +763,7 @@ async def test_abort_flow_exception(manager) -> None: assert form["description_placeholders"] == {"placeholder": "yo"} -async def test_init_unknown_flow(manager) -> None: +async def test_init_unknown_flow(manager: MockFlowManager) -> None: """Test that UnknownFlow is raised when async_create_flow returns None.""" with ( @@ -769,7 +773,7 @@ async def test_init_unknown_flow(manager) -> None: await manager.async_init("test") -async def test_async_get_unknown_flow(manager) -> None: +async def test_async_get_unknown_flow(manager: MockFlowManager) -> None: """Test that UnknownFlow is raised when async_get is called with a flow_id that does not exist.""" with pytest.raises(data_entry_flow.UnknownFlow): @@ -777,7 +781,7 @@ async def test_async_get_unknown_flow(manager) -> None: async def test_async_has_matching_flow( - hass: HomeAssistant, manager: data_entry_flow.FlowManager + hass: HomeAssistant, manager: MockFlowManager ) -> None: """Test we can check for matching flows.""" manager.hass = hass @@ -854,7 +858,7 @@ async def test_async_has_matching_flow( async def test_move_to_unknown_step_raises_and_removes_from_in_progress( - manager, + manager: MockFlowManager, ) -> None: """Test that moving to an unknown step raises and removes the flow from in progress.""" @@ -880,7 +884,7 @@ async def test_move_to_unknown_step_raises_and_removes_from_in_progress( ], ) async def test_next_step_unknown_step_raises_and_removes_from_in_progress( - manager, result_type: str, params: dict[str, str] + manager: MockFlowManager, result_type: str, params: dict[str, str] ) -> None: """Test that moving to an unknown step raises and removes the flow from in progress.""" @@ -897,13 +901,17 @@ async def test_next_step_unknown_step_raises_and_removes_from_in_progress( assert manager.async_progress() == [] -async def test_configure_raises_unknown_flow_if_not_in_progress(manager) -> None: +async def test_configure_raises_unknown_flow_if_not_in_progress( + manager: MockFlowManager, +) -> None: """Test configure raises UnknownFlow if the flow is not in progress.""" with pytest.raises(data_entry_flow.UnknownFlow): await manager.async_configure("wrong_flow_id") -async def test_abort_raises_unknown_flow_if_not_in_progress(manager) -> None: +async def test_abort_raises_unknown_flow_if_not_in_progress( + manager: MockFlowManager, +) -> None: """Test abort raises UnknownFlow if the flow is not in progress.""" with pytest.raises(data_entry_flow.UnknownFlow): await manager.async_abort("wrong_flow_id") @@ -913,7 +921,11 @@ async def test_abort_raises_unknown_flow_if_not_in_progress(manager) -> None: "menu_options", [["target1", "target2"], {"target1": "Target 1", "target2": "Target 2"}], ) -async def test_show_menu(hass: HomeAssistant, manager, menu_options) -> None: +async def test_show_menu( + hass: HomeAssistant, + manager: MockFlowManager, + menu_options: list[str] | dict[str, str], +) -> None: """Test show menu.""" manager.hass = hass @@ -952,9 +964,7 @@ async def test_show_menu(hass: HomeAssistant, manager, menu_options) -> None: assert result["step_id"] == "target1" -async def test_find_flows_by_init_data_type( - manager: data_entry_flow.FlowManager, -) -> None: +async def test_find_flows_by_init_data_type(manager: MockFlowManager) -> None: """Test we can find flows by init data type.""" @dataclasses.dataclass