Improve performance of extracting entities by label (#114720)
This commit is contained in:
parent
3d8a110908
commit
e86fec310b
3 changed files with 46 additions and 27 deletions
|
@ -512,11 +512,13 @@ class EntityRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]):
|
|||
class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
|
||||
"""Container for entity registry items, maps entity_id -> entry.
|
||||
|
||||
Maintains four additional indexes:
|
||||
Maintains six additional indexes:
|
||||
- id -> entry
|
||||
- (domain, platform, unique_id) -> entity_id
|
||||
- config_entry_id -> list[key]
|
||||
- device_id -> list[key]
|
||||
- config_entry_id -> dict[key, True]
|
||||
- device_id -> dict[key, True]
|
||||
- area_id -> dict[key, True]
|
||||
- label -> dict[key, True]
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
@ -527,6 +529,7 @@ class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
|
|||
self._config_entry_id_index: dict[str, dict[str, Literal[True]]] = {}
|
||||
self._device_id_index: dict[str, dict[str, Literal[True]]] = {}
|
||||
self._area_id_index: dict[str, dict[str, Literal[True]]] = {}
|
||||
self._labels_index: dict[str, dict[str, Literal[True]]] = {}
|
||||
|
||||
def _index_entry(self, key: str, entry: RegistryEntry) -> None:
|
||||
"""Index an entry."""
|
||||
|
@ -540,6 +543,8 @@ class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
|
|||
self._device_id_index.setdefault(device_id, {})[key] = True
|
||||
if (area_id := entry.area_id) is not None:
|
||||
self._area_id_index.setdefault(area_id, {})[key] = True
|
||||
for label in entry.labels:
|
||||
self._labels_index.setdefault(label, {})[key] = True
|
||||
|
||||
def _unindex_entry(
|
||||
self, key: str, replacement_entry: RegistryEntry | None = None
|
||||
|
@ -554,6 +559,9 @@ class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
|
|||
self._unindex_entry_value(key, device_id, self._device_id_index)
|
||||
if area_id := entry.area_id:
|
||||
self._unindex_entry_value(key, area_id, self._area_id_index)
|
||||
if labels := entry.labels:
|
||||
for label in labels:
|
||||
self._unindex_entry_value(key, label, self._labels_index)
|
||||
|
||||
def get_device_ids(self) -> KeysView[str]:
|
||||
"""Return device ids."""
|
||||
|
@ -592,6 +600,11 @@ class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
|
|||
data = self.data
|
||||
return [data[key] for key in self._area_id_index.get(area_id, ())]
|
||||
|
||||
def get_entries_for_label(self, label: str) -> list[RegistryEntry]:
|
||||
"""Get entries for label."""
|
||||
data = self.data
|
||||
return [data[key] for key in self._labels_index.get(label, ())]
|
||||
|
||||
|
||||
class EntityRegistry(BaseRegistry):
|
||||
"""Class to hold a registry of entities."""
|
||||
|
@ -1317,7 +1330,7 @@ def async_entries_for_label(
|
|||
registry: EntityRegistry, label_id: str
|
||||
) -> list[RegistryEntry]:
|
||||
"""Return entries that match a label."""
|
||||
return [entry for entry in registry.entities.values() if label_id in entry.labels]
|
||||
return registry.entities.get_entries_for_label(label_id)
|
||||
|
||||
|
||||
@callback
|
||||
|
|
|
@ -503,15 +503,15 @@ def async_extract_referenced_entity_ids( # noqa: C901
|
|||
):
|
||||
return selected
|
||||
|
||||
ent_reg = entity_registry.async_get(hass)
|
||||
entities = entity_registry.async_get(hass).entities
|
||||
dev_reg = device_registry.async_get(hass)
|
||||
area_reg = area_registry.async_get(hass)
|
||||
floor_reg = floor_registry.async_get(hass)
|
||||
label_reg = label_registry.async_get(hass)
|
||||
|
||||
for floor_id in selector.floor_ids:
|
||||
if floor_id not in floor_reg.floors:
|
||||
selected.missing_floors.add(floor_id)
|
||||
if selector.floor_ids:
|
||||
floor_reg = floor_registry.async_get(hass)
|
||||
for floor_id in selector.floor_ids:
|
||||
if floor_id not in floor_reg.floors:
|
||||
selected.missing_floors.add(floor_id)
|
||||
|
||||
for area_id in selector.area_ids:
|
||||
if area_id not in area_reg.areas:
|
||||
|
@ -521,12 +521,20 @@ def async_extract_referenced_entity_ids( # noqa: C901
|
|||
if device_id not in dev_reg.devices:
|
||||
selected.missing_devices.add(device_id)
|
||||
|
||||
for label_id in selector.label_ids:
|
||||
if label_id not in label_reg.labels:
|
||||
selected.missing_labels.add(label_id)
|
||||
|
||||
# Find areas, devices & entities for targeted labels
|
||||
if selector.label_ids:
|
||||
label_reg = label_registry.async_get(hass)
|
||||
for label_id in selector.label_ids:
|
||||
if label_id not in label_reg.labels:
|
||||
selected.missing_labels.add(label_id)
|
||||
|
||||
for entity_entry in entities.get_entries_for_label(label_id):
|
||||
if (
|
||||
entity_entry.entity_category is None
|
||||
and entity_entry.hidden_by is None
|
||||
):
|
||||
selected.indirectly_referenced.add(entity_entry.entity_id)
|
||||
|
||||
# Find areas, devices & entities for targeted labels
|
||||
for area_entry in area_reg.areas.values():
|
||||
if area_entry.labels.intersection(selector.label_ids):
|
||||
selected.referenced_areas.add(area_entry.id)
|
||||
|
@ -535,14 +543,6 @@ def async_extract_referenced_entity_ids( # noqa: C901
|
|||
if device_entry.labels.intersection(selector.label_ids):
|
||||
selected.referenced_devices.add(device_entry.id)
|
||||
|
||||
for entity_entry in ent_reg.entities.values():
|
||||
if (
|
||||
entity_entry.entity_category is None
|
||||
and entity_entry.hidden_by is None
|
||||
and entity_entry.labels.intersection(selector.label_ids)
|
||||
):
|
||||
selected.indirectly_referenced.add(entity_entry.entity_id)
|
||||
|
||||
# Find areas for targeted floors
|
||||
if selector.floor_ids:
|
||||
for area_entry in area_reg.areas.values():
|
||||
|
@ -561,7 +561,6 @@ def async_extract_referenced_entity_ids( # noqa: C901
|
|||
if not selected.referenced_areas and not selected.referenced_devices:
|
||||
return selected
|
||||
|
||||
entities = ent_reg.entities
|
||||
# Add indirectly referenced by area
|
||||
selected.indirectly_referenced.update(
|
||||
entry.entity_id
|
||||
|
|
|
@ -5,7 +5,7 @@ import json
|
|||
import logging
|
||||
import os
|
||||
from types import ModuleType
|
||||
from unittest.mock import Mock, call, patch
|
||||
from unittest.mock import call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -25,6 +25,7 @@ from homeassistant.const import (
|
|||
from homeassistant.core import HomeAssistant, State, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import discovery
|
||||
from homeassistant.helpers.entity_registry import RegistryEntry
|
||||
from homeassistant.helpers.json import JSONEncoder
|
||||
from homeassistant.setup import async_setup_component
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
@ -396,10 +397,16 @@ async def test_see_service_guard_config_entry(
|
|||
mock_device_tracker_conf: list[legacy.Device],
|
||||
) -> None:
|
||||
"""Test the guard if the device is registered in the entity registry."""
|
||||
mock_entry = Mock()
|
||||
dev_id = "test"
|
||||
entity_id = f"{const.DOMAIN}.{dev_id}"
|
||||
mock_registry(hass, {entity_id: mock_entry})
|
||||
mock_registry(
|
||||
hass,
|
||||
{
|
||||
entity_id: RegistryEntry(
|
||||
entity_id=entity_id, unique_id=1, platform=const.DOMAIN
|
||||
)
|
||||
},
|
||||
)
|
||||
devices = mock_device_tracker_conf
|
||||
assert await async_setup_component(hass, device_tracker.DOMAIN, TEST_PLATFORM)
|
||||
await hass.async_block_till_done()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue