Switch linear search to a dict lookup for ip bans (#74482)

This commit is contained in:
J. Nick Koston 2022-07-07 03:57:44 -05:00 committed by GitHub
parent ae295f1bf5
commit 0c29b68cf8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 170 additions and 58 deletions

View file

@ -26,7 +26,7 @@ from .view import HomeAssistantView
_LOGGER: Final = logging.getLogger(__name__)
KEY_BANNED_IPS: Final = "ha_banned_ips"
KEY_BAN_MANAGER: Final = "ha_banned_ips_manager"
KEY_FAILED_LOGIN_ATTEMPTS: Final = "ha_failed_login_attempts"
KEY_LOGIN_THRESHOLD: Final = "ha_login_threshold"
@ -50,9 +50,9 @@ def setup_bans(hass: HomeAssistant, app: Application, login_threshold: int) -> N
async def ban_startup(app: Application) -> None:
"""Initialize bans when app starts up."""
app[KEY_BANNED_IPS] = await async_load_ip_bans_config(
hass, hass.config.path(IP_BANS_FILE)
)
ban_manager = IpBanManager(hass)
await ban_manager.async_load()
app[KEY_BAN_MANAGER] = ban_manager
app.on_startup.append(ban_startup)
@ -62,18 +62,17 @@ async def ban_middleware(
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
) -> StreamResponse:
"""IP Ban middleware."""
if KEY_BANNED_IPS not in request.app:
ban_manager: IpBanManager | None = request.app.get(KEY_BAN_MANAGER)
if ban_manager is None:
_LOGGER.error("IP Ban middleware loaded but banned IPs not loaded")
return await handler(request)
# Verify if IP is not banned
ip_address_ = ip_address(request.remote) # type: ignore[arg-type]
is_banned = any(
ip_ban.ip_address == ip_address_ for ip_ban in request.app[KEY_BANNED_IPS]
)
if is_banned:
raise HTTPForbidden()
ip_bans_lookup = ban_manager.ip_bans_lookup
if ip_bans_lookup:
# Verify if IP is not banned
ip_address_ = ip_address(request.remote) # type: ignore[arg-type]
if ip_address_ in ip_bans_lookup:
raise HTTPForbidden()
try:
return await handler(request)
@ -129,7 +128,7 @@ async def process_wrong_login(request: Request) -> None:
)
# Check if ban middleware is loaded
if KEY_BANNED_IPS not in request.app or request.app[KEY_LOGIN_THRESHOLD] < 1:
if KEY_BAN_MANAGER not in request.app or request.app[KEY_LOGIN_THRESHOLD] < 1:
return
request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] += 1
@ -146,14 +145,9 @@ async def process_wrong_login(request: Request) -> None:
request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr]
>= request.app[KEY_LOGIN_THRESHOLD]
):
new_ban = IpBan(remote_addr)
request.app[KEY_BANNED_IPS].append(new_ban)
await hass.async_add_executor_job(
update_ip_bans_config, hass.config.path(IP_BANS_FILE), new_ban
)
ban_manager: IpBanManager = request.app[KEY_BAN_MANAGER]
_LOGGER.warning("Banned IP %s for too many login attempts", remote_addr)
await ban_manager.async_add_ban(remote_addr)
persistent_notification.async_create(
hass,
@ -173,7 +167,7 @@ async def process_success_login(request: Request) -> None:
remote_addr = ip_address(request.remote) # type: ignore[arg-type]
# Check if ban middleware is loaded
if KEY_BANNED_IPS not in request.app or request.app[KEY_LOGIN_THRESHOLD] < 1:
if KEY_BAN_MANAGER not in request.app or request.app[KEY_LOGIN_THRESHOLD] < 1:
return
if (
@ -199,32 +193,49 @@ class IpBan:
self.banned_at = banned_at or dt_util.utcnow()
async def async_load_ip_bans_config(hass: HomeAssistant, path: str) -> list[IpBan]:
"""Load list of banned IPs from config file."""
ip_list: list[IpBan] = []
class IpBanManager:
"""Manage IP bans."""
try:
list_ = await hass.async_add_executor_job(load_yaml_config_file, path)
except FileNotFoundError:
return ip_list
except HomeAssistantError as err:
_LOGGER.error("Unable to load %s: %s", path, str(err))
return ip_list
def __init__(self, hass: HomeAssistant) -> None:
"""Init the ban manager."""
self.hass = hass
self.path = hass.config.path(IP_BANS_FILE)
self.ip_bans_lookup: dict[IPv4Address | IPv6Address, IpBan] = {}
for ip_ban, ip_info in list_.items():
async def async_load(self) -> None:
"""Load the existing IP bans."""
try:
ip_info = SCHEMA_IP_BAN_ENTRY(ip_info)
ip_list.append(IpBan(ip_ban, ip_info["banned_at"]))
except vol.Invalid as err:
_LOGGER.error("Failed to load IP ban %s: %s", ip_info, err)
continue
list_ = await self.hass.async_add_executor_job(
load_yaml_config_file, self.path
)
except FileNotFoundError:
return
except HomeAssistantError as err:
_LOGGER.error("Unable to load %s: %s", self.path, str(err))
return
return ip_list
ip_bans_lookup: dict[IPv4Address | IPv6Address, IpBan] = {}
for ip_ban, ip_info in list_.items():
try:
ip_info = SCHEMA_IP_BAN_ENTRY(ip_info)
ban = IpBan(ip_ban, ip_info["banned_at"])
ip_bans_lookup[ban.ip_address] = ban
except vol.Invalid as err:
_LOGGER.error("Failed to load IP ban %s: %s", ip_info, err)
continue
self.ip_bans_lookup = ip_bans_lookup
def update_ip_bans_config(path: str, ip_ban: IpBan) -> None:
"""Update config file with new banned IP address."""
with open(path, "a", encoding="utf8") as out:
ip_ = {str(ip_ban.ip_address): {ATTR_BANNED_AT: ip_ban.banned_at.isoformat()}}
out.write("\n")
out.write(yaml.dump(ip_))
def _add_ban(self, ip_ban: IpBan) -> None:
"""Update config file with new banned IP address."""
with open(self.path, "a", encoding="utf8") as out:
ip_ = {
str(ip_ban.ip_address): {ATTR_BANNED_AT: ip_ban.banned_at.isoformat()}
}
# Write in a single write call to avoid interleaved writes
out.write("\n" + yaml.dump(ip_))
async def async_add_ban(self, remote_addr: IPv4Address | IPv6Address) -> None:
"""Add a new IP address to the banned list."""
new_ban = self.ip_bans_lookup[remote_addr] = IpBan(remote_addr)
await self.hass.async_add_executor_job(self._add_ban, new_ban)

View file

@ -19,8 +19,7 @@ def patch_zeroconf_multiple_catcher():
def prevent_io():
"""Fixture to prevent certain I/O from happening."""
with patch(
"homeassistant.components.http.ban.async_load_ip_bans_config",
return_value=[],
"homeassistant.components.http.ban.load_yaml_config_file",
):
yield

View file

@ -15,12 +15,13 @@ import homeassistant.components.http as http
from homeassistant.components.http import KEY_AUTHENTICATED
from homeassistant.components.http.ban import (
IP_BANS_FILE,
KEY_BANNED_IPS,
KEY_BAN_MANAGER,
KEY_FAILED_LOGIN_ATTEMPTS,
IpBan,
IpBanManager,
setup_bans,
)
from homeassistant.components.http.view import request_handler_factory
from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component
from . import mock_real_ip
@ -58,8 +59,10 @@ async def test_access_from_banned_ip(hass, aiohttp_client):
set_real_ip = mock_real_ip(app)
with patch(
"homeassistant.components.http.ban.async_load_ip_bans_config",
return_value=[IpBan(banned_ip) for banned_ip in BANNED_IPS],
"homeassistant.components.http.ban.load_yaml_config_file",
return_value={
banned_ip: {"banned_at": "2016-11-16T19:20:03"} for banned_ip in BANNED_IPS
},
):
client = await aiohttp_client(app)
@ -69,6 +72,99 @@ async def test_access_from_banned_ip(hass, aiohttp_client):
assert resp.status == HTTPStatus.FORBIDDEN
async def test_access_from_banned_ip_with_partially_broken_yaml_file(
hass, aiohttp_client, caplog
):
"""Test accessing to server from banned IP. Both trusted and not.
We inject some garbage into the yaml file to make sure it can
still load the bans.
"""
app = web.Application()
app["hass"] = hass
setup_bans(hass, app, 5)
set_real_ip = mock_real_ip(app)
data = {banned_ip: {"banned_at": "2016-11-16T19:20:03"} for banned_ip in BANNED_IPS}
data["5.3.3.3"] = {"banned_at": "garbage"}
with patch(
"homeassistant.components.http.ban.load_yaml_config_file",
return_value=data,
):
client = await aiohttp_client(app)
for remote_addr in BANNED_IPS:
set_real_ip(remote_addr)
resp = await client.get("/")
assert resp.status == HTTPStatus.FORBIDDEN
# Ensure garbage data is ignored
set_real_ip("5.3.3.3")
resp = await client.get("/")
assert resp.status == HTTPStatus.NOT_FOUND
assert "Failed to load IP ban" in caplog.text
async def test_no_ip_bans_file(hass, aiohttp_client):
"""Test no ip bans file."""
app = web.Application()
app["hass"] = hass
setup_bans(hass, app, 5)
set_real_ip = mock_real_ip(app)
with patch(
"homeassistant.components.http.ban.load_yaml_config_file",
side_effect=FileNotFoundError,
):
client = await aiohttp_client(app)
set_real_ip("4.3.2.1")
resp = await client.get("/")
assert resp.status == HTTPStatus.NOT_FOUND
async def test_failure_loading_ip_bans_file(hass, aiohttp_client):
"""Test failure loading ip bans file."""
app = web.Application()
app["hass"] = hass
setup_bans(hass, app, 5)
set_real_ip = mock_real_ip(app)
with patch(
"homeassistant.components.http.ban.load_yaml_config_file",
side_effect=HomeAssistantError,
):
client = await aiohttp_client(app)
set_real_ip("4.3.2.1")
resp = await client.get("/")
assert resp.status == HTTPStatus.NOT_FOUND
async def test_ip_ban_manager_never_started(hass, aiohttp_client, caplog):
"""Test we handle the ip ban manager not being started."""
app = web.Application()
app["hass"] = hass
setup_bans(hass, app, 5)
set_real_ip = mock_real_ip(app)
with patch(
"homeassistant.components.http.ban.load_yaml_config_file",
side_effect=FileNotFoundError,
):
client = await aiohttp_client(app)
# Mock the manager never being started
del app[KEY_BAN_MANAGER]
set_real_ip("4.3.2.1")
resp = await client.get("/")
assert resp.status == HTTPStatus.NOT_FOUND
assert "IP Ban middleware loaded but banned IPs not loaded" in caplog.text
@pytest.mark.parametrize(
"remote_addr, bans, status",
list(
@ -95,10 +191,13 @@ async def test_access_from_supervisor_ip(
mock_real_ip(app)(remote_addr)
with patch(
"homeassistant.components.http.ban.async_load_ip_bans_config", return_value=[]
"homeassistant.components.http.ban.load_yaml_config_file",
return_value={},
):
client = await aiohttp_client(app)
manager: IpBanManager = app[KEY_BAN_MANAGER]
assert await async_setup_component(hass, "hassio", {"hassio": {}})
m_open = mock_open()
@ -108,13 +207,13 @@ async def test_access_from_supervisor_ip(
):
resp = await client.get("/")
assert resp.status == HTTPStatus.UNAUTHORIZED
assert len(app[KEY_BANNED_IPS]) == bans
assert len(manager.ip_bans_lookup) == bans
assert m_open.call_count == bans
# second request should be forbidden if banned
resp = await client.get("/")
assert resp.status == status
assert len(app[KEY_BANNED_IPS]) == bans
assert len(manager.ip_bans_lookup) == bans
async def test_ban_middleware_not_loaded_by_config(hass):
@ -149,22 +248,25 @@ async def test_ip_bans_file_creation(hass, aiohttp_client):
mock_real_ip(app)("200.201.202.204")
with patch(
"homeassistant.components.http.ban.async_load_ip_bans_config",
return_value=[IpBan(banned_ip) for banned_ip in BANNED_IPS],
"homeassistant.components.http.ban.load_yaml_config_file",
return_value={
banned_ip: {"banned_at": "2016-11-16T19:20:03"} for banned_ip in BANNED_IPS
},
):
client = await aiohttp_client(app)
manager: IpBanManager = app[KEY_BAN_MANAGER]
m_open = mock_open()
with patch("homeassistant.components.http.ban.open", m_open, create=True):
resp = await client.get("/")
assert resp.status == HTTPStatus.UNAUTHORIZED
assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS)
assert len(manager.ip_bans_lookup) == len(BANNED_IPS)
assert m_open.call_count == 0
resp = await client.get("/")
assert resp.status == HTTPStatus.UNAUTHORIZED
assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1
assert len(manager.ip_bans_lookup) == len(BANNED_IPS) + 1
m_open.assert_called_once_with(
hass.config.path(IP_BANS_FILE), "a", encoding="utf8"
)