Add permission checks to Rest API (#18639)

* Add permission checks to Rest API

* Clean up unnecessary method

* Remove all the tuple stuff from entity check

* Simplify perms

* Correct param name for owner permission

* Hass.io make/update user to be admin

* Types
This commit is contained in:
Paulus Schoutsen 2018-11-25 18:04:48 +01:00 committed by Pascal Vizeli
parent f387cdec59
commit 8b8629a5f4
15 changed files with 282 additions and 145 deletions

View file

@ -132,13 +132,15 @@ class AuthManager:
return None return None
async def async_create_system_user(self, name: str) -> models.User: async def async_create_system_user(
self, name: str,
group_ids: Optional[List[str]] = None) -> models.User:
"""Create a system user.""" """Create a system user."""
user = await self._store.async_create_user( user = await self._store.async_create_user(
name=name, name=name,
system_generated=True, system_generated=True,
is_active=True, is_active=True,
group_ids=[], group_ids=group_ids or [],
) )
self.hass.bus.async_fire(EVENT_USER_ADDED, { self.hass.bus.async_fire(EVENT_USER_ADDED, {
@ -217,6 +219,17 @@ class AuthManager:
'user_id': user.id 'user_id': user.id
}) })
async def async_update_user(self, user: models.User,
name: Optional[str] = None,
group_ids: Optional[List[str]] = None) -> None:
"""Update a user."""
kwargs = {} # type: Dict[str,Any]
if name is not None:
kwargs['name'] = name
if group_ids is not None:
kwargs['group_ids'] = group_ids
await self._store.async_update_user(user, **kwargs)
async def async_activate_user(self, user: models.User) -> None: async def async_activate_user(self, user: models.User) -> None:
"""Activate a user.""" """Activate a user."""
await self._store.async_activate_user(user) await self._store.async_activate_user(user)

View file

@ -133,6 +133,33 @@ class AuthStore:
self._users.pop(user.id) self._users.pop(user.id)
self._async_schedule_save() self._async_schedule_save()
async def async_update_user(
self, user: models.User, name: Optional[str] = None,
is_active: Optional[bool] = None,
group_ids: Optional[List[str]] = None) -> None:
"""Update a user."""
assert self._groups is not None
if group_ids is not None:
groups = []
for grid in group_ids:
group = self._groups.get(grid)
if group is None:
raise ValueError("Invalid group specified.")
groups.append(group)
user.groups = groups
user.invalidate_permission_cache()
for attr_name, value in (
('name', name),
('is_active', is_active),
):
if value is not None:
setattr(user, attr_name, value)
self._async_schedule_save()
async def async_activate_user(self, user: models.User) -> None: async def async_activate_user(self, user: models.User) -> None:
"""Activate a user.""" """Activate a user."""
user.is_active = True user.is_active = True

View file

@ -8,6 +8,7 @@ import attr
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from . import permissions as perm_mdl from . import permissions as perm_mdl
from .const import GROUP_ID_ADMIN
from .util import generate_secret from .util import generate_secret
TOKEN_TYPE_NORMAL = 'normal' TOKEN_TYPE_NORMAL = 'normal'
@ -48,7 +49,7 @@ class User:
) # type: Dict[str, RefreshToken] ) # type: Dict[str, RefreshToken]
_permissions = attr.ib( _permissions = attr.ib(
type=perm_mdl.PolicyPermissions, type=Optional[perm_mdl.PolicyPermissions],
init=False, init=False,
cmp=False, cmp=False,
default=None, default=None,
@ -69,6 +70,19 @@ class User:
return self._permissions return self._permissions
@property
def is_admin(self) -> bool:
"""Return if user is part of the admin group."""
if self.is_owner:
return True
return self.is_active and any(
gr.id == GROUP_ID_ADMIN for gr in self.groups)
def invalidate_permission_cache(self) -> None:
"""Invalidate permission cache."""
self._permissions = None
@attr.s(slots=True) @attr.s(slots=True)
class RefreshToken: class RefreshToken:

View file

@ -5,10 +5,8 @@ from typing import ( # noqa: F401
import voluptuous as vol import voluptuous as vol
from homeassistant.core import State
from .const import CAT_ENTITIES from .const import CAT_ENTITIES
from .types import CategoryType, PolicyType from .types import PolicyType
from .entities import ENTITY_POLICY_SCHEMA, compile_entities from .entities import ENTITY_POLICY_SCHEMA, compile_entities
from .merge import merge_policies # noqa from .merge import merge_policies # noqa
@ -22,13 +20,20 @@ _LOGGER = logging.getLogger(__name__)
class AbstractPermissions: class AbstractPermissions:
"""Default permissions class.""" """Default permissions class."""
def check_entity(self, entity_id: str, key: str) -> bool: _cached_entity_func = None
"""Test if we can access entity."""
def _entity_func(self) -> Callable[[str, str], bool]:
"""Return a function that can test entity access."""
raise NotImplementedError raise NotImplementedError
def filter_states(self, states: List[State]) -> List[State]: def check_entity(self, entity_id: str, key: str) -> bool:
"""Filter a list of states for what the user is allowed to see.""" """Check if we can access entity."""
raise NotImplementedError entity_func = self._cached_entity_func
if entity_func is None:
entity_func = self._cached_entity_func = self._entity_func()
return entity_func(entity_id, key)
class PolicyPermissions(AbstractPermissions): class PolicyPermissions(AbstractPermissions):
@ -37,34 +42,10 @@ class PolicyPermissions(AbstractPermissions):
def __init__(self, policy: PolicyType) -> None: def __init__(self, policy: PolicyType) -> None:
"""Initialize the permission class.""" """Initialize the permission class."""
self._policy = policy self._policy = policy
self._compiled = {} # type: Dict[str, Callable[..., bool]]
def check_entity(self, entity_id: str, key: str) -> bool: def _entity_func(self) -> Callable[[str, str], bool]:
"""Test if we can access entity.""" """Return a function that can test entity access."""
func = self._policy_func(CAT_ENTITIES, compile_entities) return compile_entities(self._policy.get(CAT_ENTITIES))
return func(entity_id, (key,))
def filter_states(self, states: List[State]) -> List[State]:
"""Filter a list of states for what the user is allowed to see."""
func = self._policy_func(CAT_ENTITIES, compile_entities)
keys = ('read',)
return [entity for entity in states if func(entity.entity_id, keys)]
def _policy_func(self, category: str,
compile_func: Callable[[CategoryType], Callable]) \
-> Callable[..., bool]:
"""Get a policy function."""
func = self._compiled.get(category)
if func:
return func
func = self._compiled[category] = compile_func(
self._policy.get(category))
_LOGGER.debug("Compiled %s func: %s", category, func)
return func
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
"""Equals check.""" """Equals check."""
@ -78,13 +59,9 @@ class _OwnerPermissions(AbstractPermissions):
# pylint: disable=no-self-use # pylint: disable=no-self-use
def check_entity(self, entity_id: str, key: str) -> bool: def _entity_func(self) -> Callable[[str, str], bool]:
"""Test if we can access entity.""" """Return a function that can test entity access."""
return True return lambda entity_id, key: True
def filter_states(self, states: List[State]) -> List[State]:
"""Filter a list of states for what the user is allowed to see."""
return states
OwnerPermissions = _OwnerPermissions() # pylint: disable=invalid-name OwnerPermissions = _OwnerPermissions() # pylint: disable=invalid-name

View file

@ -28,28 +28,28 @@ ENTITY_POLICY_SCHEMA = vol.Any(True, vol.Schema({
})) }))
def _entity_allowed(schema: ValueType, keys: Tuple[str]) \ def _entity_allowed(schema: ValueType, key: str) \
-> Union[bool, None]: -> Union[bool, None]:
"""Test if an entity is allowed based on the keys.""" """Test if an entity is allowed based on the keys."""
if schema is None or isinstance(schema, bool): if schema is None or isinstance(schema, bool):
return schema return schema
assert isinstance(schema, dict) assert isinstance(schema, dict)
return schema.get(keys[0]) return schema.get(key)
def compile_entities(policy: CategoryType) \ def compile_entities(policy: CategoryType) \
-> Callable[[str, Tuple[str]], bool]: -> Callable[[str, str], bool]:
"""Compile policy into a function that tests policy.""" """Compile policy into a function that tests policy."""
# None, Empty Dict, False # None, Empty Dict, False
if not policy: if not policy:
def apply_policy_deny_all(entity_id: str, keys: Tuple[str]) -> bool: def apply_policy_deny_all(entity_id: str, key: str) -> bool:
"""Decline all.""" """Decline all."""
return False return False
return apply_policy_deny_all return apply_policy_deny_all
if policy is True: if policy is True:
def apply_policy_allow_all(entity_id: str, keys: Tuple[str]) -> bool: def apply_policy_allow_all(entity_id: str, key: str) -> bool:
"""Approve all.""" """Approve all."""
return True return True
@ -61,7 +61,7 @@ def compile_entities(policy: CategoryType) \
entity_ids = policy.get(ENTITY_ENTITY_IDS) entity_ids = policy.get(ENTITY_ENTITY_IDS)
all_entities = policy.get(SUBCAT_ALL) all_entities = policy.get(SUBCAT_ALL)
funcs = [] # type: List[Callable[[str, Tuple[str]], Union[None, bool]]] funcs = [] # type: List[Callable[[str, str], Union[None, bool]]]
# The order of these functions matter. The more precise are at the top. # The order of these functions matter. The more precise are at the top.
# If a function returns None, they cannot handle it. # If a function returns None, they cannot handle it.
@ -70,23 +70,23 @@ def compile_entities(policy: CategoryType) \
# Setting entity_ids to a boolean is final decision for permissions # Setting entity_ids to a boolean is final decision for permissions
# So return right away. # So return right away.
if isinstance(entity_ids, bool): if isinstance(entity_ids, bool):
def allowed_entity_id_bool(entity_id: str, keys: Tuple[str]) -> bool: def allowed_entity_id_bool(entity_id: str, key: str) -> bool:
"""Test if allowed entity_id.""" """Test if allowed entity_id."""
return entity_ids # type: ignore return entity_ids # type: ignore
return allowed_entity_id_bool return allowed_entity_id_bool
if entity_ids is not None: if entity_ids is not None:
def allowed_entity_id_dict(entity_id: str, keys: Tuple[str]) \ def allowed_entity_id_dict(entity_id: str, key: str) \
-> Union[None, bool]: -> Union[None, bool]:
"""Test if allowed entity_id.""" """Test if allowed entity_id."""
return _entity_allowed( return _entity_allowed(
entity_ids.get(entity_id), keys) # type: ignore entity_ids.get(entity_id), key) # type: ignore
funcs.append(allowed_entity_id_dict) funcs.append(allowed_entity_id_dict)
if isinstance(domains, bool): if isinstance(domains, bool):
def allowed_domain_bool(entity_id: str, keys: Tuple[str]) \ def allowed_domain_bool(entity_id: str, key: str) \
-> Union[None, bool]: -> Union[None, bool]:
"""Test if allowed domain.""" """Test if allowed domain."""
return domains return domains
@ -94,31 +94,31 @@ def compile_entities(policy: CategoryType) \
funcs.append(allowed_domain_bool) funcs.append(allowed_domain_bool)
elif domains is not None: elif domains is not None:
def allowed_domain_dict(entity_id: str, keys: Tuple[str]) \ def allowed_domain_dict(entity_id: str, key: str) \
-> Union[None, bool]: -> Union[None, bool]:
"""Test if allowed domain.""" """Test if allowed domain."""
domain = entity_id.split(".", 1)[0] domain = entity_id.split(".", 1)[0]
return _entity_allowed(domains.get(domain), keys) # type: ignore return _entity_allowed(domains.get(domain), key) # type: ignore
funcs.append(allowed_domain_dict) funcs.append(allowed_domain_dict)
if isinstance(all_entities, bool): if isinstance(all_entities, bool):
def allowed_all_entities_bool(entity_id: str, keys: Tuple[str]) \ def allowed_all_entities_bool(entity_id: str, key: str) \
-> Union[None, bool]: -> Union[None, bool]:
"""Test if allowed domain.""" """Test if allowed domain."""
return all_entities return all_entities
funcs.append(allowed_all_entities_bool) funcs.append(allowed_all_entities_bool)
elif all_entities is not None: elif all_entities is not None:
def allowed_all_entities_dict(entity_id: str, keys: Tuple[str]) \ def allowed_all_entities_dict(entity_id: str, key: str) \
-> Union[None, bool]: -> Union[None, bool]:
"""Test if allowed domain.""" """Test if allowed domain."""
return _entity_allowed(all_entities, keys) return _entity_allowed(all_entities, key)
funcs.append(allowed_all_entities_dict) funcs.append(allowed_all_entities_dict)
# Can happen if no valid subcategories specified # Can happen if no valid subcategories specified
if not funcs: if not funcs:
def apply_policy_deny_all_2(entity_id: str, keys: Tuple[str]) -> bool: def apply_policy_deny_all_2(entity_id: str, key: str) -> bool:
"""Decline all.""" """Decline all."""
return False return False
@ -128,16 +128,16 @@ def compile_entities(policy: CategoryType) \
func = funcs[0] func = funcs[0]
@wraps(func) @wraps(func)
def apply_policy_func(entity_id: str, keys: Tuple[str]) -> bool: def apply_policy_func(entity_id: str, key: str) -> bool:
"""Apply a single policy function.""" """Apply a single policy function."""
return func(entity_id, keys) is True return func(entity_id, key) is True
return apply_policy_func return apply_policy_func
def apply_policy_funcs(entity_id: str, keys: Tuple[str]) -> bool: def apply_policy_funcs(entity_id: str, key: str) -> bool:
"""Apply several policy functions.""" """Apply several policy functions."""
for func in funcs: for func in funcs:
result = func(entity_id, keys) result = func(entity_id, key)
if result is not None: if result is not None:
return result return result
return False return False

View file

@ -20,7 +20,8 @@ from homeassistant.const import (
URL_API_SERVICES, URL_API_STATES, URL_API_STATES_ENTITY, URL_API_STREAM, URL_API_SERVICES, URL_API_STATES, URL_API_STATES_ENTITY, URL_API_STREAM,
URL_API_TEMPLATE, __version__) URL_API_TEMPLATE, __version__)
import homeassistant.core as ha import homeassistant.core as ha
from homeassistant.exceptions import TemplateError from homeassistant.auth.permissions.const import POLICY_READ
from homeassistant.exceptions import TemplateError, Unauthorized
from homeassistant.helpers import template from homeassistant.helpers import template
from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.helpers.state import AsyncTrackStates from homeassistant.helpers.state import AsyncTrackStates
@ -81,6 +82,8 @@ class APIEventStream(HomeAssistantView):
async def get(self, request): async def get(self, request):
"""Provide a streaming interface for the event bus.""" """Provide a streaming interface for the event bus."""
if not request['hass_user'].is_admin:
raise Unauthorized()
hass = request.app['hass'] hass = request.app['hass']
stop_obj = object() stop_obj = object()
to_write = asyncio.Queue(loop=hass.loop) to_write = asyncio.Queue(loop=hass.loop)
@ -185,7 +188,13 @@ class APIStatesView(HomeAssistantView):
@ha.callback @ha.callback
def get(self, request): def get(self, request):
"""Get current states.""" """Get current states."""
return self.json(request.app['hass'].states.async_all()) user = request['hass_user']
entity_perm = user.permissions.check_entity
states = [
state for state in request.app['hass'].states.async_all()
if entity_perm(state.entity_id, 'read')
]
return self.json(states)
class APIEntityStateView(HomeAssistantView): class APIEntityStateView(HomeAssistantView):
@ -197,6 +206,10 @@ class APIEntityStateView(HomeAssistantView):
@ha.callback @ha.callback
def get(self, request, entity_id): def get(self, request, entity_id):
"""Retrieve state of entity.""" """Retrieve state of entity."""
user = request['hass_user']
if not user.permissions.check_entity(entity_id, POLICY_READ):
raise Unauthorized(entity_id=entity_id)
state = request.app['hass'].states.get(entity_id) state = request.app['hass'].states.get(entity_id)
if state: if state:
return self.json(state) return self.json(state)
@ -204,6 +217,8 @@ class APIEntityStateView(HomeAssistantView):
async def post(self, request, entity_id): async def post(self, request, entity_id):
"""Update state of entity.""" """Update state of entity."""
if not request['hass_user'].is_admin:
raise Unauthorized(entity_id=entity_id)
hass = request.app['hass'] hass = request.app['hass']
try: try:
data = await request.json() data = await request.json()
@ -236,6 +251,8 @@ class APIEntityStateView(HomeAssistantView):
@ha.callback @ha.callback
def delete(self, request, entity_id): def delete(self, request, entity_id):
"""Remove entity.""" """Remove entity."""
if not request['hass_user'].is_admin:
raise Unauthorized(entity_id=entity_id)
if request.app['hass'].states.async_remove(entity_id): if request.app['hass'].states.async_remove(entity_id):
return self.json_message("Entity removed.") return self.json_message("Entity removed.")
return self.json_message("Entity not found.", HTTP_NOT_FOUND) return self.json_message("Entity not found.", HTTP_NOT_FOUND)
@ -261,6 +278,8 @@ class APIEventView(HomeAssistantView):
async def post(self, request, event_type): async def post(self, request, event_type):
"""Fire events.""" """Fire events."""
if not request['hass_user'].is_admin:
raise Unauthorized()
body = await request.text() body = await request.text()
try: try:
event_data = json.loads(body) if body else None event_data = json.loads(body) if body else None
@ -346,6 +365,8 @@ class APITemplateView(HomeAssistantView):
async def post(self, request): async def post(self, request):
"""Render a template.""" """Render a template."""
if not request['hass_user'].is_admin:
raise Unauthorized()
try: try:
data = await request.json() data = await request.json()
tpl = template.Template(data['template'], request.app['hass']) tpl = template.Template(data['template'], request.app['hass'])
@ -363,6 +384,8 @@ class APIErrorLog(HomeAssistantView):
async def get(self, request): async def get(self, request):
"""Retrieve API error log.""" """Retrieve API error log."""
if not request['hass_user'].is_admin:
raise Unauthorized()
return web.FileResponse(request.app['hass'].data[DATA_LOGGING]) return web.FileResponse(request.app['hass'].data[DATA_LOGGING])

View file

@ -10,6 +10,7 @@ import os
import voluptuous as vol import voluptuous as vol
from homeassistant.auth.const import GROUP_ID_ADMIN
from homeassistant.components import SERVICE_CHECK_CONFIG from homeassistant.components import SERVICE_CHECK_CONFIG
from homeassistant.const import ( from homeassistant.const import (
ATTR_NAME, SERVICE_HOMEASSISTANT_RESTART, SERVICE_HOMEASSISTANT_STOP) ATTR_NAME, SERVICE_HOMEASSISTANT_RESTART, SERVICE_HOMEASSISTANT_STOP)
@ -181,8 +182,14 @@ async def async_setup(hass, config):
if user and user.refresh_tokens: if user and user.refresh_tokens:
refresh_token = list(user.refresh_tokens.values())[0] refresh_token = list(user.refresh_tokens.values())[0]
# Migrate old hass.io users to be admin.
if not user.is_admin:
await hass.auth.async_update_user(
user, group_ids=[GROUP_ID_ADMIN])
if refresh_token is None: if refresh_token is None:
user = await hass.auth.async_create_system_user('Hass.io') user = await hass.auth.async_create_system_user(
'Hass.io', [GROUP_ID_ADMIN])
refresh_token = await hass.auth.async_create_refresh_token(user) refresh_token = await hass.auth.async_create_refresh_token(user)
data['hassio_user'] = user.id data['hassio_user'] = user.id
await store.async_save(data) await store.async_save(data)

View file

@ -14,6 +14,7 @@ from aiohttp.web_exceptions import HTTPUnauthorized, HTTPInternalServerError
from homeassistant.components.http.ban import process_success_login from homeassistant.components.http.ban import process_success_login
from homeassistant.core import Context, is_callback from homeassistant.core import Context, is_callback
from homeassistant.const import CONTENT_TYPE_JSON from homeassistant.const import CONTENT_TYPE_JSON
from homeassistant import exceptions
from homeassistant.helpers.json import JSONEncoder from homeassistant.helpers.json import JSONEncoder
from .const import KEY_AUTHENTICATED, KEY_REAL_IP from .const import KEY_AUTHENTICATED, KEY_REAL_IP
@ -107,10 +108,13 @@ def request_handler_factory(view, handler):
_LOGGER.info('Serving %s to %s (auth: %s)', _LOGGER.info('Serving %s to %s (auth: %s)',
request.path, request.get(KEY_REAL_IP), authenticated) request.path, request.get(KEY_REAL_IP), authenticated)
result = handler(request, **request.match_info) try:
result = handler(request, **request.match_info)
if asyncio.iscoroutine(result): if asyncio.iscoroutine(result):
result = await result result = await result
except exceptions.Unauthorized:
raise HTTPUnauthorized()
if isinstance(result, web.StreamResponse): if isinstance(result, web.StreamResponse):
# The method handler returned a ready-made Response, how nice of it # The method handler returned a ready-made Response, how nice of it

View file

@ -192,9 +192,9 @@ async def entity_service_call(hass, platforms, func, call):
user = await hass.auth.async_get_user(call.context.user_id) user = await hass.auth.async_get_user(call.context.user_id)
if user is None: if user is None:
raise UnknownUser(context=call.context) raise UnknownUser(context=call.context)
perms = user.permissions entity_perms = user.permissions.check_entity
else: else:
perms = None entity_perms = None
# Are we trying to target all entities # Are we trying to target all entities
target_all_entities = ATTR_ENTITY_ID not in call.data target_all_entities = ATTR_ENTITY_ID not in call.data
@ -218,7 +218,7 @@ async def entity_service_call(hass, platforms, func, call):
# the service on. # the service on.
platforms_entities = [] platforms_entities = []
if perms is None: if entity_perms is None:
for platform in platforms: for platform in platforms:
if target_all_entities: if target_all_entities:
platforms_entities.append(list(platform.entities.values())) platforms_entities.append(list(platform.entities.values()))
@ -234,7 +234,7 @@ async def entity_service_call(hass, platforms, func, call):
for platform in platforms: for platform in platforms:
platforms_entities.append([ platforms_entities.append([
entity for entity in platform.entities.values() entity for entity in platform.entities.values()
if perms.check_entity(entity.entity_id, POLICY_CONTROL)]) if entity_perms(entity.entity_id, POLICY_CONTROL)])
else: else:
for platform in platforms: for platform in platforms:
@ -243,7 +243,7 @@ async def entity_service_call(hass, platforms, func, call):
if entity.entity_id not in entity_ids: if entity.entity_id not in entity_ids:
continue continue
if not perms.check_entity(entity.entity_id, POLICY_CONTROL): if not entity_perms(entity.entity_id, POLICY_CONTROL):
raise Unauthorized( raise Unauthorized(
context=call.context, context=call.context,
entity_id=entity.entity_id, entity_id=entity.entity_id,

View file

@ -10,7 +10,7 @@ def test_entities_none():
"""Test entity ID policy.""" """Test entity ID policy."""
policy = None policy = None
compiled = compile_entities(policy) compiled = compile_entities(policy)
assert compiled('light.kitchen', ('read',)) is False assert compiled('light.kitchen', 'read') is False
def test_entities_empty(): def test_entities_empty():
@ -18,7 +18,7 @@ def test_entities_empty():
policy = {} policy = {}
ENTITY_POLICY_SCHEMA(policy) ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy) compiled = compile_entities(policy)
assert compiled('light.kitchen', ('read',)) is False assert compiled('light.kitchen', 'read') is False
def test_entities_false(): def test_entities_false():
@ -33,7 +33,7 @@ def test_entities_true():
policy = True policy = True
ENTITY_POLICY_SCHEMA(policy) ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy) compiled = compile_entities(policy)
assert compiled('light.kitchen', ('read',)) is True assert compiled('light.kitchen', 'read') is True
def test_entities_domains_true(): def test_entities_domains_true():
@ -43,7 +43,7 @@ def test_entities_domains_true():
} }
ENTITY_POLICY_SCHEMA(policy) ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy) compiled = compile_entities(policy)
assert compiled('light.kitchen', ('read',)) is True assert compiled('light.kitchen', 'read') is True
def test_entities_domains_domain_true(): def test_entities_domains_domain_true():
@ -55,8 +55,8 @@ def test_entities_domains_domain_true():
} }
ENTITY_POLICY_SCHEMA(policy) ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy) compiled = compile_entities(policy)
assert compiled('light.kitchen', ('read',)) is True assert compiled('light.kitchen', 'read') is True
assert compiled('switch.kitchen', ('read',)) is False assert compiled('switch.kitchen', 'read') is False
def test_entities_domains_domain_false(): def test_entities_domains_domain_false():
@ -77,7 +77,7 @@ def test_entities_entity_ids_true():
} }
ENTITY_POLICY_SCHEMA(policy) ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy) compiled = compile_entities(policy)
assert compiled('light.kitchen', ('read',)) is True assert compiled('light.kitchen', 'read') is True
def test_entities_entity_ids_false(): def test_entities_entity_ids_false():
@ -98,8 +98,8 @@ def test_entities_entity_ids_entity_id_true():
} }
ENTITY_POLICY_SCHEMA(policy) ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy) compiled = compile_entities(policy)
assert compiled('light.kitchen', ('read',)) is True assert compiled('light.kitchen', 'read') is True
assert compiled('switch.kitchen', ('read',)) is False assert compiled('switch.kitchen', 'read') is False
def test_entities_entity_ids_entity_id_false(): def test_entities_entity_ids_entity_id_false():
@ -124,9 +124,9 @@ def test_entities_control_only():
} }
ENTITY_POLICY_SCHEMA(policy) ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy) compiled = compile_entities(policy)
assert compiled('light.kitchen', ('read',)) is True assert compiled('light.kitchen', 'read') is True
assert compiled('light.kitchen', ('control',)) is False assert compiled('light.kitchen', 'control') is False
assert compiled('light.kitchen', ('edit',)) is False assert compiled('light.kitchen', 'edit') is False
def test_entities_read_control(): def test_entities_read_control():
@ -141,9 +141,9 @@ def test_entities_read_control():
} }
ENTITY_POLICY_SCHEMA(policy) ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy) compiled = compile_entities(policy)
assert compiled('light.kitchen', ('read',)) is True assert compiled('light.kitchen', 'read') is True
assert compiled('light.kitchen', ('control',)) is True assert compiled('light.kitchen', 'control') is True
assert compiled('light.kitchen', ('edit',)) is False assert compiled('light.kitchen', 'edit') is False
def test_entities_all_allow(): def test_entities_all_allow():
@ -153,9 +153,9 @@ def test_entities_all_allow():
} }
ENTITY_POLICY_SCHEMA(policy) ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy) compiled = compile_entities(policy)
assert compiled('light.kitchen', ('read',)) is True assert compiled('light.kitchen', 'read') is True
assert compiled('light.kitchen', ('control',)) is True assert compiled('light.kitchen', 'control') is True
assert compiled('switch.kitchen', ('read',)) is True assert compiled('switch.kitchen', 'read') is True
def test_entities_all_read(): def test_entities_all_read():
@ -167,9 +167,9 @@ def test_entities_all_read():
} }
ENTITY_POLICY_SCHEMA(policy) ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy) compiled = compile_entities(policy)
assert compiled('light.kitchen', ('read',)) is True assert compiled('light.kitchen', 'read') is True
assert compiled('light.kitchen', ('control',)) is False assert compiled('light.kitchen', 'control') is False
assert compiled('switch.kitchen', ('read',)) is True assert compiled('switch.kitchen', 'read') is True
def test_entities_all_control(): def test_entities_all_control():
@ -181,7 +181,7 @@ def test_entities_all_control():
} }
ENTITY_POLICY_SCHEMA(policy) ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy) compiled = compile_entities(policy)
assert compiled('light.kitchen', ('read',)) is False assert compiled('light.kitchen', 'read') is False
assert compiled('light.kitchen', ('control',)) is True assert compiled('light.kitchen', 'control') is True
assert compiled('switch.kitchen', ('read',)) is False assert compiled('switch.kitchen', 'read') is False
assert compiled('switch.kitchen', ('control',)) is True assert compiled('switch.kitchen', 'control') is True

View file

@ -1,34 +0,0 @@
"""Tests for the auth permission system."""
from homeassistant.core import State
from homeassistant.auth import permissions
def test_policy_perm_filter_states():
"""Test filtering entitites."""
states = [
State('light.kitchen', 'on'),
State('light.living_room', 'off'),
State('light.balcony', 'on'),
]
perm = permissions.PolicyPermissions({
'entities': {
'entity_ids': {
'light.kitchen': True,
'light.balcony': True,
}
}
})
filtered = perm.filter_states(states)
assert len(filtered) == 2
assert filtered == [states[0], states[2]]
def test_owner_permissions():
"""Test owner permissions access all."""
assert permissions.OwnerPermissions.check_entity('light.kitchen', 'write')
states = [
State('light.kitchen', 'on'),
State('light.living_room', 'off'),
State('light.balcony', 'on'),
]
assert permissions.OwnerPermissions.filter_states(states) == states

View file

@ -14,7 +14,8 @@ from contextlib import contextmanager
from homeassistant import auth, core as ha, config_entries from homeassistant import auth, core as ha, config_entries
from homeassistant.auth import ( from homeassistant.auth import (
models as auth_models, auth_store, providers as auth_providers) models as auth_models, auth_store, providers as auth_providers,
permissions as auth_permissions)
from homeassistant.auth.permissions import system_policies from homeassistant.auth.permissions import system_policies
from homeassistant.setup import setup_component, async_setup_component from homeassistant.setup import setup_component, async_setup_component
from homeassistant.config import async_process_component_config from homeassistant.config import async_process_component_config
@ -400,6 +401,10 @@ class MockUser(auth_models.User):
auth_mgr._store._users[self.id] = self auth_mgr._store._users[self.id] = self
return self return self
def mock_policy(self, policy):
"""Mock a policy for a user."""
self._permissions = auth_permissions.PolicyPermissions(policy)
async def register_auth_provider(hass, config): async def register_auth_provider(hass, config):
"""Register an auth provider.""" """Register an auth provider."""

View file

@ -80,11 +80,10 @@ def hass_ws_client(aiohttp_client):
@pytest.fixture @pytest.fixture
def hass_access_token(hass): def hass_access_token(hass, hass_admin_user):
"""Return an access token to access Home Assistant.""" """Return an access token to access Home Assistant."""
user = MockUser().add_to_hass(hass)
refresh_token = hass.loop.run_until_complete( refresh_token = hass.loop.run_until_complete(
hass.auth.async_create_refresh_token(user, CLIENT_ID)) hass.auth.async_create_refresh_token(hass_admin_user, CLIENT_ID))
yield hass.auth.async_create_access_token(refresh_token) yield hass.auth.async_create_access_token(refresh_token)

View file

@ -5,6 +5,7 @@ from unittest.mock import patch, Mock
import pytest import pytest
from homeassistant.auth.const import GROUP_ID_ADMIN
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.components.hassio import ( from homeassistant.components.hassio import (
STORAGE_KEY, async_check_config) STORAGE_KEY, async_check_config)
@ -106,6 +107,8 @@ async def test_setup_api_push_api_data_default(hass, aioclient_mock,
) )
assert hassio_user is not None assert hassio_user is not None
assert hassio_user.system_generated assert hassio_user.system_generated
assert len(hassio_user.groups) == 1
assert hassio_user.groups[0].id == GROUP_ID_ADMIN
for token in hassio_user.refresh_tokens.values(): for token in hassio_user.refresh_tokens.values():
if token.token == refresh_token: if token.token == refresh_token:
break break
@ -113,6 +116,31 @@ async def test_setup_api_push_api_data_default(hass, aioclient_mock,
assert False, 'refresh token not found' assert False, 'refresh token not found'
async def test_setup_adds_admin_group_to_user(hass, aioclient_mock,
hass_storage):
"""Test setup with API push default data."""
# Create user without admin
user = await hass.auth.async_create_system_user('Hass.io')
assert not user.is_admin
await hass.auth.async_create_refresh_token(user)
hass_storage[STORAGE_KEY] = {
'data': {'hassio_user': user.id},
'key': STORAGE_KEY,
'version': 1
}
with patch.dict(os.environ, MOCK_ENVIRON), \
patch('homeassistant.auth.AuthManager.active', return_value=True):
result = await async_setup_component(hass, 'hassio', {
'http': {},
'hassio': {}
})
assert result
assert user.is_admin
async def test_setup_api_push_api_data_no_auth(hass, aioclient_mock, async def test_setup_api_push_api_data_no_auth(hass, aioclient_mock,
hass_storage): hass_storage):
"""Test setup with API push default data.""" """Test setup with API push default data."""

View file

@ -16,10 +16,12 @@ from tests.common import async_mock_service
@pytest.fixture @pytest.fixture
def mock_api_client(hass, aiohttp_client): def mock_api_client(hass, aiohttp_client, hass_access_token):
"""Start the Hass HTTP component.""" """Start the Hass HTTP component and return admin API client."""
hass.loop.run_until_complete(async_setup_component(hass, 'api', {})) hass.loop.run_until_complete(async_setup_component(hass, 'api', {}))
return hass.loop.run_until_complete(aiohttp_client(hass.http.app)) return hass.loop.run_until_complete(aiohttp_client(hass.http.app, headers={
'Authorization': 'Bearer {}'.format(hass_access_token)
}))
@asyncio.coroutine @asyncio.coroutine
@ -405,7 +407,8 @@ def _listen_count(hass):
return sum(hass.bus.async_listeners().values()) return sum(hass.bus.async_listeners().values())
async def test_api_error_log(hass, aiohttp_client): async def test_api_error_log(hass, aiohttp_client, hass_access_token,
hass_admin_user):
"""Test if we can fetch the error log.""" """Test if we can fetch the error log."""
hass.data[DATA_LOGGING] = '/some/path' hass.data[DATA_LOGGING] = '/some/path'
await async_setup_component(hass, 'api', { await async_setup_component(hass, 'api', {
@ -416,7 +419,7 @@ async def test_api_error_log(hass, aiohttp_client):
client = await aiohttp_client(hass.http.app) client = await aiohttp_client(hass.http.app)
resp = await client.get(const.URL_API_ERROR_LOG) resp = await client.get(const.URL_API_ERROR_LOG)
# Verufy auth required # Verify auth required
assert resp.status == 401 assert resp.status == 401
with patch( with patch(
@ -424,7 +427,7 @@ async def test_api_error_log(hass, aiohttp_client):
return_value=web.Response(status=200, text='Hello') return_value=web.Response(status=200, text='Hello')
) as mock_file: ) as mock_file:
resp = await client.get(const.URL_API_ERROR_LOG, headers={ resp = await client.get(const.URL_API_ERROR_LOG, headers={
'x-ha-access': 'yolo' 'Authorization': 'Bearer {}'.format(hass_access_token)
}) })
assert len(mock_file.mock_calls) == 1 assert len(mock_file.mock_calls) == 1
@ -432,6 +435,13 @@ async def test_api_error_log(hass, aiohttp_client):
assert resp.status == 200 assert resp.status == 200
assert await resp.text() == 'Hello' assert await resp.text() == 'Hello'
# Verify we require admin user
hass_admin_user.groups = []
resp = await client.get(const.URL_API_ERROR_LOG, headers={
'Authorization': 'Bearer {}'.format(hass_access_token)
})
assert resp.status == 401
async def test_api_fire_event_context(hass, mock_api_client, async def test_api_fire_event_context(hass, mock_api_client,
hass_access_token): hass_access_token):
@ -494,3 +504,67 @@ async def test_api_set_state_context(hass, mock_api_client, hass_access_token):
state = hass.states.get('light.kitchen') state = hass.states.get('light.kitchen')
assert state.context.user_id == refresh_token.user.id assert state.context.user_id == refresh_token.user.id
async def test_event_stream_requires_admin(hass, mock_api_client,
hass_admin_user):
"""Test user needs to be admin to access event stream."""
hass_admin_user.groups = []
resp = await mock_api_client.get('/api/stream')
assert resp.status == 401
async def test_states_view_filters(hass, mock_api_client, hass_admin_user):
"""Test filtering only visible states."""
hass_admin_user.mock_policy({
'entities': {
'entity_ids': {
'test.entity': True
}
}
})
hass.states.async_set('test.entity', 'hello')
hass.states.async_set('test.not_visible_entity', 'invisible')
resp = await mock_api_client.get(const.URL_API_STATES)
assert resp.status == 200
json = await resp.json()
assert len(json) == 1
assert json[0]['entity_id'] == 'test.entity'
async def test_get_entity_state_read_perm(hass, mock_api_client,
hass_admin_user):
"""Test getting a state requires read permission."""
hass_admin_user.mock_policy({})
resp = await mock_api_client.get('/api/states/light.test')
assert resp.status == 401
async def test_post_entity_state_admin(hass, mock_api_client, hass_admin_user):
"""Test updating state requires admin."""
hass_admin_user.groups = []
resp = await mock_api_client.post('/api/states/light.test')
assert resp.status == 401
async def test_delete_entity_state_admin(hass, mock_api_client,
hass_admin_user):
"""Test deleting entity requires admin."""
hass_admin_user.groups = []
resp = await mock_api_client.delete('/api/states/light.test')
assert resp.status == 401
async def test_post_event_admin(hass, mock_api_client, hass_admin_user):
"""Test sending event requires admin."""
hass_admin_user.groups = []
resp = await mock_api_client.post('/api/events/state_changed')
assert resp.status == 401
async def test_rendering_template_admin(hass, mock_api_client,
hass_admin_user):
"""Test rendering a template requires admin."""
hass_admin_user.groups = []
resp = await mock_api_client.post('/api/template')
assert resp.status == 401