Allow finish_flow callback to change data entry result type (#16100)
* Allow finish_flow callback to change data entry result type * Add unit test
This commit is contained in:
parent
b26506ad4a
commit
00c6f56cc8
4 changed files with 80 additions and 29 deletions
|
@ -2,7 +2,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Dict, List, Optional, Tuple, cast, Union
|
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
|
|
||||||
|
@ -256,20 +256,23 @@ class AuthManager:
|
||||||
return await auth_provider.async_credential_flow(context)
|
return await auth_provider.async_credential_flow(context)
|
||||||
|
|
||||||
async def _async_finish_login_flow(
|
async def _async_finish_login_flow(
|
||||||
self, context: Optional[Dict], result: Dict[str, Any]) \
|
self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any]) \
|
||||||
-> Optional[Union[models.User, models.Credentials]]:
|
-> Dict[str, Any]:
|
||||||
"""Return a user as result of login flow."""
|
"""Return a user as result of login flow."""
|
||||||
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||||
return None
|
return result
|
||||||
|
|
||||||
auth_provider = self._providers[result['handler']]
|
auth_provider = self._providers[result['handler']]
|
||||||
cred = await auth_provider.async_get_or_create_credentials(
|
credentials = await auth_provider.async_get_or_create_credentials(
|
||||||
result['data'])
|
result['data'])
|
||||||
|
|
||||||
if context is not None and context.get('credential_only'):
|
if flow.context is not None and flow.context.get('credential_only'):
|
||||||
return cred
|
result['result'] = credentials
|
||||||
|
return result
|
||||||
|
|
||||||
return await self.async_get_or_create_user(cred)
|
user = await self.async_get_or_create_user(credentials)
|
||||||
|
result['result'] = user
|
||||||
|
return result
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_get_auth_provider(
|
def _async_get_auth_provider(
|
||||||
|
|
|
@ -372,23 +372,24 @@ class ConfigEntries:
|
||||||
return await entry.async_unload(
|
return await entry.async_unload(
|
||||||
self.hass, component=getattr(self.hass.components, component))
|
self.hass, component=getattr(self.hass.components, component))
|
||||||
|
|
||||||
async def _async_finish_flow(self, context, result):
|
async def _async_finish_flow(self, flow, result):
|
||||||
"""Finish a config flow and add an entry."""
|
"""Finish a config flow and add an entry."""
|
||||||
# If no discovery config entries in progress, remove notification.
|
# Remove notification if no other discovery config entries in progress
|
||||||
if not any(ent['context']['source'] in DISCOVERY_SOURCES for ent
|
if not any(ent['context']['source'] in DISCOVERY_SOURCES for ent
|
||||||
in self.hass.config_entries.flow.async_progress()):
|
in self.hass.config_entries.flow.async_progress()
|
||||||
|
if ent['flow_id'] != flow.flow_id):
|
||||||
self.hass.components.persistent_notification.async_dismiss(
|
self.hass.components.persistent_notification.async_dismiss(
|
||||||
DISCOVERY_NOTIFICATION_ID)
|
DISCOVERY_NOTIFICATION_ID)
|
||||||
|
|
||||||
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||||
return None
|
return result
|
||||||
|
|
||||||
entry = ConfigEntry(
|
entry = ConfigEntry(
|
||||||
version=result['version'],
|
version=result['version'],
|
||||||
domain=result['handler'],
|
domain=result['handler'],
|
||||||
title=result['title'],
|
title=result['title'],
|
||||||
data=result['data'],
|
data=result['data'],
|
||||||
source=context['source'],
|
source=flow.context['source'],
|
||||||
)
|
)
|
||||||
self._entries.append(entry)
|
self._entries.append(entry)
|
||||||
self._async_schedule_save()
|
self._async_schedule_save()
|
||||||
|
@ -402,11 +403,8 @@ class ConfigEntries:
|
||||||
await async_setup_component(
|
await async_setup_component(
|
||||||
self.hass, entry.domain, self._hass_config)
|
self.hass, entry.domain, self._hass_config)
|
||||||
|
|
||||||
# Return Entry if they not from a discovery request
|
result['result'] = entry
|
||||||
if context['source'] not in DISCOVERY_SOURCES:
|
return result
|
||||||
return entry
|
|
||||||
|
|
||||||
return entry
|
|
||||||
|
|
||||||
async def _async_create_flow(self, handler_key, *, context, data):
|
async def _async_create_flow(self, handler_key, *, context, data):
|
||||||
"""Create a flow for specified handler.
|
"""Create a flow for specified handler.
|
||||||
|
|
|
@ -64,7 +64,7 @@ class FlowManager:
|
||||||
return await self._async_handle_step(flow, flow.init_step, data)
|
return await self._async_handle_step(flow, flow.init_step, data)
|
||||||
|
|
||||||
async def async_configure(
|
async def async_configure(
|
||||||
self, flow_id: str, user_input: Optional[str] = None) -> Any:
|
self, flow_id: str, user_input: Optional[Dict] = None) -> Any:
|
||||||
"""Continue a configuration flow."""
|
"""Continue a configuration flow."""
|
||||||
flow = self._progress.get(flow_id)
|
flow = self._progress.get(flow_id)
|
||||||
|
|
||||||
|
@ -86,7 +86,7 @@ class FlowManager:
|
||||||
raise UnknownFlow
|
raise UnknownFlow
|
||||||
|
|
||||||
async def _async_handle_step(self, flow: Any, step_id: str,
|
async def _async_handle_step(self, flow: Any, step_id: str,
|
||||||
user_input: Optional[str]) -> Dict:
|
user_input: Optional[Dict]) -> Dict:
|
||||||
"""Handle a step of a flow."""
|
"""Handle a step of a flow."""
|
||||||
method = "async_step_{}".format(step_id)
|
method = "async_step_{}".format(step_id)
|
||||||
|
|
||||||
|
@ -106,14 +106,17 @@ class FlowManager:
|
||||||
flow.cur_step = (result['step_id'], result['data_schema'])
|
flow.cur_step = (result['step_id'], result['data_schema'])
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
# We pass a copy of the result because we're mutating our version
|
||||||
|
result = await self._async_finish_flow(flow, dict(result))
|
||||||
|
|
||||||
|
# _async_finish_flow may change result type, check it again
|
||||||
|
if result['type'] == RESULT_TYPE_FORM:
|
||||||
|
flow.cur_step = (result['step_id'], result['data_schema'])
|
||||||
|
return result
|
||||||
|
|
||||||
# Abort and Success results both finish the flow
|
# Abort and Success results both finish the flow
|
||||||
self._progress.pop(flow.flow_id)
|
self._progress.pop(flow.flow_id)
|
||||||
|
|
||||||
# We pass a copy of the result because we're mutating our version
|
|
||||||
entry = await self._async_finish_flow(flow.context, dict(result))
|
|
||||||
|
|
||||||
if result['type'] == RESULT_TYPE_CREATE_ENTRY:
|
|
||||||
result['result'] = entry
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -25,11 +25,12 @@ def manager():
|
||||||
if context is not None else 'user_input'
|
if context is not None else 'user_input'
|
||||||
return flow
|
return flow
|
||||||
|
|
||||||
async def async_add_entry(context, result):
|
async def async_add_entry(flow, result):
|
||||||
if (result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY):
|
if result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||||
result['source'] = context.get('source') \
|
result['source'] = flow.context.get('source') \
|
||||||
if context is not None else 'user'
|
if flow.context is not None else 'user'
|
||||||
entries.append(result)
|
entries.append(result)
|
||||||
|
return result
|
||||||
|
|
||||||
manager = data_entry_flow.FlowManager(
|
manager = data_entry_flow.FlowManager(
|
||||||
None, async_create_flow, async_add_entry)
|
None, async_create_flow, async_add_entry)
|
||||||
|
@ -198,3 +199,49 @@ async def test_discovery_init_flow(manager):
|
||||||
assert entry['title'] == 'hello'
|
assert entry['title'] == 'hello'
|
||||||
assert entry['data'] == data
|
assert entry['data'] == data
|
||||||
assert entry['source'] == 'discovery'
|
assert entry['source'] == 'discovery'
|
||||||
|
|
||||||
|
|
||||||
|
async def test_finish_callback_change_result_type(hass):
|
||||||
|
"""Test finish callback can change result type."""
|
||||||
|
class TestFlow(data_entry_flow.FlowHandler):
|
||||||
|
VERSION = 1
|
||||||
|
|
||||||
|
async def async_step_init(self, input):
|
||||||
|
"""Return init form with one input field 'count'."""
|
||||||
|
if input is not None:
|
||||||
|
return self.async_create_entry(title='init', data=input)
|
||||||
|
return self.async_show_form(
|
||||||
|
step_id='init',
|
||||||
|
data_schema=vol.Schema({'count': int}))
|
||||||
|
|
||||||
|
async def async_create_flow(handler_name, *, context, data):
|
||||||
|
"""Create a test flow."""
|
||||||
|
return TestFlow()
|
||||||
|
|
||||||
|
async def async_finish_flow(flow, result):
|
||||||
|
"""Redirect to init form if count <= 1."""
|
||||||
|
if result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||||
|
if (result['data'] is None or
|
||||||
|
result['data'].get('count', 0) <= 1):
|
||||||
|
return flow.async_show_form(
|
||||||
|
step_id='init',
|
||||||
|
data_schema=vol.Schema({'count': int}))
|
||||||
|
else:
|
||||||
|
result['result'] = result['data']['count']
|
||||||
|
return result
|
||||||
|
|
||||||
|
manager = data_entry_flow.FlowManager(
|
||||||
|
hass, async_create_flow, async_finish_flow)
|
||||||
|
|
||||||
|
result = await manager.async_init('test')
|
||||||
|
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
||||||
|
assert result['step_id'] == 'init'
|
||||||
|
|
||||||
|
result = await manager.async_configure(result['flow_id'], {'count': 0})
|
||||||
|
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
||||||
|
assert result['step_id'] == 'init'
|
||||||
|
assert 'result' not in result
|
||||||
|
|
||||||
|
result = await manager.async_configure(result['flow_id'], {'count': 2})
|
||||||
|
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||||
|
assert result['result'] == 2
|
||||||
|
|
Loading…
Add table
Reference in a new issue