diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index 9666ad302ad..76daa1266dd 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -2,8 +2,7 @@ from __future__ import annotations -from collections import UserDict -from collections.abc import Mapping, ValuesView +from collections.abc import Mapping from enum import StrEnum from functools import lru_cache, partial import logging @@ -31,7 +30,7 @@ from .deprecation import ( ) from .frame import report from .json import JSON_DUMP, find_paths_unserializable_data, json_bytes, json_fragment -from .registry import BaseRegistry +from .registry import BaseRegistry, BaseRegistryItems from .typing import UNDEFINED, UndefinedType if TYPE_CHECKING: @@ -443,7 +442,7 @@ class DeviceRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]): _EntryTypeT = TypeVar("_EntryTypeT", DeviceEntry, DeletedDeviceEntry) -class DeviceRegistryItems(UserDict[str, _EntryTypeT]): +class DeviceRegistryItems(BaseRegistryItems[_EntryTypeT]): """Container for device registry items, maps device id -> entry. Maintains two additional indexes: @@ -457,33 +456,22 @@ class DeviceRegistryItems(UserDict[str, _EntryTypeT]): self._connections: dict[tuple[str, str], _EntryTypeT] = {} self._identifiers: dict[tuple[str, str], _EntryTypeT] = {} - def values(self) -> ValuesView[_EntryTypeT]: - """Return the underlying values to avoid __iter__ overhead.""" - return self.data.values() - - def __setitem__(self, key: str, entry: _EntryTypeT) -> None: - """Add an item.""" - data = self.data - if key in data: - old_entry = data[key] - for connection in old_entry.connections: - del self._connections[connection] - for identifier in old_entry.identifiers: - del self._identifiers[identifier] - data[key] = entry + def _index_entry(self, key: str, entry: _EntryTypeT) -> None: + """Index an entry.""" for connection in entry.connections: self._connections[connection] = entry for identifier in entry.identifiers: self._identifiers[identifier] = entry - def __delitem__(self, key: str) -> None: - """Remove an item.""" - entry = self[key] - for connection in entry.connections: + def _unindex_entry( + self, key: str, replacement_entry: _EntryTypeT | None = None + ) -> None: + """Unindex an entry.""" + old_entry = self.data[key] + for connection in old_entry.connections: del self._connections[connection] - for identifier in entry.identifiers: + for identifier in old_entry.identifiers: del self._identifiers[identifier] - super().__delitem__(key) def get_entry( self, diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index ad9ddcd5c4c..e19c4290a1d 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -10,8 +10,7 @@ timer. from __future__ import annotations -from collections import UserDict -from collections.abc import Callable, Iterable, KeysView, Mapping, ValuesView +from collections.abc import Callable, Iterable, KeysView, Mapping from datetime import datetime, timedelta from enum import StrEnum import logging @@ -53,7 +52,7 @@ from homeassistant.util.read_only_dict import ReadOnlyDict from . import device_registry as dr, storage from .device_registry import EVENT_DEVICE_REGISTRY_UPDATED from .json import JSON_DUMP, find_paths_unserializable_data, json_bytes, json_fragment -from .registry import BaseRegistry +from .registry import BaseRegistry, BaseRegistryItems from .typing import UNDEFINED, UndefinedType if TYPE_CHECKING: @@ -510,7 +509,7 @@ class EntityRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]): return data -class EntityRegistryItems(UserDict[str, RegistryEntry]): +class EntityRegistryItems(BaseRegistryItems[RegistryEntry]): """Container for entity registry items, maps entity_id -> entry. Maintains four additional indexes: @@ -529,16 +528,8 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]): self._device_id_index: dict[str, dict[str, Literal[True]]] = {} self._area_id_index: dict[str, dict[str, Literal[True]]] = {} - def values(self) -> ValuesView[RegistryEntry]: - """Return the underlying values to avoid __iter__ overhead.""" - return self.data.values() - - def __setitem__(self, key: str, entry: RegistryEntry) -> None: - """Add an item.""" - data = self.data - if key in data: - self._unindex_entry(key) - data[key] = entry + def _index_entry(self, key: str, entry: RegistryEntry) -> None: + """Index an entry.""" self._entry_ids[entry.id] = entry self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id # python has no ordered set, so we use a dict with True values @@ -550,21 +541,9 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]): if (area_id := entry.area_id) is not None: self._area_id_index.setdefault(area_id, {})[key] = True - def _unindex_entry_value( - self, key: str, value: str, index: dict[str, dict[str, Literal[True]]] + def _unindex_entry( + self, key: str, replacement_entry: RegistryEntry | None = None ) -> None: - """Unindex an entry value. - - key is the entry key - value is the value to unindex such as config_entry_id or device_id. - index is the index to unindex from. - """ - entries = index[value] - del entries[key] - if not entries: - del index[value] - - def _unindex_entry(self, key: str) -> None: """Unindex an entry.""" entry = self.data[key] del self._entry_ids[entry.id] @@ -576,11 +555,6 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]): if area_id := entry.area_id: self._unindex_entry_value(key, area_id, self._area_id_index) - def __delitem__(self, key: str) -> None: - """Remove an item.""" - self._unindex_entry(key) - super().__delitem__(key) - def get_device_ids(self) -> KeysView[str]: """Return device ids.""" return self._device_id_index.keys() diff --git a/homeassistant/helpers/normalized_name_base_registry.py b/homeassistant/helpers/normalized_name_base_registry.py index 16280a73750..f14d99b7831 100644 --- a/homeassistant/helpers/normalized_name_base_registry.py +++ b/homeassistant/helpers/normalized_name_base_registry.py @@ -1,10 +1,11 @@ """Provide a base class for registries that use a normalized name index.""" -from collections import UserDict -from collections.abc import ValuesView from dataclasses import dataclass +from functools import lru_cache from typing import TypeVar +from .registry import BaseRegistryItems + @dataclass(slots=True, frozen=True, kw_only=True) class NormalizedNameBaseRegistryEntry: @@ -17,12 +18,13 @@ class NormalizedNameBaseRegistryEntry: _VT = TypeVar("_VT", bound=NormalizedNameBaseRegistryEntry) +@lru_cache(maxsize=1024) def normalize_name(name: str) -> str: """Normalize a name by removing whitespace and case folding.""" return name.casefold().replace(" ", "") -class NormalizedNameBaseRegistryItems(UserDict[str, _VT]): +class NormalizedNameBaseRegistryItems(BaseRegistryItems[_VT]): """Base container for normalized name registry items, maps key -> entry. Maintains an additional index: @@ -34,34 +36,21 @@ class NormalizedNameBaseRegistryItems(UserDict[str, _VT]): super().__init__() self._normalized_names: dict[str, _VT] = {} - def values(self) -> ValuesView[_VT]: - """Return the underlying values to avoid __iter__ overhead.""" - return self.data.values() + def _unindex_entry(self, key: str, replacement_entry: _VT | None = None) -> None: + old_entry = self.data[key] + if ( + replacement_entry is not None + and (normalized_name := normalize_name(replacement_entry.name)) + != old_entry.normalized_name + and normalized_name in self._normalized_names + ): + raise ValueError( + f"The name {replacement_entry.name} ({normalized_name}) is already in use" + ) + del self._normalized_names[old_entry.normalized_name] - def __setitem__(self, key: str, entry: _VT) -> None: - """Add an item.""" - data = self.data - normalized_name = normalize_name(entry.name) - - if key in data: - old_entry = data[key] - if ( - normalized_name != old_entry.normalized_name - and normalized_name in self._normalized_names - ): - raise ValueError( - f"The name {entry.name} ({normalized_name}) is already in use" - ) - del self._normalized_names[old_entry.normalized_name] - data[key] = entry - self._normalized_names[normalized_name] = entry - - def __delitem__(self, key: str) -> None: - """Remove an item.""" - entry = self[key] - normalized_name = normalize_name(entry.name) - del self._normalized_names[normalized_name] - super().__delitem__(key) + def _index_entry(self, key: str, entry: _VT) -> None: + self._normalized_names[normalize_name(entry.name)] = entry def get_by_name(self, name: str) -> _VT | None: """Get entry by name.""" diff --git a/homeassistant/helpers/registry.py b/homeassistant/helpers/registry.py index d5b1035531a..0057190848a 100644 --- a/homeassistant/helpers/registry.py +++ b/homeassistant/helpers/registry.py @@ -3,7 +3,9 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any +from collections import UserDict +from collections.abc import ValuesView +from typing import TYPE_CHECKING, Any, Literal, TypeVar from homeassistant.core import CoreState, HomeAssistant, callback @@ -14,6 +16,54 @@ SAVE_DELAY = 10 SAVE_DELAY_LONG = 180 +_DataT = TypeVar("_DataT") + + +class BaseRegistryItems(UserDict[str, _DataT], ABC): + """Base class for registry items.""" + + data: dict[str, _DataT] + + def values(self) -> ValuesView[_DataT]: + """Return the underlying values to avoid __iter__ overhead.""" + return self.data.values() + + @abstractmethod + def _index_entry(self, key: str, entry: _DataT) -> None: + """Index an entry.""" + + @abstractmethod + def _unindex_entry(self, key: str, replacement_entry: _DataT | None = None) -> None: + """Unindex an entry.""" + + def __setitem__(self, key: str, entry: _DataT) -> None: + """Add an item.""" + data = self.data + if key in data: + self._unindex_entry(key, entry) + data[key] = entry + self._index_entry(key, entry) + + def _unindex_entry_value( + self, key: str, value: str, index: dict[str, dict[str, Literal[True]]] + ) -> None: + """Unindex an entry value. + + key is the entry key + value is the value to unindex such as config_entry_id or device_id. + index is the index to unindex from. + """ + entries = index[value] + del entries[key] + if not entries: + del index[value] + + def __delitem__(self, key: str) -> None: + """Remove an item.""" + self._unindex_entry(key) + super().__delitem__(key) + + class BaseRegistry(ABC): """Class to implement a registry."""