diff --git a/homeassistant/components/rainmachine/__init__.py b/homeassistant/components/rainmachine/__init__.py index 6d986fa5c67..2ff5ddcd4aa 100644 --- a/homeassistant/components/rainmachine/__init__.py +++ b/homeassistant/components/rainmachine/__init__.py @@ -1,15 +1,18 @@ """Support for RainMachine devices.""" import logging from datetime import timedelta +from functools import wraps import voluptuous as vol +from homeassistant.auth.permissions.const import POLICY_CONTROL from homeassistant.config_entries import SOURCE_IMPORT from homeassistant.const import ( ATTR_ATTRIBUTION, CONF_BINARY_SENSORS, CONF_IP_ADDRESS, CONF_PASSWORD, CONF_PORT, CONF_SCAN_INTERVAL, CONF_SENSORS, CONF_SSL, CONF_MONITORED_CONDITIONS, CONF_SWITCHES) -from homeassistant.exceptions import ConfigEntryNotReady +from homeassistant.exceptions import ( + ConfigEntryNotReady, Unauthorized, UnknownUser) from homeassistant.helpers import aiohttp_client, config_validation as cv from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.entity import Entity @@ -128,6 +131,44 @@ CONFIG_SCHEMA = vol.Schema({ }, extra=vol.ALLOW_EXTRA) +def _check_valid_user(hass): + """Ensure the user of a service call has proper permissions.""" + def decorator(service): + """Decorate.""" + @wraps(service) + async def check_permissions(call): + """Check user permission and raise before call if unauthorized.""" + if not call.context.user_id: + return + + user = await hass.auth.async_get_user(call.context.user_id) + if user is None: + raise UnknownUser( + context=call.context, + permission=POLICY_CONTROL + ) + + # RainMachine services don't interact with specific entities. + # Therefore, we examine _all_ RainMachine entities and if the user + # has permission to control _any_ of them, the user has permission + # to call the service: + en_reg = await hass.helpers.entity_registry.async_get_registry() + rainmachine_entities = [ + entity.entity_id for entity in en_reg.entities.values() + if entity.platform == DOMAIN + ] + for entity_id in rainmachine_entities: + if user.permissions.check_entity(entity_id, POLICY_CONTROL): + return await service(call) + + raise Unauthorized( + context=call.context, + permission=POLICY_CONTROL, + ) + return check_permissions + return decorator + + async def async_setup(hass, config): """Set up the RainMachine component.""" hass.data[DOMAIN] = {} @@ -197,59 +238,70 @@ async def async_setup_entry(hass, config_entry): refresh, timedelta(seconds=config_entry.data[CONF_SCAN_INTERVAL])) - async def disable_program(service): + @_check_valid_user(hass) + async def disable_program(call): """Disable a program.""" await rainmachine.client.programs.disable( - service.data[CONF_PROGRAM_ID]) + call.data[CONF_PROGRAM_ID]) async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC) - async def disable_zone(service): + @_check_valid_user(hass) + async def disable_zone(call): """Disable a zone.""" - await rainmachine.client.zones.disable(service.data[CONF_ZONE_ID]) + await rainmachine.client.zones.disable(call.data[CONF_ZONE_ID]) async_dispatcher_send(hass, ZONE_UPDATE_TOPIC) - async def enable_program(service): + @_check_valid_user(hass) + async def enable_program(call): """Enable a program.""" - await rainmachine.client.programs.enable(service.data[CONF_PROGRAM_ID]) + await rainmachine.client.programs.enable(call.data[CONF_PROGRAM_ID]) async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC) - async def enable_zone(service): + @_check_valid_user(hass) + async def enable_zone(call): """Enable a zone.""" - await rainmachine.client.zones.enable(service.data[CONF_ZONE_ID]) + await rainmachine.client.zones.enable(call.data[CONF_ZONE_ID]) async_dispatcher_send(hass, ZONE_UPDATE_TOPIC) - async def pause_watering(service): + @_check_valid_user(hass) + async def pause_watering(call): """Pause watering for a set number of seconds.""" - await rainmachine.client.watering.pause_all(service.data[CONF_SECONDS]) + await rainmachine.client.watering.pause_all(call.data[CONF_SECONDS]) async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC) - async def start_program(service): + @_check_valid_user(hass) + async def start_program(call): """Start a particular program.""" - await rainmachine.client.programs.start(service.data[CONF_PROGRAM_ID]) + await rainmachine.client.programs.start(call.data[CONF_PROGRAM_ID]) async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC) - async def start_zone(service): + @_check_valid_user(hass) + async def start_zone(call): """Start a particular zone for a certain amount of time.""" await rainmachine.client.zones.start( - service.data[CONF_ZONE_ID], service.data[CONF_ZONE_RUN_TIME]) + call.data[CONF_ZONE_ID], call.data[CONF_ZONE_RUN_TIME]) async_dispatcher_send(hass, ZONE_UPDATE_TOPIC) - async def stop_all(service): + @_check_valid_user(hass) + async def stop_all(call): """Stop all watering.""" await rainmachine.client.watering.stop_all() async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC) - async def stop_program(service): + @_check_valid_user(hass) + async def stop_program(call): """Stop a program.""" - await rainmachine.client.programs.stop(service.data[CONF_PROGRAM_ID]) + await rainmachine.client.programs.stop(call.data[CONF_PROGRAM_ID]) async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC) - async def stop_zone(service): + @_check_valid_user(hass) + async def stop_zone(call): """Stop a zone.""" - await rainmachine.client.zones.stop(service.data[CONF_ZONE_ID]) + await rainmachine.client.zones.stop(call.data[CONF_ZONE_ID]) async_dispatcher_send(hass, ZONE_UPDATE_TOPIC) - async def unpause_watering(service): + @_check_valid_user(hass) + async def unpause_watering(call): """Unpause watering.""" await rainmachine.client.watering.unpause_all() async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC) diff --git a/tests/components/rainmachine/conftest.py b/tests/components/rainmachine/conftest.py new file mode 100644 index 00000000000..fdc81151995 --- /dev/null +++ b/tests/components/rainmachine/conftest.py @@ -0,0 +1,23 @@ +"""Configuration for Rainmachine tests.""" +import pytest + +from homeassistant.components.rainmachine.const import DOMAIN +from homeassistant.const import ( + CONF_IP_ADDRESS, CONF_PASSWORD, CONF_PORT, CONF_SCAN_INTERVAL, CONF_SSL) + +from tests.common import MockConfigEntry + + +@pytest.fixture(name="config_entry") +def config_entry_fixture(): + """Create a mock RainMachine config entry.""" + return MockConfigEntry( + domain=DOMAIN, + title='192.168.1.101', + data={ + CONF_IP_ADDRESS: '192.168.1.101', + CONF_PASSWORD: '12345', + CONF_PORT: 8080, + CONF_SSL: True, + CONF_SCAN_INTERVAL: 60, + }) diff --git a/tests/components/rainmachine/test_service_permissions.py b/tests/components/rainmachine/test_service_permissions.py new file mode 100644 index 00000000000..caa84337517 --- /dev/null +++ b/tests/components/rainmachine/test_service_permissions.py @@ -0,0 +1,41 @@ +"""Define tests for permissions on RainMachine service calls.""" +import asynctest +import pytest + +from homeassistant.components.rainmachine.const import DOMAIN +from homeassistant.core import Context +from homeassistant.exceptions import Unauthorized, UnknownUser +from homeassistant.setup import async_setup_component + +from tests.common import mock_coro + + +async def setup_platform(hass, config_entry): + """Set up the media player platform for testing.""" + with asynctest.mock.patch('regenmaschine.login') as mock_login: + mock_client = mock_login.return_value + mock_client.restrictions.current.return_value = mock_coro() + mock_client.restrictions.universal.return_value = mock_coro() + config_entry.add_to_hass(hass) + assert await async_setup_component(hass, DOMAIN) + await hass.async_block_till_done() + + +async def test_services_authorization( + hass, config_entry, hass_read_only_user): + """Test that a RainMachine service is halted on incorrect permissions.""" + await setup_platform(hass, config_entry) + + with pytest.raises(UnknownUser): + await hass.services.async_call( + 'rainmachine', + 'unpause_watering', {}, + blocking=True, + context=Context(user_id='fake_user_id')) + + with pytest.raises(Unauthorized): + await hass.services.async_call( + 'rainmachine', + 'unpause_watering', {}, + blocking=True, + context=Context(user_id=hass_read_only_user.id))