diff --git a/vortex-array/benches/expr/case_when_bench.rs b/vortex-array/benches/expr/case_when_bench.rs index e25ad180fa2..11b4f3f7658 100644 --- a/vortex-array/benches/expr/case_when_bench.rs +++ b/vortex-array/benches/expr/case_when_bench.rs @@ -18,6 +18,7 @@ use vortex_array::expr::eq; use vortex_array::expr::get_item; use vortex_array::expr::gt; use vortex_array::expr::lit; +use vortex_array::expr::lt; use vortex_array::expr::nested_case_when; use vortex_array::expr::root; use vortex_array::session::ArraySession; @@ -185,6 +186,39 @@ fn case_when_all_true(bencher: Bencher, size: usize) { }); } +/// Benchmark n-ary CASE WHEN where the first branch dominates (~90% of rows). +/// This highlights the early-exit and deferred-merge optimizations: subsequent conditions +/// match no remaining rows and are skipped entirely. +#[divan::bench(args = [1000, 10000, 100000])] +fn case_when_nary_early_dominant(bencher: Bencher, size: usize) { + let array = make_struct_array(size); + + // CASE WHEN value < 90% THEN 1 WHEN value < 95% THEN 2 WHEN value < 97.5% THEN 3 ELSE 4 + let t1 = (size as i32 * 9) / 10; + let t2 = (size as i32 * 19) / 20; + let t3 = (size as i32 * 39) / 40; + + let expr = nested_case_when( + vec![ + (lt(get_item("value", root()), lit(t1)), lit(1i32)), + (lt(get_item("value", root()), lit(t2)), lit(2i32)), + (lt(get_item("value", root()), lit(t3)), lit(3i32)), + ], + Some(lit(4i32)), + ); + + bencher + .with_inputs(|| (&expr, &array)) + .bench_refs(|(expr, array)| { + let mut ctx = SESSION.create_execution_ctx(); + array + .apply(expr) + .unwrap() + .execute::(&mut ctx) + .unwrap() + }); +} + /// Benchmark CASE WHEN where all conditions are false. #[divan::bench(args = [1000, 10000, 100000])] fn case_when_all_false(bencher: Bencher, size: usize) { diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index 754a4e10964..a476a04b8ed 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -18,6 +18,8 @@ use std::sync::Arc; use prost::Message; use vortex_error::VortexResult; use vortex_error::vortex_bail; +use vortex_mask::AllOr; +use vortex_mask::Mask; use vortex_proto::expr as pb; use vortex_session::VortexSession; @@ -26,6 +28,7 @@ use crate::ExecutionCtx; use crate::IntoArray; use crate::arrays::BoolArray; use crate::arrays::ConstantArray; +use crate::builders::builder_with_capacity; use crate::dtype::DType; use crate::expr::Expression; use crate::scalar::Scalar; @@ -198,43 +201,55 @@ impl ScalarFnVTable for CaseWhen { args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx, ) -> VortexResult { + // Inspired by https://datafusion.apache.org/blog/2026/02/02/datafusion_case/ + // + // Implemented: short-circuit early exit; single-pass merge via `merge_case_branches`. + // Partial: single-branch uses `zip_impl` but THEN/ELSE still evaluated on the full batch. + // + // TODO: shrink input to `remaining` rows between WHEN iterations (batch reduction). + // TODO: project to only referenced columns before batch reduction (column projection). + // TODO: evaluate THEN/ELSE on compact matching/non-matching rows and merge without scatter. + // TODO: for constant WHEN/THEN values, compile to a hash table for a single-pass lookup. let row_count = args.row_count(); let num_pairs = options.num_when_then_pairs as usize; - let mut result: ArrayRef = if options.has_else { - args.get(num_pairs * 2)? - } else { - let then_dtype = args.get(1)?.dtype().as_nullable(); - ConstantArray::new(Scalar::null(then_dtype), row_count).into_array() - }; + let mut remaining = Mask::new_true(row_count); + let mut branches: Vec<(Mask, ArrayRef)> = Vec::with_capacity(num_pairs); - // TODO(perf): this reverse-zip approach touches every row for every condition. - // A left-to-right filter approach could maintain an "unmatched" mask, narrow it - // as conditions match, and exit early once all rows are resolved. - for i in (0..num_pairs).rev() { - let condition = args.get(i * 2)?; - let then_value = args.get(i * 2 + 1)?; + for i in 0..num_pairs { + if remaining.all_false() { + break; + } + let condition = args.get(i * 2)?; let cond_bool = condition.execute::(ctx)?; - let mask = cond_bool.to_mask_fill_null_false(); + let cond_mask = cond_bool.to_mask_fill_null_false(); + let effective_mask = &remaining & &cond_mask; - if mask.all_true() { - result = then_value; + if effective_mask.all_false() { continue; } - if mask.all_false() { - continue; - } + let then_value = args.get(i * 2 + 1)?; + remaining = remaining.bitand_not(&cond_mask); + branches.push((effective_mask, then_value)); + } + + let else_value: ArrayRef = if options.has_else { + args.get(num_pairs * 2)? + } else { + let then_dtype = args.get(1)?.dtype().as_nullable(); + ConstantArray::new(Scalar::null(then_dtype), row_count).into_array() + }; - result = zip_impl(&then_value, &result, &mask)?; + if branches.is_empty() { + return Ok(else_value); } - Ok(result) + merge_case_branches(branches, else_value) } fn is_null_sensitive(&self, _options: &Self::Options) -> bool { - // CaseWhen is null-sensitive because NULL conditions are treated as false true } @@ -243,6 +258,55 @@ impl ScalarFnVTable for CaseWhen { } } +/// Merges disjoint `(mask, then_value)` branch pairs with an `else_value` in a single pass. +/// +/// Branch masks are guaranteed disjoint by the remaining-row tracking in [`CaseWhen::execute`]. +fn merge_case_branches( + branches: Vec<(Mask, ArrayRef)>, + else_value: ArrayRef, +) -> VortexResult { + if branches.len() == 1 { + let (mask, then_value) = &branches[0]; + return zip_impl(then_value, &else_value, mask); + } + + let row_count = else_value.len(); + + let return_type = branches + .iter() + .fold(else_value.dtype().clone(), |acc, (_, arr)| { + acc.union_nullability(arr.dtype().nullability()) + }); + let mut builder = builder_with_capacity(&return_type, row_count); + + // Collect each branch's true-ranges tagged with branch index, then sort by position. + let mut events: Vec<(usize, usize, usize)> = Vec::new(); + for (branch_idx, (mask, _)) in branches.iter().enumerate() { + match mask.slices() { + AllOr::All => events.push((0, row_count, branch_idx)), + AllOr::None => {} + AllOr::Some(slices) => { + for &(start, end) in slices { + events.push((start, end, branch_idx)); + } + } + } + } + events.sort_unstable_by_key(|&(start, ..)| start); + + for (start, end, branch_idx) in &events { + if builder.len() < *start { + builder.extend_from_array(&else_value.slice(builder.len()..*start)?); + } + builder.extend_from_array(&branches[*branch_idx].1.slice(*start..*end)?); + } + if builder.len() < row_count { + builder.extend_from_array(&else_value.slice(builder.len()..row_count)?); + } + + Ok(builder.finish()) +} + #[cfg(test)] mod tests { use std::sync::LazyLock; @@ -254,11 +318,11 @@ mod tests { use super::*; use crate::Canonical; use crate::IntoArray; - use crate::ToCanonical; use crate::VortexSessionExecute as _; use crate::arrays::BoolArray; use crate::arrays::PrimitiveArray; use crate::arrays::StructArray; + use crate::assert_arrays_eq; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; @@ -595,8 +659,8 @@ mod tests { lit(0i32), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[0, 0, 100, 100, 100]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array()); } #[test] @@ -615,8 +679,8 @@ mod tests { Some(lit(0i32)), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[10, 0, 30, 0, 0]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![10i32, 0, 30, 0, 0].into_array()); } #[test] @@ -635,8 +699,8 @@ mod tests { Some(lit(0i32)), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[0, 0, 100, 100, 100]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array()); } #[test] @@ -650,26 +714,10 @@ mod tests { let result = evaluate_expr(&expr, &test_array); assert!(result.dtype().is_nullable()); - - assert_eq!( - result.scalar_at(0).unwrap(), - Scalar::null(result.dtype().clone()) - ); - assert_eq!( - result.scalar_at(1).unwrap(), - Scalar::null(result.dtype().clone()) - ); - assert_eq!( - result.scalar_at(2).unwrap(), - Scalar::null(result.dtype().clone()) - ); - assert_eq!( - result.scalar_at(3).unwrap(), - Scalar::from(100i32).cast(result.dtype()).unwrap() - ); - assert_eq!( - result.scalar_at(4).unwrap(), - Scalar::from(100i32).cast(result.dtype()).unwrap() + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([None::, None, None, Some(100), Some(100)]) + .into_array() ); } @@ -686,8 +734,8 @@ mod tests { lit(0i32), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[0, 0, 0, 0, 0]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![0i32, 0, 0, 0, 0].into_array()); } #[test] @@ -703,8 +751,67 @@ mod tests { lit(0i32), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[100, 100, 100, 100, 100]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![100i32, 100, 100, 100, 100].into_array()); + } + + #[test] + fn test_evaluate_all_true_no_else_returns_correct_dtype() { + // CASE WHEN value > 0 THEN 100 END — condition is always true, no ELSE. + // Result must be Nullable because the implicit ELSE is NULL. + let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())]) + .unwrap() + .into_array(); + + let expr = case_when_no_else(gt(get_item("value", root()), lit(0i32)), lit(100i32)); + + let result = evaluate_expr(&expr, &test_array); + assert!( + result.dtype().is_nullable(), + "result dtype must be Nullable, got {:?}", + result.dtype() + ); + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(100i32), Some(100), Some(100)]).into_array() + ); + } + + #[test] + fn test_merge_case_branches_widens_nullability_of_later_branch() -> VortexResult<()> { + // When a later THEN branch is Nullable and branches[0] and ELSE are NonNullable, + // the result dtype must still be Nullable. + // + // CASE WHEN value = 0 THEN 10 -- NonNullable + // WHEN value = 1 THEN nullable(20) -- Nullable + // ELSE 0 -- NonNullable + // → result must be Nullable(i32) + let test_array = StructArray::from_fields(&[("value", buffer![0i32, 1, 2].into_array())]) + .unwrap() + .into_array(); + + let nullable_20 = + Scalar::from(20i32).cast(&DType::Primitive(PType::I32, Nullability::Nullable))?; + + let expr = nested_case_when( + vec![ + (eq(get_item("value", root()), lit(0i32)), lit(10i32)), + (eq(get_item("value", root()), lit(1i32)), lit(nullable_20)), + ], + Some(lit(0i32)), + ); + + let result = evaluate_expr(&expr, &test_array); + assert!( + result.dtype().is_nullable(), + "result dtype must be Nullable, got {:?}", + result.dtype() + ); + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(10), Some(20), Some(0)]).into_array() + ); + Ok(()) } #[test] @@ -713,12 +820,7 @@ mod tests { let expr = case_when(lit(true), lit(100i32), lit(0i32)); let result = evaluate_expr(&expr, &test_array); - if let Some(constant) = result.as_constant() { - assert_eq!(constant, Scalar::from(100i32)); - } else { - let prim = result.to_primitive(); - assert_eq!(prim.as_slice::(), &[100, 100, 100]); - } + assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array()); } #[test] @@ -734,10 +836,10 @@ mod tests { lit(false), ); - let result = evaluate_expr(&expr, &test_array).to_bool(); - assert_eq!( - result.to_bit_buffer().iter().collect::>(), - vec![false, false, true, true, true] + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!( + result, + BoolArray::from_iter([false, false, true, true, true]).into_array() ); } @@ -752,8 +854,8 @@ mod tests { let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32)); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[100, 0, 0, 0, 100]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![100i32, 0, 0, 0, 100].into_array()); } #[test] @@ -776,8 +878,11 @@ mod tests { ); let result = evaluate_expr(&expr, &test_array); - let prim = result.to_primitive(); - assert_eq!(prim.as_slice::(), &[0, 0, 30, 40, 50]); + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(0i32), Some(0), Some(30), Some(40), Some(50)]) + .into_array() + ); } #[test] @@ -791,8 +896,8 @@ mod tests { let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32)); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[0, 0, 0]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![0i32, 0, 0].into_array()); } // ==================== N-ary Evaluate Tests ==================== @@ -815,26 +920,10 @@ mod tests { let result = evaluate_expr(&expr, &test_array); assert!(result.dtype().is_nullable()); - - assert_eq!( - result.scalar_at(0).unwrap(), - Scalar::from(10i32).cast(result.dtype()).unwrap() - ); - assert_eq!( - result.scalar_at(1).unwrap(), - Scalar::null(result.dtype().clone()) - ); - assert_eq!( - result.scalar_at(2).unwrap(), - Scalar::from(30i32).cast(result.dtype()).unwrap() - ); - assert_eq!( - result.scalar_at(3).unwrap(), - Scalar::null(result.dtype().clone()) - ); - assert_eq!( - result.scalar_at(4).unwrap(), - Scalar::null(result.dtype().clone()) + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), None, None]) + .into_array() ); } @@ -857,8 +946,8 @@ mod tests { Some(lit(0i32)), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[10, 20, 30, 40, 50]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![10i32, 20, 30, 40, 50].into_array()); } #[test] @@ -878,12 +967,10 @@ mod tests { let result = evaluate_expr(&expr, &test_array); assert!(result.dtype().is_nullable()); - for i in 0..3 { - assert_eq!( - result.scalar_at(i).unwrap(), - Scalar::null(result.dtype().clone()) - ); - } + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([None::, None, None]).into_array() + ); } #[test] @@ -905,9 +992,92 @@ mod tests { Some(lit(0i32)), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); + let result = evaluate_expr(&expr, &test_array); // First matching condition always wins - assert_eq!(result.as_slice::(), &[1, 1, 1]); + assert_arrays_eq!(result, buffer![1i32, 1, 1].into_array()); + } + + #[test] + fn test_evaluate_nary_early_exit_when_remaining_empty() { + // After branch 0 claims all rows, remaining becomes all_false. + // The loop breaks before evaluating branch 1's condition. + let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())]) + .unwrap() + .into_array(); + + let expr = nested_case_when( + vec![ + (gt(get_item("value", root()), lit(0i32)), lit(100i32)), + // Never evaluated due to early exit; 999 must never appear in output. + (gt(get_item("value", root()), lit(0i32)), lit(999i32)), + ], + Some(lit(0i32)), + ); + + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array()); + } + + #[test] + fn test_evaluate_nary_skips_branch_with_empty_effective_mask() { + // Branch 0 claims value=1. Branch 1 targets the same rows but they are already + // matched → effective_mask is all_false → branch 1 is skipped (THEN not used). + let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())]) + .unwrap() + .into_array(); + + let expr = nested_case_when( + vec![ + (eq(get_item("value", root()), lit(1i32)), lit(10i32)), + // Same condition as branch 0 — all matching rows already claimed → skipped. + // 999 must never appear in output. + (eq(get_item("value", root()), lit(1i32)), lit(999i32)), + (eq(get_item("value", root()), lit(2i32)), lit(20i32)), + ], + Some(lit(0i32)), + ); + + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array()); + } + + #[test] + fn test_evaluate_nary_string_output() -> VortexResult<()> { + // Exercises merge_case_branches with a non-primitive (Utf8) builder. + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4].into_array())]) + .unwrap() + .into_array(); + + // CASE WHEN value > 2 THEN 'high' WHEN value > 0 THEN 'low' ELSE 'none' END + // value=1,2 → 'low' (branch 1 after branch 0 claims 3,4) + // value=3,4 → 'high' (branch 0) + let expr = nested_case_when( + vec![ + (gt(get_item("value", root()), lit(2i32)), lit("high")), + (gt(get_item("value", root()), lit(0i32)), lit("low")), + ], + Some(lit("none")), + ); + + let result = evaluate_expr(&expr, &test_array); + assert_eq!( + result.scalar_at(0)?, + Scalar::utf8("low", Nullability::NonNullable) + ); + assert_eq!( + result.scalar_at(1)?, + Scalar::utf8("low", Nullability::NonNullable) + ); + assert_eq!( + result.scalar_at(2)?, + Scalar::utf8("high", Nullability::NonNullable) + ); + assert_eq!( + result.scalar_at(3)?, + Scalar::utf8("high", Nullability::NonNullable) + ); + Ok(()) } #[test] @@ -933,10 +1103,10 @@ mod tests { Some(lit(0i32)), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); + let result = evaluate_expr(&expr, &test_array); // row 0: cond1=true → 10 // row 1: cond1=NULL(→false), cond2=true → 20 // row 2: cond1=false, cond2=NULL(→false) → else=0 - assert_eq!(result.as_slice::(), &[10, 20, 0]); + assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array()); } } diff --git a/vortex-array/src/scalar_fn/fns/zip/mod.rs b/vortex-array/src/scalar_fn/fns/zip/mod.rs index 7a11dbd04dc..3d73ec3e639 100644 --- a/vortex-array/src/scalar_fn/fns/zip/mod.rs +++ b/vortex-array/src/scalar_fn/fns/zip/mod.rs @@ -126,11 +126,6 @@ impl ScalarFnVTable for Zip { return if_true.cast(return_dtype)?.execute(ctx); } - let return_dtype = if_true - .dtype() - .clone() - .union_nullability(if_false.dtype().nullability()); - if mask.all_false() { return if_false.cast(return_dtype)?.execute(ctx); } @@ -189,6 +184,14 @@ pub(crate) fn zip_impl( .dtype() .clone() .union_nullability(if_false.dtype().nullability()); + + if mask.all_true() { + return if_true.cast(return_type); + } + if mask.all_false() { + return if_false.cast(return_type); + } + zip_impl_with_builder( if_true, if_false, @@ -204,8 +207,11 @@ fn zip_impl_with_builder( mut builder: Box, ) -> VortexResult { match mask.slices() { - AllOr::All => Ok(if_true.to_array()), - AllOr::None => Ok(if_false.to_array()), + AllOr::All | AllOr::None => { + unreachable!( + "zip_impl_with_builder called with all-true or all-false mask; handle in zip_impl" + ) + } AllOr::Some(slices) => { for (start, end) in slices { builder.extend_from_array(&if_false.slice(builder.len()..*start)?); @@ -227,6 +233,7 @@ mod tests { use vortex_error::VortexResult; use vortex_mask::Mask; + use super::zip_impl; use crate::ArrayRef; use crate::IntoArray; use crate::LEGACY_SESSION; @@ -291,11 +298,57 @@ mod tests { PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), Some(40)]).into_array(); assert_arrays_eq!(result, expected); - - // result must be nullable even if_true was not assert_eq!(result.dtype(), if_false.dtype()) } + #[test] + fn test_zip_all_false_widens_nullability() { + let mask = Mask::new_false(4); + let if_true = + PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), None]).into_array(); + let if_false = buffer![1i32, 2, 3, 4].into_array(); + + let result = mask.into_array().zip(if_true.clone(), if_false).unwrap(); + let expected = + PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), Some(4)]).into_array(); + + assert_arrays_eq!(result, expected); + assert_eq!(result.dtype(), if_true.dtype()); + } + + #[test] + fn test_zip_impl_all_true_widens_nullability() -> VortexResult<()> { + let mask = Mask::new_true(4); + let if_true = buffer![10i32, 20, 30, 40].into_array(); + let if_false = + PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array(); + + let result = zip_impl(&if_true, &if_false, &mask)?; + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(10i32), Some(20), Some(30), Some(40)]) + .into_array() + ); + assert_eq!(result.dtype(), if_false.dtype()); + Ok(()) + } + + #[test] + fn test_zip_impl_all_false_widens_nullability() -> VortexResult<()> { + let mask = Mask::new_false(4); + let if_true = + PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), None]).into_array(); + let if_false = buffer![1i32, 2, 3, 4].into_array(); + + let result = zip_impl(&if_true, &if_false, &mask)?; + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(1i32), Some(2), Some(3), Some(4)]).into_array() + ); + assert_eq!(result.dtype(), if_true.dtype()); + Ok(()) + } + #[test] #[should_panic] fn test_invalid_lengths() { @@ -339,7 +392,6 @@ mod tests { buffer: views host 1.60 kB (align=16) (96.56%) "); - // test wrapped in a struct let wrapped1 = StructArray::try_from_iter([("nested", const1)])?.into_array(); let wrapped2 = StructArray::try_from_iter([("nested", const2)])?.into_array(); @@ -383,7 +435,6 @@ mod tests { builder.finish() }; - // [1,2,4,5,7,8,..] let mask = Mask::from_indices(200, (0..100).filter(|i| i % 3 != 0).collect()); let mask_array = mask.clone().into_array(); @@ -394,7 +445,6 @@ mod tests { .unwrap(); assert_eq!(zipped.nbuffers(), 2); - // assert the result is the same as arrow let expected = arrow_zip( mask.into_array() .into_arrow_preferred() diff --git a/vortex-mask/public-api.lock b/vortex-mask/public-api.lock index b0cce7abbf9..84aa1cf82f5 100644 --- a/vortex-mask/public-api.lock +++ b/vortex-mask/public-api.lock @@ -98,6 +98,10 @@ pub fn vortex_mask::Mask::values(&self) -> core::option::Option<&vortex_mask::Ma impl vortex_mask::Mask +pub fn vortex_mask::Mask::andnot(&self, rhs: &vortex_mask::Mask) -> vortex_mask::Mask + +impl vortex_mask::Mask + pub fn vortex_mask::Mask::intersect_by_rank(&self, mask: &vortex_mask::Mask) -> vortex_mask::Mask impl vortex_mask::Mask diff --git a/vortex-mask/src/bitops.rs b/vortex-mask/src/bitops.rs index a03778eafab..3ad681db161 100644 --- a/vortex-mask/src/bitops.rs +++ b/vortex-mask/src/bitops.rs @@ -46,6 +46,21 @@ impl BitOr for &Mask { } } +impl Mask { + /// Computes `self & !rhs` (AND NOT), equivalent to set difference. + pub fn bitand_not(self, rhs: &Mask) -> Mask { + if self.len() != rhs.len() { + vortex_panic!("Masks must have the same length"); + } + match (self.bit_buffer(), rhs.bit_buffer()) { + (AllOr::None, _) | (_, AllOr::All) => Mask::new_false(self.len()), + (_, AllOr::None) => self, + (AllOr::All, _) => !rhs, + (AllOr::Some(lhs), AllOr::Some(rhs)) => Mask::from_buffer(lhs.bitand_not(rhs)), + } + } +} + impl Not for Mask { type Output = Mask; @@ -353,6 +368,34 @@ mod tests { assert!(!result.value(3)); // (!(!false) | false) & !true = (false | false) & false = false } + #[test] + fn test_bitand_not() { + let a = Mask::from_buffer(BitBuffer::from_iter([true, true, false, false])); + let b = Mask::from_buffer(BitBuffer::from_iter([true, false, true, false])); + let result = a.clone().bitand_not(&b); + assert!(!result.value(0)); // true & !true = false + assert!(result.value(1)); // true & !false = true + assert!(!result.value(2)); // false & !true = false + assert!(!result.value(3)); // false & !false = false + + // bitand_not(All) = None + assert!(a.clone().bitand_not(&Mask::new_true(4)).all_false()); + + // bitand_not(None) = self + let none = Mask::new_false(4); + assert_eq!(a.clone().bitand_not(&none).true_count(), a.true_count()); + + // None.bitand_not(_) = None + assert!(none.bitand_not(&a).all_false()); + + // All.bitand_not(x) = !x + let not_b = !&b; + let all_bitand_not_b = Mask::new_true(4).bitand_not(&b); + for i in 0..4 { + assert_eq!(all_bitand_not_b.value(i), not_b.value(i)); + } + } + #[test] fn test_bitor() { // Test basic OR operations