diff --git a/pylint/plugins/hass_enforce_type_hints.py b/pylint/plugins/hass_enforce_type_hints.py index d82efa2fb3e..16449e2e5a0 100644 --- a/pylint/plugins/hass_enforce_type_hints.py +++ b/pylint/plugins/hass_enforce_type_hints.py @@ -3140,7 +3140,10 @@ class HassTypeHintChecker(BaseChecker): self._module_name = node.name self._in_test_module = self._module_name.startswith("tests.") - if (module_platform := _get_module_platform(node.name)) is None: + if ( + self._in_test_module + or (module_platform := _get_module_platform(node.name)) is None + ): return if module_platform in _PLATFORMS: @@ -3229,10 +3232,19 @@ class HassTypeHintChecker(BaseChecker): if node.is_method(): matchers = _METHOD_MATCH else: + if self._in_test_module: + if node.name.startswith("test_"): + self._check_test_function(node, False) + return + if (decoratornames := node.decoratornames()) and ( + # `@pytest.fixture` + "_pytest.fixtures.fixture" in decoratornames + # `@pytest.fixture(...)` + or "_pytest.fixtures.FixtureFunctionMarker" in decoratornames + ): + self._check_test_function(node, True) + return matchers = self._function_matchers - if self._in_test_module and node.name.startswith("test_"): - self._check_test_function(node) - return for match in matchers: if not match.need_to_check_function(node): @@ -3292,9 +3304,9 @@ class HassTypeHintChecker(BaseChecker): args=(match.return_type or "None", node.name), ) - def _check_test_function(self, node: nodes.FunctionDef) -> None: + def _check_test_function(self, node: nodes.FunctionDef, is_fixture: bool) -> None: # Check the return type, should always be `None` for test_*** functions. - if not _is_valid_type(None, node.returns, True): + if not is_fixture and not _is_valid_type(None, node.returns, True): self.add_message( "hass-return-type", node=node, @@ -3303,7 +3315,7 @@ class HassTypeHintChecker(BaseChecker): # Check that all positional arguments are correctly annotated. for arg_name, expected_type in _TEST_FIXTURES.items(): arg_node, annotation = _get_named_annotation(node, arg_name) - if arg_node and expected_type == "None": + if arg_node and expected_type == "None" and not is_fixture: self.add_message( "hass-consider-usefixtures-decorator", node=arg_node, diff --git a/tests/pylint/test_enforce_type_hints.py b/tests/pylint/test_enforce_type_hints.py index 0153214c267..68e1e14a34f 100644 --- a/tests/pylint/test_enforce_type_hints.py +++ b/tests/pylint/test_enforce_type_hints.py @@ -1232,6 +1232,86 @@ def test_pytest_invalid_function( type_hint_checker.visit_asyncfunctiondef(func_node) +def test_pytest_fixture(linter: UnittestLinter, type_hint_checker: BaseChecker) -> None: + """Ensure valid hints are accepted for a test fixture.""" + func_node = astroid.extract_node( + """ + import pytest + + @pytest.fixture + def sample_fixture( #@ + hass: HomeAssistant, + caplog: pytest.LogCaptureFixture, + aiohttp_server: Callable[[], TestServer], + unused_tcp_port_factory: Callable[[], int], + enable_custom_integrations: None, + ) -> None: + pass + """, + "tests.components.pylint_test.notify", + ) + type_hint_checker.visit_module(func_node.parent) + + with assert_no_messages( + linter, + ): + type_hint_checker.visit_asyncfunctiondef(func_node) + + +@pytest.mark.parametrize("decorator", ["@pytest.fixture", "@pytest.fixture()"]) +def test_pytest_invalid_fixture( + linter: UnittestLinter, type_hint_checker: BaseChecker, decorator: str +) -> None: + """Ensure invalid hints are rejected for a test fixture.""" + func_node, hass_node, caplog_node, none_node = astroid.extract_node( + f""" + import pytest + + {decorator} + def sample_fixture( #@ + hass: Something, #@ + caplog: SomethingElse, #@ + current_request_with_host, #@ + ) -> Any: + pass + """, + "tests.components.pylint_test.notify", + ) + type_hint_checker.visit_module(func_node.parent) + + with assert_adds_messages( + linter, + pylint.testutils.MessageTest( + msg_id="hass-argument-type", + node=hass_node, + args=("hass", ["HomeAssistant", "HomeAssistant | None"], "sample_fixture"), + line=6, + col_offset=4, + end_line=6, + end_col_offset=19, + ), + pylint.testutils.MessageTest( + msg_id="hass-argument-type", + node=caplog_node, + args=("caplog", "pytest.LogCaptureFixture", "sample_fixture"), + line=7, + col_offset=4, + end_line=7, + end_col_offset=25, + ), + pylint.testutils.MessageTest( + msg_id="hass-argument-type", + node=none_node, + args=("current_request_with_host", "None", "sample_fixture"), + line=8, + col_offset=4, + end_line=8, + end_col_offset=29, + ), + ): + type_hint_checker.visit_asyncfunctiondef(func_node) + + @pytest.mark.parametrize( "entry_annotation", [