Fix circular imports in core integrations (#111875)
* Fix circular imports in core integrations * fix circular import * fix more circular imports * fix more circular imports * fix more circular imports * fix more circular imports * fix more circular imports * fix more circular imports * fix more circular imports * adjust * fix * increase timeout * remove unused logger * keep up to date * make sure its reprod
This commit is contained in:
parent
25510fc13c
commit
c1750f7c3a
9 changed files with 242 additions and 192 deletions
184
homeassistant/helpers/http.py
Normal file
184
homeassistant/helpers/http.py
Normal file
|
@ -0,0 +1,184 @@
|
|||
"""Helper to track the current http request."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextvars import ContextVar
|
||||
from http import HTTPStatus
|
||||
import logging
|
||||
from typing import Any, Final
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.typedefs import LooseHeaders
|
||||
from aiohttp.web import Request
|
||||
from aiohttp.web_exceptions import (
|
||||
HTTPBadRequest,
|
||||
HTTPInternalServerError,
|
||||
HTTPUnauthorized,
|
||||
)
|
||||
from aiohttp.web_urldispatcher import AbstractRoute
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import exceptions
|
||||
from homeassistant.const import CONTENT_TYPE_JSON
|
||||
from homeassistant.core import Context, HomeAssistant, is_callback
|
||||
from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS, format_unserializable_data
|
||||
|
||||
from .json import find_paths_unserializable_data, json_bytes, json_dumps
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
KEY_AUTHENTICATED: Final = "ha_authenticated"
|
||||
|
||||
current_request: ContextVar[Request | None] = ContextVar(
|
||||
"current_request", default=None
|
||||
)
|
||||
|
||||
|
||||
def request_handler_factory(
|
||||
hass: HomeAssistant, view: HomeAssistantView, handler: Callable
|
||||
) -> Callable[[web.Request], Awaitable[web.StreamResponse]]:
|
||||
"""Wrap the handler classes."""
|
||||
is_coroutinefunction = asyncio.iscoroutinefunction(handler)
|
||||
assert is_coroutinefunction or is_callback(
|
||||
handler
|
||||
), "Handler should be a coroutine or a callback."
|
||||
|
||||
async def handle(request: web.Request) -> web.StreamResponse:
|
||||
"""Handle incoming request."""
|
||||
if hass.is_stopping:
|
||||
return web.Response(status=HTTPStatus.SERVICE_UNAVAILABLE)
|
||||
|
||||
authenticated = request.get(KEY_AUTHENTICATED, False)
|
||||
|
||||
if view.requires_auth and not authenticated:
|
||||
raise HTTPUnauthorized()
|
||||
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug(
|
||||
"Serving %s to %s (auth: %s)",
|
||||
request.path,
|
||||
request.remote,
|
||||
authenticated,
|
||||
)
|
||||
|
||||
try:
|
||||
if is_coroutinefunction:
|
||||
result = await handler(request, **request.match_info)
|
||||
else:
|
||||
result = handler(request, **request.match_info)
|
||||
except vol.Invalid as err:
|
||||
raise HTTPBadRequest() from err
|
||||
except exceptions.ServiceNotFound as err:
|
||||
raise HTTPInternalServerError() from err
|
||||
except exceptions.Unauthorized as err:
|
||||
raise HTTPUnauthorized() from err
|
||||
|
||||
if isinstance(result, web.StreamResponse):
|
||||
# The method handler returned a ready-made Response, how nice of it
|
||||
return result
|
||||
|
||||
status_code = HTTPStatus.OK
|
||||
if isinstance(result, tuple):
|
||||
result, status_code = result
|
||||
|
||||
if isinstance(result, bytes):
|
||||
return web.Response(body=result, status=status_code)
|
||||
|
||||
if isinstance(result, str):
|
||||
return web.Response(text=result, status=status_code)
|
||||
|
||||
if result is None:
|
||||
return web.Response(body=b"", status=status_code)
|
||||
|
||||
raise TypeError(
|
||||
f"Result should be None, string, bytes or StreamResponse. Got: {result}"
|
||||
)
|
||||
|
||||
return handle
|
||||
|
||||
|
||||
class HomeAssistantView:
|
||||
"""Base view for all views."""
|
||||
|
||||
url: str | None = None
|
||||
extra_urls: list[str] = []
|
||||
# Views inheriting from this class can override this
|
||||
requires_auth = True
|
||||
cors_allowed = False
|
||||
|
||||
@staticmethod
|
||||
def context(request: web.Request) -> Context:
|
||||
"""Generate a context from a request."""
|
||||
if (user := request.get("hass_user")) is None:
|
||||
return Context()
|
||||
|
||||
return Context(user_id=user.id)
|
||||
|
||||
@staticmethod
|
||||
def json(
|
||||
result: Any,
|
||||
status_code: HTTPStatus | int = HTTPStatus.OK,
|
||||
headers: LooseHeaders | None = None,
|
||||
) -> web.Response:
|
||||
"""Return a JSON response."""
|
||||
try:
|
||||
msg = json_bytes(result)
|
||||
except JSON_ENCODE_EXCEPTIONS as err:
|
||||
_LOGGER.error(
|
||||
"Unable to serialize to JSON. Bad data found at %s",
|
||||
format_unserializable_data(
|
||||
find_paths_unserializable_data(result, dump=json_dumps)
|
||||
),
|
||||
)
|
||||
raise HTTPInternalServerError from err
|
||||
response = web.Response(
|
||||
body=msg,
|
||||
content_type=CONTENT_TYPE_JSON,
|
||||
status=int(status_code),
|
||||
headers=headers,
|
||||
zlib_executor_size=32768,
|
||||
)
|
||||
response.enable_compression()
|
||||
return response
|
||||
|
||||
def json_message(
|
||||
self,
|
||||
message: str,
|
||||
status_code: HTTPStatus | int = HTTPStatus.OK,
|
||||
message_code: str | None = None,
|
||||
headers: LooseHeaders | None = None,
|
||||
) -> web.Response:
|
||||
"""Return a JSON message response."""
|
||||
data = {"message": message}
|
||||
if message_code is not None:
|
||||
data["code"] = message_code
|
||||
return self.json(data, status_code, headers=headers)
|
||||
|
||||
def register(
|
||||
self, hass: HomeAssistant, app: web.Application, router: web.UrlDispatcher
|
||||
) -> None:
|
||||
"""Register the view with a router."""
|
||||
assert self.url is not None, "No url set for view"
|
||||
urls = [self.url] + self.extra_urls
|
||||
routes: list[AbstractRoute] = []
|
||||
|
||||
for method in ("get", "post", "delete", "put", "patch", "head", "options"):
|
||||
if not (handler := getattr(self, method, None)):
|
||||
continue
|
||||
|
||||
handler = request_handler_factory(hass, self, handler)
|
||||
|
||||
for url in urls:
|
||||
routes.append(router.add_route(method, url, handler))
|
||||
|
||||
# Use `get` because CORS middleware is not be loaded in emulated_hue
|
||||
if self.cors_allowed:
|
||||
allow_cors = app.get("allow_all_cors")
|
||||
else:
|
||||
allow_cors = app.get("allow_configured_cors")
|
||||
|
||||
if allow_cors:
|
||||
for route in routes:
|
||||
allow_cors(route)
|
Loading…
Add table
Add a link
Reference in a new issue