Use PEP 695 for function annotations (2) (#117659)

This commit is contained in:
Marc Mueller 2024-05-18 11:44:39 +02:00 committed by GitHub
parent 4cf0a3f154
commit 900b6211ef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 30 additions and 73 deletions

View file

@ -8,14 +8,10 @@ from enum import Enum
import functools import functools
import inspect import inspect
import logging import logging
from typing import Any, NamedTuple, ParamSpec, TypeVar from typing import Any, NamedTuple
_ObjectT = TypeVar("_ObjectT", bound=object)
_R = TypeVar("_R")
_P = ParamSpec("_P")
def deprecated_substitute( def deprecated_substitute[_ObjectT: object](
substitute_name: str, substitute_name: str,
) -> Callable[[Callable[[_ObjectT], Any]], Callable[[_ObjectT], Any]]: ) -> Callable[[Callable[[_ObjectT], Any]], Callable[[_ObjectT], Any]]:
"""Help migrate properties to new names. """Help migrate properties to new names.
@ -92,7 +88,7 @@ def get_deprecated(
return config.get(new_name, default) return config.get(new_name, default)
def deprecated_class( def deprecated_class[**_P, _R](
replacement: str, *, breaks_in_ha_version: str | None = None replacement: str, *, breaks_in_ha_version: str | None = None
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: ) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
"""Mark class as deprecated and provide a replacement class to be used instead. """Mark class as deprecated and provide a replacement class to be used instead.
@ -117,7 +113,7 @@ def deprecated_class(
return deprecated_decorator return deprecated_decorator
def deprecated_function( def deprecated_function[**_P, _R](
replacement: str, *, breaks_in_ha_version: str | None = None replacement: str, *, breaks_in_ha_version: str | None = None
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: ) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
"""Mark function as deprecated and provide a replacement to be used instead. """Mark function as deprecated and provide a replacement to be used instead.

View file

@ -16,16 +16,7 @@ from operator import attrgetter
import sys import sys
import time import time
from types import FunctionType from types import FunctionType
from typing import ( from typing import TYPE_CHECKING, Any, Final, Literal, NotRequired, TypedDict, final
TYPE_CHECKING,
Any,
Final,
Literal,
NotRequired,
TypedDict,
TypeVar,
final,
)
import voluptuous as vol import voluptuous as vol
@ -79,8 +70,6 @@ timer = time.time
if TYPE_CHECKING: if TYPE_CHECKING:
from .entity_platform import EntityPlatform from .entity_platform import EntityPlatform
_T = TypeVar("_T")
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
SLOW_UPDATE_WARNING = 10 SLOW_UPDATE_WARNING = 10
DATA_ENTITY_SOURCE = "entity_info" DATA_ENTITY_SOURCE = "entity_info"
@ -1603,7 +1592,7 @@ class Entity(
return f"<entity unknown.unknown={STATE_UNKNOWN}>" return f"<entity unknown.unknown={STATE_UNKNOWN}>"
return f"<entity {self.entity_id}={self._stringify_state(self.available)}>" return f"<entity {self.entity_id}={self._stringify_state(self.available)}>"
async def async_request_call(self, coro: Coroutine[Any, Any, _T]) -> _T: async def async_request_call[_T](self, coro: Coroutine[Any, Any, _T]) -> _T:
"""Process request batched.""" """Process request batched."""
if self.parallel_updates: if self.parallel_updates:
await self.parallel_updates.acquire() await self.parallel_updates.acquire()

View file

@ -16,7 +16,7 @@ from enum import StrEnum
from functools import cached_property from functools import cached_property
import logging import logging
import time import time
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, TypeVar from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict
import attr import attr
import voluptuous as vol import voluptuous as vol
@ -65,8 +65,6 @@ from .typing import UNDEFINED, UndefinedType
if TYPE_CHECKING: if TYPE_CHECKING:
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
T = TypeVar("T")
DATA_REGISTRY: HassKey[EntityRegistry] = HassKey("entity_registry") DATA_REGISTRY: HassKey[EntityRegistry] = HassKey("entity_registry")
EVENT_ENTITY_REGISTRY_UPDATED: EventType[EventEntityRegistryUpdatedData] = EventType( EVENT_ENTITY_REGISTRY_UPDATED: EventType[EventEntityRegistryUpdatedData] = EventType(
"entity_registry_updated" "entity_registry_updated"
@ -852,7 +850,7 @@ class EntityRegistry(BaseRegistry):
): ):
disabled_by = RegistryEntryDisabler.INTEGRATION disabled_by = RegistryEntryDisabler.INTEGRATION
def none_if_undefined(value: T | UndefinedType) -> T | None: def none_if_undefined[_T](value: _T | UndefinedType) -> _T | None:
"""Return None if value is UNDEFINED, otherwise return value.""" """Return None if value is UNDEFINED, otherwise return value."""
return None if value is UNDEFINED else value return None if value is UNDEFINED else value

View file

@ -12,7 +12,7 @@ import linecache
import logging import logging
import sys import sys
from types import FrameType from types import FrameType
from typing import Any, TypeVar, cast from typing import Any, cast
from homeassistant.core import HomeAssistant, async_get_hass from homeassistant.core import HomeAssistant, async_get_hass
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@ -23,8 +23,6 @@ _LOGGER = logging.getLogger(__name__)
# Keep track of integrations already reported to prevent flooding # Keep track of integrations already reported to prevent flooding
_REPORTED_INTEGRATIONS: set[str] = set() _REPORTED_INTEGRATIONS: set[str] = set()
_CallableT = TypeVar("_CallableT", bound=Callable)
@dataclass(kw_only=True) @dataclass(kw_only=True)
class IntegrationFrame: class IntegrationFrame:
@ -209,7 +207,7 @@ def _report_integration(
) )
def warn_use(func: _CallableT, what: str) -> _CallableT: def warn_use[_CallableT: Callable](func: _CallableT, what: str) -> _CallableT:
"""Mock a function to warn when it was about to be used.""" """Mock a function to warn when it was about to be used."""
if asyncio.iscoroutinefunction(func): if asyncio.iscoroutinefunction(func):

View file

@ -6,12 +6,9 @@ import asyncio
from collections.abc import Callable, Hashable from collections.abc import Callable, Hashable
import logging import logging
import time import time
from typing import TypeVarTuple
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
_Ts = TypeVarTuple("_Ts")
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -52,7 +49,7 @@ class KeyedRateLimit:
self._rate_limit_timers.clear() self._rate_limit_timers.clear()
@callback @callback
def async_schedule_action( def async_schedule_action[*_Ts](
self, self,
key: Hashable, key: Hashable,
rate_limit: float | None, rate_limit: float | None,

View file

@ -3,15 +3,12 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Iterable, Mapping from collections.abc import Callable, Iterable, Mapping
from typing import Any, TypeVar, cast, overload from typing import Any, cast, overload
from homeassistant.core import callback from homeassistant.core import callback
REDACTED = "**REDACTED**" REDACTED = "**REDACTED**"
_T = TypeVar("_T")
_ValueT = TypeVar("_ValueT")
def partial_redact( def partial_redact(
x: str | Any, unmasked_prefix: int = 4, unmasked_suffix: int = 4 x: str | Any, unmasked_prefix: int = 4, unmasked_suffix: int = 4
@ -32,19 +29,19 @@ def partial_redact(
@overload @overload
def async_redact_data( # type: ignore[overload-overlap] def async_redact_data[_ValueT]( # type: ignore[overload-overlap]
data: Mapping, to_redact: Iterable[Any] | Mapping[Any, Callable[[_ValueT], _ValueT]] data: Mapping, to_redact: Iterable[Any] | Mapping[Any, Callable[[_ValueT], _ValueT]]
) -> dict: ... ) -> dict: ...
@overload @overload
def async_redact_data( def async_redact_data[_T, _ValueT](
data: _T, to_redact: Iterable[Any] | Mapping[Any, Callable[[_ValueT], _ValueT]] data: _T, to_redact: Iterable[Any] | Mapping[Any, Callable[[_ValueT], _ValueT]]
) -> _T: ... ) -> _T: ...
@callback @callback
def async_redact_data( def async_redact_data[_T, _ValueT](
data: _T, to_redact: Iterable[Any] | Mapping[Any, Callable[[_ValueT], _ValueT]] data: _T, to_redact: Iterable[Any] | Mapping[Any, Callable[[_ValueT], _ValueT]]
) -> _T: ) -> _T:
"""Redact sensitive data in a dict.""" """Redact sensitive data in a dict."""

View file

@ -13,7 +13,7 @@ from functools import cached_property, partial
import itertools import itertools
import logging import logging
from types import MappingProxyType from types import MappingProxyType
from typing import Any, Literal, TypedDict, TypeVar, cast from typing import Any, Literal, TypedDict, cast
import async_interrupt import async_interrupt
import voluptuous as vol import voluptuous as vol
@ -111,8 +111,6 @@ from .typing import UNDEFINED, ConfigType, UndefinedType
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs # mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
_T = TypeVar("_T")
SCRIPT_MODE_PARALLEL = "parallel" SCRIPT_MODE_PARALLEL = "parallel"
SCRIPT_MODE_QUEUED = "queued" SCRIPT_MODE_QUEUED = "queued"
SCRIPT_MODE_RESTART = "restart" SCRIPT_MODE_RESTART = "restart"
@ -713,7 +711,9 @@ class _ScriptRun:
else: else:
wait_var["remaining"] = None wait_var["remaining"] = None
async def _async_run_long_action(self, long_task: asyncio.Task[_T]) -> _T | None: async def _async_run_long_action[_T](
self, long_task: asyncio.Task[_T]
) -> _T | None:
"""Run a long task while monitoring for stop request.""" """Run a long task while monitoring for stop request."""
try: try:
async with async_interrupt.interrupt(self._stop, ScriptStoppedError, None): async with async_interrupt.interrupt(self._stop, ScriptStoppedError, None):

View file

@ -68,9 +68,6 @@ from .typing import ConfigType, TemplateVarsType
if TYPE_CHECKING: if TYPE_CHECKING:
from .entity import Entity from .entity import Entity
_EntityT = TypeVar("_EntityT", bound=Entity)
CONF_SERVICE_ENTITY_ID = "entity_id" CONF_SERVICE_ENTITY_ID = "entity_id"
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -434,7 +431,7 @@ def extract_entity_ids(
@bind_hass @bind_hass
async def async_extract_entities( async def async_extract_entities[_EntityT: Entity](
hass: HomeAssistant, hass: HomeAssistant,
entities: Iterable[_EntityT], entities: Iterable[_EntityT],
service_call: ServiceCall, service_call: ServiceCall,

View file

@ -22,17 +22,7 @@ import statistics
from struct import error as StructError, pack, unpack_from from struct import error as StructError, pack, unpack_from
import sys import sys
from types import CodeType, TracebackType from types import CodeType, TracebackType
from typing import ( from typing import Any, Concatenate, Literal, NoReturn, Self, cast, overload
Any,
Concatenate,
Literal,
NoReturn,
ParamSpec,
Self,
TypeVar,
cast,
overload,
)
from urllib.parse import urlencode as urllib_urlencode from urllib.parse import urlencode as urllib_urlencode
import weakref import weakref
@ -134,10 +124,6 @@ _COLLECTABLE_STATE_ATTRIBUTES = {
"name", "name",
} }
_T = TypeVar("_T")
_R = TypeVar("_R")
_P = ParamSpec("_P")
ALL_STATES_RATE_LIMIT = 60 # seconds ALL_STATES_RATE_LIMIT = 60 # seconds
DOMAIN_STATES_RATE_LIMIT = 1 # seconds DOMAIN_STATES_RATE_LIMIT = 1 # seconds
@ -1217,10 +1203,10 @@ def forgiving_boolean(value: Any) -> bool | object: ...
@overload @overload
def forgiving_boolean(value: Any, default: _T) -> bool | _T: ... def forgiving_boolean[_T](value: Any, default: _T) -> bool | _T: ...
def forgiving_boolean( def forgiving_boolean[_T](
value: Any, default: _T | object = _SENTINEL value: Any, default: _T | object = _SENTINEL
) -> bool | _T | object: ) -> bool | _T | object:
"""Try to convert value to a boolean.""" """Try to convert value to a boolean."""
@ -2840,7 +2826,7 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
# evaluated fresh with every execution, rather than executed # evaluated fresh with every execution, rather than executed
# at compile time and the value stored. The context itself # at compile time and the value stored. The context itself
# can be discarded, we only need to get at the hass object. # can be discarded, we only need to get at the hass object.
def hassfunction( def hassfunction[**_P, _R](
func: Callable[Concatenate[HomeAssistant, _P], _R], func: Callable[Concatenate[HomeAssistant, _P], _R],
jinja_context: Callable[ jinja_context: Callable[
[Callable[Concatenate[Any, _P], _R]], [Callable[Concatenate[Any, _P], _R]],

View file

@ -7,16 +7,13 @@ from collections.abc import Callable, Coroutine, Generator
from contextlib import contextmanager from contextlib import contextmanager
from contextvars import ContextVar from contextvars import ContextVar
from functools import wraps from functools import wraps
from typing import Any, TypeVar, TypeVarTuple from typing import Any
from homeassistant.core import ServiceResponse from homeassistant.core import ServiceResponse
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from .typing import TemplateVarsType from .typing import TemplateVarsType
_T = TypeVar("_T")
_Ts = TypeVarTuple("_Ts")
class TraceElement: class TraceElement:
"""Container for trace data.""" """Container for trace data."""
@ -135,7 +132,9 @@ def trace_id_get() -> tuple[str, str] | None:
return trace_id_cv.get() return trace_id_cv.get()
def trace_stack_push(trace_stack_var: ContextVar[list[_T] | None], node: _T) -> None: def trace_stack_push[_T](
trace_stack_var: ContextVar[list[_T] | None], node: _T
) -> None:
"""Push an element to the top of a trace stack.""" """Push an element to the top of a trace stack."""
trace_stack: list[_T] | None trace_stack: list[_T] | None
if (trace_stack := trace_stack_var.get()) is None: if (trace_stack := trace_stack_var.get()) is None:
@ -151,7 +150,7 @@ def trace_stack_pop(trace_stack_var: ContextVar[list[Any] | None]) -> None:
trace_stack.pop() trace_stack.pop()
def trace_stack_top(trace_stack_var: ContextVar[list[_T] | None]) -> _T | None: def trace_stack_top[_T](trace_stack_var: ContextVar[list[_T] | None]) -> _T | None:
"""Return the element at the top of a trace stack.""" """Return the element at the top of a trace stack."""
trace_stack = trace_stack_var.get() trace_stack = trace_stack_var.get()
return trace_stack[-1] if trace_stack else None return trace_stack[-1] if trace_stack else None
@ -261,7 +260,7 @@ def trace_path(suffix: str | list[str]) -> Generator[None, None, None]:
trace_path_pop(count) trace_path_pop(count)
def async_trace_path( def async_trace_path[*_Ts](
suffix: str | list[str], suffix: str | list[str],
) -> Callable[ ) -> Callable[
[Callable[[*_Ts], Coroutine[Any, Any, None]]], [Callable[[*_Ts], Coroutine[Any, Any, None]]],