Improve decorator type annotations [zwave_js] (#104825)

* Improve decorator type annotations [zwave_js]

* Improve _async_get_entry annotation
This commit is contained in:
Marc Mueller 2023-12-06 16:22:32 +01:00 committed by GitHub
parent 6721f9fdb2
commit c93abd9d20
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 43 additions and 14 deletions

View file

@ -1,10 +1,10 @@
"""Websocket API for Z-Wave JS.""" """Websocket API for Z-Wave JS."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable, Coroutine
import dataclasses import dataclasses
from functools import partial, wraps from functools import partial, wraps
from typing import Any, Literal, cast from typing import Any, Concatenate, Literal, ParamSpec, cast
from aiohttp import web, web_exceptions, web_request from aiohttp import web, web_exceptions, web_request
import voluptuous as vol import voluptuous as vol
@ -85,6 +85,8 @@ from .helpers import (
get_device_id, get_device_id,
) )
_P = ParamSpec("_P")
DATA_UNSUBSCRIBE = "unsubs" DATA_UNSUBSCRIBE = "unsubs"
# general API constants # general API constants
@ -264,8 +266,11 @@ QR_CODE_STRING_SCHEMA = vol.All(str, vol.Length(min=MINIMUM_QR_STRING_LENGTH))
async def _async_get_entry( async def _async_get_entry(
hass: HomeAssistant, connection: ActiveConnection, msg: dict, entry_id: str hass: HomeAssistant,
) -> tuple[ConfigEntry | None, Client | None, Driver | None]: connection: ActiveConnection,
msg: dict[str, Any],
entry_id: str,
) -> tuple[ConfigEntry, Client, Driver] | tuple[None, None, None]:
"""Get config entry and client from message data.""" """Get config entry and client from message data."""
entry = hass.config_entries.async_get_entry(entry_id) entry = hass.config_entries.async_get_entry(entry_id)
if entry is None: if entry is None:
@ -293,19 +298,26 @@ async def _async_get_entry(
return entry, client, client.driver return entry, client, client.driver
def async_get_entry(orig_func: Callable) -> Callable: def async_get_entry(
orig_func: Callable[
[HomeAssistant, ActiveConnection, dict[str, Any], ConfigEntry, Client, Driver],
Coroutine[Any, Any, None],
],
) -> Callable[
[HomeAssistant, ActiveConnection, dict[str, Any]], Coroutine[Any, Any, None]
]:
"""Decorate async function to get entry.""" """Decorate async function to get entry."""
@wraps(orig_func) @wraps(orig_func)
async def async_get_entry_func( async def async_get_entry_func(
hass: HomeAssistant, connection: ActiveConnection, msg: dict hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Provide user specific data and store to function.""" """Provide user specific data and store to function."""
entry, client, driver = await _async_get_entry( entry, client, driver = await _async_get_entry(
hass, connection, msg, msg[ENTRY_ID] hass, connection, msg, msg[ENTRY_ID]
) )
if not entry and not client and not driver: if not entry or not client or not driver:
return return
await orig_func(hass, connection, msg, entry, client, driver) await orig_func(hass, connection, msg, entry, client, driver)
@ -328,12 +340,19 @@ async def _async_get_node(
return node return node
def async_get_node(orig_func: Callable) -> Callable: def async_get_node(
orig_func: Callable[
[HomeAssistant, ActiveConnection, dict[str, Any], Node],
Coroutine[Any, Any, None],
],
) -> Callable[
[HomeAssistant, ActiveConnection, dict[str, Any]], Coroutine[Any, Any, None]
]:
"""Decorate async function to get node.""" """Decorate async function to get node."""
@wraps(orig_func) @wraps(orig_func)
async def async_get_node_func( async def async_get_node_func(
hass: HomeAssistant, connection: ActiveConnection, msg: dict hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Provide user specific data and store to function.""" """Provide user specific data and store to function."""
node = await _async_get_node(hass, connection, msg, msg[DEVICE_ID]) node = await _async_get_node(hass, connection, msg, msg[DEVICE_ID])
@ -344,16 +363,24 @@ def async_get_node(orig_func: Callable) -> Callable:
return async_get_node_func return async_get_node_func
def async_handle_failed_command(orig_func: Callable) -> Callable: def async_handle_failed_command(
orig_func: Callable[
Concatenate[HomeAssistant, ActiveConnection, dict[str, Any], _P],
Coroutine[Any, Any, None],
],
) -> Callable[
Concatenate[HomeAssistant, ActiveConnection, dict[str, Any], _P],
Coroutine[Any, Any, None],
]:
"""Decorate async function to handle FailedCommand and send relevant error.""" """Decorate async function to handle FailedCommand and send relevant error."""
@wraps(orig_func) @wraps(orig_func)
async def async_handle_failed_command_func( async def async_handle_failed_command_func(
hass: HomeAssistant, hass: HomeAssistant,
connection: ActiveConnection, connection: ActiveConnection,
msg: dict, msg: dict[str, Any],
*args: Any, *args: _P.args,
**kwargs: Any, **kwargs: _P.kwargs,
) -> None: ) -> None:
"""Handle FailedCommand within function and send relevant error.""" """Handle FailedCommand within function and send relevant error."""
try: try:

View file

@ -456,7 +456,9 @@ def remove_keys_with_empty_values(config: ConfigType) -> ConfigType:
return {key: value for key, value in config.items() if value not in ("", None)} return {key: value for key, value in config.items() if value not in ("", None)}
def check_type_schema_map(schema_map: dict[str, vol.Schema]) -> Callable: def check_type_schema_map(
schema_map: dict[str, vol.Schema]
) -> Callable[[ConfigType], ConfigType]:
"""Check type specific schema against config.""" """Check type specific schema against config."""
def _check_type_schema(config: ConfigType) -> ConfigType: def _check_type_schema(config: ConfigType) -> ConfigType: