Fix race in tracking pending writes in recorder (#93414)
This commit is contained in:
parent
f6e7b727b0
commit
f09abb0f2e
1 changed files with 19 additions and 15 deletions
|
@ -215,6 +215,7 @@ class Recorder(threading.Thread):
|
||||||
|
|
||||||
self.schema_version = 0
|
self.schema_version = 0
|
||||||
self._commits_without_expire = 0
|
self._commits_without_expire = 0
|
||||||
|
self._event_session_has_pending_writes = False
|
||||||
|
|
||||||
self.recorder_runs_manager = RecorderRunsManager()
|
self.recorder_runs_manager = RecorderRunsManager()
|
||||||
self.states_manager = StatesManager()
|
self.states_manager = StatesManager()
|
||||||
|
@ -322,7 +323,7 @@ class Recorder(threading.Thread):
|
||||||
if (
|
if (
|
||||||
self._event_listener
|
self._event_listener
|
||||||
and not self._database_lock_task
|
and not self._database_lock_task
|
||||||
and self._event_session_has_pending_writes()
|
and self._event_session_has_pending_writes
|
||||||
):
|
):
|
||||||
self.queue_task(COMMIT_TASK)
|
self.queue_task(COMMIT_TASK)
|
||||||
|
|
||||||
|
@ -688,6 +689,11 @@ class Recorder(threading.Thread):
|
||||||
# anything goes wrong in the run loop
|
# anything goes wrong in the run loop
|
||||||
self._shutdown()
|
self._shutdown()
|
||||||
|
|
||||||
|
def _add_to_session(self, session: Session, obj: object) -> None:
|
||||||
|
"""Add an object to the session."""
|
||||||
|
self._event_session_has_pending_writes = True
|
||||||
|
session.add(obj)
|
||||||
|
|
||||||
def _run(self) -> None:
|
def _run(self) -> None:
|
||||||
"""Start processing events to save."""
|
"""Start processing events to save."""
|
||||||
self.thread_id = threading.get_ident()
|
self.thread_id = threading.get_ident()
|
||||||
|
@ -1016,11 +1022,11 @@ class Recorder(threading.Thread):
|
||||||
else:
|
else:
|
||||||
event_types = EventTypes(event_type=event.event_type)
|
event_types = EventTypes(event_type=event.event_type)
|
||||||
event_type_manager.add_pending(event_types)
|
event_type_manager.add_pending(event_types)
|
||||||
session.add(event_types)
|
self._add_to_session(session, event_types)
|
||||||
dbevent.event_type_rel = event_types
|
dbevent.event_type_rel = event_types
|
||||||
|
|
||||||
if not event.data:
|
if not event.data:
|
||||||
session.add(dbevent)
|
self._add_to_session(session, dbevent)
|
||||||
return
|
return
|
||||||
|
|
||||||
event_data_manager = self.event_data_manager
|
event_data_manager = self.event_data_manager
|
||||||
|
@ -1042,10 +1048,10 @@ class Recorder(threading.Thread):
|
||||||
# No matching attributes found, save them in the DB
|
# No matching attributes found, save them in the DB
|
||||||
dbevent_data = EventData(shared_data=shared_data, hash=hash_)
|
dbevent_data = EventData(shared_data=shared_data, hash=hash_)
|
||||||
event_data_manager.add_pending(dbevent_data)
|
event_data_manager.add_pending(dbevent_data)
|
||||||
session.add(dbevent_data)
|
self._add_to_session(session, dbevent_data)
|
||||||
dbevent.event_data_rel = dbevent_data
|
dbevent.event_data_rel = dbevent_data
|
||||||
|
|
||||||
session.add(dbevent)
|
self._add_to_session(session, dbevent)
|
||||||
|
|
||||||
def _process_state_changed_event_into_session(self, event: Event) -> None:
|
def _process_state_changed_event_into_session(self, event: Event) -> None:
|
||||||
"""Process a state_changed event into the session."""
|
"""Process a state_changed event into the session."""
|
||||||
|
@ -1090,7 +1096,7 @@ class Recorder(threading.Thread):
|
||||||
else:
|
else:
|
||||||
states_meta = StatesMeta(entity_id=entity_id)
|
states_meta = StatesMeta(entity_id=entity_id)
|
||||||
states_meta_manager.add_pending(states_meta)
|
states_meta_manager.add_pending(states_meta)
|
||||||
session.add(states_meta)
|
self._add_to_session(session, states_meta)
|
||||||
dbstate.states_meta_rel = states_meta
|
dbstate.states_meta_rel = states_meta
|
||||||
|
|
||||||
# Map the event data to the StateAttributes table
|
# Map the event data to the StateAttributes table
|
||||||
|
@ -1115,10 +1121,10 @@ class Recorder(threading.Thread):
|
||||||
# No matching attributes found, save them in the DB
|
# No matching attributes found, save them in the DB
|
||||||
dbstate_attributes = StateAttributes(shared_attrs=shared_attrs, hash=hash_)
|
dbstate_attributes = StateAttributes(shared_attrs=shared_attrs, hash=hash_)
|
||||||
state_attributes_manager.add_pending(dbstate_attributes)
|
state_attributes_manager.add_pending(dbstate_attributes)
|
||||||
session.add(dbstate_attributes)
|
self._add_to_session(session, dbstate_attributes)
|
||||||
dbstate.state_attributes = dbstate_attributes
|
dbstate.state_attributes = dbstate_attributes
|
||||||
|
|
||||||
session.add(dbstate)
|
self._add_to_session(session, dbstate)
|
||||||
|
|
||||||
def _handle_database_error(self, err: Exception) -> bool:
|
def _handle_database_error(self, err: Exception) -> bool:
|
||||||
"""Handle a database error that may result in moving away the corrupt db."""
|
"""Handle a database error that may result in moving away the corrupt db."""
|
||||||
|
@ -1130,14 +1136,9 @@ class Recorder(threading.Thread):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _event_session_has_pending_writes(self) -> bool:
|
|
||||||
"""Return True if there are pending writes in the event session."""
|
|
||||||
session = self.event_session
|
|
||||||
return bool(session and (session.new or session.dirty))
|
|
||||||
|
|
||||||
def _commit_event_session_or_retry(self) -> None:
|
def _commit_event_session_or_retry(self) -> None:
|
||||||
"""Commit the event session if there is work to do."""
|
"""Commit the event session if there is work to do."""
|
||||||
if not self._event_session_has_pending_writes():
|
if not self._event_session_has_pending_writes:
|
||||||
return
|
return
|
||||||
tries = 1
|
tries = 1
|
||||||
while tries <= self.db_max_retries:
|
while tries <= self.db_max_retries:
|
||||||
|
@ -1163,6 +1164,7 @@ class Recorder(threading.Thread):
|
||||||
self._commits_without_expire += 1
|
self._commits_without_expire += 1
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
self._event_session_has_pending_writes = False
|
||||||
# We just committed the state attributes to the database
|
# We just committed the state attributes to the database
|
||||||
# and we now know the attributes_ids. We can save
|
# and we now know the attributes_ids. We can save
|
||||||
# many selects for matching attributes by loading them
|
# many selects for matching attributes by loading them
|
||||||
|
@ -1263,7 +1265,7 @@ class Recorder(threading.Thread):
|
||||||
|
|
||||||
async def async_block_till_done(self) -> None:
|
async def async_block_till_done(self) -> None:
|
||||||
"""Async version of block_till_done."""
|
"""Async version of block_till_done."""
|
||||||
if self._queue.empty() and not self._event_session_has_pending_writes():
|
if self._queue.empty() and not self._event_session_has_pending_writes:
|
||||||
return
|
return
|
||||||
event = asyncio.Event()
|
event = asyncio.Event()
|
||||||
self.queue_task(SynchronizeTask(event))
|
self.queue_task(SynchronizeTask(event))
|
||||||
|
@ -1417,6 +1419,8 @@ class Recorder(threading.Thread):
|
||||||
if self.event_session is None:
|
if self.event_session is None:
|
||||||
return
|
return
|
||||||
if self.recorder_runs_manager.active:
|
if self.recorder_runs_manager.active:
|
||||||
|
# .end will add to the event session
|
||||||
|
self._event_session_has_pending_writes = True
|
||||||
self.recorder_runs_manager.end(self.event_session)
|
self.recorder_runs_manager.end(self.event_session)
|
||||||
try:
|
try:
|
||||||
self._commit_event_session_or_retry()
|
self._commit_event_session_or_retry()
|
||||||
|
|
Loading…
Add table
Reference in a new issue