From 98809675ff8cc87dbb5e165ad3096d75e69f3b44 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 15 May 2022 10:47:29 -0500 Subject: [PATCH] Convert history queries to use lambda_stmt (#71870) Co-authored-by: Paulus Schoutsen --- homeassistant/components/recorder/__init__.py | 3 +- homeassistant/components/recorder/filters.py | 11 - homeassistant/components/recorder/history.py | 454 ++++++++---------- homeassistant/components/recorder/util.py | 46 +- tests/components/recorder/test_util.py | 61 ++- 5 files changed, 313 insertions(+), 262 deletions(-) diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index ca1ecd8c71a..4063e443e8b 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -20,7 +20,7 @@ from homeassistant.helpers.integration_platform import ( from homeassistant.helpers.typing import ConfigType from homeassistant.loader import bind_hass -from . import history, statistics, websocket_api +from . import statistics, websocket_api from .const import ( CONF_DB_INTEGRITY_CHECK, DATA_INSTANCE, @@ -166,7 +166,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: instance.async_register() instance.start() async_register_services(hass, instance) - history.async_setup(hass) statistics.async_setup(hass) websocket_api.async_setup(hass) await async_process_integration_platforms(hass, DOMAIN, _process_recorder_platform) diff --git a/homeassistant/components/recorder/filters.py b/homeassistant/components/recorder/filters.py index bb19dfc6d62..adc746379e6 100644 --- a/homeassistant/components/recorder/filters.py +++ b/homeassistant/components/recorder/filters.py @@ -2,7 +2,6 @@ from __future__ import annotations from sqlalchemy import not_, or_ -from sqlalchemy.ext.baked import BakedQuery from sqlalchemy.sql.elements import ClauseList from homeassistant.const import CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE @@ -60,16 +59,6 @@ class Filters: or self.included_entity_globs ) - def bake(self, baked_query: BakedQuery) -> BakedQuery: - """Update a baked query. - - Works the same as apply on a baked_query. - """ - if not self.has_config: - return - - baked_query += lambda q: q.filter(self.entity_filter()) - def entity_filter(self) -> ClauseList: """Generate the entity filter query.""" includes = [] diff --git a/homeassistant/components/recorder/history.py b/homeassistant/components/recorder/history.py index 316e5ab27c8..3df444faccc 100644 --- a/homeassistant/components/recorder/history.py +++ b/homeassistant/components/recorder/history.py @@ -9,13 +9,12 @@ import logging import time from typing import Any, cast -from sqlalchemy import Column, Text, and_, bindparam, func, or_ +from sqlalchemy import Column, Text, and_, func, lambda_stmt, or_, select from sqlalchemy.engine.row import Row -from sqlalchemy.ext import baked -from sqlalchemy.ext.baked import BakedQuery from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session from sqlalchemy.sql.expression import literal +from sqlalchemy.sql.lambdas import StatementLambdaElement from homeassistant.components import recorder from homeassistant.components.websocket_api.const import ( @@ -36,7 +35,7 @@ from .models import ( process_timestamp_to_utc_isoformat, row_to_compressed_state, ) -from .util import execute, session_scope +from .util import execute_stmt_lambda_element, session_scope # mypy: allow-untyped-defs, no-check-untyped-defs @@ -111,52 +110,48 @@ QUERY_STATES_NO_LAST_CHANGED = [ StateAttributes.shared_attrs, ] -HISTORY_BAKERY = "recorder_history_bakery" + +def _schema_version(hass: HomeAssistant) -> int: + return recorder.get_instance(hass).schema_version -def bake_query_and_join_attributes( - hass: HomeAssistant, no_attributes: bool, include_last_changed: bool = True -) -> tuple[Any, bool]: - """Return the initial backed query and if StateAttributes should be joined. +def lambda_stmt_and_join_attributes( + schema_version: int, no_attributes: bool, include_last_changed: bool = True +) -> tuple[StatementLambdaElement, bool]: + """Return the lambda_stmt and if StateAttributes should be joined. - Because these are baked queries the values inside the lambdas need + Because these are lambda_stmt the values inside the lambdas need to be explicitly written out to avoid caching the wrong values. """ - bakery: baked.bakery = hass.data[HISTORY_BAKERY] # If no_attributes was requested we do the query # without the attributes fields and do not join the # state_attributes table if no_attributes: if include_last_changed: - return bakery(lambda s: s.query(*QUERY_STATE_NO_ATTR)), False + return lambda_stmt(lambda: select(*QUERY_STATE_NO_ATTR)), False return ( - bakery(lambda s: s.query(*QUERY_STATE_NO_ATTR_NO_LAST_CHANGED)), + lambda_stmt(lambda: select(*QUERY_STATE_NO_ATTR_NO_LAST_CHANGED)), False, ) # If we in the process of migrating schema we do # not want to join the state_attributes table as we # do not know if it will be there yet - if recorder.get_instance(hass).schema_version < 25: + if schema_version < 25: if include_last_changed: return ( - bakery(lambda s: s.query(*QUERY_STATES_PRE_SCHEMA_25)), + lambda_stmt(lambda: select(*QUERY_STATES_PRE_SCHEMA_25)), False, ) return ( - bakery(lambda s: s.query(*QUERY_STATES_PRE_SCHEMA_25_NO_LAST_CHANGED)), + lambda_stmt(lambda: select(*QUERY_STATES_PRE_SCHEMA_25_NO_LAST_CHANGED)), False, ) # Finally if no migration is in progress and no_attributes # was not requested, we query both attributes columns and # join state_attributes if include_last_changed: - return bakery(lambda s: s.query(*QUERY_STATES)), True - return bakery(lambda s: s.query(*QUERY_STATES_NO_LAST_CHANGED)), True - - -def async_setup(hass: HomeAssistant) -> None: - """Set up the history hooks.""" - hass.data[HISTORY_BAKERY] = baked.bakery() + return lambda_stmt(lambda: select(*QUERY_STATES)), True + return lambda_stmt(lambda: select(*QUERY_STATES_NO_LAST_CHANGED)), True def get_significant_states( @@ -200,38 +195,30 @@ def _ignore_domains_filter(query: Query) -> Query: ) -def _query_significant_states_with_session( - hass: HomeAssistant, - session: Session, +def _significant_states_stmt( + schema_version: int, start_time: datetime, - end_time: datetime | None = None, - entity_ids: list[str] | None = None, - filters: Filters | None = None, - significant_changes_only: bool = True, - no_attributes: bool = False, -) -> list[Row]: + end_time: datetime | None, + entity_ids: list[str] | None, + filters: Filters | None, + significant_changes_only: bool, + no_attributes: bool, +) -> StatementLambdaElement: """Query the database for significant state changes.""" - if _LOGGER.isEnabledFor(logging.DEBUG): - timer_start = time.perf_counter() - - baked_query, join_attributes = bake_query_and_join_attributes( - hass, no_attributes, include_last_changed=True + stmt, join_attributes = lambda_stmt_and_join_attributes( + schema_version, no_attributes, include_last_changed=not significant_changes_only ) - - if entity_ids is not None and len(entity_ids) == 1: - if ( - significant_changes_only - and split_entity_id(entity_ids[0])[0] not in SIGNIFICANT_DOMAINS - ): - baked_query, join_attributes = bake_query_and_join_attributes( - hass, no_attributes, include_last_changed=False - ) - baked_query += lambda q: q.filter( - (States.last_changed == States.last_updated) - | States.last_changed.is_(None) - ) + if ( + entity_ids + and len(entity_ids) == 1 + and significant_changes_only + and split_entity_id(entity_ids[0])[0] not in SIGNIFICANT_DOMAINS + ): + stmt += lambda q: q.filter( + (States.last_changed == States.last_updated) | States.last_changed.is_(None) + ) elif significant_changes_only: - baked_query += lambda q: q.filter( + stmt += lambda q: q.filter( or_( *[ States.entity_id.like(entity_domain) @@ -244,36 +231,24 @@ def _query_significant_states_with_session( ) ) - if entity_ids is not None: - baked_query += lambda q: q.filter( - States.entity_id.in_(bindparam("entity_ids", expanding=True)) - ) + if entity_ids: + stmt += lambda q: q.filter(States.entity_id.in_(entity_ids)) else: - baked_query += _ignore_domains_filter - if filters: - filters.bake(baked_query) + stmt += _ignore_domains_filter + if filters and filters.has_config: + entity_filter = filters.entity_filter() + stmt += lambda q: q.filter(entity_filter) - baked_query += lambda q: q.filter(States.last_updated > bindparam("start_time")) - if end_time is not None: - baked_query += lambda q: q.filter(States.last_updated < bindparam("end_time")) + stmt += lambda q: q.filter(States.last_updated > start_time) + if end_time: + stmt += lambda q: q.filter(States.last_updated < end_time) if join_attributes: - baked_query += lambda q: q.outerjoin( + stmt += lambda q: q.outerjoin( StateAttributes, States.attributes_id == StateAttributes.attributes_id ) - baked_query += lambda q: q.order_by(States.entity_id, States.last_updated) - - states = execute( - baked_query(session).params( - start_time=start_time, end_time=end_time, entity_ids=entity_ids - ) - ) - - if _LOGGER.isEnabledFor(logging.DEBUG): - elapsed = time.perf_counter() - timer_start - _LOGGER.debug("get_significant_states took %fs", elapsed) - - return states + stmt += lambda q: q.order_by(States.entity_id, States.last_updated) + return stmt def get_significant_states_with_session( @@ -301,9 +276,8 @@ def get_significant_states_with_session( as well as all states from certain domains (for instance thermostat so that we get current temperature in our graphs). """ - states = _query_significant_states_with_session( - hass, - session, + stmt = _significant_states_stmt( + _schema_version(hass), start_time, end_time, entity_ids, @@ -311,6 +285,9 @@ def get_significant_states_with_session( significant_changes_only, no_attributes, ) + states = execute_stmt_lambda_element( + session, stmt, None if entity_ids else start_time, end_time + ) return _sorted_states_to_dict( hass, session, @@ -354,6 +331,38 @@ def get_full_significant_states_with_session( ) +def _state_changed_during_period_stmt( + schema_version: int, + start_time: datetime, + end_time: datetime | None, + entity_id: str | None, + no_attributes: bool, + descending: bool, + limit: int | None, +) -> StatementLambdaElement: + stmt, join_attributes = lambda_stmt_and_join_attributes( + schema_version, no_attributes, include_last_changed=False + ) + stmt += lambda q: q.filter( + ((States.last_changed == States.last_updated) | States.last_changed.is_(None)) + & (States.last_updated > start_time) + ) + if end_time: + stmt += lambda q: q.filter(States.last_updated < end_time) + stmt += lambda q: q.filter(States.entity_id == entity_id) + if join_attributes: + stmt += lambda q: q.outerjoin( + StateAttributes, States.attributes_id == StateAttributes.attributes_id + ) + if descending: + stmt += lambda q: q.order_by(States.entity_id, States.last_updated.desc()) + else: + stmt += lambda q: q.order_by(States.entity_id, States.last_updated) + if limit: + stmt += lambda q: q.limit(limit) + return stmt + + def state_changes_during_period( hass: HomeAssistant, start_time: datetime, @@ -365,52 +374,21 @@ def state_changes_during_period( include_start_time_state: bool = True, ) -> MutableMapping[str, list[State]]: """Return states changes during UTC period start_time - end_time.""" + entity_id = entity_id.lower() if entity_id is not None else None + with session_scope(hass=hass) as session: - baked_query, join_attributes = bake_query_and_join_attributes( - hass, no_attributes, include_last_changed=False + stmt = _state_changed_during_period_stmt( + _schema_version(hass), + start_time, + end_time, + entity_id, + no_attributes, + descending, + limit, ) - - baked_query += lambda q: q.filter( - ( - (States.last_changed == States.last_updated) - | States.last_changed.is_(None) - ) - & (States.last_updated > bindparam("start_time")) + states = execute_stmt_lambda_element( + session, stmt, None if entity_id else start_time, end_time ) - - if end_time is not None: - baked_query += lambda q: q.filter( - States.last_updated < bindparam("end_time") - ) - - if entity_id is not None: - baked_query += lambda q: q.filter_by(entity_id=bindparam("entity_id")) - entity_id = entity_id.lower() - - if join_attributes: - baked_query += lambda q: q.outerjoin( - StateAttributes, States.attributes_id == StateAttributes.attributes_id - ) - - if descending: - baked_query += lambda q: q.order_by( - States.entity_id, States.last_updated.desc() - ) - else: - baked_query += lambda q: q.order_by(States.entity_id, States.last_updated) - - if limit: - baked_query += lambda q: q.limit(bindparam("limit")) - - states = execute( - baked_query(session).params( - start_time=start_time, - end_time=end_time, - entity_id=entity_id, - limit=limit, - ) - ) - entity_ids = [entity_id] if entity_id is not None else None return cast( @@ -426,41 +404,37 @@ def state_changes_during_period( ) +def _get_last_state_changes_stmt( + schema_version: int, number_of_states: int, entity_id: str +) -> StatementLambdaElement: + stmt, join_attributes = lambda_stmt_and_join_attributes( + schema_version, False, include_last_changed=False + ) + stmt += lambda q: q.filter( + (States.last_changed == States.last_updated) | States.last_changed.is_(None) + ).filter(States.entity_id == entity_id) + if join_attributes: + stmt += lambda q: q.outerjoin( + StateAttributes, States.attributes_id == StateAttributes.attributes_id + ) + stmt += lambda q: q.order_by(States.entity_id, States.last_updated.desc()).limit( + number_of_states + ) + return stmt + + def get_last_state_changes( hass: HomeAssistant, number_of_states: int, entity_id: str ) -> MutableMapping[str, list[State]]: """Return the last number_of_states.""" start_time = dt_util.utcnow() + entity_id = entity_id.lower() if entity_id is not None else None with session_scope(hass=hass) as session: - baked_query, join_attributes = bake_query_and_join_attributes( - hass, False, include_last_changed=False + stmt = _get_last_state_changes_stmt( + _schema_version(hass), number_of_states, entity_id ) - - baked_query += lambda q: q.filter( - (States.last_changed == States.last_updated) | States.last_changed.is_(None) - ) - - if entity_id is not None: - baked_query += lambda q: q.filter_by(entity_id=bindparam("entity_id")) - entity_id = entity_id.lower() - - if join_attributes: - baked_query += lambda q: q.outerjoin( - StateAttributes, States.attributes_id == StateAttributes.attributes_id - ) - baked_query += lambda q: q.order_by( - States.entity_id, States.last_updated.desc() - ) - - baked_query += lambda q: q.limit(bindparam("number_of_states")) - - states = execute( - baked_query(session).params( - number_of_states=number_of_states, entity_id=entity_id - ) - ) - + states = list(execute_stmt_lambda_element(session, stmt)) entity_ids = [entity_id] if entity_id is not None else None return cast( @@ -476,96 +450,91 @@ def get_last_state_changes( ) -def _most_recent_state_ids_entities_subquery(query: Query) -> Query: - """Query to find the most recent state id for specific entities.""" +def _get_states_for_entites_stmt( + schema_version: int, + run_start: datetime, + utc_point_in_time: datetime, + entity_ids: list[str], + no_attributes: bool, +) -> StatementLambdaElement: + """Baked query to get states for specific entities.""" + stmt, join_attributes = lambda_stmt_and_join_attributes( + schema_version, no_attributes, include_last_changed=True + ) # We got an include-list of entities, accelerate the query by filtering already # in the inner query. - most_recent_state_ids = ( - query.session.query(func.max(States.state_id).label("max_state_id")) - .filter( - (States.last_updated >= bindparam("run_start")) - & (States.last_updated < bindparam("utc_point_in_time")) - ) - .filter(States.entity_id.in_(bindparam("entity_ids", expanding=True))) - .group_by(States.entity_id) - .subquery() + stmt += lambda q: q.where( + States.state_id + == ( + select(func.max(States.state_id).label("max_state_id")) + .filter( + (States.last_updated >= run_start) + & (States.last_updated < utc_point_in_time) + ) + .filter(States.entity_id.in_(entity_ids)) + .group_by(States.entity_id) + .subquery() + ).c.max_state_id ) - return query.join( - most_recent_state_ids, - States.state_id == most_recent_state_ids.c.max_state_id, - ) - - -def _get_states_baked_query_for_entites( - hass: HomeAssistant, - no_attributes: bool = False, -) -> BakedQuery: - """Baked query to get states for specific entities.""" - baked_query, join_attributes = bake_query_and_join_attributes( - hass, no_attributes, include_last_changed=True - ) - baked_query += _most_recent_state_ids_entities_subquery if join_attributes: - baked_query += lambda q: q.outerjoin( + stmt += lambda q: q.outerjoin( StateAttributes, (States.attributes_id == StateAttributes.attributes_id) ) - return baked_query + return stmt -def _most_recent_state_ids_subquery(query: Query) -> Query: - """Find the most recent state ids for all entiites.""" +def _get_states_for_all_stmt( + schema_version: int, + run_start: datetime, + utc_point_in_time: datetime, + filters: Filters | None, + no_attributes: bool, +) -> StatementLambdaElement: + """Baked query to get states for all entities.""" + stmt, join_attributes = lambda_stmt_and_join_attributes( + schema_version, no_attributes, include_last_changed=True + ) # We did not get an include-list of entities, query all states in the inner # query, then filter out unwanted domains as well as applying the custom filter. # This filtering can't be done in the inner query because the domain column is # not indexed and we can't control what's in the custom filter. most_recent_states_by_date = ( - query.session.query( + select( States.entity_id.label("max_entity_id"), func.max(States.last_updated).label("max_last_updated"), ) .filter( - (States.last_updated >= bindparam("run_start")) - & (States.last_updated < bindparam("utc_point_in_time")) + (States.last_updated >= run_start) + & (States.last_updated < utc_point_in_time) ) .group_by(States.entity_id) .subquery() ) - most_recent_state_ids = ( - query.session.query(func.max(States.state_id).label("max_state_id")) - .join( - most_recent_states_by_date, - and_( - States.entity_id == most_recent_states_by_date.c.max_entity_id, - States.last_updated == most_recent_states_by_date.c.max_last_updated, - ), - ) - .group_by(States.entity_id) - .subquery() + stmt += lambda q: q.where( + States.state_id + == ( + select(func.max(States.state_id).label("max_state_id")) + .join( + most_recent_states_by_date, + and_( + States.entity_id == most_recent_states_by_date.c.max_entity_id, + States.last_updated + == most_recent_states_by_date.c.max_last_updated, + ), + ) + .group_by(States.entity_id) + .subquery() + ).c.max_state_id, ) - return query.join( - most_recent_state_ids, - States.state_id == most_recent_state_ids.c.max_state_id, - ) - - -def _get_states_baked_query_for_all( - hass: HomeAssistant, - filters: Filters | None = None, - no_attributes: bool = False, -) -> BakedQuery: - """Baked query to get states for all entities.""" - baked_query, join_attributes = bake_query_and_join_attributes( - hass, no_attributes, include_last_changed=True - ) - baked_query += _most_recent_state_ids_subquery - baked_query += _ignore_domains_filter - if filters: - filters.bake(baked_query) + stmt += _ignore_domains_filter + if filters and filters.has_config: + entity_filter = filters.entity_filter() + stmt += lambda q: q.filter(entity_filter) if join_attributes: - baked_query += lambda q: q.outerjoin( + stmt += lambda q: q.outerjoin( StateAttributes, (States.attributes_id == StateAttributes.attributes_id) ) - return baked_query + return stmt def _get_rows_with_session( @@ -576,11 +545,15 @@ def _get_rows_with_session( run: RecorderRuns | None = None, filters: Filters | None = None, no_attributes: bool = False, -) -> list[Row]: +) -> Iterable[Row]: """Return the states at a specific point in time.""" + schema_version = _schema_version(hass) if entity_ids and len(entity_ids) == 1: - return _get_single_entity_states_with_session( - hass, session, utc_point_in_time, entity_ids[0], no_attributes + return execute_stmt_lambda_element( + session, + _get_single_entity_states_stmt( + schema_version, utc_point_in_time, entity_ids[0], no_attributes + ), ) if run is None: @@ -593,46 +566,41 @@ def _get_rows_with_session( # We have more than one entity to look at so we need to do a query on states # since the last recorder run started. if entity_ids: - baked_query = _get_states_baked_query_for_entites(hass, no_attributes) - else: - baked_query = _get_states_baked_query_for_all(hass, filters, no_attributes) - - return execute( - baked_query(session).params( - run_start=run.start, - utc_point_in_time=utc_point_in_time, - entity_ids=entity_ids, + stmt = _get_states_for_entites_stmt( + schema_version, run.start, utc_point_in_time, entity_ids, no_attributes ) - ) + else: + stmt = _get_states_for_all_stmt( + schema_version, run.start, utc_point_in_time, filters, no_attributes + ) + + return execute_stmt_lambda_element(session, stmt) -def _get_single_entity_states_with_session( - hass: HomeAssistant, - session: Session, +def _get_single_entity_states_stmt( + schema_version: int, utc_point_in_time: datetime, entity_id: str, no_attributes: bool = False, -) -> list[Row]: +) -> StatementLambdaElement: # Use an entirely different (and extremely fast) query if we only # have a single entity id - baked_query, join_attributes = bake_query_and_join_attributes( - hass, no_attributes, include_last_changed=True + stmt, join_attributes = lambda_stmt_and_join_attributes( + schema_version, no_attributes, include_last_changed=True ) - baked_query += lambda q: q.filter( - States.last_updated < bindparam("utc_point_in_time"), - States.entity_id == bindparam("entity_id"), + stmt += ( + lambda q: q.filter( + States.last_updated < utc_point_in_time, + States.entity_id == entity_id, + ) + .order_by(States.last_updated.desc()) + .limit(1) ) if join_attributes: - baked_query += lambda q: q.outerjoin( + stmt += lambda q: q.outerjoin( StateAttributes, States.attributes_id == StateAttributes.attributes_id ) - baked_query += lambda q: q.order_by(States.last_updated.desc()).limit(1) - - query = baked_query(session).params( - utc_point_in_time=utc_point_in_time, entity_id=entity_id - ) - - return execute(query) + return stmt def _sorted_states_to_dict( diff --git a/homeassistant/components/recorder/util.py b/homeassistant/components/recorder/util.py index f48a6126ea9..eb99c304808 100644 --- a/homeassistant/components/recorder/util.py +++ b/homeassistant/components/recorder/util.py @@ -1,7 +1,7 @@ """SQLAlchemy util functions.""" from __future__ import annotations -from collections.abc import Callable, Generator +from collections.abc import Callable, Generator, Iterable from contextlib import contextmanager from datetime import date, datetime, timedelta import functools @@ -18,9 +18,12 @@ from awesomeversion import ( import ciso8601 from sqlalchemy import text from sqlalchemy.engine.cursor import CursorFetchStrategy +from sqlalchemy.engine.row import Row from sqlalchemy.exc import OperationalError, SQLAlchemyError +from sqlalchemy.ext.baked import Result from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session +from sqlalchemy.sql.lambdas import StatementLambdaElement from typing_extensions import Concatenate, ParamSpec from homeassistant.core import HomeAssistant @@ -46,6 +49,7 @@ _LOGGER = logging.getLogger(__name__) RETRIES = 3 QUERY_RETRY_WAIT = 0.1 SQLITE3_POSTFIXES = ["", "-wal", "-shm"] +DEFAULT_YIELD_STATES_ROWS = 32768 MIN_VERSION_MARIA_DB = AwesomeVersion("10.3.0", AwesomeVersionStrategy.SIMPLEVER) MIN_VERSION_MARIA_DB_ROWNUM = AwesomeVersion("10.2.0", AwesomeVersionStrategy.SIMPLEVER) @@ -119,8 +123,10 @@ def commit(session: Session, work: Any) -> bool: def execute( - qry: Query, to_native: bool = False, validate_entity_ids: bool = True -) -> list: + qry: Query | Result, + to_native: bool = False, + validate_entity_ids: bool = True, +) -> list[Row]: """Query the database and convert the objects to HA native form. This method also retries a few times in the case of stale connections. @@ -163,7 +169,39 @@ def execute( raise time.sleep(QUERY_RETRY_WAIT) - assert False # unreachable + assert False # unreachable # pragma: no cover + + +def execute_stmt_lambda_element( + session: Session, + stmt: StatementLambdaElement, + start_time: datetime | None = None, + end_time: datetime | None = None, + yield_per: int | None = DEFAULT_YIELD_STATES_ROWS, +) -> Iterable[Row]: + """Execute a StatementLambdaElement. + + If the time window passed is greater than one day + the execution method will switch to yield_per to + reduce memory pressure. + + It is not recommended to pass a time window + when selecting non-ranged rows (ie selecting + specific entities) since they are usually faster + with .all(). + """ + executed = session.execute(stmt) + use_all = not start_time or ((end_time or dt_util.utcnow()) - start_time).days <= 1 + for tryno in range(0, RETRIES): + try: + return executed.all() if use_all else executed.yield_per(yield_per) # type: ignore[no-any-return] + except SQLAlchemyError as err: + _LOGGER.error("Error executing query: %s", err) + if tryno == RETRIES - 1: + raise + time.sleep(QUERY_RETRY_WAIT) + + assert False # unreachable # pragma: no cover def validate_or_move_away_sqlite_database(dburl: str) -> bool: diff --git a/tests/components/recorder/test_util.py b/tests/components/recorder/test_util.py index 237030c1186..c650ec20fb0 100644 --- a/tests/components/recorder/test_util.py +++ b/tests/components/recorder/test_util.py @@ -6,10 +6,13 @@ from unittest.mock import MagicMock, Mock, patch import pytest from sqlalchemy import text +from sqlalchemy.engine.result import ChunkedIteratorResult +from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.sql.elements import TextClause +from sqlalchemy.sql.lambdas import StatementLambdaElement from homeassistant.components import recorder -from homeassistant.components.recorder import util +from homeassistant.components.recorder import history, util from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX from homeassistant.components.recorder.models import RecorderRuns from homeassistant.components.recorder.util import ( @@ -24,6 +27,7 @@ from homeassistant.util import dt as dt_util from .common import corrupt_db_file, run_information_with_session from tests.common import SetupRecorderInstanceT, async_test_home_assistant +from tests.components.recorder.common import wait_recording_done def test_session_scope_not_setup(hass_recorder): @@ -510,8 +514,10 @@ def test_basic_sanity_check(hass_recorder): def test_combined_checks(hass_recorder, caplog): """Run Checks on the open database.""" hass = hass_recorder() + instance = recorder.get_instance(hass) + instance.db_retry_wait = 0 - cursor = hass.data[DATA_INSTANCE].engine.raw_connection().cursor() + cursor = instance.engine.raw_connection().cursor() assert util.run_checks_on_open_db("fake_db_path", cursor) is None assert "could not validate that the sqlite3 database" in caplog.text @@ -658,3 +664,54 @@ def test_build_mysqldb_conv(): assert conv["DATETIME"]("2022-05-13T22:33:12.741") == datetime( 2022, 5, 13, 22, 33, 12, 741000, tzinfo=None ) + + +@patch("homeassistant.components.recorder.util.QUERY_RETRY_WAIT", 0) +def test_execute_stmt_lambda_element(hass_recorder): + """Test executing with execute_stmt_lambda_element.""" + hass = hass_recorder() + instance = recorder.get_instance(hass) + hass.states.set("sensor.on", "on") + new_state = hass.states.get("sensor.on") + wait_recording_done(hass) + now = dt_util.utcnow() + tomorrow = now + timedelta(days=1) + one_week_from_now = now + timedelta(days=7) + + class MockExecutor: + def __init__(self, stmt): + assert isinstance(stmt, StatementLambdaElement) + self.calls = 0 + + def all(self): + self.calls += 1 + if self.calls == 2: + return ["mock_row"] + raise SQLAlchemyError + + with session_scope(hass=hass) as session: + # No time window, we always get a list + stmt = history._get_single_entity_states_stmt( + instance.schema_version, dt_util.utcnow(), "sensor.on", False + ) + rows = util.execute_stmt_lambda_element(session, stmt) + assert isinstance(rows, list) + assert rows[0].state == new_state.state + assert rows[0].entity_id == new_state.entity_id + + # Time window >= 2 days, we get a ChunkedIteratorResult + rows = util.execute_stmt_lambda_element(session, stmt, now, one_week_from_now) + assert isinstance(rows, ChunkedIteratorResult) + row = next(rows) + assert row.state == new_state.state + assert row.entity_id == new_state.entity_id + + # Time window < 2 days, we get a list + rows = util.execute_stmt_lambda_element(session, stmt, now, tomorrow) + assert isinstance(rows, list) + assert rows[0].state == new_state.state + assert rows[0].entity_id == new_state.entity_id + + with patch.object(session, "execute", MockExecutor): + rows = util.execute_stmt_lambda_element(session, stmt, now, tomorrow) + assert rows == ["mock_row"]