diff --git a/homeassistant/components/websocket_api.py b/homeassistant/components/websocket_api.py index 466236573c8..fcfd7f404e9 100644 --- a/homeassistant/components/websocket_api.py +++ b/homeassistant/components/websocket_api.py @@ -30,6 +30,8 @@ DOMAIN = 'websocket_api' URL = '/api/websocket' DEPENDENCIES = 'http', +MAX_PENDING_MSG = 512 + ERR_ID_REUSE = 1 ERR_INVALID_FORMAT = 2 ERR_NOT_FOUND = 3 @@ -211,6 +213,7 @@ class ActiveConnection: self.request = request self.wsock = None self.event_listeners = {} + self.to_write = asyncio.Queue(maxsize=MAX_PENDING_MSG, loop=hass.loop) def debug(self, message1, message2=''): """Print a debug message.""" @@ -220,13 +223,19 @@ class ActiveConnection: """Print an error message.""" _LOGGER.error("WS %s: %s %s", id(self.wsock), message1, message2) - def send_message(self, message): - """Send messages. - - Returns a coroutine object. - """ - self.debug("Sending", message) - return self.wsock.send_json(message, dumps=JSON_DUMP) + @asyncio.coroutine + def _writer(self): + """Write outgoing messages.""" + try: + while True: + message = yield from self.to_write.get() + if message is None: + break + self.debug("Sending", message) + yield from self.wsock.send_json(message, dumps=JSON_DUMP) + except (RuntimeError, asyncio.CancelledError): + # Socket disconnected or cancelled by connection handler + pass @asyncio.coroutine def handle(self): @@ -244,7 +253,8 @@ class ActiveConnection: unsub_stop = self.hass.bus.async_listen( EVENT_HOMEASSISTANT_STOP, cancel_connection) - + writer_task = self.hass.async_add_job(self._writer()) + final_message = None self.debug("Connected") msg = None @@ -255,7 +265,7 @@ class ActiveConnection: authenticated = True else: - yield from self.send_message(auth_required_message()) + yield from self.wsock.send_json(auth_required_message()) msg = yield from wsock.receive_json() msg = AUTH_MESSAGE_SCHEMA(msg) @@ -264,14 +274,14 @@ class ActiveConnection: else: self.debug("Invalid password") - yield from self.send_message( + yield from self.wsock.send_json( auth_invalid_message('Invalid password')) if not authenticated: yield from process_wrong_login(self.request) return wsock - yield from self.send_message(auth_ok_message()) + yield from self.wsock.send_json(auth_ok_message()) msg = yield from wsock.receive_json() @@ -283,13 +293,13 @@ class ActiveConnection: cur_id = msg['id'] if cur_id <= last_id: - yield from self.send_message(error_message( + self.to_write.put_nowait(error_message( cur_id, ERR_ID_REUSE, 'Identifier values have to increase.')) else: handler_name = 'handle_{}'.format(msg['type']) - yield from getattr(self, handler_name)(msg) + getattr(self, handler_name)(msg) last_id = cur_id msg = yield from wsock.receive_json() @@ -304,7 +314,7 @@ class ActiveConnection: self.log_error(error_msg) if not authenticated: - yield from self.send_message(auth_invalid_message(error_msg)) + final_message = auth_invalid_message(error_msg) else: if isinstance(msg, dict): @@ -312,8 +322,8 @@ class ActiveConnection: else: iden = None - yield from self.send_message(error_message( - iden, ERR_INVALID_FORMAT, error_msg)) + final_message = error_message( + iden, ERR_INVALID_FORMAT, error_msg) except TypeError as err: if wsock.closed: @@ -331,6 +341,11 @@ class ActiveConnection: except asyncio.CancelledError: self.debug("Connection cancelled by server") + except asyncio.QueueFull: + self.log_error("Client exceeded max pending messages:", + MAX_PENDING_MSG) + writer_task.cancel() + except Exception: # pylint: disable=broad-except error = "Unexpected error inside websocket API. " if msg is not None: @@ -338,6 +353,15 @@ class ActiveConnection: _LOGGER.exception(error) finally: + try: + if final_message is not None: + self.to_write.put_nowait(final_message) + self.to_write.put_nowait(None) + # Make sure all error messages are written before closing + yield from writer_task + except asyncio.QueueFull: + pass + unsub_stop() for unsub in self.event_listeners.values(): @@ -348,9 +372,11 @@ class ActiveConnection: return wsock - @asyncio.coroutine def handle_subscribe_events(self, msg): - """Handle subscribe events command.""" + """Handle subscribe events command. + + Async friendly. + """ msg = SUBSCRIBE_EVENTS_MESSAGE_SCHEMA(msg) @asyncio.coroutine @@ -359,21 +385,17 @@ class ActiveConnection: if event.event_type == EVENT_TIME_CHANGED: return - try: - yield from self.send_message(event_message(msg['id'], event)) - except RuntimeError: - # Socket has been closed. - pass + self.to_write.put_nowait(event_message(msg['id'], event)) self.event_listeners[msg['id']] = self.hass.bus.async_listen( msg['event_type'], forward_events) - return self.send_message(result_message(msg['id'])) + self.to_write.put_nowait(result_message(msg['id'])) def handle_unsubscribe_events(self, msg): """Handle unsubscribe events command. - Returns a coroutine object. + Async friendly. """ msg = UNSUBSCRIBE_EVENTS_MESSAGE_SCHEMA(msg) @@ -381,13 +403,12 @@ class ActiveConnection: if subscription in self.event_listeners: self.event_listeners.pop(subscription)() - return self.send_message(result_message(msg['id'])) + self.to_write.put_nowait(result_message(msg['id'])) else: - return self.send_message(error_message( + self.to_write.put_nowait(error_message( msg['id'], ERR_NOT_FOUND, 'Subscription not found.')) - @asyncio.coroutine def handle_call_service(self, msg): """Handle call service command. @@ -400,57 +421,53 @@ class ActiveConnection: """Call a service and fire complete message.""" yield from self.hass.services.async_call( msg['domain'], msg['service'], msg['service_data'], True) - try: - yield from self.send_message(result_message(msg['id'])) - except RuntimeError: - # Socket has been closed. - pass + self.to_write.put_nowait(result_message(msg['id'])) self.hass.async_add_job(call_service_helper(msg)) def handle_get_states(self, msg): """Handle get states command. - Returns a coroutine object. + Async friendly. """ msg = GET_STATES_MESSAGE_SCHEMA(msg) - return self.send_message(result_message( + self.to_write.put_nowait(result_message( msg['id'], self.hass.states.async_all())) def handle_get_services(self, msg): """Handle get services command. - Returns a coroutine object. + Async friendly. """ msg = GET_SERVICES_MESSAGE_SCHEMA(msg) - return self.send_message(result_message( + self.to_write.put_nowait(result_message( msg['id'], self.hass.services.async_services())) def handle_get_config(self, msg): """Handle get config command. - Returns a coroutine object. + Async friendly. """ msg = GET_CONFIG_MESSAGE_SCHEMA(msg) - return self.send_message(result_message( + self.to_write.put_nowait(result_message( msg['id'], self.hass.config.as_dict())) def handle_get_panels(self, msg): """Handle get panels command. - Returns a coroutine object. + Async friendly. """ msg = GET_PANELS_MESSAGE_SCHEMA(msg) - return self.send_message(result_message( + self.to_write.put_nowait(result_message( msg['id'], self.hass.data[frontend.DATA_PANELS])) def handle_ping(self, msg): """Handle ping command. - Returns a coroutine object. + Async friendly. """ - return self.send_message(pong_message(msg['id'])) + self.to_write.put_nowait(pong_message(msg['id']))