Add area permission check (#21835)

This commit is contained in:
Paulus Schoutsen 2019-03-11 11:02:37 -07:00 committed by GitHub
parent 4f49bdf262
commit 4f5446ff02
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 230 additions and 146 deletions

View file

@ -281,8 +281,9 @@ class AuthStore:
async def _async_load_task(self) -> None:
"""Load the users."""
[ent_reg, data] = await asyncio.gather(
[ent_reg, dev_reg, data] = await asyncio.gather(
self.hass.helpers.entity_registry.async_get_registry(),
self.hass.helpers.device_registry.async_get_registry(),
self._store.async_load(),
)
@ -291,7 +292,9 @@ class AuthStore:
if self._users is not None:
return
self._perm_lookup = perm_lookup = PermissionLookup(ent_reg)
self._perm_lookup = perm_lookup = PermissionLookup(
ent_reg, dev_reg
)
if data is None:
self._set_defaults()

View file

@ -1,12 +1,14 @@
"""Entity permissions."""
from functools import wraps
from typing import Callable, List, Union # noqa: F401
from collections import OrderedDict
from typing import Callable, Optional # noqa: F401
import voluptuous as vol
from .const import SUBCAT_ALL, POLICY_READ, POLICY_CONTROL, POLICY_EDIT
from .models import PermissionLookup
from .types import CategoryType, ValueType
from .types import CategoryType, SubCategoryDict, ValueType
# pylint: disable=unused-import
from .util import SubCatLookupType, lookup_all, compile_policy # noqa
SINGLE_ENTITY_SCHEMA = vol.Any(True, vol.Schema({
vol.Optional(POLICY_READ): True,
@ -15,6 +17,7 @@ SINGLE_ENTITY_SCHEMA = vol.Any(True, vol.Schema({
}))
ENTITY_DOMAINS = 'domains'
ENTITY_AREAS = 'area_ids'
ENTITY_DEVICE_IDS = 'device_ids'
ENTITY_ENTITY_IDS = 'entity_ids'
@ -24,148 +27,65 @@ ENTITY_VALUES_SCHEMA = vol.Any(True, vol.Schema({
ENTITY_POLICY_SCHEMA = vol.Any(True, vol.Schema({
vol.Optional(SUBCAT_ALL): SINGLE_ENTITY_SCHEMA,
vol.Optional(ENTITY_AREAS): ENTITY_VALUES_SCHEMA,
vol.Optional(ENTITY_DEVICE_IDS): ENTITY_VALUES_SCHEMA,
vol.Optional(ENTITY_DOMAINS): ENTITY_VALUES_SCHEMA,
vol.Optional(ENTITY_ENTITY_IDS): ENTITY_VALUES_SCHEMA,
}))
def _entity_allowed(schema: ValueType, key: str) \
-> Union[bool, None]:
"""Test if an entity is allowed based on the keys."""
if schema is None or isinstance(schema, bool):
return schema
assert isinstance(schema, dict)
return schema.get(key)
def _lookup_domain(perm_lookup: PermissionLookup,
domains_dict: SubCategoryDict,
entity_id: str) -> Optional[ValueType]:
"""Look up entity permissions by domain."""
return domains_dict.get(entity_id.split(".", 1)[0])
def _lookup_area(perm_lookup: PermissionLookup, area_dict: SubCategoryDict,
entity_id: str) -> Optional[ValueType]:
"""Look up entity permissions by area."""
entity_entry = perm_lookup.entity_registry.async_get(entity_id)
if entity_entry is None or entity_entry.device_id is None:
return None
device_entry = perm_lookup.device_registry.async_get(
entity_entry.device_id
)
if device_entry is None or device_entry.area_id is None:
return None
return area_dict.get(device_entry.area_id)
def _lookup_device(perm_lookup: PermissionLookup,
devices_dict: SubCategoryDict,
entity_id: str) -> Optional[ValueType]:
"""Look up entity permissions by device."""
entity_entry = perm_lookup.entity_registry.async_get(entity_id)
if entity_entry is None or entity_entry.device_id is None:
return None
return devices_dict.get(entity_entry.device_id)
def _lookup_entity_id(perm_lookup: PermissionLookup,
entities_dict: SubCategoryDict,
entity_id: str) -> Optional[ValueType]:
"""Look up entity permission by entity id."""
return entities_dict.get(entity_id)
def compile_entities(policy: CategoryType, perm_lookup: PermissionLookup) \
-> Callable[[str, str], bool]:
"""Compile policy into a function that tests policy."""
# None, Empty Dict, False
if not policy:
def apply_policy_deny_all(entity_id: str, key: str) -> bool:
"""Decline all."""
return False
subcategories = OrderedDict() # type: SubCatLookupType
subcategories[ENTITY_ENTITY_IDS] = _lookup_entity_id
subcategories[ENTITY_DEVICE_IDS] = _lookup_device
subcategories[ENTITY_AREAS] = _lookup_area
subcategories[ENTITY_DOMAINS] = _lookup_domain
subcategories[SUBCAT_ALL] = lookup_all
return apply_policy_deny_all
if policy is True:
def apply_policy_allow_all(entity_id: str, key: str) -> bool:
"""Approve all."""
return True
return apply_policy_allow_all
assert isinstance(policy, dict)
domains = policy.get(ENTITY_DOMAINS)
device_ids = policy.get(ENTITY_DEVICE_IDS)
entity_ids = policy.get(ENTITY_ENTITY_IDS)
all_entities = policy.get(SUBCAT_ALL)
funcs = [] # type: List[Callable[[str, str], Union[None, bool]]]
# The order of these functions matter. The more precise are at the top.
# If a function returns None, they cannot handle it.
# If a function returns a boolean, that's the result to return.
# Setting entity_ids to a boolean is final decision for permissions
# So return right away.
if isinstance(entity_ids, bool):
def allowed_entity_id_bool(entity_id: str, key: str) -> bool:
"""Test if allowed entity_id."""
return entity_ids # type: ignore
return allowed_entity_id_bool
if entity_ids is not None:
def allowed_entity_id_dict(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed entity_id."""
return _entity_allowed(
entity_ids.get(entity_id), key) # type: ignore
funcs.append(allowed_entity_id_dict)
if isinstance(device_ids, bool):
def allowed_device_id_bool(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed device_id."""
return device_ids
funcs.append(allowed_device_id_bool)
elif device_ids is not None:
def allowed_device_id_dict(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed device_id."""
entity_entry = perm_lookup.entity_registry.async_get(entity_id)
if entity_entry is None or entity_entry.device_id is None:
return None
return _entity_allowed(
device_ids.get(entity_entry.device_id), key # type: ignore
)
funcs.append(allowed_device_id_dict)
if isinstance(domains, bool):
def allowed_domain_bool(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed domain."""
return domains
funcs.append(allowed_domain_bool)
elif domains is not None:
def allowed_domain_dict(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed domain."""
domain = entity_id.split(".", 1)[0]
return _entity_allowed(domains.get(domain), key) # type: ignore
funcs.append(allowed_domain_dict)
if isinstance(all_entities, bool):
def allowed_all_entities_bool(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed domain."""
return all_entities
funcs.append(allowed_all_entities_bool)
elif all_entities is not None:
def allowed_all_entities_dict(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed domain."""
return _entity_allowed(all_entities, key)
funcs.append(allowed_all_entities_dict)
# Can happen if no valid subcategories specified
if not funcs:
def apply_policy_deny_all_2(entity_id: str, key: str) -> bool:
"""Decline all."""
return False
return apply_policy_deny_all_2
if len(funcs) == 1:
func = funcs[0]
@wraps(func)
def apply_policy_func(entity_id: str, key: str) -> bool:
"""Apply a single policy function."""
return func(entity_id, key) is True
return apply_policy_func
def apply_policy_funcs(entity_id: str, key: str) -> bool:
"""Apply several policy functions."""
for func in funcs:
result = func(entity_id, key)
if result is not None:
return result
return False
return apply_policy_funcs
return compile_policy(policy, subcategories, perm_lookup)

View file

@ -8,6 +8,9 @@ if TYPE_CHECKING:
from homeassistant.helpers import ( # noqa
entity_registry as ent_reg,
)
from homeassistant.helpers import ( # noqa
device_registry as dev_reg,
)
@attr.s(slots=True)
@ -15,3 +18,4 @@ class PermissionLookup:
"""Class to hold data for permission lookups."""
entity_registry = attr.ib(type='ent_reg.EntityRegistry')
device_registry = attr.ib(type='dev_reg.DeviceRegistry')

View file

@ -10,9 +10,11 @@ ValueType = Union[
None
]
# Example: entities.domains = { light: … }
SubCategoryDict = Mapping[str, ValueType]
SubCategoryType = Union[
# Example: entities.domains = { light: … }
Mapping[str, ValueType],
SubCategoryDict,
bool,
None
]

View file

@ -0,0 +1,98 @@
"""Helpers to deal with permissions."""
from functools import wraps
from typing import Callable, Dict, List, Optional, Union, cast # noqa: F401
from .models import PermissionLookup
from .types import CategoryType, SubCategoryDict, ValueType
LookupFunc = Callable[[PermissionLookup, SubCategoryDict, str],
Optional[ValueType]]
SubCatLookupType = Dict[str, LookupFunc]
def lookup_all(perm_lookup: PermissionLookup, lookup_dict: SubCategoryDict,
object_id: str) -> ValueType:
"""Look up permission for all."""
# In case of ALL category, lookup_dict IS the schema.
return cast(ValueType, lookup_dict)
def compile_policy(
policy: CategoryType, subcategories: SubCatLookupType,
perm_lookup: PermissionLookup
) -> Callable[[str, str], bool]: # noqa
"""Compile policy into a function that tests policy.
Subcategories are mapping key -> lookup function, ordered by highest
priority first.
"""
# None, False, empty dict
if not policy:
def apply_policy_deny_all(entity_id: str, key: str) -> bool:
"""Decline all."""
return False
return apply_policy_deny_all
if policy is True:
def apply_policy_allow_all(entity_id: str, key: str) -> bool:
"""Approve all."""
return True
return apply_policy_allow_all
assert isinstance(policy, dict)
funcs = [] # type: List[Callable[[str, str], Union[None, bool]]]
for key, lookup_func in subcategories.items():
lookup_value = policy.get(key)
# If any lookup value is `True`, it will always be positive
if isinstance(lookup_value, bool):
return lambda object_id, key: True
if lookup_value is not None:
funcs.append(_gen_dict_test_func(
perm_lookup, lookup_func, lookup_value))
if len(funcs) == 1:
func = funcs[0]
@wraps(func)
def apply_policy_func(object_id: str, key: str) -> bool:
"""Apply a single policy function."""
return func(object_id, key) is True
return apply_policy_func
def apply_policy_funcs(object_id: str, key: str) -> bool:
"""Apply several policy functions."""
for func in funcs:
result = func(object_id, key)
if result is not None:
return result
return False
return apply_policy_funcs
def _gen_dict_test_func(
perm_lookup: PermissionLookup,
lookup_func: LookupFunc,
lookup_dict: SubCategoryDict
) -> Callable[[str, str], Optional[bool]]: # noqa
"""Generate a lookup function."""
def test_value(object_id: str, key: str) -> Optional[bool]:
"""Test if permission is allowed based on the keys."""
schema = lookup_func(
perm_lookup, lookup_dict, object_id) # type: ValueType
if schema is None or isinstance(schema, bool):
return schema
assert isinstance(schema, dict)
return schema.get(key)
return test_value

View file

@ -1,7 +1,7 @@
"""Provide a way to connect entities belonging to one device."""
import logging
import uuid
from typing import List
from typing import List, Optional
from collections import OrderedDict
@ -71,6 +71,11 @@ class DeviceRegistry:
self.devices = None
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
@callback
def async_get(self, device_id: str) -> Optional[DeviceEntry]:
"""Get device."""
return self.devices.get(device_id)
@callback
def async_get_device(self, identifiers: set, connections: set):
"""Check if device is registered."""

View file

@ -6,8 +6,9 @@ from homeassistant.auth.permissions.entities import (
compile_entities, ENTITY_POLICY_SCHEMA)
from homeassistant.auth.permissions.models import PermissionLookup
from homeassistant.helpers.entity_registry import RegistryEntry
from homeassistant.helpers.device_registry import DeviceEntry
from tests.common import mock_registry
from tests.common import mock_registry, mock_device_registry
def test_entities_none():
@ -193,7 +194,7 @@ def test_entities_all_control():
def test_entities_device_id_boolean(hass):
"""Test entity ID policy applying control on device id."""
registry = mock_registry(hass, {
entity_registry = mock_registry(hass, {
'test_domain.allowed': RegistryEntry(
entity_id='test_domain.allowed',
unique_id='1234',
@ -207,6 +208,7 @@ def test_entities_device_id_boolean(hass):
device_id='mock-not-allowed-dev-id'
),
})
device_registry = mock_device_registry(hass)
policy = {
'device_ids': {
@ -216,8 +218,55 @@ def test_entities_device_id_boolean(hass):
}
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy, PermissionLookup(registry))
compiled = compile_entities(policy, PermissionLookup(
entity_registry, device_registry
))
assert compiled('test_domain.allowed', 'read') is True
assert compiled('test_domain.allowed', 'control') is False
assert compiled('test_domain.not_allowed', 'read') is False
assert compiled('test_domain.not_allowed', 'control') is False
def test_entities_areas_true():
"""Test entity ID policy for areas."""
policy = {
'area_ids': True
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is True
def test_entities_areas_area_true(hass):
"""Test entity ID policy for areas with specific area."""
entity_registry = mock_registry(hass, {
'light.kitchen': RegistryEntry(
entity_id='light.kitchen',
unique_id='1234',
platform='test_platform',
device_id='mock-dev-id'
),
})
device_registry = mock_device_registry(hass, {
'mock-dev-id': DeviceEntry(
id='mock-dev-id',
area_id='mock-area-id'
)
})
policy = {
'area_ids': {
'mock-area-id': {
'read': True,
'control': True,
}
}
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy, PermissionLookup(
entity_registry, device_registry
))
assert compiled('light.kitchen', 'read') is True
assert compiled('light.kitchen', 'control') is True
assert compiled('light.kitchen', 'edit') is False
assert compiled('switch.kitchen', 'read') is False

View file

@ -245,7 +245,9 @@ async def test_loading_race_condition(hass):
store = auth_store.AuthStore(hass)
with asynctest.patch(
'homeassistant.helpers.entity_registry.async_get_registry',
) as mock_registry, asynctest.patch(
) as mock_ent_registry, asynctest.patch(
'homeassistant.helpers.device_registry.async_get_registry',
) as mock_dev_registry, asynctest.patch(
'homeassistant.helpers.storage.Store.async_load',
) as mock_load:
results = await asyncio.gather(
@ -253,6 +255,7 @@ async def test_loading_race_condition(hass):
store.async_get_users(),
)
mock_registry.assert_called_once_with(hass)
mock_ent_registry.assert_called_once_with(hass)
mock_dev_registry.assert_called_once_with(hass)
mock_load.assert_called_once_with()
assert results[0] == results[1]