Make Service.call_from_config async

This commit is contained in:
Paulus Schoutsen 2016-09-30 22:34:45 -07:00
parent 4198c42736
commit 33a51623f8

View file

@ -1,4 +1,5 @@
"""Service calling related helpers."""
import asyncio
import functools
import logging
# pylint: disable=unused-import
@ -11,6 +12,7 @@ from homeassistant.core import HomeAssistant # NOQA
from homeassistant.exceptions import TemplateError
from homeassistant.loader import get_component
import homeassistant.helpers.config_validation as cv
from homeassistant.util.async import run_coroutine_threadsafe
HASS = None # type: Optional[HomeAssistant]
@ -37,6 +39,15 @@ def service(domain, service_name):
def call_from_config(hass, config, blocking=False, variables=None,
validate_config=True):
"""Call a service based on a config hash."""
run_coroutine_threadsafe(
async_call_from_config(hass, config, blocking, variables,
validate_config), hass.loop).result()
@asyncio.coroutine
def async_call_from_config(hass, config, blocking=False, variables=None,
validate_config=True):
"""Call a service based on a config hash."""
if validate_config:
try:
config = cv.SERVICE_SCHEMA(config)
@ -49,7 +60,8 @@ def call_from_config(hass, config, blocking=False, variables=None,
else:
try:
config[CONF_SERVICE_TEMPLATE].hass = hass
domain_service = config[CONF_SERVICE_TEMPLATE].render(variables)
domain_service = config[CONF_SERVICE_TEMPLATE].async_render(
variables)
domain_service = cv.service(domain_service)
except TemplateError as ex:
_LOGGER.error('Error rendering service name template: %s', ex)
@ -71,14 +83,15 @@ def call_from_config(hass, config, blocking=False, variables=None,
return {key: _data_template_creator(item)
for key, item in value.items()}
value.hass = hass
return value.render(variables)
return value.async_render(variables)
service_data.update(_data_template_creator(
config[CONF_SERVICE_DATA_TEMPLATE]))
if CONF_SERVICE_ENTITY_ID in config:
service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID]
hass.services.call(domain, service_name, service_data, blocking)
yield from hass.services.async_call(
domain, service_name, service_data, blocking)
def extract_entity_ids(hass, service_call):