Clean up entity component (#11691)

* Clean up entity component

* Lint

* List -> Tuple

* Add Entity.async_remove back

* Unflake setting up group test
This commit is contained in:
Paulus Schoutsen 2018-01-22 22:54:41 -08:00 committed by GitHub
parent d478517c51
commit 183e0543b4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 230 additions and 191 deletions

View file

@ -338,10 +338,9 @@ class AutomationEntity(ToggleEntity):
yield from self.async_update_ha_state()
@asyncio.coroutine
def async_remove(self):
"""Remove automation from HASS."""
def async_will_remove_from_hass(self):
"""Remove listeners when removing automation from HASS."""
yield from self.async_turn_off()
yield from super().async_remove()
@asyncio.coroutine
def async_enable(self):

View file

@ -238,6 +238,5 @@ class FlicButton(BinarySensorDevice):
import pyflic
if connection_status == pyflic.ConnectionStatus.Disconnected:
_LOGGER.info("Button (%s) disconnected. Reason: %s",
self.address, disconnect_reason)
self.remove()
_LOGGER.warning("Button (%s) disconnected. Reason: %s",
self.address, disconnect_reason)

View file

@ -124,15 +124,15 @@ def async_setup(hass, config):
"""Set up the camera component."""
component = EntityComponent(_LOGGER, DOMAIN, hass, SCAN_INTERVAL)
hass.http.register_view(CameraImageView(component.entities))
hass.http.register_view(CameraMjpegStream(component.entities))
hass.http.register_view(CameraImageView(component))
hass.http.register_view(CameraMjpegStream(component))
yield from component.async_setup(config)
@callback
def update_tokens(time):
"""Update tokens of the entities."""
for entity in component.entities.values():
for entity in component.entities:
entity.async_update_token()
hass.async_add_job(entity.async_update_ha_state())
@ -358,14 +358,14 @@ class CameraView(HomeAssistantView):
requires_auth = False
def __init__(self, entities):
def __init__(self, component):
"""Initialize a basic camera view."""
self.entities = entities
self.component = component
@asyncio.coroutine
def get(self, request, entity_id):
"""Start a GET request."""
camera = self.entities.get(entity_id)
camera = self.component.get_entity(entity_id)
if camera is None:
status = 404 if request[KEY_AUTHENTICATED] else 401

View file

@ -42,8 +42,6 @@ ATTR_ORDER = 'order'
ATTR_VIEW = 'view'
ATTR_VISIBLE = 'visible'
DATA_ALL_GROUPS = 'data_all_groups'
SERVICE_SET_VISIBILITY = 'set_visibility'
SERVICE_SET = 'set'
SERVICE_REMOVE = 'remove'
@ -250,8 +248,10 @@ def get_entity_ids(hass, entity_id, domain_filter=None):
@asyncio.coroutine
def async_setup(hass, config):
"""Set up all groups found definded in the configuration."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
hass.data[DATA_ALL_GROUPS] = {}
component = hass.data.get(DOMAIN)
if component is None:
component = hass.data[DOMAIN] = EntityComponent(_LOGGER, DOMAIN, hass)
yield from _async_process_config(hass, config, component)
@ -271,10 +271,11 @@ def async_setup(hass, config):
def groups_service_handler(service):
"""Handle dynamic group service functions."""
object_id = service.data[ATTR_OBJECT_ID]
service_groups = hass.data[DATA_ALL_GROUPS]
entity_id = ENTITY_ID_FORMAT.format(object_id)
group = component.get_entity(entity_id)
# new group
if service.service == SERVICE_SET and object_id not in service_groups:
if service.service == SERVICE_SET and group is None:
entity_ids = service.data.get(ATTR_ENTITIES) or \
service.data.get(ATTR_ADD_ENTITIES) or None
@ -289,12 +290,15 @@ def async_setup(hass, config):
user_defined=False,
**extra_arg
)
return
if group is None:
_LOGGER.warning("%s:Group '%s' doesn't exist!",
service.service, object_id)
return
# update group
if service.service == SERVICE_SET:
group = service_groups[object_id]
need_update = False
if ATTR_ADD_ENTITIES in service.data:
@ -333,12 +337,7 @@ def async_setup(hass, config):
# remove group
if service.service == SERVICE_REMOVE:
if object_id not in service_groups:
_LOGGER.warning("Group '%s' doesn't exist!", object_id)
return
del_group = service_groups.pop(object_id)
yield from del_group.async_stop()
yield from component.async_remove_entity(entity_id)
hass.services.async_register(
DOMAIN, SERVICE_SET, groups_service_handler,
@ -395,7 +394,7 @@ class Group(Entity):
"""Track a group of entity ids."""
def __init__(self, hass, name, order=None, visible=True, icon=None,
view=False, control=None, user_defined=True):
view=False, control=None, user_defined=True, entity_ids=None):
"""Initialize a group.
This Object has factory function for creation.
@ -405,7 +404,10 @@ class Group(Entity):
self._state = STATE_UNKNOWN
self._icon = icon
self.view = view
self.tracking = []
if entity_ids:
self.tracking = tuple(ent_id.lower() for ent_id in entity_ids)
else:
self.tracking = tuple()
self.group_on = None
self.group_off = None
self.visible = visible
@ -439,23 +441,21 @@ class Group(Entity):
hass, name,
order=len(hass.states.async_entity_ids(DOMAIN)),
visible=visible, icon=icon, view=view, control=control,
user_defined=user_defined
user_defined=user_defined, entity_ids=entity_ids
)
group.entity_id = async_generate_entity_id(
ENTITY_ID_FORMAT, object_id or name, hass=hass)
# run other async stuff
if entity_ids is not None:
yield from group.async_update_tracked_entity_ids(entity_ids)
else:
yield from group.async_update_ha_state(True)
# If called before the platform async_setup is called (test cases)
if DATA_ALL_GROUPS not in hass.data:
hass.data[DATA_ALL_GROUPS] = {}
component = hass.data.get(DOMAIN)
if component is None:
component = hass.data[DOMAIN] = \
EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([group], True)
hass.data[DATA_ALL_GROUPS][object_id] = group
return group
@property
@ -534,10 +534,6 @@ class Group(Entity):
yield from self.async_update_ha_state(True)
self.async_start()
def start(self):
"""Start tracking members."""
self.hass.add_job(self.async_start)
@callback
def async_start(self):
"""Start tracking members.
@ -549,17 +545,15 @@ class Group(Entity):
self.hass, self.tracking, self._async_state_changed_listener
)
def stop(self):
"""Unregister the group from Home Assistant."""
run_coroutine_threadsafe(self.async_stop(), self.hass.loop).result()
@asyncio.coroutine
def async_stop(self):
"""Unregister the group from Home Assistant.
This method must be run in the event loop.
"""
yield from self.async_remove()
if self._async_unsub_state_changed:
self._async_unsub_state_changed()
self._async_unsub_state_changed = None
@asyncio.coroutine
def async_update(self):
@ -567,17 +561,19 @@ class Group(Entity):
self._state = STATE_UNKNOWN
self._async_update_group_state()
def async_remove(self):
"""Remove group from HASS.
@asyncio.coroutine
def async_added_to_hass(self):
"""Callback when added to HASS."""
if self.tracking:
self.async_start()
This method must be run in the event loop and returns a coroutine.
"""
@asyncio.coroutine
def async_will_remove_from_hass(self):
"""Callback when removed from HASS."""
if self._async_unsub_state_changed:
self._async_unsub_state_changed()
self._async_unsub_state_changed = None
return super().async_remove()
@asyncio.coroutine
def _async_state_changed_listener(self, entity_id, old_state, new_state):
"""Respond to a member state changing.

View file

@ -82,7 +82,7 @@ def async_setup(hass, config):
mailbox_entity = MailboxEntity(hass, mailbox)
component = EntityComponent(
logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL)
yield from component.async_add_entity(mailbox_entity)
yield from component.async_add_entities([mailbox_entity])
setup_tasks = [async_setup_platform(p_type, p_config) for p_type, p_config
in config_per_platform(config, DOMAIN)]

View file

@ -366,7 +366,7 @@ def async_setup(hass, config):
component = EntityComponent(
logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL)
hass.http.register_view(MediaPlayerImageView(component.entities))
hass.http.register_view(MediaPlayerImageView(component))
yield from component.async_setup(config)
@ -929,14 +929,14 @@ class MediaPlayerImageView(HomeAssistantView):
url = '/api/media_player_proxy/{entity_id}'
name = 'api:media_player:image'
def __init__(self, entities):
def __init__(self, component):
"""Initialize a media player view."""
self.entities = entities
self.component = component
@asyncio.coroutine
def get(self, request, entity_id):
"""Start a get request."""
player = self.entities.get(entity_id)
player = self.component.get_entity(entity_id)
if player is None:
status = 404 if request[KEY_AUTHENTICATED] else 401
return web.Response(status=status)

View file

@ -161,7 +161,7 @@ def async_setup(hass, config):
face.store.pop(g_id)
entity = entities.pop(g_id)
yield from entity.async_remove()
hass.states.async_remove(entity.entity_id)
except HomeAssistantError as err:
_LOGGER.error("Can't delete group '%s' with error: %s", g_id, err)

View file

@ -86,7 +86,7 @@ def _create_instance(hass, account_name, api_key, shared_secret,
token, stored_rtm_config, component):
entity = RememberTheMilk(account_name, api_key, shared_secret,
token, stored_rtm_config)
component.add_entity(entity)
component.add_entities([entity])
hass.services.register(
DOMAIN, '{}_create_task'.format(account_name), entity.create_task,
schema=SERVICE_SCHEMA_CREATE_TASK)

View file

@ -156,7 +156,7 @@ def _async_process_config(hass, config, component):
def service_handler(service):
"""Execute a service call to script.<script name>."""
entity_id = ENTITY_ID_FORMAT.format(service.service)
script = component.entities.get(entity_id)
script = component.get_entity(entity_id)
if script.is_on:
_LOGGER.warning("Script %s already running.", entity_id)
return
@ -219,15 +219,11 @@ class ScriptEntity(ToggleEntity):
"""Turn script off."""
self.script.async_stop()
def async_remove(self):
"""Remove script from HASS.
This method must be run in the event loop and returns a coroutine.
"""
@asyncio.coroutine
def async_will_remove_from_hass(self):
"""Stop script and remove service when it will be removed from HASS."""
if self.script.is_running:
self.script.async_stop()
# remove service
self.hass.services.async_remove(DOMAIN, self.object_id)
return super().async_remove()

View file

@ -15,8 +15,7 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.config import DATA_CUSTOMIZE
from homeassistant.exceptions import NoEntitySpecifiedError
from homeassistant.util import ensure_unique_string, slugify
from homeassistant.util.async import (
run_coroutine_threadsafe, run_callback_threadsafe)
from homeassistant.util.async import run_callback_threadsafe
_LOGGER = logging.getLogger(__name__)
SLOW_UPDATE_WARNING = 10
@ -66,9 +65,12 @@ class Entity(object):
# this class. These may be used to customize the behavior of the entity.
entity_id = None # type: str
# Owning hass instance. Will be set by EntityComponent
# Owning hass instance. Will be set by EntityPlatform
hass = None # type: Optional[HomeAssistant]
# Owning platform instance. Will be set by EntityPlatform
platform = None
# If we reported if this entity was slow
_slow_reported = False
@ -311,19 +313,13 @@ class Entity(object):
if self.parallel_updates:
self.parallel_updates.release()
def remove(self) -> None:
"""Remove entity from HASS."""
run_coroutine_threadsafe(
self.async_remove(), self.hass.loop
).result()
@asyncio.coroutine
def async_remove(self) -> None:
"""Remove entity from async HASS.
This method must be run in the event loop.
"""
self.hass.states.async_remove(self.entity_id)
def async_remove(self):
"""Remove entity from Home Assistant."""
if self.platform is not None:
yield from self.platform.async_remove_entity(self.entity_id)
else:
self.hass.states.async_remove(self.entity_id)
def _attr_setter(self, name, typ, attr, attrs):
"""Populate attributes based on properties."""

View file

@ -1,6 +1,7 @@
"""Helpers for components that manage entities."""
import asyncio
from datetime import timedelta
from itertools import chain
from homeassistant import config as conf_util
from homeassistant.setup import async_prepare_setup_platform
@ -9,7 +10,6 @@ from homeassistant.const import (
DEVICE_DEFAULT_NAME)
from homeassistant.core import callback, valid_entity_id
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
from homeassistant.loader import get_component
from homeassistant.helpers import config_per_platform, discovery
from homeassistant.helpers.entity import async_generate_entity_id
from homeassistant.helpers.event import (
@ -27,7 +27,15 @@ PLATFORM_NOT_READY_RETRIES = 10
class EntityComponent(object):
"""Helper class that will help a component manage its entities."""
"""The EntityComponent manages platforms that manages entities.
This class has the following responsibilities:
- Process the configuration and set up a platform based component.
- Manage the platforms and their entities.
- Help extract the entities from a service call.
- Maintain a group that tracks all platform entities.
- Listen for discovery events for platforms related to the domain.
"""
def __init__(self, logger, domain, hass,
scan_interval=DEFAULT_SCAN_INTERVAL, group_name=None):
@ -40,7 +48,6 @@ class EntityComponent(object):
self.scan_interval = scan_interval
self.group_name = group_name
self.entities = {}
self.config = None
self._platforms = {
@ -49,6 +56,20 @@ class EntityComponent(object):
self.async_add_entities = self._platforms['core'].async_add_entities
self.add_entities = self._platforms['core'].add_entities
@property
def entities(self):
"""Return an iterable that returns all entities."""
return chain.from_iterable(platform.entities.values() for platform
in self._platforms.values())
def get_entity(self, entity_id):
"""Helper method to get an entity."""
for platform in self._platforms.values():
entity = platform.entities.get(entity_id)
if entity is not None:
return entity
return None
def setup(self, config):
"""Set up a full entity component.
@ -77,11 +98,10 @@ class EntityComponent(object):
# Generic discovery listener for loading platform dynamically
# Refer to: homeassistant.components.discovery.load_platform()
@callback
@asyncio.coroutine
def component_platform_discovered(platform, info):
"""Handle the loading of a platform."""
self.hass.async_add_job(
self._async_setup_platform(platform, {}, info))
yield from self._async_setup_platform(platform, {}, info)
discovery.async_listen_platform(
self.hass, self.domain, component_platform_discovered)
@ -107,13 +127,11 @@ class EntityComponent(object):
This method must be run in the event loop.
"""
if ATTR_ENTITY_ID not in service.data:
return [entity for entity in self.entities.values()
if entity.available]
return [entity for entity in self.entities if entity.available]
return [self.entities[entity_id] for entity_id
in extract_entity_ids(self.hass, service, expand_group)
if entity_id in self.entities and
self.entities[entity_id].available]
entity_ids = set(extract_entity_ids(self.hass, service, expand_group))
return [entity for entity in self.entities
if entity.available and entity.entity_id in entity_ids]
@asyncio.coroutine
def _async_setup_platform(self, platform_type, platform_config,
@ -193,80 +211,23 @@ class EntityComponent(object):
finally:
warn_task.cancel()
def add_entity(self, entity, platform=None, update_before_add=False):
"""Add entity to component."""
return run_coroutine_threadsafe(
self.async_add_entity(entity, platform, update_before_add),
self.hass.loop
).result()
@asyncio.coroutine
def async_add_entity(self, entity, platform=None, update_before_add=False):
"""Add entity to component.
This method must be run in the event loop.
"""
if entity is None or entity in self.entities.values():
return False
entity.hass = self.hass
# Update properties before we generate the entity_id
if update_before_add:
try:
yield from entity.async_device_update(warning=False)
except Exception: # pylint: disable=broad-except
self.logger.exception("Error on device update!")
return False
# Write entity_id to entity
if getattr(entity, 'entity_id', None) is None:
object_id = entity.name or DEVICE_DEFAULT_NAME
if platform is not None and platform.entity_namespace is not None:
object_id = '{} {}'.format(platform.entity_namespace,
object_id)
entity.entity_id = async_generate_entity_id(
self.entity_id_format, object_id,
self.entities.keys())
# Make sure it is valid in case an entity set the value themselves
if entity.entity_id in self.entities:
raise HomeAssistantError(
'Entity id already exists: {}'.format(entity.entity_id))
elif not valid_entity_id(entity.entity_id):
raise HomeAssistantError(
'Invalid entity id: {}'.format(entity.entity_id))
self.entities[entity.entity_id] = entity
if hasattr(entity, 'async_added_to_hass'):
yield from entity.async_added_to_hass()
yield from entity.async_update_ha_state()
return True
def update_group(self):
"""Set up and/or update component group."""
run_callback_threadsafe(
self.hass.loop, self.async_update_group).result()
@callback
def async_update_group(self):
"""Set up and/or update component group.
This method must be run in the event loop.
"""
if self.group_name is not None:
ids = sorted(self.entities,
key=lambda x: self.entities[x].name or x)
group = get_component('group')
group.async_set_group(
self.hass, slugify(self.group_name), name=self.group_name,
visible=False, entity_ids=ids
)
if self.group_name is None:
return
ids = [entity.entity_id for entity in
sorted(self.entities,
key=lambda entity: entity.name or entity.entity_id)]
self.hass.components.group.async_set_group(
slugify(self.group_name), name=self.group_name,
visible=False, entity_ids=ids
)
def reset(self):
"""Remove entities and reset the entity component to initial values."""
@ -287,12 +248,17 @@ class EntityComponent(object):
self._platforms = {
'core': self._platforms['core']
}
self.entities = {}
self.config = None
if self.group_name is not None:
group = get_component('group')
group.async_remove(self.hass, slugify(self.group_name))
self.hass.components.group.async_remove(slugify(self.group_name))
@asyncio.coroutine
def async_remove_entity(self, entity_id):
"""Remove an entity managed by one of the platforms."""
for platform in self._platforms.values():
if entity_id in platform.entities:
yield from platform.async_remove_entity(entity_id)
def prepare_reload(self):
"""Prepare reloading this entity component."""
@ -323,7 +289,7 @@ class EntityComponent(object):
class EntityPlatform(object):
"""Keep track of entities for a single platform and stay in loop."""
"""Manage the entities for a single platform."""
def __init__(self, component, platform, scan_interval, parallel_updates,
entity_namespace):
@ -333,7 +299,7 @@ class EntityPlatform(object):
self.scan_interval = scan_interval
self.parallel_updates = None
self.entity_namespace = entity_namespace
self.platform_entities = []
self.entities = {}
self._tasks = []
self._async_unsub_polling = None
self._process_updates = asyncio.Lock(loop=component.hass.loop)
@ -391,40 +357,88 @@ class EntityPlatform(object):
if not new_entities:
return
@asyncio.coroutine
def async_process_entity(new_entity):
"""Add entities to StateMachine."""
new_entity.parallel_updates = self.parallel_updates
ret = yield from self.component.async_add_entity(
new_entity, self, update_before_add=update_before_add
)
if ret:
self.platform_entities.append(new_entity)
component_entities = set(entity.entity_id for entity
in self.component.entities)
tasks = [async_process_entity(entity) for entity in new_entities]
tasks = [
self._async_add_entity(entity, update_before_add,
component_entities)
for entity in new_entities]
yield from asyncio.wait(tasks, loop=self.component.hass.loop)
self.component.async_update_group()
if self._async_unsub_polling is not None or \
not any(entity.should_poll for entity
in self.platform_entities):
in self.entities.values()):
return
self._async_unsub_polling = async_track_time_interval(
self.component.hass, self._update_entity_states, self.scan_interval
)
@asyncio.coroutine
def _async_add_entity(self, entity, update_before_add, component_entities):
"""Helper method to add an entity to the platform."""
if entity is None:
raise ValueError('Entity cannot be None')
# Do nothing if entity has already been added based on unique id.
if entity in self.component.entities:
return
entity.hass = self.component.hass
entity.platform = self
entity.parallel_updates = self.parallel_updates
# Update properties before we generate the entity_id
if update_before_add:
try:
yield from entity.async_device_update(warning=False)
except Exception: # pylint: disable=broad-except
self.component.logger.exception(
"%s: Error on device update!", self.platform)
return
# Write entity_id to entity
if getattr(entity, 'entity_id', None) is None:
object_id = entity.name or DEVICE_DEFAULT_NAME
if self.entity_namespace is not None:
object_id = '{} {}'.format(self.entity_namespace,
object_id)
entity.entity_id = async_generate_entity_id(
self.component.entity_id_format, object_id,
component_entities)
# Make sure it is valid in case an entity set the value themselves
if not valid_entity_id(entity.entity_id):
raise HomeAssistantError(
'Invalid entity id: {}'.format(entity.entity_id))
elif entity.entity_id in component_entities:
raise HomeAssistantError(
'Entity id already exists: {}'.format(entity.entity_id))
self.entities[entity.entity_id] = entity
component_entities.add(entity.entity_id)
if hasattr(entity, 'async_added_to_hass'):
yield from entity.async_added_to_hass()
yield from entity.async_update_ha_state()
@asyncio.coroutine
def async_reset(self):
"""Remove all entities and reset data.
This method must be run in the event loop.
"""
if not self.platform_entities:
if not self.entities:
return
tasks = [entity.async_remove() for entity in self.platform_entities]
tasks = [self._async_remove_entity(entity_id)
for entity_id in self.entities]
yield from asyncio.wait(tasks, loop=self.component.hass.loop)
@ -432,6 +446,28 @@ class EntityPlatform(object):
self._async_unsub_polling()
self._async_unsub_polling = None
@asyncio.coroutine
def async_remove_entity(self, entity_id):
"""Remove entity id from platform."""
yield from self._async_remove_entity(entity_id)
# Clean up polling job if no longer needed
if (self._async_unsub_polling is not None and
not any(entity.should_poll for entity
in self.entities.values())):
self._async_unsub_polling()
self._async_unsub_polling = None
@asyncio.coroutine
def _async_remove_entity(self, entity_id):
"""Remove entity id from platform."""
entity = self.entities.pop(entity_id)
if hasattr(entity, 'async_will_remove_from_hass'):
yield from entity.async_will_remove_from_hass()
self.component.hass.states.async_remove(entity_id)
@asyncio.coroutine
def _update_entity_states(self, now):
"""Update the states of all the polling entities.
@ -450,7 +486,7 @@ class EntityPlatform(object):
with (yield from self._process_updates):
tasks = []
for entity in self.platform_entities:
for entity in self.entities.values():
if not entity.should_poll:
continue
tasks.append(entity.async_update_ha_state(True))

View file

@ -350,7 +350,7 @@ class TestComponentsGroup(unittest.TestCase):
assert sorted(self.hass.states.entity_ids()) == \
['group.empty_group', 'group.second_group', 'group.test_group']
assert self.hass.bus.listeners['state_changed'] == 3
assert self.hass.bus.listeners['state_changed'] == 2
with patch('homeassistant.config.load_yaml_config_file', return_value={
'group': {
@ -365,14 +365,6 @@ class TestComponentsGroup(unittest.TestCase):
assert self.hass.states.entity_ids() == ['group.hello']
assert self.hass.bus.listeners['state_changed'] == 1
def test_stopping_a_group(self):
"""Test that a group correctly removes itself."""
grp = group.Group.create_group(
self.hass, 'light', ['light.test_1', 'light.test_2'])
assert self.hass.states.entity_ids() == ['group.light']
grp.stop()
assert self.hass.states.entity_ids() == []
def test_changing_group_visibility(self):
"""Test that a group can be hidden and shown."""
assert setup_component(self.hass, 'group', {

View file

@ -388,3 +388,15 @@ def test_async_pararell_updates_with_two(hass):
test_lock.release()
yield from asyncio.sleep(0, loop=hass.loop)
test_lock.release()
@asyncio.coroutine
def test_async_remove_no_platform(hass):
"""Test async_remove method when no platform set."""
ent = entity.Entity()
ent.hass = hass
ent.entity_id = 'test.test'
yield from ent.async_update_ha_state()
assert len(hass.states.async_entity_ids()) == 1
yield from ent.async_remove()
assert len(hass.states.async_entity_ids()) == 0

View file

@ -86,6 +86,7 @@ class TestHelpersEntityComponent(unittest.TestCase):
assert len(self.hass.states.entity_ids()) == 0
component.add_entities([EntityTest()])
self.hass.block_till_done()
# group exists
assert len(self.hass.states.entity_ids()) == 2
@ -98,6 +99,7 @@ class TestHelpersEntityComponent(unittest.TestCase):
# group extended
component.add_entities([EntityTest(name='goodbye')])
self.hass.block_till_done()
assert len(self.hass.states.entity_ids()) == 3
group = self.hass.states.get('group.everyone')
@ -214,7 +216,7 @@ class TestHelpersEntityComponent(unittest.TestCase):
assert 0 == len(self.hass.states.entity_ids())
component.add_entities([None, EntityTest(unique_id='not_very_unique')])
component.add_entities([EntityTest(unique_id='not_very_unique')])
assert 1 == len(self.hass.states.entity_ids())
@ -671,3 +673,14 @@ def test_raise_error_on_update(hass):
assert len(updates) == 1
assert 1 in updates
@asyncio.coroutine
def test_async_remove_with_platform(hass):
"""Remove an entity from a platform."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
entity1 = EntityTest(name='test_1')
yield from component.async_add_entities([entity1])
assert len(hass.states.async_entity_ids()) == 1
yield from entity1.async_remove()
assert len(hass.states.async_entity_ids()) == 0