Ensure asyncio blocking checks are undone after tests run (#119542)
* Ensure asyncio blocking checks are undone after tests run
* no reason to ever enable twice
* we are patching objects, make it more generic
* make sure bootstrap unblocks as well
* move disable to tests only
* re-protect
* Update tests/test_block_async_io.py
Co-authored-by: Erik Montnemery <erik@montnemery.com>
* Revert "Update tests/test_block_async_io.py"
This reverts commit 2d46028e21
.
* tweak name
* fixture only
* Update tests/conftest.py
* Update tests/conftest.py
* Apply suggestions from code review
---------
Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
parent
669569ca49
commit
d52ce03aa4
4 changed files with 184 additions and 52 deletions
|
@ -1,7 +1,9 @@
|
|||
"""Block blocking calls being done in asyncio."""
|
||||
|
||||
import builtins
|
||||
from collections.abc import Callable
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
import glob
|
||||
from http.client import HTTPConnection
|
||||
import importlib
|
||||
|
@ -46,53 +48,131 @@ def _check_sleep_call_allowed(mapped_args: dict[str, Any]) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
class BlockingCall:
|
||||
"""Class to hold information about a blocking call."""
|
||||
|
||||
original_func: Callable
|
||||
object: object
|
||||
function: str
|
||||
check_allowed: Callable[[dict[str, Any]], bool] | None
|
||||
strict: bool
|
||||
strict_core: bool
|
||||
skip_for_tests: bool
|
||||
|
||||
|
||||
_BLOCKING_CALLS: tuple[BlockingCall, ...] = (
|
||||
BlockingCall(
|
||||
original_func=HTTPConnection.putrequest,
|
||||
object=HTTPConnection,
|
||||
function="putrequest",
|
||||
check_allowed=None,
|
||||
strict=True,
|
||||
strict_core=True,
|
||||
skip_for_tests=False,
|
||||
),
|
||||
BlockingCall(
|
||||
original_func=time.sleep,
|
||||
object=time,
|
||||
function="sleep",
|
||||
check_allowed=_check_sleep_call_allowed,
|
||||
strict=True,
|
||||
strict_core=True,
|
||||
skip_for_tests=False,
|
||||
),
|
||||
BlockingCall(
|
||||
original_func=glob.glob,
|
||||
object=glob,
|
||||
function="glob",
|
||||
check_allowed=None,
|
||||
strict=False,
|
||||
strict_core=False,
|
||||
skip_for_tests=False,
|
||||
),
|
||||
BlockingCall(
|
||||
original_func=glob.iglob,
|
||||
object=glob,
|
||||
function="iglob",
|
||||
check_allowed=None,
|
||||
strict=False,
|
||||
strict_core=False,
|
||||
skip_for_tests=False,
|
||||
),
|
||||
BlockingCall(
|
||||
original_func=os.walk,
|
||||
object=os,
|
||||
function="walk",
|
||||
check_allowed=None,
|
||||
strict=False,
|
||||
strict_core=False,
|
||||
skip_for_tests=False,
|
||||
),
|
||||
BlockingCall(
|
||||
original_func=os.listdir,
|
||||
object=os,
|
||||
function="listdir",
|
||||
check_allowed=None,
|
||||
strict=False,
|
||||
strict_core=False,
|
||||
skip_for_tests=True,
|
||||
),
|
||||
BlockingCall(
|
||||
original_func=os.scandir,
|
||||
object=os,
|
||||
function="scandir",
|
||||
check_allowed=None,
|
||||
strict=False,
|
||||
strict_core=False,
|
||||
skip_for_tests=True,
|
||||
),
|
||||
BlockingCall(
|
||||
original_func=builtins.open,
|
||||
object=builtins,
|
||||
function="open",
|
||||
check_allowed=_check_file_allowed,
|
||||
strict=False,
|
||||
strict_core=False,
|
||||
skip_for_tests=True,
|
||||
),
|
||||
BlockingCall(
|
||||
original_func=importlib.import_module,
|
||||
object=importlib,
|
||||
function="import_module",
|
||||
check_allowed=_check_import_call_allowed,
|
||||
strict=False,
|
||||
strict_core=False,
|
||||
skip_for_tests=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class BlockedCalls:
|
||||
"""Class to track which calls are blocked."""
|
||||
|
||||
calls: set[BlockingCall]
|
||||
|
||||
|
||||
_BLOCKED_CALLS = BlockedCalls(set())
|
||||
|
||||
|
||||
def enable() -> None:
|
||||
"""Enable the detection of blocking calls in the event loop."""
|
||||
calls = _BLOCKED_CALLS.calls
|
||||
if calls:
|
||||
raise RuntimeError("Blocking call detection is already enabled")
|
||||
|
||||
loop_thread_id = threading.get_ident()
|
||||
# Prevent urllib3 and requests doing I/O in event loop
|
||||
HTTPConnection.putrequest = protect_loop( # type: ignore[method-assign]
|
||||
HTTPConnection.putrequest, loop_thread_id=loop_thread_id
|
||||
)
|
||||
for blocking_call in _BLOCKING_CALLS:
|
||||
if _IN_TESTS and blocking_call.skip_for_tests:
|
||||
continue
|
||||
|
||||
# Prevent sleeping in event loop.
|
||||
time.sleep = protect_loop(
|
||||
time.sleep,
|
||||
check_allowed=_check_sleep_call_allowed,
|
||||
loop_thread_id=loop_thread_id,
|
||||
)
|
||||
|
||||
glob.glob = protect_loop(
|
||||
glob.glob, strict_core=False, strict=False, loop_thread_id=loop_thread_id
|
||||
)
|
||||
glob.iglob = protect_loop(
|
||||
glob.iglob, strict_core=False, strict=False, loop_thread_id=loop_thread_id
|
||||
)
|
||||
os.walk = protect_loop(
|
||||
os.walk, strict_core=False, strict=False, loop_thread_id=loop_thread_id
|
||||
)
|
||||
|
||||
if not _IN_TESTS:
|
||||
# Prevent files being opened inside the event loop
|
||||
os.listdir = protect_loop( # type: ignore[assignment]
|
||||
os.listdir, strict_core=False, strict=False, loop_thread_id=loop_thread_id
|
||||
)
|
||||
os.scandir = protect_loop( # type: ignore[assignment]
|
||||
os.scandir, strict_core=False, strict=False, loop_thread_id=loop_thread_id
|
||||
)
|
||||
|
||||
builtins.open = protect_loop( # type: ignore[assignment]
|
||||
builtins.open,
|
||||
strict_core=False,
|
||||
strict=False,
|
||||
check_allowed=_check_file_allowed,
|
||||
loop_thread_id=loop_thread_id,
|
||||
)
|
||||
# unittest uses `importlib.import_module` to do mocking
|
||||
# so we cannot protect it if we are running tests
|
||||
importlib.import_module = protect_loop(
|
||||
importlib.import_module,
|
||||
strict_core=False,
|
||||
strict=False,
|
||||
check_allowed=_check_import_call_allowed,
|
||||
protected_function = protect_loop(
|
||||
blocking_call.original_func,
|
||||
strict=blocking_call.strict,
|
||||
strict_core=blocking_call.strict_core,
|
||||
check_allowed=blocking_call.check_allowed,
|
||||
loop_thread_id=loop_thread_id,
|
||||
)
|
||||
setattr(blocking_call.object, blocking_call.function, protected_function)
|
||||
calls.add(blocking_call)
|
||||
|
|
|
@ -35,6 +35,8 @@ import requests_mock
|
|||
from syrupy.assertion import SnapshotAssertion
|
||||
from typing_extensions import AsyncGenerator, Generator
|
||||
|
||||
from homeassistant import block_async_io
|
||||
|
||||
# Setup patching if dt_util time functions before any other Home Assistant imports
|
||||
from . import patch_time # noqa: F401, isort:skip
|
||||
|
||||
|
@ -1814,3 +1816,15 @@ def service_calls(hass: HomeAssistant) -> Generator[None, None, list[ServiceCall
|
|||
def snapshot(snapshot: SnapshotAssertion) -> SnapshotAssertion:
|
||||
"""Return snapshot assertion fixture with the Home Assistant extension."""
|
||||
return snapshot.use_extension(HomeAssistantSnapshotExtension)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_block_async_io() -> Generator[Any, Any, None]:
|
||||
"""Fixture to disable the loop protection from block_async_io."""
|
||||
yield
|
||||
calls = block_async_io._BLOCKED_CALLS.calls
|
||||
for blocking_call in calls:
|
||||
setattr(
|
||||
blocking_call.object, blocking_call.function, blocking_call.original_func
|
||||
)
|
||||
calls.clear()
|
||||
|
|
|
@ -17,6 +17,11 @@ from homeassistant.core import HomeAssistant
|
|||
from .common import extract_stack_to_frame
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_block_async_io(disable_block_async_io):
|
||||
"""Disable the loop protection from block_async_io after each test."""
|
||||
|
||||
|
||||
async def test_protect_loop_debugger_sleep(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""Test time.sleep injected by the debugger is not reported."""
|
||||
block_async_io.enable()
|
||||
|
@ -214,13 +219,25 @@ async def test_protect_loop_open(caplog: pytest.LogCaptureFixture) -> None:
|
|||
|
||||
async def test_protect_open(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""Test opening a file in the event loop logs."""
|
||||
block_async_io.enable()
|
||||
with patch.object(block_async_io, "_IN_TESTS", False):
|
||||
block_async_io.enable()
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
open("/config/data_not_exist", encoding="utf8").close()
|
||||
|
||||
assert "Detected blocking call to open with args" in caplog.text
|
||||
|
||||
|
||||
async def test_enable_multiple_times(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""Test trying to enable multiple times."""
|
||||
with patch.object(block_async_io, "_IN_TESTS", False):
|
||||
block_async_io.enable()
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError, match="Blocking call detection is already enabled"
|
||||
):
|
||||
block_async_io.enable()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
|
@ -231,7 +248,8 @@ async def test_protect_open(caplog: pytest.LogCaptureFixture) -> None:
|
|||
)
|
||||
async def test_protect_open_path(path: Any, caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""Test opening a file by path in the event loop logs."""
|
||||
block_async_io.enable()
|
||||
with patch.object(block_async_io, "_IN_TESTS", False):
|
||||
block_async_io.enable()
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
open(path, encoding="utf8").close()
|
||||
|
||||
|
@ -242,7 +260,8 @@ async def test_protect_loop_glob(
|
|||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test glob calls in the loop are logged."""
|
||||
block_async_io.enable()
|
||||
with patch.object(block_async_io, "_IN_TESTS", False):
|
||||
block_async_io.enable()
|
||||
glob.glob("/dev/null")
|
||||
assert "Detected blocking call to glob with args" in caplog.text
|
||||
caplog.clear()
|
||||
|
@ -254,7 +273,8 @@ async def test_protect_loop_iglob(
|
|||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test iglob calls in the loop are logged."""
|
||||
block_async_io.enable()
|
||||
with patch.object(block_async_io, "_IN_TESTS", False):
|
||||
block_async_io.enable()
|
||||
glob.iglob("/dev/null")
|
||||
assert "Detected blocking call to iglob with args" in caplog.text
|
||||
caplog.clear()
|
||||
|
@ -266,7 +286,8 @@ async def test_protect_loop_scandir(
|
|||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test glob calls in the loop are logged."""
|
||||
block_async_io.enable()
|
||||
with patch.object(block_async_io, "_IN_TESTS", False):
|
||||
block_async_io.enable()
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
os.scandir("/path/that/does/not/exists")
|
||||
assert "Detected blocking call to scandir with args" in caplog.text
|
||||
|
@ -280,7 +301,8 @@ async def test_protect_loop_listdir(
|
|||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test listdir calls in the loop are logged."""
|
||||
block_async_io.enable()
|
||||
with patch.object(block_async_io, "_IN_TESTS", False):
|
||||
block_async_io.enable()
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
os.listdir("/path/that/does/not/exists")
|
||||
assert "Detected blocking call to listdir with args" in caplog.text
|
||||
|
@ -293,8 +315,9 @@ async def test_protect_loop_listdir(
|
|||
async def test_protect_loop_walk(
|
||||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test glob calls in the loop are logged."""
|
||||
block_async_io.enable()
|
||||
"""Test os.walk calls in the loop are logged."""
|
||||
with patch.object(block_async_io, "_IN_TESTS", False):
|
||||
block_async_io.enable()
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
os.walk("/path/that/does/not/exists")
|
||||
assert "Detected blocking call to walk with args" in caplog.text
|
||||
|
@ -302,3 +325,13 @@ async def test_protect_loop_walk(
|
|||
with contextlib.suppress(FileNotFoundError):
|
||||
await hass.async_add_executor_job(os.walk, "/path/that/does/not/exists")
|
||||
assert "Detected blocking call to walk with args" not in caplog.text
|
||||
|
||||
|
||||
async def test_open_calls_ignored_in_tests(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""Test opening a file in tests is ignored."""
|
||||
assert block_async_io._IN_TESTS
|
||||
block_async_io.enable()
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
open("/config/data_not_exist", encoding="utf8").close()
|
||||
|
||||
assert "Detected blocking call to open with args" not in caplog.text
|
||||
|
|
|
@ -55,6 +55,11 @@ async def apply_stop_hass(stop_hass: None) -> None:
|
|||
"""Make sure all hass are stopped."""
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_block_async_io(disable_block_async_io):
|
||||
"""Disable the loop protection from block_async_io after each test."""
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def mock_http_start_stop() -> Generator[None]:
|
||||
"""Mock HTTP start and stop."""
|
||||
|
|
Loading…
Add table
Reference in a new issue