Add expiration of unused refresh tokens (#108428)

Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Michael 2024-01-25 00:24:22 +01:00 committed by GitHub
parent 0d22822ed0
commit f5d439799b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 243 additions and 7 deletions

View file

@ -4,7 +4,7 @@ from __future__ import annotations
import asyncio
from collections import OrderedDict
from collections.abc import Mapping
from datetime import timedelta
from datetime import datetime, timedelta
from functools import partial
import time
from typing import Any, cast
@ -12,11 +12,19 @@ from typing import Any, cast
import jwt
from homeassistant import data_entry_flow
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.core import (
CALLBACK_TYPE,
HassJob,
HassJobType,
HomeAssistant,
callback,
)
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.event import async_track_point_in_utc_time
from homeassistant.util import dt as dt_util
from . import auth_store, jwt_wrapper, models
from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN
from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN, REFRESH_TOKEN_EXPIRATION
from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config
from .providers import AuthProvider, LoginFlow, auth_provider_from_config
@ -75,7 +83,9 @@ async def auth_manager_from_config(
for module in modules:
module_hash[module.id] = module
return AuthManager(hass, store, provider_hash, module_hash)
manager = AuthManager(hass, store, provider_hash, module_hash)
manager.async_setup()
return manager
class AuthManagerFlowManager(data_entry_flow.FlowManager):
@ -159,6 +169,21 @@ class AuthManager:
self._mfa_modules = mfa_modules
self.login_flow = AuthManagerFlowManager(hass, self)
self._revoke_callbacks: dict[str, set[CALLBACK_TYPE]] = {}
self._expire_callback: CALLBACK_TYPE | None = None
self._remove_expired_job = HassJob(
self._async_remove_expired_refresh_tokens, job_type=HassJobType.Callback
)
@callback
def async_setup(self) -> None:
"""Set up the auth manager."""
hass = self.hass
hass.async_add_shutdown_job(
HassJob(
self._async_cancel_expiration_schedule, job_type=HassJobType.Callback
)
)
self._async_track_next_refresh_token_expiration()
@property
def auth_providers(self) -> list[AuthProvider]:
@ -424,6 +449,11 @@ class AuthManager:
else:
token_type = models.TOKEN_TYPE_NORMAL
if token_type is models.TOKEN_TYPE_NORMAL:
expire_at = time.time() + REFRESH_TOKEN_EXPIRATION
else:
expire_at = None
if user.system_generated != (token_type == models.TOKEN_TYPE_SYSTEM):
raise ValueError(
"System generated users can only have system type refresh tokens"
@ -455,6 +485,7 @@ class AuthManager:
client_icon,
token_type,
access_token_expiration,
expire_at,
credential,
)
@ -479,6 +510,38 @@ class AuthManager:
for revoke_callback in callbacks:
revoke_callback()
@callback
def _async_remove_expired_refresh_tokens(self, _: datetime | None = None) -> None:
"""Remove expired refresh tokens."""
now = time.time()
for token in self._store.async_get_refresh_tokens()[:]:
if (expire_at := token.expire_at) is not None and expire_at <= now:
self.async_remove_refresh_token(token)
self._async_track_next_refresh_token_expiration()
@callback
def _async_track_next_refresh_token_expiration(self) -> None:
"""Initialise all token expiration scheduled tasks."""
next_expiration = time.time() + REFRESH_TOKEN_EXPIRATION
for token in self._store.async_get_refresh_tokens():
if (
expire_at := token.expire_at
) is not None and expire_at < next_expiration:
next_expiration = expire_at
self._expire_callback = async_track_point_in_utc_time(
self.hass,
self._remove_expired_job,
dt_util.utc_from_timestamp(next_expiration),
)
@callback
def _async_cancel_expiration_schedule(self) -> None:
"""Cancel tracking of expired refresh tokens."""
if self._expire_callback:
self._expire_callback()
self._expire_callback = None
@callback
def _async_unregister(
self, callbacks: set[CALLBACK_TYPE], callback_: CALLBACK_TYPE

View file

@ -3,6 +3,7 @@ from __future__ import annotations
from datetime import timedelta
import hmac
import itertools
from logging import getLogger
from typing import Any
@ -17,6 +18,7 @@ from .const import (
GROUP_ID_ADMIN,
GROUP_ID_READ_ONLY,
GROUP_ID_USER,
REFRESH_TOKEN_EXPIRATION,
)
from .permissions import system_policies
from .permissions.models import PermissionLookup
@ -186,6 +188,7 @@ class AuthStore:
client_icon: str | None = None,
token_type: str = models.TOKEN_TYPE_NORMAL,
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
expire_at: float | None = None,
credential: models.Credentials | None = None,
) -> models.RefreshToken:
"""Create a new token for a user."""
@ -194,6 +197,7 @@ class AuthStore:
"client_id": client_id,
"token_type": token_type,
"access_token_expiration": access_token_expiration,
"expire_at": expire_at,
"credential": credential,
}
if client_name:
@ -239,6 +243,15 @@ class AuthStore:
return found
@callback
def async_get_refresh_tokens(self) -> list[models.RefreshToken]:
"""Get all refresh tokens."""
return list(
itertools.chain.from_iterable(
user.refresh_tokens.values() for user in self._users.values()
)
)
@callback
def async_log_refresh_token_usage(
self, refresh_token: models.RefreshToken, remote_ip: str | None = None
@ -246,9 +259,13 @@ class AuthStore:
"""Update refresh token last used information."""
refresh_token.last_used_at = dt_util.utcnow()
refresh_token.last_used_ip = remote_ip
if refresh_token.expire_at:
refresh_token.expire_at = (
refresh_token.last_used_at.timestamp() + REFRESH_TOKEN_EXPIRATION
)
self._async_schedule_save()
async def async_load(self) -> None:
async def async_load(self) -> None: # noqa: C901
"""Load the users."""
if self._loaded:
raise RuntimeError("Auth storage is already loaded")
@ -261,6 +278,8 @@ class AuthStore:
perm_lookup = PermissionLookup(ent_reg, dev_reg)
self._perm_lookup = perm_lookup
now_ts = dt_util.utcnow().timestamp()
if data is None or not isinstance(data, dict):
self._set_defaults()
return
@ -414,6 +433,14 @@ class AuthStore:
else:
last_used_at = None
if (
expire_at := rt_dict.get("expire_at")
) is None and token_type == models.TOKEN_TYPE_NORMAL:
if last_used_at:
expire_at = last_used_at.timestamp() + REFRESH_TOKEN_EXPIRATION
else:
expire_at = now_ts + REFRESH_TOKEN_EXPIRATION
token = models.RefreshToken(
id=rt_dict["id"],
user=users[rt_dict["user_id"]],
@ -430,6 +457,7 @@ class AuthStore:
jwt_key=rt_dict["jwt_key"],
last_used_at=last_used_at,
last_used_ip=rt_dict.get("last_used_ip"),
expire_at=expire_at,
version=rt_dict.get("version"),
)
if "credential_id" in rt_dict:
@ -439,6 +467,8 @@ class AuthStore:
self._groups = groups
self._users = users
self._async_schedule_save()
@callback
def _async_schedule_save(self) -> None:
"""Save users."""
@ -503,6 +533,7 @@ class AuthStore:
if refresh_token.last_used_at
else None,
"last_used_ip": refresh_token.last_used_ip,
"expire_at": refresh_token.expire_at,
"credential_id": refresh_token.credential.id
if refresh_token.credential
else None,

View file

@ -3,6 +3,7 @@ from datetime import timedelta
ACCESS_TOKEN_EXPIRATION = timedelta(minutes=30)
MFA_SESSION_EXPIRATION = timedelta(minutes=5)
REFRESH_TOKEN_EXPIRATION = timedelta(days=90).total_seconds()
GROUP_ID_ADMIN = "system-admin"
GROUP_ID_USER = "system-users"

View file

@ -117,6 +117,8 @@ class RefreshToken:
last_used_at: datetime | None = attr.ib(default=None)
last_used_ip: str | None = attr.ib(default=None)
expire_at: float | None = attr.ib(default=None)
credential: Credentials | None = attr.ib(default=None)
version: str | None = attr.ib(default=__version__)

View file

@ -1,12 +1,15 @@
"""Tests for the auth store."""
import asyncio
from datetime import timedelta
from typing import Any
from unittest.mock import patch
from freezegun import freeze_time
import pytest
from homeassistant.auth import auth_store
from homeassistant.core import HomeAssistant
from homeassistant.util import dt as dt_util
async def test_loading_no_group_data_format(
@ -267,3 +270,67 @@ async def test_loading_only_once(hass: HomeAssistant) -> None:
mock_dev_registry.assert_called_once_with(hass)
mock_load.assert_called_once_with()
assert results[0] == results[1]
async def test_add_expire_at_property(
hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
"""Test we correctly add expired_at property if not existing."""
now = dt_util.utcnow()
with freeze_time(now):
hass_storage[auth_store.STORAGE_KEY] = {
"version": 1,
"data": {
"credentials": [],
"users": [
{
"id": "user-id",
"is_active": True,
"is_owner": True,
"name": "Paulus",
"system_generated": False,
},
{
"id": "system-id",
"is_active": True,
"is_owner": True,
"name": "Hass.io",
"system_generated": True,
},
],
"refresh_tokens": [
{
"access_token_expiration": 1800.0,
"client_id": "http://localhost:8123/",
"created_at": "2018-10-03T13:43:19.774637+00:00",
"id": "user-token-id",
"jwt_key": "some-key",
"last_used_at": str(now - timedelta(days=10)),
"token": "some-token",
"user_id": "user-id",
"version": "1.2.3",
},
{
"access_token_expiration": 1800.0,
"client_id": "http://localhost:8123/",
"created_at": "2018-10-03T13:43:19.774637+00:00",
"id": "user-token-id2",
"jwt_key": "some-key2",
"token": "some-token",
"user_id": "user-id",
},
],
},
}
store = auth_store.AuthStore(hass)
await store.async_load()
users = await store.async_get_users()
assert len(users[0].refresh_tokens) == 2
token1, token2 = users[0].refresh_tokens.values()
assert token1.expire_at
assert token1.expire_at == now.timestamp() + timedelta(days=80).total_seconds()
assert token2.expire_at
assert token2.expire_at == now.timestamp() + timedelta(days=90).total_seconds()

View file

@ -26,6 +26,7 @@ from tests.common import (
CLIENT_ID,
MockUser,
async_capture_events,
async_fire_time_changed,
ensure_auth_manager_loaded,
flush_store,
)
@ -406,6 +407,8 @@ async def test_generating_system_user(hass: HomeAssistant) -> None:
assert not user.local_only
assert token is not None
assert token.client_id is None
assert token.token_type == auth.models.TOKEN_TYPE_SYSTEM
assert token.expire_at is None
await hass.async_block_till_done()
assert len(events) == 1
@ -421,6 +424,8 @@ async def test_generating_system_user(hass: HomeAssistant) -> None:
assert user.local_only
assert token is not None
assert token.client_id is None
assert token.token_type == auth.models.TOKEN_TYPE_SYSTEM
assert token.expire_at is None
await hass.async_block_till_done()
assert len(events) == 2
@ -474,6 +479,8 @@ async def test_refresh_token_with_specific_access_token_expiration(
assert token is not None
assert token.client_id == CLIENT_ID
assert token.access_token_expiration == timedelta(days=100)
assert token.token_type == auth.models.TOKEN_TYPE_NORMAL
assert token.expire_at is not None
async def test_refresh_token_type(hass: HomeAssistant) -> None:
@ -515,6 +522,7 @@ async def test_refresh_token_type_long_lived_access_token(hass: HomeAssistant) -
assert token.client_name == "GPS LOGGER"
assert token.client_icon == "mdi:home"
assert token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
assert token.expire_at is None
async def test_refresh_token_provider_validation(mock_hass) -> None:
@ -565,9 +573,9 @@ async def test_cannot_deactive_owner(mock_hass) -> None:
await manager.async_deactivate_user(owner)
async def test_remove_refresh_token(mock_hass) -> None:
async def test_remove_refresh_token(hass: HomeAssistant) -> None:
"""Test that we can remove a refresh token."""
manager = await auth.auth_manager_from_config(mock_hass, [], [])
manager = await auth.auth_manager_from_config(hass, [], [])
user = MockUser().add_to_auth_manager(manager)
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
access_token = manager.async_create_access_token(refresh_token)
@ -578,6 +586,70 @@ async def test_remove_refresh_token(mock_hass) -> None:
assert manager.async_validate_access_token(access_token) is None
async def test_remove_expired_refresh_token(hass: HomeAssistant) -> None:
"""Test that expired refresh tokens are deleted."""
manager = await auth.auth_manager_from_config(hass, [], [])
user = MockUser().add_to_auth_manager(manager)
now = dt_util.utcnow()
with freeze_time(now):
refresh_token1 = await manager.async_create_refresh_token(user, CLIENT_ID)
assert (
refresh_token1.expire_at
== now.timestamp() + timedelta(days=90).total_seconds()
)
with freeze_time(now + timedelta(days=30)):
async_fire_time_changed(hass, now + timedelta(days=30))
refresh_token2 = await manager.async_create_refresh_token(user, CLIENT_ID)
assert (
refresh_token2.expire_at
== now.timestamp() + timedelta(days=120).total_seconds()
)
with freeze_time(now + timedelta(days=89, hours=23)):
async_fire_time_changed(hass, now + timedelta(days=89, hours=23))
await hass.async_block_till_done()
assert manager.async_get_refresh_token(refresh_token1.id)
assert manager.async_get_refresh_token(refresh_token2.id)
with freeze_time(now + timedelta(days=90, seconds=5)):
async_fire_time_changed(hass, now + timedelta(days=90, seconds=5))
await hass.async_block_till_done()
assert manager.async_get_refresh_token(refresh_token1.id) is None
assert manager.async_get_refresh_token(refresh_token2.id)
with freeze_time(now + timedelta(days=120, seconds=5)):
async_fire_time_changed(hass, now + timedelta(days=120, seconds=5))
await hass.async_block_till_done()
assert manager.async_get_refresh_token(refresh_token1.id) is None
assert manager.async_get_refresh_token(refresh_token2.id) is None
async def test_update_expire_at_refresh_token(hass: HomeAssistant) -> None:
"""Test that expire at is updated when refresh token is used."""
manager = await auth.auth_manager_from_config(hass, [], [])
user = MockUser().add_to_auth_manager(manager)
now = dt_util.utcnow()
with freeze_time(now):
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
assert (
refresh_token.expire_at
== now.timestamp() + timedelta(days=90).total_seconds()
)
with freeze_time(now + timedelta(days=30)):
async_fire_time_changed(hass, now + timedelta(days=30))
await hass.async_block_till_done()
assert manager.async_create_access_token(refresh_token)
await hass.async_block_till_done()
assert (
refresh_token.expire_at
== now.timestamp()
+ timedelta(days=30).total_seconds()
+ timedelta(days=90).total_seconds()
)
async def test_register_revoke_token_callback(mock_hass) -> None:
"""Test that a registered revoke token callback is called."""
manager = await auth.auth_manager_from_config(mock_hass, [], [])