Implement websocket message coalescing (#77238)

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
J. Nick Koston 2022-08-24 22:50:48 -05:00 committed by GitHub
parent 2161b6f049
commit f6a03625ba
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 361 additions and 22 deletions

View file

@ -74,6 +74,7 @@ def async_register_commands(
async_reg(hass, handle_validate_config)
async_reg(hass, handle_subscribe_entities)
async_reg(hass, handle_supported_brands)
async_reg(hass, handle_supported_features)
def pong_message(iden: int) -> dict[str, Any]:
@ -723,3 +724,18 @@ async def handle_supported_brands(
raise int_or_exc
data[int_or_exc.domain] = int_or_exc.manifest["supported_brands"]
connection.send_result(msg["id"], data)
@callback
@decorators.websocket_command(
{
vol.Required("type"): "supported_features",
vol.Required("features"): {str: int},
}
)
def handle_supported_features(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle setting supported features."""
connection.supported_features = msg["features"]
connection.send_result(msg["id"])

View file

@ -42,6 +42,7 @@ class ActiveConnection:
self.refresh_token_id = refresh_token.id
self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
self.last_id = 0
self.supported_features: dict[str, float] = {}
current_connection.set(self)
def context(self, msg: dict[str, Any]) -> Context:

View file

@ -55,3 +55,5 @@ COMPRESSED_STATE_ATTRIBUTES = "a"
COMPRESSED_STATE_CONTEXT = "c"
COMPRESSED_STATE_LAST_CHANGED = "lc"
COMPRESSED_STATE_LAST_UPDATED = "lu"
FEATURE_COALESCE_MESSAGES = "coalesce_messages"

View file

@ -6,7 +6,7 @@ from collections.abc import Callable
from contextlib import suppress
import datetime as dt
import logging
from typing import Any, Final
from typing import TYPE_CHECKING, Any, Final
from aiohttp import WSMsgType, web
import async_timeout
@ -16,11 +16,13 @@ from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.event import async_call_later
from homeassistant.helpers.json import json_loads
from .auth import AuthPhase, auth_required_message
from .const import (
CANCELLATION_ERRORS,
DATA_CONNECTIONS,
FEATURE_COALESCE_MESSAGES,
MAX_PENDING_MSG,
PENDING_MSG_PEAK,
PENDING_MSG_PEAK_TIME,
@ -31,6 +33,10 @@ from .const import (
from .error import Disconnect
from .messages import message_to_json
if TYPE_CHECKING:
from .connection import ActiveConnection
_WS_LOGGER: Final = logging.getLogger(f"{__name__}.connection")
@ -67,26 +73,47 @@ class WebSocketHandler:
self._writer_task: asyncio.Task | None = None
self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)})
self._peak_checker_unsub: Callable[[], None] | None = None
self.connection: ActiveConnection | None = None
async def _writer(self) -> None:
"""Write outgoing messages."""
# Exceptions if Socket disconnected or cancelled by connection handler
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
while not self.wsock.closed:
if (process := await self._to_write.get()) is None:
break
to_write = self._to_write
logger = self._logger
wsock = self.wsock
try:
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
while not self.wsock.closed:
if (process := await to_write.get()) is None:
return
message = process if isinstance(process, str) else process()
if not isinstance(process, str):
message: str = process()
else:
message = process
self._logger.debug("Sending %s", message)
await self.wsock.send_str(message)
if (
to_write.empty()
or not self.connection
or FEATURE_COALESCE_MESSAGES
not in self.connection.supported_features
):
logger.debug("Sending %s", message)
await wsock.send_str(message)
continue
# Clean up the peaker checker when we shut down the writer
if self._peak_checker_unsub is not None:
self._peak_checker_unsub()
self._peak_checker_unsub = None
messages: list[str] = [message]
while not to_write.empty():
if (process := to_write.get_nowait()) is None:
return
messages.append(
process if isinstance(process, str) else process()
)
coalesced_messages = "[" + ",".join(messages) + "]"
self._logger.debug("Sending %s", coalesced_messages)
await self.wsock.send_str(coalesced_messages)
finally:
# Clean up the peaker checker when we shut down the writer
if self._peak_checker_unsub is not None:
self._peak_checker_unsub()
self._peak_checker_unsub = None
@callback
def _send_message(self, message: str | dict[str, Any] | Callable[[], str]) -> None:
@ -194,13 +221,13 @@ class WebSocketHandler:
raise Disconnect
try:
msg_data = msg.json()
msg_data = msg.json(loads=json_loads)
except ValueError as err:
disconnect_warn = "Received invalid JSON."
raise Disconnect from err
self._logger.debug("Received %s", msg_data)
connection = await auth.async_handle(msg_data)
self.connection = connection = await auth.async_handle(msg_data)
self.hass.data[DATA_CONNECTIONS] = (
self.hass.data.get(DATA_CONNECTIONS, 0) + 1
)
@ -218,13 +245,18 @@ class WebSocketHandler:
break
try:
msg_data = msg.json()
msg_data = msg.json(loads=json_loads)
except ValueError:
disconnect_warn = "Received invalid JSON."
break
self._logger.debug("Received %s", msg_data)
connection.async_handle(msg_data)
if not isinstance(msg_data, list):
connection.async_handle(msg_data)
continue
for split_msg in msg_data:
connection.async_handle(split_msg)
except asyncio.CancelledError:
self._logger.info("Connection closed by client")
@ -257,6 +289,8 @@ class WebSocketHandler:
if connection is not None:
self.hass.data[DATA_CONNECTIONS] -= 1
self.connection = None
async_dispatcher_send(self.hass, SIGNAL_WEBSOCKET_DISCONNECTED)
return wsock