Catch concurrent.futures.CancelledError in websocket code. (#12150)

* Catch concurrent.futures.CancelledError in websocket code.

* Added a comment about the use of futures.CancelledError
This commit is contained in:
Phil Elson 2018-02-07 05:30:18 +00:00 committed by Paulus Schoutsen
parent 7e246e4680
commit 5ba02c531e

View file

@ -5,6 +5,7 @@ For more details about this component, please refer to the documentation at
https://home-assistant.io/developers/websocket_api/ https://home-assistant.io/developers/websocket_api/
""" """
import asyncio import asyncio
from concurrent import futures
from contextlib import suppress from contextlib import suppress
from functools import partial from functools import partial
import json import json
@ -120,6 +121,11 @@ BASE_COMMAND_MESSAGE_SCHEMA = vol.Schema({
TYPE_PING) TYPE_PING)
}, extra=vol.ALLOW_EXTRA) }, extra=vol.ALLOW_EXTRA)
# Define the possible errors that occur when connections are cancelled.
# Originally, this was just asyncio.CancelledError, but issue #9546 showed
# that futures.CancelledErrors can also occur in some situations.
CANCELLATION_ERRORS = (asyncio.CancelledError, futures.CancelledError)
def auth_ok_message(): def auth_ok_message():
"""Return an auth_ok message.""" """Return an auth_ok message."""
@ -231,7 +237,7 @@ class ActiveConnection:
def _writer(self): def _writer(self):
"""Write outgoing messages.""" """Write outgoing messages."""
# Exceptions if Socket disconnected or cancelled by connection handler # Exceptions if Socket disconnected or cancelled by connection handler
with suppress(RuntimeError, asyncio.CancelledError): with suppress(RuntimeError, *CANCELLATION_ERRORS):
while not self.wsock.closed: while not self.wsock.closed:
message = yield from self.to_write.get() message = yield from self.to_write.get()
if message is None: if message is None:
@ -363,7 +369,7 @@ class ActiveConnection:
self.log_error(msg) self.log_error(msg)
self._writer_task.cancel() self._writer_task.cancel()
except asyncio.CancelledError: except CANCELLATION_ERRORS:
self.debug("Connection cancelled by server") self.debug("Connection cancelled by server")
except asyncio.QueueFull: except asyncio.QueueFull: