Improve websocket throughput and reduce latency (#92967)

This commit is contained in:
J. Nick Koston 2023-05-13 00:13:57 +09:00 committed by GitHub
parent 9a70f47049
commit 8711735ec0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 120 additions and 47 deletions

View file

@ -715,7 +715,7 @@ def handle_supported_features(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle setting supported features."""
connection.supported_features = msg["features"]
connection.set_supported_features(msg["features"])
connection.send_result(msg["id"])

View file

@ -48,6 +48,7 @@ class ActiveConnection:
self.refresh_token_id = refresh_token.id
self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
self.last_id = 0
self.can_coalesce = False
self.supported_features: dict[str, float] = {}
self.handlers: dict[str, tuple[MessageHandler, vol.Schema]] = self.hass.data[
const.DOMAIN
@ -55,6 +56,11 @@ class ActiveConnection:
self.binary_handlers: list[BinaryHandler | None] = []
current_connection.set(self)
def set_supported_features(self, features: dict[str, float]) -> None:
"""Set supported features."""
self.supported_features = features
self.can_coalesce = const.FEATURE_COALESCE_MESSAGES in features
def get_description(self, request: web.Request | None) -> str:
"""Return a description of the connection."""
description = self.user.name or ""

View file

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
from collections import deque
from collections.abc import Callable
from contextlib import suppress
import datetime as dt
@ -22,7 +23,6 @@ from .auth import AuthPhase, auth_required_message
from .const import (
CANCELLATION_ERRORS,
DATA_CONNECTIONS,
FEATURE_COALESCE_MESSAGES,
MAX_PENDING_MSG,
PENDING_MSG_PEAK,
PENDING_MSG_PEAK_TIME,
@ -71,7 +71,6 @@ class WebSocketHandler:
self.hass = hass
self.request = request
self.wsock = web.WebSocketResponse(heartbeat=55)
self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG)
self._handle_task: asyncio.Task | None = None
self._writer_task: asyncio.Task | None = None
self._closing: bool = False
@ -79,6 +78,13 @@ class WebSocketHandler:
self._peak_checker_unsub: Callable[[], None] | None = None
self.connection: ActiveConnection | None = None
# The WebSocketHandler has a single consumer and path
# to where messages are queued. This allows the implementation
# to use a deque and an asyncio.Future to avoid the overhead of
# an asyncio.Queue.
self._message_queue: deque = deque()
self._ready_future: asyncio.Future[None] | None = None
@property
def description(self) -> str:
"""Return a description of the connection."""
@ -88,39 +94,52 @@ class WebSocketHandler:
async def _writer(self) -> None:
"""Write outgoing messages."""
# Exceptions if Socket disconnected or cancelled by connection handler
to_write = self._to_write
# Variables are set locally to avoid lookups in the loop
message_queue = self._message_queue
logger = self._logger
wsock = self.wsock
send_str = self.wsock.send_str
loop = self.hass.loop
debug = logger.debug
# Exceptions if Socket disconnected or cancelled by connection handler
try:
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
while not self.wsock.closed:
if (process := await to_write.get()) is None:
if (messages_remaining := len(message_queue)) == 0:
self._ready_future = loop.create_future()
await self._ready_future
messages_remaining = len(message_queue)
# A None message is used to signal the end of the connection
if (process := message_queue.popleft()) is None:
return
messages_remaining -= 1
message = process if isinstance(process, str) else process()
if (
to_write.empty()
not messages_remaining
or not self.connection
or FEATURE_COALESCE_MESSAGES
not in self.connection.supported_features
or not self.connection.can_coalesce
):
logger.debug("Sending %s", message)
await wsock.send_str(message)
debug("Sending %s", message)
await send_str(message)
continue
messages: list[str] = [message]
while not to_write.empty():
if (process := to_write.get_nowait()) is None:
while messages_remaining:
# A None message is used to signal the end of the connection
if (process := message_queue.popleft()) is None:
return
messages.append(
process if isinstance(process, str) else process()
)
messages_remaining -= 1
coalesced_messages = "[" + ",".join(messages) + "]"
logger.debug("Sending %s", coalesced_messages)
await wsock.send_str(coalesced_messages)
debug("Sending %s", coalesced_messages)
await send_str(coalesced_messages)
finally:
# Clean up the peaker checker when we shut down the writer
# Clean up the peak checker when we shut down the writer
self._cancel_peak_checker()
@callback
@ -146,11 +165,9 @@ class WebSocketHandler:
if isinstance(message, dict):
message = message_to_json(message)
to_write = self._to_write
try:
to_write.put_nowait(message)
except asyncio.QueueFull:
message_queue = self._message_queue
queue_size_before_add = len(message_queue)
if queue_size_before_add >= MAX_PENDING_MSG:
self._logger.error(
(
"%s: Client unable to keep up with pending messages. Reached %s pending"
@ -162,10 +179,15 @@ class WebSocketHandler:
message,
)
self._cancel()
return
message_queue.append(message)
if self._ready_future and not self._ready_future.done():
self._ready_future.set_result(None)
peak_checker_active = self._peak_checker_unsub is not None
if to_write.qsize() < PENDING_MSG_PEAK:
if queue_size_before_add <= PENDING_MSG_PEAK:
if peak_checker_active:
self._cancel_peak_checker()
return
@ -180,7 +202,7 @@ class WebSocketHandler:
"""Check that we are no longer above the write peak."""
self._peak_checker_unsub = None
if self._to_write.qsize() < PENDING_MSG_PEAK:
if len(self._message_queue) < PENDING_MSG_PEAK:
return
self._logger.error(
@ -199,6 +221,7 @@ class WebSocketHandler:
def _cancel(self) -> None:
"""Cancel the connection."""
self._closing = True
self._cancel_peak_checker()
if self._handle_task is not None:
self._handle_task.cancel()
if self._writer_task is not None:
@ -356,14 +379,14 @@ class WebSocketHandler:
self._closing = True
self._message_queue.append(None)
if self._ready_future and not self._ready_future.done():
self._ready_future.set_result(None)
try:
self._to_write.put_nowait(None)
# Make sure all error messages are written before closing
await self._writer_task
await wsock.close()
except asyncio.QueueFull: # can be raised by put_nowait
self._writer_task.cancel()
finally:
if disconnect_warn is None:
self._logger.debug("Disconnected")

View file

@ -1,7 +1,7 @@
"""Test Websocket API http module."""
import asyncio
from datetime import timedelta
from typing import Any
from typing import Any, cast
from unittest.mock import patch
from aiohttp import ServerDisconnectedError, WSMsgType, web
@ -53,12 +53,12 @@ async def test_pending_msg_peak(
) -> None:
"""Test pending msg overflow command."""
orig_handler = http.WebSocketHandler
instance = None
setup_instance: http.WebSocketHandler | None = None
def instantiate_handler(*args):
nonlocal instance
instance = orig_handler(*args)
return instance
nonlocal setup_instance
setup_instance = orig_handler(*args)
return setup_instance
with patch(
"homeassistant.components.websocket_api.http.WebSocketHandler",
@ -66,12 +66,11 @@ async def test_pending_msg_peak(
):
websocket_client = await hass_ws_client()
# Kill writer task and fill queue past peak
for _ in range(5):
instance._to_write.put_nowait(None)
instance: http.WebSocketHandler = cast(http.WebSocketHandler, setup_instance)
# Trigger the peak check
instance._send_message({})
# Fill the queue past the allowed peak
for _ in range(10):
instance._send_message({})
async_fire_time_changed(
hass, utcnow() + timedelta(seconds=const.PENDING_MSG_PEAK_TIME + 1)
@ -79,8 +78,54 @@ async def test_pending_msg_peak(
msg = await websocket_client.receive()
assert msg.type == WSMsgType.close
assert "Client unable to keep up with pending messages" in caplog.text
assert "Stayed over 5 for 5 seconds"
async def test_pending_msg_peak_recovery(
hass: HomeAssistant,
mock_low_peak,
hass_ws_client: WebSocketGenerator,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test pending msg nears the peak but recovers."""
orig_handler = http.WebSocketHandler
setup_instance: http.WebSocketHandler | None = None
def instantiate_handler(*args):
nonlocal setup_instance
setup_instance = orig_handler(*args)
return setup_instance
with patch(
"homeassistant.components.websocket_api.http.WebSocketHandler",
instantiate_handler,
):
websocket_client = await hass_ws_client()
instance: http.WebSocketHandler = cast(http.WebSocketHandler, setup_instance)
# Make sure the call later is started
for _ in range(10):
instance._send_message({})
for _ in range(10):
msg = await websocket_client.receive()
assert msg.type == WSMsgType.TEXT
instance._send_message({})
msg = await websocket_client.receive()
assert msg.type == WSMsgType.TEXT
# Cleanly shutdown
instance._send_message({})
instance._handle_task.cancel()
msg = await websocket_client.receive()
assert msg.type == WSMsgType.TEXT
msg = await websocket_client.receive()
assert msg.type == WSMsgType.close
assert "Client unable to keep up with pending messages" not in caplog.text
async def test_pending_msg_peak_but_does_not_overflow(
@ -91,12 +136,12 @@ async def test_pending_msg_peak_but_does_not_overflow(
) -> None:
"""Test pending msg hits the low peak but recovers and does not overflow."""
orig_handler = http.WebSocketHandler
instance: http.WebSocketHandler | None = None
setup_instance: http.WebSocketHandler | None = None
def instantiate_handler(*args):
nonlocal instance
instance = orig_handler(*args)
return instance
nonlocal setup_instance
setup_instance = orig_handler(*args)
return setup_instance
with patch(
"homeassistant.components.websocket_api.http.WebSocketHandler",
@ -104,18 +149,17 @@ async def test_pending_msg_peak_but_does_not_overflow(
):
websocket_client = await hass_ws_client()
assert instance is not None
instance: http.WebSocketHandler = cast(http.WebSocketHandler, setup_instance)
# Kill writer task and fill queue past peak
for _ in range(5):
instance._to_write.put_nowait(None)
instance._message_queue.append(None)
# Trigger the peak check
instance._send_message({})
# Clear the queue
while instance._to_write.qsize() > 0:
instance._to_write.get_nowait()
instance._message_queue.clear()
# Trigger the peak clear
instance._send_message({})