From 5e166025953b02c7c75b888d602eef91081a6a91 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 23 Feb 2024 08:50:25 -1000 Subject: [PATCH] Refactor storage collections to reduce tasks during startup (#111182) * Make adding entities in storage collection a normal function Nothing is awaited when adding * cleanup * cleanup * cleanup * cleanup * reduce * reduce * reduce * reduce * tweak --- homeassistant/helpers/collection.py | 146 ++++++++++++++++------------ 1 file changed, 82 insertions(+), 64 deletions(-) diff --git a/homeassistant/helpers/collection.py b/homeassistant/helpers/collection.py index 80b40cf4fa0..c3c2ae4ec37 100644 --- a/homeassistant/helpers/collection.py +++ b/homeassistant/helpers/collection.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod import asyncio from collections.abc import Awaitable, Callable, Coroutine, Iterable from dataclasses import dataclass +from functools import partial from itertools import groupby import logging from operator import attrgetter @@ -151,7 +152,7 @@ class ObservableCollection(ABC, Generic[_ItemT]): Will be called with (change_type, item_id, updated_config). """ self.listeners.append(listener) - return lambda: self.listeners.remove(listener) + return partial(self.listeners.remove, listener) @callback def async_add_change_set_listener( @@ -162,7 +163,7 @@ class ObservableCollection(ABC, Generic[_ItemT]): Will be called with [(change_type, item_id, updated_config), ...] """ self.change_set_listeners.append(listener) - return lambda: self.change_set_listeners.remove(listener) + return partial(self.change_set_listeners.remove, listener) async def notify_changes(self, change_sets: Iterable[CollectionChangeSet]) -> None: """Notify listeners of a change.""" @@ -418,6 +419,82 @@ class IDLessCollection(YamlCollection): ) +_GROUP_BY_KEY = attrgetter("change_type") + + +@dataclass(slots=True, frozen=True) +class _CollectionLifeCycle(Generic[_EntityT]): + """Life cycle for a collection of entities.""" + + domain: str + platform: str + entity_component: EntityComponent[_EntityT] + collection: StorageCollection | YamlCollection + entity_class: type[CollectionEntity] + ent_reg: entity_registry.EntityRegistry + entities: dict[str, CollectionEntity] + + @callback + def async_setup(self) -> None: + """Set up the collection life cycle.""" + self.collection.async_add_change_set_listener(self._collection_changed) + + def _entity_removed(self, item_id: str) -> None: + """Remove entity from entities if it's removed or not added.""" + self.entities.pop(item_id, None) + + @callback + def _add_entity(self, change_set: CollectionChangeSet) -> CollectionEntity: + item_id = change_set.item_id + entity = self.collection.create_entity(self.entity_class, change_set.item) + self.entities[item_id] = entity + entity.async_on_remove(partial(self._entity_removed, item_id)) + return entity + + async def _remove_entity(self, change_set: CollectionChangeSet) -> None: + item_id = change_set.item_id + ent_reg = self.ent_reg + entities = self.entities + ent_to_remove = ent_reg.async_get_entity_id(self.domain, self.platform, item_id) + if ent_to_remove is not None: + ent_reg.async_remove(ent_to_remove) + elif entity := entities.get(item_id): + await entity.async_remove(force_remove=True) + # Unconditionally pop the entity from the entity list to avoid racing against + # the entity registry event handled by Entity._async_registry_updated + entities.pop(item_id, None) + + async def _update_entity(self, change_set: CollectionChangeSet) -> None: + if entity := self.entities.get(change_set.item_id): + await entity.async_update_config(change_set.item) + + async def _collection_changed( + self, change_sets: Iterable[CollectionChangeSet] + ) -> None: + """Handle a collection change.""" + # Create a new bucket every time we have a different change type + # to ensure operations happen in order. We only group + # the same change type. + new_entities: list[CollectionEntity] = [] + coros: list[Coroutine[Any, Any, CollectionEntity | None]] = [] + grouped: Iterable[CollectionChangeSet] + for _, grouped in groupby(change_sets, _GROUP_BY_KEY): + for change_set in grouped: + change_type = change_set.change_type + if change_type == CHANGE_ADDED: + new_entities.append(self._add_entity(change_set)) + elif change_type == CHANGE_REMOVED: + coros.append(self._remove_entity(change_set)) + elif change_type == CHANGE_UPDATED: + coros.append(self._update_entity(change_set)) + + if coros: + await asyncio.gather(*coros) + + if new_entities: + await self.entity_component.async_add_entities(new_entities) + + @callback def sync_entity_lifecycle( hass: HomeAssistant, @@ -428,69 +505,10 @@ def sync_entity_lifecycle( entity_class: type[CollectionEntity], ) -> None: """Map a collection to an entity component.""" - entities: dict[str, CollectionEntity] = {} ent_reg = entity_registry.async_get(hass) - - async def _add_entity(change_set: CollectionChangeSet) -> CollectionEntity: - def entity_removed() -> None: - """Remove entity from entities if it's removed or not added.""" - if change_set.item_id in entities: - entities.pop(change_set.item_id) - - entities[change_set.item_id] = collection.create_entity( - entity_class, change_set.item - ) - entities[change_set.item_id].async_on_remove(entity_removed) - return entities[change_set.item_id] - - async def _remove_entity(change_set: CollectionChangeSet) -> None: - ent_to_remove = ent_reg.async_get_entity_id( - domain, platform, change_set.item_id - ) - if ent_to_remove is not None: - ent_reg.async_remove(ent_to_remove) - elif change_set.item_id in entities: - await entities[change_set.item_id].async_remove(force_remove=True) - # Unconditionally pop the entity from the entity list to avoid racing against - # the entity registry event handled by Entity._async_registry_updated - if change_set.item_id in entities: - entities.pop(change_set.item_id) - - async def _update_entity(change_set: CollectionChangeSet) -> None: - if change_set.item_id not in entities: - return - await entities[change_set.item_id].async_update_config(change_set.item) - - _func_map: dict[ - str, - Callable[[CollectionChangeSet], Coroutine[Any, Any, CollectionEntity | None]], - ] = { - CHANGE_ADDED: _add_entity, - CHANGE_REMOVED: _remove_entity, - CHANGE_UPDATED: _update_entity, - } - - async def _collection_changed(change_sets: Iterable[CollectionChangeSet]) -> None: - """Handle a collection change.""" - # Create a new bucket every time we have a different change type - # to ensure operations happen in order. We only group - # the same change type. - groupby_key = attrgetter("change_type") - for _, grouped in groupby(change_sets, groupby_key): - new_entities = [ - entity - for entity in await asyncio.gather( - *( - _func_map[change_set.change_type](change_set) - for change_set in grouped - ) - ) - if entity is not None - ] - if new_entities: - await entity_component.async_add_entities(new_entities) - - collection.async_add_change_set_listener(_collection_changed) + _CollectionLifeCycle( + domain, platform, entity_component, collection, entity_class, ent_reg, {} + ).async_setup() class StorageCollectionWebsocket(Generic[_StorageCollectionT]):