Make the recorder LRU adjustment threadsafe (#88443)

This commit is contained in:
J. Nick Koston 2023-02-19 12:30:08 -06:00 committed by GitHub
parent 0d832c0a5a
commit a9731a7b26
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 7 deletions

View file

@ -78,6 +78,7 @@ from .pool import POOL_SIZE, MutexPool, RecorderPool
from .queries import find_shared_attributes_id, find_shared_data_id from .queries import find_shared_attributes_id, find_shared_data_id
from .run_history import RunHistory from .run_history import RunHistory
from .tasks import ( from .tasks import (
AdjustLRUSizeTask,
AdjustStatisticsTask, AdjustStatisticsTask,
ChangeStatisticsUnitTask, ChangeStatisticsUnitTask,
ClearStatisticsTask, ClearStatisticsTask,
@ -131,6 +132,7 @@ SHUTDOWN_TASK = object()
COMMIT_TASK = CommitTask() COMMIT_TASK = CommitTask()
KEEP_ALIVE_TASK = KeepAliveTask() KEEP_ALIVE_TASK = KeepAliveTask()
WAIT_TASK = WaitTask() WAIT_TASK = WaitTask()
ADJUST_LRU_SIZE_TASK = AdjustLRUSizeTask()
DB_LOCK_TIMEOUT = 30 DB_LOCK_TIMEOUT = 30
DB_LOCK_QUEUE_CHECK_TIMEOUT = 1 DB_LOCK_QUEUE_CHECK_TIMEOUT = 1
@ -411,7 +413,6 @@ class Recorder(threading.Thread):
@callback @callback
def _async_hass_started(self, hass: HomeAssistant) -> None: def _async_hass_started(self, hass: HomeAssistant) -> None:
"""Notify that hass has started.""" """Notify that hass has started."""
self.async_adjust_lru()
self._hass_started.set_result(None) self._hass_started.set_result(None)
@callback @callback
@ -478,20 +479,20 @@ class Recorder(threading.Thread):
@callback @callback
def _async_five_minute_tasks(self, now: datetime) -> None: def _async_five_minute_tasks(self, now: datetime) -> None:
"""Run tasks every five minutes.""" """Run tasks every five minutes."""
self.async_adjust_lru() self.queue_task(ADJUST_LRU_SIZE_TASK)
self.async_periodic_statistics() self.async_periodic_statistics()
@callback def _adjust_lru_size(self) -> None:
def async_adjust_lru(self) -> None:
"""Trigger the LRU adjustment. """Trigger the LRU adjustment.
If the number of entities has increased, increase the size of the LRU If the number of entities has increased, increase the size of the LRU
cache to avoid thrashing. cache to avoid thrashing.
""" """
current_size = self._state_attributes_ids.get_size() state_attributes_lru = self._state_attributes_ids
current_size = state_attributes_lru.get_size()
new_size = self.hass.states.async_entity_ids_count() * 2 new_size = self.hass.states.async_entity_ids_count() * 2
if new_size > current_size: if new_size > current_size:
self._state_attributes_ids.set_size(new_size) state_attributes_lru.set_size(new_size)
@callback @callback
def async_periodic_statistics(self) -> None: def async_periodic_statistics(self) -> None:
@ -677,6 +678,7 @@ class Recorder(threading.Thread):
self._schedule_compile_missing_statistics(session) self._schedule_compile_missing_statistics(session)
_LOGGER.debug("Recorder processing the queue") _LOGGER.debug("Recorder processing the queue")
self._adjust_lru_size()
self.hass.add_job(self._async_set_recorder_ready_migration_done) self.hass.add_job(self._async_set_recorder_ready_migration_done)
self._run_event_loop() self._run_event_loop()

View file

@ -328,3 +328,14 @@ class StatisticsTimestampMigrationCleanupTask(RecorderTask):
if not statistics.cleanup_statistics_timestamp_migration(instance): if not statistics.cleanup_statistics_timestamp_migration(instance):
# Schedule a new statistics migration task if this one didn't finish # Schedule a new statistics migration task if this one didn't finish
instance.queue_task(StatisticsTimestampMigrationCleanupTask()) instance.queue_task(StatisticsTimestampMigrationCleanupTask())
@dataclass
class AdjustLRUSizeTask(RecorderTask):
"""An object to insert into the recorder queue to adjust the LRU size."""
commit_before = False
def run(self, instance: Recorder) -> None:
"""Handle the task to adjust the size."""
instance._adjust_lru_size() # pylint: disable=[protected-access]

View file

@ -2087,6 +2087,6 @@ async def test_lru_increases_with_many_entities(
hass.states, "async_entity_ids_count", return_value=mock_entity_count hass.states, "async_entity_ids_count", return_value=mock_entity_count
): ):
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=10)) async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=10))
await hass.async_block_till_done() await async_wait_recording_done(hass)
assert recorder_mock._state_attributes_ids.get_size() == mock_entity_count * 2 assert recorder_mock._state_attributes_ids.get_size() == mock_entity_count * 2