From 03dcede211ca6103cb4f83517af4a98d33795e2f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 8 May 2024 16:41:20 -0500 Subject: [PATCH] Avoid creating inner tasks to load storage (#117099) --- homeassistant/helpers/storage.py | 37 ++++++++++++++++++-------------- tests/helpers/test_storage.py | 18 ++++++++++++++++ 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/homeassistant/helpers/storage.py b/homeassistant/helpers/storage.py index 41c8cc32fd0..43540578429 100644 --- a/homeassistant/helpers/storage.py +++ b/homeassistant/helpers/storage.py @@ -254,7 +254,7 @@ class Store(Generic[_T]): self._delay_handle: asyncio.TimerHandle | None = None self._unsub_final_write_listener: CALLBACK_TYPE | None = None self._write_lock = asyncio.Lock() - self._load_task: asyncio.Future[_T | None] | None = None + self._load_future: asyncio.Future[_T | None] | None = None self._encoder = encoder self._atomic_writes = atomic_writes self._read_only = read_only @@ -276,27 +276,32 @@ class Store(Generic[_T]): Will ensure that when a call comes in while another one is in progress, the second call will wait and return the result of the first call. """ - if self._load_task: - return await self._load_task + if self._load_future: + return await self._load_future - load_task = self.hass.async_create_background_task( - self._async_load(), f"Storage load {self.key}", eager_start=True - ) - if not load_task.done(): - # Only set the load task if it didn't complete immediately - self._load_task = load_task - return await load_task + self._load_future = self.hass.loop.create_future() + try: + result = await self._async_load() + except BaseException as ex: + self._load_future.set_exception(ex) + # Ensure the future is marked as retrieved + # since if there is no concurrent call it + # will otherwise never be retrieved. + self._load_future.exception() + raise + else: + self._load_future.set_result(result) + finally: + self._load_future = None + + return result async def _async_load(self) -> _T | None: """Load the data and ensure the task is removed.""" if STORAGE_SEMAPHORE not in self.hass.data: self.hass.data[STORAGE_SEMAPHORE] = asyncio.Semaphore(MAX_LOAD_CONCURRENTLY) - - try: - async with self.hass.data[STORAGE_SEMAPHORE]: - return await self._async_load_data() - finally: - self._load_task = None + async with self.hass.data[STORAGE_SEMAPHORE]: + return await self._async_load_data() async def _async_load_data(self): """Load the data.""" diff --git a/tests/helpers/test_storage.py b/tests/helpers/test_storage.py index 12dc56db85d..577e81d1a44 100644 --- a/tests/helpers/test_storage.py +++ b/tests/helpers/test_storage.py @@ -1159,3 +1159,21 @@ async def test_store_manager_cleanup_after_stop( assert store_manager.async_fetch("integration1") is None assert store_manager.async_fetch("integration2") is None await hass.async_stop(force=True) + + +async def test_storage_concurrent_load(hass: HomeAssistant) -> None: + """Test that we can load the store concurrently.""" + + store = storage.Store(hass, MOCK_VERSION, MOCK_KEY) + + async def _load_store(): + await asyncio.sleep(0) + return "data" + + with patch.object(store, "_async_load", side_effect=_load_store): + # Test that we can load the store concurrently + loads = await asyncio.gather( + store.async_load(), store.async_load(), store.async_load() + ) + for load in loads: + assert load == "data"