diff --git a/homeassistant/auth/auth_store.py b/homeassistant/auth/auth_store.py index bad1bdcf913..c6078e03f63 100644 --- a/homeassistant/auth/auth_store.py +++ b/homeassistant/auth/auth_store.py @@ -1,4 +1,5 @@ """Storage for auth models.""" +import asyncio from collections import OrderedDict from datetime import timedelta import hmac @@ -11,7 +12,7 @@ from homeassistant.util import dt as dt_util from . import models from .const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY -from .permissions import system_policies +from .permissions import PermissionLookup, system_policies from .permissions.types import PolicyType # noqa: F401 STORAGE_VERSION = 1 @@ -34,6 +35,7 @@ class AuthStore: self.hass = hass self._users = None # type: Optional[Dict[str, models.User]] self._groups = None # type: Optional[Dict[str, models.Group]] + self._perm_lookup = None # type: Optional[PermissionLookup] self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY, private=True) @@ -94,6 +96,7 @@ class AuthStore: # Until we get group management, we just put everyone in the # same group. 'groups': groups, + 'perm_lookup': self._perm_lookup, } # type: Dict[str, Any] if is_owner is not None: @@ -269,13 +272,18 @@ class AuthStore: async def _async_load(self) -> None: """Load the users.""" - data = await self._store.async_load() + [ent_reg, data] = await asyncio.gather( + self.hass.helpers.entity_registry.async_get_registry(), + self._store.async_load(), + ) # Make sure that we're not overriding data if 2 loads happened at the # same time if self._users is not None: return + self._perm_lookup = perm_lookup = PermissionLookup(ent_reg) + if data is None: self._set_defaults() return @@ -374,6 +382,7 @@ class AuthStore: is_owner=user_dict['is_owner'], is_active=user_dict['is_active'], system_generated=user_dict['system_generated'], + perm_lookup=perm_lookup, ) for cred_dict in data['credentials']: diff --git a/homeassistant/auth/models.py b/homeassistant/auth/models.py index 4b192c35898..588d80047be 100644 --- a/homeassistant/auth/models.py +++ b/homeassistant/auth/models.py @@ -31,6 +31,9 @@ class User: """A user.""" name = attr.ib(type=str) # type: Optional[str] + perm_lookup = attr.ib( + type=perm_mdl.PermissionLookup, cmp=False, + ) # type: perm_mdl.PermissionLookup id = attr.ib(type=str, factory=lambda: uuid.uuid4().hex) is_owner = attr.ib(type=bool, default=False) is_active = attr.ib(type=bool, default=False) @@ -66,7 +69,8 @@ class User: self._permissions = perm_mdl.PolicyPermissions( perm_mdl.merge_policies([ - group.policy for group in self.groups])) + group.policy for group in self.groups]), + self.perm_lookup) return self._permissions diff --git a/homeassistant/auth/permissions/__init__.py b/homeassistant/auth/permissions/__init__.py index 9113f2b03a9..63e76dd2496 100644 --- a/homeassistant/auth/permissions/__init__.py +++ b/homeassistant/auth/permissions/__init__.py @@ -1,15 +1,18 @@ """Permissions for Home Assistant.""" import logging from typing import ( # noqa: F401 - cast, Any, Callable, Dict, List, Mapping, Set, Tuple, Union) + cast, Any, Callable, Dict, List, Mapping, Set, Tuple, Union, + TYPE_CHECKING) import voluptuous as vol from .const import CAT_ENTITIES +from .models import PermissionLookup from .types import PolicyType from .entities import ENTITY_POLICY_SCHEMA, compile_entities from .merge import merge_policies # noqa + POLICY_SCHEMA = vol.Schema({ vol.Optional(CAT_ENTITIES): ENTITY_POLICY_SCHEMA }) @@ -39,13 +42,16 @@ class AbstractPermissions: class PolicyPermissions(AbstractPermissions): """Handle permissions.""" - def __init__(self, policy: PolicyType) -> None: + def __init__(self, policy: PolicyType, + perm_lookup: PermissionLookup) -> None: """Initialize the permission class.""" self._policy = policy + self._perm_lookup = perm_lookup def _entity_func(self) -> Callable[[str, str], bool]: """Return a function that can test entity access.""" - return compile_entities(self._policy.get(CAT_ENTITIES)) + return compile_entities(self._policy.get(CAT_ENTITIES), + self._perm_lookup) def __eq__(self, other: Any) -> bool: """Equals check.""" diff --git a/homeassistant/auth/permissions/entities.py b/homeassistant/auth/permissions/entities.py index 59bba468a59..0073c952648 100644 --- a/homeassistant/auth/permissions/entities.py +++ b/homeassistant/auth/permissions/entities.py @@ -5,6 +5,7 @@ from typing import Callable, List, Union # noqa: F401 import voluptuous as vol from .const import SUBCAT_ALL, POLICY_READ, POLICY_CONTROL, POLICY_EDIT +from .models import PermissionLookup from .types import CategoryType, ValueType SINGLE_ENTITY_SCHEMA = vol.Any(True, vol.Schema({ @@ -14,6 +15,7 @@ SINGLE_ENTITY_SCHEMA = vol.Any(True, vol.Schema({ })) ENTITY_DOMAINS = 'domains' +ENTITY_DEVICE_IDS = 'device_ids' ENTITY_ENTITY_IDS = 'entity_ids' ENTITY_VALUES_SCHEMA = vol.Any(True, vol.Schema({ @@ -22,6 +24,7 @@ ENTITY_VALUES_SCHEMA = vol.Any(True, vol.Schema({ ENTITY_POLICY_SCHEMA = vol.Any(True, vol.Schema({ vol.Optional(SUBCAT_ALL): SINGLE_ENTITY_SCHEMA, + vol.Optional(ENTITY_DEVICE_IDS): ENTITY_VALUES_SCHEMA, vol.Optional(ENTITY_DOMAINS): ENTITY_VALUES_SCHEMA, vol.Optional(ENTITY_ENTITY_IDS): ENTITY_VALUES_SCHEMA, })) @@ -36,7 +39,7 @@ def _entity_allowed(schema: ValueType, key: str) \ return schema.get(key) -def compile_entities(policy: CategoryType) \ +def compile_entities(policy: CategoryType, perm_lookup: PermissionLookup) \ -> Callable[[str, str], bool]: """Compile policy into a function that tests policy.""" # None, Empty Dict, False @@ -57,6 +60,7 @@ def compile_entities(policy: CategoryType) \ assert isinstance(policy, dict) domains = policy.get(ENTITY_DOMAINS) + device_ids = policy.get(ENTITY_DEVICE_IDS) entity_ids = policy.get(ENTITY_ENTITY_IDS) all_entities = policy.get(SUBCAT_ALL) @@ -84,6 +88,29 @@ def compile_entities(policy: CategoryType) \ funcs.append(allowed_entity_id_dict) + if isinstance(device_ids, bool): + def allowed_device_id_bool(entity_id: str, key: str) \ + -> Union[None, bool]: + """Test if allowed device_id.""" + return device_ids + + funcs.append(allowed_device_id_bool) + + elif device_ids is not None: + def allowed_device_id_dict(entity_id: str, key: str) \ + -> Union[None, bool]: + """Test if allowed device_id.""" + entity_entry = perm_lookup.entity_registry.async_get(entity_id) + + if entity_entry is None or entity_entry.device_id is None: + return None + + return _entity_allowed( + device_ids.get(entity_entry.device_id), key # type: ignore + ) + + funcs.append(allowed_device_id_dict) + if isinstance(domains, bool): def allowed_domain_bool(entity_id: str, key: str) \ -> Union[None, bool]: diff --git a/homeassistant/auth/permissions/models.py b/homeassistant/auth/permissions/models.py new file mode 100644 index 00000000000..7ad7d5521c5 --- /dev/null +++ b/homeassistant/auth/permissions/models.py @@ -0,0 +1,17 @@ +"""Models for permissions.""" +from typing import TYPE_CHECKING + +import attr + +if TYPE_CHECKING: + # pylint: disable=unused-import + from homeassistant.helpers import ( # noqa + entity_registry as ent_reg, + ) + + +@attr.s(slots=True) +class PermissionLookup: + """Class to hold data for permission lookups.""" + + entity_registry = attr.ib(type='ent_reg.EntityRegistry') diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index c40d14652ad..57c8bcf0af8 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -10,6 +10,7 @@ timer. from collections import OrderedDict from itertools import chain import logging +from typing import Optional import weakref import attr @@ -85,6 +86,11 @@ class EntityRegistry: """Check if an entity_id is currently registered.""" return entity_id in self.entities + @callback + def async_get(self, entity_id: str) -> Optional[RegistryEntry]: + """Get EntityEntry for an entity_id.""" + return self.entities.get(entity_id) + @callback def async_get_entity_id(self, domain: str, platform: str, unique_id: str): """Check if an entity_id is currently registered.""" diff --git a/tests/auth/permissions/test_entities.py b/tests/auth/permissions/test_entities.py index 40de5ca7334..1fd70668f8b 100644 --- a/tests/auth/permissions/test_entities.py +++ b/tests/auth/permissions/test_entities.py @@ -4,12 +4,16 @@ import voluptuous as vol from homeassistant.auth.permissions.entities import ( compile_entities, ENTITY_POLICY_SCHEMA) +from homeassistant.auth.permissions.models import PermissionLookup +from homeassistant.helpers.entity_registry import RegistryEntry + +from tests.common import mock_registry def test_entities_none(): """Test entity ID policy.""" policy = None - compiled = compile_entities(policy) + compiled = compile_entities(policy, None) assert compiled('light.kitchen', 'read') is False @@ -17,7 +21,7 @@ def test_entities_empty(): """Test entity ID policy.""" policy = {} ENTITY_POLICY_SCHEMA(policy) - compiled = compile_entities(policy) + compiled = compile_entities(policy, None) assert compiled('light.kitchen', 'read') is False @@ -32,7 +36,7 @@ def test_entities_true(): """Test entity ID policy.""" policy = True ENTITY_POLICY_SCHEMA(policy) - compiled = compile_entities(policy) + compiled = compile_entities(policy, None) assert compiled('light.kitchen', 'read') is True @@ -42,7 +46,7 @@ def test_entities_domains_true(): 'domains': True } ENTITY_POLICY_SCHEMA(policy) - compiled = compile_entities(policy) + compiled = compile_entities(policy, None) assert compiled('light.kitchen', 'read') is True @@ -54,7 +58,7 @@ def test_entities_domains_domain_true(): } } ENTITY_POLICY_SCHEMA(policy) - compiled = compile_entities(policy) + compiled = compile_entities(policy, None) assert compiled('light.kitchen', 'read') is True assert compiled('switch.kitchen', 'read') is False @@ -76,7 +80,7 @@ def test_entities_entity_ids_true(): 'entity_ids': True } ENTITY_POLICY_SCHEMA(policy) - compiled = compile_entities(policy) + compiled = compile_entities(policy, None) assert compiled('light.kitchen', 'read') is True @@ -97,7 +101,7 @@ def test_entities_entity_ids_entity_id_true(): } } ENTITY_POLICY_SCHEMA(policy) - compiled = compile_entities(policy) + compiled = compile_entities(policy, None) assert compiled('light.kitchen', 'read') is True assert compiled('switch.kitchen', 'read') is False @@ -123,7 +127,7 @@ def test_entities_control_only(): } } ENTITY_POLICY_SCHEMA(policy) - compiled = compile_entities(policy) + compiled = compile_entities(policy, None) assert compiled('light.kitchen', 'read') is True assert compiled('light.kitchen', 'control') is False assert compiled('light.kitchen', 'edit') is False @@ -140,7 +144,7 @@ def test_entities_read_control(): } } ENTITY_POLICY_SCHEMA(policy) - compiled = compile_entities(policy) + compiled = compile_entities(policy, None) assert compiled('light.kitchen', 'read') is True assert compiled('light.kitchen', 'control') is True assert compiled('light.kitchen', 'edit') is False @@ -152,7 +156,7 @@ def test_entities_all_allow(): 'all': True } ENTITY_POLICY_SCHEMA(policy) - compiled = compile_entities(policy) + compiled = compile_entities(policy, None) assert compiled('light.kitchen', 'read') is True assert compiled('light.kitchen', 'control') is True assert compiled('switch.kitchen', 'read') is True @@ -166,7 +170,7 @@ def test_entities_all_read(): } } ENTITY_POLICY_SCHEMA(policy) - compiled = compile_entities(policy) + compiled = compile_entities(policy, None) assert compiled('light.kitchen', 'read') is True assert compiled('light.kitchen', 'control') is False assert compiled('switch.kitchen', 'read') is True @@ -180,8 +184,40 @@ def test_entities_all_control(): } } ENTITY_POLICY_SCHEMA(policy) - compiled = compile_entities(policy) + compiled = compile_entities(policy, None) assert compiled('light.kitchen', 'read') is False assert compiled('light.kitchen', 'control') is True assert compiled('switch.kitchen', 'read') is False assert compiled('switch.kitchen', 'control') is True + + +def test_entities_device_id_boolean(hass): + """Test entity ID policy applying control on device id.""" + registry = mock_registry(hass, { + 'test_domain.allowed': RegistryEntry( + entity_id='test_domain.allowed', + unique_id='1234', + platform='test_platform', + device_id='mock-allowed-dev-id' + ), + 'test_domain.not_allowed': RegistryEntry( + entity_id='test_domain.not_allowed', + unique_id='5678', + platform='test_platform', + device_id='mock-not-allowed-dev-id' + ), + }) + + policy = { + 'device_ids': { + 'mock-allowed-dev-id': { + 'read': True, + } + } + } + ENTITY_POLICY_SCHEMA(policy) + compiled = compile_entities(policy, PermissionLookup(registry)) + assert compiled('test_domain.allowed', 'read') is True + assert compiled('test_domain.allowed', 'control') is False + assert compiled('test_domain.not_allowed', 'read') is False + assert compiled('test_domain.not_allowed', 'control') is False diff --git a/tests/auth/permissions/test_system_policies.py b/tests/auth/permissions/test_system_policies.py index ba6fe214146..f6a68f0865a 100644 --- a/tests/auth/permissions/test_system_policies.py +++ b/tests/auth/permissions/test_system_policies.py @@ -8,7 +8,7 @@ def test_admin_policy(): # Make sure it's valid POLICY_SCHEMA(system_policies.ADMIN_POLICY) - perms = PolicyPermissions(system_policies.ADMIN_POLICY) + perms = PolicyPermissions(system_policies.ADMIN_POLICY, None) assert perms.check_entity('light.kitchen', 'read') assert perms.check_entity('light.kitchen', 'control') assert perms.check_entity('light.kitchen', 'edit') @@ -19,7 +19,7 @@ def test_read_only_policy(): # Make sure it's valid POLICY_SCHEMA(system_policies.READ_ONLY_POLICY) - perms = PolicyPermissions(system_policies.READ_ONLY_POLICY) + perms = PolicyPermissions(system_policies.READ_ONLY_POLICY, None) assert perms.check_entity('light.kitchen', 'read') assert not perms.check_entity('light.kitchen', 'control') assert not perms.check_entity('light.kitchen', 'edit') diff --git a/tests/auth/test_models.py b/tests/auth/test_models.py index b02111e8d02..329124bc979 100644 --- a/tests/auth/test_models.py +++ b/tests/auth/test_models.py @@ -5,7 +5,12 @@ from homeassistant.auth import models, permissions def test_owner_fetching_owner_permissions(): """Test we fetch the owner permissions for an owner user.""" group = models.Group(name="Test Group", policy={}) - owner = models.User(name="Test User", groups=[group], is_owner=True) + owner = models.User( + name="Test User", + perm_lookup=None, + groups=[group], + is_owner=True + ) assert owner.permissions is permissions.OwnerPermissions @@ -25,7 +30,11 @@ def test_permissions_merged(): } } }) - user = models.User(name="Test User", groups=[group, group2]) + user = models.User( + name="Test User", + perm_lookup=None, + groups=[group, group2] + ) # Make sure we cache instance assert user.permissions is user.permissions diff --git a/tests/common.py b/tests/common.py index db7ce6e3a17..d7b28b3039a 100644 --- a/tests/common.py +++ b/tests/common.py @@ -384,6 +384,7 @@ class MockUser(auth_models.User): 'name': name, 'system_generated': system_generated, 'groups': groups or [], + 'perm_lookup': None, } if id is not None: kwargs['id'] = id @@ -401,7 +402,8 @@ class MockUser(auth_models.User): def mock_policy(self, policy): """Mock a policy for a user.""" - self._permissions = auth_permissions.PolicyPermissions(policy) + self._permissions = auth_permissions.PolicyPermissions( + policy, self.perm_lookup) async def register_auth_provider(hass, config): diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index a4e9a571943..8fca7df69c1 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -232,7 +232,7 @@ async def test_call_context_target_all(hass, mock_service_platform_call, 'light.kitchen': True } } - })))): + }, None)))): await service.entity_service_call(hass, [ Mock(entities=mock_entities) ], Mock(), ha.ServiceCall('test_domain', 'test_service', @@ -253,7 +253,7 @@ async def test_call_context_target_specific(hass, mock_service_platform_call, 'light.kitchen': True } } - })))): + }, None)))): await service.entity_service_call(hass, [ Mock(entities=mock_entities) ], Mock(), ha.ServiceCall('test_domain', 'test_service', { @@ -271,7 +271,7 @@ async def test_call_context_target_specific_no_auth( with pytest.raises(exceptions.Unauthorized) as err: with patch('homeassistant.auth.AuthManager.async_get_user', return_value=mock_coro(Mock( - permissions=PolicyPermissions({})))): + permissions=PolicyPermissions({}, None)))): await service.entity_service_call(hass, [ Mock(entities=mock_entities) ], Mock(), ha.ServiceCall('test_domain', 'test_service', {