Refactor asyncio loop protection to improve performance (#117295)

This commit is contained in:
J. Nick Koston 2024-05-13 07:01:55 +09:00 committed by GitHub
parent aae39759d9
commit d06932bbc2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 91 additions and 59 deletions

View file

@ -4,6 +4,7 @@ from contextlib import suppress
from http.client import HTTPConnection
import importlib
import sys
import threading
import time
from typing import Any
@ -25,7 +26,7 @@ def _check_sleep_call_allowed(mapped_args: dict[str, Any]) -> bool:
# I/O and we are trying to avoid blocking calls.
#
# frame[0] is us
# frame[1] is check_loop
# frame[1] is raise_for_blocking_call
# frame[2] is protected_loop_func
# frame[3] is the offender
with suppress(ValueError):
@ -35,14 +36,18 @@ def _check_sleep_call_allowed(mapped_args: dict[str, Any]) -> bool:
def enable() -> None:
"""Enable the detection of blocking calls in the event loop."""
loop_thread_id = threading.get_ident()
# Prevent urllib3 and requests doing I/O in event loop
HTTPConnection.putrequest = protect_loop( # type: ignore[method-assign]
HTTPConnection.putrequest
HTTPConnection.putrequest, loop_thread_id=loop_thread_id
)
# Prevent sleeping in event loop. Non-strict since 2022.02
time.sleep = protect_loop(
time.sleep, strict=False, check_allowed=_check_sleep_call_allowed
time.sleep,
strict=False,
check_allowed=_check_sleep_call_allowed,
loop_thread_id=loop_thread_id,
)
# Currently disabled. pytz doing I/O when getting timezone.
@ -57,4 +62,5 @@ def enable() -> None:
strict_core=False,
strict=False,
check_allowed=_check_import_call_allowed,
loop_thread_id=loop_thread_id,
)

View file

@ -1,5 +1,6 @@
"""A pool for sqlite connections."""
import asyncio
import logging
import threading
import traceback
@ -14,7 +15,7 @@ from sqlalchemy.pool import (
)
from homeassistant.helpers.frame import report
from homeassistant.util.loop import check_loop
from homeassistant.util.loop import raise_for_blocking_call
_LOGGER = logging.getLogger(__name__)
@ -86,15 +87,22 @@ class RecorderPool(SingletonThreadPool, NullPool):
if threading.get_ident() in self.recorder_and_worker_thread_ids:
super().dispose()
def _do_get(self) -> ConnectionPoolEntry:
def _do_get(self) -> ConnectionPoolEntry: # type: ignore[return]
if threading.get_ident() in self.recorder_and_worker_thread_ids:
return super()._do_get()
check_loop(
try:
asyncio.get_running_loop()
except RuntimeError:
# Not in an event loop but not in the recorder or worker thread
# which is allowed but discouraged since its much slower
return self._do_get_db_connection_protected()
# In the event loop, raise an exception
raise_for_blocking_call(
self._do_get_db_connection_protected,
strict=True,
advise_msg=ADVISE_MSG,
)
return self._do_get_db_connection_protected()
# raise_for_blocking_call will raise an exception
def _do_get_db_connection_protected(self) -> ConnectionPoolEntry:
report(

View file

@ -2,12 +2,12 @@
from __future__ import annotations
from asyncio import get_running_loop
from collections.abc import Callable
from contextlib import suppress
import functools
import linecache
import logging
import threading
from typing import Any, ParamSpec, TypeVar
from homeassistant.core import HomeAssistant, async_get_hass
@ -31,7 +31,7 @@ def _get_line_from_cache(filename: str, lineno: int) -> str:
return (linecache.getline(filename, lineno) or "?").strip()
def check_loop(
def raise_for_blocking_call(
func: Callable[..., Any],
check_allowed: Callable[[dict[str, Any]], bool] | None = None,
strict: bool = True,
@ -44,15 +44,6 @@ def check_loop(
The default advisory message is 'Use `await hass.async_add_executor_job()'
Set `advise_msg` to an alternate message if the solution differs.
"""
try:
get_running_loop()
in_loop = True
except RuntimeError:
in_loop = False
if not in_loop:
return
if check_allowed is not None and check_allowed(mapped_args):
return
@ -125,6 +116,7 @@ def check_loop(
def protect_loop(
func: Callable[_P, _R],
loop_thread_id: int,
strict: bool = True,
strict_core: bool = True,
check_allowed: Callable[[dict[str, Any]], bool] | None = None,
@ -133,14 +125,15 @@ def protect_loop(
@functools.wraps(func)
def protected_loop_func(*args: _P.args, **kwargs: _P.kwargs) -> _R:
check_loop(
func,
strict=strict,
strict_core=strict_core,
check_allowed=check_allowed,
args=args,
kwargs=kwargs,
)
if threading.get_ident() == loop_thread_id:
raise_for_blocking_call(
func,
strict=strict,
strict_core=strict_core,
check_allowed=check_allowed,
args=args,
kwargs=kwargs,
)
return func(*args, **kwargs)
return protected_loop_func

View file

@ -159,14 +159,18 @@ async def test_shutdown_before_startup_finishes(
await recorder_helper.async_wait_recorder(hass)
instance = get_instance(hass)
session = await hass.async_add_executor_job(instance.get_session)
session = await instance.async_add_executor_job(instance.get_session)
with patch.object(instance, "engine"):
hass.bus.async_fire(EVENT_HOMEASSISTANT_FINAL_WRITE)
await hass.async_block_till_done()
await hass.async_stop()
run_info = await hass.async_add_executor_job(run_information_with_session, session)
def _run_information_with_session():
instance.recorder_and_worker_thread_ids.add(threading.get_ident())
return run_information_with_session(session)
run_info = await instance.async_add_executor_job(_run_information_with_session)
assert run_info.run_id == 1
assert run_info.start is not None
@ -1693,7 +1697,8 @@ async def test_database_corruption_while_running(
await hass.async_block_till_done()
caplog.clear()
original_start_time = get_instance(hass).recorder_runs_manager.recording_start
instance = get_instance(hass)
original_start_time = instance.recorder_runs_manager.recording_start
hass.states.async_set("test.lost", "on", {})
@ -1737,11 +1742,11 @@ async def test_database_corruption_while_running(
assert db_states[0].event_id is None
return db_states[0].to_native()
state = await hass.async_add_executor_job(_get_last_state)
state = await instance.async_add_executor_job(_get_last_state)
assert state.entity_id == "test.two"
assert state.state == "on"
new_start_time = get_instance(hass).recorder_runs_manager.recording_start
new_start_time = instance.recorder_runs_manager.recording_start
assert original_start_time < new_start_time
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
@ -1850,7 +1855,7 @@ async def test_database_lock_and_unlock(
assert instance.unlock_database()
await task
db_events = await hass.async_add_executor_job(_get_db_events)
db_events = await instance.async_add_executor_job(_get_db_events)
assert len(db_events) == 1

View file

@ -9,12 +9,13 @@ import importlib
import json
from pathlib import Path
import sys
import threading
from unittest.mock import patch
import pytest
from homeassistant.components import recorder
from homeassistant.components.recorder import SQLITE_URL_PREFIX
from homeassistant.components.recorder import SQLITE_URL_PREFIX, get_instance
from homeassistant.components.recorder.util import session_scope
from homeassistant.helpers import recorder as recorder_helper
from homeassistant.setup import setup_component
@ -176,6 +177,7 @@ def test_delete_duplicates(caplog: pytest.LogCaptureFixture, tmp_path: Path) ->
):
recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
get_instance(hass).recorder_and_worker_thread_ids.add(threading.get_ident())
wait_recording_done(hass)
wait_recording_done(hass)
@ -358,6 +360,7 @@ def test_delete_duplicates_many(
):
recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
get_instance(hass).recorder_and_worker_thread_ids.add(threading.get_ident())
wait_recording_done(hass)
wait_recording_done(hass)
@ -517,6 +520,7 @@ def test_delete_duplicates_non_identical(
):
recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
get_instance(hass).recorder_and_worker_thread_ids.add(threading.get_ident())
wait_recording_done(hass)
wait_recording_done(hass)
@ -631,6 +635,7 @@ def test_delete_duplicates_short_term(
):
recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
get_instance(hass).recorder_and_worker_thread_ids.add(threading.get_ident())
wait_recording_done(hass)
wait_recording_done(hass)

View file

@ -4,6 +4,7 @@ from datetime import UTC, datetime, timedelta
import os
from pathlib import Path
import sqlite3
import threading
from unittest.mock import MagicMock, Mock, patch
import pytest
@ -843,9 +844,7 @@ async def test_periodic_db_cleanups(
assert str(text_obj) == "PRAGMA wal_checkpoint(TRUNCATE);"
@patch("homeassistant.components.recorder.pool.check_loop")
async def test_write_lock_db(
skip_check_loop,
async_setup_recorder_instance: RecorderInstanceGenerator,
hass: HomeAssistant,
tmp_path: Path,
@ -864,6 +863,7 @@ async def test_write_lock_db(
with instance.engine.connect() as connection:
connection.execute(text("DROP TABLE events;"))
instance.recorder_and_worker_thread_ids.add(threading.get_ident())
with util.write_lock_db_sqlite(instance), pytest.raises(OperationalError):
# Database should be locked now, try writing SQL command
# This needs to be called in another thread since
@ -872,7 +872,7 @@ async def test_write_lock_db(
# in the same thread as the one holding the lock since it
# would be allowed to proceed as the goal is to prevent
# all the other threads from accessing the database
await hass.async_add_executor_job(_drop_table)
await instance.async_add_executor_job(_drop_table)
def test_is_second_sunday() -> None:

View file

@ -2,11 +2,13 @@
from datetime import datetime, timedelta
from pathlib import Path
import threading
from unittest.mock import patch
from freezegun.api import FrozenDateTimeFactory
import pytest
from homeassistant.components.recorder import get_instance
from homeassistant.components.recorder.history import get_significant_states
from homeassistant.components.recorder.statistics import (
get_latest_short_term_statistics_with_session,
@ -57,6 +59,7 @@ def test_compile_missing_statistics(
recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "sensor", {})
setup_component(hass, "recorder", {"recorder": config})
get_instance(hass).recorder_and_worker_thread_ids.add(threading.get_ident())
hass.start()
wait_recording_done(hass)
wait_recording_done(hass)
@ -98,6 +101,7 @@ def test_compile_missing_statistics(
setup_component(hass, "sensor", {})
hass.states.set("sensor.test1", "0", POWER_SENSOR_ATTRIBUTES)
setup_component(hass, "recorder", {"recorder": config})
get_instance(hass).recorder_and_worker_thread_ids.add(threading.get_ident())
hass.start()
wait_recording_done(hass)
wait_recording_done(hass)

View file

@ -1,9 +1,11 @@
"""Tests for async util methods from Python source."""
import threading
from unittest.mock import Mock, patch
import pytest
from homeassistant.core import HomeAssistant
from homeassistant.util import loop as haloop
from tests.common import extract_stack_to_frame
@ -13,22 +15,24 @@ def banned_function():
"""Mock banned function."""
async def test_check_loop_async() -> None:
"""Test check_loop detects when called from event loop without integration context."""
async def test_raise_for_blocking_call_async() -> None:
"""Test raise_for_blocking_call detects when called from event loop without integration context."""
with pytest.raises(RuntimeError):
haloop.check_loop(banned_function)
haloop.raise_for_blocking_call(banned_function)
async def test_check_loop_async_non_strict_core(
async def test_raise_for_blocking_call_async_non_strict_core(
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test non_strict_core check_loop detects from event loop without integration context."""
haloop.check_loop(banned_function, strict_core=False)
"""Test non_strict_core raise_for_blocking_call detects from event loop without integration context."""
haloop.raise_for_blocking_call(banned_function, strict_core=False)
assert "Detected blocking call to banned_function" in caplog.text
async def test_check_loop_async_integration(caplog: pytest.LogCaptureFixture) -> None:
"""Test check_loop detects and raises when called from event loop from integration context."""
async def test_raise_for_blocking_call_async_integration(
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test raise_for_blocking_call detects and raises when called from event loop from integration context."""
frames = extract_stack_to_frame(
[
Mock(
@ -67,7 +71,7 @@ async def test_check_loop_async_integration(caplog: pytest.LogCaptureFixture) ->
return_value=frames,
),
):
haloop.check_loop(banned_function)
haloop.raise_for_blocking_call(banned_function)
assert (
"Detected blocking call to banned_function inside the event loop by integration"
" 'hue' at homeassistant/components/hue/light.py, line 23: self.light.is_on "
@ -77,10 +81,10 @@ async def test_check_loop_async_integration(caplog: pytest.LogCaptureFixture) ->
)
async def test_check_loop_async_integration_non_strict(
async def test_raise_for_blocking_call_async_integration_non_strict(
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test check_loop detects when called from event loop from integration context."""
"""Test raise_for_blocking_call detects when called from event loop from integration context."""
frames = extract_stack_to_frame(
[
Mock(
@ -118,7 +122,7 @@ async def test_check_loop_async_integration_non_strict(
return_value=frames,
),
):
haloop.check_loop(banned_function, strict=False)
haloop.raise_for_blocking_call(banned_function, strict=False)
assert (
"Detected blocking call to banned_function inside the event loop by integration"
" 'hue' at homeassistant/components/hue/light.py, line 23: self.light.is_on "
@ -128,8 +132,10 @@ async def test_check_loop_async_integration_non_strict(
)
async def test_check_loop_async_custom(caplog: pytest.LogCaptureFixture) -> None:
"""Test check_loop detects when called from event loop with custom component context."""
async def test_raise_for_blocking_call_async_custom(
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test raise_for_blocking_call detects when called from event loop with custom component context."""
frames = extract_stack_to_frame(
[
Mock(
@ -168,7 +174,7 @@ async def test_check_loop_async_custom(caplog: pytest.LogCaptureFixture) -> None
return_value=frames,
),
):
haloop.check_loop(banned_function)
haloop.raise_for_blocking_call(banned_function)
assert (
"Detected blocking call to banned_function inside the event loop by custom "
"integration 'hue' at custom_components/hue/light.py, line 23: self.light.is_on"
@ -178,18 +184,23 @@ async def test_check_loop_async_custom(caplog: pytest.LogCaptureFixture) -> None
) in caplog.text
def test_check_loop_sync(caplog: pytest.LogCaptureFixture) -> None:
"""Test check_loop does nothing when called from thread."""
haloop.check_loop(banned_function)
async def test_raise_for_blocking_call_sync(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test raise_for_blocking_call does nothing when called from thread."""
func = haloop.protect_loop(banned_function, threading.get_ident())
await hass.async_add_executor_job(func)
assert "Detected blocking call inside the event loop" not in caplog.text
def test_protect_loop_sync() -> None:
"""Test protect_loop calls check_loop."""
async def test_protect_loop_async() -> None:
"""Test protect_loop calls raise_for_blocking_call."""
func = Mock()
with patch("homeassistant.util.loop.check_loop") as mock_check_loop:
haloop.protect_loop(func)(1, test=2)
mock_check_loop.assert_called_once_with(
with patch(
"homeassistant.util.loop.raise_for_blocking_call"
) as mock_raise_for_blocking_call:
haloop.protect_loop(func, threading.get_ident())(1, test=2)
mock_raise_for_blocking_call.assert_called_once_with(
func,
strict=True,
args=(1,),