From 6b7a4d2d3c8d10cc3cdd27b07139a93932eadc89 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 29 Jul 2021 14:26:05 -0700 Subject: [PATCH] Revert "Allow uploading large snapshots (#53528)" (#53729) This reverts commit cdce14d63db209acedb9888e726b813a069bf720. --- homeassistant/components/hassio/http.py | 54 ++++++++++++++++++------- tests/components/hassio/test_http.py | 49 ++++++---------------- 2 files changed, 53 insertions(+), 50 deletions(-) diff --git a/homeassistant/components/hassio/http.py b/homeassistant/components/hassio/http.py index 73e5549be9a..302cc00bb9f 100644 --- a/homeassistant/components/hassio/http.py +++ b/homeassistant/components/hassio/http.py @@ -1,15 +1,16 @@ """HTTP Support for Hass.io.""" from __future__ import annotations +import asyncio import logging import os import re import aiohttp from aiohttp import web -from aiohttp.client import ClientError, ClientTimeout -from aiohttp.hdrs import CONTENT_TYPE +from aiohttp.hdrs import CONTENT_LENGTH, CONTENT_TYPE from aiohttp.web_exceptions import HTTPBadGateway +import async_timeout from homeassistant.components.http import KEY_AUTHENTICATED, HomeAssistantView from homeassistant.components.onboarding import async_is_onboarded @@ -19,6 +20,8 @@ from .const import X_HASS_IS_ADMIN, X_HASS_USER_ID, X_HASSIO _LOGGER = logging.getLogger(__name__) +MAX_UPLOAD_SIZE = 1024 * 1024 * 1024 + NO_TIMEOUT = re.compile( r"^(?:" r"|homeassistant/update" @@ -72,28 +75,48 @@ class HassIOView(HomeAssistantView): async def _command_proxy( self, path: str, request: web.Request - ) -> web.StreamResponse: + ) -> web.Response | web.StreamResponse: """Return a client request with proxy origin for Hass.io supervisor. This method is a coroutine. """ + read_timeout = _get_timeout(path) + client_timeout = 10 + data = None headers = _init_header(request) if path in ("snapshots/new/upload", "backups/new/upload"): # We need to reuse the full content type that includes the boundary headers[ "Content-Type" ] = request._stored_content_type # pylint: disable=protected-access + + # Backups are big, so we need to adjust the allowed size + request._client_max_size = ( # pylint: disable=protected-access + MAX_UPLOAD_SIZE + ) + client_timeout = 300 + try: - # Stream the request to the supervisor - client = await self._websession.request( - method=request.method, - url=f"http://{self._host}/{path}", + with async_timeout.timeout(client_timeout): + data = await request.read() + + method = getattr(self._websession, request.method.lower()) + client = await method( + f"http://{self._host}/{path}", + data=data, headers=headers, - data=request.content, - timeout=_get_timeout(path), + timeout=read_timeout, ) - # Stream the supervisor response back + # Simple request + if int(client.headers.get(CONTENT_LENGTH, 0)) < 4194000: + # Return Response + body = await client.read() + return web.Response( + content_type=client.content_type, status=client.status, body=body + ) + + # Stream response response = web.StreamResponse(status=client.status, headers=client.headers) response.content_type = client.content_type @@ -103,9 +126,12 @@ class HassIOView(HomeAssistantView): return response - except ClientError as err: + except aiohttp.ClientError as err: _LOGGER.error("Client error on api %s request %s", path, err) + except asyncio.TimeoutError: + _LOGGER.error("Client timeout error on API request %s", path) + raise HTTPBadGateway() @@ -125,11 +151,11 @@ def _init_header(request: web.Request) -> dict[str, str]: return headers -def _get_timeout(path: str) -> ClientTimeout: +def _get_timeout(path: str) -> int: """Return timeout for a URL path.""" if NO_TIMEOUT.match(path): - return ClientTimeout(connect=10) - return ClientTimeout(connect=10, total=300) + return 0 + return 300 def _need_auth(hass, path: str) -> bool: diff --git a/tests/components/hassio/test_http.py b/tests/components/hassio/test_http.py index 881d3cc26ed..fc4bb3e6a0d 100644 --- a/tests/components/hassio/test_http.py +++ b/tests/components/hassio/test_http.py @@ -1,13 +1,11 @@ """The tests for the hassio component.""" -from aiohttp.client import ClientError -from aiohttp.streams import StreamReader -from aiohttp.test_utils import TestClient +import asyncio +from unittest.mock import patch + import pytest from homeassistant.components.hassio.http import _need_auth -from tests.test_util.aiohttp import AiohttpClientMocker - async def test_forward_request(hassio_client, aioclient_mock): """Test fetching normal path.""" @@ -108,6 +106,16 @@ async def test_forward_log_request(hassio_client, aioclient_mock): assert len(aioclient_mock.mock_calls) == 1 +async def test_bad_gateway_when_cannot_find_supervisor(hassio_client): + """Test we get a bad gateway error if we can't find supervisor.""" + with patch( + "homeassistant.components.hassio.http.async_timeout.timeout", + side_effect=asyncio.TimeoutError, + ): + resp = await hassio_client.get("/api/hassio/addons/test/info") + assert resp.status == 502 + + async def test_forwarding_user_info(hassio_client, hass_admin_user, aioclient_mock): """Test that we forward user info correctly.""" aioclient_mock.get("http://127.0.0.1/hello") @@ -163,37 +171,6 @@ async def test_backup_download_headers(hassio_client, aioclient_mock): assert resp.headers["Content-Disposition"] == content_disposition -async def test_supervisor_client_error( - hassio_client: TestClient, aioclient_mock: AiohttpClientMocker -): - """Test any client error from the supervisor returns a 502.""" - # Create a request that throws a ClientError - async def raise_client_error(*args): - raise ClientError() - - aioclient_mock.get( - "http://127.0.0.1/test/raise/error", - side_effect=raise_client_error, - ) - - # Verify it returns bad gateway - resp = await hassio_client.get("/api/hassio/test/raise/error") - assert resp.status == 502 - assert len(aioclient_mock.mock_calls) == 1 - - -async def test_streamed_requests( - hassio_client: TestClient, aioclient_mock: AiohttpClientMocker -): - """Test requests get proxied to the supervisor as a stream.""" - aioclient_mock.get("http://127.0.0.1/test/stream") - await hassio_client.get("/api/hassio/test/stream", data="Test data") - assert len(aioclient_mock.mock_calls) == 1 - - # Verify the request body is passed as a StreamReader - assert isinstance(aioclient_mock.mock_calls[0][2], StreamReader) - - def test_need_auth(hass): """Test if the requested path needs authentication.""" assert not _need_auth(hass, "addons/test/logo")