Skip processing websocket_api schema if it has no arguments (#115618)
* Skip processing websocket_api schema if has no arguments
About 40% of the websocket commands on first connection have
no arguments. We can skip processing the schema for these cases
* cover
* fixes
* allow extra
* Revert "allow extra"
This reverts commit 85d9ec36b3
.
* match behavior
This commit is contained in:
parent
ea8d4d0dca
commit
588c260dc5
3 changed files with 86 additions and 8 deletions
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
|
||||
from collections.abc import Callable, Hashable
|
||||
from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from aiohttp import web
|
||||
import voluptuous as vol
|
||||
|
@ -65,9 +65,9 @@ class ActiveConnection:
|
|||
self.last_id = 0
|
||||
self.can_coalesce = False
|
||||
self.supported_features: dict[str, float] = {}
|
||||
self.handlers: dict[str, tuple[MessageHandler, vol.Schema]] = self.hass.data[
|
||||
const.DOMAIN
|
||||
]
|
||||
self.handlers: dict[str, tuple[MessageHandler, vol.Schema | Literal[False]]] = (
|
||||
self.hass.data[const.DOMAIN]
|
||||
)
|
||||
self.binary_handlers: list[BinaryHandler | None] = []
|
||||
current_connection.set(self)
|
||||
|
||||
|
@ -185,6 +185,7 @@ class ActiveConnection:
|
|||
or (
|
||||
not (cur_id := msg.get("id"))
|
||||
or type(cur_id) is not int # noqa: E721
|
||||
or cur_id < 0
|
||||
or not (type_ := msg.get("type"))
|
||||
or type(type_) is not str # noqa: E721
|
||||
)
|
||||
|
@ -220,7 +221,12 @@ class ActiveConnection:
|
|||
handler, schema = handler_schema
|
||||
|
||||
try:
|
||||
handler(self.hass, self, schema(msg))
|
||||
if schema is False:
|
||||
if len(msg) > 2:
|
||||
raise vol.Invalid("extra keys not allowed")
|
||||
handler(self.hass, self, msg)
|
||||
else:
|
||||
handler(self.hass, self, schema(msg))
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
self.async_handle_exception(msg, err)
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -137,7 +137,7 @@ def websocket_command(
|
|||
The schema must be either a dictionary where the keys are voluptuous markers, or
|
||||
a voluptuous.All schema where the first item is a voluptuous Mapping schema.
|
||||
"""
|
||||
if isinstance(schema, dict):
|
||||
if is_dict := isinstance(schema, dict):
|
||||
command = schema["type"]
|
||||
else:
|
||||
command = schema.validators[0].schema["type"]
|
||||
|
@ -145,9 +145,13 @@ def websocket_command(
|
|||
def decorate(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler:
|
||||
"""Decorate ws command function."""
|
||||
# pylint: disable=protected-access
|
||||
if isinstance(schema, dict):
|
||||
if is_dict and len(schema) == 1: # type only empty schema
|
||||
func._ws_schema = False # type: ignore[attr-defined]
|
||||
elif is_dict:
|
||||
func._ws_schema = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend(schema) # type: ignore[attr-defined]
|
||||
else:
|
||||
if TYPE_CHECKING:
|
||||
assert not isinstance(schema, dict)
|
||||
extended_schema = vol.All(
|
||||
schema.validators[0].extend(
|
||||
messages.BASE_COMMAND_MESSAGE_SCHEMA.schema
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
"""Test decorators."""
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import http, websocket_api
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
|
@ -31,9 +33,16 @@ async def test_async_response_request_context(
|
|||
def get_request(hass, connection, msg):
|
||||
handle_request(http.current_request.get(), connection, msg)
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{"type": "test-get-request-with-arg", vol.Required("arg"): str}
|
||||
)
|
||||
def get_with_arg_request(hass, connection, msg):
|
||||
handle_request(http.current_request.get(), connection, msg)
|
||||
|
||||
websocket_api.async_register_command(hass, executor_get_request)
|
||||
websocket_api.async_register_command(hass, async_get_request)
|
||||
websocket_api.async_register_command(hass, get_request)
|
||||
websocket_api.async_register_command(hass, get_with_arg_request)
|
||||
|
||||
await websocket_client.send_json(
|
||||
{
|
||||
|
@ -71,6 +80,65 @@ async def test_async_response_request_context(
|
|||
assert not msg["success"]
|
||||
assert msg["error"]["code"] == "not_found"
|
||||
|
||||
await websocket_client.send_json(
|
||||
{
|
||||
"id": 8,
|
||||
"type": "test-get-request-with-arg",
|
||||
}
|
||||
)
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == 8
|
||||
assert not msg["success"]
|
||||
assert msg["error"]["code"] == "invalid_format"
|
||||
assert (
|
||||
msg["error"]["message"] == "required key not provided @ data['arg']. Got None"
|
||||
)
|
||||
|
||||
await websocket_client.send_json(
|
||||
{
|
||||
"id": 9,
|
||||
"type": "test-get-request-with-arg",
|
||||
"arg": "dog",
|
||||
}
|
||||
)
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == 9
|
||||
assert msg["success"]
|
||||
assert msg["result"] == "/api/websocket"
|
||||
|
||||
await websocket_client.send_json(
|
||||
{
|
||||
"id": -1,
|
||||
"type": "test-get-request-with-arg",
|
||||
"arg": "dog",
|
||||
}
|
||||
)
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == -1
|
||||
assert not msg["success"]
|
||||
assert msg["error"]["code"] == "invalid_format"
|
||||
assert msg["error"]["message"] == "Message incorrectly formatted."
|
||||
|
||||
await websocket_client.send_json(
|
||||
{
|
||||
"id": 10,
|
||||
"type": "test-get-request",
|
||||
"not_valid": "dog",
|
||||
}
|
||||
)
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == 10
|
||||
assert not msg["success"]
|
||||
assert msg["error"]["code"] == "invalid_format"
|
||||
assert msg["error"]["message"] == (
|
||||
"extra keys not allowed. "
|
||||
"Got {'id': 10, 'type': 'test-get-request', 'not_valid': 'dog'}"
|
||||
)
|
||||
|
||||
|
||||
async def test_supervisor_only(hass: HomeAssistant, websocket_client) -> None:
|
||||
"""Test that only the Supervisor can make requests."""
|
||||
|
|
Loading…
Add table
Reference in a new issue