Make Store a generic class (#74617)

This commit is contained in:
epenet 2022-07-09 22:32:57 +02:00 committed by GitHub
parent d37ad20894
commit 16900dcef1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
27 changed files with 106 additions and 97 deletions

View file

@ -46,7 +46,7 @@ class AuthStore:
self._users: dict[str, models.User] | None = None self._users: dict[str, models.User] | None = None
self._groups: dict[str, models.Group] | None = None self._groups: dict[str, models.Group] | None = None
self._perm_lookup: PermissionLookup | None = None self._perm_lookup: PermissionLookup | None = None
self._store = Store( self._store = Store[dict[str, list[dict[str, Any]]]](
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
) )
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
@ -483,9 +483,10 @@ class AuthStore:
jwt_key=rt_dict["jwt_key"], jwt_key=rt_dict["jwt_key"],
last_used_at=last_used_at, last_used_at=last_used_at,
last_used_ip=rt_dict.get("last_used_ip"), last_used_ip=rt_dict.get("last_used_ip"),
credential=credentials.get(rt_dict.get("credential_id")),
version=rt_dict.get("version"), version=rt_dict.get("version"),
) )
if "credential_id" in rt_dict:
token.credential = credentials.get(rt_dict["credential_id"])
users[rt_dict["user_id"]].refresh_tokens[token.id] = token users[rt_dict["user_id"]].refresh_tokens[token.id] = token
self._groups = groups self._groups = groups

View file

@ -7,7 +7,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections import OrderedDict from collections import OrderedDict
import logging import logging
from typing import Any from typing import Any, cast
import attr import attr
import voluptuous as vol import voluptuous as vol
@ -100,7 +100,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
"""Initialize the user data store.""" """Initialize the user data store."""
super().__init__(hass, config) super().__init__(hass, config)
self._user_settings: _UsersDict | None = None self._user_settings: _UsersDict | None = None
self._user_store = Store( self._user_store = Store[dict[str, dict[str, Any]]](
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
) )
self._include = config.get(CONF_INCLUDE, []) self._include = config.get(CONF_INCLUDE, [])
@ -119,10 +119,8 @@ class NotifyAuthModule(MultiFactorAuthModule):
if self._user_settings is not None: if self._user_settings is not None:
return return
if (data := await self._user_store.async_load()) is None or not isinstance( if (data := await self._user_store.async_load()) is None:
data, dict data = cast(dict[str, dict[str, Any]], {STORAGE_USERS: {}})
):
data = {STORAGE_USERS: {}}
self._user_settings = { self._user_settings = {
user_id: NotifySetting(**setting) user_id: NotifySetting(**setting)

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
from io import BytesIO from io import BytesIO
from typing import Any from typing import Any, cast
import voluptuous as vol import voluptuous as vol
@ -77,7 +77,7 @@ class TotpAuthModule(MultiFactorAuthModule):
"""Initialize the user data store.""" """Initialize the user data store."""
super().__init__(hass, config) super().__init__(hass, config)
self._users: dict[str, str] | None = None self._users: dict[str, str] | None = None
self._user_store = Store( self._user_store = Store[dict[str, dict[str, str]]](
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
) )
self._init_lock = asyncio.Lock() self._init_lock = asyncio.Lock()
@ -93,16 +93,14 @@ class TotpAuthModule(MultiFactorAuthModule):
if self._users is not None: if self._users is not None:
return return
if (data := await self._user_store.async_load()) is None or not isinstance( if (data := await self._user_store.async_load()) is None:
data, dict data = cast(dict[str, dict[str, str]], {STORAGE_USERS: {}})
):
data = {STORAGE_USERS: {}}
self._users = data.get(STORAGE_USERS, {}) self._users = data.get(STORAGE_USERS, {})
async def _async_save(self) -> None: async def _async_save(self) -> None:
"""Save data.""" """Save data."""
await self._user_store.async_save({STORAGE_USERS: self._users}) await self._user_store.async_save({STORAGE_USERS: self._users or {}})
def _add_ota_secret(self, user_id: str, secret: str | None = None) -> str: def _add_ota_secret(self, user_id: str, secret: str | None = None) -> str:
"""Create a ota_secret for user.""" """Create a ota_secret for user."""

View file

@ -61,10 +61,10 @@ class Data:
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the user data store.""" """Initialize the user data store."""
self.hass = hass self.hass = hass
self._store = Store( self._store = Store[dict[str, list[dict[str, str]]]](
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
) )
self._data: dict[str, Any] | None = None self._data: dict[str, list[dict[str, str]]] | None = None
# Legacy mode will allow usernames to start/end with whitespace # Legacy mode will allow usernames to start/end with whitespace
# and will compare usernames case-insensitive. # and will compare usernames case-insensitive.
# Remove in 2020 or when we launch 1.0. # Remove in 2020 or when we launch 1.0.
@ -80,10 +80,8 @@ class Data:
async def async_load(self) -> None: async def async_load(self) -> None:
"""Load stored data.""" """Load stored data."""
if (data := await self._store.async_load()) is None or not isinstance( if (data := await self._store.async_load()) is None:
data, dict data = cast(dict[str, list[dict[str, str]]], {"users": []})
):
data = {"users": []}
seen: set[str] = set() seen: set[str] = set()
@ -123,7 +121,8 @@ class Data:
@property @property
def users(self) -> list[dict[str, str]]: def users(self) -> list[dict[str, str]]:
"""Return users.""" """Return users."""
return self._data["users"] # type: ignore[index,no-any-return] assert self._data is not None
return self._data["users"]
def validate_login(self, username: str, password: str) -> None: def validate_login(self, username: str, password: str) -> None:
"""Validate a username and password. """Validate a username and password.

View file

@ -5,7 +5,7 @@ import asyncio
from datetime import timedelta from datetime import timedelta
import logging import logging
import time import time
from typing import Optional, cast from typing import Any
from aiohttp import ClientError, ClientSession from aiohttp import ClientError, ClientSession
import async_timeout import async_timeout
@ -167,8 +167,8 @@ async def _configure_almond_for_ha(
return return
_LOGGER.debug("Configuring Almond to connect to Home Assistant at %s", hass_url) _LOGGER.debug("Configuring Almond to connect to Home Assistant at %s", hass_url)
store = storage.Store(hass, STORAGE_VERSION, STORAGE_KEY) store = storage.Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY)
data = cast(Optional[dict], await store.async_load()) data = await store.async_load()
if data is None: if data is None:
data = {} data = {}

View file

@ -64,7 +64,7 @@ async def async_setup_entry(
"""Set up the Ambiclimate device from config entry.""" """Set up the Ambiclimate device from config entry."""
config = entry.data config = entry.data
websession = async_get_clientsession(hass) websession = async_get_clientsession(hass)
store = Store(hass, STORAGE_VERSION, STORAGE_KEY) store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY)
token_info = await store.async_load() token_info = await store.async_load()
oauth = ambiclimate.AmbiclimateOAuth( oauth = ambiclimate.AmbiclimateOAuth(

View file

@ -1,6 +1,6 @@
"""Analytics helper class for the analytics integration.""" """Analytics helper class for the analytics integration."""
import asyncio import asyncio
from typing import cast from typing import Any
import uuid import uuid
import aiohttp import aiohttp
@ -66,12 +66,12 @@ class Analytics:
"""Initialize the Analytics class.""" """Initialize the Analytics class."""
self.hass: HomeAssistant = hass self.hass: HomeAssistant = hass
self.session = async_get_clientsession(hass) self.session = async_get_clientsession(hass)
self._data: dict = { self._data: dict[str, Any] = {
ATTR_PREFERENCES: {}, ATTR_PREFERENCES: {},
ATTR_ONBOARDED: False, ATTR_ONBOARDED: False,
ATTR_UUID: None, ATTR_UUID: None,
} }
self._store = Store(hass, STORAGE_VERSION, STORAGE_KEY) self._store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY)
@property @property
def preferences(self) -> dict: def preferences(self) -> dict:
@ -109,7 +109,7 @@ class Analytics:
async def load(self) -> None: async def load(self) -> None:
"""Load preferences.""" """Load preferences."""
stored = cast(dict, await self._store.async_load()) stored = await self._store.async_load()
if stored: if stored:
self._data = stored self._data = stored

View file

@ -36,14 +36,14 @@ class CameraPreferences:
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize camera prefs.""" """Initialize camera prefs."""
self._hass = hass self._hass = hass
self._store = Store(hass, STORAGE_VERSION, STORAGE_KEY) self._store = Store[dict[str, dict[str, bool]]](
hass, STORAGE_VERSION, STORAGE_KEY
)
self._prefs: dict[str, dict[str, bool]] | None = None self._prefs: dict[str, dict[str, bool]] | None = None
async def async_initialize(self) -> None: async def async_initialize(self) -> None:
"""Finish initializing the preferences.""" """Finish initializing the preferences."""
if (prefs := await self._store.async_load()) is None or not isinstance( if (prefs := await self._store.async_load()) is None:
prefs, dict
):
prefs = {} prefs = {}
self._prefs = prefs self._prefs = prefs

View file

@ -4,7 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections import Counter from collections import Counter
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import Literal, Optional, TypedDict, Union, cast from typing import Literal, TypedDict, Union
import voluptuous as vol import voluptuous as vol
@ -263,13 +263,15 @@ class EnergyManager:
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize energy manager.""" """Initialize energy manager."""
self._hass = hass self._hass = hass
self._store = storage.Store(hass, STORAGE_VERSION, STORAGE_KEY) self._store = storage.Store[EnergyPreferences](
hass, STORAGE_VERSION, STORAGE_KEY
)
self.data: EnergyPreferences | None = None self.data: EnergyPreferences | None = None
self._update_listeners: list[Callable[[], Awaitable]] = [] self._update_listeners: list[Callable[[], Awaitable]] = []
async def async_initialize(self) -> None: async def async_initialize(self) -> None:
"""Initialize the energy integration.""" """Initialize the energy integration."""
self.data = cast(Optional[EnergyPreferences], await self._store.async_load()) self.data = await self._store.async_load()
@staticmethod @staticmethod
def default_preferences() -> EnergyPreferences: def default_preferences() -> EnergyPreferences:
@ -294,7 +296,7 @@ class EnergyManager:
data[key] = update[key] # type: ignore[literal-required] data[key] = update[key] # type: ignore[literal-required]
self.data = data self.data = data
self._store.async_delay_save(lambda: cast(dict, self.data), 60) self._store.async_delay_save(lambda: data, 60)
if not self._update_listeners: if not self._update_listeners:
return return

View file

@ -533,12 +533,10 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: # noqa:
if not await hassio.is_connected(): if not await hassio.is_connected():
_LOGGER.warning("Not connected with the supervisor / system too busy!") _LOGGER.warning("Not connected with the supervisor / system too busy!")
store = Store(hass, STORAGE_VERSION, STORAGE_KEY) store = Store[dict[str, str]](hass, STORAGE_VERSION, STORAGE_KEY)
if (data := await store.async_load()) is None: if (data := await store.async_load()) is None:
data = {} data = {}
assert isinstance(data, dict)
refresh_token = None refresh_token = None
if "hassio_user" in data: if "hassio_user" in data:
user = await hass.auth.async_get_user(data["hassio_user"]) user = await hass.auth.async_get_user(data["hassio_user"])

View file

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, TypedDict, cast from typing import Any, TypedDict
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.storage import Store from homeassistant.helpers.storage import Store
@ -46,7 +46,9 @@ class EntityMapStorage:
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Create a new entity map store.""" """Create a new entity map store."""
self.hass = hass self.hass = hass
self.store = Store(hass, ENTITY_MAP_STORAGE_VERSION, ENTITY_MAP_STORAGE_KEY) self.store = Store[StorageLayout](
hass, ENTITY_MAP_STORAGE_VERSION, ENTITY_MAP_STORAGE_KEY
)
self.storage_data: dict[str, Pairing] = {} self.storage_data: dict[str, Pairing] = {}
async def async_initialize(self) -> None: async def async_initialize(self) -> None:
@ -55,8 +57,7 @@ class EntityMapStorage:
# There is no cached data about HomeKit devices yet # There is no cached data about HomeKit devices yet
return return
storage = cast(StorageLayout, raw_storage) self.storage_data = raw_storage.get("pairings", {})
self.storage_data = storage.get("pairings", {})
def get_map(self, homekit_id: str) -> Pairing | None: def get_map(self, homekit_id: str) -> Pairing | None:
"""Get a pairing cache item.""" """Get a pairing cache item."""
@ -87,6 +88,6 @@ class EntityMapStorage:
self.store.async_delay_save(self._data_to_save, ENTITY_MAP_SAVE_DELAY) self.store.async_delay_save(self._data_to_save, ENTITY_MAP_SAVE_DELAY)
@callback @callback
def _data_to_save(self) -> dict[str, Any]: def _data_to_save(self) -> StorageLayout:
"""Return data of entity map to store in a file.""" """Return data of entity map to store in a file."""
return {"pairings": self.storage_data} return StorageLayout(pairings=self.storage_data)

View file

@ -7,7 +7,7 @@ import logging
import os import os
import ssl import ssl
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Any, Final, Optional, TypedDict, Union, cast from typing import Any, Final, TypedDict, Union, cast
from aiohttp import web from aiohttp import web
from aiohttp.typedefs import StrOrURL from aiohttp.typedefs import StrOrURL
@ -125,10 +125,10 @@ class ConfData(TypedDict, total=False):
@bind_hass @bind_hass
async def async_get_last_config(hass: HomeAssistant) -> dict | None: async def async_get_last_config(hass: HomeAssistant) -> dict[str, Any] | None:
"""Return the last known working config.""" """Return the last known working config."""
store = storage.Store(hass, STORAGE_VERSION, STORAGE_KEY) store = storage.Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY)
return cast(Optional[dict], await store.async_load()) return await store.async_load()
class ApiConfig: class ApiConfig:
@ -475,7 +475,9 @@ async def start_http_server_and_save_config(
await server.start() await server.start()
# If we are set up successful, we store the HTTP settings for safe mode. # If we are set up successful, we store the HTTP settings for safe mode.
store = storage.Store(hass, STORAGE_VERSION, STORAGE_KEY) store: storage.Store[dict[str, Any]] = storage.Store(
hass, STORAGE_VERSION, STORAGE_KEY
)
if CONF_TRUSTED_PROXIES in conf: if CONF_TRUSTED_PROXIES in conf:
conf[CONF_TRUSTED_PROXIES] = [ conf[CONF_TRUSTED_PROXIES] = [

View file

@ -6,7 +6,7 @@ from datetime import timedelta
from ipaddress import ip_address from ipaddress import ip_address
import logging import logging
import secrets import secrets
from typing import Final from typing import Any, Final
from aiohttp import hdrs from aiohttp import hdrs
from aiohttp.web import Application, Request, StreamResponse, middleware from aiohttp.web import Application, Request, StreamResponse, middleware
@ -118,8 +118,8 @@ def async_user_not_allowed_do_auth(
async def async_setup_auth(hass: HomeAssistant, app: Application) -> None: async def async_setup_auth(hass: HomeAssistant, app: Application) -> None:
"""Create auth middleware for the app.""" """Create auth middleware for the app."""
store = Store(hass, STORAGE_VERSION, STORAGE_KEY) store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY)
if (data := await store.async_load()) is None or not isinstance(data, dict): if (data := await store.async_load()) is None:
data = {} data = {}
refresh_token = None refresh_token = None

View file

@ -1,5 +1,6 @@
"""Integrates Native Apps to Home Assistant.""" """Integrates Native Apps to Home Assistant."""
from contextlib import suppress from contextlib import suppress
from typing import Any
from homeassistant.components import cloud, notify as hass_notify from homeassistant.components import cloud, notify as hass_notify
from homeassistant.components.webhook import ( from homeassistant.components.webhook import (
@ -38,7 +39,7 @@ PLATFORMS = [Platform.SENSOR, Platform.BINARY_SENSOR, Platform.DEVICE_TRACKER]
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the mobile app component.""" """Set up the mobile app component."""
store = Store(hass, STORAGE_VERSION, STORAGE_KEY) store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY)
if (app_config := await store.async_load()) is None or not isinstance( if (app_config := await store.async_load()) is None or not isinstance(
app_config, dict app_config, dict
): ):

View file

@ -22,6 +22,7 @@ from collections.abc import Mapping
from dataclasses import dataclass from dataclasses import dataclass
import logging import logging
import os import os
from typing import Any
from google_nest_sdm.camera_traits import CameraClipPreviewTrait, CameraEventImageTrait from google_nest_sdm.camera_traits import CameraClipPreviewTrait, CameraEventImageTrait
from google_nest_sdm.device import Device from google_nest_sdm.device import Device
@ -89,7 +90,7 @@ async def async_get_media_event_store(
os.makedirs(media_path, exist_ok=True) os.makedirs(media_path, exist_ok=True)
await hass.async_add_executor_job(mkdir) await hass.async_add_executor_job(mkdir)
store = Store(hass, STORAGE_VERSION, STORAGE_KEY, private=True) store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY, private=True)
return NestEventMediaStore(hass, subscriber, store, media_path) return NestEventMediaStore(hass, subscriber, store, media_path)
@ -119,7 +120,7 @@ class NestEventMediaStore(EventMediaStore):
self, self,
hass: HomeAssistant, hass: HomeAssistant,
subscriber: GoogleNestSubscriber, subscriber: GoogleNestSubscriber,
store: Store, store: Store[dict[str, Any]],
media_path: str, media_path: str,
) -> None: ) -> None:
"""Initialize NestEventMediaStore.""" """Initialize NestEventMediaStore."""
@ -127,7 +128,7 @@ class NestEventMediaStore(EventMediaStore):
self._subscriber = subscriber self._subscriber = subscriber
self._store = store self._store = store
self._media_path = media_path self._media_path = media_path
self._data: dict | None = None self._data: dict[str, Any] | None = None
self._devices: Mapping[str, str] | None = {} self._devices: Mapping[str, str] | None = {}
async def async_load(self) -> dict | None: async def async_load(self) -> dict | None:
@ -137,15 +138,9 @@ class NestEventMediaStore(EventMediaStore):
if (data := await self._store.async_load()) is None: if (data := await self._store.async_load()) is None:
_LOGGER.debug("Loaded empty event store") _LOGGER.debug("Loaded empty event store")
self._data = {} self._data = {}
elif isinstance(data, dict): else:
_LOGGER.debug("Loaded event store with %d records", len(data)) _LOGGER.debug("Loaded event store with %d records", len(data))
self._data = data self._data = data
else:
raise ValueError(
"Unexpected data in storage version={}, key={}".format(
STORAGE_VERSION, STORAGE_KEY
)
)
return self._data return self._data
async def async_save(self, data: dict) -> None: async def async_save(self, data: dict) -> None:

View file

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any, cast from typing import Any
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.singleton import singleton from homeassistant.helpers.singleton import singleton
@ -38,8 +38,10 @@ class Network:
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the Network class.""" """Initialize the Network class."""
self._store = Store(hass, STORAGE_VERSION, STORAGE_KEY, atomic_writes=True) self._store = Store[dict[str, list[str]]](
self._data: dict[str, Any] = {} hass, STORAGE_VERSION, STORAGE_KEY, atomic_writes=True
)
self._data: dict[str, list[str]] = {}
self.adapters: list[Adapter] = [] self.adapters: list[Adapter] = []
@property @property
@ -67,7 +69,7 @@ class Network:
async def async_load(self) -> None: async def async_load(self) -> None:
"""Load config.""" """Load config."""
if stored := await self._store.async_load(): if stored := await self._store.async_load():
self._data = cast(dict, stored) self._data = stored
async def _async_save(self) -> None: async def _async_save(self) -> None:
"""Save preferences.""" """Save preferences."""

View file

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
from typing import cast from typing import Optional, cast
from homeassistant.const import __version__ as ha_version from homeassistant.const import __version__ as ha_version
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
@ -39,7 +39,9 @@ class IssueRegistry:
"""Initialize the issue registry.""" """Initialize the issue registry."""
self.hass = hass self.hass = hass
self.issues: dict[tuple[str, str], IssueEntry] = {} self.issues: dict[tuple[str, str], IssueEntry] = {}
self._store = Store(hass, STORAGE_VERSION, STORAGE_KEY, atomic_writes=True) self._store = Store[dict[str, list[dict[str, Optional[str]]]]](
hass, STORAGE_VERSION, STORAGE_KEY, atomic_writes=True
)
@callback @callback
def async_get_issue(self, domain: str, issue_id: str) -> IssueEntry | None: def async_get_issue(self, domain: str, issue_id: str) -> IssueEntry | None:
@ -119,6 +121,7 @@ class IssueRegistry:
if isinstance(data, dict): if isinstance(data, dict):
for issue in data["issues"]: for issue in data["issues"]:
assert issue["domain"] and issue["issue_id"]
issues[(issue["domain"], issue["issue_id"])] = IssueEntry( issues[(issue["domain"], issue["issue_id"])] = IssueEntry(
active=False, active=False,
breaks_in_ha_version=None, breaks_in_ha_version=None,

View file

@ -3,6 +3,7 @@ import asyncio
import functools import functools
import logging import logging
import secrets import secrets
from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4
@ -211,8 +212,8 @@ async def setup_smartapp_endpoint(hass: HomeAssistant):
return return
# Get/create config to store a unique id for this hass instance. # Get/create config to store a unique id for this hass instance.
store = Store(hass, STORAGE_VERSION, STORAGE_KEY) store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY)
if not (config := await store.async_load()) or not isinstance(config, dict): if not (config := await store.async_load()):
# Create config # Create config
config = { config = {
CONF_INSTANCE_ID: str(uuid4()), CONF_INSTANCE_ID: str(uuid4()),
@ -283,7 +284,7 @@ async def unload_smartapp_endpoint(hass: HomeAssistant):
if cloudhook_url and cloud.async_is_logged_in(hass): if cloudhook_url and cloud.async_is_logged_in(hass):
await cloud.async_delete_cloudhook(hass, hass.data[DOMAIN][CONF_WEBHOOK_ID]) await cloud.async_delete_cloudhook(hass, hass.data[DOMAIN][CONF_WEBHOOK_ID])
# Remove cloudhook from storage # Remove cloudhook from storage
store = Store(hass, STORAGE_VERSION, STORAGE_KEY) store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY)
await store.async_save( await store.async_save(
{ {
CONF_INSTANCE_ID: hass.data[DOMAIN][CONF_INSTANCE_ID], CONF_INSTANCE_ID: hass.data[DOMAIN][CONF_INSTANCE_ID],

View file

@ -52,7 +52,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Initialize the trace integration.""" """Initialize the trace integration."""
hass.data[DATA_TRACE] = {} hass.data[DATA_TRACE] = {}
websocket_api.async_setup(hass) websocket_api.async_setup(hass)
store = Store(hass, STORAGE_VERSION, STORAGE_KEY, encoder=ExtendedJSONEncoder) store = Store[dict[str, list]](
hass, STORAGE_VERSION, STORAGE_KEY, encoder=ExtendedJSONEncoder
)
hass.data[DATA_TRACE_STORE] = store hass.data[DATA_TRACE_STORE] = store
async def _async_store_traces_at_stop(*_) -> None: async def _async_store_traces_at_stop(*_) -> None:

View file

@ -40,7 +40,7 @@ class ZhaStorage:
"""Initialize the zha device storage.""" """Initialize the zha device storage."""
self.hass: HomeAssistant = hass self.hass: HomeAssistant = hass
self.devices: MutableMapping[str, ZhaDeviceEntry] = {} self.devices: MutableMapping[str, ZhaDeviceEntry] = {}
self._store = Store(hass, STORAGE_VERSION, STORAGE_KEY) self._store = Store[dict[str, Any]](hass, STORAGE_VERSION, STORAGE_KEY)
@callback @callback
def async_create_device(self, device: ZHADevice) -> ZhaDeviceEntry: def async_create_device(self, device: ZHADevice) -> ZhaDeviceEntry:
@ -94,7 +94,7 @@ class ZhaStorage:
async def async_load(self) -> None: async def async_load(self) -> None:
"""Load the registry of zha device entries.""" """Load the registry of zha device entries."""
data = cast(dict[str, Any], await self._store.async_load()) data = await self._store.async_load()
devices: OrderedDict[str, ZhaDeviceEntry] = OrderedDict() devices: OrderedDict[str, ZhaDeviceEntry] = OrderedDict()

View file

@ -845,7 +845,9 @@ class ConfigEntries:
self._hass_config = hass_config self._hass_config = hass_config
self._entries: dict[str, ConfigEntry] = {} self._entries: dict[str, ConfigEntry] = {}
self._domain_index: dict[str, list[str]] = {} self._domain_index: dict[str, list[str]] = {}
self._store = storage.Store(hass, STORAGE_VERSION, STORAGE_KEY) self._store = storage.Store[dict[str, list[dict[str, Any]]]](
hass, STORAGE_VERSION, STORAGE_KEY
)
EntityRegistryDisabledHandler(hass).async_setup() EntityRegistryDisabledHandler(hass).async_setup()
@callback @callback

View file

@ -1942,7 +1942,7 @@ class Config:
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
from .helpers.storage import Store from .helpers.storage import Store
store = Store( store = Store[dict[str, Any]](
self.hass, self.hass,
CORE_STORAGE_VERSION, CORE_STORAGE_VERSION,
CORE_STORAGE_KEY, CORE_STORAGE_KEY,
@ -1950,7 +1950,7 @@ class Config:
atomic_writes=True, atomic_writes=True,
) )
if not (data := await store.async_load()) or not isinstance(data, dict): if not (data := await store.async_load()):
return return
# In 2021.9 we fixed validation to disallow a path (because that's never correct) # In 2021.9 we fixed validation to disallow a path (because that's never correct)
@ -1998,7 +1998,7 @@ class Config:
"currency": self.currency, "currency": self.currency,
} }
store = Store( store: Store[dict[str, Any]] = Store(
self.hass, self.hass,
CORE_STORAGE_VERSION, CORE_STORAGE_VERSION,
CORE_STORAGE_KEY, CORE_STORAGE_KEY,

View file

@ -3,7 +3,7 @@ from __future__ import annotations
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Container, Iterable, MutableMapping from collections.abc import Container, Iterable, MutableMapping
from typing import cast from typing import Optional, cast
import attr import attr
@ -49,7 +49,9 @@ class AreaRegistry:
"""Initialize the area registry.""" """Initialize the area registry."""
self.hass = hass self.hass = hass
self.areas: MutableMapping[str, AreaEntry] = {} self.areas: MutableMapping[str, AreaEntry] = {}
self._store = Store(hass, STORAGE_VERSION, STORAGE_KEY, atomic_writes=True) self._store = Store[dict[str, list[dict[str, Optional[str]]]]](
hass, STORAGE_VERSION, STORAGE_KEY, atomic_writes=True
)
self._normalized_name_area_idx: dict[str, str] = {} self._normalized_name_area_idx: dict[str, str] = {}
@callback @callback
@ -176,8 +178,9 @@ class AreaRegistry:
areas: MutableMapping[str, AreaEntry] = OrderedDict() areas: MutableMapping[str, AreaEntry] = OrderedDict()
if isinstance(data, dict): if data is not None:
for area in data["areas"]: for area in data["areas"]:
assert area["name"] is not None and area["id"] is not None
normalized_name = normalize_area_name(area["name"]) normalized_name = normalize_area_name(area["name"])
areas[area["id"]] = AreaEntry( areas[area["id"]] = AreaEntry(
name=area["name"], name=area["name"],

View file

@ -164,7 +164,7 @@ def _async_get_device_id_from_index(
return None return None
class DeviceRegistryStore(storage.Store): class DeviceRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]):
"""Store entity registry data.""" """Store entity registry data."""
async def _async_migrate_func( async def _async_migrate_func(
@ -569,7 +569,6 @@ class DeviceRegistry:
deleted_devices = OrderedDict() deleted_devices = OrderedDict()
if data is not None: if data is not None:
data = cast("dict[str, Any]", data)
for device in data["devices"]: for device in data["devices"]:
devices[device["id"]] = DeviceEntry( devices[device["id"]] = DeviceEntry(
area_id=device["area_id"], area_id=device["area_id"],

View file

@ -16,7 +16,7 @@ LEGACY_UUID_FILE = ".uuid"
@singleton.singleton(DATA_KEY) @singleton.singleton(DATA_KEY)
async def async_get(hass: HomeAssistant) -> str: async def async_get(hass: HomeAssistant) -> str:
"""Get unique ID for the hass instance.""" """Get unique ID for the hass instance."""
store = storage.Store(hass, DATA_VERSION, DATA_KEY, True) store = storage.Store[dict[str, str]](hass, DATA_VERSION, DATA_KEY, True)
data: dict[str, str] | None = await storage.async_migrator( data: dict[str, str] | None = await storage.async_migrator(
hass, hass,

View file

@ -139,7 +139,7 @@ class RestoreStateData:
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the restore state data class.""" """Initialize the restore state data class."""
self.hass: HomeAssistant = hass self.hass: HomeAssistant = hass
self.store: Store = Store( self.store = Store[list[dict[str, Any]]](
hass, STORAGE_VERSION, STORAGE_KEY, encoder=JSONEncoder hass, STORAGE_VERSION, STORAGE_KEY, encoder=JSONEncoder
) )
self.last_states: dict[str, StoredState] = {} self.last_states: dict[str, StoredState] = {}

View file

@ -2,14 +2,14 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable from collections.abc import Callable, Mapping, Sequence
from contextlib import suppress from contextlib import suppress
from copy import deepcopy from copy import deepcopy
import inspect import inspect
from json import JSONEncoder from json import JSONEncoder
import logging import logging
import os import os
from typing import Any from typing import Any, Generic, TypeVar, Union
from homeassistant.const import EVENT_HOMEASSISTANT_FINAL_WRITE from homeassistant.const import EVENT_HOMEASSISTANT_FINAL_WRITE
from homeassistant.core import CALLBACK_TYPE, CoreState, Event, HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, CoreState, Event, HomeAssistant, callback
@ -24,6 +24,8 @@ _LOGGER = logging.getLogger(__name__)
STORAGE_SEMAPHORE = "storage_semaphore" STORAGE_SEMAPHORE = "storage_semaphore"
_T = TypeVar("_T", bound=Union[Mapping[str, Any], Sequence[Any]])
@bind_hass @bind_hass
async def async_migrator( async def async_migrator(
@ -66,7 +68,7 @@ async def async_migrator(
@bind_hass @bind_hass
class Store: class Store(Generic[_T]):
"""Class to help storing data.""" """Class to help storing data."""
def __init__( def __init__(
@ -90,7 +92,7 @@ class Store:
self._unsub_delay_listener: CALLBACK_TYPE | None = None self._unsub_delay_listener: CALLBACK_TYPE | None = None
self._unsub_final_write_listener: CALLBACK_TYPE | None = None self._unsub_final_write_listener: CALLBACK_TYPE | None = None
self._write_lock = asyncio.Lock() self._write_lock = asyncio.Lock()
self._load_task: asyncio.Future | None = None self._load_task: asyncio.Future[_T | None] | None = None
self._encoder = encoder self._encoder = encoder
self._atomic_writes = atomic_writes self._atomic_writes = atomic_writes
@ -99,7 +101,7 @@ class Store:
"""Return the config path.""" """Return the config path."""
return self.hass.config.path(STORAGE_DIR, self.key) return self.hass.config.path(STORAGE_DIR, self.key)
async def async_load(self) -> dict | list | None: async def async_load(self) -> _T | None:
"""Load data. """Load data.
If the expected version and minor version do not match the given versions, the If the expected version and minor version do not match the given versions, the
@ -113,7 +115,7 @@ class Store:
return await self._load_task return await self._load_task
async def _async_load(self): async def _async_load(self) -> _T | None:
"""Load the data and ensure the task is removed.""" """Load the data and ensure the task is removed."""
if STORAGE_SEMAPHORE not in self.hass.data: if STORAGE_SEMAPHORE not in self.hass.data:
self.hass.data[STORAGE_SEMAPHORE] = asyncio.Semaphore(MAX_LOAD_CONCURRENTLY) self.hass.data[STORAGE_SEMAPHORE] = asyncio.Semaphore(MAX_LOAD_CONCURRENTLY)
@ -178,7 +180,7 @@ class Store:
return stored return stored
async def async_save(self, data: dict | list) -> None: async def async_save(self, data: _T) -> None:
"""Save data.""" """Save data."""
self._data = { self._data = {
"version": self.version, "version": self.version,
@ -196,7 +198,7 @@ class Store:
@callback @callback
def async_delay_save( def async_delay_save(
self, self,
data_func: Callable[[], dict | list], data_func: Callable[[], _T],
delay: float = 0, delay: float = 0,
) -> None: ) -> None:
"""Save data with an optional delay.""" """Save data with an optional delay."""