Clean up entity component (#11691)
* Clean up entity component * Lint * List -> Tuple * Add Entity.async_remove back * Unflake setting up group test
This commit is contained in:
parent
d478517c51
commit
183e0543b4
14 changed files with 230 additions and 191 deletions
|
@ -338,10 +338,9 @@ class AutomationEntity(ToggleEntity):
|
|||
yield from self.async_update_ha_state()
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_remove(self):
|
||||
"""Remove automation from HASS."""
|
||||
def async_will_remove_from_hass(self):
|
||||
"""Remove listeners when removing automation from HASS."""
|
||||
yield from self.async_turn_off()
|
||||
yield from super().async_remove()
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_enable(self):
|
||||
|
|
|
@ -238,6 +238,5 @@ class FlicButton(BinarySensorDevice):
|
|||
import pyflic
|
||||
|
||||
if connection_status == pyflic.ConnectionStatus.Disconnected:
|
||||
_LOGGER.info("Button (%s) disconnected. Reason: %s",
|
||||
self.address, disconnect_reason)
|
||||
self.remove()
|
||||
_LOGGER.warning("Button (%s) disconnected. Reason: %s",
|
||||
self.address, disconnect_reason)
|
||||
|
|
|
@ -124,15 +124,15 @@ def async_setup(hass, config):
|
|||
"""Set up the camera component."""
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass, SCAN_INTERVAL)
|
||||
|
||||
hass.http.register_view(CameraImageView(component.entities))
|
||||
hass.http.register_view(CameraMjpegStream(component.entities))
|
||||
hass.http.register_view(CameraImageView(component))
|
||||
hass.http.register_view(CameraMjpegStream(component))
|
||||
|
||||
yield from component.async_setup(config)
|
||||
|
||||
@callback
|
||||
def update_tokens(time):
|
||||
"""Update tokens of the entities."""
|
||||
for entity in component.entities.values():
|
||||
for entity in component.entities:
|
||||
entity.async_update_token()
|
||||
hass.async_add_job(entity.async_update_ha_state())
|
||||
|
||||
|
@ -358,14 +358,14 @@ class CameraView(HomeAssistantView):
|
|||
|
||||
requires_auth = False
|
||||
|
||||
def __init__(self, entities):
|
||||
def __init__(self, component):
|
||||
"""Initialize a basic camera view."""
|
||||
self.entities = entities
|
||||
self.component = component
|
||||
|
||||
@asyncio.coroutine
|
||||
def get(self, request, entity_id):
|
||||
"""Start a GET request."""
|
||||
camera = self.entities.get(entity_id)
|
||||
camera = self.component.get_entity(entity_id)
|
||||
|
||||
if camera is None:
|
||||
status = 404 if request[KEY_AUTHENTICATED] else 401
|
||||
|
|
|
@ -42,8 +42,6 @@ ATTR_ORDER = 'order'
|
|||
ATTR_VIEW = 'view'
|
||||
ATTR_VISIBLE = 'visible'
|
||||
|
||||
DATA_ALL_GROUPS = 'data_all_groups'
|
||||
|
||||
SERVICE_SET_VISIBILITY = 'set_visibility'
|
||||
SERVICE_SET = 'set'
|
||||
SERVICE_REMOVE = 'remove'
|
||||
|
@ -250,8 +248,10 @@ def get_entity_ids(hass, entity_id, domain_filter=None):
|
|||
@asyncio.coroutine
|
||||
def async_setup(hass, config):
|
||||
"""Set up all groups found definded in the configuration."""
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
hass.data[DATA_ALL_GROUPS] = {}
|
||||
component = hass.data.get(DOMAIN)
|
||||
|
||||
if component is None:
|
||||
component = hass.data[DOMAIN] = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
|
||||
yield from _async_process_config(hass, config, component)
|
||||
|
||||
|
@ -271,10 +271,11 @@ def async_setup(hass, config):
|
|||
def groups_service_handler(service):
|
||||
"""Handle dynamic group service functions."""
|
||||
object_id = service.data[ATTR_OBJECT_ID]
|
||||
service_groups = hass.data[DATA_ALL_GROUPS]
|
||||
entity_id = ENTITY_ID_FORMAT.format(object_id)
|
||||
group = component.get_entity(entity_id)
|
||||
|
||||
# new group
|
||||
if service.service == SERVICE_SET and object_id not in service_groups:
|
||||
if service.service == SERVICE_SET and group is None:
|
||||
entity_ids = service.data.get(ATTR_ENTITIES) or \
|
||||
service.data.get(ATTR_ADD_ENTITIES) or None
|
||||
|
||||
|
@ -289,12 +290,15 @@ def async_setup(hass, config):
|
|||
user_defined=False,
|
||||
**extra_arg
|
||||
)
|
||||
return
|
||||
|
||||
if group is None:
|
||||
_LOGGER.warning("%s:Group '%s' doesn't exist!",
|
||||
service.service, object_id)
|
||||
return
|
||||
|
||||
# update group
|
||||
if service.service == SERVICE_SET:
|
||||
group = service_groups[object_id]
|
||||
need_update = False
|
||||
|
||||
if ATTR_ADD_ENTITIES in service.data:
|
||||
|
@ -333,12 +337,7 @@ def async_setup(hass, config):
|
|||
|
||||
# remove group
|
||||
if service.service == SERVICE_REMOVE:
|
||||
if object_id not in service_groups:
|
||||
_LOGGER.warning("Group '%s' doesn't exist!", object_id)
|
||||
return
|
||||
|
||||
del_group = service_groups.pop(object_id)
|
||||
yield from del_group.async_stop()
|
||||
yield from component.async_remove_entity(entity_id)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_SET, groups_service_handler,
|
||||
|
@ -395,7 +394,7 @@ class Group(Entity):
|
|||
"""Track a group of entity ids."""
|
||||
|
||||
def __init__(self, hass, name, order=None, visible=True, icon=None,
|
||||
view=False, control=None, user_defined=True):
|
||||
view=False, control=None, user_defined=True, entity_ids=None):
|
||||
"""Initialize a group.
|
||||
|
||||
This Object has factory function for creation.
|
||||
|
@ -405,7 +404,10 @@ class Group(Entity):
|
|||
self._state = STATE_UNKNOWN
|
||||
self._icon = icon
|
||||
self.view = view
|
||||
self.tracking = []
|
||||
if entity_ids:
|
||||
self.tracking = tuple(ent_id.lower() for ent_id in entity_ids)
|
||||
else:
|
||||
self.tracking = tuple()
|
||||
self.group_on = None
|
||||
self.group_off = None
|
||||
self.visible = visible
|
||||
|
@ -439,23 +441,21 @@ class Group(Entity):
|
|||
hass, name,
|
||||
order=len(hass.states.async_entity_ids(DOMAIN)),
|
||||
visible=visible, icon=icon, view=view, control=control,
|
||||
user_defined=user_defined
|
||||
user_defined=user_defined, entity_ids=entity_ids
|
||||
)
|
||||
|
||||
group.entity_id = async_generate_entity_id(
|
||||
ENTITY_ID_FORMAT, object_id or name, hass=hass)
|
||||
|
||||
# run other async stuff
|
||||
if entity_ids is not None:
|
||||
yield from group.async_update_tracked_entity_ids(entity_ids)
|
||||
else:
|
||||
yield from group.async_update_ha_state(True)
|
||||
|
||||
# If called before the platform async_setup is called (test cases)
|
||||
if DATA_ALL_GROUPS not in hass.data:
|
||||
hass.data[DATA_ALL_GROUPS] = {}
|
||||
component = hass.data.get(DOMAIN)
|
||||
|
||||
if component is None:
|
||||
component = hass.data[DOMAIN] = \
|
||||
EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
|
||||
yield from component.async_add_entities([group], True)
|
||||
|
||||
hass.data[DATA_ALL_GROUPS][object_id] = group
|
||||
return group
|
||||
|
||||
@property
|
||||
|
@ -534,10 +534,6 @@ class Group(Entity):
|
|||
yield from self.async_update_ha_state(True)
|
||||
self.async_start()
|
||||
|
||||
def start(self):
|
||||
"""Start tracking members."""
|
||||
self.hass.add_job(self.async_start)
|
||||
|
||||
@callback
|
||||
def async_start(self):
|
||||
"""Start tracking members.
|
||||
|
@ -549,17 +545,15 @@ class Group(Entity):
|
|||
self.hass, self.tracking, self._async_state_changed_listener
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
"""Unregister the group from Home Assistant."""
|
||||
run_coroutine_threadsafe(self.async_stop(), self.hass.loop).result()
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_stop(self):
|
||||
"""Unregister the group from Home Assistant.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
yield from self.async_remove()
|
||||
if self._async_unsub_state_changed:
|
||||
self._async_unsub_state_changed()
|
||||
self._async_unsub_state_changed = None
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_update(self):
|
||||
|
@ -567,17 +561,19 @@ class Group(Entity):
|
|||
self._state = STATE_UNKNOWN
|
||||
self._async_update_group_state()
|
||||
|
||||
def async_remove(self):
|
||||
"""Remove group from HASS.
|
||||
@asyncio.coroutine
|
||||
def async_added_to_hass(self):
|
||||
"""Callback when added to HASS."""
|
||||
if self.tracking:
|
||||
self.async_start()
|
||||
|
||||
This method must be run in the event loop and returns a coroutine.
|
||||
"""
|
||||
@asyncio.coroutine
|
||||
def async_will_remove_from_hass(self):
|
||||
"""Callback when removed from HASS."""
|
||||
if self._async_unsub_state_changed:
|
||||
self._async_unsub_state_changed()
|
||||
self._async_unsub_state_changed = None
|
||||
|
||||
return super().async_remove()
|
||||
|
||||
@asyncio.coroutine
|
||||
def _async_state_changed_listener(self, entity_id, old_state, new_state):
|
||||
"""Respond to a member state changing.
|
||||
|
|
|
@ -82,7 +82,7 @@ def async_setup(hass, config):
|
|||
mailbox_entity = MailboxEntity(hass, mailbox)
|
||||
component = EntityComponent(
|
||||
logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL)
|
||||
yield from component.async_add_entity(mailbox_entity)
|
||||
yield from component.async_add_entities([mailbox_entity])
|
||||
|
||||
setup_tasks = [async_setup_platform(p_type, p_config) for p_type, p_config
|
||||
in config_per_platform(config, DOMAIN)]
|
||||
|
|
|
@ -366,7 +366,7 @@ def async_setup(hass, config):
|
|||
component = EntityComponent(
|
||||
logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL)
|
||||
|
||||
hass.http.register_view(MediaPlayerImageView(component.entities))
|
||||
hass.http.register_view(MediaPlayerImageView(component))
|
||||
|
||||
yield from component.async_setup(config)
|
||||
|
||||
|
@ -929,14 +929,14 @@ class MediaPlayerImageView(HomeAssistantView):
|
|||
url = '/api/media_player_proxy/{entity_id}'
|
||||
name = 'api:media_player:image'
|
||||
|
||||
def __init__(self, entities):
|
||||
def __init__(self, component):
|
||||
"""Initialize a media player view."""
|
||||
self.entities = entities
|
||||
self.component = component
|
||||
|
||||
@asyncio.coroutine
|
||||
def get(self, request, entity_id):
|
||||
"""Start a get request."""
|
||||
player = self.entities.get(entity_id)
|
||||
player = self.component.get_entity(entity_id)
|
||||
if player is None:
|
||||
status = 404 if request[KEY_AUTHENTICATED] else 401
|
||||
return web.Response(status=status)
|
||||
|
|
|
@ -161,7 +161,7 @@ def async_setup(hass, config):
|
|||
face.store.pop(g_id)
|
||||
|
||||
entity = entities.pop(g_id)
|
||||
yield from entity.async_remove()
|
||||
hass.states.async_remove(entity.entity_id)
|
||||
except HomeAssistantError as err:
|
||||
_LOGGER.error("Can't delete group '%s' with error: %s", g_id, err)
|
||||
|
||||
|
|
|
@ -86,7 +86,7 @@ def _create_instance(hass, account_name, api_key, shared_secret,
|
|||
token, stored_rtm_config, component):
|
||||
entity = RememberTheMilk(account_name, api_key, shared_secret,
|
||||
token, stored_rtm_config)
|
||||
component.add_entity(entity)
|
||||
component.add_entities([entity])
|
||||
hass.services.register(
|
||||
DOMAIN, '{}_create_task'.format(account_name), entity.create_task,
|
||||
schema=SERVICE_SCHEMA_CREATE_TASK)
|
||||
|
|
|
@ -156,7 +156,7 @@ def _async_process_config(hass, config, component):
|
|||
def service_handler(service):
|
||||
"""Execute a service call to script.<script name>."""
|
||||
entity_id = ENTITY_ID_FORMAT.format(service.service)
|
||||
script = component.entities.get(entity_id)
|
||||
script = component.get_entity(entity_id)
|
||||
if script.is_on:
|
||||
_LOGGER.warning("Script %s already running.", entity_id)
|
||||
return
|
||||
|
@ -219,15 +219,11 @@ class ScriptEntity(ToggleEntity):
|
|||
"""Turn script off."""
|
||||
self.script.async_stop()
|
||||
|
||||
def async_remove(self):
|
||||
"""Remove script from HASS.
|
||||
|
||||
This method must be run in the event loop and returns a coroutine.
|
||||
"""
|
||||
@asyncio.coroutine
|
||||
def async_will_remove_from_hass(self):
|
||||
"""Stop script and remove service when it will be removed from HASS."""
|
||||
if self.script.is_running:
|
||||
self.script.async_stop()
|
||||
|
||||
# remove service
|
||||
self.hass.services.async_remove(DOMAIN, self.object_id)
|
||||
|
||||
return super().async_remove()
|
||||
|
|
|
@ -15,8 +15,7 @@ from homeassistant.core import HomeAssistant, callback
|
|||
from homeassistant.config import DATA_CUSTOMIZE
|
||||
from homeassistant.exceptions import NoEntitySpecifiedError
|
||||
from homeassistant.util import ensure_unique_string, slugify
|
||||
from homeassistant.util.async import (
|
||||
run_coroutine_threadsafe, run_callback_threadsafe)
|
||||
from homeassistant.util.async import run_callback_threadsafe
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
SLOW_UPDATE_WARNING = 10
|
||||
|
@ -66,9 +65,12 @@ class Entity(object):
|
|||
# this class. These may be used to customize the behavior of the entity.
|
||||
entity_id = None # type: str
|
||||
|
||||
# Owning hass instance. Will be set by EntityComponent
|
||||
# Owning hass instance. Will be set by EntityPlatform
|
||||
hass = None # type: Optional[HomeAssistant]
|
||||
|
||||
# Owning platform instance. Will be set by EntityPlatform
|
||||
platform = None
|
||||
|
||||
# If we reported if this entity was slow
|
||||
_slow_reported = False
|
||||
|
||||
|
@ -311,19 +313,13 @@ class Entity(object):
|
|||
if self.parallel_updates:
|
||||
self.parallel_updates.release()
|
||||
|
||||
def remove(self) -> None:
|
||||
"""Remove entity from HASS."""
|
||||
run_coroutine_threadsafe(
|
||||
self.async_remove(), self.hass.loop
|
||||
).result()
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_remove(self) -> None:
|
||||
"""Remove entity from async HASS.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
self.hass.states.async_remove(self.entity_id)
|
||||
def async_remove(self):
|
||||
"""Remove entity from Home Assistant."""
|
||||
if self.platform is not None:
|
||||
yield from self.platform.async_remove_entity(self.entity_id)
|
||||
else:
|
||||
self.hass.states.async_remove(self.entity_id)
|
||||
|
||||
def _attr_setter(self, name, typ, attr, attrs):
|
||||
"""Populate attributes based on properties."""
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Helpers for components that manage entities."""
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
from itertools import chain
|
||||
|
||||
from homeassistant import config as conf_util
|
||||
from homeassistant.setup import async_prepare_setup_platform
|
||||
|
@ -9,7 +10,6 @@ from homeassistant.const import (
|
|||
DEVICE_DEFAULT_NAME)
|
||||
from homeassistant.core import callback, valid_entity_id
|
||||
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
|
||||
from homeassistant.loader import get_component
|
||||
from homeassistant.helpers import config_per_platform, discovery
|
||||
from homeassistant.helpers.entity import async_generate_entity_id
|
||||
from homeassistant.helpers.event import (
|
||||
|
@ -27,7 +27,15 @@ PLATFORM_NOT_READY_RETRIES = 10
|
|||
|
||||
|
||||
class EntityComponent(object):
|
||||
"""Helper class that will help a component manage its entities."""
|
||||
"""The EntityComponent manages platforms that manages entities.
|
||||
|
||||
This class has the following responsibilities:
|
||||
- Process the configuration and set up a platform based component.
|
||||
- Manage the platforms and their entities.
|
||||
- Help extract the entities from a service call.
|
||||
- Maintain a group that tracks all platform entities.
|
||||
- Listen for discovery events for platforms related to the domain.
|
||||
"""
|
||||
|
||||
def __init__(self, logger, domain, hass,
|
||||
scan_interval=DEFAULT_SCAN_INTERVAL, group_name=None):
|
||||
|
@ -40,7 +48,6 @@ class EntityComponent(object):
|
|||
self.scan_interval = scan_interval
|
||||
self.group_name = group_name
|
||||
|
||||
self.entities = {}
|
||||
self.config = None
|
||||
|
||||
self._platforms = {
|
||||
|
@ -49,6 +56,20 @@ class EntityComponent(object):
|
|||
self.async_add_entities = self._platforms['core'].async_add_entities
|
||||
self.add_entities = self._platforms['core'].add_entities
|
||||
|
||||
@property
|
||||
def entities(self):
|
||||
"""Return an iterable that returns all entities."""
|
||||
return chain.from_iterable(platform.entities.values() for platform
|
||||
in self._platforms.values())
|
||||
|
||||
def get_entity(self, entity_id):
|
||||
"""Helper method to get an entity."""
|
||||
for platform in self._platforms.values():
|
||||
entity = platform.entities.get(entity_id)
|
||||
if entity is not None:
|
||||
return entity
|
||||
return None
|
||||
|
||||
def setup(self, config):
|
||||
"""Set up a full entity component.
|
||||
|
||||
|
@ -77,11 +98,10 @@ class EntityComponent(object):
|
|||
|
||||
# Generic discovery listener for loading platform dynamically
|
||||
# Refer to: homeassistant.components.discovery.load_platform()
|
||||
@callback
|
||||
@asyncio.coroutine
|
||||
def component_platform_discovered(platform, info):
|
||||
"""Handle the loading of a platform."""
|
||||
self.hass.async_add_job(
|
||||
self._async_setup_platform(platform, {}, info))
|
||||
yield from self._async_setup_platform(platform, {}, info)
|
||||
|
||||
discovery.async_listen_platform(
|
||||
self.hass, self.domain, component_platform_discovered)
|
||||
|
@ -107,13 +127,11 @@ class EntityComponent(object):
|
|||
This method must be run in the event loop.
|
||||
"""
|
||||
if ATTR_ENTITY_ID not in service.data:
|
||||
return [entity for entity in self.entities.values()
|
||||
if entity.available]
|
||||
return [entity for entity in self.entities if entity.available]
|
||||
|
||||
return [self.entities[entity_id] for entity_id
|
||||
in extract_entity_ids(self.hass, service, expand_group)
|
||||
if entity_id in self.entities and
|
||||
self.entities[entity_id].available]
|
||||
entity_ids = set(extract_entity_ids(self.hass, service, expand_group))
|
||||
return [entity for entity in self.entities
|
||||
if entity.available and entity.entity_id in entity_ids]
|
||||
|
||||
@asyncio.coroutine
|
||||
def _async_setup_platform(self, platform_type, platform_config,
|
||||
|
@ -193,80 +211,23 @@ class EntityComponent(object):
|
|||
finally:
|
||||
warn_task.cancel()
|
||||
|
||||
def add_entity(self, entity, platform=None, update_before_add=False):
|
||||
"""Add entity to component."""
|
||||
return run_coroutine_threadsafe(
|
||||
self.async_add_entity(entity, platform, update_before_add),
|
||||
self.hass.loop
|
||||
).result()
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_add_entity(self, entity, platform=None, update_before_add=False):
|
||||
"""Add entity to component.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
if entity is None or entity in self.entities.values():
|
||||
return False
|
||||
|
||||
entity.hass = self.hass
|
||||
|
||||
# Update properties before we generate the entity_id
|
||||
if update_before_add:
|
||||
try:
|
||||
yield from entity.async_device_update(warning=False)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
self.logger.exception("Error on device update!")
|
||||
return False
|
||||
|
||||
# Write entity_id to entity
|
||||
if getattr(entity, 'entity_id', None) is None:
|
||||
object_id = entity.name or DEVICE_DEFAULT_NAME
|
||||
|
||||
if platform is not None and platform.entity_namespace is not None:
|
||||
object_id = '{} {}'.format(platform.entity_namespace,
|
||||
object_id)
|
||||
|
||||
entity.entity_id = async_generate_entity_id(
|
||||
self.entity_id_format, object_id,
|
||||
self.entities.keys())
|
||||
|
||||
# Make sure it is valid in case an entity set the value themselves
|
||||
if entity.entity_id in self.entities:
|
||||
raise HomeAssistantError(
|
||||
'Entity id already exists: {}'.format(entity.entity_id))
|
||||
elif not valid_entity_id(entity.entity_id):
|
||||
raise HomeAssistantError(
|
||||
'Invalid entity id: {}'.format(entity.entity_id))
|
||||
|
||||
self.entities[entity.entity_id] = entity
|
||||
|
||||
if hasattr(entity, 'async_added_to_hass'):
|
||||
yield from entity.async_added_to_hass()
|
||||
|
||||
yield from entity.async_update_ha_state()
|
||||
|
||||
return True
|
||||
|
||||
def update_group(self):
|
||||
"""Set up and/or update component group."""
|
||||
run_callback_threadsafe(
|
||||
self.hass.loop, self.async_update_group).result()
|
||||
|
||||
@callback
|
||||
def async_update_group(self):
|
||||
"""Set up and/or update component group.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
if self.group_name is not None:
|
||||
ids = sorted(self.entities,
|
||||
key=lambda x: self.entities[x].name or x)
|
||||
group = get_component('group')
|
||||
group.async_set_group(
|
||||
self.hass, slugify(self.group_name), name=self.group_name,
|
||||
visible=False, entity_ids=ids
|
||||
)
|
||||
if self.group_name is None:
|
||||
return
|
||||
|
||||
ids = [entity.entity_id for entity in
|
||||
sorted(self.entities,
|
||||
key=lambda entity: entity.name or entity.entity_id)]
|
||||
|
||||
self.hass.components.group.async_set_group(
|
||||
slugify(self.group_name), name=self.group_name,
|
||||
visible=False, entity_ids=ids
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
"""Remove entities and reset the entity component to initial values."""
|
||||
|
@ -287,12 +248,17 @@ class EntityComponent(object):
|
|||
self._platforms = {
|
||||
'core': self._platforms['core']
|
||||
}
|
||||
self.entities = {}
|
||||
self.config = None
|
||||
|
||||
if self.group_name is not None:
|
||||
group = get_component('group')
|
||||
group.async_remove(self.hass, slugify(self.group_name))
|
||||
self.hass.components.group.async_remove(slugify(self.group_name))
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_remove_entity(self, entity_id):
|
||||
"""Remove an entity managed by one of the platforms."""
|
||||
for platform in self._platforms.values():
|
||||
if entity_id in platform.entities:
|
||||
yield from platform.async_remove_entity(entity_id)
|
||||
|
||||
def prepare_reload(self):
|
||||
"""Prepare reloading this entity component."""
|
||||
|
@ -323,7 +289,7 @@ class EntityComponent(object):
|
|||
|
||||
|
||||
class EntityPlatform(object):
|
||||
"""Keep track of entities for a single platform and stay in loop."""
|
||||
"""Manage the entities for a single platform."""
|
||||
|
||||
def __init__(self, component, platform, scan_interval, parallel_updates,
|
||||
entity_namespace):
|
||||
|
@ -333,7 +299,7 @@ class EntityPlatform(object):
|
|||
self.scan_interval = scan_interval
|
||||
self.parallel_updates = None
|
||||
self.entity_namespace = entity_namespace
|
||||
self.platform_entities = []
|
||||
self.entities = {}
|
||||
self._tasks = []
|
||||
self._async_unsub_polling = None
|
||||
self._process_updates = asyncio.Lock(loop=component.hass.loop)
|
||||
|
@ -391,40 +357,88 @@ class EntityPlatform(object):
|
|||
if not new_entities:
|
||||
return
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_process_entity(new_entity):
|
||||
"""Add entities to StateMachine."""
|
||||
new_entity.parallel_updates = self.parallel_updates
|
||||
ret = yield from self.component.async_add_entity(
|
||||
new_entity, self, update_before_add=update_before_add
|
||||
)
|
||||
if ret:
|
||||
self.platform_entities.append(new_entity)
|
||||
component_entities = set(entity.entity_id for entity
|
||||
in self.component.entities)
|
||||
|
||||
tasks = [async_process_entity(entity) for entity in new_entities]
|
||||
tasks = [
|
||||
self._async_add_entity(entity, update_before_add,
|
||||
component_entities)
|
||||
for entity in new_entities]
|
||||
|
||||
yield from asyncio.wait(tasks, loop=self.component.hass.loop)
|
||||
self.component.async_update_group()
|
||||
|
||||
if self._async_unsub_polling is not None or \
|
||||
not any(entity.should_poll for entity
|
||||
in self.platform_entities):
|
||||
in self.entities.values()):
|
||||
return
|
||||
|
||||
self._async_unsub_polling = async_track_time_interval(
|
||||
self.component.hass, self._update_entity_states, self.scan_interval
|
||||
)
|
||||
|
||||
@asyncio.coroutine
|
||||
def _async_add_entity(self, entity, update_before_add, component_entities):
|
||||
"""Helper method to add an entity to the platform."""
|
||||
if entity is None:
|
||||
raise ValueError('Entity cannot be None')
|
||||
|
||||
# Do nothing if entity has already been added based on unique id.
|
||||
if entity in self.component.entities:
|
||||
return
|
||||
|
||||
entity.hass = self.component.hass
|
||||
entity.platform = self
|
||||
entity.parallel_updates = self.parallel_updates
|
||||
|
||||
# Update properties before we generate the entity_id
|
||||
if update_before_add:
|
||||
try:
|
||||
yield from entity.async_device_update(warning=False)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
self.component.logger.exception(
|
||||
"%s: Error on device update!", self.platform)
|
||||
return
|
||||
|
||||
# Write entity_id to entity
|
||||
if getattr(entity, 'entity_id', None) is None:
|
||||
object_id = entity.name or DEVICE_DEFAULT_NAME
|
||||
|
||||
if self.entity_namespace is not None:
|
||||
object_id = '{} {}'.format(self.entity_namespace,
|
||||
object_id)
|
||||
|
||||
entity.entity_id = async_generate_entity_id(
|
||||
self.component.entity_id_format, object_id,
|
||||
component_entities)
|
||||
|
||||
# Make sure it is valid in case an entity set the value themselves
|
||||
if not valid_entity_id(entity.entity_id):
|
||||
raise HomeAssistantError(
|
||||
'Invalid entity id: {}'.format(entity.entity_id))
|
||||
elif entity.entity_id in component_entities:
|
||||
raise HomeAssistantError(
|
||||
'Entity id already exists: {}'.format(entity.entity_id))
|
||||
|
||||
self.entities[entity.entity_id] = entity
|
||||
component_entities.add(entity.entity_id)
|
||||
|
||||
if hasattr(entity, 'async_added_to_hass'):
|
||||
yield from entity.async_added_to_hass()
|
||||
|
||||
yield from entity.async_update_ha_state()
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_reset(self):
|
||||
"""Remove all entities and reset data.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
if not self.platform_entities:
|
||||
if not self.entities:
|
||||
return
|
||||
|
||||
tasks = [entity.async_remove() for entity in self.platform_entities]
|
||||
tasks = [self._async_remove_entity(entity_id)
|
||||
for entity_id in self.entities]
|
||||
|
||||
yield from asyncio.wait(tasks, loop=self.component.hass.loop)
|
||||
|
||||
|
@ -432,6 +446,28 @@ class EntityPlatform(object):
|
|||
self._async_unsub_polling()
|
||||
self._async_unsub_polling = None
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_remove_entity(self, entity_id):
|
||||
"""Remove entity id from platform."""
|
||||
yield from self._async_remove_entity(entity_id)
|
||||
|
||||
# Clean up polling job if no longer needed
|
||||
if (self._async_unsub_polling is not None and
|
||||
not any(entity.should_poll for entity
|
||||
in self.entities.values())):
|
||||
self._async_unsub_polling()
|
||||
self._async_unsub_polling = None
|
||||
|
||||
@asyncio.coroutine
|
||||
def _async_remove_entity(self, entity_id):
|
||||
"""Remove entity id from platform."""
|
||||
entity = self.entities.pop(entity_id)
|
||||
|
||||
if hasattr(entity, 'async_will_remove_from_hass'):
|
||||
yield from entity.async_will_remove_from_hass()
|
||||
|
||||
self.component.hass.states.async_remove(entity_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def _update_entity_states(self, now):
|
||||
"""Update the states of all the polling entities.
|
||||
|
@ -450,7 +486,7 @@ class EntityPlatform(object):
|
|||
|
||||
with (yield from self._process_updates):
|
||||
tasks = []
|
||||
for entity in self.platform_entities:
|
||||
for entity in self.entities.values():
|
||||
if not entity.should_poll:
|
||||
continue
|
||||
tasks.append(entity.async_update_ha_state(True))
|
||||
|
|
|
@ -350,7 +350,7 @@ class TestComponentsGroup(unittest.TestCase):
|
|||
|
||||
assert sorted(self.hass.states.entity_ids()) == \
|
||||
['group.empty_group', 'group.second_group', 'group.test_group']
|
||||
assert self.hass.bus.listeners['state_changed'] == 3
|
||||
assert self.hass.bus.listeners['state_changed'] == 2
|
||||
|
||||
with patch('homeassistant.config.load_yaml_config_file', return_value={
|
||||
'group': {
|
||||
|
@ -365,14 +365,6 @@ class TestComponentsGroup(unittest.TestCase):
|
|||
assert self.hass.states.entity_ids() == ['group.hello']
|
||||
assert self.hass.bus.listeners['state_changed'] == 1
|
||||
|
||||
def test_stopping_a_group(self):
|
||||
"""Test that a group correctly removes itself."""
|
||||
grp = group.Group.create_group(
|
||||
self.hass, 'light', ['light.test_1', 'light.test_2'])
|
||||
assert self.hass.states.entity_ids() == ['group.light']
|
||||
grp.stop()
|
||||
assert self.hass.states.entity_ids() == []
|
||||
|
||||
def test_changing_group_visibility(self):
|
||||
"""Test that a group can be hidden and shown."""
|
||||
assert setup_component(self.hass, 'group', {
|
||||
|
|
|
@ -388,3 +388,15 @@ def test_async_pararell_updates_with_two(hass):
|
|||
test_lock.release()
|
||||
yield from asyncio.sleep(0, loop=hass.loop)
|
||||
test_lock.release()
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_async_remove_no_platform(hass):
|
||||
"""Test async_remove method when no platform set."""
|
||||
ent = entity.Entity()
|
||||
ent.hass = hass
|
||||
ent.entity_id = 'test.test'
|
||||
yield from ent.async_update_ha_state()
|
||||
assert len(hass.states.async_entity_ids()) == 1
|
||||
yield from ent.async_remove()
|
||||
assert len(hass.states.async_entity_ids()) == 0
|
||||
|
|
|
@ -86,6 +86,7 @@ class TestHelpersEntityComponent(unittest.TestCase):
|
|||
assert len(self.hass.states.entity_ids()) == 0
|
||||
|
||||
component.add_entities([EntityTest()])
|
||||
self.hass.block_till_done()
|
||||
|
||||
# group exists
|
||||
assert len(self.hass.states.entity_ids()) == 2
|
||||
|
@ -98,6 +99,7 @@ class TestHelpersEntityComponent(unittest.TestCase):
|
|||
|
||||
# group extended
|
||||
component.add_entities([EntityTest(name='goodbye')])
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(self.hass.states.entity_ids()) == 3
|
||||
group = self.hass.states.get('group.everyone')
|
||||
|
@ -214,7 +216,7 @@ class TestHelpersEntityComponent(unittest.TestCase):
|
|||
|
||||
assert 0 == len(self.hass.states.entity_ids())
|
||||
|
||||
component.add_entities([None, EntityTest(unique_id='not_very_unique')])
|
||||
component.add_entities([EntityTest(unique_id='not_very_unique')])
|
||||
|
||||
assert 1 == len(self.hass.states.entity_ids())
|
||||
|
||||
|
@ -671,3 +673,14 @@ def test_raise_error_on_update(hass):
|
|||
|
||||
assert len(updates) == 1
|
||||
assert 1 in updates
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_async_remove_with_platform(hass):
|
||||
"""Remove an entity from a platform."""
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
entity1 = EntityTest(name='test_1')
|
||||
yield from component.async_add_entities([entity1])
|
||||
assert len(hass.states.async_entity_ids()) == 1
|
||||
yield from entity1.async_remove()
|
||||
assert len(hass.states.async_entity_ids()) == 0
|
||||
|
|
Loading…
Add table
Reference in a new issue