From 9970b7eece9827de35c77c3239739af25ef6821d Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 16 Jul 2024 21:50:19 +0200 Subject: [PATCH] Refactor recorder data migration (#121009) * Refactor recorder data migration * Fix stale docstrings * Don't store a session object in BaseRunTimeMigration instances * Simplify logic in EntityIDMigration.migration_done * Fix tests --- homeassistant/components/recorder/core.py | 58 ++------- .../components/recorder/migration.py | 120 +++++++++++++++--- homeassistant/components/recorder/tasks.py | 69 ---------- tests/components/recorder/common.py | 10 +- .../recorder/test_migration_from_schema_32.py | 29 +++-- ..._migration_run_time_migrations_remember.py | 12 +- .../components/recorder/test_v32_migration.py | 12 -- tests/conftest.py | 22 ++-- 8 files changed, 149 insertions(+), 183 deletions(-) diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index 09c85105121..2b8f45703b5 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -75,7 +75,6 @@ from .const import ( SupportedDialect, ) from .db_schema import ( - LEGACY_STATES_ENTITY_ID_LAST_UPDATED_INDEX, LEGACY_STATES_EVENT_ID_INDEX, SCHEMA_VERSION, TABLE_STATES, @@ -91,7 +90,6 @@ from .db_schema import ( ) from .executor import DBInterruptibleThreadPoolExecutor from .migration import ( - BaseRunTimeMigration, EntityIDMigration, EventsContextIDMigration, EventTypeIDMigration, @@ -115,7 +113,6 @@ from .tasks import ( CommitTask, CompileMissingStatisticsTask, DatabaseLockTask, - EntityIDPostMigrationTask, EventIdMigrationTask, ImportStatisticsTask, KeepAliveTask, @@ -804,37 +801,14 @@ class Recorder(threading.Thread): for row in execute_stmt_lambda_element(session, get_migration_changes()) } - migrator: BaseRunTimeMigration - for migrator_cls in (StatesContextIDMigration, EventsContextIDMigration): - migrator = migrator_cls(session, schema_version, migration_changes) - if migrator.needs_migrate(): - self.queue_task(migrator.task()) - - migrator = EventTypeIDMigration(session, schema_version, migration_changes) - if migrator.needs_migrate(): - self.queue_task(migrator.task()) - else: - _LOGGER.debug("Activating event_types manager as all data is migrated") - self.event_type_manager.active = True - - migrator = EntityIDMigration(session, schema_version, migration_changes) - if migrator.needs_migrate(): - self.queue_task(migrator.task()) - else: - _LOGGER.debug("Activating states_meta manager as all data is migrated") - self.states_meta_manager.active = True - with contextlib.suppress(SQLAlchemyError): - # If ix_states_entity_id_last_updated_ts still exists - # on the states table it means the entity id migration - # finished by the EntityIDPostMigrationTask did not - # complete because they restarted in the middle of it. We need - # to pick back up where we left off. - if get_index_by_name( - session, - TABLE_STATES, - LEGACY_STATES_ENTITY_ID_LAST_UPDATED_INDEX, - ): - self.queue_task(EntityIDPostMigrationTask()) + for migrator_cls in ( + StatesContextIDMigration, + EventsContextIDMigration, + EventTypeIDMigration, + EntityIDMigration, + ): + migrator = migrator_cls(schema_version, migration_changes) + migrator.do_migrate(self, session) if self.schema_version > LEGACY_STATES_EVENT_ID_INDEX_SCHEMA_VERSION: with contextlib.suppress(SQLAlchemyError): @@ -1319,22 +1293,6 @@ class Recorder(threading.Thread): ) ) - def _migrate_states_context_ids(self) -> bool: - """Migrate states context ids if needed.""" - return migration.migrate_states_context_ids(self) - - def _migrate_events_context_ids(self) -> bool: - """Migrate events context ids if needed.""" - return migration.migrate_events_context_ids(self) - - def _migrate_event_type_ids(self) -> bool: - """Migrate event type ids if needed.""" - return migration.migrate_event_type_ids(self) - - def _migrate_entity_ids(self) -> bool: - """Migrate entity_ids if needed.""" - return migration.migrate_entity_ids(self) - def _post_migrate_entity_ids(self) -> bool: """Post migrate entity_ids if needed.""" return migration.post_migrate_entity_ids(self) diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index e93a3677e74..574129ca019 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -102,12 +102,9 @@ from .queries import ( from .statistics import get_start_time from .tasks import ( CommitTask, - EntityIDMigrationTask, - EventsContextIDMigrationTask, - EventTypeIDMigrationTask, + EntityIDPostMigrationTask, PostSchemaMigrationTask, RecorderTask, - StatesContextIDMigrationTask, StatisticsTimestampMigrationCleanupTask, ) from .util import ( @@ -2001,9 +1998,6 @@ def migrate_event_type_ids(instance: Recorder) -> bool: if is_done := not events: _mark_migration_done(session, EventTypeIDMigration) - if is_done: - instance.event_type_manager.active = True - _LOGGER.debug("Migrating event_types done=%s", is_done) return is_done @@ -2182,27 +2176,62 @@ def initialize_database(session_maker: Callable[[], Session]) -> bool: return False +@dataclass(slots=True) +class MigrationTask(RecorderTask): + """Base class for migration tasks.""" + + migrator: BaseRunTimeMigration + commit_before = False + + def run(self, instance: Recorder) -> None: + """Run migration task.""" + if not self.migrator.migrate_data(instance): + # Schedule a new migration task if this one didn't finish + instance.queue_task(MigrationTask(self.migrator)) + else: + self.migrator.migration_done(instance) + + +@dataclass(slots=True) +class CommitBeforeMigrationTask(MigrationTask): + """Base class for migration tasks which commit first.""" + + commit_before = True + + class BaseRunTimeMigration(ABC): """Base class for run time migrations.""" required_schema_version = 0 migration_version = 1 migration_id: str - task: Callable[[], RecorderTask] + task = MigrationTask - def __init__( - self, session: Session, schema_version: int, migration_changes: dict[str, int] - ) -> None: + def __init__(self, schema_version: int, migration_changes: dict[str, int]) -> None: """Initialize a new BaseRunTimeMigration.""" self.schema_version = schema_version - self.session = session self.migration_changes = migration_changes + def do_migrate(self, instance: Recorder, session: Session) -> None: + """Start migration if needed.""" + if self.needs_migrate(session): + instance.queue_task(self.task(self)) + else: + self.migration_done(instance) + + @staticmethod + @abstractmethod + def migrate_data(instance: Recorder) -> bool: + """Migrate some data, returns True if migration is completed.""" + + def migration_done(self, instance: Recorder) -> None: + """Will be called after migrate returns True or if migration is not needed.""" + @abstractmethod def needs_migrate_query(self) -> StatementLambdaElement: """Return the query to check if the migration needs to run.""" - def needs_migrate(self) -> bool: + def needs_migrate(self, session: Session) -> bool: """Return if the migration needs to run. If the migration needs to run, it will return True. @@ -2220,8 +2249,8 @@ class BaseRunTimeMigration(ABC): # We do not know if the migration is done from the # migration changes table so we must check the data # This is the slow path - if not execute_stmt_lambda_element(self.session, self.needs_migrate_query()): - _mark_migration_done(self.session, self.__class__) + if not execute_stmt_lambda_element(session, self.needs_migrate_query()): + _mark_migration_done(session, self.__class__) return False return True @@ -2231,7 +2260,11 @@ class StatesContextIDMigration(BaseRunTimeMigration): required_schema_version = CONTEXT_ID_AS_BINARY_SCHEMA_VERSION migration_id = "state_context_id_as_binary" - task = StatesContextIDMigrationTask + + @staticmethod + def migrate_data(instance: Recorder) -> bool: + """Migrate some data, returns True if migration is completed.""" + return migrate_states_context_ids(instance) def needs_migrate_query(self) -> StatementLambdaElement: """Return the query to check if the migration needs to run.""" @@ -2243,7 +2276,11 @@ class EventsContextIDMigration(BaseRunTimeMigration): required_schema_version = CONTEXT_ID_AS_BINARY_SCHEMA_VERSION migration_id = "event_context_id_as_binary" - task = EventsContextIDMigrationTask + + @staticmethod + def migrate_data(instance: Recorder) -> bool: + """Migrate some data, returns True if migration is completed.""" + return migrate_events_context_ids(instance) def needs_migrate_query(self) -> StatementLambdaElement: """Return the query to check if the migration needs to run.""" @@ -2255,7 +2292,20 @@ class EventTypeIDMigration(BaseRunTimeMigration): required_schema_version = EVENT_TYPE_IDS_SCHEMA_VERSION migration_id = "event_type_id_migration" - task = EventTypeIDMigrationTask + task = CommitBeforeMigrationTask + # We have to commit before to make sure there are + # no new pending event_types about to be added to + # the db since this happens live + + @staticmethod + def migrate_data(instance: Recorder) -> bool: + """Migrate some data, returns True if migration is completed.""" + return migrate_event_type_ids(instance) + + def migration_done(self, instance: Recorder) -> None: + """Will be called after migrate returns True.""" + _LOGGER.debug("Activating event_types manager as all data is migrated") + instance.event_type_manager.active = True def needs_migrate_query(self) -> StatementLambdaElement: """Check if the data is migrated.""" @@ -2267,7 +2317,39 @@ class EntityIDMigration(BaseRunTimeMigration): required_schema_version = STATES_META_SCHEMA_VERSION migration_id = "entity_id_migration" - task = EntityIDMigrationTask + task = CommitBeforeMigrationTask + # We have to commit before to make sure there are + # no new pending states_meta about to be added to + # the db since this happens live + + @staticmethod + def migrate_data(instance: Recorder) -> bool: + """Migrate some data, returns True if migration is completed.""" + return migrate_entity_ids(instance) + + def migration_done(self, instance: Recorder) -> None: + """Will be called after migrate returns True.""" + # The migration has finished, now we start the post migration + # to remove the old entity_id data from the states table + # at this point we can also start using the StatesMeta table + # so we set active to True + _LOGGER.debug("Activating states_meta manager as all data is migrated") + instance.states_meta_manager.active = True + with ( + contextlib.suppress(SQLAlchemyError), + session_scope(session=instance.get_session()) as session, + ): + # If ix_states_entity_id_last_updated_ts still exists + # on the states table it means the entity id migration + # finished by the EntityIDPostMigrationTask did not + # complete because they restarted in the middle of it. We need + # to pick back up where we left off. + if get_index_by_name( + session, + TABLE_STATES, + LEGACY_STATES_ENTITY_ID_LAST_UPDATED_INDEX, + ): + instance.queue_task(EntityIDPostMigrationTask()) def needs_migrate_query(self) -> StatementLambdaElement: """Check if the data is migrated.""" diff --git a/homeassistant/components/recorder/tasks.py b/homeassistant/components/recorder/tasks.py index b4fe148a229..6072c5cdde7 100644 --- a/homeassistant/components/recorder/tasks.py +++ b/homeassistant/components/recorder/tasks.py @@ -358,75 +358,6 @@ class AdjustLRUSizeTask(RecorderTask): instance._adjust_lru_size() # noqa: SLF001 -@dataclass(slots=True) -class StatesContextIDMigrationTask(RecorderTask): - """An object to insert into the recorder queue to migrate states context ids.""" - - commit_before = False - - def run(self, instance: Recorder) -> None: - """Run context id migration task.""" - if ( - not instance._migrate_states_context_ids() # noqa: SLF001 - ): - # Schedule a new migration task if this one didn't finish - instance.queue_task(StatesContextIDMigrationTask()) - - -@dataclass(slots=True) -class EventsContextIDMigrationTask(RecorderTask): - """An object to insert into the recorder queue to migrate events context ids.""" - - commit_before = False - - def run(self, instance: Recorder) -> None: - """Run context id migration task.""" - if ( - not instance._migrate_events_context_ids() # noqa: SLF001 - ): - # Schedule a new migration task if this one didn't finish - instance.queue_task(EventsContextIDMigrationTask()) - - -@dataclass(slots=True) -class EventTypeIDMigrationTask(RecorderTask): - """An object to insert into the recorder queue to migrate event type ids.""" - - commit_before = True - # We have to commit before to make sure there are - # no new pending event_types about to be added to - # the db since this happens live - - def run(self, instance: Recorder) -> None: - """Run event type id migration task.""" - if not instance._migrate_event_type_ids(): # noqa: SLF001 - # Schedule a new migration task if this one didn't finish - instance.queue_task(EventTypeIDMigrationTask()) - - -@dataclass(slots=True) -class EntityIDMigrationTask(RecorderTask): - """An object to insert into the recorder queue to migrate entity_ids to StatesMeta.""" - - commit_before = True - # We have to commit before to make sure there are - # no new pending states_meta about to be added to - # the db since this happens live - - def run(self, instance: Recorder) -> None: - """Run entity_id migration task.""" - if not instance._migrate_entity_ids(): # noqa: SLF001 - # Schedule a new migration task if this one didn't finish - instance.queue_task(EntityIDMigrationTask()) - else: - # The migration has finished, now we start the post migration - # to remove the old entity_id data from the states table - # at this point we can also start using the StatesMeta table - # so we set active to True - instance.states_meta_manager.active = True - instance.queue_task(EntityIDPostMigrationTask()) - - @dataclass(slots=True) class EntityIDPostMigrationTask(RecorderTask): """An object to insert into the recorder queue to cleanup after entity_ids migration.""" diff --git a/tests/components/recorder/common.py b/tests/components/recorder/common.py index c72b1ac830b..003b07ab80f 100644 --- a/tests/components/recorder/common.py +++ b/tests/components/recorder/common.py @@ -416,6 +416,14 @@ def get_schema_module_path(schema_version_postfix: str) -> str: return f"tests.components.recorder.db_schema_{schema_version_postfix}" +@dataclass(slots=True) +class MockMigrationTask(migration.MigrationTask): + """Mock migration task which does nothing.""" + + def run(self, instance: Recorder) -> None: + """Run migration task.""" + + @contextmanager def old_db_schema(schema_version_postfix: str) -> Iterator[None]: """Fixture to initialize the db with the old schema.""" @@ -434,7 +442,7 @@ def old_db_schema(schema_version_postfix: str) -> Iterator[None]: patch.object(core, "States", old_db_schema.States), patch.object(core, "Events", old_db_schema.Events), patch.object(core, "StateAttributes", old_db_schema.StateAttributes), - patch.object(migration.EntityIDMigration, "task", core.RecorderTask), + patch.object(migration.EntityIDMigration, "task", MockMigrationTask), patch( CREATE_ENGINE_TARGET, new=partial( diff --git a/tests/components/recorder/test_migration_from_schema_32.py b/tests/components/recorder/test_migration_from_schema_32.py index 91358388614..8a542ed8764 100644 --- a/tests/components/recorder/test_migration_from_schema_32.py +++ b/tests/components/recorder/test_migration_from_schema_32.py @@ -32,13 +32,7 @@ from homeassistant.components.recorder.queries import ( get_migration_changes, select_event_type_ids, ) -from homeassistant.components.recorder.tasks import ( - EntityIDMigrationTask, - EntityIDPostMigrationTask, - EventsContextIDMigrationTask, - EventTypeIDMigrationTask, - StatesContextIDMigrationTask, -) +from homeassistant.components.recorder.tasks import EntityIDPostMigrationTask from homeassistant.components.recorder.util import ( execute_stmt_lambda_element, session_scope, @@ -48,6 +42,7 @@ import homeassistant.util.dt as dt_util from homeassistant.util.ulid import bytes_to_ulid, ulid_at_time, ulid_to_bytes from .common import ( + MockMigrationTask, async_attach_db_engine, async_recorder_block_till_done, async_wait_recording_done, @@ -116,7 +111,7 @@ def db_schema_32(): patch.object(core, "States", old_db_schema.States), patch.object(core, "Events", old_db_schema.Events), patch.object(core, "StateAttributes", old_db_schema.StateAttributes), - patch.object(migration.EntityIDMigration, "task", core.RecorderTask), + patch.object(migration.EntityIDMigration, "task", MockMigrationTask), patch(CREATE_ENGINE_TARGET, new=_create_engine_test), ): yield @@ -229,7 +224,8 @@ async def test_migrate_events_context_ids( with freeze_time(now): # This is a threadsafe way to add a task to the recorder - recorder_mock.queue_task(EventsContextIDMigrationTask()) + migrator = migration.EventsContextIDMigration(None, None) + recorder_mock.queue_task(migrator.task(migrator)) await _async_wait_migration_done(hass) def _object_as_dict(obj): @@ -419,7 +415,8 @@ async def test_migrate_states_context_ids( await recorder_mock.async_add_executor_job(_insert_states) await async_wait_recording_done(hass) - recorder_mock.queue_task(StatesContextIDMigrationTask()) + migrator = migration.StatesContextIDMigration(None, None) + recorder_mock.queue_task(migrator.task(migrator)) await _async_wait_migration_done(hass) def _object_as_dict(obj): @@ -567,7 +564,8 @@ async def test_migrate_event_type_ids( await async_wait_recording_done(hass) # This is a threadsafe way to add a task to the recorder - recorder_mock.queue_task(EventTypeIDMigrationTask()) + migrator = migration.EventTypeIDMigration(None, None) + recorder_mock.queue_task(migrator.task(migrator)) await _async_wait_migration_done(hass) def _fetch_migrated_events(): @@ -655,7 +653,8 @@ async def test_migrate_entity_ids(hass: HomeAssistant, recorder_mock: Recorder) await _async_wait_migration_done(hass) # This is a threadsafe way to add a task to the recorder - recorder_mock.queue_task(EntityIDMigrationTask()) + migrator = migration.EntityIDMigration(None, None) + recorder_mock.queue_task(migration.CommitBeforeMigrationTask(migrator)) await _async_wait_migration_done(hass) def _fetch_migrated_states(): @@ -788,7 +787,8 @@ async def test_migrate_null_entity_ids( await _async_wait_migration_done(hass) # This is a threadsafe way to add a task to the recorder - recorder_mock.queue_task(EntityIDMigrationTask()) + migrator = migration.EntityIDMigration(None, None) + recorder_mock.queue_task(migration.CommitBeforeMigrationTask(migrator)) await _async_wait_migration_done(hass) def _fetch_migrated_states(): @@ -870,7 +870,8 @@ async def test_migrate_null_event_type_ids( await _async_wait_migration_done(hass) # This is a threadsafe way to add a task to the recorder - recorder_mock.queue_task(EventTypeIDMigrationTask()) + migrator = migration.EventTypeIDMigration(None, None) + recorder_mock.queue_task(migrator.task(migrator)) await _async_wait_migration_done(hass) def _fetch_migrated_events(): diff --git a/tests/components/recorder/test_migration_run_time_migrations_remember.py b/tests/components/recorder/test_migration_run_time_migrations_remember.py index ec81711c215..bdd881a3a7b 100644 --- a/tests/components/recorder/test_migration_run_time_migrations_remember.py +++ b/tests/components/recorder/test_migration_run_time_migrations_remember.py @@ -10,8 +10,8 @@ from sqlalchemy.orm import Session from homeassistant.components import recorder from homeassistant.components.recorder import core, migration, statistics +from homeassistant.components.recorder.migration import MigrationTask from homeassistant.components.recorder.queries import get_migration_changes -from homeassistant.components.recorder.tasks import StatesContextIDMigrationTask from homeassistant.components.recorder.util import ( execute_stmt_lambda_element, session_scope, @@ -19,7 +19,11 @@ from homeassistant.components.recorder.util import ( from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.core import HomeAssistant -from .common import async_recorder_block_till_done, async_wait_recording_done +from .common import ( + MockMigrationTask, + async_recorder_block_till_done, + async_wait_recording_done, +) from tests.common import async_test_home_assistant from tests.typing import RecorderInstanceGenerator @@ -99,7 +103,7 @@ async def test_migration_changes_prevent_trying_to_migrate_again( patch.object(core, "States", old_db_schema.States), patch.object(core, "Events", old_db_schema.Events), patch.object(core, "StateAttributes", old_db_schema.StateAttributes), - patch.object(migration.EntityIDMigration, "task", core.RecorderTask), + patch.object(migration.EntityIDMigration, "task", MockMigrationTask), patch(CREATE_ENGINE_TARGET, new=_create_engine_test), ): async with ( @@ -169,4 +173,4 @@ async def test_migration_changes_prevent_trying_to_migrate_again( await hass.async_stop() for task in tasks: - assert not isinstance(task, StatesContextIDMigrationTask) + assert not isinstance(task, MigrationTask) diff --git a/tests/components/recorder/test_v32_migration.py b/tests/components/recorder/test_v32_migration.py index 666629d4bcf..2d3c339ae5c 100644 --- a/tests/components/recorder/test_v32_migration.py +++ b/tests/components/recorder/test_v32_migration.py @@ -106,10 +106,6 @@ async def test_migrate_times( patch(CREATE_ENGINE_TARGET, new=_create_engine_test), patch.multiple( "homeassistant.components.recorder.Recorder", - _migrate_events_context_ids=DEFAULT, - _migrate_states_context_ids=DEFAULT, - _migrate_event_type_ids=DEFAULT, - _migrate_entity_ids=DEFAULT, _post_migrate_entity_ids=DEFAULT, _cleanup_legacy_states_event_ids=DEFAULT, ), @@ -262,10 +258,6 @@ async def test_migrate_can_resume_entity_id_post_migration( patch(CREATE_ENGINE_TARGET, new=_create_engine_test), patch.multiple( "homeassistant.components.recorder.Recorder", - _migrate_events_context_ids=DEFAULT, - _migrate_states_context_ids=DEFAULT, - _migrate_event_type_ids=DEFAULT, - _migrate_entity_ids=DEFAULT, _post_migrate_entity_ids=DEFAULT, _cleanup_legacy_states_event_ids=DEFAULT, ), @@ -388,10 +380,6 @@ async def test_migrate_can_resume_ix_states_event_id_removed( patch(CREATE_ENGINE_TARGET, new=_create_engine_test), patch.multiple( "homeassistant.components.recorder.Recorder", - _migrate_events_context_ids=DEFAULT, - _migrate_states_context_ids=DEFAULT, - _migrate_event_type_ids=DEFAULT, - _migrate_entity_ids=DEFAULT, _post_migrate_entity_ids=DEFAULT, _cleanup_legacy_states_event_ids=DEFAULT, ), diff --git a/tests/conftest.py b/tests/conftest.py index 85f4671f6c0..594d71fa165 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1451,22 +1451,16 @@ async def async_test_recorder( else None ) migrate_states_context_ids = ( - recorder.Recorder._migrate_states_context_ids - if enable_migrate_context_ids - else None + migration.migrate_states_context_ids if enable_migrate_context_ids else None ) migrate_events_context_ids = ( - recorder.Recorder._migrate_events_context_ids - if enable_migrate_context_ids - else None + migration.migrate_events_context_ids if enable_migrate_context_ids else None ) migrate_event_type_ids = ( - recorder.Recorder._migrate_event_type_ids - if enable_migrate_event_type_ids - else None + migration.migrate_event_type_ids if enable_migrate_event_type_ids else None ) migrate_entity_ids = ( - recorder.Recorder._migrate_entity_ids if enable_migrate_entity_ids else None + migration.migrate_entity_ids if enable_migrate_entity_ids else None ) legacy_event_id_foreign_key_exists = ( recorder.Recorder._legacy_event_id_foreign_key_exists @@ -1490,22 +1484,22 @@ async def async_test_recorder( autospec=True, ), patch( - "homeassistant.components.recorder.Recorder._migrate_events_context_ids", + "homeassistant.components.recorder.migration.migrate_events_context_ids", side_effect=migrate_events_context_ids, autospec=True, ), patch( - "homeassistant.components.recorder.Recorder._migrate_states_context_ids", + "homeassistant.components.recorder.migration.migrate_states_context_ids", side_effect=migrate_states_context_ids, autospec=True, ), patch( - "homeassistant.components.recorder.Recorder._migrate_event_type_ids", + "homeassistant.components.recorder.migration.migrate_event_type_ids", side_effect=migrate_event_type_ids, autospec=True, ), patch( - "homeassistant.components.recorder.Recorder._migrate_entity_ids", + "homeassistant.components.recorder.migration.migrate_entity_ids", side_effect=migrate_entity_ids, autospec=True, ),