Reset failed login attempts counter when login success (#15564)
This commit is contained in:
parent
f2a99e83cd
commit
f1286f8e6b
4 changed files with 91 additions and 6 deletions
|
@ -72,7 +72,11 @@ async def ban_middleware(request, handler):
|
||||||
|
|
||||||
|
|
||||||
async def process_wrong_login(request):
|
async def process_wrong_login(request):
|
||||||
"""Process a wrong login attempt."""
|
"""Process a wrong login attempt.
|
||||||
|
|
||||||
|
Increase failed login attempts counter for remote IP address.
|
||||||
|
Add ip ban entry if failed login attempts exceeds threshold.
|
||||||
|
"""
|
||||||
remote_addr = request[KEY_REAL_IP]
|
remote_addr = request[KEY_REAL_IP]
|
||||||
|
|
||||||
msg = ('Login attempt or request with invalid authentication '
|
msg = ('Login attempt or request with invalid authentication '
|
||||||
|
@ -107,6 +111,27 @@ async def process_wrong_login(request):
|
||||||
'Banning IP address', NOTIFICATION_ID_BAN)
|
'Banning IP address', NOTIFICATION_ID_BAN)
|
||||||
|
|
||||||
|
|
||||||
|
async def process_success_login(request):
|
||||||
|
"""Process a success login attempt.
|
||||||
|
|
||||||
|
Reset failed login attempts counter for remote IP address.
|
||||||
|
No release IP address from banned list function, it can only be done by
|
||||||
|
manual modify ip bans config file.
|
||||||
|
"""
|
||||||
|
remote_addr = request[KEY_REAL_IP]
|
||||||
|
|
||||||
|
# Check if ban middleware is loaded
|
||||||
|
if (KEY_BANNED_IPS not in request.app or
|
||||||
|
request.app[KEY_LOGIN_THRESHOLD] < 1):
|
||||||
|
return
|
||||||
|
|
||||||
|
if remote_addr in request.app[KEY_FAILED_LOGIN_ATTEMPTS] and \
|
||||||
|
request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] > 0:
|
||||||
|
_LOGGER.debug('Login success, reset failed login attempts counter'
|
||||||
|
' from %s', remote_addr)
|
||||||
|
request.app[KEY_FAILED_LOGIN_ATTEMPTS].pop(remote_addr)
|
||||||
|
|
||||||
|
|
||||||
class IpBan:
|
class IpBan:
|
||||||
"""Represents banned IP address."""
|
"""Represents banned IP address."""
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ from aiohttp import web
|
||||||
from aiohttp.web_exceptions import HTTPUnauthorized, HTTPInternalServerError
|
from aiohttp.web_exceptions import HTTPUnauthorized, HTTPInternalServerError
|
||||||
|
|
||||||
import homeassistant.remote as rem
|
import homeassistant.remote as rem
|
||||||
|
from homeassistant.components.http.ban import process_success_login
|
||||||
from homeassistant.core import is_callback
|
from homeassistant.core import is_callback
|
||||||
from homeassistant.const import CONTENT_TYPE_JSON
|
from homeassistant.const import CONTENT_TYPE_JSON
|
||||||
|
|
||||||
|
@ -91,8 +92,11 @@ def request_handler_factory(view, handler):
|
||||||
|
|
||||||
authenticated = request.get(KEY_AUTHENTICATED, False)
|
authenticated = request.get(KEY_AUTHENTICATED, False)
|
||||||
|
|
||||||
if view.requires_auth and not authenticated:
|
if view.requires_auth:
|
||||||
raise HTTPUnauthorized()
|
if authenticated:
|
||||||
|
await process_success_login(request)
|
||||||
|
else:
|
||||||
|
raise HTTPUnauthorized()
|
||||||
|
|
||||||
_LOGGER.info('Serving %s to %s (auth: %s)',
|
_LOGGER.info('Serving %s to %s (auth: %s)',
|
||||||
request.path, request.get(KEY_REAL_IP), authenticated)
|
request.path, request.get(KEY_REAL_IP), authenticated)
|
||||||
|
|
|
@ -26,7 +26,8 @@ from homeassistant.helpers.service import async_get_all_descriptions
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView
|
||||||
from homeassistant.components.http.auth import validate_password
|
from homeassistant.components.http.auth import validate_password
|
||||||
from homeassistant.components.http.const import KEY_AUTHENTICATED
|
from homeassistant.components.http.const import KEY_AUTHENTICATED
|
||||||
from homeassistant.components.http.ban import process_wrong_login
|
from homeassistant.components.http.ban import process_wrong_login, \
|
||||||
|
process_success_login
|
||||||
|
|
||||||
DOMAIN = 'websocket_api'
|
DOMAIN = 'websocket_api'
|
||||||
|
|
||||||
|
@ -360,6 +361,7 @@ class ActiveConnection:
|
||||||
return wsock
|
return wsock
|
||||||
|
|
||||||
self.debug("Auth OK")
|
self.debug("Auth OK")
|
||||||
|
await process_success_login(request)
|
||||||
await self.wsock.send_json(auth_ok_message())
|
await self.wsock.send_json(auth_ok_message())
|
||||||
|
|
||||||
# ---------- AUTH PHASE OVER ----------
|
# ---------- AUTH PHASE OVER ----------
|
||||||
|
|
|
@ -1,14 +1,18 @@
|
||||||
"""The tests for the Home Assistant HTTP component."""
|
"""The tests for the Home Assistant HTTP component."""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
from unittest.mock import patch, mock_open
|
from ipaddress import ip_address
|
||||||
|
from unittest.mock import patch, mock_open, Mock
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from aiohttp.web_exceptions import HTTPUnauthorized
|
from aiohttp.web_exceptions import HTTPUnauthorized
|
||||||
|
from aiohttp.web_middlewares import middleware
|
||||||
|
|
||||||
|
from homeassistant.components.http import KEY_AUTHENTICATED
|
||||||
|
from homeassistant.components.http.view import request_handler_factory
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
import homeassistant.components.http as http
|
import homeassistant.components.http as http
|
||||||
from homeassistant.components.http.ban import (
|
from homeassistant.components.http.ban import (
|
||||||
IpBan, IP_BANS_FILE, setup_bans, KEY_BANNED_IPS)
|
IpBan, IP_BANS_FILE, setup_bans, KEY_BANNED_IPS, KEY_FAILED_LOGIN_ATTEMPTS)
|
||||||
|
|
||||||
from . import mock_real_ip
|
from . import mock_real_ip
|
||||||
|
|
||||||
|
@ -88,3 +92,53 @@ async def test_ip_bans_file_creation(hass, aiohttp_client):
|
||||||
resp = await client.get('/')
|
resp = await client.get('/')
|
||||||
assert resp.status == 403
|
assert resp.status == 403
|
||||||
assert m.call_count == 1
|
assert m.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_failed_login_attempts_counter(hass, aiohttp_client):
|
||||||
|
"""Testing if failed login attempts counter increased."""
|
||||||
|
app = web.Application()
|
||||||
|
app['hass'] = hass
|
||||||
|
|
||||||
|
async def auth_handler(request):
|
||||||
|
"""Return 200 status code."""
|
||||||
|
return None, 200
|
||||||
|
|
||||||
|
app.router.add_get('/auth_true', request_handler_factory(
|
||||||
|
Mock(requires_auth=True), auth_handler))
|
||||||
|
app.router.add_get('/auth_false', request_handler_factory(
|
||||||
|
Mock(requires_auth=True), auth_handler))
|
||||||
|
app.router.add_get('/', request_handler_factory(
|
||||||
|
Mock(requires_auth=False), auth_handler))
|
||||||
|
|
||||||
|
setup_bans(hass, app, 5)
|
||||||
|
remote_ip = ip_address("200.201.202.204")
|
||||||
|
mock_real_ip(app)("200.201.202.204")
|
||||||
|
|
||||||
|
@middleware
|
||||||
|
async def mock_auth(request, handler):
|
||||||
|
"""Mock auth middleware."""
|
||||||
|
if 'auth_true' in request.path:
|
||||||
|
request[KEY_AUTHENTICATED] = True
|
||||||
|
else:
|
||||||
|
request[KEY_AUTHENTICATED] = False
|
||||||
|
return await handler(request)
|
||||||
|
|
||||||
|
app.middlewares.append(mock_auth)
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
|
resp = await client.get('/auth_false')
|
||||||
|
assert resp.status == 401
|
||||||
|
assert app[KEY_FAILED_LOGIN_ATTEMPTS][remote_ip] == 1
|
||||||
|
|
||||||
|
resp = await client.get('/auth_false')
|
||||||
|
assert resp.status == 401
|
||||||
|
assert app[KEY_FAILED_LOGIN_ATTEMPTS][remote_ip] == 2
|
||||||
|
|
||||||
|
resp = await client.get('/')
|
||||||
|
assert resp.status == 200
|
||||||
|
assert app[KEY_FAILED_LOGIN_ATTEMPTS][remote_ip] == 2
|
||||||
|
|
||||||
|
resp = await client.get('/auth_true')
|
||||||
|
assert resp.status == 200
|
||||||
|
assert remote_ip not in app[KEY_FAILED_LOGIN_ATTEMPTS]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue