diff --git a/homeassistant/components/hassio/__init__.py b/homeassistant/components/hassio/__init__.py index bcb751faa64..f13db03ca4c 100644 --- a/homeassistant/components/hassio/__init__.py +++ b/homeassistant/components/hassio/__init__.py @@ -143,6 +143,14 @@ def is_hassio(hass): return DOMAIN in hass.config.components +@callback +def get_supervisor_ip(): + """Return the supervisor ip address.""" + if "SUPERVISOR" not in os.environ: + return None + return os.environ["SUPERVISOR"].partition(":")[0] + + async def async_setup(hass, config): """Set up the Hass.io component.""" # Check local setup diff --git a/homeassistant/components/http/ban.py b/homeassistant/components/http/ban.py index 38eda8e9b3f..7fffd19c467 100644 --- a/homeassistant/components/http/ban.py +++ b/homeassistant/components/http/ban.py @@ -110,6 +110,12 @@ async def process_wrong_login(request): request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] += 1 + # Supervisor IP should never be banned + if "hassio" in hass.config.components and hass.components.hassio.get_supervisor_ip() == str( + remote_addr + ): + return + if ( request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] >= request.app[KEY_LOGIN_THRESHOLD] diff --git a/tests/components/conftest.py b/tests/components/conftest.py index 528e804a01a..1670e0c2485 100644 --- a/tests/components/conftest.py +++ b/tests/components/conftest.py @@ -1,16 +1,12 @@ """Fixtures for component testing.""" -from unittest.mock import patch - +from asynctest import patch import pytest -from tests.common import mock_coro - @pytest.fixture(autouse=True) def prevent_io(): """Fixture to prevent certain I/O from happening.""" with patch( - "homeassistant.components.http.ban.async_load_ip_bans_config", - side_effect=lambda *args: mock_coro([]), + "homeassistant.components.http.ban.async_load_ip_bans_config", return_value=[], ): yield diff --git a/tests/components/http/test_ban.py b/tests/components/http/test_ban.py index 8d9d19b6a12..28be5cc45c3 100644 --- a/tests/components/http/test_ban.py +++ b/tests/components/http/test_ban.py @@ -1,11 +1,14 @@ """The tests for the Home Assistant HTTP component.""" # pylint: disable=protected-access from ipaddress import ip_address -from unittest.mock import Mock, mock_open, patch +import os +from unittest.mock import Mock, mock_open from aiohttp import web from aiohttp.web_exceptions import HTTPUnauthorized from aiohttp.web_middlewares import middleware +from asynctest import patch +import pytest import homeassistant.components.http as http from homeassistant.components.http import KEY_AUTHENTICATED @@ -21,20 +24,31 @@ from homeassistant.setup import async_setup_component from . import mock_real_ip -from tests.common import mock_coro - +SUPERVISOR_IP = "1.2.3.4" BANNED_IPS = ["200.201.202.203", "100.64.0.2"] +BANNED_IPS_WITH_SUPERVISOR = BANNED_IPS + [SUPERVISOR_IP] + + +@pytest.fixture(name="hassio_env") +def hassio_env_fixture(): + """Fixture to inject hassio env.""" + with patch.dict(os.environ, {"HASSIO": "127.0.0.1"}), patch( + "homeassistant.components.hassio.HassIO.is_connected", + return_value={"result": "ok", "data": {}}, + ), patch.dict(os.environ, {"HASSIO_TOKEN": "123456"}): + yield async def test_access_from_banned_ip(hass, aiohttp_client): """Test accessing to server from banned IP. Both trusted and not.""" app = web.Application() + app["hass"] = hass setup_bans(hass, app, 5) set_real_ip = mock_real_ip(app) with patch( "homeassistant.components.http.ban.async_load_ip_bans_config", - return_value=mock_coro([IpBan(banned_ip) for banned_ip in BANNED_IPS]), + return_value=[IpBan(banned_ip) for banned_ip in BANNED_IPS], ): client = await aiohttp_client(app) @@ -44,6 +58,48 @@ async def test_access_from_banned_ip(hass, aiohttp_client): assert resp.status == 403 +@pytest.mark.parametrize( + "remote_addr, bans, status", + list(zip(BANNED_IPS_WITH_SUPERVISOR, [1, 1, 0], [403, 403, 401])), +) +async def test_access_from_supervisor_ip( + remote_addr, bans, status, hass, aiohttp_client, hassio_env +): + """Test accessing to server from supervisor IP.""" + app = web.Application() + app["hass"] = hass + + async def unauth_handler(request): + """Return a mock web response.""" + raise HTTPUnauthorized + + app.router.add_get("/", unauth_handler) + setup_bans(hass, app, 1) + mock_real_ip(app)(remote_addr) + + with patch( + "homeassistant.components.http.ban.async_load_ip_bans_config", return_value=[], + ): + client = await aiohttp_client(app) + + assert await async_setup_component(hass, "hassio", {"hassio": {}}) + + m_open = mock_open() + + with patch.dict(os.environ, {"SUPERVISOR": SUPERVISOR_IP}), patch( + "homeassistant.components.http.ban.open", m_open, create=True + ): + resp = await client.get("/") + assert resp.status == 401 + assert len(app[KEY_BANNED_IPS]) == bans + assert m_open.call_count == bans + + # second request should be forbidden if banned + resp = await client.get("/") + assert resp.status == status + assert len(app[KEY_BANNED_IPS]) == bans + + async def test_ban_middleware_not_loaded_by_config(hass): """Test accessing to server from banned IP when feature is off.""" with patch("homeassistant.components.http.setup_bans") as mock_setup: @@ -77,26 +133,26 @@ async def test_ip_bans_file_creation(hass, aiohttp_client): with patch( "homeassistant.components.http.ban.async_load_ip_bans_config", - return_value=mock_coro([IpBan(banned_ip) for banned_ip in BANNED_IPS]), + return_value=[IpBan(banned_ip) for banned_ip in BANNED_IPS], ): client = await aiohttp_client(app) - m = mock_open() + m_open = mock_open() - with patch("homeassistant.components.http.ban.open", m, create=True): + with patch("homeassistant.components.http.ban.open", m_open, create=True): resp = await client.get("/") assert resp.status == 401 assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS) - assert m.call_count == 0 + assert m_open.call_count == 0 resp = await client.get("/") assert resp.status == 401 assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1 - m.assert_called_once_with(hass.config.path(IP_BANS_FILE), "a") + m_open.assert_called_once_with(hass.config.path(IP_BANS_FILE), "a") resp = await client.get("/") assert resp.status == 403 - assert m.call_count == 1 + assert m_open.call_count == 1 async def test_failed_login_attempts_counter(hass, aiohttp_client):