From 2a76a0852f969bfcfbb77a9826715884a4ab41d4 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 19 Jul 2018 08:37:00 +0200 Subject: [PATCH] Allow CORS requests to token endpoint (#15519) * Allow CORS requests to token endpoint * Tests * Fuck emulated hue * Clean up * Only cors existing methods --- homeassistant/components/auth/__init__.py | 1 + .../components/emulated_hue/__init__.py | 12 ++++++------ homeassistant/components/http/__init__.py | 5 ++--- homeassistant/components/http/cors.py | 14 ++++++++++++++ homeassistant/components/http/view.py | 19 ++++++++++++------- tests/components/auth/test_init.py | 17 +++++++++++++++++ tests/components/emulated_hue/test_hue_api.py | 8 ++++---- tests/components/http/test_cors.py | 4 ++-- tests/components/http/test_data_validator.py | 2 +- 9 files changed, 59 insertions(+), 23 deletions(-) diff --git a/homeassistant/components/auth/__init__.py b/homeassistant/components/auth/__init__.py index 84287c2e425..435555c2e31 100644 --- a/homeassistant/components/auth/__init__.py +++ b/homeassistant/components/auth/__init__.py @@ -241,6 +241,7 @@ class GrantTokenView(HomeAssistantView): url = '/auth/token' name = 'api:auth:token' requires_auth = False + cors_allowed = True def __init__(self, retrieve_credentials): """Initialize the grant token view.""" diff --git a/homeassistant/components/emulated_hue/__init__.py b/homeassistant/components/emulated_hue/__init__.py index ce94a560dae..36ce1c392f9 100644 --- a/homeassistant/components/emulated_hue/__init__.py +++ b/homeassistant/components/emulated_hue/__init__.py @@ -90,12 +90,12 @@ def setup(hass, yaml_config): handler = None server = None - DescriptionXmlView(config).register(app.router) - HueUsernameView().register(app.router) - HueAllLightsStateView(config).register(app.router) - HueOneLightStateView(config).register(app.router) - HueOneLightChangeView(config).register(app.router) - HueGroupView(config).register(app.router) + DescriptionXmlView(config).register(app, app.router) + HueUsernameView().register(app, app.router) + HueAllLightsStateView(config).register(app, app.router) + HueOneLightStateView(config).register(app, app.router) + HueOneLightChangeView(config).register(app, app.router) + HueGroupView(config).register(app, app.router) upnp_listener = UPNPResponderThread( config.host_ip_addr, config.listen_port, diff --git a/homeassistant/components/http/__init__.py b/homeassistant/components/http/__init__.py index c8eba41e66b..0cbee628a8a 100644 --- a/homeassistant/components/http/__init__.py +++ b/homeassistant/components/http/__init__.py @@ -187,8 +187,7 @@ class HomeAssistantHTTP(object): support_legacy=hass.auth.support_legacy, api_password=api_password) - if cors_origins: - setup_cors(app, cors_origins) + setup_cors(app, cors_origins) app['hass'] = hass @@ -226,7 +225,7 @@ class HomeAssistantHTTP(object): '{0} missing required attribute "name"'.format(class_name) ) - view.register(self.app.router) + view.register(self.app, self.app.router) def register_redirect(self, url, redirect_to): """Register a redirect with the server. diff --git a/homeassistant/components/http/cors.py b/homeassistant/components/http/cors.py index 0a37f22867e..b01e68f701d 100644 --- a/homeassistant/components/http/cors.py +++ b/homeassistant/components/http/cors.py @@ -27,6 +27,20 @@ def setup_cors(app, origins): ) for host in origins }) + def allow_cors(route, methods): + """Allow cors on a route.""" + cors.add(route, { + '*': aiohttp_cors.ResourceOptions( + allow_headers=ALLOWED_CORS_HEADERS, + allow_methods=methods, + ) + }) + + app['allow_cors'] = allow_cors + + if not origins: + return + async def cors_startup(app): """Initialize cors when app starts up.""" cors_added = set() diff --git a/homeassistant/components/http/view.py b/homeassistant/components/http/view.py index 3de276564eb..23698af8101 100644 --- a/homeassistant/components/http/view.py +++ b/homeassistant/components/http/view.py @@ -26,7 +26,9 @@ class HomeAssistantView(object): url = None extra_urls = [] - requires_auth = True # Views inheriting from this class can override this + # Views inheriting from this class can override this + requires_auth = True + cors_allowed = False # pylint: disable=no-self-use def json(self, result, status_code=200, headers=None): @@ -51,10 +53,11 @@ class HomeAssistantView(object): data['code'] = message_code return self.json(data, status_code, headers=headers) - def register(self, router): + def register(self, app, router): """Register the view with a router.""" assert self.url is not None, 'No url set for view' urls = [self.url] + self.extra_urls + routes = [] for method in ('get', 'post', 'delete', 'put'): handler = getattr(self, method, None) @@ -65,13 +68,15 @@ class HomeAssistantView(object): handler = request_handler_factory(self, handler) for url in urls: - router.add_route(method, url, handler) + routes.append( + (method, router.add_route(method, url, handler)) + ) - # aiohttp_cors does not work with class based views - # self.app.router.add_route('*', self.url, self, name=self.name) + if not self.cors_allowed: + return - # for url in self.extra_urls: - # self.app.router.add_route('*', url, self) + for method, route in routes: + app['allow_cors'](route, [method.upper()]) def request_handler_factory(view, handler): diff --git a/tests/components/auth/test_init.py b/tests/components/auth/test_init.py index 46b88e46b4d..1d3719b8c66 100644 --- a/tests/components/auth/test_init.py +++ b/tests/components/auth/test_init.py @@ -93,3 +93,20 @@ async def test_ws_current_user(hass, hass_ws_client, hass_access_token): assert user_dict['name'] == user.name assert user_dict['id'] == user.id assert user_dict['is_owner'] == user.is_owner + + +async def test_cors_on_token(hass, aiohttp_client): + """Test logging in with new user and refreshing tokens.""" + client = await async_setup_auth(hass, aiohttp_client) + + resp = await client.options('/auth/token', headers={ + 'origin': 'http://example.com', + 'Access-Control-Request-Method': 'POST', + }) + assert resp.headers['Access-Control-Allow-Origin'] == 'http://example.com' + assert resp.headers['Access-Control-Allow-Methods'] == 'POST' + + resp = await client.post('/auth/token', headers={ + 'origin': 'http://example.com' + }) + assert resp.headers['Access-Control-Allow-Origin'] == 'http://example.com' diff --git a/tests/components/emulated_hue/test_hue_api.py b/tests/components/emulated_hue/test_hue_api.py index 1617f327d27..c99d273a458 100644 --- a/tests/components/emulated_hue/test_hue_api.py +++ b/tests/components/emulated_hue/test_hue_api.py @@ -130,10 +130,10 @@ def hue_client(loop, hass_hue, aiohttp_client): } }) - HueUsernameView().register(web_app.router) - HueAllLightsStateView(config).register(web_app.router) - HueOneLightStateView(config).register(web_app.router) - HueOneLightChangeView(config).register(web_app.router) + HueUsernameView().register(web_app, web_app.router) + HueAllLightsStateView(config).register(web_app, web_app.router) + HueOneLightStateView(config).register(web_app, web_app.router) + HueOneLightChangeView(config).register(web_app, web_app.router) return loop.run_until_complete(aiohttp_client(web_app)) diff --git a/tests/components/http/test_cors.py b/tests/components/http/test_cors.py index 27367b4173e..523d4943ba0 100644 --- a/tests/components/http/test_cors.py +++ b/tests/components/http/test_cors.py @@ -19,14 +19,14 @@ from homeassistant.components.http.cors import setup_cors TRUSTED_ORIGIN = 'https://home-assistant.io' -async def test_cors_middleware_not_loaded_by_default(hass): +async def test_cors_middleware_loaded_by_default(hass): """Test accessing to server from banned IP when feature is off.""" with patch('homeassistant.components.http.setup_cors') as mock_setup: await async_setup_component(hass, 'http', { 'http': {} }) - assert len(mock_setup.mock_calls) == 0 + assert len(mock_setup.mock_calls) == 1 async def test_cors_middleware_loaded_from_config(hass): diff --git a/tests/components/http/test_data_validator.py b/tests/components/http/test_data_validator.py index 2b966daff6c..b5eed19eb61 100644 --- a/tests/components/http/test_data_validator.py +++ b/tests/components/http/test_data_validator.py @@ -23,7 +23,7 @@ async def get_client(aiohttp_client, validator): """Test method.""" return b'' - TestView().register(app.router) + TestView().register(app, app.router) client = await aiohttp_client(app) return client