Refactor storage collections to reduce tasks during startup (#111182)

* Make adding entities in storage collection a normal function

Nothing is awaited when adding

* cleanup

* cleanup

* cleanup

* cleanup

* reduce

* reduce

* reduce

* reduce

* tweak
This commit is contained in:
J. Nick Koston 2024-02-23 08:50:25 -10:00 committed by GitHub
parent b3a8a75e75
commit 5e16602595
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
import asyncio
from collections.abc import Awaitable, Callable, Coroutine, Iterable
from dataclasses import dataclass
from functools import partial
from itertools import groupby
import logging
from operator import attrgetter
@ -151,7 +152,7 @@ class ObservableCollection(ABC, Generic[_ItemT]):
Will be called with (change_type, item_id, updated_config).
"""
self.listeners.append(listener)
return lambda: self.listeners.remove(listener)
return partial(self.listeners.remove, listener)
@callback
def async_add_change_set_listener(
@ -162,7 +163,7 @@ class ObservableCollection(ABC, Generic[_ItemT]):
Will be called with [(change_type, item_id, updated_config), ...]
"""
self.change_set_listeners.append(listener)
return lambda: self.change_set_listeners.remove(listener)
return partial(self.change_set_listeners.remove, listener)
async def notify_changes(self, change_sets: Iterable[CollectionChangeSet]) -> None:
"""Notify listeners of a change."""
@ -418,6 +419,82 @@ class IDLessCollection(YamlCollection):
)
_GROUP_BY_KEY = attrgetter("change_type")
@dataclass(slots=True, frozen=True)
class _CollectionLifeCycle(Generic[_EntityT]):
"""Life cycle for a collection of entities."""
domain: str
platform: str
entity_component: EntityComponent[_EntityT]
collection: StorageCollection | YamlCollection
entity_class: type[CollectionEntity]
ent_reg: entity_registry.EntityRegistry
entities: dict[str, CollectionEntity]
@callback
def async_setup(self) -> None:
"""Set up the collection life cycle."""
self.collection.async_add_change_set_listener(self._collection_changed)
def _entity_removed(self, item_id: str) -> None:
"""Remove entity from entities if it's removed or not added."""
self.entities.pop(item_id, None)
@callback
def _add_entity(self, change_set: CollectionChangeSet) -> CollectionEntity:
item_id = change_set.item_id
entity = self.collection.create_entity(self.entity_class, change_set.item)
self.entities[item_id] = entity
entity.async_on_remove(partial(self._entity_removed, item_id))
return entity
async def _remove_entity(self, change_set: CollectionChangeSet) -> None:
item_id = change_set.item_id
ent_reg = self.ent_reg
entities = self.entities
ent_to_remove = ent_reg.async_get_entity_id(self.domain, self.platform, item_id)
if ent_to_remove is not None:
ent_reg.async_remove(ent_to_remove)
elif entity := entities.get(item_id):
await entity.async_remove(force_remove=True)
# Unconditionally pop the entity from the entity list to avoid racing against
# the entity registry event handled by Entity._async_registry_updated
entities.pop(item_id, None)
async def _update_entity(self, change_set: CollectionChangeSet) -> None:
if entity := self.entities.get(change_set.item_id):
await entity.async_update_config(change_set.item)
async def _collection_changed(
self, change_sets: Iterable[CollectionChangeSet]
) -> None:
"""Handle a collection change."""
# Create a new bucket every time we have a different change type
# to ensure operations happen in order. We only group
# the same change type.
new_entities: list[CollectionEntity] = []
coros: list[Coroutine[Any, Any, CollectionEntity | None]] = []
grouped: Iterable[CollectionChangeSet]
for _, grouped in groupby(change_sets, _GROUP_BY_KEY):
for change_set in grouped:
change_type = change_set.change_type
if change_type == CHANGE_ADDED:
new_entities.append(self._add_entity(change_set))
elif change_type == CHANGE_REMOVED:
coros.append(self._remove_entity(change_set))
elif change_type == CHANGE_UPDATED:
coros.append(self._update_entity(change_set))
if coros:
await asyncio.gather(*coros)
if new_entities:
await self.entity_component.async_add_entities(new_entities)
@callback
def sync_entity_lifecycle(
hass: HomeAssistant,
@ -428,69 +505,10 @@ def sync_entity_lifecycle(
entity_class: type[CollectionEntity],
) -> None:
"""Map a collection to an entity component."""
entities: dict[str, CollectionEntity] = {}
ent_reg = entity_registry.async_get(hass)
async def _add_entity(change_set: CollectionChangeSet) -> CollectionEntity:
def entity_removed() -> None:
"""Remove entity from entities if it's removed or not added."""
if change_set.item_id in entities:
entities.pop(change_set.item_id)
entities[change_set.item_id] = collection.create_entity(
entity_class, change_set.item
)
entities[change_set.item_id].async_on_remove(entity_removed)
return entities[change_set.item_id]
async def _remove_entity(change_set: CollectionChangeSet) -> None:
ent_to_remove = ent_reg.async_get_entity_id(
domain, platform, change_set.item_id
)
if ent_to_remove is not None:
ent_reg.async_remove(ent_to_remove)
elif change_set.item_id in entities:
await entities[change_set.item_id].async_remove(force_remove=True)
# Unconditionally pop the entity from the entity list to avoid racing against
# the entity registry event handled by Entity._async_registry_updated
if change_set.item_id in entities:
entities.pop(change_set.item_id)
async def _update_entity(change_set: CollectionChangeSet) -> None:
if change_set.item_id not in entities:
return
await entities[change_set.item_id].async_update_config(change_set.item)
_func_map: dict[
str,
Callable[[CollectionChangeSet], Coroutine[Any, Any, CollectionEntity | None]],
] = {
CHANGE_ADDED: _add_entity,
CHANGE_REMOVED: _remove_entity,
CHANGE_UPDATED: _update_entity,
}
async def _collection_changed(change_sets: Iterable[CollectionChangeSet]) -> None:
"""Handle a collection change."""
# Create a new bucket every time we have a different change type
# to ensure operations happen in order. We only group
# the same change type.
groupby_key = attrgetter("change_type")
for _, grouped in groupby(change_sets, groupby_key):
new_entities = [
entity
for entity in await asyncio.gather(
*(
_func_map[change_set.change_type](change_set)
for change_set in grouped
)
)
if entity is not None
]
if new_entities:
await entity_component.async_add_entities(new_entities)
collection.async_add_change_set_listener(_collection_changed)
_CollectionLifeCycle(
domain, platform, entity_component, collection, entity_class, ent_reg, {}
).async_setup()
class StorageCollectionWebsocket(Generic[_StorageCollectionT]):