From 1a6e81767d7cb406df997bbbf2fd8a6df45616a6 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Thu, 28 Dec 2023 14:00:24 +0100 Subject: [PATCH] Improve trace helper typing (#105964) --- homeassistant/helpers/trace.py | 43 +++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/homeassistant/helpers/trace.py b/homeassistant/helpers/trace.py index fd7a3081f7a..41be606488a 100644 --- a/homeassistant/helpers/trace.py +++ b/homeassistant/helpers/trace.py @@ -2,17 +2,20 @@ from __future__ import annotations from collections import deque -from collections.abc import Callable, Generator +from collections.abc import Callable, Coroutine, Generator from contextlib import contextmanager from contextvars import ContextVar from functools import wraps -from typing import Any, cast +from typing import Any, TypeVar, TypeVarTuple from homeassistant.core import ServiceResponse import homeassistant.util.dt as dt_util from .typing import TemplateVarsType +_T = TypeVar("_T") +_Ts = TypeVarTuple("_Ts") + class TraceElement: """Container for trace data.""" @@ -125,21 +128,23 @@ def trace_id_get() -> tuple[str, str] | None: return trace_id_cv.get() -def trace_stack_push(trace_stack_var: ContextVar, node: Any) -> None: +def trace_stack_push(trace_stack_var: ContextVar[list[_T] | None], node: _T) -> None: """Push an element to the top of a trace stack.""" + trace_stack: list[_T] | None if (trace_stack := trace_stack_var.get()) is None: trace_stack = [] trace_stack_var.set(trace_stack) trace_stack.append(node) -def trace_stack_pop(trace_stack_var: ContextVar) -> None: +def trace_stack_pop(trace_stack_var: ContextVar[list[Any] | None]) -> None: """Remove the top element from a trace stack.""" trace_stack = trace_stack_var.get() - trace_stack.pop() + if trace_stack is not None: + trace_stack.pop() -def trace_stack_top(trace_stack_var: ContextVar) -> Any | None: +def trace_stack_top(trace_stack_var: ContextVar[list[_T] | None]) -> _T | None: """Return the element at the top of a trace stack.""" trace_stack = trace_stack_var.get() return trace_stack[-1] if trace_stack else None @@ -198,21 +203,20 @@ def trace_clear() -> None: def trace_set_child_id(child_key: str, child_run_id: str) -> None: """Set child trace_id of TraceElement at the top of the stack.""" - node = cast(TraceElement, trace_stack_top(trace_stack_cv)) - if node: + if node := trace_stack_top(trace_stack_cv): node.set_child_id(child_key, child_run_id) def trace_set_result(**kwargs: Any) -> None: """Set the result of TraceElement at the top of the stack.""" - node = cast(TraceElement, trace_stack_top(trace_stack_cv)) - node.set_result(**kwargs) + if node := trace_stack_top(trace_stack_cv): + node.set_result(**kwargs) def trace_update_result(**kwargs: Any) -> None: """Update the result of TraceElement at the top of the stack.""" - node = cast(TraceElement, trace_stack_top(trace_stack_cv)) - node.update_result(**kwargs) + if node := trace_stack_top(trace_stack_cv): + node.update_result(**kwargs) class StopReason: @@ -238,7 +242,7 @@ def script_execution_get() -> str | None: @contextmanager -def trace_path(suffix: str | list[str]) -> Generator: +def trace_path(suffix: str | list[str]) -> Generator[None, None, None]: """Go deeper in the config tree. Can not be used as a decorator on couroutine functions. @@ -250,17 +254,24 @@ def trace_path(suffix: str | list[str]) -> Generator: trace_path_pop(count) -def async_trace_path(suffix: str | list[str]) -> Callable: +def async_trace_path( + suffix: str | list[str], +) -> Callable[ + [Callable[[*_Ts], Coroutine[Any, Any, None]]], + Callable[[*_Ts], Coroutine[Any, Any, None]], +]: """Go deeper in the config tree. To be used as a decorator on coroutine functions. """ - def _trace_path_decorator(func: Callable) -> Callable: + def _trace_path_decorator( + func: Callable[[*_Ts], Coroutine[Any, Any, None]], + ) -> Callable[[*_Ts], Coroutine[Any, Any, None]]: """Decorate a coroutine function.""" @wraps(func) - async def async_wrapper(*args: Any) -> None: + async def async_wrapper(*args: *_Ts) -> None: """Catch and log exception.""" with trace_path(suffix): await func(*args)