Improve trace helper typing (#105964)
This commit is contained in:
parent
6eec4998bd
commit
1a6e81767d
1 changed files with 27 additions and 16 deletions
|
@ -2,17 +2,20 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from collections.abc import Callable, Generator
|
||||
from collections.abc import Callable, Coroutine, Generator
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from functools import wraps
|
||||
from typing import Any, cast
|
||||
from typing import Any, TypeVar, TypeVarTuple
|
||||
|
||||
from homeassistant.core import ServiceResponse
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
from .typing import TemplateVarsType
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_Ts = TypeVarTuple("_Ts")
|
||||
|
||||
|
||||
class TraceElement:
|
||||
"""Container for trace data."""
|
||||
|
@ -125,21 +128,23 @@ def trace_id_get() -> tuple[str, str] | None:
|
|||
return trace_id_cv.get()
|
||||
|
||||
|
||||
def trace_stack_push(trace_stack_var: ContextVar, node: Any) -> None:
|
||||
def trace_stack_push(trace_stack_var: ContextVar[list[_T] | None], node: _T) -> None:
|
||||
"""Push an element to the top of a trace stack."""
|
||||
trace_stack: list[_T] | None
|
||||
if (trace_stack := trace_stack_var.get()) is None:
|
||||
trace_stack = []
|
||||
trace_stack_var.set(trace_stack)
|
||||
trace_stack.append(node)
|
||||
|
||||
|
||||
def trace_stack_pop(trace_stack_var: ContextVar) -> None:
|
||||
def trace_stack_pop(trace_stack_var: ContextVar[list[Any] | None]) -> None:
|
||||
"""Remove the top element from a trace stack."""
|
||||
trace_stack = trace_stack_var.get()
|
||||
trace_stack.pop()
|
||||
if trace_stack is not None:
|
||||
trace_stack.pop()
|
||||
|
||||
|
||||
def trace_stack_top(trace_stack_var: ContextVar) -> Any | None:
|
||||
def trace_stack_top(trace_stack_var: ContextVar[list[_T] | None]) -> _T | None:
|
||||
"""Return the element at the top of a trace stack."""
|
||||
trace_stack = trace_stack_var.get()
|
||||
return trace_stack[-1] if trace_stack else None
|
||||
|
@ -198,21 +203,20 @@ def trace_clear() -> None:
|
|||
|
||||
def trace_set_child_id(child_key: str, child_run_id: str) -> None:
|
||||
"""Set child trace_id of TraceElement at the top of the stack."""
|
||||
node = cast(TraceElement, trace_stack_top(trace_stack_cv))
|
||||
if node:
|
||||
if node := trace_stack_top(trace_stack_cv):
|
||||
node.set_child_id(child_key, child_run_id)
|
||||
|
||||
|
||||
def trace_set_result(**kwargs: Any) -> None:
|
||||
"""Set the result of TraceElement at the top of the stack."""
|
||||
node = cast(TraceElement, trace_stack_top(trace_stack_cv))
|
||||
node.set_result(**kwargs)
|
||||
if node := trace_stack_top(trace_stack_cv):
|
||||
node.set_result(**kwargs)
|
||||
|
||||
|
||||
def trace_update_result(**kwargs: Any) -> None:
|
||||
"""Update the result of TraceElement at the top of the stack."""
|
||||
node = cast(TraceElement, trace_stack_top(trace_stack_cv))
|
||||
node.update_result(**kwargs)
|
||||
if node := trace_stack_top(trace_stack_cv):
|
||||
node.update_result(**kwargs)
|
||||
|
||||
|
||||
class StopReason:
|
||||
|
@ -238,7 +242,7 @@ def script_execution_get() -> str | None:
|
|||
|
||||
|
||||
@contextmanager
|
||||
def trace_path(suffix: str | list[str]) -> Generator:
|
||||
def trace_path(suffix: str | list[str]) -> Generator[None, None, None]:
|
||||
"""Go deeper in the config tree.
|
||||
|
||||
Can not be used as a decorator on couroutine functions.
|
||||
|
@ -250,17 +254,24 @@ def trace_path(suffix: str | list[str]) -> Generator:
|
|||
trace_path_pop(count)
|
||||
|
||||
|
||||
def async_trace_path(suffix: str | list[str]) -> Callable:
|
||||
def async_trace_path(
|
||||
suffix: str | list[str],
|
||||
) -> Callable[
|
||||
[Callable[[*_Ts], Coroutine[Any, Any, None]]],
|
||||
Callable[[*_Ts], Coroutine[Any, Any, None]],
|
||||
]:
|
||||
"""Go deeper in the config tree.
|
||||
|
||||
To be used as a decorator on coroutine functions.
|
||||
"""
|
||||
|
||||
def _trace_path_decorator(func: Callable) -> Callable:
|
||||
def _trace_path_decorator(
|
||||
func: Callable[[*_Ts], Coroutine[Any, Any, None]],
|
||||
) -> Callable[[*_Ts], Coroutine[Any, Any, None]]:
|
||||
"""Decorate a coroutine function."""
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any) -> None:
|
||||
async def async_wrapper(*args: *_Ts) -> None:
|
||||
"""Catch and log exception."""
|
||||
with trace_path(suffix):
|
||||
await func(*args)
|
||||
|
|
Loading…
Add table
Reference in a new issue