Added recursive detection of functools.partial. (#20284)
This commit is contained in:
parent
9482a6303d
commit
5c208da82e
2 changed files with 10 additions and 4 deletions
|
@ -259,9 +259,10 @@ class HomeAssistant:
|
|||
"""
|
||||
task = None
|
||||
|
||||
# Check for partials to properly determine if coroutine function
|
||||
check_target = target
|
||||
if isinstance(target, functools.partial):
|
||||
check_target = target.func
|
||||
while isinstance(check_target, functools.partial):
|
||||
check_target = check_target.func
|
||||
|
||||
if asyncio.iscoroutine(check_target):
|
||||
task = self.loop.create_task(target) # type: ignore
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""Logging utilities."""
|
||||
import asyncio
|
||||
from asyncio.events import AbstractEventLoop
|
||||
from functools import wraps
|
||||
from functools import partial, wraps
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
|
@ -139,8 +139,13 @@ def catch_log_exception(
|
|||
friendly_msg = format_err(*args)
|
||||
logging.getLogger(module_name).error('%s\n%s', friendly_msg, exc_msg)
|
||||
|
||||
# Check for partials to properly determine if coroutine function
|
||||
check_func = func
|
||||
while isinstance(check_func, partial):
|
||||
check_func = check_func.func
|
||||
|
||||
wrapper_func = None
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
if asyncio.iscoroutinefunction(check_func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any) -> None:
|
||||
"""Catch and log exception."""
|
||||
|
|
Loading…
Add table
Reference in a new issue