Improve typing [util.decorator] (#67087)

This commit is contained in:
Marc Mueller 2022-02-23 20:58:42 +01:00 committed by GitHub
parent 46c2bd0eb0
commit ec980a574b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 59 additions and 38 deletions

View file

@ -22,6 +22,7 @@ homeassistant.helpers.script_variables
homeassistant.helpers.translation homeassistant.helpers.translation
homeassistant.util.async_ homeassistant.util.async_
homeassistant.util.color homeassistant.util.color
homeassistant.util.decorator
homeassistant.util.process homeassistant.util.process
homeassistant.util.unit_system homeassistant.util.unit_system

View file

@ -16,7 +16,7 @@ from homeassistant.data_entry_flow import FlowResult
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.util.decorator import Registry from homeassistant.util.decorator import Registry
MULTI_FACTOR_AUTH_MODULES = Registry() MULTI_FACTOR_AUTH_MODULES: Registry[str, type[MultiFactorAuthModule]] = Registry()
MULTI_FACTOR_AUTH_MODULE_SCHEMA = vol.Schema( MULTI_FACTOR_AUTH_MODULE_SCHEMA = vol.Schema(
{ {
@ -129,7 +129,7 @@ async def auth_mfa_module_from_config(
hass: HomeAssistant, config: dict[str, Any] hass: HomeAssistant, config: dict[str, Any]
) -> MultiFactorAuthModule: ) -> MultiFactorAuthModule:
"""Initialize an auth module from a config.""" """Initialize an auth module from a config."""
module_name = config[CONF_TYPE] module_name: str = config[CONF_TYPE]
module = await _load_mfa_module(hass, module_name) module = await _load_mfa_module(hass, module_name)
try: try:
@ -142,7 +142,7 @@ async def auth_mfa_module_from_config(
) )
raise raise
return MULTI_FACTOR_AUTH_MODULES[module_name](hass, config) # type: ignore[no-any-return] return MULTI_FACTOR_AUTH_MODULES[module_name](hass, config)
async def _load_mfa_module(hass: HomeAssistant, module_name: str) -> types.ModuleType: async def _load_mfa_module(hass: HomeAssistant, module_name: str) -> types.ModuleType:

View file

@ -25,7 +25,7 @@ from ..models import Credentials, RefreshToken, User, UserMeta
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DATA_REQS = "auth_prov_reqs_processed" DATA_REQS = "auth_prov_reqs_processed"
AUTH_PROVIDERS = Registry() AUTH_PROVIDERS: Registry[str, type[AuthProvider]] = Registry()
AUTH_PROVIDER_SCHEMA = vol.Schema( AUTH_PROVIDER_SCHEMA = vol.Schema(
{ {
@ -136,7 +136,7 @@ async def auth_provider_from_config(
hass: HomeAssistant, store: AuthStore, config: dict[str, Any] hass: HomeAssistant, store: AuthStore, config: dict[str, Any]
) -> AuthProvider: ) -> AuthProvider:
"""Initialize an auth provider from a config.""" """Initialize an auth provider from a config."""
provider_name = config[CONF_TYPE] provider_name: str = config[CONF_TYPE]
module = await load_auth_provider_module(hass, provider_name) module = await load_auth_provider_module(hass, provider_name)
try: try:
@ -149,7 +149,7 @@ async def auth_provider_from_config(
) )
raise raise
return AUTH_PROVIDERS[provider_name](hass, store, config) # type: ignore[no-any-return] return AUTH_PROVIDERS[provider_name](hass, store, config)
async def load_auth_provider_module( async def load_auth_provider_module(

View file

@ -83,7 +83,7 @@ if TYPE_CHECKING:
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
ENTITY_ADAPTERS = Registry() ENTITY_ADAPTERS: Registry[str, type[AlexaEntity]] = Registry()
TRANSLATION_TABLE = dict.fromkeys(map(ord, r"}{\/|\"()[]+~!><*%"), None) TRANSLATION_TABLE = dict.fromkeys(map(ord, r"}{\/|\"()[]+~!><*%"), None)

View file

@ -73,7 +73,7 @@ from .errors import (
from .state_report import async_enable_proactive_mode from .state_report import async_enable_proactive_mode
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
HANDLERS = Registry() HANDLERS = Registry() # type: ignore[var-annotated]
@HANDLERS.register(("Alexa.Discovery", "Discover")) @HANDLERS.register(("Alexa.Discovery", "Discover"))

View file

@ -12,7 +12,7 @@ from .const import DOMAIN, SYN_RESOLUTION_MATCH
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
HANDLERS = Registry() HANDLERS = Registry() # type: ignore[var-annotated]
INTENTS_API_ENDPOINT = "/api/alexa" INTENTS_API_ENDPOINT = "/api/alexa"

View file

@ -52,7 +52,7 @@ FILTER_NAME_OUTLIER = "outlier"
FILTER_NAME_THROTTLE = "throttle" FILTER_NAME_THROTTLE = "throttle"
FILTER_NAME_TIME_THROTTLE = "time_throttle" FILTER_NAME_TIME_THROTTLE = "time_throttle"
FILTER_NAME_TIME_SMA = "time_simple_moving_average" FILTER_NAME_TIME_SMA = "time_simple_moving_average"
FILTERS = Registry() FILTERS: Registry[str, type[Filter]] = Registry()
CONF_FILTERS = "filters" CONF_FILTERS = "filters"
CONF_FILTER_NAME = "filter" CONF_FILTER_NAME = "filter"

View file

@ -19,7 +19,7 @@ from .helpers import GoogleEntity, RequestData, async_get_entities
EXECUTE_LIMIT = 2 # Wait 2 seconds for execute to finish EXECUTE_LIMIT = 2 # Wait 2 seconds for execute to finish
HANDLERS = Registry() HANDLERS = Registry() # type: ignore[var-annotated]
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)

View file

@ -1,4 +1,6 @@
"""Extend the basic Accessory and Bridge functions.""" """Extend the basic Accessory and Bridge functions."""
from __future__ import annotations
import logging import logging
from pyhap.accessory import Accessory, Bridge from pyhap.accessory import Accessory, Bridge
@ -90,7 +92,7 @@ SWITCH_TYPES = {
TYPE_SWITCH: "Switch", TYPE_SWITCH: "Switch",
TYPE_VALVE: "Valve", TYPE_VALVE: "Valve",
} }
TYPES = Registry() TYPES: Registry[str, type[HomeAccessory]] = Registry()
def get_accessory(hass, driver, state, aid, config): # noqa: C901 def get_accessory(hass, driver, state, aid, config): # noqa: C901

View file

@ -9,7 +9,7 @@ from homeassistant.util import decorator
from .const import CONF_INVERSE, SIGNAL_DS18B20_NEW from .const import CONF_INVERSE, SIGNAL_DS18B20_NEW
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
HANDLERS = decorator.Registry() HANDLERS = decorator.Registry() # type: ignore[var-annotated]
@HANDLERS.register("state") @HANDLERS.register("state")

View file

@ -109,7 +109,7 @@ _LOGGER = logging.getLogger(__name__)
DELAY_SAVE = 10 DELAY_SAVE = 10
WEBHOOK_COMMANDS = Registry() WEBHOOK_COMMANDS = Registry() # type: ignore[var-annotated]
COMBINED_CLASSES = set(BINARY_SENSOR_CLASSES + SENSOR_CLASSES) COMBINED_CLASSES = set(BINARY_SENSOR_CLASSES + SENSOR_CLASSES)
SENSOR_TYPES = [ATTR_SENSOR_TYPE_BINARY_SENSOR, ATTR_SENSOR_TYPE_SENSOR] SENSOR_TYPES = [ATTR_SENSOR_TYPE_BINARY_SENSOR, ATTR_SENSOR_TYPE_SENSOR]

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable, Coroutine from collections.abc import Callable
import logging import logging
import socket import socket
import sys import sys
@ -337,9 +337,7 @@ def _gw_callback_factory(
_LOGGER.debug("Node update: node %s child %s", msg.node_id, msg.child_id) _LOGGER.debug("Node update: node %s child %s", msg.node_id, msg.child_id)
msg_type = msg.gateway.const.MessageType(msg.type) msg_type = msg.gateway.const.MessageType(msg.type)
msg_handler: Callable[ msg_handler = HANDLERS.get(msg_type.name)
[HomeAssistant, GatewayId, Message], Coroutine[Any, Any, None]
] | None = HANDLERS.get(msg_type.name)
if msg_handler is None: if msg_handler is None:
return return

View file

@ -1,6 +1,9 @@
"""Handle MySensors messages.""" """Handle MySensors messages."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Coroutine
from typing import Any
from mysensors import Message from mysensors import Message
from homeassistant.const import Platform from homeassistant.const import Platform
@ -12,7 +15,9 @@ from .const import CHILD_CALLBACK, NODE_CALLBACK, DevId, GatewayId
from .device import get_mysensors_devices from .device import get_mysensors_devices
from .helpers import discover_mysensors_platform, validate_set_msg from .helpers import discover_mysensors_platform, validate_set_msg
HANDLERS = decorator.Registry() HANDLERS: decorator.Registry[
str, Callable[[HomeAssistant, GatewayId, Message], Coroutine[Any, Any, None]]
] = decorator.Registry()
@HANDLERS.register("set") @HANDLERS.register("set")

View file

@ -31,7 +31,9 @@ from .const import (
) )
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
SCHEMAS = Registry() SCHEMAS: Registry[
tuple[str, str], Callable[[BaseAsyncGateway, ChildSensor, ValueType], vol.Schema]
] = Registry()
@callback @callback

View file

@ -1,10 +1,13 @@
"""ONVIF event parsers.""" """ONVIF event parsers."""
from collections.abc import Callable, Coroutine
from typing import Any
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from homeassistant.util.decorator import Registry from homeassistant.util.decorator import Registry
from .models import Event from .models import Event
PARSERS = Registry() PARSERS: Registry[str, Callable[[str, Any], Coroutine[Any, Any, Event]]] = Registry()
@PARSERS.register("tns1:VideoSource/MotionAlarm") @PARSERS.register("tns1:VideoSource/MotionAlarm")

View file

@ -1,8 +1,10 @@
"""Helpers to help coordinate updates.""" """Helpers to help coordinate updates."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Coroutine
from datetime import timedelta from datetime import timedelta
import logging import logging
from typing import Any
from aiohttp import ServerDisconnectedError from aiohttp import ServerDisconnectedError
from pyoverkiz.client import OverkizClient from pyoverkiz.client import OverkizClient
@ -25,7 +27,9 @@ from homeassistant.util.decorator import Registry
from .const import DOMAIN, LOGGER, UPDATE_INTERVAL from .const import DOMAIN, LOGGER, UPDATE_INTERVAL
EVENT_HANDLERS = Registry() EVENT_HANDLERS: Registry[
str, Callable[[OverkizDataUpdateCoordinator, Event], Coroutine[Any, Any, None]]
] = Registry()
class OverkizDataUpdateCoordinator(DataUpdateCoordinator[dict[str, Device]]): class OverkizDataUpdateCoordinator(DataUpdateCoordinator[dict[str, Device]]):

View file

@ -17,7 +17,7 @@ from .helper import supports_encryption
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
HANDLERS = decorator.Registry() HANDLERS = decorator.Registry() # type: ignore[var-annotated]
def get_cipher(): def get_cipher():

View file

@ -245,7 +245,7 @@ class Stream:
self, fmt: str, timeout: int = OUTPUT_IDLE_TIMEOUT self, fmt: str, timeout: int = OUTPUT_IDLE_TIMEOUT
) -> StreamOutput: ) -> StreamOutput:
"""Add provider output stream.""" """Add provider output stream."""
if not self._outputs.get(fmt): if not (provider := self._outputs.get(fmt)):
@callback @callback
def idle_callback() -> None: def idle_callback() -> None:
@ -259,7 +259,7 @@ class Stream:
self.hass, IdleTimer(self.hass, timeout, idle_callback) self.hass, IdleTimer(self.hass, timeout, idle_callback)
) )
self._outputs[fmt] = provider self._outputs[fmt] = provider
return self._outputs[fmt] return provider
def remove_provider(self, provider: StreamOutput) -> None: def remove_provider(self, provider: StreamOutput) -> None:
"""Remove provider output stream.""" """Remove provider output stream."""

View file

@ -23,7 +23,7 @@ if TYPE_CHECKING:
from . import Stream from . import Stream
PROVIDERS = Registry() PROVIDERS: Registry[str, type[StreamOutput]] = Registry()
@attr.s(slots=True) @attr.s(slots=True)

View file

@ -62,7 +62,7 @@ SOURCE_UNIGNORE = "unignore"
# This is used to signal that re-authentication is required by the user. # This is used to signal that re-authentication is required by the user.
SOURCE_REAUTH = "reauth" SOURCE_REAUTH = "reauth"
HANDLERS = Registry() HANDLERS: Registry[str, type[ConfigFlow]] = Registry()
STORAGE_KEY = "core.config_entries" STORAGE_KEY = "core.config_entries"
STORAGE_VERSION = 1 STORAGE_VERSION = 1
@ -530,8 +530,10 @@ class ConfigEntry:
) )
return False return False
# Handler may be a partial # Handler may be a partial
# Keep for backwards compatibility
# https://github.com/home-assistant/core/pull/67087#discussion_r812559950
while isinstance(handler, functools.partial): while isinstance(handler, functools.partial):
handler = handler.func handler = handler.func # type: ignore[unreachable]
if self.version == handler.VERSION: if self.version == handler.VERSION:
return True return True
@ -753,7 +755,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
if not context or "source" not in context: if not context or "source" not in context:
raise KeyError("Context not set or doesn't have a source set") raise KeyError("Context not set or doesn't have a source set")
flow = cast(ConfigFlow, handler()) flow = handler()
flow.init_step = context["source"] flow.init_step = context["source"]
return flow return flow
@ -1496,7 +1498,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager):
if entry.domain not in HANDLERS: if entry.domain not in HANDLERS:
raise data_entry_flow.UnknownHandler raise data_entry_flow.UnknownHandler
return cast(OptionsFlow, HANDLERS[entry.domain].async_get_options_flow(entry)) return HANDLERS[entry.domain].async_get_options_flow(entry)
async def async_finish_flow( async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult

View file

@ -13,7 +13,7 @@ from homeassistant.util import decorator
from . import config_validation as cv from . import config_validation as cv
SELECTORS = decorator.Registry() SELECTORS: decorator.Registry[str, type[Selector]] = decorator.Registry()
def _get_selector_class(config: Any) -> type[Selector]: def _get_selector_class(config: Any) -> type[Selector]:
@ -24,12 +24,12 @@ def _get_selector_class(config: Any) -> type[Selector]:
if len(config) != 1: if len(config) != 1:
raise vol.Invalid(f"Only one type can be specified. Found {', '.join(config)}") raise vol.Invalid(f"Only one type can be specified. Found {', '.join(config)}")
selector_type = list(config)[0] selector_type: str = list(config)[0]
if (selector_class := SELECTORS.get(selector_type)) is None: if (selector_class := SELECTORS.get(selector_type)) is None:
raise vol.Invalid(f"Unknown selector type {selector_type} found") raise vol.Invalid(f"Unknown selector type {selector_type} found")
return cast(type[Selector], selector_class) return selector_class
def selector(config: Any) -> Selector: def selector(config: Any) -> Selector:

View file

@ -2,18 +2,19 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Hashable from collections.abc import Callable, Hashable
from typing import TypeVar from typing import Any, TypeVar
CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) # pylint: disable=invalid-name _KT = TypeVar("_KT", bound=Hashable)
_VT = TypeVar("_VT", bound=Callable[..., Any])
class Registry(dict): class Registry(dict[_KT, _VT]):
"""Registry of items.""" """Registry of items."""
def register(self, name: Hashable) -> Callable[[CALLABLE_T], CALLABLE_T]: def register(self, name: _KT) -> Callable[[_VT], _VT]:
"""Return decorator to register item with a specific name.""" """Return decorator to register item with a specific name."""
def decorator(func: CALLABLE_T) -> CALLABLE_T: def decorator(func: _VT) -> _VT:
"""Register decorated function.""" """Register decorated function."""
self[name] = func self[name] = func
return func return func

View file

@ -76,6 +76,9 @@ disallow_any_generics = true
[mypy-homeassistant.util.color] [mypy-homeassistant.util.color]
disallow_any_generics = true disallow_any_generics = true
[mypy-homeassistant.util.decorator]
disallow_any_generics = true
[mypy-homeassistant.util.process] [mypy-homeassistant.util.process]
disallow_any_generics = true disallow_any_generics = true