diff --git a/homeassistant/components/application_credentials/__init__.py b/homeassistant/components/application_credentials/__init__.py index f1471f29666..2085e9a6f37 100644 --- a/homeassistant/components/application_credentials/__init__.py +++ b/homeassistant/components/application_credentials/__init__.py @@ -20,7 +20,6 @@ from homeassistant.const import ( CONF_CLIENT_ID, CONF_CLIENT_SECRET, CONF_DOMAIN, - CONF_ID, CONF_NAME, ) from homeassistant.core import HomeAssistant, callback @@ -125,10 +124,10 @@ class ApplicationCredentialsStorageCollection(collection.DictStorageCollection): def async_client_credentials(self, domain: str) -> dict[str, ClientCredential]: """Return ClientCredentials in storage for the specified domain.""" credentials = {} - for item in self.async_items(): + for item_id, item in self.data.items(): if item[CONF_DOMAIN] != domain: continue - auth_domain = item.get(CONF_AUTH_DOMAIN, item[CONF_ID]) + auth_domain = item.get(CONF_AUTH_DOMAIN, item_id) credentials[auth_domain] = ClientCredential( client_id=item[CONF_CLIENT_ID], client_secret=item[CONF_CLIENT_SECRET], @@ -244,8 +243,7 @@ async def _async_config_entry_app_credentials( return None storage_collection = hass.data[DOMAIN][DATA_STORAGE] - for item in storage_collection.async_items(): - item_id = item[CONF_ID] + for item_id, item in storage_collection.async_items().items(): if ( item[CONF_DOMAIN] == config_entry.domain and item.get(CONF_AUTH_DOMAIN, item_id) == auth_domain diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 85bd92563db..d18c531dcd2 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -63,7 +63,7 @@ async def async_get_pipeline( # Construct a pipeline for the required/configured language language = language or hass.config.language - return await pipeline_data.pipeline_store.async_create_item( + _, pipeline = await pipeline_data.pipeline_store.async_create_item( { "name": language, "language": language, @@ -72,6 +72,7 @@ async def async_get_pipeline( "tts_engine": None, # first engine } ) + return pipeline class PipelineEventType(StrEnum): @@ -610,7 +611,7 @@ class PipelineStorageCollection( """Create an item from its serialized representation.""" return Pipeline(**data) - def _serialize_item(self, item_id: str, item: Pipeline) -> dict: + def _serialize_item(self, item: Pipeline) -> dict: """Return the serialized representation of an item for storing.""" return item.to_json() diff --git a/homeassistant/components/counter/__init__.py b/homeassistant/components/counter/__init__.py index 768491f6085..471d8bea3b6 100644 --- a/homeassistant/components/counter/__init__.py +++ b/homeassistant/components/counter/__init__.py @@ -139,7 +139,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return True -class CounterStorageCollection(collection.DictStorageCollection): +class CounterStorageCollection(collection.LegacyDictStorageCollection): """Input storage based collection.""" CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS) diff --git a/homeassistant/components/image_upload/__init__.py b/homeassistant/components/image_upload/__init__.py index 17c40cfc875..5f66412a303 100644 --- a/homeassistant/components/image_upload/__init__.py +++ b/homeassistant/components/image_upload/__init__.py @@ -57,7 +57,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return True -class ImageStorageCollection(collection.DictStorageCollection): +class ImageStorageCollection(collection.LegacyDictStorageCollection): """Image collection stored in storage.""" CREATE_SCHEMA = vol.Schema(CREATE_FIELDS) @@ -156,7 +156,7 @@ class ImageUploadView(HomeAssistantView): request._client_max_size = MAX_SIZE # pylint: disable=protected-access data = await request.post() - item = await request.app["hass"].data[DOMAIN].async_create_item(data) + _, item = await request.app["hass"].data[DOMAIN].async_create_item(data) return self.json(item) diff --git a/homeassistant/components/input_boolean/__init__.py b/homeassistant/components/input_boolean/__init__.py index 33cb4b9e576..1bbf08f029e 100644 --- a/homeassistant/components/input_boolean/__init__.py +++ b/homeassistant/components/input_boolean/__init__.py @@ -65,7 +65,7 @@ STORAGE_KEY = DOMAIN STORAGE_VERSION = 1 -class InputBooleanStorageCollection(collection.DictStorageCollection): +class InputBooleanStorageCollection(collection.LegacyDictStorageCollection): """Input boolean collection stored in storage.""" CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS) diff --git a/homeassistant/components/input_button/__init__.py b/homeassistant/components/input_button/__init__.py index 8a1f0785435..fb3752ae654 100644 --- a/homeassistant/components/input_button/__init__.py +++ b/homeassistant/components/input_button/__init__.py @@ -56,7 +56,7 @@ STORAGE_KEY = DOMAIN STORAGE_VERSION = 1 -class InputButtonStorageCollection(collection.DictStorageCollection): +class InputButtonStorageCollection(collection.LegacyDictStorageCollection): """Input button collection stored in storage.""" CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS) diff --git a/homeassistant/components/input_datetime/__init__.py b/homeassistant/components/input_datetime/__init__.py index c51c0fdd67c..414a5bb196d 100644 --- a/homeassistant/components/input_datetime/__init__.py +++ b/homeassistant/components/input_datetime/__init__.py @@ -203,7 +203,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return True -class DateTimeStorageCollection(collection.DictStorageCollection): +class DateTimeStorageCollection(collection.LegacyDictStorageCollection): """Input storage based collection.""" CREATE_UPDATE_SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, has_date_or_time)) diff --git a/homeassistant/components/input_number/__init__.py b/homeassistant/components/input_number/__init__.py index 061b388ace5..8ecd92824d4 100644 --- a/homeassistant/components/input_number/__init__.py +++ b/homeassistant/components/input_number/__init__.py @@ -170,7 +170,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return True -class NumberStorageCollection(collection.DictStorageCollection): +class NumberStorageCollection(collection.LegacyDictStorageCollection): """Input storage based collection.""" SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, _cv_input_number)) @@ -184,7 +184,9 @@ class NumberStorageCollection(collection.DictStorageCollection): """Suggest an ID based on the config.""" return info[CONF_NAME] - async def _async_load_data(self) -> collection.SerializedStorageCollection | None: + async def _async_load_data( # type: ignore[override] + self, + ) -> collection.LegacySerializedStorageCollection | None: """Load the data. A past bug caused frontend to add initial value to all input numbers. diff --git a/homeassistant/components/input_select/__init__.py b/homeassistant/components/input_select/__init__.py index 186ab84fb81..19a62b09306 100644 --- a/homeassistant/components/input_select/__init__.py +++ b/homeassistant/components/input_select/__init__.py @@ -231,7 +231,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return True -class InputSelectStorageCollection(collection.DictStorageCollection): +class InputSelectStorageCollection(collection.LegacyDictStorageCollection): """Input storage based collection.""" CREATE_UPDATE_SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, _cv_input_select)) diff --git a/homeassistant/components/input_text/__init__.py b/homeassistant/components/input_text/__init__.py index efd58e38e72..ed86e6d7ff3 100644 --- a/homeassistant/components/input_text/__init__.py +++ b/homeassistant/components/input_text/__init__.py @@ -164,7 +164,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return True -class InputTextStorageCollection(collection.DictStorageCollection): +class InputTextStorageCollection(collection.LegacyDictStorageCollection): """Input storage based collection.""" CREATE_UPDATE_SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, _cv_input_text)) diff --git a/homeassistant/components/lovelace/dashboard.py b/homeassistant/components/lovelace/dashboard.py index 054aaf9b24c..371d2ac8669 100644 --- a/homeassistant/components/lovelace/dashboard.py +++ b/homeassistant/components/lovelace/dashboard.py @@ -217,7 +217,7 @@ def _config_info(mode, config): } -class DashboardsCollection(collection.DictStorageCollection): +class DashboardsCollection(collection.LegacyDictStorageCollection): """Collection of dashboards.""" CREATE_SCHEMA = vol.Schema(STORAGE_DASHBOARD_CREATE_FIELDS) @@ -229,9 +229,11 @@ class DashboardsCollection(collection.DictStorageCollection): storage.Store(hass, DASHBOARDS_STORAGE_VERSION, DASHBOARDS_STORAGE_KEY), ) - async def _async_load_data(self) -> collection.SerializedStorageCollection | None: + async def _async_load_data( # type: ignore[override] + self, + ) -> collection.LegacySerializedStorageCollection | None: """Load the data.""" - if (data := await self.store.async_load()) is None: + if (data := await super()._async_load_data()) is None: return data updated = False @@ -242,7 +244,7 @@ class DashboardsCollection(collection.DictStorageCollection): item[CONF_URL_PATH] = f"lovelace-{item[CONF_URL_PATH]}" if updated: - await self.store.async_save(data) + await self.store.async_save(data) # type: ignore[arg-type] return data diff --git a/homeassistant/components/lovelace/resources.py b/homeassistant/components/lovelace/resources.py index b6d0c939fec..7bd0cd0c584 100644 --- a/homeassistant/components/lovelace/resources.py +++ b/homeassistant/components/lovelace/resources.py @@ -37,15 +37,15 @@ class ResourceYAMLCollection: async def async_get_info(self): """Return the resources info for YAML mode.""" - return {"resources": len(self.async_items() or [])} + return {"resources": len(self.async_values() or [])} @callback - def async_items(self) -> list[dict]: + def async_values(self) -> list[dict]: """Return list of items in collection.""" return self.data -class ResourceStorageCollection(collection.DictStorageCollection): +class ResourceStorageCollection(collection.LegacyDictStorageCollection): """Collection to store resources.""" loaded = False @@ -67,9 +67,11 @@ class ResourceStorageCollection(collection.DictStorageCollection): return {"resources": len(self.async_items() or [])} - async def _async_load_data(self) -> collection.SerializedStorageCollection | None: + async def _async_load_data( # type: ignore[override] + self, + ) -> collection.LegacySerializedStorageCollection | None: """Load the data.""" - if (store_data := await self.store.async_load()) is not None: + if (store_data := await super()._async_load_data()) is not None: return store_data # Import it from config. @@ -95,9 +97,9 @@ class ResourceStorageCollection(collection.DictStorageCollection): for item in resources: item[CONF_ID] = uuid.uuid4().hex - data: collection.SerializedStorageCollection = {"items": resources} + data: collection.LegacySerializedStorageCollection = {"items": resources} - await self.store.async_save(data) + await self.store.async_save(data) # type: ignore[arg-type] await self.ll_config.async_save(conf) return data diff --git a/homeassistant/components/lovelace/websocket.py b/homeassistant/components/lovelace/websocket.py index 423ba3117ea..6eddbc4a6f9 100644 --- a/homeassistant/components/lovelace/websocket.py +++ b/homeassistant/components/lovelace/websocket.py @@ -64,7 +64,7 @@ async def websocket_lovelace_resources( await resources.async_load() resources.loaded = True - connection.send_result(msg["id"], resources.async_items()) + connection.send_result(msg["id"], resources.async_values()) @websocket_api.websocket_command( diff --git a/homeassistant/components/person/__init__.py b/homeassistant/components/person/__init__.py index c1373ce1df9..ad78c982b26 100644 --- a/homeassistant/components/person/__init__.py +++ b/homeassistant/components/person/__init__.py @@ -109,7 +109,7 @@ async def async_add_user_device_tracker( """Add a device tracker to a person linked to a user.""" coll: PersonStorageCollection = hass.data[DOMAIN][1] - for person in coll.async_items(): + for person in coll.async_values(): if person.get(ATTR_USER_ID) != user_id: continue @@ -188,7 +188,7 @@ class PersonStore(Store): return {"items": old_data["persons"]} -class PersonStorageCollection(collection.DictStorageCollection): +class PersonStorageCollection(collection.LegacyDictStorageCollection): """Person collection stored in storage.""" CREATE_SCHEMA = vol.Schema(CREATE_FIELDS) @@ -204,7 +204,9 @@ class PersonStorageCollection(collection.DictStorageCollection): super().__init__(store, id_manager) self.yaml_collection = yaml_collection - async def _async_load_data(self) -> collection.SerializedStorageCollection | None: + async def _async_load_data( # type: ignore[override] + self, + ) -> collection.LegacySerializedStorageCollection | None: """Load the data. A past bug caused onboarding to create invalid person objects. @@ -286,7 +288,10 @@ class PersonStorageCollection(collection.DictStorageCollection): if await self.hass.auth.async_get_user(user_id) is None: raise ValueError("User does not exist") - for persons in (self.data.values(), self.yaml_collection.async_items()): + for persons in ( + self.data.values(), + self.yaml_collection.async_values(), + ): if any(person for person in persons if person.get(CONF_USER_ID) == user_id): raise ValueError("User already taken") @@ -363,7 +368,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def _handle_user_removed(event: Event) -> None: """Handle a user being removed.""" user_id = event.data[ATTR_USER_ID] - for person in storage_collection.async_items(): + for person in storage_collection.async_values(): if person[CONF_USER_ID] == user_id: await storage_collection.async_update_item( person[CONF_ID], {CONF_USER_ID: None} @@ -556,7 +561,11 @@ def ws_list_person( """List persons.""" yaml, storage, _ = hass.data[DOMAIN] connection.send_result( - msg[ATTR_ID], {"storage": storage.async_items(), "config": yaml.async_items()} + msg[ATTR_ID], + { + "storage": storage.async_items(), + "config": yaml.async_values(), + }, ) diff --git a/homeassistant/components/schedule/__init__.py b/homeassistant/components/schedule/__init__.py index 2e5fcc27715..1b2e380b520 100644 --- a/homeassistant/components/schedule/__init__.py +++ b/homeassistant/components/schedule/__init__.py @@ -20,10 +20,10 @@ from homeassistant.const import ( from homeassistant.core import HomeAssistant, ServiceCall, callback from homeassistant.helpers.collection import ( CollectionEntity, - DictStorageCollection, DictStorageCollectionWebsocket, IDManager, - SerializedStorageCollection, + LegacyDictStorageCollection, + LegacySerializedStorageCollection, YamlCollection, sync_entity_lifecycle, ) @@ -209,7 +209,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return True -class ScheduleStorageCollection(DictStorageCollection): +class ScheduleStorageCollection(LegacyDictStorageCollection): """Schedules stored in storage.""" SCHEMA = vol.Schema(BASE_SCHEMA | STORAGE_SCHEDULE_SCHEMA) @@ -230,7 +230,7 @@ class ScheduleStorageCollection(DictStorageCollection): self.SCHEMA(update_data) return item | update_data - async def _async_load_data(self) -> SerializedStorageCollection | None: + async def _async_load_data(self) -> LegacySerializedStorageCollection | None: # type: ignore[override] """Load the data.""" if data := await super()._async_load_data(): data["items"] = [STORAGE_SCHEMA(item) for item in data["items"]] diff --git a/homeassistant/components/tag/__init__.py b/homeassistant/components/tag/__init__.py index cd0dd00afe5..ac93f611617 100644 --- a/homeassistant/components/tag/__init__.py +++ b/homeassistant/components/tag/__init__.py @@ -59,7 +59,7 @@ class TagIDManager(collection.IDManager): return suggestion -class TagStorageCollection(collection.DictStorageCollection): +class TagStorageCollection(collection.LegacyDictStorageCollection): """Tag collection stored in storage.""" CREATE_SCHEMA = vol.Schema(CREATE_FIELDS) diff --git a/homeassistant/components/timer/__init__.py b/homeassistant/components/timer/__init__.py index 7cb2c10425e..c6b1cbeb6f4 100644 --- a/homeassistant/components/timer/__init__.py +++ b/homeassistant/components/timer/__init__.py @@ -162,7 +162,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return True -class TimerStorageCollection(collection.DictStorageCollection): +class TimerStorageCollection(collection.LegacyDictStorageCollection): """Timer storage based collection.""" CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS) diff --git a/homeassistant/components/zone/__init__.py b/homeassistant/components/zone/__init__.py index 2133c8550da..140fb8f37c4 100644 --- a/homeassistant/components/zone/__init__.py +++ b/homeassistant/components/zone/__init__.py @@ -163,7 +163,7 @@ def in_zone(zone: State, latitude: float, longitude: float, radius: float = 0) - return zone_dist - radius < cast(float, zone.attributes[ATTR_RADIUS]) -class ZoneStorageCollection(collection.DictStorageCollection): +class ZoneStorageCollection(collection.LegacyDictStorageCollection): """Zone collection stored in storage.""" CREATE_SCHEMA = vol.Schema(CREATE_FIELDS) diff --git a/homeassistant/helpers/collection.py b/homeassistant/helpers/collection.py index cfc1750f7e5..b19b6d7eaa2 100644 --- a/homeassistant/helpers/collection.py +++ b/homeassistant/helpers/collection.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from itertools import groupby import logging from operator import attrgetter -from typing import Any, Generic, TypedDict, TypeVar +from typing import Any, Generic, TypedDict, TypeVar, cast import voluptuous as vol from voluptuous.humanize import humanize_error @@ -138,7 +138,12 @@ class ObservableCollection(ABC, Generic[_ItemT]): self.id_manager.add_collection(self.data) @callback - def async_items(self) -> list[_ItemT]: + def async_items(self) -> dict[str, _ItemT]: + """Return a shallow copy of the collection.""" + return dict(self.data) + + @callback + def async_values(self) -> list[_ItemT]: """Return list of items in collection.""" return list(self.data.values()) @@ -229,9 +234,21 @@ class YamlCollection(ObservableCollection[dict]): class SerializedStorageCollection(TypedDict): """Serialized storage collection.""" + items: list[dict[str, Any]] | dict[str, dict[str, Any]] + + +class LegacySerializedStorageCollection(TypedDict): + """Serialized storage collection.""" + items: list[dict[str, Any]] +class ModernSerializedStorageCollection(TypedDict): + """Serialized storage collection.""" + + items: dict[str, dict[str, Any]] + + class StorageCollection(ObservableCollection[_ItemT], Generic[_ItemT, _StoreT]): """Offer a CRUD interface on top of JSON storage.""" @@ -258,20 +275,29 @@ class StorageCollection(ObservableCollection[_ItemT], Generic[_ItemT, _StoreT]): async def _async_load_data(self) -> _StoreT | None: """Load the data.""" - return await self.store.async_load() + if (raw_storage := await self.store.async_load()) is None: + return raw_storage + + if isinstance(raw_storage["items"], list): + raw_storage["items"] = { + item.pop(CONF_ID): item for item in raw_storage["items"] + } + await self.store.async_save(raw_storage) + + return raw_storage async def async_load(self) -> None: - """Load the storage Manager.""" + """Load the collection.""" if not (raw_storage := await self._async_load_data()): return - for item in raw_storage["items"]: - self.data[item[CONF_ID]] = self._deserialize_item(item) + for item_id, item in raw_storage["items"].items(): + self.data[item_id] = self._deserialize_item(item) await self.notify_changes( [ - CollectionChangeSet(CHANGE_ADDED, item[CONF_ID], item) - for item in raw_storage["items"] + CollectionChangeSet(CHANGE_ADDED, item_id, item) + for item_id, item in raw_storage["items"].items() ] ) @@ -297,13 +323,10 @@ class StorageCollection(ObservableCollection[_ItemT], Generic[_ItemT, _StoreT]): """Create an item from its serialized representation.""" @abstractmethod - def _serialize_item(self, item_id: str, item: _ItemT) -> dict: - """Return the serialized representation of an item for storing. + def _serialize_item(self, item: _ItemT) -> dict: + """Return the serialized representation of an item for storing.""" - The serialized representation must include the item_id in the "id" key. - """ - - async def async_create_item(self, data: dict) -> _ItemT: + async def async_create_item(self, data: dict) -> tuple[str, _ItemT]: """Create a new item.""" validated_data = await self._process_create_data(data) item_id = self.id_manager.generate_id(self._get_suggested_id(validated_data)) @@ -311,7 +334,7 @@ class StorageCollection(ObservableCollection[_ItemT], Generic[_ItemT, _StoreT]): self.data[item_id] = item self._async_schedule_save() await self.notify_changes([CollectionChangeSet(CHANGE_ADDED, item_id, item)]) - return item + return item_id, item async def async_update_item(self, item_id: str, updates: dict) -> _ItemT: """Update item.""" @@ -353,10 +376,10 @@ class StorageCollection(ObservableCollection[_ItemT], Generic[_ItemT, _StoreT]): def _base_data_to_save(self) -> SerializedStorageCollection: """Return JSON-compatible data for storing to file.""" return { - "items": [ - self._serialize_item(item_id, item) + "items": { + item_id: self._serialize_item(item) for item_id, item in self.data.items() - ] + } } @abstractmethod @@ -370,13 +393,13 @@ class DictStorageCollection(StorageCollection[dict, SerializedStorageCollection] def _create_item(self, item_id: str, data: dict) -> dict: """Create an item from its validated, serialized representation.""" - return {CONF_ID: item_id} | data + return data def _deserialize_item(self, data: dict) -> dict: """Create an item from its validated, serialized representation.""" return data - def _serialize_item(self, item_id: str, item: dict) -> dict: + def _serialize_item(self, item: dict) -> dict: """Return the serialized representation of an item for storing.""" return item @@ -386,6 +409,43 @@ class DictStorageCollection(StorageCollection[dict, SerializedStorageCollection] return self._base_data_to_save() +class LegacyDictStorageCollection(DictStorageCollection): + """A specialized StorageCollection where the items are untyped dicts.""" + + async def _async_load_data(self) -> LegacySerializedStorageCollection | None: # type: ignore[override] + """Load the data.""" + return cast( + LegacySerializedStorageCollection | None, + await self.store.async_load(), + ) + + async def async_load(self) -> None: + """Load the collection.""" + raw_storage = await self._async_load_data() + + if raw_storage is None: + raw_storage = {"items": []} + + for item in raw_storage["items"]: + self.data[item[CONF_ID]] = self._deserialize_item(item) + + await self.notify_changes( + [ + CollectionChangeSet(CHANGE_ADDED, item[CONF_ID], item) + for item in raw_storage["items"] + ] + ) + + def _create_item(self, item_id: str, data: dict) -> dict: + """Create an item from its validated, serialized representation.""" + return {CONF_ID: item_id} | data + + @callback + def _data_to_save(self) -> SerializedStorageCollection: + """Return JSON-compatible data for storing to file.""" + return {"items": [self._serialize_item(item) for item in self.data.values()]} + + class IDLessCollection(YamlCollection): """A collection without IDs.""" @@ -583,7 +643,13 @@ class StorageCollectionWebsocket(Generic[_StorageCollectionT]): self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict ) -> None: """List items.""" - connection.send_result(msg["id"], self.storage_collection.async_items()) + connection.send_result( + msg["id"], + [ + {CONF_ID: item_id} | item + for item_id, item in self.storage_collection.data.items() + ], + ) async def ws_create_item( self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict @@ -593,8 +659,8 @@ class StorageCollectionWebsocket(Generic[_StorageCollectionT]): data = dict(msg) data.pop("id") data.pop("type") - item = await self.storage_collection.async_create_item(data) - connection.send_result(msg["id"], item) + item_id, item = await self.storage_collection.async_create_item(data) + connection.send_result(msg["id"], {CONF_ID: item_id} | item) except vol.Invalid as err: connection.send_error( msg["id"], @@ -617,7 +683,7 @@ class StorageCollectionWebsocket(Generic[_StorageCollectionT]): try: item = await self.storage_collection.async_update_item(item_id, data) - connection.send_result(msg_id, item) + connection.send_result(msg_id, {CONF_ID: item_id} | item) except ItemNotFound: connection.send_error( msg["id"], diff --git a/tests/components/assist_pipeline/test_pipeline.py b/tests/components/assist_pipeline/test_pipeline.py index f84fb2fa1d1..f17ff6df8c5 100644 --- a/tests/components/assist_pipeline/test_pipeline.py +++ b/tests/components/assist_pipeline/test_pipeline.py @@ -15,7 +15,7 @@ from homeassistant.setup import async_setup_component from tests.common import flush_store -async def test_load_datasets(hass: HomeAssistant, init_components) -> None: +async def test_load_pipelines(hass: HomeAssistant, init_components) -> None: """Make sure that we can load/save data correctly.""" pipelines = [ @@ -46,7 +46,7 @@ async def test_load_datasets(hass: HomeAssistant, init_components) -> None: pipeline_data: PipelineData = hass.data[DOMAIN] store1 = pipeline_data.pipeline_store for pipeline in pipelines: - pipeline_ids.append((await store1.async_create_item(pipeline)).id) + pipeline_ids.append((await store1.async_create_item(pipeline))[1].id) assert len(store1.data) == 3 assert store1.async_get_preferred_item() == list(store1.data)[0] @@ -64,10 +64,10 @@ async def test_load_datasets(hass: HomeAssistant, init_components) -> None: assert store1.async_get_preferred_item() == store2.async_get_preferred_item() -async def test_loading_datasets_from_storage( +async def test_loading_pipelines_from_storage( hass: HomeAssistant, hass_storage: dict[str, Any] ) -> None: - """Test loading stored datasets on start.""" + """Test loading stored pipelines on start.""" hass_storage[STORAGE_KEY] = { "version": 1, "minor_version": 1, diff --git a/tests/components/person/test_init.py b/tests/components/person/test_init.py index d22de580c2a..741ee01da67 100644 --- a/tests/components/person/test_init.py +++ b/tests/components/person/test_init.py @@ -473,7 +473,7 @@ async def test_ws_create( ) resp = await client.receive_json() - persons = manager.async_items() + persons = manager.async_values() assert len(persons) == 2 assert resp["success"] @@ -504,7 +504,7 @@ async def test_ws_create_requires_admin( ) resp = await client.receive_json() - persons = manager.async_items() + persons = manager.async_values() assert len(persons) == 1 assert not resp["success"] @@ -517,7 +517,7 @@ async def test_ws_update( manager = hass.data[DOMAIN][1] client = await hass_ws_client(hass) - persons = manager.async_items() + persons = manager.async_values() resp = await client.send_json( { @@ -544,7 +544,7 @@ async def test_ws_update( ) resp = await client.receive_json() - persons = manager.async_items() + persons = manager.async_values() assert len(persons) == 1 assert resp["success"] @@ -570,7 +570,7 @@ async def test_ws_update_require_admin( manager = hass.data[DOMAIN][1] client = await hass_ws_client(hass) - original = dict(manager.async_items()[0]) + original = dict(manager.async_values()[0]) resp = await client.send_json( { @@ -585,7 +585,7 @@ async def test_ws_update_require_admin( resp = await client.receive_json() assert not resp["success"] - not_updated = dict(manager.async_items()[0]) + not_updated = dict(manager.async_values()[0]) assert original == not_updated @@ -596,14 +596,14 @@ async def test_ws_delete( manager = hass.data[DOMAIN][1] client = await hass_ws_client(hass) - persons = manager.async_items() + persons = manager.async_values() resp = await client.send_json( {"id": 6, "type": "person/delete", "person_id": persons[0]["id"]} ) resp = await client.receive_json() - persons = manager.async_items() + persons = manager.async_values() assert len(persons) == 0 assert resp["success"] @@ -628,7 +628,7 @@ async def test_ws_delete_require_admin( { "id": 6, "type": "person/delete", - "person_id": manager.async_items()[0]["id"], + "person_id": manager.async_values()[0]["id"], "name": "Updated Name", "device_trackers": [DEVICE_TRACKER_2], "user_id": None, @@ -637,7 +637,7 @@ async def test_ws_delete_require_admin( resp = await client.receive_json() assert not resp["success"] - persons = manager.async_items() + persons = manager.async_values() assert len(persons) == 1 @@ -670,7 +670,7 @@ async def test_update_double_user_id( await storage_collection.async_create_item( {"name": "Hello", "user_id": hass_admin_user.id} ) - person = await storage_collection.async_create_item({"name": "Hello"}) + _, person = await storage_collection.async_create_item({"name": "Hello"}) with pytest.raises(ValueError): await storage_collection.async_update_item( @@ -680,7 +680,7 @@ async def test_update_double_user_id( async def test_update_invalid_user_id(hass: HomeAssistant, storage_collection) -> None: """Test updating to invalid user ID.""" - person = await storage_collection.async_create_item({"name": "Hello"}) + _, person = await storage_collection.async_create_item({"name": "Hello"}) with pytest.raises(ValueError): await storage_collection.async_update_item( @@ -694,7 +694,7 @@ async def test_update_person_when_user_removed( """Update person when user is removed.""" storage_collection = hass.data[DOMAIN][1] - person = await storage_collection.async_create_item( + _, person = await storage_collection.async_create_item( {"name": "Hello", "user_id": hass_read_only_user.id} ) @@ -712,7 +712,7 @@ async def test_removing_device_tracker(hass: HomeAssistant, storage_setup) -> No "device_tracker", "mobile_app", "bla", suggested_object_id="pixel" ) - person = await storage_collection.async_create_item( + _, person = await storage_collection.async_create_item( {"name": "Hello", "device_trackers": [entry.entity_id]} ) @@ -727,7 +727,7 @@ async def test_add_user_device_tracker( ) -> None: """Test adding a device tracker to a person tied to a user.""" storage_collection = hass.data[DOMAIN][1] - pers = await storage_collection.async_create_item( + _, pers = await storage_collection.async_create_item( { "name": "Hello", "user_id": hass_read_only_user.id, diff --git a/tests/helpers/test_collection.py b/tests/helpers/test_collection.py index 7969e02ab2f..2821b7c2fb3 100644 --- a/tests/helpers/test_collection.py +++ b/tests/helpers/test_collection.py @@ -117,9 +117,9 @@ def test_id_manager() -> None: async def test_observable_collection() -> None: """Test observerable collection.""" coll = collection.ObservableCollection(None) - assert coll.async_items() == [] + assert coll.async_items() == {} coll.data["bla"] = 1 - assert coll.async_items() == [1] + assert coll.async_items() == {"bla": 1} changes = track_changes(coll) await coll.notify_changes( @@ -193,6 +193,69 @@ async def test_yaml_collection_skipping_duplicate_ids() -> None: async def test_storage_collection(hass: HomeAssistant) -> None: """Test storage collection.""" store = storage.Store(hass, 1, "test-data") + await store.async_save( + { + "items": { + "mock-1": {"name": "Mock 1", "data": 1}, + "mock-2": {"name": "Mock 2", "data": 2}, + } + } + ) + id_manager = collection.IDManager() + coll = MockStorageCollection(store, id_manager) + changes = track_changes(coll) + + await coll.async_load() + assert id_manager.has_id("mock-1") + assert id_manager.has_id("mock-2") + assert len(changes) == 2 + assert changes[0] == ( + collection.CHANGE_ADDED, + "mock-1", + {"name": "Mock 1", "data": 1}, + ) + assert changes[1] == ( + collection.CHANGE_ADDED, + "mock-2", + {"name": "Mock 2", "data": 2}, + ) + + item_id, item = await coll.async_create_item({"name": "Mock 3"}) + assert item_id == "mock_3" + assert len(changes) == 3 + assert changes[2] == ( + collection.CHANGE_ADDED, + "mock_3", + {"name": "Mock 3"}, + ) + + updated_item = await coll.async_update_item("mock-2", {"name": "Mock 2 updated"}) + assert id_manager.has_id("mock-2") + assert updated_item == {"name": "Mock 2 updated", "data": 2} + assert len(changes) == 4 + assert changes[3] == (collection.CHANGE_UPDATED, "mock-2", updated_item) + + with pytest.raises(ValueError): + await coll.async_update_item("mock-2", {"id": "mock-2-updated"}) + + assert id_manager.has_id("mock-2") + assert not id_manager.has_id("mock-2-updated") + assert len(changes) == 4 + + await flush_store(store) + + assert await storage.Store(hass, 1, "test-data").async_load() == { + "items": { + "mock-1": {"name": "Mock 1", "data": 1}, + "mock-2": {"name": "Mock 2 updated", "data": 2}, + "mock_3": {"name": "Mock 3"}, + } + } + + +async def test_storage_collection_migration(hass: HomeAssistant) -> None: + """Test storage collection migration from old store format.""" + store = storage.Store(hass, 1, "test-data") await store.async_save( { "items": [ @@ -212,44 +275,19 @@ async def test_storage_collection(hass: HomeAssistant) -> None: assert changes[0] == ( collection.CHANGE_ADDED, "mock-1", - {"id": "mock-1", "name": "Mock 1", "data": 1}, + {"name": "Mock 1", "data": 1}, ) assert changes[1] == ( collection.CHANGE_ADDED, "mock-2", - {"id": "mock-2", "name": "Mock 2", "data": 2}, + {"name": "Mock 2", "data": 2}, ) - item = await coll.async_create_item({"name": "Mock 3"}) - assert item["id"] == "mock_3" - assert len(changes) == 3 - assert changes[2] == ( - collection.CHANGE_ADDED, - "mock_3", - {"id": "mock_3", "name": "Mock 3"}, - ) - - updated_item = await coll.async_update_item("mock-2", {"name": "Mock 2 updated"}) - assert id_manager.has_id("mock-2") - assert updated_item == {"id": "mock-2", "name": "Mock 2 updated", "data": 2} - assert len(changes) == 4 - assert changes[3] == (collection.CHANGE_UPDATED, "mock-2", updated_item) - - with pytest.raises(ValueError): - await coll.async_update_item("mock-2", {"id": "mock-2-updated"}) - - assert id_manager.has_id("mock-2") - assert not id_manager.has_id("mock-2-updated") - assert len(changes) == 4 - - await flush_store(store) - assert await storage.Store(hass, 1, "test-data").async_load() == { - "items": [ - {"id": "mock-1", "name": "Mock 1", "data": 1}, - {"id": "mock-2", "name": "Mock 2 updated", "data": 2}, - {"id": "mock_3", "name": "Mock 3"}, - ] + "items": { + "mock-1": {"name": "Mock 1", "data": 1}, + "mock-2": {"name": "Mock 2", "data": 2}, + } } @@ -477,6 +515,7 @@ async def test_storage_collection_websocket( "immutable_string": "no-changes", } assert len(changes) == 1 + response["result"].pop("id") assert changes[0] == (collection.CHANGE_ADDED, "initial_name", response["result"]) # List @@ -537,6 +576,7 @@ async def test_storage_collection_websocket( "immutable_string": "no-changes", } assert len(changes) == 2 + response["result"].pop("id") assert changes[1] == (collection.CHANGE_UPDATED, "initial_name", response["result"]) # Delete invalid ID @@ -560,7 +600,6 @@ async def test_storage_collection_websocket( collection.CHANGE_REMOVED, "initial_name", { - "id": "initial_name", "immutable_string": "no-changes", "name": "Updated name", },