Refactor recorder migration (#80175)
* Refactor recorder migration * Improve test coverage
This commit is contained in:
parent
ca4c4774ca
commit
466c4656ca
4 changed files with 101 additions and 39 deletions
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||
|
||||
from collections.abc import Callable, Iterable
|
||||
import contextlib
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
@ -61,33 +62,65 @@ def raise_if_exception_missing_str(ex: Exception, match_substrs: Iterable[str])
|
|||
raise ex
|
||||
|
||||
|
||||
def get_schema_version(session_maker: Callable[[], Session]) -> int:
|
||||
def get_schema_version(session_maker: Callable[[], Session]) -> int | None:
|
||||
"""Get the schema version."""
|
||||
with session_scope(session=session_maker()) as session:
|
||||
res = (
|
||||
session.query(SchemaChanges)
|
||||
.order_by(SchemaChanges.change_id.desc())
|
||||
.first()
|
||||
)
|
||||
current_version = getattr(res, "schema_version", None)
|
||||
|
||||
if current_version is None:
|
||||
current_version = _inspect_schema_version(session)
|
||||
_LOGGER.debug(
|
||||
"No schema version found. Inspected version: %s", current_version
|
||||
try:
|
||||
with session_scope(session=session_maker()) as session:
|
||||
res = (
|
||||
session.query(SchemaChanges)
|
||||
.order_by(SchemaChanges.change_id.desc())
|
||||
.first()
|
||||
)
|
||||
current_version = getattr(res, "schema_version", None)
|
||||
|
||||
return cast(int, current_version)
|
||||
if current_version is None:
|
||||
current_version = _inspect_schema_version(session)
|
||||
_LOGGER.debug(
|
||||
"No schema version found. Inspected version: %s", current_version
|
||||
)
|
||||
|
||||
return cast(int, current_version)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Error when determining DB schema version: %s", err)
|
||||
return None
|
||||
|
||||
|
||||
def schema_is_current(current_version: int) -> bool:
|
||||
@dataclass
|
||||
class SchemaValidationStatus:
|
||||
"""Store schema validation status."""
|
||||
|
||||
current_version: int
|
||||
|
||||
|
||||
def _schema_is_current(current_version: int) -> bool:
|
||||
"""Check if the schema is current."""
|
||||
return current_version == SCHEMA_VERSION
|
||||
|
||||
|
||||
def live_migration(current_version: int) -> bool:
|
||||
def schema_is_valid(schema_status: SchemaValidationStatus) -> bool:
|
||||
"""Check if the schema is valid."""
|
||||
return _schema_is_current(schema_status.current_version)
|
||||
|
||||
|
||||
def validate_db_schema(
|
||||
hass: HomeAssistant, session_maker: Callable[[], Session]
|
||||
) -> SchemaValidationStatus | None:
|
||||
"""Check if the schema is valid.
|
||||
|
||||
This checks that the schema is the current version as well as for some common schema
|
||||
errors caused by manual migration between database engines, for example importing an
|
||||
SQLite database to MariaDB.
|
||||
"""
|
||||
current_version = get_schema_version(session_maker)
|
||||
if current_version is None:
|
||||
return None
|
||||
|
||||
return SchemaValidationStatus(current_version)
|
||||
|
||||
|
||||
def live_migration(schema_status: SchemaValidationStatus) -> bool:
|
||||
"""Check if live migration is possible."""
|
||||
return current_version >= LIVE_MIGRATION_MIN_SCHEMA_VERSION
|
||||
return schema_status.current_version >= LIVE_MIGRATION_MIN_SCHEMA_VERSION
|
||||
|
||||
|
||||
def migrate_schema(
|
||||
|
@ -95,13 +128,14 @@ def migrate_schema(
|
|||
hass: HomeAssistant,
|
||||
engine: Engine,
|
||||
session_maker: Callable[[], Session],
|
||||
current_version: int,
|
||||
schema_status: SchemaValidationStatus,
|
||||
) -> None:
|
||||
"""Check if the schema needs to be upgraded."""
|
||||
current_version = schema_status.current_version
|
||||
_LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version)
|
||||
db_ready = False
|
||||
for version in range(current_version, SCHEMA_VERSION):
|
||||
if live_migration(version) and not db_ready:
|
||||
if live_migration(SchemaValidationStatus(version)) and not db_ready:
|
||||
db_ready = True
|
||||
instance.migration_is_live = True
|
||||
hass.add_job(instance.async_set_db_ready)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue