Async gather wait (#4247)
* Fix config validation for input_*, script * Allow scheduling coroutines * Validate entity ids when entity ids set by platform * Async: gather -> wait * Script/Group: use async_add_job instead of create_task
This commit is contained in:
parent
d4e8b831a0
commit
a343c20404
13 changed files with 90 additions and 60 deletions
|
@ -119,7 +119,7 @@ def async_setup(hass, config):
|
|||
tasks.append(hass.services.async_call(
|
||||
domain, service.service, data, blocking))
|
||||
|
||||
yield from asyncio.gather(*tasks, loop=hass.loop)
|
||||
yield from asyncio.wait(tasks, loop=hass.loop)
|
||||
|
||||
hass.services.async_register(
|
||||
ha.DOMAIN, SERVICE_TURN_OFF, handle_turn_service)
|
||||
|
|
|
@ -165,7 +165,7 @@ def async_setup(hass, config):
|
|||
for entity in component.async_extract_from_service(service_call):
|
||||
tasks.append(entity.async_trigger(
|
||||
service_call.data.get(ATTR_VARIABLES), True))
|
||||
yield from asyncio.gather(*tasks, loop=hass.loop)
|
||||
yield from asyncio.wait(tasks, loop=hass.loop)
|
||||
|
||||
@asyncio.coroutine
|
||||
def turn_onoff_service_handler(service_call):
|
||||
|
@ -174,7 +174,7 @@ def async_setup(hass, config):
|
|||
method = 'async_{}'.format(service_call.service)
|
||||
for entity in component.async_extract_from_service(service_call):
|
||||
tasks.append(getattr(entity, method)())
|
||||
yield from asyncio.gather(*tasks, loop=hass.loop)
|
||||
yield from asyncio.wait(tasks, loop=hass.loop)
|
||||
|
||||
@asyncio.coroutine
|
||||
def toggle_service_handler(service_call):
|
||||
|
@ -185,7 +185,7 @@ def async_setup(hass, config):
|
|||
tasks.append(entity.async_turn_off())
|
||||
else:
|
||||
tasks.append(entity.async_turn_on())
|
||||
yield from asyncio.gather(*tasks, loop=hass.loop)
|
||||
yield from asyncio.wait(tasks, loop=hass.loop)
|
||||
|
||||
@asyncio.coroutine
|
||||
def reload_service_handler(service_call):
|
||||
|
@ -348,8 +348,10 @@ def _async_process_config(hass, config, component):
|
|||
tasks.append(entity.async_enable())
|
||||
entities.append(entity)
|
||||
|
||||
yield from asyncio.gather(*tasks, loop=hass.loop)
|
||||
yield from component.async_add_entities(entities)
|
||||
if tasks:
|
||||
yield from asyncio.wait(tasks, loop=hass.loop)
|
||||
if entities:
|
||||
yield from component.async_add_entities(entities)
|
||||
|
||||
return len(entities) > 0
|
||||
|
||||
|
|
|
@ -144,7 +144,7 @@ def async_setup_platform(hass, config, async_add_devices, discovery_info=None):
|
|||
tasks.append(device.async_read_sid())
|
||||
devices.append(device)
|
||||
|
||||
yield from asyncio.gather(*tasks, loop=hass.loop)
|
||||
yield from asyncio.wait(tasks, loop=hass.loop)
|
||||
hass.loop.create_task(async_add_devices(devices))
|
||||
|
||||
|
||||
|
|
|
@ -184,7 +184,7 @@ def async_setup(hass, config):
|
|||
tasks = [group.async_set_visible(visible) for group
|
||||
in component.async_extract_from_service(service,
|
||||
expand_group=False)]
|
||||
yield from asyncio.gather(*tasks, loop=hass.loop)
|
||||
yield from asyncio.wait(tasks, loop=hass.loop)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_SET_VISIBILITY, visibility_service_handler,
|
||||
|
@ -207,13 +207,14 @@ def _async_process_config(hass, config, component):
|
|||
icon = conf.get(CONF_ICON)
|
||||
view = conf.get(CONF_VIEW)
|
||||
|
||||
# This order is important as groups get a number based on creation
|
||||
# order.
|
||||
# Don't create tasks and await them all. The order is important as
|
||||
# groups get a number based on creation order.
|
||||
group = yield from Group.async_create_group(
|
||||
hass, name, entity_ids, icon=icon, view=view, object_id=object_id)
|
||||
groups.append(group)
|
||||
|
||||
yield from component.async_add_entities(groups)
|
||||
if groups:
|
||||
yield from component.async_add_entities(groups)
|
||||
|
||||
|
||||
class Group(Entity):
|
||||
|
@ -394,7 +395,7 @@ class Group(Entity):
|
|||
This method must be run in the event loop.
|
||||
"""
|
||||
self._async_update_group_state(new_state)
|
||||
self.hass.loop.create_task(self.async_update_ha_state())
|
||||
self.hass.async_add_job(self.async_update_ha_state())
|
||||
|
||||
@property
|
||||
def _tracking_states(self):
|
||||
|
|
|
@ -23,17 +23,23 @@ ENTITY_ID_FORMAT = DOMAIN + '.{}'
|
|||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
CONF_INITIAL = 'initial'
|
||||
DEFAULT_INITIAL = False
|
||||
|
||||
SERVICE_SCHEMA = vol.Schema({
|
||||
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
|
||||
})
|
||||
|
||||
CONFIG_SCHEMA = vol.Schema({DOMAIN: {
|
||||
cv.slug: vol.Any({
|
||||
vol.Optional(CONF_NAME): cv.string,
|
||||
vol.Optional(CONF_INITIAL, default=False): cv.boolean,
|
||||
vol.Optional(CONF_ICON): cv.icon,
|
||||
}, None)}}, extra=vol.ALLOW_EXTRA)
|
||||
DEFAULT_CONFIG = {CONF_INITIAL: DEFAULT_INITIAL}
|
||||
|
||||
CONFIG_SCHEMA = vol.Schema({
|
||||
DOMAIN: vol.Schema({
|
||||
cv.slug: vol.Any({
|
||||
vol.Optional(CONF_NAME): cv.string,
|
||||
vol.Optional(CONF_INITIAL, default=DEFAULT_INITIAL): cv.boolean,
|
||||
vol.Optional(CONF_ICON): cv.icon,
|
||||
}, None)
|
||||
})
|
||||
}, extra=vol.ALLOW_EXTRA)
|
||||
|
||||
|
||||
def is_on(hass, entity_id):
|
||||
|
@ -65,10 +71,10 @@ def async_setup(hass, config):
|
|||
|
||||
for object_id, cfg in config[DOMAIN].items():
|
||||
if not cfg:
|
||||
cfg = {}
|
||||
cfg = DEFAULT_CONFIG
|
||||
|
||||
name = cfg.get(CONF_NAME)
|
||||
state = cfg.get(CONF_INITIAL, False)
|
||||
state = cfg.get(CONF_INITIAL)
|
||||
icon = cfg.get(CONF_ICON)
|
||||
|
||||
entities.append(InputBoolean(object_id, name, state, icon))
|
||||
|
@ -89,7 +95,7 @@ def async_setup(hass, config):
|
|||
attr = 'async_toggle'
|
||||
|
||||
tasks = [getattr(input_b, attr)() for input_b in target_inputs]
|
||||
yield from asyncio.gather(*tasks, loop=hass.loop)
|
||||
yield from asyncio.wait(tasks, loop=hass.loop)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_TURN_OFF, async_handler_service, schema=SERVICE_SCHEMA)
|
||||
|
|
|
@ -55,14 +55,16 @@ def _cv_input_select(cfg):
|
|||
return cfg
|
||||
|
||||
|
||||
CONFIG_SCHEMA = vol.Schema({DOMAIN: {
|
||||
cv.slug: vol.All({
|
||||
vol.Optional(CONF_NAME): cv.string,
|
||||
vol.Required(CONF_OPTIONS): vol.All(cv.ensure_list, vol.Length(min=1),
|
||||
[cv.string]),
|
||||
vol.Optional(CONF_INITIAL): cv.string,
|
||||
vol.Optional(CONF_ICON): cv.icon,
|
||||
}, _cv_input_select)}}, required=True, extra=vol.ALLOW_EXTRA)
|
||||
CONFIG_SCHEMA = vol.Schema({
|
||||
DOMAIN: vol.Schema({
|
||||
cv.slug: vol.All({
|
||||
vol.Optional(CONF_NAME): cv.string,
|
||||
vol.Required(CONF_OPTIONS):
|
||||
vol.All(cv.ensure_list, vol.Length(min=1), [cv.string]),
|
||||
vol.Optional(CONF_INITIAL): cv.string,
|
||||
vol.Optional(CONF_ICON): cv.icon,
|
||||
}, _cv_input_select)})
|
||||
}, required=True, extra=vol.ALLOW_EXTRA)
|
||||
|
||||
|
||||
def select_option(hass, entity_id, option):
|
||||
|
@ -111,7 +113,7 @@ def async_setup(hass, config):
|
|||
|
||||
tasks = [input_select.async_select_option(call.data[ATTR_OPTION])
|
||||
for input_select in target_inputs]
|
||||
yield from asyncio.gather(*tasks, loop=hass.loop)
|
||||
yield from asyncio.wait(tasks, loop=hass.loop)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_SELECT_OPTION, async_select_option_service,
|
||||
|
@ -124,7 +126,7 @@ def async_setup(hass, config):
|
|||
|
||||
tasks = [input_select.async_offset_index(1)
|
||||
for input_select in target_inputs]
|
||||
yield from asyncio.gather(*tasks, loop=hass.loop)
|
||||
yield from asyncio.wait(tasks, loop=hass.loop)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_SELECT_NEXT, async_select_next_service,
|
||||
|
@ -137,7 +139,7 @@ def async_setup(hass, config):
|
|||
|
||||
tasks = [input_select.async_offset_index(-1)
|
||||
for input_select in target_inputs]
|
||||
yield from asyncio.gather(*tasks, loop=hass.loop)
|
||||
yield from asyncio.wait(tasks, loop=hass.loop)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_SELECT_PREVIOUS, async_select_previous_service,
|
||||
|
|
|
@ -51,17 +51,20 @@ def _cv_input_slider(cfg):
|
|||
cfg[CONF_INITIAL] = state
|
||||
return cfg
|
||||
|
||||
CONFIG_SCHEMA = vol.Schema({DOMAIN: {
|
||||
cv.slug: vol.All({
|
||||
vol.Optional(CONF_NAME): cv.string,
|
||||
vol.Required(CONF_MIN): vol.Coerce(float),
|
||||
vol.Required(CONF_MAX): vol.Coerce(float),
|
||||
vol.Optional(CONF_INITIAL): vol.Coerce(float),
|
||||
vol.Optional(CONF_STEP, default=1): vol.All(vol.Coerce(float),
|
||||
vol.Range(min=1e-3)),
|
||||
vol.Optional(CONF_ICON): cv.icon,
|
||||
vol.Optional(ATTR_UNIT_OF_MEASUREMENT): cv.string
|
||||
}, _cv_input_slider)}}, required=True, extra=vol.ALLOW_EXTRA)
|
||||
CONFIG_SCHEMA = vol.Schema({
|
||||
DOMAIN: vol.Schema({
|
||||
cv.slug: vol.All({
|
||||
vol.Optional(CONF_NAME): cv.string,
|
||||
vol.Required(CONF_MIN): vol.Coerce(float),
|
||||
vol.Required(CONF_MAX): vol.Coerce(float),
|
||||
vol.Optional(CONF_INITIAL): vol.Coerce(float),
|
||||
vol.Optional(CONF_STEP, default=1): vol.All(vol.Coerce(float),
|
||||
vol.Range(min=1e-3)),
|
||||
vol.Optional(CONF_ICON): cv.icon,
|
||||
vol.Optional(ATTR_UNIT_OF_MEASUREMENT): cv.string
|
||||
}, _cv_input_slider)
|
||||
})
|
||||
}, required=True, extra=vol.ALLOW_EXTRA)
|
||||
|
||||
|
||||
def select_value(hass, entity_id, value):
|
||||
|
@ -101,7 +104,7 @@ def async_setup(hass, config):
|
|||
|
||||
tasks = [input_slider.async_select_value(call.data[ATTR_VALUE])
|
||||
for input_slider in target_inputs]
|
||||
yield from asyncio.gather(*tasks, loop=hass.loop)
|
||||
yield from asyncio.wait(tasks, loop=hass.loop)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_SELECT_VALUE, async_select_value_service,
|
||||
|
|
|
@ -40,7 +40,7 @@ _SCRIPT_ENTRY_SCHEMA = vol.Schema({
|
|||
})
|
||||
|
||||
CONFIG_SCHEMA = vol.Schema({
|
||||
vol.Required(DOMAIN): {cv.slug: _SCRIPT_ENTRY_SCHEMA}
|
||||
vol.Required(DOMAIN): vol.Schema({cv.slug: _SCRIPT_ENTRY_SCHEMA})
|
||||
}, extra=vol.ALLOW_EXTRA)
|
||||
|
||||
SCRIPT_SERVICE_SCHEMA = vol.Schema(dict)
|
||||
|
|
|
@ -218,4 +218,5 @@ class YrData(object):
|
|||
dev._state = new_state
|
||||
tasks.append(dev.async_update_ha_state())
|
||||
|
||||
yield from asyncio.gather(*tasks, loop=self.hass.loop)
|
||||
if tasks:
|
||||
yield from asyncio.wait(tasks, loop=self.hass.loop)
|
||||
|
|
|
@ -110,7 +110,7 @@ def async_setup(hass, config):
|
|||
zone.entity_id = ENTITY_ID_HOME
|
||||
tasks.append(zone.async_update_ha_state())
|
||||
|
||||
yield from asyncio.gather(*tasks, loop=hass.loop)
|
||||
yield from asyncio.wait(tasks, loop=hass.loop)
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
@ -207,7 +207,9 @@ class HomeAssistant(object):
|
|||
"""
|
||||
task = None
|
||||
|
||||
if is_callback(target):
|
||||
if asyncio.iscoroutine(target):
|
||||
task = self.loop.create_task(target)
|
||||
elif is_callback(target):
|
||||
self.loop.call_soon(target, *args)
|
||||
elif asyncio.iscoroutinefunction(target):
|
||||
task = self.loop.create_task(target(*args))
|
||||
|
|
|
@ -7,7 +7,7 @@ from homeassistant.bootstrap import (
|
|||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID, CONF_SCAN_INTERVAL, CONF_ENTITY_NAMESPACE,
|
||||
DEVICE_DEFAULT_NAME)
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import callback, valid_entity_id
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.loader import get_component
|
||||
from homeassistant.helpers import config_per_platform, discovery
|
||||
|
@ -71,14 +71,15 @@ class EntityComponent(object):
|
|||
for p_type, p_config in config_per_platform(config, self.domain):
|
||||
tasks.append(self._async_setup_platform(p_type, p_config))
|
||||
|
||||
yield from asyncio.gather(*tasks, loop=self.hass.loop)
|
||||
if tasks:
|
||||
yield from asyncio.wait(tasks, loop=self.hass.loop)
|
||||
|
||||
# Generic discovery listener for loading platform dynamically
|
||||
# Refer to: homeassistant.components.discovery.load_platform()
|
||||
@callback
|
||||
def component_platform_discovered(platform, info):
|
||||
"""Callback to load a platform."""
|
||||
self.hass.loop.create_task(
|
||||
self.hass.async_add_job(
|
||||
self._async_setup_platform(platform, {}, info))
|
||||
|
||||
discovery.async_listen_platform(
|
||||
|
@ -190,6 +191,14 @@ class EntityComponent(object):
|
|||
self.entity_id_format, object_id,
|
||||
self.entities.keys())
|
||||
|
||||
# Make sure it is valid in case an entity set the value themselves
|
||||
if entity.entity_id in self.entities:
|
||||
raise HomeAssistantError(
|
||||
'Entity id already exists: {}'.format(entity.entity_id))
|
||||
elif not valid_entity_id(entity.entity_id):
|
||||
raise HomeAssistantError(
|
||||
'Invalid entity id: {}'.format(entity.entity_id))
|
||||
|
||||
self.entities[entity.entity_id] = entity
|
||||
yield from entity.async_update_ha_state()
|
||||
|
||||
|
@ -229,7 +238,8 @@ class EntityComponent(object):
|
|||
tasks = [platform.async_reset() for platform
|
||||
in self._platforms.values()]
|
||||
|
||||
yield from asyncio.gather(*tasks, loop=self.hass.loop)
|
||||
if tasks:
|
||||
yield from asyncio.wait(tasks, loop=self.hass.loop)
|
||||
|
||||
self._platforms = {
|
||||
'core': self._platforms['core']
|
||||
|
@ -293,14 +303,14 @@ class EntityPlatform(object):
|
|||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
# handle empty list from component/platform
|
||||
if not new_entities:
|
||||
return
|
||||
|
||||
tasks = [self._async_process_entity(entity, update_before_add)
|
||||
for entity in new_entities]
|
||||
|
||||
# handle empty list from component/platform
|
||||
if not tasks:
|
||||
return
|
||||
|
||||
yield from asyncio.gather(*tasks, loop=self.component.hass.loop)
|
||||
yield from asyncio.wait(tasks, loop=self.component.hass.loop)
|
||||
yield from self.component.async_update_group()
|
||||
|
||||
if self._async_unsub_polling is not None or \
|
||||
|
@ -327,9 +337,12 @@ class EntityPlatform(object):
|
|||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
if not self.platform_entities:
|
||||
return
|
||||
|
||||
tasks = [entity.async_remove() for entity in self.platform_entities]
|
||||
|
||||
yield from asyncio.gather(*tasks, loop=self.component.hass.loop)
|
||||
yield from asyncio.wait(tasks, loop=self.component.hass.loop)
|
||||
|
||||
if self._async_unsub_polling is not None:
|
||||
self._async_unsub_polling()
|
||||
|
@ -343,6 +356,6 @@ class EntityPlatform(object):
|
|||
"""
|
||||
for entity in self.platform_entities:
|
||||
if entity.should_poll:
|
||||
self.component.hass.loop.create_task(
|
||||
self.component.hass.async_add_job(
|
||||
entity.async_update_ha_state(True)
|
||||
)
|
||||
|
|
|
@ -85,7 +85,7 @@ class Script():
|
|||
def script_delay(now):
|
||||
"""Called after delay is done."""
|
||||
self._async_unsub_delay_listener = None
|
||||
self.hass.loop.create_task(self.async_run(variables))
|
||||
self.hass.async_add_job(self.async_run(variables))
|
||||
|
||||
delay = action[CONF_DELAY]
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue