Improve websocket api coverage and typing (#94891)

This commit is contained in:
J. Nick Koston 2023-06-20 15:21:24 +01:00 committed by GitHub
parent b51dcb600e
commit 3f18f515e7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 246 additions and 31 deletions

View file

@ -12,6 +12,7 @@ from homeassistant.auth.models import RefreshToken, User
from homeassistant.components.http.ban import process_success_login, process_wrong_login
from homeassistant.const import __version__
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
from homeassistant.util.json import JsonValueType
from .connection import ActiveConnection
from .error import Disconnect
@ -67,10 +68,10 @@ class AuthPhase:
self._logger = logger
self._request = request
async def async_handle(self, msg: dict[str, str]) -> ActiveConnection:
async def async_handle(self, msg: JsonValueType) -> ActiveConnection:
"""Handle authentication."""
try:
msg = AUTH_MESSAGE_SCHEMA(msg)
valid_msg = AUTH_MESSAGE_SCHEMA(msg)
except vol.Invalid as err:
error_msg = (
f"Auth message incorrectly formatted: {humanize_error(msg, err)}"
@ -79,12 +80,11 @@ class AuthPhase:
self._send_message(auth_invalid_message(error_msg))
raise Disconnect from err
if "access_token" in msg:
self._logger.debug("Received access_token")
refresh_token = await self._hass.auth.async_validate_access_token(
msg["access_token"]
if (access_token := valid_msg.get("access_token")) and (
refresh_token := await self._hass.auth.async_validate_access_token(
access_token
)
if refresh_token is not None:
):
conn = await self._async_finish_auth(refresh_token.user, refresh_token)
conn.subscriptions[
"auth"

View file

@ -13,6 +13,7 @@ from homeassistant.auth.models import RefreshToken, User
from homeassistant.components.http import current_request
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, Unauthorized
from homeassistant.util.json import JsonValueType
from . import const, messages
from .util import describe_request
@ -144,7 +145,7 @@ class ActiveConnection:
self.binary_handlers[index] = None
@callback
def async_handle(self, msg: dict[str, Any]) -> None:
def async_handle(self, msg: JsonValueType) -> None:
"""Handle a single incoming message."""
if (
# Not using isinstance as we don't care about children
@ -157,10 +158,11 @@ class ActiveConnection:
or type(type_) is not str # pylint: disable=unidiomatic-typecheck
)
):
self.logger.error("Received invalid command", msg)
self.logger.error("Received invalid command: %s", msg)
id_ = msg.get("id") if isinstance(msg, dict) else 0
self.send_message(
messages.error_message(
msg.get("id"),
id_, # type: ignore[arg-type]
const.ERR_INVALID_FORMAT,
"Message incorrectly formatted.",
)

View file

@ -56,8 +56,7 @@ class WebSocketAdapter(logging.LoggerAdapter):
def process(self, msg: str, kwargs: Any) -> tuple[str, Any]:
"""Add connid to websocket log messages."""
if not self.extra or "connid" not in self.extra:
return msg, kwargs
assert self.extra is not None
return f'[{self.extra["connid"]}] {msg}', kwargs
@ -81,7 +80,7 @@ class WebSocketHandler:
# to where messages are queued. This allows the implementation
# to use a deque and an asyncio.Future to avoid the overhead of
# an asyncio.Queue.
self._message_queue: deque = deque()
self._message_queue: deque[str | Callable[[], str] | None] = deque()
self._ready_future: asyncio.Future[None] | None = None
def __repr__(self) -> str:
@ -302,14 +301,14 @@ class WebSocketHandler:
raise Disconnect
try:
msg_data = msg.json(loads=json_loads)
auth_msg_data = json_loads(msg.data)
except ValueError as err:
disconnect_warn = "Received invalid JSON."
raise Disconnect from err
if is_enabled_for(logging_debug):
debug("%s: Received %s", self.description, msg_data)
connection = await auth.async_handle(msg_data)
debug("%s: Received %s", self.description, auth_msg_data)
connection = await auth.async_handle(auth_msg_data)
self._connection = connection
hass.data[DATA_CONNECTIONS] = hass.data.get(DATA_CONNECTIONS, 0) + 1
async_dispatcher_send(hass, SIGNAL_WEBSOCKET_CONNECTED)
@ -317,7 +316,7 @@ class WebSocketHandler:
self._authenticated = True
#
#
# Our websocket implementation is backed by an asyncio.Queue
# Our websocket implementation is backed by a deque
#
# As back-pressure builds, the queue will back up and use more memory
# until we disconnect the client when the queue size reaches
@ -351,6 +350,8 @@ class WebSocketHandler:
# reach the code to set the limit, so we have to set it directly.
#
wsock._writer._limit = 2**20 # type: ignore[union-attr] # pylint: disable=protected-access
async_handle_str = connection.async_handle
async_handle_binary = connection.async_handle_binary
# Command phase
while not wsock.closed:
@ -365,7 +366,7 @@ class WebSocketHandler:
break
handler = msg.data[0]
payload = msg.data[1:]
connection.async_handle_binary(handler, payload)
async_handle_binary(handler, payload)
continue
if msg.type != WSMsgType.TEXT:
@ -373,20 +374,20 @@ class WebSocketHandler:
break
try:
msg_data = msg.json(loads=json_loads)
command_msg_data = json_loads(msg.data)
except ValueError:
disconnect_warn = "Received invalid JSON."
break
if is_enabled_for(logging_debug):
debug("%s: Received %s", self.description, msg_data)
debug("%s: Received %s", self.description, command_msg_data)
if not isinstance(msg_data, list):
connection.async_handle(msg_data)
if not isinstance(command_msg_data, list):
async_handle_str(command_msg_data)
continue
for split_msg in msg_data:
connection.async_handle(split_msg)
for split_msg in command_msg_data:
async_handle_str(split_msg)
except asyncio.CancelledError:
debug("%s: Connection cancelled", self.description)

View file

@ -2,6 +2,7 @@
from unittest.mock import patch
import aiohttp
from aiohttp import WSMsgType
import pytest
from homeassistant.auth.providers.legacy_api_password import (
@ -223,3 +224,79 @@ async def test_auth_close_after_revoke(
msg = await websocket_client.receive()
assert msg.type == aiohttp.WSMsgType.CLOSE
assert websocket_client.closed
async def test_auth_sending_invalid_json_disconnects(
hass: HomeAssistant, hass_client_no_auth: ClientSessionGenerator
) -> None:
"""Test sending invalid json during auth."""
assert await async_setup_component(hass, "websocket_api", {})
await hass.async_block_till_done()
client = await hass_client_no_auth()
async with client.ws_connect(URL) as ws:
auth_msg = await ws.receive_json()
assert auth_msg["type"] == TYPE_AUTH_REQUIRED
await ws.send_str("[--INVALID--JSON--]")
auth_msg = await ws.receive()
assert auth_msg.type == WSMsgType.close
async def test_auth_sending_binary_disconnects(
hass: HomeAssistant, hass_client_no_auth: ClientSessionGenerator
) -> None:
"""Test sending bytes during auth."""
assert await async_setup_component(hass, "websocket_api", {})
await hass.async_block_till_done()
client = await hass_client_no_auth()
async with client.ws_connect(URL) as ws:
auth_msg = await ws.receive_json()
assert auth_msg["type"] == TYPE_AUTH_REQUIRED
await ws.send_bytes(b"[INVALID]")
auth_msg = await ws.receive()
assert auth_msg.type == WSMsgType.close
async def test_auth_close_disconnects(
hass: HomeAssistant, hass_client_no_auth: ClientSessionGenerator
) -> None:
"""Test closing during auth."""
assert await async_setup_component(hass, "websocket_api", {})
await hass.async_block_till_done()
client = await hass_client_no_auth()
async with client.ws_connect(URL) as ws:
auth_msg = await ws.receive_json()
assert auth_msg["type"] == TYPE_AUTH_REQUIRED
await ws.close()
auth_msg = await ws.receive()
assert auth_msg.type == WSMsgType.CLOSED
async def test_auth_sending_unknown_type_disconnects(
hass: HomeAssistant, hass_client_no_auth: ClientSessionGenerator
) -> None:
"""Test sending unknown type during auth."""
assert await async_setup_component(hass, "websocket_api", {})
await hass.async_block_till_done()
client = await hass_client_no_auth()
async with client.ws_connect(URL) as ws:
auth_msg = await ws.receive_json()
assert auth_msg["type"] == TYPE_AUTH_REQUIRED
# pylint: disable-next=protected-access
await ws._writer._send_frame(b"1" * 130, 0x30)
auth_msg = await ws.receive()
assert auth_msg.type == WSMsgType.close

View file

@ -136,6 +136,89 @@ async def test_cleanup_on_cancellation(
assert len(subscriptions) == 0
async def test_delayed_response_handler(
hass: HomeAssistant,
websocket_client: MockHAClientWebSocket,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test a handler that responds after a connection has already been closed."""
subscriptions = None
# Register a handler that responds after it returns
@callback
@websocket_command(
{
"type": "late_responder",
}
)
def async_late_responder(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
msg_id: int = msg["id"]
nonlocal subscriptions
subscriptions = connection.subscriptions
connection.subscriptions[msg_id] = lambda: None
connection.send_result(msg_id)
async def _async_late_send_message():
await asyncio.sleep(0.05)
connection.send_event(msg_id, {"event": "any"})
hass.async_create_task(_async_late_send_message())
async_register_command(hass, async_late_responder)
await websocket_client.send_json({"id": 1, "type": "ping"})
msg = await websocket_client.receive_json()
assert msg["id"] == 1
assert msg["type"] == "pong"
assert not subscriptions
await websocket_client.send_json({"id": 2, "type": "late_responder"})
msg = await websocket_client.receive_json()
assert msg["id"] == 2
assert msg["type"] == "result"
assert len(subscriptions) == 2
assert await websocket_client.close()
await hass.async_block_till_done()
assert len(subscriptions) == 0
assert "Tried to send message" in caplog.text
assert "on closed connection" in caplog.text
async def test_ensure_disconnect_invalid_json(
hass: HomeAssistant,
websocket_client: MockHAClientWebSocket,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test we get disconnected when sending invalid JSON."""
await websocket_client.send_json({"id": 1, "type": "ping"})
msg = await websocket_client.receive_json()
assert msg["id"] == 1
assert msg["type"] == "pong"
await websocket_client.send_str("[--INVALID-JSON--]")
msg = await websocket_client.receive()
assert msg.type == WSMsgType.CLOSE
async def test_ensure_disconnect_invalid_binary(
hass: HomeAssistant,
websocket_client: MockHAClientWebSocket,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test we get disconnected when sending invalid bytes."""
await websocket_client.send_json({"id": 1, "type": "ping"})
msg = await websocket_client.receive_json()
assert msg["id"] == 1
assert msg["type"] == "pong"
await websocket_client.send_bytes(b"")
msg = await websocket_client.receive()
assert msg.type == WSMsgType.CLOSE
async def test_pending_msg_peak(
hass: HomeAssistant,
mock_low_peak,
@ -299,6 +382,58 @@ async def test_prepare_fail(
assert "Timeout preparing request" in caplog.text
async def test_enable_coalesce(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test enabling coalesce."""
websocket_client = await hass_ws_client(hass)
await websocket_client.send_json(
{
"id": 1,
"type": "supported_features",
"features": {const.FEATURE_COALESCE_MESSAGES: 1},
}
)
msg = await websocket_client.receive_json()
assert msg["id"] == 1
assert msg["success"] is True
send_tasks: list[asyncio.Future] = []
ids: set[int] = set()
start_id = 2
for idx in range(10):
id_ = idx + start_id
ids.add(id_)
send_tasks.append(websocket_client.send_json({"id": id_, "type": "ping"}))
await asyncio.gather(*send_tasks)
returned_ids: set[int] = set()
for _ in range(10):
msg = await websocket_client.receive_json()
assert msg["type"] == "pong"
returned_ids.add(msg["id"])
assert ids == returned_ids
# Now close
send_tasks_with_close: list[asyncio.Future] = []
start_id = 12
for idx in range(10):
id_ = idx + start_id
send_tasks_with_close.append(
websocket_client.send_json({"id": id_, "type": "ping"})
)
send_tasks_with_close.append(websocket_client.close())
send_tasks_with_close.append(websocket_client.send_json({"id": 50, "type": "ping"}))
with pytest.raises(ConnectionResetError):
await asyncio.gather(*send_tasks_with_close)
async def test_binary_message(
hass: HomeAssistant, websocket_client, caplog: pytest.LogCaptureFixture
) -> None: