diff --git a/homeassistant/components/zwave_js/api.py b/homeassistant/components/zwave_js/api.py index 9e50b55830c..7f4855bfbe5 100644 --- a/homeassistant/components/zwave_js/api.py +++ b/homeassistant/components/zwave_js/api.py @@ -1,10 +1,10 @@ """Websocket API for Z-Wave JS.""" from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Coroutine import dataclasses from functools import partial, wraps -from typing import Any, Literal, cast +from typing import Any, Concatenate, Literal, ParamSpec, cast from aiohttp import web, web_exceptions, web_request import voluptuous as vol @@ -85,6 +85,8 @@ from .helpers import ( get_device_id, ) +_P = ParamSpec("_P") + DATA_UNSUBSCRIBE = "unsubs" # general API constants @@ -264,8 +266,11 @@ QR_CODE_STRING_SCHEMA = vol.All(str, vol.Length(min=MINIMUM_QR_STRING_LENGTH)) async def _async_get_entry( - hass: HomeAssistant, connection: ActiveConnection, msg: dict, entry_id: str -) -> tuple[ConfigEntry | None, Client | None, Driver | None]: + hass: HomeAssistant, + connection: ActiveConnection, + msg: dict[str, Any], + entry_id: str, +) -> tuple[ConfigEntry, Client, Driver] | tuple[None, None, None]: """Get config entry and client from message data.""" entry = hass.config_entries.async_get_entry(entry_id) if entry is None: @@ -293,19 +298,26 @@ async def _async_get_entry( return entry, client, client.driver -def async_get_entry(orig_func: Callable) -> Callable: +def async_get_entry( + orig_func: Callable[ + [HomeAssistant, ActiveConnection, dict[str, Any], ConfigEntry, Client, Driver], + Coroutine[Any, Any, None], + ], +) -> Callable[ + [HomeAssistant, ActiveConnection, dict[str, Any]], Coroutine[Any, Any, None] +]: """Decorate async function to get entry.""" @wraps(orig_func) async def async_get_entry_func( - hass: HomeAssistant, connection: ActiveConnection, msg: dict + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Provide user specific data and store to function.""" entry, client, driver = await _async_get_entry( hass, connection, msg, msg[ENTRY_ID] ) - if not entry and not client and not driver: + if not entry or not client or not driver: return await orig_func(hass, connection, msg, entry, client, driver) @@ -328,12 +340,19 @@ async def _async_get_node( return node -def async_get_node(orig_func: Callable) -> Callable: +def async_get_node( + orig_func: Callable[ + [HomeAssistant, ActiveConnection, dict[str, Any], Node], + Coroutine[Any, Any, None], + ], +) -> Callable[ + [HomeAssistant, ActiveConnection, dict[str, Any]], Coroutine[Any, Any, None] +]: """Decorate async function to get node.""" @wraps(orig_func) async def async_get_node_func( - hass: HomeAssistant, connection: ActiveConnection, msg: dict + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Provide user specific data and store to function.""" node = await _async_get_node(hass, connection, msg, msg[DEVICE_ID]) @@ -344,16 +363,24 @@ def async_get_node(orig_func: Callable) -> Callable: return async_get_node_func -def async_handle_failed_command(orig_func: Callable) -> Callable: +def async_handle_failed_command( + orig_func: Callable[ + Concatenate[HomeAssistant, ActiveConnection, dict[str, Any], _P], + Coroutine[Any, Any, None], + ], +) -> Callable[ + Concatenate[HomeAssistant, ActiveConnection, dict[str, Any], _P], + Coroutine[Any, Any, None], +]: """Decorate async function to handle FailedCommand and send relevant error.""" @wraps(orig_func) async def async_handle_failed_command_func( hass: HomeAssistant, connection: ActiveConnection, - msg: dict, - *args: Any, - **kwargs: Any, + msg: dict[str, Any], + *args: _P.args, + **kwargs: _P.kwargs, ) -> None: """Handle FailedCommand within function and send relevant error.""" try: diff --git a/homeassistant/components/zwave_js/helpers.py b/homeassistant/components/zwave_js/helpers.py index 5d78d3e57e7..65c77f8ab2d 100644 --- a/homeassistant/components/zwave_js/helpers.py +++ b/homeassistant/components/zwave_js/helpers.py @@ -456,7 +456,9 @@ def remove_keys_with_empty_values(config: ConfigType) -> ConfigType: return {key: value for key, value in config.items() if value not in ("", None)} -def check_type_schema_map(schema_map: dict[str, vol.Schema]) -> Callable: +def check_type_schema_map( + schema_map: dict[str, vol.Schema] +) -> Callable[[ConfigType], ConfigType]: """Check type specific schema against config.""" def _check_type_schema(config: ConfigType) -> ConfigType: