Add a service require_admin wrapper (#21953)
* Add a service require_admin wrapper * Allow it to be used as a decorator * Lint * Add comment * Add docstring * Update syntax
This commit is contained in:
parent
bf839687ad
commit
c15f433c3e
2 changed files with 61 additions and 3 deletions
|
@ -1,7 +1,9 @@
|
|||
"""Service calling related helpers."""
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
import logging
|
||||
from os import path
|
||||
from typing import Callable
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -10,7 +12,7 @@ from homeassistant.const import (
|
|||
ATTR_ENTITY_ID, ENTITY_MATCH_ALL, ATTR_AREA_ID)
|
||||
import homeassistant.core as ha
|
||||
from homeassistant.exceptions import TemplateError, Unauthorized, UnknownUser
|
||||
from homeassistant.helpers import template
|
||||
from homeassistant.helpers import template, typing
|
||||
from homeassistant.loader import get_component, bind_hass
|
||||
from homeassistant.util.yaml import load_yaml
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
|
@ -335,3 +337,25 @@ async def _handle_service_platform_call(func, data, entities, context):
|
|||
assert not pending
|
||||
for future in done:
|
||||
future.result() # pop exception if have
|
||||
|
||||
|
||||
@bind_hass
|
||||
@ha.callback
|
||||
def async_register_admin_service(hass: typing.HomeAssistantType, domain: str,
|
||||
service: str, service_func: Callable,
|
||||
schema: vol.Schema) -> None:
|
||||
"""Register a service that requires admin access."""
|
||||
@wraps(service_func)
|
||||
async def admin_handler(call):
|
||||
if call.context.user_id:
|
||||
user = await hass.auth.async_get_user(call.context.user_id)
|
||||
if user is None:
|
||||
raise UnknownUser(context=call.context)
|
||||
if not user.is_admin:
|
||||
raise Unauthorized(context=call.context)
|
||||
|
||||
await hass.async_add_job(service_func, call)
|
||||
|
||||
hass.services.async_register(
|
||||
domain, service, admin_handler, schema
|
||||
)
|
||||
|
|
|
@ -5,18 +5,18 @@ from copy import deepcopy
|
|||
import unittest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import voluptuous as vol
|
||||
import pytest
|
||||
|
||||
# To prevent circular import when running just this file
|
||||
import homeassistant.components # noqa
|
||||
from homeassistant import core as ha, loader, exceptions
|
||||
from homeassistant.const import STATE_ON, STATE_OFF, ATTR_ENTITY_ID
|
||||
from homeassistant.helpers import service, template
|
||||
from homeassistant.setup import async_setup_component
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.auth.permissions import PolicyPermissions
|
||||
from homeassistant.helpers import (
|
||||
device_registry as dev_reg, entity_registry as ent_reg)
|
||||
service, template, device_registry as dev_reg, entity_registry as ent_reg)
|
||||
from tests.common import (
|
||||
get_test_home_assistant, mock_service, mock_coro, mock_registry,
|
||||
mock_device_registry)
|
||||
|
@ -395,3 +395,37 @@ async def test_call_with_omit_entity_id(hass, mock_service_platform_call,
|
|||
mock_entities['light.kitchen'], mock_entities['light.living_room']]
|
||||
assert ('Not passing an entity ID to a service to target '
|
||||
'all entities is deprecated') in caplog.text
|
||||
|
||||
|
||||
async def test_register_admin_service(hass, hass_read_only_user,
|
||||
hass_admin_user):
|
||||
"""Test the register admin service."""
|
||||
calls = []
|
||||
|
||||
async def mock_service(call):
|
||||
calls.append(call)
|
||||
|
||||
hass.helpers.service.async_register_admin_service(
|
||||
'test', 'test', mock_service, vol.Schema({})
|
||||
)
|
||||
|
||||
with pytest.raises(exceptions.UnknownUser):
|
||||
await hass.services.async_call(
|
||||
'test', 'test', {}, blocking=True, context=ha.Context(
|
||||
user_id='non-existing'
|
||||
))
|
||||
assert len(calls) == 0
|
||||
|
||||
with pytest.raises(exceptions.Unauthorized):
|
||||
await hass.services.async_call(
|
||||
'test', 'test', {}, blocking=True, context=ha.Context(
|
||||
user_id=hass_read_only_user.id
|
||||
))
|
||||
assert len(calls) == 0
|
||||
|
||||
await hass.services.async_call(
|
||||
'test', 'test', {}, blocking=True, context=ha.Context(
|
||||
user_id=hass_admin_user.id
|
||||
))
|
||||
assert len(calls) == 1
|
||||
assert calls[0].context.user_id == hass_admin_user.id
|
||||
|
|
Loading…
Add table
Reference in a new issue