* Significantly reduce websocket api connection auth phase latancy Since the auth phase has exclusive control over the websocket until ActiveConnection is created, we can bypass the queue and send messages right away. This reduces the latancy and reconnect time since we do not have to wait for the background processing of the queue to send the auth ok message. * only start the writer queue after auth is successful
480 lines
18 KiB
Python
480 lines
18 KiB
Python
"""View to accept incoming websocket connection."""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections import deque
|
|
from collections.abc import Callable, Coroutine
|
|
import datetime as dt
|
|
from functools import partial
|
|
import logging
|
|
from typing import TYPE_CHECKING, Any, Final
|
|
|
|
from aiohttp import WSMsgType, web
|
|
|
|
from homeassistant.components.http import HomeAssistantView
|
|
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
|
from homeassistant.core import Event, HomeAssistant, callback
|
|
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
|
from homeassistant.helpers.event import async_call_later
|
|
from homeassistant.util.json import json_loads
|
|
|
|
from .auth import AUTH_REQUIRED_MESSAGE, AuthPhase
|
|
from .const import (
|
|
DATA_CONNECTIONS,
|
|
MAX_PENDING_MSG,
|
|
PENDING_MSG_PEAK,
|
|
PENDING_MSG_PEAK_TIME,
|
|
SIGNAL_WEBSOCKET_CONNECTED,
|
|
SIGNAL_WEBSOCKET_DISCONNECTED,
|
|
URL,
|
|
)
|
|
from .error import Disconnect
|
|
from .messages import message_to_json_bytes
|
|
from .util import describe_request
|
|
|
|
if TYPE_CHECKING:
|
|
from .connection import ActiveConnection
|
|
|
|
|
|
_WS_LOGGER: Final = logging.getLogger(f"{__name__}.connection")
|
|
|
|
|
|
class WebsocketAPIView(HomeAssistantView):
|
|
"""View to serve a websockets endpoint."""
|
|
|
|
name: str = "websocketapi"
|
|
url: str = URL
|
|
requires_auth: bool = False
|
|
|
|
async def get(self, request: web.Request) -> web.WebSocketResponse:
|
|
"""Handle an incoming websocket connection."""
|
|
return await WebSocketHandler(request.app["hass"], request).async_handle()
|
|
|
|
|
|
class WebSocketAdapter(logging.LoggerAdapter):
|
|
"""Add connection id to websocket messages."""
|
|
|
|
def process(self, msg: str, kwargs: Any) -> tuple[str, Any]:
|
|
"""Add connid to websocket log messages."""
|
|
assert self.extra is not None
|
|
return f'[{self.extra["connid"]}] {msg}', kwargs
|
|
|
|
|
|
class WebSocketHandler:
|
|
"""Handle an active websocket client connection."""
|
|
|
|
__slots__ = (
|
|
"_hass",
|
|
"_request",
|
|
"_wsock",
|
|
"_handle_task",
|
|
"_writer_task",
|
|
"_closing",
|
|
"_authenticated",
|
|
"_logger",
|
|
"_peak_checker_unsub",
|
|
"_connection",
|
|
"_message_queue",
|
|
"_ready_future",
|
|
)
|
|
|
|
def __init__(self, hass: HomeAssistant, request: web.Request) -> None:
|
|
"""Initialize an active connection."""
|
|
self._hass = hass
|
|
self._request: web.Request = request
|
|
self._wsock = web.WebSocketResponse(heartbeat=55)
|
|
self._handle_task: asyncio.Task | None = None
|
|
self._writer_task: asyncio.Task | None = None
|
|
self._closing: bool = False
|
|
self._authenticated: bool = False
|
|
self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)})
|
|
self._peak_checker_unsub: Callable[[], None] | None = None
|
|
self._connection: ActiveConnection | None = None
|
|
|
|
# The WebSocketHandler has a single consumer and path
|
|
# to where messages are queued. This allows the implementation
|
|
# to use a deque and an asyncio.Future to avoid the overhead of
|
|
# an asyncio.Queue.
|
|
self._message_queue: deque[bytes | None] = deque()
|
|
self._ready_future: asyncio.Future[None] | None = None
|
|
|
|
def __repr__(self) -> str:
|
|
"""Return the representation."""
|
|
return (
|
|
"<WebSocketHandler "
|
|
f"closing={self._closing} "
|
|
f"authenticated={self._authenticated} "
|
|
f"description={self.description}>"
|
|
)
|
|
|
|
@property
|
|
def description(self) -> str:
|
|
"""Return a description of the connection."""
|
|
if connection := self._connection:
|
|
return connection.get_description(self._request)
|
|
if request := self._request:
|
|
return describe_request(request)
|
|
return "finished connection"
|
|
|
|
async def _writer(
|
|
self, send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]]
|
|
) -> None:
|
|
"""Write outgoing messages."""
|
|
# Variables are set locally to avoid lookups in the loop
|
|
message_queue = self._message_queue
|
|
logger = self._logger
|
|
wsock = self._wsock
|
|
loop = self._hass.loop
|
|
debug = logger.debug
|
|
is_enabled_for = logger.isEnabledFor
|
|
logging_debug = logging.DEBUG
|
|
# Exceptions if Socket disconnected or cancelled by connection handler
|
|
try:
|
|
while not wsock.closed:
|
|
if (messages_remaining := len(message_queue)) == 0:
|
|
self._ready_future = loop.create_future()
|
|
await self._ready_future
|
|
messages_remaining = len(message_queue)
|
|
|
|
# A None message is used to signal the end of the connection
|
|
if (message := message_queue.popleft()) is None:
|
|
return
|
|
|
|
debug_enabled = is_enabled_for(logging_debug)
|
|
messages_remaining -= 1
|
|
|
|
if (
|
|
not messages_remaining
|
|
or not (connection := self._connection)
|
|
or not connection.can_coalesce
|
|
):
|
|
if debug_enabled:
|
|
debug("%s: Sending %s", self.description, message)
|
|
await send_bytes_text(message)
|
|
continue
|
|
|
|
messages: list[bytes] = [message]
|
|
while messages_remaining:
|
|
# A None message is used to signal the end of the connection
|
|
if (message := message_queue.popleft()) is None:
|
|
return
|
|
messages.append(message)
|
|
messages_remaining -= 1
|
|
|
|
coalesced_messages = b"".join((b"[", b",".join(messages), b"]"))
|
|
if debug_enabled:
|
|
debug("%s: Sending %s", self.description, coalesced_messages)
|
|
await send_bytes_text(coalesced_messages)
|
|
except asyncio.CancelledError:
|
|
debug("%s: Writer cancelled", self.description)
|
|
raise
|
|
except (RuntimeError, ConnectionResetError) as ex:
|
|
debug("%s: Unexpected error in writer: %s", self.description, ex)
|
|
finally:
|
|
debug("%s: Writer done", self.description)
|
|
# Clean up the peak checker when we shut down the writer
|
|
self._cancel_peak_checker()
|
|
|
|
@callback
|
|
def _cancel_peak_checker(self) -> None:
|
|
"""Cancel the peak checker."""
|
|
if self._peak_checker_unsub is not None:
|
|
self._peak_checker_unsub()
|
|
self._peak_checker_unsub = None
|
|
|
|
@callback
|
|
def _send_message(self, message: str | bytes | dict[str, Any]) -> None:
|
|
"""Queue sending a message to the client.
|
|
|
|
Closes connection if the client is not reading the messages.
|
|
|
|
Async friendly.
|
|
"""
|
|
if self._closing:
|
|
# Connection is cancelled, don't flood logs about exceeding
|
|
# max pending messages.
|
|
return
|
|
|
|
if isinstance(message, dict):
|
|
message = message_to_json_bytes(message)
|
|
elif isinstance(message, str):
|
|
message = message.encode("utf-8")
|
|
|
|
message_queue = self._message_queue
|
|
queue_size_before_add = len(message_queue)
|
|
if queue_size_before_add >= MAX_PENDING_MSG:
|
|
self._logger.error(
|
|
(
|
|
"%s: Client unable to keep up with pending messages. Reached %s pending"
|
|
" messages. The system's load is too high or an integration is"
|
|
" misbehaving; Last message was: %s"
|
|
),
|
|
self.description,
|
|
MAX_PENDING_MSG,
|
|
message,
|
|
)
|
|
self._cancel()
|
|
return
|
|
|
|
message_queue.append(message)
|
|
ready_future = self._ready_future
|
|
if ready_future and not ready_future.done():
|
|
ready_future.set_result(None)
|
|
|
|
peak_checker_active = self._peak_checker_unsub is not None
|
|
|
|
if queue_size_before_add <= PENDING_MSG_PEAK:
|
|
if peak_checker_active:
|
|
self._cancel_peak_checker()
|
|
return
|
|
|
|
if not peak_checker_active:
|
|
self._peak_checker_unsub = async_call_later(
|
|
self._hass, PENDING_MSG_PEAK_TIME, self._check_write_peak
|
|
)
|
|
|
|
@callback
|
|
def _check_write_peak(self, _utc_time: dt.datetime) -> None:
|
|
"""Check that we are no longer above the write peak."""
|
|
self._peak_checker_unsub = None
|
|
|
|
if len(self._message_queue) < PENDING_MSG_PEAK:
|
|
return
|
|
|
|
self._logger.error(
|
|
(
|
|
"%s: Client unable to keep up with pending messages. Stayed over %s for %s"
|
|
" seconds. The system's load is too high or an integration is"
|
|
" misbehaving; Last message was: %s"
|
|
),
|
|
self.description,
|
|
PENDING_MSG_PEAK,
|
|
PENDING_MSG_PEAK_TIME,
|
|
self._message_queue[-1],
|
|
)
|
|
self._cancel()
|
|
|
|
@callback
|
|
def _cancel(self) -> None:
|
|
"""Cancel the connection."""
|
|
self._closing = True
|
|
self._cancel_peak_checker()
|
|
if self._handle_task is not None:
|
|
self._handle_task.cancel()
|
|
if self._writer_task is not None:
|
|
self._writer_task.cancel()
|
|
|
|
@callback
|
|
def _async_handle_hass_stop(self, event: Event) -> None:
|
|
"""Cancel this connection."""
|
|
self._cancel()
|
|
|
|
async def async_handle(self) -> web.WebSocketResponse:
|
|
"""Handle a websocket response."""
|
|
request = self._request
|
|
wsock = self._wsock
|
|
logger = self._logger
|
|
debug = logger.debug
|
|
hass = self._hass
|
|
is_enabled_for = logger.isEnabledFor
|
|
logging_debug = logging.DEBUG
|
|
|
|
try:
|
|
async with asyncio.timeout(10):
|
|
await wsock.prepare(request)
|
|
except asyncio.TimeoutError:
|
|
self._logger.warning("Timeout preparing request from %s", request.remote)
|
|
return wsock
|
|
|
|
debug("%s: Connected from %s", self.description, request.remote)
|
|
self._handle_task = asyncio.current_task()
|
|
|
|
unsub_stop = hass.bus.async_listen(
|
|
EVENT_HOMEASSISTANT_STOP, self._async_handle_hass_stop
|
|
)
|
|
|
|
writer = wsock._writer # pylint: disable=protected-access
|
|
if TYPE_CHECKING:
|
|
assert writer is not None
|
|
|
|
send_bytes_text = partial(writer.send, binary=False)
|
|
auth = AuthPhase(
|
|
logger, hass, self._send_message, self._cancel, request, send_bytes_text
|
|
)
|
|
connection = None
|
|
disconnect_warn = None
|
|
|
|
try:
|
|
await send_bytes_text(AUTH_REQUIRED_MESSAGE)
|
|
|
|
# Auth Phase
|
|
try:
|
|
msg = await wsock.receive(10)
|
|
except asyncio.TimeoutError as err:
|
|
disconnect_warn = "Did not receive auth message within 10 seconds"
|
|
raise Disconnect from err
|
|
|
|
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING):
|
|
raise Disconnect
|
|
|
|
if msg.type != WSMsgType.TEXT:
|
|
disconnect_warn = "Received non-Text message."
|
|
raise Disconnect
|
|
|
|
try:
|
|
auth_msg_data = json_loads(msg.data)
|
|
except ValueError as err:
|
|
disconnect_warn = "Received invalid JSON."
|
|
raise Disconnect from err
|
|
|
|
if is_enabled_for(logging_debug):
|
|
debug("%s: Received %s", self.description, auth_msg_data)
|
|
connection = await auth.async_handle(auth_msg_data)
|
|
# As the webserver is now started before the start
|
|
# event we do not want to block for websocket responses
|
|
#
|
|
# We only start the writer queue after the auth phase is completed
|
|
# since there is no need to queue messages before the auth phase
|
|
self._connection = connection
|
|
self._writer_task = asyncio.create_task(self._writer(send_bytes_text))
|
|
hass.data[DATA_CONNECTIONS] = hass.data.get(DATA_CONNECTIONS, 0) + 1
|
|
async_dispatcher_send(hass, SIGNAL_WEBSOCKET_CONNECTED)
|
|
|
|
self._authenticated = True
|
|
#
|
|
#
|
|
# Our websocket implementation is backed by a deque
|
|
#
|
|
# As back-pressure builds, the queue will back up and use more memory
|
|
# until we disconnect the client when the queue size reaches
|
|
# MAX_PENDING_MSG. When we are generating a high volume of websocket messages,
|
|
# we hit a bottleneck in aiohttp where it will wait for
|
|
# the buffer to drain before sending the next message and messages
|
|
# start backing up in the queue.
|
|
#
|
|
# https://github.com/aio-libs/aiohttp/issues/1367 added drains
|
|
# to the websocket writer to handle malicious clients and network issues.
|
|
# The drain causes multiple problems for us since the buffer cannot be
|
|
# drained fast enough when we deliver a high volume or large messages:
|
|
#
|
|
# - We end up disconnecting the client. The client will then reconnect,
|
|
# and the cycle repeats itself, which results in a significant amount of
|
|
# CPU usage.
|
|
#
|
|
# - Messages latency increases because messages cannot be moved into
|
|
# the TCP buffer because it is blocked waiting for the drain to happen because
|
|
# of the low default limit of 16KiB. By increasing the limit, we instead
|
|
# rely on the underlying TCP buffer and stack to deliver the messages which
|
|
# can typically happen much faster.
|
|
#
|
|
# After the auth phase is completed, and we are not concerned about
|
|
# the user being a malicious client, we set the limit to force a drain
|
|
# to 1MiB. 1MiB is the maximum expected size of the serialized entity
|
|
# registry, which is the largest message we usually send.
|
|
#
|
|
# https://github.com/aio-libs/aiohttp/commit/b3c80ee3f7d5d8f0b8bc27afe52e4d46621eaf99
|
|
# added a way to set the limit, but there is no way to actually
|
|
# reach the code to set the limit, so we have to set it directly.
|
|
#
|
|
writer._limit = 2**20 # pylint: disable=protected-access
|
|
async_handle_str = connection.async_handle
|
|
async_handle_binary = connection.async_handle_binary
|
|
|
|
# Command phase
|
|
while not wsock.closed:
|
|
msg = await wsock.receive()
|
|
|
|
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING):
|
|
break
|
|
|
|
if msg.type is WSMsgType.BINARY:
|
|
if len(msg.data) < 1:
|
|
disconnect_warn = "Received invalid binary message."
|
|
break
|
|
handler = msg.data[0]
|
|
payload = msg.data[1:]
|
|
async_handle_binary(handler, payload)
|
|
continue
|
|
|
|
if msg.type is not WSMsgType.TEXT:
|
|
disconnect_warn = "Received non-Text message."
|
|
break
|
|
|
|
try:
|
|
command_msg_data = json_loads(msg.data)
|
|
except ValueError:
|
|
disconnect_warn = "Received invalid JSON."
|
|
break
|
|
|
|
if is_enabled_for(logging_debug):
|
|
debug("%s: Received %s", self.description, command_msg_data)
|
|
|
|
# command_msg_data is always deserialized from JSON as a list
|
|
if type(command_msg_data) is not list: # noqa: E721
|
|
async_handle_str(command_msg_data)
|
|
continue
|
|
|
|
for split_msg in command_msg_data:
|
|
async_handle_str(split_msg)
|
|
|
|
except asyncio.CancelledError:
|
|
debug("%s: Connection cancelled", self.description)
|
|
raise
|
|
|
|
except Disconnect as ex:
|
|
debug("%s: Connection closed by client: %s", self.description, ex)
|
|
|
|
except Exception: # pylint: disable=broad-except
|
|
self._logger.exception(
|
|
"%s: Unexpected error inside websocket API", self.description
|
|
)
|
|
|
|
finally:
|
|
unsub_stop()
|
|
|
|
self._cancel_peak_checker()
|
|
|
|
if connection is not None:
|
|
connection.async_handle_close()
|
|
|
|
self._closing = True
|
|
|
|
self._message_queue.append(None)
|
|
if self._ready_future and not self._ready_future.done():
|
|
self._ready_future.set_result(None)
|
|
|
|
# If the writer gets canceled we still need to close the websocket
|
|
# so we have another finally block to make sure we close the websocket
|
|
# if the writer gets canceled.
|
|
try:
|
|
if self._writer_task:
|
|
await self._writer_task
|
|
finally:
|
|
try:
|
|
# Make sure all error messages are written before closing
|
|
await wsock.close()
|
|
finally:
|
|
if disconnect_warn is None:
|
|
debug("%s: Disconnected", self.description)
|
|
else:
|
|
self._logger.warning(
|
|
"%s: Disconnected: %s", self.description, disconnect_warn
|
|
)
|
|
|
|
if connection is not None:
|
|
hass.data[DATA_CONNECTIONS] -= 1
|
|
self._connection = None
|
|
|
|
async_dispatcher_send(hass, SIGNAL_WEBSOCKET_DISCONNECTED)
|
|
|
|
# Break reference cycles to make sure GC can happen sooner
|
|
self._wsock = None # type: ignore[assignment]
|
|
self._request = None # type: ignore[assignment]
|
|
self._hass = None # type: ignore[assignment]
|
|
self._logger = None # type: ignore[assignment]
|
|
self._message_queue = None # type: ignore[assignment]
|
|
self._handle_task = None
|
|
self._writer_task = None
|
|
self._ready_future = None
|
|
|
|
return wsock
|