Async gather wait (#4247)

* Fix config validation for input_*, script

* Allow scheduling coroutines

* Validate entity ids when entity ids set by platform

* Async: gather -> wait

* Script/Group: use async_add_job instead of create_task
This commit is contained in:
Paulus Schoutsen 2016-11-06 09:26:40 -08:00 committed by GitHub
parent d4e8b831a0
commit a343c20404
13 changed files with 90 additions and 60 deletions

View file

@ -119,7 +119,7 @@ def async_setup(hass, config):
tasks.append(hass.services.async_call( tasks.append(hass.services.async_call(
domain, service.service, data, blocking)) domain, service.service, data, blocking))
yield from asyncio.gather(*tasks, loop=hass.loop) yield from asyncio.wait(tasks, loop=hass.loop)
hass.services.async_register( hass.services.async_register(
ha.DOMAIN, SERVICE_TURN_OFF, handle_turn_service) ha.DOMAIN, SERVICE_TURN_OFF, handle_turn_service)

View file

@ -165,7 +165,7 @@ def async_setup(hass, config):
for entity in component.async_extract_from_service(service_call): for entity in component.async_extract_from_service(service_call):
tasks.append(entity.async_trigger( tasks.append(entity.async_trigger(
service_call.data.get(ATTR_VARIABLES), True)) service_call.data.get(ATTR_VARIABLES), True))
yield from asyncio.gather(*tasks, loop=hass.loop) yield from asyncio.wait(tasks, loop=hass.loop)
@asyncio.coroutine @asyncio.coroutine
def turn_onoff_service_handler(service_call): def turn_onoff_service_handler(service_call):
@ -174,7 +174,7 @@ def async_setup(hass, config):
method = 'async_{}'.format(service_call.service) method = 'async_{}'.format(service_call.service)
for entity in component.async_extract_from_service(service_call): for entity in component.async_extract_from_service(service_call):
tasks.append(getattr(entity, method)()) tasks.append(getattr(entity, method)())
yield from asyncio.gather(*tasks, loop=hass.loop) yield from asyncio.wait(tasks, loop=hass.loop)
@asyncio.coroutine @asyncio.coroutine
def toggle_service_handler(service_call): def toggle_service_handler(service_call):
@ -185,7 +185,7 @@ def async_setup(hass, config):
tasks.append(entity.async_turn_off()) tasks.append(entity.async_turn_off())
else: else:
tasks.append(entity.async_turn_on()) tasks.append(entity.async_turn_on())
yield from asyncio.gather(*tasks, loop=hass.loop) yield from asyncio.wait(tasks, loop=hass.loop)
@asyncio.coroutine @asyncio.coroutine
def reload_service_handler(service_call): def reload_service_handler(service_call):
@ -348,8 +348,10 @@ def _async_process_config(hass, config, component):
tasks.append(entity.async_enable()) tasks.append(entity.async_enable())
entities.append(entity) entities.append(entity)
yield from asyncio.gather(*tasks, loop=hass.loop) if tasks:
yield from component.async_add_entities(entities) yield from asyncio.wait(tasks, loop=hass.loop)
if entities:
yield from component.async_add_entities(entities)
return len(entities) > 0 return len(entities) > 0

View file

@ -144,7 +144,7 @@ def async_setup_platform(hass, config, async_add_devices, discovery_info=None):
tasks.append(device.async_read_sid()) tasks.append(device.async_read_sid())
devices.append(device) devices.append(device)
yield from asyncio.gather(*tasks, loop=hass.loop) yield from asyncio.wait(tasks, loop=hass.loop)
hass.loop.create_task(async_add_devices(devices)) hass.loop.create_task(async_add_devices(devices))

View file

@ -184,7 +184,7 @@ def async_setup(hass, config):
tasks = [group.async_set_visible(visible) for group tasks = [group.async_set_visible(visible) for group
in component.async_extract_from_service(service, in component.async_extract_from_service(service,
expand_group=False)] expand_group=False)]
yield from asyncio.gather(*tasks, loop=hass.loop) yield from asyncio.wait(tasks, loop=hass.loop)
hass.services.async_register( hass.services.async_register(
DOMAIN, SERVICE_SET_VISIBILITY, visibility_service_handler, DOMAIN, SERVICE_SET_VISIBILITY, visibility_service_handler,
@ -207,13 +207,14 @@ def _async_process_config(hass, config, component):
icon = conf.get(CONF_ICON) icon = conf.get(CONF_ICON)
view = conf.get(CONF_VIEW) view = conf.get(CONF_VIEW)
# This order is important as groups get a number based on creation # Don't create tasks and await them all. The order is important as
# order. # groups get a number based on creation order.
group = yield from Group.async_create_group( group = yield from Group.async_create_group(
hass, name, entity_ids, icon=icon, view=view, object_id=object_id) hass, name, entity_ids, icon=icon, view=view, object_id=object_id)
groups.append(group) groups.append(group)
yield from component.async_add_entities(groups) if groups:
yield from component.async_add_entities(groups)
class Group(Entity): class Group(Entity):
@ -394,7 +395,7 @@ class Group(Entity):
This method must be run in the event loop. This method must be run in the event loop.
""" """
self._async_update_group_state(new_state) self._async_update_group_state(new_state)
self.hass.loop.create_task(self.async_update_ha_state()) self.hass.async_add_job(self.async_update_ha_state())
@property @property
def _tracking_states(self): def _tracking_states(self):

View file

@ -23,17 +23,23 @@ ENTITY_ID_FORMAT = DOMAIN + '.{}'
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
CONF_INITIAL = 'initial' CONF_INITIAL = 'initial'
DEFAULT_INITIAL = False
SERVICE_SCHEMA = vol.Schema({ SERVICE_SCHEMA = vol.Schema({
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids, vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
}) })
CONFIG_SCHEMA = vol.Schema({DOMAIN: { DEFAULT_CONFIG = {CONF_INITIAL: DEFAULT_INITIAL}
cv.slug: vol.Any({
vol.Optional(CONF_NAME): cv.string, CONFIG_SCHEMA = vol.Schema({
vol.Optional(CONF_INITIAL, default=False): cv.boolean, DOMAIN: vol.Schema({
vol.Optional(CONF_ICON): cv.icon, cv.slug: vol.Any({
}, None)}}, extra=vol.ALLOW_EXTRA) vol.Optional(CONF_NAME): cv.string,
vol.Optional(CONF_INITIAL, default=DEFAULT_INITIAL): cv.boolean,
vol.Optional(CONF_ICON): cv.icon,
}, None)
})
}, extra=vol.ALLOW_EXTRA)
def is_on(hass, entity_id): def is_on(hass, entity_id):
@ -65,10 +71,10 @@ def async_setup(hass, config):
for object_id, cfg in config[DOMAIN].items(): for object_id, cfg in config[DOMAIN].items():
if not cfg: if not cfg:
cfg = {} cfg = DEFAULT_CONFIG
name = cfg.get(CONF_NAME) name = cfg.get(CONF_NAME)
state = cfg.get(CONF_INITIAL, False) state = cfg.get(CONF_INITIAL)
icon = cfg.get(CONF_ICON) icon = cfg.get(CONF_ICON)
entities.append(InputBoolean(object_id, name, state, icon)) entities.append(InputBoolean(object_id, name, state, icon))
@ -89,7 +95,7 @@ def async_setup(hass, config):
attr = 'async_toggle' attr = 'async_toggle'
tasks = [getattr(input_b, attr)() for input_b in target_inputs] tasks = [getattr(input_b, attr)() for input_b in target_inputs]
yield from asyncio.gather(*tasks, loop=hass.loop) yield from asyncio.wait(tasks, loop=hass.loop)
hass.services.async_register( hass.services.async_register(
DOMAIN, SERVICE_TURN_OFF, async_handler_service, schema=SERVICE_SCHEMA) DOMAIN, SERVICE_TURN_OFF, async_handler_service, schema=SERVICE_SCHEMA)

View file

@ -55,14 +55,16 @@ def _cv_input_select(cfg):
return cfg return cfg
CONFIG_SCHEMA = vol.Schema({DOMAIN: { CONFIG_SCHEMA = vol.Schema({
cv.slug: vol.All({ DOMAIN: vol.Schema({
vol.Optional(CONF_NAME): cv.string, cv.slug: vol.All({
vol.Required(CONF_OPTIONS): vol.All(cv.ensure_list, vol.Length(min=1), vol.Optional(CONF_NAME): cv.string,
[cv.string]), vol.Required(CONF_OPTIONS):
vol.Optional(CONF_INITIAL): cv.string, vol.All(cv.ensure_list, vol.Length(min=1), [cv.string]),
vol.Optional(CONF_ICON): cv.icon, vol.Optional(CONF_INITIAL): cv.string,
}, _cv_input_select)}}, required=True, extra=vol.ALLOW_EXTRA) vol.Optional(CONF_ICON): cv.icon,
}, _cv_input_select)})
}, required=True, extra=vol.ALLOW_EXTRA)
def select_option(hass, entity_id, option): def select_option(hass, entity_id, option):
@ -111,7 +113,7 @@ def async_setup(hass, config):
tasks = [input_select.async_select_option(call.data[ATTR_OPTION]) tasks = [input_select.async_select_option(call.data[ATTR_OPTION])
for input_select in target_inputs] for input_select in target_inputs]
yield from asyncio.gather(*tasks, loop=hass.loop) yield from asyncio.wait(tasks, loop=hass.loop)
hass.services.async_register( hass.services.async_register(
DOMAIN, SERVICE_SELECT_OPTION, async_select_option_service, DOMAIN, SERVICE_SELECT_OPTION, async_select_option_service,
@ -124,7 +126,7 @@ def async_setup(hass, config):
tasks = [input_select.async_offset_index(1) tasks = [input_select.async_offset_index(1)
for input_select in target_inputs] for input_select in target_inputs]
yield from asyncio.gather(*tasks, loop=hass.loop) yield from asyncio.wait(tasks, loop=hass.loop)
hass.services.async_register( hass.services.async_register(
DOMAIN, SERVICE_SELECT_NEXT, async_select_next_service, DOMAIN, SERVICE_SELECT_NEXT, async_select_next_service,
@ -137,7 +139,7 @@ def async_setup(hass, config):
tasks = [input_select.async_offset_index(-1) tasks = [input_select.async_offset_index(-1)
for input_select in target_inputs] for input_select in target_inputs]
yield from asyncio.gather(*tasks, loop=hass.loop) yield from asyncio.wait(tasks, loop=hass.loop)
hass.services.async_register( hass.services.async_register(
DOMAIN, SERVICE_SELECT_PREVIOUS, async_select_previous_service, DOMAIN, SERVICE_SELECT_PREVIOUS, async_select_previous_service,

View file

@ -51,17 +51,20 @@ def _cv_input_slider(cfg):
cfg[CONF_INITIAL] = state cfg[CONF_INITIAL] = state
return cfg return cfg
CONFIG_SCHEMA = vol.Schema({DOMAIN: { CONFIG_SCHEMA = vol.Schema({
cv.slug: vol.All({ DOMAIN: vol.Schema({
vol.Optional(CONF_NAME): cv.string, cv.slug: vol.All({
vol.Required(CONF_MIN): vol.Coerce(float), vol.Optional(CONF_NAME): cv.string,
vol.Required(CONF_MAX): vol.Coerce(float), vol.Required(CONF_MIN): vol.Coerce(float),
vol.Optional(CONF_INITIAL): vol.Coerce(float), vol.Required(CONF_MAX): vol.Coerce(float),
vol.Optional(CONF_STEP, default=1): vol.All(vol.Coerce(float), vol.Optional(CONF_INITIAL): vol.Coerce(float),
vol.Range(min=1e-3)), vol.Optional(CONF_STEP, default=1): vol.All(vol.Coerce(float),
vol.Optional(CONF_ICON): cv.icon, vol.Range(min=1e-3)),
vol.Optional(ATTR_UNIT_OF_MEASUREMENT): cv.string vol.Optional(CONF_ICON): cv.icon,
}, _cv_input_slider)}}, required=True, extra=vol.ALLOW_EXTRA) vol.Optional(ATTR_UNIT_OF_MEASUREMENT): cv.string
}, _cv_input_slider)
})
}, required=True, extra=vol.ALLOW_EXTRA)
def select_value(hass, entity_id, value): def select_value(hass, entity_id, value):
@ -101,7 +104,7 @@ def async_setup(hass, config):
tasks = [input_slider.async_select_value(call.data[ATTR_VALUE]) tasks = [input_slider.async_select_value(call.data[ATTR_VALUE])
for input_slider in target_inputs] for input_slider in target_inputs]
yield from asyncio.gather(*tasks, loop=hass.loop) yield from asyncio.wait(tasks, loop=hass.loop)
hass.services.async_register( hass.services.async_register(
DOMAIN, SERVICE_SELECT_VALUE, async_select_value_service, DOMAIN, SERVICE_SELECT_VALUE, async_select_value_service,

View file

@ -40,7 +40,7 @@ _SCRIPT_ENTRY_SCHEMA = vol.Schema({
}) })
CONFIG_SCHEMA = vol.Schema({ CONFIG_SCHEMA = vol.Schema({
vol.Required(DOMAIN): {cv.slug: _SCRIPT_ENTRY_SCHEMA} vol.Required(DOMAIN): vol.Schema({cv.slug: _SCRIPT_ENTRY_SCHEMA})
}, extra=vol.ALLOW_EXTRA) }, extra=vol.ALLOW_EXTRA)
SCRIPT_SERVICE_SCHEMA = vol.Schema(dict) SCRIPT_SERVICE_SCHEMA = vol.Schema(dict)

View file

@ -218,4 +218,5 @@ class YrData(object):
dev._state = new_state dev._state = new_state
tasks.append(dev.async_update_ha_state()) tasks.append(dev.async_update_ha_state())
yield from asyncio.gather(*tasks, loop=self.hass.loop) if tasks:
yield from asyncio.wait(tasks, loop=self.hass.loop)

View file

@ -110,7 +110,7 @@ def async_setup(hass, config):
zone.entity_id = ENTITY_ID_HOME zone.entity_id = ENTITY_ID_HOME
tasks.append(zone.async_update_ha_state()) tasks.append(zone.async_update_ha_state())
yield from asyncio.gather(*tasks, loop=hass.loop) yield from asyncio.wait(tasks, loop=hass.loop)
return True return True

View file

@ -207,7 +207,9 @@ class HomeAssistant(object):
""" """
task = None task = None
if is_callback(target): if asyncio.iscoroutine(target):
task = self.loop.create_task(target)
elif is_callback(target):
self.loop.call_soon(target, *args) self.loop.call_soon(target, *args)
elif asyncio.iscoroutinefunction(target): elif asyncio.iscoroutinefunction(target):
task = self.loop.create_task(target(*args)) task = self.loop.create_task(target(*args))

View file

@ -7,7 +7,7 @@ from homeassistant.bootstrap import (
from homeassistant.const import ( from homeassistant.const import (
ATTR_ENTITY_ID, CONF_SCAN_INTERVAL, CONF_ENTITY_NAMESPACE, ATTR_ENTITY_ID, CONF_SCAN_INTERVAL, CONF_ENTITY_NAMESPACE,
DEVICE_DEFAULT_NAME) DEVICE_DEFAULT_NAME)
from homeassistant.core import callback from homeassistant.core import callback, valid_entity_id
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.loader import get_component from homeassistant.loader import get_component
from homeassistant.helpers import config_per_platform, discovery from homeassistant.helpers import config_per_platform, discovery
@ -71,14 +71,15 @@ class EntityComponent(object):
for p_type, p_config in config_per_platform(config, self.domain): for p_type, p_config in config_per_platform(config, self.domain):
tasks.append(self._async_setup_platform(p_type, p_config)) tasks.append(self._async_setup_platform(p_type, p_config))
yield from asyncio.gather(*tasks, loop=self.hass.loop) if tasks:
yield from asyncio.wait(tasks, loop=self.hass.loop)
# Generic discovery listener for loading platform dynamically # Generic discovery listener for loading platform dynamically
# Refer to: homeassistant.components.discovery.load_platform() # Refer to: homeassistant.components.discovery.load_platform()
@callback @callback
def component_platform_discovered(platform, info): def component_platform_discovered(platform, info):
"""Callback to load a platform.""" """Callback to load a platform."""
self.hass.loop.create_task( self.hass.async_add_job(
self._async_setup_platform(platform, {}, info)) self._async_setup_platform(platform, {}, info))
discovery.async_listen_platform( discovery.async_listen_platform(
@ -190,6 +191,14 @@ class EntityComponent(object):
self.entity_id_format, object_id, self.entity_id_format, object_id,
self.entities.keys()) self.entities.keys())
# Make sure it is valid in case an entity set the value themselves
if entity.entity_id in self.entities:
raise HomeAssistantError(
'Entity id already exists: {}'.format(entity.entity_id))
elif not valid_entity_id(entity.entity_id):
raise HomeAssistantError(
'Invalid entity id: {}'.format(entity.entity_id))
self.entities[entity.entity_id] = entity self.entities[entity.entity_id] = entity
yield from entity.async_update_ha_state() yield from entity.async_update_ha_state()
@ -229,7 +238,8 @@ class EntityComponent(object):
tasks = [platform.async_reset() for platform tasks = [platform.async_reset() for platform
in self._platforms.values()] in self._platforms.values()]
yield from asyncio.gather(*tasks, loop=self.hass.loop) if tasks:
yield from asyncio.wait(tasks, loop=self.hass.loop)
self._platforms = { self._platforms = {
'core': self._platforms['core'] 'core': self._platforms['core']
@ -293,14 +303,14 @@ class EntityPlatform(object):
This method must be run in the event loop. This method must be run in the event loop.
""" """
# handle empty list from component/platform
if not new_entities:
return
tasks = [self._async_process_entity(entity, update_before_add) tasks = [self._async_process_entity(entity, update_before_add)
for entity in new_entities] for entity in new_entities]
# handle empty list from component/platform yield from asyncio.wait(tasks, loop=self.component.hass.loop)
if not tasks:
return
yield from asyncio.gather(*tasks, loop=self.component.hass.loop)
yield from self.component.async_update_group() yield from self.component.async_update_group()
if self._async_unsub_polling is not None or \ if self._async_unsub_polling is not None or \
@ -327,9 +337,12 @@ class EntityPlatform(object):
This method must be run in the event loop. This method must be run in the event loop.
""" """
if not self.platform_entities:
return
tasks = [entity.async_remove() for entity in self.platform_entities] tasks = [entity.async_remove() for entity in self.platform_entities]
yield from asyncio.gather(*tasks, loop=self.component.hass.loop) yield from asyncio.wait(tasks, loop=self.component.hass.loop)
if self._async_unsub_polling is not None: if self._async_unsub_polling is not None:
self._async_unsub_polling() self._async_unsub_polling()
@ -343,6 +356,6 @@ class EntityPlatform(object):
""" """
for entity in self.platform_entities: for entity in self.platform_entities:
if entity.should_poll: if entity.should_poll:
self.component.hass.loop.create_task( self.component.hass.async_add_job(
entity.async_update_ha_state(True) entity.async_update_ha_state(True)
) )

View file

@ -85,7 +85,7 @@ class Script():
def script_delay(now): def script_delay(now):
"""Called after delay is done.""" """Called after delay is done."""
self._async_unsub_delay_listener = None self._async_unsub_delay_listener = None
self.hass.loop.create_task(self.async_run(variables)) self.hass.async_add_job(self.async_run(variables))
delay = action[CONF_DELAY] delay = action[CONF_DELAY]