Improve recorder and worker thread matching in RecorderPool (#116886)

* Improve recorder and worker thread matching in RecorderPool

Previously we would look at the name of the threads. This
was a brittle if because other integrations may name their
thread Recorder or DbWorker. Instead we now use explict thread
ids which ensures there will never be a conflict

* fix

* fixes

* fixes
This commit is contained in:
J. Nick Koston 2024-05-05 15:25:10 -05:00 committed by GitHub
parent ee031f4850
commit 6339c63176
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 65 additions and 23 deletions

View file

@ -187,6 +187,7 @@ class Recorder(threading.Thread):
self.hass = hass
self.thread_id: int | None = None
self.recorder_and_worker_thread_ids: set[int] = set()
self.auto_purge = auto_purge
self.auto_repack = auto_repack
self.keep_days = keep_days
@ -294,6 +295,7 @@ class Recorder(threading.Thread):
def async_start_executor(self) -> None:
"""Start the executor."""
self._db_executor = DBInterruptibleThreadPoolExecutor(
self.recorder_and_worker_thread_ids,
thread_name_prefix=DB_WORKER_PREFIX,
max_workers=MAX_DB_EXECUTOR_WORKERS,
shutdown_hook=self._shutdown_pool,
@ -717,7 +719,10 @@ class Recorder(threading.Thread):
def _run(self) -> None:
"""Start processing events to save."""
self.thread_id = threading.get_ident()
thread_id = threading.get_ident()
self.thread_id = thread_id
self.recorder_and_worker_thread_ids.add(thread_id)
setup_result = self._setup_recorder()
if not setup_result:
@ -1411,6 +1416,9 @@ class Recorder(threading.Thread):
kwargs["pool_reset_on_return"] = None
elif self.db_url.startswith(SQLITE_URL_PREFIX):
kwargs["poolclass"] = RecorderPool
kwargs["recorder_and_worker_thread_ids"] = (
self.recorder_and_worker_thread_ids
)
elif self.db_url.startswith(
(
MARIADB_URL_PREFIX,

View file

@ -12,9 +12,13 @@ from homeassistant.util.executor import InterruptibleThreadPoolExecutor
def _worker_with_shutdown_hook(
shutdown_hook: Callable[[], None], *args: Any, **kwargs: Any
shutdown_hook: Callable[[], None],
recorder_and_worker_thread_ids: set[int],
*args: Any,
**kwargs: Any,
) -> None:
"""Create a worker that calls a function after its finished."""
recorder_and_worker_thread_ids.add(threading.get_ident())
_worker(*args, **kwargs)
shutdown_hook()
@ -22,9 +26,12 @@ def _worker_with_shutdown_hook(
class DBInterruptibleThreadPoolExecutor(InterruptibleThreadPoolExecutor):
"""A database instance that will not deadlock on shutdown."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
def __init__(
self, recorder_and_worker_thread_ids: set[int], *args: Any, **kwargs: Any
) -> None:
"""Init the executor with a shutdown hook support."""
self._shutdown_hook: Callable[[], None] = kwargs.pop("shutdown_hook")
self.recorder_and_worker_thread_ids = recorder_and_worker_thread_ids
super().__init__(*args, **kwargs)
def _adjust_thread_count(self) -> None:
@ -54,6 +61,7 @@ class DBInterruptibleThreadPoolExecutor(InterruptibleThreadPoolExecutor):
target=_worker_with_shutdown_hook,
args=(
self._shutdown_hook,
self.recorder_and_worker_thread_ids,
weakref.ref(self, weakref_cb),
self._work_queue,
self._initializer,

View file

@ -16,8 +16,6 @@ from sqlalchemy.pool import (
from homeassistant.helpers.frame import report
from homeassistant.util.loop import check_loop
from .const import DB_WORKER_PREFIX
_LOGGER = logging.getLogger(__name__)
# For debugging the MutexPool
@ -31,7 +29,7 @@ ADVISE_MSG = (
)
class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc]
class RecorderPool(SingletonThreadPool, NullPool):
"""A hybrid of NullPool and SingletonThreadPool.
When called from the creating thread or db executor acts like SingletonThreadPool
@ -39,29 +37,44 @@ class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc]
"""
def __init__( # pylint: disable=super-init-not-called
self, *args: Any, **kw: Any
self,
creator: Any,
recorder_and_worker_thread_ids: set[int] | None = None,
**kw: Any,
) -> None:
"""Create the pool."""
kw["pool_size"] = POOL_SIZE
SingletonThreadPool.__init__(self, *args, **kw)
assert (
recorder_and_worker_thread_ids is not None
), "recorder_and_worker_thread_ids is required"
self.recorder_and_worker_thread_ids = recorder_and_worker_thread_ids
SingletonThreadPool.__init__(self, creator, **kw)
@property
def recorder_or_dbworker(self) -> bool:
"""Check if the thread is a recorder or dbworker thread."""
thread_name = threading.current_thread().name
return bool(
thread_name == "Recorder" or thread_name.startswith(DB_WORKER_PREFIX)
def recreate(self) -> "RecorderPool":
"""Recreate the pool."""
self.logger.info("Pool recreating")
return self.__class__(
self._creator,
pool_size=self.size,
recycle=self._recycle,
echo=self.echo,
pre_ping=self._pre_ping,
logging_name=self._orig_logging_name,
reset_on_return=self._reset_on_return,
_dispatch=self.dispatch,
dialect=self._dialect,
recorder_and_worker_thread_ids=self.recorder_and_worker_thread_ids,
)
def _do_return_conn(self, record: ConnectionPoolEntry) -> None:
if self.recorder_or_dbworker:
if threading.get_ident() in self.recorder_and_worker_thread_ids:
return super()._do_return_conn(record)
record.close()
def shutdown(self) -> None:
"""Close the connection."""
if (
self.recorder_or_dbworker
threading.get_ident() in self.recorder_and_worker_thread_ids
and self._conn
and hasattr(self._conn, "current")
and (conn := self._conn.current())
@ -70,11 +83,11 @@ class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc]
def dispose(self) -> None:
"""Dispose of the connection."""
if self.recorder_or_dbworker:
if threading.get_ident() in self.recorder_and_worker_thread_ids:
super().dispose()
def _do_get(self) -> ConnectionPoolEntry:
if self.recorder_or_dbworker:
if threading.get_ident() in self.recorder_and_worker_thread_ids:
return super()._do_get()
check_loop(
self._do_get_db_connection_protected,

View file

@ -14,6 +14,7 @@ from unittest.mock import MagicMock, Mock, patch
from freezegun.api import FrozenDateTimeFactory
import pytest
from sqlalchemy.exc import DatabaseError, OperationalError, SQLAlchemyError
from sqlalchemy.pool import QueuePool
from homeassistant.components import recorder
from homeassistant.components.recorder import (
@ -30,7 +31,6 @@ from homeassistant.components.recorder import (
db_schema,
get_instance,
migration,
pool,
statistics,
)
from homeassistant.components.recorder.const import (
@ -2265,7 +2265,7 @@ async def test_connect_args_priority(hass: HomeAssistant, config_url) -> None:
def engine_created(*args): ...
def get_dialect_pool_class(self, *args):
return pool.RecorderPool
return QueuePool
def initialize(*args): ...

View file

@ -12,20 +12,32 @@ from homeassistant.components.recorder.pool import RecorderPool
async def test_recorder_pool_called_from_event_loop() -> None:
"""Test we raise an exception when calling from the event loop."""
engine = create_engine("sqlite://", poolclass=RecorderPool)
recorder_and_worker_thread_ids: set[int] = set()
engine = create_engine(
"sqlite://",
poolclass=RecorderPool,
recorder_and_worker_thread_ids=recorder_and_worker_thread_ids,
)
with pytest.raises(RuntimeError):
sessionmaker(bind=engine)().connection()
def test_recorder_pool(caplog: pytest.LogCaptureFixture) -> None:
"""Test RecorderPool gives the same connection in the creating thread."""
engine = create_engine("sqlite://", poolclass=RecorderPool)
recorder_and_worker_thread_ids: set[int] = set()
engine = create_engine(
"sqlite://",
poolclass=RecorderPool,
recorder_and_worker_thread_ids=recorder_and_worker_thread_ids,
)
get_session = sessionmaker(bind=engine)
shutdown = False
connections = []
add_thread = False
def _get_connection_twice():
if add_thread:
recorder_and_worker_thread_ids.add(threading.get_ident())
session = get_session()
connections.append(session.connection().connection.driver_connection)
session.close()
@ -44,6 +56,7 @@ def test_recorder_pool(caplog: pytest.LogCaptureFixture) -> None:
assert "accesses the database without the database executor" in caplog.text
assert connections[0] != connections[1]
add_thread = True
caplog.clear()
new_thread = threading.Thread(target=_get_connection_twice, name=DB_WORKER_PREFIX)
new_thread.start()