Add async_safe annotation (#3688)
* Add async_safe annotation * More async_run_job * coroutine -> async_save * Lint * Rename async_safe -> callback * Add tests to core for different job types * Add one more test with different type of callbacks * Fix typing signature for callback methods * Fix callback service executed method * Fix method signatures for callback
This commit is contained in:
parent
be7401f4a2
commit
5085cdb0f7
19 changed files with 231 additions and 87 deletions
|
@ -78,6 +78,18 @@ def valid_entity_id(entity_id: str) -> bool:
|
|||
return ENTITY_ID_PATTERN.match(entity_id) is not None
|
||||
|
||||
|
||||
def callback(func: Callable[..., None]) -> Callable[..., None]:
|
||||
"""Annotation to mark method as safe to call from within the event loop."""
|
||||
# pylint: disable=protected-access
|
||||
func._hass_callback = True
|
||||
return func
|
||||
|
||||
|
||||
def is_callback(func: Callable[..., Any]) -> bool:
|
||||
"""Check if function is safe to be called in the event loop."""
|
||||
return '_hass_callback' in func.__dict__
|
||||
|
||||
|
||||
class CoreState(enum.Enum):
|
||||
"""Represent the current state of Home Assistant."""
|
||||
|
||||
|
@ -224,11 +236,24 @@ class HomeAssistant(object):
|
|||
target: target to call.
|
||||
args: parameters for method to call.
|
||||
"""
|
||||
if asyncio.iscoroutinefunction(target):
|
||||
if is_callback(target):
|
||||
self.loop.call_soon(target, *args)
|
||||
elif asyncio.iscoroutinefunction(target):
|
||||
self.loop.create_task(target(*args))
|
||||
else:
|
||||
self.add_job(target, *args)
|
||||
|
||||
def async_run_job(self, target: Callable[..., None], *args: Any):
|
||||
"""Run a job from within the event loop.
|
||||
|
||||
target: target to call.
|
||||
args: parameters for method to call.
|
||||
"""
|
||||
if is_callback(target):
|
||||
target(*args)
|
||||
else:
|
||||
self.async_add_job(target, *args)
|
||||
|
||||
def _loop_empty(self):
|
||||
"""Python 3.4.2 empty loop compatibility function."""
|
||||
# pylint: disable=protected-access
|
||||
|
@ -380,7 +405,6 @@ class EventBus(object):
|
|||
|
||||
self._loop.call_soon_threadsafe(self.async_fire, event_type,
|
||||
event_data, origin)
|
||||
return
|
||||
|
||||
def async_fire(self, event_type: str, event_data=None,
|
||||
origin=EventOrigin.local, wait=False):
|
||||
|
@ -408,6 +432,8 @@ class EventBus(object):
|
|||
for func in listeners:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
self._loop.create_task(func(event))
|
||||
elif is_callback(func):
|
||||
self._loop.call_soon(func, event)
|
||||
else:
|
||||
sync_jobs.append((job_priority, (func, event)))
|
||||
|
||||
|
@ -795,7 +821,7 @@ class Service(object):
|
|||
"""Represents a callable service."""
|
||||
|
||||
__slots__ = ['func', 'description', 'fields', 'schema',
|
||||
'iscoroutinefunction']
|
||||
'is_callback', 'is_coroutinefunction']
|
||||
|
||||
def __init__(self, func, description, fields, schema):
|
||||
"""Initialize a service."""
|
||||
|
@ -803,7 +829,8 @@ class Service(object):
|
|||
self.description = description or ''
|
||||
self.fields = fields or {}
|
||||
self.schema = schema
|
||||
self.iscoroutinefunction = asyncio.iscoroutinefunction(func)
|
||||
self.is_callback = is_callback(func)
|
||||
self.is_coroutinefunction = asyncio.iscoroutinefunction(func)
|
||||
|
||||
def as_dict(self):
|
||||
"""Return dictionary representation of this service."""
|
||||
|
@ -934,7 +961,7 @@ class ServiceRegistry(object):
|
|||
self._loop
|
||||
).result()
|
||||
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def async_call(self, domain, service, service_data=None, blocking=False):
|
||||
"""
|
||||
Call a service.
|
||||
|
@ -966,7 +993,7 @@ class ServiceRegistry(object):
|
|||
if blocking:
|
||||
fut = asyncio.Future(loop=self._loop)
|
||||
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def service_executed(event):
|
||||
"""Callback method that is called when service is executed."""
|
||||
if event.data[ATTR_SERVICE_CALL_ID] == call_id:
|
||||
|
@ -1007,7 +1034,8 @@ class ServiceRegistry(object):
|
|||
|
||||
data = {ATTR_SERVICE_CALL_ID: call_id}
|
||||
|
||||
if service_handler.iscoroutinefunction:
|
||||
if (service_handler.is_coroutinefunction or
|
||||
service_handler.is_callback):
|
||||
self._bus.async_fire(EVENT_SERVICE_EXECUTED, data)
|
||||
else:
|
||||
self._bus.fire(EVENT_SERVICE_EXECUTED, data)
|
||||
|
@ -1023,17 +1051,19 @@ class ServiceRegistry(object):
|
|||
|
||||
service_call = ServiceCall(domain, service, service_data, call_id)
|
||||
|
||||
if not service_handler.iscoroutinefunction:
|
||||
if service_handler.is_callback:
|
||||
service_handler.func(service_call)
|
||||
fire_service_executed()
|
||||
elif service_handler.is_coroutinefunction:
|
||||
yield from service_handler.func(service_call)
|
||||
fire_service_executed()
|
||||
else:
|
||||
def execute_service():
|
||||
"""Execute a service and fires a SERVICE_EXECUTED event."""
|
||||
service_handler.func(service_call)
|
||||
fire_service_executed()
|
||||
|
||||
self._add_job(execute_service, priority=JobPriority.EVENT_SERVICE)
|
||||
return
|
||||
|
||||
yield from service_handler.func(service_call)
|
||||
fire_service_executed()
|
||||
|
||||
def _generate_unique_id(self):
|
||||
"""Generate a unique service call id."""
|
||||
|
@ -1098,7 +1128,7 @@ def async_create_timer(hass, interval=TIMER_INTERVAL):
|
|||
stop_event = asyncio.Event(loop=hass.loop)
|
||||
|
||||
# Setting the Event inside the loop by marking it as a coroutine
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def stop_timer(event):
|
||||
"""Stop the timer."""
|
||||
stop_event.set()
|
||||
|
@ -1212,7 +1242,7 @@ def async_monitor_worker_pool(hass):
|
|||
|
||||
schedule()
|
||||
|
||||
@asyncio.coroutine
|
||||
@callback
|
||||
def stop_monitor(event):
|
||||
"""Stop the monitor."""
|
||||
handle.cancel()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue