Make Store a generic class (#74617)
This commit is contained in:
parent
d37ad20894
commit
16900dcef1
27 changed files with 106 additions and 97 deletions
|
@ -2,14 +2,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from contextlib import suppress
|
||||
from copy import deepcopy
|
||||
import inspect
|
||||
from json import JSONEncoder
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Any, Generic, TypeVar, Union
|
||||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_FINAL_WRITE
|
||||
from homeassistant.core import CALLBACK_TYPE, CoreState, Event, HomeAssistant, callback
|
||||
|
@ -24,6 +24,8 @@ _LOGGER = logging.getLogger(__name__)
|
|||
|
||||
STORAGE_SEMAPHORE = "storage_semaphore"
|
||||
|
||||
_T = TypeVar("_T", bound=Union[Mapping[str, Any], Sequence[Any]])
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_migrator(
|
||||
|
@ -66,7 +68,7 @@ async def async_migrator(
|
|||
|
||||
|
||||
@bind_hass
|
||||
class Store:
|
||||
class Store(Generic[_T]):
|
||||
"""Class to help storing data."""
|
||||
|
||||
def __init__(
|
||||
|
@ -90,7 +92,7 @@ class Store:
|
|||
self._unsub_delay_listener: CALLBACK_TYPE | None = None
|
||||
self._unsub_final_write_listener: CALLBACK_TYPE | None = None
|
||||
self._write_lock = asyncio.Lock()
|
||||
self._load_task: asyncio.Future | None = None
|
||||
self._load_task: asyncio.Future[_T | None] | None = None
|
||||
self._encoder = encoder
|
||||
self._atomic_writes = atomic_writes
|
||||
|
||||
|
@ -99,7 +101,7 @@ class Store:
|
|||
"""Return the config path."""
|
||||
return self.hass.config.path(STORAGE_DIR, self.key)
|
||||
|
||||
async def async_load(self) -> dict | list | None:
|
||||
async def async_load(self) -> _T | None:
|
||||
"""Load data.
|
||||
|
||||
If the expected version and minor version do not match the given versions, the
|
||||
|
@ -113,7 +115,7 @@ class Store:
|
|||
|
||||
return await self._load_task
|
||||
|
||||
async def _async_load(self):
|
||||
async def _async_load(self) -> _T | None:
|
||||
"""Load the data and ensure the task is removed."""
|
||||
if STORAGE_SEMAPHORE not in self.hass.data:
|
||||
self.hass.data[STORAGE_SEMAPHORE] = asyncio.Semaphore(MAX_LOAD_CONCURRENTLY)
|
||||
|
@ -178,7 +180,7 @@ class Store:
|
|||
|
||||
return stored
|
||||
|
||||
async def async_save(self, data: dict | list) -> None:
|
||||
async def async_save(self, data: _T) -> None:
|
||||
"""Save data."""
|
||||
self._data = {
|
||||
"version": self.version,
|
||||
|
@ -196,7 +198,7 @@ class Store:
|
|||
@callback
|
||||
def async_delay_save(
|
||||
self,
|
||||
data_func: Callable[[], dict | list],
|
||||
data_func: Callable[[], _T],
|
||||
delay: float = 0,
|
||||
) -> None:
|
||||
"""Save data with an optional delay."""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue