Add expiration of unused refresh tokens (#108428)
Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
parent
0d22822ed0
commit
f5d439799b
6 changed files with 243 additions and 7 deletions
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Mapping
|
||||
from datetime import timedelta
|
||||
from datetime import datetime, timedelta
|
||||
from functools import partial
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
@ -12,11 +12,19 @@ from typing import Any, cast
|
|||
import jwt
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
|
||||
from homeassistant.core import (
|
||||
CALLBACK_TYPE,
|
||||
HassJob,
|
||||
HassJobType,
|
||||
HomeAssistant,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
from homeassistant.helpers.event import async_track_point_in_utc_time
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from . import auth_store, jwt_wrapper, models
|
||||
from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN
|
||||
from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN, REFRESH_TOKEN_EXPIRATION
|
||||
from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config
|
||||
from .providers import AuthProvider, LoginFlow, auth_provider_from_config
|
||||
|
||||
|
@ -75,7 +83,9 @@ async def auth_manager_from_config(
|
|||
for module in modules:
|
||||
module_hash[module.id] = module
|
||||
|
||||
return AuthManager(hass, store, provider_hash, module_hash)
|
||||
manager = AuthManager(hass, store, provider_hash, module_hash)
|
||||
manager.async_setup()
|
||||
return manager
|
||||
|
||||
|
||||
class AuthManagerFlowManager(data_entry_flow.FlowManager):
|
||||
|
@ -159,6 +169,21 @@ class AuthManager:
|
|||
self._mfa_modules = mfa_modules
|
||||
self.login_flow = AuthManagerFlowManager(hass, self)
|
||||
self._revoke_callbacks: dict[str, set[CALLBACK_TYPE]] = {}
|
||||
self._expire_callback: CALLBACK_TYPE | None = None
|
||||
self._remove_expired_job = HassJob(
|
||||
self._async_remove_expired_refresh_tokens, job_type=HassJobType.Callback
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_setup(self) -> None:
|
||||
"""Set up the auth manager."""
|
||||
hass = self.hass
|
||||
hass.async_add_shutdown_job(
|
||||
HassJob(
|
||||
self._async_cancel_expiration_schedule, job_type=HassJobType.Callback
|
||||
)
|
||||
)
|
||||
self._async_track_next_refresh_token_expiration()
|
||||
|
||||
@property
|
||||
def auth_providers(self) -> list[AuthProvider]:
|
||||
|
@ -424,6 +449,11 @@ class AuthManager:
|
|||
else:
|
||||
token_type = models.TOKEN_TYPE_NORMAL
|
||||
|
||||
if token_type is models.TOKEN_TYPE_NORMAL:
|
||||
expire_at = time.time() + REFRESH_TOKEN_EXPIRATION
|
||||
else:
|
||||
expire_at = None
|
||||
|
||||
if user.system_generated != (token_type == models.TOKEN_TYPE_SYSTEM):
|
||||
raise ValueError(
|
||||
"System generated users can only have system type refresh tokens"
|
||||
|
@ -455,6 +485,7 @@ class AuthManager:
|
|||
client_icon,
|
||||
token_type,
|
||||
access_token_expiration,
|
||||
expire_at,
|
||||
credential,
|
||||
)
|
||||
|
||||
|
@ -479,6 +510,38 @@ class AuthManager:
|
|||
for revoke_callback in callbacks:
|
||||
revoke_callback()
|
||||
|
||||
@callback
|
||||
def _async_remove_expired_refresh_tokens(self, _: datetime | None = None) -> None:
|
||||
"""Remove expired refresh tokens."""
|
||||
now = time.time()
|
||||
for token in self._store.async_get_refresh_tokens()[:]:
|
||||
if (expire_at := token.expire_at) is not None and expire_at <= now:
|
||||
self.async_remove_refresh_token(token)
|
||||
self._async_track_next_refresh_token_expiration()
|
||||
|
||||
@callback
|
||||
def _async_track_next_refresh_token_expiration(self) -> None:
|
||||
"""Initialise all token expiration scheduled tasks."""
|
||||
next_expiration = time.time() + REFRESH_TOKEN_EXPIRATION
|
||||
for token in self._store.async_get_refresh_tokens():
|
||||
if (
|
||||
expire_at := token.expire_at
|
||||
) is not None and expire_at < next_expiration:
|
||||
next_expiration = expire_at
|
||||
|
||||
self._expire_callback = async_track_point_in_utc_time(
|
||||
self.hass,
|
||||
self._remove_expired_job,
|
||||
dt_util.utc_from_timestamp(next_expiration),
|
||||
)
|
||||
|
||||
@callback
|
||||
def _async_cancel_expiration_schedule(self) -> None:
|
||||
"""Cancel tracking of expired refresh tokens."""
|
||||
if self._expire_callback:
|
||||
self._expire_callback()
|
||||
self._expire_callback = None
|
||||
|
||||
@callback
|
||||
def _async_unregister(
|
||||
self, callbacks: set[CALLBACK_TYPE], callback_: CALLBACK_TYPE
|
||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||
|
||||
from datetime import timedelta
|
||||
import hmac
|
||||
import itertools
|
||||
from logging import getLogger
|
||||
from typing import Any
|
||||
|
||||
|
@ -17,6 +18,7 @@ from .const import (
|
|||
GROUP_ID_ADMIN,
|
||||
GROUP_ID_READ_ONLY,
|
||||
GROUP_ID_USER,
|
||||
REFRESH_TOKEN_EXPIRATION,
|
||||
)
|
||||
from .permissions import system_policies
|
||||
from .permissions.models import PermissionLookup
|
||||
|
@ -186,6 +188,7 @@ class AuthStore:
|
|||
client_icon: str | None = None,
|
||||
token_type: str = models.TOKEN_TYPE_NORMAL,
|
||||
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
|
||||
expire_at: float | None = None,
|
||||
credential: models.Credentials | None = None,
|
||||
) -> models.RefreshToken:
|
||||
"""Create a new token for a user."""
|
||||
|
@ -194,6 +197,7 @@ class AuthStore:
|
|||
"client_id": client_id,
|
||||
"token_type": token_type,
|
||||
"access_token_expiration": access_token_expiration,
|
||||
"expire_at": expire_at,
|
||||
"credential": credential,
|
||||
}
|
||||
if client_name:
|
||||
|
@ -239,6 +243,15 @@ class AuthStore:
|
|||
|
||||
return found
|
||||
|
||||
@callback
|
||||
def async_get_refresh_tokens(self) -> list[models.RefreshToken]:
|
||||
"""Get all refresh tokens."""
|
||||
return list(
|
||||
itertools.chain.from_iterable(
|
||||
user.refresh_tokens.values() for user in self._users.values()
|
||||
)
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_log_refresh_token_usage(
|
||||
self, refresh_token: models.RefreshToken, remote_ip: str | None = None
|
||||
|
@ -246,9 +259,13 @@ class AuthStore:
|
|||
"""Update refresh token last used information."""
|
||||
refresh_token.last_used_at = dt_util.utcnow()
|
||||
refresh_token.last_used_ip = remote_ip
|
||||
if refresh_token.expire_at:
|
||||
refresh_token.expire_at = (
|
||||
refresh_token.last_used_at.timestamp() + REFRESH_TOKEN_EXPIRATION
|
||||
)
|
||||
self._async_schedule_save()
|
||||
|
||||
async def async_load(self) -> None:
|
||||
async def async_load(self) -> None: # noqa: C901
|
||||
"""Load the users."""
|
||||
if self._loaded:
|
||||
raise RuntimeError("Auth storage is already loaded")
|
||||
|
@ -261,6 +278,8 @@ class AuthStore:
|
|||
perm_lookup = PermissionLookup(ent_reg, dev_reg)
|
||||
self._perm_lookup = perm_lookup
|
||||
|
||||
now_ts = dt_util.utcnow().timestamp()
|
||||
|
||||
if data is None or not isinstance(data, dict):
|
||||
self._set_defaults()
|
||||
return
|
||||
|
@ -414,6 +433,14 @@ class AuthStore:
|
|||
else:
|
||||
last_used_at = None
|
||||
|
||||
if (
|
||||
expire_at := rt_dict.get("expire_at")
|
||||
) is None and token_type == models.TOKEN_TYPE_NORMAL:
|
||||
if last_used_at:
|
||||
expire_at = last_used_at.timestamp() + REFRESH_TOKEN_EXPIRATION
|
||||
else:
|
||||
expire_at = now_ts + REFRESH_TOKEN_EXPIRATION
|
||||
|
||||
token = models.RefreshToken(
|
||||
id=rt_dict["id"],
|
||||
user=users[rt_dict["user_id"]],
|
||||
|
@ -430,6 +457,7 @@ class AuthStore:
|
|||
jwt_key=rt_dict["jwt_key"],
|
||||
last_used_at=last_used_at,
|
||||
last_used_ip=rt_dict.get("last_used_ip"),
|
||||
expire_at=expire_at,
|
||||
version=rt_dict.get("version"),
|
||||
)
|
||||
if "credential_id" in rt_dict:
|
||||
|
@ -439,6 +467,8 @@ class AuthStore:
|
|||
self._groups = groups
|
||||
self._users = users
|
||||
|
||||
self._async_schedule_save()
|
||||
|
||||
@callback
|
||||
def _async_schedule_save(self) -> None:
|
||||
"""Save users."""
|
||||
|
@ -503,6 +533,7 @@ class AuthStore:
|
|||
if refresh_token.last_used_at
|
||||
else None,
|
||||
"last_used_ip": refresh_token.last_used_ip,
|
||||
"expire_at": refresh_token.expire_at,
|
||||
"credential_id": refresh_token.credential.id
|
||||
if refresh_token.credential
|
||||
else None,
|
||||
|
|
|
@ -3,6 +3,7 @@ from datetime import timedelta
|
|||
|
||||
ACCESS_TOKEN_EXPIRATION = timedelta(minutes=30)
|
||||
MFA_SESSION_EXPIRATION = timedelta(minutes=5)
|
||||
REFRESH_TOKEN_EXPIRATION = timedelta(days=90).total_seconds()
|
||||
|
||||
GROUP_ID_ADMIN = "system-admin"
|
||||
GROUP_ID_USER = "system-users"
|
||||
|
|
|
@ -117,6 +117,8 @@ class RefreshToken:
|
|||
last_used_at: datetime | None = attr.ib(default=None)
|
||||
last_used_ip: str | None = attr.ib(default=None)
|
||||
|
||||
expire_at: float | None = attr.ib(default=None)
|
||||
|
||||
credential: Credentials | None = attr.ib(default=None)
|
||||
|
||||
version: str | None = attr.ib(default=__version__)
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
"""Tests for the auth store."""
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from freezegun import freeze_time
|
||||
import pytest
|
||||
|
||||
from homeassistant.auth import auth_store
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
|
||||
async def test_loading_no_group_data_format(
|
||||
|
@ -267,3 +270,67 @@ async def test_loading_only_once(hass: HomeAssistant) -> None:
|
|||
mock_dev_registry.assert_called_once_with(hass)
|
||||
mock_load.assert_called_once_with()
|
||||
assert results[0] == results[1]
|
||||
|
||||
|
||||
async def test_add_expire_at_property(
|
||||
hass: HomeAssistant, hass_storage: dict[str, Any]
|
||||
) -> None:
|
||||
"""Test we correctly add expired_at property if not existing."""
|
||||
now = dt_util.utcnow()
|
||||
with freeze_time(now):
|
||||
hass_storage[auth_store.STORAGE_KEY] = {
|
||||
"version": 1,
|
||||
"data": {
|
||||
"credentials": [],
|
||||
"users": [
|
||||
{
|
||||
"id": "user-id",
|
||||
"is_active": True,
|
||||
"is_owner": True,
|
||||
"name": "Paulus",
|
||||
"system_generated": False,
|
||||
},
|
||||
{
|
||||
"id": "system-id",
|
||||
"is_active": True,
|
||||
"is_owner": True,
|
||||
"name": "Hass.io",
|
||||
"system_generated": True,
|
||||
},
|
||||
],
|
||||
"refresh_tokens": [
|
||||
{
|
||||
"access_token_expiration": 1800.0,
|
||||
"client_id": "http://localhost:8123/",
|
||||
"created_at": "2018-10-03T13:43:19.774637+00:00",
|
||||
"id": "user-token-id",
|
||||
"jwt_key": "some-key",
|
||||
"last_used_at": str(now - timedelta(days=10)),
|
||||
"token": "some-token",
|
||||
"user_id": "user-id",
|
||||
"version": "1.2.3",
|
||||
},
|
||||
{
|
||||
"access_token_expiration": 1800.0,
|
||||
"client_id": "http://localhost:8123/",
|
||||
"created_at": "2018-10-03T13:43:19.774637+00:00",
|
||||
"id": "user-token-id2",
|
||||
"jwt_key": "some-key2",
|
||||
"token": "some-token",
|
||||
"user_id": "user-id",
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
store = auth_store.AuthStore(hass)
|
||||
await store.async_load()
|
||||
|
||||
users = await store.async_get_users()
|
||||
|
||||
assert len(users[0].refresh_tokens) == 2
|
||||
token1, token2 = users[0].refresh_tokens.values()
|
||||
assert token1.expire_at
|
||||
assert token1.expire_at == now.timestamp() + timedelta(days=80).total_seconds()
|
||||
assert token2.expire_at
|
||||
assert token2.expire_at == now.timestamp() + timedelta(days=90).total_seconds()
|
||||
|
|
|
@ -26,6 +26,7 @@ from tests.common import (
|
|||
CLIENT_ID,
|
||||
MockUser,
|
||||
async_capture_events,
|
||||
async_fire_time_changed,
|
||||
ensure_auth_manager_loaded,
|
||||
flush_store,
|
||||
)
|
||||
|
@ -406,6 +407,8 @@ async def test_generating_system_user(hass: HomeAssistant) -> None:
|
|||
assert not user.local_only
|
||||
assert token is not None
|
||||
assert token.client_id is None
|
||||
assert token.token_type == auth.models.TOKEN_TYPE_SYSTEM
|
||||
assert token.expire_at is None
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert len(events) == 1
|
||||
|
@ -421,6 +424,8 @@ async def test_generating_system_user(hass: HomeAssistant) -> None:
|
|||
assert user.local_only
|
||||
assert token is not None
|
||||
assert token.client_id is None
|
||||
assert token.token_type == auth.models.TOKEN_TYPE_SYSTEM
|
||||
assert token.expire_at is None
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert len(events) == 2
|
||||
|
@ -474,6 +479,8 @@ async def test_refresh_token_with_specific_access_token_expiration(
|
|||
assert token is not None
|
||||
assert token.client_id == CLIENT_ID
|
||||
assert token.access_token_expiration == timedelta(days=100)
|
||||
assert token.token_type == auth.models.TOKEN_TYPE_NORMAL
|
||||
assert token.expire_at is not None
|
||||
|
||||
|
||||
async def test_refresh_token_type(hass: HomeAssistant) -> None:
|
||||
|
@ -515,6 +522,7 @@ async def test_refresh_token_type_long_lived_access_token(hass: HomeAssistant) -
|
|||
assert token.client_name == "GPS LOGGER"
|
||||
assert token.client_icon == "mdi:home"
|
||||
assert token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
||||
assert token.expire_at is None
|
||||
|
||||
|
||||
async def test_refresh_token_provider_validation(mock_hass) -> None:
|
||||
|
@ -565,9 +573,9 @@ async def test_cannot_deactive_owner(mock_hass) -> None:
|
|||
await manager.async_deactivate_user(owner)
|
||||
|
||||
|
||||
async def test_remove_refresh_token(mock_hass) -> None:
|
||||
async def test_remove_refresh_token(hass: HomeAssistant) -> None:
|
||||
"""Test that we can remove a refresh token."""
|
||||
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||
manager = await auth.auth_manager_from_config(hass, [], [])
|
||||
user = MockUser().add_to_auth_manager(manager)
|
||||
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
|
||||
access_token = manager.async_create_access_token(refresh_token)
|
||||
|
@ -578,6 +586,70 @@ async def test_remove_refresh_token(mock_hass) -> None:
|
|||
assert manager.async_validate_access_token(access_token) is None
|
||||
|
||||
|
||||
async def test_remove_expired_refresh_token(hass: HomeAssistant) -> None:
|
||||
"""Test that expired refresh tokens are deleted."""
|
||||
manager = await auth.auth_manager_from_config(hass, [], [])
|
||||
user = MockUser().add_to_auth_manager(manager)
|
||||
now = dt_util.utcnow()
|
||||
with freeze_time(now):
|
||||
refresh_token1 = await manager.async_create_refresh_token(user, CLIENT_ID)
|
||||
assert (
|
||||
refresh_token1.expire_at
|
||||
== now.timestamp() + timedelta(days=90).total_seconds()
|
||||
)
|
||||
|
||||
with freeze_time(now + timedelta(days=30)):
|
||||
async_fire_time_changed(hass, now + timedelta(days=30))
|
||||
refresh_token2 = await manager.async_create_refresh_token(user, CLIENT_ID)
|
||||
assert (
|
||||
refresh_token2.expire_at
|
||||
== now.timestamp() + timedelta(days=120).total_seconds()
|
||||
)
|
||||
|
||||
with freeze_time(now + timedelta(days=89, hours=23)):
|
||||
async_fire_time_changed(hass, now + timedelta(days=89, hours=23))
|
||||
await hass.async_block_till_done()
|
||||
assert manager.async_get_refresh_token(refresh_token1.id)
|
||||
assert manager.async_get_refresh_token(refresh_token2.id)
|
||||
|
||||
with freeze_time(now + timedelta(days=90, seconds=5)):
|
||||
async_fire_time_changed(hass, now + timedelta(days=90, seconds=5))
|
||||
await hass.async_block_till_done()
|
||||
assert manager.async_get_refresh_token(refresh_token1.id) is None
|
||||
assert manager.async_get_refresh_token(refresh_token2.id)
|
||||
|
||||
with freeze_time(now + timedelta(days=120, seconds=5)):
|
||||
async_fire_time_changed(hass, now + timedelta(days=120, seconds=5))
|
||||
await hass.async_block_till_done()
|
||||
assert manager.async_get_refresh_token(refresh_token1.id) is None
|
||||
assert manager.async_get_refresh_token(refresh_token2.id) is None
|
||||
|
||||
|
||||
async def test_update_expire_at_refresh_token(hass: HomeAssistant) -> None:
|
||||
"""Test that expire at is updated when refresh token is used."""
|
||||
manager = await auth.auth_manager_from_config(hass, [], [])
|
||||
user = MockUser().add_to_auth_manager(manager)
|
||||
now = dt_util.utcnow()
|
||||
with freeze_time(now):
|
||||
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
|
||||
assert (
|
||||
refresh_token.expire_at
|
||||
== now.timestamp() + timedelta(days=90).total_seconds()
|
||||
)
|
||||
|
||||
with freeze_time(now + timedelta(days=30)):
|
||||
async_fire_time_changed(hass, now + timedelta(days=30))
|
||||
await hass.async_block_till_done()
|
||||
assert manager.async_create_access_token(refresh_token)
|
||||
await hass.async_block_till_done()
|
||||
assert (
|
||||
refresh_token.expire_at
|
||||
== now.timestamp()
|
||||
+ timedelta(days=30).total_seconds()
|
||||
+ timedelta(days=90).total_seconds()
|
||||
)
|
||||
|
||||
|
||||
async def test_register_revoke_token_callback(mock_hass) -> None:
|
||||
"""Test that a registered revoke token callback is called."""
|
||||
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||
|
|
Loading…
Add table
Reference in a new issue