Add support for eager tasks (#111425)
* Add support for eager tasks python 3.12 supports eager tasks reading: https://docs.python.org/3/library/asyncio-task.html#eager-task-factory https://github.com/python/cpython/issues/97696 There are lots of places were we are unlikely to suspend, but we might suspend so creating a task makes sense * reduce * revert entity * revert * coverage * coverage * coverage * coverage * fix test
This commit is contained in:
parent
93cc6e0f36
commit
67e356904b
8 changed files with 162 additions and 18 deletions
|
@ -45,6 +45,7 @@ def async_response(
|
||||||
hass.async_create_background_task(
|
hass.async_create_background_task(
|
||||||
_handle_async_response(func, hass, connection, msg),
|
_handle_async_response(func, hass, connection, msg),
|
||||||
task_name,
|
task_name,
|
||||||
|
eager_start=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return schedule_handler
|
return schedule_handler
|
||||||
|
|
|
@ -915,6 +915,7 @@ class ConfigEntry:
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
target: Coroutine[Any, Any, _R],
|
target: Coroutine[Any, Any, _R],
|
||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
|
eager_start: bool = False,
|
||||||
) -> asyncio.Task[_R]:
|
) -> asyncio.Task[_R]:
|
||||||
"""Create a task from within the event loop.
|
"""Create a task from within the event loop.
|
||||||
|
|
||||||
|
@ -923,7 +924,7 @@ class ConfigEntry:
|
||||||
target: target to call.
|
target: target to call.
|
||||||
"""
|
"""
|
||||||
task = hass.async_create_task(
|
task = hass.async_create_task(
|
||||||
target, f"{name} {self.title} {self.domain} {self.entry_id}"
|
target, f"{name} {self.title} {self.domain} {self.entry_id}", eager_start
|
||||||
)
|
)
|
||||||
self._tasks.add(task)
|
self._tasks.add(task)
|
||||||
task.add_done_callback(self._tasks.remove)
|
task.add_done_callback(self._tasks.remove)
|
||||||
|
@ -932,7 +933,11 @@ class ConfigEntry:
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_create_background_task(
|
def async_create_background_task(
|
||||||
self, hass: HomeAssistant, target: Coroutine[Any, Any, _R], name: str
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
target: Coroutine[Any, Any, _R],
|
||||||
|
name: str,
|
||||||
|
eager_start: bool = False,
|
||||||
) -> asyncio.Task[_R]:
|
) -> asyncio.Task[_R]:
|
||||||
"""Create a background task tied to the config entry lifecycle.
|
"""Create a background task tied to the config entry lifecycle.
|
||||||
|
|
||||||
|
@ -940,7 +945,7 @@ class ConfigEntry:
|
||||||
|
|
||||||
target: target to call.
|
target: target to call.
|
||||||
"""
|
"""
|
||||||
task = hass.async_create_background_task(target, name)
|
task = hass.async_create_background_task(target, name, eager_start)
|
||||||
self._background_tasks.add(task)
|
self._background_tasks.add(task)
|
||||||
task.add_done_callback(self._background_tasks.remove)
|
task.add_done_callback(self._background_tasks.remove)
|
||||||
return task
|
return task
|
||||||
|
|
|
@ -91,6 +91,7 @@ from .helpers.json import json_bytes, json_fragment
|
||||||
from .util import dt as dt_util, location
|
from .util import dt as dt_util, location
|
||||||
from .util.async_ import (
|
from .util.async_ import (
|
||||||
cancelling,
|
cancelling,
|
||||||
|
create_eager_task,
|
||||||
run_callback_threadsafe,
|
run_callback_threadsafe,
|
||||||
shutdown_run_callback_threadsafe,
|
shutdown_run_callback_threadsafe,
|
||||||
)
|
)
|
||||||
|
@ -622,7 +623,10 @@ class HomeAssistant:
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_create_task(
|
def async_create_task(
|
||||||
self, target: Coroutine[Any, Any, _R], name: str | None = None
|
self,
|
||||||
|
target: Coroutine[Any, Any, _R],
|
||||||
|
name: str | None = None,
|
||||||
|
eager_start: bool = False,
|
||||||
) -> asyncio.Task[_R]:
|
) -> asyncio.Task[_R]:
|
||||||
"""Create a task from within the event loop.
|
"""Create a task from within the event loop.
|
||||||
|
|
||||||
|
@ -631,16 +635,17 @@ class HomeAssistant:
|
||||||
|
|
||||||
target: target to call.
|
target: target to call.
|
||||||
"""
|
"""
|
||||||
task = self.loop.create_task(target, name=name)
|
if eager_start:
|
||||||
|
task = create_eager_task(target, name=name, loop=self.loop)
|
||||||
|
else:
|
||||||
|
task = self.loop.create_task(target, name=name)
|
||||||
self._tasks.add(task)
|
self._tasks.add(task)
|
||||||
task.add_done_callback(self._tasks.remove)
|
task.add_done_callback(self._tasks.remove)
|
||||||
return task
|
return task
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_create_background_task(
|
def async_create_background_task(
|
||||||
self,
|
self, target: Coroutine[Any, Any, _R], name: str, eager_start: bool = False
|
||||||
target: Coroutine[Any, Any, _R],
|
|
||||||
name: str,
|
|
||||||
) -> asyncio.Task[_R]:
|
) -> asyncio.Task[_R]:
|
||||||
"""Create a task from within the event loop.
|
"""Create a task from within the event loop.
|
||||||
|
|
||||||
|
@ -650,7 +655,10 @@ class HomeAssistant:
|
||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
"""
|
"""
|
||||||
task = self.loop.create_task(target, name=name)
|
if eager_start:
|
||||||
|
task = create_eager_task(target, name=name, loop=self.loop)
|
||||||
|
else:
|
||||||
|
task = self.loop.create_task(target, name=name)
|
||||||
self._background_tasks.add(task)
|
self._background_tasks.add(task)
|
||||||
task.add_done_callback(self._background_tasks.remove)
|
task.add_done_callback(self._background_tasks.remove)
|
||||||
return task
|
return task
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
"""Asyncio utilities."""
|
"""Asyncio utilities."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from asyncio import Future, Semaphore, gather, get_running_loop
|
from asyncio import AbstractEventLoop, Future, Semaphore, Task, gather, get_running_loop
|
||||||
from asyncio.events import AbstractEventLoop
|
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from traceback import extract_stack
|
from traceback import extract_stack
|
||||||
from typing import Any, ParamSpec, TypeVar, TypeVarTuple
|
from typing import Any, ParamSpec, TypeVar, TypeVarTuple
|
||||||
|
@ -23,6 +23,36 @@ _R = TypeVar("_R")
|
||||||
_P = ParamSpec("_P")
|
_P = ParamSpec("_P")
|
||||||
_Ts = TypeVarTuple("_Ts")
|
_Ts = TypeVarTuple("_Ts")
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 12, 0):
|
||||||
|
|
||||||
|
def create_eager_task(
|
||||||
|
coro: Awaitable[_T],
|
||||||
|
*,
|
||||||
|
name: str | None = None,
|
||||||
|
loop: AbstractEventLoop | None = None,
|
||||||
|
) -> Task[_T]:
|
||||||
|
"""Create a task from a coroutine and schedule it to run immediately."""
|
||||||
|
return Task(
|
||||||
|
coro,
|
||||||
|
loop=loop or get_running_loop(),
|
||||||
|
name=name,
|
||||||
|
eager_start=True, # type: ignore[call-arg]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
|
||||||
|
def create_eager_task(
|
||||||
|
coro: Awaitable[_T],
|
||||||
|
*,
|
||||||
|
name: str | None = None,
|
||||||
|
loop: AbstractEventLoop | None = None,
|
||||||
|
) -> Task[_T]:
|
||||||
|
"""Create a task from a coroutine and schedule it to run immediately."""
|
||||||
|
return Task(
|
||||||
|
coro,
|
||||||
|
loop=loop or get_running_loop(),
|
||||||
|
name=name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def cancelling(task: Future[Any]) -> bool:
|
def cancelling(task: Future[Any]) -> bool:
|
||||||
"""Return True if task is cancelling."""
|
"""Return True if task is cancelling."""
|
||||||
|
|
|
@ -260,14 +260,14 @@ async def async_test_home_assistant(
|
||||||
|
|
||||||
return orig_async_add_executor_job(target, *args)
|
return orig_async_add_executor_job(target, *args)
|
||||||
|
|
||||||
def async_create_task(coroutine, name=None):
|
def async_create_task(coroutine, name=None, eager_start=False):
|
||||||
"""Create task."""
|
"""Create task."""
|
||||||
if isinstance(coroutine, Mock) and not isinstance(coroutine, AsyncMock):
|
if isinstance(coroutine, Mock) and not isinstance(coroutine, AsyncMock):
|
||||||
fut = asyncio.Future()
|
fut = asyncio.Future()
|
||||||
fut.set_result(None)
|
fut.set_result(None)
|
||||||
return fut
|
return fut
|
||||||
|
|
||||||
return orig_async_create_task(coroutine, name)
|
return orig_async_create_task(coroutine, name, eager_start)
|
||||||
|
|
||||||
hass.async_add_job = async_add_job
|
hass.async_add_job = async_add_job
|
||||||
hass.async_add_executor_job = async_add_executor_job
|
hass.async_add_executor_job = async_add_executor_job
|
||||||
|
|
|
@ -4228,11 +4228,16 @@ async def test_task_tracking(hass: HomeAssistant) -> None:
|
||||||
|
|
||||||
entry.async_on_unload(test_unload)
|
entry.async_on_unload(test_unload)
|
||||||
entry.async_create_task(hass, test_task())
|
entry.async_create_task(hass, test_task())
|
||||||
entry.async_create_background_task(hass, test_task(), "background-task-name")
|
entry.async_create_background_task(
|
||||||
|
hass, test_task(), "background-task-name", eager_start=True
|
||||||
|
)
|
||||||
|
entry.async_create_background_task(
|
||||||
|
hass, test_task(), "background-task-name", eager_start=False
|
||||||
|
)
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
hass.loop.call_soon(event.set)
|
hass.loop.call_soon(event.set)
|
||||||
await entry._async_process_on_unload(hass)
|
await entry._async_process_on_unload(hass)
|
||||||
assert results == ["on_unload", "background", "normal"]
|
assert results == ["on_unload", "background", "background", "normal"]
|
||||||
|
|
||||||
|
|
||||||
async def test_preview_supported(
|
async def test_preview_supported(
|
||||||
|
|
|
@ -8,6 +8,7 @@ import functools
|
||||||
import gc
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
@ -161,7 +162,9 @@ def test_async_add_job_add_hass_threaded_job_to_pool() -> None:
|
||||||
assert len(hass.loop.run_in_executor.mock_calls) == 2
|
assert len(hass.loop.run_in_executor.mock_calls) == 2
|
||||||
|
|
||||||
|
|
||||||
def test_async_create_task_schedule_coroutine(event_loop) -> None:
|
def test_async_create_task_schedule_coroutine(
|
||||||
|
event_loop: asyncio.AbstractEventLoop,
|
||||||
|
) -> None:
|
||||||
"""Test that we schedule coroutines and add jobs to the job pool."""
|
"""Test that we schedule coroutines and add jobs to the job pool."""
|
||||||
hass = MagicMock(loop=MagicMock(wraps=event_loop))
|
hass = MagicMock(loop=MagicMock(wraps=event_loop))
|
||||||
|
|
||||||
|
@ -174,6 +177,44 @@ def test_async_create_task_schedule_coroutine(event_loop) -> None:
|
||||||
assert len(hass.add_job.mock_calls) == 0
|
assert len(hass.add_job.mock_calls) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 12), reason="eager_start is only supported for Python 3.12"
|
||||||
|
)
|
||||||
|
def test_async_create_task_eager_start_schedule_coroutine(
|
||||||
|
event_loop: asyncio.AbstractEventLoop,
|
||||||
|
) -> None:
|
||||||
|
"""Test that we schedule coroutines and add jobs to the job pool."""
|
||||||
|
hass = MagicMock(loop=MagicMock(wraps=event_loop))
|
||||||
|
|
||||||
|
async def job():
|
||||||
|
pass
|
||||||
|
|
||||||
|
ha.HomeAssistant.async_create_task(hass, job(), eager_start=True)
|
||||||
|
# Should create the task directly since 3.12 supports eager_start
|
||||||
|
assert len(hass.loop.create_task.mock_calls) == 0
|
||||||
|
assert len(hass.add_job.mock_calls) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info >= (3, 12), reason="eager_start is not supported on < 3.12"
|
||||||
|
)
|
||||||
|
def test_async_create_task_eager_start_fallback_schedule_coroutine(
|
||||||
|
event_loop: asyncio.AbstractEventLoop,
|
||||||
|
) -> None:
|
||||||
|
"""Test that we schedule coroutines and add jobs to the job pool."""
|
||||||
|
hass = MagicMock(loop=MagicMock(wraps=event_loop))
|
||||||
|
|
||||||
|
async def job():
|
||||||
|
pass
|
||||||
|
|
||||||
|
ha.HomeAssistant.async_create_task(hass, job(), eager_start=True)
|
||||||
|
assert len(hass.loop.call_soon.mock_calls) == 1
|
||||||
|
# Should fallback to loop.create_task since 3.11 does
|
||||||
|
# not support eager_start
|
||||||
|
assert len(hass.loop.create_task.mock_calls) == 0
|
||||||
|
assert len(hass.add_job.mock_calls) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_async_create_task_schedule_coroutine_with_name(event_loop) -> None:
|
def test_async_create_task_schedule_coroutine_with_name(event_loop) -> None:
|
||||||
"""Test that we schedule coroutines and add jobs to the job pool with a name."""
|
"""Test that we schedule coroutines and add jobs to the job pool with a name."""
|
||||||
hass = MagicMock(loop=MagicMock(wraps=event_loop))
|
hass = MagicMock(loop=MagicMock(wraps=event_loop))
|
||||||
|
@ -2598,7 +2639,8 @@ async def test_state_changed_events_to_not_leak_contexts(hass: HomeAssistant) ->
|
||||||
assert len(_get_by_type("homeassistant.core.Context")) == init_count
|
assert len(_get_by_type("homeassistant.core.Context")) == init_count
|
||||||
|
|
||||||
|
|
||||||
async def test_background_task(hass: HomeAssistant) -> None:
|
@pytest.mark.parametrize("eager_start", (True, False))
|
||||||
|
async def test_background_task(hass: HomeAssistant, eager_start: bool) -> None:
|
||||||
"""Test background tasks being quit."""
|
"""Test background tasks being quit."""
|
||||||
result = asyncio.Future()
|
result = asyncio.Future()
|
||||||
|
|
||||||
|
@ -2609,7 +2651,9 @@ async def test_background_task(hass: HomeAssistant) -> None:
|
||||||
result.set_result(hass.state)
|
result.set_result(hass.state)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
task = hass.async_create_background_task(test_task(), "happy task")
|
task = hass.async_create_background_task(
|
||||||
|
test_task(), "happy task", eager_start=eager_start
|
||||||
|
)
|
||||||
assert "happy task" in str(task)
|
assert "happy task" in str(task)
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
await hass.async_stop()
|
await hass.async_stop()
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
"""Tests for async util methods from Python source."""
|
"""Tests for async util methods from Python source."""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
|
@ -246,3 +247,53 @@ async def test_callback_is_always_scheduled(hass: HomeAssistant) -> None:
|
||||||
hasync.run_callback_threadsafe(hass.loop, callback)
|
hasync.run_callback_threadsafe(hass.loop, callback)
|
||||||
|
|
||||||
mock_call_soon_threadsafe.assert_called_once()
|
mock_call_soon_threadsafe.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.version_info < (3, 12), reason="Test requires Python 3.12+")
|
||||||
|
async def test_create_eager_task_312(hass: HomeAssistant) -> None:
|
||||||
|
"""Test create_eager_task schedules a task eagerly in the event loop.
|
||||||
|
|
||||||
|
For Python 3.12+, the task is scheduled eagerly in the event loop.
|
||||||
|
"""
|
||||||
|
events = []
|
||||||
|
|
||||||
|
async def _normal_task():
|
||||||
|
events.append("normal")
|
||||||
|
|
||||||
|
async def _eager_task():
|
||||||
|
events.append("eager")
|
||||||
|
|
||||||
|
task1 = hasync.create_eager_task(_eager_task())
|
||||||
|
task2 = asyncio.create_task(_normal_task())
|
||||||
|
|
||||||
|
assert events == ["eager"]
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert events == ["eager", "normal"]
|
||||||
|
await task1
|
||||||
|
await task2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.version_info >= (3, 12), reason="Test requires < Python 3.12")
|
||||||
|
async def test_create_eager_task_pre_312(hass: HomeAssistant) -> None:
|
||||||
|
"""Test create_eager_task schedules a task in the event loop.
|
||||||
|
|
||||||
|
For older python versions, the task is scheduled normally.
|
||||||
|
"""
|
||||||
|
events = []
|
||||||
|
|
||||||
|
async def _normal_task():
|
||||||
|
events.append("normal")
|
||||||
|
|
||||||
|
async def _eager_task():
|
||||||
|
events.append("eager")
|
||||||
|
|
||||||
|
task1 = hasync.create_eager_task(_eager_task())
|
||||||
|
task2 = asyncio.create_task(_normal_task())
|
||||||
|
|
||||||
|
assert events == []
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert events == ["eager", "normal"]
|
||||||
|
await task1
|
||||||
|
await task2
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue