This reverts commit cdce14d63d
.
This commit is contained in:
parent
390101720d
commit
6b7a4d2d3c
2 changed files with 53 additions and 50 deletions
|
@ -1,15 +1,16 @@
|
|||
"""HTTP Support for Hass.io."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
from aiohttp.client import ClientError, ClientTimeout
|
||||
from aiohttp.hdrs import CONTENT_TYPE
|
||||
from aiohttp.hdrs import CONTENT_LENGTH, CONTENT_TYPE
|
||||
from aiohttp.web_exceptions import HTTPBadGateway
|
||||
import async_timeout
|
||||
|
||||
from homeassistant.components.http import KEY_AUTHENTICATED, HomeAssistantView
|
||||
from homeassistant.components.onboarding import async_is_onboarded
|
||||
|
@ -19,6 +20,8 @@ from .const import X_HASS_IS_ADMIN, X_HASS_USER_ID, X_HASSIO
|
|||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
MAX_UPLOAD_SIZE = 1024 * 1024 * 1024
|
||||
|
||||
NO_TIMEOUT = re.compile(
|
||||
r"^(?:"
|
||||
r"|homeassistant/update"
|
||||
|
@ -72,28 +75,48 @@ class HassIOView(HomeAssistantView):
|
|||
|
||||
async def _command_proxy(
|
||||
self, path: str, request: web.Request
|
||||
) -> web.StreamResponse:
|
||||
) -> web.Response | web.StreamResponse:
|
||||
"""Return a client request with proxy origin for Hass.io supervisor.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
read_timeout = _get_timeout(path)
|
||||
client_timeout = 10
|
||||
data = None
|
||||
headers = _init_header(request)
|
||||
if path in ("snapshots/new/upload", "backups/new/upload"):
|
||||
# We need to reuse the full content type that includes the boundary
|
||||
headers[
|
||||
"Content-Type"
|
||||
] = request._stored_content_type # pylint: disable=protected-access
|
||||
|
||||
# Backups are big, so we need to adjust the allowed size
|
||||
request._client_max_size = ( # pylint: disable=protected-access
|
||||
MAX_UPLOAD_SIZE
|
||||
)
|
||||
client_timeout = 300
|
||||
|
||||
try:
|
||||
# Stream the request to the supervisor
|
||||
client = await self._websession.request(
|
||||
method=request.method,
|
||||
url=f"http://{self._host}/{path}",
|
||||
with async_timeout.timeout(client_timeout):
|
||||
data = await request.read()
|
||||
|
||||
method = getattr(self._websession, request.method.lower())
|
||||
client = await method(
|
||||
f"http://{self._host}/{path}",
|
||||
data=data,
|
||||
headers=headers,
|
||||
data=request.content,
|
||||
timeout=_get_timeout(path),
|
||||
timeout=read_timeout,
|
||||
)
|
||||
|
||||
# Stream the supervisor response back
|
||||
# Simple request
|
||||
if int(client.headers.get(CONTENT_LENGTH, 0)) < 4194000:
|
||||
# Return Response
|
||||
body = await client.read()
|
||||
return web.Response(
|
||||
content_type=client.content_type, status=client.status, body=body
|
||||
)
|
||||
|
||||
# Stream response
|
||||
response = web.StreamResponse(status=client.status, headers=client.headers)
|
||||
response.content_type = client.content_type
|
||||
|
||||
|
@ -103,9 +126,12 @@ class HassIOView(HomeAssistantView):
|
|||
|
||||
return response
|
||||
|
||||
except ClientError as err:
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.error("Client error on api %s request %s", path, err)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
_LOGGER.error("Client timeout error on API request %s", path)
|
||||
|
||||
raise HTTPBadGateway()
|
||||
|
||||
|
||||
|
@ -125,11 +151,11 @@ def _init_header(request: web.Request) -> dict[str, str]:
|
|||
return headers
|
||||
|
||||
|
||||
def _get_timeout(path: str) -> ClientTimeout:
|
||||
def _get_timeout(path: str) -> int:
|
||||
"""Return timeout for a URL path."""
|
||||
if NO_TIMEOUT.match(path):
|
||||
return ClientTimeout(connect=10)
|
||||
return ClientTimeout(connect=10, total=300)
|
||||
return 0
|
||||
return 300
|
||||
|
||||
|
||||
def _need_auth(hass, path: str) -> bool:
|
||||
|
|
|
@ -1,13 +1,11 @@
|
|||
"""The tests for the hassio component."""
|
||||
from aiohttp.client import ClientError
|
||||
from aiohttp.streams import StreamReader
|
||||
from aiohttp.test_utils import TestClient
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.hassio.http import _need_auth
|
||||
|
||||
from tests.test_util.aiohttp import AiohttpClientMocker
|
||||
|
||||
|
||||
async def test_forward_request(hassio_client, aioclient_mock):
|
||||
"""Test fetching normal path."""
|
||||
|
@ -108,6 +106,16 @@ async def test_forward_log_request(hassio_client, aioclient_mock):
|
|||
assert len(aioclient_mock.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_bad_gateway_when_cannot_find_supervisor(hassio_client):
|
||||
"""Test we get a bad gateway error if we can't find supervisor."""
|
||||
with patch(
|
||||
"homeassistant.components.hassio.http.async_timeout.timeout",
|
||||
side_effect=asyncio.TimeoutError,
|
||||
):
|
||||
resp = await hassio_client.get("/api/hassio/addons/test/info")
|
||||
assert resp.status == 502
|
||||
|
||||
|
||||
async def test_forwarding_user_info(hassio_client, hass_admin_user, aioclient_mock):
|
||||
"""Test that we forward user info correctly."""
|
||||
aioclient_mock.get("http://127.0.0.1/hello")
|
||||
|
@ -163,37 +171,6 @@ async def test_backup_download_headers(hassio_client, aioclient_mock):
|
|||
assert resp.headers["Content-Disposition"] == content_disposition
|
||||
|
||||
|
||||
async def test_supervisor_client_error(
|
||||
hassio_client: TestClient, aioclient_mock: AiohttpClientMocker
|
||||
):
|
||||
"""Test any client error from the supervisor returns a 502."""
|
||||
# Create a request that throws a ClientError
|
||||
async def raise_client_error(*args):
|
||||
raise ClientError()
|
||||
|
||||
aioclient_mock.get(
|
||||
"http://127.0.0.1/test/raise/error",
|
||||
side_effect=raise_client_error,
|
||||
)
|
||||
|
||||
# Verify it returns bad gateway
|
||||
resp = await hassio_client.get("/api/hassio/test/raise/error")
|
||||
assert resp.status == 502
|
||||
assert len(aioclient_mock.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_streamed_requests(
|
||||
hassio_client: TestClient, aioclient_mock: AiohttpClientMocker
|
||||
):
|
||||
"""Test requests get proxied to the supervisor as a stream."""
|
||||
aioclient_mock.get("http://127.0.0.1/test/stream")
|
||||
await hassio_client.get("/api/hassio/test/stream", data="Test data")
|
||||
assert len(aioclient_mock.mock_calls) == 1
|
||||
|
||||
# Verify the request body is passed as a StreamReader
|
||||
assert isinstance(aioclient_mock.mock_calls[0][2], StreamReader)
|
||||
|
||||
|
||||
def test_need_auth(hass):
|
||||
"""Test if the requested path needs authentication."""
|
||||
assert not _need_auth(hass, "addons/test/logo")
|
||||
|
|
Loading…
Add table
Reference in a new issue