diff --git a/homeassistant/components/websocket_api/auth.py b/homeassistant/components/websocket_api/auth.py index f09c2601328..0a681692c3d 100644 --- a/homeassistant/components/websocket_api/auth.py +++ b/homeassistant/components/websocket_api/auth.py @@ -12,6 +12,7 @@ from homeassistant.auth.models import RefreshToken, User from homeassistant.components.http.ban import process_success_login, process_wrong_login from homeassistant.const import __version__ from homeassistant.core import CALLBACK_TYPE, HomeAssistant +from homeassistant.helpers.json import json_bytes from homeassistant.util.json import JsonValueType from .connection import ActiveConnection @@ -34,15 +35,10 @@ AUTH_MESSAGE_SCHEMA: Final = vol.Schema( } ) - -def auth_ok_message() -> dict[str, str]: - """Return an auth_ok message.""" - return {"type": TYPE_AUTH_OK, "ha_version": __version__} - - -def auth_required_message() -> dict[str, str]: - """Return an auth_required message.""" - return {"type": TYPE_AUTH_REQUIRED, "ha_version": __version__} +AUTH_OK_MESSAGE = json_bytes({"type": TYPE_AUTH_OK, "ha_version": __version__}) +AUTH_REQUIRED_MESSAGE = json_bytes( + {"type": TYPE_AUTH_REQUIRED, "ha_version": __version__} +) def auth_invalid_message(message: str) -> dict[str, str]: @@ -104,7 +100,7 @@ class AuthPhase: """Create an active connection.""" self._logger.debug("Auth OK") process_success_login(self._request) - self._send_message(auth_ok_message()) + self._send_message(AUTH_OK_MESSAGE) return ActiveConnection( self._logger, self._hass, self._send_message, user, refresh_token ) diff --git a/homeassistant/components/websocket_api/http.py b/homeassistant/components/websocket_api/http.py index c8c5d00cb2a..d966e4e26ef 100644 --- a/homeassistant/components/websocket_api/http.py +++ b/homeassistant/components/websocket_api/http.py @@ -18,7 +18,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.event import async_call_later from homeassistant.util.json import json_loads -from .auth import AuthPhase, auth_required_message +from .auth import AUTH_REQUIRED_MESSAGE, AuthPhase from .const import ( DATA_CONNECTIONS, MAX_PENDING_MSG, @@ -266,6 +266,11 @@ class WebSocketHandler: if self._writer_task is not None: self._writer_task.cancel() + @callback + def _async_handle_hass_stop(self, event: Event) -> None: + """Cancel this connection.""" + self._cancel() + async def async_handle(self) -> web.WebSocketResponse: """Handle a websocket response.""" request = self._request @@ -286,12 +291,9 @@ class WebSocketHandler: debug("%s: Connected from %s", self.description, request.remote) self._handle_task = asyncio.current_task() - @callback - def handle_hass_stop(event: Event) -> None: - """Cancel this connection.""" - self._cancel() - - unsub_stop = hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, handle_hass_stop) + unsub_stop = hass.bus.async_listen( + EVENT_HOMEASSISTANT_STOP, self._async_handle_hass_stop + ) # As the webserver is now started before the start # event we do not want to block for websocket responses @@ -302,7 +304,7 @@ class WebSocketHandler: disconnect_warn = None try: - self._send_message(auth_required_message()) + self._send_message(AUTH_REQUIRED_MESSAGE) # Auth Phase try: @@ -379,7 +381,7 @@ class WebSocketHandler: if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING): break - if msg.type == WSMsgType.BINARY: + if msg.type is WSMsgType.BINARY: if len(msg.data) < 1: disconnect_warn = "Received invalid binary message." break @@ -388,7 +390,7 @@ class WebSocketHandler: async_handle_binary(handler, payload) continue - if msg.type != WSMsgType.TEXT: + if msg.type is not WSMsgType.TEXT: disconnect_warn = "Received non-Text message." break @@ -401,7 +403,8 @@ class WebSocketHandler: if is_enabled_for(logging_debug): debug("%s: Received %s", self.description, command_msg_data) - if not isinstance(command_msg_data, list): + # command_msg_data is always deserialized from JSON as a list + if type(command_msg_data) is not list: # noqa: E721 async_handle_str(command_msg_data) continue