Allow finish_flow callback to change data entry result type (#16100)

* Allow finish_flow callback to change data entry result type

* Add unit test
This commit is contained in:
Jason Hu 2018-08-21 10:48:24 -07:00 committed by GitHub
parent b26506ad4a
commit 00c6f56cc8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 80 additions and 29 deletions

View file

@ -2,7 +2,7 @@
import asyncio
import logging
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, cast, Union
from typing import Any, Dict, List, Optional, Tuple, cast
import jwt
@ -256,20 +256,23 @@ class AuthManager:
return await auth_provider.async_credential_flow(context)
async def _async_finish_login_flow(
self, context: Optional[Dict], result: Dict[str, Any]) \
-> Optional[Union[models.User, models.Credentials]]:
self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any]) \
-> Dict[str, Any]:
"""Return a user as result of login flow."""
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return None
return result
auth_provider = self._providers[result['handler']]
cred = await auth_provider.async_get_or_create_credentials(
credentials = await auth_provider.async_get_or_create_credentials(
result['data'])
if context is not None and context.get('credential_only'):
return cred
if flow.context is not None and flow.context.get('credential_only'):
result['result'] = credentials
return result
return await self.async_get_or_create_user(cred)
user = await self.async_get_or_create_user(credentials)
result['result'] = user
return result
@callback
def _async_get_auth_provider(

View file

@ -372,23 +372,24 @@ class ConfigEntries:
return await entry.async_unload(
self.hass, component=getattr(self.hass.components, component))
async def _async_finish_flow(self, context, result):
async def _async_finish_flow(self, flow, result):
"""Finish a config flow and add an entry."""
# If no discovery config entries in progress, remove notification.
# Remove notification if no other discovery config entries in progress
if not any(ent['context']['source'] in DISCOVERY_SOURCES for ent
in self.hass.config_entries.flow.async_progress()):
in self.hass.config_entries.flow.async_progress()
if ent['flow_id'] != flow.flow_id):
self.hass.components.persistent_notification.async_dismiss(
DISCOVERY_NOTIFICATION_ID)
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return None
return result
entry = ConfigEntry(
version=result['version'],
domain=result['handler'],
title=result['title'],
data=result['data'],
source=context['source'],
source=flow.context['source'],
)
self._entries.append(entry)
self._async_schedule_save()
@ -402,11 +403,8 @@ class ConfigEntries:
await async_setup_component(
self.hass, entry.domain, self._hass_config)
# Return Entry if they not from a discovery request
if context['source'] not in DISCOVERY_SOURCES:
return entry
return entry
result['result'] = entry
return result
async def _async_create_flow(self, handler_key, *, context, data):
"""Create a flow for specified handler.

View file

@ -64,7 +64,7 @@ class FlowManager:
return await self._async_handle_step(flow, flow.init_step, data)
async def async_configure(
self, flow_id: str, user_input: Optional[str] = None) -> Any:
self, flow_id: str, user_input: Optional[Dict] = None) -> Any:
"""Continue a configuration flow."""
flow = self._progress.get(flow_id)
@ -86,7 +86,7 @@ class FlowManager:
raise UnknownFlow
async def _async_handle_step(self, flow: Any, step_id: str,
user_input: Optional[str]) -> Dict:
user_input: Optional[Dict]) -> Dict:
"""Handle a step of a flow."""
method = "async_step_{}".format(step_id)
@ -106,14 +106,17 @@ class FlowManager:
flow.cur_step = (result['step_id'], result['data_schema'])
return result
# We pass a copy of the result because we're mutating our version
result = await self._async_finish_flow(flow, dict(result))
# _async_finish_flow may change result type, check it again
if result['type'] == RESULT_TYPE_FORM:
flow.cur_step = (result['step_id'], result['data_schema'])
return result
# Abort and Success results both finish the flow
self._progress.pop(flow.flow_id)
# We pass a copy of the result because we're mutating our version
entry = await self._async_finish_flow(flow.context, dict(result))
if result['type'] == RESULT_TYPE_CREATE_ENTRY:
result['result'] = entry
return result

View file

@ -25,11 +25,12 @@ def manager():
if context is not None else 'user_input'
return flow
async def async_add_entry(context, result):
if (result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY):
result['source'] = context.get('source') \
if context is not None else 'user'
async def async_add_entry(flow, result):
if result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
result['source'] = flow.context.get('source') \
if flow.context is not None else 'user'
entries.append(result)
return result
manager = data_entry_flow.FlowManager(
None, async_create_flow, async_add_entry)
@ -198,3 +199,49 @@ async def test_discovery_init_flow(manager):
assert entry['title'] == 'hello'
assert entry['data'] == data
assert entry['source'] == 'discovery'
async def test_finish_callback_change_result_type(hass):
"""Test finish callback can change result type."""
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 1
async def async_step_init(self, input):
"""Return init form with one input field 'count'."""
if input is not None:
return self.async_create_entry(title='init', data=input)
return self.async_show_form(
step_id='init',
data_schema=vol.Schema({'count': int}))
async def async_create_flow(handler_name, *, context, data):
"""Create a test flow."""
return TestFlow()
async def async_finish_flow(flow, result):
"""Redirect to init form if count <= 1."""
if result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
if (result['data'] is None or
result['data'].get('count', 0) <= 1):
return flow.async_show_form(
step_id='init',
data_schema=vol.Schema({'count': int}))
else:
result['result'] = result['data']['count']
return result
manager = data_entry_flow.FlowManager(
hass, async_create_flow, async_finish_flow)
result = await manager.async_init('test')
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
assert result['step_id'] == 'init'
result = await manager.async_configure(result['flow_id'], {'count': 0})
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
assert result['step_id'] == 'init'
assert 'result' not in result
result = await manager.async_configure(result['flow_id'], {'count': 2})
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result['result'] == 2