From 51c35ab9a8d89609fc5a0365ba424fbe080a6724 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Tue, 11 Feb 2020 09:40:50 -0800 Subject: [PATCH] Entity Registry to store and restore name/icon (#31714) * Entity Registry to store and restore name/icon * Update test_entity_registry.py * Add original name/icon to JSON result --- .../components/config/entity_registry.py | 12 ++++---- homeassistant/helpers/entity.py | 2 +- homeassistant/helpers/entity_platform.py | 2 ++ homeassistant/helpers/entity_registry.py | 28 +++++++++++++++++ .../components/config/test_entity_registry.py | 30 ++++++++++++++++++- tests/helpers/test_entity_registry.py | 17 +++++++++++ 6 files changed, 84 insertions(+), 7 deletions(-) diff --git a/homeassistant/components/config/entity_registry.py b/homeassistant/components/config/entity_registry.py index 458a9dd3ecb..a7993017116 100644 --- a/homeassistant/components/config/entity_registry.py +++ b/homeassistant/components/config/entity_registry.py @@ -68,6 +68,7 @@ async def websocket_get_entity(hass, connection, msg): vol.Required("entity_id"): cv.entity_id, # If passed in, we update value. Passing None will remove old value. vol.Optional("name"): vol.Any(str, None), + vol.Optional("icon"): vol.Any(str, None), vol.Optional("new_entity_id"): str, # We only allow setting disabled_by user via API. vol.Optional("disabled_by"): vol.Any("user", None), @@ -88,11 +89,9 @@ async def websocket_update_entity(hass, connection, msg): changes = {} - if "name" in msg: - changes["name"] = msg["name"] - - if "disabled_by" in msg: - changes["disabled_by"] = msg["disabled_by"] + for key in ("name", "icon", "disabled_by"): + if key in msg: + changes[key] = msg[key] if "new_entity_id" in msg and msg["new_entity_id"] != msg["entity_id"]: changes["new_entity_id"] = msg["new_entity_id"] @@ -151,5 +150,8 @@ def _entry_dict(entry): "disabled_by": entry.disabled_by, "entity_id": entry.entity_id, "name": entry.name, + "icon": entry.icon, "platform": entry.platform, + "original_name": entry.original_name, + "original_icon": entry.original_icon, } diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 92072b22df2..4c3b9448f5a 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -337,7 +337,7 @@ class Entity(ABC): if name is not None: attr[ATTR_FRIENDLY_NAME] = name - icon = self.icon + icon = (entry and entry.icon) or self.icon if icon is not None: attr[ATTR_ICON] = icon diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index e71b28f1713..e1e046eaa6d 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -369,6 +369,8 @@ class EntityPlatform: supported_features=entity.supported_features, device_class=entity.device_class, unit_of_measurement=entity.unit_of_measurement, + original_name=entity.name, + original_icon=entity.icon, ) entity.registry_entry = entry diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 635f7feba13..05b687a8454 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -17,6 +17,8 @@ import attr from homeassistant.const import ( ATTR_DEVICE_CLASS, + ATTR_FRIENDLY_NAME, + ATTR_ICON, ATTR_SUPPORTED_FEATURES, ATTR_UNIT_OF_MEASUREMENT, EVENT_HOMEASSISTANT_START, @@ -60,6 +62,7 @@ class RegistryEntry: unique_id = attr.ib(type=str) platform = attr.ib(type=str) name = attr.ib(type=str, default=None) + icon = attr.ib(type=str, default=None) device_id: Optional[str] = attr.ib(default=None) config_entry_id: Optional[str] = attr.ib(default=None) disabled_by = attr.ib( @@ -79,6 +82,9 @@ class RegistryEntry: supported_features: int = attr.ib(default=0) device_class: Optional[str] = attr.ib(default=None) unit_of_measurement: Optional[str] = attr.ib(default=None) + # As set by integration + original_name: Optional[str] = attr.ib(default=None) + original_icon: Optional[str] = attr.ib(default=None) domain = attr.ib(type=str, init=False, repr=False) @domain.default @@ -167,6 +173,8 @@ class EntityRegistry: supported_features: Optional[int] = None, device_class: Optional[str] = None, unit_of_measurement: Optional[str] = None, + original_name: Optional[str] = None, + original_icon: Optional[str] = None, ) -> RegistryEntry: """Get entity. Create if it doesn't exist.""" config_entry_id = None @@ -184,6 +192,8 @@ class EntityRegistry: supported_features=supported_features or _UNDEF, device_class=device_class or _UNDEF, unit_of_measurement=unit_of_measurement or _UNDEF, + original_name=original_name or _UNDEF, + original_icon=original_icon or _UNDEF, # When we changed our slugify algorithm, we invalidated some # stored entity IDs with either a __ or ending in _. # Fix introduced in 0.86 (Jan 23, 2019). Next line can be @@ -215,6 +225,8 @@ class EntityRegistry: supported_features=supported_features or 0, device_class=device_class, unit_of_measurement=unit_of_measurement, + original_name=original_name, + original_icon=original_icon, ) self.entities[entity_id] = entity _LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id) @@ -254,6 +266,7 @@ class EntityRegistry: entity_id, *, name=_UNDEF, + icon=_UNDEF, new_entity_id=_UNDEF, new_unique_id=_UNDEF, disabled_by=_UNDEF, @@ -264,6 +277,7 @@ class EntityRegistry: self._async_update_entity( entity_id, name=name, + icon=icon, new_entity_id=new_entity_id, new_unique_id=new_unique_id, disabled_by=disabled_by, @@ -276,6 +290,7 @@ class EntityRegistry: entity_id, *, name=_UNDEF, + icon=_UNDEF, config_entry_id=_UNDEF, new_entity_id=_UNDEF, device_id=_UNDEF, @@ -285,6 +300,8 @@ class EntityRegistry: supported_features=_UNDEF, device_class=_UNDEF, unit_of_measurement=_UNDEF, + original_name=_UNDEF, + original_icon=_UNDEF, ): """Private facing update properties method.""" old = self.entities[entity_id] @@ -293,6 +310,7 @@ class EntityRegistry: for attr_name, value in ( ("name", name), + ("icon", icon), ("config_entry_id", config_entry_id), ("device_id", device_id), ("disabled_by", disabled_by), @@ -300,6 +318,8 @@ class EntityRegistry: ("supported_features", supported_features), ("device_class", device_class), ("unit_of_measurement", unit_of_measurement), + ("original_name", original_name), + ("original_icon", original_icon), ): if value is not _UNDEF and value != getattr(old, attr_name): changes[attr_name] = value @@ -523,6 +543,14 @@ def async_setup_entity_restore( if entry.unit_of_measurement is not None: attrs[ATTR_UNIT_OF_MEASUREMENT] = entry.unit_of_measurement + name = entry.name or entry.original_name + if name is not None: + attrs[ATTR_FRIENDLY_NAME] = name + + icon = entry.icon or entry.original_icon + if icon is not None: + attrs[ATTR_ICON] = icon + states.async_set(entry.entity_id, STATE_UNAVAILABLE, attrs) hass.bus.async_listen(EVENT_HOMEASSISTANT_START, _write_unavailable_states) diff --git a/tests/components/config/test_entity_registry.py b/tests/components/config/test_entity_registry.py index 133c88d9ceb..8fe7e8fdbe4 100644 --- a/tests/components/config/test_entity_registry.py +++ b/tests/components/config/test_entity_registry.py @@ -41,6 +41,9 @@ async def test_list_entities(hass, client): "disabled_by": None, "entity_id": "test_domain.name", "name": "Hello World", + "icon": None, + "original_name": None, + "original_icon": None, "platform": "test_platform", }, { @@ -49,6 +52,9 @@ async def test_list_entities(hass, client): "disabled_by": None, "entity_id": "test_domain.no_name", "name": None, + "icon": None, + "original_name": None, + "original_icon": None, "platform": "test_platform", }, ] @@ -85,6 +91,9 @@ async def test_get_entity(hass, client): "platform": "test_platform", "entity_id": "test_domain.name", "name": "Hello World", + "icon": None, + "original_name": None, + "original_icon": None, } await client.send_json( @@ -103,6 +112,9 @@ async def test_get_entity(hass, client): "platform": "test_platform", "entity_id": "test_domain.no_name", "name": None, + "icon": None, + "original_name": None, + "original_icon": None, } @@ -117,6 +129,7 @@ async def test_update_entity(hass, client): # Using component.async_add_entities is equal to platform "domain" platform="test_platform", name="before update", + icon="icon:before update", ) }, ) @@ -127,14 +140,16 @@ async def test_update_entity(hass, client): state = hass.states.get("test_domain.world") assert state is not None assert state.name == "before update" + assert state.attributes["icon"] == "icon:before update" - # UPDATE NAME + # UPDATE NAME & ICON await client.send_json( { "id": 6, "type": "config/entity_registry/update", "entity_id": "test_domain.world", "name": "after update", + "icon": "icon:after update", } ) @@ -147,10 +162,14 @@ async def test_update_entity(hass, client): "platform": "test_platform", "entity_id": "test_domain.world", "name": "after update", + "icon": "icon:after update", + "original_name": None, + "original_icon": None, } state = hass.states.get("test_domain.world") assert state.name == "after update" + assert state.attributes["icon"] == "icon:after update" # UPDATE DISABLED_BY TO USER await client.send_json( @@ -186,6 +205,9 @@ async def test_update_entity(hass, client): "platform": "test_platform", "entity_id": "test_domain.world", "name": "after update", + "icon": "icon:after update", + "original_name": None, + "original_icon": None, } @@ -229,6 +251,9 @@ async def test_update_entity_no_changes(hass, client): "platform": "test_platform", "entity_id": "test_domain.world", "name": "name of entity", + "icon": None, + "original_name": None, + "original_icon": None, } state = hass.states.get("test_domain.world") @@ -301,6 +326,9 @@ async def test_update_entity_id(hass, client): "platform": "test_platform", "entity_id": "test_domain.planet", "name": None, + "icon": None, + "original_name": None, + "original_icon": None, } assert hass.states.get("test_domain.world") is None diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index e532d99f333..6782007ebe7 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -72,6 +72,9 @@ def test_get_or_create_updates_data(registry): supported_features=5, device_class="mock-device-class", disabled_by=entity_registry.DISABLED_HASS, + unit_of_measurement="initial-unit_of_measurement", + original_name="initial-original_name", + original_icon="initial-original_icon", ) assert orig_entry.config_entry_id == orig_config_entry.entry_id @@ -80,6 +83,9 @@ def test_get_or_create_updates_data(registry): assert orig_entry.supported_features == 5 assert orig_entry.device_class == "mock-device-class" assert orig_entry.disabled_by == entity_registry.DISABLED_HASS + assert orig_entry.unit_of_measurement == "initial-unit_of_measurement" + assert orig_entry.original_name == "initial-original_name" + assert orig_entry.original_icon == "initial-original_icon" new_config_entry = MockConfigEntry(domain="light") @@ -93,6 +99,9 @@ def test_get_or_create_updates_data(registry): supported_features=10, device_class="new-mock-device-class", disabled_by=entity_registry.DISABLED_USER, + unit_of_measurement="updated-unit_of_measurement", + original_name="updated-original_name", + original_icon="updated-original_icon", ) assert new_entry.config_entry_id == new_config_entry.entry_id @@ -100,6 +109,9 @@ def test_get_or_create_updates_data(registry): assert new_entry.capabilities == {"new-max": 100} assert new_entry.supported_features == 10 assert new_entry.device_class == "new-mock-device-class" + assert new_entry.unit_of_measurement == "updated-unit_of_measurement" + assert new_entry.original_name == "updated-original_name" + assert new_entry.original_icon == "updated-original_icon" # Should not be updated assert new_entry.disabled_by == entity_registry.DISABLED_HASS @@ -434,6 +446,7 @@ async def test_update_entity(registry): for attr_name, new_value in ( ("name", "new name"), + ("icon", "new icon"), ("disabled_by", entity_registry.DISABLED_USER), ): changes = {attr_name: new_value} @@ -503,6 +516,8 @@ async def test_restore_states(hass): capabilities={"max": 100}, supported_features=5, device_class="mock-device-class", + original_name="Mock Original Name", + original_icon="hass:original-icon", ) hass.bus.async_fire(EVENT_HOMEASSISTANT_START, {}) @@ -524,6 +539,8 @@ async def test_restore_states(hass): "supported_features": 5, "device_class": "mock-device-class", "restored": True, + "friendly_name": "Mock Original Name", + "icon": "hass:original-icon", } registry.async_remove("light.disabled")