Significantly reduce websocket api connection auth phase latency (#108564)

* Significantly reduce websocket api connection auth phase latancy

Since the auth phase has exclusive control over the websocket
until ActiveConnection is created, we can bypass the queue and
send messages right away. This reduces the latancy and reconnect
time since we do not have to wait for the background processing
of the queue to send the auth ok message.

* only start the writer queue after auth is successful
This commit is contained in:
J. Nick Koston 2024-01-21 17:33:31 -10:00 committed by GitHub
parent da1d530889
commit dbb5645e63
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 47 additions and 38 deletions

View file

@ -1,14 +1,13 @@
"""Handle the auth of a connection.""" """Handle the auth of a connection."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable, Coroutine
from typing import TYPE_CHECKING, Any, Final from typing import TYPE_CHECKING, Any, Final
from aiohttp.web import Request from aiohttp.web import Request
import voluptuous as vol import voluptuous as vol
from voluptuous.humanize import humanize_error from voluptuous.humanize import humanize_error
from homeassistant.auth.models import RefreshToken, User
from homeassistant.components.http.ban import process_success_login, process_wrong_login from homeassistant.components.http.ban import process_success_login, process_wrong_login
from homeassistant.const import __version__ from homeassistant.const import __version__
from homeassistant.core import CALLBACK_TYPE, HomeAssistant from homeassistant.core import CALLBACK_TYPE, HomeAssistant
@ -41,9 +40,9 @@ AUTH_REQUIRED_MESSAGE = json_bytes(
) )
def auth_invalid_message(message: str) -> dict[str, str]: def auth_invalid_message(message: str) -> bytes:
"""Return an auth_invalid message.""" """Return an auth_invalid message."""
return {"type": TYPE_AUTH_INVALID, "message": message} return json_bytes({"type": TYPE_AUTH_INVALID, "message": message})
class AuthPhase: class AuthPhase:
@ -56,13 +55,17 @@ class AuthPhase:
send_message: Callable[[bytes | str | dict[str, Any]], None], send_message: Callable[[bytes | str | dict[str, Any]], None],
cancel_ws: CALLBACK_TYPE, cancel_ws: CALLBACK_TYPE,
request: Request, request: Request,
send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]],
) -> None: ) -> None:
"""Initialize the authentiated connection.""" """Initialize the authenticated connection."""
self._hass = hass self._hass = hass
# send_message will send a message to the client via the queue.
self._send_message = send_message self._send_message = send_message
self._cancel_ws = cancel_ws self._cancel_ws = cancel_ws
self._logger = logger self._logger = logger
self._request = request self._request = request
# send_bytes_text will directly send a message to the client.
self._send_bytes_text = send_bytes_text
async def async_handle(self, msg: JsonValueType) -> ActiveConnection: async def async_handle(self, msg: JsonValueType) -> ActiveConnection:
"""Handle authentication.""" """Handle authentication."""
@ -73,7 +76,7 @@ class AuthPhase:
f"Auth message incorrectly formatted: {humanize_error(msg, err)}" f"Auth message incorrectly formatted: {humanize_error(msg, err)}"
) )
self._logger.warning(error_msg) self._logger.warning(error_msg)
self._send_message(auth_invalid_message(error_msg)) await self._send_bytes_text(auth_invalid_message(error_msg))
raise Disconnect from err raise Disconnect from err
if (access_token := valid_msg.get("access_token")) and ( if (access_token := valid_msg.get("access_token")) and (
@ -81,26 +84,25 @@ class AuthPhase:
access_token access_token
) )
): ):
conn = await self._async_finish_auth(refresh_token.user, refresh_token) conn = ActiveConnection(
self._logger,
self._hass,
self._send_message,
refresh_token.user,
refresh_token,
)
conn.subscriptions[ conn.subscriptions[
"auth" "auth"
] = self._hass.auth.async_register_revoke_token_callback( ] = self._hass.auth.async_register_revoke_token_callback(
refresh_token.id, self._cancel_ws refresh_token.id, self._cancel_ws
) )
await self._send_bytes_text(AUTH_OK_MESSAGE)
return conn
self._send_message(auth_invalid_message("Invalid access token or password"))
await process_wrong_login(self._request)
raise Disconnect
async def _async_finish_auth(
self, user: User, refresh_token: RefreshToken
) -> ActiveConnection:
"""Create an active connection."""
self._logger.debug("Auth OK") self._logger.debug("Auth OK")
process_success_login(self._request) process_success_login(self._request)
self._send_message(AUTH_OK_MESSAGE) return conn
return ActiveConnection(
self._logger, self._hass, self._send_message, user, refresh_token await self._send_bytes_text(
auth_invalid_message("Invalid access token or password")
) )
await process_wrong_login(self._request)
raise Disconnect

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections import deque from collections import deque
from collections.abc import Callable from collections.abc import Callable, Coroutine
import datetime as dt import datetime as dt
from functools import partial from functools import partial
import logging import logging
@ -116,16 +116,14 @@ class WebSocketHandler:
return describe_request(request) return describe_request(request)
return "finished connection" return "finished connection"
async def _writer(self) -> None: async def _writer(
self, send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]]
) -> None:
"""Write outgoing messages.""" """Write outgoing messages."""
# Variables are set locally to avoid lookups in the loop # Variables are set locally to avoid lookups in the loop
message_queue = self._message_queue message_queue = self._message_queue
logger = self._logger logger = self._logger
wsock = self._wsock wsock = self._wsock
writer = wsock._writer # pylint: disable=protected-access
if TYPE_CHECKING:
assert writer is not None
send_str = partial(writer.send, binary=False)
loop = self._hass.loop loop = self._hass.loop
debug = logger.debug debug = logger.debug
is_enabled_for = logger.isEnabledFor is_enabled_for = logger.isEnabledFor
@ -152,7 +150,7 @@ class WebSocketHandler:
): ):
if debug_enabled: if debug_enabled:
debug("%s: Sending %s", self.description, message) debug("%s: Sending %s", self.description, message)
await send_str(message) await send_bytes_text(message)
continue continue
messages: list[bytes] = [message] messages: list[bytes] = [message]
@ -166,7 +164,7 @@ class WebSocketHandler:
coalesced_messages = b"".join((b"[", b",".join(messages), b"]")) coalesced_messages = b"".join((b"[", b",".join(messages), b"]"))
if debug_enabled: if debug_enabled:
debug("%s: Sending %s", self.description, coalesced_messages) debug("%s: Sending %s", self.description, coalesced_messages)
await send_str(coalesced_messages) await send_bytes_text(coalesced_messages)
except asyncio.CancelledError: except asyncio.CancelledError:
debug("%s: Writer cancelled", self.description) debug("%s: Writer cancelled", self.description)
raise raise
@ -186,7 +184,7 @@ class WebSocketHandler:
@callback @callback
def _send_message(self, message: str | bytes | dict[str, Any]) -> None: def _send_message(self, message: str | bytes | dict[str, Any]) -> None:
"""Send a message to the client. """Queue sending a message to the client.
Closes connection if the client is not reading the messages. Closes connection if the client is not reading the messages.
@ -295,21 +293,23 @@ class WebSocketHandler:
EVENT_HOMEASSISTANT_STOP, self._async_handle_hass_stop EVENT_HOMEASSISTANT_STOP, self._async_handle_hass_stop
) )
# As the webserver is now started before the start writer = wsock._writer # pylint: disable=protected-access
# event we do not want to block for websocket responses if TYPE_CHECKING:
self._writer_task = asyncio.create_task(self._writer()) assert writer is not None
auth = AuthPhase(logger, hass, self._send_message, self._cancel, request) send_bytes_text = partial(writer.send, binary=False)
auth = AuthPhase(
logger, hass, self._send_message, self._cancel, request, send_bytes_text
)
connection = None connection = None
disconnect_warn = None disconnect_warn = None
try: try:
self._send_message(AUTH_REQUIRED_MESSAGE) await send_bytes_text(AUTH_REQUIRED_MESSAGE)
# Auth Phase # Auth Phase
try: try:
async with asyncio.timeout(10): msg = await wsock.receive(10)
msg = await wsock.receive()
except asyncio.TimeoutError as err: except asyncio.TimeoutError as err:
disconnect_warn = "Did not receive auth message within 10 seconds" disconnect_warn = "Did not receive auth message within 10 seconds"
raise Disconnect from err raise Disconnect from err
@ -330,7 +330,13 @@ class WebSocketHandler:
if is_enabled_for(logging_debug): if is_enabled_for(logging_debug):
debug("%s: Received %s", self.description, auth_msg_data) debug("%s: Received %s", self.description, auth_msg_data)
connection = await auth.async_handle(auth_msg_data) connection = await auth.async_handle(auth_msg_data)
# As the webserver is now started before the start
# event we do not want to block for websocket responses
#
# We only start the writer queue after the auth phase is completed
# since there is no need to queue messages before the auth phase
self._connection = connection self._connection = connection
self._writer_task = asyncio.create_task(self._writer(send_bytes_text))
hass.data[DATA_CONNECTIONS] = hass.data.get(DATA_CONNECTIONS, 0) + 1 hass.data[DATA_CONNECTIONS] = hass.data.get(DATA_CONNECTIONS, 0) + 1
async_dispatcher_send(hass, SIGNAL_WEBSOCKET_CONNECTED) async_dispatcher_send(hass, SIGNAL_WEBSOCKET_CONNECTED)
@ -370,7 +376,7 @@ class WebSocketHandler:
# added a way to set the limit, but there is no way to actually # added a way to set the limit, but there is no way to actually
# reach the code to set the limit, so we have to set it directly. # reach the code to set the limit, so we have to set it directly.
# #
wsock._writer._limit = 2**20 # type: ignore[union-attr] # pylint: disable=protected-access writer._limit = 2**20 # pylint: disable=protected-access
async_handle_str = connection.async_handle async_handle_str = connection.async_handle
async_handle_binary = connection.async_handle_binary async_handle_binary = connection.async_handle_binary
@ -441,6 +447,7 @@ class WebSocketHandler:
# so we have another finally block to make sure we close the websocket # so we have another finally block to make sure we close the websocket
# if the writer gets canceled. # if the writer gets canceled.
try: try:
if self._writer_task:
await self._writer_task await self._writer_task
finally: finally:
try: try: