Make database access in the eventloop raise an exception (#71547)
This commit is contained in:
parent
2560d35f1c
commit
222baa53dd
7 changed files with 66 additions and 32 deletions
|
@ -8,6 +8,7 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||||
from sqlalchemy.pool import NullPool, SingletonThreadPool, StaticPool
|
from sqlalchemy.pool import NullPool, SingletonThreadPool, StaticPool
|
||||||
|
|
||||||
from homeassistant.helpers.frame import report
|
from homeassistant.helpers.frame import report
|
||||||
|
from homeassistant.util.async_ import check_loop
|
||||||
|
|
||||||
from .const import DB_WORKER_PREFIX
|
from .const import DB_WORKER_PREFIX
|
||||||
|
|
||||||
|
@ -19,6 +20,10 @@ DEBUG_MUTEX_POOL_TRACE = False
|
||||||
|
|
||||||
POOL_SIZE = 5
|
POOL_SIZE = 5
|
||||||
|
|
||||||
|
ADVISE_MSG = (
|
||||||
|
"Use homeassistant.components.recorder.get_instance(hass).async_add_executor_job()"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc]
|
class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc]
|
||||||
"""A hybrid of NullPool and SingletonThreadPool.
|
"""A hybrid of NullPool and SingletonThreadPool.
|
||||||
|
@ -62,9 +67,17 @@ class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc]
|
||||||
def _do_get(self) -> Any:
|
def _do_get(self) -> Any:
|
||||||
if self.recorder_or_dbworker:
|
if self.recorder_or_dbworker:
|
||||||
return super()._do_get()
|
return super()._do_get()
|
||||||
|
check_loop(
|
||||||
|
self._do_get_db_connection_protected,
|
||||||
|
strict=True,
|
||||||
|
advise_msg=ADVISE_MSG,
|
||||||
|
)
|
||||||
|
return self._do_get_db_connection_protected()
|
||||||
|
|
||||||
|
def _do_get_db_connection_protected(self) -> Any:
|
||||||
report(
|
report(
|
||||||
"accesses the database without the database executor; "
|
"accesses the database without the database executor; "
|
||||||
"Use homeassistant.components.recorder.get_instance(hass).async_add_executor_job() "
|
f"{ADVISE_MSG} "
|
||||||
"for faster database operations",
|
"for faster database operations",
|
||||||
exclude_integrations={"recorder"},
|
exclude_integrations={"recorder"},
|
||||||
error_if_core=False,
|
error_if_core=False,
|
||||||
|
|
|
@ -94,8 +94,14 @@ def run_callback_threadsafe(
|
||||||
return future
|
return future
|
||||||
|
|
||||||
|
|
||||||
def check_loop(func: Callable[..., Any], strict: bool = True) -> None:
|
def check_loop(
|
||||||
"""Warn if called inside the event loop. Raise if `strict` is True."""
|
func: Callable[..., Any], strict: bool = True, advise_msg: str | None = None
|
||||||
|
) -> None:
|
||||||
|
"""Warn if called inside the event loop. Raise if `strict` is True.
|
||||||
|
|
||||||
|
The default advisory message is 'Use `await hass.async_add_executor_job()'
|
||||||
|
Set `advise_msg` to an alternate message if the the solution differs.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
get_running_loop()
|
get_running_loop()
|
||||||
in_loop = True
|
in_loop = True
|
||||||
|
@ -134,6 +140,7 @@ def check_loop(func: Callable[..., Any], strict: bool = True) -> None:
|
||||||
if found_frame is None:
|
if found_frame is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Detected blocking call to {func.__name__} inside the event loop. "
|
f"Detected blocking call to {func.__name__} inside the event loop. "
|
||||||
|
f"{advise_msg or 'Use `await hass.async_add_executor_job()`'}; "
|
||||||
"This is causing stability issues. Please report issue"
|
"This is causing stability issues. Please report issue"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -160,7 +167,7 @@ def check_loop(func: Callable[..., Any], strict: bool = True) -> None:
|
||||||
if strict:
|
if strict:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Blocking calls must be done in the executor or a separate thread; "
|
"Blocking calls must be done in the executor or a separate thread; "
|
||||||
"Use `await hass.async_add_executor_job()` "
|
f"{advise_msg or 'Use `await hass.async_add_executor_job()`'}; "
|
||||||
f"at {found_frame.filename[index:]}, line {found_frame.lineno}: {(found_frame.line or '?').strip()}"
|
f"at {found_frame.filename[index:]}, line {found_frame.lineno}: {(found_frame.line or '?').strip()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -905,10 +905,9 @@ def init_recorder_component(hass, add_config=None):
|
||||||
if recorder.CONF_COMMIT_INTERVAL not in config:
|
if recorder.CONF_COMMIT_INTERVAL not in config:
|
||||||
config[recorder.CONF_COMMIT_INTERVAL] = 0
|
config[recorder.CONF_COMMIT_INTERVAL] = 0
|
||||||
|
|
||||||
with patch(
|
with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch(
|
||||||
"homeassistant.components.recorder.ALLOW_IN_MEMORY_DB",
|
"homeassistant.components.recorder.migration.migrate_schema"
|
||||||
True,
|
):
|
||||||
), patch("homeassistant.components.recorder.migration.migrate_schema"):
|
|
||||||
assert setup_component(hass, recorder.DOMAIN, {recorder.DOMAIN: config})
|
assert setup_component(hass, recorder.DOMAIN, {recorder.DOMAIN: config})
|
||||||
assert recorder.DOMAIN in hass.config.components
|
assert recorder.DOMAIN in hass.config.components
|
||||||
_LOGGER.info(
|
_LOGGER.info(
|
||||||
|
|
|
@ -1319,7 +1319,9 @@ def test_entity_id_filter(hass_recorder):
|
||||||
|
|
||||||
|
|
||||||
async def test_database_lock_and_unlock(
|
async def test_database_lock_and_unlock(
|
||||||
hass: HomeAssistant, async_setup_recorder_instance: SetupRecorderInstanceT, tmp_path
|
hass: HomeAssistant,
|
||||||
|
async_setup_recorder_instance: SetupRecorderInstanceT,
|
||||||
|
tmp_path,
|
||||||
):
|
):
|
||||||
"""Test writing events during lock getting written after unlocking."""
|
"""Test writing events during lock getting written after unlocking."""
|
||||||
# Use file DB, in memory DB cannot do write locks.
|
# Use file DB, in memory DB cannot do write locks.
|
||||||
|
@ -1330,6 +1332,10 @@ async def test_database_lock_and_unlock(
|
||||||
await async_setup_recorder_instance(hass, config)
|
await async_setup_recorder_instance(hass, config)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
def _get_db_events():
|
||||||
|
with session_scope(hass=hass) as session:
|
||||||
|
return list(session.query(Events).filter_by(event_type=event_type))
|
||||||
|
|
||||||
instance: Recorder = hass.data[DATA_INSTANCE]
|
instance: Recorder = hass.data[DATA_INSTANCE]
|
||||||
|
|
||||||
assert await instance.lock_database()
|
assert await instance.lock_database()
|
||||||
|
@ -1344,21 +1350,20 @@ async def test_database_lock_and_unlock(
|
||||||
# Recording can't be finished while lock is held
|
# Recording can't be finished while lock is held
|
||||||
with pytest.raises(asyncio.TimeoutError):
|
with pytest.raises(asyncio.TimeoutError):
|
||||||
await asyncio.wait_for(asyncio.shield(task), timeout=1)
|
await asyncio.wait_for(asyncio.shield(task), timeout=1)
|
||||||
|
db_events = await hass.async_add_executor_job(_get_db_events)
|
||||||
with session_scope(hass=hass) as session:
|
|
||||||
db_events = list(session.query(Events).filter_by(event_type=event_type))
|
|
||||||
assert len(db_events) == 0
|
assert len(db_events) == 0
|
||||||
|
|
||||||
assert instance.unlock_database()
|
assert instance.unlock_database()
|
||||||
|
|
||||||
await task
|
await task
|
||||||
with session_scope(hass=hass) as session:
|
db_events = await hass.async_add_executor_job(_get_db_events)
|
||||||
db_events = list(session.query(Events).filter_by(event_type=event_type))
|
assert len(db_events) == 1
|
||||||
assert len(db_events) == 1
|
|
||||||
|
|
||||||
|
|
||||||
async def test_database_lock_and_overflow(
|
async def test_database_lock_and_overflow(
|
||||||
hass: HomeAssistant, async_setup_recorder_instance: SetupRecorderInstanceT, tmp_path
|
hass: HomeAssistant,
|
||||||
|
async_setup_recorder_instance: SetupRecorderInstanceT,
|
||||||
|
tmp_path,
|
||||||
):
|
):
|
||||||
"""Test writing events during lock leading to overflow the queue causes the database to unlock."""
|
"""Test writing events during lock leading to overflow the queue causes the database to unlock."""
|
||||||
# Use file DB, in memory DB cannot do write locks.
|
# Use file DB, in memory DB cannot do write locks.
|
||||||
|
@ -1369,6 +1374,10 @@ async def test_database_lock_and_overflow(
|
||||||
await async_setup_recorder_instance(hass, config)
|
await async_setup_recorder_instance(hass, config)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
def _get_db_events():
|
||||||
|
with session_scope(hass=hass) as session:
|
||||||
|
return list(session.query(Events).filter_by(event_type=event_type))
|
||||||
|
|
||||||
instance: Recorder = hass.data[DATA_INSTANCE]
|
instance: Recorder = hass.data[DATA_INSTANCE]
|
||||||
|
|
||||||
with patch.object(recorder.core, "MAX_QUEUE_BACKLOG", 1), patch.object(
|
with patch.object(recorder.core, "MAX_QUEUE_BACKLOG", 1), patch.object(
|
||||||
|
@ -1384,9 +1393,8 @@ async def test_database_lock_and_overflow(
|
||||||
# even before unlocking.
|
# even before unlocking.
|
||||||
await async_wait_recording_done(hass)
|
await async_wait_recording_done(hass)
|
||||||
|
|
||||||
with session_scope(hass=hass) as session:
|
db_events = await hass.async_add_executor_job(_get_db_events)
|
||||||
db_events = list(session.query(Events).filter_by(event_type=event_type))
|
assert len(db_events) == 1
|
||||||
assert len(db_events) == 1
|
|
||||||
|
|
||||||
assert not instance.unlock_database()
|
assert not instance.unlock_database()
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Test pool."""
|
"""Test pool."""
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
|
import pytest
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
@ -8,6 +9,13 @@ from homeassistant.components.recorder.const import DB_WORKER_PREFIX
|
||||||
from homeassistant.components.recorder.pool import RecorderPool
|
from homeassistant.components.recorder.pool import RecorderPool
|
||||||
|
|
||||||
|
|
||||||
|
async def test_recorder_pool_called_from_event_loop():
|
||||||
|
"""Test we raise an exception when calling from the event loop."""
|
||||||
|
engine = create_engine("sqlite://", poolclass=RecorderPool)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
sessionmaker(bind=engine)().connection()
|
||||||
|
|
||||||
|
|
||||||
def test_recorder_pool(caplog):
|
def test_recorder_pool(caplog):
|
||||||
"""Test RecorderPool gives the same connection in the creating thread."""
|
"""Test RecorderPool gives the same connection in the creating thread."""
|
||||||
|
|
||||||
|
@ -28,30 +36,26 @@ def test_recorder_pool(caplog):
|
||||||
connections.append(session.connection().connection.connection)
|
connections.append(session.connection().connection.connection)
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
_get_connection_twice()
|
|
||||||
assert "accesses the database without the database executor" in caplog.text
|
|
||||||
assert connections[0] != connections[1]
|
|
||||||
|
|
||||||
caplog.clear()
|
caplog.clear()
|
||||||
new_thread = threading.Thread(target=_get_connection_twice)
|
new_thread = threading.Thread(target=_get_connection_twice)
|
||||||
new_thread.start()
|
new_thread.start()
|
||||||
new_thread.join()
|
new_thread.join()
|
||||||
assert "accesses the database without the database executor" in caplog.text
|
assert "accesses the database without the database executor" in caplog.text
|
||||||
assert connections[2] != connections[3]
|
assert connections[0] != connections[1]
|
||||||
|
|
||||||
caplog.clear()
|
caplog.clear()
|
||||||
new_thread = threading.Thread(target=_get_connection_twice, name=DB_WORKER_PREFIX)
|
new_thread = threading.Thread(target=_get_connection_twice, name=DB_WORKER_PREFIX)
|
||||||
new_thread.start()
|
new_thread.start()
|
||||||
new_thread.join()
|
new_thread.join()
|
||||||
assert "accesses the database without the database executor" not in caplog.text
|
assert "accesses the database without the database executor" not in caplog.text
|
||||||
assert connections[4] == connections[5]
|
assert connections[2] == connections[3]
|
||||||
|
|
||||||
caplog.clear()
|
caplog.clear()
|
||||||
new_thread = threading.Thread(target=_get_connection_twice, name="Recorder")
|
new_thread = threading.Thread(target=_get_connection_twice, name="Recorder")
|
||||||
new_thread.start()
|
new_thread.start()
|
||||||
new_thread.join()
|
new_thread.join()
|
||||||
assert "accesses the database without the database executor" not in caplog.text
|
assert "accesses the database without the database executor" not in caplog.text
|
||||||
assert connections[6] == connections[7]
|
assert connections[4] == connections[5]
|
||||||
|
|
||||||
shutdown = True
|
shutdown = True
|
||||||
caplog.clear()
|
caplog.clear()
|
||||||
|
@ -59,4 +63,4 @@ def test_recorder_pool(caplog):
|
||||||
new_thread.start()
|
new_thread.start()
|
||||||
new_thread.join()
|
new_thread.join()
|
||||||
assert "accesses the database without the database executor" not in caplog.text
|
assert "accesses the database without the database executor" not in caplog.text
|
||||||
assert connections[8] != connections[9]
|
assert connections[6] != connections[7]
|
||||||
|
|
|
@ -597,8 +597,12 @@ def test_periodic_db_cleanups(hass_recorder):
|
||||||
assert str(text_obj) == "PRAGMA wal_checkpoint(TRUNCATE);"
|
assert str(text_obj) == "PRAGMA wal_checkpoint(TRUNCATE);"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("homeassistant.components.recorder.pool.check_loop")
|
||||||
async def test_write_lock_db(
|
async def test_write_lock_db(
|
||||||
hass: HomeAssistant, async_setup_recorder_instance: SetupRecorderInstanceT, tmp_path
|
skip_check_loop,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
async_setup_recorder_instance: SetupRecorderInstanceT,
|
||||||
|
tmp_path,
|
||||||
):
|
):
|
||||||
"""Test database write lock."""
|
"""Test database write lock."""
|
||||||
from sqlalchemy.exc import OperationalError
|
from sqlalchemy.exc import OperationalError
|
||||||
|
|
|
@ -692,10 +692,9 @@ async def _async_init_recorder_component(hass, add_config=None):
|
||||||
if recorder.CONF_COMMIT_INTERVAL not in config:
|
if recorder.CONF_COMMIT_INTERVAL not in config:
|
||||||
config[recorder.CONF_COMMIT_INTERVAL] = 0
|
config[recorder.CONF_COMMIT_INTERVAL] = 0
|
||||||
|
|
||||||
with patch(
|
with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch(
|
||||||
"homeassistant.components.recorder.ALLOW_IN_MEMORY_DB",
|
"homeassistant.components.recorder.migration.migrate_schema"
|
||||||
True,
|
):
|
||||||
), patch("homeassistant.components.recorder.migration.migrate_schema"):
|
|
||||||
assert await async_setup_component(
|
assert await async_setup_component(
|
||||||
hass, recorder.DOMAIN, {recorder.DOMAIN: config}
|
hass, recorder.DOMAIN, {recorder.DOMAIN: config}
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Reference in a new issue