Add expiration of unused refresh tokens (#108428)
Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
parent
0d22822ed0
commit
f5d439799b
6 changed files with 243 additions and 7 deletions
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from datetime import timedelta
|
from datetime import datetime, timedelta
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import time
|
import time
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
@ -12,11 +12,19 @@ from typing import Any, cast
|
||||||
import jwt
|
import jwt
|
||||||
|
|
||||||
from homeassistant import data_entry_flow
|
from homeassistant import data_entry_flow
|
||||||
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
|
from homeassistant.core import (
|
||||||
|
CALLBACK_TYPE,
|
||||||
|
HassJob,
|
||||||
|
HassJobType,
|
||||||
|
HomeAssistant,
|
||||||
|
callback,
|
||||||
|
)
|
||||||
from homeassistant.data_entry_flow import FlowResult
|
from homeassistant.data_entry_flow import FlowResult
|
||||||
|
from homeassistant.helpers.event import async_track_point_in_utc_time
|
||||||
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
from . import auth_store, jwt_wrapper, models
|
from . import auth_store, jwt_wrapper, models
|
||||||
from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN
|
from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN, REFRESH_TOKEN_EXPIRATION
|
||||||
from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config
|
from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config
|
||||||
from .providers import AuthProvider, LoginFlow, auth_provider_from_config
|
from .providers import AuthProvider, LoginFlow, auth_provider_from_config
|
||||||
|
|
||||||
|
@ -75,7 +83,9 @@ async def auth_manager_from_config(
|
||||||
for module in modules:
|
for module in modules:
|
||||||
module_hash[module.id] = module
|
module_hash[module.id] = module
|
||||||
|
|
||||||
return AuthManager(hass, store, provider_hash, module_hash)
|
manager = AuthManager(hass, store, provider_hash, module_hash)
|
||||||
|
manager.async_setup()
|
||||||
|
return manager
|
||||||
|
|
||||||
|
|
||||||
class AuthManagerFlowManager(data_entry_flow.FlowManager):
|
class AuthManagerFlowManager(data_entry_flow.FlowManager):
|
||||||
|
@ -159,6 +169,21 @@ class AuthManager:
|
||||||
self._mfa_modules = mfa_modules
|
self._mfa_modules = mfa_modules
|
||||||
self.login_flow = AuthManagerFlowManager(hass, self)
|
self.login_flow = AuthManagerFlowManager(hass, self)
|
||||||
self._revoke_callbacks: dict[str, set[CALLBACK_TYPE]] = {}
|
self._revoke_callbacks: dict[str, set[CALLBACK_TYPE]] = {}
|
||||||
|
self._expire_callback: CALLBACK_TYPE | None = None
|
||||||
|
self._remove_expired_job = HassJob(
|
||||||
|
self._async_remove_expired_refresh_tokens, job_type=HassJobType.Callback
|
||||||
|
)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_setup(self) -> None:
|
||||||
|
"""Set up the auth manager."""
|
||||||
|
hass = self.hass
|
||||||
|
hass.async_add_shutdown_job(
|
||||||
|
HassJob(
|
||||||
|
self._async_cancel_expiration_schedule, job_type=HassJobType.Callback
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._async_track_next_refresh_token_expiration()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def auth_providers(self) -> list[AuthProvider]:
|
def auth_providers(self) -> list[AuthProvider]:
|
||||||
|
@ -424,6 +449,11 @@ class AuthManager:
|
||||||
else:
|
else:
|
||||||
token_type = models.TOKEN_TYPE_NORMAL
|
token_type = models.TOKEN_TYPE_NORMAL
|
||||||
|
|
||||||
|
if token_type is models.TOKEN_TYPE_NORMAL:
|
||||||
|
expire_at = time.time() + REFRESH_TOKEN_EXPIRATION
|
||||||
|
else:
|
||||||
|
expire_at = None
|
||||||
|
|
||||||
if user.system_generated != (token_type == models.TOKEN_TYPE_SYSTEM):
|
if user.system_generated != (token_type == models.TOKEN_TYPE_SYSTEM):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"System generated users can only have system type refresh tokens"
|
"System generated users can only have system type refresh tokens"
|
||||||
|
@ -455,6 +485,7 @@ class AuthManager:
|
||||||
client_icon,
|
client_icon,
|
||||||
token_type,
|
token_type,
|
||||||
access_token_expiration,
|
access_token_expiration,
|
||||||
|
expire_at,
|
||||||
credential,
|
credential,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -479,6 +510,38 @@ class AuthManager:
|
||||||
for revoke_callback in callbacks:
|
for revoke_callback in callbacks:
|
||||||
revoke_callback()
|
revoke_callback()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _async_remove_expired_refresh_tokens(self, _: datetime | None = None) -> None:
|
||||||
|
"""Remove expired refresh tokens."""
|
||||||
|
now = time.time()
|
||||||
|
for token in self._store.async_get_refresh_tokens()[:]:
|
||||||
|
if (expire_at := token.expire_at) is not None and expire_at <= now:
|
||||||
|
self.async_remove_refresh_token(token)
|
||||||
|
self._async_track_next_refresh_token_expiration()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _async_track_next_refresh_token_expiration(self) -> None:
|
||||||
|
"""Initialise all token expiration scheduled tasks."""
|
||||||
|
next_expiration = time.time() + REFRESH_TOKEN_EXPIRATION
|
||||||
|
for token in self._store.async_get_refresh_tokens():
|
||||||
|
if (
|
||||||
|
expire_at := token.expire_at
|
||||||
|
) is not None and expire_at < next_expiration:
|
||||||
|
next_expiration = expire_at
|
||||||
|
|
||||||
|
self._expire_callback = async_track_point_in_utc_time(
|
||||||
|
self.hass,
|
||||||
|
self._remove_expired_job,
|
||||||
|
dt_util.utc_from_timestamp(next_expiration),
|
||||||
|
)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _async_cancel_expiration_schedule(self) -> None:
|
||||||
|
"""Cancel tracking of expired refresh tokens."""
|
||||||
|
if self._expire_callback:
|
||||||
|
self._expire_callback()
|
||||||
|
self._expire_callback = None
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_unregister(
|
def _async_unregister(
|
||||||
self, callbacks: set[CALLBACK_TYPE], callback_: CALLBACK_TYPE
|
self, callbacks: set[CALLBACK_TYPE], callback_: CALLBACK_TYPE
|
||||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
import hmac
|
import hmac
|
||||||
|
import itertools
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
@ -17,6 +18,7 @@ from .const import (
|
||||||
GROUP_ID_ADMIN,
|
GROUP_ID_ADMIN,
|
||||||
GROUP_ID_READ_ONLY,
|
GROUP_ID_READ_ONLY,
|
||||||
GROUP_ID_USER,
|
GROUP_ID_USER,
|
||||||
|
REFRESH_TOKEN_EXPIRATION,
|
||||||
)
|
)
|
||||||
from .permissions import system_policies
|
from .permissions import system_policies
|
||||||
from .permissions.models import PermissionLookup
|
from .permissions.models import PermissionLookup
|
||||||
|
@ -186,6 +188,7 @@ class AuthStore:
|
||||||
client_icon: str | None = None,
|
client_icon: str | None = None,
|
||||||
token_type: str = models.TOKEN_TYPE_NORMAL,
|
token_type: str = models.TOKEN_TYPE_NORMAL,
|
||||||
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
|
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
|
||||||
|
expire_at: float | None = None,
|
||||||
credential: models.Credentials | None = None,
|
credential: models.Credentials | None = None,
|
||||||
) -> models.RefreshToken:
|
) -> models.RefreshToken:
|
||||||
"""Create a new token for a user."""
|
"""Create a new token for a user."""
|
||||||
|
@ -194,6 +197,7 @@ class AuthStore:
|
||||||
"client_id": client_id,
|
"client_id": client_id,
|
||||||
"token_type": token_type,
|
"token_type": token_type,
|
||||||
"access_token_expiration": access_token_expiration,
|
"access_token_expiration": access_token_expiration,
|
||||||
|
"expire_at": expire_at,
|
||||||
"credential": credential,
|
"credential": credential,
|
||||||
}
|
}
|
||||||
if client_name:
|
if client_name:
|
||||||
|
@ -239,6 +243,15 @@ class AuthStore:
|
||||||
|
|
||||||
return found
|
return found
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_get_refresh_tokens(self) -> list[models.RefreshToken]:
|
||||||
|
"""Get all refresh tokens."""
|
||||||
|
return list(
|
||||||
|
itertools.chain.from_iterable(
|
||||||
|
user.refresh_tokens.values() for user in self._users.values()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_log_refresh_token_usage(
|
def async_log_refresh_token_usage(
|
||||||
self, refresh_token: models.RefreshToken, remote_ip: str | None = None
|
self, refresh_token: models.RefreshToken, remote_ip: str | None = None
|
||||||
|
@ -246,9 +259,13 @@ class AuthStore:
|
||||||
"""Update refresh token last used information."""
|
"""Update refresh token last used information."""
|
||||||
refresh_token.last_used_at = dt_util.utcnow()
|
refresh_token.last_used_at = dt_util.utcnow()
|
||||||
refresh_token.last_used_ip = remote_ip
|
refresh_token.last_used_ip = remote_ip
|
||||||
|
if refresh_token.expire_at:
|
||||||
|
refresh_token.expire_at = (
|
||||||
|
refresh_token.last_used_at.timestamp() + REFRESH_TOKEN_EXPIRATION
|
||||||
|
)
|
||||||
self._async_schedule_save()
|
self._async_schedule_save()
|
||||||
|
|
||||||
async def async_load(self) -> None:
|
async def async_load(self) -> None: # noqa: C901
|
||||||
"""Load the users."""
|
"""Load the users."""
|
||||||
if self._loaded:
|
if self._loaded:
|
||||||
raise RuntimeError("Auth storage is already loaded")
|
raise RuntimeError("Auth storage is already loaded")
|
||||||
|
@ -261,6 +278,8 @@ class AuthStore:
|
||||||
perm_lookup = PermissionLookup(ent_reg, dev_reg)
|
perm_lookup = PermissionLookup(ent_reg, dev_reg)
|
||||||
self._perm_lookup = perm_lookup
|
self._perm_lookup = perm_lookup
|
||||||
|
|
||||||
|
now_ts = dt_util.utcnow().timestamp()
|
||||||
|
|
||||||
if data is None or not isinstance(data, dict):
|
if data is None or not isinstance(data, dict):
|
||||||
self._set_defaults()
|
self._set_defaults()
|
||||||
return
|
return
|
||||||
|
@ -414,6 +433,14 @@ class AuthStore:
|
||||||
else:
|
else:
|
||||||
last_used_at = None
|
last_used_at = None
|
||||||
|
|
||||||
|
if (
|
||||||
|
expire_at := rt_dict.get("expire_at")
|
||||||
|
) is None and token_type == models.TOKEN_TYPE_NORMAL:
|
||||||
|
if last_used_at:
|
||||||
|
expire_at = last_used_at.timestamp() + REFRESH_TOKEN_EXPIRATION
|
||||||
|
else:
|
||||||
|
expire_at = now_ts + REFRESH_TOKEN_EXPIRATION
|
||||||
|
|
||||||
token = models.RefreshToken(
|
token = models.RefreshToken(
|
||||||
id=rt_dict["id"],
|
id=rt_dict["id"],
|
||||||
user=users[rt_dict["user_id"]],
|
user=users[rt_dict["user_id"]],
|
||||||
|
@ -430,6 +457,7 @@ class AuthStore:
|
||||||
jwt_key=rt_dict["jwt_key"],
|
jwt_key=rt_dict["jwt_key"],
|
||||||
last_used_at=last_used_at,
|
last_used_at=last_used_at,
|
||||||
last_used_ip=rt_dict.get("last_used_ip"),
|
last_used_ip=rt_dict.get("last_used_ip"),
|
||||||
|
expire_at=expire_at,
|
||||||
version=rt_dict.get("version"),
|
version=rt_dict.get("version"),
|
||||||
)
|
)
|
||||||
if "credential_id" in rt_dict:
|
if "credential_id" in rt_dict:
|
||||||
|
@ -439,6 +467,8 @@ class AuthStore:
|
||||||
self._groups = groups
|
self._groups = groups
|
||||||
self._users = users
|
self._users = users
|
||||||
|
|
||||||
|
self._async_schedule_save()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_schedule_save(self) -> None:
|
def _async_schedule_save(self) -> None:
|
||||||
"""Save users."""
|
"""Save users."""
|
||||||
|
@ -503,6 +533,7 @@ class AuthStore:
|
||||||
if refresh_token.last_used_at
|
if refresh_token.last_used_at
|
||||||
else None,
|
else None,
|
||||||
"last_used_ip": refresh_token.last_used_ip,
|
"last_used_ip": refresh_token.last_used_ip,
|
||||||
|
"expire_at": refresh_token.expire_at,
|
||||||
"credential_id": refresh_token.credential.id
|
"credential_id": refresh_token.credential.id
|
||||||
if refresh_token.credential
|
if refresh_token.credential
|
||||||
else None,
|
else None,
|
||||||
|
|
|
@ -3,6 +3,7 @@ from datetime import timedelta
|
||||||
|
|
||||||
ACCESS_TOKEN_EXPIRATION = timedelta(minutes=30)
|
ACCESS_TOKEN_EXPIRATION = timedelta(minutes=30)
|
||||||
MFA_SESSION_EXPIRATION = timedelta(minutes=5)
|
MFA_SESSION_EXPIRATION = timedelta(minutes=5)
|
||||||
|
REFRESH_TOKEN_EXPIRATION = timedelta(days=90).total_seconds()
|
||||||
|
|
||||||
GROUP_ID_ADMIN = "system-admin"
|
GROUP_ID_ADMIN = "system-admin"
|
||||||
GROUP_ID_USER = "system-users"
|
GROUP_ID_USER = "system-users"
|
||||||
|
|
|
@ -117,6 +117,8 @@ class RefreshToken:
|
||||||
last_used_at: datetime | None = attr.ib(default=None)
|
last_used_at: datetime | None = attr.ib(default=None)
|
||||||
last_used_ip: str | None = attr.ib(default=None)
|
last_used_ip: str | None = attr.ib(default=None)
|
||||||
|
|
||||||
|
expire_at: float | None = attr.ib(default=None)
|
||||||
|
|
||||||
credential: Credentials | None = attr.ib(default=None)
|
credential: Credentials | None = attr.ib(default=None)
|
||||||
|
|
||||||
version: str | None = attr.ib(default=__version__)
|
version: str | None = attr.ib(default=__version__)
|
||||||
|
|
|
@ -1,12 +1,15 @@
|
||||||
"""Tests for the auth store."""
|
"""Tests for the auth store."""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from datetime import timedelta
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from freezegun import freeze_time
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.auth import auth_store
|
from homeassistant.auth import auth_store
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
|
|
||||||
async def test_loading_no_group_data_format(
|
async def test_loading_no_group_data_format(
|
||||||
|
@ -267,3 +270,67 @@ async def test_loading_only_once(hass: HomeAssistant) -> None:
|
||||||
mock_dev_registry.assert_called_once_with(hass)
|
mock_dev_registry.assert_called_once_with(hass)
|
||||||
mock_load.assert_called_once_with()
|
mock_load.assert_called_once_with()
|
||||||
assert results[0] == results[1]
|
assert results[0] == results[1]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_add_expire_at_property(
|
||||||
|
hass: HomeAssistant, hass_storage: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
"""Test we correctly add expired_at property if not existing."""
|
||||||
|
now = dt_util.utcnow()
|
||||||
|
with freeze_time(now):
|
||||||
|
hass_storage[auth_store.STORAGE_KEY] = {
|
||||||
|
"version": 1,
|
||||||
|
"data": {
|
||||||
|
"credentials": [],
|
||||||
|
"users": [
|
||||||
|
{
|
||||||
|
"id": "user-id",
|
||||||
|
"is_active": True,
|
||||||
|
"is_owner": True,
|
||||||
|
"name": "Paulus",
|
||||||
|
"system_generated": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "system-id",
|
||||||
|
"is_active": True,
|
||||||
|
"is_owner": True,
|
||||||
|
"name": "Hass.io",
|
||||||
|
"system_generated": True,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"refresh_tokens": [
|
||||||
|
{
|
||||||
|
"access_token_expiration": 1800.0,
|
||||||
|
"client_id": "http://localhost:8123/",
|
||||||
|
"created_at": "2018-10-03T13:43:19.774637+00:00",
|
||||||
|
"id": "user-token-id",
|
||||||
|
"jwt_key": "some-key",
|
||||||
|
"last_used_at": str(now - timedelta(days=10)),
|
||||||
|
"token": "some-token",
|
||||||
|
"user_id": "user-id",
|
||||||
|
"version": "1.2.3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"access_token_expiration": 1800.0,
|
||||||
|
"client_id": "http://localhost:8123/",
|
||||||
|
"created_at": "2018-10-03T13:43:19.774637+00:00",
|
||||||
|
"id": "user-token-id2",
|
||||||
|
"jwt_key": "some-key2",
|
||||||
|
"token": "some-token",
|
||||||
|
"user_id": "user-id",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
store = auth_store.AuthStore(hass)
|
||||||
|
await store.async_load()
|
||||||
|
|
||||||
|
users = await store.async_get_users()
|
||||||
|
|
||||||
|
assert len(users[0].refresh_tokens) == 2
|
||||||
|
token1, token2 = users[0].refresh_tokens.values()
|
||||||
|
assert token1.expire_at
|
||||||
|
assert token1.expire_at == now.timestamp() + timedelta(days=80).total_seconds()
|
||||||
|
assert token2.expire_at
|
||||||
|
assert token2.expire_at == now.timestamp() + timedelta(days=90).total_seconds()
|
||||||
|
|
|
@ -26,6 +26,7 @@ from tests.common import (
|
||||||
CLIENT_ID,
|
CLIENT_ID,
|
||||||
MockUser,
|
MockUser,
|
||||||
async_capture_events,
|
async_capture_events,
|
||||||
|
async_fire_time_changed,
|
||||||
ensure_auth_manager_loaded,
|
ensure_auth_manager_loaded,
|
||||||
flush_store,
|
flush_store,
|
||||||
)
|
)
|
||||||
|
@ -406,6 +407,8 @@ async def test_generating_system_user(hass: HomeAssistant) -> None:
|
||||||
assert not user.local_only
|
assert not user.local_only
|
||||||
assert token is not None
|
assert token is not None
|
||||||
assert token.client_id is None
|
assert token.client_id is None
|
||||||
|
assert token.token_type == auth.models.TOKEN_TYPE_SYSTEM
|
||||||
|
assert token.expire_at is None
|
||||||
|
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert len(events) == 1
|
assert len(events) == 1
|
||||||
|
@ -421,6 +424,8 @@ async def test_generating_system_user(hass: HomeAssistant) -> None:
|
||||||
assert user.local_only
|
assert user.local_only
|
||||||
assert token is not None
|
assert token is not None
|
||||||
assert token.client_id is None
|
assert token.client_id is None
|
||||||
|
assert token.token_type == auth.models.TOKEN_TYPE_SYSTEM
|
||||||
|
assert token.expire_at is None
|
||||||
|
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert len(events) == 2
|
assert len(events) == 2
|
||||||
|
@ -474,6 +479,8 @@ async def test_refresh_token_with_specific_access_token_expiration(
|
||||||
assert token is not None
|
assert token is not None
|
||||||
assert token.client_id == CLIENT_ID
|
assert token.client_id == CLIENT_ID
|
||||||
assert token.access_token_expiration == timedelta(days=100)
|
assert token.access_token_expiration == timedelta(days=100)
|
||||||
|
assert token.token_type == auth.models.TOKEN_TYPE_NORMAL
|
||||||
|
assert token.expire_at is not None
|
||||||
|
|
||||||
|
|
||||||
async def test_refresh_token_type(hass: HomeAssistant) -> None:
|
async def test_refresh_token_type(hass: HomeAssistant) -> None:
|
||||||
|
@ -515,6 +522,7 @@ async def test_refresh_token_type_long_lived_access_token(hass: HomeAssistant) -
|
||||||
assert token.client_name == "GPS LOGGER"
|
assert token.client_name == "GPS LOGGER"
|
||||||
assert token.client_icon == "mdi:home"
|
assert token.client_icon == "mdi:home"
|
||||||
assert token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
assert token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
||||||
|
assert token.expire_at is None
|
||||||
|
|
||||||
|
|
||||||
async def test_refresh_token_provider_validation(mock_hass) -> None:
|
async def test_refresh_token_provider_validation(mock_hass) -> None:
|
||||||
|
@ -565,9 +573,9 @@ async def test_cannot_deactive_owner(mock_hass) -> None:
|
||||||
await manager.async_deactivate_user(owner)
|
await manager.async_deactivate_user(owner)
|
||||||
|
|
||||||
|
|
||||||
async def test_remove_refresh_token(mock_hass) -> None:
|
async def test_remove_refresh_token(hass: HomeAssistant) -> None:
|
||||||
"""Test that we can remove a refresh token."""
|
"""Test that we can remove a refresh token."""
|
||||||
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
manager = await auth.auth_manager_from_config(hass, [], [])
|
||||||
user = MockUser().add_to_auth_manager(manager)
|
user = MockUser().add_to_auth_manager(manager)
|
||||||
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
|
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
|
||||||
access_token = manager.async_create_access_token(refresh_token)
|
access_token = manager.async_create_access_token(refresh_token)
|
||||||
|
@ -578,6 +586,70 @@ async def test_remove_refresh_token(mock_hass) -> None:
|
||||||
assert manager.async_validate_access_token(access_token) is None
|
assert manager.async_validate_access_token(access_token) is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_remove_expired_refresh_token(hass: HomeAssistant) -> None:
|
||||||
|
"""Test that expired refresh tokens are deleted."""
|
||||||
|
manager = await auth.auth_manager_from_config(hass, [], [])
|
||||||
|
user = MockUser().add_to_auth_manager(manager)
|
||||||
|
now = dt_util.utcnow()
|
||||||
|
with freeze_time(now):
|
||||||
|
refresh_token1 = await manager.async_create_refresh_token(user, CLIENT_ID)
|
||||||
|
assert (
|
||||||
|
refresh_token1.expire_at
|
||||||
|
== now.timestamp() + timedelta(days=90).total_seconds()
|
||||||
|
)
|
||||||
|
|
||||||
|
with freeze_time(now + timedelta(days=30)):
|
||||||
|
async_fire_time_changed(hass, now + timedelta(days=30))
|
||||||
|
refresh_token2 = await manager.async_create_refresh_token(user, CLIENT_ID)
|
||||||
|
assert (
|
||||||
|
refresh_token2.expire_at
|
||||||
|
== now.timestamp() + timedelta(days=120).total_seconds()
|
||||||
|
)
|
||||||
|
|
||||||
|
with freeze_time(now + timedelta(days=89, hours=23)):
|
||||||
|
async_fire_time_changed(hass, now + timedelta(days=89, hours=23))
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert manager.async_get_refresh_token(refresh_token1.id)
|
||||||
|
assert manager.async_get_refresh_token(refresh_token2.id)
|
||||||
|
|
||||||
|
with freeze_time(now + timedelta(days=90, seconds=5)):
|
||||||
|
async_fire_time_changed(hass, now + timedelta(days=90, seconds=5))
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert manager.async_get_refresh_token(refresh_token1.id) is None
|
||||||
|
assert manager.async_get_refresh_token(refresh_token2.id)
|
||||||
|
|
||||||
|
with freeze_time(now + timedelta(days=120, seconds=5)):
|
||||||
|
async_fire_time_changed(hass, now + timedelta(days=120, seconds=5))
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert manager.async_get_refresh_token(refresh_token1.id) is None
|
||||||
|
assert manager.async_get_refresh_token(refresh_token2.id) is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_update_expire_at_refresh_token(hass: HomeAssistant) -> None:
|
||||||
|
"""Test that expire at is updated when refresh token is used."""
|
||||||
|
manager = await auth.auth_manager_from_config(hass, [], [])
|
||||||
|
user = MockUser().add_to_auth_manager(manager)
|
||||||
|
now = dt_util.utcnow()
|
||||||
|
with freeze_time(now):
|
||||||
|
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
|
||||||
|
assert (
|
||||||
|
refresh_token.expire_at
|
||||||
|
== now.timestamp() + timedelta(days=90).total_seconds()
|
||||||
|
)
|
||||||
|
|
||||||
|
with freeze_time(now + timedelta(days=30)):
|
||||||
|
async_fire_time_changed(hass, now + timedelta(days=30))
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert manager.async_create_access_token(refresh_token)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert (
|
||||||
|
refresh_token.expire_at
|
||||||
|
== now.timestamp()
|
||||||
|
+ timedelta(days=30).total_seconds()
|
||||||
|
+ timedelta(days=90).total_seconds()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_register_revoke_token_callback(mock_hass) -> None:
|
async def test_register_revoke_token_callback(mock_hass) -> None:
|
||||||
"""Test that a registered revoke token callback is called."""
|
"""Test that a registered revoke token callback is called."""
|
||||||
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue