Refactor recorder data migration (#121009)

* Refactor recorder data migration

* Fix stale docstrings

* Don't store a session object in BaseRunTimeMigration instances

* Simplify logic in EntityIDMigration.migration_done

* Fix tests
This commit is contained in:
Erik Montnemery 2024-07-16 21:50:19 +02:00 committed by GitHub
parent baa97ca981
commit 9970b7eece
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 149 additions and 183 deletions

View file

@ -75,7 +75,6 @@ from .const import (
SupportedDialect, SupportedDialect,
) )
from .db_schema import ( from .db_schema import (
LEGACY_STATES_ENTITY_ID_LAST_UPDATED_INDEX,
LEGACY_STATES_EVENT_ID_INDEX, LEGACY_STATES_EVENT_ID_INDEX,
SCHEMA_VERSION, SCHEMA_VERSION,
TABLE_STATES, TABLE_STATES,
@ -91,7 +90,6 @@ from .db_schema import (
) )
from .executor import DBInterruptibleThreadPoolExecutor from .executor import DBInterruptibleThreadPoolExecutor
from .migration import ( from .migration import (
BaseRunTimeMigration,
EntityIDMigration, EntityIDMigration,
EventsContextIDMigration, EventsContextIDMigration,
EventTypeIDMigration, EventTypeIDMigration,
@ -115,7 +113,6 @@ from .tasks import (
CommitTask, CommitTask,
CompileMissingStatisticsTask, CompileMissingStatisticsTask,
DatabaseLockTask, DatabaseLockTask,
EntityIDPostMigrationTask,
EventIdMigrationTask, EventIdMigrationTask,
ImportStatisticsTask, ImportStatisticsTask,
KeepAliveTask, KeepAliveTask,
@ -804,37 +801,14 @@ class Recorder(threading.Thread):
for row in execute_stmt_lambda_element(session, get_migration_changes()) for row in execute_stmt_lambda_element(session, get_migration_changes())
} }
migrator: BaseRunTimeMigration for migrator_cls in (
for migrator_cls in (StatesContextIDMigration, EventsContextIDMigration): StatesContextIDMigration,
migrator = migrator_cls(session, schema_version, migration_changes) EventsContextIDMigration,
if migrator.needs_migrate(): EventTypeIDMigration,
self.queue_task(migrator.task()) EntityIDMigration,
):
migrator = EventTypeIDMigration(session, schema_version, migration_changes) migrator = migrator_cls(schema_version, migration_changes)
if migrator.needs_migrate(): migrator.do_migrate(self, session)
self.queue_task(migrator.task())
else:
_LOGGER.debug("Activating event_types manager as all data is migrated")
self.event_type_manager.active = True
migrator = EntityIDMigration(session, schema_version, migration_changes)
if migrator.needs_migrate():
self.queue_task(migrator.task())
else:
_LOGGER.debug("Activating states_meta manager as all data is migrated")
self.states_meta_manager.active = True
with contextlib.suppress(SQLAlchemyError):
# If ix_states_entity_id_last_updated_ts still exists
# on the states table it means the entity id migration
# finished by the EntityIDPostMigrationTask did not
# complete because they restarted in the middle of it. We need
# to pick back up where we left off.
if get_index_by_name(
session,
TABLE_STATES,
LEGACY_STATES_ENTITY_ID_LAST_UPDATED_INDEX,
):
self.queue_task(EntityIDPostMigrationTask())
if self.schema_version > LEGACY_STATES_EVENT_ID_INDEX_SCHEMA_VERSION: if self.schema_version > LEGACY_STATES_EVENT_ID_INDEX_SCHEMA_VERSION:
with contextlib.suppress(SQLAlchemyError): with contextlib.suppress(SQLAlchemyError):
@ -1319,22 +1293,6 @@ class Recorder(threading.Thread):
) )
) )
def _migrate_states_context_ids(self) -> bool:
"""Migrate states context ids if needed."""
return migration.migrate_states_context_ids(self)
def _migrate_events_context_ids(self) -> bool:
"""Migrate events context ids if needed."""
return migration.migrate_events_context_ids(self)
def _migrate_event_type_ids(self) -> bool:
"""Migrate event type ids if needed."""
return migration.migrate_event_type_ids(self)
def _migrate_entity_ids(self) -> bool:
"""Migrate entity_ids if needed."""
return migration.migrate_entity_ids(self)
def _post_migrate_entity_ids(self) -> bool: def _post_migrate_entity_ids(self) -> bool:
"""Post migrate entity_ids if needed.""" """Post migrate entity_ids if needed."""
return migration.post_migrate_entity_ids(self) return migration.post_migrate_entity_ids(self)

View file

@ -102,12 +102,9 @@ from .queries import (
from .statistics import get_start_time from .statistics import get_start_time
from .tasks import ( from .tasks import (
CommitTask, CommitTask,
EntityIDMigrationTask, EntityIDPostMigrationTask,
EventsContextIDMigrationTask,
EventTypeIDMigrationTask,
PostSchemaMigrationTask, PostSchemaMigrationTask,
RecorderTask, RecorderTask,
StatesContextIDMigrationTask,
StatisticsTimestampMigrationCleanupTask, StatisticsTimestampMigrationCleanupTask,
) )
from .util import ( from .util import (
@ -2001,9 +1998,6 @@ def migrate_event_type_ids(instance: Recorder) -> bool:
if is_done := not events: if is_done := not events:
_mark_migration_done(session, EventTypeIDMigration) _mark_migration_done(session, EventTypeIDMigration)
if is_done:
instance.event_type_manager.active = True
_LOGGER.debug("Migrating event_types done=%s", is_done) _LOGGER.debug("Migrating event_types done=%s", is_done)
return is_done return is_done
@ -2182,27 +2176,62 @@ def initialize_database(session_maker: Callable[[], Session]) -> bool:
return False return False
@dataclass(slots=True)
class MigrationTask(RecorderTask):
"""Base class for migration tasks."""
migrator: BaseRunTimeMigration
commit_before = False
def run(self, instance: Recorder) -> None:
"""Run migration task."""
if not self.migrator.migrate_data(instance):
# Schedule a new migration task if this one didn't finish
instance.queue_task(MigrationTask(self.migrator))
else:
self.migrator.migration_done(instance)
@dataclass(slots=True)
class CommitBeforeMigrationTask(MigrationTask):
"""Base class for migration tasks which commit first."""
commit_before = True
class BaseRunTimeMigration(ABC): class BaseRunTimeMigration(ABC):
"""Base class for run time migrations.""" """Base class for run time migrations."""
required_schema_version = 0 required_schema_version = 0
migration_version = 1 migration_version = 1
migration_id: str migration_id: str
task: Callable[[], RecorderTask] task = MigrationTask
def __init__( def __init__(self, schema_version: int, migration_changes: dict[str, int]) -> None:
self, session: Session, schema_version: int, migration_changes: dict[str, int]
) -> None:
"""Initialize a new BaseRunTimeMigration.""" """Initialize a new BaseRunTimeMigration."""
self.schema_version = schema_version self.schema_version = schema_version
self.session = session
self.migration_changes = migration_changes self.migration_changes = migration_changes
def do_migrate(self, instance: Recorder, session: Session) -> None:
"""Start migration if needed."""
if self.needs_migrate(session):
instance.queue_task(self.task(self))
else:
self.migration_done(instance)
@staticmethod
@abstractmethod
def migrate_data(instance: Recorder) -> bool:
"""Migrate some data, returns True if migration is completed."""
def migration_done(self, instance: Recorder) -> None:
"""Will be called after migrate returns True or if migration is not needed."""
@abstractmethod @abstractmethod
def needs_migrate_query(self) -> StatementLambdaElement: def needs_migrate_query(self) -> StatementLambdaElement:
"""Return the query to check if the migration needs to run.""" """Return the query to check if the migration needs to run."""
def needs_migrate(self) -> bool: def needs_migrate(self, session: Session) -> bool:
"""Return if the migration needs to run. """Return if the migration needs to run.
If the migration needs to run, it will return True. If the migration needs to run, it will return True.
@ -2220,8 +2249,8 @@ class BaseRunTimeMigration(ABC):
# We do not know if the migration is done from the # We do not know if the migration is done from the
# migration changes table so we must check the data # migration changes table so we must check the data
# This is the slow path # This is the slow path
if not execute_stmt_lambda_element(self.session, self.needs_migrate_query()): if not execute_stmt_lambda_element(session, self.needs_migrate_query()):
_mark_migration_done(self.session, self.__class__) _mark_migration_done(session, self.__class__)
return False return False
return True return True
@ -2231,7 +2260,11 @@ class StatesContextIDMigration(BaseRunTimeMigration):
required_schema_version = CONTEXT_ID_AS_BINARY_SCHEMA_VERSION required_schema_version = CONTEXT_ID_AS_BINARY_SCHEMA_VERSION
migration_id = "state_context_id_as_binary" migration_id = "state_context_id_as_binary"
task = StatesContextIDMigrationTask
@staticmethod
def migrate_data(instance: Recorder) -> bool:
"""Migrate some data, returns True if migration is completed."""
return migrate_states_context_ids(instance)
def needs_migrate_query(self) -> StatementLambdaElement: def needs_migrate_query(self) -> StatementLambdaElement:
"""Return the query to check if the migration needs to run.""" """Return the query to check if the migration needs to run."""
@ -2243,7 +2276,11 @@ class EventsContextIDMigration(BaseRunTimeMigration):
required_schema_version = CONTEXT_ID_AS_BINARY_SCHEMA_VERSION required_schema_version = CONTEXT_ID_AS_BINARY_SCHEMA_VERSION
migration_id = "event_context_id_as_binary" migration_id = "event_context_id_as_binary"
task = EventsContextIDMigrationTask
@staticmethod
def migrate_data(instance: Recorder) -> bool:
"""Migrate some data, returns True if migration is completed."""
return migrate_events_context_ids(instance)
def needs_migrate_query(self) -> StatementLambdaElement: def needs_migrate_query(self) -> StatementLambdaElement:
"""Return the query to check if the migration needs to run.""" """Return the query to check if the migration needs to run."""
@ -2255,7 +2292,20 @@ class EventTypeIDMigration(BaseRunTimeMigration):
required_schema_version = EVENT_TYPE_IDS_SCHEMA_VERSION required_schema_version = EVENT_TYPE_IDS_SCHEMA_VERSION
migration_id = "event_type_id_migration" migration_id = "event_type_id_migration"
task = EventTypeIDMigrationTask task = CommitBeforeMigrationTask
# We have to commit before to make sure there are
# no new pending event_types about to be added to
# the db since this happens live
@staticmethod
def migrate_data(instance: Recorder) -> bool:
"""Migrate some data, returns True if migration is completed."""
return migrate_event_type_ids(instance)
def migration_done(self, instance: Recorder) -> None:
"""Will be called after migrate returns True."""
_LOGGER.debug("Activating event_types manager as all data is migrated")
instance.event_type_manager.active = True
def needs_migrate_query(self) -> StatementLambdaElement: def needs_migrate_query(self) -> StatementLambdaElement:
"""Check if the data is migrated.""" """Check if the data is migrated."""
@ -2267,7 +2317,39 @@ class EntityIDMigration(BaseRunTimeMigration):
required_schema_version = STATES_META_SCHEMA_VERSION required_schema_version = STATES_META_SCHEMA_VERSION
migration_id = "entity_id_migration" migration_id = "entity_id_migration"
task = EntityIDMigrationTask task = CommitBeforeMigrationTask
# We have to commit before to make sure there are
# no new pending states_meta about to be added to
# the db since this happens live
@staticmethod
def migrate_data(instance: Recorder) -> bool:
"""Migrate some data, returns True if migration is completed."""
return migrate_entity_ids(instance)
def migration_done(self, instance: Recorder) -> None:
"""Will be called after migrate returns True."""
# The migration has finished, now we start the post migration
# to remove the old entity_id data from the states table
# at this point we can also start using the StatesMeta table
# so we set active to True
_LOGGER.debug("Activating states_meta manager as all data is migrated")
instance.states_meta_manager.active = True
with (
contextlib.suppress(SQLAlchemyError),
session_scope(session=instance.get_session()) as session,
):
# If ix_states_entity_id_last_updated_ts still exists
# on the states table it means the entity id migration
# finished by the EntityIDPostMigrationTask did not
# complete because they restarted in the middle of it. We need
# to pick back up where we left off.
if get_index_by_name(
session,
TABLE_STATES,
LEGACY_STATES_ENTITY_ID_LAST_UPDATED_INDEX,
):
instance.queue_task(EntityIDPostMigrationTask())
def needs_migrate_query(self) -> StatementLambdaElement: def needs_migrate_query(self) -> StatementLambdaElement:
"""Check if the data is migrated.""" """Check if the data is migrated."""

View file

@ -358,75 +358,6 @@ class AdjustLRUSizeTask(RecorderTask):
instance._adjust_lru_size() # noqa: SLF001 instance._adjust_lru_size() # noqa: SLF001
@dataclass(slots=True)
class StatesContextIDMigrationTask(RecorderTask):
"""An object to insert into the recorder queue to migrate states context ids."""
commit_before = False
def run(self, instance: Recorder) -> None:
"""Run context id migration task."""
if (
not instance._migrate_states_context_ids() # noqa: SLF001
):
# Schedule a new migration task if this one didn't finish
instance.queue_task(StatesContextIDMigrationTask())
@dataclass(slots=True)
class EventsContextIDMigrationTask(RecorderTask):
"""An object to insert into the recorder queue to migrate events context ids."""
commit_before = False
def run(self, instance: Recorder) -> None:
"""Run context id migration task."""
if (
not instance._migrate_events_context_ids() # noqa: SLF001
):
# Schedule a new migration task if this one didn't finish
instance.queue_task(EventsContextIDMigrationTask())
@dataclass(slots=True)
class EventTypeIDMigrationTask(RecorderTask):
"""An object to insert into the recorder queue to migrate event type ids."""
commit_before = True
# We have to commit before to make sure there are
# no new pending event_types about to be added to
# the db since this happens live
def run(self, instance: Recorder) -> None:
"""Run event type id migration task."""
if not instance._migrate_event_type_ids(): # noqa: SLF001
# Schedule a new migration task if this one didn't finish
instance.queue_task(EventTypeIDMigrationTask())
@dataclass(slots=True)
class EntityIDMigrationTask(RecorderTask):
"""An object to insert into the recorder queue to migrate entity_ids to StatesMeta."""
commit_before = True
# We have to commit before to make sure there are
# no new pending states_meta about to be added to
# the db since this happens live
def run(self, instance: Recorder) -> None:
"""Run entity_id migration task."""
if not instance._migrate_entity_ids(): # noqa: SLF001
# Schedule a new migration task if this one didn't finish
instance.queue_task(EntityIDMigrationTask())
else:
# The migration has finished, now we start the post migration
# to remove the old entity_id data from the states table
# at this point we can also start using the StatesMeta table
# so we set active to True
instance.states_meta_manager.active = True
instance.queue_task(EntityIDPostMigrationTask())
@dataclass(slots=True) @dataclass(slots=True)
class EntityIDPostMigrationTask(RecorderTask): class EntityIDPostMigrationTask(RecorderTask):
"""An object to insert into the recorder queue to cleanup after entity_ids migration.""" """An object to insert into the recorder queue to cleanup after entity_ids migration."""

View file

@ -416,6 +416,14 @@ def get_schema_module_path(schema_version_postfix: str) -> str:
return f"tests.components.recorder.db_schema_{schema_version_postfix}" return f"tests.components.recorder.db_schema_{schema_version_postfix}"
@dataclass(slots=True)
class MockMigrationTask(migration.MigrationTask):
"""Mock migration task which does nothing."""
def run(self, instance: Recorder) -> None:
"""Run migration task."""
@contextmanager @contextmanager
def old_db_schema(schema_version_postfix: str) -> Iterator[None]: def old_db_schema(schema_version_postfix: str) -> Iterator[None]:
"""Fixture to initialize the db with the old schema.""" """Fixture to initialize the db with the old schema."""
@ -434,7 +442,7 @@ def old_db_schema(schema_version_postfix: str) -> Iterator[None]:
patch.object(core, "States", old_db_schema.States), patch.object(core, "States", old_db_schema.States),
patch.object(core, "Events", old_db_schema.Events), patch.object(core, "Events", old_db_schema.Events),
patch.object(core, "StateAttributes", old_db_schema.StateAttributes), patch.object(core, "StateAttributes", old_db_schema.StateAttributes),
patch.object(migration.EntityIDMigration, "task", core.RecorderTask), patch.object(migration.EntityIDMigration, "task", MockMigrationTask),
patch( patch(
CREATE_ENGINE_TARGET, CREATE_ENGINE_TARGET,
new=partial( new=partial(

View file

@ -32,13 +32,7 @@ from homeassistant.components.recorder.queries import (
get_migration_changes, get_migration_changes,
select_event_type_ids, select_event_type_ids,
) )
from homeassistant.components.recorder.tasks import ( from homeassistant.components.recorder.tasks import EntityIDPostMigrationTask
EntityIDMigrationTask,
EntityIDPostMigrationTask,
EventsContextIDMigrationTask,
EventTypeIDMigrationTask,
StatesContextIDMigrationTask,
)
from homeassistant.components.recorder.util import ( from homeassistant.components.recorder.util import (
execute_stmt_lambda_element, execute_stmt_lambda_element,
session_scope, session_scope,
@ -48,6 +42,7 @@ import homeassistant.util.dt as dt_util
from homeassistant.util.ulid import bytes_to_ulid, ulid_at_time, ulid_to_bytes from homeassistant.util.ulid import bytes_to_ulid, ulid_at_time, ulid_to_bytes
from .common import ( from .common import (
MockMigrationTask,
async_attach_db_engine, async_attach_db_engine,
async_recorder_block_till_done, async_recorder_block_till_done,
async_wait_recording_done, async_wait_recording_done,
@ -116,7 +111,7 @@ def db_schema_32():
patch.object(core, "States", old_db_schema.States), patch.object(core, "States", old_db_schema.States),
patch.object(core, "Events", old_db_schema.Events), patch.object(core, "Events", old_db_schema.Events),
patch.object(core, "StateAttributes", old_db_schema.StateAttributes), patch.object(core, "StateAttributes", old_db_schema.StateAttributes),
patch.object(migration.EntityIDMigration, "task", core.RecorderTask), patch.object(migration.EntityIDMigration, "task", MockMigrationTask),
patch(CREATE_ENGINE_TARGET, new=_create_engine_test), patch(CREATE_ENGINE_TARGET, new=_create_engine_test),
): ):
yield yield
@ -229,7 +224,8 @@ async def test_migrate_events_context_ids(
with freeze_time(now): with freeze_time(now):
# This is a threadsafe way to add a task to the recorder # This is a threadsafe way to add a task to the recorder
recorder_mock.queue_task(EventsContextIDMigrationTask()) migrator = migration.EventsContextIDMigration(None, None)
recorder_mock.queue_task(migrator.task(migrator))
await _async_wait_migration_done(hass) await _async_wait_migration_done(hass)
def _object_as_dict(obj): def _object_as_dict(obj):
@ -419,7 +415,8 @@ async def test_migrate_states_context_ids(
await recorder_mock.async_add_executor_job(_insert_states) await recorder_mock.async_add_executor_job(_insert_states)
await async_wait_recording_done(hass) await async_wait_recording_done(hass)
recorder_mock.queue_task(StatesContextIDMigrationTask()) migrator = migration.StatesContextIDMigration(None, None)
recorder_mock.queue_task(migrator.task(migrator))
await _async_wait_migration_done(hass) await _async_wait_migration_done(hass)
def _object_as_dict(obj): def _object_as_dict(obj):
@ -567,7 +564,8 @@ async def test_migrate_event_type_ids(
await async_wait_recording_done(hass) await async_wait_recording_done(hass)
# This is a threadsafe way to add a task to the recorder # This is a threadsafe way to add a task to the recorder
recorder_mock.queue_task(EventTypeIDMigrationTask()) migrator = migration.EventTypeIDMigration(None, None)
recorder_mock.queue_task(migrator.task(migrator))
await _async_wait_migration_done(hass) await _async_wait_migration_done(hass)
def _fetch_migrated_events(): def _fetch_migrated_events():
@ -655,7 +653,8 @@ async def test_migrate_entity_ids(hass: HomeAssistant, recorder_mock: Recorder)
await _async_wait_migration_done(hass) await _async_wait_migration_done(hass)
# This is a threadsafe way to add a task to the recorder # This is a threadsafe way to add a task to the recorder
recorder_mock.queue_task(EntityIDMigrationTask()) migrator = migration.EntityIDMigration(None, None)
recorder_mock.queue_task(migration.CommitBeforeMigrationTask(migrator))
await _async_wait_migration_done(hass) await _async_wait_migration_done(hass)
def _fetch_migrated_states(): def _fetch_migrated_states():
@ -788,7 +787,8 @@ async def test_migrate_null_entity_ids(
await _async_wait_migration_done(hass) await _async_wait_migration_done(hass)
# This is a threadsafe way to add a task to the recorder # This is a threadsafe way to add a task to the recorder
recorder_mock.queue_task(EntityIDMigrationTask()) migrator = migration.EntityIDMigration(None, None)
recorder_mock.queue_task(migration.CommitBeforeMigrationTask(migrator))
await _async_wait_migration_done(hass) await _async_wait_migration_done(hass)
def _fetch_migrated_states(): def _fetch_migrated_states():
@ -870,7 +870,8 @@ async def test_migrate_null_event_type_ids(
await _async_wait_migration_done(hass) await _async_wait_migration_done(hass)
# This is a threadsafe way to add a task to the recorder # This is a threadsafe way to add a task to the recorder
recorder_mock.queue_task(EventTypeIDMigrationTask()) migrator = migration.EventTypeIDMigration(None, None)
recorder_mock.queue_task(migrator.task(migrator))
await _async_wait_migration_done(hass) await _async_wait_migration_done(hass)
def _fetch_migrated_events(): def _fetch_migrated_events():

View file

@ -10,8 +10,8 @@ from sqlalchemy.orm import Session
from homeassistant.components import recorder from homeassistant.components import recorder
from homeassistant.components.recorder import core, migration, statistics from homeassistant.components.recorder import core, migration, statistics
from homeassistant.components.recorder.migration import MigrationTask
from homeassistant.components.recorder.queries import get_migration_changes from homeassistant.components.recorder.queries import get_migration_changes
from homeassistant.components.recorder.tasks import StatesContextIDMigrationTask
from homeassistant.components.recorder.util import ( from homeassistant.components.recorder.util import (
execute_stmt_lambda_element, execute_stmt_lambda_element,
session_scope, session_scope,
@ -19,7 +19,11 @@ from homeassistant.components.recorder.util import (
from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from .common import async_recorder_block_till_done, async_wait_recording_done from .common import (
MockMigrationTask,
async_recorder_block_till_done,
async_wait_recording_done,
)
from tests.common import async_test_home_assistant from tests.common import async_test_home_assistant
from tests.typing import RecorderInstanceGenerator from tests.typing import RecorderInstanceGenerator
@ -99,7 +103,7 @@ async def test_migration_changes_prevent_trying_to_migrate_again(
patch.object(core, "States", old_db_schema.States), patch.object(core, "States", old_db_schema.States),
patch.object(core, "Events", old_db_schema.Events), patch.object(core, "Events", old_db_schema.Events),
patch.object(core, "StateAttributes", old_db_schema.StateAttributes), patch.object(core, "StateAttributes", old_db_schema.StateAttributes),
patch.object(migration.EntityIDMigration, "task", core.RecorderTask), patch.object(migration.EntityIDMigration, "task", MockMigrationTask),
patch(CREATE_ENGINE_TARGET, new=_create_engine_test), patch(CREATE_ENGINE_TARGET, new=_create_engine_test),
): ):
async with ( async with (
@ -169,4 +173,4 @@ async def test_migration_changes_prevent_trying_to_migrate_again(
await hass.async_stop() await hass.async_stop()
for task in tasks: for task in tasks:
assert not isinstance(task, StatesContextIDMigrationTask) assert not isinstance(task, MigrationTask)

View file

@ -106,10 +106,6 @@ async def test_migrate_times(
patch(CREATE_ENGINE_TARGET, new=_create_engine_test), patch(CREATE_ENGINE_TARGET, new=_create_engine_test),
patch.multiple( patch.multiple(
"homeassistant.components.recorder.Recorder", "homeassistant.components.recorder.Recorder",
_migrate_events_context_ids=DEFAULT,
_migrate_states_context_ids=DEFAULT,
_migrate_event_type_ids=DEFAULT,
_migrate_entity_ids=DEFAULT,
_post_migrate_entity_ids=DEFAULT, _post_migrate_entity_ids=DEFAULT,
_cleanup_legacy_states_event_ids=DEFAULT, _cleanup_legacy_states_event_ids=DEFAULT,
), ),
@ -262,10 +258,6 @@ async def test_migrate_can_resume_entity_id_post_migration(
patch(CREATE_ENGINE_TARGET, new=_create_engine_test), patch(CREATE_ENGINE_TARGET, new=_create_engine_test),
patch.multiple( patch.multiple(
"homeassistant.components.recorder.Recorder", "homeassistant.components.recorder.Recorder",
_migrate_events_context_ids=DEFAULT,
_migrate_states_context_ids=DEFAULT,
_migrate_event_type_ids=DEFAULT,
_migrate_entity_ids=DEFAULT,
_post_migrate_entity_ids=DEFAULT, _post_migrate_entity_ids=DEFAULT,
_cleanup_legacy_states_event_ids=DEFAULT, _cleanup_legacy_states_event_ids=DEFAULT,
), ),
@ -388,10 +380,6 @@ async def test_migrate_can_resume_ix_states_event_id_removed(
patch(CREATE_ENGINE_TARGET, new=_create_engine_test), patch(CREATE_ENGINE_TARGET, new=_create_engine_test),
patch.multiple( patch.multiple(
"homeassistant.components.recorder.Recorder", "homeassistant.components.recorder.Recorder",
_migrate_events_context_ids=DEFAULT,
_migrate_states_context_ids=DEFAULT,
_migrate_event_type_ids=DEFAULT,
_migrate_entity_ids=DEFAULT,
_post_migrate_entity_ids=DEFAULT, _post_migrate_entity_ids=DEFAULT,
_cleanup_legacy_states_event_ids=DEFAULT, _cleanup_legacy_states_event_ids=DEFAULT,
), ),

View file

@ -1451,22 +1451,16 @@ async def async_test_recorder(
else None else None
) )
migrate_states_context_ids = ( migrate_states_context_ids = (
recorder.Recorder._migrate_states_context_ids migration.migrate_states_context_ids if enable_migrate_context_ids else None
if enable_migrate_context_ids
else None
) )
migrate_events_context_ids = ( migrate_events_context_ids = (
recorder.Recorder._migrate_events_context_ids migration.migrate_events_context_ids if enable_migrate_context_ids else None
if enable_migrate_context_ids
else None
) )
migrate_event_type_ids = ( migrate_event_type_ids = (
recorder.Recorder._migrate_event_type_ids migration.migrate_event_type_ids if enable_migrate_event_type_ids else None
if enable_migrate_event_type_ids
else None
) )
migrate_entity_ids = ( migrate_entity_ids = (
recorder.Recorder._migrate_entity_ids if enable_migrate_entity_ids else None migration.migrate_entity_ids if enable_migrate_entity_ids else None
) )
legacy_event_id_foreign_key_exists = ( legacy_event_id_foreign_key_exists = (
recorder.Recorder._legacy_event_id_foreign_key_exists recorder.Recorder._legacy_event_id_foreign_key_exists
@ -1490,22 +1484,22 @@ async def async_test_recorder(
autospec=True, autospec=True,
), ),
patch( patch(
"homeassistant.components.recorder.Recorder._migrate_events_context_ids", "homeassistant.components.recorder.migration.migrate_events_context_ids",
side_effect=migrate_events_context_ids, side_effect=migrate_events_context_ids,
autospec=True, autospec=True,
), ),
patch( patch(
"homeassistant.components.recorder.Recorder._migrate_states_context_ids", "homeassistant.components.recorder.migration.migrate_states_context_ids",
side_effect=migrate_states_context_ids, side_effect=migrate_states_context_ids,
autospec=True, autospec=True,
), ),
patch( patch(
"homeassistant.components.recorder.Recorder._migrate_event_type_ids", "homeassistant.components.recorder.migration.migrate_event_type_ids",
side_effect=migrate_event_type_ids, side_effect=migrate_event_type_ids,
autospec=True, autospec=True,
), ),
patch( patch(
"homeassistant.components.recorder.Recorder._migrate_entity_ids", "homeassistant.components.recorder.migration.migrate_entity_ids",
side_effect=migrate_entity_ids, side_effect=migrate_entity_ids,
autospec=True, autospec=True,
), ),