Use JWT for access tokens (#15972)
* Use JWT for access tokens * Update requirements * Improvements
This commit is contained in:
parent
ee5d49a033
commit
e776f88eec
20 changed files with 203 additions and 155 deletions
|
@ -4,10 +4,12 @@ import logging
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import List, Awaitable
|
from typing import List, Awaitable
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
|
||||||
from homeassistant import data_entry_flow
|
from homeassistant import data_entry_flow
|
||||||
from homeassistant.core import callback, HomeAssistant
|
from homeassistant.core import callback, HomeAssistant
|
||||||
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
from . import models
|
|
||||||
from . import auth_store
|
from . import auth_store
|
||||||
from .providers import auth_provider_from_config
|
from .providers import auth_provider_from_config
|
||||||
|
|
||||||
|
@ -54,7 +56,6 @@ class AuthManager:
|
||||||
self.login_flow = data_entry_flow.FlowManager(
|
self.login_flow = data_entry_flow.FlowManager(
|
||||||
hass, self._async_create_login_flow,
|
hass, self._async_create_login_flow,
|
||||||
self._async_finish_login_flow)
|
self._async_finish_login_flow)
|
||||||
self._access_tokens = OrderedDict()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def active(self):
|
def active(self):
|
||||||
|
@ -181,35 +182,56 @@ class AuthManager:
|
||||||
|
|
||||||
return await self._store.async_create_refresh_token(user, client_id)
|
return await self._store.async_create_refresh_token(user, client_id)
|
||||||
|
|
||||||
async def async_get_refresh_token(self, token):
|
async def async_get_refresh_token(self, token_id):
|
||||||
|
"""Get refresh token by id."""
|
||||||
|
return await self._store.async_get_refresh_token(token_id)
|
||||||
|
|
||||||
|
async def async_get_refresh_token_by_token(self, token):
|
||||||
"""Get refresh token by token."""
|
"""Get refresh token by token."""
|
||||||
return await self._store.async_get_refresh_token(token)
|
return await self._store.async_get_refresh_token_by_token(token)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_create_access_token(self, refresh_token):
|
def async_create_access_token(self, refresh_token):
|
||||||
"""Create a new access token."""
|
"""Create a new access token."""
|
||||||
access_token = models.AccessToken(refresh_token=refresh_token)
|
# pylint: disable=no-self-use
|
||||||
self._access_tokens[access_token.token] = access_token
|
return jwt.encode({
|
||||||
return access_token
|
'iss': refresh_token.id,
|
||||||
|
'iat': dt_util.utcnow(),
|
||||||
|
'exp': dt_util.utcnow() + refresh_token.access_token_expiration,
|
||||||
|
}, refresh_token.jwt_key, algorithm='HS256').decode()
|
||||||
|
|
||||||
@callback
|
async def async_validate_access_token(self, token):
|
||||||
def async_get_access_token(self, token):
|
"""Return if an access token is valid."""
|
||||||
"""Get an access token."""
|
try:
|
||||||
tkn = self._access_tokens.get(token)
|
unverif_claims = jwt.decode(token, verify=False)
|
||||||
|
except jwt.InvalidTokenError:
|
||||||
if tkn is None:
|
|
||||||
_LOGGER.debug('Attempt to get non-existing access token')
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if tkn.expired or not tkn.refresh_token.user.is_active:
|
refresh_token = await self.async_get_refresh_token(
|
||||||
if tkn.expired:
|
unverif_claims.get('iss'))
|
||||||
_LOGGER.debug('Attempt to get expired access token')
|
|
||||||
else:
|
if refresh_token is None:
|
||||||
_LOGGER.debug('Attempt to get access token for inactive user')
|
jwt_key = ''
|
||||||
self._access_tokens.pop(token)
|
issuer = ''
|
||||||
|
else:
|
||||||
|
jwt_key = refresh_token.jwt_key
|
||||||
|
issuer = refresh_token.id
|
||||||
|
|
||||||
|
try:
|
||||||
|
jwt.decode(
|
||||||
|
token,
|
||||||
|
jwt_key,
|
||||||
|
leeway=10,
|
||||||
|
issuer=issuer,
|
||||||
|
algorithms=['HS256']
|
||||||
|
)
|
||||||
|
except jwt.InvalidTokenError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return tkn
|
if not refresh_token.user.is_active:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return refresh_token
|
||||||
|
|
||||||
async def _async_create_login_flow(self, handler, *, context, data):
|
async def _async_create_login_flow(self, handler, *, context, data):
|
||||||
"""Create a login flow."""
|
"""Create a login flow."""
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Storage for auth models."""
|
"""Storage for auth models."""
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
import hmac
|
||||||
|
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
|
@ -110,22 +111,36 @@ class AuthStore:
|
||||||
async def async_create_refresh_token(self, user, client_id=None):
|
async def async_create_refresh_token(self, user, client_id=None):
|
||||||
"""Create a new token for a user."""
|
"""Create a new token for a user."""
|
||||||
refresh_token = models.RefreshToken(user=user, client_id=client_id)
|
refresh_token = models.RefreshToken(user=user, client_id=client_id)
|
||||||
user.refresh_tokens[refresh_token.token] = refresh_token
|
user.refresh_tokens[refresh_token.id] = refresh_token
|
||||||
await self.async_save()
|
await self.async_save()
|
||||||
return refresh_token
|
return refresh_token
|
||||||
|
|
||||||
async def async_get_refresh_token(self, token):
|
async def async_get_refresh_token(self, token_id):
|
||||||
"""Get refresh token by token."""
|
"""Get refresh token by id."""
|
||||||
if self._users is None:
|
if self._users is None:
|
||||||
await self.async_load()
|
await self.async_load()
|
||||||
|
|
||||||
for user in self._users.values():
|
for user in self._users.values():
|
||||||
refresh_token = user.refresh_tokens.get(token)
|
refresh_token = user.refresh_tokens.get(token_id)
|
||||||
if refresh_token is not None:
|
if refresh_token is not None:
|
||||||
return refresh_token
|
return refresh_token
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def async_get_refresh_token_by_token(self, token):
|
||||||
|
"""Get refresh token by token."""
|
||||||
|
if self._users is None:
|
||||||
|
await self.async_load()
|
||||||
|
|
||||||
|
found = None
|
||||||
|
|
||||||
|
for user in self._users.values():
|
||||||
|
for refresh_token in user.refresh_tokens.values():
|
||||||
|
if hmac.compare_digest(refresh_token.token, token):
|
||||||
|
found = refresh_token
|
||||||
|
|
||||||
|
return found
|
||||||
|
|
||||||
async def async_load(self):
|
async def async_load(self):
|
||||||
"""Load the users."""
|
"""Load the users."""
|
||||||
data = await self._store.async_load()
|
data = await self._store.async_load()
|
||||||
|
@ -153,9 +168,11 @@ class AuthStore:
|
||||||
data=cred_dict['data'],
|
data=cred_dict['data'],
|
||||||
))
|
))
|
||||||
|
|
||||||
refresh_tokens = OrderedDict()
|
|
||||||
|
|
||||||
for rt_dict in data['refresh_tokens']:
|
for rt_dict in data['refresh_tokens']:
|
||||||
|
# Filter out the old keys that don't have jwt_key (pre-0.76)
|
||||||
|
if 'jwt_key' not in rt_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
token = models.RefreshToken(
|
token = models.RefreshToken(
|
||||||
id=rt_dict['id'],
|
id=rt_dict['id'],
|
||||||
user=users[rt_dict['user_id']],
|
user=users[rt_dict['user_id']],
|
||||||
|
@ -164,18 +181,9 @@ class AuthStore:
|
||||||
access_token_expiration=timedelta(
|
access_token_expiration=timedelta(
|
||||||
seconds=rt_dict['access_token_expiration']),
|
seconds=rt_dict['access_token_expiration']),
|
||||||
token=rt_dict['token'],
|
token=rt_dict['token'],
|
||||||
|
jwt_key=rt_dict['jwt_key']
|
||||||
)
|
)
|
||||||
refresh_tokens[token.id] = token
|
users[rt_dict['user_id']].refresh_tokens[token.id] = token
|
||||||
users[rt_dict['user_id']].refresh_tokens[token.token] = token
|
|
||||||
|
|
||||||
for ac_dict in data['access_tokens']:
|
|
||||||
refresh_token = refresh_tokens[ac_dict['refresh_token_id']]
|
|
||||||
token = models.AccessToken(
|
|
||||||
refresh_token=refresh_token,
|
|
||||||
created_at=dt_util.parse_datetime(ac_dict['created_at']),
|
|
||||||
token=ac_dict['token'],
|
|
||||||
)
|
|
||||||
refresh_token.access_tokens.append(token)
|
|
||||||
|
|
||||||
self._users = users
|
self._users = users
|
||||||
|
|
||||||
|
@ -213,27 +221,15 @@ class AuthStore:
|
||||||
'access_token_expiration':
|
'access_token_expiration':
|
||||||
refresh_token.access_token_expiration.total_seconds(),
|
refresh_token.access_token_expiration.total_seconds(),
|
||||||
'token': refresh_token.token,
|
'token': refresh_token.token,
|
||||||
|
'jwt_key': refresh_token.jwt_key,
|
||||||
}
|
}
|
||||||
for user in self._users.values()
|
for user in self._users.values()
|
||||||
for refresh_token in user.refresh_tokens.values()
|
for refresh_token in user.refresh_tokens.values()
|
||||||
]
|
]
|
||||||
|
|
||||||
access_tokens = [
|
|
||||||
{
|
|
||||||
'id': user.id,
|
|
||||||
'refresh_token_id': refresh_token.id,
|
|
||||||
'created_at': access_token.created_at.isoformat(),
|
|
||||||
'token': access_token.token,
|
|
||||||
}
|
|
||||||
for user in self._users.values()
|
|
||||||
for refresh_token in user.refresh_tokens.values()
|
|
||||||
for access_token in refresh_token.access_tokens
|
|
||||||
]
|
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
'users': users,
|
'users': users,
|
||||||
'credentials': credentials,
|
'credentials': credentials,
|
||||||
'access_tokens': access_tokens,
|
|
||||||
'refresh_tokens': refresh_tokens,
|
'refresh_tokens': refresh_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -39,26 +39,8 @@ class RefreshToken:
|
||||||
default=ACCESS_TOKEN_EXPIRATION)
|
default=ACCESS_TOKEN_EXPIRATION)
|
||||||
token = attr.ib(type=str,
|
token = attr.ib(type=str,
|
||||||
default=attr.Factory(lambda: generate_secret(64)))
|
default=attr.Factory(lambda: generate_secret(64)))
|
||||||
access_tokens = attr.ib(type=list, default=attr.Factory(list), cmp=False)
|
jwt_key = attr.ib(type=str,
|
||||||
|
default=attr.Factory(lambda: generate_secret(64)))
|
||||||
|
|
||||||
@attr.s(slots=True)
|
|
||||||
class AccessToken:
|
|
||||||
"""Access token to access the API.
|
|
||||||
|
|
||||||
These will only ever be stored in memory and not be persisted.
|
|
||||||
"""
|
|
||||||
|
|
||||||
refresh_token = attr.ib(type=RefreshToken)
|
|
||||||
created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow))
|
|
||||||
token = attr.ib(type=str,
|
|
||||||
default=attr.Factory(generate_secret))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def expired(self):
|
|
||||||
"""Return if this token has expired."""
|
|
||||||
expires = self.created_at + self.refresh_token.access_token_expiration
|
|
||||||
return dt_util.utcnow() > expires
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True)
|
@attr.s(slots=True)
|
||||||
|
|
|
@ -155,7 +155,7 @@ class GrantTokenView(HomeAssistantView):
|
||||||
access_token = hass.auth.async_create_access_token(refresh_token)
|
access_token = hass.auth.async_create_access_token(refresh_token)
|
||||||
|
|
||||||
return self.json({
|
return self.json({
|
||||||
'access_token': access_token.token,
|
'access_token': access_token,
|
||||||
'token_type': 'Bearer',
|
'token_type': 'Bearer',
|
||||||
'refresh_token': refresh_token.token,
|
'refresh_token': refresh_token.token,
|
||||||
'expires_in':
|
'expires_in':
|
||||||
|
@ -178,7 +178,7 @@ class GrantTokenView(HomeAssistantView):
|
||||||
'error': 'invalid_request',
|
'error': 'invalid_request',
|
||||||
}, status_code=400)
|
}, status_code=400)
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_get_refresh_token(token)
|
refresh_token = await hass.auth.async_get_refresh_token_by_token(token)
|
||||||
|
|
||||||
if refresh_token is None:
|
if refresh_token is None:
|
||||||
return self.json({
|
return self.json({
|
||||||
|
@ -193,7 +193,7 @@ class GrantTokenView(HomeAssistantView):
|
||||||
access_token = hass.auth.async_create_access_token(refresh_token)
|
access_token = hass.auth.async_create_access_token(refresh_token)
|
||||||
|
|
||||||
return self.json({
|
return self.json({
|
||||||
'access_token': access_token.token,
|
'access_token': access_token,
|
||||||
'token_type': 'Bearer',
|
'token_type': 'Bearer',
|
||||||
'expires_in':
|
'expires_in':
|
||||||
int(refresh_token.access_token_expiration.total_seconds()),
|
int(refresh_token.access_token_expiration.total_seconds()),
|
||||||
|
|
|
@ -106,11 +106,11 @@ async def async_validate_auth_header(request, api_password=None):
|
||||||
|
|
||||||
if auth_type == 'Bearer':
|
if auth_type == 'Bearer':
|
||||||
hass = request.app['hass']
|
hass = request.app['hass']
|
||||||
access_token = hass.auth.async_get_access_token(auth_val)
|
refresh_token = await hass.auth.async_validate_access_token(auth_val)
|
||||||
if access_token is None:
|
if refresh_token is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
request['hass_user'] = access_token.refresh_token.user
|
request['hass_user'] = refresh_token.user
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if auth_type == 'Basic' and api_password is not None:
|
if auth_type == 'Basic' and api_password is not None:
|
||||||
|
|
|
@ -355,11 +355,12 @@ class ActiveConnection:
|
||||||
|
|
||||||
if self.hass.auth.active and 'access_token' in msg:
|
if self.hass.auth.active and 'access_token' in msg:
|
||||||
self.debug("Received access_token")
|
self.debug("Received access_token")
|
||||||
token = self.hass.auth.async_get_access_token(
|
refresh_token = \
|
||||||
msg['access_token'])
|
await self.hass.auth.async_validate_access_token(
|
||||||
authenticated = token is not None
|
msg['access_token'])
|
||||||
|
authenticated = refresh_token is not None
|
||||||
if authenticated:
|
if authenticated:
|
||||||
request['hass_user'] = token.refresh_token.user
|
request['hass_user'] = refresh_token.user
|
||||||
|
|
||||||
elif ((not self.hass.auth.active or
|
elif ((not self.hass.auth.active or
|
||||||
self.hass.auth.support_legacy) and
|
self.hass.auth.support_legacy) and
|
||||||
|
|
|
@ -4,6 +4,7 @@ async_timeout==3.0.0
|
||||||
attrs==18.1.0
|
attrs==18.1.0
|
||||||
certifi>=2018.04.16
|
certifi>=2018.04.16
|
||||||
jinja2>=2.10
|
jinja2>=2.10
|
||||||
|
PyJWT==1.6.4
|
||||||
pip>=8.0.3
|
pip>=8.0.3
|
||||||
pytz>=2018.04
|
pytz>=2018.04
|
||||||
pyyaml>=3.13,<4
|
pyyaml>=3.13,<4
|
||||||
|
|
|
@ -5,6 +5,7 @@ async_timeout==3.0.0
|
||||||
attrs==18.1.0
|
attrs==18.1.0
|
||||||
certifi>=2018.04.16
|
certifi>=2018.04.16
|
||||||
jinja2>=2.10
|
jinja2>=2.10
|
||||||
|
PyJWT==1.6.4
|
||||||
pip>=8.0.3
|
pip>=8.0.3
|
||||||
pytz>=2018.04
|
pytz>=2018.04
|
||||||
pyyaml>=3.13,<4
|
pyyaml>=3.13,<4
|
||||||
|
|
1
setup.py
1
setup.py
|
@ -38,6 +38,7 @@ REQUIRES = [
|
||||||
'attrs==18.1.0',
|
'attrs==18.1.0',
|
||||||
'certifi>=2018.04.16',
|
'certifi>=2018.04.16',
|
||||||
'jinja2>=2.10',
|
'jinja2>=2.10',
|
||||||
|
'PyJWT==1.6.4',
|
||||||
'pip>=8.0.3',
|
'pip>=8.0.3',
|
||||||
'pytz>=2018.04',
|
'pytz>=2018.04',
|
||||||
'pyyaml>=3.13,<4',
|
'pyyaml>=3.13,<4',
|
||||||
|
|
|
@ -199,9 +199,7 @@ async def test_saving_loading(hass, hass_storage):
|
||||||
})
|
})
|
||||||
user = await manager.async_get_or_create_user(step['result'])
|
user = await manager.async_get_or_create_user(step['result'])
|
||||||
await manager.async_activate_user(user)
|
await manager.async_activate_user(user)
|
||||||
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
|
await manager.async_create_refresh_token(user, CLIENT_ID)
|
||||||
|
|
||||||
manager.async_create_access_token(refresh_token)
|
|
||||||
|
|
||||||
await flush_store(manager._store._store)
|
await flush_store(manager._store._store)
|
||||||
|
|
||||||
|
@ -211,30 +209,6 @@ async def test_saving_loading(hass, hass_storage):
|
||||||
assert users[0] == user
|
assert users[0] == user
|
||||||
|
|
||||||
|
|
||||||
def test_access_token_expired():
|
|
||||||
"""Test that the expired property on access tokens work."""
|
|
||||||
refresh_token = auth_models.RefreshToken(
|
|
||||||
user=None,
|
|
||||||
client_id='bla'
|
|
||||||
)
|
|
||||||
|
|
||||||
access_token = auth_models.AccessToken(
|
|
||||||
refresh_token=refresh_token
|
|
||||||
)
|
|
||||||
|
|
||||||
assert access_token.expired is False
|
|
||||||
|
|
||||||
with patch('homeassistant.util.dt.utcnow',
|
|
||||||
return_value=dt_util.utcnow() +
|
|
||||||
auth_const.ACCESS_TOKEN_EXPIRATION):
|
|
||||||
assert access_token.expired is True
|
|
||||||
|
|
||||||
almost_exp = \
|
|
||||||
dt_util.utcnow() + auth_const.ACCESS_TOKEN_EXPIRATION - timedelta(1)
|
|
||||||
with patch('homeassistant.util.dt.utcnow', return_value=almost_exp):
|
|
||||||
assert access_token.expired is False
|
|
||||||
|
|
||||||
|
|
||||||
async def test_cannot_retrieve_expired_access_token(hass):
|
async def test_cannot_retrieve_expired_access_token(hass):
|
||||||
"""Test that we cannot retrieve expired access tokens."""
|
"""Test that we cannot retrieve expired access tokens."""
|
||||||
manager = await auth.auth_manager_from_config(hass, [])
|
manager = await auth.auth_manager_from_config(hass, [])
|
||||||
|
@ -244,15 +218,20 @@ async def test_cannot_retrieve_expired_access_token(hass):
|
||||||
assert refresh_token.client_id == CLIENT_ID
|
assert refresh_token.client_id == CLIENT_ID
|
||||||
|
|
||||||
access_token = manager.async_create_access_token(refresh_token)
|
access_token = manager.async_create_access_token(refresh_token)
|
||||||
assert manager.async_get_access_token(access_token.token) is access_token
|
assert (
|
||||||
|
await manager.async_validate_access_token(access_token)
|
||||||
|
is refresh_token
|
||||||
|
)
|
||||||
|
|
||||||
with patch('homeassistant.util.dt.utcnow',
|
with patch('homeassistant.util.dt.utcnow',
|
||||||
return_value=dt_util.utcnow() +
|
return_value=dt_util.utcnow() -
|
||||||
auth_const.ACCESS_TOKEN_EXPIRATION):
|
auth_const.ACCESS_TOKEN_EXPIRATION - timedelta(seconds=11)):
|
||||||
assert manager.async_get_access_token(access_token.token) is None
|
access_token = manager.async_create_access_token(refresh_token)
|
||||||
|
|
||||||
# Even with unpatched time, it should have been removed from manager
|
assert (
|
||||||
assert manager.async_get_access_token(access_token.token) is None
|
await manager.async_validate_access_token(access_token)
|
||||||
|
is None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_generating_system_user(hass):
|
async def test_generating_system_user(hass):
|
||||||
|
|
|
@ -314,12 +314,18 @@ def mock_registry(hass, mock_entries=None):
|
||||||
class MockUser(auth_models.User):
|
class MockUser(auth_models.User):
|
||||||
"""Mock a user in Home Assistant."""
|
"""Mock a user in Home Assistant."""
|
||||||
|
|
||||||
def __init__(self, id='mock-id', is_owner=False, is_active=True,
|
def __init__(self, id=None, is_owner=False, is_active=True,
|
||||||
name='Mock User', system_generated=False):
|
name='Mock User', system_generated=False):
|
||||||
"""Initialize mock user."""
|
"""Initialize mock user."""
|
||||||
super().__init__(
|
kwargs = {
|
||||||
id=id, is_owner=is_owner, is_active=is_active, name=name,
|
'is_owner': is_owner,
|
||||||
system_generated=system_generated)
|
'is_active': is_active,
|
||||||
|
'name': name,
|
||||||
|
'system_generated': system_generated
|
||||||
|
}
|
||||||
|
if id is not None:
|
||||||
|
kwargs['id'] = id
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def add_to_hass(self, hass):
|
def add_to_hass(self, hass):
|
||||||
"""Test helper to add entry to hass."""
|
"""Test helper to add entry to hass."""
|
||||||
|
|
|
@ -44,7 +44,10 @@ async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
tokens = await resp.json()
|
tokens = await resp.json()
|
||||||
|
|
||||||
assert hass.auth.async_get_access_token(tokens['access_token']) is not None
|
assert (
|
||||||
|
await hass.auth.async_validate_access_token(tokens['access_token'])
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
# Use refresh token to get more tokens.
|
# Use refresh token to get more tokens.
|
||||||
resp = await client.post('/auth/token', data={
|
resp = await client.post('/auth/token', data={
|
||||||
|
@ -56,7 +59,10 @@ async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
tokens = await resp.json()
|
tokens = await resp.json()
|
||||||
assert 'refresh_token' not in tokens
|
assert 'refresh_token' not in tokens
|
||||||
assert hass.auth.async_get_access_token(tokens['access_token']) is not None
|
assert (
|
||||||
|
await hass.auth.async_validate_access_token(tokens['access_token'])
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
# Test using access token to hit API.
|
# Test using access token to hit API.
|
||||||
resp = await client.get('/api/')
|
resp = await client.get('/api/')
|
||||||
|
@ -98,7 +104,9 @@ async def test_ws_current_user(hass, hass_ws_client, hass_access_token):
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
user = hass_access_token.refresh_token.user
|
refresh_token = await hass.auth.async_validate_access_token(
|
||||||
|
hass_access_token)
|
||||||
|
user = refresh_token.user
|
||||||
credential = Credentials(auth_provider_type='homeassistant',
|
credential = Credentials(auth_provider_type='homeassistant',
|
||||||
auth_provider_id=None,
|
auth_provider_id=None,
|
||||||
data={}, id='test-id')
|
data={}, id='test-id')
|
||||||
|
@ -169,7 +177,10 @@ async def test_refresh_token_system_generated(hass, aiohttp_client):
|
||||||
|
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
tokens = await resp.json()
|
tokens = await resp.json()
|
||||||
assert hass.auth.async_get_access_token(tokens['access_token']) is not None
|
assert (
|
||||||
|
await hass.auth.async_validate_access_token(tokens['access_token'])
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_refresh_token_different_client_id(hass, aiohttp_client):
|
async def test_refresh_token_different_client_id(hass, aiohttp_client):
|
||||||
|
@ -208,4 +219,7 @@ async def test_refresh_token_different_client_id(hass, aiohttp_client):
|
||||||
|
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
tokens = await resp.json()
|
tokens = await resp.json()
|
||||||
assert hass.auth.async_get_access_token(tokens['access_token']) is not None
|
assert (
|
||||||
|
await hass.auth.async_validate_access_token(tokens['access_token'])
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
|
@ -52,7 +52,7 @@ async def async_get_code(hass, aiohttp_client):
|
||||||
'user': user,
|
'user': user,
|
||||||
'code': step['result'],
|
'code': step['result'],
|
||||||
'client': client,
|
'client': client,
|
||||||
'access_token': access_token.token,
|
'access_token': access_token,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -122,11 +122,13 @@ async def test_delete_unable_self_account(hass, hass_ws_client,
|
||||||
hass_access_token):
|
hass_access_token):
|
||||||
"""Test we cannot delete our own account."""
|
"""Test we cannot delete our own account."""
|
||||||
client = await hass_ws_client(hass, hass_access_token)
|
client = await hass_ws_client(hass, hass_access_token)
|
||||||
|
refresh_token = await hass.auth.async_validate_access_token(
|
||||||
|
hass_access_token)
|
||||||
|
|
||||||
await client.send_json({
|
await client.send_json({
|
||||||
'id': 5,
|
'id': 5,
|
||||||
'type': auth_config.WS_TYPE_DELETE,
|
'type': auth_config.WS_TYPE_DELETE,
|
||||||
'user_id': hass_access_token.refresh_token.user.id,
|
'user_id': refresh_token.user.id,
|
||||||
})
|
})
|
||||||
|
|
||||||
result = await client.receive_json()
|
result = await client.receive_json()
|
||||||
|
@ -137,7 +139,9 @@ async def test_delete_unable_self_account(hass, hass_ws_client,
|
||||||
async def test_delete_unknown_user(hass, hass_ws_client, hass_access_token):
|
async def test_delete_unknown_user(hass, hass_ws_client, hass_access_token):
|
||||||
"""Test we cannot delete an unknown user."""
|
"""Test we cannot delete an unknown user."""
|
||||||
client = await hass_ws_client(hass, hass_access_token)
|
client = await hass_ws_client(hass, hass_access_token)
|
||||||
hass_access_token.refresh_token.user.is_owner = True
|
refresh_token = await hass.auth.async_validate_access_token(
|
||||||
|
hass_access_token)
|
||||||
|
refresh_token.user.is_owner = True
|
||||||
|
|
||||||
await client.send_json({
|
await client.send_json({
|
||||||
'id': 5,
|
'id': 5,
|
||||||
|
@ -153,7 +157,9 @@ async def test_delete_unknown_user(hass, hass_ws_client, hass_access_token):
|
||||||
async def test_delete(hass, hass_ws_client, hass_access_token):
|
async def test_delete(hass, hass_ws_client, hass_access_token):
|
||||||
"""Test delete command works."""
|
"""Test delete command works."""
|
||||||
client = await hass_ws_client(hass, hass_access_token)
|
client = await hass_ws_client(hass, hass_access_token)
|
||||||
hass_access_token.refresh_token.user.is_owner = True
|
refresh_token = await hass.auth.async_validate_access_token(
|
||||||
|
hass_access_token)
|
||||||
|
refresh_token.user.is_owner = True
|
||||||
test_user = MockUser(
|
test_user = MockUser(
|
||||||
id='efg',
|
id='efg',
|
||||||
).add_to_hass(hass)
|
).add_to_hass(hass)
|
||||||
|
@ -174,7 +180,9 @@ async def test_delete(hass, hass_ws_client, hass_access_token):
|
||||||
async def test_create(hass, hass_ws_client, hass_access_token):
|
async def test_create(hass, hass_ws_client, hass_access_token):
|
||||||
"""Test create command works."""
|
"""Test create command works."""
|
||||||
client = await hass_ws_client(hass, hass_access_token)
|
client = await hass_ws_client(hass, hass_access_token)
|
||||||
hass_access_token.refresh_token.user.is_owner = True
|
refresh_token = await hass.auth.async_validate_access_token(
|
||||||
|
hass_access_token)
|
||||||
|
refresh_token.user.is_owner = True
|
||||||
|
|
||||||
assert len(await hass.auth.async_get_users()) == 1
|
assert len(await hass.auth.async_get_users()) == 1
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,7 @@ from tests.common import MockUser, register_auth_provider
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def setup_config(hass, aiohttp_client):
|
def setup_config(hass):
|
||||||
"""Fixture that sets up the auth provider homeassistant module."""
|
"""Fixture that sets up the auth provider homeassistant module."""
|
||||||
hass.loop.run_until_complete(register_auth_provider(hass, {
|
hass.loop.run_until_complete(register_auth_provider(hass, {
|
||||||
'type': 'homeassistant'
|
'type': 'homeassistant'
|
||||||
|
@ -22,7 +22,9 @@ async def test_create_auth_system_generated_user(hass, hass_access_token,
|
||||||
"""Test we can't add auth to system generated users."""
|
"""Test we can't add auth to system generated users."""
|
||||||
system_user = MockUser(system_generated=True).add_to_hass(hass)
|
system_user = MockUser(system_generated=True).add_to_hass(hass)
|
||||||
client = await hass_ws_client(hass, hass_access_token)
|
client = await hass_ws_client(hass, hass_access_token)
|
||||||
hass_access_token.refresh_token.user.is_owner = True
|
refresh_token = await hass.auth.async_validate_access_token(
|
||||||
|
hass_access_token)
|
||||||
|
refresh_token.user.is_owner = True
|
||||||
|
|
||||||
await client.send_json({
|
await client.send_json({
|
||||||
'id': 5,
|
'id': 5,
|
||||||
|
@ -47,7 +49,9 @@ async def test_create_auth_unknown_user(hass_ws_client, hass,
|
||||||
hass_access_token):
|
hass_access_token):
|
||||||
"""Test create pointing at unknown user."""
|
"""Test create pointing at unknown user."""
|
||||||
client = await hass_ws_client(hass, hass_access_token)
|
client = await hass_ws_client(hass, hass_access_token)
|
||||||
hass_access_token.refresh_token.user.is_owner = True
|
refresh_token = await hass.auth.async_validate_access_token(
|
||||||
|
hass_access_token)
|
||||||
|
refresh_token.user.is_owner = True
|
||||||
|
|
||||||
await client.send_json({
|
await client.send_json({
|
||||||
'id': 5,
|
'id': 5,
|
||||||
|
@ -86,7 +90,9 @@ async def test_create_auth(hass, hass_ws_client, hass_access_token,
|
||||||
"""Test create auth command works."""
|
"""Test create auth command works."""
|
||||||
client = await hass_ws_client(hass, hass_access_token)
|
client = await hass_ws_client(hass, hass_access_token)
|
||||||
user = MockUser().add_to_hass(hass)
|
user = MockUser().add_to_hass(hass)
|
||||||
hass_access_token.refresh_token.user.is_owner = True
|
refresh_token = await hass.auth.async_validate_access_token(
|
||||||
|
hass_access_token)
|
||||||
|
refresh_token.user.is_owner = True
|
||||||
|
|
||||||
assert len(user.credentials) == 0
|
assert len(user.credentials) == 0
|
||||||
|
|
||||||
|
@ -117,7 +123,9 @@ async def test_create_auth_duplicate_username(hass, hass_ws_client,
|
||||||
"""Test we can't create auth with a duplicate username."""
|
"""Test we can't create auth with a duplicate username."""
|
||||||
client = await hass_ws_client(hass, hass_access_token)
|
client = await hass_ws_client(hass, hass_access_token)
|
||||||
user = MockUser().add_to_hass(hass)
|
user = MockUser().add_to_hass(hass)
|
||||||
hass_access_token.refresh_token.user.is_owner = True
|
refresh_token = await hass.auth.async_validate_access_token(
|
||||||
|
hass_access_token)
|
||||||
|
refresh_token.user.is_owner = True
|
||||||
|
|
||||||
hass_storage[prov_ha.STORAGE_KEY] = {
|
hass_storage[prov_ha.STORAGE_KEY] = {
|
||||||
'version': 1,
|
'version': 1,
|
||||||
|
@ -145,7 +153,9 @@ async def test_delete_removes_just_auth(hass_ws_client, hass, hass_storage,
|
||||||
hass_access_token):
|
hass_access_token):
|
||||||
"""Test deleting an auth without being connected to a user."""
|
"""Test deleting an auth without being connected to a user."""
|
||||||
client = await hass_ws_client(hass, hass_access_token)
|
client = await hass_ws_client(hass, hass_access_token)
|
||||||
hass_access_token.refresh_token.user.is_owner = True
|
refresh_token = await hass.auth.async_validate_access_token(
|
||||||
|
hass_access_token)
|
||||||
|
refresh_token.user.is_owner = True
|
||||||
|
|
||||||
hass_storage[prov_ha.STORAGE_KEY] = {
|
hass_storage[prov_ha.STORAGE_KEY] = {
|
||||||
'version': 1,
|
'version': 1,
|
||||||
|
@ -171,7 +181,9 @@ async def test_delete_removes_credential(hass, hass_ws_client,
|
||||||
hass_access_token, hass_storage):
|
hass_access_token, hass_storage):
|
||||||
"""Test deleting auth that is connected to a user."""
|
"""Test deleting auth that is connected to a user."""
|
||||||
client = await hass_ws_client(hass, hass_access_token)
|
client = await hass_ws_client(hass, hass_access_token)
|
||||||
hass_access_token.refresh_token.user.is_owner = True
|
refresh_token = await hass.auth.async_validate_access_token(
|
||||||
|
hass_access_token)
|
||||||
|
refresh_token.user.is_owner = True
|
||||||
|
|
||||||
user = MockUser().add_to_hass(hass)
|
user = MockUser().add_to_hass(hass)
|
||||||
user.credentials.append(
|
user.credentials.append(
|
||||||
|
@ -216,7 +228,9 @@ async def test_delete_requires_owner(hass, hass_ws_client, hass_access_token):
|
||||||
async def test_delete_unknown_auth(hass, hass_ws_client, hass_access_token):
|
async def test_delete_unknown_auth(hass, hass_ws_client, hass_access_token):
|
||||||
"""Test trying to delete an unknown auth username."""
|
"""Test trying to delete an unknown auth username."""
|
||||||
client = await hass_ws_client(hass, hass_access_token)
|
client = await hass_ws_client(hass, hass_access_token)
|
||||||
hass_access_token.refresh_token.user.is_owner = True
|
refresh_token = await hass.auth.async_validate_access_token(
|
||||||
|
hass_access_token)
|
||||||
|
refresh_token.user.is_owner = True
|
||||||
|
|
||||||
await client.send_json({
|
await client.send_json({
|
||||||
'id': 5,
|
'id': 5,
|
||||||
|
@ -240,7 +254,9 @@ async def test_change_password(hass, hass_ws_client, hass_access_token):
|
||||||
'username': 'test-user'
|
'username': 'test-user'
|
||||||
})
|
})
|
||||||
|
|
||||||
user = hass_access_token.refresh_token.user
|
refresh_token = await hass.auth.async_validate_access_token(
|
||||||
|
hass_access_token)
|
||||||
|
user = refresh_token.user
|
||||||
await hass.auth.async_link_user(user, credentials)
|
await hass.auth.async_link_user(user, credentials)
|
||||||
|
|
||||||
client = await hass_ws_client(hass, hass_access_token)
|
client = await hass_ws_client(hass, hass_access_token)
|
||||||
|
@ -268,7 +284,9 @@ async def test_change_password_wrong_pw(hass, hass_ws_client,
|
||||||
'username': 'test-user'
|
'username': 'test-user'
|
||||||
})
|
})
|
||||||
|
|
||||||
user = hass_access_token.refresh_token.user
|
refresh_token = await hass.auth.async_validate_access_token(
|
||||||
|
hass_access_token)
|
||||||
|
user = refresh_token.user
|
||||||
await hass.auth.async_link_user(user, credentials)
|
await hass.auth.async_link_user(user, credentials)
|
||||||
|
|
||||||
client = await hass_ws_client(hass, hass_access_token)
|
client = await hass_ws_client(hass, hass_access_token)
|
||||||
|
|
|
@ -28,7 +28,7 @@ def hass_ws_client(aiohttp_client):
|
||||||
|
|
||||||
await websocket.send_json({
|
await websocket.send_json({
|
||||||
'type': websocket_api.TYPE_AUTH,
|
'type': websocket_api.TYPE_AUTH,
|
||||||
'access_token': access_token.token
|
'access_token': access_token
|
||||||
})
|
})
|
||||||
|
|
||||||
auth_ok = await websocket.receive_json()
|
auth_ok = await websocket.receive_json()
|
||||||
|
|
|
@ -106,7 +106,11 @@ async def test_setup_api_push_api_data_default(hass, aioclient_mock,
|
||||||
)
|
)
|
||||||
assert hassio_user is not None
|
assert hassio_user is not None
|
||||||
assert hassio_user.system_generated
|
assert hassio_user.system_generated
|
||||||
assert refresh_token in hassio_user.refresh_tokens
|
for token in hassio_user.refresh_tokens.values():
|
||||||
|
if token.token == refresh_token:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
assert False, 'refresh token not found'
|
||||||
|
|
||||||
|
|
||||||
async def test_setup_api_push_api_data_no_auth(hass, aioclient_mock,
|
async def test_setup_api_push_api_data_no_auth(hass, aioclient_mock,
|
||||||
|
|
|
@ -156,9 +156,9 @@ async def test_access_with_trusted_ip(app2, aiohttp_client):
|
||||||
|
|
||||||
|
|
||||||
async def test_auth_active_access_with_access_token_in_header(
|
async def test_auth_active_access_with_access_token_in_header(
|
||||||
app, aiohttp_client, hass_access_token):
|
hass, app, aiohttp_client, hass_access_token):
|
||||||
"""Test access with access token in header."""
|
"""Test access with access token in header."""
|
||||||
token = hass_access_token.token
|
token = hass_access_token
|
||||||
setup_auth(app, [], True, api_password=None)
|
setup_auth(app, [], True, api_password=None)
|
||||||
client = await aiohttp_client(app)
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
|
@ -182,7 +182,9 @@ async def test_auth_active_access_with_access_token_in_header(
|
||||||
'/', headers={'Authorization': 'BEARER {}'.format(token)})
|
'/', headers={'Authorization': 'BEARER {}'.format(token)})
|
||||||
assert req.status == 401
|
assert req.status == 401
|
||||||
|
|
||||||
hass_access_token.refresh_token.user.is_active = False
|
refresh_token = await hass.auth.async_validate_access_token(
|
||||||
|
hass_access_token)
|
||||||
|
refresh_token.user.is_active = False
|
||||||
req = await client.get(
|
req = await client.get(
|
||||||
'/', headers={'Authorization': 'Bearer {}'.format(token)})
|
'/', headers={'Authorization': 'Bearer {}'.format(token)})
|
||||||
assert req.status == 401
|
assert req.status == 401
|
||||||
|
|
|
@ -448,13 +448,15 @@ async def test_api_fire_event_context(hass, mock_api_client,
|
||||||
await mock_api_client.post(
|
await mock_api_client.post(
|
||||||
const.URL_API_EVENTS_EVENT.format("test.event"),
|
const.URL_API_EVENTS_EVENT.format("test.event"),
|
||||||
headers={
|
headers={
|
||||||
'authorization': 'Bearer {}'.format(hass_access_token.token)
|
'authorization': 'Bearer {}'.format(hass_access_token)
|
||||||
})
|
})
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
refresh_token = await hass.auth.async_validate_access_token(
|
||||||
|
hass_access_token)
|
||||||
|
|
||||||
assert len(test_value) == 1
|
assert len(test_value) == 1
|
||||||
assert test_value[0].context.user_id == \
|
assert test_value[0].context.user_id == refresh_token.user.id
|
||||||
hass_access_token.refresh_token.user.id
|
|
||||||
|
|
||||||
|
|
||||||
async def test_api_call_service_context(hass, mock_api_client,
|
async def test_api_call_service_context(hass, mock_api_client,
|
||||||
|
@ -465,12 +467,15 @@ async def test_api_call_service_context(hass, mock_api_client,
|
||||||
await mock_api_client.post(
|
await mock_api_client.post(
|
||||||
'/api/services/test_domain/test_service',
|
'/api/services/test_domain/test_service',
|
||||||
headers={
|
headers={
|
||||||
'authorization': 'Bearer {}'.format(hass_access_token.token)
|
'authorization': 'Bearer {}'.format(hass_access_token)
|
||||||
})
|
})
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
refresh_token = await hass.auth.async_validate_access_token(
|
||||||
|
hass_access_token)
|
||||||
|
|
||||||
assert len(calls) == 1
|
assert len(calls) == 1
|
||||||
assert calls[0].context.user_id == hass_access_token.refresh_token.user.id
|
assert calls[0].context.user_id == refresh_token.user.id
|
||||||
|
|
||||||
|
|
||||||
async def test_api_set_state_context(hass, mock_api_client, hass_access_token):
|
async def test_api_set_state_context(hass, mock_api_client, hass_access_token):
|
||||||
|
@ -481,8 +486,11 @@ async def test_api_set_state_context(hass, mock_api_client, hass_access_token):
|
||||||
'state': 'on'
|
'state': 'on'
|
||||||
},
|
},
|
||||||
headers={
|
headers={
|
||||||
'authorization': 'Bearer {}'.format(hass_access_token.token)
|
'authorization': 'Bearer {}'.format(hass_access_token)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
refresh_token = await hass.auth.async_validate_access_token(
|
||||||
|
hass_access_token)
|
||||||
|
|
||||||
state = hass.states.get('light.kitchen')
|
state = hass.states.get('light.kitchen')
|
||||||
assert state.context.user_id == hass_access_token.refresh_token.user.id
|
assert state.context.user_id == refresh_token.user.id
|
||||||
|
|
|
@ -334,7 +334,7 @@ async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token):
|
||||||
|
|
||||||
await ws.send_json({
|
await ws.send_json({
|
||||||
'type': wapi.TYPE_AUTH,
|
'type': wapi.TYPE_AUTH,
|
||||||
'access_token': hass_access_token.token
|
'access_token': hass_access_token
|
||||||
})
|
})
|
||||||
|
|
||||||
auth_msg = await ws.receive_json()
|
auth_msg = await ws.receive_json()
|
||||||
|
@ -344,7 +344,9 @@ async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token):
|
||||||
async def test_auth_active_user_inactive(hass, aiohttp_client,
|
async def test_auth_active_user_inactive(hass, aiohttp_client,
|
||||||
hass_access_token):
|
hass_access_token):
|
||||||
"""Test authenticating with a token."""
|
"""Test authenticating with a token."""
|
||||||
hass_access_token.refresh_token.user.is_active = False
|
refresh_token = await hass.auth.async_validate_access_token(
|
||||||
|
hass_access_token)
|
||||||
|
refresh_token.user.is_active = False
|
||||||
assert await async_setup_component(hass, 'websocket_api', {
|
assert await async_setup_component(hass, 'websocket_api', {
|
||||||
'http': {
|
'http': {
|
||||||
'api_password': API_PASSWORD
|
'api_password': API_PASSWORD
|
||||||
|
@ -361,7 +363,7 @@ async def test_auth_active_user_inactive(hass, aiohttp_client,
|
||||||
|
|
||||||
await ws.send_json({
|
await ws.send_json({
|
||||||
'type': wapi.TYPE_AUTH,
|
'type': wapi.TYPE_AUTH,
|
||||||
'access_token': hass_access_token.token
|
'access_token': hass_access_token
|
||||||
})
|
})
|
||||||
|
|
||||||
auth_msg = await ws.receive_json()
|
auth_msg = await ws.receive_json()
|
||||||
|
@ -465,7 +467,7 @@ async def test_call_service_context_with_user(hass, aiohttp_client,
|
||||||
|
|
||||||
await ws.send_json({
|
await ws.send_json({
|
||||||
'type': wapi.TYPE_AUTH,
|
'type': wapi.TYPE_AUTH,
|
||||||
'access_token': hass_access_token.token
|
'access_token': hass_access_token
|
||||||
})
|
})
|
||||||
|
|
||||||
auth_msg = await ws.receive_json()
|
auth_msg = await ws.receive_json()
|
||||||
|
@ -484,12 +486,15 @@ async def test_call_service_context_with_user(hass, aiohttp_client,
|
||||||
msg = await ws.receive_json()
|
msg = await ws.receive_json()
|
||||||
assert msg['success']
|
assert msg['success']
|
||||||
|
|
||||||
|
refresh_token = await hass.auth.async_validate_access_token(
|
||||||
|
hass_access_token)
|
||||||
|
|
||||||
assert len(calls) == 1
|
assert len(calls) == 1
|
||||||
call = calls[0]
|
call = calls[0]
|
||||||
assert call.domain == 'domain_test'
|
assert call.domain == 'domain_test'
|
||||||
assert call.service == 'test_service'
|
assert call.service == 'test_service'
|
||||||
assert call.data == {'hello': 'world'}
|
assert call.data == {'hello': 'world'}
|
||||||
assert call.context.user_id == hass_access_token.refresh_token.user.id
|
assert call.context.user_id == refresh_token.user.id
|
||||||
|
|
||||||
|
|
||||||
async def test_call_service_context_no_user(hass, aiohttp_client):
|
async def test_call_service_context_no_user(hass, aiohttp_client):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue