Add permission checks to Rest API (#18639)
* Add permission checks to Rest API * Clean up unnecessary method * Remove all the tuple stuff from entity check * Simplify perms * Correct param name for owner permission * Hass.io make/update user to be admin * Types
This commit is contained in:
parent
f387cdec59
commit
8b8629a5f4
15 changed files with 282 additions and 145 deletions
|
@ -132,13 +132,15 @@ class AuthManager:
|
|||
|
||||
return None
|
||||
|
||||
async def async_create_system_user(self, name: str) -> models.User:
|
||||
async def async_create_system_user(
|
||||
self, name: str,
|
||||
group_ids: Optional[List[str]] = None) -> models.User:
|
||||
"""Create a system user."""
|
||||
user = await self._store.async_create_user(
|
||||
name=name,
|
||||
system_generated=True,
|
||||
is_active=True,
|
||||
group_ids=[],
|
||||
group_ids=group_ids or [],
|
||||
)
|
||||
|
||||
self.hass.bus.async_fire(EVENT_USER_ADDED, {
|
||||
|
@ -217,6 +219,17 @@ class AuthManager:
|
|||
'user_id': user.id
|
||||
})
|
||||
|
||||
async def async_update_user(self, user: models.User,
|
||||
name: Optional[str] = None,
|
||||
group_ids: Optional[List[str]] = None) -> None:
|
||||
"""Update a user."""
|
||||
kwargs = {} # type: Dict[str,Any]
|
||||
if name is not None:
|
||||
kwargs['name'] = name
|
||||
if group_ids is not None:
|
||||
kwargs['group_ids'] = group_ids
|
||||
await self._store.async_update_user(user, **kwargs)
|
||||
|
||||
async def async_activate_user(self, user: models.User) -> None:
|
||||
"""Activate a user."""
|
||||
await self._store.async_activate_user(user)
|
||||
|
|
|
@ -133,6 +133,33 @@ class AuthStore:
|
|||
self._users.pop(user.id)
|
||||
self._async_schedule_save()
|
||||
|
||||
async def async_update_user(
|
||||
self, user: models.User, name: Optional[str] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
group_ids: Optional[List[str]] = None) -> None:
|
||||
"""Update a user."""
|
||||
assert self._groups is not None
|
||||
|
||||
if group_ids is not None:
|
||||
groups = []
|
||||
for grid in group_ids:
|
||||
group = self._groups.get(grid)
|
||||
if group is None:
|
||||
raise ValueError("Invalid group specified.")
|
||||
groups.append(group)
|
||||
|
||||
user.groups = groups
|
||||
user.invalidate_permission_cache()
|
||||
|
||||
for attr_name, value in (
|
||||
('name', name),
|
||||
('is_active', is_active),
|
||||
):
|
||||
if value is not None:
|
||||
setattr(user, attr_name, value)
|
||||
|
||||
self._async_schedule_save()
|
||||
|
||||
async def async_activate_user(self, user: models.User) -> None:
|
||||
"""Activate a user."""
|
||||
user.is_active = True
|
||||
|
|
|
@ -8,6 +8,7 @@ import attr
|
|||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from . import permissions as perm_mdl
|
||||
from .const import GROUP_ID_ADMIN
|
||||
from .util import generate_secret
|
||||
|
||||
TOKEN_TYPE_NORMAL = 'normal'
|
||||
|
@ -48,7 +49,7 @@ class User:
|
|||
) # type: Dict[str, RefreshToken]
|
||||
|
||||
_permissions = attr.ib(
|
||||
type=perm_mdl.PolicyPermissions,
|
||||
type=Optional[perm_mdl.PolicyPermissions],
|
||||
init=False,
|
||||
cmp=False,
|
||||
default=None,
|
||||
|
@ -69,6 +70,19 @@ class User:
|
|||
|
||||
return self._permissions
|
||||
|
||||
@property
|
||||
def is_admin(self) -> bool:
|
||||
"""Return if user is part of the admin group."""
|
||||
if self.is_owner:
|
||||
return True
|
||||
|
||||
return self.is_active and any(
|
||||
gr.id == GROUP_ID_ADMIN for gr in self.groups)
|
||||
|
||||
def invalidate_permission_cache(self) -> None:
|
||||
"""Invalidate permission cache."""
|
||||
self._permissions = None
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class RefreshToken:
|
||||
|
|
|
@ -5,10 +5,8 @@ from typing import ( # noqa: F401
|
|||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import State
|
||||
|
||||
from .const import CAT_ENTITIES
|
||||
from .types import CategoryType, PolicyType
|
||||
from .types import PolicyType
|
||||
from .entities import ENTITY_POLICY_SCHEMA, compile_entities
|
||||
from .merge import merge_policies # noqa
|
||||
|
||||
|
@ -22,13 +20,20 @@ _LOGGER = logging.getLogger(__name__)
|
|||
class AbstractPermissions:
|
||||
"""Default permissions class."""
|
||||
|
||||
def check_entity(self, entity_id: str, key: str) -> bool:
|
||||
"""Test if we can access entity."""
|
||||
_cached_entity_func = None
|
||||
|
||||
def _entity_func(self) -> Callable[[str, str], bool]:
|
||||
"""Return a function that can test entity access."""
|
||||
raise NotImplementedError
|
||||
|
||||
def filter_states(self, states: List[State]) -> List[State]:
|
||||
"""Filter a list of states for what the user is allowed to see."""
|
||||
raise NotImplementedError
|
||||
def check_entity(self, entity_id: str, key: str) -> bool:
|
||||
"""Check if we can access entity."""
|
||||
entity_func = self._cached_entity_func
|
||||
|
||||
if entity_func is None:
|
||||
entity_func = self._cached_entity_func = self._entity_func()
|
||||
|
||||
return entity_func(entity_id, key)
|
||||
|
||||
|
||||
class PolicyPermissions(AbstractPermissions):
|
||||
|
@ -37,34 +42,10 @@ class PolicyPermissions(AbstractPermissions):
|
|||
def __init__(self, policy: PolicyType) -> None:
|
||||
"""Initialize the permission class."""
|
||||
self._policy = policy
|
||||
self._compiled = {} # type: Dict[str, Callable[..., bool]]
|
||||
|
||||
def check_entity(self, entity_id: str, key: str) -> bool:
|
||||
"""Test if we can access entity."""
|
||||
func = self._policy_func(CAT_ENTITIES, compile_entities)
|
||||
return func(entity_id, (key,))
|
||||
|
||||
def filter_states(self, states: List[State]) -> List[State]:
|
||||
"""Filter a list of states for what the user is allowed to see."""
|
||||
func = self._policy_func(CAT_ENTITIES, compile_entities)
|
||||
keys = ('read',)
|
||||
return [entity for entity in states if func(entity.entity_id, keys)]
|
||||
|
||||
def _policy_func(self, category: str,
|
||||
compile_func: Callable[[CategoryType], Callable]) \
|
||||
-> Callable[..., bool]:
|
||||
"""Get a policy function."""
|
||||
func = self._compiled.get(category)
|
||||
|
||||
if func:
|
||||
return func
|
||||
|
||||
func = self._compiled[category] = compile_func(
|
||||
self._policy.get(category))
|
||||
|
||||
_LOGGER.debug("Compiled %s func: %s", category, func)
|
||||
|
||||
return func
|
||||
def _entity_func(self) -> Callable[[str, str], bool]:
|
||||
"""Return a function that can test entity access."""
|
||||
return compile_entities(self._policy.get(CAT_ENTITIES))
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Equals check."""
|
||||
|
@ -78,13 +59,9 @@ class _OwnerPermissions(AbstractPermissions):
|
|||
|
||||
# pylint: disable=no-self-use
|
||||
|
||||
def check_entity(self, entity_id: str, key: str) -> bool:
|
||||
"""Test if we can access entity."""
|
||||
return True
|
||||
|
||||
def filter_states(self, states: List[State]) -> List[State]:
|
||||
"""Filter a list of states for what the user is allowed to see."""
|
||||
return states
|
||||
def _entity_func(self) -> Callable[[str, str], bool]:
|
||||
"""Return a function that can test entity access."""
|
||||
return lambda entity_id, key: True
|
||||
|
||||
|
||||
OwnerPermissions = _OwnerPermissions() # pylint: disable=invalid-name
|
||||
|
|
|
@ -28,28 +28,28 @@ ENTITY_POLICY_SCHEMA = vol.Any(True, vol.Schema({
|
|||
}))
|
||||
|
||||
|
||||
def _entity_allowed(schema: ValueType, keys: Tuple[str]) \
|
||||
def _entity_allowed(schema: ValueType, key: str) \
|
||||
-> Union[bool, None]:
|
||||
"""Test if an entity is allowed based on the keys."""
|
||||
if schema is None or isinstance(schema, bool):
|
||||
return schema
|
||||
assert isinstance(schema, dict)
|
||||
return schema.get(keys[0])
|
||||
return schema.get(key)
|
||||
|
||||
|
||||
def compile_entities(policy: CategoryType) \
|
||||
-> Callable[[str, Tuple[str]], bool]:
|
||||
-> Callable[[str, str], bool]:
|
||||
"""Compile policy into a function that tests policy."""
|
||||
# None, Empty Dict, False
|
||||
if not policy:
|
||||
def apply_policy_deny_all(entity_id: str, keys: Tuple[str]) -> bool:
|
||||
def apply_policy_deny_all(entity_id: str, key: str) -> bool:
|
||||
"""Decline all."""
|
||||
return False
|
||||
|
||||
return apply_policy_deny_all
|
||||
|
||||
if policy is True:
|
||||
def apply_policy_allow_all(entity_id: str, keys: Tuple[str]) -> bool:
|
||||
def apply_policy_allow_all(entity_id: str, key: str) -> bool:
|
||||
"""Approve all."""
|
||||
return True
|
||||
|
||||
|
@ -61,7 +61,7 @@ def compile_entities(policy: CategoryType) \
|
|||
entity_ids = policy.get(ENTITY_ENTITY_IDS)
|
||||
all_entities = policy.get(SUBCAT_ALL)
|
||||
|
||||
funcs = [] # type: List[Callable[[str, Tuple[str]], Union[None, bool]]]
|
||||
funcs = [] # type: List[Callable[[str, str], Union[None, bool]]]
|
||||
|
||||
# The order of these functions matter. The more precise are at the top.
|
||||
# If a function returns None, they cannot handle it.
|
||||
|
@ -70,23 +70,23 @@ def compile_entities(policy: CategoryType) \
|
|||
# Setting entity_ids to a boolean is final decision for permissions
|
||||
# So return right away.
|
||||
if isinstance(entity_ids, bool):
|
||||
def allowed_entity_id_bool(entity_id: str, keys: Tuple[str]) -> bool:
|
||||
def allowed_entity_id_bool(entity_id: str, key: str) -> bool:
|
||||
"""Test if allowed entity_id."""
|
||||
return entity_ids # type: ignore
|
||||
|
||||
return allowed_entity_id_bool
|
||||
|
||||
if entity_ids is not None:
|
||||
def allowed_entity_id_dict(entity_id: str, keys: Tuple[str]) \
|
||||
def allowed_entity_id_dict(entity_id: str, key: str) \
|
||||
-> Union[None, bool]:
|
||||
"""Test if allowed entity_id."""
|
||||
return _entity_allowed(
|
||||
entity_ids.get(entity_id), keys) # type: ignore
|
||||
entity_ids.get(entity_id), key) # type: ignore
|
||||
|
||||
funcs.append(allowed_entity_id_dict)
|
||||
|
||||
if isinstance(domains, bool):
|
||||
def allowed_domain_bool(entity_id: str, keys: Tuple[str]) \
|
||||
def allowed_domain_bool(entity_id: str, key: str) \
|
||||
-> Union[None, bool]:
|
||||
"""Test if allowed domain."""
|
||||
return domains
|
||||
|
@ -94,31 +94,31 @@ def compile_entities(policy: CategoryType) \
|
|||
funcs.append(allowed_domain_bool)
|
||||
|
||||
elif domains is not None:
|
||||
def allowed_domain_dict(entity_id: str, keys: Tuple[str]) \
|
||||
def allowed_domain_dict(entity_id: str, key: str) \
|
||||
-> Union[None, bool]:
|
||||
"""Test if allowed domain."""
|
||||
domain = entity_id.split(".", 1)[0]
|
||||
return _entity_allowed(domains.get(domain), keys) # type: ignore
|
||||
return _entity_allowed(domains.get(domain), key) # type: ignore
|
||||
|
||||
funcs.append(allowed_domain_dict)
|
||||
|
||||
if isinstance(all_entities, bool):
|
||||
def allowed_all_entities_bool(entity_id: str, keys: Tuple[str]) \
|
||||
def allowed_all_entities_bool(entity_id: str, key: str) \
|
||||
-> Union[None, bool]:
|
||||
"""Test if allowed domain."""
|
||||
return all_entities
|
||||
funcs.append(allowed_all_entities_bool)
|
||||
|
||||
elif all_entities is not None:
|
||||
def allowed_all_entities_dict(entity_id: str, keys: Tuple[str]) \
|
||||
def allowed_all_entities_dict(entity_id: str, key: str) \
|
||||
-> Union[None, bool]:
|
||||
"""Test if allowed domain."""
|
||||
return _entity_allowed(all_entities, keys)
|
||||
return _entity_allowed(all_entities, key)
|
||||
funcs.append(allowed_all_entities_dict)
|
||||
|
||||
# Can happen if no valid subcategories specified
|
||||
if not funcs:
|
||||
def apply_policy_deny_all_2(entity_id: str, keys: Tuple[str]) -> bool:
|
||||
def apply_policy_deny_all_2(entity_id: str, key: str) -> bool:
|
||||
"""Decline all."""
|
||||
return False
|
||||
|
||||
|
@ -128,16 +128,16 @@ def compile_entities(policy: CategoryType) \
|
|||
func = funcs[0]
|
||||
|
||||
@wraps(func)
|
||||
def apply_policy_func(entity_id: str, keys: Tuple[str]) -> bool:
|
||||
def apply_policy_func(entity_id: str, key: str) -> bool:
|
||||
"""Apply a single policy function."""
|
||||
return func(entity_id, keys) is True
|
||||
return func(entity_id, key) is True
|
||||
|
||||
return apply_policy_func
|
||||
|
||||
def apply_policy_funcs(entity_id: str, keys: Tuple[str]) -> bool:
|
||||
def apply_policy_funcs(entity_id: str, key: str) -> bool:
|
||||
"""Apply several policy functions."""
|
||||
for func in funcs:
|
||||
result = func(entity_id, keys)
|
||||
result = func(entity_id, key)
|
||||
if result is not None:
|
||||
return result
|
||||
return False
|
||||
|
|
|
@ -20,7 +20,8 @@ from homeassistant.const import (
|
|||
URL_API_SERVICES, URL_API_STATES, URL_API_STATES_ENTITY, URL_API_STREAM,
|
||||
URL_API_TEMPLATE, __version__)
|
||||
import homeassistant.core as ha
|
||||
from homeassistant.exceptions import TemplateError
|
||||
from homeassistant.auth.permissions.const import POLICY_READ
|
||||
from homeassistant.exceptions import TemplateError, Unauthorized
|
||||
from homeassistant.helpers import template
|
||||
from homeassistant.helpers.service import async_get_all_descriptions
|
||||
from homeassistant.helpers.state import AsyncTrackStates
|
||||
|
@ -81,6 +82,8 @@ class APIEventStream(HomeAssistantView):
|
|||
|
||||
async def get(self, request):
|
||||
"""Provide a streaming interface for the event bus."""
|
||||
if not request['hass_user'].is_admin:
|
||||
raise Unauthorized()
|
||||
hass = request.app['hass']
|
||||
stop_obj = object()
|
||||
to_write = asyncio.Queue(loop=hass.loop)
|
||||
|
@ -185,7 +188,13 @@ class APIStatesView(HomeAssistantView):
|
|||
@ha.callback
|
||||
def get(self, request):
|
||||
"""Get current states."""
|
||||
return self.json(request.app['hass'].states.async_all())
|
||||
user = request['hass_user']
|
||||
entity_perm = user.permissions.check_entity
|
||||
states = [
|
||||
state for state in request.app['hass'].states.async_all()
|
||||
if entity_perm(state.entity_id, 'read')
|
||||
]
|
||||
return self.json(states)
|
||||
|
||||
|
||||
class APIEntityStateView(HomeAssistantView):
|
||||
|
@ -197,6 +206,10 @@ class APIEntityStateView(HomeAssistantView):
|
|||
@ha.callback
|
||||
def get(self, request, entity_id):
|
||||
"""Retrieve state of entity."""
|
||||
user = request['hass_user']
|
||||
if not user.permissions.check_entity(entity_id, POLICY_READ):
|
||||
raise Unauthorized(entity_id=entity_id)
|
||||
|
||||
state = request.app['hass'].states.get(entity_id)
|
||||
if state:
|
||||
return self.json(state)
|
||||
|
@ -204,6 +217,8 @@ class APIEntityStateView(HomeAssistantView):
|
|||
|
||||
async def post(self, request, entity_id):
|
||||
"""Update state of entity."""
|
||||
if not request['hass_user'].is_admin:
|
||||
raise Unauthorized(entity_id=entity_id)
|
||||
hass = request.app['hass']
|
||||
try:
|
||||
data = await request.json()
|
||||
|
@ -236,6 +251,8 @@ class APIEntityStateView(HomeAssistantView):
|
|||
@ha.callback
|
||||
def delete(self, request, entity_id):
|
||||
"""Remove entity."""
|
||||
if not request['hass_user'].is_admin:
|
||||
raise Unauthorized(entity_id=entity_id)
|
||||
if request.app['hass'].states.async_remove(entity_id):
|
||||
return self.json_message("Entity removed.")
|
||||
return self.json_message("Entity not found.", HTTP_NOT_FOUND)
|
||||
|
@ -261,6 +278,8 @@ class APIEventView(HomeAssistantView):
|
|||
|
||||
async def post(self, request, event_type):
|
||||
"""Fire events."""
|
||||
if not request['hass_user'].is_admin:
|
||||
raise Unauthorized()
|
||||
body = await request.text()
|
||||
try:
|
||||
event_data = json.loads(body) if body else None
|
||||
|
@ -346,6 +365,8 @@ class APITemplateView(HomeAssistantView):
|
|||
|
||||
async def post(self, request):
|
||||
"""Render a template."""
|
||||
if not request['hass_user'].is_admin:
|
||||
raise Unauthorized()
|
||||
try:
|
||||
data = await request.json()
|
||||
tpl = template.Template(data['template'], request.app['hass'])
|
||||
|
@ -363,6 +384,8 @@ class APIErrorLog(HomeAssistantView):
|
|||
|
||||
async def get(self, request):
|
||||
"""Retrieve API error log."""
|
||||
if not request['hass_user'].is_admin:
|
||||
raise Unauthorized()
|
||||
return web.FileResponse(request.app['hass'].data[DATA_LOGGING])
|
||||
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ import os
|
|||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.auth.const import GROUP_ID_ADMIN
|
||||
from homeassistant.components import SERVICE_CHECK_CONFIG
|
||||
from homeassistant.const import (
|
||||
ATTR_NAME, SERVICE_HOMEASSISTANT_RESTART, SERVICE_HOMEASSISTANT_STOP)
|
||||
|
@ -181,8 +182,14 @@ async def async_setup(hass, config):
|
|||
if user and user.refresh_tokens:
|
||||
refresh_token = list(user.refresh_tokens.values())[0]
|
||||
|
||||
# Migrate old hass.io users to be admin.
|
||||
if not user.is_admin:
|
||||
await hass.auth.async_update_user(
|
||||
user, group_ids=[GROUP_ID_ADMIN])
|
||||
|
||||
if refresh_token is None:
|
||||
user = await hass.auth.async_create_system_user('Hass.io')
|
||||
user = await hass.auth.async_create_system_user(
|
||||
'Hass.io', [GROUP_ID_ADMIN])
|
||||
refresh_token = await hass.auth.async_create_refresh_token(user)
|
||||
data['hassio_user'] = user.id
|
||||
await store.async_save(data)
|
||||
|
|
|
@ -14,6 +14,7 @@ from aiohttp.web_exceptions import HTTPUnauthorized, HTTPInternalServerError
|
|||
from homeassistant.components.http.ban import process_success_login
|
||||
from homeassistant.core import Context, is_callback
|
||||
from homeassistant.const import CONTENT_TYPE_JSON
|
||||
from homeassistant import exceptions
|
||||
from homeassistant.helpers.json import JSONEncoder
|
||||
|
||||
from .const import KEY_AUTHENTICATED, KEY_REAL_IP
|
||||
|
@ -107,10 +108,13 @@ def request_handler_factory(view, handler):
|
|||
_LOGGER.info('Serving %s to %s (auth: %s)',
|
||||
request.path, request.get(KEY_REAL_IP), authenticated)
|
||||
|
||||
result = handler(request, **request.match_info)
|
||||
try:
|
||||
result = handler(request, **request.match_info)
|
||||
|
||||
if asyncio.iscoroutine(result):
|
||||
result = await result
|
||||
if asyncio.iscoroutine(result):
|
||||
result = await result
|
||||
except exceptions.Unauthorized:
|
||||
raise HTTPUnauthorized()
|
||||
|
||||
if isinstance(result, web.StreamResponse):
|
||||
# The method handler returned a ready-made Response, how nice of it
|
||||
|
|
|
@ -192,9 +192,9 @@ async def entity_service_call(hass, platforms, func, call):
|
|||
user = await hass.auth.async_get_user(call.context.user_id)
|
||||
if user is None:
|
||||
raise UnknownUser(context=call.context)
|
||||
perms = user.permissions
|
||||
entity_perms = user.permissions.check_entity
|
||||
else:
|
||||
perms = None
|
||||
entity_perms = None
|
||||
|
||||
# Are we trying to target all entities
|
||||
target_all_entities = ATTR_ENTITY_ID not in call.data
|
||||
|
@ -218,7 +218,7 @@ async def entity_service_call(hass, platforms, func, call):
|
|||
# the service on.
|
||||
platforms_entities = []
|
||||
|
||||
if perms is None:
|
||||
if entity_perms is None:
|
||||
for platform in platforms:
|
||||
if target_all_entities:
|
||||
platforms_entities.append(list(platform.entities.values()))
|
||||
|
@ -234,7 +234,7 @@ async def entity_service_call(hass, platforms, func, call):
|
|||
for platform in platforms:
|
||||
platforms_entities.append([
|
||||
entity for entity in platform.entities.values()
|
||||
if perms.check_entity(entity.entity_id, POLICY_CONTROL)])
|
||||
if entity_perms(entity.entity_id, POLICY_CONTROL)])
|
||||
|
||||
else:
|
||||
for platform in platforms:
|
||||
|
@ -243,7 +243,7 @@ async def entity_service_call(hass, platforms, func, call):
|
|||
if entity.entity_id not in entity_ids:
|
||||
continue
|
||||
|
||||
if not perms.check_entity(entity.entity_id, POLICY_CONTROL):
|
||||
if not entity_perms(entity.entity_id, POLICY_CONTROL):
|
||||
raise Unauthorized(
|
||||
context=call.context,
|
||||
entity_id=entity.entity_id,
|
||||
|
|
|
@ -10,7 +10,7 @@ def test_entities_none():
|
|||
"""Test entity ID policy."""
|
||||
policy = None
|
||||
compiled = compile_entities(policy)
|
||||
assert compiled('light.kitchen', ('read',)) is False
|
||||
assert compiled('light.kitchen', 'read') is False
|
||||
|
||||
|
||||
def test_entities_empty():
|
||||
|
@ -18,7 +18,7 @@ def test_entities_empty():
|
|||
policy = {}
|
||||
ENTITY_POLICY_SCHEMA(policy)
|
||||
compiled = compile_entities(policy)
|
||||
assert compiled('light.kitchen', ('read',)) is False
|
||||
assert compiled('light.kitchen', 'read') is False
|
||||
|
||||
|
||||
def test_entities_false():
|
||||
|
@ -33,7 +33,7 @@ def test_entities_true():
|
|||
policy = True
|
||||
ENTITY_POLICY_SCHEMA(policy)
|
||||
compiled = compile_entities(policy)
|
||||
assert compiled('light.kitchen', ('read',)) is True
|
||||
assert compiled('light.kitchen', 'read') is True
|
||||
|
||||
|
||||
def test_entities_domains_true():
|
||||
|
@ -43,7 +43,7 @@ def test_entities_domains_true():
|
|||
}
|
||||
ENTITY_POLICY_SCHEMA(policy)
|
||||
compiled = compile_entities(policy)
|
||||
assert compiled('light.kitchen', ('read',)) is True
|
||||
assert compiled('light.kitchen', 'read') is True
|
||||
|
||||
|
||||
def test_entities_domains_domain_true():
|
||||
|
@ -55,8 +55,8 @@ def test_entities_domains_domain_true():
|
|||
}
|
||||
ENTITY_POLICY_SCHEMA(policy)
|
||||
compiled = compile_entities(policy)
|
||||
assert compiled('light.kitchen', ('read',)) is True
|
||||
assert compiled('switch.kitchen', ('read',)) is False
|
||||
assert compiled('light.kitchen', 'read') is True
|
||||
assert compiled('switch.kitchen', 'read') is False
|
||||
|
||||
|
||||
def test_entities_domains_domain_false():
|
||||
|
@ -77,7 +77,7 @@ def test_entities_entity_ids_true():
|
|||
}
|
||||
ENTITY_POLICY_SCHEMA(policy)
|
||||
compiled = compile_entities(policy)
|
||||
assert compiled('light.kitchen', ('read',)) is True
|
||||
assert compiled('light.kitchen', 'read') is True
|
||||
|
||||
|
||||
def test_entities_entity_ids_false():
|
||||
|
@ -98,8 +98,8 @@ def test_entities_entity_ids_entity_id_true():
|
|||
}
|
||||
ENTITY_POLICY_SCHEMA(policy)
|
||||
compiled = compile_entities(policy)
|
||||
assert compiled('light.kitchen', ('read',)) is True
|
||||
assert compiled('switch.kitchen', ('read',)) is False
|
||||
assert compiled('light.kitchen', 'read') is True
|
||||
assert compiled('switch.kitchen', 'read') is False
|
||||
|
||||
|
||||
def test_entities_entity_ids_entity_id_false():
|
||||
|
@ -124,9 +124,9 @@ def test_entities_control_only():
|
|||
}
|
||||
ENTITY_POLICY_SCHEMA(policy)
|
||||
compiled = compile_entities(policy)
|
||||
assert compiled('light.kitchen', ('read',)) is True
|
||||
assert compiled('light.kitchen', ('control',)) is False
|
||||
assert compiled('light.kitchen', ('edit',)) is False
|
||||
assert compiled('light.kitchen', 'read') is True
|
||||
assert compiled('light.kitchen', 'control') is False
|
||||
assert compiled('light.kitchen', 'edit') is False
|
||||
|
||||
|
||||
def test_entities_read_control():
|
||||
|
@ -141,9 +141,9 @@ def test_entities_read_control():
|
|||
}
|
||||
ENTITY_POLICY_SCHEMA(policy)
|
||||
compiled = compile_entities(policy)
|
||||
assert compiled('light.kitchen', ('read',)) is True
|
||||
assert compiled('light.kitchen', ('control',)) is True
|
||||
assert compiled('light.kitchen', ('edit',)) is False
|
||||
assert compiled('light.kitchen', 'read') is True
|
||||
assert compiled('light.kitchen', 'control') is True
|
||||
assert compiled('light.kitchen', 'edit') is False
|
||||
|
||||
|
||||
def test_entities_all_allow():
|
||||
|
@ -153,9 +153,9 @@ def test_entities_all_allow():
|
|||
}
|
||||
ENTITY_POLICY_SCHEMA(policy)
|
||||
compiled = compile_entities(policy)
|
||||
assert compiled('light.kitchen', ('read',)) is True
|
||||
assert compiled('light.kitchen', ('control',)) is True
|
||||
assert compiled('switch.kitchen', ('read',)) is True
|
||||
assert compiled('light.kitchen', 'read') is True
|
||||
assert compiled('light.kitchen', 'control') is True
|
||||
assert compiled('switch.kitchen', 'read') is True
|
||||
|
||||
|
||||
def test_entities_all_read():
|
||||
|
@ -167,9 +167,9 @@ def test_entities_all_read():
|
|||
}
|
||||
ENTITY_POLICY_SCHEMA(policy)
|
||||
compiled = compile_entities(policy)
|
||||
assert compiled('light.kitchen', ('read',)) is True
|
||||
assert compiled('light.kitchen', ('control',)) is False
|
||||
assert compiled('switch.kitchen', ('read',)) is True
|
||||
assert compiled('light.kitchen', 'read') is True
|
||||
assert compiled('light.kitchen', 'control') is False
|
||||
assert compiled('switch.kitchen', 'read') is True
|
||||
|
||||
|
||||
def test_entities_all_control():
|
||||
|
@ -181,7 +181,7 @@ def test_entities_all_control():
|
|||
}
|
||||
ENTITY_POLICY_SCHEMA(policy)
|
||||
compiled = compile_entities(policy)
|
||||
assert compiled('light.kitchen', ('read',)) is False
|
||||
assert compiled('light.kitchen', ('control',)) is True
|
||||
assert compiled('switch.kitchen', ('read',)) is False
|
||||
assert compiled('switch.kitchen', ('control',)) is True
|
||||
assert compiled('light.kitchen', 'read') is False
|
||||
assert compiled('light.kitchen', 'control') is True
|
||||
assert compiled('switch.kitchen', 'read') is False
|
||||
assert compiled('switch.kitchen', 'control') is True
|
||||
|
|
|
@ -1,34 +0,0 @@
|
|||
"""Tests for the auth permission system."""
|
||||
from homeassistant.core import State
|
||||
from homeassistant.auth import permissions
|
||||
|
||||
|
||||
def test_policy_perm_filter_states():
|
||||
"""Test filtering entitites."""
|
||||
states = [
|
||||
State('light.kitchen', 'on'),
|
||||
State('light.living_room', 'off'),
|
||||
State('light.balcony', 'on'),
|
||||
]
|
||||
perm = permissions.PolicyPermissions({
|
||||
'entities': {
|
||||
'entity_ids': {
|
||||
'light.kitchen': True,
|
||||
'light.balcony': True,
|
||||
}
|
||||
}
|
||||
})
|
||||
filtered = perm.filter_states(states)
|
||||
assert len(filtered) == 2
|
||||
assert filtered == [states[0], states[2]]
|
||||
|
||||
|
||||
def test_owner_permissions():
|
||||
"""Test owner permissions access all."""
|
||||
assert permissions.OwnerPermissions.check_entity('light.kitchen', 'write')
|
||||
states = [
|
||||
State('light.kitchen', 'on'),
|
||||
State('light.living_room', 'off'),
|
||||
State('light.balcony', 'on'),
|
||||
]
|
||||
assert permissions.OwnerPermissions.filter_states(states) == states
|
|
@ -14,7 +14,8 @@ from contextlib import contextmanager
|
|||
|
||||
from homeassistant import auth, core as ha, config_entries
|
||||
from homeassistant.auth import (
|
||||
models as auth_models, auth_store, providers as auth_providers)
|
||||
models as auth_models, auth_store, providers as auth_providers,
|
||||
permissions as auth_permissions)
|
||||
from homeassistant.auth.permissions import system_policies
|
||||
from homeassistant.setup import setup_component, async_setup_component
|
||||
from homeassistant.config import async_process_component_config
|
||||
|
@ -400,6 +401,10 @@ class MockUser(auth_models.User):
|
|||
auth_mgr._store._users[self.id] = self
|
||||
return self
|
||||
|
||||
def mock_policy(self, policy):
|
||||
"""Mock a policy for a user."""
|
||||
self._permissions = auth_permissions.PolicyPermissions(policy)
|
||||
|
||||
|
||||
async def register_auth_provider(hass, config):
|
||||
"""Register an auth provider."""
|
||||
|
|
|
@ -80,11 +80,10 @@ def hass_ws_client(aiohttp_client):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def hass_access_token(hass):
|
||||
def hass_access_token(hass, hass_admin_user):
|
||||
"""Return an access token to access Home Assistant."""
|
||||
user = MockUser().add_to_hass(hass)
|
||||
refresh_token = hass.loop.run_until_complete(
|
||||
hass.auth.async_create_refresh_token(user, CLIENT_ID))
|
||||
hass.auth.async_create_refresh_token(hass_admin_user, CLIENT_ID))
|
||||
yield hass.auth.async_create_access_token(refresh_token)
|
||||
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ from unittest.mock import patch, Mock
|
|||
|
||||
import pytest
|
||||
|
||||
from homeassistant.auth.const import GROUP_ID_ADMIN
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.components.hassio import (
|
||||
STORAGE_KEY, async_check_config)
|
||||
|
@ -106,6 +107,8 @@ async def test_setup_api_push_api_data_default(hass, aioclient_mock,
|
|||
)
|
||||
assert hassio_user is not None
|
||||
assert hassio_user.system_generated
|
||||
assert len(hassio_user.groups) == 1
|
||||
assert hassio_user.groups[0].id == GROUP_ID_ADMIN
|
||||
for token in hassio_user.refresh_tokens.values():
|
||||
if token.token == refresh_token:
|
||||
break
|
||||
|
@ -113,6 +116,31 @@ async def test_setup_api_push_api_data_default(hass, aioclient_mock,
|
|||
assert False, 'refresh token not found'
|
||||
|
||||
|
||||
async def test_setup_adds_admin_group_to_user(hass, aioclient_mock,
|
||||
hass_storage):
|
||||
"""Test setup with API push default data."""
|
||||
# Create user without admin
|
||||
user = await hass.auth.async_create_system_user('Hass.io')
|
||||
assert not user.is_admin
|
||||
await hass.auth.async_create_refresh_token(user)
|
||||
|
||||
hass_storage[STORAGE_KEY] = {
|
||||
'data': {'hassio_user': user.id},
|
||||
'key': STORAGE_KEY,
|
||||
'version': 1
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, MOCK_ENVIRON), \
|
||||
patch('homeassistant.auth.AuthManager.active', return_value=True):
|
||||
result = await async_setup_component(hass, 'hassio', {
|
||||
'http': {},
|
||||
'hassio': {}
|
||||
})
|
||||
assert result
|
||||
|
||||
assert user.is_admin
|
||||
|
||||
|
||||
async def test_setup_api_push_api_data_no_auth(hass, aioclient_mock,
|
||||
hass_storage):
|
||||
"""Test setup with API push default data."""
|
||||
|
|
|
@ -16,10 +16,12 @@ from tests.common import async_mock_service
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_client(hass, aiohttp_client):
|
||||
"""Start the Hass HTTP component."""
|
||||
def mock_api_client(hass, aiohttp_client, hass_access_token):
|
||||
"""Start the Hass HTTP component and return admin API client."""
|
||||
hass.loop.run_until_complete(async_setup_component(hass, 'api', {}))
|
||||
return hass.loop.run_until_complete(aiohttp_client(hass.http.app))
|
||||
return hass.loop.run_until_complete(aiohttp_client(hass.http.app, headers={
|
||||
'Authorization': 'Bearer {}'.format(hass_access_token)
|
||||
}))
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
|
@ -405,7 +407,8 @@ def _listen_count(hass):
|
|||
return sum(hass.bus.async_listeners().values())
|
||||
|
||||
|
||||
async def test_api_error_log(hass, aiohttp_client):
|
||||
async def test_api_error_log(hass, aiohttp_client, hass_access_token,
|
||||
hass_admin_user):
|
||||
"""Test if we can fetch the error log."""
|
||||
hass.data[DATA_LOGGING] = '/some/path'
|
||||
await async_setup_component(hass, 'api', {
|
||||
|
@ -416,7 +419,7 @@ async def test_api_error_log(hass, aiohttp_client):
|
|||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
resp = await client.get(const.URL_API_ERROR_LOG)
|
||||
# Verufy auth required
|
||||
# Verify auth required
|
||||
assert resp.status == 401
|
||||
|
||||
with patch(
|
||||
|
@ -424,7 +427,7 @@ async def test_api_error_log(hass, aiohttp_client):
|
|||
return_value=web.Response(status=200, text='Hello')
|
||||
) as mock_file:
|
||||
resp = await client.get(const.URL_API_ERROR_LOG, headers={
|
||||
'x-ha-access': 'yolo'
|
||||
'Authorization': 'Bearer {}'.format(hass_access_token)
|
||||
})
|
||||
|
||||
assert len(mock_file.mock_calls) == 1
|
||||
|
@ -432,6 +435,13 @@ async def test_api_error_log(hass, aiohttp_client):
|
|||
assert resp.status == 200
|
||||
assert await resp.text() == 'Hello'
|
||||
|
||||
# Verify we require admin user
|
||||
hass_admin_user.groups = []
|
||||
resp = await client.get(const.URL_API_ERROR_LOG, headers={
|
||||
'Authorization': 'Bearer {}'.format(hass_access_token)
|
||||
})
|
||||
assert resp.status == 401
|
||||
|
||||
|
||||
async def test_api_fire_event_context(hass, mock_api_client,
|
||||
hass_access_token):
|
||||
|
@ -494,3 +504,67 @@ async def test_api_set_state_context(hass, mock_api_client, hass_access_token):
|
|||
|
||||
state = hass.states.get('light.kitchen')
|
||||
assert state.context.user_id == refresh_token.user.id
|
||||
|
||||
|
||||
async def test_event_stream_requires_admin(hass, mock_api_client,
|
||||
hass_admin_user):
|
||||
"""Test user needs to be admin to access event stream."""
|
||||
hass_admin_user.groups = []
|
||||
resp = await mock_api_client.get('/api/stream')
|
||||
assert resp.status == 401
|
||||
|
||||
|
||||
async def test_states_view_filters(hass, mock_api_client, hass_admin_user):
|
||||
"""Test filtering only visible states."""
|
||||
hass_admin_user.mock_policy({
|
||||
'entities': {
|
||||
'entity_ids': {
|
||||
'test.entity': True
|
||||
}
|
||||
}
|
||||
})
|
||||
hass.states.async_set('test.entity', 'hello')
|
||||
hass.states.async_set('test.not_visible_entity', 'invisible')
|
||||
resp = await mock_api_client.get(const.URL_API_STATES)
|
||||
assert resp.status == 200
|
||||
json = await resp.json()
|
||||
assert len(json) == 1
|
||||
assert json[0]['entity_id'] == 'test.entity'
|
||||
|
||||
|
||||
async def test_get_entity_state_read_perm(hass, mock_api_client,
|
||||
hass_admin_user):
|
||||
"""Test getting a state requires read permission."""
|
||||
hass_admin_user.mock_policy({})
|
||||
resp = await mock_api_client.get('/api/states/light.test')
|
||||
assert resp.status == 401
|
||||
|
||||
|
||||
async def test_post_entity_state_admin(hass, mock_api_client, hass_admin_user):
|
||||
"""Test updating state requires admin."""
|
||||
hass_admin_user.groups = []
|
||||
resp = await mock_api_client.post('/api/states/light.test')
|
||||
assert resp.status == 401
|
||||
|
||||
|
||||
async def test_delete_entity_state_admin(hass, mock_api_client,
|
||||
hass_admin_user):
|
||||
"""Test deleting entity requires admin."""
|
||||
hass_admin_user.groups = []
|
||||
resp = await mock_api_client.delete('/api/states/light.test')
|
||||
assert resp.status == 401
|
||||
|
||||
|
||||
async def test_post_event_admin(hass, mock_api_client, hass_admin_user):
|
||||
"""Test sending event requires admin."""
|
||||
hass_admin_user.groups = []
|
||||
resp = await mock_api_client.post('/api/events/state_changed')
|
||||
assert resp.status == 401
|
||||
|
||||
|
||||
async def test_rendering_template_admin(hass, mock_api_client,
|
||||
hass_admin_user):
|
||||
"""Test rendering a template requires admin."""
|
||||
hass_admin_user.groups = []
|
||||
resp = await mock_api_client.post('/api/template')
|
||||
assert resp.status == 401
|
||||
|
|
Loading…
Add table
Reference in a new issue