Improve performance of extracting entities by label (#114720)

This commit is contained in:
J. Nick Koston 2024-04-03 10:24:44 -10:00 committed by GitHub
parent 3d8a110908
commit e86fec310b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 46 additions and 27 deletions

View file

@ -512,11 +512,13 @@ class EntityRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]):
class EntityRegistryItems(BaseRegistryItems[RegistryEntry]): class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
"""Container for entity registry items, maps entity_id -> entry. """Container for entity registry items, maps entity_id -> entry.
Maintains four additional indexes: Maintains six additional indexes:
- id -> entry - id -> entry
- (domain, platform, unique_id) -> entity_id - (domain, platform, unique_id) -> entity_id
- config_entry_id -> list[key] - config_entry_id -> dict[key, True]
- device_id -> list[key] - device_id -> dict[key, True]
- area_id -> dict[key, True]
- label -> dict[key, True]
""" """
def __init__(self) -> None: def __init__(self) -> None:
@ -527,6 +529,7 @@ class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
self._config_entry_id_index: dict[str, dict[str, Literal[True]]] = {} self._config_entry_id_index: dict[str, dict[str, Literal[True]]] = {}
self._device_id_index: dict[str, dict[str, Literal[True]]] = {} self._device_id_index: dict[str, dict[str, Literal[True]]] = {}
self._area_id_index: dict[str, dict[str, Literal[True]]] = {} self._area_id_index: dict[str, dict[str, Literal[True]]] = {}
self._labels_index: dict[str, dict[str, Literal[True]]] = {}
def _index_entry(self, key: str, entry: RegistryEntry) -> None: def _index_entry(self, key: str, entry: RegistryEntry) -> None:
"""Index an entry.""" """Index an entry."""
@ -540,6 +543,8 @@ class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
self._device_id_index.setdefault(device_id, {})[key] = True self._device_id_index.setdefault(device_id, {})[key] = True
if (area_id := entry.area_id) is not None: if (area_id := entry.area_id) is not None:
self._area_id_index.setdefault(area_id, {})[key] = True self._area_id_index.setdefault(area_id, {})[key] = True
for label in entry.labels:
self._labels_index.setdefault(label, {})[key] = True
def _unindex_entry( def _unindex_entry(
self, key: str, replacement_entry: RegistryEntry | None = None self, key: str, replacement_entry: RegistryEntry | None = None
@ -554,6 +559,9 @@ class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
self._unindex_entry_value(key, device_id, self._device_id_index) self._unindex_entry_value(key, device_id, self._device_id_index)
if area_id := entry.area_id: if area_id := entry.area_id:
self._unindex_entry_value(key, area_id, self._area_id_index) self._unindex_entry_value(key, area_id, self._area_id_index)
if labels := entry.labels:
for label in labels:
self._unindex_entry_value(key, label, self._labels_index)
def get_device_ids(self) -> KeysView[str]: def get_device_ids(self) -> KeysView[str]:
"""Return device ids.""" """Return device ids."""
@ -592,6 +600,11 @@ class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
data = self.data data = self.data
return [data[key] for key in self._area_id_index.get(area_id, ())] return [data[key] for key in self._area_id_index.get(area_id, ())]
def get_entries_for_label(self, label: str) -> list[RegistryEntry]:
"""Get entries for label."""
data = self.data
return [data[key] for key in self._labels_index.get(label, ())]
class EntityRegistry(BaseRegistry): class EntityRegistry(BaseRegistry):
"""Class to hold a registry of entities.""" """Class to hold a registry of entities."""
@ -1317,7 +1330,7 @@ def async_entries_for_label(
registry: EntityRegistry, label_id: str registry: EntityRegistry, label_id: str
) -> list[RegistryEntry]: ) -> list[RegistryEntry]:
"""Return entries that match a label.""" """Return entries that match a label."""
return [entry for entry in registry.entities.values() if label_id in entry.labels] return registry.entities.get_entries_for_label(label_id)
@callback @callback

View file

@ -503,15 +503,15 @@ def async_extract_referenced_entity_ids( # noqa: C901
): ):
return selected return selected
ent_reg = entity_registry.async_get(hass) entities = entity_registry.async_get(hass).entities
dev_reg = device_registry.async_get(hass) dev_reg = device_registry.async_get(hass)
area_reg = area_registry.async_get(hass) area_reg = area_registry.async_get(hass)
floor_reg = floor_registry.async_get(hass)
label_reg = label_registry.async_get(hass)
for floor_id in selector.floor_ids: if selector.floor_ids:
if floor_id not in floor_reg.floors: floor_reg = floor_registry.async_get(hass)
selected.missing_floors.add(floor_id) for floor_id in selector.floor_ids:
if floor_id not in floor_reg.floors:
selected.missing_floors.add(floor_id)
for area_id in selector.area_ids: for area_id in selector.area_ids:
if area_id not in area_reg.areas: if area_id not in area_reg.areas:
@ -521,12 +521,20 @@ def async_extract_referenced_entity_ids( # noqa: C901
if device_id not in dev_reg.devices: if device_id not in dev_reg.devices:
selected.missing_devices.add(device_id) selected.missing_devices.add(device_id)
for label_id in selector.label_ids:
if label_id not in label_reg.labels:
selected.missing_labels.add(label_id)
# Find areas, devices & entities for targeted labels
if selector.label_ids: if selector.label_ids:
label_reg = label_registry.async_get(hass)
for label_id in selector.label_ids:
if label_id not in label_reg.labels:
selected.missing_labels.add(label_id)
for entity_entry in entities.get_entries_for_label(label_id):
if (
entity_entry.entity_category is None
and entity_entry.hidden_by is None
):
selected.indirectly_referenced.add(entity_entry.entity_id)
# Find areas, devices & entities for targeted labels
for area_entry in area_reg.areas.values(): for area_entry in area_reg.areas.values():
if area_entry.labels.intersection(selector.label_ids): if area_entry.labels.intersection(selector.label_ids):
selected.referenced_areas.add(area_entry.id) selected.referenced_areas.add(area_entry.id)
@ -535,14 +543,6 @@ def async_extract_referenced_entity_ids( # noqa: C901
if device_entry.labels.intersection(selector.label_ids): if device_entry.labels.intersection(selector.label_ids):
selected.referenced_devices.add(device_entry.id) selected.referenced_devices.add(device_entry.id)
for entity_entry in ent_reg.entities.values():
if (
entity_entry.entity_category is None
and entity_entry.hidden_by is None
and entity_entry.labels.intersection(selector.label_ids)
):
selected.indirectly_referenced.add(entity_entry.entity_id)
# Find areas for targeted floors # Find areas for targeted floors
if selector.floor_ids: if selector.floor_ids:
for area_entry in area_reg.areas.values(): for area_entry in area_reg.areas.values():
@ -561,7 +561,6 @@ def async_extract_referenced_entity_ids( # noqa: C901
if not selected.referenced_areas and not selected.referenced_devices: if not selected.referenced_areas and not selected.referenced_devices:
return selected return selected
entities = ent_reg.entities
# Add indirectly referenced by area # Add indirectly referenced by area
selected.indirectly_referenced.update( selected.indirectly_referenced.update(
entry.entity_id entry.entity_id

View file

@ -5,7 +5,7 @@ import json
import logging import logging
import os import os
from types import ModuleType from types import ModuleType
from unittest.mock import Mock, call, patch from unittest.mock import call, patch
import pytest import pytest
@ -25,6 +25,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, State, callback from homeassistant.core import HomeAssistant, State, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import discovery from homeassistant.helpers import discovery
from homeassistant.helpers.entity_registry import RegistryEntry
from homeassistant.helpers.json import JSONEncoder from homeassistant.helpers.json import JSONEncoder
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@ -396,10 +397,16 @@ async def test_see_service_guard_config_entry(
mock_device_tracker_conf: list[legacy.Device], mock_device_tracker_conf: list[legacy.Device],
) -> None: ) -> None:
"""Test the guard if the device is registered in the entity registry.""" """Test the guard if the device is registered in the entity registry."""
mock_entry = Mock()
dev_id = "test" dev_id = "test"
entity_id = f"{const.DOMAIN}.{dev_id}" entity_id = f"{const.DOMAIN}.{dev_id}"
mock_registry(hass, {entity_id: mock_entry}) mock_registry(
hass,
{
entity_id: RegistryEntry(
entity_id=entity_id, unique_id=1, platform=const.DOMAIN
)
},
)
devices = mock_device_tracker_conf devices = mock_device_tracker_conf
assert await async_setup_component(hass, device_tracker.DOMAIN, TEST_PLATFORM) assert await async_setup_component(hass, device_tracker.DOMAIN, TEST_PLATFORM)
await hass.async_block_till_done() await hass.async_block_till_done()