diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index 92d9baed771..281b130486f 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -187,6 +187,7 @@ class Recorder(threading.Thread): self.hass = hass self.thread_id: int | None = None + self.recorder_and_worker_thread_ids: set[int] = set() self.auto_purge = auto_purge self.auto_repack = auto_repack self.keep_days = keep_days @@ -294,6 +295,7 @@ class Recorder(threading.Thread): def async_start_executor(self) -> None: """Start the executor.""" self._db_executor = DBInterruptibleThreadPoolExecutor( + self.recorder_and_worker_thread_ids, thread_name_prefix=DB_WORKER_PREFIX, max_workers=MAX_DB_EXECUTOR_WORKERS, shutdown_hook=self._shutdown_pool, @@ -717,7 +719,10 @@ class Recorder(threading.Thread): def _run(self) -> None: """Start processing events to save.""" - self.thread_id = threading.get_ident() + thread_id = threading.get_ident() + self.thread_id = thread_id + self.recorder_and_worker_thread_ids.add(thread_id) + setup_result = self._setup_recorder() if not setup_result: @@ -1411,6 +1416,9 @@ class Recorder(threading.Thread): kwargs["pool_reset_on_return"] = None elif self.db_url.startswith(SQLITE_URL_PREFIX): kwargs["poolclass"] = RecorderPool + kwargs["recorder_and_worker_thread_ids"] = ( + self.recorder_and_worker_thread_ids + ) elif self.db_url.startswith( ( MARIADB_URL_PREFIX, diff --git a/homeassistant/components/recorder/executor.py b/homeassistant/components/recorder/executor.py index b17547499e8..8102c769ac1 100644 --- a/homeassistant/components/recorder/executor.py +++ b/homeassistant/components/recorder/executor.py @@ -12,9 +12,13 @@ from homeassistant.util.executor import InterruptibleThreadPoolExecutor def _worker_with_shutdown_hook( - shutdown_hook: Callable[[], None], *args: Any, **kwargs: Any + shutdown_hook: Callable[[], None], + recorder_and_worker_thread_ids: set[int], + *args: Any, + **kwargs: Any, ) -> None: """Create a worker that calls a function after its finished.""" + recorder_and_worker_thread_ids.add(threading.get_ident()) _worker(*args, **kwargs) shutdown_hook() @@ -22,9 +26,12 @@ def _worker_with_shutdown_hook( class DBInterruptibleThreadPoolExecutor(InterruptibleThreadPoolExecutor): """A database instance that will not deadlock on shutdown.""" - def __init__(self, *args: Any, **kwargs: Any) -> None: + def __init__( + self, recorder_and_worker_thread_ids: set[int], *args: Any, **kwargs: Any + ) -> None: """Init the executor with a shutdown hook support.""" self._shutdown_hook: Callable[[], None] = kwargs.pop("shutdown_hook") + self.recorder_and_worker_thread_ids = recorder_and_worker_thread_ids super().__init__(*args, **kwargs) def _adjust_thread_count(self) -> None: @@ -54,6 +61,7 @@ class DBInterruptibleThreadPoolExecutor(InterruptibleThreadPoolExecutor): target=_worker_with_shutdown_hook, args=( self._shutdown_hook, + self.recorder_and_worker_thread_ids, weakref.ref(self, weakref_cb), self._work_queue, self._initializer, diff --git a/homeassistant/components/recorder/pool.py b/homeassistant/components/recorder/pool.py index ec7aa5bdcb6..bc5b02983da 100644 --- a/homeassistant/components/recorder/pool.py +++ b/homeassistant/components/recorder/pool.py @@ -16,8 +16,6 @@ from sqlalchemy.pool import ( from homeassistant.helpers.frame import report from homeassistant.util.loop import check_loop -from .const import DB_WORKER_PREFIX - _LOGGER = logging.getLogger(__name__) # For debugging the MutexPool @@ -31,7 +29,7 @@ ADVISE_MSG = ( ) -class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc] +class RecorderPool(SingletonThreadPool, NullPool): """A hybrid of NullPool and SingletonThreadPool. When called from the creating thread or db executor acts like SingletonThreadPool @@ -39,29 +37,44 @@ class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc] """ def __init__( # pylint: disable=super-init-not-called - self, *args: Any, **kw: Any + self, + creator: Any, + recorder_and_worker_thread_ids: set[int] | None = None, + **kw: Any, ) -> None: """Create the pool.""" kw["pool_size"] = POOL_SIZE - SingletonThreadPool.__init__(self, *args, **kw) + assert ( + recorder_and_worker_thread_ids is not None + ), "recorder_and_worker_thread_ids is required" + self.recorder_and_worker_thread_ids = recorder_and_worker_thread_ids + SingletonThreadPool.__init__(self, creator, **kw) - @property - def recorder_or_dbworker(self) -> bool: - """Check if the thread is a recorder or dbworker thread.""" - thread_name = threading.current_thread().name - return bool( - thread_name == "Recorder" or thread_name.startswith(DB_WORKER_PREFIX) + def recreate(self) -> "RecorderPool": + """Recreate the pool.""" + self.logger.info("Pool recreating") + return self.__class__( + self._creator, + pool_size=self.size, + recycle=self._recycle, + echo=self.echo, + pre_ping=self._pre_ping, + logging_name=self._orig_logging_name, + reset_on_return=self._reset_on_return, + _dispatch=self.dispatch, + dialect=self._dialect, + recorder_and_worker_thread_ids=self.recorder_and_worker_thread_ids, ) def _do_return_conn(self, record: ConnectionPoolEntry) -> None: - if self.recorder_or_dbworker: + if threading.get_ident() in self.recorder_and_worker_thread_ids: return super()._do_return_conn(record) record.close() def shutdown(self) -> None: """Close the connection.""" if ( - self.recorder_or_dbworker + threading.get_ident() in self.recorder_and_worker_thread_ids and self._conn and hasattr(self._conn, "current") and (conn := self._conn.current()) @@ -70,11 +83,11 @@ class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc] def dispose(self) -> None: """Dispose of the connection.""" - if self.recorder_or_dbworker: + if threading.get_ident() in self.recorder_and_worker_thread_ids: super().dispose() def _do_get(self) -> ConnectionPoolEntry: - if self.recorder_or_dbworker: + if threading.get_ident() in self.recorder_and_worker_thread_ids: return super()._do_get() check_loop( self._do_get_db_connection_protected, diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index d9f0e7d296f..feeb7e04547 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -14,6 +14,7 @@ from unittest.mock import MagicMock, Mock, patch from freezegun.api import FrozenDateTimeFactory import pytest from sqlalchemy.exc import DatabaseError, OperationalError, SQLAlchemyError +from sqlalchemy.pool import QueuePool from homeassistant.components import recorder from homeassistant.components.recorder import ( @@ -30,7 +31,6 @@ from homeassistant.components.recorder import ( db_schema, get_instance, migration, - pool, statistics, ) from homeassistant.components.recorder.const import ( @@ -2265,7 +2265,7 @@ async def test_connect_args_priority(hass: HomeAssistant, config_url) -> None: def engine_created(*args): ... def get_dialect_pool_class(self, *args): - return pool.RecorderPool + return QueuePool def initialize(*args): ... diff --git a/tests/components/recorder/test_pool.py b/tests/components/recorder/test_pool.py index 541fc8d714b..3cca095399b 100644 --- a/tests/components/recorder/test_pool.py +++ b/tests/components/recorder/test_pool.py @@ -12,20 +12,32 @@ from homeassistant.components.recorder.pool import RecorderPool async def test_recorder_pool_called_from_event_loop() -> None: """Test we raise an exception when calling from the event loop.""" - engine = create_engine("sqlite://", poolclass=RecorderPool) + recorder_and_worker_thread_ids: set[int] = set() + engine = create_engine( + "sqlite://", + poolclass=RecorderPool, + recorder_and_worker_thread_ids=recorder_and_worker_thread_ids, + ) with pytest.raises(RuntimeError): sessionmaker(bind=engine)().connection() def test_recorder_pool(caplog: pytest.LogCaptureFixture) -> None: """Test RecorderPool gives the same connection in the creating thread.""" - - engine = create_engine("sqlite://", poolclass=RecorderPool) + recorder_and_worker_thread_ids: set[int] = set() + engine = create_engine( + "sqlite://", + poolclass=RecorderPool, + recorder_and_worker_thread_ids=recorder_and_worker_thread_ids, + ) get_session = sessionmaker(bind=engine) shutdown = False connections = [] + add_thread = False def _get_connection_twice(): + if add_thread: + recorder_and_worker_thread_ids.add(threading.get_ident()) session = get_session() connections.append(session.connection().connection.driver_connection) session.close() @@ -44,6 +56,7 @@ def test_recorder_pool(caplog: pytest.LogCaptureFixture) -> None: assert "accesses the database without the database executor" in caplog.text assert connections[0] != connections[1] + add_thread = True caplog.clear() new_thread = threading.Thread(target=_get_connection_twice, name=DB_WORKER_PREFIX) new_thread.start()