Fail tests if recorder creates nested sessions (#122579)
* Fail tests if recorder creates nested sessions * Adjust import order * Move get_instance
This commit is contained in:
parent
32a0463f47
commit
5dbd7684ce
6 changed files with 139 additions and 49 deletions
|
@ -23,6 +23,7 @@ from homeassistant.helpers.entityfilter import (
|
||||||
from homeassistant.helpers.integration_platform import (
|
from homeassistant.helpers.integration_platform import (
|
||||||
async_process_integration_platforms,
|
async_process_integration_platforms,
|
||||||
)
|
)
|
||||||
|
from homeassistant.helpers.recorder import DATA_INSTANCE
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
from homeassistant.loader import bind_hass
|
from homeassistant.loader import bind_hass
|
||||||
from homeassistant.util.event_type import EventType
|
from homeassistant.util.event_type import EventType
|
||||||
|
@ -30,7 +31,6 @@ from homeassistant.util.event_type import EventType
|
||||||
from . import entity_registry, websocket_api
|
from . import entity_registry, websocket_api
|
||||||
from .const import ( # noqa: F401
|
from .const import ( # noqa: F401
|
||||||
CONF_DB_INTEGRITY_CHECK,
|
CONF_DB_INTEGRITY_CHECK,
|
||||||
DATA_INSTANCE,
|
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
INTEGRATION_PLATFORM_COMPILE_STATISTICS,
|
INTEGRATION_PLATFORM_COMPILE_STATISTICS,
|
||||||
INTEGRATION_PLATFORMS_RUN_IN_RECORDER_THREAD,
|
INTEGRATION_PLATFORMS_RUN_IN_RECORDER_THREAD,
|
||||||
|
|
|
@ -13,15 +13,11 @@ from homeassistant.const import (
|
||||||
EVENT_RECORDER_HOURLY_STATISTICS_GENERATED, # noqa: F401
|
EVENT_RECORDER_HOURLY_STATISTICS_GENERATED, # noqa: F401
|
||||||
)
|
)
|
||||||
from homeassistant.helpers.json import JSON_DUMP # noqa: F401
|
from homeassistant.helpers.json import JSON_DUMP # noqa: F401
|
||||||
from homeassistant.util.hass_dict import HassKey
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .core import Recorder # noqa: F401
|
from .core import Recorder # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
DATA_INSTANCE: HassKey[Recorder] = HassKey("recorder_instance")
|
|
||||||
|
|
||||||
|
|
||||||
SQLITE_URL_PREFIX = "sqlite://"
|
SQLITE_URL_PREFIX = "sqlite://"
|
||||||
MARIADB_URL_PREFIX = "mariadb://"
|
MARIADB_URL_PREFIX = "mariadb://"
|
||||||
MARIADB_PYMYSQL_URL_PREFIX = "mariadb+pymysql://"
|
MARIADB_PYMYSQL_URL_PREFIX = "mariadb+pymysql://"
|
||||||
|
|
|
@ -29,10 +29,14 @@ import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers import config_validation as cv, issue_registry as ir
|
from homeassistant.helpers import config_validation as cv, issue_registry as ir
|
||||||
|
from homeassistant.helpers.recorder import ( # noqa: F401
|
||||||
|
DATA_INSTANCE,
|
||||||
|
get_instance,
|
||||||
|
session_scope,
|
||||||
|
)
|
||||||
import homeassistant.util.dt as dt_util
|
import homeassistant.util.dt as dt_util
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
DATA_INSTANCE,
|
|
||||||
DEFAULT_MAX_BIND_VARS,
|
DEFAULT_MAX_BIND_VARS,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
SQLITE_MAX_BIND_VARS,
|
SQLITE_MAX_BIND_VARS,
|
||||||
|
@ -111,42 +115,6 @@ SUNDAY_WEEKDAY = 6
|
||||||
DAYS_IN_WEEK = 7
|
DAYS_IN_WEEK = 7
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def session_scope(
|
|
||||||
*,
|
|
||||||
hass: HomeAssistant | None = None,
|
|
||||||
session: Session | None = None,
|
|
||||||
exception_filter: Callable[[Exception], bool] | None = None,
|
|
||||||
read_only: bool = False,
|
|
||||||
) -> Generator[Session]:
|
|
||||||
"""Provide a transactional scope around a series of operations.
|
|
||||||
|
|
||||||
read_only is used to indicate that the session is only used for reading
|
|
||||||
data and that no commit is required. It does not prevent the session
|
|
||||||
from writing and is not a security measure.
|
|
||||||
"""
|
|
||||||
if session is None and hass is not None:
|
|
||||||
session = get_instance(hass).get_session()
|
|
||||||
|
|
||||||
if session is None:
|
|
||||||
raise RuntimeError("Session required")
|
|
||||||
|
|
||||||
need_rollback = False
|
|
||||||
try:
|
|
||||||
yield session
|
|
||||||
if not read_only and session.get_transaction():
|
|
||||||
need_rollback = True
|
|
||||||
session.commit()
|
|
||||||
except Exception as err:
|
|
||||||
_LOGGER.exception("Error executing query")
|
|
||||||
if need_rollback:
|
|
||||||
session.rollback()
|
|
||||||
if not exception_filter or not exception_filter(err):
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
session.close()
|
|
||||||
|
|
||||||
|
|
||||||
def execute(
|
def execute(
|
||||||
qry: Query, to_native: bool = False, validate_entity_ids: bool = True
|
qry: Query, to_native: bool = False, validate_entity_ids: bool = True
|
||||||
) -> list[Row]:
|
) -> list[Row]:
|
||||||
|
@ -769,12 +737,6 @@ def is_second_sunday(date_time: datetime) -> bool:
|
||||||
return bool(second_sunday(date_time.year, date_time.month).day == date_time.day)
|
return bool(second_sunday(date_time.year, date_time.month).day == date_time.day)
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=1)
|
|
||||||
def get_instance(hass: HomeAssistant) -> Recorder:
|
|
||||||
"""Get the recorder instance."""
|
|
||||||
return hass.data[DATA_INSTANCE]
|
|
||||||
|
|
||||||
|
|
||||||
PERIOD_SCHEMA = vol.Schema(
|
PERIOD_SCHEMA = vol.Schema(
|
||||||
{
|
{
|
||||||
vol.Exclusive("calendar", "period"): vol.Schema(
|
vol.Exclusive("calendar", "period"): vol.Schema(
|
||||||
|
|
|
@ -3,13 +3,25 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections.abc import Callable, Generator
|
||||||
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
import functools
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.util.hass_dict import HassKey
|
from homeassistant.util.hass_dict import HassKey
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sqlalchemy.orm.session import Session
|
||||||
|
|
||||||
|
from homeassistant.components.recorder import Recorder
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
DOMAIN: HassKey[RecorderData] = HassKey("recorder")
|
DOMAIN: HassKey[RecorderData] = HassKey("recorder")
|
||||||
|
DATA_INSTANCE: HassKey[Recorder] = HassKey("recorder_instance")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
|
@ -56,3 +68,45 @@ async def async_wait_recorder(hass: HomeAssistant) -> bool:
|
||||||
if DOMAIN not in hass.data:
|
if DOMAIN not in hass.data:
|
||||||
return False
|
return False
|
||||||
return await hass.data[DOMAIN].db_connected
|
return await hass.data[DOMAIN].db_connected
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=1)
|
||||||
|
def get_instance(hass: HomeAssistant) -> Recorder:
|
||||||
|
"""Get the recorder instance."""
|
||||||
|
return hass.data[DATA_INSTANCE]
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def session_scope(
|
||||||
|
*,
|
||||||
|
hass: HomeAssistant | None = None,
|
||||||
|
session: Session | None = None,
|
||||||
|
exception_filter: Callable[[Exception], bool] | None = None,
|
||||||
|
read_only: bool = False,
|
||||||
|
) -> Generator[Session]:
|
||||||
|
"""Provide a transactional scope around a series of operations.
|
||||||
|
|
||||||
|
read_only is used to indicate that the session is only used for reading
|
||||||
|
data and that no commit is required. It does not prevent the session
|
||||||
|
from writing and is not a security measure.
|
||||||
|
"""
|
||||||
|
if session is None and hass is not None:
|
||||||
|
session = get_instance(hass).get_session()
|
||||||
|
|
||||||
|
if session is None:
|
||||||
|
raise RuntimeError("Session required")
|
||||||
|
|
||||||
|
need_rollback = False
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
if not read_only and session.get_transaction():
|
||||||
|
need_rollback = True
|
||||||
|
session.commit()
|
||||||
|
except Exception as err:
|
||||||
|
_LOGGER.exception("Error executing query")
|
||||||
|
if need_rollback:
|
||||||
|
session.rollback()
|
||||||
|
if not exception_filter or not exception_filter(err):
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
|
@ -39,6 +39,9 @@ from syrupy.assertion import SnapshotAssertion
|
||||||
from homeassistant import block_async_io
|
from homeassistant import block_async_io
|
||||||
from homeassistant.exceptions import ServiceNotFound
|
from homeassistant.exceptions import ServiceNotFound
|
||||||
|
|
||||||
|
# Setup patching of recorder functions before any other Home Assistant imports
|
||||||
|
from . import patch_recorder # noqa: F401, isort:skip
|
||||||
|
|
||||||
# Setup patching of dt_util time functions before any other Home Assistant imports
|
# Setup patching of dt_util time functions before any other Home Assistant imports
|
||||||
from . import patch_time # noqa: F401, isort:skip
|
from . import patch_time # noqa: F401, isort:skip
|
||||||
|
|
||||||
|
@ -1423,6 +1426,15 @@ async def _async_init_recorder_component(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadSession(threading.local):
|
||||||
|
"""Keep track of session per thread."""
|
||||||
|
|
||||||
|
has_session = False
|
||||||
|
|
||||||
|
|
||||||
|
thread_session = ThreadSession()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def async_test_recorder(
|
async def async_test_recorder(
|
||||||
recorder_db_url: str,
|
recorder_db_url: str,
|
||||||
|
@ -1444,6 +1456,39 @@ async def async_test_recorder(
|
||||||
# pylint: disable-next=import-outside-toplevel
|
# pylint: disable-next=import-outside-toplevel
|
||||||
from .components.recorder.common import async_recorder_block_till_done
|
from .components.recorder.common import async_recorder_block_till_done
|
||||||
|
|
||||||
|
# pylint: disable-next=import-outside-toplevel
|
||||||
|
from .patch_recorder import real_session_scope
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# pylint: disable-next=import-outside-toplevel
|
||||||
|
from sqlalchemy.orm.session import Session
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def debug_session_scope(
|
||||||
|
*,
|
||||||
|
hass: HomeAssistant | None = None,
|
||||||
|
session: Session | None = None,
|
||||||
|
exception_filter: Callable[[Exception], bool] | None = None,
|
||||||
|
read_only: bool = False,
|
||||||
|
) -> Generator[Session]:
|
||||||
|
"""Wrap session_scope to bark if we create nested sessions."""
|
||||||
|
if thread_session.has_session:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Thread '{threading.current_thread().name}' already has an "
|
||||||
|
"active session"
|
||||||
|
)
|
||||||
|
thread_session.has_session = True
|
||||||
|
try:
|
||||||
|
with real_session_scope(
|
||||||
|
hass=hass,
|
||||||
|
session=session,
|
||||||
|
exception_filter=exception_filter,
|
||||||
|
read_only=read_only,
|
||||||
|
) as ses:
|
||||||
|
yield ses
|
||||||
|
finally:
|
||||||
|
thread_session.has_session = False
|
||||||
|
|
||||||
nightly = recorder.Recorder.async_nightly_tasks if enable_nightly_purge else None
|
nightly = recorder.Recorder.async_nightly_tasks if enable_nightly_purge else None
|
||||||
stats = recorder.Recorder.async_periodic_statistics if enable_statistics else None
|
stats = recorder.Recorder.async_periodic_statistics if enable_statistics else None
|
||||||
schema_validate = (
|
schema_validate = (
|
||||||
|
@ -1525,6 +1570,12 @@ async def async_test_recorder(
|
||||||
side_effect=compile_missing,
|
side_effect=compile_missing,
|
||||||
autospec=True,
|
autospec=True,
|
||||||
),
|
),
|
||||||
|
patch.object(
|
||||||
|
patch_recorder,
|
||||||
|
"real_session_scope",
|
||||||
|
side_effect=debug_session_scope,
|
||||||
|
autospec=True,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
|
|
27
tests/patch_recorder.py
Normal file
27
tests/patch_recorder.py
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
"""Patch recorder related functions."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Patch recorder util session scope
|
||||||
|
from homeassistant.helpers import recorder as recorder_helper # noqa: E402
|
||||||
|
|
||||||
|
# Make sure homeassistant.components.recorder.util is not already imported
|
||||||
|
assert "homeassistant.components.recorder.util" not in sys.modules
|
||||||
|
|
||||||
|
real_session_scope = recorder_helper.session_scope
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _session_scope_wrapper(*args, **kwargs):
|
||||||
|
"""Make session_scope patchable.
|
||||||
|
|
||||||
|
This function will be imported by recorder modules.
|
||||||
|
"""
|
||||||
|
with real_session_scope(*args, **kwargs) as ses:
|
||||||
|
yield ses
|
||||||
|
|
||||||
|
|
||||||
|
recorder_helper.session_scope = _session_scope_wrapper
|
Loading…
Add table
Reference in a new issue