Refactor recorder schema migration (#122372)

* Refactor recorder schema migration

* Simplify

* Remove unused imports

* Refactor _migrate_schema according to review comments

* Add comment
This commit is contained in:
Erik Montnemery 2024-07-22 16:53:54 +02:00 committed by GitHub
parent c73e7ae178
commit e8b88557ee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 168 additions and 76 deletions

View file

@ -188,12 +188,13 @@ def get_schema_version(session_maker: Callable[[], Session]) -> int | None:
return None
@dataclass
@dataclass(frozen=True)
class SchemaValidationStatus:
"""Store schema validation status."""
current_version: int
schema_errors: set[str]
start_version: int
valid: bool
@ -224,7 +225,9 @@ def validate_db_schema(
valid = is_current and not schema_errors
return SchemaValidationStatus(current_version, schema_errors, valid)
return SchemaValidationStatus(
current_version, schema_errors, current_version, valid
)
def _find_schema_errors(
@ -260,35 +263,30 @@ def pre_migrate_schema(engine: Engine) -> None:
)
def migrate_schema(
def _migrate_schema(
instance: Recorder,
hass: HomeAssistant,
engine: Engine,
session_maker: Callable[[], Session],
schema_status: SchemaValidationStatus,
) -> None:
end_version: int,
) -> SchemaValidationStatus:
"""Check if the schema needs to be upgraded."""
current_version = schema_status.current_version
if current_version != SCHEMA_VERSION:
start_version = schema_status.start_version
if current_version < end_version:
_LOGGER.warning(
"Database is about to upgrade from schema version: %s to: %s",
current_version,
SCHEMA_VERSION,
end_version,
)
db_ready = False
for version in range(current_version, SCHEMA_VERSION):
if (
live_migration(dataclass_replace(schema_status, current_version=version))
and not db_ready
):
db_ready = True
instance.migration_is_live = True
hass.add_job(instance.async_set_db_ready)
schema_status = dataclass_replace(schema_status, current_version=end_version)
for version in range(current_version, end_version):
new_version = version + 1
_LOGGER.info("Upgrading recorder db schema to version %s", new_version)
_apply_update(
instance, hass, engine, session_maker, new_version, current_version
)
_apply_update(instance, hass, engine, session_maker, new_version, start_version)
with session_scope(session=session_maker()) as session:
session.add(SchemaChanges(schema_version=new_version))
@ -296,6 +294,37 @@ def migrate_schema(
# so its clear that the upgrade is done
_LOGGER.warning("Upgrade to version %s done", new_version)
return schema_status
def migrate_schema_non_live(
instance: Recorder,
hass: HomeAssistant,
engine: Engine,
session_maker: Callable[[], Session],
schema_status: SchemaValidationStatus,
) -> SchemaValidationStatus:
"""Check if the schema needs to be upgraded."""
end_version = LIVE_MIGRATION_MIN_SCHEMA_VERSION - 1
return _migrate_schema(
instance, hass, engine, session_maker, schema_status, end_version
)
def migrate_schema_live(
instance: Recorder,
hass: HomeAssistant,
engine: Engine,
session_maker: Callable[[], Session],
schema_status: SchemaValidationStatus,
) -> SchemaValidationStatus:
"""Check if the schema needs to be upgraded."""
end_version = SCHEMA_VERSION
schema_status = _migrate_schema(
instance, hass, engine, session_maker, schema_status, end_version
)
# Repairs are currently done during the live migration
if schema_errors := schema_status.schema_errors:
_LOGGER.warning(
"Database is about to correct DB schema errors: %s",
@ -305,12 +334,15 @@ def migrate_schema(
states_correct_db_schema(instance, schema_errors)
events_correct_db_schema(instance, schema_errors)
if current_version != SCHEMA_VERSION:
instance.queue_task(PostSchemaMigrationTask(current_version, SCHEMA_VERSION))
start_version = schema_status.start_version
if start_version != SCHEMA_VERSION:
instance.queue_task(PostSchemaMigrationTask(start_version, SCHEMA_VERSION))
# Make sure the post schema migration task is committed in case
# the next task does not have commit_before = True
instance.queue_task(CommitTask())
return schema_status
def _create_index(
session_maker: Callable[[], Session], table_name: str, index_name: str