Improve registry store data typing (#115066)

This commit is contained in:
Marc Mueller 2024-04-07 02:34:49 +02:00 committed by GitHub
parent a0936902c2
commit cb9352110c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 104 additions and 54 deletions

View file

@ -26,6 +26,24 @@ STORAGE_VERSION_MAJOR = 1
STORAGE_VERSION_MINOR = 6 STORAGE_VERSION_MINOR = 6
class _AreaStoreData(TypedDict):
"""Data type for individual area. Used in AreasRegistryStoreData."""
aliases: list[str]
floor_id: str | None
icon: str | None
id: str
labels: list[str]
name: str
picture: str | None
class AreasRegistryStoreData(TypedDict):
"""Store data type for AreaRegistry."""
areas: list[_AreaStoreData]
class EventAreaRegistryUpdatedData(TypedDict): class EventAreaRegistryUpdatedData(TypedDict):
"""EventAreaRegistryUpdated data.""" """EventAreaRegistryUpdated data."""
@ -45,7 +63,7 @@ class AreaEntry(NormalizedNameBaseRegistryEntry):
picture: str | None picture: str | None
class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]): class AreaRegistryStore(Store[AreasRegistryStoreData]):
"""Store area registry data.""" """Store area registry data."""
async def _async_migrate_func( async def _async_migrate_func(
@ -53,7 +71,7 @@ class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]):
old_major_version: int, old_major_version: int,
old_minor_version: int, old_minor_version: int,
old_data: dict[str, list[dict[str, Any]]], old_data: dict[str, list[dict[str, Any]]],
) -> dict[str, Any]: ) -> AreasRegistryStoreData:
"""Migrate to the new version.""" """Migrate to the new version."""
if old_major_version < 2: if old_major_version < 2:
if old_minor_version < 2: if old_minor_version < 2:
@ -84,7 +102,7 @@ class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]):
if old_major_version > 1: if old_major_version > 1:
raise NotImplementedError raise NotImplementedError
return old_data return old_data # type: ignore[return-value]
class AreaRegistryItems(NormalizedNameBaseRegistryItems[AreaEntry]): class AreaRegistryItems(NormalizedNameBaseRegistryItems[AreaEntry]):
@ -126,7 +144,7 @@ class AreaRegistryItems(NormalizedNameBaseRegistryItems[AreaEntry]):
return [data[key] for key in self._floors_index.get(floor, ())] return [data[key] for key in self._floors_index.get(floor, ())]
class AreaRegistry(BaseRegistry): class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
"""Class to hold a registry of areas.""" """Class to hold a registry of areas."""
areas: AreaRegistryItems areas: AreaRegistryItems
@ -314,11 +332,10 @@ class AreaRegistry(BaseRegistry):
self._area_data = areas.data self._area_data = areas.data
@callback @callback
def _data_to_save(self) -> dict[str, list[dict[str, Any]]]: def _data_to_save(self) -> AreasRegistryStoreData:
"""Return data of area registry to store in a file.""" """Return data of area registry to store in a file."""
data = {} return {
"areas": [
data["areas"] = [
{ {
"aliases": list(entry.aliases), "aliases": list(entry.aliases),
"floor_id": entry.floor_id, "floor_id": entry.floor_id,
@ -330,8 +347,7 @@ class AreaRegistry(BaseRegistry):
} }
for entry in self.areas.values() for entry in self.areas.values()
] ]
}
return data
def _generate_area_id(self, name: str) -> str: def _generate_area_id(self, name: str) -> str:
"""Generate area ID.""" """Generate area ID."""

View file

@ -20,6 +20,20 @@ STORAGE_KEY = "core.category_registry"
STORAGE_VERSION_MAJOR = 1 STORAGE_VERSION_MAJOR = 1
class _CategoryStoreData(TypedDict):
"""Data type for individual category. Used in CategoryRegistryStoreData."""
category_id: str
icon: str | None
name: str
class CategoryRegistryStoreData(TypedDict):
"""Store data type for CategoryRegistry."""
categories: dict[str, list[_CategoryStoreData]]
class EventCategoryRegistryUpdatedData(TypedDict): class EventCategoryRegistryUpdatedData(TypedDict):
"""Event data for when the category registry is updated.""" """Event data for when the category registry is updated."""
@ -40,14 +54,14 @@ class CategoryEntry:
name: str name: str
class CategoryRegistry(BaseRegistry): class CategoryRegistry(BaseRegistry[CategoryRegistryStoreData]):
"""Class to hold a registry of categories by scope.""" """Class to hold a registry of categories by scope."""
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the category registry.""" """Initialize the category registry."""
self.hass = hass self.hass = hass
self.categories: dict[str, dict[str, CategoryEntry]] = {} self.categories: dict[str, dict[str, CategoryEntry]] = {}
self._store: Store[dict[str, dict[str, list[dict[str, str]]]]] = Store( self._store = Store(
hass, hass,
STORAGE_VERSION_MAJOR, STORAGE_VERSION_MAJOR,
STORAGE_KEY, STORAGE_KEY,
@ -167,7 +181,7 @@ class CategoryRegistry(BaseRegistry):
self.categories = category_entries self.categories = category_entries
@callback @callback
def _data_to_save(self) -> dict[str, dict[str, list[dict[str, str | None]]]]: def _data_to_save(self) -> CategoryRegistryStoreData:
"""Return data of category registry to store in a file.""" """Return data of category registry to store in a file."""
return { return {
"categories": { "categories": {

View file

@ -551,7 +551,7 @@ class ActiveDeviceRegistryItems(DeviceRegistryItems[DeviceEntry]):
] ]
class DeviceRegistry(BaseRegistry): class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
"""Class to hold a registry of devices.""" """Class to hold a registry of devices."""
devices: ActiveDeviceRegistryItems devices: ActiveDeviceRegistryItems

View file

@ -5,7 +5,7 @@ from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
import dataclasses import dataclasses
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal, TypedDict, cast from typing import Literal, TypedDict, cast
from homeassistant.core import Event, HomeAssistant, callback from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.util import slugify from homeassistant.util import slugify
@ -25,6 +25,22 @@ STORAGE_KEY = "core.floor_registry"
STORAGE_VERSION_MAJOR = 1 STORAGE_VERSION_MAJOR = 1
class _FloorStoreData(TypedDict):
"""Data type for individual floor. Used in FloorRegistryStoreData."""
aliases: list[str]
floor_id: str
icon: str | None
level: int | None
name: str
class FloorRegistryStoreData(TypedDict):
"""Store data type for FloorRegistry."""
floors: list[_FloorStoreData]
class EventFloorRegistryUpdatedData(TypedDict): class EventFloorRegistryUpdatedData(TypedDict):
"""Event data for when the floor registry is updated.""" """Event data for when the floor registry is updated."""
@ -45,7 +61,7 @@ class FloorEntry(NormalizedNameBaseRegistryEntry):
level: int | None = None level: int | None = None
class FloorRegistry(BaseRegistry): class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
"""Class to hold a registry of floors.""" """Class to hold a registry of floors."""
floors: NormalizedNameBaseRegistryItems[FloorEntry] floors: NormalizedNameBaseRegistryItems[FloorEntry]
@ -54,14 +70,12 @@ class FloorRegistry(BaseRegistry):
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the floor registry.""" """Initialize the floor registry."""
self.hass = hass self.hass = hass
self._store: Store[dict[str, list[dict[str, str | int | list[str] | None]]]] = ( self._store = Store(
Store(
hass, hass,
STORAGE_VERSION_MAJOR, STORAGE_VERSION_MAJOR,
STORAGE_KEY, STORAGE_KEY,
atomic_writes=True, atomic_writes=True,
) )
)
@callback @callback
def async_get_floor(self, floor_id: str) -> FloorEntry | None: def async_get_floor(self, floor_id: str) -> FloorEntry | None:
@ -190,13 +204,6 @@ class FloorRegistry(BaseRegistry):
if data is not None: if data is not None:
for floor in data["floors"]: for floor in data["floors"]:
if TYPE_CHECKING:
assert isinstance(floor["aliases"], list)
assert isinstance(floor["icon"], str)
assert isinstance(floor["level"], int)
assert isinstance(floor["name"], str)
assert isinstance(floor["floor_id"], str)
normalized_name = normalize_name(floor["name"]) normalized_name = normalize_name(floor["name"])
floors[floor["floor_id"]] = FloorEntry( floors[floor["floor_id"]] = FloorEntry(
aliases=set(floor["aliases"]), aliases=set(floor["aliases"]),
@ -211,7 +218,7 @@ class FloorRegistry(BaseRegistry):
self._floor_data = floors.data self._floor_data = floors.data
@callback @callback
def _data_to_save(self) -> dict[str, list[dict[str, str | int | list[str] | None]]]: def _data_to_save(self) -> FloorRegistryStoreData:
"""Return data of floor registry to store in a file.""" """Return data of floor registry to store in a file."""
return { return {
"floors": [ "floors": [

View file

@ -25,6 +25,22 @@ STORAGE_KEY = "core.label_registry"
STORAGE_VERSION_MAJOR = 1 STORAGE_VERSION_MAJOR = 1
class _LabelStoreData(TypedDict):
"""Data type for individual label. Used in LabelRegistryStoreData."""
color: str | None
description: str | None
icon: str | None
label_id: str
name: str
class LabelRegistryStoreData(TypedDict):
"""Store data type for LabelRegistry."""
labels: list[_LabelStoreData]
class EventLabelRegistryUpdatedData(TypedDict): class EventLabelRegistryUpdatedData(TypedDict):
"""Event data for when the label registry is updated.""" """Event data for when the label registry is updated."""
@ -45,7 +61,7 @@ class LabelEntry(NormalizedNameBaseRegistryEntry):
icon: str | None = None icon: str | None = None
class LabelRegistry(BaseRegistry): class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
"""Class to hold a registry of labels.""" """Class to hold a registry of labels."""
labels: NormalizedNameBaseRegistryItems[LabelEntry] labels: NormalizedNameBaseRegistryItems[LabelEntry]
@ -54,7 +70,7 @@ class LabelRegistry(BaseRegistry):
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the label registry.""" """Initialize the label registry."""
self.hass = hass self.hass = hass
self._store: Store[dict[str, list[dict[str, str | None]]]] = Store( self._store = Store(
hass, hass,
STORAGE_VERSION_MAJOR, STORAGE_VERSION_MAJOR,
STORAGE_KEY, STORAGE_KEY,
@ -189,10 +205,6 @@ class LabelRegistry(BaseRegistry):
if data is not None: if data is not None:
for label in data["labels"]: for label in data["labels"]:
# Check if the necessary keys are present
if label["label_id"] is None or label["name"] is None:
continue
normalized_name = normalize_name(label["name"]) normalized_name = normalize_name(label["name"])
labels[label["label_id"]] = LabelEntry( labels[label["label_id"]] = LabelEntry(
color=label["color"], color=label["color"],
@ -207,7 +219,7 @@ class LabelRegistry(BaseRegistry):
self._label_data = labels.data self._label_data = labels.data
@callback @callback
def _data_to_save(self) -> dict[str, list[dict[str, str | None]]]: def _data_to_save(self) -> LabelRegistryStoreData:
"""Return data of label registry to store in a file.""" """Return data of label registry to store in a file."""
return { return {
"labels": [ "labels": [

View file

@ -4,8 +4,8 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import UserDict from collections import UserDict
from collections.abc import ValuesView from collections.abc import Mapping, Sequence, ValuesView
from typing import TYPE_CHECKING, Any, Literal, TypeVar from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar
from homeassistant.core import CoreState, HomeAssistant, callback from homeassistant.core import CoreState, HomeAssistant, callback
@ -17,6 +17,7 @@ SAVE_DELAY_LONG = 180
_DataT = TypeVar("_DataT") _DataT = TypeVar("_DataT")
_StoreDataT = TypeVar("_StoreDataT", bound=Mapping[str, Any] | Sequence[Any])
class BaseRegistryItems(UserDict[str, _DataT], ABC): class BaseRegistryItems(UserDict[str, _DataT], ABC):
@ -64,11 +65,11 @@ class BaseRegistryItems(UserDict[str, _DataT], ABC):
super().__delitem__(key) super().__delitem__(key)
class BaseRegistry(ABC): class BaseRegistry(ABC, Generic[_StoreDataT]):
"""Class to implement a registry.""" """Class to implement a registry."""
hass: HomeAssistant hass: HomeAssistant
_store: Store _store: Store[_StoreDataT]
@callback @callback
def async_schedule_save(self) -> None: def async_schedule_save(self) -> None:
@ -80,5 +81,5 @@ class BaseRegistry(ABC):
@callback @callback
@abstractmethod @abstractmethod
def _data_to_save(self) -> dict[str, Any]: def _data_to_save(self) -> _StoreDataT:
"""Return data of registry to store in a file.""" """Return data of registry to store in a file."""