diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index 545b799c467..e0ea195a3ff 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -382,14 +382,9 @@ class FlowManager(abc.ABC): self, flow: FlowHandler, step_id: str, user_input: dict | BaseServiceInfo | None ) -> FlowResult: """Handle a step of a flow.""" + self._raise_if_step_does_not_exist(flow, step_id) + method = f"async_step_{step_id}" - - if not hasattr(flow, method): - self._async_remove_flow_progress(flow.flow_id) - raise UnknownStep( - f"Handler {flow.__class__.__name__} doesn't support step {step_id}" - ) - try: result: FlowResult = await getattr(flow, method)(user_input) except AbortFlow as err: @@ -419,6 +414,7 @@ class FlowManager(abc.ABC): FlowResultType.SHOW_PROGRESS_DONE, FlowResultType.MENU, ): + self._raise_if_step_does_not_exist(flow, result["step_id"]) flow.cur_step = result return result @@ -435,6 +431,16 @@ class FlowManager(abc.ABC): return result + def _raise_if_step_does_not_exist(self, flow: FlowHandler, step_id: str) -> None: + """Raise if the step does not exist.""" + method = f"async_step_{step_id}" + + if not hasattr(flow, method): + self._async_remove_flow_progress(flow.flow_id) + raise UnknownStep( + f"Handler {self.__class__.__name__} doesn't support step {step_id}" + ) + async def _async_setup_preview(self, flow: FlowHandler) -> None: """Set up preview for a flow handler.""" if flow.handler not in self._preview: diff --git a/tests/components/abode/test_init.py b/tests/components/abode/test_init.py index 17039235f37..d208b6302bc 100644 --- a/tests/components/abode/test_init.py +++ b/tests/components/abode/test_init.py @@ -77,7 +77,10 @@ async def test_invalid_credentials(hass: HomeAssistant) -> None: ), ), patch( "homeassistant.components.abode.config_flow.AbodeFlowHandler.async_step_reauth", - return_value={"type": data_entry_flow.FlowResultType.FORM}, + return_value={ + "type": data_entry_flow.FlowResultType.FORM, + "step_id": "reauth_confirm", + }, ) as mock_async_step_reauth: await setup_platform(hass, ALARM_DOMAIN) diff --git a/tests/components/aussie_broadband/test_init.py b/tests/components/aussie_broadband/test_init.py index 3eb1972011c..1430eca3a26 100644 --- a/tests/components/aussie_broadband/test_init.py +++ b/tests/components/aussie_broadband/test_init.py @@ -23,7 +23,10 @@ async def test_auth_failure(hass: HomeAssistant) -> None: """Test init with an authentication failure.""" with patch( "homeassistant.components.aussie_broadband.config_flow.ConfigFlow.async_step_reauth", - return_value={"type": data_entry_flow.FlowResultType.FORM}, + return_value={ + "type": data_entry_flow.FlowResultType.FORM, + "step_id": "reauth_confirm", + }, ) as mock_async_step_reauth: await setup_platform(hass, side_effect=AuthenticationException()) mock_async_step_reauth.assert_called_once() diff --git a/tests/components/config/test_config_entries.py b/tests/components/config/test_config_entries.py index 4239e031893..3cc7ada49ba 100644 --- a/tests/components/config/test_config_entries.py +++ b/tests/components/config/test_config_entries.py @@ -798,6 +798,9 @@ async def test_options_flow(hass: HomeAssistant, client) -> None: description_placeholders={"enabled": "Set to true to be true"}, ) + async def async_step_user(self, user_input=None): + raise NotImplementedError + return OptionsFlowHandler() mock_integration(hass, MockModule("test")) @@ -1271,6 +1274,9 @@ async def test_ignore_flow( await self.async_set_unique_id("mock-unique-id") return self.async_show_form(step_id="account") + async def async_step_account(self, user_input=None): + raise NotImplementedError + ws_client = await hass_ws_client(hass) with patch.dict(HANDLERS, {"test": TestFlow}): diff --git a/tests/components/synology_dsm/test_init.py b/tests/components/synology_dsm/test_init.py index bfc3daf0aa2..91556f459ba 100644 --- a/tests/components/synology_dsm/test_init.py +++ b/tests/components/synology_dsm/test_init.py @@ -50,7 +50,10 @@ async def test_reauth_triggered(hass: HomeAssistant) -> None: side_effect=SynologyDSMLoginInvalidException(USERNAME), ), patch( "homeassistant.components.synology_dsm.config_flow.SynologyDSMFlowHandler.async_step_reauth", - return_value={"type": data_entry_flow.FlowResultType.FORM}, + return_value={ + "type": data_entry_flow.FlowResultType.FORM, + "step_id": "reauth_confirm", + }, ) as mock_async_step_reauth: entry = MockConfigEntry( domain=DOMAIN, diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 52caa1ae275..d17c724cb2a 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -2171,6 +2171,9 @@ async def test_manual_add_overrides_ignored_entry( ) return self.async_show_form(step_id="step2") + async def async_step_step2(self, user_input=None): + raise NotImplementedError + with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}), patch( "homeassistant.config_entries.ConfigEntries.async_reload" ) as async_reload: @@ -2500,6 +2503,9 @@ async def test_partial_flows_hidden( await pause_discovery.wait() return self.async_show_form(step_id="someform") + async def async_step_someform(self, user_input=None): + raise NotImplementedError + with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}): # Start a config entry flow and wait for it to be blocked init_task = asyncio.ensure_future( @@ -2788,6 +2794,9 @@ async def test_flow_with_default_discovery_with_unique_id( await self._async_handle_discovery_without_unique_id() return self.async_show_form(step_id="mock") + async def async_step_mock(self, user_input=None): + raise NotImplementedError + with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}): result = await manager.flow.async_init( "comp", context={"source": config_entries.SOURCE_DISCOVERY} @@ -2841,6 +2850,9 @@ async def test_default_discovery_in_progress( await self._async_handle_discovery_without_unique_id() return self.async_show_form(step_id="mock") + async def async_step_mock(self, user_input=None): + raise NotImplementedError + with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}): result = await manager.flow.async_init( "comp", @@ -2878,6 +2890,9 @@ async def test_default_discovery_abort_on_new_unique_flow( await self._async_handle_discovery_without_unique_id() return self.async_show_form(step_id="mock") + async def async_step_mock(self, user_input=None): + raise NotImplementedError + with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}): # First discovery with default, no unique ID result2 = await manager.flow.async_init( @@ -2922,6 +2937,9 @@ async def test_default_discovery_abort_on_user_flow_complete( await self._async_handle_discovery_without_unique_id() return self.async_show_form(step_id="mock") + async def async_step_mock(self, user_input=None): + raise NotImplementedError + with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}): # First discovery with default, no unique ID flow1 = await manager.flow.async_init( @@ -3968,6 +3986,9 @@ async def test_preview_supported( """Mock Reauth.""" return self.async_show_form(step_id="next", preview="test") + async def async_step_next(self, user_input=None): + raise NotImplementedError + @staticmethod async def async_setup_preview(hass: HomeAssistant) -> None: """Set up preview.""" @@ -4006,6 +4027,9 @@ async def test_preview_not_supported( """Mock Reauth.""" return self.async_show_form(step_id="user_confirm") + async def async_step_user_confirm(self, user_input=None): + raise NotImplementedError + mock_integration(hass, MockModule("test")) mock_entity_platform(hass, "config_flow.test", None) diff --git a/tests/test_data_entry_flow.py b/tests/test_data_entry_flow.py index e6a28fc2e4f..98380890e41 100644 --- a/tests/test_data_entry_flow.py +++ b/tests/test_data_entry_flow.py @@ -621,6 +621,35 @@ async def test_move_to_unknown_step_raises_and_removes_from_in_progress( assert manager.async_progress() == [] +@pytest.mark.parametrize( + ("result_type", "params"), + [ + ("async_external_step_done", {"next_step_id": "does_not_exist"}), + ("async_external_step", {"step_id": "does_not_exist", "url": "blah"}), + ("async_show_form", {"step_id": "does_not_exist"}), + ("async_show_menu", {"step_id": "does_not_exist", "menu_options": []}), + ("async_show_progress_done", {"next_step_id": "does_not_exist"}), + ("async_show_progress", {"step_id": "does_not_exist", "progress_action": ""}), + ], +) +async def test_next_step_unknown_step_raises_and_removes_from_in_progress( + manager, result_type: str, params: dict[str, str] +) -> None: + """Test that moving to an unknown step raises and removes the flow from in progress.""" + + @manager.mock_reg_handler("test") + class TestFlow(data_entry_flow.FlowHandler): + VERSION = 1 + + async def async_step_init(self, user_input=None): + return getattr(self, result_type)(**params) + + with pytest.raises(data_entry_flow.UnknownStep): + await manager.async_init("test", context={"init_step": "init"}) + + assert manager.async_progress() == [] + + async def test_configure_raises_unknown_flow_if_not_in_progress(manager) -> None: """Test configure raises UnknownFlow if the flow is not in progress.""" with pytest.raises(data_entry_flow.UnknownFlow):