Reduce code for registry items with a base class (#114689)

This commit is contained in:
J. Nick Koston 2024-04-02 21:02:32 -10:00 committed by GitHub
parent d17f308c6a
commit adbaed2c6d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 89 additions and 88 deletions

View file

@ -2,8 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections import UserDict from collections.abc import Mapping
from collections.abc import Mapping, ValuesView
from enum import StrEnum from enum import StrEnum
from functools import lru_cache, partial from functools import lru_cache, partial
import logging import logging
@ -31,7 +30,7 @@ from .deprecation import (
) )
from .frame import report from .frame import report
from .json import JSON_DUMP, find_paths_unserializable_data, json_bytes, json_fragment from .json import JSON_DUMP, find_paths_unserializable_data, json_bytes, json_fragment
from .registry import BaseRegistry from .registry import BaseRegistry, BaseRegistryItems
from .typing import UNDEFINED, UndefinedType from .typing import UNDEFINED, UndefinedType
if TYPE_CHECKING: if TYPE_CHECKING:
@ -443,7 +442,7 @@ class DeviceRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]):
_EntryTypeT = TypeVar("_EntryTypeT", DeviceEntry, DeletedDeviceEntry) _EntryTypeT = TypeVar("_EntryTypeT", DeviceEntry, DeletedDeviceEntry)
class DeviceRegistryItems(UserDict[str, _EntryTypeT]): class DeviceRegistryItems(BaseRegistryItems[_EntryTypeT]):
"""Container for device registry items, maps device id -> entry. """Container for device registry items, maps device id -> entry.
Maintains two additional indexes: Maintains two additional indexes:
@ -457,33 +456,22 @@ class DeviceRegistryItems(UserDict[str, _EntryTypeT]):
self._connections: dict[tuple[str, str], _EntryTypeT] = {} self._connections: dict[tuple[str, str], _EntryTypeT] = {}
self._identifiers: dict[tuple[str, str], _EntryTypeT] = {} self._identifiers: dict[tuple[str, str], _EntryTypeT] = {}
def values(self) -> ValuesView[_EntryTypeT]: def _index_entry(self, key: str, entry: _EntryTypeT) -> None:
"""Return the underlying values to avoid __iter__ overhead.""" """Index an entry."""
return self.data.values()
def __setitem__(self, key: str, entry: _EntryTypeT) -> None:
"""Add an item."""
data = self.data
if key in data:
old_entry = data[key]
for connection in old_entry.connections:
del self._connections[connection]
for identifier in old_entry.identifiers:
del self._identifiers[identifier]
data[key] = entry
for connection in entry.connections: for connection in entry.connections:
self._connections[connection] = entry self._connections[connection] = entry
for identifier in entry.identifiers: for identifier in entry.identifiers:
self._identifiers[identifier] = entry self._identifiers[identifier] = entry
def __delitem__(self, key: str) -> None: def _unindex_entry(
"""Remove an item.""" self, key: str, replacement_entry: _EntryTypeT | None = None
entry = self[key] ) -> None:
for connection in entry.connections: """Unindex an entry."""
old_entry = self.data[key]
for connection in old_entry.connections:
del self._connections[connection] del self._connections[connection]
for identifier in entry.identifiers: for identifier in old_entry.identifiers:
del self._identifiers[identifier] del self._identifiers[identifier]
super().__delitem__(key)
def get_entry( def get_entry(
self, self,

View file

@ -10,8 +10,7 @@ timer.
from __future__ import annotations from __future__ import annotations
from collections import UserDict from collections.abc import Callable, Iterable, KeysView, Mapping
from collections.abc import Callable, Iterable, KeysView, Mapping, ValuesView
from datetime import datetime, timedelta from datetime import datetime, timedelta
from enum import StrEnum from enum import StrEnum
import logging import logging
@ -53,7 +52,7 @@ from homeassistant.util.read_only_dict import ReadOnlyDict
from . import device_registry as dr, storage from . import device_registry as dr, storage
from .device_registry import EVENT_DEVICE_REGISTRY_UPDATED from .device_registry import EVENT_DEVICE_REGISTRY_UPDATED
from .json import JSON_DUMP, find_paths_unserializable_data, json_bytes, json_fragment from .json import JSON_DUMP, find_paths_unserializable_data, json_bytes, json_fragment
from .registry import BaseRegistry from .registry import BaseRegistry, BaseRegistryItems
from .typing import UNDEFINED, UndefinedType from .typing import UNDEFINED, UndefinedType
if TYPE_CHECKING: if TYPE_CHECKING:
@ -510,7 +509,7 @@ class EntityRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]):
return data return data
class EntityRegistryItems(UserDict[str, RegistryEntry]): class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
"""Container for entity registry items, maps entity_id -> entry. """Container for entity registry items, maps entity_id -> entry.
Maintains four additional indexes: Maintains four additional indexes:
@ -529,16 +528,8 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]):
self._device_id_index: dict[str, dict[str, Literal[True]]] = {} self._device_id_index: dict[str, dict[str, Literal[True]]] = {}
self._area_id_index: dict[str, dict[str, Literal[True]]] = {} self._area_id_index: dict[str, dict[str, Literal[True]]] = {}
def values(self) -> ValuesView[RegistryEntry]: def _index_entry(self, key: str, entry: RegistryEntry) -> None:
"""Return the underlying values to avoid __iter__ overhead.""" """Index an entry."""
return self.data.values()
def __setitem__(self, key: str, entry: RegistryEntry) -> None:
"""Add an item."""
data = self.data
if key in data:
self._unindex_entry(key)
data[key] = entry
self._entry_ids[entry.id] = entry self._entry_ids[entry.id] = entry
self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id
# python has no ordered set, so we use a dict with True values # python has no ordered set, so we use a dict with True values
@ -550,21 +541,9 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]):
if (area_id := entry.area_id) is not None: if (area_id := entry.area_id) is not None:
self._area_id_index.setdefault(area_id, {})[key] = True self._area_id_index.setdefault(area_id, {})[key] = True
def _unindex_entry_value( def _unindex_entry(
self, key: str, value: str, index: dict[str, dict[str, Literal[True]]] self, key: str, replacement_entry: RegistryEntry | None = None
) -> None: ) -> None:
"""Unindex an entry value.
key is the entry key
value is the value to unindex such as config_entry_id or device_id.
index is the index to unindex from.
"""
entries = index[value]
del entries[key]
if not entries:
del index[value]
def _unindex_entry(self, key: str) -> None:
"""Unindex an entry.""" """Unindex an entry."""
entry = self.data[key] entry = self.data[key]
del self._entry_ids[entry.id] del self._entry_ids[entry.id]
@ -576,11 +555,6 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]):
if area_id := entry.area_id: if area_id := entry.area_id:
self._unindex_entry_value(key, area_id, self._area_id_index) self._unindex_entry_value(key, area_id, self._area_id_index)
def __delitem__(self, key: str) -> None:
"""Remove an item."""
self._unindex_entry(key)
super().__delitem__(key)
def get_device_ids(self) -> KeysView[str]: def get_device_ids(self) -> KeysView[str]:
"""Return device ids.""" """Return device ids."""
return self._device_id_index.keys() return self._device_id_index.keys()

View file

@ -1,10 +1,11 @@
"""Provide a base class for registries that use a normalized name index.""" """Provide a base class for registries that use a normalized name index."""
from collections import UserDict
from collections.abc import ValuesView
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache
from typing import TypeVar from typing import TypeVar
from .registry import BaseRegistryItems
@dataclass(slots=True, frozen=True, kw_only=True) @dataclass(slots=True, frozen=True, kw_only=True)
class NormalizedNameBaseRegistryEntry: class NormalizedNameBaseRegistryEntry:
@ -17,12 +18,13 @@ class NormalizedNameBaseRegistryEntry:
_VT = TypeVar("_VT", bound=NormalizedNameBaseRegistryEntry) _VT = TypeVar("_VT", bound=NormalizedNameBaseRegistryEntry)
@lru_cache(maxsize=1024)
def normalize_name(name: str) -> str: def normalize_name(name: str) -> str:
"""Normalize a name by removing whitespace and case folding.""" """Normalize a name by removing whitespace and case folding."""
return name.casefold().replace(" ", "") return name.casefold().replace(" ", "")
class NormalizedNameBaseRegistryItems(UserDict[str, _VT]): class NormalizedNameBaseRegistryItems(BaseRegistryItems[_VT]):
"""Base container for normalized name registry items, maps key -> entry. """Base container for normalized name registry items, maps key -> entry.
Maintains an additional index: Maintains an additional index:
@ -34,34 +36,21 @@ class NormalizedNameBaseRegistryItems(UserDict[str, _VT]):
super().__init__() super().__init__()
self._normalized_names: dict[str, _VT] = {} self._normalized_names: dict[str, _VT] = {}
def values(self) -> ValuesView[_VT]: def _unindex_entry(self, key: str, replacement_entry: _VT | None = None) -> None:
"""Return the underlying values to avoid __iter__ overhead.""" old_entry = self.data[key]
return self.data.values() if (
replacement_entry is not None
and (normalized_name := normalize_name(replacement_entry.name))
!= old_entry.normalized_name
and normalized_name in self._normalized_names
):
raise ValueError(
f"The name {replacement_entry.name} ({normalized_name}) is already in use"
)
del self._normalized_names[old_entry.normalized_name]
def __setitem__(self, key: str, entry: _VT) -> None: def _index_entry(self, key: str, entry: _VT) -> None:
"""Add an item.""" self._normalized_names[normalize_name(entry.name)] = entry
data = self.data
normalized_name = normalize_name(entry.name)
if key in data:
old_entry = data[key]
if (
normalized_name != old_entry.normalized_name
and normalized_name in self._normalized_names
):
raise ValueError(
f"The name {entry.name} ({normalized_name}) is already in use"
)
del self._normalized_names[old_entry.normalized_name]
data[key] = entry
self._normalized_names[normalized_name] = entry
def __delitem__(self, key: str) -> None:
"""Remove an item."""
entry = self[key]
normalized_name = normalize_name(entry.name)
del self._normalized_names[normalized_name]
super().__delitem__(key)
def get_by_name(self, name: str) -> _VT | None: def get_by_name(self, name: str) -> _VT | None:
"""Get entry by name.""" """Get entry by name."""

View file

@ -3,7 +3,9 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any from collections import UserDict
from collections.abc import ValuesView
from typing import TYPE_CHECKING, Any, Literal, TypeVar
from homeassistant.core import CoreState, HomeAssistant, callback from homeassistant.core import CoreState, HomeAssistant, callback
@ -14,6 +16,54 @@ SAVE_DELAY = 10
SAVE_DELAY_LONG = 180 SAVE_DELAY_LONG = 180
_DataT = TypeVar("_DataT")
class BaseRegistryItems(UserDict[str, _DataT], ABC):
"""Base class for registry items."""
data: dict[str, _DataT]
def values(self) -> ValuesView[_DataT]:
"""Return the underlying values to avoid __iter__ overhead."""
return self.data.values()
@abstractmethod
def _index_entry(self, key: str, entry: _DataT) -> None:
"""Index an entry."""
@abstractmethod
def _unindex_entry(self, key: str, replacement_entry: _DataT | None = None) -> None:
"""Unindex an entry."""
def __setitem__(self, key: str, entry: _DataT) -> None:
"""Add an item."""
data = self.data
if key in data:
self._unindex_entry(key, entry)
data[key] = entry
self._index_entry(key, entry)
def _unindex_entry_value(
self, key: str, value: str, index: dict[str, dict[str, Literal[True]]]
) -> None:
"""Unindex an entry value.
key is the entry key
value is the value to unindex such as config_entry_id or device_id.
index is the index to unindex from.
"""
entries = index[value]
del entries[key]
if not entries:
del index[value]
def __delitem__(self, key: str) -> None:
"""Remove an item."""
self._unindex_entry(key)
super().__delitem__(key)
class BaseRegistry(ABC): class BaseRegistry(ABC):
"""Class to implement a registry.""" """Class to implement a registry."""