Improve reliability of context id migration (#89609)

* Split context id migration into states and events tasks

Since events can finish much earlier than states we
would keep looking at the table because states as not
done. Make them seperate tasks

* add retry dec

* fix migration happening twice

* another case
This commit is contained in:
J. Nick Koston 2023-03-12 15:41:48 -10:00 committed by GitHub
parent 85ca94e9d4
commit b9ac6b4a7c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 258 additions and 47 deletions

View file

@ -97,9 +97,9 @@ from .tasks import (
ChangeStatisticsUnitTask, ChangeStatisticsUnitTask,
ClearStatisticsTask, ClearStatisticsTask,
CommitTask, CommitTask,
ContextIDMigrationTask,
DatabaseLockTask, DatabaseLockTask,
EntityIDMigrationTask, EntityIDMigrationTask,
EventsContextIDMigrationTask,
EventTask, EventTask,
EventTypeIDMigrationTask, EventTypeIDMigrationTask,
ImportStatisticsTask, ImportStatisticsTask,
@ -107,6 +107,7 @@ from .tasks import (
PerodicCleanupTask, PerodicCleanupTask,
PurgeTask, PurgeTask,
RecorderTask, RecorderTask,
StatesContextIDMigrationTask,
StatisticsTask, StatisticsTask,
StopTask, StopTask,
SynchronizeTask, SynchronizeTask,
@ -654,8 +655,9 @@ class Recorder(threading.Thread):
self.migration_is_live = migration.live_migration(schema_status) self.migration_is_live = migration.live_migration(schema_status)
self.hass.add_job(self.async_connection_success) self.hass.add_job(self.async_connection_success)
database_was_ready = self.migration_is_live or schema_status.valid
if self.migration_is_live or schema_status.valid: if database_was_ready:
# If the migrate is live or the schema is valid, we need to # If the migrate is live or the schema is valid, we need to
# wait for startup to complete. If its not live, we need to continue # wait for startup to complete. If its not live, we need to continue
# on. # on.
@ -670,7 +672,6 @@ class Recorder(threading.Thread):
# Make sure we cleanly close the run if # Make sure we cleanly close the run if
# we restart before startup finishes # we restart before startup finishes
self._shutdown() self._shutdown()
self._activate_and_set_db_ready()
return return
if not schema_status.valid: if not schema_status.valid:
@ -692,7 +693,8 @@ class Recorder(threading.Thread):
self._shutdown() self._shutdown()
return return
self._activate_and_set_db_ready() if not database_was_ready:
self._activate_and_set_db_ready()
# Catch up with missed statistics # Catch up with missed statistics
with session_scope(session=self.get_session()) as session: with session_scope(session=self.get_session()) as session:
@ -710,9 +712,14 @@ class Recorder(threading.Thread):
if ( if (
self.schema_version < 36 self.schema_version < 36
or session.execute(has_events_context_ids_to_migrate()).scalar() or session.execute(has_events_context_ids_to_migrate()).scalar()
):
self.queue_task(StatesContextIDMigrationTask())
if (
self.schema_version < 36
or session.execute(has_states_context_ids_to_migrate()).scalar() or session.execute(has_states_context_ids_to_migrate()).scalar()
): ):
self.queue_task(ContextIDMigrationTask()) self.queue_task(EventsContextIDMigrationTask())
if ( if (
self.schema_version < 37 self.schema_version < 37
@ -1236,9 +1243,13 @@ class Recorder(threading.Thread):
"""Run post schema migration tasks.""" """Run post schema migration tasks."""
migration.post_schema_migration(self, old_version, new_version) migration.post_schema_migration(self, old_version, new_version)
def _migrate_context_ids(self) -> bool: def _migrate_states_context_ids(self) -> bool:
"""Migrate context ids if needed.""" """Migrate states context ids if needed."""
return migration.migrate_context_ids(self) return migration.migrate_states_context_ids(self)
def _migrate_events_context_ids(self) -> bool:
"""Migrate events context ids if needed."""
return migration.migrate_events_context_ids(self)
def _migrate_event_type_ids(self) -> bool: def _migrate_event_type_ids(self) -> bool:
"""Migrate event type ids if needed.""" """Migrate event type ids if needed."""

View file

@ -64,7 +64,7 @@ from .tasks import (
PostSchemaMigrationTask, PostSchemaMigrationTask,
StatisticsTimestampMigrationCleanupTask, StatisticsTimestampMigrationCleanupTask,
) )
from .util import database_job_retry_wrapper, session_scope from .util import database_job_retry_wrapper, retryable_database_job, session_scope
if TYPE_CHECKING: if TYPE_CHECKING:
from . import Recorder from . import Recorder
@ -1301,8 +1301,43 @@ def _context_id_to_bytes(context_id: str | None) -> bytes | None:
return None return None
def migrate_context_ids(instance: Recorder) -> bool: @retryable_database_job("migrate states context_ids to binary format")
"""Migrate context_ids to use binary format.""" def migrate_states_context_ids(instance: Recorder) -> bool:
"""Migrate states context_ids to use binary format."""
_to_bytes = _context_id_to_bytes
session_maker = instance.get_session
_LOGGER.debug("Migrating states context_ids to binary format")
with session_scope(session=session_maker()) as session:
if states := session.execute(find_states_context_ids_to_migrate()).all():
session.execute(
update(States),
[
{
"state_id": state_id,
"context_id": None,
"context_id_bin": _to_bytes(context_id) or _EMPTY_CONTEXT_ID,
"context_user_id": None,
"context_user_id_bin": _to_bytes(context_user_id),
"context_parent_id": None,
"context_parent_id_bin": _to_bytes(context_parent_id),
}
for state_id, context_id, context_user_id, context_parent_id in states
],
)
# If there is more work to do return False
# so that we can be called again
is_done = not states
if is_done:
_drop_index(session_maker, "states", "ix_states_context_id")
_LOGGER.debug("Migrating states context_ids to binary format: done=%s", is_done)
return is_done
@retryable_database_job("migrate events context_ids to binary format")
def migrate_events_context_ids(instance: Recorder) -> bool:
"""Migrate events context_ids to use binary format."""
_to_bytes = _context_id_to_bytes _to_bytes = _context_id_to_bytes
session_maker = instance.get_session session_maker = instance.get_session
_LOGGER.debug("Migrating context_ids to binary format") _LOGGER.debug("Migrating context_ids to binary format")
@ -1323,34 +1358,18 @@ def migrate_context_ids(instance: Recorder) -> bool:
for event_id, context_id, context_user_id, context_parent_id in events for event_id, context_id, context_user_id, context_parent_id in events
], ],
) )
if states := session.execute(find_states_context_ids_to_migrate()).all():
session.execute(
update(States),
[
{
"state_id": state_id,
"context_id": None,
"context_id_bin": _to_bytes(context_id) or _EMPTY_CONTEXT_ID,
"context_user_id": None,
"context_user_id_bin": _to_bytes(context_user_id),
"context_parent_id": None,
"context_parent_id_bin": _to_bytes(context_parent_id),
}
for state_id, context_id, context_user_id, context_parent_id in states
],
)
# If there is more work to do return False # If there is more work to do return False
# so that we can be called again # so that we can be called again
is_done = not (events or states) is_done = not events
if is_done: if is_done:
_drop_index(session_maker, "events", "ix_events_context_id") _drop_index(session_maker, "events", "ix_events_context_id")
_drop_index(session_maker, "states", "ix_states_context_id")
_LOGGER.debug("Migrating context_ids to binary format: done=%s", is_done) _LOGGER.debug("Migrating events context_ids to binary format: done=%s", is_done)
return is_done return is_done
@retryable_database_job("migrate events event_types to event_type_ids")
def migrate_event_type_ids(instance: Recorder) -> bool: def migrate_event_type_ids(instance: Recorder) -> bool:
"""Migrate event_type to event_type_ids.""" """Migrate event_type to event_type_ids."""
session_maker = instance.get_session session_maker = instance.get_session
@ -1407,6 +1426,7 @@ def migrate_event_type_ids(instance: Recorder) -> bool:
return is_done return is_done
@retryable_database_job("migrate states entity_ids to states_meta")
def migrate_entity_ids(instance: Recorder) -> bool: def migrate_entity_ids(instance: Recorder) -> bool:
"""Migrate entity_ids to states_meta. """Migrate entity_ids to states_meta.
@ -1468,6 +1488,7 @@ def migrate_entity_ids(instance: Recorder) -> bool:
return is_done return is_done
@retryable_database_job("post migrate states entity_ids to states_meta")
def post_migrate_entity_ids(instance: Recorder) -> bool: def post_migrate_entity_ids(instance: Recorder) -> bool:
"""Remove old entity_id strings from states. """Remove old entity_id strings from states.

View file

@ -346,16 +346,33 @@ class AdjustLRUSizeTask(RecorderTask):
@dataclass @dataclass
class ContextIDMigrationTask(RecorderTask): class StatesContextIDMigrationTask(RecorderTask):
"""An object to insert into the recorder queue to migrate context ids.""" """An object to insert into the recorder queue to migrate states context ids."""
commit_before = False commit_before = False
def run(self, instance: Recorder) -> None: def run(self, instance: Recorder) -> None:
"""Run context id migration task.""" """Run context id migration task."""
if not instance._migrate_context_ids(): # pylint: disable=[protected-access] if (
not instance._migrate_states_context_ids() # pylint: disable=[protected-access]
):
# Schedule a new migration task if this one didn't finish # Schedule a new migration task if this one didn't finish
instance.queue_task(ContextIDMigrationTask()) instance.queue_task(StatesContextIDMigrationTask())
@dataclass
class EventsContextIDMigrationTask(RecorderTask):
"""An object to insert into the recorder queue to migrate events context ids."""
commit_before = False
def run(self, instance: Recorder) -> None:
"""Run context id migration task."""
if (
not instance._migrate_events_context_ids() # pylint: disable=[protected-access]
):
# Schedule a new migration task if this one didn't finish
instance.queue_task(EventsContextIDMigrationTask())
@dataclass @dataclass

View file

@ -32,10 +32,11 @@ from homeassistant.components.recorder.db_schema import (
) )
from homeassistant.components.recorder.queries import select_event_type_ids from homeassistant.components.recorder.queries import select_event_type_ids
from homeassistant.components.recorder.tasks import ( from homeassistant.components.recorder.tasks import (
ContextIDMigrationTask,
EntityIDMigrationTask, EntityIDMigrationTask,
EntityIDPostMigrationTask, EntityIDPostMigrationTask,
EventsContextIDMigrationTask,
EventTypeIDMigrationTask, EventTypeIDMigrationTask,
StatesContextIDMigrationTask,
) )
from homeassistant.components.recorder.util import session_scope from homeassistant.components.recorder.util import session_scope
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -558,7 +559,7 @@ def test_raise_if_exception_missing_empty_cause_str() -> None:
@pytest.mark.parametrize("enable_migrate_context_ids", [True]) @pytest.mark.parametrize("enable_migrate_context_ids", [True])
async def test_migrate_context_ids( async def test_migrate_events_context_ids(
async_setup_recorder_instance: RecorderInstanceGenerator, hass: HomeAssistant async_setup_recorder_instance: RecorderInstanceGenerator, hass: HomeAssistant
) -> None: ) -> None:
"""Test we can migrate old uuid context ids and ulid context ids to binary format.""" """Test we can migrate old uuid context ids and ulid context ids to binary format."""
@ -632,7 +633,7 @@ async def test_migrate_context_ids(
await async_wait_recording_done(hass) await async_wait_recording_done(hass)
# This is a threadsafe way to add a task to the recorder # This is a threadsafe way to add a task to the recorder
instance.queue_task(ContextIDMigrationTask()) instance.queue_task(EventsContextIDMigrationTask())
await async_recorder_block_till_done(hass) await async_recorder_block_till_done(hass)
def _object_as_dict(obj): def _object_as_dict(obj):
@ -701,6 +702,137 @@ async def test_migrate_context_ids(
assert invalid_context_id_event["context_parent_id_bin"] is None assert invalid_context_id_event["context_parent_id_bin"] is None
@pytest.mark.parametrize("enable_migrate_context_ids", [True])
async def test_migrate_states_context_ids(
async_setup_recorder_instance: RecorderInstanceGenerator, hass: HomeAssistant
) -> None:
"""Test we can migrate old uuid context ids and ulid context ids to binary format."""
instance = await async_setup_recorder_instance(hass)
await async_wait_recording_done(hass)
test_uuid = uuid.uuid4()
uuid_hex = test_uuid.hex
uuid_bin = test_uuid.bytes
def _insert_events():
with session_scope(hass=hass) as session:
session.add_all(
(
States(
entity_id="state.old_uuid_context_id",
last_updated_ts=1677721632.452529,
context_id=uuid_hex,
context_id_bin=None,
context_user_id=None,
context_user_id_bin=None,
context_parent_id=None,
context_parent_id_bin=None,
),
States(
entity_id="state.empty_context_id",
last_updated_ts=1677721632.552529,
context_id=None,
context_id_bin=None,
context_user_id=None,
context_user_id_bin=None,
context_parent_id=None,
context_parent_id_bin=None,
),
States(
entity_id="state.ulid_context_id",
last_updated_ts=1677721632.552529,
context_id="01ARZ3NDEKTSV4RRFFQ69G5FAV",
context_id_bin=None,
context_user_id="9400facee45711eaa9308bfd3d19e474",
context_user_id_bin=None,
context_parent_id="01ARZ3NDEKTSV4RRFFQ69G5FA2",
context_parent_id_bin=None,
),
States(
entity_id="state.invalid_context_id",
last_updated_ts=1677721632.552529,
context_id="invalid",
context_id_bin=None,
context_user_id=None,
context_user_id_bin=None,
context_parent_id=None,
context_parent_id_bin=None,
),
)
)
await instance.async_add_executor_job(_insert_events)
await async_wait_recording_done(hass)
# This is a threadsafe way to add a task to the recorder
instance.queue_task(StatesContextIDMigrationTask())
await async_recorder_block_till_done(hass)
def _object_as_dict(obj):
return {c.key: getattr(obj, c.key) for c in inspect(obj).mapper.column_attrs}
def _fetch_migrated_states():
with session_scope(hass=hass) as session:
events = (
session.query(States)
.filter(
States.entity_id.in_(
[
"state.old_uuid_context_id",
"state.empty_context_id",
"state.ulid_context_id",
"state.invalid_context_id",
]
)
)
.all()
)
assert len(events) == 4
return {state.entity_id: _object_as_dict(state) for state in events}
states_by_entity_id = await instance.async_add_executor_job(_fetch_migrated_states)
old_uuid_context_id = states_by_entity_id["state.old_uuid_context_id"]
assert old_uuid_context_id["context_id"] is None
assert old_uuid_context_id["context_user_id"] is None
assert old_uuid_context_id["context_parent_id"] is None
assert old_uuid_context_id["context_id_bin"] == uuid_bin
assert old_uuid_context_id["context_user_id_bin"] is None
assert old_uuid_context_id["context_parent_id_bin"] is None
empty_context_id = states_by_entity_id["state.empty_context_id"]
assert empty_context_id["context_id"] is None
assert empty_context_id["context_user_id"] is None
assert empty_context_id["context_parent_id"] is None
assert empty_context_id["context_id_bin"] == b"\x00" * 16
assert empty_context_id["context_user_id_bin"] is None
assert empty_context_id["context_parent_id_bin"] is None
ulid_context_id = states_by_entity_id["state.ulid_context_id"]
assert ulid_context_id["context_id"] is None
assert ulid_context_id["context_user_id"] is None
assert ulid_context_id["context_parent_id"] is None
assert (
bytes_to_ulid(ulid_context_id["context_id_bin"]) == "01ARZ3NDEKTSV4RRFFQ69G5FAV"
)
assert (
ulid_context_id["context_user_id_bin"]
== b"\x94\x00\xfa\xce\xe4W\x11\xea\xa90\x8b\xfd=\x19\xe4t"
)
assert (
bytes_to_ulid(ulid_context_id["context_parent_id_bin"])
== "01ARZ3NDEKTSV4RRFFQ69G5FA2"
)
invalid_context_id = states_by_entity_id["state.invalid_context_id"]
assert invalid_context_id["context_id"] is None
assert invalid_context_id["context_user_id"] is None
assert invalid_context_id["context_parent_id"] is None
assert invalid_context_id["context_id_bin"] == b"\x00" * 16
assert invalid_context_id["context_user_id_bin"] is None
assert invalid_context_id["context_parent_id_bin"] is None
@pytest.mark.parametrize("enable_migrate_event_type_ids", [True]) @pytest.mark.parametrize("enable_migrate_event_type_ids", [True])
async def test_migrate_event_type_ids( async def test_migrate_event_type_ids(
async_setup_recorder_instance: RecorderInstanceGenerator, hass: HomeAssistant async_setup_recorder_instance: RecorderInstanceGenerator, hass: HomeAssistant

View file

@ -86,6 +86,7 @@ async def test_migrate_times(caplog: pytest.LogCaptureFixture, tmpdir) -> None:
EventOrigin.local, EventOrigin.local,
time_fired=now, time_fired=now,
) )
number_of_migrations = 5
with patch.object(recorder, "db_schema", old_db_schema), patch.object( with patch.object(recorder, "db_schema", old_db_schema), patch.object(
recorder.migration, "SCHEMA_VERSION", old_db_schema.SCHEMA_VERSION recorder.migration, "SCHEMA_VERSION", old_db_schema.SCHEMA_VERSION
@ -100,11 +101,15 @@ async def test_migrate_times(caplog: pytest.LogCaptureFixture, tmpdir) -> None:
), patch( ), patch(
CREATE_ENGINE_TARGET, new=_create_engine_test CREATE_ENGINE_TARGET, new=_create_engine_test
), patch( ), patch(
"homeassistant.components.recorder.Recorder._migrate_context_ids", "homeassistant.components.recorder.Recorder._migrate_events_context_ids",
), patch(
"homeassistant.components.recorder.Recorder._migrate_states_context_ids",
), patch( ), patch(
"homeassistant.components.recorder.Recorder._migrate_event_type_ids", "homeassistant.components.recorder.Recorder._migrate_event_type_ids",
), patch( ), patch(
"homeassistant.components.recorder.Recorder._migrate_entity_ids", "homeassistant.components.recorder.Recorder._migrate_entity_ids",
), patch(
"homeassistant.components.recorder.Recorder._post_migrate_entity_ids"
): ):
hass = await async_test_home_assistant(asyncio.get_running_loop()) hass = await async_test_home_assistant(asyncio.get_running_loop())
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
@ -122,8 +127,10 @@ async def test_migrate_times(caplog: pytest.LogCaptureFixture, tmpdir) -> None:
await recorder.get_instance(hass).async_add_executor_job(_add_data) await recorder.get_instance(hass).async_add_executor_job(_add_data)
await hass.async_block_till_done() await hass.async_block_till_done()
await recorder.get_instance(hass).async_block_till_done()
await hass.async_stop() await hass.async_stop()
await hass.async_block_till_done()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ dt_util.DEFAULT_TIME_ZONE = ORIG_TZ
@ -137,7 +144,8 @@ async def test_migrate_times(caplog: pytest.LogCaptureFixture, tmpdir) -> None:
# We need to wait for all the migration tasks to complete # We need to wait for all the migration tasks to complete
# before we can check the database. # before we can check the database.
for _ in range(5): for _ in range(number_of_migrations):
await recorder.get_instance(hass).async_block_till_done()
await async_wait_recording_done(hass) await async_wait_recording_done(hass)
def _get_test_data_from_db(): def _get_test_data_from_db():

View file

@ -1250,8 +1250,15 @@ def hass_recorder(
if enable_statistics_table_validation if enable_statistics_table_validation
else itertools.repeat(set()) else itertools.repeat(set())
) )
migrate_context_ids = ( migrate_states_context_ids = (
recorder.Recorder._migrate_context_ids if enable_migrate_context_ids else None recorder.Recorder._migrate_states_context_ids
if enable_migrate_context_ids
else None
)
migrate_events_context_ids = (
recorder.Recorder._migrate_events_context_ids
if enable_migrate_context_ids
else None
) )
migrate_event_type_ids = ( migrate_event_type_ids = (
recorder.Recorder._migrate_event_type_ids recorder.Recorder._migrate_event_type_ids
@ -1274,8 +1281,12 @@ def hass_recorder(
side_effect=stats_validate, side_effect=stats_validate,
autospec=True, autospec=True,
), patch( ), patch(
"homeassistant.components.recorder.Recorder._migrate_context_ids", "homeassistant.components.recorder.Recorder._migrate_events_context_ids",
side_effect=migrate_context_ids, side_effect=migrate_events_context_ids,
autospec=True,
), patch(
"homeassistant.components.recorder.Recorder._migrate_states_context_ids",
side_effect=migrate_states_context_ids,
autospec=True, autospec=True,
), patch( ), patch(
"homeassistant.components.recorder.Recorder._migrate_event_type_ids", "homeassistant.components.recorder.Recorder._migrate_event_type_ids",
@ -1354,8 +1365,15 @@ async def async_setup_recorder_instance(
if enable_statistics_table_validation if enable_statistics_table_validation
else itertools.repeat(set()) else itertools.repeat(set())
) )
migrate_context_ids = ( migrate_states_context_ids = (
recorder.Recorder._migrate_context_ids if enable_migrate_context_ids else None recorder.Recorder._migrate_states_context_ids
if enable_migrate_context_ids
else None
)
migrate_events_context_ids = (
recorder.Recorder._migrate_events_context_ids
if enable_migrate_context_ids
else None
) )
migrate_event_type_ids = ( migrate_event_type_ids = (
recorder.Recorder._migrate_event_type_ids recorder.Recorder._migrate_event_type_ids
@ -1378,8 +1396,12 @@ async def async_setup_recorder_instance(
side_effect=stats_validate, side_effect=stats_validate,
autospec=True, autospec=True,
), patch( ), patch(
"homeassistant.components.recorder.Recorder._migrate_context_ids", "homeassistant.components.recorder.Recorder._migrate_events_context_ids",
side_effect=migrate_context_ids, side_effect=migrate_events_context_ids,
autospec=True,
), patch(
"homeassistant.components.recorder.Recorder._migrate_states_context_ids",
side_effect=migrate_states_context_ids,
autospec=True, autospec=True,
), patch( ), patch(
"homeassistant.components.recorder.Recorder._migrate_event_type_ids", "homeassistant.components.recorder.Recorder._migrate_event_type_ids",