Improve http decorator typing (#75541)

This commit is contained in:
Marc Mueller 2022-07-21 13:07:42 +02:00 committed by GitHub
parent 1d7d2875e1
commit b1ed1543c8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 37 additions and 22 deletions

View file

@ -257,7 +257,7 @@ class LoginFlowResourceView(LoginFlowBaseView):
@RequestDataValidator(vol.Schema({"client_id": str}, extra=vol.ALLOW_EXTRA)) @RequestDataValidator(vol.Schema({"client_id": str}, extra=vol.ALLOW_EXTRA))
@log_invalid_auth @log_invalid_auth
async def post(self, request, flow_id, data): async def post(self, request, data, flow_id):
"""Handle progressing a login flow request.""" """Handle progressing a login flow request."""
client_id = data.pop("client_id") client_id = data.pop("client_id")

View file

@ -2,17 +2,18 @@
from __future__ import annotations from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable, Coroutine
from contextlib import suppress from contextlib import suppress
from datetime import datetime from datetime import datetime
from http import HTTPStatus from http import HTTPStatus
from ipaddress import IPv4Address, IPv6Address, ip_address from ipaddress import IPv4Address, IPv6Address, ip_address
import logging import logging
from socket import gethostbyaddr, herror from socket import gethostbyaddr, herror
from typing import Any, Final from typing import Any, Final, TypeVar
from aiohttp.web import Application, Request, StreamResponse, middleware from aiohttp.web import Application, Request, Response, StreamResponse, middleware
from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
from typing_extensions import Concatenate, ParamSpec
import voluptuous as vol import voluptuous as vol
from homeassistant.components import persistent_notification from homeassistant.components import persistent_notification
@ -24,6 +25,9 @@ from homeassistant.util import dt as dt_util, yaml
from .view import HomeAssistantView from .view import HomeAssistantView
_HassViewT = TypeVar("_HassViewT", bound=HomeAssistantView)
_P = ParamSpec("_P")
_LOGGER: Final = logging.getLogger(__name__) _LOGGER: Final = logging.getLogger(__name__)
KEY_BAN_MANAGER: Final = "ha_banned_ips_manager" KEY_BAN_MANAGER: Final = "ha_banned_ips_manager"
@ -82,13 +86,13 @@ async def ban_middleware(
def log_invalid_auth( def log_invalid_auth(
func: Callable[..., Awaitable[StreamResponse]] func: Callable[Concatenate[_HassViewT, Request, _P], Awaitable[Response]]
) -> Callable[..., Awaitable[StreamResponse]]: ) -> Callable[Concatenate[_HassViewT, Request, _P], Coroutine[Any, Any, Response]]:
"""Decorate function to handle invalid auth or failed login attempts.""" """Decorate function to handle invalid auth or failed login attempts."""
async def handle_req( async def handle_req(
view: HomeAssistantView, request: Request, *args: Any, **kwargs: Any view: _HassViewT, request: Request, *args: _P.args, **kwargs: _P.kwargs
) -> StreamResponse: ) -> Response:
"""Try to log failed login attempts if response status >= BAD_REQUEST.""" """Try to log failed login attempts if response status >= BAD_REQUEST."""
resp = await func(view, request, *args, **kwargs) resp = await func(view, request, *args, **kwargs)
if resp.status >= HTTPStatus.BAD_REQUEST: if resp.status >= HTTPStatus.BAD_REQUEST:

View file

@ -1,17 +1,21 @@
"""Decorator for view methods to help with data validation.""" """Decorator for view methods to help with data validation."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable, Coroutine
from functools import wraps from functools import wraps
from http import HTTPStatus from http import HTTPStatus
import logging import logging
from typing import Any from typing import Any, TypeVar
from aiohttp import web from aiohttp import web
from typing_extensions import Concatenate, ParamSpec
import voluptuous as vol import voluptuous as vol
from .view import HomeAssistantView from .view import HomeAssistantView
_HassViewT = TypeVar("_HassViewT", bound=HomeAssistantView)
_P = ParamSpec("_P")
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -33,33 +37,40 @@ class RequestDataValidator:
self._allow_empty = allow_empty self._allow_empty = allow_empty
def __call__( def __call__(
self, method: Callable[..., Awaitable[web.StreamResponse]] self,
) -> Callable: method: Callable[
Concatenate[_HassViewT, web.Request, dict[str, Any], _P],
Awaitable[web.Response],
],
) -> Callable[
Concatenate[_HassViewT, web.Request, _P],
Coroutine[Any, Any, web.Response],
]:
"""Decorate a function.""" """Decorate a function."""
@wraps(method) @wraps(method)
async def wrapper( async def wrapper(
view: HomeAssistantView, request: web.Request, *args: Any, **kwargs: Any view: _HassViewT, request: web.Request, *args: _P.args, **kwargs: _P.kwargs
) -> web.StreamResponse: ) -> web.Response:
"""Wrap a request handler with data validation.""" """Wrap a request handler with data validation."""
data = None raw_data = None
try: try:
data = await request.json() raw_data = await request.json()
except ValueError: except ValueError:
if not self._allow_empty or (await request.content.read()) != b"": if not self._allow_empty or (await request.content.read()) != b"":
_LOGGER.error("Invalid JSON received") _LOGGER.error("Invalid JSON received")
return view.json_message("Invalid JSON.", HTTPStatus.BAD_REQUEST) return view.json_message("Invalid JSON.", HTTPStatus.BAD_REQUEST)
data = {} raw_data = {}
try: try:
kwargs["data"] = self._schema(data) data: dict[str, Any] = self._schema(raw_data)
except vol.Invalid as err: except vol.Invalid as err:
_LOGGER.error("Data does not match schema: %s", err) _LOGGER.error("Data does not match schema: %s", err)
return view.json_message( return view.json_message(
f"Message format incorrect: {err}", HTTPStatus.BAD_REQUEST f"Message format incorrect: {err}", HTTPStatus.BAD_REQUEST
) )
result = await method(view, request, *args, **kwargs) result = await method(view, request, data, *args, **kwargs)
return result return result
return wrapper return wrapper

View file

@ -113,7 +113,7 @@ class RepairsFlowIndexView(FlowManagerIndexView):
result = self._prepare_result_json(result) result = self._prepare_result_json(result)
return self.json(result) # pylint: disable=arguments-differ return self.json(result)
class RepairsFlowResourceView(FlowManagerResourceView): class RepairsFlowResourceView(FlowManagerResourceView):
@ -136,4 +136,4 @@ class RepairsFlowResourceView(FlowManagerResourceView):
raise Unauthorized(permission=POLICY_EDIT) raise Unauthorized(permission=POLICY_EDIT)
# pylint: disable=no-value-for-parameter # pylint: disable=no-value-for-parameter
return await super().post(request, flow_id) # type: ignore[no-any-return] return await super().post(request, flow_id)

View file

@ -102,7 +102,7 @@ class FlowManagerResourceView(_BaseFlowManagerView):
@RequestDataValidator(vol.Schema(dict), allow_empty=True) @RequestDataValidator(vol.Schema(dict), allow_empty=True)
async def post( async def post(
self, request: web.Request, flow_id: str, data: dict[str, Any] self, request: web.Request, data: dict[str, Any], flow_id: str
) -> web.Response: ) -> web.Response:
"""Handle a POST request.""" """Handle a POST request."""
try: try: