Make sure we check access token expiration (#15207)
* Make sure we check access token expiration * Use correct access token websocket
This commit is contained in:
parent
2205090795
commit
39971ee919
5 changed files with 74 additions and 11 deletions
|
@ -159,9 +159,10 @@ class AccessToken:
|
|||
default=attr.Factory(generate_secret))
|
||||
|
||||
@property
|
||||
def expires(self):
|
||||
"""Return datetime when this token expires."""
|
||||
return self.created_at + self.refresh_token.access_token_expiration
|
||||
def expired(self):
|
||||
"""Return if this token has expired."""
|
||||
expires = self.created_at + self.refresh_token.access_token_expiration
|
||||
return dt_util.utcnow() > expires
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
|
@ -272,7 +273,12 @@ class AuthManager:
|
|||
self.login_flow = data_entry_flow.FlowManager(
|
||||
hass, self._async_create_login_flow,
|
||||
self._async_finish_login_flow)
|
||||
self.access_tokens = {}
|
||||
self._access_tokens = {}
|
||||
|
||||
@property
|
||||
def active(self):
|
||||
"""Return if any auth providers are registered."""
|
||||
return bool(self._providers)
|
||||
|
||||
@property
|
||||
def async_auth_providers(self):
|
||||
|
@ -308,13 +314,22 @@ class AuthManager:
|
|||
def async_create_access_token(self, refresh_token):
|
||||
"""Create a new access token."""
|
||||
access_token = AccessToken(refresh_token)
|
||||
self.access_tokens[access_token.token] = access_token
|
||||
self._access_tokens[access_token.token] = access_token
|
||||
return access_token
|
||||
|
||||
@callback
|
||||
def async_get_access_token(self, token):
|
||||
"""Get an access token."""
|
||||
return self.access_tokens.get(token)
|
||||
tkn = self._access_tokens.get(token)
|
||||
|
||||
if tkn is None:
|
||||
return None
|
||||
|
||||
if tkn.expired:
|
||||
self._access_tokens.pop(token)
|
||||
return None
|
||||
|
||||
return tkn
|
||||
|
||||
async def async_create_client(self, name, *, redirect_uris=None,
|
||||
no_secret=False):
|
||||
|
|
|
@ -200,7 +200,7 @@ def add_manifest_json_key(key, val):
|
|||
|
||||
async def async_setup(hass, config):
|
||||
"""Set up the serving of the frontend."""
|
||||
if list(hass.auth.async_auth_providers):
|
||||
if hass.auth.active:
|
||||
client = await hass.auth.async_create_client(
|
||||
'Home Assistant Frontend',
|
||||
redirect_uris=['/'],
|
||||
|
|
|
@ -324,8 +324,9 @@ class ActiveConnection:
|
|||
request, msg['api_password'])
|
||||
|
||||
elif 'access_token' in msg:
|
||||
authenticated = \
|
||||
msg['access_token'] in self.hass.auth.access_tokens
|
||||
token = self.hass.auth.async_get_access_token(
|
||||
msg['access_token'])
|
||||
authenticated = token is not None
|
||||
|
||||
if not authenticated:
|
||||
self.debug("Invalid password")
|
||||
|
|
|
@ -320,6 +320,7 @@ class MockUser(auth.User):
|
|||
|
||||
def add_to_auth_manager(self, auth_mgr):
|
||||
"""Test helper to add entry to hass."""
|
||||
ensure_auth_manager_loaded(auth_mgr)
|
||||
auth_mgr._store.users[self.id] = self
|
||||
return self
|
||||
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
"""Tests for the Home Assistant auth module."""
|
||||
from unittest.mock import Mock
|
||||
from datetime import timedelta
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant import auth, data_entry_flow
|
||||
from homeassistant.util import dt as dt_util
|
||||
from tests.common import MockUser, ensure_auth_manager_loaded, flush_store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_hass():
|
||||
def mock_hass(loop):
|
||||
"""Hass mock with minimum amount of data set to make it work with auth."""
|
||||
hass = Mock()
|
||||
hass.config.skip_pip = True
|
||||
|
@ -195,3 +197,47 @@ async def test_saving_loading(hass, hass_storage):
|
|||
|
||||
assert len(store2.clients) == 1
|
||||
assert store2.clients[client.id] == client
|
||||
|
||||
|
||||
def test_access_token_expired():
|
||||
"""Test that the expired property on access tokens work."""
|
||||
refresh_token = auth.RefreshToken(
|
||||
user=None,
|
||||
client_id='bla'
|
||||
)
|
||||
|
||||
access_token = auth.AccessToken(
|
||||
refresh_token=refresh_token
|
||||
)
|
||||
|
||||
assert access_token.expired is False
|
||||
|
||||
with patch('homeassistant.auth.dt_util.utcnow',
|
||||
return_value=dt_util.utcnow() + auth.ACCESS_TOKEN_EXPIRATION):
|
||||
assert access_token.expired is True
|
||||
|
||||
almost_exp = dt_util.utcnow() + auth.ACCESS_TOKEN_EXPIRATION - timedelta(1)
|
||||
with patch('homeassistant.auth.dt_util.utcnow', return_value=almost_exp):
|
||||
assert access_token.expired is False
|
||||
|
||||
|
||||
async def test_cannot_retrieve_expired_access_token(hass):
|
||||
"""Test that we cannot retrieve expired access tokens."""
|
||||
manager = await auth.auth_manager_from_config(hass, [])
|
||||
user = MockUser(
|
||||
id='mock-user',
|
||||
is_owner=False,
|
||||
is_active=False,
|
||||
name='Paulus',
|
||||
).add_to_auth_manager(manager)
|
||||
refresh_token = await manager.async_create_refresh_token(user, 'bla')
|
||||
access_token = manager.async_create_access_token(refresh_token)
|
||||
|
||||
assert manager.async_get_access_token(access_token.token) is access_token
|
||||
|
||||
with patch('homeassistant.auth.dt_util.utcnow',
|
||||
return_value=dt_util.utcnow() + auth.ACCESS_TOKEN_EXPIRATION):
|
||||
assert manager.async_get_access_token(access_token.token) is None
|
||||
|
||||
# Even with unpatched time, it should have been removed from manager
|
||||
assert manager.async_get_access_token(access_token.token) is None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue