Index auth token ids to avoid linear search (#116583)
* Index auth token ids to avoid linear search * async_remove_refresh_token * coverage
This commit is contained in:
parent
c8e6292cb7
commit
a57f4b8f42
2 changed files with 47 additions and 12 deletions
|
@ -63,6 +63,7 @@ class AuthStore:
|
||||||
self._store = Store[dict[str, list[dict[str, Any]]]](
|
self._store = Store[dict[str, list[dict[str, Any]]]](
|
||||||
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
|
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
|
||||||
)
|
)
|
||||||
|
self._token_id_to_user_id: dict[str, str] = {}
|
||||||
|
|
||||||
async def async_get_groups(self) -> list[models.Group]:
|
async def async_get_groups(self) -> list[models.Group]:
|
||||||
"""Retrieve all users."""
|
"""Retrieve all users."""
|
||||||
|
@ -136,7 +137,10 @@ class AuthStore:
|
||||||
|
|
||||||
async def async_remove_user(self, user: models.User) -> None:
|
async def async_remove_user(self, user: models.User) -> None:
|
||||||
"""Remove a user."""
|
"""Remove a user."""
|
||||||
self._users.pop(user.id)
|
user = self._users.pop(user.id)
|
||||||
|
for refresh_token_id in user.refresh_tokens:
|
||||||
|
del self._token_id_to_user_id[refresh_token_id]
|
||||||
|
user.refresh_tokens.clear()
|
||||||
self._async_schedule_save()
|
self._async_schedule_save()
|
||||||
|
|
||||||
async def async_update_user(
|
async def async_update_user(
|
||||||
|
@ -219,7 +223,9 @@ class AuthStore:
|
||||||
kwargs["client_icon"] = client_icon
|
kwargs["client_icon"] = client_icon
|
||||||
|
|
||||||
refresh_token = models.RefreshToken(**kwargs)
|
refresh_token = models.RefreshToken(**kwargs)
|
||||||
user.refresh_tokens[refresh_token.id] = refresh_token
|
token_id = refresh_token.id
|
||||||
|
user.refresh_tokens[token_id] = refresh_token
|
||||||
|
self._token_id_to_user_id[token_id] = user.id
|
||||||
|
|
||||||
self._async_schedule_save()
|
self._async_schedule_save()
|
||||||
return refresh_token
|
return refresh_token
|
||||||
|
@ -227,19 +233,17 @@ class AuthStore:
|
||||||
@callback
|
@callback
|
||||||
def async_remove_refresh_token(self, refresh_token: models.RefreshToken) -> None:
|
def async_remove_refresh_token(self, refresh_token: models.RefreshToken) -> None:
|
||||||
"""Remove a refresh token."""
|
"""Remove a refresh token."""
|
||||||
for user in self._users.values():
|
refresh_token_id = refresh_token.id
|
||||||
if user.refresh_tokens.pop(refresh_token.id, None):
|
if user_id := self._token_id_to_user_id.get(refresh_token_id):
|
||||||
self._async_schedule_save()
|
del self._users[user_id].refresh_tokens[refresh_token_id]
|
||||||
break
|
del self._token_id_to_user_id[refresh_token_id]
|
||||||
|
self._async_schedule_save()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_get_refresh_token(self, token_id: str) -> models.RefreshToken | None:
|
def async_get_refresh_token(self, token_id: str) -> models.RefreshToken | None:
|
||||||
"""Get refresh token by id."""
|
"""Get refresh token by id."""
|
||||||
for user in self._users.values():
|
if user_id := self._token_id_to_user_id.get(token_id):
|
||||||
refresh_token = user.refresh_tokens.get(token_id)
|
return self._users[user_id].refresh_tokens.get(token_id)
|
||||||
if refresh_token is not None:
|
|
||||||
return refresh_token
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -479,9 +483,18 @@ class AuthStore:
|
||||||
|
|
||||||
self._groups = groups
|
self._groups = groups
|
||||||
self._users = users
|
self._users = users
|
||||||
|
self._build_token_id_to_user_id()
|
||||||
self._async_schedule_save(INITIAL_LOAD_SAVE_DELAY)
|
self._async_schedule_save(INITIAL_LOAD_SAVE_DELAY)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _build_token_id_to_user_id(self) -> None:
|
||||||
|
"""Build a map of token id to user id."""
|
||||||
|
self._token_id_to_user_id = {
|
||||||
|
token_id: user_id
|
||||||
|
for user_id, user in self._users.items()
|
||||||
|
for token_id in user.refresh_tokens
|
||||||
|
}
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_schedule_save(self, delay: float = DEFAULT_SAVE_DELAY) -> None:
|
def _async_schedule_save(self, delay: float = DEFAULT_SAVE_DELAY) -> None:
|
||||||
"""Save users."""
|
"""Save users."""
|
||||||
|
@ -575,6 +588,7 @@ class AuthStore:
|
||||||
read_only_group = _system_read_only_group()
|
read_only_group = _system_read_only_group()
|
||||||
groups[read_only_group.id] = read_only_group
|
groups[read_only_group.id] = read_only_group
|
||||||
self._groups = groups
|
self._groups = groups
|
||||||
|
self._build_token_id_to_user_id()
|
||||||
|
|
||||||
|
|
||||||
def _system_admin_group() -> models.Group:
|
def _system_admin_group() -> models.Group:
|
||||||
|
|
|
@ -305,3 +305,24 @@ async def test_loading_does_not_write_right_away(
|
||||||
# Once for the task
|
# Once for the task
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert hass_storage[auth_store.STORAGE_KEY] != {}
|
assert hass_storage[auth_store.STORAGE_KEY] != {}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_add_remove_user_affects_tokens(
|
||||||
|
hass: HomeAssistant, hass_storage: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
"""Test adding and removing a user removes the tokens."""
|
||||||
|
store = auth_store.AuthStore(hass)
|
||||||
|
await store.async_load()
|
||||||
|
user = await store.async_create_user("Test User")
|
||||||
|
assert user.name == "Test User"
|
||||||
|
refresh_token = await store.async_create_refresh_token(
|
||||||
|
user, "client_id", "access_token_expiration"
|
||||||
|
)
|
||||||
|
assert user.refresh_tokens == {refresh_token.id: refresh_token}
|
||||||
|
assert await store.async_get_user(user.id) == user
|
||||||
|
assert store.async_get_refresh_token(refresh_token.id) == refresh_token
|
||||||
|
assert store.async_get_refresh_token_by_token(refresh_token.token) == refresh_token
|
||||||
|
await store.async_remove_user(user)
|
||||||
|
assert store.async_get_refresh_token(refresh_token.id) is None
|
||||||
|
assert store.async_get_refresh_token_by_token(refresh_token.token) is None
|
||||||
|
assert user.refresh_tokens == {}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue