Add index for floor/label to the area registry (#114777)

This commit is contained in:
J. Nick Koston 2024-04-03 21:04:26 -10:00 committed by GitHub
parent aa52688d4b
commit aedfd6c983
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 53 additions and 13 deletions

View file

@ -87,10 +87,49 @@ class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]):
return old_data return old_data
class AreaRegistryItems(NormalizedNameBaseRegistryItems[AreaEntry]):
"""Class to hold area registry items."""
def __init__(self) -> None:
"""Initialize the area registry items."""
super().__init__()
self._labels_index: dict[str, dict[str, Literal[True]]] = {}
self._floors_index: dict[str, dict[str, Literal[True]]] = {}
def _index_entry(self, key: str, entry: AreaEntry) -> None:
"""Index an entry."""
if entry.floor_id is not None:
self._floors_index.setdefault(entry.floor_id, {})[key] = True
for label in entry.labels:
self._labels_index.setdefault(label, {})[key] = True
super()._index_entry(key, entry)
def _unindex_entry(
self, key: str, replacement_entry: AreaEntry | None = None
) -> None:
entry = self.data[key]
if labels := entry.labels:
for label in labels:
self._unindex_entry_value(key, label, self._labels_index)
if floor_id := entry.floor_id:
self._unindex_entry_value(key, floor_id, self._floors_index)
return super()._unindex_entry(key, replacement_entry)
def get_areas_for_label(self, label: str) -> list[AreaEntry]:
"""Get areas for label."""
data = self.data
return [data[key] for key in self._labels_index.get(label, ())]
def get_areas_for_floor(self, floor: str) -> list[AreaEntry]:
"""Get areas for floor."""
data = self.data
return [data[key] for key in self._floors_index.get(floor, ())]
class AreaRegistry(BaseRegistry): class AreaRegistry(BaseRegistry):
"""Class to hold a registry of areas.""" """Class to hold a registry of areas."""
areas: NormalizedNameBaseRegistryItems[AreaEntry] areas: AreaRegistryItems
_area_data: dict[str, AreaEntry] _area_data: dict[str, AreaEntry]
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
@ -254,7 +293,7 @@ class AreaRegistry(BaseRegistry):
data = await self._store.async_load() data = await self._store.async_load()
areas = NormalizedNameBaseRegistryItems[AreaEntry]() areas = AreaRegistryItems()
if data is not None: if data is not None:
for area in data["areas"]: for area in data["areas"]:
@ -369,10 +408,10 @@ async def async_load(hass: HomeAssistant) -> None:
@callback @callback
def async_entries_for_floor(registry: AreaRegistry, floor_id: str) -> list[AreaEntry]: def async_entries_for_floor(registry: AreaRegistry, floor_id: str) -> list[AreaEntry]:
"""Return entries that match a floor.""" """Return entries that match a floor."""
return [area for area in registry.areas.values() if floor_id == area.floor_id] return registry.areas.get_areas_for_floor(floor_id)
@callback @callback
def async_entries_for_label(registry: AreaRegistry, label_id: str) -> list[AreaEntry]: def async_entries_for_label(registry: AreaRegistry, label_id: str) -> list[AreaEntry]:
"""Return entries that match a label.""" """Return entries that match a label."""
return [area for area in registry.areas.values() if label_id in area.labels] return registry.areas.get_areas_for_label(label_id)

View file

@ -537,16 +537,16 @@ def async_extract_referenced_entity_ids( # noqa: C901
for device_entry in dev_reg.devices.get_devices_for_label(label_id): for device_entry in dev_reg.devices.get_devices_for_label(label_id):
selected.referenced_devices.add(device_entry.id) selected.referenced_devices.add(device_entry.id)
# Find areas for targeted labels for area_entry in area_reg.areas.get_areas_for_label(label_id):
for area_entry in area_reg.areas.values():
if area_entry.labels.intersection(selector.label_ids):
selected.referenced_areas.add(area_entry.id) selected.referenced_areas.add(area_entry.id)
# Find areas for targeted floors # Find areas for targeted floors
if selector.floor_ids: if selector.floor_ids:
for area_entry in area_reg.areas.values(): selected.referenced_areas.update(
if area_entry.id and area_entry.floor_id in selector.floor_ids: area_entry.id
selected.referenced_areas.add(area_entry.id) for floor_id in selector.floor_ids
for area_entry in area_reg.areas.get_areas_for_floor(floor_id)
)
# Find devices for targeted areas # Find devices for targeted areas
selected.referenced_devices.update(selector.device_ids) selected.referenced_devices.update(selector.device_ids)

View file

@ -3,7 +3,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections import OrderedDict
from collections.abc import AsyncGenerator, Generator, Mapping, Sequence from collections.abc import AsyncGenerator, Generator, Mapping, Sequence
from contextlib import asynccontextmanager, contextmanager from contextlib import asynccontextmanager, contextmanager
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
@ -649,7 +648,9 @@ def mock_area_registry(
fixture instead. fixture instead.
""" """
registry = ar.AreaRegistry(hass) registry = ar.AreaRegistry(hass)
registry.areas = mock_entries or OrderedDict() registry.areas = ar.AreaRegistryItems()
for key, entry in mock_entries.items():
registry.areas[key] = entry
hass.data[ar.DATA_REGISTRY] = registry hass.data[ar.DATA_REGISTRY] = registry
return registry return registry

View file

@ -823,7 +823,7 @@ async def test_empty_aliases(
area_kitchen = area_registry.async_get_or_create("kitchen_id") area_kitchen = area_registry.async_get_or_create("kitchen_id")
area_kitchen = area_registry.async_update(area_kitchen.id, name="kitchen") area_kitchen = area_registry.async_update(area_kitchen.id, name="kitchen")
area_kitchen = area_registry.async_update( area_kitchen = area_registry.async_update(
area_kitchen.id, aliases={" "}, floor_id=floor_1 area_kitchen.id, aliases={" "}, floor_id=floor_1.floor_id
) )
entry = MockConfigEntry() entry = MockConfigEntry()