Validate steps in Flowhandler (#102152)
* Validate steps in Flowhandler * Move validation to FlowManager._async_handle_step * Fix _raise_if_not_has_step * Fix config_entries tests * Fix tests * Rename * Add test
This commit is contained in:
parent
9857c0fa3a
commit
4498c2e8c4
7 changed files with 84 additions and 10 deletions
|
@ -382,14 +382,9 @@ class FlowManager(abc.ABC):
|
||||||
self, flow: FlowHandler, step_id: str, user_input: dict | BaseServiceInfo | None
|
self, flow: FlowHandler, step_id: str, user_input: dict | BaseServiceInfo | None
|
||||||
) -> FlowResult:
|
) -> FlowResult:
|
||||||
"""Handle a step of a flow."""
|
"""Handle a step of a flow."""
|
||||||
|
self._raise_if_step_does_not_exist(flow, step_id)
|
||||||
|
|
||||||
method = f"async_step_{step_id}"
|
method = f"async_step_{step_id}"
|
||||||
|
|
||||||
if not hasattr(flow, method):
|
|
||||||
self._async_remove_flow_progress(flow.flow_id)
|
|
||||||
raise UnknownStep(
|
|
||||||
f"Handler {flow.__class__.__name__} doesn't support step {step_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result: FlowResult = await getattr(flow, method)(user_input)
|
result: FlowResult = await getattr(flow, method)(user_input)
|
||||||
except AbortFlow as err:
|
except AbortFlow as err:
|
||||||
|
@ -419,6 +414,7 @@ class FlowManager(abc.ABC):
|
||||||
FlowResultType.SHOW_PROGRESS_DONE,
|
FlowResultType.SHOW_PROGRESS_DONE,
|
||||||
FlowResultType.MENU,
|
FlowResultType.MENU,
|
||||||
):
|
):
|
||||||
|
self._raise_if_step_does_not_exist(flow, result["step_id"])
|
||||||
flow.cur_step = result
|
flow.cur_step = result
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -435,6 +431,16 @@ class FlowManager(abc.ABC):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _raise_if_step_does_not_exist(self, flow: FlowHandler, step_id: str) -> None:
|
||||||
|
"""Raise if the step does not exist."""
|
||||||
|
method = f"async_step_{step_id}"
|
||||||
|
|
||||||
|
if not hasattr(flow, method):
|
||||||
|
self._async_remove_flow_progress(flow.flow_id)
|
||||||
|
raise UnknownStep(
|
||||||
|
f"Handler {self.__class__.__name__} doesn't support step {step_id}"
|
||||||
|
)
|
||||||
|
|
||||||
async def _async_setup_preview(self, flow: FlowHandler) -> None:
|
async def _async_setup_preview(self, flow: FlowHandler) -> None:
|
||||||
"""Set up preview for a flow handler."""
|
"""Set up preview for a flow handler."""
|
||||||
if flow.handler not in self._preview:
|
if flow.handler not in self._preview:
|
||||||
|
|
|
@ -77,7 +77,10 @@ async def test_invalid_credentials(hass: HomeAssistant) -> None:
|
||||||
),
|
),
|
||||||
), patch(
|
), patch(
|
||||||
"homeassistant.components.abode.config_flow.AbodeFlowHandler.async_step_reauth",
|
"homeassistant.components.abode.config_flow.AbodeFlowHandler.async_step_reauth",
|
||||||
return_value={"type": data_entry_flow.FlowResultType.FORM},
|
return_value={
|
||||||
|
"type": data_entry_flow.FlowResultType.FORM,
|
||||||
|
"step_id": "reauth_confirm",
|
||||||
|
},
|
||||||
) as mock_async_step_reauth:
|
) as mock_async_step_reauth:
|
||||||
await setup_platform(hass, ALARM_DOMAIN)
|
await setup_platform(hass, ALARM_DOMAIN)
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,10 @@ async def test_auth_failure(hass: HomeAssistant) -> None:
|
||||||
"""Test init with an authentication failure."""
|
"""Test init with an authentication failure."""
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.aussie_broadband.config_flow.ConfigFlow.async_step_reauth",
|
"homeassistant.components.aussie_broadband.config_flow.ConfigFlow.async_step_reauth",
|
||||||
return_value={"type": data_entry_flow.FlowResultType.FORM},
|
return_value={
|
||||||
|
"type": data_entry_flow.FlowResultType.FORM,
|
||||||
|
"step_id": "reauth_confirm",
|
||||||
|
},
|
||||||
) as mock_async_step_reauth:
|
) as mock_async_step_reauth:
|
||||||
await setup_platform(hass, side_effect=AuthenticationException())
|
await setup_platform(hass, side_effect=AuthenticationException())
|
||||||
mock_async_step_reauth.assert_called_once()
|
mock_async_step_reauth.assert_called_once()
|
||||||
|
|
|
@ -798,6 +798,9 @@ async def test_options_flow(hass: HomeAssistant, client) -> None:
|
||||||
description_placeholders={"enabled": "Set to true to be true"},
|
description_placeholders={"enabled": "Set to true to be true"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def async_step_user(self, user_input=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
return OptionsFlowHandler()
|
return OptionsFlowHandler()
|
||||||
|
|
||||||
mock_integration(hass, MockModule("test"))
|
mock_integration(hass, MockModule("test"))
|
||||||
|
@ -1271,6 +1274,9 @@ async def test_ignore_flow(
|
||||||
await self.async_set_unique_id("mock-unique-id")
|
await self.async_set_unique_id("mock-unique-id")
|
||||||
return self.async_show_form(step_id="account")
|
return self.async_show_form(step_id="account")
|
||||||
|
|
||||||
|
async def async_step_account(self, user_input=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
ws_client = await hass_ws_client(hass)
|
ws_client = await hass_ws_client(hass)
|
||||||
|
|
||||||
with patch.dict(HANDLERS, {"test": TestFlow}):
|
with patch.dict(HANDLERS, {"test": TestFlow}):
|
||||||
|
|
|
@ -50,7 +50,10 @@ async def test_reauth_triggered(hass: HomeAssistant) -> None:
|
||||||
side_effect=SynologyDSMLoginInvalidException(USERNAME),
|
side_effect=SynologyDSMLoginInvalidException(USERNAME),
|
||||||
), patch(
|
), patch(
|
||||||
"homeassistant.components.synology_dsm.config_flow.SynologyDSMFlowHandler.async_step_reauth",
|
"homeassistant.components.synology_dsm.config_flow.SynologyDSMFlowHandler.async_step_reauth",
|
||||||
return_value={"type": data_entry_flow.FlowResultType.FORM},
|
return_value={
|
||||||
|
"type": data_entry_flow.FlowResultType.FORM,
|
||||||
|
"step_id": "reauth_confirm",
|
||||||
|
},
|
||||||
) as mock_async_step_reauth:
|
) as mock_async_step_reauth:
|
||||||
entry = MockConfigEntry(
|
entry = MockConfigEntry(
|
||||||
domain=DOMAIN,
|
domain=DOMAIN,
|
||||||
|
|
|
@ -2171,6 +2171,9 @@ async def test_manual_add_overrides_ignored_entry(
|
||||||
)
|
)
|
||||||
return self.async_show_form(step_id="step2")
|
return self.async_show_form(step_id="step2")
|
||||||
|
|
||||||
|
async def async_step_step2(self, user_input=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}), patch(
|
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}), patch(
|
||||||
"homeassistant.config_entries.ConfigEntries.async_reload"
|
"homeassistant.config_entries.ConfigEntries.async_reload"
|
||||||
) as async_reload:
|
) as async_reload:
|
||||||
|
@ -2500,6 +2503,9 @@ async def test_partial_flows_hidden(
|
||||||
await pause_discovery.wait()
|
await pause_discovery.wait()
|
||||||
return self.async_show_form(step_id="someform")
|
return self.async_show_form(step_id="someform")
|
||||||
|
|
||||||
|
async def async_step_someform(self, user_input=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
|
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
|
||||||
# Start a config entry flow and wait for it to be blocked
|
# Start a config entry flow and wait for it to be blocked
|
||||||
init_task = asyncio.ensure_future(
|
init_task = asyncio.ensure_future(
|
||||||
|
@ -2788,6 +2794,9 @@ async def test_flow_with_default_discovery_with_unique_id(
|
||||||
await self._async_handle_discovery_without_unique_id()
|
await self._async_handle_discovery_without_unique_id()
|
||||||
return self.async_show_form(step_id="mock")
|
return self.async_show_form(step_id="mock")
|
||||||
|
|
||||||
|
async def async_step_mock(self, user_input=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
|
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
|
||||||
result = await manager.flow.async_init(
|
result = await manager.flow.async_init(
|
||||||
"comp", context={"source": config_entries.SOURCE_DISCOVERY}
|
"comp", context={"source": config_entries.SOURCE_DISCOVERY}
|
||||||
|
@ -2841,6 +2850,9 @@ async def test_default_discovery_in_progress(
|
||||||
await self._async_handle_discovery_without_unique_id()
|
await self._async_handle_discovery_without_unique_id()
|
||||||
return self.async_show_form(step_id="mock")
|
return self.async_show_form(step_id="mock")
|
||||||
|
|
||||||
|
async def async_step_mock(self, user_input=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
|
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
|
||||||
result = await manager.flow.async_init(
|
result = await manager.flow.async_init(
|
||||||
"comp",
|
"comp",
|
||||||
|
@ -2878,6 +2890,9 @@ async def test_default_discovery_abort_on_new_unique_flow(
|
||||||
await self._async_handle_discovery_without_unique_id()
|
await self._async_handle_discovery_without_unique_id()
|
||||||
return self.async_show_form(step_id="mock")
|
return self.async_show_form(step_id="mock")
|
||||||
|
|
||||||
|
async def async_step_mock(self, user_input=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
|
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
|
||||||
# First discovery with default, no unique ID
|
# First discovery with default, no unique ID
|
||||||
result2 = await manager.flow.async_init(
|
result2 = await manager.flow.async_init(
|
||||||
|
@ -2922,6 +2937,9 @@ async def test_default_discovery_abort_on_user_flow_complete(
|
||||||
await self._async_handle_discovery_without_unique_id()
|
await self._async_handle_discovery_without_unique_id()
|
||||||
return self.async_show_form(step_id="mock")
|
return self.async_show_form(step_id="mock")
|
||||||
|
|
||||||
|
async def async_step_mock(self, user_input=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
|
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
|
||||||
# First discovery with default, no unique ID
|
# First discovery with default, no unique ID
|
||||||
flow1 = await manager.flow.async_init(
|
flow1 = await manager.flow.async_init(
|
||||||
|
@ -3968,6 +3986,9 @@ async def test_preview_supported(
|
||||||
"""Mock Reauth."""
|
"""Mock Reauth."""
|
||||||
return self.async_show_form(step_id="next", preview="test")
|
return self.async_show_form(step_id="next", preview="test")
|
||||||
|
|
||||||
|
async def async_step_next(self, user_input=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def async_setup_preview(hass: HomeAssistant) -> None:
|
async def async_setup_preview(hass: HomeAssistant) -> None:
|
||||||
"""Set up preview."""
|
"""Set up preview."""
|
||||||
|
@ -4006,6 +4027,9 @@ async def test_preview_not_supported(
|
||||||
"""Mock Reauth."""
|
"""Mock Reauth."""
|
||||||
return self.async_show_form(step_id="user_confirm")
|
return self.async_show_form(step_id="user_confirm")
|
||||||
|
|
||||||
|
async def async_step_user_confirm(self, user_input=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
mock_integration(hass, MockModule("test"))
|
mock_integration(hass, MockModule("test"))
|
||||||
mock_entity_platform(hass, "config_flow.test", None)
|
mock_entity_platform(hass, "config_flow.test", None)
|
||||||
|
|
||||||
|
|
|
@ -621,6 +621,35 @@ async def test_move_to_unknown_step_raises_and_removes_from_in_progress(
|
||||||
assert manager.async_progress() == []
|
assert manager.async_progress() == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("result_type", "params"),
|
||||||
|
[
|
||||||
|
("async_external_step_done", {"next_step_id": "does_not_exist"}),
|
||||||
|
("async_external_step", {"step_id": "does_not_exist", "url": "blah"}),
|
||||||
|
("async_show_form", {"step_id": "does_not_exist"}),
|
||||||
|
("async_show_menu", {"step_id": "does_not_exist", "menu_options": []}),
|
||||||
|
("async_show_progress_done", {"next_step_id": "does_not_exist"}),
|
||||||
|
("async_show_progress", {"step_id": "does_not_exist", "progress_action": ""}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_next_step_unknown_step_raises_and_removes_from_in_progress(
|
||||||
|
manager, result_type: str, params: dict[str, str]
|
||||||
|
) -> None:
|
||||||
|
"""Test that moving to an unknown step raises and removes the flow from in progress."""
|
||||||
|
|
||||||
|
@manager.mock_reg_handler("test")
|
||||||
|
class TestFlow(data_entry_flow.FlowHandler):
|
||||||
|
VERSION = 1
|
||||||
|
|
||||||
|
async def async_step_init(self, user_input=None):
|
||||||
|
return getattr(self, result_type)(**params)
|
||||||
|
|
||||||
|
with pytest.raises(data_entry_flow.UnknownStep):
|
||||||
|
await manager.async_init("test", context={"init_step": "init"})
|
||||||
|
|
||||||
|
assert manager.async_progress() == []
|
||||||
|
|
||||||
|
|
||||||
async def test_configure_raises_unknown_flow_if_not_in_progress(manager) -> None:
|
async def test_configure_raises_unknown_flow_if_not_in_progress(manager) -> None:
|
||||||
"""Test configure raises UnknownFlow if the flow is not in progress."""
|
"""Test configure raises UnknownFlow if the flow is not in progress."""
|
||||||
with pytest.raises(data_entry_flow.UnknownFlow):
|
with pytest.raises(data_entry_flow.UnknownFlow):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue