Use @require_admin decorator (#98061)
Co-authored-by: Robert Resch <robert@resch.dev> Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
This commit is contained in:
parent
525f39fe28
commit
b0f68f1ef3
8 changed files with 136 additions and 66 deletions
|
@ -11,7 +11,7 @@ import voluptuous as vol
|
|||
|
||||
from homeassistant.auth.permissions.const import POLICY_READ
|
||||
from homeassistant.bootstrap import DATA_LOGGING
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.components.http import HomeAssistantView, require_admin
|
||||
from homeassistant.const import (
|
||||
EVENT_HOMEASSISTANT_STOP,
|
||||
MATCH_ALL,
|
||||
|
@ -110,10 +110,9 @@ class APIEventStream(HomeAssistantView):
|
|||
url = URL_API_STREAM
|
||||
name = "api:stream"
|
||||
|
||||
@require_admin
|
||||
async def get(self, request):
|
||||
"""Provide a streaming interface for the event bus."""
|
||||
if not request["hass_user"].is_admin:
|
||||
raise Unauthorized()
|
||||
hass = request.app["hass"]
|
||||
stop_obj = object()
|
||||
to_write = asyncio.Queue()
|
||||
|
@ -278,10 +277,9 @@ class APIEventView(HomeAssistantView):
|
|||
url = "/api/events/{event_type}"
|
||||
name = "api:event"
|
||||
|
||||
@require_admin
|
||||
async def post(self, request, event_type):
|
||||
"""Fire events."""
|
||||
if not request["hass_user"].is_admin:
|
||||
raise Unauthorized()
|
||||
body = await request.text()
|
||||
try:
|
||||
event_data = json_loads(body) if body else None
|
||||
|
@ -385,10 +383,9 @@ class APITemplateView(HomeAssistantView):
|
|||
url = URL_API_TEMPLATE
|
||||
name = "api:template"
|
||||
|
||||
@require_admin
|
||||
async def post(self, request):
|
||||
"""Render a template."""
|
||||
if not request["hass_user"].is_admin:
|
||||
raise Unauthorized()
|
||||
try:
|
||||
data = await request.json()
|
||||
tpl = _cached_template(data["template"], request.app["hass"])
|
||||
|
@ -405,10 +402,9 @@ class APIErrorLog(HomeAssistantView):
|
|||
url = URL_API_ERROR_LOG
|
||||
name = "api:error_log"
|
||||
|
||||
@require_admin
|
||||
async def get(self, request):
|
||||
"""Retrieve API error log."""
|
||||
if not request["hass_user"].is_admin:
|
||||
raise Unauthorized()
|
||||
return web.FileResponse(request.app["hass"].data[DATA_LOGGING])
|
||||
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ import voluptuous as vol
|
|||
from homeassistant import config_entries, data_entry_flow
|
||||
from homeassistant.auth.permissions.const import CAT_CONFIG_ENTRIES, POLICY_EDIT
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.components.http import HomeAssistantView, require_admin
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import DependencyError, Unauthorized
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
|
@ -138,12 +138,11 @@ class ConfigManagerFlowIndexView(FlowManagerIndexView):
|
|||
"""Not implemented."""
|
||||
raise aiohttp.web_exceptions.HTTPMethodNotAllowed("GET", ["POST"])
|
||||
|
||||
# pylint: disable=arguments-differ
|
||||
@require_admin(
|
||||
error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="add")
|
||||
)
|
||||
async def post(self, request):
|
||||
"""Handle a POST request."""
|
||||
if not request["hass_user"].is_admin:
|
||||
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="add")
|
||||
|
||||
# pylint: disable=no-value-for-parameter
|
||||
try:
|
||||
return await super().post(request)
|
||||
|
@ -164,19 +163,18 @@ class ConfigManagerFlowResourceView(FlowManagerResourceView):
|
|||
url = "/api/config/config_entries/flow/{flow_id}"
|
||||
name = "api:config:config_entries:flow:resource"
|
||||
|
||||
async def get(self, request, flow_id):
|
||||
@require_admin(
|
||||
error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="add")
|
||||
)
|
||||
async def get(self, request, /, flow_id):
|
||||
"""Get the current state of a data_entry_flow."""
|
||||
if not request["hass_user"].is_admin:
|
||||
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="add")
|
||||
|
||||
return await super().get(request, flow_id)
|
||||
|
||||
# pylint: disable=arguments-differ
|
||||
@require_admin(
|
||||
error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="add")
|
||||
)
|
||||
async def post(self, request, flow_id):
|
||||
"""Handle a POST request."""
|
||||
if not request["hass_user"].is_admin:
|
||||
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="add")
|
||||
|
||||
# pylint: disable=no-value-for-parameter
|
||||
return await super().post(request, flow_id)
|
||||
|
||||
|
@ -206,15 +204,14 @@ class OptionManagerFlowIndexView(FlowManagerIndexView):
|
|||
url = "/api/config/config_entries/options/flow"
|
||||
name = "api:config:config_entries:option:flow"
|
||||
|
||||
# pylint: disable=arguments-differ
|
||||
@require_admin(
|
||||
error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
|
||||
)
|
||||
async def post(self, request):
|
||||
"""Handle a POST request.
|
||||
|
||||
handler in request is entry_id.
|
||||
"""
|
||||
if not request["hass_user"].is_admin:
|
||||
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
|
||||
|
||||
# pylint: disable=no-value-for-parameter
|
||||
return await super().post(request)
|
||||
|
||||
|
@ -225,19 +222,18 @@ class OptionManagerFlowResourceView(FlowManagerResourceView):
|
|||
url = "/api/config/config_entries/options/flow/{flow_id}"
|
||||
name = "api:config:config_entries:options:flow:resource"
|
||||
|
||||
async def get(self, request, flow_id):
|
||||
@require_admin(
|
||||
error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
|
||||
)
|
||||
async def get(self, request, /, flow_id):
|
||||
"""Get the current state of a data_entry_flow."""
|
||||
if not request["hass_user"].is_admin:
|
||||
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
|
||||
|
||||
return await super().get(request, flow_id)
|
||||
|
||||
# pylint: disable=arguments-differ
|
||||
@require_admin(
|
||||
error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
|
||||
)
|
||||
async def post(self, request, flow_id):
|
||||
"""Handle a POST request."""
|
||||
if not request["hass_user"].is_admin:
|
||||
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
|
||||
|
||||
# pylint: disable=no-value-for-parameter
|
||||
return await super().post(request, flow_id)
|
||||
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
"""Decorators for the Home Assistant API."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
from collections.abc import Callable, Coroutine
|
||||
from functools import wraps
|
||||
from typing import Any, Concatenate, ParamSpec, TypeVar, overload
|
||||
|
||||
from aiohttp.web import Request, Response
|
||||
|
||||
|
@ -12,20 +13,61 @@ from .view import HomeAssistantView
|
|||
|
||||
_HomeAssistantViewT = TypeVar("_HomeAssistantViewT", bound=HomeAssistantView)
|
||||
_P = ParamSpec("_P")
|
||||
_FuncType = Callable[
|
||||
Concatenate[_HomeAssistantViewT, Request, _P], Coroutine[Any, Any, Response]
|
||||
]
|
||||
|
||||
|
||||
@overload
|
||||
def require_admin(
|
||||
_func: None = None,
|
||||
*,
|
||||
error: Unauthorized | None = None,
|
||||
) -> Callable[[_FuncType[_HomeAssistantViewT, _P]], _FuncType[_HomeAssistantViewT, _P]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def require_admin(
|
||||
_func: _FuncType[_HomeAssistantViewT, _P],
|
||||
) -> _FuncType[_HomeAssistantViewT, _P]:
|
||||
...
|
||||
|
||||
|
||||
def require_admin(
|
||||
func: Callable[Concatenate[_HomeAssistantViewT, Request, _P], Awaitable[Response]]
|
||||
) -> Callable[Concatenate[_HomeAssistantViewT, Request, _P], Awaitable[Response]]:
|
||||
_func: _FuncType[_HomeAssistantViewT, _P] | None = None,
|
||||
*,
|
||||
error: Unauthorized | None = None,
|
||||
) -> (
|
||||
Callable[[_FuncType[_HomeAssistantViewT, _P]], _FuncType[_HomeAssistantViewT, _P]]
|
||||
| _FuncType[_HomeAssistantViewT, _P]
|
||||
):
|
||||
"""Home Assistant API decorator to require user to be an admin."""
|
||||
|
||||
async def with_admin(
|
||||
self: _HomeAssistantViewT, request: Request, *args: _P.args, **kwargs: _P.kwargs
|
||||
) -> Response:
|
||||
"""Check admin and call function."""
|
||||
if not request["hass_user"].is_admin:
|
||||
raise Unauthorized()
|
||||
def decorator_require_admin(
|
||||
func: _FuncType[_HomeAssistantViewT, _P]
|
||||
) -> _FuncType[_HomeAssistantViewT, _P]:
|
||||
"""Wrap the provided with_admin function."""
|
||||
|
||||
return await func(self, request, *args, **kwargs)
|
||||
@wraps(func)
|
||||
async def with_admin(
|
||||
self: _HomeAssistantViewT,
|
||||
request: Request,
|
||||
*args: _P.args,
|
||||
**kwargs: _P.kwargs,
|
||||
) -> Response:
|
||||
"""Check admin and call function."""
|
||||
if not request["hass_user"].is_admin:
|
||||
raise error or Unauthorized()
|
||||
|
||||
return with_admin
|
||||
return await func(self, request, *args, **kwargs)
|
||||
|
||||
return with_admin
|
||||
|
||||
# See if we're being called as @require_admin or @require_admin().
|
||||
if _func is None:
|
||||
# We're called with brackets.
|
||||
return decorator_require_admin
|
||||
|
||||
# We're called as @require_admin without brackets.
|
||||
return decorator_require_admin(_func)
|
||||
|
|
|
@ -12,9 +12,9 @@ from aiohttp.web_request import FileField
|
|||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import http, websocket_api
|
||||
from homeassistant.components.http import require_admin
|
||||
from homeassistant.components.media_player import BrowseError, MediaClass
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import Unauthorized
|
||||
from homeassistant.util import raise_if_invalid_filename, raise_if_invalid_path
|
||||
|
||||
from .const import DOMAIN, MEDIA_CLASS_MAP, MEDIA_MIME_TYPES
|
||||
|
@ -254,11 +254,9 @@ class UploadMediaView(http.HomeAssistantView):
|
|||
}
|
||||
)
|
||||
|
||||
@require_admin
|
||||
async def post(self, request: web.Request) -> web.Response:
|
||||
"""Handle upload."""
|
||||
if not request["hass_user"].is_admin:
|
||||
raise Unauthorized()
|
||||
|
||||
# Increase max payload
|
||||
request._client_max_size = MAX_UPLOAD_SIZE # pylint: disable=protected-access
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ from homeassistant import data_entry_flow
|
|||
from homeassistant.auth.permissions.const import POLICY_EDIT
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||
from homeassistant.components.http.decorators import require_admin
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import Unauthorized
|
||||
from homeassistant.helpers.data_entry_flow import (
|
||||
|
@ -88,6 +89,7 @@ class RepairsFlowIndexView(FlowManagerIndexView):
|
|||
url = "/api/repairs/issues/fix"
|
||||
name = "api:repairs:issues:fix"
|
||||
|
||||
@require_admin(error=Unauthorized(permission=POLICY_EDIT))
|
||||
@RequestDataValidator(
|
||||
vol.Schema(
|
||||
{
|
||||
|
@ -99,9 +101,6 @@ class RepairsFlowIndexView(FlowManagerIndexView):
|
|||
)
|
||||
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
|
||||
"""Handle a POST request."""
|
||||
if not request["hass_user"].is_admin:
|
||||
raise Unauthorized(permission=POLICY_EDIT)
|
||||
|
||||
try:
|
||||
result = await self._flow_mgr.async_init(
|
||||
data["handler"],
|
||||
|
@ -125,18 +124,12 @@ class RepairsFlowResourceView(FlowManagerResourceView):
|
|||
url = "/api/repairs/issues/fix/{flow_id}"
|
||||
name = "api:repairs:issues:fix:resource"
|
||||
|
||||
async def get(self, request: web.Request, flow_id: str) -> web.Response:
|
||||
@require_admin(error=Unauthorized(permission=POLICY_EDIT))
|
||||
async def get(self, request: web.Request, /, flow_id: str) -> web.Response:
|
||||
"""Get the current state of a data_entry_flow."""
|
||||
if not request["hass_user"].is_admin:
|
||||
raise Unauthorized(permission=POLICY_EDIT)
|
||||
|
||||
return await super().get(request, flow_id)
|
||||
|
||||
# pylint: disable=arguments-differ
|
||||
@require_admin(error=Unauthorized(permission=POLICY_EDIT))
|
||||
async def post(self, request: web.Request, flow_id: str) -> web.Response:
|
||||
"""Handle a POST request."""
|
||||
if not request["hass_user"].is_admin:
|
||||
raise Unauthorized(permission=POLICY_EDIT)
|
||||
|
||||
# pylint: disable=no-value-for-parameter
|
||||
return await super().post(request, flow_id)
|
||||
|
|
|
@ -55,6 +55,7 @@ from zwave_js_server.model.utils import (
|
|||
from zwave_js_server.util.node import async_set_config_parameter
|
||||
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.components.http import require_admin
|
||||
from homeassistant.components.http.view import HomeAssistantView
|
||||
from homeassistant.components.websocket_api import (
|
||||
ERR_INVALID_FORMAT,
|
||||
|
@ -65,7 +66,6 @@ from homeassistant.components.websocket_api import (
|
|||
)
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import Unauthorized
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
import homeassistant.helpers.device_registry as dr
|
||||
|
@ -2149,10 +2149,9 @@ class FirmwareUploadView(HomeAssistantView):
|
|||
super().__init__()
|
||||
self._dev_reg = dev_reg
|
||||
|
||||
@require_admin
|
||||
async def post(self, request: web.Request, device_id: str) -> web.Response:
|
||||
"""Handle upload."""
|
||||
if not request["hass_user"].is_admin:
|
||||
raise Unauthorized()
|
||||
hass = request.app["hass"]
|
||||
|
||||
try:
|
||||
|
|
|
@ -90,7 +90,7 @@ class FlowManagerIndexView(_BaseFlowManagerView):
|
|||
class FlowManagerResourceView(_BaseFlowManagerView):
|
||||
"""View to interact with the flow manager."""
|
||||
|
||||
async def get(self, request: web.Request, flow_id: str) -> web.Response:
|
||||
async def get(self, request: web.Request, /, flow_id: str) -> web.Response:
|
||||
"""Get the current state of a data_entry_flow."""
|
||||
try:
|
||||
result = await self._flow_mgr.async_configure(flow_id)
|
||||
|
|
|
@ -825,6 +825,52 @@ async def test_options_flow(hass: HomeAssistant, client) -> None:
|
|||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("endpoint", "method"),
|
||||
[
|
||||
("/api/config/config_entries/options/flow", "post"),
|
||||
("/api/config/config_entries/options/flow/1", "get"),
|
||||
("/api/config/config_entries/options/flow/1", "post"),
|
||||
],
|
||||
)
|
||||
async def test_options_flow_unauth(
|
||||
hass: HomeAssistant, client, hass_admin_user: MockUser, endpoint: str, method: str
|
||||
) -> None:
|
||||
"""Test unauthorized on options flow."""
|
||||
|
||||
class TestFlow(core_ce.ConfigFlow):
|
||||
@staticmethod
|
||||
@callback
|
||||
def async_get_options_flow(config_entry):
|
||||
class OptionsFlowHandler(data_entry_flow.FlowHandler):
|
||||
async def async_step_init(self, user_input=None):
|
||||
schema = OrderedDict()
|
||||
schema[vol.Required("enabled")] = bool
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
data_schema=schema,
|
||||
description_placeholders={"enabled": "Set to true to be true"},
|
||||
)
|
||||
|
||||
return OptionsFlowHandler()
|
||||
|
||||
mock_integration(hass, MockModule("test"))
|
||||
mock_entity_platform(hass, "config_flow.test", None)
|
||||
MockConfigEntry(
|
||||
domain="test",
|
||||
entry_id="test1",
|
||||
source="bla",
|
||||
).add_to_hass(hass)
|
||||
entry = hass.config_entries.async_entries()[0]
|
||||
|
||||
hass_admin_user.groups = []
|
||||
|
||||
with patch.dict(HANDLERS, {"test": TestFlow}):
|
||||
resp = await getattr(client, method)(endpoint, json={"handler": entry.entry_id})
|
||||
|
||||
assert resp.status == HTTPStatus.UNAUTHORIZED
|
||||
|
||||
|
||||
async def test_two_step_options_flow(hass: HomeAssistant, client) -> None:
|
||||
"""Test we can finish a two step options flow."""
|
||||
mock_integration(
|
||||
|
|
Loading…
Add table
Reference in a new issue