Enforce type hints on device_automation platform (#72126)

This commit is contained in:
epenet 2022-05-23 18:51:40 +02:00 committed by GitHub
parent 3cdc5c8429
commit f25663067c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 221 additions and 12 deletions

View file

@ -31,6 +31,8 @@ _TYPE_HINT_MATCHERS: dict[str, re.Pattern] = {
"x_of_y": re.compile(r"^(\w+)\[(.*?]*)\]$"),
# x_of_y_comma_z matches items such as "Callable[..., Awaitable[None]]"
"x_of_y_comma_z": re.compile(r"^(\w+)\[(.*?]*), (.*?]*)\]$"),
# x_of_y_of_z_comma_a matches items such as "list[dict[str, Any]]"
"x_of_y_of_z_comma_a": re.compile(r"^(\w+)\[(\w+)\[(.*?]*), (.*?]*)\]\]$"),
}
_MODULE_FILTERS: dict[str, re.Pattern] = {
@ -44,12 +46,20 @@ _MODULE_FILTERS: dict[str, re.Pattern] = {
"application_credentials": re.compile(
r"^homeassistant\.components\.\w+\.(application_credentials)$"
),
# device_tracker matches only in the package root (device_tracker.py)
"device_tracker": re.compile(r"^homeassistant\.components\.\w+\.(device_tracker)$"),
# diagnostics matches only in the package root (diagnostics.py)
"diagnostics": re.compile(r"^homeassistant\.components\.\w+\.(diagnostics)$"),
# config_flow matches only in the package root (config_flow.py)
"config_flow": re.compile(r"^homeassistant\.components\.\w+\.(config_flow)$"),
# device_action matches only in the package root (device_action.py)
"device_action": re.compile(r"^homeassistant\.components\.\w+\.(device_action)$"),
# device_condition matches only in the package root (device_condition.py)
"device_condition": re.compile(
r"^homeassistant\.components\.\w+\.(device_condition)$"
),
# device_tracker matches only in the package root (device_tracker.py)
"device_tracker": re.compile(r"^homeassistant\.components\.\w+\.(device_tracker)$"),
# device_trigger matches only in the package root (device_trigger.py)
"device_trigger": re.compile(r"^homeassistant\.components\.\w+\.(device_trigger)$"),
# diagnostics matches only in the package root (diagnostics.py)
"diagnostics": re.compile(r"^homeassistant\.components\.\w+\.(diagnostics)$"),
}
_METHOD_MATCH: list[TypeHintMatch] = [
@ -157,6 +167,88 @@ _METHOD_MATCH: list[TypeHintMatch] = [
},
return_type="AuthorizationServer",
),
TypeHintMatch(
module_filter=_MODULE_FILTERS["config_flow"],
function_name="_async_has_devices",
arg_types={
0: "HomeAssistant",
},
return_type="bool",
),
TypeHintMatch(
module_filter=_MODULE_FILTERS["device_action"],
function_name="async_validate_action_config",
arg_types={
0: "HomeAssistant",
1: "ConfigType",
},
return_type="ConfigType",
),
TypeHintMatch(
module_filter=_MODULE_FILTERS["device_action"],
function_name="async_call_action_from_config",
arg_types={
0: "HomeAssistant",
1: "ConfigType",
2: "TemplateVarsType",
3: "Context | None",
},
return_type=None,
),
TypeHintMatch(
module_filter=_MODULE_FILTERS["device_action"],
function_name="async_get_action_capabilities",
arg_types={
0: "HomeAssistant",
1: "ConfigType",
},
return_type="dict[str, Schema]",
),
TypeHintMatch(
module_filter=_MODULE_FILTERS["device_action"],
function_name="async_get_actions",
arg_types={
0: "HomeAssistant",
1: "str",
},
return_type=["list[dict[str, str]]", "list[dict[str, Any]]"],
),
TypeHintMatch(
module_filter=_MODULE_FILTERS["device_condition"],
function_name="async_validate_condition_config",
arg_types={
0: "HomeAssistant",
1: "ConfigType",
},
return_type="ConfigType",
),
TypeHintMatch(
module_filter=_MODULE_FILTERS["device_condition"],
function_name="async_condition_from_config",
arg_types={
0: "HomeAssistant",
1: "ConfigType",
},
return_type="ConditionCheckerType",
),
TypeHintMatch(
module_filter=_MODULE_FILTERS["device_condition"],
function_name="async_get_condition_capabilities",
arg_types={
0: "HomeAssistant",
1: "ConfigType",
},
return_type="dict[str, Schema]",
),
TypeHintMatch(
module_filter=_MODULE_FILTERS["device_condition"],
function_name="async_get_conditions",
arg_types={
0: "HomeAssistant",
1: "str",
},
return_type=["list[dict[str, str]]", "list[dict[str, Any]]"],
),
TypeHintMatch(
module_filter=_MODULE_FILTERS["device_tracker"],
function_name="setup_scanner",
@ -197,6 +289,44 @@ _METHOD_MATCH: list[TypeHintMatch] = [
},
return_type=["DeviceScanner", "DeviceScanner | None"],
),
TypeHintMatch(
module_filter=_MODULE_FILTERS["device_trigger"],
function_name="async_validate_condition_config",
arg_types={
0: "HomeAssistant",
1: "ConfigType",
},
return_type="ConfigType",
),
TypeHintMatch(
module_filter=_MODULE_FILTERS["device_trigger"],
function_name="async_attach_trigger",
arg_types={
0: "HomeAssistant",
1: "ConfigType",
2: "AutomationActionType",
3: "AutomationTriggerInfo",
},
return_type="CALLBACK_TYPE",
),
TypeHintMatch(
module_filter=_MODULE_FILTERS["device_trigger"],
function_name="async_get_trigger_capabilities",
arg_types={
0: "HomeAssistant",
1: "ConfigType",
},
return_type="dict[str, Schema]",
),
TypeHintMatch(
module_filter=_MODULE_FILTERS["device_trigger"],
function_name="async_get_triggers",
arg_types={
0: "HomeAssistant",
1: "str",
},
return_type=["list[dict[str, str]]", "list[dict[str, Any]]"],
),
TypeHintMatch(
module_filter=_MODULE_FILTERS["diagnostics"],
function_name="async_get_config_entry_diagnostics",
@ -216,14 +346,6 @@ _METHOD_MATCH: list[TypeHintMatch] = [
},
return_type=UNDEFINED,
),
TypeHintMatch(
module_filter=_MODULE_FILTERS["config_flow"],
function_name="_async_has_devices",
arg_types={
0: "HomeAssistant",
},
return_type="bool",
),
]
@ -254,6 +376,18 @@ def _is_valid_type(expected_type: list[str] | str | None, node: astroid.NodeNG)
and _is_valid_type(match.group(2), node.right)
)
# Special case for xxx[yyy[zzz, aaa]]`
if match := _TYPE_HINT_MATCHERS["x_of_y_of_z_comma_a"].match(expected_type):
return (
isinstance(node, astroid.Subscript)
and _is_valid_type(match.group(1), node.value)
and isinstance(subnode := node.slice, astroid.Subscript)
and _is_valid_type(match.group(2), subnode.value)
and isinstance(subnode.slice, astroid.Tuple)
and _is_valid_type(match.group(3), subnode.slice.elts[0])
and _is_valid_type(match.group(4), subnode.slice.elts[1])
)
# Special case for xxx[yyy, zzz]`
if match := _TYPE_HINT_MATCHERS["x_of_y_comma_z"].match(expected_type):
return (

View file

@ -14,6 +14,32 @@ import pytest
from . import assert_adds_messages, assert_no_messages
@pytest.mark.parametrize(
("string", "expected_x", "expected_y", "expected_z", "expected_a"),
[
("list[dict[str, str]]", "list", "dict", "str", "str"),
("list[dict[str, Any]]", "list", "dict", "str", "Any"),
],
)
def test_regex_x_of_y_of_z_comma_a(
hass_enforce_type_hints: ModuleType,
string: str,
expected_x: str,
expected_y: str,
expected_z: str,
expected_a: str,
) -> None:
"""Test x_of_y_of_z_comma_a regexes."""
matchers: dict[str, re.Pattern] = hass_enforce_type_hints._TYPE_HINT_MATCHERS
assert (match := matchers["x_of_y_of_z_comma_a"].match(string))
assert match.group(0) == string
assert match.group(1) == expected_x
assert match.group(2) == expected_y
assert match.group(3) == expected_z
assert match.group(4) == expected_a
@pytest.mark.parametrize(
("string", "expected_x", "expected_y", "expected_z"),
[
@ -165,3 +191,52 @@ def test_valid_discovery_info(
with assert_no_messages(linter):
type_hint_checker.visit_asyncfunctiondef(func_node)
def test_invalid_list_dict_str_any(
linter: UnittestLinter, type_hint_checker: BaseChecker
) -> None:
"""Ensure invalid hints are rejected for discovery_info."""
type_hint_checker.module = "homeassistant.components.pylint_test.device_trigger"
func_node = astroid.extract_node(
"""
async def async_get_triggers( #@
hass: HomeAssistant,
device_id: str
) -> list:
pass
"""
)
with assert_adds_messages(
linter,
pylint.testutils.MessageTest(
msg_id="hass-return-type",
node=func_node,
args=["list[dict[str, str]]", "list[dict[str, Any]]"],
line=2,
col_offset=0,
end_line=2,
end_col_offset=28,
),
):
type_hint_checker.visit_asyncfunctiondef(func_node)
def test_valid_list_dict_str_any(
linter: UnittestLinter, type_hint_checker: BaseChecker
) -> None:
"""Ensure valid hints are accepted for discovery_info."""
type_hint_checker.module = "homeassistant.components.pylint_test.device_trigger"
func_node = astroid.extract_node(
"""
async def async_get_triggers( #@
hass: HomeAssistant,
device_id: str
) -> list[dict[str, Any]]:
pass
"""
)
with assert_no_messages(linter):
type_hint_checker.visit_asyncfunctiondef(func_node)