Improve trace helper typing (#105964)
This commit is contained in:
parent
6eec4998bd
commit
1a6e81767d
1 changed files with 27 additions and 16 deletions
|
@ -2,17 +2,20 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from collections.abc import Callable, Generator
|
from collections.abc import Callable, Coroutine, Generator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, cast
|
from typing import Any, TypeVar, TypeVarTuple
|
||||||
|
|
||||||
from homeassistant.core import ServiceResponse
|
from homeassistant.core import ServiceResponse
|
||||||
import homeassistant.util.dt as dt_util
|
import homeassistant.util.dt as dt_util
|
||||||
|
|
||||||
from .typing import TemplateVarsType
|
from .typing import TemplateVarsType
|
||||||
|
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
_Ts = TypeVarTuple("_Ts")
|
||||||
|
|
||||||
|
|
||||||
class TraceElement:
|
class TraceElement:
|
||||||
"""Container for trace data."""
|
"""Container for trace data."""
|
||||||
|
@ -125,21 +128,23 @@ def trace_id_get() -> tuple[str, str] | None:
|
||||||
return trace_id_cv.get()
|
return trace_id_cv.get()
|
||||||
|
|
||||||
|
|
||||||
def trace_stack_push(trace_stack_var: ContextVar, node: Any) -> None:
|
def trace_stack_push(trace_stack_var: ContextVar[list[_T] | None], node: _T) -> None:
|
||||||
"""Push an element to the top of a trace stack."""
|
"""Push an element to the top of a trace stack."""
|
||||||
|
trace_stack: list[_T] | None
|
||||||
if (trace_stack := trace_stack_var.get()) is None:
|
if (trace_stack := trace_stack_var.get()) is None:
|
||||||
trace_stack = []
|
trace_stack = []
|
||||||
trace_stack_var.set(trace_stack)
|
trace_stack_var.set(trace_stack)
|
||||||
trace_stack.append(node)
|
trace_stack.append(node)
|
||||||
|
|
||||||
|
|
||||||
def trace_stack_pop(trace_stack_var: ContextVar) -> None:
|
def trace_stack_pop(trace_stack_var: ContextVar[list[Any] | None]) -> None:
|
||||||
"""Remove the top element from a trace stack."""
|
"""Remove the top element from a trace stack."""
|
||||||
trace_stack = trace_stack_var.get()
|
trace_stack = trace_stack_var.get()
|
||||||
|
if trace_stack is not None:
|
||||||
trace_stack.pop()
|
trace_stack.pop()
|
||||||
|
|
||||||
|
|
||||||
def trace_stack_top(trace_stack_var: ContextVar) -> Any | None:
|
def trace_stack_top(trace_stack_var: ContextVar[list[_T] | None]) -> _T | None:
|
||||||
"""Return the element at the top of a trace stack."""
|
"""Return the element at the top of a trace stack."""
|
||||||
trace_stack = trace_stack_var.get()
|
trace_stack = trace_stack_var.get()
|
||||||
return trace_stack[-1] if trace_stack else None
|
return trace_stack[-1] if trace_stack else None
|
||||||
|
@ -198,20 +203,19 @@ def trace_clear() -> None:
|
||||||
|
|
||||||
def trace_set_child_id(child_key: str, child_run_id: str) -> None:
|
def trace_set_child_id(child_key: str, child_run_id: str) -> None:
|
||||||
"""Set child trace_id of TraceElement at the top of the stack."""
|
"""Set child trace_id of TraceElement at the top of the stack."""
|
||||||
node = cast(TraceElement, trace_stack_top(trace_stack_cv))
|
if node := trace_stack_top(trace_stack_cv):
|
||||||
if node:
|
|
||||||
node.set_child_id(child_key, child_run_id)
|
node.set_child_id(child_key, child_run_id)
|
||||||
|
|
||||||
|
|
||||||
def trace_set_result(**kwargs: Any) -> None:
|
def trace_set_result(**kwargs: Any) -> None:
|
||||||
"""Set the result of TraceElement at the top of the stack."""
|
"""Set the result of TraceElement at the top of the stack."""
|
||||||
node = cast(TraceElement, trace_stack_top(trace_stack_cv))
|
if node := trace_stack_top(trace_stack_cv):
|
||||||
node.set_result(**kwargs)
|
node.set_result(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
def trace_update_result(**kwargs: Any) -> None:
|
def trace_update_result(**kwargs: Any) -> None:
|
||||||
"""Update the result of TraceElement at the top of the stack."""
|
"""Update the result of TraceElement at the top of the stack."""
|
||||||
node = cast(TraceElement, trace_stack_top(trace_stack_cv))
|
if node := trace_stack_top(trace_stack_cv):
|
||||||
node.update_result(**kwargs)
|
node.update_result(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -238,7 +242,7 @@ def script_execution_get() -> str | None:
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def trace_path(suffix: str | list[str]) -> Generator:
|
def trace_path(suffix: str | list[str]) -> Generator[None, None, None]:
|
||||||
"""Go deeper in the config tree.
|
"""Go deeper in the config tree.
|
||||||
|
|
||||||
Can not be used as a decorator on couroutine functions.
|
Can not be used as a decorator on couroutine functions.
|
||||||
|
@ -250,17 +254,24 @@ def trace_path(suffix: str | list[str]) -> Generator:
|
||||||
trace_path_pop(count)
|
trace_path_pop(count)
|
||||||
|
|
||||||
|
|
||||||
def async_trace_path(suffix: str | list[str]) -> Callable:
|
def async_trace_path(
|
||||||
|
suffix: str | list[str],
|
||||||
|
) -> Callable[
|
||||||
|
[Callable[[*_Ts], Coroutine[Any, Any, None]]],
|
||||||
|
Callable[[*_Ts], Coroutine[Any, Any, None]],
|
||||||
|
]:
|
||||||
"""Go deeper in the config tree.
|
"""Go deeper in the config tree.
|
||||||
|
|
||||||
To be used as a decorator on coroutine functions.
|
To be used as a decorator on coroutine functions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _trace_path_decorator(func: Callable) -> Callable:
|
def _trace_path_decorator(
|
||||||
|
func: Callable[[*_Ts], Coroutine[Any, Any, None]],
|
||||||
|
) -> Callable[[*_Ts], Coroutine[Any, Any, None]]:
|
||||||
"""Decorate a coroutine function."""
|
"""Decorate a coroutine function."""
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
async def async_wrapper(*args: Any) -> None:
|
async def async_wrapper(*args: *_Ts) -> None:
|
||||||
"""Catch and log exception."""
|
"""Catch and log exception."""
|
||||||
with trace_path(suffix):
|
with trace_path(suffix):
|
||||||
await func(*args)
|
await func(*args)
|
||||||
|
|
Loading…
Add table
Reference in a new issue