Always load auth storage at startup (#108543)

This commit is contained in:
J. Nick Koston 2024-01-20 16:16:43 -10:00 committed by GitHub
parent 4d46f5ec07
commit ec15b0def2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 43 additions and 81 deletions

View file

@ -47,6 +47,7 @@ async def auth_manager_from_config(
mfa modules exist in configs.
"""
store = auth_store.AuthStore(hass)
await store.async_load()
if provider_configs:
providers = await asyncio.gather(
*(
@ -73,8 +74,7 @@ async def auth_manager_from_config(
for module in modules:
module_hash[module.id] = module
manager = AuthManager(hass, store, provider_hash, module_hash)
return manager
return AuthManager(hass, store, provider_hash, module_hash)
class AuthManagerFlowManager(data_entry_flow.FlowManager):

View file

@ -1,7 +1,6 @@
"""Storage for auth models."""
from __future__ import annotations
import asyncio
from datetime import timedelta
import hmac
from logging import getLogger
@ -42,44 +41,28 @@ class AuthStore:
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the auth store."""
self.hass = hass
self._users: dict[str, models.User] | None = None
self._groups: dict[str, models.Group] | None = None
self._perm_lookup: PermissionLookup | None = None
self._loaded = False
self._users: dict[str, models.User] = None # type: ignore[assignment]
self._groups: dict[str, models.Group] = None # type: ignore[assignment]
self._perm_lookup: PermissionLookup = None # type: ignore[assignment]
self._store = Store[dict[str, list[dict[str, Any]]]](
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
)
self._lock = asyncio.Lock()
async def async_get_groups(self) -> list[models.Group]:
"""Retrieve all users."""
if self._groups is None:
await self._async_load()
assert self._groups is not None
return list(self._groups.values())
async def async_get_group(self, group_id: str) -> models.Group | None:
"""Retrieve all users."""
if self._groups is None:
await self._async_load()
assert self._groups is not None
return self._groups.get(group_id)
async def async_get_users(self) -> list[models.User]:
"""Retrieve all users."""
if self._users is None:
await self._async_load()
assert self._users is not None
return list(self._users.values())
async def async_get_user(self, user_id: str) -> models.User | None:
"""Retrieve a user by id."""
if self._users is None:
await self._async_load()
assert self._users is not None
return self._users.get(user_id)
async def async_create_user(
@ -93,12 +76,6 @@ class AuthStore:
local_only: bool | None = None,
) -> models.User:
"""Create a new user."""
if self._users is None:
await self._async_load()
assert self._users is not None
assert self._groups is not None
groups = []
for group_id in group_ids or []:
if (group := self._groups.get(group_id)) is None:
@ -144,10 +121,6 @@ class AuthStore:
async def async_remove_user(self, user: models.User) -> None:
"""Remove a user."""
if self._users is None:
await self._async_load()
assert self._users is not None
self._users.pop(user.id)
self._async_schedule_save()
@ -160,8 +133,6 @@ class AuthStore:
local_only: bool | None = None,
) -> None:
"""Update a user."""
assert self._groups is not None
if group_ids is not None:
groups = []
for grid in group_ids:
@ -193,10 +164,6 @@ class AuthStore:
async def async_remove_credentials(self, credentials: models.Credentials) -> None:
"""Remove credentials."""
if self._users is None:
await self._async_load()
assert self._users is not None
for user in self._users.values():
found = None
@ -244,10 +211,6 @@ class AuthStore:
self, refresh_token: models.RefreshToken
) -> None:
"""Remove a refresh token."""
if self._users is None:
await self._async_load()
assert self._users is not None
for user in self._users.values():
if user.refresh_tokens.pop(refresh_token.id, None):
self._async_schedule_save()
@ -257,10 +220,6 @@ class AuthStore:
self, token_id: str
) -> models.RefreshToken | None:
"""Get refresh token by id."""
if self._users is None:
await self._async_load()
assert self._users is not None
for user in self._users.values():
refresh_token = user.refresh_tokens.get(token_id)
if refresh_token is not None:
@ -272,10 +231,6 @@ class AuthStore:
self, token: str
) -> models.RefreshToken | None:
"""Get refresh token by token."""
if self._users is None:
await self._async_load()
assert self._users is not None
found = None
for user in self._users.values():
@ -294,25 +249,18 @@ class AuthStore:
refresh_token.last_used_ip = remote_ip
self._async_schedule_save()
async def _async_load(self) -> None:
async def async_load(self) -> None:
"""Load the users."""
async with self._lock:
if self._users is not None:
return
await self._async_load_task()
if self._loaded:
raise RuntimeError("Auth storage is already loaded")
self._loaded = True
async def _async_load_task(self) -> None:
"""Load the users."""
dev_reg = dr.async_get(self.hass)
ent_reg = er.async_get(self.hass)
data = await self._store.async_load()
# Make sure that we're not overriding data if 2 loads happened at the
# same time
if self._users is not None:
return
self._perm_lookup = perm_lookup = PermissionLookup(ent_reg, dev_reg)
perm_lookup = PermissionLookup(ent_reg, dev_reg)
self._perm_lookup = perm_lookup
if data is None or not isinstance(data, dict):
self._set_defaults()
@ -495,17 +443,11 @@ class AuthStore:
@callback
def _async_schedule_save(self) -> None:
"""Save users."""
if self._users is None:
return
self._store.async_delay_save(self._data_to_save, 1)
@callback
def _data_to_save(self) -> dict[str, list[dict[str, Any]]]:
"""Return the data to store."""
assert self._users is not None
assert self._groups is not None
users = [
{
"id": user.id,

View file

@ -9,6 +9,7 @@ from homeassistant.auth import auth_manager_from_config
from homeassistant.auth.providers import homeassistant as hass_auth
from homeassistant.config import get_default_config_dir
from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr, entity_registry as er
# mypy: allow-untyped-calls, allow-untyped-defs
@ -51,6 +52,7 @@ def run(args):
async def run_command(args):
"""Run the command."""
hass = HomeAssistant(os.path.join(os.getcwd(), args.config))
await asyncio.gather(dr.async_load(hass), er.async_load(hass))
hass.auth = await auth_manager_from_config(hass, [{"type": "homeassistant"}], [])
provider = hass.auth.auth_providers[0]
await provider.async_initialize()

View file

@ -13,9 +13,11 @@ from homeassistant.const import CONF_TYPE
@pytest.fixture
def store(hass):
async def store(hass):
"""Mock store."""
return auth_store.AuthStore(hass)
store = auth_store.AuthStore(hass)
await store.async_load()
return store
@pytest.fixture

View file

@ -9,9 +9,11 @@ from homeassistant.auth.providers import insecure_example
@pytest.fixture
def store(hass):
async def store(hass):
"""Mock store."""
return auth_store.AuthStore(hass)
store = auth_store.AuthStore(hass)
await store.async_load()
return store
@pytest.fixture

View file

@ -14,9 +14,11 @@ CONFIG = {"type": "legacy_api_password", "api_password": "test-password"}
@pytest.fixture
def store(hass):
async def store(hass):
"""Mock store."""
return auth_store.AuthStore(hass)
store = auth_store.AuthStore(hass)
await store.async_load()
return store
@pytest.fixture

View file

@ -16,9 +16,11 @@ from homeassistant.setup import async_setup_component
@pytest.fixture
def store(hass):
async def store(hass):
"""Mock store."""
return auth_store.AuthStore(hass)
store = auth_store.AuthStore(hass)
await store.async_load()
return store
@pytest.fixture

View file

@ -3,6 +3,8 @@ import asyncio
from typing import Any
from unittest.mock import patch
import pytest
from homeassistant.auth import auth_store
from homeassistant.core import HomeAssistant
@ -67,6 +69,7 @@ async def test_loading_no_group_data_format(
}
store = auth_store.AuthStore(hass)
await store.async_load()
groups = await store.async_get_groups()
assert len(groups) == 3
admin_group = groups[0]
@ -165,6 +168,7 @@ async def test_loading_all_access_group_data_format(
}
store = auth_store.AuthStore(hass)
await store.async_load()
groups = await store.async_get_groups()
assert len(groups) == 3
admin_group = groups[0]
@ -205,6 +209,7 @@ async def test_loading_empty_data(
) -> None:
"""Test we correctly load with no existing data."""
store = auth_store.AuthStore(hass)
await store.async_load()
groups = await store.async_get_groups()
assert len(groups) == 3
admin_group = groups[0]
@ -232,7 +237,7 @@ async def test_system_groups_store_id_and_name(
Name is stored so that we remain backwards compat with < 0.82.
"""
store = auth_store.AuthStore(hass)
await store._async_load()
await store.async_load()
data = store._data_to_save()
assert len(data["users"]) == 0
assert data["groups"] == [
@ -242,8 +247,8 @@ async def test_system_groups_store_id_and_name(
]
async def test_loading_race_condition(hass: HomeAssistant) -> None:
"""Test only one storage load called when concurrent loading occurred ."""
async def test_loading_only_once(hass: HomeAssistant) -> None:
"""Test only one storage load is allowed."""
store = auth_store.AuthStore(hass)
with patch(
"homeassistant.helpers.entity_registry.async_get"
@ -252,6 +257,10 @@ async def test_loading_race_condition(hass: HomeAssistant) -> None:
) as mock_dev_registry, patch(
"homeassistant.helpers.storage.Store.async_load", return_value=None
) as mock_load:
await store.async_load()
with pytest.raises(RuntimeError, match="Auth storage is already loaded"):
await store.async_load()
results = await asyncio.gather(store.async_get_users(), store.async_get_users())
mock_ent_registry.assert_called_once_with(hass)

View file

@ -343,6 +343,7 @@ async def test_saving_loading(
await flush_store(manager._store._store)
store2 = auth_store.AuthStore(hass)
await store2.async_load()
users = await store2.async_get_users()
assert len(users) == 1
assert users[0].permissions == user.permissions