diff --git a/homeassistant/auth/__init__.py b/homeassistant/auth/__init__.py index 9695e77f6f1..148f97702e3 100644 --- a/homeassistant/auth/__init__.py +++ b/homeassistant/auth/__init__.py @@ -4,10 +4,12 @@ import logging from collections import OrderedDict from typing import List, Awaitable +import jwt + from homeassistant import data_entry_flow from homeassistant.core import callback, HomeAssistant +from homeassistant.util import dt as dt_util -from . import models from . import auth_store from .providers import auth_provider_from_config @@ -54,7 +56,6 @@ class AuthManager: self.login_flow = data_entry_flow.FlowManager( hass, self._async_create_login_flow, self._async_finish_login_flow) - self._access_tokens = OrderedDict() @property def active(self): @@ -181,35 +182,56 @@ class AuthManager: return await self._store.async_create_refresh_token(user, client_id) - async def async_get_refresh_token(self, token): + async def async_get_refresh_token(self, token_id): + """Get refresh token by id.""" + return await self._store.async_get_refresh_token(token_id) + + async def async_get_refresh_token_by_token(self, token): """Get refresh token by token.""" - return await self._store.async_get_refresh_token(token) + return await self._store.async_get_refresh_token_by_token(token) @callback def async_create_access_token(self, refresh_token): """Create a new access token.""" - access_token = models.AccessToken(refresh_token=refresh_token) - self._access_tokens[access_token.token] = access_token - return access_token + # pylint: disable=no-self-use + return jwt.encode({ + 'iss': refresh_token.id, + 'iat': dt_util.utcnow(), + 'exp': dt_util.utcnow() + refresh_token.access_token_expiration, + }, refresh_token.jwt_key, algorithm='HS256').decode() - @callback - def async_get_access_token(self, token): - """Get an access token.""" - tkn = self._access_tokens.get(token) - - if tkn is None: - _LOGGER.debug('Attempt to get non-existing access token') + async def async_validate_access_token(self, token): + """Return if an access token is valid.""" + try: + unverif_claims = jwt.decode(token, verify=False) + except jwt.InvalidTokenError: return None - if tkn.expired or not tkn.refresh_token.user.is_active: - if tkn.expired: - _LOGGER.debug('Attempt to get expired access token') - else: - _LOGGER.debug('Attempt to get access token for inactive user') - self._access_tokens.pop(token) + refresh_token = await self.async_get_refresh_token( + unverif_claims.get('iss')) + + if refresh_token is None: + jwt_key = '' + issuer = '' + else: + jwt_key = refresh_token.jwt_key + issuer = refresh_token.id + + try: + jwt.decode( + token, + jwt_key, + leeway=10, + issuer=issuer, + algorithms=['HS256'] + ) + except jwt.InvalidTokenError: return None - return tkn + if not refresh_token.user.is_active: + return None + + return refresh_token async def _async_create_login_flow(self, handler, *, context, data): """Create a login flow.""" diff --git a/homeassistant/auth/auth_store.py b/homeassistant/auth/auth_store.py index 8fd66d4bbb7..806cd109d78 100644 --- a/homeassistant/auth/auth_store.py +++ b/homeassistant/auth/auth_store.py @@ -1,6 +1,7 @@ """Storage for auth models.""" from collections import OrderedDict from datetime import timedelta +import hmac from homeassistant.util import dt as dt_util @@ -110,22 +111,36 @@ class AuthStore: async def async_create_refresh_token(self, user, client_id=None): """Create a new token for a user.""" refresh_token = models.RefreshToken(user=user, client_id=client_id) - user.refresh_tokens[refresh_token.token] = refresh_token + user.refresh_tokens[refresh_token.id] = refresh_token await self.async_save() return refresh_token - async def async_get_refresh_token(self, token): - """Get refresh token by token.""" + async def async_get_refresh_token(self, token_id): + """Get refresh token by id.""" if self._users is None: await self.async_load() for user in self._users.values(): - refresh_token = user.refresh_tokens.get(token) + refresh_token = user.refresh_tokens.get(token_id) if refresh_token is not None: return refresh_token return None + async def async_get_refresh_token_by_token(self, token): + """Get refresh token by token.""" + if self._users is None: + await self.async_load() + + found = None + + for user in self._users.values(): + for refresh_token in user.refresh_tokens.values(): + if hmac.compare_digest(refresh_token.token, token): + found = refresh_token + + return found + async def async_load(self): """Load the users.""" data = await self._store.async_load() @@ -153,9 +168,11 @@ class AuthStore: data=cred_dict['data'], )) - refresh_tokens = OrderedDict() - for rt_dict in data['refresh_tokens']: + # Filter out the old keys that don't have jwt_key (pre-0.76) + if 'jwt_key' not in rt_dict: + continue + token = models.RefreshToken( id=rt_dict['id'], user=users[rt_dict['user_id']], @@ -164,18 +181,9 @@ class AuthStore: access_token_expiration=timedelta( seconds=rt_dict['access_token_expiration']), token=rt_dict['token'], + jwt_key=rt_dict['jwt_key'] ) - refresh_tokens[token.id] = token - users[rt_dict['user_id']].refresh_tokens[token.token] = token - - for ac_dict in data['access_tokens']: - refresh_token = refresh_tokens[ac_dict['refresh_token_id']] - token = models.AccessToken( - refresh_token=refresh_token, - created_at=dt_util.parse_datetime(ac_dict['created_at']), - token=ac_dict['token'], - ) - refresh_token.access_tokens.append(token) + users[rt_dict['user_id']].refresh_tokens[token.id] = token self._users = users @@ -213,27 +221,15 @@ class AuthStore: 'access_token_expiration': refresh_token.access_token_expiration.total_seconds(), 'token': refresh_token.token, + 'jwt_key': refresh_token.jwt_key, } for user in self._users.values() for refresh_token in user.refresh_tokens.values() ] - access_tokens = [ - { - 'id': user.id, - 'refresh_token_id': refresh_token.id, - 'created_at': access_token.created_at.isoformat(), - 'token': access_token.token, - } - for user in self._users.values() - for refresh_token in user.refresh_tokens.values() - for access_token in refresh_token.access_tokens - ] - data = { 'users': users, 'credentials': credentials, - 'access_tokens': access_tokens, 'refresh_tokens': refresh_tokens, } diff --git a/homeassistant/auth/models.py b/homeassistant/auth/models.py index 38e054dc7cf..3f49c56bce6 100644 --- a/homeassistant/auth/models.py +++ b/homeassistant/auth/models.py @@ -39,26 +39,8 @@ class RefreshToken: default=ACCESS_TOKEN_EXPIRATION) token = attr.ib(type=str, default=attr.Factory(lambda: generate_secret(64))) - access_tokens = attr.ib(type=list, default=attr.Factory(list), cmp=False) - - -@attr.s(slots=True) -class AccessToken: - """Access token to access the API. - - These will only ever be stored in memory and not be persisted. - """ - - refresh_token = attr.ib(type=RefreshToken) - created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow)) - token = attr.ib(type=str, - default=attr.Factory(generate_secret)) - - @property - def expired(self): - """Return if this token has expired.""" - expires = self.created_at + self.refresh_token.access_token_expiration - return dt_util.utcnow() > expires + jwt_key = attr.ib(type=str, + default=attr.Factory(lambda: generate_secret(64))) @attr.s(slots=True) diff --git a/homeassistant/components/auth/__init__.py b/homeassistant/components/auth/__init__.py index 0b2b4fb1a2e..102bfe58b55 100644 --- a/homeassistant/components/auth/__init__.py +++ b/homeassistant/components/auth/__init__.py @@ -155,7 +155,7 @@ class GrantTokenView(HomeAssistantView): access_token = hass.auth.async_create_access_token(refresh_token) return self.json({ - 'access_token': access_token.token, + 'access_token': access_token, 'token_type': 'Bearer', 'refresh_token': refresh_token.token, 'expires_in': @@ -178,7 +178,7 @@ class GrantTokenView(HomeAssistantView): 'error': 'invalid_request', }, status_code=400) - refresh_token = await hass.auth.async_get_refresh_token(token) + refresh_token = await hass.auth.async_get_refresh_token_by_token(token) if refresh_token is None: return self.json({ @@ -193,7 +193,7 @@ class GrantTokenView(HomeAssistantView): access_token = hass.auth.async_create_access_token(refresh_token) return self.json({ - 'access_token': access_token.token, + 'access_token': access_token, 'token_type': 'Bearer', 'expires_in': int(refresh_token.access_token_expiration.total_seconds()), diff --git a/homeassistant/components/http/auth.py b/homeassistant/components/http/auth.py index 77621e3bc7c..d01d1b50c5a 100644 --- a/homeassistant/components/http/auth.py +++ b/homeassistant/components/http/auth.py @@ -106,11 +106,11 @@ async def async_validate_auth_header(request, api_password=None): if auth_type == 'Bearer': hass = request.app['hass'] - access_token = hass.auth.async_get_access_token(auth_val) - if access_token is None: + refresh_token = await hass.auth.async_validate_access_token(auth_val) + if refresh_token is None: return False - request['hass_user'] = access_token.refresh_token.user + request['hass_user'] = refresh_token.user return True if auth_type == 'Basic' and api_password is not None: diff --git a/homeassistant/components/websocket_api.py b/homeassistant/components/websocket_api.py index d9c92fa357f..532f3672df4 100644 --- a/homeassistant/components/websocket_api.py +++ b/homeassistant/components/websocket_api.py @@ -352,11 +352,12 @@ class ActiveConnection: if self.hass.auth.active and 'access_token' in msg: self.debug("Received access_token") - token = self.hass.auth.async_get_access_token( - msg['access_token']) - authenticated = token is not None + refresh_token = \ + await self.hass.auth.async_validate_access_token( + msg['access_token']) + authenticated = refresh_token is not None if authenticated: - request['hass_user'] = token.refresh_token.user + request['hass_user'] = refresh_token.user elif ((not self.hass.auth.active or self.hass.auth.support_legacy) and diff --git a/homeassistant/package_constraints.txt b/homeassistant/package_constraints.txt index 29e10838f21..3aa1e3643c6 100644 --- a/homeassistant/package_constraints.txt +++ b/homeassistant/package_constraints.txt @@ -4,6 +4,7 @@ async_timeout==3.0.0 attrs==18.1.0 certifi>=2018.04.16 jinja2>=2.10 +PyJWT==1.6.4 pip>=8.0.3 pytz>=2018.04 pyyaml>=3.13,<4 diff --git a/requirements_all.txt b/requirements_all.txt index 0e6d7e1ac07..3f50e50d19a 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -5,6 +5,7 @@ async_timeout==3.0.0 attrs==18.1.0 certifi>=2018.04.16 jinja2>=2.10 +PyJWT==1.6.4 pip>=8.0.3 pytz>=2018.04 pyyaml>=3.13,<4 diff --git a/setup.py b/setup.py index b319df9067d..bd1e70aa8ae 100755 --- a/setup.py +++ b/setup.py @@ -38,6 +38,7 @@ REQUIRES = [ 'attrs==18.1.0', 'certifi>=2018.04.16', 'jinja2>=2.10', + 'PyJWT==1.6.4', 'pip>=8.0.3', 'pytz>=2018.04', 'pyyaml>=3.13,<4', diff --git a/tests/auth/test_init.py b/tests/auth/test_init.py index cad4bbdbd71..da5daca7cf6 100644 --- a/tests/auth/test_init.py +++ b/tests/auth/test_init.py @@ -199,9 +199,7 @@ async def test_saving_loading(hass, hass_storage): }) user = await manager.async_get_or_create_user(step['result']) await manager.async_activate_user(user) - refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID) - - manager.async_create_access_token(refresh_token) + await manager.async_create_refresh_token(user, CLIENT_ID) await flush_store(manager._store._store) @@ -211,30 +209,6 @@ async def test_saving_loading(hass, hass_storage): assert users[0] == user -def test_access_token_expired(): - """Test that the expired property on access tokens work.""" - refresh_token = auth_models.RefreshToken( - user=None, - client_id='bla' - ) - - access_token = auth_models.AccessToken( - refresh_token=refresh_token - ) - - assert access_token.expired is False - - with patch('homeassistant.util.dt.utcnow', - return_value=dt_util.utcnow() + - auth_const.ACCESS_TOKEN_EXPIRATION): - assert access_token.expired is True - - almost_exp = \ - dt_util.utcnow() + auth_const.ACCESS_TOKEN_EXPIRATION - timedelta(1) - with patch('homeassistant.util.dt.utcnow', return_value=almost_exp): - assert access_token.expired is False - - async def test_cannot_retrieve_expired_access_token(hass): """Test that we cannot retrieve expired access tokens.""" manager = await auth.auth_manager_from_config(hass, []) @@ -244,15 +218,20 @@ async def test_cannot_retrieve_expired_access_token(hass): assert refresh_token.client_id == CLIENT_ID access_token = manager.async_create_access_token(refresh_token) - assert manager.async_get_access_token(access_token.token) is access_token + assert ( + await manager.async_validate_access_token(access_token) + is refresh_token + ) with patch('homeassistant.util.dt.utcnow', - return_value=dt_util.utcnow() + - auth_const.ACCESS_TOKEN_EXPIRATION): - assert manager.async_get_access_token(access_token.token) is None + return_value=dt_util.utcnow() - + auth_const.ACCESS_TOKEN_EXPIRATION - timedelta(seconds=11)): + access_token = manager.async_create_access_token(refresh_token) - # Even with unpatched time, it should have been removed from manager - assert manager.async_get_access_token(access_token.token) is None + assert ( + await manager.async_validate_access_token(access_token) + is None + ) async def test_generating_system_user(hass): diff --git a/tests/common.py b/tests/common.py index df333cca735..81e4774ccd4 100644 --- a/tests/common.py +++ b/tests/common.py @@ -314,12 +314,18 @@ def mock_registry(hass, mock_entries=None): class MockUser(auth_models.User): """Mock a user in Home Assistant.""" - def __init__(self, id='mock-id', is_owner=False, is_active=True, + def __init__(self, id=None, is_owner=False, is_active=True, name='Mock User', system_generated=False): """Initialize mock user.""" - super().__init__( - id=id, is_owner=is_owner, is_active=is_active, name=name, - system_generated=system_generated) + kwargs = { + 'is_owner': is_owner, + 'is_active': is_active, + 'name': name, + 'system_generated': system_generated + } + if id is not None: + kwargs['id'] = id + super().__init__(**kwargs) def add_to_hass(self, hass): """Test helper to add entry to hass.""" diff --git a/tests/components/auth/test_init.py b/tests/components/auth/test_init.py index eea768c96a0..f1a1bb5bd3c 100644 --- a/tests/components/auth/test_init.py +++ b/tests/components/auth/test_init.py @@ -44,7 +44,10 @@ async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client): assert resp.status == 200 tokens = await resp.json() - assert hass.auth.async_get_access_token(tokens['access_token']) is not None + assert ( + await hass.auth.async_validate_access_token(tokens['access_token']) + is not None + ) # Use refresh token to get more tokens. resp = await client.post('/auth/token', data={ @@ -56,7 +59,10 @@ async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client): assert resp.status == 200 tokens = await resp.json() assert 'refresh_token' not in tokens - assert hass.auth.async_get_access_token(tokens['access_token']) is not None + assert ( + await hass.auth.async_validate_access_token(tokens['access_token']) + is not None + ) # Test using access token to hit API. resp = await client.get('/api/') @@ -98,7 +104,9 @@ async def test_ws_current_user(hass, hass_ws_client, hass_access_token): } }) - user = hass_access_token.refresh_token.user + refresh_token = await hass.auth.async_validate_access_token( + hass_access_token) + user = refresh_token.user credential = Credentials(auth_provider_type='homeassistant', auth_provider_id=None, data={}, id='test-id') @@ -169,7 +177,10 @@ async def test_refresh_token_system_generated(hass, aiohttp_client): assert resp.status == 200 tokens = await resp.json() - assert hass.auth.async_get_access_token(tokens['access_token']) is not None + assert ( + await hass.auth.async_validate_access_token(tokens['access_token']) + is not None + ) async def test_refresh_token_different_client_id(hass, aiohttp_client): @@ -208,4 +219,7 @@ async def test_refresh_token_different_client_id(hass, aiohttp_client): assert resp.status == 200 tokens = await resp.json() - assert hass.auth.async_get_access_token(tokens['access_token']) is not None + assert ( + await hass.auth.async_validate_access_token(tokens['access_token']) + is not None + ) diff --git a/tests/components/auth/test_init_link_user.py b/tests/components/auth/test_init_link_user.py index 13515db87fa..e209e0ee856 100644 --- a/tests/components/auth/test_init_link_user.py +++ b/tests/components/auth/test_init_link_user.py @@ -52,7 +52,7 @@ async def async_get_code(hass, aiohttp_client): 'user': user, 'code': step['result'], 'client': client, - 'access_token': access_token.token, + 'access_token': access_token, } diff --git a/tests/components/config/test_auth.py b/tests/components/config/test_auth.py index fe8f351955f..cd04eedf08e 100644 --- a/tests/components/config/test_auth.py +++ b/tests/components/config/test_auth.py @@ -122,11 +122,13 @@ async def test_delete_unable_self_account(hass, hass_ws_client, hass_access_token): """Test we cannot delete our own account.""" client = await hass_ws_client(hass, hass_access_token) + refresh_token = await hass.auth.async_validate_access_token( + hass_access_token) await client.send_json({ 'id': 5, 'type': auth_config.WS_TYPE_DELETE, - 'user_id': hass_access_token.refresh_token.user.id, + 'user_id': refresh_token.user.id, }) result = await client.receive_json() @@ -137,7 +139,9 @@ async def test_delete_unable_self_account(hass, hass_ws_client, async def test_delete_unknown_user(hass, hass_ws_client, hass_access_token): """Test we cannot delete an unknown user.""" client = await hass_ws_client(hass, hass_access_token) - hass_access_token.refresh_token.user.is_owner = True + refresh_token = await hass.auth.async_validate_access_token( + hass_access_token) + refresh_token.user.is_owner = True await client.send_json({ 'id': 5, @@ -153,7 +157,9 @@ async def test_delete_unknown_user(hass, hass_ws_client, hass_access_token): async def test_delete(hass, hass_ws_client, hass_access_token): """Test delete command works.""" client = await hass_ws_client(hass, hass_access_token) - hass_access_token.refresh_token.user.is_owner = True + refresh_token = await hass.auth.async_validate_access_token( + hass_access_token) + refresh_token.user.is_owner = True test_user = MockUser( id='efg', ).add_to_hass(hass) @@ -174,7 +180,9 @@ async def test_delete(hass, hass_ws_client, hass_access_token): async def test_create(hass, hass_ws_client, hass_access_token): """Test create command works.""" client = await hass_ws_client(hass, hass_access_token) - hass_access_token.refresh_token.user.is_owner = True + refresh_token = await hass.auth.async_validate_access_token( + hass_access_token) + refresh_token.user.is_owner = True assert len(await hass.auth.async_get_users()) == 1 diff --git a/tests/components/config/test_auth_provider_homeassistant.py b/tests/components/config/test_auth_provider_homeassistant.py index cd2cbc44539..a374083c2ab 100644 --- a/tests/components/config/test_auth_provider_homeassistant.py +++ b/tests/components/config/test_auth_provider_homeassistant.py @@ -9,7 +9,7 @@ from tests.common import MockUser, register_auth_provider @pytest.fixture(autouse=True) -def setup_config(hass, aiohttp_client): +def setup_config(hass): """Fixture that sets up the auth provider homeassistant module.""" hass.loop.run_until_complete(register_auth_provider(hass, { 'type': 'homeassistant' @@ -22,7 +22,9 @@ async def test_create_auth_system_generated_user(hass, hass_access_token, """Test we can't add auth to system generated users.""" system_user = MockUser(system_generated=True).add_to_hass(hass) client = await hass_ws_client(hass, hass_access_token) - hass_access_token.refresh_token.user.is_owner = True + refresh_token = await hass.auth.async_validate_access_token( + hass_access_token) + refresh_token.user.is_owner = True await client.send_json({ 'id': 5, @@ -47,7 +49,9 @@ async def test_create_auth_unknown_user(hass_ws_client, hass, hass_access_token): """Test create pointing at unknown user.""" client = await hass_ws_client(hass, hass_access_token) - hass_access_token.refresh_token.user.is_owner = True + refresh_token = await hass.auth.async_validate_access_token( + hass_access_token) + refresh_token.user.is_owner = True await client.send_json({ 'id': 5, @@ -86,7 +90,9 @@ async def test_create_auth(hass, hass_ws_client, hass_access_token, """Test create auth command works.""" client = await hass_ws_client(hass, hass_access_token) user = MockUser().add_to_hass(hass) - hass_access_token.refresh_token.user.is_owner = True + refresh_token = await hass.auth.async_validate_access_token( + hass_access_token) + refresh_token.user.is_owner = True assert len(user.credentials) == 0 @@ -117,7 +123,9 @@ async def test_create_auth_duplicate_username(hass, hass_ws_client, """Test we can't create auth with a duplicate username.""" client = await hass_ws_client(hass, hass_access_token) user = MockUser().add_to_hass(hass) - hass_access_token.refresh_token.user.is_owner = True + refresh_token = await hass.auth.async_validate_access_token( + hass_access_token) + refresh_token.user.is_owner = True hass_storage[prov_ha.STORAGE_KEY] = { 'version': 1, @@ -145,7 +153,9 @@ async def test_delete_removes_just_auth(hass_ws_client, hass, hass_storage, hass_access_token): """Test deleting an auth without being connected to a user.""" client = await hass_ws_client(hass, hass_access_token) - hass_access_token.refresh_token.user.is_owner = True + refresh_token = await hass.auth.async_validate_access_token( + hass_access_token) + refresh_token.user.is_owner = True hass_storage[prov_ha.STORAGE_KEY] = { 'version': 1, @@ -171,7 +181,9 @@ async def test_delete_removes_credential(hass, hass_ws_client, hass_access_token, hass_storage): """Test deleting auth that is connected to a user.""" client = await hass_ws_client(hass, hass_access_token) - hass_access_token.refresh_token.user.is_owner = True + refresh_token = await hass.auth.async_validate_access_token( + hass_access_token) + refresh_token.user.is_owner = True user = MockUser().add_to_hass(hass) user.credentials.append( @@ -216,7 +228,9 @@ async def test_delete_requires_owner(hass, hass_ws_client, hass_access_token): async def test_delete_unknown_auth(hass, hass_ws_client, hass_access_token): """Test trying to delete an unknown auth username.""" client = await hass_ws_client(hass, hass_access_token) - hass_access_token.refresh_token.user.is_owner = True + refresh_token = await hass.auth.async_validate_access_token( + hass_access_token) + refresh_token.user.is_owner = True await client.send_json({ 'id': 5, @@ -240,7 +254,9 @@ async def test_change_password(hass, hass_ws_client, hass_access_token): 'username': 'test-user' }) - user = hass_access_token.refresh_token.user + refresh_token = await hass.auth.async_validate_access_token( + hass_access_token) + user = refresh_token.user await hass.auth.async_link_user(user, credentials) client = await hass_ws_client(hass, hass_access_token) @@ -268,7 +284,9 @@ async def test_change_password_wrong_pw(hass, hass_ws_client, 'username': 'test-user' }) - user = hass_access_token.refresh_token.user + refresh_token = await hass.auth.async_validate_access_token( + hass_access_token) + user = refresh_token.user await hass.auth.async_link_user(user, credentials) client = await hass_ws_client(hass, hass_access_token) diff --git a/tests/components/conftest.py b/tests/components/conftest.py index 5f6a17a4101..bb9b643296e 100644 --- a/tests/components/conftest.py +++ b/tests/components/conftest.py @@ -28,7 +28,7 @@ def hass_ws_client(aiohttp_client): await websocket.send_json({ 'type': websocket_api.TYPE_AUTH, - 'access_token': access_token.token + 'access_token': access_token }) auth_ok = await websocket.receive_json() diff --git a/tests/components/hassio/test_init.py b/tests/components/hassio/test_init.py index b1975669731..4fd59dd3f7a 100644 --- a/tests/components/hassio/test_init.py +++ b/tests/components/hassio/test_init.py @@ -106,7 +106,11 @@ async def test_setup_api_push_api_data_default(hass, aioclient_mock, ) assert hassio_user is not None assert hassio_user.system_generated - assert refresh_token in hassio_user.refresh_tokens + for token in hassio_user.refresh_tokens.values(): + if token.token == refresh_token: + break + else: + assert False, 'refresh token not found' async def test_setup_api_push_api_data_no_auth(hass, aioclient_mock, diff --git a/tests/components/http/test_auth.py b/tests/components/http/test_auth.py index 31cba79a6c8..8e7a62e2e9f 100644 --- a/tests/components/http/test_auth.py +++ b/tests/components/http/test_auth.py @@ -156,9 +156,9 @@ async def test_access_with_trusted_ip(app2, aiohttp_client): async def test_auth_active_access_with_access_token_in_header( - app, aiohttp_client, hass_access_token): + hass, app, aiohttp_client, hass_access_token): """Test access with access token in header.""" - token = hass_access_token.token + token = hass_access_token setup_auth(app, [], True, api_password=None) client = await aiohttp_client(app) @@ -182,7 +182,9 @@ async def test_auth_active_access_with_access_token_in_header( '/', headers={'Authorization': 'BEARER {}'.format(token)}) assert req.status == 401 - hass_access_token.refresh_token.user.is_active = False + refresh_token = await hass.auth.async_validate_access_token( + hass_access_token) + refresh_token.user.is_active = False req = await client.get( '/', headers={'Authorization': 'Bearer {}'.format(token)}) assert req.status == 401 diff --git a/tests/components/test_api.py b/tests/components/test_api.py index 09dc27e97c1..2be1168b86a 100644 --- a/tests/components/test_api.py +++ b/tests/components/test_api.py @@ -448,13 +448,15 @@ async def test_api_fire_event_context(hass, mock_api_client, await mock_api_client.post( const.URL_API_EVENTS_EVENT.format("test.event"), headers={ - 'authorization': 'Bearer {}'.format(hass_access_token.token) + 'authorization': 'Bearer {}'.format(hass_access_token) }) await hass.async_block_till_done() + refresh_token = await hass.auth.async_validate_access_token( + hass_access_token) + assert len(test_value) == 1 - assert test_value[0].context.user_id == \ - hass_access_token.refresh_token.user.id + assert test_value[0].context.user_id == refresh_token.user.id async def test_api_call_service_context(hass, mock_api_client, @@ -465,12 +467,15 @@ async def test_api_call_service_context(hass, mock_api_client, await mock_api_client.post( '/api/services/test_domain/test_service', headers={ - 'authorization': 'Bearer {}'.format(hass_access_token.token) + 'authorization': 'Bearer {}'.format(hass_access_token) }) await hass.async_block_till_done() + refresh_token = await hass.auth.async_validate_access_token( + hass_access_token) + assert len(calls) == 1 - assert calls[0].context.user_id == hass_access_token.refresh_token.user.id + assert calls[0].context.user_id == refresh_token.user.id async def test_api_set_state_context(hass, mock_api_client, hass_access_token): @@ -481,8 +486,11 @@ async def test_api_set_state_context(hass, mock_api_client, hass_access_token): 'state': 'on' }, headers={ - 'authorization': 'Bearer {}'.format(hass_access_token.token) + 'authorization': 'Bearer {}'.format(hass_access_token) }) + refresh_token = await hass.auth.async_validate_access_token( + hass_access_token) + state = hass.states.get('light.kitchen') - assert state.context.user_id == hass_access_token.refresh_token.user.id + assert state.context.user_id == refresh_token.user.id diff --git a/tests/components/test_websocket_api.py b/tests/components/test_websocket_api.py index 1fac1af9f64..199a9d804f8 100644 --- a/tests/components/test_websocket_api.py +++ b/tests/components/test_websocket_api.py @@ -334,7 +334,7 @@ async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token): await ws.send_json({ 'type': wapi.TYPE_AUTH, - 'access_token': hass_access_token.token + 'access_token': hass_access_token }) auth_msg = await ws.receive_json() @@ -344,7 +344,9 @@ async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token): async def test_auth_active_user_inactive(hass, aiohttp_client, hass_access_token): """Test authenticating with a token.""" - hass_access_token.refresh_token.user.is_active = False + refresh_token = await hass.auth.async_validate_access_token( + hass_access_token) + refresh_token.user.is_active = False assert await async_setup_component(hass, 'websocket_api', { 'http': { 'api_password': API_PASSWORD @@ -361,7 +363,7 @@ async def test_auth_active_user_inactive(hass, aiohttp_client, await ws.send_json({ 'type': wapi.TYPE_AUTH, - 'access_token': hass_access_token.token + 'access_token': hass_access_token }) auth_msg = await ws.receive_json() @@ -465,7 +467,7 @@ async def test_call_service_context_with_user(hass, aiohttp_client, await ws.send_json({ 'type': wapi.TYPE_AUTH, - 'access_token': hass_access_token.token + 'access_token': hass_access_token }) auth_msg = await ws.receive_json() @@ -484,12 +486,15 @@ async def test_call_service_context_with_user(hass, aiohttp_client, msg = await ws.receive_json() assert msg['success'] + refresh_token = await hass.auth.async_validate_access_token( + hass_access_token) + assert len(calls) == 1 call = calls[0] assert call.domain == 'domain_test' assert call.service == 'test_service' assert call.data == {'hello': 'world'} - assert call.context.user_id == hass_access_token.refresh_token.user.id + assert call.context.user_id == refresh_token.user.id async def test_call_service_context_no_user(hass, aiohttp_client):