Add permission checks to Rest API (#18639)
* Add permission checks to Rest API * Clean up unnecessary method * Remove all the tuple stuff from entity check * Simplify perms * Correct param name for owner permission * Hass.io make/update user to be admin * Types
This commit is contained in:
parent
f387cdec59
commit
8b8629a5f4
15 changed files with 282 additions and 145 deletions
|
@ -132,13 +132,15 @@ class AuthManager:
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def async_create_system_user(self, name: str) -> models.User:
|
async def async_create_system_user(
|
||||||
|
self, name: str,
|
||||||
|
group_ids: Optional[List[str]] = None) -> models.User:
|
||||||
"""Create a system user."""
|
"""Create a system user."""
|
||||||
user = await self._store.async_create_user(
|
user = await self._store.async_create_user(
|
||||||
name=name,
|
name=name,
|
||||||
system_generated=True,
|
system_generated=True,
|
||||||
is_active=True,
|
is_active=True,
|
||||||
group_ids=[],
|
group_ids=group_ids or [],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.hass.bus.async_fire(EVENT_USER_ADDED, {
|
self.hass.bus.async_fire(EVENT_USER_ADDED, {
|
||||||
|
@ -217,6 +219,17 @@ class AuthManager:
|
||||||
'user_id': user.id
|
'user_id': user.id
|
||||||
})
|
})
|
||||||
|
|
||||||
|
async def async_update_user(self, user: models.User,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
group_ids: Optional[List[str]] = None) -> None:
|
||||||
|
"""Update a user."""
|
||||||
|
kwargs = {} # type: Dict[str,Any]
|
||||||
|
if name is not None:
|
||||||
|
kwargs['name'] = name
|
||||||
|
if group_ids is not None:
|
||||||
|
kwargs['group_ids'] = group_ids
|
||||||
|
await self._store.async_update_user(user, **kwargs)
|
||||||
|
|
||||||
async def async_activate_user(self, user: models.User) -> None:
|
async def async_activate_user(self, user: models.User) -> None:
|
||||||
"""Activate a user."""
|
"""Activate a user."""
|
||||||
await self._store.async_activate_user(user)
|
await self._store.async_activate_user(user)
|
||||||
|
|
|
@ -133,6 +133,33 @@ class AuthStore:
|
||||||
self._users.pop(user.id)
|
self._users.pop(user.id)
|
||||||
self._async_schedule_save()
|
self._async_schedule_save()
|
||||||
|
|
||||||
|
async def async_update_user(
|
||||||
|
self, user: models.User, name: Optional[str] = None,
|
||||||
|
is_active: Optional[bool] = None,
|
||||||
|
group_ids: Optional[List[str]] = None) -> None:
|
||||||
|
"""Update a user."""
|
||||||
|
assert self._groups is not None
|
||||||
|
|
||||||
|
if group_ids is not None:
|
||||||
|
groups = []
|
||||||
|
for grid in group_ids:
|
||||||
|
group = self._groups.get(grid)
|
||||||
|
if group is None:
|
||||||
|
raise ValueError("Invalid group specified.")
|
||||||
|
groups.append(group)
|
||||||
|
|
||||||
|
user.groups = groups
|
||||||
|
user.invalidate_permission_cache()
|
||||||
|
|
||||||
|
for attr_name, value in (
|
||||||
|
('name', name),
|
||||||
|
('is_active', is_active),
|
||||||
|
):
|
||||||
|
if value is not None:
|
||||||
|
setattr(user, attr_name, value)
|
||||||
|
|
||||||
|
self._async_schedule_save()
|
||||||
|
|
||||||
async def async_activate_user(self, user: models.User) -> None:
|
async def async_activate_user(self, user: models.User) -> None:
|
||||||
"""Activate a user."""
|
"""Activate a user."""
|
||||||
user.is_active = True
|
user.is_active = True
|
||||||
|
|
|
@ -8,6 +8,7 @@ import attr
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
from . import permissions as perm_mdl
|
from . import permissions as perm_mdl
|
||||||
|
from .const import GROUP_ID_ADMIN
|
||||||
from .util import generate_secret
|
from .util import generate_secret
|
||||||
|
|
||||||
TOKEN_TYPE_NORMAL = 'normal'
|
TOKEN_TYPE_NORMAL = 'normal'
|
||||||
|
@ -48,7 +49,7 @@ class User:
|
||||||
) # type: Dict[str, RefreshToken]
|
) # type: Dict[str, RefreshToken]
|
||||||
|
|
||||||
_permissions = attr.ib(
|
_permissions = attr.ib(
|
||||||
type=perm_mdl.PolicyPermissions,
|
type=Optional[perm_mdl.PolicyPermissions],
|
||||||
init=False,
|
init=False,
|
||||||
cmp=False,
|
cmp=False,
|
||||||
default=None,
|
default=None,
|
||||||
|
@ -69,6 +70,19 @@ class User:
|
||||||
|
|
||||||
return self._permissions
|
return self._permissions
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_admin(self) -> bool:
|
||||||
|
"""Return if user is part of the admin group."""
|
||||||
|
if self.is_owner:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return self.is_active and any(
|
||||||
|
gr.id == GROUP_ID_ADMIN for gr in self.groups)
|
||||||
|
|
||||||
|
def invalidate_permission_cache(self) -> None:
|
||||||
|
"""Invalidate permission cache."""
|
||||||
|
self._permissions = None
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True)
|
@attr.s(slots=True)
|
||||||
class RefreshToken:
|
class RefreshToken:
|
||||||
|
|
|
@ -5,10 +5,8 @@ from typing import ( # noqa: F401
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.core import State
|
|
||||||
|
|
||||||
from .const import CAT_ENTITIES
|
from .const import CAT_ENTITIES
|
||||||
from .types import CategoryType, PolicyType
|
from .types import PolicyType
|
||||||
from .entities import ENTITY_POLICY_SCHEMA, compile_entities
|
from .entities import ENTITY_POLICY_SCHEMA, compile_entities
|
||||||
from .merge import merge_policies # noqa
|
from .merge import merge_policies # noqa
|
||||||
|
|
||||||
|
@ -22,13 +20,20 @@ _LOGGER = logging.getLogger(__name__)
|
||||||
class AbstractPermissions:
|
class AbstractPermissions:
|
||||||
"""Default permissions class."""
|
"""Default permissions class."""
|
||||||
|
|
||||||
def check_entity(self, entity_id: str, key: str) -> bool:
|
_cached_entity_func = None
|
||||||
"""Test if we can access entity."""
|
|
||||||
|
def _entity_func(self) -> Callable[[str, str], bool]:
|
||||||
|
"""Return a function that can test entity access."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def filter_states(self, states: List[State]) -> List[State]:
|
def check_entity(self, entity_id: str, key: str) -> bool:
|
||||||
"""Filter a list of states for what the user is allowed to see."""
|
"""Check if we can access entity."""
|
||||||
raise NotImplementedError
|
entity_func = self._cached_entity_func
|
||||||
|
|
||||||
|
if entity_func is None:
|
||||||
|
entity_func = self._cached_entity_func = self._entity_func()
|
||||||
|
|
||||||
|
return entity_func(entity_id, key)
|
||||||
|
|
||||||
|
|
||||||
class PolicyPermissions(AbstractPermissions):
|
class PolicyPermissions(AbstractPermissions):
|
||||||
|
@ -37,34 +42,10 @@ class PolicyPermissions(AbstractPermissions):
|
||||||
def __init__(self, policy: PolicyType) -> None:
|
def __init__(self, policy: PolicyType) -> None:
|
||||||
"""Initialize the permission class."""
|
"""Initialize the permission class."""
|
||||||
self._policy = policy
|
self._policy = policy
|
||||||
self._compiled = {} # type: Dict[str, Callable[..., bool]]
|
|
||||||
|
|
||||||
def check_entity(self, entity_id: str, key: str) -> bool:
|
def _entity_func(self) -> Callable[[str, str], bool]:
|
||||||
"""Test if we can access entity."""
|
"""Return a function that can test entity access."""
|
||||||
func = self._policy_func(CAT_ENTITIES, compile_entities)
|
return compile_entities(self._policy.get(CAT_ENTITIES))
|
||||||
return func(entity_id, (key,))
|
|
||||||
|
|
||||||
def filter_states(self, states: List[State]) -> List[State]:
|
|
||||||
"""Filter a list of states for what the user is allowed to see."""
|
|
||||||
func = self._policy_func(CAT_ENTITIES, compile_entities)
|
|
||||||
keys = ('read',)
|
|
||||||
return [entity for entity in states if func(entity.entity_id, keys)]
|
|
||||||
|
|
||||||
def _policy_func(self, category: str,
|
|
||||||
compile_func: Callable[[CategoryType], Callable]) \
|
|
||||||
-> Callable[..., bool]:
|
|
||||||
"""Get a policy function."""
|
|
||||||
func = self._compiled.get(category)
|
|
||||||
|
|
||||||
if func:
|
|
||||||
return func
|
|
||||||
|
|
||||||
func = self._compiled[category] = compile_func(
|
|
||||||
self._policy.get(category))
|
|
||||||
|
|
||||||
_LOGGER.debug("Compiled %s func: %s", category, func)
|
|
||||||
|
|
||||||
return func
|
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
"""Equals check."""
|
"""Equals check."""
|
||||||
|
@ -78,13 +59,9 @@ class _OwnerPermissions(AbstractPermissions):
|
||||||
|
|
||||||
# pylint: disable=no-self-use
|
# pylint: disable=no-self-use
|
||||||
|
|
||||||
def check_entity(self, entity_id: str, key: str) -> bool:
|
def _entity_func(self) -> Callable[[str, str], bool]:
|
||||||
"""Test if we can access entity."""
|
"""Return a function that can test entity access."""
|
||||||
return True
|
return lambda entity_id, key: True
|
||||||
|
|
||||||
def filter_states(self, states: List[State]) -> List[State]:
|
|
||||||
"""Filter a list of states for what the user is allowed to see."""
|
|
||||||
return states
|
|
||||||
|
|
||||||
|
|
||||||
OwnerPermissions = _OwnerPermissions() # pylint: disable=invalid-name
|
OwnerPermissions = _OwnerPermissions() # pylint: disable=invalid-name
|
||||||
|
|
|
@ -28,28 +28,28 @@ ENTITY_POLICY_SCHEMA = vol.Any(True, vol.Schema({
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
def _entity_allowed(schema: ValueType, keys: Tuple[str]) \
|
def _entity_allowed(schema: ValueType, key: str) \
|
||||||
-> Union[bool, None]:
|
-> Union[bool, None]:
|
||||||
"""Test if an entity is allowed based on the keys."""
|
"""Test if an entity is allowed based on the keys."""
|
||||||
if schema is None or isinstance(schema, bool):
|
if schema is None or isinstance(schema, bool):
|
||||||
return schema
|
return schema
|
||||||
assert isinstance(schema, dict)
|
assert isinstance(schema, dict)
|
||||||
return schema.get(keys[0])
|
return schema.get(key)
|
||||||
|
|
||||||
|
|
||||||
def compile_entities(policy: CategoryType) \
|
def compile_entities(policy: CategoryType) \
|
||||||
-> Callable[[str, Tuple[str]], bool]:
|
-> Callable[[str, str], bool]:
|
||||||
"""Compile policy into a function that tests policy."""
|
"""Compile policy into a function that tests policy."""
|
||||||
# None, Empty Dict, False
|
# None, Empty Dict, False
|
||||||
if not policy:
|
if not policy:
|
||||||
def apply_policy_deny_all(entity_id: str, keys: Tuple[str]) -> bool:
|
def apply_policy_deny_all(entity_id: str, key: str) -> bool:
|
||||||
"""Decline all."""
|
"""Decline all."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return apply_policy_deny_all
|
return apply_policy_deny_all
|
||||||
|
|
||||||
if policy is True:
|
if policy is True:
|
||||||
def apply_policy_allow_all(entity_id: str, keys: Tuple[str]) -> bool:
|
def apply_policy_allow_all(entity_id: str, key: str) -> bool:
|
||||||
"""Approve all."""
|
"""Approve all."""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ def compile_entities(policy: CategoryType) \
|
||||||
entity_ids = policy.get(ENTITY_ENTITY_IDS)
|
entity_ids = policy.get(ENTITY_ENTITY_IDS)
|
||||||
all_entities = policy.get(SUBCAT_ALL)
|
all_entities = policy.get(SUBCAT_ALL)
|
||||||
|
|
||||||
funcs = [] # type: List[Callable[[str, Tuple[str]], Union[None, bool]]]
|
funcs = [] # type: List[Callable[[str, str], Union[None, bool]]]
|
||||||
|
|
||||||
# The order of these functions matter. The more precise are at the top.
|
# The order of these functions matter. The more precise are at the top.
|
||||||
# If a function returns None, they cannot handle it.
|
# If a function returns None, they cannot handle it.
|
||||||
|
@ -70,23 +70,23 @@ def compile_entities(policy: CategoryType) \
|
||||||
# Setting entity_ids to a boolean is final decision for permissions
|
# Setting entity_ids to a boolean is final decision for permissions
|
||||||
# So return right away.
|
# So return right away.
|
||||||
if isinstance(entity_ids, bool):
|
if isinstance(entity_ids, bool):
|
||||||
def allowed_entity_id_bool(entity_id: str, keys: Tuple[str]) -> bool:
|
def allowed_entity_id_bool(entity_id: str, key: str) -> bool:
|
||||||
"""Test if allowed entity_id."""
|
"""Test if allowed entity_id."""
|
||||||
return entity_ids # type: ignore
|
return entity_ids # type: ignore
|
||||||
|
|
||||||
return allowed_entity_id_bool
|
return allowed_entity_id_bool
|
||||||
|
|
||||||
if entity_ids is not None:
|
if entity_ids is not None:
|
||||||
def allowed_entity_id_dict(entity_id: str, keys: Tuple[str]) \
|
def allowed_entity_id_dict(entity_id: str, key: str) \
|
||||||
-> Union[None, bool]:
|
-> Union[None, bool]:
|
||||||
"""Test if allowed entity_id."""
|
"""Test if allowed entity_id."""
|
||||||
return _entity_allowed(
|
return _entity_allowed(
|
||||||
entity_ids.get(entity_id), keys) # type: ignore
|
entity_ids.get(entity_id), key) # type: ignore
|
||||||
|
|
||||||
funcs.append(allowed_entity_id_dict)
|
funcs.append(allowed_entity_id_dict)
|
||||||
|
|
||||||
if isinstance(domains, bool):
|
if isinstance(domains, bool):
|
||||||
def allowed_domain_bool(entity_id: str, keys: Tuple[str]) \
|
def allowed_domain_bool(entity_id: str, key: str) \
|
||||||
-> Union[None, bool]:
|
-> Union[None, bool]:
|
||||||
"""Test if allowed domain."""
|
"""Test if allowed domain."""
|
||||||
return domains
|
return domains
|
||||||
|
@ -94,31 +94,31 @@ def compile_entities(policy: CategoryType) \
|
||||||
funcs.append(allowed_domain_bool)
|
funcs.append(allowed_domain_bool)
|
||||||
|
|
||||||
elif domains is not None:
|
elif domains is not None:
|
||||||
def allowed_domain_dict(entity_id: str, keys: Tuple[str]) \
|
def allowed_domain_dict(entity_id: str, key: str) \
|
||||||
-> Union[None, bool]:
|
-> Union[None, bool]:
|
||||||
"""Test if allowed domain."""
|
"""Test if allowed domain."""
|
||||||
domain = entity_id.split(".", 1)[0]
|
domain = entity_id.split(".", 1)[0]
|
||||||
return _entity_allowed(domains.get(domain), keys) # type: ignore
|
return _entity_allowed(domains.get(domain), key) # type: ignore
|
||||||
|
|
||||||
funcs.append(allowed_domain_dict)
|
funcs.append(allowed_domain_dict)
|
||||||
|
|
||||||
if isinstance(all_entities, bool):
|
if isinstance(all_entities, bool):
|
||||||
def allowed_all_entities_bool(entity_id: str, keys: Tuple[str]) \
|
def allowed_all_entities_bool(entity_id: str, key: str) \
|
||||||
-> Union[None, bool]:
|
-> Union[None, bool]:
|
||||||
"""Test if allowed domain."""
|
"""Test if allowed domain."""
|
||||||
return all_entities
|
return all_entities
|
||||||
funcs.append(allowed_all_entities_bool)
|
funcs.append(allowed_all_entities_bool)
|
||||||
|
|
||||||
elif all_entities is not None:
|
elif all_entities is not None:
|
||||||
def allowed_all_entities_dict(entity_id: str, keys: Tuple[str]) \
|
def allowed_all_entities_dict(entity_id: str, key: str) \
|
||||||
-> Union[None, bool]:
|
-> Union[None, bool]:
|
||||||
"""Test if allowed domain."""
|
"""Test if allowed domain."""
|
||||||
return _entity_allowed(all_entities, keys)
|
return _entity_allowed(all_entities, key)
|
||||||
funcs.append(allowed_all_entities_dict)
|
funcs.append(allowed_all_entities_dict)
|
||||||
|
|
||||||
# Can happen if no valid subcategories specified
|
# Can happen if no valid subcategories specified
|
||||||
if not funcs:
|
if not funcs:
|
||||||
def apply_policy_deny_all_2(entity_id: str, keys: Tuple[str]) -> bool:
|
def apply_policy_deny_all_2(entity_id: str, key: str) -> bool:
|
||||||
"""Decline all."""
|
"""Decline all."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -128,16 +128,16 @@ def compile_entities(policy: CategoryType) \
|
||||||
func = funcs[0]
|
func = funcs[0]
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def apply_policy_func(entity_id: str, keys: Tuple[str]) -> bool:
|
def apply_policy_func(entity_id: str, key: str) -> bool:
|
||||||
"""Apply a single policy function."""
|
"""Apply a single policy function."""
|
||||||
return func(entity_id, keys) is True
|
return func(entity_id, key) is True
|
||||||
|
|
||||||
return apply_policy_func
|
return apply_policy_func
|
||||||
|
|
||||||
def apply_policy_funcs(entity_id: str, keys: Tuple[str]) -> bool:
|
def apply_policy_funcs(entity_id: str, key: str) -> bool:
|
||||||
"""Apply several policy functions."""
|
"""Apply several policy functions."""
|
||||||
for func in funcs:
|
for func in funcs:
|
||||||
result = func(entity_id, keys)
|
result = func(entity_id, key)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
return result
|
return result
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -20,7 +20,8 @@ from homeassistant.const import (
|
||||||
URL_API_SERVICES, URL_API_STATES, URL_API_STATES_ENTITY, URL_API_STREAM,
|
URL_API_SERVICES, URL_API_STATES, URL_API_STATES_ENTITY, URL_API_STREAM,
|
||||||
URL_API_TEMPLATE, __version__)
|
URL_API_TEMPLATE, __version__)
|
||||||
import homeassistant.core as ha
|
import homeassistant.core as ha
|
||||||
from homeassistant.exceptions import TemplateError
|
from homeassistant.auth.permissions.const import POLICY_READ
|
||||||
|
from homeassistant.exceptions import TemplateError, Unauthorized
|
||||||
from homeassistant.helpers import template
|
from homeassistant.helpers import template
|
||||||
from homeassistant.helpers.service import async_get_all_descriptions
|
from homeassistant.helpers.service import async_get_all_descriptions
|
||||||
from homeassistant.helpers.state import AsyncTrackStates
|
from homeassistant.helpers.state import AsyncTrackStates
|
||||||
|
@ -81,6 +82,8 @@ class APIEventStream(HomeAssistantView):
|
||||||
|
|
||||||
async def get(self, request):
|
async def get(self, request):
|
||||||
"""Provide a streaming interface for the event bus."""
|
"""Provide a streaming interface for the event bus."""
|
||||||
|
if not request['hass_user'].is_admin:
|
||||||
|
raise Unauthorized()
|
||||||
hass = request.app['hass']
|
hass = request.app['hass']
|
||||||
stop_obj = object()
|
stop_obj = object()
|
||||||
to_write = asyncio.Queue(loop=hass.loop)
|
to_write = asyncio.Queue(loop=hass.loop)
|
||||||
|
@ -185,7 +188,13 @@ class APIStatesView(HomeAssistantView):
|
||||||
@ha.callback
|
@ha.callback
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
"""Get current states."""
|
"""Get current states."""
|
||||||
return self.json(request.app['hass'].states.async_all())
|
user = request['hass_user']
|
||||||
|
entity_perm = user.permissions.check_entity
|
||||||
|
states = [
|
||||||
|
state for state in request.app['hass'].states.async_all()
|
||||||
|
if entity_perm(state.entity_id, 'read')
|
||||||
|
]
|
||||||
|
return self.json(states)
|
||||||
|
|
||||||
|
|
||||||
class APIEntityStateView(HomeAssistantView):
|
class APIEntityStateView(HomeAssistantView):
|
||||||
|
@ -197,6 +206,10 @@ class APIEntityStateView(HomeAssistantView):
|
||||||
@ha.callback
|
@ha.callback
|
||||||
def get(self, request, entity_id):
|
def get(self, request, entity_id):
|
||||||
"""Retrieve state of entity."""
|
"""Retrieve state of entity."""
|
||||||
|
user = request['hass_user']
|
||||||
|
if not user.permissions.check_entity(entity_id, POLICY_READ):
|
||||||
|
raise Unauthorized(entity_id=entity_id)
|
||||||
|
|
||||||
state = request.app['hass'].states.get(entity_id)
|
state = request.app['hass'].states.get(entity_id)
|
||||||
if state:
|
if state:
|
||||||
return self.json(state)
|
return self.json(state)
|
||||||
|
@ -204,6 +217,8 @@ class APIEntityStateView(HomeAssistantView):
|
||||||
|
|
||||||
async def post(self, request, entity_id):
|
async def post(self, request, entity_id):
|
||||||
"""Update state of entity."""
|
"""Update state of entity."""
|
||||||
|
if not request['hass_user'].is_admin:
|
||||||
|
raise Unauthorized(entity_id=entity_id)
|
||||||
hass = request.app['hass']
|
hass = request.app['hass']
|
||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
|
@ -236,6 +251,8 @@ class APIEntityStateView(HomeAssistantView):
|
||||||
@ha.callback
|
@ha.callback
|
||||||
def delete(self, request, entity_id):
|
def delete(self, request, entity_id):
|
||||||
"""Remove entity."""
|
"""Remove entity."""
|
||||||
|
if not request['hass_user'].is_admin:
|
||||||
|
raise Unauthorized(entity_id=entity_id)
|
||||||
if request.app['hass'].states.async_remove(entity_id):
|
if request.app['hass'].states.async_remove(entity_id):
|
||||||
return self.json_message("Entity removed.")
|
return self.json_message("Entity removed.")
|
||||||
return self.json_message("Entity not found.", HTTP_NOT_FOUND)
|
return self.json_message("Entity not found.", HTTP_NOT_FOUND)
|
||||||
|
@ -261,6 +278,8 @@ class APIEventView(HomeAssistantView):
|
||||||
|
|
||||||
async def post(self, request, event_type):
|
async def post(self, request, event_type):
|
||||||
"""Fire events."""
|
"""Fire events."""
|
||||||
|
if not request['hass_user'].is_admin:
|
||||||
|
raise Unauthorized()
|
||||||
body = await request.text()
|
body = await request.text()
|
||||||
try:
|
try:
|
||||||
event_data = json.loads(body) if body else None
|
event_data = json.loads(body) if body else None
|
||||||
|
@ -346,6 +365,8 @@ class APITemplateView(HomeAssistantView):
|
||||||
|
|
||||||
async def post(self, request):
|
async def post(self, request):
|
||||||
"""Render a template."""
|
"""Render a template."""
|
||||||
|
if not request['hass_user'].is_admin:
|
||||||
|
raise Unauthorized()
|
||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
tpl = template.Template(data['template'], request.app['hass'])
|
tpl = template.Template(data['template'], request.app['hass'])
|
||||||
|
@ -363,6 +384,8 @@ class APIErrorLog(HomeAssistantView):
|
||||||
|
|
||||||
async def get(self, request):
|
async def get(self, request):
|
||||||
"""Retrieve API error log."""
|
"""Retrieve API error log."""
|
||||||
|
if not request['hass_user'].is_admin:
|
||||||
|
raise Unauthorized()
|
||||||
return web.FileResponse(request.app['hass'].data[DATA_LOGGING])
|
return web.FileResponse(request.app['hass'].data[DATA_LOGGING])
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ import os
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.auth.const import GROUP_ID_ADMIN
|
||||||
from homeassistant.components import SERVICE_CHECK_CONFIG
|
from homeassistant.components import SERVICE_CHECK_CONFIG
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
ATTR_NAME, SERVICE_HOMEASSISTANT_RESTART, SERVICE_HOMEASSISTANT_STOP)
|
ATTR_NAME, SERVICE_HOMEASSISTANT_RESTART, SERVICE_HOMEASSISTANT_STOP)
|
||||||
|
@ -181,8 +182,14 @@ async def async_setup(hass, config):
|
||||||
if user and user.refresh_tokens:
|
if user and user.refresh_tokens:
|
||||||
refresh_token = list(user.refresh_tokens.values())[0]
|
refresh_token = list(user.refresh_tokens.values())[0]
|
||||||
|
|
||||||
|
# Migrate old hass.io users to be admin.
|
||||||
|
if not user.is_admin:
|
||||||
|
await hass.auth.async_update_user(
|
||||||
|
user, group_ids=[GROUP_ID_ADMIN])
|
||||||
|
|
||||||
if refresh_token is None:
|
if refresh_token is None:
|
||||||
user = await hass.auth.async_create_system_user('Hass.io')
|
user = await hass.auth.async_create_system_user(
|
||||||
|
'Hass.io', [GROUP_ID_ADMIN])
|
||||||
refresh_token = await hass.auth.async_create_refresh_token(user)
|
refresh_token = await hass.auth.async_create_refresh_token(user)
|
||||||
data['hassio_user'] = user.id
|
data['hassio_user'] = user.id
|
||||||
await store.async_save(data)
|
await store.async_save(data)
|
||||||
|
|
|
@ -14,6 +14,7 @@ from aiohttp.web_exceptions import HTTPUnauthorized, HTTPInternalServerError
|
||||||
from homeassistant.components.http.ban import process_success_login
|
from homeassistant.components.http.ban import process_success_login
|
||||||
from homeassistant.core import Context, is_callback
|
from homeassistant.core import Context, is_callback
|
||||||
from homeassistant.const import CONTENT_TYPE_JSON
|
from homeassistant.const import CONTENT_TYPE_JSON
|
||||||
|
from homeassistant import exceptions
|
||||||
from homeassistant.helpers.json import JSONEncoder
|
from homeassistant.helpers.json import JSONEncoder
|
||||||
|
|
||||||
from .const import KEY_AUTHENTICATED, KEY_REAL_IP
|
from .const import KEY_AUTHENTICATED, KEY_REAL_IP
|
||||||
|
@ -107,10 +108,13 @@ def request_handler_factory(view, handler):
|
||||||
_LOGGER.info('Serving %s to %s (auth: %s)',
|
_LOGGER.info('Serving %s to %s (auth: %s)',
|
||||||
request.path, request.get(KEY_REAL_IP), authenticated)
|
request.path, request.get(KEY_REAL_IP), authenticated)
|
||||||
|
|
||||||
result = handler(request, **request.match_info)
|
try:
|
||||||
|
result = handler(request, **request.match_info)
|
||||||
|
|
||||||
if asyncio.iscoroutine(result):
|
if asyncio.iscoroutine(result):
|
||||||
result = await result
|
result = await result
|
||||||
|
except exceptions.Unauthorized:
|
||||||
|
raise HTTPUnauthorized()
|
||||||
|
|
||||||
if isinstance(result, web.StreamResponse):
|
if isinstance(result, web.StreamResponse):
|
||||||
# The method handler returned a ready-made Response, how nice of it
|
# The method handler returned a ready-made Response, how nice of it
|
||||||
|
|
|
@ -192,9 +192,9 @@ async def entity_service_call(hass, platforms, func, call):
|
||||||
user = await hass.auth.async_get_user(call.context.user_id)
|
user = await hass.auth.async_get_user(call.context.user_id)
|
||||||
if user is None:
|
if user is None:
|
||||||
raise UnknownUser(context=call.context)
|
raise UnknownUser(context=call.context)
|
||||||
perms = user.permissions
|
entity_perms = user.permissions.check_entity
|
||||||
else:
|
else:
|
||||||
perms = None
|
entity_perms = None
|
||||||
|
|
||||||
# Are we trying to target all entities
|
# Are we trying to target all entities
|
||||||
target_all_entities = ATTR_ENTITY_ID not in call.data
|
target_all_entities = ATTR_ENTITY_ID not in call.data
|
||||||
|
@ -218,7 +218,7 @@ async def entity_service_call(hass, platforms, func, call):
|
||||||
# the service on.
|
# the service on.
|
||||||
platforms_entities = []
|
platforms_entities = []
|
||||||
|
|
||||||
if perms is None:
|
if entity_perms is None:
|
||||||
for platform in platforms:
|
for platform in platforms:
|
||||||
if target_all_entities:
|
if target_all_entities:
|
||||||
platforms_entities.append(list(platform.entities.values()))
|
platforms_entities.append(list(platform.entities.values()))
|
||||||
|
@ -234,7 +234,7 @@ async def entity_service_call(hass, platforms, func, call):
|
||||||
for platform in platforms:
|
for platform in platforms:
|
||||||
platforms_entities.append([
|
platforms_entities.append([
|
||||||
entity for entity in platform.entities.values()
|
entity for entity in platform.entities.values()
|
||||||
if perms.check_entity(entity.entity_id, POLICY_CONTROL)])
|
if entity_perms(entity.entity_id, POLICY_CONTROL)])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
for platform in platforms:
|
for platform in platforms:
|
||||||
|
@ -243,7 +243,7 @@ async def entity_service_call(hass, platforms, func, call):
|
||||||
if entity.entity_id not in entity_ids:
|
if entity.entity_id not in entity_ids:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not perms.check_entity(entity.entity_id, POLICY_CONTROL):
|
if not entity_perms(entity.entity_id, POLICY_CONTROL):
|
||||||
raise Unauthorized(
|
raise Unauthorized(
|
||||||
context=call.context,
|
context=call.context,
|
||||||
entity_id=entity.entity_id,
|
entity_id=entity.entity_id,
|
||||||
|
|
|
@ -10,7 +10,7 @@ def test_entities_none():
|
||||||
"""Test entity ID policy."""
|
"""Test entity ID policy."""
|
||||||
policy = None
|
policy = None
|
||||||
compiled = compile_entities(policy)
|
compiled = compile_entities(policy)
|
||||||
assert compiled('light.kitchen', ('read',)) is False
|
assert compiled('light.kitchen', 'read') is False
|
||||||
|
|
||||||
|
|
||||||
def test_entities_empty():
|
def test_entities_empty():
|
||||||
|
@ -18,7 +18,7 @@ def test_entities_empty():
|
||||||
policy = {}
|
policy = {}
|
||||||
ENTITY_POLICY_SCHEMA(policy)
|
ENTITY_POLICY_SCHEMA(policy)
|
||||||
compiled = compile_entities(policy)
|
compiled = compile_entities(policy)
|
||||||
assert compiled('light.kitchen', ('read',)) is False
|
assert compiled('light.kitchen', 'read') is False
|
||||||
|
|
||||||
|
|
||||||
def test_entities_false():
|
def test_entities_false():
|
||||||
|
@ -33,7 +33,7 @@ def test_entities_true():
|
||||||
policy = True
|
policy = True
|
||||||
ENTITY_POLICY_SCHEMA(policy)
|
ENTITY_POLICY_SCHEMA(policy)
|
||||||
compiled = compile_entities(policy)
|
compiled = compile_entities(policy)
|
||||||
assert compiled('light.kitchen', ('read',)) is True
|
assert compiled('light.kitchen', 'read') is True
|
||||||
|
|
||||||
|
|
||||||
def test_entities_domains_true():
|
def test_entities_domains_true():
|
||||||
|
@ -43,7 +43,7 @@ def test_entities_domains_true():
|
||||||
}
|
}
|
||||||
ENTITY_POLICY_SCHEMA(policy)
|
ENTITY_POLICY_SCHEMA(policy)
|
||||||
compiled = compile_entities(policy)
|
compiled = compile_entities(policy)
|
||||||
assert compiled('light.kitchen', ('read',)) is True
|
assert compiled('light.kitchen', 'read') is True
|
||||||
|
|
||||||
|
|
||||||
def test_entities_domains_domain_true():
|
def test_entities_domains_domain_true():
|
||||||
|
@ -55,8 +55,8 @@ def test_entities_domains_domain_true():
|
||||||
}
|
}
|
||||||
ENTITY_POLICY_SCHEMA(policy)
|
ENTITY_POLICY_SCHEMA(policy)
|
||||||
compiled = compile_entities(policy)
|
compiled = compile_entities(policy)
|
||||||
assert compiled('light.kitchen', ('read',)) is True
|
assert compiled('light.kitchen', 'read') is True
|
||||||
assert compiled('switch.kitchen', ('read',)) is False
|
assert compiled('switch.kitchen', 'read') is False
|
||||||
|
|
||||||
|
|
||||||
def test_entities_domains_domain_false():
|
def test_entities_domains_domain_false():
|
||||||
|
@ -77,7 +77,7 @@ def test_entities_entity_ids_true():
|
||||||
}
|
}
|
||||||
ENTITY_POLICY_SCHEMA(policy)
|
ENTITY_POLICY_SCHEMA(policy)
|
||||||
compiled = compile_entities(policy)
|
compiled = compile_entities(policy)
|
||||||
assert compiled('light.kitchen', ('read',)) is True
|
assert compiled('light.kitchen', 'read') is True
|
||||||
|
|
||||||
|
|
||||||
def test_entities_entity_ids_false():
|
def test_entities_entity_ids_false():
|
||||||
|
@ -98,8 +98,8 @@ def test_entities_entity_ids_entity_id_true():
|
||||||
}
|
}
|
||||||
ENTITY_POLICY_SCHEMA(policy)
|
ENTITY_POLICY_SCHEMA(policy)
|
||||||
compiled = compile_entities(policy)
|
compiled = compile_entities(policy)
|
||||||
assert compiled('light.kitchen', ('read',)) is True
|
assert compiled('light.kitchen', 'read') is True
|
||||||
assert compiled('switch.kitchen', ('read',)) is False
|
assert compiled('switch.kitchen', 'read') is False
|
||||||
|
|
||||||
|
|
||||||
def test_entities_entity_ids_entity_id_false():
|
def test_entities_entity_ids_entity_id_false():
|
||||||
|
@ -124,9 +124,9 @@ def test_entities_control_only():
|
||||||
}
|
}
|
||||||
ENTITY_POLICY_SCHEMA(policy)
|
ENTITY_POLICY_SCHEMA(policy)
|
||||||
compiled = compile_entities(policy)
|
compiled = compile_entities(policy)
|
||||||
assert compiled('light.kitchen', ('read',)) is True
|
assert compiled('light.kitchen', 'read') is True
|
||||||
assert compiled('light.kitchen', ('control',)) is False
|
assert compiled('light.kitchen', 'control') is False
|
||||||
assert compiled('light.kitchen', ('edit',)) is False
|
assert compiled('light.kitchen', 'edit') is False
|
||||||
|
|
||||||
|
|
||||||
def test_entities_read_control():
|
def test_entities_read_control():
|
||||||
|
@ -141,9 +141,9 @@ def test_entities_read_control():
|
||||||
}
|
}
|
||||||
ENTITY_POLICY_SCHEMA(policy)
|
ENTITY_POLICY_SCHEMA(policy)
|
||||||
compiled = compile_entities(policy)
|
compiled = compile_entities(policy)
|
||||||
assert compiled('light.kitchen', ('read',)) is True
|
assert compiled('light.kitchen', 'read') is True
|
||||||
assert compiled('light.kitchen', ('control',)) is True
|
assert compiled('light.kitchen', 'control') is True
|
||||||
assert compiled('light.kitchen', ('edit',)) is False
|
assert compiled('light.kitchen', 'edit') is False
|
||||||
|
|
||||||
|
|
||||||
def test_entities_all_allow():
|
def test_entities_all_allow():
|
||||||
|
@ -153,9 +153,9 @@ def test_entities_all_allow():
|
||||||
}
|
}
|
||||||
ENTITY_POLICY_SCHEMA(policy)
|
ENTITY_POLICY_SCHEMA(policy)
|
||||||
compiled = compile_entities(policy)
|
compiled = compile_entities(policy)
|
||||||
assert compiled('light.kitchen', ('read',)) is True
|
assert compiled('light.kitchen', 'read') is True
|
||||||
assert compiled('light.kitchen', ('control',)) is True
|
assert compiled('light.kitchen', 'control') is True
|
||||||
assert compiled('switch.kitchen', ('read',)) is True
|
assert compiled('switch.kitchen', 'read') is True
|
||||||
|
|
||||||
|
|
||||||
def test_entities_all_read():
|
def test_entities_all_read():
|
||||||
|
@ -167,9 +167,9 @@ def test_entities_all_read():
|
||||||
}
|
}
|
||||||
ENTITY_POLICY_SCHEMA(policy)
|
ENTITY_POLICY_SCHEMA(policy)
|
||||||
compiled = compile_entities(policy)
|
compiled = compile_entities(policy)
|
||||||
assert compiled('light.kitchen', ('read',)) is True
|
assert compiled('light.kitchen', 'read') is True
|
||||||
assert compiled('light.kitchen', ('control',)) is False
|
assert compiled('light.kitchen', 'control') is False
|
||||||
assert compiled('switch.kitchen', ('read',)) is True
|
assert compiled('switch.kitchen', 'read') is True
|
||||||
|
|
||||||
|
|
||||||
def test_entities_all_control():
|
def test_entities_all_control():
|
||||||
|
@ -181,7 +181,7 @@ def test_entities_all_control():
|
||||||
}
|
}
|
||||||
ENTITY_POLICY_SCHEMA(policy)
|
ENTITY_POLICY_SCHEMA(policy)
|
||||||
compiled = compile_entities(policy)
|
compiled = compile_entities(policy)
|
||||||
assert compiled('light.kitchen', ('read',)) is False
|
assert compiled('light.kitchen', 'read') is False
|
||||||
assert compiled('light.kitchen', ('control',)) is True
|
assert compiled('light.kitchen', 'control') is True
|
||||||
assert compiled('switch.kitchen', ('read',)) is False
|
assert compiled('switch.kitchen', 'read') is False
|
||||||
assert compiled('switch.kitchen', ('control',)) is True
|
assert compiled('switch.kitchen', 'control') is True
|
||||||
|
|
|
@ -1,34 +0,0 @@
|
||||||
"""Tests for the auth permission system."""
|
|
||||||
from homeassistant.core import State
|
|
||||||
from homeassistant.auth import permissions
|
|
||||||
|
|
||||||
|
|
||||||
def test_policy_perm_filter_states():
|
|
||||||
"""Test filtering entitites."""
|
|
||||||
states = [
|
|
||||||
State('light.kitchen', 'on'),
|
|
||||||
State('light.living_room', 'off'),
|
|
||||||
State('light.balcony', 'on'),
|
|
||||||
]
|
|
||||||
perm = permissions.PolicyPermissions({
|
|
||||||
'entities': {
|
|
||||||
'entity_ids': {
|
|
||||||
'light.kitchen': True,
|
|
||||||
'light.balcony': True,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
filtered = perm.filter_states(states)
|
|
||||||
assert len(filtered) == 2
|
|
||||||
assert filtered == [states[0], states[2]]
|
|
||||||
|
|
||||||
|
|
||||||
def test_owner_permissions():
|
|
||||||
"""Test owner permissions access all."""
|
|
||||||
assert permissions.OwnerPermissions.check_entity('light.kitchen', 'write')
|
|
||||||
states = [
|
|
||||||
State('light.kitchen', 'on'),
|
|
||||||
State('light.living_room', 'off'),
|
|
||||||
State('light.balcony', 'on'),
|
|
||||||
]
|
|
||||||
assert permissions.OwnerPermissions.filter_states(states) == states
|
|
|
@ -14,7 +14,8 @@ from contextlib import contextmanager
|
||||||
|
|
||||||
from homeassistant import auth, core as ha, config_entries
|
from homeassistant import auth, core as ha, config_entries
|
||||||
from homeassistant.auth import (
|
from homeassistant.auth import (
|
||||||
models as auth_models, auth_store, providers as auth_providers)
|
models as auth_models, auth_store, providers as auth_providers,
|
||||||
|
permissions as auth_permissions)
|
||||||
from homeassistant.auth.permissions import system_policies
|
from homeassistant.auth.permissions import system_policies
|
||||||
from homeassistant.setup import setup_component, async_setup_component
|
from homeassistant.setup import setup_component, async_setup_component
|
||||||
from homeassistant.config import async_process_component_config
|
from homeassistant.config import async_process_component_config
|
||||||
|
@ -400,6 +401,10 @@ class MockUser(auth_models.User):
|
||||||
auth_mgr._store._users[self.id] = self
|
auth_mgr._store._users[self.id] = self
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def mock_policy(self, policy):
|
||||||
|
"""Mock a policy for a user."""
|
||||||
|
self._permissions = auth_permissions.PolicyPermissions(policy)
|
||||||
|
|
||||||
|
|
||||||
async def register_auth_provider(hass, config):
|
async def register_auth_provider(hass, config):
|
||||||
"""Register an auth provider."""
|
"""Register an auth provider."""
|
||||||
|
|
|
@ -80,11 +80,10 @@ def hass_ws_client(aiohttp_client):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def hass_access_token(hass):
|
def hass_access_token(hass, hass_admin_user):
|
||||||
"""Return an access token to access Home Assistant."""
|
"""Return an access token to access Home Assistant."""
|
||||||
user = MockUser().add_to_hass(hass)
|
|
||||||
refresh_token = hass.loop.run_until_complete(
|
refresh_token = hass.loop.run_until_complete(
|
||||||
hass.auth.async_create_refresh_token(user, CLIENT_ID))
|
hass.auth.async_create_refresh_token(hass_admin_user, CLIENT_ID))
|
||||||
yield hass.auth.async_create_access_token(refresh_token)
|
yield hass.auth.async_create_access_token(refresh_token)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ from unittest.mock import patch, Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.auth.const import GROUP_ID_ADMIN
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
from homeassistant.components.hassio import (
|
from homeassistant.components.hassio import (
|
||||||
STORAGE_KEY, async_check_config)
|
STORAGE_KEY, async_check_config)
|
||||||
|
@ -106,6 +107,8 @@ async def test_setup_api_push_api_data_default(hass, aioclient_mock,
|
||||||
)
|
)
|
||||||
assert hassio_user is not None
|
assert hassio_user is not None
|
||||||
assert hassio_user.system_generated
|
assert hassio_user.system_generated
|
||||||
|
assert len(hassio_user.groups) == 1
|
||||||
|
assert hassio_user.groups[0].id == GROUP_ID_ADMIN
|
||||||
for token in hassio_user.refresh_tokens.values():
|
for token in hassio_user.refresh_tokens.values():
|
||||||
if token.token == refresh_token:
|
if token.token == refresh_token:
|
||||||
break
|
break
|
||||||
|
@ -113,6 +116,31 @@ async def test_setup_api_push_api_data_default(hass, aioclient_mock,
|
||||||
assert False, 'refresh token not found'
|
assert False, 'refresh token not found'
|
||||||
|
|
||||||
|
|
||||||
|
async def test_setup_adds_admin_group_to_user(hass, aioclient_mock,
|
||||||
|
hass_storage):
|
||||||
|
"""Test setup with API push default data."""
|
||||||
|
# Create user without admin
|
||||||
|
user = await hass.auth.async_create_system_user('Hass.io')
|
||||||
|
assert not user.is_admin
|
||||||
|
await hass.auth.async_create_refresh_token(user)
|
||||||
|
|
||||||
|
hass_storage[STORAGE_KEY] = {
|
||||||
|
'data': {'hassio_user': user.id},
|
||||||
|
'key': STORAGE_KEY,
|
||||||
|
'version': 1
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.dict(os.environ, MOCK_ENVIRON), \
|
||||||
|
patch('homeassistant.auth.AuthManager.active', return_value=True):
|
||||||
|
result = await async_setup_component(hass, 'hassio', {
|
||||||
|
'http': {},
|
||||||
|
'hassio': {}
|
||||||
|
})
|
||||||
|
assert result
|
||||||
|
|
||||||
|
assert user.is_admin
|
||||||
|
|
||||||
|
|
||||||
async def test_setup_api_push_api_data_no_auth(hass, aioclient_mock,
|
async def test_setup_api_push_api_data_no_auth(hass, aioclient_mock,
|
||||||
hass_storage):
|
hass_storage):
|
||||||
"""Test setup with API push default data."""
|
"""Test setup with API push default data."""
|
||||||
|
|
|
@ -16,10 +16,12 @@ from tests.common import async_mock_service
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_api_client(hass, aiohttp_client):
|
def mock_api_client(hass, aiohttp_client, hass_access_token):
|
||||||
"""Start the Hass HTTP component."""
|
"""Start the Hass HTTP component and return admin API client."""
|
||||||
hass.loop.run_until_complete(async_setup_component(hass, 'api', {}))
|
hass.loop.run_until_complete(async_setup_component(hass, 'api', {}))
|
||||||
return hass.loop.run_until_complete(aiohttp_client(hass.http.app))
|
return hass.loop.run_until_complete(aiohttp_client(hass.http.app, headers={
|
||||||
|
'Authorization': 'Bearer {}'.format(hass_access_token)
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
|
@ -405,7 +407,8 @@ def _listen_count(hass):
|
||||||
return sum(hass.bus.async_listeners().values())
|
return sum(hass.bus.async_listeners().values())
|
||||||
|
|
||||||
|
|
||||||
async def test_api_error_log(hass, aiohttp_client):
|
async def test_api_error_log(hass, aiohttp_client, hass_access_token,
|
||||||
|
hass_admin_user):
|
||||||
"""Test if we can fetch the error log."""
|
"""Test if we can fetch the error log."""
|
||||||
hass.data[DATA_LOGGING] = '/some/path'
|
hass.data[DATA_LOGGING] = '/some/path'
|
||||||
await async_setup_component(hass, 'api', {
|
await async_setup_component(hass, 'api', {
|
||||||
|
@ -416,7 +419,7 @@ async def test_api_error_log(hass, aiohttp_client):
|
||||||
client = await aiohttp_client(hass.http.app)
|
client = await aiohttp_client(hass.http.app)
|
||||||
|
|
||||||
resp = await client.get(const.URL_API_ERROR_LOG)
|
resp = await client.get(const.URL_API_ERROR_LOG)
|
||||||
# Verufy auth required
|
# Verify auth required
|
||||||
assert resp.status == 401
|
assert resp.status == 401
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
|
@ -424,7 +427,7 @@ async def test_api_error_log(hass, aiohttp_client):
|
||||||
return_value=web.Response(status=200, text='Hello')
|
return_value=web.Response(status=200, text='Hello')
|
||||||
) as mock_file:
|
) as mock_file:
|
||||||
resp = await client.get(const.URL_API_ERROR_LOG, headers={
|
resp = await client.get(const.URL_API_ERROR_LOG, headers={
|
||||||
'x-ha-access': 'yolo'
|
'Authorization': 'Bearer {}'.format(hass_access_token)
|
||||||
})
|
})
|
||||||
|
|
||||||
assert len(mock_file.mock_calls) == 1
|
assert len(mock_file.mock_calls) == 1
|
||||||
|
@ -432,6 +435,13 @@ async def test_api_error_log(hass, aiohttp_client):
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
assert await resp.text() == 'Hello'
|
assert await resp.text() == 'Hello'
|
||||||
|
|
||||||
|
# Verify we require admin user
|
||||||
|
hass_admin_user.groups = []
|
||||||
|
resp = await client.get(const.URL_API_ERROR_LOG, headers={
|
||||||
|
'Authorization': 'Bearer {}'.format(hass_access_token)
|
||||||
|
})
|
||||||
|
assert resp.status == 401
|
||||||
|
|
||||||
|
|
||||||
async def test_api_fire_event_context(hass, mock_api_client,
|
async def test_api_fire_event_context(hass, mock_api_client,
|
||||||
hass_access_token):
|
hass_access_token):
|
||||||
|
@ -494,3 +504,67 @@ async def test_api_set_state_context(hass, mock_api_client, hass_access_token):
|
||||||
|
|
||||||
state = hass.states.get('light.kitchen')
|
state = hass.states.get('light.kitchen')
|
||||||
assert state.context.user_id == refresh_token.user.id
|
assert state.context.user_id == refresh_token.user.id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_event_stream_requires_admin(hass, mock_api_client,
|
||||||
|
hass_admin_user):
|
||||||
|
"""Test user needs to be admin to access event stream."""
|
||||||
|
hass_admin_user.groups = []
|
||||||
|
resp = await mock_api_client.get('/api/stream')
|
||||||
|
assert resp.status == 401
|
||||||
|
|
||||||
|
|
||||||
|
async def test_states_view_filters(hass, mock_api_client, hass_admin_user):
|
||||||
|
"""Test filtering only visible states."""
|
||||||
|
hass_admin_user.mock_policy({
|
||||||
|
'entities': {
|
||||||
|
'entity_ids': {
|
||||||
|
'test.entity': True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
hass.states.async_set('test.entity', 'hello')
|
||||||
|
hass.states.async_set('test.not_visible_entity', 'invisible')
|
||||||
|
resp = await mock_api_client.get(const.URL_API_STATES)
|
||||||
|
assert resp.status == 200
|
||||||
|
json = await resp.json()
|
||||||
|
assert len(json) == 1
|
||||||
|
assert json[0]['entity_id'] == 'test.entity'
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_entity_state_read_perm(hass, mock_api_client,
|
||||||
|
hass_admin_user):
|
||||||
|
"""Test getting a state requires read permission."""
|
||||||
|
hass_admin_user.mock_policy({})
|
||||||
|
resp = await mock_api_client.get('/api/states/light.test')
|
||||||
|
assert resp.status == 401
|
||||||
|
|
||||||
|
|
||||||
|
async def test_post_entity_state_admin(hass, mock_api_client, hass_admin_user):
|
||||||
|
"""Test updating state requires admin."""
|
||||||
|
hass_admin_user.groups = []
|
||||||
|
resp = await mock_api_client.post('/api/states/light.test')
|
||||||
|
assert resp.status == 401
|
||||||
|
|
||||||
|
|
||||||
|
async def test_delete_entity_state_admin(hass, mock_api_client,
|
||||||
|
hass_admin_user):
|
||||||
|
"""Test deleting entity requires admin."""
|
||||||
|
hass_admin_user.groups = []
|
||||||
|
resp = await mock_api_client.delete('/api/states/light.test')
|
||||||
|
assert resp.status == 401
|
||||||
|
|
||||||
|
|
||||||
|
async def test_post_event_admin(hass, mock_api_client, hass_admin_user):
|
||||||
|
"""Test sending event requires admin."""
|
||||||
|
hass_admin_user.groups = []
|
||||||
|
resp = await mock_api_client.post('/api/events/state_changed')
|
||||||
|
assert resp.status == 401
|
||||||
|
|
||||||
|
|
||||||
|
async def test_rendering_template_admin(hass, mock_api_client,
|
||||||
|
hass_admin_user):
|
||||||
|
"""Test rendering a template requires admin."""
|
||||||
|
hass_admin_user.groups = []
|
||||||
|
resp = await mock_api_client.post('/api/template')
|
||||||
|
assert resp.status == 401
|
||||||
|
|
Loading…
Add table
Reference in a new issue