Add a service require_admin wrapper ()

* Add a service require_admin wrapper

* Allow it to be used as a decorator

* Lint

* Add comment

* Add docstring

* Update syntax
This commit is contained in:
Paulus Schoutsen 2019-03-12 22:09:50 -07:00 committed by GitHub
parent bf839687ad
commit c15f433c3e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 3 deletions
homeassistant/helpers
tests/helpers

View file

@ -1,7 +1,9 @@
"""Service calling related helpers."""
import asyncio
from functools import wraps
import logging
from os import path
from typing import Callable
import voluptuous as vol
@ -10,7 +12,7 @@ from homeassistant.const import (
ATTR_ENTITY_ID, ENTITY_MATCH_ALL, ATTR_AREA_ID)
import homeassistant.core as ha
from homeassistant.exceptions import TemplateError, Unauthorized, UnknownUser
from homeassistant.helpers import template
from homeassistant.helpers import template, typing
from homeassistant.loader import get_component, bind_hass
from homeassistant.util.yaml import load_yaml
import homeassistant.helpers.config_validation as cv
@ -335,3 +337,25 @@ async def _handle_service_platform_call(func, data, entities, context):
assert not pending
for future in done:
future.result() # pop exception if have
@bind_hass
@ha.callback
def async_register_admin_service(hass: typing.HomeAssistantType, domain: str,
service: str, service_func: Callable,
schema: vol.Schema) -> None:
"""Register a service that requires admin access."""
@wraps(service_func)
async def admin_handler(call):
if call.context.user_id:
user = await hass.auth.async_get_user(call.context.user_id)
if user is None:
raise UnknownUser(context=call.context)
if not user.is_admin:
raise Unauthorized(context=call.context)
await hass.async_add_job(service_func, call)
hass.services.async_register(
domain, service, admin_handler, schema
)

View file

@ -5,18 +5,18 @@ from copy import deepcopy
import unittest
from unittest.mock import Mock, patch
import voluptuous as vol
import pytest
# To prevent circular import when running just this file
import homeassistant.components # noqa
from homeassistant import core as ha, loader, exceptions
from homeassistant.const import STATE_ON, STATE_OFF, ATTR_ENTITY_ID
from homeassistant.helpers import service, template
from homeassistant.setup import async_setup_component
import homeassistant.helpers.config_validation as cv
from homeassistant.auth.permissions import PolicyPermissions
from homeassistant.helpers import (
device_registry as dev_reg, entity_registry as ent_reg)
service, template, device_registry as dev_reg, entity_registry as ent_reg)
from tests.common import (
get_test_home_assistant, mock_service, mock_coro, mock_registry,
mock_device_registry)
@ -395,3 +395,37 @@ async def test_call_with_omit_entity_id(hass, mock_service_platform_call,
mock_entities['light.kitchen'], mock_entities['light.living_room']]
assert ('Not passing an entity ID to a service to target '
'all entities is deprecated') in caplog.text
async def test_register_admin_service(hass, hass_read_only_user,
hass_admin_user):
"""Test the register admin service."""
calls = []
async def mock_service(call):
calls.append(call)
hass.helpers.service.async_register_admin_service(
'test', 'test', mock_service, vol.Schema({})
)
with pytest.raises(exceptions.UnknownUser):
await hass.services.async_call(
'test', 'test', {}, blocking=True, context=ha.Context(
user_id='non-existing'
))
assert len(calls) == 0
with pytest.raises(exceptions.Unauthorized):
await hass.services.async_call(
'test', 'test', {}, blocking=True, context=ha.Context(
user_id=hass_read_only_user.id
))
assert len(calls) == 0
await hass.services.async_call(
'test', 'test', {}, blocking=True, context=ha.Context(
user_id=hass_admin_user.id
))
assert len(calls) == 1
assert calls[0].context.user_id == hass_admin_user.id