Make Store a generic class (#74617)
This commit is contained in:
parent
d37ad20894
commit
16900dcef1
27 changed files with 106 additions and 97 deletions
|
@ -46,7 +46,7 @@ class AuthStore:
|
|||
self._users: dict[str, models.User] | None = None
|
||||
self._groups: dict[str, models.Group] | None = None
|
||||
self._perm_lookup: PermissionLookup | None = None
|
||||
self._store = Store(
|
||||
self._store = Store[dict[str, list[dict[str, Any]]]](
|
||||
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
|
||||
)
|
||||
self._lock = asyncio.Lock()
|
||||
|
@ -483,9 +483,10 @@ class AuthStore:
|
|||
jwt_key=rt_dict["jwt_key"],
|
||||
last_used_at=last_used_at,
|
||||
last_used_ip=rt_dict.get("last_used_ip"),
|
||||
credential=credentials.get(rt_dict.get("credential_id")),
|
||||
version=rt_dict.get("version"),
|
||||
)
|
||||
if "credential_id" in rt_dict:
|
||||
token.credential = credentials.get(rt_dict["credential_id"])
|
||||
users[rt_dict["user_id"]].refresh_tokens[token.id] = token
|
||||
|
||||
self._groups = groups
|
||||
|
|
|
@ -7,7 +7,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
from collections import OrderedDict
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import attr
|
||||
import voluptuous as vol
|
||||
|
@ -100,7 +100,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
|
|||
"""Initialize the user data store."""
|
||||
super().__init__(hass, config)
|
||||
self._user_settings: _UsersDict | None = None
|
||||
self._user_store = Store(
|
||||
self._user_store = Store[dict[str, dict[str, Any]]](
|
||||
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
|
||||
)
|
||||
self._include = config.get(CONF_INCLUDE, [])
|
||||
|
@ -119,10 +119,8 @@ class NotifyAuthModule(MultiFactorAuthModule):
|
|||
if self._user_settings is not None:
|
||||
return
|
||||
|
||||
if (data := await self._user_store.async_load()) is None or not isinstance(
|
||||
data, dict
|
||||
):
|
||||
data = {STORAGE_USERS: {}}
|
||||
if (data := await self._user_store.async_load()) is None:
|
||||
data = cast(dict[str, dict[str, Any]], {STORAGE_USERS: {}})
|
||||
|
||||
self._user_settings = {
|
||||
user_id: NotifySetting(**setting)
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -77,7 +77,7 @@ class TotpAuthModule(MultiFactorAuthModule):
|
|||
"""Initialize the user data store."""
|
||||
super().__init__(hass, config)
|
||||
self._users: dict[str, str] | None = None
|
||||
self._user_store = Store(
|
||||
self._user_store = Store[dict[str, dict[str, str]]](
|
||||
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
|
||||
)
|
||||
self._init_lock = asyncio.Lock()
|
||||
|
@ -93,16 +93,14 @@ class TotpAuthModule(MultiFactorAuthModule):
|
|||
if self._users is not None:
|
||||
return
|
||||
|
||||
if (data := await self._user_store.async_load()) is None or not isinstance(
|
||||
data, dict
|
||||
):
|
||||
data = {STORAGE_USERS: {}}
|
||||
if (data := await self._user_store.async_load()) is None:
|
||||
data = cast(dict[str, dict[str, str]], {STORAGE_USERS: {}})
|
||||
|
||||
self._users = data.get(STORAGE_USERS, {})
|
||||
|
||||
async def _async_save(self) -> None:
|
||||
"""Save data."""
|
||||
await self._user_store.async_save({STORAGE_USERS: self._users})
|
||||
await self._user_store.async_save({STORAGE_USERS: self._users or {}})
|
||||
|
||||
def _add_ota_secret(self, user_id: str, secret: str | None = None) -> str:
|
||||
"""Create a ota_secret for user."""
|
||||
|
|
|
@ -61,10 +61,10 @@ class Data:
|
|||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the user data store."""
|
||||
self.hass = hass
|
||||
self._store = Store(
|
||||
self._store = Store[dict[str, list[dict[str, str]]]](
|
||||
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
|
||||
)
|
||||
self._data: dict[str, Any] | None = None
|
||||
self._data: dict[str, list[dict[str, str]]] | None = None
|
||||
# Legacy mode will allow usernames to start/end with whitespace
|
||||
# and will compare usernames case-insensitive.
|
||||
# Remove in 2020 or when we launch 1.0.
|
||||
|
@ -80,10 +80,8 @@ class Data:
|
|||
|
||||
async def async_load(self) -> None:
|
||||
"""Load stored data."""
|
||||
if (data := await self._store.async_load()) is None or not isinstance(
|
||||
data, dict
|
||||
):
|
||||
data = {"users": []}
|
||||
if (data := await self._store.async_load()) is None:
|
||||
data = cast(dict[str, list[dict[str, str]]], {"users": []})
|
||||
|
||||
seen: set[str] = set()
|
||||
|
||||
|
@ -123,7 +121,8 @@ class Data:
|
|||
@property
|
||||
def users(self) -> list[dict[str, str]]:
|
||||
"""Return users."""
|
||||
return self._data["users"] # type: ignore[index,no-any-return]
|
||||
assert self._data is not None
|
||||
return self._data["users"]
|
||||
|
||||
def validate_login(self, username: str, password: str) -> None:
|
||||
"""Validate a username and password.
|
||||
|
|
|
@ -5,7 +5,7 @@ import asyncio
|
|||
from datetime import timedelta
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, cast
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import ClientError, ClientSession
|
||||
import async_timeout
|
||||
|
@ -167,8 +167,8 @@ async def _configure_almond_for_ha(
|
|||
return
|
||||
|
||||
_LOGGER.debug("Configuring Almond to connect to Home Assistant at %s", hass_url)
|
||||
store = storage.Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
data = cast(Optional[dict], await store.async_load())
|
||||
store = storage.Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
data = await store.async_load()
|
||||
|
||||
if data is None:
|
||||
data = {}
|
||||
|
|
|
@ -64,7 +64,7 @@ async def async_setup_entry(
|
|||
"""Set up the Ambiclimate device from config entry."""
|
||||
config = entry.data
|
||||
websession = async_get_clientsession(hass)
|
||||
store = Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
token_info = await store.async_load()
|
||||
|
||||
oauth = ambiclimate.AmbiclimateOAuth(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
"""Analytics helper class for the analytics integration."""
|
||||
import asyncio
|
||||
from typing import cast
|
||||
from typing import Any
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
|
@ -66,12 +66,12 @@ class Analytics:
|
|||
"""Initialize the Analytics class."""
|
||||
self.hass: HomeAssistant = hass
|
||||
self.session = async_get_clientsession(hass)
|
||||
self._data: dict = {
|
||||
self._data: dict[str, Any] = {
|
||||
ATTR_PREFERENCES: {},
|
||||
ATTR_ONBOARDED: False,
|
||||
ATTR_UUID: None,
|
||||
}
|
||||
self._store = Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
self._store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
|
||||
@property
|
||||
def preferences(self) -> dict:
|
||||
|
@ -109,7 +109,7 @@ class Analytics:
|
|||
|
||||
async def load(self) -> None:
|
||||
"""Load preferences."""
|
||||
stored = cast(dict, await self._store.async_load())
|
||||
stored = await self._store.async_load()
|
||||
if stored:
|
||||
self._data = stored
|
||||
|
||||
|
|
|
@ -36,14 +36,14 @@ class CameraPreferences:
|
|||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize camera prefs."""
|
||||
self._hass = hass
|
||||
self._store = Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
self._store = Store[dict[str, dict[str, bool]]](
|
||||
hass, STORAGE_VERSION, STORAGE_KEY
|
||||
)
|
||||
self._prefs: dict[str, dict[str, bool]] | None = None
|
||||
|
||||
async def async_initialize(self) -> None:
|
||||
"""Finish initializing the preferences."""
|
||||
if (prefs := await self._store.async_load()) is None or not isinstance(
|
||||
prefs, dict
|
||||
):
|
||||
if (prefs := await self._store.async_load()) is None:
|
||||
prefs = {}
|
||||
|
||||
self._prefs = prefs
|
||||
|
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
from collections import Counter
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Literal, Optional, TypedDict, Union, cast
|
||||
from typing import Literal, TypedDict, Union
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -263,13 +263,15 @@ class EnergyManager:
|
|||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize energy manager."""
|
||||
self._hass = hass
|
||||
self._store = storage.Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
self._store = storage.Store[EnergyPreferences](
|
||||
hass, STORAGE_VERSION, STORAGE_KEY
|
||||
)
|
||||
self.data: EnergyPreferences | None = None
|
||||
self._update_listeners: list[Callable[[], Awaitable]] = []
|
||||
|
||||
async def async_initialize(self) -> None:
|
||||
"""Initialize the energy integration."""
|
||||
self.data = cast(Optional[EnergyPreferences], await self._store.async_load())
|
||||
self.data = await self._store.async_load()
|
||||
|
||||
@staticmethod
|
||||
def default_preferences() -> EnergyPreferences:
|
||||
|
@ -294,7 +296,7 @@ class EnergyManager:
|
|||
data[key] = update[key] # type: ignore[literal-required]
|
||||
|
||||
self.data = data
|
||||
self._store.async_delay_save(lambda: cast(dict, self.data), 60)
|
||||
self._store.async_delay_save(lambda: data, 60)
|
||||
|
||||
if not self._update_listeners:
|
||||
return
|
||||
|
|
|
@ -533,12 +533,10 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: # noqa:
|
|||
if not await hassio.is_connected():
|
||||
_LOGGER.warning("Not connected with the supervisor / system too busy!")
|
||||
|
||||
store = Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
store = Store[dict[str, str]](hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
if (data := await store.async_load()) is None:
|
||||
data = {}
|
||||
|
||||
assert isinstance(data, dict)
|
||||
|
||||
refresh_token = None
|
||||
if "hassio_user" in data:
|
||||
user = await hass.auth.async_get_user(data["hassio_user"])
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, TypedDict, cast
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.storage import Store
|
||||
|
@ -46,7 +46,9 @@ class EntityMapStorage:
|
|||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Create a new entity map store."""
|
||||
self.hass = hass
|
||||
self.store = Store(hass, ENTITY_MAP_STORAGE_VERSION, ENTITY_MAP_STORAGE_KEY)
|
||||
self.store = Store[StorageLayout](
|
||||
hass, ENTITY_MAP_STORAGE_VERSION, ENTITY_MAP_STORAGE_KEY
|
||||
)
|
||||
self.storage_data: dict[str, Pairing] = {}
|
||||
|
||||
async def async_initialize(self) -> None:
|
||||
|
@ -55,8 +57,7 @@ class EntityMapStorage:
|
|||
# There is no cached data about HomeKit devices yet
|
||||
return
|
||||
|
||||
storage = cast(StorageLayout, raw_storage)
|
||||
self.storage_data = storage.get("pairings", {})
|
||||
self.storage_data = raw_storage.get("pairings", {})
|
||||
|
||||
def get_map(self, homekit_id: str) -> Pairing | None:
|
||||
"""Get a pairing cache item."""
|
||||
|
@ -87,6 +88,6 @@ class EntityMapStorage:
|
|||
self.store.async_delay_save(self._data_to_save, ENTITY_MAP_SAVE_DELAY)
|
||||
|
||||
@callback
|
||||
def _data_to_save(self) -> dict[str, Any]:
|
||||
def _data_to_save(self) -> StorageLayout:
|
||||
"""Return data of entity map to store in a file."""
|
||||
return {"pairings": self.storage_data}
|
||||
return StorageLayout(pairings=self.storage_data)
|
||||
|
|
|
@ -7,7 +7,7 @@ import logging
|
|||
import os
|
||||
import ssl
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Any, Final, Optional, TypedDict, Union, cast
|
||||
from typing import Any, Final, TypedDict, Union, cast
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.typedefs import StrOrURL
|
||||
|
@ -125,10 +125,10 @@ class ConfData(TypedDict, total=False):
|
|||
|
||||
|
||||
@bind_hass
|
||||
async def async_get_last_config(hass: HomeAssistant) -> dict | None:
|
||||
async def async_get_last_config(hass: HomeAssistant) -> dict[str, Any] | None:
|
||||
"""Return the last known working config."""
|
||||
store = storage.Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
return cast(Optional[dict], await store.async_load())
|
||||
store = storage.Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
return await store.async_load()
|
||||
|
||||
|
||||
class ApiConfig:
|
||||
|
@ -475,7 +475,9 @@ async def start_http_server_and_save_config(
|
|||
await server.start()
|
||||
|
||||
# If we are set up successful, we store the HTTP settings for safe mode.
|
||||
store = storage.Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
store: storage.Store[dict[str, Any]] = storage.Store(
|
||||
hass, STORAGE_VERSION, STORAGE_KEY
|
||||
)
|
||||
|
||||
if CONF_TRUSTED_PROXIES in conf:
|
||||
conf[CONF_TRUSTED_PROXIES] = [
|
||||
|
|
|
@ -6,7 +6,7 @@ from datetime import timedelta
|
|||
from ipaddress import ip_address
|
||||
import logging
|
||||
import secrets
|
||||
from typing import Final
|
||||
from typing import Any, Final
|
||||
|
||||
from aiohttp import hdrs
|
||||
from aiohttp.web import Application, Request, StreamResponse, middleware
|
||||
|
@ -118,8 +118,8 @@ def async_user_not_allowed_do_auth(
|
|||
|
||||
async def async_setup_auth(hass: HomeAssistant, app: Application) -> None:
|
||||
"""Create auth middleware for the app."""
|
||||
store = Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
if (data := await store.async_load()) is None or not isinstance(data, dict):
|
||||
store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
if (data := await store.async_load()) is None:
|
||||
data = {}
|
||||
|
||||
refresh_token = None
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""Integrates Native Apps to Home Assistant."""
|
||||
from contextlib import suppress
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.components import cloud, notify as hass_notify
|
||||
from homeassistant.components.webhook import (
|
||||
|
@ -38,7 +39,7 @@ PLATFORMS = [Platform.SENSOR, Platform.BINARY_SENSOR, Platform.DEVICE_TRACKER]
|
|||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up the mobile app component."""
|
||||
store = Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
if (app_config := await store.async_load()) is None or not isinstance(
|
||||
app_config, dict
|
||||
):
|
||||
|
|
|
@ -22,6 +22,7 @@ from collections.abc import Mapping
|
|||
from dataclasses import dataclass
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from google_nest_sdm.camera_traits import CameraClipPreviewTrait, CameraEventImageTrait
|
||||
from google_nest_sdm.device import Device
|
||||
|
@ -89,7 +90,7 @@ async def async_get_media_event_store(
|
|||
os.makedirs(media_path, exist_ok=True)
|
||||
|
||||
await hass.async_add_executor_job(mkdir)
|
||||
store = Store(hass, STORAGE_VERSION, STORAGE_KEY, private=True)
|
||||
store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY, private=True)
|
||||
return NestEventMediaStore(hass, subscriber, store, media_path)
|
||||
|
||||
|
||||
|
@ -119,7 +120,7 @@ class NestEventMediaStore(EventMediaStore):
|
|||
self,
|
||||
hass: HomeAssistant,
|
||||
subscriber: GoogleNestSubscriber,
|
||||
store: Store,
|
||||
store: Store[dict[str, Any]],
|
||||
media_path: str,
|
||||
) -> None:
|
||||
"""Initialize NestEventMediaStore."""
|
||||
|
@ -127,7 +128,7 @@ class NestEventMediaStore(EventMediaStore):
|
|||
self._subscriber = subscriber
|
||||
self._store = store
|
||||
self._media_path = media_path
|
||||
self._data: dict | None = None
|
||||
self._data: dict[str, Any] | None = None
|
||||
self._devices: Mapping[str, str] | None = {}
|
||||
|
||||
async def async_load(self) -> dict | None:
|
||||
|
@ -137,15 +138,9 @@ class NestEventMediaStore(EventMediaStore):
|
|||
if (data := await self._store.async_load()) is None:
|
||||
_LOGGER.debug("Loaded empty event store")
|
||||
self._data = {}
|
||||
elif isinstance(data, dict):
|
||||
else:
|
||||
_LOGGER.debug("Loaded event store with %d records", len(data))
|
||||
self._data = data
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unexpected data in storage version={}, key={}".format(
|
||||
STORAGE_VERSION, STORAGE_KEY
|
||||
)
|
||||
)
|
||||
return self._data
|
||||
|
||||
async def async_save(self, data: dict) -> None:
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.singleton import singleton
|
||||
|
@ -38,8 +38,10 @@ class Network:
|
|||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the Network class."""
|
||||
self._store = Store(hass, STORAGE_VERSION, STORAGE_KEY, atomic_writes=True)
|
||||
self._data: dict[str, Any] = {}
|
||||
self._store = Store[dict[str, list[str]]](
|
||||
hass, STORAGE_VERSION, STORAGE_KEY, atomic_writes=True
|
||||
)
|
||||
self._data: dict[str, list[str]] = {}
|
||||
self.adapters: list[Adapter] = []
|
||||
|
||||
@property
|
||||
|
@ -67,7 +69,7 @@ class Network:
|
|||
async def async_load(self) -> None:
|
||||
"""Load config."""
|
||||
if stored := await self._store.async_load():
|
||||
self._data = cast(dict, stored)
|
||||
self._data = stored
|
||||
|
||||
async def _async_save(self) -> None:
|
||||
"""Save preferences."""
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from homeassistant.const import __version__ as ha_version
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
|
@ -39,7 +39,9 @@ class IssueRegistry:
|
|||
"""Initialize the issue registry."""
|
||||
self.hass = hass
|
||||
self.issues: dict[tuple[str, str], IssueEntry] = {}
|
||||
self._store = Store(hass, STORAGE_VERSION, STORAGE_KEY, atomic_writes=True)
|
||||
self._store = Store[dict[str, list[dict[str, Optional[str]]]]](
|
||||
hass, STORAGE_VERSION, STORAGE_KEY, atomic_writes=True
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_get_issue(self, domain: str, issue_id: str) -> IssueEntry | None:
|
||||
|
@ -119,6 +121,7 @@ class IssueRegistry:
|
|||
|
||||
if isinstance(data, dict):
|
||||
for issue in data["issues"]:
|
||||
assert issue["domain"] and issue["issue_id"]
|
||||
issues[(issue["domain"], issue["issue_id"])] = IssueEntry(
|
||||
active=False,
|
||||
breaks_in_ha_version=None,
|
||||
|
|
|
@ -3,6 +3,7 @@ import asyncio
|
|||
import functools
|
||||
import logging
|
||||
import secrets
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
|
@ -211,8 +212,8 @@ async def setup_smartapp_endpoint(hass: HomeAssistant):
|
|||
return
|
||||
|
||||
# Get/create config to store a unique id for this hass instance.
|
||||
store = Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
if not (config := await store.async_load()) or not isinstance(config, dict):
|
||||
store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
if not (config := await store.async_load()):
|
||||
# Create config
|
||||
config = {
|
||||
CONF_INSTANCE_ID: str(uuid4()),
|
||||
|
@ -283,7 +284,7 @@ async def unload_smartapp_endpoint(hass: HomeAssistant):
|
|||
if cloudhook_url and cloud.async_is_logged_in(hass):
|
||||
await cloud.async_delete_cloudhook(hass, hass.data[DOMAIN][CONF_WEBHOOK_ID])
|
||||
# Remove cloudhook from storage
|
||||
store = Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
await store.async_save(
|
||||
{
|
||||
CONF_INSTANCE_ID: hass.data[DOMAIN][CONF_INSTANCE_ID],
|
||||
|
|
|
@ -52,7 +52,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
"""Initialize the trace integration."""
|
||||
hass.data[DATA_TRACE] = {}
|
||||
websocket_api.async_setup(hass)
|
||||
store = Store(hass, STORAGE_VERSION, STORAGE_KEY, encoder=ExtendedJSONEncoder)
|
||||
store = Store[dict[str, list]](
|
||||
hass, STORAGE_VERSION, STORAGE_KEY, encoder=ExtendedJSONEncoder
|
||||
)
|
||||
hass.data[DATA_TRACE_STORE] = store
|
||||
|
||||
async def _async_store_traces_at_stop(*_) -> None:
|
||||
|
|
|
@ -40,7 +40,7 @@ class ZhaStorage:
|
|||
"""Initialize the zha device storage."""
|
||||
self.hass: HomeAssistant = hass
|
||||
self.devices: MutableMapping[str, ZhaDeviceEntry] = {}
|
||||
self._store = Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
self._store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
|
||||
@callback
|
||||
def async_create_device(self, device: ZHADevice) -> ZhaDeviceEntry:
|
||||
|
@ -94,7 +94,7 @@ class ZhaStorage:
|
|||
|
||||
async def async_load(self) -> None:
|
||||
"""Load the registry of zha device entries."""
|
||||
data = cast(dict[str, Any], await self._store.async_load())
|
||||
data = await self._store.async_load()
|
||||
|
||||
devices: OrderedDict[str, ZhaDeviceEntry] = OrderedDict()
|
||||
|
||||
|
|
|
@ -845,7 +845,9 @@ class ConfigEntries:
|
|||
self._hass_config = hass_config
|
||||
self._entries: dict[str, ConfigEntry] = {}
|
||||
self._domain_index: dict[str, list[str]] = {}
|
||||
self._store = storage.Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
self._store = storage.Store[dict[str, list[dict[str, Any]]]](
|
||||
hass, STORAGE_VERSION, STORAGE_KEY
|
||||
)
|
||||
EntityRegistryDisabledHandler(hass).async_setup()
|
||||
|
||||
@callback
|
||||
|
|
|
@ -1942,7 +1942,7 @@ class Config:
|
|||
# pylint: disable=import-outside-toplevel
|
||||
from .helpers.storage import Store
|
||||
|
||||
store = Store(
|
||||
store = Store[dict[str, Any]](
|
||||
self.hass,
|
||||
CORE_STORAGE_VERSION,
|
||||
CORE_STORAGE_KEY,
|
||||
|
@ -1950,7 +1950,7 @@ class Config:
|
|||
atomic_writes=True,
|
||||
)
|
||||
|
||||
if not (data := await store.async_load()) or not isinstance(data, dict):
|
||||
if not (data := await store.async_load()):
|
||||
return
|
||||
|
||||
# In 2021.9 we fixed validation to disallow a path (because that's never correct)
|
||||
|
@ -1998,7 +1998,7 @@ class Config:
|
|||
"currency": self.currency,
|
||||
}
|
||||
|
||||
store = Store(
|
||||
store: Store[dict[str, Any]] = Store(
|
||||
self.hass,
|
||||
CORE_STORAGE_VERSION,
|
||||
CORE_STORAGE_KEY,
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Container, Iterable, MutableMapping
|
||||
from typing import cast
|
||||
from typing import Optional, cast
|
||||
|
||||
import attr
|
||||
|
||||
|
@ -49,7 +49,9 @@ class AreaRegistry:
|
|||
"""Initialize the area registry."""
|
||||
self.hass = hass
|
||||
self.areas: MutableMapping[str, AreaEntry] = {}
|
||||
self._store = Store(hass, STORAGE_VERSION, STORAGE_KEY, atomic_writes=True)
|
||||
self._store = Store[dict[str, list[dict[str, Optional[str]]]]](
|
||||
hass, STORAGE_VERSION, STORAGE_KEY, atomic_writes=True
|
||||
)
|
||||
self._normalized_name_area_idx: dict[str, str] = {}
|
||||
|
||||
@callback
|
||||
|
@ -176,8 +178,9 @@ class AreaRegistry:
|
|||
|
||||
areas: MutableMapping[str, AreaEntry] = OrderedDict()
|
||||
|
||||
if isinstance(data, dict):
|
||||
if data is not None:
|
||||
for area in data["areas"]:
|
||||
assert area["name"] is not None and area["id"] is not None
|
||||
normalized_name = normalize_area_name(area["name"])
|
||||
areas[area["id"]] = AreaEntry(
|
||||
name=area["name"],
|
||||
|
|
|
@ -164,7 +164,7 @@ def _async_get_device_id_from_index(
|
|||
return None
|
||||
|
||||
|
||||
class DeviceRegistryStore(storage.Store):
|
||||
class DeviceRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]):
|
||||
"""Store entity registry data."""
|
||||
|
||||
async def _async_migrate_func(
|
||||
|
@ -569,7 +569,6 @@ class DeviceRegistry:
|
|||
deleted_devices = OrderedDict()
|
||||
|
||||
if data is not None:
|
||||
data = cast("dict[str, Any]", data)
|
||||
for device in data["devices"]:
|
||||
devices[device["id"]] = DeviceEntry(
|
||||
area_id=device["area_id"],
|
||||
|
|
|
@ -16,7 +16,7 @@ LEGACY_UUID_FILE = ".uuid"
|
|||
@singleton.singleton(DATA_KEY)
|
||||
async def async_get(hass: HomeAssistant) -> str:
|
||||
"""Get unique ID for the hass instance."""
|
||||
store = storage.Store(hass, DATA_VERSION, DATA_KEY, True)
|
||||
store = storage.Store[dict[str, str]](hass, DATA_VERSION, DATA_KEY, True)
|
||||
|
||||
data: dict[str, str] | None = await storage.async_migrator(
|
||||
hass,
|
||||
|
|
|
@ -139,7 +139,7 @@ class RestoreStateData:
|
|||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the restore state data class."""
|
||||
self.hass: HomeAssistant = hass
|
||||
self.store: Store = Store(
|
||||
self.store = Store[list[dict[str, Any]]](
|
||||
hass, STORAGE_VERSION, STORAGE_KEY, encoder=JSONEncoder
|
||||
)
|
||||
self.last_states: dict[str, StoredState] = {}
|
||||
|
|
|
@ -2,14 +2,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from contextlib import suppress
|
||||
from copy import deepcopy
|
||||
import inspect
|
||||
from json import JSONEncoder
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Any, Generic, TypeVar, Union
|
||||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_FINAL_WRITE
|
||||
from homeassistant.core import CALLBACK_TYPE, CoreState, Event, HomeAssistant, callback
|
||||
|
@ -24,6 +24,8 @@ _LOGGER = logging.getLogger(__name__)
|
|||
|
||||
STORAGE_SEMAPHORE = "storage_semaphore"
|
||||
|
||||
_T = TypeVar("_T", bound=Union[Mapping[str, Any], Sequence[Any]])
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_migrator(
|
||||
|
@ -66,7 +68,7 @@ async def async_migrator(
|
|||
|
||||
|
||||
@bind_hass
|
||||
class Store:
|
||||
class Store(Generic[_T]):
|
||||
"""Class to help storing data."""
|
||||
|
||||
def __init__(
|
||||
|
@ -90,7 +92,7 @@ class Store:
|
|||
self._unsub_delay_listener: CALLBACK_TYPE | None = None
|
||||
self._unsub_final_write_listener: CALLBACK_TYPE | None = None
|
||||
self._write_lock = asyncio.Lock()
|
||||
self._load_task: asyncio.Future | None = None
|
||||
self._load_task: asyncio.Future[_T | None] | None = None
|
||||
self._encoder = encoder
|
||||
self._atomic_writes = atomic_writes
|
||||
|
||||
|
@ -99,7 +101,7 @@ class Store:
|
|||
"""Return the config path."""
|
||||
return self.hass.config.path(STORAGE_DIR, self.key)
|
||||
|
||||
async def async_load(self) -> dict | list | None:
|
||||
async def async_load(self) -> _T | None:
|
||||
"""Load data.
|
||||
|
||||
If the expected version and minor version do not match the given versions, the
|
||||
|
@ -113,7 +115,7 @@ class Store:
|
|||
|
||||
return await self._load_task
|
||||
|
||||
async def _async_load(self):
|
||||
async def _async_load(self) -> _T | None:
|
||||
"""Load the data and ensure the task is removed."""
|
||||
if STORAGE_SEMAPHORE not in self.hass.data:
|
||||
self.hass.data[STORAGE_SEMAPHORE] = asyncio.Semaphore(MAX_LOAD_CONCURRENTLY)
|
||||
|
@ -178,7 +180,7 @@ class Store:
|
|||
|
||||
return stored
|
||||
|
||||
async def async_save(self, data: dict | list) -> None:
|
||||
async def async_save(self, data: _T) -> None:
|
||||
"""Save data."""
|
||||
self._data = {
|
||||
"version": self.version,
|
||||
|
@ -196,7 +198,7 @@ class Store:
|
|||
@callback
|
||||
def async_delay_save(
|
||||
self,
|
||||
data_func: Callable[[], dict | list],
|
||||
data_func: Callable[[], _T],
|
||||
delay: float = 0,
|
||||
) -> None:
|
||||
"""Save data with an optional delay."""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue