Reduce code for registry items with a base class (#114689)

This commit is contained in:
J. Nick Koston 2024-04-02 21:02:32 -10:00 committed by GitHub
parent d17f308c6a
commit adbaed2c6d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 89 additions and 88 deletions

View file

@ -2,8 +2,7 @@
from __future__ import annotations
from collections import UserDict
from collections.abc import Mapping, ValuesView
from collections.abc import Mapping
from enum import StrEnum
from functools import lru_cache, partial
import logging
@ -31,7 +30,7 @@ from .deprecation import (
)
from .frame import report
from .json import JSON_DUMP, find_paths_unserializable_data, json_bytes, json_fragment
from .registry import BaseRegistry
from .registry import BaseRegistry, BaseRegistryItems
from .typing import UNDEFINED, UndefinedType
if TYPE_CHECKING:
@ -443,7 +442,7 @@ class DeviceRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]):
_EntryTypeT = TypeVar("_EntryTypeT", DeviceEntry, DeletedDeviceEntry)
class DeviceRegistryItems(UserDict[str, _EntryTypeT]):
class DeviceRegistryItems(BaseRegistryItems[_EntryTypeT]):
"""Container for device registry items, maps device id -> entry.
Maintains two additional indexes:
@ -457,33 +456,22 @@ class DeviceRegistryItems(UserDict[str, _EntryTypeT]):
self._connections: dict[tuple[str, str], _EntryTypeT] = {}
self._identifiers: dict[tuple[str, str], _EntryTypeT] = {}
def values(self) -> ValuesView[_EntryTypeT]:
"""Return the underlying values to avoid __iter__ overhead."""
return self.data.values()
def __setitem__(self, key: str, entry: _EntryTypeT) -> None:
"""Add an item."""
data = self.data
if key in data:
old_entry = data[key]
for connection in old_entry.connections:
del self._connections[connection]
for identifier in old_entry.identifiers:
del self._identifiers[identifier]
data[key] = entry
def _index_entry(self, key: str, entry: _EntryTypeT) -> None:
"""Index an entry."""
for connection in entry.connections:
self._connections[connection] = entry
for identifier in entry.identifiers:
self._identifiers[identifier] = entry
def __delitem__(self, key: str) -> None:
"""Remove an item."""
entry = self[key]
for connection in entry.connections:
def _unindex_entry(
self, key: str, replacement_entry: _EntryTypeT | None = None
) -> None:
"""Unindex an entry."""
old_entry = self.data[key]
for connection in old_entry.connections:
del self._connections[connection]
for identifier in entry.identifiers:
for identifier in old_entry.identifiers:
del self._identifiers[identifier]
super().__delitem__(key)
def get_entry(
self,

View file

@ -10,8 +10,7 @@ timer.
from __future__ import annotations
from collections import UserDict
from collections.abc import Callable, Iterable, KeysView, Mapping, ValuesView
from collections.abc import Callable, Iterable, KeysView, Mapping
from datetime import datetime, timedelta
from enum import StrEnum
import logging
@ -53,7 +52,7 @@ from homeassistant.util.read_only_dict import ReadOnlyDict
from . import device_registry as dr, storage
from .device_registry import EVENT_DEVICE_REGISTRY_UPDATED
from .json import JSON_DUMP, find_paths_unserializable_data, json_bytes, json_fragment
from .registry import BaseRegistry
from .registry import BaseRegistry, BaseRegistryItems
from .typing import UNDEFINED, UndefinedType
if TYPE_CHECKING:
@ -510,7 +509,7 @@ class EntityRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]):
return data
class EntityRegistryItems(UserDict[str, RegistryEntry]):
class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
"""Container for entity registry items, maps entity_id -> entry.
Maintains four additional indexes:
@ -529,16 +528,8 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]):
self._device_id_index: dict[str, dict[str, Literal[True]]] = {}
self._area_id_index: dict[str, dict[str, Literal[True]]] = {}
def values(self) -> ValuesView[RegistryEntry]:
"""Return the underlying values to avoid __iter__ overhead."""
return self.data.values()
def __setitem__(self, key: str, entry: RegistryEntry) -> None:
"""Add an item."""
data = self.data
if key in data:
self._unindex_entry(key)
data[key] = entry
def _index_entry(self, key: str, entry: RegistryEntry) -> None:
"""Index an entry."""
self._entry_ids[entry.id] = entry
self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id
# python has no ordered set, so we use a dict with True values
@ -550,21 +541,9 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]):
if (area_id := entry.area_id) is not None:
self._area_id_index.setdefault(area_id, {})[key] = True
def _unindex_entry_value(
self, key: str, value: str, index: dict[str, dict[str, Literal[True]]]
def _unindex_entry(
self, key: str, replacement_entry: RegistryEntry | None = None
) -> None:
"""Unindex an entry value.
key is the entry key
value is the value to unindex such as config_entry_id or device_id.
index is the index to unindex from.
"""
entries = index[value]
del entries[key]
if not entries:
del index[value]
def _unindex_entry(self, key: str) -> None:
"""Unindex an entry."""
entry = self.data[key]
del self._entry_ids[entry.id]
@ -576,11 +555,6 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]):
if area_id := entry.area_id:
self._unindex_entry_value(key, area_id, self._area_id_index)
def __delitem__(self, key: str) -> None:
"""Remove an item."""
self._unindex_entry(key)
super().__delitem__(key)
def get_device_ids(self) -> KeysView[str]:
"""Return device ids."""
return self._device_id_index.keys()

View file

@ -1,10 +1,11 @@
"""Provide a base class for registries that use a normalized name index."""
from collections import UserDict
from collections.abc import ValuesView
from dataclasses import dataclass
from functools import lru_cache
from typing import TypeVar
from .registry import BaseRegistryItems
@dataclass(slots=True, frozen=True, kw_only=True)
class NormalizedNameBaseRegistryEntry:
@ -17,12 +18,13 @@ class NormalizedNameBaseRegistryEntry:
_VT = TypeVar("_VT", bound=NormalizedNameBaseRegistryEntry)
@lru_cache(maxsize=1024)
def normalize_name(name: str) -> str:
"""Normalize a name by removing whitespace and case folding."""
return name.casefold().replace(" ", "")
class NormalizedNameBaseRegistryItems(UserDict[str, _VT]):
class NormalizedNameBaseRegistryItems(BaseRegistryItems[_VT]):
"""Base container for normalized name registry items, maps key -> entry.
Maintains an additional index:
@ -34,34 +36,21 @@ class NormalizedNameBaseRegistryItems(UserDict[str, _VT]):
super().__init__()
self._normalized_names: dict[str, _VT] = {}
def values(self) -> ValuesView[_VT]:
"""Return the underlying values to avoid __iter__ overhead."""
return self.data.values()
def _unindex_entry(self, key: str, replacement_entry: _VT | None = None) -> None:
old_entry = self.data[key]
if (
replacement_entry is not None
and (normalized_name := normalize_name(replacement_entry.name))
!= old_entry.normalized_name
and normalized_name in self._normalized_names
):
raise ValueError(
f"The name {replacement_entry.name} ({normalized_name}) is already in use"
)
del self._normalized_names[old_entry.normalized_name]
def __setitem__(self, key: str, entry: _VT) -> None:
"""Add an item."""
data = self.data
normalized_name = normalize_name(entry.name)
if key in data:
old_entry = data[key]
if (
normalized_name != old_entry.normalized_name
and normalized_name in self._normalized_names
):
raise ValueError(
f"The name {entry.name} ({normalized_name}) is already in use"
)
del self._normalized_names[old_entry.normalized_name]
data[key] = entry
self._normalized_names[normalized_name] = entry
def __delitem__(self, key: str) -> None:
"""Remove an item."""
entry = self[key]
normalized_name = normalize_name(entry.name)
del self._normalized_names[normalized_name]
super().__delitem__(key)
def _index_entry(self, key: str, entry: _VT) -> None:
self._normalized_names[normalize_name(entry.name)] = entry
def get_by_name(self, name: str) -> _VT | None:
"""Get entry by name."""

View file

@ -3,7 +3,9 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
from collections import UserDict
from collections.abc import ValuesView
from typing import TYPE_CHECKING, Any, Literal, TypeVar
from homeassistant.core import CoreState, HomeAssistant, callback
@ -14,6 +16,54 @@ SAVE_DELAY = 10
SAVE_DELAY_LONG = 180
_DataT = TypeVar("_DataT")
class BaseRegistryItems(UserDict[str, _DataT], ABC):
"""Base class for registry items."""
data: dict[str, _DataT]
def values(self) -> ValuesView[_DataT]:
"""Return the underlying values to avoid __iter__ overhead."""
return self.data.values()
@abstractmethod
def _index_entry(self, key: str, entry: _DataT) -> None:
"""Index an entry."""
@abstractmethod
def _unindex_entry(self, key: str, replacement_entry: _DataT | None = None) -> None:
"""Unindex an entry."""
def __setitem__(self, key: str, entry: _DataT) -> None:
"""Add an item."""
data = self.data
if key in data:
self._unindex_entry(key, entry)
data[key] = entry
self._index_entry(key, entry)
def _unindex_entry_value(
self, key: str, value: str, index: dict[str, dict[str, Literal[True]]]
) -> None:
"""Unindex an entry value.
key is the entry key
value is the value to unindex such as config_entry_id or device_id.
index is the index to unindex from.
"""
entries = index[value]
del entries[key]
if not entries:
del index[value]
def __delitem__(self, key: str) -> None:
"""Remove an item."""
self._unindex_entry(key)
super().__delitem__(key)
class BaseRegistry(ABC):
"""Class to implement a registry."""