diff --git a/homeassistant/components/websocket_api/connection.py b/homeassistant/components/websocket_api/connection.py index 63b4418a19d..3c0743601dd 100644 --- a/homeassistant/components/websocket_api/connection.py +++ b/homeassistant/components/websocket_api/connection.py @@ -4,7 +4,7 @@ from __future__ import annotations from collections.abc import Callable, Hashable from contextvars import ContextVar -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal from aiohttp import web import voluptuous as vol @@ -65,9 +65,9 @@ class ActiveConnection: self.last_id = 0 self.can_coalesce = False self.supported_features: dict[str, float] = {} - self.handlers: dict[str, tuple[MessageHandler, vol.Schema]] = self.hass.data[ - const.DOMAIN - ] + self.handlers: dict[str, tuple[MessageHandler, vol.Schema | Literal[False]]] = ( + self.hass.data[const.DOMAIN] + ) self.binary_handlers: list[BinaryHandler | None] = [] current_connection.set(self) @@ -185,6 +185,7 @@ class ActiveConnection: or ( not (cur_id := msg.get("id")) or type(cur_id) is not int # noqa: E721 + or cur_id < 0 or not (type_ := msg.get("type")) or type(type_) is not str # noqa: E721 ) @@ -220,7 +221,12 @@ class ActiveConnection: handler, schema = handler_schema try: - handler(self.hass, self, schema(msg)) + if schema is False: + if len(msg) > 2: + raise vol.Invalid("extra keys not allowed") + handler(self.hass, self, msg) + else: + handler(self.hass, self, schema(msg)) except Exception as err: # pylint: disable=broad-except self.async_handle_exception(msg, err) diff --git a/homeassistant/components/websocket_api/decorators.py b/homeassistant/components/websocket_api/decorators.py index 51643752a0f..0ed8be30139 100644 --- a/homeassistant/components/websocket_api/decorators.py +++ b/homeassistant/components/websocket_api/decorators.py @@ -4,7 +4,7 @@ from __future__ import annotations from collections.abc import Callable from functools import wraps -from typing import Any +from typing import TYPE_CHECKING, Any import voluptuous as vol @@ -137,7 +137,7 @@ def websocket_command( The schema must be either a dictionary where the keys are voluptuous markers, or a voluptuous.All schema where the first item is a voluptuous Mapping schema. """ - if isinstance(schema, dict): + if is_dict := isinstance(schema, dict): command = schema["type"] else: command = schema.validators[0].schema["type"] @@ -145,9 +145,13 @@ def websocket_command( def decorate(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler: """Decorate ws command function.""" # pylint: disable=protected-access - if isinstance(schema, dict): + if is_dict and len(schema) == 1: # type only empty schema + func._ws_schema = False # type: ignore[attr-defined] + elif is_dict: func._ws_schema = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend(schema) # type: ignore[attr-defined] else: + if TYPE_CHECKING: + assert not isinstance(schema, dict) extended_schema = vol.All( schema.validators[0].extend( messages.BASE_COMMAND_MESSAGE_SCHEMA.schema diff --git a/tests/components/websocket_api/test_decorators.py b/tests/components/websocket_api/test_decorators.py index 3e9c13a8b15..0ade5329190 100644 --- a/tests/components/websocket_api/test_decorators.py +++ b/tests/components/websocket_api/test_decorators.py @@ -1,5 +1,7 @@ """Test decorators.""" +import voluptuous as vol + from homeassistant.components import http, websocket_api from homeassistant.core import HomeAssistant @@ -31,9 +33,16 @@ async def test_async_response_request_context( def get_request(hass, connection, msg): handle_request(http.current_request.get(), connection, msg) + @websocket_api.websocket_command( + {"type": "test-get-request-with-arg", vol.Required("arg"): str} + ) + def get_with_arg_request(hass, connection, msg): + handle_request(http.current_request.get(), connection, msg) + websocket_api.async_register_command(hass, executor_get_request) websocket_api.async_register_command(hass, async_get_request) websocket_api.async_register_command(hass, get_request) + websocket_api.async_register_command(hass, get_with_arg_request) await websocket_client.send_json( { @@ -71,6 +80,65 @@ async def test_async_response_request_context( assert not msg["success"] assert msg["error"]["code"] == "not_found" + await websocket_client.send_json( + { + "id": 8, + "type": "test-get-request-with-arg", + } + ) + + msg = await websocket_client.receive_json() + assert msg["id"] == 8 + assert not msg["success"] + assert msg["error"]["code"] == "invalid_format" + assert ( + msg["error"]["message"] == "required key not provided @ data['arg']. Got None" + ) + + await websocket_client.send_json( + { + "id": 9, + "type": "test-get-request-with-arg", + "arg": "dog", + } + ) + + msg = await websocket_client.receive_json() + assert msg["id"] == 9 + assert msg["success"] + assert msg["result"] == "/api/websocket" + + await websocket_client.send_json( + { + "id": -1, + "type": "test-get-request-with-arg", + "arg": "dog", + } + ) + + msg = await websocket_client.receive_json() + assert msg["id"] == -1 + assert not msg["success"] + assert msg["error"]["code"] == "invalid_format" + assert msg["error"]["message"] == "Message incorrectly formatted." + + await websocket_client.send_json( + { + "id": 10, + "type": "test-get-request", + "not_valid": "dog", + } + ) + + msg = await websocket_client.receive_json() + assert msg["id"] == 10 + assert not msg["success"] + assert msg["error"]["code"] == "invalid_format" + assert msg["error"]["message"] == ( + "extra keys not allowed. " + "Got {'id': 10, 'type': 'test-get-request', 'not_valid': 'dog'}" + ) + async def test_supervisor_only(hass: HomeAssistant, websocket_client) -> None: """Test that only the Supervisor can make requests."""