Add index for floor/label to the area registry (#114777)
This commit is contained in:
parent
aa52688d4b
commit
aedfd6c983
4 changed files with 53 additions and 13 deletions
|
@ -87,10 +87,49 @@ class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]):
|
||||||
return old_data
|
return old_data
|
||||||
|
|
||||||
|
|
||||||
|
class AreaRegistryItems(NormalizedNameBaseRegistryItems[AreaEntry]):
|
||||||
|
"""Class to hold area registry items."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize the area registry items."""
|
||||||
|
super().__init__()
|
||||||
|
self._labels_index: dict[str, dict[str, Literal[True]]] = {}
|
||||||
|
self._floors_index: dict[str, dict[str, Literal[True]]] = {}
|
||||||
|
|
||||||
|
def _index_entry(self, key: str, entry: AreaEntry) -> None:
|
||||||
|
"""Index an entry."""
|
||||||
|
if entry.floor_id is not None:
|
||||||
|
self._floors_index.setdefault(entry.floor_id, {})[key] = True
|
||||||
|
for label in entry.labels:
|
||||||
|
self._labels_index.setdefault(label, {})[key] = True
|
||||||
|
super()._index_entry(key, entry)
|
||||||
|
|
||||||
|
def _unindex_entry(
|
||||||
|
self, key: str, replacement_entry: AreaEntry | None = None
|
||||||
|
) -> None:
|
||||||
|
entry = self.data[key]
|
||||||
|
if labels := entry.labels:
|
||||||
|
for label in labels:
|
||||||
|
self._unindex_entry_value(key, label, self._labels_index)
|
||||||
|
if floor_id := entry.floor_id:
|
||||||
|
self._unindex_entry_value(key, floor_id, self._floors_index)
|
||||||
|
return super()._unindex_entry(key, replacement_entry)
|
||||||
|
|
||||||
|
def get_areas_for_label(self, label: str) -> list[AreaEntry]:
|
||||||
|
"""Get areas for label."""
|
||||||
|
data = self.data
|
||||||
|
return [data[key] for key in self._labels_index.get(label, ())]
|
||||||
|
|
||||||
|
def get_areas_for_floor(self, floor: str) -> list[AreaEntry]:
|
||||||
|
"""Get areas for floor."""
|
||||||
|
data = self.data
|
||||||
|
return [data[key] for key in self._floors_index.get(floor, ())]
|
||||||
|
|
||||||
|
|
||||||
class AreaRegistry(BaseRegistry):
|
class AreaRegistry(BaseRegistry):
|
||||||
"""Class to hold a registry of areas."""
|
"""Class to hold a registry of areas."""
|
||||||
|
|
||||||
areas: NormalizedNameBaseRegistryItems[AreaEntry]
|
areas: AreaRegistryItems
|
||||||
_area_data: dict[str, AreaEntry]
|
_area_data: dict[str, AreaEntry]
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant) -> None:
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
|
@ -254,7 +293,7 @@ class AreaRegistry(BaseRegistry):
|
||||||
|
|
||||||
data = await self._store.async_load()
|
data = await self._store.async_load()
|
||||||
|
|
||||||
areas = NormalizedNameBaseRegistryItems[AreaEntry]()
|
areas = AreaRegistryItems()
|
||||||
|
|
||||||
if data is not None:
|
if data is not None:
|
||||||
for area in data["areas"]:
|
for area in data["areas"]:
|
||||||
|
@ -369,10 +408,10 @@ async def async_load(hass: HomeAssistant) -> None:
|
||||||
@callback
|
@callback
|
||||||
def async_entries_for_floor(registry: AreaRegistry, floor_id: str) -> list[AreaEntry]:
|
def async_entries_for_floor(registry: AreaRegistry, floor_id: str) -> list[AreaEntry]:
|
||||||
"""Return entries that match a floor."""
|
"""Return entries that match a floor."""
|
||||||
return [area for area in registry.areas.values() if floor_id == area.floor_id]
|
return registry.areas.get_areas_for_floor(floor_id)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_entries_for_label(registry: AreaRegistry, label_id: str) -> list[AreaEntry]:
|
def async_entries_for_label(registry: AreaRegistry, label_id: str) -> list[AreaEntry]:
|
||||||
"""Return entries that match a label."""
|
"""Return entries that match a label."""
|
||||||
return [area for area in registry.areas.values() if label_id in area.labels]
|
return registry.areas.get_areas_for_label(label_id)
|
||||||
|
|
|
@ -537,16 +537,16 @@ def async_extract_referenced_entity_ids( # noqa: C901
|
||||||
for device_entry in dev_reg.devices.get_devices_for_label(label_id):
|
for device_entry in dev_reg.devices.get_devices_for_label(label_id):
|
||||||
selected.referenced_devices.add(device_entry.id)
|
selected.referenced_devices.add(device_entry.id)
|
||||||
|
|
||||||
# Find areas for targeted labels
|
for area_entry in area_reg.areas.get_areas_for_label(label_id):
|
||||||
for area_entry in area_reg.areas.values():
|
|
||||||
if area_entry.labels.intersection(selector.label_ids):
|
|
||||||
selected.referenced_areas.add(area_entry.id)
|
selected.referenced_areas.add(area_entry.id)
|
||||||
|
|
||||||
# Find areas for targeted floors
|
# Find areas for targeted floors
|
||||||
if selector.floor_ids:
|
if selector.floor_ids:
|
||||||
for area_entry in area_reg.areas.values():
|
selected.referenced_areas.update(
|
||||||
if area_entry.id and area_entry.floor_id in selector.floor_ids:
|
area_entry.id
|
||||||
selected.referenced_areas.add(area_entry.id)
|
for floor_id in selector.floor_ids
|
||||||
|
for area_entry in area_reg.areas.get_areas_for_floor(floor_id)
|
||||||
|
)
|
||||||
|
|
||||||
# Find devices for targeted areas
|
# Find devices for targeted areas
|
||||||
selected.referenced_devices.update(selector.device_ids)
|
selected.referenced_devices.update(selector.device_ids)
|
||||||
|
|
|
@ -3,7 +3,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import OrderedDict
|
|
||||||
from collections.abc import AsyncGenerator, Generator, Mapping, Sequence
|
from collections.abc import AsyncGenerator, Generator, Mapping, Sequence
|
||||||
from contextlib import asynccontextmanager, contextmanager
|
from contextlib import asynccontextmanager, contextmanager
|
||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, datetime, timedelta
|
||||||
|
@ -649,7 +648,9 @@ def mock_area_registry(
|
||||||
fixture instead.
|
fixture instead.
|
||||||
"""
|
"""
|
||||||
registry = ar.AreaRegistry(hass)
|
registry = ar.AreaRegistry(hass)
|
||||||
registry.areas = mock_entries or OrderedDict()
|
registry.areas = ar.AreaRegistryItems()
|
||||||
|
for key, entry in mock_entries.items():
|
||||||
|
registry.areas[key] = entry
|
||||||
|
|
||||||
hass.data[ar.DATA_REGISTRY] = registry
|
hass.data[ar.DATA_REGISTRY] = registry
|
||||||
return registry
|
return registry
|
||||||
|
|
|
@ -823,7 +823,7 @@ async def test_empty_aliases(
|
||||||
area_kitchen = area_registry.async_get_or_create("kitchen_id")
|
area_kitchen = area_registry.async_get_or_create("kitchen_id")
|
||||||
area_kitchen = area_registry.async_update(area_kitchen.id, name="kitchen")
|
area_kitchen = area_registry.async_update(area_kitchen.id, name="kitchen")
|
||||||
area_kitchen = area_registry.async_update(
|
area_kitchen = area_registry.async_update(
|
||||||
area_kitchen.id, aliases={" "}, floor_id=floor_1
|
area_kitchen.id, aliases={" "}, floor_id=floor_1.floor_id
|
||||||
)
|
)
|
||||||
|
|
||||||
entry = MockConfigEntry()
|
entry = MockConfigEntry()
|
||||||
|
|
Loading…
Add table
Reference in a new issue