Improve websocket api coverage and typing (#94891)

This commit is contained in:
J. Nick Koston 2023-06-20 15:21:24 +01:00 committed by GitHub
parent b51dcb600e
commit 3f18f515e7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 246 additions and 31 deletions

View file

@ -56,8 +56,7 @@ class WebSocketAdapter(logging.LoggerAdapter):
def process(self, msg: str, kwargs: Any) -> tuple[str, Any]:
"""Add connid to websocket log messages."""
if not self.extra or "connid" not in self.extra:
return msg, kwargs
assert self.extra is not None
return f'[{self.extra["connid"]}] {msg}', kwargs
@ -81,7 +80,7 @@ class WebSocketHandler:
# to where messages are queued. This allows the implementation
# to use a deque and an asyncio.Future to avoid the overhead of
# an asyncio.Queue.
self._message_queue: deque = deque()
self._message_queue: deque[str | Callable[[], str] | None] = deque()
self._ready_future: asyncio.Future[None] | None = None
def __repr__(self) -> str:
@ -302,14 +301,14 @@ class WebSocketHandler:
raise Disconnect
try:
msg_data = msg.json(loads=json_loads)
auth_msg_data = json_loads(msg.data)
except ValueError as err:
disconnect_warn = "Received invalid JSON."
raise Disconnect from err
if is_enabled_for(logging_debug):
debug("%s: Received %s", self.description, msg_data)
connection = await auth.async_handle(msg_data)
debug("%s: Received %s", self.description, auth_msg_data)
connection = await auth.async_handle(auth_msg_data)
self._connection = connection
hass.data[DATA_CONNECTIONS] = hass.data.get(DATA_CONNECTIONS, 0) + 1
async_dispatcher_send(hass, SIGNAL_WEBSOCKET_CONNECTED)
@ -317,7 +316,7 @@ class WebSocketHandler:
self._authenticated = True
#
#
# Our websocket implementation is backed by an asyncio.Queue
# Our websocket implementation is backed by a deque
#
# As back-pressure builds, the queue will back up and use more memory
# until we disconnect the client when the queue size reaches
@ -351,6 +350,8 @@ class WebSocketHandler:
# reach the code to set the limit, so we have to set it directly.
#
wsock._writer._limit = 2**20 # type: ignore[union-attr] # pylint: disable=protected-access
async_handle_str = connection.async_handle
async_handle_binary = connection.async_handle_binary
# Command phase
while not wsock.closed:
@ -365,7 +366,7 @@ class WebSocketHandler:
break
handler = msg.data[0]
payload = msg.data[1:]
connection.async_handle_binary(handler, payload)
async_handle_binary(handler, payload)
continue
if msg.type != WSMsgType.TEXT:
@ -373,20 +374,20 @@ class WebSocketHandler:
break
try:
msg_data = msg.json(loads=json_loads)
command_msg_data = json_loads(msg.data)
except ValueError:
disconnect_warn = "Received invalid JSON."
break
if is_enabled_for(logging_debug):
debug("%s: Received %s", self.description, msg_data)
debug("%s: Received %s", self.description, command_msg_data)
if not isinstance(msg_data, list):
connection.async_handle(msg_data)
if not isinstance(command_msg_data, list):
async_handle_str(command_msg_data)
continue
for split_msg in msg_data:
connection.async_handle(split_msg)
for split_msg in command_msg_data:
async_handle_str(split_msg)
except asyncio.CancelledError:
debug("%s: Connection cancelled", self.description)