Small cleanups to the websocket api handler (#108274)
This commit is contained in:
parent
c656024365
commit
b4b041d4bf
2 changed files with 20 additions and 21 deletions
|
@ -12,6 +12,7 @@ from homeassistant.auth.models import RefreshToken, User
|
|||
from homeassistant.components.http.ban import process_success_login, process_wrong_login
|
||||
from homeassistant.const import __version__
|
||||
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
|
||||
from homeassistant.helpers.json import json_bytes
|
||||
from homeassistant.util.json import JsonValueType
|
||||
|
||||
from .connection import ActiveConnection
|
||||
|
@ -34,15 +35,10 @@ AUTH_MESSAGE_SCHEMA: Final = vol.Schema(
|
|||
}
|
||||
)
|
||||
|
||||
|
||||
def auth_ok_message() -> dict[str, str]:
|
||||
"""Return an auth_ok message."""
|
||||
return {"type": TYPE_AUTH_OK, "ha_version": __version__}
|
||||
|
||||
|
||||
def auth_required_message() -> dict[str, str]:
|
||||
"""Return an auth_required message."""
|
||||
return {"type": TYPE_AUTH_REQUIRED, "ha_version": __version__}
|
||||
AUTH_OK_MESSAGE = json_bytes({"type": TYPE_AUTH_OK, "ha_version": __version__})
|
||||
AUTH_REQUIRED_MESSAGE = json_bytes(
|
||||
{"type": TYPE_AUTH_REQUIRED, "ha_version": __version__}
|
||||
)
|
||||
|
||||
|
||||
def auth_invalid_message(message: str) -> dict[str, str]:
|
||||
|
@ -104,7 +100,7 @@ class AuthPhase:
|
|||
"""Create an active connection."""
|
||||
self._logger.debug("Auth OK")
|
||||
process_success_login(self._request)
|
||||
self._send_message(auth_ok_message())
|
||||
self._send_message(AUTH_OK_MESSAGE)
|
||||
return ActiveConnection(
|
||||
self._logger, self._hass, self._send_message, user, refresh_token
|
||||
)
|
||||
|
|
|
@ -18,7 +18,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_send
|
|||
from homeassistant.helpers.event import async_call_later
|
||||
from homeassistant.util.json import json_loads
|
||||
|
||||
from .auth import AuthPhase, auth_required_message
|
||||
from .auth import AUTH_REQUIRED_MESSAGE, AuthPhase
|
||||
from .const import (
|
||||
DATA_CONNECTIONS,
|
||||
MAX_PENDING_MSG,
|
||||
|
@ -266,6 +266,11 @@ class WebSocketHandler:
|
|||
if self._writer_task is not None:
|
||||
self._writer_task.cancel()
|
||||
|
||||
@callback
|
||||
def _async_handle_hass_stop(self, event: Event) -> None:
|
||||
"""Cancel this connection."""
|
||||
self._cancel()
|
||||
|
||||
async def async_handle(self) -> web.WebSocketResponse:
|
||||
"""Handle a websocket response."""
|
||||
request = self._request
|
||||
|
@ -286,12 +291,9 @@ class WebSocketHandler:
|
|||
debug("%s: Connected from %s", self.description, request.remote)
|
||||
self._handle_task = asyncio.current_task()
|
||||
|
||||
@callback
|
||||
def handle_hass_stop(event: Event) -> None:
|
||||
"""Cancel this connection."""
|
||||
self._cancel()
|
||||
|
||||
unsub_stop = hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, handle_hass_stop)
|
||||
unsub_stop = hass.bus.async_listen(
|
||||
EVENT_HOMEASSISTANT_STOP, self._async_handle_hass_stop
|
||||
)
|
||||
|
||||
# As the webserver is now started before the start
|
||||
# event we do not want to block for websocket responses
|
||||
|
@ -302,7 +304,7 @@ class WebSocketHandler:
|
|||
disconnect_warn = None
|
||||
|
||||
try:
|
||||
self._send_message(auth_required_message())
|
||||
self._send_message(AUTH_REQUIRED_MESSAGE)
|
||||
|
||||
# Auth Phase
|
||||
try:
|
||||
|
@ -379,7 +381,7 @@ class WebSocketHandler:
|
|||
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING):
|
||||
break
|
||||
|
||||
if msg.type == WSMsgType.BINARY:
|
||||
if msg.type is WSMsgType.BINARY:
|
||||
if len(msg.data) < 1:
|
||||
disconnect_warn = "Received invalid binary message."
|
||||
break
|
||||
|
@ -388,7 +390,7 @@ class WebSocketHandler:
|
|||
async_handle_binary(handler, payload)
|
||||
continue
|
||||
|
||||
if msg.type != WSMsgType.TEXT:
|
||||
if msg.type is not WSMsgType.TEXT:
|
||||
disconnect_warn = "Received non-Text message."
|
||||
break
|
||||
|
||||
|
@ -401,7 +403,8 @@ class WebSocketHandler:
|
|||
if is_enabled_for(logging_debug):
|
||||
debug("%s: Received %s", self.description, command_msg_data)
|
||||
|
||||
if not isinstance(command_msg_data, list):
|
||||
# command_msg_data is always deserialized from JSON as a list
|
||||
if type(command_msg_data) is not list: # noqa: E721
|
||||
async_handle_str(command_msg_data)
|
||||
continue
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue