Validate common statistics DB schema errors on start (#79707)
* Validate common statistics db schema errors on start * Fix test * Add tests * Adjust tests * Disable statistics schema validation in tests * Update after rebase
This commit is contained in:
parent
724a79a8e8
commit
f869ce9d06
5 changed files with 602 additions and 87 deletions
|
@ -591,16 +591,14 @@ class Recorder(threading.Thread):
|
|||
self.hass.add_job(self.async_connection_failed)
|
||||
return
|
||||
|
||||
schema_status = migration.validate_db_schema(self.hass, self.get_session)
|
||||
schema_status = migration.validate_db_schema(self.hass, self, self.get_session)
|
||||
if schema_status is None:
|
||||
# Give up if we could not validate the schema
|
||||
self.hass.add_job(self.async_connection_failed)
|
||||
return
|
||||
self.schema_version = schema_status.current_version
|
||||
|
||||
schema_is_valid = migration.schema_is_valid(schema_status)
|
||||
|
||||
if schema_is_valid:
|
||||
if schema_status.valid:
|
||||
self._setup_run()
|
||||
else:
|
||||
self.migration_in_progress = True
|
||||
|
@ -608,8 +606,8 @@ class Recorder(threading.Thread):
|
|||
|
||||
self.hass.add_job(self.async_connection_success)
|
||||
|
||||
if self.migration_is_live or schema_is_valid:
|
||||
# If the migrate is live or the schema is current, we need to
|
||||
if self.migration_is_live or schema_status.valid:
|
||||
# If the migrate is live or the schema is valid, we need to
|
||||
# wait for startup to complete. If its not live, we need to continue
|
||||
# on.
|
||||
self.hass.add_job(self.async_set_db_ready)
|
||||
|
@ -626,7 +624,7 @@ class Recorder(threading.Thread):
|
|||
self.hass.add_job(self.async_set_db_ready)
|
||||
return
|
||||
|
||||
if not schema_is_valid:
|
||||
if not schema_status.valid:
|
||||
if self._migrate_schema_and_setup_run(schema_status):
|
||||
self.schema_version = SCHEMA_VERSION
|
||||
if not self._event_listener:
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
|
||||
from collections.abc import Callable, Iterable
|
||||
import contextlib
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, replace as dataclass_replace
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
@ -37,9 +37,11 @@ from .db_schema import (
|
|||
)
|
||||
from .models import process_timestamp
|
||||
from .statistics import (
|
||||
correct_db_schema as statistics_correct_db_schema,
|
||||
delete_statistics_duplicates,
|
||||
delete_statistics_meta_duplicates,
|
||||
get_start_time,
|
||||
validate_db_schema as statistics_validate_db_schema,
|
||||
)
|
||||
from .util import session_scope
|
||||
|
||||
|
@ -83,6 +85,8 @@ class SchemaValidationStatus:
|
|||
"""Store schema validation status."""
|
||||
|
||||
current_version: int
|
||||
statistics_schema_errors: set[str]
|
||||
valid: bool
|
||||
|
||||
|
||||
def _schema_is_current(current_version: int) -> bool:
|
||||
|
@ -90,13 +94,8 @@ def _schema_is_current(current_version: int) -> bool:
|
|||
return current_version == SCHEMA_VERSION
|
||||
|
||||
|
||||
def schema_is_valid(schema_status: SchemaValidationStatus) -> bool:
|
||||
"""Check if the schema is valid."""
|
||||
return _schema_is_current(schema_status.current_version)
|
||||
|
||||
|
||||
def validate_db_schema(
|
||||
hass: HomeAssistant, session_maker: Callable[[], Session]
|
||||
hass: HomeAssistant, engine: Engine, session_maker: Callable[[], Session]
|
||||
) -> SchemaValidationStatus | None:
|
||||
"""Check if the schema is valid.
|
||||
|
||||
|
@ -104,11 +103,20 @@ def validate_db_schema(
|
|||
errors caused by manual migration between database engines, for example importing an
|
||||
SQLite database to MariaDB.
|
||||
"""
|
||||
schema_errors: set[str] = set()
|
||||
|
||||
current_version = get_schema_version(session_maker)
|
||||
if current_version is None:
|
||||
return None
|
||||
|
||||
return SchemaValidationStatus(current_version)
|
||||
if is_current := _schema_is_current(current_version):
|
||||
# We can only check for further errors if the schema is current, because
|
||||
# columns may otherwise not exist etc.
|
||||
schema_errors |= statistics_validate_db_schema(hass, engine, session_maker)
|
||||
|
||||
valid = is_current and not schema_errors
|
||||
|
||||
return SchemaValidationStatus(current_version, schema_errors, valid)
|
||||
|
||||
|
||||
def live_migration(schema_status: SchemaValidationStatus) -> bool:
|
||||
|
@ -125,10 +133,18 @@ def migrate_schema(
|
|||
) -> None:
|
||||
"""Check if the schema needs to be upgraded."""
|
||||
current_version = schema_status.current_version
|
||||
_LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version)
|
||||
if current_version != SCHEMA_VERSION:
|
||||
_LOGGER.warning(
|
||||
"Database is about to upgrade from schema version: %s to: %s",
|
||||
current_version,
|
||||
SCHEMA_VERSION,
|
||||
)
|
||||
db_ready = False
|
||||
for version in range(current_version, SCHEMA_VERSION):
|
||||
if live_migration(SchemaValidationStatus(version)) and not db_ready:
|
||||
if (
|
||||
live_migration(dataclass_replace(schema_status, current_version=version))
|
||||
and not db_ready
|
||||
):
|
||||
db_ready = True
|
||||
instance.migration_is_live = True
|
||||
hass.add_job(instance.async_set_db_ready)
|
||||
|
@ -140,6 +156,13 @@ def migrate_schema(
|
|||
|
||||
_LOGGER.info("Upgrade to version %s done", new_version)
|
||||
|
||||
if schema_errors := schema_status.statistics_schema_errors:
|
||||
_LOGGER.warning(
|
||||
"Database is about to correct DB schema errors: %s",
|
||||
", ".join(sorted(schema_errors)),
|
||||
)
|
||||
statistics_correct_db_schema(engine, session_maker, schema_errors)
|
||||
|
||||
|
||||
def _create_index(
|
||||
session_maker: Callable[[], Session], table_name: str, index_name: str
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Iterable
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
import contextlib
|
||||
import dataclasses
|
||||
from datetime import datetime, timedelta
|
||||
|
@ -15,9 +15,10 @@ import re
|
|||
from statistics import mean
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from sqlalchemy import bindparam, func, lambda_stmt, select
|
||||
from sqlalchemy import bindparam, func, lambda_stmt, select, text
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.engine.row import Row
|
||||
from sqlalchemy.exc import SQLAlchemyError, StatementError
|
||||
from sqlalchemy.exc import OperationalError, SQLAlchemyError, StatementError
|
||||
from sqlalchemy.orm.session import Session
|
||||
from sqlalchemy.sql.expression import literal_column, true
|
||||
from sqlalchemy.sql.lambdas import StatementLambdaElement
|
||||
|
@ -874,12 +875,17 @@ def get_metadata(
|
|||
)
|
||||
|
||||
|
||||
def _clear_statistics_with_session(session: Session, statistic_ids: list[str]) -> None:
|
||||
"""Clear statistics for a list of statistic_ids."""
|
||||
session.query(StatisticsMeta).filter(
|
||||
StatisticsMeta.statistic_id.in_(statistic_ids)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
|
||||
def clear_statistics(instance: Recorder, statistic_ids: list[str]) -> None:
|
||||
"""Clear statistics for a list of statistic_ids."""
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
session.query(StatisticsMeta).filter(
|
||||
StatisticsMeta.statistic_id.in_(statistic_ids)
|
||||
).delete(synchronize_session=False)
|
||||
_clear_statistics_with_session(session, statistic_ids)
|
||||
|
||||
|
||||
def update_statistics_metadata(
|
||||
|
@ -1562,6 +1568,78 @@ def statistic_during_period(
|
|||
return {key: convert(value) for key, value in result.items()}
|
||||
|
||||
|
||||
def _statistics_during_period_with_session(
|
||||
hass: HomeAssistant,
|
||||
session: Session,
|
||||
start_time: datetime,
|
||||
end_time: datetime | None,
|
||||
statistic_ids: list[str] | None,
|
||||
period: Literal["5minute", "day", "hour", "week", "month"],
|
||||
units: dict[str, str] | None,
|
||||
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
|
||||
) -> dict[str, list[dict[str, Any]]]:
|
||||
"""Return statistic data points during UTC period start_time - end_time.
|
||||
|
||||
If end_time is omitted, returns statistics newer than or equal to start_time.
|
||||
If statistic_ids is omitted, returns statistics for all statistics ids.
|
||||
"""
|
||||
metadata = None
|
||||
# Fetch metadata for the given (or all) statistic_ids
|
||||
metadata = get_metadata_with_session(session, statistic_ids=statistic_ids)
|
||||
if not metadata:
|
||||
return {}
|
||||
|
||||
metadata_ids = None
|
||||
if statistic_ids is not None:
|
||||
metadata_ids = [metadata_id for metadata_id, _ in metadata.values()]
|
||||
|
||||
table: type[Statistics | StatisticsShortTerm] = (
|
||||
Statistics if period != "5minute" else StatisticsShortTerm
|
||||
)
|
||||
stmt = _statistics_during_period_stmt(
|
||||
start_time, end_time, metadata_ids, table, types
|
||||
)
|
||||
stats = execute_stmt_lambda_element(session, stmt)
|
||||
|
||||
if not stats:
|
||||
return {}
|
||||
# Return statistics combined with metadata
|
||||
if period not in ("day", "week", "month"):
|
||||
return _sorted_statistics_to_dict(
|
||||
hass,
|
||||
session,
|
||||
stats,
|
||||
statistic_ids,
|
||||
metadata,
|
||||
True,
|
||||
table,
|
||||
start_time,
|
||||
units,
|
||||
types,
|
||||
)
|
||||
|
||||
result = _sorted_statistics_to_dict(
|
||||
hass,
|
||||
session,
|
||||
stats,
|
||||
statistic_ids,
|
||||
metadata,
|
||||
True,
|
||||
table,
|
||||
start_time,
|
||||
units,
|
||||
types,
|
||||
)
|
||||
|
||||
if period == "day":
|
||||
return _reduce_statistics_per_day(result, types)
|
||||
|
||||
if period == "week":
|
||||
return _reduce_statistics_per_week(result, types)
|
||||
|
||||
return _reduce_statistics_per_month(result, types)
|
||||
|
||||
|
||||
def statistics_during_period(
|
||||
hass: HomeAssistant,
|
||||
start_time: datetime,
|
||||
|
@ -1576,63 +1654,18 @@ def statistics_during_period(
|
|||
If end_time is omitted, returns statistics newer than or equal to start_time.
|
||||
If statistic_ids is omitted, returns statistics for all statistics ids.
|
||||
"""
|
||||
metadata = None
|
||||
with session_scope(hass=hass) as session:
|
||||
# Fetch metadata for the given (or all) statistic_ids
|
||||
metadata = get_metadata_with_session(session, statistic_ids=statistic_ids)
|
||||
if not metadata:
|
||||
return {}
|
||||
|
||||
metadata_ids = None
|
||||
if statistic_ids is not None:
|
||||
metadata_ids = [metadata_id for metadata_id, _ in metadata.values()]
|
||||
|
||||
table: type[Statistics | StatisticsShortTerm] = (
|
||||
Statistics if period != "5minute" else StatisticsShortTerm
|
||||
)
|
||||
stmt = _statistics_during_period_stmt(
|
||||
start_time, end_time, metadata_ids, table, types
|
||||
)
|
||||
stats = execute_stmt_lambda_element(session, stmt)
|
||||
|
||||
if not stats:
|
||||
return {}
|
||||
# Return statistics combined with metadata
|
||||
if period not in ("day", "week", "month"):
|
||||
return _sorted_statistics_to_dict(
|
||||
hass,
|
||||
session,
|
||||
stats,
|
||||
statistic_ids,
|
||||
metadata,
|
||||
True,
|
||||
table,
|
||||
start_time,
|
||||
units,
|
||||
types,
|
||||
)
|
||||
|
||||
result = _sorted_statistics_to_dict(
|
||||
return _statistics_during_period_with_session(
|
||||
hass,
|
||||
session,
|
||||
stats,
|
||||
statistic_ids,
|
||||
metadata,
|
||||
True,
|
||||
table,
|
||||
start_time,
|
||||
end_time,
|
||||
statistic_ids,
|
||||
period,
|
||||
units,
|
||||
types,
|
||||
)
|
||||
|
||||
if period == "day":
|
||||
return _reduce_statistics_per_day(result, types)
|
||||
|
||||
if period == "week":
|
||||
return _reduce_statistics_per_week(result, types)
|
||||
|
||||
return _reduce_statistics_per_month(result, types)
|
||||
|
||||
|
||||
def _get_last_statistics_stmt(
|
||||
metadata_id: int,
|
||||
|
@ -2047,6 +2080,26 @@ def _filter_unique_constraint_integrity_error(
|
|||
return _filter_unique_constraint_integrity_error
|
||||
|
||||
|
||||
def _import_statistics_with_session(
|
||||
session: Session,
|
||||
metadata: StatisticMetaData,
|
||||
statistics: Iterable[StatisticData],
|
||||
table: type[Statistics | StatisticsShortTerm],
|
||||
) -> bool:
|
||||
"""Import statistics to the database."""
|
||||
old_metadata_dict = get_metadata_with_session(
|
||||
session, statistic_ids=[metadata["statistic_id"]]
|
||||
)
|
||||
metadata_id = _update_or_add_metadata(session, metadata, old_metadata_dict)
|
||||
for stat in statistics:
|
||||
if stat_id := _statistics_exists(session, table, metadata_id, stat["start"]):
|
||||
_update_statistics(session, table, stat_id, stat)
|
||||
else:
|
||||
_insert_statistics(session, table, metadata_id, stat)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@retryable_database_job("statistics")
|
||||
def import_statistics(
|
||||
instance: Recorder,
|
||||
|
@ -2060,19 +2113,7 @@ def import_statistics(
|
|||
session=instance.get_session(),
|
||||
exception_filter=_filter_unique_constraint_integrity_error(instance),
|
||||
) as session:
|
||||
old_metadata_dict = get_metadata_with_session(
|
||||
session, statistic_ids=[metadata["statistic_id"]]
|
||||
)
|
||||
metadata_id = _update_or_add_metadata(session, metadata, old_metadata_dict)
|
||||
for stat in statistics:
|
||||
if stat_id := _statistics_exists(
|
||||
session, table, metadata_id, stat["start"]
|
||||
):
|
||||
_update_statistics(session, table, stat_id, stat)
|
||||
else:
|
||||
_insert_statistics(session, table, metadata_id, stat)
|
||||
|
||||
return True
|
||||
return _import_statistics_with_session(session, metadata, statistics, table)
|
||||
|
||||
|
||||
@retryable_database_job("adjust_statistics")
|
||||
|
@ -2189,3 +2230,232 @@ def async_change_statistics_unit(
|
|||
new_unit_of_measurement=new_unit_of_measurement,
|
||||
old_unit_of_measurement=old_unit_of_measurement,
|
||||
)
|
||||
|
||||
|
||||
def _validate_db_schema_utf8(
|
||||
instance: Recorder, session_maker: Callable[[], Session]
|
||||
) -> set[str]:
|
||||
"""Do some basic checks for common schema errors caused by manual migration."""
|
||||
schema_errors: set[str] = set()
|
||||
|
||||
# Lack of full utf8 support is only an issue for MySQL / MariaDB
|
||||
if instance.dialect_name != SupportedDialect.MYSQL:
|
||||
return schema_errors
|
||||
|
||||
# This name can't be represented unless 4-byte UTF-8 unicode is supported
|
||||
utf8_name = "𓆚𓃗"
|
||||
statistic_id = f"{DOMAIN}.db_test"
|
||||
|
||||
metadata: StatisticMetaData = {
|
||||
"has_mean": True,
|
||||
"has_sum": True,
|
||||
"name": utf8_name,
|
||||
"source": DOMAIN,
|
||||
"statistic_id": statistic_id,
|
||||
"unit_of_measurement": None,
|
||||
}
|
||||
|
||||
# Try inserting some metadata which needs utfmb4 support
|
||||
try:
|
||||
with session_scope(session=session_maker()) as session:
|
||||
old_metadata_dict = get_metadata_with_session(
|
||||
session, statistic_ids=[statistic_id]
|
||||
)
|
||||
try:
|
||||
_update_or_add_metadata(session, metadata, old_metadata_dict)
|
||||
_clear_statistics_with_session(session, statistic_ids=[statistic_id])
|
||||
except OperationalError as err:
|
||||
if err.orig and err.orig.args[0] == 1366:
|
||||
_LOGGER.debug(
|
||||
"Database table statistics_meta does not support 4-byte UTF-8"
|
||||
)
|
||||
schema_errors.add("statistics_meta.4-byte UTF-8")
|
||||
session.rollback()
|
||||
else:
|
||||
raise
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Error when validating DB schema: %s", exc)
|
||||
return schema_errors
|
||||
|
||||
|
||||
def _validate_db_schema(
|
||||
hass: HomeAssistant, instance: Recorder, session_maker: Callable[[], Session]
|
||||
) -> set[str]:
|
||||
"""Do some basic checks for common schema errors caused by manual migration."""
|
||||
schema_errors: set[str] = set()
|
||||
|
||||
# Wrong precision is only an issue for MySQL / MariaDB / PostgreSQL
|
||||
if instance.dialect_name not in (
|
||||
SupportedDialect.MYSQL,
|
||||
SupportedDialect.POSTGRESQL,
|
||||
):
|
||||
return schema_errors
|
||||
|
||||
# This number can't be accurately represented as a 32-bit float
|
||||
precise_number = 1.000000000000001
|
||||
# This time can't be accurately represented unless datetimes have µs precision
|
||||
precise_time = datetime(2020, 10, 6, microsecond=1, tzinfo=dt_util.UTC)
|
||||
|
||||
start_time = datetime(2020, 10, 6, tzinfo=dt_util.UTC)
|
||||
statistic_id = f"{DOMAIN}.db_test"
|
||||
|
||||
metadata: StatisticMetaData = {
|
||||
"has_mean": True,
|
||||
"has_sum": True,
|
||||
"name": None,
|
||||
"source": DOMAIN,
|
||||
"statistic_id": statistic_id,
|
||||
"unit_of_measurement": None,
|
||||
}
|
||||
statistics: StatisticData = {
|
||||
"last_reset": precise_time,
|
||||
"max": precise_number,
|
||||
"mean": precise_number,
|
||||
"min": precise_number,
|
||||
"start": precise_time,
|
||||
"state": precise_number,
|
||||
"sum": precise_number,
|
||||
}
|
||||
|
||||
def check_columns(
|
||||
schema_errors: set[str],
|
||||
stored: Mapping,
|
||||
expected: Mapping,
|
||||
columns: tuple[str, ...],
|
||||
table_name: str,
|
||||
supports: str,
|
||||
) -> None:
|
||||
for column in columns:
|
||||
if stored[column] != expected[column]:
|
||||
schema_errors.add(f"{table_name}.{supports}")
|
||||
_LOGGER.debug(
|
||||
"Column %s in database table %s does not support %s (%s != %s)",
|
||||
column,
|
||||
table_name,
|
||||
supports,
|
||||
stored[column],
|
||||
expected[column],
|
||||
)
|
||||
|
||||
# Insert / adjust a test statistics row in each of the tables
|
||||
tables: tuple[type[Statistics | StatisticsShortTerm], ...] = (
|
||||
Statistics,
|
||||
StatisticsShortTerm,
|
||||
)
|
||||
try:
|
||||
with session_scope(session=session_maker()) as session:
|
||||
for table in tables:
|
||||
_import_statistics_with_session(session, metadata, (statistics,), table)
|
||||
stored_statistics = _statistics_during_period_with_session(
|
||||
hass,
|
||||
session,
|
||||
start_time,
|
||||
None,
|
||||
[statistic_id],
|
||||
"hour" if table == Statistics else "5minute",
|
||||
None,
|
||||
{"last_reset", "max", "mean", "min", "state", "sum"},
|
||||
)
|
||||
if not (stored_statistic := stored_statistics.get(statistic_id)):
|
||||
_LOGGER.warning(
|
||||
"Schema validation failed for table: %s", table.__tablename__
|
||||
)
|
||||
continue
|
||||
|
||||
check_columns(
|
||||
schema_errors,
|
||||
stored_statistic[0],
|
||||
statistics,
|
||||
("max", "mean", "min", "state", "sum"),
|
||||
table.__tablename__,
|
||||
"double precision",
|
||||
)
|
||||
assert statistics["last_reset"]
|
||||
check_columns(
|
||||
schema_errors,
|
||||
stored_statistic[0],
|
||||
{
|
||||
"last_reset": statistics["last_reset"],
|
||||
"start": statistics["start"],
|
||||
},
|
||||
("start", "last_reset"),
|
||||
table.__tablename__,
|
||||
"µs precision",
|
||||
)
|
||||
_clear_statistics_with_session(session, statistic_ids=[statistic_id])
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Error when validating DB schema: %s", exc)
|
||||
|
||||
return schema_errors
|
||||
|
||||
|
||||
def validate_db_schema(
|
||||
hass: HomeAssistant, instance: Recorder, session_maker: Callable[[], Session]
|
||||
) -> set[str]:
|
||||
"""Do some basic checks for common schema errors caused by manual migration."""
|
||||
schema_errors: set[str] = set()
|
||||
schema_errors |= _validate_db_schema_utf8(instance, session_maker)
|
||||
schema_errors |= _validate_db_schema(hass, instance, session_maker)
|
||||
if schema_errors:
|
||||
_LOGGER.debug(
|
||||
"Detected statistics schema errors: %s", ", ".join(sorted(schema_errors))
|
||||
)
|
||||
return schema_errors
|
||||
|
||||
|
||||
def correct_db_schema(
|
||||
engine: Engine, session_maker: Callable[[], Session], schema_errors: set[str]
|
||||
) -> None:
|
||||
"""Correct issues detected by validate_db_schema."""
|
||||
from .migration import _modify_columns # pylint: disable=import-outside-toplevel
|
||||
|
||||
if "statistics_meta.4-byte UTF-8" in schema_errors:
|
||||
# Attempt to convert the table to utf8mb4
|
||||
_LOGGER.warning(
|
||||
"Updating character set and collation of table %s to utf8mb4. "
|
||||
"Note: this can take several minutes on large databases and slow "
|
||||
"computers. Please be patient!",
|
||||
"statistics_meta",
|
||||
)
|
||||
with contextlib.suppress(SQLAlchemyError):
|
||||
with session_scope(session=session_maker()) as session:
|
||||
connection = session.connection()
|
||||
connection.execute(
|
||||
# Using LOCK=EXCLUSIVE to prevent the database from corrupting
|
||||
# https://github.com/home-assistant/core/issues/56104
|
||||
text(
|
||||
"ALTER TABLE statistics_meta CONVERT TO "
|
||||
"CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci, LOCK=EXCLUSIVE"
|
||||
)
|
||||
)
|
||||
|
||||
tables: tuple[type[Statistics | StatisticsShortTerm], ...] = (
|
||||
Statistics,
|
||||
StatisticsShortTerm,
|
||||
)
|
||||
for table in tables:
|
||||
if f"{table.__tablename__}.double precision" in schema_errors:
|
||||
# Attempt to convert float columns to double precision
|
||||
_modify_columns(
|
||||
session_maker,
|
||||
engine,
|
||||
table.__tablename__,
|
||||
[
|
||||
"mean DOUBLE PRECISION",
|
||||
"min DOUBLE PRECISION",
|
||||
"max DOUBLE PRECISION",
|
||||
"state DOUBLE PRECISION",
|
||||
"sum DOUBLE PRECISION",
|
||||
],
|
||||
)
|
||||
if f"{table.__tablename__}.µs precision" in schema_errors:
|
||||
# Attempt to convert datetime columns to µs precision
|
||||
_modify_columns(
|
||||
session_maker,
|
||||
engine,
|
||||
table.__tablename__,
|
||||
[
|
||||
"last_reset DATETIME(6)",
|
||||
"start DATETIME(6)",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
"""The tests for sensor recorder platform."""
|
||||
# pylint: disable=protected-access,invalid-name
|
||||
from datetime import timedelta
|
||||
from datetime import datetime, timedelta
|
||||
import importlib
|
||||
import sys
|
||||
from unittest.mock import patch, sentinel
|
||||
from unittest.mock import ANY, DEFAULT, MagicMock, patch, sentinel
|
||||
|
||||
import pytest
|
||||
from pytest import approx
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from homeassistant.components import recorder
|
||||
|
@ -16,6 +17,8 @@ from homeassistant.components.recorder.const import SQLITE_URL_PREFIX
|
|||
from homeassistant.components.recorder.db_schema import StatisticsShortTerm
|
||||
from homeassistant.components.recorder.models import process_timestamp
|
||||
from homeassistant.components.recorder.statistics import (
|
||||
_statistics_during_period_with_session,
|
||||
_update_or_add_metadata,
|
||||
async_add_external_statistics,
|
||||
async_import_statistics,
|
||||
delete_statistics_duplicates,
|
||||
|
@ -1475,6 +1478,196 @@ def test_delete_metadata_duplicates_no_duplicates(hass_recorder, caplog):
|
|||
assert "duplicated statistics_meta rows" not in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_statistics_table_validation", [True])
|
||||
@pytest.mark.parametrize("db_engine", ("mysql", "postgresql"))
|
||||
async def test_validate_db_schema(
|
||||
async_setup_recorder_instance, hass, caplog, db_engine
|
||||
):
|
||||
"""Test validating DB schema with MySQL and PostgreSQL.
|
||||
|
||||
Note: The test uses SQLite, the purpose is only to exercise the code.
|
||||
"""
|
||||
with patch(
|
||||
"homeassistant.components.recorder.core.Recorder.dialect_name", db_engine
|
||||
):
|
||||
await async_setup_recorder_instance(hass)
|
||||
await async_wait_recording_done(hass)
|
||||
assert "Schema validation failed" not in caplog.text
|
||||
assert "Detected statistics schema errors" not in caplog.text
|
||||
assert "Database is about to correct DB schema errors" not in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_statistics_table_validation", [True])
|
||||
async def test_validate_db_schema_fix_utf8_issue(
|
||||
async_setup_recorder_instance, hass, caplog
|
||||
):
|
||||
"""Test validating DB schema with MySQL.
|
||||
|
||||
Note: The test uses SQLite, the purpose is only to exercise the code.
|
||||
"""
|
||||
orig_error = MagicMock()
|
||||
orig_error.args = [1366]
|
||||
utf8_error = OperationalError("", "", orig=orig_error)
|
||||
with patch(
|
||||
"homeassistant.components.recorder.core.Recorder.dialect_name", "mysql"
|
||||
), patch(
|
||||
"homeassistant.components.recorder.statistics._update_or_add_metadata",
|
||||
side_effect=[utf8_error, DEFAULT, DEFAULT],
|
||||
wraps=_update_or_add_metadata,
|
||||
):
|
||||
await async_setup_recorder_instance(hass)
|
||||
await async_wait_recording_done(hass)
|
||||
|
||||
assert "Schema validation failed" not in caplog.text
|
||||
assert (
|
||||
"Database is about to correct DB schema errors: statistics_meta.4-byte UTF-8"
|
||||
in caplog.text
|
||||
)
|
||||
assert (
|
||||
"Updating character set and collation of table statistics_meta to utf8mb4"
|
||||
in caplog.text
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_statistics_table_validation", [True])
|
||||
@pytest.mark.parametrize("db_engine", ("mysql", "postgresql"))
|
||||
@pytest.mark.parametrize(
|
||||
"table, replace_index", (("statistics", 0), ("statistics_short_term", 1))
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"column, value",
|
||||
(("max", 1.0), ("mean", 1.0), ("min", 1.0), ("state", 1.0), ("sum", 1.0)),
|
||||
)
|
||||
async def test_validate_db_schema_fix_float_issue(
|
||||
async_setup_recorder_instance,
|
||||
hass,
|
||||
caplog,
|
||||
db_engine,
|
||||
table,
|
||||
replace_index,
|
||||
column,
|
||||
value,
|
||||
):
|
||||
"""Test validating DB schema with MySQL.
|
||||
|
||||
Note: The test uses SQLite, the purpose is only to exercise the code.
|
||||
"""
|
||||
orig_error = MagicMock()
|
||||
orig_error.args = [1366]
|
||||
precise_number = 1.000000000000001
|
||||
precise_time = datetime(2020, 10, 6, microsecond=1, tzinfo=dt_util.UTC)
|
||||
statistics = {
|
||||
"recorder.db_test": [
|
||||
{
|
||||
"last_reset": precise_time,
|
||||
"max": precise_number,
|
||||
"mean": precise_number,
|
||||
"min": precise_number,
|
||||
"start": precise_time,
|
||||
"state": precise_number,
|
||||
"sum": precise_number,
|
||||
}
|
||||
]
|
||||
}
|
||||
statistics["recorder.db_test"][0][column] = value
|
||||
fake_statistics = [DEFAULT, DEFAULT]
|
||||
fake_statistics[replace_index] = statistics
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.recorder.core.Recorder.dialect_name", db_engine
|
||||
), patch(
|
||||
"homeassistant.components.recorder.statistics._statistics_during_period_with_session",
|
||||
side_effect=fake_statistics,
|
||||
wraps=_statistics_during_period_with_session,
|
||||
), patch(
|
||||
"homeassistant.components.recorder.migration._modify_columns"
|
||||
) as modify_columns_mock:
|
||||
await async_setup_recorder_instance(hass)
|
||||
await async_wait_recording_done(hass)
|
||||
|
||||
assert "Schema validation failed" not in caplog.text
|
||||
assert (
|
||||
f"Database is about to correct DB schema errors: {table}.double precision"
|
||||
in caplog.text
|
||||
)
|
||||
modification = [
|
||||
"mean DOUBLE PRECISION",
|
||||
"min DOUBLE PRECISION",
|
||||
"max DOUBLE PRECISION",
|
||||
"state DOUBLE PRECISION",
|
||||
"sum DOUBLE PRECISION",
|
||||
]
|
||||
modify_columns_mock.assert_called_once_with(ANY, ANY, table, modification)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_statistics_table_validation", [True])
|
||||
@pytest.mark.parametrize("db_engine", ("mysql", "postgresql"))
|
||||
@pytest.mark.parametrize(
|
||||
"table, replace_index", (("statistics", 0), ("statistics_short_term", 1))
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"column, value",
|
||||
(
|
||||
("last_reset", "2020-10-06T00:00:00+00:00"),
|
||||
("start", "2020-10-06T00:00:00+00:00"),
|
||||
),
|
||||
)
|
||||
async def test_validate_db_schema_fix_statistics_datetime_issue(
|
||||
async_setup_recorder_instance,
|
||||
hass,
|
||||
caplog,
|
||||
db_engine,
|
||||
table,
|
||||
replace_index,
|
||||
column,
|
||||
value,
|
||||
):
|
||||
"""Test validating DB schema with MySQL.
|
||||
|
||||
Note: The test uses SQLite, the purpose is only to exercise the code.
|
||||
"""
|
||||
orig_error = MagicMock()
|
||||
orig_error.args = [1366]
|
||||
precise_number = 1.000000000000001
|
||||
precise_time = datetime(2020, 10, 6, microsecond=1, tzinfo=dt_util.UTC)
|
||||
statistics = {
|
||||
"recorder.db_test": [
|
||||
{
|
||||
"last_reset": precise_time,
|
||||
"max": precise_number,
|
||||
"mean": precise_number,
|
||||
"min": precise_number,
|
||||
"start": precise_time,
|
||||
"state": precise_number,
|
||||
"sum": precise_number,
|
||||
}
|
||||
]
|
||||
}
|
||||
statistics["recorder.db_test"][0][column] = value
|
||||
fake_statistics = [DEFAULT, DEFAULT]
|
||||
fake_statistics[replace_index] = statistics
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.recorder.core.Recorder.dialect_name", db_engine
|
||||
), patch(
|
||||
"homeassistant.components.recorder.statistics._statistics_during_period_with_session",
|
||||
side_effect=fake_statistics,
|
||||
wraps=_statistics_during_period_with_session,
|
||||
), patch(
|
||||
"homeassistant.components.recorder.migration._modify_columns"
|
||||
) as modify_columns_mock:
|
||||
await async_setup_recorder_instance(hass)
|
||||
await async_wait_recording_done(hass)
|
||||
|
||||
assert "Schema validation failed" not in caplog.text
|
||||
assert (
|
||||
f"Database is about to correct DB schema errors: {table}.µs precision"
|
||||
in caplog.text
|
||||
)
|
||||
modification = ["last_reset DATETIME(6)", "start DATETIME(6)"]
|
||||
modify_columns_mock.assert_called_once_with(ANY, ANY, table, modification)
|
||||
|
||||
|
||||
def record_states(hass):
|
||||
"""Record some test states.
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import asyncio
|
|||
from collections.abc import AsyncGenerator, Callable, Generator
|
||||
from contextlib import asynccontextmanager
|
||||
import functools
|
||||
import itertools
|
||||
from json import JSONDecoder, loads
|
||||
import logging
|
||||
import sqlite3
|
||||
|
@ -860,6 +861,16 @@ def enable_statistics():
|
|||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_statistics_table_validation():
|
||||
"""Fixture to control enabling of recorder's statistics table validation.
|
||||
|
||||
To enable statistics table validation, tests can be marked with:
|
||||
@pytest.mark.parametrize("enable_statistics_table_validation", [True])
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_nightly_purge():
|
||||
"""Fixture to control enabling of recorder's nightly purge job.
|
||||
|
@ -902,6 +913,7 @@ def hass_recorder(
|
|||
recorder_db_url,
|
||||
enable_nightly_purge,
|
||||
enable_statistics,
|
||||
enable_statistics_table_validation,
|
||||
hass_storage,
|
||||
):
|
||||
"""Home Assistant fixture with in-memory recorder."""
|
||||
|
@ -910,6 +922,11 @@ def hass_recorder(
|
|||
hass = get_test_home_assistant()
|
||||
nightly = recorder.Recorder.async_nightly_tasks if enable_nightly_purge else None
|
||||
stats = recorder.Recorder.async_periodic_statistics if enable_statistics else None
|
||||
stats_validate = (
|
||||
recorder.statistics.validate_db_schema
|
||||
if enable_statistics_table_validation
|
||||
else itertools.repeat(set())
|
||||
)
|
||||
with patch(
|
||||
"homeassistant.components.recorder.Recorder.async_nightly_tasks",
|
||||
side_effect=nightly,
|
||||
|
@ -918,6 +935,10 @@ def hass_recorder(
|
|||
"homeassistant.components.recorder.Recorder.async_periodic_statistics",
|
||||
side_effect=stats,
|
||||
autospec=True,
|
||||
), patch(
|
||||
"homeassistant.components.recorder.migration.statistics_validate_db_schema",
|
||||
side_effect=stats_validate,
|
||||
autospec=True,
|
||||
):
|
||||
|
||||
def setup_recorder(config=None):
|
||||
|
@ -962,12 +983,18 @@ async def async_setup_recorder_instance(
|
|||
hass_fixture_setup,
|
||||
enable_nightly_purge,
|
||||
enable_statistics,
|
||||
enable_statistics_table_validation,
|
||||
) -> AsyncGenerator[SetupRecorderInstanceT, None]:
|
||||
"""Yield callable to setup recorder instance."""
|
||||
assert not hass_fixture_setup
|
||||
|
||||
nightly = recorder.Recorder.async_nightly_tasks if enable_nightly_purge else None
|
||||
stats = recorder.Recorder.async_periodic_statistics if enable_statistics else None
|
||||
stats_validate = (
|
||||
recorder.statistics.validate_db_schema
|
||||
if enable_statistics_table_validation
|
||||
else itertools.repeat(set())
|
||||
)
|
||||
with patch(
|
||||
"homeassistant.components.recorder.Recorder.async_nightly_tasks",
|
||||
side_effect=nightly,
|
||||
|
@ -976,6 +1003,10 @@ async def async_setup_recorder_instance(
|
|||
"homeassistant.components.recorder.Recorder.async_periodic_statistics",
|
||||
side_effect=stats,
|
||||
autospec=True,
|
||||
), patch(
|
||||
"homeassistant.components.recorder.migration.statistics_validate_db_schema",
|
||||
side_effect=stats_validate,
|
||||
autospec=True,
|
||||
):
|
||||
|
||||
async def async_setup_recorder(
|
||||
|
|
Loading…
Add table
Reference in a new issue