From 47da1c456b9d6ee964040b14a3a8b79f6be0a255 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 13 Sep 2022 20:55:06 +0200 Subject: [PATCH] Don't allow partial update of counter settings (#78371) --- homeassistant/components/counter/__init__.py | 23 ++++--------- tests/components/counter/test_init.py | 36 ++++++++++---------- 2 files changed, 24 insertions(+), 35 deletions(-) diff --git a/homeassistant/components/counter/__init__.py b/homeassistant/components/counter/__init__.py index 61ec384ae50..113826c2291 100644 --- a/homeassistant/components/counter/__init__.py +++ b/homeassistant/components/counter/__init__.py @@ -47,7 +47,7 @@ SERVICE_CONFIGURE = "configure" STORAGE_KEY = DOMAIN STORAGE_VERSION = 1 -CREATE_FIELDS = { +STORAGE_FIELDS = { vol.Optional(CONF_ICON): cv.icon, vol.Optional(CONF_INITIAL, default=DEFAULT_INITIAL): cv.positive_int, vol.Required(CONF_NAME): vol.All(cv.string, vol.Length(min=1)), @@ -57,16 +57,6 @@ CREATE_FIELDS = { vol.Optional(CONF_STEP, default=DEFAULT_STEP): cv.positive_int, } -UPDATE_FIELDS = { - vol.Optional(CONF_ICON): cv.icon, - vol.Optional(CONF_INITIAL): cv.positive_int, - vol.Optional(CONF_NAME): cv.string, - vol.Optional(CONF_MAXIMUM): vol.Any(None, vol.Coerce(int)), - vol.Optional(CONF_MINIMUM): vol.Any(None, vol.Coerce(int)), - vol.Optional(CONF_RESTORE): cv.boolean, - vol.Optional(CONF_STEP): cv.positive_int, -} - def _none_to_empty_dict(value): if value is None: @@ -128,7 +118,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: await storage_collection.async_load() collection.StorageCollectionWebsocket( - storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS + storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS ).async_setup(hass) component.async_register_entity_service(SERVICE_INCREMENT, {}, "async_increment") @@ -152,12 +142,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: class CounterStorageCollection(collection.StorageCollection): """Input storage based collection.""" - CREATE_SCHEMA = vol.Schema(CREATE_FIELDS) - UPDATE_SCHEMA = vol.Schema(UPDATE_FIELDS) + CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS) async def _process_create_data(self, data: dict) -> dict: """Validate the config is valid.""" - return self.CREATE_SCHEMA(data) + return self.CREATE_UPDATE_SCHEMA(data) @callback def _get_suggested_id(self, info: dict) -> str: @@ -166,8 +155,8 @@ class CounterStorageCollection(collection.StorageCollection): async def _update_data(self, data: dict, update_data: dict) -> dict: """Return a new updated data object.""" - update_data = self.UPDATE_SCHEMA(update_data) - return {**data, **update_data} + update_data = self.CREATE_UPDATE_SCHEMA(update_data) + return {CONF_ID: data[CONF_ID]} | update_data class Counter(collection.CollectionEntity, RestoreEntity): diff --git a/tests/components/counter/test_init.py b/tests/components/counter/test_init.py index 107dd97924d..90885be770d 100644 --- a/tests/components/counter/test_init.py +++ b/tests/components/counter/test_init.py @@ -591,17 +591,15 @@ async def test_ws_delete(hass, hass_ws_client, storage_setup): async def test_update_min_max(hass, hass_ws_client, storage_setup): """Test updating min/max updates the state.""" - items = [ - { - "id": "from_storage", - "initial": 15, - "name": "from storage", - "maximum": 100, - "minimum": 10, - "step": 3, - "restore": True, - } - ] + settings = { + "initial": 15, + "name": "from storage", + "maximum": 100, + "minimum": 10, + "step": 3, + "restore": True, + } + items = [{"id": "from_storage"} | settings] assert await storage_setup(items) input_id = "from_storage" @@ -618,16 +616,18 @@ async def test_update_min_max(hass, hass_ws_client, storage_setup): client = await hass_ws_client(hass) + updated_settings = settings | {"minimum": 19} await client.send_json( { "id": 6, "type": f"{DOMAIN}/update", f"{DOMAIN}_id": f"{input_id}", - "minimum": 19, + **updated_settings, } ) resp = await client.receive_json() assert resp["success"] + assert resp["result"] == {"id": "from_storage"} | updated_settings state = hass.states.get(input_entity_id) assert int(state.state) == 19 @@ -635,18 +635,18 @@ async def test_update_min_max(hass, hass_ws_client, storage_setup): assert state.attributes[ATTR_MAXIMUM] == 100 assert state.attributes[ATTR_STEP] == 3 + updated_settings = settings | {"maximum": 5, "minimum": 2, "step": 5} await client.send_json( { "id": 7, "type": f"{DOMAIN}/update", f"{DOMAIN}_id": f"{input_id}", - "maximum": 5, - "minimum": 2, - "step": 5, + **updated_settings, } ) resp = await client.receive_json() assert resp["success"] + assert resp["result"] == {"id": "from_storage"} | updated_settings state = hass.states.get(input_entity_id) assert int(state.state) == 5 @@ -654,18 +654,18 @@ async def test_update_min_max(hass, hass_ws_client, storage_setup): assert state.attributes[ATTR_MAXIMUM] == 5 assert state.attributes[ATTR_STEP] == 5 + updated_settings = settings | {"maximum": None, "minimum": None, "step": 6} await client.send_json( { "id": 8, "type": f"{DOMAIN}/update", f"{DOMAIN}_id": f"{input_id}", - "maximum": None, - "minimum": None, - "step": 6, + **updated_settings, } ) resp = await client.receive_json() assert resp["success"] + assert resp["result"] == {"id": "from_storage"} | updated_settings state = hass.states.get(input_entity_id) assert int(state.state) == 5