diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index 036cd690da2..8995f075f32 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -77,7 +77,7 @@ def handle_subscribe_events(hass, connection, msg): ): return - connection.send_message(messages.event_message(msg["id"], event)) + connection.send_message(messages.cached_event_message(msg["id"], event)) else: @@ -87,7 +87,7 @@ def handle_subscribe_events(hass, connection, msg): if event.event_type == EVENT_TIME_CHANGED: return - connection.send_message(messages.event_message(msg["id"], event.as_dict())) + connection.send_message(messages.cached_event_message(msg["id"], event)) connection.subscriptions[msg["id"]] = hass.bus.async_listen( event_type, forward_events diff --git a/homeassistant/components/websocket_api/http.py b/homeassistant/components/websocket_api/http.py index 7c56fcbc606..27dac62791e 100644 --- a/homeassistant/components/websocket_api/http.py +++ b/homeassistant/components/websocket_api/http.py @@ -11,17 +11,11 @@ from homeassistant.components.http import HomeAssistantView from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.core import callback from homeassistant.helpers.event import async_call_later -from homeassistant.util.json import ( - find_paths_unserializable_data, - format_unserializable_data, -) from .auth import AuthPhase, auth_required_message from .const import ( CANCELLATION_ERRORS, DATA_CONNECTIONS, - ERR_UNKNOWN_ERROR, - JSON_DUMP, MAX_PENDING_MSG, PENDING_MSG_PEAK, PENDING_MSG_PEAK_TIME, @@ -30,7 +24,7 @@ from .const import ( URL, ) from .error import Disconnect -from .messages import error_message +from .messages import message_to_json # mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs @@ -72,27 +66,10 @@ class WebSocketHandler: self._logger.debug("Sending %s", message) - if isinstance(message, str): - await self.wsock.send_str(message) - continue + if not isinstance(message, str): + message = message_to_json(message) - try: - dumped = JSON_DUMP(message) - except (ValueError, TypeError): - await self.wsock.send_json( - error_message( - message["id"], ERR_UNKNOWN_ERROR, "Invalid JSON in response" - ) - ) - self._logger.error( - "Unable to serialize to JSON. Bad data found at %s", - format_unserializable_data( - find_paths_unserializable_data(message, dump=JSON_DUMP) - ), - ) - continue - - await self.wsock.send_str(dumped) + await self.wsock.send_str(message) # Clean up the peaker checker when we shut down the writer if self._peak_checker_unsub: diff --git a/homeassistant/components/websocket_api/messages.py b/homeassistant/components/websocket_api/messages.py index 27d557e8110..52e97b60ccf 100644 --- a/homeassistant/components/websocket_api/messages.py +++ b/homeassistant/components/websocket_api/messages.py @@ -1,11 +1,21 @@ """Message templates for websocket commands.""" +from functools import lru_cache +import logging +from typing import Any, Dict + import voluptuous as vol +from homeassistant.core import Event from homeassistant.helpers import config_validation as cv +from homeassistant.util.json import ( + find_paths_unserializable_data, + format_unserializable_data, +) from . import const +_LOGGER = logging.getLogger(__name__) # mypy: allow-untyped-defs # Minimal requirements of a message @@ -18,12 +28,12 @@ MINIMAL_MESSAGE_SCHEMA = vol.Schema( BASE_COMMAND_MESSAGE_SCHEMA = vol.Schema({vol.Required("id"): cv.positive_int}) -def result_message(iden, result=None): +def result_message(iden: int, result: Any = None) -> Dict: """Return a success result message.""" return {"id": iden, "type": const.TYPE_RESULT, "success": True, "result": result} -def error_message(iden, code, message): +def error_message(iden: int, code: str, message: str) -> Dict: """Return an error result message.""" return { "id": iden, @@ -33,6 +43,37 @@ def error_message(iden, code, message): } -def event_message(iden, event): +def event_message(iden: int, event: Any) -> Dict: """Return an event message.""" return {"id": iden, "type": "event", "event": event} + + +@lru_cache(maxsize=128) +def cached_event_message(iden: int, event: Event) -> str: + """Return an event message. + + Serialize to json once per message. + + Since we can have many clients connected that are + all getting many of the same events (mostly state changed) + we can avoid serializing the same data for each connection. + """ + return message_to_json(event_message(iden, event)) + + +def message_to_json(message: Any) -> str: + """Serialize a websocket message to json.""" + try: + return const.JSON_DUMP(message) + except (ValueError, TypeError): + _LOGGER.error( + "Unable to serialize to JSON. Bad data found at %s", + format_unserializable_data( + find_paths_unserializable_data(message, dump=const.JSON_DUMP) + ), + ) + return const.JSON_DUMP( + error_message( + message["id"], const.ERR_UNKNOWN_ERROR, "Invalid JSON in response" + ) + ) diff --git a/homeassistant/core.py b/homeassistant/core.py index f230fef01eb..fd34032112b 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -548,6 +548,11 @@ class Event: self.time_fired = time_fired or dt_util.utcnow() self.context: Context = context or Context() + def __hash__(self) -> int: + """Make hashable.""" + # The only event type that shares context are the TIME_CHANGED + return hash((self.event_type, self.context.id, self.time_fired)) + def as_dict(self) -> Dict: """Create a dict representation of this Event. diff --git a/tests/components/websocket_api/test_messages.py b/tests/components/websocket_api/test_messages.py new file mode 100644 index 00000000000..832b72c5c1c --- /dev/null +++ b/tests/components/websocket_api/test_messages.py @@ -0,0 +1,65 @@ +"""Test Websocket API messages module.""" + +from homeassistant.components.websocket_api.messages import ( + cached_event_message, + message_to_json, +) +from homeassistant.const import EVENT_STATE_CHANGED +from homeassistant.core import callback + + +async def test_cached_event_message(hass): + """Test that we cache event messages.""" + + events = [] + + @callback + def _event_listener(event): + events.append(event) + + hass.bus.async_listen(EVENT_STATE_CHANGED, _event_listener) + + hass.states.async_set("light.window", "on") + hass.states.async_set("light.window", "off") + await hass.async_block_till_done() + + assert len(events) == 2 + + msg0 = cached_event_message(2, events[0]) + assert msg0 == cached_event_message(2, events[0]) + + msg1 = cached_event_message(2, events[1]) + assert msg1 == cached_event_message(2, events[1]) + + assert msg0 != msg1 + + cache_info = cached_event_message.cache_info() + assert cache_info.hits == 2 + assert cache_info.misses == 2 + assert cache_info.currsize == 2 + + cached_event_message(2, events[1]) + cache_info = cached_event_message.cache_info() + assert cache_info.hits == 3 + assert cache_info.misses == 2 + assert cache_info.currsize == 2 + + +async def test_message_to_json(caplog): + """Test we can serialize websocket messages.""" + + json_str = message_to_json({"id": 1, "message": "xyz"}) + + assert json_str == '{"id": 1, "message": "xyz"}' + + json_str2 = message_to_json({"id": 1, "message": _Unserializeable()}) + + assert ( + json_str2 + == '{"id": 1, "type": "result", "success": false, "error": {"code": "unknown_error", "message": "Invalid JSON in response"}}' + ) + assert "Unable to serialize to JSON" in caplog.text + + +class _Unserializeable: + """A class that cannot be serialized."""