Use PEP 695 for function annotations (2) (#117659)

This commit is contained in:
Marc Mueller 2024-05-18 11:44:39 +02:00 committed by GitHub
parent 4cf0a3f154
commit 900b6211ef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 30 additions and 73 deletions

View file

@ -8,14 +8,10 @@ from enum import Enum
import functools
import inspect
import logging
from typing import Any, NamedTuple, ParamSpec, TypeVar
_ObjectT = TypeVar("_ObjectT", bound=object)
_R = TypeVar("_R")
_P = ParamSpec("_P")
from typing import Any, NamedTuple
def deprecated_substitute(
def deprecated_substitute[_ObjectT: object](
substitute_name: str,
) -> Callable[[Callable[[_ObjectT], Any]], Callable[[_ObjectT], Any]]:
"""Help migrate properties to new names.
@ -92,7 +88,7 @@ def get_deprecated(
return config.get(new_name, default)
def deprecated_class(
def deprecated_class[**_P, _R](
replacement: str, *, breaks_in_ha_version: str | None = None
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
"""Mark class as deprecated and provide a replacement class to be used instead.
@ -117,7 +113,7 @@ def deprecated_class(
return deprecated_decorator
def deprecated_function(
def deprecated_function[**_P, _R](
replacement: str, *, breaks_in_ha_version: str | None = None
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
"""Mark function as deprecated and provide a replacement to be used instead.

View file

@ -16,16 +16,7 @@ from operator import attrgetter
import sys
import time
from types import FunctionType
from typing import (
TYPE_CHECKING,
Any,
Final,
Literal,
NotRequired,
TypedDict,
TypeVar,
final,
)
from typing import TYPE_CHECKING, Any, Final, Literal, NotRequired, TypedDict, final
import voluptuous as vol
@ -79,8 +70,6 @@ timer = time.time
if TYPE_CHECKING:
from .entity_platform import EntityPlatform
_T = TypeVar("_T")
_LOGGER = logging.getLogger(__name__)
SLOW_UPDATE_WARNING = 10
DATA_ENTITY_SOURCE = "entity_info"
@ -1603,7 +1592,7 @@ class Entity(
return f"<entity unknown.unknown={STATE_UNKNOWN}>"
return f"<entity {self.entity_id}={self._stringify_state(self.available)}>"
async def async_request_call(self, coro: Coroutine[Any, Any, _T]) -> _T:
async def async_request_call[_T](self, coro: Coroutine[Any, Any, _T]) -> _T:
"""Process request batched."""
if self.parallel_updates:
await self.parallel_updates.acquire()

View file

@ -16,7 +16,7 @@ from enum import StrEnum
from functools import cached_property
import logging
import time
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, TypeVar
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict
import attr
import voluptuous as vol
@ -65,8 +65,6 @@ from .typing import UNDEFINED, UndefinedType
if TYPE_CHECKING:
from homeassistant.config_entries import ConfigEntry
T = TypeVar("T")
DATA_REGISTRY: HassKey[EntityRegistry] = HassKey("entity_registry")
EVENT_ENTITY_REGISTRY_UPDATED: EventType[EventEntityRegistryUpdatedData] = EventType(
"entity_registry_updated"
@ -852,7 +850,7 @@ class EntityRegistry(BaseRegistry):
):
disabled_by = RegistryEntryDisabler.INTEGRATION
def none_if_undefined(value: T | UndefinedType) -> T | None:
def none_if_undefined[_T](value: _T | UndefinedType) -> _T | None:
"""Return None if value is UNDEFINED, otherwise return value."""
return None if value is UNDEFINED else value

View file

@ -12,7 +12,7 @@ import linecache
import logging
import sys
from types import FrameType
from typing import Any, TypeVar, cast
from typing import Any, cast
from homeassistant.core import HomeAssistant, async_get_hass
from homeassistant.exceptions import HomeAssistantError
@ -23,8 +23,6 @@ _LOGGER = logging.getLogger(__name__)
# Keep track of integrations already reported to prevent flooding
_REPORTED_INTEGRATIONS: set[str] = set()
_CallableT = TypeVar("_CallableT", bound=Callable)
@dataclass(kw_only=True)
class IntegrationFrame:
@ -209,7 +207,7 @@ def _report_integration(
)
def warn_use(func: _CallableT, what: str) -> _CallableT:
def warn_use[_CallableT: Callable](func: _CallableT, what: str) -> _CallableT:
"""Mock a function to warn when it was about to be used."""
if asyncio.iscoroutinefunction(func):

View file

@ -6,12 +6,9 @@ import asyncio
from collections.abc import Callable, Hashable
import logging
import time
from typing import TypeVarTuple
from homeassistant.core import HomeAssistant, callback
_Ts = TypeVarTuple("_Ts")
_LOGGER = logging.getLogger(__name__)
@ -52,7 +49,7 @@ class KeyedRateLimit:
self._rate_limit_timers.clear()
@callback
def async_schedule_action(
def async_schedule_action[*_Ts](
self,
key: Hashable,
rate_limit: float | None,

View file

@ -3,15 +3,12 @@
from __future__ import annotations
from collections.abc import Callable, Iterable, Mapping
from typing import Any, TypeVar, cast, overload
from typing import Any, cast, overload
from homeassistant.core import callback
REDACTED = "**REDACTED**"
_T = TypeVar("_T")
_ValueT = TypeVar("_ValueT")
def partial_redact(
x: str | Any, unmasked_prefix: int = 4, unmasked_suffix: int = 4
@ -32,19 +29,19 @@ def partial_redact(
@overload
def async_redact_data( # type: ignore[overload-overlap]
def async_redact_data[_ValueT]( # type: ignore[overload-overlap]
data: Mapping, to_redact: Iterable[Any] | Mapping[Any, Callable[[_ValueT], _ValueT]]
) -> dict: ...
@overload
def async_redact_data(
def async_redact_data[_T, _ValueT](
data: _T, to_redact: Iterable[Any] | Mapping[Any, Callable[[_ValueT], _ValueT]]
) -> _T: ...
@callback
def async_redact_data(
def async_redact_data[_T, _ValueT](
data: _T, to_redact: Iterable[Any] | Mapping[Any, Callable[[_ValueT], _ValueT]]
) -> _T:
"""Redact sensitive data in a dict."""

View file

@ -13,7 +13,7 @@ from functools import cached_property, partial
import itertools
import logging
from types import MappingProxyType
from typing import Any, Literal, TypedDict, TypeVar, cast
from typing import Any, Literal, TypedDict, cast
import async_interrupt
import voluptuous as vol
@ -111,8 +111,6 @@ from .typing import UNDEFINED, ConfigType, UndefinedType
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
_T = TypeVar("_T")
SCRIPT_MODE_PARALLEL = "parallel"
SCRIPT_MODE_QUEUED = "queued"
SCRIPT_MODE_RESTART = "restart"
@ -713,7 +711,9 @@ class _ScriptRun:
else:
wait_var["remaining"] = None
async def _async_run_long_action(self, long_task: asyncio.Task[_T]) -> _T | None:
async def _async_run_long_action[_T](
self, long_task: asyncio.Task[_T]
) -> _T | None:
"""Run a long task while monitoring for stop request."""
try:
async with async_interrupt.interrupt(self._stop, ScriptStoppedError, None):

View file

@ -68,9 +68,6 @@ from .typing import ConfigType, TemplateVarsType
if TYPE_CHECKING:
from .entity import Entity
_EntityT = TypeVar("_EntityT", bound=Entity)
CONF_SERVICE_ENTITY_ID = "entity_id"
_LOGGER = logging.getLogger(__name__)
@ -434,7 +431,7 @@ def extract_entity_ids(
@bind_hass
async def async_extract_entities(
async def async_extract_entities[_EntityT: Entity](
hass: HomeAssistant,
entities: Iterable[_EntityT],
service_call: ServiceCall,

View file

@ -22,17 +22,7 @@ import statistics
from struct import error as StructError, pack, unpack_from
import sys
from types import CodeType, TracebackType
from typing import (
Any,
Concatenate,
Literal,
NoReturn,
ParamSpec,
Self,
TypeVar,
cast,
overload,
)
from typing import Any, Concatenate, Literal, NoReturn, Self, cast, overload
from urllib.parse import urlencode as urllib_urlencode
import weakref
@ -134,10 +124,6 @@ _COLLECTABLE_STATE_ATTRIBUTES = {
"name",
}
_T = TypeVar("_T")
_R = TypeVar("_R")
_P = ParamSpec("_P")
ALL_STATES_RATE_LIMIT = 60 # seconds
DOMAIN_STATES_RATE_LIMIT = 1 # seconds
@ -1217,10 +1203,10 @@ def forgiving_boolean(value: Any) -> bool | object: ...
@overload
def forgiving_boolean(value: Any, default: _T) -> bool | _T: ...
def forgiving_boolean[_T](value: Any, default: _T) -> bool | _T: ...
def forgiving_boolean(
def forgiving_boolean[_T](
value: Any, default: _T | object = _SENTINEL
) -> bool | _T | object:
"""Try to convert value to a boolean."""
@ -2840,7 +2826,7 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
# evaluated fresh with every execution, rather than executed
# at compile time and the value stored. The context itself
# can be discarded, we only need to get at the hass object.
def hassfunction(
def hassfunction[**_P, _R](
func: Callable[Concatenate[HomeAssistant, _P], _R],
jinja_context: Callable[
[Callable[Concatenate[Any, _P], _R]],

View file

@ -7,16 +7,13 @@ from collections.abc import Callable, Coroutine, Generator
from contextlib import contextmanager
from contextvars import ContextVar
from functools import wraps
from typing import Any, TypeVar, TypeVarTuple
from typing import Any
from homeassistant.core import ServiceResponse
import homeassistant.util.dt as dt_util
from .typing import TemplateVarsType
_T = TypeVar("_T")
_Ts = TypeVarTuple("_Ts")
class TraceElement:
"""Container for trace data."""
@ -135,7 +132,9 @@ def trace_id_get() -> tuple[str, str] | None:
return trace_id_cv.get()
def trace_stack_push(trace_stack_var: ContextVar[list[_T] | None], node: _T) -> None:
def trace_stack_push[_T](
trace_stack_var: ContextVar[list[_T] | None], node: _T
) -> None:
"""Push an element to the top of a trace stack."""
trace_stack: list[_T] | None
if (trace_stack := trace_stack_var.get()) is None:
@ -151,7 +150,7 @@ def trace_stack_pop(trace_stack_var: ContextVar[list[Any] | None]) -> None:
trace_stack.pop()
def trace_stack_top(trace_stack_var: ContextVar[list[_T] | None]) -> _T | None:
def trace_stack_top[_T](trace_stack_var: ContextVar[list[_T] | None]) -> _T | None:
"""Return the element at the top of a trace stack."""
trace_stack = trace_stack_var.get()
return trace_stack[-1] if trace_stack else None
@ -261,7 +260,7 @@ def trace_path(suffix: str | list[str]) -> Generator[None, None, None]:
trace_path_pop(count)
def async_trace_path(
def async_trace_path[*_Ts](
suffix: str | list[str],
) -> Callable[
[Callable[[*_Ts], Coroutine[Any, Any, None]]],