Refactor asyncio loop protection to improve performance (#117295)
This commit is contained in:
parent
aae39759d9
commit
d06932bbc2
8 changed files with 91 additions and 59 deletions
|
@ -4,6 +4,7 @@ from contextlib import suppress
|
|||
from http.client import HTTPConnection
|
||||
import importlib
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
|
@ -25,7 +26,7 @@ def _check_sleep_call_allowed(mapped_args: dict[str, Any]) -> bool:
|
|||
# I/O and we are trying to avoid blocking calls.
|
||||
#
|
||||
# frame[0] is us
|
||||
# frame[1] is check_loop
|
||||
# frame[1] is raise_for_blocking_call
|
||||
# frame[2] is protected_loop_func
|
||||
# frame[3] is the offender
|
||||
with suppress(ValueError):
|
||||
|
@ -35,14 +36,18 @@ def _check_sleep_call_allowed(mapped_args: dict[str, Any]) -> bool:
|
|||
|
||||
def enable() -> None:
|
||||
"""Enable the detection of blocking calls in the event loop."""
|
||||
loop_thread_id = threading.get_ident()
|
||||
# Prevent urllib3 and requests doing I/O in event loop
|
||||
HTTPConnection.putrequest = protect_loop( # type: ignore[method-assign]
|
||||
HTTPConnection.putrequest
|
||||
HTTPConnection.putrequest, loop_thread_id=loop_thread_id
|
||||
)
|
||||
|
||||
# Prevent sleeping in event loop. Non-strict since 2022.02
|
||||
time.sleep = protect_loop(
|
||||
time.sleep, strict=False, check_allowed=_check_sleep_call_allowed
|
||||
time.sleep,
|
||||
strict=False,
|
||||
check_allowed=_check_sleep_call_allowed,
|
||||
loop_thread_id=loop_thread_id,
|
||||
)
|
||||
|
||||
# Currently disabled. pytz doing I/O when getting timezone.
|
||||
|
@ -57,4 +62,5 @@ def enable() -> None:
|
|||
strict_core=False,
|
||||
strict=False,
|
||||
check_allowed=_check_import_call_allowed,
|
||||
loop_thread_id=loop_thread_id,
|
||||
)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""A pool for sqlite connections."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import traceback
|
||||
|
@ -14,7 +15,7 @@ from sqlalchemy.pool import (
|
|||
)
|
||||
|
||||
from homeassistant.helpers.frame import report
|
||||
from homeassistant.util.loop import check_loop
|
||||
from homeassistant.util.loop import raise_for_blocking_call
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -86,15 +87,22 @@ class RecorderPool(SingletonThreadPool, NullPool):
|
|||
if threading.get_ident() in self.recorder_and_worker_thread_ids:
|
||||
super().dispose()
|
||||
|
||||
def _do_get(self) -> ConnectionPoolEntry:
|
||||
def _do_get(self) -> ConnectionPoolEntry: # type: ignore[return]
|
||||
if threading.get_ident() in self.recorder_and_worker_thread_ids:
|
||||
return super()._do_get()
|
||||
check_loop(
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# Not in an event loop but not in the recorder or worker thread
|
||||
# which is allowed but discouraged since its much slower
|
||||
return self._do_get_db_connection_protected()
|
||||
# In the event loop, raise an exception
|
||||
raise_for_blocking_call(
|
||||
self._do_get_db_connection_protected,
|
||||
strict=True,
|
||||
advise_msg=ADVISE_MSG,
|
||||
)
|
||||
return self._do_get_db_connection_protected()
|
||||
# raise_for_blocking_call will raise an exception
|
||||
|
||||
def _do_get_db_connection_protected(self) -> ConnectionPoolEntry:
|
||||
report(
|
||||
|
|
|
@ -2,12 +2,12 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from asyncio import get_running_loop
|
||||
from collections.abc import Callable
|
||||
from contextlib import suppress
|
||||
import functools
|
||||
import linecache
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
from homeassistant.core import HomeAssistant, async_get_hass
|
||||
|
@ -31,7 +31,7 @@ def _get_line_from_cache(filename: str, lineno: int) -> str:
|
|||
return (linecache.getline(filename, lineno) or "?").strip()
|
||||
|
||||
|
||||
def check_loop(
|
||||
def raise_for_blocking_call(
|
||||
func: Callable[..., Any],
|
||||
check_allowed: Callable[[dict[str, Any]], bool] | None = None,
|
||||
strict: bool = True,
|
||||
|
@ -44,15 +44,6 @@ def check_loop(
|
|||
The default advisory message is 'Use `await hass.async_add_executor_job()'
|
||||
Set `advise_msg` to an alternate message if the solution differs.
|
||||
"""
|
||||
try:
|
||||
get_running_loop()
|
||||
in_loop = True
|
||||
except RuntimeError:
|
||||
in_loop = False
|
||||
|
||||
if not in_loop:
|
||||
return
|
||||
|
||||
if check_allowed is not None and check_allowed(mapped_args):
|
||||
return
|
||||
|
||||
|
@ -125,6 +116,7 @@ def check_loop(
|
|||
|
||||
def protect_loop(
|
||||
func: Callable[_P, _R],
|
||||
loop_thread_id: int,
|
||||
strict: bool = True,
|
||||
strict_core: bool = True,
|
||||
check_allowed: Callable[[dict[str, Any]], bool] | None = None,
|
||||
|
@ -133,14 +125,15 @@ def protect_loop(
|
|||
|
||||
@functools.wraps(func)
|
||||
def protected_loop_func(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
check_loop(
|
||||
func,
|
||||
strict=strict,
|
||||
strict_core=strict_core,
|
||||
check_allowed=check_allowed,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
if threading.get_ident() == loop_thread_id:
|
||||
raise_for_blocking_call(
|
||||
func,
|
||||
strict=strict,
|
||||
strict_core=strict_core,
|
||||
check_allowed=check_allowed,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return protected_loop_func
|
||||
|
|
|
@ -159,14 +159,18 @@ async def test_shutdown_before_startup_finishes(
|
|||
await recorder_helper.async_wait_recorder(hass)
|
||||
instance = get_instance(hass)
|
||||
|
||||
session = await hass.async_add_executor_job(instance.get_session)
|
||||
session = await instance.async_add_executor_job(instance.get_session)
|
||||
|
||||
with patch.object(instance, "engine"):
|
||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_FINAL_WRITE)
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_stop()
|
||||
|
||||
run_info = await hass.async_add_executor_job(run_information_with_session, session)
|
||||
def _run_information_with_session():
|
||||
instance.recorder_and_worker_thread_ids.add(threading.get_ident())
|
||||
return run_information_with_session(session)
|
||||
|
||||
run_info = await instance.async_add_executor_job(_run_information_with_session)
|
||||
|
||||
assert run_info.run_id == 1
|
||||
assert run_info.start is not None
|
||||
|
@ -1693,7 +1697,8 @@ async def test_database_corruption_while_running(
|
|||
await hass.async_block_till_done()
|
||||
caplog.clear()
|
||||
|
||||
original_start_time = get_instance(hass).recorder_runs_manager.recording_start
|
||||
instance = get_instance(hass)
|
||||
original_start_time = instance.recorder_runs_manager.recording_start
|
||||
|
||||
hass.states.async_set("test.lost", "on", {})
|
||||
|
||||
|
@ -1737,11 +1742,11 @@ async def test_database_corruption_while_running(
|
|||
assert db_states[0].event_id is None
|
||||
return db_states[0].to_native()
|
||||
|
||||
state = await hass.async_add_executor_job(_get_last_state)
|
||||
state = await instance.async_add_executor_job(_get_last_state)
|
||||
assert state.entity_id == "test.two"
|
||||
assert state.state == "on"
|
||||
|
||||
new_start_time = get_instance(hass).recorder_runs_manager.recording_start
|
||||
new_start_time = instance.recorder_runs_manager.recording_start
|
||||
assert original_start_time < new_start_time
|
||||
|
||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
|
||||
|
@ -1850,7 +1855,7 @@ async def test_database_lock_and_unlock(
|
|||
assert instance.unlock_database()
|
||||
|
||||
await task
|
||||
db_events = await hass.async_add_executor_job(_get_db_events)
|
||||
db_events = await instance.async_add_executor_job(_get_db_events)
|
||||
assert len(db_events) == 1
|
||||
|
||||
|
||||
|
|
|
@ -9,12 +9,13 @@ import importlib
|
|||
import json
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import recorder
|
||||
from homeassistant.components.recorder import SQLITE_URL_PREFIX
|
||||
from homeassistant.components.recorder import SQLITE_URL_PREFIX, get_instance
|
||||
from homeassistant.components.recorder.util import session_scope
|
||||
from homeassistant.helpers import recorder as recorder_helper
|
||||
from homeassistant.setup import setup_component
|
||||
|
@ -176,6 +177,7 @@ def test_delete_duplicates(caplog: pytest.LogCaptureFixture, tmp_path: Path) ->
|
|||
):
|
||||
recorder_helper.async_initialize_recorder(hass)
|
||||
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
|
||||
get_instance(hass).recorder_and_worker_thread_ids.add(threading.get_ident())
|
||||
wait_recording_done(hass)
|
||||
wait_recording_done(hass)
|
||||
|
||||
|
@ -358,6 +360,7 @@ def test_delete_duplicates_many(
|
|||
):
|
||||
recorder_helper.async_initialize_recorder(hass)
|
||||
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
|
||||
get_instance(hass).recorder_and_worker_thread_ids.add(threading.get_ident())
|
||||
wait_recording_done(hass)
|
||||
wait_recording_done(hass)
|
||||
|
||||
|
@ -517,6 +520,7 @@ def test_delete_duplicates_non_identical(
|
|||
):
|
||||
recorder_helper.async_initialize_recorder(hass)
|
||||
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
|
||||
get_instance(hass).recorder_and_worker_thread_ids.add(threading.get_ident())
|
||||
wait_recording_done(hass)
|
||||
wait_recording_done(hass)
|
||||
|
||||
|
@ -631,6 +635,7 @@ def test_delete_duplicates_short_term(
|
|||
):
|
||||
recorder_helper.async_initialize_recorder(hass)
|
||||
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
|
||||
get_instance(hass).recorder_and_worker_thread_ids.add(threading.get_ident())
|
||||
wait_recording_done(hass)
|
||||
wait_recording_done(hass)
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ from datetime import UTC, datetime, timedelta
|
|||
import os
|
||||
from pathlib import Path
|
||||
import sqlite3
|
||||
import threading
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
@ -843,9 +844,7 @@ async def test_periodic_db_cleanups(
|
|||
assert str(text_obj) == "PRAGMA wal_checkpoint(TRUNCATE);"
|
||||
|
||||
|
||||
@patch("homeassistant.components.recorder.pool.check_loop")
|
||||
async def test_write_lock_db(
|
||||
skip_check_loop,
|
||||
async_setup_recorder_instance: RecorderInstanceGenerator,
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
|
@ -864,6 +863,7 @@ async def test_write_lock_db(
|
|||
with instance.engine.connect() as connection:
|
||||
connection.execute(text("DROP TABLE events;"))
|
||||
|
||||
instance.recorder_and_worker_thread_ids.add(threading.get_ident())
|
||||
with util.write_lock_db_sqlite(instance), pytest.raises(OperationalError):
|
||||
# Database should be locked now, try writing SQL command
|
||||
# This needs to be called in another thread since
|
||||
|
@ -872,7 +872,7 @@ async def test_write_lock_db(
|
|||
# in the same thread as the one holding the lock since it
|
||||
# would be allowed to proceed as the goal is to prevent
|
||||
# all the other threads from accessing the database
|
||||
await hass.async_add_executor_job(_drop_table)
|
||||
await instance.async_add_executor_job(_drop_table)
|
||||
|
||||
|
||||
def test_is_second_sunday() -> None:
|
||||
|
|
|
@ -2,11 +2,13 @@
|
|||
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
from freezegun.api import FrozenDateTimeFactory
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.recorder import get_instance
|
||||
from homeassistant.components.recorder.history import get_significant_states
|
||||
from homeassistant.components.recorder.statistics import (
|
||||
get_latest_short_term_statistics_with_session,
|
||||
|
@ -57,6 +59,7 @@ def test_compile_missing_statistics(
|
|||
recorder_helper.async_initialize_recorder(hass)
|
||||
setup_component(hass, "sensor", {})
|
||||
setup_component(hass, "recorder", {"recorder": config})
|
||||
get_instance(hass).recorder_and_worker_thread_ids.add(threading.get_ident())
|
||||
hass.start()
|
||||
wait_recording_done(hass)
|
||||
wait_recording_done(hass)
|
||||
|
@ -98,6 +101,7 @@ def test_compile_missing_statistics(
|
|||
setup_component(hass, "sensor", {})
|
||||
hass.states.set("sensor.test1", "0", POWER_SENSOR_ATTRIBUTES)
|
||||
setup_component(hass, "recorder", {"recorder": config})
|
||||
get_instance(hass).recorder_and_worker_thread_ids.add(threading.get_ident())
|
||||
hass.start()
|
||||
wait_recording_done(hass)
|
||||
wait_recording_done(hass)
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
"""Tests for async util methods from Python source."""
|
||||
|
||||
import threading
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.util import loop as haloop
|
||||
|
||||
from tests.common import extract_stack_to_frame
|
||||
|
@ -13,22 +15,24 @@ def banned_function():
|
|||
"""Mock banned function."""
|
||||
|
||||
|
||||
async def test_check_loop_async() -> None:
|
||||
"""Test check_loop detects when called from event loop without integration context."""
|
||||
async def test_raise_for_blocking_call_async() -> None:
|
||||
"""Test raise_for_blocking_call detects when called from event loop without integration context."""
|
||||
with pytest.raises(RuntimeError):
|
||||
haloop.check_loop(banned_function)
|
||||
haloop.raise_for_blocking_call(banned_function)
|
||||
|
||||
|
||||
async def test_check_loop_async_non_strict_core(
|
||||
async def test_raise_for_blocking_call_async_non_strict_core(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test non_strict_core check_loop detects from event loop without integration context."""
|
||||
haloop.check_loop(banned_function, strict_core=False)
|
||||
"""Test non_strict_core raise_for_blocking_call detects from event loop without integration context."""
|
||||
haloop.raise_for_blocking_call(banned_function, strict_core=False)
|
||||
assert "Detected blocking call to banned_function" in caplog.text
|
||||
|
||||
|
||||
async def test_check_loop_async_integration(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""Test check_loop detects and raises when called from event loop from integration context."""
|
||||
async def test_raise_for_blocking_call_async_integration(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test raise_for_blocking_call detects and raises when called from event loop from integration context."""
|
||||
frames = extract_stack_to_frame(
|
||||
[
|
||||
Mock(
|
||||
|
@ -67,7 +71,7 @@ async def test_check_loop_async_integration(caplog: pytest.LogCaptureFixture) ->
|
|||
return_value=frames,
|
||||
),
|
||||
):
|
||||
haloop.check_loop(banned_function)
|
||||
haloop.raise_for_blocking_call(banned_function)
|
||||
assert (
|
||||
"Detected blocking call to banned_function inside the event loop by integration"
|
||||
" 'hue' at homeassistant/components/hue/light.py, line 23: self.light.is_on "
|
||||
|
@ -77,10 +81,10 @@ async def test_check_loop_async_integration(caplog: pytest.LogCaptureFixture) ->
|
|||
)
|
||||
|
||||
|
||||
async def test_check_loop_async_integration_non_strict(
|
||||
async def test_raise_for_blocking_call_async_integration_non_strict(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test check_loop detects when called from event loop from integration context."""
|
||||
"""Test raise_for_blocking_call detects when called from event loop from integration context."""
|
||||
frames = extract_stack_to_frame(
|
||||
[
|
||||
Mock(
|
||||
|
@ -118,7 +122,7 @@ async def test_check_loop_async_integration_non_strict(
|
|||
return_value=frames,
|
||||
),
|
||||
):
|
||||
haloop.check_loop(banned_function, strict=False)
|
||||
haloop.raise_for_blocking_call(banned_function, strict=False)
|
||||
assert (
|
||||
"Detected blocking call to banned_function inside the event loop by integration"
|
||||
" 'hue' at homeassistant/components/hue/light.py, line 23: self.light.is_on "
|
||||
|
@ -128,8 +132,10 @@ async def test_check_loop_async_integration_non_strict(
|
|||
)
|
||||
|
||||
|
||||
async def test_check_loop_async_custom(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""Test check_loop detects when called from event loop with custom component context."""
|
||||
async def test_raise_for_blocking_call_async_custom(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test raise_for_blocking_call detects when called from event loop with custom component context."""
|
||||
frames = extract_stack_to_frame(
|
||||
[
|
||||
Mock(
|
||||
|
@ -168,7 +174,7 @@ async def test_check_loop_async_custom(caplog: pytest.LogCaptureFixture) -> None
|
|||
return_value=frames,
|
||||
),
|
||||
):
|
||||
haloop.check_loop(banned_function)
|
||||
haloop.raise_for_blocking_call(banned_function)
|
||||
assert (
|
||||
"Detected blocking call to banned_function inside the event loop by custom "
|
||||
"integration 'hue' at custom_components/hue/light.py, line 23: self.light.is_on"
|
||||
|
@ -178,18 +184,23 @@ async def test_check_loop_async_custom(caplog: pytest.LogCaptureFixture) -> None
|
|||
) in caplog.text
|
||||
|
||||
|
||||
def test_check_loop_sync(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""Test check_loop does nothing when called from thread."""
|
||||
haloop.check_loop(banned_function)
|
||||
async def test_raise_for_blocking_call_sync(
|
||||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test raise_for_blocking_call does nothing when called from thread."""
|
||||
func = haloop.protect_loop(banned_function, threading.get_ident())
|
||||
await hass.async_add_executor_job(func)
|
||||
assert "Detected blocking call inside the event loop" not in caplog.text
|
||||
|
||||
|
||||
def test_protect_loop_sync() -> None:
|
||||
"""Test protect_loop calls check_loop."""
|
||||
async def test_protect_loop_async() -> None:
|
||||
"""Test protect_loop calls raise_for_blocking_call."""
|
||||
func = Mock()
|
||||
with patch("homeassistant.util.loop.check_loop") as mock_check_loop:
|
||||
haloop.protect_loop(func)(1, test=2)
|
||||
mock_check_loop.assert_called_once_with(
|
||||
with patch(
|
||||
"homeassistant.util.loop.raise_for_blocking_call"
|
||||
) as mock_raise_for_blocking_call:
|
||||
haloop.protect_loop(func, threading.get_ident())(1, test=2)
|
||||
mock_raise_for_blocking_call.assert_called_once_with(
|
||||
func,
|
||||
strict=True,
|
||||
args=(1,),
|
||||
|
|
Loading…
Add table
Reference in a new issue