diff --git a/homeassistant/components/logbook/models.py b/homeassistant/components/logbook/models.py index 0b0a9aeb414..3fc4b5dac8b 100644 --- a/homeassistant/components/logbook/models.py +++ b/homeassistant/components/logbook/models.py @@ -82,7 +82,7 @@ class EventAsRow: @callback -def async_event_to_row(event: Event) -> EventAsRow | None: +def async_event_to_row(event: Event) -> EventAsRow: """Convert an event to a row.""" if event.event_type != EVENT_STATE_CHANGED: return EventAsRow( diff --git a/homeassistant/components/logbook/processor.py b/homeassistant/components/logbook/processor.py index d73a852ca1c..289ee677a21 100644 --- a/homeassistant/components/logbook/processor.py +++ b/homeassistant/components/logbook/processor.py @@ -1,14 +1,14 @@ """Event parser and human readable log generator.""" from __future__ import annotations -from collections.abc import Callable, Generator +from collections.abc import Callable, Generator, Sequence from contextlib import suppress from dataclasses import dataclass from datetime import datetime as dt from typing import Any +from sqlalchemy.engine import Result from sqlalchemy.engine.row import Row -from sqlalchemy.orm.query import Query from homeassistant.components.recorder.filters import Filters from homeassistant.components.recorder.models import ( @@ -70,7 +70,7 @@ class LogbookRun: event_cache: EventCache entity_name_cache: EntityNameCache include_entity_name: bool - format_time: Callable[[Row], Any] + format_time: Callable[[Row | EventAsRow], Any] class EventProcessor: @@ -133,13 +133,13 @@ class EventProcessor: ) -> list[dict[str, Any]]: """Get events for a period of time.""" - def yield_rows(query: Query) -> Generator[Row, None, None]: + def yield_rows(result: Result) -> Sequence[Row] | Result: """Yield rows from the database.""" # end_day - start_day intentionally checks .days and not .total_seconds() # since we don't want to switch over to buffered if they go # over one day by a few hours since the UI makes it so easy to do that. if self.limited_select or (end_day - start_day).days <= 1: - return query.all() # type: ignore[no-any-return] + return result.all() # Only buffer rows to reduce memory pressure # if we expect the result set is going to be very large. # What is considered very large is going to differ @@ -149,7 +149,7 @@ class EventProcessor: # even and RPi3 that number seems higher in testing # so we don't switch over until we request > 1 day+ of data. # - return query.yield_per(1024) # type: ignore[no-any-return] + return result.yield_per(1024) stmt = statement_for_request( start_day, @@ -164,12 +164,12 @@ class EventProcessor: return self.humanify(yield_rows(session.execute(stmt))) def humanify( - self, row_generator: Generator[Row | EventAsRow, None, None] + self, rows: Generator[EventAsRow, None, None] | Sequence[Row] | Result ) -> list[dict[str, str]]: """Humanify rows.""" return list( _humanify( - row_generator, + rows, self.ent_reg, self.logbook_run, self.context_augmenter, @@ -178,7 +178,7 @@ class EventProcessor: def _humanify( - rows: Generator[Row | EventAsRow, None, None], + rows: Generator[EventAsRow, None, None] | Sequence[Row] | Result, ent_reg: er.EntityRegistry, logbook_run: LogbookRun, context_augmenter: ContextAugmenter, @@ -263,7 +263,7 @@ class ContextLookup: self._memorize_new = True self._lookup: dict[str | None, Row | EventAsRow | None] = {None: None} - def memorize(self, row: Row) -> str | None: + def memorize(self, row: Row | EventAsRow) -> str | None: """Memorize a context from the database.""" if self._memorize_new: context_id: str = row.context_id @@ -276,7 +276,7 @@ class ContextLookup: self._lookup.clear() self._memorize_new = False - def get(self, context_id: str) -> Row | None: + def get(self, context_id: str) -> Row | EventAsRow | None: """Get the context origin.""" return self._lookup.get(context_id) @@ -294,7 +294,7 @@ class ContextAugmenter: def _get_context_row( self, context_id: str | None, row: Row | EventAsRow - ) -> Row | EventAsRow: + ) -> Row | EventAsRow | None: """Get the context row from the id or row context.""" if context_id: return self.context_lookup.get(context_id) diff --git a/homeassistant/components/logbook/queries/all.py b/homeassistant/components/logbook/queries/all.py index 21624181a3b..729a4d2195a 100644 --- a/homeassistant/components/logbook/queries/all.py +++ b/homeassistant/components/logbook/queries/all.py @@ -2,9 +2,9 @@ from __future__ import annotations from sqlalchemy import lambda_stmt -from sqlalchemy.orm import Query -from sqlalchemy.sql.elements import ClauseList +from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.lambdas import StatementLambdaElement +from sqlalchemy.sql.selectable import Select from homeassistant.components.recorder.db_schema import ( LAST_UPDATED_INDEX_TS, @@ -24,8 +24,8 @@ def all_stmt( start_day: float, end_day: float, event_types: tuple[str, ...], - states_entity_filter: ClauseList | None = None, - events_entity_filter: ClauseList | None = None, + states_entity_filter: ColumnElement | None = None, + events_entity_filter: ColumnElement | None = None, context_id: str | None = None, ) -> StatementLambdaElement: """Generate a logbook query for all entities.""" @@ -37,8 +37,18 @@ def all_stmt( # are gone from the database remove the # _legacy_select_events_context_id() stmt += lambda s: s.where(Events.context_id == context_id).union_all( - _states_query_for_context_id(start_day, end_day, context_id), - legacy_select_events_context_id(start_day, end_day, context_id), + _states_query_for_context_id( + start_day, + end_day, + # https://github.com/python/mypy/issues/2608 + context_id, # type:ignore[arg-type] + ), + legacy_select_events_context_id( + start_day, + end_day, + # https://github.com/python/mypy/issues/2608 + context_id, # type:ignore[arg-type] + ), ) else: if events_entity_filter is not None: @@ -46,7 +56,10 @@ def all_stmt( if states_entity_filter is not None: stmt += lambda s: s.union_all( - _states_query_for_all(start_day, end_day).where(states_entity_filter) + _states_query_for_all(start_day, end_day).where( + # https://github.com/python/mypy/issues/2608 + states_entity_filter # type:ignore[arg-type] + ) ) else: stmt += lambda s: s.union_all(_states_query_for_all(start_day, end_day)) @@ -55,20 +68,20 @@ def all_stmt( return stmt -def _states_query_for_all(start_day: float, end_day: float) -> Query: +def _states_query_for_all(start_day: float, end_day: float) -> Select: return apply_states_filters(_apply_all_hints(select_states()), start_day, end_day) -def _apply_all_hints(query: Query) -> Query: +def _apply_all_hints(sel: Select) -> Select: """Force mysql to use the right index on large selects.""" - return query.with_hint( + return sel.with_hint( States, f"FORCE INDEX ({LAST_UPDATED_INDEX_TS})", dialect_name="mysql" ) def _states_query_for_context_id( start_day: float, end_day: float, context_id: str -) -> Query: +) -> Select: return apply_states_filters(select_states(), start_day, end_day).where( States.context_id == context_id ) diff --git a/homeassistant/components/logbook/queries/common.py b/homeassistant/components/logbook/queries/common.py index 362766504e3..ca00f31615a 100644 --- a/homeassistant/components/logbook/queries/common.py +++ b/homeassistant/components/logbook/queries/common.py @@ -5,8 +5,7 @@ from typing import Final import sqlalchemy from sqlalchemy import select -from sqlalchemy.orm import Query -from sqlalchemy.sql.elements import ClauseList +from sqlalchemy.sql.elements import BooleanClauseList, ColumnElement from sqlalchemy.sql.expression import literal from sqlalchemy.sql.selectable import Select @@ -69,7 +68,7 @@ STATE_CONTEXT_ONLY_COLUMNS = ( literal(value=None, type_=sqlalchemy.String).label("old_format_icon"), ) -EVENT_COLUMNS_FOR_STATE_SELECT = [ +EVENT_COLUMNS_FOR_STATE_SELECT = ( literal(value=None, type_=sqlalchemy.Text).label("event_id"), # We use PSEUDO_EVENT_STATE_CHANGED aka None for # state_changed events since it takes up less @@ -84,7 +83,7 @@ EVENT_COLUMNS_FOR_STATE_SELECT = [ States.context_user_id.label("context_user_id"), States.context_parent_id.label("context_parent_id"), literal(value=None, type_=sqlalchemy.Text).label("shared_data"), -] +) EMPTY_STATE_COLUMNS = ( literal(value=0, type_=sqlalchemy.Integer).label("state_id"), @@ -103,8 +102,8 @@ EVENT_ROWS_NO_STATES = ( # Virtual column to tell logbook if it should avoid processing # the event as its only used to link contexts -CONTEXT_ONLY = literal("1").label("context_only") -NOT_CONTEXT_ONLY = literal(None).label("context_only") +CONTEXT_ONLY = literal(value="1", type_=sqlalchemy.String).label("context_only") +NOT_CONTEXT_ONLY = literal(value=None, type_=sqlalchemy.String).label("context_only") def select_events_context_id_subquery( @@ -188,7 +187,7 @@ def legacy_select_events_context_id( ) -def apply_states_filters(query: Query, start_day: float, end_day: float) -> Query: +def apply_states_filters(sel: Select, start_day: float, end_day: float) -> Select: """Filter states by time range. Filters states that do not have an old state or new state (added / removed) @@ -196,7 +195,7 @@ def apply_states_filters(query: Query, start_day: float, end_day: float) -> Quer Filters states that do not have matching last_updated_ts and last_changed_ts. """ return ( - query.filter( + sel.filter( (States.last_updated_ts > start_day) & (States.last_updated_ts < end_day) ) .outerjoin(OLD_STATE, (States.old_state_id == OLD_STATE.state_id)) @@ -212,18 +211,18 @@ def apply_states_filters(query: Query, start_day: float, end_day: float) -> Quer ) -def _missing_state_matcher() -> sqlalchemy.and_: +def _missing_state_matcher() -> ColumnElement[bool]: # The below removes state change events that do not have # and old_state or the old_state is missing (newly added entities) # or the new_state is missing (removed entities) return sqlalchemy.and_( - OLD_STATE.state_id.isnot(None), + OLD_STATE.state_id.is_not(None), (States.state != OLD_STATE.state), - States.state.isnot(None), + States.state.is_not(None), ) -def _not_continuous_entity_matcher() -> sqlalchemy.or_: +def _not_continuous_entity_matcher() -> ColumnElement[bool]: """Match non continuous entities.""" return sqlalchemy.or_( # First exclude domains that may be continuous @@ -236,7 +235,7 @@ def _not_continuous_entity_matcher() -> sqlalchemy.or_: ) -def _not_possible_continuous_domain_matcher() -> sqlalchemy.and_: +def _not_possible_continuous_domain_matcher() -> ColumnElement[bool]: """Match not continuous domains. This matches domain that are always considered continuous @@ -254,7 +253,7 @@ def _not_possible_continuous_domain_matcher() -> sqlalchemy.and_: ).self_group() -def _conditionally_continuous_domain_matcher() -> sqlalchemy.or_: +def _conditionally_continuous_domain_matcher() -> ColumnElement[bool]: """Match conditionally continuous domains. This matches domain that are only considered @@ -268,22 +267,22 @@ def _conditionally_continuous_domain_matcher() -> sqlalchemy.or_: ).self_group() -def _not_uom_attributes_matcher() -> ClauseList: +def _not_uom_attributes_matcher() -> BooleanClauseList: """Prefilter ATTR_UNIT_OF_MEASUREMENT as its much faster in sql.""" return ~StateAttributes.shared_attrs.like( UNIT_OF_MEASUREMENT_JSON_LIKE ) | ~States.attributes.like(UNIT_OF_MEASUREMENT_JSON_LIKE) -def apply_states_context_hints(query: Query) -> Query: +def apply_states_context_hints(sel: Select) -> Select: """Force mysql to use the right index on large context_id selects.""" - return query.with_hint( + return sel.with_hint( States, f"FORCE INDEX ({STATES_CONTEXT_ID_INDEX})", dialect_name="mysql" ) -def apply_events_context_hints(query: Query) -> Query: +def apply_events_context_hints(sel: Select) -> Select: """Force mysql to use the right index on large context_id selects.""" - return query.with_hint( + return sel.with_hint( Events, f"FORCE INDEX ({EVENTS_CONTEXT_ID_INDEX})", dialect_name="mysql" ) diff --git a/homeassistant/components/logbook/queries/devices.py b/homeassistant/components/logbook/queries/devices.py index a270f1996ce..fa2deaf4c02 100644 --- a/homeassistant/components/logbook/queries/devices.py +++ b/homeassistant/components/logbook/queries/devices.py @@ -5,10 +5,9 @@ from collections.abc import Iterable import sqlalchemy from sqlalchemy import lambda_stmt, select -from sqlalchemy.orm import Query -from sqlalchemy.sql.elements import ClauseList +from sqlalchemy.sql.elements import BooleanClauseList from sqlalchemy.sql.lambdas import StatementLambdaElement -from sqlalchemy.sql.selectable import CTE, CompoundSelect +from sqlalchemy.sql.selectable import CTE, CompoundSelect, Select from homeassistant.components.recorder.db_schema import ( DEVICE_ID_IN_EVENT, @@ -32,7 +31,7 @@ def _select_device_id_context_ids_sub_query( end_day: float, event_types: tuple[str, ...], json_quotable_device_ids: list[str], -) -> CompoundSelect: +) -> Select: """Generate a subquery to find context ids for multiple devices.""" inner = select_events_context_id_subquery(start_day, end_day, event_types).where( apply_event_device_id_matchers(json_quotable_device_ids) @@ -41,7 +40,7 @@ def _select_device_id_context_ids_sub_query( def _apply_devices_context_union( - query: Query, + sel: Select, start_day: float, end_day: float, event_types: tuple[str, ...], @@ -54,7 +53,7 @@ def _apply_devices_context_union( event_types, json_quotable_device_ids, ).cte() - return query.union_all( + return sel.union_all( apply_events_context_hints( select_events_context_only() .select_from(devices_cte) @@ -91,7 +90,7 @@ def devices_stmt( def apply_event_device_id_matchers( json_quotable_device_ids: Iterable[str], -) -> ClauseList: +) -> BooleanClauseList: """Create matchers for the device_ids in the event_data.""" return DEVICE_ID_IN_EVENT.is_not(None) & sqlalchemy.cast( DEVICE_ID_IN_EVENT, sqlalchemy.Text() diff --git a/homeassistant/components/logbook/queries/entities.py b/homeassistant/components/logbook/queries/entities.py index afe7c7c7c2e..3d26443ce90 100644 --- a/homeassistant/components/logbook/queries/entities.py +++ b/homeassistant/components/logbook/queries/entities.py @@ -5,9 +5,9 @@ from collections.abc import Iterable import sqlalchemy from sqlalchemy import lambda_stmt, select, union_all -from sqlalchemy.orm import Query +from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.lambdas import StatementLambdaElement -from sqlalchemy.sql.selectable import CTE, CompoundSelect +from sqlalchemy.sql.selectable import CTE, CompoundSelect, Select from homeassistant.components.recorder.db_schema import ( ENTITY_ID_IN_EVENT, @@ -36,7 +36,7 @@ def _select_entities_context_ids_sub_query( event_types: tuple[str, ...], entity_ids: list[str], json_quoted_entity_ids: list[str], -) -> CompoundSelect: +) -> Select: """Generate a subquery to find context ids for multiple entities.""" union = union_all( select_events_context_id_subquery(start_day, end_day, event_types).where( @@ -52,7 +52,7 @@ def _select_entities_context_ids_sub_query( def _apply_entities_context_union( - query: Query, + sel: Select, start_day: float, end_day: float, event_types: tuple[str, ...], @@ -72,8 +72,8 @@ def _apply_entities_context_union( # query much slower on MySQL, and since we already filter them away # in the python code anyways since they will have context_only # set on them the impact is minimal. - return query.union_all( - states_query_for_entity_ids(start_day, end_day, entity_ids), + return sel.union_all( + states_select_for_entity_ids(start_day, end_day, entity_ids), apply_events_context_hints( select_events_context_only() .select_from(entities_cte) @@ -109,9 +109,9 @@ def entities_stmt( ) -def states_query_for_entity_ids( +def states_select_for_entity_ids( start_day: float, end_day: float, entity_ids: list[str] -) -> Query: +) -> Select: """Generate a select for states from the States table for specific entities.""" return apply_states_filters( apply_entities_hints(select_states()), start_day, end_day @@ -120,7 +120,7 @@ def states_query_for_entity_ids( def apply_event_entity_id_matchers( json_quoted_entity_ids: Iterable[str], -) -> sqlalchemy.or_: +) -> ColumnElement[bool]: """Create matchers for the entity_id in the event_data.""" return sqlalchemy.or_( ENTITY_ID_IN_EVENT.is_not(None) @@ -134,8 +134,8 @@ def apply_event_entity_id_matchers( ) -def apply_entities_hints(query: Query) -> Query: +def apply_entities_hints(sel: Select) -> Select: """Force mysql to use the right index on large selects.""" - return query.with_hint( + return sel.with_hint( States, f"FORCE INDEX ({ENTITY_ID_LAST_UPDATED_INDEX_TS})", dialect_name="mysql" ) diff --git a/homeassistant/components/logbook/queries/entities_and_devices.py b/homeassistant/components/logbook/queries/entities_and_devices.py index 94e9afc551d..43d11d0bdff 100644 --- a/homeassistant/components/logbook/queries/entities_and_devices.py +++ b/homeassistant/components/logbook/queries/entities_and_devices.py @@ -3,11 +3,10 @@ from __future__ import annotations from collections.abc import Iterable -import sqlalchemy from sqlalchemy import lambda_stmt, select, union_all -from sqlalchemy.orm import Query +from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.lambdas import StatementLambdaElement -from sqlalchemy.sql.selectable import CTE, CompoundSelect +from sqlalchemy.sql.selectable import CTE, CompoundSelect, Select from homeassistant.components.recorder.db_schema import EventData, Events, States @@ -23,7 +22,7 @@ from .devices import apply_event_device_id_matchers from .entities import ( apply_entities_hints, apply_event_entity_id_matchers, - states_query_for_entity_ids, + states_select_for_entity_ids, ) @@ -34,7 +33,7 @@ def _select_entities_device_id_context_ids_sub_query( entity_ids: list[str], json_quoted_entity_ids: list[str], json_quoted_device_ids: list[str], -) -> CompoundSelect: +) -> Select: """Generate a subquery to find context ids for multiple entities and multiple devices.""" union = union_all( select_events_context_id_subquery(start_day, end_day, event_types).where( @@ -52,7 +51,7 @@ def _select_entities_device_id_context_ids_sub_query( def _apply_entities_devices_context_union( - query: Query, + sel: Select, start_day: float, end_day: float, event_types: tuple[str, ...], @@ -73,8 +72,8 @@ def _apply_entities_devices_context_union( # query much slower on MySQL, and since we already filter them away # in the python code anyways since they will have context_only # set on them the impact is minimal. - return query.union_all( - states_query_for_entity_ids(start_day, end_day, entity_ids), + return sel.union_all( + states_select_for_entity_ids(start_day, end_day, entity_ids), apply_events_context_hints( select_events_context_only() .select_from(devices_entities_cte) @@ -117,7 +116,7 @@ def entities_devices_stmt( def _apply_event_entity_id_device_id_matchers( json_quoted_entity_ids: Iterable[str], json_quoted_device_ids: Iterable[str] -) -> sqlalchemy.or_: +) -> ColumnElement[bool]: """Create matchers for the device_id and entity_id in the event_data.""" return apply_event_entity_id_matchers( json_quoted_entity_ids diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index 129e77ec54a..002618738a1 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -190,7 +190,7 @@ class Recorder(threading.Thread): self.schema_version = 0 self._commits_without_expire = 0 - self._old_states: dict[str, States] = {} + self._old_states: dict[str | None, States] = {} self._state_attributes_ids: LRU = LRU(STATE_ATTRIBUTES_ID_CACHE_SIZE) self._event_data_ids: LRU = LRU(EVENT_DATA_ID_CACHE_SIZE) self._pending_state_attributes: dict[str, StateAttributes] = {} @@ -739,6 +739,7 @@ class Recorder(threading.Thread): self.hass.add_job(self._async_migration_started) try: + assert self.engine is not None migration.migrate_schema( self, self.hass, self.engine, self.get_session, schema_status ) @@ -1026,6 +1027,8 @@ class Recorder(threading.Thread): def _post_schema_migration(self, old_version: int, new_version: int) -> None: """Run post schema migration tasks.""" + assert self.engine is not None + assert self.event_session is not None migration.post_schema_migration( self.engine, self.event_session, old_version, new_version ) @@ -1034,7 +1037,7 @@ class Recorder(threading.Thread): """Send a keep alive to keep the db connection open.""" assert self.event_session is not None _LOGGER.debug("Sending keepalive") - self.event_session.connection().scalar(select([1])) + self.event_session.connection().scalar(select(1)) @callback def event_listener(self, event: Event) -> None: @@ -1198,6 +1201,8 @@ class Recorder(threading.Thread): start = start.replace(minute=0, second=0, microsecond=0) # Find the newest statistics run, if any + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable if last_run := session.query(func.max(StatisticsRuns.start)).scalar(): start = max(start, process_timestamp(last_run) + timedelta(minutes=5)) diff --git a/homeassistant/components/recorder/db_schema.py b/homeassistant/components/recorder/db_schema.py index aa0192b2412..5d10a459d88 100644 --- a/homeassistant/components/recorder/db_schema.py +++ b/homeassistant/components/recorder/db_schema.py @@ -13,7 +13,7 @@ from sqlalchemy import ( JSON, BigInteger, Boolean, - Column, + ColumnElement, DateTime, Float, ForeignKey, @@ -27,8 +27,9 @@ from sqlalchemy import ( type_coerce, ) from sqlalchemy.dialects import mysql, oracle, postgresql, sqlite -from sqlalchemy.ext.declarative import declared_attr -from sqlalchemy.orm import aliased, declarative_base, relationship +from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.orm import DeclarativeBase, Mapped, aliased, mapped_column, relationship +from sqlalchemy.orm.query import RowReturningQuery from sqlalchemy.orm.session import Session from typing_extensions import Self @@ -53,9 +54,12 @@ import homeassistant.util.dt as dt_util from .const import ALL_DOMAIN_EXCLUDE_ATTRS, SupportedDialect from .models import StatisticData, StatisticMetaData, process_timestamp + # SQLAlchemy Schema # pylint: disable=invalid-name -Base = declarative_base() +class Base(DeclarativeBase): + """Base class for tables.""" + SCHEMA_VERSION = 33 @@ -101,7 +105,7 @@ EVENTS_CONTEXT_ID_INDEX = "ix_events_context_id" STATES_CONTEXT_ID_INDEX = "ix_states_context_id" -class FAST_PYSQLITE_DATETIME(sqlite.DATETIME): # type: ignore[misc] +class FAST_PYSQLITE_DATETIME(sqlite.DATETIME): """Use ciso8601 to parse datetimes instead of sqlalchemy built-in regex.""" def result_processor(self, dialect, coltype): # type: ignore[no-untyped-def] @@ -110,19 +114,19 @@ class FAST_PYSQLITE_DATETIME(sqlite.DATETIME): # type: ignore[misc] JSON_VARIANT_CAST = Text().with_variant( - postgresql.JSON(none_as_null=True), "postgresql" + postgresql.JSON(none_as_null=True), "postgresql" # type: ignore[no-untyped-call] ) JSONB_VARIANT_CAST = Text().with_variant( - postgresql.JSONB(none_as_null=True), "postgresql" + postgresql.JSONB(none_as_null=True), "postgresql" # type: ignore[no-untyped-call] ) DATETIME_TYPE = ( DateTime(timezone=True) - .with_variant(mysql.DATETIME(timezone=True, fsp=6), "mysql") - .with_variant(FAST_PYSQLITE_DATETIME(), "sqlite") + .with_variant(mysql.DATETIME(timezone=True, fsp=6), "mysql") # type: ignore[no-untyped-call] + .with_variant(FAST_PYSQLITE_DATETIME(), "sqlite") # type: ignore[no-untyped-call] ) DOUBLE_TYPE = ( Float() - .with_variant(mysql.DOUBLE(asdecimal=False), "mysql") + .with_variant(mysql.DOUBLE(asdecimal=False), "mysql") # type: ignore[no-untyped-call] .with_variant(oracle.DOUBLE_PRECISION(), "oracle") .with_variant(postgresql.DOUBLE_PRECISION(), "postgresql") ) @@ -130,10 +134,10 @@ DOUBLE_TYPE = ( TIMESTAMP_TYPE = DOUBLE_TYPE -class JSONLiteral(JSON): # type: ignore[misc] +class JSONLiteral(JSON): """Teach SA how to literalize json.""" - def literal_processor(self, dialect: str) -> Callable[[Any], str]: + def literal_processor(self, dialect: Dialect) -> Callable[[Any], str]: """Processor to convert a value to JSON.""" def process(value: Any) -> str: @@ -147,7 +151,7 @@ EVENT_ORIGIN_ORDER = [EventOrigin.local, EventOrigin.remote] EVENT_ORIGIN_TO_IDX = {origin: idx for idx, origin in enumerate(EVENT_ORIGIN_ORDER)} -class Events(Base): # type: ignore[misc,valid-type] +class Events(Base): """Event history data.""" __table_args__ = ( @@ -157,18 +161,32 @@ class Events(Base): # type: ignore[misc,valid-type] {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, ) __tablename__ = TABLE_EVENTS - event_id = Column(Integer, Identity(), primary_key=True) - event_type = Column(String(MAX_LENGTH_EVENT_EVENT_TYPE)) - event_data = Column(Text().with_variant(mysql.LONGTEXT, "mysql")) - origin = Column(String(MAX_LENGTH_EVENT_ORIGIN)) # no longer used for new rows - origin_idx = Column(SmallInteger) - time_fired = Column(DATETIME_TYPE) # no longer used for new rows - time_fired_ts = Column(TIMESTAMP_TYPE, index=True) - context_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True) - context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID)) - context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID)) - data_id = Column(Integer, ForeignKey("event_data.data_id"), index=True) - event_data_rel = relationship("EventData") + event_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + event_type: Mapped[str | None] = mapped_column(String(MAX_LENGTH_EVENT_EVENT_TYPE)) + event_data: Mapped[str | None] = mapped_column( + Text().with_variant(mysql.LONGTEXT, "mysql") + ) + origin: Mapped[str | None] = mapped_column( + String(MAX_LENGTH_EVENT_ORIGIN) + ) # no longer used for new rows + origin_idx: Mapped[int | None] = mapped_column(SmallInteger) + time_fired: Mapped[datetime | None] = mapped_column( + DATETIME_TYPE + ) # no longer used for new rows + time_fired_ts: Mapped[float | None] = mapped_column(TIMESTAMP_TYPE, index=True) + context_id: Mapped[str | None] = mapped_column( + String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True + ) + context_user_id: Mapped[str | None] = mapped_column( + String(MAX_LENGTH_EVENT_CONTEXT_ID) + ) + context_parent_id: Mapped[str | None] = mapped_column( + String(MAX_LENGTH_EVENT_CONTEXT_ID) + ) + data_id: Mapped[int | None] = mapped_column( + Integer, ForeignKey("event_data.data_id"), index=True + ) + event_data_rel: Mapped[EventData | None] = relationship("EventData") def __repr__(self) -> str: """Return string representation of instance for debugging.""" @@ -180,12 +198,15 @@ class Events(Base): # type: ignore[misc,valid-type] ) @property - def _time_fired_isotime(self) -> str: + def _time_fired_isotime(self) -> str | None: """Return time_fired as an isotime string.""" + date_time: datetime | None if self.time_fired_ts is not None: date_time = dt_util.utc_from_timestamp(self.time_fired_ts) else: date_time = process_timestamp(self.time_fired) + if date_time is None: + return None return date_time.isoformat(sep=" ", timespec="seconds") @staticmethod @@ -211,12 +232,12 @@ class Events(Base): # type: ignore[misc,valid-type] ) try: return Event( - self.event_type, + self.event_type or "", json_loads_object(self.event_data) if self.event_data else {}, EventOrigin(self.origin) if self.origin - else EVENT_ORIGIN_ORDER[self.origin_idx], - dt_util.utc_from_timestamp(self.time_fired_ts), + else EVENT_ORIGIN_ORDER[self.origin_idx or 0], + dt_util.utc_from_timestamp(self.time_fired_ts or 0), context=context, ) except JSON_DECODE_EXCEPTIONS: @@ -225,17 +246,19 @@ class Events(Base): # type: ignore[misc,valid-type] return None -class EventData(Base): # type: ignore[misc,valid-type] +class EventData(Base): """Event data history.""" __table_args__ = ( {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, ) __tablename__ = TABLE_EVENT_DATA - data_id = Column(Integer, Identity(), primary_key=True) - hash = Column(BigInteger, index=True) + data_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + hash: Mapped[int | None] = mapped_column(BigInteger, index=True) # Note that this is not named attributes to avoid confusion with the states table - shared_data = Column(Text().with_variant(mysql.LONGTEXT, "mysql")) + shared_data: Mapped[str | None] = mapped_column( + Text().with_variant(mysql.LONGTEXT, "mysql") + ) def __repr__(self) -> str: """Return string representation of instance for debugging.""" @@ -260,15 +283,18 @@ class EventData(Base): # type: ignore[misc,valid-type] return cast(int, fnv1a_32(shared_data_bytes)) def to_native(self) -> dict[str, Any]: - """Convert to an HA state object.""" + """Convert to an event data dictionary.""" + shared_data = self.shared_data + if shared_data is None: + return {} try: - return cast(dict[str, Any], json_loads(self.shared_data)) + return cast(dict[str, Any], json_loads(shared_data)) except JSON_DECODE_EXCEPTIONS: _LOGGER.exception("Error converting row to event data: %s", self) return {} -class States(Base): # type: ignore[misc,valid-type] +class States(Base): """State change history.""" __table_args__ = ( @@ -278,29 +304,45 @@ class States(Base): # type: ignore[misc,valid-type] {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, ) __tablename__ = TABLE_STATES - state_id = Column(Integer, Identity(), primary_key=True) - entity_id = Column(String(MAX_LENGTH_STATE_ENTITY_ID)) - state = Column(String(MAX_LENGTH_STATE_STATE)) - attributes = Column( + state_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + entity_id: Mapped[str | None] = mapped_column(String(MAX_LENGTH_STATE_ENTITY_ID)) + state: Mapped[str | None] = mapped_column(String(MAX_LENGTH_STATE_STATE)) + attributes: Mapped[str | None] = mapped_column( Text().with_variant(mysql.LONGTEXT, "mysql") ) # no longer used for new rows - event_id = Column( # no longer used for new rows + event_id: Mapped[int | None] = mapped_column( # no longer used for new rows Integer, ForeignKey("events.event_id", ondelete="CASCADE"), index=True ) - last_changed = Column(DATETIME_TYPE) # no longer used for new rows - last_changed_ts = Column(TIMESTAMP_TYPE) - last_updated = Column(DATETIME_TYPE) # no longer used for new rows - last_updated_ts = Column(TIMESTAMP_TYPE, default=time.time, index=True) - old_state_id = Column(Integer, ForeignKey("states.state_id"), index=True) - attributes_id = Column( + last_changed: Mapped[datetime | None] = mapped_column( + DATETIME_TYPE + ) # no longer used for new rows + last_changed_ts: Mapped[float | None] = mapped_column(TIMESTAMP_TYPE) + last_updated: Mapped[datetime | None] = mapped_column( + DATETIME_TYPE + ) # no longer used for new rows + last_updated_ts: Mapped[float | None] = mapped_column( + TIMESTAMP_TYPE, default=time.time, index=True + ) + old_state_id: Mapped[int | None] = mapped_column( + Integer, ForeignKey("states.state_id"), index=True + ) + attributes_id: Mapped[int | None] = mapped_column( Integer, ForeignKey("state_attributes.attributes_id"), index=True ) - context_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True) - context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID)) - context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID)) - origin_idx = Column(SmallInteger) # 0 is local, 1 is remote - old_state = relationship("States", remote_side=[state_id]) - state_attributes = relationship("StateAttributes") + context_id: Mapped[str | None] = mapped_column( + String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True + ) + context_user_id: Mapped[str | None] = mapped_column( + String(MAX_LENGTH_EVENT_CONTEXT_ID) + ) + context_parent_id: Mapped[str | None] = mapped_column( + String(MAX_LENGTH_EVENT_CONTEXT_ID) + ) + origin_idx: Mapped[int | None] = mapped_column( + SmallInteger + ) # 0 is local, 1 is remote + old_state: Mapped[States | None] = relationship("States", remote_side=[state_id]) + state_attributes: Mapped[StateAttributes | None] = relationship("StateAttributes") def __repr__(self) -> str: """Return string representation of instance for debugging.""" @@ -312,12 +354,15 @@ class States(Base): # type: ignore[misc,valid-type] ) @property - def _last_updated_isotime(self) -> str: + def _last_updated_isotime(self) -> str | None: """Return last_updated as an isotime string.""" + date_time: datetime | None if self.last_updated_ts is not None: date_time = dt_util.utc_from_timestamp(self.last_updated_ts) else: date_time = process_timestamp(self.last_updated) + if date_time is None: + return None return date_time.isoformat(sep=" ", timespec="seconds") @staticmethod @@ -372,8 +417,8 @@ class States(Base): # type: ignore[misc,valid-type] last_updated = dt_util.utc_from_timestamp(self.last_updated_ts or 0) last_changed = dt_util.utc_from_timestamp(self.last_changed_ts or 0) return State( - self.entity_id, - self.state, + self.entity_id or "", + self.state, # type: ignore[arg-type] # Join the state_attributes table on attributes_id to get the attributes # for newer states attrs, @@ -384,17 +429,19 @@ class States(Base): # type: ignore[misc,valid-type] ) -class StateAttributes(Base): # type: ignore[misc,valid-type] +class StateAttributes(Base): """State attribute change history.""" __table_args__ = ( {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, ) __tablename__ = TABLE_STATE_ATTRIBUTES - attributes_id = Column(Integer, Identity(), primary_key=True) - hash = Column(BigInteger, index=True) + attributes_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + hash: Mapped[int | None] = mapped_column(BigInteger, index=True) # Note that this is not named attributes to avoid confusion with the states table - shared_attrs = Column(Text().with_variant(mysql.LONGTEXT, "mysql")) + shared_attrs: Mapped[str | None] = mapped_column( + Text().with_variant(mysql.LONGTEXT, "mysql") + ) def __repr__(self) -> str: """Return string representation of instance for debugging.""" @@ -439,9 +486,12 @@ class StateAttributes(Base): # type: ignore[misc,valid-type] return cast(int, fnv1a_32(shared_attrs_bytes)) def to_native(self) -> dict[str, Any]: - """Convert to an HA state object.""" + """Convert to a state attributes dictionary.""" + shared_attrs = self.shared_attrs + if shared_attrs is None: + return {} try: - return cast(dict[str, Any], json_loads(self.shared_attrs)) + return cast(dict[str, Any], json_loads(shared_attrs)) except JSON_DECODE_EXCEPTIONS: # When json_loads fails _LOGGER.exception("Error converting row to state attributes: %s", self) @@ -451,25 +501,22 @@ class StateAttributes(Base): # type: ignore[misc,valid-type] class StatisticsBase: """Statistics base class.""" - id = Column(Integer, Identity(), primary_key=True) - created = Column(DATETIME_TYPE, default=dt_util.utcnow) + id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + created: Mapped[datetime] = mapped_column(DATETIME_TYPE, default=dt_util.utcnow) + metadata_id: Mapped[int | None] = mapped_column( + Integer, + ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"), + index=True, + ) + start: Mapped[datetime | None] = mapped_column(DATETIME_TYPE, index=True) + mean: Mapped[float | None] = mapped_column(DOUBLE_TYPE) + min: Mapped[float | None] = mapped_column(DOUBLE_TYPE) + max: Mapped[float | None] = mapped_column(DOUBLE_TYPE) + last_reset: Mapped[datetime | None] = mapped_column(DATETIME_TYPE) + state: Mapped[float | None] = mapped_column(DOUBLE_TYPE) + sum: Mapped[float | None] = mapped_column(DOUBLE_TYPE) - @declared_attr # type: ignore[misc] - def metadata_id(self) -> Column: - """Define the metadata_id column for sub classes.""" - return Column( - Integer, - ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"), - index=True, - ) - - start = Column(DATETIME_TYPE, index=True) - mean = Column(DOUBLE_TYPE) - min = Column(DOUBLE_TYPE) - max = Column(DOUBLE_TYPE) - last_reset = Column(DATETIME_TYPE) - state = Column(DOUBLE_TYPE) - sum = Column(DOUBLE_TYPE) + duration: timedelta @classmethod def from_stats(cls, metadata_id: int, stats: StatisticData) -> Self: @@ -480,7 +527,7 @@ class StatisticsBase: ) -class Statistics(Base, StatisticsBase): # type: ignore[misc,valid-type] +class Statistics(Base, StatisticsBase): """Long term statistics.""" duration = timedelta(hours=1) @@ -492,7 +539,7 @@ class Statistics(Base, StatisticsBase): # type: ignore[misc,valid-type] __tablename__ = TABLE_STATISTICS -class StatisticsShortTerm(Base, StatisticsBase): # type: ignore[misc,valid-type] +class StatisticsShortTerm(Base, StatisticsBase): """Short term statistics.""" duration = timedelta(minutes=5) @@ -509,20 +556,22 @@ class StatisticsShortTerm(Base, StatisticsBase): # type: ignore[misc,valid-type __tablename__ = TABLE_STATISTICS_SHORT_TERM -class StatisticsMeta(Base): # type: ignore[misc,valid-type] +class StatisticsMeta(Base): """Statistics meta data.""" __table_args__ = ( {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, ) __tablename__ = TABLE_STATISTICS_META - id = Column(Integer, Identity(), primary_key=True) - statistic_id = Column(String(255), index=True, unique=True) - source = Column(String(32)) - unit_of_measurement = Column(String(255)) - has_mean = Column(Boolean) - has_sum = Column(Boolean) - name = Column(String(255)) + id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + statistic_id: Mapped[str | None] = mapped_column( + String(255), index=True, unique=True + ) + source: Mapped[str | None] = mapped_column(String(32)) + unit_of_measurement: Mapped[str | None] = mapped_column(String(255)) + has_mean: Mapped[bool | None] = mapped_column(Boolean) + has_sum: Mapped[bool | None] = mapped_column(Boolean) + name: Mapped[str | None] = mapped_column(String(255)) @staticmethod def from_meta(meta: StatisticMetaData) -> StatisticsMeta: @@ -530,16 +579,16 @@ class StatisticsMeta(Base): # type: ignore[misc,valid-type] return StatisticsMeta(**meta) -class RecorderRuns(Base): # type: ignore[misc,valid-type] +class RecorderRuns(Base): """Representation of recorder run.""" __table_args__ = (Index("ix_recorder_runs_start_end", "start", "end"),) __tablename__ = TABLE_RECORDER_RUNS - run_id = Column(Integer, Identity(), primary_key=True) - start = Column(DATETIME_TYPE, default=dt_util.utcnow) - end = Column(DATETIME_TYPE) - closed_incorrect = Column(Boolean, default=False) - created = Column(DATETIME_TYPE, default=dt_util.utcnow) + run_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + start: Mapped[datetime] = mapped_column(DATETIME_TYPE, default=dt_util.utcnow) + end: Mapped[datetime | None] = mapped_column(DATETIME_TYPE) + closed_incorrect: Mapped[bool] = mapped_column(Boolean, default=False) + created: Mapped[datetime] = mapped_column(DATETIME_TYPE, default=dt_util.utcnow) def __repr__(self) -> str: """Return string representation of instance for debugging.""" @@ -563,9 +612,9 @@ class RecorderRuns(Base): # type: ignore[misc,valid-type] assert session is not None, "RecorderRuns need to be persisted" - query = session.query(distinct(States.entity_id)).filter( - States.last_updated >= self.start - ) + query: RowReturningQuery[tuple[str]] = session.query(distinct(States.entity_id)) + + query = query.filter(States.last_updated >= self.start) if point_in_time is not None: query = query.filter(States.last_updated < point_in_time) @@ -579,13 +628,13 @@ class RecorderRuns(Base): # type: ignore[misc,valid-type] return self -class SchemaChanges(Base): # type: ignore[misc,valid-type] +class SchemaChanges(Base): """Representation of schema version changes.""" __tablename__ = TABLE_SCHEMA_CHANGES - change_id = Column(Integer, Identity(), primary_key=True) - schema_version = Column(Integer) - changed = Column(DATETIME_TYPE, default=dt_util.utcnow) + change_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + schema_version: Mapped[int | None] = mapped_column(Integer) + changed: Mapped[datetime] = mapped_column(DATETIME_TYPE, default=dt_util.utcnow) def __repr__(self) -> str: """Return string representation of instance for debugging.""" @@ -597,12 +646,12 @@ class SchemaChanges(Base): # type: ignore[misc,valid-type] ) -class StatisticsRuns(Base): # type: ignore[misc,valid-type] +class StatisticsRuns(Base): """Representation of statistics run.""" __tablename__ = TABLE_STATISTICS_RUNS - run_id = Column(Integer, Identity(), primary_key=True) - start = Column(DATETIME_TYPE, index=True) + run_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + start: Mapped[datetime] = mapped_column(DATETIME_TYPE, index=True) def __repr__(self) -> str: """Return string representation of instance for debugging.""" @@ -626,7 +675,7 @@ OLD_FORMAT_ATTRS_JSON = type_coerce( States.attributes.cast(JSON_VARIANT_CAST), JSON(none_as_null=True) ) -ENTITY_ID_IN_EVENT: Column = EVENT_DATA_JSON["entity_id"] -OLD_ENTITY_ID_IN_EVENT: Column = OLD_FORMAT_EVENT_DATA_JSON["entity_id"] -DEVICE_ID_IN_EVENT: Column = EVENT_DATA_JSON["device_id"] +ENTITY_ID_IN_EVENT: ColumnElement = EVENT_DATA_JSON["entity_id"] +OLD_ENTITY_ID_IN_EVENT: ColumnElement = OLD_FORMAT_EVENT_DATA_JSON["entity_id"] +DEVICE_ID_IN_EVENT: ColumnElement = EVENT_DATA_JSON["device_id"] OLD_STATE = aliased(States, name="old_state") diff --git a/homeassistant/components/recorder/filters.py b/homeassistant/components/recorder/filters.py index 48251b6db59..90f7d8c0a06 100644 --- a/homeassistant/components/recorder/filters.py +++ b/homeassistant/components/recorder/filters.py @@ -6,7 +6,7 @@ import json from typing import Any from sqlalchemy import Column, Text, cast, not_, or_ -from sqlalchemy.sql.elements import ClauseList +from sqlalchemy.sql.elements import ColumnElement from homeassistant.const import CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE from homeassistant.helpers.entityfilter import CONF_ENTITY_GLOBS @@ -125,7 +125,7 @@ class Filters: def _generate_filter_for_columns( self, columns: Iterable[Column], encoder: Callable[[Any], Any] - ) -> ClauseList: + ) -> ColumnElement | None: """Generate a filter from pre-comuted sets and pattern lists. This must match exactly how homeassistant.helpers.entityfilter works. @@ -174,6 +174,8 @@ class Filters: if self.included_domains or self.included_entity_globs: return or_( i_entities, + # https://github.com/sqlalchemy/sqlalchemy/issues/9190 + # pylint: disable-next=invalid-unary-operand-type (~e_entities & (i_entity_globs | (~e_entity_globs & i_domains))), ).self_group() @@ -184,23 +186,24 @@ class Filters: # - Otherwise, entity matches domain exclude: exclude # - Otherwise: include if self.excluded_domains or self.excluded_entity_globs: - return (not_(or_(*excludes)) | i_entities).self_group() + return (not_(or_(*excludes)) | i_entities).self_group() # type: ignore[no-any-return, no-untyped-call] # Case 6 - No Domain and/or glob includes or excludes # - Entity listed in entities include: include # - Otherwise: exclude return i_entities - def states_entity_filter(self) -> ClauseList: + def states_entity_filter(self) -> ColumnElement | None: """Generate the entity filter query.""" def _encoder(data: Any) -> Any: """Nothing to encode for states since there is no json.""" return data - return self._generate_filter_for_columns((States.entity_id,), _encoder) + # The type annotation should be improved so the type ignore can be removed + return self._generate_filter_for_columns((States.entity_id,), _encoder) # type: ignore[arg-type] - def events_entity_filter(self) -> ClauseList: + def events_entity_filter(self) -> ColumnElement: """Generate the entity filter query.""" _encoder = json.dumps return or_( @@ -215,15 +218,16 @@ class Filters: & ( (OLD_ENTITY_ID_IN_EVENT == JSON_NULL) | OLD_ENTITY_ID_IN_EVENT.is_(None) ), - self._generate_filter_for_columns( - (ENTITY_ID_IN_EVENT, OLD_ENTITY_ID_IN_EVENT), _encoder + # Needs https://github.com/bdraco/home-assistant/commit/bba91945006a46f3a01870008eb048e4f9cbb1ef + self._generate_filter_for_columns( # type: ignore[union-attr] + (ENTITY_ID_IN_EVENT, OLD_ENTITY_ID_IN_EVENT), _encoder # type: ignore[arg-type] ).self_group(), ) def _globs_to_like( glob_strs: Iterable[str], columns: Iterable[Column], encoder: Callable[[Any], Any] -) -> ClauseList: +) -> ColumnElement: """Translate glob to sql.""" matchers = [ ( @@ -240,7 +244,7 @@ def _globs_to_like( def _entity_matcher( entity_ids: Iterable[str], columns: Iterable[Column], encoder: Callable[[Any], Any] -) -> ClauseList: +) -> ColumnElement: matchers = [ ( column.is_not(None) @@ -253,7 +257,7 @@ def _entity_matcher( def _domain_matcher( domains: Iterable[str], columns: Iterable[Column], encoder: Callable[[Any], Any] -) -> ClauseList: +) -> ColumnElement: matchers = [ (column.is_not(None) & cast(column, Text()).like(encoder(domain_matcher))) for domain_matcher in like_domain_matchers(domains) diff --git a/homeassistant/components/recorder/history.py b/homeassistant/components/recorder/history.py index 0f3db70f66a..7a5bc80b956 100644 --- a/homeassistant/components/recorder/history.py +++ b/homeassistant/components/recorder/history.py @@ -59,87 +59,87 @@ NEED_ATTRIBUTE_DOMAINS = { } -_BASE_STATES = [ +_BASE_STATES = ( States.entity_id, States.state, States.last_changed_ts, States.last_updated_ts, -] -_BASE_STATES_NO_LAST_CHANGED = [ +) +_BASE_STATES_NO_LAST_CHANGED = ( # type: ignore[var-annotated] States.entity_id, States.state, literal(value=None).label("last_changed_ts"), States.last_updated_ts, -] -_QUERY_STATE_NO_ATTR = [ +) +_QUERY_STATE_NO_ATTR = ( *_BASE_STATES, literal(value=None, type_=Text).label("attributes"), literal(value=None, type_=Text).label("shared_attrs"), -] -_QUERY_STATE_NO_ATTR_NO_LAST_CHANGED = [ +) +_QUERY_STATE_NO_ATTR_NO_LAST_CHANGED = ( *_BASE_STATES_NO_LAST_CHANGED, literal(value=None, type_=Text).label("attributes"), literal(value=None, type_=Text).label("shared_attrs"), -] -_BASE_STATES_PRE_SCHEMA_31 = [ +) +_BASE_STATES_PRE_SCHEMA_31 = ( States.entity_id, States.state, States.last_changed, States.last_updated, -] -_BASE_STATES_NO_LAST_CHANGED_PRE_SCHEMA_31 = [ +) +_BASE_STATES_NO_LAST_CHANGED_PRE_SCHEMA_31 = ( States.entity_id, States.state, literal(value=None, type_=Text).label("last_changed"), States.last_updated, -] -_QUERY_STATE_NO_ATTR_PRE_SCHEMA_31 = [ +) +_QUERY_STATE_NO_ATTR_PRE_SCHEMA_31 = ( *_BASE_STATES_PRE_SCHEMA_31, literal(value=None, type_=Text).label("attributes"), literal(value=None, type_=Text).label("shared_attrs"), -] -_QUERY_STATE_NO_ATTR_NO_LAST_CHANGED_PRE_SCHEMA_31 = [ +) +_QUERY_STATE_NO_ATTR_NO_LAST_CHANGED_PRE_SCHEMA_31 = ( *_BASE_STATES_NO_LAST_CHANGED_PRE_SCHEMA_31, literal(value=None, type_=Text).label("attributes"), literal(value=None, type_=Text).label("shared_attrs"), -] +) # Remove QUERY_STATES_PRE_SCHEMA_25 # and the migration_in_progress check # once schema 26 is created -_QUERY_STATES_PRE_SCHEMA_25 = [ +_QUERY_STATES_PRE_SCHEMA_25 = ( *_BASE_STATES_PRE_SCHEMA_31, States.attributes, literal(value=None, type_=Text).label("shared_attrs"), -] -_QUERY_STATES_PRE_SCHEMA_25_NO_LAST_CHANGED = [ +) +_QUERY_STATES_PRE_SCHEMA_25_NO_LAST_CHANGED = ( *_BASE_STATES_NO_LAST_CHANGED_PRE_SCHEMA_31, States.attributes, literal(value=None, type_=Text).label("shared_attrs"), -] -_QUERY_STATES_PRE_SCHEMA_31 = [ +) +_QUERY_STATES_PRE_SCHEMA_31 = ( *_BASE_STATES_PRE_SCHEMA_31, # Remove States.attributes once all attributes are in StateAttributes.shared_attrs States.attributes, StateAttributes.shared_attrs, -] -_QUERY_STATES_NO_LAST_CHANGED_PRE_SCHEMA_31 = [ +) +_QUERY_STATES_NO_LAST_CHANGED_PRE_SCHEMA_31 = ( *_BASE_STATES_NO_LAST_CHANGED_PRE_SCHEMA_31, # Remove States.attributes once all attributes are in StateAttributes.shared_attrs States.attributes, StateAttributes.shared_attrs, -] -_QUERY_STATES = [ +) +_QUERY_STATES = ( *_BASE_STATES, # Remove States.attributes once all attributes are in StateAttributes.shared_attrs States.attributes, StateAttributes.shared_attrs, -] -_QUERY_STATES_NO_LAST_CHANGED = [ +) +_QUERY_STATES_NO_LAST_CHANGED = ( *_BASE_STATES_NO_LAST_CHANGED, # Remove States.attributes once all attributes are in StateAttributes.shared_attrs States.attributes, StateAttributes.shared_attrs, -] +) def _schema_version(hass: HomeAssistant) -> int: @@ -305,7 +305,10 @@ def _significant_states_stmt( ) if entity_ids: - stmt += lambda q: q.filter(States.entity_id.in_(entity_ids)) + stmt += lambda q: q.filter( + # https://github.com/python/mypy/issues/2608 + States.entity_id.in_(entity_ids) # type:ignore[arg-type] + ) else: stmt += _ignore_domains_filter if filters and filters.has_config: @@ -598,6 +601,8 @@ def _get_states_for_entites_stmt( stmt += lambda q: q.where( States.state_id == ( + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable select(func.max(States.state_id).label("max_state_id")) .filter( (States.last_updated_ts >= run_start_ts) @@ -612,6 +617,8 @@ def _get_states_for_entites_stmt( stmt += lambda q: q.where( States.state_id == ( + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable select(func.max(States.state_id).label("max_state_id")) .filter( (States.last_updated >= run_start) @@ -641,6 +648,8 @@ def _generate_most_recent_states_by_date( return ( select( States.entity_id.label("max_entity_id"), + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable func.max(States.last_updated_ts).label("max_last_updated"), ) .filter( @@ -653,6 +662,8 @@ def _generate_most_recent_states_by_date( return ( select( States.entity_id.label("max_entity_id"), + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable func.max(States.last_updated).label("max_last_updated"), ) .filter( @@ -686,6 +697,8 @@ def _get_states_for_all_stmt( stmt += lambda q: q.where( States.state_id == ( + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable select(func.max(States.state_id).label("max_state_id")) .join( most_recent_states_by_date, @@ -703,6 +716,8 @@ def _get_states_for_all_stmt( stmt += lambda q: q.where( States.state_id == ( + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable select(func.max(States.state_id).label("max_state_id")) .join( most_recent_states_by_date, @@ -876,11 +891,11 @@ def _sorted_states_to_dict( _LOGGER.debug("getting %d first datapoints took %fs", len(result), elapsed) if entity_ids and len(entity_ids) == 1: - states_iter: Iterable[tuple[str | Column, Iterator[States]]] = ( + states_iter: Iterable[tuple[str, Iterator[Row]]] = ( (entity_ids[0], iter(states)), ) else: - states_iter = groupby(states, lambda state: state.entity_id) + states_iter = groupby(states, lambda state: state.entity_id) # type: ignore[no-any-return] # Append all changes to it for ent_id, group in states_iter: diff --git a/homeassistant/components/recorder/manifest.json b/homeassistant/components/recorder/manifest.json index c9e05cb17f2..bf3b83f189d 100644 --- a/homeassistant/components/recorder/manifest.json +++ b/homeassistant/components/recorder/manifest.json @@ -2,7 +2,7 @@ "domain": "recorder", "name": "Recorder", "documentation": "https://www.home-assistant.io/integrations/recorder", - "requirements": ["sqlalchemy==1.4.45", "fnvhash==0.1.0"], + "requirements": ["sqlalchemy==2.0.2", "fnvhash==0.1.0"], "codeowners": ["@home-assistant/core"], "quality_scale": "internal", "iot_class": "local_push", diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index 99ac4adcd69..b7df00175e9 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -6,7 +6,7 @@ import contextlib from dataclasses import dataclass, replace as dataclass_replace from datetime import timedelta import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import sqlalchemy from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text @@ -96,7 +96,7 @@ def _schema_is_current(current_version: int) -> bool: def validate_db_schema( - hass: HomeAssistant, engine: Engine, session_maker: Callable[[], Session] + hass: HomeAssistant, instance: Recorder, session_maker: Callable[[], Session] ) -> SchemaValidationStatus | None: """Check if the schema is valid. @@ -113,7 +113,7 @@ def validate_db_schema( if is_current := _schema_is_current(current_version): # We can only check for further errors if the schema is current, because # columns may otherwise not exist etc. - schema_errors |= statistics_validate_db_schema(hass, engine, session_maker) + schema_errors |= statistics_validate_db_schema(hass, instance, session_maker) valid = is_current and not schema_errors @@ -444,17 +444,17 @@ def _update_states_table_with_foreign_key_options( states_key_constraints = Base.metadata.tables[TABLE_STATES].foreign_key_constraints old_states_table = Table( # noqa: F841 pylint: disable=unused-variable - TABLE_STATES, MetaData(), *(alter["old_fk"] for alter in alters) + TABLE_STATES, MetaData(), *(alter["old_fk"] for alter in alters) # type: ignore[arg-type] ) for alter in alters: with session_scope(session=session_maker()) as session: try: connection = session.connection() - connection.execute(DropConstraint(alter["old_fk"])) + connection.execute(DropConstraint(alter["old_fk"])) # type: ignore[no-untyped-call] for fkc in states_key_constraints: if fkc.column_keys == alter["columns"]: - connection.execute(AddConstraint(fkc)) + connection.execute(AddConstraint(fkc)) # type: ignore[no-untyped-call] except (InternalError, OperationalError): _LOGGER.exception( "Could not update foreign options in %s table", TABLE_STATES @@ -484,7 +484,7 @@ def _drop_foreign_key_constraints( with session_scope(session=session_maker()) as session: try: connection = session.connection() - connection.execute(DropConstraint(drop)) + connection.execute(DropConstraint(drop)) # type: ignore[no-untyped-call] except (InternalError, OperationalError): _LOGGER.exception( "Could not drop foreign constraints in %s table on %s", @@ -630,18 +630,21 @@ def _apply_update( # noqa: C901 # Order matters! Statistics and StatisticsShortTerm have a relation with # StatisticsMeta, so statistics need to be deleted before meta (or in pair # depending on the SQL backend); and meta needs to be created before statistics. + + # We need to cast __table__ to Table, explanation in + # https://github.com/sqlalchemy/sqlalchemy/issues/9130 Base.metadata.drop_all( bind=engine, tables=[ - StatisticsShortTerm.__table__, - Statistics.__table__, - StatisticsMeta.__table__, + cast(Table, StatisticsShortTerm.__table__), + cast(Table, Statistics.__table__), + cast(Table, StatisticsMeta.__table__), ], ) - StatisticsMeta.__table__.create(engine) - StatisticsShortTerm.__table__.create(engine) - Statistics.__table__.create(engine) + cast(Table, StatisticsMeta.__table__).create(engine) + cast(Table, StatisticsShortTerm.__table__).create(engine) + cast(Table, Statistics.__table__).create(engine) elif new_version == 19: # This adds the statistic runs table, insert a fake run to prevent duplicating # statistics. @@ -694,20 +697,22 @@ def _apply_update( # noqa: C901 # so statistics need to be deleted before meta (or in pair depending # on the SQL backend); and meta needs to be created before statistics. if engine.dialect.name == "oracle": + # We need to cast __table__ to Table, explanation in + # https://github.com/sqlalchemy/sqlalchemy/issues/9130 Base.metadata.drop_all( bind=engine, tables=[ - StatisticsShortTerm.__table__, - Statistics.__table__, - StatisticsMeta.__table__, - StatisticsRuns.__table__, + cast(Table, StatisticsShortTerm.__table__), + cast(Table, Statistics.__table__), + cast(Table, StatisticsMeta.__table__), + cast(Table, StatisticsRuns.__table__), ], ) - StatisticsRuns.__table__.create(engine) - StatisticsMeta.__table__.create(engine) - StatisticsShortTerm.__table__.create(engine) - Statistics.__table__.create(engine) + cast(Table, StatisticsRuns.__table__).create(engine) + cast(Table, StatisticsMeta.__table__).create(engine) + cast(Table, StatisticsShortTerm.__table__).create(engine) + cast(Table, Statistics.__table__).create(engine) # Block 5-minute statistics for one hour from the last run, or it will overlap # with existing hourly statistics. Don't block on a database with no existing @@ -715,6 +720,8 @@ def _apply_update( # noqa: C901 with session_scope(session=session_maker()) as session: if session.query(Statistics.id).count() and ( last_run_string := session.query( + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable func.max(StatisticsRuns.start) ).scalar() ): @@ -996,7 +1003,7 @@ def _migrate_columns_to_timestamp( ) ) result = None - while result is None or result.rowcount > 0: + while result is None or result.rowcount > 0: # type: ignore[unreachable] with session_scope(session=session_maker()) as session: result = session.connection().execute( text( @@ -1027,7 +1034,7 @@ def _migrate_columns_to_timestamp( ) ) result = None - while result is None or result.rowcount > 0: + while result is None or result.rowcount > 0: # type: ignore[unreachable] with session_scope(session=session_maker()) as session: result = session.connection().execute( text( diff --git a/homeassistant/components/recorder/pool.py b/homeassistant/components/recorder/pool.py index 8f39bdb9b7f..02ba7545f89 100644 --- a/homeassistant/components/recorder/pool.py +++ b/homeassistant/components/recorder/pool.py @@ -5,7 +5,12 @@ import traceback from typing import Any from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.pool import NullPool, SingletonThreadPool, StaticPool +from sqlalchemy.pool import ( + ConnectionPoolEntry, + NullPool, + SingletonThreadPool, + StaticPool, +) from homeassistant.helpers.frame import report from homeassistant.util.async_ import check_loop @@ -47,11 +52,10 @@ class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc] thread_name == "Recorder" or thread_name.startswith(DB_WORKER_PREFIX) ) - # Any can be switched out for ConnectionPoolEntry in the next version of sqlalchemy - def _do_return_conn(self, conn: Any) -> Any: + def _do_return_conn(self, record: ConnectionPoolEntry) -> Any: if self.recorder_or_dbworker: - return super()._do_return_conn(conn) - conn.close() + return super()._do_return_conn(record) + record.close() def shutdown(self) -> None: """Close the connection.""" @@ -92,7 +96,7 @@ class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc] return super(NullPool, self)._create_connection() -class MutexPool(StaticPool): # type: ignore[misc] +class MutexPool(StaticPool): """A pool which prevents concurrent accesses from multiple threads. This is used in tests to prevent unsafe concurrent accesses to in-memory SQLite @@ -102,14 +106,14 @@ class MutexPool(StaticPool): # type: ignore[misc] _reference_counter = 0 pool_lock: threading.RLock - def _do_return_conn(self, conn: Any) -> None: + def _do_return_conn(self, record: ConnectionPoolEntry) -> Any: if DEBUG_MUTEX_POOL_TRACE: trace = traceback.extract_stack() trace_msg = "\n" + "".join(traceback.format_list(trace[:-1])) else: trace_msg = "" - super()._do_return_conn(conn) + super()._do_return_conn(record) if DEBUG_MUTEX_POOL: self._reference_counter -= 1 _LOGGER.debug( diff --git a/homeassistant/components/recorder/purge.py b/homeassistant/components/recorder/purge.py index fa380e5a7e2..eb6413fa786 100644 --- a/homeassistant/components/recorder/purge.py +++ b/homeassistant/components/recorder/purge.py @@ -8,6 +8,7 @@ from itertools import islice, zip_longest import logging from typing import TYPE_CHECKING, Any +from sqlalchemy.engine.row import Row from sqlalchemy.orm.session import Session from sqlalchemy.sql.expression import distinct @@ -616,9 +617,9 @@ def _purge_filtered_states( database_engine: DatabaseEngine, ) -> None: """Remove filtered states and linked events.""" - state_ids: list[int] - attributes_ids: list[int] - event_ids: list[int] + state_ids: tuple[int, ...] + attributes_ids: tuple[int, ...] + event_ids: tuple[int, ...] state_ids, attributes_ids, event_ids = zip( *( session.query(States.state_id, States.attributes_id, States.event_id) @@ -627,12 +628,12 @@ def _purge_filtered_states( .all() ) ) - event_ids = [id_ for id_ in event_ids if id_ is not None] + filtered_event_ids = [id_ for id_ in event_ids if id_ is not None] _LOGGER.debug( "Selected %s state_ids to remove that should be filtered", len(state_ids) ) _purge_state_ids(instance, session, set(state_ids)) - _purge_event_ids(session, event_ids) + _purge_event_ids(session, filtered_event_ids) unused_attribute_ids_set = _select_unused_attributes_ids( session, {id_ for id_ in attributes_ids if id_ is not None}, database_engine ) @@ -656,7 +657,7 @@ def _purge_filtered_events( _LOGGER.debug( "Selected %s event_ids to remove that should be filtered", len(event_ids) ) - states: list[States] = ( + states: list[Row[tuple[int]]] = ( session.query(States.state_id).filter(States.event_id.in_(event_ids)).all() ) state_ids: set[int] = {state.state_id for state in states} diff --git a/homeassistant/components/recorder/queries.py b/homeassistant/components/recorder/queries.py index 29bac70eef6..c20393c69ae 100644 --- a/homeassistant/components/recorder/queries.py +++ b/homeassistant/components/recorder/queries.py @@ -42,6 +42,8 @@ def find_shared_data_id(attr_hash: int, shared_data: str) -> StatementLambdaElem def _state_attrs_exist(attr: int | None) -> Select: """Check if a state attributes id exists in the states table.""" + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable return select(func.min(States.attributes_id)).where(States.attributes_id == attr) @@ -279,6 +281,8 @@ def data_ids_exist_in_events_with_fast_in_distinct( def _event_data_id_exist(data_id: int | None) -> Select: """Check if a event data id exists in the events table.""" + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable return select(func.min(Events.data_id)).where(Events.data_id == data_id) @@ -620,6 +624,8 @@ def find_statistics_runs_to_purge( def find_latest_statistics_runs_run_id() -> StatementLambdaElement: """Find the latest statistics_runs run_id.""" + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable return lambda_stmt(lambda: select(func.max(StatisticsRuns.run_id))) @@ -639,4 +645,6 @@ def find_legacy_event_state_and_attributes_and_data_ids_to_purge( def find_legacy_row() -> StatementLambdaElement: """Check if there are still states in the table with an event_id.""" + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable return lambda_stmt(lambda: select(func.max(States.event_id))) diff --git a/homeassistant/components/recorder/run_history.py b/homeassistant/components/recorder/run_history.py index 02b2df066bd..2b1a65b4e99 100644 --- a/homeassistant/components/recorder/run_history.py +++ b/homeassistant/components/recorder/run_history.py @@ -122,7 +122,8 @@ class RunHistory: for run in session.query(RecorderRuns).order_by(RecorderRuns.start.asc()).all(): session.expunge(run) if run_dt := process_timestamp(run.start): - timestamp = run_dt.timestamp() + # Not sure if this is correct or runs_by_timestamp annotation should be changed + timestamp = int(run_dt.timestamp()) run_timestamps.append(timestamp) runs_by_timestamp[timestamp] = run diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index 244074cc10c..f52be1faa8f 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections import defaultdict -from collections.abc import Callable, Iterable, Mapping +from collections.abc import Callable, Iterable, Mapping, Sequence import contextlib import dataclasses from datetime import datetime, timedelta @@ -13,7 +13,7 @@ import logging import os import re from statistics import mean -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast from sqlalchemy import bindparam, func, lambda_stmt, select, text from sqlalchemy.engine import Engine @@ -77,7 +77,7 @@ from .util import ( if TYPE_CHECKING: from . import Recorder -QUERY_STATISTICS = [ +QUERY_STATISTICS = ( Statistics.metadata_id, Statistics.start, Statistics.mean, @@ -86,9 +86,9 @@ QUERY_STATISTICS = [ Statistics.last_reset, Statistics.state, Statistics.sum, -] +) -QUERY_STATISTICS_SHORT_TERM = [ +QUERY_STATISTICS_SHORT_TERM = ( StatisticsShortTerm.metadata_id, StatisticsShortTerm.start, StatisticsShortTerm.mean, @@ -97,30 +97,34 @@ QUERY_STATISTICS_SHORT_TERM = [ StatisticsShortTerm.last_reset, StatisticsShortTerm.state, StatisticsShortTerm.sum, -] +) -QUERY_STATISTICS_SUMMARY_MEAN = [ +QUERY_STATISTICS_SUMMARY_MEAN = ( StatisticsShortTerm.metadata_id, func.avg(StatisticsShortTerm.mean), + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable func.min(StatisticsShortTerm.min), + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable func.max(StatisticsShortTerm.max), -] +) -QUERY_STATISTICS_SUMMARY_SUM = [ +QUERY_STATISTICS_SUMMARY_SUM = ( StatisticsShortTerm.metadata_id, StatisticsShortTerm.start, StatisticsShortTerm.last_reset, StatisticsShortTerm.state, StatisticsShortTerm.sum, func.row_number() - .over( + .over( # type: ignore[no-untyped-call] partition_by=StatisticsShortTerm.metadata_id, order_by=StatisticsShortTerm.start.desc(), ) .label("rownum"), -] +) -QUERY_STATISTIC_META = [ +QUERY_STATISTIC_META = ( StatisticsMeta.id, StatisticsMeta.statistic_id, StatisticsMeta.source, @@ -128,7 +132,7 @@ QUERY_STATISTIC_META = [ StatisticsMeta.has_mean, StatisticsMeta.has_sum, StatisticsMeta.name, -] +) STATISTIC_UNIT_TO_UNIT_CONVERTER: dict[str | None, type[BaseUnitConverter]] = { @@ -372,7 +376,7 @@ def _update_or_add_metadata( statistic_id, new_metadata, ) - return meta.id # type: ignore[no-any-return] + return meta.id metadata_id, old_metadata = old_metadata_dict[statistic_id] if ( @@ -401,7 +405,7 @@ def _update_or_add_metadata( def _find_duplicates( - session: Session, table: type[Statistics | StatisticsShortTerm] + session: Session, table: type[StatisticsBase] ) -> tuple[list[int], list[dict]]: """Find duplicated statistics.""" subquery = ( @@ -411,6 +415,8 @@ def _find_duplicates( literal_column("1").label("is_duplicate"), ) .group_by(table.metadata_id, table.start) + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable .having(func.count() > 1) .subquery() ) @@ -435,7 +441,7 @@ def _find_duplicates( if not duplicates: return (duplicate_ids, non_identical_duplicates_as_dict) - def columns_to_dict(duplicate: type[Statistics | StatisticsShortTerm]) -> dict: + def columns_to_dict(duplicate: Row) -> dict: """Convert a SQLAlchemy row to dict.""" dict_ = {} for key in duplicate.__mapper__.c.keys(): @@ -466,7 +472,7 @@ def _find_duplicates( def _delete_duplicates_from_table( - session: Session, table: type[Statistics | StatisticsShortTerm] + session: Session, table: type[StatisticsBase] ) -> tuple[int, list[dict]]: """Identify and delete duplicated statistics from a specified table.""" all_non_identical_duplicates: list[dict] = [] @@ -542,6 +548,8 @@ def _find_statistics_meta_duplicates(session: Session) -> list[int]: literal_column("1").label("is_duplicate"), ) .group_by(StatisticsMeta.statistic_id) + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable .having(func.count() > 1) .subquery() ) @@ -672,8 +680,8 @@ def _compile_hourly_statistics(session: Session, start: datetime) -> None: } # Insert compiled hourly statistics in the database - for metadata_id, stat in summary.items(): - session.add(Statistics.from_stats(metadata_id, stat)) + for metadata_id, summary_item in summary.items(): + session.add(Statistics.from_stats(metadata_id, summary_item)) @retryable_database_job("statistics") @@ -743,7 +751,7 @@ def compile_statistics(instance: Recorder, start: datetime, fire_events: bool) - def _adjust_sum_statistics( session: Session, - table: type[Statistics | StatisticsShortTerm], + table: type[StatisticsBase], metadata_id: int, start_time: datetime, adj: float, @@ -767,7 +775,7 @@ def _adjust_sum_statistics( def _insert_statistics( session: Session, - table: type[Statistics | StatisticsShortTerm], + table: type[StatisticsBase], metadata_id: int, statistic: StatisticData, ) -> None: @@ -784,7 +792,7 @@ def _insert_statistics( def _update_statistics( session: Session, - table: type[Statistics | StatisticsShortTerm], + table: type[StatisticsBase], stat_id: int, statistic: StatisticData, ) -> None: @@ -816,8 +824,11 @@ def _generate_get_metadata_stmt( ) -> StatementLambdaElement: """Generate a statement to fetch metadata.""" stmt = lambda_stmt(lambda: select(*QUERY_STATISTIC_META)) - if statistic_ids is not None: - stmt += lambda q: q.where(StatisticsMeta.statistic_id.in_(statistic_ids)) + if statistic_ids: + stmt += lambda q: q.where( + # https://github.com/python/mypy/issues/2608 + StatisticsMeta.statistic_id.in_(statistic_ids) # type:ignore[arg-type] + ) if statistic_source is not None: stmt += lambda q: q.where(StatisticsMeta.source == statistic_source) if statistic_type == "mean": @@ -849,15 +860,15 @@ def get_metadata_with_session( return {} return { - meta["statistic_id"]: ( - meta["id"], + meta.statistic_id: ( + meta.id, { - "has_mean": meta["has_mean"], - "has_sum": meta["has_sum"], - "name": meta["name"], - "source": meta["source"], - "statistic_id": meta["statistic_id"], - "unit_of_measurement": meta["unit_of_measurement"], + "has_mean": meta.has_mean, + "has_sum": meta.has_sum, + "name": meta.name, + "source": meta.source, + "statistic_id": meta.statistic_id, + "unit_of_measurement": meta.unit_of_measurement, }, ) for meta in result @@ -1132,7 +1143,7 @@ def _statistics_during_period_stmt( start_time: datetime, end_time: datetime | None, metadata_ids: list[int] | None, - table: type[Statistics | StatisticsShortTerm], + table: type[StatisticsBase], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], ) -> StatementLambdaElement: """Prepare a database query for statistics during a given period. @@ -1140,25 +1151,28 @@ def _statistics_during_period_stmt( This prepares a lambda_stmt query, so we don't insert the parameters yet. """ - columns = [table.metadata_id, table.start] + columns = select(table.metadata_id, table.start) if "last_reset" in types: - columns.append(table.last_reset) + columns = columns.add_columns(table.last_reset) if "max" in types: - columns.append(table.max) + columns = columns.add_columns(table.max) if "mean" in types: - columns.append(table.mean) + columns = columns.add_columns(table.mean) if "min" in types: - columns.append(table.min) + columns = columns.add_columns(table.min) if "state" in types: - columns.append(table.state) + columns = columns.add_columns(table.state) if "sum" in types: - columns.append(table.sum) + columns = columns.add_columns(table.sum) - stmt = lambda_stmt(lambda: select(columns).filter(table.start >= start_time)) + stmt = lambda_stmt(lambda: columns.filter(table.start >= start_time)) if end_time is not None: stmt += lambda q: q.filter(table.start < end_time) if metadata_ids: - stmt += lambda q: q.filter(table.metadata_id.in_(metadata_ids)) + stmt += lambda q: q.filter( + # https://github.com/python/mypy/issues/2608 + table.metadata_id.in_(metadata_ids) # type:ignore[arg-type] + ) stmt += lambda q: q.order_by(table.metadata_id, table.start) return stmt @@ -1168,34 +1182,43 @@ def _get_max_mean_min_statistic_in_sub_period( result: dict[str, float], start_time: datetime | None, end_time: datetime | None, - table: type[Statistics | StatisticsShortTerm], + table: type[StatisticsBase], types: set[Literal["max", "mean", "min", "change"]], metadata_id: int, ) -> None: """Return max, mean and min during the period.""" # Calculate max, mean, min - columns = [] + columns = select() if "max" in types: - columns.append(func.max(table.max)) + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable + columns = columns.add_columns(func.max(table.max)) if "mean" in types: - columns.append(func.avg(table.mean)) - columns.append(func.count(table.mean)) + columns = columns.add_columns(func.avg(table.mean)) + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable + columns = columns.add_columns(func.count(table.mean)) if "min" in types: - columns.append(func.min(table.min)) - stmt = lambda_stmt(lambda: select(columns).filter(table.metadata_id == metadata_id)) + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable + columns = columns.add_columns(func.min(table.min)) + stmt = lambda_stmt(lambda: columns.filter(table.metadata_id == metadata_id)) if start_time is not None: stmt += lambda q: q.filter(table.start >= start_time) if end_time is not None: stmt += lambda q: q.filter(table.start < end_time) - stats = execute_stmt_lambda_element(session, stmt) - if "max" in types and stats and (new_max := stats[0].max) is not None: + stats = cast(Sequence[Row[Any]], execute_stmt_lambda_element(session, stmt)) + if not stats: + return + if "max" in types and (new_max := stats[0].max) is not None: old_max = result.get("max") result["max"] = max(new_max, old_max) if old_max is not None else new_max - if "mean" in types and stats and stats[0].avg is not None: - duration = stats[0].count * table.duration.total_seconds() + if "mean" in types and stats[0].avg is not None: + # https://github.com/sqlalchemy/sqlalchemy/issues/9127 + duration = stats[0].count * table.duration.total_seconds() # type: ignore[operator] result["duration"] = result.get("duration", 0.0) + duration result["mean_acc"] = result.get("mean_acc", 0.0) + stats[0].avg * duration - if "min" in types and stats and (new_min := stats[0].min) is not None: + if "min" in types and (new_min := stats[0].min) is not None: old_min = result.get("min") result["min"] = min(new_min, old_min) if old_min is not None else new_min @@ -1268,7 +1291,7 @@ def _get_max_mean_min_statistic( def _first_statistic( session: Session, - table: type[Statistics | StatisticsShortTerm], + table: type[StatisticsBase], metadata_id: int, ) -> datetime | None: """Return the data of the oldest statistic row for a given metadata id.""" @@ -1278,9 +1301,8 @@ def _first_statistic( .order_by(table.start.asc()) .limit(1) ) - if stats := execute_stmt_lambda_element(session, stmt): - return process_timestamp(stats[0].start) # type: ignore[no-any-return] - return None + stats = cast(Sequence[Row], execute_stmt_lambda_element(session, stmt)) + return process_timestamp(stats[0].start) if stats else None def _get_oldest_sum_statistic( @@ -1297,7 +1319,7 @@ def _get_oldest_sum_statistic( def _get_oldest_sum_statistic_in_sub_period( session: Session, start_time: datetime | None, - table: type[Statistics | StatisticsShortTerm], + table: type[StatisticsBase], metadata_id: int, ) -> float | None: """Return the oldest non-NULL sum during the period.""" @@ -1317,7 +1339,7 @@ def _get_oldest_sum_statistic( period = start_time.replace(minute=0, second=0, microsecond=0) prev_period = period - table.duration stmt += lambda q: q.filter(table.start >= prev_period) - stats = execute_stmt_lambda_element(session, stmt) + stats = cast(Sequence[Row], execute_stmt_lambda_element(session, stmt)) return stats[0].sum if stats else None oldest_sum: float | None = None @@ -1380,7 +1402,7 @@ def _get_newest_sum_statistic( session: Session, start_time: datetime | None, end_time: datetime | None, - table: type[Statistics | StatisticsShortTerm], + table: type[StatisticsBase], metadata_id: int, ) -> float | None: """Return the newest non-NULL sum during the period.""" @@ -1397,7 +1419,7 @@ def _get_newest_sum_statistic( stmt += lambda q: q.filter(table.start >= start_time) if end_time is not None: stmt += lambda q: q.filter(table.start < end_time) - stats = execute_stmt_lambda_element(session, stmt) + stats = cast(Sequence[Row], execute_stmt_lambda_element(session, stmt)) return stats[0].sum if stats else None @@ -1700,7 +1722,7 @@ def _get_last_statistics( number_of_stats: int, statistic_id: str, convert_units: bool, - table: type[Statistics | StatisticsShortTerm], + table: type[StatisticsBase], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], ) -> dict[str, list[dict]]: """Return the last number_of_stats statistics for a given statistic_id.""" @@ -1766,6 +1788,8 @@ def _generate_most_recent_statistic_row(metadata_ids: list[int]) -> Subquery: return ( select( StatisticsShortTerm.metadata_id, + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable func.max(StatisticsShortTerm.start).label("start_max"), ) .where(StatisticsShortTerm.metadata_id.in_(metadata_ids)) @@ -1831,28 +1855,30 @@ def get_latest_short_term_statistics( def _statistics_at_time( session: Session, metadata_ids: set[int], - table: type[Statistics | StatisticsShortTerm], + table: type[StatisticsBase], start_time: datetime, types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], -) -> list | None: +) -> Sequence[Row] | None: """Return last known statistics, earlier than start_time, for the metadata_ids.""" - columns = [table.metadata_id, table.start] + columns = select(table.metadata_id, table.start) if "last_reset" in types: - columns.append(table.last_reset) + columns = columns.add_columns(table.last_reset) if "max" in types: - columns.append(table.max) + columns = columns.add_columns(table.max) if "mean" in types: - columns.append(table.mean) + columns = columns.add_columns(table.mean) if "min" in types: - columns.append(table.min) + columns = columns.add_columns(table.min) if "state" in types: - columns.append(table.state) + columns = columns.add_columns(table.state) if "sum" in types: - columns.append(table.sum) + columns = columns.add_columns(table.sum) - stmt = lambda_stmt(lambda: select(columns)) + stmt = lambda_stmt(lambda: columns) most_recent_statistic_ids = ( + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable lambda_stmt(lambda: select(func.max(table.id).label("max_id"))) .filter(table.start < start_time) .filter(table.metadata_id.in_(metadata_ids)) @@ -1864,7 +1890,7 @@ def _statistics_at_time( most_recent_statistic_ids, table.id == most_recent_statistic_ids.c.max_id, ) - return execute_stmt_lambda_element(session, stmt) + return cast(Sequence[Row], execute_stmt_lambda_element(session, stmt)) def _sorted_statistics_to_dict( @@ -1874,7 +1900,7 @@ def _sorted_statistics_to_dict( statistic_ids: list[str] | None, _metadata: dict[str, tuple[int, StatisticMetaData]], convert_units: bool, - table: type[Statistics | StatisticsShortTerm], + table: type[StatisticsBase], start_time: datetime | None, units: dict[str, str] | None, types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], @@ -1965,7 +1991,7 @@ def validate_statistics(hass: HomeAssistant) -> dict[str, list[ValidationIssue]] def _statistics_exists( session: Session, - table: type[Statistics | StatisticsShortTerm], + table: type[StatisticsBase], metadata_id: int, start: datetime, ) -> int | None: @@ -1975,7 +2001,7 @@ def _statistics_exists( .filter((table.metadata_id == metadata_id) & (table.start == start)) .first() ) - return result["id"] if result else None + return result.id if result else None @callback @@ -2067,11 +2093,16 @@ def _filter_unique_constraint_integrity_error( ignore = True if ( dialect_name == SupportedDialect.POSTGRESQL + and err.orig and hasattr(err.orig, "pgcode") and err.orig.pgcode == "23505" ): ignore = True - if dialect_name == SupportedDialect.MYSQL and hasattr(err.orig, "args"): + if ( + dialect_name == SupportedDialect.MYSQL + and err.orig + and hasattr(err.orig, "args") + ): with contextlib.suppress(TypeError): if err.orig.args[0] == 1062: ignore = True @@ -2095,7 +2126,7 @@ def _import_statistics_with_session( session: Session, metadata: StatisticMetaData, statistics: Iterable[StatisticData], - table: type[Statistics | StatisticsShortTerm], + table: type[StatisticsBase], ) -> bool: """Import statistics to the database.""" old_metadata_dict = get_metadata_with_session( @@ -2116,7 +2147,7 @@ def import_statistics( instance: Recorder, metadata: StatisticMetaData, statistics: Iterable[StatisticData], - table: type[Statistics | StatisticsShortTerm], + table: type[StatisticsBase], ) -> bool: """Process an import_statistics job.""" @@ -2174,7 +2205,7 @@ def _change_statistics_unit_for_table( convert: Callable[[float | None], float | None], ) -> None: """Insert statistics in the database.""" - columns = [table.id, table.mean, table.min, table.max, table.state, table.sum] + columns = (table.id, table.mean, table.min, table.max, table.state, table.sum) query = session.query(*columns).filter_by(metadata_id=bindparam("metadata_id")) rows = execute(query.params(metadata_id=metadata_id)) for row in rows: @@ -2215,7 +2246,11 @@ def change_statistics_unit( metadata_id = metadata[0] convert = _get_unit_converter(old_unit, new_unit) - for table in (StatisticsShortTerm, Statistics): + tables: tuple[type[StatisticsBase], ...] = ( + Statistics, + StatisticsShortTerm, + ) + for table in tables: _change_statistics_unit_for_table(session, table, metadata_id, convert) session.query(StatisticsMeta).filter( StatisticsMeta.statistic_id == statistic_id diff --git a/homeassistant/components/recorder/system_health/mysql.py b/homeassistant/components/recorder/system_health/mysql.py index 747a806c227..1ade699eaf1 100644 --- a/homeassistant/components/recorder/system_health/mysql.py +++ b/homeassistant/components/recorder/system_health/mysql.py @@ -14,7 +14,7 @@ def db_size_bytes(session: Session, database_name: str) -> float | None: "TABLE_SCHEMA=:database_name" ), {"database_name": database_name}, - ).first()[0] + ).scalar() if size is None: return None diff --git a/homeassistant/components/recorder/system_health/postgresql.py b/homeassistant/components/recorder/system_health/postgresql.py index 3e0667b1f4f..aa9197a8e85 100644 --- a/homeassistant/components/recorder/system_health/postgresql.py +++ b/homeassistant/components/recorder/system_health/postgresql.py @@ -5,11 +5,14 @@ from sqlalchemy import text from sqlalchemy.orm.session import Session -def db_size_bytes(session: Session, database_name: str) -> float: +def db_size_bytes(session: Session, database_name: str) -> float | None: """Get the mysql database size.""" - return float( - session.execute( - text("select pg_database_size(:database_name);"), - {"database_name": database_name}, - ).first()[0] - ) + size = session.execute( + text("select pg_database_size(:database_name);"), + {"database_name": database_name}, + ).scalar() + + if not size: + return None + + return float(size) diff --git a/homeassistant/components/recorder/system_health/sqlite.py b/homeassistant/components/recorder/system_health/sqlite.py index 5a5901d2cb3..01c601aa9e9 100644 --- a/homeassistant/components/recorder/system_health/sqlite.py +++ b/homeassistant/components/recorder/system_health/sqlite.py @@ -5,13 +5,16 @@ from sqlalchemy import text from sqlalchemy.orm.session import Session -def db_size_bytes(session: Session, database_name: str) -> float: +def db_size_bytes(session: Session, database_name: str) -> float | None: """Get the mysql database size.""" - return float( - session.execute( - text( - "SELECT page_count * page_size as size " - "FROM pragma_page_count(), pragma_page_size();" - ) - ).first()[0] - ) + size = session.execute( + text( + "SELECT page_count * page_size as size " + "FROM pragma_page_count(), pragma_page_size();" + ) + ).scalar() + + if not size: + return None + + return float(size) diff --git a/homeassistant/components/recorder/util.py b/homeassistant/components/recorder/util.py index 90c096e75e2..bd5bcf5f20f 100644 --- a/homeassistant/components/recorder/util.py +++ b/homeassistant/components/recorder/util.py @@ -1,7 +1,7 @@ """SQLAlchemy util functions.""" from __future__ import annotations -from collections.abc import Callable, Generator +from collections.abc import Callable, Generator, Sequence from contextlib import contextmanager from datetime import date, datetime, timedelta import functools @@ -17,8 +17,7 @@ from awesomeversion import ( ) import ciso8601 from sqlalchemy import text -from sqlalchemy.engine.cursor import CursorFetchStrategy -from sqlalchemy.engine.row import Row +from sqlalchemy.engine import Result, Row from sqlalchemy.exc import OperationalError, SQLAlchemyError from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session @@ -45,6 +44,8 @@ from .models import ( ) if TYPE_CHECKING: + from sqlite3.dbapi2 import Cursor as SQLiteCursor + from . import Recorder _RecorderT = TypeVar("_RecorderT", bound="Recorder") @@ -202,8 +203,8 @@ def execute_stmt_lambda_element( stmt: StatementLambdaElement, start_time: datetime | None = None, end_time: datetime | None = None, - yield_per: int | None = DEFAULT_YIELD_STATES_ROWS, -) -> list[Row]: + yield_per: int = DEFAULT_YIELD_STATES_ROWS, +) -> Sequence[Row] | Result: """Execute a StatementLambdaElement. If the time window passed is greater than one day @@ -220,8 +221,8 @@ def execute_stmt_lambda_element( for tryno in range(RETRIES): try: if use_all: - return executed.all() # type: ignore[no-any-return] - return executed.yield_per(yield_per) # type: ignore[no-any-return] + return executed.all() + return executed.yield_per(yield_per) except SQLAlchemyError as err: _LOGGER.error("Error executing query: %s", err) if tryno == RETRIES - 1: @@ -252,7 +253,7 @@ def dburl_to_path(dburl: str) -> str: return dburl.removeprefix(SQLITE_URL_PREFIX) -def last_run_was_recently_clean(cursor: CursorFetchStrategy) -> bool: +def last_run_was_recently_clean(cursor: SQLiteCursor) -> bool: """Verify the last recorder run was recently clean.""" cursor.execute("SELECT end FROM recorder_runs ORDER BY start DESC LIMIT 1;") @@ -273,7 +274,7 @@ def last_run_was_recently_clean(cursor: CursorFetchStrategy) -> bool: return True -def basic_sanity_check(cursor: CursorFetchStrategy) -> bool: +def basic_sanity_check(cursor: SQLiteCursor) -> bool: """Check tables to make sure select does not fail.""" for table in TABLES_TO_CHECK: @@ -300,7 +301,7 @@ def validate_sqlite_database(dbpath: str) -> bool: return True -def run_checks_on_open_db(dbpath: str, cursor: CursorFetchStrategy) -> None: +def run_checks_on_open_db(dbpath: str, cursor: SQLiteCursor) -> None: """Run checks that will generate a sqlite3 exception if there is corruption.""" sanity_check_passed = basic_sanity_check(cursor) last_run_was_clean = last_run_was_recently_clean(cursor) diff --git a/homeassistant/components/sql/config_flow.py b/homeassistant/components/sql/config_flow.py index f33aad68c3e..a6b1afe4049 100644 --- a/homeassistant/components/sql/config_flow.py +++ b/homeassistant/components/sql/config_flow.py @@ -7,7 +7,7 @@ from typing import Any import sqlalchemy from sqlalchemy.engine import Result from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import scoped_session, sessionmaker +from sqlalchemy.orm import Session, scoped_session, sessionmaker import voluptuous as vol from homeassistant import config_entries @@ -47,7 +47,7 @@ def validate_query(db_url: str, query: str, column: str) -> bool: engine = sqlalchemy.create_engine(db_url, future=True) sessmaker = scoped_session(sessionmaker(bind=engine, future=True)) - sess: scoped_session = sessmaker() + sess: Session = sessmaker() try: result: Result = sess.execute(sqlalchemy.text(query)) diff --git a/homeassistant/components/sql/manifest.json b/homeassistant/components/sql/manifest.json index db3c20b2fc3..d81528b08e4 100644 --- a/homeassistant/components/sql/manifest.json +++ b/homeassistant/components/sql/manifest.json @@ -2,7 +2,7 @@ "domain": "sql", "name": "SQL", "documentation": "https://www.home-assistant.io/integrations/sql", - "requirements": ["sqlalchemy==1.4.45"], + "requirements": ["sqlalchemy==2.0.2"], "codeowners": ["@dgomes", "@gjohansson-ST"], "config_flow": true, "iot_class": "local_polling" diff --git a/homeassistant/components/sql/sensor.py b/homeassistant/components/sql/sensor.py index 4469c4c8057..5d51087a9dd 100644 --- a/homeassistant/components/sql/sensor.py +++ b/homeassistant/components/sql/sensor.py @@ -8,7 +8,7 @@ import logging import sqlalchemy from sqlalchemy.engine import Result from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import scoped_session, sessionmaker +from sqlalchemy.orm import Session, scoped_session, sessionmaker from homeassistant.components.recorder import CONF_DB_URL, DEFAULT_DB_FILE, DEFAULT_URL from homeassistant.components.sensor import SensorEntity @@ -125,14 +125,14 @@ async def async_setup_sensor( if not db_url: db_url = DEFAULT_URL.format(hass_config_path=hass.config.path(DEFAULT_DB_FILE)) - sess: scoped_session | None = None + sess: Session | None = None try: engine = sqlalchemy.create_engine(db_url, future=True) sessmaker = scoped_session(sessionmaker(bind=engine, future=True)) # Run a dummy query just to test the db_url sess = sessmaker() - sess.execute("SELECT 1;") + sess.execute(sqlalchemy.text("SELECT 1;")) except SQLAlchemyError as err: _LOGGER.error( diff --git a/homeassistant/package_constraints.txt b/homeassistant/package_constraints.txt index 3dad786bf2f..953049da035 100644 --- a/homeassistant/package_constraints.txt +++ b/homeassistant/package_constraints.txt @@ -42,7 +42,7 @@ pyudev==0.23.2 pyyaml==6.0 requests==2.28.1 scapy==2.5.0 -sqlalchemy==1.4.45 +sqlalchemy==2.0.2 typing-extensions>=4.4.0,<5.0 voluptuous-serialize==2.5.0 voluptuous==0.13.1 diff --git a/requirements_all.txt b/requirements_all.txt index 124eea4e952..10f9d0d27ed 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -2392,7 +2392,7 @@ spotipy==2.22.1 # homeassistant.components.recorder # homeassistant.components.sql -sqlalchemy==1.4.45 +sqlalchemy==2.0.2 # homeassistant.components.srp_energy srpenergy==1.3.6 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 9c605ab7e22..46cfe8cca3d 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -1689,7 +1689,7 @@ spotipy==2.22.1 # homeassistant.components.recorder # homeassistant.components.sql -sqlalchemy==1.4.45 +sqlalchemy==2.0.2 # homeassistant.components.srp_energy srpenergy==1.3.6 diff --git a/tests/components/history/test_init.py b/tests/components/history/test_init.py index 22f10d1893c..83901bec7ec 100644 --- a/tests/components/history/test_init.py +++ b/tests/components/history/test_init.py @@ -782,9 +782,8 @@ async def test_fetch_period_api_with_entity_glob_exclude( assert response.status == HTTPStatus.OK response_json = await response.json() assert len(response_json) == 3 - assert response_json[0][0]["entity_id"] == "binary_sensor.sensor" - assert response_json[1][0]["entity_id"] == "light.cow" - assert response_json[2][0]["entity_id"] == "light.match" + entities = {state[0]["entity_id"] for state in response_json} + assert entities == {"binary_sensor.sensor", "light.cow", "light.match"} async def test_fetch_period_api_with_entity_glob_include_and_exclude( @@ -824,10 +823,13 @@ async def test_fetch_period_api_with_entity_glob_include_and_exclude( assert response.status == HTTPStatus.OK response_json = await response.json() assert len(response_json) == 4 - assert response_json[0][0]["entity_id"] == "light.many_state_changes" - assert response_json[1][0]["entity_id"] == "light.match" - assert response_json[2][0]["entity_id"] == "media_player.test" - assert response_json[3][0]["entity_id"] == "switch.match" + entities = {state[0]["entity_id"] for state in response_json} + assert entities == { + "light.many_state_changes", + "light.match", + "media_player.test", + "switch.match", + } async def test_entity_ids_limit_via_api(recorder_mock, hass, hass_client): diff --git a/tests/components/history/test_init_db_schema_30.py b/tests/components/history/test_init_db_schema_30.py index 47340abfcfd..344cf436792 100644 --- a/tests/components/history/test_init_db_schema_30.py +++ b/tests/components/history/test_init_db_schema_30.py @@ -822,9 +822,8 @@ async def test_fetch_period_api_with_entity_glob_exclude( assert response.status == HTTPStatus.OK response_json = await response.json() assert len(response_json) == 3 - assert response_json[0][0]["entity_id"] == "binary_sensor.sensor" - assert response_json[1][0]["entity_id"] == "light.cow" - assert response_json[2][0]["entity_id"] == "light.match" + entities = {state[0]["entity_id"] for state in response_json} + assert entities == {"binary_sensor.sensor", "light.cow", "light.match"} async def test_fetch_period_api_with_entity_glob_include_and_exclude( @@ -864,10 +863,13 @@ async def test_fetch_period_api_with_entity_glob_include_and_exclude( assert response.status == HTTPStatus.OK response_json = await response.json() assert len(response_json) == 4 - assert response_json[0][0]["entity_id"] == "light.many_state_changes" - assert response_json[1][0]["entity_id"] == "light.match" - assert response_json[2][0]["entity_id"] == "media_player.test" - assert response_json[3][0]["entity_id"] == "switch.match" + entities = {state[0]["entity_id"] for state in response_json} + assert entities == { + "light.many_state_changes", + "light.match", + "media_player.test", + "switch.match", + } async def test_entity_ids_limit_via_api(recorder_mock, hass, hass_client): diff --git a/tests/components/recorder/db_schema_25.py b/tests/components/recorder/db_schema_25.py index 43aa245a761..7f276d42df8 100644 --- a/tests/components/recorder/db_schema_25.py +++ b/tests/components/recorder/db_schema_25.py @@ -23,7 +23,6 @@ from sqlalchemy import ( ) from sqlalchemy.dialects import mysql, oracle, postgresql from sqlalchemy.engine.row import Row -from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import declarative_base, relationship from sqlalchemy.orm.session import Session @@ -322,16 +321,11 @@ class StatisticsBase: id = Column(Integer, Identity(), primary_key=True) created = Column(DATETIME_TYPE, default=dt_util.utcnow) - - @declared_attr # type: ignore[misc] - def metadata_id(self) -> Column: - """Define the metadata_id column for sub classes.""" - return Column( - Integer, - ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"), - index=True, - ) - + metadata_id = Column( + Integer, + ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"), + index=True, + ) start = Column(DATETIME_TYPE, index=True) mean = Column(DOUBLE_TYPE) min = Column(DOUBLE_TYPE) diff --git a/tests/components/recorder/db_schema_28.py b/tests/components/recorder/db_schema_28.py index 8d2de0432ac..422f317a6f1 100644 --- a/tests/components/recorder/db_schema_28.py +++ b/tests/components/recorder/db_schema_28.py @@ -28,7 +28,6 @@ from sqlalchemy import ( ) from sqlalchemy.dialects import mysql, oracle, postgresql from sqlalchemy.engine.row import Row -from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import declarative_base, relationship from sqlalchemy.orm.session import Session @@ -402,16 +401,11 @@ class StatisticsBase: id = Column(Integer, Identity(), primary_key=True) created = Column(DATETIME_TYPE, default=dt_util.utcnow) - - @declared_attr # type: ignore[misc] - def metadata_id(self) -> Column: - """Define the metadata_id column for sub classes.""" - return Column( - Integer, - ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"), - index=True, - ) - + metadata_id = Column( + Integer, + ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"), + index=True, + ) start = Column(DATETIME_TYPE, index=True) mean = Column(DOUBLE_TYPE) min = Column(DOUBLE_TYPE) diff --git a/tests/components/recorder/db_schema_30.py b/tests/components/recorder/db_schema_30.py index 451cc94d9fa..78715480297 100644 --- a/tests/components/recorder/db_schema_30.py +++ b/tests/components/recorder/db_schema_30.py @@ -30,7 +30,6 @@ from sqlalchemy import ( type_coerce, ) from sqlalchemy.dialects import mysql, oracle, postgresql, sqlite -from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import aliased, declarative_base, relationship from sqlalchemy.orm.session import Session from typing_extensions import Self @@ -477,16 +476,11 @@ class StatisticsBase: id = Column(Integer, Identity(), primary_key=True) created = Column(DATETIME_TYPE, default=dt_util.utcnow) - - @declared_attr # type: ignore[misc] - def metadata_id(self) -> Column: - """Define the metadata_id column for sub classes.""" - return Column( - Integer, - ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"), - index=True, - ) - + metadata_id = Column( + Integer, + ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"), + index=True, + ) start = Column(DATETIME_TYPE, index=True) mean = Column(DOUBLE_TYPE) min = Column(DOUBLE_TYPE) diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index fc5cf8ebd9b..9fccf055855 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -1904,6 +1904,9 @@ async def test_connect_args_priority(hass, config_url): def on_connect_url(self, url): return False + def _builtin_onconnect(self): + ... + class MockEntrypoint: def engine_created(*_): ... diff --git a/tests/components/recorder/test_system_health.py b/tests/components/recorder/test_system_health.py index a90f3b9f24d..0ef7adca14c 100644 --- a/tests/components/recorder/test_system_health.py +++ b/tests/components/recorder/test_system_health.py @@ -44,7 +44,7 @@ async def test_recorder_system_health_alternate_dbms(recorder_mock, hass, dialec "homeassistant.components.recorder.core.Recorder.dialect_name", dialect_name ), patch( "sqlalchemy.orm.session.Session.execute", - return_value=Mock(first=Mock(return_value=("1048576",))), + return_value=Mock(scalar=Mock(return_value=("1048576"))), ): info = await get_system_health_info(hass, "recorder") instance = get_instance(hass) @@ -76,7 +76,7 @@ async def test_recorder_system_health_db_url_missing_host( "postgresql://homeassistant:blabla@/home_assistant?host=/config/socket", ), patch( "sqlalchemy.orm.session.Session.execute", - return_value=Mock(first=Mock(return_value=("1048576",))), + return_value=Mock(scalar=Mock(return_value=("1048576"))), ): info = await get_system_health_info(hass, "recorder") assert info == { diff --git a/tests/components/sql/test_sensor.py b/tests/components/sql/test_sensor.py index d586ccbe3b1..c0a046e1f8d 100644 --- a/tests/components/sql/test_sensor.py +++ b/tests/components/sql/test_sensor.py @@ -5,6 +5,7 @@ from datetime import timedelta from unittest.mock import AsyncMock, patch import pytest +from sqlalchemy import text as sql_text from sqlalchemy.exc import SQLAlchemyError from homeassistant.components.sql.const import DOMAIN @@ -114,7 +115,7 @@ async def test_query_mssql_no_result( } with patch("homeassistant.components.sql.sensor.sqlalchemy"), patch( "homeassistant.components.sql.sensor.sqlalchemy.text", - return_value="SELECT TOP 1 5 as value where 1=2", + return_value=sql_text("SELECT TOP 1 5 as value where 1=2"), ): await init_integration(hass, config) diff --git a/tests/conftest.py b/tests/conftest.py index 410cc4f1303..7e732468f6d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -998,6 +998,23 @@ def recorder_config(): def recorder_db_url(pytestconfig): """Prepare a default database for tests and return a connection URL.""" db_url: str = pytestconfig.getoption("dburl") + if db_url.startswith(("postgresql://", "mysql://")): + import sqlalchemy_utils + + def _ha_orm_quote(mixed, ident): + """Conditionally quote an identifier. + + Modified to include https://github.com/kvesteri/sqlalchemy-utils/pull/677 + """ + if isinstance(mixed, sqlalchemy_utils.functions.orm.Dialect): + dialect = mixed + elif hasattr(mixed, "dialect"): + dialect = mixed.dialect + else: + dialect = sqlalchemy_utils.functions.orm.get_bind(mixed).dialect + return dialect.preparer(dialect).quote(ident) + + sqlalchemy_utils.functions.database.quote = _ha_orm_quote if db_url.startswith("mysql://"): import sqlalchemy_utils