Validate steps in Flowhandler (#102152)

* Validate steps in Flowhandler

* Move validation to FlowManager._async_handle_step

* Fix _raise_if_not_has_step

* Fix config_entries tests

* Fix tests

* Rename

* Add test
This commit is contained in:
Erik Montnemery 2023-10-19 13:34:10 +02:00 committed by GitHub
parent 9857c0fa3a
commit 4498c2e8c4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 84 additions and 10 deletions

View file

@ -382,14 +382,9 @@ class FlowManager(abc.ABC):
self, flow: FlowHandler, step_id: str, user_input: dict | BaseServiceInfo | None
) -> FlowResult:
"""Handle a step of a flow."""
self._raise_if_step_does_not_exist(flow, step_id)
method = f"async_step_{step_id}"
if not hasattr(flow, method):
self._async_remove_flow_progress(flow.flow_id)
raise UnknownStep(
f"Handler {flow.__class__.__name__} doesn't support step {step_id}"
)
try:
result: FlowResult = await getattr(flow, method)(user_input)
except AbortFlow as err:
@ -419,6 +414,7 @@ class FlowManager(abc.ABC):
FlowResultType.SHOW_PROGRESS_DONE,
FlowResultType.MENU,
):
self._raise_if_step_does_not_exist(flow, result["step_id"])
flow.cur_step = result
return result
@ -435,6 +431,16 @@ class FlowManager(abc.ABC):
return result
def _raise_if_step_does_not_exist(self, flow: FlowHandler, step_id: str) -> None:
"""Raise if the step does not exist."""
method = f"async_step_{step_id}"
if not hasattr(flow, method):
self._async_remove_flow_progress(flow.flow_id)
raise UnknownStep(
f"Handler {self.__class__.__name__} doesn't support step {step_id}"
)
async def _async_setup_preview(self, flow: FlowHandler) -> None:
"""Set up preview for a flow handler."""
if flow.handler not in self._preview:

View file

@ -77,7 +77,10 @@ async def test_invalid_credentials(hass: HomeAssistant) -> None:
),
), patch(
"homeassistant.components.abode.config_flow.AbodeFlowHandler.async_step_reauth",
return_value={"type": data_entry_flow.FlowResultType.FORM},
return_value={
"type": data_entry_flow.FlowResultType.FORM,
"step_id": "reauth_confirm",
},
) as mock_async_step_reauth:
await setup_platform(hass, ALARM_DOMAIN)

View file

@ -23,7 +23,10 @@ async def test_auth_failure(hass: HomeAssistant) -> None:
"""Test init with an authentication failure."""
with patch(
"homeassistant.components.aussie_broadband.config_flow.ConfigFlow.async_step_reauth",
return_value={"type": data_entry_flow.FlowResultType.FORM},
return_value={
"type": data_entry_flow.FlowResultType.FORM,
"step_id": "reauth_confirm",
},
) as mock_async_step_reauth:
await setup_platform(hass, side_effect=AuthenticationException())
mock_async_step_reauth.assert_called_once()

View file

@ -798,6 +798,9 @@ async def test_options_flow(hass: HomeAssistant, client) -> None:
description_placeholders={"enabled": "Set to true to be true"},
)
async def async_step_user(self, user_input=None):
raise NotImplementedError
return OptionsFlowHandler()
mock_integration(hass, MockModule("test"))
@ -1271,6 +1274,9 @@ async def test_ignore_flow(
await self.async_set_unique_id("mock-unique-id")
return self.async_show_form(step_id="account")
async def async_step_account(self, user_input=None):
raise NotImplementedError
ws_client = await hass_ws_client(hass)
with patch.dict(HANDLERS, {"test": TestFlow}):

View file

@ -50,7 +50,10 @@ async def test_reauth_triggered(hass: HomeAssistant) -> None:
side_effect=SynologyDSMLoginInvalidException(USERNAME),
), patch(
"homeassistant.components.synology_dsm.config_flow.SynologyDSMFlowHandler.async_step_reauth",
return_value={"type": data_entry_flow.FlowResultType.FORM},
return_value={
"type": data_entry_flow.FlowResultType.FORM,
"step_id": "reauth_confirm",
},
) as mock_async_step_reauth:
entry = MockConfigEntry(
domain=DOMAIN,

View file

@ -2171,6 +2171,9 @@ async def test_manual_add_overrides_ignored_entry(
)
return self.async_show_form(step_id="step2")
async def async_step_step2(self, user_input=None):
raise NotImplementedError
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}), patch(
"homeassistant.config_entries.ConfigEntries.async_reload"
) as async_reload:
@ -2500,6 +2503,9 @@ async def test_partial_flows_hidden(
await pause_discovery.wait()
return self.async_show_form(step_id="someform")
async def async_step_someform(self, user_input=None):
raise NotImplementedError
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
# Start a config entry flow and wait for it to be blocked
init_task = asyncio.ensure_future(
@ -2788,6 +2794,9 @@ async def test_flow_with_default_discovery_with_unique_id(
await self._async_handle_discovery_without_unique_id()
return self.async_show_form(step_id="mock")
async def async_step_mock(self, user_input=None):
raise NotImplementedError
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
result = await manager.flow.async_init(
"comp", context={"source": config_entries.SOURCE_DISCOVERY}
@ -2841,6 +2850,9 @@ async def test_default_discovery_in_progress(
await self._async_handle_discovery_without_unique_id()
return self.async_show_form(step_id="mock")
async def async_step_mock(self, user_input=None):
raise NotImplementedError
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
result = await manager.flow.async_init(
"comp",
@ -2878,6 +2890,9 @@ async def test_default_discovery_abort_on_new_unique_flow(
await self._async_handle_discovery_without_unique_id()
return self.async_show_form(step_id="mock")
async def async_step_mock(self, user_input=None):
raise NotImplementedError
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
# First discovery with default, no unique ID
result2 = await manager.flow.async_init(
@ -2922,6 +2937,9 @@ async def test_default_discovery_abort_on_user_flow_complete(
await self._async_handle_discovery_without_unique_id()
return self.async_show_form(step_id="mock")
async def async_step_mock(self, user_input=None):
raise NotImplementedError
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
# First discovery with default, no unique ID
flow1 = await manager.flow.async_init(
@ -3968,6 +3986,9 @@ async def test_preview_supported(
"""Mock Reauth."""
return self.async_show_form(step_id="next", preview="test")
async def async_step_next(self, user_input=None):
raise NotImplementedError
@staticmethod
async def async_setup_preview(hass: HomeAssistant) -> None:
"""Set up preview."""
@ -4006,6 +4027,9 @@ async def test_preview_not_supported(
"""Mock Reauth."""
return self.async_show_form(step_id="user_confirm")
async def async_step_user_confirm(self, user_input=None):
raise NotImplementedError
mock_integration(hass, MockModule("test"))
mock_entity_platform(hass, "config_flow.test", None)

View file

@ -621,6 +621,35 @@ async def test_move_to_unknown_step_raises_and_removes_from_in_progress(
assert manager.async_progress() == []
@pytest.mark.parametrize(
("result_type", "params"),
[
("async_external_step_done", {"next_step_id": "does_not_exist"}),
("async_external_step", {"step_id": "does_not_exist", "url": "blah"}),
("async_show_form", {"step_id": "does_not_exist"}),
("async_show_menu", {"step_id": "does_not_exist", "menu_options": []}),
("async_show_progress_done", {"next_step_id": "does_not_exist"}),
("async_show_progress", {"step_id": "does_not_exist", "progress_action": ""}),
],
)
async def test_next_step_unknown_step_raises_and_removes_from_in_progress(
manager, result_type: str, params: dict[str, str]
) -> None:
"""Test that moving to an unknown step raises and removes the flow from in progress."""
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 1
async def async_step_init(self, user_input=None):
return getattr(self, result_type)(**params)
with pytest.raises(data_entry_flow.UnknownStep):
await manager.async_init("test", context={"init_step": "init"})
assert manager.async_progress() == []
async def test_configure_raises_unknown_flow_if_not_in_progress(manager) -> None:
"""Test configure raises UnknownFlow if the flow is not in progress."""
with pytest.raises(data_entry_flow.UnknownFlow):