Refactor User attribute caching to be safer and more efficient (#96723)

* Cache construction of is_admin

This has to be checked for a lot of api calls and the websocket
every time the call is made

* Cache construction of is_admin

This has to be checked for a lot of api calls and the websocket
every time the call is made

* Cache construction of is_admin

This has to be checked for a lot of api calls and the websocket
every time the call is made

* modernize

* coverage

* coverage

* verify caching

* verify caching

* fix type

* fix mocking
This commit is contained in:
J. Nick Koston 2024-01-13 10:10:50 -10:00 committed by GitHub
parent d7910841ef
commit b1d0c6a4f1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 72 additions and 30 deletions

View file

@ -171,7 +171,6 @@ class AuthStore:
groups.append(group)
user.groups = groups
user.invalidate_permission_cache()
for attr_name, value in (
("name", name),

View file

@ -3,10 +3,12 @@ from __future__ import annotations
from datetime import datetime, timedelta
import secrets
from typing import NamedTuple
from typing import TYPE_CHECKING, Any, NamedTuple
import uuid
import attr
from attr import Attribute
from attr.setters import validate
from homeassistant.const import __version__
from homeassistant.util import dt as dt_util
@ -14,6 +16,12 @@ from homeassistant.util import dt as dt_util
from . import permissions as perm_mdl
from .const import GROUP_ID_ADMIN
if TYPE_CHECKING:
from functools import cached_property
else:
from homeassistant.backports.functools import cached_property
TOKEN_TYPE_NORMAL = "normal"
TOKEN_TYPE_SYSTEM = "system"
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN = "long_lived_access_token"
@ -29,19 +37,27 @@ class Group:
system_generated: bool = attr.ib(default=False)
@attr.s(slots=True)
def _handle_permissions_change(self: User, user_attr: Attribute, new: Any) -> Any:
"""Handle a change to a permissions."""
self.invalidate_cache()
return validate(self, user_attr, new)
@attr.s(slots=False)
class User:
"""A user."""
name: str | None = attr.ib()
perm_lookup: perm_mdl.PermissionLookup = attr.ib(eq=False, order=False)
id: str = attr.ib(factory=lambda: uuid.uuid4().hex)
is_owner: bool = attr.ib(default=False)
is_active: bool = attr.ib(default=False)
is_owner: bool = attr.ib(default=False, on_setattr=_handle_permissions_change)
is_active: bool = attr.ib(default=False, on_setattr=_handle_permissions_change)
system_generated: bool = attr.ib(default=False)
local_only: bool = attr.ib(default=False)
groups: list[Group] = attr.ib(factory=list, eq=False, order=False)
groups: list[Group] = attr.ib(
factory=list, eq=False, order=False, on_setattr=_handle_permissions_change
)
# List of credentials of a user.
credentials: list[Credentials] = attr.ib(factory=list, eq=False, order=False)
@ -51,40 +67,31 @@ class User:
factory=dict, eq=False, order=False
)
_permissions: perm_mdl.PolicyPermissions | None = attr.ib(
init=False,
eq=False,
order=False,
default=None,
)
@property
@cached_property
def permissions(self) -> perm_mdl.AbstractPermissions:
"""Return permissions object for user."""
if self.is_owner:
return perm_mdl.OwnerPermissions
if self._permissions is not None:
return self._permissions
self._permissions = perm_mdl.PolicyPermissions(
return perm_mdl.PolicyPermissions(
perm_mdl.merge_policies([group.policy for group in self.groups]),
self.perm_lookup,
)
return self._permissions
@property
@cached_property
def is_admin(self) -> bool:
"""Return if user is part of the admin group."""
if self.is_owner:
return True
return self.is_owner or (
self.is_active and any(gr.id == GROUP_ID_ADMIN for gr in self.groups)
)
return self.is_active and any(gr.id == GROUP_ID_ADMIN for gr in self.groups)
def invalidate_permission_cache(self) -> None:
"""Invalidate permission cache."""
self._permissions = None
def invalidate_cache(self) -> None:
"""Invalidate permission and is_admin cache."""
for attr_to_invalidate in ("permissions", "is_admin"):
# try is must more efficient than suppress
try: # noqa: SIM105
delattr(self, attr_to_invalidate)
except AttributeError:
pass
@attr.s(slots=True)

View file

@ -26,3 +26,37 @@ def test_permissions_merged() -> None:
assert user.permissions.check_entity("switch.bla", "read") is True
assert user.permissions.check_entity("light.kitchen", "read") is True
assert user.permissions.check_entity("light.not_kitchen", "read") is False
def test_cache_cleared_on_group_change() -> None:
"""Test we clear the cache when a group changes."""
group = models.Group(
name="Test Group", policy={"entities": {"domains": {"switch": True}}}
)
admin_group = models.Group(
name="Admin group", id=models.GROUP_ID_ADMIN, policy={"entities": {}}
)
user = models.User(
name="Test User", perm_lookup=None, groups=[group], is_active=True
)
# Make sure we cache instance
assert user.permissions is user.permissions
# Make sure we cache is_admin
assert user.is_admin is user.is_admin
assert user.is_active is True
user.groups = []
assert user.groups == []
assert user.is_admin is False
user.is_owner = True
assert user.is_admin is True
user.is_owner = False
assert user.is_admin is False
user.groups = [admin_group]
assert user.is_admin is True
user.is_active = False
assert user.is_admin is False

View file

@ -669,7 +669,7 @@ class MockUser(auth_models.User):
def mock_policy(self, policy):
"""Mock a policy for a user."""
self._permissions = auth_permissions.PolicyPermissions(policy, self.perm_lookup)
self.permissions = auth_permissions.PolicyPermissions(policy, self.perm_lookup)
async def register_auth_provider(

View file

@ -684,6 +684,8 @@ async def test_get_entity_state_read_perm(
) -> None:
"""Test getting a state requires read permission."""
hass_admin_user.mock_policy({})
hass_admin_user.groups = []
assert hass_admin_user.is_admin is False
resp = await mock_api_client.get("/api/states/light.test")
assert resp.status == HTTPStatus.UNAUTHORIZED