Prevent duplication of statistics metadata (#71637)

* Prevent duplication of statistics metadata

* Add models_schema_28.py

* Handle entity renaming as a recorder job

* Improve tests
This commit is contained in:
Erik Montnemery 2022-05-24 15:34:46 +02:00 committed by GitHub
parent d620072585
commit 23bd64b7a2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 1175 additions and 32 deletions

View file

@ -35,6 +35,7 @@ from homeassistant.helpers.event import (
async_track_time_interval,
async_track_utc_time_change,
)
from homeassistant.helpers.typing import UNDEFINED, UndefinedType
import homeassistant.util.dt as dt_util
from . import migration, statistics
@ -461,10 +462,18 @@ class Recorder(threading.Thread):
@callback
def async_update_statistics_metadata(
self, statistic_id: str, unit_of_measurement: str | None
self,
statistic_id: str,
*,
new_statistic_id: str | UndefinedType = UNDEFINED,
new_unit_of_measurement: str | None | UndefinedType = UNDEFINED,
) -> None:
"""Update statistics metadata for a statistic_id."""
self.queue_task(UpdateStatisticsMetadataTask(statistic_id, unit_of_measurement))
self.queue_task(
UpdateStatisticsMetadataTask(
statistic_id, new_statistic_id, new_unit_of_measurement
)
)
@callback
def async_external_statistics(

View file

@ -33,7 +33,11 @@ from .models import (
StatisticsShortTerm,
process_timestamp,
)
from .statistics import delete_duplicates, get_start_time
from .statistics import (
delete_statistics_duplicates,
delete_statistics_meta_duplicates,
get_start_time,
)
from .util import session_scope
_LOGGER = logging.getLogger(__name__)
@ -670,7 +674,7 @@ def _apply_update( # noqa: C901
# There may be duplicated statistics entries, delete duplicated statistics
# and try again
with session_scope(session=session_maker()) as session:
delete_duplicates(hass, session)
delete_statistics_duplicates(hass, session)
_create_index(
session_maker, "statistics", "ix_statistics_statistic_id_start"
)
@ -705,6 +709,21 @@ def _apply_update( # noqa: C901
_create_index(session_maker, "states", "ix_states_context_id")
# Once there are no longer any state_changed events
# in the events table we can drop the index on states.event_id
elif new_version == 29:
# Recreate statistics_meta index to block duplicated statistic_id
_drop_index(session_maker, "statistics_meta", "ix_statistics_meta_statistic_id")
try:
_create_index(
session_maker, "statistics_meta", "ix_statistics_meta_statistic_id"
)
except DatabaseError:
# There may be duplicated statistics_meta entries, delete duplicates
# and try again
with session_scope(session=session_maker()) as session:
delete_statistics_meta_duplicates(session)
_create_index(
session_maker, "statistics_meta", "ix_statistics_meta_statistic_id"
)
else:
raise ValueError(f"No schema migration defined for version {new_version}")

View file

@ -51,7 +51,7 @@ from .const import ALL_DOMAIN_EXCLUDE_ATTRS, JSON_DUMP
# pylint: disable=invalid-name
Base = declarative_base()
SCHEMA_VERSION = 28
SCHEMA_VERSION = 29
_LOGGER = logging.getLogger(__name__)
@ -515,7 +515,7 @@ class StatisticsMeta(Base): # type: ignore[misc,valid-type]
)
__tablename__ = TABLE_STATISTICS_META
id = Column(Integer, Identity(), primary_key=True)
statistic_id = Column(String(255), index=True)
statistic_id = Column(String(255), index=True, unique=True)
source = Column(String(32))
unit_of_measurement = Column(String(255))
has_mean = Column(Boolean)

View file

@ -33,6 +33,7 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity_registry
from homeassistant.helpers.json import JSONEncoder
from homeassistant.helpers.storage import STORAGE_DIR
from homeassistant.helpers.typing import UNDEFINED, UndefinedType
import homeassistant.util.dt as dt_util
import homeassistant.util.pressure as pressure_util
import homeassistant.util.temperature as temperature_util
@ -208,18 +209,11 @@ class ValidationIssue:
def async_setup(hass: HomeAssistant) -> None:
"""Set up the history hooks."""
def _entity_id_changed(event: Event) -> None:
"""Handle entity_id changed."""
old_entity_id = event.data["old_entity_id"]
entity_id = event.data["entity_id"]
with session_scope(hass=hass) as session:
session.query(StatisticsMeta).filter(
(StatisticsMeta.statistic_id == old_entity_id)
& (StatisticsMeta.source == DOMAIN)
).update({StatisticsMeta.statistic_id: entity_id})
async def _async_entity_id_changed(event: Event) -> None:
await hass.data[DATA_INSTANCE].async_add_executor_job(_entity_id_changed, event)
@callback
def _async_entity_id_changed(event: Event) -> None:
hass.data[DATA_INSTANCE].async_update_statistics_metadata(
event.data["old_entity_id"], new_statistic_id=event.data["entity_id"]
)
@callback
def entity_registry_changed_filter(event: Event) -> bool:
@ -380,7 +374,7 @@ def _delete_duplicates_from_table(
return (total_deleted_rows, all_non_identical_duplicates)
def delete_duplicates(hass: HomeAssistant, session: Session) -> None:
def delete_statistics_duplicates(hass: HomeAssistant, session: Session) -> None:
"""Identify and delete duplicated statistics.
A backup will be made of duplicated statistics before it is deleted.
@ -423,6 +417,69 @@ def delete_duplicates(hass: HomeAssistant, session: Session) -> None:
)
def _find_statistics_meta_duplicates(session: Session) -> list[int]:
"""Find duplicated statistics_meta."""
subquery = (
session.query(
StatisticsMeta.statistic_id,
literal_column("1").label("is_duplicate"),
)
.group_by(StatisticsMeta.statistic_id)
.having(func.count() > 1)
.subquery()
)
query = (
session.query(StatisticsMeta)
.outerjoin(
subquery,
(subquery.c.statistic_id == StatisticsMeta.statistic_id),
)
.filter(subquery.c.is_duplicate == 1)
.order_by(StatisticsMeta.statistic_id, StatisticsMeta.id.desc())
.limit(1000 * MAX_ROWS_TO_PURGE)
)
duplicates = execute(query)
statistic_id = None
duplicate_ids: list[int] = []
if not duplicates:
return duplicate_ids
for duplicate in duplicates:
if statistic_id != duplicate.statistic_id:
statistic_id = duplicate.statistic_id
continue
duplicate_ids.append(duplicate.id)
return duplicate_ids
def _delete_statistics_meta_duplicates(session: Session) -> int:
"""Identify and delete duplicated statistics from a specified table."""
total_deleted_rows = 0
while True:
duplicate_ids = _find_statistics_meta_duplicates(session)
if not duplicate_ids:
break
for i in range(0, len(duplicate_ids), MAX_ROWS_TO_PURGE):
deleted_rows = (
session.query(StatisticsMeta)
.filter(StatisticsMeta.id.in_(duplicate_ids[i : i + MAX_ROWS_TO_PURGE]))
.delete(synchronize_session=False)
)
total_deleted_rows += deleted_rows
return total_deleted_rows
def delete_statistics_meta_duplicates(session: Session) -> None:
"""Identify and delete duplicated statistics_meta."""
deleted_statistics_rows = _delete_statistics_meta_duplicates(session)
if deleted_statistics_rows:
_LOGGER.info(
"Deleted %s duplicated statistics_meta rows", deleted_statistics_rows
)
def _compile_hourly_statistics_summary_mean_stmt(
start_time: datetime, end_time: datetime
) -> StatementLambdaElement:
@ -736,13 +793,26 @@ def clear_statistics(instance: Recorder, statistic_ids: list[str]) -> None:
def update_statistics_metadata(
instance: Recorder, statistic_id: str, unit_of_measurement: str | None
instance: Recorder,
statistic_id: str,
new_statistic_id: str | None | UndefinedType,
new_unit_of_measurement: str | None | UndefinedType,
) -> None:
"""Update statistics metadata for a statistic_id."""
with session_scope(session=instance.get_session()) as session:
session.query(StatisticsMeta).filter(
StatisticsMeta.statistic_id == statistic_id
).update({StatisticsMeta.unit_of_measurement: unit_of_measurement})
if new_unit_of_measurement is not UNDEFINED:
with session_scope(session=instance.get_session()) as session:
session.query(StatisticsMeta).filter(
StatisticsMeta.statistic_id == statistic_id
).update({StatisticsMeta.unit_of_measurement: new_unit_of_measurement})
if new_statistic_id is not UNDEFINED:
with session_scope(
session=instance.get_session(),
exception_filter=_filter_unique_constraint_integrity_error(instance),
) as session:
session.query(StatisticsMeta).filter(
(StatisticsMeta.statistic_id == statistic_id)
& (StatisticsMeta.source == DOMAIN)
).update({StatisticsMeta.statistic_id: new_statistic_id})
def list_statistic_ids(

View file

@ -10,6 +10,7 @@ import threading
from typing import TYPE_CHECKING, Any
from homeassistant.core import Event
from homeassistant.helpers.typing import UndefinedType
from . import purge, statistics
from .const import DOMAIN, EXCLUDE_ATTRIBUTES
@ -46,12 +47,16 @@ class UpdateStatisticsMetadataTask(RecorderTask):
"""Object to store statistics_id and unit for update of statistics metadata."""
statistic_id: str
unit_of_measurement: str | None
new_statistic_id: str | None | UndefinedType
new_unit_of_measurement: str | None | UndefinedType
def run(self, instance: Recorder) -> None:
"""Handle the task."""
statistics.update_statistics_metadata(
instance, self.statistic_id, self.unit_of_measurement
instance,
self.statistic_id,
self.new_statistic_id,
self.new_unit_of_measurement,
)

View file

@ -103,7 +103,7 @@ def ws_update_statistics_metadata(
) -> None:
"""Update statistics metadata for a statistic_id."""
hass.data[DATA_INSTANCE].async_update_statistics_metadata(
msg["statistic_id"], msg["unit_of_measurement"]
msg["statistic_id"], new_unit_of_measurement=msg["unit_of_measurement"]
)
connection.send_result(msg["id"])

View file

@ -0,0 +1,753 @@
"""Models for SQLAlchemy.
This file contains the model definitions for schema version 28.
It is used to test the schema migration logic.
"""
from __future__ import annotations
from datetime import datetime, timedelta
import json
import logging
from typing import Any, TypedDict, cast, overload
from fnvhash import fnv1a_32
from sqlalchemy import (
BigInteger,
Boolean,
Column,
DateTime,
Float,
ForeignKey,
Identity,
Index,
Integer,
SmallInteger,
String,
Text,
distinct,
)
from sqlalchemy.dialects import mysql, oracle, postgresql
from sqlalchemy.engine.row import Row
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy.orm.session import Session
from homeassistant.components.recorder.const import ALL_DOMAIN_EXCLUDE_ATTRS, JSON_DUMP
from homeassistant.const import (
MAX_LENGTH_EVENT_CONTEXT_ID,
MAX_LENGTH_EVENT_EVENT_TYPE,
MAX_LENGTH_EVENT_ORIGIN,
MAX_LENGTH_STATE_ENTITY_ID,
MAX_LENGTH_STATE_STATE,
)
from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id
import homeassistant.util.dt as dt_util
# SQLAlchemy Schema
# pylint: disable=invalid-name
Base = declarative_base()
SCHEMA_VERSION = 28
_LOGGER = logging.getLogger(__name__)
DB_TIMEZONE = "+00:00"
TABLE_EVENTS = "events"
TABLE_EVENT_DATA = "event_data"
TABLE_STATES = "states"
TABLE_STATE_ATTRIBUTES = "state_attributes"
TABLE_RECORDER_RUNS = "recorder_runs"
TABLE_SCHEMA_CHANGES = "schema_changes"
TABLE_STATISTICS = "statistics"
TABLE_STATISTICS_META = "statistics_meta"
TABLE_STATISTICS_RUNS = "statistics_runs"
TABLE_STATISTICS_SHORT_TERM = "statistics_short_term"
ALL_TABLES = [
TABLE_STATES,
TABLE_STATE_ATTRIBUTES,
TABLE_EVENTS,
TABLE_EVENT_DATA,
TABLE_RECORDER_RUNS,
TABLE_SCHEMA_CHANGES,
TABLE_STATISTICS,
TABLE_STATISTICS_META,
TABLE_STATISTICS_RUNS,
TABLE_STATISTICS_SHORT_TERM,
]
TABLES_TO_CHECK = [
TABLE_STATES,
TABLE_EVENTS,
TABLE_RECORDER_RUNS,
TABLE_SCHEMA_CHANGES,
]
EMPTY_JSON_OBJECT = "{}"
DATETIME_TYPE = DateTime(timezone=True).with_variant(
mysql.DATETIME(timezone=True, fsp=6), "mysql"
)
DOUBLE_TYPE = (
Float()
.with_variant(mysql.DOUBLE(asdecimal=False), "mysql")
.with_variant(oracle.DOUBLE_PRECISION(), "oracle")
.with_variant(postgresql.DOUBLE_PRECISION(), "postgresql")
)
EVENT_ORIGIN_ORDER = [EventOrigin.local, EventOrigin.remote]
EVENT_ORIGIN_TO_IDX = {origin: idx for idx, origin in enumerate(EVENT_ORIGIN_ORDER)}
class Events(Base): # type: ignore[misc,valid-type]
"""Event history data."""
__table_args__ = (
# Used for fetching events at a specific time
# see logbook
Index("ix_events_event_type_time_fired", "event_type", "time_fired"),
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
)
__tablename__ = TABLE_EVENTS
event_id = Column(Integer, Identity(), primary_key=True) # no longer used
event_type = Column(String(MAX_LENGTH_EVENT_EVENT_TYPE))
event_data = Column(Text().with_variant(mysql.LONGTEXT, "mysql"))
origin = Column(String(MAX_LENGTH_EVENT_ORIGIN)) # no longer used
origin_idx = Column(SmallInteger)
time_fired = Column(DATETIME_TYPE, index=True)
context_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True)
context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID))
context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID))
data_id = Column(Integer, ForeignKey("event_data.data_id"), index=True)
event_data_rel = relationship("EventData")
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.Events("
f"id={self.event_id}, type='{self.event_type}', "
f"origin_idx='{self.origin_idx}', time_fired='{self.time_fired}'"
f", data_id={self.data_id})>"
)
@staticmethod
def from_event(event: Event) -> Events:
"""Create an event database object from a native event."""
return Events(
event_type=event.event_type,
event_data=None,
origin_idx=EVENT_ORIGIN_TO_IDX.get(event.origin),
time_fired=event.time_fired,
context_id=event.context.id,
context_user_id=event.context.user_id,
context_parent_id=event.context.parent_id,
)
def to_native(self, validate_entity_id: bool = True) -> Event | None:
"""Convert to a native HA Event."""
context = Context(
id=self.context_id,
user_id=self.context_user_id,
parent_id=self.context_parent_id,
)
try:
return Event(
self.event_type,
json.loads(self.event_data) if self.event_data else {},
EventOrigin(self.origin)
if self.origin
else EVENT_ORIGIN_ORDER[self.origin_idx],
process_timestamp(self.time_fired),
context=context,
)
except ValueError:
# When json.loads fails
_LOGGER.exception("Error converting to event: %s", self)
return None
class EventData(Base): # type: ignore[misc,valid-type]
"""Event data history."""
__table_args__ = (
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
)
__tablename__ = TABLE_EVENT_DATA
data_id = Column(Integer, Identity(), primary_key=True)
hash = Column(BigInteger, index=True)
# Note that this is not named attributes to avoid confusion with the states table
shared_data = Column(Text().with_variant(mysql.LONGTEXT, "mysql"))
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.EventData("
f"id={self.data_id}, hash='{self.hash}', data='{self.shared_data}'"
f")>"
)
@staticmethod
def from_event(event: Event) -> EventData:
"""Create object from an event."""
shared_data = JSON_DUMP(event.data)
return EventData(
shared_data=shared_data, hash=EventData.hash_shared_data(shared_data)
)
@staticmethod
def shared_data_from_event(event: Event) -> str:
"""Create shared_attrs from an event."""
return JSON_DUMP(event.data)
@staticmethod
def hash_shared_data(shared_data: str) -> int:
"""Return the hash of json encoded shared data."""
return cast(int, fnv1a_32(shared_data.encode("utf-8")))
def to_native(self) -> dict[str, Any]:
"""Convert to an HA state object."""
try:
return cast(dict[str, Any], json.loads(self.shared_data))
except ValueError:
_LOGGER.exception("Error converting row to event data: %s", self)
return {}
class States(Base): # type: ignore[misc,valid-type]
"""State change history."""
__table_args__ = (
# Used for fetching the state of entities at a specific time
# (get_states in history.py)
Index("ix_states_entity_id_last_updated", "entity_id", "last_updated"),
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
)
__tablename__ = TABLE_STATES
state_id = Column(Integer, Identity(), primary_key=True)
entity_id = Column(String(MAX_LENGTH_STATE_ENTITY_ID))
state = Column(String(MAX_LENGTH_STATE_STATE))
attributes = Column(Text().with_variant(mysql.LONGTEXT, "mysql"))
event_id = Column(
Integer, ForeignKey("events.event_id", ondelete="CASCADE"), index=True
)
last_changed = Column(DATETIME_TYPE, default=dt_util.utcnow)
last_updated = Column(DATETIME_TYPE, default=dt_util.utcnow, index=True)
old_state_id = Column(Integer, ForeignKey("states.state_id"), index=True)
attributes_id = Column(
Integer, ForeignKey("state_attributes.attributes_id"), index=True
)
context_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True)
context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID))
context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID))
origin_idx = Column(SmallInteger) # 0 is local, 1 is remote
old_state = relationship("States", remote_side=[state_id])
state_attributes = relationship("StateAttributes")
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.States("
f"id={self.state_id}, entity_id='{self.entity_id}', "
f"state='{self.state}', event_id='{self.event_id}', "
f"last_updated='{self.last_updated.isoformat(sep=' ', timespec='seconds')}', "
f"old_state_id={self.old_state_id}, attributes_id={self.attributes_id}"
f")>"
)
@staticmethod
def from_event(event: Event) -> States:
"""Create object from a state_changed event."""
entity_id = event.data["entity_id"]
state: State | None = event.data.get("new_state")
dbstate = States(
entity_id=entity_id,
attributes=None,
context_id=event.context.id,
context_user_id=event.context.user_id,
context_parent_id=event.context.parent_id,
origin_idx=EVENT_ORIGIN_TO_IDX.get(event.origin),
)
# None state means the state was removed from the state machine
if state is None:
dbstate.state = ""
dbstate.last_changed = event.time_fired
dbstate.last_updated = event.time_fired
else:
dbstate.state = state.state
dbstate.last_changed = state.last_changed
dbstate.last_updated = state.last_updated
return dbstate
def to_native(self, validate_entity_id: bool = True) -> State | None:
"""Convert to an HA state object."""
context = Context(
id=self.context_id,
user_id=self.context_user_id,
parent_id=self.context_parent_id,
)
try:
return State(
self.entity_id,
self.state,
# Join the state_attributes table on attributes_id to get the attributes
# for newer states
json.loads(self.attributes) if self.attributes else {},
process_timestamp(self.last_changed),
process_timestamp(self.last_updated),
context=context,
validate_entity_id=validate_entity_id,
)
except ValueError:
# When json.loads fails
_LOGGER.exception("Error converting row to state: %s", self)
return None
class StateAttributes(Base): # type: ignore[misc,valid-type]
"""State attribute change history."""
__table_args__ = (
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
)
__tablename__ = TABLE_STATE_ATTRIBUTES
attributes_id = Column(Integer, Identity(), primary_key=True)
hash = Column(BigInteger, index=True)
# Note that this is not named attributes to avoid confusion with the states table
shared_attrs = Column(Text().with_variant(mysql.LONGTEXT, "mysql"))
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.StateAttributes("
f"id={self.attributes_id}, hash='{self.hash}', attributes='{self.shared_attrs}'"
f")>"
)
@staticmethod
def from_event(event: Event) -> StateAttributes:
"""Create object from a state_changed event."""
state: State | None = event.data.get("new_state")
# None state means the state was removed from the state machine
dbstate = StateAttributes(
shared_attrs="{}" if state is None else JSON_DUMP(state.attributes)
)
dbstate.hash = StateAttributes.hash_shared_attrs(dbstate.shared_attrs)
return dbstate
@staticmethod
def shared_attrs_from_event(
event: Event, exclude_attrs_by_domain: dict[str, set[str]]
) -> str:
"""Create shared_attrs from a state_changed event."""
state: State | None = event.data.get("new_state")
# None state means the state was removed from the state machine
if state is None:
return "{}"
domain = split_entity_id(state.entity_id)[0]
exclude_attrs = (
exclude_attrs_by_domain.get(domain, set()) | ALL_DOMAIN_EXCLUDE_ATTRS
)
return JSON_DUMP(
{k: v for k, v in state.attributes.items() if k not in exclude_attrs}
)
@staticmethod
def hash_shared_attrs(shared_attrs: str) -> int:
"""Return the hash of json encoded shared attributes."""
return cast(int, fnv1a_32(shared_attrs.encode("utf-8")))
def to_native(self) -> dict[str, Any]:
"""Convert to an HA state object."""
try:
return cast(dict[str, Any], json.loads(self.shared_attrs))
except ValueError:
# When json.loads fails
_LOGGER.exception("Error converting row to state attributes: %s", self)
return {}
class StatisticResult(TypedDict):
"""Statistic result data class.
Allows multiple datapoints for the same statistic_id.
"""
meta: StatisticMetaData
stat: StatisticData
class StatisticDataBase(TypedDict):
"""Mandatory fields for statistic data class."""
start: datetime
class StatisticData(StatisticDataBase, total=False):
"""Statistic data class."""
mean: float
min: float
max: float
last_reset: datetime | None
state: float
sum: float
class StatisticsBase:
"""Statistics base class."""
id = Column(Integer, Identity(), primary_key=True)
created = Column(DATETIME_TYPE, default=dt_util.utcnow)
@declared_attr # type: ignore[misc]
def metadata_id(self) -> Column:
"""Define the metadata_id column for sub classes."""
return Column(
Integer,
ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"),
index=True,
)
start = Column(DATETIME_TYPE, index=True)
mean = Column(DOUBLE_TYPE)
min = Column(DOUBLE_TYPE)
max = Column(DOUBLE_TYPE)
last_reset = Column(DATETIME_TYPE)
state = Column(DOUBLE_TYPE)
sum = Column(DOUBLE_TYPE)
@classmethod
def from_stats(cls, metadata_id: int, stats: StatisticData) -> StatisticsBase:
"""Create object from a statistics."""
return cls( # type: ignore[call-arg,misc]
metadata_id=metadata_id,
**stats,
)
class Statistics(Base, StatisticsBase): # type: ignore[misc,valid-type]
"""Long term statistics."""
duration = timedelta(hours=1)
__table_args__ = (
# Used for fetching statistics for a certain entity at a specific time
Index("ix_statistics_statistic_id_start", "metadata_id", "start", unique=True),
)
__tablename__ = TABLE_STATISTICS
class StatisticsShortTerm(Base, StatisticsBase): # type: ignore[misc,valid-type]
"""Short term statistics."""
duration = timedelta(minutes=5)
__table_args__ = (
# Used for fetching statistics for a certain entity at a specific time
Index(
"ix_statistics_short_term_statistic_id_start",
"metadata_id",
"start",
unique=True,
),
)
__tablename__ = TABLE_STATISTICS_SHORT_TERM
class StatisticMetaData(TypedDict):
"""Statistic meta data class."""
has_mean: bool
has_sum: bool
name: str | None
source: str
statistic_id: str
unit_of_measurement: str | None
class StatisticsMeta(Base): # type: ignore[misc,valid-type]
"""Statistics meta data."""
__table_args__ = (
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
)
__tablename__ = TABLE_STATISTICS_META
id = Column(Integer, Identity(), primary_key=True)
statistic_id = Column(String(255), index=True)
source = Column(String(32))
unit_of_measurement = Column(String(255))
has_mean = Column(Boolean)
has_sum = Column(Boolean)
name = Column(String(255))
@staticmethod
def from_meta(meta: StatisticMetaData) -> StatisticsMeta:
"""Create object from meta data."""
return StatisticsMeta(**meta)
class RecorderRuns(Base): # type: ignore[misc,valid-type]
"""Representation of recorder run."""
__table_args__ = (Index("ix_recorder_runs_start_end", "start", "end"),)
__tablename__ = TABLE_RECORDER_RUNS
run_id = Column(Integer, Identity(), primary_key=True)
start = Column(DateTime(timezone=True), default=dt_util.utcnow)
end = Column(DateTime(timezone=True))
closed_incorrect = Column(Boolean, default=False)
created = Column(DateTime(timezone=True), default=dt_util.utcnow)
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
end = (
f"'{self.end.isoformat(sep=' ', timespec='seconds')}'" if self.end else None
)
return (
f"<recorder.RecorderRuns("
f"id={self.run_id}, start='{self.start.isoformat(sep=' ', timespec='seconds')}', "
f"end={end}, closed_incorrect={self.closed_incorrect}, "
f"created='{self.created.isoformat(sep=' ', timespec='seconds')}'"
f")>"
)
def entity_ids(self, point_in_time: datetime | None = None) -> list[str]:
"""Return the entity ids that existed in this run.
Specify point_in_time if you want to know which existed at that point
in time inside the run.
"""
session = Session.object_session(self)
assert session is not None, "RecorderRuns need to be persisted"
query = session.query(distinct(States.entity_id)).filter(
States.last_updated >= self.start
)
if point_in_time is not None:
query = query.filter(States.last_updated < point_in_time)
elif self.end is not None:
query = query.filter(States.last_updated < self.end)
return [row[0] for row in query]
def to_native(self, validate_entity_id: bool = True) -> RecorderRuns:
"""Return self, native format is this model."""
return self
class SchemaChanges(Base): # type: ignore[misc,valid-type]
"""Representation of schema version changes."""
__tablename__ = TABLE_SCHEMA_CHANGES
change_id = Column(Integer, Identity(), primary_key=True)
schema_version = Column(Integer)
changed = Column(DateTime(timezone=True), default=dt_util.utcnow)
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.SchemaChanges("
f"id={self.change_id}, schema_version={self.schema_version}, "
f"changed='{self.changed.isoformat(sep=' ', timespec='seconds')}'"
f")>"
)
class StatisticsRuns(Base): # type: ignore[misc,valid-type]
"""Representation of statistics run."""
__tablename__ = TABLE_STATISTICS_RUNS
run_id = Column(Integer, Identity(), primary_key=True)
start = Column(DateTime(timezone=True), index=True)
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.StatisticsRuns("
f"id={self.run_id}, start='{self.start.isoformat(sep=' ', timespec='seconds')}', "
f")>"
)
@overload
def process_timestamp(ts: None) -> None:
...
@overload
def process_timestamp(ts: datetime) -> datetime:
...
def process_timestamp(ts: datetime | None) -> datetime | None:
"""Process a timestamp into datetime object."""
if ts is None:
return None
if ts.tzinfo is None:
return ts.replace(tzinfo=dt_util.UTC)
return dt_util.as_utc(ts)
@overload
def process_timestamp_to_utc_isoformat(ts: None) -> None:
...
@overload
def process_timestamp_to_utc_isoformat(ts: datetime) -> str:
...
def process_timestamp_to_utc_isoformat(ts: datetime | None) -> str | None:
"""Process a timestamp into UTC isotime."""
if ts is None:
return None
if ts.tzinfo == dt_util.UTC:
return ts.isoformat()
if ts.tzinfo is None:
return f"{ts.isoformat()}{DB_TIMEZONE}"
return ts.astimezone(dt_util.UTC).isoformat()
class LazyState(State):
"""A lazy version of core State."""
__slots__ = [
"_row",
"_attributes",
"_last_changed",
"_last_updated",
"_context",
"_attr_cache",
]
def __init__( # pylint: disable=super-init-not-called
self, row: Row, attr_cache: dict[str, dict[str, Any]] | None = None
) -> None:
"""Init the lazy state."""
self._row = row
self.entity_id: str = self._row.entity_id
self.state = self._row.state or ""
self._attributes: dict[str, Any] | None = None
self._last_changed: datetime | None = None
self._last_updated: datetime | None = None
self._context: Context | None = None
self._attr_cache = attr_cache
@property # type: ignore[override]
def attributes(self) -> dict[str, Any]: # type: ignore[override]
"""State attributes."""
if self._attributes is None:
source = self._row.shared_attrs or self._row.attributes
if self._attr_cache is not None and (
attributes := self._attr_cache.get(source)
):
self._attributes = attributes
return attributes
if source == EMPTY_JSON_OBJECT or source is None:
self._attributes = {}
return self._attributes
try:
self._attributes = json.loads(source)
except ValueError:
# When json.loads fails
_LOGGER.exception(
"Error converting row to state attributes: %s", self._row
)
self._attributes = {}
if self._attr_cache is not None:
self._attr_cache[source] = self._attributes
return self._attributes
@attributes.setter
def attributes(self, value: dict[str, Any]) -> None:
"""Set attributes."""
self._attributes = value
@property # type: ignore[override]
def context(self) -> Context: # type: ignore[override]
"""State context."""
if self._context is None:
self._context = Context(id=None) # type: ignore[arg-type]
return self._context
@context.setter
def context(self, value: Context) -> None:
"""Set context."""
self._context = value
@property # type: ignore[override]
def last_changed(self) -> datetime: # type: ignore[override]
"""Last changed datetime."""
if self._last_changed is None:
self._last_changed = process_timestamp(self._row.last_changed)
return self._last_changed
@last_changed.setter
def last_changed(self, value: datetime) -> None:
"""Set last changed datetime."""
self._last_changed = value
@property # type: ignore[override]
def last_updated(self) -> datetime: # type: ignore[override]
"""Last updated datetime."""
if self._last_updated is None:
if (last_updated := self._row.last_updated) is not None:
self._last_updated = process_timestamp(last_updated)
else:
self._last_updated = self.last_changed
return self._last_updated
@last_updated.setter
def last_updated(self, value: datetime) -> None:
"""Set last updated datetime."""
self._last_updated = value
def as_dict(self) -> dict[str, Any]: # type: ignore[override]
"""Return a dict representation of the LazyState.
Async friendly.
To be used for JSON serialization.
"""
if self._last_changed is None and self._last_updated is None:
last_changed_isoformat = process_timestamp_to_utc_isoformat(
self._row.last_changed
)
if (
self._row.last_updated is None
or self._row.last_changed == self._row.last_updated
):
last_updated_isoformat = last_changed_isoformat
else:
last_updated_isoformat = process_timestamp_to_utc_isoformat(
self._row.last_updated
)
else:
last_changed_isoformat = self.last_changed.isoformat()
if self.last_changed == self.last_updated:
last_updated_isoformat = last_changed_isoformat
else:
last_updated_isoformat = self.last_updated.isoformat()
return {
"entity_id": self.entity_id,
"state": self.state,
"attributes": self._attributes or self.attributes,
"last_changed": last_changed_isoformat,
"last_updated": last_updated_isoformat,
}
def __eq__(self, other: Any) -> bool:
"""Return the comparison."""
return (
other.__class__ in [self.__class__, State]
and self.entity_id == other.entity_id
and self.state == other.state
and self.attributes == other.attributes
)

View file

@ -1,21 +1,26 @@
"""The tests for sensor recorder platform."""
# pylint: disable=protected-access,invalid-name
from datetime import timedelta
import importlib
import sys
from unittest.mock import patch, sentinel
import pytest
from pytest import approx
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from homeassistant.components import recorder
from homeassistant.components.recorder import history, statistics
from homeassistant.components.recorder.const import DATA_INSTANCE
from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX
from homeassistant.components.recorder.models import (
StatisticsShortTerm,
process_timestamp_to_utc_isoformat,
)
from homeassistant.components.recorder.statistics import (
async_add_external_statistics,
delete_duplicates,
delete_statistics_duplicates,
delete_statistics_meta_duplicates,
get_last_short_term_statistics,
get_last_statistics,
get_latest_short_term_statistics,
@ -32,7 +37,7 @@ import homeassistant.util.dt as dt_util
from .common import async_wait_recording_done, do_adhoc_statistics
from tests.common import mock_registry
from tests.common import get_test_home_assistant, mock_registry
from tests.components.recorder.common import wait_recording_done
ORIG_TZ = dt_util.DEFAULT_TIME_ZONE
@ -306,12 +311,92 @@ def test_rename_entity(hass_recorder):
entity_reg.async_update_entity("sensor.test1", new_entity_id="sensor.test99")
hass.add_job(rename_entry)
hass.block_till_done()
wait_recording_done(hass)
stats = statistics_during_period(hass, zero, period="5minute")
assert stats == {"sensor.test99": expected_stats99, "sensor.test2": expected_stats2}
def test_rename_entity_collision(hass_recorder, caplog):
"""Test statistics is migrated when entity_id is changed."""
hass = hass_recorder()
setup_component(hass, "sensor", {})
entity_reg = mock_registry(hass)
@callback
def add_entry():
reg_entry = entity_reg.async_get_or_create(
"sensor",
"test",
"unique_0000",
suggested_object_id="test1",
)
assert reg_entry.entity_id == "sensor.test1"
hass.add_job(add_entry)
hass.block_till_done()
zero, four, states = record_states(hass)
hist = history.get_significant_states(hass, zero, four)
assert dict(states) == dict(hist)
for kwargs in ({}, {"statistic_ids": ["sensor.test1"]}):
stats = statistics_during_period(hass, zero, period="5minute", **kwargs)
assert stats == {}
stats = get_last_short_term_statistics(hass, 0, "sensor.test1", True)
assert stats == {}
do_adhoc_statistics(hass, start=zero)
wait_recording_done(hass)
expected_1 = {
"statistic_id": "sensor.test1",
"start": process_timestamp_to_utc_isoformat(zero),
"end": process_timestamp_to_utc_isoformat(zero + timedelta(minutes=5)),
"mean": approx(14.915254237288135),
"min": approx(10.0),
"max": approx(20.0),
"last_reset": None,
"state": None,
"sum": None,
}
expected_stats1 = [
{**expected_1, "statistic_id": "sensor.test1"},
]
expected_stats2 = [
{**expected_1, "statistic_id": "sensor.test2"},
]
stats = statistics_during_period(hass, zero, period="5minute")
assert stats == {"sensor.test1": expected_stats1, "sensor.test2": expected_stats2}
# Insert metadata for sensor.test99
metadata_1 = {
"has_mean": True,
"has_sum": False,
"name": "Total imported energy",
"source": "test",
"statistic_id": "sensor.test99",
"unit_of_measurement": "kWh",
}
with session_scope(hass=hass) as session:
session.add(recorder.models.StatisticsMeta.from_meta(metadata_1))
# Rename entity sensor.test1 to sensor.test99
@callback
def rename_entry():
entity_reg.async_update_entity("sensor.test1", new_entity_id="sensor.test99")
hass.add_job(rename_entry)
wait_recording_done(hass)
# Statistics failed to migrate due to the collision
stats = statistics_during_period(hass, zero, period="5minute")
assert stats == {"sensor.test1": expected_stats1, "sensor.test2": expected_stats2}
assert "Blocked attempt to insert duplicated statistic rows" in caplog.text
def test_statistics_duplicated(hass_recorder, caplog):
"""Test statistics with same start time is not compiled."""
hass = hass_recorder()
@ -737,7 +822,7 @@ def test_delete_duplicates_no_duplicates(hass_recorder, caplog):
hass = hass_recorder()
wait_recording_done(hass)
with session_scope(hass=hass) as session:
delete_duplicates(hass, session)
delete_statistics_duplicates(hass, session)
assert "duplicated statistics rows" not in caplog.text
assert "Found non identical" not in caplog.text
assert "Found duplicated" not in caplog.text
@ -800,6 +885,208 @@ def test_duplicate_statistics_handle_integrity_error(hass_recorder, caplog):
assert "Blocked attempt to insert duplicated statistic rows" in caplog.text
def _create_engine_28(*args, **kwargs):
"""Test version of create_engine that initializes with old schema.
This simulates an existing db with the old schema.
"""
module = "tests.components.recorder.models_schema_28"
importlib.import_module(module)
old_models = sys.modules[module]
engine = create_engine(*args, **kwargs)
old_models.Base.metadata.create_all(engine)
with Session(engine) as session:
session.add(recorder.models.StatisticsRuns(start=statistics.get_start_time()))
session.add(
recorder.models.SchemaChanges(schema_version=old_models.SCHEMA_VERSION)
)
session.commit()
return engine
def test_delete_metadata_duplicates(caplog, tmpdir):
"""Test removal of duplicated statistics."""
test_db_file = tmpdir.mkdir("sqlite").join("test_run_info.db")
dburl = f"{SQLITE_URL_PREFIX}//{test_db_file}"
module = "tests.components.recorder.models_schema_28"
importlib.import_module(module)
old_models = sys.modules[module]
external_energy_metadata_1 = {
"has_mean": False,
"has_sum": True,
"name": "Total imported energy",
"source": "test",
"statistic_id": "test:total_energy_import_tariff_1",
"unit_of_measurement": "kWh",
}
external_energy_metadata_2 = {
"has_mean": False,
"has_sum": True,
"name": "Total imported energy",
"source": "test",
"statistic_id": "test:total_energy_import_tariff_1",
"unit_of_measurement": "kWh",
}
external_co2_metadata = {
"has_mean": True,
"has_sum": False,
"name": "Fossil percentage",
"source": "test",
"statistic_id": "test:fossil_percentage",
"unit_of_measurement": "%",
}
# Create some duplicated statistics_meta with schema version 28
with patch.object(recorder, "models", old_models), patch.object(
recorder.migration, "SCHEMA_VERSION", old_models.SCHEMA_VERSION
), patch(
"homeassistant.components.recorder.core.create_engine", new=_create_engine_28
):
hass = get_test_home_assistant()
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
wait_recording_done(hass)
wait_recording_done(hass)
with session_scope(hass=hass) as session:
session.add(
recorder.models.StatisticsMeta.from_meta(external_energy_metadata_1)
)
session.add(
recorder.models.StatisticsMeta.from_meta(external_energy_metadata_2)
)
session.add(recorder.models.StatisticsMeta.from_meta(external_co2_metadata))
with session_scope(hass=hass) as session:
tmp = session.query(recorder.models.StatisticsMeta).all()
assert len(tmp) == 3
assert tmp[0].id == 1
assert tmp[0].statistic_id == "test:total_energy_import_tariff_1"
assert tmp[1].id == 2
assert tmp[1].statistic_id == "test:total_energy_import_tariff_1"
assert tmp[2].id == 3
assert tmp[2].statistic_id == "test:fossil_percentage"
hass.stop()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ
# Test that the duplicates are removed during migration from schema 28
hass = get_test_home_assistant()
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
hass.start()
wait_recording_done(hass)
wait_recording_done(hass)
assert "Deleted 1 duplicated statistics_meta rows" in caplog.text
with session_scope(hass=hass) as session:
tmp = session.query(recorder.models.StatisticsMeta).all()
assert len(tmp) == 2
assert tmp[0].id == 2
assert tmp[0].statistic_id == "test:total_energy_import_tariff_1"
assert tmp[1].id == 3
assert tmp[1].statistic_id == "test:fossil_percentage"
hass.stop()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ
def test_delete_metadata_duplicates_many(caplog, tmpdir):
"""Test removal of duplicated statistics."""
test_db_file = tmpdir.mkdir("sqlite").join("test_run_info.db")
dburl = f"{SQLITE_URL_PREFIX}//{test_db_file}"
module = "tests.components.recorder.models_schema_28"
importlib.import_module(module)
old_models = sys.modules[module]
external_energy_metadata_1 = {
"has_mean": False,
"has_sum": True,
"name": "Total imported energy",
"source": "test",
"statistic_id": "test:total_energy_import_tariff_1",
"unit_of_measurement": "kWh",
}
external_energy_metadata_2 = {
"has_mean": False,
"has_sum": True,
"name": "Total imported energy",
"source": "test",
"statistic_id": "test:total_energy_import_tariff_2",
"unit_of_measurement": "kWh",
}
external_co2_metadata = {
"has_mean": True,
"has_sum": False,
"name": "Fossil percentage",
"source": "test",
"statistic_id": "test:fossil_percentage",
"unit_of_measurement": "%",
}
# Create some duplicated statistics with schema version 28
with patch.object(recorder, "models", old_models), patch.object(
recorder.migration, "SCHEMA_VERSION", old_models.SCHEMA_VERSION
), patch(
"homeassistant.components.recorder.core.create_engine", new=_create_engine_28
):
hass = get_test_home_assistant()
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
wait_recording_done(hass)
wait_recording_done(hass)
with session_scope(hass=hass) as session:
session.add(
recorder.models.StatisticsMeta.from_meta(external_energy_metadata_1)
)
for _ in range(3000):
session.add(
recorder.models.StatisticsMeta.from_meta(external_energy_metadata_1)
)
session.add(
recorder.models.StatisticsMeta.from_meta(external_energy_metadata_2)
)
session.add(
recorder.models.StatisticsMeta.from_meta(external_energy_metadata_2)
)
session.add(recorder.models.StatisticsMeta.from_meta(external_co2_metadata))
session.add(recorder.models.StatisticsMeta.from_meta(external_co2_metadata))
hass.stop()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ
# Test that the duplicates are removed during migration from schema 28
hass = get_test_home_assistant()
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
hass.start()
wait_recording_done(hass)
wait_recording_done(hass)
assert "Deleted 3002 duplicated statistics_meta rows" in caplog.text
with session_scope(hass=hass) as session:
tmp = session.query(recorder.models.StatisticsMeta).all()
assert len(tmp) == 3
assert tmp[0].id == 3001
assert tmp[0].statistic_id == "test:total_energy_import_tariff_1"
assert tmp[1].id == 3003
assert tmp[1].statistic_id == "test:total_energy_import_tariff_2"
assert tmp[2].id == 3005
assert tmp[2].statistic_id == "test:fossil_percentage"
hass.stop()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ
def test_delete_metadata_duplicates_no_duplicates(hass_recorder, caplog):
"""Test removal of duplicated statistics."""
hass = hass_recorder()
wait_recording_done(hass)
with session_scope(hass=hass) as session:
delete_statistics_meta_duplicates(session)
assert "duplicated statistics_meta rows" not in caplog.text
def record_states(hass):
"""Record some test states.