Add is_admin checks to cloud APIs (#97804)

This commit is contained in:
Franck Nijhof 2023-08-08 11:02:42 +02:00 committed by GitHub
parent 3859d2e2a6
commit 5e020ea354
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 66 additions and 3 deletions

View file

@ -24,7 +24,7 @@ from homeassistant.components.alexa import (
)
from homeassistant.components.google_assistant import helpers as google_helpers
from homeassistant.components.homeassistant import exposed_entities
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http import HomeAssistantView, require_admin
from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.const import CLOUD_NEVER_EXPOSED_ENTITIES
from homeassistant.core import HomeAssistant
@ -128,7 +128,6 @@ def _handle_cloud_errors(
try:
result = await handler(view, request, *args, **kwargs)
return result
except Exception as err: # pylint: disable=broad-except
status, msg = _process_cloud_exception(err, request.path)
return view.json_message(
@ -188,6 +187,7 @@ class GoogleActionsSyncView(HomeAssistantView):
url = "/api/cloud/google_actions/sync"
name = "api:cloud:google_actions/sync"
@require_admin
@_handle_cloud_errors
async def post(self, request: web.Request) -> web.Response:
"""Trigger a Google Actions sync."""
@ -204,6 +204,7 @@ class CloudLoginView(HomeAssistantView):
url = "/api/cloud/login"
name = "api:cloud:login"
@require_admin
@_handle_cloud_errors
@RequestDataValidator(
vol.Schema({vol.Required("email"): str, vol.Required("password"): str})
@ -244,6 +245,7 @@ class CloudLogoutView(HomeAssistantView):
url = "/api/cloud/logout"
name = "api:cloud:logout"
@require_admin
@_handle_cloud_errors
async def post(self, request: web.Request) -> web.Response:
"""Handle logout request."""
@ -262,6 +264,7 @@ class CloudRegisterView(HomeAssistantView):
url = "/api/cloud/register"
name = "api:cloud:register"
@require_admin
@_handle_cloud_errors
@RequestDataValidator(
vol.Schema(
@ -305,6 +308,7 @@ class CloudResendConfirmView(HomeAssistantView):
url = "/api/cloud/resend_confirm"
name = "api:cloud:resend_confirm"
@require_admin
@_handle_cloud_errors
@RequestDataValidator(vol.Schema({vol.Required("email"): str}))
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
@ -324,6 +328,7 @@ class CloudForgotPasswordView(HomeAssistantView):
url = "/api/cloud/forgot_password"
name = "api:cloud:forgot_password"
@require_admin
@_handle_cloud_errors
@RequestDataValidator(vol.Schema({vol.Required("email"): str}))
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:

View file

@ -52,6 +52,7 @@ from .const import ( # noqa: F401
KEY_HASS_USER,
)
from .cors import setup_cors
from .decorators import require_admin # noqa: F401
from .forwarded import async_setup_forwarded
from .headers import setup_headers
from .request_context import current_request, setup_request_context

View file

@ -0,0 +1,31 @@
"""Decorators for the Home Assistant API."""
from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import Concatenate, ParamSpec, TypeVar
from aiohttp.web import Request, Response
from homeassistant.exceptions import Unauthorized
from .view import HomeAssistantView
_HomeAssistantViewT = TypeVar("_HomeAssistantViewT", bound=HomeAssistantView)
_P = ParamSpec("_P")
def require_admin(
func: Callable[Concatenate[_HomeAssistantViewT, Request, _P], Awaitable[Response]]
) -> Callable[Concatenate[_HomeAssistantViewT, Request, _P], Awaitable[Response]]:
"""Home Assistant API decorator to require user to be an admin."""
async def with_admin(
self: _HomeAssistantViewT, request: Request, *args: _P.args, **kwargs: _P.kwargs
) -> Response:
"""Check admin and call function."""
if not request["hass_user"].is_admin:
raise Unauthorized()
return await func(self, request, *args, **kwargs)
return with_admin

View file

@ -1,6 +1,7 @@
"""Tests for the HTTP API for the cloud component."""
import asyncio
from http import HTTPStatus
from typing import Any
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import aiohttp
@ -24,7 +25,7 @@ from . import mock_cloud, mock_cloud_prefs
from tests.components.google_assistant import MockConfig
from tests.test_util.aiohttp import AiohttpClientMocker
from tests.typing import WebSocketGenerator
from tests.typing import ClientSessionGenerator, WebSocketGenerator
SUBSCRIPTION_INFO_URL = "https://api-test.hass.io/payments/subscription_info"
@ -1207,3 +1208,28 @@ async def test_tts_info(
assert response["success"]
assert response["result"] == {"languages": [["en-US", "male"], ["en-US", "female"]]}
@pytest.mark.parametrize(
("endpoint", "data"),
[
("/api/cloud/forgot_password", {"email": "fake@example.com"}),
("/api/cloud/google_actions/sync", None),
("/api/cloud/login", {"email": "fake@example.com", "password": "secret"}),
("/api/cloud/logout", None),
("/api/cloud/register", {"email": "fake@example.com", "password": "secret"}),
("/api/cloud/resend_confirm", {"email": "fake@example.com"}),
],
)
async def test_api_calls_require_admin(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
hass_read_only_access_token: str,
endpoint: str,
data: dict[str, Any] | None,
) -> None:
"""Test cloud APIs endpoints do not work as a normal user."""
client = await hass_client(hass_read_only_access_token)
resp = await client.post(endpoint, json=data)
assert resp.status == HTTPStatus.UNAUTHORIZED