From 7eda8aafc80690abbd9e4a8cf059041ce964faf0 Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Thu, 6 Jun 2024 10:43:31 +0200 Subject: [PATCH] Ignore nested functions when enforcing type hints in tests (#118948) --- pylint/plugins/hass_enforce_type_hints.py | 8 ++++---- tests/pylint/test_enforce_type_hints.py | 22 ++++++++++++++++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/pylint/plugins/hass_enforce_type_hints.py b/pylint/plugins/hass_enforce_type_hints.py index 3c6139a41e7..0adebaf98f6 100644 --- a/pylint/plugins/hass_enforce_type_hints.py +++ b/pylint/plugins/hass_enforce_type_hints.py @@ -3138,15 +3138,15 @@ class HassTypeHintChecker(BaseChecker): _class_matchers: list[ClassTypeHintMatch] _function_matchers: list[TypeHintMatch] - _module_name: str + _module_node: nodes.Module _in_test_module: bool def visit_module(self, node: nodes.Module) -> None: """Populate matchers for a Module node.""" self._class_matchers = [] self._function_matchers = [] - self._module_name = node.name - self._in_test_module = self._module_name.startswith("tests.") + self._module_node = node + self._in_test_module = node.name.startswith("tests.") if ( self._in_test_module @@ -3230,7 +3230,7 @@ class HassTypeHintChecker(BaseChecker): if node.is_method(): matchers = _METHOD_MATCH else: - if self._in_test_module: + if self._in_test_module and node.parent is self._module_node: if node.name.startswith("test_"): self._check_test_function(node, False) return diff --git a/tests/pylint/test_enforce_type_hints.py b/tests/pylint/test_enforce_type_hints.py index 9f0f4905dab..5b1c494568d 100644 --- a/tests/pylint/test_enforce_type_hints.py +++ b/tests/pylint/test_enforce_type_hints.py @@ -1152,6 +1152,28 @@ def test_pytest_function( type_hint_checker.visit_asyncfunctiondef(func_node) +def test_pytest_nested_function( + linter: UnittestLinter, type_hint_checker: BaseChecker +) -> None: + """Ensure valid hints are accepted for a test function.""" + func_node, nested_func_node = astroid.extract_node( + """ + async def some_function( #@ + ): + def test_value(value: str) -> bool: #@ + return value == "Yes" + return test_value + """, + "tests.components.pylint_test.notify", + ) + type_hint_checker.visit_module(func_node.parent) + + with assert_no_messages( + linter, + ): + type_hint_checker.visit_asyncfunctiondef(nested_func_node) + + def test_pytest_invalid_function( linter: UnittestLinter, type_hint_checker: BaseChecker ) -> None: