Improve websocket throughput and reduce latency (#92967)
This commit is contained in:
parent
9a70f47049
commit
8711735ec0
4 changed files with 120 additions and 47 deletions
|
@ -715,7 +715,7 @@ def handle_supported_features(
|
|||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle setting supported features."""
|
||||
connection.supported_features = msg["features"]
|
||||
connection.set_supported_features(msg["features"])
|
||||
connection.send_result(msg["id"])
|
||||
|
||||
|
||||
|
|
|
@ -48,6 +48,7 @@ class ActiveConnection:
|
|||
self.refresh_token_id = refresh_token.id
|
||||
self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
|
||||
self.last_id = 0
|
||||
self.can_coalesce = False
|
||||
self.supported_features: dict[str, float] = {}
|
||||
self.handlers: dict[str, tuple[MessageHandler, vol.Schema]] = self.hass.data[
|
||||
const.DOMAIN
|
||||
|
@ -55,6 +56,11 @@ class ActiveConnection:
|
|||
self.binary_handlers: list[BinaryHandler | None] = []
|
||||
current_connection.set(self)
|
||||
|
||||
def set_supported_features(self, features: dict[str, float]) -> None:
|
||||
"""Set supported features."""
|
||||
self.supported_features = features
|
||||
self.can_coalesce = const.FEATURE_COALESCE_MESSAGES in features
|
||||
|
||||
def get_description(self, request: web.Request | None) -> str:
|
||||
"""Return a description of the connection."""
|
||||
description = self.user.name or ""
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from collections.abc import Callable
|
||||
from contextlib import suppress
|
||||
import datetime as dt
|
||||
|
@ -22,7 +23,6 @@ from .auth import AuthPhase, auth_required_message
|
|||
from .const import (
|
||||
CANCELLATION_ERRORS,
|
||||
DATA_CONNECTIONS,
|
||||
FEATURE_COALESCE_MESSAGES,
|
||||
MAX_PENDING_MSG,
|
||||
PENDING_MSG_PEAK,
|
||||
PENDING_MSG_PEAK_TIME,
|
||||
|
@ -71,7 +71,6 @@ class WebSocketHandler:
|
|||
self.hass = hass
|
||||
self.request = request
|
||||
self.wsock = web.WebSocketResponse(heartbeat=55)
|
||||
self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG)
|
||||
self._handle_task: asyncio.Task | None = None
|
||||
self._writer_task: asyncio.Task | None = None
|
||||
self._closing: bool = False
|
||||
|
@ -79,6 +78,13 @@ class WebSocketHandler:
|
|||
self._peak_checker_unsub: Callable[[], None] | None = None
|
||||
self.connection: ActiveConnection | None = None
|
||||
|
||||
# The WebSocketHandler has a single consumer and path
|
||||
# to where messages are queued. This allows the implementation
|
||||
# to use a deque and an asyncio.Future to avoid the overhead of
|
||||
# an asyncio.Queue.
|
||||
self._message_queue: deque = deque()
|
||||
self._ready_future: asyncio.Future[None] | None = None
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""Return a description of the connection."""
|
||||
|
@ -88,39 +94,52 @@ class WebSocketHandler:
|
|||
|
||||
async def _writer(self) -> None:
|
||||
"""Write outgoing messages."""
|
||||
# Exceptions if Socket disconnected or cancelled by connection handler
|
||||
to_write = self._to_write
|
||||
# Variables are set locally to avoid lookups in the loop
|
||||
message_queue = self._message_queue
|
||||
logger = self._logger
|
||||
wsock = self.wsock
|
||||
send_str = self.wsock.send_str
|
||||
loop = self.hass.loop
|
||||
debug = logger.debug
|
||||
# Exceptions if Socket disconnected or cancelled by connection handler
|
||||
try:
|
||||
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
|
||||
while not self.wsock.closed:
|
||||
if (process := await to_write.get()) is None:
|
||||
if (messages_remaining := len(message_queue)) == 0:
|
||||
self._ready_future = loop.create_future()
|
||||
await self._ready_future
|
||||
messages_remaining = len(message_queue)
|
||||
|
||||
# A None message is used to signal the end of the connection
|
||||
if (process := message_queue.popleft()) is None:
|
||||
return
|
||||
|
||||
messages_remaining -= 1
|
||||
message = process if isinstance(process, str) else process()
|
||||
|
||||
if (
|
||||
to_write.empty()
|
||||
not messages_remaining
|
||||
or not self.connection
|
||||
or FEATURE_COALESCE_MESSAGES
|
||||
not in self.connection.supported_features
|
||||
or not self.connection.can_coalesce
|
||||
):
|
||||
logger.debug("Sending %s", message)
|
||||
await wsock.send_str(message)
|
||||
debug("Sending %s", message)
|
||||
await send_str(message)
|
||||
continue
|
||||
|
||||
messages: list[str] = [message]
|
||||
while not to_write.empty():
|
||||
if (process := to_write.get_nowait()) is None:
|
||||
while messages_remaining:
|
||||
# A None message is used to signal the end of the connection
|
||||
if (process := message_queue.popleft()) is None:
|
||||
return
|
||||
messages.append(
|
||||
process if isinstance(process, str) else process()
|
||||
)
|
||||
messages_remaining -= 1
|
||||
|
||||
coalesced_messages = "[" + ",".join(messages) + "]"
|
||||
logger.debug("Sending %s", coalesced_messages)
|
||||
await wsock.send_str(coalesced_messages)
|
||||
debug("Sending %s", coalesced_messages)
|
||||
await send_str(coalesced_messages)
|
||||
finally:
|
||||
# Clean up the peaker checker when we shut down the writer
|
||||
# Clean up the peak checker when we shut down the writer
|
||||
self._cancel_peak_checker()
|
||||
|
||||
@callback
|
||||
|
@ -146,11 +165,9 @@ class WebSocketHandler:
|
|||
if isinstance(message, dict):
|
||||
message = message_to_json(message)
|
||||
|
||||
to_write = self._to_write
|
||||
|
||||
try:
|
||||
to_write.put_nowait(message)
|
||||
except asyncio.QueueFull:
|
||||
message_queue = self._message_queue
|
||||
queue_size_before_add = len(message_queue)
|
||||
if queue_size_before_add >= MAX_PENDING_MSG:
|
||||
self._logger.error(
|
||||
(
|
||||
"%s: Client unable to keep up with pending messages. Reached %s pending"
|
||||
|
@ -162,10 +179,15 @@ class WebSocketHandler:
|
|||
message,
|
||||
)
|
||||
self._cancel()
|
||||
return
|
||||
|
||||
message_queue.append(message)
|
||||
if self._ready_future and not self._ready_future.done():
|
||||
self._ready_future.set_result(None)
|
||||
|
||||
peak_checker_active = self._peak_checker_unsub is not None
|
||||
|
||||
if to_write.qsize() < PENDING_MSG_PEAK:
|
||||
if queue_size_before_add <= PENDING_MSG_PEAK:
|
||||
if peak_checker_active:
|
||||
self._cancel_peak_checker()
|
||||
return
|
||||
|
@ -180,7 +202,7 @@ class WebSocketHandler:
|
|||
"""Check that we are no longer above the write peak."""
|
||||
self._peak_checker_unsub = None
|
||||
|
||||
if self._to_write.qsize() < PENDING_MSG_PEAK:
|
||||
if len(self._message_queue) < PENDING_MSG_PEAK:
|
||||
return
|
||||
|
||||
self._logger.error(
|
||||
|
@ -199,6 +221,7 @@ class WebSocketHandler:
|
|||
def _cancel(self) -> None:
|
||||
"""Cancel the connection."""
|
||||
self._closing = True
|
||||
self._cancel_peak_checker()
|
||||
if self._handle_task is not None:
|
||||
self._handle_task.cancel()
|
||||
if self._writer_task is not None:
|
||||
|
@ -356,14 +379,14 @@ class WebSocketHandler:
|
|||
|
||||
self._closing = True
|
||||
|
||||
self._message_queue.append(None)
|
||||
if self._ready_future and not self._ready_future.done():
|
||||
self._ready_future.set_result(None)
|
||||
|
||||
try:
|
||||
self._to_write.put_nowait(None)
|
||||
# Make sure all error messages are written before closing
|
||||
await self._writer_task
|
||||
await wsock.close()
|
||||
except asyncio.QueueFull: # can be raised by put_nowait
|
||||
self._writer_task.cancel()
|
||||
|
||||
finally:
|
||||
if disconnect_warn is None:
|
||||
self._logger.debug("Disconnected")
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""Test Websocket API http module."""
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
from aiohttp import ServerDisconnectedError, WSMsgType, web
|
||||
|
@ -53,12 +53,12 @@ async def test_pending_msg_peak(
|
|||
) -> None:
|
||||
"""Test pending msg overflow command."""
|
||||
orig_handler = http.WebSocketHandler
|
||||
instance = None
|
||||
setup_instance: http.WebSocketHandler | None = None
|
||||
|
||||
def instantiate_handler(*args):
|
||||
nonlocal instance
|
||||
instance = orig_handler(*args)
|
||||
return instance
|
||||
nonlocal setup_instance
|
||||
setup_instance = orig_handler(*args)
|
||||
return setup_instance
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.websocket_api.http.WebSocketHandler",
|
||||
|
@ -66,12 +66,11 @@ async def test_pending_msg_peak(
|
|||
):
|
||||
websocket_client = await hass_ws_client()
|
||||
|
||||
# Kill writer task and fill queue past peak
|
||||
for _ in range(5):
|
||||
instance._to_write.put_nowait(None)
|
||||
instance: http.WebSocketHandler = cast(http.WebSocketHandler, setup_instance)
|
||||
|
||||
# Trigger the peak check
|
||||
instance._send_message({})
|
||||
# Fill the queue past the allowed peak
|
||||
for _ in range(10):
|
||||
instance._send_message({})
|
||||
|
||||
async_fire_time_changed(
|
||||
hass, utcnow() + timedelta(seconds=const.PENDING_MSG_PEAK_TIME + 1)
|
||||
|
@ -79,8 +78,54 @@ async def test_pending_msg_peak(
|
|||
|
||||
msg = await websocket_client.receive()
|
||||
assert msg.type == WSMsgType.close
|
||||
|
||||
assert "Client unable to keep up with pending messages" in caplog.text
|
||||
assert "Stayed over 5 for 5 seconds"
|
||||
|
||||
|
||||
async def test_pending_msg_peak_recovery(
|
||||
hass: HomeAssistant,
|
||||
mock_low_peak,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test pending msg nears the peak but recovers."""
|
||||
orig_handler = http.WebSocketHandler
|
||||
setup_instance: http.WebSocketHandler | None = None
|
||||
|
||||
def instantiate_handler(*args):
|
||||
nonlocal setup_instance
|
||||
setup_instance = orig_handler(*args)
|
||||
return setup_instance
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.websocket_api.http.WebSocketHandler",
|
||||
instantiate_handler,
|
||||
):
|
||||
websocket_client = await hass_ws_client()
|
||||
|
||||
instance: http.WebSocketHandler = cast(http.WebSocketHandler, setup_instance)
|
||||
|
||||
# Make sure the call later is started
|
||||
for _ in range(10):
|
||||
instance._send_message({})
|
||||
|
||||
for _ in range(10):
|
||||
msg = await websocket_client.receive()
|
||||
assert msg.type == WSMsgType.TEXT
|
||||
|
||||
instance._send_message({})
|
||||
msg = await websocket_client.receive()
|
||||
assert msg.type == WSMsgType.TEXT
|
||||
|
||||
# Cleanly shutdown
|
||||
instance._send_message({})
|
||||
instance._handle_task.cancel()
|
||||
|
||||
msg = await websocket_client.receive()
|
||||
assert msg.type == WSMsgType.TEXT
|
||||
msg = await websocket_client.receive()
|
||||
assert msg.type == WSMsgType.close
|
||||
assert "Client unable to keep up with pending messages" not in caplog.text
|
||||
|
||||
|
||||
async def test_pending_msg_peak_but_does_not_overflow(
|
||||
|
@ -91,12 +136,12 @@ async def test_pending_msg_peak_but_does_not_overflow(
|
|||
) -> None:
|
||||
"""Test pending msg hits the low peak but recovers and does not overflow."""
|
||||
orig_handler = http.WebSocketHandler
|
||||
instance: http.WebSocketHandler | None = None
|
||||
setup_instance: http.WebSocketHandler | None = None
|
||||
|
||||
def instantiate_handler(*args):
|
||||
nonlocal instance
|
||||
instance = orig_handler(*args)
|
||||
return instance
|
||||
nonlocal setup_instance
|
||||
setup_instance = orig_handler(*args)
|
||||
return setup_instance
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.websocket_api.http.WebSocketHandler",
|
||||
|
@ -104,18 +149,17 @@ async def test_pending_msg_peak_but_does_not_overflow(
|
|||
):
|
||||
websocket_client = await hass_ws_client()
|
||||
|
||||
assert instance is not None
|
||||
instance: http.WebSocketHandler = cast(http.WebSocketHandler, setup_instance)
|
||||
|
||||
# Kill writer task and fill queue past peak
|
||||
for _ in range(5):
|
||||
instance._to_write.put_nowait(None)
|
||||
instance._message_queue.append(None)
|
||||
|
||||
# Trigger the peak check
|
||||
instance._send_message({})
|
||||
|
||||
# Clear the queue
|
||||
while instance._to_write.qsize() > 0:
|
||||
instance._to_write.get_nowait()
|
||||
instance._message_queue.clear()
|
||||
|
||||
# Trigger the peak clear
|
||||
instance._send_message({})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue