Clean up some async stuff (#3915)

* Clean up some async stuff

* Adjust comments

* Pass hass instance to eventbus
This commit is contained in:
Paulus Schoutsen 2016-10-17 19:38:41 -07:00 committed by GitHub
parent daea93d9f9
commit 4c8d1d9d2f
7 changed files with 139 additions and 111 deletions

View file

@ -8,7 +8,6 @@ of entities and react to changes.
import asyncio
from concurrent.futures import ThreadPoolExecutor
import enum
import functools as ft
import logging
import os
import re
@ -137,8 +136,8 @@ class HomeAssistant(object):
self.executor = ThreadPoolExecutor(max_workers=5)
self.loop.set_default_executor(self.executor)
self.loop.set_exception_handler(self._async_exception_handler)
self.pool = pool = create_worker_pool()
self.bus = EventBus(pool, self.loop)
self.pool = create_worker_pool()
self.bus = EventBus(self)
self.services = ServiceRegistry(self.bus, self.add_job, self.loop)
self.states = StateMachine(self.bus, self.loop)
self.config = Config() # type: Config
@ -218,8 +217,8 @@ class HomeAssistant(object):
"""
# pylint: disable=protected-access
self.loop._thread_ident = threading.get_ident()
async_create_timer(self)
async_monitor_worker_pool(self)
_async_create_timer(self)
_async_monitor_worker_pool(self)
self.bus.async_fire(EVENT_HOMEASSISTANT_START)
yield from self.loop.run_in_executor(None, self.pool.block_till_done)
self.state = CoreState.running
@ -235,9 +234,12 @@ class HomeAssistant(object):
"""
self.pool.add_job(priority, (target,) + args)
@callback
def async_add_job(self, target: Callable[..., None], *args: Any):
"""Add a job from within the eventloop.
This method must be run in the event loop.
target: target to call.
args: parameters for method to call.
"""
@ -248,9 +250,12 @@ class HomeAssistant(object):
else:
self.add_job(target, *args)
@callback
def async_run_job(self, target: Callable[..., None], *args: Any):
"""Run a job from within the event loop.
This method must be run in the event loop.
target: target to call.
args: parameters for method to call.
"""
@ -369,7 +374,10 @@ class Event(object):
self.time_fired = time_fired or dt_util.utcnow()
def as_dict(self):
"""Create a dict representation of this Event."""
"""Create a dict representation of this Event.
Async friendly.
"""
return {
'event_type': self.event_type,
'data': dict(self.data),
@ -400,13 +408,12 @@ class Event(object):
class EventBus(object):
"""Allows firing of and listening for events."""
def __init__(self, pool: util.ThreadPool,
loop: asyncio.AbstractEventLoop) -> None:
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize a new event bus."""
self._listeners = {}
self._pool = pool
self._loop = loop
self._hass = hass
@callback
def async_listeners(self):
"""Dict with events and the number of listeners.
@ -419,23 +426,25 @@ class EventBus(object):
def listeners(self):
"""Dict with events and the number of listeners."""
return run_callback_threadsafe(
self._loop, self.async_listeners
self._hass.loop, self.async_listeners
).result()
def fire(self, event_type: str, event_data=None, origin=EventOrigin.local):
"""Fire an event."""
if not self._pool.running:
raise HomeAssistantError('Home Assistant has shut down.')
self._loop.call_soon_threadsafe(self.async_fire, event_type,
event_data, origin)
self._hass.loop.call_soon_threadsafe(self.async_fire, event_type,
event_data, origin)
@callback
def async_fire(self, event_type: str, event_data=None,
origin=EventOrigin.local, wait=False):
"""Fire an event.
This method must be run in the event loop.
"""
if event_type != EVENT_HOMEASSISTANT_STOP and \
self._hass.state == CoreState.stopping:
raise HomeAssistantError('Home Assistant is shutting down.')
# Copy the list of the current listeners because some listeners
# remove themselves as a listener while being executed which
# causes the iterator to be confused.
@ -450,20 +459,8 @@ class EventBus(object):
if not listeners:
return
job_priority = JobPriority.from_event_type(event_type)
sync_jobs = []
for func in listeners:
if asyncio.iscoroutinefunction(func):
self._loop.create_task(func(event))
elif is_callback(func):
self._loop.call_soon(func, event)
else:
sync_jobs.append((job_priority, (func, event)))
# Send all the sync jobs at once
if sync_jobs:
self._pool.add_many_jobs(sync_jobs)
self._hass.async_add_job(func, event)
def listen(self, event_type, listener):
"""Listen for all events or events of a specific type.
@ -471,16 +468,17 @@ class EventBus(object):
To listen to all events specify the constant ``MATCH_ALL``
as event_type.
"""
future = run_callback_threadsafe(
self._loop, self.async_listen, event_type, listener)
future.result()
async_remove_listener = run_callback_threadsafe(
self._hass.loop, self.async_listen, event_type, listener).result()
def remove_listener():
"""Remove the listener."""
self._remove_listener(event_type, listener)
run_callback_threadsafe(
self._hass.loop, async_remove_listener).result()
return remove_listener
@callback
def async_listen(self, event_type, listener):
"""Listen for all events or events of a specific type.
@ -496,7 +494,7 @@ class EventBus(object):
def remove_listener():
"""Remove the listener."""
self.async_remove_listener(event_type, listener)
self._async_remove_listener(event_type, listener)
return remove_listener
@ -508,26 +506,18 @@ class EventBus(object):
Returns function to unsubscribe the listener.
"""
@ft.wraps(listener)
def onetime_listener(event):
"""Remove listener from eventbus and then fire listener."""
if hasattr(onetime_listener, 'run'):
return
# Set variable so that we will never run twice.
# Because the event bus might have to wait till a thread comes
# available to execute this listener it might occur that the
# listener gets lined up twice to be executed.
# This will make sure the second time it does nothing.
setattr(onetime_listener, 'run', True)
async_remove_listener = run_callback_threadsafe(
self._hass.loop, self.async_listen_once, event_type, listener,
).result()
remove_listener()
listener(event)
remove_listener = self.listen(event_type, onetime_listener)
def remove_listener():
"""Remove the listener."""
run_callback_threadsafe(
self._hass.loop, async_remove_listener).result()
return remove_listener
@callback
def async_listen_once(self, event_type, listener):
"""Listen once for event of a specific type.
@ -538,8 +528,7 @@ class EventBus(object):
This method must be run in the event loop.
"""
@ft.wraps(listener)
@asyncio.coroutine
@callback
def onetime_listener(event):
"""Remove listener from eventbus and then fire listener."""
if hasattr(onetime_listener, 'run'):
@ -550,34 +539,14 @@ class EventBus(object):
# multiple times as well.
# This will make sure the second time it does nothing.
setattr(onetime_listener, 'run', True)
self._async_remove_listener(event_type, onetime_listener)
self.async_remove_listener(event_type, onetime_listener)
self._hass.async_run_job(listener, event)
if asyncio.iscoroutinefunction(listener):
yield from listener(event)
else:
job_priority = JobPriority.from_event_type(event.event_type)
self._pool.add_job(job_priority, (listener, event))
return self.async_listen(event_type, onetime_listener)
self.async_listen(event_type, onetime_listener)
return onetime_listener
def remove_listener(self, event_type, listener):
"""Remove a listener of a specific event_type. (DEPRECATED 0.28)."""
_LOGGER.warning('bus.remove_listener has been deprecated. Please use '
'the function returned from calling listen.')
self._remove_listener(event_type, listener)
def _remove_listener(self, event_type, listener):
"""Remove a listener of a specific event_type."""
future = run_callback_threadsafe(
self._loop,
self.async_remove_listener, event_type, listener
)
future.result()
def async_remove_listener(self, event_type, listener):
@callback
def _async_remove_listener(self, event_type, listener):
"""Remove a listener of a specific event_type.
This method must be run in the event loop.
@ -644,6 +613,8 @@ class State(object):
def as_dict(self):
"""Return a dict representation of the State.
Async friendly.
To be used for JSON serialization.
Ensures: state == State.from_dict(state.as_dict())
"""
@ -657,6 +628,8 @@ class State(object):
def from_dict(cls, json_dict):
"""Initialize a state from a dict.
Async friendly.
Ensures: state == State.from_json_dict(state.to_json_dict())
"""
if not (json_dict and 'entity_id' in json_dict and
@ -709,8 +682,12 @@ class StateMachine(object):
)
return future.result()
@callback
def async_entity_ids(self, domain_filter=None):
"""List of entity ids that are being tracked."""
"""List of entity ids that are being tracked.
This method must be run in the event loop.
"""
if domain_filter is None:
return list(self._states.keys())
@ -723,6 +700,7 @@ class StateMachine(object):
"""Create a list of all states."""
return run_callback_threadsafe(self._loop, self.async_all).result()
@callback
def async_all(self):
"""Create a list of all states.
@ -763,6 +741,7 @@ class StateMachine(object):
return run_callback_threadsafe(
self._loop, self.async_remove, entity_id).result()
@callback
def async_remove(self, entity_id):
"""Remove the state of an entity.
@ -800,6 +779,7 @@ class StateMachine(object):
self.async_set, entity_id, new_state, attributes, force_update,
).result()
@callback
def async_set(self, entity_id, new_state, attributes=None,
force_update=False):
"""Set the state of an entity, add entity if it does not exist.
@ -908,14 +888,21 @@ class ServiceRegistry(object):
self._loop, self.async_services,
).result()
@callback
def async_services(self):
"""Dict with per domain a list of available services."""
"""Dict with per domain a list of available services.
This method must be run in the event loop.
"""
return {domain: {key: value.as_dict() for key, value
in self._services[domain].items()}
for domain in self._services}
def has_service(self, domain, service):
"""Test if specified service exists."""
"""Test if specified service exists.
Async friendly.
"""
return service.lower() in self._services.get(domain.lower(), [])
# pylint: disable=too-many-arguments
@ -935,6 +922,7 @@ class ServiceRegistry(object):
schema
).result()
@callback
def async_register(self, domain, service, service_func, description=None,
schema=None):
"""
@ -985,7 +973,7 @@ class ServiceRegistry(object):
self._loop
).result()
@callback
@asyncio.coroutine
def async_call(self, domain, service, service_data=None, blocking=False):
"""
Call a service.
@ -1121,18 +1109,27 @@ class Config(object):
self.config_dir = None
def distance(self: object, lat: float, lon: float) -> float:
"""Calculate distance from Home Assistant."""
"""Calculate distance from Home Assistant.
Async friendly.
"""
return self.units.length(
location.distance(self.latitude, self.longitude, lat, lon), 'm')
def path(self, *path):
"""Generate path to the file within the config dir."""
"""Generate path to the file within the config dir.
Async friendly.
"""
if self.config_dir is None:
raise HomeAssistantError("config_dir is not set")
return os.path.join(self.config_dir, *path)
def as_dict(self):
"""Create a dict representation of this dict."""
"""Create a dict representation of this dict.
Async friendly.
"""
time_zone = self.time_zone or dt_util.UTC
return {
@ -1147,7 +1144,7 @@ class Config(object):
}
def async_create_timer(hass, interval=TIMER_INTERVAL):
def _async_create_timer(hass, interval=TIMER_INTERVAL):
"""Create a timer that will start on HOMEASSISTANT_START."""
stop_event = asyncio.Event(loop=hass.loop)
@ -1230,7 +1227,7 @@ def create_worker_pool(worker_count=None):
return util.ThreadPool(job_handler, worker_count)
def async_monitor_worker_pool(hass):
def _async_monitor_worker_pool(hass):
"""Create a monitor for the thread pool to check if pool is misbehaving."""
busy_threshold = hass.pool.worker_count * 3