Fail tests if recorder creates nested sessions (#122579)
* Fail tests if recorder creates nested sessions * Adjust import order * Move get_instance
This commit is contained in:
parent
32a0463f47
commit
5dbd7684ce
6 changed files with 139 additions and 49 deletions
|
@ -23,6 +23,7 @@ from homeassistant.helpers.entityfilter import (
|
|||
from homeassistant.helpers.integration_platform import (
|
||||
async_process_integration_platforms,
|
||||
)
|
||||
from homeassistant.helpers.recorder import DATA_INSTANCE
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util.event_type import EventType
|
||||
|
@ -30,7 +31,6 @@ from homeassistant.util.event_type import EventType
|
|||
from . import entity_registry, websocket_api
|
||||
from .const import ( # noqa: F401
|
||||
CONF_DB_INTEGRITY_CHECK,
|
||||
DATA_INSTANCE,
|
||||
DOMAIN,
|
||||
INTEGRATION_PLATFORM_COMPILE_STATISTICS,
|
||||
INTEGRATION_PLATFORMS_RUN_IN_RECORDER_THREAD,
|
||||
|
|
|
@ -13,15 +13,11 @@ from homeassistant.const import (
|
|||
EVENT_RECORDER_HOURLY_STATISTICS_GENERATED, # noqa: F401
|
||||
)
|
||||
from homeassistant.helpers.json import JSON_DUMP # noqa: F401
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .core import Recorder # noqa: F401
|
||||
|
||||
|
||||
DATA_INSTANCE: HassKey[Recorder] = HassKey("recorder_instance")
|
||||
|
||||
|
||||
SQLITE_URL_PREFIX = "sqlite://"
|
||||
MARIADB_URL_PREFIX = "mariadb://"
|
||||
MARIADB_PYMYSQL_URL_PREFIX = "mariadb+pymysql://"
|
||||
|
|
|
@ -29,10 +29,14 @@ import voluptuous as vol
|
|||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import config_validation as cv, issue_registry as ir
|
||||
from homeassistant.helpers.recorder import ( # noqa: F401
|
||||
DATA_INSTANCE,
|
||||
get_instance,
|
||||
session_scope,
|
||||
)
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
from .const import (
|
||||
DATA_INSTANCE,
|
||||
DEFAULT_MAX_BIND_VARS,
|
||||
DOMAIN,
|
||||
SQLITE_MAX_BIND_VARS,
|
||||
|
@ -111,42 +115,6 @@ SUNDAY_WEEKDAY = 6
|
|||
DAYS_IN_WEEK = 7
|
||||
|
||||
|
||||
@contextmanager
|
||||
def session_scope(
|
||||
*,
|
||||
hass: HomeAssistant | None = None,
|
||||
session: Session | None = None,
|
||||
exception_filter: Callable[[Exception], bool] | None = None,
|
||||
read_only: bool = False,
|
||||
) -> Generator[Session]:
|
||||
"""Provide a transactional scope around a series of operations.
|
||||
|
||||
read_only is used to indicate that the session is only used for reading
|
||||
data and that no commit is required. It does not prevent the session
|
||||
from writing and is not a security measure.
|
||||
"""
|
||||
if session is None and hass is not None:
|
||||
session = get_instance(hass).get_session()
|
||||
|
||||
if session is None:
|
||||
raise RuntimeError("Session required")
|
||||
|
||||
need_rollback = False
|
||||
try:
|
||||
yield session
|
||||
if not read_only and session.get_transaction():
|
||||
need_rollback = True
|
||||
session.commit()
|
||||
except Exception as err:
|
||||
_LOGGER.exception("Error executing query")
|
||||
if need_rollback:
|
||||
session.rollback()
|
||||
if not exception_filter or not exception_filter(err):
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def execute(
|
||||
qry: Query, to_native: bool = False, validate_entity_ids: bool = True
|
||||
) -> list[Row]:
|
||||
|
@ -769,12 +737,6 @@ def is_second_sunday(date_time: datetime) -> bool:
|
|||
return bool(second_sunday(date_time.year, date_time.month).day == date_time.day)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def get_instance(hass: HomeAssistant) -> Recorder:
|
||||
"""Get the recorder instance."""
|
||||
return hass.data[DATA_INSTANCE]
|
||||
|
||||
|
||||
PERIOD_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Exclusive("calendar", "period"): vol.Schema(
|
||||
|
|
|
@ -3,13 +3,25 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
import functools
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from homeassistant.components.recorder import Recorder
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
DOMAIN: HassKey[RecorderData] = HassKey("recorder")
|
||||
DATA_INSTANCE: HassKey[Recorder] = HassKey("recorder_instance")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
|
@ -56,3 +68,45 @@ async def async_wait_recorder(hass: HomeAssistant) -> bool:
|
|||
if DOMAIN not in hass.data:
|
||||
return False
|
||||
return await hass.data[DOMAIN].db_connected
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def get_instance(hass: HomeAssistant) -> Recorder:
|
||||
"""Get the recorder instance."""
|
||||
return hass.data[DATA_INSTANCE]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def session_scope(
|
||||
*,
|
||||
hass: HomeAssistant | None = None,
|
||||
session: Session | None = None,
|
||||
exception_filter: Callable[[Exception], bool] | None = None,
|
||||
read_only: bool = False,
|
||||
) -> Generator[Session]:
|
||||
"""Provide a transactional scope around a series of operations.
|
||||
|
||||
read_only is used to indicate that the session is only used for reading
|
||||
data and that no commit is required. It does not prevent the session
|
||||
from writing and is not a security measure.
|
||||
"""
|
||||
if session is None and hass is not None:
|
||||
session = get_instance(hass).get_session()
|
||||
|
||||
if session is None:
|
||||
raise RuntimeError("Session required")
|
||||
|
||||
need_rollback = False
|
||||
try:
|
||||
yield session
|
||||
if not read_only and session.get_transaction():
|
||||
need_rollback = True
|
||||
session.commit()
|
||||
except Exception as err:
|
||||
_LOGGER.exception("Error executing query")
|
||||
if need_rollback:
|
||||
session.rollback()
|
||||
if not exception_filter or not exception_filter(err):
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
|
|
@ -39,6 +39,9 @@ from syrupy.assertion import SnapshotAssertion
|
|||
from homeassistant import block_async_io
|
||||
from homeassistant.exceptions import ServiceNotFound
|
||||
|
||||
# Setup patching of recorder functions before any other Home Assistant imports
|
||||
from . import patch_recorder # noqa: F401, isort:skip
|
||||
|
||||
# Setup patching of dt_util time functions before any other Home Assistant imports
|
||||
from . import patch_time # noqa: F401, isort:skip
|
||||
|
||||
|
@ -1423,6 +1426,15 @@ async def _async_init_recorder_component(
|
|||
)
|
||||
|
||||
|
||||
class ThreadSession(threading.local):
|
||||
"""Keep track of session per thread."""
|
||||
|
||||
has_session = False
|
||||
|
||||
|
||||
thread_session = ThreadSession()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_test_recorder(
|
||||
recorder_db_url: str,
|
||||
|
@ -1444,6 +1456,39 @@ async def async_test_recorder(
|
|||
# pylint: disable-next=import-outside-toplevel
|
||||
from .components.recorder.common import async_recorder_block_till_done
|
||||
|
||||
# pylint: disable-next=import-outside-toplevel
|
||||
from .patch_recorder import real_session_scope
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# pylint: disable-next=import-outside-toplevel
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
@contextmanager
|
||||
def debug_session_scope(
|
||||
*,
|
||||
hass: HomeAssistant | None = None,
|
||||
session: Session | None = None,
|
||||
exception_filter: Callable[[Exception], bool] | None = None,
|
||||
read_only: bool = False,
|
||||
) -> Generator[Session]:
|
||||
"""Wrap session_scope to bark if we create nested sessions."""
|
||||
if thread_session.has_session:
|
||||
raise RuntimeError(
|
||||
f"Thread '{threading.current_thread().name}' already has an "
|
||||
"active session"
|
||||
)
|
||||
thread_session.has_session = True
|
||||
try:
|
||||
with real_session_scope(
|
||||
hass=hass,
|
||||
session=session,
|
||||
exception_filter=exception_filter,
|
||||
read_only=read_only,
|
||||
) as ses:
|
||||
yield ses
|
||||
finally:
|
||||
thread_session.has_session = False
|
||||
|
||||
nightly = recorder.Recorder.async_nightly_tasks if enable_nightly_purge else None
|
||||
stats = recorder.Recorder.async_periodic_statistics if enable_statistics else None
|
||||
schema_validate = (
|
||||
|
@ -1525,6 +1570,12 @@ async def async_test_recorder(
|
|||
side_effect=compile_missing,
|
||||
autospec=True,
|
||||
),
|
||||
patch.object(
|
||||
patch_recorder,
|
||||
"real_session_scope",
|
||||
side_effect=debug_session_scope,
|
||||
autospec=True,
|
||||
),
|
||||
):
|
||||
|
||||
@asynccontextmanager
|
||||
|
|
27
tests/patch_recorder.py
Normal file
27
tests/patch_recorder.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
"""Patch recorder related functions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
import sys
|
||||
|
||||
# Patch recorder util session scope
|
||||
from homeassistant.helpers import recorder as recorder_helper # noqa: E402
|
||||
|
||||
# Make sure homeassistant.components.recorder.util is not already imported
|
||||
assert "homeassistant.components.recorder.util" not in sys.modules
|
||||
|
||||
real_session_scope = recorder_helper.session_scope
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _session_scope_wrapper(*args, **kwargs):
|
||||
"""Make session_scope patchable.
|
||||
|
||||
This function will be imported by recorder modules.
|
||||
"""
|
||||
with real_session_scope(*args, **kwargs) as ses:
|
||||
yield ses
|
||||
|
||||
|
||||
recorder_helper.session_scope = _session_scope_wrapper
|
Loading…
Add table
Reference in a new issue