Create decorator to check service permissions (#22667)

* Create decorator to check service permissions

* Typing

* Linting

* Member comments

* Linting

* Member comments

* Updated import

* Owner comments

* Linting

* Linting

* More work

* Fixed tests

* Removed service helper tests in RainMachine

* Linting

* Owner comments

* Linting

* Owner comments

Co-Authored-By: bachya <bachya1208@gmail.com>
This commit is contained in:
Aaron Bach 2019-04-13 13:54:29 -06:00 committed by GitHub
parent 7a6950fd72
commit fc481133e7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 177 additions and 118 deletions

View file

@ -2,22 +2,20 @@
import asyncio import asyncio
import logging import logging
from datetime import timedelta from datetime import timedelta
from functools import wraps
import voluptuous as vol import voluptuous as vol
from homeassistant.auth.permissions.const import POLICY_CONTROL
from homeassistant.config_entries import SOURCE_IMPORT from homeassistant.config_entries import SOURCE_IMPORT
from homeassistant.const import ( from homeassistant.const import (
ATTR_ATTRIBUTION, CONF_BINARY_SENSORS, CONF_IP_ADDRESS, CONF_PASSWORD, ATTR_ATTRIBUTION, CONF_BINARY_SENSORS, CONF_IP_ADDRESS, CONF_PASSWORD,
CONF_PORT, CONF_SCAN_INTERVAL, CONF_SENSORS, CONF_SSL, CONF_PORT, CONF_SCAN_INTERVAL, CONF_SENSORS, CONF_SSL,
CONF_MONITORED_CONDITIONS, CONF_SWITCHES) CONF_MONITORED_CONDITIONS, CONF_SWITCHES)
from homeassistant.exceptions import ( from homeassistant.exceptions import ConfigEntryNotReady
ConfigEntryNotReady, Unauthorized, UnknownUser)
from homeassistant.helpers import aiohttp_client, config_validation as cv from homeassistant.helpers import aiohttp_client, config_validation as cv
from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.service import verify_domain_control
from .config_flow import configured_instances from .config_flow import configured_instances
from .const import ( from .const import (
@ -131,44 +129,6 @@ CONFIG_SCHEMA = vol.Schema({
}, extra=vol.ALLOW_EXTRA) }, extra=vol.ALLOW_EXTRA)
def _check_valid_user(hass):
"""Ensure the user of a service call has proper permissions."""
def decorator(service):
"""Decorate."""
@wraps(service)
async def check_permissions(call):
"""Check user permission and raise before call if unauthorized."""
if not call.context.user_id:
return
user = await hass.auth.async_get_user(call.context.user_id)
if user is None:
raise UnknownUser(
context=call.context,
permission=POLICY_CONTROL
)
# RainMachine services don't interact with specific entities.
# Therefore, we examine _all_ RainMachine entities and if the user
# has permission to control _any_ of them, the user has permission
# to call the service:
en_reg = await hass.helpers.entity_registry.async_get_registry()
rainmachine_entities = [
entity.entity_id for entity in en_reg.entities.values()
if entity.platform == DOMAIN
]
for entity_id in rainmachine_entities:
if user.permissions.check_entity(entity_id, POLICY_CONTROL):
return await service(call)
raise Unauthorized(
context=call.context,
permission=POLICY_CONTROL,
)
return check_permissions
return decorator
async def async_setup(hass, config): async def async_setup(hass, config):
"""Set up the RainMachine component.""" """Set up the RainMachine component."""
hass.data[DOMAIN] = {} hass.data[DOMAIN] = {}
@ -198,6 +158,8 @@ async def async_setup_entry(hass, config_entry):
from regenmaschine import login from regenmaschine import login
from regenmaschine.errors import RainMachineError from regenmaschine.errors import RainMachineError
_verify_domain_control = verify_domain_control(hass, DOMAIN)
websession = aiohttp_client.async_get_clientsession(hass) websession = aiohttp_client.async_get_clientsession(hass)
try: try:
@ -238,69 +200,69 @@ async def async_setup_entry(hass, config_entry):
refresh, refresh,
timedelta(seconds=config_entry.data[CONF_SCAN_INTERVAL])) timedelta(seconds=config_entry.data[CONF_SCAN_INTERVAL]))
@_check_valid_user(hass) @_verify_domain_control
async def disable_program(call): async def disable_program(call):
"""Disable a program.""" """Disable a program."""
await rainmachine.client.programs.disable( await rainmachine.client.programs.disable(
call.data[CONF_PROGRAM_ID]) call.data[CONF_PROGRAM_ID])
async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC) async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC)
@_check_valid_user(hass) @_verify_domain_control
async def disable_zone(call): async def disable_zone(call):
"""Disable a zone.""" """Disable a zone."""
await rainmachine.client.zones.disable(call.data[CONF_ZONE_ID]) await rainmachine.client.zones.disable(call.data[CONF_ZONE_ID])
async_dispatcher_send(hass, ZONE_UPDATE_TOPIC) async_dispatcher_send(hass, ZONE_UPDATE_TOPIC)
@_check_valid_user(hass) @_verify_domain_control
async def enable_program(call): async def enable_program(call):
"""Enable a program.""" """Enable a program."""
await rainmachine.client.programs.enable(call.data[CONF_PROGRAM_ID]) await rainmachine.client.programs.enable(call.data[CONF_PROGRAM_ID])
async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC) async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC)
@_check_valid_user(hass) @_verify_domain_control
async def enable_zone(call): async def enable_zone(call):
"""Enable a zone.""" """Enable a zone."""
await rainmachine.client.zones.enable(call.data[CONF_ZONE_ID]) await rainmachine.client.zones.enable(call.data[CONF_ZONE_ID])
async_dispatcher_send(hass, ZONE_UPDATE_TOPIC) async_dispatcher_send(hass, ZONE_UPDATE_TOPIC)
@_check_valid_user(hass) @_verify_domain_control
async def pause_watering(call): async def pause_watering(call):
"""Pause watering for a set number of seconds.""" """Pause watering for a set number of seconds."""
await rainmachine.client.watering.pause_all(call.data[CONF_SECONDS]) await rainmachine.client.watering.pause_all(call.data[CONF_SECONDS])
async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC) async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC)
@_check_valid_user(hass) @_verify_domain_control
async def start_program(call): async def start_program(call):
"""Start a particular program.""" """Start a particular program."""
await rainmachine.client.programs.start(call.data[CONF_PROGRAM_ID]) await rainmachine.client.programs.start(call.data[CONF_PROGRAM_ID])
async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC) async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC)
@_check_valid_user(hass) @_verify_domain_control
async def start_zone(call): async def start_zone(call):
"""Start a particular zone for a certain amount of time.""" """Start a particular zone for a certain amount of time."""
await rainmachine.client.zones.start( await rainmachine.client.zones.start(
call.data[CONF_ZONE_ID], call.data[CONF_ZONE_RUN_TIME]) call.data[CONF_ZONE_ID], call.data[CONF_ZONE_RUN_TIME])
async_dispatcher_send(hass, ZONE_UPDATE_TOPIC) async_dispatcher_send(hass, ZONE_UPDATE_TOPIC)
@_check_valid_user(hass) @_verify_domain_control
async def stop_all(call): async def stop_all(call):
"""Stop all watering.""" """Stop all watering."""
await rainmachine.client.watering.stop_all() await rainmachine.client.watering.stop_all()
async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC) async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC)
@_check_valid_user(hass) @_verify_domain_control
async def stop_program(call): async def stop_program(call):
"""Stop a program.""" """Stop a program."""
await rainmachine.client.programs.stop(call.data[CONF_PROGRAM_ID]) await rainmachine.client.programs.stop(call.data[CONF_PROGRAM_ID])
async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC) async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC)
@_check_valid_user(hass) @_verify_domain_control
async def stop_zone(call): async def stop_zone(call):
"""Stop a zone.""" """Stop a zone."""
await rainmachine.client.zones.stop(call.data[CONF_ZONE_ID]) await rainmachine.client.zones.stop(call.data[CONF_ZONE_ID])
async_dispatcher_send(hass, ZONE_UPDATE_TOPIC) async_dispatcher_send(hass, ZONE_UPDATE_TOPIC)
@_check_valid_user(hass) @_verify_domain_control
async def unpause_watering(call): async def unpause_watering(call):
"""Unpause watering.""" """Unpause watering."""
await rainmachine.client.watering.unpause_all() await rainmachine.client.watering.unpause_all()

View file

@ -6,7 +6,7 @@ from typing import Callable
import voluptuous as vol import voluptuous as vol
from homeassistant.auth.permissions.const import POLICY_CONTROL from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_CONTROL
from homeassistant.const import ( from homeassistant.const import (
ATTR_ENTITY_ID, ENTITY_MATCH_ALL, ATTR_AREA_ID) ATTR_ENTITY_ID, ENTITY_MATCH_ALL, ATTR_AREA_ID)
import homeassistant.core as ha import homeassistant.core as ha
@ -19,6 +19,8 @@ import homeassistant.helpers.config_validation as cv
from homeassistant.util.async_ import run_coroutine_threadsafe from homeassistant.util.async_ import run_coroutine_threadsafe
from homeassistant.helpers.typing import HomeAssistantType from homeassistant.helpers.typing import HomeAssistantType
from .typing import HomeAssistantType
CONF_SERVICE = 'service' CONF_SERVICE = 'service'
CONF_SERVICE_TEMPLATE = 'service_template' CONF_SERVICE_TEMPLATE = 'service_template'
CONF_SERVICE_ENTITY_ID = 'entity_id' CONF_SERVICE_ENTITY_ID = 'entity_id'
@ -369,3 +371,47 @@ def async_register_admin_service(
hass.services.async_register( hass.services.async_register(
domain, service, admin_handler, schema domain, service, admin_handler, schema
) )
@bind_hass
@ha.callback
def verify_domain_control(hass: HomeAssistantType, domain: str) -> Callable:
"""Ensure permission to access any entity under domain in service call."""
def decorator(service_handler: Callable) -> Callable:
"""Decorate."""
if not asyncio.iscoroutinefunction(service_handler):
raise HomeAssistantError(
'Can only decorate async functions.')
async def check_permissions(call):
"""Check user permission and raise before call if unauthorized."""
if not call.context.user_id:
return await service_handler(call)
user = await hass.auth.async_get_user(call.context.user_id)
if user is None:
raise UnknownUser(
context=call.context,
permission=POLICY_CONTROL,
user_id=call.context.user_id)
reg = await hass.helpers.entity_registry.async_get_registry()
entities = [
entity.entity_id for entity in reg.entities.values()
if entity.platform == domain
]
for entity_id in entities:
if user.permissions.check_entity(entity_id, POLICY_CONTROL):
return await service_handler(call)
raise Unauthorized(
context=call.context,
permission=POLICY_CONTROL,
user_id=call.context.user_id,
perm_category=CAT_ENTITIES
)
return check_permissions
return decorator

View file

@ -1,23 +0,0 @@
"""Configuration for Rainmachine tests."""
import pytest
from homeassistant.components.rainmachine.const import DOMAIN
from homeassistant.const import (
CONF_IP_ADDRESS, CONF_PASSWORD, CONF_PORT, CONF_SCAN_INTERVAL, CONF_SSL)
from tests.common import MockConfigEntry
@pytest.fixture(name="config_entry")
def config_entry_fixture():
"""Create a mock RainMachine config entry."""
return MockConfigEntry(
domain=DOMAIN,
title='192.168.1.101',
data={
CONF_IP_ADDRESS: '192.168.1.101',
CONF_PASSWORD: '12345',
CONF_PORT: 8080,
CONF_SSL: True,
CONF_SCAN_INTERVAL: 60,
})

View file

@ -1,41 +0,0 @@
"""Define tests for permissions on RainMachine service calls."""
import asynctest
import pytest
from homeassistant.components.rainmachine.const import DOMAIN
from homeassistant.core import Context
from homeassistant.exceptions import Unauthorized, UnknownUser
from homeassistant.setup import async_setup_component
from tests.common import mock_coro
async def setup_platform(hass, config_entry):
"""Set up the media player platform for testing."""
with asynctest.mock.patch('regenmaschine.login') as mock_login:
mock_client = mock_login.return_value
mock_client.restrictions.current.return_value = mock_coro()
mock_client.restrictions.universal.return_value = mock_coro()
config_entry.add_to_hass(hass)
assert await async_setup_component(hass, DOMAIN)
await hass.async_block_till_done()
async def test_services_authorization(
hass, config_entry, hass_read_only_user):
"""Test that a RainMachine service is halted on incorrect permissions."""
await setup_platform(hass, config_entry)
with pytest.raises(UnknownUser):
await hass.services.async_call(
'rainmachine',
'unpause_watering', {},
blocking=True,
context=Context(user_id='fake_user_id'))
with pytest.raises(Unauthorized):
await hass.services.async_call(
'rainmachine',
'unpause_watering', {},
blocking=True,
context=Context(user_id=hass_read_only_user.id))

View file

@ -38,12 +38,14 @@ def mock_entities():
available=True, available=True,
should_poll=False, should_poll=False,
supported_features=1, supported_features=1,
platform='test_domain',
) )
living_room = Mock( living_room = Mock(
entity_id='light.living_room', entity_id='light.living_room',
available=True, available=True,
should_poll=False, should_poll=False,
supported_features=0, supported_features=0,
platform='test_domain',
) )
entities = OrderedDict() entities = OrderedDict()
entities[kitchen.entity_id] = kitchen entities[kitchen.entity_id] = kitchen
@ -461,3 +463,116 @@ async def test_register_admin_service(hass, hass_read_only_user,
)) ))
assert len(calls) == 1 assert len(calls) == 1
assert calls[0].context.user_id == hass_admin_user.id assert calls[0].context.user_id == hass_admin_user.id
async def test_domain_control_not_async(hass, mock_entities):
"""Test domain verification in a service call with an unknown user."""
calls = []
def mock_service_log(call):
"""Define a protected service."""
calls.append(call)
with pytest.raises(exceptions.HomeAssistantError):
hass.helpers.service.verify_domain_control(
'test_domain')(mock_service_log)
async def test_domain_control_unknown(hass, mock_entities):
"""Test domain verification in a service call with an unknown user."""
calls = []
async def mock_service_log(call):
"""Define a protected service."""
calls.append(call)
with patch('homeassistant.helpers.entity_registry.async_get_registry',
return_value=mock_coro(Mock(entities=mock_entities))):
protected_mock_service = hass.helpers.service.verify_domain_control(
'test_domain')(mock_service_log)
hass.services.async_register(
'test_domain', 'test_service', protected_mock_service, schema=None)
with pytest.raises(exceptions.UnknownUser):
await hass.services.async_call(
'test_domain',
'test_service', {},
blocking=True,
context=ha.Context(user_id='fake_user_id'))
assert len(calls) == 0
async def test_domain_control_unauthorized(
hass, hass_read_only_user, mock_entities):
"""Test domain verification in a service call with an unauthorized user."""
calls = []
async def mock_service_log(call):
"""Define a protected service."""
calls.append(call)
with patch('homeassistant.helpers.entity_registry.async_get_registry',
return_value=mock_coro(Mock(entities=mock_entities))):
protected_mock_service = hass.helpers.service.verify_domain_control(
'test_domain')(mock_service_log)
hass.services.async_register(
'test_domain', 'test_service', protected_mock_service, schema=None)
with pytest.raises(exceptions.Unauthorized):
await hass.services.async_call(
'test_domain',
'test_service', {},
blocking=True,
context=ha.Context(user_id=hass_read_only_user.id))
async def test_domain_control_admin(hass, hass_admin_user, mock_entities):
"""Test domain verification in a service call with an admin user."""
calls = []
async def mock_service_log(call):
"""Define a protected service."""
calls.append(call)
with patch('homeassistant.helpers.entity_registry.async_get_registry',
return_value=mock_coro(Mock(entities=mock_entities))):
protected_mock_service = hass.helpers.service.verify_domain_control(
'test_domain')(mock_service_log)
hass.services.async_register(
'test_domain', 'test_service', protected_mock_service, schema=None)
await hass.services.async_call(
'test_domain',
'test_service', {},
blocking=True,
context=ha.Context(user_id=hass_admin_user.id))
assert len(calls) == 1
async def test_domain_control_no_user(hass, mock_entities):
"""Test domain verification in a service call with no user."""
calls = []
async def mock_service_log(call):
"""Define a protected service."""
calls.append(call)
with patch('homeassistant.helpers.entity_registry.async_get_registry',
return_value=mock_coro(Mock(entities=mock_entities))):
protected_mock_service = hass.helpers.service.verify_domain_control(
'test_domain')(mock_service_log)
hass.services.async_register(
'test_domain', 'test_service', protected_mock_service, schema=None)
await hass.services.async_call(
'test_domain',
'test_service', {},
blocking=True,
context=ha.Context(user_id=None))
assert len(calls) == 1