Add area permission check (#21835)
This commit is contained in:
parent
4f49bdf262
commit
4f5446ff02
8 changed files with 230 additions and 146 deletions
|
@ -281,8 +281,9 @@ class AuthStore:
|
|||
|
||||
async def _async_load_task(self) -> None:
|
||||
"""Load the users."""
|
||||
[ent_reg, data] = await asyncio.gather(
|
||||
[ent_reg, dev_reg, data] = await asyncio.gather(
|
||||
self.hass.helpers.entity_registry.async_get_registry(),
|
||||
self.hass.helpers.device_registry.async_get_registry(),
|
||||
self._store.async_load(),
|
||||
)
|
||||
|
||||
|
@ -291,7 +292,9 @@ class AuthStore:
|
|||
if self._users is not None:
|
||||
return
|
||||
|
||||
self._perm_lookup = perm_lookup = PermissionLookup(ent_reg)
|
||||
self._perm_lookup = perm_lookup = PermissionLookup(
|
||||
ent_reg, dev_reg
|
||||
)
|
||||
|
||||
if data is None:
|
||||
self._set_defaults()
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
"""Entity permissions."""
|
||||
from functools import wraps
|
||||
from typing import Callable, List, Union # noqa: F401
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, Optional # noqa: F401
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from .const import SUBCAT_ALL, POLICY_READ, POLICY_CONTROL, POLICY_EDIT
|
||||
from .models import PermissionLookup
|
||||
from .types import CategoryType, ValueType
|
||||
from .types import CategoryType, SubCategoryDict, ValueType
|
||||
# pylint: disable=unused-import
|
||||
from .util import SubCatLookupType, lookup_all, compile_policy # noqa
|
||||
|
||||
SINGLE_ENTITY_SCHEMA = vol.Any(True, vol.Schema({
|
||||
vol.Optional(POLICY_READ): True,
|
||||
|
@ -15,6 +17,7 @@ SINGLE_ENTITY_SCHEMA = vol.Any(True, vol.Schema({
|
|||
}))
|
||||
|
||||
ENTITY_DOMAINS = 'domains'
|
||||
ENTITY_AREAS = 'area_ids'
|
||||
ENTITY_DEVICE_IDS = 'device_ids'
|
||||
ENTITY_ENTITY_IDS = 'entity_ids'
|
||||
|
||||
|
@ -24,148 +27,65 @@ ENTITY_VALUES_SCHEMA = vol.Any(True, vol.Schema({
|
|||
|
||||
ENTITY_POLICY_SCHEMA = vol.Any(True, vol.Schema({
|
||||
vol.Optional(SUBCAT_ALL): SINGLE_ENTITY_SCHEMA,
|
||||
vol.Optional(ENTITY_AREAS): ENTITY_VALUES_SCHEMA,
|
||||
vol.Optional(ENTITY_DEVICE_IDS): ENTITY_VALUES_SCHEMA,
|
||||
vol.Optional(ENTITY_DOMAINS): ENTITY_VALUES_SCHEMA,
|
||||
vol.Optional(ENTITY_ENTITY_IDS): ENTITY_VALUES_SCHEMA,
|
||||
}))
|
||||
|
||||
|
||||
def _entity_allowed(schema: ValueType, key: str) \
|
||||
-> Union[bool, None]:
|
||||
"""Test if an entity is allowed based on the keys."""
|
||||
if schema is None or isinstance(schema, bool):
|
||||
return schema
|
||||
assert isinstance(schema, dict)
|
||||
return schema.get(key)
|
||||
def _lookup_domain(perm_lookup: PermissionLookup,
|
||||
domains_dict: SubCategoryDict,
|
||||
entity_id: str) -> Optional[ValueType]:
|
||||
"""Look up entity permissions by domain."""
|
||||
return domains_dict.get(entity_id.split(".", 1)[0])
|
||||
|
||||
|
||||
def _lookup_area(perm_lookup: PermissionLookup, area_dict: SubCategoryDict,
|
||||
entity_id: str) -> Optional[ValueType]:
|
||||
"""Look up entity permissions by area."""
|
||||
entity_entry = perm_lookup.entity_registry.async_get(entity_id)
|
||||
|
||||
if entity_entry is None or entity_entry.device_id is None:
|
||||
return None
|
||||
|
||||
device_entry = perm_lookup.device_registry.async_get(
|
||||
entity_entry.device_id
|
||||
)
|
||||
|
||||
if device_entry is None or device_entry.area_id is None:
|
||||
return None
|
||||
|
||||
return area_dict.get(device_entry.area_id)
|
||||
|
||||
|
||||
def _lookup_device(perm_lookup: PermissionLookup,
|
||||
devices_dict: SubCategoryDict,
|
||||
entity_id: str) -> Optional[ValueType]:
|
||||
"""Look up entity permissions by device."""
|
||||
entity_entry = perm_lookup.entity_registry.async_get(entity_id)
|
||||
|
||||
if entity_entry is None or entity_entry.device_id is None:
|
||||
return None
|
||||
|
||||
return devices_dict.get(entity_entry.device_id)
|
||||
|
||||
|
||||
def _lookup_entity_id(perm_lookup: PermissionLookup,
|
||||
entities_dict: SubCategoryDict,
|
||||
entity_id: str) -> Optional[ValueType]:
|
||||
"""Look up entity permission by entity id."""
|
||||
return entities_dict.get(entity_id)
|
||||
|
||||
|
||||
def compile_entities(policy: CategoryType, perm_lookup: PermissionLookup) \
|
||||
-> Callable[[str, str], bool]:
|
||||
"""Compile policy into a function that tests policy."""
|
||||
# None, Empty Dict, False
|
||||
if not policy:
|
||||
def apply_policy_deny_all(entity_id: str, key: str) -> bool:
|
||||
"""Decline all."""
|
||||
return False
|
||||
subcategories = OrderedDict() # type: SubCatLookupType
|
||||
subcategories[ENTITY_ENTITY_IDS] = _lookup_entity_id
|
||||
subcategories[ENTITY_DEVICE_IDS] = _lookup_device
|
||||
subcategories[ENTITY_AREAS] = _lookup_area
|
||||
subcategories[ENTITY_DOMAINS] = _lookup_domain
|
||||
subcategories[SUBCAT_ALL] = lookup_all
|
||||
|
||||
return apply_policy_deny_all
|
||||
|
||||
if policy is True:
|
||||
def apply_policy_allow_all(entity_id: str, key: str) -> bool:
|
||||
"""Approve all."""
|
||||
return True
|
||||
|
||||
return apply_policy_allow_all
|
||||
|
||||
assert isinstance(policy, dict)
|
||||
|
||||
domains = policy.get(ENTITY_DOMAINS)
|
||||
device_ids = policy.get(ENTITY_DEVICE_IDS)
|
||||
entity_ids = policy.get(ENTITY_ENTITY_IDS)
|
||||
all_entities = policy.get(SUBCAT_ALL)
|
||||
|
||||
funcs = [] # type: List[Callable[[str, str], Union[None, bool]]]
|
||||
|
||||
# The order of these functions matter. The more precise are at the top.
|
||||
# If a function returns None, they cannot handle it.
|
||||
# If a function returns a boolean, that's the result to return.
|
||||
|
||||
# Setting entity_ids to a boolean is final decision for permissions
|
||||
# So return right away.
|
||||
if isinstance(entity_ids, bool):
|
||||
def allowed_entity_id_bool(entity_id: str, key: str) -> bool:
|
||||
"""Test if allowed entity_id."""
|
||||
return entity_ids # type: ignore
|
||||
|
||||
return allowed_entity_id_bool
|
||||
|
||||
if entity_ids is not None:
|
||||
def allowed_entity_id_dict(entity_id: str, key: str) \
|
||||
-> Union[None, bool]:
|
||||
"""Test if allowed entity_id."""
|
||||
return _entity_allowed(
|
||||
entity_ids.get(entity_id), key) # type: ignore
|
||||
|
||||
funcs.append(allowed_entity_id_dict)
|
||||
|
||||
if isinstance(device_ids, bool):
|
||||
def allowed_device_id_bool(entity_id: str, key: str) \
|
||||
-> Union[None, bool]:
|
||||
"""Test if allowed device_id."""
|
||||
return device_ids
|
||||
|
||||
funcs.append(allowed_device_id_bool)
|
||||
|
||||
elif device_ids is not None:
|
||||
def allowed_device_id_dict(entity_id: str, key: str) \
|
||||
-> Union[None, bool]:
|
||||
"""Test if allowed device_id."""
|
||||
entity_entry = perm_lookup.entity_registry.async_get(entity_id)
|
||||
|
||||
if entity_entry is None or entity_entry.device_id is None:
|
||||
return None
|
||||
|
||||
return _entity_allowed(
|
||||
device_ids.get(entity_entry.device_id), key # type: ignore
|
||||
)
|
||||
|
||||
funcs.append(allowed_device_id_dict)
|
||||
|
||||
if isinstance(domains, bool):
|
||||
def allowed_domain_bool(entity_id: str, key: str) \
|
||||
-> Union[None, bool]:
|
||||
"""Test if allowed domain."""
|
||||
return domains
|
||||
|
||||
funcs.append(allowed_domain_bool)
|
||||
|
||||
elif domains is not None:
|
||||
def allowed_domain_dict(entity_id: str, key: str) \
|
||||
-> Union[None, bool]:
|
||||
"""Test if allowed domain."""
|
||||
domain = entity_id.split(".", 1)[0]
|
||||
return _entity_allowed(domains.get(domain), key) # type: ignore
|
||||
|
||||
funcs.append(allowed_domain_dict)
|
||||
|
||||
if isinstance(all_entities, bool):
|
||||
def allowed_all_entities_bool(entity_id: str, key: str) \
|
||||
-> Union[None, bool]:
|
||||
"""Test if allowed domain."""
|
||||
return all_entities
|
||||
funcs.append(allowed_all_entities_bool)
|
||||
|
||||
elif all_entities is not None:
|
||||
def allowed_all_entities_dict(entity_id: str, key: str) \
|
||||
-> Union[None, bool]:
|
||||
"""Test if allowed domain."""
|
||||
return _entity_allowed(all_entities, key)
|
||||
funcs.append(allowed_all_entities_dict)
|
||||
|
||||
# Can happen if no valid subcategories specified
|
||||
if not funcs:
|
||||
def apply_policy_deny_all_2(entity_id: str, key: str) -> bool:
|
||||
"""Decline all."""
|
||||
return False
|
||||
|
||||
return apply_policy_deny_all_2
|
||||
|
||||
if len(funcs) == 1:
|
||||
func = funcs[0]
|
||||
|
||||
@wraps(func)
|
||||
def apply_policy_func(entity_id: str, key: str) -> bool:
|
||||
"""Apply a single policy function."""
|
||||
return func(entity_id, key) is True
|
||||
|
||||
return apply_policy_func
|
||||
|
||||
def apply_policy_funcs(entity_id: str, key: str) -> bool:
|
||||
"""Apply several policy functions."""
|
||||
for func in funcs:
|
||||
result = func(entity_id, key)
|
||||
if result is not None:
|
||||
return result
|
||||
return False
|
||||
|
||||
return apply_policy_funcs
|
||||
return compile_policy(policy, subcategories, perm_lookup)
|
||||
|
|
|
@ -8,6 +8,9 @@ if TYPE_CHECKING:
|
|||
from homeassistant.helpers import ( # noqa
|
||||
entity_registry as ent_reg,
|
||||
)
|
||||
from homeassistant.helpers import ( # noqa
|
||||
device_registry as dev_reg,
|
||||
)
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
|
@ -15,3 +18,4 @@ class PermissionLookup:
|
|||
"""Class to hold data for permission lookups."""
|
||||
|
||||
entity_registry = attr.ib(type='ent_reg.EntityRegistry')
|
||||
device_registry = attr.ib(type='dev_reg.DeviceRegistry')
|
||||
|
|
|
@ -10,9 +10,11 @@ ValueType = Union[
|
|||
None
|
||||
]
|
||||
|
||||
# Example: entities.domains = { light: … }
|
||||
SubCategoryDict = Mapping[str, ValueType]
|
||||
|
||||
SubCategoryType = Union[
|
||||
# Example: entities.domains = { light: … }
|
||||
Mapping[str, ValueType],
|
||||
SubCategoryDict,
|
||||
bool,
|
||||
None
|
||||
]
|
||||
|
|
98
homeassistant/auth/permissions/util.py
Normal file
98
homeassistant/auth/permissions/util.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
"""Helpers to deal with permissions."""
|
||||
from functools import wraps
|
||||
|
||||
from typing import Callable, Dict, List, Optional, Union, cast # noqa: F401
|
||||
|
||||
from .models import PermissionLookup
|
||||
from .types import CategoryType, SubCategoryDict, ValueType
|
||||
|
||||
LookupFunc = Callable[[PermissionLookup, SubCategoryDict, str],
|
||||
Optional[ValueType]]
|
||||
SubCatLookupType = Dict[str, LookupFunc]
|
||||
|
||||
|
||||
def lookup_all(perm_lookup: PermissionLookup, lookup_dict: SubCategoryDict,
|
||||
object_id: str) -> ValueType:
|
||||
"""Look up permission for all."""
|
||||
# In case of ALL category, lookup_dict IS the schema.
|
||||
return cast(ValueType, lookup_dict)
|
||||
|
||||
|
||||
def compile_policy(
|
||||
policy: CategoryType, subcategories: SubCatLookupType,
|
||||
perm_lookup: PermissionLookup
|
||||
) -> Callable[[str, str], bool]: # noqa
|
||||
"""Compile policy into a function that tests policy.
|
||||
Subcategories are mapping key -> lookup function, ordered by highest
|
||||
priority first.
|
||||
"""
|
||||
# None, False, empty dict
|
||||
if not policy:
|
||||
def apply_policy_deny_all(entity_id: str, key: str) -> bool:
|
||||
"""Decline all."""
|
||||
return False
|
||||
|
||||
return apply_policy_deny_all
|
||||
|
||||
if policy is True:
|
||||
def apply_policy_allow_all(entity_id: str, key: str) -> bool:
|
||||
"""Approve all."""
|
||||
return True
|
||||
|
||||
return apply_policy_allow_all
|
||||
|
||||
assert isinstance(policy, dict)
|
||||
|
||||
funcs = [] # type: List[Callable[[str, str], Union[None, bool]]]
|
||||
|
||||
for key, lookup_func in subcategories.items():
|
||||
lookup_value = policy.get(key)
|
||||
|
||||
# If any lookup value is `True`, it will always be positive
|
||||
if isinstance(lookup_value, bool):
|
||||
return lambda object_id, key: True
|
||||
|
||||
if lookup_value is not None:
|
||||
funcs.append(_gen_dict_test_func(
|
||||
perm_lookup, lookup_func, lookup_value))
|
||||
|
||||
if len(funcs) == 1:
|
||||
func = funcs[0]
|
||||
|
||||
@wraps(func)
|
||||
def apply_policy_func(object_id: str, key: str) -> bool:
|
||||
"""Apply a single policy function."""
|
||||
return func(object_id, key) is True
|
||||
|
||||
return apply_policy_func
|
||||
|
||||
def apply_policy_funcs(object_id: str, key: str) -> bool:
|
||||
"""Apply several policy functions."""
|
||||
for func in funcs:
|
||||
result = func(object_id, key)
|
||||
if result is not None:
|
||||
return result
|
||||
return False
|
||||
|
||||
return apply_policy_funcs
|
||||
|
||||
|
||||
def _gen_dict_test_func(
|
||||
perm_lookup: PermissionLookup,
|
||||
lookup_func: LookupFunc,
|
||||
lookup_dict: SubCategoryDict
|
||||
) -> Callable[[str, str], Optional[bool]]: # noqa
|
||||
"""Generate a lookup function."""
|
||||
def test_value(object_id: str, key: str) -> Optional[bool]:
|
||||
"""Test if permission is allowed based on the keys."""
|
||||
schema = lookup_func(
|
||||
perm_lookup, lookup_dict, object_id) # type: ValueType
|
||||
|
||||
if schema is None or isinstance(schema, bool):
|
||||
return schema
|
||||
|
||||
assert isinstance(schema, dict)
|
||||
|
||||
return schema.get(key)
|
||||
|
||||
return test_value
|
|
@ -1,7 +1,7 @@
|
|||
"""Provide a way to connect entities belonging to one device."""
|
||||
import logging
|
||||
import uuid
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
|
@ -71,6 +71,11 @@ class DeviceRegistry:
|
|||
self.devices = None
|
||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||
|
||||
@callback
|
||||
def async_get(self, device_id: str) -> Optional[DeviceEntry]:
|
||||
"""Get device."""
|
||||
return self.devices.get(device_id)
|
||||
|
||||
@callback
|
||||
def async_get_device(self, identifiers: set, connections: set):
|
||||
"""Check if device is registered."""
|
||||
|
|
|
@ -6,8 +6,9 @@ from homeassistant.auth.permissions.entities import (
|
|||
compile_entities, ENTITY_POLICY_SCHEMA)
|
||||
from homeassistant.auth.permissions.models import PermissionLookup
|
||||
from homeassistant.helpers.entity_registry import RegistryEntry
|
||||
from homeassistant.helpers.device_registry import DeviceEntry
|
||||
|
||||
from tests.common import mock_registry
|
||||
from tests.common import mock_registry, mock_device_registry
|
||||
|
||||
|
||||
def test_entities_none():
|
||||
|
@ -193,7 +194,7 @@ def test_entities_all_control():
|
|||
|
||||
def test_entities_device_id_boolean(hass):
|
||||
"""Test entity ID policy applying control on device id."""
|
||||
registry = mock_registry(hass, {
|
||||
entity_registry = mock_registry(hass, {
|
||||
'test_domain.allowed': RegistryEntry(
|
||||
entity_id='test_domain.allowed',
|
||||
unique_id='1234',
|
||||
|
@ -207,6 +208,7 @@ def test_entities_device_id_boolean(hass):
|
|||
device_id='mock-not-allowed-dev-id'
|
||||
),
|
||||
})
|
||||
device_registry = mock_device_registry(hass)
|
||||
|
||||
policy = {
|
||||
'device_ids': {
|
||||
|
@ -216,8 +218,55 @@ def test_entities_device_id_boolean(hass):
|
|||
}
|
||||
}
|
||||
ENTITY_POLICY_SCHEMA(policy)
|
||||
compiled = compile_entities(policy, PermissionLookup(registry))
|
||||
compiled = compile_entities(policy, PermissionLookup(
|
||||
entity_registry, device_registry
|
||||
))
|
||||
assert compiled('test_domain.allowed', 'read') is True
|
||||
assert compiled('test_domain.allowed', 'control') is False
|
||||
assert compiled('test_domain.not_allowed', 'read') is False
|
||||
assert compiled('test_domain.not_allowed', 'control') is False
|
||||
|
||||
|
||||
def test_entities_areas_true():
|
||||
"""Test entity ID policy for areas."""
|
||||
policy = {
|
||||
'area_ids': True
|
||||
}
|
||||
ENTITY_POLICY_SCHEMA(policy)
|
||||
compiled = compile_entities(policy, None)
|
||||
assert compiled('light.kitchen', 'read') is True
|
||||
|
||||
|
||||
def test_entities_areas_area_true(hass):
|
||||
"""Test entity ID policy for areas with specific area."""
|
||||
entity_registry = mock_registry(hass, {
|
||||
'light.kitchen': RegistryEntry(
|
||||
entity_id='light.kitchen',
|
||||
unique_id='1234',
|
||||
platform='test_platform',
|
||||
device_id='mock-dev-id'
|
||||
),
|
||||
})
|
||||
device_registry = mock_device_registry(hass, {
|
||||
'mock-dev-id': DeviceEntry(
|
||||
id='mock-dev-id',
|
||||
area_id='mock-area-id'
|
||||
)
|
||||
})
|
||||
|
||||
policy = {
|
||||
'area_ids': {
|
||||
'mock-area-id': {
|
||||
'read': True,
|
||||
'control': True,
|
||||
}
|
||||
}
|
||||
}
|
||||
ENTITY_POLICY_SCHEMA(policy)
|
||||
compiled = compile_entities(policy, PermissionLookup(
|
||||
entity_registry, device_registry
|
||||
))
|
||||
assert compiled('light.kitchen', 'read') is True
|
||||
assert compiled('light.kitchen', 'control') is True
|
||||
assert compiled('light.kitchen', 'edit') is False
|
||||
assert compiled('switch.kitchen', 'read') is False
|
||||
|
|
|
@ -245,7 +245,9 @@ async def test_loading_race_condition(hass):
|
|||
store = auth_store.AuthStore(hass)
|
||||
with asynctest.patch(
|
||||
'homeassistant.helpers.entity_registry.async_get_registry',
|
||||
) as mock_registry, asynctest.patch(
|
||||
) as mock_ent_registry, asynctest.patch(
|
||||
'homeassistant.helpers.device_registry.async_get_registry',
|
||||
) as mock_dev_registry, asynctest.patch(
|
||||
'homeassistant.helpers.storage.Store.async_load',
|
||||
) as mock_load:
|
||||
results = await asyncio.gather(
|
||||
|
@ -253,6 +255,7 @@ async def test_loading_race_condition(hass):
|
|||
store.async_get_users(),
|
||||
)
|
||||
|
||||
mock_registry.assert_called_once_with(hass)
|
||||
mock_ent_registry.assert_called_once_with(hass)
|
||||
mock_dev_registry.assert_called_once_with(hass)
|
||||
mock_load.assert_called_once_with()
|
||||
assert results[0] == results[1]
|
||||
|
|
Loading…
Add table
Reference in a new issue