Fail tests if recorder creates nested sessions (#122579)

* Fail tests if recorder creates nested sessions

* Adjust import order

* Move get_instance
This commit is contained in:
Erik Montnemery 2024-07-25 21:18:55 +02:00 committed by GitHub
parent 32a0463f47
commit 5dbd7684ce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 139 additions and 49 deletions

View file

@ -29,10 +29,14 @@ import voluptuous as vol
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, issue_registry as ir
from homeassistant.helpers.recorder import ( # noqa: F401
DATA_INSTANCE,
get_instance,
session_scope,
)
import homeassistant.util.dt as dt_util
from .const import (
DATA_INSTANCE,
DEFAULT_MAX_BIND_VARS,
DOMAIN,
SQLITE_MAX_BIND_VARS,
@ -111,42 +115,6 @@ SUNDAY_WEEKDAY = 6
DAYS_IN_WEEK = 7
@contextmanager
def session_scope(
*,
hass: HomeAssistant | None = None,
session: Session | None = None,
exception_filter: Callable[[Exception], bool] | None = None,
read_only: bool = False,
) -> Generator[Session]:
"""Provide a transactional scope around a series of operations.
read_only is used to indicate that the session is only used for reading
data and that no commit is required. It does not prevent the session
from writing and is not a security measure.
"""
if session is None and hass is not None:
session = get_instance(hass).get_session()
if session is None:
raise RuntimeError("Session required")
need_rollback = False
try:
yield session
if not read_only and session.get_transaction():
need_rollback = True
session.commit()
except Exception as err:
_LOGGER.exception("Error executing query")
if need_rollback:
session.rollback()
if not exception_filter or not exception_filter(err):
raise
finally:
session.close()
def execute(
qry: Query, to_native: bool = False, validate_entity_ids: bool = True
) -> list[Row]:
@ -769,12 +737,6 @@ def is_second_sunday(date_time: datetime) -> bool:
return bool(second_sunday(date_time.year, date_time.month).day == date_time.day)
@functools.lru_cache(maxsize=1)
def get_instance(hass: HomeAssistant) -> Recorder:
"""Get the recorder instance."""
return hass.data[DATA_INSTANCE]
PERIOD_SCHEMA = vol.Schema(
{
vol.Exclusive("calendar", "period"): vol.Schema(