diff --git a/homeassistant/components/websocket_api/decorators.py b/homeassistant/components/websocket_api/decorators.py index a148ed2be8d..b4c72d497cd 100644 --- a/homeassistant/components/websocket_api/decorators.py +++ b/homeassistant/components/websocket_api/decorators.py @@ -45,6 +45,7 @@ def async_response( hass.async_create_background_task( _handle_async_response(func, hass, connection, msg), task_name, + eager_start=True, ) return schedule_handler diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index d7994305712..f2753322bc0 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -915,6 +915,7 @@ class ConfigEntry: hass: HomeAssistant, target: Coroutine[Any, Any, _R], name: str | None = None, + eager_start: bool = False, ) -> asyncio.Task[_R]: """Create a task from within the event loop. @@ -923,7 +924,7 @@ class ConfigEntry: target: target to call. """ task = hass.async_create_task( - target, f"{name} {self.title} {self.domain} {self.entry_id}" + target, f"{name} {self.title} {self.domain} {self.entry_id}", eager_start ) self._tasks.add(task) task.add_done_callback(self._tasks.remove) @@ -932,7 +933,11 @@ class ConfigEntry: @callback def async_create_background_task( - self, hass: HomeAssistant, target: Coroutine[Any, Any, _R], name: str + self, + hass: HomeAssistant, + target: Coroutine[Any, Any, _R], + name: str, + eager_start: bool = False, ) -> asyncio.Task[_R]: """Create a background task tied to the config entry lifecycle. @@ -940,7 +945,7 @@ class ConfigEntry: target: target to call. """ - task = hass.async_create_background_task(target, name) + task = hass.async_create_background_task(target, name, eager_start) self._background_tasks.add(task) task.add_done_callback(self._background_tasks.remove) return task diff --git a/homeassistant/core.py b/homeassistant/core.py index c49777a67fb..47c05442eed 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -91,6 +91,7 @@ from .helpers.json import json_bytes, json_fragment from .util import dt as dt_util, location from .util.async_ import ( cancelling, + create_eager_task, run_callback_threadsafe, shutdown_run_callback_threadsafe, ) @@ -622,7 +623,10 @@ class HomeAssistant: @callback def async_create_task( - self, target: Coroutine[Any, Any, _R], name: str | None = None + self, + target: Coroutine[Any, Any, _R], + name: str | None = None, + eager_start: bool = False, ) -> asyncio.Task[_R]: """Create a task from within the event loop. @@ -631,16 +635,17 @@ class HomeAssistant: target: target to call. """ - task = self.loop.create_task(target, name=name) + if eager_start: + task = create_eager_task(target, name=name, loop=self.loop) + else: + task = self.loop.create_task(target, name=name) self._tasks.add(task) task.add_done_callback(self._tasks.remove) return task @callback def async_create_background_task( - self, - target: Coroutine[Any, Any, _R], - name: str, + self, target: Coroutine[Any, Any, _R], name: str, eager_start: bool = False ) -> asyncio.Task[_R]: """Create a task from within the event loop. @@ -650,7 +655,10 @@ class HomeAssistant: This method must be run in the event loop. """ - task = self.loop.create_task(target, name=name) + if eager_start: + task = create_eager_task(target, name=name, loop=self.loop) + else: + task = self.loop.create_task(target, name=name) self._background_tasks.add(task) task.add_done_callback(self._background_tasks.remove) return task diff --git a/homeassistant/util/async_.py b/homeassistant/util/async_.py index 1b8496fe327..2fb44b0623b 100644 --- a/homeassistant/util/async_.py +++ b/homeassistant/util/async_.py @@ -1,13 +1,13 @@ """Asyncio utilities.""" from __future__ import annotations -from asyncio import Future, Semaphore, gather, get_running_loop -from asyncio.events import AbstractEventLoop +from asyncio import AbstractEventLoop, Future, Semaphore, Task, gather, get_running_loop from collections.abc import Awaitable, Callable import concurrent.futures from contextlib import suppress import functools import logging +import sys import threading from traceback import extract_stack from typing import Any, ParamSpec, TypeVar, TypeVarTuple @@ -23,6 +23,36 @@ _R = TypeVar("_R") _P = ParamSpec("_P") _Ts = TypeVarTuple("_Ts") +if sys.version_info >= (3, 12, 0): + + def create_eager_task( + coro: Awaitable[_T], + *, + name: str | None = None, + loop: AbstractEventLoop | None = None, + ) -> Task[_T]: + """Create a task from a coroutine and schedule it to run immediately.""" + return Task( + coro, + loop=loop or get_running_loop(), + name=name, + eager_start=True, # type: ignore[call-arg] + ) +else: + + def create_eager_task( + coro: Awaitable[_T], + *, + name: str | None = None, + loop: AbstractEventLoop | None = None, + ) -> Task[_T]: + """Create a task from a coroutine and schedule it to run immediately.""" + return Task( + coro, + loop=loop or get_running_loop(), + name=name, + ) + def cancelling(task: Future[Any]) -> bool: """Return True if task is cancelling.""" diff --git a/tests/common.py b/tests/common.py index dca847ff71c..14cacdf5d68 100644 --- a/tests/common.py +++ b/tests/common.py @@ -260,14 +260,14 @@ async def async_test_home_assistant( return orig_async_add_executor_job(target, *args) - def async_create_task(coroutine, name=None): + def async_create_task(coroutine, name=None, eager_start=False): """Create task.""" if isinstance(coroutine, Mock) and not isinstance(coroutine, AsyncMock): fut = asyncio.Future() fut.set_result(None) return fut - return orig_async_create_task(coroutine, name) + return orig_async_create_task(coroutine, name, eager_start) hass.async_add_job = async_add_job hass.async_add_executor_job = async_add_executor_job diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index bb38437ee09..188340f8ade 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -4228,11 +4228,16 @@ async def test_task_tracking(hass: HomeAssistant) -> None: entry.async_on_unload(test_unload) entry.async_create_task(hass, test_task()) - entry.async_create_background_task(hass, test_task(), "background-task-name") + entry.async_create_background_task( + hass, test_task(), "background-task-name", eager_start=True + ) + entry.async_create_background_task( + hass, test_task(), "background-task-name", eager_start=False + ) await asyncio.sleep(0) hass.loop.call_soon(event.set) await entry._async_process_on_unload(hass) - assert results == ["on_unload", "background", "normal"] + assert results == ["on_unload", "background", "background", "normal"] async def test_preview_supported( diff --git a/tests/test_core.py b/tests/test_core.py index 466d4578c7e..e9680202d5f 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -8,6 +8,7 @@ import functools import gc import logging import os +import sys from tempfile import TemporaryDirectory import threading import time @@ -161,7 +162,9 @@ def test_async_add_job_add_hass_threaded_job_to_pool() -> None: assert len(hass.loop.run_in_executor.mock_calls) == 2 -def test_async_create_task_schedule_coroutine(event_loop) -> None: +def test_async_create_task_schedule_coroutine( + event_loop: asyncio.AbstractEventLoop, +) -> None: """Test that we schedule coroutines and add jobs to the job pool.""" hass = MagicMock(loop=MagicMock(wraps=event_loop)) @@ -174,6 +177,44 @@ def test_async_create_task_schedule_coroutine(event_loop) -> None: assert len(hass.add_job.mock_calls) == 0 +@pytest.mark.skipif( + sys.version_info < (3, 12), reason="eager_start is only supported for Python 3.12" +) +def test_async_create_task_eager_start_schedule_coroutine( + event_loop: asyncio.AbstractEventLoop, +) -> None: + """Test that we schedule coroutines and add jobs to the job pool.""" + hass = MagicMock(loop=MagicMock(wraps=event_loop)) + + async def job(): + pass + + ha.HomeAssistant.async_create_task(hass, job(), eager_start=True) + # Should create the task directly since 3.12 supports eager_start + assert len(hass.loop.create_task.mock_calls) == 0 + assert len(hass.add_job.mock_calls) == 0 + + +@pytest.mark.skipif( + sys.version_info >= (3, 12), reason="eager_start is not supported on < 3.12" +) +def test_async_create_task_eager_start_fallback_schedule_coroutine( + event_loop: asyncio.AbstractEventLoop, +) -> None: + """Test that we schedule coroutines and add jobs to the job pool.""" + hass = MagicMock(loop=MagicMock(wraps=event_loop)) + + async def job(): + pass + + ha.HomeAssistant.async_create_task(hass, job(), eager_start=True) + assert len(hass.loop.call_soon.mock_calls) == 1 + # Should fallback to loop.create_task since 3.11 does + # not support eager_start + assert len(hass.loop.create_task.mock_calls) == 0 + assert len(hass.add_job.mock_calls) == 0 + + def test_async_create_task_schedule_coroutine_with_name(event_loop) -> None: """Test that we schedule coroutines and add jobs to the job pool with a name.""" hass = MagicMock(loop=MagicMock(wraps=event_loop)) @@ -2598,7 +2639,8 @@ async def test_state_changed_events_to_not_leak_contexts(hass: HomeAssistant) -> assert len(_get_by_type("homeassistant.core.Context")) == init_count -async def test_background_task(hass: HomeAssistant) -> None: +@pytest.mark.parametrize("eager_start", (True, False)) +async def test_background_task(hass: HomeAssistant, eager_start: bool) -> None: """Test background tasks being quit.""" result = asyncio.Future() @@ -2609,7 +2651,9 @@ async def test_background_task(hass: HomeAssistant) -> None: result.set_result(hass.state) raise - task = hass.async_create_background_task(test_task(), "happy task") + task = hass.async_create_background_task( + test_task(), "happy task", eager_start=eager_start + ) assert "happy task" in str(task) await asyncio.sleep(0) await hass.async_stop() diff --git a/tests/util/test_async.py b/tests/util/test_async.py index 60f86ee7af4..ad2c9329fb7 100644 --- a/tests/util/test_async.py +++ b/tests/util/test_async.py @@ -1,5 +1,6 @@ """Tests for async util methods from Python source.""" import asyncio +import sys import time from unittest.mock import MagicMock, Mock, patch @@ -246,3 +247,53 @@ async def test_callback_is_always_scheduled(hass: HomeAssistant) -> None: hasync.run_callback_threadsafe(hass.loop, callback) mock_call_soon_threadsafe.assert_called_once() + + +@pytest.mark.skipif(sys.version_info < (3, 12), reason="Test requires Python 3.12+") +async def test_create_eager_task_312(hass: HomeAssistant) -> None: + """Test create_eager_task schedules a task eagerly in the event loop. + + For Python 3.12+, the task is scheduled eagerly in the event loop. + """ + events = [] + + async def _normal_task(): + events.append("normal") + + async def _eager_task(): + events.append("eager") + + task1 = hasync.create_eager_task(_eager_task()) + task2 = asyncio.create_task(_normal_task()) + + assert events == ["eager"] + + await asyncio.sleep(0) + assert events == ["eager", "normal"] + await task1 + await task2 + + +@pytest.mark.skipif(sys.version_info >= (3, 12), reason="Test requires < Python 3.12") +async def test_create_eager_task_pre_312(hass: HomeAssistant) -> None: + """Test create_eager_task schedules a task in the event loop. + + For older python versions, the task is scheduled normally. + """ + events = [] + + async def _normal_task(): + events.append("normal") + + async def _eager_task(): + events.append("eager") + + task1 = hasync.create_eager_task(_eager_task()) + task2 = asyncio.create_task(_normal_task()) + + assert events == [] + + await asyncio.sleep(0) + assert events == ["eager", "normal"] + await task1 + await task2