Improve logging and handling when websocket gets behind (#86854)

fixes undefined
This commit is contained in:
J. Nick Koston 2023-01-29 10:49:27 -10:00 committed by GitHub
parent c612a92cfb
commit 0f4b17755e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 127 additions and 36 deletions

View file

@ -113,18 +113,14 @@ def handle_subscribe_events(
):
return
connection.send_message(
lambda: messages.cached_event_message(msg["id"], event)
)
connection.send_message(messages.cached_event_message(msg["id"], event))
else:
@callback
def forward_events(event: Event) -> None:
"""Forward events to websocket."""
connection.send_message(
lambda: messages.cached_event_message(msg["id"], event)
)
connection.send_message(messages.cached_event_message(msg["id"], event))
connection.subscriptions[msg["id"]] = hass.bus.async_listen(
event_type, forward_events, run_immediately=True
@ -296,9 +292,7 @@ def handle_subscribe_entities(
if entity_ids and event.data["entity_id"] not in entity_ids:
return
connection.send_message(
lambda: messages.cached_state_diff_message(msg["id"], event)
)
connection.send_message(messages.cached_state_diff_message(msg["id"], event))
# We must never await between sending the states and listening for
# state changed events or we will introduce a race condition

View file

@ -6,6 +6,7 @@ from collections.abc import Callable, Hashable
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any
from aiohttp import web
import voluptuous as vol
from homeassistant.auth.models import RefreshToken, User
@ -14,6 +15,7 @@ from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, Unauthorized
from . import const, messages
from .util import describe_request
if TYPE_CHECKING:
from .http import WebSocketAdapter
@ -46,6 +48,13 @@ class ActiveConnection:
self.supported_features: dict[str, float] = {}
current_connection.set(self)
def get_description(self, request: web.Request | None) -> str:
"""Return a description of the connection."""
description = self.user.name or ""
if request:
description += " " + describe_request(request)
return description
def context(self, msg: dict[str, Any]) -> Context:
"""Return a context."""
return Context(user_id=self.user.id)
@ -142,9 +151,6 @@ class ActiveConnection:
if code:
err_message += f" ({code})"
if request := current_request.get():
err_message += f" from {request.remote}"
if user_agent := request.headers.get("user-agent"):
err_message += f" ({user_agent})"
err_message += " " + self.get_description(current_request.get())
log_handler("Error handling message: %s", err_message)

View file

@ -21,9 +21,12 @@ AsyncWebSocketCommandHandler = Callable[
DOMAIN: Final = "websocket_api"
URL: Final = "/api/websocket"
PENDING_MSG_PEAK: Final = 512
PENDING_MSG_PEAK: Final = 1024
PENDING_MSG_PEAK_TIME: Final = 5
MAX_PENDING_MSG: Final = 2048
# Maximum number of messages that can be pending at any given time.
# This is effectively the upper limit of the number of entities
# that can fire state changes within ~1 second.
MAX_PENDING_MSG: Final = 4096
ERR_ID_REUSE: Final = "id_reuse"
ERR_INVALID_FORMAT: Final = "invalid_format"

View file

@ -32,6 +32,7 @@ from .const import (
)
from .error import Disconnect
from .messages import message_to_json
from .util import describe_request
if TYPE_CHECKING:
from .connection import ActiveConnection
@ -73,10 +74,18 @@ class WebSocketHandler:
self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG)
self._handle_task: asyncio.Task | None = None
self._writer_task: asyncio.Task | None = None
self._closing: bool = False
self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)})
self._peak_checker_unsub: Callable[[], None] | None = None
self.connection: ActiveConnection | None = None
@property
def description(self) -> str:
"""Return a description of the connection."""
if self.connection is not None:
return self.connection.get_description(self.request)
return describe_request(self.request)
async def _writer(self) -> None:
"""Write outgoing messages."""
# Exceptions if Socket disconnected or cancelled by connection handler
@ -89,7 +98,6 @@ class WebSocketHandler:
if (process := await to_write.get()) is None:
return
message = process if isinstance(process, str) else process()
if (
to_write.empty()
or not self.connection
@ -109,13 +117,18 @@ class WebSocketHandler:
)
coalesced_messages = "[" + ",".join(messages) + "]"
self._logger.debug("Sending %s", coalesced_messages)
await self.wsock.send_str(coalesced_messages)
logger.debug("Sending %s", coalesced_messages)
await wsock.send_str(coalesced_messages)
finally:
# Clean up the peaker checker when we shut down the writer
if self._peak_checker_unsub is not None:
self._peak_checker_unsub()
self._peak_checker_unsub = None
self._cancel_peak_checker()
@callback
def _cancel_peak_checker(self) -> None:
"""Cancel the peak checker."""
if self._peak_checker_unsub is not None:
self._peak_checker_unsub()
self._peak_checker_unsub = None
@callback
def _send_message(self, message: str | dict[str, Any] | Callable[[], str]) -> None:
@ -125,25 +138,39 @@ class WebSocketHandler:
Async friendly.
"""
if self._closing:
# Connection is cancelled, don't flood logs about exceeding
# max pending messages.
return
if isinstance(message, dict):
message = message_to_json(message)
to_write = self._to_write
try:
self._to_write.put_nowait(message)
to_write.put_nowait(message)
except asyncio.QueueFull:
self._logger.error(
"Client exceeded max pending messages [2]: %s", MAX_PENDING_MSG
(
"%s: Client unable to keep up with pending messages. Reached %s pending"
" messages. The system's load is too high or an integration is"
" misbehaving. Last message was: %s"
),
self.description,
MAX_PENDING_MSG,
message,
)
self._cancel()
if self._to_write.qsize() < PENDING_MSG_PEAK:
if self._peak_checker_unsub:
self._peak_checker_unsub()
self._peak_checker_unsub = None
peak_checker_active = self._peak_checker_unsub is not None
if to_write.qsize() < PENDING_MSG_PEAK:
if peak_checker_active:
self._cancel_peak_checker()
return
if self._peak_checker_unsub is None:
if not peak_checker_active:
self._peak_checker_unsub = async_call_later(
self.hass, PENDING_MSG_PEAK_TIME, self._check_write_peak
)
@ -158,10 +185,11 @@ class WebSocketHandler:
self._logger.error(
(
"Client unable to keep up with pending messages. Stayed over %s for %s"
"%s: Client unable to keep up with pending messages. Stayed over %s for %s"
" seconds. The system's load is too high or an integration is"
" misbehaving"
),
self.description,
PENDING_MSG_PEAK,
PENDING_MSG_PEAK_TIME,
)
@ -170,6 +198,7 @@ class WebSocketHandler:
@callback
def _cancel(self) -> None:
"""Cancel the connection."""
self._closing = True
if self._handle_task is not None:
self._handle_task.cancel()
if self._writer_task is not None:
@ -279,6 +308,8 @@ class WebSocketHandler:
if connection is not None:
connection.async_handle_close()
self._closing = True
try:
self._to_write.put_nowait(None)
# Make sure all error messages are written before closing

View file

@ -0,0 +1,13 @@
"""Websocket API util.""."""
from __future__ import annotations
from aiohttp import web
def describe_request(request: web.Request) -> str:
"""Describe a request."""
description = f"from {request.remote}"
if user_agent := request.headers.get("user-agent"):
description += f" ({user_agent})"
return description

View file

@ -21,37 +21,37 @@ from tests.common import MockUser
exceptions.Unauthorized(),
websocket_api.ERR_UNAUTHORIZED,
"Unauthorized",
"Error handling message: Unauthorized (unauthorized) from 127.0.0.42 (Browser)",
"Error handling message: Unauthorized (unauthorized) Mock User from 127.0.0.42 (Browser)",
),
(
vol.Invalid("Invalid something"),
websocket_api.ERR_INVALID_FORMAT,
"Invalid something. Got {'id': 5}",
"Error handling message: Invalid something. Got {'id': 5} (invalid_format) from 127.0.0.42 (Browser)",
"Error handling message: Invalid something. Got {'id': 5} (invalid_format) Mock User from 127.0.0.42 (Browser)",
),
(
asyncio.TimeoutError(),
websocket_api.ERR_TIMEOUT,
"Timeout",
"Error handling message: Timeout (timeout) from 127.0.0.42 (Browser)",
"Error handling message: Timeout (timeout) Mock User from 127.0.0.42 (Browser)",
),
(
exceptions.HomeAssistantError("Failed to do X"),
websocket_api.ERR_UNKNOWN_ERROR,
"Failed to do X",
"Error handling message: Failed to do X (unknown_error) from 127.0.0.42 (Browser)",
"Error handling message: Failed to do X (unknown_error) Mock User from 127.0.0.42 (Browser)",
),
(
ValueError("Really bad"),
websocket_api.ERR_UNKNOWN_ERROR,
"Unknown error",
"Error handling message: Unknown error (unknown_error) from 127.0.0.42 (Browser)",
"Error handling message: Unknown error (unknown_error) Mock User from 127.0.0.42 (Browser)",
),
(
exceptions.HomeAssistantError,
websocket_api.ERR_UNKNOWN_ERROR,
"Unknown error",
"Error handling message: Unknown error (unknown_error) from 127.0.0.42 (Browser)",
"Error handling message: Unknown error (unknown_error) Mock User from 127.0.0.42 (Browser)",
),
],
)

View file

@ -67,6 +67,50 @@ async def test_pending_msg_peak(hass, mock_low_peak, hass_ws_client, caplog):
assert "Client unable to keep up with pending messages" in caplog.text
async def test_pending_msg_peak_but_does_not_overflow(
hass, mock_low_peak, hass_ws_client, caplog
):
"""Test pending msg hits the low peak but recovers and does not overflow."""
orig_handler = http.WebSocketHandler
instance: http.WebSocketHandler | None = None
def instantiate_handler(*args):
nonlocal instance
instance = orig_handler(*args)
return instance
with patch(
"homeassistant.components.websocket_api.http.WebSocketHandler",
instantiate_handler,
):
websocket_client = await hass_ws_client()
assert instance is not None
# Kill writer task and fill queue past peak
for _ in range(5):
instance._to_write.put_nowait(None)
# Trigger the peak check
instance._send_message({})
# Clear the queue
while instance._to_write.qsize() > 0:
instance._to_write.get_nowait()
# Trigger the peak clear
instance._send_message({})
async_fire_time_changed(
hass, utcnow() + timedelta(seconds=const.PENDING_MSG_PEAK_TIME + 1)
)
msg = await websocket_client.receive()
assert msg.type == WSMsgType.TEXT
assert "Client unable to keep up with pending messages" not in caplog.text
async def test_non_json_message(hass, websocket_client, caplog):
"""Test trying to serialize non JSON objects."""
bad_data = object()