Always load auth storage at startup (#108543)

This commit is contained in:
J. Nick Koston 2024-01-20 16:16:43 -10:00 committed by GitHub
parent 4d46f5ec07
commit ec15b0def2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 43 additions and 81 deletions

View file

@ -47,6 +47,7 @@ async def auth_manager_from_config(
mfa modules exist in configs. mfa modules exist in configs.
""" """
store = auth_store.AuthStore(hass) store = auth_store.AuthStore(hass)
await store.async_load()
if provider_configs: if provider_configs:
providers = await asyncio.gather( providers = await asyncio.gather(
*( *(
@ -73,8 +74,7 @@ async def auth_manager_from_config(
for module in modules: for module in modules:
module_hash[module.id] = module module_hash[module.id] = module
manager = AuthManager(hass, store, provider_hash, module_hash) return AuthManager(hass, store, provider_hash, module_hash)
return manager
class AuthManagerFlowManager(data_entry_flow.FlowManager): class AuthManagerFlowManager(data_entry_flow.FlowManager):

View file

@ -1,7 +1,6 @@
"""Storage for auth models.""" """Storage for auth models."""
from __future__ import annotations from __future__ import annotations
import asyncio
from datetime import timedelta from datetime import timedelta
import hmac import hmac
from logging import getLogger from logging import getLogger
@ -42,44 +41,28 @@ class AuthStore:
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the auth store.""" """Initialize the auth store."""
self.hass = hass self.hass = hass
self._users: dict[str, models.User] | None = None self._loaded = False
self._groups: dict[str, models.Group] | None = None self._users: dict[str, models.User] = None # type: ignore[assignment]
self._perm_lookup: PermissionLookup | None = None self._groups: dict[str, models.Group] = None # type: ignore[assignment]
self._perm_lookup: PermissionLookup = None # type: ignore[assignment]
self._store = Store[dict[str, list[dict[str, Any]]]]( self._store = Store[dict[str, list[dict[str, Any]]]](
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
) )
self._lock = asyncio.Lock()
async def async_get_groups(self) -> list[models.Group]: async def async_get_groups(self) -> list[models.Group]:
"""Retrieve all users.""" """Retrieve all users."""
if self._groups is None:
await self._async_load()
assert self._groups is not None
return list(self._groups.values()) return list(self._groups.values())
async def async_get_group(self, group_id: str) -> models.Group | None: async def async_get_group(self, group_id: str) -> models.Group | None:
"""Retrieve all users.""" """Retrieve all users."""
if self._groups is None:
await self._async_load()
assert self._groups is not None
return self._groups.get(group_id) return self._groups.get(group_id)
async def async_get_users(self) -> list[models.User]: async def async_get_users(self) -> list[models.User]:
"""Retrieve all users.""" """Retrieve all users."""
if self._users is None:
await self._async_load()
assert self._users is not None
return list(self._users.values()) return list(self._users.values())
async def async_get_user(self, user_id: str) -> models.User | None: async def async_get_user(self, user_id: str) -> models.User | None:
"""Retrieve a user by id.""" """Retrieve a user by id."""
if self._users is None:
await self._async_load()
assert self._users is not None
return self._users.get(user_id) return self._users.get(user_id)
async def async_create_user( async def async_create_user(
@ -93,12 +76,6 @@ class AuthStore:
local_only: bool | None = None, local_only: bool | None = None,
) -> models.User: ) -> models.User:
"""Create a new user.""" """Create a new user."""
if self._users is None:
await self._async_load()
assert self._users is not None
assert self._groups is not None
groups = [] groups = []
for group_id in group_ids or []: for group_id in group_ids or []:
if (group := self._groups.get(group_id)) is None: if (group := self._groups.get(group_id)) is None:
@ -144,10 +121,6 @@ class AuthStore:
async def async_remove_user(self, user: models.User) -> None: async def async_remove_user(self, user: models.User) -> None:
"""Remove a user.""" """Remove a user."""
if self._users is None:
await self._async_load()
assert self._users is not None
self._users.pop(user.id) self._users.pop(user.id)
self._async_schedule_save() self._async_schedule_save()
@ -160,8 +133,6 @@ class AuthStore:
local_only: bool | None = None, local_only: bool | None = None,
) -> None: ) -> None:
"""Update a user.""" """Update a user."""
assert self._groups is not None
if group_ids is not None: if group_ids is not None:
groups = [] groups = []
for grid in group_ids: for grid in group_ids:
@ -193,10 +164,6 @@ class AuthStore:
async def async_remove_credentials(self, credentials: models.Credentials) -> None: async def async_remove_credentials(self, credentials: models.Credentials) -> None:
"""Remove credentials.""" """Remove credentials."""
if self._users is None:
await self._async_load()
assert self._users is not None
for user in self._users.values(): for user in self._users.values():
found = None found = None
@ -244,10 +211,6 @@ class AuthStore:
self, refresh_token: models.RefreshToken self, refresh_token: models.RefreshToken
) -> None: ) -> None:
"""Remove a refresh token.""" """Remove a refresh token."""
if self._users is None:
await self._async_load()
assert self._users is not None
for user in self._users.values(): for user in self._users.values():
if user.refresh_tokens.pop(refresh_token.id, None): if user.refresh_tokens.pop(refresh_token.id, None):
self._async_schedule_save() self._async_schedule_save()
@ -257,10 +220,6 @@ class AuthStore:
self, token_id: str self, token_id: str
) -> models.RefreshToken | None: ) -> models.RefreshToken | None:
"""Get refresh token by id.""" """Get refresh token by id."""
if self._users is None:
await self._async_load()
assert self._users is not None
for user in self._users.values(): for user in self._users.values():
refresh_token = user.refresh_tokens.get(token_id) refresh_token = user.refresh_tokens.get(token_id)
if refresh_token is not None: if refresh_token is not None:
@ -272,10 +231,6 @@ class AuthStore:
self, token: str self, token: str
) -> models.RefreshToken | None: ) -> models.RefreshToken | None:
"""Get refresh token by token.""" """Get refresh token by token."""
if self._users is None:
await self._async_load()
assert self._users is not None
found = None found = None
for user in self._users.values(): for user in self._users.values():
@ -294,25 +249,18 @@ class AuthStore:
refresh_token.last_used_ip = remote_ip refresh_token.last_used_ip = remote_ip
self._async_schedule_save() self._async_schedule_save()
async def _async_load(self) -> None: async def async_load(self) -> None:
"""Load the users.""" """Load the users."""
async with self._lock: if self._loaded:
if self._users is not None: raise RuntimeError("Auth storage is already loaded")
return self._loaded = True
await self._async_load_task()
async def _async_load_task(self) -> None:
"""Load the users."""
dev_reg = dr.async_get(self.hass) dev_reg = dr.async_get(self.hass)
ent_reg = er.async_get(self.hass) ent_reg = er.async_get(self.hass)
data = await self._store.async_load() data = await self._store.async_load()
# Make sure that we're not overriding data if 2 loads happened at the perm_lookup = PermissionLookup(ent_reg, dev_reg)
# same time self._perm_lookup = perm_lookup
if self._users is not None:
return
self._perm_lookup = perm_lookup = PermissionLookup(ent_reg, dev_reg)
if data is None or not isinstance(data, dict): if data is None or not isinstance(data, dict):
self._set_defaults() self._set_defaults()
@ -495,17 +443,11 @@ class AuthStore:
@callback @callback
def _async_schedule_save(self) -> None: def _async_schedule_save(self) -> None:
"""Save users.""" """Save users."""
if self._users is None:
return
self._store.async_delay_save(self._data_to_save, 1) self._store.async_delay_save(self._data_to_save, 1)
@callback @callback
def _data_to_save(self) -> dict[str, list[dict[str, Any]]]: def _data_to_save(self) -> dict[str, list[dict[str, Any]]]:
"""Return the data to store.""" """Return the data to store."""
assert self._users is not None
assert self._groups is not None
users = [ users = [
{ {
"id": user.id, "id": user.id,

View file

@ -9,6 +9,7 @@ from homeassistant.auth import auth_manager_from_config
from homeassistant.auth.providers import homeassistant as hass_auth from homeassistant.auth.providers import homeassistant as hass_auth
from homeassistant.config import get_default_config_dir from homeassistant.config import get_default_config_dir
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr, entity_registry as er
# mypy: allow-untyped-calls, allow-untyped-defs # mypy: allow-untyped-calls, allow-untyped-defs
@ -51,6 +52,7 @@ def run(args):
async def run_command(args): async def run_command(args):
"""Run the command.""" """Run the command."""
hass = HomeAssistant(os.path.join(os.getcwd(), args.config)) hass = HomeAssistant(os.path.join(os.getcwd(), args.config))
await asyncio.gather(dr.async_load(hass), er.async_load(hass))
hass.auth = await auth_manager_from_config(hass, [{"type": "homeassistant"}], []) hass.auth = await auth_manager_from_config(hass, [{"type": "homeassistant"}], [])
provider = hass.auth.auth_providers[0] provider = hass.auth.auth_providers[0]
await provider.async_initialize() await provider.async_initialize()

View file

@ -13,9 +13,11 @@ from homeassistant.const import CONF_TYPE
@pytest.fixture @pytest.fixture
def store(hass): async def store(hass):
"""Mock store.""" """Mock store."""
return auth_store.AuthStore(hass) store = auth_store.AuthStore(hass)
await store.async_load()
return store
@pytest.fixture @pytest.fixture

View file

@ -9,9 +9,11 @@ from homeassistant.auth.providers import insecure_example
@pytest.fixture @pytest.fixture
def store(hass): async def store(hass):
"""Mock store.""" """Mock store."""
return auth_store.AuthStore(hass) store = auth_store.AuthStore(hass)
await store.async_load()
return store
@pytest.fixture @pytest.fixture

View file

@ -14,9 +14,11 @@ CONFIG = {"type": "legacy_api_password", "api_password": "test-password"}
@pytest.fixture @pytest.fixture
def store(hass): async def store(hass):
"""Mock store.""" """Mock store."""
return auth_store.AuthStore(hass) store = auth_store.AuthStore(hass)
await store.async_load()
return store
@pytest.fixture @pytest.fixture

View file

@ -16,9 +16,11 @@ from homeassistant.setup import async_setup_component
@pytest.fixture @pytest.fixture
def store(hass): async def store(hass):
"""Mock store.""" """Mock store."""
return auth_store.AuthStore(hass) store = auth_store.AuthStore(hass)
await store.async_load()
return store
@pytest.fixture @pytest.fixture

View file

@ -3,6 +3,8 @@ import asyncio
from typing import Any from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest
from homeassistant.auth import auth_store from homeassistant.auth import auth_store
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -67,6 +69,7 @@ async def test_loading_no_group_data_format(
} }
store = auth_store.AuthStore(hass) store = auth_store.AuthStore(hass)
await store.async_load()
groups = await store.async_get_groups() groups = await store.async_get_groups()
assert len(groups) == 3 assert len(groups) == 3
admin_group = groups[0] admin_group = groups[0]
@ -165,6 +168,7 @@ async def test_loading_all_access_group_data_format(
} }
store = auth_store.AuthStore(hass) store = auth_store.AuthStore(hass)
await store.async_load()
groups = await store.async_get_groups() groups = await store.async_get_groups()
assert len(groups) == 3 assert len(groups) == 3
admin_group = groups[0] admin_group = groups[0]
@ -205,6 +209,7 @@ async def test_loading_empty_data(
) -> None: ) -> None:
"""Test we correctly load with no existing data.""" """Test we correctly load with no existing data."""
store = auth_store.AuthStore(hass) store = auth_store.AuthStore(hass)
await store.async_load()
groups = await store.async_get_groups() groups = await store.async_get_groups()
assert len(groups) == 3 assert len(groups) == 3
admin_group = groups[0] admin_group = groups[0]
@ -232,7 +237,7 @@ async def test_system_groups_store_id_and_name(
Name is stored so that we remain backwards compat with < 0.82. Name is stored so that we remain backwards compat with < 0.82.
""" """
store = auth_store.AuthStore(hass) store = auth_store.AuthStore(hass)
await store._async_load() await store.async_load()
data = store._data_to_save() data = store._data_to_save()
assert len(data["users"]) == 0 assert len(data["users"]) == 0
assert data["groups"] == [ assert data["groups"] == [
@ -242,8 +247,8 @@ async def test_system_groups_store_id_and_name(
] ]
async def test_loading_race_condition(hass: HomeAssistant) -> None: async def test_loading_only_once(hass: HomeAssistant) -> None:
"""Test only one storage load called when concurrent loading occurred .""" """Test only one storage load is allowed."""
store = auth_store.AuthStore(hass) store = auth_store.AuthStore(hass)
with patch( with patch(
"homeassistant.helpers.entity_registry.async_get" "homeassistant.helpers.entity_registry.async_get"
@ -252,6 +257,10 @@ async def test_loading_race_condition(hass: HomeAssistant) -> None:
) as mock_dev_registry, patch( ) as mock_dev_registry, patch(
"homeassistant.helpers.storage.Store.async_load", return_value=None "homeassistant.helpers.storage.Store.async_load", return_value=None
) as mock_load: ) as mock_load:
await store.async_load()
with pytest.raises(RuntimeError, match="Auth storage is already loaded"):
await store.async_load()
results = await asyncio.gather(store.async_get_users(), store.async_get_users()) results = await asyncio.gather(store.async_get_users(), store.async_get_users())
mock_ent_registry.assert_called_once_with(hass) mock_ent_registry.assert_called_once_with(hass)

View file

@ -343,6 +343,7 @@ async def test_saving_loading(
await flush_store(manager._store._store) await flush_store(manager._store._store)
store2 = auth_store.AuthStore(hass) store2 = auth_store.AuthStore(hass)
await store2.async_load()
users = await store2.async_get_users() users = await store2.async_get_users()
assert len(users) == 1 assert len(users) == 1
assert users[0].permissions == user.permissions assert users[0].permissions == user.permissions