Refactor User attribute caching to be safer and more efficient (#96723)
* Cache construction of is_admin This has to be checked for a lot of api calls and the websocket every time the call is made * Cache construction of is_admin This has to be checked for a lot of api calls and the websocket every time the call is made * Cache construction of is_admin This has to be checked for a lot of api calls and the websocket every time the call is made * modernize * coverage * coverage * verify caching * verify caching * fix type * fix mocking
This commit is contained in:
parent
d7910841ef
commit
b1d0c6a4f1
5 changed files with 72 additions and 30 deletions
|
@ -171,7 +171,6 @@ class AuthStore:
|
|||
groups.append(group)
|
||||
|
||||
user.groups = groups
|
||||
user.invalidate_permission_cache()
|
||||
|
||||
for attr_name, value in (
|
||||
("name", name),
|
||||
|
|
|
@ -3,10 +3,12 @@ from __future__ import annotations
|
|||
|
||||
from datetime import datetime, timedelta
|
||||
import secrets
|
||||
from typing import NamedTuple
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple
|
||||
import uuid
|
||||
|
||||
import attr
|
||||
from attr import Attribute
|
||||
from attr.setters import validate
|
||||
|
||||
from homeassistant.const import __version__
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
@ -14,6 +16,12 @@ from homeassistant.util import dt as dt_util
|
|||
from . import permissions as perm_mdl
|
||||
from .const import GROUP_ID_ADMIN
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from functools import cached_property
|
||||
else:
|
||||
from homeassistant.backports.functools import cached_property
|
||||
|
||||
|
||||
TOKEN_TYPE_NORMAL = "normal"
|
||||
TOKEN_TYPE_SYSTEM = "system"
|
||||
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN = "long_lived_access_token"
|
||||
|
@ -29,19 +37,27 @@ class Group:
|
|||
system_generated: bool = attr.ib(default=False)
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
def _handle_permissions_change(self: User, user_attr: Attribute, new: Any) -> Any:
|
||||
"""Handle a change to a permissions."""
|
||||
self.invalidate_cache()
|
||||
return validate(self, user_attr, new)
|
||||
|
||||
|
||||
@attr.s(slots=False)
|
||||
class User:
|
||||
"""A user."""
|
||||
|
||||
name: str | None = attr.ib()
|
||||
perm_lookup: perm_mdl.PermissionLookup = attr.ib(eq=False, order=False)
|
||||
id: str = attr.ib(factory=lambda: uuid.uuid4().hex)
|
||||
is_owner: bool = attr.ib(default=False)
|
||||
is_active: bool = attr.ib(default=False)
|
||||
is_owner: bool = attr.ib(default=False, on_setattr=_handle_permissions_change)
|
||||
is_active: bool = attr.ib(default=False, on_setattr=_handle_permissions_change)
|
||||
system_generated: bool = attr.ib(default=False)
|
||||
local_only: bool = attr.ib(default=False)
|
||||
|
||||
groups: list[Group] = attr.ib(factory=list, eq=False, order=False)
|
||||
groups: list[Group] = attr.ib(
|
||||
factory=list, eq=False, order=False, on_setattr=_handle_permissions_change
|
||||
)
|
||||
|
||||
# List of credentials of a user.
|
||||
credentials: list[Credentials] = attr.ib(factory=list, eq=False, order=False)
|
||||
|
@ -51,40 +67,31 @@ class User:
|
|||
factory=dict, eq=False, order=False
|
||||
)
|
||||
|
||||
_permissions: perm_mdl.PolicyPermissions | None = attr.ib(
|
||||
init=False,
|
||||
eq=False,
|
||||
order=False,
|
||||
default=None,
|
||||
)
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def permissions(self) -> perm_mdl.AbstractPermissions:
|
||||
"""Return permissions object for user."""
|
||||
if self.is_owner:
|
||||
return perm_mdl.OwnerPermissions
|
||||
|
||||
if self._permissions is not None:
|
||||
return self._permissions
|
||||
|
||||
self._permissions = perm_mdl.PolicyPermissions(
|
||||
return perm_mdl.PolicyPermissions(
|
||||
perm_mdl.merge_policies([group.policy for group in self.groups]),
|
||||
self.perm_lookup,
|
||||
)
|
||||
|
||||
return self._permissions
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def is_admin(self) -> bool:
|
||||
"""Return if user is part of the admin group."""
|
||||
if self.is_owner:
|
||||
return True
|
||||
return self.is_owner or (
|
||||
self.is_active and any(gr.id == GROUP_ID_ADMIN for gr in self.groups)
|
||||
)
|
||||
|
||||
return self.is_active and any(gr.id == GROUP_ID_ADMIN for gr in self.groups)
|
||||
|
||||
def invalidate_permission_cache(self) -> None:
|
||||
"""Invalidate permission cache."""
|
||||
self._permissions = None
|
||||
def invalidate_cache(self) -> None:
|
||||
"""Invalidate permission and is_admin cache."""
|
||||
for attr_to_invalidate in ("permissions", "is_admin"):
|
||||
# try is must more efficient than suppress
|
||||
try: # noqa: SIM105
|
||||
delattr(self, attr_to_invalidate)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
|
|
|
@ -26,3 +26,37 @@ def test_permissions_merged() -> None:
|
|||
assert user.permissions.check_entity("switch.bla", "read") is True
|
||||
assert user.permissions.check_entity("light.kitchen", "read") is True
|
||||
assert user.permissions.check_entity("light.not_kitchen", "read") is False
|
||||
|
||||
|
||||
def test_cache_cleared_on_group_change() -> None:
|
||||
"""Test we clear the cache when a group changes."""
|
||||
group = models.Group(
|
||||
name="Test Group", policy={"entities": {"domains": {"switch": True}}}
|
||||
)
|
||||
admin_group = models.Group(
|
||||
name="Admin group", id=models.GROUP_ID_ADMIN, policy={"entities": {}}
|
||||
)
|
||||
user = models.User(
|
||||
name="Test User", perm_lookup=None, groups=[group], is_active=True
|
||||
)
|
||||
# Make sure we cache instance
|
||||
assert user.permissions is user.permissions
|
||||
|
||||
# Make sure we cache is_admin
|
||||
assert user.is_admin is user.is_admin
|
||||
assert user.is_active is True
|
||||
|
||||
user.groups = []
|
||||
assert user.groups == []
|
||||
assert user.is_admin is False
|
||||
|
||||
user.is_owner = True
|
||||
assert user.is_admin is True
|
||||
user.is_owner = False
|
||||
|
||||
assert user.is_admin is False
|
||||
user.groups = [admin_group]
|
||||
assert user.is_admin is True
|
||||
|
||||
user.is_active = False
|
||||
assert user.is_admin is False
|
||||
|
|
|
@ -669,7 +669,7 @@ class MockUser(auth_models.User):
|
|||
|
||||
def mock_policy(self, policy):
|
||||
"""Mock a policy for a user."""
|
||||
self._permissions = auth_permissions.PolicyPermissions(policy, self.perm_lookup)
|
||||
self.permissions = auth_permissions.PolicyPermissions(policy, self.perm_lookup)
|
||||
|
||||
|
||||
async def register_auth_provider(
|
||||
|
|
|
@ -684,6 +684,8 @@ async def test_get_entity_state_read_perm(
|
|||
) -> None:
|
||||
"""Test getting a state requires read permission."""
|
||||
hass_admin_user.mock_policy({})
|
||||
hass_admin_user.groups = []
|
||||
assert hass_admin_user.is_admin is False
|
||||
resp = await mock_api_client.get("/api/states/light.test")
|
||||
assert resp.status == HTTPStatus.UNAUTHORIZED
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue