diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index 76daa1266dd..c9a9016560c 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -491,10 +491,71 @@ class DeviceRegistryItems(BaseRegistryItems[_EntryTypeT]): return None +class ActiveDeviceRegistryItems(DeviceRegistryItems[DeviceEntry]): + """Container for active (non-deleted) device registry entries.""" + + def __init__(self) -> None: + """Initialize the container. + + Maintains three additional indexes: + + - area_id -> dict[key, True] + - config_entry_id -> dict[key, True] + - label -> dict[key, True] + """ + super().__init__() + self._area_id_index: dict[str, dict[str, Literal[True]]] = {} + self._config_entry_id_index: dict[str, dict[str, Literal[True]]] = {} + self._labels_index: dict[str, dict[str, Literal[True]]] = {} + + def _index_entry(self, key: str, entry: DeviceEntry) -> None: + """Index an entry.""" + super()._index_entry(key, entry) + if (area_id := entry.area_id) is not None: + self._area_id_index.setdefault(area_id, {})[key] = True + for label in entry.labels: + self._labels_index.setdefault(label, {})[key] = True + for config_entry_id in entry.config_entries: + self._config_entry_id_index.setdefault(config_entry_id, {})[key] = True + + def _unindex_entry( + self, key: str, replacement_entry: DeviceEntry | None = None + ) -> None: + """Unindex an entry.""" + entry = self.data[key] + if area_id := entry.area_id: + self._unindex_entry_value(key, area_id, self._area_id_index) + if labels := entry.labels: + for label in labels: + self._unindex_entry_value(key, label, self._labels_index) + for config_entry_id in entry.config_entries: + self._unindex_entry_value(key, config_entry_id, self._config_entry_id_index) + super()._unindex_entry(key, replacement_entry) + + def get_devices_for_area_id(self, area_id: str) -> list[DeviceEntry]: + """Get devices for area.""" + data = self.data + return [data[key] for key in self._area_id_index.get(area_id, ())] + + def get_devices_for_label(self, label: str) -> list[DeviceEntry]: + """Get devices for label.""" + data = self.data + return [data[key] for key in self._labels_index.get(label, ())] + + def get_devices_for_config_entry_id( + self, config_entry_id: str + ) -> list[DeviceEntry]: + """Get devices for config entry.""" + data = self.data + return [ + data[key] for key in self._config_entry_id_index.get(config_entry_id, ()) + ] + + class DeviceRegistry(BaseRegistry): """Class to hold a registry of devices.""" - devices: DeviceRegistryItems[DeviceEntry] + devices: ActiveDeviceRegistryItems deleted_devices: DeviceRegistryItems[DeletedDeviceEntry] _device_data: dict[str, DeviceEntry] @@ -884,7 +945,7 @@ class DeviceRegistry(BaseRegistry): data = await self._store.async_load() - devices: DeviceRegistryItems[DeviceEntry] = DeviceRegistryItems() + devices = ActiveDeviceRegistryItems() deleted_devices: DeviceRegistryItems[DeletedDeviceEntry] = DeviceRegistryItems() if data is not None: @@ -1018,7 +1079,7 @@ async def async_load(hass: HomeAssistant) -> None: @callback def async_entries_for_area(registry: DeviceRegistry, area_id: str) -> list[DeviceEntry]: """Return entries that match an area.""" - return [device for device in registry.devices.values() if device.area_id == area_id] + return registry.devices.get_devices_for_area_id(area_id) @callback @@ -1026,7 +1087,7 @@ def async_entries_for_label( registry: DeviceRegistry, label_id: str ) -> list[DeviceEntry]: """Return entries that match a label.""" - return [device for device in registry.devices.values() if label_id in device.labels] + return registry.devices.get_devices_for_label(label_id) @callback @@ -1034,11 +1095,7 @@ def async_entries_for_config_entry( registry: DeviceRegistry, config_entry_id: str ) -> list[DeviceEntry]: """Return entries that match a config entry.""" - return [ - device - for device in registry.devices.values() - if config_entry_id in device.config_entries - ] + return registry.devices.get_devices_for_config_entry_id(config_entry_id) @callback diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 43942458233..00dfea23549 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -534,15 +534,14 @@ def async_extract_referenced_entity_ids( # noqa: C901 ): selected.indirectly_referenced.add(entity_entry.entity_id) - # Find areas, devices & entities for targeted labels + for device_entry in dev_reg.devices.get_devices_for_label(label_id): + selected.referenced_devices.add(device_entry.id) + + # Find areas for targeted labels for area_entry in area_reg.areas.values(): if area_entry.labels.intersection(selector.label_ids): selected.referenced_areas.add(area_entry.id) - for device_entry in dev_reg.devices.values(): - if device_entry.labels.intersection(selector.label_ids): - selected.referenced_devices.add(device_entry.id) - # Find areas for targeted floors if selector.floor_ids: for area_entry in area_reg.areas.values(): @@ -554,9 +553,11 @@ def async_extract_referenced_entity_ids( # noqa: C901 selected.referenced_areas.update(selector.area_ids) if selected.referenced_areas: - for device_entry in dev_reg.devices.values(): - if device_entry.area_id in selected.referenced_areas: - selected.referenced_devices.add(device_entry.id) + for area_id in selected.referenced_areas: + selected.referenced_devices.update( + device_entry.id + for device_entry in dev_reg.devices.get_devices_for_area_id(area_id) + ) if not selected.referenced_areas and not selected.referenced_devices: return selected diff --git a/tests/common.py b/tests/common.py index 210eb07d668..d3bcdcbd004 100644 --- a/tests/common.py +++ b/tests/common.py @@ -671,7 +671,7 @@ def mock_device_registry( fixture instead. """ registry = dr.DeviceRegistry(hass) - registry.devices = dr.DeviceRegistryItems() + registry.devices = dr.ActiveDeviceRegistryItems() registry._device_data = registry.devices.data if mock_entries is None: mock_entries = {}