Don't add id to storage collection data

This commit is contained in:
Erik 2023-04-07 21:49:29 +02:00
parent 2819ad9a16
commit ea79580f58
22 changed files with 238 additions and 119 deletions

View file

@ -20,7 +20,6 @@ from homeassistant.const import (
CONF_CLIENT_ID,
CONF_CLIENT_SECRET,
CONF_DOMAIN,
CONF_ID,
CONF_NAME,
)
from homeassistant.core import HomeAssistant, callback
@ -125,10 +124,10 @@ class ApplicationCredentialsStorageCollection(collection.DictStorageCollection):
def async_client_credentials(self, domain: str) -> dict[str, ClientCredential]:
"""Return ClientCredentials in storage for the specified domain."""
credentials = {}
for item in self.async_items():
for item_id, item in self.data.items():
if item[CONF_DOMAIN] != domain:
continue
auth_domain = item.get(CONF_AUTH_DOMAIN, item[CONF_ID])
auth_domain = item.get(CONF_AUTH_DOMAIN, item_id)
credentials[auth_domain] = ClientCredential(
client_id=item[CONF_CLIENT_ID],
client_secret=item[CONF_CLIENT_SECRET],
@ -244,8 +243,7 @@ async def _async_config_entry_app_credentials(
return None
storage_collection = hass.data[DOMAIN][DATA_STORAGE]
for item in storage_collection.async_items():
item_id = item[CONF_ID]
for item_id, item in storage_collection.async_items().items():
if (
item[CONF_DOMAIN] == config_entry.domain
and item.get(CONF_AUTH_DOMAIN, item_id) == auth_domain

View file

@ -63,7 +63,7 @@ async def async_get_pipeline(
# Construct a pipeline for the required/configured language
language = language or hass.config.language
return await pipeline_data.pipeline_store.async_create_item(
_, pipeline = await pipeline_data.pipeline_store.async_create_item(
{
"name": language,
"language": language,
@ -72,6 +72,7 @@ async def async_get_pipeline(
"tts_engine": None, # first engine
}
)
return pipeline
class PipelineEventType(StrEnum):
@ -610,7 +611,7 @@ class PipelineStorageCollection(
"""Create an item from its serialized representation."""
return Pipeline(**data)
def _serialize_item(self, item_id: str, item: Pipeline) -> dict:
def _serialize_item(self, item: Pipeline) -> dict:
"""Return the serialized representation of an item for storing."""
return item.to_json()

View file

@ -139,7 +139,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
class CounterStorageCollection(collection.DictStorageCollection):
class CounterStorageCollection(collection.LegacyDictStorageCollection):
"""Input storage based collection."""
CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)

View file

@ -57,7 +57,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
class ImageStorageCollection(collection.DictStorageCollection):
class ImageStorageCollection(collection.LegacyDictStorageCollection):
"""Image collection stored in storage."""
CREATE_SCHEMA = vol.Schema(CREATE_FIELDS)
@ -156,7 +156,7 @@ class ImageUploadView(HomeAssistantView):
request._client_max_size = MAX_SIZE # pylint: disable=protected-access
data = await request.post()
item = await request.app["hass"].data[DOMAIN].async_create_item(data)
_, item = await request.app["hass"].data[DOMAIN].async_create_item(data)
return self.json(item)

View file

@ -65,7 +65,7 @@ STORAGE_KEY = DOMAIN
STORAGE_VERSION = 1
class InputBooleanStorageCollection(collection.DictStorageCollection):
class InputBooleanStorageCollection(collection.LegacyDictStorageCollection):
"""Input boolean collection stored in storage."""
CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)

View file

@ -56,7 +56,7 @@ STORAGE_KEY = DOMAIN
STORAGE_VERSION = 1
class InputButtonStorageCollection(collection.DictStorageCollection):
class InputButtonStorageCollection(collection.LegacyDictStorageCollection):
"""Input button collection stored in storage."""
CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)

View file

@ -203,7 +203,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
class DateTimeStorageCollection(collection.DictStorageCollection):
class DateTimeStorageCollection(collection.LegacyDictStorageCollection):
"""Input storage based collection."""
CREATE_UPDATE_SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, has_date_or_time))

View file

@ -170,7 +170,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
class NumberStorageCollection(collection.DictStorageCollection):
class NumberStorageCollection(collection.LegacyDictStorageCollection):
"""Input storage based collection."""
SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, _cv_input_number))
@ -184,7 +184,9 @@ class NumberStorageCollection(collection.DictStorageCollection):
"""Suggest an ID based on the config."""
return info[CONF_NAME]
async def _async_load_data(self) -> collection.SerializedStorageCollection | None:
async def _async_load_data( # type: ignore[override]
self,
) -> collection.LegacySerializedStorageCollection | None:
"""Load the data.
A past bug caused frontend to add initial value to all input numbers.

View file

@ -231,7 +231,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
class InputSelectStorageCollection(collection.DictStorageCollection):
class InputSelectStorageCollection(collection.LegacyDictStorageCollection):
"""Input storage based collection."""
CREATE_UPDATE_SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, _cv_input_select))

View file

@ -164,7 +164,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
class InputTextStorageCollection(collection.DictStorageCollection):
class InputTextStorageCollection(collection.LegacyDictStorageCollection):
"""Input storage based collection."""
CREATE_UPDATE_SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, _cv_input_text))

View file

@ -217,7 +217,7 @@ def _config_info(mode, config):
}
class DashboardsCollection(collection.DictStorageCollection):
class DashboardsCollection(collection.LegacyDictStorageCollection):
"""Collection of dashboards."""
CREATE_SCHEMA = vol.Schema(STORAGE_DASHBOARD_CREATE_FIELDS)
@ -229,9 +229,11 @@ class DashboardsCollection(collection.DictStorageCollection):
storage.Store(hass, DASHBOARDS_STORAGE_VERSION, DASHBOARDS_STORAGE_KEY),
)
async def _async_load_data(self) -> collection.SerializedStorageCollection | None:
async def _async_load_data( # type: ignore[override]
self,
) -> collection.LegacySerializedStorageCollection | None:
"""Load the data."""
if (data := await self.store.async_load()) is None:
if (data := await super()._async_load_data()) is None:
return data
updated = False
@ -242,7 +244,7 @@ class DashboardsCollection(collection.DictStorageCollection):
item[CONF_URL_PATH] = f"lovelace-{item[CONF_URL_PATH]}"
if updated:
await self.store.async_save(data)
await self.store.async_save(data) # type: ignore[arg-type]
return data

View file

@ -37,15 +37,15 @@ class ResourceYAMLCollection:
async def async_get_info(self):
"""Return the resources info for YAML mode."""
return {"resources": len(self.async_items() or [])}
return {"resources": len(self.async_values() or [])}
@callback
def async_items(self) -> list[dict]:
def async_values(self) -> list[dict]:
"""Return list of items in collection."""
return self.data
class ResourceStorageCollection(collection.DictStorageCollection):
class ResourceStorageCollection(collection.LegacyDictStorageCollection):
"""Collection to store resources."""
loaded = False
@ -67,9 +67,11 @@ class ResourceStorageCollection(collection.DictStorageCollection):
return {"resources": len(self.async_items() or [])}
async def _async_load_data(self) -> collection.SerializedStorageCollection | None:
async def _async_load_data( # type: ignore[override]
self,
) -> collection.LegacySerializedStorageCollection | None:
"""Load the data."""
if (store_data := await self.store.async_load()) is not None:
if (store_data := await super()._async_load_data()) is not None:
return store_data
# Import it from config.
@ -95,9 +97,9 @@ class ResourceStorageCollection(collection.DictStorageCollection):
for item in resources:
item[CONF_ID] = uuid.uuid4().hex
data: collection.SerializedStorageCollection = {"items": resources}
data: collection.LegacySerializedStorageCollection = {"items": resources}
await self.store.async_save(data)
await self.store.async_save(data) # type: ignore[arg-type]
await self.ll_config.async_save(conf)
return data

View file

@ -64,7 +64,7 @@ async def websocket_lovelace_resources(
await resources.async_load()
resources.loaded = True
connection.send_result(msg["id"], resources.async_items())
connection.send_result(msg["id"], resources.async_values())
@websocket_api.websocket_command(

View file

@ -109,7 +109,7 @@ async def async_add_user_device_tracker(
"""Add a device tracker to a person linked to a user."""
coll: PersonStorageCollection = hass.data[DOMAIN][1]
for person in coll.async_items():
for person in coll.async_values():
if person.get(ATTR_USER_ID) != user_id:
continue
@ -188,7 +188,7 @@ class PersonStore(Store):
return {"items": old_data["persons"]}
class PersonStorageCollection(collection.DictStorageCollection):
class PersonStorageCollection(collection.LegacyDictStorageCollection):
"""Person collection stored in storage."""
CREATE_SCHEMA = vol.Schema(CREATE_FIELDS)
@ -204,7 +204,9 @@ class PersonStorageCollection(collection.DictStorageCollection):
super().__init__(store, id_manager)
self.yaml_collection = yaml_collection
async def _async_load_data(self) -> collection.SerializedStorageCollection | None:
async def _async_load_data( # type: ignore[override]
self,
) -> collection.LegacySerializedStorageCollection | None:
"""Load the data.
A past bug caused onboarding to create invalid person objects.
@ -286,7 +288,10 @@ class PersonStorageCollection(collection.DictStorageCollection):
if await self.hass.auth.async_get_user(user_id) is None:
raise ValueError("User does not exist")
for persons in (self.data.values(), self.yaml_collection.async_items()):
for persons in (
self.data.values(),
self.yaml_collection.async_values(),
):
if any(person for person in persons if person.get(CONF_USER_ID) == user_id):
raise ValueError("User already taken")
@ -363,7 +368,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def _handle_user_removed(event: Event) -> None:
"""Handle a user being removed."""
user_id = event.data[ATTR_USER_ID]
for person in storage_collection.async_items():
for person in storage_collection.async_values():
if person[CONF_USER_ID] == user_id:
await storage_collection.async_update_item(
person[CONF_ID], {CONF_USER_ID: None}
@ -556,7 +561,11 @@ def ws_list_person(
"""List persons."""
yaml, storage, _ = hass.data[DOMAIN]
connection.send_result(
msg[ATTR_ID], {"storage": storage.async_items(), "config": yaml.async_items()}
msg[ATTR_ID],
{
"storage": storage.async_items(),
"config": yaml.async_values(),
},
)

View file

@ -20,10 +20,10 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.helpers.collection import (
CollectionEntity,
DictStorageCollection,
DictStorageCollectionWebsocket,
IDManager,
SerializedStorageCollection,
LegacyDictStorageCollection,
LegacySerializedStorageCollection,
YamlCollection,
sync_entity_lifecycle,
)
@ -209,7 +209,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
class ScheduleStorageCollection(DictStorageCollection):
class ScheduleStorageCollection(LegacyDictStorageCollection):
"""Schedules stored in storage."""
SCHEMA = vol.Schema(BASE_SCHEMA | STORAGE_SCHEDULE_SCHEMA)
@ -230,7 +230,7 @@ class ScheduleStorageCollection(DictStorageCollection):
self.SCHEMA(update_data)
return item | update_data
async def _async_load_data(self) -> SerializedStorageCollection | None:
async def _async_load_data(self) -> LegacySerializedStorageCollection | None: # type: ignore[override]
"""Load the data."""
if data := await super()._async_load_data():
data["items"] = [STORAGE_SCHEMA(item) for item in data["items"]]

View file

@ -59,7 +59,7 @@ class TagIDManager(collection.IDManager):
return suggestion
class TagStorageCollection(collection.DictStorageCollection):
class TagStorageCollection(collection.LegacyDictStorageCollection):
"""Tag collection stored in storage."""
CREATE_SCHEMA = vol.Schema(CREATE_FIELDS)

View file

@ -162,7 +162,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
class TimerStorageCollection(collection.DictStorageCollection):
class TimerStorageCollection(collection.LegacyDictStorageCollection):
"""Timer storage based collection."""
CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)

View file

@ -163,7 +163,7 @@ def in_zone(zone: State, latitude: float, longitude: float, radius: float = 0) -
return zone_dist - radius < cast(float, zone.attributes[ATTR_RADIUS])
class ZoneStorageCollection(collection.DictStorageCollection):
class ZoneStorageCollection(collection.LegacyDictStorageCollection):
"""Zone collection stored in storage."""
CREATE_SCHEMA = vol.Schema(CREATE_FIELDS)

View file

@ -8,7 +8,7 @@ from dataclasses import dataclass
from itertools import groupby
import logging
from operator import attrgetter
from typing import Any, Generic, TypedDict, TypeVar
from typing import Any, Generic, TypedDict, TypeVar, cast
import voluptuous as vol
from voluptuous.humanize import humanize_error
@ -138,7 +138,12 @@ class ObservableCollection(ABC, Generic[_ItemT]):
self.id_manager.add_collection(self.data)
@callback
def async_items(self) -> list[_ItemT]:
def async_items(self) -> dict[str, _ItemT]:
"""Return a shallow copy of the collection."""
return dict(self.data)
@callback
def async_values(self) -> list[_ItemT]:
"""Return list of items in collection."""
return list(self.data.values())
@ -229,9 +234,21 @@ class YamlCollection(ObservableCollection[dict]):
class SerializedStorageCollection(TypedDict):
"""Serialized storage collection."""
items: list[dict[str, Any]] | dict[str, dict[str, Any]]
class LegacySerializedStorageCollection(TypedDict):
"""Serialized storage collection."""
items: list[dict[str, Any]]
class ModernSerializedStorageCollection(TypedDict):
"""Serialized storage collection."""
items: dict[str, dict[str, Any]]
class StorageCollection(ObservableCollection[_ItemT], Generic[_ItemT, _StoreT]):
"""Offer a CRUD interface on top of JSON storage."""
@ -258,20 +275,29 @@ class StorageCollection(ObservableCollection[_ItemT], Generic[_ItemT, _StoreT]):
async def _async_load_data(self) -> _StoreT | None:
"""Load the data."""
return await self.store.async_load()
if (raw_storage := await self.store.async_load()) is None:
return raw_storage
if isinstance(raw_storage["items"], list):
raw_storage["items"] = {
item.pop(CONF_ID): item for item in raw_storage["items"]
}
await self.store.async_save(raw_storage)
return raw_storage
async def async_load(self) -> None:
"""Load the storage Manager."""
"""Load the collection."""
if not (raw_storage := await self._async_load_data()):
return
for item in raw_storage["items"]:
self.data[item[CONF_ID]] = self._deserialize_item(item)
for item_id, item in raw_storage["items"].items():
self.data[item_id] = self._deserialize_item(item)
await self.notify_changes(
[
CollectionChangeSet(CHANGE_ADDED, item[CONF_ID], item)
for item in raw_storage["items"]
CollectionChangeSet(CHANGE_ADDED, item_id, item)
for item_id, item in raw_storage["items"].items()
]
)
@ -297,13 +323,10 @@ class StorageCollection(ObservableCollection[_ItemT], Generic[_ItemT, _StoreT]):
"""Create an item from its serialized representation."""
@abstractmethod
def _serialize_item(self, item_id: str, item: _ItemT) -> dict:
"""Return the serialized representation of an item for storing.
def _serialize_item(self, item: _ItemT) -> dict:
"""Return the serialized representation of an item for storing."""
The serialized representation must include the item_id in the "id" key.
"""
async def async_create_item(self, data: dict) -> _ItemT:
async def async_create_item(self, data: dict) -> tuple[str, _ItemT]:
"""Create a new item."""
validated_data = await self._process_create_data(data)
item_id = self.id_manager.generate_id(self._get_suggested_id(validated_data))
@ -311,7 +334,7 @@ class StorageCollection(ObservableCollection[_ItemT], Generic[_ItemT, _StoreT]):
self.data[item_id] = item
self._async_schedule_save()
await self.notify_changes([CollectionChangeSet(CHANGE_ADDED, item_id, item)])
return item
return item_id, item
async def async_update_item(self, item_id: str, updates: dict) -> _ItemT:
"""Update item."""
@ -353,10 +376,10 @@ class StorageCollection(ObservableCollection[_ItemT], Generic[_ItemT, _StoreT]):
def _base_data_to_save(self) -> SerializedStorageCollection:
"""Return JSON-compatible data for storing to file."""
return {
"items": [
self._serialize_item(item_id, item)
"items": {
item_id: self._serialize_item(item)
for item_id, item in self.data.items()
]
}
}
@abstractmethod
@ -370,13 +393,13 @@ class DictStorageCollection(StorageCollection[dict, SerializedStorageCollection]
def _create_item(self, item_id: str, data: dict) -> dict:
"""Create an item from its validated, serialized representation."""
return {CONF_ID: item_id} | data
return data
def _deserialize_item(self, data: dict) -> dict:
"""Create an item from its validated, serialized representation."""
return data
def _serialize_item(self, item_id: str, item: dict) -> dict:
def _serialize_item(self, item: dict) -> dict:
"""Return the serialized representation of an item for storing."""
return item
@ -386,6 +409,43 @@ class DictStorageCollection(StorageCollection[dict, SerializedStorageCollection]
return self._base_data_to_save()
class LegacyDictStorageCollection(DictStorageCollection):
"""A specialized StorageCollection where the items are untyped dicts."""
async def _async_load_data(self) -> LegacySerializedStorageCollection | None: # type: ignore[override]
"""Load the data."""
return cast(
LegacySerializedStorageCollection | None,
await self.store.async_load(),
)
async def async_load(self) -> None:
"""Load the collection."""
raw_storage = await self._async_load_data()
if raw_storage is None:
raw_storage = {"items": []}
for item in raw_storage["items"]:
self.data[item[CONF_ID]] = self._deserialize_item(item)
await self.notify_changes(
[
CollectionChangeSet(CHANGE_ADDED, item[CONF_ID], item)
for item in raw_storage["items"]
]
)
def _create_item(self, item_id: str, data: dict) -> dict:
"""Create an item from its validated, serialized representation."""
return {CONF_ID: item_id} | data
@callback
def _data_to_save(self) -> SerializedStorageCollection:
"""Return JSON-compatible data for storing to file."""
return {"items": [self._serialize_item(item) for item in self.data.values()]}
class IDLessCollection(YamlCollection):
"""A collection without IDs."""
@ -583,7 +643,13 @@ class StorageCollectionWebsocket(Generic[_StorageCollectionT]):
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""List items."""
connection.send_result(msg["id"], self.storage_collection.async_items())
connection.send_result(
msg["id"],
[
{CONF_ID: item_id} | item
for item_id, item in self.storage_collection.data.items()
],
)
async def ws_create_item(
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
@ -593,8 +659,8 @@ class StorageCollectionWebsocket(Generic[_StorageCollectionT]):
data = dict(msg)
data.pop("id")
data.pop("type")
item = await self.storage_collection.async_create_item(data)
connection.send_result(msg["id"], item)
item_id, item = await self.storage_collection.async_create_item(data)
connection.send_result(msg["id"], {CONF_ID: item_id} | item)
except vol.Invalid as err:
connection.send_error(
msg["id"],
@ -617,7 +683,7 @@ class StorageCollectionWebsocket(Generic[_StorageCollectionT]):
try:
item = await self.storage_collection.async_update_item(item_id, data)
connection.send_result(msg_id, item)
connection.send_result(msg_id, {CONF_ID: item_id} | item)
except ItemNotFound:
connection.send_error(
msg["id"],

View file

@ -15,7 +15,7 @@ from homeassistant.setup import async_setup_component
from tests.common import flush_store
async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
async def test_load_pipelines(hass: HomeAssistant, init_components) -> None:
"""Make sure that we can load/save data correctly."""
pipelines = [
@ -46,7 +46,7 @@ async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
pipeline_data: PipelineData = hass.data[DOMAIN]
store1 = pipeline_data.pipeline_store
for pipeline in pipelines:
pipeline_ids.append((await store1.async_create_item(pipeline)).id)
pipeline_ids.append((await store1.async_create_item(pipeline))[1].id)
assert len(store1.data) == 3
assert store1.async_get_preferred_item() == list(store1.data)[0]
@ -64,10 +64,10 @@ async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
assert store1.async_get_preferred_item() == store2.async_get_preferred_item()
async def test_loading_datasets_from_storage(
async def test_loading_pipelines_from_storage(
hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
"""Test loading stored datasets on start."""
"""Test loading stored pipelines on start."""
hass_storage[STORAGE_KEY] = {
"version": 1,
"minor_version": 1,

View file

@ -473,7 +473,7 @@ async def test_ws_create(
)
resp = await client.receive_json()
persons = manager.async_items()
persons = manager.async_values()
assert len(persons) == 2
assert resp["success"]
@ -504,7 +504,7 @@ async def test_ws_create_requires_admin(
)
resp = await client.receive_json()
persons = manager.async_items()
persons = manager.async_values()
assert len(persons) == 1
assert not resp["success"]
@ -517,7 +517,7 @@ async def test_ws_update(
manager = hass.data[DOMAIN][1]
client = await hass_ws_client(hass)
persons = manager.async_items()
persons = manager.async_values()
resp = await client.send_json(
{
@ -544,7 +544,7 @@ async def test_ws_update(
)
resp = await client.receive_json()
persons = manager.async_items()
persons = manager.async_values()
assert len(persons) == 1
assert resp["success"]
@ -570,7 +570,7 @@ async def test_ws_update_require_admin(
manager = hass.data[DOMAIN][1]
client = await hass_ws_client(hass)
original = dict(manager.async_items()[0])
original = dict(manager.async_values()[0])
resp = await client.send_json(
{
@ -585,7 +585,7 @@ async def test_ws_update_require_admin(
resp = await client.receive_json()
assert not resp["success"]
not_updated = dict(manager.async_items()[0])
not_updated = dict(manager.async_values()[0])
assert original == not_updated
@ -596,14 +596,14 @@ async def test_ws_delete(
manager = hass.data[DOMAIN][1]
client = await hass_ws_client(hass)
persons = manager.async_items()
persons = manager.async_values()
resp = await client.send_json(
{"id": 6, "type": "person/delete", "person_id": persons[0]["id"]}
)
resp = await client.receive_json()
persons = manager.async_items()
persons = manager.async_values()
assert len(persons) == 0
assert resp["success"]
@ -628,7 +628,7 @@ async def test_ws_delete_require_admin(
{
"id": 6,
"type": "person/delete",
"person_id": manager.async_items()[0]["id"],
"person_id": manager.async_values()[0]["id"],
"name": "Updated Name",
"device_trackers": [DEVICE_TRACKER_2],
"user_id": None,
@ -637,7 +637,7 @@ async def test_ws_delete_require_admin(
resp = await client.receive_json()
assert not resp["success"]
persons = manager.async_items()
persons = manager.async_values()
assert len(persons) == 1
@ -670,7 +670,7 @@ async def test_update_double_user_id(
await storage_collection.async_create_item(
{"name": "Hello", "user_id": hass_admin_user.id}
)
person = await storage_collection.async_create_item({"name": "Hello"})
_, person = await storage_collection.async_create_item({"name": "Hello"})
with pytest.raises(ValueError):
await storage_collection.async_update_item(
@ -680,7 +680,7 @@ async def test_update_double_user_id(
async def test_update_invalid_user_id(hass: HomeAssistant, storage_collection) -> None:
"""Test updating to invalid user ID."""
person = await storage_collection.async_create_item({"name": "Hello"})
_, person = await storage_collection.async_create_item({"name": "Hello"})
with pytest.raises(ValueError):
await storage_collection.async_update_item(
@ -694,7 +694,7 @@ async def test_update_person_when_user_removed(
"""Update person when user is removed."""
storage_collection = hass.data[DOMAIN][1]
person = await storage_collection.async_create_item(
_, person = await storage_collection.async_create_item(
{"name": "Hello", "user_id": hass_read_only_user.id}
)
@ -712,7 +712,7 @@ async def test_removing_device_tracker(hass: HomeAssistant, storage_setup) -> No
"device_tracker", "mobile_app", "bla", suggested_object_id="pixel"
)
person = await storage_collection.async_create_item(
_, person = await storage_collection.async_create_item(
{"name": "Hello", "device_trackers": [entry.entity_id]}
)
@ -727,7 +727,7 @@ async def test_add_user_device_tracker(
) -> None:
"""Test adding a device tracker to a person tied to a user."""
storage_collection = hass.data[DOMAIN][1]
pers = await storage_collection.async_create_item(
_, pers = await storage_collection.async_create_item(
{
"name": "Hello",
"user_id": hass_read_only_user.id,

View file

@ -117,9 +117,9 @@ def test_id_manager() -> None:
async def test_observable_collection() -> None:
"""Test observerable collection."""
coll = collection.ObservableCollection(None)
assert coll.async_items() == []
assert coll.async_items() == {}
coll.data["bla"] = 1
assert coll.async_items() == [1]
assert coll.async_items() == {"bla": 1}
changes = track_changes(coll)
await coll.notify_changes(
@ -193,6 +193,69 @@ async def test_yaml_collection_skipping_duplicate_ids() -> None:
async def test_storage_collection(hass: HomeAssistant) -> None:
"""Test storage collection."""
store = storage.Store(hass, 1, "test-data")
await store.async_save(
{
"items": {
"mock-1": {"name": "Mock 1", "data": 1},
"mock-2": {"name": "Mock 2", "data": 2},
}
}
)
id_manager = collection.IDManager()
coll = MockStorageCollection(store, id_manager)
changes = track_changes(coll)
await coll.async_load()
assert id_manager.has_id("mock-1")
assert id_manager.has_id("mock-2")
assert len(changes) == 2
assert changes[0] == (
collection.CHANGE_ADDED,
"mock-1",
{"name": "Mock 1", "data": 1},
)
assert changes[1] == (
collection.CHANGE_ADDED,
"mock-2",
{"name": "Mock 2", "data": 2},
)
item_id, item = await coll.async_create_item({"name": "Mock 3"})
assert item_id == "mock_3"
assert len(changes) == 3
assert changes[2] == (
collection.CHANGE_ADDED,
"mock_3",
{"name": "Mock 3"},
)
updated_item = await coll.async_update_item("mock-2", {"name": "Mock 2 updated"})
assert id_manager.has_id("mock-2")
assert updated_item == {"name": "Mock 2 updated", "data": 2}
assert len(changes) == 4
assert changes[3] == (collection.CHANGE_UPDATED, "mock-2", updated_item)
with pytest.raises(ValueError):
await coll.async_update_item("mock-2", {"id": "mock-2-updated"})
assert id_manager.has_id("mock-2")
assert not id_manager.has_id("mock-2-updated")
assert len(changes) == 4
await flush_store(store)
assert await storage.Store(hass, 1, "test-data").async_load() == {
"items": {
"mock-1": {"name": "Mock 1", "data": 1},
"mock-2": {"name": "Mock 2 updated", "data": 2},
"mock_3": {"name": "Mock 3"},
}
}
async def test_storage_collection_migration(hass: HomeAssistant) -> None:
"""Test storage collection migration from old store format."""
store = storage.Store(hass, 1, "test-data")
await store.async_save(
{
"items": [
@ -212,44 +275,19 @@ async def test_storage_collection(hass: HomeAssistant) -> None:
assert changes[0] == (
collection.CHANGE_ADDED,
"mock-1",
{"id": "mock-1", "name": "Mock 1", "data": 1},
{"name": "Mock 1", "data": 1},
)
assert changes[1] == (
collection.CHANGE_ADDED,
"mock-2",
{"id": "mock-2", "name": "Mock 2", "data": 2},
{"name": "Mock 2", "data": 2},
)
item = await coll.async_create_item({"name": "Mock 3"})
assert item["id"] == "mock_3"
assert len(changes) == 3
assert changes[2] == (
collection.CHANGE_ADDED,
"mock_3",
{"id": "mock_3", "name": "Mock 3"},
)
updated_item = await coll.async_update_item("mock-2", {"name": "Mock 2 updated"})
assert id_manager.has_id("mock-2")
assert updated_item == {"id": "mock-2", "name": "Mock 2 updated", "data": 2}
assert len(changes) == 4
assert changes[3] == (collection.CHANGE_UPDATED, "mock-2", updated_item)
with pytest.raises(ValueError):
await coll.async_update_item("mock-2", {"id": "mock-2-updated"})
assert id_manager.has_id("mock-2")
assert not id_manager.has_id("mock-2-updated")
assert len(changes) == 4
await flush_store(store)
assert await storage.Store(hass, 1, "test-data").async_load() == {
"items": [
{"id": "mock-1", "name": "Mock 1", "data": 1},
{"id": "mock-2", "name": "Mock 2 updated", "data": 2},
{"id": "mock_3", "name": "Mock 3"},
]
"items": {
"mock-1": {"name": "Mock 1", "data": 1},
"mock-2": {"name": "Mock 2", "data": 2},
}
}
@ -477,6 +515,7 @@ async def test_storage_collection_websocket(
"immutable_string": "no-changes",
}
assert len(changes) == 1
response["result"].pop("id")
assert changes[0] == (collection.CHANGE_ADDED, "initial_name", response["result"])
# List
@ -537,6 +576,7 @@ async def test_storage_collection_websocket(
"immutable_string": "no-changes",
}
assert len(changes) == 2
response["result"].pop("id")
assert changes[1] == (collection.CHANGE_UPDATED, "initial_name", response["result"])
# Delete invalid ID
@ -560,7 +600,6 @@ async def test_storage_collection_websocket(
collection.CHANGE_REMOVED,
"initial_name",
{
"id": "initial_name",
"immutable_string": "no-changes",
"name": "Updated name",
},