diff --git a/homeassistant/components/logbook/queries/__init__.py b/homeassistant/components/logbook/queries/__init__.py index 3c027823612..a59ebc94b87 100644 --- a/homeassistant/components/logbook/queries/__init__.py +++ b/homeassistant/components/logbook/queries/__init__.py @@ -2,8 +2,9 @@ from __future__ import annotations from datetime import datetime as dt +import json -from sqlalchemy.sql.lambdas import StatementLambdaElement +from sqlalchemy.sql.selectable import Select from homeassistant.components.recorder.filters import Filters @@ -21,7 +22,7 @@ def statement_for_request( device_ids: list[str] | None = None, filters: Filters | None = None, context_id: str | None = None, -) -> StatementLambdaElement: +) -> Select: """Generate the logbook statement for a logbook request.""" # No entities: logbook sends everything for the timeframe @@ -38,41 +39,36 @@ def statement_for_request( context_id, ) - # sqlalchemy caches object quoting, the - # json quotable ones must be a different - # object from the non-json ones to prevent - # sqlalchemy from quoting them incorrectly - # entities and devices: logbook sends everything for the timeframe for the entities and devices if entity_ids and device_ids: - json_quotable_entity_ids = list(entity_ids) - json_quotable_device_ids = list(device_ids) + json_quoted_entity_ids = [json.dumps(entity_id) for entity_id in entity_ids] + json_quoted_device_ids = [json.dumps(device_id) for device_id in device_ids] return entities_devices_stmt( start_day, end_day, event_types, entity_ids, - json_quotable_entity_ids, - json_quotable_device_ids, + json_quoted_entity_ids, + json_quoted_device_ids, ) # entities: logbook sends everything for the timeframe for the entities if entity_ids: - json_quotable_entity_ids = list(entity_ids) + json_quoted_entity_ids = [json.dumps(entity_id) for entity_id in entity_ids] return entities_stmt( start_day, end_day, event_types, entity_ids, - json_quotable_entity_ids, + json_quoted_entity_ids, ) # devices: logbook sends everything for the timeframe for the devices assert device_ids is not None - json_quotable_device_ids = list(device_ids) + json_quoted_device_ids = [json.dumps(device_id) for device_id in device_ids] return devices_stmt( start_day, end_day, event_types, - json_quotable_device_ids, + json_quoted_device_ids, ) diff --git a/homeassistant/components/logbook/queries/all.py b/homeassistant/components/logbook/queries/all.py index da05aa02fff..e0a651c7972 100644 --- a/homeassistant/components/logbook/queries/all.py +++ b/homeassistant/components/logbook/queries/all.py @@ -3,10 +3,9 @@ from __future__ import annotations from datetime import datetime as dt -from sqlalchemy import lambda_stmt from sqlalchemy.orm import Query from sqlalchemy.sql.elements import ClauseList -from sqlalchemy.sql.lambdas import StatementLambdaElement +from sqlalchemy.sql.selectable import Select from homeassistant.components.recorder.db_schema import ( LAST_UPDATED_INDEX, @@ -29,32 +28,29 @@ def all_stmt( states_entity_filter: ClauseList | None = None, events_entity_filter: ClauseList | None = None, context_id: str | None = None, -) -> StatementLambdaElement: +) -> Select: """Generate a logbook query for all entities.""" - stmt = lambda_stmt( - lambda: select_events_without_states(start_day, end_day, event_types) - ) + stmt = select_events_without_states(start_day, end_day, event_types) if context_id is not None: # Once all the old `state_changed` events # are gone from the database remove the # _legacy_select_events_context_id() - stmt += lambda s: s.where(Events.context_id == context_id).union_all( + stmt = stmt.where(Events.context_id == context_id).union_all( _states_query_for_context_id(start_day, end_day, context_id), legacy_select_events_context_id(start_day, end_day, context_id), ) else: if events_entity_filter is not None: - stmt += lambda s: s.where(events_entity_filter) + stmt = stmt.where(events_entity_filter) if states_entity_filter is not None: - stmt += lambda s: s.union_all( + stmt = stmt.union_all( _states_query_for_all(start_day, end_day).where(states_entity_filter) ) else: - stmt += lambda s: s.union_all(_states_query_for_all(start_day, end_day)) + stmt = stmt.union_all(_states_query_for_all(start_day, end_day)) - stmt += lambda s: s.order_by(Events.time_fired) - return stmt + return stmt.order_by(Events.time_fired) def _states_query_for_all(start_day: dt, end_day: dt) -> Query: diff --git a/homeassistant/components/logbook/queries/devices.py b/homeassistant/components/logbook/queries/devices.py index f750c552bc4..cbe766fb02c 100644 --- a/homeassistant/components/logbook/queries/devices.py +++ b/homeassistant/components/logbook/queries/devices.py @@ -4,11 +4,10 @@ from __future__ import annotations from collections.abc import Iterable from datetime import datetime as dt -from sqlalchemy import lambda_stmt, select +from sqlalchemy import select from sqlalchemy.orm import Query from sqlalchemy.sql.elements import ClauseList -from sqlalchemy.sql.lambdas import StatementLambdaElement -from sqlalchemy.sql.selectable import CTE, CompoundSelect +from sqlalchemy.sql.selectable import CTE, CompoundSelect, Select from homeassistant.components.recorder.db_schema import ( DEVICE_ID_IN_EVENT, @@ -31,11 +30,11 @@ def _select_device_id_context_ids_sub_query( start_day: dt, end_day: dt, event_types: tuple[str, ...], - json_quotable_device_ids: list[str], + json_quoted_device_ids: list[str], ) -> CompoundSelect: """Generate a subquery to find context ids for multiple devices.""" inner = select_events_context_id_subquery(start_day, end_day, event_types).where( - apply_event_device_id_matchers(json_quotable_device_ids) + apply_event_device_id_matchers(json_quoted_device_ids) ) return select(inner.c.context_id).group_by(inner.c.context_id) @@ -45,14 +44,14 @@ def _apply_devices_context_union( start_day: dt, end_day: dt, event_types: tuple[str, ...], - json_quotable_device_ids: list[str], + json_quoted_device_ids: list[str], ) -> CompoundSelect: """Generate a CTE to find the device context ids and a query to find linked row.""" devices_cte: CTE = _select_device_id_context_ids_sub_query( start_day, end_day, event_types, - json_quotable_device_ids, + json_quoted_device_ids, ).cte() return query.union_all( apply_events_context_hints( @@ -72,25 +71,22 @@ def devices_stmt( start_day: dt, end_day: dt, event_types: tuple[str, ...], - json_quotable_device_ids: list[str], -) -> StatementLambdaElement: + json_quoted_device_ids: list[str], +) -> Select: """Generate a logbook query for multiple devices.""" - stmt = lambda_stmt( - lambda: _apply_devices_context_union( - select_events_without_states(start_day, end_day, event_types).where( - apply_event_device_id_matchers(json_quotable_device_ids) - ), - start_day, - end_day, - event_types, - json_quotable_device_ids, - ).order_by(Events.time_fired) - ) - return stmt + return _apply_devices_context_union( + select_events_without_states(start_day, end_day, event_types).where( + apply_event_device_id_matchers(json_quoted_device_ids) + ), + start_day, + end_day, + event_types, + json_quoted_device_ids, + ).order_by(Events.time_fired) def apply_event_device_id_matchers( - json_quotable_device_ids: Iterable[str], + json_quoted_device_ids: Iterable[str], ) -> ClauseList: """Create matchers for the device_ids in the event_data.""" - return DEVICE_ID_IN_EVENT.in_(json_quotable_device_ids) + return DEVICE_ID_IN_EVENT.in_(json_quoted_device_ids) diff --git a/homeassistant/components/logbook/queries/entities.py b/homeassistant/components/logbook/queries/entities.py index 4ef96c100d7..4d250fbb0f1 100644 --- a/homeassistant/components/logbook/queries/entities.py +++ b/homeassistant/components/logbook/queries/entities.py @@ -5,10 +5,9 @@ from collections.abc import Iterable from datetime import datetime as dt import sqlalchemy -from sqlalchemy import lambda_stmt, select, union_all +from sqlalchemy import select, union_all from sqlalchemy.orm import Query -from sqlalchemy.sql.lambdas import StatementLambdaElement -from sqlalchemy.sql.selectable import CTE, CompoundSelect +from sqlalchemy.sql.selectable import CTE, CompoundSelect, Select from homeassistant.components.recorder.db_schema import ( ENTITY_ID_IN_EVENT, @@ -36,12 +35,12 @@ def _select_entities_context_ids_sub_query( end_day: dt, event_types: tuple[str, ...], entity_ids: list[str], - json_quotable_entity_ids: list[str], + json_quoted_entity_ids: list[str], ) -> CompoundSelect: """Generate a subquery to find context ids for multiple entities.""" union = union_all( select_events_context_id_subquery(start_day, end_day, event_types).where( - apply_event_entity_id_matchers(json_quotable_entity_ids) + apply_event_entity_id_matchers(json_quoted_entity_ids) ), apply_entities_hints(select(States.context_id)) .filter((States.last_updated > start_day) & (States.last_updated < end_day)) @@ -56,7 +55,7 @@ def _apply_entities_context_union( end_day: dt, event_types: tuple[str, ...], entity_ids: list[str], - json_quotable_entity_ids: list[str], + json_quoted_entity_ids: list[str], ) -> CompoundSelect: """Generate a CTE to find the entity and device context ids and a query to find linked row.""" entities_cte: CTE = _select_entities_context_ids_sub_query( @@ -64,7 +63,7 @@ def _apply_entities_context_union( end_day, event_types, entity_ids, - json_quotable_entity_ids, + json_quoted_entity_ids, ).cte() # We used to optimize this to exclude rows we already in the union with # a States.entity_id.not_in(entity_ids) but that made the @@ -91,21 +90,19 @@ def entities_stmt( end_day: dt, event_types: tuple[str, ...], entity_ids: list[str], - json_quotable_entity_ids: list[str], -) -> StatementLambdaElement: + json_quoted_entity_ids: list[str], +) -> Select: """Generate a logbook query for multiple entities.""" - return lambda_stmt( - lambda: _apply_entities_context_union( - select_events_without_states(start_day, end_day, event_types).where( - apply_event_entity_id_matchers(json_quotable_entity_ids) - ), - start_day, - end_day, - event_types, - entity_ids, - json_quotable_entity_ids, - ).order_by(Events.time_fired) - ) + return _apply_entities_context_union( + select_events_without_states(start_day, end_day, event_types).where( + apply_event_entity_id_matchers(json_quoted_entity_ids) + ), + start_day, + end_day, + event_types, + entity_ids, + json_quoted_entity_ids, + ).order_by(Events.time_fired) def states_query_for_entity_ids( @@ -118,12 +115,12 @@ def states_query_for_entity_ids( def apply_event_entity_id_matchers( - json_quotable_entity_ids: Iterable[str], + json_quoted_entity_ids: Iterable[str], ) -> sqlalchemy.or_: """Create matchers for the entity_id in the event_data.""" - return ENTITY_ID_IN_EVENT.in_( - json_quotable_entity_ids - ) | OLD_ENTITY_ID_IN_EVENT.in_(json_quotable_entity_ids) + return ENTITY_ID_IN_EVENT.in_(json_quoted_entity_ids) | OLD_ENTITY_ID_IN_EVENT.in_( + json_quoted_entity_ids + ) def apply_entities_hints(query: Query) -> Query: diff --git a/homeassistant/components/logbook/queries/entities_and_devices.py b/homeassistant/components/logbook/queries/entities_and_devices.py index 591918dd653..8b8051e2966 100644 --- a/homeassistant/components/logbook/queries/entities_and_devices.py +++ b/homeassistant/components/logbook/queries/entities_and_devices.py @@ -5,10 +5,9 @@ from collections.abc import Iterable from datetime import datetime as dt import sqlalchemy -from sqlalchemy import lambda_stmt, select, union_all +from sqlalchemy import select, union_all from sqlalchemy.orm import Query -from sqlalchemy.sql.lambdas import StatementLambdaElement -from sqlalchemy.sql.selectable import CTE, CompoundSelect +from sqlalchemy.sql.selectable import CTE, CompoundSelect, Select from homeassistant.components.recorder.db_schema import EventData, Events, States @@ -33,14 +32,14 @@ def _select_entities_device_id_context_ids_sub_query( end_day: dt, event_types: tuple[str, ...], entity_ids: list[str], - json_quotable_entity_ids: list[str], - json_quotable_device_ids: list[str], + json_quoted_entity_ids: list[str], + json_quoted_device_ids: list[str], ) -> CompoundSelect: """Generate a subquery to find context ids for multiple entities and multiple devices.""" union = union_all( select_events_context_id_subquery(start_day, end_day, event_types).where( _apply_event_entity_id_device_id_matchers( - json_quotable_entity_ids, json_quotable_device_ids + json_quoted_entity_ids, json_quoted_device_ids ) ), apply_entities_hints(select(States.context_id)) @@ -56,16 +55,16 @@ def _apply_entities_devices_context_union( end_day: dt, event_types: tuple[str, ...], entity_ids: list[str], - json_quotable_entity_ids: list[str], - json_quotable_device_ids: list[str], + json_quoted_entity_ids: list[str], + json_quoted_device_ids: list[str], ) -> CompoundSelect: devices_entities_cte: CTE = _select_entities_device_id_context_ids_sub_query( start_day, end_day, event_types, entity_ids, - json_quotable_entity_ids, - json_quotable_device_ids, + json_quoted_entity_ids, + json_quoted_device_ids, ).cte() # We used to optimize this to exclude rows we already in the union with # a States.entity_id.not_in(entity_ids) but that made the @@ -92,32 +91,30 @@ def entities_devices_stmt( end_day: dt, event_types: tuple[str, ...], entity_ids: list[str], - json_quotable_entity_ids: list[str], - json_quotable_device_ids: list[str], -) -> StatementLambdaElement: + json_quoted_entity_ids: list[str], + json_quoted_device_ids: list[str], +) -> Select: """Generate a logbook query for multiple entities.""" - stmt = lambda_stmt( - lambda: _apply_entities_devices_context_union( - select_events_without_states(start_day, end_day, event_types).where( - _apply_event_entity_id_device_id_matchers( - json_quotable_entity_ids, json_quotable_device_ids - ) - ), - start_day, - end_day, - event_types, - entity_ids, - json_quotable_entity_ids, - json_quotable_device_ids, - ).order_by(Events.time_fired) - ) + stmt = _apply_entities_devices_context_union( + select_events_without_states(start_day, end_day, event_types).where( + _apply_event_entity_id_device_id_matchers( + json_quoted_entity_ids, json_quoted_device_ids + ) + ), + start_day, + end_day, + event_types, + entity_ids, + json_quoted_entity_ids, + json_quoted_device_ids, + ).order_by(Events.time_fired) return stmt def _apply_event_entity_id_device_id_matchers( - json_quotable_entity_ids: Iterable[str], json_quotable_device_ids: Iterable[str] + json_quoted_entity_ids: Iterable[str], json_quoted_device_ids: Iterable[str] ) -> sqlalchemy.or_: """Create matchers for the device_id and entity_id in the event_data.""" return apply_event_entity_id_matchers( - json_quotable_entity_ids - ) | apply_event_device_id_matchers(json_quotable_device_ids) + json_quoted_entity_ids + ) | apply_event_device_id_matchers(json_quoted_device_ids) diff --git a/homeassistant/components/recorder/history.py b/homeassistant/components/recorder/history.py index e1eca282a3a..1238b63f3c9 100644 --- a/homeassistant/components/recorder/history.py +++ b/homeassistant/components/recorder/history.py @@ -9,13 +9,11 @@ import logging import time from typing import Any, cast -from sqlalchemy import Column, Text, and_, func, lambda_stmt, or_, select +from sqlalchemy import Column, Text, and_, func, or_, select from sqlalchemy.engine.row import Row -from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session from sqlalchemy.sql.expression import literal -from sqlalchemy.sql.lambdas import StatementLambdaElement -from sqlalchemy.sql.selectable import Subquery +from sqlalchemy.sql.selectable import Select, Subquery from homeassistant.components import recorder from homeassistant.components.websocket_api.const import ( @@ -34,7 +32,7 @@ from .models import ( process_timestamp_to_utc_isoformat, row_to_compressed_state, ) -from .util import execute_stmt_lambda_element, session_scope +from .util import execute_stmt, session_scope # mypy: allow-untyped-defs, no-check-untyped-defs @@ -114,22 +112,18 @@ def _schema_version(hass: HomeAssistant) -> int: return recorder.get_instance(hass).schema_version -def lambda_stmt_and_join_attributes( +def stmt_and_join_attributes( schema_version: int, no_attributes: bool, include_last_changed: bool = True -) -> tuple[StatementLambdaElement, bool]: - """Return the lambda_stmt and if StateAttributes should be joined. - - Because these are lambda_stmt the values inside the lambdas need - to be explicitly written out to avoid caching the wrong values. - """ +) -> tuple[Select, bool]: + """Return the stmt and if StateAttributes should be joined.""" # If no_attributes was requested we do the query # without the attributes fields and do not join the # state_attributes table if no_attributes: if include_last_changed: - return lambda_stmt(lambda: select(*QUERY_STATE_NO_ATTR)), False + return select(*QUERY_STATE_NO_ATTR), False return ( - lambda_stmt(lambda: select(*QUERY_STATE_NO_ATTR_NO_LAST_CHANGED)), + select(*QUERY_STATE_NO_ATTR_NO_LAST_CHANGED), False, ) # If we in the process of migrating schema we do @@ -138,19 +132,19 @@ def lambda_stmt_and_join_attributes( if schema_version < 25: if include_last_changed: return ( - lambda_stmt(lambda: select(*QUERY_STATES_PRE_SCHEMA_25)), + select(*QUERY_STATES_PRE_SCHEMA_25), False, ) return ( - lambda_stmt(lambda: select(*QUERY_STATES_PRE_SCHEMA_25_NO_LAST_CHANGED)), + select(*QUERY_STATES_PRE_SCHEMA_25_NO_LAST_CHANGED), False, ) # Finally if no migration is in progress and no_attributes # was not requested, we query both attributes columns and # join state_attributes if include_last_changed: - return lambda_stmt(lambda: select(*QUERY_STATES)), True - return lambda_stmt(lambda: select(*QUERY_STATES_NO_LAST_CHANGED)), True + return select(*QUERY_STATES), True + return select(*QUERY_STATES_NO_LAST_CHANGED), True def get_significant_states( @@ -182,7 +176,7 @@ def get_significant_states( ) -def _ignore_domains_filter(query: Query) -> Query: +def _ignore_domains_filter(query: Select) -> Select: """Add a filter to ignore domains we do not fetch history for.""" return query.filter( and_( @@ -202,9 +196,9 @@ def _significant_states_stmt( filters: Filters | None, significant_changes_only: bool, no_attributes: bool, -) -> StatementLambdaElement: +) -> Select: """Query the database for significant state changes.""" - stmt, join_attributes = lambda_stmt_and_join_attributes( + stmt, join_attributes = stmt_and_join_attributes( schema_version, no_attributes, include_last_changed=not significant_changes_only ) if ( @@ -213,11 +207,11 @@ def _significant_states_stmt( and significant_changes_only and split_entity_id(entity_ids[0])[0] not in SIGNIFICANT_DOMAINS ): - stmt += lambda q: q.filter( + stmt = stmt.filter( (States.last_changed == States.last_updated) | States.last_changed.is_(None) ) elif significant_changes_only: - stmt += lambda q: q.filter( + stmt = stmt.filter( or_( *[ States.entity_id.like(entity_domain) @@ -231,25 +225,22 @@ def _significant_states_stmt( ) if entity_ids: - stmt += lambda q: q.filter(States.entity_id.in_(entity_ids)) + stmt = stmt.filter(States.entity_id.in_(entity_ids)) else: - stmt += _ignore_domains_filter + stmt = _ignore_domains_filter(stmt) if filters and filters.has_config: entity_filter = filters.states_entity_filter() - stmt = stmt.add_criteria( - lambda q: q.filter(entity_filter), track_on=[filters] - ) + stmt = stmt.filter(entity_filter) - stmt += lambda q: q.filter(States.last_updated > start_time) + stmt = stmt.filter(States.last_updated > start_time) if end_time: - stmt += lambda q: q.filter(States.last_updated < end_time) + stmt = stmt.filter(States.last_updated < end_time) if join_attributes: - stmt += lambda q: q.outerjoin( + stmt = stmt.outerjoin( StateAttributes, States.attributes_id == StateAttributes.attributes_id ) - stmt += lambda q: q.order_by(States.entity_id, States.last_updated) - return stmt + return stmt.order_by(States.entity_id, States.last_updated) def get_significant_states_with_session( @@ -286,9 +277,7 @@ def get_significant_states_with_session( significant_changes_only, no_attributes, ) - states = execute_stmt_lambda_element( - session, stmt, None if entity_ids else start_time, end_time - ) + states = execute_stmt(session, stmt, None if entity_ids else start_time, end_time) return _sorted_states_to_dict( hass, session, @@ -340,28 +329,28 @@ def _state_changed_during_period_stmt( no_attributes: bool, descending: bool, limit: int | None, -) -> StatementLambdaElement: - stmt, join_attributes = lambda_stmt_and_join_attributes( +) -> Select: + stmt, join_attributes = stmt_and_join_attributes( schema_version, no_attributes, include_last_changed=False ) - stmt += lambda q: q.filter( + stmt = stmt.filter( ((States.last_changed == States.last_updated) | States.last_changed.is_(None)) & (States.last_updated > start_time) ) if end_time: - stmt += lambda q: q.filter(States.last_updated < end_time) + stmt = stmt.filter(States.last_updated < end_time) if entity_id: - stmt += lambda q: q.filter(States.entity_id == entity_id) + stmt = stmt.filter(States.entity_id == entity_id) if join_attributes: - stmt += lambda q: q.outerjoin( + stmt = stmt.outerjoin( StateAttributes, States.attributes_id == StateAttributes.attributes_id ) if descending: - stmt += lambda q: q.order_by(States.entity_id, States.last_updated.desc()) + stmt = stmt.order_by(States.entity_id, States.last_updated.desc()) else: - stmt += lambda q: q.order_by(States.entity_id, States.last_updated) + stmt = stmt.order_by(States.entity_id, States.last_updated) if limit: - stmt += lambda q: q.limit(limit) + stmt = stmt.limit(limit) return stmt @@ -389,7 +378,7 @@ def state_changes_during_period( descending, limit, ) - states = execute_stmt_lambda_element( + states = execute_stmt( session, stmt, None if entity_id else start_time, end_time ) return cast( @@ -407,23 +396,22 @@ def state_changes_during_period( def _get_last_state_changes_stmt( schema_version: int, number_of_states: int, entity_id: str | None -) -> StatementLambdaElement: - stmt, join_attributes = lambda_stmt_and_join_attributes( +) -> Select: + stmt, join_attributes = stmt_and_join_attributes( schema_version, False, include_last_changed=False ) - stmt += lambda q: q.filter( + stmt = stmt.filter( (States.last_changed == States.last_updated) | States.last_changed.is_(None) ) if entity_id: - stmt += lambda q: q.filter(States.entity_id == entity_id) + stmt = stmt.filter(States.entity_id == entity_id) if join_attributes: - stmt += lambda q: q.outerjoin( + stmt = stmt.outerjoin( StateAttributes, States.attributes_id == StateAttributes.attributes_id ) - stmt += lambda q: q.order_by(States.entity_id, States.last_updated.desc()).limit( + return stmt.order_by(States.entity_id, States.last_updated.desc()).limit( number_of_states ) - return stmt def get_last_state_changes( @@ -438,7 +426,7 @@ def get_last_state_changes( stmt = _get_last_state_changes_stmt( _schema_version(hass), number_of_states, entity_id ) - states = list(execute_stmt_lambda_element(session, stmt)) + states = list(execute_stmt(session, stmt)) return cast( MutableMapping[str, list[State]], _sorted_states_to_dict( @@ -458,14 +446,14 @@ def _get_states_for_entites_stmt( utc_point_in_time: datetime, entity_ids: list[str], no_attributes: bool, -) -> StatementLambdaElement: +) -> Select: """Baked query to get states for specific entities.""" - stmt, join_attributes = lambda_stmt_and_join_attributes( + stmt, join_attributes = stmt_and_join_attributes( schema_version, no_attributes, include_last_changed=True ) # We got an include-list of entities, accelerate the query by filtering already # in the inner query. - stmt += lambda q: q.where( + stmt = stmt.where( States.state_id == ( select(func.max(States.state_id).label("max_state_id")) @@ -479,7 +467,7 @@ def _get_states_for_entites_stmt( ).c.max_state_id ) if join_attributes: - stmt += lambda q: q.outerjoin( + stmt = stmt.outerjoin( StateAttributes, (States.attributes_id == StateAttributes.attributes_id) ) return stmt @@ -510,9 +498,9 @@ def _get_states_for_all_stmt( utc_point_in_time: datetime, filters: Filters | None, no_attributes: bool, -) -> StatementLambdaElement: +) -> Select: """Baked query to get states for all entities.""" - stmt, join_attributes = lambda_stmt_and_join_attributes( + stmt, join_attributes = stmt_and_join_attributes( schema_version, no_attributes, include_last_changed=True ) # We did not get an include-list of entities, query all states in the inner @@ -522,7 +510,7 @@ def _get_states_for_all_stmt( most_recent_states_by_date = _generate_most_recent_states_by_date( run_start, utc_point_in_time ) - stmt += lambda q: q.where( + stmt = stmt.where( States.state_id == ( select(func.max(States.state_id).label("max_state_id")) @@ -538,12 +526,12 @@ def _get_states_for_all_stmt( .subquery() ).c.max_state_id, ) - stmt += _ignore_domains_filter + stmt = _ignore_domains_filter(stmt) if filters and filters.has_config: entity_filter = filters.states_entity_filter() - stmt = stmt.add_criteria(lambda q: q.filter(entity_filter), track_on=[filters]) + stmt = stmt.filter(entity_filter) if join_attributes: - stmt += lambda q: q.outerjoin( + stmt = stmt.outerjoin( StateAttributes, (States.attributes_id == StateAttributes.attributes_id) ) return stmt @@ -561,7 +549,7 @@ def _get_rows_with_session( """Return the states at a specific point in time.""" schema_version = _schema_version(hass) if entity_ids and len(entity_ids) == 1: - return execute_stmt_lambda_element( + return execute_stmt( session, _get_single_entity_states_stmt( schema_version, utc_point_in_time, entity_ids[0], no_attributes @@ -586,7 +574,7 @@ def _get_rows_with_session( schema_version, run.start, utc_point_in_time, filters, no_attributes ) - return execute_stmt_lambda_element(session, stmt) + return execute_stmt(session, stmt) def _get_single_entity_states_stmt( @@ -594,14 +582,14 @@ def _get_single_entity_states_stmt( utc_point_in_time: datetime, entity_id: str, no_attributes: bool = False, -) -> StatementLambdaElement: +) -> Select: # Use an entirely different (and extremely fast) query if we only # have a single entity id - stmt, join_attributes = lambda_stmt_and_join_attributes( + stmt, join_attributes = stmt_and_join_attributes( schema_version, no_attributes, include_last_changed=True ) - stmt += ( - lambda q: q.filter( + stmt = ( + stmt.filter( States.last_updated < utc_point_in_time, States.entity_id == entity_id, ) @@ -609,7 +597,7 @@ def _get_single_entity_states_stmt( .limit(1) ) if join_attributes: - stmt += lambda q: q.outerjoin( + stmt = stmt.outerjoin( StateAttributes, States.attributes_id == StateAttributes.attributes_id ) return stmt diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index 26221aa199b..8d314830ec4 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -14,13 +14,12 @@ import re from statistics import mean from typing import TYPE_CHECKING, Any, Literal, overload -from sqlalchemy import bindparam, func, lambda_stmt, select +from sqlalchemy import bindparam, func, select from sqlalchemy.engine.row import Row from sqlalchemy.exc import SQLAlchemyError, StatementError from sqlalchemy.orm.session import Session from sqlalchemy.sql.expression import literal_column, true -from sqlalchemy.sql.lambdas import StatementLambdaElement -from sqlalchemy.sql.selectable import Subquery +from sqlalchemy.sql.selectable import Select, Subquery import voluptuous as vol from homeassistant.const import ( @@ -50,12 +49,7 @@ from .models import ( process_timestamp, process_timestamp_to_utc_isoformat, ) -from .util import ( - execute, - execute_stmt_lambda_element, - retryable_database_job, - session_scope, -) +from .util import execute, execute_stmt, retryable_database_job, session_scope if TYPE_CHECKING: from . import Recorder @@ -480,10 +474,10 @@ def delete_statistics_meta_duplicates(session: Session) -> None: def _compile_hourly_statistics_summary_mean_stmt( start_time: datetime, end_time: datetime -) -> StatementLambdaElement: +) -> Select: """Generate the summary mean statement for hourly statistics.""" - return lambda_stmt( - lambda: select(*QUERY_STATISTICS_SUMMARY_MEAN) + return ( + select(*QUERY_STATISTICS_SUMMARY_MEAN) .filter(StatisticsShortTerm.start >= start_time) .filter(StatisticsShortTerm.start < end_time) .group_by(StatisticsShortTerm.metadata_id) @@ -506,7 +500,7 @@ def compile_hourly_statistics( # Compute last hour's average, min, max summary: dict[str, StatisticData] = {} stmt = _compile_hourly_statistics_summary_mean_stmt(start_time, end_time) - stats = execute_stmt_lambda_element(session, stmt) + stats = execute_stmt(session, stmt) if stats: for stat in stats: @@ -688,17 +682,17 @@ def _generate_get_metadata_stmt( statistic_ids: list[str] | tuple[str] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None, statistic_source: str | None = None, -) -> StatementLambdaElement: +) -> Select: """Generate a statement to fetch metadata.""" - stmt = lambda_stmt(lambda: select(*QUERY_STATISTIC_META)) + stmt = select(*QUERY_STATISTIC_META) if statistic_ids is not None: - stmt += lambda q: q.where(StatisticsMeta.statistic_id.in_(statistic_ids)) + stmt = stmt.where(StatisticsMeta.statistic_id.in_(statistic_ids)) if statistic_source is not None: - stmt += lambda q: q.where(StatisticsMeta.source == statistic_source) + stmt = stmt.where(StatisticsMeta.source == statistic_source) if statistic_type == "mean": - stmt += lambda q: q.where(StatisticsMeta.has_mean == true()) + stmt = stmt.where(StatisticsMeta.has_mean == true()) elif statistic_type == "sum": - stmt += lambda q: q.where(StatisticsMeta.has_sum == true()) + stmt = stmt.where(StatisticsMeta.has_sum == true()) return stmt @@ -720,7 +714,7 @@ def get_metadata_with_session( # Fetch metatadata from the database stmt = _generate_get_metadata_stmt(statistic_ids, statistic_type, statistic_source) - result = execute_stmt_lambda_element(session, stmt) + result = execute_stmt(session, stmt) if not result: return {} @@ -982,44 +976,30 @@ def _statistics_during_period_stmt( start_time: datetime, end_time: datetime | None, metadata_ids: list[int] | None, -) -> StatementLambdaElement: - """Prepare a database query for statistics during a given period. - - This prepares a lambda_stmt query, so we don't insert the parameters yet. - """ - stmt = lambda_stmt( - lambda: select(*QUERY_STATISTICS).filter(Statistics.start >= start_time) - ) +) -> Select: + """Prepare a database query for statistics during a given period.""" + stmt = select(*QUERY_STATISTICS).filter(Statistics.start >= start_time) if end_time is not None: - stmt += lambda q: q.filter(Statistics.start < end_time) + stmt = stmt.filter(Statistics.start < end_time) if metadata_ids: - stmt += lambda q: q.filter(Statistics.metadata_id.in_(metadata_ids)) - stmt += lambda q: q.order_by(Statistics.metadata_id, Statistics.start) - return stmt + stmt = stmt.filter(Statistics.metadata_id.in_(metadata_ids)) + return stmt.order_by(Statistics.metadata_id, Statistics.start) def _statistics_during_period_stmt_short_term( start_time: datetime, end_time: datetime | None, metadata_ids: list[int] | None, -) -> StatementLambdaElement: - """Prepare a database query for short term statistics during a given period. - - This prepares a lambda_stmt query, so we don't insert the parameters yet. - """ - stmt = lambda_stmt( - lambda: select(*QUERY_STATISTICS_SHORT_TERM).filter( - StatisticsShortTerm.start >= start_time - ) +) -> Select: + """Prepare a database query for short term statistics during a given period.""" + stmt = select(*QUERY_STATISTICS_SHORT_TERM).filter( + StatisticsShortTerm.start >= start_time ) if end_time is not None: - stmt += lambda q: q.filter(StatisticsShortTerm.start < end_time) + stmt = stmt.filter(StatisticsShortTerm.start < end_time) if metadata_ids: - stmt += lambda q: q.filter(StatisticsShortTerm.metadata_id.in_(metadata_ids)) - stmt += lambda q: q.order_by( - StatisticsShortTerm.metadata_id, StatisticsShortTerm.start - ) - return stmt + stmt = stmt.filter(StatisticsShortTerm.metadata_id.in_(metadata_ids)) + return stmt.order_by(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start) def statistics_during_period( @@ -1054,7 +1034,7 @@ def statistics_during_period( else: table = Statistics stmt = _statistics_during_period_stmt(start_time, end_time, metadata_ids) - stats = execute_stmt_lambda_element(session, stmt) + stats = execute_stmt(session, stmt) if not stats: return {} @@ -1085,10 +1065,10 @@ def statistics_during_period( def _get_last_statistics_stmt( metadata_id: int, number_of_stats: int, -) -> StatementLambdaElement: +) -> Select: """Generate a statement for number_of_stats statistics for a given statistic_id.""" - return lambda_stmt( - lambda: select(*QUERY_STATISTICS) + return ( + select(*QUERY_STATISTICS) .filter_by(metadata_id=metadata_id) .order_by(Statistics.metadata_id, Statistics.start.desc()) .limit(number_of_stats) @@ -1098,10 +1078,10 @@ def _get_last_statistics_stmt( def _get_last_statistics_short_term_stmt( metadata_id: int, number_of_stats: int, -) -> StatementLambdaElement: +) -> Select: """Generate a statement for number_of_stats short term statistics for a given statistic_id.""" - return lambda_stmt( - lambda: select(*QUERY_STATISTICS_SHORT_TERM) + return ( + select(*QUERY_STATISTICS_SHORT_TERM) .filter_by(metadata_id=metadata_id) .order_by(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start.desc()) .limit(number_of_stats) @@ -1127,7 +1107,7 @@ def _get_last_statistics( stmt = _get_last_statistics_stmt(metadata_id, number_of_stats) else: stmt = _get_last_statistics_short_term_stmt(metadata_id, number_of_stats) - stats = execute_stmt_lambda_element(session, stmt) + stats = execute_stmt(session, stmt) if not stats: return {} @@ -1177,11 +1157,11 @@ def _generate_most_recent_statistic_row(metadata_ids: list[int]) -> Subquery: def _latest_short_term_statistics_stmt( metadata_ids: list[int], -) -> StatementLambdaElement: +) -> Select: """Create the statement for finding the latest short term stat rows.""" - stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM)) + stmt = select(*QUERY_STATISTICS_SHORT_TERM) most_recent_statistic_row = _generate_most_recent_statistic_row(metadata_ids) - stmt += lambda s: s.join( + return stmt.join( most_recent_statistic_row, ( StatisticsShortTerm.metadata_id # pylint: disable=comparison-with-callable @@ -1189,7 +1169,6 @@ def _latest_short_term_statistics_stmt( ) & (StatisticsShortTerm.start == most_recent_statistic_row.c.start_max), ) - return stmt def get_latest_short_term_statistics( @@ -1212,7 +1191,7 @@ def get_latest_short_term_statistics( if statistic_id in metadata ] stmt = _latest_short_term_statistics_stmt(metadata_ids) - stats = execute_stmt_lambda_element(session, stmt) + stats = execute_stmt(session, stmt) if not stats: return {} diff --git a/homeassistant/components/recorder/util.py b/homeassistant/components/recorder/util.py index c1fbc831987..7e183f7f64f 100644 --- a/homeassistant/components/recorder/util.py +++ b/homeassistant/components/recorder/util.py @@ -22,7 +22,6 @@ from sqlalchemy.engine.row import Row from sqlalchemy.exc import OperationalError, SQLAlchemyError from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session -from sqlalchemy.sql.lambdas import StatementLambdaElement from typing_extensions import Concatenate, ParamSpec from homeassistant.core import HomeAssistant @@ -166,9 +165,9 @@ def execute( assert False # unreachable # pragma: no cover -def execute_stmt_lambda_element( +def execute_stmt( session: Session, - stmt: StatementLambdaElement, + query: Query, start_time: datetime | None = None, end_time: datetime | None = None, yield_per: int | None = DEFAULT_YIELD_STATES_ROWS, @@ -184,11 +183,12 @@ def execute_stmt_lambda_element( specific entities) since they are usually faster with .all(). """ - executed = session.execute(stmt) use_all = not start_time or ((end_time or dt_util.utcnow()) - start_time).days <= 1 for tryno in range(0, RETRIES): try: - return executed.all() if use_all else executed.yield_per(yield_per) # type: ignore[no-any-return] + if use_all: + return session.execute(query).all() # type: ignore[no-any-return] + return session.execute(query).yield_per(yield_per) # type: ignore[no-any-return] except SQLAlchemyError as err: _LOGGER.error("Error executing query: %s", err) if tryno == RETRIES - 1: diff --git a/tests/components/recorder/test_util.py b/tests/components/recorder/test_util.py index 8624719f951..97cf4a58b5c 100644 --- a/tests/components/recorder/test_util.py +++ b/tests/components/recorder/test_util.py @@ -9,7 +9,6 @@ from sqlalchemy import text from sqlalchemy.engine.result import ChunkedIteratorResult from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.sql.elements import TextClause -from sqlalchemy.sql.lambdas import StatementLambdaElement from homeassistant.components import recorder from homeassistant.components.recorder import history, util @@ -713,8 +712,8 @@ def test_build_mysqldb_conv(): @patch("homeassistant.components.recorder.util.QUERY_RETRY_WAIT", 0) -def test_execute_stmt_lambda_element(hass_recorder): - """Test executing with execute_stmt_lambda_element.""" +def test_execute_stmt(hass_recorder): + """Test executing with execute_stmt.""" hass = hass_recorder() instance = recorder.get_instance(hass) hass.states.set("sensor.on", "on") @@ -725,13 +724,15 @@ def test_execute_stmt_lambda_element(hass_recorder): one_week_from_now = now + timedelta(days=7) class MockExecutor: + + _calls = 0 + def __init__(self, stmt): - assert isinstance(stmt, StatementLambdaElement) - self.calls = 0 + """Init the mock.""" def all(self): - self.calls += 1 - if self.calls == 2: + MockExecutor._calls += 1 + if MockExecutor._calls == 2: return ["mock_row"] raise SQLAlchemyError @@ -740,24 +741,24 @@ def test_execute_stmt_lambda_element(hass_recorder): stmt = history._get_single_entity_states_stmt( instance.schema_version, dt_util.utcnow(), "sensor.on", False ) - rows = util.execute_stmt_lambda_element(session, stmt) + rows = util.execute_stmt(session, stmt) assert isinstance(rows, list) assert rows[0].state == new_state.state assert rows[0].entity_id == new_state.entity_id # Time window >= 2 days, we get a ChunkedIteratorResult - rows = util.execute_stmt_lambda_element(session, stmt, now, one_week_from_now) + rows = util.execute_stmt(session, stmt, now, one_week_from_now) assert isinstance(rows, ChunkedIteratorResult) row = next(rows) assert row.state == new_state.state assert row.entity_id == new_state.entity_id # Time window < 2 days, we get a list - rows = util.execute_stmt_lambda_element(session, stmt, now, tomorrow) + rows = util.execute_stmt(session, stmt, now, tomorrow) assert isinstance(rows, list) assert rows[0].state == new_state.state assert rows[0].entity_id == new_state.entity_id with patch.object(session, "execute", MockExecutor): - rows = util.execute_stmt_lambda_element(session, stmt, now, tomorrow) + rows = util.execute_stmt(session, stmt, now, tomorrow) assert rows == ["mock_row"]