diff --git a/crates/cudnn-sys/build/cudnn_sdk.rs b/crates/cudnn-sys/build/cudnn_sdk.rs index 8a71296c..729ce7dd 100644 --- a/crates/cudnn-sys/build/cudnn_sdk.rs +++ b/crates/cudnn-sys/build/cudnn_sdk.rs @@ -58,6 +58,38 @@ impl CudnnSdk { p.join("cudnn.h").is_file() && p.join("cudnn_version.h").is_file() } + #[cfg(target_os = "windows")] + fn is_vxy_dir_name(name: &str) -> bool { + name.starts_with('v') + && name[1..] + .split('.') + .all(|part| !part.is_empty() && part.chars().all(|c| c.is_ascii_digit())) + } + + #[cfg(target_os = "windows")] + fn collect_windows_cudnn_include_paths(bases: &[&Path]) -> Vec { + let mut paths = Vec::new(); + + for base in bases { + if let Ok(entries) = fs::read_dir(base) { + for entry in entries.flatten() { + if let Ok(file_type) = entry.file_type() { + if file_type.is_dir() { + let name = entry.file_name(); + if let Some(name_str) = name.to_str() { + if Self::is_vxy_dir_name(name_str) { + paths.push(base.join(name_str).join("include")); + } + } + } + } + } + } + } + + paths + } + fn find_cudnn_include_dir() -> Result> { let cudnn_include_dir = env::var_os("CUDNN_INCLUDE_DIR"); @@ -71,15 +103,32 @@ impl CudnnSdk { "/usr/local/include/x86_64-linux-gnu", "/usr/local/include/aarch64-linux-gnu", ]; + + #[cfg(not(target_os = "windows"))] + let mut cudnn_paths: Vec = + CUDNN_DEFAULT_PATHS.iter().map(Path::new).map(path::PathBuf::from).collect(); + #[cfg(target_os = "windows")] - const CUDNN_DEFAULT_PATHS: &[&str] = &[ - "C:/Program Files/NVIDIA/CUDNN/v9.x/include", - "C:/Program Files/NVIDIA/CUDNN/v8.x/include", - ]; + let mut cudnn_paths: Vec = { + // Legacy standalone cuDNN installs following NVIDIA's documentation. + let mut paths = vec![ + path::PathBuf::from("C:/Program Files/NVIDIA/CUDNN/v9.x/include"), + path::PathBuf::from("C:/Program Files/NVIDIA/CUDNN/v8.x/include"), + ]; + + // Dynamically discover CUDA and cuDNN installs by matching vX.Y-style directories. + let bases = [ + Path::new("C:/Program Files/NVIDIA/CUDNN"), + Path::new("C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA"), + ]; + + paths.extend(Self::collect_windows_cudnn_include_paths(&bases)); + + paths + }; - let mut cudnn_paths: Vec<&Path> = CUDNN_DEFAULT_PATHS.iter().map(Path::new).collect(); if let Some(override_path) = &cudnn_include_dir { - cudnn_paths.push(Path::new(override_path)); + cudnn_paths.push(Path::new(override_path).to_path_buf()); } cudnn_paths @@ -108,3 +157,66 @@ impl CudnnSdk { Ok([major, minor, patch]) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(target_os = "windows")] + #[test] + fn is_vxy_dir_name_accepts_valid_versions() { + assert!(CudnnSdk::is_vxy_dir_name("v9.0")); + assert!(CudnnSdk::is_vxy_dir_name("v10.6")); + assert!(CudnnSdk::is_vxy_dir_name("v12.1")); + assert!(CudnnSdk::is_vxy_dir_name("v13.0")); + assert!(CudnnSdk::is_vxy_dir_name("v9.10.3")); + } + + #[cfg(target_os = "windows")] + #[test] + fn is_vxy_dir_name_rejects_invalid_versions() { + assert!(!CudnnSdk::is_vxy_dir_name("v")); + assert!(!CudnnSdk::is_vxy_dir_name("v9.")); + assert!(!CudnnSdk::is_vxy_dir_name("v.9")); + assert!(!CudnnSdk::is_vxy_dir_name("v9.a")); + assert!(!CudnnSdk::is_vxy_dir_name("9.0")); + assert!(!CudnnSdk::is_vxy_dir_name("vx.y")); + assert!(!CudnnSdk::is_vxy_dir_name("random")); + } + + #[cfg(target_os = "windows")] + #[test] + fn collect_windows_cudnn_include_paths_discovers_multiple_versions() { + use std::fs::create_dir_all; + + let tmp_dir = env::temp_dir().join("cudnn_sdk_tests"); + if tmp_dir.exists() { + fs::remove_dir_all(&tmp_dir).ok(); + } + + let cuda_base = tmp_dir.join("CUDA"); + let v10_6 = cuda_base.join("v10.6"); + let v12_1 = cuda_base.join("v12.1"); + let v13_0 = cuda_base.join("v13.0"); + + for ver in [&v10_6, &v12_1, &v13_0] { + let include_dir = ver.join("include"); + create_dir_all(&include_dir).unwrap(); + fs::write(include_dir.join("cudnn.h"), "// stub").unwrap(); + fs::write(include_dir.join("cudnn_version.h"), "// stub").unwrap(); + } + + // Also add some junk directories that should be ignored. + create_dir_all(cuda_base.join("vbad")).unwrap(); + create_dir_all(cuda_base.join("not-a-version")).unwrap(); + + let bases: [&Path; 1] = [&cuda_base]; + let mut paths = CudnnSdk::collect_windows_cudnn_include_paths(&bases); + paths.sort(); + + assert!(paths.contains(&v10_6.join("include"))); + assert!(paths.contains(&v12_1.join("include"))); + assert!(paths.contains(&v13_0.join("include"))); + assert_eq!(paths.len(), 3); + } +}