diff --git a/homeassistant/components/config/entity_registry.py b/homeassistant/components/config/entity_registry.py index c594bf1f99e..7c0867e3852 100644 --- a/homeassistant/components/config/entity_registry.py +++ b/homeassistant/components/config/entity_registry.py @@ -20,6 +20,7 @@ SCHEMA_WS_UPDATE = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({ vol.Required('entity_id'): cv.entity_id, # If passed in, we update value. Passing None will remove old value. vol.Optional('name'): vol.Any(str, None), + vol.Optional('new_entity_id'): str, }) @@ -74,13 +75,28 @@ def websocket_update_entity(hass, connection, msg): msg['id'], websocket_api.ERR_NOT_FOUND, 'Entity not found')) return - entry = registry.async_update_entity( - msg['entity_id'], name=msg['name']) - connection.send_message_outside(websocket_api.result_message( - msg['id'], _entry_dict(entry) - )) + changes = {} - hass.async_add_job(update_entity()) + if 'name' in msg: + changes['name'] = msg['name'] + + if 'new_entity_id' in msg: + changes['new_entity_id'] = msg['new_entity_id'] + + try: + if changes: + entry = registry.async_update_entity( + msg['entity_id'], **changes) + except ValueError as err: + connection.send_message_outside(websocket_api.error_message( + msg['id'], 'invalid_info', str(err) + )) + else: + connection.send_message_outside(websocket_api.result_message( + msg['id'], _entry_dict(entry) + )) + + hass.async_create_task(update_entity()) @callback diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index c7e88b210b3..f466664fc61 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -82,6 +82,9 @@ class Entity: # Name in the entity registry registry_name = None + # Hold list for functions to call on remove. + _on_remove = None + @property def should_poll(self) -> bool: """Return True if entity has to be polled for state. @@ -324,8 +327,19 @@ class Entity: if self.parallel_updates: self.parallel_updates.release() + @callback + def async_on_remove(self, func): + """Add a function to call when entity removed.""" + if self._on_remove is None: + self._on_remove = [] + self._on_remove.append(func) + async def async_remove(self): """Remove entity from Home Assistant.""" + if self._on_remove is not None: + while self._on_remove: + self._on_remove.pop()() + if self.platform is not None: await self.platform.async_remove_entity(self.entity_id) else: @@ -335,7 +349,17 @@ class Entity: def async_registry_updated(self, old, new): """Called when the entity registry has been updated.""" self.registry_name = new.name - self.async_schedule_update_ha_state() + + if new.entity_id == self.entity_id: + self.async_schedule_update_ha_state() + return + + async def readd(): + """Remove and add entity again.""" + await self.async_remove() + await self.platform.async_add_entities([self]) + + self.hass.async_create_task(readd()) def __eq__(self, other): """Return the comparison.""" diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index 0847c116954..dc1e376f471 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -283,7 +283,7 @@ class EntityPlatform: entity.entity_id = entry.entity_id entity.registry_name = entry.name - entry.add_update_listener(entity) + entity.async_on_remove(entry.add_update_listener(entity)) # We won't generate an entity ID if the platform has already set one # We will however make sure that platform cannot pick a registered ID diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index b222d78b577..2fa64ff8680 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -19,10 +19,10 @@ import weakref import attr -from ..core import callback, split_entity_id -from ..loader import bind_hass -from ..util import ensure_unique_string, slugify -from ..util.yaml import load_yaml, save_yaml +from homeassistant.core import callback, split_entity_id, valid_entity_id +from homeassistant.loader import bind_hass +from homeassistant.util import ensure_unique_string, slugify +from homeassistant.util.yaml import load_yaml, save_yaml PATH_REGISTRY = 'entity_registry.yaml' DATA_REGISTRY = 'entity_registry' @@ -63,8 +63,13 @@ class RegistryEntry: """Listen for when entry is updated. Listener: Callback function(old_entry, new_entry) + + Returns function to unlisten. """ - self.update_listeners.append(weakref.ref(listener)) + weak_listener = weakref.ref(listener) + self.update_listeners.append(weak_listener) + + return lambda: self.update_listeners.remove(weak_listener) class EntityRegistry: @@ -133,13 +138,18 @@ class EntityRegistry: return entity @callback - def async_update_entity(self, entity_id, *, name=_UNDEF): + def async_update_entity(self, entity_id, *, name=_UNDEF, + new_entity_id=_UNDEF): """Update properties of an entity.""" - return self._async_update_entity(entity_id, name=name) + return self._async_update_entity( + entity_id, + name=name, + new_entity_id=new_entity_id + ) @callback def _async_update_entity(self, entity_id, *, name=_UNDEF, - config_entry_id=_UNDEF): + config_entry_id=_UNDEF, new_entity_id=_UNDEF): """Private facing update properties method.""" old = self.entities[entity_id] @@ -152,6 +162,20 @@ class EntityRegistry: config_entry_id != old.config_entry_id): changes['config_entry_id'] = config_entry_id + if new_entity_id is not _UNDEF and new_entity_id != old.entity_id: + if self.async_is_registered(new_entity_id): + raise ValueError('Entity is already registered') + + if not valid_entity_id(new_entity_id): + raise ValueError('Invalid entity ID') + + if (split_entity_id(new_entity_id)[0] != + split_entity_id(entity_id)[0]): + raise ValueError('New entity ID should be same domain') + + self.entities.pop(entity_id) + entity_id = changes['entity_id'] = new_entity_id + if not changes: return old diff --git a/tests/components/config/test_entity_registry.py b/tests/components/config/test_entity_registry.py index 1591b8da1d2..559f29372de 100644 --- a/tests/components/config/test_entity_registry.py +++ b/tests/components/config/test_entity_registry.py @@ -54,8 +54,8 @@ async def test_get_entity(hass, client): } -async def test_update_entity(hass, client): - """Test get entry.""" +async def test_update_entity_name(hass, client): + """Test updating entity name.""" mock_registry(hass, { 'test_domain.world': RegistryEntry( entity_id='test_domain.world', @@ -92,7 +92,7 @@ async def test_update_entity(hass, client): async def test_update_entity_no_changes(hass, client): - """Test get entry.""" + """Test update entity with no changes.""" mock_registry(hass, { 'test_domain.world': RegistryEntry( entity_id='test_domain.world', @@ -129,7 +129,7 @@ async def test_update_entity_no_changes(hass, client): async def test_get_nonexisting_entity(client): - """Test get entry.""" + """Test get entry with nonexisting entity.""" await client.send_json({ 'id': 6, 'type': 'config/entity_registry/get', @@ -141,7 +141,7 @@ async def test_get_nonexisting_entity(client): async def test_update_nonexisting_entity(client): - """Test get entry.""" + """Test update a nonexisting entity.""" await client.send_json({ 'id': 6, 'type': 'config/entity_registry/update', @@ -151,3 +151,37 @@ async def test_update_nonexisting_entity(client): msg = await client.receive_json() assert not msg['success'] + + +async def test_update_entity_id(hass, client): + """Test update entity id.""" + mock_registry(hass, { + 'test_domain.world': RegistryEntry( + entity_id='test_domain.world', + unique_id='1234', + # Using component.async_add_entities is equal to platform "domain" + platform='test_platform', + ) + }) + platform = MockEntityPlatform(hass) + entity = MockEntity(unique_id='1234') + await platform.async_add_entities([entity]) + + assert hass.states.get('test_domain.world') is not None + + await client.send_json({ + 'id': 6, + 'type': 'config/entity_registry/update', + 'entity_id': 'test_domain.world', + 'new_entity_id': 'test_domain.planet', + }) + + msg = await client.receive_json() + + assert msg['result'] == { + 'entity_id': 'test_domain.planet', + 'name': None + } + + assert hass.states.get('test_domain.world') is None + assert hass.states.get('test_domain.planet') is not None diff --git a/tests/helpers/test_entity.py b/tests/helpers/test_entity.py index 4981ad23cc0..e24bec489f4 100644 --- a/tests/helpers/test_entity.py +++ b/tests/helpers/test_entity.py @@ -400,3 +400,15 @@ def test_async_remove_no_platform(hass): assert len(hass.states.async_entity_ids()) == 1 yield from ent.async_remove() assert len(hass.states.async_entity_ids()) == 0 + + +async def test_async_remove_runs_callbacks(hass): + """Test async_remove method when no platform set.""" + result = [] + + ent = entity.Entity() + ent.hass = hass + ent.entity_id = 'test.test' + ent.async_on_remove(lambda: result.append(1)) + await ent.async_remove() + assert len(result) == 1 diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py index 2d2f148189f..b52405aa8be 100644 --- a/tests/helpers/test_entity_platform.py +++ b/tests/helpers/test_entity_platform.py @@ -5,6 +5,8 @@ import unittest from unittest.mock import patch, Mock, MagicMock from datetime import timedelta +import pytest + from homeassistant.exceptions import PlatformNotReady import homeassistant.loader as loader from homeassistant.helpers.entity import generate_entity_id @@ -487,7 +489,7 @@ def test_registry_respect_entity_disabled(hass): assert hass.states.async_entity_ids() == [] -async def test_entity_registry_updates(hass): +async def test_entity_registry_updates_name(hass): """Test that updates on the entity registry update platform entities.""" registry = mock_registry(hass, { 'test_domain.world': entity_registry.RegistryEntry( @@ -602,3 +604,75 @@ def test_not_fails_with_adding_empty_entities_(hass): yield from component.async_add_entities([]) assert len(hass.states.async_entity_ids()) == 0 + + +async def test_entity_registry_updates_entity_id(hass): + """Test that updates on the entity registry update platform entities.""" + registry = mock_registry(hass, { + 'test_domain.world': entity_registry.RegistryEntry( + entity_id='test_domain.world', + unique_id='1234', + # Using component.async_add_entities is equal to platform "domain" + platform='test_platform', + name='Some name' + ) + }) + platform = MockEntityPlatform(hass) + entity = MockEntity(unique_id='1234') + await platform.async_add_entities([entity]) + + state = hass.states.get('test_domain.world') + assert state is not None + assert state.name == 'Some name' + + registry.async_update_entity('test_domain.world', + new_entity_id='test_domain.planet') + await hass.async_block_till_done() + await hass.async_block_till_done() + + assert hass.states.get('test_domain.world') is None + assert hass.states.get('test_domain.planet') is not None + + +async def test_entity_registry_updates_invalid_entity_id(hass): + """Test that we can't update to an invalid entity id.""" + registry = mock_registry(hass, { + 'test_domain.world': entity_registry.RegistryEntry( + entity_id='test_domain.world', + unique_id='1234', + # Using component.async_add_entities is equal to platform "domain" + platform='test_platform', + name='Some name' + ), + 'test_domain.existing': entity_registry.RegistryEntry( + entity_id='test_domain.existing', + unique_id='5678', + platform='test_platform', + ), + }) + platform = MockEntityPlatform(hass) + entity = MockEntity(unique_id='1234') + await platform.async_add_entities([entity]) + + state = hass.states.get('test_domain.world') + assert state is not None + assert state.name == 'Some name' + + with pytest.raises(ValueError): + registry.async_update_entity('test_domain.world', + new_entity_id='test_domain.existing') + + with pytest.raises(ValueError): + registry.async_update_entity('test_domain.world', + new_entity_id='invalid_entity_id') + + with pytest.raises(ValueError): + registry.async_update_entity('test_domain.world', + new_entity_id='diff_domain.world') + + await hass.async_block_till_done() + await hass.async_block_till_done() + + assert hass.states.get('test_domain.world') is not None + assert hass.states.get('invalid_entity_id') is None + assert hass.states.get('diff_domain.world') is None