Make Store a generic class (#74617)

This commit is contained in:
epenet 2022-07-09 22:32:57 +02:00 committed by GitHub
parent d37ad20894
commit 16900dcef1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
27 changed files with 106 additions and 97 deletions

View file

@ -2,14 +2,14 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
from collections.abc import Callable, Mapping, Sequence
from contextlib import suppress
from copy import deepcopy
import inspect
from json import JSONEncoder
import logging
import os
from typing import Any
from typing import Any, Generic, TypeVar, Union
from homeassistant.const import EVENT_HOMEASSISTANT_FINAL_WRITE
from homeassistant.core import CALLBACK_TYPE, CoreState, Event, HomeAssistant, callback
@ -24,6 +24,8 @@ _LOGGER = logging.getLogger(__name__)
STORAGE_SEMAPHORE = "storage_semaphore"
_T = TypeVar("_T", bound=Union[Mapping[str, Any], Sequence[Any]])
@bind_hass
async def async_migrator(
@ -66,7 +68,7 @@ async def async_migrator(
@bind_hass
class Store:
class Store(Generic[_T]):
"""Class to help storing data."""
def __init__(
@ -90,7 +92,7 @@ class Store:
self._unsub_delay_listener: CALLBACK_TYPE | None = None
self._unsub_final_write_listener: CALLBACK_TYPE | None = None
self._write_lock = asyncio.Lock()
self._load_task: asyncio.Future | None = None
self._load_task: asyncio.Future[_T | None] | None = None
self._encoder = encoder
self._atomic_writes = atomic_writes
@ -99,7 +101,7 @@ class Store:
"""Return the config path."""
return self.hass.config.path(STORAGE_DIR, self.key)
async def async_load(self) -> dict | list | None:
async def async_load(self) -> _T | None:
"""Load data.
If the expected version and minor version do not match the given versions, the
@ -113,7 +115,7 @@ class Store:
return await self._load_task
async def _async_load(self):
async def _async_load(self) -> _T | None:
"""Load the data and ensure the task is removed."""
if STORAGE_SEMAPHORE not in self.hass.data:
self.hass.data[STORAGE_SEMAPHORE] = asyncio.Semaphore(MAX_LOAD_CONCURRENTLY)
@ -178,7 +180,7 @@ class Store:
return stored
async def async_save(self, data: dict | list) -> None:
async def async_save(self, data: _T) -> None:
"""Save data."""
self._data = {
"version": self.version,
@ -196,7 +198,7 @@ class Store:
@callback
def async_delay_save(
self,
data_func: Callable[[], dict | list],
data_func: Callable[[], _T],
delay: float = 0,
) -> None:
"""Save data with an optional delay."""