Improve http decorator typing (#75541)
This commit is contained in:
parent
1d7d2875e1
commit
b1ed1543c8
5 changed files with 37 additions and 22 deletions
|
@ -257,7 +257,7 @@ class LoginFlowResourceView(LoginFlowBaseView):
|
||||||
|
|
||||||
@RequestDataValidator(vol.Schema({"client_id": str}, extra=vol.ALLOW_EXTRA))
|
@RequestDataValidator(vol.Schema({"client_id": str}, extra=vol.ALLOW_EXTRA))
|
||||||
@log_invalid_auth
|
@log_invalid_auth
|
||||||
async def post(self, request, flow_id, data):
|
async def post(self, request, data, flow_id):
|
||||||
"""Handle progressing a login flow request."""
|
"""Handle progressing a login flow request."""
|
||||||
client_id = data.pop("client_id")
|
client_id = data.pop("client_id")
|
||||||
|
|
||||||
|
|
|
@ -2,17 +2,18 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable, Coroutine
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from ipaddress import IPv4Address, IPv6Address, ip_address
|
from ipaddress import IPv4Address, IPv6Address, ip_address
|
||||||
import logging
|
import logging
|
||||||
from socket import gethostbyaddr, herror
|
from socket import gethostbyaddr, herror
|
||||||
from typing import Any, Final
|
from typing import Any, Final, TypeVar
|
||||||
|
|
||||||
from aiohttp.web import Application, Request, StreamResponse, middleware
|
from aiohttp.web import Application, Request, Response, StreamResponse, middleware
|
||||||
from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
|
from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
|
||||||
|
from typing_extensions import Concatenate, ParamSpec
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import persistent_notification
|
from homeassistant.components import persistent_notification
|
||||||
|
@ -24,6 +25,9 @@ from homeassistant.util import dt as dt_util, yaml
|
||||||
|
|
||||||
from .view import HomeAssistantView
|
from .view import HomeAssistantView
|
||||||
|
|
||||||
|
_HassViewT = TypeVar("_HassViewT", bound=HomeAssistantView)
|
||||||
|
_P = ParamSpec("_P")
|
||||||
|
|
||||||
_LOGGER: Final = logging.getLogger(__name__)
|
_LOGGER: Final = logging.getLogger(__name__)
|
||||||
|
|
||||||
KEY_BAN_MANAGER: Final = "ha_banned_ips_manager"
|
KEY_BAN_MANAGER: Final = "ha_banned_ips_manager"
|
||||||
|
@ -82,13 +86,13 @@ async def ban_middleware(
|
||||||
|
|
||||||
|
|
||||||
def log_invalid_auth(
|
def log_invalid_auth(
|
||||||
func: Callable[..., Awaitable[StreamResponse]]
|
func: Callable[Concatenate[_HassViewT, Request, _P], Awaitable[Response]]
|
||||||
) -> Callable[..., Awaitable[StreamResponse]]:
|
) -> Callable[Concatenate[_HassViewT, Request, _P], Coroutine[Any, Any, Response]]:
|
||||||
"""Decorate function to handle invalid auth or failed login attempts."""
|
"""Decorate function to handle invalid auth or failed login attempts."""
|
||||||
|
|
||||||
async def handle_req(
|
async def handle_req(
|
||||||
view: HomeAssistantView, request: Request, *args: Any, **kwargs: Any
|
view: _HassViewT, request: Request, *args: _P.args, **kwargs: _P.kwargs
|
||||||
) -> StreamResponse:
|
) -> Response:
|
||||||
"""Try to log failed login attempts if response status >= BAD_REQUEST."""
|
"""Try to log failed login attempts if response status >= BAD_REQUEST."""
|
||||||
resp = await func(view, request, *args, **kwargs)
|
resp = await func(view, request, *args, **kwargs)
|
||||||
if resp.status >= HTTPStatus.BAD_REQUEST:
|
if resp.status >= HTTPStatus.BAD_REQUEST:
|
||||||
|
|
|
@ -1,17 +1,21 @@
|
||||||
"""Decorator for view methods to help with data validation."""
|
"""Decorator for view methods to help with data validation."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable, Coroutine
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
from typing_extensions import Concatenate, ParamSpec
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from .view import HomeAssistantView
|
from .view import HomeAssistantView
|
||||||
|
|
||||||
|
_HassViewT = TypeVar("_HassViewT", bound=HomeAssistantView)
|
||||||
|
_P = ParamSpec("_P")
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,33 +37,40 @@ class RequestDataValidator:
|
||||||
self._allow_empty = allow_empty
|
self._allow_empty = allow_empty
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, method: Callable[..., Awaitable[web.StreamResponse]]
|
self,
|
||||||
) -> Callable:
|
method: Callable[
|
||||||
|
Concatenate[_HassViewT, web.Request, dict[str, Any], _P],
|
||||||
|
Awaitable[web.Response],
|
||||||
|
],
|
||||||
|
) -> Callable[
|
||||||
|
Concatenate[_HassViewT, web.Request, _P],
|
||||||
|
Coroutine[Any, Any, web.Response],
|
||||||
|
]:
|
||||||
"""Decorate a function."""
|
"""Decorate a function."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
async def wrapper(
|
async def wrapper(
|
||||||
view: HomeAssistantView, request: web.Request, *args: Any, **kwargs: Any
|
view: _HassViewT, request: web.Request, *args: _P.args, **kwargs: _P.kwargs
|
||||||
) -> web.StreamResponse:
|
) -> web.Response:
|
||||||
"""Wrap a request handler with data validation."""
|
"""Wrap a request handler with data validation."""
|
||||||
data = None
|
raw_data = None
|
||||||
try:
|
try:
|
||||||
data = await request.json()
|
raw_data = await request.json()
|
||||||
except ValueError:
|
except ValueError:
|
||||||
if not self._allow_empty or (await request.content.read()) != b"":
|
if not self._allow_empty or (await request.content.read()) != b"":
|
||||||
_LOGGER.error("Invalid JSON received")
|
_LOGGER.error("Invalid JSON received")
|
||||||
return view.json_message("Invalid JSON.", HTTPStatus.BAD_REQUEST)
|
return view.json_message("Invalid JSON.", HTTPStatus.BAD_REQUEST)
|
||||||
data = {}
|
raw_data = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
kwargs["data"] = self._schema(data)
|
data: dict[str, Any] = self._schema(raw_data)
|
||||||
except vol.Invalid as err:
|
except vol.Invalid as err:
|
||||||
_LOGGER.error("Data does not match schema: %s", err)
|
_LOGGER.error("Data does not match schema: %s", err)
|
||||||
return view.json_message(
|
return view.json_message(
|
||||||
f"Message format incorrect: {err}", HTTPStatus.BAD_REQUEST
|
f"Message format incorrect: {err}", HTTPStatus.BAD_REQUEST
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await method(view, request, *args, **kwargs)
|
result = await method(view, request, data, *args, **kwargs)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
|
@ -113,7 +113,7 @@ class RepairsFlowIndexView(FlowManagerIndexView):
|
||||||
|
|
||||||
result = self._prepare_result_json(result)
|
result = self._prepare_result_json(result)
|
||||||
|
|
||||||
return self.json(result) # pylint: disable=arguments-differ
|
return self.json(result)
|
||||||
|
|
||||||
|
|
||||||
class RepairsFlowResourceView(FlowManagerResourceView):
|
class RepairsFlowResourceView(FlowManagerResourceView):
|
||||||
|
@ -136,4 +136,4 @@ class RepairsFlowResourceView(FlowManagerResourceView):
|
||||||
raise Unauthorized(permission=POLICY_EDIT)
|
raise Unauthorized(permission=POLICY_EDIT)
|
||||||
|
|
||||||
# pylint: disable=no-value-for-parameter
|
# pylint: disable=no-value-for-parameter
|
||||||
return await super().post(request, flow_id) # type: ignore[no-any-return]
|
return await super().post(request, flow_id)
|
||||||
|
|
|
@ -102,7 +102,7 @@ class FlowManagerResourceView(_BaseFlowManagerView):
|
||||||
|
|
||||||
@RequestDataValidator(vol.Schema(dict), allow_empty=True)
|
@RequestDataValidator(vol.Schema(dict), allow_empty=True)
|
||||||
async def post(
|
async def post(
|
||||||
self, request: web.Request, flow_id: str, data: dict[str, Any]
|
self, request: web.Request, data: dict[str, Any], flow_id: str
|
||||||
) -> web.Response:
|
) -> web.Response:
|
||||||
"""Handle a POST request."""
|
"""Handle a POST request."""
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue