Refactor recorder migration (#80175)

* Refactor recorder migration

* Improve test coverage
This commit is contained in:
Erik Montnemery 2022-10-13 08:11:54 +02:00 committed by GitHub
parent ca4c4774ca
commit 466c4656ca
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 101 additions and 39 deletions

View file

@ -3,6 +3,7 @@ from __future__ import annotations
from collections.abc import Callable, Iterable
import contextlib
from dataclasses import dataclass
from datetime import timedelta
import logging
from typing import TYPE_CHECKING, cast
@ -61,33 +62,65 @@ def raise_if_exception_missing_str(ex: Exception, match_substrs: Iterable[str])
raise ex
def get_schema_version(session_maker: Callable[[], Session]) -> int:
def get_schema_version(session_maker: Callable[[], Session]) -> int | None:
"""Get the schema version."""
with session_scope(session=session_maker()) as session:
res = (
session.query(SchemaChanges)
.order_by(SchemaChanges.change_id.desc())
.first()
)
current_version = getattr(res, "schema_version", None)
if current_version is None:
current_version = _inspect_schema_version(session)
_LOGGER.debug(
"No schema version found. Inspected version: %s", current_version
try:
with session_scope(session=session_maker()) as session:
res = (
session.query(SchemaChanges)
.order_by(SchemaChanges.change_id.desc())
.first()
)
current_version = getattr(res, "schema_version", None)
return cast(int, current_version)
if current_version is None:
current_version = _inspect_schema_version(session)
_LOGGER.debug(
"No schema version found. Inspected version: %s", current_version
)
return cast(int, current_version)
except Exception as err: # pylint: disable=broad-except
_LOGGER.exception("Error when determining DB schema version: %s", err)
return None
def schema_is_current(current_version: int) -> bool:
@dataclass
class SchemaValidationStatus:
"""Store schema validation status."""
current_version: int
def _schema_is_current(current_version: int) -> bool:
"""Check if the schema is current."""
return current_version == SCHEMA_VERSION
def live_migration(current_version: int) -> bool:
def schema_is_valid(schema_status: SchemaValidationStatus) -> bool:
"""Check if the schema is valid."""
return _schema_is_current(schema_status.current_version)
def validate_db_schema(
hass: HomeAssistant, session_maker: Callable[[], Session]
) -> SchemaValidationStatus | None:
"""Check if the schema is valid.
This checks that the schema is the current version as well as for some common schema
errors caused by manual migration between database engines, for example importing an
SQLite database to MariaDB.
"""
current_version = get_schema_version(session_maker)
if current_version is None:
return None
return SchemaValidationStatus(current_version)
def live_migration(schema_status: SchemaValidationStatus) -> bool:
"""Check if live migration is possible."""
return current_version >= LIVE_MIGRATION_MIN_SCHEMA_VERSION
return schema_status.current_version >= LIVE_MIGRATION_MIN_SCHEMA_VERSION
def migrate_schema(
@ -95,13 +128,14 @@ def migrate_schema(
hass: HomeAssistant,
engine: Engine,
session_maker: Callable[[], Session],
current_version: int,
schema_status: SchemaValidationStatus,
) -> None:
"""Check if the schema needs to be upgraded."""
current_version = schema_status.current_version
_LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version)
db_ready = False
for version in range(current_version, SCHEMA_VERSION):
if live_migration(version) and not db_ready:
if live_migration(SchemaValidationStatus(version)) and not db_ready:
db_ready = True
instance.migration_is_live = True
hass.add_job(instance.async_set_db_ready)