Allow checking entity permissions based on devices (#19007)

* Allow checking entity permissions based on devices

* Fix tests
This commit is contained in:
Paulus Schoutsen 2018-12-05 11:41:00 +01:00 committed by Pascal Vizeli
parent 2680bf8a61
commit 3928d034a3
11 changed files with 143 additions and 27 deletions

View file

@ -1,4 +1,5 @@
"""Storage for auth models."""
import asyncio
from collections import OrderedDict
from datetime import timedelta
import hmac
@ -11,7 +12,7 @@ from homeassistant.util import dt as dt_util
from . import models
from .const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY
from .permissions import system_policies
from .permissions import PermissionLookup, system_policies
from .permissions.types import PolicyType # noqa: F401
STORAGE_VERSION = 1
@ -34,6 +35,7 @@ class AuthStore:
self.hass = hass
self._users = None # type: Optional[Dict[str, models.User]]
self._groups = None # type: Optional[Dict[str, models.Group]]
self._perm_lookup = None # type: Optional[PermissionLookup]
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY,
private=True)
@ -94,6 +96,7 @@ class AuthStore:
# Until we get group management, we just put everyone in the
# same group.
'groups': groups,
'perm_lookup': self._perm_lookup,
} # type: Dict[str, Any]
if is_owner is not None:
@ -269,13 +272,18 @@ class AuthStore:
async def _async_load(self) -> None:
"""Load the users."""
data = await self._store.async_load()
[ent_reg, data] = await asyncio.gather(
self.hass.helpers.entity_registry.async_get_registry(),
self._store.async_load(),
)
# Make sure that we're not overriding data if 2 loads happened at the
# same time
if self._users is not None:
return
self._perm_lookup = perm_lookup = PermissionLookup(ent_reg)
if data is None:
self._set_defaults()
return
@ -374,6 +382,7 @@ class AuthStore:
is_owner=user_dict['is_owner'],
is_active=user_dict['is_active'],
system_generated=user_dict['system_generated'],
perm_lookup=perm_lookup,
)
for cred_dict in data['credentials']:

View file

@ -31,6 +31,9 @@ class User:
"""A user."""
name = attr.ib(type=str) # type: Optional[str]
perm_lookup = attr.ib(
type=perm_mdl.PermissionLookup, cmp=False,
) # type: perm_mdl.PermissionLookup
id = attr.ib(type=str, factory=lambda: uuid.uuid4().hex)
is_owner = attr.ib(type=bool, default=False)
is_active = attr.ib(type=bool, default=False)
@ -66,7 +69,8 @@ class User:
self._permissions = perm_mdl.PolicyPermissions(
perm_mdl.merge_policies([
group.policy for group in self.groups]))
group.policy for group in self.groups]),
self.perm_lookup)
return self._permissions

View file

@ -1,15 +1,18 @@
"""Permissions for Home Assistant."""
import logging
from typing import ( # noqa: F401
cast, Any, Callable, Dict, List, Mapping, Set, Tuple, Union)
cast, Any, Callable, Dict, List, Mapping, Set, Tuple, Union,
TYPE_CHECKING)
import voluptuous as vol
from .const import CAT_ENTITIES
from .models import PermissionLookup
from .types import PolicyType
from .entities import ENTITY_POLICY_SCHEMA, compile_entities
from .merge import merge_policies # noqa
POLICY_SCHEMA = vol.Schema({
vol.Optional(CAT_ENTITIES): ENTITY_POLICY_SCHEMA
})
@ -39,13 +42,16 @@ class AbstractPermissions:
class PolicyPermissions(AbstractPermissions):
"""Handle permissions."""
def __init__(self, policy: PolicyType) -> None:
def __init__(self, policy: PolicyType,
perm_lookup: PermissionLookup) -> None:
"""Initialize the permission class."""
self._policy = policy
self._perm_lookup = perm_lookup
def _entity_func(self) -> Callable[[str, str], bool]:
"""Return a function that can test entity access."""
return compile_entities(self._policy.get(CAT_ENTITIES))
return compile_entities(self._policy.get(CAT_ENTITIES),
self._perm_lookup)
def __eq__(self, other: Any) -> bool:
"""Equals check."""

View file

@ -5,6 +5,7 @@ from typing import Callable, List, Union # noqa: F401
import voluptuous as vol
from .const import SUBCAT_ALL, POLICY_READ, POLICY_CONTROL, POLICY_EDIT
from .models import PermissionLookup
from .types import CategoryType, ValueType
SINGLE_ENTITY_SCHEMA = vol.Any(True, vol.Schema({
@ -14,6 +15,7 @@ SINGLE_ENTITY_SCHEMA = vol.Any(True, vol.Schema({
}))
ENTITY_DOMAINS = 'domains'
ENTITY_DEVICE_IDS = 'device_ids'
ENTITY_ENTITY_IDS = 'entity_ids'
ENTITY_VALUES_SCHEMA = vol.Any(True, vol.Schema({
@ -22,6 +24,7 @@ ENTITY_VALUES_SCHEMA = vol.Any(True, vol.Schema({
ENTITY_POLICY_SCHEMA = vol.Any(True, vol.Schema({
vol.Optional(SUBCAT_ALL): SINGLE_ENTITY_SCHEMA,
vol.Optional(ENTITY_DEVICE_IDS): ENTITY_VALUES_SCHEMA,
vol.Optional(ENTITY_DOMAINS): ENTITY_VALUES_SCHEMA,
vol.Optional(ENTITY_ENTITY_IDS): ENTITY_VALUES_SCHEMA,
}))
@ -36,7 +39,7 @@ def _entity_allowed(schema: ValueType, key: str) \
return schema.get(key)
def compile_entities(policy: CategoryType) \
def compile_entities(policy: CategoryType, perm_lookup: PermissionLookup) \
-> Callable[[str, str], bool]:
"""Compile policy into a function that tests policy."""
# None, Empty Dict, False
@ -57,6 +60,7 @@ def compile_entities(policy: CategoryType) \
assert isinstance(policy, dict)
domains = policy.get(ENTITY_DOMAINS)
device_ids = policy.get(ENTITY_DEVICE_IDS)
entity_ids = policy.get(ENTITY_ENTITY_IDS)
all_entities = policy.get(SUBCAT_ALL)
@ -84,6 +88,29 @@ def compile_entities(policy: CategoryType) \
funcs.append(allowed_entity_id_dict)
if isinstance(device_ids, bool):
def allowed_device_id_bool(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed device_id."""
return device_ids
funcs.append(allowed_device_id_bool)
elif device_ids is not None:
def allowed_device_id_dict(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed device_id."""
entity_entry = perm_lookup.entity_registry.async_get(entity_id)
if entity_entry is None or entity_entry.device_id is None:
return None
return _entity_allowed(
device_ids.get(entity_entry.device_id), key # type: ignore
)
funcs.append(allowed_device_id_dict)
if isinstance(domains, bool):
def allowed_domain_bool(entity_id: str, key: str) \
-> Union[None, bool]:

View file

@ -0,0 +1,17 @@
"""Models for permissions."""
from typing import TYPE_CHECKING
import attr
if TYPE_CHECKING:
# pylint: disable=unused-import
from homeassistant.helpers import ( # noqa
entity_registry as ent_reg,
)
@attr.s(slots=True)
class PermissionLookup:
"""Class to hold data for permission lookups."""
entity_registry = attr.ib(type='ent_reg.EntityRegistry')

View file

@ -10,6 +10,7 @@ timer.
from collections import OrderedDict
from itertools import chain
import logging
from typing import Optional
import weakref
import attr
@ -85,6 +86,11 @@ class EntityRegistry:
"""Check if an entity_id is currently registered."""
return entity_id in self.entities
@callback
def async_get(self, entity_id: str) -> Optional[RegistryEntry]:
"""Get EntityEntry for an entity_id."""
return self.entities.get(entity_id)
@callback
def async_get_entity_id(self, domain: str, platform: str, unique_id: str):
"""Check if an entity_id is currently registered."""

View file

@ -4,12 +4,16 @@ import voluptuous as vol
from homeassistant.auth.permissions.entities import (
compile_entities, ENTITY_POLICY_SCHEMA)
from homeassistant.auth.permissions.models import PermissionLookup
from homeassistant.helpers.entity_registry import RegistryEntry
from tests.common import mock_registry
def test_entities_none():
"""Test entity ID policy."""
policy = None
compiled = compile_entities(policy)
compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is False
@ -17,7 +21,7 @@ def test_entities_empty():
"""Test entity ID policy."""
policy = {}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy)
compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is False
@ -32,7 +36,7 @@ def test_entities_true():
"""Test entity ID policy."""
policy = True
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy)
compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is True
@ -42,7 +46,7 @@ def test_entities_domains_true():
'domains': True
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy)
compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is True
@ -54,7 +58,7 @@ def test_entities_domains_domain_true():
}
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy)
compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is True
assert compiled('switch.kitchen', 'read') is False
@ -76,7 +80,7 @@ def test_entities_entity_ids_true():
'entity_ids': True
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy)
compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is True
@ -97,7 +101,7 @@ def test_entities_entity_ids_entity_id_true():
}
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy)
compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is True
assert compiled('switch.kitchen', 'read') is False
@ -123,7 +127,7 @@ def test_entities_control_only():
}
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy)
compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is True
assert compiled('light.kitchen', 'control') is False
assert compiled('light.kitchen', 'edit') is False
@ -140,7 +144,7 @@ def test_entities_read_control():
}
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy)
compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is True
assert compiled('light.kitchen', 'control') is True
assert compiled('light.kitchen', 'edit') is False
@ -152,7 +156,7 @@ def test_entities_all_allow():
'all': True
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy)
compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is True
assert compiled('light.kitchen', 'control') is True
assert compiled('switch.kitchen', 'read') is True
@ -166,7 +170,7 @@ def test_entities_all_read():
}
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy)
compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is True
assert compiled('light.kitchen', 'control') is False
assert compiled('switch.kitchen', 'read') is True
@ -180,8 +184,40 @@ def test_entities_all_control():
}
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy)
compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is False
assert compiled('light.kitchen', 'control') is True
assert compiled('switch.kitchen', 'read') is False
assert compiled('switch.kitchen', 'control') is True
def test_entities_device_id_boolean(hass):
"""Test entity ID policy applying control on device id."""
registry = mock_registry(hass, {
'test_domain.allowed': RegistryEntry(
entity_id='test_domain.allowed',
unique_id='1234',
platform='test_platform',
device_id='mock-allowed-dev-id'
),
'test_domain.not_allowed': RegistryEntry(
entity_id='test_domain.not_allowed',
unique_id='5678',
platform='test_platform',
device_id='mock-not-allowed-dev-id'
),
})
policy = {
'device_ids': {
'mock-allowed-dev-id': {
'read': True,
}
}
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy, PermissionLookup(registry))
assert compiled('test_domain.allowed', 'read') is True
assert compiled('test_domain.allowed', 'control') is False
assert compiled('test_domain.not_allowed', 'read') is False
assert compiled('test_domain.not_allowed', 'control') is False

View file

@ -8,7 +8,7 @@ def test_admin_policy():
# Make sure it's valid
POLICY_SCHEMA(system_policies.ADMIN_POLICY)
perms = PolicyPermissions(system_policies.ADMIN_POLICY)
perms = PolicyPermissions(system_policies.ADMIN_POLICY, None)
assert perms.check_entity('light.kitchen', 'read')
assert perms.check_entity('light.kitchen', 'control')
assert perms.check_entity('light.kitchen', 'edit')
@ -19,7 +19,7 @@ def test_read_only_policy():
# Make sure it's valid
POLICY_SCHEMA(system_policies.READ_ONLY_POLICY)
perms = PolicyPermissions(system_policies.READ_ONLY_POLICY)
perms = PolicyPermissions(system_policies.READ_ONLY_POLICY, None)
assert perms.check_entity('light.kitchen', 'read')
assert not perms.check_entity('light.kitchen', 'control')
assert not perms.check_entity('light.kitchen', 'edit')

View file

@ -5,7 +5,12 @@ from homeassistant.auth import models, permissions
def test_owner_fetching_owner_permissions():
"""Test we fetch the owner permissions for an owner user."""
group = models.Group(name="Test Group", policy={})
owner = models.User(name="Test User", groups=[group], is_owner=True)
owner = models.User(
name="Test User",
perm_lookup=None,
groups=[group],
is_owner=True
)
assert owner.permissions is permissions.OwnerPermissions
@ -25,7 +30,11 @@ def test_permissions_merged():
}
}
})
user = models.User(name="Test User", groups=[group, group2])
user = models.User(
name="Test User",
perm_lookup=None,
groups=[group, group2]
)
# Make sure we cache instance
assert user.permissions is user.permissions

View file

@ -384,6 +384,7 @@ class MockUser(auth_models.User):
'name': name,
'system_generated': system_generated,
'groups': groups or [],
'perm_lookup': None,
}
if id is not None:
kwargs['id'] = id
@ -401,7 +402,8 @@ class MockUser(auth_models.User):
def mock_policy(self, policy):
"""Mock a policy for a user."""
self._permissions = auth_permissions.PolicyPermissions(policy)
self._permissions = auth_permissions.PolicyPermissions(
policy, self.perm_lookup)
async def register_auth_provider(hass, config):

View file

@ -232,7 +232,7 @@ async def test_call_context_target_all(hass, mock_service_platform_call,
'light.kitchen': True
}
}
})))):
}, None)))):
await service.entity_service_call(hass, [
Mock(entities=mock_entities)
], Mock(), ha.ServiceCall('test_domain', 'test_service',
@ -253,7 +253,7 @@ async def test_call_context_target_specific(hass, mock_service_platform_call,
'light.kitchen': True
}
}
})))):
}, None)))):
await service.entity_service_call(hass, [
Mock(entities=mock_entities)
], Mock(), ha.ServiceCall('test_domain', 'test_service', {
@ -271,7 +271,7 @@ async def test_call_context_target_specific_no_auth(
with pytest.raises(exceptions.Unauthorized) as err:
with patch('homeassistant.auth.AuthManager.async_get_user',
return_value=mock_coro(Mock(
permissions=PolicyPermissions({})))):
permissions=PolicyPermissions({}, None)))):
await service.entity_service_call(hass, [
Mock(entities=mock_entities)
], Mock(), ha.ServiceCall('test_domain', 'test_service', {