Refactor recorder data migration (#121009)

* Refactor recorder data migration

* Fix stale docstrings

* Don't store a session object in BaseRunTimeMigration instances

* Simplify logic in EntityIDMigration.migration_done

* Fix tests
This commit is contained in:
Erik Montnemery 2024-07-16 21:50:19 +02:00 committed by GitHub
parent baa97ca981
commit 9970b7eece
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 149 additions and 183 deletions

View file

@ -75,7 +75,6 @@ from .const import (
SupportedDialect,
)
from .db_schema import (
LEGACY_STATES_ENTITY_ID_LAST_UPDATED_INDEX,
LEGACY_STATES_EVENT_ID_INDEX,
SCHEMA_VERSION,
TABLE_STATES,
@ -91,7 +90,6 @@ from .db_schema import (
)
from .executor import DBInterruptibleThreadPoolExecutor
from .migration import (
BaseRunTimeMigration,
EntityIDMigration,
EventsContextIDMigration,
EventTypeIDMigration,
@ -115,7 +113,6 @@ from .tasks import (
CommitTask,
CompileMissingStatisticsTask,
DatabaseLockTask,
EntityIDPostMigrationTask,
EventIdMigrationTask,
ImportStatisticsTask,
KeepAliveTask,
@ -804,37 +801,14 @@ class Recorder(threading.Thread):
for row in execute_stmt_lambda_element(session, get_migration_changes())
}
migrator: BaseRunTimeMigration
for migrator_cls in (StatesContextIDMigration, EventsContextIDMigration):
migrator = migrator_cls(session, schema_version, migration_changes)
if migrator.needs_migrate():
self.queue_task(migrator.task())
migrator = EventTypeIDMigration(session, schema_version, migration_changes)
if migrator.needs_migrate():
self.queue_task(migrator.task())
else:
_LOGGER.debug("Activating event_types manager as all data is migrated")
self.event_type_manager.active = True
migrator = EntityIDMigration(session, schema_version, migration_changes)
if migrator.needs_migrate():
self.queue_task(migrator.task())
else:
_LOGGER.debug("Activating states_meta manager as all data is migrated")
self.states_meta_manager.active = True
with contextlib.suppress(SQLAlchemyError):
# If ix_states_entity_id_last_updated_ts still exists
# on the states table it means the entity id migration
# finished by the EntityIDPostMigrationTask did not
# complete because they restarted in the middle of it. We need
# to pick back up where we left off.
if get_index_by_name(
session,
TABLE_STATES,
LEGACY_STATES_ENTITY_ID_LAST_UPDATED_INDEX,
):
self.queue_task(EntityIDPostMigrationTask())
for migrator_cls in (
StatesContextIDMigration,
EventsContextIDMigration,
EventTypeIDMigration,
EntityIDMigration,
):
migrator = migrator_cls(schema_version, migration_changes)
migrator.do_migrate(self, session)
if self.schema_version > LEGACY_STATES_EVENT_ID_INDEX_SCHEMA_VERSION:
with contextlib.suppress(SQLAlchemyError):
@ -1319,22 +1293,6 @@ class Recorder(threading.Thread):
)
)
def _migrate_states_context_ids(self) -> bool:
"""Migrate states context ids if needed."""
return migration.migrate_states_context_ids(self)
def _migrate_events_context_ids(self) -> bool:
"""Migrate events context ids if needed."""
return migration.migrate_events_context_ids(self)
def _migrate_event_type_ids(self) -> bool:
"""Migrate event type ids if needed."""
return migration.migrate_event_type_ids(self)
def _migrate_entity_ids(self) -> bool:
"""Migrate entity_ids if needed."""
return migration.migrate_entity_ids(self)
def _post_migrate_entity_ids(self) -> bool:
"""Post migrate entity_ids if needed."""
return migration.post_migrate_entity_ids(self)

View file

@ -102,12 +102,9 @@ from .queries import (
from .statistics import get_start_time
from .tasks import (
CommitTask,
EntityIDMigrationTask,
EventsContextIDMigrationTask,
EventTypeIDMigrationTask,
EntityIDPostMigrationTask,
PostSchemaMigrationTask,
RecorderTask,
StatesContextIDMigrationTask,
StatisticsTimestampMigrationCleanupTask,
)
from .util import (
@ -2001,9 +1998,6 @@ def migrate_event_type_ids(instance: Recorder) -> bool:
if is_done := not events:
_mark_migration_done(session, EventTypeIDMigration)
if is_done:
instance.event_type_manager.active = True
_LOGGER.debug("Migrating event_types done=%s", is_done)
return is_done
@ -2182,27 +2176,62 @@ def initialize_database(session_maker: Callable[[], Session]) -> bool:
return False
@dataclass(slots=True)
class MigrationTask(RecorderTask):
"""Base class for migration tasks."""
migrator: BaseRunTimeMigration
commit_before = False
def run(self, instance: Recorder) -> None:
"""Run migration task."""
if not self.migrator.migrate_data(instance):
# Schedule a new migration task if this one didn't finish
instance.queue_task(MigrationTask(self.migrator))
else:
self.migrator.migration_done(instance)
@dataclass(slots=True)
class CommitBeforeMigrationTask(MigrationTask):
"""Base class for migration tasks which commit first."""
commit_before = True
class BaseRunTimeMigration(ABC):
"""Base class for run time migrations."""
required_schema_version = 0
migration_version = 1
migration_id: str
task: Callable[[], RecorderTask]
task = MigrationTask
def __init__(
self, session: Session, schema_version: int, migration_changes: dict[str, int]
) -> None:
def __init__(self, schema_version: int, migration_changes: dict[str, int]) -> None:
"""Initialize a new BaseRunTimeMigration."""
self.schema_version = schema_version
self.session = session
self.migration_changes = migration_changes
def do_migrate(self, instance: Recorder, session: Session) -> None:
"""Start migration if needed."""
if self.needs_migrate(session):
instance.queue_task(self.task(self))
else:
self.migration_done(instance)
@staticmethod
@abstractmethod
def migrate_data(instance: Recorder) -> bool:
"""Migrate some data, returns True if migration is completed."""
def migration_done(self, instance: Recorder) -> None:
"""Will be called after migrate returns True or if migration is not needed."""
@abstractmethod
def needs_migrate_query(self) -> StatementLambdaElement:
"""Return the query to check if the migration needs to run."""
def needs_migrate(self) -> bool:
def needs_migrate(self, session: Session) -> bool:
"""Return if the migration needs to run.
If the migration needs to run, it will return True.
@ -2220,8 +2249,8 @@ class BaseRunTimeMigration(ABC):
# We do not know if the migration is done from the
# migration changes table so we must check the data
# This is the slow path
if not execute_stmt_lambda_element(self.session, self.needs_migrate_query()):
_mark_migration_done(self.session, self.__class__)
if not execute_stmt_lambda_element(session, self.needs_migrate_query()):
_mark_migration_done(session, self.__class__)
return False
return True
@ -2231,7 +2260,11 @@ class StatesContextIDMigration(BaseRunTimeMigration):
required_schema_version = CONTEXT_ID_AS_BINARY_SCHEMA_VERSION
migration_id = "state_context_id_as_binary"
task = StatesContextIDMigrationTask
@staticmethod
def migrate_data(instance: Recorder) -> bool:
"""Migrate some data, returns True if migration is completed."""
return migrate_states_context_ids(instance)
def needs_migrate_query(self) -> StatementLambdaElement:
"""Return the query to check if the migration needs to run."""
@ -2243,7 +2276,11 @@ class EventsContextIDMigration(BaseRunTimeMigration):
required_schema_version = CONTEXT_ID_AS_BINARY_SCHEMA_VERSION
migration_id = "event_context_id_as_binary"
task = EventsContextIDMigrationTask
@staticmethod
def migrate_data(instance: Recorder) -> bool:
"""Migrate some data, returns True if migration is completed."""
return migrate_events_context_ids(instance)
def needs_migrate_query(self) -> StatementLambdaElement:
"""Return the query to check if the migration needs to run."""
@ -2255,7 +2292,20 @@ class EventTypeIDMigration(BaseRunTimeMigration):
required_schema_version = EVENT_TYPE_IDS_SCHEMA_VERSION
migration_id = "event_type_id_migration"
task = EventTypeIDMigrationTask
task = CommitBeforeMigrationTask
# We have to commit before to make sure there are
# no new pending event_types about to be added to
# the db since this happens live
@staticmethod
def migrate_data(instance: Recorder) -> bool:
"""Migrate some data, returns True if migration is completed."""
return migrate_event_type_ids(instance)
def migration_done(self, instance: Recorder) -> None:
"""Will be called after migrate returns True."""
_LOGGER.debug("Activating event_types manager as all data is migrated")
instance.event_type_manager.active = True
def needs_migrate_query(self) -> StatementLambdaElement:
"""Check if the data is migrated."""
@ -2267,7 +2317,39 @@ class EntityIDMigration(BaseRunTimeMigration):
required_schema_version = STATES_META_SCHEMA_VERSION
migration_id = "entity_id_migration"
task = EntityIDMigrationTask
task = CommitBeforeMigrationTask
# We have to commit before to make sure there are
# no new pending states_meta about to be added to
# the db since this happens live
@staticmethod
def migrate_data(instance: Recorder) -> bool:
"""Migrate some data, returns True if migration is completed."""
return migrate_entity_ids(instance)
def migration_done(self, instance: Recorder) -> None:
"""Will be called after migrate returns True."""
# The migration has finished, now we start the post migration
# to remove the old entity_id data from the states table
# at this point we can also start using the StatesMeta table
# so we set active to True
_LOGGER.debug("Activating states_meta manager as all data is migrated")
instance.states_meta_manager.active = True
with (
contextlib.suppress(SQLAlchemyError),
session_scope(session=instance.get_session()) as session,
):
# If ix_states_entity_id_last_updated_ts still exists
# on the states table it means the entity id migration
# finished by the EntityIDPostMigrationTask did not
# complete because they restarted in the middle of it. We need
# to pick back up where we left off.
if get_index_by_name(
session,
TABLE_STATES,
LEGACY_STATES_ENTITY_ID_LAST_UPDATED_INDEX,
):
instance.queue_task(EntityIDPostMigrationTask())
def needs_migrate_query(self) -> StatementLambdaElement:
"""Check if the data is migrated."""

View file

@ -358,75 +358,6 @@ class AdjustLRUSizeTask(RecorderTask):
instance._adjust_lru_size() # noqa: SLF001
@dataclass(slots=True)
class StatesContextIDMigrationTask(RecorderTask):
"""An object to insert into the recorder queue to migrate states context ids."""
commit_before = False
def run(self, instance: Recorder) -> None:
"""Run context id migration task."""
if (
not instance._migrate_states_context_ids() # noqa: SLF001
):
# Schedule a new migration task if this one didn't finish
instance.queue_task(StatesContextIDMigrationTask())
@dataclass(slots=True)
class EventsContextIDMigrationTask(RecorderTask):
"""An object to insert into the recorder queue to migrate events context ids."""
commit_before = False
def run(self, instance: Recorder) -> None:
"""Run context id migration task."""
if (
not instance._migrate_events_context_ids() # noqa: SLF001
):
# Schedule a new migration task if this one didn't finish
instance.queue_task(EventsContextIDMigrationTask())
@dataclass(slots=True)
class EventTypeIDMigrationTask(RecorderTask):
"""An object to insert into the recorder queue to migrate event type ids."""
commit_before = True
# We have to commit before to make sure there are
# no new pending event_types about to be added to
# the db since this happens live
def run(self, instance: Recorder) -> None:
"""Run event type id migration task."""
if not instance._migrate_event_type_ids(): # noqa: SLF001
# Schedule a new migration task if this one didn't finish
instance.queue_task(EventTypeIDMigrationTask())
@dataclass(slots=True)
class EntityIDMigrationTask(RecorderTask):
"""An object to insert into the recorder queue to migrate entity_ids to StatesMeta."""
commit_before = True
# We have to commit before to make sure there are
# no new pending states_meta about to be added to
# the db since this happens live
def run(self, instance: Recorder) -> None:
"""Run entity_id migration task."""
if not instance._migrate_entity_ids(): # noqa: SLF001
# Schedule a new migration task if this one didn't finish
instance.queue_task(EntityIDMigrationTask())
else:
# The migration has finished, now we start the post migration
# to remove the old entity_id data from the states table
# at this point we can also start using the StatesMeta table
# so we set active to True
instance.states_meta_manager.active = True
instance.queue_task(EntityIDPostMigrationTask())
@dataclass(slots=True)
class EntityIDPostMigrationTask(RecorderTask):
"""An object to insert into the recorder queue to cleanup after entity_ids migration."""

View file

@ -416,6 +416,14 @@ def get_schema_module_path(schema_version_postfix: str) -> str:
return f"tests.components.recorder.db_schema_{schema_version_postfix}"
@dataclass(slots=True)
class MockMigrationTask(migration.MigrationTask):
"""Mock migration task which does nothing."""
def run(self, instance: Recorder) -> None:
"""Run migration task."""
@contextmanager
def old_db_schema(schema_version_postfix: str) -> Iterator[None]:
"""Fixture to initialize the db with the old schema."""
@ -434,7 +442,7 @@ def old_db_schema(schema_version_postfix: str) -> Iterator[None]:
patch.object(core, "States", old_db_schema.States),
patch.object(core, "Events", old_db_schema.Events),
patch.object(core, "StateAttributes", old_db_schema.StateAttributes),
patch.object(migration.EntityIDMigration, "task", core.RecorderTask),
patch.object(migration.EntityIDMigration, "task", MockMigrationTask),
patch(
CREATE_ENGINE_TARGET,
new=partial(

View file

@ -32,13 +32,7 @@ from homeassistant.components.recorder.queries import (
get_migration_changes,
select_event_type_ids,
)
from homeassistant.components.recorder.tasks import (
EntityIDMigrationTask,
EntityIDPostMigrationTask,
EventsContextIDMigrationTask,
EventTypeIDMigrationTask,
StatesContextIDMigrationTask,
)
from homeassistant.components.recorder.tasks import EntityIDPostMigrationTask
from homeassistant.components.recorder.util import (
execute_stmt_lambda_element,
session_scope,
@ -48,6 +42,7 @@ import homeassistant.util.dt as dt_util
from homeassistant.util.ulid import bytes_to_ulid, ulid_at_time, ulid_to_bytes
from .common import (
MockMigrationTask,
async_attach_db_engine,
async_recorder_block_till_done,
async_wait_recording_done,
@ -116,7 +111,7 @@ def db_schema_32():
patch.object(core, "States", old_db_schema.States),
patch.object(core, "Events", old_db_schema.Events),
patch.object(core, "StateAttributes", old_db_schema.StateAttributes),
patch.object(migration.EntityIDMigration, "task", core.RecorderTask),
patch.object(migration.EntityIDMigration, "task", MockMigrationTask),
patch(CREATE_ENGINE_TARGET, new=_create_engine_test),
):
yield
@ -229,7 +224,8 @@ async def test_migrate_events_context_ids(
with freeze_time(now):
# This is a threadsafe way to add a task to the recorder
recorder_mock.queue_task(EventsContextIDMigrationTask())
migrator = migration.EventsContextIDMigration(None, None)
recorder_mock.queue_task(migrator.task(migrator))
await _async_wait_migration_done(hass)
def _object_as_dict(obj):
@ -419,7 +415,8 @@ async def test_migrate_states_context_ids(
await recorder_mock.async_add_executor_job(_insert_states)
await async_wait_recording_done(hass)
recorder_mock.queue_task(StatesContextIDMigrationTask())
migrator = migration.StatesContextIDMigration(None, None)
recorder_mock.queue_task(migrator.task(migrator))
await _async_wait_migration_done(hass)
def _object_as_dict(obj):
@ -567,7 +564,8 @@ async def test_migrate_event_type_ids(
await async_wait_recording_done(hass)
# This is a threadsafe way to add a task to the recorder
recorder_mock.queue_task(EventTypeIDMigrationTask())
migrator = migration.EventTypeIDMigration(None, None)
recorder_mock.queue_task(migrator.task(migrator))
await _async_wait_migration_done(hass)
def _fetch_migrated_events():
@ -655,7 +653,8 @@ async def test_migrate_entity_ids(hass: HomeAssistant, recorder_mock: Recorder)
await _async_wait_migration_done(hass)
# This is a threadsafe way to add a task to the recorder
recorder_mock.queue_task(EntityIDMigrationTask())
migrator = migration.EntityIDMigration(None, None)
recorder_mock.queue_task(migration.CommitBeforeMigrationTask(migrator))
await _async_wait_migration_done(hass)
def _fetch_migrated_states():
@ -788,7 +787,8 @@ async def test_migrate_null_entity_ids(
await _async_wait_migration_done(hass)
# This is a threadsafe way to add a task to the recorder
recorder_mock.queue_task(EntityIDMigrationTask())
migrator = migration.EntityIDMigration(None, None)
recorder_mock.queue_task(migration.CommitBeforeMigrationTask(migrator))
await _async_wait_migration_done(hass)
def _fetch_migrated_states():
@ -870,7 +870,8 @@ async def test_migrate_null_event_type_ids(
await _async_wait_migration_done(hass)
# This is a threadsafe way to add a task to the recorder
recorder_mock.queue_task(EventTypeIDMigrationTask())
migrator = migration.EventTypeIDMigration(None, None)
recorder_mock.queue_task(migrator.task(migrator))
await _async_wait_migration_done(hass)
def _fetch_migrated_events():

View file

@ -10,8 +10,8 @@ from sqlalchemy.orm import Session
from homeassistant.components import recorder
from homeassistant.components.recorder import core, migration, statistics
from homeassistant.components.recorder.migration import MigrationTask
from homeassistant.components.recorder.queries import get_migration_changes
from homeassistant.components.recorder.tasks import StatesContextIDMigrationTask
from homeassistant.components.recorder.util import (
execute_stmt_lambda_element,
session_scope,
@ -19,7 +19,11 @@ from homeassistant.components.recorder.util import (
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import HomeAssistant
from .common import async_recorder_block_till_done, async_wait_recording_done
from .common import (
MockMigrationTask,
async_recorder_block_till_done,
async_wait_recording_done,
)
from tests.common import async_test_home_assistant
from tests.typing import RecorderInstanceGenerator
@ -99,7 +103,7 @@ async def test_migration_changes_prevent_trying_to_migrate_again(
patch.object(core, "States", old_db_schema.States),
patch.object(core, "Events", old_db_schema.Events),
patch.object(core, "StateAttributes", old_db_schema.StateAttributes),
patch.object(migration.EntityIDMigration, "task", core.RecorderTask),
patch.object(migration.EntityIDMigration, "task", MockMigrationTask),
patch(CREATE_ENGINE_TARGET, new=_create_engine_test),
):
async with (
@ -169,4 +173,4 @@ async def test_migration_changes_prevent_trying_to_migrate_again(
await hass.async_stop()
for task in tasks:
assert not isinstance(task, StatesContextIDMigrationTask)
assert not isinstance(task, MigrationTask)

View file

@ -106,10 +106,6 @@ async def test_migrate_times(
patch(CREATE_ENGINE_TARGET, new=_create_engine_test),
patch.multiple(
"homeassistant.components.recorder.Recorder",
_migrate_events_context_ids=DEFAULT,
_migrate_states_context_ids=DEFAULT,
_migrate_event_type_ids=DEFAULT,
_migrate_entity_ids=DEFAULT,
_post_migrate_entity_ids=DEFAULT,
_cleanup_legacy_states_event_ids=DEFAULT,
),
@ -262,10 +258,6 @@ async def test_migrate_can_resume_entity_id_post_migration(
patch(CREATE_ENGINE_TARGET, new=_create_engine_test),
patch.multiple(
"homeassistant.components.recorder.Recorder",
_migrate_events_context_ids=DEFAULT,
_migrate_states_context_ids=DEFAULT,
_migrate_event_type_ids=DEFAULT,
_migrate_entity_ids=DEFAULT,
_post_migrate_entity_ids=DEFAULT,
_cleanup_legacy_states_event_ids=DEFAULT,
),
@ -388,10 +380,6 @@ async def test_migrate_can_resume_ix_states_event_id_removed(
patch(CREATE_ENGINE_TARGET, new=_create_engine_test),
patch.multiple(
"homeassistant.components.recorder.Recorder",
_migrate_events_context_ids=DEFAULT,
_migrate_states_context_ids=DEFAULT,
_migrate_event_type_ids=DEFAULT,
_migrate_entity_ids=DEFAULT,
_post_migrate_entity_ids=DEFAULT,
_cleanup_legacy_states_event_ids=DEFAULT,
),

View file

@ -1451,22 +1451,16 @@ async def async_test_recorder(
else None
)
migrate_states_context_ids = (
recorder.Recorder._migrate_states_context_ids
if enable_migrate_context_ids
else None
migration.migrate_states_context_ids if enable_migrate_context_ids else None
)
migrate_events_context_ids = (
recorder.Recorder._migrate_events_context_ids
if enable_migrate_context_ids
else None
migration.migrate_events_context_ids if enable_migrate_context_ids else None
)
migrate_event_type_ids = (
recorder.Recorder._migrate_event_type_ids
if enable_migrate_event_type_ids
else None
migration.migrate_event_type_ids if enable_migrate_event_type_ids else None
)
migrate_entity_ids = (
recorder.Recorder._migrate_entity_ids if enable_migrate_entity_ids else None
migration.migrate_entity_ids if enable_migrate_entity_ids else None
)
legacy_event_id_foreign_key_exists = (
recorder.Recorder._legacy_event_id_foreign_key_exists
@ -1490,22 +1484,22 @@ async def async_test_recorder(
autospec=True,
),
patch(
"homeassistant.components.recorder.Recorder._migrate_events_context_ids",
"homeassistant.components.recorder.migration.migrate_events_context_ids",
side_effect=migrate_events_context_ids,
autospec=True,
),
patch(
"homeassistant.components.recorder.Recorder._migrate_states_context_ids",
"homeassistant.components.recorder.migration.migrate_states_context_ids",
side_effect=migrate_states_context_ids,
autospec=True,
),
patch(
"homeassistant.components.recorder.Recorder._migrate_event_type_ids",
"homeassistant.components.recorder.migration.migrate_event_type_ids",
side_effect=migrate_event_type_ids,
autospec=True,
),
patch(
"homeassistant.components.recorder.Recorder._migrate_entity_ids",
"homeassistant.components.recorder.migration.migrate_entity_ids",
side_effect=migrate_entity_ids,
autospec=True,
),