Allow WS queue to temporarily peak (#34175)
* Allow WS queue to temporarily peak * Remove unused code
This commit is contained in:
parent
dbcc294d67
commit
bea354b82a
6 changed files with 116 additions and 34 deletions
|
@ -16,7 +16,9 @@ WebSocketCommandHandler = Callable[[HomeAssistant, "ActiveConnection", dict], No
|
||||||
|
|
||||||
DOMAIN = "websocket_api"
|
DOMAIN = "websocket_api"
|
||||||
URL = "/api/websocket"
|
URL = "/api/websocket"
|
||||||
MAX_PENDING_MSG = 512
|
PENDING_MSG_PEAK = 512
|
||||||
|
PENDING_MSG_PEAK_TIME = 5
|
||||||
|
MAX_PENDING_MSG = 2048
|
||||||
|
|
||||||
ERR_ID_REUSE = "id_reuse"
|
ERR_ID_REUSE = "id_reuse"
|
||||||
ERR_INVALID_FORMAT = "invalid_format"
|
ERR_INVALID_FORMAT = "invalid_format"
|
||||||
|
|
|
@ -10,6 +10,7 @@ import async_timeout
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView
|
||||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
from homeassistant.helpers.event import async_call_later
|
||||||
|
|
||||||
from .auth import AuthPhase, auth_required_message
|
from .auth import AuthPhase, auth_required_message
|
||||||
from .const import (
|
from .const import (
|
||||||
|
@ -18,6 +19,8 @@ from .const import (
|
||||||
ERR_UNKNOWN_ERROR,
|
ERR_UNKNOWN_ERROR,
|
||||||
JSON_DUMP,
|
JSON_DUMP,
|
||||||
MAX_PENDING_MSG,
|
MAX_PENDING_MSG,
|
||||||
|
PENDING_MSG_PEAK,
|
||||||
|
PENDING_MSG_PEAK_TIME,
|
||||||
SIGNAL_WEBSOCKET_CONNECTED,
|
SIGNAL_WEBSOCKET_CONNECTED,
|
||||||
SIGNAL_WEBSOCKET_DISCONNECTED,
|
SIGNAL_WEBSOCKET_DISCONNECTED,
|
||||||
URL,
|
URL,
|
||||||
|
@ -52,6 +55,7 @@ class WebSocketHandler:
|
||||||
self._handle_task = None
|
self._handle_task = None
|
||||||
self._writer_task = None
|
self._writer_task = None
|
||||||
self._logger = logging.getLogger("{}.connection.{}".format(__name__, id(self)))
|
self._logger = logging.getLogger("{}.connection.{}".format(__name__, id(self)))
|
||||||
|
self._peak_checker_unsub = None
|
||||||
|
|
||||||
async def _writer(self):
|
async def _writer(self):
|
||||||
"""Write outgoing messages."""
|
"""Write outgoing messages."""
|
||||||
|
@ -83,6 +87,11 @@ class WebSocketHandler:
|
||||||
|
|
||||||
await self.wsock.send_str(dumped)
|
await self.wsock.send_str(dumped)
|
||||||
|
|
||||||
|
# Clean up the peaker checker when we shut down the writer
|
||||||
|
if self._peak_checker_unsub:
|
||||||
|
self._peak_checker_unsub()
|
||||||
|
self._peak_checker_unsub = None
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _send_message(self, message):
|
def _send_message(self, message):
|
||||||
"""Send a message to the client.
|
"""Send a message to the client.
|
||||||
|
@ -97,8 +106,35 @@ class WebSocketHandler:
|
||||||
self._logger.error(
|
self._logger.error(
|
||||||
"Client exceeded max pending messages [2]: %s", MAX_PENDING_MSG
|
"Client exceeded max pending messages [2]: %s", MAX_PENDING_MSG
|
||||||
)
|
)
|
||||||
|
|
||||||
self._cancel()
|
self._cancel()
|
||||||
|
|
||||||
|
if self._to_write.qsize() < PENDING_MSG_PEAK:
|
||||||
|
if self._peak_checker_unsub:
|
||||||
|
self._peak_checker_unsub()
|
||||||
|
self._peak_checker_unsub = None
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._peak_checker_unsub is None:
|
||||||
|
self._peak_checker_unsub = async_call_later(
|
||||||
|
self.hass, PENDING_MSG_PEAK_TIME, self._check_write_peak
|
||||||
|
)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _check_write_peak(self, _):
|
||||||
|
"""Check that we are no longer above the write peak."""
|
||||||
|
self._peak_checker_unsub = None
|
||||||
|
|
||||||
|
if self._to_write.qsize() < PENDING_MSG_PEAK:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._logger.error(
|
||||||
|
"Client unable to keep up with pending messages. Stayed over %s for %s seconds",
|
||||||
|
PENDING_MSG_PEAK,
|
||||||
|
PENDING_MSG_PEAK_TIME,
|
||||||
|
)
|
||||||
|
self._cancel()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _cancel(self):
|
def _cancel(self):
|
||||||
"""Cancel the connection."""
|
"""Cancel the connection."""
|
||||||
|
@ -111,13 +147,7 @@ class WebSocketHandler:
|
||||||
wsock = self.wsock = web.WebSocketResponse(heartbeat=55)
|
wsock = self.wsock = web.WebSocketResponse(heartbeat=55)
|
||||||
await wsock.prepare(request)
|
await wsock.prepare(request)
|
||||||
self._logger.debug("Connected")
|
self._logger.debug("Connected")
|
||||||
|
self._handle_task = asyncio.current_task()
|
||||||
# Py3.7+
|
|
||||||
if hasattr(asyncio, "current_task"):
|
|
||||||
# pylint: disable=no-member
|
|
||||||
self._handle_task = asyncio.current_task()
|
|
||||||
else:
|
|
||||||
self._handle_task = asyncio.Task.current_task()
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def handle_hass_stop(event):
|
def handle_hass_stop(event):
|
||||||
|
|
|
@ -7,23 +7,23 @@ from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def websocket_client(hass, hass_ws_client, hass_access_token):
|
async def websocket_client(hass, hass_ws_client):
|
||||||
"""Create a websocket client."""
|
"""Create a websocket client."""
|
||||||
return hass.loop.run_until_complete(hass_ws_client(hass, hass_access_token))
|
return await hass_ws_client(hass)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def no_auth_websocket_client(hass, loop, aiohttp_client):
|
async def no_auth_websocket_client(hass, aiohttp_client):
|
||||||
"""Websocket connection that requires authentication."""
|
"""Websocket connection that requires authentication."""
|
||||||
assert loop.run_until_complete(async_setup_component(hass, "websocket_api", {}))
|
assert await async_setup_component(hass, "websocket_api", {})
|
||||||
|
|
||||||
client = loop.run_until_complete(aiohttp_client(hass.http.app))
|
client = await aiohttp_client(hass.http.app)
|
||||||
ws = loop.run_until_complete(client.ws_connect(URL))
|
ws = await client.ws_connect(URL)
|
||||||
|
|
||||||
auth_ok = loop.run_until_complete(ws.receive_json())
|
auth_ok = await ws.receive_json()
|
||||||
assert auth_ok["type"] == TYPE_AUTH_REQUIRED
|
assert auth_ok["type"] == TYPE_AUTH_REQUIRED
|
||||||
|
|
||||||
yield ws
|
yield ws
|
||||||
|
|
||||||
if not ws.closed:
|
if not ws.closed:
|
||||||
loop.run_until_complete(ws.close())
|
await ws.close()
|
||||||
|
|
66
tests/components/websocket_api/test_http.py
Normal file
66
tests/components/websocket_api/test_http.py
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
"""Test Websocket API http module."""
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
from aiohttp import WSMsgType
|
||||||
|
from asynctest import patch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components.websocket_api import const, http
|
||||||
|
from homeassistant.util.dt import utcnow
|
||||||
|
|
||||||
|
from tests.common import async_fire_time_changed
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_low_queue():
|
||||||
|
"""Mock a low queue."""
|
||||||
|
with patch("homeassistant.components.websocket_api.http.MAX_PENDING_MSG", 5):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_low_peak():
|
||||||
|
"""Mock a low queue."""
|
||||||
|
with patch("homeassistant.components.websocket_api.http.PENDING_MSG_PEAK", 5):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pending_msg_overflow(hass, mock_low_queue, websocket_client):
|
||||||
|
"""Test get_panels command."""
|
||||||
|
for idx in range(10):
|
||||||
|
await websocket_client.send_json({"id": idx + 1, "type": "ping"})
|
||||||
|
msg = await websocket_client.receive()
|
||||||
|
assert msg.type == WSMsgType.close
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pending_msg_peak(hass, mock_low_peak, hass_ws_client, caplog):
|
||||||
|
"""Test pending msg overflow command."""
|
||||||
|
orig_handler = http.WebSocketHandler
|
||||||
|
instance = None
|
||||||
|
|
||||||
|
def instantiate_handler(*args):
|
||||||
|
nonlocal instance
|
||||||
|
instance = orig_handler(*args)
|
||||||
|
return instance
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.websocket_api.http.WebSocketHandler",
|
||||||
|
instantiate_handler,
|
||||||
|
):
|
||||||
|
websocket_client = await hass_ws_client()
|
||||||
|
|
||||||
|
# Kill writer task and fill queue past peak
|
||||||
|
for _ in range(5):
|
||||||
|
instance._to_write.put_nowait(None)
|
||||||
|
|
||||||
|
# Trigger the peak check
|
||||||
|
instance._send_message({})
|
||||||
|
|
||||||
|
async_fire_time_changed(
|
||||||
|
hass, utcnow() + timedelta(seconds=const.PENDING_MSG_PEAK_TIME + 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = await websocket_client.receive()
|
||||||
|
assert msg.type == WSMsgType.close
|
||||||
|
|
||||||
|
assert "Client unable to keep up with pending messages" in caplog.text
|
|
@ -2,19 +2,11 @@
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from aiohttp import WSMsgType
|
from aiohttp import WSMsgType
|
||||||
import pytest
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components.websocket_api import const, messages
|
from homeassistant.components.websocket_api import const, messages
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_low_queue():
|
|
||||||
"""Mock a low queue."""
|
|
||||||
with patch("homeassistant.components.websocket_api.http.MAX_PENDING_MSG", 5):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
async def test_invalid_message_format(websocket_client):
|
async def test_invalid_message_format(websocket_client):
|
||||||
"""Test sending invalid JSON."""
|
"""Test sending invalid JSON."""
|
||||||
await websocket_client.send_json({"type": 5})
|
await websocket_client.send_json({"type": 5})
|
||||||
|
@ -46,14 +38,6 @@ async def test_quiting_hass(hass, websocket_client):
|
||||||
assert msg.type == WSMsgType.CLOSE
|
assert msg.type == WSMsgType.CLOSE
|
||||||
|
|
||||||
|
|
||||||
async def test_pending_msg_overflow(hass, mock_low_queue, websocket_client):
|
|
||||||
"""Test get_panels command."""
|
|
||||||
for idx in range(10):
|
|
||||||
await websocket_client.send_json({"id": idx + 1, "type": "ping"})
|
|
||||||
msg = await websocket_client.receive()
|
|
||||||
assert msg.type == WSMsgType.close
|
|
||||||
|
|
||||||
|
|
||||||
async def test_unknown_command(websocket_client):
|
async def test_unknown_command(websocket_client):
|
||||||
"""Test get_panels command."""
|
"""Test get_panels command."""
|
||||||
await websocket_client.send_json({"id": 5, "type": "unknown_command"})
|
await websocket_client.send_json({"id": 5, "type": "unknown_command"})
|
||||||
|
|
|
@ -228,10 +228,10 @@ def hass_client(hass, aiohttp_client, hass_access_token):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def hass_ws_client(aiohttp_client, hass_access_token):
|
def hass_ws_client(aiohttp_client, hass_access_token, hass):
|
||||||
"""Websocket client fixture connected to websocket server."""
|
"""Websocket client fixture connected to websocket server."""
|
||||||
|
|
||||||
async def create_client(hass, access_token=hass_access_token):
|
async def create_client(hass=hass, access_token=hass_access_token):
|
||||||
"""Create a websocket client."""
|
"""Create a websocket client."""
|
||||||
assert await async_setup_component(hass, "websocket_api", {})
|
assert await async_setup_component(hass, "websocket_api", {})
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue