From 7f966613bdf372706d40c51565b89bc5820b6f7e Mon Sep 17 00:00:00 2001 From: David Boslee Date: Fri, 8 Oct 2021 10:38:22 -0600 Subject: [PATCH] Disconnect websockets after token is revoked (#57091) Co-authored-by: Paulus Schoutsen --- homeassistant/auth/__init__.py | 25 ++++++++++++- .../components/websocket_api/auth.py | 13 +++++-- .../components/websocket_api/http.py | 4 ++- tests/auth/test_init.py | 36 +++++++++++++++++++ tests/components/auth/test_init.py | 10 ++++-- tests/components/websocket_api/test_auth.py | 13 +++++++ 6 files changed, 94 insertions(+), 7 deletions(-) diff --git a/homeassistant/auth/__init__.py b/homeassistant/auth/__init__.py index c528aff221f..f47228ee506 100644 --- a/homeassistant/auth/__init__.py +++ b/homeassistant/auth/__init__.py @@ -9,7 +9,7 @@ from typing import Any, Dict, Mapping, Optional, Tuple, cast import jwt from homeassistant import data_entry_flow -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.data_entry_flow import FlowResult from homeassistant.util import dt as dt_util @@ -155,6 +155,7 @@ class AuthManager: self._providers = providers self._mfa_modules = mfa_modules self.login_flow = AuthManagerFlowManager(hass, self) + self._revoke_callbacks: dict[str, list[CALLBACK_TYPE]] = {} @property def auth_providers(self) -> list[AuthProvider]: @@ -446,6 +447,28 @@ class AuthManager: """Delete a refresh token.""" await self._store.async_remove_refresh_token(refresh_token) + callbacks = self._revoke_callbacks.pop(refresh_token.id, []) + for revoke_callback in callbacks: + revoke_callback() + + @callback + def async_register_revoke_token_callback( + self, refresh_token_id: str, revoke_callback: CALLBACK_TYPE + ) -> CALLBACK_TYPE: + """Register a callback to be called when the refresh token id is revoked.""" + if refresh_token_id not in self._revoke_callbacks: + self._revoke_callbacks[refresh_token_id] = [] + + callbacks = self._revoke_callbacks[refresh_token_id] + callbacks.append(revoke_callback) + + @callback + def unregister() -> None: + if revoke_callback in callbacks: + callbacks.remove(revoke_callback) + + return unregister + @callback def async_create_access_token( self, refresh_token: models.RefreshToken, remote_ip: str | None = None diff --git a/homeassistant/components/websocket_api/auth.py b/homeassistant/components/websocket_api/auth.py index 130ffe82840..794dae77153 100644 --- a/homeassistant/components/websocket_api/auth.py +++ b/homeassistant/components/websocket_api/auth.py @@ -11,7 +11,7 @@ from voluptuous.humanize import humanize_error from homeassistant.auth.models import RefreshToken, User from homeassistant.components.http.ban import process_success_login, process_wrong_login from homeassistant.const import __version__ -from homeassistant.core import HomeAssistant +from homeassistant.core import CALLBACK_TYPE, HomeAssistant from .connection import ActiveConnection from .error import Disconnect @@ -57,11 +57,13 @@ class AuthPhase: logger: WebSocketAdapter, hass: HomeAssistant, send_message: Callable[[str | dict[str, Any]], None], + cancel_ws: CALLBACK_TYPE, request: Request, ) -> None: """Initialize the authentiated connection.""" self._hass = hass self._send_message = send_message + self._cancel_ws = cancel_ws self._logger = logger self._request = request @@ -83,7 +85,14 @@ class AuthPhase: msg["access_token"] ) if refresh_token is not None: - return await self._async_finish_auth(refresh_token.user, refresh_token) + conn = await self._async_finish_auth(refresh_token.user, refresh_token) + conn.subscriptions[ + "auth" + ] = self._hass.auth.async_register_revoke_token_callback( + refresh_token.id, self._cancel_ws + ) + + return conn self._send_message(auth_invalid_message("Invalid access token or password")) await process_wrong_login(self._request) diff --git a/homeassistant/components/websocket_api/http.py b/homeassistant/components/websocket_api/http.py index aa6a74b27ec..bce6713403a 100644 --- a/homeassistant/components/websocket_api/http.py +++ b/homeassistant/components/websocket_api/http.py @@ -161,7 +161,9 @@ class WebSocketHandler: # event we do not want to block for websocket responses self._writer_task = asyncio.create_task(self._writer()) - auth = AuthPhase(self._logger, self.hass, self._send_message, request) + auth = AuthPhase( + self._logger, self.hass, self._send_message, self._cancel, request + ) connection = None disconnect_warn = None diff --git a/tests/auth/test_init.py b/tests/auth/test_init.py index 4c3d93ede15..fa8c86536ca 100644 --- a/tests/auth/test_init.py +++ b/tests/auth/test_init.py @@ -529,6 +529,42 @@ async def test_remove_refresh_token(mock_hass): assert await manager.async_validate_access_token(access_token) is None +async def test_register_revoke_token_callback(mock_hass): + """Test that a registered revoke token callback is called.""" + manager = await auth.auth_manager_from_config(mock_hass, [], []) + user = MockUser().add_to_auth_manager(manager) + refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID) + + called = False + + def cb(): + nonlocal called + called = True + + manager.async_register_revoke_token_callback(refresh_token.id, cb) + await manager.async_remove_refresh_token(refresh_token) + assert called + + +async def test_unregister_revoke_token_callback(mock_hass): + """Test that a revoke token callback can be unregistered.""" + manager = await auth.auth_manager_from_config(mock_hass, [], []) + user = MockUser().add_to_auth_manager(manager) + refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID) + + called = False + + def cb(): + nonlocal called + called = True + + unregister = manager.async_register_revoke_token_callback(refresh_token.id, cb) + unregister() + + await manager.async_remove_refresh_token(refresh_token) + assert not called + + async def test_create_access_token(mock_hass): """Test normal refresh_token's jwt_key keep same after used.""" manager = await auth.auth_manager_from_config(mock_hass, [], []) diff --git a/tests/components/auth/test_init.py b/tests/components/auth/test_init.py index 207667fc26d..b615ba4156c 100644 --- a/tests/components/auth/test_init.py +++ b/tests/components/auth/test_init.py @@ -363,11 +363,15 @@ async def test_ws_refresh_tokens(hass, hass_ws_client, hass_access_token): assert token["last_used_ip"] == refresh_token.last_used_ip -async def test_ws_delete_refresh_token(hass, hass_ws_client, hass_access_token): +async def test_ws_delete_refresh_token( + hass, hass_admin_user, hass_admin_credential, hass_ws_client, hass_access_token +): """Test deleting a refresh token.""" assert await async_setup_component(hass, "auth", {"http": {}}) - refresh_token = await hass.auth.async_validate_access_token(hass_access_token) + refresh_token = await hass.auth.async_create_refresh_token( + hass_admin_user, CLIENT_ID, credential=hass_admin_credential + ) ws_client = await hass_ws_client(hass, hass_access_token) @@ -382,7 +386,7 @@ async def test_ws_delete_refresh_token(hass, hass_ws_client, hass_access_token): result = await ws_client.receive_json() assert result["success"], result - refresh_token = await hass.auth.async_validate_access_token(hass_access_token) + refresh_token = await hass.auth.async_get_refresh_token(refresh_token.id) assert refresh_token is None diff --git a/tests/components/websocket_api/test_auth.py b/tests/components/websocket_api/test_auth.py index a57faf4a895..7834474470c 100644 --- a/tests/components/websocket_api/test_auth.py +++ b/tests/components/websocket_api/test_auth.py @@ -1,6 +1,7 @@ """Test auth of websocket API.""" from unittest.mock import patch +import aiohttp import pytest from homeassistant.components.websocket_api.auth import ( @@ -191,3 +192,15 @@ async def test_auth_with_invalid_token(hass, hass_client_no_auth): auth_msg = await ws.receive_json() assert auth_msg["type"] == TYPE_AUTH_INVALID + + +async def test_auth_close_after_revoke(hass, websocket_client, hass_access_token): + """Test that a websocket is closed after the refresh token is revoked.""" + assert not websocket_client.closed + + refresh_token = await hass.auth.async_validate_access_token(hass_access_token) + await hass.auth.async_remove_refresh_token(refresh_token) + + msg = await websocket_client.receive() + assert msg.type == aiohttp.WSMsgType.CLOSE + assert websocket_client.closed