Allow finish_flow callback to change data entry result type (#16100)

* Allow finish_flow callback to change data entry result type

* Add unit test
This commit is contained in:
Jason Hu 2018-08-21 10:48:24 -07:00 committed by GitHub
parent b26506ad4a
commit 00c6f56cc8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 80 additions and 29 deletions

View file

@ -2,7 +2,7 @@
import asyncio import asyncio
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, cast, Union from typing import Any, Dict, List, Optional, Tuple, cast
import jwt import jwt
@ -256,20 +256,23 @@ class AuthManager:
return await auth_provider.async_credential_flow(context) return await auth_provider.async_credential_flow(context)
async def _async_finish_login_flow( async def _async_finish_login_flow(
self, context: Optional[Dict], result: Dict[str, Any]) \ self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any]) \
-> Optional[Union[models.User, models.Credentials]]: -> Dict[str, Any]:
"""Return a user as result of login flow.""" """Return a user as result of login flow."""
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY: if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return None return result
auth_provider = self._providers[result['handler']] auth_provider = self._providers[result['handler']]
cred = await auth_provider.async_get_or_create_credentials( credentials = await auth_provider.async_get_or_create_credentials(
result['data']) result['data'])
if context is not None and context.get('credential_only'): if flow.context is not None and flow.context.get('credential_only'):
return cred result['result'] = credentials
return result
return await self.async_get_or_create_user(cred) user = await self.async_get_or_create_user(credentials)
result['result'] = user
return result
@callback @callback
def _async_get_auth_provider( def _async_get_auth_provider(

View file

@ -372,23 +372,24 @@ class ConfigEntries:
return await entry.async_unload( return await entry.async_unload(
self.hass, component=getattr(self.hass.components, component)) self.hass, component=getattr(self.hass.components, component))
async def _async_finish_flow(self, context, result): async def _async_finish_flow(self, flow, result):
"""Finish a config flow and add an entry.""" """Finish a config flow and add an entry."""
# If no discovery config entries in progress, remove notification. # Remove notification if no other discovery config entries in progress
if not any(ent['context']['source'] in DISCOVERY_SOURCES for ent if not any(ent['context']['source'] in DISCOVERY_SOURCES for ent
in self.hass.config_entries.flow.async_progress()): in self.hass.config_entries.flow.async_progress()
if ent['flow_id'] != flow.flow_id):
self.hass.components.persistent_notification.async_dismiss( self.hass.components.persistent_notification.async_dismiss(
DISCOVERY_NOTIFICATION_ID) DISCOVERY_NOTIFICATION_ID)
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY: if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return None return result
entry = ConfigEntry( entry = ConfigEntry(
version=result['version'], version=result['version'],
domain=result['handler'], domain=result['handler'],
title=result['title'], title=result['title'],
data=result['data'], data=result['data'],
source=context['source'], source=flow.context['source'],
) )
self._entries.append(entry) self._entries.append(entry)
self._async_schedule_save() self._async_schedule_save()
@ -402,11 +403,8 @@ class ConfigEntries:
await async_setup_component( await async_setup_component(
self.hass, entry.domain, self._hass_config) self.hass, entry.domain, self._hass_config)
# Return Entry if they not from a discovery request result['result'] = entry
if context['source'] not in DISCOVERY_SOURCES: return result
return entry
return entry
async def _async_create_flow(self, handler_key, *, context, data): async def _async_create_flow(self, handler_key, *, context, data):
"""Create a flow for specified handler. """Create a flow for specified handler.

View file

@ -64,7 +64,7 @@ class FlowManager:
return await self._async_handle_step(flow, flow.init_step, data) return await self._async_handle_step(flow, flow.init_step, data)
async def async_configure( async def async_configure(
self, flow_id: str, user_input: Optional[str] = None) -> Any: self, flow_id: str, user_input: Optional[Dict] = None) -> Any:
"""Continue a configuration flow.""" """Continue a configuration flow."""
flow = self._progress.get(flow_id) flow = self._progress.get(flow_id)
@ -86,7 +86,7 @@ class FlowManager:
raise UnknownFlow raise UnknownFlow
async def _async_handle_step(self, flow: Any, step_id: str, async def _async_handle_step(self, flow: Any, step_id: str,
user_input: Optional[str]) -> Dict: user_input: Optional[Dict]) -> Dict:
"""Handle a step of a flow.""" """Handle a step of a flow."""
method = "async_step_{}".format(step_id) method = "async_step_{}".format(step_id)
@ -106,14 +106,17 @@ class FlowManager:
flow.cur_step = (result['step_id'], result['data_schema']) flow.cur_step = (result['step_id'], result['data_schema'])
return result return result
# We pass a copy of the result because we're mutating our version
result = await self._async_finish_flow(flow, dict(result))
# _async_finish_flow may change result type, check it again
if result['type'] == RESULT_TYPE_FORM:
flow.cur_step = (result['step_id'], result['data_schema'])
return result
# Abort and Success results both finish the flow # Abort and Success results both finish the flow
self._progress.pop(flow.flow_id) self._progress.pop(flow.flow_id)
# We pass a copy of the result because we're mutating our version
entry = await self._async_finish_flow(flow.context, dict(result))
if result['type'] == RESULT_TYPE_CREATE_ENTRY:
result['result'] = entry
return result return result

View file

@ -25,11 +25,12 @@ def manager():
if context is not None else 'user_input' if context is not None else 'user_input'
return flow return flow
async def async_add_entry(context, result): async def async_add_entry(flow, result):
if (result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY): if result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
result['source'] = context.get('source') \ result['source'] = flow.context.get('source') \
if context is not None else 'user' if flow.context is not None else 'user'
entries.append(result) entries.append(result)
return result
manager = data_entry_flow.FlowManager( manager = data_entry_flow.FlowManager(
None, async_create_flow, async_add_entry) None, async_create_flow, async_add_entry)
@ -198,3 +199,49 @@ async def test_discovery_init_flow(manager):
assert entry['title'] == 'hello' assert entry['title'] == 'hello'
assert entry['data'] == data assert entry['data'] == data
assert entry['source'] == 'discovery' assert entry['source'] == 'discovery'
async def test_finish_callback_change_result_type(hass):
"""Test finish callback can change result type."""
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 1
async def async_step_init(self, input):
"""Return init form with one input field 'count'."""
if input is not None:
return self.async_create_entry(title='init', data=input)
return self.async_show_form(
step_id='init',
data_schema=vol.Schema({'count': int}))
async def async_create_flow(handler_name, *, context, data):
"""Create a test flow."""
return TestFlow()
async def async_finish_flow(flow, result):
"""Redirect to init form if count <= 1."""
if result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
if (result['data'] is None or
result['data'].get('count', 0) <= 1):
return flow.async_show_form(
step_id='init',
data_schema=vol.Schema({'count': int}))
else:
result['result'] = result['data']['count']
return result
manager = data_entry_flow.FlowManager(
hass, async_create_flow, async_finish_flow)
result = await manager.async_init('test')
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
assert result['step_id'] == 'init'
result = await manager.async_configure(result['flow_id'], {'count': 0})
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
assert result['step_id'] == 'init'
assert 'result' not in result
result = await manager.async_configure(result['flow_id'], {'count': 2})
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result['result'] == 2