Create decorator to check service permissions (#22667)
* Create decorator to check service permissions * Typing * Linting * Member comments * Linting * Member comments * Updated import * Owner comments * Linting * Linting * More work * Fixed tests * Removed service helper tests in RainMachine * Linting * Owner comments * Linting * Owner comments Co-Authored-By: bachya <bachya1208@gmail.com>
This commit is contained in:
parent
7a6950fd72
commit
fc481133e7
5 changed files with 177 additions and 118 deletions
|
@ -2,22 +2,20 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from functools import wraps
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.auth.permissions.const import POLICY_CONTROL
|
||||
from homeassistant.config_entries import SOURCE_IMPORT
|
||||
from homeassistant.const import (
|
||||
ATTR_ATTRIBUTION, CONF_BINARY_SENSORS, CONF_IP_ADDRESS, CONF_PASSWORD,
|
||||
CONF_PORT, CONF_SCAN_INTERVAL, CONF_SENSORS, CONF_SSL,
|
||||
CONF_MONITORED_CONDITIONS, CONF_SWITCHES)
|
||||
from homeassistant.exceptions import (
|
||||
ConfigEntryNotReady, Unauthorized, UnknownUser)
|
||||
from homeassistant.exceptions import ConfigEntryNotReady
|
||||
from homeassistant.helpers import aiohttp_client, config_validation as cv
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.event import async_track_time_interval
|
||||
from homeassistant.helpers.service import verify_domain_control
|
||||
|
||||
from .config_flow import configured_instances
|
||||
from .const import (
|
||||
|
@ -131,44 +129,6 @@ CONFIG_SCHEMA = vol.Schema({
|
|||
}, extra=vol.ALLOW_EXTRA)
|
||||
|
||||
|
||||
def _check_valid_user(hass):
|
||||
"""Ensure the user of a service call has proper permissions."""
|
||||
def decorator(service):
|
||||
"""Decorate."""
|
||||
@wraps(service)
|
||||
async def check_permissions(call):
|
||||
"""Check user permission and raise before call if unauthorized."""
|
||||
if not call.context.user_id:
|
||||
return
|
||||
|
||||
user = await hass.auth.async_get_user(call.context.user_id)
|
||||
if user is None:
|
||||
raise UnknownUser(
|
||||
context=call.context,
|
||||
permission=POLICY_CONTROL
|
||||
)
|
||||
|
||||
# RainMachine services don't interact with specific entities.
|
||||
# Therefore, we examine _all_ RainMachine entities and if the user
|
||||
# has permission to control _any_ of them, the user has permission
|
||||
# to call the service:
|
||||
en_reg = await hass.helpers.entity_registry.async_get_registry()
|
||||
rainmachine_entities = [
|
||||
entity.entity_id for entity in en_reg.entities.values()
|
||||
if entity.platform == DOMAIN
|
||||
]
|
||||
for entity_id in rainmachine_entities:
|
||||
if user.permissions.check_entity(entity_id, POLICY_CONTROL):
|
||||
return await service(call)
|
||||
|
||||
raise Unauthorized(
|
||||
context=call.context,
|
||||
permission=POLICY_CONTROL,
|
||||
)
|
||||
return check_permissions
|
||||
return decorator
|
||||
|
||||
|
||||
async def async_setup(hass, config):
|
||||
"""Set up the RainMachine component."""
|
||||
hass.data[DOMAIN] = {}
|
||||
|
@ -198,6 +158,8 @@ async def async_setup_entry(hass, config_entry):
|
|||
from regenmaschine import login
|
||||
from regenmaschine.errors import RainMachineError
|
||||
|
||||
_verify_domain_control = verify_domain_control(hass, DOMAIN)
|
||||
|
||||
websession = aiohttp_client.async_get_clientsession(hass)
|
||||
|
||||
try:
|
||||
|
@ -238,69 +200,69 @@ async def async_setup_entry(hass, config_entry):
|
|||
refresh,
|
||||
timedelta(seconds=config_entry.data[CONF_SCAN_INTERVAL]))
|
||||
|
||||
@_check_valid_user(hass)
|
||||
@_verify_domain_control
|
||||
async def disable_program(call):
|
||||
"""Disable a program."""
|
||||
await rainmachine.client.programs.disable(
|
||||
call.data[CONF_PROGRAM_ID])
|
||||
async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC)
|
||||
|
||||
@_check_valid_user(hass)
|
||||
@_verify_domain_control
|
||||
async def disable_zone(call):
|
||||
"""Disable a zone."""
|
||||
await rainmachine.client.zones.disable(call.data[CONF_ZONE_ID])
|
||||
async_dispatcher_send(hass, ZONE_UPDATE_TOPIC)
|
||||
|
||||
@_check_valid_user(hass)
|
||||
@_verify_domain_control
|
||||
async def enable_program(call):
|
||||
"""Enable a program."""
|
||||
await rainmachine.client.programs.enable(call.data[CONF_PROGRAM_ID])
|
||||
async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC)
|
||||
|
||||
@_check_valid_user(hass)
|
||||
@_verify_domain_control
|
||||
async def enable_zone(call):
|
||||
"""Enable a zone."""
|
||||
await rainmachine.client.zones.enable(call.data[CONF_ZONE_ID])
|
||||
async_dispatcher_send(hass, ZONE_UPDATE_TOPIC)
|
||||
|
||||
@_check_valid_user(hass)
|
||||
@_verify_domain_control
|
||||
async def pause_watering(call):
|
||||
"""Pause watering for a set number of seconds."""
|
||||
await rainmachine.client.watering.pause_all(call.data[CONF_SECONDS])
|
||||
async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC)
|
||||
|
||||
@_check_valid_user(hass)
|
||||
@_verify_domain_control
|
||||
async def start_program(call):
|
||||
"""Start a particular program."""
|
||||
await rainmachine.client.programs.start(call.data[CONF_PROGRAM_ID])
|
||||
async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC)
|
||||
|
||||
@_check_valid_user(hass)
|
||||
@_verify_domain_control
|
||||
async def start_zone(call):
|
||||
"""Start a particular zone for a certain amount of time."""
|
||||
await rainmachine.client.zones.start(
|
||||
call.data[CONF_ZONE_ID], call.data[CONF_ZONE_RUN_TIME])
|
||||
async_dispatcher_send(hass, ZONE_UPDATE_TOPIC)
|
||||
|
||||
@_check_valid_user(hass)
|
||||
@_verify_domain_control
|
||||
async def stop_all(call):
|
||||
"""Stop all watering."""
|
||||
await rainmachine.client.watering.stop_all()
|
||||
async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC)
|
||||
|
||||
@_check_valid_user(hass)
|
||||
@_verify_domain_control
|
||||
async def stop_program(call):
|
||||
"""Stop a program."""
|
||||
await rainmachine.client.programs.stop(call.data[CONF_PROGRAM_ID])
|
||||
async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC)
|
||||
|
||||
@_check_valid_user(hass)
|
||||
@_verify_domain_control
|
||||
async def stop_zone(call):
|
||||
"""Stop a zone."""
|
||||
await rainmachine.client.zones.stop(call.data[CONF_ZONE_ID])
|
||||
async_dispatcher_send(hass, ZONE_UPDATE_TOPIC)
|
||||
|
||||
@_check_valid_user(hass)
|
||||
@_verify_domain_control
|
||||
async def unpause_watering(call):
|
||||
"""Unpause watering."""
|
||||
await rainmachine.client.watering.unpause_all()
|
||||
|
|
|
@ -6,7 +6,7 @@ from typing import Callable
|
|||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.auth.permissions.const import POLICY_CONTROL
|
||||
from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_CONTROL
|
||||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID, ENTITY_MATCH_ALL, ATTR_AREA_ID)
|
||||
import homeassistant.core as ha
|
||||
|
@ -19,6 +19,8 @@ import homeassistant.helpers.config_validation as cv
|
|||
from homeassistant.util.async_ import run_coroutine_threadsafe
|
||||
from homeassistant.helpers.typing import HomeAssistantType
|
||||
|
||||
from .typing import HomeAssistantType
|
||||
|
||||
CONF_SERVICE = 'service'
|
||||
CONF_SERVICE_TEMPLATE = 'service_template'
|
||||
CONF_SERVICE_ENTITY_ID = 'entity_id'
|
||||
|
@ -369,3 +371,47 @@ def async_register_admin_service(
|
|||
hass.services.async_register(
|
||||
domain, service, admin_handler, schema
|
||||
)
|
||||
|
||||
|
||||
@bind_hass
|
||||
@ha.callback
|
||||
def verify_domain_control(hass: HomeAssistantType, domain: str) -> Callable:
|
||||
"""Ensure permission to access any entity under domain in service call."""
|
||||
def decorator(service_handler: Callable) -> Callable:
|
||||
"""Decorate."""
|
||||
if not asyncio.iscoroutinefunction(service_handler):
|
||||
raise HomeAssistantError(
|
||||
'Can only decorate async functions.')
|
||||
|
||||
async def check_permissions(call):
|
||||
"""Check user permission and raise before call if unauthorized."""
|
||||
if not call.context.user_id:
|
||||
return await service_handler(call)
|
||||
|
||||
user = await hass.auth.async_get_user(call.context.user_id)
|
||||
if user is None:
|
||||
raise UnknownUser(
|
||||
context=call.context,
|
||||
permission=POLICY_CONTROL,
|
||||
user_id=call.context.user_id)
|
||||
|
||||
reg = await hass.helpers.entity_registry.async_get_registry()
|
||||
entities = [
|
||||
entity.entity_id for entity in reg.entities.values()
|
||||
if entity.platform == domain
|
||||
]
|
||||
|
||||
for entity_id in entities:
|
||||
if user.permissions.check_entity(entity_id, POLICY_CONTROL):
|
||||
return await service_handler(call)
|
||||
|
||||
raise Unauthorized(
|
||||
context=call.context,
|
||||
permission=POLICY_CONTROL,
|
||||
user_id=call.context.user_id,
|
||||
perm_category=CAT_ENTITIES
|
||||
)
|
||||
|
||||
return check_permissions
|
||||
|
||||
return decorator
|
||||
|
|
|
@ -1,23 +0,0 @@
|
|||
"""Configuration for Rainmachine tests."""
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.rainmachine.const import DOMAIN
|
||||
from homeassistant.const import (
|
||||
CONF_IP_ADDRESS, CONF_PASSWORD, CONF_PORT, CONF_SCAN_INTERVAL, CONF_SSL)
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@pytest.fixture(name="config_entry")
|
||||
def config_entry_fixture():
|
||||
"""Create a mock RainMachine config entry."""
|
||||
return MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
title='192.168.1.101',
|
||||
data={
|
||||
CONF_IP_ADDRESS: '192.168.1.101',
|
||||
CONF_PASSWORD: '12345',
|
||||
CONF_PORT: 8080,
|
||||
CONF_SSL: True,
|
||||
CONF_SCAN_INTERVAL: 60,
|
||||
})
|
|
@ -1,41 +0,0 @@
|
|||
"""Define tests for permissions on RainMachine service calls."""
|
||||
import asynctest
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.rainmachine.const import DOMAIN
|
||||
from homeassistant.core import Context
|
||||
from homeassistant.exceptions import Unauthorized, UnknownUser
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import mock_coro
|
||||
|
||||
|
||||
async def setup_platform(hass, config_entry):
|
||||
"""Set up the media player platform for testing."""
|
||||
with asynctest.mock.patch('regenmaschine.login') as mock_login:
|
||||
mock_client = mock_login.return_value
|
||||
mock_client.restrictions.current.return_value = mock_coro()
|
||||
mock_client.restrictions.universal.return_value = mock_coro()
|
||||
config_entry.add_to_hass(hass)
|
||||
assert await async_setup_component(hass, DOMAIN)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def test_services_authorization(
|
||||
hass, config_entry, hass_read_only_user):
|
||||
"""Test that a RainMachine service is halted on incorrect permissions."""
|
||||
await setup_platform(hass, config_entry)
|
||||
|
||||
with pytest.raises(UnknownUser):
|
||||
await hass.services.async_call(
|
||||
'rainmachine',
|
||||
'unpause_watering', {},
|
||||
blocking=True,
|
||||
context=Context(user_id='fake_user_id'))
|
||||
|
||||
with pytest.raises(Unauthorized):
|
||||
await hass.services.async_call(
|
||||
'rainmachine',
|
||||
'unpause_watering', {},
|
||||
blocking=True,
|
||||
context=Context(user_id=hass_read_only_user.id))
|
|
@ -38,12 +38,14 @@ def mock_entities():
|
|||
available=True,
|
||||
should_poll=False,
|
||||
supported_features=1,
|
||||
platform='test_domain',
|
||||
)
|
||||
living_room = Mock(
|
||||
entity_id='light.living_room',
|
||||
available=True,
|
||||
should_poll=False,
|
||||
supported_features=0,
|
||||
platform='test_domain',
|
||||
)
|
||||
entities = OrderedDict()
|
||||
entities[kitchen.entity_id] = kitchen
|
||||
|
@ -461,3 +463,116 @@ async def test_register_admin_service(hass, hass_read_only_user,
|
|||
))
|
||||
assert len(calls) == 1
|
||||
assert calls[0].context.user_id == hass_admin_user.id
|
||||
|
||||
|
||||
async def test_domain_control_not_async(hass, mock_entities):
|
||||
"""Test domain verification in a service call with an unknown user."""
|
||||
calls = []
|
||||
|
||||
def mock_service_log(call):
|
||||
"""Define a protected service."""
|
||||
calls.append(call)
|
||||
|
||||
with pytest.raises(exceptions.HomeAssistantError):
|
||||
hass.helpers.service.verify_domain_control(
|
||||
'test_domain')(mock_service_log)
|
||||
|
||||
|
||||
async def test_domain_control_unknown(hass, mock_entities):
|
||||
"""Test domain verification in a service call with an unknown user."""
|
||||
calls = []
|
||||
|
||||
async def mock_service_log(call):
|
||||
"""Define a protected service."""
|
||||
calls.append(call)
|
||||
|
||||
with patch('homeassistant.helpers.entity_registry.async_get_registry',
|
||||
return_value=mock_coro(Mock(entities=mock_entities))):
|
||||
protected_mock_service = hass.helpers.service.verify_domain_control(
|
||||
'test_domain')(mock_service_log)
|
||||
|
||||
hass.services.async_register(
|
||||
'test_domain', 'test_service', protected_mock_service, schema=None)
|
||||
|
||||
with pytest.raises(exceptions.UnknownUser):
|
||||
await hass.services.async_call(
|
||||
'test_domain',
|
||||
'test_service', {},
|
||||
blocking=True,
|
||||
context=ha.Context(user_id='fake_user_id'))
|
||||
assert len(calls) == 0
|
||||
|
||||
|
||||
async def test_domain_control_unauthorized(
|
||||
hass, hass_read_only_user, mock_entities):
|
||||
"""Test domain verification in a service call with an unauthorized user."""
|
||||
calls = []
|
||||
|
||||
async def mock_service_log(call):
|
||||
"""Define a protected service."""
|
||||
calls.append(call)
|
||||
|
||||
with patch('homeassistant.helpers.entity_registry.async_get_registry',
|
||||
return_value=mock_coro(Mock(entities=mock_entities))):
|
||||
protected_mock_service = hass.helpers.service.verify_domain_control(
|
||||
'test_domain')(mock_service_log)
|
||||
|
||||
hass.services.async_register(
|
||||
'test_domain', 'test_service', protected_mock_service, schema=None)
|
||||
|
||||
with pytest.raises(exceptions.Unauthorized):
|
||||
await hass.services.async_call(
|
||||
'test_domain',
|
||||
'test_service', {},
|
||||
blocking=True,
|
||||
context=ha.Context(user_id=hass_read_only_user.id))
|
||||
|
||||
|
||||
async def test_domain_control_admin(hass, hass_admin_user, mock_entities):
|
||||
"""Test domain verification in a service call with an admin user."""
|
||||
calls = []
|
||||
|
||||
async def mock_service_log(call):
|
||||
"""Define a protected service."""
|
||||
calls.append(call)
|
||||
|
||||
with patch('homeassistant.helpers.entity_registry.async_get_registry',
|
||||
return_value=mock_coro(Mock(entities=mock_entities))):
|
||||
protected_mock_service = hass.helpers.service.verify_domain_control(
|
||||
'test_domain')(mock_service_log)
|
||||
|
||||
hass.services.async_register(
|
||||
'test_domain', 'test_service', protected_mock_service, schema=None)
|
||||
|
||||
await hass.services.async_call(
|
||||
'test_domain',
|
||||
'test_service', {},
|
||||
blocking=True,
|
||||
context=ha.Context(user_id=hass_admin_user.id))
|
||||
|
||||
assert len(calls) == 1
|
||||
|
||||
|
||||
async def test_domain_control_no_user(hass, mock_entities):
|
||||
"""Test domain verification in a service call with no user."""
|
||||
calls = []
|
||||
|
||||
async def mock_service_log(call):
|
||||
"""Define a protected service."""
|
||||
calls.append(call)
|
||||
|
||||
with patch('homeassistant.helpers.entity_registry.async_get_registry',
|
||||
return_value=mock_coro(Mock(entities=mock_entities))):
|
||||
protected_mock_service = hass.helpers.service.verify_domain_control(
|
||||
'test_domain')(mock_service_log)
|
||||
|
||||
hass.services.async_register(
|
||||
'test_domain', 'test_service', protected_mock_service, schema=None)
|
||||
|
||||
await hass.services.async_call(
|
||||
'test_domain',
|
||||
'test_service', {},
|
||||
blocking=True,
|
||||
context=ha.Context(user_id=None))
|
||||
|
||||
assert len(calls) == 1
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue