From 39971ee9190b616fc3149c53912f9f8b2976c46a Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Fri, 29 Jun 2018 00:02:33 -0400 Subject: [PATCH] Make sure we check access token expiration (#15207) * Make sure we check access token expiration * Use correct access token websocket --- homeassistant/auth.py | 27 +++++++--- homeassistant/components/frontend/__init__.py | 2 +- homeassistant/components/websocket_api.py | 5 +- tests/common.py | 1 + tests/test_auth.py | 50 ++++++++++++++++++- 5 files changed, 74 insertions(+), 11 deletions(-) diff --git a/homeassistant/auth.py b/homeassistant/auth.py index 0c8346607ca..22abcdf213c 100644 --- a/homeassistant/auth.py +++ b/homeassistant/auth.py @@ -159,9 +159,10 @@ class AccessToken: default=attr.Factory(generate_secret)) @property - def expires(self): - """Return datetime when this token expires.""" - return self.created_at + self.refresh_token.access_token_expiration + def expired(self): + """Return if this token has expired.""" + expires = self.created_at + self.refresh_token.access_token_expiration + return dt_util.utcnow() > expires @attr.s(slots=True) @@ -272,7 +273,12 @@ class AuthManager: self.login_flow = data_entry_flow.FlowManager( hass, self._async_create_login_flow, self._async_finish_login_flow) - self.access_tokens = {} + self._access_tokens = {} + + @property + def active(self): + """Return if any auth providers are registered.""" + return bool(self._providers) @property def async_auth_providers(self): @@ -308,13 +314,22 @@ class AuthManager: def async_create_access_token(self, refresh_token): """Create a new access token.""" access_token = AccessToken(refresh_token) - self.access_tokens[access_token.token] = access_token + self._access_tokens[access_token.token] = access_token return access_token @callback def async_get_access_token(self, token): """Get an access token.""" - return self.access_tokens.get(token) + tkn = self._access_tokens.get(token) + + if tkn is None: + return None + + if tkn.expired: + self._access_tokens.pop(token) + return None + + return tkn async def async_create_client(self, name, *, redirect_uris=None, no_secret=False): diff --git a/homeassistant/components/frontend/__init__.py b/homeassistant/components/frontend/__init__.py index ffdd3160b2e..0e9d7612669 100644 --- a/homeassistant/components/frontend/__init__.py +++ b/homeassistant/components/frontend/__init__.py @@ -200,7 +200,7 @@ def add_manifest_json_key(key, val): async def async_setup(hass, config): """Set up the serving of the frontend.""" - if list(hass.auth.async_auth_providers): + if hass.auth.active: client = await hass.auth.async_create_client( 'Home Assistant Frontend', redirect_uris=['/'], diff --git a/homeassistant/components/websocket_api.py b/homeassistant/components/websocket_api.py index aacef4547b7..bf472348bab 100644 --- a/homeassistant/components/websocket_api.py +++ b/homeassistant/components/websocket_api.py @@ -324,8 +324,9 @@ class ActiveConnection: request, msg['api_password']) elif 'access_token' in msg: - authenticated = \ - msg['access_token'] in self.hass.auth.access_tokens + token = self.hass.auth.async_get_access_token( + msg['access_token']) + authenticated = token is not None if not authenticated: self.debug("Invalid password") diff --git a/tests/common.py b/tests/common.py index 8eaee686b22..1b8eabaa0db 100644 --- a/tests/common.py +++ b/tests/common.py @@ -320,6 +320,7 @@ class MockUser(auth.User): def add_to_auth_manager(self, auth_mgr): """Test helper to add entry to hass.""" + ensure_auth_manager_loaded(auth_mgr) auth_mgr._store.users[self.id] = self return self diff --git a/tests/test_auth.py b/tests/test_auth.py index 116f92ca817..4c0db71466e 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,14 +1,16 @@ """Tests for the Home Assistant auth module.""" -from unittest.mock import Mock +from datetime import timedelta +from unittest.mock import Mock, patch import pytest from homeassistant import auth, data_entry_flow +from homeassistant.util import dt as dt_util from tests.common import MockUser, ensure_auth_manager_loaded, flush_store @pytest.fixture -def mock_hass(): +def mock_hass(loop): """Hass mock with minimum amount of data set to make it work with auth.""" hass = Mock() hass.config.skip_pip = True @@ -195,3 +197,47 @@ async def test_saving_loading(hass, hass_storage): assert len(store2.clients) == 1 assert store2.clients[client.id] == client + + +def test_access_token_expired(): + """Test that the expired property on access tokens work.""" + refresh_token = auth.RefreshToken( + user=None, + client_id='bla' + ) + + access_token = auth.AccessToken( + refresh_token=refresh_token + ) + + assert access_token.expired is False + + with patch('homeassistant.auth.dt_util.utcnow', + return_value=dt_util.utcnow() + auth.ACCESS_TOKEN_EXPIRATION): + assert access_token.expired is True + + almost_exp = dt_util.utcnow() + auth.ACCESS_TOKEN_EXPIRATION - timedelta(1) + with patch('homeassistant.auth.dt_util.utcnow', return_value=almost_exp): + assert access_token.expired is False + + +async def test_cannot_retrieve_expired_access_token(hass): + """Test that we cannot retrieve expired access tokens.""" + manager = await auth.auth_manager_from_config(hass, []) + user = MockUser( + id='mock-user', + is_owner=False, + is_active=False, + name='Paulus', + ).add_to_auth_manager(manager) + refresh_token = await manager.async_create_refresh_token(user, 'bla') + access_token = manager.async_create_access_token(refresh_token) + + assert manager.async_get_access_token(access_token.token) is access_token + + with patch('homeassistant.auth.dt_util.utcnow', + return_value=dt_util.utcnow() + auth.ACCESS_TOKEN_EXPIRATION): + assert manager.async_get_access_token(access_token.token) is None + + # Even with unpatched time, it should have been removed from manager + assert manager.async_get_access_token(access_token.token) is None