Make sql subqueries threadsafe (#89254)
* Make sql subqueries threadsafe fixes #89224 * fix join outside of lambda * move statement generation into a seperate function to make it easier to test * add cache key tests * no need to mock hass
This commit is contained in:
parent
9672b5f02c
commit
3c70dd9b42
3 changed files with 257 additions and 179 deletions
|
@ -17,7 +17,6 @@ from sqlalchemy.orm.query import Query
|
||||||
from sqlalchemy.orm.session import Session
|
from sqlalchemy.orm.session import Session
|
||||||
from sqlalchemy.sql.expression import literal
|
from sqlalchemy.sql.expression import literal
|
||||||
from sqlalchemy.sql.lambdas import StatementLambdaElement
|
from sqlalchemy.sql.lambdas import StatementLambdaElement
|
||||||
from sqlalchemy.sql.selectable import Subquery
|
|
||||||
|
|
||||||
from homeassistant.const import COMPRESSED_STATE_LAST_UPDATED, COMPRESSED_STATE_STATE
|
from homeassistant.const import COMPRESSED_STATE_LAST_UPDATED, COMPRESSED_STATE_STATE
|
||||||
from homeassistant.core import HomeAssistant, State, split_entity_id
|
from homeassistant.core import HomeAssistant, State, split_entity_id
|
||||||
|
@ -592,48 +591,6 @@ def get_last_state_changes(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _generate_most_recent_states_for_entities_by_date(
|
|
||||||
schema_version: int,
|
|
||||||
run_start: datetime,
|
|
||||||
utc_point_in_time: datetime,
|
|
||||||
entity_ids: list[str],
|
|
||||||
) -> Subquery:
|
|
||||||
"""Generate the sub query for the most recent states for specific entities by date."""
|
|
||||||
if schema_version >= 31:
|
|
||||||
run_start_ts = process_timestamp(run_start).timestamp()
|
|
||||||
utc_point_in_time_ts = dt_util.utc_to_timestamp(utc_point_in_time)
|
|
||||||
return (
|
|
||||||
select(
|
|
||||||
States.entity_id.label("max_entity_id"),
|
|
||||||
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
|
|
||||||
# pylint: disable-next=not-callable
|
|
||||||
func.max(States.last_updated_ts).label("max_last_updated"),
|
|
||||||
)
|
|
||||||
.filter(
|
|
||||||
(States.last_updated_ts >= run_start_ts)
|
|
||||||
& (States.last_updated_ts < utc_point_in_time_ts)
|
|
||||||
)
|
|
||||||
.filter(States.entity_id.in_(entity_ids))
|
|
||||||
.group_by(States.entity_id)
|
|
||||||
.subquery()
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
select(
|
|
||||||
States.entity_id.label("max_entity_id"),
|
|
||||||
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
|
|
||||||
# pylint: disable-next=not-callable
|
|
||||||
func.max(States.last_updated).label("max_last_updated"),
|
|
||||||
)
|
|
||||||
.filter(
|
|
||||||
(States.last_updated >= run_start)
|
|
||||||
& (States.last_updated < utc_point_in_time)
|
|
||||||
)
|
|
||||||
.filter(States.entity_id.in_(entity_ids))
|
|
||||||
.group_by(States.entity_id)
|
|
||||||
.subquery()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_states_for_entities_stmt(
|
def _get_states_for_entities_stmt(
|
||||||
schema_version: int,
|
schema_version: int,
|
||||||
run_start: datetime,
|
run_start: datetime,
|
||||||
|
@ -645,16 +602,29 @@ def _get_states_for_entities_stmt(
|
||||||
stmt, join_attributes = lambda_stmt_and_join_attributes(
|
stmt, join_attributes = lambda_stmt_and_join_attributes(
|
||||||
schema_version, no_attributes, include_last_changed=True
|
schema_version, no_attributes, include_last_changed=True
|
||||||
)
|
)
|
||||||
most_recent_states_for_entities_by_date = (
|
|
||||||
_generate_most_recent_states_for_entities_by_date(
|
|
||||||
schema_version, run_start, utc_point_in_time, entity_ids
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# We got an include-list of entities, accelerate the query by filtering already
|
# We got an include-list of entities, accelerate the query by filtering already
|
||||||
# in the inner query.
|
# in the inner query.
|
||||||
if schema_version >= 31:
|
if schema_version >= 31:
|
||||||
|
run_start_ts = process_timestamp(run_start).timestamp()
|
||||||
|
utc_point_in_time_ts = dt_util.utc_to_timestamp(utc_point_in_time)
|
||||||
stmt += lambda q: q.join(
|
stmt += lambda q: q.join(
|
||||||
most_recent_states_for_entities_by_date,
|
(
|
||||||
|
most_recent_states_for_entities_by_date := (
|
||||||
|
select(
|
||||||
|
States.entity_id.label("max_entity_id"),
|
||||||
|
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
|
||||||
|
# pylint: disable-next=not-callable
|
||||||
|
func.max(States.last_updated_ts).label("max_last_updated"),
|
||||||
|
)
|
||||||
|
.filter(
|
||||||
|
(States.last_updated_ts >= run_start_ts)
|
||||||
|
& (States.last_updated_ts < utc_point_in_time_ts)
|
||||||
|
)
|
||||||
|
.filter(States.entity_id.in_(entity_ids))
|
||||||
|
.group_by(States.entity_id)
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
),
|
||||||
and_(
|
and_(
|
||||||
States.entity_id
|
States.entity_id
|
||||||
== most_recent_states_for_entities_by_date.c.max_entity_id,
|
== most_recent_states_for_entities_by_date.c.max_entity_id,
|
||||||
|
@ -664,7 +634,21 @@ def _get_states_for_entities_stmt(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
stmt += lambda q: q.join(
|
stmt += lambda q: q.join(
|
||||||
most_recent_states_for_entities_by_date,
|
(
|
||||||
|
most_recent_states_for_entities_by_date := select(
|
||||||
|
States.entity_id.label("max_entity_id"),
|
||||||
|
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
|
||||||
|
# pylint: disable-next=not-callable
|
||||||
|
func.max(States.last_updated).label("max_last_updated"),
|
||||||
|
)
|
||||||
|
.filter(
|
||||||
|
(States.last_updated >= run_start)
|
||||||
|
& (States.last_updated < utc_point_in_time)
|
||||||
|
)
|
||||||
|
.filter(States.entity_id.in_(entity_ids))
|
||||||
|
.group_by(States.entity_id)
|
||||||
|
.subquery()
|
||||||
|
),
|
||||||
and_(
|
and_(
|
||||||
States.entity_id
|
States.entity_id
|
||||||
== most_recent_states_for_entities_by_date.c.max_entity_id,
|
== most_recent_states_for_entities_by_date.c.max_entity_id,
|
||||||
|
@ -679,45 +663,6 @@ def _get_states_for_entities_stmt(
|
||||||
return stmt
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
def _generate_most_recent_states_by_date(
|
|
||||||
schema_version: int,
|
|
||||||
run_start: datetime,
|
|
||||||
utc_point_in_time: datetime,
|
|
||||||
) -> Subquery:
|
|
||||||
"""Generate the sub query for the most recent states by date."""
|
|
||||||
if schema_version >= 31:
|
|
||||||
run_start_ts = process_timestamp(run_start).timestamp()
|
|
||||||
utc_point_in_time_ts = dt_util.utc_to_timestamp(utc_point_in_time)
|
|
||||||
return (
|
|
||||||
select(
|
|
||||||
States.entity_id.label("max_entity_id"),
|
|
||||||
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
|
|
||||||
# pylint: disable-next=not-callable
|
|
||||||
func.max(States.last_updated_ts).label("max_last_updated"),
|
|
||||||
)
|
|
||||||
.filter(
|
|
||||||
(States.last_updated_ts >= run_start_ts)
|
|
||||||
& (States.last_updated_ts < utc_point_in_time_ts)
|
|
||||||
)
|
|
||||||
.group_by(States.entity_id)
|
|
||||||
.subquery()
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
select(
|
|
||||||
States.entity_id.label("max_entity_id"),
|
|
||||||
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
|
|
||||||
# pylint: disable-next=not-callable
|
|
||||||
func.max(States.last_updated).label("max_last_updated"),
|
|
||||||
)
|
|
||||||
.filter(
|
|
||||||
(States.last_updated >= run_start)
|
|
||||||
& (States.last_updated < utc_point_in_time)
|
|
||||||
)
|
|
||||||
.group_by(States.entity_id)
|
|
||||||
.subquery()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_states_for_all_stmt(
|
def _get_states_for_all_stmt(
|
||||||
schema_version: int,
|
schema_version: int,
|
||||||
run_start: datetime,
|
run_start: datetime,
|
||||||
|
@ -733,12 +678,26 @@ def _get_states_for_all_stmt(
|
||||||
# query, then filter out unwanted domains as well as applying the custom filter.
|
# query, then filter out unwanted domains as well as applying the custom filter.
|
||||||
# This filtering can't be done in the inner query because the domain column is
|
# This filtering can't be done in the inner query because the domain column is
|
||||||
# not indexed and we can't control what's in the custom filter.
|
# not indexed and we can't control what's in the custom filter.
|
||||||
most_recent_states_by_date = _generate_most_recent_states_by_date(
|
|
||||||
schema_version, run_start, utc_point_in_time
|
|
||||||
)
|
|
||||||
if schema_version >= 31:
|
if schema_version >= 31:
|
||||||
|
run_start_ts = process_timestamp(run_start).timestamp()
|
||||||
|
utc_point_in_time_ts = dt_util.utc_to_timestamp(utc_point_in_time)
|
||||||
stmt += lambda q: q.join(
|
stmt += lambda q: q.join(
|
||||||
most_recent_states_by_date,
|
(
|
||||||
|
most_recent_states_by_date := (
|
||||||
|
select(
|
||||||
|
States.entity_id.label("max_entity_id"),
|
||||||
|
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
|
||||||
|
# pylint: disable-next=not-callable
|
||||||
|
func.max(States.last_updated_ts).label("max_last_updated"),
|
||||||
|
)
|
||||||
|
.filter(
|
||||||
|
(States.last_updated_ts >= run_start_ts)
|
||||||
|
& (States.last_updated_ts < utc_point_in_time_ts)
|
||||||
|
)
|
||||||
|
.group_by(States.entity_id)
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
),
|
||||||
and_(
|
and_(
|
||||||
States.entity_id == most_recent_states_by_date.c.max_entity_id,
|
States.entity_id == most_recent_states_by_date.c.max_entity_id,
|
||||||
States.last_updated_ts == most_recent_states_by_date.c.max_last_updated,
|
States.last_updated_ts == most_recent_states_by_date.c.max_last_updated,
|
||||||
|
@ -746,7 +705,22 @@ def _get_states_for_all_stmt(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
stmt += lambda q: q.join(
|
stmt += lambda q: q.join(
|
||||||
most_recent_states_by_date,
|
(
|
||||||
|
most_recent_states_by_date := (
|
||||||
|
select(
|
||||||
|
States.entity_id.label("max_entity_id"),
|
||||||
|
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
|
||||||
|
# pylint: disable-next=not-callable
|
||||||
|
func.max(States.last_updated).label("max_last_updated"),
|
||||||
|
)
|
||||||
|
.filter(
|
||||||
|
(States.last_updated >= run_start)
|
||||||
|
& (States.last_updated < utc_point_in_time)
|
||||||
|
)
|
||||||
|
.group_by(States.entity_id)
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
),
|
||||||
and_(
|
and_(
|
||||||
States.entity_id == most_recent_states_by_date.c.max_entity_id,
|
States.entity_id == most_recent_states_by_date.c.max_entity_id,
|
||||||
States.last_updated == most_recent_states_by_date.c.max_last_updated,
|
States.last_updated == most_recent_states_by_date.c.max_last_updated,
|
||||||
|
|
|
@ -16,14 +16,13 @@ import re
|
||||||
from statistics import mean
|
from statistics import mean
|
||||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||||
|
|
||||||
from sqlalchemy import and_, bindparam, func, lambda_stmt, select, text
|
from sqlalchemy import Select, and_, bindparam, func, lambda_stmt, select, text
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
from sqlalchemy.engine.row import Row
|
from sqlalchemy.engine.row import Row
|
||||||
from sqlalchemy.exc import OperationalError, SQLAlchemyError, StatementError
|
from sqlalchemy.exc import OperationalError, SQLAlchemyError, StatementError
|
||||||
from sqlalchemy.orm.session import Session
|
from sqlalchemy.orm.session import Session
|
||||||
from sqlalchemy.sql.expression import literal_column, true
|
from sqlalchemy.sql.expression import literal_column, true
|
||||||
from sqlalchemy.sql.lambdas import StatementLambdaElement
|
from sqlalchemy.sql.lambdas import StatementLambdaElement
|
||||||
from sqlalchemy.sql.selectable import Subquery
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.const import ATTR_UNIT_OF_MEASUREMENT
|
from homeassistant.const import ATTR_UNIT_OF_MEASUREMENT
|
||||||
|
@ -650,27 +649,19 @@ def _compile_hourly_statistics_summary_mean_stmt(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _compile_hourly_statistics_last_sum_stmt_subquery(
|
|
||||||
start_time_ts: float, end_time_ts: float
|
|
||||||
) -> Subquery:
|
|
||||||
"""Generate the summary mean statement for hourly statistics."""
|
|
||||||
return (
|
|
||||||
select(*QUERY_STATISTICS_SUMMARY_SUM)
|
|
||||||
.filter(StatisticsShortTerm.start_ts >= start_time_ts)
|
|
||||||
.filter(StatisticsShortTerm.start_ts < end_time_ts)
|
|
||||||
.subquery()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _compile_hourly_statistics_last_sum_stmt(
|
def _compile_hourly_statistics_last_sum_stmt(
|
||||||
start_time_ts: float, end_time_ts: float
|
start_time_ts: float, end_time_ts: float
|
||||||
) -> StatementLambdaElement:
|
) -> StatementLambdaElement:
|
||||||
"""Generate the summary mean statement for hourly statistics."""
|
"""Generate the summary mean statement for hourly statistics."""
|
||||||
subquery = _compile_hourly_statistics_last_sum_stmt_subquery(
|
|
||||||
start_time_ts, end_time_ts
|
|
||||||
)
|
|
||||||
return lambda_stmt(
|
return lambda_stmt(
|
||||||
lambda: select(subquery)
|
lambda: select(
|
||||||
|
subquery := (
|
||||||
|
select(*QUERY_STATISTICS_SUMMARY_SUM)
|
||||||
|
.filter(StatisticsShortTerm.start_ts >= start_time_ts)
|
||||||
|
.filter(StatisticsShortTerm.start_ts < end_time_ts)
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
)
|
||||||
.filter(subquery.c.rownum == 1)
|
.filter(subquery.c.rownum == 1)
|
||||||
.order_by(subquery.c.metadata_id)
|
.order_by(subquery.c.metadata_id)
|
||||||
)
|
)
|
||||||
|
@ -1267,7 +1258,8 @@ def _reduce_statistics_per_month(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _statistics_during_period_stmt(
|
def _generate_statistics_during_period_stmt(
|
||||||
|
columns: Select,
|
||||||
start_time: datetime,
|
start_time: datetime,
|
||||||
end_time: datetime | None,
|
end_time: datetime | None,
|
||||||
metadata_ids: list[int] | None,
|
metadata_ids: list[int] | None,
|
||||||
|
@ -1279,21 +1271,6 @@ def _statistics_during_period_stmt(
|
||||||
This prepares a lambda_stmt query, so we don't insert the parameters yet.
|
This prepares a lambda_stmt query, so we don't insert the parameters yet.
|
||||||
"""
|
"""
|
||||||
start_time_ts = start_time.timestamp()
|
start_time_ts = start_time.timestamp()
|
||||||
|
|
||||||
columns = select(table.metadata_id, table.start_ts)
|
|
||||||
if "last_reset" in types:
|
|
||||||
columns = columns.add_columns(table.last_reset_ts)
|
|
||||||
if "max" in types:
|
|
||||||
columns = columns.add_columns(table.max)
|
|
||||||
if "mean" in types:
|
|
||||||
columns = columns.add_columns(table.mean)
|
|
||||||
if "min" in types:
|
|
||||||
columns = columns.add_columns(table.min)
|
|
||||||
if "state" in types:
|
|
||||||
columns = columns.add_columns(table.state)
|
|
||||||
if "sum" in types:
|
|
||||||
columns = columns.add_columns(table.sum)
|
|
||||||
|
|
||||||
stmt = lambda_stmt(lambda: columns.filter(table.start_ts >= start_time_ts))
|
stmt = lambda_stmt(lambda: columns.filter(table.start_ts >= start_time_ts))
|
||||||
if end_time is not None:
|
if end_time is not None:
|
||||||
end_time_ts = end_time.timestamp()
|
end_time_ts = end_time.timestamp()
|
||||||
|
@ -1307,6 +1284,23 @@ def _statistics_during_period_stmt(
|
||||||
return stmt
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_max_mean_min_statistic_in_sub_period_stmt(
|
||||||
|
columns: Select,
|
||||||
|
start_time: datetime | None,
|
||||||
|
end_time: datetime | None,
|
||||||
|
table: type[StatisticsBase],
|
||||||
|
metadata_id: int,
|
||||||
|
) -> StatementLambdaElement:
|
||||||
|
stmt = lambda_stmt(lambda: columns.filter(table.metadata_id == metadata_id))
|
||||||
|
if start_time is not None:
|
||||||
|
start_time_ts = start_time.timestamp()
|
||||||
|
stmt += lambda q: q.filter(table.start_ts >= start_time_ts)
|
||||||
|
if end_time is not None:
|
||||||
|
end_time_ts = end_time.timestamp()
|
||||||
|
stmt += lambda q: q.filter(table.start_ts < end_time_ts)
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
def _get_max_mean_min_statistic_in_sub_period(
|
def _get_max_mean_min_statistic_in_sub_period(
|
||||||
session: Session,
|
session: Session,
|
||||||
result: dict[str, float],
|
result: dict[str, float],
|
||||||
|
@ -1332,13 +1326,9 @@ def _get_max_mean_min_statistic_in_sub_period(
|
||||||
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
|
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
|
||||||
# pylint: disable-next=not-callable
|
# pylint: disable-next=not-callable
|
||||||
columns = columns.add_columns(func.min(table.min))
|
columns = columns.add_columns(func.min(table.min))
|
||||||
stmt = lambda_stmt(lambda: columns.filter(table.metadata_id == metadata_id))
|
stmt = _generate_max_mean_min_statistic_in_sub_period_stmt(
|
||||||
if start_time is not None:
|
columns, start_time, end_time, table, metadata_id
|
||||||
start_time_ts = start_time.timestamp()
|
)
|
||||||
stmt += lambda q: q.filter(table.start_ts >= start_time_ts)
|
|
||||||
if end_time is not None:
|
|
||||||
end_time_ts = end_time.timestamp()
|
|
||||||
stmt += lambda q: q.filter(table.start_ts < end_time_ts)
|
|
||||||
stats = cast(Sequence[Row[Any]], execute_stmt_lambda_element(session, stmt))
|
stats = cast(Sequence[Row[Any]], execute_stmt_lambda_element(session, stmt))
|
||||||
if not stats:
|
if not stats:
|
||||||
return
|
return
|
||||||
|
@ -1753,8 +1743,21 @@ def _statistics_during_period_with_session(
|
||||||
table: type[Statistics | StatisticsShortTerm] = (
|
table: type[Statistics | StatisticsShortTerm] = (
|
||||||
Statistics if period != "5minute" else StatisticsShortTerm
|
Statistics if period != "5minute" else StatisticsShortTerm
|
||||||
)
|
)
|
||||||
stmt = _statistics_during_period_stmt(
|
columns = select(table.metadata_id, table.start_ts) # type: ignore[call-overload]
|
||||||
start_time, end_time, metadata_ids, table, types
|
if "last_reset" in types:
|
||||||
|
columns = columns.add_columns(table.last_reset_ts)
|
||||||
|
if "max" in types:
|
||||||
|
columns = columns.add_columns(table.max)
|
||||||
|
if "mean" in types:
|
||||||
|
columns = columns.add_columns(table.mean)
|
||||||
|
if "min" in types:
|
||||||
|
columns = columns.add_columns(table.min)
|
||||||
|
if "state" in types:
|
||||||
|
columns = columns.add_columns(table.state)
|
||||||
|
if "sum" in types:
|
||||||
|
columns = columns.add_columns(table.sum)
|
||||||
|
stmt = _generate_statistics_during_period_stmt(
|
||||||
|
columns, start_time, end_time, metadata_ids, table, types
|
||||||
)
|
)
|
||||||
stats = cast(Sequence[Row], execute_stmt_lambda_element(session, stmt))
|
stats = cast(Sequence[Row], execute_stmt_lambda_element(session, stmt))
|
||||||
|
|
||||||
|
@ -1919,28 +1922,24 @@ def get_last_short_term_statistics(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _generate_most_recent_statistic_row(metadata_ids: list[int]) -> Subquery:
|
|
||||||
"""Generate the subquery to find the most recent statistic row."""
|
|
||||||
return (
|
|
||||||
select(
|
|
||||||
StatisticsShortTerm.metadata_id,
|
|
||||||
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
|
|
||||||
# pylint: disable-next=not-callable
|
|
||||||
func.max(StatisticsShortTerm.start_ts).label("start_max"),
|
|
||||||
)
|
|
||||||
.where(StatisticsShortTerm.metadata_id.in_(metadata_ids))
|
|
||||||
.group_by(StatisticsShortTerm.metadata_id)
|
|
||||||
).subquery()
|
|
||||||
|
|
||||||
|
|
||||||
def _latest_short_term_statistics_stmt(
|
def _latest_short_term_statistics_stmt(
|
||||||
metadata_ids: list[int],
|
metadata_ids: list[int],
|
||||||
) -> StatementLambdaElement:
|
) -> StatementLambdaElement:
|
||||||
"""Create the statement for finding the latest short term stat rows."""
|
"""Create the statement for finding the latest short term stat rows."""
|
||||||
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
|
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
|
||||||
most_recent_statistic_row = _generate_most_recent_statistic_row(metadata_ids)
|
|
||||||
stmt += lambda s: s.join(
|
stmt += lambda s: s.join(
|
||||||
most_recent_statistic_row,
|
(
|
||||||
|
most_recent_statistic_row := (
|
||||||
|
select(
|
||||||
|
StatisticsShortTerm.metadata_id,
|
||||||
|
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
|
||||||
|
# pylint: disable-next=not-callable
|
||||||
|
func.max(StatisticsShortTerm.start_ts).label("start_max"),
|
||||||
|
)
|
||||||
|
.where(StatisticsShortTerm.metadata_id.in_(metadata_ids))
|
||||||
|
.group_by(StatisticsShortTerm.metadata_id)
|
||||||
|
).subquery()
|
||||||
|
),
|
||||||
(
|
(
|
||||||
StatisticsShortTerm.metadata_id # pylint: disable=comparison-with-callable
|
StatisticsShortTerm.metadata_id # pylint: disable=comparison-with-callable
|
||||||
== most_recent_statistic_row.c.metadata_id
|
== most_recent_statistic_row.c.metadata_id
|
||||||
|
@ -1988,21 +1987,34 @@ def get_latest_short_term_statistics(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_most_recent_statistics_subquery(
|
def _generate_statistics_at_time_stmt(
|
||||||
metadata_ids: set[int], table: type[StatisticsBase], start_time_ts: float
|
columns: Select,
|
||||||
) -> Subquery:
|
table: type[StatisticsBase],
|
||||||
"""Generate the subquery to find the most recent statistic row."""
|
metadata_ids: set[int],
|
||||||
return (
|
start_time_ts: float,
|
||||||
select(
|
) -> StatementLambdaElement:
|
||||||
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
|
"""Create the statement for finding the statistics for a given time."""
|
||||||
# pylint: disable-next=not-callable
|
return lambda_stmt(
|
||||||
func.max(table.start_ts).label("max_start_ts"),
|
lambda: columns.join(
|
||||||
table.metadata_id.label("max_metadata_id"),
|
(
|
||||||
|
most_recent_statistic_ids := (
|
||||||
|
select(
|
||||||
|
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
|
||||||
|
# pylint: disable-next=not-callable
|
||||||
|
func.max(table.start_ts).label("max_start_ts"),
|
||||||
|
table.metadata_id.label("max_metadata_id"),
|
||||||
|
)
|
||||||
|
.filter(table.start_ts < start_time_ts)
|
||||||
|
.filter(table.metadata_id.in_(metadata_ids))
|
||||||
|
.group_by(table.metadata_id)
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
),
|
||||||
|
and_(
|
||||||
|
table.start_ts == most_recent_statistic_ids.c.max_start_ts,
|
||||||
|
table.metadata_id == most_recent_statistic_ids.c.max_metadata_id,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
.filter(table.start_ts < start_time_ts)
|
|
||||||
.filter(table.metadata_id.in_(metadata_ids))
|
|
||||||
.group_by(table.metadata_id)
|
|
||||||
.subquery()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2027,19 +2039,10 @@ def _statistics_at_time(
|
||||||
columns = columns.add_columns(table.state)
|
columns = columns.add_columns(table.state)
|
||||||
if "sum" in types:
|
if "sum" in types:
|
||||||
columns = columns.add_columns(table.sum)
|
columns = columns.add_columns(table.sum)
|
||||||
|
|
||||||
start_time_ts = start_time.timestamp()
|
start_time_ts = start_time.timestamp()
|
||||||
most_recent_statistic_ids = _get_most_recent_statistics_subquery(
|
stmt = _generate_statistics_at_time_stmt(
|
||||||
metadata_ids, table, start_time_ts
|
columns, table, metadata_ids, start_time_ts
|
||||||
)
|
)
|
||||||
stmt = lambda_stmt(lambda: columns).join(
|
|
||||||
most_recent_statistic_ids,
|
|
||||||
and_(
|
|
||||||
table.start_ts == most_recent_statistic_ids.c.max_start_ts,
|
|
||||||
table.metadata_id == most_recent_statistic_ids.c.max_metadata_id,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return cast(Sequence[Row], execute_stmt_lambda_element(session, stmt))
|
return cast(Sequence[Row], execute_stmt_lambda_element(session, stmt))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ import sys
|
||||||
from unittest.mock import ANY, DEFAULT, MagicMock, patch, sentinel
|
from unittest.mock import ANY, DEFAULT, MagicMock, patch, sentinel
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine, select
|
||||||
from sqlalchemy.exc import OperationalError
|
from sqlalchemy.exc import OperationalError
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
@ -22,6 +22,10 @@ from homeassistant.components.recorder.models import (
|
||||||
)
|
)
|
||||||
from homeassistant.components.recorder.statistics import (
|
from homeassistant.components.recorder.statistics import (
|
||||||
STATISTIC_UNIT_TO_UNIT_CONVERTER,
|
STATISTIC_UNIT_TO_UNIT_CONVERTER,
|
||||||
|
_generate_get_metadata_stmt,
|
||||||
|
_generate_max_mean_min_statistic_in_sub_period_stmt,
|
||||||
|
_generate_statistics_at_time_stmt,
|
||||||
|
_generate_statistics_during_period_stmt,
|
||||||
_statistics_during_period_with_session,
|
_statistics_during_period_with_session,
|
||||||
_update_or_add_metadata,
|
_update_or_add_metadata,
|
||||||
async_add_external_statistics,
|
async_add_external_statistics,
|
||||||
|
@ -1799,3 +1803,100 @@ def record_states(hass):
|
||||||
states[sns4].append(set_state(sns4, "20", attributes=sns4_attr))
|
states[sns4].append(set_state(sns4, "20", attributes=sns4_attr))
|
||||||
|
|
||||||
return zero, four, states
|
return zero, four, states
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_key_for_generate_statistics_during_period_stmt():
|
||||||
|
"""Test cache key for _generate_statistics_during_period_stmt."""
|
||||||
|
columns = select(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start_ts)
|
||||||
|
stmt = _generate_statistics_during_period_stmt(
|
||||||
|
columns, dt_util.utcnow(), dt_util.utcnow(), [0], StatisticsShortTerm, {}
|
||||||
|
)
|
||||||
|
cache_key_1 = stmt._generate_cache_key()
|
||||||
|
stmt2 = _generate_statistics_during_period_stmt(
|
||||||
|
columns, dt_util.utcnow(), dt_util.utcnow(), [0], StatisticsShortTerm, {}
|
||||||
|
)
|
||||||
|
cache_key_2 = stmt2._generate_cache_key()
|
||||||
|
assert cache_key_1 == cache_key_2
|
||||||
|
columns2 = select(
|
||||||
|
StatisticsShortTerm.metadata_id,
|
||||||
|
StatisticsShortTerm.start_ts,
|
||||||
|
StatisticsShortTerm.sum,
|
||||||
|
StatisticsShortTerm.mean,
|
||||||
|
)
|
||||||
|
stmt3 = _generate_statistics_during_period_stmt(
|
||||||
|
columns2,
|
||||||
|
dt_util.utcnow(),
|
||||||
|
dt_util.utcnow(),
|
||||||
|
[0],
|
||||||
|
StatisticsShortTerm,
|
||||||
|
{"max", "mean"},
|
||||||
|
)
|
||||||
|
cache_key_3 = stmt3._generate_cache_key()
|
||||||
|
assert cache_key_1 != cache_key_3
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_key_for_generate_get_metadata_stmt():
|
||||||
|
"""Test cache key for _generate_get_metadata_stmt."""
|
||||||
|
stmt_mean = _generate_get_metadata_stmt([0], "mean")
|
||||||
|
stmt_mean2 = _generate_get_metadata_stmt([1], "mean")
|
||||||
|
stmt_sum = _generate_get_metadata_stmt([0], "sum")
|
||||||
|
stmt_none = _generate_get_metadata_stmt()
|
||||||
|
assert stmt_mean._generate_cache_key() == stmt_mean2._generate_cache_key()
|
||||||
|
assert stmt_mean._generate_cache_key() != stmt_sum._generate_cache_key()
|
||||||
|
assert stmt_mean._generate_cache_key() != stmt_none._generate_cache_key()
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_key_for_generate_max_mean_min_statistic_in_sub_period_stmt():
|
||||||
|
"""Test cache key for _generate_max_mean_min_statistic_in_sub_period_stmt."""
|
||||||
|
columns = select(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start_ts)
|
||||||
|
stmt = _generate_max_mean_min_statistic_in_sub_period_stmt(
|
||||||
|
columns,
|
||||||
|
dt_util.utcnow(),
|
||||||
|
dt_util.utcnow(),
|
||||||
|
StatisticsShortTerm,
|
||||||
|
[0],
|
||||||
|
)
|
||||||
|
cache_key_1 = stmt._generate_cache_key()
|
||||||
|
stmt2 = _generate_max_mean_min_statistic_in_sub_period_stmt(
|
||||||
|
columns,
|
||||||
|
dt_util.utcnow(),
|
||||||
|
dt_util.utcnow(),
|
||||||
|
StatisticsShortTerm,
|
||||||
|
[0],
|
||||||
|
)
|
||||||
|
cache_key_2 = stmt2._generate_cache_key()
|
||||||
|
assert cache_key_1 == cache_key_2
|
||||||
|
columns2 = select(
|
||||||
|
StatisticsShortTerm.metadata_id,
|
||||||
|
StatisticsShortTerm.start_ts,
|
||||||
|
StatisticsShortTerm.sum,
|
||||||
|
StatisticsShortTerm.mean,
|
||||||
|
)
|
||||||
|
stmt3 = _generate_max_mean_min_statistic_in_sub_period_stmt(
|
||||||
|
columns2,
|
||||||
|
dt_util.utcnow(),
|
||||||
|
dt_util.utcnow(),
|
||||||
|
StatisticsShortTerm,
|
||||||
|
[0],
|
||||||
|
)
|
||||||
|
cache_key_3 = stmt3._generate_cache_key()
|
||||||
|
assert cache_key_1 != cache_key_3
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_key_for_generate_statistics_at_time_stmt():
|
||||||
|
"""Test cache key for _generate_statistics_at_time_stmt."""
|
||||||
|
columns = select(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start_ts)
|
||||||
|
stmt = _generate_statistics_at_time_stmt(columns, StatisticsShortTerm, {0}, 0.0)
|
||||||
|
cache_key_1 = stmt._generate_cache_key()
|
||||||
|
stmt2 = _generate_statistics_at_time_stmt(columns, StatisticsShortTerm, {0}, 0.0)
|
||||||
|
cache_key_2 = stmt2._generate_cache_key()
|
||||||
|
assert cache_key_1 == cache_key_2
|
||||||
|
columns2 = select(
|
||||||
|
StatisticsShortTerm.metadata_id,
|
||||||
|
StatisticsShortTerm.start_ts,
|
||||||
|
StatisticsShortTerm.sum,
|
||||||
|
StatisticsShortTerm.mean,
|
||||||
|
)
|
||||||
|
stmt3 = _generate_statistics_at_time_stmt(columns2, StatisticsShortTerm, {0}, 0.0)
|
||||||
|
cache_key_3 = stmt3._generate_cache_key()
|
||||||
|
assert cache_key_1 != cache_key_3
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue