Improve registry store data typing (#115066)
This commit is contained in:
parent
a0936902c2
commit
cb9352110c
6 changed files with 104 additions and 54 deletions
|
@ -26,6 +26,24 @@ STORAGE_VERSION_MAJOR = 1
|
||||||
STORAGE_VERSION_MINOR = 6
|
STORAGE_VERSION_MINOR = 6
|
||||||
|
|
||||||
|
|
||||||
|
class _AreaStoreData(TypedDict):
|
||||||
|
"""Data type for individual area. Used in AreasRegistryStoreData."""
|
||||||
|
|
||||||
|
aliases: list[str]
|
||||||
|
floor_id: str | None
|
||||||
|
icon: str | None
|
||||||
|
id: str
|
||||||
|
labels: list[str]
|
||||||
|
name: str
|
||||||
|
picture: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class AreasRegistryStoreData(TypedDict):
|
||||||
|
"""Store data type for AreaRegistry."""
|
||||||
|
|
||||||
|
areas: list[_AreaStoreData]
|
||||||
|
|
||||||
|
|
||||||
class EventAreaRegistryUpdatedData(TypedDict):
|
class EventAreaRegistryUpdatedData(TypedDict):
|
||||||
"""EventAreaRegistryUpdated data."""
|
"""EventAreaRegistryUpdated data."""
|
||||||
|
|
||||||
|
@ -45,7 +63,7 @@ class AreaEntry(NormalizedNameBaseRegistryEntry):
|
||||||
picture: str | None
|
picture: str | None
|
||||||
|
|
||||||
|
|
||||||
class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]):
|
class AreaRegistryStore(Store[AreasRegistryStoreData]):
|
||||||
"""Store area registry data."""
|
"""Store area registry data."""
|
||||||
|
|
||||||
async def _async_migrate_func(
|
async def _async_migrate_func(
|
||||||
|
@ -53,7 +71,7 @@ class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]):
|
||||||
old_major_version: int,
|
old_major_version: int,
|
||||||
old_minor_version: int,
|
old_minor_version: int,
|
||||||
old_data: dict[str, list[dict[str, Any]]],
|
old_data: dict[str, list[dict[str, Any]]],
|
||||||
) -> dict[str, Any]:
|
) -> AreasRegistryStoreData:
|
||||||
"""Migrate to the new version."""
|
"""Migrate to the new version."""
|
||||||
if old_major_version < 2:
|
if old_major_version < 2:
|
||||||
if old_minor_version < 2:
|
if old_minor_version < 2:
|
||||||
|
@ -84,7 +102,7 @@ class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]):
|
||||||
|
|
||||||
if old_major_version > 1:
|
if old_major_version > 1:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
return old_data
|
return old_data # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
class AreaRegistryItems(NormalizedNameBaseRegistryItems[AreaEntry]):
|
class AreaRegistryItems(NormalizedNameBaseRegistryItems[AreaEntry]):
|
||||||
|
@ -126,7 +144,7 @@ class AreaRegistryItems(NormalizedNameBaseRegistryItems[AreaEntry]):
|
||||||
return [data[key] for key in self._floors_index.get(floor, ())]
|
return [data[key] for key in self._floors_index.get(floor, ())]
|
||||||
|
|
||||||
|
|
||||||
class AreaRegistry(BaseRegistry):
|
class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
|
||||||
"""Class to hold a registry of areas."""
|
"""Class to hold a registry of areas."""
|
||||||
|
|
||||||
areas: AreaRegistryItems
|
areas: AreaRegistryItems
|
||||||
|
@ -314,11 +332,10 @@ class AreaRegistry(BaseRegistry):
|
||||||
self._area_data = areas.data
|
self._area_data = areas.data
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _data_to_save(self) -> dict[str, list[dict[str, Any]]]:
|
def _data_to_save(self) -> AreasRegistryStoreData:
|
||||||
"""Return data of area registry to store in a file."""
|
"""Return data of area registry to store in a file."""
|
||||||
data = {}
|
return {
|
||||||
|
"areas": [
|
||||||
data["areas"] = [
|
|
||||||
{
|
{
|
||||||
"aliases": list(entry.aliases),
|
"aliases": list(entry.aliases),
|
||||||
"floor_id": entry.floor_id,
|
"floor_id": entry.floor_id,
|
||||||
|
@ -330,8 +347,7 @@ class AreaRegistry(BaseRegistry):
|
||||||
}
|
}
|
||||||
for entry in self.areas.values()
|
for entry in self.areas.values()
|
||||||
]
|
]
|
||||||
|
}
|
||||||
return data
|
|
||||||
|
|
||||||
def _generate_area_id(self, name: str) -> str:
|
def _generate_area_id(self, name: str) -> str:
|
||||||
"""Generate area ID."""
|
"""Generate area ID."""
|
||||||
|
|
|
@ -20,6 +20,20 @@ STORAGE_KEY = "core.category_registry"
|
||||||
STORAGE_VERSION_MAJOR = 1
|
STORAGE_VERSION_MAJOR = 1
|
||||||
|
|
||||||
|
|
||||||
|
class _CategoryStoreData(TypedDict):
|
||||||
|
"""Data type for individual category. Used in CategoryRegistryStoreData."""
|
||||||
|
|
||||||
|
category_id: str
|
||||||
|
icon: str | None
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class CategoryRegistryStoreData(TypedDict):
|
||||||
|
"""Store data type for CategoryRegistry."""
|
||||||
|
|
||||||
|
categories: dict[str, list[_CategoryStoreData]]
|
||||||
|
|
||||||
|
|
||||||
class EventCategoryRegistryUpdatedData(TypedDict):
|
class EventCategoryRegistryUpdatedData(TypedDict):
|
||||||
"""Event data for when the category registry is updated."""
|
"""Event data for when the category registry is updated."""
|
||||||
|
|
||||||
|
@ -40,14 +54,14 @@ class CategoryEntry:
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
class CategoryRegistry(BaseRegistry):
|
class CategoryRegistry(BaseRegistry[CategoryRegistryStoreData]):
|
||||||
"""Class to hold a registry of categories by scope."""
|
"""Class to hold a registry of categories by scope."""
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant) -> None:
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
"""Initialize the category registry."""
|
"""Initialize the category registry."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.categories: dict[str, dict[str, CategoryEntry]] = {}
|
self.categories: dict[str, dict[str, CategoryEntry]] = {}
|
||||||
self._store: Store[dict[str, dict[str, list[dict[str, str]]]]] = Store(
|
self._store = Store(
|
||||||
hass,
|
hass,
|
||||||
STORAGE_VERSION_MAJOR,
|
STORAGE_VERSION_MAJOR,
|
||||||
STORAGE_KEY,
|
STORAGE_KEY,
|
||||||
|
@ -167,7 +181,7 @@ class CategoryRegistry(BaseRegistry):
|
||||||
self.categories = category_entries
|
self.categories = category_entries
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _data_to_save(self) -> dict[str, dict[str, list[dict[str, str | None]]]]:
|
def _data_to_save(self) -> CategoryRegistryStoreData:
|
||||||
"""Return data of category registry to store in a file."""
|
"""Return data of category registry to store in a file."""
|
||||||
return {
|
return {
|
||||||
"categories": {
|
"categories": {
|
||||||
|
|
|
@ -551,7 +551,7 @@ class ActiveDeviceRegistryItems(DeviceRegistryItems[DeviceEntry]):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class DeviceRegistry(BaseRegistry):
|
class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
||||||
"""Class to hold a registry of devices."""
|
"""Class to hold a registry of devices."""
|
||||||
|
|
||||||
devices: ActiveDeviceRegistryItems
|
devices: ActiveDeviceRegistryItems
|
||||||
|
|
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Literal, TypedDict, cast
|
from typing import Literal, TypedDict, cast
|
||||||
|
|
||||||
from homeassistant.core import Event, HomeAssistant, callback
|
from homeassistant.core import Event, HomeAssistant, callback
|
||||||
from homeassistant.util import slugify
|
from homeassistant.util import slugify
|
||||||
|
@ -25,6 +25,22 @@ STORAGE_KEY = "core.floor_registry"
|
||||||
STORAGE_VERSION_MAJOR = 1
|
STORAGE_VERSION_MAJOR = 1
|
||||||
|
|
||||||
|
|
||||||
|
class _FloorStoreData(TypedDict):
|
||||||
|
"""Data type for individual floor. Used in FloorRegistryStoreData."""
|
||||||
|
|
||||||
|
aliases: list[str]
|
||||||
|
floor_id: str
|
||||||
|
icon: str | None
|
||||||
|
level: int | None
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class FloorRegistryStoreData(TypedDict):
|
||||||
|
"""Store data type for FloorRegistry."""
|
||||||
|
|
||||||
|
floors: list[_FloorStoreData]
|
||||||
|
|
||||||
|
|
||||||
class EventFloorRegistryUpdatedData(TypedDict):
|
class EventFloorRegistryUpdatedData(TypedDict):
|
||||||
"""Event data for when the floor registry is updated."""
|
"""Event data for when the floor registry is updated."""
|
||||||
|
|
||||||
|
@ -45,7 +61,7 @@ class FloorEntry(NormalizedNameBaseRegistryEntry):
|
||||||
level: int | None = None
|
level: int | None = None
|
||||||
|
|
||||||
|
|
||||||
class FloorRegistry(BaseRegistry):
|
class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
|
||||||
"""Class to hold a registry of floors."""
|
"""Class to hold a registry of floors."""
|
||||||
|
|
||||||
floors: NormalizedNameBaseRegistryItems[FloorEntry]
|
floors: NormalizedNameBaseRegistryItems[FloorEntry]
|
||||||
|
@ -54,14 +70,12 @@ class FloorRegistry(BaseRegistry):
|
||||||
def __init__(self, hass: HomeAssistant) -> None:
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
"""Initialize the floor registry."""
|
"""Initialize the floor registry."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self._store: Store[dict[str, list[dict[str, str | int | list[str] | None]]]] = (
|
self._store = Store(
|
||||||
Store(
|
|
||||||
hass,
|
hass,
|
||||||
STORAGE_VERSION_MAJOR,
|
STORAGE_VERSION_MAJOR,
|
||||||
STORAGE_KEY,
|
STORAGE_KEY,
|
||||||
atomic_writes=True,
|
atomic_writes=True,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_get_floor(self, floor_id: str) -> FloorEntry | None:
|
def async_get_floor(self, floor_id: str) -> FloorEntry | None:
|
||||||
|
@ -190,13 +204,6 @@ class FloorRegistry(BaseRegistry):
|
||||||
|
|
||||||
if data is not None:
|
if data is not None:
|
||||||
for floor in data["floors"]:
|
for floor in data["floors"]:
|
||||||
if TYPE_CHECKING:
|
|
||||||
assert isinstance(floor["aliases"], list)
|
|
||||||
assert isinstance(floor["icon"], str)
|
|
||||||
assert isinstance(floor["level"], int)
|
|
||||||
assert isinstance(floor["name"], str)
|
|
||||||
assert isinstance(floor["floor_id"], str)
|
|
||||||
|
|
||||||
normalized_name = normalize_name(floor["name"])
|
normalized_name = normalize_name(floor["name"])
|
||||||
floors[floor["floor_id"]] = FloorEntry(
|
floors[floor["floor_id"]] = FloorEntry(
|
||||||
aliases=set(floor["aliases"]),
|
aliases=set(floor["aliases"]),
|
||||||
|
@ -211,7 +218,7 @@ class FloorRegistry(BaseRegistry):
|
||||||
self._floor_data = floors.data
|
self._floor_data = floors.data
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _data_to_save(self) -> dict[str, list[dict[str, str | int | list[str] | None]]]:
|
def _data_to_save(self) -> FloorRegistryStoreData:
|
||||||
"""Return data of floor registry to store in a file."""
|
"""Return data of floor registry to store in a file."""
|
||||||
return {
|
return {
|
||||||
"floors": [
|
"floors": [
|
||||||
|
|
|
@ -25,6 +25,22 @@ STORAGE_KEY = "core.label_registry"
|
||||||
STORAGE_VERSION_MAJOR = 1
|
STORAGE_VERSION_MAJOR = 1
|
||||||
|
|
||||||
|
|
||||||
|
class _LabelStoreData(TypedDict):
|
||||||
|
"""Data type for individual label. Used in LabelRegistryStoreData."""
|
||||||
|
|
||||||
|
color: str | None
|
||||||
|
description: str | None
|
||||||
|
icon: str | None
|
||||||
|
label_id: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class LabelRegistryStoreData(TypedDict):
|
||||||
|
"""Store data type for LabelRegistry."""
|
||||||
|
|
||||||
|
labels: list[_LabelStoreData]
|
||||||
|
|
||||||
|
|
||||||
class EventLabelRegistryUpdatedData(TypedDict):
|
class EventLabelRegistryUpdatedData(TypedDict):
|
||||||
"""Event data for when the label registry is updated."""
|
"""Event data for when the label registry is updated."""
|
||||||
|
|
||||||
|
@ -45,7 +61,7 @@ class LabelEntry(NormalizedNameBaseRegistryEntry):
|
||||||
icon: str | None = None
|
icon: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class LabelRegistry(BaseRegistry):
|
class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
|
||||||
"""Class to hold a registry of labels."""
|
"""Class to hold a registry of labels."""
|
||||||
|
|
||||||
labels: NormalizedNameBaseRegistryItems[LabelEntry]
|
labels: NormalizedNameBaseRegistryItems[LabelEntry]
|
||||||
|
@ -54,7 +70,7 @@ class LabelRegistry(BaseRegistry):
|
||||||
def __init__(self, hass: HomeAssistant) -> None:
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
"""Initialize the label registry."""
|
"""Initialize the label registry."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self._store: Store[dict[str, list[dict[str, str | None]]]] = Store(
|
self._store = Store(
|
||||||
hass,
|
hass,
|
||||||
STORAGE_VERSION_MAJOR,
|
STORAGE_VERSION_MAJOR,
|
||||||
STORAGE_KEY,
|
STORAGE_KEY,
|
||||||
|
@ -189,10 +205,6 @@ class LabelRegistry(BaseRegistry):
|
||||||
|
|
||||||
if data is not None:
|
if data is not None:
|
||||||
for label in data["labels"]:
|
for label in data["labels"]:
|
||||||
# Check if the necessary keys are present
|
|
||||||
if label["label_id"] is None or label["name"] is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
normalized_name = normalize_name(label["name"])
|
normalized_name = normalize_name(label["name"])
|
||||||
labels[label["label_id"]] = LabelEntry(
|
labels[label["label_id"]] = LabelEntry(
|
||||||
color=label["color"],
|
color=label["color"],
|
||||||
|
@ -207,7 +219,7 @@ class LabelRegistry(BaseRegistry):
|
||||||
self._label_data = labels.data
|
self._label_data = labels.data
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _data_to_save(self) -> dict[str, list[dict[str, str | None]]]:
|
def _data_to_save(self) -> LabelRegistryStoreData:
|
||||||
"""Return data of label registry to store in a file."""
|
"""Return data of label registry to store in a file."""
|
||||||
return {
|
return {
|
||||||
"labels": [
|
"labels": [
|
||||||
|
|
|
@ -4,8 +4,8 @@ from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import UserDict
|
from collections import UserDict
|
||||||
from collections.abc import ValuesView
|
from collections.abc import Mapping, Sequence, ValuesView
|
||||||
from typing import TYPE_CHECKING, Any, Literal, TypeVar
|
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar
|
||||||
|
|
||||||
from homeassistant.core import CoreState, HomeAssistant, callback
|
from homeassistant.core import CoreState, HomeAssistant, callback
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ SAVE_DELAY_LONG = 180
|
||||||
|
|
||||||
|
|
||||||
_DataT = TypeVar("_DataT")
|
_DataT = TypeVar("_DataT")
|
||||||
|
_StoreDataT = TypeVar("_StoreDataT", bound=Mapping[str, Any] | Sequence[Any])
|
||||||
|
|
||||||
|
|
||||||
class BaseRegistryItems(UserDict[str, _DataT], ABC):
|
class BaseRegistryItems(UserDict[str, _DataT], ABC):
|
||||||
|
@ -64,11 +65,11 @@ class BaseRegistryItems(UserDict[str, _DataT], ABC):
|
||||||
super().__delitem__(key)
|
super().__delitem__(key)
|
||||||
|
|
||||||
|
|
||||||
class BaseRegistry(ABC):
|
class BaseRegistry(ABC, Generic[_StoreDataT]):
|
||||||
"""Class to implement a registry."""
|
"""Class to implement a registry."""
|
||||||
|
|
||||||
hass: HomeAssistant
|
hass: HomeAssistant
|
||||||
_store: Store
|
_store: Store[_StoreDataT]
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_schedule_save(self) -> None:
|
def async_schedule_save(self) -> None:
|
||||||
|
@ -80,5 +81,5 @@ class BaseRegistry(ABC):
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _data_to_save(self) -> dict[str, Any]:
|
def _data_to_save(self) -> _StoreDataT:
|
||||||
"""Return data of registry to store in a file."""
|
"""Return data of registry to store in a file."""
|
||||||
|
|
Loading…
Add table
Reference in a new issue