Advanced Ip filtering (#4424)
* Added IP Bans configuration * Fixing warnings * Added ban enabled option and unit tests * Fixed py34 tox * http: requested changes fix * Requested changes fix
This commit is contained in:
parent
95b439fbd5
commit
2a7bc0e55c
6 changed files with 225 additions and 18 deletions
|
@ -75,9 +75,12 @@ def setup(hass, yaml_config):
|
|||
api_password=None,
|
||||
ssl_certificate=None,
|
||||
ssl_key=None,
|
||||
cors_origins=[],
|
||||
cors_origins=None,
|
||||
use_x_forwarded_for=False,
|
||||
trusted_networks=[]
|
||||
trusted_networks=None,
|
||||
ip_bans=None,
|
||||
login_threshold=0,
|
||||
is_ban_enabled=False
|
||||
)
|
||||
|
||||
server.register_view(DescriptionXmlView(hass, config))
|
||||
|
|
|
@ -5,32 +5,36 @@ For more details about this component, please refer to the documentation at
|
|||
https://home-assistant.io/components/http/
|
||||
"""
|
||||
import asyncio
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
import ssl
|
||||
from datetime import datetime
|
||||
from ipaddress import ip_address, ip_network
|
||||
from pathlib import Path
|
||||
|
||||
import hmac
|
||||
import os
|
||||
import re
|
||||
import voluptuous as vol
|
||||
from aiohttp import web, hdrs
|
||||
from aiohttp.file_sender import FileSender
|
||||
from aiohttp.web_exceptions import (
|
||||
HTTPUnauthorized, HTTPMovedPermanently, HTTPNotModified)
|
||||
HTTPUnauthorized, HTTPMovedPermanently, HTTPNotModified, HTTPForbidden)
|
||||
from aiohttp.web_urldispatcher import StaticResource
|
||||
|
||||
from homeassistant.core import is_callback
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
import homeassistant.remote as rem
|
||||
from homeassistant import util
|
||||
from homeassistant.components import persistent_notification
|
||||
from homeassistant.config import load_yaml_config_file
|
||||
from homeassistant.const import (
|
||||
SERVER_PORT, HTTP_HEADER_HA_AUTH, # HTTP_HEADER_CACHE_CONTROL,
|
||||
CONTENT_TYPE_JSON, ALLOWED_CORS_HEADERS, EVENT_HOMEASSISTANT_STOP,
|
||||
EVENT_HOMEASSISTANT_START, HTTP_HEADER_X_FORWARDED_FOR)
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.components import persistent_notification
|
||||
from homeassistant.core import is_callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.util.yaml import dump
|
||||
|
||||
DOMAIN = 'http'
|
||||
REQUIREMENTS = ('aiohttp_cors==0.5.0',)
|
||||
|
@ -44,9 +48,16 @@ CONF_SSL_KEY = 'ssl_key'
|
|||
CONF_CORS_ORIGINS = 'cors_allowed_origins'
|
||||
CONF_USE_X_FORWARDED_FOR = 'use_x_forwarded_for'
|
||||
CONF_TRUSTED_NETWORKS = 'trusted_networks'
|
||||
CONF_LOGIN_ATTEMPTS_THRESHOLD = 'login_attempts_threshold'
|
||||
CONF_IP_BAN_ENABLED = 'ip_ban_enabled'
|
||||
|
||||
DATA_API_PASSWORD = 'api_password'
|
||||
NOTIFICATION_ID_LOGIN = 'http-login'
|
||||
NOTIFICATION_ID_BAN = 'ip-ban'
|
||||
|
||||
IP_BANS = 'ip_bans.yaml'
|
||||
ATTR_BANNED_AT = "banned_at"
|
||||
|
||||
|
||||
# TLS configuation follows the best-practice guidelines specified here:
|
||||
# https://wiki.mozilla.org/Security/Server_Side_TLS
|
||||
|
@ -85,7 +96,9 @@ CONFIG_SCHEMA = vol.Schema({
|
|||
vol.Optional(CONF_CORS_ORIGINS): vol.All(cv.ensure_list, [cv.string]),
|
||||
vol.Optional(CONF_USE_X_FORWARDED_FOR, default=False): cv.boolean,
|
||||
vol.Optional(CONF_TRUSTED_NETWORKS):
|
||||
vol.All(cv.ensure_list, [ip_network])
|
||||
vol.All(cv.ensure_list, [ip_network]),
|
||||
vol.Optional(CONF_LOGIN_ATTEMPTS_THRESHOLD): cv.positive_int,
|
||||
vol.Optional(CONF_IP_BAN_ENABLED): cv.boolean
|
||||
}),
|
||||
}, extra=vol.ALLOW_EXTRA)
|
||||
|
||||
|
@ -131,6 +144,9 @@ def setup(hass, config):
|
|||
trusted_networks = [
|
||||
ip_network(trusted_network)
|
||||
for trusted_network in conf.get(CONF_TRUSTED_NETWORKS, [])]
|
||||
is_ban_enabled = bool(conf.get(CONF_IP_BAN_ENABLED, False))
|
||||
login_threshold = int(conf.get(CONF_LOGIN_ATTEMPTS_THRESHOLD, -1))
|
||||
ip_bans = load_ip_bans_config(hass.config.path(IP_BANS))
|
||||
|
||||
server = HomeAssistantWSGI(
|
||||
hass,
|
||||
|
@ -142,7 +158,10 @@ def setup(hass, config):
|
|||
ssl_key=ssl_key,
|
||||
cors_origins=cors_origins,
|
||||
use_x_forwarded_for=use_x_forwarded_for,
|
||||
trusted_networks=trusted_networks
|
||||
trusted_networks=trusted_networks,
|
||||
ip_bans=ip_bans,
|
||||
login_threshold=login_threshold,
|
||||
is_ban_enabled=is_ban_enabled
|
||||
)
|
||||
|
||||
@asyncio.coroutine
|
||||
|
@ -254,7 +273,8 @@ class HomeAssistantWSGI(object):
|
|||
|
||||
def __init__(self, hass, development, api_password, ssl_certificate,
|
||||
ssl_key, server_host, server_port, cors_origins,
|
||||
use_x_forwarded_for, trusted_networks):
|
||||
use_x_forwarded_for, trusted_networks,
|
||||
ip_bans, login_threshold, is_ban_enabled):
|
||||
"""Initialize the WSGI Home Assistant server."""
|
||||
import aiohttp_cors
|
||||
|
||||
|
@ -268,10 +288,15 @@ class HomeAssistantWSGI(object):
|
|||
self.server_host = server_host
|
||||
self.server_port = server_port
|
||||
self.use_x_forwarded_for = use_x_forwarded_for
|
||||
self.trusted_networks = trusted_networks
|
||||
self.trusted_networks = trusted_networks \
|
||||
if trusted_networks is not None else []
|
||||
self.event_forwarder = None
|
||||
self._handler = None
|
||||
self.server = None
|
||||
self.login_threshold = login_threshold
|
||||
self.ip_bans = ip_bans if ip_bans is not None else []
|
||||
self.failed_login_attempts = {}
|
||||
self.is_ban_enabled = is_ban_enabled
|
||||
|
||||
if cors_origins:
|
||||
self.cors = aiohttp_cors.setup(self.app, defaults={
|
||||
|
@ -385,6 +410,39 @@ class HomeAssistantWSGI(object):
|
|||
return any(ip_address(remote_addr) in trusted_network
|
||||
for trusted_network in self.hass.http.trusted_networks)
|
||||
|
||||
def wrong_login_attempt(self, remote_addr):
|
||||
"""Registering wrong login attempt."""
|
||||
if not self.is_ban_enabled or self.login_threshold < 1:
|
||||
return
|
||||
|
||||
if remote_addr in self.failed_login_attempts:
|
||||
self.failed_login_attempts[remote_addr] += 1
|
||||
else:
|
||||
self.failed_login_attempts[remote_addr] = 1
|
||||
|
||||
if self.failed_login_attempts[remote_addr] > self.login_threshold:
|
||||
new_ban = IpBan(remote_addr)
|
||||
self.ip_bans.append(new_ban)
|
||||
update_ip_bans_config(self.hass.config.path(IP_BANS), new_ban)
|
||||
_LOGGER.warning('Banned IP %s for too many login attempts',
|
||||
remote_addr)
|
||||
persistent_notification.async_create(
|
||||
self.hass,
|
||||
'Too many login attempts from {}'.format(remote_addr),
|
||||
'Banning IP address', NOTIFICATION_ID_BAN)
|
||||
|
||||
def is_banned_ip(self, remote_addr):
|
||||
"""Check if IP address is in a ban list."""
|
||||
if not self.is_ban_enabled:
|
||||
return False
|
||||
|
||||
ip_address_ = ip_address(remote_addr)
|
||||
for ip_ban in self.ip_bans:
|
||||
if ip_ban.ip_address == ip_address_:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class HomeAssistantView(object):
|
||||
"""Base view for all views."""
|
||||
|
@ -465,6 +523,9 @@ def request_handler_factory(view, handler):
|
|||
|
||||
remote_addr = view.hass.http.get_real_ip(request)
|
||||
|
||||
if view.hass.http.is_banned_ip(remote_addr):
|
||||
raise HTTPForbidden()
|
||||
|
||||
# Auth code verbose on purpose
|
||||
authenticated = False
|
||||
|
||||
|
@ -484,6 +545,7 @@ def request_handler_factory(view, handler):
|
|||
authenticated = True
|
||||
|
||||
if view.requires_auth and not authenticated:
|
||||
view.hass.http.wrong_login_attempt(remote_addr)
|
||||
_LOGGER.warning('Login attempt or request with an invalid '
|
||||
'password from %s', remote_addr)
|
||||
persistent_notification.async_create(
|
||||
|
@ -525,3 +587,55 @@ def request_handler_factory(view, handler):
|
|||
return web.Response(body=result, status=status_code)
|
||||
|
||||
return handle
|
||||
|
||||
|
||||
class IpBan(object):
|
||||
"""Represents banned IP address."""
|
||||
|
||||
def __init__(self, ip_ban: str, banned_at: datetime=None) -> None:
|
||||
"""Initializing Ip Ban object."""
|
||||
self.ip_address = ip_address(ip_ban)
|
||||
self.banned_at = banned_at
|
||||
if self.banned_at is None:
|
||||
self.banned_at = datetime.utcnow()
|
||||
|
||||
|
||||
def load_ip_bans_config(path: str):
|
||||
"""Loading list of banned IPs from config file."""
|
||||
ip_list = []
|
||||
ip_schema = vol.Schema({
|
||||
vol.Optional('banned_at'): vol.Any(None, cv.datetime)
|
||||
})
|
||||
|
||||
try:
|
||||
try:
|
||||
list_ = load_yaml_config_file(path)
|
||||
except HomeAssistantError as err:
|
||||
_LOGGER.error('Unable to load %s: %s', path, str(err))
|
||||
return []
|
||||
|
||||
for ip_ban, ip_info in list_.items():
|
||||
try:
|
||||
ip_info = ip_schema(ip_info)
|
||||
ip_info['ip_ban'] = ip_address(ip_ban)
|
||||
ip_list.append(IpBan(**ip_info))
|
||||
except vol.Invalid:
|
||||
_LOGGER.exception('Failed to load IP ban')
|
||||
continue
|
||||
|
||||
except(HomeAssistantError, FileNotFoundError):
|
||||
# No need to report error, file absence means
|
||||
# that no bans were applied.
|
||||
return []
|
||||
|
||||
return ip_list
|
||||
|
||||
|
||||
def update_ip_bans_config(path: str, ip_ban: IpBan):
|
||||
"""Update config file with new banned IP address."""
|
||||
with open(path, 'a') as out:
|
||||
ip_ = {str(ip_ban.ip_address): {
|
||||
ATTR_BANNED_AT: ip_ban.banned_at.strftime("%Y-%m-%dT%H:%M:%S")
|
||||
}}
|
||||
out.write('\n')
|
||||
out.write(dump(ip_))
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
"""Helpers for config validation using voluptuous."""
|
||||
from collections import OrderedDict
|
||||
from datetime import timedelta
|
||||
from datetime import timedelta, datetime as datetime_sys
|
||||
import os
|
||||
import re
|
||||
from urllib.parse import urlparse
|
||||
|
@ -297,6 +297,22 @@ def time(value):
|
|||
return time_val
|
||||
|
||||
|
||||
def datetime(value):
|
||||
"""Validate datetime."""
|
||||
if isinstance(value, datetime_sys):
|
||||
return value
|
||||
|
||||
try:
|
||||
date_val = dt_util.parse_datetime(value)
|
||||
except TypeError:
|
||||
date_val = None
|
||||
|
||||
if date_val is None:
|
||||
raise vol.Invalid('Invalid datetime specified: {}'.format(value))
|
||||
|
||||
return date_val
|
||||
|
||||
|
||||
def time_zone(value):
|
||||
"""Validate timezone."""
|
||||
if dt_util.get_time_zone(value) is not None:
|
||||
|
|
|
@ -124,6 +124,7 @@ class TestHtml5Notify(object):
|
|||
app = web.Application(loop=loop)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
resp = yield from client.post(REGISTER_URL,
|
||||
data=json.dumps(SUBSCRIPTION_1))
|
||||
|
||||
|
@ -155,6 +156,7 @@ class TestHtml5Notify(object):
|
|||
app = web.Application(loop=loop)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
|
||||
resp = yield from client.post(REGISTER_URL, data=json.dumps({
|
||||
'browser': 'invalid browser',
|
||||
|
@ -209,6 +211,7 @@ class TestHtml5Notify(object):
|
|||
app = web.Application(loop=loop)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
|
||||
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
|
||||
'subscription': SUBSCRIPTION_1['subscription'],
|
||||
|
@ -253,6 +256,7 @@ class TestHtml5Notify(object):
|
|||
app = web.Application(loop=loop)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
|
||||
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
|
||||
'subscription': SUBSCRIPTION_3['subscription']
|
||||
|
@ -295,6 +299,7 @@ class TestHtml5Notify(object):
|
|||
app = web.Application(loop=loop)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
|
||||
with patch('homeassistant.components.notify.html5._save_config',
|
||||
return_value=False):
|
||||
|
@ -329,6 +334,7 @@ class TestHtml5Notify(object):
|
|||
app = web.Application(loop=loop)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
|
||||
resp = yield from client.post(PUBLISH_URL, data=json.dumps({
|
||||
'type': 'push',
|
||||
|
@ -384,6 +390,7 @@ class TestHtml5Notify(object):
|
|||
app = web.Application(loop=loop)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
|
||||
resp = yield from client.post(PUBLISH_URL, data=json.dumps({
|
||||
'type': 'push',
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# pylint: disable=protected-access
|
||||
import logging
|
||||
from ipaddress import ip_network
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import patch, mock_open
|
||||
|
||||
import requests
|
||||
|
||||
|
@ -25,7 +25,7 @@ TRUSTED_NETWORKS = ['192.0.2.0/24', '2001:DB8:ABCD::/48', '100.64.0.1',
|
|||
TRUSTED_ADDRESSES = ['100.64.0.1', '192.0.2.100', 'FD01:DB8::1',
|
||||
'2001:DB8:ABCD::1']
|
||||
UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1']
|
||||
|
||||
BANNED_IPS = ['200.201.202.203', '100.64.0.1']
|
||||
|
||||
CORS_ORIGINS = [HTTP_BASE_URL, HTTP_BASE]
|
||||
|
||||
|
@ -63,6 +63,9 @@ def setUpModule():
|
|||
ip_network(trusted_network)
|
||||
for trusted_network in TRUSTED_NETWORKS]
|
||||
|
||||
hass.http.ip_bans = [http.IpBan(banned_ip)
|
||||
for banned_ip in BANNED_IPS]
|
||||
|
||||
hass.start()
|
||||
|
||||
|
||||
|
@ -227,3 +230,56 @@ class TestHttp:
|
|||
assert req.headers.get(allow_origin) == HTTP_BASE_URL
|
||||
assert req.headers.get(allow_headers) == \
|
||||
const.HTTP_HEADER_HA_AUTH.upper()
|
||||
|
||||
def test_access_from_banned_ip(self):
|
||||
"""Test accessing to server from banned IP. Both trusted and not."""
|
||||
hass.http.is_ban_enabled = True
|
||||
for remote_addr in BANNED_IPS:
|
||||
with patch('homeassistant.components.http.'
|
||||
'HomeAssistantWSGI.get_real_ip',
|
||||
return_value=remote_addr):
|
||||
req = requests.get(
|
||||
_url(const.URL_API))
|
||||
assert req.status_code == 403
|
||||
|
||||
def test_access_from_banned_ip_when_ban_is_off(self):
|
||||
"""Test accessing to server from banned IP when feature is off"""
|
||||
hass.http.is_ban_enabled = False
|
||||
for remote_addr in BANNED_IPS:
|
||||
with patch('homeassistant.components.http.'
|
||||
'HomeAssistantWSGI.get_real_ip',
|
||||
return_value=remote_addr):
|
||||
req = requests.get(
|
||||
_url(const.URL_API),
|
||||
headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
|
||||
assert req.status_code == 200
|
||||
|
||||
def test_ip_bans_file_creation(self):
|
||||
"""Testing if banned IP file created"""
|
||||
hass.http.is_ban_enabled = True
|
||||
hass.http.login_threshold = 1
|
||||
|
||||
m = mock_open()
|
||||
|
||||
def call_server():
|
||||
with patch('homeassistant.components.http.'
|
||||
'HomeAssistantWSGI.get_real_ip',
|
||||
return_value="200.201.202.204"):
|
||||
return requests.get(
|
||||
_url(const.URL_API),
|
||||
headers={const.HTTP_HEADER_HA_AUTH: 'Wrong password'})
|
||||
|
||||
with patch('homeassistant.components.http.open', m, create=True):
|
||||
req = call_server()
|
||||
assert req.status_code == 401
|
||||
assert len(hass.http.ip_bans) == len(BANNED_IPS)
|
||||
assert m.call_count == 0
|
||||
|
||||
req = call_server()
|
||||
assert req.status_code == 401
|
||||
assert len(hass.http.ip_bans) == len(BANNED_IPS) + 1
|
||||
m.assert_called_once_with(hass.config.path(http.IP_BANS), 'a')
|
||||
|
||||
req = call_server()
|
||||
assert req.status_code == 403
|
||||
assert m.call_count == 1
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
"""Test config validators."""
|
||||
from collections import OrderedDict
|
||||
from datetime import timedelta
|
||||
from datetime import timedelta, datetime, date
|
||||
import enum
|
||||
import os
|
||||
from socket import _GLOBAL_DEFAULT_TIMEOUT
|
||||
|
@ -358,6 +358,17 @@ def test_time_zone():
|
|||
schema('UTC')
|
||||
|
||||
|
||||
def test_datetime():
|
||||
"""Test date time validation."""
|
||||
schema = vol.Schema(cv.datetime)
|
||||
for value in [date.today(), 'Wrong DateTime', '2016-11-23']:
|
||||
with pytest.raises(vol.MultipleInvalid):
|
||||
schema(value)
|
||||
|
||||
schema(datetime.now())
|
||||
schema('2016-11-23T18:59:08')
|
||||
|
||||
|
||||
def test_key_dependency():
|
||||
"""Test key_dependency validator."""
|
||||
schema = vol.Schema(cv.key_dependency('beer', 'soda'))
|
||||
|
|
Loading…
Add table
Reference in a new issue