From f25663067cfbcc98b7307fd501ce67cc76a81937 Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Mon, 23 May 2022 18:51:40 +0200 Subject: [PATCH] Enforce type hints on device_automation platform (#72126) --- pylint/plugins/hass_enforce_type_hints.py | 158 ++++++++++++++++++++-- tests/pylint/test_enforce_type_hints.py | 75 ++++++++++ 2 files changed, 221 insertions(+), 12 deletions(-) diff --git a/pylint/plugins/hass_enforce_type_hints.py b/pylint/plugins/hass_enforce_type_hints.py index 6cf521addac..9d35d07fab2 100644 --- a/pylint/plugins/hass_enforce_type_hints.py +++ b/pylint/plugins/hass_enforce_type_hints.py @@ -31,6 +31,8 @@ _TYPE_HINT_MATCHERS: dict[str, re.Pattern] = { "x_of_y": re.compile(r"^(\w+)\[(.*?]*)\]$"), # x_of_y_comma_z matches items such as "Callable[..., Awaitable[None]]" "x_of_y_comma_z": re.compile(r"^(\w+)\[(.*?]*), (.*?]*)\]$"), + # x_of_y_of_z_comma_a matches items such as "list[dict[str, Any]]" + "x_of_y_of_z_comma_a": re.compile(r"^(\w+)\[(\w+)\[(.*?]*), (.*?]*)\]\]$"), } _MODULE_FILTERS: dict[str, re.Pattern] = { @@ -44,12 +46,20 @@ _MODULE_FILTERS: dict[str, re.Pattern] = { "application_credentials": re.compile( r"^homeassistant\.components\.\w+\.(application_credentials)$" ), - # device_tracker matches only in the package root (device_tracker.py) - "device_tracker": re.compile(r"^homeassistant\.components\.\w+\.(device_tracker)$"), - # diagnostics matches only in the package root (diagnostics.py) - "diagnostics": re.compile(r"^homeassistant\.components\.\w+\.(diagnostics)$"), # config_flow matches only in the package root (config_flow.py) "config_flow": re.compile(r"^homeassistant\.components\.\w+\.(config_flow)$"), + # device_action matches only in the package root (device_action.py) + "device_action": re.compile(r"^homeassistant\.components\.\w+\.(device_action)$"), + # device_condition matches only in the package root (device_condition.py) + "device_condition": re.compile( + r"^homeassistant\.components\.\w+\.(device_condition)$" + ), + # device_tracker matches only in the package root (device_tracker.py) + "device_tracker": re.compile(r"^homeassistant\.components\.\w+\.(device_tracker)$"), + # device_trigger matches only in the package root (device_trigger.py) + "device_trigger": re.compile(r"^homeassistant\.components\.\w+\.(device_trigger)$"), + # diagnostics matches only in the package root (diagnostics.py) + "diagnostics": re.compile(r"^homeassistant\.components\.\w+\.(diagnostics)$"), } _METHOD_MATCH: list[TypeHintMatch] = [ @@ -157,6 +167,88 @@ _METHOD_MATCH: list[TypeHintMatch] = [ }, return_type="AuthorizationServer", ), + TypeHintMatch( + module_filter=_MODULE_FILTERS["config_flow"], + function_name="_async_has_devices", + arg_types={ + 0: "HomeAssistant", + }, + return_type="bool", + ), + TypeHintMatch( + module_filter=_MODULE_FILTERS["device_action"], + function_name="async_validate_action_config", + arg_types={ + 0: "HomeAssistant", + 1: "ConfigType", + }, + return_type="ConfigType", + ), + TypeHintMatch( + module_filter=_MODULE_FILTERS["device_action"], + function_name="async_call_action_from_config", + arg_types={ + 0: "HomeAssistant", + 1: "ConfigType", + 2: "TemplateVarsType", + 3: "Context | None", + }, + return_type=None, + ), + TypeHintMatch( + module_filter=_MODULE_FILTERS["device_action"], + function_name="async_get_action_capabilities", + arg_types={ + 0: "HomeAssistant", + 1: "ConfigType", + }, + return_type="dict[str, Schema]", + ), + TypeHintMatch( + module_filter=_MODULE_FILTERS["device_action"], + function_name="async_get_actions", + arg_types={ + 0: "HomeAssistant", + 1: "str", + }, + return_type=["list[dict[str, str]]", "list[dict[str, Any]]"], + ), + TypeHintMatch( + module_filter=_MODULE_FILTERS["device_condition"], + function_name="async_validate_condition_config", + arg_types={ + 0: "HomeAssistant", + 1: "ConfigType", + }, + return_type="ConfigType", + ), + TypeHintMatch( + module_filter=_MODULE_FILTERS["device_condition"], + function_name="async_condition_from_config", + arg_types={ + 0: "HomeAssistant", + 1: "ConfigType", + }, + return_type="ConditionCheckerType", + ), + TypeHintMatch( + module_filter=_MODULE_FILTERS["device_condition"], + function_name="async_get_condition_capabilities", + arg_types={ + 0: "HomeAssistant", + 1: "ConfigType", + }, + return_type="dict[str, Schema]", + ), + TypeHintMatch( + module_filter=_MODULE_FILTERS["device_condition"], + function_name="async_get_conditions", + arg_types={ + 0: "HomeAssistant", + 1: "str", + }, + return_type=["list[dict[str, str]]", "list[dict[str, Any]]"], + ), TypeHintMatch( module_filter=_MODULE_FILTERS["device_tracker"], function_name="setup_scanner", @@ -197,6 +289,44 @@ _METHOD_MATCH: list[TypeHintMatch] = [ }, return_type=["DeviceScanner", "DeviceScanner | None"], ), + TypeHintMatch( + module_filter=_MODULE_FILTERS["device_trigger"], + function_name="async_validate_condition_config", + arg_types={ + 0: "HomeAssistant", + 1: "ConfigType", + }, + return_type="ConfigType", + ), + TypeHintMatch( + module_filter=_MODULE_FILTERS["device_trigger"], + function_name="async_attach_trigger", + arg_types={ + 0: "HomeAssistant", + 1: "ConfigType", + 2: "AutomationActionType", + 3: "AutomationTriggerInfo", + }, + return_type="CALLBACK_TYPE", + ), + TypeHintMatch( + module_filter=_MODULE_FILTERS["device_trigger"], + function_name="async_get_trigger_capabilities", + arg_types={ + 0: "HomeAssistant", + 1: "ConfigType", + }, + return_type="dict[str, Schema]", + ), + TypeHintMatch( + module_filter=_MODULE_FILTERS["device_trigger"], + function_name="async_get_triggers", + arg_types={ + 0: "HomeAssistant", + 1: "str", + }, + return_type=["list[dict[str, str]]", "list[dict[str, Any]]"], + ), TypeHintMatch( module_filter=_MODULE_FILTERS["diagnostics"], function_name="async_get_config_entry_diagnostics", @@ -216,14 +346,6 @@ _METHOD_MATCH: list[TypeHintMatch] = [ }, return_type=UNDEFINED, ), - TypeHintMatch( - module_filter=_MODULE_FILTERS["config_flow"], - function_name="_async_has_devices", - arg_types={ - 0: "HomeAssistant", - }, - return_type="bool", - ), ] @@ -254,6 +376,18 @@ def _is_valid_type(expected_type: list[str] | str | None, node: astroid.NodeNG) and _is_valid_type(match.group(2), node.right) ) + # Special case for xxx[yyy[zzz, aaa]]` + if match := _TYPE_HINT_MATCHERS["x_of_y_of_z_comma_a"].match(expected_type): + return ( + isinstance(node, astroid.Subscript) + and _is_valid_type(match.group(1), node.value) + and isinstance(subnode := node.slice, astroid.Subscript) + and _is_valid_type(match.group(2), subnode.value) + and isinstance(subnode.slice, astroid.Tuple) + and _is_valid_type(match.group(3), subnode.slice.elts[0]) + and _is_valid_type(match.group(4), subnode.slice.elts[1]) + ) + # Special case for xxx[yyy, zzz]` if match := _TYPE_HINT_MATCHERS["x_of_y_comma_z"].match(expected_type): return ( diff --git a/tests/pylint/test_enforce_type_hints.py b/tests/pylint/test_enforce_type_hints.py index 64be85c9a44..feb3b6b341c 100644 --- a/tests/pylint/test_enforce_type_hints.py +++ b/tests/pylint/test_enforce_type_hints.py @@ -14,6 +14,32 @@ import pytest from . import assert_adds_messages, assert_no_messages +@pytest.mark.parametrize( + ("string", "expected_x", "expected_y", "expected_z", "expected_a"), + [ + ("list[dict[str, str]]", "list", "dict", "str", "str"), + ("list[dict[str, Any]]", "list", "dict", "str", "Any"), + ], +) +def test_regex_x_of_y_of_z_comma_a( + hass_enforce_type_hints: ModuleType, + string: str, + expected_x: str, + expected_y: str, + expected_z: str, + expected_a: str, +) -> None: + """Test x_of_y_of_z_comma_a regexes.""" + matchers: dict[str, re.Pattern] = hass_enforce_type_hints._TYPE_HINT_MATCHERS + + assert (match := matchers["x_of_y_of_z_comma_a"].match(string)) + assert match.group(0) == string + assert match.group(1) == expected_x + assert match.group(2) == expected_y + assert match.group(3) == expected_z + assert match.group(4) == expected_a + + @pytest.mark.parametrize( ("string", "expected_x", "expected_y", "expected_z"), [ @@ -165,3 +191,52 @@ def test_valid_discovery_info( with assert_no_messages(linter): type_hint_checker.visit_asyncfunctiondef(func_node) + + +def test_invalid_list_dict_str_any( + linter: UnittestLinter, type_hint_checker: BaseChecker +) -> None: + """Ensure invalid hints are rejected for discovery_info.""" + type_hint_checker.module = "homeassistant.components.pylint_test.device_trigger" + func_node = astroid.extract_node( + """ + async def async_get_triggers( #@ + hass: HomeAssistant, + device_id: str + ) -> list: + pass + """ + ) + + with assert_adds_messages( + linter, + pylint.testutils.MessageTest( + msg_id="hass-return-type", + node=func_node, + args=["list[dict[str, str]]", "list[dict[str, Any]]"], + line=2, + col_offset=0, + end_line=2, + end_col_offset=28, + ), + ): + type_hint_checker.visit_asyncfunctiondef(func_node) + + +def test_valid_list_dict_str_any( + linter: UnittestLinter, type_hint_checker: BaseChecker +) -> None: + """Ensure valid hints are accepted for discovery_info.""" + type_hint_checker.module = "homeassistant.components.pylint_test.device_trigger" + func_node = astroid.extract_node( + """ + async def async_get_triggers( #@ + hass: HomeAssistant, + device_id: str + ) -> list[dict[str, Any]]: + pass + """ + ) + + with assert_no_messages(linter): + type_hint_checker.visit_asyncfunctiondef(func_node)