From 3eafe13085444a5f29b57c3a43740eff5be36f35 Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Tue, 2 Aug 2022 00:03:52 +0200 Subject: [PATCH] Improve UI in pylint plugin (#74157) * Adjust FlowResult result type * Adjust tests * Adjust return_type * Use StrEnum for base device_class * Add test for device_class * Add and use SentinelValues.DEVICE_CLASS * Remove duplicate device_class * Cleanup return-type * Drop inheritance check from device_class * Add caching for class methods * Improve tests * Adjust duplicate checks * Adjust tests * Fix rebase --- pylint/plugins/hass_enforce_type_hints.py | 29 ++--- tests/pylint/test_enforce_type_hints.py | 131 +++++++++++++++++----- 2 files changed, 116 insertions(+), 44 deletions(-) diff --git a/pylint/plugins/hass_enforce_type_hints.py b/pylint/plugins/hass_enforce_type_hints.py index 551db458b1d..d0d20cedd7c 100644 --- a/pylint/plugins/hass_enforce_type_hints.py +++ b/pylint/plugins/hass_enforce_type_hints.py @@ -16,7 +16,6 @@ class _Special(Enum): """Sentinel values""" UNDEFINED = 1 - DEVICE_CLASS = 2 _PLATFORMS: set[str] = {platform.value for platform in Platform} @@ -466,6 +465,7 @@ _CLASS_MATCH: dict[str, list[ClassTypeHintMatch]] = { } # Overriding properties and functions are normally checked by mypy, and will only # be checked by pylint when --ignore-missing-annotations is False + _ENTITY_MATCH: list[TypeHintMatch] = [ TypeHintMatch( function_name="should_poll", @@ -505,7 +505,7 @@ _ENTITY_MATCH: list[TypeHintMatch] = [ ), TypeHintMatch( function_name="device_class", - return_type=[_Special.DEVICE_CLASS, "str", None], + return_type=["str", None], ), TypeHintMatch( function_name="unit_of_measurement", @@ -1416,15 +1416,6 @@ def _is_valid_type( if expected_type is _Special.UNDEFINED: return True - # Special case for device_class - if expected_type is _Special.DEVICE_CLASS and in_return: - return ( - isinstance(node, nodes.Name) - and node.name.endswith("DeviceClass") - or isinstance(node, nodes.Attribute) - and node.attrname.endswith("DeviceClass") - ) - if isinstance(expected_type, list): for expected_type_item in expected_type: if _is_valid_type(expected_type_item, node, in_return): @@ -1636,18 +1627,28 @@ class HassTypeHintChecker(BaseChecker): # type: ignore[misc] def visit_classdef(self, node: nodes.ClassDef) -> None: """Called when a ClassDef node is visited.""" ancestor: nodes.ClassDef + checked_class_methods: set[str] = set() for ancestor in node.ancestors(): for class_matches in self._class_matchers: if ancestor.name == class_matches.base_class: - self._visit_class_functions(node, class_matches.matches) + self._visit_class_functions( + node, class_matches.matches, checked_class_methods + ) def _visit_class_functions( - self, node: nodes.ClassDef, matches: list[TypeHintMatch] + self, + node: nodes.ClassDef, + matches: list[TypeHintMatch], + checked_class_methods: set[str], ) -> None: + cached_methods: list[nodes.FunctionDef] = list(node.mymethods()) for match in matches: - for function_node in node.mymethods(): + for function_node in cached_methods: + if function_node.name in checked_class_methods: + continue if match.need_to_check_function(function_node): self._check_function(function_node, match) + checked_class_methods.add(function_node.name) def visit_functiondef(self, node: nodes.FunctionDef) -> None: """Called when a FunctionDef node is visited.""" diff --git a/tests/pylint/test_enforce_type_hints.py b/tests/pylint/test_enforce_type_hints.py index 53c17880716..d9edde9fdee 100644 --- a/tests/pylint/test_enforce_type_hints.py +++ b/tests/pylint/test_enforce_type_hints.py @@ -307,7 +307,10 @@ def test_invalid_config_flow_step( """Ensure invalid hints are rejected for ConfigFlow step.""" class_node, func_node, arg_node = astroid.extract_node( """ - class ConfigFlow(): + class FlowHandler(): + pass + + class ConfigFlow(FlowHandler): pass class AxisFlowHandler( #@ @@ -329,18 +332,18 @@ def test_invalid_config_flow_step( msg_id="hass-argument-type", node=arg_node, args=(2, "ZeroconfServiceInfo", "async_step_zeroconf"), - line=10, + line=13, col_offset=8, - end_line=10, + end_line=13, end_col_offset=27, ), pylint.testutils.MessageTest( msg_id="hass-return-type", node=func_node, args=("FlowResult", "async_step_zeroconf"), - line=8, + line=11, col_offset=4, - end_line=8, + end_line=11, end_col_offset=33, ), ): @@ -353,7 +356,10 @@ def test_valid_config_flow_step( """Ensure valid hints are accepted for ConfigFlow step.""" class_node = astroid.extract_node( """ - class ConfigFlow(): + class FlowHandler(): + pass + + class ConfigFlow(FlowHandler): pass class AxisFlowHandler( #@ @@ -377,9 +383,16 @@ def test_invalid_config_flow_async_get_options_flow( linter: UnittestLinter, type_hint_checker: BaseChecker ) -> None: """Ensure invalid hints are rejected for ConfigFlow async_get_options_flow.""" + # AxisOptionsFlow doesn't inherit OptionsFlow, and therefore should fail class_node, func_node, arg_node = astroid.extract_node( """ - class ConfigFlow(): + class FlowHandler(): + pass + + class ConfigFlow(FlowHandler): + pass + + class OptionsFlow(FlowHandler): pass class AxisOptionsFlow(): @@ -403,18 +416,18 @@ def test_invalid_config_flow_async_get_options_flow( msg_id="hass-argument-type", node=arg_node, args=(1, "ConfigEntry", "async_get_options_flow"), - line=12, + line=18, col_offset=8, - end_line=12, + end_line=18, end_col_offset=20, ), pylint.testutils.MessageTest( msg_id="hass-return-type", node=func_node, args=("OptionsFlow", "async_get_options_flow"), - line=11, + line=17, col_offset=4, - end_line=11, + end_line=17, end_col_offset=30, ), ): @@ -427,10 +440,13 @@ def test_valid_config_flow_async_get_options_flow( """Ensure valid hints are accepted for ConfigFlow async_get_options_flow.""" class_node = astroid.extract_node( """ - class ConfigFlow(): + class FlowHandler(): pass - class OptionsFlow(): + class ConfigFlow(FlowHandler): + pass + + class OptionsFlow(FlowHandler): pass class AxisOptionsFlow(OptionsFlow): @@ -467,7 +483,10 @@ def test_invalid_entity_properties( class_node, prop_node, func_node = astroid.extract_node( """ - class LockEntity(): + class Entity(): + pass + + class LockEntity(Entity): pass class DoorLock( #@ @@ -495,27 +514,27 @@ def test_invalid_entity_properties( msg_id="hass-return-type", node=prop_node, args=(["str", None], "changed_by"), - line=9, + line=12, col_offset=4, - end_line=9, + end_line=12, end_col_offset=18, ), pylint.testutils.MessageTest( msg_id="hass-argument-type", node=func_node, args=("kwargs", "Any", "async_lock"), - line=14, + line=17, col_offset=4, - end_line=14, + end_line=17, end_col_offset=24, ), pylint.testutils.MessageTest( msg_id="hass-return-type", node=func_node, args=("None", "async_lock"), - line=14, + line=17, col_offset=4, - end_line=14, + end_line=17, end_col_offset=24, ), ): @@ -531,7 +550,10 @@ def test_ignore_invalid_entity_properties( class_node = astroid.extract_node( """ - class LockEntity(): + class Entity(): + pass + + class LockEntity(Entity): pass class DoorLock( #@ @@ -566,7 +588,13 @@ def test_named_arguments( class_node, func_node, percentage_node, preset_mode_node = astroid.extract_node( """ - class FanEntity(): + class Entity(): + pass + + class ToggleEntity(Entity): + pass + + class FanEntity(ToggleEntity): pass class MyFan( #@ @@ -591,36 +619,36 @@ def test_named_arguments( msg_id="hass-argument-type", node=percentage_node, args=("percentage", "int | None", "async_turn_on"), - line=10, + line=16, col_offset=8, - end_line=10, + end_line=16, end_col_offset=18, ), pylint.testutils.MessageTest( msg_id="hass-argument-type", node=preset_mode_node, args=("preset_mode", "str | None", "async_turn_on"), - line=12, + line=18, col_offset=8, - end_line=12, + end_line=18, end_col_offset=24, ), pylint.testutils.MessageTest( msg_id="hass-argument-type", node=func_node, args=("kwargs", "Any", "async_turn_on"), - line=8, + line=14, col_offset=4, - end_line=8, + end_line=14, end_col_offset=27, ), pylint.testutils.MessageTest( msg_id="hass-return-type", node=func_node, args=("None", "async_turn_on"), - line=8, + line=14, col_offset=4, - end_line=8, + end_line=14, end_col_offset=27, ), ): @@ -829,3 +857,46 @@ def test_invalid_long_tuple( ), ): type_hint_checker.visit_classdef(class_node) + + +def test_invalid_device_class( + linter: UnittestLinter, type_hint_checker: BaseChecker +) -> None: + """Ensure invalid hints are rejected for entity device_class.""" + # Set bypass option + type_hint_checker.config.ignore_missing_annotations = False + + class_node, prop_node = astroid.extract_node( + """ + class Entity(): + pass + + class CoverEntity(Entity): + pass + + class MyCover( #@ + CoverEntity + ): + @property + def device_class( #@ + self + ): + pass + """, + "homeassistant.components.pylint_test.cover", + ) + type_hint_checker.visit_module(class_node.parent) + + with assert_adds_messages( + linter, + pylint.testutils.MessageTest( + msg_id="hass-return-type", + node=prop_node, + args=(["CoverDeviceClass", "str", None], "device_class"), + line=12, + col_offset=4, + end_line=12, + end_col_offset=20, + ), + ): + type_hint_checker.visit_classdef(class_node)