Skip to content
Open
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ require (
github.com/vincent-petithory/dataurl v1.0.0
github.com/xeipuuv/gojsonschema v1.2.0
github.com/xeonx/timeago v1.0.0-rc5
golang.org/x/crypto v0.37.0
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0
golang.org/x/sync v0.13.0
golang.org/x/sys v0.32.0
Expand Down Expand Up @@ -271,6 +270,7 @@ require (
go.uber.org/automaxprocs v1.6.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect
golang.org/x/crypto v0.37.0 // indirect
golang.org/x/exp/typeparams v0.0.0-20250210185358-939b2ce775ac // indirect
golang.org/x/mod v0.24.0 // indirect
golang.org/x/net v0.39.0 // indirect
Expand Down
68 changes: 47 additions & 21 deletions pkg/config/compatibility.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,29 +273,37 @@ func CUDABaseImageFor(cuda string, cuDNN string) (string, error) {
return images[0].ImageTag(), nil
}

func tfGPUPackage(ver string, cuda string) (name string, cpuVersion string, err error) {
func tfGPUPackage(ver string, cuda string) (PythonRequirement, error) {
for _, compat := range TFCompatibilityMatrix {
if compat.TF == ver && version.Equal(compat.CUDA, cuda) {
name, cpuVersion, _, _, err = SplitPinnedPythonRequirement(compat.TFGPUPackage)
return name, cpuVersion, err
if req := SplitPinnedPythonRequirement(compat.TFGPUPackage); !req.ParsedFieldsValid {
return PythonRequirement{}, fmt.Errorf("Invalid Python requirement for %s version %s", ver, cuda)
} else {
return req, nil
}
}
}
// We've already warned user if they're doing something stupid in validateAndCompleteCUDA(), so fail silently
return "", "", nil
return PythonRequirement{}, nil
}

func torchCPUPackage(ver, goos, goarch string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
func torchCPUPackage(ver, goos, goarch string) (req PythonRequirement, err error) {
req.Name = "torch"
req.Version = ver
req.ParsedFieldsValid = true

// The default is to just install the default version. For older pytorch versions, they don't have any CPU versions.
for _, compat := range TorchCompatibilityMatrix {
if compat.TorchVersion() == ver && compat.CUDA == nil {
return "torch", torchStripCPUSuffixForM1(compat.Torch, goos, goarch), compat.FindLinks, compat.ExtraIndexURL, nil
req.Version = torchStripCPUSuffixForM1(compat.Torch, goos, goarch)
req.FindLinks = []string{compat.FindLinks}
req.ExtraIndexURLs = []string{compat.ExtraIndexURL}
}
}

// Fall back to just installing default version. For older pytorch versions, they don't have any CPU versions.
return "torch", ver, "", "", nil
return
}

func torchGPUPackage(ver string, cuda string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
func torchGPUPackage(ver string, cuda string) (req PythonRequirement, err error) {
// find the torch package that has the requested torch version and the latest cuda version
// that is at most as high as the requested cuda version
var latest *TorchCompatibility
Expand Down Expand Up @@ -324,25 +332,36 @@ func torchGPUPackage(ver string, cuda string) (name, cpuVersion, findLinks, extr
}
}
}
if latest == nil {
// We've already warned user if they're doing something stupid in validateAndCompleteCUDA()
return "torch", ver, "", "", nil

req.Name = "torch"
req.ParsedFieldsValid = true
// We've already warned user if they're doing something stupid in validateAndCompleteCUDA()
if latest != nil {
req.Version = version.StripModifier(latest.Torch)
req.FindLinks = []string{latest.FindLinks}
req.ExtraIndexURLs = []string{latest.ExtraIndexURL}
}

return "torch", version.StripModifier(latest.Torch), latest.FindLinks, latest.ExtraIndexURL, nil
return
}

func torchvisionCPUPackage(ver, goos, goarch string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
func torchvisionCPUPackage(ver, goos, goarch string) (req PythonRequirement, err error) {
req.Name = "torchvision"
req.ParsedFieldsValid = true

// Fall back to just installing default version. For older torchvision versions, they don't have any CPU versions.
req.Version = ver
for _, compat := range TorchCompatibilityMatrix {
if compat.TorchvisionVersion() == ver && compat.CUDA == nil {
return "torchvision", torchStripCPUSuffixForM1(compat.Torchvision, goos, goarch), compat.FindLinks, compat.ExtraIndexURL, nil
req.Version = torchStripCPUSuffixForM1(compat.Torchvision, goos, goarch)
req.FindLinks = []string{compat.FindLinks}
req.ExtraIndexURLs = []string{compat.ExtraIndexURL}
}
}
// Fall back to just installing default version. For older torchvision versions, they don't have any CPU versions.
return "torchvision", ver, "", "", nil
return
}

func torchvisionGPUPackage(ver, cuda string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
func torchvisionGPUPackage(ver, cuda string) (req PythonRequirement, err error) {
// find the torchvision package that has the requested
// torchvision version and the latest cuda version that is at
// most as high as the requested cuda version
Expand Down Expand Up @@ -371,13 +390,20 @@ func torchvisionGPUPackage(ver, cuda string) (name, cpuVersion, findLinks, extra
}
}
}

req.Name = "torchvision"
req.ParsedFieldsValid = true
if latest == nil {
// TODO: can we suggest a CUDA version known to be compatible?
console.Warnf("Cog doesn't know if CUDA %s is compatible with torchvision %s. This might cause CUDA problems.", cuda, ver)
return "torchvision", ver, "", "", nil
req.Version = ver
} else {
req.Version = version.StripModifier(latest.Torchvision)
req.FindLinks = []string{latest.FindLinks}
req.ExtraIndexURLs = []string{latest.ExtraIndexURL}
}

return "torchvision", version.StripModifier(latest.Torchvision), latest.FindLinks, latest.ExtraIndexURL, nil
return
}

// aarch64 packages don't have +cpu suffix: https://download.pytorch.org/whl/torch_stable.html
Expand Down
185 changes: 82 additions & 103 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,12 @@ func (c *Config) cudaFromTF() (tfVersion string, tfCUDA string, tfCuDNN string,
return "", "", "", nil
}

func (c *Config) pythonPackageVersion(name string) (version string, ok bool) {
func (c *Config) pythonPackageVersion(name string) (string, bool) {
for _, pkg := range c.Build.pythonRequirementsContent {
pkgName, version, _, _, err := SplitPinnedPythonRequirement(pkg)
if err != nil {
// package is not in package==version format
continue
}
if pkgName == name {
return version, true
if req := SplitPinnedPythonRequirement(pkg); !req.ParsedFieldsValid {
return "", false
} else if req.Name == name {
return req.Version, true
}
}
return "", false
Expand Down Expand Up @@ -327,132 +324,114 @@ func (c *Config) ValidateAndComplete(projectDir string) error {
}

// PythonRequirementsForArch returns a requirements.txt file with all the GPU packages resolved for given OS and architecture.
// The packages listed in c.Build.pythonRequirementsContent are user-supplied requirements. Packages listed in the
// `includePackages` parameter are defaults. The two sets are union'd together, with the user's own requirements
// taking precedence (version, find-links, etc) if there is a duplicate.
//
// The method will return the string content of the requirements file, or an error.
func (c *Config) PythonRequirementsForArch(goos string, goarch string, includePackages []string) (string, error) {
packages := []string{}
findLinksSet := map[string]bool{}
extraIndexURLSet := map[string]bool{}

includePackageNames := []string{}
for _, pkg := range includePackages {
packageName, err := PackageName(pkg)
// First, parse all the incoming requirements into PythonRequirements
userRequirements := ParseRequirements(c.Build.pythonRequirementsContent, 0)

// Do the same for the packages we've been asked to include by default, but set their ordering keys using a
// sequence number later than the user requirements. This will ensure that our default requirements come at the
// end of the list, and the order is maintained.
includeRequirements := ParseRequirements(includePackages, len(userRequirements))

// For the user requirements, update them for the given OS and architecture
var err error
for i, req := range userRequirements {
// We're only interested in requirements that we were actually able to parse
if !req.ParsedFieldsValid {
continue
}
userRequirements[i], err = c.pythonPackageForArch(req, goos, goarch)
if err != nil {
return "", err
}
includePackageNames = append(includePackageNames, packageName)
}

// Include all the requirements and remove our include packages if they exist
for _, pkg := range c.Build.pythonRequirementsContent {
archPkg, findLinksList, extraIndexURLs, err := c.pythonPackageForArch(pkg, goos, goarch)
if err != nil {
return "", err
}
packages = append(packages, archPkg)
if len(findLinksList) > 0 {
for _, fl := range findLinksList {
findLinksSet[fl] = true
}
}
if len(extraIndexURLs) > 0 {
for _, u := range extraIndexURLs {
extraIndexURLSet[u] = true
}
}
// We're about to perform deduplication between the user requirements and the provided defaults. There may
// be user requirements that we weren't able to parse though - we will keep a note of those so that we can
// add them back in later.
unparsed := make([]PythonRequirement, 0)

packageName, _ := PackageName(archPkg)
if packageName != "" {
foundIdx := -1
for i, includePkg := range includePackageNames {
if includePkg == packageName {
foundIdx = i
break
}
}
if foundIdx != -1 {
includePackageNames = append(includePackageNames[:foundIdx], includePackageNames[foundIdx+1:]...)
includePackages = append(includePackages[:foundIdx], includePackages[foundIdx+1:]...)
}
}
// Next, build a map of requirements keyed on the requirement name. We'll init this with the requirements
// from `includePackages`, and update it with the user's requirements (which may therefore overwrite the defaults).
finalRequirementsMap := make(map[string]PythonRequirement)
for _, req := range includeRequirements {
finalRequirementsMap[req.Name] = req
}

// If we still have some include packages add them in
packages = append(packages, includePackages...)

// Create final requirements.txt output
// Put index URLs first
lines := []string{}
for findLinks := range findLinksSet {
lines = append(lines, "--find-links "+findLinks)
for _, req := range userRequirements {
if req.ParsedFieldsValid {
finalRequirementsMap[req.Name] = req
} else {
unparsed = append(unparsed, req)
}
}
for extraIndexURL := range extraIndexURLSet {
lines = append(lines, "--extra-index-url "+extraIndexURL)

// Now we can build a real PythonRequirements from the values of the finalRequirementsMap
finalRequirements := make(PythonRequirements, 0, len(finalRequirementsMap)+len(unparsed))
for _, req := range finalRequirementsMap {
finalRequirements = append(finalRequirements, req)
}

// Then, everything else
lines = append(lines, packages...)
// Add the unparsed requirements back in
finalRequirements = append(finalRequirements, unparsed...)

return strings.Join(lines, "\n"), nil
return finalRequirements.RequirementsFileContent(), nil
}

// pythonPackageForArch takes a package==version line and
// returns a package==version and index URL resolved to the correct GPU package for the given OS and architecture
func (c *Config) pythonPackageForArch(pkg, goos, goarch string) (actualPackage string, findLinksList []string, extraIndexURLs []string, err error) {
name, version, findLinksList, extraIndexURLs, err := SplitPinnedPythonRequirement(pkg)
if err != nil {
// It's not pinned, so just return the line verbatim
return pkg, []string{}, []string{}, nil
}
if len(extraIndexURLs) > 0 {
return name + "==" + version, findLinksList, extraIndexURLs, nil
}

extraIndexURL := ""
findLinks := ""
switch name {
// pythonPackageForArch takes a PythonRequirement and
// returns a package==version and index URL resolved to the correct GPU package for the given OS and architecture. If
// the package is not one of the ones whose version we manage, we return the original requirement.
func (c *Config) pythonPackageForArch(req PythonRequirement, goos, goarch string) (out PythonRequirement, err error) {
switch req.Name {
case "tensorflow":
if c.Build.GPU {
name, version, err = tfGPUPackage(version, c.Build.CUDA)
if err != nil {
return "", nil, nil, err
}
out, err = tfGPUPackage(req.Version, c.Build.CUDA)
}
// There is no CPU case for tensorflow because the default package is just the CPU package, so no transformation of version is needed
case "torch":
if c.Build.GPU {
name, version, findLinks, extraIndexURL, err = torchGPUPackage(version, c.Build.CUDA)
if err != nil {
return "", nil, nil, err
}
out, err = torchGPUPackage(req.Version, c.Build.CUDA)
} else {
name, version, findLinks, extraIndexURL, err = torchCPUPackage(version, goos, goarch)
if err != nil {
return "", nil, nil, err
}
out, err = torchCPUPackage(req.Version, goos, goarch)
}
case "torchvision":
if c.Build.GPU {
name, version, findLinks, extraIndexURL, err = torchvisionGPUPackage(version, c.Build.CUDA)
if err != nil {
return "", nil, nil, err
}
out, err = torchvisionGPUPackage(req.Version, c.Build.CUDA)
} else {
name, version, findLinks, extraIndexURL, err = torchvisionCPUPackage(version, goos, goarch)
if err != nil {
return "", nil, nil, err
}
out, err = torchvisionCPUPackage(req.Version, goos, goarch)
}
default:
out = req
}

if err != nil {
return PythonRequirement{}, err
}

// Regardless of whether we're using the original or generated requirement, we bring across some user-supplied
// attributes if provided.
out.order = req.order

// We treat version slightly differently, because we may have rewritten the field to include the cpu specifier.
// Therefore, we will only overwrite the output version if the output version is currently empty.
if req.Version != "" && out.Version == "" {
out.Version = req.Version
}
pkgWithVersion := name
if version != "" {
pkgWithVersion += "==" + version
if req.EnvironmentMarkers != "" {
out.EnvironmentMarkers = req.EnvironmentMarkers
}
if extraIndexURL != "" {
extraIndexURLs = []string{extraIndexURL}
if len(req.FindLinks) > 0 {
out.FindLinks = req.FindLinks
}
if findLinks != "" {
findLinksList = []string{findLinks}
if len(req.ExtraIndexURLs) > 0 {
out.ExtraIndexURLs = req.ExtraIndexURLs
}
return pkgWithVersion, findLinksList, extraIndexURLs, nil
return
}

func ValidateCudaVersion(cudaVersion string) error {
Expand Down
20 changes: 20 additions & 0 deletions pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,26 @@ func TestBlankBuild(t *testing.T) {
require.Equal(t, false, config.Build.GPU)
}

// TestPythonRequirementsForArchWithPlatform checks that generated requirements don't lose any metadata that
// was supplied in the original requirements.txt, such as platform restrictions. We do expect hashes to be dropped.
func TestPythonRequirementsForArchWithPlatform(t *testing.T) {
tmpDir := t.TempDir()
err := os.WriteFile(path.Join(tmpDir, "requirements.txt"), []byte(`pywin32==310 ; sys_platform == 'win32' \
--hash=sha256:126298077a9d7c95c53823934f000599f66ec9296b09167810eb24875f32689c`), 0o644)
require.NoError(t, err)
config := &Config{
Build: &Build{
PythonVersion: "3.8",
PythonRequirements: "requirements.txt",
},
}
require.NoError(t, config.ValidateAndComplete(tmpDir))
requirements, err := config.PythonRequirementsForArch("", "", []string{})
require.NoError(t, err)
expected := "pywin32==310 ; sys_platform == 'win32'"
require.Equal(t, expected, requirements)
}

func TestPythonRequirementsForArchWithAddedPackage(t *testing.T) {
config := &Config{
Build: &Build{
Expand Down
Loading