From 772f7b661f118bdd1ece030914dd8b9183cd2c1e Mon Sep 17 00:00:00 2001 From: Charry Wu Date: Sat, 14 Mar 2026 08:27:38 -0700 Subject: [PATCH 1/3] fix(cudnn-sys): update Windows default cuDNN include paths Made-with: Cursor --- crates/cudnn-sys/build/cudnn_sdk.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/crates/cudnn-sys/build/cudnn_sdk.rs b/crates/cudnn-sys/build/cudnn_sdk.rs index 8a71296c..2d566aaa 100644 --- a/crates/cudnn-sys/build/cudnn_sdk.rs +++ b/crates/cudnn-sys/build/cudnn_sdk.rs @@ -73,8 +73,15 @@ impl CudnnSdk { ]; #[cfg(target_os = "windows")] const CUDNN_DEFAULT_PATHS: &[&str] = &[ + // Standalone cuDNN installs following NVIDIA's documentation. "C:/Program Files/NVIDIA/CUDNN/v9.x/include", "C:/Program Files/NVIDIA/CUDNN/v8.x/include", + // CUDA Toolkit installs that bundle cuDNN headers. + // These are the default Windows install locations for recent CUDA versions. + "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.2/include", + "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.8/include", + "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v12.6/include", + "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v13.0/include", ]; let mut cudnn_paths: Vec<&Path> = CUDNN_DEFAULT_PATHS.iter().map(Path::new).collect(); From dc8b52e5dfa2c52fcc5f6f372e97b9722222797f Mon Sep 17 00:00:00 2001 From: Charry Wu Date: Sat, 14 Mar 2026 08:38:57 -0700 Subject: [PATCH 2/3] feat(cudnn-sys): discover Windows cuDNN paths via vX.Y directories Made-with: Cursor --- crates/cudnn-sys/build/cudnn_sdk.rs | 56 ++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 13 deletions(-) diff --git a/crates/cudnn-sys/build/cudnn_sdk.rs b/crates/cudnn-sys/build/cudnn_sdk.rs index 2d566aaa..da8d5bfa 100644 --- a/crates/cudnn-sys/build/cudnn_sdk.rs +++ b/crates/cudnn-sys/build/cudnn_sdk.rs @@ -71,22 +71,52 @@ impl CudnnSdk { "/usr/local/include/x86_64-linux-gnu", "/usr/local/include/aarch64-linux-gnu", ]; + + #[cfg(not(target_os = "windows"))] + let mut cudnn_paths: Vec = + CUDNN_DEFAULT_PATHS.iter().map(Path::new).map(path::PathBuf::from).collect(); + #[cfg(target_os = "windows")] - const CUDNN_DEFAULT_PATHS: &[&str] = &[ - // Standalone cuDNN installs following NVIDIA's documentation. - "C:/Program Files/NVIDIA/CUDNN/v9.x/include", - "C:/Program Files/NVIDIA/CUDNN/v8.x/include", - // CUDA Toolkit installs that bundle cuDNN headers. - // These are the default Windows install locations for recent CUDA versions. - "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.2/include", - "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.8/include", - "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v12.6/include", - "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v13.0/include", - ]; + let mut cudnn_paths: Vec = { + // Legacy standalone cuDNN installs following NVIDIA's documentation. + let mut paths = vec![ + path::PathBuf::from("C:/Program Files/NVIDIA/CUDNN/v9.x/include"), + path::PathBuf::from("C:/Program Files/NVIDIA/CUDNN/v8.x/include"), + ]; + + // Dynamically discover CUDA and cuDNN installs by matching vX.Y-style directories. + let bases = [ + Path::new("C:/Program Files/NVIDIA/CUDNN"), + Path::new("C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA"), + ]; + + for base in bases { + if let Ok(entries) = fs::read_dir(base) { + for entry in entries.flatten() { + if let Ok(file_type) = entry.file_type() { + if file_type.is_dir() { + let name = entry.file_name(); + if let Some(name_str) = name.to_str() { + // Match directories like v9.0, v10.2, v13.0, etc. + if name_str.starts_with('v') + && name_str[1..] + .split('.') + .all(|part| !part.is_empty() && part.chars().all(|c| c.is_ascii_digit())) + { + paths.push(base.join(name_str).join("include")); + } + } + } + } + } + } + } + + paths + }; - let mut cudnn_paths: Vec<&Path> = CUDNN_DEFAULT_PATHS.iter().map(Path::new).collect(); if let Some(override_path) = &cudnn_include_dir { - cudnn_paths.push(Path::new(override_path)); + cudnn_paths.push(Path::new(override_path).to_path_buf()); } cudnn_paths From a2867b7cb7ca94fc2e25bdc8b40e516babfd7fa1 Mon Sep 17 00:00:00 2001 From: Charry Wu Date: Tue, 17 Mar 2026 14:13:08 -0700 Subject: [PATCH 3/3] test(cudnn-sys): add windows vX.Y path discovery tests Made-with: Cursor --- crates/cudnn-sys/build/cudnn_sdk.rs | 117 +++++++++++++++++++++++----- 1 file changed, 96 insertions(+), 21 deletions(-) diff --git a/crates/cudnn-sys/build/cudnn_sdk.rs b/crates/cudnn-sys/build/cudnn_sdk.rs index da8d5bfa..729ce7dd 100644 --- a/crates/cudnn-sys/build/cudnn_sdk.rs +++ b/crates/cudnn-sys/build/cudnn_sdk.rs @@ -58,6 +58,38 @@ impl CudnnSdk { p.join("cudnn.h").is_file() && p.join("cudnn_version.h").is_file() } + #[cfg(target_os = "windows")] + fn is_vxy_dir_name(name: &str) -> bool { + name.starts_with('v') + && name[1..] + .split('.') + .all(|part| !part.is_empty() && part.chars().all(|c| c.is_ascii_digit())) + } + + #[cfg(target_os = "windows")] + fn collect_windows_cudnn_include_paths(bases: &[&Path]) -> Vec { + let mut paths = Vec::new(); + + for base in bases { + if let Ok(entries) = fs::read_dir(base) { + for entry in entries.flatten() { + if let Ok(file_type) = entry.file_type() { + if file_type.is_dir() { + let name = entry.file_name(); + if let Some(name_str) = name.to_str() { + if Self::is_vxy_dir_name(name_str) { + paths.push(base.join(name_str).join("include")); + } + } + } + } + } + } + } + + paths + } + fn find_cudnn_include_dir() -> Result> { let cudnn_include_dir = env::var_os("CUDNN_INCLUDE_DIR"); @@ -90,27 +122,7 @@ impl CudnnSdk { Path::new("C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA"), ]; - for base in bases { - if let Ok(entries) = fs::read_dir(base) { - for entry in entries.flatten() { - if let Ok(file_type) = entry.file_type() { - if file_type.is_dir() { - let name = entry.file_name(); - if let Some(name_str) = name.to_str() { - // Match directories like v9.0, v10.2, v13.0, etc. - if name_str.starts_with('v') - && name_str[1..] - .split('.') - .all(|part| !part.is_empty() && part.chars().all(|c| c.is_ascii_digit())) - { - paths.push(base.join(name_str).join("include")); - } - } - } - } - } - } - } + paths.extend(Self::collect_windows_cudnn_include_paths(&bases)); paths }; @@ -145,3 +157,66 @@ impl CudnnSdk { Ok([major, minor, patch]) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(target_os = "windows")] + #[test] + fn is_vxy_dir_name_accepts_valid_versions() { + assert!(CudnnSdk::is_vxy_dir_name("v9.0")); + assert!(CudnnSdk::is_vxy_dir_name("v10.6")); + assert!(CudnnSdk::is_vxy_dir_name("v12.1")); + assert!(CudnnSdk::is_vxy_dir_name("v13.0")); + assert!(CudnnSdk::is_vxy_dir_name("v9.10.3")); + } + + #[cfg(target_os = "windows")] + #[test] + fn is_vxy_dir_name_rejects_invalid_versions() { + assert!(!CudnnSdk::is_vxy_dir_name("v")); + assert!(!CudnnSdk::is_vxy_dir_name("v9.")); + assert!(!CudnnSdk::is_vxy_dir_name("v.9")); + assert!(!CudnnSdk::is_vxy_dir_name("v9.a")); + assert!(!CudnnSdk::is_vxy_dir_name("9.0")); + assert!(!CudnnSdk::is_vxy_dir_name("vx.y")); + assert!(!CudnnSdk::is_vxy_dir_name("random")); + } + + #[cfg(target_os = "windows")] + #[test] + fn collect_windows_cudnn_include_paths_discovers_multiple_versions() { + use std::fs::create_dir_all; + + let tmp_dir = env::temp_dir().join("cudnn_sdk_tests"); + if tmp_dir.exists() { + fs::remove_dir_all(&tmp_dir).ok(); + } + + let cuda_base = tmp_dir.join("CUDA"); + let v10_6 = cuda_base.join("v10.6"); + let v12_1 = cuda_base.join("v12.1"); + let v13_0 = cuda_base.join("v13.0"); + + for ver in [&v10_6, &v12_1, &v13_0] { + let include_dir = ver.join("include"); + create_dir_all(&include_dir).unwrap(); + fs::write(include_dir.join("cudnn.h"), "// stub").unwrap(); + fs::write(include_dir.join("cudnn_version.h"), "// stub").unwrap(); + } + + // Also add some junk directories that should be ignored. + create_dir_all(cuda_base.join("vbad")).unwrap(); + create_dir_all(cuda_base.join("not-a-version")).unwrap(); + + let bases: [&Path; 1] = [&cuda_base]; + let mut paths = CudnnSdk::collect_windows_cudnn_include_paths(&bases); + paths.sort(); + + assert!(paths.contains(&v10_6.join("include"))); + assert!(paths.contains(&v12_1.join("include"))); + assert!(paths.contains(&v13_0.join("include"))); + assert_eq!(paths.len(), 3); + } +}