Improve performance of extracting entities by label (#114720)
This commit is contained in:
parent
3d8a110908
commit
e86fec310b
3 changed files with 46 additions and 27 deletions
|
@ -512,11 +512,13 @@ class EntityRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]):
|
||||||
class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
|
class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
|
||||||
"""Container for entity registry items, maps entity_id -> entry.
|
"""Container for entity registry items, maps entity_id -> entry.
|
||||||
|
|
||||||
Maintains four additional indexes:
|
Maintains six additional indexes:
|
||||||
- id -> entry
|
- id -> entry
|
||||||
- (domain, platform, unique_id) -> entity_id
|
- (domain, platform, unique_id) -> entity_id
|
||||||
- config_entry_id -> list[key]
|
- config_entry_id -> dict[key, True]
|
||||||
- device_id -> list[key]
|
- device_id -> dict[key, True]
|
||||||
|
- area_id -> dict[key, True]
|
||||||
|
- label -> dict[key, True]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
@ -527,6 +529,7 @@ class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
|
||||||
self._config_entry_id_index: dict[str, dict[str, Literal[True]]] = {}
|
self._config_entry_id_index: dict[str, dict[str, Literal[True]]] = {}
|
||||||
self._device_id_index: dict[str, dict[str, Literal[True]]] = {}
|
self._device_id_index: dict[str, dict[str, Literal[True]]] = {}
|
||||||
self._area_id_index: dict[str, dict[str, Literal[True]]] = {}
|
self._area_id_index: dict[str, dict[str, Literal[True]]] = {}
|
||||||
|
self._labels_index: dict[str, dict[str, Literal[True]]] = {}
|
||||||
|
|
||||||
def _index_entry(self, key: str, entry: RegistryEntry) -> None:
|
def _index_entry(self, key: str, entry: RegistryEntry) -> None:
|
||||||
"""Index an entry."""
|
"""Index an entry."""
|
||||||
|
@ -540,6 +543,8 @@ class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
|
||||||
self._device_id_index.setdefault(device_id, {})[key] = True
|
self._device_id_index.setdefault(device_id, {})[key] = True
|
||||||
if (area_id := entry.area_id) is not None:
|
if (area_id := entry.area_id) is not None:
|
||||||
self._area_id_index.setdefault(area_id, {})[key] = True
|
self._area_id_index.setdefault(area_id, {})[key] = True
|
||||||
|
for label in entry.labels:
|
||||||
|
self._labels_index.setdefault(label, {})[key] = True
|
||||||
|
|
||||||
def _unindex_entry(
|
def _unindex_entry(
|
||||||
self, key: str, replacement_entry: RegistryEntry | None = None
|
self, key: str, replacement_entry: RegistryEntry | None = None
|
||||||
|
@ -554,6 +559,9 @@ class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
|
||||||
self._unindex_entry_value(key, device_id, self._device_id_index)
|
self._unindex_entry_value(key, device_id, self._device_id_index)
|
||||||
if area_id := entry.area_id:
|
if area_id := entry.area_id:
|
||||||
self._unindex_entry_value(key, area_id, self._area_id_index)
|
self._unindex_entry_value(key, area_id, self._area_id_index)
|
||||||
|
if labels := entry.labels:
|
||||||
|
for label in labels:
|
||||||
|
self._unindex_entry_value(key, label, self._labels_index)
|
||||||
|
|
||||||
def get_device_ids(self) -> KeysView[str]:
|
def get_device_ids(self) -> KeysView[str]:
|
||||||
"""Return device ids."""
|
"""Return device ids."""
|
||||||
|
@ -592,6 +600,11 @@ class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
|
||||||
data = self.data
|
data = self.data
|
||||||
return [data[key] for key in self._area_id_index.get(area_id, ())]
|
return [data[key] for key in self._area_id_index.get(area_id, ())]
|
||||||
|
|
||||||
|
def get_entries_for_label(self, label: str) -> list[RegistryEntry]:
|
||||||
|
"""Get entries for label."""
|
||||||
|
data = self.data
|
||||||
|
return [data[key] for key in self._labels_index.get(label, ())]
|
||||||
|
|
||||||
|
|
||||||
class EntityRegistry(BaseRegistry):
|
class EntityRegistry(BaseRegistry):
|
||||||
"""Class to hold a registry of entities."""
|
"""Class to hold a registry of entities."""
|
||||||
|
@ -1317,7 +1330,7 @@ def async_entries_for_label(
|
||||||
registry: EntityRegistry, label_id: str
|
registry: EntityRegistry, label_id: str
|
||||||
) -> list[RegistryEntry]:
|
) -> list[RegistryEntry]:
|
||||||
"""Return entries that match a label."""
|
"""Return entries that match a label."""
|
||||||
return [entry for entry in registry.entities.values() if label_id in entry.labels]
|
return registry.entities.get_entries_for_label(label_id)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
|
|
@ -503,15 +503,15 @@ def async_extract_referenced_entity_ids( # noqa: C901
|
||||||
):
|
):
|
||||||
return selected
|
return selected
|
||||||
|
|
||||||
ent_reg = entity_registry.async_get(hass)
|
entities = entity_registry.async_get(hass).entities
|
||||||
dev_reg = device_registry.async_get(hass)
|
dev_reg = device_registry.async_get(hass)
|
||||||
area_reg = area_registry.async_get(hass)
|
area_reg = area_registry.async_get(hass)
|
||||||
floor_reg = floor_registry.async_get(hass)
|
|
||||||
label_reg = label_registry.async_get(hass)
|
|
||||||
|
|
||||||
for floor_id in selector.floor_ids:
|
if selector.floor_ids:
|
||||||
if floor_id not in floor_reg.floors:
|
floor_reg = floor_registry.async_get(hass)
|
||||||
selected.missing_floors.add(floor_id)
|
for floor_id in selector.floor_ids:
|
||||||
|
if floor_id not in floor_reg.floors:
|
||||||
|
selected.missing_floors.add(floor_id)
|
||||||
|
|
||||||
for area_id in selector.area_ids:
|
for area_id in selector.area_ids:
|
||||||
if area_id not in area_reg.areas:
|
if area_id not in area_reg.areas:
|
||||||
|
@ -521,12 +521,20 @@ def async_extract_referenced_entity_ids( # noqa: C901
|
||||||
if device_id not in dev_reg.devices:
|
if device_id not in dev_reg.devices:
|
||||||
selected.missing_devices.add(device_id)
|
selected.missing_devices.add(device_id)
|
||||||
|
|
||||||
for label_id in selector.label_ids:
|
|
||||||
if label_id not in label_reg.labels:
|
|
||||||
selected.missing_labels.add(label_id)
|
|
||||||
|
|
||||||
# Find areas, devices & entities for targeted labels
|
|
||||||
if selector.label_ids:
|
if selector.label_ids:
|
||||||
|
label_reg = label_registry.async_get(hass)
|
||||||
|
for label_id in selector.label_ids:
|
||||||
|
if label_id not in label_reg.labels:
|
||||||
|
selected.missing_labels.add(label_id)
|
||||||
|
|
||||||
|
for entity_entry in entities.get_entries_for_label(label_id):
|
||||||
|
if (
|
||||||
|
entity_entry.entity_category is None
|
||||||
|
and entity_entry.hidden_by is None
|
||||||
|
):
|
||||||
|
selected.indirectly_referenced.add(entity_entry.entity_id)
|
||||||
|
|
||||||
|
# Find areas, devices & entities for targeted labels
|
||||||
for area_entry in area_reg.areas.values():
|
for area_entry in area_reg.areas.values():
|
||||||
if area_entry.labels.intersection(selector.label_ids):
|
if area_entry.labels.intersection(selector.label_ids):
|
||||||
selected.referenced_areas.add(area_entry.id)
|
selected.referenced_areas.add(area_entry.id)
|
||||||
|
@ -535,14 +543,6 @@ def async_extract_referenced_entity_ids( # noqa: C901
|
||||||
if device_entry.labels.intersection(selector.label_ids):
|
if device_entry.labels.intersection(selector.label_ids):
|
||||||
selected.referenced_devices.add(device_entry.id)
|
selected.referenced_devices.add(device_entry.id)
|
||||||
|
|
||||||
for entity_entry in ent_reg.entities.values():
|
|
||||||
if (
|
|
||||||
entity_entry.entity_category is None
|
|
||||||
and entity_entry.hidden_by is None
|
|
||||||
and entity_entry.labels.intersection(selector.label_ids)
|
|
||||||
):
|
|
||||||
selected.indirectly_referenced.add(entity_entry.entity_id)
|
|
||||||
|
|
||||||
# Find areas for targeted floors
|
# Find areas for targeted floors
|
||||||
if selector.floor_ids:
|
if selector.floor_ids:
|
||||||
for area_entry in area_reg.areas.values():
|
for area_entry in area_reg.areas.values():
|
||||||
|
@ -561,7 +561,6 @@ def async_extract_referenced_entity_ids( # noqa: C901
|
||||||
if not selected.referenced_areas and not selected.referenced_devices:
|
if not selected.referenced_areas and not selected.referenced_devices:
|
||||||
return selected
|
return selected
|
||||||
|
|
||||||
entities = ent_reg.entities
|
|
||||||
# Add indirectly referenced by area
|
# Add indirectly referenced by area
|
||||||
selected.indirectly_referenced.update(
|
selected.indirectly_referenced.update(
|
||||||
entry.entity_id
|
entry.entity_id
|
||||||
|
|
|
@ -5,7 +5,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from unittest.mock import Mock, call, patch
|
from unittest.mock import call, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ from homeassistant.const import (
|
||||||
from homeassistant.core import HomeAssistant, State, callback
|
from homeassistant.core import HomeAssistant, State, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import discovery
|
from homeassistant.helpers import discovery
|
||||||
|
from homeassistant.helpers.entity_registry import RegistryEntry
|
||||||
from homeassistant.helpers.json import JSONEncoder
|
from homeassistant.helpers.json import JSONEncoder
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
import homeassistant.util.dt as dt_util
|
import homeassistant.util.dt as dt_util
|
||||||
|
@ -396,10 +397,16 @@ async def test_see_service_guard_config_entry(
|
||||||
mock_device_tracker_conf: list[legacy.Device],
|
mock_device_tracker_conf: list[legacy.Device],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the guard if the device is registered in the entity registry."""
|
"""Test the guard if the device is registered in the entity registry."""
|
||||||
mock_entry = Mock()
|
|
||||||
dev_id = "test"
|
dev_id = "test"
|
||||||
entity_id = f"{const.DOMAIN}.{dev_id}"
|
entity_id = f"{const.DOMAIN}.{dev_id}"
|
||||||
mock_registry(hass, {entity_id: mock_entry})
|
mock_registry(
|
||||||
|
hass,
|
||||||
|
{
|
||||||
|
entity_id: RegistryEntry(
|
||||||
|
entity_id=entity_id, unique_id=1, platform=const.DOMAIN
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
devices = mock_device_tracker_conf
|
devices = mock_device_tracker_conf
|
||||||
assert await async_setup_component(hass, device_tracker.DOMAIN, TEST_PLATFORM)
|
assert await async_setup_component(hass, device_tracker.DOMAIN, TEST_PLATFORM)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue