Replace fire_coroutine_threadsafe with asyncio.run_coroutine_threadsafe (#88572)

fire_coroutine_threadsafe did not hold a reference to the asyncio
task which meant the task had the risk of being prematurely
garbage collected
This commit is contained in:
J. Nick Koston 2023-02-21 20:16:18 -06:00 committed by GitHub
parent e54eb7e2c8
commit 5bc0636905
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 70 deletions

View file

@ -14,6 +14,7 @@ from collections.abc import (
Iterable, Iterable,
Mapping, Mapping,
) )
import concurrent.futures
from contextlib import suppress from contextlib import suppress
from contextvars import ContextVar from contextvars import ContextVar
import datetime import datetime
@ -79,11 +80,7 @@ from .exceptions import (
) )
from .helpers.aiohttp_compat import restore_original_aiohttp_cancel_behavior from .helpers.aiohttp_compat import restore_original_aiohttp_cancel_behavior
from .util import dt as dt_util, location, ulid as ulid_util from .util import dt as dt_util, location, ulid as ulid_util
from .util.async_ import ( from .util.async_ import run_callback_threadsafe, shutdown_run_callback_threadsafe
fire_coroutine_threadsafe,
run_callback_threadsafe,
shutdown_run_callback_threadsafe,
)
from .util.read_only_dict import ReadOnlyDict from .util.read_only_dict import ReadOnlyDict
from .util.timeout import TimeoutManager from .util.timeout import TimeoutManager
from .util.unit_system import ( from .util.unit_system import (
@ -294,6 +291,7 @@ class HomeAssistant:
self._stopped: asyncio.Event | None = None self._stopped: asyncio.Event | None = None
# Timeout handler for Core/Helper namespace # Timeout handler for Core/Helper namespace
self.timeout: TimeoutManager = TimeoutManager() self.timeout: TimeoutManager = TimeoutManager()
self._stop_future: concurrent.futures.Future[None] | None = None
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
@ -312,12 +310,14 @@ class HomeAssistant:
For regular use, use "await hass.run()". For regular use, use "await hass.run()".
""" """
# Register the async start # Register the async start
fire_coroutine_threadsafe(self.async_start(), self.loop) _future = asyncio.run_coroutine_threadsafe(self.async_start(), self.loop)
# Run forever # Run forever
# Block until stopped # Block until stopped
_LOGGER.info("Starting Home Assistant core loop") _LOGGER.info("Starting Home Assistant core loop")
self.loop.run_forever() self.loop.run_forever()
# The future is never retrieved but we still hold a reference to it
# to prevent the task from being garbage collected prematurely.
del _future
return self.exit_code return self.exit_code
async def async_run(self, *, attach_signals: bool = True) -> int: async def async_run(self, *, attach_signals: bool = True) -> int:
@ -682,7 +682,11 @@ class HomeAssistant:
"""Stop Home Assistant and shuts down all threads.""" """Stop Home Assistant and shuts down all threads."""
if self.state == CoreState.not_running: # just ignore if self.state == CoreState.not_running: # just ignore
return return
fire_coroutine_threadsafe(self.async_stop(), self.loop) # The future is never retrieved, and we only hold a reference
# to it to prevent it from being garbage collected.
self._stop_future = asyncio.run_coroutine_threadsafe(
self.async_stop(), self.loop
)
async def async_stop(self, exit_code: int = 0, *, force: bool = False) -> None: async def async_stop(self, exit_code: int = 0, *, force: bool = False) -> None:
"""Stop Home Assistant and shuts down all threads. """Stop Home Assistant and shuts down all threads.

View file

@ -1,9 +1,9 @@
"""Asyncio utilities.""" """Asyncio utilities."""
from __future__ import annotations from __future__ import annotations
from asyncio import Semaphore, coroutines, ensure_future, gather, get_running_loop from asyncio import Semaphore, gather, get_running_loop
from asyncio.events import AbstractEventLoop from asyncio.events import AbstractEventLoop
from collections.abc import Awaitable, Callable, Coroutine from collections.abc import Awaitable, Callable
import concurrent.futures import concurrent.futures
import functools import functools
import logging import logging
@ -20,29 +20,6 @@ _R = TypeVar("_R")
_P = ParamSpec("_P") _P = ParamSpec("_P")
def fire_coroutine_threadsafe(
coro: Coroutine[Any, Any, Any], loop: AbstractEventLoop
) -> None:
"""Submit a coroutine object to a given event loop.
This method does not provide a way to retrieve the result and
is intended for fire-and-forget use. This reduces the
work involved to fire the function on the loop.
"""
ident = loop.__dict__.get("_thread_ident")
if ident is not None and ident == threading.get_ident():
raise RuntimeError("Cannot be called from within the event loop")
if not coroutines.iscoroutine(coro):
raise TypeError(f"A coroutine object is required: {coro}")
def callback() -> None:
"""Handle the firing of a coroutine."""
ensure_future(coro, loop=loop)
loop.call_soon_threadsafe(callback)
def run_callback_threadsafe( def run_callback_threadsafe(
loop: AbstractEventLoop, callback: Callable[..., _T], *args: Any loop: AbstractEventLoop, callback: Callable[..., _T], *args: Any
) -> concurrent.futures.Future[_T]: ) -> concurrent.futures.Future[_T]:

View file

@ -10,43 +10,6 @@ from homeassistant.core import HomeAssistant
from homeassistant.util import async_ as hasync from homeassistant.util import async_ as hasync
@patch("asyncio.coroutines.iscoroutine")
@patch("concurrent.futures.Future")
@patch("threading.get_ident")
def test_fire_coroutine_threadsafe_from_inside_event_loop(
mock_ident, _, mock_iscoroutine
) -> None:
"""Testing calling fire_coroutine_threadsafe from inside an event loop."""
coro = MagicMock()
loop = MagicMock()
loop._thread_ident = None
mock_ident.return_value = 5
mock_iscoroutine.return_value = True
hasync.fire_coroutine_threadsafe(coro, loop)
assert len(loop.call_soon_threadsafe.mock_calls) == 1
loop._thread_ident = 5
mock_ident.return_value = 5
mock_iscoroutine.return_value = True
with pytest.raises(RuntimeError):
hasync.fire_coroutine_threadsafe(coro, loop)
assert len(loop.call_soon_threadsafe.mock_calls) == 1
loop._thread_ident = 1
mock_ident.return_value = 5
mock_iscoroutine.return_value = False
with pytest.raises(TypeError):
hasync.fire_coroutine_threadsafe(coro, loop)
assert len(loop.call_soon_threadsafe.mock_calls) == 1
loop._thread_ident = 1
mock_ident.return_value = 5
mock_iscoroutine.return_value = True
hasync.fire_coroutine_threadsafe(coro, loop)
assert len(loop.call_soon_threadsafe.mock_calls) == 2
@patch("concurrent.futures.Future") @patch("concurrent.futures.Future")
@patch("threading.get_ident") @patch("threading.get_ident")
def test_run_callback_threadsafe_from_inside_event_loop(mock_ident, _) -> None: def test_run_callback_threadsafe_from_inside_event_loop(mock_ident, _) -> None: