Use @require_admin decorator (#98061)

Co-authored-by: Robert Resch <robert@resch.dev>
Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
This commit is contained in:
Robert Resch 2023-08-14 15:07:20 +02:00 committed by GitHub
parent 525f39fe28
commit b0f68f1ef3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 136 additions and 66 deletions

View file

@ -11,7 +11,7 @@ import voluptuous as vol
from homeassistant.auth.permissions.const import POLICY_READ
from homeassistant.bootstrap import DATA_LOGGING
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http import HomeAssistantView, require_admin
from homeassistant.const import (
EVENT_HOMEASSISTANT_STOP,
MATCH_ALL,
@ -110,10 +110,9 @@ class APIEventStream(HomeAssistantView):
url = URL_API_STREAM
name = "api:stream"
@require_admin
async def get(self, request):
"""Provide a streaming interface for the event bus."""
if not request["hass_user"].is_admin:
raise Unauthorized()
hass = request.app["hass"]
stop_obj = object()
to_write = asyncio.Queue()
@ -278,10 +277,9 @@ class APIEventView(HomeAssistantView):
url = "/api/events/{event_type}"
name = "api:event"
@require_admin
async def post(self, request, event_type):
"""Fire events."""
if not request["hass_user"].is_admin:
raise Unauthorized()
body = await request.text()
try:
event_data = json_loads(body) if body else None
@ -385,10 +383,9 @@ class APITemplateView(HomeAssistantView):
url = URL_API_TEMPLATE
name = "api:template"
@require_admin
async def post(self, request):
"""Render a template."""
if not request["hass_user"].is_admin:
raise Unauthorized()
try:
data = await request.json()
tpl = _cached_template(data["template"], request.app["hass"])
@ -405,10 +402,9 @@ class APIErrorLog(HomeAssistantView):
url = URL_API_ERROR_LOG
name = "api:error_log"
@require_admin
async def get(self, request):
"""Retrieve API error log."""
if not request["hass_user"].is_admin:
raise Unauthorized()
return web.FileResponse(request.app["hass"].data[DATA_LOGGING])

View file

@ -11,7 +11,7 @@ import voluptuous as vol
from homeassistant import config_entries, data_entry_flow
from homeassistant.auth.permissions.const import CAT_CONFIG_ENTRIES, POLICY_EDIT
from homeassistant.components import websocket_api
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http import HomeAssistantView, require_admin
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import DependencyError, Unauthorized
import homeassistant.helpers.config_validation as cv
@ -138,12 +138,11 @@ class ConfigManagerFlowIndexView(FlowManagerIndexView):
"""Not implemented."""
raise aiohttp.web_exceptions.HTTPMethodNotAllowed("GET", ["POST"])
# pylint: disable=arguments-differ
@require_admin(
error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="add")
)
async def post(self, request):
"""Handle a POST request."""
if not request["hass_user"].is_admin:
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="add")
# pylint: disable=no-value-for-parameter
try:
return await super().post(request)
@ -164,19 +163,18 @@ class ConfigManagerFlowResourceView(FlowManagerResourceView):
url = "/api/config/config_entries/flow/{flow_id}"
name = "api:config:config_entries:flow:resource"
async def get(self, request, flow_id):
@require_admin(
error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="add")
)
async def get(self, request, /, flow_id):
"""Get the current state of a data_entry_flow."""
if not request["hass_user"].is_admin:
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="add")
return await super().get(request, flow_id)
# pylint: disable=arguments-differ
@require_admin(
error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="add")
)
async def post(self, request, flow_id):
"""Handle a POST request."""
if not request["hass_user"].is_admin:
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="add")
# pylint: disable=no-value-for-parameter
return await super().post(request, flow_id)
@ -206,15 +204,14 @@ class OptionManagerFlowIndexView(FlowManagerIndexView):
url = "/api/config/config_entries/options/flow"
name = "api:config:config_entries:option:flow"
# pylint: disable=arguments-differ
@require_admin(
error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
)
async def post(self, request):
"""Handle a POST request.
handler in request is entry_id.
"""
if not request["hass_user"].is_admin:
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
# pylint: disable=no-value-for-parameter
return await super().post(request)
@ -225,19 +222,18 @@ class OptionManagerFlowResourceView(FlowManagerResourceView):
url = "/api/config/config_entries/options/flow/{flow_id}"
name = "api:config:config_entries:options:flow:resource"
async def get(self, request, flow_id):
@require_admin(
error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
)
async def get(self, request, /, flow_id):
"""Get the current state of a data_entry_flow."""
if not request["hass_user"].is_admin:
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
return await super().get(request, flow_id)
# pylint: disable=arguments-differ
@require_admin(
error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
)
async def post(self, request, flow_id):
"""Handle a POST request."""
if not request["hass_user"].is_admin:
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
# pylint: disable=no-value-for-parameter
return await super().post(request, flow_id)

View file

@ -1,8 +1,9 @@
"""Decorators for the Home Assistant API."""
from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import Concatenate, ParamSpec, TypeVar
from collections.abc import Callable, Coroutine
from functools import wraps
from typing import Any, Concatenate, ParamSpec, TypeVar, overload
from aiohttp.web import Request, Response
@ -12,20 +13,61 @@ from .view import HomeAssistantView
_HomeAssistantViewT = TypeVar("_HomeAssistantViewT", bound=HomeAssistantView)
_P = ParamSpec("_P")
_FuncType = Callable[
Concatenate[_HomeAssistantViewT, Request, _P], Coroutine[Any, Any, Response]
]
@overload
def require_admin(
_func: None = None,
*,
error: Unauthorized | None = None,
) -> Callable[[_FuncType[_HomeAssistantViewT, _P]], _FuncType[_HomeAssistantViewT, _P]]:
...
@overload
def require_admin(
_func: _FuncType[_HomeAssistantViewT, _P],
) -> _FuncType[_HomeAssistantViewT, _P]:
...
def require_admin(
func: Callable[Concatenate[_HomeAssistantViewT, Request, _P], Awaitable[Response]]
) -> Callable[Concatenate[_HomeAssistantViewT, Request, _P], Awaitable[Response]]:
_func: _FuncType[_HomeAssistantViewT, _P] | None = None,
*,
error: Unauthorized | None = None,
) -> (
Callable[[_FuncType[_HomeAssistantViewT, _P]], _FuncType[_HomeAssistantViewT, _P]]
| _FuncType[_HomeAssistantViewT, _P]
):
"""Home Assistant API decorator to require user to be an admin."""
async def with_admin(
self: _HomeAssistantViewT, request: Request, *args: _P.args, **kwargs: _P.kwargs
) -> Response:
"""Check admin and call function."""
if not request["hass_user"].is_admin:
raise Unauthorized()
def decorator_require_admin(
func: _FuncType[_HomeAssistantViewT, _P]
) -> _FuncType[_HomeAssistantViewT, _P]:
"""Wrap the provided with_admin function."""
return await func(self, request, *args, **kwargs)
@wraps(func)
async def with_admin(
self: _HomeAssistantViewT,
request: Request,
*args: _P.args,
**kwargs: _P.kwargs,
) -> Response:
"""Check admin and call function."""
if not request["hass_user"].is_admin:
raise error or Unauthorized()
return with_admin
return await func(self, request, *args, **kwargs)
return with_admin
# See if we're being called as @require_admin or @require_admin().
if _func is None:
# We're called with brackets.
return decorator_require_admin
# We're called as @require_admin without brackets.
return decorator_require_admin(_func)

View file

@ -12,9 +12,9 @@ from aiohttp.web_request import FileField
import voluptuous as vol
from homeassistant.components import http, websocket_api
from homeassistant.components.http import require_admin
from homeassistant.components.media_player import BrowseError, MediaClass
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import Unauthorized
from homeassistant.util import raise_if_invalid_filename, raise_if_invalid_path
from .const import DOMAIN, MEDIA_CLASS_MAP, MEDIA_MIME_TYPES
@ -254,11 +254,9 @@ class UploadMediaView(http.HomeAssistantView):
}
)
@require_admin
async def post(self, request: web.Request) -> web.Response:
"""Handle upload."""
if not request["hass_user"].is_admin:
raise Unauthorized()
# Increase max payload
request._client_max_size = MAX_UPLOAD_SIZE # pylint: disable=protected-access

View file

@ -12,6 +12,7 @@ from homeassistant import data_entry_flow
from homeassistant.auth.permissions.const import POLICY_EDIT
from homeassistant.components import websocket_api
from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.components.http.decorators import require_admin
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import Unauthorized
from homeassistant.helpers.data_entry_flow import (
@ -88,6 +89,7 @@ class RepairsFlowIndexView(FlowManagerIndexView):
url = "/api/repairs/issues/fix"
name = "api:repairs:issues:fix"
@require_admin(error=Unauthorized(permission=POLICY_EDIT))
@RequestDataValidator(
vol.Schema(
{
@ -99,9 +101,6 @@ class RepairsFlowIndexView(FlowManagerIndexView):
)
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
"""Handle a POST request."""
if not request["hass_user"].is_admin:
raise Unauthorized(permission=POLICY_EDIT)
try:
result = await self._flow_mgr.async_init(
data["handler"],
@ -125,18 +124,12 @@ class RepairsFlowResourceView(FlowManagerResourceView):
url = "/api/repairs/issues/fix/{flow_id}"
name = "api:repairs:issues:fix:resource"
async def get(self, request: web.Request, flow_id: str) -> web.Response:
@require_admin(error=Unauthorized(permission=POLICY_EDIT))
async def get(self, request: web.Request, /, flow_id: str) -> web.Response:
"""Get the current state of a data_entry_flow."""
if not request["hass_user"].is_admin:
raise Unauthorized(permission=POLICY_EDIT)
return await super().get(request, flow_id)
# pylint: disable=arguments-differ
@require_admin(error=Unauthorized(permission=POLICY_EDIT))
async def post(self, request: web.Request, flow_id: str) -> web.Response:
"""Handle a POST request."""
if not request["hass_user"].is_admin:
raise Unauthorized(permission=POLICY_EDIT)
# pylint: disable=no-value-for-parameter
return await super().post(request, flow_id)

View file

@ -55,6 +55,7 @@ from zwave_js_server.model.utils import (
from zwave_js_server.util.node import async_set_config_parameter
from homeassistant.components import websocket_api
from homeassistant.components.http import require_admin
from homeassistant.components.http.view import HomeAssistantView
from homeassistant.components.websocket_api import (
ERR_INVALID_FORMAT,
@ -65,7 +66,6 @@ from homeassistant.components.websocket_api import (
)
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import Unauthorized
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.aiohttp_client import async_get_clientsession
import homeassistant.helpers.device_registry as dr
@ -2149,10 +2149,9 @@ class FirmwareUploadView(HomeAssistantView):
super().__init__()
self._dev_reg = dev_reg
@require_admin
async def post(self, request: web.Request, device_id: str) -> web.Response:
"""Handle upload."""
if not request["hass_user"].is_admin:
raise Unauthorized()
hass = request.app["hass"]
try:

View file

@ -90,7 +90,7 @@ class FlowManagerIndexView(_BaseFlowManagerView):
class FlowManagerResourceView(_BaseFlowManagerView):
"""View to interact with the flow manager."""
async def get(self, request: web.Request, flow_id: str) -> web.Response:
async def get(self, request: web.Request, /, flow_id: str) -> web.Response:
"""Get the current state of a data_entry_flow."""
try:
result = await self._flow_mgr.async_configure(flow_id)

View file

@ -825,6 +825,52 @@ async def test_options_flow(hass: HomeAssistant, client) -> None:
}
@pytest.mark.parametrize(
("endpoint", "method"),
[
("/api/config/config_entries/options/flow", "post"),
("/api/config/config_entries/options/flow/1", "get"),
("/api/config/config_entries/options/flow/1", "post"),
],
)
async def test_options_flow_unauth(
hass: HomeAssistant, client, hass_admin_user: MockUser, endpoint: str, method: str
) -> None:
"""Test unauthorized on options flow."""
class TestFlow(core_ce.ConfigFlow):
@staticmethod
@callback
def async_get_options_flow(config_entry):
class OptionsFlowHandler(data_entry_flow.FlowHandler):
async def async_step_init(self, user_input=None):
schema = OrderedDict()
schema[vol.Required("enabled")] = bool
return self.async_show_form(
step_id="user",
data_schema=schema,
description_placeholders={"enabled": "Set to true to be true"},
)
return OptionsFlowHandler()
mock_integration(hass, MockModule("test"))
mock_entity_platform(hass, "config_flow.test", None)
MockConfigEntry(
domain="test",
entry_id="test1",
source="bla",
).add_to_hass(hass)
entry = hass.config_entries.async_entries()[0]
hass_admin_user.groups = []
with patch.dict(HANDLERS, {"test": TestFlow}):
resp = await getattr(client, method)(endpoint, json={"handler": entry.entry_id})
assert resp.status == HTTPStatus.UNAUTHORIZED
async def test_two_step_options_flow(hass: HomeAssistant, client) -> None:
"""Test we can finish a two step options flow."""
mock_integration(