diff --git a/homeassistant/auth.py b/homeassistant/auth.py index e6760cd9096..ae191f24c61 100644 --- a/homeassistant/auth.py +++ b/homeassistant/auth.py @@ -186,16 +186,6 @@ class Credentials: is_new = attr.ib(type=bool, default=True) -@attr.s(slots=True) -class Client: - """Client that interacts with Home Assistant on behalf of a user.""" - - name = attr.ib(type=str) - id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex)) - secret = attr.ib(type=str, default=attr.Factory(generate_secret)) - redirect_uris = attr.ib(type=list, default=attr.Factory(list)) - - async def load_auth_provider_module(hass, provider): """Load an auth provider.""" try: @@ -356,20 +346,20 @@ class AuthManager: """Remove a user.""" await self._store.async_remove_user(user) - async def async_create_refresh_token(self, user, client=None): + async def async_create_refresh_token(self, user, client_id=None): """Create a new refresh token for a user.""" if not user.is_active: raise ValueError('User is not active') - if user.system_generated and client is not None: + if user.system_generated and client_id is not None: raise ValueError( 'System generated users cannot have refresh tokens connected ' 'to a client.') - if not user.system_generated and client is None: + if not user.system_generated and client_id is None: raise ValueError('Client is required to generate a refresh token.') - return await self._store.async_create_refresh_token(user, client) + return await self._store.async_create_refresh_token(user, client_id) async def async_get_refresh_token(self, token): """Get refresh token by token.""" @@ -396,26 +386,6 @@ class AuthManager: return tkn - async def async_create_client(self, name, *, redirect_uris=None, - no_secret=False): - """Create a new client.""" - return await self._store.async_create_client( - name, redirect_uris, no_secret) - - async def async_get_or_create_client(self, name, *, redirect_uris=None, - no_secret=False): - """Find a client, if not exists, create a new one.""" - for client in await self._store.async_get_clients(): - if client.name == name: - return client - - return await self._store.async_create_client( - name, redirect_uris, no_secret) - - async def async_get_client(self, client_id): - """Get a client.""" - return await self._store.async_get_client(client_id) - async def _async_create_login_flow(self, handler, *, source, data): """Create a login flow.""" auth_provider = self._providers[handler] @@ -456,7 +426,6 @@ class AuthStore: """Initialize the auth store.""" self.hass = hass self._users = None - self._clients = None self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) async def async_get_users(self): @@ -515,9 +484,8 @@ class AuthStore: self._users.pop(user.id) await self.async_save() - async def async_create_refresh_token(self, user, client=None): + async def async_create_refresh_token(self, user, client_id=None): """Create a new token for a user.""" - client_id = client.id if client is not None else None refresh_token = RefreshToken(user=user, client_id=client_id) user.refresh_tokens[refresh_token.token] = refresh_token await self.async_save() @@ -535,38 +503,6 @@ class AuthStore: return None - async def async_create_client(self, name, redirect_uris, no_secret): - """Create a new client.""" - if self._clients is None: - await self.async_load() - - kwargs = { - 'name': name, - 'redirect_uris': redirect_uris - } - - if no_secret: - kwargs['secret'] = None - - client = Client(**kwargs) - self._clients[client.id] = client - await self.async_save() - return client - - async def async_get_clients(self): - """Return all clients.""" - if self._clients is None: - await self.async_load() - - return list(self._clients.values()) - - async def async_get_client(self, client_id): - """Get a client.""" - if self._clients is None: - await self.async_load() - - return self._clients.get(client_id) - async def async_load(self): """Load the users.""" data = await self._store.async_load() @@ -578,7 +514,6 @@ class AuthStore: if data is None: self._users = {} - self._clients = {} return users = { @@ -618,12 +553,7 @@ class AuthStore: ) refresh_token.access_tokens.append(token) - clients = { - cl_dict['id']: Client(**cl_dict) for cl_dict in data['clients'] - } - self._users = users - self._clients = clients async def async_save(self): """Save users.""" @@ -676,19 +606,8 @@ class AuthStore: for access_token in refresh_token.access_tokens ] - clients = [ - { - 'id': client.id, - 'name': client.name, - 'secret': client.secret, - 'redirect_uris': client.redirect_uris, - } - for client in self._clients.values() - ] - data = { 'users': users, - 'clients': clients, 'credentials': credentials, 'access_tokens': access_tokens, 'refresh_tokens': refresh_tokens, diff --git a/homeassistant/components/auth/__init__.py b/homeassistant/components/auth/__init__.py index 511999c52ab..c41b417576e 100644 --- a/homeassistant/components/auth/__init__.py +++ b/homeassistant/components/auth/__init__.py @@ -115,7 +115,8 @@ from homeassistant.helpers.data_entry_flow import ( from homeassistant.components.http.view import HomeAssistantView from homeassistant.components.http.data_validator import RequestDataValidator -from .client import verify_client +from . import indieauth + DOMAIN = 'auth' DEPENDENCIES = ['http'] @@ -143,8 +144,7 @@ class AuthProvidersView(HomeAssistantView): name = 'api:auth:providers' requires_auth = False - @verify_client - async def get(self, request, client): + async def get(self, request): """Get available auth providers.""" return self.json([{ 'name': provider.name, @@ -164,16 +164,16 @@ class LoginFlowIndexView(FlowManagerIndexView): """Do not allow index of flows in progress.""" return aiohttp.web.Response(status=405) - # pylint: disable=arguments-differ - @verify_client @RequestDataValidator(vol.Schema({ + vol.Required('client_id'): str, vol.Required('handler'): vol.Any(str, list), vol.Required('redirect_uri'): str, })) - async def post(self, request, client, data): + async def post(self, request, data): """Create a new login flow.""" - if data['redirect_uri'] not in client.redirect_uris: - return self.json_message('invalid redirect uri', ) + if not indieauth.verify_redirect_uri(data['client_id'], + data['redirect_uri']): + return self.json_message('invalid client id or redirect uri', 400) # pylint: disable=no-value-for-parameter return await super().post(request) @@ -191,16 +191,20 @@ class LoginFlowResourceView(FlowManagerResourceView): super().__init__(flow_mgr) self._store_credentials = store_credentials - # pylint: disable=arguments-differ - async def get(self, request): + async def get(self, request, flow_id): """Do not allow getting status of a flow in progress.""" return self.json_message('Invalid flow specified', 404) - # pylint: disable=arguments-differ - @verify_client - @RequestDataValidator(vol.Schema(dict), allow_empty=True) - async def post(self, request, client, flow_id, data): + @RequestDataValidator(vol.Schema({ + 'client_id': str + }, extra=vol.ALLOW_EXTRA)) + async def post(self, request, flow_id, data): """Handle progressing a login flow request.""" + client_id = data.pop('client_id') + + if not indieauth.verify_client_id(client_id): + return self.json_message('Invalid client id', 400) + try: result = await self._flow_mgr.async_configure(flow_id, data) except data_entry_flow.UnknownFlow: @@ -212,7 +216,7 @@ class LoginFlowResourceView(FlowManagerResourceView): return self.json(self._prepare_result_json(result)) result.pop('data') - result['result'] = self._store_credentials(client.id, result['result']) + result['result'] = self._store_credentials(client_id, result['result']) return self.json(result) @@ -228,24 +232,31 @@ class GrantTokenView(HomeAssistantView): """Initialize the grant token view.""" self._retrieve_credentials = retrieve_credentials - @verify_client - async def post(self, request, client): + async def post(self, request): """Grant a token.""" hass = request.app['hass'] data = await request.post() + + client_id = data.get('client_id') + if client_id is None or not indieauth.verify_client_id(client_id): + return self.json({ + 'error': 'invalid_request', + }, status_code=400) + grant_type = data.get('grant_type') if grant_type == 'authorization_code': - return await self._async_handle_auth_code(hass, client, data) + return await self._async_handle_auth_code(hass, client_id, data) elif grant_type == 'refresh_token': - return await self._async_handle_refresh_token(hass, client, data) + return await self._async_handle_refresh_token( + hass, client_id, data) return self.json({ 'error': 'unsupported_grant_type', }, status_code=400) - async def _async_handle_auth_code(self, hass, client, data): + async def _async_handle_auth_code(self, hass, client_id, data): """Handle authorization code request.""" code = data.get('code') @@ -254,7 +265,7 @@ class GrantTokenView(HomeAssistantView): 'error': 'invalid_request', }, status_code=400) - credentials = self._retrieve_credentials(client.id, code) + credentials = self._retrieve_credentials(client_id, code) if credentials is None: return self.json({ @@ -263,7 +274,7 @@ class GrantTokenView(HomeAssistantView): user = await hass.auth.async_get_or_create_user(credentials) refresh_token = await hass.auth.async_create_refresh_token(user, - client) + client_id) access_token = hass.auth.async_create_access_token(refresh_token) return self.json({ @@ -274,7 +285,7 @@ class GrantTokenView(HomeAssistantView): int(refresh_token.access_token_expiration.total_seconds()), }) - async def _async_handle_refresh_token(self, hass, client, data): + async def _async_handle_refresh_token(self, hass, client_id, data): """Handle authorization code request.""" token = data.get('refresh_token') @@ -285,7 +296,7 @@ class GrantTokenView(HomeAssistantView): refresh_token = await hass.auth.async_get_refresh_token(token) - if refresh_token is None or refresh_token.client_id != client.id: + if refresh_token is None or refresh_token.client_id != client_id: return self.json({ 'error': 'invalid_grant', }, status_code=400) diff --git a/homeassistant/components/auth/client.py b/homeassistant/components/auth/client.py deleted file mode 100644 index 122c3032188..00000000000 --- a/homeassistant/components/auth/client.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Helpers to resolve client ID/secret.""" -import base64 -from functools import wraps -import hmac - -import aiohttp.hdrs - - -def verify_client(method): - """Decorator to verify client id/secret on requests.""" - @wraps(method) - async def wrapper(view, request, *args, **kwargs): - """Verify client id/secret before doing request.""" - client = await _verify_client(request) - - if client is None: - return view.json({ - 'error': 'invalid_client', - }, status_code=401) - - return await method( - view, request, *args, **kwargs, client=client) - - return wrapper - - -async def _verify_client(request): - """Method to verify the client id/secret in consistent time. - - By using a consistent time for looking up client id and comparing the - secret, we prevent attacks by malicious actors trying different client ids - and are able to derive from the time it takes to process the request if - they guessed the client id correctly. - """ - if aiohttp.hdrs.AUTHORIZATION not in request.headers: - return None - - auth_type, auth_value = \ - request.headers.get(aiohttp.hdrs.AUTHORIZATION).split(' ', 1) - - if auth_type != 'Basic': - return None - - decoded = base64.b64decode(auth_value).decode('utf-8') - try: - client_id, client_secret = decoded.split(':', 1) - except ValueError: - # If no ':' in decoded - client_id, client_secret = decoded, None - - return await async_secure_get_client( - request.app['hass'], client_id, client_secret) - - -async def async_secure_get_client(hass, client_id, client_secret): - """Get a client id/secret in consistent time.""" - client = await hass.auth.async_get_client(client_id) - - if client is None: - if client_secret is not None: - # Still do a compare so we run same time as if a client was found. - hmac.compare_digest(client_secret.encode('utf-8'), - client_secret.encode('utf-8')) - return None - - if client.secret is None: - return client - - elif client_secret is None: - # Still do a compare so we run same time as if a secret was passed. - hmac.compare_digest(client.secret.encode('utf-8'), - client.secret.encode('utf-8')) - return None - - elif hmac.compare_digest(client_secret.encode('utf-8'), - client.secret.encode('utf-8')): - return client - - return None diff --git a/homeassistant/components/auth/indieauth.py b/homeassistant/components/auth/indieauth.py new file mode 100644 index 00000000000..ef7f8a9b292 --- /dev/null +++ b/homeassistant/components/auth/indieauth.py @@ -0,0 +1,130 @@ +"""Helpers to resolve client ID/secret.""" +from ipaddress import ip_address, ip_network +from urllib.parse import urlparse + +# IP addresses of loopback interfaces +ALLOWED_IPS = ( + ip_address('127.0.0.1'), + ip_address('::1'), +) + +# RFC1918 - Address allocation for Private Internets +ALLOWED_NETWORKS = ( + ip_network('10.0.0.0/8'), + ip_network('172.16.0.0/12'), + ip_network('192.168.0.0/16'), +) + + +def verify_redirect_uri(client_id, redirect_uri): + """Verify that the client and redirect uri match.""" + try: + client_id_parts = _parse_client_id(client_id) + except ValueError: + return False + + redirect_parts = _parse_url(redirect_uri) + + # IndieAuth 4.2.2 allows for redirect_uri to be on different domain + # but needs to be specified in link tag when fetching `client_id`. + # This is not implemented. + + # Verify redirect url and client url have same scheme and domain. + return ( + client_id_parts.scheme == redirect_parts.scheme and + client_id_parts.netloc == redirect_parts.netloc + ) + + +def verify_client_id(client_id): + """Verify that the client id is valid.""" + try: + _parse_client_id(client_id) + return True + except ValueError: + return False + + +def _parse_url(url): + """Parse a url in parts and canonicalize according to IndieAuth.""" + parts = urlparse(url) + + # Canonicalize a url according to IndieAuth 3.2. + + # SHOULD convert the hostname to lowercase + parts = parts._replace(netloc=parts.netloc.lower()) + + # If a URL with no path component is ever encountered, + # it MUST be treated as if it had the path /. + if parts.path == '': + parts = parts._replace(path='/') + + return parts + + +def _parse_client_id(client_id): + """Test if client id is a valid URL according to IndieAuth section 3.2. + + https://indieauth.spec.indieweb.org/#client-identifier + """ + parts = _parse_url(client_id) + + # Client identifier URLs + # MUST have either an https or http scheme + if parts.scheme not in ('http', 'https'): + raise ValueError() + + # MUST contain a path component + # Handled by url canonicalization. + + # MUST NOT contain single-dot or double-dot path segments + if any(segment in ('.', '..') for segment in parts.path.split('/')): + raise ValueError( + 'Client ID cannot contain single-dot or double-dot path segments') + + # MUST NOT contain a fragment component + if parts.fragment != '': + raise ValueError('Client ID cannot contain a fragment') + + # MUST NOT contain a username or password component + if parts.username is not None: + raise ValueError('Client ID cannot contain username') + + if parts.password is not None: + raise ValueError('Client ID cannot contain password') + + # MAY contain a port + try: + # parts raises ValueError when port cannot be parsed as int + parts.port + except ValueError: + raise ValueError('Client ID contains invalid port') + + # Additionally, hostnames + # MUST be domain names or a loopback interface and + # MUST NOT be IPv4 or IPv6 addresses except for IPv4 127.0.0.1 + # or IPv6 [::1] + + # We are not goint to follow the spec here. We are going to allow + # any internal network IP to be used inside a client id. + + address = None + + try: + netloc = parts.netloc + + # Strip the [, ] from ipv6 addresses before parsing + if netloc[0] == '[' and netloc[-1] == ']': + netloc = netloc[1:-1] + + address = ip_address(netloc) + except ValueError: + # Not an ip address + pass + + if (address is None or + address in ALLOWED_IPS or + any(address in network for network in ALLOWED_NETWORKS)): + return parts + + raise ValueError('Hostname should be a domain name or local IP address') diff --git a/homeassistant/components/frontend/__init__.py b/homeassistant/components/frontend/__init__.py index a6fb8735a66..4304742021f 100644 --- a/homeassistant/components/frontend/__init__.py +++ b/homeassistant/components/frontend/__init__.py @@ -200,15 +200,6 @@ def add_manifest_json_key(key, val): async def async_setup(hass, config): """Set up the serving of the frontend.""" - if hass.auth.active: - client = await hass.auth.async_get_or_create_client( - 'Home Assistant Frontend', - redirect_uris=['/'], - no_secret=True, - ) - else: - client = None - hass.components.websocket_api.async_register_command( WS_TYPE_GET_PANELS, websocket_get_panels, SCHEMA_GET_PANELS) hass.components.websocket_api.async_register_command( @@ -255,7 +246,7 @@ async def async_setup(hass, config): if os.path.isdir(local): hass.http.register_static_path("/local", local, not is_dev) - index_view = IndexView(repo_path, js_version, client) + index_view = IndexView(repo_path, js_version, hass.auth.active) hass.http.register_view(index_view) @callback @@ -350,11 +341,11 @@ class IndexView(HomeAssistantView): requires_auth = False extra_urls = ['/states', '/states/{extra}'] - def __init__(self, repo_path, js_option, client): + def __init__(self, repo_path, js_option, auth_active): """Initialize the frontend view.""" self.repo_path = repo_path self.js_option = js_option - self.client = client + self.auth_active = auth_active self._template_cache = {} def get_template(self, latest): @@ -399,11 +390,9 @@ class IndexView(HomeAssistantView): no_auth=no_auth, theme_color=MANIFEST_JSON['theme_color'], extra_urls=hass.data[extra_key], + client_id=self.auth_active ) - if self.client is not None: - template_params['client_id'] = self.client.id - return web.Response(text=template.render(**template_params), content_type='text/html') diff --git a/tests/common.py b/tests/common.py index ccb8f49ea97..98a3b0a6074 100644 --- a/tests/common.py +++ b/tests/common.py @@ -31,6 +31,8 @@ from homeassistant.util.async_ import ( _TEST_INSTANCE_PORT = SERVER_PORT _LOGGER = logging.getLogger(__name__) INSTANCES = [] +CLIENT_ID = 'https://example.com/app' +CLIENT_REDIRECT_URI = 'https://example.com/app/callback' def threadsafe_callback_factory(func): @@ -330,8 +332,6 @@ class MockUser(auth.User): def ensure_auth_manager_loaded(auth_mgr): """Ensure an auth manager is considered loaded.""" store = auth_mgr._store - if store._clients is None: - store._clients = {} if store._users is None: store._users = {} diff --git a/tests/components/auth/__init__.py b/tests/components/auth/__init__.py index 21719c12569..ce94d1ecbfa 100644 --- a/tests/components/auth/__init__.py +++ b/tests/components/auth/__init__.py @@ -1,6 +1,4 @@ """Tests for the auth component.""" -from aiohttp.helpers import BasicAuth - from homeassistant import auth from homeassistant.setup import async_setup_component @@ -16,10 +14,6 @@ BASE_CONFIG = [{ 'name': 'Test Name' }] }] -CLIENT_ID = 'test-id' -CLIENT_SECRET = 'test-secret' -CLIENT_AUTH = BasicAuth(CLIENT_ID, CLIENT_SECRET) -CLIENT_REDIRECT_URI = 'http://example.com/callback' async def async_setup_auth(hass, aiohttp_client, provider_configs=BASE_CONFIG, @@ -32,9 +26,6 @@ async def async_setup_auth(hass, aiohttp_client, provider_configs=BASE_CONFIG, 'api_password': 'bla' } }) - client = auth.Client('Test Client', CLIENT_ID, CLIENT_SECRET, - redirect_uris=[CLIENT_REDIRECT_URI]) - hass.auth._store._clients[client.id] = client if setup_api: await async_setup_component(hass, 'api', {}) return await aiohttp_client(hass.http.app) diff --git a/tests/components/auth/test_client.py b/tests/components/auth/test_client.py deleted file mode 100644 index 65ad22efae2..00000000000 --- a/tests/components/auth/test_client.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Tests for the client validator.""" -from aiohttp.helpers import BasicAuth -import pytest - -from homeassistant.setup import async_setup_component -from homeassistant.components.auth.client import verify_client -from homeassistant.components.http.view import HomeAssistantView - -from . import async_setup_auth - - -@pytest.fixture -def mock_view(hass): - """Register a view that verifies client id/secret.""" - hass.loop.run_until_complete(async_setup_component(hass, 'http', {})) - - clients = [] - - class ClientView(HomeAssistantView): - url = '/' - name = 'bla' - - @verify_client - async def get(self, request, client): - """Handle GET request.""" - clients.append(client) - - hass.http.register_view(ClientView) - return clients - - -async def test_verify_client(hass, aiohttp_client, mock_view): - """Test that verify client can extract client auth from a request.""" - http_client = await async_setup_auth(hass, aiohttp_client) - client = await hass.auth.async_create_client('Hello') - - resp = await http_client.get('/', auth=BasicAuth(client.id, client.secret)) - assert resp.status == 200 - assert mock_view[0] is client - - -async def test_verify_client_no_auth_header(hass, aiohttp_client, mock_view): - """Test that verify client will decline unknown client id.""" - http_client = await async_setup_auth(hass, aiohttp_client) - - resp = await http_client.get('/') - assert resp.status == 401 - assert mock_view == [] - - -async def test_verify_client_invalid_client_id(hass, aiohttp_client, - mock_view): - """Test that verify client will decline unknown client id.""" - http_client = await async_setup_auth(hass, aiohttp_client) - client = await hass.auth.async_create_client('Hello') - - resp = await http_client.get('/', auth=BasicAuth('invalid', client.secret)) - assert resp.status == 401 - assert mock_view == [] - - -async def test_verify_client_invalid_client_secret(hass, aiohttp_client, - mock_view): - """Test that verify client will decline incorrect client secret.""" - http_client = await async_setup_auth(hass, aiohttp_client) - client = await hass.auth.async_create_client('Hello') - - resp = await http_client.get('/', auth=BasicAuth(client.id, 'invalid')) - assert resp.status == 401 - assert mock_view == [] diff --git a/tests/components/auth/test_indieauth.py b/tests/components/auth/test_indieauth.py new file mode 100644 index 00000000000..7bd720ddf70 --- /dev/null +++ b/tests/components/auth/test_indieauth.py @@ -0,0 +1,110 @@ +"""Tests for the client validator.""" +from homeassistant.components.auth import indieauth + +import pytest + + +def test_client_id_scheme(): + """Test we enforce valid scheme.""" + assert indieauth._parse_client_id('http://ex.com/') + assert indieauth._parse_client_id('https://ex.com/') + + with pytest.raises(ValueError): + indieauth._parse_client_id('ftp://ex.com') + + +def test_client_id_path(): + """Test we enforce valid path.""" + assert indieauth._parse_client_id('http://ex.com').path == '/' + assert indieauth._parse_client_id('http://ex.com/hello').path == '/hello' + assert indieauth._parse_client_id( + 'http://ex.com/hello/.world').path == '/hello/.world' + assert indieauth._parse_client_id( + 'http://ex.com/hello./.world').path == '/hello./.world' + + with pytest.raises(ValueError): + indieauth._parse_client_id('http://ex.com/.') + + with pytest.raises(ValueError): + indieauth._parse_client_id('http://ex.com/hello/./yo') + + with pytest.raises(ValueError): + indieauth._parse_client_id('http://ex.com/hello/../yo') + + +def test_client_id_fragment(): + """Test we enforce valid fragment.""" + with pytest.raises(ValueError): + indieauth._parse_client_id('http://ex.com/#yoo') + + +def test_client_id_user_pass(): + """Test we enforce valid username/password.""" + with pytest.raises(ValueError): + indieauth._parse_client_id('http://user@ex.com/') + + with pytest.raises(ValueError): + indieauth._parse_client_id('http://user:pass@ex.com/') + + +def test_client_id_hostname(): + """Test we enforce valid hostname.""" + assert indieauth._parse_client_id('http://www.home-assistant.io/') + assert indieauth._parse_client_id('http://[::1]') + assert indieauth._parse_client_id('http://127.0.0.1') + assert indieauth._parse_client_id('http://10.0.0.0') + assert indieauth._parse_client_id('http://10.255.255.255') + assert indieauth._parse_client_id('http://172.16.0.0') + assert indieauth._parse_client_id('http://172.31.255.255') + assert indieauth._parse_client_id('http://192.168.0.0') + assert indieauth._parse_client_id('http://192.168.255.255') + + with pytest.raises(ValueError): + assert indieauth._parse_client_id('http://255.255.255.255/') + with pytest.raises(ValueError): + assert indieauth._parse_client_id('http://11.0.0.0/') + with pytest.raises(ValueError): + assert indieauth._parse_client_id('http://172.32.0.0/') + with pytest.raises(ValueError): + assert indieauth._parse_client_id('http://192.167.0.0/') + + +def test_parse_url_lowercase_host(): + """Test we update empty paths.""" + assert indieauth._parse_url('http://ex.com/hello').path == '/hello' + assert indieauth._parse_url('http://EX.COM/hello').hostname == 'ex.com' + + parts = indieauth._parse_url('http://EX.COM:123/HELLO') + assert parts.netloc == 'ex.com:123' + assert parts.path == '/HELLO' + + +def test_parse_url_path(): + """Test we update empty paths.""" + assert indieauth._parse_url('http://ex.com').path == '/' + + +def test_verify_redirect_uri(): + """Test that we verify redirect uri correctly.""" + assert indieauth.verify_redirect_uri( + 'http://ex.com', + 'http://ex.com/callback' + ) + + # Different domain + assert not indieauth.verify_redirect_uri( + 'http://ex.com', + 'http://different.com/callback' + ) + + # Different scheme + assert not indieauth.verify_redirect_uri( + 'http://ex.com', + 'https://ex.com/callback' + ) + + # Different subdomain + assert not indieauth.verify_redirect_uri( + 'https://sub1.ex.com', + 'https://sub2.ex.com/callback' + ) diff --git a/tests/components/auth/test_init.py b/tests/components/auth/test_init.py index 7cff04327b8..68a77d18d56 100644 --- a/tests/components/auth/test_init.py +++ b/tests/components/auth/test_init.py @@ -1,22 +1,26 @@ """Integration tests for the auth component.""" -from . import async_setup_auth, CLIENT_AUTH, CLIENT_REDIRECT_URI +from . import async_setup_auth + +from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI async def test_login_new_user_and_refresh_token(hass, aiohttp_client): """Test logging in with new user and refreshing tokens.""" client = await async_setup_auth(hass, aiohttp_client, setup_api=True) resp = await client.post('/auth/login_flow', json={ + 'client_id': CLIENT_ID, 'handler': ['insecure_example', None], 'redirect_uri': CLIENT_REDIRECT_URI, - }, auth=CLIENT_AUTH) + }) assert resp.status == 200 step = await resp.json() resp = await client.post( '/auth/login_flow/{}'.format(step['flow_id']), json={ + 'client_id': CLIENT_ID, 'username': 'test-user', 'password': 'test-pass', - }, auth=CLIENT_AUTH) + }) assert resp.status == 200 step = await resp.json() @@ -24,9 +28,10 @@ async def test_login_new_user_and_refresh_token(hass, aiohttp_client): # Exchange code for tokens resp = await client.post('/auth/token', data={ + 'client_id': CLIENT_ID, 'grant_type': 'authorization_code', 'code': code - }, auth=CLIENT_AUTH) + }) assert resp.status == 200 tokens = await resp.json() @@ -35,9 +40,10 @@ async def test_login_new_user_and_refresh_token(hass, aiohttp_client): # Use refresh token to get more tokens. resp = await client.post('/auth/token', data={ + 'client_id': CLIENT_ID, 'grant_type': 'refresh_token', 'refresh_token': tokens['refresh_token'] - }, auth=CLIENT_AUTH) + }) assert resp.status == 200 tokens = await resp.json() diff --git a/tests/components/auth/test_init_link_user.py b/tests/components/auth/test_init_link_user.py index 853c002ba46..28a924bb43a 100644 --- a/tests/components/auth/test_init_link_user.py +++ b/tests/components/auth/test_init_link_user.py @@ -1,5 +1,7 @@ """Tests for the link user flow.""" -from . import async_setup_auth, CLIENT_AUTH, CLIENT_ID, CLIENT_REDIRECT_URI +from . import async_setup_auth + +from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI async def async_get_code(hass, aiohttp_client): @@ -25,17 +27,19 @@ async def async_get_code(hass, aiohttp_client): client = await async_setup_auth(hass, aiohttp_client, config) resp = await client.post('/auth/login_flow', json={ + 'client_id': CLIENT_ID, 'handler': ['insecure_example', None], 'redirect_uri': CLIENT_REDIRECT_URI, - }, auth=CLIENT_AUTH) + }) assert resp.status == 200 step = await resp.json() resp = await client.post( '/auth/login_flow/{}'.format(step['flow_id']), json={ + 'client_id': CLIENT_ID, 'username': 'test-user', 'password': 'test-pass', - }, auth=CLIENT_AUTH) + }) assert resp.status == 200 step = await resp.json() @@ -43,9 +47,10 @@ async def async_get_code(hass, aiohttp_client): # Exchange code for tokens resp = await client.post('/auth/token', data={ + 'client_id': CLIENT_ID, 'grant_type': 'authorization_code', 'code': code - }, auth=CLIENT_AUTH) + }) assert resp.status == 200 tokens = await resp.json() @@ -57,17 +62,19 @@ async def async_get_code(hass, aiohttp_client): # Now authenticate with the 2nd flow resp = await client.post('/auth/login_flow', json={ + 'client_id': CLIENT_ID, 'handler': ['insecure_example', '2nd auth'], 'redirect_uri': CLIENT_REDIRECT_URI, - }, auth=CLIENT_AUTH) + }) assert resp.status == 200 step = await resp.json() resp = await client.post( '/auth/login_flow/{}'.format(step['flow_id']), json={ + 'client_id': CLIENT_ID, 'username': '2nd-user', 'password': '2nd-pass', - }, auth=CLIENT_AUTH) + }) assert resp.status == 200 step = await resp.json() diff --git a/tests/components/auth/test_init_login_flow.py b/tests/components/auth/test_init_login_flow.py index ad39fba3997..50bd03d6ced 100644 --- a/tests/components/auth/test_init_login_flow.py +++ b/tests/components/auth/test_init_login_flow.py @@ -1,13 +1,13 @@ """Tests for the login flow.""" -from aiohttp.helpers import BasicAuth +from . import async_setup_auth -from . import async_setup_auth, CLIENT_AUTH, CLIENT_REDIRECT_URI +from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI async def test_fetch_auth_providers(hass, aiohttp_client): """Test fetching auth providers.""" client = await async_setup_auth(hass, aiohttp_client) - resp = await client.get('/auth/providers', auth=CLIENT_AUTH) + resp = await client.get('/auth/providers') assert await resp.json() == [{ 'name': 'Example', 'type': 'insecure_example', @@ -15,14 +15,6 @@ async def test_fetch_auth_providers(hass, aiohttp_client): }] -async def test_fetch_auth_providers_require_valid_client(hass, aiohttp_client): - """Test fetching auth providers.""" - client = await async_setup_auth(hass, aiohttp_client) - resp = await client.get('/auth/providers', - auth=BasicAuth('invalid', 'bla')) - assert resp.status == 401 - - async def test_cannot_get_flows_in_progress(hass, aiohttp_client): """Test we cannot get flows in progress.""" client = await async_setup_auth(hass, aiohttp_client, []) @@ -34,18 +26,20 @@ async def test_invalid_username_password(hass, aiohttp_client): """Test we cannot get flows in progress.""" client = await async_setup_auth(hass, aiohttp_client) resp = await client.post('/auth/login_flow', json={ + 'client_id': CLIENT_ID, 'handler': ['insecure_example', None], 'redirect_uri': CLIENT_REDIRECT_URI - }, auth=CLIENT_AUTH) + }) assert resp.status == 200 step = await resp.json() # Incorrect username resp = await client.post( '/auth/login_flow/{}'.format(step['flow_id']), json={ + 'client_id': CLIENT_ID, 'username': 'wrong-user', 'password': 'test-pass', - }, auth=CLIENT_AUTH) + }) assert resp.status == 200 step = await resp.json() @@ -56,9 +50,10 @@ async def test_invalid_username_password(hass, aiohttp_client): # Incorrect password resp = await client.post( '/auth/login_flow/{}'.format(step['flow_id']), json={ + 'client_id': CLIENT_ID, 'username': 'test-user', 'password': 'wrong-pass', - }, auth=CLIENT_AUTH) + }) assert resp.status == 200 step = await resp.json() diff --git a/tests/components/conftest.py b/tests/components/conftest.py index 00e3ee88d16..843866cbfbd 100644 --- a/tests/components/conftest.py +++ b/tests/components/conftest.py @@ -3,7 +3,7 @@ import pytest from homeassistant.setup import async_setup_component -from tests.common import MockUser +from tests.common import MockUser, CLIENT_ID @pytest.fixture @@ -28,11 +28,6 @@ def hass_ws_client(aiohttp_client): def hass_access_token(hass): """Return an access token to access Home Assistant.""" user = MockUser().add_to_hass(hass) - client = hass.loop.run_until_complete(hass.auth.async_create_client( - 'Access Token Fixture', - redirect_uris=['/'], - no_secret=True, - )) refresh_token = hass.loop.run_until_complete( - hass.auth.async_create_refresh_token(user, client)) + hass.auth.async_create_refresh_token(user, CLIENT_ID)) yield hass.auth.async_create_access_token(refresh_token) diff --git a/tests/test_auth.py b/tests/test_auth.py index 8096a081679..3119c3d8d71 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -6,7 +6,8 @@ import pytest from homeassistant import auth, data_entry_flow from homeassistant.util import dt as dt_util -from tests.common import MockUser, ensure_auth_manager_loaded, flush_store +from tests.common import ( + MockUser, ensure_auth_manager_loaded, flush_store, CLIENT_ID) @pytest.fixture @@ -181,10 +182,7 @@ async def test_saving_loading(hass, hass_storage): }) user = await manager.async_get_or_create_user(step['result']) - client = await manager.async_create_client( - 'test', redirect_uris=['https://example.com']) - - refresh_token = await manager.async_create_refresh_token(user, client) + refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID) manager.async_create_access_token(refresh_token) @@ -195,10 +193,6 @@ async def test_saving_loading(hass, hass_storage): assert len(users) == 1 assert users[0] == user - clients = await store2.async_get_clients() - assert len(clients) == 1 - assert clients[0] == client - def test_access_token_expired(): """Test that the expired property on access tokens work.""" @@ -225,11 +219,10 @@ def test_access_token_expired(): async def test_cannot_retrieve_expired_access_token(hass): """Test that we cannot retrieve expired access tokens.""" manager = await auth.auth_manager_from_config(hass, []) - client = await manager.async_create_client('test') user = MockUser().add_to_auth_manager(manager) - refresh_token = await manager.async_create_refresh_token(user, client) + refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID) assert refresh_token.user.id is user.id - assert refresh_token.client_id is client.id + assert refresh_token.client_id == CLIENT_ID access_token = manager.async_create_access_token(refresh_token) assert manager.async_get_access_token(access_token.token) is access_token @@ -242,19 +235,6 @@ async def test_cannot_retrieve_expired_access_token(hass): assert manager.async_get_access_token(access_token.token) is None -async def test_get_or_create_client(hass): - """Test that get_or_create_client works.""" - manager = await auth.auth_manager_from_config(hass, []) - - client1 = await manager.async_get_or_create_client( - 'Test Client', redirect_uris=['https://test.com/1']) - assert client1.name is 'Test Client' - - client2 = await manager.async_get_or_create_client( - 'Test Client', redirect_uris=['https://test.com/1']) - assert client2.id is client1.id - - async def test_generating_system_user(hass): """Test that we can add a system user.""" manager = await auth.auth_manager_from_config(hass, []) @@ -274,10 +254,9 @@ async def test_refresh_token_requires_client_for_user(hass): with pytest.raises(ValueError): await manager.async_create_refresh_token(user) - client = await manager.async_get_or_create_client('Test client') - token = await manager.async_create_refresh_token(user, client) + token = await manager.async_create_refresh_token(user, CLIENT_ID) assert token is not None - assert token.client_id == client.id + assert token.client_id == CLIENT_ID async def test_refresh_token_not_requires_client_for_system_user(hass): @@ -285,10 +264,9 @@ async def test_refresh_token_not_requires_client_for_system_user(hass): manager = await auth.auth_manager_from_config(hass, []) user = await manager.async_create_system_user('Hass.io') assert user.system_generated is True - client = await manager.async_get_or_create_client('Test client') with pytest.raises(ValueError): - await manager.async_create_refresh_token(user, client) + await manager.async_create_refresh_token(user, CLIENT_ID) token = await manager.async_create_refresh_token(user) assert token is not None