diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index b9ba90caf3f..41fa8db5814 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -23,6 +23,7 @@ from homeassistant.helpers.entityfilter import ( from homeassistant.helpers.integration_platform import ( async_process_integration_platforms, ) +from homeassistant.helpers.recorder import DATA_INSTANCE from homeassistant.helpers.typing import ConfigType from homeassistant.loader import bind_hass from homeassistant.util.event_type import EventType @@ -30,7 +31,6 @@ from homeassistant.util.event_type import EventType from . import entity_registry, websocket_api from .const import ( # noqa: F401 CONF_DB_INTEGRITY_CHECK, - DATA_INSTANCE, DOMAIN, INTEGRATION_PLATFORM_COMPILE_STATISTICS, INTEGRATION_PLATFORMS_RUN_IN_RECORDER_THREAD, diff --git a/homeassistant/components/recorder/const.py b/homeassistant/components/recorder/const.py index 31870a5db2d..c7dba18cad9 100644 --- a/homeassistant/components/recorder/const.py +++ b/homeassistant/components/recorder/const.py @@ -13,15 +13,11 @@ from homeassistant.const import ( EVENT_RECORDER_HOURLY_STATISTICS_GENERATED, # noqa: F401 ) from homeassistant.helpers.json import JSON_DUMP # noqa: F401 -from homeassistant.util.hass_dict import HassKey if TYPE_CHECKING: from .core import Recorder # noqa: F401 -DATA_INSTANCE: HassKey[Recorder] = HassKey("recorder_instance") - - SQLITE_URL_PREFIX = "sqlite://" MARIADB_URL_PREFIX = "mariadb://" MARIADB_PYMYSQL_URL_PREFIX = "mariadb+pymysql://" diff --git a/homeassistant/components/recorder/util.py b/homeassistant/components/recorder/util.py index 89621821ff8..1ef85b28f8d 100644 --- a/homeassistant/components/recorder/util.py +++ b/homeassistant/components/recorder/util.py @@ -29,10 +29,14 @@ import voluptuous as vol from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import config_validation as cv, issue_registry as ir +from homeassistant.helpers.recorder import ( # noqa: F401 + DATA_INSTANCE, + get_instance, + session_scope, +) import homeassistant.util.dt as dt_util from .const import ( - DATA_INSTANCE, DEFAULT_MAX_BIND_VARS, DOMAIN, SQLITE_MAX_BIND_VARS, @@ -111,42 +115,6 @@ SUNDAY_WEEKDAY = 6 DAYS_IN_WEEK = 7 -@contextmanager -def session_scope( - *, - hass: HomeAssistant | None = None, - session: Session | None = None, - exception_filter: Callable[[Exception], bool] | None = None, - read_only: bool = False, -) -> Generator[Session]: - """Provide a transactional scope around a series of operations. - - read_only is used to indicate that the session is only used for reading - data and that no commit is required. It does not prevent the session - from writing and is not a security measure. - """ - if session is None and hass is not None: - session = get_instance(hass).get_session() - - if session is None: - raise RuntimeError("Session required") - - need_rollback = False - try: - yield session - if not read_only and session.get_transaction(): - need_rollback = True - session.commit() - except Exception as err: - _LOGGER.exception("Error executing query") - if need_rollback: - session.rollback() - if not exception_filter or not exception_filter(err): - raise - finally: - session.close() - - def execute( qry: Query, to_native: bool = False, validate_entity_ids: bool = True ) -> list[Row]: @@ -769,12 +737,6 @@ def is_second_sunday(date_time: datetime) -> bool: return bool(second_sunday(date_time.year, date_time.month).day == date_time.day) -@functools.lru_cache(maxsize=1) -def get_instance(hass: HomeAssistant) -> Recorder: - """Get the recorder instance.""" - return hass.data[DATA_INSTANCE] - - PERIOD_SCHEMA = vol.Schema( { vol.Exclusive("calendar", "period"): vol.Schema( diff --git a/homeassistant/helpers/recorder.py b/homeassistant/helpers/recorder.py index f6657efc6d7..59604944eeb 100644 --- a/homeassistant/helpers/recorder.py +++ b/homeassistant/helpers/recorder.py @@ -3,13 +3,25 @@ from __future__ import annotations import asyncio +from collections.abc import Callable, Generator +from contextlib import contextmanager from dataclasses import dataclass, field -from typing import Any +import functools +import logging +from typing import TYPE_CHECKING, Any from homeassistant.core import HomeAssistant, callback from homeassistant.util.hass_dict import HassKey +if TYPE_CHECKING: + from sqlalchemy.orm.session import Session + + from homeassistant.components.recorder import Recorder + +_LOGGER = logging.getLogger(__name__) + DOMAIN: HassKey[RecorderData] = HassKey("recorder") +DATA_INSTANCE: HassKey[Recorder] = HassKey("recorder_instance") @dataclass(slots=True) @@ -56,3 +68,45 @@ async def async_wait_recorder(hass: HomeAssistant) -> bool: if DOMAIN not in hass.data: return False return await hass.data[DOMAIN].db_connected + + +@functools.lru_cache(maxsize=1) +def get_instance(hass: HomeAssistant) -> Recorder: + """Get the recorder instance.""" + return hass.data[DATA_INSTANCE] + + +@contextmanager +def session_scope( + *, + hass: HomeAssistant | None = None, + session: Session | None = None, + exception_filter: Callable[[Exception], bool] | None = None, + read_only: bool = False, +) -> Generator[Session]: + """Provide a transactional scope around a series of operations. + + read_only is used to indicate that the session is only used for reading + data and that no commit is required. It does not prevent the session + from writing and is not a security measure. + """ + if session is None and hass is not None: + session = get_instance(hass).get_session() + + if session is None: + raise RuntimeError("Session required") + + need_rollback = False + try: + yield session + if not read_only and session.get_transaction(): + need_rollback = True + session.commit() + except Exception as err: + _LOGGER.exception("Error executing query") + if need_rollback: + session.rollback() + if not exception_filter or not exception_filter(err): + raise + finally: + session.close() diff --git a/tests/conftest.py b/tests/conftest.py index bc139255e66..de0dbc2e0d2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -39,6 +39,9 @@ from syrupy.assertion import SnapshotAssertion from homeassistant import block_async_io from homeassistant.exceptions import ServiceNotFound +# Setup patching of recorder functions before any other Home Assistant imports +from . import patch_recorder # noqa: F401, isort:skip + # Setup patching of dt_util time functions before any other Home Assistant imports from . import patch_time # noqa: F401, isort:skip @@ -1423,6 +1426,15 @@ async def _async_init_recorder_component( ) +class ThreadSession(threading.local): + """Keep track of session per thread.""" + + has_session = False + + +thread_session = ThreadSession() + + @pytest.fixture async def async_test_recorder( recorder_db_url: str, @@ -1444,6 +1456,39 @@ async def async_test_recorder( # pylint: disable-next=import-outside-toplevel from .components.recorder.common import async_recorder_block_till_done + # pylint: disable-next=import-outside-toplevel + from .patch_recorder import real_session_scope + + if TYPE_CHECKING: + # pylint: disable-next=import-outside-toplevel + from sqlalchemy.orm.session import Session + + @contextmanager + def debug_session_scope( + *, + hass: HomeAssistant | None = None, + session: Session | None = None, + exception_filter: Callable[[Exception], bool] | None = None, + read_only: bool = False, + ) -> Generator[Session]: + """Wrap session_scope to bark if we create nested sessions.""" + if thread_session.has_session: + raise RuntimeError( + f"Thread '{threading.current_thread().name}' already has an " + "active session" + ) + thread_session.has_session = True + try: + with real_session_scope( + hass=hass, + session=session, + exception_filter=exception_filter, + read_only=read_only, + ) as ses: + yield ses + finally: + thread_session.has_session = False + nightly = recorder.Recorder.async_nightly_tasks if enable_nightly_purge else None stats = recorder.Recorder.async_periodic_statistics if enable_statistics else None schema_validate = ( @@ -1525,6 +1570,12 @@ async def async_test_recorder( side_effect=compile_missing, autospec=True, ), + patch.object( + patch_recorder, + "real_session_scope", + side_effect=debug_session_scope, + autospec=True, + ), ): @asynccontextmanager diff --git a/tests/patch_recorder.py b/tests/patch_recorder.py new file mode 100644 index 00000000000..4993e84fc30 --- /dev/null +++ b/tests/patch_recorder.py @@ -0,0 +1,27 @@ +"""Patch recorder related functions.""" + +from __future__ import annotations + +from contextlib import contextmanager +import sys + +# Patch recorder util session scope +from homeassistant.helpers import recorder as recorder_helper # noqa: E402 + +# Make sure homeassistant.components.recorder.util is not already imported +assert "homeassistant.components.recorder.util" not in sys.modules + +real_session_scope = recorder_helper.session_scope + + +@contextmanager +def _session_scope_wrapper(*args, **kwargs): + """Make session_scope patchable. + + This function will be imported by recorder modules. + """ + with real_session_scope(*args, **kwargs) as ses: + yield ses + + +recorder_helper.session_scope = _session_scope_wrapper