diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index 43915c0187b..67d3bff3b2a 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -215,6 +215,7 @@ class Recorder(threading.Thread): self.schema_version = 0 self._commits_without_expire = 0 + self._event_session_has_pending_writes = False self.recorder_runs_manager = RecorderRunsManager() self.states_manager = StatesManager() @@ -322,7 +323,7 @@ class Recorder(threading.Thread): if ( self._event_listener and not self._database_lock_task - and self._event_session_has_pending_writes() + and self._event_session_has_pending_writes ): self.queue_task(COMMIT_TASK) @@ -688,6 +689,11 @@ class Recorder(threading.Thread): # anything goes wrong in the run loop self._shutdown() + def _add_to_session(self, session: Session, obj: object) -> None: + """Add an object to the session.""" + self._event_session_has_pending_writes = True + session.add(obj) + def _run(self) -> None: """Start processing events to save.""" self.thread_id = threading.get_ident() @@ -1016,11 +1022,11 @@ class Recorder(threading.Thread): else: event_types = EventTypes(event_type=event.event_type) event_type_manager.add_pending(event_types) - session.add(event_types) + self._add_to_session(session, event_types) dbevent.event_type_rel = event_types if not event.data: - session.add(dbevent) + self._add_to_session(session, dbevent) return event_data_manager = self.event_data_manager @@ -1042,10 +1048,10 @@ class Recorder(threading.Thread): # No matching attributes found, save them in the DB dbevent_data = EventData(shared_data=shared_data, hash=hash_) event_data_manager.add_pending(dbevent_data) - session.add(dbevent_data) + self._add_to_session(session, dbevent_data) dbevent.event_data_rel = dbevent_data - session.add(dbevent) + self._add_to_session(session, dbevent) def _process_state_changed_event_into_session(self, event: Event) -> None: """Process a state_changed event into the session.""" @@ -1090,7 +1096,7 @@ class Recorder(threading.Thread): else: states_meta = StatesMeta(entity_id=entity_id) states_meta_manager.add_pending(states_meta) - session.add(states_meta) + self._add_to_session(session, states_meta) dbstate.states_meta_rel = states_meta # Map the event data to the StateAttributes table @@ -1115,10 +1121,10 @@ class Recorder(threading.Thread): # No matching attributes found, save them in the DB dbstate_attributes = StateAttributes(shared_attrs=shared_attrs, hash=hash_) state_attributes_manager.add_pending(dbstate_attributes) - session.add(dbstate_attributes) + self._add_to_session(session, dbstate_attributes) dbstate.state_attributes = dbstate_attributes - session.add(dbstate) + self._add_to_session(session, dbstate) def _handle_database_error(self, err: Exception) -> bool: """Handle a database error that may result in moving away the corrupt db.""" @@ -1130,14 +1136,9 @@ class Recorder(threading.Thread): return True return False - def _event_session_has_pending_writes(self) -> bool: - """Return True if there are pending writes in the event session.""" - session = self.event_session - return bool(session and (session.new or session.dirty)) - def _commit_event_session_or_retry(self) -> None: """Commit the event session if there is work to do.""" - if not self._event_session_has_pending_writes(): + if not self._event_session_has_pending_writes: return tries = 1 while tries <= self.db_max_retries: @@ -1163,6 +1164,7 @@ class Recorder(threading.Thread): self._commits_without_expire += 1 session.commit() + self._event_session_has_pending_writes = False # We just committed the state attributes to the database # and we now know the attributes_ids. We can save # many selects for matching attributes by loading them @@ -1263,7 +1265,7 @@ class Recorder(threading.Thread): async def async_block_till_done(self) -> None: """Async version of block_till_done.""" - if self._queue.empty() and not self._event_session_has_pending_writes(): + if self._queue.empty() and not self._event_session_has_pending_writes: return event = asyncio.Event() self.queue_task(SynchronizeTask(event)) @@ -1417,6 +1419,8 @@ class Recorder(threading.Thread): if self.event_session is None: return if self.recorder_runs_manager.active: + # .end will add to the event session + self._event_session_has_pending_writes = True self.recorder_runs_manager.end(self.event_session) try: self._commit_event_session_or_retry()