Improve logging and handling when websocket gets behind (#86854)
fixes undefined
This commit is contained in:
parent
c612a92cfb
commit
0f4b17755e
7 changed files with 127 additions and 36 deletions
|
@ -113,18 +113,14 @@ def handle_subscribe_events(
|
|||
):
|
||||
return
|
||||
|
||||
connection.send_message(
|
||||
lambda: messages.cached_event_message(msg["id"], event)
|
||||
)
|
||||
connection.send_message(messages.cached_event_message(msg["id"], event))
|
||||
|
||||
else:
|
||||
|
||||
@callback
|
||||
def forward_events(event: Event) -> None:
|
||||
"""Forward events to websocket."""
|
||||
connection.send_message(
|
||||
lambda: messages.cached_event_message(msg["id"], event)
|
||||
)
|
||||
connection.send_message(messages.cached_event_message(msg["id"], event))
|
||||
|
||||
connection.subscriptions[msg["id"]] = hass.bus.async_listen(
|
||||
event_type, forward_events, run_immediately=True
|
||||
|
@ -296,9 +292,7 @@ def handle_subscribe_entities(
|
|||
if entity_ids and event.data["entity_id"] not in entity_ids:
|
||||
return
|
||||
|
||||
connection.send_message(
|
||||
lambda: messages.cached_state_diff_message(msg["id"], event)
|
||||
)
|
||||
connection.send_message(messages.cached_state_diff_message(msg["id"], event))
|
||||
|
||||
# We must never await between sending the states and listening for
|
||||
# state changed events or we will introduce a race condition
|
||||
|
|
|
@ -6,6 +6,7 @@ from collections.abc import Callable, Hashable
|
|||
from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from aiohttp import web
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.auth.models import RefreshToken, User
|
||||
|
@ -14,6 +15,7 @@ from homeassistant.core import Context, HomeAssistant, callback
|
|||
from homeassistant.exceptions import HomeAssistantError, Unauthorized
|
||||
|
||||
from . import const, messages
|
||||
from .util import describe_request
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .http import WebSocketAdapter
|
||||
|
@ -46,6 +48,13 @@ class ActiveConnection:
|
|||
self.supported_features: dict[str, float] = {}
|
||||
current_connection.set(self)
|
||||
|
||||
def get_description(self, request: web.Request | None) -> str:
|
||||
"""Return a description of the connection."""
|
||||
description = self.user.name or ""
|
||||
if request:
|
||||
description += " " + describe_request(request)
|
||||
return description
|
||||
|
||||
def context(self, msg: dict[str, Any]) -> Context:
|
||||
"""Return a context."""
|
||||
return Context(user_id=self.user.id)
|
||||
|
@ -142,9 +151,6 @@ class ActiveConnection:
|
|||
|
||||
if code:
|
||||
err_message += f" ({code})"
|
||||
if request := current_request.get():
|
||||
err_message += f" from {request.remote}"
|
||||
if user_agent := request.headers.get("user-agent"):
|
||||
err_message += f" ({user_agent})"
|
||||
err_message += " " + self.get_description(current_request.get())
|
||||
|
||||
log_handler("Error handling message: %s", err_message)
|
||||
|
|
|
@ -21,9 +21,12 @@ AsyncWebSocketCommandHandler = Callable[
|
|||
|
||||
DOMAIN: Final = "websocket_api"
|
||||
URL: Final = "/api/websocket"
|
||||
PENDING_MSG_PEAK: Final = 512
|
||||
PENDING_MSG_PEAK: Final = 1024
|
||||
PENDING_MSG_PEAK_TIME: Final = 5
|
||||
MAX_PENDING_MSG: Final = 2048
|
||||
# Maximum number of messages that can be pending at any given time.
|
||||
# This is effectively the upper limit of the number of entities
|
||||
# that can fire state changes within ~1 second.
|
||||
MAX_PENDING_MSG: Final = 4096
|
||||
|
||||
ERR_ID_REUSE: Final = "id_reuse"
|
||||
ERR_INVALID_FORMAT: Final = "invalid_format"
|
||||
|
|
|
@ -32,6 +32,7 @@ from .const import (
|
|||
)
|
||||
from .error import Disconnect
|
||||
from .messages import message_to_json
|
||||
from .util import describe_request
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .connection import ActiveConnection
|
||||
|
@ -73,10 +74,18 @@ class WebSocketHandler:
|
|||
self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG)
|
||||
self._handle_task: asyncio.Task | None = None
|
||||
self._writer_task: asyncio.Task | None = None
|
||||
self._closing: bool = False
|
||||
self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)})
|
||||
self._peak_checker_unsub: Callable[[], None] | None = None
|
||||
self.connection: ActiveConnection | None = None
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""Return a description of the connection."""
|
||||
if self.connection is not None:
|
||||
return self.connection.get_description(self.request)
|
||||
return describe_request(self.request)
|
||||
|
||||
async def _writer(self) -> None:
|
||||
"""Write outgoing messages."""
|
||||
# Exceptions if Socket disconnected or cancelled by connection handler
|
||||
|
@ -89,7 +98,6 @@ class WebSocketHandler:
|
|||
if (process := await to_write.get()) is None:
|
||||
return
|
||||
message = process if isinstance(process, str) else process()
|
||||
|
||||
if (
|
||||
to_write.empty()
|
||||
or not self.connection
|
||||
|
@ -109,13 +117,18 @@ class WebSocketHandler:
|
|||
)
|
||||
|
||||
coalesced_messages = "[" + ",".join(messages) + "]"
|
||||
self._logger.debug("Sending %s", coalesced_messages)
|
||||
await self.wsock.send_str(coalesced_messages)
|
||||
logger.debug("Sending %s", coalesced_messages)
|
||||
await wsock.send_str(coalesced_messages)
|
||||
finally:
|
||||
# Clean up the peaker checker when we shut down the writer
|
||||
if self._peak_checker_unsub is not None:
|
||||
self._peak_checker_unsub()
|
||||
self._peak_checker_unsub = None
|
||||
self._cancel_peak_checker()
|
||||
|
||||
@callback
|
||||
def _cancel_peak_checker(self) -> None:
|
||||
"""Cancel the peak checker."""
|
||||
if self._peak_checker_unsub is not None:
|
||||
self._peak_checker_unsub()
|
||||
self._peak_checker_unsub = None
|
||||
|
||||
@callback
|
||||
def _send_message(self, message: str | dict[str, Any] | Callable[[], str]) -> None:
|
||||
|
@ -125,25 +138,39 @@ class WebSocketHandler:
|
|||
|
||||
Async friendly.
|
||||
"""
|
||||
if self._closing:
|
||||
# Connection is cancelled, don't flood logs about exceeding
|
||||
# max pending messages.
|
||||
return
|
||||
|
||||
if isinstance(message, dict):
|
||||
message = message_to_json(message)
|
||||
|
||||
to_write = self._to_write
|
||||
|
||||
try:
|
||||
self._to_write.put_nowait(message)
|
||||
to_write.put_nowait(message)
|
||||
except asyncio.QueueFull:
|
||||
self._logger.error(
|
||||
"Client exceeded max pending messages [2]: %s", MAX_PENDING_MSG
|
||||
(
|
||||
"%s: Client unable to keep up with pending messages. Reached %s pending"
|
||||
" messages. The system's load is too high or an integration is"
|
||||
" misbehaving. Last message was: %s"
|
||||
),
|
||||
self.description,
|
||||
MAX_PENDING_MSG,
|
||||
message,
|
||||
)
|
||||
|
||||
self._cancel()
|
||||
|
||||
if self._to_write.qsize() < PENDING_MSG_PEAK:
|
||||
if self._peak_checker_unsub:
|
||||
self._peak_checker_unsub()
|
||||
self._peak_checker_unsub = None
|
||||
peak_checker_active = self._peak_checker_unsub is not None
|
||||
|
||||
if to_write.qsize() < PENDING_MSG_PEAK:
|
||||
if peak_checker_active:
|
||||
self._cancel_peak_checker()
|
||||
return
|
||||
|
||||
if self._peak_checker_unsub is None:
|
||||
if not peak_checker_active:
|
||||
self._peak_checker_unsub = async_call_later(
|
||||
self.hass, PENDING_MSG_PEAK_TIME, self._check_write_peak
|
||||
)
|
||||
|
@ -158,10 +185,11 @@ class WebSocketHandler:
|
|||
|
||||
self._logger.error(
|
||||
(
|
||||
"Client unable to keep up with pending messages. Stayed over %s for %s"
|
||||
"%s: Client unable to keep up with pending messages. Stayed over %s for %s"
|
||||
" seconds. The system's load is too high or an integration is"
|
||||
" misbehaving"
|
||||
),
|
||||
self.description,
|
||||
PENDING_MSG_PEAK,
|
||||
PENDING_MSG_PEAK_TIME,
|
||||
)
|
||||
|
@ -170,6 +198,7 @@ class WebSocketHandler:
|
|||
@callback
|
||||
def _cancel(self) -> None:
|
||||
"""Cancel the connection."""
|
||||
self._closing = True
|
||||
if self._handle_task is not None:
|
||||
self._handle_task.cancel()
|
||||
if self._writer_task is not None:
|
||||
|
@ -279,6 +308,8 @@ class WebSocketHandler:
|
|||
if connection is not None:
|
||||
connection.async_handle_close()
|
||||
|
||||
self._closing = True
|
||||
|
||||
try:
|
||||
self._to_write.put_nowait(None)
|
||||
# Make sure all error messages are written before closing
|
||||
|
|
13
homeassistant/components/websocket_api/util.py
Normal file
13
homeassistant/components/websocket_api/util.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
"""Websocket API util.""."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
|
||||
def describe_request(request: web.Request) -> str:
|
||||
"""Describe a request."""
|
||||
description = f"from {request.remote}"
|
||||
if user_agent := request.headers.get("user-agent"):
|
||||
description += f" ({user_agent})"
|
||||
return description
|
|
@ -21,37 +21,37 @@ from tests.common import MockUser
|
|||
exceptions.Unauthorized(),
|
||||
websocket_api.ERR_UNAUTHORIZED,
|
||||
"Unauthorized",
|
||||
"Error handling message: Unauthorized (unauthorized) from 127.0.0.42 (Browser)",
|
||||
"Error handling message: Unauthorized (unauthorized) Mock User from 127.0.0.42 (Browser)",
|
||||
),
|
||||
(
|
||||
vol.Invalid("Invalid something"),
|
||||
websocket_api.ERR_INVALID_FORMAT,
|
||||
"Invalid something. Got {'id': 5}",
|
||||
"Error handling message: Invalid something. Got {'id': 5} (invalid_format) from 127.0.0.42 (Browser)",
|
||||
"Error handling message: Invalid something. Got {'id': 5} (invalid_format) Mock User from 127.0.0.42 (Browser)",
|
||||
),
|
||||
(
|
||||
asyncio.TimeoutError(),
|
||||
websocket_api.ERR_TIMEOUT,
|
||||
"Timeout",
|
||||
"Error handling message: Timeout (timeout) from 127.0.0.42 (Browser)",
|
||||
"Error handling message: Timeout (timeout) Mock User from 127.0.0.42 (Browser)",
|
||||
),
|
||||
(
|
||||
exceptions.HomeAssistantError("Failed to do X"),
|
||||
websocket_api.ERR_UNKNOWN_ERROR,
|
||||
"Failed to do X",
|
||||
"Error handling message: Failed to do X (unknown_error) from 127.0.0.42 (Browser)",
|
||||
"Error handling message: Failed to do X (unknown_error) Mock User from 127.0.0.42 (Browser)",
|
||||
),
|
||||
(
|
||||
ValueError("Really bad"),
|
||||
websocket_api.ERR_UNKNOWN_ERROR,
|
||||
"Unknown error",
|
||||
"Error handling message: Unknown error (unknown_error) from 127.0.0.42 (Browser)",
|
||||
"Error handling message: Unknown error (unknown_error) Mock User from 127.0.0.42 (Browser)",
|
||||
),
|
||||
(
|
||||
exceptions.HomeAssistantError,
|
||||
websocket_api.ERR_UNKNOWN_ERROR,
|
||||
"Unknown error",
|
||||
"Error handling message: Unknown error (unknown_error) from 127.0.0.42 (Browser)",
|
||||
"Error handling message: Unknown error (unknown_error) Mock User from 127.0.0.42 (Browser)",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
|
@ -67,6 +67,50 @@ async def test_pending_msg_peak(hass, mock_low_peak, hass_ws_client, caplog):
|
|||
assert "Client unable to keep up with pending messages" in caplog.text
|
||||
|
||||
|
||||
async def test_pending_msg_peak_but_does_not_overflow(
|
||||
hass, mock_low_peak, hass_ws_client, caplog
|
||||
):
|
||||
"""Test pending msg hits the low peak but recovers and does not overflow."""
|
||||
orig_handler = http.WebSocketHandler
|
||||
instance: http.WebSocketHandler | None = None
|
||||
|
||||
def instantiate_handler(*args):
|
||||
nonlocal instance
|
||||
instance = orig_handler(*args)
|
||||
return instance
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.websocket_api.http.WebSocketHandler",
|
||||
instantiate_handler,
|
||||
):
|
||||
websocket_client = await hass_ws_client()
|
||||
|
||||
assert instance is not None
|
||||
|
||||
# Kill writer task and fill queue past peak
|
||||
for _ in range(5):
|
||||
instance._to_write.put_nowait(None)
|
||||
|
||||
# Trigger the peak check
|
||||
instance._send_message({})
|
||||
|
||||
# Clear the queue
|
||||
while instance._to_write.qsize() > 0:
|
||||
instance._to_write.get_nowait()
|
||||
|
||||
# Trigger the peak clear
|
||||
instance._send_message({})
|
||||
|
||||
async_fire_time_changed(
|
||||
hass, utcnow() + timedelta(seconds=const.PENDING_MSG_PEAK_TIME + 1)
|
||||
)
|
||||
|
||||
msg = await websocket_client.receive()
|
||||
assert msg.type == WSMsgType.TEXT
|
||||
|
||||
assert "Client unable to keep up with pending messages" not in caplog.text
|
||||
|
||||
|
||||
async def test_non_json_message(hass, websocket_client, caplog):
|
||||
"""Test trying to serialize non JSON objects."""
|
||||
bad_data = object()
|
||||
|
|
Loading…
Add table
Reference in a new issue