Fix memory leaks in websocket api (#94780)
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
a7d327afa2
commit
1206f2c1da
5 changed files with 244 additions and 84 deletions
|
@ -56,6 +56,10 @@ class ActiveConnection:
|
||||||
self.binary_handlers: list[BinaryHandler | None] = []
|
self.binary_handlers: list[BinaryHandler | None] = []
|
||||||
current_connection.set(self)
|
current_connection.set(self)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""Return the representation."""
|
||||||
|
return f"<ActiveConnection {self.get_description(None)}>"
|
||||||
|
|
||||||
def set_supported_features(self, features: dict[str, float]) -> None:
|
def set_supported_features(self, features: dict[str, float]) -> None:
|
||||||
"""Set supported features."""
|
"""Set supported features."""
|
||||||
self.supported_features = features
|
self.supported_features = features
|
||||||
|
@ -193,7 +197,24 @@ class ActiveConnection:
|
||||||
def async_handle_close(self) -> None:
|
def async_handle_close(self) -> None:
|
||||||
"""Handle closing down connection."""
|
"""Handle closing down connection."""
|
||||||
for unsub in self.subscriptions.values():
|
for unsub in self.subscriptions.values():
|
||||||
|
try:
|
||||||
unsub()
|
unsub()
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
# If one fails, make sure we still try the rest
|
||||||
|
self.logger.exception(
|
||||||
|
"Error unsubscribing from subscription: %s", unsub
|
||||||
|
)
|
||||||
|
self.subscriptions.clear()
|
||||||
|
self.send_message = self._connect_closed_error
|
||||||
|
current_request.set(None)
|
||||||
|
current_connection.set(None)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _connect_closed_error(
|
||||||
|
self, msg: str | dict[str, Any] | Callable[[], str]
|
||||||
|
) -> None:
|
||||||
|
"""Send a message when the connection is closed."""
|
||||||
|
self.logger.debug("Tried to send message %s on closed connection", msg)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_handle_exception(self, msg: dict[str, Any], err: Exception) -> None:
|
def async_handle_exception(self, msg: dict[str, Any], err: Exception) -> None:
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
"""Websocket constants."""
|
"""Websocket constants."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from concurrent import futures
|
|
||||||
from typing import TYPE_CHECKING, Any, Final
|
from typing import TYPE_CHECKING, Any, Final
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
@ -42,10 +40,6 @@ ERR_TEMPLATE_ERROR: Final = "template_error"
|
||||||
|
|
||||||
TYPE_RESULT: Final = "result"
|
TYPE_RESULT: Final = "result"
|
||||||
|
|
||||||
# Define the possible errors that occur when connections are cancelled.
|
|
||||||
# Originally, this was just asyncio.CancelledError, but issue #9546 showed
|
|
||||||
# that futures.CancelledErrors can also occur in some situations.
|
|
||||||
CANCELLATION_ERRORS: Final = (asyncio.CancelledError, futures.CancelledError)
|
|
||||||
|
|
||||||
# Event types
|
# Event types
|
||||||
SIGNAL_WEBSOCKET_CONNECTED: Final = "websocket_connected"
|
SIGNAL_WEBSOCKET_CONNECTED: Final = "websocket_connected"
|
||||||
|
|
|
@ -4,7 +4,6 @@ from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from contextlib import suppress
|
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Final
|
from typing import TYPE_CHECKING, Any, Final
|
||||||
|
@ -21,7 +20,6 @@ from homeassistant.util.json import json_loads
|
||||||
|
|
||||||
from .auth import AuthPhase, auth_required_message
|
from .auth import AuthPhase, auth_required_message
|
||||||
from .const import (
|
from .const import (
|
||||||
CANCELLATION_ERRORS,
|
|
||||||
DATA_CONNECTIONS,
|
DATA_CONNECTIONS,
|
||||||
MAX_PENDING_MSG,
|
MAX_PENDING_MSG,
|
||||||
PENDING_MSG_PEAK,
|
PENDING_MSG_PEAK,
|
||||||
|
@ -68,15 +66,16 @@ class WebSocketHandler:
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant, request: web.Request) -> None:
|
def __init__(self, hass: HomeAssistant, request: web.Request) -> None:
|
||||||
"""Initialize an active connection."""
|
"""Initialize an active connection."""
|
||||||
self.hass = hass
|
self._hass = hass
|
||||||
self.request = request
|
self._request: web.Request = request
|
||||||
self.wsock = web.WebSocketResponse(heartbeat=55)
|
self._wsock = web.WebSocketResponse(heartbeat=55)
|
||||||
self._handle_task: asyncio.Task | None = None
|
self._handle_task: asyncio.Task | None = None
|
||||||
self._writer_task: asyncio.Task | None = None
|
self._writer_task: asyncio.Task | None = None
|
||||||
self._closing: bool = False
|
self._closing: bool = False
|
||||||
|
self._authenticated: bool = False
|
||||||
self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)})
|
self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)})
|
||||||
self._peak_checker_unsub: Callable[[], None] | None = None
|
self._peak_checker_unsub: Callable[[], None] | None = None
|
||||||
self.connection: ActiveConnection | None = None
|
self._connection: ActiveConnection | None = None
|
||||||
|
|
||||||
# The WebSocketHandler has a single consumer and path
|
# The WebSocketHandler has a single consumer and path
|
||||||
# to where messages are queued. This allows the implementation
|
# to where messages are queued. This allows the implementation
|
||||||
|
@ -85,25 +84,38 @@ class WebSocketHandler:
|
||||||
self._message_queue: deque = deque()
|
self._message_queue: deque = deque()
|
||||||
self._ready_future: asyncio.Future[None] | None = None
|
self._ready_future: asyncio.Future[None] | None = None
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""Return the representation."""
|
||||||
|
return (
|
||||||
|
"<WebSocketHandler "
|
||||||
|
f"closing={self._closing} "
|
||||||
|
f"authenticated={self._authenticated} "
|
||||||
|
f"description={self.description}>"
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
"""Return a description of the connection."""
|
"""Return a description of the connection."""
|
||||||
if self.connection is not None:
|
if connection := self._connection:
|
||||||
return self.connection.get_description(self.request)
|
return connection.get_description(self._request)
|
||||||
return describe_request(self.request)
|
if request := self._request:
|
||||||
|
return describe_request(request)
|
||||||
|
return "finished connection"
|
||||||
|
|
||||||
async def _writer(self) -> None:
|
async def _writer(self) -> None:
|
||||||
"""Write outgoing messages."""
|
"""Write outgoing messages."""
|
||||||
# Variables are set locally to avoid lookups in the loop
|
# Variables are set locally to avoid lookups in the loop
|
||||||
message_queue = self._message_queue
|
message_queue = self._message_queue
|
||||||
logger = self._logger
|
logger = self._logger
|
||||||
send_str = self.wsock.send_str
|
wsock = self._wsock
|
||||||
loop = self.hass.loop
|
send_str = wsock.send_str
|
||||||
|
loop = self._hass.loop
|
||||||
debug = logger.debug
|
debug = logger.debug
|
||||||
|
is_enabled_for = logger.isEnabledFor
|
||||||
|
logging_debug = logging.DEBUG
|
||||||
# Exceptions if Socket disconnected or cancelled by connection handler
|
# Exceptions if Socket disconnected or cancelled by connection handler
|
||||||
try:
|
try:
|
||||||
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
|
while not wsock.closed:
|
||||||
while not self.wsock.closed:
|
|
||||||
if (messages_remaining := len(message_queue)) == 0:
|
if (messages_remaining := len(message_queue)) == 0:
|
||||||
self._ready_future = loop.create_future()
|
self._ready_future = loop.create_future()
|
||||||
await self._ready_future
|
await self._ready_future
|
||||||
|
@ -113,15 +125,17 @@ class WebSocketHandler:
|
||||||
if (process := message_queue.popleft()) is None:
|
if (process := message_queue.popleft()) is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
debug_enabled = is_enabled_for(logging_debug)
|
||||||
messages_remaining -= 1
|
messages_remaining -= 1
|
||||||
message = process if isinstance(process, str) else process()
|
message = process if isinstance(process, str) else process()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not messages_remaining
|
not messages_remaining
|
||||||
or not self.connection
|
or not (connection := self._connection)
|
||||||
or not self.connection.can_coalesce
|
or not connection.can_coalesce
|
||||||
):
|
):
|
||||||
debug("Sending %s", message)
|
if debug_enabled:
|
||||||
|
debug("%s: Sending %s", self.description, message)
|
||||||
await send_str(message)
|
await send_str(message)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -130,16 +144,21 @@ class WebSocketHandler:
|
||||||
# A None message is used to signal the end of the connection
|
# A None message is used to signal the end of the connection
|
||||||
if (process := message_queue.popleft()) is None:
|
if (process := message_queue.popleft()) is None:
|
||||||
return
|
return
|
||||||
messages.append(
|
messages.append(process if isinstance(process, str) else process())
|
||||||
process if isinstance(process, str) else process()
|
|
||||||
)
|
|
||||||
messages_remaining -= 1
|
messages_remaining -= 1
|
||||||
|
|
||||||
joined_messages = ",".join(messages)
|
joined_messages = ",".join(messages)
|
||||||
coalesced_messages = f"[{joined_messages}]"
|
coalesced_messages = f"[{joined_messages}]"
|
||||||
debug("Sending %s", coalesced_messages)
|
if debug_enabled:
|
||||||
|
debug("%s: Sending %s", self.description, coalesced_messages)
|
||||||
await send_str(coalesced_messages)
|
await send_str(coalesced_messages)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
debug("%s: Writer cancelled", self.description)
|
||||||
|
raise
|
||||||
|
except (RuntimeError, ConnectionResetError) as ex:
|
||||||
|
debug("%s: Unexpected error in writer: %s", self.description, ex)
|
||||||
finally:
|
finally:
|
||||||
|
debug("%s: Writer done", self.description)
|
||||||
# Clean up the peak checker when we shut down the writer
|
# Clean up the peak checker when we shut down the writer
|
||||||
self._cancel_peak_checker()
|
self._cancel_peak_checker()
|
||||||
|
|
||||||
|
@ -195,7 +214,7 @@ class WebSocketHandler:
|
||||||
|
|
||||||
if not peak_checker_active:
|
if not peak_checker_active:
|
||||||
self._peak_checker_unsub = async_call_later(
|
self._peak_checker_unsub = async_call_later(
|
||||||
self.hass, PENDING_MSG_PEAK_TIME, self._check_write_peak
|
self._hass, PENDING_MSG_PEAK_TIME, self._check_write_peak
|
||||||
)
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -231,8 +250,14 @@ class WebSocketHandler:
|
||||||
|
|
||||||
async def async_handle(self) -> web.WebSocketResponse:
|
async def async_handle(self) -> web.WebSocketResponse:
|
||||||
"""Handle a websocket response."""
|
"""Handle a websocket response."""
|
||||||
request = self.request
|
request = self._request
|
||||||
wsock = self.wsock
|
wsock = self._wsock
|
||||||
|
logger = self._logger
|
||||||
|
debug = logger.debug
|
||||||
|
hass = self._hass
|
||||||
|
is_enabled_for = logger.isEnabledFor
|
||||||
|
logging_debug = logging.DEBUG
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with async_timeout.timeout(10):
|
async with async_timeout.timeout(10):
|
||||||
await wsock.prepare(request)
|
await wsock.prepare(request)
|
||||||
|
@ -240,7 +265,7 @@ class WebSocketHandler:
|
||||||
self._logger.warning("Timeout preparing request from %s", request.remote)
|
self._logger.warning("Timeout preparing request from %s", request.remote)
|
||||||
return wsock
|
return wsock
|
||||||
|
|
||||||
self._logger.debug("Connected from %s", request.remote)
|
debug("%s: Connected from %s", self.description, request.remote)
|
||||||
self._handle_task = asyncio.current_task()
|
self._handle_task = asyncio.current_task()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -248,17 +273,13 @@ class WebSocketHandler:
|
||||||
"""Cancel this connection."""
|
"""Cancel this connection."""
|
||||||
self._cancel()
|
self._cancel()
|
||||||
|
|
||||||
unsub_stop = self.hass.bus.async_listen(
|
unsub_stop = hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, handle_hass_stop)
|
||||||
EVENT_HOMEASSISTANT_STOP, handle_hass_stop
|
|
||||||
)
|
|
||||||
|
|
||||||
# As the webserver is now started before the start
|
# As the webserver is now started before the start
|
||||||
# event we do not want to block for websocket responses
|
# event we do not want to block for websocket responses
|
||||||
self._writer_task = asyncio.create_task(self._writer())
|
self._writer_task = asyncio.create_task(self._writer())
|
||||||
|
|
||||||
auth = AuthPhase(
|
auth = AuthPhase(logger, hass, self._send_message, self._cancel, request)
|
||||||
self._logger, self.hass, self._send_message, self._cancel, request
|
|
||||||
)
|
|
||||||
connection = None
|
connection = None
|
||||||
disconnect_warn = None
|
disconnect_warn = None
|
||||||
|
|
||||||
|
@ -286,13 +307,14 @@ class WebSocketHandler:
|
||||||
disconnect_warn = "Received invalid JSON."
|
disconnect_warn = "Received invalid JSON."
|
||||||
raise Disconnect from err
|
raise Disconnect from err
|
||||||
|
|
||||||
self._logger.debug("Received %s", msg_data)
|
if is_enabled_for(logging_debug):
|
||||||
self.connection = connection = await auth.async_handle(msg_data)
|
debug("%s: Received %s", self.description, msg_data)
|
||||||
self.hass.data[DATA_CONNECTIONS] = (
|
connection = await auth.async_handle(msg_data)
|
||||||
self.hass.data.get(DATA_CONNECTIONS, 0) + 1
|
self._connection = connection
|
||||||
)
|
hass.data[DATA_CONNECTIONS] = hass.data.get(DATA_CONNECTIONS, 0) + 1
|
||||||
async_dispatcher_send(self.hass, SIGNAL_WEBSOCKET_CONNECTED)
|
async_dispatcher_send(hass, SIGNAL_WEBSOCKET_CONNECTED)
|
||||||
|
|
||||||
|
self._authenticated = True
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
# Our websocket implementation is backed by an asyncio.Queue
|
# Our websocket implementation is backed by an asyncio.Queue
|
||||||
|
@ -356,7 +378,9 @@ class WebSocketHandler:
|
||||||
disconnect_warn = "Received invalid JSON."
|
disconnect_warn = "Received invalid JSON."
|
||||||
break
|
break
|
||||||
|
|
||||||
self._logger.debug("Received %s", msg_data)
|
if is_enabled_for(logging_debug):
|
||||||
|
debug("%s: Received %s", self.description, msg_data)
|
||||||
|
|
||||||
if not isinstance(msg_data, list):
|
if not isinstance(msg_data, list):
|
||||||
connection.async_handle(msg_data)
|
connection.async_handle(msg_data)
|
||||||
continue
|
continue
|
||||||
|
@ -365,17 +389,22 @@ class WebSocketHandler:
|
||||||
connection.async_handle(split_msg)
|
connection.async_handle(split_msg)
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
self._logger.info("Connection closed by client")
|
debug("%s: Connection cancelled", self.description)
|
||||||
|
raise
|
||||||
|
|
||||||
except Disconnect:
|
except Disconnect as ex:
|
||||||
pass
|
debug("%s: Connection closed by client: %s", self.description, ex)
|
||||||
|
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
self._logger.exception("Unexpected error inside websocket API")
|
self._logger.exception(
|
||||||
|
"%s: Unexpected error inside websocket API", self.description
|
||||||
|
)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
unsub_stop()
|
unsub_stop()
|
||||||
|
|
||||||
|
self._cancel_peak_checker()
|
||||||
|
|
||||||
if connection is not None:
|
if connection is not None:
|
||||||
connection.async_handle_close()
|
connection.async_handle_close()
|
||||||
|
|
||||||
|
@ -385,20 +414,37 @@ class WebSocketHandler:
|
||||||
if self._ready_future and not self._ready_future.done():
|
if self._ready_future and not self._ready_future.done():
|
||||||
self._ready_future.set_result(None)
|
self._ready_future.set_result(None)
|
||||||
|
|
||||||
|
# If the writer gets canceled we still need to close the websocket
|
||||||
|
# so we have another finally block to make sure we close the websocket
|
||||||
|
# if the writer gets canceled.
|
||||||
|
try:
|
||||||
|
await self._writer_task
|
||||||
|
finally:
|
||||||
try:
|
try:
|
||||||
# Make sure all error messages are written before closing
|
# Make sure all error messages are written before closing
|
||||||
await self._writer_task
|
|
||||||
await wsock.close()
|
await wsock.close()
|
||||||
finally:
|
finally:
|
||||||
if disconnect_warn is None:
|
if disconnect_warn is None:
|
||||||
self._logger.debug("Disconnected")
|
debug("%s: Disconnected", self.description)
|
||||||
else:
|
else:
|
||||||
self._logger.warning("Disconnected: %s", disconnect_warn)
|
self._logger.warning(
|
||||||
|
"%s: Disconnected: %s", self.description, disconnect_warn
|
||||||
|
)
|
||||||
|
|
||||||
if connection is not None:
|
if connection is not None:
|
||||||
self.hass.data[DATA_CONNECTIONS] -= 1
|
hass.data[DATA_CONNECTIONS] -= 1
|
||||||
self.connection = None
|
self._connection = None
|
||||||
|
|
||||||
async_dispatcher_send(self.hass, SIGNAL_WEBSOCKET_DISCONNECTED)
|
async_dispatcher_send(hass, SIGNAL_WEBSOCKET_DISCONNECTED)
|
||||||
|
|
||||||
|
# Break reference cycles to make sure GC can happen sooner
|
||||||
|
self._wsock = None # type: ignore[assignment]
|
||||||
|
self._request = None # type: ignore[assignment]
|
||||||
|
self._hass = None # type: ignore[assignment]
|
||||||
|
self._logger = None # type: ignore[assignment]
|
||||||
|
self._message_queue = None # type: ignore[assignment]
|
||||||
|
self._handle_task = None
|
||||||
|
self._writer_task = None
|
||||||
|
self._ready_future = None
|
||||||
|
|
||||||
return wsock
|
return wsock
|
||||||
|
|
|
@ -3,11 +3,19 @@ import pytest
|
||||||
|
|
||||||
from homeassistant.components.websocket_api.auth import TYPE_AUTH_REQUIRED
|
from homeassistant.components.websocket_api.auth import TYPE_AUTH_REQUIRED
|
||||||
from homeassistant.components.websocket_api.http import URL
|
from homeassistant.components.websocket_api.http import URL
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from tests.typing import (
|
||||||
|
MockHAClientWebSocket,
|
||||||
|
WebSocketGenerator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def websocket_client(hass, hass_ws_client):
|
async def websocket_client(
|
||||||
|
hass: HomeAssistant, hass_ws_client: WebSocketGenerator
|
||||||
|
) -> MockHAClientWebSocket:
|
||||||
"""Create a websocket client."""
|
"""Create a websocket client."""
|
||||||
return await hass_ws_client(hass)
|
return await hass_ws_client(hass)
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,10 @@ from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.util.dt import utcnow
|
from homeassistant.util.dt import utcnow
|
||||||
|
|
||||||
from tests.common import async_fire_time_changed
|
from tests.common import async_fire_time_changed
|
||||||
from tests.typing import WebSocketGenerator
|
from tests.typing import (
|
||||||
|
MockHAClientWebSocket,
|
||||||
|
WebSocketGenerator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -36,15 +39,103 @@ def mock_low_peak():
|
||||||
|
|
||||||
|
|
||||||
async def test_pending_msg_overflow(
|
async def test_pending_msg_overflow(
|
||||||
hass: HomeAssistant, mock_low_queue, websocket_client
|
hass: HomeAssistant, mock_low_queue, websocket_client: MockHAClientWebSocket
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test get_panels command."""
|
"""Test pending messages overflows."""
|
||||||
for idx in range(10):
|
for idx in range(10):
|
||||||
await websocket_client.send_json({"id": idx + 1, "type": "ping"})
|
await websocket_client.send_json({"id": idx + 1, "type": "ping"})
|
||||||
msg = await websocket_client.receive()
|
msg = await websocket_client.receive()
|
||||||
assert msg.type == WSMsgType.close
|
assert msg.type == WSMsgType.close
|
||||||
|
|
||||||
|
|
||||||
|
async def test_cleanup_on_cancellation(
|
||||||
|
hass: HomeAssistant, websocket_client: MockHAClientWebSocket
|
||||||
|
) -> None:
|
||||||
|
"""Test cleanup on cancellation."""
|
||||||
|
|
||||||
|
subscriptions = None
|
||||||
|
|
||||||
|
# Register a handler that registers a subscription
|
||||||
|
@callback
|
||||||
|
@websocket_command(
|
||||||
|
{
|
||||||
|
"type": "fake_subscription",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def fake_subscription(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
nonlocal subscriptions
|
||||||
|
msg_id: int = msg["id"]
|
||||||
|
connection.subscriptions[msg_id] = callback(lambda: None)
|
||||||
|
connection.send_result(msg_id)
|
||||||
|
subscriptions = connection.subscriptions
|
||||||
|
|
||||||
|
async_register_command(hass, fake_subscription)
|
||||||
|
|
||||||
|
# Register a handler that raises on cancel
|
||||||
|
@callback
|
||||||
|
@websocket_command(
|
||||||
|
{
|
||||||
|
"type": "subscription_that_raises_on_cancel",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def subscription_that_raises_on_cancel(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
nonlocal subscriptions
|
||||||
|
msg_id: int = msg["id"]
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _raise():
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
connection.subscriptions[msg_id] = _raise
|
||||||
|
connection.send_result(msg_id)
|
||||||
|
subscriptions = connection.subscriptions
|
||||||
|
|
||||||
|
async_register_command(hass, subscription_that_raises_on_cancel)
|
||||||
|
|
||||||
|
# Register a handler that cancels in handler
|
||||||
|
@callback
|
||||||
|
@websocket_command(
|
||||||
|
{
|
||||||
|
"type": "cancel_in_handler",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def cancel_in_handler(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
raise asyncio.CancelledError()
|
||||||
|
|
||||||
|
async_register_command(hass, cancel_in_handler)
|
||||||
|
|
||||||
|
await websocket_client.send_json({"id": 1, "type": "ping"})
|
||||||
|
msg = await websocket_client.receive_json()
|
||||||
|
assert msg["id"] == 1
|
||||||
|
assert msg["type"] == "pong"
|
||||||
|
assert not subscriptions
|
||||||
|
await websocket_client.send_json({"id": 2, "type": "fake_subscription"})
|
||||||
|
msg = await websocket_client.receive_json()
|
||||||
|
assert msg["id"] == 2
|
||||||
|
assert msg["type"] == const.TYPE_RESULT
|
||||||
|
assert msg["success"]
|
||||||
|
assert len(subscriptions) == 2
|
||||||
|
await websocket_client.send_json(
|
||||||
|
{"id": 3, "type": "subscription_that_raises_on_cancel"}
|
||||||
|
)
|
||||||
|
msg = await websocket_client.receive_json()
|
||||||
|
assert msg["id"] == 3
|
||||||
|
assert msg["type"] == const.TYPE_RESULT
|
||||||
|
assert msg["success"]
|
||||||
|
assert len(subscriptions) == 3
|
||||||
|
await websocket_client.send_json({"id": 4, "type": "cancel_in_handler"})
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
msg = await websocket_client.receive()
|
||||||
|
assert msg.type == WSMsgType.close
|
||||||
|
assert len(subscriptions) == 0
|
||||||
|
|
||||||
|
|
||||||
async def test_pending_msg_peak(
|
async def test_pending_msg_peak(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_low_peak,
|
mock_low_peak,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue