Upgrade SQLAlchemy to 2.0.2 (#86436)

Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Erik Montnemery 2023-02-08 15:17:32 +01:00 committed by GitHub
parent 93dafefd96
commit 94519de8dd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
37 changed files with 583 additions and 430 deletions

View file

@ -82,7 +82,7 @@ class EventAsRow:
@callback @callback
def async_event_to_row(event: Event) -> EventAsRow | None: def async_event_to_row(event: Event) -> EventAsRow:
"""Convert an event to a row.""" """Convert an event to a row."""
if event.event_type != EVENT_STATE_CHANGED: if event.event_type != EVENT_STATE_CHANGED:
return EventAsRow( return EventAsRow(

View file

@ -1,14 +1,14 @@
"""Event parser and human readable log generator.""" """Event parser and human readable log generator."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Generator from collections.abc import Callable, Generator, Sequence
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime as dt from datetime import datetime as dt
from typing import Any from typing import Any
from sqlalchemy.engine import Result
from sqlalchemy.engine.row import Row from sqlalchemy.engine.row import Row
from sqlalchemy.orm.query import Query
from homeassistant.components.recorder.filters import Filters from homeassistant.components.recorder.filters import Filters
from homeassistant.components.recorder.models import ( from homeassistant.components.recorder.models import (
@ -70,7 +70,7 @@ class LogbookRun:
event_cache: EventCache event_cache: EventCache
entity_name_cache: EntityNameCache entity_name_cache: EntityNameCache
include_entity_name: bool include_entity_name: bool
format_time: Callable[[Row], Any] format_time: Callable[[Row | EventAsRow], Any]
class EventProcessor: class EventProcessor:
@ -133,13 +133,13 @@ class EventProcessor:
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Get events for a period of time.""" """Get events for a period of time."""
def yield_rows(query: Query) -> Generator[Row, None, None]: def yield_rows(result: Result) -> Sequence[Row] | Result:
"""Yield rows from the database.""" """Yield rows from the database."""
# end_day - start_day intentionally checks .days and not .total_seconds() # end_day - start_day intentionally checks .days and not .total_seconds()
# since we don't want to switch over to buffered if they go # since we don't want to switch over to buffered if they go
# over one day by a few hours since the UI makes it so easy to do that. # over one day by a few hours since the UI makes it so easy to do that.
if self.limited_select or (end_day - start_day).days <= 1: if self.limited_select or (end_day - start_day).days <= 1:
return query.all() # type: ignore[no-any-return] return result.all()
# Only buffer rows to reduce memory pressure # Only buffer rows to reduce memory pressure
# if we expect the result set is going to be very large. # if we expect the result set is going to be very large.
# What is considered very large is going to differ # What is considered very large is going to differ
@ -149,7 +149,7 @@ class EventProcessor:
# even and RPi3 that number seems higher in testing # even and RPi3 that number seems higher in testing
# so we don't switch over until we request > 1 day+ of data. # so we don't switch over until we request > 1 day+ of data.
# #
return query.yield_per(1024) # type: ignore[no-any-return] return result.yield_per(1024)
stmt = statement_for_request( stmt = statement_for_request(
start_day, start_day,
@ -164,12 +164,12 @@ class EventProcessor:
return self.humanify(yield_rows(session.execute(stmt))) return self.humanify(yield_rows(session.execute(stmt)))
def humanify( def humanify(
self, row_generator: Generator[Row | EventAsRow, None, None] self, rows: Generator[EventAsRow, None, None] | Sequence[Row] | Result
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
"""Humanify rows.""" """Humanify rows."""
return list( return list(
_humanify( _humanify(
row_generator, rows,
self.ent_reg, self.ent_reg,
self.logbook_run, self.logbook_run,
self.context_augmenter, self.context_augmenter,
@ -178,7 +178,7 @@ class EventProcessor:
def _humanify( def _humanify(
rows: Generator[Row | EventAsRow, None, None], rows: Generator[EventAsRow, None, None] | Sequence[Row] | Result,
ent_reg: er.EntityRegistry, ent_reg: er.EntityRegistry,
logbook_run: LogbookRun, logbook_run: LogbookRun,
context_augmenter: ContextAugmenter, context_augmenter: ContextAugmenter,
@ -263,7 +263,7 @@ class ContextLookup:
self._memorize_new = True self._memorize_new = True
self._lookup: dict[str | None, Row | EventAsRow | None] = {None: None} self._lookup: dict[str | None, Row | EventAsRow | None] = {None: None}
def memorize(self, row: Row) -> str | None: def memorize(self, row: Row | EventAsRow) -> str | None:
"""Memorize a context from the database.""" """Memorize a context from the database."""
if self._memorize_new: if self._memorize_new:
context_id: str = row.context_id context_id: str = row.context_id
@ -276,7 +276,7 @@ class ContextLookup:
self._lookup.clear() self._lookup.clear()
self._memorize_new = False self._memorize_new = False
def get(self, context_id: str) -> Row | None: def get(self, context_id: str) -> Row | EventAsRow | None:
"""Get the context origin.""" """Get the context origin."""
return self._lookup.get(context_id) return self._lookup.get(context_id)
@ -294,7 +294,7 @@ class ContextAugmenter:
def _get_context_row( def _get_context_row(
self, context_id: str | None, row: Row | EventAsRow self, context_id: str | None, row: Row | EventAsRow
) -> Row | EventAsRow: ) -> Row | EventAsRow | None:
"""Get the context row from the id or row context.""" """Get the context row from the id or row context."""
if context_id: if context_id:
return self.context_lookup.get(context_id) return self.context_lookup.get(context_id)

View file

@ -2,9 +2,9 @@
from __future__ import annotations from __future__ import annotations
from sqlalchemy import lambda_stmt from sqlalchemy import lambda_stmt
from sqlalchemy.orm import Query from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import ClauseList
from sqlalchemy.sql.lambdas import StatementLambdaElement from sqlalchemy.sql.lambdas import StatementLambdaElement
from sqlalchemy.sql.selectable import Select
from homeassistant.components.recorder.db_schema import ( from homeassistant.components.recorder.db_schema import (
LAST_UPDATED_INDEX_TS, LAST_UPDATED_INDEX_TS,
@ -24,8 +24,8 @@ def all_stmt(
start_day: float, start_day: float,
end_day: float, end_day: float,
event_types: tuple[str, ...], event_types: tuple[str, ...],
states_entity_filter: ClauseList | None = None, states_entity_filter: ColumnElement | None = None,
events_entity_filter: ClauseList | None = None, events_entity_filter: ColumnElement | None = None,
context_id: str | None = None, context_id: str | None = None,
) -> StatementLambdaElement: ) -> StatementLambdaElement:
"""Generate a logbook query for all entities.""" """Generate a logbook query for all entities."""
@ -37,8 +37,18 @@ def all_stmt(
# are gone from the database remove the # are gone from the database remove the
# _legacy_select_events_context_id() # _legacy_select_events_context_id()
stmt += lambda s: s.where(Events.context_id == context_id).union_all( stmt += lambda s: s.where(Events.context_id == context_id).union_all(
_states_query_for_context_id(start_day, end_day, context_id), _states_query_for_context_id(
legacy_select_events_context_id(start_day, end_day, context_id), start_day,
end_day,
# https://github.com/python/mypy/issues/2608
context_id, # type:ignore[arg-type]
),
legacy_select_events_context_id(
start_day,
end_day,
# https://github.com/python/mypy/issues/2608
context_id, # type:ignore[arg-type]
),
) )
else: else:
if events_entity_filter is not None: if events_entity_filter is not None:
@ -46,7 +56,10 @@ def all_stmt(
if states_entity_filter is not None: if states_entity_filter is not None:
stmt += lambda s: s.union_all( stmt += lambda s: s.union_all(
_states_query_for_all(start_day, end_day).where(states_entity_filter) _states_query_for_all(start_day, end_day).where(
# https://github.com/python/mypy/issues/2608
states_entity_filter # type:ignore[arg-type]
)
) )
else: else:
stmt += lambda s: s.union_all(_states_query_for_all(start_day, end_day)) stmt += lambda s: s.union_all(_states_query_for_all(start_day, end_day))
@ -55,20 +68,20 @@ def all_stmt(
return stmt return stmt
def _states_query_for_all(start_day: float, end_day: float) -> Query: def _states_query_for_all(start_day: float, end_day: float) -> Select:
return apply_states_filters(_apply_all_hints(select_states()), start_day, end_day) return apply_states_filters(_apply_all_hints(select_states()), start_day, end_day)
def _apply_all_hints(query: Query) -> Query: def _apply_all_hints(sel: Select) -> Select:
"""Force mysql to use the right index on large selects.""" """Force mysql to use the right index on large selects."""
return query.with_hint( return sel.with_hint(
States, f"FORCE INDEX ({LAST_UPDATED_INDEX_TS})", dialect_name="mysql" States, f"FORCE INDEX ({LAST_UPDATED_INDEX_TS})", dialect_name="mysql"
) )
def _states_query_for_context_id( def _states_query_for_context_id(
start_day: float, end_day: float, context_id: str start_day: float, end_day: float, context_id: str
) -> Query: ) -> Select:
return apply_states_filters(select_states(), start_day, end_day).where( return apply_states_filters(select_states(), start_day, end_day).where(
States.context_id == context_id States.context_id == context_id
) )

View file

@ -5,8 +5,7 @@ from typing import Final
import sqlalchemy import sqlalchemy
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Query from sqlalchemy.sql.elements import BooleanClauseList, ColumnElement
from sqlalchemy.sql.elements import ClauseList
from sqlalchemy.sql.expression import literal from sqlalchemy.sql.expression import literal
from sqlalchemy.sql.selectable import Select from sqlalchemy.sql.selectable import Select
@ -69,7 +68,7 @@ STATE_CONTEXT_ONLY_COLUMNS = (
literal(value=None, type_=sqlalchemy.String).label("old_format_icon"), literal(value=None, type_=sqlalchemy.String).label("old_format_icon"),
) )
EVENT_COLUMNS_FOR_STATE_SELECT = [ EVENT_COLUMNS_FOR_STATE_SELECT = (
literal(value=None, type_=sqlalchemy.Text).label("event_id"), literal(value=None, type_=sqlalchemy.Text).label("event_id"),
# We use PSEUDO_EVENT_STATE_CHANGED aka None for # We use PSEUDO_EVENT_STATE_CHANGED aka None for
# state_changed events since it takes up less # state_changed events since it takes up less
@ -84,7 +83,7 @@ EVENT_COLUMNS_FOR_STATE_SELECT = [
States.context_user_id.label("context_user_id"), States.context_user_id.label("context_user_id"),
States.context_parent_id.label("context_parent_id"), States.context_parent_id.label("context_parent_id"),
literal(value=None, type_=sqlalchemy.Text).label("shared_data"), literal(value=None, type_=sqlalchemy.Text).label("shared_data"),
] )
EMPTY_STATE_COLUMNS = ( EMPTY_STATE_COLUMNS = (
literal(value=0, type_=sqlalchemy.Integer).label("state_id"), literal(value=0, type_=sqlalchemy.Integer).label("state_id"),
@ -103,8 +102,8 @@ EVENT_ROWS_NO_STATES = (
# Virtual column to tell logbook if it should avoid processing # Virtual column to tell logbook if it should avoid processing
# the event as its only used to link contexts # the event as its only used to link contexts
CONTEXT_ONLY = literal("1").label("context_only") CONTEXT_ONLY = literal(value="1", type_=sqlalchemy.String).label("context_only")
NOT_CONTEXT_ONLY = literal(None).label("context_only") NOT_CONTEXT_ONLY = literal(value=None, type_=sqlalchemy.String).label("context_only")
def select_events_context_id_subquery( def select_events_context_id_subquery(
@ -188,7 +187,7 @@ def legacy_select_events_context_id(
) )
def apply_states_filters(query: Query, start_day: float, end_day: float) -> Query: def apply_states_filters(sel: Select, start_day: float, end_day: float) -> Select:
"""Filter states by time range. """Filter states by time range.
Filters states that do not have an old state or new state (added / removed) Filters states that do not have an old state or new state (added / removed)
@ -196,7 +195,7 @@ def apply_states_filters(query: Query, start_day: float, end_day: float) -> Quer
Filters states that do not have matching last_updated_ts and last_changed_ts. Filters states that do not have matching last_updated_ts and last_changed_ts.
""" """
return ( return (
query.filter( sel.filter(
(States.last_updated_ts > start_day) & (States.last_updated_ts < end_day) (States.last_updated_ts > start_day) & (States.last_updated_ts < end_day)
) )
.outerjoin(OLD_STATE, (States.old_state_id == OLD_STATE.state_id)) .outerjoin(OLD_STATE, (States.old_state_id == OLD_STATE.state_id))
@ -212,18 +211,18 @@ def apply_states_filters(query: Query, start_day: float, end_day: float) -> Quer
) )
def _missing_state_matcher() -> sqlalchemy.and_: def _missing_state_matcher() -> ColumnElement[bool]:
# The below removes state change events that do not have # The below removes state change events that do not have
# and old_state or the old_state is missing (newly added entities) # and old_state or the old_state is missing (newly added entities)
# or the new_state is missing (removed entities) # or the new_state is missing (removed entities)
return sqlalchemy.and_( return sqlalchemy.and_(
OLD_STATE.state_id.isnot(None), OLD_STATE.state_id.is_not(None),
(States.state != OLD_STATE.state), (States.state != OLD_STATE.state),
States.state.isnot(None), States.state.is_not(None),
) )
def _not_continuous_entity_matcher() -> sqlalchemy.or_: def _not_continuous_entity_matcher() -> ColumnElement[bool]:
"""Match non continuous entities.""" """Match non continuous entities."""
return sqlalchemy.or_( return sqlalchemy.or_(
# First exclude domains that may be continuous # First exclude domains that may be continuous
@ -236,7 +235,7 @@ def _not_continuous_entity_matcher() -> sqlalchemy.or_:
) )
def _not_possible_continuous_domain_matcher() -> sqlalchemy.and_: def _not_possible_continuous_domain_matcher() -> ColumnElement[bool]:
"""Match not continuous domains. """Match not continuous domains.
This matches domain that are always considered continuous This matches domain that are always considered continuous
@ -254,7 +253,7 @@ def _not_possible_continuous_domain_matcher() -> sqlalchemy.and_:
).self_group() ).self_group()
def _conditionally_continuous_domain_matcher() -> sqlalchemy.or_: def _conditionally_continuous_domain_matcher() -> ColumnElement[bool]:
"""Match conditionally continuous domains. """Match conditionally continuous domains.
This matches domain that are only considered This matches domain that are only considered
@ -268,22 +267,22 @@ def _conditionally_continuous_domain_matcher() -> sqlalchemy.or_:
).self_group() ).self_group()
def _not_uom_attributes_matcher() -> ClauseList: def _not_uom_attributes_matcher() -> BooleanClauseList:
"""Prefilter ATTR_UNIT_OF_MEASUREMENT as its much faster in sql.""" """Prefilter ATTR_UNIT_OF_MEASUREMENT as its much faster in sql."""
return ~StateAttributes.shared_attrs.like( return ~StateAttributes.shared_attrs.like(
UNIT_OF_MEASUREMENT_JSON_LIKE UNIT_OF_MEASUREMENT_JSON_LIKE
) | ~States.attributes.like(UNIT_OF_MEASUREMENT_JSON_LIKE) ) | ~States.attributes.like(UNIT_OF_MEASUREMENT_JSON_LIKE)
def apply_states_context_hints(query: Query) -> Query: def apply_states_context_hints(sel: Select) -> Select:
"""Force mysql to use the right index on large context_id selects.""" """Force mysql to use the right index on large context_id selects."""
return query.with_hint( return sel.with_hint(
States, f"FORCE INDEX ({STATES_CONTEXT_ID_INDEX})", dialect_name="mysql" States, f"FORCE INDEX ({STATES_CONTEXT_ID_INDEX})", dialect_name="mysql"
) )
def apply_events_context_hints(query: Query) -> Query: def apply_events_context_hints(sel: Select) -> Select:
"""Force mysql to use the right index on large context_id selects.""" """Force mysql to use the right index on large context_id selects."""
return query.with_hint( return sel.with_hint(
Events, f"FORCE INDEX ({EVENTS_CONTEXT_ID_INDEX})", dialect_name="mysql" Events, f"FORCE INDEX ({EVENTS_CONTEXT_ID_INDEX})", dialect_name="mysql"
) )

View file

@ -5,10 +5,9 @@ from collections.abc import Iterable
import sqlalchemy import sqlalchemy
from sqlalchemy import lambda_stmt, select from sqlalchemy import lambda_stmt, select
from sqlalchemy.orm import Query from sqlalchemy.sql.elements import BooleanClauseList
from sqlalchemy.sql.elements import ClauseList
from sqlalchemy.sql.lambdas import StatementLambdaElement from sqlalchemy.sql.lambdas import StatementLambdaElement
from sqlalchemy.sql.selectable import CTE, CompoundSelect from sqlalchemy.sql.selectable import CTE, CompoundSelect, Select
from homeassistant.components.recorder.db_schema import ( from homeassistant.components.recorder.db_schema import (
DEVICE_ID_IN_EVENT, DEVICE_ID_IN_EVENT,
@ -32,7 +31,7 @@ def _select_device_id_context_ids_sub_query(
end_day: float, end_day: float,
event_types: tuple[str, ...], event_types: tuple[str, ...],
json_quotable_device_ids: list[str], json_quotable_device_ids: list[str],
) -> CompoundSelect: ) -> Select:
"""Generate a subquery to find context ids for multiple devices.""" """Generate a subquery to find context ids for multiple devices."""
inner = select_events_context_id_subquery(start_day, end_day, event_types).where( inner = select_events_context_id_subquery(start_day, end_day, event_types).where(
apply_event_device_id_matchers(json_quotable_device_ids) apply_event_device_id_matchers(json_quotable_device_ids)
@ -41,7 +40,7 @@ def _select_device_id_context_ids_sub_query(
def _apply_devices_context_union( def _apply_devices_context_union(
query: Query, sel: Select,
start_day: float, start_day: float,
end_day: float, end_day: float,
event_types: tuple[str, ...], event_types: tuple[str, ...],
@ -54,7 +53,7 @@ def _apply_devices_context_union(
event_types, event_types,
json_quotable_device_ids, json_quotable_device_ids,
).cte() ).cte()
return query.union_all( return sel.union_all(
apply_events_context_hints( apply_events_context_hints(
select_events_context_only() select_events_context_only()
.select_from(devices_cte) .select_from(devices_cte)
@ -91,7 +90,7 @@ def devices_stmt(
def apply_event_device_id_matchers( def apply_event_device_id_matchers(
json_quotable_device_ids: Iterable[str], json_quotable_device_ids: Iterable[str],
) -> ClauseList: ) -> BooleanClauseList:
"""Create matchers for the device_ids in the event_data.""" """Create matchers for the device_ids in the event_data."""
return DEVICE_ID_IN_EVENT.is_not(None) & sqlalchemy.cast( return DEVICE_ID_IN_EVENT.is_not(None) & sqlalchemy.cast(
DEVICE_ID_IN_EVENT, sqlalchemy.Text() DEVICE_ID_IN_EVENT, sqlalchemy.Text()

View file

@ -5,9 +5,9 @@ from collections.abc import Iterable
import sqlalchemy import sqlalchemy
from sqlalchemy import lambda_stmt, select, union_all from sqlalchemy import lambda_stmt, select, union_all
from sqlalchemy.orm import Query from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.lambdas import StatementLambdaElement from sqlalchemy.sql.lambdas import StatementLambdaElement
from sqlalchemy.sql.selectable import CTE, CompoundSelect from sqlalchemy.sql.selectable import CTE, CompoundSelect, Select
from homeassistant.components.recorder.db_schema import ( from homeassistant.components.recorder.db_schema import (
ENTITY_ID_IN_EVENT, ENTITY_ID_IN_EVENT,
@ -36,7 +36,7 @@ def _select_entities_context_ids_sub_query(
event_types: tuple[str, ...], event_types: tuple[str, ...],
entity_ids: list[str], entity_ids: list[str],
json_quoted_entity_ids: list[str], json_quoted_entity_ids: list[str],
) -> CompoundSelect: ) -> Select:
"""Generate a subquery to find context ids for multiple entities.""" """Generate a subquery to find context ids for multiple entities."""
union = union_all( union = union_all(
select_events_context_id_subquery(start_day, end_day, event_types).where( select_events_context_id_subquery(start_day, end_day, event_types).where(
@ -52,7 +52,7 @@ def _select_entities_context_ids_sub_query(
def _apply_entities_context_union( def _apply_entities_context_union(
query: Query, sel: Select,
start_day: float, start_day: float,
end_day: float, end_day: float,
event_types: tuple[str, ...], event_types: tuple[str, ...],
@ -72,8 +72,8 @@ def _apply_entities_context_union(
# query much slower on MySQL, and since we already filter them away # query much slower on MySQL, and since we already filter them away
# in the python code anyways since they will have context_only # in the python code anyways since they will have context_only
# set on them the impact is minimal. # set on them the impact is minimal.
return query.union_all( return sel.union_all(
states_query_for_entity_ids(start_day, end_day, entity_ids), states_select_for_entity_ids(start_day, end_day, entity_ids),
apply_events_context_hints( apply_events_context_hints(
select_events_context_only() select_events_context_only()
.select_from(entities_cte) .select_from(entities_cte)
@ -109,9 +109,9 @@ def entities_stmt(
) )
def states_query_for_entity_ids( def states_select_for_entity_ids(
start_day: float, end_day: float, entity_ids: list[str] start_day: float, end_day: float, entity_ids: list[str]
) -> Query: ) -> Select:
"""Generate a select for states from the States table for specific entities.""" """Generate a select for states from the States table for specific entities."""
return apply_states_filters( return apply_states_filters(
apply_entities_hints(select_states()), start_day, end_day apply_entities_hints(select_states()), start_day, end_day
@ -120,7 +120,7 @@ def states_query_for_entity_ids(
def apply_event_entity_id_matchers( def apply_event_entity_id_matchers(
json_quoted_entity_ids: Iterable[str], json_quoted_entity_ids: Iterable[str],
) -> sqlalchemy.or_: ) -> ColumnElement[bool]:
"""Create matchers for the entity_id in the event_data.""" """Create matchers for the entity_id in the event_data."""
return sqlalchemy.or_( return sqlalchemy.or_(
ENTITY_ID_IN_EVENT.is_not(None) ENTITY_ID_IN_EVENT.is_not(None)
@ -134,8 +134,8 @@ def apply_event_entity_id_matchers(
) )
def apply_entities_hints(query: Query) -> Query: def apply_entities_hints(sel: Select) -> Select:
"""Force mysql to use the right index on large selects.""" """Force mysql to use the right index on large selects."""
return query.with_hint( return sel.with_hint(
States, f"FORCE INDEX ({ENTITY_ID_LAST_UPDATED_INDEX_TS})", dialect_name="mysql" States, f"FORCE INDEX ({ENTITY_ID_LAST_UPDATED_INDEX_TS})", dialect_name="mysql"
) )

View file

@ -3,11 +3,10 @@ from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
import sqlalchemy
from sqlalchemy import lambda_stmt, select, union_all from sqlalchemy import lambda_stmt, select, union_all
from sqlalchemy.orm import Query from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.lambdas import StatementLambdaElement from sqlalchemy.sql.lambdas import StatementLambdaElement
from sqlalchemy.sql.selectable import CTE, CompoundSelect from sqlalchemy.sql.selectable import CTE, CompoundSelect, Select
from homeassistant.components.recorder.db_schema import EventData, Events, States from homeassistant.components.recorder.db_schema import EventData, Events, States
@ -23,7 +22,7 @@ from .devices import apply_event_device_id_matchers
from .entities import ( from .entities import (
apply_entities_hints, apply_entities_hints,
apply_event_entity_id_matchers, apply_event_entity_id_matchers,
states_query_for_entity_ids, states_select_for_entity_ids,
) )
@ -34,7 +33,7 @@ def _select_entities_device_id_context_ids_sub_query(
entity_ids: list[str], entity_ids: list[str],
json_quoted_entity_ids: list[str], json_quoted_entity_ids: list[str],
json_quoted_device_ids: list[str], json_quoted_device_ids: list[str],
) -> CompoundSelect: ) -> Select:
"""Generate a subquery to find context ids for multiple entities and multiple devices.""" """Generate a subquery to find context ids for multiple entities and multiple devices."""
union = union_all( union = union_all(
select_events_context_id_subquery(start_day, end_day, event_types).where( select_events_context_id_subquery(start_day, end_day, event_types).where(
@ -52,7 +51,7 @@ def _select_entities_device_id_context_ids_sub_query(
def _apply_entities_devices_context_union( def _apply_entities_devices_context_union(
query: Query, sel: Select,
start_day: float, start_day: float,
end_day: float, end_day: float,
event_types: tuple[str, ...], event_types: tuple[str, ...],
@ -73,8 +72,8 @@ def _apply_entities_devices_context_union(
# query much slower on MySQL, and since we already filter them away # query much slower on MySQL, and since we already filter them away
# in the python code anyways since they will have context_only # in the python code anyways since they will have context_only
# set on them the impact is minimal. # set on them the impact is minimal.
return query.union_all( return sel.union_all(
states_query_for_entity_ids(start_day, end_day, entity_ids), states_select_for_entity_ids(start_day, end_day, entity_ids),
apply_events_context_hints( apply_events_context_hints(
select_events_context_only() select_events_context_only()
.select_from(devices_entities_cte) .select_from(devices_entities_cte)
@ -117,7 +116,7 @@ def entities_devices_stmt(
def _apply_event_entity_id_device_id_matchers( def _apply_event_entity_id_device_id_matchers(
json_quoted_entity_ids: Iterable[str], json_quoted_device_ids: Iterable[str] json_quoted_entity_ids: Iterable[str], json_quoted_device_ids: Iterable[str]
) -> sqlalchemy.or_: ) -> ColumnElement[bool]:
"""Create matchers for the device_id and entity_id in the event_data.""" """Create matchers for the device_id and entity_id in the event_data."""
return apply_event_entity_id_matchers( return apply_event_entity_id_matchers(
json_quoted_entity_ids json_quoted_entity_ids

View file

@ -190,7 +190,7 @@ class Recorder(threading.Thread):
self.schema_version = 0 self.schema_version = 0
self._commits_without_expire = 0 self._commits_without_expire = 0
self._old_states: dict[str, States] = {} self._old_states: dict[str | None, States] = {}
self._state_attributes_ids: LRU = LRU(STATE_ATTRIBUTES_ID_CACHE_SIZE) self._state_attributes_ids: LRU = LRU(STATE_ATTRIBUTES_ID_CACHE_SIZE)
self._event_data_ids: LRU = LRU(EVENT_DATA_ID_CACHE_SIZE) self._event_data_ids: LRU = LRU(EVENT_DATA_ID_CACHE_SIZE)
self._pending_state_attributes: dict[str, StateAttributes] = {} self._pending_state_attributes: dict[str, StateAttributes] = {}
@ -739,6 +739,7 @@ class Recorder(threading.Thread):
self.hass.add_job(self._async_migration_started) self.hass.add_job(self._async_migration_started)
try: try:
assert self.engine is not None
migration.migrate_schema( migration.migrate_schema(
self, self.hass, self.engine, self.get_session, schema_status self, self.hass, self.engine, self.get_session, schema_status
) )
@ -1026,6 +1027,8 @@ class Recorder(threading.Thread):
def _post_schema_migration(self, old_version: int, new_version: int) -> None: def _post_schema_migration(self, old_version: int, new_version: int) -> None:
"""Run post schema migration tasks.""" """Run post schema migration tasks."""
assert self.engine is not None
assert self.event_session is not None
migration.post_schema_migration( migration.post_schema_migration(
self.engine, self.event_session, old_version, new_version self.engine, self.event_session, old_version, new_version
) )
@ -1034,7 +1037,7 @@ class Recorder(threading.Thread):
"""Send a keep alive to keep the db connection open.""" """Send a keep alive to keep the db connection open."""
assert self.event_session is not None assert self.event_session is not None
_LOGGER.debug("Sending keepalive") _LOGGER.debug("Sending keepalive")
self.event_session.connection().scalar(select([1])) self.event_session.connection().scalar(select(1))
@callback @callback
def event_listener(self, event: Event) -> None: def event_listener(self, event: Event) -> None:
@ -1198,6 +1201,8 @@ class Recorder(threading.Thread):
start = start.replace(minute=0, second=0, microsecond=0) start = start.replace(minute=0, second=0, microsecond=0)
# Find the newest statistics run, if any # Find the newest statistics run, if any
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
if last_run := session.query(func.max(StatisticsRuns.start)).scalar(): if last_run := session.query(func.max(StatisticsRuns.start)).scalar():
start = max(start, process_timestamp(last_run) + timedelta(minutes=5)) start = max(start, process_timestamp(last_run) + timedelta(minutes=5))

View file

@ -13,7 +13,7 @@ from sqlalchemy import (
JSON, JSON,
BigInteger, BigInteger,
Boolean, Boolean,
Column, ColumnElement,
DateTime, DateTime,
Float, Float,
ForeignKey, ForeignKey,
@ -27,8 +27,9 @@ from sqlalchemy import (
type_coerce, type_coerce,
) )
from sqlalchemy.dialects import mysql, oracle, postgresql, sqlite from sqlalchemy.dialects import mysql, oracle, postgresql, sqlite
from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.orm import aliased, declarative_base, relationship from sqlalchemy.orm import DeclarativeBase, Mapped, aliased, mapped_column, relationship
from sqlalchemy.orm.query import RowReturningQuery
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from typing_extensions import Self from typing_extensions import Self
@ -53,9 +54,12 @@ import homeassistant.util.dt as dt_util
from .const import ALL_DOMAIN_EXCLUDE_ATTRS, SupportedDialect from .const import ALL_DOMAIN_EXCLUDE_ATTRS, SupportedDialect
from .models import StatisticData, StatisticMetaData, process_timestamp from .models import StatisticData, StatisticMetaData, process_timestamp
# SQLAlchemy Schema # SQLAlchemy Schema
# pylint: disable=invalid-name # pylint: disable=invalid-name
Base = declarative_base() class Base(DeclarativeBase):
"""Base class for tables."""
SCHEMA_VERSION = 33 SCHEMA_VERSION = 33
@ -101,7 +105,7 @@ EVENTS_CONTEXT_ID_INDEX = "ix_events_context_id"
STATES_CONTEXT_ID_INDEX = "ix_states_context_id" STATES_CONTEXT_ID_INDEX = "ix_states_context_id"
class FAST_PYSQLITE_DATETIME(sqlite.DATETIME): # type: ignore[misc] class FAST_PYSQLITE_DATETIME(sqlite.DATETIME):
"""Use ciso8601 to parse datetimes instead of sqlalchemy built-in regex.""" """Use ciso8601 to parse datetimes instead of sqlalchemy built-in regex."""
def result_processor(self, dialect, coltype): # type: ignore[no-untyped-def] def result_processor(self, dialect, coltype): # type: ignore[no-untyped-def]
@ -110,19 +114,19 @@ class FAST_PYSQLITE_DATETIME(sqlite.DATETIME): # type: ignore[misc]
JSON_VARIANT_CAST = Text().with_variant( JSON_VARIANT_CAST = Text().with_variant(
postgresql.JSON(none_as_null=True), "postgresql" postgresql.JSON(none_as_null=True), "postgresql" # type: ignore[no-untyped-call]
) )
JSONB_VARIANT_CAST = Text().with_variant( JSONB_VARIANT_CAST = Text().with_variant(
postgresql.JSONB(none_as_null=True), "postgresql" postgresql.JSONB(none_as_null=True), "postgresql" # type: ignore[no-untyped-call]
) )
DATETIME_TYPE = ( DATETIME_TYPE = (
DateTime(timezone=True) DateTime(timezone=True)
.with_variant(mysql.DATETIME(timezone=True, fsp=6), "mysql") .with_variant(mysql.DATETIME(timezone=True, fsp=6), "mysql") # type: ignore[no-untyped-call]
.with_variant(FAST_PYSQLITE_DATETIME(), "sqlite") .with_variant(FAST_PYSQLITE_DATETIME(), "sqlite") # type: ignore[no-untyped-call]
) )
DOUBLE_TYPE = ( DOUBLE_TYPE = (
Float() Float()
.with_variant(mysql.DOUBLE(asdecimal=False), "mysql") .with_variant(mysql.DOUBLE(asdecimal=False), "mysql") # type: ignore[no-untyped-call]
.with_variant(oracle.DOUBLE_PRECISION(), "oracle") .with_variant(oracle.DOUBLE_PRECISION(), "oracle")
.with_variant(postgresql.DOUBLE_PRECISION(), "postgresql") .with_variant(postgresql.DOUBLE_PRECISION(), "postgresql")
) )
@ -130,10 +134,10 @@ DOUBLE_TYPE = (
TIMESTAMP_TYPE = DOUBLE_TYPE TIMESTAMP_TYPE = DOUBLE_TYPE
class JSONLiteral(JSON): # type: ignore[misc] class JSONLiteral(JSON):
"""Teach SA how to literalize json.""" """Teach SA how to literalize json."""
def literal_processor(self, dialect: str) -> Callable[[Any], str]: def literal_processor(self, dialect: Dialect) -> Callable[[Any], str]:
"""Processor to convert a value to JSON.""" """Processor to convert a value to JSON."""
def process(value: Any) -> str: def process(value: Any) -> str:
@ -147,7 +151,7 @@ EVENT_ORIGIN_ORDER = [EventOrigin.local, EventOrigin.remote]
EVENT_ORIGIN_TO_IDX = {origin: idx for idx, origin in enumerate(EVENT_ORIGIN_ORDER)} EVENT_ORIGIN_TO_IDX = {origin: idx for idx, origin in enumerate(EVENT_ORIGIN_ORDER)}
class Events(Base): # type: ignore[misc,valid-type] class Events(Base):
"""Event history data.""" """Event history data."""
__table_args__ = ( __table_args__ = (
@ -157,18 +161,32 @@ class Events(Base): # type: ignore[misc,valid-type]
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
) )
__tablename__ = TABLE_EVENTS __tablename__ = TABLE_EVENTS
event_id = Column(Integer, Identity(), primary_key=True) event_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True)
event_type = Column(String(MAX_LENGTH_EVENT_EVENT_TYPE)) event_type: Mapped[str | None] = mapped_column(String(MAX_LENGTH_EVENT_EVENT_TYPE))
event_data = Column(Text().with_variant(mysql.LONGTEXT, "mysql")) event_data: Mapped[str | None] = mapped_column(
origin = Column(String(MAX_LENGTH_EVENT_ORIGIN)) # no longer used for new rows Text().with_variant(mysql.LONGTEXT, "mysql")
origin_idx = Column(SmallInteger) )
time_fired = Column(DATETIME_TYPE) # no longer used for new rows origin: Mapped[str | None] = mapped_column(
time_fired_ts = Column(TIMESTAMP_TYPE, index=True) String(MAX_LENGTH_EVENT_ORIGIN)
context_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True) ) # no longer used for new rows
context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID)) origin_idx: Mapped[int | None] = mapped_column(SmallInteger)
context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID)) time_fired: Mapped[datetime | None] = mapped_column(
data_id = Column(Integer, ForeignKey("event_data.data_id"), index=True) DATETIME_TYPE
event_data_rel = relationship("EventData") ) # no longer used for new rows
time_fired_ts: Mapped[float | None] = mapped_column(TIMESTAMP_TYPE, index=True)
context_id: Mapped[str | None] = mapped_column(
String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True
)
context_user_id: Mapped[str | None] = mapped_column(
String(MAX_LENGTH_EVENT_CONTEXT_ID)
)
context_parent_id: Mapped[str | None] = mapped_column(
String(MAX_LENGTH_EVENT_CONTEXT_ID)
)
data_id: Mapped[int | None] = mapped_column(
Integer, ForeignKey("event_data.data_id"), index=True
)
event_data_rel: Mapped[EventData | None] = relationship("EventData")
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return string representation of instance for debugging.""" """Return string representation of instance for debugging."""
@ -180,12 +198,15 @@ class Events(Base): # type: ignore[misc,valid-type]
) )
@property @property
def _time_fired_isotime(self) -> str: def _time_fired_isotime(self) -> str | None:
"""Return time_fired as an isotime string.""" """Return time_fired as an isotime string."""
date_time: datetime | None
if self.time_fired_ts is not None: if self.time_fired_ts is not None:
date_time = dt_util.utc_from_timestamp(self.time_fired_ts) date_time = dt_util.utc_from_timestamp(self.time_fired_ts)
else: else:
date_time = process_timestamp(self.time_fired) date_time = process_timestamp(self.time_fired)
if date_time is None:
return None
return date_time.isoformat(sep=" ", timespec="seconds") return date_time.isoformat(sep=" ", timespec="seconds")
@staticmethod @staticmethod
@ -211,12 +232,12 @@ class Events(Base): # type: ignore[misc,valid-type]
) )
try: try:
return Event( return Event(
self.event_type, self.event_type or "",
json_loads_object(self.event_data) if self.event_data else {}, json_loads_object(self.event_data) if self.event_data else {},
EventOrigin(self.origin) EventOrigin(self.origin)
if self.origin if self.origin
else EVENT_ORIGIN_ORDER[self.origin_idx], else EVENT_ORIGIN_ORDER[self.origin_idx or 0],
dt_util.utc_from_timestamp(self.time_fired_ts), dt_util.utc_from_timestamp(self.time_fired_ts or 0),
context=context, context=context,
) )
except JSON_DECODE_EXCEPTIONS: except JSON_DECODE_EXCEPTIONS:
@ -225,17 +246,19 @@ class Events(Base): # type: ignore[misc,valid-type]
return None return None
class EventData(Base): # type: ignore[misc,valid-type] class EventData(Base):
"""Event data history.""" """Event data history."""
__table_args__ = ( __table_args__ = (
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
) )
__tablename__ = TABLE_EVENT_DATA __tablename__ = TABLE_EVENT_DATA
data_id = Column(Integer, Identity(), primary_key=True) data_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True)
hash = Column(BigInteger, index=True) hash: Mapped[int | None] = mapped_column(BigInteger, index=True)
# Note that this is not named attributes to avoid confusion with the states table # Note that this is not named attributes to avoid confusion with the states table
shared_data = Column(Text().with_variant(mysql.LONGTEXT, "mysql")) shared_data: Mapped[str | None] = mapped_column(
Text().with_variant(mysql.LONGTEXT, "mysql")
)
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return string representation of instance for debugging.""" """Return string representation of instance for debugging."""
@ -260,15 +283,18 @@ class EventData(Base): # type: ignore[misc,valid-type]
return cast(int, fnv1a_32(shared_data_bytes)) return cast(int, fnv1a_32(shared_data_bytes))
def to_native(self) -> dict[str, Any]: def to_native(self) -> dict[str, Any]:
"""Convert to an HA state object.""" """Convert to an event data dictionary."""
shared_data = self.shared_data
if shared_data is None:
return {}
try: try:
return cast(dict[str, Any], json_loads(self.shared_data)) return cast(dict[str, Any], json_loads(shared_data))
except JSON_DECODE_EXCEPTIONS: except JSON_DECODE_EXCEPTIONS:
_LOGGER.exception("Error converting row to event data: %s", self) _LOGGER.exception("Error converting row to event data: %s", self)
return {} return {}
class States(Base): # type: ignore[misc,valid-type] class States(Base):
"""State change history.""" """State change history."""
__table_args__ = ( __table_args__ = (
@ -278,29 +304,45 @@ class States(Base): # type: ignore[misc,valid-type]
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
) )
__tablename__ = TABLE_STATES __tablename__ = TABLE_STATES
state_id = Column(Integer, Identity(), primary_key=True) state_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True)
entity_id = Column(String(MAX_LENGTH_STATE_ENTITY_ID)) entity_id: Mapped[str | None] = mapped_column(String(MAX_LENGTH_STATE_ENTITY_ID))
state = Column(String(MAX_LENGTH_STATE_STATE)) state: Mapped[str | None] = mapped_column(String(MAX_LENGTH_STATE_STATE))
attributes = Column( attributes: Mapped[str | None] = mapped_column(
Text().with_variant(mysql.LONGTEXT, "mysql") Text().with_variant(mysql.LONGTEXT, "mysql")
) # no longer used for new rows ) # no longer used for new rows
event_id = Column( # no longer used for new rows event_id: Mapped[int | None] = mapped_column( # no longer used for new rows
Integer, ForeignKey("events.event_id", ondelete="CASCADE"), index=True Integer, ForeignKey("events.event_id", ondelete="CASCADE"), index=True
) )
last_changed = Column(DATETIME_TYPE) # no longer used for new rows last_changed: Mapped[datetime | None] = mapped_column(
last_changed_ts = Column(TIMESTAMP_TYPE) DATETIME_TYPE
last_updated = Column(DATETIME_TYPE) # no longer used for new rows ) # no longer used for new rows
last_updated_ts = Column(TIMESTAMP_TYPE, default=time.time, index=True) last_changed_ts: Mapped[float | None] = mapped_column(TIMESTAMP_TYPE)
old_state_id = Column(Integer, ForeignKey("states.state_id"), index=True) last_updated: Mapped[datetime | None] = mapped_column(
attributes_id = Column( DATETIME_TYPE
) # no longer used for new rows
last_updated_ts: Mapped[float | None] = mapped_column(
TIMESTAMP_TYPE, default=time.time, index=True
)
old_state_id: Mapped[int | None] = mapped_column(
Integer, ForeignKey("states.state_id"), index=True
)
attributes_id: Mapped[int | None] = mapped_column(
Integer, ForeignKey("state_attributes.attributes_id"), index=True Integer, ForeignKey("state_attributes.attributes_id"), index=True
) )
context_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True) context_id: Mapped[str | None] = mapped_column(
context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID)) String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True
context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID)) )
origin_idx = Column(SmallInteger) # 0 is local, 1 is remote context_user_id: Mapped[str | None] = mapped_column(
old_state = relationship("States", remote_side=[state_id]) String(MAX_LENGTH_EVENT_CONTEXT_ID)
state_attributes = relationship("StateAttributes") )
context_parent_id: Mapped[str | None] = mapped_column(
String(MAX_LENGTH_EVENT_CONTEXT_ID)
)
origin_idx: Mapped[int | None] = mapped_column(
SmallInteger
) # 0 is local, 1 is remote
old_state: Mapped[States | None] = relationship("States", remote_side=[state_id])
state_attributes: Mapped[StateAttributes | None] = relationship("StateAttributes")
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return string representation of instance for debugging.""" """Return string representation of instance for debugging."""
@ -312,12 +354,15 @@ class States(Base): # type: ignore[misc,valid-type]
) )
@property @property
def _last_updated_isotime(self) -> str: def _last_updated_isotime(self) -> str | None:
"""Return last_updated as an isotime string.""" """Return last_updated as an isotime string."""
date_time: datetime | None
if self.last_updated_ts is not None: if self.last_updated_ts is not None:
date_time = dt_util.utc_from_timestamp(self.last_updated_ts) date_time = dt_util.utc_from_timestamp(self.last_updated_ts)
else: else:
date_time = process_timestamp(self.last_updated) date_time = process_timestamp(self.last_updated)
if date_time is None:
return None
return date_time.isoformat(sep=" ", timespec="seconds") return date_time.isoformat(sep=" ", timespec="seconds")
@staticmethod @staticmethod
@ -372,8 +417,8 @@ class States(Base): # type: ignore[misc,valid-type]
last_updated = dt_util.utc_from_timestamp(self.last_updated_ts or 0) last_updated = dt_util.utc_from_timestamp(self.last_updated_ts or 0)
last_changed = dt_util.utc_from_timestamp(self.last_changed_ts or 0) last_changed = dt_util.utc_from_timestamp(self.last_changed_ts or 0)
return State( return State(
self.entity_id, self.entity_id or "",
self.state, self.state, # type: ignore[arg-type]
# Join the state_attributes table on attributes_id to get the attributes # Join the state_attributes table on attributes_id to get the attributes
# for newer states # for newer states
attrs, attrs,
@ -384,17 +429,19 @@ class States(Base): # type: ignore[misc,valid-type]
) )
class StateAttributes(Base): # type: ignore[misc,valid-type] class StateAttributes(Base):
"""State attribute change history.""" """State attribute change history."""
__table_args__ = ( __table_args__ = (
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
) )
__tablename__ = TABLE_STATE_ATTRIBUTES __tablename__ = TABLE_STATE_ATTRIBUTES
attributes_id = Column(Integer, Identity(), primary_key=True) attributes_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True)
hash = Column(BigInteger, index=True) hash: Mapped[int | None] = mapped_column(BigInteger, index=True)
# Note that this is not named attributes to avoid confusion with the states table # Note that this is not named attributes to avoid confusion with the states table
shared_attrs = Column(Text().with_variant(mysql.LONGTEXT, "mysql")) shared_attrs: Mapped[str | None] = mapped_column(
Text().with_variant(mysql.LONGTEXT, "mysql")
)
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return string representation of instance for debugging.""" """Return string representation of instance for debugging."""
@ -439,9 +486,12 @@ class StateAttributes(Base): # type: ignore[misc,valid-type]
return cast(int, fnv1a_32(shared_attrs_bytes)) return cast(int, fnv1a_32(shared_attrs_bytes))
def to_native(self) -> dict[str, Any]: def to_native(self) -> dict[str, Any]:
"""Convert to an HA state object.""" """Convert to a state attributes dictionary."""
shared_attrs = self.shared_attrs
if shared_attrs is None:
return {}
try: try:
return cast(dict[str, Any], json_loads(self.shared_attrs)) return cast(dict[str, Any], json_loads(shared_attrs))
except JSON_DECODE_EXCEPTIONS: except JSON_DECODE_EXCEPTIONS:
# When json_loads fails # When json_loads fails
_LOGGER.exception("Error converting row to state attributes: %s", self) _LOGGER.exception("Error converting row to state attributes: %s", self)
@ -451,25 +501,22 @@ class StateAttributes(Base): # type: ignore[misc,valid-type]
class StatisticsBase: class StatisticsBase:
"""Statistics base class.""" """Statistics base class."""
id = Column(Integer, Identity(), primary_key=True) id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True)
created = Column(DATETIME_TYPE, default=dt_util.utcnow) created: Mapped[datetime] = mapped_column(DATETIME_TYPE, default=dt_util.utcnow)
metadata_id: Mapped[int | None] = mapped_column(
Integer,
ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"),
index=True,
)
start: Mapped[datetime | None] = mapped_column(DATETIME_TYPE, index=True)
mean: Mapped[float | None] = mapped_column(DOUBLE_TYPE)
min: Mapped[float | None] = mapped_column(DOUBLE_TYPE)
max: Mapped[float | None] = mapped_column(DOUBLE_TYPE)
last_reset: Mapped[datetime | None] = mapped_column(DATETIME_TYPE)
state: Mapped[float | None] = mapped_column(DOUBLE_TYPE)
sum: Mapped[float | None] = mapped_column(DOUBLE_TYPE)
@declared_attr # type: ignore[misc] duration: timedelta
def metadata_id(self) -> Column:
"""Define the metadata_id column for sub classes."""
return Column(
Integer,
ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"),
index=True,
)
start = Column(DATETIME_TYPE, index=True)
mean = Column(DOUBLE_TYPE)
min = Column(DOUBLE_TYPE)
max = Column(DOUBLE_TYPE)
last_reset = Column(DATETIME_TYPE)
state = Column(DOUBLE_TYPE)
sum = Column(DOUBLE_TYPE)
@classmethod @classmethod
def from_stats(cls, metadata_id: int, stats: StatisticData) -> Self: def from_stats(cls, metadata_id: int, stats: StatisticData) -> Self:
@ -480,7 +527,7 @@ class StatisticsBase:
) )
class Statistics(Base, StatisticsBase): # type: ignore[misc,valid-type] class Statistics(Base, StatisticsBase):
"""Long term statistics.""" """Long term statistics."""
duration = timedelta(hours=1) duration = timedelta(hours=1)
@ -492,7 +539,7 @@ class Statistics(Base, StatisticsBase): # type: ignore[misc,valid-type]
__tablename__ = TABLE_STATISTICS __tablename__ = TABLE_STATISTICS
class StatisticsShortTerm(Base, StatisticsBase): # type: ignore[misc,valid-type] class StatisticsShortTerm(Base, StatisticsBase):
"""Short term statistics.""" """Short term statistics."""
duration = timedelta(minutes=5) duration = timedelta(minutes=5)
@ -509,20 +556,22 @@ class StatisticsShortTerm(Base, StatisticsBase): # type: ignore[misc,valid-type
__tablename__ = TABLE_STATISTICS_SHORT_TERM __tablename__ = TABLE_STATISTICS_SHORT_TERM
class StatisticsMeta(Base): # type: ignore[misc,valid-type] class StatisticsMeta(Base):
"""Statistics meta data.""" """Statistics meta data."""
__table_args__ = ( __table_args__ = (
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
) )
__tablename__ = TABLE_STATISTICS_META __tablename__ = TABLE_STATISTICS_META
id = Column(Integer, Identity(), primary_key=True) id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True)
statistic_id = Column(String(255), index=True, unique=True) statistic_id: Mapped[str | None] = mapped_column(
source = Column(String(32)) String(255), index=True, unique=True
unit_of_measurement = Column(String(255)) )
has_mean = Column(Boolean) source: Mapped[str | None] = mapped_column(String(32))
has_sum = Column(Boolean) unit_of_measurement: Mapped[str | None] = mapped_column(String(255))
name = Column(String(255)) has_mean: Mapped[bool | None] = mapped_column(Boolean)
has_sum: Mapped[bool | None] = mapped_column(Boolean)
name: Mapped[str | None] = mapped_column(String(255))
@staticmethod @staticmethod
def from_meta(meta: StatisticMetaData) -> StatisticsMeta: def from_meta(meta: StatisticMetaData) -> StatisticsMeta:
@ -530,16 +579,16 @@ class StatisticsMeta(Base): # type: ignore[misc,valid-type]
return StatisticsMeta(**meta) return StatisticsMeta(**meta)
class RecorderRuns(Base): # type: ignore[misc,valid-type] class RecorderRuns(Base):
"""Representation of recorder run.""" """Representation of recorder run."""
__table_args__ = (Index("ix_recorder_runs_start_end", "start", "end"),) __table_args__ = (Index("ix_recorder_runs_start_end", "start", "end"),)
__tablename__ = TABLE_RECORDER_RUNS __tablename__ = TABLE_RECORDER_RUNS
run_id = Column(Integer, Identity(), primary_key=True) run_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True)
start = Column(DATETIME_TYPE, default=dt_util.utcnow) start: Mapped[datetime] = mapped_column(DATETIME_TYPE, default=dt_util.utcnow)
end = Column(DATETIME_TYPE) end: Mapped[datetime | None] = mapped_column(DATETIME_TYPE)
closed_incorrect = Column(Boolean, default=False) closed_incorrect: Mapped[bool] = mapped_column(Boolean, default=False)
created = Column(DATETIME_TYPE, default=dt_util.utcnow) created: Mapped[datetime] = mapped_column(DATETIME_TYPE, default=dt_util.utcnow)
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return string representation of instance for debugging.""" """Return string representation of instance for debugging."""
@ -563,9 +612,9 @@ class RecorderRuns(Base): # type: ignore[misc,valid-type]
assert session is not None, "RecorderRuns need to be persisted" assert session is not None, "RecorderRuns need to be persisted"
query = session.query(distinct(States.entity_id)).filter( query: RowReturningQuery[tuple[str]] = session.query(distinct(States.entity_id))
States.last_updated >= self.start
) query = query.filter(States.last_updated >= self.start)
if point_in_time is not None: if point_in_time is not None:
query = query.filter(States.last_updated < point_in_time) query = query.filter(States.last_updated < point_in_time)
@ -579,13 +628,13 @@ class RecorderRuns(Base): # type: ignore[misc,valid-type]
return self return self
class SchemaChanges(Base): # type: ignore[misc,valid-type] class SchemaChanges(Base):
"""Representation of schema version changes.""" """Representation of schema version changes."""
__tablename__ = TABLE_SCHEMA_CHANGES __tablename__ = TABLE_SCHEMA_CHANGES
change_id = Column(Integer, Identity(), primary_key=True) change_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True)
schema_version = Column(Integer) schema_version: Mapped[int | None] = mapped_column(Integer)
changed = Column(DATETIME_TYPE, default=dt_util.utcnow) changed: Mapped[datetime] = mapped_column(DATETIME_TYPE, default=dt_util.utcnow)
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return string representation of instance for debugging.""" """Return string representation of instance for debugging."""
@ -597,12 +646,12 @@ class SchemaChanges(Base): # type: ignore[misc,valid-type]
) )
class StatisticsRuns(Base): # type: ignore[misc,valid-type] class StatisticsRuns(Base):
"""Representation of statistics run.""" """Representation of statistics run."""
__tablename__ = TABLE_STATISTICS_RUNS __tablename__ = TABLE_STATISTICS_RUNS
run_id = Column(Integer, Identity(), primary_key=True) run_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True)
start = Column(DATETIME_TYPE, index=True) start: Mapped[datetime] = mapped_column(DATETIME_TYPE, index=True)
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return string representation of instance for debugging.""" """Return string representation of instance for debugging."""
@ -626,7 +675,7 @@ OLD_FORMAT_ATTRS_JSON = type_coerce(
States.attributes.cast(JSON_VARIANT_CAST), JSON(none_as_null=True) States.attributes.cast(JSON_VARIANT_CAST), JSON(none_as_null=True)
) )
ENTITY_ID_IN_EVENT: Column = EVENT_DATA_JSON["entity_id"] ENTITY_ID_IN_EVENT: ColumnElement = EVENT_DATA_JSON["entity_id"]
OLD_ENTITY_ID_IN_EVENT: Column = OLD_FORMAT_EVENT_DATA_JSON["entity_id"] OLD_ENTITY_ID_IN_EVENT: ColumnElement = OLD_FORMAT_EVENT_DATA_JSON["entity_id"]
DEVICE_ID_IN_EVENT: Column = EVENT_DATA_JSON["device_id"] DEVICE_ID_IN_EVENT: ColumnElement = EVENT_DATA_JSON["device_id"]
OLD_STATE = aliased(States, name="old_state") OLD_STATE = aliased(States, name="old_state")

View file

@ -6,7 +6,7 @@ import json
from typing import Any from typing import Any
from sqlalchemy import Column, Text, cast, not_, or_ from sqlalchemy import Column, Text, cast, not_, or_
from sqlalchemy.sql.elements import ClauseList from sqlalchemy.sql.elements import ColumnElement
from homeassistant.const import CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE from homeassistant.const import CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE
from homeassistant.helpers.entityfilter import CONF_ENTITY_GLOBS from homeassistant.helpers.entityfilter import CONF_ENTITY_GLOBS
@ -125,7 +125,7 @@ class Filters:
def _generate_filter_for_columns( def _generate_filter_for_columns(
self, columns: Iterable[Column], encoder: Callable[[Any], Any] self, columns: Iterable[Column], encoder: Callable[[Any], Any]
) -> ClauseList: ) -> ColumnElement | None:
"""Generate a filter from pre-comuted sets and pattern lists. """Generate a filter from pre-comuted sets and pattern lists.
This must match exactly how homeassistant.helpers.entityfilter works. This must match exactly how homeassistant.helpers.entityfilter works.
@ -174,6 +174,8 @@ class Filters:
if self.included_domains or self.included_entity_globs: if self.included_domains or self.included_entity_globs:
return or_( return or_(
i_entities, i_entities,
# https://github.com/sqlalchemy/sqlalchemy/issues/9190
# pylint: disable-next=invalid-unary-operand-type
(~e_entities & (i_entity_globs | (~e_entity_globs & i_domains))), (~e_entities & (i_entity_globs | (~e_entity_globs & i_domains))),
).self_group() ).self_group()
@ -184,23 +186,24 @@ class Filters:
# - Otherwise, entity matches domain exclude: exclude # - Otherwise, entity matches domain exclude: exclude
# - Otherwise: include # - Otherwise: include
if self.excluded_domains or self.excluded_entity_globs: if self.excluded_domains or self.excluded_entity_globs:
return (not_(or_(*excludes)) | i_entities).self_group() return (not_(or_(*excludes)) | i_entities).self_group() # type: ignore[no-any-return, no-untyped-call]
# Case 6 - No Domain and/or glob includes or excludes # Case 6 - No Domain and/or glob includes or excludes
# - Entity listed in entities include: include # - Entity listed in entities include: include
# - Otherwise: exclude # - Otherwise: exclude
return i_entities return i_entities
def states_entity_filter(self) -> ClauseList: def states_entity_filter(self) -> ColumnElement | None:
"""Generate the entity filter query.""" """Generate the entity filter query."""
def _encoder(data: Any) -> Any: def _encoder(data: Any) -> Any:
"""Nothing to encode for states since there is no json.""" """Nothing to encode for states since there is no json."""
return data return data
return self._generate_filter_for_columns((States.entity_id,), _encoder) # The type annotation should be improved so the type ignore can be removed
return self._generate_filter_for_columns((States.entity_id,), _encoder) # type: ignore[arg-type]
def events_entity_filter(self) -> ClauseList: def events_entity_filter(self) -> ColumnElement:
"""Generate the entity filter query.""" """Generate the entity filter query."""
_encoder = json.dumps _encoder = json.dumps
return or_( return or_(
@ -215,15 +218,16 @@ class Filters:
& ( & (
(OLD_ENTITY_ID_IN_EVENT == JSON_NULL) | OLD_ENTITY_ID_IN_EVENT.is_(None) (OLD_ENTITY_ID_IN_EVENT == JSON_NULL) | OLD_ENTITY_ID_IN_EVENT.is_(None)
), ),
self._generate_filter_for_columns( # Needs https://github.com/bdraco/home-assistant/commit/bba91945006a46f3a01870008eb048e4f9cbb1ef
(ENTITY_ID_IN_EVENT, OLD_ENTITY_ID_IN_EVENT), _encoder self._generate_filter_for_columns( # type: ignore[union-attr]
(ENTITY_ID_IN_EVENT, OLD_ENTITY_ID_IN_EVENT), _encoder # type: ignore[arg-type]
).self_group(), ).self_group(),
) )
def _globs_to_like( def _globs_to_like(
glob_strs: Iterable[str], columns: Iterable[Column], encoder: Callable[[Any], Any] glob_strs: Iterable[str], columns: Iterable[Column], encoder: Callable[[Any], Any]
) -> ClauseList: ) -> ColumnElement:
"""Translate glob to sql.""" """Translate glob to sql."""
matchers = [ matchers = [
( (
@ -240,7 +244,7 @@ def _globs_to_like(
def _entity_matcher( def _entity_matcher(
entity_ids: Iterable[str], columns: Iterable[Column], encoder: Callable[[Any], Any] entity_ids: Iterable[str], columns: Iterable[Column], encoder: Callable[[Any], Any]
) -> ClauseList: ) -> ColumnElement:
matchers = [ matchers = [
( (
column.is_not(None) column.is_not(None)
@ -253,7 +257,7 @@ def _entity_matcher(
def _domain_matcher( def _domain_matcher(
domains: Iterable[str], columns: Iterable[Column], encoder: Callable[[Any], Any] domains: Iterable[str], columns: Iterable[Column], encoder: Callable[[Any], Any]
) -> ClauseList: ) -> ColumnElement:
matchers = [ matchers = [
(column.is_not(None) & cast(column, Text()).like(encoder(domain_matcher))) (column.is_not(None) & cast(column, Text()).like(encoder(domain_matcher)))
for domain_matcher in like_domain_matchers(domains) for domain_matcher in like_domain_matchers(domains)

View file

@ -59,87 +59,87 @@ NEED_ATTRIBUTE_DOMAINS = {
} }
_BASE_STATES = [ _BASE_STATES = (
States.entity_id, States.entity_id,
States.state, States.state,
States.last_changed_ts, States.last_changed_ts,
States.last_updated_ts, States.last_updated_ts,
] )
_BASE_STATES_NO_LAST_CHANGED = [ _BASE_STATES_NO_LAST_CHANGED = ( # type: ignore[var-annotated]
States.entity_id, States.entity_id,
States.state, States.state,
literal(value=None).label("last_changed_ts"), literal(value=None).label("last_changed_ts"),
States.last_updated_ts, States.last_updated_ts,
] )
_QUERY_STATE_NO_ATTR = [ _QUERY_STATE_NO_ATTR = (
*_BASE_STATES, *_BASE_STATES,
literal(value=None, type_=Text).label("attributes"), literal(value=None, type_=Text).label("attributes"),
literal(value=None, type_=Text).label("shared_attrs"), literal(value=None, type_=Text).label("shared_attrs"),
] )
_QUERY_STATE_NO_ATTR_NO_LAST_CHANGED = [ _QUERY_STATE_NO_ATTR_NO_LAST_CHANGED = (
*_BASE_STATES_NO_LAST_CHANGED, *_BASE_STATES_NO_LAST_CHANGED,
literal(value=None, type_=Text).label("attributes"), literal(value=None, type_=Text).label("attributes"),
literal(value=None, type_=Text).label("shared_attrs"), literal(value=None, type_=Text).label("shared_attrs"),
] )
_BASE_STATES_PRE_SCHEMA_31 = [ _BASE_STATES_PRE_SCHEMA_31 = (
States.entity_id, States.entity_id,
States.state, States.state,
States.last_changed, States.last_changed,
States.last_updated, States.last_updated,
] )
_BASE_STATES_NO_LAST_CHANGED_PRE_SCHEMA_31 = [ _BASE_STATES_NO_LAST_CHANGED_PRE_SCHEMA_31 = (
States.entity_id, States.entity_id,
States.state, States.state,
literal(value=None, type_=Text).label("last_changed"), literal(value=None, type_=Text).label("last_changed"),
States.last_updated, States.last_updated,
] )
_QUERY_STATE_NO_ATTR_PRE_SCHEMA_31 = [ _QUERY_STATE_NO_ATTR_PRE_SCHEMA_31 = (
*_BASE_STATES_PRE_SCHEMA_31, *_BASE_STATES_PRE_SCHEMA_31,
literal(value=None, type_=Text).label("attributes"), literal(value=None, type_=Text).label("attributes"),
literal(value=None, type_=Text).label("shared_attrs"), literal(value=None, type_=Text).label("shared_attrs"),
] )
_QUERY_STATE_NO_ATTR_NO_LAST_CHANGED_PRE_SCHEMA_31 = [ _QUERY_STATE_NO_ATTR_NO_LAST_CHANGED_PRE_SCHEMA_31 = (
*_BASE_STATES_NO_LAST_CHANGED_PRE_SCHEMA_31, *_BASE_STATES_NO_LAST_CHANGED_PRE_SCHEMA_31,
literal(value=None, type_=Text).label("attributes"), literal(value=None, type_=Text).label("attributes"),
literal(value=None, type_=Text).label("shared_attrs"), literal(value=None, type_=Text).label("shared_attrs"),
] )
# Remove QUERY_STATES_PRE_SCHEMA_25 # Remove QUERY_STATES_PRE_SCHEMA_25
# and the migration_in_progress check # and the migration_in_progress check
# once schema 26 is created # once schema 26 is created
_QUERY_STATES_PRE_SCHEMA_25 = [ _QUERY_STATES_PRE_SCHEMA_25 = (
*_BASE_STATES_PRE_SCHEMA_31, *_BASE_STATES_PRE_SCHEMA_31,
States.attributes, States.attributes,
literal(value=None, type_=Text).label("shared_attrs"), literal(value=None, type_=Text).label("shared_attrs"),
] )
_QUERY_STATES_PRE_SCHEMA_25_NO_LAST_CHANGED = [ _QUERY_STATES_PRE_SCHEMA_25_NO_LAST_CHANGED = (
*_BASE_STATES_NO_LAST_CHANGED_PRE_SCHEMA_31, *_BASE_STATES_NO_LAST_CHANGED_PRE_SCHEMA_31,
States.attributes, States.attributes,
literal(value=None, type_=Text).label("shared_attrs"), literal(value=None, type_=Text).label("shared_attrs"),
] )
_QUERY_STATES_PRE_SCHEMA_31 = [ _QUERY_STATES_PRE_SCHEMA_31 = (
*_BASE_STATES_PRE_SCHEMA_31, *_BASE_STATES_PRE_SCHEMA_31,
# Remove States.attributes once all attributes are in StateAttributes.shared_attrs # Remove States.attributes once all attributes are in StateAttributes.shared_attrs
States.attributes, States.attributes,
StateAttributes.shared_attrs, StateAttributes.shared_attrs,
] )
_QUERY_STATES_NO_LAST_CHANGED_PRE_SCHEMA_31 = [ _QUERY_STATES_NO_LAST_CHANGED_PRE_SCHEMA_31 = (
*_BASE_STATES_NO_LAST_CHANGED_PRE_SCHEMA_31, *_BASE_STATES_NO_LAST_CHANGED_PRE_SCHEMA_31,
# Remove States.attributes once all attributes are in StateAttributes.shared_attrs # Remove States.attributes once all attributes are in StateAttributes.shared_attrs
States.attributes, States.attributes,
StateAttributes.shared_attrs, StateAttributes.shared_attrs,
] )
_QUERY_STATES = [ _QUERY_STATES = (
*_BASE_STATES, *_BASE_STATES,
# Remove States.attributes once all attributes are in StateAttributes.shared_attrs # Remove States.attributes once all attributes are in StateAttributes.shared_attrs
States.attributes, States.attributes,
StateAttributes.shared_attrs, StateAttributes.shared_attrs,
] )
_QUERY_STATES_NO_LAST_CHANGED = [ _QUERY_STATES_NO_LAST_CHANGED = (
*_BASE_STATES_NO_LAST_CHANGED, *_BASE_STATES_NO_LAST_CHANGED,
# Remove States.attributes once all attributes are in StateAttributes.shared_attrs # Remove States.attributes once all attributes are in StateAttributes.shared_attrs
States.attributes, States.attributes,
StateAttributes.shared_attrs, StateAttributes.shared_attrs,
] )
def _schema_version(hass: HomeAssistant) -> int: def _schema_version(hass: HomeAssistant) -> int:
@ -305,7 +305,10 @@ def _significant_states_stmt(
) )
if entity_ids: if entity_ids:
stmt += lambda q: q.filter(States.entity_id.in_(entity_ids)) stmt += lambda q: q.filter(
# https://github.com/python/mypy/issues/2608
States.entity_id.in_(entity_ids) # type:ignore[arg-type]
)
else: else:
stmt += _ignore_domains_filter stmt += _ignore_domains_filter
if filters and filters.has_config: if filters and filters.has_config:
@ -598,6 +601,8 @@ def _get_states_for_entites_stmt(
stmt += lambda q: q.where( stmt += lambda q: q.where(
States.state_id States.state_id
== ( == (
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
select(func.max(States.state_id).label("max_state_id")) select(func.max(States.state_id).label("max_state_id"))
.filter( .filter(
(States.last_updated_ts >= run_start_ts) (States.last_updated_ts >= run_start_ts)
@ -612,6 +617,8 @@ def _get_states_for_entites_stmt(
stmt += lambda q: q.where( stmt += lambda q: q.where(
States.state_id States.state_id
== ( == (
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
select(func.max(States.state_id).label("max_state_id")) select(func.max(States.state_id).label("max_state_id"))
.filter( .filter(
(States.last_updated >= run_start) (States.last_updated >= run_start)
@ -641,6 +648,8 @@ def _generate_most_recent_states_by_date(
return ( return (
select( select(
States.entity_id.label("max_entity_id"), States.entity_id.label("max_entity_id"),
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
func.max(States.last_updated_ts).label("max_last_updated"), func.max(States.last_updated_ts).label("max_last_updated"),
) )
.filter( .filter(
@ -653,6 +662,8 @@ def _generate_most_recent_states_by_date(
return ( return (
select( select(
States.entity_id.label("max_entity_id"), States.entity_id.label("max_entity_id"),
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
func.max(States.last_updated).label("max_last_updated"), func.max(States.last_updated).label("max_last_updated"),
) )
.filter( .filter(
@ -686,6 +697,8 @@ def _get_states_for_all_stmt(
stmt += lambda q: q.where( stmt += lambda q: q.where(
States.state_id States.state_id
== ( == (
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
select(func.max(States.state_id).label("max_state_id")) select(func.max(States.state_id).label("max_state_id"))
.join( .join(
most_recent_states_by_date, most_recent_states_by_date,
@ -703,6 +716,8 @@ def _get_states_for_all_stmt(
stmt += lambda q: q.where( stmt += lambda q: q.where(
States.state_id States.state_id
== ( == (
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
select(func.max(States.state_id).label("max_state_id")) select(func.max(States.state_id).label("max_state_id"))
.join( .join(
most_recent_states_by_date, most_recent_states_by_date,
@ -876,11 +891,11 @@ def _sorted_states_to_dict(
_LOGGER.debug("getting %d first datapoints took %fs", len(result), elapsed) _LOGGER.debug("getting %d first datapoints took %fs", len(result), elapsed)
if entity_ids and len(entity_ids) == 1: if entity_ids and len(entity_ids) == 1:
states_iter: Iterable[tuple[str | Column, Iterator[States]]] = ( states_iter: Iterable[tuple[str, Iterator[Row]]] = (
(entity_ids[0], iter(states)), (entity_ids[0], iter(states)),
) )
else: else:
states_iter = groupby(states, lambda state: state.entity_id) states_iter = groupby(states, lambda state: state.entity_id) # type: ignore[no-any-return]
# Append all changes to it # Append all changes to it
for ent_id, group in states_iter: for ent_id, group in states_iter:

View file

@ -2,7 +2,7 @@
"domain": "recorder", "domain": "recorder",
"name": "Recorder", "name": "Recorder",
"documentation": "https://www.home-assistant.io/integrations/recorder", "documentation": "https://www.home-assistant.io/integrations/recorder",
"requirements": ["sqlalchemy==1.4.45", "fnvhash==0.1.0"], "requirements": ["sqlalchemy==2.0.2", "fnvhash==0.1.0"],
"codeowners": ["@home-assistant/core"], "codeowners": ["@home-assistant/core"],
"quality_scale": "internal", "quality_scale": "internal",
"iot_class": "local_push", "iot_class": "local_push",

View file

@ -6,7 +6,7 @@ import contextlib
from dataclasses import dataclass, replace as dataclass_replace from dataclasses import dataclass, replace as dataclass_replace
from datetime import timedelta from datetime import timedelta
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, cast
import sqlalchemy import sqlalchemy
from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text
@ -96,7 +96,7 @@ def _schema_is_current(current_version: int) -> bool:
def validate_db_schema( def validate_db_schema(
hass: HomeAssistant, engine: Engine, session_maker: Callable[[], Session] hass: HomeAssistant, instance: Recorder, session_maker: Callable[[], Session]
) -> SchemaValidationStatus | None: ) -> SchemaValidationStatus | None:
"""Check if the schema is valid. """Check if the schema is valid.
@ -113,7 +113,7 @@ def validate_db_schema(
if is_current := _schema_is_current(current_version): if is_current := _schema_is_current(current_version):
# We can only check for further errors if the schema is current, because # We can only check for further errors if the schema is current, because
# columns may otherwise not exist etc. # columns may otherwise not exist etc.
schema_errors |= statistics_validate_db_schema(hass, engine, session_maker) schema_errors |= statistics_validate_db_schema(hass, instance, session_maker)
valid = is_current and not schema_errors valid = is_current and not schema_errors
@ -444,17 +444,17 @@ def _update_states_table_with_foreign_key_options(
states_key_constraints = Base.metadata.tables[TABLE_STATES].foreign_key_constraints states_key_constraints = Base.metadata.tables[TABLE_STATES].foreign_key_constraints
old_states_table = Table( # noqa: F841 pylint: disable=unused-variable old_states_table = Table( # noqa: F841 pylint: disable=unused-variable
TABLE_STATES, MetaData(), *(alter["old_fk"] for alter in alters) TABLE_STATES, MetaData(), *(alter["old_fk"] for alter in alters) # type: ignore[arg-type]
) )
for alter in alters: for alter in alters:
with session_scope(session=session_maker()) as session: with session_scope(session=session_maker()) as session:
try: try:
connection = session.connection() connection = session.connection()
connection.execute(DropConstraint(alter["old_fk"])) connection.execute(DropConstraint(alter["old_fk"])) # type: ignore[no-untyped-call]
for fkc in states_key_constraints: for fkc in states_key_constraints:
if fkc.column_keys == alter["columns"]: if fkc.column_keys == alter["columns"]:
connection.execute(AddConstraint(fkc)) connection.execute(AddConstraint(fkc)) # type: ignore[no-untyped-call]
except (InternalError, OperationalError): except (InternalError, OperationalError):
_LOGGER.exception( _LOGGER.exception(
"Could not update foreign options in %s table", TABLE_STATES "Could not update foreign options in %s table", TABLE_STATES
@ -484,7 +484,7 @@ def _drop_foreign_key_constraints(
with session_scope(session=session_maker()) as session: with session_scope(session=session_maker()) as session:
try: try:
connection = session.connection() connection = session.connection()
connection.execute(DropConstraint(drop)) connection.execute(DropConstraint(drop)) # type: ignore[no-untyped-call]
except (InternalError, OperationalError): except (InternalError, OperationalError):
_LOGGER.exception( _LOGGER.exception(
"Could not drop foreign constraints in %s table on %s", "Could not drop foreign constraints in %s table on %s",
@ -630,18 +630,21 @@ def _apply_update( # noqa: C901
# Order matters! Statistics and StatisticsShortTerm have a relation with # Order matters! Statistics and StatisticsShortTerm have a relation with
# StatisticsMeta, so statistics need to be deleted before meta (or in pair # StatisticsMeta, so statistics need to be deleted before meta (or in pair
# depending on the SQL backend); and meta needs to be created before statistics. # depending on the SQL backend); and meta needs to be created before statistics.
# We need to cast __table__ to Table, explanation in
# https://github.com/sqlalchemy/sqlalchemy/issues/9130
Base.metadata.drop_all( Base.metadata.drop_all(
bind=engine, bind=engine,
tables=[ tables=[
StatisticsShortTerm.__table__, cast(Table, StatisticsShortTerm.__table__),
Statistics.__table__, cast(Table, Statistics.__table__),
StatisticsMeta.__table__, cast(Table, StatisticsMeta.__table__),
], ],
) )
StatisticsMeta.__table__.create(engine) cast(Table, StatisticsMeta.__table__).create(engine)
StatisticsShortTerm.__table__.create(engine) cast(Table, StatisticsShortTerm.__table__).create(engine)
Statistics.__table__.create(engine) cast(Table, Statistics.__table__).create(engine)
elif new_version == 19: elif new_version == 19:
# This adds the statistic runs table, insert a fake run to prevent duplicating # This adds the statistic runs table, insert a fake run to prevent duplicating
# statistics. # statistics.
@ -694,20 +697,22 @@ def _apply_update( # noqa: C901
# so statistics need to be deleted before meta (or in pair depending # so statistics need to be deleted before meta (or in pair depending
# on the SQL backend); and meta needs to be created before statistics. # on the SQL backend); and meta needs to be created before statistics.
if engine.dialect.name == "oracle": if engine.dialect.name == "oracle":
# We need to cast __table__ to Table, explanation in
# https://github.com/sqlalchemy/sqlalchemy/issues/9130
Base.metadata.drop_all( Base.metadata.drop_all(
bind=engine, bind=engine,
tables=[ tables=[
StatisticsShortTerm.__table__, cast(Table, StatisticsShortTerm.__table__),
Statistics.__table__, cast(Table, Statistics.__table__),
StatisticsMeta.__table__, cast(Table, StatisticsMeta.__table__),
StatisticsRuns.__table__, cast(Table, StatisticsRuns.__table__),
], ],
) )
StatisticsRuns.__table__.create(engine) cast(Table, StatisticsRuns.__table__).create(engine)
StatisticsMeta.__table__.create(engine) cast(Table, StatisticsMeta.__table__).create(engine)
StatisticsShortTerm.__table__.create(engine) cast(Table, StatisticsShortTerm.__table__).create(engine)
Statistics.__table__.create(engine) cast(Table, Statistics.__table__).create(engine)
# Block 5-minute statistics for one hour from the last run, or it will overlap # Block 5-minute statistics for one hour from the last run, or it will overlap
# with existing hourly statistics. Don't block on a database with no existing # with existing hourly statistics. Don't block on a database with no existing
@ -715,6 +720,8 @@ def _apply_update( # noqa: C901
with session_scope(session=session_maker()) as session: with session_scope(session=session_maker()) as session:
if session.query(Statistics.id).count() and ( if session.query(Statistics.id).count() and (
last_run_string := session.query( last_run_string := session.query(
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
func.max(StatisticsRuns.start) func.max(StatisticsRuns.start)
).scalar() ).scalar()
): ):
@ -996,7 +1003,7 @@ def _migrate_columns_to_timestamp(
) )
) )
result = None result = None
while result is None or result.rowcount > 0: while result is None or result.rowcount > 0: # type: ignore[unreachable]
with session_scope(session=session_maker()) as session: with session_scope(session=session_maker()) as session:
result = session.connection().execute( result = session.connection().execute(
text( text(
@ -1027,7 +1034,7 @@ def _migrate_columns_to_timestamp(
) )
) )
result = None result = None
while result is None or result.rowcount > 0: while result is None or result.rowcount > 0: # type: ignore[unreachable]
with session_scope(session=session_maker()) as session: with session_scope(session=session_maker()) as session:
result = session.connection().execute( result = session.connection().execute(
text( text(

View file

@ -5,7 +5,12 @@ import traceback
from typing import Any from typing import Any
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.pool import NullPool, SingletonThreadPool, StaticPool from sqlalchemy.pool import (
ConnectionPoolEntry,
NullPool,
SingletonThreadPool,
StaticPool,
)
from homeassistant.helpers.frame import report from homeassistant.helpers.frame import report
from homeassistant.util.async_ import check_loop from homeassistant.util.async_ import check_loop
@ -47,11 +52,10 @@ class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc]
thread_name == "Recorder" or thread_name.startswith(DB_WORKER_PREFIX) thread_name == "Recorder" or thread_name.startswith(DB_WORKER_PREFIX)
) )
# Any can be switched out for ConnectionPoolEntry in the next version of sqlalchemy def _do_return_conn(self, record: ConnectionPoolEntry) -> Any:
def _do_return_conn(self, conn: Any) -> Any:
if self.recorder_or_dbworker: if self.recorder_or_dbworker:
return super()._do_return_conn(conn) return super()._do_return_conn(record)
conn.close() record.close()
def shutdown(self) -> None: def shutdown(self) -> None:
"""Close the connection.""" """Close the connection."""
@ -92,7 +96,7 @@ class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc]
return super(NullPool, self)._create_connection() return super(NullPool, self)._create_connection()
class MutexPool(StaticPool): # type: ignore[misc] class MutexPool(StaticPool):
"""A pool which prevents concurrent accesses from multiple threads. """A pool which prevents concurrent accesses from multiple threads.
This is used in tests to prevent unsafe concurrent accesses to in-memory SQLite This is used in tests to prevent unsafe concurrent accesses to in-memory SQLite
@ -102,14 +106,14 @@ class MutexPool(StaticPool): # type: ignore[misc]
_reference_counter = 0 _reference_counter = 0
pool_lock: threading.RLock pool_lock: threading.RLock
def _do_return_conn(self, conn: Any) -> None: def _do_return_conn(self, record: ConnectionPoolEntry) -> Any:
if DEBUG_MUTEX_POOL_TRACE: if DEBUG_MUTEX_POOL_TRACE:
trace = traceback.extract_stack() trace = traceback.extract_stack()
trace_msg = "\n" + "".join(traceback.format_list(trace[:-1])) trace_msg = "\n" + "".join(traceback.format_list(trace[:-1]))
else: else:
trace_msg = "" trace_msg = ""
super()._do_return_conn(conn) super()._do_return_conn(record)
if DEBUG_MUTEX_POOL: if DEBUG_MUTEX_POOL:
self._reference_counter -= 1 self._reference_counter -= 1
_LOGGER.debug( _LOGGER.debug(

View file

@ -8,6 +8,7 @@ from itertools import islice, zip_longest
import logging import logging
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from sqlalchemy.engine.row import Row
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import distinct from sqlalchemy.sql.expression import distinct
@ -616,9 +617,9 @@ def _purge_filtered_states(
database_engine: DatabaseEngine, database_engine: DatabaseEngine,
) -> None: ) -> None:
"""Remove filtered states and linked events.""" """Remove filtered states and linked events."""
state_ids: list[int] state_ids: tuple[int, ...]
attributes_ids: list[int] attributes_ids: tuple[int, ...]
event_ids: list[int] event_ids: tuple[int, ...]
state_ids, attributes_ids, event_ids = zip( state_ids, attributes_ids, event_ids = zip(
*( *(
session.query(States.state_id, States.attributes_id, States.event_id) session.query(States.state_id, States.attributes_id, States.event_id)
@ -627,12 +628,12 @@ def _purge_filtered_states(
.all() .all()
) )
) )
event_ids = [id_ for id_ in event_ids if id_ is not None] filtered_event_ids = [id_ for id_ in event_ids if id_ is not None]
_LOGGER.debug( _LOGGER.debug(
"Selected %s state_ids to remove that should be filtered", len(state_ids) "Selected %s state_ids to remove that should be filtered", len(state_ids)
) )
_purge_state_ids(instance, session, set(state_ids)) _purge_state_ids(instance, session, set(state_ids))
_purge_event_ids(session, event_ids) _purge_event_ids(session, filtered_event_ids)
unused_attribute_ids_set = _select_unused_attributes_ids( unused_attribute_ids_set = _select_unused_attributes_ids(
session, {id_ for id_ in attributes_ids if id_ is not None}, database_engine session, {id_ for id_ in attributes_ids if id_ is not None}, database_engine
) )
@ -656,7 +657,7 @@ def _purge_filtered_events(
_LOGGER.debug( _LOGGER.debug(
"Selected %s event_ids to remove that should be filtered", len(event_ids) "Selected %s event_ids to remove that should be filtered", len(event_ids)
) )
states: list[States] = ( states: list[Row[tuple[int]]] = (
session.query(States.state_id).filter(States.event_id.in_(event_ids)).all() session.query(States.state_id).filter(States.event_id.in_(event_ids)).all()
) )
state_ids: set[int] = {state.state_id for state in states} state_ids: set[int] = {state.state_id for state in states}

View file

@ -42,6 +42,8 @@ def find_shared_data_id(attr_hash: int, shared_data: str) -> StatementLambdaElem
def _state_attrs_exist(attr: int | None) -> Select: def _state_attrs_exist(attr: int | None) -> Select:
"""Check if a state attributes id exists in the states table.""" """Check if a state attributes id exists in the states table."""
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
return select(func.min(States.attributes_id)).where(States.attributes_id == attr) return select(func.min(States.attributes_id)).where(States.attributes_id == attr)
@ -279,6 +281,8 @@ def data_ids_exist_in_events_with_fast_in_distinct(
def _event_data_id_exist(data_id: int | None) -> Select: def _event_data_id_exist(data_id: int | None) -> Select:
"""Check if a event data id exists in the events table.""" """Check if a event data id exists in the events table."""
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
return select(func.min(Events.data_id)).where(Events.data_id == data_id) return select(func.min(Events.data_id)).where(Events.data_id == data_id)
@ -620,6 +624,8 @@ def find_statistics_runs_to_purge(
def find_latest_statistics_runs_run_id() -> StatementLambdaElement: def find_latest_statistics_runs_run_id() -> StatementLambdaElement:
"""Find the latest statistics_runs run_id.""" """Find the latest statistics_runs run_id."""
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
return lambda_stmt(lambda: select(func.max(StatisticsRuns.run_id))) return lambda_stmt(lambda: select(func.max(StatisticsRuns.run_id)))
@ -639,4 +645,6 @@ def find_legacy_event_state_and_attributes_and_data_ids_to_purge(
def find_legacy_row() -> StatementLambdaElement: def find_legacy_row() -> StatementLambdaElement:
"""Check if there are still states in the table with an event_id.""" """Check if there are still states in the table with an event_id."""
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
return lambda_stmt(lambda: select(func.max(States.event_id))) return lambda_stmt(lambda: select(func.max(States.event_id)))

View file

@ -122,7 +122,8 @@ class RunHistory:
for run in session.query(RecorderRuns).order_by(RecorderRuns.start.asc()).all(): for run in session.query(RecorderRuns).order_by(RecorderRuns.start.asc()).all():
session.expunge(run) session.expunge(run)
if run_dt := process_timestamp(run.start): if run_dt := process_timestamp(run.start):
timestamp = run_dt.timestamp() # Not sure if this is correct or runs_by_timestamp annotation should be changed
timestamp = int(run_dt.timestamp())
run_timestamps.append(timestamp) run_timestamps.append(timestamp)
runs_by_timestamp[timestamp] = run runs_by_timestamp[timestamp] = run

View file

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable, Iterable, Mapping from collections.abc import Callable, Iterable, Mapping, Sequence
import contextlib import contextlib
import dataclasses import dataclasses
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -13,7 +13,7 @@ import logging
import os import os
import re import re
from statistics import mean from statistics import mean
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal, cast
from sqlalchemy import bindparam, func, lambda_stmt, select, text from sqlalchemy import bindparam, func, lambda_stmt, select, text
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
@ -77,7 +77,7 @@ from .util import (
if TYPE_CHECKING: if TYPE_CHECKING:
from . import Recorder from . import Recorder
QUERY_STATISTICS = [ QUERY_STATISTICS = (
Statistics.metadata_id, Statistics.metadata_id,
Statistics.start, Statistics.start,
Statistics.mean, Statistics.mean,
@ -86,9 +86,9 @@ QUERY_STATISTICS = [
Statistics.last_reset, Statistics.last_reset,
Statistics.state, Statistics.state,
Statistics.sum, Statistics.sum,
] )
QUERY_STATISTICS_SHORT_TERM = [ QUERY_STATISTICS_SHORT_TERM = (
StatisticsShortTerm.metadata_id, StatisticsShortTerm.metadata_id,
StatisticsShortTerm.start, StatisticsShortTerm.start,
StatisticsShortTerm.mean, StatisticsShortTerm.mean,
@ -97,30 +97,34 @@ QUERY_STATISTICS_SHORT_TERM = [
StatisticsShortTerm.last_reset, StatisticsShortTerm.last_reset,
StatisticsShortTerm.state, StatisticsShortTerm.state,
StatisticsShortTerm.sum, StatisticsShortTerm.sum,
] )
QUERY_STATISTICS_SUMMARY_MEAN = [ QUERY_STATISTICS_SUMMARY_MEAN = (
StatisticsShortTerm.metadata_id, StatisticsShortTerm.metadata_id,
func.avg(StatisticsShortTerm.mean), func.avg(StatisticsShortTerm.mean),
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
func.min(StatisticsShortTerm.min), func.min(StatisticsShortTerm.min),
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
func.max(StatisticsShortTerm.max), func.max(StatisticsShortTerm.max),
] )
QUERY_STATISTICS_SUMMARY_SUM = [ QUERY_STATISTICS_SUMMARY_SUM = (
StatisticsShortTerm.metadata_id, StatisticsShortTerm.metadata_id,
StatisticsShortTerm.start, StatisticsShortTerm.start,
StatisticsShortTerm.last_reset, StatisticsShortTerm.last_reset,
StatisticsShortTerm.state, StatisticsShortTerm.state,
StatisticsShortTerm.sum, StatisticsShortTerm.sum,
func.row_number() func.row_number()
.over( .over( # type: ignore[no-untyped-call]
partition_by=StatisticsShortTerm.metadata_id, partition_by=StatisticsShortTerm.metadata_id,
order_by=StatisticsShortTerm.start.desc(), order_by=StatisticsShortTerm.start.desc(),
) )
.label("rownum"), .label("rownum"),
] )
QUERY_STATISTIC_META = [ QUERY_STATISTIC_META = (
StatisticsMeta.id, StatisticsMeta.id,
StatisticsMeta.statistic_id, StatisticsMeta.statistic_id,
StatisticsMeta.source, StatisticsMeta.source,
@ -128,7 +132,7 @@ QUERY_STATISTIC_META = [
StatisticsMeta.has_mean, StatisticsMeta.has_mean,
StatisticsMeta.has_sum, StatisticsMeta.has_sum,
StatisticsMeta.name, StatisticsMeta.name,
] )
STATISTIC_UNIT_TO_UNIT_CONVERTER: dict[str | None, type[BaseUnitConverter]] = { STATISTIC_UNIT_TO_UNIT_CONVERTER: dict[str | None, type[BaseUnitConverter]] = {
@ -372,7 +376,7 @@ def _update_or_add_metadata(
statistic_id, statistic_id,
new_metadata, new_metadata,
) )
return meta.id # type: ignore[no-any-return] return meta.id
metadata_id, old_metadata = old_metadata_dict[statistic_id] metadata_id, old_metadata = old_metadata_dict[statistic_id]
if ( if (
@ -401,7 +405,7 @@ def _update_or_add_metadata(
def _find_duplicates( def _find_duplicates(
session: Session, table: type[Statistics | StatisticsShortTerm] session: Session, table: type[StatisticsBase]
) -> tuple[list[int], list[dict]]: ) -> tuple[list[int], list[dict]]:
"""Find duplicated statistics.""" """Find duplicated statistics."""
subquery = ( subquery = (
@ -411,6 +415,8 @@ def _find_duplicates(
literal_column("1").label("is_duplicate"), literal_column("1").label("is_duplicate"),
) )
.group_by(table.metadata_id, table.start) .group_by(table.metadata_id, table.start)
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
.having(func.count() > 1) .having(func.count() > 1)
.subquery() .subquery()
) )
@ -435,7 +441,7 @@ def _find_duplicates(
if not duplicates: if not duplicates:
return (duplicate_ids, non_identical_duplicates_as_dict) return (duplicate_ids, non_identical_duplicates_as_dict)
def columns_to_dict(duplicate: type[Statistics | StatisticsShortTerm]) -> dict: def columns_to_dict(duplicate: Row) -> dict:
"""Convert a SQLAlchemy row to dict.""" """Convert a SQLAlchemy row to dict."""
dict_ = {} dict_ = {}
for key in duplicate.__mapper__.c.keys(): for key in duplicate.__mapper__.c.keys():
@ -466,7 +472,7 @@ def _find_duplicates(
def _delete_duplicates_from_table( def _delete_duplicates_from_table(
session: Session, table: type[Statistics | StatisticsShortTerm] session: Session, table: type[StatisticsBase]
) -> tuple[int, list[dict]]: ) -> tuple[int, list[dict]]:
"""Identify and delete duplicated statistics from a specified table.""" """Identify and delete duplicated statistics from a specified table."""
all_non_identical_duplicates: list[dict] = [] all_non_identical_duplicates: list[dict] = []
@ -542,6 +548,8 @@ def _find_statistics_meta_duplicates(session: Session) -> list[int]:
literal_column("1").label("is_duplicate"), literal_column("1").label("is_duplicate"),
) )
.group_by(StatisticsMeta.statistic_id) .group_by(StatisticsMeta.statistic_id)
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
.having(func.count() > 1) .having(func.count() > 1)
.subquery() .subquery()
) )
@ -672,8 +680,8 @@ def _compile_hourly_statistics(session: Session, start: datetime) -> None:
} }
# Insert compiled hourly statistics in the database # Insert compiled hourly statistics in the database
for metadata_id, stat in summary.items(): for metadata_id, summary_item in summary.items():
session.add(Statistics.from_stats(metadata_id, stat)) session.add(Statistics.from_stats(metadata_id, summary_item))
@retryable_database_job("statistics") @retryable_database_job("statistics")
@ -743,7 +751,7 @@ def compile_statistics(instance: Recorder, start: datetime, fire_events: bool) -
def _adjust_sum_statistics( def _adjust_sum_statistics(
session: Session, session: Session,
table: type[Statistics | StatisticsShortTerm], table: type[StatisticsBase],
metadata_id: int, metadata_id: int,
start_time: datetime, start_time: datetime,
adj: float, adj: float,
@ -767,7 +775,7 @@ def _adjust_sum_statistics(
def _insert_statistics( def _insert_statistics(
session: Session, session: Session,
table: type[Statistics | StatisticsShortTerm], table: type[StatisticsBase],
metadata_id: int, metadata_id: int,
statistic: StatisticData, statistic: StatisticData,
) -> None: ) -> None:
@ -784,7 +792,7 @@ def _insert_statistics(
def _update_statistics( def _update_statistics(
session: Session, session: Session,
table: type[Statistics | StatisticsShortTerm], table: type[StatisticsBase],
stat_id: int, stat_id: int,
statistic: StatisticData, statistic: StatisticData,
) -> None: ) -> None:
@ -816,8 +824,11 @@ def _generate_get_metadata_stmt(
) -> StatementLambdaElement: ) -> StatementLambdaElement:
"""Generate a statement to fetch metadata.""" """Generate a statement to fetch metadata."""
stmt = lambda_stmt(lambda: select(*QUERY_STATISTIC_META)) stmt = lambda_stmt(lambda: select(*QUERY_STATISTIC_META))
if statistic_ids is not None: if statistic_ids:
stmt += lambda q: q.where(StatisticsMeta.statistic_id.in_(statistic_ids)) stmt += lambda q: q.where(
# https://github.com/python/mypy/issues/2608
StatisticsMeta.statistic_id.in_(statistic_ids) # type:ignore[arg-type]
)
if statistic_source is not None: if statistic_source is not None:
stmt += lambda q: q.where(StatisticsMeta.source == statistic_source) stmt += lambda q: q.where(StatisticsMeta.source == statistic_source)
if statistic_type == "mean": if statistic_type == "mean":
@ -849,15 +860,15 @@ def get_metadata_with_session(
return {} return {}
return { return {
meta["statistic_id"]: ( meta.statistic_id: (
meta["id"], meta.id,
{ {
"has_mean": meta["has_mean"], "has_mean": meta.has_mean,
"has_sum": meta["has_sum"], "has_sum": meta.has_sum,
"name": meta["name"], "name": meta.name,
"source": meta["source"], "source": meta.source,
"statistic_id": meta["statistic_id"], "statistic_id": meta.statistic_id,
"unit_of_measurement": meta["unit_of_measurement"], "unit_of_measurement": meta.unit_of_measurement,
}, },
) )
for meta in result for meta in result
@ -1132,7 +1143,7 @@ def _statistics_during_period_stmt(
start_time: datetime, start_time: datetime,
end_time: datetime | None, end_time: datetime | None,
metadata_ids: list[int] | None, metadata_ids: list[int] | None,
table: type[Statistics | StatisticsShortTerm], table: type[StatisticsBase],
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> StatementLambdaElement: ) -> StatementLambdaElement:
"""Prepare a database query for statistics during a given period. """Prepare a database query for statistics during a given period.
@ -1140,25 +1151,28 @@ def _statistics_during_period_stmt(
This prepares a lambda_stmt query, so we don't insert the parameters yet. This prepares a lambda_stmt query, so we don't insert the parameters yet.
""" """
columns = [table.metadata_id, table.start] columns = select(table.metadata_id, table.start)
if "last_reset" in types: if "last_reset" in types:
columns.append(table.last_reset) columns = columns.add_columns(table.last_reset)
if "max" in types: if "max" in types:
columns.append(table.max) columns = columns.add_columns(table.max)
if "mean" in types: if "mean" in types:
columns.append(table.mean) columns = columns.add_columns(table.mean)
if "min" in types: if "min" in types:
columns.append(table.min) columns = columns.add_columns(table.min)
if "state" in types: if "state" in types:
columns.append(table.state) columns = columns.add_columns(table.state)
if "sum" in types: if "sum" in types:
columns.append(table.sum) columns = columns.add_columns(table.sum)
stmt = lambda_stmt(lambda: select(columns).filter(table.start >= start_time)) stmt = lambda_stmt(lambda: columns.filter(table.start >= start_time))
if end_time is not None: if end_time is not None:
stmt += lambda q: q.filter(table.start < end_time) stmt += lambda q: q.filter(table.start < end_time)
if metadata_ids: if metadata_ids:
stmt += lambda q: q.filter(table.metadata_id.in_(metadata_ids)) stmt += lambda q: q.filter(
# https://github.com/python/mypy/issues/2608
table.metadata_id.in_(metadata_ids) # type:ignore[arg-type]
)
stmt += lambda q: q.order_by(table.metadata_id, table.start) stmt += lambda q: q.order_by(table.metadata_id, table.start)
return stmt return stmt
@ -1168,34 +1182,43 @@ def _get_max_mean_min_statistic_in_sub_period(
result: dict[str, float], result: dict[str, float],
start_time: datetime | None, start_time: datetime | None,
end_time: datetime | None, end_time: datetime | None,
table: type[Statistics | StatisticsShortTerm], table: type[StatisticsBase],
types: set[Literal["max", "mean", "min", "change"]], types: set[Literal["max", "mean", "min", "change"]],
metadata_id: int, metadata_id: int,
) -> None: ) -> None:
"""Return max, mean and min during the period.""" """Return max, mean and min during the period."""
# Calculate max, mean, min # Calculate max, mean, min
columns = [] columns = select()
if "max" in types: if "max" in types:
columns.append(func.max(table.max)) # https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
columns = columns.add_columns(func.max(table.max))
if "mean" in types: if "mean" in types:
columns.append(func.avg(table.mean)) columns = columns.add_columns(func.avg(table.mean))
columns.append(func.count(table.mean)) # https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
columns = columns.add_columns(func.count(table.mean))
if "min" in types: if "min" in types:
columns.append(func.min(table.min)) # https://github.com/sqlalchemy/sqlalchemy/issues/9189
stmt = lambda_stmt(lambda: select(columns).filter(table.metadata_id == metadata_id)) # pylint: disable-next=not-callable
columns = columns.add_columns(func.min(table.min))
stmt = lambda_stmt(lambda: columns.filter(table.metadata_id == metadata_id))
if start_time is not None: if start_time is not None:
stmt += lambda q: q.filter(table.start >= start_time) stmt += lambda q: q.filter(table.start >= start_time)
if end_time is not None: if end_time is not None:
stmt += lambda q: q.filter(table.start < end_time) stmt += lambda q: q.filter(table.start < end_time)
stats = execute_stmt_lambda_element(session, stmt) stats = cast(Sequence[Row[Any]], execute_stmt_lambda_element(session, stmt))
if "max" in types and stats and (new_max := stats[0].max) is not None: if not stats:
return
if "max" in types and (new_max := stats[0].max) is not None:
old_max = result.get("max") old_max = result.get("max")
result["max"] = max(new_max, old_max) if old_max is not None else new_max result["max"] = max(new_max, old_max) if old_max is not None else new_max
if "mean" in types and stats and stats[0].avg is not None: if "mean" in types and stats[0].avg is not None:
duration = stats[0].count * table.duration.total_seconds() # https://github.com/sqlalchemy/sqlalchemy/issues/9127
duration = stats[0].count * table.duration.total_seconds() # type: ignore[operator]
result["duration"] = result.get("duration", 0.0) + duration result["duration"] = result.get("duration", 0.0) + duration
result["mean_acc"] = result.get("mean_acc", 0.0) + stats[0].avg * duration result["mean_acc"] = result.get("mean_acc", 0.0) + stats[0].avg * duration
if "min" in types and stats and (new_min := stats[0].min) is not None: if "min" in types and (new_min := stats[0].min) is not None:
old_min = result.get("min") old_min = result.get("min")
result["min"] = min(new_min, old_min) if old_min is not None else new_min result["min"] = min(new_min, old_min) if old_min is not None else new_min
@ -1268,7 +1291,7 @@ def _get_max_mean_min_statistic(
def _first_statistic( def _first_statistic(
session: Session, session: Session,
table: type[Statistics | StatisticsShortTerm], table: type[StatisticsBase],
metadata_id: int, metadata_id: int,
) -> datetime | None: ) -> datetime | None:
"""Return the data of the oldest statistic row for a given metadata id.""" """Return the data of the oldest statistic row for a given metadata id."""
@ -1278,9 +1301,8 @@ def _first_statistic(
.order_by(table.start.asc()) .order_by(table.start.asc())
.limit(1) .limit(1)
) )
if stats := execute_stmt_lambda_element(session, stmt): stats = cast(Sequence[Row], execute_stmt_lambda_element(session, stmt))
return process_timestamp(stats[0].start) # type: ignore[no-any-return] return process_timestamp(stats[0].start) if stats else None
return None
def _get_oldest_sum_statistic( def _get_oldest_sum_statistic(
@ -1297,7 +1319,7 @@ def _get_oldest_sum_statistic(
def _get_oldest_sum_statistic_in_sub_period( def _get_oldest_sum_statistic_in_sub_period(
session: Session, session: Session,
start_time: datetime | None, start_time: datetime | None,
table: type[Statistics | StatisticsShortTerm], table: type[StatisticsBase],
metadata_id: int, metadata_id: int,
) -> float | None: ) -> float | None:
"""Return the oldest non-NULL sum during the period.""" """Return the oldest non-NULL sum during the period."""
@ -1317,7 +1339,7 @@ def _get_oldest_sum_statistic(
period = start_time.replace(minute=0, second=0, microsecond=0) period = start_time.replace(minute=0, second=0, microsecond=0)
prev_period = period - table.duration prev_period = period - table.duration
stmt += lambda q: q.filter(table.start >= prev_period) stmt += lambda q: q.filter(table.start >= prev_period)
stats = execute_stmt_lambda_element(session, stmt) stats = cast(Sequence[Row], execute_stmt_lambda_element(session, stmt))
return stats[0].sum if stats else None return stats[0].sum if stats else None
oldest_sum: float | None = None oldest_sum: float | None = None
@ -1380,7 +1402,7 @@ def _get_newest_sum_statistic(
session: Session, session: Session,
start_time: datetime | None, start_time: datetime | None,
end_time: datetime | None, end_time: datetime | None,
table: type[Statistics | StatisticsShortTerm], table: type[StatisticsBase],
metadata_id: int, metadata_id: int,
) -> float | None: ) -> float | None:
"""Return the newest non-NULL sum during the period.""" """Return the newest non-NULL sum during the period."""
@ -1397,7 +1419,7 @@ def _get_newest_sum_statistic(
stmt += lambda q: q.filter(table.start >= start_time) stmt += lambda q: q.filter(table.start >= start_time)
if end_time is not None: if end_time is not None:
stmt += lambda q: q.filter(table.start < end_time) stmt += lambda q: q.filter(table.start < end_time)
stats = execute_stmt_lambda_element(session, stmt) stats = cast(Sequence[Row], execute_stmt_lambda_element(session, stmt))
return stats[0].sum if stats else None return stats[0].sum if stats else None
@ -1700,7 +1722,7 @@ def _get_last_statistics(
number_of_stats: int, number_of_stats: int,
statistic_id: str, statistic_id: str,
convert_units: bool, convert_units: bool,
table: type[Statistics | StatisticsShortTerm], table: type[StatisticsBase],
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> dict[str, list[dict]]: ) -> dict[str, list[dict]]:
"""Return the last number_of_stats statistics for a given statistic_id.""" """Return the last number_of_stats statistics for a given statistic_id."""
@ -1766,6 +1788,8 @@ def _generate_most_recent_statistic_row(metadata_ids: list[int]) -> Subquery:
return ( return (
select( select(
StatisticsShortTerm.metadata_id, StatisticsShortTerm.metadata_id,
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
func.max(StatisticsShortTerm.start).label("start_max"), func.max(StatisticsShortTerm.start).label("start_max"),
) )
.where(StatisticsShortTerm.metadata_id.in_(metadata_ids)) .where(StatisticsShortTerm.metadata_id.in_(metadata_ids))
@ -1831,28 +1855,30 @@ def get_latest_short_term_statistics(
def _statistics_at_time( def _statistics_at_time(
session: Session, session: Session,
metadata_ids: set[int], metadata_ids: set[int],
table: type[Statistics | StatisticsShortTerm], table: type[StatisticsBase],
start_time: datetime, start_time: datetime,
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> list | None: ) -> Sequence[Row] | None:
"""Return last known statistics, earlier than start_time, for the metadata_ids.""" """Return last known statistics, earlier than start_time, for the metadata_ids."""
columns = [table.metadata_id, table.start] columns = select(table.metadata_id, table.start)
if "last_reset" in types: if "last_reset" in types:
columns.append(table.last_reset) columns = columns.add_columns(table.last_reset)
if "max" in types: if "max" in types:
columns.append(table.max) columns = columns.add_columns(table.max)
if "mean" in types: if "mean" in types:
columns.append(table.mean) columns = columns.add_columns(table.mean)
if "min" in types: if "min" in types:
columns.append(table.min) columns = columns.add_columns(table.min)
if "state" in types: if "state" in types:
columns.append(table.state) columns = columns.add_columns(table.state)
if "sum" in types: if "sum" in types:
columns.append(table.sum) columns = columns.add_columns(table.sum)
stmt = lambda_stmt(lambda: select(columns)) stmt = lambda_stmt(lambda: columns)
most_recent_statistic_ids = ( most_recent_statistic_ids = (
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
lambda_stmt(lambda: select(func.max(table.id).label("max_id"))) lambda_stmt(lambda: select(func.max(table.id).label("max_id")))
.filter(table.start < start_time) .filter(table.start < start_time)
.filter(table.metadata_id.in_(metadata_ids)) .filter(table.metadata_id.in_(metadata_ids))
@ -1864,7 +1890,7 @@ def _statistics_at_time(
most_recent_statistic_ids, most_recent_statistic_ids,
table.id == most_recent_statistic_ids.c.max_id, table.id == most_recent_statistic_ids.c.max_id,
) )
return execute_stmt_lambda_element(session, stmt) return cast(Sequence[Row], execute_stmt_lambda_element(session, stmt))
def _sorted_statistics_to_dict( def _sorted_statistics_to_dict(
@ -1874,7 +1900,7 @@ def _sorted_statistics_to_dict(
statistic_ids: list[str] | None, statistic_ids: list[str] | None,
_metadata: dict[str, tuple[int, StatisticMetaData]], _metadata: dict[str, tuple[int, StatisticMetaData]],
convert_units: bool, convert_units: bool,
table: type[Statistics | StatisticsShortTerm], table: type[StatisticsBase],
start_time: datetime | None, start_time: datetime | None,
units: dict[str, str] | None, units: dict[str, str] | None,
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
@ -1965,7 +1991,7 @@ def validate_statistics(hass: HomeAssistant) -> dict[str, list[ValidationIssue]]
def _statistics_exists( def _statistics_exists(
session: Session, session: Session,
table: type[Statistics | StatisticsShortTerm], table: type[StatisticsBase],
metadata_id: int, metadata_id: int,
start: datetime, start: datetime,
) -> int | None: ) -> int | None:
@ -1975,7 +2001,7 @@ def _statistics_exists(
.filter((table.metadata_id == metadata_id) & (table.start == start)) .filter((table.metadata_id == metadata_id) & (table.start == start))
.first() .first()
) )
return result["id"] if result else None return result.id if result else None
@callback @callback
@ -2067,11 +2093,16 @@ def _filter_unique_constraint_integrity_error(
ignore = True ignore = True
if ( if (
dialect_name == SupportedDialect.POSTGRESQL dialect_name == SupportedDialect.POSTGRESQL
and err.orig
and hasattr(err.orig, "pgcode") and hasattr(err.orig, "pgcode")
and err.orig.pgcode == "23505" and err.orig.pgcode == "23505"
): ):
ignore = True ignore = True
if dialect_name == SupportedDialect.MYSQL and hasattr(err.orig, "args"): if (
dialect_name == SupportedDialect.MYSQL
and err.orig
and hasattr(err.orig, "args")
):
with contextlib.suppress(TypeError): with contextlib.suppress(TypeError):
if err.orig.args[0] == 1062: if err.orig.args[0] == 1062:
ignore = True ignore = True
@ -2095,7 +2126,7 @@ def _import_statistics_with_session(
session: Session, session: Session,
metadata: StatisticMetaData, metadata: StatisticMetaData,
statistics: Iterable[StatisticData], statistics: Iterable[StatisticData],
table: type[Statistics | StatisticsShortTerm], table: type[StatisticsBase],
) -> bool: ) -> bool:
"""Import statistics to the database.""" """Import statistics to the database."""
old_metadata_dict = get_metadata_with_session( old_metadata_dict = get_metadata_with_session(
@ -2116,7 +2147,7 @@ def import_statistics(
instance: Recorder, instance: Recorder,
metadata: StatisticMetaData, metadata: StatisticMetaData,
statistics: Iterable[StatisticData], statistics: Iterable[StatisticData],
table: type[Statistics | StatisticsShortTerm], table: type[StatisticsBase],
) -> bool: ) -> bool:
"""Process an import_statistics job.""" """Process an import_statistics job."""
@ -2174,7 +2205,7 @@ def _change_statistics_unit_for_table(
convert: Callable[[float | None], float | None], convert: Callable[[float | None], float | None],
) -> None: ) -> None:
"""Insert statistics in the database.""" """Insert statistics in the database."""
columns = [table.id, table.mean, table.min, table.max, table.state, table.sum] columns = (table.id, table.mean, table.min, table.max, table.state, table.sum)
query = session.query(*columns).filter_by(metadata_id=bindparam("metadata_id")) query = session.query(*columns).filter_by(metadata_id=bindparam("metadata_id"))
rows = execute(query.params(metadata_id=metadata_id)) rows = execute(query.params(metadata_id=metadata_id))
for row in rows: for row in rows:
@ -2215,7 +2246,11 @@ def change_statistics_unit(
metadata_id = metadata[0] metadata_id = metadata[0]
convert = _get_unit_converter(old_unit, new_unit) convert = _get_unit_converter(old_unit, new_unit)
for table in (StatisticsShortTerm, Statistics): tables: tuple[type[StatisticsBase], ...] = (
Statistics,
StatisticsShortTerm,
)
for table in tables:
_change_statistics_unit_for_table(session, table, metadata_id, convert) _change_statistics_unit_for_table(session, table, metadata_id, convert)
session.query(StatisticsMeta).filter( session.query(StatisticsMeta).filter(
StatisticsMeta.statistic_id == statistic_id StatisticsMeta.statistic_id == statistic_id

View file

@ -14,7 +14,7 @@ def db_size_bytes(session: Session, database_name: str) -> float | None:
"TABLE_SCHEMA=:database_name" "TABLE_SCHEMA=:database_name"
), ),
{"database_name": database_name}, {"database_name": database_name},
).first()[0] ).scalar()
if size is None: if size is None:
return None return None

View file

@ -5,11 +5,14 @@ from sqlalchemy import text
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
def db_size_bytes(session: Session, database_name: str) -> float: def db_size_bytes(session: Session, database_name: str) -> float | None:
"""Get the mysql database size.""" """Get the mysql database size."""
return float( size = session.execute(
session.execute( text("select pg_database_size(:database_name);"),
text("select pg_database_size(:database_name);"), {"database_name": database_name},
{"database_name": database_name}, ).scalar()
).first()[0]
) if not size:
return None
return float(size)

View file

@ -5,13 +5,16 @@ from sqlalchemy import text
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
def db_size_bytes(session: Session, database_name: str) -> float: def db_size_bytes(session: Session, database_name: str) -> float | None:
"""Get the mysql database size.""" """Get the mysql database size."""
return float( size = session.execute(
session.execute( text(
text( "SELECT page_count * page_size as size "
"SELECT page_count * page_size as size " "FROM pragma_page_count(), pragma_page_size();"
"FROM pragma_page_count(), pragma_page_size();" )
) ).scalar()
).first()[0]
) if not size:
return None
return float(size)

View file

@ -1,7 +1,7 @@
"""SQLAlchemy util functions.""" """SQLAlchemy util functions."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Generator from collections.abc import Callable, Generator, Sequence
from contextlib import contextmanager from contextlib import contextmanager
from datetime import date, datetime, timedelta from datetime import date, datetime, timedelta
import functools import functools
@ -17,8 +17,7 @@ from awesomeversion import (
) )
import ciso8601 import ciso8601
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.engine.cursor import CursorFetchStrategy from sqlalchemy.engine import Result, Row
from sqlalchemy.engine.row import Row
from sqlalchemy.exc import OperationalError, SQLAlchemyError from sqlalchemy.exc import OperationalError, SQLAlchemyError
from sqlalchemy.orm.query import Query from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
@ -45,6 +44,8 @@ from .models import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from sqlite3.dbapi2 import Cursor as SQLiteCursor
from . import Recorder from . import Recorder
_RecorderT = TypeVar("_RecorderT", bound="Recorder") _RecorderT = TypeVar("_RecorderT", bound="Recorder")
@ -202,8 +203,8 @@ def execute_stmt_lambda_element(
stmt: StatementLambdaElement, stmt: StatementLambdaElement,
start_time: datetime | None = None, start_time: datetime | None = None,
end_time: datetime | None = None, end_time: datetime | None = None,
yield_per: int | None = DEFAULT_YIELD_STATES_ROWS, yield_per: int = DEFAULT_YIELD_STATES_ROWS,
) -> list[Row]: ) -> Sequence[Row] | Result:
"""Execute a StatementLambdaElement. """Execute a StatementLambdaElement.
If the time window passed is greater than one day If the time window passed is greater than one day
@ -220,8 +221,8 @@ def execute_stmt_lambda_element(
for tryno in range(RETRIES): for tryno in range(RETRIES):
try: try:
if use_all: if use_all:
return executed.all() # type: ignore[no-any-return] return executed.all()
return executed.yield_per(yield_per) # type: ignore[no-any-return] return executed.yield_per(yield_per)
except SQLAlchemyError as err: except SQLAlchemyError as err:
_LOGGER.error("Error executing query: %s", err) _LOGGER.error("Error executing query: %s", err)
if tryno == RETRIES - 1: if tryno == RETRIES - 1:
@ -252,7 +253,7 @@ def dburl_to_path(dburl: str) -> str:
return dburl.removeprefix(SQLITE_URL_PREFIX) return dburl.removeprefix(SQLITE_URL_PREFIX)
def last_run_was_recently_clean(cursor: CursorFetchStrategy) -> bool: def last_run_was_recently_clean(cursor: SQLiteCursor) -> bool:
"""Verify the last recorder run was recently clean.""" """Verify the last recorder run was recently clean."""
cursor.execute("SELECT end FROM recorder_runs ORDER BY start DESC LIMIT 1;") cursor.execute("SELECT end FROM recorder_runs ORDER BY start DESC LIMIT 1;")
@ -273,7 +274,7 @@ def last_run_was_recently_clean(cursor: CursorFetchStrategy) -> bool:
return True return True
def basic_sanity_check(cursor: CursorFetchStrategy) -> bool: def basic_sanity_check(cursor: SQLiteCursor) -> bool:
"""Check tables to make sure select does not fail.""" """Check tables to make sure select does not fail."""
for table in TABLES_TO_CHECK: for table in TABLES_TO_CHECK:
@ -300,7 +301,7 @@ def validate_sqlite_database(dbpath: str) -> bool:
return True return True
def run_checks_on_open_db(dbpath: str, cursor: CursorFetchStrategy) -> None: def run_checks_on_open_db(dbpath: str, cursor: SQLiteCursor) -> None:
"""Run checks that will generate a sqlite3 exception if there is corruption.""" """Run checks that will generate a sqlite3 exception if there is corruption."""
sanity_check_passed = basic_sanity_check(cursor) sanity_check_passed = basic_sanity_check(cursor)
last_run_was_clean = last_run_was_recently_clean(cursor) last_run_was_clean = last_run_was_recently_clean(cursor)

View file

@ -7,7 +7,7 @@ from typing import Any
import sqlalchemy import sqlalchemy
from sqlalchemy.engine import Result from sqlalchemy.engine import Result
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm import Session, scoped_session, sessionmaker
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
@ -47,7 +47,7 @@ def validate_query(db_url: str, query: str, column: str) -> bool:
engine = sqlalchemy.create_engine(db_url, future=True) engine = sqlalchemy.create_engine(db_url, future=True)
sessmaker = scoped_session(sessionmaker(bind=engine, future=True)) sessmaker = scoped_session(sessionmaker(bind=engine, future=True))
sess: scoped_session = sessmaker() sess: Session = sessmaker()
try: try:
result: Result = sess.execute(sqlalchemy.text(query)) result: Result = sess.execute(sqlalchemy.text(query))

View file

@ -2,7 +2,7 @@
"domain": "sql", "domain": "sql",
"name": "SQL", "name": "SQL",
"documentation": "https://www.home-assistant.io/integrations/sql", "documentation": "https://www.home-assistant.io/integrations/sql",
"requirements": ["sqlalchemy==1.4.45"], "requirements": ["sqlalchemy==2.0.2"],
"codeowners": ["@dgomes", "@gjohansson-ST"], "codeowners": ["@dgomes", "@gjohansson-ST"],
"config_flow": true, "config_flow": true,
"iot_class": "local_polling" "iot_class": "local_polling"

View file

@ -8,7 +8,7 @@ import logging
import sqlalchemy import sqlalchemy
from sqlalchemy.engine import Result from sqlalchemy.engine import Result
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm import Session, scoped_session, sessionmaker
from homeassistant.components.recorder import CONF_DB_URL, DEFAULT_DB_FILE, DEFAULT_URL from homeassistant.components.recorder import CONF_DB_URL, DEFAULT_DB_FILE, DEFAULT_URL
from homeassistant.components.sensor import SensorEntity from homeassistant.components.sensor import SensorEntity
@ -125,14 +125,14 @@ async def async_setup_sensor(
if not db_url: if not db_url:
db_url = DEFAULT_URL.format(hass_config_path=hass.config.path(DEFAULT_DB_FILE)) db_url = DEFAULT_URL.format(hass_config_path=hass.config.path(DEFAULT_DB_FILE))
sess: scoped_session | None = None sess: Session | None = None
try: try:
engine = sqlalchemy.create_engine(db_url, future=True) engine = sqlalchemy.create_engine(db_url, future=True)
sessmaker = scoped_session(sessionmaker(bind=engine, future=True)) sessmaker = scoped_session(sessionmaker(bind=engine, future=True))
# Run a dummy query just to test the db_url # Run a dummy query just to test the db_url
sess = sessmaker() sess = sessmaker()
sess.execute("SELECT 1;") sess.execute(sqlalchemy.text("SELECT 1;"))
except SQLAlchemyError as err: except SQLAlchemyError as err:
_LOGGER.error( _LOGGER.error(

View file

@ -42,7 +42,7 @@ pyudev==0.23.2
pyyaml==6.0 pyyaml==6.0
requests==2.28.1 requests==2.28.1
scapy==2.5.0 scapy==2.5.0
sqlalchemy==1.4.45 sqlalchemy==2.0.2
typing-extensions>=4.4.0,<5.0 typing-extensions>=4.4.0,<5.0
voluptuous-serialize==2.5.0 voluptuous-serialize==2.5.0
voluptuous==0.13.1 voluptuous==0.13.1

View file

@ -2392,7 +2392,7 @@ spotipy==2.22.1
# homeassistant.components.recorder # homeassistant.components.recorder
# homeassistant.components.sql # homeassistant.components.sql
sqlalchemy==1.4.45 sqlalchemy==2.0.2
# homeassistant.components.srp_energy # homeassistant.components.srp_energy
srpenergy==1.3.6 srpenergy==1.3.6

View file

@ -1689,7 +1689,7 @@ spotipy==2.22.1
# homeassistant.components.recorder # homeassistant.components.recorder
# homeassistant.components.sql # homeassistant.components.sql
sqlalchemy==1.4.45 sqlalchemy==2.0.2
# homeassistant.components.srp_energy # homeassistant.components.srp_energy
srpenergy==1.3.6 srpenergy==1.3.6

View file

@ -782,9 +782,8 @@ async def test_fetch_period_api_with_entity_glob_exclude(
assert response.status == HTTPStatus.OK assert response.status == HTTPStatus.OK
response_json = await response.json() response_json = await response.json()
assert len(response_json) == 3 assert len(response_json) == 3
assert response_json[0][0]["entity_id"] == "binary_sensor.sensor" entities = {state[0]["entity_id"] for state in response_json}
assert response_json[1][0]["entity_id"] == "light.cow" assert entities == {"binary_sensor.sensor", "light.cow", "light.match"}
assert response_json[2][0]["entity_id"] == "light.match"
async def test_fetch_period_api_with_entity_glob_include_and_exclude( async def test_fetch_period_api_with_entity_glob_include_and_exclude(
@ -824,10 +823,13 @@ async def test_fetch_period_api_with_entity_glob_include_and_exclude(
assert response.status == HTTPStatus.OK assert response.status == HTTPStatus.OK
response_json = await response.json() response_json = await response.json()
assert len(response_json) == 4 assert len(response_json) == 4
assert response_json[0][0]["entity_id"] == "light.many_state_changes" entities = {state[0]["entity_id"] for state in response_json}
assert response_json[1][0]["entity_id"] == "light.match" assert entities == {
assert response_json[2][0]["entity_id"] == "media_player.test" "light.many_state_changes",
assert response_json[3][0]["entity_id"] == "switch.match" "light.match",
"media_player.test",
"switch.match",
}
async def test_entity_ids_limit_via_api(recorder_mock, hass, hass_client): async def test_entity_ids_limit_via_api(recorder_mock, hass, hass_client):

View file

@ -822,9 +822,8 @@ async def test_fetch_period_api_with_entity_glob_exclude(
assert response.status == HTTPStatus.OK assert response.status == HTTPStatus.OK
response_json = await response.json() response_json = await response.json()
assert len(response_json) == 3 assert len(response_json) == 3
assert response_json[0][0]["entity_id"] == "binary_sensor.sensor" entities = {state[0]["entity_id"] for state in response_json}
assert response_json[1][0]["entity_id"] == "light.cow" assert entities == {"binary_sensor.sensor", "light.cow", "light.match"}
assert response_json[2][0]["entity_id"] == "light.match"
async def test_fetch_period_api_with_entity_glob_include_and_exclude( async def test_fetch_period_api_with_entity_glob_include_and_exclude(
@ -864,10 +863,13 @@ async def test_fetch_period_api_with_entity_glob_include_and_exclude(
assert response.status == HTTPStatus.OK assert response.status == HTTPStatus.OK
response_json = await response.json() response_json = await response.json()
assert len(response_json) == 4 assert len(response_json) == 4
assert response_json[0][0]["entity_id"] == "light.many_state_changes" entities = {state[0]["entity_id"] for state in response_json}
assert response_json[1][0]["entity_id"] == "light.match" assert entities == {
assert response_json[2][0]["entity_id"] == "media_player.test" "light.many_state_changes",
assert response_json[3][0]["entity_id"] == "switch.match" "light.match",
"media_player.test",
"switch.match",
}
async def test_entity_ids_limit_via_api(recorder_mock, hass, hass_client): async def test_entity_ids_limit_via_api(recorder_mock, hass, hass_client):

View file

@ -23,7 +23,6 @@ from sqlalchemy import (
) )
from sqlalchemy.dialects import mysql, oracle, postgresql from sqlalchemy.dialects import mysql, oracle, postgresql
from sqlalchemy.engine.row import Row from sqlalchemy.engine.row import Row
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import declarative_base, relationship from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
@ -322,16 +321,11 @@ class StatisticsBase:
id = Column(Integer, Identity(), primary_key=True) id = Column(Integer, Identity(), primary_key=True)
created = Column(DATETIME_TYPE, default=dt_util.utcnow) created = Column(DATETIME_TYPE, default=dt_util.utcnow)
metadata_id = Column(
@declared_attr # type: ignore[misc] Integer,
def metadata_id(self) -> Column: ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"),
"""Define the metadata_id column for sub classes.""" index=True,
return Column( )
Integer,
ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"),
index=True,
)
start = Column(DATETIME_TYPE, index=True) start = Column(DATETIME_TYPE, index=True)
mean = Column(DOUBLE_TYPE) mean = Column(DOUBLE_TYPE)
min = Column(DOUBLE_TYPE) min = Column(DOUBLE_TYPE)

View file

@ -28,7 +28,6 @@ from sqlalchemy import (
) )
from sqlalchemy.dialects import mysql, oracle, postgresql from sqlalchemy.dialects import mysql, oracle, postgresql
from sqlalchemy.engine.row import Row from sqlalchemy.engine.row import Row
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import declarative_base, relationship from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
@ -402,16 +401,11 @@ class StatisticsBase:
id = Column(Integer, Identity(), primary_key=True) id = Column(Integer, Identity(), primary_key=True)
created = Column(DATETIME_TYPE, default=dt_util.utcnow) created = Column(DATETIME_TYPE, default=dt_util.utcnow)
metadata_id = Column(
@declared_attr # type: ignore[misc] Integer,
def metadata_id(self) -> Column: ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"),
"""Define the metadata_id column for sub classes.""" index=True,
return Column( )
Integer,
ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"),
index=True,
)
start = Column(DATETIME_TYPE, index=True) start = Column(DATETIME_TYPE, index=True)
mean = Column(DOUBLE_TYPE) mean = Column(DOUBLE_TYPE)
min = Column(DOUBLE_TYPE) min = Column(DOUBLE_TYPE)

View file

@ -30,7 +30,6 @@ from sqlalchemy import (
type_coerce, type_coerce,
) )
from sqlalchemy.dialects import mysql, oracle, postgresql, sqlite from sqlalchemy.dialects import mysql, oracle, postgresql, sqlite
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import aliased, declarative_base, relationship from sqlalchemy.orm import aliased, declarative_base, relationship
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from typing_extensions import Self from typing_extensions import Self
@ -477,16 +476,11 @@ class StatisticsBase:
id = Column(Integer, Identity(), primary_key=True) id = Column(Integer, Identity(), primary_key=True)
created = Column(DATETIME_TYPE, default=dt_util.utcnow) created = Column(DATETIME_TYPE, default=dt_util.utcnow)
metadata_id = Column(
@declared_attr # type: ignore[misc] Integer,
def metadata_id(self) -> Column: ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"),
"""Define the metadata_id column for sub classes.""" index=True,
return Column( )
Integer,
ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"),
index=True,
)
start = Column(DATETIME_TYPE, index=True) start = Column(DATETIME_TYPE, index=True)
mean = Column(DOUBLE_TYPE) mean = Column(DOUBLE_TYPE)
min = Column(DOUBLE_TYPE) min = Column(DOUBLE_TYPE)

View file

@ -1904,6 +1904,9 @@ async def test_connect_args_priority(hass, config_url):
def on_connect_url(self, url): def on_connect_url(self, url):
return False return False
def _builtin_onconnect(self):
...
class MockEntrypoint: class MockEntrypoint:
def engine_created(*_): def engine_created(*_):
... ...

View file

@ -44,7 +44,7 @@ async def test_recorder_system_health_alternate_dbms(recorder_mock, hass, dialec
"homeassistant.components.recorder.core.Recorder.dialect_name", dialect_name "homeassistant.components.recorder.core.Recorder.dialect_name", dialect_name
), patch( ), patch(
"sqlalchemy.orm.session.Session.execute", "sqlalchemy.orm.session.Session.execute",
return_value=Mock(first=Mock(return_value=("1048576",))), return_value=Mock(scalar=Mock(return_value=("1048576"))),
): ):
info = await get_system_health_info(hass, "recorder") info = await get_system_health_info(hass, "recorder")
instance = get_instance(hass) instance = get_instance(hass)
@ -76,7 +76,7 @@ async def test_recorder_system_health_db_url_missing_host(
"postgresql://homeassistant:blabla@/home_assistant?host=/config/socket", "postgresql://homeassistant:blabla@/home_assistant?host=/config/socket",
), patch( ), patch(
"sqlalchemy.orm.session.Session.execute", "sqlalchemy.orm.session.Session.execute",
return_value=Mock(first=Mock(return_value=("1048576",))), return_value=Mock(scalar=Mock(return_value=("1048576"))),
): ):
info = await get_system_health_info(hass, "recorder") info = await get_system_health_info(hass, "recorder")
assert info == { assert info == {

View file

@ -5,6 +5,7 @@ from datetime import timedelta
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
import pytest import pytest
from sqlalchemy import text as sql_text
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from homeassistant.components.sql.const import DOMAIN from homeassistant.components.sql.const import DOMAIN
@ -114,7 +115,7 @@ async def test_query_mssql_no_result(
} }
with patch("homeassistant.components.sql.sensor.sqlalchemy"), patch( with patch("homeassistant.components.sql.sensor.sqlalchemy"), patch(
"homeassistant.components.sql.sensor.sqlalchemy.text", "homeassistant.components.sql.sensor.sqlalchemy.text",
return_value="SELECT TOP 1 5 as value where 1=2", return_value=sql_text("SELECT TOP 1 5 as value where 1=2"),
): ):
await init_integration(hass, config) await init_integration(hass, config)

View file

@ -998,6 +998,23 @@ def recorder_config():
def recorder_db_url(pytestconfig): def recorder_db_url(pytestconfig):
"""Prepare a default database for tests and return a connection URL.""" """Prepare a default database for tests and return a connection URL."""
db_url: str = pytestconfig.getoption("dburl") db_url: str = pytestconfig.getoption("dburl")
if db_url.startswith(("postgresql://", "mysql://")):
import sqlalchemy_utils
def _ha_orm_quote(mixed, ident):
"""Conditionally quote an identifier.
Modified to include https://github.com/kvesteri/sqlalchemy-utils/pull/677
"""
if isinstance(mixed, sqlalchemy_utils.functions.orm.Dialect):
dialect = mixed
elif hasattr(mixed, "dialect"):
dialect = mixed.dialect
else:
dialect = sqlalchemy_utils.functions.orm.get_bind(mixed).dialect
return dialect.preparer(dialect).quote(ident)
sqlalchemy_utils.functions.database.quote = _ha_orm_quote
if db_url.startswith("mysql://"): if db_url.startswith("mysql://"):
import sqlalchemy_utils import sqlalchemy_utils