From 3c1e62aeefd05f969274c34840dfb801d619bc63 Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Wed, 22 Feb 2023 13:39:28 +0100 Subject: [PATCH] Improve type hint in zwave_js trigger (#88597) Imrpove type hint in zwave_js trigger --- homeassistant/components/zwave_js/trigger.py | 26 ++++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/homeassistant/components/zwave_js/trigger.py b/homeassistant/components/zwave_js/trigger.py index d1751dc4f4f..f747c25c71b 100644 --- a/homeassistant/components/zwave_js/trigger.py +++ b/homeassistant/components/zwave_js/trigger.py @@ -1,12 +1,15 @@ """Z-Wave JS trigger dispatcher.""" from __future__ import annotations -from types import ModuleType from typing import cast from homeassistant.const import CONF_PLATFORM from homeassistant.core import CALLBACK_TYPE, HomeAssistant -from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo +from homeassistant.helpers.trigger import ( + TriggerActionType, + TriggerInfo, + TriggerProtocol, +) from homeassistant.helpers.typing import ConfigType from .triggers import event, value_updated @@ -17,7 +20,7 @@ TRIGGERS = { } -def _get_trigger_platform(config: ConfigType) -> ModuleType: +def _get_trigger_platform(config: ConfigType) -> TriggerProtocol: """Return trigger platform.""" platform_split = config[CONF_PLATFORM].split(".", maxsplit=1) if len(platform_split) < 2 or platform_split[1] not in TRIGGERS: @@ -31,12 +34,9 @@ async def async_validate_trigger_config( """Validate config.""" platform = _get_trigger_platform(config) if hasattr(platform, "async_validate_trigger_config"): - return cast( - ConfigType, - await getattr(platform, "async_validate_trigger_config")(hass, config), - ) - assert hasattr(platform, "TRIGGER_SCHEMA") - return cast(ConfigType, getattr(platform, "TRIGGER_SCHEMA")(config)) + return await platform.async_validate_trigger_config(hass, config) + + return cast(ConfigType, platform.TRIGGER_SCHEMA(config)) async def async_attach_trigger( @@ -47,10 +47,4 @@ async def async_attach_trigger( ) -> CALLBACK_TYPE: """Attach trigger of specified platform.""" platform = _get_trigger_platform(config) - assert hasattr(platform, "async_attach_trigger") - return cast( - CALLBACK_TYPE, - await getattr(platform, "async_attach_trigger")( - hass, config, action, trigger_info - ), - ) + return await platform.async_attach_trigger(hass, config, action, trigger_info)