Implement websocket message coalescing (#77238)
Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
parent
2161b6f049
commit
f6a03625ba
6 changed files with 361 additions and 22 deletions
|
@ -74,6 +74,7 @@ def async_register_commands(
|
|||
async_reg(hass, handle_validate_config)
|
||||
async_reg(hass, handle_subscribe_entities)
|
||||
async_reg(hass, handle_supported_brands)
|
||||
async_reg(hass, handle_supported_features)
|
||||
|
||||
|
||||
def pong_message(iden: int) -> dict[str, Any]:
|
||||
|
@ -723,3 +724,18 @@ async def handle_supported_brands(
|
|||
raise int_or_exc
|
||||
data[int_or_exc.domain] = int_or_exc.manifest["supported_brands"]
|
||||
connection.send_result(msg["id"], data)
|
||||
|
||||
|
||||
@callback
|
||||
@decorators.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "supported_features",
|
||||
vol.Required("features"): {str: int},
|
||||
}
|
||||
)
|
||||
def handle_supported_features(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle setting supported features."""
|
||||
connection.supported_features = msg["features"]
|
||||
connection.send_result(msg["id"])
|
||||
|
|
|
@ -42,6 +42,7 @@ class ActiveConnection:
|
|||
self.refresh_token_id = refresh_token.id
|
||||
self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
|
||||
self.last_id = 0
|
||||
self.supported_features: dict[str, float] = {}
|
||||
current_connection.set(self)
|
||||
|
||||
def context(self, msg: dict[str, Any]) -> Context:
|
||||
|
|
|
@ -55,3 +55,5 @@ COMPRESSED_STATE_ATTRIBUTES = "a"
|
|||
COMPRESSED_STATE_CONTEXT = "c"
|
||||
COMPRESSED_STATE_LAST_CHANGED = "lc"
|
||||
COMPRESSED_STATE_LAST_UPDATED = "lu"
|
||||
|
||||
FEATURE_COALESCE_MESSAGES = "coalesce_messages"
|
||||
|
|
|
@ -6,7 +6,7 @@ from collections.abc import Callable
|
|||
from contextlib import suppress
|
||||
import datetime as dt
|
||||
import logging
|
||||
from typing import Any, Final
|
||||
from typing import TYPE_CHECKING, Any, Final
|
||||
|
||||
from aiohttp import WSMsgType, web
|
||||
import async_timeout
|
||||
|
@ -16,11 +16,13 @@ from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
|||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
from homeassistant.helpers.json import json_loads
|
||||
|
||||
from .auth import AuthPhase, auth_required_message
|
||||
from .const import (
|
||||
CANCELLATION_ERRORS,
|
||||
DATA_CONNECTIONS,
|
||||
FEATURE_COALESCE_MESSAGES,
|
||||
MAX_PENDING_MSG,
|
||||
PENDING_MSG_PEAK,
|
||||
PENDING_MSG_PEAK_TIME,
|
||||
|
@ -31,6 +33,10 @@ from .const import (
|
|||
from .error import Disconnect
|
||||
from .messages import message_to_json
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .connection import ActiveConnection
|
||||
|
||||
|
||||
_WS_LOGGER: Final = logging.getLogger(f"{__name__}.connection")
|
||||
|
||||
|
||||
|
@ -67,26 +73,47 @@ class WebSocketHandler:
|
|||
self._writer_task: asyncio.Task | None = None
|
||||
self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)})
|
||||
self._peak_checker_unsub: Callable[[], None] | None = None
|
||||
self.connection: ActiveConnection | None = None
|
||||
|
||||
async def _writer(self) -> None:
|
||||
"""Write outgoing messages."""
|
||||
# Exceptions if Socket disconnected or cancelled by connection handler
|
||||
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
|
||||
while not self.wsock.closed:
|
||||
if (process := await self._to_write.get()) is None:
|
||||
break
|
||||
to_write = self._to_write
|
||||
logger = self._logger
|
||||
wsock = self.wsock
|
||||
try:
|
||||
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
|
||||
while not self.wsock.closed:
|
||||
if (process := await to_write.get()) is None:
|
||||
return
|
||||
message = process if isinstance(process, str) else process()
|
||||
|
||||
if not isinstance(process, str):
|
||||
message: str = process()
|
||||
else:
|
||||
message = process
|
||||
self._logger.debug("Sending %s", message)
|
||||
await self.wsock.send_str(message)
|
||||
if (
|
||||
to_write.empty()
|
||||
or not self.connection
|
||||
or FEATURE_COALESCE_MESSAGES
|
||||
not in self.connection.supported_features
|
||||
):
|
||||
logger.debug("Sending %s", message)
|
||||
await wsock.send_str(message)
|
||||
continue
|
||||
|
||||
# Clean up the peaker checker when we shut down the writer
|
||||
if self._peak_checker_unsub is not None:
|
||||
self._peak_checker_unsub()
|
||||
self._peak_checker_unsub = None
|
||||
messages: list[str] = [message]
|
||||
while not to_write.empty():
|
||||
if (process := to_write.get_nowait()) is None:
|
||||
return
|
||||
messages.append(
|
||||
process if isinstance(process, str) else process()
|
||||
)
|
||||
|
||||
coalesced_messages = "[" + ",".join(messages) + "]"
|
||||
self._logger.debug("Sending %s", coalesced_messages)
|
||||
await self.wsock.send_str(coalesced_messages)
|
||||
finally:
|
||||
# Clean up the peaker checker when we shut down the writer
|
||||
if self._peak_checker_unsub is not None:
|
||||
self._peak_checker_unsub()
|
||||
self._peak_checker_unsub = None
|
||||
|
||||
@callback
|
||||
def _send_message(self, message: str | dict[str, Any] | Callable[[], str]) -> None:
|
||||
|
@ -194,13 +221,13 @@ class WebSocketHandler:
|
|||
raise Disconnect
|
||||
|
||||
try:
|
||||
msg_data = msg.json()
|
||||
msg_data = msg.json(loads=json_loads)
|
||||
except ValueError as err:
|
||||
disconnect_warn = "Received invalid JSON."
|
||||
raise Disconnect from err
|
||||
|
||||
self._logger.debug("Received %s", msg_data)
|
||||
connection = await auth.async_handle(msg_data)
|
||||
self.connection = connection = await auth.async_handle(msg_data)
|
||||
self.hass.data[DATA_CONNECTIONS] = (
|
||||
self.hass.data.get(DATA_CONNECTIONS, 0) + 1
|
||||
)
|
||||
|
@ -218,13 +245,18 @@ class WebSocketHandler:
|
|||
break
|
||||
|
||||
try:
|
||||
msg_data = msg.json()
|
||||
msg_data = msg.json(loads=json_loads)
|
||||
except ValueError:
|
||||
disconnect_warn = "Received invalid JSON."
|
||||
break
|
||||
|
||||
self._logger.debug("Received %s", msg_data)
|
||||
connection.async_handle(msg_data)
|
||||
if not isinstance(msg_data, list):
|
||||
connection.async_handle(msg_data)
|
||||
continue
|
||||
|
||||
for split_msg in msg_data:
|
||||
connection.async_handle(split_msg)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self._logger.info("Connection closed by client")
|
||||
|
@ -257,6 +289,8 @@ class WebSocketHandler:
|
|||
|
||||
if connection is not None:
|
||||
self.hass.data[DATA_CONNECTIONS] -= 1
|
||||
self.connection = None
|
||||
|
||||
async_dispatcher_send(self.hass, SIGNAL_WEBSOCKET_DISCONNECTED)
|
||||
|
||||
return wsock
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue