Make sql subqueries threadsafe (#89254)

* Make sql subqueries threadsafe

fixes #89224

* fix join outside of lambda

* move statement generation into a seperate function to make it easier to test

* add cache key tests

* no need to mock hass
This commit is contained in:
J. Nick Koston 2023-03-06 15:44:11 -10:00 committed by GitHub
parent 9672b5f02c
commit 3c70dd9b42
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 257 additions and 179 deletions

View file

@ -17,7 +17,6 @@ from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import literal from sqlalchemy.sql.expression import literal
from sqlalchemy.sql.lambdas import StatementLambdaElement from sqlalchemy.sql.lambdas import StatementLambdaElement
from sqlalchemy.sql.selectable import Subquery
from homeassistant.const import COMPRESSED_STATE_LAST_UPDATED, COMPRESSED_STATE_STATE from homeassistant.const import COMPRESSED_STATE_LAST_UPDATED, COMPRESSED_STATE_STATE
from homeassistant.core import HomeAssistant, State, split_entity_id from homeassistant.core import HomeAssistant, State, split_entity_id
@ -592,48 +591,6 @@ def get_last_state_changes(
) )
def _generate_most_recent_states_for_entities_by_date(
schema_version: int,
run_start: datetime,
utc_point_in_time: datetime,
entity_ids: list[str],
) -> Subquery:
"""Generate the sub query for the most recent states for specific entities by date."""
if schema_version >= 31:
run_start_ts = process_timestamp(run_start).timestamp()
utc_point_in_time_ts = dt_util.utc_to_timestamp(utc_point_in_time)
return (
select(
States.entity_id.label("max_entity_id"),
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
func.max(States.last_updated_ts).label("max_last_updated"),
)
.filter(
(States.last_updated_ts >= run_start_ts)
& (States.last_updated_ts < utc_point_in_time_ts)
)
.filter(States.entity_id.in_(entity_ids))
.group_by(States.entity_id)
.subquery()
)
return (
select(
States.entity_id.label("max_entity_id"),
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
func.max(States.last_updated).label("max_last_updated"),
)
.filter(
(States.last_updated >= run_start)
& (States.last_updated < utc_point_in_time)
)
.filter(States.entity_id.in_(entity_ids))
.group_by(States.entity_id)
.subquery()
)
def _get_states_for_entities_stmt( def _get_states_for_entities_stmt(
schema_version: int, schema_version: int,
run_start: datetime, run_start: datetime,
@ -645,16 +602,29 @@ def _get_states_for_entities_stmt(
stmt, join_attributes = lambda_stmt_and_join_attributes( stmt, join_attributes = lambda_stmt_and_join_attributes(
schema_version, no_attributes, include_last_changed=True schema_version, no_attributes, include_last_changed=True
) )
most_recent_states_for_entities_by_date = (
_generate_most_recent_states_for_entities_by_date(
schema_version, run_start, utc_point_in_time, entity_ids
)
)
# We got an include-list of entities, accelerate the query by filtering already # We got an include-list of entities, accelerate the query by filtering already
# in the inner query. # in the inner query.
if schema_version >= 31: if schema_version >= 31:
run_start_ts = process_timestamp(run_start).timestamp()
utc_point_in_time_ts = dt_util.utc_to_timestamp(utc_point_in_time)
stmt += lambda q: q.join( stmt += lambda q: q.join(
most_recent_states_for_entities_by_date, (
most_recent_states_for_entities_by_date := (
select(
States.entity_id.label("max_entity_id"),
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
func.max(States.last_updated_ts).label("max_last_updated"),
)
.filter(
(States.last_updated_ts >= run_start_ts)
& (States.last_updated_ts < utc_point_in_time_ts)
)
.filter(States.entity_id.in_(entity_ids))
.group_by(States.entity_id)
.subquery()
)
),
and_( and_(
States.entity_id States.entity_id
== most_recent_states_for_entities_by_date.c.max_entity_id, == most_recent_states_for_entities_by_date.c.max_entity_id,
@ -664,7 +634,21 @@ def _get_states_for_entities_stmt(
) )
else: else:
stmt += lambda q: q.join( stmt += lambda q: q.join(
most_recent_states_for_entities_by_date, (
most_recent_states_for_entities_by_date := select(
States.entity_id.label("max_entity_id"),
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
func.max(States.last_updated).label("max_last_updated"),
)
.filter(
(States.last_updated >= run_start)
& (States.last_updated < utc_point_in_time)
)
.filter(States.entity_id.in_(entity_ids))
.group_by(States.entity_id)
.subquery()
),
and_( and_(
States.entity_id States.entity_id
== most_recent_states_for_entities_by_date.c.max_entity_id, == most_recent_states_for_entities_by_date.c.max_entity_id,
@ -679,45 +663,6 @@ def _get_states_for_entities_stmt(
return stmt return stmt
def _generate_most_recent_states_by_date(
schema_version: int,
run_start: datetime,
utc_point_in_time: datetime,
) -> Subquery:
"""Generate the sub query for the most recent states by date."""
if schema_version >= 31:
run_start_ts = process_timestamp(run_start).timestamp()
utc_point_in_time_ts = dt_util.utc_to_timestamp(utc_point_in_time)
return (
select(
States.entity_id.label("max_entity_id"),
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
func.max(States.last_updated_ts).label("max_last_updated"),
)
.filter(
(States.last_updated_ts >= run_start_ts)
& (States.last_updated_ts < utc_point_in_time_ts)
)
.group_by(States.entity_id)
.subquery()
)
return (
select(
States.entity_id.label("max_entity_id"),
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
func.max(States.last_updated).label("max_last_updated"),
)
.filter(
(States.last_updated >= run_start)
& (States.last_updated < utc_point_in_time)
)
.group_by(States.entity_id)
.subquery()
)
def _get_states_for_all_stmt( def _get_states_for_all_stmt(
schema_version: int, schema_version: int,
run_start: datetime, run_start: datetime,
@ -733,12 +678,26 @@ def _get_states_for_all_stmt(
# query, then filter out unwanted domains as well as applying the custom filter. # query, then filter out unwanted domains as well as applying the custom filter.
# This filtering can't be done in the inner query because the domain column is # This filtering can't be done in the inner query because the domain column is
# not indexed and we can't control what's in the custom filter. # not indexed and we can't control what's in the custom filter.
most_recent_states_by_date = _generate_most_recent_states_by_date(
schema_version, run_start, utc_point_in_time
)
if schema_version >= 31: if schema_version >= 31:
run_start_ts = process_timestamp(run_start).timestamp()
utc_point_in_time_ts = dt_util.utc_to_timestamp(utc_point_in_time)
stmt += lambda q: q.join( stmt += lambda q: q.join(
most_recent_states_by_date, (
most_recent_states_by_date := (
select(
States.entity_id.label("max_entity_id"),
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
func.max(States.last_updated_ts).label("max_last_updated"),
)
.filter(
(States.last_updated_ts >= run_start_ts)
& (States.last_updated_ts < utc_point_in_time_ts)
)
.group_by(States.entity_id)
.subquery()
)
),
and_( and_(
States.entity_id == most_recent_states_by_date.c.max_entity_id, States.entity_id == most_recent_states_by_date.c.max_entity_id,
States.last_updated_ts == most_recent_states_by_date.c.max_last_updated, States.last_updated_ts == most_recent_states_by_date.c.max_last_updated,
@ -746,7 +705,22 @@ def _get_states_for_all_stmt(
) )
else: else:
stmt += lambda q: q.join( stmt += lambda q: q.join(
most_recent_states_by_date, (
most_recent_states_by_date := (
select(
States.entity_id.label("max_entity_id"),
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
func.max(States.last_updated).label("max_last_updated"),
)
.filter(
(States.last_updated >= run_start)
& (States.last_updated < utc_point_in_time)
)
.group_by(States.entity_id)
.subquery()
)
),
and_( and_(
States.entity_id == most_recent_states_by_date.c.max_entity_id, States.entity_id == most_recent_states_by_date.c.max_entity_id,
States.last_updated == most_recent_states_by_date.c.max_last_updated, States.last_updated == most_recent_states_by_date.c.max_last_updated,

View file

@ -16,14 +16,13 @@ import re
from statistics import mean from statistics import mean
from typing import TYPE_CHECKING, Any, Literal, cast from typing import TYPE_CHECKING, Any, Literal, cast
from sqlalchemy import and_, bindparam, func, lambda_stmt, select, text from sqlalchemy import Select, and_, bindparam, func, lambda_stmt, select, text
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from sqlalchemy.engine.row import Row from sqlalchemy.engine.row import Row
from sqlalchemy.exc import OperationalError, SQLAlchemyError, StatementError from sqlalchemy.exc import OperationalError, SQLAlchemyError, StatementError
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import literal_column, true from sqlalchemy.sql.expression import literal_column, true
from sqlalchemy.sql.lambdas import StatementLambdaElement from sqlalchemy.sql.lambdas import StatementLambdaElement
from sqlalchemy.sql.selectable import Subquery
import voluptuous as vol import voluptuous as vol
from homeassistant.const import ATTR_UNIT_OF_MEASUREMENT from homeassistant.const import ATTR_UNIT_OF_MEASUREMENT
@ -650,27 +649,19 @@ def _compile_hourly_statistics_summary_mean_stmt(
) )
def _compile_hourly_statistics_last_sum_stmt_subquery(
start_time_ts: float, end_time_ts: float
) -> Subquery:
"""Generate the summary mean statement for hourly statistics."""
return (
select(*QUERY_STATISTICS_SUMMARY_SUM)
.filter(StatisticsShortTerm.start_ts >= start_time_ts)
.filter(StatisticsShortTerm.start_ts < end_time_ts)
.subquery()
)
def _compile_hourly_statistics_last_sum_stmt( def _compile_hourly_statistics_last_sum_stmt(
start_time_ts: float, end_time_ts: float start_time_ts: float, end_time_ts: float
) -> StatementLambdaElement: ) -> StatementLambdaElement:
"""Generate the summary mean statement for hourly statistics.""" """Generate the summary mean statement for hourly statistics."""
subquery = _compile_hourly_statistics_last_sum_stmt_subquery(
start_time_ts, end_time_ts
)
return lambda_stmt( return lambda_stmt(
lambda: select(subquery) lambda: select(
subquery := (
select(*QUERY_STATISTICS_SUMMARY_SUM)
.filter(StatisticsShortTerm.start_ts >= start_time_ts)
.filter(StatisticsShortTerm.start_ts < end_time_ts)
.subquery()
)
)
.filter(subquery.c.rownum == 1) .filter(subquery.c.rownum == 1)
.order_by(subquery.c.metadata_id) .order_by(subquery.c.metadata_id)
) )
@ -1267,7 +1258,8 @@ def _reduce_statistics_per_month(
) )
def _statistics_during_period_stmt( def _generate_statistics_during_period_stmt(
columns: Select,
start_time: datetime, start_time: datetime,
end_time: datetime | None, end_time: datetime | None,
metadata_ids: list[int] | None, metadata_ids: list[int] | None,
@ -1279,21 +1271,6 @@ def _statistics_during_period_stmt(
This prepares a lambda_stmt query, so we don't insert the parameters yet. This prepares a lambda_stmt query, so we don't insert the parameters yet.
""" """
start_time_ts = start_time.timestamp() start_time_ts = start_time.timestamp()
columns = select(table.metadata_id, table.start_ts)
if "last_reset" in types:
columns = columns.add_columns(table.last_reset_ts)
if "max" in types:
columns = columns.add_columns(table.max)
if "mean" in types:
columns = columns.add_columns(table.mean)
if "min" in types:
columns = columns.add_columns(table.min)
if "state" in types:
columns = columns.add_columns(table.state)
if "sum" in types:
columns = columns.add_columns(table.sum)
stmt = lambda_stmt(lambda: columns.filter(table.start_ts >= start_time_ts)) stmt = lambda_stmt(lambda: columns.filter(table.start_ts >= start_time_ts))
if end_time is not None: if end_time is not None:
end_time_ts = end_time.timestamp() end_time_ts = end_time.timestamp()
@ -1307,6 +1284,23 @@ def _statistics_during_period_stmt(
return stmt return stmt
def _generate_max_mean_min_statistic_in_sub_period_stmt(
columns: Select,
start_time: datetime | None,
end_time: datetime | None,
table: type[StatisticsBase],
metadata_id: int,
) -> StatementLambdaElement:
stmt = lambda_stmt(lambda: columns.filter(table.metadata_id == metadata_id))
if start_time is not None:
start_time_ts = start_time.timestamp()
stmt += lambda q: q.filter(table.start_ts >= start_time_ts)
if end_time is not None:
end_time_ts = end_time.timestamp()
stmt += lambda q: q.filter(table.start_ts < end_time_ts)
return stmt
def _get_max_mean_min_statistic_in_sub_period( def _get_max_mean_min_statistic_in_sub_period(
session: Session, session: Session,
result: dict[str, float], result: dict[str, float],
@ -1332,13 +1326,9 @@ def _get_max_mean_min_statistic_in_sub_period(
# https://github.com/sqlalchemy/sqlalchemy/issues/9189 # https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable # pylint: disable-next=not-callable
columns = columns.add_columns(func.min(table.min)) columns = columns.add_columns(func.min(table.min))
stmt = lambda_stmt(lambda: columns.filter(table.metadata_id == metadata_id)) stmt = _generate_max_mean_min_statistic_in_sub_period_stmt(
if start_time is not None: columns, start_time, end_time, table, metadata_id
start_time_ts = start_time.timestamp() )
stmt += lambda q: q.filter(table.start_ts >= start_time_ts)
if end_time is not None:
end_time_ts = end_time.timestamp()
stmt += lambda q: q.filter(table.start_ts < end_time_ts)
stats = cast(Sequence[Row[Any]], execute_stmt_lambda_element(session, stmt)) stats = cast(Sequence[Row[Any]], execute_stmt_lambda_element(session, stmt))
if not stats: if not stats:
return return
@ -1753,8 +1743,21 @@ def _statistics_during_period_with_session(
table: type[Statistics | StatisticsShortTerm] = ( table: type[Statistics | StatisticsShortTerm] = (
Statistics if period != "5minute" else StatisticsShortTerm Statistics if period != "5minute" else StatisticsShortTerm
) )
stmt = _statistics_during_period_stmt( columns = select(table.metadata_id, table.start_ts) # type: ignore[call-overload]
start_time, end_time, metadata_ids, table, types if "last_reset" in types:
columns = columns.add_columns(table.last_reset_ts)
if "max" in types:
columns = columns.add_columns(table.max)
if "mean" in types:
columns = columns.add_columns(table.mean)
if "min" in types:
columns = columns.add_columns(table.min)
if "state" in types:
columns = columns.add_columns(table.state)
if "sum" in types:
columns = columns.add_columns(table.sum)
stmt = _generate_statistics_during_period_stmt(
columns, start_time, end_time, metadata_ids, table, types
) )
stats = cast(Sequence[Row], execute_stmt_lambda_element(session, stmt)) stats = cast(Sequence[Row], execute_stmt_lambda_element(session, stmt))
@ -1919,28 +1922,24 @@ def get_last_short_term_statistics(
) )
def _generate_most_recent_statistic_row(metadata_ids: list[int]) -> Subquery:
"""Generate the subquery to find the most recent statistic row."""
return (
select(
StatisticsShortTerm.metadata_id,
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
func.max(StatisticsShortTerm.start_ts).label("start_max"),
)
.where(StatisticsShortTerm.metadata_id.in_(metadata_ids))
.group_by(StatisticsShortTerm.metadata_id)
).subquery()
def _latest_short_term_statistics_stmt( def _latest_short_term_statistics_stmt(
metadata_ids: list[int], metadata_ids: list[int],
) -> StatementLambdaElement: ) -> StatementLambdaElement:
"""Create the statement for finding the latest short term stat rows.""" """Create the statement for finding the latest short term stat rows."""
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM)) stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
most_recent_statistic_row = _generate_most_recent_statistic_row(metadata_ids)
stmt += lambda s: s.join( stmt += lambda s: s.join(
most_recent_statistic_row, (
most_recent_statistic_row := (
select(
StatisticsShortTerm.metadata_id,
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
func.max(StatisticsShortTerm.start_ts).label("start_max"),
)
.where(StatisticsShortTerm.metadata_id.in_(metadata_ids))
.group_by(StatisticsShortTerm.metadata_id)
).subquery()
),
( (
StatisticsShortTerm.metadata_id # pylint: disable=comparison-with-callable StatisticsShortTerm.metadata_id # pylint: disable=comparison-with-callable
== most_recent_statistic_row.c.metadata_id == most_recent_statistic_row.c.metadata_id
@ -1988,21 +1987,34 @@ def get_latest_short_term_statistics(
) )
def _get_most_recent_statistics_subquery( def _generate_statistics_at_time_stmt(
metadata_ids: set[int], table: type[StatisticsBase], start_time_ts: float columns: Select,
) -> Subquery: table: type[StatisticsBase],
"""Generate the subquery to find the most recent statistic row.""" metadata_ids: set[int],
return ( start_time_ts: float,
select( ) -> StatementLambdaElement:
# https://github.com/sqlalchemy/sqlalchemy/issues/9189 """Create the statement for finding the statistics for a given time."""
# pylint: disable-next=not-callable return lambda_stmt(
func.max(table.start_ts).label("max_start_ts"), lambda: columns.join(
table.metadata_id.label("max_metadata_id"), (
most_recent_statistic_ids := (
select(
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
func.max(table.start_ts).label("max_start_ts"),
table.metadata_id.label("max_metadata_id"),
)
.filter(table.start_ts < start_time_ts)
.filter(table.metadata_id.in_(metadata_ids))
.group_by(table.metadata_id)
.subquery()
)
),
and_(
table.start_ts == most_recent_statistic_ids.c.max_start_ts,
table.metadata_id == most_recent_statistic_ids.c.max_metadata_id,
),
) )
.filter(table.start_ts < start_time_ts)
.filter(table.metadata_id.in_(metadata_ids))
.group_by(table.metadata_id)
.subquery()
) )
@ -2027,19 +2039,10 @@ def _statistics_at_time(
columns = columns.add_columns(table.state) columns = columns.add_columns(table.state)
if "sum" in types: if "sum" in types:
columns = columns.add_columns(table.sum) columns = columns.add_columns(table.sum)
start_time_ts = start_time.timestamp() start_time_ts = start_time.timestamp()
most_recent_statistic_ids = _get_most_recent_statistics_subquery( stmt = _generate_statistics_at_time_stmt(
metadata_ids, table, start_time_ts columns, table, metadata_ids, start_time_ts
) )
stmt = lambda_stmt(lambda: columns).join(
most_recent_statistic_ids,
and_(
table.start_ts == most_recent_statistic_ids.c.max_start_ts,
table.metadata_id == most_recent_statistic_ids.c.max_metadata_id,
),
)
return cast(Sequence[Row], execute_stmt_lambda_element(session, stmt)) return cast(Sequence[Row], execute_stmt_lambda_element(session, stmt))

View file

@ -8,7 +8,7 @@ import sys
from unittest.mock import ANY, DEFAULT, MagicMock, patch, sentinel from unittest.mock import ANY, DEFAULT, MagicMock, patch, sentinel
import pytest import pytest
from sqlalchemy import create_engine from sqlalchemy import create_engine, select
from sqlalchemy.exc import OperationalError from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -22,6 +22,10 @@ from homeassistant.components.recorder.models import (
) )
from homeassistant.components.recorder.statistics import ( from homeassistant.components.recorder.statistics import (
STATISTIC_UNIT_TO_UNIT_CONVERTER, STATISTIC_UNIT_TO_UNIT_CONVERTER,
_generate_get_metadata_stmt,
_generate_max_mean_min_statistic_in_sub_period_stmt,
_generate_statistics_at_time_stmt,
_generate_statistics_during_period_stmt,
_statistics_during_period_with_session, _statistics_during_period_with_session,
_update_or_add_metadata, _update_or_add_metadata,
async_add_external_statistics, async_add_external_statistics,
@ -1799,3 +1803,100 @@ def record_states(hass):
states[sns4].append(set_state(sns4, "20", attributes=sns4_attr)) states[sns4].append(set_state(sns4, "20", attributes=sns4_attr))
return zero, four, states return zero, four, states
def test_cache_key_for_generate_statistics_during_period_stmt():
"""Test cache key for _generate_statistics_during_period_stmt."""
columns = select(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start_ts)
stmt = _generate_statistics_during_period_stmt(
columns, dt_util.utcnow(), dt_util.utcnow(), [0], StatisticsShortTerm, {}
)
cache_key_1 = stmt._generate_cache_key()
stmt2 = _generate_statistics_during_period_stmt(
columns, dt_util.utcnow(), dt_util.utcnow(), [0], StatisticsShortTerm, {}
)
cache_key_2 = stmt2._generate_cache_key()
assert cache_key_1 == cache_key_2
columns2 = select(
StatisticsShortTerm.metadata_id,
StatisticsShortTerm.start_ts,
StatisticsShortTerm.sum,
StatisticsShortTerm.mean,
)
stmt3 = _generate_statistics_during_period_stmt(
columns2,
dt_util.utcnow(),
dt_util.utcnow(),
[0],
StatisticsShortTerm,
{"max", "mean"},
)
cache_key_3 = stmt3._generate_cache_key()
assert cache_key_1 != cache_key_3
def test_cache_key_for_generate_get_metadata_stmt():
"""Test cache key for _generate_get_metadata_stmt."""
stmt_mean = _generate_get_metadata_stmt([0], "mean")
stmt_mean2 = _generate_get_metadata_stmt([1], "mean")
stmt_sum = _generate_get_metadata_stmt([0], "sum")
stmt_none = _generate_get_metadata_stmt()
assert stmt_mean._generate_cache_key() == stmt_mean2._generate_cache_key()
assert stmt_mean._generate_cache_key() != stmt_sum._generate_cache_key()
assert stmt_mean._generate_cache_key() != stmt_none._generate_cache_key()
def test_cache_key_for_generate_max_mean_min_statistic_in_sub_period_stmt():
"""Test cache key for _generate_max_mean_min_statistic_in_sub_period_stmt."""
columns = select(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start_ts)
stmt = _generate_max_mean_min_statistic_in_sub_period_stmt(
columns,
dt_util.utcnow(),
dt_util.utcnow(),
StatisticsShortTerm,
[0],
)
cache_key_1 = stmt._generate_cache_key()
stmt2 = _generate_max_mean_min_statistic_in_sub_period_stmt(
columns,
dt_util.utcnow(),
dt_util.utcnow(),
StatisticsShortTerm,
[0],
)
cache_key_2 = stmt2._generate_cache_key()
assert cache_key_1 == cache_key_2
columns2 = select(
StatisticsShortTerm.metadata_id,
StatisticsShortTerm.start_ts,
StatisticsShortTerm.sum,
StatisticsShortTerm.mean,
)
stmt3 = _generate_max_mean_min_statistic_in_sub_period_stmt(
columns2,
dt_util.utcnow(),
dt_util.utcnow(),
StatisticsShortTerm,
[0],
)
cache_key_3 = stmt3._generate_cache_key()
assert cache_key_1 != cache_key_3
def test_cache_key_for_generate_statistics_at_time_stmt():
"""Test cache key for _generate_statistics_at_time_stmt."""
columns = select(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start_ts)
stmt = _generate_statistics_at_time_stmt(columns, StatisticsShortTerm, {0}, 0.0)
cache_key_1 = stmt._generate_cache_key()
stmt2 = _generate_statistics_at_time_stmt(columns, StatisticsShortTerm, {0}, 0.0)
cache_key_2 = stmt2._generate_cache_key()
assert cache_key_1 == cache_key_2
columns2 = select(
StatisticsShortTerm.metadata_id,
StatisticsShortTerm.start_ts,
StatisticsShortTerm.sum,
StatisticsShortTerm.mean,
)
stmt3 = _generate_statistics_at_time_stmt(columns2, StatisticsShortTerm, {0}, 0.0)
cache_key_3 = stmt3._generate_cache_key()
assert cache_key_1 != cache_key_3