Refactor data entry flow (#15883)

* Refactoring data_entry_flow and config_entry_flow

Move SOURCE_* to config_entries
Change data_entry_flow.FlowManager.async_init() source param default
 to None
Change this first step_id as source or init if source is None
_BaseFlowManagerView pass in SOURCE_USER as default source

* First step of data entry flow decided by _async_create_flow() now

* Lint

* Change helpers.config_entry_flow.DiscoveryFlowHandler default step

* Change FlowManager.async_init source param to context dict param
This commit is contained in:
Jason Hu 2018-08-09 04:24:14 -07:00 committed by Paulus Schoutsen
parent 39d19f2183
commit f58425dd3c
25 changed files with 128 additions and 79 deletions

View file

@ -211,7 +211,7 @@ class AuthManager:
return tkn
async def _async_create_login_flow(self, handler, *, source, data):
async def _async_create_login_flow(self, handler, *, context, data):
"""Create a login flow."""
auth_provider = self._providers[handler]

View file

@ -1,5 +1,5 @@
"""Component to embed Google Cast."""
from homeassistant import data_entry_flow
from homeassistant import config_entries
from homeassistant.helpers import config_entry_flow
@ -15,7 +15,7 @@ async def async_setup(hass, config):
if conf is not None:
hass.async_create_task(hass.config_entries.flow.async_init(
DOMAIN, source=data_entry_flow.SOURCE_IMPORT))
DOMAIN, context={'source': config_entries.SOURCE_IMPORT}))
return True

View file

@ -96,7 +96,7 @@ class ConfigManagerFlowIndexView(FlowManagerIndexView):
return self.json([
flw for flw in hass.config_entries.flow.async_progress()
if flw['source'] != data_entry_flow.SOURCE_USER])
if flw['source'] != config_entries.SOURCE_USER])
class ConfigManagerFlowResourceView(FlowManagerResourceView):

View file

@ -6,6 +6,7 @@ https://home-assistant.io/components/deconz/
"""
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.const import (
CONF_API_KEY, CONF_EVENT, CONF_HOST,
CONF_ID, CONF_PORT, EVENT_HOMEASSISTANT_STOP)
@ -60,7 +61,9 @@ async def async_setup(hass, config):
deconz_config = config[DOMAIN]
if deconz_config and not configured_hosts(hass):
hass.async_add_job(hass.config_entries.flow.async_init(
DOMAIN, source='import', data=deconz_config
DOMAIN,
context={'source': config_entries.SOURCE_IMPORT},
data=deconz_config
))
return True

View file

@ -33,6 +33,10 @@ class DeconzFlowHandler(data_entry_flow.FlowHandler):
self.bridges = []
self.deconz_config = {}
async def async_step_user(self, user_input=None):
"""Handle a flow initialized by the user."""
return await self.async_step_init(user_input)
async def async_step_init(self, user_input=None):
"""Handle a deCONZ config flow start.

View file

@ -13,7 +13,7 @@ import os
import voluptuous as vol
from homeassistant import data_entry_flow
from homeassistant import config_entries
from homeassistant.core import callback
from homeassistant.const import EVENT_HOMEASSISTANT_START
import homeassistant.helpers.config_validation as cv
@ -138,7 +138,7 @@ async def async_setup(hass, config):
if service in CONFIG_ENTRY_HANDLERS:
await hass.config_entries.flow.async_init(
CONFIG_ENTRY_HANDLERS[service],
source=data_entry_flow.SOURCE_DISCOVERY,
context={'source': config_entries.SOURCE_DISCOVERY},
data=info
)
return

View file

@ -10,6 +10,7 @@ import logging
import voluptuous as vol
import homeassistant.helpers.config_validation as cv
from homeassistant import config_entries
from .const import (
DOMAIN, HMIPC_HAPID, HMIPC_AUTHTOKEN, HMIPC_NAME,
@ -41,7 +42,8 @@ async def async_setup(hass, config):
for conf in accesspoints:
if conf[CONF_ACCESSPOINT] not in configured_haps(hass):
hass.async_add_job(hass.config_entries.flow.async_init(
DOMAIN, source='import', data={
DOMAIN, context={'source': config_entries.SOURCE_IMPORT},
data={
HMIPC_HAPID: conf[CONF_ACCESSPOINT],
HMIPC_AUTHTOKEN: conf[CONF_AUTHTOKEN],
HMIPC_NAME: conf[CONF_NAME],

View file

@ -27,6 +27,10 @@ class HomematicipCloudFlowHandler(data_entry_flow.FlowHandler):
"""Initialize HomematicIP Cloud config flow."""
self.auth = None
async def async_step_user(self, user_input=None):
"""Handle a flow initialized by the user."""
return await self.async_step_init(user_input)
async def async_step_init(self, user_input=None):
"""Handle a flow start."""
errors = {}

View file

@ -9,7 +9,7 @@ import logging
import voluptuous as vol
from homeassistant import data_entry_flow
from homeassistant import config_entries
from homeassistant.const import CONF_FILENAME, CONF_HOST
from homeassistant.helpers import aiohttp_client, config_validation as cv
@ -108,7 +108,8 @@ async def async_setup(hass, config):
# deadlock: creating a config entry will set up the component but the
# setup would block till the entry is created!
hass.async_add_job(hass.config_entries.flow.async_init(
DOMAIN, source=data_entry_flow.SOURCE_IMPORT, data={
DOMAIN, context={'source': config_entries.SOURCE_IMPORT},
data={
'host': bridge_conf[CONF_HOST],
'path': bridge_conf[CONF_FILENAME],
}

View file

@ -51,7 +51,8 @@ class HueBridge:
# linking procedure. When linking succeeds, it will remove the
# old config entry.
hass.async_add_job(hass.config_entries.flow.async_init(
DOMAIN, source='import', data={
DOMAIN, context={'source': config_entries.SOURCE_IMPORT},
data={
'host': host,
}
))

View file

@ -50,6 +50,10 @@ class HueFlowHandler(data_entry_flow.FlowHandler):
"""Initialize the Hue flow."""
self.host = None
async def async_step_user(self, user_input=None):
"""Handle a flow initialized by the user."""
return await self.async_step_init(user_input)
async def async_step_init(self, user_input=None):
"""Handle a flow start."""
from aiohue.discovery import discover_nupnp

View file

@ -11,6 +11,7 @@ from datetime import datetime, timedelta
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.const import (
CONF_STRUCTURE, CONF_FILENAME, CONF_BINARY_SENSORS, CONF_SENSORS,
CONF_MONITORED_CONDITIONS,
@ -103,7 +104,8 @@ async def async_setup(hass, config):
access_token_cache_file = hass.config.path(filename)
hass.async_add_job(hass.config_entries.flow.async_init(
DOMAIN, source='import', data={
DOMAIN, context={'source': config_entries.SOURCE_IMPORT},
data={
'nest_conf_path': access_token_cache_file,
}
))

View file

@ -58,6 +58,10 @@ class NestFlowHandler(data_entry_flow.FlowHandler):
"""Initialize the Nest config flow."""
self.flow_impl = None
async def async_step_user(self, user_input=None):
"""Handle a flow initialized by the user."""
return await self.async_step_init(user_input)
async def async_step_init(self, user_input=None):
"""Handle a flow start."""
flows = self.hass.data.get(DATA_FLOW_IMPL, {})

View file

@ -1,5 +1,5 @@
"""Component to embed Sonos."""
from homeassistant import data_entry_flow
from homeassistant import config_entries
from homeassistant.helpers import config_entry_flow
@ -15,7 +15,7 @@ async def async_setup(hass, config):
if conf is not None:
hass.async_create_task(hass.config_entries.flow.async_init(
DOMAIN, source=data_entry_flow.SOURCE_IMPORT))
DOMAIN, context={'source': config_entries.SOURCE_IMPORT}))
return True

View file

@ -29,6 +29,10 @@ class ZoneFlowHandler(data_entry_flow.FlowHandler):
"""Initialize zone configuration flow."""
pass
async def async_step_user(self, user_input=None):
"""Handle a flow initialized by the user."""
return await self.async_step_init(user_input)
async def async_step_init(self, user_input=None):
"""Handle a flow start."""
errors = {}

View file

@ -24,20 +24,24 @@ Before instantiating the handler, Home Assistant will make sure to load all
dependencies and install the requirements of the component.
At a minimum, each config flow will have to define a version number and the
'init' step.
'user' step.
@config_entries.HANDLERS.register(DOMAIN)
class ExampleConfigFlow(config_entries.FlowHandler):
class ExampleConfigFlow(data_entry_flow.FlowHandler):
VERSION = 1
async def async_step_init(self, user_input=None):
async def async_step_user(self, user_input=None):
The 'init' step is the first step of a flow and is called when a user
The 'user' step is the first step of a flow and is called when a user
starts a new flow. Each step has three different possible results: "Show Form",
"Abort" and "Create Entry".
> Note: prior 0.76, the default step is 'init' step, some config flows still
keep 'init' step to avoid break localization. All new config flow should use
'user' step.
### Show Form
This will show a form to the user to fill in. You define the current step,
@ -50,7 +54,7 @@ a title, a description and the schema of the data that needs to be returned.
data_schema[vol.Required('password')] = str
return self.async_show_form(
step_id='init',
step_id='user',
title='Account Info',
data_schema=vol.Schema(data_schema)
)
@ -97,10 +101,10 @@ Assistant, a success message is shown to the user and the flow is finished.
You might want to initialize a config flow programmatically. For example, if
we discover a device on the network that requires user interaction to finish
setup. To do so, pass a source parameter and optional user input to the init
step:
method:
await hass.config_entries.flow.async_init(
'hue', source='discovery', data=discovery_info)
'hue', context={'source': 'discovery'}, data=discovery_info)
The config flow handler will need to add a step to support the source. The step
should follow the same return values as a normal step.
@ -123,6 +127,11 @@ from homeassistant.util.decorator import Registry
_LOGGER = logging.getLogger(__name__)
SOURCE_USER = 'user'
SOURCE_DISCOVERY = 'discovery'
SOURCE_IMPORT = 'import'
HANDLERS = Registry()
# Components that have config flows. In future we will auto-generate this list.
FLOWS = [
@ -151,8 +160,8 @@ ENTRY_STATE_FAILED_UNLOAD = 'failed_unload'
DISCOVERY_NOTIFICATION_ID = 'config_entry_discovery'
DISCOVERY_SOURCES = (
data_entry_flow.SOURCE_DISCOVERY,
data_entry_flow.SOURCE_IMPORT,
SOURCE_DISCOVERY,
SOURCE_IMPORT,
)
EVENT_FLOW_DISCOVERED = 'config_entry_discovered'
@ -374,12 +383,15 @@ class ConfigEntries:
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return None
source = result['source']
if source is None:
source = SOURCE_USER
entry = ConfigEntry(
version=result['version'],
domain=result['handler'],
title=result['title'],
data=result['data'],
source=result['source'],
source=source,
)
self._entries.append(entry)
await self._async_schedule_save()
@ -399,17 +411,22 @@ class ConfigEntries:
return entry
async def _async_create_flow(self, handler, *, source, data):
async def _async_create_flow(self, handler_key, *, context, data):
"""Create a flow for specified handler.
Handler key is the domain of the component that we want to setup.
"""
component = getattr(self.hass.components, handler)
handler = HANDLERS.get(handler)
component = getattr(self.hass.components, handler_key)
handler = HANDLERS.get(handler_key)
if handler is None:
raise data_entry_flow.UnknownHandler
if context is not None:
source = context.get('source', SOURCE_USER)
else:
source = SOURCE_USER
# Make sure requirements and dependencies of component are resolved
await async_process_deps_reqs(
self.hass, self._hass_config, handler, component)
@ -424,7 +441,10 @@ class ConfigEntries:
notification_id=DISCOVERY_NOTIFICATION_ID
)
return handler()
flow = handler()
flow.source = source
flow.init_step = source
return flow
async def _async_schedule_save(self):
"""Save the entity registry to a file."""

View file

@ -8,10 +8,6 @@ from .exceptions import HomeAssistantError
_LOGGER = logging.getLogger(__name__)
SOURCE_USER = 'user'
SOURCE_DISCOVERY = 'discovery'
SOURCE_IMPORT = 'import'
RESULT_TYPE_FORM = 'form'
RESULT_TYPE_CREATE_ENTRY = 'create_entry'
RESULT_TYPE_ABORT = 'abort'
@ -53,22 +49,17 @@ class FlowManager:
'source': flow.source,
} for flow in self._progress.values()]
async def async_init(self, handler: Callable, *, source: str = SOURCE_USER,
data: str = None) -> Any:
async def async_init(self, handler: Callable, *, context: Dict = None,
data: Any = None) -> Any:
"""Start a configuration flow."""
flow = await self._async_create_flow(handler, source=source, data=data)
flow = await self._async_create_flow(
handler, context=context, data=data)
flow.hass = self.hass
flow.handler = handler
flow.flow_id = uuid.uuid4().hex
flow.source = source
self._progress[flow.flow_id] = flow
if source == SOURCE_USER:
step = 'init'
else:
step = source
return await self._async_handle_step(flow, step, data)
return await self._async_handle_step(flow, flow.init_step, data)
async def async_configure(
self, flow_id: str, user_input: str = None) -> Any:
@ -131,9 +122,12 @@ class FlowHandler:
flow_id = None
hass = None
handler = None
source = SOURCE_USER
source = None
cur_step = None
# Set by _async_create_flow callback
init_step = 'init'
# Set by developer
VERSION = 1

View file

@ -22,7 +22,7 @@ class DiscoveryFlowHandler(data_entry_flow.FlowHandler):
self._title = title
self._discovery_function = discovery_function
async def async_step_init(self, user_input=None):
async def async_step_user(self, user_input=None):
"""Handle a flow initialized by the user."""
if self._async_current_entries():
return self.async_abort(

View file

@ -2,7 +2,7 @@
import voluptuous as vol
from homeassistant import data_entry_flow
from homeassistant import data_entry_flow, config_entries
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.data_validator import RequestDataValidator
@ -53,7 +53,8 @@ class FlowManagerIndexView(_BaseFlowManagerView):
handler = data['handler']
try:
result = await self._flow_mgr.async_init(handler)
result = await self._flow_mgr.async_init(
handler, context={'source': config_entries.SOURCE_USER})
except data_entry_flow.UnknownHandler:
return self.json_message('Invalid handler specified', 404)
except data_entry_flow.UnknownStep:

View file

@ -12,7 +12,7 @@ import logging
import threading
from contextlib import contextmanager
from homeassistant import auth, core as ha, data_entry_flow, config_entries
from homeassistant import auth, core as ha, config_entries
from homeassistant.auth import (
models as auth_models, auth_store, providers as auth_providers)
from homeassistant.setup import setup_component, async_setup_component
@ -509,7 +509,7 @@ class MockConfigEntry(config_entries.ConfigEntry):
"""Helper for creating config entries that adds some defaults."""
def __init__(self, *, domain='test', data=None, version=0, entry_id=None,
source=data_entry_flow.SOURCE_USER, title='Mock Title',
source=config_entries.SOURCE_USER, title='Mock Title',
state=None):
"""Initialize a mock config entry."""
kwargs = {

View file

@ -102,13 +102,13 @@ def test_initialize_flow(hass, client):
"""Test we can initialize a flow."""
class TestFlow(FlowHandler):
@asyncio.coroutine
def async_step_init(self, user_input=None):
def async_step_user(self, user_input=None):
schema = OrderedDict()
schema[vol.Required('username')] = str
schema[vol.Required('password')] = str
return self.async_show_form(
step_id='init',
step_id='user',
data_schema=schema,
description_placeholders={
'url': 'https://example.com',
@ -130,7 +130,7 @@ def test_initialize_flow(hass, client):
assert data == {
'type': 'form',
'handler': 'test',
'step_id': 'init',
'step_id': 'user',
'data_schema': [
{
'name': 'username',
@ -157,7 +157,7 @@ def test_abort(hass, client):
"""Test a flow that aborts."""
class TestFlow(FlowHandler):
@asyncio.coroutine
def async_step_init(self, user_input=None):
def async_step_user(self, user_input=None):
return self.async_abort(reason='bla')
with patch.dict(HANDLERS, {'test': TestFlow}):
@ -185,7 +185,7 @@ def test_create_account(hass, client):
VERSION = 1
@asyncio.coroutine
def async_step_init(self, user_input=None):
def async_step_user(self, user_input=None):
return self.async_create_entry(
title='Test Entry',
data={'secret': 'account_token'}
@ -218,7 +218,7 @@ def test_two_step_flow(hass, client):
VERSION = 1
@asyncio.coroutine
def async_step_init(self, user_input=None):
def async_step_user(self, user_input=None):
return self.async_show_form(
step_id='account',
data_schema=vol.Schema({
@ -286,7 +286,7 @@ def test_get_progress_index(hass, client):
with patch.dict(HANDLERS, {'test': TestFlow}):
form = yield from hass.config_entries.flow.async_init(
'test', source='hassio')
'test', context={'source': 'hassio'})
resp = yield from client.get('/api/config/config_entries/flow')
assert resp.status == 200
@ -305,13 +305,13 @@ def test_get_progress_flow(hass, client):
"""Test we can query the API for same result as we get from init a flow."""
class TestFlow(FlowHandler):
@asyncio.coroutine
def async_step_init(self, user_input=None):
def async_step_user(self, user_input=None):
schema = OrderedDict()
schema[vol.Required('username')] = str
schema[vol.Required('password')] = str
return self.async_show_form(
step_id='init',
step_id='user',
data_schema=schema,
errors={
'username': 'Should be unique.'

View file

@ -5,7 +5,7 @@ from unittest.mock import patch, MagicMock
import pytest
from homeassistant import data_entry_flow
from homeassistant import config_entries
from homeassistant.bootstrap import async_setup_component
from homeassistant.components import discovery
from homeassistant.util.dt import utcnow
@ -175,5 +175,5 @@ async def test_discover_config_flow(hass):
assert len(m_init.mock_calls) == 1
args, kwargs = m_init.mock_calls[0][1:]
assert args == ('mock-component',)
assert kwargs['source'] == data_entry_flow.SOURCE_DISCOVERY
assert kwargs['context']['source'] == config_entries.SOURCE_DISCOVERY
assert kwargs['data'] == discovery_info

View file

@ -31,7 +31,7 @@ async def test_single_entry_allowed(hass, flow_conf):
flow.hass = hass
MockConfigEntry(domain='test').add_to_hass(hass)
result = await flow.async_step_init()
result = await flow.async_step_user()
assert result['type'] == data_entry_flow.RESULT_TYPE_ABORT
assert result['reason'] == 'single_instance_allowed'
@ -42,7 +42,7 @@ async def test_user_no_devices_found(hass, flow_conf):
flow = config_entries.HANDLERS['test']()
flow.hass = hass
result = await flow.async_step_init()
result = await flow.async_step_user()
assert result['type'] == data_entry_flow.RESULT_TYPE_ABORT
assert result['reason'] == 'no_devices_found'
@ -54,7 +54,7 @@ async def test_user_no_confirmation(hass, flow_conf):
flow.hass = hass
flow_conf['discovered'] = True
result = await flow.async_step_init()
result = await flow.async_step_user()
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
@ -90,12 +90,12 @@ async def test_multiple_discoveries(hass, flow_conf):
loader.set_component(hass, 'test', MockModule('test'))
result = await hass.config_entries.flow.async_init(
'test', source=data_entry_flow.SOURCE_DISCOVERY, data={})
'test', context={'source': config_entries.SOURCE_DISCOVERY}, data={})
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
# Second discovery
result = await hass.config_entries.flow.async_init(
'test', source=data_entry_flow.SOURCE_DISCOVERY, data={})
'test', context={'source': config_entries.SOURCE_DISCOVERY}, data={})
assert result['type'] == data_entry_flow.RESULT_TYPE_ABORT
@ -105,7 +105,7 @@ async def test_user_init_trumps_discovery(hass, flow_conf):
# Discovery starts flow
result = await hass.config_entries.flow.async_init(
'test', source=data_entry_flow.SOURCE_DISCOVERY, data={})
'test', context={'source': config_entries.SOURCE_DISCOVERY}, data={})
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
# User starts flow

View file

@ -108,7 +108,7 @@ def test_add_entry_calls_setup_entry(hass, manager):
VERSION = 1
@asyncio.coroutine
def async_step_init(self, user_input=None):
def async_step_user(self, user_input=None):
return self.async_create_entry(
title='title',
data={
@ -162,7 +162,7 @@ async def test_saving_and_loading(hass):
VERSION = 5
@asyncio.coroutine
def async_step_init(self, user_input=None):
def async_step_user(self, user_input=None):
return self.async_create_entry(
title='Test Title',
data={
@ -177,7 +177,7 @@ async def test_saving_and_loading(hass):
VERSION = 3
@asyncio.coroutine
def async_step_init(self, user_input=None):
def async_step_user(self, user_input=None):
return self.async_create_entry(
title='Test 2 Title',
data={
@ -266,7 +266,7 @@ async def test_discovery_notification(hass):
with patch.dict(config_entries.HANDLERS, {'test': TestFlow}):
result = await hass.config_entries.flow.async_init(
'test', source=data_entry_flow.SOURCE_DISCOVERY)
'test', context={'source': config_entries.SOURCE_DISCOVERY})
await hass.async_block_till_done()
state = hass.states.get('persistent_notification.config_entry_discovery')
@ -294,7 +294,7 @@ async def test_discovery_notification_not_created(hass):
with patch.dict(config_entries.HANDLERS, {'test': TestFlow}):
await hass.config_entries.flow.async_init(
'test', source=data_entry_flow.SOURCE_DISCOVERY)
'test', context={'source': config_entries.SOURCE_DISCOVERY})
await hass.async_block_till_done()
state = hass.states.get('persistent_notification.config_entry_discovery')

View file

@ -12,13 +12,18 @@ def manager():
handlers = Registry()
entries = []
async def async_create_flow(handler_name, *, source, data):
async def async_create_flow(handler_name, *, context, data):
handler = handlers.get(handler_name)
if handler is None:
raise data_entry_flow.UnknownHandler
return handler()
flow = handler()
flow.init_step = context.get('init_step', 'init') \
if context is not None else 'init'
flow.source = context.get('source') \
if context is not None else 'user_input'
return flow
async def async_add_entry(result):
if (result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY):
@ -57,12 +62,12 @@ async def test_configure_two_steps(manager):
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 1
async def async_step_init(self, user_input=None):
async def async_step_first(self, user_input=None):
if user_input is not None:
self.init_data = user_input
return await self.async_step_second()
return self.async_show_form(
step_id='init',
step_id='first',
data_schema=vol.Schema([str])
)
@ -77,7 +82,7 @@ async def test_configure_two_steps(manager):
data_schema=vol.Schema([str])
)
form = await manager.async_init('test')
form = await manager.async_init('test', context={'init_step': 'first'})
with pytest.raises(vol.Invalid):
form = await manager.async_configure(
@ -163,7 +168,7 @@ async def test_create_saves_data(manager):
assert entry['handler'] == 'test'
assert entry['title'] == 'Test Title'
assert entry['data'] == 'Test Data'
assert entry['source'] == data_entry_flow.SOURCE_USER
assert entry['source'] == 'user_input'
async def test_discovery_init_flow(manager):
@ -172,7 +177,7 @@ async def test_discovery_init_flow(manager):
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 5
async def async_step_discovery(self, info):
async def async_step_init(self, info):
return self.async_create_entry(title=info['id'], data=info)
data = {
@ -181,7 +186,7 @@ async def test_discovery_init_flow(manager):
}
await manager.async_init(
'test', source=data_entry_flow.SOURCE_DISCOVERY, data=data)
'test', context={'source': 'discovery'}, data=data)
assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 1
@ -190,4 +195,4 @@ async def test_discovery_init_flow(manager):
assert entry['handler'] == 'test'
assert entry['title'] == 'hello'
assert entry['data'] == data
assert entry['source'] == data_entry_flow.SOURCE_DISCOVERY
assert entry['source'] == 'discovery'