Prevent duplication of statistics metadata (#71637)

* Prevent duplication of statistics metadata

* Add models_schema_28.py

* Handle entity renaming as a recorder job

* Improve tests
This commit is contained in:
Erik Montnemery 2022-05-24 15:34:46 +02:00 committed by GitHub
parent d620072585
commit 23bd64b7a2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 1175 additions and 32 deletions

View file

@ -35,6 +35,7 @@ from homeassistant.helpers.event import (
async_track_time_interval, async_track_time_interval,
async_track_utc_time_change, async_track_utc_time_change,
) )
from homeassistant.helpers.typing import UNDEFINED, UndefinedType
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from . import migration, statistics from . import migration, statistics
@ -461,10 +462,18 @@ class Recorder(threading.Thread):
@callback @callback
def async_update_statistics_metadata( def async_update_statistics_metadata(
self, statistic_id: str, unit_of_measurement: str | None self,
statistic_id: str,
*,
new_statistic_id: str | UndefinedType = UNDEFINED,
new_unit_of_measurement: str | None | UndefinedType = UNDEFINED,
) -> None: ) -> None:
"""Update statistics metadata for a statistic_id.""" """Update statistics metadata for a statistic_id."""
self.queue_task(UpdateStatisticsMetadataTask(statistic_id, unit_of_measurement)) self.queue_task(
UpdateStatisticsMetadataTask(
statistic_id, new_statistic_id, new_unit_of_measurement
)
)
@callback @callback
def async_external_statistics( def async_external_statistics(

View file

@ -33,7 +33,11 @@ from .models import (
StatisticsShortTerm, StatisticsShortTerm,
process_timestamp, process_timestamp,
) )
from .statistics import delete_duplicates, get_start_time from .statistics import (
delete_statistics_duplicates,
delete_statistics_meta_duplicates,
get_start_time,
)
from .util import session_scope from .util import session_scope
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -670,7 +674,7 @@ def _apply_update( # noqa: C901
# There may be duplicated statistics entries, delete duplicated statistics # There may be duplicated statistics entries, delete duplicated statistics
# and try again # and try again
with session_scope(session=session_maker()) as session: with session_scope(session=session_maker()) as session:
delete_duplicates(hass, session) delete_statistics_duplicates(hass, session)
_create_index( _create_index(
session_maker, "statistics", "ix_statistics_statistic_id_start" session_maker, "statistics", "ix_statistics_statistic_id_start"
) )
@ -705,6 +709,21 @@ def _apply_update( # noqa: C901
_create_index(session_maker, "states", "ix_states_context_id") _create_index(session_maker, "states", "ix_states_context_id")
# Once there are no longer any state_changed events # Once there are no longer any state_changed events
# in the events table we can drop the index on states.event_id # in the events table we can drop the index on states.event_id
elif new_version == 29:
# Recreate statistics_meta index to block duplicated statistic_id
_drop_index(session_maker, "statistics_meta", "ix_statistics_meta_statistic_id")
try:
_create_index(
session_maker, "statistics_meta", "ix_statistics_meta_statistic_id"
)
except DatabaseError:
# There may be duplicated statistics_meta entries, delete duplicates
# and try again
with session_scope(session=session_maker()) as session:
delete_statistics_meta_duplicates(session)
_create_index(
session_maker, "statistics_meta", "ix_statistics_meta_statistic_id"
)
else: else:
raise ValueError(f"No schema migration defined for version {new_version}") raise ValueError(f"No schema migration defined for version {new_version}")

View file

@ -51,7 +51,7 @@ from .const import ALL_DOMAIN_EXCLUDE_ATTRS, JSON_DUMP
# pylint: disable=invalid-name # pylint: disable=invalid-name
Base = declarative_base() Base = declarative_base()
SCHEMA_VERSION = 28 SCHEMA_VERSION = 29
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -515,7 +515,7 @@ class StatisticsMeta(Base): # type: ignore[misc,valid-type]
) )
__tablename__ = TABLE_STATISTICS_META __tablename__ = TABLE_STATISTICS_META
id = Column(Integer, Identity(), primary_key=True) id = Column(Integer, Identity(), primary_key=True)
statistic_id = Column(String(255), index=True) statistic_id = Column(String(255), index=True, unique=True)
source = Column(String(32)) source = Column(String(32))
unit_of_measurement = Column(String(255)) unit_of_measurement = Column(String(255))
has_mean = Column(Boolean) has_mean = Column(Boolean)

View file

@ -33,6 +33,7 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity_registry from homeassistant.helpers import entity_registry
from homeassistant.helpers.json import JSONEncoder from homeassistant.helpers.json import JSONEncoder
from homeassistant.helpers.storage import STORAGE_DIR from homeassistant.helpers.storage import STORAGE_DIR
from homeassistant.helpers.typing import UNDEFINED, UndefinedType
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
import homeassistant.util.pressure as pressure_util import homeassistant.util.pressure as pressure_util
import homeassistant.util.temperature as temperature_util import homeassistant.util.temperature as temperature_util
@ -208,18 +209,11 @@ class ValidationIssue:
def async_setup(hass: HomeAssistant) -> None: def async_setup(hass: HomeAssistant) -> None:
"""Set up the history hooks.""" """Set up the history hooks."""
def _entity_id_changed(event: Event) -> None: @callback
"""Handle entity_id changed.""" def _async_entity_id_changed(event: Event) -> None:
old_entity_id = event.data["old_entity_id"] hass.data[DATA_INSTANCE].async_update_statistics_metadata(
entity_id = event.data["entity_id"] event.data["old_entity_id"], new_statistic_id=event.data["entity_id"]
with session_scope(hass=hass) as session: )
session.query(StatisticsMeta).filter(
(StatisticsMeta.statistic_id == old_entity_id)
& (StatisticsMeta.source == DOMAIN)
).update({StatisticsMeta.statistic_id: entity_id})
async def _async_entity_id_changed(event: Event) -> None:
await hass.data[DATA_INSTANCE].async_add_executor_job(_entity_id_changed, event)
@callback @callback
def entity_registry_changed_filter(event: Event) -> bool: def entity_registry_changed_filter(event: Event) -> bool:
@ -380,7 +374,7 @@ def _delete_duplicates_from_table(
return (total_deleted_rows, all_non_identical_duplicates) return (total_deleted_rows, all_non_identical_duplicates)
def delete_duplicates(hass: HomeAssistant, session: Session) -> None: def delete_statistics_duplicates(hass: HomeAssistant, session: Session) -> None:
"""Identify and delete duplicated statistics. """Identify and delete duplicated statistics.
A backup will be made of duplicated statistics before it is deleted. A backup will be made of duplicated statistics before it is deleted.
@ -423,6 +417,69 @@ def delete_duplicates(hass: HomeAssistant, session: Session) -> None:
) )
def _find_statistics_meta_duplicates(session: Session) -> list[int]:
"""Find duplicated statistics_meta."""
subquery = (
session.query(
StatisticsMeta.statistic_id,
literal_column("1").label("is_duplicate"),
)
.group_by(StatisticsMeta.statistic_id)
.having(func.count() > 1)
.subquery()
)
query = (
session.query(StatisticsMeta)
.outerjoin(
subquery,
(subquery.c.statistic_id == StatisticsMeta.statistic_id),
)
.filter(subquery.c.is_duplicate == 1)
.order_by(StatisticsMeta.statistic_id, StatisticsMeta.id.desc())
.limit(1000 * MAX_ROWS_TO_PURGE)
)
duplicates = execute(query)
statistic_id = None
duplicate_ids: list[int] = []
if not duplicates:
return duplicate_ids
for duplicate in duplicates:
if statistic_id != duplicate.statistic_id:
statistic_id = duplicate.statistic_id
continue
duplicate_ids.append(duplicate.id)
return duplicate_ids
def _delete_statistics_meta_duplicates(session: Session) -> int:
"""Identify and delete duplicated statistics from a specified table."""
total_deleted_rows = 0
while True:
duplicate_ids = _find_statistics_meta_duplicates(session)
if not duplicate_ids:
break
for i in range(0, len(duplicate_ids), MAX_ROWS_TO_PURGE):
deleted_rows = (
session.query(StatisticsMeta)
.filter(StatisticsMeta.id.in_(duplicate_ids[i : i + MAX_ROWS_TO_PURGE]))
.delete(synchronize_session=False)
)
total_deleted_rows += deleted_rows
return total_deleted_rows
def delete_statistics_meta_duplicates(session: Session) -> None:
"""Identify and delete duplicated statistics_meta."""
deleted_statistics_rows = _delete_statistics_meta_duplicates(session)
if deleted_statistics_rows:
_LOGGER.info(
"Deleted %s duplicated statistics_meta rows", deleted_statistics_rows
)
def _compile_hourly_statistics_summary_mean_stmt( def _compile_hourly_statistics_summary_mean_stmt(
start_time: datetime, end_time: datetime start_time: datetime, end_time: datetime
) -> StatementLambdaElement: ) -> StatementLambdaElement:
@ -736,13 +793,26 @@ def clear_statistics(instance: Recorder, statistic_ids: list[str]) -> None:
def update_statistics_metadata( def update_statistics_metadata(
instance: Recorder, statistic_id: str, unit_of_measurement: str | None instance: Recorder,
statistic_id: str,
new_statistic_id: str | None | UndefinedType,
new_unit_of_measurement: str | None | UndefinedType,
) -> None: ) -> None:
"""Update statistics metadata for a statistic_id.""" """Update statistics metadata for a statistic_id."""
with session_scope(session=instance.get_session()) as session: if new_unit_of_measurement is not UNDEFINED:
session.query(StatisticsMeta).filter( with session_scope(session=instance.get_session()) as session:
StatisticsMeta.statistic_id == statistic_id session.query(StatisticsMeta).filter(
).update({StatisticsMeta.unit_of_measurement: unit_of_measurement}) StatisticsMeta.statistic_id == statistic_id
).update({StatisticsMeta.unit_of_measurement: new_unit_of_measurement})
if new_statistic_id is not UNDEFINED:
with session_scope(
session=instance.get_session(),
exception_filter=_filter_unique_constraint_integrity_error(instance),
) as session:
session.query(StatisticsMeta).filter(
(StatisticsMeta.statistic_id == statistic_id)
& (StatisticsMeta.source == DOMAIN)
).update({StatisticsMeta.statistic_id: new_statistic_id})
def list_statistic_ids( def list_statistic_ids(

View file

@ -10,6 +10,7 @@ import threading
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from homeassistant.core import Event from homeassistant.core import Event
from homeassistant.helpers.typing import UndefinedType
from . import purge, statistics from . import purge, statistics
from .const import DOMAIN, EXCLUDE_ATTRIBUTES from .const import DOMAIN, EXCLUDE_ATTRIBUTES
@ -46,12 +47,16 @@ class UpdateStatisticsMetadataTask(RecorderTask):
"""Object to store statistics_id and unit for update of statistics metadata.""" """Object to store statistics_id and unit for update of statistics metadata."""
statistic_id: str statistic_id: str
unit_of_measurement: str | None new_statistic_id: str | None | UndefinedType
new_unit_of_measurement: str | None | UndefinedType
def run(self, instance: Recorder) -> None: def run(self, instance: Recorder) -> None:
"""Handle the task.""" """Handle the task."""
statistics.update_statistics_metadata( statistics.update_statistics_metadata(
instance, self.statistic_id, self.unit_of_measurement instance,
self.statistic_id,
self.new_statistic_id,
self.new_unit_of_measurement,
) )

View file

@ -103,7 +103,7 @@ def ws_update_statistics_metadata(
) -> None: ) -> None:
"""Update statistics metadata for a statistic_id.""" """Update statistics metadata for a statistic_id."""
hass.data[DATA_INSTANCE].async_update_statistics_metadata( hass.data[DATA_INSTANCE].async_update_statistics_metadata(
msg["statistic_id"], msg["unit_of_measurement"] msg["statistic_id"], new_unit_of_measurement=msg["unit_of_measurement"]
) )
connection.send_result(msg["id"]) connection.send_result(msg["id"])

View file

@ -0,0 +1,753 @@
"""Models for SQLAlchemy.
This file contains the model definitions for schema version 28.
It is used to test the schema migration logic.
"""
from __future__ import annotations
from datetime import datetime, timedelta
import json
import logging
from typing import Any, TypedDict, cast, overload
from fnvhash import fnv1a_32
from sqlalchemy import (
BigInteger,
Boolean,
Column,
DateTime,
Float,
ForeignKey,
Identity,
Index,
Integer,
SmallInteger,
String,
Text,
distinct,
)
from sqlalchemy.dialects import mysql, oracle, postgresql
from sqlalchemy.engine.row import Row
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy.orm.session import Session
from homeassistant.components.recorder.const import ALL_DOMAIN_EXCLUDE_ATTRS, JSON_DUMP
from homeassistant.const import (
MAX_LENGTH_EVENT_CONTEXT_ID,
MAX_LENGTH_EVENT_EVENT_TYPE,
MAX_LENGTH_EVENT_ORIGIN,
MAX_LENGTH_STATE_ENTITY_ID,
MAX_LENGTH_STATE_STATE,
)
from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id
import homeassistant.util.dt as dt_util
# SQLAlchemy Schema
# pylint: disable=invalid-name
Base = declarative_base()
SCHEMA_VERSION = 28
_LOGGER = logging.getLogger(__name__)
DB_TIMEZONE = "+00:00"
TABLE_EVENTS = "events"
TABLE_EVENT_DATA = "event_data"
TABLE_STATES = "states"
TABLE_STATE_ATTRIBUTES = "state_attributes"
TABLE_RECORDER_RUNS = "recorder_runs"
TABLE_SCHEMA_CHANGES = "schema_changes"
TABLE_STATISTICS = "statistics"
TABLE_STATISTICS_META = "statistics_meta"
TABLE_STATISTICS_RUNS = "statistics_runs"
TABLE_STATISTICS_SHORT_TERM = "statistics_short_term"
ALL_TABLES = [
TABLE_STATES,
TABLE_STATE_ATTRIBUTES,
TABLE_EVENTS,
TABLE_EVENT_DATA,
TABLE_RECORDER_RUNS,
TABLE_SCHEMA_CHANGES,
TABLE_STATISTICS,
TABLE_STATISTICS_META,
TABLE_STATISTICS_RUNS,
TABLE_STATISTICS_SHORT_TERM,
]
TABLES_TO_CHECK = [
TABLE_STATES,
TABLE_EVENTS,
TABLE_RECORDER_RUNS,
TABLE_SCHEMA_CHANGES,
]
EMPTY_JSON_OBJECT = "{}"
DATETIME_TYPE = DateTime(timezone=True).with_variant(
mysql.DATETIME(timezone=True, fsp=6), "mysql"
)
DOUBLE_TYPE = (
Float()
.with_variant(mysql.DOUBLE(asdecimal=False), "mysql")
.with_variant(oracle.DOUBLE_PRECISION(), "oracle")
.with_variant(postgresql.DOUBLE_PRECISION(), "postgresql")
)
EVENT_ORIGIN_ORDER = [EventOrigin.local, EventOrigin.remote]
EVENT_ORIGIN_TO_IDX = {origin: idx for idx, origin in enumerate(EVENT_ORIGIN_ORDER)}
class Events(Base): # type: ignore[misc,valid-type]
"""Event history data."""
__table_args__ = (
# Used for fetching events at a specific time
# see logbook
Index("ix_events_event_type_time_fired", "event_type", "time_fired"),
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
)
__tablename__ = TABLE_EVENTS
event_id = Column(Integer, Identity(), primary_key=True) # no longer used
event_type = Column(String(MAX_LENGTH_EVENT_EVENT_TYPE))
event_data = Column(Text().with_variant(mysql.LONGTEXT, "mysql"))
origin = Column(String(MAX_LENGTH_EVENT_ORIGIN)) # no longer used
origin_idx = Column(SmallInteger)
time_fired = Column(DATETIME_TYPE, index=True)
context_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True)
context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID))
context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID))
data_id = Column(Integer, ForeignKey("event_data.data_id"), index=True)
event_data_rel = relationship("EventData")
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.Events("
f"id={self.event_id}, type='{self.event_type}', "
f"origin_idx='{self.origin_idx}', time_fired='{self.time_fired}'"
f", data_id={self.data_id})>"
)
@staticmethod
def from_event(event: Event) -> Events:
"""Create an event database object from a native event."""
return Events(
event_type=event.event_type,
event_data=None,
origin_idx=EVENT_ORIGIN_TO_IDX.get(event.origin),
time_fired=event.time_fired,
context_id=event.context.id,
context_user_id=event.context.user_id,
context_parent_id=event.context.parent_id,
)
def to_native(self, validate_entity_id: bool = True) -> Event | None:
"""Convert to a native HA Event."""
context = Context(
id=self.context_id,
user_id=self.context_user_id,
parent_id=self.context_parent_id,
)
try:
return Event(
self.event_type,
json.loads(self.event_data) if self.event_data else {},
EventOrigin(self.origin)
if self.origin
else EVENT_ORIGIN_ORDER[self.origin_idx],
process_timestamp(self.time_fired),
context=context,
)
except ValueError:
# When json.loads fails
_LOGGER.exception("Error converting to event: %s", self)
return None
class EventData(Base): # type: ignore[misc,valid-type]
"""Event data history."""
__table_args__ = (
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
)
__tablename__ = TABLE_EVENT_DATA
data_id = Column(Integer, Identity(), primary_key=True)
hash = Column(BigInteger, index=True)
# Note that this is not named attributes to avoid confusion with the states table
shared_data = Column(Text().with_variant(mysql.LONGTEXT, "mysql"))
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.EventData("
f"id={self.data_id}, hash='{self.hash}', data='{self.shared_data}'"
f")>"
)
@staticmethod
def from_event(event: Event) -> EventData:
"""Create object from an event."""
shared_data = JSON_DUMP(event.data)
return EventData(
shared_data=shared_data, hash=EventData.hash_shared_data(shared_data)
)
@staticmethod
def shared_data_from_event(event: Event) -> str:
"""Create shared_attrs from an event."""
return JSON_DUMP(event.data)
@staticmethod
def hash_shared_data(shared_data: str) -> int:
"""Return the hash of json encoded shared data."""
return cast(int, fnv1a_32(shared_data.encode("utf-8")))
def to_native(self) -> dict[str, Any]:
"""Convert to an HA state object."""
try:
return cast(dict[str, Any], json.loads(self.shared_data))
except ValueError:
_LOGGER.exception("Error converting row to event data: %s", self)
return {}
class States(Base): # type: ignore[misc,valid-type]
"""State change history."""
__table_args__ = (
# Used for fetching the state of entities at a specific time
# (get_states in history.py)
Index("ix_states_entity_id_last_updated", "entity_id", "last_updated"),
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
)
__tablename__ = TABLE_STATES
state_id = Column(Integer, Identity(), primary_key=True)
entity_id = Column(String(MAX_LENGTH_STATE_ENTITY_ID))
state = Column(String(MAX_LENGTH_STATE_STATE))
attributes = Column(Text().with_variant(mysql.LONGTEXT, "mysql"))
event_id = Column(
Integer, ForeignKey("events.event_id", ondelete="CASCADE"), index=True
)
last_changed = Column(DATETIME_TYPE, default=dt_util.utcnow)
last_updated = Column(DATETIME_TYPE, default=dt_util.utcnow, index=True)
old_state_id = Column(Integer, ForeignKey("states.state_id"), index=True)
attributes_id = Column(
Integer, ForeignKey("state_attributes.attributes_id"), index=True
)
context_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True)
context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID))
context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID))
origin_idx = Column(SmallInteger) # 0 is local, 1 is remote
old_state = relationship("States", remote_side=[state_id])
state_attributes = relationship("StateAttributes")
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.States("
f"id={self.state_id}, entity_id='{self.entity_id}', "
f"state='{self.state}', event_id='{self.event_id}', "
f"last_updated='{self.last_updated.isoformat(sep=' ', timespec='seconds')}', "
f"old_state_id={self.old_state_id}, attributes_id={self.attributes_id}"
f")>"
)
@staticmethod
def from_event(event: Event) -> States:
"""Create object from a state_changed event."""
entity_id = event.data["entity_id"]
state: State | None = event.data.get("new_state")
dbstate = States(
entity_id=entity_id,
attributes=None,
context_id=event.context.id,
context_user_id=event.context.user_id,
context_parent_id=event.context.parent_id,
origin_idx=EVENT_ORIGIN_TO_IDX.get(event.origin),
)
# None state means the state was removed from the state machine
if state is None:
dbstate.state = ""
dbstate.last_changed = event.time_fired
dbstate.last_updated = event.time_fired
else:
dbstate.state = state.state
dbstate.last_changed = state.last_changed
dbstate.last_updated = state.last_updated
return dbstate
def to_native(self, validate_entity_id: bool = True) -> State | None:
"""Convert to an HA state object."""
context = Context(
id=self.context_id,
user_id=self.context_user_id,
parent_id=self.context_parent_id,
)
try:
return State(
self.entity_id,
self.state,
# Join the state_attributes table on attributes_id to get the attributes
# for newer states
json.loads(self.attributes) if self.attributes else {},
process_timestamp(self.last_changed),
process_timestamp(self.last_updated),
context=context,
validate_entity_id=validate_entity_id,
)
except ValueError:
# When json.loads fails
_LOGGER.exception("Error converting row to state: %s", self)
return None
class StateAttributes(Base): # type: ignore[misc,valid-type]
"""State attribute change history."""
__table_args__ = (
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
)
__tablename__ = TABLE_STATE_ATTRIBUTES
attributes_id = Column(Integer, Identity(), primary_key=True)
hash = Column(BigInteger, index=True)
# Note that this is not named attributes to avoid confusion with the states table
shared_attrs = Column(Text().with_variant(mysql.LONGTEXT, "mysql"))
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.StateAttributes("
f"id={self.attributes_id}, hash='{self.hash}', attributes='{self.shared_attrs}'"
f")>"
)
@staticmethod
def from_event(event: Event) -> StateAttributes:
"""Create object from a state_changed event."""
state: State | None = event.data.get("new_state")
# None state means the state was removed from the state machine
dbstate = StateAttributes(
shared_attrs="{}" if state is None else JSON_DUMP(state.attributes)
)
dbstate.hash = StateAttributes.hash_shared_attrs(dbstate.shared_attrs)
return dbstate
@staticmethod
def shared_attrs_from_event(
event: Event, exclude_attrs_by_domain: dict[str, set[str]]
) -> str:
"""Create shared_attrs from a state_changed event."""
state: State | None = event.data.get("new_state")
# None state means the state was removed from the state machine
if state is None:
return "{}"
domain = split_entity_id(state.entity_id)[0]
exclude_attrs = (
exclude_attrs_by_domain.get(domain, set()) | ALL_DOMAIN_EXCLUDE_ATTRS
)
return JSON_DUMP(
{k: v for k, v in state.attributes.items() if k not in exclude_attrs}
)
@staticmethod
def hash_shared_attrs(shared_attrs: str) -> int:
"""Return the hash of json encoded shared attributes."""
return cast(int, fnv1a_32(shared_attrs.encode("utf-8")))
def to_native(self) -> dict[str, Any]:
"""Convert to an HA state object."""
try:
return cast(dict[str, Any], json.loads(self.shared_attrs))
except ValueError:
# When json.loads fails
_LOGGER.exception("Error converting row to state attributes: %s", self)
return {}
class StatisticResult(TypedDict):
"""Statistic result data class.
Allows multiple datapoints for the same statistic_id.
"""
meta: StatisticMetaData
stat: StatisticData
class StatisticDataBase(TypedDict):
"""Mandatory fields for statistic data class."""
start: datetime
class StatisticData(StatisticDataBase, total=False):
"""Statistic data class."""
mean: float
min: float
max: float
last_reset: datetime | None
state: float
sum: float
class StatisticsBase:
"""Statistics base class."""
id = Column(Integer, Identity(), primary_key=True)
created = Column(DATETIME_TYPE, default=dt_util.utcnow)
@declared_attr # type: ignore[misc]
def metadata_id(self) -> Column:
"""Define the metadata_id column for sub classes."""
return Column(
Integer,
ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"),
index=True,
)
start = Column(DATETIME_TYPE, index=True)
mean = Column(DOUBLE_TYPE)
min = Column(DOUBLE_TYPE)
max = Column(DOUBLE_TYPE)
last_reset = Column(DATETIME_TYPE)
state = Column(DOUBLE_TYPE)
sum = Column(DOUBLE_TYPE)
@classmethod
def from_stats(cls, metadata_id: int, stats: StatisticData) -> StatisticsBase:
"""Create object from a statistics."""
return cls( # type: ignore[call-arg,misc]
metadata_id=metadata_id,
**stats,
)
class Statistics(Base, StatisticsBase): # type: ignore[misc,valid-type]
"""Long term statistics."""
duration = timedelta(hours=1)
__table_args__ = (
# Used for fetching statistics for a certain entity at a specific time
Index("ix_statistics_statistic_id_start", "metadata_id", "start", unique=True),
)
__tablename__ = TABLE_STATISTICS
class StatisticsShortTerm(Base, StatisticsBase): # type: ignore[misc,valid-type]
"""Short term statistics."""
duration = timedelta(minutes=5)
__table_args__ = (
# Used for fetching statistics for a certain entity at a specific time
Index(
"ix_statistics_short_term_statistic_id_start",
"metadata_id",
"start",
unique=True,
),
)
__tablename__ = TABLE_STATISTICS_SHORT_TERM
class StatisticMetaData(TypedDict):
"""Statistic meta data class."""
has_mean: bool
has_sum: bool
name: str | None
source: str
statistic_id: str
unit_of_measurement: str | None
class StatisticsMeta(Base): # type: ignore[misc,valid-type]
"""Statistics meta data."""
__table_args__ = (
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
)
__tablename__ = TABLE_STATISTICS_META
id = Column(Integer, Identity(), primary_key=True)
statistic_id = Column(String(255), index=True)
source = Column(String(32))
unit_of_measurement = Column(String(255))
has_mean = Column(Boolean)
has_sum = Column(Boolean)
name = Column(String(255))
@staticmethod
def from_meta(meta: StatisticMetaData) -> StatisticsMeta:
"""Create object from meta data."""
return StatisticsMeta(**meta)
class RecorderRuns(Base): # type: ignore[misc,valid-type]
"""Representation of recorder run."""
__table_args__ = (Index("ix_recorder_runs_start_end", "start", "end"),)
__tablename__ = TABLE_RECORDER_RUNS
run_id = Column(Integer, Identity(), primary_key=True)
start = Column(DateTime(timezone=True), default=dt_util.utcnow)
end = Column(DateTime(timezone=True))
closed_incorrect = Column(Boolean, default=False)
created = Column(DateTime(timezone=True), default=dt_util.utcnow)
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
end = (
f"'{self.end.isoformat(sep=' ', timespec='seconds')}'" if self.end else None
)
return (
f"<recorder.RecorderRuns("
f"id={self.run_id}, start='{self.start.isoformat(sep=' ', timespec='seconds')}', "
f"end={end}, closed_incorrect={self.closed_incorrect}, "
f"created='{self.created.isoformat(sep=' ', timespec='seconds')}'"
f")>"
)
def entity_ids(self, point_in_time: datetime | None = None) -> list[str]:
"""Return the entity ids that existed in this run.
Specify point_in_time if you want to know which existed at that point
in time inside the run.
"""
session = Session.object_session(self)
assert session is not None, "RecorderRuns need to be persisted"
query = session.query(distinct(States.entity_id)).filter(
States.last_updated >= self.start
)
if point_in_time is not None:
query = query.filter(States.last_updated < point_in_time)
elif self.end is not None:
query = query.filter(States.last_updated < self.end)
return [row[0] for row in query]
def to_native(self, validate_entity_id: bool = True) -> RecorderRuns:
"""Return self, native format is this model."""
return self
class SchemaChanges(Base): # type: ignore[misc,valid-type]
"""Representation of schema version changes."""
__tablename__ = TABLE_SCHEMA_CHANGES
change_id = Column(Integer, Identity(), primary_key=True)
schema_version = Column(Integer)
changed = Column(DateTime(timezone=True), default=dt_util.utcnow)
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.SchemaChanges("
f"id={self.change_id}, schema_version={self.schema_version}, "
f"changed='{self.changed.isoformat(sep=' ', timespec='seconds')}'"
f")>"
)
class StatisticsRuns(Base): # type: ignore[misc,valid-type]
"""Representation of statistics run."""
__tablename__ = TABLE_STATISTICS_RUNS
run_id = Column(Integer, Identity(), primary_key=True)
start = Column(DateTime(timezone=True), index=True)
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.StatisticsRuns("
f"id={self.run_id}, start='{self.start.isoformat(sep=' ', timespec='seconds')}', "
f")>"
)
@overload
def process_timestamp(ts: None) -> None:
...
@overload
def process_timestamp(ts: datetime) -> datetime:
...
def process_timestamp(ts: datetime | None) -> datetime | None:
"""Process a timestamp into datetime object."""
if ts is None:
return None
if ts.tzinfo is None:
return ts.replace(tzinfo=dt_util.UTC)
return dt_util.as_utc(ts)
@overload
def process_timestamp_to_utc_isoformat(ts: None) -> None:
...
@overload
def process_timestamp_to_utc_isoformat(ts: datetime) -> str:
...
def process_timestamp_to_utc_isoformat(ts: datetime | None) -> str | None:
"""Process a timestamp into UTC isotime."""
if ts is None:
return None
if ts.tzinfo == dt_util.UTC:
return ts.isoformat()
if ts.tzinfo is None:
return f"{ts.isoformat()}{DB_TIMEZONE}"
return ts.astimezone(dt_util.UTC).isoformat()
class LazyState(State):
"""A lazy version of core State."""
__slots__ = [
"_row",
"_attributes",
"_last_changed",
"_last_updated",
"_context",
"_attr_cache",
]
def __init__( # pylint: disable=super-init-not-called
self, row: Row, attr_cache: dict[str, dict[str, Any]] | None = None
) -> None:
"""Init the lazy state."""
self._row = row
self.entity_id: str = self._row.entity_id
self.state = self._row.state or ""
self._attributes: dict[str, Any] | None = None
self._last_changed: datetime | None = None
self._last_updated: datetime | None = None
self._context: Context | None = None
self._attr_cache = attr_cache
@property # type: ignore[override]
def attributes(self) -> dict[str, Any]: # type: ignore[override]
"""State attributes."""
if self._attributes is None:
source = self._row.shared_attrs or self._row.attributes
if self._attr_cache is not None and (
attributes := self._attr_cache.get(source)
):
self._attributes = attributes
return attributes
if source == EMPTY_JSON_OBJECT or source is None:
self._attributes = {}
return self._attributes
try:
self._attributes = json.loads(source)
except ValueError:
# When json.loads fails
_LOGGER.exception(
"Error converting row to state attributes: %s", self._row
)
self._attributes = {}
if self._attr_cache is not None:
self._attr_cache[source] = self._attributes
return self._attributes
@attributes.setter
def attributes(self, value: dict[str, Any]) -> None:
"""Set attributes."""
self._attributes = value
@property # type: ignore[override]
def context(self) -> Context: # type: ignore[override]
"""State context."""
if self._context is None:
self._context = Context(id=None) # type: ignore[arg-type]
return self._context
@context.setter
def context(self, value: Context) -> None:
"""Set context."""
self._context = value
@property # type: ignore[override]
def last_changed(self) -> datetime: # type: ignore[override]
"""Last changed datetime."""
if self._last_changed is None:
self._last_changed = process_timestamp(self._row.last_changed)
return self._last_changed
@last_changed.setter
def last_changed(self, value: datetime) -> None:
"""Set last changed datetime."""
self._last_changed = value
@property # type: ignore[override]
def last_updated(self) -> datetime: # type: ignore[override]
"""Last updated datetime."""
if self._last_updated is None:
if (last_updated := self._row.last_updated) is not None:
self._last_updated = process_timestamp(last_updated)
else:
self._last_updated = self.last_changed
return self._last_updated
@last_updated.setter
def last_updated(self, value: datetime) -> None:
"""Set last updated datetime."""
self._last_updated = value
def as_dict(self) -> dict[str, Any]: # type: ignore[override]
"""Return a dict representation of the LazyState.
Async friendly.
To be used for JSON serialization.
"""
if self._last_changed is None and self._last_updated is None:
last_changed_isoformat = process_timestamp_to_utc_isoformat(
self._row.last_changed
)
if (
self._row.last_updated is None
or self._row.last_changed == self._row.last_updated
):
last_updated_isoformat = last_changed_isoformat
else:
last_updated_isoformat = process_timestamp_to_utc_isoformat(
self._row.last_updated
)
else:
last_changed_isoformat = self.last_changed.isoformat()
if self.last_changed == self.last_updated:
last_updated_isoformat = last_changed_isoformat
else:
last_updated_isoformat = self.last_updated.isoformat()
return {
"entity_id": self.entity_id,
"state": self.state,
"attributes": self._attributes or self.attributes,
"last_changed": last_changed_isoformat,
"last_updated": last_updated_isoformat,
}
def __eq__(self, other: Any) -> bool:
"""Return the comparison."""
return (
other.__class__ in [self.__class__, State]
and self.entity_id == other.entity_id
and self.state == other.state
and self.attributes == other.attributes
)

View file

@ -1,21 +1,26 @@
"""The tests for sensor recorder platform.""" """The tests for sensor recorder platform."""
# pylint: disable=protected-access,invalid-name # pylint: disable=protected-access,invalid-name
from datetime import timedelta from datetime import timedelta
import importlib
import sys
from unittest.mock import patch, sentinel from unittest.mock import patch, sentinel
import pytest import pytest
from pytest import approx from pytest import approx
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from homeassistant.components import recorder from homeassistant.components import recorder
from homeassistant.components.recorder import history, statistics from homeassistant.components.recorder import history, statistics
from homeassistant.components.recorder.const import DATA_INSTANCE from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX
from homeassistant.components.recorder.models import ( from homeassistant.components.recorder.models import (
StatisticsShortTerm, StatisticsShortTerm,
process_timestamp_to_utc_isoformat, process_timestamp_to_utc_isoformat,
) )
from homeassistant.components.recorder.statistics import ( from homeassistant.components.recorder.statistics import (
async_add_external_statistics, async_add_external_statistics,
delete_duplicates, delete_statistics_duplicates,
delete_statistics_meta_duplicates,
get_last_short_term_statistics, get_last_short_term_statistics,
get_last_statistics, get_last_statistics,
get_latest_short_term_statistics, get_latest_short_term_statistics,
@ -32,7 +37,7 @@ import homeassistant.util.dt as dt_util
from .common import async_wait_recording_done, do_adhoc_statistics from .common import async_wait_recording_done, do_adhoc_statistics
from tests.common import mock_registry from tests.common import get_test_home_assistant, mock_registry
from tests.components.recorder.common import wait_recording_done from tests.components.recorder.common import wait_recording_done
ORIG_TZ = dt_util.DEFAULT_TIME_ZONE ORIG_TZ = dt_util.DEFAULT_TIME_ZONE
@ -306,12 +311,92 @@ def test_rename_entity(hass_recorder):
entity_reg.async_update_entity("sensor.test1", new_entity_id="sensor.test99") entity_reg.async_update_entity("sensor.test1", new_entity_id="sensor.test99")
hass.add_job(rename_entry) hass.add_job(rename_entry)
hass.block_till_done() wait_recording_done(hass)
stats = statistics_during_period(hass, zero, period="5minute") stats = statistics_during_period(hass, zero, period="5minute")
assert stats == {"sensor.test99": expected_stats99, "sensor.test2": expected_stats2} assert stats == {"sensor.test99": expected_stats99, "sensor.test2": expected_stats2}
def test_rename_entity_collision(hass_recorder, caplog):
"""Test statistics is migrated when entity_id is changed."""
hass = hass_recorder()
setup_component(hass, "sensor", {})
entity_reg = mock_registry(hass)
@callback
def add_entry():
reg_entry = entity_reg.async_get_or_create(
"sensor",
"test",
"unique_0000",
suggested_object_id="test1",
)
assert reg_entry.entity_id == "sensor.test1"
hass.add_job(add_entry)
hass.block_till_done()
zero, four, states = record_states(hass)
hist = history.get_significant_states(hass, zero, four)
assert dict(states) == dict(hist)
for kwargs in ({}, {"statistic_ids": ["sensor.test1"]}):
stats = statistics_during_period(hass, zero, period="5minute", **kwargs)
assert stats == {}
stats = get_last_short_term_statistics(hass, 0, "sensor.test1", True)
assert stats == {}
do_adhoc_statistics(hass, start=zero)
wait_recording_done(hass)
expected_1 = {
"statistic_id": "sensor.test1",
"start": process_timestamp_to_utc_isoformat(zero),
"end": process_timestamp_to_utc_isoformat(zero + timedelta(minutes=5)),
"mean": approx(14.915254237288135),
"min": approx(10.0),
"max": approx(20.0),
"last_reset": None,
"state": None,
"sum": None,
}
expected_stats1 = [
{**expected_1, "statistic_id": "sensor.test1"},
]
expected_stats2 = [
{**expected_1, "statistic_id": "sensor.test2"},
]
stats = statistics_during_period(hass, zero, period="5minute")
assert stats == {"sensor.test1": expected_stats1, "sensor.test2": expected_stats2}
# Insert metadata for sensor.test99
metadata_1 = {
"has_mean": True,
"has_sum": False,
"name": "Total imported energy",
"source": "test",
"statistic_id": "sensor.test99",
"unit_of_measurement": "kWh",
}
with session_scope(hass=hass) as session:
session.add(recorder.models.StatisticsMeta.from_meta(metadata_1))
# Rename entity sensor.test1 to sensor.test99
@callback
def rename_entry():
entity_reg.async_update_entity("sensor.test1", new_entity_id="sensor.test99")
hass.add_job(rename_entry)
wait_recording_done(hass)
# Statistics failed to migrate due to the collision
stats = statistics_during_period(hass, zero, period="5minute")
assert stats == {"sensor.test1": expected_stats1, "sensor.test2": expected_stats2}
assert "Blocked attempt to insert duplicated statistic rows" in caplog.text
def test_statistics_duplicated(hass_recorder, caplog): def test_statistics_duplicated(hass_recorder, caplog):
"""Test statistics with same start time is not compiled.""" """Test statistics with same start time is not compiled."""
hass = hass_recorder() hass = hass_recorder()
@ -737,7 +822,7 @@ def test_delete_duplicates_no_duplicates(hass_recorder, caplog):
hass = hass_recorder() hass = hass_recorder()
wait_recording_done(hass) wait_recording_done(hass)
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
delete_duplicates(hass, session) delete_statistics_duplicates(hass, session)
assert "duplicated statistics rows" not in caplog.text assert "duplicated statistics rows" not in caplog.text
assert "Found non identical" not in caplog.text assert "Found non identical" not in caplog.text
assert "Found duplicated" not in caplog.text assert "Found duplicated" not in caplog.text
@ -800,6 +885,208 @@ def test_duplicate_statistics_handle_integrity_error(hass_recorder, caplog):
assert "Blocked attempt to insert duplicated statistic rows" in caplog.text assert "Blocked attempt to insert duplicated statistic rows" in caplog.text
def _create_engine_28(*args, **kwargs):
"""Test version of create_engine that initializes with old schema.
This simulates an existing db with the old schema.
"""
module = "tests.components.recorder.models_schema_28"
importlib.import_module(module)
old_models = sys.modules[module]
engine = create_engine(*args, **kwargs)
old_models.Base.metadata.create_all(engine)
with Session(engine) as session:
session.add(recorder.models.StatisticsRuns(start=statistics.get_start_time()))
session.add(
recorder.models.SchemaChanges(schema_version=old_models.SCHEMA_VERSION)
)
session.commit()
return engine
def test_delete_metadata_duplicates(caplog, tmpdir):
"""Test removal of duplicated statistics."""
test_db_file = tmpdir.mkdir("sqlite").join("test_run_info.db")
dburl = f"{SQLITE_URL_PREFIX}//{test_db_file}"
module = "tests.components.recorder.models_schema_28"
importlib.import_module(module)
old_models = sys.modules[module]
external_energy_metadata_1 = {
"has_mean": False,
"has_sum": True,
"name": "Total imported energy",
"source": "test",
"statistic_id": "test:total_energy_import_tariff_1",
"unit_of_measurement": "kWh",
}
external_energy_metadata_2 = {
"has_mean": False,
"has_sum": True,
"name": "Total imported energy",
"source": "test",
"statistic_id": "test:total_energy_import_tariff_1",
"unit_of_measurement": "kWh",
}
external_co2_metadata = {
"has_mean": True,
"has_sum": False,
"name": "Fossil percentage",
"source": "test",
"statistic_id": "test:fossil_percentage",
"unit_of_measurement": "%",
}
# Create some duplicated statistics_meta with schema version 28
with patch.object(recorder, "models", old_models), patch.object(
recorder.migration, "SCHEMA_VERSION", old_models.SCHEMA_VERSION
), patch(
"homeassistant.components.recorder.core.create_engine", new=_create_engine_28
):
hass = get_test_home_assistant()
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
wait_recording_done(hass)
wait_recording_done(hass)
with session_scope(hass=hass) as session:
session.add(
recorder.models.StatisticsMeta.from_meta(external_energy_metadata_1)
)
session.add(
recorder.models.StatisticsMeta.from_meta(external_energy_metadata_2)
)
session.add(recorder.models.StatisticsMeta.from_meta(external_co2_metadata))
with session_scope(hass=hass) as session:
tmp = session.query(recorder.models.StatisticsMeta).all()
assert len(tmp) == 3
assert tmp[0].id == 1
assert tmp[0].statistic_id == "test:total_energy_import_tariff_1"
assert tmp[1].id == 2
assert tmp[1].statistic_id == "test:total_energy_import_tariff_1"
assert tmp[2].id == 3
assert tmp[2].statistic_id == "test:fossil_percentage"
hass.stop()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ
# Test that the duplicates are removed during migration from schema 28
hass = get_test_home_assistant()
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
hass.start()
wait_recording_done(hass)
wait_recording_done(hass)
assert "Deleted 1 duplicated statistics_meta rows" in caplog.text
with session_scope(hass=hass) as session:
tmp = session.query(recorder.models.StatisticsMeta).all()
assert len(tmp) == 2
assert tmp[0].id == 2
assert tmp[0].statistic_id == "test:total_energy_import_tariff_1"
assert tmp[1].id == 3
assert tmp[1].statistic_id == "test:fossil_percentage"
hass.stop()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ
def test_delete_metadata_duplicates_many(caplog, tmpdir):
"""Test removal of duplicated statistics."""
test_db_file = tmpdir.mkdir("sqlite").join("test_run_info.db")
dburl = f"{SQLITE_URL_PREFIX}//{test_db_file}"
module = "tests.components.recorder.models_schema_28"
importlib.import_module(module)
old_models = sys.modules[module]
external_energy_metadata_1 = {
"has_mean": False,
"has_sum": True,
"name": "Total imported energy",
"source": "test",
"statistic_id": "test:total_energy_import_tariff_1",
"unit_of_measurement": "kWh",
}
external_energy_metadata_2 = {
"has_mean": False,
"has_sum": True,
"name": "Total imported energy",
"source": "test",
"statistic_id": "test:total_energy_import_tariff_2",
"unit_of_measurement": "kWh",
}
external_co2_metadata = {
"has_mean": True,
"has_sum": False,
"name": "Fossil percentage",
"source": "test",
"statistic_id": "test:fossil_percentage",
"unit_of_measurement": "%",
}
# Create some duplicated statistics with schema version 28
with patch.object(recorder, "models", old_models), patch.object(
recorder.migration, "SCHEMA_VERSION", old_models.SCHEMA_VERSION
), patch(
"homeassistant.components.recorder.core.create_engine", new=_create_engine_28
):
hass = get_test_home_assistant()
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
wait_recording_done(hass)
wait_recording_done(hass)
with session_scope(hass=hass) as session:
session.add(
recorder.models.StatisticsMeta.from_meta(external_energy_metadata_1)
)
for _ in range(3000):
session.add(
recorder.models.StatisticsMeta.from_meta(external_energy_metadata_1)
)
session.add(
recorder.models.StatisticsMeta.from_meta(external_energy_metadata_2)
)
session.add(
recorder.models.StatisticsMeta.from_meta(external_energy_metadata_2)
)
session.add(recorder.models.StatisticsMeta.from_meta(external_co2_metadata))
session.add(recorder.models.StatisticsMeta.from_meta(external_co2_metadata))
hass.stop()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ
# Test that the duplicates are removed during migration from schema 28
hass = get_test_home_assistant()
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
hass.start()
wait_recording_done(hass)
wait_recording_done(hass)
assert "Deleted 3002 duplicated statistics_meta rows" in caplog.text
with session_scope(hass=hass) as session:
tmp = session.query(recorder.models.StatisticsMeta).all()
assert len(tmp) == 3
assert tmp[0].id == 3001
assert tmp[0].statistic_id == "test:total_energy_import_tariff_1"
assert tmp[1].id == 3003
assert tmp[1].statistic_id == "test:total_energy_import_tariff_2"
assert tmp[2].id == 3005
assert tmp[2].statistic_id == "test:fossil_percentage"
hass.stop()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ
def test_delete_metadata_duplicates_no_duplicates(hass_recorder, caplog):
"""Test removal of duplicated statistics."""
hass = hass_recorder()
wait_recording_done(hass)
with session_scope(hass=hass) as session:
delete_statistics_meta_duplicates(session)
assert "duplicated statistics_meta rows" not in caplog.text
def record_states(hass): def record_states(hass):
"""Record some test states. """Record some test states.