diff --git a/pylint/plugins/hass_enforce_type_hints.py b/pylint/plugins/hass_enforce_type_hints.py index e5659be06cb..c489a508798 100644 --- a/pylint/plugins/hass_enforce_type_hints.py +++ b/pylint/plugins/hass_enforce_type_hints.py @@ -90,6 +90,18 @@ _METHOD_MATCH: list[TypeHintMatch] = [ ), ] +_TEST_FIXTURES: dict[str, list[str] | str] = { + "mqtt_client_mock": "MqttMockPahoClient", + "mqtt_mock": "MqttMockHAClient", + "mqtt_mock_entry_no_yaml_config": "MqttMockHAClientGenerator", + "mqtt_mock_entry_with_yaml_config": "MqttMockHAClientGenerator", +} +_TEST_FUNCTION_MATCH = TypeHintMatch( + function_name="test_*", + return_type=None, +) + + _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = { "__init__": [ TypeHintMatch( @@ -2860,6 +2872,11 @@ def _get_module_platform(module_name: str) -> str | None: return platform.lstrip(".") if platform else "__init__" +def _is_test_function(module_name: str, node: nodes.FunctionDef) -> bool: + """Return True if function is a pytest function.""" + return module_name.startswith("tests.") and node.name.startswith("test_") + + class HassTypeHintChecker(BaseChecker): # type: ignore[misc] """Checker for setup type hints.""" @@ -2890,16 +2907,15 @@ class HassTypeHintChecker(BaseChecker): # type: ignore[misc] ), ) - def __init__(self, linter: PyLinter | None = None) -> None: - """Initialize the HassTypeHintChecker.""" - super().__init__(linter) - self._function_matchers: list[TypeHintMatch] = [] - self._class_matchers: list[ClassTypeHintMatch] = [] + _class_matchers: list[ClassTypeHintMatch] + _function_matchers: list[TypeHintMatch] + _module_name: str def visit_module(self, node: nodes.Module) -> None: """Populate matchers for a Module node.""" - self._function_matchers = [] self._class_matchers = [] + self._function_matchers = [] + self._module_name = node.name if (module_platform := _get_module_platform(node.name)) is None: return @@ -2985,6 +3001,8 @@ class HassTypeHintChecker(BaseChecker): # type: ignore[misc] matchers = _METHOD_MATCH else: matchers = self._function_matchers + if _is_test_function(self._module_name, node): + self._check_test_function(node, annotations) for match in matchers: if not match.need_to_check_function(node): continue @@ -3001,7 +3019,11 @@ class HassTypeHintChecker(BaseChecker): # type: ignore[misc] # Check that all positional arguments are correctly annotated. if match.arg_types: for key, expected_type in match.arg_types.items(): - if node.args.args[key].name in _COMMON_ARGUMENTS: + if ( + node.args.args[key].name in _COMMON_ARGUMENTS + or _is_test_function(self._module_name, node) + and node.args.args[key].name in _TEST_FIXTURES + ): # It has already been checked, avoid double-message continue if not _is_valid_type(expected_type, annotations[key]): @@ -3014,7 +3036,11 @@ class HassTypeHintChecker(BaseChecker): # type: ignore[misc] # Check that all keyword arguments are correctly annotated. if match.named_arg_types is not None: for arg_name, expected_type in match.named_arg_types.items(): - if arg_name in _COMMON_ARGUMENTS: + if ( + arg_name in _COMMON_ARGUMENTS + or _is_test_function(self._module_name, node) + and arg_name in _TEST_FIXTURES + ): # It has already been checked, avoid double-message continue arg_node, annotation = _get_named_annotation(node, arg_name) @@ -3043,6 +3069,26 @@ class HassTypeHintChecker(BaseChecker): # type: ignore[misc] args=(match.return_type or "None", node.name), ) + def _check_test_function( + self, node: nodes.FunctionDef, annotations: list[nodes.NodeNG | None] + ) -> None: + # Check the return type. + if not _is_valid_return_type(_TEST_FUNCTION_MATCH, node.returns): + self.add_message( + "hass-return-type", + node=node, + args=(_TEST_FUNCTION_MATCH.return_type or "None", node.name), + ) + # Check that all positional arguments are correctly annotated. + for arg_name, expected_type in _TEST_FIXTURES.items(): + arg_node, annotation = _get_named_annotation(node, arg_name) + if arg_node and not _is_valid_type(expected_type, annotation): + self.add_message( + "hass-argument-type", + node=arg_node, + args=(arg_name, expected_type, node.name), + ) + def register(linter: PyLinter) -> None: """Register the checker."""