Improve recorder and worker thread matching in RecorderPool (#116886)
* Improve recorder and worker thread matching in RecorderPool Previously we would look at the name of the threads. This was a brittle if because other integrations may name their thread Recorder or DbWorker. Instead we now use explict thread ids which ensures there will never be a conflict * fix * fixes * fixes
This commit is contained in:
parent
ee031f4850
commit
6339c63176
5 changed files with 65 additions and 23 deletions
|
@ -187,6 +187,7 @@ class Recorder(threading.Thread):
|
|||
|
||||
self.hass = hass
|
||||
self.thread_id: int | None = None
|
||||
self.recorder_and_worker_thread_ids: set[int] = set()
|
||||
self.auto_purge = auto_purge
|
||||
self.auto_repack = auto_repack
|
||||
self.keep_days = keep_days
|
||||
|
@ -294,6 +295,7 @@ class Recorder(threading.Thread):
|
|||
def async_start_executor(self) -> None:
|
||||
"""Start the executor."""
|
||||
self._db_executor = DBInterruptibleThreadPoolExecutor(
|
||||
self.recorder_and_worker_thread_ids,
|
||||
thread_name_prefix=DB_WORKER_PREFIX,
|
||||
max_workers=MAX_DB_EXECUTOR_WORKERS,
|
||||
shutdown_hook=self._shutdown_pool,
|
||||
|
@ -717,7 +719,10 @@ class Recorder(threading.Thread):
|
|||
|
||||
def _run(self) -> None:
|
||||
"""Start processing events to save."""
|
||||
self.thread_id = threading.get_ident()
|
||||
thread_id = threading.get_ident()
|
||||
self.thread_id = thread_id
|
||||
self.recorder_and_worker_thread_ids.add(thread_id)
|
||||
|
||||
setup_result = self._setup_recorder()
|
||||
|
||||
if not setup_result:
|
||||
|
@ -1411,6 +1416,9 @@ class Recorder(threading.Thread):
|
|||
kwargs["pool_reset_on_return"] = None
|
||||
elif self.db_url.startswith(SQLITE_URL_PREFIX):
|
||||
kwargs["poolclass"] = RecorderPool
|
||||
kwargs["recorder_and_worker_thread_ids"] = (
|
||||
self.recorder_and_worker_thread_ids
|
||||
)
|
||||
elif self.db_url.startswith(
|
||||
(
|
||||
MARIADB_URL_PREFIX,
|
||||
|
|
|
@ -12,9 +12,13 @@ from homeassistant.util.executor import InterruptibleThreadPoolExecutor
|
|||
|
||||
|
||||
def _worker_with_shutdown_hook(
|
||||
shutdown_hook: Callable[[], None], *args: Any, **kwargs: Any
|
||||
shutdown_hook: Callable[[], None],
|
||||
recorder_and_worker_thread_ids: set[int],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a worker that calls a function after its finished."""
|
||||
recorder_and_worker_thread_ids.add(threading.get_ident())
|
||||
_worker(*args, **kwargs)
|
||||
shutdown_hook()
|
||||
|
||||
|
@ -22,9 +26,12 @@ def _worker_with_shutdown_hook(
|
|||
class DBInterruptibleThreadPoolExecutor(InterruptibleThreadPoolExecutor):
|
||||
"""A database instance that will not deadlock on shutdown."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
def __init__(
|
||||
self, recorder_and_worker_thread_ids: set[int], *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
"""Init the executor with a shutdown hook support."""
|
||||
self._shutdown_hook: Callable[[], None] = kwargs.pop("shutdown_hook")
|
||||
self.recorder_and_worker_thread_ids = recorder_and_worker_thread_ids
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _adjust_thread_count(self) -> None:
|
||||
|
@ -54,6 +61,7 @@ class DBInterruptibleThreadPoolExecutor(InterruptibleThreadPoolExecutor):
|
|||
target=_worker_with_shutdown_hook,
|
||||
args=(
|
||||
self._shutdown_hook,
|
||||
self.recorder_and_worker_thread_ids,
|
||||
weakref.ref(self, weakref_cb),
|
||||
self._work_queue,
|
||||
self._initializer,
|
||||
|
|
|
@ -16,8 +16,6 @@ from sqlalchemy.pool import (
|
|||
from homeassistant.helpers.frame import report
|
||||
from homeassistant.util.loop import check_loop
|
||||
|
||||
from .const import DB_WORKER_PREFIX
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# For debugging the MutexPool
|
||||
|
@ -31,7 +29,7 @@ ADVISE_MSG = (
|
|||
)
|
||||
|
||||
|
||||
class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc]
|
||||
class RecorderPool(SingletonThreadPool, NullPool):
|
||||
"""A hybrid of NullPool and SingletonThreadPool.
|
||||
|
||||
When called from the creating thread or db executor acts like SingletonThreadPool
|
||||
|
@ -39,29 +37,44 @@ class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc]
|
|||
"""
|
||||
|
||||
def __init__( # pylint: disable=super-init-not-called
|
||||
self, *args: Any, **kw: Any
|
||||
self,
|
||||
creator: Any,
|
||||
recorder_and_worker_thread_ids: set[int] | None = None,
|
||||
**kw: Any,
|
||||
) -> None:
|
||||
"""Create the pool."""
|
||||
kw["pool_size"] = POOL_SIZE
|
||||
SingletonThreadPool.__init__(self, *args, **kw)
|
||||
assert (
|
||||
recorder_and_worker_thread_ids is not None
|
||||
), "recorder_and_worker_thread_ids is required"
|
||||
self.recorder_and_worker_thread_ids = recorder_and_worker_thread_ids
|
||||
SingletonThreadPool.__init__(self, creator, **kw)
|
||||
|
||||
@property
|
||||
def recorder_or_dbworker(self) -> bool:
|
||||
"""Check if the thread is a recorder or dbworker thread."""
|
||||
thread_name = threading.current_thread().name
|
||||
return bool(
|
||||
thread_name == "Recorder" or thread_name.startswith(DB_WORKER_PREFIX)
|
||||
def recreate(self) -> "RecorderPool":
|
||||
"""Recreate the pool."""
|
||||
self.logger.info("Pool recreating")
|
||||
return self.__class__(
|
||||
self._creator,
|
||||
pool_size=self.size,
|
||||
recycle=self._recycle,
|
||||
echo=self.echo,
|
||||
pre_ping=self._pre_ping,
|
||||
logging_name=self._orig_logging_name,
|
||||
reset_on_return=self._reset_on_return,
|
||||
_dispatch=self.dispatch,
|
||||
dialect=self._dialect,
|
||||
recorder_and_worker_thread_ids=self.recorder_and_worker_thread_ids,
|
||||
)
|
||||
|
||||
def _do_return_conn(self, record: ConnectionPoolEntry) -> None:
|
||||
if self.recorder_or_dbworker:
|
||||
if threading.get_ident() in self.recorder_and_worker_thread_ids:
|
||||
return super()._do_return_conn(record)
|
||||
record.close()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Close the connection."""
|
||||
if (
|
||||
self.recorder_or_dbworker
|
||||
threading.get_ident() in self.recorder_and_worker_thread_ids
|
||||
and self._conn
|
||||
and hasattr(self._conn, "current")
|
||||
and (conn := self._conn.current())
|
||||
|
@ -70,11 +83,11 @@ class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc]
|
|||
|
||||
def dispose(self) -> None:
|
||||
"""Dispose of the connection."""
|
||||
if self.recorder_or_dbworker:
|
||||
if threading.get_ident() in self.recorder_and_worker_thread_ids:
|
||||
super().dispose()
|
||||
|
||||
def _do_get(self) -> ConnectionPoolEntry:
|
||||
if self.recorder_or_dbworker:
|
||||
if threading.get_ident() in self.recorder_and_worker_thread_ids:
|
||||
return super()._do_get()
|
||||
check_loop(
|
||||
self._do_get_db_connection_protected,
|
||||
|
|
|
@ -14,6 +14,7 @@ from unittest.mock import MagicMock, Mock, patch
|
|||
from freezegun.api import FrozenDateTimeFactory
|
||||
import pytest
|
||||
from sqlalchemy.exc import DatabaseError, OperationalError, SQLAlchemyError
|
||||
from sqlalchemy.pool import QueuePool
|
||||
|
||||
from homeassistant.components import recorder
|
||||
from homeassistant.components.recorder import (
|
||||
|
@ -30,7 +31,6 @@ from homeassistant.components.recorder import (
|
|||
db_schema,
|
||||
get_instance,
|
||||
migration,
|
||||
pool,
|
||||
statistics,
|
||||
)
|
||||
from homeassistant.components.recorder.const import (
|
||||
|
@ -2265,7 +2265,7 @@ async def test_connect_args_priority(hass: HomeAssistant, config_url) -> None:
|
|||
def engine_created(*args): ...
|
||||
|
||||
def get_dialect_pool_class(self, *args):
|
||||
return pool.RecorderPool
|
||||
return QueuePool
|
||||
|
||||
def initialize(*args): ...
|
||||
|
||||
|
|
|
@ -12,20 +12,32 @@ from homeassistant.components.recorder.pool import RecorderPool
|
|||
|
||||
async def test_recorder_pool_called_from_event_loop() -> None:
|
||||
"""Test we raise an exception when calling from the event loop."""
|
||||
engine = create_engine("sqlite://", poolclass=RecorderPool)
|
||||
recorder_and_worker_thread_ids: set[int] = set()
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
poolclass=RecorderPool,
|
||||
recorder_and_worker_thread_ids=recorder_and_worker_thread_ids,
|
||||
)
|
||||
with pytest.raises(RuntimeError):
|
||||
sessionmaker(bind=engine)().connection()
|
||||
|
||||
|
||||
def test_recorder_pool(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""Test RecorderPool gives the same connection in the creating thread."""
|
||||
|
||||
engine = create_engine("sqlite://", poolclass=RecorderPool)
|
||||
recorder_and_worker_thread_ids: set[int] = set()
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
poolclass=RecorderPool,
|
||||
recorder_and_worker_thread_ids=recorder_and_worker_thread_ids,
|
||||
)
|
||||
get_session = sessionmaker(bind=engine)
|
||||
shutdown = False
|
||||
connections = []
|
||||
add_thread = False
|
||||
|
||||
def _get_connection_twice():
|
||||
if add_thread:
|
||||
recorder_and_worker_thread_ids.add(threading.get_ident())
|
||||
session = get_session()
|
||||
connections.append(session.connection().connection.driver_connection)
|
||||
session.close()
|
||||
|
@ -44,6 +56,7 @@ def test_recorder_pool(caplog: pytest.LogCaptureFixture) -> None:
|
|||
assert "accesses the database without the database executor" in caplog.text
|
||||
assert connections[0] != connections[1]
|
||||
|
||||
add_thread = True
|
||||
caplog.clear()
|
||||
new_thread = threading.Thread(target=_get_connection_twice, name=DB_WORKER_PREFIX)
|
||||
new_thread.start()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue