Parallelize collections helper (#34783)

This commit is contained in:
Paulus Schoutsen 2020-04-28 14:31:16 -07:00 committed by GitHub
parent 893f796df2
commit 28f6e79385
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,5 +1,6 @@
"""Helper to deal with YAML + storage.""" """Helper to deal with YAML + storage."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import asyncio
import logging import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional, cast from typing import Any, Awaitable, Callable, Dict, List, Optional, cast
@ -107,8 +108,9 @@ class ObservableCollection(ABC):
async def notify_change(self, change_type: str, item_id: str, item: dict) -> None: async def notify_change(self, change_type: str, item_id: str, item: dict) -> None:
"""Notify listeners of a change.""" """Notify listeners of a change."""
self.logger.debug("%s %s: %s", change_type, item_id, item) self.logger.debug("%s %s: %s", change_type, item_id, item)
for listener in self.listeners: await asyncio.gather(
await listener(change_type, item_id, item) *[listener(change_type, item_id, item) for listener in self.listeners]
)
class YamlCollection(ObservableCollection): class YamlCollection(ObservableCollection):
@ -118,6 +120,8 @@ class YamlCollection(ObservableCollection):
"""Load the YAML collection. Overrides existing data.""" """Load the YAML collection. Overrides existing data."""
old_ids = set(self.data) old_ids = set(self.data)
tasks = []
for item in data: for item in data:
item_id = item[CONF_ID] item_id = item[CONF_ID]
@ -131,11 +135,15 @@ class YamlCollection(ObservableCollection):
event = CHANGE_ADDED event = CHANGE_ADDED
self.data[item_id] = item self.data[item_id] = item
await self.notify_change(event, item_id, item) tasks.append(self.notify_change(event, item_id, item))
for item_id in old_ids: for item_id in old_ids:
tasks.append(
self.notify_change(CHANGE_REMOVED, item_id, self.data.pop(item_id))
)
await self.notify_change(CHANGE_REMOVED, item_id, self.data.pop(item_id)) if tasks:
await asyncio.gather(*tasks)
class StorageCollection(ObservableCollection): class StorageCollection(ObservableCollection):
@ -169,7 +177,13 @@ class StorageCollection(ObservableCollection):
for item in raw_storage["items"]: for item in raw_storage["items"]:
self.data[item[CONF_ID]] = item self.data[item[CONF_ID]] = item
await self.notify_change(CHANGE_ADDED, item[CONF_ID], item)
await asyncio.gather(
*[
self.notify_change(CHANGE_ADDED, item[CONF_ID], item)
for item in raw_storage["items"]
]
)
@abstractmethod @abstractmethod
async def _process_create_data(self, data: dict) -> dict: async def _process_create_data(self, data: dict) -> dict:
@ -240,8 +254,12 @@ class IDLessCollection(ObservableCollection):
async def async_load(self, data: List[dict]) -> None: async def async_load(self, data: List[dict]) -> None:
"""Load the collection. Overrides existing data.""" """Load the collection. Overrides existing data."""
for item_id, item in list(self.data.items()): await asyncio.gather(
await self.notify_change(CHANGE_REMOVED, item_id, item) *[
self.notify_change(CHANGE_REMOVED, item_id, item)
for item_id, item in list(self.data.items())
]
)
self.data.clear() self.data.clear()
@ -250,7 +268,13 @@ class IDLessCollection(ObservableCollection):
item_id = f"fakeid-{self.counter}" item_id = f"fakeid-{self.counter}"
self.data[item_id] = item self.data[item_id] = item
await self.notify_change(CHANGE_ADDED, item_id, item)
await asyncio.gather(
*[
self.notify_change(CHANGE_ADDED, item_id, item)
for item_id, item in self.data.items()
]
)
@callback @callback