Avoid creating inner tasks to load storage (#117099)

This commit is contained in:
J. Nick Koston 2024-05-08 16:41:20 -05:00 committed by GitHub
parent ead69af27c
commit 03dcede211
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 39 additions and 16 deletions

View file

@ -254,7 +254,7 @@ class Store(Generic[_T]):
self._delay_handle: asyncio.TimerHandle | None = None
self._unsub_final_write_listener: CALLBACK_TYPE | None = None
self._write_lock = asyncio.Lock()
self._load_task: asyncio.Future[_T | None] | None = None
self._load_future: asyncio.Future[_T | None] | None = None
self._encoder = encoder
self._atomic_writes = atomic_writes
self._read_only = read_only
@ -276,27 +276,32 @@ class Store(Generic[_T]):
Will ensure that when a call comes in while another one is in progress,
the second call will wait and return the result of the first call.
"""
if self._load_task:
return await self._load_task
if self._load_future:
return await self._load_future
load_task = self.hass.async_create_background_task(
self._async_load(), f"Storage load {self.key}", eager_start=True
)
if not load_task.done():
# Only set the load task if it didn't complete immediately
self._load_task = load_task
return await load_task
self._load_future = self.hass.loop.create_future()
try:
result = await self._async_load()
except BaseException as ex:
self._load_future.set_exception(ex)
# Ensure the future is marked as retrieved
# since if there is no concurrent call it
# will otherwise never be retrieved.
self._load_future.exception()
raise
else:
self._load_future.set_result(result)
finally:
self._load_future = None
return result
async def _async_load(self) -> _T | None:
"""Load the data and ensure the task is removed."""
if STORAGE_SEMAPHORE not in self.hass.data:
self.hass.data[STORAGE_SEMAPHORE] = asyncio.Semaphore(MAX_LOAD_CONCURRENTLY)
try:
async with self.hass.data[STORAGE_SEMAPHORE]:
return await self._async_load_data()
finally:
self._load_task = None
async with self.hass.data[STORAGE_SEMAPHORE]:
return await self._async_load_data()
async def _async_load_data(self):
"""Load the data."""

View file

@ -1159,3 +1159,21 @@ async def test_store_manager_cleanup_after_stop(
assert store_manager.async_fetch("integration1") is None
assert store_manager.async_fetch("integration2") is None
await hass.async_stop(force=True)
async def test_storage_concurrent_load(hass: HomeAssistant) -> None:
"""Test that we can load the store concurrently."""
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY)
async def _load_store():
await asyncio.sleep(0)
return "data"
with patch.object(store, "_async_load", side_effect=_load_store):
# Test that we can load the store concurrently
loads = await asyncio.gather(
store.async_load(), store.async_load(), store.async_load()
)
for load in loads:
assert load == "data"