Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions ci/tools/validate-release-wheels
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from __future__ import annotations

import argparse
import re
import sys
from collections import defaultdict
from pathlib import Path

from check_release_notes import parse_version_from_tag

COMPONENT_TO_DISTRIBUTIONS: dict[str, set[str]] = {
"cuda-core": {"cuda_core"},
"cuda-bindings": {"cuda_bindings"},
Expand All @@ -22,11 +23,13 @@ COMPONENT_TO_DISTRIBUTIONS: dict[str, set[str]] = {
"all": {"cuda_core", "cuda_bindings", "cuda_pathfinder", "cuda_python"},
}

TAG_PATTERNS = (
re.compile(r"^v(?P<version>\d+\.\d+\.\d+)"),
re.compile(r"^cuda-core-v(?P<version>\d+\.\d+\.\d+)"),
re.compile(r"^cuda-pathfinder-v(?P<version>\d+\.\d+\.\d+)"),
)
COMPONENT_TO_TAG_COMPONENTS: dict[str, tuple[str, ...]] = {
"cuda-core": ("cuda-core",),
"cuda-bindings": ("cuda-bindings",),
"cuda-pathfinder": ("cuda-pathfinder",),
"cuda-python": ("cuda-python",),
"all": ("cuda-core", "cuda-bindings", "cuda-pathfinder", "cuda-python"),
}


def parse_args() -> argparse.Namespace:
Expand All @@ -42,15 +45,18 @@ def parse_args() -> argparse.Namespace:
return parser.parse_args()


def version_from_tag(tag: str) -> str:
for pattern in TAG_PATTERNS:
match = pattern.match(tag)
if match:
return match.group("version")
def version_from_tag(tag: str, component: str) -> str:
versions = {
version
for tag_component in COMPONENT_TO_TAG_COMPONENTS[component]
if (version := parse_version_from_tag(tag, tag_component)) is not None
}
if len(versions) == 1:
return versions.pop()
raise ValueError(
"Unsupported git tag format "
f"{tag!r}; expected tags beginning with vX.Y.Z, cuda-core-vX.Y.Z, "
"or cuda-pathfinder-vX.Y.Z."
f"{tag!r} for component {component!r}; expected vX.Y.Z[.postN], "
"cuda-core-vX.Y.Z[.postN], or cuda-pathfinder-vX.Y.Z[.postN]."
)


Expand All @@ -64,7 +70,12 @@ def parse_wheel_dist_and_version(path: Path) -> tuple[str, str]:

def main() -> int:
args = parse_args()
expected_version = version_from_tag(args.git_tag)
try:
expected_version = version_from_tag(args.git_tag, args.component)
except ValueError as exc:
print(f"Error: {exc}", file=sys.stderr)
return 1

expected_distributions = COMPONENT_TO_DISTRIBUTIONS[args.component]
wheel_dir = Path(args.wheel_dir)

Expand Down