diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index 630efe19560..4ebd4703b65 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -97,9 +97,9 @@ from .tasks import ( ChangeStatisticsUnitTask, ClearStatisticsTask, CommitTask, - ContextIDMigrationTask, DatabaseLockTask, EntityIDMigrationTask, + EventsContextIDMigrationTask, EventTask, EventTypeIDMigrationTask, ImportStatisticsTask, @@ -107,6 +107,7 @@ from .tasks import ( PerodicCleanupTask, PurgeTask, RecorderTask, + StatesContextIDMigrationTask, StatisticsTask, StopTask, SynchronizeTask, @@ -654,8 +655,9 @@ class Recorder(threading.Thread): self.migration_is_live = migration.live_migration(schema_status) self.hass.add_job(self.async_connection_success) + database_was_ready = self.migration_is_live or schema_status.valid - if self.migration_is_live or schema_status.valid: + if database_was_ready: # If the migrate is live or the schema is valid, we need to # wait for startup to complete. If its not live, we need to continue # on. @@ -670,7 +672,6 @@ class Recorder(threading.Thread): # Make sure we cleanly close the run if # we restart before startup finishes self._shutdown() - self._activate_and_set_db_ready() return if not schema_status.valid: @@ -692,7 +693,8 @@ class Recorder(threading.Thread): self._shutdown() return - self._activate_and_set_db_ready() + if not database_was_ready: + self._activate_and_set_db_ready() # Catch up with missed statistics with session_scope(session=self.get_session()) as session: @@ -710,9 +712,14 @@ class Recorder(threading.Thread): if ( self.schema_version < 36 or session.execute(has_events_context_ids_to_migrate()).scalar() + ): + self.queue_task(StatesContextIDMigrationTask()) + + if ( + self.schema_version < 36 or session.execute(has_states_context_ids_to_migrate()).scalar() ): - self.queue_task(ContextIDMigrationTask()) + self.queue_task(EventsContextIDMigrationTask()) if ( self.schema_version < 37 @@ -1236,9 +1243,13 @@ class Recorder(threading.Thread): """Run post schema migration tasks.""" migration.post_schema_migration(self, old_version, new_version) - def _migrate_context_ids(self) -> bool: - """Migrate context ids if needed.""" - return migration.migrate_context_ids(self) + def _migrate_states_context_ids(self) -> bool: + """Migrate states context ids if needed.""" + return migration.migrate_states_context_ids(self) + + def _migrate_events_context_ids(self) -> bool: + """Migrate events context ids if needed.""" + return migration.migrate_events_context_ids(self) def _migrate_event_type_ids(self) -> bool: """Migrate event type ids if needed.""" diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index c13a0314577..b1b33dd29e2 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -64,7 +64,7 @@ from .tasks import ( PostSchemaMigrationTask, StatisticsTimestampMigrationCleanupTask, ) -from .util import database_job_retry_wrapper, session_scope +from .util import database_job_retry_wrapper, retryable_database_job, session_scope if TYPE_CHECKING: from . import Recorder @@ -1301,8 +1301,43 @@ def _context_id_to_bytes(context_id: str | None) -> bytes | None: return None -def migrate_context_ids(instance: Recorder) -> bool: - """Migrate context_ids to use binary format.""" +@retryable_database_job("migrate states context_ids to binary format") +def migrate_states_context_ids(instance: Recorder) -> bool: + """Migrate states context_ids to use binary format.""" + _to_bytes = _context_id_to_bytes + session_maker = instance.get_session + _LOGGER.debug("Migrating states context_ids to binary format") + with session_scope(session=session_maker()) as session: + if states := session.execute(find_states_context_ids_to_migrate()).all(): + session.execute( + update(States), + [ + { + "state_id": state_id, + "context_id": None, + "context_id_bin": _to_bytes(context_id) or _EMPTY_CONTEXT_ID, + "context_user_id": None, + "context_user_id_bin": _to_bytes(context_user_id), + "context_parent_id": None, + "context_parent_id_bin": _to_bytes(context_parent_id), + } + for state_id, context_id, context_user_id, context_parent_id in states + ], + ) + # If there is more work to do return False + # so that we can be called again + is_done = not states + + if is_done: + _drop_index(session_maker, "states", "ix_states_context_id") + + _LOGGER.debug("Migrating states context_ids to binary format: done=%s", is_done) + return is_done + + +@retryable_database_job("migrate events context_ids to binary format") +def migrate_events_context_ids(instance: Recorder) -> bool: + """Migrate events context_ids to use binary format.""" _to_bytes = _context_id_to_bytes session_maker = instance.get_session _LOGGER.debug("Migrating context_ids to binary format") @@ -1323,34 +1358,18 @@ def migrate_context_ids(instance: Recorder) -> bool: for event_id, context_id, context_user_id, context_parent_id in events ], ) - if states := session.execute(find_states_context_ids_to_migrate()).all(): - session.execute( - update(States), - [ - { - "state_id": state_id, - "context_id": None, - "context_id_bin": _to_bytes(context_id) or _EMPTY_CONTEXT_ID, - "context_user_id": None, - "context_user_id_bin": _to_bytes(context_user_id), - "context_parent_id": None, - "context_parent_id_bin": _to_bytes(context_parent_id), - } - for state_id, context_id, context_user_id, context_parent_id in states - ], - ) # If there is more work to do return False # so that we can be called again - is_done = not (events or states) + is_done = not events if is_done: _drop_index(session_maker, "events", "ix_events_context_id") - _drop_index(session_maker, "states", "ix_states_context_id") - _LOGGER.debug("Migrating context_ids to binary format: done=%s", is_done) + _LOGGER.debug("Migrating events context_ids to binary format: done=%s", is_done) return is_done +@retryable_database_job("migrate events event_types to event_type_ids") def migrate_event_type_ids(instance: Recorder) -> bool: """Migrate event_type to event_type_ids.""" session_maker = instance.get_session @@ -1407,6 +1426,7 @@ def migrate_event_type_ids(instance: Recorder) -> bool: return is_done +@retryable_database_job("migrate states entity_ids to states_meta") def migrate_entity_ids(instance: Recorder) -> bool: """Migrate entity_ids to states_meta. @@ -1468,6 +1488,7 @@ def migrate_entity_ids(instance: Recorder) -> bool: return is_done +@retryable_database_job("post migrate states entity_ids to states_meta") def post_migrate_entity_ids(instance: Recorder) -> bool: """Remove old entity_id strings from states. diff --git a/homeassistant/components/recorder/tasks.py b/homeassistant/components/recorder/tasks.py index 0b99ca742b2..17b63aad229 100644 --- a/homeassistant/components/recorder/tasks.py +++ b/homeassistant/components/recorder/tasks.py @@ -346,16 +346,33 @@ class AdjustLRUSizeTask(RecorderTask): @dataclass -class ContextIDMigrationTask(RecorderTask): - """An object to insert into the recorder queue to migrate context ids.""" +class StatesContextIDMigrationTask(RecorderTask): + """An object to insert into the recorder queue to migrate states context ids.""" commit_before = False def run(self, instance: Recorder) -> None: """Run context id migration task.""" - if not instance._migrate_context_ids(): # pylint: disable=[protected-access] + if ( + not instance._migrate_states_context_ids() # pylint: disable=[protected-access] + ): # Schedule a new migration task if this one didn't finish - instance.queue_task(ContextIDMigrationTask()) + instance.queue_task(StatesContextIDMigrationTask()) + + +@dataclass +class EventsContextIDMigrationTask(RecorderTask): + """An object to insert into the recorder queue to migrate events context ids.""" + + commit_before = False + + def run(self, instance: Recorder) -> None: + """Run context id migration task.""" + if ( + not instance._migrate_events_context_ids() # pylint: disable=[protected-access] + ): + # Schedule a new migration task if this one didn't finish + instance.queue_task(EventsContextIDMigrationTask()) @dataclass diff --git a/tests/components/recorder/test_migrate.py b/tests/components/recorder/test_migrate.py index 060d1bcb743..c9d0be5973f 100644 --- a/tests/components/recorder/test_migrate.py +++ b/tests/components/recorder/test_migrate.py @@ -32,10 +32,11 @@ from homeassistant.components.recorder.db_schema import ( ) from homeassistant.components.recorder.queries import select_event_type_ids from homeassistant.components.recorder.tasks import ( - ContextIDMigrationTask, EntityIDMigrationTask, EntityIDPostMigrationTask, + EventsContextIDMigrationTask, EventTypeIDMigrationTask, + StatesContextIDMigrationTask, ) from homeassistant.components.recorder.util import session_scope from homeassistant.core import HomeAssistant @@ -558,7 +559,7 @@ def test_raise_if_exception_missing_empty_cause_str() -> None: @pytest.mark.parametrize("enable_migrate_context_ids", [True]) -async def test_migrate_context_ids( +async def test_migrate_events_context_ids( async_setup_recorder_instance: RecorderInstanceGenerator, hass: HomeAssistant ) -> None: """Test we can migrate old uuid context ids and ulid context ids to binary format.""" @@ -632,7 +633,7 @@ async def test_migrate_context_ids( await async_wait_recording_done(hass) # This is a threadsafe way to add a task to the recorder - instance.queue_task(ContextIDMigrationTask()) + instance.queue_task(EventsContextIDMigrationTask()) await async_recorder_block_till_done(hass) def _object_as_dict(obj): @@ -701,6 +702,137 @@ async def test_migrate_context_ids( assert invalid_context_id_event["context_parent_id_bin"] is None +@pytest.mark.parametrize("enable_migrate_context_ids", [True]) +async def test_migrate_states_context_ids( + async_setup_recorder_instance: RecorderInstanceGenerator, hass: HomeAssistant +) -> None: + """Test we can migrate old uuid context ids and ulid context ids to binary format.""" + instance = await async_setup_recorder_instance(hass) + await async_wait_recording_done(hass) + + test_uuid = uuid.uuid4() + uuid_hex = test_uuid.hex + uuid_bin = test_uuid.bytes + + def _insert_events(): + with session_scope(hass=hass) as session: + session.add_all( + ( + States( + entity_id="state.old_uuid_context_id", + last_updated_ts=1677721632.452529, + context_id=uuid_hex, + context_id_bin=None, + context_user_id=None, + context_user_id_bin=None, + context_parent_id=None, + context_parent_id_bin=None, + ), + States( + entity_id="state.empty_context_id", + last_updated_ts=1677721632.552529, + context_id=None, + context_id_bin=None, + context_user_id=None, + context_user_id_bin=None, + context_parent_id=None, + context_parent_id_bin=None, + ), + States( + entity_id="state.ulid_context_id", + last_updated_ts=1677721632.552529, + context_id="01ARZ3NDEKTSV4RRFFQ69G5FAV", + context_id_bin=None, + context_user_id="9400facee45711eaa9308bfd3d19e474", + context_user_id_bin=None, + context_parent_id="01ARZ3NDEKTSV4RRFFQ69G5FA2", + context_parent_id_bin=None, + ), + States( + entity_id="state.invalid_context_id", + last_updated_ts=1677721632.552529, + context_id="invalid", + context_id_bin=None, + context_user_id=None, + context_user_id_bin=None, + context_parent_id=None, + context_parent_id_bin=None, + ), + ) + ) + + await instance.async_add_executor_job(_insert_events) + + await async_wait_recording_done(hass) + # This is a threadsafe way to add a task to the recorder + instance.queue_task(StatesContextIDMigrationTask()) + await async_recorder_block_till_done(hass) + + def _object_as_dict(obj): + return {c.key: getattr(obj, c.key) for c in inspect(obj).mapper.column_attrs} + + def _fetch_migrated_states(): + with session_scope(hass=hass) as session: + events = ( + session.query(States) + .filter( + States.entity_id.in_( + [ + "state.old_uuid_context_id", + "state.empty_context_id", + "state.ulid_context_id", + "state.invalid_context_id", + ] + ) + ) + .all() + ) + assert len(events) == 4 + return {state.entity_id: _object_as_dict(state) for state in events} + + states_by_entity_id = await instance.async_add_executor_job(_fetch_migrated_states) + + old_uuid_context_id = states_by_entity_id["state.old_uuid_context_id"] + assert old_uuid_context_id["context_id"] is None + assert old_uuid_context_id["context_user_id"] is None + assert old_uuid_context_id["context_parent_id"] is None + assert old_uuid_context_id["context_id_bin"] == uuid_bin + assert old_uuid_context_id["context_user_id_bin"] is None + assert old_uuid_context_id["context_parent_id_bin"] is None + + empty_context_id = states_by_entity_id["state.empty_context_id"] + assert empty_context_id["context_id"] is None + assert empty_context_id["context_user_id"] is None + assert empty_context_id["context_parent_id"] is None + assert empty_context_id["context_id_bin"] == b"\x00" * 16 + assert empty_context_id["context_user_id_bin"] is None + assert empty_context_id["context_parent_id_bin"] is None + + ulid_context_id = states_by_entity_id["state.ulid_context_id"] + assert ulid_context_id["context_id"] is None + assert ulid_context_id["context_user_id"] is None + assert ulid_context_id["context_parent_id"] is None + assert ( + bytes_to_ulid(ulid_context_id["context_id_bin"]) == "01ARZ3NDEKTSV4RRFFQ69G5FAV" + ) + assert ( + ulid_context_id["context_user_id_bin"] + == b"\x94\x00\xfa\xce\xe4W\x11\xea\xa90\x8b\xfd=\x19\xe4t" + ) + assert ( + bytes_to_ulid(ulid_context_id["context_parent_id_bin"]) + == "01ARZ3NDEKTSV4RRFFQ69G5FA2" + ) + + invalid_context_id = states_by_entity_id["state.invalid_context_id"] + assert invalid_context_id["context_id"] is None + assert invalid_context_id["context_user_id"] is None + assert invalid_context_id["context_parent_id"] is None + assert invalid_context_id["context_id_bin"] == b"\x00" * 16 + assert invalid_context_id["context_user_id_bin"] is None + assert invalid_context_id["context_parent_id_bin"] is None + + @pytest.mark.parametrize("enable_migrate_event_type_ids", [True]) async def test_migrate_event_type_ids( async_setup_recorder_instance: RecorderInstanceGenerator, hass: HomeAssistant diff --git a/tests/components/recorder/test_v32_migration.py b/tests/components/recorder/test_v32_migration.py index 50029e56b21..467dc2961c6 100644 --- a/tests/components/recorder/test_v32_migration.py +++ b/tests/components/recorder/test_v32_migration.py @@ -86,6 +86,7 @@ async def test_migrate_times(caplog: pytest.LogCaptureFixture, tmpdir) -> None: EventOrigin.local, time_fired=now, ) + number_of_migrations = 5 with patch.object(recorder, "db_schema", old_db_schema), patch.object( recorder.migration, "SCHEMA_VERSION", old_db_schema.SCHEMA_VERSION @@ -100,11 +101,15 @@ async def test_migrate_times(caplog: pytest.LogCaptureFixture, tmpdir) -> None: ), patch( CREATE_ENGINE_TARGET, new=_create_engine_test ), patch( - "homeassistant.components.recorder.Recorder._migrate_context_ids", + "homeassistant.components.recorder.Recorder._migrate_events_context_ids", + ), patch( + "homeassistant.components.recorder.Recorder._migrate_states_context_ids", ), patch( "homeassistant.components.recorder.Recorder._migrate_event_type_ids", ), patch( "homeassistant.components.recorder.Recorder._migrate_entity_ids", + ), patch( + "homeassistant.components.recorder.Recorder._post_migrate_entity_ids" ): hass = await async_test_home_assistant(asyncio.get_running_loop()) recorder_helper.async_initialize_recorder(hass) @@ -122,8 +127,10 @@ async def test_migrate_times(caplog: pytest.LogCaptureFixture, tmpdir) -> None: await recorder.get_instance(hass).async_add_executor_job(_add_data) await hass.async_block_till_done() + await recorder.get_instance(hass).async_block_till_done() await hass.async_stop() + await hass.async_block_till_done() dt_util.DEFAULT_TIME_ZONE = ORIG_TZ @@ -137,7 +144,8 @@ async def test_migrate_times(caplog: pytest.LogCaptureFixture, tmpdir) -> None: # We need to wait for all the migration tasks to complete # before we can check the database. - for _ in range(5): + for _ in range(number_of_migrations): + await recorder.get_instance(hass).async_block_till_done() await async_wait_recording_done(hass) def _get_test_data_from_db(): diff --git a/tests/conftest.py b/tests/conftest.py index 4f7b553955e..0307e65d272 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1250,8 +1250,15 @@ def hass_recorder( if enable_statistics_table_validation else itertools.repeat(set()) ) - migrate_context_ids = ( - recorder.Recorder._migrate_context_ids if enable_migrate_context_ids else None + migrate_states_context_ids = ( + recorder.Recorder._migrate_states_context_ids + if enable_migrate_context_ids + else None + ) + migrate_events_context_ids = ( + recorder.Recorder._migrate_events_context_ids + if enable_migrate_context_ids + else None ) migrate_event_type_ids = ( recorder.Recorder._migrate_event_type_ids @@ -1274,8 +1281,12 @@ def hass_recorder( side_effect=stats_validate, autospec=True, ), patch( - "homeassistant.components.recorder.Recorder._migrate_context_ids", - side_effect=migrate_context_ids, + "homeassistant.components.recorder.Recorder._migrate_events_context_ids", + side_effect=migrate_events_context_ids, + autospec=True, + ), patch( + "homeassistant.components.recorder.Recorder._migrate_states_context_ids", + side_effect=migrate_states_context_ids, autospec=True, ), patch( "homeassistant.components.recorder.Recorder._migrate_event_type_ids", @@ -1354,8 +1365,15 @@ async def async_setup_recorder_instance( if enable_statistics_table_validation else itertools.repeat(set()) ) - migrate_context_ids = ( - recorder.Recorder._migrate_context_ids if enable_migrate_context_ids else None + migrate_states_context_ids = ( + recorder.Recorder._migrate_states_context_ids + if enable_migrate_context_ids + else None + ) + migrate_events_context_ids = ( + recorder.Recorder._migrate_events_context_ids + if enable_migrate_context_ids + else None ) migrate_event_type_ids = ( recorder.Recorder._migrate_event_type_ids @@ -1378,8 +1396,12 @@ async def async_setup_recorder_instance( side_effect=stats_validate, autospec=True, ), patch( - "homeassistant.components.recorder.Recorder._migrate_context_ids", - side_effect=migrate_context_ids, + "homeassistant.components.recorder.Recorder._migrate_events_context_ids", + side_effect=migrate_events_context_ids, + autospec=True, + ), patch( + "homeassistant.components.recorder.Recorder._migrate_states_context_ids", + side_effect=migrate_states_context_ids, autospec=True, ), patch( "homeassistant.components.recorder.Recorder._migrate_event_type_ids",