Protect waiting for event loop from within event loop (#3658)
* Protect waiting for event loop from within event loop * Faster fetching of loop attribute for ident check
This commit is contained in:
parent
e455daa61d
commit
abb8bcb6d9
4 changed files with 93 additions and 0 deletions
|
@ -199,6 +199,8 @@ class HomeAssistant(object):
|
|||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
self.loop._thread_ident = threading.get_ident()
|
||||
async_create_timer(self)
|
||||
async_monitor_worker_pool(self)
|
||||
self.bus.async_fire(EVENT_HOMEASSISTANT_START)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""Asyncio backports for Python 3.4.3 compatibility."""
|
||||
import concurrent.futures
|
||||
import threading
|
||||
from asyncio import coroutines
|
||||
from asyncio.futures import Future
|
||||
|
||||
|
@ -97,6 +98,10 @@ def run_coroutine_threadsafe(coro, loop):
|
|||
|
||||
Return a concurrent.futures.Future to access the result.
|
||||
"""
|
||||
ident = loop.__dict__.get("_thread_ident")
|
||||
if ident is not None and ident == threading.get_ident():
|
||||
raise RuntimeError('Cannot be called from within the event loop')
|
||||
|
||||
if not coroutines.iscoroutine(coro):
|
||||
raise TypeError('A coroutine object is required')
|
||||
future = concurrent.futures.Future()
|
||||
|
@ -122,6 +127,10 @@ def fire_coroutine_threadsafe(coro, loop):
|
|||
is intended for fire-and-forget use. This reduces the
|
||||
work involved to fire the function on the loop.
|
||||
"""
|
||||
ident = loop.__dict__.get("_thread_ident")
|
||||
if ident is not None and ident == threading.get_ident():
|
||||
raise RuntimeError('Cannot be called from within the event loop')
|
||||
|
||||
if not coroutines.iscoroutine(coro):
|
||||
raise TypeError('A coroutine object is required: %s' % coro)
|
||||
|
||||
|
@ -139,6 +148,10 @@ def run_callback_threadsafe(loop, callback, *args):
|
|||
|
||||
Return a concurrent.futures.Future to access the result.
|
||||
"""
|
||||
ident = loop.__dict__.get("_thread_ident")
|
||||
if ident is not None and ident == threading.get_ident():
|
||||
raise RuntimeError('Cannot be called from within the event loop')
|
||||
|
||||
future = concurrent.futures.Future()
|
||||
|
||||
def run_callback():
|
||||
|
|
|
@ -58,6 +58,7 @@ def get_test_home_assistant(num_threads=None):
|
|||
stop_event = threading.Event()
|
||||
|
||||
def run_loop():
|
||||
loop._thread_ident = threading.get_ident()
|
||||
loop.run_forever()
|
||||
loop.close()
|
||||
stop_event.set()
|
||||
|
|
|
@ -1,10 +1,87 @@
|
|||
"""Tests for async util methods from Python source."""
|
||||
import asyncio
|
||||
from asyncio import test_utils
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.util import async as hasync
|
||||
|
||||
|
||||
@patch('asyncio.coroutines.iscoroutine', return_value=True)
|
||||
@patch('concurrent.futures.Future')
|
||||
@patch('threading.get_ident')
|
||||
def test_run_coroutine_threadsafe_from_inside_event_loop(mock_ident, _, __):
|
||||
"""Testing calling run_coroutine_threadsafe from inside an event loop."""
|
||||
coro = MagicMock()
|
||||
loop = MagicMock()
|
||||
|
||||
loop._thread_ident = None
|
||||
mock_ident.return_value = 5
|
||||
hasync.run_coroutine_threadsafe(coro, loop)
|
||||
assert len(loop.call_soon_threadsafe.mock_calls) == 1
|
||||
|
||||
loop._thread_ident = 5
|
||||
mock_ident.return_value = 5
|
||||
with pytest.raises(RuntimeError):
|
||||
hasync.run_coroutine_threadsafe(coro, loop)
|
||||
assert len(loop.call_soon_threadsafe.mock_calls) == 1
|
||||
|
||||
loop._thread_ident = 1
|
||||
mock_ident.return_value = 5
|
||||
hasync.run_coroutine_threadsafe(coro, loop)
|
||||
assert len(loop.call_soon_threadsafe.mock_calls) == 2
|
||||
|
||||
|
||||
@patch('asyncio.coroutines.iscoroutine', return_value=True)
|
||||
@patch('concurrent.futures.Future')
|
||||
@patch('threading.get_ident')
|
||||
def test_fire_coroutine_threadsafe_from_inside_event_loop(mock_ident, _, __):
|
||||
"""Testing calling fire_coroutine_threadsafe from inside an event loop."""
|
||||
coro = MagicMock()
|
||||
loop = MagicMock()
|
||||
|
||||
loop._thread_ident = None
|
||||
mock_ident.return_value = 5
|
||||
hasync.fire_coroutine_threadsafe(coro, loop)
|
||||
assert len(loop.call_soon_threadsafe.mock_calls) == 1
|
||||
|
||||
loop._thread_ident = 5
|
||||
mock_ident.return_value = 5
|
||||
with pytest.raises(RuntimeError):
|
||||
hasync.fire_coroutine_threadsafe(coro, loop)
|
||||
assert len(loop.call_soon_threadsafe.mock_calls) == 1
|
||||
|
||||
loop._thread_ident = 1
|
||||
mock_ident.return_value = 5
|
||||
hasync.fire_coroutine_threadsafe(coro, loop)
|
||||
assert len(loop.call_soon_threadsafe.mock_calls) == 2
|
||||
|
||||
|
||||
@patch('concurrent.futures.Future')
|
||||
@patch('threading.get_ident')
|
||||
def test_run_callback_threadsafe_from_inside_event_loop(mock_ident, _):
|
||||
"""Testing calling run_callback_threadsafe from inside an event loop."""
|
||||
callback = MagicMock()
|
||||
loop = MagicMock()
|
||||
|
||||
loop._thread_ident = None
|
||||
mock_ident.return_value = 5
|
||||
hasync.run_callback_threadsafe(loop, callback)
|
||||
assert len(loop.call_soon_threadsafe.mock_calls) == 1
|
||||
|
||||
loop._thread_ident = 5
|
||||
mock_ident.return_value = 5
|
||||
with pytest.raises(RuntimeError):
|
||||
hasync.run_callback_threadsafe(loop, callback)
|
||||
assert len(loop.call_soon_threadsafe.mock_calls) == 1
|
||||
|
||||
loop._thread_ident = 1
|
||||
mock_ident.return_value = 5
|
||||
hasync.run_callback_threadsafe(loop, callback)
|
||||
assert len(loop.call_soon_threadsafe.mock_calls) == 2
|
||||
|
||||
|
||||
class RunCoroutineThreadsafeTests(test_utils.TestCase):
|
||||
"""Test case for asyncio.run_coroutine_threadsafe."""
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue