Replace fire_coroutine_threadsafe with asyncio.run_coroutine_threadsafe (#88572)
fire_coroutine_threadsafe did not hold a reference to the asyncio task which meant the task had the risk of being prematurely garbage collected
This commit is contained in:
parent
e54eb7e2c8
commit
5bc0636905
3 changed files with 14 additions and 70 deletions
|
@ -14,6 +14,7 @@ from collections.abc import (
|
||||||
Iterable,
|
Iterable,
|
||||||
Mapping,
|
Mapping,
|
||||||
)
|
)
|
||||||
|
import concurrent.futures
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
import datetime
|
import datetime
|
||||||
|
@ -79,11 +80,7 @@ from .exceptions import (
|
||||||
)
|
)
|
||||||
from .helpers.aiohttp_compat import restore_original_aiohttp_cancel_behavior
|
from .helpers.aiohttp_compat import restore_original_aiohttp_cancel_behavior
|
||||||
from .util import dt as dt_util, location, ulid as ulid_util
|
from .util import dt as dt_util, location, ulid as ulid_util
|
||||||
from .util.async_ import (
|
from .util.async_ import run_callback_threadsafe, shutdown_run_callback_threadsafe
|
||||||
fire_coroutine_threadsafe,
|
|
||||||
run_callback_threadsafe,
|
|
||||||
shutdown_run_callback_threadsafe,
|
|
||||||
)
|
|
||||||
from .util.read_only_dict import ReadOnlyDict
|
from .util.read_only_dict import ReadOnlyDict
|
||||||
from .util.timeout import TimeoutManager
|
from .util.timeout import TimeoutManager
|
||||||
from .util.unit_system import (
|
from .util.unit_system import (
|
||||||
|
@ -294,6 +291,7 @@ class HomeAssistant:
|
||||||
self._stopped: asyncio.Event | None = None
|
self._stopped: asyncio.Event | None = None
|
||||||
# Timeout handler for Core/Helper namespace
|
# Timeout handler for Core/Helper namespace
|
||||||
self.timeout: TimeoutManager = TimeoutManager()
|
self.timeout: TimeoutManager = TimeoutManager()
|
||||||
|
self._stop_future: concurrent.futures.Future[None] | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
|
@ -312,12 +310,14 @@ class HomeAssistant:
|
||||||
For regular use, use "await hass.run()".
|
For regular use, use "await hass.run()".
|
||||||
"""
|
"""
|
||||||
# Register the async start
|
# Register the async start
|
||||||
fire_coroutine_threadsafe(self.async_start(), self.loop)
|
_future = asyncio.run_coroutine_threadsafe(self.async_start(), self.loop)
|
||||||
|
|
||||||
# Run forever
|
# Run forever
|
||||||
# Block until stopped
|
# Block until stopped
|
||||||
_LOGGER.info("Starting Home Assistant core loop")
|
_LOGGER.info("Starting Home Assistant core loop")
|
||||||
self.loop.run_forever()
|
self.loop.run_forever()
|
||||||
|
# The future is never retrieved but we still hold a reference to it
|
||||||
|
# to prevent the task from being garbage collected prematurely.
|
||||||
|
del _future
|
||||||
return self.exit_code
|
return self.exit_code
|
||||||
|
|
||||||
async def async_run(self, *, attach_signals: bool = True) -> int:
|
async def async_run(self, *, attach_signals: bool = True) -> int:
|
||||||
|
@ -682,7 +682,11 @@ class HomeAssistant:
|
||||||
"""Stop Home Assistant and shuts down all threads."""
|
"""Stop Home Assistant and shuts down all threads."""
|
||||||
if self.state == CoreState.not_running: # just ignore
|
if self.state == CoreState.not_running: # just ignore
|
||||||
return
|
return
|
||||||
fire_coroutine_threadsafe(self.async_stop(), self.loop)
|
# The future is never retrieved, and we only hold a reference
|
||||||
|
# to it to prevent it from being garbage collected.
|
||||||
|
self._stop_future = asyncio.run_coroutine_threadsafe(
|
||||||
|
self.async_stop(), self.loop
|
||||||
|
)
|
||||||
|
|
||||||
async def async_stop(self, exit_code: int = 0, *, force: bool = False) -> None:
|
async def async_stop(self, exit_code: int = 0, *, force: bool = False) -> None:
|
||||||
"""Stop Home Assistant and shuts down all threads.
|
"""Stop Home Assistant and shuts down all threads.
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
"""Asyncio utilities."""
|
"""Asyncio utilities."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from asyncio import Semaphore, coroutines, ensure_future, gather, get_running_loop
|
from asyncio import Semaphore, gather, get_running_loop
|
||||||
from asyncio.events import AbstractEventLoop
|
from asyncio.events import AbstractEventLoop
|
||||||
from collections.abc import Awaitable, Callable, Coroutine
|
from collections.abc import Awaitable, Callable
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
|
@ -20,29 +20,6 @@ _R = TypeVar("_R")
|
||||||
_P = ParamSpec("_P")
|
_P = ParamSpec("_P")
|
||||||
|
|
||||||
|
|
||||||
def fire_coroutine_threadsafe(
|
|
||||||
coro: Coroutine[Any, Any, Any], loop: AbstractEventLoop
|
|
||||||
) -> None:
|
|
||||||
"""Submit a coroutine object to a given event loop.
|
|
||||||
|
|
||||||
This method does not provide a way to retrieve the result and
|
|
||||||
is intended for fire-and-forget use. This reduces the
|
|
||||||
work involved to fire the function on the loop.
|
|
||||||
"""
|
|
||||||
ident = loop.__dict__.get("_thread_ident")
|
|
||||||
if ident is not None and ident == threading.get_ident():
|
|
||||||
raise RuntimeError("Cannot be called from within the event loop")
|
|
||||||
|
|
||||||
if not coroutines.iscoroutine(coro):
|
|
||||||
raise TypeError(f"A coroutine object is required: {coro}")
|
|
||||||
|
|
||||||
def callback() -> None:
|
|
||||||
"""Handle the firing of a coroutine."""
|
|
||||||
ensure_future(coro, loop=loop)
|
|
||||||
|
|
||||||
loop.call_soon_threadsafe(callback)
|
|
||||||
|
|
||||||
|
|
||||||
def run_callback_threadsafe(
|
def run_callback_threadsafe(
|
||||||
loop: AbstractEventLoop, callback: Callable[..., _T], *args: Any
|
loop: AbstractEventLoop, callback: Callable[..., _T], *args: Any
|
||||||
) -> concurrent.futures.Future[_T]:
|
) -> concurrent.futures.Future[_T]:
|
||||||
|
|
|
@ -10,43 +10,6 @@ from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.util import async_ as hasync
|
from homeassistant.util import async_ as hasync
|
||||||
|
|
||||||
|
|
||||||
@patch("asyncio.coroutines.iscoroutine")
|
|
||||||
@patch("concurrent.futures.Future")
|
|
||||||
@patch("threading.get_ident")
|
|
||||||
def test_fire_coroutine_threadsafe_from_inside_event_loop(
|
|
||||||
mock_ident, _, mock_iscoroutine
|
|
||||||
) -> None:
|
|
||||||
"""Testing calling fire_coroutine_threadsafe from inside an event loop."""
|
|
||||||
coro = MagicMock()
|
|
||||||
loop = MagicMock()
|
|
||||||
|
|
||||||
loop._thread_ident = None
|
|
||||||
mock_ident.return_value = 5
|
|
||||||
mock_iscoroutine.return_value = True
|
|
||||||
hasync.fire_coroutine_threadsafe(coro, loop)
|
|
||||||
assert len(loop.call_soon_threadsafe.mock_calls) == 1
|
|
||||||
|
|
||||||
loop._thread_ident = 5
|
|
||||||
mock_ident.return_value = 5
|
|
||||||
mock_iscoroutine.return_value = True
|
|
||||||
with pytest.raises(RuntimeError):
|
|
||||||
hasync.fire_coroutine_threadsafe(coro, loop)
|
|
||||||
assert len(loop.call_soon_threadsafe.mock_calls) == 1
|
|
||||||
|
|
||||||
loop._thread_ident = 1
|
|
||||||
mock_ident.return_value = 5
|
|
||||||
mock_iscoroutine.return_value = False
|
|
||||||
with pytest.raises(TypeError):
|
|
||||||
hasync.fire_coroutine_threadsafe(coro, loop)
|
|
||||||
assert len(loop.call_soon_threadsafe.mock_calls) == 1
|
|
||||||
|
|
||||||
loop._thread_ident = 1
|
|
||||||
mock_ident.return_value = 5
|
|
||||||
mock_iscoroutine.return_value = True
|
|
||||||
hasync.fire_coroutine_threadsafe(coro, loop)
|
|
||||||
assert len(loop.call_soon_threadsafe.mock_calls) == 2
|
|
||||||
|
|
||||||
|
|
||||||
@patch("concurrent.futures.Future")
|
@patch("concurrent.futures.Future")
|
||||||
@patch("threading.get_ident")
|
@patch("threading.get_ident")
|
||||||
def test_run_callback_threadsafe_from_inside_event_loop(mock_ident, _) -> None:
|
def test_run_callback_threadsafe_from_inside_event_loop(mock_ident, _) -> None:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue