From 206d02d5313be45e57f9cd16fcae2981befb9aa3 Mon Sep 17 00:00:00 2001
From: Paulus Schoutsen <paulus@paulusschoutsen.nl>
Date: Sat, 13 May 2017 16:34:45 -0700
Subject: [PATCH] Websocket_api: avoid parallel drain (#7576)

* Websocket_api: avoid parallel drain

* Remove send_message method
---
 homeassistant/components/websocket_api.py | 103 +++++++++++++---------
 1 file changed, 60 insertions(+), 43 deletions(-)

diff --git a/homeassistant/components/websocket_api.py b/homeassistant/components/websocket_api.py
index 466236573c8..fcfd7f404e9 100644
--- a/homeassistant/components/websocket_api.py
+++ b/homeassistant/components/websocket_api.py
@@ -30,6 +30,8 @@ DOMAIN = 'websocket_api'
 URL = '/api/websocket'
 DEPENDENCIES = 'http',
 
+MAX_PENDING_MSG = 512
+
 ERR_ID_REUSE = 1
 ERR_INVALID_FORMAT = 2
 ERR_NOT_FOUND = 3
@@ -211,6 +213,7 @@ class ActiveConnection:
         self.request = request
         self.wsock = None
         self.event_listeners = {}
+        self.to_write = asyncio.Queue(maxsize=MAX_PENDING_MSG, loop=hass.loop)
 
     def debug(self, message1, message2=''):
         """Print a debug message."""
@@ -220,13 +223,19 @@ class ActiveConnection:
         """Print an error message."""
         _LOGGER.error("WS %s: %s %s", id(self.wsock), message1, message2)
 
-    def send_message(self, message):
-        """Send messages.
-
-        Returns a coroutine object.
-        """
-        self.debug("Sending", message)
-        return self.wsock.send_json(message, dumps=JSON_DUMP)
+    @asyncio.coroutine
+    def _writer(self):
+        """Write outgoing messages."""
+        try:
+            while True:
+                message = yield from self.to_write.get()
+                if message is None:
+                    break
+                self.debug("Sending", message)
+                yield from self.wsock.send_json(message, dumps=JSON_DUMP)
+        except (RuntimeError, asyncio.CancelledError):
+            # Socket disconnected or cancelled by connection handler
+            pass
 
     @asyncio.coroutine
     def handle(self):
@@ -244,7 +253,8 @@ class ActiveConnection:
 
         unsub_stop = self.hass.bus.async_listen(
             EVENT_HOMEASSISTANT_STOP, cancel_connection)
-
+        writer_task = self.hass.async_add_job(self._writer())
+        final_message = None
         self.debug("Connected")
 
         msg = None
@@ -255,7 +265,7 @@ class ActiveConnection:
                 authenticated = True
 
             else:
-                yield from self.send_message(auth_required_message())
+                yield from self.wsock.send_json(auth_required_message())
                 msg = yield from wsock.receive_json()
                 msg = AUTH_MESSAGE_SCHEMA(msg)
 
@@ -264,14 +274,14 @@ class ActiveConnection:
 
                 else:
                     self.debug("Invalid password")
-                    yield from self.send_message(
+                    yield from self.wsock.send_json(
                         auth_invalid_message('Invalid password'))
 
             if not authenticated:
                 yield from process_wrong_login(self.request)
                 return wsock
 
-            yield from self.send_message(auth_ok_message())
+            yield from self.wsock.send_json(auth_ok_message())
 
             msg = yield from wsock.receive_json()
 
@@ -283,13 +293,13 @@ class ActiveConnection:
                 cur_id = msg['id']
 
                 if cur_id <= last_id:
-                    yield from self.send_message(error_message(
+                    self.to_write.put_nowait(error_message(
                         cur_id, ERR_ID_REUSE,
                         'Identifier values have to increase.'))
 
                 else:
                     handler_name = 'handle_{}'.format(msg['type'])
-                    yield from getattr(self, handler_name)(msg)
+                    getattr(self, handler_name)(msg)
 
                 last_id = cur_id
                 msg = yield from wsock.receive_json()
@@ -304,7 +314,7 @@ class ActiveConnection:
             self.log_error(error_msg)
 
             if not authenticated:
-                yield from self.send_message(auth_invalid_message(error_msg))
+                final_message = auth_invalid_message(error_msg)
 
             else:
                 if isinstance(msg, dict):
@@ -312,8 +322,8 @@ class ActiveConnection:
                 else:
                     iden = None
 
-                yield from self.send_message(error_message(
-                    iden, ERR_INVALID_FORMAT, error_msg))
+                final_message = error_message(
+                    iden, ERR_INVALID_FORMAT, error_msg)
 
         except TypeError as err:
             if wsock.closed:
@@ -331,6 +341,11 @@ class ActiveConnection:
         except asyncio.CancelledError:
             self.debug("Connection cancelled by server")
 
+        except asyncio.QueueFull:
+            self.log_error("Client exceeded max pending messages:",
+                           MAX_PENDING_MSG)
+            writer_task.cancel()
+
         except Exception:  # pylint: disable=broad-except
             error = "Unexpected error inside websocket API. "
             if msg is not None:
@@ -338,6 +353,15 @@ class ActiveConnection:
             _LOGGER.exception(error)
 
         finally:
+            try:
+                if final_message is not None:
+                    self.to_write.put_nowait(final_message)
+                self.to_write.put_nowait(None)
+                # Make sure all error messages are written before closing
+                yield from writer_task
+            except asyncio.QueueFull:
+                pass
+
             unsub_stop()
 
             for unsub in self.event_listeners.values():
@@ -348,9 +372,11 @@ class ActiveConnection:
 
         return wsock
 
-    @asyncio.coroutine
     def handle_subscribe_events(self, msg):
-        """Handle subscribe events command."""
+        """Handle subscribe events command.
+
+        Async friendly.
+        """
         msg = SUBSCRIBE_EVENTS_MESSAGE_SCHEMA(msg)
 
         @asyncio.coroutine
@@ -359,21 +385,17 @@ class ActiveConnection:
             if event.event_type == EVENT_TIME_CHANGED:
                 return
 
-            try:
-                yield from self.send_message(event_message(msg['id'], event))
-            except RuntimeError:
-                # Socket has been closed.
-                pass
+            self.to_write.put_nowait(event_message(msg['id'], event))
 
         self.event_listeners[msg['id']] = self.hass.bus.async_listen(
             msg['event_type'], forward_events)
 
-        return self.send_message(result_message(msg['id']))
+        self.to_write.put_nowait(result_message(msg['id']))
 
     def handle_unsubscribe_events(self, msg):
         """Handle unsubscribe events command.
 
-        Returns a coroutine object.
+        Async friendly.
         """
         msg = UNSUBSCRIBE_EVENTS_MESSAGE_SCHEMA(msg)
 
@@ -381,13 +403,12 @@ class ActiveConnection:
 
         if subscription in self.event_listeners:
             self.event_listeners.pop(subscription)()
-            return self.send_message(result_message(msg['id']))
+            self.to_write.put_nowait(result_message(msg['id']))
         else:
-            return self.send_message(error_message(
+            self.to_write.put_nowait(error_message(
                 msg['id'], ERR_NOT_FOUND,
                 'Subscription not found.'))
 
-    @asyncio.coroutine
     def handle_call_service(self, msg):
         """Handle call service command.
 
@@ -400,57 +421,53 @@ class ActiveConnection:
             """Call a service and fire complete message."""
             yield from self.hass.services.async_call(
                 msg['domain'], msg['service'], msg['service_data'], True)
-            try:
-                yield from self.send_message(result_message(msg['id']))
-            except RuntimeError:
-                # Socket has been closed.
-                pass
+            self.to_write.put_nowait(result_message(msg['id']))
 
         self.hass.async_add_job(call_service_helper(msg))
 
     def handle_get_states(self, msg):
         """Handle get states command.
 
-        Returns a coroutine object.
+        Async friendly.
         """
         msg = GET_STATES_MESSAGE_SCHEMA(msg)
 
-        return self.send_message(result_message(
+        self.to_write.put_nowait(result_message(
             msg['id'], self.hass.states.async_all()))
 
     def handle_get_services(self, msg):
         """Handle get services command.
 
-        Returns a coroutine object.
+        Async friendly.
         """
         msg = GET_SERVICES_MESSAGE_SCHEMA(msg)
 
-        return self.send_message(result_message(
+        self.to_write.put_nowait(result_message(
             msg['id'], self.hass.services.async_services()))
 
     def handle_get_config(self, msg):
         """Handle get config command.
 
-        Returns a coroutine object.
+        Async friendly.
         """
         msg = GET_CONFIG_MESSAGE_SCHEMA(msg)
 
-        return self.send_message(result_message(
+        self.to_write.put_nowait(result_message(
             msg['id'], self.hass.config.as_dict()))
 
     def handle_get_panels(self, msg):
         """Handle get panels command.
 
-        Returns a coroutine object.
+        Async friendly.
         """
         msg = GET_PANELS_MESSAGE_SCHEMA(msg)
 
-        return self.send_message(result_message(
+        self.to_write.put_nowait(result_message(
             msg['id'], self.hass.data[frontend.DATA_PANELS]))
 
     def handle_ping(self, msg):
         """Handle ping command.
 
-        Returns a coroutine object.
+        Async friendly.
         """
-        return self.send_message(pong_message(msg['id']))
+        self.to_write.put_nowait(pong_message(msg['id']))