Fail tests if recorder creates nested sessions (#122579)

* Fail tests if recorder creates nested sessions

* Adjust import order

* Move get_instance
This commit is contained in:
Erik Montnemery 2024-07-25 21:18:55 +02:00 committed by GitHub
parent 32a0463f47
commit 5dbd7684ce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 139 additions and 49 deletions

View file

@ -23,6 +23,7 @@ from homeassistant.helpers.entityfilter import (
from homeassistant.helpers.integration_platform import (
async_process_integration_platforms,
)
from homeassistant.helpers.recorder import DATA_INSTANCE
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass
from homeassistant.util.event_type import EventType
@ -30,7 +31,6 @@ from homeassistant.util.event_type import EventType
from . import entity_registry, websocket_api
from .const import ( # noqa: F401
CONF_DB_INTEGRITY_CHECK,
DATA_INSTANCE,
DOMAIN,
INTEGRATION_PLATFORM_COMPILE_STATISTICS,
INTEGRATION_PLATFORMS_RUN_IN_RECORDER_THREAD,

View file

@ -13,15 +13,11 @@ from homeassistant.const import (
EVENT_RECORDER_HOURLY_STATISTICS_GENERATED, # noqa: F401
)
from homeassistant.helpers.json import JSON_DUMP # noqa: F401
from homeassistant.util.hass_dict import HassKey
if TYPE_CHECKING:
from .core import Recorder # noqa: F401
DATA_INSTANCE: HassKey[Recorder] = HassKey("recorder_instance")
SQLITE_URL_PREFIX = "sqlite://"
MARIADB_URL_PREFIX = "mariadb://"
MARIADB_PYMYSQL_URL_PREFIX = "mariadb+pymysql://"

View file

@ -29,10 +29,14 @@ import voluptuous as vol
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, issue_registry as ir
from homeassistant.helpers.recorder import ( # noqa: F401
DATA_INSTANCE,
get_instance,
session_scope,
)
import homeassistant.util.dt as dt_util
from .const import (
DATA_INSTANCE,
DEFAULT_MAX_BIND_VARS,
DOMAIN,
SQLITE_MAX_BIND_VARS,
@ -111,42 +115,6 @@ SUNDAY_WEEKDAY = 6
DAYS_IN_WEEK = 7
@contextmanager
def session_scope(
*,
hass: HomeAssistant | None = None,
session: Session | None = None,
exception_filter: Callable[[Exception], bool] | None = None,
read_only: bool = False,
) -> Generator[Session]:
"""Provide a transactional scope around a series of operations.
read_only is used to indicate that the session is only used for reading
data and that no commit is required. It does not prevent the session
from writing and is not a security measure.
"""
if session is None and hass is not None:
session = get_instance(hass).get_session()
if session is None:
raise RuntimeError("Session required")
need_rollback = False
try:
yield session
if not read_only and session.get_transaction():
need_rollback = True
session.commit()
except Exception as err:
_LOGGER.exception("Error executing query")
if need_rollback:
session.rollback()
if not exception_filter or not exception_filter(err):
raise
finally:
session.close()
def execute(
qry: Query, to_native: bool = False, validate_entity_ids: bool = True
) -> list[Row]:
@ -769,12 +737,6 @@ def is_second_sunday(date_time: datetime) -> bool:
return bool(second_sunday(date_time.year, date_time.month).day == date_time.day)
@functools.lru_cache(maxsize=1)
def get_instance(hass: HomeAssistant) -> Recorder:
"""Get the recorder instance."""
return hass.data[DATA_INSTANCE]
PERIOD_SCHEMA = vol.Schema(
{
vol.Exclusive("calendar", "period"): vol.Schema(

View file

@ -3,13 +3,25 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable, Generator
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any
import functools
import logging
from typing import TYPE_CHECKING, Any
from homeassistant.core import HomeAssistant, callback
from homeassistant.util.hass_dict import HassKey
if TYPE_CHECKING:
from sqlalchemy.orm.session import Session
from homeassistant.components.recorder import Recorder
_LOGGER = logging.getLogger(__name__)
DOMAIN: HassKey[RecorderData] = HassKey("recorder")
DATA_INSTANCE: HassKey[Recorder] = HassKey("recorder_instance")
@dataclass(slots=True)
@ -56,3 +68,45 @@ async def async_wait_recorder(hass: HomeAssistant) -> bool:
if DOMAIN not in hass.data:
return False
return await hass.data[DOMAIN].db_connected
@functools.lru_cache(maxsize=1)
def get_instance(hass: HomeAssistant) -> Recorder:
"""Get the recorder instance."""
return hass.data[DATA_INSTANCE]
@contextmanager
def session_scope(
*,
hass: HomeAssistant | None = None,
session: Session | None = None,
exception_filter: Callable[[Exception], bool] | None = None,
read_only: bool = False,
) -> Generator[Session]:
"""Provide a transactional scope around a series of operations.
read_only is used to indicate that the session is only used for reading
data and that no commit is required. It does not prevent the session
from writing and is not a security measure.
"""
if session is None and hass is not None:
session = get_instance(hass).get_session()
if session is None:
raise RuntimeError("Session required")
need_rollback = False
try:
yield session
if not read_only and session.get_transaction():
need_rollback = True
session.commit()
except Exception as err:
_LOGGER.exception("Error executing query")
if need_rollback:
session.rollback()
if not exception_filter or not exception_filter(err):
raise
finally:
session.close()

View file

@ -39,6 +39,9 @@ from syrupy.assertion import SnapshotAssertion
from homeassistant import block_async_io
from homeassistant.exceptions import ServiceNotFound
# Setup patching of recorder functions before any other Home Assistant imports
from . import patch_recorder # noqa: F401, isort:skip
# Setup patching of dt_util time functions before any other Home Assistant imports
from . import patch_time # noqa: F401, isort:skip
@ -1423,6 +1426,15 @@ async def _async_init_recorder_component(
)
class ThreadSession(threading.local):
"""Keep track of session per thread."""
has_session = False
thread_session = ThreadSession()
@pytest.fixture
async def async_test_recorder(
recorder_db_url: str,
@ -1444,6 +1456,39 @@ async def async_test_recorder(
# pylint: disable-next=import-outside-toplevel
from .components.recorder.common import async_recorder_block_till_done
# pylint: disable-next=import-outside-toplevel
from .patch_recorder import real_session_scope
if TYPE_CHECKING:
# pylint: disable-next=import-outside-toplevel
from sqlalchemy.orm.session import Session
@contextmanager
def debug_session_scope(
*,
hass: HomeAssistant | None = None,
session: Session | None = None,
exception_filter: Callable[[Exception], bool] | None = None,
read_only: bool = False,
) -> Generator[Session]:
"""Wrap session_scope to bark if we create nested sessions."""
if thread_session.has_session:
raise RuntimeError(
f"Thread '{threading.current_thread().name}' already has an "
"active session"
)
thread_session.has_session = True
try:
with real_session_scope(
hass=hass,
session=session,
exception_filter=exception_filter,
read_only=read_only,
) as ses:
yield ses
finally:
thread_session.has_session = False
nightly = recorder.Recorder.async_nightly_tasks if enable_nightly_purge else None
stats = recorder.Recorder.async_periodic_statistics if enable_statistics else None
schema_validate = (
@ -1525,6 +1570,12 @@ async def async_test_recorder(
side_effect=compile_missing,
autospec=True,
),
patch.object(
patch_recorder,
"real_session_scope",
side_effect=debug_session_scope,
autospec=True,
),
):
@asynccontextmanager

27
tests/patch_recorder.py Normal file
View file

@ -0,0 +1,27 @@
"""Patch recorder related functions."""
from __future__ import annotations
from contextlib import contextmanager
import sys
# Patch recorder util session scope
from homeassistant.helpers import recorder as recorder_helper # noqa: E402
# Make sure homeassistant.components.recorder.util is not already imported
assert "homeassistant.components.recorder.util" not in sys.modules
real_session_scope = recorder_helper.session_scope
@contextmanager
def _session_scope_wrapper(*args, **kwargs):
"""Make session_scope patchable.
This function will be imported by recorder modules.
"""
with real_session_scope(*args, **kwargs) as ses:
yield ses
recorder_helper.session_scope = _session_scope_wrapper