Fix memory leaks in websocket api (#94780)

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
J. Nick Koston 2023-06-19 18:27:22 -05:00 committed by GitHub
parent a7d327afa2
commit 1206f2c1da
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 244 additions and 84 deletions

View file

@ -56,6 +56,10 @@ class ActiveConnection:
self.binary_handlers: list[BinaryHandler | None] = []
current_connection.set(self)
def __repr__(self) -> str:
"""Return the representation."""
return f"<ActiveConnection {self.get_description(None)}>"
def set_supported_features(self, features: dict[str, float]) -> None:
"""Set supported features."""
self.supported_features = features
@ -193,7 +197,24 @@ class ActiveConnection:
def async_handle_close(self) -> None:
"""Handle closing down connection."""
for unsub in self.subscriptions.values():
try:
unsub()
except Exception: # pylint: disable=broad-except
# If one fails, make sure we still try the rest
self.logger.exception(
"Error unsubscribing from subscription: %s", unsub
)
self.subscriptions.clear()
self.send_message = self._connect_closed_error
current_request.set(None)
current_connection.set(None)
@callback
def _connect_closed_error(
self, msg: str | dict[str, Any] | Callable[[], str]
) -> None:
"""Send a message when the connection is closed."""
self.logger.debug("Tried to send message %s on closed connection", msg)
@callback
def async_handle_exception(self, msg: dict[str, Any], err: Exception) -> None:

View file

@ -1,9 +1,7 @@
"""Websocket constants."""
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from concurrent import futures
from typing import TYPE_CHECKING, Any, Final
from homeassistant.core import HomeAssistant
@ -42,10 +40,6 @@ ERR_TEMPLATE_ERROR: Final = "template_error"
TYPE_RESULT: Final = "result"
# Define the possible errors that occur when connections are cancelled.
# Originally, this was just asyncio.CancelledError, but issue #9546 showed
# that futures.CancelledErrors can also occur in some situations.
CANCELLATION_ERRORS: Final = (asyncio.CancelledError, futures.CancelledError)
# Event types
SIGNAL_WEBSOCKET_CONNECTED: Final = "websocket_connected"

View file

@ -4,7 +4,6 @@ from __future__ import annotations
import asyncio
from collections import deque
from collections.abc import Callable
from contextlib import suppress
import datetime as dt
import logging
from typing import TYPE_CHECKING, Any, Final
@ -21,7 +20,6 @@ from homeassistant.util.json import json_loads
from .auth import AuthPhase, auth_required_message
from .const import (
CANCELLATION_ERRORS,
DATA_CONNECTIONS,
MAX_PENDING_MSG,
PENDING_MSG_PEAK,
@ -68,15 +66,16 @@ class WebSocketHandler:
def __init__(self, hass: HomeAssistant, request: web.Request) -> None:
"""Initialize an active connection."""
self.hass = hass
self.request = request
self.wsock = web.WebSocketResponse(heartbeat=55)
self._hass = hass
self._request: web.Request = request
self._wsock = web.WebSocketResponse(heartbeat=55)
self._handle_task: asyncio.Task | None = None
self._writer_task: asyncio.Task | None = None
self._closing: bool = False
self._authenticated: bool = False
self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)})
self._peak_checker_unsub: Callable[[], None] | None = None
self.connection: ActiveConnection | None = None
self._connection: ActiveConnection | None = None
# The WebSocketHandler has a single consumer and path
# to where messages are queued. This allows the implementation
@ -85,25 +84,38 @@ class WebSocketHandler:
self._message_queue: deque = deque()
self._ready_future: asyncio.Future[None] | None = None
def __repr__(self) -> str:
"""Return the representation."""
return (
"<WebSocketHandler "
f"closing={self._closing} "
f"authenticated={self._authenticated} "
f"description={self.description}>"
)
@property
def description(self) -> str:
"""Return a description of the connection."""
if self.connection is not None:
return self.connection.get_description(self.request)
return describe_request(self.request)
if connection := self._connection:
return connection.get_description(self._request)
if request := self._request:
return describe_request(request)
return "finished connection"
async def _writer(self) -> None:
"""Write outgoing messages."""
# Variables are set locally to avoid lookups in the loop
message_queue = self._message_queue
logger = self._logger
send_str = self.wsock.send_str
loop = self.hass.loop
wsock = self._wsock
send_str = wsock.send_str
loop = self._hass.loop
debug = logger.debug
is_enabled_for = logger.isEnabledFor
logging_debug = logging.DEBUG
# Exceptions if Socket disconnected or cancelled by connection handler
try:
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
while not self.wsock.closed:
while not wsock.closed:
if (messages_remaining := len(message_queue)) == 0:
self._ready_future = loop.create_future()
await self._ready_future
@ -113,15 +125,17 @@ class WebSocketHandler:
if (process := message_queue.popleft()) is None:
return
debug_enabled = is_enabled_for(logging_debug)
messages_remaining -= 1
message = process if isinstance(process, str) else process()
if (
not messages_remaining
or not self.connection
or not self.connection.can_coalesce
or not (connection := self._connection)
or not connection.can_coalesce
):
debug("Sending %s", message)
if debug_enabled:
debug("%s: Sending %s", self.description, message)
await send_str(message)
continue
@ -130,16 +144,21 @@ class WebSocketHandler:
# A None message is used to signal the end of the connection
if (process := message_queue.popleft()) is None:
return
messages.append(
process if isinstance(process, str) else process()
)
messages.append(process if isinstance(process, str) else process())
messages_remaining -= 1
joined_messages = ",".join(messages)
coalesced_messages = f"[{joined_messages}]"
debug("Sending %s", coalesced_messages)
if debug_enabled:
debug("%s: Sending %s", self.description, coalesced_messages)
await send_str(coalesced_messages)
except asyncio.CancelledError:
debug("%s: Writer cancelled", self.description)
raise
except (RuntimeError, ConnectionResetError) as ex:
debug("%s: Unexpected error in writer: %s", self.description, ex)
finally:
debug("%s: Writer done", self.description)
# Clean up the peak checker when we shut down the writer
self._cancel_peak_checker()
@ -195,7 +214,7 @@ class WebSocketHandler:
if not peak_checker_active:
self._peak_checker_unsub = async_call_later(
self.hass, PENDING_MSG_PEAK_TIME, self._check_write_peak
self._hass, PENDING_MSG_PEAK_TIME, self._check_write_peak
)
@callback
@ -231,8 +250,14 @@ class WebSocketHandler:
async def async_handle(self) -> web.WebSocketResponse:
"""Handle a websocket response."""
request = self.request
wsock = self.wsock
request = self._request
wsock = self._wsock
logger = self._logger
debug = logger.debug
hass = self._hass
is_enabled_for = logger.isEnabledFor
logging_debug = logging.DEBUG
try:
async with async_timeout.timeout(10):
await wsock.prepare(request)
@ -240,7 +265,7 @@ class WebSocketHandler:
self._logger.warning("Timeout preparing request from %s", request.remote)
return wsock
self._logger.debug("Connected from %s", request.remote)
debug("%s: Connected from %s", self.description, request.remote)
self._handle_task = asyncio.current_task()
@callback
@ -248,17 +273,13 @@ class WebSocketHandler:
"""Cancel this connection."""
self._cancel()
unsub_stop = self.hass.bus.async_listen(
EVENT_HOMEASSISTANT_STOP, handle_hass_stop
)
unsub_stop = hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, handle_hass_stop)
# As the webserver is now started before the start
# event we do not want to block for websocket responses
self._writer_task = asyncio.create_task(self._writer())
auth = AuthPhase(
self._logger, self.hass, self._send_message, self._cancel, request
)
auth = AuthPhase(logger, hass, self._send_message, self._cancel, request)
connection = None
disconnect_warn = None
@ -286,13 +307,14 @@ class WebSocketHandler:
disconnect_warn = "Received invalid JSON."
raise Disconnect from err
self._logger.debug("Received %s", msg_data)
self.connection = connection = await auth.async_handle(msg_data)
self.hass.data[DATA_CONNECTIONS] = (
self.hass.data.get(DATA_CONNECTIONS, 0) + 1
)
async_dispatcher_send(self.hass, SIGNAL_WEBSOCKET_CONNECTED)
if is_enabled_for(logging_debug):
debug("%s: Received %s", self.description, msg_data)
connection = await auth.async_handle(msg_data)
self._connection = connection
hass.data[DATA_CONNECTIONS] = hass.data.get(DATA_CONNECTIONS, 0) + 1
async_dispatcher_send(hass, SIGNAL_WEBSOCKET_CONNECTED)
self._authenticated = True
#
#
# Our websocket implementation is backed by an asyncio.Queue
@ -356,7 +378,9 @@ class WebSocketHandler:
disconnect_warn = "Received invalid JSON."
break
self._logger.debug("Received %s", msg_data)
if is_enabled_for(logging_debug):
debug("%s: Received %s", self.description, msg_data)
if not isinstance(msg_data, list):
connection.async_handle(msg_data)
continue
@ -365,17 +389,22 @@ class WebSocketHandler:
connection.async_handle(split_msg)
except asyncio.CancelledError:
self._logger.info("Connection closed by client")
debug("%s: Connection cancelled", self.description)
raise
except Disconnect:
pass
except Disconnect as ex:
debug("%s: Connection closed by client: %s", self.description, ex)
except Exception: # pylint: disable=broad-except
self._logger.exception("Unexpected error inside websocket API")
self._logger.exception(
"%s: Unexpected error inside websocket API", self.description
)
finally:
unsub_stop()
self._cancel_peak_checker()
if connection is not None:
connection.async_handle_close()
@ -385,20 +414,37 @@ class WebSocketHandler:
if self._ready_future and not self._ready_future.done():
self._ready_future.set_result(None)
# If the writer gets canceled we still need to close the websocket
# so we have another finally block to make sure we close the websocket
# if the writer gets canceled.
try:
await self._writer_task
finally:
try:
# Make sure all error messages are written before closing
await self._writer_task
await wsock.close()
finally:
if disconnect_warn is None:
self._logger.debug("Disconnected")
debug("%s: Disconnected", self.description)
else:
self._logger.warning("Disconnected: %s", disconnect_warn)
self._logger.warning(
"%s: Disconnected: %s", self.description, disconnect_warn
)
if connection is not None:
self.hass.data[DATA_CONNECTIONS] -= 1
self.connection = None
hass.data[DATA_CONNECTIONS] -= 1
self._connection = None
async_dispatcher_send(self.hass, SIGNAL_WEBSOCKET_DISCONNECTED)
async_dispatcher_send(hass, SIGNAL_WEBSOCKET_DISCONNECTED)
# Break reference cycles to make sure GC can happen sooner
self._wsock = None # type: ignore[assignment]
self._request = None # type: ignore[assignment]
self._hass = None # type: ignore[assignment]
self._logger = None # type: ignore[assignment]
self._message_queue = None # type: ignore[assignment]
self._handle_task = None
self._writer_task = None
self._ready_future = None
return wsock

View file

@ -3,11 +3,19 @@ import pytest
from homeassistant.components.websocket_api.auth import TYPE_AUTH_REQUIRED
from homeassistant.components.websocket_api.http import URL
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from tests.typing import (
MockHAClientWebSocket,
WebSocketGenerator,
)
@pytest.fixture
async def websocket_client(hass, hass_ws_client):
async def websocket_client(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator
) -> MockHAClientWebSocket:
"""Create a websocket client."""
return await hass_ws_client(hass)

View file

@ -18,7 +18,10 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.util.dt import utcnow
from tests.common import async_fire_time_changed
from tests.typing import WebSocketGenerator
from tests.typing import (
MockHAClientWebSocket,
WebSocketGenerator,
)
@pytest.fixture
@ -36,15 +39,103 @@ def mock_low_peak():
async def test_pending_msg_overflow(
hass: HomeAssistant, mock_low_queue, websocket_client
hass: HomeAssistant, mock_low_queue, websocket_client: MockHAClientWebSocket
) -> None:
"""Test get_panels command."""
"""Test pending messages overflows."""
for idx in range(10):
await websocket_client.send_json({"id": idx + 1, "type": "ping"})
msg = await websocket_client.receive()
assert msg.type == WSMsgType.close
async def test_cleanup_on_cancellation(
hass: HomeAssistant, websocket_client: MockHAClientWebSocket
) -> None:
"""Test cleanup on cancellation."""
subscriptions = None
# Register a handler that registers a subscription
@callback
@websocket_command(
{
"type": "fake_subscription",
}
)
def fake_subscription(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
nonlocal subscriptions
msg_id: int = msg["id"]
connection.subscriptions[msg_id] = callback(lambda: None)
connection.send_result(msg_id)
subscriptions = connection.subscriptions
async_register_command(hass, fake_subscription)
# Register a handler that raises on cancel
@callback
@websocket_command(
{
"type": "subscription_that_raises_on_cancel",
}
)
def subscription_that_raises_on_cancel(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
nonlocal subscriptions
msg_id: int = msg["id"]
@callback
def _raise():
raise ValueError()
connection.subscriptions[msg_id] = _raise
connection.send_result(msg_id)
subscriptions = connection.subscriptions
async_register_command(hass, subscription_that_raises_on_cancel)
# Register a handler that cancels in handler
@callback
@websocket_command(
{
"type": "cancel_in_handler",
}
)
def cancel_in_handler(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
raise asyncio.CancelledError()
async_register_command(hass, cancel_in_handler)
await websocket_client.send_json({"id": 1, "type": "ping"})
msg = await websocket_client.receive_json()
assert msg["id"] == 1
assert msg["type"] == "pong"
assert not subscriptions
await websocket_client.send_json({"id": 2, "type": "fake_subscription"})
msg = await websocket_client.receive_json()
assert msg["id"] == 2
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
assert len(subscriptions) == 2
await websocket_client.send_json(
{"id": 3, "type": "subscription_that_raises_on_cancel"}
)
msg = await websocket_client.receive_json()
assert msg["id"] == 3
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
assert len(subscriptions) == 3
await websocket_client.send_json({"id": 4, "type": "cancel_in_handler"})
await hass.async_block_till_done()
msg = await websocket_client.receive()
assert msg.type == WSMsgType.close
assert len(subscriptions) == 0
async def test_pending_msg_peak(
hass: HomeAssistant,
mock_low_peak,