Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 118 additions & 6 deletions crates/cudnn-sys/build/cudnn_sdk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,38 @@ impl CudnnSdk {
p.join("cudnn.h").is_file() && p.join("cudnn_version.h").is_file()
}

#[cfg(target_os = "windows")]
fn is_vxy_dir_name(name: &str) -> bool {
name.starts_with('v')
&& name[1..]
.split('.')
.all(|part| !part.is_empty() && part.chars().all(|c| c.is_ascii_digit()))
}

#[cfg(target_os = "windows")]
fn collect_windows_cudnn_include_paths(bases: &[&Path]) -> Vec<path::PathBuf> {
let mut paths = Vec::new();

for base in bases {
if let Ok(entries) = fs::read_dir(base) {
for entry in entries.flatten() {
if let Ok(file_type) = entry.file_type() {
if file_type.is_dir() {
let name = entry.file_name();
if let Some(name_str) = name.to_str() {
if Self::is_vxy_dir_name(name_str) {
paths.push(base.join(name_str).join("include"));
}
}
}
}
}
}
}

paths
}

fn find_cudnn_include_dir() -> Result<path::PathBuf, Box<dyn error::Error>> {
let cudnn_include_dir = env::var_os("CUDNN_INCLUDE_DIR");

Expand All @@ -71,15 +103,32 @@ impl CudnnSdk {
"/usr/local/include/x86_64-linux-gnu",
"/usr/local/include/aarch64-linux-gnu",
];

#[cfg(not(target_os = "windows"))]
let mut cudnn_paths: Vec<path::PathBuf> =
CUDNN_DEFAULT_PATHS.iter().map(Path::new).map(path::PathBuf::from).collect();

#[cfg(target_os = "windows")]
const CUDNN_DEFAULT_PATHS: &[&str] = &[
"C:/Program Files/NVIDIA/CUDNN/v9.x/include",
"C:/Program Files/NVIDIA/CUDNN/v8.x/include",
];
let mut cudnn_paths: Vec<path::PathBuf> = {
// Legacy standalone cuDNN installs following NVIDIA's documentation.
let mut paths = vec![
path::PathBuf::from("C:/Program Files/NVIDIA/CUDNN/v9.x/include"),
path::PathBuf::from("C:/Program Files/NVIDIA/CUDNN/v8.x/include"),
];

// Dynamically discover CUDA and cuDNN installs by matching vX.Y-style directories.
let bases = [
Path::new("C:/Program Files/NVIDIA/CUDNN"),
Path::new("C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA"),
];

paths.extend(Self::collect_windows_cudnn_include_paths(&bases));

paths
};

let mut cudnn_paths: Vec<&Path> = CUDNN_DEFAULT_PATHS.iter().map(Path::new).collect();
if let Some(override_path) = &cudnn_include_dir {
cudnn_paths.push(Path::new(override_path));
cudnn_paths.push(Path::new(override_path).to_path_buf());
}

cudnn_paths
Expand Down Expand Up @@ -108,3 +157,66 @@ impl CudnnSdk {
Ok([major, minor, patch])
}
}

#[cfg(test)]
mod tests {
use super::*;

#[cfg(target_os = "windows")]
#[test]
fn is_vxy_dir_name_accepts_valid_versions() {
assert!(CudnnSdk::is_vxy_dir_name("v9.0"));
assert!(CudnnSdk::is_vxy_dir_name("v10.6"));
assert!(CudnnSdk::is_vxy_dir_name("v12.1"));
assert!(CudnnSdk::is_vxy_dir_name("v13.0"));
assert!(CudnnSdk::is_vxy_dir_name("v9.10.3"));
}

#[cfg(target_os = "windows")]
#[test]
fn is_vxy_dir_name_rejects_invalid_versions() {
assert!(!CudnnSdk::is_vxy_dir_name("v"));
assert!(!CudnnSdk::is_vxy_dir_name("v9."));
assert!(!CudnnSdk::is_vxy_dir_name("v.9"));
assert!(!CudnnSdk::is_vxy_dir_name("v9.a"));
assert!(!CudnnSdk::is_vxy_dir_name("9.0"));
assert!(!CudnnSdk::is_vxy_dir_name("vx.y"));
assert!(!CudnnSdk::is_vxy_dir_name("random"));
}

#[cfg(target_os = "windows")]
#[test]
fn collect_windows_cudnn_include_paths_discovers_multiple_versions() {
use std::fs::create_dir_all;

let tmp_dir = env::temp_dir().join("cudnn_sdk_tests");
if tmp_dir.exists() {
fs::remove_dir_all(&tmp_dir).ok();
}

let cuda_base = tmp_dir.join("CUDA");
let v10_6 = cuda_base.join("v10.6");
let v12_1 = cuda_base.join("v12.1");
let v13_0 = cuda_base.join("v13.0");

for ver in [&v10_6, &v12_1, &v13_0] {
let include_dir = ver.join("include");
create_dir_all(&include_dir).unwrap();
fs::write(include_dir.join("cudnn.h"), "// stub").unwrap();
fs::write(include_dir.join("cudnn_version.h"), "// stub").unwrap();
}

// Also add some junk directories that should be ignored.
create_dir_all(cuda_base.join("vbad")).unwrap();
create_dir_all(cuda_base.join("not-a-version")).unwrap();

let bases: [&Path; 1] = [&cuda_base];
let mut paths = CudnnSdk::collect_windows_cudnn_include_paths(&bases);
paths.sort();

assert!(paths.contains(&v10_6.join("include")));
assert!(paths.contains(&v12_1.join("include")));
assert!(paths.contains(&v13_0.join("include")));
assert_eq!(paths.len(), 3);
}
}