Create decorator to check service permissions (#22667)

* Create decorator to check service permissions

* Typing

* Linting

* Member comments

* Linting

* Member comments

* Updated import

* Owner comments

* Linting

* Linting

* More work

* Fixed tests

* Removed service helper tests in RainMachine

* Linting

* Owner comments

* Linting

* Owner comments

Co-Authored-By: bachya <bachya1208@gmail.com>
This commit is contained in:
Aaron Bach 2019-04-13 13:54:29 -06:00 committed by GitHub
parent 7a6950fd72
commit fc481133e7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 177 additions and 118 deletions

View file

@ -2,22 +2,20 @@
import asyncio
import logging
from datetime import timedelta
from functools import wraps
import voluptuous as vol
from homeassistant.auth.permissions.const import POLICY_CONTROL
from homeassistant.config_entries import SOURCE_IMPORT
from homeassistant.const import (
ATTR_ATTRIBUTION, CONF_BINARY_SENSORS, CONF_IP_ADDRESS, CONF_PASSWORD,
CONF_PORT, CONF_SCAN_INTERVAL, CONF_SENSORS, CONF_SSL,
CONF_MONITORED_CONDITIONS, CONF_SWITCHES)
from homeassistant.exceptions import (
ConfigEntryNotReady, Unauthorized, UnknownUser)
from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers import aiohttp_client, config_validation as cv
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.service import verify_domain_control
from .config_flow import configured_instances
from .const import (
@ -131,44 +129,6 @@ CONFIG_SCHEMA = vol.Schema({
}, extra=vol.ALLOW_EXTRA)
def _check_valid_user(hass):
"""Ensure the user of a service call has proper permissions."""
def decorator(service):
"""Decorate."""
@wraps(service)
async def check_permissions(call):
"""Check user permission and raise before call if unauthorized."""
if not call.context.user_id:
return
user = await hass.auth.async_get_user(call.context.user_id)
if user is None:
raise UnknownUser(
context=call.context,
permission=POLICY_CONTROL
)
# RainMachine services don't interact with specific entities.
# Therefore, we examine _all_ RainMachine entities and if the user
# has permission to control _any_ of them, the user has permission
# to call the service:
en_reg = await hass.helpers.entity_registry.async_get_registry()
rainmachine_entities = [
entity.entity_id for entity in en_reg.entities.values()
if entity.platform == DOMAIN
]
for entity_id in rainmachine_entities:
if user.permissions.check_entity(entity_id, POLICY_CONTROL):
return await service(call)
raise Unauthorized(
context=call.context,
permission=POLICY_CONTROL,
)
return check_permissions
return decorator
async def async_setup(hass, config):
"""Set up the RainMachine component."""
hass.data[DOMAIN] = {}
@ -198,6 +158,8 @@ async def async_setup_entry(hass, config_entry):
from regenmaschine import login
from regenmaschine.errors import RainMachineError
_verify_domain_control = verify_domain_control(hass, DOMAIN)
websession = aiohttp_client.async_get_clientsession(hass)
try:
@ -238,69 +200,69 @@ async def async_setup_entry(hass, config_entry):
refresh,
timedelta(seconds=config_entry.data[CONF_SCAN_INTERVAL]))
@_check_valid_user(hass)
@_verify_domain_control
async def disable_program(call):
"""Disable a program."""
await rainmachine.client.programs.disable(
call.data[CONF_PROGRAM_ID])
async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC)
@_check_valid_user(hass)
@_verify_domain_control
async def disable_zone(call):
"""Disable a zone."""
await rainmachine.client.zones.disable(call.data[CONF_ZONE_ID])
async_dispatcher_send(hass, ZONE_UPDATE_TOPIC)
@_check_valid_user(hass)
@_verify_domain_control
async def enable_program(call):
"""Enable a program."""
await rainmachine.client.programs.enable(call.data[CONF_PROGRAM_ID])
async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC)
@_check_valid_user(hass)
@_verify_domain_control
async def enable_zone(call):
"""Enable a zone."""
await rainmachine.client.zones.enable(call.data[CONF_ZONE_ID])
async_dispatcher_send(hass, ZONE_UPDATE_TOPIC)
@_check_valid_user(hass)
@_verify_domain_control
async def pause_watering(call):
"""Pause watering for a set number of seconds."""
await rainmachine.client.watering.pause_all(call.data[CONF_SECONDS])
async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC)
@_check_valid_user(hass)
@_verify_domain_control
async def start_program(call):
"""Start a particular program."""
await rainmachine.client.programs.start(call.data[CONF_PROGRAM_ID])
async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC)
@_check_valid_user(hass)
@_verify_domain_control
async def start_zone(call):
"""Start a particular zone for a certain amount of time."""
await rainmachine.client.zones.start(
call.data[CONF_ZONE_ID], call.data[CONF_ZONE_RUN_TIME])
async_dispatcher_send(hass, ZONE_UPDATE_TOPIC)
@_check_valid_user(hass)
@_verify_domain_control
async def stop_all(call):
"""Stop all watering."""
await rainmachine.client.watering.stop_all()
async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC)
@_check_valid_user(hass)
@_verify_domain_control
async def stop_program(call):
"""Stop a program."""
await rainmachine.client.programs.stop(call.data[CONF_PROGRAM_ID])
async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC)
@_check_valid_user(hass)
@_verify_domain_control
async def stop_zone(call):
"""Stop a zone."""
await rainmachine.client.zones.stop(call.data[CONF_ZONE_ID])
async_dispatcher_send(hass, ZONE_UPDATE_TOPIC)
@_check_valid_user(hass)
@_verify_domain_control
async def unpause_watering(call):
"""Unpause watering."""
await rainmachine.client.watering.unpause_all()

View file

@ -6,7 +6,7 @@ from typing import Callable
import voluptuous as vol
from homeassistant.auth.permissions.const import POLICY_CONTROL
from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_CONTROL
from homeassistant.const import (
ATTR_ENTITY_ID, ENTITY_MATCH_ALL, ATTR_AREA_ID)
import homeassistant.core as ha
@ -19,6 +19,8 @@ import homeassistant.helpers.config_validation as cv
from homeassistant.util.async_ import run_coroutine_threadsafe
from homeassistant.helpers.typing import HomeAssistantType
from .typing import HomeAssistantType
CONF_SERVICE = 'service'
CONF_SERVICE_TEMPLATE = 'service_template'
CONF_SERVICE_ENTITY_ID = 'entity_id'
@ -369,3 +371,47 @@ def async_register_admin_service(
hass.services.async_register(
domain, service, admin_handler, schema
)
@bind_hass
@ha.callback
def verify_domain_control(hass: HomeAssistantType, domain: str) -> Callable:
"""Ensure permission to access any entity under domain in service call."""
def decorator(service_handler: Callable) -> Callable:
"""Decorate."""
if not asyncio.iscoroutinefunction(service_handler):
raise HomeAssistantError(
'Can only decorate async functions.')
async def check_permissions(call):
"""Check user permission and raise before call if unauthorized."""
if not call.context.user_id:
return await service_handler(call)
user = await hass.auth.async_get_user(call.context.user_id)
if user is None:
raise UnknownUser(
context=call.context,
permission=POLICY_CONTROL,
user_id=call.context.user_id)
reg = await hass.helpers.entity_registry.async_get_registry()
entities = [
entity.entity_id for entity in reg.entities.values()
if entity.platform == domain
]
for entity_id in entities:
if user.permissions.check_entity(entity_id, POLICY_CONTROL):
return await service_handler(call)
raise Unauthorized(
context=call.context,
permission=POLICY_CONTROL,
user_id=call.context.user_id,
perm_category=CAT_ENTITIES
)
return check_permissions
return decorator

View file

@ -1,23 +0,0 @@
"""Configuration for Rainmachine tests."""
import pytest
from homeassistant.components.rainmachine.const import DOMAIN
from homeassistant.const import (
CONF_IP_ADDRESS, CONF_PASSWORD, CONF_PORT, CONF_SCAN_INTERVAL, CONF_SSL)
from tests.common import MockConfigEntry
@pytest.fixture(name="config_entry")
def config_entry_fixture():
"""Create a mock RainMachine config entry."""
return MockConfigEntry(
domain=DOMAIN,
title='192.168.1.101',
data={
CONF_IP_ADDRESS: '192.168.1.101',
CONF_PASSWORD: '12345',
CONF_PORT: 8080,
CONF_SSL: True,
CONF_SCAN_INTERVAL: 60,
})

View file

@ -1,41 +0,0 @@
"""Define tests for permissions on RainMachine service calls."""
import asynctest
import pytest
from homeassistant.components.rainmachine.const import DOMAIN
from homeassistant.core import Context
from homeassistant.exceptions import Unauthorized, UnknownUser
from homeassistant.setup import async_setup_component
from tests.common import mock_coro
async def setup_platform(hass, config_entry):
"""Set up the media player platform for testing."""
with asynctest.mock.patch('regenmaschine.login') as mock_login:
mock_client = mock_login.return_value
mock_client.restrictions.current.return_value = mock_coro()
mock_client.restrictions.universal.return_value = mock_coro()
config_entry.add_to_hass(hass)
assert await async_setup_component(hass, DOMAIN)
await hass.async_block_till_done()
async def test_services_authorization(
hass, config_entry, hass_read_only_user):
"""Test that a RainMachine service is halted on incorrect permissions."""
await setup_platform(hass, config_entry)
with pytest.raises(UnknownUser):
await hass.services.async_call(
'rainmachine',
'unpause_watering', {},
blocking=True,
context=Context(user_id='fake_user_id'))
with pytest.raises(Unauthorized):
await hass.services.async_call(
'rainmachine',
'unpause_watering', {},
blocking=True,
context=Context(user_id=hass_read_only_user.id))

View file

@ -38,12 +38,14 @@ def mock_entities():
available=True,
should_poll=False,
supported_features=1,
platform='test_domain',
)
living_room = Mock(
entity_id='light.living_room',
available=True,
should_poll=False,
supported_features=0,
platform='test_domain',
)
entities = OrderedDict()
entities[kitchen.entity_id] = kitchen
@ -461,3 +463,116 @@ async def test_register_admin_service(hass, hass_read_only_user,
))
assert len(calls) == 1
assert calls[0].context.user_id == hass_admin_user.id
async def test_domain_control_not_async(hass, mock_entities):
"""Test domain verification in a service call with an unknown user."""
calls = []
def mock_service_log(call):
"""Define a protected service."""
calls.append(call)
with pytest.raises(exceptions.HomeAssistantError):
hass.helpers.service.verify_domain_control(
'test_domain')(mock_service_log)
async def test_domain_control_unknown(hass, mock_entities):
"""Test domain verification in a service call with an unknown user."""
calls = []
async def mock_service_log(call):
"""Define a protected service."""
calls.append(call)
with patch('homeassistant.helpers.entity_registry.async_get_registry',
return_value=mock_coro(Mock(entities=mock_entities))):
protected_mock_service = hass.helpers.service.verify_domain_control(
'test_domain')(mock_service_log)
hass.services.async_register(
'test_domain', 'test_service', protected_mock_service, schema=None)
with pytest.raises(exceptions.UnknownUser):
await hass.services.async_call(
'test_domain',
'test_service', {},
blocking=True,
context=ha.Context(user_id='fake_user_id'))
assert len(calls) == 0
async def test_domain_control_unauthorized(
hass, hass_read_only_user, mock_entities):
"""Test domain verification in a service call with an unauthorized user."""
calls = []
async def mock_service_log(call):
"""Define a protected service."""
calls.append(call)
with patch('homeassistant.helpers.entity_registry.async_get_registry',
return_value=mock_coro(Mock(entities=mock_entities))):
protected_mock_service = hass.helpers.service.verify_domain_control(
'test_domain')(mock_service_log)
hass.services.async_register(
'test_domain', 'test_service', protected_mock_service, schema=None)
with pytest.raises(exceptions.Unauthorized):
await hass.services.async_call(
'test_domain',
'test_service', {},
blocking=True,
context=ha.Context(user_id=hass_read_only_user.id))
async def test_domain_control_admin(hass, hass_admin_user, mock_entities):
"""Test domain verification in a service call with an admin user."""
calls = []
async def mock_service_log(call):
"""Define a protected service."""
calls.append(call)
with patch('homeassistant.helpers.entity_registry.async_get_registry',
return_value=mock_coro(Mock(entities=mock_entities))):
protected_mock_service = hass.helpers.service.verify_domain_control(
'test_domain')(mock_service_log)
hass.services.async_register(
'test_domain', 'test_service', protected_mock_service, schema=None)
await hass.services.async_call(
'test_domain',
'test_service', {},
blocking=True,
context=ha.Context(user_id=hass_admin_user.id))
assert len(calls) == 1
async def test_domain_control_no_user(hass, mock_entities):
"""Test domain verification in a service call with no user."""
calls = []
async def mock_service_log(call):
"""Define a protected service."""
calls.append(call)
with patch('homeassistant.helpers.entity_registry.async_get_registry',
return_value=mock_coro(Mock(entities=mock_entities))):
protected_mock_service = hass.helpers.service.verify_domain_control(
'test_domain')(mock_service_log)
hass.services.async_register(
'test_domain', 'test_service', protected_mock_service, schema=None)
await hass.services.async_call(
'test_domain',
'test_service', {},
blocking=True,
context=ha.Context(user_id=None))
assert len(calls) == 1