diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 987bdeae6ca..687ed0b6f8b 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -363,14 +363,16 @@ class Entity: async def async_remove(self): """Remove entity from Home Assistant.""" + will_remove = getattr(self, 'async_will_remove_from_hass', None) + + if will_remove: + await will_remove() # pylint: disable=not-callable + if self._on_remove is not None: while self._on_remove: self._on_remove.pop()() - if self.platform is not None: - await self.platform.async_remove_entity(self.entity_id) - else: - self.hass.states.async_remove(self.entity_id) + self.hass.states.async_remove(self.entity_id) @callback def async_registry_updated(self, old, new): diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index 3ab45577236..5fd580a33f0 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -345,8 +345,10 @@ class EntityPlatform: raise HomeAssistantError( msg) - self.entities[entity.entity_id] = entity - component_entities.add(entity.entity_id) + entity_id = entity.entity_id + self.entities[entity_id] = entity + component_entities.add(entity_id) + entity.async_on_remove(lambda: self.entities.pop(entity_id)) if hasattr(entity, 'async_added_to_hass'): await entity.async_added_to_hass() @@ -365,7 +367,7 @@ class EntityPlatform: if not self.entities: return - tasks = [self._async_remove_entity(entity_id) + tasks = [self.async_remove_entity(entity_id) for entity_id in self.entities] await asyncio.wait(tasks, loop=self.hass.loop) @@ -376,7 +378,7 @@ class EntityPlatform: async def async_remove_entity(self, entity_id): """Remove entity id from platform.""" - await self._async_remove_entity(entity_id) + await self.entities[entity_id].async_remove() # Clean up polling job if no longer needed if (self._async_unsub_polling is not None and @@ -385,15 +387,6 @@ class EntityPlatform: self._async_unsub_polling() self._async_unsub_polling = None - async def _async_remove_entity(self, entity_id): - """Remove entity id from platform.""" - entity = self.entities.pop(entity_id) - - if hasattr(entity, 'async_will_remove_from_hass'): - await entity.async_will_remove_from_hass() - - self.hass.states.async_remove(entity_id) - async def _update_entity_states(self, now): """Update the states of all the polling entities. diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 340118502b1..59777e2e6bb 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -11,7 +11,8 @@ from homeassistant.setup import async_setup_component from homeassistant.util import dt from tests.common import ( - MockModule, mock_coro, MockConfigEntry, async_fire_time_changed) + MockModule, mock_coro, MockConfigEntry, async_fire_time_changed, + MockPlatform, MockEntity) @pytest.fixture @@ -40,35 +41,87 @@ def test_call_setup_entry(hass): assert len(mock_setup_entry.mock_calls) == 1 -@asyncio.coroutine -def test_remove_entry(hass, manager): +async def test_remove_entry(hass, manager): """Test that we can remove an entry.""" - mock_unload_entry = MagicMock(return_value=mock_coro(True)) + async def mock_setup_entry(hass, entry): + """Mock setting up entry.""" + hass.loop.create_task(hass.config_entries.async_forward_entry_setup( + entry, 'light')) + return True + async def mock_unload_entry(hass, entry): + """Mock unloading an entry.""" + result = await hass.config_entries.async_forward_entry_unload( + entry, 'light') + assert result + return result + + entity = MockEntity( + unique_id='1234', + name='Test Entity', + ) + + async def mock_setup_entry_platform(hass, entry, async_add_entities): + """Mock setting up platform.""" + async_add_entities([entity]) + + loader.set_component(hass, 'test', MockModule( + 'test', + async_setup_entry=mock_setup_entry, + async_unload_entry=mock_unload_entry + )) loader.set_component( - hass, 'test', - MockModule('comp', async_unload_entry=mock_unload_entry)) + hass, 'light.test', + MockPlatform(async_setup_entry=mock_setup_entry_platform)) MockConfigEntry(domain='test', entry_id='test1').add_to_manager(manager) - MockConfigEntry( + entry = MockConfigEntry( domain='test', entry_id='test2', - state=config_entries.ENTRY_STATE_LOADED - ).add_to_manager(manager) + ) + entry.add_to_manager(manager) MockConfigEntry(domain='test', entry_id='test3').add_to_manager(manager) + # Check all config entries exist assert [item.entry_id for item in manager.async_entries()] == \ ['test1', 'test2', 'test3'] - result = yield from manager.async_remove('test2') + # Setup entry + await entry.async_setup(hass) + await hass.async_block_till_done() + # Check entity state got added + assert hass.states.get('light.test_entity') is not None + # Group all_lights, light.test_entity + assert len(hass.states.async_all()) == 2 + + # Check entity got added to entity registry + ent_reg = await hass.helpers.entity_registry.async_get_registry() + assert len(ent_reg.entities) == 1 + entity_entry = list(ent_reg.entities.values())[0] + assert entity_entry.config_entry_id == entry.entry_id + + # Remove entry + result = await manager.async_remove('test2') + await hass.async_block_till_done() + + # Check that unload went well and so no need to restart assert result == { 'require_restart': False } + + # Check that config entry was removed. assert [item.entry_id for item in manager.async_entries()] == \ ['test1', 'test3'] - assert len(mock_unload_entry.mock_calls) == 1 + # Check that entity state has been removed + assert hass.states.get('light.test_entity') is None + # Just Group all_lights + assert len(hass.states.async_all()) == 1 + + # Check that entity registry entry no longer references config_entry_id + entity_entry = list(ent_reg.entities.values())[0] + assert entity_entry.config_entry_id is None @asyncio.coroutine