Improve performance of websocket_api dispatch (#88496)

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
J. Nick Koston 2023-02-20 12:51:34 -06:00 committed by GitHub
parent cc4a179ca8
commit ecf87ae979
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 11 deletions

View file

@ -46,6 +46,7 @@ class ActiveConnection:
self.subscriptions: dict[Hashable, Callable[[], Any]] = {} self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
self.last_id = 0 self.last_id = 0
self.supported_features: dict[str, float] = {} self.supported_features: dict[str, float] = {}
self.handlers = self.hass.data[const.DOMAIN]
current_connection.set(self) current_connection.set(self)
def get_description(self, request: web.Request | None) -> str: def get_description(self, request: web.Request | None) -> str:
@ -72,12 +73,17 @@ class ActiveConnection:
@callback @callback
def async_handle(self, msg: dict[str, Any]) -> None: def async_handle(self, msg: dict[str, Any]) -> None:
"""Handle a single incoming message.""" """Handle a single incoming message."""
handlers = self.hass.data[const.DOMAIN] if (
# Not using isinstance as we don't care about children
try: # as these are always coming from JSON
msg = messages.MINIMAL_MESSAGE_SCHEMA(msg) type(msg) is not dict # pylint: disable=unidiomatic-typecheck
cur_id = msg["id"] or (
except vol.Invalid: not (cur_id := msg.get("id"))
or type(cur_id) is not int # pylint: disable=unidiomatic-typecheck
or not (type_ := msg.get("type"))
or type(type_) is not str # pylint: disable=unidiomatic-typecheck
)
):
self.logger.error("Received invalid command", msg) self.logger.error("Received invalid command", msg)
self.send_message( self.send_message(
messages.error_message( messages.error_message(
@ -96,8 +102,8 @@ class ActiveConnection:
) )
return return
if msg["type"] not in handlers: if not (handler_schema := self.handlers.get(type_)):
self.logger.info("Received unknown command: {}".format(msg["type"])) self.logger.info(f"Received unknown command: {type_}")
self.send_message( self.send_message(
messages.error_message( messages.error_message(
cur_id, const.ERR_UNKNOWN_COMMAND, "Unknown command." cur_id, const.ERR_UNKNOWN_COMMAND, "Unknown command."
@ -105,7 +111,7 @@ class ActiveConnection:
) )
return return
handler, schema = handlers[msg["type"]] handler, schema = handler_schema
try: try:
handler(self.hass, self, schema(msg)) handler(self.hass, self, schema(msg))

View file

@ -10,6 +10,8 @@ import voluptuous as vol
from homeassistant import exceptions from homeassistant import exceptions
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.components.websocket_api.const import DOMAIN
from homeassistant.core import HomeAssistant
from tests.common import MockUser from tests.common import MockUser
@ -56,6 +58,7 @@ from tests.common import MockUser
], ],
) )
async def test_exception_handling( async def test_exception_handling(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
exc: Exception, exc: Exception,
code: str, code: str,
@ -67,6 +70,7 @@ async def test_exception_handling(
user = MockUser() user = MockUser()
refresh_token = Mock() refresh_token = Mock()
current_request = AsyncMock() current_request = AsyncMock()
hass.data[DOMAIN] = {}
def get_extra_info(key: str) -> Any: def get_extra_info(key: str) -> Any:
if key == "sslcontext": if key == "sslcontext":
@ -89,7 +93,7 @@ async def test_exception_handling(
) as current_request: ) as current_request:
current_request.get.return_value = mocked_request current_request.get.return_value = mocked_request
conn = websocket_api.ActiveConnection( conn = websocket_api.ActiveConnection(
logging.getLogger(__name__), None, send_messages.append, user, refresh_token logging.getLogger(__name__), hass, send_messages.append, user, refresh_token
) )
conn.async_handle_exception({"id": 5}, exc) conn.async_handle_exception({"id": 5}, exc)

View file

@ -17,7 +17,7 @@ from tests.typing import WebSocketGenerator
@pytest.fixture @pytest.fixture
def mock_low_queue(): def mock_low_queue():
"""Mock a low queue.""" """Mock a low queue."""
with patch("homeassistant.components.websocket_api.http.MAX_PENDING_MSG", 5): with patch("homeassistant.components.websocket_api.http.MAX_PENDING_MSG", 1):
yield yield