Use JWT for access tokens (#15972)

* Use JWT for access tokens

* Update requirements

* Improvements
This commit is contained in:
Paulus Schoutsen 2018-08-14 21:14:12 +02:00 committed by GitHub
parent ee5d49a033
commit e776f88eec
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 203 additions and 155 deletions

View file

@ -4,10 +4,12 @@ import logging
from collections import OrderedDict from collections import OrderedDict
from typing import List, Awaitable from typing import List, Awaitable
import jwt
from homeassistant import data_entry_flow from homeassistant import data_entry_flow
from homeassistant.core import callback, HomeAssistant from homeassistant.core import callback, HomeAssistant
from homeassistant.util import dt as dt_util
from . import models
from . import auth_store from . import auth_store
from .providers import auth_provider_from_config from .providers import auth_provider_from_config
@ -54,7 +56,6 @@ class AuthManager:
self.login_flow = data_entry_flow.FlowManager( self.login_flow = data_entry_flow.FlowManager(
hass, self._async_create_login_flow, hass, self._async_create_login_flow,
self._async_finish_login_flow) self._async_finish_login_flow)
self._access_tokens = OrderedDict()
@property @property
def active(self): def active(self):
@ -181,35 +182,56 @@ class AuthManager:
return await self._store.async_create_refresh_token(user, client_id) return await self._store.async_create_refresh_token(user, client_id)
async def async_get_refresh_token(self, token): async def async_get_refresh_token(self, token_id):
"""Get refresh token by id."""
return await self._store.async_get_refresh_token(token_id)
async def async_get_refresh_token_by_token(self, token):
"""Get refresh token by token.""" """Get refresh token by token."""
return await self._store.async_get_refresh_token(token) return await self._store.async_get_refresh_token_by_token(token)
@callback @callback
def async_create_access_token(self, refresh_token): def async_create_access_token(self, refresh_token):
"""Create a new access token.""" """Create a new access token."""
access_token = models.AccessToken(refresh_token=refresh_token) # pylint: disable=no-self-use
self._access_tokens[access_token.token] = access_token return jwt.encode({
return access_token 'iss': refresh_token.id,
'iat': dt_util.utcnow(),
'exp': dt_util.utcnow() + refresh_token.access_token_expiration,
}, refresh_token.jwt_key, algorithm='HS256').decode()
@callback async def async_validate_access_token(self, token):
def async_get_access_token(self, token): """Return if an access token is valid."""
"""Get an access token.""" try:
tkn = self._access_tokens.get(token) unverif_claims = jwt.decode(token, verify=False)
except jwt.InvalidTokenError:
if tkn is None:
_LOGGER.debug('Attempt to get non-existing access token')
return None return None
if tkn.expired or not tkn.refresh_token.user.is_active: refresh_token = await self.async_get_refresh_token(
if tkn.expired: unverif_claims.get('iss'))
_LOGGER.debug('Attempt to get expired access token')
else: if refresh_token is None:
_LOGGER.debug('Attempt to get access token for inactive user') jwt_key = ''
self._access_tokens.pop(token) issuer = ''
else:
jwt_key = refresh_token.jwt_key
issuer = refresh_token.id
try:
jwt.decode(
token,
jwt_key,
leeway=10,
issuer=issuer,
algorithms=['HS256']
)
except jwt.InvalidTokenError:
return None return None
return tkn if not refresh_token.user.is_active:
return None
return refresh_token
async def _async_create_login_flow(self, handler, *, context, data): async def _async_create_login_flow(self, handler, *, context, data):
"""Create a login flow.""" """Create a login flow."""

View file

@ -1,6 +1,7 @@
"""Storage for auth models.""" """Storage for auth models."""
from collections import OrderedDict from collections import OrderedDict
from datetime import timedelta from datetime import timedelta
import hmac
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
@ -110,22 +111,36 @@ class AuthStore:
async def async_create_refresh_token(self, user, client_id=None): async def async_create_refresh_token(self, user, client_id=None):
"""Create a new token for a user.""" """Create a new token for a user."""
refresh_token = models.RefreshToken(user=user, client_id=client_id) refresh_token = models.RefreshToken(user=user, client_id=client_id)
user.refresh_tokens[refresh_token.token] = refresh_token user.refresh_tokens[refresh_token.id] = refresh_token
await self.async_save() await self.async_save()
return refresh_token return refresh_token
async def async_get_refresh_token(self, token): async def async_get_refresh_token(self, token_id):
"""Get refresh token by token.""" """Get refresh token by id."""
if self._users is None: if self._users is None:
await self.async_load() await self.async_load()
for user in self._users.values(): for user in self._users.values():
refresh_token = user.refresh_tokens.get(token) refresh_token = user.refresh_tokens.get(token_id)
if refresh_token is not None: if refresh_token is not None:
return refresh_token return refresh_token
return None return None
async def async_get_refresh_token_by_token(self, token):
"""Get refresh token by token."""
if self._users is None:
await self.async_load()
found = None
for user in self._users.values():
for refresh_token in user.refresh_tokens.values():
if hmac.compare_digest(refresh_token.token, token):
found = refresh_token
return found
async def async_load(self): async def async_load(self):
"""Load the users.""" """Load the users."""
data = await self._store.async_load() data = await self._store.async_load()
@ -153,9 +168,11 @@ class AuthStore:
data=cred_dict['data'], data=cred_dict['data'],
)) ))
refresh_tokens = OrderedDict()
for rt_dict in data['refresh_tokens']: for rt_dict in data['refresh_tokens']:
# Filter out the old keys that don't have jwt_key (pre-0.76)
if 'jwt_key' not in rt_dict:
continue
token = models.RefreshToken( token = models.RefreshToken(
id=rt_dict['id'], id=rt_dict['id'],
user=users[rt_dict['user_id']], user=users[rt_dict['user_id']],
@ -164,18 +181,9 @@ class AuthStore:
access_token_expiration=timedelta( access_token_expiration=timedelta(
seconds=rt_dict['access_token_expiration']), seconds=rt_dict['access_token_expiration']),
token=rt_dict['token'], token=rt_dict['token'],
jwt_key=rt_dict['jwt_key']
) )
refresh_tokens[token.id] = token users[rt_dict['user_id']].refresh_tokens[token.id] = token
users[rt_dict['user_id']].refresh_tokens[token.token] = token
for ac_dict in data['access_tokens']:
refresh_token = refresh_tokens[ac_dict['refresh_token_id']]
token = models.AccessToken(
refresh_token=refresh_token,
created_at=dt_util.parse_datetime(ac_dict['created_at']),
token=ac_dict['token'],
)
refresh_token.access_tokens.append(token)
self._users = users self._users = users
@ -213,27 +221,15 @@ class AuthStore:
'access_token_expiration': 'access_token_expiration':
refresh_token.access_token_expiration.total_seconds(), refresh_token.access_token_expiration.total_seconds(),
'token': refresh_token.token, 'token': refresh_token.token,
'jwt_key': refresh_token.jwt_key,
} }
for user in self._users.values() for user in self._users.values()
for refresh_token in user.refresh_tokens.values() for refresh_token in user.refresh_tokens.values()
] ]
access_tokens = [
{
'id': user.id,
'refresh_token_id': refresh_token.id,
'created_at': access_token.created_at.isoformat(),
'token': access_token.token,
}
for user in self._users.values()
for refresh_token in user.refresh_tokens.values()
for access_token in refresh_token.access_tokens
]
data = { data = {
'users': users, 'users': users,
'credentials': credentials, 'credentials': credentials,
'access_tokens': access_tokens,
'refresh_tokens': refresh_tokens, 'refresh_tokens': refresh_tokens,
} }

View file

@ -39,26 +39,8 @@ class RefreshToken:
default=ACCESS_TOKEN_EXPIRATION) default=ACCESS_TOKEN_EXPIRATION)
token = attr.ib(type=str, token = attr.ib(type=str,
default=attr.Factory(lambda: generate_secret(64))) default=attr.Factory(lambda: generate_secret(64)))
access_tokens = attr.ib(type=list, default=attr.Factory(list), cmp=False) jwt_key = attr.ib(type=str,
default=attr.Factory(lambda: generate_secret(64)))
@attr.s(slots=True)
class AccessToken:
"""Access token to access the API.
These will only ever be stored in memory and not be persisted.
"""
refresh_token = attr.ib(type=RefreshToken)
created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow))
token = attr.ib(type=str,
default=attr.Factory(generate_secret))
@property
def expired(self):
"""Return if this token has expired."""
expires = self.created_at + self.refresh_token.access_token_expiration
return dt_util.utcnow() > expires
@attr.s(slots=True) @attr.s(slots=True)

View file

@ -155,7 +155,7 @@ class GrantTokenView(HomeAssistantView):
access_token = hass.auth.async_create_access_token(refresh_token) access_token = hass.auth.async_create_access_token(refresh_token)
return self.json({ return self.json({
'access_token': access_token.token, 'access_token': access_token,
'token_type': 'Bearer', 'token_type': 'Bearer',
'refresh_token': refresh_token.token, 'refresh_token': refresh_token.token,
'expires_in': 'expires_in':
@ -178,7 +178,7 @@ class GrantTokenView(HomeAssistantView):
'error': 'invalid_request', 'error': 'invalid_request',
}, status_code=400) }, status_code=400)
refresh_token = await hass.auth.async_get_refresh_token(token) refresh_token = await hass.auth.async_get_refresh_token_by_token(token)
if refresh_token is None: if refresh_token is None:
return self.json({ return self.json({
@ -193,7 +193,7 @@ class GrantTokenView(HomeAssistantView):
access_token = hass.auth.async_create_access_token(refresh_token) access_token = hass.auth.async_create_access_token(refresh_token)
return self.json({ return self.json({
'access_token': access_token.token, 'access_token': access_token,
'token_type': 'Bearer', 'token_type': 'Bearer',
'expires_in': 'expires_in':
int(refresh_token.access_token_expiration.total_seconds()), int(refresh_token.access_token_expiration.total_seconds()),

View file

@ -106,11 +106,11 @@ async def async_validate_auth_header(request, api_password=None):
if auth_type == 'Bearer': if auth_type == 'Bearer':
hass = request.app['hass'] hass = request.app['hass']
access_token = hass.auth.async_get_access_token(auth_val) refresh_token = await hass.auth.async_validate_access_token(auth_val)
if access_token is None: if refresh_token is None:
return False return False
request['hass_user'] = access_token.refresh_token.user request['hass_user'] = refresh_token.user
return True return True
if auth_type == 'Basic' and api_password is not None: if auth_type == 'Basic' and api_password is not None:

View file

@ -355,11 +355,12 @@ class ActiveConnection:
if self.hass.auth.active and 'access_token' in msg: if self.hass.auth.active and 'access_token' in msg:
self.debug("Received access_token") self.debug("Received access_token")
token = self.hass.auth.async_get_access_token( refresh_token = \
msg['access_token']) await self.hass.auth.async_validate_access_token(
authenticated = token is not None msg['access_token'])
authenticated = refresh_token is not None
if authenticated: if authenticated:
request['hass_user'] = token.refresh_token.user request['hass_user'] = refresh_token.user
elif ((not self.hass.auth.active or elif ((not self.hass.auth.active or
self.hass.auth.support_legacy) and self.hass.auth.support_legacy) and

View file

@ -4,6 +4,7 @@ async_timeout==3.0.0
attrs==18.1.0 attrs==18.1.0
certifi>=2018.04.16 certifi>=2018.04.16
jinja2>=2.10 jinja2>=2.10
PyJWT==1.6.4
pip>=8.0.3 pip>=8.0.3
pytz>=2018.04 pytz>=2018.04
pyyaml>=3.13,<4 pyyaml>=3.13,<4

View file

@ -5,6 +5,7 @@ async_timeout==3.0.0
attrs==18.1.0 attrs==18.1.0
certifi>=2018.04.16 certifi>=2018.04.16
jinja2>=2.10 jinja2>=2.10
PyJWT==1.6.4
pip>=8.0.3 pip>=8.0.3
pytz>=2018.04 pytz>=2018.04
pyyaml>=3.13,<4 pyyaml>=3.13,<4

View file

@ -38,6 +38,7 @@ REQUIRES = [
'attrs==18.1.0', 'attrs==18.1.0',
'certifi>=2018.04.16', 'certifi>=2018.04.16',
'jinja2>=2.10', 'jinja2>=2.10',
'PyJWT==1.6.4',
'pip>=8.0.3', 'pip>=8.0.3',
'pytz>=2018.04', 'pytz>=2018.04',
'pyyaml>=3.13,<4', 'pyyaml>=3.13,<4',

View file

@ -199,9 +199,7 @@ async def test_saving_loading(hass, hass_storage):
}) })
user = await manager.async_get_or_create_user(step['result']) user = await manager.async_get_or_create_user(step['result'])
await manager.async_activate_user(user) await manager.async_activate_user(user)
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID) await manager.async_create_refresh_token(user, CLIENT_ID)
manager.async_create_access_token(refresh_token)
await flush_store(manager._store._store) await flush_store(manager._store._store)
@ -211,30 +209,6 @@ async def test_saving_loading(hass, hass_storage):
assert users[0] == user assert users[0] == user
def test_access_token_expired():
"""Test that the expired property on access tokens work."""
refresh_token = auth_models.RefreshToken(
user=None,
client_id='bla'
)
access_token = auth_models.AccessToken(
refresh_token=refresh_token
)
assert access_token.expired is False
with patch('homeassistant.util.dt.utcnow',
return_value=dt_util.utcnow() +
auth_const.ACCESS_TOKEN_EXPIRATION):
assert access_token.expired is True
almost_exp = \
dt_util.utcnow() + auth_const.ACCESS_TOKEN_EXPIRATION - timedelta(1)
with patch('homeassistant.util.dt.utcnow', return_value=almost_exp):
assert access_token.expired is False
async def test_cannot_retrieve_expired_access_token(hass): async def test_cannot_retrieve_expired_access_token(hass):
"""Test that we cannot retrieve expired access tokens.""" """Test that we cannot retrieve expired access tokens."""
manager = await auth.auth_manager_from_config(hass, []) manager = await auth.auth_manager_from_config(hass, [])
@ -244,15 +218,20 @@ async def test_cannot_retrieve_expired_access_token(hass):
assert refresh_token.client_id == CLIENT_ID assert refresh_token.client_id == CLIENT_ID
access_token = manager.async_create_access_token(refresh_token) access_token = manager.async_create_access_token(refresh_token)
assert manager.async_get_access_token(access_token.token) is access_token assert (
await manager.async_validate_access_token(access_token)
is refresh_token
)
with patch('homeassistant.util.dt.utcnow', with patch('homeassistant.util.dt.utcnow',
return_value=dt_util.utcnow() + return_value=dt_util.utcnow() -
auth_const.ACCESS_TOKEN_EXPIRATION): auth_const.ACCESS_TOKEN_EXPIRATION - timedelta(seconds=11)):
assert manager.async_get_access_token(access_token.token) is None access_token = manager.async_create_access_token(refresh_token)
# Even with unpatched time, it should have been removed from manager assert (
assert manager.async_get_access_token(access_token.token) is None await manager.async_validate_access_token(access_token)
is None
)
async def test_generating_system_user(hass): async def test_generating_system_user(hass):

View file

@ -314,12 +314,18 @@ def mock_registry(hass, mock_entries=None):
class MockUser(auth_models.User): class MockUser(auth_models.User):
"""Mock a user in Home Assistant.""" """Mock a user in Home Assistant."""
def __init__(self, id='mock-id', is_owner=False, is_active=True, def __init__(self, id=None, is_owner=False, is_active=True,
name='Mock User', system_generated=False): name='Mock User', system_generated=False):
"""Initialize mock user.""" """Initialize mock user."""
super().__init__( kwargs = {
id=id, is_owner=is_owner, is_active=is_active, name=name, 'is_owner': is_owner,
system_generated=system_generated) 'is_active': is_active,
'name': name,
'system_generated': system_generated
}
if id is not None:
kwargs['id'] = id
super().__init__(**kwargs)
def add_to_hass(self, hass): def add_to_hass(self, hass):
"""Test helper to add entry to hass.""" """Test helper to add entry to hass."""

View file

@ -44,7 +44,10 @@ async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
assert resp.status == 200 assert resp.status == 200
tokens = await resp.json() tokens = await resp.json()
assert hass.auth.async_get_access_token(tokens['access_token']) is not None assert (
await hass.auth.async_validate_access_token(tokens['access_token'])
is not None
)
# Use refresh token to get more tokens. # Use refresh token to get more tokens.
resp = await client.post('/auth/token', data={ resp = await client.post('/auth/token', data={
@ -56,7 +59,10 @@ async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
assert resp.status == 200 assert resp.status == 200
tokens = await resp.json() tokens = await resp.json()
assert 'refresh_token' not in tokens assert 'refresh_token' not in tokens
assert hass.auth.async_get_access_token(tokens['access_token']) is not None assert (
await hass.auth.async_validate_access_token(tokens['access_token'])
is not None
)
# Test using access token to hit API. # Test using access token to hit API.
resp = await client.get('/api/') resp = await client.get('/api/')
@ -98,7 +104,9 @@ async def test_ws_current_user(hass, hass_ws_client, hass_access_token):
} }
}) })
user = hass_access_token.refresh_token.user refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
user = refresh_token.user
credential = Credentials(auth_provider_type='homeassistant', credential = Credentials(auth_provider_type='homeassistant',
auth_provider_id=None, auth_provider_id=None,
data={}, id='test-id') data={}, id='test-id')
@ -169,7 +177,10 @@ async def test_refresh_token_system_generated(hass, aiohttp_client):
assert resp.status == 200 assert resp.status == 200
tokens = await resp.json() tokens = await resp.json()
assert hass.auth.async_get_access_token(tokens['access_token']) is not None assert (
await hass.auth.async_validate_access_token(tokens['access_token'])
is not None
)
async def test_refresh_token_different_client_id(hass, aiohttp_client): async def test_refresh_token_different_client_id(hass, aiohttp_client):
@ -208,4 +219,7 @@ async def test_refresh_token_different_client_id(hass, aiohttp_client):
assert resp.status == 200 assert resp.status == 200
tokens = await resp.json() tokens = await resp.json()
assert hass.auth.async_get_access_token(tokens['access_token']) is not None assert (
await hass.auth.async_validate_access_token(tokens['access_token'])
is not None
)

View file

@ -52,7 +52,7 @@ async def async_get_code(hass, aiohttp_client):
'user': user, 'user': user,
'code': step['result'], 'code': step['result'],
'client': client, 'client': client,
'access_token': access_token.token, 'access_token': access_token,
} }

View file

@ -122,11 +122,13 @@ async def test_delete_unable_self_account(hass, hass_ws_client,
hass_access_token): hass_access_token):
"""Test we cannot delete our own account.""" """Test we cannot delete our own account."""
client = await hass_ws_client(hass, hass_access_token) client = await hass_ws_client(hass, hass_access_token)
refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
await client.send_json({ await client.send_json({
'id': 5, 'id': 5,
'type': auth_config.WS_TYPE_DELETE, 'type': auth_config.WS_TYPE_DELETE,
'user_id': hass_access_token.refresh_token.user.id, 'user_id': refresh_token.user.id,
}) })
result = await client.receive_json() result = await client.receive_json()
@ -137,7 +139,9 @@ async def test_delete_unable_self_account(hass, hass_ws_client,
async def test_delete_unknown_user(hass, hass_ws_client, hass_access_token): async def test_delete_unknown_user(hass, hass_ws_client, hass_access_token):
"""Test we cannot delete an unknown user.""" """Test we cannot delete an unknown user."""
client = await hass_ws_client(hass, hass_access_token) client = await hass_ws_client(hass, hass_access_token)
hass_access_token.refresh_token.user.is_owner = True refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
refresh_token.user.is_owner = True
await client.send_json({ await client.send_json({
'id': 5, 'id': 5,
@ -153,7 +157,9 @@ async def test_delete_unknown_user(hass, hass_ws_client, hass_access_token):
async def test_delete(hass, hass_ws_client, hass_access_token): async def test_delete(hass, hass_ws_client, hass_access_token):
"""Test delete command works.""" """Test delete command works."""
client = await hass_ws_client(hass, hass_access_token) client = await hass_ws_client(hass, hass_access_token)
hass_access_token.refresh_token.user.is_owner = True refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
refresh_token.user.is_owner = True
test_user = MockUser( test_user = MockUser(
id='efg', id='efg',
).add_to_hass(hass) ).add_to_hass(hass)
@ -174,7 +180,9 @@ async def test_delete(hass, hass_ws_client, hass_access_token):
async def test_create(hass, hass_ws_client, hass_access_token): async def test_create(hass, hass_ws_client, hass_access_token):
"""Test create command works.""" """Test create command works."""
client = await hass_ws_client(hass, hass_access_token) client = await hass_ws_client(hass, hass_access_token)
hass_access_token.refresh_token.user.is_owner = True refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
refresh_token.user.is_owner = True
assert len(await hass.auth.async_get_users()) == 1 assert len(await hass.auth.async_get_users()) == 1

View file

@ -9,7 +9,7 @@ from tests.common import MockUser, register_auth_provider
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup_config(hass, aiohttp_client): def setup_config(hass):
"""Fixture that sets up the auth provider homeassistant module.""" """Fixture that sets up the auth provider homeassistant module."""
hass.loop.run_until_complete(register_auth_provider(hass, { hass.loop.run_until_complete(register_auth_provider(hass, {
'type': 'homeassistant' 'type': 'homeassistant'
@ -22,7 +22,9 @@ async def test_create_auth_system_generated_user(hass, hass_access_token,
"""Test we can't add auth to system generated users.""" """Test we can't add auth to system generated users."""
system_user = MockUser(system_generated=True).add_to_hass(hass) system_user = MockUser(system_generated=True).add_to_hass(hass)
client = await hass_ws_client(hass, hass_access_token) client = await hass_ws_client(hass, hass_access_token)
hass_access_token.refresh_token.user.is_owner = True refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
refresh_token.user.is_owner = True
await client.send_json({ await client.send_json({
'id': 5, 'id': 5,
@ -47,7 +49,9 @@ async def test_create_auth_unknown_user(hass_ws_client, hass,
hass_access_token): hass_access_token):
"""Test create pointing at unknown user.""" """Test create pointing at unknown user."""
client = await hass_ws_client(hass, hass_access_token) client = await hass_ws_client(hass, hass_access_token)
hass_access_token.refresh_token.user.is_owner = True refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
refresh_token.user.is_owner = True
await client.send_json({ await client.send_json({
'id': 5, 'id': 5,
@ -86,7 +90,9 @@ async def test_create_auth(hass, hass_ws_client, hass_access_token,
"""Test create auth command works.""" """Test create auth command works."""
client = await hass_ws_client(hass, hass_access_token) client = await hass_ws_client(hass, hass_access_token)
user = MockUser().add_to_hass(hass) user = MockUser().add_to_hass(hass)
hass_access_token.refresh_token.user.is_owner = True refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
refresh_token.user.is_owner = True
assert len(user.credentials) == 0 assert len(user.credentials) == 0
@ -117,7 +123,9 @@ async def test_create_auth_duplicate_username(hass, hass_ws_client,
"""Test we can't create auth with a duplicate username.""" """Test we can't create auth with a duplicate username."""
client = await hass_ws_client(hass, hass_access_token) client = await hass_ws_client(hass, hass_access_token)
user = MockUser().add_to_hass(hass) user = MockUser().add_to_hass(hass)
hass_access_token.refresh_token.user.is_owner = True refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
refresh_token.user.is_owner = True
hass_storage[prov_ha.STORAGE_KEY] = { hass_storage[prov_ha.STORAGE_KEY] = {
'version': 1, 'version': 1,
@ -145,7 +153,9 @@ async def test_delete_removes_just_auth(hass_ws_client, hass, hass_storage,
hass_access_token): hass_access_token):
"""Test deleting an auth without being connected to a user.""" """Test deleting an auth without being connected to a user."""
client = await hass_ws_client(hass, hass_access_token) client = await hass_ws_client(hass, hass_access_token)
hass_access_token.refresh_token.user.is_owner = True refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
refresh_token.user.is_owner = True
hass_storage[prov_ha.STORAGE_KEY] = { hass_storage[prov_ha.STORAGE_KEY] = {
'version': 1, 'version': 1,
@ -171,7 +181,9 @@ async def test_delete_removes_credential(hass, hass_ws_client,
hass_access_token, hass_storage): hass_access_token, hass_storage):
"""Test deleting auth that is connected to a user.""" """Test deleting auth that is connected to a user."""
client = await hass_ws_client(hass, hass_access_token) client = await hass_ws_client(hass, hass_access_token)
hass_access_token.refresh_token.user.is_owner = True refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
refresh_token.user.is_owner = True
user = MockUser().add_to_hass(hass) user = MockUser().add_to_hass(hass)
user.credentials.append( user.credentials.append(
@ -216,7 +228,9 @@ async def test_delete_requires_owner(hass, hass_ws_client, hass_access_token):
async def test_delete_unknown_auth(hass, hass_ws_client, hass_access_token): async def test_delete_unknown_auth(hass, hass_ws_client, hass_access_token):
"""Test trying to delete an unknown auth username.""" """Test trying to delete an unknown auth username."""
client = await hass_ws_client(hass, hass_access_token) client = await hass_ws_client(hass, hass_access_token)
hass_access_token.refresh_token.user.is_owner = True refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
refresh_token.user.is_owner = True
await client.send_json({ await client.send_json({
'id': 5, 'id': 5,
@ -240,7 +254,9 @@ async def test_change_password(hass, hass_ws_client, hass_access_token):
'username': 'test-user' 'username': 'test-user'
}) })
user = hass_access_token.refresh_token.user refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
user = refresh_token.user
await hass.auth.async_link_user(user, credentials) await hass.auth.async_link_user(user, credentials)
client = await hass_ws_client(hass, hass_access_token) client = await hass_ws_client(hass, hass_access_token)
@ -268,7 +284,9 @@ async def test_change_password_wrong_pw(hass, hass_ws_client,
'username': 'test-user' 'username': 'test-user'
}) })
user = hass_access_token.refresh_token.user refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
user = refresh_token.user
await hass.auth.async_link_user(user, credentials) await hass.auth.async_link_user(user, credentials)
client = await hass_ws_client(hass, hass_access_token) client = await hass_ws_client(hass, hass_access_token)

View file

@ -28,7 +28,7 @@ def hass_ws_client(aiohttp_client):
await websocket.send_json({ await websocket.send_json({
'type': websocket_api.TYPE_AUTH, 'type': websocket_api.TYPE_AUTH,
'access_token': access_token.token 'access_token': access_token
}) })
auth_ok = await websocket.receive_json() auth_ok = await websocket.receive_json()

View file

@ -106,7 +106,11 @@ async def test_setup_api_push_api_data_default(hass, aioclient_mock,
) )
assert hassio_user is not None assert hassio_user is not None
assert hassio_user.system_generated assert hassio_user.system_generated
assert refresh_token in hassio_user.refresh_tokens for token in hassio_user.refresh_tokens.values():
if token.token == refresh_token:
break
else:
assert False, 'refresh token not found'
async def test_setup_api_push_api_data_no_auth(hass, aioclient_mock, async def test_setup_api_push_api_data_no_auth(hass, aioclient_mock,

View file

@ -156,9 +156,9 @@ async def test_access_with_trusted_ip(app2, aiohttp_client):
async def test_auth_active_access_with_access_token_in_header( async def test_auth_active_access_with_access_token_in_header(
app, aiohttp_client, hass_access_token): hass, app, aiohttp_client, hass_access_token):
"""Test access with access token in header.""" """Test access with access token in header."""
token = hass_access_token.token token = hass_access_token
setup_auth(app, [], True, api_password=None) setup_auth(app, [], True, api_password=None)
client = await aiohttp_client(app) client = await aiohttp_client(app)
@ -182,7 +182,9 @@ async def test_auth_active_access_with_access_token_in_header(
'/', headers={'Authorization': 'BEARER {}'.format(token)}) '/', headers={'Authorization': 'BEARER {}'.format(token)})
assert req.status == 401 assert req.status == 401
hass_access_token.refresh_token.user.is_active = False refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
refresh_token.user.is_active = False
req = await client.get( req = await client.get(
'/', headers={'Authorization': 'Bearer {}'.format(token)}) '/', headers={'Authorization': 'Bearer {}'.format(token)})
assert req.status == 401 assert req.status == 401

View file

@ -448,13 +448,15 @@ async def test_api_fire_event_context(hass, mock_api_client,
await mock_api_client.post( await mock_api_client.post(
const.URL_API_EVENTS_EVENT.format("test.event"), const.URL_API_EVENTS_EVENT.format("test.event"),
headers={ headers={
'authorization': 'Bearer {}'.format(hass_access_token.token) 'authorization': 'Bearer {}'.format(hass_access_token)
}) })
await hass.async_block_till_done() await hass.async_block_till_done()
refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
assert len(test_value) == 1 assert len(test_value) == 1
assert test_value[0].context.user_id == \ assert test_value[0].context.user_id == refresh_token.user.id
hass_access_token.refresh_token.user.id
async def test_api_call_service_context(hass, mock_api_client, async def test_api_call_service_context(hass, mock_api_client,
@ -465,12 +467,15 @@ async def test_api_call_service_context(hass, mock_api_client,
await mock_api_client.post( await mock_api_client.post(
'/api/services/test_domain/test_service', '/api/services/test_domain/test_service',
headers={ headers={
'authorization': 'Bearer {}'.format(hass_access_token.token) 'authorization': 'Bearer {}'.format(hass_access_token)
}) })
await hass.async_block_till_done() await hass.async_block_till_done()
refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
assert len(calls) == 1 assert len(calls) == 1
assert calls[0].context.user_id == hass_access_token.refresh_token.user.id assert calls[0].context.user_id == refresh_token.user.id
async def test_api_set_state_context(hass, mock_api_client, hass_access_token): async def test_api_set_state_context(hass, mock_api_client, hass_access_token):
@ -481,8 +486,11 @@ async def test_api_set_state_context(hass, mock_api_client, hass_access_token):
'state': 'on' 'state': 'on'
}, },
headers={ headers={
'authorization': 'Bearer {}'.format(hass_access_token.token) 'authorization': 'Bearer {}'.format(hass_access_token)
}) })
refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
state = hass.states.get('light.kitchen') state = hass.states.get('light.kitchen')
assert state.context.user_id == hass_access_token.refresh_token.user.id assert state.context.user_id == refresh_token.user.id

View file

@ -334,7 +334,7 @@ async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token):
await ws.send_json({ await ws.send_json({
'type': wapi.TYPE_AUTH, 'type': wapi.TYPE_AUTH,
'access_token': hass_access_token.token 'access_token': hass_access_token
}) })
auth_msg = await ws.receive_json() auth_msg = await ws.receive_json()
@ -344,7 +344,9 @@ async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token):
async def test_auth_active_user_inactive(hass, aiohttp_client, async def test_auth_active_user_inactive(hass, aiohttp_client,
hass_access_token): hass_access_token):
"""Test authenticating with a token.""" """Test authenticating with a token."""
hass_access_token.refresh_token.user.is_active = False refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
refresh_token.user.is_active = False
assert await async_setup_component(hass, 'websocket_api', { assert await async_setup_component(hass, 'websocket_api', {
'http': { 'http': {
'api_password': API_PASSWORD 'api_password': API_PASSWORD
@ -361,7 +363,7 @@ async def test_auth_active_user_inactive(hass, aiohttp_client,
await ws.send_json({ await ws.send_json({
'type': wapi.TYPE_AUTH, 'type': wapi.TYPE_AUTH,
'access_token': hass_access_token.token 'access_token': hass_access_token
}) })
auth_msg = await ws.receive_json() auth_msg = await ws.receive_json()
@ -465,7 +467,7 @@ async def test_call_service_context_with_user(hass, aiohttp_client,
await ws.send_json({ await ws.send_json({
'type': wapi.TYPE_AUTH, 'type': wapi.TYPE_AUTH,
'access_token': hass_access_token.token 'access_token': hass_access_token
}) })
auth_msg = await ws.receive_json() auth_msg = await ws.receive_json()
@ -484,12 +486,15 @@ async def test_call_service_context_with_user(hass, aiohttp_client,
msg = await ws.receive_json() msg = await ws.receive_json()
assert msg['success'] assert msg['success']
refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
assert len(calls) == 1 assert len(calls) == 1
call = calls[0] call = calls[0]
assert call.domain == 'domain_test' assert call.domain == 'domain_test'
assert call.service == 'test_service' assert call.service == 'test_service'
assert call.data == {'hello': 'world'} assert call.data == {'hello': 'world'}
assert call.context.user_id == hass_access_token.refresh_token.user.id assert call.context.user_id == refresh_token.user.id
async def test_call_service_context_no_user(hass, aiohttp_client): async def test_call_service_context_no_user(hass, aiohttp_client):