From 8b8629a5f416e6f04bd246f71f13250a75451033 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sun, 25 Nov 2018 18:04:48 +0100 Subject: [PATCH] Add permission checks to Rest API (#18639) * Add permission checks to Rest API * Clean up unnecessary method * Remove all the tuple stuff from entity check * Simplify perms * Correct param name for owner permission * Hass.io make/update user to be admin * Types --- homeassistant/auth/__init__.py | 17 +++- homeassistant/auth/auth_store.py | 27 +++++++ homeassistant/auth/models.py | 16 +++- homeassistant/auth/permissions/__init__.py | 61 +++++---------- homeassistant/auth/permissions/entities.py | 40 +++++----- homeassistant/components/api.py | 27 ++++++- homeassistant/components/hassio/__init__.py | 9 ++- homeassistant/components/http/view.py | 10 ++- homeassistant/helpers/service.py | 10 +-- tests/auth/permissions/test_entities.py | 50 ++++++------ tests/auth/permissions/test_init.py | 34 -------- tests/common.py | 7 +- tests/components/conftest.py | 5 +- tests/components/hassio/test_init.py | 28 +++++++ tests/components/test_api.py | 86 +++++++++++++++++++-- 15 files changed, 282 insertions(+), 145 deletions(-) delete mode 100644 tests/auth/permissions/test_init.py diff --git a/homeassistant/auth/__init__.py b/homeassistant/auth/__init__.py index e69dec37df2..7d8ef13d2bb 100644 --- a/homeassistant/auth/__init__.py +++ b/homeassistant/auth/__init__.py @@ -132,13 +132,15 @@ class AuthManager: return None - async def async_create_system_user(self, name: str) -> models.User: + async def async_create_system_user( + self, name: str, + group_ids: Optional[List[str]] = None) -> models.User: """Create a system user.""" user = await self._store.async_create_user( name=name, system_generated=True, is_active=True, - group_ids=[], + group_ids=group_ids or [], ) self.hass.bus.async_fire(EVENT_USER_ADDED, { @@ -217,6 +219,17 @@ class AuthManager: 'user_id': user.id }) + async def async_update_user(self, user: models.User, + name: Optional[str] = None, + group_ids: Optional[List[str]] = None) -> None: + """Update a user.""" + kwargs = {} # type: Dict[str,Any] + if name is not None: + kwargs['name'] = name + if group_ids is not None: + kwargs['group_ids'] = group_ids + await self._store.async_update_user(user, **kwargs) + async def async_activate_user(self, user: models.User) -> None: """Activate a user.""" await self._store.async_activate_user(user) diff --git a/homeassistant/auth/auth_store.py b/homeassistant/auth/auth_store.py index 867d5357a58..cf82c40a4d3 100644 --- a/homeassistant/auth/auth_store.py +++ b/homeassistant/auth/auth_store.py @@ -133,6 +133,33 @@ class AuthStore: self._users.pop(user.id) self._async_schedule_save() + async def async_update_user( + self, user: models.User, name: Optional[str] = None, + is_active: Optional[bool] = None, + group_ids: Optional[List[str]] = None) -> None: + """Update a user.""" + assert self._groups is not None + + if group_ids is not None: + groups = [] + for grid in group_ids: + group = self._groups.get(grid) + if group is None: + raise ValueError("Invalid group specified.") + groups.append(group) + + user.groups = groups + user.invalidate_permission_cache() + + for attr_name, value in ( + ('name', name), + ('is_active', is_active), + ): + if value is not None: + setattr(user, attr_name, value) + + self._async_schedule_save() + async def async_activate_user(self, user: models.User) -> None: """Activate a user.""" user.is_active = True diff --git a/homeassistant/auth/models.py b/homeassistant/auth/models.py index cefaabe7521..4b192c35898 100644 --- a/homeassistant/auth/models.py +++ b/homeassistant/auth/models.py @@ -8,6 +8,7 @@ import attr from homeassistant.util import dt as dt_util from . import permissions as perm_mdl +from .const import GROUP_ID_ADMIN from .util import generate_secret TOKEN_TYPE_NORMAL = 'normal' @@ -48,7 +49,7 @@ class User: ) # type: Dict[str, RefreshToken] _permissions = attr.ib( - type=perm_mdl.PolicyPermissions, + type=Optional[perm_mdl.PolicyPermissions], init=False, cmp=False, default=None, @@ -69,6 +70,19 @@ class User: return self._permissions + @property + def is_admin(self) -> bool: + """Return if user is part of the admin group.""" + if self.is_owner: + return True + + return self.is_active and any( + gr.id == GROUP_ID_ADMIN for gr in self.groups) + + def invalidate_permission_cache(self) -> None: + """Invalidate permission cache.""" + self._permissions = None + @attr.s(slots=True) class RefreshToken: diff --git a/homeassistant/auth/permissions/__init__.py b/homeassistant/auth/permissions/__init__.py index fd3cf81f029..9113f2b03a9 100644 --- a/homeassistant/auth/permissions/__init__.py +++ b/homeassistant/auth/permissions/__init__.py @@ -5,10 +5,8 @@ from typing import ( # noqa: F401 import voluptuous as vol -from homeassistant.core import State - from .const import CAT_ENTITIES -from .types import CategoryType, PolicyType +from .types import PolicyType from .entities import ENTITY_POLICY_SCHEMA, compile_entities from .merge import merge_policies # noqa @@ -22,13 +20,20 @@ _LOGGER = logging.getLogger(__name__) class AbstractPermissions: """Default permissions class.""" - def check_entity(self, entity_id: str, key: str) -> bool: - """Test if we can access entity.""" + _cached_entity_func = None + + def _entity_func(self) -> Callable[[str, str], bool]: + """Return a function that can test entity access.""" raise NotImplementedError - def filter_states(self, states: List[State]) -> List[State]: - """Filter a list of states for what the user is allowed to see.""" - raise NotImplementedError + def check_entity(self, entity_id: str, key: str) -> bool: + """Check if we can access entity.""" + entity_func = self._cached_entity_func + + if entity_func is None: + entity_func = self._cached_entity_func = self._entity_func() + + return entity_func(entity_id, key) class PolicyPermissions(AbstractPermissions): @@ -37,34 +42,10 @@ class PolicyPermissions(AbstractPermissions): def __init__(self, policy: PolicyType) -> None: """Initialize the permission class.""" self._policy = policy - self._compiled = {} # type: Dict[str, Callable[..., bool]] - def check_entity(self, entity_id: str, key: str) -> bool: - """Test if we can access entity.""" - func = self._policy_func(CAT_ENTITIES, compile_entities) - return func(entity_id, (key,)) - - def filter_states(self, states: List[State]) -> List[State]: - """Filter a list of states for what the user is allowed to see.""" - func = self._policy_func(CAT_ENTITIES, compile_entities) - keys = ('read',) - return [entity for entity in states if func(entity.entity_id, keys)] - - def _policy_func(self, category: str, - compile_func: Callable[[CategoryType], Callable]) \ - -> Callable[..., bool]: - """Get a policy function.""" - func = self._compiled.get(category) - - if func: - return func - - func = self._compiled[category] = compile_func( - self._policy.get(category)) - - _LOGGER.debug("Compiled %s func: %s", category, func) - - return func + def _entity_func(self) -> Callable[[str, str], bool]: + """Return a function that can test entity access.""" + return compile_entities(self._policy.get(CAT_ENTITIES)) def __eq__(self, other: Any) -> bool: """Equals check.""" @@ -78,13 +59,9 @@ class _OwnerPermissions(AbstractPermissions): # pylint: disable=no-self-use - def check_entity(self, entity_id: str, key: str) -> bool: - """Test if we can access entity.""" - return True - - def filter_states(self, states: List[State]) -> List[State]: - """Filter a list of states for what the user is allowed to see.""" - return states + def _entity_func(self) -> Callable[[str, str], bool]: + """Return a function that can test entity access.""" + return lambda entity_id, key: True OwnerPermissions = _OwnerPermissions() # pylint: disable=invalid-name diff --git a/homeassistant/auth/permissions/entities.py b/homeassistant/auth/permissions/entities.py index 89b9398628c..74a43246fd1 100644 --- a/homeassistant/auth/permissions/entities.py +++ b/homeassistant/auth/permissions/entities.py @@ -28,28 +28,28 @@ ENTITY_POLICY_SCHEMA = vol.Any(True, vol.Schema({ })) -def _entity_allowed(schema: ValueType, keys: Tuple[str]) \ +def _entity_allowed(schema: ValueType, key: str) \ -> Union[bool, None]: """Test if an entity is allowed based on the keys.""" if schema is None or isinstance(schema, bool): return schema assert isinstance(schema, dict) - return schema.get(keys[0]) + return schema.get(key) def compile_entities(policy: CategoryType) \ - -> Callable[[str, Tuple[str]], bool]: + -> Callable[[str, str], bool]: """Compile policy into a function that tests policy.""" # None, Empty Dict, False if not policy: - def apply_policy_deny_all(entity_id: str, keys: Tuple[str]) -> bool: + def apply_policy_deny_all(entity_id: str, key: str) -> bool: """Decline all.""" return False return apply_policy_deny_all if policy is True: - def apply_policy_allow_all(entity_id: str, keys: Tuple[str]) -> bool: + def apply_policy_allow_all(entity_id: str, key: str) -> bool: """Approve all.""" return True @@ -61,7 +61,7 @@ def compile_entities(policy: CategoryType) \ entity_ids = policy.get(ENTITY_ENTITY_IDS) all_entities = policy.get(SUBCAT_ALL) - funcs = [] # type: List[Callable[[str, Tuple[str]], Union[None, bool]]] + funcs = [] # type: List[Callable[[str, str], Union[None, bool]]] # The order of these functions matter. The more precise are at the top. # If a function returns None, they cannot handle it. @@ -70,23 +70,23 @@ def compile_entities(policy: CategoryType) \ # Setting entity_ids to a boolean is final decision for permissions # So return right away. if isinstance(entity_ids, bool): - def allowed_entity_id_bool(entity_id: str, keys: Tuple[str]) -> bool: + def allowed_entity_id_bool(entity_id: str, key: str) -> bool: """Test if allowed entity_id.""" return entity_ids # type: ignore return allowed_entity_id_bool if entity_ids is not None: - def allowed_entity_id_dict(entity_id: str, keys: Tuple[str]) \ + def allowed_entity_id_dict(entity_id: str, key: str) \ -> Union[None, bool]: """Test if allowed entity_id.""" return _entity_allowed( - entity_ids.get(entity_id), keys) # type: ignore + entity_ids.get(entity_id), key) # type: ignore funcs.append(allowed_entity_id_dict) if isinstance(domains, bool): - def allowed_domain_bool(entity_id: str, keys: Tuple[str]) \ + def allowed_domain_bool(entity_id: str, key: str) \ -> Union[None, bool]: """Test if allowed domain.""" return domains @@ -94,31 +94,31 @@ def compile_entities(policy: CategoryType) \ funcs.append(allowed_domain_bool) elif domains is not None: - def allowed_domain_dict(entity_id: str, keys: Tuple[str]) \ + def allowed_domain_dict(entity_id: str, key: str) \ -> Union[None, bool]: """Test if allowed domain.""" domain = entity_id.split(".", 1)[0] - return _entity_allowed(domains.get(domain), keys) # type: ignore + return _entity_allowed(domains.get(domain), key) # type: ignore funcs.append(allowed_domain_dict) if isinstance(all_entities, bool): - def allowed_all_entities_bool(entity_id: str, keys: Tuple[str]) \ + def allowed_all_entities_bool(entity_id: str, key: str) \ -> Union[None, bool]: """Test if allowed domain.""" return all_entities funcs.append(allowed_all_entities_bool) elif all_entities is not None: - def allowed_all_entities_dict(entity_id: str, keys: Tuple[str]) \ + def allowed_all_entities_dict(entity_id: str, key: str) \ -> Union[None, bool]: """Test if allowed domain.""" - return _entity_allowed(all_entities, keys) + return _entity_allowed(all_entities, key) funcs.append(allowed_all_entities_dict) # Can happen if no valid subcategories specified if not funcs: - def apply_policy_deny_all_2(entity_id: str, keys: Tuple[str]) -> bool: + def apply_policy_deny_all_2(entity_id: str, key: str) -> bool: """Decline all.""" return False @@ -128,16 +128,16 @@ def compile_entities(policy: CategoryType) \ func = funcs[0] @wraps(func) - def apply_policy_func(entity_id: str, keys: Tuple[str]) -> bool: + def apply_policy_func(entity_id: str, key: str) -> bool: """Apply a single policy function.""" - return func(entity_id, keys) is True + return func(entity_id, key) is True return apply_policy_func - def apply_policy_funcs(entity_id: str, keys: Tuple[str]) -> bool: + def apply_policy_funcs(entity_id: str, key: str) -> bool: """Apply several policy functions.""" for func in funcs: - result = func(entity_id, keys) + result = func(entity_id, key) if result is not None: return result return False diff --git a/homeassistant/components/api.py b/homeassistant/components/api.py index cbe404537eb..b001bcd0437 100644 --- a/homeassistant/components/api.py +++ b/homeassistant/components/api.py @@ -20,7 +20,8 @@ from homeassistant.const import ( URL_API_SERVICES, URL_API_STATES, URL_API_STATES_ENTITY, URL_API_STREAM, URL_API_TEMPLATE, __version__) import homeassistant.core as ha -from homeassistant.exceptions import TemplateError +from homeassistant.auth.permissions.const import POLICY_READ +from homeassistant.exceptions import TemplateError, Unauthorized from homeassistant.helpers import template from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.helpers.state import AsyncTrackStates @@ -81,6 +82,8 @@ class APIEventStream(HomeAssistantView): async def get(self, request): """Provide a streaming interface for the event bus.""" + if not request['hass_user'].is_admin: + raise Unauthorized() hass = request.app['hass'] stop_obj = object() to_write = asyncio.Queue(loop=hass.loop) @@ -185,7 +188,13 @@ class APIStatesView(HomeAssistantView): @ha.callback def get(self, request): """Get current states.""" - return self.json(request.app['hass'].states.async_all()) + user = request['hass_user'] + entity_perm = user.permissions.check_entity + states = [ + state for state in request.app['hass'].states.async_all() + if entity_perm(state.entity_id, 'read') + ] + return self.json(states) class APIEntityStateView(HomeAssistantView): @@ -197,6 +206,10 @@ class APIEntityStateView(HomeAssistantView): @ha.callback def get(self, request, entity_id): """Retrieve state of entity.""" + user = request['hass_user'] + if not user.permissions.check_entity(entity_id, POLICY_READ): + raise Unauthorized(entity_id=entity_id) + state = request.app['hass'].states.get(entity_id) if state: return self.json(state) @@ -204,6 +217,8 @@ class APIEntityStateView(HomeAssistantView): async def post(self, request, entity_id): """Update state of entity.""" + if not request['hass_user'].is_admin: + raise Unauthorized(entity_id=entity_id) hass = request.app['hass'] try: data = await request.json() @@ -236,6 +251,8 @@ class APIEntityStateView(HomeAssistantView): @ha.callback def delete(self, request, entity_id): """Remove entity.""" + if not request['hass_user'].is_admin: + raise Unauthorized(entity_id=entity_id) if request.app['hass'].states.async_remove(entity_id): return self.json_message("Entity removed.") return self.json_message("Entity not found.", HTTP_NOT_FOUND) @@ -261,6 +278,8 @@ class APIEventView(HomeAssistantView): async def post(self, request, event_type): """Fire events.""" + if not request['hass_user'].is_admin: + raise Unauthorized() body = await request.text() try: event_data = json.loads(body) if body else None @@ -346,6 +365,8 @@ class APITemplateView(HomeAssistantView): async def post(self, request): """Render a template.""" + if not request['hass_user'].is_admin: + raise Unauthorized() try: data = await request.json() tpl = template.Template(data['template'], request.app['hass']) @@ -363,6 +384,8 @@ class APIErrorLog(HomeAssistantView): async def get(self, request): """Retrieve API error log.""" + if not request['hass_user'].is_admin: + raise Unauthorized() return web.FileResponse(request.app['hass'].data[DATA_LOGGING]) diff --git a/homeassistant/components/hassio/__init__.py b/homeassistant/components/hassio/__init__.py index 4c13cb799a6..6bfcaaa5d85 100644 --- a/homeassistant/components/hassio/__init__.py +++ b/homeassistant/components/hassio/__init__.py @@ -10,6 +10,7 @@ import os import voluptuous as vol +from homeassistant.auth.const import GROUP_ID_ADMIN from homeassistant.components import SERVICE_CHECK_CONFIG from homeassistant.const import ( ATTR_NAME, SERVICE_HOMEASSISTANT_RESTART, SERVICE_HOMEASSISTANT_STOP) @@ -181,8 +182,14 @@ async def async_setup(hass, config): if user and user.refresh_tokens: refresh_token = list(user.refresh_tokens.values())[0] + # Migrate old hass.io users to be admin. + if not user.is_admin: + await hass.auth.async_update_user( + user, group_ids=[GROUP_ID_ADMIN]) + if refresh_token is None: - user = await hass.auth.async_create_system_user('Hass.io') + user = await hass.auth.async_create_system_user( + 'Hass.io', [GROUP_ID_ADMIN]) refresh_token = await hass.auth.async_create_refresh_token(user) data['hassio_user'] = user.id await store.async_save(data) diff --git a/homeassistant/components/http/view.py b/homeassistant/components/http/view.py index b3b2587fc45..30d4ed0ab8d 100644 --- a/homeassistant/components/http/view.py +++ b/homeassistant/components/http/view.py @@ -14,6 +14,7 @@ from aiohttp.web_exceptions import HTTPUnauthorized, HTTPInternalServerError from homeassistant.components.http.ban import process_success_login from homeassistant.core import Context, is_callback from homeassistant.const import CONTENT_TYPE_JSON +from homeassistant import exceptions from homeassistant.helpers.json import JSONEncoder from .const import KEY_AUTHENTICATED, KEY_REAL_IP @@ -107,10 +108,13 @@ def request_handler_factory(view, handler): _LOGGER.info('Serving %s to %s (auth: %s)', request.path, request.get(KEY_REAL_IP), authenticated) - result = handler(request, **request.match_info) + try: + result = handler(request, **request.match_info) - if asyncio.iscoroutine(result): - result = await result + if asyncio.iscoroutine(result): + result = await result + except exceptions.Unauthorized: + raise HTTPUnauthorized() if isinstance(result, web.StreamResponse): # The method handler returned a ready-made Response, how nice of it diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 5e0d9c7e88a..e8068f57286 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -192,9 +192,9 @@ async def entity_service_call(hass, platforms, func, call): user = await hass.auth.async_get_user(call.context.user_id) if user is None: raise UnknownUser(context=call.context) - perms = user.permissions + entity_perms = user.permissions.check_entity else: - perms = None + entity_perms = None # Are we trying to target all entities target_all_entities = ATTR_ENTITY_ID not in call.data @@ -218,7 +218,7 @@ async def entity_service_call(hass, platforms, func, call): # the service on. platforms_entities = [] - if perms is None: + if entity_perms is None: for platform in platforms: if target_all_entities: platforms_entities.append(list(platform.entities.values())) @@ -234,7 +234,7 @@ async def entity_service_call(hass, platforms, func, call): for platform in platforms: platforms_entities.append([ entity for entity in platform.entities.values() - if perms.check_entity(entity.entity_id, POLICY_CONTROL)]) + if entity_perms(entity.entity_id, POLICY_CONTROL)]) else: for platform in platforms: @@ -243,7 +243,7 @@ async def entity_service_call(hass, platforms, func, call): if entity.entity_id not in entity_ids: continue - if not perms.check_entity(entity.entity_id, POLICY_CONTROL): + if not entity_perms(entity.entity_id, POLICY_CONTROL): raise Unauthorized( context=call.context, entity_id=entity.entity_id, diff --git a/tests/auth/permissions/test_entities.py b/tests/auth/permissions/test_entities.py index 33c164d12b4..40de5ca7334 100644 --- a/tests/auth/permissions/test_entities.py +++ b/tests/auth/permissions/test_entities.py @@ -10,7 +10,7 @@ def test_entities_none(): """Test entity ID policy.""" policy = None compiled = compile_entities(policy) - assert compiled('light.kitchen', ('read',)) is False + assert compiled('light.kitchen', 'read') is False def test_entities_empty(): @@ -18,7 +18,7 @@ def test_entities_empty(): policy = {} ENTITY_POLICY_SCHEMA(policy) compiled = compile_entities(policy) - assert compiled('light.kitchen', ('read',)) is False + assert compiled('light.kitchen', 'read') is False def test_entities_false(): @@ -33,7 +33,7 @@ def test_entities_true(): policy = True ENTITY_POLICY_SCHEMA(policy) compiled = compile_entities(policy) - assert compiled('light.kitchen', ('read',)) is True + assert compiled('light.kitchen', 'read') is True def test_entities_domains_true(): @@ -43,7 +43,7 @@ def test_entities_domains_true(): } ENTITY_POLICY_SCHEMA(policy) compiled = compile_entities(policy) - assert compiled('light.kitchen', ('read',)) is True + assert compiled('light.kitchen', 'read') is True def test_entities_domains_domain_true(): @@ -55,8 +55,8 @@ def test_entities_domains_domain_true(): } ENTITY_POLICY_SCHEMA(policy) compiled = compile_entities(policy) - assert compiled('light.kitchen', ('read',)) is True - assert compiled('switch.kitchen', ('read',)) is False + assert compiled('light.kitchen', 'read') is True + assert compiled('switch.kitchen', 'read') is False def test_entities_domains_domain_false(): @@ -77,7 +77,7 @@ def test_entities_entity_ids_true(): } ENTITY_POLICY_SCHEMA(policy) compiled = compile_entities(policy) - assert compiled('light.kitchen', ('read',)) is True + assert compiled('light.kitchen', 'read') is True def test_entities_entity_ids_false(): @@ -98,8 +98,8 @@ def test_entities_entity_ids_entity_id_true(): } ENTITY_POLICY_SCHEMA(policy) compiled = compile_entities(policy) - assert compiled('light.kitchen', ('read',)) is True - assert compiled('switch.kitchen', ('read',)) is False + assert compiled('light.kitchen', 'read') is True + assert compiled('switch.kitchen', 'read') is False def test_entities_entity_ids_entity_id_false(): @@ -124,9 +124,9 @@ def test_entities_control_only(): } ENTITY_POLICY_SCHEMA(policy) compiled = compile_entities(policy) - assert compiled('light.kitchen', ('read',)) is True - assert compiled('light.kitchen', ('control',)) is False - assert compiled('light.kitchen', ('edit',)) is False + assert compiled('light.kitchen', 'read') is True + assert compiled('light.kitchen', 'control') is False + assert compiled('light.kitchen', 'edit') is False def test_entities_read_control(): @@ -141,9 +141,9 @@ def test_entities_read_control(): } ENTITY_POLICY_SCHEMA(policy) compiled = compile_entities(policy) - assert compiled('light.kitchen', ('read',)) is True - assert compiled('light.kitchen', ('control',)) is True - assert compiled('light.kitchen', ('edit',)) is False + assert compiled('light.kitchen', 'read') is True + assert compiled('light.kitchen', 'control') is True + assert compiled('light.kitchen', 'edit') is False def test_entities_all_allow(): @@ -153,9 +153,9 @@ def test_entities_all_allow(): } ENTITY_POLICY_SCHEMA(policy) compiled = compile_entities(policy) - assert compiled('light.kitchen', ('read',)) is True - assert compiled('light.kitchen', ('control',)) is True - assert compiled('switch.kitchen', ('read',)) is True + assert compiled('light.kitchen', 'read') is True + assert compiled('light.kitchen', 'control') is True + assert compiled('switch.kitchen', 'read') is True def test_entities_all_read(): @@ -167,9 +167,9 @@ def test_entities_all_read(): } ENTITY_POLICY_SCHEMA(policy) compiled = compile_entities(policy) - assert compiled('light.kitchen', ('read',)) is True - assert compiled('light.kitchen', ('control',)) is False - assert compiled('switch.kitchen', ('read',)) is True + assert compiled('light.kitchen', 'read') is True + assert compiled('light.kitchen', 'control') is False + assert compiled('switch.kitchen', 'read') is True def test_entities_all_control(): @@ -181,7 +181,7 @@ def test_entities_all_control(): } ENTITY_POLICY_SCHEMA(policy) compiled = compile_entities(policy) - assert compiled('light.kitchen', ('read',)) is False - assert compiled('light.kitchen', ('control',)) is True - assert compiled('switch.kitchen', ('read',)) is False - assert compiled('switch.kitchen', ('control',)) is True + assert compiled('light.kitchen', 'read') is False + assert compiled('light.kitchen', 'control') is True + assert compiled('switch.kitchen', 'read') is False + assert compiled('switch.kitchen', 'control') is True diff --git a/tests/auth/permissions/test_init.py b/tests/auth/permissions/test_init.py deleted file mode 100644 index fdc5440a9d5..00000000000 --- a/tests/auth/permissions/test_init.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Tests for the auth permission system.""" -from homeassistant.core import State -from homeassistant.auth import permissions - - -def test_policy_perm_filter_states(): - """Test filtering entitites.""" - states = [ - State('light.kitchen', 'on'), - State('light.living_room', 'off'), - State('light.balcony', 'on'), - ] - perm = permissions.PolicyPermissions({ - 'entities': { - 'entity_ids': { - 'light.kitchen': True, - 'light.balcony': True, - } - } - }) - filtered = perm.filter_states(states) - assert len(filtered) == 2 - assert filtered == [states[0], states[2]] - - -def test_owner_permissions(): - """Test owner permissions access all.""" - assert permissions.OwnerPermissions.check_entity('light.kitchen', 'write') - states = [ - State('light.kitchen', 'on'), - State('light.living_room', 'off'), - State('light.balcony', 'on'), - ] - assert permissions.OwnerPermissions.filter_states(states) == states diff --git a/tests/common.py b/tests/common.py index c6a75fcb63d..d5056e220f0 100644 --- a/tests/common.py +++ b/tests/common.py @@ -14,7 +14,8 @@ from contextlib import contextmanager from homeassistant import auth, core as ha, config_entries from homeassistant.auth import ( - models as auth_models, auth_store, providers as auth_providers) + models as auth_models, auth_store, providers as auth_providers, + permissions as auth_permissions) from homeassistant.auth.permissions import system_policies from homeassistant.setup import setup_component, async_setup_component from homeassistant.config import async_process_component_config @@ -400,6 +401,10 @@ class MockUser(auth_models.User): auth_mgr._store._users[self.id] = self return self + def mock_policy(self, policy): + """Mock a policy for a user.""" + self._permissions = auth_permissions.PolicyPermissions(policy) + async def register_auth_provider(hass, config): """Register an auth provider.""" diff --git a/tests/components/conftest.py b/tests/components/conftest.py index b519b8e936d..46d75a56ad6 100644 --- a/tests/components/conftest.py +++ b/tests/components/conftest.py @@ -80,11 +80,10 @@ def hass_ws_client(aiohttp_client): @pytest.fixture -def hass_access_token(hass): +def hass_access_token(hass, hass_admin_user): """Return an access token to access Home Assistant.""" - user = MockUser().add_to_hass(hass) refresh_token = hass.loop.run_until_complete( - hass.auth.async_create_refresh_token(user, CLIENT_ID)) + hass.auth.async_create_refresh_token(hass_admin_user, CLIENT_ID)) yield hass.auth.async_create_access_token(refresh_token) diff --git a/tests/components/hassio/test_init.py b/tests/components/hassio/test_init.py index 4fd59dd3f7a..51fca931faa 100644 --- a/tests/components/hassio/test_init.py +++ b/tests/components/hassio/test_init.py @@ -5,6 +5,7 @@ from unittest.mock import patch, Mock import pytest +from homeassistant.auth.const import GROUP_ID_ADMIN from homeassistant.setup import async_setup_component from homeassistant.components.hassio import ( STORAGE_KEY, async_check_config) @@ -106,6 +107,8 @@ async def test_setup_api_push_api_data_default(hass, aioclient_mock, ) assert hassio_user is not None assert hassio_user.system_generated + assert len(hassio_user.groups) == 1 + assert hassio_user.groups[0].id == GROUP_ID_ADMIN for token in hassio_user.refresh_tokens.values(): if token.token == refresh_token: break @@ -113,6 +116,31 @@ async def test_setup_api_push_api_data_default(hass, aioclient_mock, assert False, 'refresh token not found' +async def test_setup_adds_admin_group_to_user(hass, aioclient_mock, + hass_storage): + """Test setup with API push default data.""" + # Create user without admin + user = await hass.auth.async_create_system_user('Hass.io') + assert not user.is_admin + await hass.auth.async_create_refresh_token(user) + + hass_storage[STORAGE_KEY] = { + 'data': {'hassio_user': user.id}, + 'key': STORAGE_KEY, + 'version': 1 + } + + with patch.dict(os.environ, MOCK_ENVIRON), \ + patch('homeassistant.auth.AuthManager.active', return_value=True): + result = await async_setup_component(hass, 'hassio', { + 'http': {}, + 'hassio': {} + }) + assert result + + assert user.is_admin + + async def test_setup_api_push_api_data_no_auth(hass, aioclient_mock, hass_storage): """Test setup with API push default data.""" diff --git a/tests/components/test_api.py b/tests/components/test_api.py index 6f6b4e93068..3ebfa05a3d3 100644 --- a/tests/components/test_api.py +++ b/tests/components/test_api.py @@ -16,10 +16,12 @@ from tests.common import async_mock_service @pytest.fixture -def mock_api_client(hass, aiohttp_client): - """Start the Hass HTTP component.""" +def mock_api_client(hass, aiohttp_client, hass_access_token): + """Start the Hass HTTP component and return admin API client.""" hass.loop.run_until_complete(async_setup_component(hass, 'api', {})) - return hass.loop.run_until_complete(aiohttp_client(hass.http.app)) + return hass.loop.run_until_complete(aiohttp_client(hass.http.app, headers={ + 'Authorization': 'Bearer {}'.format(hass_access_token) + })) @asyncio.coroutine @@ -405,7 +407,8 @@ def _listen_count(hass): return sum(hass.bus.async_listeners().values()) -async def test_api_error_log(hass, aiohttp_client): +async def test_api_error_log(hass, aiohttp_client, hass_access_token, + hass_admin_user): """Test if we can fetch the error log.""" hass.data[DATA_LOGGING] = '/some/path' await async_setup_component(hass, 'api', { @@ -416,7 +419,7 @@ async def test_api_error_log(hass, aiohttp_client): client = await aiohttp_client(hass.http.app) resp = await client.get(const.URL_API_ERROR_LOG) - # Verufy auth required + # Verify auth required assert resp.status == 401 with patch( @@ -424,7 +427,7 @@ async def test_api_error_log(hass, aiohttp_client): return_value=web.Response(status=200, text='Hello') ) as mock_file: resp = await client.get(const.URL_API_ERROR_LOG, headers={ - 'x-ha-access': 'yolo' + 'Authorization': 'Bearer {}'.format(hass_access_token) }) assert len(mock_file.mock_calls) == 1 @@ -432,6 +435,13 @@ async def test_api_error_log(hass, aiohttp_client): assert resp.status == 200 assert await resp.text() == 'Hello' + # Verify we require admin user + hass_admin_user.groups = [] + resp = await client.get(const.URL_API_ERROR_LOG, headers={ + 'Authorization': 'Bearer {}'.format(hass_access_token) + }) + assert resp.status == 401 + async def test_api_fire_event_context(hass, mock_api_client, hass_access_token): @@ -494,3 +504,67 @@ async def test_api_set_state_context(hass, mock_api_client, hass_access_token): state = hass.states.get('light.kitchen') assert state.context.user_id == refresh_token.user.id + + +async def test_event_stream_requires_admin(hass, mock_api_client, + hass_admin_user): + """Test user needs to be admin to access event stream.""" + hass_admin_user.groups = [] + resp = await mock_api_client.get('/api/stream') + assert resp.status == 401 + + +async def test_states_view_filters(hass, mock_api_client, hass_admin_user): + """Test filtering only visible states.""" + hass_admin_user.mock_policy({ + 'entities': { + 'entity_ids': { + 'test.entity': True + } + } + }) + hass.states.async_set('test.entity', 'hello') + hass.states.async_set('test.not_visible_entity', 'invisible') + resp = await mock_api_client.get(const.URL_API_STATES) + assert resp.status == 200 + json = await resp.json() + assert len(json) == 1 + assert json[0]['entity_id'] == 'test.entity' + + +async def test_get_entity_state_read_perm(hass, mock_api_client, + hass_admin_user): + """Test getting a state requires read permission.""" + hass_admin_user.mock_policy({}) + resp = await mock_api_client.get('/api/states/light.test') + assert resp.status == 401 + + +async def test_post_entity_state_admin(hass, mock_api_client, hass_admin_user): + """Test updating state requires admin.""" + hass_admin_user.groups = [] + resp = await mock_api_client.post('/api/states/light.test') + assert resp.status == 401 + + +async def test_delete_entity_state_admin(hass, mock_api_client, + hass_admin_user): + """Test deleting entity requires admin.""" + hass_admin_user.groups = [] + resp = await mock_api_client.delete('/api/states/light.test') + assert resp.status == 401 + + +async def test_post_event_admin(hass, mock_api_client, hass_admin_user): + """Test sending event requires admin.""" + hass_admin_user.groups = [] + resp = await mock_api_client.post('/api/events/state_changed') + assert resp.status == 401 + + +async def test_rendering_template_admin(hass, mock_api_client, + hass_admin_user): + """Test rendering a template requires admin.""" + hass_admin_user.groups = [] + resp = await mock_api_client.post('/api/template') + assert resp.status == 401