Fix race in tracking pending writes in recorder (#93414)

This commit is contained in:
J. Nick Koston 2023-05-23 14:47:31 -05:00 committed by GitHub
parent f6e7b727b0
commit f09abb0f2e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -215,6 +215,7 @@ class Recorder(threading.Thread):
self.schema_version = 0 self.schema_version = 0
self._commits_without_expire = 0 self._commits_without_expire = 0
self._event_session_has_pending_writes = False
self.recorder_runs_manager = RecorderRunsManager() self.recorder_runs_manager = RecorderRunsManager()
self.states_manager = StatesManager() self.states_manager = StatesManager()
@ -322,7 +323,7 @@ class Recorder(threading.Thread):
if ( if (
self._event_listener self._event_listener
and not self._database_lock_task and not self._database_lock_task
and self._event_session_has_pending_writes() and self._event_session_has_pending_writes
): ):
self.queue_task(COMMIT_TASK) self.queue_task(COMMIT_TASK)
@ -688,6 +689,11 @@ class Recorder(threading.Thread):
# anything goes wrong in the run loop # anything goes wrong in the run loop
self._shutdown() self._shutdown()
def _add_to_session(self, session: Session, obj: object) -> None:
"""Add an object to the session."""
self._event_session_has_pending_writes = True
session.add(obj)
def _run(self) -> None: def _run(self) -> None:
"""Start processing events to save.""" """Start processing events to save."""
self.thread_id = threading.get_ident() self.thread_id = threading.get_ident()
@ -1016,11 +1022,11 @@ class Recorder(threading.Thread):
else: else:
event_types = EventTypes(event_type=event.event_type) event_types = EventTypes(event_type=event.event_type)
event_type_manager.add_pending(event_types) event_type_manager.add_pending(event_types)
session.add(event_types) self._add_to_session(session, event_types)
dbevent.event_type_rel = event_types dbevent.event_type_rel = event_types
if not event.data: if not event.data:
session.add(dbevent) self._add_to_session(session, dbevent)
return return
event_data_manager = self.event_data_manager event_data_manager = self.event_data_manager
@ -1042,10 +1048,10 @@ class Recorder(threading.Thread):
# No matching attributes found, save them in the DB # No matching attributes found, save them in the DB
dbevent_data = EventData(shared_data=shared_data, hash=hash_) dbevent_data = EventData(shared_data=shared_data, hash=hash_)
event_data_manager.add_pending(dbevent_data) event_data_manager.add_pending(dbevent_data)
session.add(dbevent_data) self._add_to_session(session, dbevent_data)
dbevent.event_data_rel = dbevent_data dbevent.event_data_rel = dbevent_data
session.add(dbevent) self._add_to_session(session, dbevent)
def _process_state_changed_event_into_session(self, event: Event) -> None: def _process_state_changed_event_into_session(self, event: Event) -> None:
"""Process a state_changed event into the session.""" """Process a state_changed event into the session."""
@ -1090,7 +1096,7 @@ class Recorder(threading.Thread):
else: else:
states_meta = StatesMeta(entity_id=entity_id) states_meta = StatesMeta(entity_id=entity_id)
states_meta_manager.add_pending(states_meta) states_meta_manager.add_pending(states_meta)
session.add(states_meta) self._add_to_session(session, states_meta)
dbstate.states_meta_rel = states_meta dbstate.states_meta_rel = states_meta
# Map the event data to the StateAttributes table # Map the event data to the StateAttributes table
@ -1115,10 +1121,10 @@ class Recorder(threading.Thread):
# No matching attributes found, save them in the DB # No matching attributes found, save them in the DB
dbstate_attributes = StateAttributes(shared_attrs=shared_attrs, hash=hash_) dbstate_attributes = StateAttributes(shared_attrs=shared_attrs, hash=hash_)
state_attributes_manager.add_pending(dbstate_attributes) state_attributes_manager.add_pending(dbstate_attributes)
session.add(dbstate_attributes) self._add_to_session(session, dbstate_attributes)
dbstate.state_attributes = dbstate_attributes dbstate.state_attributes = dbstate_attributes
session.add(dbstate) self._add_to_session(session, dbstate)
def _handle_database_error(self, err: Exception) -> bool: def _handle_database_error(self, err: Exception) -> bool:
"""Handle a database error that may result in moving away the corrupt db.""" """Handle a database error that may result in moving away the corrupt db."""
@ -1130,14 +1136,9 @@ class Recorder(threading.Thread):
return True return True
return False return False
def _event_session_has_pending_writes(self) -> bool:
"""Return True if there are pending writes in the event session."""
session = self.event_session
return bool(session and (session.new or session.dirty))
def _commit_event_session_or_retry(self) -> None: def _commit_event_session_or_retry(self) -> None:
"""Commit the event session if there is work to do.""" """Commit the event session if there is work to do."""
if not self._event_session_has_pending_writes(): if not self._event_session_has_pending_writes:
return return
tries = 1 tries = 1
while tries <= self.db_max_retries: while tries <= self.db_max_retries:
@ -1163,6 +1164,7 @@ class Recorder(threading.Thread):
self._commits_without_expire += 1 self._commits_without_expire += 1
session.commit() session.commit()
self._event_session_has_pending_writes = False
# We just committed the state attributes to the database # We just committed the state attributes to the database
# and we now know the attributes_ids. We can save # and we now know the attributes_ids. We can save
# many selects for matching attributes by loading them # many selects for matching attributes by loading them
@ -1263,7 +1265,7 @@ class Recorder(threading.Thread):
async def async_block_till_done(self) -> None: async def async_block_till_done(self) -> None:
"""Async version of block_till_done.""" """Async version of block_till_done."""
if self._queue.empty() and not self._event_session_has_pending_writes(): if self._queue.empty() and not self._event_session_has_pending_writes:
return return
event = asyncio.Event() event = asyncio.Event()
self.queue_task(SynchronizeTask(event)) self.queue_task(SynchronizeTask(event))
@ -1417,6 +1419,8 @@ class Recorder(threading.Thread):
if self.event_session is None: if self.event_session is None:
return return
if self.recorder_runs_manager.active: if self.recorder_runs_manager.active:
# .end will add to the event session
self._event_session_has_pending_writes = True
self.recorder_runs_manager.end(self.event_session) self.recorder_runs_manager.end(self.event_session)
try: try:
self._commit_event_session_or_retry() self._commit_event_session_or_retry()