Make sure we check access token expiration (#15207)

* Make sure we check access token expiration

* Use correct access token websocket
This commit is contained in:
Paulus Schoutsen 2018-06-29 00:02:33 -04:00 committed by GitHub
parent 2205090795
commit 39971ee919
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 74 additions and 11 deletions

View file

@ -159,9 +159,10 @@ class AccessToken:
default=attr.Factory(generate_secret)) default=attr.Factory(generate_secret))
@property @property
def expires(self): def expired(self):
"""Return datetime when this token expires.""" """Return if this token has expired."""
return self.created_at + self.refresh_token.access_token_expiration expires = self.created_at + self.refresh_token.access_token_expiration
return dt_util.utcnow() > expires
@attr.s(slots=True) @attr.s(slots=True)
@ -272,7 +273,12 @@ class AuthManager:
self.login_flow = data_entry_flow.FlowManager( self.login_flow = data_entry_flow.FlowManager(
hass, self._async_create_login_flow, hass, self._async_create_login_flow,
self._async_finish_login_flow) self._async_finish_login_flow)
self.access_tokens = {} self._access_tokens = {}
@property
def active(self):
"""Return if any auth providers are registered."""
return bool(self._providers)
@property @property
def async_auth_providers(self): def async_auth_providers(self):
@ -308,13 +314,22 @@ class AuthManager:
def async_create_access_token(self, refresh_token): def async_create_access_token(self, refresh_token):
"""Create a new access token.""" """Create a new access token."""
access_token = AccessToken(refresh_token) access_token = AccessToken(refresh_token)
self.access_tokens[access_token.token] = access_token self._access_tokens[access_token.token] = access_token
return access_token return access_token
@callback @callback
def async_get_access_token(self, token): def async_get_access_token(self, token):
"""Get an access token.""" """Get an access token."""
return self.access_tokens.get(token) tkn = self._access_tokens.get(token)
if tkn is None:
return None
if tkn.expired:
self._access_tokens.pop(token)
return None
return tkn
async def async_create_client(self, name, *, redirect_uris=None, async def async_create_client(self, name, *, redirect_uris=None,
no_secret=False): no_secret=False):

View file

@ -200,7 +200,7 @@ def add_manifest_json_key(key, val):
async def async_setup(hass, config): async def async_setup(hass, config):
"""Set up the serving of the frontend.""" """Set up the serving of the frontend."""
if list(hass.auth.async_auth_providers): if hass.auth.active:
client = await hass.auth.async_create_client( client = await hass.auth.async_create_client(
'Home Assistant Frontend', 'Home Assistant Frontend',
redirect_uris=['/'], redirect_uris=['/'],

View file

@ -324,8 +324,9 @@ class ActiveConnection:
request, msg['api_password']) request, msg['api_password'])
elif 'access_token' in msg: elif 'access_token' in msg:
authenticated = \ token = self.hass.auth.async_get_access_token(
msg['access_token'] in self.hass.auth.access_tokens msg['access_token'])
authenticated = token is not None
if not authenticated: if not authenticated:
self.debug("Invalid password") self.debug("Invalid password")

View file

@ -320,6 +320,7 @@ class MockUser(auth.User):
def add_to_auth_manager(self, auth_mgr): def add_to_auth_manager(self, auth_mgr):
"""Test helper to add entry to hass.""" """Test helper to add entry to hass."""
ensure_auth_manager_loaded(auth_mgr)
auth_mgr._store.users[self.id] = self auth_mgr._store.users[self.id] = self
return self return self

View file

@ -1,14 +1,16 @@
"""Tests for the Home Assistant auth module.""" """Tests for the Home Assistant auth module."""
from unittest.mock import Mock from datetime import timedelta
from unittest.mock import Mock, patch
import pytest import pytest
from homeassistant import auth, data_entry_flow from homeassistant import auth, data_entry_flow
from homeassistant.util import dt as dt_util
from tests.common import MockUser, ensure_auth_manager_loaded, flush_store from tests.common import MockUser, ensure_auth_manager_loaded, flush_store
@pytest.fixture @pytest.fixture
def mock_hass(): def mock_hass(loop):
"""Hass mock with minimum amount of data set to make it work with auth.""" """Hass mock with minimum amount of data set to make it work with auth."""
hass = Mock() hass = Mock()
hass.config.skip_pip = True hass.config.skip_pip = True
@ -195,3 +197,47 @@ async def test_saving_loading(hass, hass_storage):
assert len(store2.clients) == 1 assert len(store2.clients) == 1
assert store2.clients[client.id] == client assert store2.clients[client.id] == client
def test_access_token_expired():
"""Test that the expired property on access tokens work."""
refresh_token = auth.RefreshToken(
user=None,
client_id='bla'
)
access_token = auth.AccessToken(
refresh_token=refresh_token
)
assert access_token.expired is False
with patch('homeassistant.auth.dt_util.utcnow',
return_value=dt_util.utcnow() + auth.ACCESS_TOKEN_EXPIRATION):
assert access_token.expired is True
almost_exp = dt_util.utcnow() + auth.ACCESS_TOKEN_EXPIRATION - timedelta(1)
with patch('homeassistant.auth.dt_util.utcnow', return_value=almost_exp):
assert access_token.expired is False
async def test_cannot_retrieve_expired_access_token(hass):
"""Test that we cannot retrieve expired access tokens."""
manager = await auth.auth_manager_from_config(hass, [])
user = MockUser(
id='mock-user',
is_owner=False,
is_active=False,
name='Paulus',
).add_to_auth_manager(manager)
refresh_token = await manager.async_create_refresh_token(user, 'bla')
access_token = manager.async_create_access_token(refresh_token)
assert manager.async_get_access_token(access_token.token) is access_token
with patch('homeassistant.auth.dt_util.utcnow',
return_value=dt_util.utcnow() + auth.ACCESS_TOKEN_EXPIRATION):
assert manager.async_get_access_token(access_token.token) is None
# Even with unpatched time, it should have been removed from manager
assert manager.async_get_access_token(access_token.token) is None