Disconnect websockets after token is revoked (#57091)

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
David Boslee 2021-10-08 10:38:22 -06:00 committed by GitHub
parent 830e2bc47a
commit 7f966613bd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 94 additions and 7 deletions

View file

@ -9,7 +9,7 @@ from typing import Any, Dict, Mapping, Optional, Tuple, cast
import jwt
from homeassistant import data_entry_flow
from homeassistant.core import HomeAssistant, callback
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.util import dt as dt_util
@ -155,6 +155,7 @@ class AuthManager:
self._providers = providers
self._mfa_modules = mfa_modules
self.login_flow = AuthManagerFlowManager(hass, self)
self._revoke_callbacks: dict[str, list[CALLBACK_TYPE]] = {}
@property
def auth_providers(self) -> list[AuthProvider]:
@ -446,6 +447,28 @@ class AuthManager:
"""Delete a refresh token."""
await self._store.async_remove_refresh_token(refresh_token)
callbacks = self._revoke_callbacks.pop(refresh_token.id, [])
for revoke_callback in callbacks:
revoke_callback()
@callback
def async_register_revoke_token_callback(
self, refresh_token_id: str, revoke_callback: CALLBACK_TYPE
) -> CALLBACK_TYPE:
"""Register a callback to be called when the refresh token id is revoked."""
if refresh_token_id not in self._revoke_callbacks:
self._revoke_callbacks[refresh_token_id] = []
callbacks = self._revoke_callbacks[refresh_token_id]
callbacks.append(revoke_callback)
@callback
def unregister() -> None:
if revoke_callback in callbacks:
callbacks.remove(revoke_callback)
return unregister
@callback
def async_create_access_token(
self, refresh_token: models.RefreshToken, remote_ip: str | None = None

View file

@ -11,7 +11,7 @@ from voluptuous.humanize import humanize_error
from homeassistant.auth.models import RefreshToken, User
from homeassistant.components.http.ban import process_success_login, process_wrong_login
from homeassistant.const import __version__
from homeassistant.core import HomeAssistant
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
from .connection import ActiveConnection
from .error import Disconnect
@ -57,11 +57,13 @@ class AuthPhase:
logger: WebSocketAdapter,
hass: HomeAssistant,
send_message: Callable[[str | dict[str, Any]], None],
cancel_ws: CALLBACK_TYPE,
request: Request,
) -> None:
"""Initialize the authentiated connection."""
self._hass = hass
self._send_message = send_message
self._cancel_ws = cancel_ws
self._logger = logger
self._request = request
@ -83,7 +85,14 @@ class AuthPhase:
msg["access_token"]
)
if refresh_token is not None:
return await self._async_finish_auth(refresh_token.user, refresh_token)
conn = await self._async_finish_auth(refresh_token.user, refresh_token)
conn.subscriptions[
"auth"
] = self._hass.auth.async_register_revoke_token_callback(
refresh_token.id, self._cancel_ws
)
return conn
self._send_message(auth_invalid_message("Invalid access token or password"))
await process_wrong_login(self._request)

View file

@ -161,7 +161,9 @@ class WebSocketHandler:
# event we do not want to block for websocket responses
self._writer_task = asyncio.create_task(self._writer())
auth = AuthPhase(self._logger, self.hass, self._send_message, request)
auth = AuthPhase(
self._logger, self.hass, self._send_message, self._cancel, request
)
connection = None
disconnect_warn = None

View file

@ -529,6 +529,42 @@ async def test_remove_refresh_token(mock_hass):
assert await manager.async_validate_access_token(access_token) is None
async def test_register_revoke_token_callback(mock_hass):
"""Test that a registered revoke token callback is called."""
manager = await auth.auth_manager_from_config(mock_hass, [], [])
user = MockUser().add_to_auth_manager(manager)
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
called = False
def cb():
nonlocal called
called = True
manager.async_register_revoke_token_callback(refresh_token.id, cb)
await manager.async_remove_refresh_token(refresh_token)
assert called
async def test_unregister_revoke_token_callback(mock_hass):
"""Test that a revoke token callback can be unregistered."""
manager = await auth.auth_manager_from_config(mock_hass, [], [])
user = MockUser().add_to_auth_manager(manager)
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
called = False
def cb():
nonlocal called
called = True
unregister = manager.async_register_revoke_token_callback(refresh_token.id, cb)
unregister()
await manager.async_remove_refresh_token(refresh_token)
assert not called
async def test_create_access_token(mock_hass):
"""Test normal refresh_token's jwt_key keep same after used."""
manager = await auth.auth_manager_from_config(mock_hass, [], [])

View file

@ -363,11 +363,15 @@ async def test_ws_refresh_tokens(hass, hass_ws_client, hass_access_token):
assert token["last_used_ip"] == refresh_token.last_used_ip
async def test_ws_delete_refresh_token(hass, hass_ws_client, hass_access_token):
async def test_ws_delete_refresh_token(
hass, hass_admin_user, hass_admin_credential, hass_ws_client, hass_access_token
):
"""Test deleting a refresh token."""
assert await async_setup_component(hass, "auth", {"http": {}})
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
refresh_token = await hass.auth.async_create_refresh_token(
hass_admin_user, CLIENT_ID, credential=hass_admin_credential
)
ws_client = await hass_ws_client(hass, hass_access_token)
@ -382,7 +386,7 @@ async def test_ws_delete_refresh_token(hass, hass_ws_client, hass_access_token):
result = await ws_client.receive_json()
assert result["success"], result
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
refresh_token = await hass.auth.async_get_refresh_token(refresh_token.id)
assert refresh_token is None

View file

@ -1,6 +1,7 @@
"""Test auth of websocket API."""
from unittest.mock import patch
import aiohttp
import pytest
from homeassistant.components.websocket_api.auth import (
@ -191,3 +192,15 @@ async def test_auth_with_invalid_token(hass, hass_client_no_auth):
auth_msg = await ws.receive_json()
assert auth_msg["type"] == TYPE_AUTH_INVALID
async def test_auth_close_after_revoke(hass, websocket_client, hass_access_token):
"""Test that a websocket is closed after the refresh token is revoked."""
assert not websocket_client.closed
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
await hass.auth.async_remove_refresh_token(refresh_token)
msg = await websocket_client.receive()
assert msg.type == aiohttp.WSMsgType.CLOSE
assert websocket_client.closed