diff --git a/homeassistant/components/http/ban.py b/homeassistant/components/http/ban.py index 620bdc7613c..81349fe95a1 100644 --- a/homeassistant/components/http/ban.py +++ b/homeassistant/components/http/ban.py @@ -26,7 +26,7 @@ from .view import HomeAssistantView _LOGGER: Final = logging.getLogger(__name__) -KEY_BANNED_IPS: Final = "ha_banned_ips" +KEY_BAN_MANAGER: Final = "ha_banned_ips_manager" KEY_FAILED_LOGIN_ATTEMPTS: Final = "ha_failed_login_attempts" KEY_LOGIN_THRESHOLD: Final = "ha_login_threshold" @@ -50,9 +50,9 @@ def setup_bans(hass: HomeAssistant, app: Application, login_threshold: int) -> N async def ban_startup(app: Application) -> None: """Initialize bans when app starts up.""" - app[KEY_BANNED_IPS] = await async_load_ip_bans_config( - hass, hass.config.path(IP_BANS_FILE) - ) + ban_manager = IpBanManager(hass) + await ban_manager.async_load() + app[KEY_BAN_MANAGER] = ban_manager app.on_startup.append(ban_startup) @@ -62,18 +62,17 @@ async def ban_middleware( request: Request, handler: Callable[[Request], Awaitable[StreamResponse]] ) -> StreamResponse: """IP Ban middleware.""" - if KEY_BANNED_IPS not in request.app: + ban_manager: IpBanManager | None = request.app.get(KEY_BAN_MANAGER) + if ban_manager is None: _LOGGER.error("IP Ban middleware loaded but banned IPs not loaded") return await handler(request) - # Verify if IP is not banned - ip_address_ = ip_address(request.remote) # type: ignore[arg-type] - is_banned = any( - ip_ban.ip_address == ip_address_ for ip_ban in request.app[KEY_BANNED_IPS] - ) - - if is_banned: - raise HTTPForbidden() + ip_bans_lookup = ban_manager.ip_bans_lookup + if ip_bans_lookup: + # Verify if IP is not banned + ip_address_ = ip_address(request.remote) # type: ignore[arg-type] + if ip_address_ in ip_bans_lookup: + raise HTTPForbidden() try: return await handler(request) @@ -129,7 +128,7 @@ async def process_wrong_login(request: Request) -> None: ) # Check if ban middleware is loaded - if KEY_BANNED_IPS not in request.app or request.app[KEY_LOGIN_THRESHOLD] < 1: + if KEY_BAN_MANAGER not in request.app or request.app[KEY_LOGIN_THRESHOLD] < 1: return request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] += 1 @@ -146,14 +145,9 @@ async def process_wrong_login(request: Request) -> None: request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] >= request.app[KEY_LOGIN_THRESHOLD] ): - new_ban = IpBan(remote_addr) - request.app[KEY_BANNED_IPS].append(new_ban) - - await hass.async_add_executor_job( - update_ip_bans_config, hass.config.path(IP_BANS_FILE), new_ban - ) - + ban_manager: IpBanManager = request.app[KEY_BAN_MANAGER] _LOGGER.warning("Banned IP %s for too many login attempts", remote_addr) + await ban_manager.async_add_ban(remote_addr) persistent_notification.async_create( hass, @@ -173,7 +167,7 @@ async def process_success_login(request: Request) -> None: remote_addr = ip_address(request.remote) # type: ignore[arg-type] # Check if ban middleware is loaded - if KEY_BANNED_IPS not in request.app or request.app[KEY_LOGIN_THRESHOLD] < 1: + if KEY_BAN_MANAGER not in request.app or request.app[KEY_LOGIN_THRESHOLD] < 1: return if ( @@ -199,32 +193,49 @@ class IpBan: self.banned_at = banned_at or dt_util.utcnow() -async def async_load_ip_bans_config(hass: HomeAssistant, path: str) -> list[IpBan]: - """Load list of banned IPs from config file.""" - ip_list: list[IpBan] = [] +class IpBanManager: + """Manage IP bans.""" - try: - list_ = await hass.async_add_executor_job(load_yaml_config_file, path) - except FileNotFoundError: - return ip_list - except HomeAssistantError as err: - _LOGGER.error("Unable to load %s: %s", path, str(err)) - return ip_list + def __init__(self, hass: HomeAssistant) -> None: + """Init the ban manager.""" + self.hass = hass + self.path = hass.config.path(IP_BANS_FILE) + self.ip_bans_lookup: dict[IPv4Address | IPv6Address, IpBan] = {} - for ip_ban, ip_info in list_.items(): + async def async_load(self) -> None: + """Load the existing IP bans.""" try: - ip_info = SCHEMA_IP_BAN_ENTRY(ip_info) - ip_list.append(IpBan(ip_ban, ip_info["banned_at"])) - except vol.Invalid as err: - _LOGGER.error("Failed to load IP ban %s: %s", ip_info, err) - continue + list_ = await self.hass.async_add_executor_job( + load_yaml_config_file, self.path + ) + except FileNotFoundError: + return + except HomeAssistantError as err: + _LOGGER.error("Unable to load %s: %s", self.path, str(err)) + return - return ip_list + ip_bans_lookup: dict[IPv4Address | IPv6Address, IpBan] = {} + for ip_ban, ip_info in list_.items(): + try: + ip_info = SCHEMA_IP_BAN_ENTRY(ip_info) + ban = IpBan(ip_ban, ip_info["banned_at"]) + ip_bans_lookup[ban.ip_address] = ban + except vol.Invalid as err: + _LOGGER.error("Failed to load IP ban %s: %s", ip_info, err) + continue + self.ip_bans_lookup = ip_bans_lookup -def update_ip_bans_config(path: str, ip_ban: IpBan) -> None: - """Update config file with new banned IP address.""" - with open(path, "a", encoding="utf8") as out: - ip_ = {str(ip_ban.ip_address): {ATTR_BANNED_AT: ip_ban.banned_at.isoformat()}} - out.write("\n") - out.write(yaml.dump(ip_)) + def _add_ban(self, ip_ban: IpBan) -> None: + """Update config file with new banned IP address.""" + with open(self.path, "a", encoding="utf8") as out: + ip_ = { + str(ip_ban.ip_address): {ATTR_BANNED_AT: ip_ban.banned_at.isoformat()} + } + # Write in a single write call to avoid interleaved writes + out.write("\n" + yaml.dump(ip_)) + + async def async_add_ban(self, remote_addr: IPv4Address | IPv6Address) -> None: + """Add a new IP address to the banned list.""" + new_ban = self.ip_bans_lookup[remote_addr] = IpBan(remote_addr) + await self.hass.async_add_executor_job(self._add_ban, new_ban) diff --git a/tests/components/conftest.py b/tests/components/conftest.py index f153263cbc6..6cad53aea72 100644 --- a/tests/components/conftest.py +++ b/tests/components/conftest.py @@ -19,8 +19,7 @@ def patch_zeroconf_multiple_catcher(): def prevent_io(): """Fixture to prevent certain I/O from happening.""" with patch( - "homeassistant.components.http.ban.async_load_ip_bans_config", - return_value=[], + "homeassistant.components.http.ban.load_yaml_config_file", ): yield diff --git a/tests/components/http/test_ban.py b/tests/components/http/test_ban.py index 5e482d16248..05a6493c9c2 100644 --- a/tests/components/http/test_ban.py +++ b/tests/components/http/test_ban.py @@ -15,12 +15,13 @@ import homeassistant.components.http as http from homeassistant.components.http import KEY_AUTHENTICATED from homeassistant.components.http.ban import ( IP_BANS_FILE, - KEY_BANNED_IPS, + KEY_BAN_MANAGER, KEY_FAILED_LOGIN_ATTEMPTS, - IpBan, + IpBanManager, setup_bans, ) from homeassistant.components.http.view import request_handler_factory +from homeassistant.exceptions import HomeAssistantError from homeassistant.setup import async_setup_component from . import mock_real_ip @@ -58,8 +59,10 @@ async def test_access_from_banned_ip(hass, aiohttp_client): set_real_ip = mock_real_ip(app) with patch( - "homeassistant.components.http.ban.async_load_ip_bans_config", - return_value=[IpBan(banned_ip) for banned_ip in BANNED_IPS], + "homeassistant.components.http.ban.load_yaml_config_file", + return_value={ + banned_ip: {"banned_at": "2016-11-16T19:20:03"} for banned_ip in BANNED_IPS + }, ): client = await aiohttp_client(app) @@ -69,6 +72,99 @@ async def test_access_from_banned_ip(hass, aiohttp_client): assert resp.status == HTTPStatus.FORBIDDEN +async def test_access_from_banned_ip_with_partially_broken_yaml_file( + hass, aiohttp_client, caplog +): + """Test accessing to server from banned IP. Both trusted and not. + + We inject some garbage into the yaml file to make sure it can + still load the bans. + """ + app = web.Application() + app["hass"] = hass + setup_bans(hass, app, 5) + set_real_ip = mock_real_ip(app) + + data = {banned_ip: {"banned_at": "2016-11-16T19:20:03"} for banned_ip in BANNED_IPS} + data["5.3.3.3"] = {"banned_at": "garbage"} + + with patch( + "homeassistant.components.http.ban.load_yaml_config_file", + return_value=data, + ): + client = await aiohttp_client(app) + + for remote_addr in BANNED_IPS: + set_real_ip(remote_addr) + resp = await client.get("/") + assert resp.status == HTTPStatus.FORBIDDEN + + # Ensure garbage data is ignored + set_real_ip("5.3.3.3") + resp = await client.get("/") + assert resp.status == HTTPStatus.NOT_FOUND + + assert "Failed to load IP ban" in caplog.text + + +async def test_no_ip_bans_file(hass, aiohttp_client): + """Test no ip bans file.""" + app = web.Application() + app["hass"] = hass + setup_bans(hass, app, 5) + set_real_ip = mock_real_ip(app) + + with patch( + "homeassistant.components.http.ban.load_yaml_config_file", + side_effect=FileNotFoundError, + ): + client = await aiohttp_client(app) + + set_real_ip("4.3.2.1") + resp = await client.get("/") + assert resp.status == HTTPStatus.NOT_FOUND + + +async def test_failure_loading_ip_bans_file(hass, aiohttp_client): + """Test failure loading ip bans file.""" + app = web.Application() + app["hass"] = hass + setup_bans(hass, app, 5) + set_real_ip = mock_real_ip(app) + + with patch( + "homeassistant.components.http.ban.load_yaml_config_file", + side_effect=HomeAssistantError, + ): + client = await aiohttp_client(app) + + set_real_ip("4.3.2.1") + resp = await client.get("/") + assert resp.status == HTTPStatus.NOT_FOUND + + +async def test_ip_ban_manager_never_started(hass, aiohttp_client, caplog): + """Test we handle the ip ban manager not being started.""" + app = web.Application() + app["hass"] = hass + setup_bans(hass, app, 5) + set_real_ip = mock_real_ip(app) + + with patch( + "homeassistant.components.http.ban.load_yaml_config_file", + side_effect=FileNotFoundError, + ): + client = await aiohttp_client(app) + + # Mock the manager never being started + del app[KEY_BAN_MANAGER] + + set_real_ip("4.3.2.1") + resp = await client.get("/") + assert resp.status == HTTPStatus.NOT_FOUND + assert "IP Ban middleware loaded but banned IPs not loaded" in caplog.text + + @pytest.mark.parametrize( "remote_addr, bans, status", list( @@ -95,10 +191,13 @@ async def test_access_from_supervisor_ip( mock_real_ip(app)(remote_addr) with patch( - "homeassistant.components.http.ban.async_load_ip_bans_config", return_value=[] + "homeassistant.components.http.ban.load_yaml_config_file", + return_value={}, ): client = await aiohttp_client(app) + manager: IpBanManager = app[KEY_BAN_MANAGER] + assert await async_setup_component(hass, "hassio", {"hassio": {}}) m_open = mock_open() @@ -108,13 +207,13 @@ async def test_access_from_supervisor_ip( ): resp = await client.get("/") assert resp.status == HTTPStatus.UNAUTHORIZED - assert len(app[KEY_BANNED_IPS]) == bans + assert len(manager.ip_bans_lookup) == bans assert m_open.call_count == bans # second request should be forbidden if banned resp = await client.get("/") assert resp.status == status - assert len(app[KEY_BANNED_IPS]) == bans + assert len(manager.ip_bans_lookup) == bans async def test_ban_middleware_not_loaded_by_config(hass): @@ -149,22 +248,25 @@ async def test_ip_bans_file_creation(hass, aiohttp_client): mock_real_ip(app)("200.201.202.204") with patch( - "homeassistant.components.http.ban.async_load_ip_bans_config", - return_value=[IpBan(banned_ip) for banned_ip in BANNED_IPS], + "homeassistant.components.http.ban.load_yaml_config_file", + return_value={ + banned_ip: {"banned_at": "2016-11-16T19:20:03"} for banned_ip in BANNED_IPS + }, ): client = await aiohttp_client(app) + manager: IpBanManager = app[KEY_BAN_MANAGER] m_open = mock_open() with patch("homeassistant.components.http.ban.open", m_open, create=True): resp = await client.get("/") assert resp.status == HTTPStatus.UNAUTHORIZED - assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS) + assert len(manager.ip_bans_lookup) == len(BANNED_IPS) assert m_open.call_count == 0 resp = await client.get("/") assert resp.status == HTTPStatus.UNAUTHORIZED - assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1 + assert len(manager.ip_bans_lookup) == len(BANNED_IPS) + 1 m_open.assert_called_once_with( hass.config.path(IP_BANS_FILE), "a", encoding="utf8" )