Switch linear search to a dict lookup for ip bans (#74482)
This commit is contained in:
parent
ae295f1bf5
commit
0c29b68cf8
3 changed files with 170 additions and 58 deletions
|
@ -26,7 +26,7 @@ from .view import HomeAssistantView
|
|||
|
||||
_LOGGER: Final = logging.getLogger(__name__)
|
||||
|
||||
KEY_BANNED_IPS: Final = "ha_banned_ips"
|
||||
KEY_BAN_MANAGER: Final = "ha_banned_ips_manager"
|
||||
KEY_FAILED_LOGIN_ATTEMPTS: Final = "ha_failed_login_attempts"
|
||||
KEY_LOGIN_THRESHOLD: Final = "ha_login_threshold"
|
||||
|
||||
|
@ -50,9 +50,9 @@ def setup_bans(hass: HomeAssistant, app: Application, login_threshold: int) -> N
|
|||
|
||||
async def ban_startup(app: Application) -> None:
|
||||
"""Initialize bans when app starts up."""
|
||||
app[KEY_BANNED_IPS] = await async_load_ip_bans_config(
|
||||
hass, hass.config.path(IP_BANS_FILE)
|
||||
)
|
||||
ban_manager = IpBanManager(hass)
|
||||
await ban_manager.async_load()
|
||||
app[KEY_BAN_MANAGER] = ban_manager
|
||||
|
||||
app.on_startup.append(ban_startup)
|
||||
|
||||
|
@ -62,18 +62,17 @@ async def ban_middleware(
|
|||
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
|
||||
) -> StreamResponse:
|
||||
"""IP Ban middleware."""
|
||||
if KEY_BANNED_IPS not in request.app:
|
||||
ban_manager: IpBanManager | None = request.app.get(KEY_BAN_MANAGER)
|
||||
if ban_manager is None:
|
||||
_LOGGER.error("IP Ban middleware loaded but banned IPs not loaded")
|
||||
return await handler(request)
|
||||
|
||||
# Verify if IP is not banned
|
||||
ip_address_ = ip_address(request.remote) # type: ignore[arg-type]
|
||||
is_banned = any(
|
||||
ip_ban.ip_address == ip_address_ for ip_ban in request.app[KEY_BANNED_IPS]
|
||||
)
|
||||
|
||||
if is_banned:
|
||||
raise HTTPForbidden()
|
||||
ip_bans_lookup = ban_manager.ip_bans_lookup
|
||||
if ip_bans_lookup:
|
||||
# Verify if IP is not banned
|
||||
ip_address_ = ip_address(request.remote) # type: ignore[arg-type]
|
||||
if ip_address_ in ip_bans_lookup:
|
||||
raise HTTPForbidden()
|
||||
|
||||
try:
|
||||
return await handler(request)
|
||||
|
@ -129,7 +128,7 @@ async def process_wrong_login(request: Request) -> None:
|
|||
)
|
||||
|
||||
# Check if ban middleware is loaded
|
||||
if KEY_BANNED_IPS not in request.app or request.app[KEY_LOGIN_THRESHOLD] < 1:
|
||||
if KEY_BAN_MANAGER not in request.app or request.app[KEY_LOGIN_THRESHOLD] < 1:
|
||||
return
|
||||
|
||||
request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] += 1
|
||||
|
@ -146,14 +145,9 @@ async def process_wrong_login(request: Request) -> None:
|
|||
request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr]
|
||||
>= request.app[KEY_LOGIN_THRESHOLD]
|
||||
):
|
||||
new_ban = IpBan(remote_addr)
|
||||
request.app[KEY_BANNED_IPS].append(new_ban)
|
||||
|
||||
await hass.async_add_executor_job(
|
||||
update_ip_bans_config, hass.config.path(IP_BANS_FILE), new_ban
|
||||
)
|
||||
|
||||
ban_manager: IpBanManager = request.app[KEY_BAN_MANAGER]
|
||||
_LOGGER.warning("Banned IP %s for too many login attempts", remote_addr)
|
||||
await ban_manager.async_add_ban(remote_addr)
|
||||
|
||||
persistent_notification.async_create(
|
||||
hass,
|
||||
|
@ -173,7 +167,7 @@ async def process_success_login(request: Request) -> None:
|
|||
remote_addr = ip_address(request.remote) # type: ignore[arg-type]
|
||||
|
||||
# Check if ban middleware is loaded
|
||||
if KEY_BANNED_IPS not in request.app or request.app[KEY_LOGIN_THRESHOLD] < 1:
|
||||
if KEY_BAN_MANAGER not in request.app or request.app[KEY_LOGIN_THRESHOLD] < 1:
|
||||
return
|
||||
|
||||
if (
|
||||
|
@ -199,32 +193,49 @@ class IpBan:
|
|||
self.banned_at = banned_at or dt_util.utcnow()
|
||||
|
||||
|
||||
async def async_load_ip_bans_config(hass: HomeAssistant, path: str) -> list[IpBan]:
|
||||
"""Load list of banned IPs from config file."""
|
||||
ip_list: list[IpBan] = []
|
||||
class IpBanManager:
|
||||
"""Manage IP bans."""
|
||||
|
||||
try:
|
||||
list_ = await hass.async_add_executor_job(load_yaml_config_file, path)
|
||||
except FileNotFoundError:
|
||||
return ip_list
|
||||
except HomeAssistantError as err:
|
||||
_LOGGER.error("Unable to load %s: %s", path, str(err))
|
||||
return ip_list
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Init the ban manager."""
|
||||
self.hass = hass
|
||||
self.path = hass.config.path(IP_BANS_FILE)
|
||||
self.ip_bans_lookup: dict[IPv4Address | IPv6Address, IpBan] = {}
|
||||
|
||||
for ip_ban, ip_info in list_.items():
|
||||
async def async_load(self) -> None:
|
||||
"""Load the existing IP bans."""
|
||||
try:
|
||||
ip_info = SCHEMA_IP_BAN_ENTRY(ip_info)
|
||||
ip_list.append(IpBan(ip_ban, ip_info["banned_at"]))
|
||||
except vol.Invalid as err:
|
||||
_LOGGER.error("Failed to load IP ban %s: %s", ip_info, err)
|
||||
continue
|
||||
list_ = await self.hass.async_add_executor_job(
|
||||
load_yaml_config_file, self.path
|
||||
)
|
||||
except FileNotFoundError:
|
||||
return
|
||||
except HomeAssistantError as err:
|
||||
_LOGGER.error("Unable to load %s: %s", self.path, str(err))
|
||||
return
|
||||
|
||||
return ip_list
|
||||
ip_bans_lookup: dict[IPv4Address | IPv6Address, IpBan] = {}
|
||||
for ip_ban, ip_info in list_.items():
|
||||
try:
|
||||
ip_info = SCHEMA_IP_BAN_ENTRY(ip_info)
|
||||
ban = IpBan(ip_ban, ip_info["banned_at"])
|
||||
ip_bans_lookup[ban.ip_address] = ban
|
||||
except vol.Invalid as err:
|
||||
_LOGGER.error("Failed to load IP ban %s: %s", ip_info, err)
|
||||
continue
|
||||
|
||||
self.ip_bans_lookup = ip_bans_lookup
|
||||
|
||||
def update_ip_bans_config(path: str, ip_ban: IpBan) -> None:
|
||||
"""Update config file with new banned IP address."""
|
||||
with open(path, "a", encoding="utf8") as out:
|
||||
ip_ = {str(ip_ban.ip_address): {ATTR_BANNED_AT: ip_ban.banned_at.isoformat()}}
|
||||
out.write("\n")
|
||||
out.write(yaml.dump(ip_))
|
||||
def _add_ban(self, ip_ban: IpBan) -> None:
|
||||
"""Update config file with new banned IP address."""
|
||||
with open(self.path, "a", encoding="utf8") as out:
|
||||
ip_ = {
|
||||
str(ip_ban.ip_address): {ATTR_BANNED_AT: ip_ban.banned_at.isoformat()}
|
||||
}
|
||||
# Write in a single write call to avoid interleaved writes
|
||||
out.write("\n" + yaml.dump(ip_))
|
||||
|
||||
async def async_add_ban(self, remote_addr: IPv4Address | IPv6Address) -> None:
|
||||
"""Add a new IP address to the banned list."""
|
||||
new_ban = self.ip_bans_lookup[remote_addr] = IpBan(remote_addr)
|
||||
await self.hass.async_add_executor_job(self._add_ban, new_ban)
|
||||
|
|
|
@ -19,8 +19,7 @@ def patch_zeroconf_multiple_catcher():
|
|||
def prevent_io():
|
||||
"""Fixture to prevent certain I/O from happening."""
|
||||
with patch(
|
||||
"homeassistant.components.http.ban.async_load_ip_bans_config",
|
||||
return_value=[],
|
||||
"homeassistant.components.http.ban.load_yaml_config_file",
|
||||
):
|
||||
yield
|
||||
|
||||
|
|
|
@ -15,12 +15,13 @@ import homeassistant.components.http as http
|
|||
from homeassistant.components.http import KEY_AUTHENTICATED
|
||||
from homeassistant.components.http.ban import (
|
||||
IP_BANS_FILE,
|
||||
KEY_BANNED_IPS,
|
||||
KEY_BAN_MANAGER,
|
||||
KEY_FAILED_LOGIN_ATTEMPTS,
|
||||
IpBan,
|
||||
IpBanManager,
|
||||
setup_bans,
|
||||
)
|
||||
from homeassistant.components.http.view import request_handler_factory
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import mock_real_ip
|
||||
|
@ -58,8 +59,10 @@ async def test_access_from_banned_ip(hass, aiohttp_client):
|
|||
set_real_ip = mock_real_ip(app)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.http.ban.async_load_ip_bans_config",
|
||||
return_value=[IpBan(banned_ip) for banned_ip in BANNED_IPS],
|
||||
"homeassistant.components.http.ban.load_yaml_config_file",
|
||||
return_value={
|
||||
banned_ip: {"banned_at": "2016-11-16T19:20:03"} for banned_ip in BANNED_IPS
|
||||
},
|
||||
):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
|
@ -69,6 +72,99 @@ async def test_access_from_banned_ip(hass, aiohttp_client):
|
|||
assert resp.status == HTTPStatus.FORBIDDEN
|
||||
|
||||
|
||||
async def test_access_from_banned_ip_with_partially_broken_yaml_file(
|
||||
hass, aiohttp_client, caplog
|
||||
):
|
||||
"""Test accessing to server from banned IP. Both trusted and not.
|
||||
|
||||
We inject some garbage into the yaml file to make sure it can
|
||||
still load the bans.
|
||||
"""
|
||||
app = web.Application()
|
||||
app["hass"] = hass
|
||||
setup_bans(hass, app, 5)
|
||||
set_real_ip = mock_real_ip(app)
|
||||
|
||||
data = {banned_ip: {"banned_at": "2016-11-16T19:20:03"} for banned_ip in BANNED_IPS}
|
||||
data["5.3.3.3"] = {"banned_at": "garbage"}
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.http.ban.load_yaml_config_file",
|
||||
return_value=data,
|
||||
):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
for remote_addr in BANNED_IPS:
|
||||
set_real_ip(remote_addr)
|
||||
resp = await client.get("/")
|
||||
assert resp.status == HTTPStatus.FORBIDDEN
|
||||
|
||||
# Ensure garbage data is ignored
|
||||
set_real_ip("5.3.3.3")
|
||||
resp = await client.get("/")
|
||||
assert resp.status == HTTPStatus.NOT_FOUND
|
||||
|
||||
assert "Failed to load IP ban" in caplog.text
|
||||
|
||||
|
||||
async def test_no_ip_bans_file(hass, aiohttp_client):
|
||||
"""Test no ip bans file."""
|
||||
app = web.Application()
|
||||
app["hass"] = hass
|
||||
setup_bans(hass, app, 5)
|
||||
set_real_ip = mock_real_ip(app)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.http.ban.load_yaml_config_file",
|
||||
side_effect=FileNotFoundError,
|
||||
):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
set_real_ip("4.3.2.1")
|
||||
resp = await client.get("/")
|
||||
assert resp.status == HTTPStatus.NOT_FOUND
|
||||
|
||||
|
||||
async def test_failure_loading_ip_bans_file(hass, aiohttp_client):
|
||||
"""Test failure loading ip bans file."""
|
||||
app = web.Application()
|
||||
app["hass"] = hass
|
||||
setup_bans(hass, app, 5)
|
||||
set_real_ip = mock_real_ip(app)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.http.ban.load_yaml_config_file",
|
||||
side_effect=HomeAssistantError,
|
||||
):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
set_real_ip("4.3.2.1")
|
||||
resp = await client.get("/")
|
||||
assert resp.status == HTTPStatus.NOT_FOUND
|
||||
|
||||
|
||||
async def test_ip_ban_manager_never_started(hass, aiohttp_client, caplog):
|
||||
"""Test we handle the ip ban manager not being started."""
|
||||
app = web.Application()
|
||||
app["hass"] = hass
|
||||
setup_bans(hass, app, 5)
|
||||
set_real_ip = mock_real_ip(app)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.http.ban.load_yaml_config_file",
|
||||
side_effect=FileNotFoundError,
|
||||
):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
# Mock the manager never being started
|
||||
del app[KEY_BAN_MANAGER]
|
||||
|
||||
set_real_ip("4.3.2.1")
|
||||
resp = await client.get("/")
|
||||
assert resp.status == HTTPStatus.NOT_FOUND
|
||||
assert "IP Ban middleware loaded but banned IPs not loaded" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"remote_addr, bans, status",
|
||||
list(
|
||||
|
@ -95,10 +191,13 @@ async def test_access_from_supervisor_ip(
|
|||
mock_real_ip(app)(remote_addr)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.http.ban.async_load_ip_bans_config", return_value=[]
|
||||
"homeassistant.components.http.ban.load_yaml_config_file",
|
||||
return_value={},
|
||||
):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
manager: IpBanManager = app[KEY_BAN_MANAGER]
|
||||
|
||||
assert await async_setup_component(hass, "hassio", {"hassio": {}})
|
||||
|
||||
m_open = mock_open()
|
||||
|
@ -108,13 +207,13 @@ async def test_access_from_supervisor_ip(
|
|||
):
|
||||
resp = await client.get("/")
|
||||
assert resp.status == HTTPStatus.UNAUTHORIZED
|
||||
assert len(app[KEY_BANNED_IPS]) == bans
|
||||
assert len(manager.ip_bans_lookup) == bans
|
||||
assert m_open.call_count == bans
|
||||
|
||||
# second request should be forbidden if banned
|
||||
resp = await client.get("/")
|
||||
assert resp.status == status
|
||||
assert len(app[KEY_BANNED_IPS]) == bans
|
||||
assert len(manager.ip_bans_lookup) == bans
|
||||
|
||||
|
||||
async def test_ban_middleware_not_loaded_by_config(hass):
|
||||
|
@ -149,22 +248,25 @@ async def test_ip_bans_file_creation(hass, aiohttp_client):
|
|||
mock_real_ip(app)("200.201.202.204")
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.http.ban.async_load_ip_bans_config",
|
||||
return_value=[IpBan(banned_ip) for banned_ip in BANNED_IPS],
|
||||
"homeassistant.components.http.ban.load_yaml_config_file",
|
||||
return_value={
|
||||
banned_ip: {"banned_at": "2016-11-16T19:20:03"} for banned_ip in BANNED_IPS
|
||||
},
|
||||
):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
manager: IpBanManager = app[KEY_BAN_MANAGER]
|
||||
m_open = mock_open()
|
||||
|
||||
with patch("homeassistant.components.http.ban.open", m_open, create=True):
|
||||
resp = await client.get("/")
|
||||
assert resp.status == HTTPStatus.UNAUTHORIZED
|
||||
assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS)
|
||||
assert len(manager.ip_bans_lookup) == len(BANNED_IPS)
|
||||
assert m_open.call_count == 0
|
||||
|
||||
resp = await client.get("/")
|
||||
assert resp.status == HTTPStatus.UNAUTHORIZED
|
||||
assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1
|
||||
assert len(manager.ip_bans_lookup) == len(BANNED_IPS) + 1
|
||||
m_open.assert_called_once_with(
|
||||
hass.config.path(IP_BANS_FILE), "a", encoding="utf8"
|
||||
)
|
||||
|
|
Loading…
Add table
Reference in a new issue