Add expiration of unused refresh tokens (#108428)

Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Michael 2024-01-25 00:24:22 +01:00 committed by GitHub
parent 0d22822ed0
commit f5d439799b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 243 additions and 7 deletions

View file

@ -4,7 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Mapping from collections.abc import Mapping
from datetime import timedelta from datetime import datetime, timedelta
from functools import partial from functools import partial
import time import time
from typing import Any, cast from typing import Any, cast
@ -12,11 +12,19 @@ from typing import Any, cast
import jwt import jwt
from homeassistant import data_entry_flow from homeassistant import data_entry_flow
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.core import (
CALLBACK_TYPE,
HassJob,
HassJobType,
HomeAssistant,
callback,
)
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.event import async_track_point_in_utc_time
from homeassistant.util import dt as dt_util
from . import auth_store, jwt_wrapper, models from . import auth_store, jwt_wrapper, models
from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN, REFRESH_TOKEN_EXPIRATION
from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config
from .providers import AuthProvider, LoginFlow, auth_provider_from_config from .providers import AuthProvider, LoginFlow, auth_provider_from_config
@ -75,7 +83,9 @@ async def auth_manager_from_config(
for module in modules: for module in modules:
module_hash[module.id] = module module_hash[module.id] = module
return AuthManager(hass, store, provider_hash, module_hash) manager = AuthManager(hass, store, provider_hash, module_hash)
manager.async_setup()
return manager
class AuthManagerFlowManager(data_entry_flow.FlowManager): class AuthManagerFlowManager(data_entry_flow.FlowManager):
@ -159,6 +169,21 @@ class AuthManager:
self._mfa_modules = mfa_modules self._mfa_modules = mfa_modules
self.login_flow = AuthManagerFlowManager(hass, self) self.login_flow = AuthManagerFlowManager(hass, self)
self._revoke_callbacks: dict[str, set[CALLBACK_TYPE]] = {} self._revoke_callbacks: dict[str, set[CALLBACK_TYPE]] = {}
self._expire_callback: CALLBACK_TYPE | None = None
self._remove_expired_job = HassJob(
self._async_remove_expired_refresh_tokens, job_type=HassJobType.Callback
)
@callback
def async_setup(self) -> None:
"""Set up the auth manager."""
hass = self.hass
hass.async_add_shutdown_job(
HassJob(
self._async_cancel_expiration_schedule, job_type=HassJobType.Callback
)
)
self._async_track_next_refresh_token_expiration()
@property @property
def auth_providers(self) -> list[AuthProvider]: def auth_providers(self) -> list[AuthProvider]:
@ -424,6 +449,11 @@ class AuthManager:
else: else:
token_type = models.TOKEN_TYPE_NORMAL token_type = models.TOKEN_TYPE_NORMAL
if token_type is models.TOKEN_TYPE_NORMAL:
expire_at = time.time() + REFRESH_TOKEN_EXPIRATION
else:
expire_at = None
if user.system_generated != (token_type == models.TOKEN_TYPE_SYSTEM): if user.system_generated != (token_type == models.TOKEN_TYPE_SYSTEM):
raise ValueError( raise ValueError(
"System generated users can only have system type refresh tokens" "System generated users can only have system type refresh tokens"
@ -455,6 +485,7 @@ class AuthManager:
client_icon, client_icon,
token_type, token_type,
access_token_expiration, access_token_expiration,
expire_at,
credential, credential,
) )
@ -479,6 +510,38 @@ class AuthManager:
for revoke_callback in callbacks: for revoke_callback in callbacks:
revoke_callback() revoke_callback()
@callback
def _async_remove_expired_refresh_tokens(self, _: datetime | None = None) -> None:
"""Remove expired refresh tokens."""
now = time.time()
for token in self._store.async_get_refresh_tokens()[:]:
if (expire_at := token.expire_at) is not None and expire_at <= now:
self.async_remove_refresh_token(token)
self._async_track_next_refresh_token_expiration()
@callback
def _async_track_next_refresh_token_expiration(self) -> None:
"""Initialise all token expiration scheduled tasks."""
next_expiration = time.time() + REFRESH_TOKEN_EXPIRATION
for token in self._store.async_get_refresh_tokens():
if (
expire_at := token.expire_at
) is not None and expire_at < next_expiration:
next_expiration = expire_at
self._expire_callback = async_track_point_in_utc_time(
self.hass,
self._remove_expired_job,
dt_util.utc_from_timestamp(next_expiration),
)
@callback
def _async_cancel_expiration_schedule(self) -> None:
"""Cancel tracking of expired refresh tokens."""
if self._expire_callback:
self._expire_callback()
self._expire_callback = None
@callback @callback
def _async_unregister( def _async_unregister(
self, callbacks: set[CALLBACK_TYPE], callback_: CALLBACK_TYPE self, callbacks: set[CALLBACK_TYPE], callback_: CALLBACK_TYPE

View file

@ -3,6 +3,7 @@ from __future__ import annotations
from datetime import timedelta from datetime import timedelta
import hmac import hmac
import itertools
from logging import getLogger from logging import getLogger
from typing import Any from typing import Any
@ -17,6 +18,7 @@ from .const import (
GROUP_ID_ADMIN, GROUP_ID_ADMIN,
GROUP_ID_READ_ONLY, GROUP_ID_READ_ONLY,
GROUP_ID_USER, GROUP_ID_USER,
REFRESH_TOKEN_EXPIRATION,
) )
from .permissions import system_policies from .permissions import system_policies
from .permissions.models import PermissionLookup from .permissions.models import PermissionLookup
@ -186,6 +188,7 @@ class AuthStore:
client_icon: str | None = None, client_icon: str | None = None,
token_type: str = models.TOKEN_TYPE_NORMAL, token_type: str = models.TOKEN_TYPE_NORMAL,
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION, access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
expire_at: float | None = None,
credential: models.Credentials | None = None, credential: models.Credentials | None = None,
) -> models.RefreshToken: ) -> models.RefreshToken:
"""Create a new token for a user.""" """Create a new token for a user."""
@ -194,6 +197,7 @@ class AuthStore:
"client_id": client_id, "client_id": client_id,
"token_type": token_type, "token_type": token_type,
"access_token_expiration": access_token_expiration, "access_token_expiration": access_token_expiration,
"expire_at": expire_at,
"credential": credential, "credential": credential,
} }
if client_name: if client_name:
@ -239,6 +243,15 @@ class AuthStore:
return found return found
@callback
def async_get_refresh_tokens(self) -> list[models.RefreshToken]:
"""Get all refresh tokens."""
return list(
itertools.chain.from_iterable(
user.refresh_tokens.values() for user in self._users.values()
)
)
@callback @callback
def async_log_refresh_token_usage( def async_log_refresh_token_usage(
self, refresh_token: models.RefreshToken, remote_ip: str | None = None self, refresh_token: models.RefreshToken, remote_ip: str | None = None
@ -246,9 +259,13 @@ class AuthStore:
"""Update refresh token last used information.""" """Update refresh token last used information."""
refresh_token.last_used_at = dt_util.utcnow() refresh_token.last_used_at = dt_util.utcnow()
refresh_token.last_used_ip = remote_ip refresh_token.last_used_ip = remote_ip
if refresh_token.expire_at:
refresh_token.expire_at = (
refresh_token.last_used_at.timestamp() + REFRESH_TOKEN_EXPIRATION
)
self._async_schedule_save() self._async_schedule_save()
async def async_load(self) -> None: async def async_load(self) -> None: # noqa: C901
"""Load the users.""" """Load the users."""
if self._loaded: if self._loaded:
raise RuntimeError("Auth storage is already loaded") raise RuntimeError("Auth storage is already loaded")
@ -261,6 +278,8 @@ class AuthStore:
perm_lookup = PermissionLookup(ent_reg, dev_reg) perm_lookup = PermissionLookup(ent_reg, dev_reg)
self._perm_lookup = perm_lookup self._perm_lookup = perm_lookup
now_ts = dt_util.utcnow().timestamp()
if data is None or not isinstance(data, dict): if data is None or not isinstance(data, dict):
self._set_defaults() self._set_defaults()
return return
@ -414,6 +433,14 @@ class AuthStore:
else: else:
last_used_at = None last_used_at = None
if (
expire_at := rt_dict.get("expire_at")
) is None and token_type == models.TOKEN_TYPE_NORMAL:
if last_used_at:
expire_at = last_used_at.timestamp() + REFRESH_TOKEN_EXPIRATION
else:
expire_at = now_ts + REFRESH_TOKEN_EXPIRATION
token = models.RefreshToken( token = models.RefreshToken(
id=rt_dict["id"], id=rt_dict["id"],
user=users[rt_dict["user_id"]], user=users[rt_dict["user_id"]],
@ -430,6 +457,7 @@ class AuthStore:
jwt_key=rt_dict["jwt_key"], jwt_key=rt_dict["jwt_key"],
last_used_at=last_used_at, last_used_at=last_used_at,
last_used_ip=rt_dict.get("last_used_ip"), last_used_ip=rt_dict.get("last_used_ip"),
expire_at=expire_at,
version=rt_dict.get("version"), version=rt_dict.get("version"),
) )
if "credential_id" in rt_dict: if "credential_id" in rt_dict:
@ -439,6 +467,8 @@ class AuthStore:
self._groups = groups self._groups = groups
self._users = users self._users = users
self._async_schedule_save()
@callback @callback
def _async_schedule_save(self) -> None: def _async_schedule_save(self) -> None:
"""Save users.""" """Save users."""
@ -503,6 +533,7 @@ class AuthStore:
if refresh_token.last_used_at if refresh_token.last_used_at
else None, else None,
"last_used_ip": refresh_token.last_used_ip, "last_used_ip": refresh_token.last_used_ip,
"expire_at": refresh_token.expire_at,
"credential_id": refresh_token.credential.id "credential_id": refresh_token.credential.id
if refresh_token.credential if refresh_token.credential
else None, else None,

View file

@ -3,6 +3,7 @@ from datetime import timedelta
ACCESS_TOKEN_EXPIRATION = timedelta(minutes=30) ACCESS_TOKEN_EXPIRATION = timedelta(minutes=30)
MFA_SESSION_EXPIRATION = timedelta(minutes=5) MFA_SESSION_EXPIRATION = timedelta(minutes=5)
REFRESH_TOKEN_EXPIRATION = timedelta(days=90).total_seconds()
GROUP_ID_ADMIN = "system-admin" GROUP_ID_ADMIN = "system-admin"
GROUP_ID_USER = "system-users" GROUP_ID_USER = "system-users"

View file

@ -117,6 +117,8 @@ class RefreshToken:
last_used_at: datetime | None = attr.ib(default=None) last_used_at: datetime | None = attr.ib(default=None)
last_used_ip: str | None = attr.ib(default=None) last_used_ip: str | None = attr.ib(default=None)
expire_at: float | None = attr.ib(default=None)
credential: Credentials | None = attr.ib(default=None) credential: Credentials | None = attr.ib(default=None)
version: str | None = attr.ib(default=__version__) version: str | None = attr.ib(default=__version__)

View file

@ -1,12 +1,15 @@
"""Tests for the auth store.""" """Tests for the auth store."""
import asyncio import asyncio
from datetime import timedelta
from typing import Any from typing import Any
from unittest.mock import patch from unittest.mock import patch
from freezegun import freeze_time
import pytest import pytest
from homeassistant.auth import auth_store from homeassistant.auth import auth_store
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.util import dt as dt_util
async def test_loading_no_group_data_format( async def test_loading_no_group_data_format(
@ -267,3 +270,67 @@ async def test_loading_only_once(hass: HomeAssistant) -> None:
mock_dev_registry.assert_called_once_with(hass) mock_dev_registry.assert_called_once_with(hass)
mock_load.assert_called_once_with() mock_load.assert_called_once_with()
assert results[0] == results[1] assert results[0] == results[1]
async def test_add_expire_at_property(
hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
"""Test we correctly add expired_at property if not existing."""
now = dt_util.utcnow()
with freeze_time(now):
hass_storage[auth_store.STORAGE_KEY] = {
"version": 1,
"data": {
"credentials": [],
"users": [
{
"id": "user-id",
"is_active": True,
"is_owner": True,
"name": "Paulus",
"system_generated": False,
},
{
"id": "system-id",
"is_active": True,
"is_owner": True,
"name": "Hass.io",
"system_generated": True,
},
],
"refresh_tokens": [
{
"access_token_expiration": 1800.0,
"client_id": "http://localhost:8123/",
"created_at": "2018-10-03T13:43:19.774637+00:00",
"id": "user-token-id",
"jwt_key": "some-key",
"last_used_at": str(now - timedelta(days=10)),
"token": "some-token",
"user_id": "user-id",
"version": "1.2.3",
},
{
"access_token_expiration": 1800.0,
"client_id": "http://localhost:8123/",
"created_at": "2018-10-03T13:43:19.774637+00:00",
"id": "user-token-id2",
"jwt_key": "some-key2",
"token": "some-token",
"user_id": "user-id",
},
],
},
}
store = auth_store.AuthStore(hass)
await store.async_load()
users = await store.async_get_users()
assert len(users[0].refresh_tokens) == 2
token1, token2 = users[0].refresh_tokens.values()
assert token1.expire_at
assert token1.expire_at == now.timestamp() + timedelta(days=80).total_seconds()
assert token2.expire_at
assert token2.expire_at == now.timestamp() + timedelta(days=90).total_seconds()

View file

@ -26,6 +26,7 @@ from tests.common import (
CLIENT_ID, CLIENT_ID,
MockUser, MockUser,
async_capture_events, async_capture_events,
async_fire_time_changed,
ensure_auth_manager_loaded, ensure_auth_manager_loaded,
flush_store, flush_store,
) )
@ -406,6 +407,8 @@ async def test_generating_system_user(hass: HomeAssistant) -> None:
assert not user.local_only assert not user.local_only
assert token is not None assert token is not None
assert token.client_id is None assert token.client_id is None
assert token.token_type == auth.models.TOKEN_TYPE_SYSTEM
assert token.expire_at is None
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(events) == 1 assert len(events) == 1
@ -421,6 +424,8 @@ async def test_generating_system_user(hass: HomeAssistant) -> None:
assert user.local_only assert user.local_only
assert token is not None assert token is not None
assert token.client_id is None assert token.client_id is None
assert token.token_type == auth.models.TOKEN_TYPE_SYSTEM
assert token.expire_at is None
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(events) == 2 assert len(events) == 2
@ -474,6 +479,8 @@ async def test_refresh_token_with_specific_access_token_expiration(
assert token is not None assert token is not None
assert token.client_id == CLIENT_ID assert token.client_id == CLIENT_ID
assert token.access_token_expiration == timedelta(days=100) assert token.access_token_expiration == timedelta(days=100)
assert token.token_type == auth.models.TOKEN_TYPE_NORMAL
assert token.expire_at is not None
async def test_refresh_token_type(hass: HomeAssistant) -> None: async def test_refresh_token_type(hass: HomeAssistant) -> None:
@ -515,6 +522,7 @@ async def test_refresh_token_type_long_lived_access_token(hass: HomeAssistant) -
assert token.client_name == "GPS LOGGER" assert token.client_name == "GPS LOGGER"
assert token.client_icon == "mdi:home" assert token.client_icon == "mdi:home"
assert token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN assert token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
assert token.expire_at is None
async def test_refresh_token_provider_validation(mock_hass) -> None: async def test_refresh_token_provider_validation(mock_hass) -> None:
@ -565,9 +573,9 @@ async def test_cannot_deactive_owner(mock_hass) -> None:
await manager.async_deactivate_user(owner) await manager.async_deactivate_user(owner)
async def test_remove_refresh_token(mock_hass) -> None: async def test_remove_refresh_token(hass: HomeAssistant) -> None:
"""Test that we can remove a refresh token.""" """Test that we can remove a refresh token."""
manager = await auth.auth_manager_from_config(mock_hass, [], []) manager = await auth.auth_manager_from_config(hass, [], [])
user = MockUser().add_to_auth_manager(manager) user = MockUser().add_to_auth_manager(manager)
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID) refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
access_token = manager.async_create_access_token(refresh_token) access_token = manager.async_create_access_token(refresh_token)
@ -578,6 +586,70 @@ async def test_remove_refresh_token(mock_hass) -> None:
assert manager.async_validate_access_token(access_token) is None assert manager.async_validate_access_token(access_token) is None
async def test_remove_expired_refresh_token(hass: HomeAssistant) -> None:
"""Test that expired refresh tokens are deleted."""
manager = await auth.auth_manager_from_config(hass, [], [])
user = MockUser().add_to_auth_manager(manager)
now = dt_util.utcnow()
with freeze_time(now):
refresh_token1 = await manager.async_create_refresh_token(user, CLIENT_ID)
assert (
refresh_token1.expire_at
== now.timestamp() + timedelta(days=90).total_seconds()
)
with freeze_time(now + timedelta(days=30)):
async_fire_time_changed(hass, now + timedelta(days=30))
refresh_token2 = await manager.async_create_refresh_token(user, CLIENT_ID)
assert (
refresh_token2.expire_at
== now.timestamp() + timedelta(days=120).total_seconds()
)
with freeze_time(now + timedelta(days=89, hours=23)):
async_fire_time_changed(hass, now + timedelta(days=89, hours=23))
await hass.async_block_till_done()
assert manager.async_get_refresh_token(refresh_token1.id)
assert manager.async_get_refresh_token(refresh_token2.id)
with freeze_time(now + timedelta(days=90, seconds=5)):
async_fire_time_changed(hass, now + timedelta(days=90, seconds=5))
await hass.async_block_till_done()
assert manager.async_get_refresh_token(refresh_token1.id) is None
assert manager.async_get_refresh_token(refresh_token2.id)
with freeze_time(now + timedelta(days=120, seconds=5)):
async_fire_time_changed(hass, now + timedelta(days=120, seconds=5))
await hass.async_block_till_done()
assert manager.async_get_refresh_token(refresh_token1.id) is None
assert manager.async_get_refresh_token(refresh_token2.id) is None
async def test_update_expire_at_refresh_token(hass: HomeAssistant) -> None:
"""Test that expire at is updated when refresh token is used."""
manager = await auth.auth_manager_from_config(hass, [], [])
user = MockUser().add_to_auth_manager(manager)
now = dt_util.utcnow()
with freeze_time(now):
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
assert (
refresh_token.expire_at
== now.timestamp() + timedelta(days=90).total_seconds()
)
with freeze_time(now + timedelta(days=30)):
async_fire_time_changed(hass, now + timedelta(days=30))
await hass.async_block_till_done()
assert manager.async_create_access_token(refresh_token)
await hass.async_block_till_done()
assert (
refresh_token.expire_at
== now.timestamp()
+ timedelta(days=30).total_seconds()
+ timedelta(days=90).total_seconds()
)
async def test_register_revoke_token_callback(mock_hass) -> None: async def test_register_revoke_token_callback(mock_hass) -> None:
"""Test that a registered revoke token callback is called.""" """Test that a registered revoke token callback is called."""
manager = await auth.auth_manager_from_config(mock_hass, [], []) manager = await auth.auth_manager_from_config(mock_hass, [], [])