Add index for floor/label to the area registry (#114777)

This commit is contained in:
J. Nick Koston 2024-04-03 21:04:26 -10:00 committed by GitHub
parent aa52688d4b
commit aedfd6c983
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 53 additions and 13 deletions

View file

@ -87,10 +87,49 @@ class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]):
return old_data
class AreaRegistryItems(NormalizedNameBaseRegistryItems[AreaEntry]):
"""Class to hold area registry items."""
def __init__(self) -> None:
"""Initialize the area registry items."""
super().__init__()
self._labels_index: dict[str, dict[str, Literal[True]]] = {}
self._floors_index: dict[str, dict[str, Literal[True]]] = {}
def _index_entry(self, key: str, entry: AreaEntry) -> None:
"""Index an entry."""
if entry.floor_id is not None:
self._floors_index.setdefault(entry.floor_id, {})[key] = True
for label in entry.labels:
self._labels_index.setdefault(label, {})[key] = True
super()._index_entry(key, entry)
def _unindex_entry(
self, key: str, replacement_entry: AreaEntry | None = None
) -> None:
entry = self.data[key]
if labels := entry.labels:
for label in labels:
self._unindex_entry_value(key, label, self._labels_index)
if floor_id := entry.floor_id:
self._unindex_entry_value(key, floor_id, self._floors_index)
return super()._unindex_entry(key, replacement_entry)
def get_areas_for_label(self, label: str) -> list[AreaEntry]:
"""Get areas for label."""
data = self.data
return [data[key] for key in self._labels_index.get(label, ())]
def get_areas_for_floor(self, floor: str) -> list[AreaEntry]:
"""Get areas for floor."""
data = self.data
return [data[key] for key in self._floors_index.get(floor, ())]
class AreaRegistry(BaseRegistry):
"""Class to hold a registry of areas."""
areas: NormalizedNameBaseRegistryItems[AreaEntry]
areas: AreaRegistryItems
_area_data: dict[str, AreaEntry]
def __init__(self, hass: HomeAssistant) -> None:
@ -254,7 +293,7 @@ class AreaRegistry(BaseRegistry):
data = await self._store.async_load()
areas = NormalizedNameBaseRegistryItems[AreaEntry]()
areas = AreaRegistryItems()
if data is not None:
for area in data["areas"]:
@ -369,10 +408,10 @@ async def async_load(hass: HomeAssistant) -> None:
@callback
def async_entries_for_floor(registry: AreaRegistry, floor_id: str) -> list[AreaEntry]:
"""Return entries that match a floor."""
return [area for area in registry.areas.values() if floor_id == area.floor_id]
return registry.areas.get_areas_for_floor(floor_id)
@callback
def async_entries_for_label(registry: AreaRegistry, label_id: str) -> list[AreaEntry]:
"""Return entries that match a label."""
return [area for area in registry.areas.values() if label_id in area.labels]
return registry.areas.get_areas_for_label(label_id)

View file

@ -537,16 +537,16 @@ def async_extract_referenced_entity_ids( # noqa: C901
for device_entry in dev_reg.devices.get_devices_for_label(label_id):
selected.referenced_devices.add(device_entry.id)
# Find areas for targeted labels
for area_entry in area_reg.areas.values():
if area_entry.labels.intersection(selector.label_ids):
for area_entry in area_reg.areas.get_areas_for_label(label_id):
selected.referenced_areas.add(area_entry.id)
# Find areas for targeted floors
if selector.floor_ids:
for area_entry in area_reg.areas.values():
if area_entry.id and area_entry.floor_id in selector.floor_ids:
selected.referenced_areas.add(area_entry.id)
selected.referenced_areas.update(
area_entry.id
for floor_id in selector.floor_ids
for area_entry in area_reg.areas.get_areas_for_floor(floor_id)
)
# Find devices for targeted areas
selected.referenced_devices.update(selector.device_ids)

View file

@ -3,7 +3,6 @@
from __future__ import annotations
import asyncio
from collections import OrderedDict
from collections.abc import AsyncGenerator, Generator, Mapping, Sequence
from contextlib import asynccontextmanager, contextmanager
from datetime import UTC, datetime, timedelta
@ -649,7 +648,9 @@ def mock_area_registry(
fixture instead.
"""
registry = ar.AreaRegistry(hass)
registry.areas = mock_entries or OrderedDict()
registry.areas = ar.AreaRegistryItems()
for key, entry in mock_entries.items():
registry.areas[key] = entry
hass.data[ar.DATA_REGISTRY] = registry
return registry

View file

@ -823,7 +823,7 @@ async def test_empty_aliases(
area_kitchen = area_registry.async_get_or_create("kitchen_id")
area_kitchen = area_registry.async_update(area_kitchen.id, name="kitchen")
area_kitchen = area_registry.async_update(
area_kitchen.id, aliases={" "}, floor_id=floor_1
area_kitchen.id, aliases={" "}, floor_id=floor_1.floor_id
)
entry = MockConfigEntry()