diff --git a/homeassistant/auth/auth_store.py b/homeassistant/auth/auth_store.py index 08ff2d7bb52..a64c14454a6 100644 --- a/homeassistant/auth/auth_store.py +++ b/homeassistant/auth/auth_store.py @@ -281,8 +281,9 @@ class AuthStore: async def _async_load_task(self) -> None: """Load the users.""" - [ent_reg, data] = await asyncio.gather( + [ent_reg, dev_reg, data] = await asyncio.gather( self.hass.helpers.entity_registry.async_get_registry(), + self.hass.helpers.device_registry.async_get_registry(), self._store.async_load(), ) @@ -291,7 +292,9 @@ class AuthStore: if self._users is not None: return - self._perm_lookup = perm_lookup = PermissionLookup(ent_reg) + self._perm_lookup = perm_lookup = PermissionLookup( + ent_reg, dev_reg + ) if data is None: self._set_defaults() diff --git a/homeassistant/auth/permissions/entities.py b/homeassistant/auth/permissions/entities.py index 0073c952648..3d7fc80307e 100644 --- a/homeassistant/auth/permissions/entities.py +++ b/homeassistant/auth/permissions/entities.py @@ -1,12 +1,14 @@ """Entity permissions.""" -from functools import wraps -from typing import Callable, List, Union # noqa: F401 +from collections import OrderedDict +from typing import Callable, Optional # noqa: F401 import voluptuous as vol from .const import SUBCAT_ALL, POLICY_READ, POLICY_CONTROL, POLICY_EDIT from .models import PermissionLookup -from .types import CategoryType, ValueType +from .types import CategoryType, SubCategoryDict, ValueType +# pylint: disable=unused-import +from .util import SubCatLookupType, lookup_all, compile_policy # noqa SINGLE_ENTITY_SCHEMA = vol.Any(True, vol.Schema({ vol.Optional(POLICY_READ): True, @@ -15,6 +17,7 @@ SINGLE_ENTITY_SCHEMA = vol.Any(True, vol.Schema({ })) ENTITY_DOMAINS = 'domains' +ENTITY_AREAS = 'area_ids' ENTITY_DEVICE_IDS = 'device_ids' ENTITY_ENTITY_IDS = 'entity_ids' @@ -24,148 +27,65 @@ ENTITY_VALUES_SCHEMA = vol.Any(True, vol.Schema({ ENTITY_POLICY_SCHEMA = vol.Any(True, vol.Schema({ vol.Optional(SUBCAT_ALL): SINGLE_ENTITY_SCHEMA, + vol.Optional(ENTITY_AREAS): ENTITY_VALUES_SCHEMA, vol.Optional(ENTITY_DEVICE_IDS): ENTITY_VALUES_SCHEMA, vol.Optional(ENTITY_DOMAINS): ENTITY_VALUES_SCHEMA, vol.Optional(ENTITY_ENTITY_IDS): ENTITY_VALUES_SCHEMA, })) -def _entity_allowed(schema: ValueType, key: str) \ - -> Union[bool, None]: - """Test if an entity is allowed based on the keys.""" - if schema is None or isinstance(schema, bool): - return schema - assert isinstance(schema, dict) - return schema.get(key) +def _lookup_domain(perm_lookup: PermissionLookup, + domains_dict: SubCategoryDict, + entity_id: str) -> Optional[ValueType]: + """Look up entity permissions by domain.""" + return domains_dict.get(entity_id.split(".", 1)[0]) + + +def _lookup_area(perm_lookup: PermissionLookup, area_dict: SubCategoryDict, + entity_id: str) -> Optional[ValueType]: + """Look up entity permissions by area.""" + entity_entry = perm_lookup.entity_registry.async_get(entity_id) + + if entity_entry is None or entity_entry.device_id is None: + return None + + device_entry = perm_lookup.device_registry.async_get( + entity_entry.device_id + ) + + if device_entry is None or device_entry.area_id is None: + return None + + return area_dict.get(device_entry.area_id) + + +def _lookup_device(perm_lookup: PermissionLookup, + devices_dict: SubCategoryDict, + entity_id: str) -> Optional[ValueType]: + """Look up entity permissions by device.""" + entity_entry = perm_lookup.entity_registry.async_get(entity_id) + + if entity_entry is None or entity_entry.device_id is None: + return None + + return devices_dict.get(entity_entry.device_id) + + +def _lookup_entity_id(perm_lookup: PermissionLookup, + entities_dict: SubCategoryDict, + entity_id: str) -> Optional[ValueType]: + """Look up entity permission by entity id.""" + return entities_dict.get(entity_id) def compile_entities(policy: CategoryType, perm_lookup: PermissionLookup) \ -> Callable[[str, str], bool]: """Compile policy into a function that tests policy.""" - # None, Empty Dict, False - if not policy: - def apply_policy_deny_all(entity_id: str, key: str) -> bool: - """Decline all.""" - return False + subcategories = OrderedDict() # type: SubCatLookupType + subcategories[ENTITY_ENTITY_IDS] = _lookup_entity_id + subcategories[ENTITY_DEVICE_IDS] = _lookup_device + subcategories[ENTITY_AREAS] = _lookup_area + subcategories[ENTITY_DOMAINS] = _lookup_domain + subcategories[SUBCAT_ALL] = lookup_all - return apply_policy_deny_all - - if policy is True: - def apply_policy_allow_all(entity_id: str, key: str) -> bool: - """Approve all.""" - return True - - return apply_policy_allow_all - - assert isinstance(policy, dict) - - domains = policy.get(ENTITY_DOMAINS) - device_ids = policy.get(ENTITY_DEVICE_IDS) - entity_ids = policy.get(ENTITY_ENTITY_IDS) - all_entities = policy.get(SUBCAT_ALL) - - funcs = [] # type: List[Callable[[str, str], Union[None, bool]]] - - # The order of these functions matter. The more precise are at the top. - # If a function returns None, they cannot handle it. - # If a function returns a boolean, that's the result to return. - - # Setting entity_ids to a boolean is final decision for permissions - # So return right away. - if isinstance(entity_ids, bool): - def allowed_entity_id_bool(entity_id: str, key: str) -> bool: - """Test if allowed entity_id.""" - return entity_ids # type: ignore - - return allowed_entity_id_bool - - if entity_ids is not None: - def allowed_entity_id_dict(entity_id: str, key: str) \ - -> Union[None, bool]: - """Test if allowed entity_id.""" - return _entity_allowed( - entity_ids.get(entity_id), key) # type: ignore - - funcs.append(allowed_entity_id_dict) - - if isinstance(device_ids, bool): - def allowed_device_id_bool(entity_id: str, key: str) \ - -> Union[None, bool]: - """Test if allowed device_id.""" - return device_ids - - funcs.append(allowed_device_id_bool) - - elif device_ids is not None: - def allowed_device_id_dict(entity_id: str, key: str) \ - -> Union[None, bool]: - """Test if allowed device_id.""" - entity_entry = perm_lookup.entity_registry.async_get(entity_id) - - if entity_entry is None or entity_entry.device_id is None: - return None - - return _entity_allowed( - device_ids.get(entity_entry.device_id), key # type: ignore - ) - - funcs.append(allowed_device_id_dict) - - if isinstance(domains, bool): - def allowed_domain_bool(entity_id: str, key: str) \ - -> Union[None, bool]: - """Test if allowed domain.""" - return domains - - funcs.append(allowed_domain_bool) - - elif domains is not None: - def allowed_domain_dict(entity_id: str, key: str) \ - -> Union[None, bool]: - """Test if allowed domain.""" - domain = entity_id.split(".", 1)[0] - return _entity_allowed(domains.get(domain), key) # type: ignore - - funcs.append(allowed_domain_dict) - - if isinstance(all_entities, bool): - def allowed_all_entities_bool(entity_id: str, key: str) \ - -> Union[None, bool]: - """Test if allowed domain.""" - return all_entities - funcs.append(allowed_all_entities_bool) - - elif all_entities is not None: - def allowed_all_entities_dict(entity_id: str, key: str) \ - -> Union[None, bool]: - """Test if allowed domain.""" - return _entity_allowed(all_entities, key) - funcs.append(allowed_all_entities_dict) - - # Can happen if no valid subcategories specified - if not funcs: - def apply_policy_deny_all_2(entity_id: str, key: str) -> bool: - """Decline all.""" - return False - - return apply_policy_deny_all_2 - - if len(funcs) == 1: - func = funcs[0] - - @wraps(func) - def apply_policy_func(entity_id: str, key: str) -> bool: - """Apply a single policy function.""" - return func(entity_id, key) is True - - return apply_policy_func - - def apply_policy_funcs(entity_id: str, key: str) -> bool: - """Apply several policy functions.""" - for func in funcs: - result = func(entity_id, key) - if result is not None: - return result - return False - - return apply_policy_funcs + return compile_policy(policy, subcategories, perm_lookup) diff --git a/homeassistant/auth/permissions/models.py b/homeassistant/auth/permissions/models.py index 7ad7d5521c5..10a76a4ec73 100644 --- a/homeassistant/auth/permissions/models.py +++ b/homeassistant/auth/permissions/models.py @@ -8,6 +8,9 @@ if TYPE_CHECKING: from homeassistant.helpers import ( # noqa entity_registry as ent_reg, ) + from homeassistant.helpers import ( # noqa + device_registry as dev_reg, + ) @attr.s(slots=True) @@ -15,3 +18,4 @@ class PermissionLookup: """Class to hold data for permission lookups.""" entity_registry = attr.ib(type='ent_reg.EntityRegistry') + device_registry = attr.ib(type='dev_reg.DeviceRegistry') diff --git a/homeassistant/auth/permissions/types.py b/homeassistant/auth/permissions/types.py index 78d13b9679f..5479e59dcb6 100644 --- a/homeassistant/auth/permissions/types.py +++ b/homeassistant/auth/permissions/types.py @@ -10,9 +10,11 @@ ValueType = Union[ None ] +# Example: entities.domains = { light: … } +SubCategoryDict = Mapping[str, ValueType] + SubCategoryType = Union[ - # Example: entities.domains = { light: … } - Mapping[str, ValueType], + SubCategoryDict, bool, None ] diff --git a/homeassistant/auth/permissions/util.py b/homeassistant/auth/permissions/util.py new file mode 100644 index 00000000000..d2d259fb32e --- /dev/null +++ b/homeassistant/auth/permissions/util.py @@ -0,0 +1,98 @@ +"""Helpers to deal with permissions.""" +from functools import wraps + +from typing import Callable, Dict, List, Optional, Union, cast # noqa: F401 + +from .models import PermissionLookup +from .types import CategoryType, SubCategoryDict, ValueType + +LookupFunc = Callable[[PermissionLookup, SubCategoryDict, str], + Optional[ValueType]] +SubCatLookupType = Dict[str, LookupFunc] + + +def lookup_all(perm_lookup: PermissionLookup, lookup_dict: SubCategoryDict, + object_id: str) -> ValueType: + """Look up permission for all.""" + # In case of ALL category, lookup_dict IS the schema. + return cast(ValueType, lookup_dict) + + +def compile_policy( + policy: CategoryType, subcategories: SubCatLookupType, + perm_lookup: PermissionLookup + ) -> Callable[[str, str], bool]: # noqa + """Compile policy into a function that tests policy. + Subcategories are mapping key -> lookup function, ordered by highest + priority first. + """ + # None, False, empty dict + if not policy: + def apply_policy_deny_all(entity_id: str, key: str) -> bool: + """Decline all.""" + return False + + return apply_policy_deny_all + + if policy is True: + def apply_policy_allow_all(entity_id: str, key: str) -> bool: + """Approve all.""" + return True + + return apply_policy_allow_all + + assert isinstance(policy, dict) + + funcs = [] # type: List[Callable[[str, str], Union[None, bool]]] + + for key, lookup_func in subcategories.items(): + lookup_value = policy.get(key) + + # If any lookup value is `True`, it will always be positive + if isinstance(lookup_value, bool): + return lambda object_id, key: True + + if lookup_value is not None: + funcs.append(_gen_dict_test_func( + perm_lookup, lookup_func, lookup_value)) + + if len(funcs) == 1: + func = funcs[0] + + @wraps(func) + def apply_policy_func(object_id: str, key: str) -> bool: + """Apply a single policy function.""" + return func(object_id, key) is True + + return apply_policy_func + + def apply_policy_funcs(object_id: str, key: str) -> bool: + """Apply several policy functions.""" + for func in funcs: + result = func(object_id, key) + if result is not None: + return result + return False + + return apply_policy_funcs + + +def _gen_dict_test_func( + perm_lookup: PermissionLookup, + lookup_func: LookupFunc, + lookup_dict: SubCategoryDict + ) -> Callable[[str, str], Optional[bool]]: # noqa + """Generate a lookup function.""" + def test_value(object_id: str, key: str) -> Optional[bool]: + """Test if permission is allowed based on the keys.""" + schema = lookup_func( + perm_lookup, lookup_dict, object_id) # type: ValueType + + if schema is None or isinstance(schema, bool): + return schema + + assert isinstance(schema, dict) + + return schema.get(key) + + return test_value diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index 9c8ee27d0d2..1ea6c400208 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -1,7 +1,7 @@ """Provide a way to connect entities belonging to one device.""" import logging import uuid -from typing import List +from typing import List, Optional from collections import OrderedDict @@ -71,6 +71,11 @@ class DeviceRegistry: self.devices = None self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) + @callback + def async_get(self, device_id: str) -> Optional[DeviceEntry]: + """Get device.""" + return self.devices.get(device_id) + @callback def async_get_device(self, identifiers: set, connections: set): """Check if device is registered.""" diff --git a/tests/auth/permissions/test_entities.py b/tests/auth/permissions/test_entities.py index 1fd70668f8b..119deac3311 100644 --- a/tests/auth/permissions/test_entities.py +++ b/tests/auth/permissions/test_entities.py @@ -6,8 +6,9 @@ from homeassistant.auth.permissions.entities import ( compile_entities, ENTITY_POLICY_SCHEMA) from homeassistant.auth.permissions.models import PermissionLookup from homeassistant.helpers.entity_registry import RegistryEntry +from homeassistant.helpers.device_registry import DeviceEntry -from tests.common import mock_registry +from tests.common import mock_registry, mock_device_registry def test_entities_none(): @@ -193,7 +194,7 @@ def test_entities_all_control(): def test_entities_device_id_boolean(hass): """Test entity ID policy applying control on device id.""" - registry = mock_registry(hass, { + entity_registry = mock_registry(hass, { 'test_domain.allowed': RegistryEntry( entity_id='test_domain.allowed', unique_id='1234', @@ -207,6 +208,7 @@ def test_entities_device_id_boolean(hass): device_id='mock-not-allowed-dev-id' ), }) + device_registry = mock_device_registry(hass) policy = { 'device_ids': { @@ -216,8 +218,55 @@ def test_entities_device_id_boolean(hass): } } ENTITY_POLICY_SCHEMA(policy) - compiled = compile_entities(policy, PermissionLookup(registry)) + compiled = compile_entities(policy, PermissionLookup( + entity_registry, device_registry + )) assert compiled('test_domain.allowed', 'read') is True assert compiled('test_domain.allowed', 'control') is False assert compiled('test_domain.not_allowed', 'read') is False assert compiled('test_domain.not_allowed', 'control') is False + + +def test_entities_areas_true(): + """Test entity ID policy for areas.""" + policy = { + 'area_ids': True + } + ENTITY_POLICY_SCHEMA(policy) + compiled = compile_entities(policy, None) + assert compiled('light.kitchen', 'read') is True + + +def test_entities_areas_area_true(hass): + """Test entity ID policy for areas with specific area.""" + entity_registry = mock_registry(hass, { + 'light.kitchen': RegistryEntry( + entity_id='light.kitchen', + unique_id='1234', + platform='test_platform', + device_id='mock-dev-id' + ), + }) + device_registry = mock_device_registry(hass, { + 'mock-dev-id': DeviceEntry( + id='mock-dev-id', + area_id='mock-area-id' + ) + }) + + policy = { + 'area_ids': { + 'mock-area-id': { + 'read': True, + 'control': True, + } + } + } + ENTITY_POLICY_SCHEMA(policy) + compiled = compile_entities(policy, PermissionLookup( + entity_registry, device_registry + )) + assert compiled('light.kitchen', 'read') is True + assert compiled('light.kitchen', 'control') is True + assert compiled('light.kitchen', 'edit') is False + assert compiled('switch.kitchen', 'read') is False diff --git a/tests/auth/test_auth_store.py b/tests/auth/test_auth_store.py index 136bc3d62ec..32c314b56d6 100644 --- a/tests/auth/test_auth_store.py +++ b/tests/auth/test_auth_store.py @@ -245,7 +245,9 @@ async def test_loading_race_condition(hass): store = auth_store.AuthStore(hass) with asynctest.patch( 'homeassistant.helpers.entity_registry.async_get_registry', - ) as mock_registry, asynctest.patch( + ) as mock_ent_registry, asynctest.patch( + 'homeassistant.helpers.device_registry.async_get_registry', + ) as mock_dev_registry, asynctest.patch( 'homeassistant.helpers.storage.Store.async_load', ) as mock_load: results = await asyncio.gather( @@ -253,6 +255,7 @@ async def test_loading_race_condition(hass): store.async_get_users(), ) - mock_registry.assert_called_once_with(hass) + mock_ent_registry.assert_called_once_with(hass) + mock_dev_registry.assert_called_once_with(hass) mock_load.assert_called_once_with() assert results[0] == results[1]