Avoid creating inner tasks to load storage (#117099)
This commit is contained in:
parent
ead69af27c
commit
03dcede211
2 changed files with 39 additions and 16 deletions
|
@ -254,7 +254,7 @@ class Store(Generic[_T]):
|
|||
self._delay_handle: asyncio.TimerHandle | None = None
|
||||
self._unsub_final_write_listener: CALLBACK_TYPE | None = None
|
||||
self._write_lock = asyncio.Lock()
|
||||
self._load_task: asyncio.Future[_T | None] | None = None
|
||||
self._load_future: asyncio.Future[_T | None] | None = None
|
||||
self._encoder = encoder
|
||||
self._atomic_writes = atomic_writes
|
||||
self._read_only = read_only
|
||||
|
@ -276,27 +276,32 @@ class Store(Generic[_T]):
|
|||
Will ensure that when a call comes in while another one is in progress,
|
||||
the second call will wait and return the result of the first call.
|
||||
"""
|
||||
if self._load_task:
|
||||
return await self._load_task
|
||||
if self._load_future:
|
||||
return await self._load_future
|
||||
|
||||
load_task = self.hass.async_create_background_task(
|
||||
self._async_load(), f"Storage load {self.key}", eager_start=True
|
||||
)
|
||||
if not load_task.done():
|
||||
# Only set the load task if it didn't complete immediately
|
||||
self._load_task = load_task
|
||||
return await load_task
|
||||
self._load_future = self.hass.loop.create_future()
|
||||
try:
|
||||
result = await self._async_load()
|
||||
except BaseException as ex:
|
||||
self._load_future.set_exception(ex)
|
||||
# Ensure the future is marked as retrieved
|
||||
# since if there is no concurrent call it
|
||||
# will otherwise never be retrieved.
|
||||
self._load_future.exception()
|
||||
raise
|
||||
else:
|
||||
self._load_future.set_result(result)
|
||||
finally:
|
||||
self._load_future = None
|
||||
|
||||
return result
|
||||
|
||||
async def _async_load(self) -> _T | None:
|
||||
"""Load the data and ensure the task is removed."""
|
||||
if STORAGE_SEMAPHORE not in self.hass.data:
|
||||
self.hass.data[STORAGE_SEMAPHORE] = asyncio.Semaphore(MAX_LOAD_CONCURRENTLY)
|
||||
|
||||
try:
|
||||
async with self.hass.data[STORAGE_SEMAPHORE]:
|
||||
return await self._async_load_data()
|
||||
finally:
|
||||
self._load_task = None
|
||||
async with self.hass.data[STORAGE_SEMAPHORE]:
|
||||
return await self._async_load_data()
|
||||
|
||||
async def _async_load_data(self):
|
||||
"""Load the data."""
|
||||
|
|
|
@ -1159,3 +1159,21 @@ async def test_store_manager_cleanup_after_stop(
|
|||
assert store_manager.async_fetch("integration1") is None
|
||||
assert store_manager.async_fetch("integration2") is None
|
||||
await hass.async_stop(force=True)
|
||||
|
||||
|
||||
async def test_storage_concurrent_load(hass: HomeAssistant) -> None:
|
||||
"""Test that we can load the store concurrently."""
|
||||
|
||||
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY)
|
||||
|
||||
async def _load_store():
|
||||
await asyncio.sleep(0)
|
||||
return "data"
|
||||
|
||||
with patch.object(store, "_async_load", side_effect=_load_store):
|
||||
# Test that we can load the store concurrently
|
||||
loads = await asyncio.gather(
|
||||
store.async_load(), store.async_load(), store.async_load()
|
||||
)
|
||||
for load in loads:
|
||||
assert load == "data"
|
||||
|
|
Loading…
Add table
Reference in a new issue