Improve typing [util.decorator] (#67087)
This commit is contained in:
parent
46c2bd0eb0
commit
ec980a574b
23 changed files with 59 additions and 38 deletions
|
@ -22,6 +22,7 @@ homeassistant.helpers.script_variables
|
|||
homeassistant.helpers.translation
|
||||
homeassistant.util.async_
|
||||
homeassistant.util.color
|
||||
homeassistant.util.decorator
|
||||
homeassistant.util.process
|
||||
homeassistant.util.unit_system
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ from homeassistant.data_entry_flow import FlowResult
|
|||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.util.decorator import Registry
|
||||
|
||||
MULTI_FACTOR_AUTH_MODULES = Registry()
|
||||
MULTI_FACTOR_AUTH_MODULES: Registry[str, type[MultiFactorAuthModule]] = Registry()
|
||||
|
||||
MULTI_FACTOR_AUTH_MODULE_SCHEMA = vol.Schema(
|
||||
{
|
||||
|
@ -129,7 +129,7 @@ async def auth_mfa_module_from_config(
|
|||
hass: HomeAssistant, config: dict[str, Any]
|
||||
) -> MultiFactorAuthModule:
|
||||
"""Initialize an auth module from a config."""
|
||||
module_name = config[CONF_TYPE]
|
||||
module_name: str = config[CONF_TYPE]
|
||||
module = await _load_mfa_module(hass, module_name)
|
||||
|
||||
try:
|
||||
|
@ -142,7 +142,7 @@ async def auth_mfa_module_from_config(
|
|||
)
|
||||
raise
|
||||
|
||||
return MULTI_FACTOR_AUTH_MODULES[module_name](hass, config) # type: ignore[no-any-return]
|
||||
return MULTI_FACTOR_AUTH_MODULES[module_name](hass, config)
|
||||
|
||||
|
||||
async def _load_mfa_module(hass: HomeAssistant, module_name: str) -> types.ModuleType:
|
||||
|
|
|
@ -25,7 +25,7 @@ from ..models import Credentials, RefreshToken, User, UserMeta
|
|||
_LOGGER = logging.getLogger(__name__)
|
||||
DATA_REQS = "auth_prov_reqs_processed"
|
||||
|
||||
AUTH_PROVIDERS = Registry()
|
||||
AUTH_PROVIDERS: Registry[str, type[AuthProvider]] = Registry()
|
||||
|
||||
AUTH_PROVIDER_SCHEMA = vol.Schema(
|
||||
{
|
||||
|
@ -136,7 +136,7 @@ async def auth_provider_from_config(
|
|||
hass: HomeAssistant, store: AuthStore, config: dict[str, Any]
|
||||
) -> AuthProvider:
|
||||
"""Initialize an auth provider from a config."""
|
||||
provider_name = config[CONF_TYPE]
|
||||
provider_name: str = config[CONF_TYPE]
|
||||
module = await load_auth_provider_module(hass, provider_name)
|
||||
|
||||
try:
|
||||
|
@ -149,7 +149,7 @@ async def auth_provider_from_config(
|
|||
)
|
||||
raise
|
||||
|
||||
return AUTH_PROVIDERS[provider_name](hass, store, config) # type: ignore[no-any-return]
|
||||
return AUTH_PROVIDERS[provider_name](hass, store, config)
|
||||
|
||||
|
||||
async def load_auth_provider_module(
|
||||
|
|
|
@ -83,7 +83,7 @@ if TYPE_CHECKING:
|
|||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
ENTITY_ADAPTERS = Registry()
|
||||
ENTITY_ADAPTERS: Registry[str, type[AlexaEntity]] = Registry()
|
||||
|
||||
TRANSLATION_TABLE = dict.fromkeys(map(ord, r"}{\/|\"()[]+~!><*%"), None)
|
||||
|
||||
|
|
|
@ -73,7 +73,7 @@ from .errors import (
|
|||
from .state_report import async_enable_proactive_mode
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
HANDLERS = Registry()
|
||||
HANDLERS = Registry() # type: ignore[var-annotated]
|
||||
|
||||
|
||||
@HANDLERS.register(("Alexa.Discovery", "Discover"))
|
||||
|
|
|
@ -12,7 +12,7 @@ from .const import DOMAIN, SYN_RESOLUTION_MATCH
|
|||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
HANDLERS = Registry()
|
||||
HANDLERS = Registry() # type: ignore[var-annotated]
|
||||
|
||||
INTENTS_API_ENDPOINT = "/api/alexa"
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ FILTER_NAME_OUTLIER = "outlier"
|
|||
FILTER_NAME_THROTTLE = "throttle"
|
||||
FILTER_NAME_TIME_THROTTLE = "time_throttle"
|
||||
FILTER_NAME_TIME_SMA = "time_simple_moving_average"
|
||||
FILTERS = Registry()
|
||||
FILTERS: Registry[str, type[Filter]] = Registry()
|
||||
|
||||
CONF_FILTERS = "filters"
|
||||
CONF_FILTER_NAME = "filter"
|
||||
|
|
|
@ -19,7 +19,7 @@ from .helpers import GoogleEntity, RequestData, async_get_entities
|
|||
|
||||
EXECUTE_LIMIT = 2 # Wait 2 seconds for execute to finish
|
||||
|
||||
HANDLERS = Registry()
|
||||
HANDLERS = Registry() # type: ignore[var-annotated]
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
"""Extend the basic Accessory and Bridge functions."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from pyhap.accessory import Accessory, Bridge
|
||||
|
@ -90,7 +92,7 @@ SWITCH_TYPES = {
|
|||
TYPE_SWITCH: "Switch",
|
||||
TYPE_VALVE: "Valve",
|
||||
}
|
||||
TYPES = Registry()
|
||||
TYPES: Registry[str, type[HomeAccessory]] = Registry()
|
||||
|
||||
|
||||
def get_accessory(hass, driver, state, aid, config): # noqa: C901
|
||||
|
|
|
@ -9,7 +9,7 @@ from homeassistant.util import decorator
|
|||
from .const import CONF_INVERSE, SIGNAL_DS18B20_NEW
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
HANDLERS = decorator.Registry()
|
||||
HANDLERS = decorator.Registry() # type: ignore[var-annotated]
|
||||
|
||||
|
||||
@HANDLERS.register("state")
|
||||
|
|
|
@ -109,7 +109,7 @@ _LOGGER = logging.getLogger(__name__)
|
|||
|
||||
DELAY_SAVE = 10
|
||||
|
||||
WEBHOOK_COMMANDS = Registry()
|
||||
WEBHOOK_COMMANDS = Registry() # type: ignore[var-annotated]
|
||||
|
||||
COMBINED_CLASSES = set(BINARY_SENSOR_CLASSES + SENSOR_CLASSES)
|
||||
SENSOR_TYPES = [ATTR_SENSOR_TYPE_BINARY_SENSOR, ATTR_SENSOR_TYPE_SENSOR]
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Coroutine
|
||||
from collections.abc import Callable
|
||||
import logging
|
||||
import socket
|
||||
import sys
|
||||
|
@ -337,9 +337,7 @@ def _gw_callback_factory(
|
|||
_LOGGER.debug("Node update: node %s child %s", msg.node_id, msg.child_id)
|
||||
|
||||
msg_type = msg.gateway.const.MessageType(msg.type)
|
||||
msg_handler: Callable[
|
||||
[HomeAssistant, GatewayId, Message], Coroutine[Any, Any, None]
|
||||
] | None = HANDLERS.get(msg_type.name)
|
||||
msg_handler = HANDLERS.get(msg_type.name)
|
||||
|
||||
if msg_handler is None:
|
||||
return
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
"""Handle MySensors messages."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any
|
||||
|
||||
from mysensors import Message
|
||||
|
||||
from homeassistant.const import Platform
|
||||
|
@ -12,7 +15,9 @@ from .const import CHILD_CALLBACK, NODE_CALLBACK, DevId, GatewayId
|
|||
from .device import get_mysensors_devices
|
||||
from .helpers import discover_mysensors_platform, validate_set_msg
|
||||
|
||||
HANDLERS = decorator.Registry()
|
||||
HANDLERS: decorator.Registry[
|
||||
str, Callable[[HomeAssistant, GatewayId, Message], Coroutine[Any, Any, None]]
|
||||
] = decorator.Registry()
|
||||
|
||||
|
||||
@HANDLERS.register("set")
|
||||
|
|
|
@ -31,7 +31,9 @@ from .const import (
|
|||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
SCHEMAS = Registry()
|
||||
SCHEMAS: Registry[
|
||||
tuple[str, str], Callable[[BaseAsyncGateway, ChildSensor, ValueType], vol.Schema]
|
||||
] = Registry()
|
||||
|
||||
|
||||
@callback
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
"""ONVIF event parsers."""
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.util import dt as dt_util
|
||||
from homeassistant.util.decorator import Registry
|
||||
|
||||
from .models import Event
|
||||
|
||||
PARSERS = Registry()
|
||||
PARSERS: Registry[str, Callable[[str, Any], Coroutine[Any, Any, Event]]] = Registry()
|
||||
|
||||
|
||||
@PARSERS.register("tns1:VideoSource/MotionAlarm")
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
"""Helpers to help coordinate updates."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Coroutine
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import ServerDisconnectedError
|
||||
from pyoverkiz.client import OverkizClient
|
||||
|
@ -25,7 +27,9 @@ from homeassistant.util.decorator import Registry
|
|||
|
||||
from .const import DOMAIN, LOGGER, UPDATE_INTERVAL
|
||||
|
||||
EVENT_HANDLERS = Registry()
|
||||
EVENT_HANDLERS: Registry[
|
||||
str, Callable[[OverkizDataUpdateCoordinator, Event], Coroutine[Any, Any, None]]
|
||||
] = Registry()
|
||||
|
||||
|
||||
class OverkizDataUpdateCoordinator(DataUpdateCoordinator[dict[str, Device]]):
|
||||
|
|
|
@ -17,7 +17,7 @@ from .helper import supports_encryption
|
|||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
HANDLERS = decorator.Registry()
|
||||
HANDLERS = decorator.Registry() # type: ignore[var-annotated]
|
||||
|
||||
|
||||
def get_cipher():
|
||||
|
|
|
@ -245,7 +245,7 @@ class Stream:
|
|||
self, fmt: str, timeout: int = OUTPUT_IDLE_TIMEOUT
|
||||
) -> StreamOutput:
|
||||
"""Add provider output stream."""
|
||||
if not self._outputs.get(fmt):
|
||||
if not (provider := self._outputs.get(fmt)):
|
||||
|
||||
@callback
|
||||
def idle_callback() -> None:
|
||||
|
@ -259,7 +259,7 @@ class Stream:
|
|||
self.hass, IdleTimer(self.hass, timeout, idle_callback)
|
||||
)
|
||||
self._outputs[fmt] = provider
|
||||
return self._outputs[fmt]
|
||||
return provider
|
||||
|
||||
def remove_provider(self, provider: StreamOutput) -> None:
|
||||
"""Remove provider output stream."""
|
||||
|
|
|
@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
|||
|
||||
from . import Stream
|
||||
|
||||
PROVIDERS = Registry()
|
||||
PROVIDERS: Registry[str, type[StreamOutput]] = Registry()
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
|
|
|
@ -62,7 +62,7 @@ SOURCE_UNIGNORE = "unignore"
|
|||
# This is used to signal that re-authentication is required by the user.
|
||||
SOURCE_REAUTH = "reauth"
|
||||
|
||||
HANDLERS = Registry()
|
||||
HANDLERS: Registry[str, type[ConfigFlow]] = Registry()
|
||||
|
||||
STORAGE_KEY = "core.config_entries"
|
||||
STORAGE_VERSION = 1
|
||||
|
@ -530,8 +530,10 @@ class ConfigEntry:
|
|||
)
|
||||
return False
|
||||
# Handler may be a partial
|
||||
# Keep for backwards compatibility
|
||||
# https://github.com/home-assistant/core/pull/67087#discussion_r812559950
|
||||
while isinstance(handler, functools.partial):
|
||||
handler = handler.func
|
||||
handler = handler.func # type: ignore[unreachable]
|
||||
|
||||
if self.version == handler.VERSION:
|
||||
return True
|
||||
|
@ -753,7 +755,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
|
|||
if not context or "source" not in context:
|
||||
raise KeyError("Context not set or doesn't have a source set")
|
||||
|
||||
flow = cast(ConfigFlow, handler())
|
||||
flow = handler()
|
||||
flow.init_step = context["source"]
|
||||
return flow
|
||||
|
||||
|
@ -1496,7 +1498,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager):
|
|||
if entry.domain not in HANDLERS:
|
||||
raise data_entry_flow.UnknownHandler
|
||||
|
||||
return cast(OptionsFlow, HANDLERS[entry.domain].async_get_options_flow(entry))
|
||||
return HANDLERS[entry.domain].async_get_options_flow(entry)
|
||||
|
||||
async def async_finish_flow(
|
||||
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult
|
||||
|
|
|
@ -13,7 +13,7 @@ from homeassistant.util import decorator
|
|||
|
||||
from . import config_validation as cv
|
||||
|
||||
SELECTORS = decorator.Registry()
|
||||
SELECTORS: decorator.Registry[str, type[Selector]] = decorator.Registry()
|
||||
|
||||
|
||||
def _get_selector_class(config: Any) -> type[Selector]:
|
||||
|
@ -24,12 +24,12 @@ def _get_selector_class(config: Any) -> type[Selector]:
|
|||
if len(config) != 1:
|
||||
raise vol.Invalid(f"Only one type can be specified. Found {', '.join(config)}")
|
||||
|
||||
selector_type = list(config)[0]
|
||||
selector_type: str = list(config)[0]
|
||||
|
||||
if (selector_class := SELECTORS.get(selector_type)) is None:
|
||||
raise vol.Invalid(f"Unknown selector type {selector_type} found")
|
||||
|
||||
return cast(type[Selector], selector_class)
|
||||
return selector_class
|
||||
|
||||
|
||||
def selector(config: Any) -> Selector:
|
||||
|
|
|
@ -2,18 +2,19 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Hashable
|
||||
from typing import TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) # pylint: disable=invalid-name
|
||||
_KT = TypeVar("_KT", bound=Hashable)
|
||||
_VT = TypeVar("_VT", bound=Callable[..., Any])
|
||||
|
||||
|
||||
class Registry(dict):
|
||||
class Registry(dict[_KT, _VT]):
|
||||
"""Registry of items."""
|
||||
|
||||
def register(self, name: Hashable) -> Callable[[CALLABLE_T], CALLABLE_T]:
|
||||
def register(self, name: _KT) -> Callable[[_VT], _VT]:
|
||||
"""Return decorator to register item with a specific name."""
|
||||
|
||||
def decorator(func: CALLABLE_T) -> CALLABLE_T:
|
||||
def decorator(func: _VT) -> _VT:
|
||||
"""Register decorated function."""
|
||||
self[name] = func
|
||||
return func
|
||||
|
|
3
mypy.ini
3
mypy.ini
|
@ -76,6 +76,9 @@ disallow_any_generics = true
|
|||
[mypy-homeassistant.util.color]
|
||||
disallow_any_generics = true
|
||||
|
||||
[mypy-homeassistant.util.decorator]
|
||||
disallow_any_generics = true
|
||||
|
||||
[mypy-homeassistant.util.process]
|
||||
disallow_any_generics = true
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue