Add type hints to requirements script (#82075)
This commit is contained in:
parent
1582d88957
commit
0538154767
1 changed files with 44 additions and 26 deletions
|
@ -1,5 +1,7 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Generate an updated requirements_all.txt."""
|
"""Generate updated constraint and requirements files."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import difflib
|
import difflib
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
|
@ -7,6 +9,7 @@ from pathlib import Path
|
||||||
import pkgutil
|
import pkgutil
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from homeassistant.util.yaml.loader import load_yaml
|
from homeassistant.util.yaml.loader import load_yaml
|
||||||
from script.hassfest.model import Integration
|
from script.hassfest.model import Integration
|
||||||
|
@ -157,7 +160,7 @@ IGNORE_PRE_COMMIT_HOOK_ID = (
|
||||||
PACKAGE_REGEX = re.compile(r"^(?:--.+\s)?([-_\.\w\d]+).*==.+$")
|
PACKAGE_REGEX = re.compile(r"^(?:--.+\s)?([-_\.\w\d]+).*==.+$")
|
||||||
|
|
||||||
|
|
||||||
def has_tests(module: str):
|
def has_tests(module: str) -> bool:
|
||||||
"""Test if a module has tests.
|
"""Test if a module has tests.
|
||||||
|
|
||||||
Module format: homeassistant.components.hue
|
Module format: homeassistant.components.hue
|
||||||
|
@ -169,11 +172,11 @@ def has_tests(module: str):
|
||||||
return path.exists()
|
return path.exists()
|
||||||
|
|
||||||
|
|
||||||
def explore_module(package, explore_children):
|
def explore_module(package: str, explore_children: bool) -> list[str]:
|
||||||
"""Explore the modules."""
|
"""Explore the modules."""
|
||||||
module = importlib.import_module(package)
|
module = importlib.import_module(package)
|
||||||
|
|
||||||
found = []
|
found: list[str] = []
|
||||||
|
|
||||||
if not hasattr(module, "__path__"):
|
if not hasattr(module, "__path__"):
|
||||||
return found
|
return found
|
||||||
|
@ -187,14 +190,17 @@ def explore_module(package, explore_children):
|
||||||
return found
|
return found
|
||||||
|
|
||||||
|
|
||||||
def core_requirements():
|
def core_requirements() -> list[str]:
|
||||||
"""Gather core requirements out of pyproject.toml."""
|
"""Gather core requirements out of pyproject.toml."""
|
||||||
with open("pyproject.toml", "rb") as fp:
|
with open("pyproject.toml", "rb") as fp:
|
||||||
data = tomllib.load(fp)
|
data = tomllib.load(fp)
|
||||||
return data["project"]["dependencies"]
|
dependencies: list[str] = data["project"]["dependencies"]
|
||||||
|
return dependencies
|
||||||
|
|
||||||
|
|
||||||
def gather_recursive_requirements(domain, seen=None):
|
def gather_recursive_requirements(
|
||||||
|
domain: str, seen: set[str] | None = None
|
||||||
|
) -> set[str]:
|
||||||
"""Recursively gather requirements from a module."""
|
"""Recursively gather requirements from a module."""
|
||||||
if seen is None:
|
if seen is None:
|
||||||
seen = set()
|
seen = set()
|
||||||
|
@ -221,18 +227,18 @@ def normalize_package_name(requirement: str) -> str:
|
||||||
return package
|
return package
|
||||||
|
|
||||||
|
|
||||||
def comment_requirement(req):
|
def comment_requirement(req: str) -> bool:
|
||||||
"""Comment out requirement. Some don't install on all systems."""
|
"""Comment out requirement. Some don't install on all systems."""
|
||||||
return any(
|
return any(
|
||||||
normalize_package_name(req) == ign for ign in COMMENT_REQUIREMENTS_NORMALIZED
|
normalize_package_name(req) == ign for ign in COMMENT_REQUIREMENTS_NORMALIZED
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def gather_modules():
|
def gather_modules() -> dict[str, list[str]] | None:
|
||||||
"""Collect the information."""
|
"""Collect the information."""
|
||||||
reqs = {}
|
reqs: dict[str, list[str]] = {}
|
||||||
|
|
||||||
errors = []
|
errors: list[str] = []
|
||||||
|
|
||||||
gather_requirements_from_manifests(errors, reqs)
|
gather_requirements_from_manifests(errors, reqs)
|
||||||
gather_requirements_from_modules(errors, reqs)
|
gather_requirements_from_modules(errors, reqs)
|
||||||
|
@ -248,7 +254,9 @@ def gather_modules():
|
||||||
return reqs
|
return reqs
|
||||||
|
|
||||||
|
|
||||||
def gather_requirements_from_manifests(errors, reqs):
|
def gather_requirements_from_manifests(
|
||||||
|
errors: list[str], reqs: dict[str, list[str]]
|
||||||
|
) -> None:
|
||||||
"""Gather all of the requirements from manifests."""
|
"""Gather all of the requirements from manifests."""
|
||||||
integrations = Integration.load_dir(Path("homeassistant/components"))
|
integrations = Integration.load_dir(Path("homeassistant/components"))
|
||||||
for domain in sorted(integrations):
|
for domain in sorted(integrations):
|
||||||
|
@ -266,7 +274,9 @@ def gather_requirements_from_manifests(errors, reqs):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def gather_requirements_from_modules(errors, reqs):
|
def gather_requirements_from_modules(
|
||||||
|
errors: list[str], reqs: dict[str, list[str]]
|
||||||
|
) -> None:
|
||||||
"""Collect the requirements from the modules directly."""
|
"""Collect the requirements from the modules directly."""
|
||||||
for package in sorted(
|
for package in sorted(
|
||||||
explore_module("homeassistant.scripts", True)
|
explore_module("homeassistant.scripts", True)
|
||||||
|
@ -283,7 +293,12 @@ def gather_requirements_from_modules(errors, reqs):
|
||||||
process_requirements(errors, module.REQUIREMENTS, package, reqs)
|
process_requirements(errors, module.REQUIREMENTS, package, reqs)
|
||||||
|
|
||||||
|
|
||||||
def process_requirements(errors, module_requirements, package, reqs):
|
def process_requirements(
|
||||||
|
errors: list[str],
|
||||||
|
module_requirements: list[str],
|
||||||
|
package: str,
|
||||||
|
reqs: dict[str, list[str]],
|
||||||
|
) -> None:
|
||||||
"""Process all of the requirements."""
|
"""Process all of the requirements."""
|
||||||
for req in module_requirements:
|
for req in module_requirements:
|
||||||
if "://" in req:
|
if "://" in req:
|
||||||
|
@ -293,7 +308,7 @@ def process_requirements(errors, module_requirements, package, reqs):
|
||||||
reqs.setdefault(req, []).append(package)
|
reqs.setdefault(req, []).append(package)
|
||||||
|
|
||||||
|
|
||||||
def generate_requirements_list(reqs):
|
def generate_requirements_list(reqs: dict[str, list[str]]) -> str:
|
||||||
"""Generate a pip file based on requirements."""
|
"""Generate a pip file based on requirements."""
|
||||||
output = []
|
output = []
|
||||||
for pkg, requirements in sorted(reqs.items(), key=lambda item: item[0]):
|
for pkg, requirements in sorted(reqs.items(), key=lambda item: item[0]):
|
||||||
|
@ -307,7 +322,7 @@ def generate_requirements_list(reqs):
|
||||||
return "".join(output)
|
return "".join(output)
|
||||||
|
|
||||||
|
|
||||||
def requirements_output(reqs):
|
def requirements_output() -> str:
|
||||||
"""Generate output for requirements."""
|
"""Generate output for requirements."""
|
||||||
output = [
|
output = [
|
||||||
"-c homeassistant/package_constraints.txt\n",
|
"-c homeassistant/package_constraints.txt\n",
|
||||||
|
@ -320,7 +335,7 @@ def requirements_output(reqs):
|
||||||
return "".join(output)
|
return "".join(output)
|
||||||
|
|
||||||
|
|
||||||
def requirements_all_output(reqs):
|
def requirements_all_output(reqs: dict[str, list[str]]) -> str:
|
||||||
"""Generate output for requirements_all."""
|
"""Generate output for requirements_all."""
|
||||||
output = [
|
output = [
|
||||||
"# Home Assistant Core, full dependency set\n",
|
"# Home Assistant Core, full dependency set\n",
|
||||||
|
@ -331,7 +346,7 @@ def requirements_all_output(reqs):
|
||||||
return "".join(output)
|
return "".join(output)
|
||||||
|
|
||||||
|
|
||||||
def requirements_test_all_output(reqs):
|
def requirements_test_all_output(reqs: dict[str, list[str]]) -> str:
|
||||||
"""Generate output for test_requirements."""
|
"""Generate output for test_requirements."""
|
||||||
output = [
|
output = [
|
||||||
"# Home Assistant tests, full dependency set\n",
|
"# Home Assistant tests, full dependency set\n",
|
||||||
|
@ -356,15 +371,18 @@ def requirements_test_all_output(reqs):
|
||||||
return "".join(output)
|
return "".join(output)
|
||||||
|
|
||||||
|
|
||||||
def requirements_pre_commit_output():
|
def requirements_pre_commit_output() -> str:
|
||||||
"""Generate output for pre-commit dependencies."""
|
"""Generate output for pre-commit dependencies."""
|
||||||
source = ".pre-commit-config.yaml"
|
source = ".pre-commit-config.yaml"
|
||||||
pre_commit_conf = load_yaml(source)
|
pre_commit_conf: dict[str, list[dict[str, Any]]]
|
||||||
reqs = []
|
pre_commit_conf = load_yaml(source) # type: ignore[assignment]
|
||||||
|
reqs: list[str] = []
|
||||||
|
hook: dict[str, Any]
|
||||||
for repo in (x for x in pre_commit_conf["repos"] if x.get("rev")):
|
for repo in (x for x in pre_commit_conf["repos"] if x.get("rev")):
|
||||||
|
rev: str = repo["rev"]
|
||||||
for hook in repo["hooks"]:
|
for hook in repo["hooks"]:
|
||||||
if hook["id"] not in IGNORE_PRE_COMMIT_HOOK_ID:
|
if hook["id"] not in IGNORE_PRE_COMMIT_HOOK_ID:
|
||||||
reqs.append(f"{hook['id']}=={repo['rev'].lstrip('v')}")
|
reqs.append(f"{hook['id']}=={rev.lstrip('v')}")
|
||||||
reqs.extend(x for x in hook.get("additional_dependencies", ()))
|
reqs.extend(x for x in hook.get("additional_dependencies", ()))
|
||||||
output = [
|
output = [
|
||||||
f"# Automatically generated "
|
f"# Automatically generated "
|
||||||
|
@ -375,7 +393,7 @@ def requirements_pre_commit_output():
|
||||||
return "\n".join(output) + "\n"
|
return "\n".join(output) + "\n"
|
||||||
|
|
||||||
|
|
||||||
def gather_constraints():
|
def gather_constraints() -> str:
|
||||||
"""Construct output for constraint file."""
|
"""Construct output for constraint file."""
|
||||||
return (
|
return (
|
||||||
"\n".join(
|
"\n".join(
|
||||||
|
@ -392,7 +410,7 @@ def gather_constraints():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def diff_file(filename, content):
|
def diff_file(filename: str, content: str) -> list[str]:
|
||||||
"""Diff a file."""
|
"""Diff a file."""
|
||||||
return list(
|
return list(
|
||||||
difflib.context_diff(
|
difflib.context_diff(
|
||||||
|
@ -404,7 +422,7 @@ def diff_file(filename, content):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def main(validate):
|
def main(validate: bool) -> int:
|
||||||
"""Run the script."""
|
"""Run the script."""
|
||||||
if not os.path.isfile("requirements_all.txt"):
|
if not os.path.isfile("requirements_all.txt"):
|
||||||
print("Run this from HA root dir")
|
print("Run this from HA root dir")
|
||||||
|
@ -415,7 +433,7 @@ def main(validate):
|
||||||
if data is None:
|
if data is None:
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
reqs_file = requirements_output(data)
|
reqs_file = requirements_output()
|
||||||
reqs_all_file = requirements_all_output(data)
|
reqs_all_file = requirements_all_output(data)
|
||||||
reqs_test_all_file = requirements_test_all_output(data)
|
reqs_test_all_file = requirements_test_all_output(data)
|
||||||
reqs_pre_commit_file = requirements_pre_commit_output()
|
reqs_pre_commit_file = requirements_pre_commit_output()
|
||||||
|
|
Loading…
Add table
Reference in a new issue