Fix memory leaks in websocket api (#94780)

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
J. Nick Koston 2023-06-19 18:27:22 -05:00 committed by GitHub
parent a7d327afa2
commit 1206f2c1da
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 244 additions and 84 deletions

View file

@ -56,6 +56,10 @@ class ActiveConnection:
self.binary_handlers: list[BinaryHandler | None] = [] self.binary_handlers: list[BinaryHandler | None] = []
current_connection.set(self) current_connection.set(self)
def __repr__(self) -> str:
"""Return the representation."""
return f"<ActiveConnection {self.get_description(None)}>"
def set_supported_features(self, features: dict[str, float]) -> None: def set_supported_features(self, features: dict[str, float]) -> None:
"""Set supported features.""" """Set supported features."""
self.supported_features = features self.supported_features = features
@ -193,7 +197,24 @@ class ActiveConnection:
def async_handle_close(self) -> None: def async_handle_close(self) -> None:
"""Handle closing down connection.""" """Handle closing down connection."""
for unsub in self.subscriptions.values(): for unsub in self.subscriptions.values():
try:
unsub() unsub()
except Exception: # pylint: disable=broad-except
# If one fails, make sure we still try the rest
self.logger.exception(
"Error unsubscribing from subscription: %s", unsub
)
self.subscriptions.clear()
self.send_message = self._connect_closed_error
current_request.set(None)
current_connection.set(None)
@callback
def _connect_closed_error(
self, msg: str | dict[str, Any] | Callable[[], str]
) -> None:
"""Send a message when the connection is closed."""
self.logger.debug("Tried to send message %s on closed connection", msg)
@callback @callback
def async_handle_exception(self, msg: dict[str, Any], err: Exception) -> None: def async_handle_exception(self, msg: dict[str, Any], err: Exception) -> None:

View file

@ -1,9 +1,7 @@
"""Websocket constants.""" """Websocket constants."""
from __future__ import annotations from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from concurrent import futures
from typing import TYPE_CHECKING, Any, Final from typing import TYPE_CHECKING, Any, Final
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -42,10 +40,6 @@ ERR_TEMPLATE_ERROR: Final = "template_error"
TYPE_RESULT: Final = "result" TYPE_RESULT: Final = "result"
# Define the possible errors that occur when connections are cancelled.
# Originally, this was just asyncio.CancelledError, but issue #9546 showed
# that futures.CancelledErrors can also occur in some situations.
CANCELLATION_ERRORS: Final = (asyncio.CancelledError, futures.CancelledError)
# Event types # Event types
SIGNAL_WEBSOCKET_CONNECTED: Final = "websocket_connected" SIGNAL_WEBSOCKET_CONNECTED: Final = "websocket_connected"

View file

@ -4,7 +4,6 @@ from __future__ import annotations
import asyncio import asyncio
from collections import deque from collections import deque
from collections.abc import Callable from collections.abc import Callable
from contextlib import suppress
import datetime as dt import datetime as dt
import logging import logging
from typing import TYPE_CHECKING, Any, Final from typing import TYPE_CHECKING, Any, Final
@ -21,7 +20,6 @@ from homeassistant.util.json import json_loads
from .auth import AuthPhase, auth_required_message from .auth import AuthPhase, auth_required_message
from .const import ( from .const import (
CANCELLATION_ERRORS,
DATA_CONNECTIONS, DATA_CONNECTIONS,
MAX_PENDING_MSG, MAX_PENDING_MSG,
PENDING_MSG_PEAK, PENDING_MSG_PEAK,
@ -68,15 +66,16 @@ class WebSocketHandler:
def __init__(self, hass: HomeAssistant, request: web.Request) -> None: def __init__(self, hass: HomeAssistant, request: web.Request) -> None:
"""Initialize an active connection.""" """Initialize an active connection."""
self.hass = hass self._hass = hass
self.request = request self._request: web.Request = request
self.wsock = web.WebSocketResponse(heartbeat=55) self._wsock = web.WebSocketResponse(heartbeat=55)
self._handle_task: asyncio.Task | None = None self._handle_task: asyncio.Task | None = None
self._writer_task: asyncio.Task | None = None self._writer_task: asyncio.Task | None = None
self._closing: bool = False self._closing: bool = False
self._authenticated: bool = False
self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)}) self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)})
self._peak_checker_unsub: Callable[[], None] | None = None self._peak_checker_unsub: Callable[[], None] | None = None
self.connection: ActiveConnection | None = None self._connection: ActiveConnection | None = None
# The WebSocketHandler has a single consumer and path # The WebSocketHandler has a single consumer and path
# to where messages are queued. This allows the implementation # to where messages are queued. This allows the implementation
@ -85,25 +84,38 @@ class WebSocketHandler:
self._message_queue: deque = deque() self._message_queue: deque = deque()
self._ready_future: asyncio.Future[None] | None = None self._ready_future: asyncio.Future[None] | None = None
def __repr__(self) -> str:
"""Return the representation."""
return (
"<WebSocketHandler "
f"closing={self._closing} "
f"authenticated={self._authenticated} "
f"description={self.description}>"
)
@property @property
def description(self) -> str: def description(self) -> str:
"""Return a description of the connection.""" """Return a description of the connection."""
if self.connection is not None: if connection := self._connection:
return self.connection.get_description(self.request) return connection.get_description(self._request)
return describe_request(self.request) if request := self._request:
return describe_request(request)
return "finished connection"
async def _writer(self) -> None: async def _writer(self) -> None:
"""Write outgoing messages.""" """Write outgoing messages."""
# Variables are set locally to avoid lookups in the loop # Variables are set locally to avoid lookups in the loop
message_queue = self._message_queue message_queue = self._message_queue
logger = self._logger logger = self._logger
send_str = self.wsock.send_str wsock = self._wsock
loop = self.hass.loop send_str = wsock.send_str
loop = self._hass.loop
debug = logger.debug debug = logger.debug
is_enabled_for = logger.isEnabledFor
logging_debug = logging.DEBUG
# Exceptions if Socket disconnected or cancelled by connection handler # Exceptions if Socket disconnected or cancelled by connection handler
try: try:
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS): while not wsock.closed:
while not self.wsock.closed:
if (messages_remaining := len(message_queue)) == 0: if (messages_remaining := len(message_queue)) == 0:
self._ready_future = loop.create_future() self._ready_future = loop.create_future()
await self._ready_future await self._ready_future
@ -113,15 +125,17 @@ class WebSocketHandler:
if (process := message_queue.popleft()) is None: if (process := message_queue.popleft()) is None:
return return
debug_enabled = is_enabled_for(logging_debug)
messages_remaining -= 1 messages_remaining -= 1
message = process if isinstance(process, str) else process() message = process if isinstance(process, str) else process()
if ( if (
not messages_remaining not messages_remaining
or not self.connection or not (connection := self._connection)
or not self.connection.can_coalesce or not connection.can_coalesce
): ):
debug("Sending %s", message) if debug_enabled:
debug("%s: Sending %s", self.description, message)
await send_str(message) await send_str(message)
continue continue
@ -130,16 +144,21 @@ class WebSocketHandler:
# A None message is used to signal the end of the connection # A None message is used to signal the end of the connection
if (process := message_queue.popleft()) is None: if (process := message_queue.popleft()) is None:
return return
messages.append( messages.append(process if isinstance(process, str) else process())
process if isinstance(process, str) else process()
)
messages_remaining -= 1 messages_remaining -= 1
joined_messages = ",".join(messages) joined_messages = ",".join(messages)
coalesced_messages = f"[{joined_messages}]" coalesced_messages = f"[{joined_messages}]"
debug("Sending %s", coalesced_messages) if debug_enabled:
debug("%s: Sending %s", self.description, coalesced_messages)
await send_str(coalesced_messages) await send_str(coalesced_messages)
except asyncio.CancelledError:
debug("%s: Writer cancelled", self.description)
raise
except (RuntimeError, ConnectionResetError) as ex:
debug("%s: Unexpected error in writer: %s", self.description, ex)
finally: finally:
debug("%s: Writer done", self.description)
# Clean up the peak checker when we shut down the writer # Clean up the peak checker when we shut down the writer
self._cancel_peak_checker() self._cancel_peak_checker()
@ -195,7 +214,7 @@ class WebSocketHandler:
if not peak_checker_active: if not peak_checker_active:
self._peak_checker_unsub = async_call_later( self._peak_checker_unsub = async_call_later(
self.hass, PENDING_MSG_PEAK_TIME, self._check_write_peak self._hass, PENDING_MSG_PEAK_TIME, self._check_write_peak
) )
@callback @callback
@ -231,8 +250,14 @@ class WebSocketHandler:
async def async_handle(self) -> web.WebSocketResponse: async def async_handle(self) -> web.WebSocketResponse:
"""Handle a websocket response.""" """Handle a websocket response."""
request = self.request request = self._request
wsock = self.wsock wsock = self._wsock
logger = self._logger
debug = logger.debug
hass = self._hass
is_enabled_for = logger.isEnabledFor
logging_debug = logging.DEBUG
try: try:
async with async_timeout.timeout(10): async with async_timeout.timeout(10):
await wsock.prepare(request) await wsock.prepare(request)
@ -240,7 +265,7 @@ class WebSocketHandler:
self._logger.warning("Timeout preparing request from %s", request.remote) self._logger.warning("Timeout preparing request from %s", request.remote)
return wsock return wsock
self._logger.debug("Connected from %s", request.remote) debug("%s: Connected from %s", self.description, request.remote)
self._handle_task = asyncio.current_task() self._handle_task = asyncio.current_task()
@callback @callback
@ -248,17 +273,13 @@ class WebSocketHandler:
"""Cancel this connection.""" """Cancel this connection."""
self._cancel() self._cancel()
unsub_stop = self.hass.bus.async_listen( unsub_stop = hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, handle_hass_stop)
EVENT_HOMEASSISTANT_STOP, handle_hass_stop
)
# As the webserver is now started before the start # As the webserver is now started before the start
# event we do not want to block for websocket responses # event we do not want to block for websocket responses
self._writer_task = asyncio.create_task(self._writer()) self._writer_task = asyncio.create_task(self._writer())
auth = AuthPhase( auth = AuthPhase(logger, hass, self._send_message, self._cancel, request)
self._logger, self.hass, self._send_message, self._cancel, request
)
connection = None connection = None
disconnect_warn = None disconnect_warn = None
@ -286,13 +307,14 @@ class WebSocketHandler:
disconnect_warn = "Received invalid JSON." disconnect_warn = "Received invalid JSON."
raise Disconnect from err raise Disconnect from err
self._logger.debug("Received %s", msg_data) if is_enabled_for(logging_debug):
self.connection = connection = await auth.async_handle(msg_data) debug("%s: Received %s", self.description, msg_data)
self.hass.data[DATA_CONNECTIONS] = ( connection = await auth.async_handle(msg_data)
self.hass.data.get(DATA_CONNECTIONS, 0) + 1 self._connection = connection
) hass.data[DATA_CONNECTIONS] = hass.data.get(DATA_CONNECTIONS, 0) + 1
async_dispatcher_send(self.hass, SIGNAL_WEBSOCKET_CONNECTED) async_dispatcher_send(hass, SIGNAL_WEBSOCKET_CONNECTED)
self._authenticated = True
# #
# #
# Our websocket implementation is backed by an asyncio.Queue # Our websocket implementation is backed by an asyncio.Queue
@ -356,7 +378,9 @@ class WebSocketHandler:
disconnect_warn = "Received invalid JSON." disconnect_warn = "Received invalid JSON."
break break
self._logger.debug("Received %s", msg_data) if is_enabled_for(logging_debug):
debug("%s: Received %s", self.description, msg_data)
if not isinstance(msg_data, list): if not isinstance(msg_data, list):
connection.async_handle(msg_data) connection.async_handle(msg_data)
continue continue
@ -365,17 +389,22 @@ class WebSocketHandler:
connection.async_handle(split_msg) connection.async_handle(split_msg)
except asyncio.CancelledError: except asyncio.CancelledError:
self._logger.info("Connection closed by client") debug("%s: Connection cancelled", self.description)
raise
except Disconnect: except Disconnect as ex:
pass debug("%s: Connection closed by client: %s", self.description, ex)
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
self._logger.exception("Unexpected error inside websocket API") self._logger.exception(
"%s: Unexpected error inside websocket API", self.description
)
finally: finally:
unsub_stop() unsub_stop()
self._cancel_peak_checker()
if connection is not None: if connection is not None:
connection.async_handle_close() connection.async_handle_close()
@ -385,20 +414,37 @@ class WebSocketHandler:
if self._ready_future and not self._ready_future.done(): if self._ready_future and not self._ready_future.done():
self._ready_future.set_result(None) self._ready_future.set_result(None)
# If the writer gets canceled we still need to close the websocket
# so we have another finally block to make sure we close the websocket
# if the writer gets canceled.
try:
await self._writer_task
finally:
try: try:
# Make sure all error messages are written before closing # Make sure all error messages are written before closing
await self._writer_task
await wsock.close() await wsock.close()
finally: finally:
if disconnect_warn is None: if disconnect_warn is None:
self._logger.debug("Disconnected") debug("%s: Disconnected", self.description)
else: else:
self._logger.warning("Disconnected: %s", disconnect_warn) self._logger.warning(
"%s: Disconnected: %s", self.description, disconnect_warn
)
if connection is not None: if connection is not None:
self.hass.data[DATA_CONNECTIONS] -= 1 hass.data[DATA_CONNECTIONS] -= 1
self.connection = None self._connection = None
async_dispatcher_send(self.hass, SIGNAL_WEBSOCKET_DISCONNECTED) async_dispatcher_send(hass, SIGNAL_WEBSOCKET_DISCONNECTED)
# Break reference cycles to make sure GC can happen sooner
self._wsock = None # type: ignore[assignment]
self._request = None # type: ignore[assignment]
self._hass = None # type: ignore[assignment]
self._logger = None # type: ignore[assignment]
self._message_queue = None # type: ignore[assignment]
self._handle_task = None
self._writer_task = None
self._ready_future = None
return wsock return wsock

View file

@ -3,11 +3,19 @@ import pytest
from homeassistant.components.websocket_api.auth import TYPE_AUTH_REQUIRED from homeassistant.components.websocket_api.auth import TYPE_AUTH_REQUIRED
from homeassistant.components.websocket_api.http import URL from homeassistant.components.websocket_api.http import URL
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.typing import (
MockHAClientWebSocket,
WebSocketGenerator,
)
@pytest.fixture @pytest.fixture
async def websocket_client(hass, hass_ws_client): async def websocket_client(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator
) -> MockHAClientWebSocket:
"""Create a websocket client.""" """Create a websocket client."""
return await hass_ws_client(hass) return await hass_ws_client(hass)

View file

@ -18,7 +18,10 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.util.dt import utcnow from homeassistant.util.dt import utcnow
from tests.common import async_fire_time_changed from tests.common import async_fire_time_changed
from tests.typing import WebSocketGenerator from tests.typing import (
MockHAClientWebSocket,
WebSocketGenerator,
)
@pytest.fixture @pytest.fixture
@ -36,15 +39,103 @@ def mock_low_peak():
async def test_pending_msg_overflow( async def test_pending_msg_overflow(
hass: HomeAssistant, mock_low_queue, websocket_client hass: HomeAssistant, mock_low_queue, websocket_client: MockHAClientWebSocket
) -> None: ) -> None:
"""Test get_panels command.""" """Test pending messages overflows."""
for idx in range(10): for idx in range(10):
await websocket_client.send_json({"id": idx + 1, "type": "ping"}) await websocket_client.send_json({"id": idx + 1, "type": "ping"})
msg = await websocket_client.receive() msg = await websocket_client.receive()
assert msg.type == WSMsgType.close assert msg.type == WSMsgType.close
async def test_cleanup_on_cancellation(
hass: HomeAssistant, websocket_client: MockHAClientWebSocket
) -> None:
"""Test cleanup on cancellation."""
subscriptions = None
# Register a handler that registers a subscription
@callback
@websocket_command(
{
"type": "fake_subscription",
}
)
def fake_subscription(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
nonlocal subscriptions
msg_id: int = msg["id"]
connection.subscriptions[msg_id] = callback(lambda: None)
connection.send_result(msg_id)
subscriptions = connection.subscriptions
async_register_command(hass, fake_subscription)
# Register a handler that raises on cancel
@callback
@websocket_command(
{
"type": "subscription_that_raises_on_cancel",
}
)
def subscription_that_raises_on_cancel(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
nonlocal subscriptions
msg_id: int = msg["id"]
@callback
def _raise():
raise ValueError()
connection.subscriptions[msg_id] = _raise
connection.send_result(msg_id)
subscriptions = connection.subscriptions
async_register_command(hass, subscription_that_raises_on_cancel)
# Register a handler that cancels in handler
@callback
@websocket_command(
{
"type": "cancel_in_handler",
}
)
def cancel_in_handler(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
raise asyncio.CancelledError()
async_register_command(hass, cancel_in_handler)
await websocket_client.send_json({"id": 1, "type": "ping"})
msg = await websocket_client.receive_json()
assert msg["id"] == 1
assert msg["type"] == "pong"
assert not subscriptions
await websocket_client.send_json({"id": 2, "type": "fake_subscription"})
msg = await websocket_client.receive_json()
assert msg["id"] == 2
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
assert len(subscriptions) == 2
await websocket_client.send_json(
{"id": 3, "type": "subscription_that_raises_on_cancel"}
)
msg = await websocket_client.receive_json()
assert msg["id"] == 3
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
assert len(subscriptions) == 3
await websocket_client.send_json({"id": 4, "type": "cancel_in_handler"})
await hass.async_block_till_done()
msg = await websocket_client.receive()
assert msg.type == WSMsgType.close
assert len(subscriptions) == 0
async def test_pending_msg_peak( async def test_pending_msg_peak(
hass: HomeAssistant, hass: HomeAssistant,
mock_low_peak, mock_low_peak,