Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions vortex-cuda/cub/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
use std::path::PathBuf;
use std::sync::OnceLock;

use vortex_cuda_macros::cuda_tests;

/// Raw FFI type definitions and dynamically-loaded function pointers from bindgen.
#[allow(
non_upper_case_globals,
Expand Down Expand Up @@ -60,26 +58,26 @@ pub fn cub_library() -> Result<&'static sys::CubLibrary, CubError> {
.map_err(|e| CubError::LibraryLoadError(e.clone()))
}

#[cuda_tests]
#[cfg(test)]
mod tests {
use crate::filter;

#[test]
#[vortex_cuda_macros::test]
fn test_filter_temp_size_u64() -> Result<(), crate::CubError> {
let temp_bytes = filter::filter_get_temp_size_u64(1000)?;
// CUB requires some temporary storage
assert!(temp_bytes > 0);
Ok(())
}

#[test]
#[vortex_cuda_macros::test]
fn test_filter_temp_size_f64() -> Result<(), crate::CubError> {
let temp_bytes = filter::filter_get_temp_size_f64(10000)?;
assert!(temp_bytes > 0);
Ok(())
}

#[test]
#[vortex_cuda_macros::test]
fn test_filter_temp_size_zero_items() -> Result<(), crate::CubError> {
// Just verify the call doesn't fail with zero items
let _temp_bytes = filter::filter_get_temp_size_u8(0)?;
Expand Down
45 changes: 33 additions & 12 deletions vortex-cuda/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
//! #[cuda_not_available]
//! fn fallback_function() { /* ... */ }
//!
//! // Only compiled in test builds when CUDA is available
//! #[cuda_test]
//! mod tests {
//! // ...
//! // Ignore tests when CUDA is not available
//! #[crate::test]
//! async fn my_test() {
//! ...
//! }
//! ```

Expand All @@ -30,7 +30,6 @@ use std::sync::LazyLock;

use proc_macro::TokenStream;
use quote::quote;
use syn::Item;
use syn::parse_macro_input;

/// Cached result of nvcc availability check.
Expand Down Expand Up @@ -61,17 +60,39 @@ pub fn cuda_not_available(_attr: TokenStream, item: TokenStream) -> TokenStream
}
}

/// Conditionally compiles the annotated item only in test builds when CUDA is available.
/// Test attribute to ignore tests if CUDA isn't available. Supports both sync and async tests (using tokio).
///
/// Must be named `test` to work with frameworks like `rstest`.
#[proc_macro_attribute]
pub fn cuda_tests(_attr: TokenStream, item: TokenStream) -> TokenStream {
pub fn test(_attr: TokenStream, item: TokenStream) -> TokenStream {
let item = parse_macro_input!(item as syn::ItemFn);
if *NVCC_AVAILABLE {
let item = parse_macro_input!(item as Item);
quote! {
#[cfg(test)]
#item
if item.sig.asyncness.is_some() {
quote! {
#[tokio::test]
#item
}
} else {
quote! {
#[test]
#item
}
}
.into()
} else {
TokenStream::new()
if item.sig.asyncness.is_some() {
quote! {
#[tokio::test]
#[ignore]
#item
}
} else {
quote! {
#[test]
#[ignore]
#item
}
}
.into()
}
}
6 changes: 2 additions & 4 deletions vortex-cuda/nvcomp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ mod error;
pub mod zstd;

pub use error::NvcompError;
use vortex_cuda_macros::cuda_tests;

/// The loaded nvcomp library instance.
static NVCOMP_LIB: OnceLock<Result<sys::NvcompLibrary, String>> = OnceLock::new();
Expand Down Expand Up @@ -66,13 +65,12 @@ pub fn nvcomp_library() -> Result<&'static sys::NvcompLibrary, NvcompError> {
.as_ref()
.map_err(|e| NvcompError::LibraryLoadError(e.clone()))
}

#[cuda_tests]
#[cfg(test)]
mod tests {
use crate::zstd;

/// Test that we can call nvcompBatchedZstdDecompressGetTempSizeAsync.
#[test]
#[vortex_cuda_macros::test]
fn test_get_decompress_temp_size() {
let num_chunks = 10;
let max_uncompressed_chunk_bytes = 65536; // 64KB recommended chunk size
Expand Down
16 changes: 7 additions & 9 deletions vortex-cuda/src/arrow/canonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ use vortex::error::VortexResult;
use vortex::error::vortex_bail;
use vortex::error::vortex_ensure;
use vortex::extension::datetime::AnyTemporal;
use vortex_cuda_macros::cuda_tests;

use crate::CudaExecutionCtx;
use crate::arrow::ArrowArray;
Expand Down Expand Up @@ -275,8 +274,7 @@ unsafe extern "C" fn release_array(array: *mut ArrowArray) {
}
}

#[cuda_tests]
#[allow(clippy::unwrap_used)]
#[cfg(test)]
mod tests {
use rstest::rstest;
use vortex::array::IntoArray;
Expand Down Expand Up @@ -308,7 +306,7 @@ mod tests {
#[case::i64(PrimitiveArray::from_iter(0i64..10).into_array(), 10)]
#[case::f32(PrimitiveArray::from_iter([1.0f32, 2.0, 3.0]).into_array(), 3)]
#[case::f64(PrimitiveArray::from_iter([1.0f64, 2.0, 3.0]).into_array(), 3)]
#[tokio::test]
#[crate::test]
async fn test_export_primitive(
#[case] array: vortex::array::ArrayRef,
#[case] expected_len: i64,
Expand All @@ -330,7 +328,7 @@ mod tests {
Ok(())
}

#[tokio::test]
#[crate::test]
async fn test_export_null() -> VortexResult<()> {
let mut ctx = CudaSession::create_execution_ctx(&VortexSession::empty())
.vortex_expect("failed to create execution context");
Expand All @@ -346,7 +344,7 @@ mod tests {
Ok(())
}

#[tokio::test]
#[crate::test]
async fn test_export_decimal() -> VortexResult<()> {
let mut ctx = CudaSession::create_execution_ctx(&VortexSession::empty())
.vortex_expect("failed to create execution context");
Expand All @@ -365,7 +363,7 @@ mod tests {
Ok(())
}

#[tokio::test]
#[crate::test]
async fn test_export_temporal() -> VortexResult<()> {
let mut ctx = CudaSession::create_execution_ctx(&VortexSession::empty())
.vortex_expect("failed to create execution context");
Expand All @@ -388,7 +386,7 @@ mod tests {
Ok(())
}

#[tokio::test]
#[crate::test]
async fn test_export_varbinview() -> VortexResult<()> {
let mut ctx = CudaSession::create_execution_ctx(&VortexSession::empty())
.vortex_expect("failed to create execution context");
Expand All @@ -413,7 +411,7 @@ mod tests {
Ok(())
}

#[tokio::test]
#[crate::test]
async fn test_export_struct() -> VortexResult<()> {
let mut ctx = CudaSession::create_execution_ctx(&VortexSession::empty())
.vortex_expect("failed to create execution context");
Expand Down
Loading
Loading