Async EntitiesComponent (#3820)

* first version

* First draft component entities

* Change add_entities to callback from coroutine

* Fix bug add async_prepare_reload

* Group draft v1

* group async

* bugfix

* bugfix v2

* fix lint

* fix extract_entity_ids

* fix other things

* move get_component out of executor

* bugfix

* Address minor changes

* lint

* bugfix - should work now

* make group init async only

* change update handling to old stuff

* fix group handling, remove generator from init

* fix lint

* protect loop for spaming with updates

* fix lint

* update test_group

* fix

* update group handling

* fix __init__ async trouble

* move device_tracker to new layout

* lint

* fix group unittest

* Test with coroutine

* fix bug

* now it works 💯

* ups

* first part of suggestion

* add_entities to coroutine

* change group

* convert add async_add_entity to coroutine

* fix unit tests

* fix lint

* fix lint part 2

* fix wrong import delete

* change async_update_tracked_entity_ids to coroutine

* fix

* revert last change

* fix unittest entity id

* fix unittest

* fix unittest

* fix unittest entity_component

* fix group

* fix group_test

* try part 2 to fix test_group

* fix all entity_component

* rename _process_config

* Change Group to init with factory

* fix lint

* fix lint

* fix callback

* Tweak entity component and group

* More fixes

* Final fixes

* No longer needed blocks

* Address @bbangert comments

* Add test for group.stop

* More callbacks for automation
This commit is contained in:
Pascal Vizeli 2016-10-16 18:35:46 +02:00 committed by Paulus Schoutsen
parent a0fdb2778d
commit 0b8b9ecb94
14 changed files with 503 additions and 266 deletions

View file

@ -1,5 +1,5 @@
"""Helpers for components that manage entities."""
from threading import Lock
import asyncio
from homeassistant import config as conf_util
from homeassistant.bootstrap import (prepare_setup_platform,
@ -7,12 +7,15 @@ from homeassistant.bootstrap import (prepare_setup_platform,
from homeassistant.const import (
ATTR_ENTITY_ID, CONF_SCAN_INTERVAL, CONF_ENTITY_NAMESPACE,
DEVICE_DEFAULT_NAME)
from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.loader import get_component
from homeassistant.helpers import config_per_platform, discovery
from homeassistant.helpers.entity import generate_entity_id
from homeassistant.helpers.event import track_utc_time_change
from homeassistant.helpers.entity import async_generate_entity_id
from homeassistant.helpers.event import async_track_utc_time_change
from homeassistant.helpers.service import extract_entity_ids
from homeassistant.util.async import (
run_callback_threadsafe, run_coroutine_threadsafe)
DEFAULT_SCAN_INTERVAL = 15
@ -37,11 +40,11 @@ class EntityComponent(object):
self.group = None
self.config = None
self.lock = Lock()
self._platforms = {
'core': EntityPlatform(self, self.scan_interval, None),
}
self.async_add_entities = self._platforms['core'].async_add_entities
self.add_entities = self._platforms['core'].add_entities
def setup(self, config):
@ -50,20 +53,38 @@ class EntityComponent(object):
Loads the platforms from the config and will listen for supported
discovered platforms.
"""
run_coroutine_threadsafe(
self.async_setup(config), self.hass.loop
).result()
@asyncio.coroutine
def async_setup(self, config):
"""Set up a full entity component.
Loads the platforms from the config and will listen for supported
discovered platforms.
This method must be run in the event loop.
"""
self.config = config
# Look in config for Domain, Domain 2, Domain 3 etc and load them
tasks = []
for p_type, p_config in config_per_platform(config, self.domain):
self._setup_platform(p_type, p_config)
tasks.append(self._async_setup_platform(p_type, p_config))
yield from asyncio.gather(*tasks, loop=self.hass.loop)
# Generic discovery listener for loading platform dynamically
# Refer to: homeassistant.components.discovery.load_platform()
@callback
def component_platform_discovered(platform, info):
"""Callback to load a platform."""
self._setup_platform(platform, {}, info)
self.hass.loop.create_task(
self._async_setup_platform(platform, {}, info))
discovery.listen_platform(self.hass, self.domain,
component_platform_discovered)
discovery.async_listen_platform(
self.hass, self.domain, component_platform_discovered)
def extract_from_service(self, service):
"""Extract all known entities from a service call.
@ -71,19 +92,36 @@ class EntityComponent(object):
Will return all entities if no entities specified in call.
Will return an empty list if entities specified but unknown.
"""
with self.lock:
if ATTR_ENTITY_ID not in service.data:
return list(self.entities.values())
return run_callback_threadsafe(
self.hass.loop, self.async_extract_from_service, service
).result()
return [self.entities[entity_id] for entity_id
in extract_entity_ids(self.hass, service)
if entity_id in self.entities]
def async_extract_from_service(self, service):
"""Extract all known entities from a service call.
def _setup_platform(self, platform_type, platform_config,
discovery_info=None):
"""Setup a platform for this component."""
platform = prepare_setup_platform(
self.hass, self.config, self.domain, platform_type)
Will return all entities if no entities specified in call.
Will return an empty list if entities specified but unknown.
This method must be run in the event loop.
"""
if ATTR_ENTITY_ID not in service.data:
return list(self.entities.values())
return [self.entities[entity_id] for entity_id
in extract_entity_ids(self.hass, service)
if entity_id in self.entities]
@asyncio.coroutine
def _async_setup_platform(self, platform_type, platform_config,
discovery_info=None):
"""Setup a platform for this component.
This method must be run in the event loop.
"""
platform = yield from self.hass.loop.run_in_executor(
None, prepare_setup_platform, self.hass, self.config, self.domain,
platform_type
)
if platform is None:
return
@ -102,9 +140,16 @@ class EntityComponent(object):
entity_platform = self._platforms[key]
try:
platform.setup_platform(self.hass, platform_config,
entity_platform.add_entities,
discovery_info)
if getattr(platform, 'async_setup_platform', None):
yield from platform.async_setup_platform(
self.hass, platform_config,
entity_platform.async_add_entities, discovery_info
)
else:
yield from self.hass.loop.run_in_executor(
None, platform.setup_platform, self.hass, platform_config,
entity_platform.add_entities, discovery_info
)
self.hass.config.components.append(
'{}.{}'.format(self.domain, platform_type))
@ -114,6 +159,16 @@ class EntityComponent(object):
def add_entity(self, entity, platform=None):
"""Add entity to component."""
return run_coroutine_threadsafe(
self.async_add_entity(entity, platform), self.hass.loop
).result()
@asyncio.coroutine
def async_add_entity(self, entity, platform=None):
"""Add entity to component.
This method must be run in the event loop.
"""
if entity is None or entity in self.entities.values():
return False
@ -126,40 +181,60 @@ class EntityComponent(object):
object_id = '{} {}'.format(platform.entity_namespace,
object_id)
entity.entity_id = generate_entity_id(
entity.entity_id = async_generate_entity_id(
self.entity_id_format, object_id,
self.entities.keys())
self.entities[entity.entity_id] = entity
entity.update_ha_state()
yield from entity.async_update_ha_state()
return True
def update_group(self):
"""Set up and/or update component group."""
run_callback_threadsafe(
self.hass.loop, self.async_update_group).result()
@asyncio.coroutine
def async_update_group(self):
"""Set up and/or update component group.
This method must be run in the event loop.
"""
if self.group is None and self.group_name is not None:
group = get_component('group')
self.group = group.Group(self.hass, self.group_name,
user_defined=False)
if self.group is not None:
self.group.update_tracked_entity_ids(self.entities.keys())
self.group = yield from group.Group.async_create_group(
self.hass, self.group_name, self.entities.keys(),
user_defined=False
)
elif self.group is not None:
yield from self.group.async_update_tracked_entity_ids(
self.entities.keys())
def reset(self):
"""Remove entities and reset the entity component to initial values."""
with self.lock:
for platform in self._platforms.values():
platform.reset()
run_coroutine_threadsafe(self.async_reset(), self.hass.loop).result()
self._platforms = {
'core': self._platforms['core']
}
self.entities = {}
self.config = None
@asyncio.coroutine
def async_reset(self):
"""Remove entities and reset the entity component to initial values.
if self.group is not None:
self.group.stop()
self.group = None
This method must be run in the event loop.
"""
tasks = [platform.async_reset() for platform
in self._platforms.values()]
yield from asyncio.gather(*tasks, loop=self.hass.loop)
self._platforms = {
'core': self._platforms['core']
}
self.entities = {}
self.config = None
if self.group is not None:
yield from self.group.async_stop()
self.group = None
def prepare_reload(self):
"""Prepare reloading this entity component."""
@ -178,9 +253,20 @@ class EntityComponent(object):
self.reset()
return conf
@asyncio.coroutine
def async_prepare_reload(self):
"""Prepare reloading this entity component.
This method must be run in the event loop.
"""
conf = yield from self.hass.loop.run_in_executor(
None, self.prepare_reload
)
return conf
class EntityPlatform(object):
"""Keep track of entities for a single platform."""
"""Keep track of entities for a single platform and stay in loop."""
# pylint: disable=too-few-public-methods
def __init__(self, component, scan_interval, entity_namespace):
@ -189,41 +275,58 @@ class EntityPlatform(object):
self.scan_interval = scan_interval
self.entity_namespace = entity_namespace
self.platform_entities = []
self._unsub_polling = None
self._async_unsub_polling = None
def add_entities(self, new_entities):
"""Add entities for a single platform."""
with self.component.lock:
for entity in new_entities:
if self.component.add_entity(entity, self):
self.platform_entities.append(entity)
run_coroutine_threadsafe(
self.async_add_entities(new_entities), self.component.hass.loop
).result()
self.component.update_group()
@asyncio.coroutine
def async_add_entities(self, new_entities):
"""Add entities for a single platform async.
if self._unsub_polling is not None or \
not any(entity.should_poll for entity
in self.platform_entities):
return
This method must be run in the event loop.
"""
for entity in new_entities:
ret = yield from self.component.async_add_entity(entity, self)
if ret:
self.platform_entities.append(entity)
self._unsub_polling = track_utc_time_change(
self.component.hass, self._update_entity_states,
second=range(0, 60, self.scan_interval))
yield from self.component.async_update_group()
def reset(self):
"""Remove all entities and reset data."""
for entity in self.platform_entities:
entity.remove()
if self._unsub_polling is not None:
self._unsub_polling()
self._unsub_polling = None
if self._async_unsub_polling is not None or \
not any(entity.should_poll for entity
in self.platform_entities):
return
self._async_unsub_polling = async_track_utc_time_change(
self.component.hass, self._update_entity_states,
second=range(0, 60, self.scan_interval))
@asyncio.coroutine
def async_reset(self):
"""Remove all entities and reset data.
This method must be run in the event loop.
"""
tasks = [entity.async_remove() for entity in self.platform_entities]
yield from asyncio.gather(*tasks, loop=self.component.hass.loop)
if self._async_unsub_polling is not None:
self._async_unsub_polling()
self._async_unsub_polling = None
@callback
def _update_entity_states(self, now):
"""Update the states of all the polling entities."""
with self.component.lock:
# We copy the entities because new entities might be detected
# during state update causing deadlocks.
entities = list(entity for entity in self.platform_entities
if entity.should_poll)
"""Update the states of all the polling entities.
for entity in entities:
entity.update_ha_state(True)
This method must be run in the event loop.
"""
for entity in self.platform_entities:
if entity.should_poll:
self.component.hass.loop.create_task(
entity.async_update_ha_state(True)
)