Improve websocket message coalescing to handle thundering herds better (#118268)

* Increase websocket peak messages to match max expected entities

During startup the websocket would frequently disconnect if more than
4096 entities were added back to back. Some MQTT setups will have more
than 10000 entities. Match the websocket peak value to the max expected
entities

* coalesce more

* delay more if the backlog gets large

* wait to send if the queue is building rapidly

* tweak

* tweak for chrome since it works great in firefox but chrome cannot handle it

* Revert "tweak for chrome since it works great in firefox but chrome cannot handle it"

This reverts commit 439e2d76b1.

* adjust for chrome

* lower number

* remove code

* fixes

* fast path for bytes

* compact

* adjust test since we see the close right away now on overload

* simplify check

* reduce loop

* tweak

* handle ready right away
This commit is contained in:
J. Nick Koston 2024-05-28 17:14:06 -10:00 committed by GitHub
parent b94bf1f214
commit 79bc179ce8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 124 additions and 69 deletions

View file

@ -125,6 +125,7 @@ as part of a config flow.
from __future__ import annotations from __future__ import annotations
import asyncio
from collections.abc import Callable from collections.abc import Callable
from datetime import datetime, timedelta from datetime import datetime, timedelta
from http import HTTPStatus from http import HTTPStatus
@ -168,6 +169,8 @@ type RetrieveResultType = Callable[[str, str], Credentials | None]
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN) CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
DELETE_CURRENT_TOKEN_DELAY = 2
@bind_hass @bind_hass
def create_auth_code( def create_auth_code(
@ -644,11 +647,34 @@ def websocket_delete_all_refresh_tokens(
else: else:
connection.send_result(msg["id"], {}) connection.send_result(msg["id"], {})
async def _delete_current_token_soon() -> None:
"""Delete the current token after a delay.
We do not want to delete the current token immediately as it will
close the connection.
This is implemented as a tracked task to ensure the token
is still deleted if Home Assistant is shut down during
the delay.
It should not be refactored to use a call_later as that
would not be tracked and the token would not be deleted
if Home Assistant was shut down during the delay.
"""
try:
await asyncio.sleep(DELETE_CURRENT_TOKEN_DELAY)
finally:
# If the task is cancelled because we are shutting down, delete
# the token right away.
hass.auth.async_remove_refresh_token(current_refresh_token)
if delete_current_token and ( if delete_current_token and (
not limit_token_types or current_refresh_token.token_type == token_type not limit_token_types or current_refresh_token.token_type == token_type
): ):
# This will close the connection so we need to send the result first. # Deleting the token will close the connection so we need
hass.loop.call_soon(hass.auth.async_remove_refresh_token, current_refresh_token) # to do it with a delay in a tracked task to ensure it still
# happens if Home Assistant is shutting down.
hass.async_create_task(_delete_current_token_soon())
@websocket_api.websocket_command( @websocket_api.websocket_command(

View file

@ -25,8 +25,15 @@ PENDING_MSG_PEAK_TIME: Final = 5
# Maximum number of messages that can be pending at any given time. # Maximum number of messages that can be pending at any given time.
# This is effectively the upper limit of the number of entities # This is effectively the upper limit of the number of entities
# that can fire state changes within ~1 second. # that can fire state changes within ~1 second.
# Ideally we would use homeassistant.const.MAX_EXPECTED_ENTITY_IDS
# but since chrome will lock up with too many messages we need to
# limit it to a lower number.
MAX_PENDING_MSG: Final = 4096 MAX_PENDING_MSG: Final = 4096
# Maximum number of messages that are pending before we force
# resolve the ready future.
PENDING_MSG_MAX_FORCE_READY: Final = 256
ERR_ID_REUSE: Final = "id_reuse" ERR_ID_REUSE: Final = "id_reuse"
ERR_INVALID_FORMAT: Final = "invalid_format" ERR_INVALID_FORMAT: Final = "invalid_format"
ERR_NOT_ALLOWED: Final = "not_allowed" ERR_NOT_ALLOWED: Final = "not_allowed"

View file

@ -24,6 +24,7 @@ from .auth import AUTH_REQUIRED_MESSAGE, AuthPhase
from .const import ( from .const import (
DATA_CONNECTIONS, DATA_CONNECTIONS,
MAX_PENDING_MSG, MAX_PENDING_MSG,
PENDING_MSG_MAX_FORCE_READY,
PENDING_MSG_PEAK, PENDING_MSG_PEAK,
PENDING_MSG_PEAK_TIME, PENDING_MSG_PEAK_TIME,
SIGNAL_WEBSOCKET_CONNECTED, SIGNAL_WEBSOCKET_CONNECTED,
@ -67,6 +68,7 @@ class WebSocketHandler:
__slots__ = ( __slots__ = (
"_hass", "_hass",
"_loop",
"_request", "_request",
"_wsock", "_wsock",
"_handle_task", "_handle_task",
@ -78,11 +80,13 @@ class WebSocketHandler:
"_connection", "_connection",
"_message_queue", "_message_queue",
"_ready_future", "_ready_future",
"_release_ready_queue_size",
) )
def __init__(self, hass: HomeAssistant, request: web.Request) -> None: def __init__(self, hass: HomeAssistant, request: web.Request) -> None:
"""Initialize an active connection.""" """Initialize an active connection."""
self._hass = hass self._hass = hass
self._loop = hass.loop
self._request: web.Request = request self._request: web.Request = request
self._wsock = web.WebSocketResponse(heartbeat=55) self._wsock = web.WebSocketResponse(heartbeat=55)
self._handle_task: asyncio.Task | None = None self._handle_task: asyncio.Task | None = None
@ -97,8 +101,9 @@ class WebSocketHandler:
# to where messages are queued. This allows the implementation # to where messages are queued. This allows the implementation
# to use a deque and an asyncio.Future to avoid the overhead of # to use a deque and an asyncio.Future to avoid the overhead of
# an asyncio.Queue. # an asyncio.Queue.
self._message_queue: deque[bytes | None] = deque() self._message_queue: deque[bytes] = deque()
self._ready_future: asyncio.Future[None] | None = None self._ready_future: asyncio.Future[int] | None = None
self._release_ready_queue_size: int = 0
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return the representation.""" """Return the representation."""
@ -126,45 +131,35 @@ class WebSocketHandler:
message_queue = self._message_queue message_queue = self._message_queue
logger = self._logger logger = self._logger
wsock = self._wsock wsock = self._wsock
loop = self._hass.loop loop = self._loop
is_debug_log_enabled = partial(logger.isEnabledFor, logging.DEBUG)
debug = logger.debug debug = logger.debug
is_enabled_for = logger.isEnabledFor can_coalesce = self._connection and self._connection.can_coalesce
logging_debug = logging.DEBUG ready_message_count = len(message_queue)
# Exceptions if Socket disconnected or cancelled by connection handler # Exceptions if Socket disconnected or cancelled by connection handler
try: try:
while not wsock.closed: while not wsock.closed:
if (messages_remaining := len(message_queue)) == 0: if not message_queue:
self._ready_future = loop.create_future() self._ready_future = loop.create_future()
await self._ready_future ready_message_count = await self._ready_future
messages_remaining = len(message_queue)
# A None message is used to signal the end of the connection if self._closing:
if (message := message_queue.popleft()) is None:
return return
debug_enabled = is_enabled_for(logging_debug) if not can_coalesce:
messages_remaining -= 1 # coalesce may be enabled later in the connection
can_coalesce = self._connection and self._connection.can_coalesce
if ( if not can_coalesce or ready_message_count == 1:
not messages_remaining message = message_queue.popleft()
or not (connection := self._connection) if is_debug_log_enabled():
or not connection.can_coalesce
):
if debug_enabled:
debug("%s: Sending %s", self.description, message) debug("%s: Sending %s", self.description, message)
await send_bytes_text(message) await send_bytes_text(message)
continue continue
messages: list[bytes] = [message] coalesced_messages = b"".join((b"[", b",".join(message_queue), b"]"))
while messages_remaining: message_queue.clear()
# A None message is used to signal the end of the connection if is_debug_log_enabled():
if (message := message_queue.popleft()) is None:
return
messages.append(message)
messages_remaining -= 1
coalesced_messages = b"".join((b"[", b",".join(messages), b"]"))
if debug_enabled:
debug("%s: Sending %s", self.description, coalesced_messages) debug("%s: Sending %s", self.description, coalesced_messages)
await send_bytes_text(coalesced_messages) await send_bytes_text(coalesced_messages)
except asyncio.CancelledError: except asyncio.CancelledError:
@ -197,14 +192,15 @@ class WebSocketHandler:
# max pending messages. # max pending messages.
return return
if isinstance(message, dict): if type(message) is not bytes: # noqa: E721
message = message_to_json_bytes(message) if isinstance(message, dict):
elif isinstance(message, str): message = message_to_json_bytes(message)
message = message.encode("utf-8") elif isinstance(message, str):
message = message.encode("utf-8")
message_queue = self._message_queue message_queue = self._message_queue
queue_size_before_add = len(message_queue) message_queue.append(message)
if queue_size_before_add >= MAX_PENDING_MSG: if (queue_size_after_add := len(message_queue)) >= MAX_PENDING_MSG:
self._logger.error( self._logger.error(
( (
"%s: Client unable to keep up with pending messages. Reached %s pending" "%s: Client unable to keep up with pending messages. Reached %s pending"
@ -218,14 +214,14 @@ class WebSocketHandler:
self._cancel() self._cancel()
return return
message_queue.append(message) if self._release_ready_queue_size == 0:
ready_future = self._ready_future # Try to coalesce more messages to reduce the number of writes
if ready_future and not ready_future.done(): self._release_ready_queue_size = queue_size_after_add
ready_future.set_result(None) self._loop.call_soon(self._release_ready_future_or_reschedule)
peak_checker_active = self._peak_checker_unsub is not None peak_checker_active = self._peak_checker_unsub is not None
if queue_size_before_add <= PENDING_MSG_PEAK: if queue_size_after_add <= PENDING_MSG_PEAK:
if peak_checker_active: if peak_checker_active:
self._cancel_peak_checker() self._cancel_peak_checker()
return return
@ -235,6 +231,32 @@ class WebSocketHandler:
self._hass, PENDING_MSG_PEAK_TIME, self._check_write_peak self._hass, PENDING_MSG_PEAK_TIME, self._check_write_peak
) )
@callback
def _release_ready_future_or_reschedule(self) -> None:
"""Release the ready future or reschedule.
We will release the ready future if the queue did not grow since the
last time we tried to release the ready future.
If we reach PENDING_MSG_MAX_FORCE_READY, we will release the ready future
immediately so avoid the coalesced messages from growing too large.
"""
if not (ready_future := self._ready_future) or not (
queue_size := len(self._message_queue)
):
self._release_ready_queue_size = 0
return
# If we are below the max pending to force ready, and there are new messages
# in the queue since the last time we tried to release the ready future, we
# try again later so we can coalesce more messages.
if queue_size > self._release_ready_queue_size < PENDING_MSG_MAX_FORCE_READY:
self._release_ready_queue_size = queue_size
self._loop.call_soon(self._release_ready_future_or_reschedule)
return
self._release_ready_queue_size = 0
if not ready_future.done():
ready_future.set_result(queue_size)
@callback @callback
def _check_write_peak(self, _utc_time: dt.datetime) -> None: def _check_write_peak(self, _utc_time: dt.datetime) -> None:
"""Check that we are no longer above the write peak.""" """Check that we are no longer above the write peak."""
@ -440,10 +462,8 @@ class WebSocketHandler:
connection.async_handle_close() connection.async_handle_close()
self._closing = True self._closing = True
self._message_queue.append(None)
if self._ready_future and not self._ready_future.done(): if self._ready_future and not self._ready_future.done():
self._ready_future.set_result(None) self._ready_future.set_result(len(self._message_queue))
# If the writer gets canceled we still need to close the websocket # If the writer gets canceled we still need to close the websocket
# so we have another finally block to make sure we close the websocket # so we have another finally block to make sure we close the websocket

View file

@ -546,20 +546,21 @@ async def test_ws_delete_all_refresh_tokens_error(
tokens = result["result"] tokens = result["result"]
await ws_client.send_json( with patch("homeassistant.components.auth.DELETE_CURRENT_TOKEN_DELAY", 0.001):
{ await ws_client.send_json(
"id": 6, {
"type": "auth/delete_all_refresh_tokens", "id": 6,
} "type": "auth/delete_all_refresh_tokens",
) }
)
caplog.clear() caplog.clear()
result = await ws_client.receive_json() result = await ws_client.receive_json()
assert result, result["success"] is False assert result, result["success"] is False
assert result["error"] == { assert result["error"] == {
"code": "token_removing_error", "code": "token_removing_error",
"message": "During removal, an error was raised.", "message": "During removal, an error was raised.",
} }
records = [ records = [
record record
@ -571,6 +572,7 @@ async def test_ws_delete_all_refresh_tokens_error(
assert records[0].exc_info and str(records[0].exc_info[1]) == "I'm bad" assert records[0].exc_info and str(records[0].exc_info[1]) == "I'm bad"
assert records[0].name == "homeassistant.components.auth" assert records[0].name == "homeassistant.components.auth"
await hass.async_block_till_done()
for token in tokens: for token in tokens:
refresh_token = hass.auth.async_get_refresh_token(token["id"]) refresh_token = hass.auth.async_get_refresh_token(token["id"])
assert refresh_token is None assert refresh_token is None
@ -629,18 +631,20 @@ async def test_ws_delete_all_refresh_tokens(
result = await ws_client.receive_json() result = await ws_client.receive_json()
assert result["success"], result assert result["success"], result
await ws_client.send_json( with patch("homeassistant.components.auth.DELETE_CURRENT_TOKEN_DELAY", 0.001):
{ await ws_client.send_json(
"id": 6, {
"type": "auth/delete_all_refresh_tokens", "id": 6,
**delete_token_type, "type": "auth/delete_all_refresh_tokens",
**delete_current_token, **delete_token_type,
} **delete_current_token,
) }
)
result = await ws_client.receive_json() result = await ws_client.receive_json()
assert result, result["success"] assert result, result["success"]
await hass.async_block_till_done()
# We need to enumerate the user since we may remove the token # We need to enumerate the user since we may remove the token
# that is used to authenticate the user which will prevent the websocket # that is used to authenticate the user which will prevent the websocket
# connection from working # connection from working

View file

@ -294,8 +294,6 @@ async def test_pending_msg_peak_recovery(
instance._send_message({}) instance._send_message({})
instance._handle_task.cancel() instance._handle_task.cancel()
msg = await websocket_client.receive()
assert msg.type == WSMsgType.TEXT
msg = await websocket_client.receive() msg = await websocket_client.receive()
assert msg.type is WSMsgType.CLOSE assert msg.type is WSMsgType.CLOSE
assert "Client unable to keep up with pending messages" not in caplog.text assert "Client unable to keep up with pending messages" not in caplog.text