Improve decorator type annotations [zwave_js] (#104825)
* Improve decorator type annotations [zwave_js] * Improve _async_get_entry annotation
This commit is contained in:
parent
6721f9fdb2
commit
c93abd9d20
2 changed files with 43 additions and 14 deletions
|
@ -1,10 +1,10 @@
|
||||||
"""Websocket API for Z-Wave JS."""
|
"""Websocket API for Z-Wave JS."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable, Coroutine
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from typing import Any, Literal, cast
|
from typing import Any, Concatenate, Literal, ParamSpec, cast
|
||||||
|
|
||||||
from aiohttp import web, web_exceptions, web_request
|
from aiohttp import web, web_exceptions, web_request
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
@ -85,6 +85,8 @@ from .helpers import (
|
||||||
get_device_id,
|
get_device_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_P = ParamSpec("_P")
|
||||||
|
|
||||||
DATA_UNSUBSCRIBE = "unsubs"
|
DATA_UNSUBSCRIBE = "unsubs"
|
||||||
|
|
||||||
# general API constants
|
# general API constants
|
||||||
|
@ -264,8 +266,11 @@ QR_CODE_STRING_SCHEMA = vol.All(str, vol.Length(min=MINIMUM_QR_STRING_LENGTH))
|
||||||
|
|
||||||
|
|
||||||
async def _async_get_entry(
|
async def _async_get_entry(
|
||||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict, entry_id: str
|
hass: HomeAssistant,
|
||||||
) -> tuple[ConfigEntry | None, Client | None, Driver | None]:
|
connection: ActiveConnection,
|
||||||
|
msg: dict[str, Any],
|
||||||
|
entry_id: str,
|
||||||
|
) -> tuple[ConfigEntry, Client, Driver] | tuple[None, None, None]:
|
||||||
"""Get config entry and client from message data."""
|
"""Get config entry and client from message data."""
|
||||||
entry = hass.config_entries.async_get_entry(entry_id)
|
entry = hass.config_entries.async_get_entry(entry_id)
|
||||||
if entry is None:
|
if entry is None:
|
||||||
|
@ -293,19 +298,26 @@ async def _async_get_entry(
|
||||||
return entry, client, client.driver
|
return entry, client, client.driver
|
||||||
|
|
||||||
|
|
||||||
def async_get_entry(orig_func: Callable) -> Callable:
|
def async_get_entry(
|
||||||
|
orig_func: Callable[
|
||||||
|
[HomeAssistant, ActiveConnection, dict[str, Any], ConfigEntry, Client, Driver],
|
||||||
|
Coroutine[Any, Any, None],
|
||||||
|
],
|
||||||
|
) -> Callable[
|
||||||
|
[HomeAssistant, ActiveConnection, dict[str, Any]], Coroutine[Any, Any, None]
|
||||||
|
]:
|
||||||
"""Decorate async function to get entry."""
|
"""Decorate async function to get entry."""
|
||||||
|
|
||||||
@wraps(orig_func)
|
@wraps(orig_func)
|
||||||
async def async_get_entry_func(
|
async def async_get_entry_func(
|
||||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Provide user specific data and store to function."""
|
"""Provide user specific data and store to function."""
|
||||||
entry, client, driver = await _async_get_entry(
|
entry, client, driver = await _async_get_entry(
|
||||||
hass, connection, msg, msg[ENTRY_ID]
|
hass, connection, msg, msg[ENTRY_ID]
|
||||||
)
|
)
|
||||||
|
|
||||||
if not entry and not client and not driver:
|
if not entry or not client or not driver:
|
||||||
return
|
return
|
||||||
|
|
||||||
await orig_func(hass, connection, msg, entry, client, driver)
|
await orig_func(hass, connection, msg, entry, client, driver)
|
||||||
|
@ -328,12 +340,19 @@ async def _async_get_node(
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
def async_get_node(orig_func: Callable) -> Callable:
|
def async_get_node(
|
||||||
|
orig_func: Callable[
|
||||||
|
[HomeAssistant, ActiveConnection, dict[str, Any], Node],
|
||||||
|
Coroutine[Any, Any, None],
|
||||||
|
],
|
||||||
|
) -> Callable[
|
||||||
|
[HomeAssistant, ActiveConnection, dict[str, Any]], Coroutine[Any, Any, None]
|
||||||
|
]:
|
||||||
"""Decorate async function to get node."""
|
"""Decorate async function to get node."""
|
||||||
|
|
||||||
@wraps(orig_func)
|
@wraps(orig_func)
|
||||||
async def async_get_node_func(
|
async def async_get_node_func(
|
||||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Provide user specific data and store to function."""
|
"""Provide user specific data and store to function."""
|
||||||
node = await _async_get_node(hass, connection, msg, msg[DEVICE_ID])
|
node = await _async_get_node(hass, connection, msg, msg[DEVICE_ID])
|
||||||
|
@ -344,16 +363,24 @@ def async_get_node(orig_func: Callable) -> Callable:
|
||||||
return async_get_node_func
|
return async_get_node_func
|
||||||
|
|
||||||
|
|
||||||
def async_handle_failed_command(orig_func: Callable) -> Callable:
|
def async_handle_failed_command(
|
||||||
|
orig_func: Callable[
|
||||||
|
Concatenate[HomeAssistant, ActiveConnection, dict[str, Any], _P],
|
||||||
|
Coroutine[Any, Any, None],
|
||||||
|
],
|
||||||
|
) -> Callable[
|
||||||
|
Concatenate[HomeAssistant, ActiveConnection, dict[str, Any], _P],
|
||||||
|
Coroutine[Any, Any, None],
|
||||||
|
]:
|
||||||
"""Decorate async function to handle FailedCommand and send relevant error."""
|
"""Decorate async function to handle FailedCommand and send relevant error."""
|
||||||
|
|
||||||
@wraps(orig_func)
|
@wraps(orig_func)
|
||||||
async def async_handle_failed_command_func(
|
async def async_handle_failed_command_func(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
connection: ActiveConnection,
|
connection: ActiveConnection,
|
||||||
msg: dict,
|
msg: dict[str, Any],
|
||||||
*args: Any,
|
*args: _P.args,
|
||||||
**kwargs: Any,
|
**kwargs: _P.kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle FailedCommand within function and send relevant error."""
|
"""Handle FailedCommand within function and send relevant error."""
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -456,7 +456,9 @@ def remove_keys_with_empty_values(config: ConfigType) -> ConfigType:
|
||||||
return {key: value for key, value in config.items() if value not in ("", None)}
|
return {key: value for key, value in config.items() if value not in ("", None)}
|
||||||
|
|
||||||
|
|
||||||
def check_type_schema_map(schema_map: dict[str, vol.Schema]) -> Callable:
|
def check_type_schema_map(
|
||||||
|
schema_map: dict[str, vol.Schema]
|
||||||
|
) -> Callable[[ConfigType], ConfigType]:
|
||||||
"""Check type specific schema against config."""
|
"""Check type specific schema against config."""
|
||||||
|
|
||||||
def _check_type_schema(config: ConfigType) -> ConfigType:
|
def _check_type_schema(config: ConfigType) -> ConfigType:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue