Mark database sessions that do not write data as read_only (#89600)

* Mark sessions that do not write data as read_only

* Mark sessions that do not write data as read_only
This commit is contained in:
J. Nick Koston 2023-03-12 15:33:28 -10:00 committed by GitHub
parent 977a07de13
commit 85ca94e9d4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 22 additions and 16 deletions

View file

@ -168,7 +168,7 @@ class HistoryPeriodView(HomeAssistantView):
"""Fetch significant stats from the database as json."""
timer_start = time.perf_counter()
with session_scope(hass=hass) as session:
with session_scope(hass=hass, read_only=True) as session:
states = history.get_significant_states_with_session(
hass,
session,

View file

@ -150,7 +150,7 @@ class EventProcessor:
#
return result.yield_per(1024)
with session_scope(hass=self.hass) as session:
with session_scope(hass=self.hass, read_only=True) as session:
metadata_ids: list[int] | None = None
if self.entity_ids:
instance = get_instance(self.hass)

View file

@ -213,7 +213,7 @@ def get_significant_states(
compressed_state_format: bool = False,
) -> MutableMapping[str, list[State | dict[str, Any]]]:
"""Wrap get_significant_states_with_session with an sql session."""
with session_scope(hass=hass) as session:
with session_scope(hass=hass, read_only=True) as session:
return get_significant_states_with_session(
hass,
session,
@ -488,7 +488,7 @@ def state_changes_during_period(
entity_id = entity_id.lower() if entity_id is not None else None
entity_ids = [entity_id] if entity_id is not None else None
with session_scope(hass=hass) as session:
with session_scope(hass=hass, read_only=True) as session:
stmt = _state_changed_during_period_stmt(
_schema_version(hass),
start_time,
@ -558,7 +558,7 @@ def get_last_state_changes(
entity_id_lower = entity_id.lower()
entity_ids = [entity_id_lower]
with session_scope(hass=hass) as session:
with session_scope(hass=hass, read_only=True) as session:
stmt = _get_last_state_changes_stmt(
_schema_version(hass), number_of_states, entity_id_lower
)

View file

@ -115,7 +115,7 @@ def get_significant_states(
compressed_state_format: bool = False,
) -> MutableMapping[str, list[State | dict[str, Any]]]:
"""Wrap get_significant_states_with_session with an sql session."""
with session_scope(hass=hass) as session:
with session_scope(hass=hass, read_only=True) as session:
return get_significant_states_with_session(
hass,
session,
@ -360,7 +360,7 @@ def state_changes_during_period(
entity_id = entity_id.lower() if entity_id is not None else None
entity_ids = [entity_id] if entity_id is not None else None
with session_scope(hass=hass) as session:
with session_scope(hass=hass, read_only=True) as session:
metadata_id: int | None = None
entity_id_to_metadata_id = None
if entity_id:
@ -424,7 +424,7 @@ def get_last_state_changes(
entity_id_lower = entity_id.lower()
entity_ids = [entity_id_lower]
with session_scope(hass=hass) as session:
with session_scope(hass=hass, read_only=True) as session:
instance = recorder.get_instance(hass)
if not (metadata_id := instance.states_meta_manager.get(entity_id, session)):
return {}

View file

@ -925,7 +925,7 @@ def get_metadata(
statistic_source: str | None = None,
) -> dict[str, tuple[int, StatisticMetaData]]:
"""Return metadata for statistic_ids."""
with session_scope(hass=hass) as session:
with session_scope(hass=hass, read_only=True) as session:
return get_metadata_with_session(
session,
statistic_ids=statistic_ids,
@ -985,7 +985,7 @@ def list_statistic_ids(
statistic_ids_set = set(statistic_ids) if statistic_ids else None
# Query the database
with session_scope(hass=hass) as session:
with session_scope(hass=hass, read_only=True) as session:
metadata = get_metadata_with_session(
session, statistic_type=statistic_type, statistic_ids=statistic_ids
)
@ -1589,7 +1589,7 @@ def statistic_during_period(
result: dict[str, Any] = {}
with session_scope(hass=hass) as session:
with session_scope(hass=hass, read_only=True) as session:
# Fetch metadata for the given statistic_id
if not (
metadata := get_metadata_with_session(session, statistic_ids=[statistic_id])
@ -1814,7 +1814,7 @@ def statistics_during_period(
If end_time is omitted, returns statistics newer than or equal to start_time.
If statistic_ids is omitted, returns statistics for all statistics ids.
"""
with session_scope(hass=hass) as session:
with session_scope(hass=hass, read_only=True) as session:
return _statistics_during_period_with_session(
hass,
session,
@ -1866,7 +1866,7 @@ def _get_last_statistics(
) -> dict[str, list[dict]]:
"""Return the last number_of_stats statistics for a given statistic_id."""
statistic_ids = [statistic_id]
with session_scope(hass=hass) as session:
with session_scope(hass=hass, read_only=True) as session:
# Fetch metadata for the given statistic_id
metadata = get_metadata_with_session(session, statistic_ids=statistic_ids)
if not metadata:
@ -1953,7 +1953,7 @@ def get_latest_short_term_statistics(
metadata: dict[str, tuple[int, StatisticMetaData]] | None = None,
) -> dict[str, list[dict]]:
"""Return the latest short term statistics for a list of statistic_ids."""
with session_scope(hass=hass) as session:
with session_scope(hass=hass, read_only=True) as session:
# Fetch metadata for the given statistic_ids
if not metadata:
metadata = get_metadata_with_session(session, statistic_ids=statistic_ids)

View file

@ -110,8 +110,14 @@ def session_scope(
hass: HomeAssistant | None = None,
session: Session | None = None,
exception_filter: Callable[[Exception], bool] | None = None,
read_only: bool = False,
) -> Generator[Session, None, None]:
"""Provide a transactional scope around a series of operations."""
"""Provide a transactional scope around a series of operations.
read_only is used to indicate that the session is only used for reading
data and that no commit is required. It does not prevent the session
from writing and is not a security measure.
"""
if session is None and hass is not None:
session = get_instance(hass).get_session()
@ -121,7 +127,7 @@ def session_scope(
need_rollback = False
try:
yield session
if session.get_transaction():
if session.get_transaction() and not read_only:
need_rollback = True
session.commit()
except Exception as err: # pylint: disable=broad-except