Improve trace helper typing (#105964)

This commit is contained in:
Marc Mueller 2023-12-28 14:00:24 +01:00 committed by GitHub
parent 6eec4998bd
commit 1a6e81767d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -2,17 +2,20 @@
from __future__ import annotations from __future__ import annotations
from collections import deque from collections import deque
from collections.abc import Callable, Generator from collections.abc import Callable, Coroutine, Generator
from contextlib import contextmanager from contextlib import contextmanager
from contextvars import ContextVar from contextvars import ContextVar
from functools import wraps from functools import wraps
from typing import Any, cast from typing import Any, TypeVar, TypeVarTuple
from homeassistant.core import ServiceResponse from homeassistant.core import ServiceResponse
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from .typing import TemplateVarsType from .typing import TemplateVarsType
_T = TypeVar("_T")
_Ts = TypeVarTuple("_Ts")
class TraceElement: class TraceElement:
"""Container for trace data.""" """Container for trace data."""
@ -125,21 +128,23 @@ def trace_id_get() -> tuple[str, str] | None:
return trace_id_cv.get() return trace_id_cv.get()
def trace_stack_push(trace_stack_var: ContextVar, node: Any) -> None: def trace_stack_push(trace_stack_var: ContextVar[list[_T] | None], node: _T) -> None:
"""Push an element to the top of a trace stack.""" """Push an element to the top of a trace stack."""
trace_stack: list[_T] | None
if (trace_stack := trace_stack_var.get()) is None: if (trace_stack := trace_stack_var.get()) is None:
trace_stack = [] trace_stack = []
trace_stack_var.set(trace_stack) trace_stack_var.set(trace_stack)
trace_stack.append(node) trace_stack.append(node)
def trace_stack_pop(trace_stack_var: ContextVar) -> None: def trace_stack_pop(trace_stack_var: ContextVar[list[Any] | None]) -> None:
"""Remove the top element from a trace stack.""" """Remove the top element from a trace stack."""
trace_stack = trace_stack_var.get() trace_stack = trace_stack_var.get()
if trace_stack is not None:
trace_stack.pop() trace_stack.pop()
def trace_stack_top(trace_stack_var: ContextVar) -> Any | None: def trace_stack_top(trace_stack_var: ContextVar[list[_T] | None]) -> _T | None:
"""Return the element at the top of a trace stack.""" """Return the element at the top of a trace stack."""
trace_stack = trace_stack_var.get() trace_stack = trace_stack_var.get()
return trace_stack[-1] if trace_stack else None return trace_stack[-1] if trace_stack else None
@ -198,20 +203,19 @@ def trace_clear() -> None:
def trace_set_child_id(child_key: str, child_run_id: str) -> None: def trace_set_child_id(child_key: str, child_run_id: str) -> None:
"""Set child trace_id of TraceElement at the top of the stack.""" """Set child trace_id of TraceElement at the top of the stack."""
node = cast(TraceElement, trace_stack_top(trace_stack_cv)) if node := trace_stack_top(trace_stack_cv):
if node:
node.set_child_id(child_key, child_run_id) node.set_child_id(child_key, child_run_id)
def trace_set_result(**kwargs: Any) -> None: def trace_set_result(**kwargs: Any) -> None:
"""Set the result of TraceElement at the top of the stack.""" """Set the result of TraceElement at the top of the stack."""
node = cast(TraceElement, trace_stack_top(trace_stack_cv)) if node := trace_stack_top(trace_stack_cv):
node.set_result(**kwargs) node.set_result(**kwargs)
def trace_update_result(**kwargs: Any) -> None: def trace_update_result(**kwargs: Any) -> None:
"""Update the result of TraceElement at the top of the stack.""" """Update the result of TraceElement at the top of the stack."""
node = cast(TraceElement, trace_stack_top(trace_stack_cv)) if node := trace_stack_top(trace_stack_cv):
node.update_result(**kwargs) node.update_result(**kwargs)
@ -238,7 +242,7 @@ def script_execution_get() -> str | None:
@contextmanager @contextmanager
def trace_path(suffix: str | list[str]) -> Generator: def trace_path(suffix: str | list[str]) -> Generator[None, None, None]:
"""Go deeper in the config tree. """Go deeper in the config tree.
Can not be used as a decorator on couroutine functions. Can not be used as a decorator on couroutine functions.
@ -250,17 +254,24 @@ def trace_path(suffix: str | list[str]) -> Generator:
trace_path_pop(count) trace_path_pop(count)
def async_trace_path(suffix: str | list[str]) -> Callable: def async_trace_path(
suffix: str | list[str],
) -> Callable[
[Callable[[*_Ts], Coroutine[Any, Any, None]]],
Callable[[*_Ts], Coroutine[Any, Any, None]],
]:
"""Go deeper in the config tree. """Go deeper in the config tree.
To be used as a decorator on coroutine functions. To be used as a decorator on coroutine functions.
""" """
def _trace_path_decorator(func: Callable) -> Callable: def _trace_path_decorator(
func: Callable[[*_Ts], Coroutine[Any, Any, None]],
) -> Callable[[*_Ts], Coroutine[Any, Any, None]]:
"""Decorate a coroutine function.""" """Decorate a coroutine function."""
@wraps(func) @wraps(func)
async def async_wrapper(*args: Any) -> None: async def async_wrapper(*args: *_Ts) -> None:
"""Catch and log exception.""" """Catch and log exception."""
with trace_path(suffix): with trace_path(suffix):
await func(*args) await func(*args)