diff --git a/homeassistant/components/websocket_api/auth.py b/homeassistant/components/websocket_api/auth.py index 9c074588a17..d0831f2e90e 100644 --- a/homeassistant/components/websocket_api/auth.py +++ b/homeassistant/components/websocket_api/auth.py @@ -12,6 +12,7 @@ from homeassistant.auth.models import RefreshToken, User from homeassistant.components.http.ban import process_success_login, process_wrong_login from homeassistant.const import __version__ from homeassistant.core import CALLBACK_TYPE, HomeAssistant +from homeassistant.util.json import JsonValueType from .connection import ActiveConnection from .error import Disconnect @@ -67,10 +68,10 @@ class AuthPhase: self._logger = logger self._request = request - async def async_handle(self, msg: dict[str, str]) -> ActiveConnection: + async def async_handle(self, msg: JsonValueType) -> ActiveConnection: """Handle authentication.""" try: - msg = AUTH_MESSAGE_SCHEMA(msg) + valid_msg = AUTH_MESSAGE_SCHEMA(msg) except vol.Invalid as err: error_msg = ( f"Auth message incorrectly formatted: {humanize_error(msg, err)}" @@ -79,20 +80,19 @@ class AuthPhase: self._send_message(auth_invalid_message(error_msg)) raise Disconnect from err - if "access_token" in msg: - self._logger.debug("Received access_token") - refresh_token = await self._hass.auth.async_validate_access_token( - msg["access_token"] + if (access_token := valid_msg.get("access_token")) and ( + refresh_token := await self._hass.auth.async_validate_access_token( + access_token + ) + ): + conn = await self._async_finish_auth(refresh_token.user, refresh_token) + conn.subscriptions[ + "auth" + ] = self._hass.auth.async_register_revoke_token_callback( + refresh_token.id, self._cancel_ws ) - if refresh_token is not None: - conn = await self._async_finish_auth(refresh_token.user, refresh_token) - conn.subscriptions[ - "auth" - ] = self._hass.auth.async_register_revoke_token_callback( - refresh_token.id, self._cancel_ws - ) - return conn + return conn self._send_message(auth_invalid_message("Invalid access token or password")) await process_wrong_login(self._request) diff --git a/homeassistant/components/websocket_api/connection.py b/homeassistant/components/websocket_api/connection.py index a91a5178830..c07661893f7 100644 --- a/homeassistant/components/websocket_api/connection.py +++ b/homeassistant/components/websocket_api/connection.py @@ -13,6 +13,7 @@ from homeassistant.auth.models import RefreshToken, User from homeassistant.components.http import current_request from homeassistant.core import Context, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError, Unauthorized +from homeassistant.util.json import JsonValueType from . import const, messages from .util import describe_request @@ -144,7 +145,7 @@ class ActiveConnection: self.binary_handlers[index] = None @callback - def async_handle(self, msg: dict[str, Any]) -> None: + def async_handle(self, msg: JsonValueType) -> None: """Handle a single incoming message.""" if ( # Not using isinstance as we don't care about children @@ -157,10 +158,11 @@ class ActiveConnection: or type(type_) is not str # pylint: disable=unidiomatic-typecheck ) ): - self.logger.error("Received invalid command", msg) + self.logger.error("Received invalid command: %s", msg) + id_ = msg.get("id") if isinstance(msg, dict) else 0 self.send_message( messages.error_message( - msg.get("id"), + id_, # type: ignore[arg-type] const.ERR_INVALID_FORMAT, "Message incorrectly formatted.", ) diff --git a/homeassistant/components/websocket_api/http.py b/homeassistant/components/websocket_api/http.py index 54daf89d8dd..6ac0e10a76c 100644 --- a/homeassistant/components/websocket_api/http.py +++ b/homeassistant/components/websocket_api/http.py @@ -56,8 +56,7 @@ class WebSocketAdapter(logging.LoggerAdapter): def process(self, msg: str, kwargs: Any) -> tuple[str, Any]: """Add connid to websocket log messages.""" - if not self.extra or "connid" not in self.extra: - return msg, kwargs + assert self.extra is not None return f'[{self.extra["connid"]}] {msg}', kwargs @@ -81,7 +80,7 @@ class WebSocketHandler: # to where messages are queued. This allows the implementation # to use a deque and an asyncio.Future to avoid the overhead of # an asyncio.Queue. - self._message_queue: deque = deque() + self._message_queue: deque[str | Callable[[], str] | None] = deque() self._ready_future: asyncio.Future[None] | None = None def __repr__(self) -> str: @@ -302,14 +301,14 @@ class WebSocketHandler: raise Disconnect try: - msg_data = msg.json(loads=json_loads) + auth_msg_data = json_loads(msg.data) except ValueError as err: disconnect_warn = "Received invalid JSON." raise Disconnect from err if is_enabled_for(logging_debug): - debug("%s: Received %s", self.description, msg_data) - connection = await auth.async_handle(msg_data) + debug("%s: Received %s", self.description, auth_msg_data) + connection = await auth.async_handle(auth_msg_data) self._connection = connection hass.data[DATA_CONNECTIONS] = hass.data.get(DATA_CONNECTIONS, 0) + 1 async_dispatcher_send(hass, SIGNAL_WEBSOCKET_CONNECTED) @@ -317,7 +316,7 @@ class WebSocketHandler: self._authenticated = True # # - # Our websocket implementation is backed by an asyncio.Queue + # Our websocket implementation is backed by a deque # # As back-pressure builds, the queue will back up and use more memory # until we disconnect the client when the queue size reaches @@ -351,6 +350,8 @@ class WebSocketHandler: # reach the code to set the limit, so we have to set it directly. # wsock._writer._limit = 2**20 # type: ignore[union-attr] # pylint: disable=protected-access + async_handle_str = connection.async_handle + async_handle_binary = connection.async_handle_binary # Command phase while not wsock.closed: @@ -365,7 +366,7 @@ class WebSocketHandler: break handler = msg.data[0] payload = msg.data[1:] - connection.async_handle_binary(handler, payload) + async_handle_binary(handler, payload) continue if msg.type != WSMsgType.TEXT: @@ -373,20 +374,20 @@ class WebSocketHandler: break try: - msg_data = msg.json(loads=json_loads) + command_msg_data = json_loads(msg.data) except ValueError: disconnect_warn = "Received invalid JSON." break if is_enabled_for(logging_debug): - debug("%s: Received %s", self.description, msg_data) + debug("%s: Received %s", self.description, command_msg_data) - if not isinstance(msg_data, list): - connection.async_handle(msg_data) + if not isinstance(command_msg_data, list): + async_handle_str(command_msg_data) continue - for split_msg in msg_data: - connection.async_handle(split_msg) + for split_msg in command_msg_data: + async_handle_str(split_msg) except asyncio.CancelledError: debug("%s: Connection cancelled", self.description) diff --git a/tests/components/websocket_api/test_auth.py b/tests/components/websocket_api/test_auth.py index 070bd68d44a..51bff1af0d7 100644 --- a/tests/components/websocket_api/test_auth.py +++ b/tests/components/websocket_api/test_auth.py @@ -2,6 +2,7 @@ from unittest.mock import patch import aiohttp +from aiohttp import WSMsgType import pytest from homeassistant.auth.providers.legacy_api_password import ( @@ -223,3 +224,79 @@ async def test_auth_close_after_revoke( msg = await websocket_client.receive() assert msg.type == aiohttp.WSMsgType.CLOSE assert websocket_client.closed + + +async def test_auth_sending_invalid_json_disconnects( + hass: HomeAssistant, hass_client_no_auth: ClientSessionGenerator +) -> None: + """Test sending invalid json during auth.""" + assert await async_setup_component(hass, "websocket_api", {}) + await hass.async_block_till_done() + + client = await hass_client_no_auth() + + async with client.ws_connect(URL) as ws: + auth_msg = await ws.receive_json() + assert auth_msg["type"] == TYPE_AUTH_REQUIRED + + await ws.send_str("[--INVALID--JSON--]") + + auth_msg = await ws.receive() + assert auth_msg.type == WSMsgType.close + + +async def test_auth_sending_binary_disconnects( + hass: HomeAssistant, hass_client_no_auth: ClientSessionGenerator +) -> None: + """Test sending bytes during auth.""" + assert await async_setup_component(hass, "websocket_api", {}) + await hass.async_block_till_done() + + client = await hass_client_no_auth() + + async with client.ws_connect(URL) as ws: + auth_msg = await ws.receive_json() + assert auth_msg["type"] == TYPE_AUTH_REQUIRED + + await ws.send_bytes(b"[INVALID]") + + auth_msg = await ws.receive() + assert auth_msg.type == WSMsgType.close + + +async def test_auth_close_disconnects( + hass: HomeAssistant, hass_client_no_auth: ClientSessionGenerator +) -> None: + """Test closing during auth.""" + assert await async_setup_component(hass, "websocket_api", {}) + await hass.async_block_till_done() + + client = await hass_client_no_auth() + + async with client.ws_connect(URL) as ws: + auth_msg = await ws.receive_json() + assert auth_msg["type"] == TYPE_AUTH_REQUIRED + + await ws.close() + + auth_msg = await ws.receive() + assert auth_msg.type == WSMsgType.CLOSED + + +async def test_auth_sending_unknown_type_disconnects( + hass: HomeAssistant, hass_client_no_auth: ClientSessionGenerator +) -> None: + """Test sending unknown type during auth.""" + assert await async_setup_component(hass, "websocket_api", {}) + await hass.async_block_till_done() + + client = await hass_client_no_auth() + + async with client.ws_connect(URL) as ws: + auth_msg = await ws.receive_json() + assert auth_msg["type"] == TYPE_AUTH_REQUIRED + + # pylint: disable-next=protected-access + await ws._writer._send_frame(b"1" * 130, 0x30) + auth_msg = await ws.receive() + assert auth_msg.type == WSMsgType.close diff --git a/tests/components/websocket_api/test_http.py b/tests/components/websocket_api/test_http.py index 3205d40b52d..b94df47213e 100644 --- a/tests/components/websocket_api/test_http.py +++ b/tests/components/websocket_api/test_http.py @@ -136,6 +136,89 @@ async def test_cleanup_on_cancellation( assert len(subscriptions) == 0 +async def test_delayed_response_handler( + hass: HomeAssistant, + websocket_client: MockHAClientWebSocket, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test a handler that responds after a connection has already been closed.""" + + subscriptions = None + + # Register a handler that responds after it returns + @callback + @websocket_command( + { + "type": "late_responder", + } + ) + def async_late_responder( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] + ) -> None: + msg_id: int = msg["id"] + nonlocal subscriptions + subscriptions = connection.subscriptions + connection.subscriptions[msg_id] = lambda: None + connection.send_result(msg_id) + + async def _async_late_send_message(): + await asyncio.sleep(0.05) + connection.send_event(msg_id, {"event": "any"}) + + hass.async_create_task(_async_late_send_message()) + + async_register_command(hass, async_late_responder) + + await websocket_client.send_json({"id": 1, "type": "ping"}) + msg = await websocket_client.receive_json() + assert msg["id"] == 1 + assert msg["type"] == "pong" + assert not subscriptions + await websocket_client.send_json({"id": 2, "type": "late_responder"}) + msg = await websocket_client.receive_json() + assert msg["id"] == 2 + assert msg["type"] == "result" + assert len(subscriptions) == 2 + assert await websocket_client.close() + await hass.async_block_till_done() + assert len(subscriptions) == 0 + + assert "Tried to send message" in caplog.text + assert "on closed connection" in caplog.text + + +async def test_ensure_disconnect_invalid_json( + hass: HomeAssistant, + websocket_client: MockHAClientWebSocket, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test we get disconnected when sending invalid JSON.""" + + await websocket_client.send_json({"id": 1, "type": "ping"}) + msg = await websocket_client.receive_json() + assert msg["id"] == 1 + assert msg["type"] == "pong" + await websocket_client.send_str("[--INVALID-JSON--]") + msg = await websocket_client.receive() + assert msg.type == WSMsgType.CLOSE + + +async def test_ensure_disconnect_invalid_binary( + hass: HomeAssistant, + websocket_client: MockHAClientWebSocket, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test we get disconnected when sending invalid bytes.""" + + await websocket_client.send_json({"id": 1, "type": "ping"}) + msg = await websocket_client.receive_json() + assert msg["id"] == 1 + assert msg["type"] == "pong" + await websocket_client.send_bytes(b"") + msg = await websocket_client.receive() + assert msg.type == WSMsgType.CLOSE + + async def test_pending_msg_peak( hass: HomeAssistant, mock_low_peak, @@ -299,6 +382,58 @@ async def test_prepare_fail( assert "Timeout preparing request" in caplog.text +async def test_enable_coalesce( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test enabling coalesce.""" + websocket_client = await hass_ws_client(hass) + + await websocket_client.send_json( + { + "id": 1, + "type": "supported_features", + "features": {const.FEATURE_COALESCE_MESSAGES: 1}, + } + ) + msg = await websocket_client.receive_json() + assert msg["id"] == 1 + assert msg["success"] is True + send_tasks: list[asyncio.Future] = [] + ids: set[int] = set() + start_id = 2 + + for idx in range(10): + id_ = idx + start_id + ids.add(id_) + send_tasks.append(websocket_client.send_json({"id": id_, "type": "ping"})) + + await asyncio.gather(*send_tasks) + returned_ids: set[int] = set() + for _ in range(10): + msg = await websocket_client.receive_json() + assert msg["type"] == "pong" + returned_ids.add(msg["id"]) + + assert ids == returned_ids + + # Now close + send_tasks_with_close: list[asyncio.Future] = [] + start_id = 12 + for idx in range(10): + id_ = idx + start_id + send_tasks_with_close.append( + websocket_client.send_json({"id": id_, "type": "ping"}) + ) + + send_tasks_with_close.append(websocket_client.close()) + send_tasks_with_close.append(websocket_client.send_json({"id": 50, "type": "ping"})) + + with pytest.raises(ConnectionResetError): + await asyncio.gather(*send_tasks_with_close) + + async def test_binary_message( hass: HomeAssistant, websocket_client, caplog: pytest.LogCaptureFixture ) -> None: