Allow WS queue to temporarily peak (#34175)
* Allow WS queue to temporarily peak * Remove unused code
This commit is contained in:
parent
dbcc294d67
commit
bea354b82a
6 changed files with 116 additions and 34 deletions
|
@ -16,7 +16,9 @@ WebSocketCommandHandler = Callable[[HomeAssistant, "ActiveConnection", dict], No
|
|||
|
||||
DOMAIN = "websocket_api"
|
||||
URL = "/api/websocket"
|
||||
MAX_PENDING_MSG = 512
|
||||
PENDING_MSG_PEAK = 512
|
||||
PENDING_MSG_PEAK_TIME = 5
|
||||
MAX_PENDING_MSG = 2048
|
||||
|
||||
ERR_ID_REUSE = "id_reuse"
|
||||
ERR_INVALID_FORMAT = "invalid_format"
|
||||
|
|
|
@ -10,6 +10,7 @@ import async_timeout
|
|||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
|
||||
from .auth import AuthPhase, auth_required_message
|
||||
from .const import (
|
||||
|
@ -18,6 +19,8 @@ from .const import (
|
|||
ERR_UNKNOWN_ERROR,
|
||||
JSON_DUMP,
|
||||
MAX_PENDING_MSG,
|
||||
PENDING_MSG_PEAK,
|
||||
PENDING_MSG_PEAK_TIME,
|
||||
SIGNAL_WEBSOCKET_CONNECTED,
|
||||
SIGNAL_WEBSOCKET_DISCONNECTED,
|
||||
URL,
|
||||
|
@ -52,6 +55,7 @@ class WebSocketHandler:
|
|||
self._handle_task = None
|
||||
self._writer_task = None
|
||||
self._logger = logging.getLogger("{}.connection.{}".format(__name__, id(self)))
|
||||
self._peak_checker_unsub = None
|
||||
|
||||
async def _writer(self):
|
||||
"""Write outgoing messages."""
|
||||
|
@ -83,6 +87,11 @@ class WebSocketHandler:
|
|||
|
||||
await self.wsock.send_str(dumped)
|
||||
|
||||
# Clean up the peaker checker when we shut down the writer
|
||||
if self._peak_checker_unsub:
|
||||
self._peak_checker_unsub()
|
||||
self._peak_checker_unsub = None
|
||||
|
||||
@callback
|
||||
def _send_message(self, message):
|
||||
"""Send a message to the client.
|
||||
|
@ -97,8 +106,35 @@ class WebSocketHandler:
|
|||
self._logger.error(
|
||||
"Client exceeded max pending messages [2]: %s", MAX_PENDING_MSG
|
||||
)
|
||||
|
||||
self._cancel()
|
||||
|
||||
if self._to_write.qsize() < PENDING_MSG_PEAK:
|
||||
if self._peak_checker_unsub:
|
||||
self._peak_checker_unsub()
|
||||
self._peak_checker_unsub = None
|
||||
return
|
||||
|
||||
if self._peak_checker_unsub is None:
|
||||
self._peak_checker_unsub = async_call_later(
|
||||
self.hass, PENDING_MSG_PEAK_TIME, self._check_write_peak
|
||||
)
|
||||
|
||||
@callback
|
||||
def _check_write_peak(self, _):
|
||||
"""Check that we are no longer above the write peak."""
|
||||
self._peak_checker_unsub = None
|
||||
|
||||
if self._to_write.qsize() < PENDING_MSG_PEAK:
|
||||
return
|
||||
|
||||
self._logger.error(
|
||||
"Client unable to keep up with pending messages. Stayed over %s for %s seconds",
|
||||
PENDING_MSG_PEAK,
|
||||
PENDING_MSG_PEAK_TIME,
|
||||
)
|
||||
self._cancel()
|
||||
|
||||
@callback
|
||||
def _cancel(self):
|
||||
"""Cancel the connection."""
|
||||
|
@ -111,13 +147,7 @@ class WebSocketHandler:
|
|||
wsock = self.wsock = web.WebSocketResponse(heartbeat=55)
|
||||
await wsock.prepare(request)
|
||||
self._logger.debug("Connected")
|
||||
|
||||
# Py3.7+
|
||||
if hasattr(asyncio, "current_task"):
|
||||
# pylint: disable=no-member
|
||||
self._handle_task = asyncio.current_task()
|
||||
else:
|
||||
self._handle_task = asyncio.Task.current_task()
|
||||
self._handle_task = asyncio.current_task()
|
||||
|
||||
@callback
|
||||
def handle_hass_stop(event):
|
||||
|
|
|
@ -7,23 +7,23 @@ from homeassistant.setup import async_setup_component
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def websocket_client(hass, hass_ws_client, hass_access_token):
|
||||
async def websocket_client(hass, hass_ws_client):
|
||||
"""Create a websocket client."""
|
||||
return hass.loop.run_until_complete(hass_ws_client(hass, hass_access_token))
|
||||
return await hass_ws_client(hass)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def no_auth_websocket_client(hass, loop, aiohttp_client):
|
||||
async def no_auth_websocket_client(hass, aiohttp_client):
|
||||
"""Websocket connection that requires authentication."""
|
||||
assert loop.run_until_complete(async_setup_component(hass, "websocket_api", {}))
|
||||
assert await async_setup_component(hass, "websocket_api", {})
|
||||
|
||||
client = loop.run_until_complete(aiohttp_client(hass.http.app))
|
||||
ws = loop.run_until_complete(client.ws_connect(URL))
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
ws = await client.ws_connect(URL)
|
||||
|
||||
auth_ok = loop.run_until_complete(ws.receive_json())
|
||||
auth_ok = await ws.receive_json()
|
||||
assert auth_ok["type"] == TYPE_AUTH_REQUIRED
|
||||
|
||||
yield ws
|
||||
|
||||
if not ws.closed:
|
||||
loop.run_until_complete(ws.close())
|
||||
await ws.close()
|
||||
|
|
66
tests/components/websocket_api/test_http.py
Normal file
66
tests/components/websocket_api/test_http.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
"""Test Websocket API http module."""
|
||||
from datetime import timedelta
|
||||
|
||||
from aiohttp import WSMsgType
|
||||
from asynctest import patch
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.websocket_api import const, http
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
from tests.common import async_fire_time_changed
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_low_queue():
|
||||
"""Mock a low queue."""
|
||||
with patch("homeassistant.components.websocket_api.http.MAX_PENDING_MSG", 5):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_low_peak():
|
||||
"""Mock a low queue."""
|
||||
with patch("homeassistant.components.websocket_api.http.PENDING_MSG_PEAK", 5):
|
||||
yield
|
||||
|
||||
|
||||
async def test_pending_msg_overflow(hass, mock_low_queue, websocket_client):
|
||||
"""Test get_panels command."""
|
||||
for idx in range(10):
|
||||
await websocket_client.send_json({"id": idx + 1, "type": "ping"})
|
||||
msg = await websocket_client.receive()
|
||||
assert msg.type == WSMsgType.close
|
||||
|
||||
|
||||
async def test_pending_msg_peak(hass, mock_low_peak, hass_ws_client, caplog):
|
||||
"""Test pending msg overflow command."""
|
||||
orig_handler = http.WebSocketHandler
|
||||
instance = None
|
||||
|
||||
def instantiate_handler(*args):
|
||||
nonlocal instance
|
||||
instance = orig_handler(*args)
|
||||
return instance
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.websocket_api.http.WebSocketHandler",
|
||||
instantiate_handler,
|
||||
):
|
||||
websocket_client = await hass_ws_client()
|
||||
|
||||
# Kill writer task and fill queue past peak
|
||||
for _ in range(5):
|
||||
instance._to_write.put_nowait(None)
|
||||
|
||||
# Trigger the peak check
|
||||
instance._send_message({})
|
||||
|
||||
async_fire_time_changed(
|
||||
hass, utcnow() + timedelta(seconds=const.PENDING_MSG_PEAK_TIME + 1)
|
||||
)
|
||||
|
||||
msg = await websocket_client.receive()
|
||||
assert msg.type == WSMsgType.close
|
||||
|
||||
assert "Client unable to keep up with pending messages" in caplog.text
|
|
@ -2,19 +2,11 @@
|
|||
from unittest.mock import Mock, patch
|
||||
|
||||
from aiohttp import WSMsgType
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.websocket_api import const, messages
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_low_queue():
|
||||
"""Mock a low queue."""
|
||||
with patch("homeassistant.components.websocket_api.http.MAX_PENDING_MSG", 5):
|
||||
yield
|
||||
|
||||
|
||||
async def test_invalid_message_format(websocket_client):
|
||||
"""Test sending invalid JSON."""
|
||||
await websocket_client.send_json({"type": 5})
|
||||
|
@ -46,14 +38,6 @@ async def test_quiting_hass(hass, websocket_client):
|
|||
assert msg.type == WSMsgType.CLOSE
|
||||
|
||||
|
||||
async def test_pending_msg_overflow(hass, mock_low_queue, websocket_client):
|
||||
"""Test get_panels command."""
|
||||
for idx in range(10):
|
||||
await websocket_client.send_json({"id": idx + 1, "type": "ping"})
|
||||
msg = await websocket_client.receive()
|
||||
assert msg.type == WSMsgType.close
|
||||
|
||||
|
||||
async def test_unknown_command(websocket_client):
|
||||
"""Test get_panels command."""
|
||||
await websocket_client.send_json({"id": 5, "type": "unknown_command"})
|
||||
|
|
|
@ -228,10 +228,10 @@ def hass_client(hass, aiohttp_client, hass_access_token):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def hass_ws_client(aiohttp_client, hass_access_token):
|
||||
def hass_ws_client(aiohttp_client, hass_access_token, hass):
|
||||
"""Websocket client fixture connected to websocket server."""
|
||||
|
||||
async def create_client(hass, access_token=hass_access_token):
|
||||
async def create_client(hass=hass, access_token=hass_access_token):
|
||||
"""Create a websocket client."""
|
||||
assert await async_setup_component(hass, "websocket_api", {})
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue