Fix memory leaks in websocket api (#94780)
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
a7d327afa2
commit
1206f2c1da
5 changed files with 244 additions and 84 deletions
|
@ -56,6 +56,10 @@ class ActiveConnection:
|
|||
self.binary_handlers: list[BinaryHandler | None] = []
|
||||
current_connection.set(self)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return the representation."""
|
||||
return f"<ActiveConnection {self.get_description(None)}>"
|
||||
|
||||
def set_supported_features(self, features: dict[str, float]) -> None:
|
||||
"""Set supported features."""
|
||||
self.supported_features = features
|
||||
|
@ -193,7 +197,24 @@ class ActiveConnection:
|
|||
def async_handle_close(self) -> None:
|
||||
"""Handle closing down connection."""
|
||||
for unsub in self.subscriptions.values():
|
||||
try:
|
||||
unsub()
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# If one fails, make sure we still try the rest
|
||||
self.logger.exception(
|
||||
"Error unsubscribing from subscription: %s", unsub
|
||||
)
|
||||
self.subscriptions.clear()
|
||||
self.send_message = self._connect_closed_error
|
||||
current_request.set(None)
|
||||
current_connection.set(None)
|
||||
|
||||
@callback
|
||||
def _connect_closed_error(
|
||||
self, msg: str | dict[str, Any] | Callable[[], str]
|
||||
) -> None:
|
||||
"""Send a message when the connection is closed."""
|
||||
self.logger.debug("Tried to send message %s on closed connection", msg)
|
||||
|
||||
@callback
|
||||
def async_handle_exception(self, msg: dict[str, Any], err: Exception) -> None:
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
"""Websocket constants."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from concurrent import futures
|
||||
from typing import TYPE_CHECKING, Any, Final
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
@ -42,10 +40,6 @@ ERR_TEMPLATE_ERROR: Final = "template_error"
|
|||
|
||||
TYPE_RESULT: Final = "result"
|
||||
|
||||
# Define the possible errors that occur when connections are cancelled.
|
||||
# Originally, this was just asyncio.CancelledError, but issue #9546 showed
|
||||
# that futures.CancelledErrors can also occur in some situations.
|
||||
CANCELLATION_ERRORS: Final = (asyncio.CancelledError, futures.CancelledError)
|
||||
|
||||
# Event types
|
||||
SIGNAL_WEBSOCKET_CONNECTED: Final = "websocket_connected"
|
||||
|
|
|
@ -4,7 +4,6 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
from collections import deque
|
||||
from collections.abc import Callable
|
||||
from contextlib import suppress
|
||||
import datetime as dt
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Final
|
||||
|
@ -21,7 +20,6 @@ from homeassistant.util.json import json_loads
|
|||
|
||||
from .auth import AuthPhase, auth_required_message
|
||||
from .const import (
|
||||
CANCELLATION_ERRORS,
|
||||
DATA_CONNECTIONS,
|
||||
MAX_PENDING_MSG,
|
||||
PENDING_MSG_PEAK,
|
||||
|
@ -68,15 +66,16 @@ class WebSocketHandler:
|
|||
|
||||
def __init__(self, hass: HomeAssistant, request: web.Request) -> None:
|
||||
"""Initialize an active connection."""
|
||||
self.hass = hass
|
||||
self.request = request
|
||||
self.wsock = web.WebSocketResponse(heartbeat=55)
|
||||
self._hass = hass
|
||||
self._request: web.Request = request
|
||||
self._wsock = web.WebSocketResponse(heartbeat=55)
|
||||
self._handle_task: asyncio.Task | None = None
|
||||
self._writer_task: asyncio.Task | None = None
|
||||
self._closing: bool = False
|
||||
self._authenticated: bool = False
|
||||
self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)})
|
||||
self._peak_checker_unsub: Callable[[], None] | None = None
|
||||
self.connection: ActiveConnection | None = None
|
||||
self._connection: ActiveConnection | None = None
|
||||
|
||||
# The WebSocketHandler has a single consumer and path
|
||||
# to where messages are queued. This allows the implementation
|
||||
|
@ -85,25 +84,38 @@ class WebSocketHandler:
|
|||
self._message_queue: deque = deque()
|
||||
self._ready_future: asyncio.Future[None] | None = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return the representation."""
|
||||
return (
|
||||
"<WebSocketHandler "
|
||||
f"closing={self._closing} "
|
||||
f"authenticated={self._authenticated} "
|
||||
f"description={self.description}>"
|
||||
)
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""Return a description of the connection."""
|
||||
if self.connection is not None:
|
||||
return self.connection.get_description(self.request)
|
||||
return describe_request(self.request)
|
||||
if connection := self._connection:
|
||||
return connection.get_description(self._request)
|
||||
if request := self._request:
|
||||
return describe_request(request)
|
||||
return "finished connection"
|
||||
|
||||
async def _writer(self) -> None:
|
||||
"""Write outgoing messages."""
|
||||
# Variables are set locally to avoid lookups in the loop
|
||||
message_queue = self._message_queue
|
||||
logger = self._logger
|
||||
send_str = self.wsock.send_str
|
||||
loop = self.hass.loop
|
||||
wsock = self._wsock
|
||||
send_str = wsock.send_str
|
||||
loop = self._hass.loop
|
||||
debug = logger.debug
|
||||
is_enabled_for = logger.isEnabledFor
|
||||
logging_debug = logging.DEBUG
|
||||
# Exceptions if Socket disconnected or cancelled by connection handler
|
||||
try:
|
||||
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
|
||||
while not self.wsock.closed:
|
||||
while not wsock.closed:
|
||||
if (messages_remaining := len(message_queue)) == 0:
|
||||
self._ready_future = loop.create_future()
|
||||
await self._ready_future
|
||||
|
@ -113,15 +125,17 @@ class WebSocketHandler:
|
|||
if (process := message_queue.popleft()) is None:
|
||||
return
|
||||
|
||||
debug_enabled = is_enabled_for(logging_debug)
|
||||
messages_remaining -= 1
|
||||
message = process if isinstance(process, str) else process()
|
||||
|
||||
if (
|
||||
not messages_remaining
|
||||
or not self.connection
|
||||
or not self.connection.can_coalesce
|
||||
or not (connection := self._connection)
|
||||
or not connection.can_coalesce
|
||||
):
|
||||
debug("Sending %s", message)
|
||||
if debug_enabled:
|
||||
debug("%s: Sending %s", self.description, message)
|
||||
await send_str(message)
|
||||
continue
|
||||
|
||||
|
@ -130,16 +144,21 @@ class WebSocketHandler:
|
|||
# A None message is used to signal the end of the connection
|
||||
if (process := message_queue.popleft()) is None:
|
||||
return
|
||||
messages.append(
|
||||
process if isinstance(process, str) else process()
|
||||
)
|
||||
messages.append(process if isinstance(process, str) else process())
|
||||
messages_remaining -= 1
|
||||
|
||||
joined_messages = ",".join(messages)
|
||||
coalesced_messages = f"[{joined_messages}]"
|
||||
debug("Sending %s", coalesced_messages)
|
||||
if debug_enabled:
|
||||
debug("%s: Sending %s", self.description, coalesced_messages)
|
||||
await send_str(coalesced_messages)
|
||||
except asyncio.CancelledError:
|
||||
debug("%s: Writer cancelled", self.description)
|
||||
raise
|
||||
except (RuntimeError, ConnectionResetError) as ex:
|
||||
debug("%s: Unexpected error in writer: %s", self.description, ex)
|
||||
finally:
|
||||
debug("%s: Writer done", self.description)
|
||||
# Clean up the peak checker when we shut down the writer
|
||||
self._cancel_peak_checker()
|
||||
|
||||
|
@ -195,7 +214,7 @@ class WebSocketHandler:
|
|||
|
||||
if not peak_checker_active:
|
||||
self._peak_checker_unsub = async_call_later(
|
||||
self.hass, PENDING_MSG_PEAK_TIME, self._check_write_peak
|
||||
self._hass, PENDING_MSG_PEAK_TIME, self._check_write_peak
|
||||
)
|
||||
|
||||
@callback
|
||||
|
@ -231,8 +250,14 @@ class WebSocketHandler:
|
|||
|
||||
async def async_handle(self) -> web.WebSocketResponse:
|
||||
"""Handle a websocket response."""
|
||||
request = self.request
|
||||
wsock = self.wsock
|
||||
request = self._request
|
||||
wsock = self._wsock
|
||||
logger = self._logger
|
||||
debug = logger.debug
|
||||
hass = self._hass
|
||||
is_enabled_for = logger.isEnabledFor
|
||||
logging_debug = logging.DEBUG
|
||||
|
||||
try:
|
||||
async with async_timeout.timeout(10):
|
||||
await wsock.prepare(request)
|
||||
|
@ -240,7 +265,7 @@ class WebSocketHandler:
|
|||
self._logger.warning("Timeout preparing request from %s", request.remote)
|
||||
return wsock
|
||||
|
||||
self._logger.debug("Connected from %s", request.remote)
|
||||
debug("%s: Connected from %s", self.description, request.remote)
|
||||
self._handle_task = asyncio.current_task()
|
||||
|
||||
@callback
|
||||
|
@ -248,17 +273,13 @@ class WebSocketHandler:
|
|||
"""Cancel this connection."""
|
||||
self._cancel()
|
||||
|
||||
unsub_stop = self.hass.bus.async_listen(
|
||||
EVENT_HOMEASSISTANT_STOP, handle_hass_stop
|
||||
)
|
||||
unsub_stop = hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, handle_hass_stop)
|
||||
|
||||
# As the webserver is now started before the start
|
||||
# event we do not want to block for websocket responses
|
||||
self._writer_task = asyncio.create_task(self._writer())
|
||||
|
||||
auth = AuthPhase(
|
||||
self._logger, self.hass, self._send_message, self._cancel, request
|
||||
)
|
||||
auth = AuthPhase(logger, hass, self._send_message, self._cancel, request)
|
||||
connection = None
|
||||
disconnect_warn = None
|
||||
|
||||
|
@ -286,13 +307,14 @@ class WebSocketHandler:
|
|||
disconnect_warn = "Received invalid JSON."
|
||||
raise Disconnect from err
|
||||
|
||||
self._logger.debug("Received %s", msg_data)
|
||||
self.connection = connection = await auth.async_handle(msg_data)
|
||||
self.hass.data[DATA_CONNECTIONS] = (
|
||||
self.hass.data.get(DATA_CONNECTIONS, 0) + 1
|
||||
)
|
||||
async_dispatcher_send(self.hass, SIGNAL_WEBSOCKET_CONNECTED)
|
||||
if is_enabled_for(logging_debug):
|
||||
debug("%s: Received %s", self.description, msg_data)
|
||||
connection = await auth.async_handle(msg_data)
|
||||
self._connection = connection
|
||||
hass.data[DATA_CONNECTIONS] = hass.data.get(DATA_CONNECTIONS, 0) + 1
|
||||
async_dispatcher_send(hass, SIGNAL_WEBSOCKET_CONNECTED)
|
||||
|
||||
self._authenticated = True
|
||||
#
|
||||
#
|
||||
# Our websocket implementation is backed by an asyncio.Queue
|
||||
|
@ -356,7 +378,9 @@ class WebSocketHandler:
|
|||
disconnect_warn = "Received invalid JSON."
|
||||
break
|
||||
|
||||
self._logger.debug("Received %s", msg_data)
|
||||
if is_enabled_for(logging_debug):
|
||||
debug("%s: Received %s", self.description, msg_data)
|
||||
|
||||
if not isinstance(msg_data, list):
|
||||
connection.async_handle(msg_data)
|
||||
continue
|
||||
|
@ -365,17 +389,22 @@ class WebSocketHandler:
|
|||
connection.async_handle(split_msg)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self._logger.info("Connection closed by client")
|
||||
debug("%s: Connection cancelled", self.description)
|
||||
raise
|
||||
|
||||
except Disconnect:
|
||||
pass
|
||||
except Disconnect as ex:
|
||||
debug("%s: Connection closed by client: %s", self.description, ex)
|
||||
|
||||
except Exception: # pylint: disable=broad-except
|
||||
self._logger.exception("Unexpected error inside websocket API")
|
||||
self._logger.exception(
|
||||
"%s: Unexpected error inside websocket API", self.description
|
||||
)
|
||||
|
||||
finally:
|
||||
unsub_stop()
|
||||
|
||||
self._cancel_peak_checker()
|
||||
|
||||
if connection is not None:
|
||||
connection.async_handle_close()
|
||||
|
||||
|
@ -385,20 +414,37 @@ class WebSocketHandler:
|
|||
if self._ready_future and not self._ready_future.done():
|
||||
self._ready_future.set_result(None)
|
||||
|
||||
# If the writer gets canceled we still need to close the websocket
|
||||
# so we have another finally block to make sure we close the websocket
|
||||
# if the writer gets canceled.
|
||||
try:
|
||||
await self._writer_task
|
||||
finally:
|
||||
try:
|
||||
# Make sure all error messages are written before closing
|
||||
await self._writer_task
|
||||
await wsock.close()
|
||||
finally:
|
||||
if disconnect_warn is None:
|
||||
self._logger.debug("Disconnected")
|
||||
debug("%s: Disconnected", self.description)
|
||||
else:
|
||||
self._logger.warning("Disconnected: %s", disconnect_warn)
|
||||
self._logger.warning(
|
||||
"%s: Disconnected: %s", self.description, disconnect_warn
|
||||
)
|
||||
|
||||
if connection is not None:
|
||||
self.hass.data[DATA_CONNECTIONS] -= 1
|
||||
self.connection = None
|
||||
hass.data[DATA_CONNECTIONS] -= 1
|
||||
self._connection = None
|
||||
|
||||
async_dispatcher_send(self.hass, SIGNAL_WEBSOCKET_DISCONNECTED)
|
||||
async_dispatcher_send(hass, SIGNAL_WEBSOCKET_DISCONNECTED)
|
||||
|
||||
# Break reference cycles to make sure GC can happen sooner
|
||||
self._wsock = None # type: ignore[assignment]
|
||||
self._request = None # type: ignore[assignment]
|
||||
self._hass = None # type: ignore[assignment]
|
||||
self._logger = None # type: ignore[assignment]
|
||||
self._message_queue = None # type: ignore[assignment]
|
||||
self._handle_task = None
|
||||
self._writer_task = None
|
||||
self._ready_future = None
|
||||
|
||||
return wsock
|
||||
|
|
|
@ -3,11 +3,19 @@ import pytest
|
|||
|
||||
from homeassistant.components.websocket_api.auth import TYPE_AUTH_REQUIRED
|
||||
from homeassistant.components.websocket_api.http import URL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.typing import (
|
||||
MockHAClientWebSocket,
|
||||
WebSocketGenerator,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def websocket_client(hass, hass_ws_client):
|
||||
async def websocket_client(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator
|
||||
) -> MockHAClientWebSocket:
|
||||
"""Create a websocket client."""
|
||||
return await hass_ws_client(hass)
|
||||
|
||||
|
|
|
@ -18,7 +18,10 @@ from homeassistant.core import HomeAssistant, callback
|
|||
from homeassistant.util.dt import utcnow
|
||||
|
||||
from tests.common import async_fire_time_changed
|
||||
from tests.typing import WebSocketGenerator
|
||||
from tests.typing import (
|
||||
MockHAClientWebSocket,
|
||||
WebSocketGenerator,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -36,15 +39,103 @@ def mock_low_peak():
|
|||
|
||||
|
||||
async def test_pending_msg_overflow(
|
||||
hass: HomeAssistant, mock_low_queue, websocket_client
|
||||
hass: HomeAssistant, mock_low_queue, websocket_client: MockHAClientWebSocket
|
||||
) -> None:
|
||||
"""Test get_panels command."""
|
||||
"""Test pending messages overflows."""
|
||||
for idx in range(10):
|
||||
await websocket_client.send_json({"id": idx + 1, "type": "ping"})
|
||||
msg = await websocket_client.receive()
|
||||
assert msg.type == WSMsgType.close
|
||||
|
||||
|
||||
async def test_cleanup_on_cancellation(
|
||||
hass: HomeAssistant, websocket_client: MockHAClientWebSocket
|
||||
) -> None:
|
||||
"""Test cleanup on cancellation."""
|
||||
|
||||
subscriptions = None
|
||||
|
||||
# Register a handler that registers a subscription
|
||||
@callback
|
||||
@websocket_command(
|
||||
{
|
||||
"type": "fake_subscription",
|
||||
}
|
||||
)
|
||||
def fake_subscription(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
nonlocal subscriptions
|
||||
msg_id: int = msg["id"]
|
||||
connection.subscriptions[msg_id] = callback(lambda: None)
|
||||
connection.send_result(msg_id)
|
||||
subscriptions = connection.subscriptions
|
||||
|
||||
async_register_command(hass, fake_subscription)
|
||||
|
||||
# Register a handler that raises on cancel
|
||||
@callback
|
||||
@websocket_command(
|
||||
{
|
||||
"type": "subscription_that_raises_on_cancel",
|
||||
}
|
||||
)
|
||||
def subscription_that_raises_on_cancel(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
nonlocal subscriptions
|
||||
msg_id: int = msg["id"]
|
||||
|
||||
@callback
|
||||
def _raise():
|
||||
raise ValueError()
|
||||
|
||||
connection.subscriptions[msg_id] = _raise
|
||||
connection.send_result(msg_id)
|
||||
subscriptions = connection.subscriptions
|
||||
|
||||
async_register_command(hass, subscription_that_raises_on_cancel)
|
||||
|
||||
# Register a handler that cancels in handler
|
||||
@callback
|
||||
@websocket_command(
|
||||
{
|
||||
"type": "cancel_in_handler",
|
||||
}
|
||||
)
|
||||
def cancel_in_handler(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
async_register_command(hass, cancel_in_handler)
|
||||
|
||||
await websocket_client.send_json({"id": 1, "type": "ping"})
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == 1
|
||||
assert msg["type"] == "pong"
|
||||
assert not subscriptions
|
||||
await websocket_client.send_json({"id": 2, "type": "fake_subscription"})
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == 2
|
||||
assert msg["type"] == const.TYPE_RESULT
|
||||
assert msg["success"]
|
||||
assert len(subscriptions) == 2
|
||||
await websocket_client.send_json(
|
||||
{"id": 3, "type": "subscription_that_raises_on_cancel"}
|
||||
)
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == 3
|
||||
assert msg["type"] == const.TYPE_RESULT
|
||||
assert msg["success"]
|
||||
assert len(subscriptions) == 3
|
||||
await websocket_client.send_json({"id": 4, "type": "cancel_in_handler"})
|
||||
await hass.async_block_till_done()
|
||||
msg = await websocket_client.receive()
|
||||
assert msg.type == WSMsgType.close
|
||||
assert len(subscriptions) == 0
|
||||
|
||||
|
||||
async def test_pending_msg_peak(
|
||||
hass: HomeAssistant,
|
||||
mock_low_peak,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue