diff --git a/pylint/plugins/hass_enforce_type_hints.py b/pylint/plugins/hass_enforce_type_hints.py index ac58db37b72..d82efa2fb3e 100644 --- a/pylint/plugins/hass_enforce_type_hints.py +++ b/pylint/plugins/hass_enforce_type_hints.py @@ -3092,11 +3092,6 @@ def _get_module_platform(module_name: str) -> str | None: return platform.lstrip(".") if platform else "__init__" -def _is_test_function(module_name: str, node: nodes.FunctionDef) -> bool: - """Return True if function is a pytest function.""" - return module_name.startswith("tests.") and node.name.startswith("test_") - - class HassTypeHintChecker(BaseChecker): """Checker for setup type hints.""" @@ -3136,12 +3131,14 @@ class HassTypeHintChecker(BaseChecker): _class_matchers: list[ClassTypeHintMatch] _function_matchers: list[TypeHintMatch] _module_name: str + _in_test_module: bool def visit_module(self, node: nodes.Module) -> None: """Populate matchers for a Module node.""" self._class_matchers = [] self._function_matchers = [] self._module_name = node.name + self._in_test_module = self._module_name.startswith("tests.") if (module_platform := _get_module_platform(node.name)) is None: return @@ -3233,8 +3230,10 @@ class HassTypeHintChecker(BaseChecker): matchers = _METHOD_MATCH else: matchers = self._function_matchers - if _is_test_function(self._module_name, node): - self._check_test_function(node, annotations) + if self._in_test_module and node.name.startswith("test_"): + self._check_test_function(node) + return + for match in matchers: if not match.need_to_check_function(node): continue @@ -3251,11 +3250,7 @@ class HassTypeHintChecker(BaseChecker): # Check that all positional arguments are correctly annotated. if match.arg_types: for key, expected_type in match.arg_types.items(): - if ( - node.args.args[key].name in _COMMON_ARGUMENTS - or _is_test_function(self._module_name, node) - and node.args.args[key].name in _TEST_FIXTURES - ): + if node.args.args[key].name in _COMMON_ARGUMENTS: # It has already been checked, avoid double-message continue if not _is_valid_type(expected_type, annotations[key]): @@ -3268,11 +3263,7 @@ class HassTypeHintChecker(BaseChecker): # Check that all keyword arguments are correctly annotated. if match.named_arg_types is not None: for arg_name, expected_type in match.named_arg_types.items(): - if ( - arg_name in _COMMON_ARGUMENTS - or _is_test_function(self._module_name, node) - and arg_name in _TEST_FIXTURES - ): + if arg_name in _COMMON_ARGUMENTS: # It has already been checked, avoid double-message continue arg_node, annotation = _get_named_annotation(node, arg_name) @@ -3301,9 +3292,7 @@ class HassTypeHintChecker(BaseChecker): args=(match.return_type or "None", node.name), ) - def _check_test_function( - self, node: nodes.FunctionDef, annotations: list[nodes.NodeNG | None] - ) -> None: + def _check_test_function(self, node: nodes.FunctionDef) -> None: # Check the return type, should always be `None` for test_*** functions. if not _is_valid_type(None, node.returns, True): self.add_message(