Allow finish_flow callback to change data entry result type (#16100)
* Allow finish_flow callback to change data entry result type * Add unit test
This commit is contained in:
parent
b26506ad4a
commit
00c6f56cc8
4 changed files with 80 additions and 29 deletions
|
@ -2,7 +2,7 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
|
||||
import jwt
|
||||
|
||||
|
@ -256,20 +256,23 @@ class AuthManager:
|
|||
return await auth_provider.async_credential_flow(context)
|
||||
|
||||
async def _async_finish_login_flow(
|
||||
self, context: Optional[Dict], result: Dict[str, Any]) \
|
||||
-> Optional[Union[models.User, models.Credentials]]:
|
||||
self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any]) \
|
||||
-> Dict[str, Any]:
|
||||
"""Return a user as result of login flow."""
|
||||
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||
return None
|
||||
return result
|
||||
|
||||
auth_provider = self._providers[result['handler']]
|
||||
cred = await auth_provider.async_get_or_create_credentials(
|
||||
credentials = await auth_provider.async_get_or_create_credentials(
|
||||
result['data'])
|
||||
|
||||
if context is not None and context.get('credential_only'):
|
||||
return cred
|
||||
if flow.context is not None and flow.context.get('credential_only'):
|
||||
result['result'] = credentials
|
||||
return result
|
||||
|
||||
return await self.async_get_or_create_user(cred)
|
||||
user = await self.async_get_or_create_user(credentials)
|
||||
result['result'] = user
|
||||
return result
|
||||
|
||||
@callback
|
||||
def _async_get_auth_provider(
|
||||
|
|
|
@ -372,23 +372,24 @@ class ConfigEntries:
|
|||
return await entry.async_unload(
|
||||
self.hass, component=getattr(self.hass.components, component))
|
||||
|
||||
async def _async_finish_flow(self, context, result):
|
||||
async def _async_finish_flow(self, flow, result):
|
||||
"""Finish a config flow and add an entry."""
|
||||
# If no discovery config entries in progress, remove notification.
|
||||
# Remove notification if no other discovery config entries in progress
|
||||
if not any(ent['context']['source'] in DISCOVERY_SOURCES for ent
|
||||
in self.hass.config_entries.flow.async_progress()):
|
||||
in self.hass.config_entries.flow.async_progress()
|
||||
if ent['flow_id'] != flow.flow_id):
|
||||
self.hass.components.persistent_notification.async_dismiss(
|
||||
DISCOVERY_NOTIFICATION_ID)
|
||||
|
||||
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||
return None
|
||||
return result
|
||||
|
||||
entry = ConfigEntry(
|
||||
version=result['version'],
|
||||
domain=result['handler'],
|
||||
title=result['title'],
|
||||
data=result['data'],
|
||||
source=context['source'],
|
||||
source=flow.context['source'],
|
||||
)
|
||||
self._entries.append(entry)
|
||||
self._async_schedule_save()
|
||||
|
@ -402,11 +403,8 @@ class ConfigEntries:
|
|||
await async_setup_component(
|
||||
self.hass, entry.domain, self._hass_config)
|
||||
|
||||
# Return Entry if they not from a discovery request
|
||||
if context['source'] not in DISCOVERY_SOURCES:
|
||||
return entry
|
||||
|
||||
return entry
|
||||
result['result'] = entry
|
||||
return result
|
||||
|
||||
async def _async_create_flow(self, handler_key, *, context, data):
|
||||
"""Create a flow for specified handler.
|
||||
|
|
|
@ -64,7 +64,7 @@ class FlowManager:
|
|||
return await self._async_handle_step(flow, flow.init_step, data)
|
||||
|
||||
async def async_configure(
|
||||
self, flow_id: str, user_input: Optional[str] = None) -> Any:
|
||||
self, flow_id: str, user_input: Optional[Dict] = None) -> Any:
|
||||
"""Continue a configuration flow."""
|
||||
flow = self._progress.get(flow_id)
|
||||
|
||||
|
@ -86,7 +86,7 @@ class FlowManager:
|
|||
raise UnknownFlow
|
||||
|
||||
async def _async_handle_step(self, flow: Any, step_id: str,
|
||||
user_input: Optional[str]) -> Dict:
|
||||
user_input: Optional[Dict]) -> Dict:
|
||||
"""Handle a step of a flow."""
|
||||
method = "async_step_{}".format(step_id)
|
||||
|
||||
|
@ -106,14 +106,17 @@ class FlowManager:
|
|||
flow.cur_step = (result['step_id'], result['data_schema'])
|
||||
return result
|
||||
|
||||
# We pass a copy of the result because we're mutating our version
|
||||
result = await self._async_finish_flow(flow, dict(result))
|
||||
|
||||
# _async_finish_flow may change result type, check it again
|
||||
if result['type'] == RESULT_TYPE_FORM:
|
||||
flow.cur_step = (result['step_id'], result['data_schema'])
|
||||
return result
|
||||
|
||||
# Abort and Success results both finish the flow
|
||||
self._progress.pop(flow.flow_id)
|
||||
|
||||
# We pass a copy of the result because we're mutating our version
|
||||
entry = await self._async_finish_flow(flow.context, dict(result))
|
||||
|
||||
if result['type'] == RESULT_TYPE_CREATE_ENTRY:
|
||||
result['result'] = entry
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
@ -25,11 +25,12 @@ def manager():
|
|||
if context is not None else 'user_input'
|
||||
return flow
|
||||
|
||||
async def async_add_entry(context, result):
|
||||
if (result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY):
|
||||
result['source'] = context.get('source') \
|
||||
if context is not None else 'user'
|
||||
async def async_add_entry(flow, result):
|
||||
if result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||
result['source'] = flow.context.get('source') \
|
||||
if flow.context is not None else 'user'
|
||||
entries.append(result)
|
||||
return result
|
||||
|
||||
manager = data_entry_flow.FlowManager(
|
||||
None, async_create_flow, async_add_entry)
|
||||
|
@ -198,3 +199,49 @@ async def test_discovery_init_flow(manager):
|
|||
assert entry['title'] == 'hello'
|
||||
assert entry['data'] == data
|
||||
assert entry['source'] == 'discovery'
|
||||
|
||||
|
||||
async def test_finish_callback_change_result_type(hass):
|
||||
"""Test finish callback can change result type."""
|
||||
class TestFlow(data_entry_flow.FlowHandler):
|
||||
VERSION = 1
|
||||
|
||||
async def async_step_init(self, input):
|
||||
"""Return init form with one input field 'count'."""
|
||||
if input is not None:
|
||||
return self.async_create_entry(title='init', data=input)
|
||||
return self.async_show_form(
|
||||
step_id='init',
|
||||
data_schema=vol.Schema({'count': int}))
|
||||
|
||||
async def async_create_flow(handler_name, *, context, data):
|
||||
"""Create a test flow."""
|
||||
return TestFlow()
|
||||
|
||||
async def async_finish_flow(flow, result):
|
||||
"""Redirect to init form if count <= 1."""
|
||||
if result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||
if (result['data'] is None or
|
||||
result['data'].get('count', 0) <= 1):
|
||||
return flow.async_show_form(
|
||||
step_id='init',
|
||||
data_schema=vol.Schema({'count': int}))
|
||||
else:
|
||||
result['result'] = result['data']['count']
|
||||
return result
|
||||
|
||||
manager = data_entry_flow.FlowManager(
|
||||
hass, async_create_flow, async_finish_flow)
|
||||
|
||||
result = await manager.async_init('test')
|
||||
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result['step_id'] == 'init'
|
||||
|
||||
result = await manager.async_configure(result['flow_id'], {'count': 0})
|
||||
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result['step_id'] == 'init'
|
||||
assert 'result' not in result
|
||||
|
||||
result = await manager.async_configure(result['flow_id'], {'count': 2})
|
||||
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
assert result['result'] == 2
|
||||
|
|
Loading…
Add table
Reference in a new issue