Always load auth storage at startup (#108543)
This commit is contained in:
parent
4d46f5ec07
commit
ec15b0def2
9 changed files with 43 additions and 81 deletions
|
@ -47,6 +47,7 @@ async def auth_manager_from_config(
|
|||
mfa modules exist in configs.
|
||||
"""
|
||||
store = auth_store.AuthStore(hass)
|
||||
await store.async_load()
|
||||
if provider_configs:
|
||||
providers = await asyncio.gather(
|
||||
*(
|
||||
|
@ -73,8 +74,7 @@ async def auth_manager_from_config(
|
|||
for module in modules:
|
||||
module_hash[module.id] = module
|
||||
|
||||
manager = AuthManager(hass, store, provider_hash, module_hash)
|
||||
return manager
|
||||
return AuthManager(hass, store, provider_hash, module_hash)
|
||||
|
||||
|
||||
class AuthManagerFlowManager(data_entry_flow.FlowManager):
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
"""Storage for auth models."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
import hmac
|
||||
from logging import getLogger
|
||||
|
@ -42,44 +41,28 @@ class AuthStore:
|
|||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the auth store."""
|
||||
self.hass = hass
|
||||
self._users: dict[str, models.User] | None = None
|
||||
self._groups: dict[str, models.Group] | None = None
|
||||
self._perm_lookup: PermissionLookup | None = None
|
||||
self._loaded = False
|
||||
self._users: dict[str, models.User] = None # type: ignore[assignment]
|
||||
self._groups: dict[str, models.Group] = None # type: ignore[assignment]
|
||||
self._perm_lookup: PermissionLookup = None # type: ignore[assignment]
|
||||
self._store = Store[dict[str, list[dict[str, Any]]]](
|
||||
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
|
||||
)
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def async_get_groups(self) -> list[models.Group]:
|
||||
"""Retrieve all users."""
|
||||
if self._groups is None:
|
||||
await self._async_load()
|
||||
assert self._groups is not None
|
||||
|
||||
return list(self._groups.values())
|
||||
|
||||
async def async_get_group(self, group_id: str) -> models.Group | None:
|
||||
"""Retrieve all users."""
|
||||
if self._groups is None:
|
||||
await self._async_load()
|
||||
assert self._groups is not None
|
||||
|
||||
return self._groups.get(group_id)
|
||||
|
||||
async def async_get_users(self) -> list[models.User]:
|
||||
"""Retrieve all users."""
|
||||
if self._users is None:
|
||||
await self._async_load()
|
||||
assert self._users is not None
|
||||
|
||||
return list(self._users.values())
|
||||
|
||||
async def async_get_user(self, user_id: str) -> models.User | None:
|
||||
"""Retrieve a user by id."""
|
||||
if self._users is None:
|
||||
await self._async_load()
|
||||
assert self._users is not None
|
||||
|
||||
return self._users.get(user_id)
|
||||
|
||||
async def async_create_user(
|
||||
|
@ -93,12 +76,6 @@ class AuthStore:
|
|||
local_only: bool | None = None,
|
||||
) -> models.User:
|
||||
"""Create a new user."""
|
||||
if self._users is None:
|
||||
await self._async_load()
|
||||
|
||||
assert self._users is not None
|
||||
assert self._groups is not None
|
||||
|
||||
groups = []
|
||||
for group_id in group_ids or []:
|
||||
if (group := self._groups.get(group_id)) is None:
|
||||
|
@ -144,10 +121,6 @@ class AuthStore:
|
|||
|
||||
async def async_remove_user(self, user: models.User) -> None:
|
||||
"""Remove a user."""
|
||||
if self._users is None:
|
||||
await self._async_load()
|
||||
assert self._users is not None
|
||||
|
||||
self._users.pop(user.id)
|
||||
self._async_schedule_save()
|
||||
|
||||
|
@ -160,8 +133,6 @@ class AuthStore:
|
|||
local_only: bool | None = None,
|
||||
) -> None:
|
||||
"""Update a user."""
|
||||
assert self._groups is not None
|
||||
|
||||
if group_ids is not None:
|
||||
groups = []
|
||||
for grid in group_ids:
|
||||
|
@ -193,10 +164,6 @@ class AuthStore:
|
|||
|
||||
async def async_remove_credentials(self, credentials: models.Credentials) -> None:
|
||||
"""Remove credentials."""
|
||||
if self._users is None:
|
||||
await self._async_load()
|
||||
assert self._users is not None
|
||||
|
||||
for user in self._users.values():
|
||||
found = None
|
||||
|
||||
|
@ -244,10 +211,6 @@ class AuthStore:
|
|||
self, refresh_token: models.RefreshToken
|
||||
) -> None:
|
||||
"""Remove a refresh token."""
|
||||
if self._users is None:
|
||||
await self._async_load()
|
||||
assert self._users is not None
|
||||
|
||||
for user in self._users.values():
|
||||
if user.refresh_tokens.pop(refresh_token.id, None):
|
||||
self._async_schedule_save()
|
||||
|
@ -257,10 +220,6 @@ class AuthStore:
|
|||
self, token_id: str
|
||||
) -> models.RefreshToken | None:
|
||||
"""Get refresh token by id."""
|
||||
if self._users is None:
|
||||
await self._async_load()
|
||||
assert self._users is not None
|
||||
|
||||
for user in self._users.values():
|
||||
refresh_token = user.refresh_tokens.get(token_id)
|
||||
if refresh_token is not None:
|
||||
|
@ -272,10 +231,6 @@ class AuthStore:
|
|||
self, token: str
|
||||
) -> models.RefreshToken | None:
|
||||
"""Get refresh token by token."""
|
||||
if self._users is None:
|
||||
await self._async_load()
|
||||
assert self._users is not None
|
||||
|
||||
found = None
|
||||
|
||||
for user in self._users.values():
|
||||
|
@ -294,25 +249,18 @@ class AuthStore:
|
|||
refresh_token.last_used_ip = remote_ip
|
||||
self._async_schedule_save()
|
||||
|
||||
async def _async_load(self) -> None:
|
||||
async def async_load(self) -> None:
|
||||
"""Load the users."""
|
||||
async with self._lock:
|
||||
if self._users is not None:
|
||||
return
|
||||
await self._async_load_task()
|
||||
if self._loaded:
|
||||
raise RuntimeError("Auth storage is already loaded")
|
||||
self._loaded = True
|
||||
|
||||
async def _async_load_task(self) -> None:
|
||||
"""Load the users."""
|
||||
dev_reg = dr.async_get(self.hass)
|
||||
ent_reg = er.async_get(self.hass)
|
||||
data = await self._store.async_load()
|
||||
|
||||
# Make sure that we're not overriding data if 2 loads happened at the
|
||||
# same time
|
||||
if self._users is not None:
|
||||
return
|
||||
|
||||
self._perm_lookup = perm_lookup = PermissionLookup(ent_reg, dev_reg)
|
||||
perm_lookup = PermissionLookup(ent_reg, dev_reg)
|
||||
self._perm_lookup = perm_lookup
|
||||
|
||||
if data is None or not isinstance(data, dict):
|
||||
self._set_defaults()
|
||||
|
@ -495,17 +443,11 @@ class AuthStore:
|
|||
@callback
|
||||
def _async_schedule_save(self) -> None:
|
||||
"""Save users."""
|
||||
if self._users is None:
|
||||
return
|
||||
|
||||
self._store.async_delay_save(self._data_to_save, 1)
|
||||
|
||||
@callback
|
||||
def _data_to_save(self) -> dict[str, list[dict[str, Any]]]:
|
||||
"""Return the data to store."""
|
||||
assert self._users is not None
|
||||
assert self._groups is not None
|
||||
|
||||
users = [
|
||||
{
|
||||
"id": user.id,
|
||||
|
|
|
@ -9,6 +9,7 @@ from homeassistant.auth import auth_manager_from_config
|
|||
from homeassistant.auth.providers import homeassistant as hass_auth
|
||||
from homeassistant.config import get_default_config_dir
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
||||
|
||||
|
@ -51,6 +52,7 @@ def run(args):
|
|||
async def run_command(args):
|
||||
"""Run the command."""
|
||||
hass = HomeAssistant(os.path.join(os.getcwd(), args.config))
|
||||
await asyncio.gather(dr.async_load(hass), er.async_load(hass))
|
||||
hass.auth = await auth_manager_from_config(hass, [{"type": "homeassistant"}], [])
|
||||
provider = hass.auth.auth_providers[0]
|
||||
await provider.async_initialize()
|
||||
|
|
|
@ -13,9 +13,11 @@ from homeassistant.const import CONF_TYPE
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def store(hass):
|
||||
async def store(hass):
|
||||
"""Mock store."""
|
||||
return auth_store.AuthStore(hass)
|
||||
store = auth_store.AuthStore(hass)
|
||||
await store.async_load()
|
||||
return store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
@ -9,9 +9,11 @@ from homeassistant.auth.providers import insecure_example
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def store(hass):
|
||||
async def store(hass):
|
||||
"""Mock store."""
|
||||
return auth_store.AuthStore(hass)
|
||||
store = auth_store.AuthStore(hass)
|
||||
await store.async_load()
|
||||
return store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
@ -14,9 +14,11 @@ CONFIG = {"type": "legacy_api_password", "api_password": "test-password"}
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def store(hass):
|
||||
async def store(hass):
|
||||
"""Mock store."""
|
||||
return auth_store.AuthStore(hass)
|
||||
store = auth_store.AuthStore(hass)
|
||||
await store.async_load()
|
||||
return store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
@ -16,9 +16,11 @@ from homeassistant.setup import async_setup_component
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def store(hass):
|
||||
async def store(hass):
|
||||
"""Mock store."""
|
||||
return auth_store.AuthStore(hass)
|
||||
store = auth_store.AuthStore(hass)
|
||||
await store.async_load()
|
||||
return store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
@ -3,6 +3,8 @@ import asyncio
|
|||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.auth import auth_store
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
|
@ -67,6 +69,7 @@ async def test_loading_no_group_data_format(
|
|||
}
|
||||
|
||||
store = auth_store.AuthStore(hass)
|
||||
await store.async_load()
|
||||
groups = await store.async_get_groups()
|
||||
assert len(groups) == 3
|
||||
admin_group = groups[0]
|
||||
|
@ -165,6 +168,7 @@ async def test_loading_all_access_group_data_format(
|
|||
}
|
||||
|
||||
store = auth_store.AuthStore(hass)
|
||||
await store.async_load()
|
||||
groups = await store.async_get_groups()
|
||||
assert len(groups) == 3
|
||||
admin_group = groups[0]
|
||||
|
@ -205,6 +209,7 @@ async def test_loading_empty_data(
|
|||
) -> None:
|
||||
"""Test we correctly load with no existing data."""
|
||||
store = auth_store.AuthStore(hass)
|
||||
await store.async_load()
|
||||
groups = await store.async_get_groups()
|
||||
assert len(groups) == 3
|
||||
admin_group = groups[0]
|
||||
|
@ -232,7 +237,7 @@ async def test_system_groups_store_id_and_name(
|
|||
Name is stored so that we remain backwards compat with < 0.82.
|
||||
"""
|
||||
store = auth_store.AuthStore(hass)
|
||||
await store._async_load()
|
||||
await store.async_load()
|
||||
data = store._data_to_save()
|
||||
assert len(data["users"]) == 0
|
||||
assert data["groups"] == [
|
||||
|
@ -242,8 +247,8 @@ async def test_system_groups_store_id_and_name(
|
|||
]
|
||||
|
||||
|
||||
async def test_loading_race_condition(hass: HomeAssistant) -> None:
|
||||
"""Test only one storage load called when concurrent loading occurred ."""
|
||||
async def test_loading_only_once(hass: HomeAssistant) -> None:
|
||||
"""Test only one storage load is allowed."""
|
||||
store = auth_store.AuthStore(hass)
|
||||
with patch(
|
||||
"homeassistant.helpers.entity_registry.async_get"
|
||||
|
@ -252,6 +257,10 @@ async def test_loading_race_condition(hass: HomeAssistant) -> None:
|
|||
) as mock_dev_registry, patch(
|
||||
"homeassistant.helpers.storage.Store.async_load", return_value=None
|
||||
) as mock_load:
|
||||
await store.async_load()
|
||||
with pytest.raises(RuntimeError, match="Auth storage is already loaded"):
|
||||
await store.async_load()
|
||||
|
||||
results = await asyncio.gather(store.async_get_users(), store.async_get_users())
|
||||
|
||||
mock_ent_registry.assert_called_once_with(hass)
|
||||
|
|
|
@ -343,6 +343,7 @@ async def test_saving_loading(
|
|||
await flush_store(manager._store._store)
|
||||
|
||||
store2 = auth_store.AuthStore(hass)
|
||||
await store2.async_load()
|
||||
users = await store2.async_get_users()
|
||||
assert len(users) == 1
|
||||
assert users[0].permissions == user.permissions
|
||||
|
|
Loading…
Add table
Reference in a new issue