Improve logging and handling when websocket gets behind (#86854)

fixes undefined
This commit is contained in:
J. Nick Koston 2023-01-29 10:49:27 -10:00 committed by GitHub
parent c612a92cfb
commit 0f4b17755e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 127 additions and 36 deletions

View file

@ -113,18 +113,14 @@ def handle_subscribe_events(
): ):
return return
connection.send_message( connection.send_message(messages.cached_event_message(msg["id"], event))
lambda: messages.cached_event_message(msg["id"], event)
)
else: else:
@callback @callback
def forward_events(event: Event) -> None: def forward_events(event: Event) -> None:
"""Forward events to websocket.""" """Forward events to websocket."""
connection.send_message( connection.send_message(messages.cached_event_message(msg["id"], event))
lambda: messages.cached_event_message(msg["id"], event)
)
connection.subscriptions[msg["id"]] = hass.bus.async_listen( connection.subscriptions[msg["id"]] = hass.bus.async_listen(
event_type, forward_events, run_immediately=True event_type, forward_events, run_immediately=True
@ -296,9 +292,7 @@ def handle_subscribe_entities(
if entity_ids and event.data["entity_id"] not in entity_ids: if entity_ids and event.data["entity_id"] not in entity_ids:
return return
connection.send_message( connection.send_message(messages.cached_state_diff_message(msg["id"], event))
lambda: messages.cached_state_diff_message(msg["id"], event)
)
# We must never await between sending the states and listening for # We must never await between sending the states and listening for
# state changed events or we will introduce a race condition # state changed events or we will introduce a race condition

View file

@ -6,6 +6,7 @@ from collections.abc import Callable, Hashable
from contextvars import ContextVar from contextvars import ContextVar
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from aiohttp import web
import voluptuous as vol import voluptuous as vol
from homeassistant.auth.models import RefreshToken, User from homeassistant.auth.models import RefreshToken, User
@ -14,6 +15,7 @@ from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, Unauthorized from homeassistant.exceptions import HomeAssistantError, Unauthorized
from . import const, messages from . import const, messages
from .util import describe_request
if TYPE_CHECKING: if TYPE_CHECKING:
from .http import WebSocketAdapter from .http import WebSocketAdapter
@ -46,6 +48,13 @@ class ActiveConnection:
self.supported_features: dict[str, float] = {} self.supported_features: dict[str, float] = {}
current_connection.set(self) current_connection.set(self)
def get_description(self, request: web.Request | None) -> str:
"""Return a description of the connection."""
description = self.user.name or ""
if request:
description += " " + describe_request(request)
return description
def context(self, msg: dict[str, Any]) -> Context: def context(self, msg: dict[str, Any]) -> Context:
"""Return a context.""" """Return a context."""
return Context(user_id=self.user.id) return Context(user_id=self.user.id)
@ -142,9 +151,6 @@ class ActiveConnection:
if code: if code:
err_message += f" ({code})" err_message += f" ({code})"
if request := current_request.get(): err_message += " " + self.get_description(current_request.get())
err_message += f" from {request.remote}"
if user_agent := request.headers.get("user-agent"):
err_message += f" ({user_agent})"
log_handler("Error handling message: %s", err_message) log_handler("Error handling message: %s", err_message)

View file

@ -21,9 +21,12 @@ AsyncWebSocketCommandHandler = Callable[
DOMAIN: Final = "websocket_api" DOMAIN: Final = "websocket_api"
URL: Final = "/api/websocket" URL: Final = "/api/websocket"
PENDING_MSG_PEAK: Final = 512 PENDING_MSG_PEAK: Final = 1024
PENDING_MSG_PEAK_TIME: Final = 5 PENDING_MSG_PEAK_TIME: Final = 5
MAX_PENDING_MSG: Final = 2048 # Maximum number of messages that can be pending at any given time.
# This is effectively the upper limit of the number of entities
# that can fire state changes within ~1 second.
MAX_PENDING_MSG: Final = 4096
ERR_ID_REUSE: Final = "id_reuse" ERR_ID_REUSE: Final = "id_reuse"
ERR_INVALID_FORMAT: Final = "invalid_format" ERR_INVALID_FORMAT: Final = "invalid_format"

View file

@ -32,6 +32,7 @@ from .const import (
) )
from .error import Disconnect from .error import Disconnect
from .messages import message_to_json from .messages import message_to_json
from .util import describe_request
if TYPE_CHECKING: if TYPE_CHECKING:
from .connection import ActiveConnection from .connection import ActiveConnection
@ -73,10 +74,18 @@ class WebSocketHandler:
self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG) self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG)
self._handle_task: asyncio.Task | None = None self._handle_task: asyncio.Task | None = None
self._writer_task: asyncio.Task | None = None self._writer_task: asyncio.Task | None = None
self._closing: bool = False
self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)}) self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)})
self._peak_checker_unsub: Callable[[], None] | None = None self._peak_checker_unsub: Callable[[], None] | None = None
self.connection: ActiveConnection | None = None self.connection: ActiveConnection | None = None
@property
def description(self) -> str:
"""Return a description of the connection."""
if self.connection is not None:
return self.connection.get_description(self.request)
return describe_request(self.request)
async def _writer(self) -> None: async def _writer(self) -> None:
"""Write outgoing messages.""" """Write outgoing messages."""
# Exceptions if Socket disconnected or cancelled by connection handler # Exceptions if Socket disconnected or cancelled by connection handler
@ -89,7 +98,6 @@ class WebSocketHandler:
if (process := await to_write.get()) is None: if (process := await to_write.get()) is None:
return return
message = process if isinstance(process, str) else process() message = process if isinstance(process, str) else process()
if ( if (
to_write.empty() to_write.empty()
or not self.connection or not self.connection
@ -109,10 +117,15 @@ class WebSocketHandler:
) )
coalesced_messages = "[" + ",".join(messages) + "]" coalesced_messages = "[" + ",".join(messages) + "]"
self._logger.debug("Sending %s", coalesced_messages) logger.debug("Sending %s", coalesced_messages)
await self.wsock.send_str(coalesced_messages) await wsock.send_str(coalesced_messages)
finally: finally:
# Clean up the peaker checker when we shut down the writer # Clean up the peaker checker when we shut down the writer
self._cancel_peak_checker()
@callback
def _cancel_peak_checker(self) -> None:
"""Cancel the peak checker."""
if self._peak_checker_unsub is not None: if self._peak_checker_unsub is not None:
self._peak_checker_unsub() self._peak_checker_unsub()
self._peak_checker_unsub = None self._peak_checker_unsub = None
@ -125,25 +138,39 @@ class WebSocketHandler:
Async friendly. Async friendly.
""" """
if self._closing:
# Connection is cancelled, don't flood logs about exceeding
# max pending messages.
return
if isinstance(message, dict): if isinstance(message, dict):
message = message_to_json(message) message = message_to_json(message)
to_write = self._to_write
try: try:
self._to_write.put_nowait(message) to_write.put_nowait(message)
except asyncio.QueueFull: except asyncio.QueueFull:
self._logger.error( self._logger.error(
"Client exceeded max pending messages [2]: %s", MAX_PENDING_MSG (
"%s: Client unable to keep up with pending messages. Reached %s pending"
" messages. The system's load is too high or an integration is"
" misbehaving. Last message was: %s"
),
self.description,
MAX_PENDING_MSG,
message,
) )
self._cancel() self._cancel()
if self._to_write.qsize() < PENDING_MSG_PEAK: peak_checker_active = self._peak_checker_unsub is not None
if self._peak_checker_unsub:
self._peak_checker_unsub() if to_write.qsize() < PENDING_MSG_PEAK:
self._peak_checker_unsub = None if peak_checker_active:
self._cancel_peak_checker()
return return
if self._peak_checker_unsub is None: if not peak_checker_active:
self._peak_checker_unsub = async_call_later( self._peak_checker_unsub = async_call_later(
self.hass, PENDING_MSG_PEAK_TIME, self._check_write_peak self.hass, PENDING_MSG_PEAK_TIME, self._check_write_peak
) )
@ -158,10 +185,11 @@ class WebSocketHandler:
self._logger.error( self._logger.error(
( (
"Client unable to keep up with pending messages. Stayed over %s for %s" "%s: Client unable to keep up with pending messages. Stayed over %s for %s"
" seconds. The system's load is too high or an integration is" " seconds. The system's load is too high or an integration is"
" misbehaving" " misbehaving"
), ),
self.description,
PENDING_MSG_PEAK, PENDING_MSG_PEAK,
PENDING_MSG_PEAK_TIME, PENDING_MSG_PEAK_TIME,
) )
@ -170,6 +198,7 @@ class WebSocketHandler:
@callback @callback
def _cancel(self) -> None: def _cancel(self) -> None:
"""Cancel the connection.""" """Cancel the connection."""
self._closing = True
if self._handle_task is not None: if self._handle_task is not None:
self._handle_task.cancel() self._handle_task.cancel()
if self._writer_task is not None: if self._writer_task is not None:
@ -279,6 +308,8 @@ class WebSocketHandler:
if connection is not None: if connection is not None:
connection.async_handle_close() connection.async_handle_close()
self._closing = True
try: try:
self._to_write.put_nowait(None) self._to_write.put_nowait(None)
# Make sure all error messages are written before closing # Make sure all error messages are written before closing

View file

@ -0,0 +1,13 @@
"""Websocket API util.""."""
from __future__ import annotations
from aiohttp import web
def describe_request(request: web.Request) -> str:
"""Describe a request."""
description = f"from {request.remote}"
if user_agent := request.headers.get("user-agent"):
description += f" ({user_agent})"
return description

View file

@ -21,37 +21,37 @@ from tests.common import MockUser
exceptions.Unauthorized(), exceptions.Unauthorized(),
websocket_api.ERR_UNAUTHORIZED, websocket_api.ERR_UNAUTHORIZED,
"Unauthorized", "Unauthorized",
"Error handling message: Unauthorized (unauthorized) from 127.0.0.42 (Browser)", "Error handling message: Unauthorized (unauthorized) Mock User from 127.0.0.42 (Browser)",
), ),
( (
vol.Invalid("Invalid something"), vol.Invalid("Invalid something"),
websocket_api.ERR_INVALID_FORMAT, websocket_api.ERR_INVALID_FORMAT,
"Invalid something. Got {'id': 5}", "Invalid something. Got {'id': 5}",
"Error handling message: Invalid something. Got {'id': 5} (invalid_format) from 127.0.0.42 (Browser)", "Error handling message: Invalid something. Got {'id': 5} (invalid_format) Mock User from 127.0.0.42 (Browser)",
), ),
( (
asyncio.TimeoutError(), asyncio.TimeoutError(),
websocket_api.ERR_TIMEOUT, websocket_api.ERR_TIMEOUT,
"Timeout", "Timeout",
"Error handling message: Timeout (timeout) from 127.0.0.42 (Browser)", "Error handling message: Timeout (timeout) Mock User from 127.0.0.42 (Browser)",
), ),
( (
exceptions.HomeAssistantError("Failed to do X"), exceptions.HomeAssistantError("Failed to do X"),
websocket_api.ERR_UNKNOWN_ERROR, websocket_api.ERR_UNKNOWN_ERROR,
"Failed to do X", "Failed to do X",
"Error handling message: Failed to do X (unknown_error) from 127.0.0.42 (Browser)", "Error handling message: Failed to do X (unknown_error) Mock User from 127.0.0.42 (Browser)",
), ),
( (
ValueError("Really bad"), ValueError("Really bad"),
websocket_api.ERR_UNKNOWN_ERROR, websocket_api.ERR_UNKNOWN_ERROR,
"Unknown error", "Unknown error",
"Error handling message: Unknown error (unknown_error) from 127.0.0.42 (Browser)", "Error handling message: Unknown error (unknown_error) Mock User from 127.0.0.42 (Browser)",
), ),
( (
exceptions.HomeAssistantError, exceptions.HomeAssistantError,
websocket_api.ERR_UNKNOWN_ERROR, websocket_api.ERR_UNKNOWN_ERROR,
"Unknown error", "Unknown error",
"Error handling message: Unknown error (unknown_error) from 127.0.0.42 (Browser)", "Error handling message: Unknown error (unknown_error) Mock User from 127.0.0.42 (Browser)",
), ),
], ],
) )

View file

@ -67,6 +67,50 @@ async def test_pending_msg_peak(hass, mock_low_peak, hass_ws_client, caplog):
assert "Client unable to keep up with pending messages" in caplog.text assert "Client unable to keep up with pending messages" in caplog.text
async def test_pending_msg_peak_but_does_not_overflow(
hass, mock_low_peak, hass_ws_client, caplog
):
"""Test pending msg hits the low peak but recovers and does not overflow."""
orig_handler = http.WebSocketHandler
instance: http.WebSocketHandler | None = None
def instantiate_handler(*args):
nonlocal instance
instance = orig_handler(*args)
return instance
with patch(
"homeassistant.components.websocket_api.http.WebSocketHandler",
instantiate_handler,
):
websocket_client = await hass_ws_client()
assert instance is not None
# Kill writer task and fill queue past peak
for _ in range(5):
instance._to_write.put_nowait(None)
# Trigger the peak check
instance._send_message({})
# Clear the queue
while instance._to_write.qsize() > 0:
instance._to_write.get_nowait()
# Trigger the peak clear
instance._send_message({})
async_fire_time_changed(
hass, utcnow() + timedelta(seconds=const.PENDING_MSG_PEAK_TIME + 1)
)
msg = await websocket_client.receive()
assert msg.type == WSMsgType.TEXT
assert "Client unable to keep up with pending messages" not in caplog.text
async def test_non_json_message(hass, websocket_client, caplog): async def test_non_json_message(hass, websocket_client, caplog):
"""Test trying to serialize non JSON objects.""" """Test trying to serialize non JSON objects."""
bad_data = object() bad_data = object()