Clean up some async stuff (#3915)
* Clean up some async stuff * Adjust comments * Pass hass instance to eventbus
This commit is contained in:
parent
daea93d9f9
commit
4c8d1d9d2f
7 changed files with 139 additions and 111 deletions
|
@ -8,7 +8,6 @@ of entities and react to changes.
|
|||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import enum
|
||||
import functools as ft
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
@ -137,8 +136,8 @@ class HomeAssistant(object):
|
|||
self.executor = ThreadPoolExecutor(max_workers=5)
|
||||
self.loop.set_default_executor(self.executor)
|
||||
self.loop.set_exception_handler(self._async_exception_handler)
|
||||
self.pool = pool = create_worker_pool()
|
||||
self.bus = EventBus(pool, self.loop)
|
||||
self.pool = create_worker_pool()
|
||||
self.bus = EventBus(self)
|
||||
self.services = ServiceRegistry(self.bus, self.add_job, self.loop)
|
||||
self.states = StateMachine(self.bus, self.loop)
|
||||
self.config = Config() # type: Config
|
||||
|
@ -218,8 +217,8 @@ class HomeAssistant(object):
|
|||
"""
|
||||
# pylint: disable=protected-access
|
||||
self.loop._thread_ident = threading.get_ident()
|
||||
async_create_timer(self)
|
||||
async_monitor_worker_pool(self)
|
||||
_async_create_timer(self)
|
||||
_async_monitor_worker_pool(self)
|
||||
self.bus.async_fire(EVENT_HOMEASSISTANT_START)
|
||||
yield from self.loop.run_in_executor(None, self.pool.block_till_done)
|
||||
self.state = CoreState.running
|
||||
|
@ -235,9 +234,12 @@ class HomeAssistant(object):
|
|||
"""
|
||||
self.pool.add_job(priority, (target,) + args)
|
||||
|
||||
@callback
|
||||
def async_add_job(self, target: Callable[..., None], *args: Any):
|
||||
"""Add a job from within the eventloop.
|
||||
|
||||
This method must be run in the event loop.
|
||||
|
||||
target: target to call.
|
||||
args: parameters for method to call.
|
||||
"""
|
||||
|
@ -248,9 +250,12 @@ class HomeAssistant(object):
|
|||
else:
|
||||
self.add_job(target, *args)
|
||||
|
||||
@callback
|
||||
def async_run_job(self, target: Callable[..., None], *args: Any):
|
||||
"""Run a job from within the event loop.
|
||||
|
||||
This method must be run in the event loop.
|
||||
|
||||
target: target to call.
|
||||
args: parameters for method to call.
|
||||
"""
|
||||
|
@ -369,7 +374,10 @@ class Event(object):
|
|||
self.time_fired = time_fired or dt_util.utcnow()
|
||||
|
||||
def as_dict(self):
|
||||
"""Create a dict representation of this Event."""
|
||||
"""Create a dict representation of this Event.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
return {
|
||||
'event_type': self.event_type,
|
||||
'data': dict(self.data),
|
||||
|
@ -400,13 +408,12 @@ class Event(object):
|
|||
class EventBus(object):
|
||||
"""Allows firing of and listening for events."""
|
||||
|
||||
def __init__(self, pool: util.ThreadPool,
|
||||
loop: asyncio.AbstractEventLoop) -> None:
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize a new event bus."""
|
||||
self._listeners = {}
|
||||
self._pool = pool
|
||||
self._loop = loop
|
||||
self._hass = hass
|
||||
|
||||
@callback
|
||||
def async_listeners(self):
|
||||
"""Dict with events and the number of listeners.
|
||||
|
||||
|
@ -419,23 +426,25 @@ class EventBus(object):
|
|||
def listeners(self):
|
||||
"""Dict with events and the number of listeners."""
|
||||
return run_callback_threadsafe(
|
||||
self._loop, self.async_listeners
|
||||
self._hass.loop, self.async_listeners
|
||||
).result()
|
||||
|
||||
def fire(self, event_type: str, event_data=None, origin=EventOrigin.local):
|
||||
"""Fire an event."""
|
||||
if not self._pool.running:
|
||||
raise HomeAssistantError('Home Assistant has shut down.')
|
||||
|
||||
self._loop.call_soon_threadsafe(self.async_fire, event_type,
|
||||
event_data, origin)
|
||||
self._hass.loop.call_soon_threadsafe(self.async_fire, event_type,
|
||||
event_data, origin)
|
||||
|
||||
@callback
|
||||
def async_fire(self, event_type: str, event_data=None,
|
||||
origin=EventOrigin.local, wait=False):
|
||||
"""Fire an event.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
if event_type != EVENT_HOMEASSISTANT_STOP and \
|
||||
self._hass.state == CoreState.stopping:
|
||||
raise HomeAssistantError('Home Assistant is shutting down.')
|
||||
|
||||
# Copy the list of the current listeners because some listeners
|
||||
# remove themselves as a listener while being executed which
|
||||
# causes the iterator to be confused.
|
||||
|
@ -450,20 +459,8 @@ class EventBus(object):
|
|||
if not listeners:
|
||||
return
|
||||
|
||||
job_priority = JobPriority.from_event_type(event_type)
|
||||
|
||||
sync_jobs = []
|
||||
for func in listeners:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
self._loop.create_task(func(event))
|
||||
elif is_callback(func):
|
||||
self._loop.call_soon(func, event)
|
||||
else:
|
||||
sync_jobs.append((job_priority, (func, event)))
|
||||
|
||||
# Send all the sync jobs at once
|
||||
if sync_jobs:
|
||||
self._pool.add_many_jobs(sync_jobs)
|
||||
self._hass.async_add_job(func, event)
|
||||
|
||||
def listen(self, event_type, listener):
|
||||
"""Listen for all events or events of a specific type.
|
||||
|
@ -471,16 +468,17 @@ class EventBus(object):
|
|||
To listen to all events specify the constant ``MATCH_ALL``
|
||||
as event_type.
|
||||
"""
|
||||
future = run_callback_threadsafe(
|
||||
self._loop, self.async_listen, event_type, listener)
|
||||
future.result()
|
||||
async_remove_listener = run_callback_threadsafe(
|
||||
self._hass.loop, self.async_listen, event_type, listener).result()
|
||||
|
||||
def remove_listener():
|
||||
"""Remove the listener."""
|
||||
self._remove_listener(event_type, listener)
|
||||
run_callback_threadsafe(
|
||||
self._hass.loop, async_remove_listener).result()
|
||||
|
||||
return remove_listener
|
||||
|
||||
@callback
|
||||
def async_listen(self, event_type, listener):
|
||||
"""Listen for all events or events of a specific type.
|
||||
|
||||
|
@ -496,7 +494,7 @@ class EventBus(object):
|
|||
|
||||
def remove_listener():
|
||||
"""Remove the listener."""
|
||||
self.async_remove_listener(event_type, listener)
|
||||
self._async_remove_listener(event_type, listener)
|
||||
|
||||
return remove_listener
|
||||
|
||||
|
@ -508,26 +506,18 @@ class EventBus(object):
|
|||
|
||||
Returns function to unsubscribe the listener.
|
||||
"""
|
||||
@ft.wraps(listener)
|
||||
def onetime_listener(event):
|
||||
"""Remove listener from eventbus and then fire listener."""
|
||||
if hasattr(onetime_listener, 'run'):
|
||||
return
|
||||
# Set variable so that we will never run twice.
|
||||
# Because the event bus might have to wait till a thread comes
|
||||
# available to execute this listener it might occur that the
|
||||
# listener gets lined up twice to be executed.
|
||||
# This will make sure the second time it does nothing.
|
||||
setattr(onetime_listener, 'run', True)
|
||||
async_remove_listener = run_callback_threadsafe(
|
||||
self._hass.loop, self.async_listen_once, event_type, listener,
|
||||
).result()
|
||||
|
||||
remove_listener()
|
||||
|
||||
listener(event)
|
||||
|
||||
remove_listener = self.listen(event_type, onetime_listener)
|
||||
def remove_listener():
|
||||
"""Remove the listener."""
|
||||
run_callback_threadsafe(
|
||||
self._hass.loop, async_remove_listener).result()
|
||||
|
||||
return remove_listener
|
||||
|
||||
@callback
|
||||
def async_listen_once(self, event_type, listener):
|
||||
"""Listen once for event of a specific type.
|
||||
|
||||
|
@ -538,8 +528,7 @@ class EventBus(object):
|
|||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
@ft.wraps(listener)
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def onetime_listener(event):
|
||||
"""Remove listener from eventbus and then fire listener."""
|
||||
if hasattr(onetime_listener, 'run'):
|
||||
|
@ -550,34 +539,14 @@ class EventBus(object):
|
|||
# multiple times as well.
|
||||
# This will make sure the second time it does nothing.
|
||||
setattr(onetime_listener, 'run', True)
|
||||
self._async_remove_listener(event_type, onetime_listener)
|
||||
|
||||
self.async_remove_listener(event_type, onetime_listener)
|
||||
self._hass.async_run_job(listener, event)
|
||||
|
||||
if asyncio.iscoroutinefunction(listener):
|
||||
yield from listener(event)
|
||||
else:
|
||||
job_priority = JobPriority.from_event_type(event.event_type)
|
||||
self._pool.add_job(job_priority, (listener, event))
|
||||
return self.async_listen(event_type, onetime_listener)
|
||||
|
||||
self.async_listen(event_type, onetime_listener)
|
||||
|
||||
return onetime_listener
|
||||
|
||||
def remove_listener(self, event_type, listener):
|
||||
"""Remove a listener of a specific event_type. (DEPRECATED 0.28)."""
|
||||
_LOGGER.warning('bus.remove_listener has been deprecated. Please use '
|
||||
'the function returned from calling listen.')
|
||||
self._remove_listener(event_type, listener)
|
||||
|
||||
def _remove_listener(self, event_type, listener):
|
||||
"""Remove a listener of a specific event_type."""
|
||||
future = run_callback_threadsafe(
|
||||
self._loop,
|
||||
self.async_remove_listener, event_type, listener
|
||||
)
|
||||
future.result()
|
||||
|
||||
def async_remove_listener(self, event_type, listener):
|
||||
@callback
|
||||
def _async_remove_listener(self, event_type, listener):
|
||||
"""Remove a listener of a specific event_type.
|
||||
|
||||
This method must be run in the event loop.
|
||||
|
@ -644,6 +613,8 @@ class State(object):
|
|||
def as_dict(self):
|
||||
"""Return a dict representation of the State.
|
||||
|
||||
Async friendly.
|
||||
|
||||
To be used for JSON serialization.
|
||||
Ensures: state == State.from_dict(state.as_dict())
|
||||
"""
|
||||
|
@ -657,6 +628,8 @@ class State(object):
|
|||
def from_dict(cls, json_dict):
|
||||
"""Initialize a state from a dict.
|
||||
|
||||
Async friendly.
|
||||
|
||||
Ensures: state == State.from_json_dict(state.to_json_dict())
|
||||
"""
|
||||
if not (json_dict and 'entity_id' in json_dict and
|
||||
|
@ -709,8 +682,12 @@ class StateMachine(object):
|
|||
)
|
||||
return future.result()
|
||||
|
||||
@callback
|
||||
def async_entity_ids(self, domain_filter=None):
|
||||
"""List of entity ids that are being tracked."""
|
||||
"""List of entity ids that are being tracked.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
if domain_filter is None:
|
||||
return list(self._states.keys())
|
||||
|
||||
|
@ -723,6 +700,7 @@ class StateMachine(object):
|
|||
"""Create a list of all states."""
|
||||
return run_callback_threadsafe(self._loop, self.async_all).result()
|
||||
|
||||
@callback
|
||||
def async_all(self):
|
||||
"""Create a list of all states.
|
||||
|
||||
|
@ -763,6 +741,7 @@ class StateMachine(object):
|
|||
return run_callback_threadsafe(
|
||||
self._loop, self.async_remove, entity_id).result()
|
||||
|
||||
@callback
|
||||
def async_remove(self, entity_id):
|
||||
"""Remove the state of an entity.
|
||||
|
||||
|
@ -800,6 +779,7 @@ class StateMachine(object):
|
|||
self.async_set, entity_id, new_state, attributes, force_update,
|
||||
).result()
|
||||
|
||||
@callback
|
||||
def async_set(self, entity_id, new_state, attributes=None,
|
||||
force_update=False):
|
||||
"""Set the state of an entity, add entity if it does not exist.
|
||||
|
@ -908,14 +888,21 @@ class ServiceRegistry(object):
|
|||
self._loop, self.async_services,
|
||||
).result()
|
||||
|
||||
@callback
|
||||
def async_services(self):
|
||||
"""Dict with per domain a list of available services."""
|
||||
"""Dict with per domain a list of available services.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
return {domain: {key: value.as_dict() for key, value
|
||||
in self._services[domain].items()}
|
||||
for domain in self._services}
|
||||
|
||||
def has_service(self, domain, service):
|
||||
"""Test if specified service exists."""
|
||||
"""Test if specified service exists.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
return service.lower() in self._services.get(domain.lower(), [])
|
||||
|
||||
# pylint: disable=too-many-arguments
|
||||
|
@ -935,6 +922,7 @@ class ServiceRegistry(object):
|
|||
schema
|
||||
).result()
|
||||
|
||||
@callback
|
||||
def async_register(self, domain, service, service_func, description=None,
|
||||
schema=None):
|
||||
"""
|
||||
|
@ -985,7 +973,7 @@ class ServiceRegistry(object):
|
|||
self._loop
|
||||
).result()
|
||||
|
||||
@callback
|
||||
@asyncio.coroutine
|
||||
def async_call(self, domain, service, service_data=None, blocking=False):
|
||||
"""
|
||||
Call a service.
|
||||
|
@ -1121,18 +1109,27 @@ class Config(object):
|
|||
self.config_dir = None
|
||||
|
||||
def distance(self: object, lat: float, lon: float) -> float:
|
||||
"""Calculate distance from Home Assistant."""
|
||||
"""Calculate distance from Home Assistant.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
return self.units.length(
|
||||
location.distance(self.latitude, self.longitude, lat, lon), 'm')
|
||||
|
||||
def path(self, *path):
|
||||
"""Generate path to the file within the config dir."""
|
||||
"""Generate path to the file within the config dir.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
if self.config_dir is None:
|
||||
raise HomeAssistantError("config_dir is not set")
|
||||
return os.path.join(self.config_dir, *path)
|
||||
|
||||
def as_dict(self):
|
||||
"""Create a dict representation of this dict."""
|
||||
"""Create a dict representation of this dict.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
time_zone = self.time_zone or dt_util.UTC
|
||||
|
||||
return {
|
||||
|
@ -1147,7 +1144,7 @@ class Config(object):
|
|||
}
|
||||
|
||||
|
||||
def async_create_timer(hass, interval=TIMER_INTERVAL):
|
||||
def _async_create_timer(hass, interval=TIMER_INTERVAL):
|
||||
"""Create a timer that will start on HOMEASSISTANT_START."""
|
||||
stop_event = asyncio.Event(loop=hass.loop)
|
||||
|
||||
|
@ -1230,7 +1227,7 @@ def create_worker_pool(worker_count=None):
|
|||
return util.ThreadPool(job_handler, worker_count)
|
||||
|
||||
|
||||
def async_monitor_worker_pool(hass):
|
||||
def _async_monitor_worker_pool(hass):
|
||||
"""Create a monitor for the thread pool to check if pool is misbehaving."""
|
||||
busy_threshold = hass.pool.worker_count * 3
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue