* Significantly reduce websocket api connection auth phase latancy Since the auth phase has exclusive control over the websocket until ActiveConnection is created, we can bypass the queue and send messages right away. This reduces the latancy and reconnect time since we do not have to wait for the background processing of the queue to send the auth ok message. * only start the writer queue after auth is successful
108 lines
3.6 KiB
Python
108 lines
3.6 KiB
Python
"""Handle the auth of a connection."""
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable, Coroutine
|
|
from typing import TYPE_CHECKING, Any, Final
|
|
|
|
from aiohttp.web import Request
|
|
import voluptuous as vol
|
|
from voluptuous.humanize import humanize_error
|
|
|
|
from homeassistant.components.http.ban import process_success_login, process_wrong_login
|
|
from homeassistant.const import __version__
|
|
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
|
|
from homeassistant.helpers.json import json_bytes
|
|
from homeassistant.util.json import JsonValueType
|
|
|
|
from .connection import ActiveConnection
|
|
from .error import Disconnect
|
|
|
|
if TYPE_CHECKING:
|
|
from .http import WebSocketAdapter
|
|
|
|
|
|
TYPE_AUTH: Final = "auth"
|
|
TYPE_AUTH_INVALID: Final = "auth_invalid"
|
|
TYPE_AUTH_OK: Final = "auth_ok"
|
|
TYPE_AUTH_REQUIRED: Final = "auth_required"
|
|
|
|
AUTH_MESSAGE_SCHEMA: Final = vol.Schema(
|
|
{
|
|
vol.Required("type"): TYPE_AUTH,
|
|
vol.Exclusive("api_password", "auth"): str,
|
|
vol.Exclusive("access_token", "auth"): str,
|
|
}
|
|
)
|
|
|
|
AUTH_OK_MESSAGE = json_bytes({"type": TYPE_AUTH_OK, "ha_version": __version__})
|
|
AUTH_REQUIRED_MESSAGE = json_bytes(
|
|
{"type": TYPE_AUTH_REQUIRED, "ha_version": __version__}
|
|
)
|
|
|
|
|
|
def auth_invalid_message(message: str) -> bytes:
|
|
"""Return an auth_invalid message."""
|
|
return json_bytes({"type": TYPE_AUTH_INVALID, "message": message})
|
|
|
|
|
|
class AuthPhase:
|
|
"""Connection that requires client to authenticate first."""
|
|
|
|
def __init__(
|
|
self,
|
|
logger: WebSocketAdapter,
|
|
hass: HomeAssistant,
|
|
send_message: Callable[[bytes | str | dict[str, Any]], None],
|
|
cancel_ws: CALLBACK_TYPE,
|
|
request: Request,
|
|
send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]],
|
|
) -> None:
|
|
"""Initialize the authenticated connection."""
|
|
self._hass = hass
|
|
# send_message will send a message to the client via the queue.
|
|
self._send_message = send_message
|
|
self._cancel_ws = cancel_ws
|
|
self._logger = logger
|
|
self._request = request
|
|
# send_bytes_text will directly send a message to the client.
|
|
self._send_bytes_text = send_bytes_text
|
|
|
|
async def async_handle(self, msg: JsonValueType) -> ActiveConnection:
|
|
"""Handle authentication."""
|
|
try:
|
|
valid_msg = AUTH_MESSAGE_SCHEMA(msg)
|
|
except vol.Invalid as err:
|
|
error_msg = (
|
|
f"Auth message incorrectly formatted: {humanize_error(msg, err)}"
|
|
)
|
|
self._logger.warning(error_msg)
|
|
await self._send_bytes_text(auth_invalid_message(error_msg))
|
|
raise Disconnect from err
|
|
|
|
if (access_token := valid_msg.get("access_token")) and (
|
|
refresh_token := await self._hass.auth.async_validate_access_token(
|
|
access_token
|
|
)
|
|
):
|
|
conn = ActiveConnection(
|
|
self._logger,
|
|
self._hass,
|
|
self._send_message,
|
|
refresh_token.user,
|
|
refresh_token,
|
|
)
|
|
conn.subscriptions[
|
|
"auth"
|
|
] = self._hass.auth.async_register_revoke_token_callback(
|
|
refresh_token.id, self._cancel_ws
|
|
)
|
|
await self._send_bytes_text(AUTH_OK_MESSAGE)
|
|
self._logger.debug("Auth OK")
|
|
process_success_login(self._request)
|
|
return conn
|
|
|
|
await self._send_bytes_text(
|
|
auth_invalid_message("Invalid access token or password")
|
|
)
|
|
await process_wrong_login(self._request)
|
|
raise Disconnect
|