Add async_safe annotation (#3688)

* Add async_safe annotation

* More async_run_job

* coroutine -> async_save

* Lint

* Rename async_safe -> callback

* Add tests to core for different job types

* Add one more test with different type of callbacks

* Fix typing signature for callback methods

* Fix callback service executed method

* Fix method signatures for callback
This commit is contained in:
Paulus Schoutsen 2016-10-04 20:44:32 -07:00 committed by GitHub
parent be7401f4a2
commit 5085cdb0f7
19 changed files with 231 additions and 87 deletions

View file

@ -78,6 +78,18 @@ def valid_entity_id(entity_id: str) -> bool:
return ENTITY_ID_PATTERN.match(entity_id) is not None
def callback(func: Callable[..., None]) -> Callable[..., None]:
"""Annotation to mark method as safe to call from within the event loop."""
# pylint: disable=protected-access
func._hass_callback = True
return func
def is_callback(func: Callable[..., Any]) -> bool:
"""Check if function is safe to be called in the event loop."""
return '_hass_callback' in func.__dict__
class CoreState(enum.Enum):
"""Represent the current state of Home Assistant."""
@ -224,11 +236,24 @@ class HomeAssistant(object):
target: target to call.
args: parameters for method to call.
"""
if asyncio.iscoroutinefunction(target):
if is_callback(target):
self.loop.call_soon(target, *args)
elif asyncio.iscoroutinefunction(target):
self.loop.create_task(target(*args))
else:
self.add_job(target, *args)
def async_run_job(self, target: Callable[..., None], *args: Any):
"""Run a job from within the event loop.
target: target to call.
args: parameters for method to call.
"""
if is_callback(target):
target(*args)
else:
self.async_add_job(target, *args)
def _loop_empty(self):
"""Python 3.4.2 empty loop compatibility function."""
# pylint: disable=protected-access
@ -380,7 +405,6 @@ class EventBus(object):
self._loop.call_soon_threadsafe(self.async_fire, event_type,
event_data, origin)
return
def async_fire(self, event_type: str, event_data=None,
origin=EventOrigin.local, wait=False):
@ -408,6 +432,8 @@ class EventBus(object):
for func in listeners:
if asyncio.iscoroutinefunction(func):
self._loop.create_task(func(event))
elif is_callback(func):
self._loop.call_soon(func, event)
else:
sync_jobs.append((job_priority, (func, event)))
@ -795,7 +821,7 @@ class Service(object):
"""Represents a callable service."""
__slots__ = ['func', 'description', 'fields', 'schema',
'iscoroutinefunction']
'is_callback', 'is_coroutinefunction']
def __init__(self, func, description, fields, schema):
"""Initialize a service."""
@ -803,7 +829,8 @@ class Service(object):
self.description = description or ''
self.fields = fields or {}
self.schema = schema
self.iscoroutinefunction = asyncio.iscoroutinefunction(func)
self.is_callback = is_callback(func)
self.is_coroutinefunction = asyncio.iscoroutinefunction(func)
def as_dict(self):
"""Return dictionary representation of this service."""
@ -934,7 +961,7 @@ class ServiceRegistry(object):
self._loop
).result()
@asyncio.coroutine
@callback
def async_call(self, domain, service, service_data=None, blocking=False):
"""
Call a service.
@ -966,7 +993,7 @@ class ServiceRegistry(object):
if blocking:
fut = asyncio.Future(loop=self._loop)
@asyncio.coroutine
@callback
def service_executed(event):
"""Callback method that is called when service is executed."""
if event.data[ATTR_SERVICE_CALL_ID] == call_id:
@ -1007,7 +1034,8 @@ class ServiceRegistry(object):
data = {ATTR_SERVICE_CALL_ID: call_id}
if service_handler.iscoroutinefunction:
if (service_handler.is_coroutinefunction or
service_handler.is_callback):
self._bus.async_fire(EVENT_SERVICE_EXECUTED, data)
else:
self._bus.fire(EVENT_SERVICE_EXECUTED, data)
@ -1023,17 +1051,19 @@ class ServiceRegistry(object):
service_call = ServiceCall(domain, service, service_data, call_id)
if not service_handler.iscoroutinefunction:
if service_handler.is_callback:
service_handler.func(service_call)
fire_service_executed()
elif service_handler.is_coroutinefunction:
yield from service_handler.func(service_call)
fire_service_executed()
else:
def execute_service():
"""Execute a service and fires a SERVICE_EXECUTED event."""
service_handler.func(service_call)
fire_service_executed()
self._add_job(execute_service, priority=JobPriority.EVENT_SERVICE)
return
yield from service_handler.func(service_call)
fire_service_executed()
def _generate_unique_id(self):
"""Generate a unique service call id."""
@ -1098,7 +1128,7 @@ def async_create_timer(hass, interval=TIMER_INTERVAL):
stop_event = asyncio.Event(loop=hass.loop)
# Setting the Event inside the loop by marking it as a coroutine
@asyncio.coroutine
@callback
def stop_timer(event):
"""Stop the timer."""
stop_event.set()
@ -1212,7 +1242,7 @@ def async_monitor_worker_pool(hass):
schedule()
@asyncio.coroutine
@callback
def stop_monitor(event):
"""Stop the monitor."""
handle.cancel()