Improve websocket api coverage and typing (#94891)
This commit is contained in:
parent
b51dcb600e
commit
3f18f515e7
5 changed files with 246 additions and 31 deletions
|
@ -12,6 +12,7 @@ from homeassistant.auth.models import RefreshToken, User
|
|||
from homeassistant.components.http.ban import process_success_login, process_wrong_login
|
||||
from homeassistant.const import __version__
|
||||
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
|
||||
from homeassistant.util.json import JsonValueType
|
||||
|
||||
from .connection import ActiveConnection
|
||||
from .error import Disconnect
|
||||
|
@ -67,10 +68,10 @@ class AuthPhase:
|
|||
self._logger = logger
|
||||
self._request = request
|
||||
|
||||
async def async_handle(self, msg: dict[str, str]) -> ActiveConnection:
|
||||
async def async_handle(self, msg: JsonValueType) -> ActiveConnection:
|
||||
"""Handle authentication."""
|
||||
try:
|
||||
msg = AUTH_MESSAGE_SCHEMA(msg)
|
||||
valid_msg = AUTH_MESSAGE_SCHEMA(msg)
|
||||
except vol.Invalid as err:
|
||||
error_msg = (
|
||||
f"Auth message incorrectly formatted: {humanize_error(msg, err)}"
|
||||
|
@ -79,12 +80,11 @@ class AuthPhase:
|
|||
self._send_message(auth_invalid_message(error_msg))
|
||||
raise Disconnect from err
|
||||
|
||||
if "access_token" in msg:
|
||||
self._logger.debug("Received access_token")
|
||||
refresh_token = await self._hass.auth.async_validate_access_token(
|
||||
msg["access_token"]
|
||||
if (access_token := valid_msg.get("access_token")) and (
|
||||
refresh_token := await self._hass.auth.async_validate_access_token(
|
||||
access_token
|
||||
)
|
||||
if refresh_token is not None:
|
||||
):
|
||||
conn = await self._async_finish_auth(refresh_token.user, refresh_token)
|
||||
conn.subscriptions[
|
||||
"auth"
|
||||
|
|
|
@ -13,6 +13,7 @@ from homeassistant.auth.models import RefreshToken, User
|
|||
from homeassistant.components.http import current_request
|
||||
from homeassistant.core import Context, HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError, Unauthorized
|
||||
from homeassistant.util.json import JsonValueType
|
||||
|
||||
from . import const, messages
|
||||
from .util import describe_request
|
||||
|
@ -144,7 +145,7 @@ class ActiveConnection:
|
|||
self.binary_handlers[index] = None
|
||||
|
||||
@callback
|
||||
def async_handle(self, msg: dict[str, Any]) -> None:
|
||||
def async_handle(self, msg: JsonValueType) -> None:
|
||||
"""Handle a single incoming message."""
|
||||
if (
|
||||
# Not using isinstance as we don't care about children
|
||||
|
@ -157,10 +158,11 @@ class ActiveConnection:
|
|||
or type(type_) is not str # pylint: disable=unidiomatic-typecheck
|
||||
)
|
||||
):
|
||||
self.logger.error("Received invalid command", msg)
|
||||
self.logger.error("Received invalid command: %s", msg)
|
||||
id_ = msg.get("id") if isinstance(msg, dict) else 0
|
||||
self.send_message(
|
||||
messages.error_message(
|
||||
msg.get("id"),
|
||||
id_, # type: ignore[arg-type]
|
||||
const.ERR_INVALID_FORMAT,
|
||||
"Message incorrectly formatted.",
|
||||
)
|
||||
|
|
|
@ -56,8 +56,7 @@ class WebSocketAdapter(logging.LoggerAdapter):
|
|||
|
||||
def process(self, msg: str, kwargs: Any) -> tuple[str, Any]:
|
||||
"""Add connid to websocket log messages."""
|
||||
if not self.extra or "connid" not in self.extra:
|
||||
return msg, kwargs
|
||||
assert self.extra is not None
|
||||
return f'[{self.extra["connid"]}] {msg}', kwargs
|
||||
|
||||
|
||||
|
@ -81,7 +80,7 @@ class WebSocketHandler:
|
|||
# to where messages are queued. This allows the implementation
|
||||
# to use a deque and an asyncio.Future to avoid the overhead of
|
||||
# an asyncio.Queue.
|
||||
self._message_queue: deque = deque()
|
||||
self._message_queue: deque[str | Callable[[], str] | None] = deque()
|
||||
self._ready_future: asyncio.Future[None] | None = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
@ -302,14 +301,14 @@ class WebSocketHandler:
|
|||
raise Disconnect
|
||||
|
||||
try:
|
||||
msg_data = msg.json(loads=json_loads)
|
||||
auth_msg_data = json_loads(msg.data)
|
||||
except ValueError as err:
|
||||
disconnect_warn = "Received invalid JSON."
|
||||
raise Disconnect from err
|
||||
|
||||
if is_enabled_for(logging_debug):
|
||||
debug("%s: Received %s", self.description, msg_data)
|
||||
connection = await auth.async_handle(msg_data)
|
||||
debug("%s: Received %s", self.description, auth_msg_data)
|
||||
connection = await auth.async_handle(auth_msg_data)
|
||||
self._connection = connection
|
||||
hass.data[DATA_CONNECTIONS] = hass.data.get(DATA_CONNECTIONS, 0) + 1
|
||||
async_dispatcher_send(hass, SIGNAL_WEBSOCKET_CONNECTED)
|
||||
|
@ -317,7 +316,7 @@ class WebSocketHandler:
|
|||
self._authenticated = True
|
||||
#
|
||||
#
|
||||
# Our websocket implementation is backed by an asyncio.Queue
|
||||
# Our websocket implementation is backed by a deque
|
||||
#
|
||||
# As back-pressure builds, the queue will back up and use more memory
|
||||
# until we disconnect the client when the queue size reaches
|
||||
|
@ -351,6 +350,8 @@ class WebSocketHandler:
|
|||
# reach the code to set the limit, so we have to set it directly.
|
||||
#
|
||||
wsock._writer._limit = 2**20 # type: ignore[union-attr] # pylint: disable=protected-access
|
||||
async_handle_str = connection.async_handle
|
||||
async_handle_binary = connection.async_handle_binary
|
||||
|
||||
# Command phase
|
||||
while not wsock.closed:
|
||||
|
@ -365,7 +366,7 @@ class WebSocketHandler:
|
|||
break
|
||||
handler = msg.data[0]
|
||||
payload = msg.data[1:]
|
||||
connection.async_handle_binary(handler, payload)
|
||||
async_handle_binary(handler, payload)
|
||||
continue
|
||||
|
||||
if msg.type != WSMsgType.TEXT:
|
||||
|
@ -373,20 +374,20 @@ class WebSocketHandler:
|
|||
break
|
||||
|
||||
try:
|
||||
msg_data = msg.json(loads=json_loads)
|
||||
command_msg_data = json_loads(msg.data)
|
||||
except ValueError:
|
||||
disconnect_warn = "Received invalid JSON."
|
||||
break
|
||||
|
||||
if is_enabled_for(logging_debug):
|
||||
debug("%s: Received %s", self.description, msg_data)
|
||||
debug("%s: Received %s", self.description, command_msg_data)
|
||||
|
||||
if not isinstance(msg_data, list):
|
||||
connection.async_handle(msg_data)
|
||||
if not isinstance(command_msg_data, list):
|
||||
async_handle_str(command_msg_data)
|
||||
continue
|
||||
|
||||
for split_msg in msg_data:
|
||||
connection.async_handle(split_msg)
|
||||
for split_msg in command_msg_data:
|
||||
async_handle_str(split_msg)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
debug("%s: Connection cancelled", self.description)
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import WSMsgType
|
||||
import pytest
|
||||
|
||||
from homeassistant.auth.providers.legacy_api_password import (
|
||||
|
@ -223,3 +224,79 @@ async def test_auth_close_after_revoke(
|
|||
msg = await websocket_client.receive()
|
||||
assert msg.type == aiohttp.WSMsgType.CLOSE
|
||||
assert websocket_client.closed
|
||||
|
||||
|
||||
async def test_auth_sending_invalid_json_disconnects(
|
||||
hass: HomeAssistant, hass_client_no_auth: ClientSessionGenerator
|
||||
) -> None:
|
||||
"""Test sending invalid json during auth."""
|
||||
assert await async_setup_component(hass, "websocket_api", {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
client = await hass_client_no_auth()
|
||||
|
||||
async with client.ws_connect(URL) as ws:
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg["type"] == TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_str("[--INVALID--JSON--]")
|
||||
|
||||
auth_msg = await ws.receive()
|
||||
assert auth_msg.type == WSMsgType.close
|
||||
|
||||
|
||||
async def test_auth_sending_binary_disconnects(
|
||||
hass: HomeAssistant, hass_client_no_auth: ClientSessionGenerator
|
||||
) -> None:
|
||||
"""Test sending bytes during auth."""
|
||||
assert await async_setup_component(hass, "websocket_api", {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
client = await hass_client_no_auth()
|
||||
|
||||
async with client.ws_connect(URL) as ws:
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg["type"] == TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_bytes(b"[INVALID]")
|
||||
|
||||
auth_msg = await ws.receive()
|
||||
assert auth_msg.type == WSMsgType.close
|
||||
|
||||
|
||||
async def test_auth_close_disconnects(
|
||||
hass: HomeAssistant, hass_client_no_auth: ClientSessionGenerator
|
||||
) -> None:
|
||||
"""Test closing during auth."""
|
||||
assert await async_setup_component(hass, "websocket_api", {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
client = await hass_client_no_auth()
|
||||
|
||||
async with client.ws_connect(URL) as ws:
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg["type"] == TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.close()
|
||||
|
||||
auth_msg = await ws.receive()
|
||||
assert auth_msg.type == WSMsgType.CLOSED
|
||||
|
||||
|
||||
async def test_auth_sending_unknown_type_disconnects(
|
||||
hass: HomeAssistant, hass_client_no_auth: ClientSessionGenerator
|
||||
) -> None:
|
||||
"""Test sending unknown type during auth."""
|
||||
assert await async_setup_component(hass, "websocket_api", {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
client = await hass_client_no_auth()
|
||||
|
||||
async with client.ws_connect(URL) as ws:
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg["type"] == TYPE_AUTH_REQUIRED
|
||||
|
||||
# pylint: disable-next=protected-access
|
||||
await ws._writer._send_frame(b"1" * 130, 0x30)
|
||||
auth_msg = await ws.receive()
|
||||
assert auth_msg.type == WSMsgType.close
|
||||
|
|
|
@ -136,6 +136,89 @@ async def test_cleanup_on_cancellation(
|
|||
assert len(subscriptions) == 0
|
||||
|
||||
|
||||
async def test_delayed_response_handler(
|
||||
hass: HomeAssistant,
|
||||
websocket_client: MockHAClientWebSocket,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test a handler that responds after a connection has already been closed."""
|
||||
|
||||
subscriptions = None
|
||||
|
||||
# Register a handler that responds after it returns
|
||||
@callback
|
||||
@websocket_command(
|
||||
{
|
||||
"type": "late_responder",
|
||||
}
|
||||
)
|
||||
def async_late_responder(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
msg_id: int = msg["id"]
|
||||
nonlocal subscriptions
|
||||
subscriptions = connection.subscriptions
|
||||
connection.subscriptions[msg_id] = lambda: None
|
||||
connection.send_result(msg_id)
|
||||
|
||||
async def _async_late_send_message():
|
||||
await asyncio.sleep(0.05)
|
||||
connection.send_event(msg_id, {"event": "any"})
|
||||
|
||||
hass.async_create_task(_async_late_send_message())
|
||||
|
||||
async_register_command(hass, async_late_responder)
|
||||
|
||||
await websocket_client.send_json({"id": 1, "type": "ping"})
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == 1
|
||||
assert msg["type"] == "pong"
|
||||
assert not subscriptions
|
||||
await websocket_client.send_json({"id": 2, "type": "late_responder"})
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == 2
|
||||
assert msg["type"] == "result"
|
||||
assert len(subscriptions) == 2
|
||||
assert await websocket_client.close()
|
||||
await hass.async_block_till_done()
|
||||
assert len(subscriptions) == 0
|
||||
|
||||
assert "Tried to send message" in caplog.text
|
||||
assert "on closed connection" in caplog.text
|
||||
|
||||
|
||||
async def test_ensure_disconnect_invalid_json(
|
||||
hass: HomeAssistant,
|
||||
websocket_client: MockHAClientWebSocket,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test we get disconnected when sending invalid JSON."""
|
||||
|
||||
await websocket_client.send_json({"id": 1, "type": "ping"})
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == 1
|
||||
assert msg["type"] == "pong"
|
||||
await websocket_client.send_str("[--INVALID-JSON--]")
|
||||
msg = await websocket_client.receive()
|
||||
assert msg.type == WSMsgType.CLOSE
|
||||
|
||||
|
||||
async def test_ensure_disconnect_invalid_binary(
|
||||
hass: HomeAssistant,
|
||||
websocket_client: MockHAClientWebSocket,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test we get disconnected when sending invalid bytes."""
|
||||
|
||||
await websocket_client.send_json({"id": 1, "type": "ping"})
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == 1
|
||||
assert msg["type"] == "pong"
|
||||
await websocket_client.send_bytes(b"")
|
||||
msg = await websocket_client.receive()
|
||||
assert msg.type == WSMsgType.CLOSE
|
||||
|
||||
|
||||
async def test_pending_msg_peak(
|
||||
hass: HomeAssistant,
|
||||
mock_low_peak,
|
||||
|
@ -299,6 +382,58 @@ async def test_prepare_fail(
|
|||
assert "Timeout preparing request" in caplog.text
|
||||
|
||||
|
||||
async def test_enable_coalesce(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test enabling coalesce."""
|
||||
websocket_client = await hass_ws_client(hass)
|
||||
|
||||
await websocket_client.send_json(
|
||||
{
|
||||
"id": 1,
|
||||
"type": "supported_features",
|
||||
"features": {const.FEATURE_COALESCE_MESSAGES: 1},
|
||||
}
|
||||
)
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == 1
|
||||
assert msg["success"] is True
|
||||
send_tasks: list[asyncio.Future] = []
|
||||
ids: set[int] = set()
|
||||
start_id = 2
|
||||
|
||||
for idx in range(10):
|
||||
id_ = idx + start_id
|
||||
ids.add(id_)
|
||||
send_tasks.append(websocket_client.send_json({"id": id_, "type": "ping"}))
|
||||
|
||||
await asyncio.gather(*send_tasks)
|
||||
returned_ids: set[int] = set()
|
||||
for _ in range(10):
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["type"] == "pong"
|
||||
returned_ids.add(msg["id"])
|
||||
|
||||
assert ids == returned_ids
|
||||
|
||||
# Now close
|
||||
send_tasks_with_close: list[asyncio.Future] = []
|
||||
start_id = 12
|
||||
for idx in range(10):
|
||||
id_ = idx + start_id
|
||||
send_tasks_with_close.append(
|
||||
websocket_client.send_json({"id": id_, "type": "ping"})
|
||||
)
|
||||
|
||||
send_tasks_with_close.append(websocket_client.close())
|
||||
send_tasks_with_close.append(websocket_client.send_json({"id": 50, "type": "ping"}))
|
||||
|
||||
with pytest.raises(ConnectionResetError):
|
||||
await asyncio.gather(*send_tasks_with_close)
|
||||
|
||||
|
||||
async def test_binary_message(
|
||||
hass: HomeAssistant, websocket_client, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue