diff --git a/homeassistant/auth/__init__.py b/homeassistant/auth/__init__.py index 24e34a2d555..0b749766263 100644 --- a/homeassistant/auth/__init__.py +++ b/homeassistant/auth/__init__.py @@ -516,6 +516,13 @@ class AuthManager: for revoke_callback in callbacks: revoke_callback() + @callback + def async_set_expiry( + self, refresh_token: models.RefreshToken, *, enable_expiry: bool + ) -> None: + """Enable or disable expiry of a refresh token.""" + self._store.async_set_expiry(refresh_token, enable_expiry=enable_expiry) + @callback def _async_remove_expired_refresh_tokens(self, _: datetime | None = None) -> None: """Remove expired refresh tokens.""" diff --git a/homeassistant/auth/auth_store.py b/homeassistant/auth/auth_store.py index bf93011355c..3bf025c058c 100644 --- a/homeassistant/auth/auth_store.py +++ b/homeassistant/auth/auth_store.py @@ -6,7 +6,6 @@ from datetime import timedelta import hmac import itertools from logging import getLogger -import time from typing import Any from homeassistant.core import HomeAssistant, callback @@ -282,6 +281,21 @@ class AuthStore: ) self._async_schedule_save() + @callback + def async_set_expiry( + self, refresh_token: models.RefreshToken, *, enable_expiry: bool + ) -> None: + """Enable or disable expiry of a refresh token.""" + if enable_expiry: + if refresh_token.expire_at is None: + refresh_token.expire_at = ( + refresh_token.last_used_at or dt_util.utcnow() + ).timestamp() + REFRESH_TOKEN_EXPIRATION + self._async_schedule_save() + else: + refresh_token.expire_at = None + self._async_schedule_save() + async def async_load(self) -> None: # noqa: C901 """Load the users.""" if self._loaded: @@ -295,8 +309,6 @@ class AuthStore: perm_lookup = PermissionLookup(ent_reg, dev_reg) self._perm_lookup = perm_lookup - now_ts = time.time() - if data is None or not isinstance(data, dict): self._set_defaults() return @@ -450,14 +462,6 @@ class AuthStore: else: last_used_at = None - if ( - expire_at := rt_dict.get("expire_at") - ) is None and token_type == models.TOKEN_TYPE_NORMAL: - if last_used_at: - expire_at = last_used_at.timestamp() + REFRESH_TOKEN_EXPIRATION - else: - expire_at = now_ts + REFRESH_TOKEN_EXPIRATION - token = models.RefreshToken( id=rt_dict["id"], user=users[rt_dict["user_id"]], @@ -474,7 +478,7 @@ class AuthStore: jwt_key=rt_dict["jwt_key"], last_used_at=last_used_at, last_used_ip=rt_dict.get("last_used_ip"), - expire_at=expire_at, + expire_at=rt_dict.get("expire_at"), version=rt_dict.get("version"), ) if "credential_id" in rt_dict: diff --git a/homeassistant/components/auth/__init__.py b/homeassistant/components/auth/__init__.py index 8d9b47fdd06..6e4bbac8b63 100644 --- a/homeassistant/components/auth/__init__.py +++ b/homeassistant/components/auth/__init__.py @@ -197,6 +197,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: websocket_api.async_register_command(hass, websocket_delete_refresh_token) websocket_api.async_register_command(hass, websocket_delete_all_refresh_tokens) websocket_api.async_register_command(hass, websocket_sign_path) + websocket_api.async_register_command(hass, websocket_refresh_token_set_expiry) login_flow.async_setup(hass, store_result) mfa_setup_flow.async_setup(hass) @@ -565,18 +566,23 @@ def websocket_refresh_tokens( else: auth_provider_type = None + expire_at = None + if refresh.expire_at: + expire_at = dt_util.utc_from_timestamp(refresh.expire_at) + tokens.append( { - "id": refresh.id, + "auth_provider_type": auth_provider_type, + "client_icon": refresh.client_icon, "client_id": refresh.client_id, "client_name": refresh.client_name, - "client_icon": refresh.client_icon, - "type": refresh.token_type, "created_at": refresh.created_at, + "expire_at": expire_at, + "id": refresh.id, "is_current": refresh.id == current_id, "last_used_at": refresh.last_used_at, "last_used_ip": refresh.last_used_ip, - "auth_provider_type": auth_provider_type, + "type": refresh.token_type, } ) @@ -702,3 +708,26 @@ def websocket_sign_path( }, ) ) + + +@callback +@websocket_api.websocket_command( + { + vol.Required("type"): "auth/refresh_token_set_expiry", + vol.Required("refresh_token_id"): str, + vol.Required("enable_expiry"): bool, + } +) +@websocket_api.ws_require_user() +def websocket_refresh_token_set_expiry( + hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any] +) -> None: + """Handle a set expiry of a refresh token request.""" + refresh_token = connection.user.refresh_tokens.get(msg["refresh_token_id"]) + + if refresh_token is None: + connection.send_error(msg["id"], "invalid_token_id", "Received invalid token") + return + + hass.auth.async_set_expiry(refresh_token, enable_expiry=msg["enable_expiry"]) + connection.send_result(msg["id"], {}) diff --git a/tests/auth/test_auth_store.py b/tests/auth/test_auth_store.py index 8ef8a4e3946..65bc35a5ff8 100644 --- a/tests/auth/test_auth_store.py +++ b/tests/auth/test_auth_store.py @@ -1,17 +1,14 @@ """Tests for the auth store.""" import asyncio -from datetime import timedelta from typing import Any from unittest.mock import patch -from freezegun import freeze_time from freezegun.api import FrozenDateTimeFactory import pytest from homeassistant.auth import auth_store from homeassistant.core import HomeAssistant -from homeassistant.util import dt as dt_util MOCK_STORAGE_DATA = { "version": 1, @@ -220,68 +217,64 @@ async def test_loading_only_once(hass: HomeAssistant) -> None: assert results[0] == results[1] -async def test_add_expire_at_property( +async def test_dont_change_expire_at_on_load( hass: HomeAssistant, hass_storage: dict[str, Any] ) -> None: - """Test we correctly add expired_at property if not existing.""" - now = dt_util.utcnow() - with freeze_time(now): - hass_storage[auth_store.STORAGE_KEY] = { - "version": 1, - "data": { - "credentials": [], - "users": [ - { - "id": "user-id", - "is_active": True, - "is_owner": True, - "name": "Paulus", - "system_generated": False, - }, - { - "id": "system-id", - "is_active": True, - "is_owner": True, - "name": "Hass.io", - "system_generated": True, - }, - ], - "refresh_tokens": [ - { - "access_token_expiration": 1800.0, - "client_id": "http://localhost:8123/", - "created_at": "2018-10-03T13:43:19.774637+00:00", - "id": "user-token-id", - "jwt_key": "some-key", - "last_used_at": str(now - timedelta(days=10)), - "token": "some-token", - "user_id": "user-id", - "version": "1.2.3", - }, - { - "access_token_expiration": 1800.0, - "client_id": "http://localhost:8123/", - "created_at": "2018-10-03T13:43:19.774637+00:00", - "id": "user-token-id2", - "jwt_key": "some-key2", - "token": "some-token", - "user_id": "user-id", - }, - ], - }, - } + """Test we correctly don't modify expired_at store load.""" + hass_storage[auth_store.STORAGE_KEY] = { + "version": 1, + "data": { + "credentials": [], + "users": [ + { + "id": "user-id", + "is_active": True, + "is_owner": True, + "name": "Paulus", + "system_generated": False, + }, + { + "id": "system-id", + "is_active": True, + "is_owner": True, + "name": "Hass.io", + "system_generated": True, + }, + ], + "refresh_tokens": [ + { + "access_token_expiration": 1800.0, + "client_id": "http://localhost:8123/", + "created_at": "2018-10-03T13:43:19.774637+00:00", + "id": "user-token-id", + "jwt_key": "some-key", + "token": "some-token", + "user_id": "user-id", + "version": "1.2.3", + }, + { + "access_token_expiration": 1800.0, + "client_id": "http://localhost:8123/", + "created_at": "2018-10-03T13:43:19.774637+00:00", + "id": "user-token-id2", + "jwt_key": "some-key2", + "token": "some-token", + "user_id": "user-id", + "expire_at": 1724133771.079745, + }, + ], + }, + } - store = auth_store.AuthStore(hass) - await store.async_load() + store = auth_store.AuthStore(hass) + await store.async_load() users = await store.async_get_users() assert len(users[0].refresh_tokens) == 2 token1, token2 = users[0].refresh_tokens.values() - assert token1.expire_at - assert token1.expire_at == now.timestamp() + timedelta(days=80).total_seconds() - assert token2.expire_at - assert token2.expire_at == now.timestamp() + timedelta(days=90).total_seconds() + assert not token1.expire_at + assert token2.expire_at == 1724133771.079745 async def test_loading_does_not_write_right_away( @@ -326,3 +319,63 @@ async def test_add_remove_user_affects_tokens( assert store.async_get_refresh_token(refresh_token.id) is None assert store.async_get_refresh_token_by_token(refresh_token.token) is None assert user.refresh_tokens == {} + + +async def test_set_expiry_date( + hass: HomeAssistant, hass_storage: dict[str, Any], freezer: FrozenDateTimeFactory +) -> None: + """Test set expiry date of a refresh token.""" + hass_storage[auth_store.STORAGE_KEY] = { + "version": 1, + "data": { + "credentials": [], + "users": [ + { + "id": "user-id", + "is_active": True, + "is_owner": True, + "name": "Paulus", + "system_generated": False, + }, + ], + "refresh_tokens": [ + { + "access_token_expiration": 1800.0, + "client_id": "http://localhost:8123/", + "created_at": "2018-10-03T13:43:19.774637+00:00", + "id": "user-token-id", + "jwt_key": "some-key", + "token": "some-token", + "user_id": "user-id", + "expire_at": 1724133771.079745, + }, + ], + }, + } + + store = auth_store.AuthStore(hass) + await store.async_load() + + users = await store.async_get_users() + + assert len(users[0].refresh_tokens) == 1 + (token,) = users[0].refresh_tokens.values() + assert token.expire_at == 1724133771.079745 + + store.async_set_expiry(token, enable_expiry=False) + assert token.expire_at is None + + freezer.tick(auth_store.DEFAULT_SAVE_DELAY * 2) + # Once for scheduling the task + await hass.async_block_till_done() + # Once for the task + await hass.async_block_till_done() + + # verify token is saved without expire_at + assert ( + hass_storage[auth_store.STORAGE_KEY]["data"]["refresh_tokens"][0]["expire_at"] + is None + ) + + store.async_set_expiry(token, enable_expiry=True) + assert token.expire_at is not None diff --git a/tests/components/auth/test_init.py b/tests/components/auth/test_init.py index 09079337e07..d0ca4699e0e 100644 --- a/tests/components/auth/test_init.py +++ b/tests/components/auth/test_init.py @@ -690,3 +690,72 @@ async def test_ws_sign_path( hass, path, expires = mock_sign.mock_calls[0][1] assert path == "/api/hello" assert expires.total_seconds() == 20 + + +async def test_ws_refresh_token_set_expiry( + hass: HomeAssistant, + hass_admin_user: MockUser, + hass_admin_credential: Credentials, + hass_ws_client: WebSocketGenerator, + hass_access_token: str, +) -> None: + """Test setting expiry of a refresh token.""" + assert await async_setup_component(hass, "auth", {"http": {}}) + + refresh_token = await hass.auth.async_create_refresh_token( + hass_admin_user, CLIENT_ID, credential=hass_admin_credential + ) + assert refresh_token.expire_at is not None + ws_client = await hass_ws_client(hass, hass_access_token) + + await ws_client.send_json_auto_id( + { + "type": "auth/refresh_token_set_expiry", + "refresh_token_id": refresh_token.id, + "enable_expiry": False, + } + ) + + result = await ws_client.receive_json() + assert result["success"], result + refresh_token = hass.auth.async_get_refresh_token(refresh_token.id) + assert refresh_token.expire_at is None + + await ws_client.send_json_auto_id( + { + "type": "auth/refresh_token_set_expiry", + "refresh_token_id": refresh_token.id, + "enable_expiry": True, + } + ) + + result = await ws_client.receive_json() + assert result["success"], result + refresh_token = hass.auth.async_get_refresh_token(refresh_token.id) + assert refresh_token.expire_at is not None + + +async def test_ws_refresh_token_set_expiry_error( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + hass_access_token: str, +) -> None: + """Test setting expiry of a invalid refresh token returns error.""" + assert await async_setup_component(hass, "auth", {"http": {}}) + + ws_client = await hass_ws_client(hass, hass_access_token) + + await ws_client.send_json_auto_id( + { + "type": "auth/refresh_token_set_expiry", + "refresh_token_id": "invalid", + "enable_expiry": False, + } + ) + + result = await ws_client.receive_json() + assert result, result["success"] is False + assert result["error"] == { + "code": "invalid_token_id", + "message": "Received invalid token", + }