Address asyncio comments (#3663)
* Template platforms: create_task instead of yield from * Automation: less yielding, more create_tasking * Helpers.script: less yielding, more create_tasking * Deflake logbook test * Deflake automation reload config test * MQTT: Use async_add_job and threaded_listener_factory * Deflake other logbook test * lint * Add test for automation trigger service * MQTT client can be called from within async
This commit is contained in:
parent
f2a12b7ac2
commit
d58548dd1c
10 changed files with 123 additions and 76 deletions
|
@ -154,15 +154,24 @@ def setup(hass, config):
|
|||
def trigger_service_handler(service_call):
|
||||
"""Handle automation triggers."""
|
||||
for entity in component.extract_from_service(service_call):
|
||||
yield from entity.async_trigger(
|
||||
service_call.data.get(ATTR_VARIABLES))
|
||||
hass.loop.create_task(entity.async_trigger(
|
||||
service_call.data.get(ATTR_VARIABLES), True))
|
||||
|
||||
@asyncio.coroutine
|
||||
def service_handler(service_call):
|
||||
"""Handle automation service calls."""
|
||||
def turn_onoff_service_handler(service_call):
|
||||
"""Handle automation turn on/off service calls."""
|
||||
method = 'async_{}'.format(service_call.service)
|
||||
for entity in component.extract_from_service(service_call):
|
||||
yield from getattr(entity, method)()
|
||||
hass.loop.create_task(getattr(entity, method)())
|
||||
|
||||
@asyncio.coroutine
|
||||
def toggle_service_handler(service_call):
|
||||
"""Handle automation toggle service calls."""
|
||||
for entity in component.extract_from_service(service_call):
|
||||
if entity.is_on:
|
||||
hass.loop.create_task(entity.async_turn_off())
|
||||
else:
|
||||
hass.loop.create_task(entity.async_turn_on())
|
||||
|
||||
@asyncio.coroutine
|
||||
def reload_service_handler(service_call):
|
||||
|
@ -171,7 +180,7 @@ def setup(hass, config):
|
|||
None, component.prepare_reload)
|
||||
if conf is None:
|
||||
return
|
||||
yield from _async_process_config(hass, conf, component)
|
||||
hass.loop.create_task(_async_process_config(hass, conf, component))
|
||||
|
||||
hass.services.register(DOMAIN, SERVICE_TRIGGER, trigger_service_handler,
|
||||
descriptions.get(SERVICE_TRIGGER),
|
||||
|
@ -181,8 +190,12 @@ def setup(hass, config):
|
|||
descriptions.get(SERVICE_RELOAD),
|
||||
schema=RELOAD_SERVICE_SCHEMA)
|
||||
|
||||
for service in (SERVICE_TURN_ON, SERVICE_TURN_OFF, SERVICE_TOGGLE):
|
||||
hass.services.register(DOMAIN, service, service_handler,
|
||||
hass.services.register(DOMAIN, SERVICE_TOGGLE, toggle_service_handler,
|
||||
descriptions.get(SERVICE_TOGGLE),
|
||||
schema=SERVICE_SCHEMA)
|
||||
|
||||
for service in (SERVICE_TURN_ON, SERVICE_TURN_OFF):
|
||||
hass.services.register(DOMAIN, service, turn_onoff_service_handler,
|
||||
descriptions.get(service),
|
||||
schema=SERVICE_SCHEMA)
|
||||
|
||||
|
@ -236,8 +249,11 @@ class AutomationEntity(ToggleEntity):
|
|||
@asyncio.coroutine
|
||||
def async_turn_on(self, **kwargs) -> None:
|
||||
"""Turn the entity on and update the state."""
|
||||
if self._enabled:
|
||||
return
|
||||
|
||||
yield from self.async_enable()
|
||||
yield from self.async_update_ha_state()
|
||||
self.hass.loop.create_task(self.async_update_ha_state())
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_turn_off(self, **kwargs) -> None:
|
||||
|
@ -248,23 +264,18 @@ class AutomationEntity(ToggleEntity):
|
|||
self._async_detach_triggers()
|
||||
self._async_detach_triggers = None
|
||||
self._enabled = False
|
||||
yield from self.async_update_ha_state()
|
||||
self.hass.loop.create_task(self.async_update_ha_state())
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_toggle(self):
|
||||
"""Toggle the state of the entity."""
|
||||
if self._enabled:
|
||||
yield from self.async_turn_off()
|
||||
else:
|
||||
yield from self.async_turn_on()
|
||||
def async_trigger(self, variables, skip_condition=False):
|
||||
"""Trigger automation.
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_trigger(self, variables):
|
||||
"""Trigger automation."""
|
||||
if self._cond_func(variables):
|
||||
This method is a coroutine.
|
||||
"""
|
||||
if skip_condition or self._cond_func(variables):
|
||||
yield from self._async_action(variables)
|
||||
self._last_triggered = utcnow()
|
||||
yield from self.async_update_ha_state()
|
||||
self.hass.loop.create_task(self.async_update_ha_state())
|
||||
|
||||
def remove(self):
|
||||
"""Remove automation from HASS."""
|
||||
|
@ -274,7 +285,10 @@ class AutomationEntity(ToggleEntity):
|
|||
|
||||
@asyncio.coroutine
|
||||
def async_enable(self):
|
||||
"""Enable this automation entity."""
|
||||
"""Enable this automation entity.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
if self._enabled:
|
||||
return
|
||||
|
||||
|
@ -285,8 +299,12 @@ class AutomationEntity(ToggleEntity):
|
|||
|
||||
@asyncio.coroutine
|
||||
def _async_process_config(hass, config, component):
|
||||
"""Process config and add automations."""
|
||||
"""Process config and add automations.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
entities = []
|
||||
tasks = []
|
||||
|
||||
for config_key in extract_domain_configs(config, DOMAIN):
|
||||
conf = config[config_key]
|
||||
|
@ -315,9 +333,10 @@ def _async_process_config(hass, config, component):
|
|||
config_block.get(CONF_TRIGGER, []), name)
|
||||
entity = AutomationEntity(name, async_attach_triggers, cond_func,
|
||||
action, hidden)
|
||||
yield from entity.async_enable()
|
||||
tasks.append(hass.loop.create_task(entity.async_enable()))
|
||||
entities.append(entity)
|
||||
|
||||
yield from asyncio.gather(*tasks, loop=hass.loop)
|
||||
yield from hass.loop.run_in_executor(
|
||||
None, component.add_entities, entities)
|
||||
|
||||
|
@ -333,7 +352,7 @@ def _async_get_action(hass, config, name):
|
|||
"""Action to be executed."""
|
||||
_LOGGER.info('Executing %s', name)
|
||||
logbook.async_log_entry(hass, name, 'has been triggered', DOMAIN)
|
||||
yield from script_obj.async_run(variables)
|
||||
hass.loop.create_task(script_obj.async_run(variables))
|
||||
|
||||
return action
|
||||
|
||||
|
@ -359,7 +378,10 @@ def _async_process_if(hass, config, p_config):
|
|||
|
||||
@asyncio.coroutine
|
||||
def _async_process_trigger(hass, config, trigger_configs, name, action):
|
||||
"""Setup the triggers."""
|
||||
"""Setup the triggers.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
removes = []
|
||||
|
||||
for conf in trigger_configs:
|
||||
|
|
|
@ -85,7 +85,7 @@ class BinarySensorTemplate(BinarySensorDevice):
|
|||
@asyncio.coroutine
|
||||
def template_bsensor_state_listener(entity, old_state, new_state):
|
||||
"""Called when the target device changes state."""
|
||||
yield from self.async_update_ha_state(True)
|
||||
hass.loop.create_task(self.async_update_ha_state(True))
|
||||
|
||||
track_state_change(hass, entity_ids, template_bsensor_state_listener)
|
||||
|
||||
|
|
|
@ -12,16 +12,14 @@ import time
|
|||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import JobPriority
|
||||
from homeassistant.bootstrap import prepare_setup_platform
|
||||
from homeassistant.config import load_yaml_config_file
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import template
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers import template, config_validation as cv
|
||||
from homeassistant.helpers.event import threaded_listener_factory
|
||||
from homeassistant.const import (
|
||||
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
|
||||
CONF_PLATFORM, CONF_SCAN_INTERVAL, CONF_VALUE_TEMPLATE)
|
||||
from homeassistant.util.async import run_callback_threadsafe
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -165,18 +163,6 @@ def publish_template(hass, topic, payload_template, qos=None, retain=None):
|
|||
hass.services.call(DOMAIN, SERVICE_PUBLISH, data)
|
||||
|
||||
|
||||
def subscribe(hass, topic, callback, qos=DEFAULT_QOS):
|
||||
"""Subscribe to an MQTT topic."""
|
||||
async_remove = run_callback_threadsafe(
|
||||
hass.loop, async_subscribe, hass, topic, callback, qos).result()
|
||||
|
||||
def remove_mqtt():
|
||||
"""Remove MQTT subscription."""
|
||||
run_callback_threadsafe(hass.loop, async_remove).result()
|
||||
|
||||
return remove_mqtt
|
||||
|
||||
|
||||
def async_subscribe(hass, topic, callback, qos=DEFAULT_QOS):
|
||||
"""Subscribe to an MQTT topic."""
|
||||
@asyncio.coroutine
|
||||
|
@ -185,14 +171,8 @@ def async_subscribe(hass, topic, callback, qos=DEFAULT_QOS):
|
|||
if not _match_topic(topic, event.data[ATTR_TOPIC]):
|
||||
return
|
||||
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
yield from callback(
|
||||
event.data[ATTR_TOPIC], event.data[ATTR_PAYLOAD],
|
||||
event.data[ATTR_QOS])
|
||||
else:
|
||||
hass.add_job(callback, event.data[ATTR_TOPIC],
|
||||
event.data[ATTR_PAYLOAD], event.data[ATTR_QOS],
|
||||
priority=JobPriority.EVENT_CALLBACK)
|
||||
hass.async_add_job(callback, event.data[ATTR_TOPIC],
|
||||
event.data[ATTR_PAYLOAD], event.data[ATTR_QOS])
|
||||
|
||||
async_remove = hass.bus.async_listen(EVENT_MQTT_MESSAGE_RECEIVED,
|
||||
mqtt_topic_subscriber)
|
||||
|
@ -203,6 +183,10 @@ def async_subscribe(hass, topic, callback, qos=DEFAULT_QOS):
|
|||
return async_remove
|
||||
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
subscribe = threaded_listener_factory(async_subscribe)
|
||||
|
||||
|
||||
def _setup_server(hass, config):
|
||||
"""Try to start embedded MQTT broker."""
|
||||
conf = config.get(DOMAIN, {})
|
||||
|
|
|
@ -124,7 +124,7 @@ class ScriptEntity(ToggleEntity):
|
|||
def __init__(self, hass, object_id, name, sequence):
|
||||
"""Initialize the script."""
|
||||
self.entity_id = ENTITY_ID_FORMAT.format(object_id)
|
||||
self.script = Script(hass, sequence, name, self.update_ha_state)
|
||||
self.script = Script(hass, sequence, name, self.async_update_ha_state)
|
||||
|
||||
@property
|
||||
def should_poll(self):
|
||||
|
|
|
@ -82,7 +82,7 @@ class SensorTemplate(Entity):
|
|||
@asyncio.coroutine
|
||||
def template_sensor_state_listener(entity, old_state, new_state):
|
||||
"""Called when the target device changes state."""
|
||||
yield from self.async_update_ha_state(True)
|
||||
hass.loop.create_task(self.async_update_ha_state(True))
|
||||
|
||||
track_state_change(hass, entity_ids, template_sensor_state_listener)
|
||||
|
||||
|
|
|
@ -91,7 +91,7 @@ class SwitchTemplate(SwitchDevice):
|
|||
@asyncio.coroutine
|
||||
def template_switch_state_listener(entity, old_state, new_state):
|
||||
"""Called when the target device changes state."""
|
||||
yield from self.async_update_ha_state(True)
|
||||
hass.loop.create_task(self.async_update_ha_state(True))
|
||||
|
||||
track_state_change(hass, entity_ids, template_switch_state_listener)
|
||||
|
||||
|
|
|
@ -9,11 +9,11 @@ from ..const import (
|
|||
from ..util import dt as dt_util
|
||||
from ..util.async import run_callback_threadsafe
|
||||
|
||||
# PyLint does not like the use of _threaded_factory
|
||||
# PyLint does not like the use of threaded_listener_factory
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _threaded_factory(async_factory):
|
||||
def threaded_listener_factory(async_factory):
|
||||
"""Convert an async event helper to a threaded one."""
|
||||
@ft.wraps(async_factory)
|
||||
def factory(*args, **kwargs):
|
||||
|
@ -83,7 +83,7 @@ def async_track_state_change(hass, entity_ids, action, from_state=None,
|
|||
return hass.bus.async_listen(EVENT_STATE_CHANGED, state_change_listener)
|
||||
|
||||
|
||||
track_state_change = _threaded_factory(async_track_state_change)
|
||||
track_state_change = threaded_listener_factory(async_track_state_change)
|
||||
|
||||
|
||||
def async_track_point_in_time(hass, action, point_in_time):
|
||||
|
@ -100,7 +100,7 @@ def async_track_point_in_time(hass, action, point_in_time):
|
|||
utc_point_in_time)
|
||||
|
||||
|
||||
track_point_in_time = _threaded_factory(async_track_point_in_time)
|
||||
track_point_in_time = threaded_listener_factory(async_track_point_in_time)
|
||||
|
||||
|
||||
def async_track_point_in_utc_time(hass, action, point_in_time):
|
||||
|
@ -133,7 +133,8 @@ def async_track_point_in_utc_time(hass, action, point_in_time):
|
|||
return async_unsub
|
||||
|
||||
|
||||
track_point_in_utc_time = _threaded_factory(async_track_point_in_utc_time)
|
||||
track_point_in_utc_time = threaded_listener_factory(
|
||||
async_track_point_in_utc_time)
|
||||
|
||||
|
||||
def async_track_sunrise(hass, action, offset=None):
|
||||
|
@ -169,7 +170,7 @@ def async_track_sunrise(hass, action, offset=None):
|
|||
return remove_listener
|
||||
|
||||
|
||||
track_sunrise = _threaded_factory(async_track_sunrise)
|
||||
track_sunrise = threaded_listener_factory(async_track_sunrise)
|
||||
|
||||
|
||||
def async_track_sunset(hass, action, offset=None):
|
||||
|
@ -205,7 +206,7 @@ def async_track_sunset(hass, action, offset=None):
|
|||
return remove_listener
|
||||
|
||||
|
||||
track_sunset = _threaded_factory(async_track_sunset)
|
||||
track_sunset = threaded_listener_factory(async_track_sunset)
|
||||
|
||||
|
||||
# pylint: disable=too-many-arguments
|
||||
|
@ -251,7 +252,7 @@ def async_track_utc_time_change(hass, action, year=None, month=None, day=None,
|
|||
pattern_time_change_listener)
|
||||
|
||||
|
||||
track_utc_time_change = _threaded_factory(async_track_utc_time_change)
|
||||
track_utc_time_change = threaded_listener_factory(async_track_utc_time_change)
|
||||
|
||||
|
||||
# pylint: disable=too-many-arguments
|
||||
|
@ -262,7 +263,7 @@ def async_track_time_change(hass, action, year=None, month=None, day=None,
|
|||
minute, second, local=True)
|
||||
|
||||
|
||||
track_time_change = _threaded_factory(async_track_time_change)
|
||||
track_time_change = threaded_listener_factory(async_track_time_change)
|
||||
|
||||
|
||||
def _process_state_match(parameter):
|
||||
|
|
|
@ -66,7 +66,7 @@ class Script():
|
|||
def async_run(self, variables: Optional[Sequence]=None) -> None:
|
||||
"""Run script.
|
||||
|
||||
Returns a coroutine.
|
||||
This method is a coroutine.
|
||||
"""
|
||||
if self._cur == -1:
|
||||
self._log('Running script')
|
||||
|
@ -85,7 +85,7 @@ class Script():
|
|||
def script_delay(now):
|
||||
"""Called after delay is done."""
|
||||
self._async_unsub_delay_listener = None
|
||||
yield from self.async_run(variables)
|
||||
self.hass.loop.create_task(self.async_run(variables))
|
||||
|
||||
delay = action[CONF_DELAY]
|
||||
|
||||
|
@ -100,7 +100,8 @@ class Script():
|
|||
self.hass, script_delay,
|
||||
date_util.utcnow() + delay)
|
||||
self._cur = cur + 1
|
||||
self._trigger_change_listener()
|
||||
if self._change_listener:
|
||||
self.hass.async_add_job(self._change_listener)
|
||||
return
|
||||
|
||||
elif CONF_CONDITION in action:
|
||||
|
@ -115,7 +116,8 @@ class Script():
|
|||
|
||||
self._cur = -1
|
||||
self.last_action = None
|
||||
self._trigger_change_listener()
|
||||
if self._change_listener:
|
||||
self.hass.async_add_job(self._change_listener)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop running script."""
|
||||
|
@ -128,11 +130,15 @@ class Script():
|
|||
|
||||
self._cur = -1
|
||||
self._async_remove_listener()
|
||||
self._trigger_change_listener()
|
||||
if self._change_listener:
|
||||
self.hass.async_add_job(self._change_listener)
|
||||
|
||||
@asyncio.coroutine
|
||||
def _async_call_service(self, action, variables):
|
||||
"""Call the service specified in the action."""
|
||||
"""Call the service specified in the action.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
self.last_action = action.get(CONF_ALIAS, 'call service')
|
||||
self._log("Executing step %s" % self.last_action)
|
||||
yield from service.async_call_from_config(
|
||||
|
@ -165,10 +171,3 @@ class Script():
|
|||
msg = "Script {}: {}".format(self.name, msg)
|
||||
|
||||
_LOGGER.info(msg)
|
||||
|
||||
def _trigger_change_listener(self):
|
||||
"""Trigger the change listener."""
|
||||
if not self._change_listener:
|
||||
return
|
||||
|
||||
self.hass.async_add_job(self._change_listener)
|
||||
|
|
|
@ -144,6 +144,35 @@ class TestAutomation(unittest.TestCase):
|
|||
self.hass.block_till_done()
|
||||
self.assertEqual(2, len(self.calls))
|
||||
|
||||
def test_trigger_service_ignoring_condition(self):
|
||||
"""Test triggers."""
|
||||
assert setup_component(self.hass, automation.DOMAIN, {
|
||||
automation.DOMAIN: {
|
||||
'trigger': [
|
||||
{
|
||||
'platform': 'event',
|
||||
'event_type': 'test_event',
|
||||
},
|
||||
],
|
||||
'condition': {
|
||||
'condition': 'state',
|
||||
'entity_id': 'non.existing',
|
||||
'state': 'beer',
|
||||
},
|
||||
'action': {
|
||||
'service': 'test.automation',
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
self.hass.bus.fire('test_event')
|
||||
self.hass.block_till_done()
|
||||
assert len(self.calls) == 0
|
||||
|
||||
self.hass.services.call('automation', 'trigger', blocking=True)
|
||||
self.hass.block_till_done()
|
||||
assert len(self.calls) == 1
|
||||
|
||||
def test_two_conditions_with_and(self):
|
||||
"""Test two and conditions."""
|
||||
entity_id = 'test.entity'
|
||||
|
@ -348,6 +377,8 @@ class TestAutomation(unittest.TestCase):
|
|||
|
||||
automation.reload(self.hass)
|
||||
self.hass.block_till_done()
|
||||
# De-flake ?!
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert self.hass.states.get('automation.hello') is None
|
||||
assert self.hass.states.get('automation.bye') is not None
|
||||
|
|
|
@ -50,6 +50,11 @@ class TestComponentLogbook(unittest.TestCase):
|
|||
logbook.ATTR_ENTITY_ID: 'switch.test_switch'
|
||||
}, True)
|
||||
|
||||
# Logbook entry service call results in firing an event.
|
||||
# Our service call will unblock when the event listeners have been
|
||||
# scheduled. This means that they may not have been processed yet.
|
||||
self.hass.block_till_done()
|
||||
|
||||
self.assertEqual(1, len(calls))
|
||||
last_call = calls[-1]
|
||||
|
||||
|
@ -70,6 +75,11 @@ class TestComponentLogbook(unittest.TestCase):
|
|||
self.hass.bus.listen(logbook.EVENT_LOGBOOK_ENTRY, event_listener)
|
||||
self.hass.services.call(logbook.DOMAIN, 'log', {}, True)
|
||||
|
||||
# Logbook entry service call results in firing an event.
|
||||
# Our service call will unblock when the event listeners have been
|
||||
# scheduled. This means that they may not have been processed yet.
|
||||
self.hass.block_till_done()
|
||||
|
||||
self.assertEqual(0, len(calls))
|
||||
|
||||
def test_humanify_filter_sensor(self):
|
||||
|
|
Loading…
Add table
Reference in a new issue