Improve decorator type annotations [zwave_js] (#104825)

* Improve decorator type annotations [zwave_js]

* Improve _async_get_entry annotation
This commit is contained in:
Marc Mueller 2023-12-06 16:22:32 +01:00 committed by GitHub
parent 6721f9fdb2
commit c93abd9d20
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 43 additions and 14 deletions

View file

@ -1,10 +1,10 @@
"""Websocket API for Z-Wave JS."""
from __future__ import annotations
from collections.abc import Callable
from collections.abc import Callable, Coroutine
import dataclasses
from functools import partial, wraps
from typing import Any, Literal, cast
from typing import Any, Concatenate, Literal, ParamSpec, cast
from aiohttp import web, web_exceptions, web_request
import voluptuous as vol
@ -85,6 +85,8 @@ from .helpers import (
get_device_id,
)
_P = ParamSpec("_P")
DATA_UNSUBSCRIBE = "unsubs"
# general API constants
@ -264,8 +266,11 @@ QR_CODE_STRING_SCHEMA = vol.All(str, vol.Length(min=MINIMUM_QR_STRING_LENGTH))
async def _async_get_entry(
hass: HomeAssistant, connection: ActiveConnection, msg: dict, entry_id: str
) -> tuple[ConfigEntry | None, Client | None, Driver | None]:
hass: HomeAssistant,
connection: ActiveConnection,
msg: dict[str, Any],
entry_id: str,
) -> tuple[ConfigEntry, Client, Driver] | tuple[None, None, None]:
"""Get config entry and client from message data."""
entry = hass.config_entries.async_get_entry(entry_id)
if entry is None:
@ -293,19 +298,26 @@ async def _async_get_entry(
return entry, client, client.driver
def async_get_entry(orig_func: Callable) -> Callable:
def async_get_entry(
orig_func: Callable[
[HomeAssistant, ActiveConnection, dict[str, Any], ConfigEntry, Client, Driver],
Coroutine[Any, Any, None],
],
) -> Callable[
[HomeAssistant, ActiveConnection, dict[str, Any]], Coroutine[Any, Any, None]
]:
"""Decorate async function to get entry."""
@wraps(orig_func)
async def async_get_entry_func(
hass: HomeAssistant, connection: ActiveConnection, msg: dict
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Provide user specific data and store to function."""
entry, client, driver = await _async_get_entry(
hass, connection, msg, msg[ENTRY_ID]
)
if not entry and not client and not driver:
if not entry or not client or not driver:
return
await orig_func(hass, connection, msg, entry, client, driver)
@ -328,12 +340,19 @@ async def _async_get_node(
return node
def async_get_node(orig_func: Callable) -> Callable:
def async_get_node(
orig_func: Callable[
[HomeAssistant, ActiveConnection, dict[str, Any], Node],
Coroutine[Any, Any, None],
],
) -> Callable[
[HomeAssistant, ActiveConnection, dict[str, Any]], Coroutine[Any, Any, None]
]:
"""Decorate async function to get node."""
@wraps(orig_func)
async def async_get_node_func(
hass: HomeAssistant, connection: ActiveConnection, msg: dict
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Provide user specific data and store to function."""
node = await _async_get_node(hass, connection, msg, msg[DEVICE_ID])
@ -344,16 +363,24 @@ def async_get_node(orig_func: Callable) -> Callable:
return async_get_node_func
def async_handle_failed_command(orig_func: Callable) -> Callable:
def async_handle_failed_command(
orig_func: Callable[
Concatenate[HomeAssistant, ActiveConnection, dict[str, Any], _P],
Coroutine[Any, Any, None],
],
) -> Callable[
Concatenate[HomeAssistant, ActiveConnection, dict[str, Any], _P],
Coroutine[Any, Any, None],
]:
"""Decorate async function to handle FailedCommand and send relevant error."""
@wraps(orig_func)
async def async_handle_failed_command_func(
hass: HomeAssistant,
connection: ActiveConnection,
msg: dict,
*args: Any,
**kwargs: Any,
msg: dict[str, Any],
*args: _P.args,
**kwargs: _P.kwargs,
) -> None:
"""Handle FailedCommand within function and send relevant error."""
try: