From 06251d403a0ab4194931267b3a7729a889ca6571 Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Thu, 30 May 2024 10:41:32 +0200 Subject: [PATCH] Fix special case in pylint type hint plugin (#118454) * Fix special case in pylint type hint plugin * Simplify * Simplify * Simplify * Apply Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com> --------- Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com> --- pylint/plugins/hass_enforce_type_hints.py | 6 +++++- tests/pylint/test_enforce_type_hints.py | 5 ++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/pylint/plugins/hass_enforce_type_hints.py b/pylint/plugins/hass_enforce_type_hints.py index 6d3b68cbeb6..0fc522f46c2 100644 --- a/pylint/plugins/hass_enforce_type_hints.py +++ b/pylint/plugins/hass_enforce_type_hints.py @@ -69,7 +69,7 @@ class ClassTypeHintMatch: matches: list[TypeHintMatch] -_INNER_MATCH = r"((?:[\w\| ]+)|(?:\.{3})|(?:\w+\[.+\]))" +_INNER_MATCH = r"((?:[\w\| ]+)|(?:\.{3})|(?:\w+\[.+\])|(?:\[\]))" _TYPE_HINT_MATCHERS: dict[str, re.Pattern[str]] = { # a_or_b matches items such as "DiscoveryInfoType | None" # or "dict | list | None" @@ -2914,6 +2914,10 @@ def _is_valid_type( if expected_type == "...": return isinstance(node, nodes.Const) and node.value == Ellipsis + # Special case for an empty list, such as Callable[[], TestServer] + if expected_type == "[]": + return isinstance(node, nodes.List) and not node.elts + # Special case for `xxx | yyy` if match := _TYPE_HINT_MATCHERS["a_or_b"].match(expected_type): return ( diff --git a/tests/pylint/test_enforce_type_hints.py b/tests/pylint/test_enforce_type_hints.py index 64dd472827e..0153214c267 100644 --- a/tests/pylint/test_enforce_type_hints.py +++ b/tests/pylint/test_enforce_type_hints.py @@ -54,6 +54,7 @@ def test_regex_get_module_platform( ("list[dict[str, str]]", 1, ("list", "dict[str, str]")), ("list[dict[str, Any]]", 1, ("list", "dict[str, Any]")), ("tuple[bytes | None, str | None]", 2, ("tuple", "bytes | None", "str | None")), + ("Callable[[], TestServer]", 2, ("Callable", "[]", "TestServer")), ], ) def test_regex_x_of_y_i( @@ -1130,12 +1131,14 @@ def test_notify_get_service( def test_pytest_function( linter: UnittestLinter, type_hint_checker: BaseChecker ) -> None: - """Ensure valid hints are accepted for async_get_service.""" + """Ensure valid hints are accepted for a test function.""" func_node = astroid.extract_node( """ async def test_sample( #@ hass: HomeAssistant, caplog: pytest.LogCaptureFixture, + aiohttp_server: Callable[[], TestServer], + unused_tcp_port_factory: Callable[[], int], ) -> None: pass """,