Allow WS queue to temporarily peak (#34175)

* Allow WS queue to temporarily peak

* Remove unused code
This commit is contained in:
Paulus Schoutsen 2020-04-13 18:50:36 -07:00 committed by GitHub
parent dbcc294d67
commit bea354b82a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 116 additions and 34 deletions

View file

@ -16,7 +16,9 @@ WebSocketCommandHandler = Callable[[HomeAssistant, "ActiveConnection", dict], No
DOMAIN = "websocket_api" DOMAIN = "websocket_api"
URL = "/api/websocket" URL = "/api/websocket"
MAX_PENDING_MSG = 512 PENDING_MSG_PEAK = 512
PENDING_MSG_PEAK_TIME = 5
MAX_PENDING_MSG = 2048
ERR_ID_REUSE = "id_reuse" ERR_ID_REUSE = "id_reuse"
ERR_INVALID_FORMAT = "invalid_format" ERR_INVALID_FORMAT = "invalid_format"

View file

@ -10,6 +10,7 @@ import async_timeout
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers.event import async_call_later
from .auth import AuthPhase, auth_required_message from .auth import AuthPhase, auth_required_message
from .const import ( from .const import (
@ -18,6 +19,8 @@ from .const import (
ERR_UNKNOWN_ERROR, ERR_UNKNOWN_ERROR,
JSON_DUMP, JSON_DUMP,
MAX_PENDING_MSG, MAX_PENDING_MSG,
PENDING_MSG_PEAK,
PENDING_MSG_PEAK_TIME,
SIGNAL_WEBSOCKET_CONNECTED, SIGNAL_WEBSOCKET_CONNECTED,
SIGNAL_WEBSOCKET_DISCONNECTED, SIGNAL_WEBSOCKET_DISCONNECTED,
URL, URL,
@ -52,6 +55,7 @@ class WebSocketHandler:
self._handle_task = None self._handle_task = None
self._writer_task = None self._writer_task = None
self._logger = logging.getLogger("{}.connection.{}".format(__name__, id(self))) self._logger = logging.getLogger("{}.connection.{}".format(__name__, id(self)))
self._peak_checker_unsub = None
async def _writer(self): async def _writer(self):
"""Write outgoing messages.""" """Write outgoing messages."""
@ -83,6 +87,11 @@ class WebSocketHandler:
await self.wsock.send_str(dumped) await self.wsock.send_str(dumped)
# Clean up the peaker checker when we shut down the writer
if self._peak_checker_unsub:
self._peak_checker_unsub()
self._peak_checker_unsub = None
@callback @callback
def _send_message(self, message): def _send_message(self, message):
"""Send a message to the client. """Send a message to the client.
@ -97,8 +106,35 @@ class WebSocketHandler:
self._logger.error( self._logger.error(
"Client exceeded max pending messages [2]: %s", MAX_PENDING_MSG "Client exceeded max pending messages [2]: %s", MAX_PENDING_MSG
) )
self._cancel() self._cancel()
if self._to_write.qsize() < PENDING_MSG_PEAK:
if self._peak_checker_unsub:
self._peak_checker_unsub()
self._peak_checker_unsub = None
return
if self._peak_checker_unsub is None:
self._peak_checker_unsub = async_call_later(
self.hass, PENDING_MSG_PEAK_TIME, self._check_write_peak
)
@callback
def _check_write_peak(self, _):
"""Check that we are no longer above the write peak."""
self._peak_checker_unsub = None
if self._to_write.qsize() < PENDING_MSG_PEAK:
return
self._logger.error(
"Client unable to keep up with pending messages. Stayed over %s for %s seconds",
PENDING_MSG_PEAK,
PENDING_MSG_PEAK_TIME,
)
self._cancel()
@callback @callback
def _cancel(self): def _cancel(self):
"""Cancel the connection.""" """Cancel the connection."""
@ -111,13 +147,7 @@ class WebSocketHandler:
wsock = self.wsock = web.WebSocketResponse(heartbeat=55) wsock = self.wsock = web.WebSocketResponse(heartbeat=55)
await wsock.prepare(request) await wsock.prepare(request)
self._logger.debug("Connected") self._logger.debug("Connected")
self._handle_task = asyncio.current_task()
# Py3.7+
if hasattr(asyncio, "current_task"):
# pylint: disable=no-member
self._handle_task = asyncio.current_task()
else:
self._handle_task = asyncio.Task.current_task()
@callback @callback
def handle_hass_stop(event): def handle_hass_stop(event):

View file

@ -7,23 +7,23 @@ from homeassistant.setup import async_setup_component
@pytest.fixture @pytest.fixture
def websocket_client(hass, hass_ws_client, hass_access_token): async def websocket_client(hass, hass_ws_client):
"""Create a websocket client.""" """Create a websocket client."""
return hass.loop.run_until_complete(hass_ws_client(hass, hass_access_token)) return await hass_ws_client(hass)
@pytest.fixture @pytest.fixture
def no_auth_websocket_client(hass, loop, aiohttp_client): async def no_auth_websocket_client(hass, aiohttp_client):
"""Websocket connection that requires authentication.""" """Websocket connection that requires authentication."""
assert loop.run_until_complete(async_setup_component(hass, "websocket_api", {})) assert await async_setup_component(hass, "websocket_api", {})
client = loop.run_until_complete(aiohttp_client(hass.http.app)) client = await aiohttp_client(hass.http.app)
ws = loop.run_until_complete(client.ws_connect(URL)) ws = await client.ws_connect(URL)
auth_ok = loop.run_until_complete(ws.receive_json()) auth_ok = await ws.receive_json()
assert auth_ok["type"] == TYPE_AUTH_REQUIRED assert auth_ok["type"] == TYPE_AUTH_REQUIRED
yield ws yield ws
if not ws.closed: if not ws.closed:
loop.run_until_complete(ws.close()) await ws.close()

View file

@ -0,0 +1,66 @@
"""Test Websocket API http module."""
from datetime import timedelta
from aiohttp import WSMsgType
from asynctest import patch
import pytest
from homeassistant.components.websocket_api import const, http
from homeassistant.util.dt import utcnow
from tests.common import async_fire_time_changed
@pytest.fixture
def mock_low_queue():
"""Mock a low queue."""
with patch("homeassistant.components.websocket_api.http.MAX_PENDING_MSG", 5):
yield
@pytest.fixture
def mock_low_peak():
"""Mock a low queue."""
with patch("homeassistant.components.websocket_api.http.PENDING_MSG_PEAK", 5):
yield
async def test_pending_msg_overflow(hass, mock_low_queue, websocket_client):
"""Test get_panels command."""
for idx in range(10):
await websocket_client.send_json({"id": idx + 1, "type": "ping"})
msg = await websocket_client.receive()
assert msg.type == WSMsgType.close
async def test_pending_msg_peak(hass, mock_low_peak, hass_ws_client, caplog):
"""Test pending msg overflow command."""
orig_handler = http.WebSocketHandler
instance = None
def instantiate_handler(*args):
nonlocal instance
instance = orig_handler(*args)
return instance
with patch(
"homeassistant.components.websocket_api.http.WebSocketHandler",
instantiate_handler,
):
websocket_client = await hass_ws_client()
# Kill writer task and fill queue past peak
for _ in range(5):
instance._to_write.put_nowait(None)
# Trigger the peak check
instance._send_message({})
async_fire_time_changed(
hass, utcnow() + timedelta(seconds=const.PENDING_MSG_PEAK_TIME + 1)
)
msg = await websocket_client.receive()
assert msg.type == WSMsgType.close
assert "Client unable to keep up with pending messages" in caplog.text

View file

@ -2,19 +2,11 @@
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from aiohttp import WSMsgType from aiohttp import WSMsgType
import pytest
import voluptuous as vol import voluptuous as vol
from homeassistant.components.websocket_api import const, messages from homeassistant.components.websocket_api import const, messages
@pytest.fixture
def mock_low_queue():
"""Mock a low queue."""
with patch("homeassistant.components.websocket_api.http.MAX_PENDING_MSG", 5):
yield
async def test_invalid_message_format(websocket_client): async def test_invalid_message_format(websocket_client):
"""Test sending invalid JSON.""" """Test sending invalid JSON."""
await websocket_client.send_json({"type": 5}) await websocket_client.send_json({"type": 5})
@ -46,14 +38,6 @@ async def test_quiting_hass(hass, websocket_client):
assert msg.type == WSMsgType.CLOSE assert msg.type == WSMsgType.CLOSE
async def test_pending_msg_overflow(hass, mock_low_queue, websocket_client):
"""Test get_panels command."""
for idx in range(10):
await websocket_client.send_json({"id": idx + 1, "type": "ping"})
msg = await websocket_client.receive()
assert msg.type == WSMsgType.close
async def test_unknown_command(websocket_client): async def test_unknown_command(websocket_client):
"""Test get_panels command.""" """Test get_panels command."""
await websocket_client.send_json({"id": 5, "type": "unknown_command"}) await websocket_client.send_json({"id": 5, "type": "unknown_command"})

View file

@ -228,10 +228,10 @@ def hass_client(hass, aiohttp_client, hass_access_token):
@pytest.fixture @pytest.fixture
def hass_ws_client(aiohttp_client, hass_access_token): def hass_ws_client(aiohttp_client, hass_access_token, hass):
"""Websocket client fixture connected to websocket server.""" """Websocket client fixture connected to websocket server."""
async def create_client(hass, access_token=hass_access_token): async def create_client(hass=hass, access_token=hass_access_token):
"""Create a websocket client.""" """Create a websocket client."""
assert await async_setup_component(hass, "websocket_api", {}) assert await async_setup_component(hass, "websocket_api", {})