Added recursive detection of functools.partial. (#20284)

This commit is contained in:
Andrew Sayre 2019-01-21 00:27:32 -06:00 committed by Paulus Schoutsen
parent 9482a6303d
commit 5c208da82e
2 changed files with 10 additions and 4 deletions

View file

@ -259,9 +259,10 @@ class HomeAssistant:
"""
task = None
# Check for partials to properly determine if coroutine function
check_target = target
if isinstance(target, functools.partial):
check_target = target.func
while isinstance(check_target, functools.partial):
check_target = check_target.func
if asyncio.iscoroutine(check_target):
task = self.loop.create_task(target) # type: ignore

View file

@ -1,7 +1,7 @@
"""Logging utilities."""
import asyncio
from asyncio.events import AbstractEventLoop
from functools import wraps
from functools import partial, wraps
import inspect
import logging
import threading
@ -139,8 +139,13 @@ def catch_log_exception(
friendly_msg = format_err(*args)
logging.getLogger(module_name).error('%s\n%s', friendly_msg, exc_msg)
# Check for partials to properly determine if coroutine function
check_func = func
while isinstance(check_func, partial):
check_func = check_func.func
wrapper_func = None
if asyncio.iscoroutinefunction(func):
if asyncio.iscoroutinefunction(check_func):
@wraps(func)
async def async_wrapper(*args: Any) -> None:
"""Catch and log exception."""