Add type hints to requirements script (#82075)

This commit is contained in:
epenet 2022-11-16 13:00:35 +01:00 committed by GitHub
parent 1582d88957
commit 0538154767
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,5 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Generate an updated requirements_all.txt.""" """Generate updated constraint and requirements files."""
from __future__ import annotations
import difflib import difflib
import importlib import importlib
import os import os
@ -7,6 +9,7 @@ from pathlib import Path
import pkgutil import pkgutil
import re import re
import sys import sys
from typing import Any
from homeassistant.util.yaml.loader import load_yaml from homeassistant.util.yaml.loader import load_yaml
from script.hassfest.model import Integration from script.hassfest.model import Integration
@ -157,7 +160,7 @@ IGNORE_PRE_COMMIT_HOOK_ID = (
PACKAGE_REGEX = re.compile(r"^(?:--.+\s)?([-_\.\w\d]+).*==.+$") PACKAGE_REGEX = re.compile(r"^(?:--.+\s)?([-_\.\w\d]+).*==.+$")
def has_tests(module: str): def has_tests(module: str) -> bool:
"""Test if a module has tests. """Test if a module has tests.
Module format: homeassistant.components.hue Module format: homeassistant.components.hue
@ -169,11 +172,11 @@ def has_tests(module: str):
return path.exists() return path.exists()
def explore_module(package, explore_children): def explore_module(package: str, explore_children: bool) -> list[str]:
"""Explore the modules.""" """Explore the modules."""
module = importlib.import_module(package) module = importlib.import_module(package)
found = [] found: list[str] = []
if not hasattr(module, "__path__"): if not hasattr(module, "__path__"):
return found return found
@ -187,14 +190,17 @@ def explore_module(package, explore_children):
return found return found
def core_requirements(): def core_requirements() -> list[str]:
"""Gather core requirements out of pyproject.toml.""" """Gather core requirements out of pyproject.toml."""
with open("pyproject.toml", "rb") as fp: with open("pyproject.toml", "rb") as fp:
data = tomllib.load(fp) data = tomllib.load(fp)
return data["project"]["dependencies"] dependencies: list[str] = data["project"]["dependencies"]
return dependencies
def gather_recursive_requirements(domain, seen=None): def gather_recursive_requirements(
domain: str, seen: set[str] | None = None
) -> set[str]:
"""Recursively gather requirements from a module.""" """Recursively gather requirements from a module."""
if seen is None: if seen is None:
seen = set() seen = set()
@ -221,18 +227,18 @@ def normalize_package_name(requirement: str) -> str:
return package return package
def comment_requirement(req): def comment_requirement(req: str) -> bool:
"""Comment out requirement. Some don't install on all systems.""" """Comment out requirement. Some don't install on all systems."""
return any( return any(
normalize_package_name(req) == ign for ign in COMMENT_REQUIREMENTS_NORMALIZED normalize_package_name(req) == ign for ign in COMMENT_REQUIREMENTS_NORMALIZED
) )
def gather_modules(): def gather_modules() -> dict[str, list[str]] | None:
"""Collect the information.""" """Collect the information."""
reqs = {} reqs: dict[str, list[str]] = {}
errors = [] errors: list[str] = []
gather_requirements_from_manifests(errors, reqs) gather_requirements_from_manifests(errors, reqs)
gather_requirements_from_modules(errors, reqs) gather_requirements_from_modules(errors, reqs)
@ -248,7 +254,9 @@ def gather_modules():
return reqs return reqs
def gather_requirements_from_manifests(errors, reqs): def gather_requirements_from_manifests(
errors: list[str], reqs: dict[str, list[str]]
) -> None:
"""Gather all of the requirements from manifests.""" """Gather all of the requirements from manifests."""
integrations = Integration.load_dir(Path("homeassistant/components")) integrations = Integration.load_dir(Path("homeassistant/components"))
for domain in sorted(integrations): for domain in sorted(integrations):
@ -266,7 +274,9 @@ def gather_requirements_from_manifests(errors, reqs):
) )
def gather_requirements_from_modules(errors, reqs): def gather_requirements_from_modules(
errors: list[str], reqs: dict[str, list[str]]
) -> None:
"""Collect the requirements from the modules directly.""" """Collect the requirements from the modules directly."""
for package in sorted( for package in sorted(
explore_module("homeassistant.scripts", True) explore_module("homeassistant.scripts", True)
@ -283,7 +293,12 @@ def gather_requirements_from_modules(errors, reqs):
process_requirements(errors, module.REQUIREMENTS, package, reqs) process_requirements(errors, module.REQUIREMENTS, package, reqs)
def process_requirements(errors, module_requirements, package, reqs): def process_requirements(
errors: list[str],
module_requirements: list[str],
package: str,
reqs: dict[str, list[str]],
) -> None:
"""Process all of the requirements.""" """Process all of the requirements."""
for req in module_requirements: for req in module_requirements:
if "://" in req: if "://" in req:
@ -293,7 +308,7 @@ def process_requirements(errors, module_requirements, package, reqs):
reqs.setdefault(req, []).append(package) reqs.setdefault(req, []).append(package)
def generate_requirements_list(reqs): def generate_requirements_list(reqs: dict[str, list[str]]) -> str:
"""Generate a pip file based on requirements.""" """Generate a pip file based on requirements."""
output = [] output = []
for pkg, requirements in sorted(reqs.items(), key=lambda item: item[0]): for pkg, requirements in sorted(reqs.items(), key=lambda item: item[0]):
@ -307,7 +322,7 @@ def generate_requirements_list(reqs):
return "".join(output) return "".join(output)
def requirements_output(reqs): def requirements_output() -> str:
"""Generate output for requirements.""" """Generate output for requirements."""
output = [ output = [
"-c homeassistant/package_constraints.txt\n", "-c homeassistant/package_constraints.txt\n",
@ -320,7 +335,7 @@ def requirements_output(reqs):
return "".join(output) return "".join(output)
def requirements_all_output(reqs): def requirements_all_output(reqs: dict[str, list[str]]) -> str:
"""Generate output for requirements_all.""" """Generate output for requirements_all."""
output = [ output = [
"# Home Assistant Core, full dependency set\n", "# Home Assistant Core, full dependency set\n",
@ -331,7 +346,7 @@ def requirements_all_output(reqs):
return "".join(output) return "".join(output)
def requirements_test_all_output(reqs): def requirements_test_all_output(reqs: dict[str, list[str]]) -> str:
"""Generate output for test_requirements.""" """Generate output for test_requirements."""
output = [ output = [
"# Home Assistant tests, full dependency set\n", "# Home Assistant tests, full dependency set\n",
@ -356,15 +371,18 @@ def requirements_test_all_output(reqs):
return "".join(output) return "".join(output)
def requirements_pre_commit_output(): def requirements_pre_commit_output() -> str:
"""Generate output for pre-commit dependencies.""" """Generate output for pre-commit dependencies."""
source = ".pre-commit-config.yaml" source = ".pre-commit-config.yaml"
pre_commit_conf = load_yaml(source) pre_commit_conf: dict[str, list[dict[str, Any]]]
reqs = [] pre_commit_conf = load_yaml(source) # type: ignore[assignment]
reqs: list[str] = []
hook: dict[str, Any]
for repo in (x for x in pre_commit_conf["repos"] if x.get("rev")): for repo in (x for x in pre_commit_conf["repos"] if x.get("rev")):
rev: str = repo["rev"]
for hook in repo["hooks"]: for hook in repo["hooks"]:
if hook["id"] not in IGNORE_PRE_COMMIT_HOOK_ID: if hook["id"] not in IGNORE_PRE_COMMIT_HOOK_ID:
reqs.append(f"{hook['id']}=={repo['rev'].lstrip('v')}") reqs.append(f"{hook['id']}=={rev.lstrip('v')}")
reqs.extend(x for x in hook.get("additional_dependencies", ())) reqs.extend(x for x in hook.get("additional_dependencies", ()))
output = [ output = [
f"# Automatically generated " f"# Automatically generated "
@ -375,7 +393,7 @@ def requirements_pre_commit_output():
return "\n".join(output) + "\n" return "\n".join(output) + "\n"
def gather_constraints(): def gather_constraints() -> str:
"""Construct output for constraint file.""" """Construct output for constraint file."""
return ( return (
"\n".join( "\n".join(
@ -392,7 +410,7 @@ def gather_constraints():
) )
def diff_file(filename, content): def diff_file(filename: str, content: str) -> list[str]:
"""Diff a file.""" """Diff a file."""
return list( return list(
difflib.context_diff( difflib.context_diff(
@ -404,7 +422,7 @@ def diff_file(filename, content):
) )
def main(validate): def main(validate: bool) -> int:
"""Run the script.""" """Run the script."""
if not os.path.isfile("requirements_all.txt"): if not os.path.isfile("requirements_all.txt"):
print("Run this from HA root dir") print("Run this from HA root dir")
@ -415,7 +433,7 @@ def main(validate):
if data is None: if data is None:
return 1 return 1
reqs_file = requirements_output(data) reqs_file = requirements_output()
reqs_all_file = requirements_all_output(data) reqs_all_file = requirements_all_output(data)
reqs_test_all_file = requirements_test_all_output(data) reqs_test_all_file = requirements_test_all_output(data)
reqs_pre_commit_file = requirements_pre_commit_output() reqs_pre_commit_file = requirements_pre_commit_output()