diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index 2e6e5a7bd12..32119b85597 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -25,7 +25,7 @@ from .models import ( StatisticsShortTerm, process_timestamp, ) -from .statistics import get_start_time +from .statistics import delete_duplicates, get_start_time from .util import session_scope _LOGGER = logging.getLogger(__name__) @@ -587,6 +587,22 @@ def _apply_update(instance, session, new_version, old_version): # noqa: C901 elif new_version == 23: # Add name column to StatisticsMeta _add_columns(session, "statistics_meta", ["name VARCHAR(255)"]) + elif new_version == 24: + # Delete duplicated statistics + delete_duplicates(instance, session) + # Recreate statistics indices to block duplicated statistics + _drop_index(connection, "statistics", "ix_statistics_statistic_id_start") + _create_index(connection, "statistics", "ix_statistics_statistic_id_start") + _drop_index( + connection, + "statistics_short_term", + "ix_statistics_short_term_statistic_id_start", + ) + _create_index( + connection, + "statistics_short_term", + "ix_statistics_short_term_statistic_id_start", + ) else: raise ValueError(f"No schema migration defined for version {new_version}") diff --git a/homeassistant/components/recorder/models.py b/homeassistant/components/recorder/models.py index 6998c8e5f53..55d6f73108c 100644 --- a/homeassistant/components/recorder/models.py +++ b/homeassistant/components/recorder/models.py @@ -40,7 +40,7 @@ import homeassistant.util.dt as dt_util # pylint: disable=invalid-name Base = declarative_base() -SCHEMA_VERSION = 23 +SCHEMA_VERSION = 24 _LOGGER = logging.getLogger(__name__) @@ -289,7 +289,7 @@ class Statistics(Base, StatisticsBase): # type: ignore __table_args__ = ( # Used for fetching statistics for a certain entity at a specific time - Index("ix_statistics_statistic_id_start", "metadata_id", "start"), + Index("ix_statistics_statistic_id_start", "metadata_id", "start", unique=True), ) __tablename__ = TABLE_STATISTICS @@ -301,7 +301,12 @@ class StatisticsShortTerm(Base, StatisticsBase): # type: ignore __table_args__ = ( # Used for fetching statistics for a certain entity at a specific time - Index("ix_statistics_short_term_statistic_id_start", "metadata_id", "start"), + Index( + "ix_statistics_short_term_statistic_id_start", + "metadata_id", + "start", + unique=True, + ), ) __tablename__ = TABLE_STATISTICS_SHORT_TERM diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index 02c00722e72..5310c8ed9f3 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -3,19 +3,21 @@ from __future__ import annotations from collections import defaultdict from collections.abc import Callable, Iterable +import contextlib import dataclasses from datetime import datetime, timedelta from itertools import chain, groupby +import json import logging import re from statistics import mean from typing import TYPE_CHECKING, Any, Literal from sqlalchemy import bindparam, func -from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.exc import SQLAlchemyError, StatementError from sqlalchemy.ext import baked from sqlalchemy.orm.scoping import scoped_session -from sqlalchemy.sql.expression import true +from sqlalchemy.sql.expression import literal_column, true from homeassistant.const import ( PRESSURE_PA, @@ -26,13 +28,14 @@ from homeassistant.const import ( from homeassistant.core import Event, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import entity_registry +from homeassistant.helpers.json import JSONEncoder import homeassistant.util.dt as dt_util import homeassistant.util.pressure as pressure_util import homeassistant.util.temperature as temperature_util from homeassistant.util.unit_system import UnitSystem import homeassistant.util.volume as volume_util -from .const import DATA_INSTANCE, DOMAIN +from .const import DATA_INSTANCE, DOMAIN, MAX_ROWS_TO_PURGE from .models import ( StatisticData, StatisticMetaData, @@ -114,6 +117,8 @@ QUERY_STATISTIC_META_ID = [ StatisticsMeta.statistic_id, ] +MAX_DUPLICATES = 1000000 + STATISTICS_BAKERY = "recorder_statistics_bakery" STATISTICS_META_BAKERY = "recorder_statistics_meta_bakery" STATISTICS_SHORT_TERM_BAKERY = "recorder_statistics_short_term_bakery" @@ -262,6 +267,139 @@ def _update_or_add_metadata( return metadata_id +def _find_duplicates( + session: scoped_session, table: type[Statistics | StatisticsShortTerm] +) -> tuple[list[int], list[dict]]: + """Find duplicated statistics.""" + subquery = ( + session.query( + table.start, + table.metadata_id, + literal_column("1").label("is_duplicate"), + ) + .group_by(table.metadata_id, table.start) + .having(func.count() > 1) + .subquery() + ) + query = ( + session.query(table) + .outerjoin( + subquery, + (subquery.c.metadata_id == table.metadata_id) + & (subquery.c.start == table.start), + ) + .filter(subquery.c.is_duplicate == 1) + .order_by(table.metadata_id, table.start, table.id.desc()) + .limit(MAX_ROWS_TO_PURGE) + ) + duplicates = execute(query) + original_as_dict = {} + start = None + metadata_id = None + duplicate_ids: list[int] = [] + non_identical_duplicates_as_dict: list[dict] = [] + + if not duplicates: + return (duplicate_ids, non_identical_duplicates_as_dict) + + def columns_to_dict(duplicate: type[Statistics | StatisticsShortTerm]) -> dict: + """Convert a SQLAlchemy row to dict.""" + dict_ = {} + for key in duplicate.__mapper__.c.keys(): + dict_[key] = getattr(duplicate, key) + return dict_ + + def compare_statistic_rows(row1: dict, row2: dict) -> bool: + """Compare two statistics rows, ignoring id and created.""" + ignore_keys = ["id", "created"] + keys1 = set(row1).difference(ignore_keys) + keys2 = set(row2).difference(ignore_keys) + return keys1 == keys2 and all(row1[k] == row2[k] for k in keys1) + + for duplicate in duplicates: + if start != duplicate.start or metadata_id != duplicate.metadata_id: + original_as_dict = columns_to_dict(duplicate) + start = duplicate.start + metadata_id = duplicate.metadata_id + continue + duplicate_as_dict = columns_to_dict(duplicate) + duplicate_ids.append(duplicate.id) + if not compare_statistic_rows(original_as_dict, duplicate_as_dict): + non_identical_duplicates_as_dict.append(duplicate_as_dict) + + return (duplicate_ids, non_identical_duplicates_as_dict) + + +def _delete_duplicates_from_table( + session: scoped_session, table: type[Statistics | StatisticsShortTerm] +) -> tuple[int, list[dict]]: + """Identify and delete duplicated statistics from a specified table.""" + all_non_identical_duplicates: list[dict] = [] + total_deleted_rows = 0 + while True: + duplicate_ids, non_identical_duplicates = _find_duplicates(session, table) + if not duplicate_ids: + break + all_non_identical_duplicates.extend(non_identical_duplicates) + deleted_rows = ( + session.query(table) + .filter(table.id.in_(duplicate_ids)) + .delete(synchronize_session=False) + ) + total_deleted_rows += deleted_rows + if total_deleted_rows >= MAX_DUPLICATES: + break + return (total_deleted_rows, all_non_identical_duplicates) + + +def delete_duplicates(instance: Recorder, session: scoped_session) -> None: + """Identify and delete duplicated statistics. + + A backup will be made of duplicated statistics before it is deleted. + """ + deleted_statistics_rows, non_identical_duplicates = _delete_duplicates_from_table( + session, Statistics + ) + if deleted_statistics_rows: + _LOGGER.info("Deleted %s duplicated statistics rows", deleted_statistics_rows) + + if non_identical_duplicates: + isotime = dt_util.utcnow().isoformat() + backup_file_name = f"deleted_statistics.{isotime}.json" + backup_path = instance.hass.config.path(backup_file_name) + with open(backup_path, "w", encoding="utf8") as backup_file: + json.dump( + non_identical_duplicates, + backup_file, + indent=4, + sort_keys=True, + cls=JSONEncoder, + ) + _LOGGER.warning( + "Deleted %s non identical duplicated %s rows, a backup of the deleted rows " + "has been saved to %s", + len(non_identical_duplicates), + Statistics.__tablename__, + backup_path, + ) + + if deleted_statistics_rows >= MAX_DUPLICATES: + _LOGGER.warning( + "Found more than %s duplicated statistic rows, please report at " + 'https://github.com/home-assistant/core/issues?q=is%%3Aissue+label%%3A"integration%%3A+recorder"+', + MAX_DUPLICATES - 1, + ) + + deleted_short_term_statistics_rows, _ = _delete_duplicates_from_table( + session, StatisticsShortTerm + ) + if deleted_short_term_statistics_rows: + _LOGGER.warning( + "Deleted duplicated short term statistic rows, please report at " + 'https://github.com/home-assistant/core/issues?q=is%%3Aissue+label%%3A"integration%%3A+recorder"+' + ) + + def compile_hourly_statistics( instance: Recorder, session: scoped_session, start: datetime ) -> None: @@ -411,7 +549,10 @@ def compile_statistics(instance: Recorder, start: datetime) -> bool: platform_stats.extend(platform_stat) # Insert collected statistics in the database - with session_scope(session=instance.get_session()) as session: # type: ignore + with session_scope( + session=instance.get_session(), # type: ignore + exception_filter=_filter_unique_constraint_integrity_error(instance), + ) as session: for stats in platform_stats: metadata_id = _update_or_add_metadata(instance.hass, session, stats["meta"]) _insert_statistics( @@ -1066,6 +1207,43 @@ def async_add_external_statistics( hass.data[DATA_INSTANCE].async_external_statistics(metadata, statistics) +def _filter_unique_constraint_integrity_error( + instance: Recorder, +) -> Callable[[Exception], bool]: + def _filter_unique_constraint_integrity_error(err: Exception) -> bool: + """Handle unique constraint integrity errors.""" + if not isinstance(err, StatementError): + return False + + ignore = False + if ( + instance.engine.dialect.name == "sqlite" + and "UNIQUE constraint failed" in str(err) + ): + ignore = True + if ( + instance.engine.dialect.name == "postgresql" + and hasattr(err.orig, "pgcode") + and err.orig.pgcode == "23505" + ): + ignore = True + if instance.engine.dialect.name == "mysql" and hasattr(err.orig, "args"): + with contextlib.suppress(TypeError): + if err.orig.args[0] == 1062: + ignore = True + + if ignore: + _LOGGER.warning( + "Blocked attempt to insert duplicated statistic rows, please report at " + 'https://github.com/home-assistant/core/issues?q=is%%3Aissue+label%%3A"integration%%3A+recorder"+', + exc_info=err, + ) + + return ignore + + return _filter_unique_constraint_integrity_error + + @retryable_database_job("statistics") def add_external_statistics( instance: Recorder, @@ -1073,7 +1251,11 @@ def add_external_statistics( statistics: Iterable[StatisticData], ) -> bool: """Process an add_statistics job.""" - with session_scope(session=instance.get_session()) as session: # type: ignore + + with session_scope( + session=instance.get_session(), # type: ignore + exception_filter=_filter_unique_constraint_integrity_error(instance), + ) as session: metadata_id = _update_or_add_metadata(instance.hass, session, metadata) for stat in statistics: if stat_id := _statistics_exists( diff --git a/homeassistant/components/recorder/util.py b/homeassistant/components/recorder/util.py index 3900641db63..734694b8224 100644 --- a/homeassistant/components/recorder/util.py +++ b/homeassistant/components/recorder/util.py @@ -66,7 +66,10 @@ RETRYABLE_MYSQL_ERRORS = (1205, 1206, 1213) @contextmanager def session_scope( - *, hass: HomeAssistant | None = None, session: Session | None = None + *, + hass: HomeAssistant | None = None, + session: Session | None = None, + exception_filter: Callable[[Exception], bool] | None = None, ) -> Generator[Session, None, None]: """Provide a transactional scope around a series of operations.""" if session is None and hass is not None: @@ -81,11 +84,12 @@ def session_scope( if session.get_transaction(): need_rollback = True session.commit() - except Exception as err: + except Exception as err: # pylint: disable=broad-except _LOGGER.error("Error executing query: %s", err) if need_rollback: session.rollback() - raise + if not exception_filter or not exception_filter(err): + raise finally: session.close() diff --git a/tests/components/recorder/models_schema_23.py b/tests/components/recorder/models_schema_23.py new file mode 100644 index 00000000000..50839f41906 --- /dev/null +++ b/tests/components/recorder/models_schema_23.py @@ -0,0 +1,582 @@ +"""Models for SQLAlchemy. + +This file contains the model definitions for schema version 23, +used by Home Assistant Core 2021.11.0, which adds the name column +to statistics_meta. +It is used to test the schema migration logic. +""" +from __future__ import annotations + +from datetime import datetime, timedelta +import json +import logging +from typing import TypedDict, overload + +from sqlalchemy import ( + Boolean, + Column, + DateTime, + Float, + ForeignKey, + Identity, + Index, + Integer, + String, + Text, + distinct, +) +from sqlalchemy.dialects import mysql, oracle, postgresql +from sqlalchemy.ext.declarative import declared_attr +from sqlalchemy.orm import declarative_base, relationship +from sqlalchemy.orm.session import Session + +from homeassistant.const import ( + MAX_LENGTH_EVENT_CONTEXT_ID, + MAX_LENGTH_EVENT_EVENT_TYPE, + MAX_LENGTH_EVENT_ORIGIN, + MAX_LENGTH_STATE_DOMAIN, + MAX_LENGTH_STATE_ENTITY_ID, + MAX_LENGTH_STATE_STATE, +) +from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id +from homeassistant.helpers.json import JSONEncoder +import homeassistant.util.dt as dt_util + +# SQLAlchemy Schema +# pylint: disable=invalid-name +Base = declarative_base() + +SCHEMA_VERSION = 23 + +_LOGGER = logging.getLogger(__name__) + +DB_TIMEZONE = "+00:00" + +TABLE_EVENTS = "events" +TABLE_STATES = "states" +TABLE_RECORDER_RUNS = "recorder_runs" +TABLE_SCHEMA_CHANGES = "schema_changes" +TABLE_STATISTICS = "statistics" +TABLE_STATISTICS_META = "statistics_meta" +TABLE_STATISTICS_RUNS = "statistics_runs" +TABLE_STATISTICS_SHORT_TERM = "statistics_short_term" + +ALL_TABLES = [ + TABLE_STATES, + TABLE_EVENTS, + TABLE_RECORDER_RUNS, + TABLE_SCHEMA_CHANGES, + TABLE_STATISTICS, + TABLE_STATISTICS_META, + TABLE_STATISTICS_RUNS, + TABLE_STATISTICS_SHORT_TERM, +] + +DATETIME_TYPE = DateTime(timezone=True).with_variant( + mysql.DATETIME(timezone=True, fsp=6), "mysql" +) +DOUBLE_TYPE = ( + Float() + .with_variant(mysql.DOUBLE(asdecimal=False), "mysql") + .with_variant(oracle.DOUBLE_PRECISION(), "oracle") + .with_variant(postgresql.DOUBLE_PRECISION(), "postgresql") +) + + +class Events(Base): # type: ignore + """Event history data.""" + + __table_args__ = ( + # Used for fetching events at a specific time + # see logbook + Index("ix_events_event_type_time_fired", "event_type", "time_fired"), + {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, + ) + __tablename__ = TABLE_EVENTS + event_id = Column(Integer, Identity(), primary_key=True) + event_type = Column(String(MAX_LENGTH_EVENT_EVENT_TYPE)) + event_data = Column(Text().with_variant(mysql.LONGTEXT, "mysql")) + origin = Column(String(MAX_LENGTH_EVENT_ORIGIN)) + time_fired = Column(DATETIME_TYPE, index=True) + created = Column(DATETIME_TYPE, default=dt_util.utcnow) + context_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True) + context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True) + context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + @staticmethod + def from_event(event, event_data=None): + """Create an event database object from a native event.""" + return Events( + event_type=event.event_type, + event_data=event_data + or json.dumps(event.data, cls=JSONEncoder, separators=(",", ":")), + origin=str(event.origin.value), + time_fired=event.time_fired, + context_id=event.context.id, + context_user_id=event.context.user_id, + context_parent_id=event.context.parent_id, + ) + + def to_native(self, validate_entity_id=True): + """Convert to a native HA Event.""" + context = Context( + id=self.context_id, + user_id=self.context_user_id, + parent_id=self.context_parent_id, + ) + try: + return Event( + self.event_type, + json.loads(self.event_data), + EventOrigin(self.origin), + process_timestamp(self.time_fired), + context=context, + ) + except ValueError: + # When json.loads fails + _LOGGER.exception("Error converting to event: %s", self) + return None + + +class States(Base): # type: ignore + """State change history.""" + + __table_args__ = ( + # Used for fetching the state of entities at a specific time + # (get_states in history.py) + Index("ix_states_entity_id_last_updated", "entity_id", "last_updated"), + {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, + ) + __tablename__ = TABLE_STATES + state_id = Column(Integer, Identity(), primary_key=True) + domain = Column(String(MAX_LENGTH_STATE_DOMAIN)) + entity_id = Column(String(MAX_LENGTH_STATE_ENTITY_ID)) + state = Column(String(MAX_LENGTH_STATE_STATE)) + attributes = Column(Text().with_variant(mysql.LONGTEXT, "mysql")) + event_id = Column( + Integer, ForeignKey("events.event_id", ondelete="CASCADE"), index=True + ) + last_changed = Column(DATETIME_TYPE, default=dt_util.utcnow) + last_updated = Column(DATETIME_TYPE, default=dt_util.utcnow, index=True) + created = Column(DATETIME_TYPE, default=dt_util.utcnow) + old_state_id = Column(Integer, ForeignKey("states.state_id"), index=True) + event = relationship("Events", uselist=False) + old_state = relationship("States", remote_side=[state_id]) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + @staticmethod + def from_event(event): + """Create object from a state_changed event.""" + entity_id = event.data["entity_id"] + state = event.data.get("new_state") + + dbstate = States(entity_id=entity_id) + + # State got deleted + if state is None: + dbstate.state = "" + dbstate.domain = split_entity_id(entity_id)[0] + dbstate.attributes = "{}" + dbstate.last_changed = event.time_fired + dbstate.last_updated = event.time_fired + else: + dbstate.domain = state.domain + dbstate.state = state.state + dbstate.attributes = json.dumps( + dict(state.attributes), cls=JSONEncoder, separators=(",", ":") + ) + dbstate.last_changed = state.last_changed + dbstate.last_updated = state.last_updated + + return dbstate + + def to_native(self, validate_entity_id=True): + """Convert to an HA state object.""" + try: + return State( + self.entity_id, + self.state, + json.loads(self.attributes), + process_timestamp(self.last_changed), + process_timestamp(self.last_updated), + # Join the events table on event_id to get the context instead + # as it will always be there for state_changed events + context=Context(id=None), + validate_entity_id=validate_entity_id, + ) + except ValueError: + # When json.loads fails + _LOGGER.exception("Error converting row to state: %s", self) + return None + + +class StatisticResult(TypedDict): + """Statistic result data class. + + Allows multiple datapoints for the same statistic_id. + """ + + meta: StatisticMetaData + stat: StatisticData + + +class StatisticDataBase(TypedDict): + """Mandatory fields for statistic data class.""" + + start: datetime + + +class StatisticData(StatisticDataBase, total=False): + """Statistic data class.""" + + mean: float + min: float + max: float + last_reset: datetime | None + state: float + sum: float + + +class StatisticsBase: + """Statistics base class.""" + + id = Column(Integer, Identity(), primary_key=True) + created = Column(DATETIME_TYPE, default=dt_util.utcnow) + + @declared_attr + def metadata_id(self): + """Define the metadata_id column for sub classes.""" + return Column( + Integer, + ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"), + index=True, + ) + + start = Column(DATETIME_TYPE, index=True) + mean = Column(DOUBLE_TYPE) + min = Column(DOUBLE_TYPE) + max = Column(DOUBLE_TYPE) + last_reset = Column(DATETIME_TYPE) + state = Column(DOUBLE_TYPE) + sum = Column(DOUBLE_TYPE) + + @classmethod + def from_stats(cls, metadata_id: int, stats: StatisticData): + """Create object from a statistics.""" + return cls( # type: ignore + metadata_id=metadata_id, + **stats, + ) + + +class Statistics(Base, StatisticsBase): # type: ignore + """Long term statistics.""" + + duration = timedelta(hours=1) + + __table_args__ = ( + # Used for fetching statistics for a certain entity at a specific time + Index("ix_statistics_statistic_id_start", "metadata_id", "start"), + ) + __tablename__ = TABLE_STATISTICS + + +class StatisticsShortTerm(Base, StatisticsBase): # type: ignore + """Short term statistics.""" + + duration = timedelta(minutes=5) + + __table_args__ = ( + # Used for fetching statistics for a certain entity at a specific time + Index("ix_statistics_short_term_statistic_id_start", "metadata_id", "start"), + ) + __tablename__ = TABLE_STATISTICS_SHORT_TERM + + +class StatisticMetaData(TypedDict): + """Statistic meta data class.""" + + has_mean: bool + has_sum: bool + name: str | None + source: str + statistic_id: str + unit_of_measurement: str | None + + +class StatisticsMeta(Base): # type: ignore + """Statistics meta data.""" + + __table_args__ = ( + {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, + ) + __tablename__ = TABLE_STATISTICS_META + id = Column(Integer, Identity(), primary_key=True) + statistic_id = Column(String(255), index=True) + source = Column(String(32)) + unit_of_measurement = Column(String(255)) + has_mean = Column(Boolean) + has_sum = Column(Boolean) + name = Column(String(255)) + + @staticmethod + def from_meta(meta: StatisticMetaData) -> StatisticsMeta: + """Create object from meta data.""" + return StatisticsMeta(**meta) + + +class RecorderRuns(Base): # type: ignore + """Representation of recorder run.""" + + __table_args__ = (Index("ix_recorder_runs_start_end", "start", "end"),) + __tablename__ = TABLE_RECORDER_RUNS + run_id = Column(Integer, Identity(), primary_key=True) + start = Column(DateTime(timezone=True), default=dt_util.utcnow) + end = Column(DateTime(timezone=True)) + closed_incorrect = Column(Boolean, default=False) + created = Column(DateTime(timezone=True), default=dt_util.utcnow) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + end = ( + f"'{self.end.isoformat(sep=' ', timespec='seconds')}'" if self.end else None + ) + return ( + f"" + ) + + def entity_ids(self, point_in_time=None): + """Return the entity ids that existed in this run. + + Specify point_in_time if you want to know which existed at that point + in time inside the run. + """ + session = Session.object_session(self) + + assert session is not None, "RecorderRuns need to be persisted" + + query = session.query(distinct(States.entity_id)).filter( + States.last_updated >= self.start + ) + + if point_in_time is not None: + query = query.filter(States.last_updated < point_in_time) + elif self.end is not None: + query = query.filter(States.last_updated < self.end) + + return [row[0] for row in query] + + def to_native(self, validate_entity_id=True): + """Return self, native format is this model.""" + return self + + +class SchemaChanges(Base): # type: ignore + """Representation of schema version changes.""" + + __tablename__ = TABLE_SCHEMA_CHANGES + change_id = Column(Integer, Identity(), primary_key=True) + schema_version = Column(Integer) + changed = Column(DateTime(timezone=True), default=dt_util.utcnow) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + +class StatisticsRuns(Base): # type: ignore + """Representation of statistics run.""" + + __tablename__ = TABLE_STATISTICS_RUNS + run_id = Column(Integer, Identity(), primary_key=True) + start = Column(DateTime(timezone=True)) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + +@overload +def process_timestamp(ts: None) -> None: + ... + + +@overload +def process_timestamp(ts: datetime) -> datetime: + ... + + +def process_timestamp(ts: datetime | None) -> datetime | None: + """Process a timestamp into datetime object.""" + if ts is None: + return None + if ts.tzinfo is None: + return ts.replace(tzinfo=dt_util.UTC) + + return dt_util.as_utc(ts) + + +@overload +def process_timestamp_to_utc_isoformat(ts: None) -> None: + ... + + +@overload +def process_timestamp_to_utc_isoformat(ts: datetime) -> str: + ... + + +def process_timestamp_to_utc_isoformat(ts: datetime | None) -> str | None: + """Process a timestamp into UTC isotime.""" + if ts is None: + return None + if ts.tzinfo == dt_util.UTC: + return ts.isoformat() + if ts.tzinfo is None: + return f"{ts.isoformat()}{DB_TIMEZONE}" + return ts.astimezone(dt_util.UTC).isoformat() + + +class LazyState(State): + """A lazy version of core State.""" + + __slots__ = [ + "_row", + "entity_id", + "state", + "_attributes", + "_last_changed", + "_last_updated", + "_context", + ] + + def __init__(self, row): # pylint: disable=super-init-not-called + """Init the lazy state.""" + self._row = row + self.entity_id = self._row.entity_id + self.state = self._row.state or "" + self._attributes = None + self._last_changed = None + self._last_updated = None + self._context = None + + @property # type: ignore + def attributes(self): + """State attributes.""" + if not self._attributes: + try: + self._attributes = json.loads(self._row.attributes) + except ValueError: + # When json.loads fails + _LOGGER.exception("Error converting row to state: %s", self._row) + self._attributes = {} + return self._attributes + + @attributes.setter + def attributes(self, value): + """Set attributes.""" + self._attributes = value + + @property # type: ignore + def context(self): + """State context.""" + if not self._context: + self._context = Context(id=None) + return self._context + + @context.setter + def context(self, value): + """Set context.""" + self._context = value + + @property # type: ignore + def last_changed(self): + """Last changed datetime.""" + if not self._last_changed: + self._last_changed = process_timestamp(self._row.last_changed) + return self._last_changed + + @last_changed.setter + def last_changed(self, value): + """Set last changed datetime.""" + self._last_changed = value + + @property # type: ignore + def last_updated(self): + """Last updated datetime.""" + if not self._last_updated: + self._last_updated = process_timestamp(self._row.last_updated) + return self._last_updated + + @last_updated.setter + def last_updated(self, value): + """Set last updated datetime.""" + self._last_updated = value + + def as_dict(self): + """Return a dict representation of the LazyState. + + Async friendly. + + To be used for JSON serialization. + """ + if self._last_changed: + last_changed_isoformat = self._last_changed.isoformat() + else: + last_changed_isoformat = process_timestamp_to_utc_isoformat( + self._row.last_changed + ) + if self._last_updated: + last_updated_isoformat = self._last_updated.isoformat() + else: + last_updated_isoformat = process_timestamp_to_utc_isoformat( + self._row.last_updated + ) + return { + "entity_id": self.entity_id, + "state": self.state, + "attributes": self._attributes or self.attributes, + "last_changed": last_changed_isoformat, + "last_updated": last_updated_isoformat, + } + + def __eq__(self, other): + """Return the comparison.""" + return ( + other.__class__ in [self.__class__, State] + and self.entity_id == other.entity_id + and self.state == other.state + and self.attributes == other.attributes + ) diff --git a/tests/components/recorder/test_statistics.py b/tests/components/recorder/test_statistics.py index c4dd33ce840..0f4468019e9 100644 --- a/tests/components/recorder/test_statistics.py +++ b/tests/components/recorder/test_statistics.py @@ -1,12 +1,18 @@ """The tests for sensor recorder platform.""" # pylint: disable=protected-access,invalid-name from datetime import timedelta +import importlib +import json +import sys from unittest.mock import patch, sentinel import pytest from pytest import approx +from sqlalchemy import create_engine +from sqlalchemy.orm import Session -from homeassistant.components.recorder import history +from homeassistant.components import recorder +from homeassistant.components.recorder import SQLITE_URL_PREFIX, history, statistics from homeassistant.components.recorder.const import DATA_INSTANCE from homeassistant.components.recorder.models import ( StatisticsShortTerm, @@ -14,18 +20,20 @@ from homeassistant.components.recorder.models import ( ) from homeassistant.components.recorder.statistics import ( async_add_external_statistics, + delete_duplicates, get_last_short_term_statistics, get_last_statistics, get_metadata, list_statistic_ids, statistics_during_period, ) +from homeassistant.components.recorder.util import session_scope from homeassistant.const import TEMP_CELSIUS from homeassistant.exceptions import HomeAssistantError from homeassistant.setup import setup_component import homeassistant.util.dt as dt_util -from tests.common import mock_registry +from tests.common import get_test_home_assistant, mock_registry from tests.components.recorder.common import wait_recording_done @@ -650,6 +658,624 @@ def test_monthly_statistics(hass_recorder, caplog, timezone): dt_util.set_default_time_zone(dt_util.get_time_zone("UTC")) +def _create_engine_test(*args, **kwargs): + """Test version of create_engine that initializes with old schema. + + This simulates an existing db with the old schema. + """ + module = "tests.components.recorder.models_schema_23" + importlib.import_module(module) + old_models = sys.modules[module] + engine = create_engine(*args, **kwargs) + old_models.Base.metadata.create_all(engine) + with Session(engine) as session: + session.add(recorder.models.StatisticsRuns(start=statistics.get_start_time())) + session.add( + recorder.models.SchemaChanges(schema_version=old_models.SCHEMA_VERSION) + ) + session.commit() + return engine + + +def test_delete_duplicates(caplog, tmpdir): + """Test removal of duplicated statistics.""" + test_db_file = tmpdir.mkdir("sqlite").join("test_run_info.db") + dburl = f"{SQLITE_URL_PREFIX}//{test_db_file}" + + module = "tests.components.recorder.models_schema_23" + importlib.import_module(module) + old_models = sys.modules[module] + + period1 = dt_util.as_utc(dt_util.parse_datetime("2021-09-01 00:00:00")) + period2 = dt_util.as_utc(dt_util.parse_datetime("2021-09-30 23:00:00")) + period3 = dt_util.as_utc(dt_util.parse_datetime("2021-10-01 00:00:00")) + period4 = dt_util.as_utc(dt_util.parse_datetime("2021-10-31 23:00:00")) + + external_energy_statistics_1 = ( + { + "start": period1, + "last_reset": None, + "state": 0, + "sum": 2, + }, + { + "start": period2, + "last_reset": None, + "state": 1, + "sum": 3, + }, + { + "start": period3, + "last_reset": None, + "state": 2, + "sum": 4, + }, + { + "start": period4, + "last_reset": None, + "state": 3, + "sum": 5, + }, + { + "start": period4, + "last_reset": None, + "state": 3, + "sum": 5, + }, + ) + external_energy_metadata_1 = { + "has_mean": False, + "has_sum": True, + "name": "Total imported energy", + "source": "test", + "statistic_id": "test:total_energy_import_tariff_1", + "unit_of_measurement": "kWh", + } + external_energy_statistics_2 = ( + { + "start": period1, + "last_reset": None, + "state": 0, + "sum": 20, + }, + { + "start": period2, + "last_reset": None, + "state": 1, + "sum": 30, + }, + { + "start": period3, + "last_reset": None, + "state": 2, + "sum": 40, + }, + { + "start": period4, + "last_reset": None, + "state": 3, + "sum": 50, + }, + { + "start": period4, + "last_reset": None, + "state": 3, + "sum": 50, + }, + ) + external_energy_metadata_2 = { + "has_mean": False, + "has_sum": True, + "name": "Total imported energy", + "source": "test", + "statistic_id": "test:total_energy_import_tariff_2", + "unit_of_measurement": "kWh", + } + external_co2_statistics = ( + { + "start": period1, + "last_reset": None, + "mean": 10, + }, + { + "start": period2, + "last_reset": None, + "mean": 30, + }, + { + "start": period3, + "last_reset": None, + "mean": 60, + }, + { + "start": period4, + "last_reset": None, + "mean": 90, + }, + ) + external_co2_metadata = { + "has_mean": True, + "has_sum": False, + "name": "Fossil percentage", + "source": "test", + "statistic_id": "test:fossil_percentage", + "unit_of_measurement": "%", + } + + # Create some duplicated statistics with schema version 23 + with patch.object(recorder, "models", old_models), patch.object( + recorder.migration, "SCHEMA_VERSION", old_models.SCHEMA_VERSION + ), patch( + "homeassistant.components.recorder.create_engine", new=_create_engine_test + ): + hass = get_test_home_assistant() + setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) + wait_recording_done(hass) + wait_recording_done(hass) + + with session_scope(hass=hass) as session: + session.add( + recorder.models.StatisticsMeta.from_meta(external_energy_metadata_1) + ) + session.add( + recorder.models.StatisticsMeta.from_meta(external_energy_metadata_2) + ) + session.add(recorder.models.StatisticsMeta.from_meta(external_co2_metadata)) + with session_scope(hass=hass) as session: + for stat in external_energy_statistics_1: + session.add(recorder.models.Statistics.from_stats(1, stat)) + for stat in external_energy_statistics_2: + session.add(recorder.models.Statistics.from_stats(2, stat)) + for stat in external_co2_statistics: + session.add(recorder.models.Statistics.from_stats(3, stat)) + + hass.stop() + + # Test that the duplicates are removed during migration from schema 23 + hass = get_test_home_assistant() + setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) + hass.start() + wait_recording_done(hass) + wait_recording_done(hass) + hass.stop() + + assert "Deleted 2 duplicated statistics rows" in caplog.text + assert "Found non identical" not in caplog.text + assert "Found more than" not in caplog.text + assert "Found duplicated" not in caplog.text + + +@pytest.mark.freeze_time("2021-08-01 00:00:00+00:00") +def test_delete_duplicates_non_identical(caplog, tmpdir): + """Test removal of duplicated statistics.""" + test_db_file = tmpdir.mkdir("sqlite").join("test_run_info.db") + dburl = f"{SQLITE_URL_PREFIX}//{test_db_file}" + + module = "tests.components.recorder.models_schema_23" + importlib.import_module(module) + old_models = sys.modules[module] + + period1 = dt_util.as_utc(dt_util.parse_datetime("2021-09-01 00:00:00")) + period2 = dt_util.as_utc(dt_util.parse_datetime("2021-09-30 23:00:00")) + period3 = dt_util.as_utc(dt_util.parse_datetime("2021-10-01 00:00:00")) + period4 = dt_util.as_utc(dt_util.parse_datetime("2021-10-31 23:00:00")) + + external_energy_statistics_1 = ( + { + "start": period1, + "last_reset": None, + "state": 0, + "sum": 2, + }, + { + "start": period2, + "last_reset": None, + "state": 1, + "sum": 3, + }, + { + "start": period3, + "last_reset": None, + "state": 2, + "sum": 4, + }, + { + "start": period4, + "last_reset": None, + "state": 3, + "sum": 5, + }, + { + "start": period4, + "last_reset": None, + "state": 3, + "sum": 6, + }, + ) + external_energy_metadata_1 = { + "has_mean": False, + "has_sum": True, + "name": "Total imported energy", + "source": "test", + "statistic_id": "test:total_energy_import_tariff_1", + "unit_of_measurement": "kWh", + } + external_energy_statistics_2 = ( + { + "start": period1, + "last_reset": None, + "state": 0, + "sum": 20, + }, + { + "start": period2, + "last_reset": None, + "state": 1, + "sum": 30, + }, + { + "start": period3, + "last_reset": None, + "state": 2, + "sum": 40, + }, + { + "start": period4, + "last_reset": None, + "state": 3, + "sum": 50, + }, + { + "start": period4, + "last_reset": None, + "state": 3, + "sum": 50, + }, + ) + external_energy_metadata_2 = { + "has_mean": False, + "has_sum": True, + "name": "Total imported energy", + "source": "test", + "statistic_id": "test:total_energy_import_tariff_2", + "unit_of_measurement": "kWh", + } + + # Create some duplicated statistics with schema version 23 + with patch.object(recorder, "models", old_models), patch.object( + recorder.migration, "SCHEMA_VERSION", old_models.SCHEMA_VERSION + ), patch( + "homeassistant.components.recorder.create_engine", new=_create_engine_test + ): + hass = get_test_home_assistant() + setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) + wait_recording_done(hass) + wait_recording_done(hass) + + with session_scope(hass=hass) as session: + session.add( + recorder.models.StatisticsMeta.from_meta(external_energy_metadata_1) + ) + session.add( + recorder.models.StatisticsMeta.from_meta(external_energy_metadata_2) + ) + with session_scope(hass=hass) as session: + for stat in external_energy_statistics_1: + session.add(recorder.models.Statistics.from_stats(1, stat)) + for stat in external_energy_statistics_2: + session.add(recorder.models.Statistics.from_stats(2, stat)) + + hass.stop() + + # Test that the duplicates are removed during migration from schema 23 + hass = get_test_home_assistant() + hass.config.config_dir = tmpdir + setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) + hass.start() + wait_recording_done(hass) + wait_recording_done(hass) + hass.stop() + + assert "Deleted 2 duplicated statistics rows" in caplog.text + assert "Deleted 1 non identical" in caplog.text + assert "Found more than" not in caplog.text + assert "Found duplicated" not in caplog.text + + isotime = dt_util.utcnow().isoformat() + backup_file_name = f"deleted_statistics.{isotime}.json" + + with open(hass.config.path(backup_file_name)) as backup_file: + backup = json.load(backup_file) + + assert backup == [ + { + "created": "2021-08-01T00:00:00", + "id": 4, + "last_reset": None, + "max": None, + "mean": None, + "metadata_id": 1, + "min": None, + "start": "2021-10-31T23:00:00", + "state": 3.0, + "sum": 5.0, + } + ] + + +@patch.object(statistics, "MAX_DUPLICATES", 2) +def test_delete_duplicates_too_many(caplog, tmpdir): + """Test removal of duplicated statistics.""" + test_db_file = tmpdir.mkdir("sqlite").join("test_run_info.db") + dburl = f"{SQLITE_URL_PREFIX}//{test_db_file}" + + module = "tests.components.recorder.models_schema_23" + importlib.import_module(module) + old_models = sys.modules[module] + + period1 = dt_util.as_utc(dt_util.parse_datetime("2021-09-01 00:00:00")) + period2 = dt_util.as_utc(dt_util.parse_datetime("2021-09-30 23:00:00")) + period3 = dt_util.as_utc(dt_util.parse_datetime("2021-10-01 00:00:00")) + period4 = dt_util.as_utc(dt_util.parse_datetime("2021-10-31 23:00:00")) + + external_energy_statistics_1 = ( + { + "start": period1, + "last_reset": None, + "state": 0, + "sum": 2, + }, + { + "start": period2, + "last_reset": None, + "state": 1, + "sum": 3, + }, + { + "start": period3, + "last_reset": None, + "state": 2, + "sum": 4, + }, + { + "start": period4, + "last_reset": None, + "state": 3, + "sum": 5, + }, + { + "start": period4, + "last_reset": None, + "state": 3, + "sum": 5, + }, + ) + external_energy_metadata_1 = { + "has_mean": False, + "has_sum": True, + "name": "Total imported energy", + "source": "test", + "statistic_id": "test:total_energy_import_tariff_1", + "unit_of_measurement": "kWh", + } + external_energy_statistics_2 = ( + { + "start": period1, + "last_reset": None, + "state": 0, + "sum": 20, + }, + { + "start": period2, + "last_reset": None, + "state": 1, + "sum": 30, + }, + { + "start": period3, + "last_reset": None, + "state": 2, + "sum": 40, + }, + { + "start": period4, + "last_reset": None, + "state": 3, + "sum": 50, + }, + { + "start": period4, + "last_reset": None, + "state": 3, + "sum": 50, + }, + ) + external_energy_metadata_2 = { + "has_mean": False, + "has_sum": True, + "name": "Total imported energy", + "source": "test", + "statistic_id": "test:total_energy_import_tariff_2", + "unit_of_measurement": "kWh", + } + + # Create some duplicated statistics with schema version 23 + with patch.object(recorder, "models", old_models), patch.object( + recorder.migration, "SCHEMA_VERSION", old_models.SCHEMA_VERSION + ), patch( + "homeassistant.components.recorder.create_engine", new=_create_engine_test + ): + hass = get_test_home_assistant() + setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) + wait_recording_done(hass) + wait_recording_done(hass) + + with session_scope(hass=hass) as session: + session.add( + recorder.models.StatisticsMeta.from_meta(external_energy_metadata_1) + ) + session.add( + recorder.models.StatisticsMeta.from_meta(external_energy_metadata_2) + ) + with session_scope(hass=hass) as session: + for stat in external_energy_statistics_1: + session.add(recorder.models.Statistics.from_stats(1, stat)) + for stat in external_energy_statistics_2: + session.add(recorder.models.Statistics.from_stats(2, stat)) + + hass.stop() + + # Test that the duplicates are removed during migration from schema 23 + hass = get_test_home_assistant() + hass.config.config_dir = tmpdir + setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) + hass.start() + wait_recording_done(hass) + wait_recording_done(hass) + hass.stop() + + assert "Deleted 2 duplicated statistics rows" in caplog.text + assert "Found non identical" not in caplog.text + assert "Found more than 1 duplicated statistic rows" in caplog.text + assert "Found duplicated" not in caplog.text + + +@patch.object(statistics, "MAX_DUPLICATES", 2) +def test_delete_duplicates_short_term(caplog, tmpdir): + """Test removal of duplicated statistics.""" + test_db_file = tmpdir.mkdir("sqlite").join("test_run_info.db") + dburl = f"{SQLITE_URL_PREFIX}//{test_db_file}" + + module = "tests.components.recorder.models_schema_23" + importlib.import_module(module) + old_models = sys.modules[module] + + period4 = dt_util.as_utc(dt_util.parse_datetime("2021-10-31 23:00:00")) + + external_energy_metadata_1 = { + "has_mean": False, + "has_sum": True, + "name": "Total imported energy", + "source": "test", + "statistic_id": "test:total_energy_import_tariff_1", + "unit_of_measurement": "kWh", + } + statistic_row = { + "start": period4, + "last_reset": None, + "state": 3, + "sum": 5, + } + + # Create some duplicated statistics with schema version 23 + with patch.object(recorder, "models", old_models), patch.object( + recorder.migration, "SCHEMA_VERSION", old_models.SCHEMA_VERSION + ), patch( + "homeassistant.components.recorder.create_engine", new=_create_engine_test + ): + hass = get_test_home_assistant() + setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) + wait_recording_done(hass) + wait_recording_done(hass) + + with session_scope(hass=hass) as session: + session.add( + recorder.models.StatisticsMeta.from_meta(external_energy_metadata_1) + ) + with session_scope(hass=hass) as session: + session.add( + recorder.models.StatisticsShortTerm.from_stats(1, statistic_row) + ) + session.add( + recorder.models.StatisticsShortTerm.from_stats(1, statistic_row) + ) + + hass.stop() + + # Test that the duplicates are removed during migration from schema 23 + hass = get_test_home_assistant() + hass.config.config_dir = tmpdir + setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) + hass.start() + wait_recording_done(hass) + wait_recording_done(hass) + hass.stop() + + assert "duplicated statistics rows" not in caplog.text + assert "Found non identical" not in caplog.text + assert "Found more than" not in caplog.text + assert "Deleted duplicated short term statistic" in caplog.text + + +def test_delete_duplicates_no_duplicates(hass_recorder, caplog): + """Test removal of duplicated statistics.""" + hass = hass_recorder() + wait_recording_done(hass) + with session_scope(hass=hass) as session: + delete_duplicates(hass.data[DATA_INSTANCE], session) + assert "duplicated statistics rows" not in caplog.text + assert "Found non identical" not in caplog.text + assert "Found more than" not in caplog.text + assert "Found duplicated" not in caplog.text + + +def test_duplicate_statistics_handle_integrity_error(hass_recorder, caplog): + """Test the recorder does not blow up if statistics is duplicated.""" + hass = hass_recorder() + wait_recording_done(hass) + + period1 = dt_util.as_utc(dt_util.parse_datetime("2021-09-01 00:00:00")) + period2 = dt_util.as_utc(dt_util.parse_datetime("2021-09-30 23:00:00")) + + external_energy_metadata_1 = { + "has_mean": False, + "has_sum": True, + "name": "Total imported energy", + "source": "test", + "statistic_id": "test:total_energy_import_tariff_1", + "unit_of_measurement": "kWh", + } + external_energy_statistics_1 = [ + { + "start": period1, + "last_reset": None, + "state": 3, + "sum": 5, + }, + ] + external_energy_statistics_2 = [ + { + "start": period2, + "last_reset": None, + "state": 3, + "sum": 6, + } + ] + + with patch.object( + statistics, "_statistics_exists", return_value=False + ), patch.object( + statistics, "_insert_statistics", wraps=statistics._insert_statistics + ) as insert_statistics_mock: + async_add_external_statistics( + hass, external_energy_metadata_1, external_energy_statistics_1 + ) + async_add_external_statistics( + hass, external_energy_metadata_1, external_energy_statistics_1 + ) + async_add_external_statistics( + hass, external_energy_metadata_1, external_energy_statistics_2 + ) + wait_recording_done(hass) + assert insert_statistics_mock.call_count == 3 + + with session_scope(hass=hass) as session: + tmp = session.query(recorder.models.Statistics).all() + assert len(tmp) == 2 + + assert "Blocked attempt to insert duplicated statistic rows" in caplog.text + + def record_states(hass): """Record some test states.