Clean up some async stuff (#3915)

* Clean up some async stuff

* Adjust comments

* Pass hass instance to eventbus
This commit is contained in:
Paulus Schoutsen 2016-10-17 19:38:41 -07:00 committed by GitHub
parent daea93d9f9
commit 4c8d1d9d2f
7 changed files with 139 additions and 111 deletions

View file

@ -79,8 +79,7 @@ class NuimoThread(threading.Thread):
self._name = name
self._hass_is_running = True
self._nuimo = None
self._listener = hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP,
self.stop)
hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, self.stop)
def run(self):
"""Setup connection or be idle."""
@ -99,8 +98,6 @@ class NuimoThread(threading.Thread):
"""Terminate Thread by unsetting flag."""
_LOGGER.debug('Stopping thread for Nuimo %s', self._mac)
self._hass_is_running = False
self._hass.bus.remove_listener(EVENT_HOMEASSISTANT_STOP,
self._listener)
def _attach(self):
"""Create a nuimo object from mac address or discovery."""

View file

@ -8,7 +8,6 @@ of entities and react to changes.
import asyncio
from concurrent.futures import ThreadPoolExecutor
import enum
import functools as ft
import logging
import os
import re
@ -137,8 +136,8 @@ class HomeAssistant(object):
self.executor = ThreadPoolExecutor(max_workers=5)
self.loop.set_default_executor(self.executor)
self.loop.set_exception_handler(self._async_exception_handler)
self.pool = pool = create_worker_pool()
self.bus = EventBus(pool, self.loop)
self.pool = create_worker_pool()
self.bus = EventBus(self)
self.services = ServiceRegistry(self.bus, self.add_job, self.loop)
self.states = StateMachine(self.bus, self.loop)
self.config = Config() # type: Config
@ -218,8 +217,8 @@ class HomeAssistant(object):
"""
# pylint: disable=protected-access
self.loop._thread_ident = threading.get_ident()
async_create_timer(self)
async_monitor_worker_pool(self)
_async_create_timer(self)
_async_monitor_worker_pool(self)
self.bus.async_fire(EVENT_HOMEASSISTANT_START)
yield from self.loop.run_in_executor(None, self.pool.block_till_done)
self.state = CoreState.running
@ -235,9 +234,12 @@ class HomeAssistant(object):
"""
self.pool.add_job(priority, (target,) + args)
@callback
def async_add_job(self, target: Callable[..., None], *args: Any):
"""Add a job from within the eventloop.
This method must be run in the event loop.
target: target to call.
args: parameters for method to call.
"""
@ -248,9 +250,12 @@ class HomeAssistant(object):
else:
self.add_job(target, *args)
@callback
def async_run_job(self, target: Callable[..., None], *args: Any):
"""Run a job from within the event loop.
This method must be run in the event loop.
target: target to call.
args: parameters for method to call.
"""
@ -369,7 +374,10 @@ class Event(object):
self.time_fired = time_fired or dt_util.utcnow()
def as_dict(self):
"""Create a dict representation of this Event."""
"""Create a dict representation of this Event.
Async friendly.
"""
return {
'event_type': self.event_type,
'data': dict(self.data),
@ -400,13 +408,12 @@ class Event(object):
class EventBus(object):
"""Allows firing of and listening for events."""
def __init__(self, pool: util.ThreadPool,
loop: asyncio.AbstractEventLoop) -> None:
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize a new event bus."""
self._listeners = {}
self._pool = pool
self._loop = loop
self._hass = hass
@callback
def async_listeners(self):
"""Dict with events and the number of listeners.
@ -419,23 +426,25 @@ class EventBus(object):
def listeners(self):
"""Dict with events and the number of listeners."""
return run_callback_threadsafe(
self._loop, self.async_listeners
self._hass.loop, self.async_listeners
).result()
def fire(self, event_type: str, event_data=None, origin=EventOrigin.local):
"""Fire an event."""
if not self._pool.running:
raise HomeAssistantError('Home Assistant has shut down.')
self._loop.call_soon_threadsafe(self.async_fire, event_type,
event_data, origin)
self._hass.loop.call_soon_threadsafe(self.async_fire, event_type,
event_data, origin)
@callback
def async_fire(self, event_type: str, event_data=None,
origin=EventOrigin.local, wait=False):
"""Fire an event.
This method must be run in the event loop.
"""
if event_type != EVENT_HOMEASSISTANT_STOP and \
self._hass.state == CoreState.stopping:
raise HomeAssistantError('Home Assistant is shutting down.')
# Copy the list of the current listeners because some listeners
# remove themselves as a listener while being executed which
# causes the iterator to be confused.
@ -450,20 +459,8 @@ class EventBus(object):
if not listeners:
return
job_priority = JobPriority.from_event_type(event_type)
sync_jobs = []
for func in listeners:
if asyncio.iscoroutinefunction(func):
self._loop.create_task(func(event))
elif is_callback(func):
self._loop.call_soon(func, event)
else:
sync_jobs.append((job_priority, (func, event)))
# Send all the sync jobs at once
if sync_jobs:
self._pool.add_many_jobs(sync_jobs)
self._hass.async_add_job(func, event)
def listen(self, event_type, listener):
"""Listen for all events or events of a specific type.
@ -471,16 +468,17 @@ class EventBus(object):
To listen to all events specify the constant ``MATCH_ALL``
as event_type.
"""
future = run_callback_threadsafe(
self._loop, self.async_listen, event_type, listener)
future.result()
async_remove_listener = run_callback_threadsafe(
self._hass.loop, self.async_listen, event_type, listener).result()
def remove_listener():
"""Remove the listener."""
self._remove_listener(event_type, listener)
run_callback_threadsafe(
self._hass.loop, async_remove_listener).result()
return remove_listener
@callback
def async_listen(self, event_type, listener):
"""Listen for all events or events of a specific type.
@ -496,7 +494,7 @@ class EventBus(object):
def remove_listener():
"""Remove the listener."""
self.async_remove_listener(event_type, listener)
self._async_remove_listener(event_type, listener)
return remove_listener
@ -508,26 +506,18 @@ class EventBus(object):
Returns function to unsubscribe the listener.
"""
@ft.wraps(listener)
def onetime_listener(event):
"""Remove listener from eventbus and then fire listener."""
if hasattr(onetime_listener, 'run'):
return
# Set variable so that we will never run twice.
# Because the event bus might have to wait till a thread comes
# available to execute this listener it might occur that the
# listener gets lined up twice to be executed.
# This will make sure the second time it does nothing.
setattr(onetime_listener, 'run', True)
async_remove_listener = run_callback_threadsafe(
self._hass.loop, self.async_listen_once, event_type, listener,
).result()
remove_listener()
listener(event)
remove_listener = self.listen(event_type, onetime_listener)
def remove_listener():
"""Remove the listener."""
run_callback_threadsafe(
self._hass.loop, async_remove_listener).result()
return remove_listener
@callback
def async_listen_once(self, event_type, listener):
"""Listen once for event of a specific type.
@ -538,8 +528,7 @@ class EventBus(object):
This method must be run in the event loop.
"""
@ft.wraps(listener)
@asyncio.coroutine
@callback
def onetime_listener(event):
"""Remove listener from eventbus and then fire listener."""
if hasattr(onetime_listener, 'run'):
@ -550,34 +539,14 @@ class EventBus(object):
# multiple times as well.
# This will make sure the second time it does nothing.
setattr(onetime_listener, 'run', True)
self._async_remove_listener(event_type, onetime_listener)
self.async_remove_listener(event_type, onetime_listener)
self._hass.async_run_job(listener, event)
if asyncio.iscoroutinefunction(listener):
yield from listener(event)
else:
job_priority = JobPriority.from_event_type(event.event_type)
self._pool.add_job(job_priority, (listener, event))
return self.async_listen(event_type, onetime_listener)
self.async_listen(event_type, onetime_listener)
return onetime_listener
def remove_listener(self, event_type, listener):
"""Remove a listener of a specific event_type. (DEPRECATED 0.28)."""
_LOGGER.warning('bus.remove_listener has been deprecated. Please use '
'the function returned from calling listen.')
self._remove_listener(event_type, listener)
def _remove_listener(self, event_type, listener):
"""Remove a listener of a specific event_type."""
future = run_callback_threadsafe(
self._loop,
self.async_remove_listener, event_type, listener
)
future.result()
def async_remove_listener(self, event_type, listener):
@callback
def _async_remove_listener(self, event_type, listener):
"""Remove a listener of a specific event_type.
This method must be run in the event loop.
@ -644,6 +613,8 @@ class State(object):
def as_dict(self):
"""Return a dict representation of the State.
Async friendly.
To be used for JSON serialization.
Ensures: state == State.from_dict(state.as_dict())
"""
@ -657,6 +628,8 @@ class State(object):
def from_dict(cls, json_dict):
"""Initialize a state from a dict.
Async friendly.
Ensures: state == State.from_json_dict(state.to_json_dict())
"""
if not (json_dict and 'entity_id' in json_dict and
@ -709,8 +682,12 @@ class StateMachine(object):
)
return future.result()
@callback
def async_entity_ids(self, domain_filter=None):
"""List of entity ids that are being tracked."""
"""List of entity ids that are being tracked.
This method must be run in the event loop.
"""
if domain_filter is None:
return list(self._states.keys())
@ -723,6 +700,7 @@ class StateMachine(object):
"""Create a list of all states."""
return run_callback_threadsafe(self._loop, self.async_all).result()
@callback
def async_all(self):
"""Create a list of all states.
@ -763,6 +741,7 @@ class StateMachine(object):
return run_callback_threadsafe(
self._loop, self.async_remove, entity_id).result()
@callback
def async_remove(self, entity_id):
"""Remove the state of an entity.
@ -800,6 +779,7 @@ class StateMachine(object):
self.async_set, entity_id, new_state, attributes, force_update,
).result()
@callback
def async_set(self, entity_id, new_state, attributes=None,
force_update=False):
"""Set the state of an entity, add entity if it does not exist.
@ -908,14 +888,21 @@ class ServiceRegistry(object):
self._loop, self.async_services,
).result()
@callback
def async_services(self):
"""Dict with per domain a list of available services."""
"""Dict with per domain a list of available services.
This method must be run in the event loop.
"""
return {domain: {key: value.as_dict() for key, value
in self._services[domain].items()}
for domain in self._services}
def has_service(self, domain, service):
"""Test if specified service exists."""
"""Test if specified service exists.
Async friendly.
"""
return service.lower() in self._services.get(domain.lower(), [])
# pylint: disable=too-many-arguments
@ -935,6 +922,7 @@ class ServiceRegistry(object):
schema
).result()
@callback
def async_register(self, domain, service, service_func, description=None,
schema=None):
"""
@ -985,7 +973,7 @@ class ServiceRegistry(object):
self._loop
).result()
@callback
@asyncio.coroutine
def async_call(self, domain, service, service_data=None, blocking=False):
"""
Call a service.
@ -1121,18 +1109,27 @@ class Config(object):
self.config_dir = None
def distance(self: object, lat: float, lon: float) -> float:
"""Calculate distance from Home Assistant."""
"""Calculate distance from Home Assistant.
Async friendly.
"""
return self.units.length(
location.distance(self.latitude, self.longitude, lat, lon), 'm')
def path(self, *path):
"""Generate path to the file within the config dir."""
"""Generate path to the file within the config dir.
Async friendly.
"""
if self.config_dir is None:
raise HomeAssistantError("config_dir is not set")
return os.path.join(self.config_dir, *path)
def as_dict(self):
"""Create a dict representation of this dict."""
"""Create a dict representation of this dict.
Async friendly.
"""
time_zone = self.time_zone or dt_util.UTC
return {
@ -1147,7 +1144,7 @@ class Config(object):
}
def async_create_timer(hass, interval=TIMER_INTERVAL):
def _async_create_timer(hass, interval=TIMER_INTERVAL):
"""Create a timer that will start on HOMEASSISTANT_START."""
stop_event = asyncio.Event(loop=hass.loop)
@ -1230,7 +1227,7 @@ def create_worker_pool(worker_count=None):
return util.ThreadPool(job_handler, worker_count)
def async_monitor_worker_pool(hass):
def _async_monitor_worker_pool(hass):
"""Create a monitor for the thread pool to check if pool is misbehaving."""
busy_threshold = hass.pool.worker_count * 3

View file

@ -124,9 +124,9 @@ class HomeAssistant(ha.HomeAssistant):
self.remote_api = remote_api
self.loop = loop or asyncio.get_event_loop()
self.pool = pool = ha.create_worker_pool()
self.pool = ha.create_worker_pool()
self.bus = EventBus(remote_api, pool, self.loop)
self.bus = EventBus(remote_api, self)
self.services = ha.ServiceRegistry(self.bus, self.add_job, self.loop)
self.states = StateMachine(self.bus, self.loop, self.remote_api)
self.config = ha.Config()
@ -143,7 +143,7 @@ class HomeAssistant(ha.HomeAssistant):
'Unable to setup local API to receive events')
self.state = ha.CoreState.starting
ha.async_create_timer(self)
ha._async_create_timer(self) # pylint: disable=protected-access
self.bus.fire(ha.EVENT_HOMEASSISTANT_START,
origin=ha.EventOrigin.remote)
@ -180,9 +180,9 @@ class EventBus(ha.EventBus):
"""EventBus implementation that forwards fire_event to remote API."""
# pylint: disable=too-few-public-methods
def __init__(self, api, pool, loop):
def __init__(self, api, hass):
"""Initalize the eventbus."""
super().__init__(pool, loop)
super().__init__(hass)
self._api = api
def fire(self, event_type, event_data=None, origin=ha.EventOrigin.local):

View file

@ -76,8 +76,8 @@ def get_test_home_assistant(num_threads=None):
"""Fake stop."""
yield None
@patch.object(ha, 'async_create_timer')
@patch.object(ha, 'async_monitor_worker_pool')
@patch.object(ha, '_async_create_timer')
@patch.object(ha, '_async_monitor_worker_pool')
@patch.object(hass.loop, 'add_signal_handler')
@patch.object(hass.loop, 'run_forever')
@patch.object(hass.loop, 'close')

View file

@ -145,14 +145,14 @@ class TestAPI(unittest.TestCase):
requests.post(_url(const.URL_API_STATES_ENTITY.format("test.test")),
data=json.dumps({"state": "not_to_be_set"}),
headers=HA_HEADERS)
hass.bus._pool.block_till_done()
hass.block_till_done()
self.assertEqual(0, len(events))
requests.post(_url(const.URL_API_STATES_ENTITY.format("test.test")),
data=json.dumps({"state": "not_to_be_set",
"force_update": True}),
headers=HA_HEADERS)
hass.bus._pool.block_till_done()
hass.block_till_done()
self.assertEqual(1, len(events))
# pylint: disable=invalid-name

View file

@ -179,19 +179,16 @@ class TestEventBus(unittest.TestCase):
def listener(_): pass
self.bus.listen('test', listener)
unsub = self.bus.listen('test', listener)
self.assertEqual(old_count + 1, len(self.bus.listeners))
# Try deleting a non registered listener, nothing should happen
self.bus._remove_listener('test', lambda x: len)
# Remove listener
self.bus._remove_listener('test', listener)
unsub()
self.assertEqual(old_count, len(self.bus.listeners))
# Try deleting listener while category doesn't exist either
self.bus._remove_listener('test', listener)
# Should do nothing now
unsub()
def test_unsubscribe_listener(self):
"""Test unsubscribe listener from returned function."""
@ -215,11 +212,48 @@ class TestEventBus(unittest.TestCase):
assert len(calls) == 1
def test_listen_once_event(self):
def test_listen_once_event_with_callback(self):
"""Test listen_once_event method."""
runs = []
self.bus.listen_once('test_event', lambda x: runs.append(1))
@ha.callback
def event_handler(event):
runs.append(event)
self.bus.listen_once('test_event', event_handler)
self.bus.fire('test_event')
# Second time it should not increase runs
self.bus.fire('test_event')
self.hass.block_till_done()
self.assertEqual(1, len(runs))
def test_listen_once_event_with_coroutine(self):
"""Test listen_once_event method."""
runs = []
@asyncio.coroutine
def event_handler(event):
runs.append(event)
self.bus.listen_once('test_event', event_handler)
self.bus.fire('test_event')
# Second time it should not increase runs
self.bus.fire('test_event')
self.hass.block_till_done()
self.assertEqual(1, len(runs))
def test_listen_once_event_with_thread(self):
"""Test listen_once_event method."""
runs = []
def event_handler(event):
runs.append(event)
self.bus.listen_once('test_event', event_handler)
self.bus.fire('test_event')
# Second time it should not increase runs
@ -604,7 +638,7 @@ class TestWorkerPoolMonitor(object):
schedule_handle = MagicMock()
hass.loop.call_later.return_value = schedule_handle
ha.async_monitor_worker_pool(hass)
ha._async_monitor_worker_pool(hass)
assert hass.loop.call_later.called
assert hass.bus.async_listen_once.called
assert not schedule_handle.called
@ -650,7 +684,7 @@ class TestAsyncCreateTimer(object):
now.second = 1
mock_utcnow.reset_mock()
ha.async_create_timer(hass)
ha._async_create_timer(hass)
assert len(hass.bus.async_listen_once.mock_calls) == 2
start_timer = hass.bus.async_listen_once.mock_calls[1][1][1]

View file

@ -69,7 +69,7 @@ def setUpModule(): # pylint: disable=invalid-name
{http.DOMAIN: {http.CONF_API_PASSWORD: API_PASSWORD,
http.CONF_SERVER_PORT: SLAVE_PORT}})
with patch.object(ha, 'async_create_timer', return_value=None):
with patch.object(ha, '_async_create_timer', return_value=None):
slave.start()