Improve websocket api coverage and typing (#94891)

This commit is contained in:
J. Nick Koston 2023-06-20 15:21:24 +01:00 committed by GitHub
parent b51dcb600e
commit 3f18f515e7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 246 additions and 31 deletions

View file

@ -12,6 +12,7 @@ from homeassistant.auth.models import RefreshToken, User
from homeassistant.components.http.ban import process_success_login, process_wrong_login
from homeassistant.const import __version__
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
from homeassistant.util.json import JsonValueType
from .connection import ActiveConnection
from .error import Disconnect
@ -67,10 +68,10 @@ class AuthPhase:
self._logger = logger
self._request = request
async def async_handle(self, msg: dict[str, str]) -> ActiveConnection:
async def async_handle(self, msg: JsonValueType) -> ActiveConnection:
"""Handle authentication."""
try:
msg = AUTH_MESSAGE_SCHEMA(msg)
valid_msg = AUTH_MESSAGE_SCHEMA(msg)
except vol.Invalid as err:
error_msg = (
f"Auth message incorrectly formatted: {humanize_error(msg, err)}"
@ -79,20 +80,19 @@ class AuthPhase:
self._send_message(auth_invalid_message(error_msg))
raise Disconnect from err
if "access_token" in msg:
self._logger.debug("Received access_token")
refresh_token = await self._hass.auth.async_validate_access_token(
msg["access_token"]
if (access_token := valid_msg.get("access_token")) and (
refresh_token := await self._hass.auth.async_validate_access_token(
access_token
)
):
conn = await self._async_finish_auth(refresh_token.user, refresh_token)
conn.subscriptions[
"auth"
] = self._hass.auth.async_register_revoke_token_callback(
refresh_token.id, self._cancel_ws
)
if refresh_token is not None:
conn = await self._async_finish_auth(refresh_token.user, refresh_token)
conn.subscriptions[
"auth"
] = self._hass.auth.async_register_revoke_token_callback(
refresh_token.id, self._cancel_ws
)
return conn
return conn
self._send_message(auth_invalid_message("Invalid access token or password"))
await process_wrong_login(self._request)