diff --git a/homeassistant/components/websocket_api/const.py b/homeassistant/components/websocket_api/const.py index 183e7008853..121ea7496de 100644 --- a/homeassistant/components/websocket_api/const.py +++ b/homeassistant/components/websocket_api/const.py @@ -16,7 +16,9 @@ WebSocketCommandHandler = Callable[[HomeAssistant, "ActiveConnection", dict], No DOMAIN = "websocket_api" URL = "/api/websocket" -MAX_PENDING_MSG = 512 +PENDING_MSG_PEAK = 512 +PENDING_MSG_PEAK_TIME = 5 +MAX_PENDING_MSG = 2048 ERR_ID_REUSE = "id_reuse" ERR_INVALID_FORMAT = "invalid_format" diff --git a/homeassistant/components/websocket_api/http.py b/homeassistant/components/websocket_api/http.py index 3921413fd28..80ec35f5f7e 100644 --- a/homeassistant/components/websocket_api/http.py +++ b/homeassistant/components/websocket_api/http.py @@ -10,6 +10,7 @@ import async_timeout from homeassistant.components.http import HomeAssistantView from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.core import callback +from homeassistant.helpers.event import async_call_later from .auth import AuthPhase, auth_required_message from .const import ( @@ -18,6 +19,8 @@ from .const import ( ERR_UNKNOWN_ERROR, JSON_DUMP, MAX_PENDING_MSG, + PENDING_MSG_PEAK, + PENDING_MSG_PEAK_TIME, SIGNAL_WEBSOCKET_CONNECTED, SIGNAL_WEBSOCKET_DISCONNECTED, URL, @@ -52,6 +55,7 @@ class WebSocketHandler: self._handle_task = None self._writer_task = None self._logger = logging.getLogger("{}.connection.{}".format(__name__, id(self))) + self._peak_checker_unsub = None async def _writer(self): """Write outgoing messages.""" @@ -83,6 +87,11 @@ class WebSocketHandler: await self.wsock.send_str(dumped) + # Clean up the peaker checker when we shut down the writer + if self._peak_checker_unsub: + self._peak_checker_unsub() + self._peak_checker_unsub = None + @callback def _send_message(self, message): """Send a message to the client. @@ -97,8 +106,35 @@ class WebSocketHandler: self._logger.error( "Client exceeded max pending messages [2]: %s", MAX_PENDING_MSG ) + self._cancel() + if self._to_write.qsize() < PENDING_MSG_PEAK: + if self._peak_checker_unsub: + self._peak_checker_unsub() + self._peak_checker_unsub = None + return + + if self._peak_checker_unsub is None: + self._peak_checker_unsub = async_call_later( + self.hass, PENDING_MSG_PEAK_TIME, self._check_write_peak + ) + + @callback + def _check_write_peak(self, _): + """Check that we are no longer above the write peak.""" + self._peak_checker_unsub = None + + if self._to_write.qsize() < PENDING_MSG_PEAK: + return + + self._logger.error( + "Client unable to keep up with pending messages. Stayed over %s for %s seconds", + PENDING_MSG_PEAK, + PENDING_MSG_PEAK_TIME, + ) + self._cancel() + @callback def _cancel(self): """Cancel the connection.""" @@ -111,13 +147,7 @@ class WebSocketHandler: wsock = self.wsock = web.WebSocketResponse(heartbeat=55) await wsock.prepare(request) self._logger.debug("Connected") - - # Py3.7+ - if hasattr(asyncio, "current_task"): - # pylint: disable=no-member - self._handle_task = asyncio.current_task() - else: - self._handle_task = asyncio.Task.current_task() + self._handle_task = asyncio.current_task() @callback def handle_hass_stop(event): diff --git a/tests/components/websocket_api/conftest.py b/tests/components/websocket_api/conftest.py index 65b9232821f..93538f2b00b 100644 --- a/tests/components/websocket_api/conftest.py +++ b/tests/components/websocket_api/conftest.py @@ -7,23 +7,23 @@ from homeassistant.setup import async_setup_component @pytest.fixture -def websocket_client(hass, hass_ws_client, hass_access_token): +async def websocket_client(hass, hass_ws_client): """Create a websocket client.""" - return hass.loop.run_until_complete(hass_ws_client(hass, hass_access_token)) + return await hass_ws_client(hass) @pytest.fixture -def no_auth_websocket_client(hass, loop, aiohttp_client): +async def no_auth_websocket_client(hass, aiohttp_client): """Websocket connection that requires authentication.""" - assert loop.run_until_complete(async_setup_component(hass, "websocket_api", {})) + assert await async_setup_component(hass, "websocket_api", {}) - client = loop.run_until_complete(aiohttp_client(hass.http.app)) - ws = loop.run_until_complete(client.ws_connect(URL)) + client = await aiohttp_client(hass.http.app) + ws = await client.ws_connect(URL) - auth_ok = loop.run_until_complete(ws.receive_json()) + auth_ok = await ws.receive_json() assert auth_ok["type"] == TYPE_AUTH_REQUIRED yield ws if not ws.closed: - loop.run_until_complete(ws.close()) + await ws.close() diff --git a/tests/components/websocket_api/test_http.py b/tests/components/websocket_api/test_http.py new file mode 100644 index 00000000000..33a019b2e70 --- /dev/null +++ b/tests/components/websocket_api/test_http.py @@ -0,0 +1,66 @@ +"""Test Websocket API http module.""" +from datetime import timedelta + +from aiohttp import WSMsgType +from asynctest import patch +import pytest + +from homeassistant.components.websocket_api import const, http +from homeassistant.util.dt import utcnow + +from tests.common import async_fire_time_changed + + +@pytest.fixture +def mock_low_queue(): + """Mock a low queue.""" + with patch("homeassistant.components.websocket_api.http.MAX_PENDING_MSG", 5): + yield + + +@pytest.fixture +def mock_low_peak(): + """Mock a low queue.""" + with patch("homeassistant.components.websocket_api.http.PENDING_MSG_PEAK", 5): + yield + + +async def test_pending_msg_overflow(hass, mock_low_queue, websocket_client): + """Test get_panels command.""" + for idx in range(10): + await websocket_client.send_json({"id": idx + 1, "type": "ping"}) + msg = await websocket_client.receive() + assert msg.type == WSMsgType.close + + +async def test_pending_msg_peak(hass, mock_low_peak, hass_ws_client, caplog): + """Test pending msg overflow command.""" + orig_handler = http.WebSocketHandler + instance = None + + def instantiate_handler(*args): + nonlocal instance + instance = orig_handler(*args) + return instance + + with patch( + "homeassistant.components.websocket_api.http.WebSocketHandler", + instantiate_handler, + ): + websocket_client = await hass_ws_client() + + # Kill writer task and fill queue past peak + for _ in range(5): + instance._to_write.put_nowait(None) + + # Trigger the peak check + instance._send_message({}) + + async_fire_time_changed( + hass, utcnow() + timedelta(seconds=const.PENDING_MSG_PEAK_TIME + 1) + ) + + msg = await websocket_client.receive() + assert msg.type == WSMsgType.close + + assert "Client unable to keep up with pending messages" in caplog.text diff --git a/tests/components/websocket_api/test_init.py b/tests/components/websocket_api/test_init.py index d32f55516aa..041c0e76533 100644 --- a/tests/components/websocket_api/test_init.py +++ b/tests/components/websocket_api/test_init.py @@ -2,19 +2,11 @@ from unittest.mock import Mock, patch from aiohttp import WSMsgType -import pytest import voluptuous as vol from homeassistant.components.websocket_api import const, messages -@pytest.fixture -def mock_low_queue(): - """Mock a low queue.""" - with patch("homeassistant.components.websocket_api.http.MAX_PENDING_MSG", 5): - yield - - async def test_invalid_message_format(websocket_client): """Test sending invalid JSON.""" await websocket_client.send_json({"type": 5}) @@ -46,14 +38,6 @@ async def test_quiting_hass(hass, websocket_client): assert msg.type == WSMsgType.CLOSE -async def test_pending_msg_overflow(hass, mock_low_queue, websocket_client): - """Test get_panels command.""" - for idx in range(10): - await websocket_client.send_json({"id": idx + 1, "type": "ping"}) - msg = await websocket_client.receive() - assert msg.type == WSMsgType.close - - async def test_unknown_command(websocket_client): """Test get_panels command.""" await websocket_client.send_json({"id": 5, "type": "unknown_command"}) diff --git a/tests/conftest.py b/tests/conftest.py index f93d5190350..89a3065bcec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -228,10 +228,10 @@ def hass_client(hass, aiohttp_client, hass_access_token): @pytest.fixture -def hass_ws_client(aiohttp_client, hass_access_token): +def hass_ws_client(aiohttp_client, hass_access_token, hass): """Websocket client fixture connected to websocket server.""" - async def create_client(hass, access_token=hass_access_token): + async def create_client(hass=hass, access_token=hass_access_token): """Create a websocket client.""" assert await async_setup_component(hass, "websocket_api", {})