diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index 2ba07e42f49..7df4cf57e56 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -35,6 +35,7 @@ from homeassistant.helpers.event import ( async_track_time_interval, async_track_utc_time_change, ) +from homeassistant.helpers.typing import UNDEFINED, UndefinedType import homeassistant.util.dt as dt_util from . import migration, statistics @@ -461,10 +462,18 @@ class Recorder(threading.Thread): @callback def async_update_statistics_metadata( - self, statistic_id: str, unit_of_measurement: str | None + self, + statistic_id: str, + *, + new_statistic_id: str | UndefinedType = UNDEFINED, + new_unit_of_measurement: str | None | UndefinedType = UNDEFINED, ) -> None: """Update statistics metadata for a statistic_id.""" - self.queue_task(UpdateStatisticsMetadataTask(statistic_id, unit_of_measurement)) + self.queue_task( + UpdateStatisticsMetadataTask( + statistic_id, new_statistic_id, new_unit_of_measurement + ) + ) @callback def async_external_statistics( diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index 5da31f18781..eadfc543b59 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -33,7 +33,11 @@ from .models import ( StatisticsShortTerm, process_timestamp, ) -from .statistics import delete_duplicates, get_start_time +from .statistics import ( + delete_statistics_duplicates, + delete_statistics_meta_duplicates, + get_start_time, +) from .util import session_scope _LOGGER = logging.getLogger(__name__) @@ -670,7 +674,7 @@ def _apply_update( # noqa: C901 # There may be duplicated statistics entries, delete duplicated statistics # and try again with session_scope(session=session_maker()) as session: - delete_duplicates(hass, session) + delete_statistics_duplicates(hass, session) _create_index( session_maker, "statistics", "ix_statistics_statistic_id_start" ) @@ -705,6 +709,21 @@ def _apply_update( # noqa: C901 _create_index(session_maker, "states", "ix_states_context_id") # Once there are no longer any state_changed events # in the events table we can drop the index on states.event_id + elif new_version == 29: + # Recreate statistics_meta index to block duplicated statistic_id + _drop_index(session_maker, "statistics_meta", "ix_statistics_meta_statistic_id") + try: + _create_index( + session_maker, "statistics_meta", "ix_statistics_meta_statistic_id" + ) + except DatabaseError: + # There may be duplicated statistics_meta entries, delete duplicates + # and try again + with session_scope(session=session_maker()) as session: + delete_statistics_meta_duplicates(session) + _create_index( + session_maker, "statistics_meta", "ix_statistics_meta_statistic_id" + ) else: raise ValueError(f"No schema migration defined for version {new_version}") diff --git a/homeassistant/components/recorder/models.py b/homeassistant/components/recorder/models.py index 4dabd7899e0..90c2e5e5616 100644 --- a/homeassistant/components/recorder/models.py +++ b/homeassistant/components/recorder/models.py @@ -51,7 +51,7 @@ from .const import ALL_DOMAIN_EXCLUDE_ATTRS, JSON_DUMP # pylint: disable=invalid-name Base = declarative_base() -SCHEMA_VERSION = 28 +SCHEMA_VERSION = 29 _LOGGER = logging.getLogger(__name__) @@ -515,7 +515,7 @@ class StatisticsMeta(Base): # type: ignore[misc,valid-type] ) __tablename__ = TABLE_STATISTICS_META id = Column(Integer, Identity(), primary_key=True) - statistic_id = Column(String(255), index=True) + statistic_id = Column(String(255), index=True, unique=True) source = Column(String(32)) unit_of_measurement = Column(String(255)) has_mean = Column(Boolean) diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index 2c6c51a31f4..4bed39fee4a 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -33,6 +33,7 @@ from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import entity_registry from homeassistant.helpers.json import JSONEncoder from homeassistant.helpers.storage import STORAGE_DIR +from homeassistant.helpers.typing import UNDEFINED, UndefinedType import homeassistant.util.dt as dt_util import homeassistant.util.pressure as pressure_util import homeassistant.util.temperature as temperature_util @@ -208,18 +209,11 @@ class ValidationIssue: def async_setup(hass: HomeAssistant) -> None: """Set up the history hooks.""" - def _entity_id_changed(event: Event) -> None: - """Handle entity_id changed.""" - old_entity_id = event.data["old_entity_id"] - entity_id = event.data["entity_id"] - with session_scope(hass=hass) as session: - session.query(StatisticsMeta).filter( - (StatisticsMeta.statistic_id == old_entity_id) - & (StatisticsMeta.source == DOMAIN) - ).update({StatisticsMeta.statistic_id: entity_id}) - - async def _async_entity_id_changed(event: Event) -> None: - await hass.data[DATA_INSTANCE].async_add_executor_job(_entity_id_changed, event) + @callback + def _async_entity_id_changed(event: Event) -> None: + hass.data[DATA_INSTANCE].async_update_statistics_metadata( + event.data["old_entity_id"], new_statistic_id=event.data["entity_id"] + ) @callback def entity_registry_changed_filter(event: Event) -> bool: @@ -380,7 +374,7 @@ def _delete_duplicates_from_table( return (total_deleted_rows, all_non_identical_duplicates) -def delete_duplicates(hass: HomeAssistant, session: Session) -> None: +def delete_statistics_duplicates(hass: HomeAssistant, session: Session) -> None: """Identify and delete duplicated statistics. A backup will be made of duplicated statistics before it is deleted. @@ -423,6 +417,69 @@ def delete_duplicates(hass: HomeAssistant, session: Session) -> None: ) +def _find_statistics_meta_duplicates(session: Session) -> list[int]: + """Find duplicated statistics_meta.""" + subquery = ( + session.query( + StatisticsMeta.statistic_id, + literal_column("1").label("is_duplicate"), + ) + .group_by(StatisticsMeta.statistic_id) + .having(func.count() > 1) + .subquery() + ) + query = ( + session.query(StatisticsMeta) + .outerjoin( + subquery, + (subquery.c.statistic_id == StatisticsMeta.statistic_id), + ) + .filter(subquery.c.is_duplicate == 1) + .order_by(StatisticsMeta.statistic_id, StatisticsMeta.id.desc()) + .limit(1000 * MAX_ROWS_TO_PURGE) + ) + duplicates = execute(query) + statistic_id = None + duplicate_ids: list[int] = [] + + if not duplicates: + return duplicate_ids + + for duplicate in duplicates: + if statistic_id != duplicate.statistic_id: + statistic_id = duplicate.statistic_id + continue + duplicate_ids.append(duplicate.id) + + return duplicate_ids + + +def _delete_statistics_meta_duplicates(session: Session) -> int: + """Identify and delete duplicated statistics from a specified table.""" + total_deleted_rows = 0 + while True: + duplicate_ids = _find_statistics_meta_duplicates(session) + if not duplicate_ids: + break + for i in range(0, len(duplicate_ids), MAX_ROWS_TO_PURGE): + deleted_rows = ( + session.query(StatisticsMeta) + .filter(StatisticsMeta.id.in_(duplicate_ids[i : i + MAX_ROWS_TO_PURGE])) + .delete(synchronize_session=False) + ) + total_deleted_rows += deleted_rows + return total_deleted_rows + + +def delete_statistics_meta_duplicates(session: Session) -> None: + """Identify and delete duplicated statistics_meta.""" + deleted_statistics_rows = _delete_statistics_meta_duplicates(session) + if deleted_statistics_rows: + _LOGGER.info( + "Deleted %s duplicated statistics_meta rows", deleted_statistics_rows + ) + + def _compile_hourly_statistics_summary_mean_stmt( start_time: datetime, end_time: datetime ) -> StatementLambdaElement: @@ -736,13 +793,26 @@ def clear_statistics(instance: Recorder, statistic_ids: list[str]) -> None: def update_statistics_metadata( - instance: Recorder, statistic_id: str, unit_of_measurement: str | None + instance: Recorder, + statistic_id: str, + new_statistic_id: str | None | UndefinedType, + new_unit_of_measurement: str | None | UndefinedType, ) -> None: """Update statistics metadata for a statistic_id.""" - with session_scope(session=instance.get_session()) as session: - session.query(StatisticsMeta).filter( - StatisticsMeta.statistic_id == statistic_id - ).update({StatisticsMeta.unit_of_measurement: unit_of_measurement}) + if new_unit_of_measurement is not UNDEFINED: + with session_scope(session=instance.get_session()) as session: + session.query(StatisticsMeta).filter( + StatisticsMeta.statistic_id == statistic_id + ).update({StatisticsMeta.unit_of_measurement: new_unit_of_measurement}) + if new_statistic_id is not UNDEFINED: + with session_scope( + session=instance.get_session(), + exception_filter=_filter_unique_constraint_integrity_error(instance), + ) as session: + session.query(StatisticsMeta).filter( + (StatisticsMeta.statistic_id == statistic_id) + & (StatisticsMeta.source == DOMAIN) + ).update({StatisticsMeta.statistic_id: new_statistic_id}) def list_statistic_ids( diff --git a/homeassistant/components/recorder/tasks.py b/homeassistant/components/recorder/tasks.py index 07855d27dff..5ec83a3cefc 100644 --- a/homeassistant/components/recorder/tasks.py +++ b/homeassistant/components/recorder/tasks.py @@ -10,6 +10,7 @@ import threading from typing import TYPE_CHECKING, Any from homeassistant.core import Event +from homeassistant.helpers.typing import UndefinedType from . import purge, statistics from .const import DOMAIN, EXCLUDE_ATTRIBUTES @@ -46,12 +47,16 @@ class UpdateStatisticsMetadataTask(RecorderTask): """Object to store statistics_id and unit for update of statistics metadata.""" statistic_id: str - unit_of_measurement: str | None + new_statistic_id: str | None | UndefinedType + new_unit_of_measurement: str | None | UndefinedType def run(self, instance: Recorder) -> None: """Handle the task.""" statistics.update_statistics_metadata( - instance, self.statistic_id, self.unit_of_measurement + instance, + self.statistic_id, + self.new_statistic_id, + self.new_unit_of_measurement, ) diff --git a/homeassistant/components/recorder/websocket_api.py b/homeassistant/components/recorder/websocket_api.py index a851d2681f4..d0499fbf9cb 100644 --- a/homeassistant/components/recorder/websocket_api.py +++ b/homeassistant/components/recorder/websocket_api.py @@ -103,7 +103,7 @@ def ws_update_statistics_metadata( ) -> None: """Update statistics metadata for a statistic_id.""" hass.data[DATA_INSTANCE].async_update_statistics_metadata( - msg["statistic_id"], msg["unit_of_measurement"] + msg["statistic_id"], new_unit_of_measurement=msg["unit_of_measurement"] ) connection.send_result(msg["id"]) diff --git a/tests/components/recorder/models_schema_28.py b/tests/components/recorder/models_schema_28.py new file mode 100644 index 00000000000..8d2de0432ac --- /dev/null +++ b/tests/components/recorder/models_schema_28.py @@ -0,0 +1,753 @@ +"""Models for SQLAlchemy. + +This file contains the model definitions for schema version 28. +It is used to test the schema migration logic. +""" +from __future__ import annotations + +from datetime import datetime, timedelta +import json +import logging +from typing import Any, TypedDict, cast, overload + +from fnvhash import fnv1a_32 +from sqlalchemy import ( + BigInteger, + Boolean, + Column, + DateTime, + Float, + ForeignKey, + Identity, + Index, + Integer, + SmallInteger, + String, + Text, + distinct, +) +from sqlalchemy.dialects import mysql, oracle, postgresql +from sqlalchemy.engine.row import Row +from sqlalchemy.ext.declarative import declared_attr +from sqlalchemy.orm import declarative_base, relationship +from sqlalchemy.orm.session import Session + +from homeassistant.components.recorder.const import ALL_DOMAIN_EXCLUDE_ATTRS, JSON_DUMP +from homeassistant.const import ( + MAX_LENGTH_EVENT_CONTEXT_ID, + MAX_LENGTH_EVENT_EVENT_TYPE, + MAX_LENGTH_EVENT_ORIGIN, + MAX_LENGTH_STATE_ENTITY_ID, + MAX_LENGTH_STATE_STATE, +) +from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id +import homeassistant.util.dt as dt_util + +# SQLAlchemy Schema +# pylint: disable=invalid-name +Base = declarative_base() + +SCHEMA_VERSION = 28 + +_LOGGER = logging.getLogger(__name__) + +DB_TIMEZONE = "+00:00" + +TABLE_EVENTS = "events" +TABLE_EVENT_DATA = "event_data" +TABLE_STATES = "states" +TABLE_STATE_ATTRIBUTES = "state_attributes" +TABLE_RECORDER_RUNS = "recorder_runs" +TABLE_SCHEMA_CHANGES = "schema_changes" +TABLE_STATISTICS = "statistics" +TABLE_STATISTICS_META = "statistics_meta" +TABLE_STATISTICS_RUNS = "statistics_runs" +TABLE_STATISTICS_SHORT_TERM = "statistics_short_term" + +ALL_TABLES = [ + TABLE_STATES, + TABLE_STATE_ATTRIBUTES, + TABLE_EVENTS, + TABLE_EVENT_DATA, + TABLE_RECORDER_RUNS, + TABLE_SCHEMA_CHANGES, + TABLE_STATISTICS, + TABLE_STATISTICS_META, + TABLE_STATISTICS_RUNS, + TABLE_STATISTICS_SHORT_TERM, +] + +TABLES_TO_CHECK = [ + TABLE_STATES, + TABLE_EVENTS, + TABLE_RECORDER_RUNS, + TABLE_SCHEMA_CHANGES, +] + + +EMPTY_JSON_OBJECT = "{}" + + +DATETIME_TYPE = DateTime(timezone=True).with_variant( + mysql.DATETIME(timezone=True, fsp=6), "mysql" +) +DOUBLE_TYPE = ( + Float() + .with_variant(mysql.DOUBLE(asdecimal=False), "mysql") + .with_variant(oracle.DOUBLE_PRECISION(), "oracle") + .with_variant(postgresql.DOUBLE_PRECISION(), "postgresql") +) +EVENT_ORIGIN_ORDER = [EventOrigin.local, EventOrigin.remote] +EVENT_ORIGIN_TO_IDX = {origin: idx for idx, origin in enumerate(EVENT_ORIGIN_ORDER)} + + +class Events(Base): # type: ignore[misc,valid-type] + """Event history data.""" + + __table_args__ = ( + # Used for fetching events at a specific time + # see logbook + Index("ix_events_event_type_time_fired", "event_type", "time_fired"), + {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, + ) + __tablename__ = TABLE_EVENTS + event_id = Column(Integer, Identity(), primary_key=True) # no longer used + event_type = Column(String(MAX_LENGTH_EVENT_EVENT_TYPE)) + event_data = Column(Text().with_variant(mysql.LONGTEXT, "mysql")) + origin = Column(String(MAX_LENGTH_EVENT_ORIGIN)) # no longer used + origin_idx = Column(SmallInteger) + time_fired = Column(DATETIME_TYPE, index=True) + context_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True) + context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID)) + context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID)) + data_id = Column(Integer, ForeignKey("event_data.data_id"), index=True) + event_data_rel = relationship("EventData") + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + @staticmethod + def from_event(event: Event) -> Events: + """Create an event database object from a native event.""" + return Events( + event_type=event.event_type, + event_data=None, + origin_idx=EVENT_ORIGIN_TO_IDX.get(event.origin), + time_fired=event.time_fired, + context_id=event.context.id, + context_user_id=event.context.user_id, + context_parent_id=event.context.parent_id, + ) + + def to_native(self, validate_entity_id: bool = True) -> Event | None: + """Convert to a native HA Event.""" + context = Context( + id=self.context_id, + user_id=self.context_user_id, + parent_id=self.context_parent_id, + ) + try: + return Event( + self.event_type, + json.loads(self.event_data) if self.event_data else {}, + EventOrigin(self.origin) + if self.origin + else EVENT_ORIGIN_ORDER[self.origin_idx], + process_timestamp(self.time_fired), + context=context, + ) + except ValueError: + # When json.loads fails + _LOGGER.exception("Error converting to event: %s", self) + return None + + +class EventData(Base): # type: ignore[misc,valid-type] + """Event data history.""" + + __table_args__ = ( + {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, + ) + __tablename__ = TABLE_EVENT_DATA + data_id = Column(Integer, Identity(), primary_key=True) + hash = Column(BigInteger, index=True) + # Note that this is not named attributes to avoid confusion with the states table + shared_data = Column(Text().with_variant(mysql.LONGTEXT, "mysql")) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + @staticmethod + def from_event(event: Event) -> EventData: + """Create object from an event.""" + shared_data = JSON_DUMP(event.data) + return EventData( + shared_data=shared_data, hash=EventData.hash_shared_data(shared_data) + ) + + @staticmethod + def shared_data_from_event(event: Event) -> str: + """Create shared_attrs from an event.""" + return JSON_DUMP(event.data) + + @staticmethod + def hash_shared_data(shared_data: str) -> int: + """Return the hash of json encoded shared data.""" + return cast(int, fnv1a_32(shared_data.encode("utf-8"))) + + def to_native(self) -> dict[str, Any]: + """Convert to an HA state object.""" + try: + return cast(dict[str, Any], json.loads(self.shared_data)) + except ValueError: + _LOGGER.exception("Error converting row to event data: %s", self) + return {} + + +class States(Base): # type: ignore[misc,valid-type] + """State change history.""" + + __table_args__ = ( + # Used for fetching the state of entities at a specific time + # (get_states in history.py) + Index("ix_states_entity_id_last_updated", "entity_id", "last_updated"), + {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, + ) + __tablename__ = TABLE_STATES + state_id = Column(Integer, Identity(), primary_key=True) + entity_id = Column(String(MAX_LENGTH_STATE_ENTITY_ID)) + state = Column(String(MAX_LENGTH_STATE_STATE)) + attributes = Column(Text().with_variant(mysql.LONGTEXT, "mysql")) + event_id = Column( + Integer, ForeignKey("events.event_id", ondelete="CASCADE"), index=True + ) + last_changed = Column(DATETIME_TYPE, default=dt_util.utcnow) + last_updated = Column(DATETIME_TYPE, default=dt_util.utcnow, index=True) + old_state_id = Column(Integer, ForeignKey("states.state_id"), index=True) + attributes_id = Column( + Integer, ForeignKey("state_attributes.attributes_id"), index=True + ) + context_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True) + context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID)) + context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID)) + origin_idx = Column(SmallInteger) # 0 is local, 1 is remote + old_state = relationship("States", remote_side=[state_id]) + state_attributes = relationship("StateAttributes") + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + @staticmethod + def from_event(event: Event) -> States: + """Create object from a state_changed event.""" + entity_id = event.data["entity_id"] + state: State | None = event.data.get("new_state") + dbstate = States( + entity_id=entity_id, + attributes=None, + context_id=event.context.id, + context_user_id=event.context.user_id, + context_parent_id=event.context.parent_id, + origin_idx=EVENT_ORIGIN_TO_IDX.get(event.origin), + ) + + # None state means the state was removed from the state machine + if state is None: + dbstate.state = "" + dbstate.last_changed = event.time_fired + dbstate.last_updated = event.time_fired + else: + dbstate.state = state.state + dbstate.last_changed = state.last_changed + dbstate.last_updated = state.last_updated + + return dbstate + + def to_native(self, validate_entity_id: bool = True) -> State | None: + """Convert to an HA state object.""" + context = Context( + id=self.context_id, + user_id=self.context_user_id, + parent_id=self.context_parent_id, + ) + try: + return State( + self.entity_id, + self.state, + # Join the state_attributes table on attributes_id to get the attributes + # for newer states + json.loads(self.attributes) if self.attributes else {}, + process_timestamp(self.last_changed), + process_timestamp(self.last_updated), + context=context, + validate_entity_id=validate_entity_id, + ) + except ValueError: + # When json.loads fails + _LOGGER.exception("Error converting row to state: %s", self) + return None + + +class StateAttributes(Base): # type: ignore[misc,valid-type] + """State attribute change history.""" + + __table_args__ = ( + {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, + ) + __tablename__ = TABLE_STATE_ATTRIBUTES + attributes_id = Column(Integer, Identity(), primary_key=True) + hash = Column(BigInteger, index=True) + # Note that this is not named attributes to avoid confusion with the states table + shared_attrs = Column(Text().with_variant(mysql.LONGTEXT, "mysql")) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + @staticmethod + def from_event(event: Event) -> StateAttributes: + """Create object from a state_changed event.""" + state: State | None = event.data.get("new_state") + # None state means the state was removed from the state machine + dbstate = StateAttributes( + shared_attrs="{}" if state is None else JSON_DUMP(state.attributes) + ) + dbstate.hash = StateAttributes.hash_shared_attrs(dbstate.shared_attrs) + return dbstate + + @staticmethod + def shared_attrs_from_event( + event: Event, exclude_attrs_by_domain: dict[str, set[str]] + ) -> str: + """Create shared_attrs from a state_changed event.""" + state: State | None = event.data.get("new_state") + # None state means the state was removed from the state machine + if state is None: + return "{}" + domain = split_entity_id(state.entity_id)[0] + exclude_attrs = ( + exclude_attrs_by_domain.get(domain, set()) | ALL_DOMAIN_EXCLUDE_ATTRS + ) + return JSON_DUMP( + {k: v for k, v in state.attributes.items() if k not in exclude_attrs} + ) + + @staticmethod + def hash_shared_attrs(shared_attrs: str) -> int: + """Return the hash of json encoded shared attributes.""" + return cast(int, fnv1a_32(shared_attrs.encode("utf-8"))) + + def to_native(self) -> dict[str, Any]: + """Convert to an HA state object.""" + try: + return cast(dict[str, Any], json.loads(self.shared_attrs)) + except ValueError: + # When json.loads fails + _LOGGER.exception("Error converting row to state attributes: %s", self) + return {} + + +class StatisticResult(TypedDict): + """Statistic result data class. + + Allows multiple datapoints for the same statistic_id. + """ + + meta: StatisticMetaData + stat: StatisticData + + +class StatisticDataBase(TypedDict): + """Mandatory fields for statistic data class.""" + + start: datetime + + +class StatisticData(StatisticDataBase, total=False): + """Statistic data class.""" + + mean: float + min: float + max: float + last_reset: datetime | None + state: float + sum: float + + +class StatisticsBase: + """Statistics base class.""" + + id = Column(Integer, Identity(), primary_key=True) + created = Column(DATETIME_TYPE, default=dt_util.utcnow) + + @declared_attr # type: ignore[misc] + def metadata_id(self) -> Column: + """Define the metadata_id column for sub classes.""" + return Column( + Integer, + ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"), + index=True, + ) + + start = Column(DATETIME_TYPE, index=True) + mean = Column(DOUBLE_TYPE) + min = Column(DOUBLE_TYPE) + max = Column(DOUBLE_TYPE) + last_reset = Column(DATETIME_TYPE) + state = Column(DOUBLE_TYPE) + sum = Column(DOUBLE_TYPE) + + @classmethod + def from_stats(cls, metadata_id: int, stats: StatisticData) -> StatisticsBase: + """Create object from a statistics.""" + return cls( # type: ignore[call-arg,misc] + metadata_id=metadata_id, + **stats, + ) + + +class Statistics(Base, StatisticsBase): # type: ignore[misc,valid-type] + """Long term statistics.""" + + duration = timedelta(hours=1) + + __table_args__ = ( + # Used for fetching statistics for a certain entity at a specific time + Index("ix_statistics_statistic_id_start", "metadata_id", "start", unique=True), + ) + __tablename__ = TABLE_STATISTICS + + +class StatisticsShortTerm(Base, StatisticsBase): # type: ignore[misc,valid-type] + """Short term statistics.""" + + duration = timedelta(minutes=5) + + __table_args__ = ( + # Used for fetching statistics for a certain entity at a specific time + Index( + "ix_statistics_short_term_statistic_id_start", + "metadata_id", + "start", + unique=True, + ), + ) + __tablename__ = TABLE_STATISTICS_SHORT_TERM + + +class StatisticMetaData(TypedDict): + """Statistic meta data class.""" + + has_mean: bool + has_sum: bool + name: str | None + source: str + statistic_id: str + unit_of_measurement: str | None + + +class StatisticsMeta(Base): # type: ignore[misc,valid-type] + """Statistics meta data.""" + + __table_args__ = ( + {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, + ) + __tablename__ = TABLE_STATISTICS_META + id = Column(Integer, Identity(), primary_key=True) + statistic_id = Column(String(255), index=True) + source = Column(String(32)) + unit_of_measurement = Column(String(255)) + has_mean = Column(Boolean) + has_sum = Column(Boolean) + name = Column(String(255)) + + @staticmethod + def from_meta(meta: StatisticMetaData) -> StatisticsMeta: + """Create object from meta data.""" + return StatisticsMeta(**meta) + + +class RecorderRuns(Base): # type: ignore[misc,valid-type] + """Representation of recorder run.""" + + __table_args__ = (Index("ix_recorder_runs_start_end", "start", "end"),) + __tablename__ = TABLE_RECORDER_RUNS + run_id = Column(Integer, Identity(), primary_key=True) + start = Column(DateTime(timezone=True), default=dt_util.utcnow) + end = Column(DateTime(timezone=True)) + closed_incorrect = Column(Boolean, default=False) + created = Column(DateTime(timezone=True), default=dt_util.utcnow) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + end = ( + f"'{self.end.isoformat(sep=' ', timespec='seconds')}'" if self.end else None + ) + return ( + f"" + ) + + def entity_ids(self, point_in_time: datetime | None = None) -> list[str]: + """Return the entity ids that existed in this run. + + Specify point_in_time if you want to know which existed at that point + in time inside the run. + """ + session = Session.object_session(self) + + assert session is not None, "RecorderRuns need to be persisted" + + query = session.query(distinct(States.entity_id)).filter( + States.last_updated >= self.start + ) + + if point_in_time is not None: + query = query.filter(States.last_updated < point_in_time) + elif self.end is not None: + query = query.filter(States.last_updated < self.end) + + return [row[0] for row in query] + + def to_native(self, validate_entity_id: bool = True) -> RecorderRuns: + """Return self, native format is this model.""" + return self + + +class SchemaChanges(Base): # type: ignore[misc,valid-type] + """Representation of schema version changes.""" + + __tablename__ = TABLE_SCHEMA_CHANGES + change_id = Column(Integer, Identity(), primary_key=True) + schema_version = Column(Integer) + changed = Column(DateTime(timezone=True), default=dt_util.utcnow) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + +class StatisticsRuns(Base): # type: ignore[misc,valid-type] + """Representation of statistics run.""" + + __tablename__ = TABLE_STATISTICS_RUNS + run_id = Column(Integer, Identity(), primary_key=True) + start = Column(DateTime(timezone=True), index=True) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + +@overload +def process_timestamp(ts: None) -> None: + ... + + +@overload +def process_timestamp(ts: datetime) -> datetime: + ... + + +def process_timestamp(ts: datetime | None) -> datetime | None: + """Process a timestamp into datetime object.""" + if ts is None: + return None + if ts.tzinfo is None: + return ts.replace(tzinfo=dt_util.UTC) + + return dt_util.as_utc(ts) + + +@overload +def process_timestamp_to_utc_isoformat(ts: None) -> None: + ... + + +@overload +def process_timestamp_to_utc_isoformat(ts: datetime) -> str: + ... + + +def process_timestamp_to_utc_isoformat(ts: datetime | None) -> str | None: + """Process a timestamp into UTC isotime.""" + if ts is None: + return None + if ts.tzinfo == dt_util.UTC: + return ts.isoformat() + if ts.tzinfo is None: + return f"{ts.isoformat()}{DB_TIMEZONE}" + return ts.astimezone(dt_util.UTC).isoformat() + + +class LazyState(State): + """A lazy version of core State.""" + + __slots__ = [ + "_row", + "_attributes", + "_last_changed", + "_last_updated", + "_context", + "_attr_cache", + ] + + def __init__( # pylint: disable=super-init-not-called + self, row: Row, attr_cache: dict[str, dict[str, Any]] | None = None + ) -> None: + """Init the lazy state.""" + self._row = row + self.entity_id: str = self._row.entity_id + self.state = self._row.state or "" + self._attributes: dict[str, Any] | None = None + self._last_changed: datetime | None = None + self._last_updated: datetime | None = None + self._context: Context | None = None + self._attr_cache = attr_cache + + @property # type: ignore[override] + def attributes(self) -> dict[str, Any]: # type: ignore[override] + """State attributes.""" + if self._attributes is None: + source = self._row.shared_attrs or self._row.attributes + if self._attr_cache is not None and ( + attributes := self._attr_cache.get(source) + ): + self._attributes = attributes + return attributes + if source == EMPTY_JSON_OBJECT or source is None: + self._attributes = {} + return self._attributes + try: + self._attributes = json.loads(source) + except ValueError: + # When json.loads fails + _LOGGER.exception( + "Error converting row to state attributes: %s", self._row + ) + self._attributes = {} + if self._attr_cache is not None: + self._attr_cache[source] = self._attributes + return self._attributes + + @attributes.setter + def attributes(self, value: dict[str, Any]) -> None: + """Set attributes.""" + self._attributes = value + + @property # type: ignore[override] + def context(self) -> Context: # type: ignore[override] + """State context.""" + if self._context is None: + self._context = Context(id=None) # type: ignore[arg-type] + return self._context + + @context.setter + def context(self, value: Context) -> None: + """Set context.""" + self._context = value + + @property # type: ignore[override] + def last_changed(self) -> datetime: # type: ignore[override] + """Last changed datetime.""" + if self._last_changed is None: + self._last_changed = process_timestamp(self._row.last_changed) + return self._last_changed + + @last_changed.setter + def last_changed(self, value: datetime) -> None: + """Set last changed datetime.""" + self._last_changed = value + + @property # type: ignore[override] + def last_updated(self) -> datetime: # type: ignore[override] + """Last updated datetime.""" + if self._last_updated is None: + if (last_updated := self._row.last_updated) is not None: + self._last_updated = process_timestamp(last_updated) + else: + self._last_updated = self.last_changed + return self._last_updated + + @last_updated.setter + def last_updated(self, value: datetime) -> None: + """Set last updated datetime.""" + self._last_updated = value + + def as_dict(self) -> dict[str, Any]: # type: ignore[override] + """Return a dict representation of the LazyState. + + Async friendly. + + To be used for JSON serialization. + """ + if self._last_changed is None and self._last_updated is None: + last_changed_isoformat = process_timestamp_to_utc_isoformat( + self._row.last_changed + ) + if ( + self._row.last_updated is None + or self._row.last_changed == self._row.last_updated + ): + last_updated_isoformat = last_changed_isoformat + else: + last_updated_isoformat = process_timestamp_to_utc_isoformat( + self._row.last_updated + ) + else: + last_changed_isoformat = self.last_changed.isoformat() + if self.last_changed == self.last_updated: + last_updated_isoformat = last_changed_isoformat + else: + last_updated_isoformat = self.last_updated.isoformat() + return { + "entity_id": self.entity_id, + "state": self.state, + "attributes": self._attributes or self.attributes, + "last_changed": last_changed_isoformat, + "last_updated": last_updated_isoformat, + } + + def __eq__(self, other: Any) -> bool: + """Return the comparison.""" + return ( + other.__class__ in [self.__class__, State] + and self.entity_id == other.entity_id + and self.state == other.state + and self.attributes == other.attributes + ) diff --git a/tests/components/recorder/test_statistics.py b/tests/components/recorder/test_statistics.py index 28eba51a4f3..882f00d2940 100644 --- a/tests/components/recorder/test_statistics.py +++ b/tests/components/recorder/test_statistics.py @@ -1,21 +1,26 @@ """The tests for sensor recorder platform.""" # pylint: disable=protected-access,invalid-name from datetime import timedelta +import importlib +import sys from unittest.mock import patch, sentinel import pytest from pytest import approx +from sqlalchemy import create_engine +from sqlalchemy.orm import Session from homeassistant.components import recorder from homeassistant.components.recorder import history, statistics -from homeassistant.components.recorder.const import DATA_INSTANCE +from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX from homeassistant.components.recorder.models import ( StatisticsShortTerm, process_timestamp_to_utc_isoformat, ) from homeassistant.components.recorder.statistics import ( async_add_external_statistics, - delete_duplicates, + delete_statistics_duplicates, + delete_statistics_meta_duplicates, get_last_short_term_statistics, get_last_statistics, get_latest_short_term_statistics, @@ -32,7 +37,7 @@ import homeassistant.util.dt as dt_util from .common import async_wait_recording_done, do_adhoc_statistics -from tests.common import mock_registry +from tests.common import get_test_home_assistant, mock_registry from tests.components.recorder.common import wait_recording_done ORIG_TZ = dt_util.DEFAULT_TIME_ZONE @@ -306,12 +311,92 @@ def test_rename_entity(hass_recorder): entity_reg.async_update_entity("sensor.test1", new_entity_id="sensor.test99") hass.add_job(rename_entry) - hass.block_till_done() + wait_recording_done(hass) stats = statistics_during_period(hass, zero, period="5minute") assert stats == {"sensor.test99": expected_stats99, "sensor.test2": expected_stats2} +def test_rename_entity_collision(hass_recorder, caplog): + """Test statistics is migrated when entity_id is changed.""" + hass = hass_recorder() + setup_component(hass, "sensor", {}) + + entity_reg = mock_registry(hass) + + @callback + def add_entry(): + reg_entry = entity_reg.async_get_or_create( + "sensor", + "test", + "unique_0000", + suggested_object_id="test1", + ) + assert reg_entry.entity_id == "sensor.test1" + + hass.add_job(add_entry) + hass.block_till_done() + + zero, four, states = record_states(hass) + hist = history.get_significant_states(hass, zero, four) + assert dict(states) == dict(hist) + + for kwargs in ({}, {"statistic_ids": ["sensor.test1"]}): + stats = statistics_during_period(hass, zero, period="5minute", **kwargs) + assert stats == {} + stats = get_last_short_term_statistics(hass, 0, "sensor.test1", True) + assert stats == {} + + do_adhoc_statistics(hass, start=zero) + wait_recording_done(hass) + expected_1 = { + "statistic_id": "sensor.test1", + "start": process_timestamp_to_utc_isoformat(zero), + "end": process_timestamp_to_utc_isoformat(zero + timedelta(minutes=5)), + "mean": approx(14.915254237288135), + "min": approx(10.0), + "max": approx(20.0), + "last_reset": None, + "state": None, + "sum": None, + } + expected_stats1 = [ + {**expected_1, "statistic_id": "sensor.test1"}, + ] + expected_stats2 = [ + {**expected_1, "statistic_id": "sensor.test2"}, + ] + + stats = statistics_during_period(hass, zero, period="5minute") + assert stats == {"sensor.test1": expected_stats1, "sensor.test2": expected_stats2} + + # Insert metadata for sensor.test99 + metadata_1 = { + "has_mean": True, + "has_sum": False, + "name": "Total imported energy", + "source": "test", + "statistic_id": "sensor.test99", + "unit_of_measurement": "kWh", + } + + with session_scope(hass=hass) as session: + session.add(recorder.models.StatisticsMeta.from_meta(metadata_1)) + + # Rename entity sensor.test1 to sensor.test99 + @callback + def rename_entry(): + entity_reg.async_update_entity("sensor.test1", new_entity_id="sensor.test99") + + hass.add_job(rename_entry) + wait_recording_done(hass) + + # Statistics failed to migrate due to the collision + stats = statistics_during_period(hass, zero, period="5minute") + assert stats == {"sensor.test1": expected_stats1, "sensor.test2": expected_stats2} + assert "Blocked attempt to insert duplicated statistic rows" in caplog.text + + def test_statistics_duplicated(hass_recorder, caplog): """Test statistics with same start time is not compiled.""" hass = hass_recorder() @@ -737,7 +822,7 @@ def test_delete_duplicates_no_duplicates(hass_recorder, caplog): hass = hass_recorder() wait_recording_done(hass) with session_scope(hass=hass) as session: - delete_duplicates(hass, session) + delete_statistics_duplicates(hass, session) assert "duplicated statistics rows" not in caplog.text assert "Found non identical" not in caplog.text assert "Found duplicated" not in caplog.text @@ -800,6 +885,208 @@ def test_duplicate_statistics_handle_integrity_error(hass_recorder, caplog): assert "Blocked attempt to insert duplicated statistic rows" in caplog.text +def _create_engine_28(*args, **kwargs): + """Test version of create_engine that initializes with old schema. + + This simulates an existing db with the old schema. + """ + module = "tests.components.recorder.models_schema_28" + importlib.import_module(module) + old_models = sys.modules[module] + engine = create_engine(*args, **kwargs) + old_models.Base.metadata.create_all(engine) + with Session(engine) as session: + session.add(recorder.models.StatisticsRuns(start=statistics.get_start_time())) + session.add( + recorder.models.SchemaChanges(schema_version=old_models.SCHEMA_VERSION) + ) + session.commit() + return engine + + +def test_delete_metadata_duplicates(caplog, tmpdir): + """Test removal of duplicated statistics.""" + test_db_file = tmpdir.mkdir("sqlite").join("test_run_info.db") + dburl = f"{SQLITE_URL_PREFIX}//{test_db_file}" + + module = "tests.components.recorder.models_schema_28" + importlib.import_module(module) + old_models = sys.modules[module] + + external_energy_metadata_1 = { + "has_mean": False, + "has_sum": True, + "name": "Total imported energy", + "source": "test", + "statistic_id": "test:total_energy_import_tariff_1", + "unit_of_measurement": "kWh", + } + external_energy_metadata_2 = { + "has_mean": False, + "has_sum": True, + "name": "Total imported energy", + "source": "test", + "statistic_id": "test:total_energy_import_tariff_1", + "unit_of_measurement": "kWh", + } + external_co2_metadata = { + "has_mean": True, + "has_sum": False, + "name": "Fossil percentage", + "source": "test", + "statistic_id": "test:fossil_percentage", + "unit_of_measurement": "%", + } + + # Create some duplicated statistics_meta with schema version 28 + with patch.object(recorder, "models", old_models), patch.object( + recorder.migration, "SCHEMA_VERSION", old_models.SCHEMA_VERSION + ), patch( + "homeassistant.components.recorder.core.create_engine", new=_create_engine_28 + ): + hass = get_test_home_assistant() + setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) + wait_recording_done(hass) + wait_recording_done(hass) + + with session_scope(hass=hass) as session: + session.add( + recorder.models.StatisticsMeta.from_meta(external_energy_metadata_1) + ) + session.add( + recorder.models.StatisticsMeta.from_meta(external_energy_metadata_2) + ) + session.add(recorder.models.StatisticsMeta.from_meta(external_co2_metadata)) + + with session_scope(hass=hass) as session: + tmp = session.query(recorder.models.StatisticsMeta).all() + assert len(tmp) == 3 + assert tmp[0].id == 1 + assert tmp[0].statistic_id == "test:total_energy_import_tariff_1" + assert tmp[1].id == 2 + assert tmp[1].statistic_id == "test:total_energy_import_tariff_1" + assert tmp[2].id == 3 + assert tmp[2].statistic_id == "test:fossil_percentage" + + hass.stop() + dt_util.DEFAULT_TIME_ZONE = ORIG_TZ + + # Test that the duplicates are removed during migration from schema 28 + hass = get_test_home_assistant() + setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) + hass.start() + wait_recording_done(hass) + wait_recording_done(hass) + + assert "Deleted 1 duplicated statistics_meta rows" in caplog.text + with session_scope(hass=hass) as session: + tmp = session.query(recorder.models.StatisticsMeta).all() + assert len(tmp) == 2 + assert tmp[0].id == 2 + assert tmp[0].statistic_id == "test:total_energy_import_tariff_1" + assert tmp[1].id == 3 + assert tmp[1].statistic_id == "test:fossil_percentage" + + hass.stop() + dt_util.DEFAULT_TIME_ZONE = ORIG_TZ + + +def test_delete_metadata_duplicates_many(caplog, tmpdir): + """Test removal of duplicated statistics.""" + test_db_file = tmpdir.mkdir("sqlite").join("test_run_info.db") + dburl = f"{SQLITE_URL_PREFIX}//{test_db_file}" + + module = "tests.components.recorder.models_schema_28" + importlib.import_module(module) + old_models = sys.modules[module] + + external_energy_metadata_1 = { + "has_mean": False, + "has_sum": True, + "name": "Total imported energy", + "source": "test", + "statistic_id": "test:total_energy_import_tariff_1", + "unit_of_measurement": "kWh", + } + external_energy_metadata_2 = { + "has_mean": False, + "has_sum": True, + "name": "Total imported energy", + "source": "test", + "statistic_id": "test:total_energy_import_tariff_2", + "unit_of_measurement": "kWh", + } + external_co2_metadata = { + "has_mean": True, + "has_sum": False, + "name": "Fossil percentage", + "source": "test", + "statistic_id": "test:fossil_percentage", + "unit_of_measurement": "%", + } + + # Create some duplicated statistics with schema version 28 + with patch.object(recorder, "models", old_models), patch.object( + recorder.migration, "SCHEMA_VERSION", old_models.SCHEMA_VERSION + ), patch( + "homeassistant.components.recorder.core.create_engine", new=_create_engine_28 + ): + hass = get_test_home_assistant() + setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) + wait_recording_done(hass) + wait_recording_done(hass) + + with session_scope(hass=hass) as session: + session.add( + recorder.models.StatisticsMeta.from_meta(external_energy_metadata_1) + ) + for _ in range(3000): + session.add( + recorder.models.StatisticsMeta.from_meta(external_energy_metadata_1) + ) + session.add( + recorder.models.StatisticsMeta.from_meta(external_energy_metadata_2) + ) + session.add( + recorder.models.StatisticsMeta.from_meta(external_energy_metadata_2) + ) + session.add(recorder.models.StatisticsMeta.from_meta(external_co2_metadata)) + session.add(recorder.models.StatisticsMeta.from_meta(external_co2_metadata)) + + hass.stop() + dt_util.DEFAULT_TIME_ZONE = ORIG_TZ + + # Test that the duplicates are removed during migration from schema 28 + hass = get_test_home_assistant() + setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) + hass.start() + wait_recording_done(hass) + wait_recording_done(hass) + + assert "Deleted 3002 duplicated statistics_meta rows" in caplog.text + with session_scope(hass=hass) as session: + tmp = session.query(recorder.models.StatisticsMeta).all() + assert len(tmp) == 3 + assert tmp[0].id == 3001 + assert tmp[0].statistic_id == "test:total_energy_import_tariff_1" + assert tmp[1].id == 3003 + assert tmp[1].statistic_id == "test:total_energy_import_tariff_2" + assert tmp[2].id == 3005 + assert tmp[2].statistic_id == "test:fossil_percentage" + + hass.stop() + dt_util.DEFAULT_TIME_ZONE = ORIG_TZ + + +def test_delete_metadata_duplicates_no_duplicates(hass_recorder, caplog): + """Test removal of duplicated statistics.""" + hass = hass_recorder() + wait_recording_done(hass) + with session_scope(hass=hass) as session: + delete_statistics_meta_duplicates(session) + assert "duplicated statistics_meta rows" not in caplog.text + + def record_states(hass): """Record some test states.