Add data entry flow helper (#13935)
* Extract data entry flows HTTP views into helper * Remove use of domain * Lint * Fix tests * Update doc
This commit is contained in:
parent
6e9669c18d
commit
534aa0e4b5
5 changed files with 132 additions and 84 deletions
|
@ -1,11 +1,10 @@
|
||||||
"""Http views to control the config manager."""
|
"""Http views to control the config manager."""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import voluptuous as vol
|
|
||||||
|
|
||||||
from homeassistant import config_entries, data_entry_flow
|
from homeassistant import config_entries, data_entry_flow
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView
|
||||||
from homeassistant.components.http.data_validator import RequestDataValidator
|
from homeassistant.helpers.data_entry_flow import (
|
||||||
|
FlowManagerIndexView, FlowManagerResourceView)
|
||||||
|
|
||||||
|
|
||||||
REQUIREMENTS = ['voluptuous-serialize==1']
|
REQUIREMENTS = ['voluptuous-serialize==1']
|
||||||
|
@ -16,8 +15,10 @@ def async_setup(hass):
|
||||||
"""Enable the Home Assistant views."""
|
"""Enable the Home Assistant views."""
|
||||||
hass.http.register_view(ConfigManagerEntryIndexView)
|
hass.http.register_view(ConfigManagerEntryIndexView)
|
||||||
hass.http.register_view(ConfigManagerEntryResourceView)
|
hass.http.register_view(ConfigManagerEntryResourceView)
|
||||||
hass.http.register_view(ConfigManagerFlowIndexView)
|
hass.http.register_view(
|
||||||
hass.http.register_view(ConfigManagerFlowResourceView)
|
ConfigManagerFlowIndexView(hass.config_entries.flow))
|
||||||
|
hass.http.register_view(
|
||||||
|
ConfigManagerFlowResourceView(hass.config_entries.flow))
|
||||||
hass.http.register_view(ConfigManagerAvailableFlowView)
|
hass.http.register_view(ConfigManagerAvailableFlowView)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -78,7 +79,7 @@ class ConfigManagerEntryResourceView(HomeAssistantView):
|
||||||
return self.json(result)
|
return self.json(result)
|
||||||
|
|
||||||
|
|
||||||
class ConfigManagerFlowIndexView(HomeAssistantView):
|
class ConfigManagerFlowIndexView(FlowManagerIndexView):
|
||||||
"""View to create config flows."""
|
"""View to create config flows."""
|
||||||
|
|
||||||
url = '/api/config/config_entries/flow'
|
url = '/api/config/config_entries/flow'
|
||||||
|
@ -97,78 +98,13 @@ class ConfigManagerFlowIndexView(HomeAssistantView):
|
||||||
flw for flw in hass.config_entries.flow.async_progress()
|
flw for flw in hass.config_entries.flow.async_progress()
|
||||||
if flw['source'] != data_entry_flow.SOURCE_USER])
|
if flw['source'] != data_entry_flow.SOURCE_USER])
|
||||||
|
|
||||||
@RequestDataValidator(vol.Schema({
|
|
||||||
vol.Required('domain'): str,
|
|
||||||
}))
|
|
||||||
@asyncio.coroutine
|
|
||||||
def post(self, request, data):
|
|
||||||
"""Handle a POST request."""
|
|
||||||
hass = request.app['hass']
|
|
||||||
|
|
||||||
try:
|
class ConfigManagerFlowResourceView(FlowManagerResourceView):
|
||||||
result = yield from hass.config_entries.flow.async_init(
|
|
||||||
data['domain'])
|
|
||||||
except data_entry_flow.UnknownHandler:
|
|
||||||
return self.json_message('Invalid handler specified', 404)
|
|
||||||
except data_entry_flow.UnknownStep:
|
|
||||||
return self.json_message('Handler does not support init', 400)
|
|
||||||
|
|
||||||
result = _prepare_json(result)
|
|
||||||
|
|
||||||
return self.json(result)
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigManagerFlowResourceView(HomeAssistantView):
|
|
||||||
"""View to interact with the flow manager."""
|
"""View to interact with the flow manager."""
|
||||||
|
|
||||||
url = '/api/config/config_entries/flow/{flow_id}'
|
url = '/api/config/config_entries/flow/{flow_id}'
|
||||||
name = 'api:config:config_entries:flow:resource'
|
name = 'api:config:config_entries:flow:resource'
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def get(self, request, flow_id):
|
|
||||||
"""Get the current state of a data_entry_flow."""
|
|
||||||
hass = request.app['hass']
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = yield from hass.config_entries.flow.async_configure(
|
|
||||||
flow_id)
|
|
||||||
except data_entry_flow.UnknownFlow:
|
|
||||||
return self.json_message('Invalid flow specified', 404)
|
|
||||||
|
|
||||||
result = _prepare_json(result)
|
|
||||||
|
|
||||||
return self.json(result)
|
|
||||||
|
|
||||||
@RequestDataValidator(vol.Schema(dict), allow_empty=True)
|
|
||||||
@asyncio.coroutine
|
|
||||||
def post(self, request, flow_id, data):
|
|
||||||
"""Handle a POST request."""
|
|
||||||
hass = request.app['hass']
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = yield from hass.config_entries.flow.async_configure(
|
|
||||||
flow_id, data)
|
|
||||||
except data_entry_flow.UnknownFlow:
|
|
||||||
return self.json_message('Invalid flow specified', 404)
|
|
||||||
except vol.Invalid:
|
|
||||||
return self.json_message('User input malformed', 400)
|
|
||||||
|
|
||||||
result = _prepare_json(result)
|
|
||||||
|
|
||||||
return self.json(result)
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def delete(self, request, flow_id):
|
|
||||||
"""Cancel a flow in progress."""
|
|
||||||
hass = request.app['hass']
|
|
||||||
|
|
||||||
try:
|
|
||||||
hass.config_entries.flow.async_abort(flow_id)
|
|
||||||
except data_entry_flow.UnknownFlow:
|
|
||||||
return self.json_message('Invalid flow specified', 404)
|
|
||||||
|
|
||||||
return self.json_message('Flow aborted')
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigManagerAvailableFlowView(HomeAssistantView):
|
class ConfigManagerAvailableFlowView(HomeAssistantView):
|
||||||
"""View to query available flows."""
|
"""View to query available flows."""
|
||||||
|
|
|
@ -338,7 +338,7 @@ class ConfigEntries:
|
||||||
if component not in self.hass.config.components:
|
if component not in self.hass.config.components:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
await entry.async_unload(
|
return await entry.async_unload(
|
||||||
self.hass, component=getattr(self.hass.components, component))
|
self.hass, component=getattr(self.hass.components, component))
|
||||||
|
|
||||||
async def _async_save_entry(self, result):
|
async def _async_save_entry(self, result):
|
||||||
|
@ -362,6 +362,8 @@ class ConfigEntries:
|
||||||
await async_setup_component(
|
await async_setup_component(
|
||||||
self.hass, entry.domain, self._hass_config)
|
self.hass, entry.domain, self._hass_config)
|
||||||
|
|
||||||
|
return entry
|
||||||
|
|
||||||
async def _async_create_flow(self, handler):
|
async def _async_create_flow(self, handler):
|
||||||
"""Create a flow for specified handler.
|
"""Create a flow for specified handler.
|
||||||
|
|
||||||
|
|
|
@ -34,12 +34,12 @@ class UnknownStep(FlowError):
|
||||||
class FlowManager:
|
class FlowManager:
|
||||||
"""Manage all the flows that are in progress."""
|
"""Manage all the flows that are in progress."""
|
||||||
|
|
||||||
def __init__(self, hass, async_create_flow, async_save_entry):
|
def __init__(self, hass, async_create_flow, async_finish_flow):
|
||||||
"""Initialize the flow manager."""
|
"""Initialize the flow manager."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self._progress = {}
|
self._progress = {}
|
||||||
self._async_create_flow = async_create_flow
|
self._async_create_flow = async_create_flow
|
||||||
self._async_save_entry = async_save_entry
|
self._async_finish_flow = async_finish_flow
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_progress(self):
|
def async_progress(self):
|
||||||
|
@ -113,10 +113,8 @@ class FlowManager:
|
||||||
if result['type'] == RESULT_TYPE_ABORT:
|
if result['type'] == RESULT_TYPE_ABORT:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# We pass a copy of the result because we're going to mutate our
|
# We pass a copy of the result because we're mutating our version
|
||||||
# version afterwards and don't want to cause unexpected bugs.
|
result['result'] = await self._async_finish_flow(dict(result))
|
||||||
await self._async_save_entry(dict(result))
|
|
||||||
result.pop('data')
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
106
homeassistant/helpers/data_entry_flow.py
Normal file
106
homeassistant/helpers/data_entry_flow.py
Normal file
|
@ -0,0 +1,106 @@
|
||||||
|
"""Helpers for the data entry flow."""
|
||||||
|
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant import data_entry_flow
|
||||||
|
from homeassistant.components.http import HomeAssistantView
|
||||||
|
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_json(result):
|
||||||
|
"""Convert result for JSON."""
|
||||||
|
if result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||||
|
data = result.copy()
|
||||||
|
data.pop('result')
|
||||||
|
data.pop('data')
|
||||||
|
return data
|
||||||
|
|
||||||
|
elif result['type'] != data_entry_flow.RESULT_TYPE_FORM:
|
||||||
|
return result
|
||||||
|
|
||||||
|
import voluptuous_serialize
|
||||||
|
|
||||||
|
data = result.copy()
|
||||||
|
|
||||||
|
schema = data['data_schema']
|
||||||
|
if schema is None:
|
||||||
|
data['data_schema'] = []
|
||||||
|
else:
|
||||||
|
data['data_schema'] = voluptuous_serialize.convert(schema)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class FlowManagerIndexView(HomeAssistantView):
|
||||||
|
"""View to create config flows."""
|
||||||
|
|
||||||
|
def __init__(self, flow_mgr):
|
||||||
|
"""Initialize the flow manager index view."""
|
||||||
|
self._flow_mgr = flow_mgr
|
||||||
|
|
||||||
|
async def get(self, request):
|
||||||
|
"""List flows that are in progress."""
|
||||||
|
return self.json(self._flow_mgr.async_progress())
|
||||||
|
|
||||||
|
@RequestDataValidator(vol.Schema({
|
||||||
|
vol.Required('handler'): vol.Any(str, list),
|
||||||
|
}))
|
||||||
|
async def post(self, request, data):
|
||||||
|
"""Handle a POST request."""
|
||||||
|
if isinstance(data['handler'], list):
|
||||||
|
handler = tuple(data['handler'])
|
||||||
|
else:
|
||||||
|
handler = data['handler']
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self._flow_mgr.async_init(handler)
|
||||||
|
except data_entry_flow.UnknownHandler:
|
||||||
|
return self.json_message('Invalid handler specified', 404)
|
||||||
|
except data_entry_flow.UnknownStep:
|
||||||
|
return self.json_message('Handler does not support init', 400)
|
||||||
|
|
||||||
|
result = _prepare_json(result)
|
||||||
|
|
||||||
|
return self.json(result)
|
||||||
|
|
||||||
|
|
||||||
|
class FlowManagerResourceView(HomeAssistantView):
|
||||||
|
"""View to interact with the flow manager."""
|
||||||
|
|
||||||
|
def __init__(self, flow_mgr):
|
||||||
|
"""Initialize the flow manager resource view."""
|
||||||
|
self._flow_mgr = flow_mgr
|
||||||
|
|
||||||
|
async def get(self, request, flow_id):
|
||||||
|
"""Get the current state of a data_entry_flow."""
|
||||||
|
try:
|
||||||
|
result = await self._flow_mgr.async_configure(flow_id)
|
||||||
|
except data_entry_flow.UnknownFlow:
|
||||||
|
return self.json_message('Invalid flow specified', 404)
|
||||||
|
|
||||||
|
result = _prepare_json(result)
|
||||||
|
|
||||||
|
return self.json(result)
|
||||||
|
|
||||||
|
@RequestDataValidator(vol.Schema(dict), allow_empty=True)
|
||||||
|
async def post(self, request, flow_id, data):
|
||||||
|
"""Handle a POST request."""
|
||||||
|
try:
|
||||||
|
result = await self._flow_mgr.async_configure(flow_id, data)
|
||||||
|
except data_entry_flow.UnknownFlow:
|
||||||
|
return self.json_message('Invalid flow specified', 404)
|
||||||
|
except vol.Invalid:
|
||||||
|
return self.json_message('User input malformed', 400)
|
||||||
|
|
||||||
|
result = _prepare_json(result)
|
||||||
|
|
||||||
|
return self.json(result)
|
||||||
|
|
||||||
|
async def delete(self, request, flow_id):
|
||||||
|
"""Cancel a flow in progress."""
|
||||||
|
try:
|
||||||
|
self._flow_mgr.async_abort(flow_id)
|
||||||
|
except data_entry_flow.UnknownFlow:
|
||||||
|
return self.json_message('Invalid flow specified', 404)
|
||||||
|
|
||||||
|
return self.json_message('Flow aborted')
|
|
@ -17,6 +17,12 @@ from homeassistant.loader import set_component
|
||||||
from tests.common import MockConfigEntry, MockModule, mock_coro_func
|
from tests.common import MockConfigEntry, MockModule, mock_coro_func
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='session', autouse=True)
|
||||||
|
def mock_test_component():
|
||||||
|
"""Ensure a component called 'test' exists."""
|
||||||
|
set_component('test', MockModule('test'))
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def client(hass, aiohttp_client):
|
def client(hass, aiohttp_client):
|
||||||
"""Fixture that can interact with the config manager API."""
|
"""Fixture that can interact with the config manager API."""
|
||||||
|
@ -111,7 +117,7 @@ def test_initialize_flow(hass, client):
|
||||||
|
|
||||||
with patch.dict(HANDLERS, {'test': TestFlow}):
|
with patch.dict(HANDLERS, {'test': TestFlow}):
|
||||||
resp = yield from client.post('/api/config/config_entries/flow',
|
resp = yield from client.post('/api/config/config_entries/flow',
|
||||||
json={'domain': 'test'})
|
json={'handler': 'test'})
|
||||||
|
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
data = yield from resp.json()
|
data = yield from resp.json()
|
||||||
|
@ -150,7 +156,7 @@ def test_abort(hass, client):
|
||||||
|
|
||||||
with patch.dict(HANDLERS, {'test': TestFlow}):
|
with patch.dict(HANDLERS, {'test': TestFlow}):
|
||||||
resp = yield from client.post('/api/config/config_entries/flow',
|
resp = yield from client.post('/api/config/config_entries/flow',
|
||||||
json={'domain': 'test'})
|
json={'handler': 'test'})
|
||||||
|
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
data = yield from resp.json()
|
data = yield from resp.json()
|
||||||
|
@ -180,7 +186,7 @@ def test_create_account(hass, client):
|
||||||
|
|
||||||
with patch.dict(HANDLERS, {'test': TestFlow}):
|
with patch.dict(HANDLERS, {'test': TestFlow}):
|
||||||
resp = yield from client.post('/api/config/config_entries/flow',
|
resp = yield from client.post('/api/config/config_entries/flow',
|
||||||
json={'domain': 'test'})
|
json={'handler': 'test'})
|
||||||
|
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
data = yield from resp.json()
|
data = yield from resp.json()
|
||||||
|
@ -220,7 +226,7 @@ def test_two_step_flow(hass, client):
|
||||||
|
|
||||||
with patch.dict(HANDLERS, {'test': TestFlow}):
|
with patch.dict(HANDLERS, {'test': TestFlow}):
|
||||||
resp = yield from client.post('/api/config/config_entries/flow',
|
resp = yield from client.post('/api/config/config_entries/flow',
|
||||||
json={'domain': 'test'})
|
json={'handler': 'test'})
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
data = yield from resp.json()
|
data = yield from resp.json()
|
||||||
flow_id = data.pop('flow_id')
|
flow_id = data.pop('flow_id')
|
||||||
|
@ -305,7 +311,7 @@ def test_get_progress_flow(hass, client):
|
||||||
|
|
||||||
with patch.dict(HANDLERS, {'test': TestFlow}):
|
with patch.dict(HANDLERS, {'test': TestFlow}):
|
||||||
resp = yield from client.post('/api/config/config_entries/flow',
|
resp = yield from client.post('/api/config/config_entries/flow',
|
||||||
json={'domain': 'test'})
|
json={'handler': 'test'})
|
||||||
|
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
data = yield from resp.json()
|
data = yield from resp.json()
|
||||||
|
|
Loading…
Add table
Reference in a new issue