Skip processing websocket_api schema if it has no arguments (#115618)

* Skip processing websocket_api schema if has no arguments

About 40% of the websocket commands on first connection have
no arguments. We can skip processing the schema for these cases

* cover

* fixes

* allow extra

* Revert "allow extra"

This reverts commit 85d9ec36b3.

* match behavior
This commit is contained in:
J. Nick Koston 2024-04-18 09:41:08 -05:00 committed by GitHub
parent ea8d4d0dca
commit 588c260dc5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 86 additions and 8 deletions

View file

@ -4,7 +4,7 @@ from __future__ import annotations
from collections.abc import Callable, Hashable
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal
from aiohttp import web
import voluptuous as vol
@ -65,9 +65,9 @@ class ActiveConnection:
self.last_id = 0
self.can_coalesce = False
self.supported_features: dict[str, float] = {}
self.handlers: dict[str, tuple[MessageHandler, vol.Schema]] = self.hass.data[
const.DOMAIN
]
self.handlers: dict[str, tuple[MessageHandler, vol.Schema | Literal[False]]] = (
self.hass.data[const.DOMAIN]
)
self.binary_handlers: list[BinaryHandler | None] = []
current_connection.set(self)
@ -185,6 +185,7 @@ class ActiveConnection:
or (
not (cur_id := msg.get("id"))
or type(cur_id) is not int # noqa: E721
or cur_id < 0
or not (type_ := msg.get("type"))
or type(type_) is not str # noqa: E721
)
@ -220,7 +221,12 @@ class ActiveConnection:
handler, schema = handler_schema
try:
handler(self.hass, self, schema(msg))
if schema is False:
if len(msg) > 2:
raise vol.Invalid("extra keys not allowed")
handler(self.hass, self, msg)
else:
handler(self.hass, self, schema(msg))
except Exception as err: # pylint: disable=broad-except
self.async_handle_exception(msg, err)

View file

@ -4,7 +4,7 @@ from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import Any
from typing import TYPE_CHECKING, Any
import voluptuous as vol
@ -137,7 +137,7 @@ def websocket_command(
The schema must be either a dictionary where the keys are voluptuous markers, or
a voluptuous.All schema where the first item is a voluptuous Mapping schema.
"""
if isinstance(schema, dict):
if is_dict := isinstance(schema, dict):
command = schema["type"]
else:
command = schema.validators[0].schema["type"]
@ -145,9 +145,13 @@ def websocket_command(
def decorate(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler:
"""Decorate ws command function."""
# pylint: disable=protected-access
if isinstance(schema, dict):
if is_dict and len(schema) == 1: # type only empty schema
func._ws_schema = False # type: ignore[attr-defined]
elif is_dict:
func._ws_schema = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend(schema) # type: ignore[attr-defined]
else:
if TYPE_CHECKING:
assert not isinstance(schema, dict)
extended_schema = vol.All(
schema.validators[0].extend(
messages.BASE_COMMAND_MESSAGE_SCHEMA.schema

View file

@ -1,5 +1,7 @@
"""Test decorators."""
import voluptuous as vol
from homeassistant.components import http, websocket_api
from homeassistant.core import HomeAssistant
@ -31,9 +33,16 @@ async def test_async_response_request_context(
def get_request(hass, connection, msg):
handle_request(http.current_request.get(), connection, msg)
@websocket_api.websocket_command(
{"type": "test-get-request-with-arg", vol.Required("arg"): str}
)
def get_with_arg_request(hass, connection, msg):
handle_request(http.current_request.get(), connection, msg)
websocket_api.async_register_command(hass, executor_get_request)
websocket_api.async_register_command(hass, async_get_request)
websocket_api.async_register_command(hass, get_request)
websocket_api.async_register_command(hass, get_with_arg_request)
await websocket_client.send_json(
{
@ -71,6 +80,65 @@ async def test_async_response_request_context(
assert not msg["success"]
assert msg["error"]["code"] == "not_found"
await websocket_client.send_json(
{
"id": 8,
"type": "test-get-request-with-arg",
}
)
msg = await websocket_client.receive_json()
assert msg["id"] == 8
assert not msg["success"]
assert msg["error"]["code"] == "invalid_format"
assert (
msg["error"]["message"] == "required key not provided @ data['arg']. Got None"
)
await websocket_client.send_json(
{
"id": 9,
"type": "test-get-request-with-arg",
"arg": "dog",
}
)
msg = await websocket_client.receive_json()
assert msg["id"] == 9
assert msg["success"]
assert msg["result"] == "/api/websocket"
await websocket_client.send_json(
{
"id": -1,
"type": "test-get-request-with-arg",
"arg": "dog",
}
)
msg = await websocket_client.receive_json()
assert msg["id"] == -1
assert not msg["success"]
assert msg["error"]["code"] == "invalid_format"
assert msg["error"]["message"] == "Message incorrectly formatted."
await websocket_client.send_json(
{
"id": 10,
"type": "test-get-request",
"not_valid": "dog",
}
)
msg = await websocket_client.receive_json()
assert msg["id"] == 10
assert not msg["success"]
assert msg["error"]["code"] == "invalid_format"
assert msg["error"]["message"] == (
"extra keys not allowed. "
"Got {'id': 10, 'type': 'test-get-request', 'not_valid': 'dog'}"
)
async def test_supervisor_only(hass: HomeAssistant, websocket_client) -> None:
"""Test that only the Supervisor can make requests."""