From d58548dd1c830f0aa5b6d68c8b9e0c68be5e014e Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 3 Oct 2016 22:39:27 -0700 Subject: [PATCH] Address asyncio comments (#3663) * Template platforms: create_task instead of yield from * Automation: less yielding, more create_tasking * Helpers.script: less yielding, more create_tasking * Deflake logbook test * Deflake automation reload config test * MQTT: Use async_add_job and threaded_listener_factory * Deflake other logbook test * lint * Add test for automation trigger service * MQTT client can be called from within async --- .../components/automation/__init__.py | 74 ++++++++++++------- .../components/binary_sensor/template.py | 2 +- homeassistant/components/mqtt/__init__.py | 32 ++------ homeassistant/components/script.py | 2 +- homeassistant/components/sensor/template.py | 2 +- homeassistant/components/switch/template.py | 2 +- homeassistant/helpers/event.py | 19 ++--- homeassistant/helpers/script.py | 25 +++---- tests/components/automation/test_init.py | 31 ++++++++ tests/components/test_logbook.py | 10 +++ 10 files changed, 123 insertions(+), 76 deletions(-) diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index a677fe1da4e..579d4b40003 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -154,15 +154,24 @@ def setup(hass, config): def trigger_service_handler(service_call): """Handle automation triggers.""" for entity in component.extract_from_service(service_call): - yield from entity.async_trigger( - service_call.data.get(ATTR_VARIABLES)) + hass.loop.create_task(entity.async_trigger( + service_call.data.get(ATTR_VARIABLES), True)) @asyncio.coroutine - def service_handler(service_call): - """Handle automation service calls.""" + def turn_onoff_service_handler(service_call): + """Handle automation turn on/off service calls.""" method = 'async_{}'.format(service_call.service) for entity in component.extract_from_service(service_call): - yield from getattr(entity, method)() + hass.loop.create_task(getattr(entity, method)()) + + @asyncio.coroutine + def toggle_service_handler(service_call): + """Handle automation toggle service calls.""" + for entity in component.extract_from_service(service_call): + if entity.is_on: + hass.loop.create_task(entity.async_turn_off()) + else: + hass.loop.create_task(entity.async_turn_on()) @asyncio.coroutine def reload_service_handler(service_call): @@ -171,7 +180,7 @@ def setup(hass, config): None, component.prepare_reload) if conf is None: return - yield from _async_process_config(hass, conf, component) + hass.loop.create_task(_async_process_config(hass, conf, component)) hass.services.register(DOMAIN, SERVICE_TRIGGER, trigger_service_handler, descriptions.get(SERVICE_TRIGGER), @@ -181,8 +190,12 @@ def setup(hass, config): descriptions.get(SERVICE_RELOAD), schema=RELOAD_SERVICE_SCHEMA) - for service in (SERVICE_TURN_ON, SERVICE_TURN_OFF, SERVICE_TOGGLE): - hass.services.register(DOMAIN, service, service_handler, + hass.services.register(DOMAIN, SERVICE_TOGGLE, toggle_service_handler, + descriptions.get(SERVICE_TOGGLE), + schema=SERVICE_SCHEMA) + + for service in (SERVICE_TURN_ON, SERVICE_TURN_OFF): + hass.services.register(DOMAIN, service, turn_onoff_service_handler, descriptions.get(service), schema=SERVICE_SCHEMA) @@ -236,8 +249,11 @@ class AutomationEntity(ToggleEntity): @asyncio.coroutine def async_turn_on(self, **kwargs) -> None: """Turn the entity on and update the state.""" + if self._enabled: + return + yield from self.async_enable() - yield from self.async_update_ha_state() + self.hass.loop.create_task(self.async_update_ha_state()) @asyncio.coroutine def async_turn_off(self, **kwargs) -> None: @@ -248,23 +264,18 @@ class AutomationEntity(ToggleEntity): self._async_detach_triggers() self._async_detach_triggers = None self._enabled = False - yield from self.async_update_ha_state() + self.hass.loop.create_task(self.async_update_ha_state()) @asyncio.coroutine - def async_toggle(self): - """Toggle the state of the entity.""" - if self._enabled: - yield from self.async_turn_off() - else: - yield from self.async_turn_on() + def async_trigger(self, variables, skip_condition=False): + """Trigger automation. - @asyncio.coroutine - def async_trigger(self, variables): - """Trigger automation.""" - if self._cond_func(variables): + This method is a coroutine. + """ + if skip_condition or self._cond_func(variables): yield from self._async_action(variables) self._last_triggered = utcnow() - yield from self.async_update_ha_state() + self.hass.loop.create_task(self.async_update_ha_state()) def remove(self): """Remove automation from HASS.""" @@ -274,7 +285,10 @@ class AutomationEntity(ToggleEntity): @asyncio.coroutine def async_enable(self): - """Enable this automation entity.""" + """Enable this automation entity. + + This method is a coroutine. + """ if self._enabled: return @@ -285,8 +299,12 @@ class AutomationEntity(ToggleEntity): @asyncio.coroutine def _async_process_config(hass, config, component): - """Process config and add automations.""" + """Process config and add automations. + + This method is a coroutine. + """ entities = [] + tasks = [] for config_key in extract_domain_configs(config, DOMAIN): conf = config[config_key] @@ -315,9 +333,10 @@ def _async_process_config(hass, config, component): config_block.get(CONF_TRIGGER, []), name) entity = AutomationEntity(name, async_attach_triggers, cond_func, action, hidden) - yield from entity.async_enable() + tasks.append(hass.loop.create_task(entity.async_enable())) entities.append(entity) + yield from asyncio.gather(*tasks, loop=hass.loop) yield from hass.loop.run_in_executor( None, component.add_entities, entities) @@ -333,7 +352,7 @@ def _async_get_action(hass, config, name): """Action to be executed.""" _LOGGER.info('Executing %s', name) logbook.async_log_entry(hass, name, 'has been triggered', DOMAIN) - yield from script_obj.async_run(variables) + hass.loop.create_task(script_obj.async_run(variables)) return action @@ -359,7 +378,10 @@ def _async_process_if(hass, config, p_config): @asyncio.coroutine def _async_process_trigger(hass, config, trigger_configs, name, action): - """Setup the triggers.""" + """Setup the triggers. + + This method is a coroutine. + """ removes = [] for conf in trigger_configs: diff --git a/homeassistant/components/binary_sensor/template.py b/homeassistant/components/binary_sensor/template.py index 85c9f0e8950..339a5cb9ba1 100644 --- a/homeassistant/components/binary_sensor/template.py +++ b/homeassistant/components/binary_sensor/template.py @@ -85,7 +85,7 @@ class BinarySensorTemplate(BinarySensorDevice): @asyncio.coroutine def template_bsensor_state_listener(entity, old_state, new_state): """Called when the target device changes state.""" - yield from self.async_update_ha_state(True) + hass.loop.create_task(self.async_update_ha_state(True)) track_state_change(hass, entity_ids, template_bsensor_state_listener) diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 86896e8309e..7995d9bf39a 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -12,16 +12,14 @@ import time import voluptuous as vol -from homeassistant.core import JobPriority from homeassistant.bootstrap import prepare_setup_platform from homeassistant.config import load_yaml_config_file from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import template -import homeassistant.helpers.config_validation as cv +from homeassistant.helpers import template, config_validation as cv +from homeassistant.helpers.event import threaded_listener_factory from homeassistant.const import ( EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, CONF_PLATFORM, CONF_SCAN_INTERVAL, CONF_VALUE_TEMPLATE) -from homeassistant.util.async import run_callback_threadsafe _LOGGER = logging.getLogger(__name__) @@ -165,18 +163,6 @@ def publish_template(hass, topic, payload_template, qos=None, retain=None): hass.services.call(DOMAIN, SERVICE_PUBLISH, data) -def subscribe(hass, topic, callback, qos=DEFAULT_QOS): - """Subscribe to an MQTT topic.""" - async_remove = run_callback_threadsafe( - hass.loop, async_subscribe, hass, topic, callback, qos).result() - - def remove_mqtt(): - """Remove MQTT subscription.""" - run_callback_threadsafe(hass.loop, async_remove).result() - - return remove_mqtt - - def async_subscribe(hass, topic, callback, qos=DEFAULT_QOS): """Subscribe to an MQTT topic.""" @asyncio.coroutine @@ -185,14 +171,8 @@ def async_subscribe(hass, topic, callback, qos=DEFAULT_QOS): if not _match_topic(topic, event.data[ATTR_TOPIC]): return - if asyncio.iscoroutinefunction(callback): - yield from callback( - event.data[ATTR_TOPIC], event.data[ATTR_PAYLOAD], - event.data[ATTR_QOS]) - else: - hass.add_job(callback, event.data[ATTR_TOPIC], - event.data[ATTR_PAYLOAD], event.data[ATTR_QOS], - priority=JobPriority.EVENT_CALLBACK) + hass.async_add_job(callback, event.data[ATTR_TOPIC], + event.data[ATTR_PAYLOAD], event.data[ATTR_QOS]) async_remove = hass.bus.async_listen(EVENT_MQTT_MESSAGE_RECEIVED, mqtt_topic_subscriber) @@ -203,6 +183,10 @@ def async_subscribe(hass, topic, callback, qos=DEFAULT_QOS): return async_remove +# pylint: disable=invalid-name +subscribe = threaded_listener_factory(async_subscribe) + + def _setup_server(hass, config): """Try to start embedded MQTT broker.""" conf = config.get(DOMAIN, {}) diff --git a/homeassistant/components/script.py b/homeassistant/components/script.py index b235c4d4eb7..961c37f896a 100644 --- a/homeassistant/components/script.py +++ b/homeassistant/components/script.py @@ -124,7 +124,7 @@ class ScriptEntity(ToggleEntity): def __init__(self, hass, object_id, name, sequence): """Initialize the script.""" self.entity_id = ENTITY_ID_FORMAT.format(object_id) - self.script = Script(hass, sequence, name, self.update_ha_state) + self.script = Script(hass, sequence, name, self.async_update_ha_state) @property def should_poll(self): diff --git a/homeassistant/components/sensor/template.py b/homeassistant/components/sensor/template.py index 4b6f322b5aa..ed905f44ebd 100644 --- a/homeassistant/components/sensor/template.py +++ b/homeassistant/components/sensor/template.py @@ -82,7 +82,7 @@ class SensorTemplate(Entity): @asyncio.coroutine def template_sensor_state_listener(entity, old_state, new_state): """Called when the target device changes state.""" - yield from self.async_update_ha_state(True) + hass.loop.create_task(self.async_update_ha_state(True)) track_state_change(hass, entity_ids, template_sensor_state_listener) diff --git a/homeassistant/components/switch/template.py b/homeassistant/components/switch/template.py index 7c6f4f5886d..bcd74454ce5 100644 --- a/homeassistant/components/switch/template.py +++ b/homeassistant/components/switch/template.py @@ -91,7 +91,7 @@ class SwitchTemplate(SwitchDevice): @asyncio.coroutine def template_switch_state_listener(entity, old_state, new_state): """Called when the target device changes state.""" - yield from self.async_update_ha_state(True) + hass.loop.create_task(self.async_update_ha_state(True)) track_state_change(hass, entity_ids, template_switch_state_listener) diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index bf781e7e746..69f620adb82 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -9,11 +9,11 @@ from ..const import ( from ..util import dt as dt_util from ..util.async import run_callback_threadsafe -# PyLint does not like the use of _threaded_factory +# PyLint does not like the use of threaded_listener_factory # pylint: disable=invalid-name -def _threaded_factory(async_factory): +def threaded_listener_factory(async_factory): """Convert an async event helper to a threaded one.""" @ft.wraps(async_factory) def factory(*args, **kwargs): @@ -83,7 +83,7 @@ def async_track_state_change(hass, entity_ids, action, from_state=None, return hass.bus.async_listen(EVENT_STATE_CHANGED, state_change_listener) -track_state_change = _threaded_factory(async_track_state_change) +track_state_change = threaded_listener_factory(async_track_state_change) def async_track_point_in_time(hass, action, point_in_time): @@ -100,7 +100,7 @@ def async_track_point_in_time(hass, action, point_in_time): utc_point_in_time) -track_point_in_time = _threaded_factory(async_track_point_in_time) +track_point_in_time = threaded_listener_factory(async_track_point_in_time) def async_track_point_in_utc_time(hass, action, point_in_time): @@ -133,7 +133,8 @@ def async_track_point_in_utc_time(hass, action, point_in_time): return async_unsub -track_point_in_utc_time = _threaded_factory(async_track_point_in_utc_time) +track_point_in_utc_time = threaded_listener_factory( + async_track_point_in_utc_time) def async_track_sunrise(hass, action, offset=None): @@ -169,7 +170,7 @@ def async_track_sunrise(hass, action, offset=None): return remove_listener -track_sunrise = _threaded_factory(async_track_sunrise) +track_sunrise = threaded_listener_factory(async_track_sunrise) def async_track_sunset(hass, action, offset=None): @@ -205,7 +206,7 @@ def async_track_sunset(hass, action, offset=None): return remove_listener -track_sunset = _threaded_factory(async_track_sunset) +track_sunset = threaded_listener_factory(async_track_sunset) # pylint: disable=too-many-arguments @@ -251,7 +252,7 @@ def async_track_utc_time_change(hass, action, year=None, month=None, day=None, pattern_time_change_listener) -track_utc_time_change = _threaded_factory(async_track_utc_time_change) +track_utc_time_change = threaded_listener_factory(async_track_utc_time_change) # pylint: disable=too-many-arguments @@ -262,7 +263,7 @@ def async_track_time_change(hass, action, year=None, month=None, day=None, minute, second, local=True) -track_time_change = _threaded_factory(async_track_time_change) +track_time_change = threaded_listener_factory(async_track_time_change) def _process_state_match(parameter): diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 1bfe7d550ad..cb4a1fbbe04 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -66,7 +66,7 @@ class Script(): def async_run(self, variables: Optional[Sequence]=None) -> None: """Run script. - Returns a coroutine. + This method is a coroutine. """ if self._cur == -1: self._log('Running script') @@ -85,7 +85,7 @@ class Script(): def script_delay(now): """Called after delay is done.""" self._async_unsub_delay_listener = None - yield from self.async_run(variables) + self.hass.loop.create_task(self.async_run(variables)) delay = action[CONF_DELAY] @@ -100,7 +100,8 @@ class Script(): self.hass, script_delay, date_util.utcnow() + delay) self._cur = cur + 1 - self._trigger_change_listener() + if self._change_listener: + self.hass.async_add_job(self._change_listener) return elif CONF_CONDITION in action: @@ -115,7 +116,8 @@ class Script(): self._cur = -1 self.last_action = None - self._trigger_change_listener() + if self._change_listener: + self.hass.async_add_job(self._change_listener) def stop(self) -> None: """Stop running script.""" @@ -128,11 +130,15 @@ class Script(): self._cur = -1 self._async_remove_listener() - self._trigger_change_listener() + if self._change_listener: + self.hass.async_add_job(self._change_listener) @asyncio.coroutine def _async_call_service(self, action, variables): - """Call the service specified in the action.""" + """Call the service specified in the action. + + This method is a coroutine. + """ self.last_action = action.get(CONF_ALIAS, 'call service') self._log("Executing step %s" % self.last_action) yield from service.async_call_from_config( @@ -165,10 +171,3 @@ class Script(): msg = "Script {}: {}".format(self.name, msg) _LOGGER.info(msg) - - def _trigger_change_listener(self): - """Trigger the change listener.""" - if not self._change_listener: - return - - self.hass.async_add_job(self._change_listener) diff --git a/tests/components/automation/test_init.py b/tests/components/automation/test_init.py index b667436d9a6..8cc7825bb1f 100644 --- a/tests/components/automation/test_init.py +++ b/tests/components/automation/test_init.py @@ -144,6 +144,35 @@ class TestAutomation(unittest.TestCase): self.hass.block_till_done() self.assertEqual(2, len(self.calls)) + def test_trigger_service_ignoring_condition(self): + """Test triggers.""" + assert setup_component(self.hass, automation.DOMAIN, { + automation.DOMAIN: { + 'trigger': [ + { + 'platform': 'event', + 'event_type': 'test_event', + }, + ], + 'condition': { + 'condition': 'state', + 'entity_id': 'non.existing', + 'state': 'beer', + }, + 'action': { + 'service': 'test.automation', + } + } + }) + + self.hass.bus.fire('test_event') + self.hass.block_till_done() + assert len(self.calls) == 0 + + self.hass.services.call('automation', 'trigger', blocking=True) + self.hass.block_till_done() + assert len(self.calls) == 1 + def test_two_conditions_with_and(self): """Test two and conditions.""" entity_id = 'test.entity' @@ -348,6 +377,8 @@ class TestAutomation(unittest.TestCase): automation.reload(self.hass) self.hass.block_till_done() + # De-flake ?! + self.hass.block_till_done() assert self.hass.states.get('automation.hello') is None assert self.hass.states.get('automation.bye') is not None diff --git a/tests/components/test_logbook.py b/tests/components/test_logbook.py index 539622d9296..9e8ab09a5a6 100644 --- a/tests/components/test_logbook.py +++ b/tests/components/test_logbook.py @@ -50,6 +50,11 @@ class TestComponentLogbook(unittest.TestCase): logbook.ATTR_ENTITY_ID: 'switch.test_switch' }, True) + # Logbook entry service call results in firing an event. + # Our service call will unblock when the event listeners have been + # scheduled. This means that they may not have been processed yet. + self.hass.block_till_done() + self.assertEqual(1, len(calls)) last_call = calls[-1] @@ -70,6 +75,11 @@ class TestComponentLogbook(unittest.TestCase): self.hass.bus.listen(logbook.EVENT_LOGBOOK_ENTRY, event_listener) self.hass.services.call(logbook.DOMAIN, 'log', {}, True) + # Logbook entry service call results in firing an event. + # Our service call will unblock when the event listeners have been + # scheduled. This means that they may not have been processed yet. + self.hass.block_till_done() + self.assertEqual(0, len(calls)) def test_humanify_filter_sensor(self):