diff --git a/pylint/plugins/hass_enforce_type_hints.py b/pylint/plugins/hass_enforce_type_hints.py index 3c6139a41e7..0adebaf98f6 100644 --- a/pylint/plugins/hass_enforce_type_hints.py +++ b/pylint/plugins/hass_enforce_type_hints.py @@ -3138,15 +3138,15 @@ class HassTypeHintChecker(BaseChecker): _class_matchers: list[ClassTypeHintMatch] _function_matchers: list[TypeHintMatch] - _module_name: str + _module_node: nodes.Module _in_test_module: bool def visit_module(self, node: nodes.Module) -> None: """Populate matchers for a Module node.""" self._class_matchers = [] self._function_matchers = [] - self._module_name = node.name - self._in_test_module = self._module_name.startswith("tests.") + self._module_node = node + self._in_test_module = node.name.startswith("tests.") if ( self._in_test_module @@ -3230,7 +3230,7 @@ class HassTypeHintChecker(BaseChecker): if node.is_method(): matchers = _METHOD_MATCH else: - if self._in_test_module: + if self._in_test_module and node.parent is self._module_node: if node.name.startswith("test_"): self._check_test_function(node, False) return diff --git a/tests/pylint/test_enforce_type_hints.py b/tests/pylint/test_enforce_type_hints.py index 9f0f4905dab..5b1c494568d 100644 --- a/tests/pylint/test_enforce_type_hints.py +++ b/tests/pylint/test_enforce_type_hints.py @@ -1152,6 +1152,28 @@ def test_pytest_function( type_hint_checker.visit_asyncfunctiondef(func_node) +def test_pytest_nested_function( + linter: UnittestLinter, type_hint_checker: BaseChecker +) -> None: + """Ensure valid hints are accepted for a test function.""" + func_node, nested_func_node = astroid.extract_node( + """ + async def some_function( #@ + ): + def test_value(value: str) -> bool: #@ + return value == "Yes" + return test_value + """, + "tests.components.pylint_test.notify", + ) + type_hint_checker.visit_module(func_node.parent) + + with assert_no_messages( + linter, + ): + type_hint_checker.visit_asyncfunctiondef(nested_func_node) + + def test_pytest_invalid_function( linter: UnittestLinter, type_hint_checker: BaseChecker ) -> None: