Improve performance of extracting entities by label (#114720)

This commit is contained in:
J. Nick Koston 2024-04-03 10:24:44 -10:00 committed by GitHub
parent 3d8a110908
commit e86fec310b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 46 additions and 27 deletions

View file

@ -512,11 +512,13 @@ class EntityRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]):
class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
"""Container for entity registry items, maps entity_id -> entry.
Maintains four additional indexes:
Maintains six additional indexes:
- id -> entry
- (domain, platform, unique_id) -> entity_id
- config_entry_id -> list[key]
- device_id -> list[key]
- config_entry_id -> dict[key, True]
- device_id -> dict[key, True]
- area_id -> dict[key, True]
- label -> dict[key, True]
"""
def __init__(self) -> None:
@ -527,6 +529,7 @@ class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
self._config_entry_id_index: dict[str, dict[str, Literal[True]]] = {}
self._device_id_index: dict[str, dict[str, Literal[True]]] = {}
self._area_id_index: dict[str, dict[str, Literal[True]]] = {}
self._labels_index: dict[str, dict[str, Literal[True]]] = {}
def _index_entry(self, key: str, entry: RegistryEntry) -> None:
"""Index an entry."""
@ -540,6 +543,8 @@ class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
self._device_id_index.setdefault(device_id, {})[key] = True
if (area_id := entry.area_id) is not None:
self._area_id_index.setdefault(area_id, {})[key] = True
for label in entry.labels:
self._labels_index.setdefault(label, {})[key] = True
def _unindex_entry(
self, key: str, replacement_entry: RegistryEntry | None = None
@ -554,6 +559,9 @@ class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
self._unindex_entry_value(key, device_id, self._device_id_index)
if area_id := entry.area_id:
self._unindex_entry_value(key, area_id, self._area_id_index)
if labels := entry.labels:
for label in labels:
self._unindex_entry_value(key, label, self._labels_index)
def get_device_ids(self) -> KeysView[str]:
"""Return device ids."""
@ -592,6 +600,11 @@ class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
data = self.data
return [data[key] for key in self._area_id_index.get(area_id, ())]
def get_entries_for_label(self, label: str) -> list[RegistryEntry]:
"""Get entries for label."""
data = self.data
return [data[key] for key in self._labels_index.get(label, ())]
class EntityRegistry(BaseRegistry):
"""Class to hold a registry of entities."""
@ -1317,7 +1330,7 @@ def async_entries_for_label(
registry: EntityRegistry, label_id: str
) -> list[RegistryEntry]:
"""Return entries that match a label."""
return [entry for entry in registry.entities.values() if label_id in entry.labels]
return registry.entities.get_entries_for_label(label_id)
@callback

View file

@ -503,15 +503,15 @@ def async_extract_referenced_entity_ids( # noqa: C901
):
return selected
ent_reg = entity_registry.async_get(hass)
entities = entity_registry.async_get(hass).entities
dev_reg = device_registry.async_get(hass)
area_reg = area_registry.async_get(hass)
floor_reg = floor_registry.async_get(hass)
label_reg = label_registry.async_get(hass)
for floor_id in selector.floor_ids:
if floor_id not in floor_reg.floors:
selected.missing_floors.add(floor_id)
if selector.floor_ids:
floor_reg = floor_registry.async_get(hass)
for floor_id in selector.floor_ids:
if floor_id not in floor_reg.floors:
selected.missing_floors.add(floor_id)
for area_id in selector.area_ids:
if area_id not in area_reg.areas:
@ -521,12 +521,20 @@ def async_extract_referenced_entity_ids( # noqa: C901
if device_id not in dev_reg.devices:
selected.missing_devices.add(device_id)
for label_id in selector.label_ids:
if label_id not in label_reg.labels:
selected.missing_labels.add(label_id)
# Find areas, devices & entities for targeted labels
if selector.label_ids:
label_reg = label_registry.async_get(hass)
for label_id in selector.label_ids:
if label_id not in label_reg.labels:
selected.missing_labels.add(label_id)
for entity_entry in entities.get_entries_for_label(label_id):
if (
entity_entry.entity_category is None
and entity_entry.hidden_by is None
):
selected.indirectly_referenced.add(entity_entry.entity_id)
# Find areas, devices & entities for targeted labels
for area_entry in area_reg.areas.values():
if area_entry.labels.intersection(selector.label_ids):
selected.referenced_areas.add(area_entry.id)
@ -535,14 +543,6 @@ def async_extract_referenced_entity_ids( # noqa: C901
if device_entry.labels.intersection(selector.label_ids):
selected.referenced_devices.add(device_entry.id)
for entity_entry in ent_reg.entities.values():
if (
entity_entry.entity_category is None
and entity_entry.hidden_by is None
and entity_entry.labels.intersection(selector.label_ids)
):
selected.indirectly_referenced.add(entity_entry.entity_id)
# Find areas for targeted floors
if selector.floor_ids:
for area_entry in area_reg.areas.values():
@ -561,7 +561,6 @@ def async_extract_referenced_entity_ids( # noqa: C901
if not selected.referenced_areas and not selected.referenced_devices:
return selected
entities = ent_reg.entities
# Add indirectly referenced by area
selected.indirectly_referenced.update(
entry.entity_id

View file

@ -5,7 +5,7 @@ import json
import logging
import os
from types import ModuleType
from unittest.mock import Mock, call, patch
from unittest.mock import call, patch
import pytest
@ -25,6 +25,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, State, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import discovery
from homeassistant.helpers.entity_registry import RegistryEntry
from homeassistant.helpers.json import JSONEncoder
from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util
@ -396,10 +397,16 @@ async def test_see_service_guard_config_entry(
mock_device_tracker_conf: list[legacy.Device],
) -> None:
"""Test the guard if the device is registered in the entity registry."""
mock_entry = Mock()
dev_id = "test"
entity_id = f"{const.DOMAIN}.{dev_id}"
mock_registry(hass, {entity_id: mock_entry})
mock_registry(
hass,
{
entity_id: RegistryEntry(
entity_id=entity_id, unique_id=1, platform=const.DOMAIN
)
},
)
devices = mock_device_tracker_conf
assert await async_setup_component(hass, device_tracker.DOMAIN, TEST_PLATFORM)
await hass.async_block_till_done()