Validate common statistics DB schema errors on start (#79707)

* Validate common statistics db schema errors on start

* Fix test

* Add tests

* Adjust tests

* Disable statistics schema validation in tests

* Update after rebase
This commit is contained in:
Erik Montnemery 2022-11-29 10:16:08 +01:00 committed by GitHub
parent 724a79a8e8
commit f869ce9d06
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 602 additions and 87 deletions

View file

@ -591,16 +591,14 @@ class Recorder(threading.Thread):
self.hass.add_job(self.async_connection_failed)
return
schema_status = migration.validate_db_schema(self.hass, self.get_session)
schema_status = migration.validate_db_schema(self.hass, self, self.get_session)
if schema_status is None:
# Give up if we could not validate the schema
self.hass.add_job(self.async_connection_failed)
return
self.schema_version = schema_status.current_version
schema_is_valid = migration.schema_is_valid(schema_status)
if schema_is_valid:
if schema_status.valid:
self._setup_run()
else:
self.migration_in_progress = True
@ -608,8 +606,8 @@ class Recorder(threading.Thread):
self.hass.add_job(self.async_connection_success)
if self.migration_is_live or schema_is_valid:
# If the migrate is live or the schema is current, we need to
if self.migration_is_live or schema_status.valid:
# If the migrate is live or the schema is valid, we need to
# wait for startup to complete. If its not live, we need to continue
# on.
self.hass.add_job(self.async_set_db_ready)
@ -626,7 +624,7 @@ class Recorder(threading.Thread):
self.hass.add_job(self.async_set_db_ready)
return
if not schema_is_valid:
if not schema_status.valid:
if self._migrate_schema_and_setup_run(schema_status):
self.schema_version = SCHEMA_VERSION
if not self._event_listener:

View file

@ -3,7 +3,7 @@ from __future__ import annotations
from collections.abc import Callable, Iterable
import contextlib
from dataclasses import dataclass
from dataclasses import dataclass, replace as dataclass_replace
from datetime import timedelta
import logging
from typing import TYPE_CHECKING
@ -37,9 +37,11 @@ from .db_schema import (
)
from .models import process_timestamp
from .statistics import (
correct_db_schema as statistics_correct_db_schema,
delete_statistics_duplicates,
delete_statistics_meta_duplicates,
get_start_time,
validate_db_schema as statistics_validate_db_schema,
)
from .util import session_scope
@ -83,6 +85,8 @@ class SchemaValidationStatus:
"""Store schema validation status."""
current_version: int
statistics_schema_errors: set[str]
valid: bool
def _schema_is_current(current_version: int) -> bool:
@ -90,13 +94,8 @@ def _schema_is_current(current_version: int) -> bool:
return current_version == SCHEMA_VERSION
def schema_is_valid(schema_status: SchemaValidationStatus) -> bool:
"""Check if the schema is valid."""
return _schema_is_current(schema_status.current_version)
def validate_db_schema(
hass: HomeAssistant, session_maker: Callable[[], Session]
hass: HomeAssistant, engine: Engine, session_maker: Callable[[], Session]
) -> SchemaValidationStatus | None:
"""Check if the schema is valid.
@ -104,11 +103,20 @@ def validate_db_schema(
errors caused by manual migration between database engines, for example importing an
SQLite database to MariaDB.
"""
schema_errors: set[str] = set()
current_version = get_schema_version(session_maker)
if current_version is None:
return None
return SchemaValidationStatus(current_version)
if is_current := _schema_is_current(current_version):
# We can only check for further errors if the schema is current, because
# columns may otherwise not exist etc.
schema_errors |= statistics_validate_db_schema(hass, engine, session_maker)
valid = is_current and not schema_errors
return SchemaValidationStatus(current_version, schema_errors, valid)
def live_migration(schema_status: SchemaValidationStatus) -> bool:
@ -125,10 +133,18 @@ def migrate_schema(
) -> None:
"""Check if the schema needs to be upgraded."""
current_version = schema_status.current_version
_LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version)
if current_version != SCHEMA_VERSION:
_LOGGER.warning(
"Database is about to upgrade from schema version: %s to: %s",
current_version,
SCHEMA_VERSION,
)
db_ready = False
for version in range(current_version, SCHEMA_VERSION):
if live_migration(SchemaValidationStatus(version)) and not db_ready:
if (
live_migration(dataclass_replace(schema_status, current_version=version))
and not db_ready
):
db_ready = True
instance.migration_is_live = True
hass.add_job(instance.async_set_db_ready)
@ -140,6 +156,13 @@ def migrate_schema(
_LOGGER.info("Upgrade to version %s done", new_version)
if schema_errors := schema_status.statistics_schema_errors:
_LOGGER.warning(
"Database is about to correct DB schema errors: %s",
", ".join(sorted(schema_errors)),
)
statistics_correct_db_schema(engine, session_maker, schema_errors)
def _create_index(
session_maker: Callable[[], Session], table_name: str, index_name: str

View file

@ -2,7 +2,7 @@
from __future__ import annotations
from collections import defaultdict
from collections.abc import Callable, Iterable
from collections.abc import Callable, Iterable, Mapping
import contextlib
import dataclasses
from datetime import datetime, timedelta
@ -15,9 +15,10 @@ import re
from statistics import mean
from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import bindparam, func, lambda_stmt, select
from sqlalchemy import bindparam, func, lambda_stmt, select, text
from sqlalchemy.engine import Engine
from sqlalchemy.engine.row import Row
from sqlalchemy.exc import SQLAlchemyError, StatementError
from sqlalchemy.exc import OperationalError, SQLAlchemyError, StatementError
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import literal_column, true
from sqlalchemy.sql.lambdas import StatementLambdaElement
@ -874,12 +875,17 @@ def get_metadata(
)
def _clear_statistics_with_session(session: Session, statistic_ids: list[str]) -> None:
"""Clear statistics for a list of statistic_ids."""
session.query(StatisticsMeta).filter(
StatisticsMeta.statistic_id.in_(statistic_ids)
).delete(synchronize_session=False)
def clear_statistics(instance: Recorder, statistic_ids: list[str]) -> None:
"""Clear statistics for a list of statistic_ids."""
with session_scope(session=instance.get_session()) as session:
session.query(StatisticsMeta).filter(
StatisticsMeta.statistic_id.in_(statistic_ids)
).delete(synchronize_session=False)
_clear_statistics_with_session(session, statistic_ids)
def update_statistics_metadata(
@ -1562,6 +1568,78 @@ def statistic_during_period(
return {key: convert(value) for key, value in result.items()}
def _statistics_during_period_with_session(
hass: HomeAssistant,
session: Session,
start_time: datetime,
end_time: datetime | None,
statistic_ids: list[str] | None,
period: Literal["5minute", "day", "hour", "week", "month"],
units: dict[str, str] | None,
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> dict[str, list[dict[str, Any]]]:
"""Return statistic data points during UTC period start_time - end_time.
If end_time is omitted, returns statistics newer than or equal to start_time.
If statistic_ids is omitted, returns statistics for all statistics ids.
"""
metadata = None
# Fetch metadata for the given (or all) statistic_ids
metadata = get_metadata_with_session(session, statistic_ids=statistic_ids)
if not metadata:
return {}
metadata_ids = None
if statistic_ids is not None:
metadata_ids = [metadata_id for metadata_id, _ in metadata.values()]
table: type[Statistics | StatisticsShortTerm] = (
Statistics if period != "5minute" else StatisticsShortTerm
)
stmt = _statistics_during_period_stmt(
start_time, end_time, metadata_ids, table, types
)
stats = execute_stmt_lambda_element(session, stmt)
if not stats:
return {}
# Return statistics combined with metadata
if period not in ("day", "week", "month"):
return _sorted_statistics_to_dict(
hass,
session,
stats,
statistic_ids,
metadata,
True,
table,
start_time,
units,
types,
)
result = _sorted_statistics_to_dict(
hass,
session,
stats,
statistic_ids,
metadata,
True,
table,
start_time,
units,
types,
)
if period == "day":
return _reduce_statistics_per_day(result, types)
if period == "week":
return _reduce_statistics_per_week(result, types)
return _reduce_statistics_per_month(result, types)
def statistics_during_period(
hass: HomeAssistant,
start_time: datetime,
@ -1576,63 +1654,18 @@ def statistics_during_period(
If end_time is omitted, returns statistics newer than or equal to start_time.
If statistic_ids is omitted, returns statistics for all statistics ids.
"""
metadata = None
with session_scope(hass=hass) as session:
# Fetch metadata for the given (or all) statistic_ids
metadata = get_metadata_with_session(session, statistic_ids=statistic_ids)
if not metadata:
return {}
metadata_ids = None
if statistic_ids is not None:
metadata_ids = [metadata_id for metadata_id, _ in metadata.values()]
table: type[Statistics | StatisticsShortTerm] = (
Statistics if period != "5minute" else StatisticsShortTerm
)
stmt = _statistics_during_period_stmt(
start_time, end_time, metadata_ids, table, types
)
stats = execute_stmt_lambda_element(session, stmt)
if not stats:
return {}
# Return statistics combined with metadata
if period not in ("day", "week", "month"):
return _sorted_statistics_to_dict(
hass,
session,
stats,
statistic_ids,
metadata,
True,
table,
start_time,
units,
types,
)
result = _sorted_statistics_to_dict(
return _statistics_during_period_with_session(
hass,
session,
stats,
statistic_ids,
metadata,
True,
table,
start_time,
end_time,
statistic_ids,
period,
units,
types,
)
if period == "day":
return _reduce_statistics_per_day(result, types)
if period == "week":
return _reduce_statistics_per_week(result, types)
return _reduce_statistics_per_month(result, types)
def _get_last_statistics_stmt(
metadata_id: int,
@ -2047,6 +2080,26 @@ def _filter_unique_constraint_integrity_error(
return _filter_unique_constraint_integrity_error
def _import_statistics_with_session(
session: Session,
metadata: StatisticMetaData,
statistics: Iterable[StatisticData],
table: type[Statistics | StatisticsShortTerm],
) -> bool:
"""Import statistics to the database."""
old_metadata_dict = get_metadata_with_session(
session, statistic_ids=[metadata["statistic_id"]]
)
metadata_id = _update_or_add_metadata(session, metadata, old_metadata_dict)
for stat in statistics:
if stat_id := _statistics_exists(session, table, metadata_id, stat["start"]):
_update_statistics(session, table, stat_id, stat)
else:
_insert_statistics(session, table, metadata_id, stat)
return True
@retryable_database_job("statistics")
def import_statistics(
instance: Recorder,
@ -2060,19 +2113,7 @@ def import_statistics(
session=instance.get_session(),
exception_filter=_filter_unique_constraint_integrity_error(instance),
) as session:
old_metadata_dict = get_metadata_with_session(
session, statistic_ids=[metadata["statistic_id"]]
)
metadata_id = _update_or_add_metadata(session, metadata, old_metadata_dict)
for stat in statistics:
if stat_id := _statistics_exists(
session, table, metadata_id, stat["start"]
):
_update_statistics(session, table, stat_id, stat)
else:
_insert_statistics(session, table, metadata_id, stat)
return True
return _import_statistics_with_session(session, metadata, statistics, table)
@retryable_database_job("adjust_statistics")
@ -2189,3 +2230,232 @@ def async_change_statistics_unit(
new_unit_of_measurement=new_unit_of_measurement,
old_unit_of_measurement=old_unit_of_measurement,
)
def _validate_db_schema_utf8(
instance: Recorder, session_maker: Callable[[], Session]
) -> set[str]:
"""Do some basic checks for common schema errors caused by manual migration."""
schema_errors: set[str] = set()
# Lack of full utf8 support is only an issue for MySQL / MariaDB
if instance.dialect_name != SupportedDialect.MYSQL:
return schema_errors
# This name can't be represented unless 4-byte UTF-8 unicode is supported
utf8_name = "𓆚𓃗"
statistic_id = f"{DOMAIN}.db_test"
metadata: StatisticMetaData = {
"has_mean": True,
"has_sum": True,
"name": utf8_name,
"source": DOMAIN,
"statistic_id": statistic_id,
"unit_of_measurement": None,
}
# Try inserting some metadata which needs utfmb4 support
try:
with session_scope(session=session_maker()) as session:
old_metadata_dict = get_metadata_with_session(
session, statistic_ids=[statistic_id]
)
try:
_update_or_add_metadata(session, metadata, old_metadata_dict)
_clear_statistics_with_session(session, statistic_ids=[statistic_id])
except OperationalError as err:
if err.orig and err.orig.args[0] == 1366:
_LOGGER.debug(
"Database table statistics_meta does not support 4-byte UTF-8"
)
schema_errors.add("statistics_meta.4-byte UTF-8")
session.rollback()
else:
raise
except Exception as exc: # pylint: disable=broad-except
_LOGGER.exception("Error when validating DB schema: %s", exc)
return schema_errors
def _validate_db_schema(
hass: HomeAssistant, instance: Recorder, session_maker: Callable[[], Session]
) -> set[str]:
"""Do some basic checks for common schema errors caused by manual migration."""
schema_errors: set[str] = set()
# Wrong precision is only an issue for MySQL / MariaDB / PostgreSQL
if instance.dialect_name not in (
SupportedDialect.MYSQL,
SupportedDialect.POSTGRESQL,
):
return schema_errors
# This number can't be accurately represented as a 32-bit float
precise_number = 1.000000000000001
# This time can't be accurately represented unless datetimes have µs precision
precise_time = datetime(2020, 10, 6, microsecond=1, tzinfo=dt_util.UTC)
start_time = datetime(2020, 10, 6, tzinfo=dt_util.UTC)
statistic_id = f"{DOMAIN}.db_test"
metadata: StatisticMetaData = {
"has_mean": True,
"has_sum": True,
"name": None,
"source": DOMAIN,
"statistic_id": statistic_id,
"unit_of_measurement": None,
}
statistics: StatisticData = {
"last_reset": precise_time,
"max": precise_number,
"mean": precise_number,
"min": precise_number,
"start": precise_time,
"state": precise_number,
"sum": precise_number,
}
def check_columns(
schema_errors: set[str],
stored: Mapping,
expected: Mapping,
columns: tuple[str, ...],
table_name: str,
supports: str,
) -> None:
for column in columns:
if stored[column] != expected[column]:
schema_errors.add(f"{table_name}.{supports}")
_LOGGER.debug(
"Column %s in database table %s does not support %s (%s != %s)",
column,
table_name,
supports,
stored[column],
expected[column],
)
# Insert / adjust a test statistics row in each of the tables
tables: tuple[type[Statistics | StatisticsShortTerm], ...] = (
Statistics,
StatisticsShortTerm,
)
try:
with session_scope(session=session_maker()) as session:
for table in tables:
_import_statistics_with_session(session, metadata, (statistics,), table)
stored_statistics = _statistics_during_period_with_session(
hass,
session,
start_time,
None,
[statistic_id],
"hour" if table == Statistics else "5minute",
None,
{"last_reset", "max", "mean", "min", "state", "sum"},
)
if not (stored_statistic := stored_statistics.get(statistic_id)):
_LOGGER.warning(
"Schema validation failed for table: %s", table.__tablename__
)
continue
check_columns(
schema_errors,
stored_statistic[0],
statistics,
("max", "mean", "min", "state", "sum"),
table.__tablename__,
"double precision",
)
assert statistics["last_reset"]
check_columns(
schema_errors,
stored_statistic[0],
{
"last_reset": statistics["last_reset"],
"start": statistics["start"],
},
("start", "last_reset"),
table.__tablename__,
"µs precision",
)
_clear_statistics_with_session(session, statistic_ids=[statistic_id])
except Exception as exc: # pylint: disable=broad-except
_LOGGER.exception("Error when validating DB schema: %s", exc)
return schema_errors
def validate_db_schema(
hass: HomeAssistant, instance: Recorder, session_maker: Callable[[], Session]
) -> set[str]:
"""Do some basic checks for common schema errors caused by manual migration."""
schema_errors: set[str] = set()
schema_errors |= _validate_db_schema_utf8(instance, session_maker)
schema_errors |= _validate_db_schema(hass, instance, session_maker)
if schema_errors:
_LOGGER.debug(
"Detected statistics schema errors: %s", ", ".join(sorted(schema_errors))
)
return schema_errors
def correct_db_schema(
engine: Engine, session_maker: Callable[[], Session], schema_errors: set[str]
) -> None:
"""Correct issues detected by validate_db_schema."""
from .migration import _modify_columns # pylint: disable=import-outside-toplevel
if "statistics_meta.4-byte UTF-8" in schema_errors:
# Attempt to convert the table to utf8mb4
_LOGGER.warning(
"Updating character set and collation of table %s to utf8mb4. "
"Note: this can take several minutes on large databases and slow "
"computers. Please be patient!",
"statistics_meta",
)
with contextlib.suppress(SQLAlchemyError):
with session_scope(session=session_maker()) as session:
connection = session.connection()
connection.execute(
# Using LOCK=EXCLUSIVE to prevent the database from corrupting
# https://github.com/home-assistant/core/issues/56104
text(
"ALTER TABLE statistics_meta CONVERT TO "
"CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci, LOCK=EXCLUSIVE"
)
)
tables: tuple[type[Statistics | StatisticsShortTerm], ...] = (
Statistics,
StatisticsShortTerm,
)
for table in tables:
if f"{table.__tablename__}.double precision" in schema_errors:
# Attempt to convert float columns to double precision
_modify_columns(
session_maker,
engine,
table.__tablename__,
[
"mean DOUBLE PRECISION",
"min DOUBLE PRECISION",
"max DOUBLE PRECISION",
"state DOUBLE PRECISION",
"sum DOUBLE PRECISION",
],
)
if f"{table.__tablename__}.µs precision" in schema_errors:
# Attempt to convert datetime columns to µs precision
_modify_columns(
session_maker,
engine,
table.__tablename__,
[
"last_reset DATETIME(6)",
"start DATETIME(6)",
],
)

View file

@ -1,13 +1,14 @@
"""The tests for sensor recorder platform."""
# pylint: disable=protected-access,invalid-name
from datetime import timedelta
from datetime import datetime, timedelta
import importlib
import sys
from unittest.mock import patch, sentinel
from unittest.mock import ANY, DEFAULT, MagicMock, patch, sentinel
import pytest
from pytest import approx
from sqlalchemy import create_engine
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session
from homeassistant.components import recorder
@ -16,6 +17,8 @@ from homeassistant.components.recorder.const import SQLITE_URL_PREFIX
from homeassistant.components.recorder.db_schema import StatisticsShortTerm
from homeassistant.components.recorder.models import process_timestamp
from homeassistant.components.recorder.statistics import (
_statistics_during_period_with_session,
_update_or_add_metadata,
async_add_external_statistics,
async_import_statistics,
delete_statistics_duplicates,
@ -1475,6 +1478,196 @@ def test_delete_metadata_duplicates_no_duplicates(hass_recorder, caplog):
assert "duplicated statistics_meta rows" not in caplog.text
@pytest.mark.parametrize("enable_statistics_table_validation", [True])
@pytest.mark.parametrize("db_engine", ("mysql", "postgresql"))
async def test_validate_db_schema(
async_setup_recorder_instance, hass, caplog, db_engine
):
"""Test validating DB schema with MySQL and PostgreSQL.
Note: The test uses SQLite, the purpose is only to exercise the code.
"""
with patch(
"homeassistant.components.recorder.core.Recorder.dialect_name", db_engine
):
await async_setup_recorder_instance(hass)
await async_wait_recording_done(hass)
assert "Schema validation failed" not in caplog.text
assert "Detected statistics schema errors" not in caplog.text
assert "Database is about to correct DB schema errors" not in caplog.text
@pytest.mark.parametrize("enable_statistics_table_validation", [True])
async def test_validate_db_schema_fix_utf8_issue(
async_setup_recorder_instance, hass, caplog
):
"""Test validating DB schema with MySQL.
Note: The test uses SQLite, the purpose is only to exercise the code.
"""
orig_error = MagicMock()
orig_error.args = [1366]
utf8_error = OperationalError("", "", orig=orig_error)
with patch(
"homeassistant.components.recorder.core.Recorder.dialect_name", "mysql"
), patch(
"homeassistant.components.recorder.statistics._update_or_add_metadata",
side_effect=[utf8_error, DEFAULT, DEFAULT],
wraps=_update_or_add_metadata,
):
await async_setup_recorder_instance(hass)
await async_wait_recording_done(hass)
assert "Schema validation failed" not in caplog.text
assert (
"Database is about to correct DB schema errors: statistics_meta.4-byte UTF-8"
in caplog.text
)
assert (
"Updating character set and collation of table statistics_meta to utf8mb4"
in caplog.text
)
@pytest.mark.parametrize("enable_statistics_table_validation", [True])
@pytest.mark.parametrize("db_engine", ("mysql", "postgresql"))
@pytest.mark.parametrize(
"table, replace_index", (("statistics", 0), ("statistics_short_term", 1))
)
@pytest.mark.parametrize(
"column, value",
(("max", 1.0), ("mean", 1.0), ("min", 1.0), ("state", 1.0), ("sum", 1.0)),
)
async def test_validate_db_schema_fix_float_issue(
async_setup_recorder_instance,
hass,
caplog,
db_engine,
table,
replace_index,
column,
value,
):
"""Test validating DB schema with MySQL.
Note: The test uses SQLite, the purpose is only to exercise the code.
"""
orig_error = MagicMock()
orig_error.args = [1366]
precise_number = 1.000000000000001
precise_time = datetime(2020, 10, 6, microsecond=1, tzinfo=dt_util.UTC)
statistics = {
"recorder.db_test": [
{
"last_reset": precise_time,
"max": precise_number,
"mean": precise_number,
"min": precise_number,
"start": precise_time,
"state": precise_number,
"sum": precise_number,
}
]
}
statistics["recorder.db_test"][0][column] = value
fake_statistics = [DEFAULT, DEFAULT]
fake_statistics[replace_index] = statistics
with patch(
"homeassistant.components.recorder.core.Recorder.dialect_name", db_engine
), patch(
"homeassistant.components.recorder.statistics._statistics_during_period_with_session",
side_effect=fake_statistics,
wraps=_statistics_during_period_with_session,
), patch(
"homeassistant.components.recorder.migration._modify_columns"
) as modify_columns_mock:
await async_setup_recorder_instance(hass)
await async_wait_recording_done(hass)
assert "Schema validation failed" not in caplog.text
assert (
f"Database is about to correct DB schema errors: {table}.double precision"
in caplog.text
)
modification = [
"mean DOUBLE PRECISION",
"min DOUBLE PRECISION",
"max DOUBLE PRECISION",
"state DOUBLE PRECISION",
"sum DOUBLE PRECISION",
]
modify_columns_mock.assert_called_once_with(ANY, ANY, table, modification)
@pytest.mark.parametrize("enable_statistics_table_validation", [True])
@pytest.mark.parametrize("db_engine", ("mysql", "postgresql"))
@pytest.mark.parametrize(
"table, replace_index", (("statistics", 0), ("statistics_short_term", 1))
)
@pytest.mark.parametrize(
"column, value",
(
("last_reset", "2020-10-06T00:00:00+00:00"),
("start", "2020-10-06T00:00:00+00:00"),
),
)
async def test_validate_db_schema_fix_statistics_datetime_issue(
async_setup_recorder_instance,
hass,
caplog,
db_engine,
table,
replace_index,
column,
value,
):
"""Test validating DB schema with MySQL.
Note: The test uses SQLite, the purpose is only to exercise the code.
"""
orig_error = MagicMock()
orig_error.args = [1366]
precise_number = 1.000000000000001
precise_time = datetime(2020, 10, 6, microsecond=1, tzinfo=dt_util.UTC)
statistics = {
"recorder.db_test": [
{
"last_reset": precise_time,
"max": precise_number,
"mean": precise_number,
"min": precise_number,
"start": precise_time,
"state": precise_number,
"sum": precise_number,
}
]
}
statistics["recorder.db_test"][0][column] = value
fake_statistics = [DEFAULT, DEFAULT]
fake_statistics[replace_index] = statistics
with patch(
"homeassistant.components.recorder.core.Recorder.dialect_name", db_engine
), patch(
"homeassistant.components.recorder.statistics._statistics_during_period_with_session",
side_effect=fake_statistics,
wraps=_statistics_during_period_with_session,
), patch(
"homeassistant.components.recorder.migration._modify_columns"
) as modify_columns_mock:
await async_setup_recorder_instance(hass)
await async_wait_recording_done(hass)
assert "Schema validation failed" not in caplog.text
assert (
f"Database is about to correct DB schema errors: {table}.µs precision"
in caplog.text
)
modification = ["last_reset DATETIME(6)", "start DATETIME(6)"]
modify_columns_mock.assert_called_once_with(ANY, ANY, table, modification)
def record_states(hass):
"""Record some test states.

View file

@ -5,6 +5,7 @@ import asyncio
from collections.abc import AsyncGenerator, Callable, Generator
from contextlib import asynccontextmanager
import functools
import itertools
from json import JSONDecoder, loads
import logging
import sqlite3
@ -860,6 +861,16 @@ def enable_statistics():
return False
@pytest.fixture
def enable_statistics_table_validation():
"""Fixture to control enabling of recorder's statistics table validation.
To enable statistics table validation, tests can be marked with:
@pytest.mark.parametrize("enable_statistics_table_validation", [True])
"""
return False
@pytest.fixture
def enable_nightly_purge():
"""Fixture to control enabling of recorder's nightly purge job.
@ -902,6 +913,7 @@ def hass_recorder(
recorder_db_url,
enable_nightly_purge,
enable_statistics,
enable_statistics_table_validation,
hass_storage,
):
"""Home Assistant fixture with in-memory recorder."""
@ -910,6 +922,11 @@ def hass_recorder(
hass = get_test_home_assistant()
nightly = recorder.Recorder.async_nightly_tasks if enable_nightly_purge else None
stats = recorder.Recorder.async_periodic_statistics if enable_statistics else None
stats_validate = (
recorder.statistics.validate_db_schema
if enable_statistics_table_validation
else itertools.repeat(set())
)
with patch(
"homeassistant.components.recorder.Recorder.async_nightly_tasks",
side_effect=nightly,
@ -918,6 +935,10 @@ def hass_recorder(
"homeassistant.components.recorder.Recorder.async_periodic_statistics",
side_effect=stats,
autospec=True,
), patch(
"homeassistant.components.recorder.migration.statistics_validate_db_schema",
side_effect=stats_validate,
autospec=True,
):
def setup_recorder(config=None):
@ -962,12 +983,18 @@ async def async_setup_recorder_instance(
hass_fixture_setup,
enable_nightly_purge,
enable_statistics,
enable_statistics_table_validation,
) -> AsyncGenerator[SetupRecorderInstanceT, None]:
"""Yield callable to setup recorder instance."""
assert not hass_fixture_setup
nightly = recorder.Recorder.async_nightly_tasks if enable_nightly_purge else None
stats = recorder.Recorder.async_periodic_statistics if enable_statistics else None
stats_validate = (
recorder.statistics.validate_db_schema
if enable_statistics_table_validation
else itertools.repeat(set())
)
with patch(
"homeassistant.components.recorder.Recorder.async_nightly_tasks",
side_effect=nightly,
@ -976,6 +1003,10 @@ async def async_setup_recorder_instance(
"homeassistant.components.recorder.Recorder.async_periodic_statistics",
side_effect=stats,
autospec=True,
), patch(
"homeassistant.components.recorder.migration.statistics_validate_db_schema",
side_effect=stats_validate,
autospec=True,
):
async def async_setup_recorder(