Improve trace helper typing (#105964)

This commit is contained in:
Marc Mueller 2023-12-28 14:00:24 +01:00 committed by GitHub
parent 6eec4998bd
commit 1a6e81767d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -2,17 +2,20 @@
from __future__ import annotations
from collections import deque
from collections.abc import Callable, Generator
from collections.abc import Callable, Coroutine, Generator
from contextlib import contextmanager
from contextvars import ContextVar
from functools import wraps
from typing import Any, cast
from typing import Any, TypeVar, TypeVarTuple
from homeassistant.core import ServiceResponse
import homeassistant.util.dt as dt_util
from .typing import TemplateVarsType
_T = TypeVar("_T")
_Ts = TypeVarTuple("_Ts")
class TraceElement:
"""Container for trace data."""
@ -125,21 +128,23 @@ def trace_id_get() -> tuple[str, str] | None:
return trace_id_cv.get()
def trace_stack_push(trace_stack_var: ContextVar, node: Any) -> None:
def trace_stack_push(trace_stack_var: ContextVar[list[_T] | None], node: _T) -> None:
"""Push an element to the top of a trace stack."""
trace_stack: list[_T] | None
if (trace_stack := trace_stack_var.get()) is None:
trace_stack = []
trace_stack_var.set(trace_stack)
trace_stack.append(node)
def trace_stack_pop(trace_stack_var: ContextVar) -> None:
def trace_stack_pop(trace_stack_var: ContextVar[list[Any] | None]) -> None:
"""Remove the top element from a trace stack."""
trace_stack = trace_stack_var.get()
trace_stack.pop()
if trace_stack is not None:
trace_stack.pop()
def trace_stack_top(trace_stack_var: ContextVar) -> Any | None:
def trace_stack_top(trace_stack_var: ContextVar[list[_T] | None]) -> _T | None:
"""Return the element at the top of a trace stack."""
trace_stack = trace_stack_var.get()
return trace_stack[-1] if trace_stack else None
@ -198,21 +203,20 @@ def trace_clear() -> None:
def trace_set_child_id(child_key: str, child_run_id: str) -> None:
"""Set child trace_id of TraceElement at the top of the stack."""
node = cast(TraceElement, trace_stack_top(trace_stack_cv))
if node:
if node := trace_stack_top(trace_stack_cv):
node.set_child_id(child_key, child_run_id)
def trace_set_result(**kwargs: Any) -> None:
"""Set the result of TraceElement at the top of the stack."""
node = cast(TraceElement, trace_stack_top(trace_stack_cv))
node.set_result(**kwargs)
if node := trace_stack_top(trace_stack_cv):
node.set_result(**kwargs)
def trace_update_result(**kwargs: Any) -> None:
"""Update the result of TraceElement at the top of the stack."""
node = cast(TraceElement, trace_stack_top(trace_stack_cv))
node.update_result(**kwargs)
if node := trace_stack_top(trace_stack_cv):
node.update_result(**kwargs)
class StopReason:
@ -238,7 +242,7 @@ def script_execution_get() -> str | None:
@contextmanager
def trace_path(suffix: str | list[str]) -> Generator:
def trace_path(suffix: str | list[str]) -> Generator[None, None, None]:
"""Go deeper in the config tree.
Can not be used as a decorator on couroutine functions.
@ -250,17 +254,24 @@ def trace_path(suffix: str | list[str]) -> Generator:
trace_path_pop(count)
def async_trace_path(suffix: str | list[str]) -> Callable:
def async_trace_path(
suffix: str | list[str],
) -> Callable[
[Callable[[*_Ts], Coroutine[Any, Any, None]]],
Callable[[*_Ts], Coroutine[Any, Any, None]],
]:
"""Go deeper in the config tree.
To be used as a decorator on coroutine functions.
"""
def _trace_path_decorator(func: Callable) -> Callable:
def _trace_path_decorator(
func: Callable[[*_Ts], Coroutine[Any, Any, None]],
) -> Callable[[*_Ts], Coroutine[Any, Any, None]]:
"""Decorate a coroutine function."""
@wraps(func)
async def async_wrapper(*args: Any) -> None:
async def async_wrapper(*args: *_Ts) -> None:
"""Catch and log exception."""
with trace_path(suffix):
await func(*args)