Use PEP 695 for function annotations (2) (#117659)
This commit is contained in:
parent
4cf0a3f154
commit
900b6211ef
10 changed files with 30 additions and 73 deletions
|
@ -8,14 +8,10 @@ from enum import Enum
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, NamedTuple, ParamSpec, TypeVar
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
_ObjectT = TypeVar("_ObjectT", bound=object)
|
|
||||||
_R = TypeVar("_R")
|
|
||||||
_P = ParamSpec("_P")
|
|
||||||
|
|
||||||
|
|
||||||
def deprecated_substitute(
|
def deprecated_substitute[_ObjectT: object](
|
||||||
substitute_name: str,
|
substitute_name: str,
|
||||||
) -> Callable[[Callable[[_ObjectT], Any]], Callable[[_ObjectT], Any]]:
|
) -> Callable[[Callable[[_ObjectT], Any]], Callable[[_ObjectT], Any]]:
|
||||||
"""Help migrate properties to new names.
|
"""Help migrate properties to new names.
|
||||||
|
@ -92,7 +88,7 @@ def get_deprecated(
|
||||||
return config.get(new_name, default)
|
return config.get(new_name, default)
|
||||||
|
|
||||||
|
|
||||||
def deprecated_class(
|
def deprecated_class[**_P, _R](
|
||||||
replacement: str, *, breaks_in_ha_version: str | None = None
|
replacement: str, *, breaks_in_ha_version: str | None = None
|
||||||
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
|
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
|
||||||
"""Mark class as deprecated and provide a replacement class to be used instead.
|
"""Mark class as deprecated and provide a replacement class to be used instead.
|
||||||
|
@ -117,7 +113,7 @@ def deprecated_class(
|
||||||
return deprecated_decorator
|
return deprecated_decorator
|
||||||
|
|
||||||
|
|
||||||
def deprecated_function(
|
def deprecated_function[**_P, _R](
|
||||||
replacement: str, *, breaks_in_ha_version: str | None = None
|
replacement: str, *, breaks_in_ha_version: str | None = None
|
||||||
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
|
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
|
||||||
"""Mark function as deprecated and provide a replacement to be used instead.
|
"""Mark function as deprecated and provide a replacement to be used instead.
|
||||||
|
|
|
@ -16,16 +16,7 @@ from operator import attrgetter
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from types import FunctionType
|
from types import FunctionType
|
||||||
from typing import (
|
from typing import TYPE_CHECKING, Any, Final, Literal, NotRequired, TypedDict, final
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
Final,
|
|
||||||
Literal,
|
|
||||||
NotRequired,
|
|
||||||
TypedDict,
|
|
||||||
TypeVar,
|
|
||||||
final,
|
|
||||||
)
|
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
@ -79,8 +70,6 @@ timer = time.time
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .entity_platform import EntityPlatform
|
from .entity_platform import EntityPlatform
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
SLOW_UPDATE_WARNING = 10
|
SLOW_UPDATE_WARNING = 10
|
||||||
DATA_ENTITY_SOURCE = "entity_info"
|
DATA_ENTITY_SOURCE = "entity_info"
|
||||||
|
@ -1603,7 +1592,7 @@ class Entity(
|
||||||
return f"<entity unknown.unknown={STATE_UNKNOWN}>"
|
return f"<entity unknown.unknown={STATE_UNKNOWN}>"
|
||||||
return f"<entity {self.entity_id}={self._stringify_state(self.available)}>"
|
return f"<entity {self.entity_id}={self._stringify_state(self.available)}>"
|
||||||
|
|
||||||
async def async_request_call(self, coro: Coroutine[Any, Any, _T]) -> _T:
|
async def async_request_call[_T](self, coro: Coroutine[Any, Any, _T]) -> _T:
|
||||||
"""Process request batched."""
|
"""Process request batched."""
|
||||||
if self.parallel_updates:
|
if self.parallel_updates:
|
||||||
await self.parallel_updates.acquire()
|
await self.parallel_updates.acquire()
|
||||||
|
|
|
@ -16,7 +16,7 @@ from enum import StrEnum
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, TypeVar
|
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
@ -65,8 +65,6 @@ from .typing import UNDEFINED, UndefinedType
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
DATA_REGISTRY: HassKey[EntityRegistry] = HassKey("entity_registry")
|
DATA_REGISTRY: HassKey[EntityRegistry] = HassKey("entity_registry")
|
||||||
EVENT_ENTITY_REGISTRY_UPDATED: EventType[EventEntityRegistryUpdatedData] = EventType(
|
EVENT_ENTITY_REGISTRY_UPDATED: EventType[EventEntityRegistryUpdatedData] = EventType(
|
||||||
"entity_registry_updated"
|
"entity_registry_updated"
|
||||||
|
@ -852,7 +850,7 @@ class EntityRegistry(BaseRegistry):
|
||||||
):
|
):
|
||||||
disabled_by = RegistryEntryDisabler.INTEGRATION
|
disabled_by = RegistryEntryDisabler.INTEGRATION
|
||||||
|
|
||||||
def none_if_undefined(value: T | UndefinedType) -> T | None:
|
def none_if_undefined[_T](value: _T | UndefinedType) -> _T | None:
|
||||||
"""Return None if value is UNDEFINED, otherwise return value."""
|
"""Return None if value is UNDEFINED, otherwise return value."""
|
||||||
return None if value is UNDEFINED else value
|
return None if value is UNDEFINED else value
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ import linecache
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from types import FrameType
|
from types import FrameType
|
||||||
from typing import Any, TypeVar, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant, async_get_hass
|
from homeassistant.core import HomeAssistant, async_get_hass
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
@ -23,8 +23,6 @@ _LOGGER = logging.getLogger(__name__)
|
||||||
# Keep track of integrations already reported to prevent flooding
|
# Keep track of integrations already reported to prevent flooding
|
||||||
_REPORTED_INTEGRATIONS: set[str] = set()
|
_REPORTED_INTEGRATIONS: set[str] = set()
|
||||||
|
|
||||||
_CallableT = TypeVar("_CallableT", bound=Callable)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(kw_only=True)
|
@dataclass(kw_only=True)
|
||||||
class IntegrationFrame:
|
class IntegrationFrame:
|
||||||
|
@ -209,7 +207,7 @@ def _report_integration(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def warn_use(func: _CallableT, what: str) -> _CallableT:
|
def warn_use[_CallableT: Callable](func: _CallableT, what: str) -> _CallableT:
|
||||||
"""Mock a function to warn when it was about to be used."""
|
"""Mock a function to warn when it was about to be used."""
|
||||||
if asyncio.iscoroutinefunction(func):
|
if asyncio.iscoroutinefunction(func):
|
||||||
|
|
||||||
|
|
|
@ -6,12 +6,9 @@ import asyncio
|
||||||
from collections.abc import Callable, Hashable
|
from collections.abc import Callable, Hashable
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import TypeVarTuple
|
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
|
||||||
_Ts = TypeVarTuple("_Ts")
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -52,7 +49,7 @@ class KeyedRateLimit:
|
||||||
self._rate_limit_timers.clear()
|
self._rate_limit_timers.clear()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_schedule_action(
|
def async_schedule_action[*_Ts](
|
||||||
self,
|
self,
|
||||||
key: Hashable,
|
key: Hashable,
|
||||||
rate_limit: float | None,
|
rate_limit: float | None,
|
||||||
|
|
|
@ -3,15 +3,12 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable, Iterable, Mapping
|
from collections.abc import Callable, Iterable, Mapping
|
||||||
from typing import Any, TypeVar, cast, overload
|
from typing import Any, cast, overload
|
||||||
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
|
||||||
REDACTED = "**REDACTED**"
|
REDACTED = "**REDACTED**"
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
|
||||||
_ValueT = TypeVar("_ValueT")
|
|
||||||
|
|
||||||
|
|
||||||
def partial_redact(
|
def partial_redact(
|
||||||
x: str | Any, unmasked_prefix: int = 4, unmasked_suffix: int = 4
|
x: str | Any, unmasked_prefix: int = 4, unmasked_suffix: int = 4
|
||||||
|
@ -32,19 +29,19 @@ def partial_redact(
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def async_redact_data( # type: ignore[overload-overlap]
|
def async_redact_data[_ValueT]( # type: ignore[overload-overlap]
|
||||||
data: Mapping, to_redact: Iterable[Any] | Mapping[Any, Callable[[_ValueT], _ValueT]]
|
data: Mapping, to_redact: Iterable[Any] | Mapping[Any, Callable[[_ValueT], _ValueT]]
|
||||||
) -> dict: ...
|
) -> dict: ...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def async_redact_data(
|
def async_redact_data[_T, _ValueT](
|
||||||
data: _T, to_redact: Iterable[Any] | Mapping[Any, Callable[[_ValueT], _ValueT]]
|
data: _T, to_redact: Iterable[Any] | Mapping[Any, Callable[[_ValueT], _ValueT]]
|
||||||
) -> _T: ...
|
) -> _T: ...
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_redact_data(
|
def async_redact_data[_T, _ValueT](
|
||||||
data: _T, to_redact: Iterable[Any] | Mapping[Any, Callable[[_ValueT], _ValueT]]
|
data: _T, to_redact: Iterable[Any] | Mapping[Any, Callable[[_ValueT], _ValueT]]
|
||||||
) -> _T:
|
) -> _T:
|
||||||
"""Redact sensitive data in a dict."""
|
"""Redact sensitive data in a dict."""
|
||||||
|
|
|
@ -13,7 +13,7 @@ from functools import cached_property, partial
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import Any, Literal, TypedDict, TypeVar, cast
|
from typing import Any, Literal, TypedDict, cast
|
||||||
|
|
||||||
import async_interrupt
|
import async_interrupt
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
@ -111,8 +111,6 @@ from .typing import UNDEFINED, ConfigType, UndefinedType
|
||||||
|
|
||||||
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
|
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
|
||||||
|
|
||||||
SCRIPT_MODE_PARALLEL = "parallel"
|
SCRIPT_MODE_PARALLEL = "parallel"
|
||||||
SCRIPT_MODE_QUEUED = "queued"
|
SCRIPT_MODE_QUEUED = "queued"
|
||||||
SCRIPT_MODE_RESTART = "restart"
|
SCRIPT_MODE_RESTART = "restart"
|
||||||
|
@ -713,7 +711,9 @@ class _ScriptRun:
|
||||||
else:
|
else:
|
||||||
wait_var["remaining"] = None
|
wait_var["remaining"] = None
|
||||||
|
|
||||||
async def _async_run_long_action(self, long_task: asyncio.Task[_T]) -> _T | None:
|
async def _async_run_long_action[_T](
|
||||||
|
self, long_task: asyncio.Task[_T]
|
||||||
|
) -> _T | None:
|
||||||
"""Run a long task while monitoring for stop request."""
|
"""Run a long task while monitoring for stop request."""
|
||||||
try:
|
try:
|
||||||
async with async_interrupt.interrupt(self._stop, ScriptStoppedError, None):
|
async with async_interrupt.interrupt(self._stop, ScriptStoppedError, None):
|
||||||
|
|
|
@ -68,9 +68,6 @@ from .typing import ConfigType, TemplateVarsType
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .entity import Entity
|
from .entity import Entity
|
||||||
|
|
||||||
_EntityT = TypeVar("_EntityT", bound=Entity)
|
|
||||||
|
|
||||||
|
|
||||||
CONF_SERVICE_ENTITY_ID = "entity_id"
|
CONF_SERVICE_ENTITY_ID = "entity_id"
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
@ -434,7 +431,7 @@ def extract_entity_ids(
|
||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
async def async_extract_entities(
|
async def async_extract_entities[_EntityT: Entity](
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
entities: Iterable[_EntityT],
|
entities: Iterable[_EntityT],
|
||||||
service_call: ServiceCall,
|
service_call: ServiceCall,
|
||||||
|
|
|
@ -22,17 +22,7 @@ import statistics
|
||||||
from struct import error as StructError, pack, unpack_from
|
from struct import error as StructError, pack, unpack_from
|
||||||
import sys
|
import sys
|
||||||
from types import CodeType, TracebackType
|
from types import CodeType, TracebackType
|
||||||
from typing import (
|
from typing import Any, Concatenate, Literal, NoReturn, Self, cast, overload
|
||||||
Any,
|
|
||||||
Concatenate,
|
|
||||||
Literal,
|
|
||||||
NoReturn,
|
|
||||||
ParamSpec,
|
|
||||||
Self,
|
|
||||||
TypeVar,
|
|
||||||
cast,
|
|
||||||
overload,
|
|
||||||
)
|
|
||||||
from urllib.parse import urlencode as urllib_urlencode
|
from urllib.parse import urlencode as urllib_urlencode
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
|
@ -134,10 +124,6 @@ _COLLECTABLE_STATE_ATTRIBUTES = {
|
||||||
"name",
|
"name",
|
||||||
}
|
}
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
|
||||||
_R = TypeVar("_R")
|
|
||||||
_P = ParamSpec("_P")
|
|
||||||
|
|
||||||
ALL_STATES_RATE_LIMIT = 60 # seconds
|
ALL_STATES_RATE_LIMIT = 60 # seconds
|
||||||
DOMAIN_STATES_RATE_LIMIT = 1 # seconds
|
DOMAIN_STATES_RATE_LIMIT = 1 # seconds
|
||||||
|
|
||||||
|
@ -1217,10 +1203,10 @@ def forgiving_boolean(value: Any) -> bool | object: ...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def forgiving_boolean(value: Any, default: _T) -> bool | _T: ...
|
def forgiving_boolean[_T](value: Any, default: _T) -> bool | _T: ...
|
||||||
|
|
||||||
|
|
||||||
def forgiving_boolean(
|
def forgiving_boolean[_T](
|
||||||
value: Any, default: _T | object = _SENTINEL
|
value: Any, default: _T | object = _SENTINEL
|
||||||
) -> bool | _T | object:
|
) -> bool | _T | object:
|
||||||
"""Try to convert value to a boolean."""
|
"""Try to convert value to a boolean."""
|
||||||
|
@ -2840,7 +2826,7 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
|
||||||
# evaluated fresh with every execution, rather than executed
|
# evaluated fresh with every execution, rather than executed
|
||||||
# at compile time and the value stored. The context itself
|
# at compile time and the value stored. The context itself
|
||||||
# can be discarded, we only need to get at the hass object.
|
# can be discarded, we only need to get at the hass object.
|
||||||
def hassfunction(
|
def hassfunction[**_P, _R](
|
||||||
func: Callable[Concatenate[HomeAssistant, _P], _R],
|
func: Callable[Concatenate[HomeAssistant, _P], _R],
|
||||||
jinja_context: Callable[
|
jinja_context: Callable[
|
||||||
[Callable[Concatenate[Any, _P], _R]],
|
[Callable[Concatenate[Any, _P], _R]],
|
||||||
|
|
|
@ -7,16 +7,13 @@ from collections.abc import Callable, Coroutine, Generator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, TypeVar, TypeVarTuple
|
from typing import Any
|
||||||
|
|
||||||
from homeassistant.core import ServiceResponse
|
from homeassistant.core import ServiceResponse
|
||||||
import homeassistant.util.dt as dt_util
|
import homeassistant.util.dt as dt_util
|
||||||
|
|
||||||
from .typing import TemplateVarsType
|
from .typing import TemplateVarsType
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
|
||||||
_Ts = TypeVarTuple("_Ts")
|
|
||||||
|
|
||||||
|
|
||||||
class TraceElement:
|
class TraceElement:
|
||||||
"""Container for trace data."""
|
"""Container for trace data."""
|
||||||
|
@ -135,7 +132,9 @@ def trace_id_get() -> tuple[str, str] | None:
|
||||||
return trace_id_cv.get()
|
return trace_id_cv.get()
|
||||||
|
|
||||||
|
|
||||||
def trace_stack_push(trace_stack_var: ContextVar[list[_T] | None], node: _T) -> None:
|
def trace_stack_push[_T](
|
||||||
|
trace_stack_var: ContextVar[list[_T] | None], node: _T
|
||||||
|
) -> None:
|
||||||
"""Push an element to the top of a trace stack."""
|
"""Push an element to the top of a trace stack."""
|
||||||
trace_stack: list[_T] | None
|
trace_stack: list[_T] | None
|
||||||
if (trace_stack := trace_stack_var.get()) is None:
|
if (trace_stack := trace_stack_var.get()) is None:
|
||||||
|
@ -151,7 +150,7 @@ def trace_stack_pop(trace_stack_var: ContextVar[list[Any] | None]) -> None:
|
||||||
trace_stack.pop()
|
trace_stack.pop()
|
||||||
|
|
||||||
|
|
||||||
def trace_stack_top(trace_stack_var: ContextVar[list[_T] | None]) -> _T | None:
|
def trace_stack_top[_T](trace_stack_var: ContextVar[list[_T] | None]) -> _T | None:
|
||||||
"""Return the element at the top of a trace stack."""
|
"""Return the element at the top of a trace stack."""
|
||||||
trace_stack = trace_stack_var.get()
|
trace_stack = trace_stack_var.get()
|
||||||
return trace_stack[-1] if trace_stack else None
|
return trace_stack[-1] if trace_stack else None
|
||||||
|
@ -261,7 +260,7 @@ def trace_path(suffix: str | list[str]) -> Generator[None, None, None]:
|
||||||
trace_path_pop(count)
|
trace_path_pop(count)
|
||||||
|
|
||||||
|
|
||||||
def async_trace_path(
|
def async_trace_path[*_Ts](
|
||||||
suffix: str | list[str],
|
suffix: str | list[str],
|
||||||
) -> Callable[
|
) -> Callable[
|
||||||
[Callable[[*_Ts], Coroutine[Any, Any, None]]],
|
[Callable[[*_Ts], Coroutine[Any, Any, None]]],
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue