Make TypeVars private (2) (#68206)
This commit is contained in:
parent
be7ef6115c
commit
eae0c75620
5 changed files with 31 additions and 29 deletions
|
@ -4,13 +4,15 @@ from __future__ import annotations
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, TypeVar
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
T = TypeVar("T", bound="StrEnum")
|
_StrEnumT = TypeVar("_StrEnumT", bound="StrEnum")
|
||||||
|
|
||||||
|
|
||||||
class StrEnum(str, Enum):
|
class StrEnum(str, Enum):
|
||||||
"""Partial backport of Python 3.11's StrEnum for our basic use cases."""
|
"""Partial backport of Python 3.11's StrEnum for our basic use cases."""
|
||||||
|
|
||||||
def __new__(cls: type[T], value: str, *args: Any, **kwargs: Any) -> T:
|
def __new__(
|
||||||
|
cls: type[_StrEnumT], value: str, *args: Any, **kwargs: Any
|
||||||
|
) -> _StrEnumT:
|
||||||
"""Create a new StrEnum instance."""
|
"""Create a new StrEnum instance."""
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
raise TypeError(f"{value!r} is not a string")
|
raise TypeError(f"{value!r} is not a string")
|
||||||
|
|
|
@ -102,7 +102,7 @@ sun_event = vol.All(vol.Lower, vol.Any(SUN_EVENT_SUNSET, SUN_EVENT_SUNRISE))
|
||||||
port = vol.All(vol.Coerce(int), vol.Range(min=1, max=65535))
|
port = vol.All(vol.Coerce(int), vol.Range(min=1, max=65535))
|
||||||
|
|
||||||
# typing typevar
|
# typing typevar
|
||||||
T = TypeVar("T")
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
def path(value: Any) -> str:
|
def path(value: Any) -> str:
|
||||||
|
@ -253,20 +253,20 @@ def ensure_list(value: None) -> list[Any]:
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def ensure_list(value: list[T]) -> list[T]:
|
def ensure_list(value: list[_T]) -> list[_T]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def ensure_list(value: list[T] | T) -> list[T]:
|
def ensure_list(value: list[_T] | _T) -> list[_T]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
def ensure_list(value: T | None) -> list[T] | list[Any]:
|
def ensure_list(value: _T | None) -> list[_T] | list[Any]:
|
||||||
"""Wrap value in list if it is not one."""
|
"""Wrap value in list if it is not one."""
|
||||||
if value is None:
|
if value is None:
|
||||||
return []
|
return []
|
||||||
return cast("list[T]", value) if isinstance(value, list) else [value]
|
return cast("list[_T]", value) if isinstance(value, list) else [value]
|
||||||
|
|
||||||
|
|
||||||
def entity_id(value: Any) -> str:
|
def entity_id(value: Any) -> str:
|
||||||
|
@ -467,7 +467,7 @@ def time_period_seconds(value: float | str) -> timedelta:
|
||||||
time_period = vol.Any(time_period_str, time_period_seconds, timedelta, time_period_dict)
|
time_period = vol.Any(time_period_str, time_period_seconds, timedelta, time_period_dict)
|
||||||
|
|
||||||
|
|
||||||
def match_all(value: T) -> T:
|
def match_all(value: _T) -> _T:
|
||||||
"""Validate that matches all values."""
|
"""Validate that matches all values."""
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
@ -483,7 +483,7 @@ positive_time_period_dict = vol.All(time_period_dict, positive_timedelta)
|
||||||
positive_time_period = vol.All(time_period, positive_timedelta)
|
positive_time_period = vol.All(time_period, positive_timedelta)
|
||||||
|
|
||||||
|
|
||||||
def remove_falsy(value: list[T]) -> list[T]:
|
def remove_falsy(value: list[_T]) -> list[_T]:
|
||||||
"""Remove falsy values from a list."""
|
"""Remove falsy values from a list."""
|
||||||
return [v for v in value if v]
|
return [v for v in value if v]
|
||||||
|
|
||||||
|
@ -510,7 +510,7 @@ def slug(value: Any) -> str:
|
||||||
|
|
||||||
|
|
||||||
def schema_with_slug_keys(
|
def schema_with_slug_keys(
|
||||||
value_schema: T | Callable, *, slug_validator: Callable[[Any], str] = slug
|
value_schema: _T | Callable, *, slug_validator: Callable[[Any], str] = slug
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
"""Ensure dicts have slugs as keys.
|
"""Ensure dicts have slugs as keys.
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ _LOGGER = logging.getLogger(__name__)
|
||||||
# Keep track of integrations already reported to prevent flooding
|
# Keep track of integrations already reported to prevent flooding
|
||||||
_REPORTED_INTEGRATIONS: set[str] = set()
|
_REPORTED_INTEGRATIONS: set[str] = set()
|
||||||
|
|
||||||
CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) # pylint: disable=invalid-name
|
_CallableT = TypeVar("_CallableT", bound=Callable)
|
||||||
|
|
||||||
|
|
||||||
def get_integration_frame(
|
def get_integration_frame(
|
||||||
|
@ -113,7 +113,7 @@ def report_integration(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def warn_use(func: CALLABLE_T, what: str) -> CALLABLE_T:
|
def warn_use(func: _CallableT, what: str) -> _CallableT:
|
||||||
"""Mock a function to warn when it was about to be used."""
|
"""Mock a function to warn when it was about to be used."""
|
||||||
if asyncio.iscoroutinefunction(func):
|
if asyncio.iscoroutinefunction(func):
|
||||||
|
|
||||||
|
@ -127,4 +127,4 @@ def warn_use(func: CALLABLE_T, what: str) -> CALLABLE_T:
|
||||||
def report_use(*args: Any, **kwargs: Any) -> None:
|
def report_use(*args: Any, **kwargs: Any) -> None:
|
||||||
report(what)
|
report(what)
|
||||||
|
|
||||||
return cast(CALLABLE_T, report_use)
|
return cast(_CallableT, report_use)
|
||||||
|
|
|
@ -9,9 +9,9 @@ from typing import TypeVar, cast
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.loader import bind_hass
|
from homeassistant.loader import bind_hass
|
||||||
|
|
||||||
T = TypeVar("T")
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
FUNC = Callable[[HomeAssistant], T]
|
FUNC = Callable[[HomeAssistant], _T]
|
||||||
|
|
||||||
|
|
||||||
def singleton(data_key: str) -> Callable[[FUNC], FUNC]:
|
def singleton(data_key: str) -> Callable[[FUNC], FUNC]:
|
||||||
|
@ -26,30 +26,30 @@ def singleton(data_key: str) -> Callable[[FUNC], FUNC]:
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def wrapped(hass: HomeAssistant) -> T:
|
def wrapped(hass: HomeAssistant) -> _T:
|
||||||
if data_key not in hass.data:
|
if data_key not in hass.data:
|
||||||
hass.data[data_key] = func(hass)
|
hass.data[data_key] = func(hass)
|
||||||
return cast(T, hass.data[data_key])
|
return cast(_T, hass.data[data_key])
|
||||||
|
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def async_wrapped(hass: HomeAssistant) -> T:
|
async def async_wrapped(hass: HomeAssistant) -> _T:
|
||||||
if data_key not in hass.data:
|
if data_key not in hass.data:
|
||||||
evt = hass.data[data_key] = asyncio.Event()
|
evt = hass.data[data_key] = asyncio.Event()
|
||||||
result = await func(hass)
|
result = await func(hass)
|
||||||
hass.data[data_key] = result
|
hass.data[data_key] = result
|
||||||
evt.set()
|
evt.set()
|
||||||
return cast(T, result)
|
return cast(_T, result)
|
||||||
|
|
||||||
obj_or_evt = hass.data[data_key]
|
obj_or_evt = hass.data[data_key]
|
||||||
|
|
||||||
if isinstance(obj_or_evt, asyncio.Event):
|
if isinstance(obj_or_evt, asyncio.Event):
|
||||||
await obj_or_evt.wait()
|
await obj_or_evt.wait()
|
||||||
return cast(T, hass.data[data_key])
|
return cast(_T, hass.data[data_key])
|
||||||
|
|
||||||
return cast(T, obj_or_evt)
|
return cast(_T, obj_or_evt)
|
||||||
|
|
||||||
return async_wrapped
|
return async_wrapped
|
||||||
|
|
||||||
|
|
|
@ -23,14 +23,14 @@ from .debounce import Debouncer
|
||||||
REQUEST_REFRESH_DEFAULT_COOLDOWN = 10
|
REQUEST_REFRESH_DEFAULT_COOLDOWN = 10
|
||||||
REQUEST_REFRESH_DEFAULT_IMMEDIATE = True
|
REQUEST_REFRESH_DEFAULT_IMMEDIATE = True
|
||||||
|
|
||||||
T = TypeVar("T")
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
class UpdateFailed(Exception):
|
class UpdateFailed(Exception):
|
||||||
"""Raised when an update has failed."""
|
"""Raised when an update has failed."""
|
||||||
|
|
||||||
|
|
||||||
class DataUpdateCoordinator(Generic[T]):
|
class DataUpdateCoordinator(Generic[_T]):
|
||||||
"""Class to manage fetching data from single endpoint."""
|
"""Class to manage fetching data from single endpoint."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -40,7 +40,7 @@ class DataUpdateCoordinator(Generic[T]):
|
||||||
*,
|
*,
|
||||||
name: str,
|
name: str,
|
||||||
update_interval: timedelta | None = None,
|
update_interval: timedelta | None = None,
|
||||||
update_method: Callable[[], Awaitable[T]] | None = None,
|
update_method: Callable[[], Awaitable[_T]] | None = None,
|
||||||
request_refresh_debouncer: Debouncer | None = None,
|
request_refresh_debouncer: Debouncer | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize global data updater."""
|
"""Initialize global data updater."""
|
||||||
|
@ -56,7 +56,7 @@ class DataUpdateCoordinator(Generic[T]):
|
||||||
# to make sure the first update was successful.
|
# to make sure the first update was successful.
|
||||||
# Set type to just T to remove annoying checks that data is not None
|
# Set type to just T to remove annoying checks that data is not None
|
||||||
# when it was already checked during setup.
|
# when it was already checked during setup.
|
||||||
self.data: T = None # type: ignore[assignment]
|
self.data: _T = None # type: ignore[assignment]
|
||||||
|
|
||||||
self._listeners: list[CALLBACK_TYPE] = []
|
self._listeners: list[CALLBACK_TYPE] = []
|
||||||
self._job = HassJob(self._handle_refresh_interval)
|
self._job = HassJob(self._handle_refresh_interval)
|
||||||
|
@ -140,7 +140,7 @@ class DataUpdateCoordinator(Generic[T]):
|
||||||
"""
|
"""
|
||||||
await self._debounced_refresh.async_call()
|
await self._debounced_refresh.async_call()
|
||||||
|
|
||||||
async def _async_update_data(self) -> T:
|
async def _async_update_data(self) -> _T:
|
||||||
"""Fetch the latest data from the source."""
|
"""Fetch the latest data from the source."""
|
||||||
if self.update_method is None:
|
if self.update_method is None:
|
||||||
raise NotImplementedError("Update method not implemented")
|
raise NotImplementedError("Update method not implemented")
|
||||||
|
@ -265,7 +265,7 @@ class DataUpdateCoordinator(Generic[T]):
|
||||||
update_callback()
|
update_callback()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_set_updated_data(self, data: T) -> None:
|
def async_set_updated_data(self, data: _T) -> None:
|
||||||
"""Manually update data, notify listeners and reset refresh interval."""
|
"""Manually update data, notify listeners and reset refresh interval."""
|
||||||
if self._unsub_refresh:
|
if self._unsub_refresh:
|
||||||
self._unsub_refresh()
|
self._unsub_refresh()
|
||||||
|
@ -295,10 +295,10 @@ class DataUpdateCoordinator(Generic[T]):
|
||||||
self._unsub_refresh = None
|
self._unsub_refresh = None
|
||||||
|
|
||||||
|
|
||||||
class CoordinatorEntity(Generic[T], entity.Entity):
|
class CoordinatorEntity(Generic[_T], entity.Entity):
|
||||||
"""A class for entities using DataUpdateCoordinator."""
|
"""A class for entities using DataUpdateCoordinator."""
|
||||||
|
|
||||||
def __init__(self, coordinator: DataUpdateCoordinator[T]) -> None:
|
def __init__(self, coordinator: DataUpdateCoordinator[_T]) -> None:
|
||||||
"""Create the entity with a DataUpdateCoordinator."""
|
"""Create the entity with a DataUpdateCoordinator."""
|
||||||
self.coordinator = coordinator
|
self.coordinator = coordinator
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue