From 534aa0e4b54b992ac55de0ae576276e94089ce49 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Tue, 17 Apr 2018 05:44:32 -0400 Subject: [PATCH] Add data entry flow helper (#13935) * Extract data entry flows HTTP views into helper * Remove use of domain * Lint * Fix tests * Update doc --- .../components/config/config_entries.py | 80 ++----------- homeassistant/config_entries.py | 4 +- homeassistant/data_entry_flow.py | 10 +- homeassistant/helpers/data_entry_flow.py | 106 ++++++++++++++++++ .../components/config/test_config_entries.py | 16 ++- 5 files changed, 132 insertions(+), 84 deletions(-) create mode 100644 homeassistant/helpers/data_entry_flow.py diff --git a/homeassistant/components/config/config_entries.py b/homeassistant/components/config/config_entries.py index 967317134c2..d2aa918eda2 100644 --- a/homeassistant/components/config/config_entries.py +++ b/homeassistant/components/config/config_entries.py @@ -1,11 +1,10 @@ """Http views to control the config manager.""" import asyncio -import voluptuous as vol - from homeassistant import config_entries, data_entry_flow from homeassistant.components.http import HomeAssistantView -from homeassistant.components.http.data_validator import RequestDataValidator +from homeassistant.helpers.data_entry_flow import ( + FlowManagerIndexView, FlowManagerResourceView) REQUIREMENTS = ['voluptuous-serialize==1'] @@ -16,8 +15,10 @@ def async_setup(hass): """Enable the Home Assistant views.""" hass.http.register_view(ConfigManagerEntryIndexView) hass.http.register_view(ConfigManagerEntryResourceView) - hass.http.register_view(ConfigManagerFlowIndexView) - hass.http.register_view(ConfigManagerFlowResourceView) + hass.http.register_view( + ConfigManagerFlowIndexView(hass.config_entries.flow)) + hass.http.register_view( + ConfigManagerFlowResourceView(hass.config_entries.flow)) hass.http.register_view(ConfigManagerAvailableFlowView) return True @@ -78,7 +79,7 @@ class ConfigManagerEntryResourceView(HomeAssistantView): return self.json(result) -class ConfigManagerFlowIndexView(HomeAssistantView): +class ConfigManagerFlowIndexView(FlowManagerIndexView): """View to create config flows.""" url = '/api/config/config_entries/flow' @@ -97,78 +98,13 @@ class ConfigManagerFlowIndexView(HomeAssistantView): flw for flw in hass.config_entries.flow.async_progress() if flw['source'] != data_entry_flow.SOURCE_USER]) - @RequestDataValidator(vol.Schema({ - vol.Required('domain'): str, - })) - @asyncio.coroutine - def post(self, request, data): - """Handle a POST request.""" - hass = request.app['hass'] - try: - result = yield from hass.config_entries.flow.async_init( - data['domain']) - except data_entry_flow.UnknownHandler: - return self.json_message('Invalid handler specified', 404) - except data_entry_flow.UnknownStep: - return self.json_message('Handler does not support init', 400) - - result = _prepare_json(result) - - return self.json(result) - - -class ConfigManagerFlowResourceView(HomeAssistantView): +class ConfigManagerFlowResourceView(FlowManagerResourceView): """View to interact with the flow manager.""" url = '/api/config/config_entries/flow/{flow_id}' name = 'api:config:config_entries:flow:resource' - @asyncio.coroutine - def get(self, request, flow_id): - """Get the current state of a data_entry_flow.""" - hass = request.app['hass'] - - try: - result = yield from hass.config_entries.flow.async_configure( - flow_id) - except data_entry_flow.UnknownFlow: - return self.json_message('Invalid flow specified', 404) - - result = _prepare_json(result) - - return self.json(result) - - @RequestDataValidator(vol.Schema(dict), allow_empty=True) - @asyncio.coroutine - def post(self, request, flow_id, data): - """Handle a POST request.""" - hass = request.app['hass'] - - try: - result = yield from hass.config_entries.flow.async_configure( - flow_id, data) - except data_entry_flow.UnknownFlow: - return self.json_message('Invalid flow specified', 404) - except vol.Invalid: - return self.json_message('User input malformed', 400) - - result = _prepare_json(result) - - return self.json(result) - - @asyncio.coroutine - def delete(self, request, flow_id): - """Cancel a flow in progress.""" - hass = request.app['hass'] - - try: - hass.config_entries.flow.async_abort(flow_id) - except data_entry_flow.UnknownFlow: - return self.json_message('Invalid flow specified', 404) - - return self.json_message('Flow aborted') - class ConfigManagerAvailableFlowView(HomeAssistantView): """View to query available flows.""" diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index e143c94197e..46bb2f7bfe2 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -338,7 +338,7 @@ class ConfigEntries: if component not in self.hass.config.components: return True - await entry.async_unload( + return await entry.async_unload( self.hass, component=getattr(self.hass.components, component)) async def _async_save_entry(self, result): @@ -362,6 +362,8 @@ class ConfigEntries: await async_setup_component( self.hass, entry.domain, self._hass_config) + return entry + async def _async_create_flow(self, handler): """Create a flow for specified handler. diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index 361b6653cfd..cadec3f3d69 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -34,12 +34,12 @@ class UnknownStep(FlowError): class FlowManager: """Manage all the flows that are in progress.""" - def __init__(self, hass, async_create_flow, async_save_entry): + def __init__(self, hass, async_create_flow, async_finish_flow): """Initialize the flow manager.""" self.hass = hass self._progress = {} self._async_create_flow = async_create_flow - self._async_save_entry = async_save_entry + self._async_finish_flow = async_finish_flow @callback def async_progress(self): @@ -113,10 +113,8 @@ class FlowManager: if result['type'] == RESULT_TYPE_ABORT: return result - # We pass a copy of the result because we're going to mutate our - # version afterwards and don't want to cause unexpected bugs. - await self._async_save_entry(dict(result)) - result.pop('data') + # We pass a copy of the result because we're mutating our version + result['result'] = await self._async_finish_flow(dict(result)) return result diff --git a/homeassistant/helpers/data_entry_flow.py b/homeassistant/helpers/data_entry_flow.py new file mode 100644 index 00000000000..a8aca2fd2e9 --- /dev/null +++ b/homeassistant/helpers/data_entry_flow.py @@ -0,0 +1,106 @@ +"""Helpers for the data entry flow.""" + +import voluptuous as vol + +from homeassistant import data_entry_flow +from homeassistant.components.http import HomeAssistantView +from homeassistant.components.http.data_validator import RequestDataValidator + + +def _prepare_json(result): + """Convert result for JSON.""" + if result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY: + data = result.copy() + data.pop('result') + data.pop('data') + return data + + elif result['type'] != data_entry_flow.RESULT_TYPE_FORM: + return result + + import voluptuous_serialize + + data = result.copy() + + schema = data['data_schema'] + if schema is None: + data['data_schema'] = [] + else: + data['data_schema'] = voluptuous_serialize.convert(schema) + + return data + + +class FlowManagerIndexView(HomeAssistantView): + """View to create config flows.""" + + def __init__(self, flow_mgr): + """Initialize the flow manager index view.""" + self._flow_mgr = flow_mgr + + async def get(self, request): + """List flows that are in progress.""" + return self.json(self._flow_mgr.async_progress()) + + @RequestDataValidator(vol.Schema({ + vol.Required('handler'): vol.Any(str, list), + })) + async def post(self, request, data): + """Handle a POST request.""" + if isinstance(data['handler'], list): + handler = tuple(data['handler']) + else: + handler = data['handler'] + + try: + result = await self._flow_mgr.async_init(handler) + except data_entry_flow.UnknownHandler: + return self.json_message('Invalid handler specified', 404) + except data_entry_flow.UnknownStep: + return self.json_message('Handler does not support init', 400) + + result = _prepare_json(result) + + return self.json(result) + + +class FlowManagerResourceView(HomeAssistantView): + """View to interact with the flow manager.""" + + def __init__(self, flow_mgr): + """Initialize the flow manager resource view.""" + self._flow_mgr = flow_mgr + + async def get(self, request, flow_id): + """Get the current state of a data_entry_flow.""" + try: + result = await self._flow_mgr.async_configure(flow_id) + except data_entry_flow.UnknownFlow: + return self.json_message('Invalid flow specified', 404) + + result = _prepare_json(result) + + return self.json(result) + + @RequestDataValidator(vol.Schema(dict), allow_empty=True) + async def post(self, request, flow_id, data): + """Handle a POST request.""" + try: + result = await self._flow_mgr.async_configure(flow_id, data) + except data_entry_flow.UnknownFlow: + return self.json_message('Invalid flow specified', 404) + except vol.Invalid: + return self.json_message('User input malformed', 400) + + result = _prepare_json(result) + + return self.json(result) + + async def delete(self, request, flow_id): + """Cancel a flow in progress.""" + try: + self._flow_mgr.async_abort(flow_id) + except data_entry_flow.UnknownFlow: + return self.json_message('Invalid flow specified', 404) + + return self.json_message('Flow aborted') diff --git a/tests/components/config/test_config_entries.py b/tests/components/config/test_config_entries.py index 70cb6c3fbaa..f53be8818a3 100644 --- a/tests/components/config/test_config_entries.py +++ b/tests/components/config/test_config_entries.py @@ -17,6 +17,12 @@ from homeassistant.loader import set_component from tests.common import MockConfigEntry, MockModule, mock_coro_func +@pytest.fixture(scope='session', autouse=True) +def mock_test_component(): + """Ensure a component called 'test' exists.""" + set_component('test', MockModule('test')) + + @pytest.fixture def client(hass, aiohttp_client): """Fixture that can interact with the config manager API.""" @@ -111,7 +117,7 @@ def test_initialize_flow(hass, client): with patch.dict(HANDLERS, {'test': TestFlow}): resp = yield from client.post('/api/config/config_entries/flow', - json={'domain': 'test'}) + json={'handler': 'test'}) assert resp.status == 200 data = yield from resp.json() @@ -150,7 +156,7 @@ def test_abort(hass, client): with patch.dict(HANDLERS, {'test': TestFlow}): resp = yield from client.post('/api/config/config_entries/flow', - json={'domain': 'test'}) + json={'handler': 'test'}) assert resp.status == 200 data = yield from resp.json() @@ -180,7 +186,7 @@ def test_create_account(hass, client): with patch.dict(HANDLERS, {'test': TestFlow}): resp = yield from client.post('/api/config/config_entries/flow', - json={'domain': 'test'}) + json={'handler': 'test'}) assert resp.status == 200 data = yield from resp.json() @@ -220,7 +226,7 @@ def test_two_step_flow(hass, client): with patch.dict(HANDLERS, {'test': TestFlow}): resp = yield from client.post('/api/config/config_entries/flow', - json={'domain': 'test'}) + json={'handler': 'test'}) assert resp.status == 200 data = yield from resp.json() flow_id = data.pop('flow_id') @@ -305,7 +311,7 @@ def test_get_progress_flow(hass, client): with patch.dict(HANDLERS, {'test': TestFlow}): resp = yield from client.post('/api/config/config_entries/flow', - json={'domain': 'test'}) + json={'handler': 'test'}) assert resp.status == 200 data = yield from resp.json()