From 768bf0ff930d52658ec95f9706f30a081a631af5 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Thu, 5 Mar 2026 16:02:27 +0000 Subject: [PATCH 1/8] forward pass case when Signed-off-by: Baris Palaska --- vortex-array/benches/expr/case_when_bench.rs | 34 +++ vortex-array/src/scalar_fn/fns/case_when.rs | 239 +++++++++++++++++-- vortex-array/src/scalar_fn/fns/zip/mod.rs | 33 ++- 3 files changed, 281 insertions(+), 25 deletions(-) diff --git a/vortex-array/benches/expr/case_when_bench.rs b/vortex-array/benches/expr/case_when_bench.rs index e25ad180fa2..11b4f3f7658 100644 --- a/vortex-array/benches/expr/case_when_bench.rs +++ b/vortex-array/benches/expr/case_when_bench.rs @@ -18,6 +18,7 @@ use vortex_array::expr::eq; use vortex_array::expr::get_item; use vortex_array::expr::gt; use vortex_array::expr::lit; +use vortex_array::expr::lt; use vortex_array::expr::nested_case_when; use vortex_array::expr::root; use vortex_array::session::ArraySession; @@ -185,6 +186,39 @@ fn case_when_all_true(bencher: Bencher, size: usize) { }); } +/// Benchmark n-ary CASE WHEN where the first branch dominates (~90% of rows). +/// This highlights the early-exit and deferred-merge optimizations: subsequent conditions +/// match no remaining rows and are skipped entirely. +#[divan::bench(args = [1000, 10000, 100000])] +fn case_when_nary_early_dominant(bencher: Bencher, size: usize) { + let array = make_struct_array(size); + + // CASE WHEN value < 90% THEN 1 WHEN value < 95% THEN 2 WHEN value < 97.5% THEN 3 ELSE 4 + let t1 = (size as i32 * 9) / 10; + let t2 = (size as i32 * 19) / 20; + let t3 = (size as i32 * 39) / 40; + + let expr = nested_case_when( + vec![ + (lt(get_item("value", root()), lit(t1)), lit(1i32)), + (lt(get_item("value", root()), lit(t2)), lit(2i32)), + (lt(get_item("value", root()), lit(t3)), lit(3i32)), + ], + Some(lit(4i32)), + ); + + bencher + .with_inputs(|| (&expr, &array)) + .bench_refs(|(expr, array)| { + let mut ctx = SESSION.create_execution_ctx(); + array + .apply(expr) + .unwrap() + .execute::(&mut ctx) + .unwrap() + }); +} + /// Benchmark CASE WHEN where all conditions are false. #[divan::bench(args = [1000, 10000, 100000])] fn case_when_all_false(bencher: Bencher, size: usize) { diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index f701548f145..ab425f50226 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -11,6 +11,8 @@ use std::sync::Arc; use prost::Message; use vortex_error::VortexResult; use vortex_error::vortex_bail; +use vortex_mask::AllOr; +use vortex_mask::Mask; use vortex_proto::expr as pb; use vortex_session::VortexSession; @@ -19,6 +21,7 @@ use crate::ExecutionCtx; use crate::IntoArray; use crate::arrays::BoolArray; use crate::arrays::ConstantArray; +use crate::builders::builder_with_capacity; use crate::dtype::DType; use crate::expr::Expression; use crate::scalar::Scalar; @@ -191,37 +194,45 @@ impl ScalarFnVTable for CaseWhen { let row_count = args.row_count(); let num_pairs = options.num_when_then_pairs as usize; - let mut result: ArrayRef = if options.has_else { - args.get(num_pairs * 2)? - } else { - let then_dtype = args.get(1)?.dtype().as_nullable(); - ConstantArray::new(Scalar::null(then_dtype), row_count).into_array() - }; + // Track unmatched rows; AND each condition with `remaining` to enforce first-match-wins + // and produce disjoint branch masks. + let mut remaining = Mask::new_true(row_count); + let mut branches: Vec<(Mask, ArrayRef)> = Vec::with_capacity(num_pairs); - for i in (0..num_pairs).rev() { - let condition = args.get(i * 2)?; - let then_value = args.get(i * 2 + 1)?; + for i in 0..num_pairs { + if remaining.all_false() { + break; + } + let condition = args.get(i * 2)?; let cond_bool = condition.execute::(ctx)?; - let mask = cond_bool.to_mask_fill_null_false(); + let cond_mask = cond_bool.to_mask_fill_null_false(); + let effective_mask = &remaining & &cond_mask; - if mask.all_true() { - result = then_value; + if effective_mask.all_false() { continue; } - if mask.all_false() { - continue; - } + let then_value = args.get(i * 2 + 1)?; + remaining = &remaining & &(!&cond_mask); + branches.push((effective_mask, then_value)); + } - result = zip_impl(&then_value, &result, &mask)?; + let else_value: ArrayRef = if options.has_else { + args.get(num_pairs * 2)? + } else { + let then_dtype = args.get(1)?.dtype().as_nullable(); + ConstantArray::new(Scalar::null(then_dtype), row_count).into_array() + }; + + if branches.is_empty() { + return Ok(else_value); } - Ok(result) + merge_case_branches(branches, else_value) } fn is_null_sensitive(&self, _options: &Self::Options) -> bool { - // CaseWhen is null-sensitive because NULL conditions are treated as false true } @@ -230,6 +241,55 @@ impl ScalarFnVTable for CaseWhen { } } +/// Merges disjoint `(mask, then_value)` branch pairs with an `else_value` in a single pass. +/// +/// Branch masks are guaranteed disjoint by the remaining-row tracking in [`CaseWhen::execute`]. +fn merge_case_branches( + branches: Vec<(Mask, ArrayRef)>, + else_value: ArrayRef, +) -> VortexResult { + if branches.len() == 1 { + let (mask, then_value) = &branches[0]; + return zip_impl(then_value, &else_value, mask); + } + + let row_count = else_value.len(); + + let return_type = branches + .iter() + .fold(else_value.dtype().clone(), |acc, (_, arr)| { + acc.union_nullability(arr.dtype().nullability()) + }); + let mut builder = builder_with_capacity(&return_type, row_count); + + // Collect each branch's true-ranges tagged with branch index, then sort by position. + let mut events: Vec<(usize, usize, usize)> = Vec::new(); + for (branch_idx, (mask, _)) in branches.iter().enumerate() { + match mask.slices() { + AllOr::All => events.push((0, row_count, branch_idx)), + AllOr::None => {} + AllOr::Some(slices) => { + for &(start, end) in slices { + events.push((start, end, branch_idx)); + } + } + } + } + events.sort_unstable_by_key(|&(start, ..)| start); + + for (start, end, branch_idx) in &events { + if builder.len() < *start { + builder.extend_from_array(&else_value.slice(builder.len()..*start)?); + } + builder.extend_from_array(&branches[*branch_idx].1.slice(*start..*end)?); + } + if builder.len() < row_count { + builder.extend_from_array(&else_value.slice(builder.len()..row_count)?); + } + + Ok(builder.finish()) +} + #[cfg(test)] mod tests { use std::sync::LazyLock; @@ -246,6 +306,7 @@ mod tests { use crate::arrays::BoolArray; use crate::arrays::PrimitiveArray; use crate::arrays::StructArray; + use crate::assert_arrays_eq; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; @@ -690,6 +751,65 @@ mod tests { assert_eq!(result.as_slice::(), &[100, 100, 100, 100, 100]); } + #[test] + fn test_evaluate_all_true_no_else_returns_correct_dtype() { + // CASE WHEN value > 0 THEN 100 END — condition is always true, no ELSE. + // Result must be Nullable because the implicit ELSE is NULL. + let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())]) + .unwrap() + .into_array(); + + let expr = case_when_no_else(gt(get_item("value", root()), lit(0i32)), lit(100i32)); + + let result = evaluate_expr(&expr, &test_array); + assert!( + result.dtype().is_nullable(), + "result dtype must be Nullable, got {:?}", + result.dtype() + ); + assert_eq!( + result.scalar_at(0).unwrap(), + Scalar::from(100i32).cast(result.dtype()).unwrap() + ); + } + + #[test] + fn test_merge_case_branches_widens_nullability_of_later_branch() -> VortexResult<()> { + // When a later THEN branch is Nullable and branches[0] and ELSE are NonNullable, + // the result dtype must still be Nullable. + // + // CASE WHEN value = 0 THEN 10 -- NonNullable + // WHEN value = 1 THEN nullable(20) -- Nullable + // ELSE 0 -- NonNullable + // → result must be Nullable(i32) + let test_array = StructArray::from_fields(&[("value", buffer![0i32, 1, 2].into_array())]) + .unwrap() + .into_array(); + + let nullable_20 = + Scalar::from(20i32).cast(&DType::Primitive(PType::I32, Nullability::Nullable))?; + + let expr = nested_case_when( + vec![ + (eq(get_item("value", root()), lit(0i32)), lit(10i32)), + (eq(get_item("value", root()), lit(1i32)), lit(nullable_20)), + ], + Some(lit(0i32)), + ); + + let result = evaluate_expr(&expr, &test_array); + assert!( + result.dtype().is_nullable(), + "result dtype must be Nullable, got {:?}", + result.dtype() + ); + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(10), Some(20), Some(0)]).into_array() + ); + Ok(()) + } + #[test] fn test_evaluate_with_literal_condition() { let test_array = buffer![1i32, 2, 3].into_array(); @@ -893,6 +1013,89 @@ mod tests { assert_eq!(result.as_slice::(), &[1, 1, 1]); } + #[test] + fn test_evaluate_nary_early_exit_when_remaining_empty() { + // After branch 0 claims all rows, remaining becomes all_false. + // The loop breaks before evaluating branch 1's condition. + let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())]) + .unwrap() + .into_array(); + + let expr = nested_case_when( + vec![ + (gt(get_item("value", root()), lit(0i32)), lit(100i32)), + // Never evaluated due to early exit; 999 must never appear in output. + (gt(get_item("value", root()), lit(0i32)), lit(999i32)), + ], + Some(lit(0i32)), + ); + + let result = evaluate_expr(&expr, &test_array).to_primitive(); + assert_eq!(result.as_slice::(), &[100, 100, 100]); + } + + #[test] + fn test_evaluate_nary_skips_branch_with_empty_effective_mask() { + // Branch 0 claims value=1. Branch 1 targets the same rows but they are already + // matched → effective_mask is all_false → branch 1 is skipped (THEN not used). + let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())]) + .unwrap() + .into_array(); + + let expr = nested_case_when( + vec![ + (eq(get_item("value", root()), lit(1i32)), lit(10i32)), + // Same condition as branch 0 — all matching rows already claimed → skipped. + // 999 must never appear in output. + (eq(get_item("value", root()), lit(1i32)), lit(999i32)), + (eq(get_item("value", root()), lit(2i32)), lit(20i32)), + ], + Some(lit(0i32)), + ); + + let result = evaluate_expr(&expr, &test_array).to_primitive(); + assert_eq!(result.as_slice::(), &[10, 20, 0]); + } + + #[test] + fn test_evaluate_nary_string_output() -> VortexResult<()> { + // Exercises merge_case_branches with a non-primitive (Utf8) builder. + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4].into_array())]) + .unwrap() + .into_array(); + + // CASE WHEN value > 2 THEN 'high' WHEN value > 0 THEN 'low' ELSE 'none' END + // value=1,2 → 'low' (branch 1 after branch 0 claims 3,4) + // value=3,4 → 'high' (branch 0) + let expr = nested_case_when( + vec![ + (gt(get_item("value", root()), lit(2i32)), lit("high")), + (gt(get_item("value", root()), lit(0i32)), lit("low")), + ], + Some(lit("none")), + ); + + let result = evaluate_expr(&expr, &test_array); + assert_eq!( + result.scalar_at(0)?, + Scalar::utf8("low", Nullability::NonNullable) + ); + assert_eq!( + result.scalar_at(1)?, + Scalar::utf8("low", Nullability::NonNullable) + ); + assert_eq!( + result.scalar_at(2)?, + Scalar::utf8("high", Nullability::NonNullable) + ); + assert_eq!( + result.scalar_at(3)?, + Scalar::utf8("high", Nullability::NonNullable) + ); + Ok(()) + } + #[test] fn test_evaluate_nary_with_nullable_conditions() { let test_array = StructArray::from_fields(&[ diff --git a/vortex-array/src/scalar_fn/fns/zip/mod.rs b/vortex-array/src/scalar_fn/fns/zip/mod.rs index 7a11dbd04dc..8b25fec760c 100644 --- a/vortex-array/src/scalar_fn/fns/zip/mod.rs +++ b/vortex-array/src/scalar_fn/fns/zip/mod.rs @@ -126,11 +126,6 @@ impl ScalarFnVTable for Zip { return if_true.cast(return_dtype)?.execute(ctx); } - let return_dtype = if_true - .dtype() - .clone() - .union_nullability(if_false.dtype().nullability()); - if mask.all_false() { return if_false.cast(return_dtype)?.execute(ctx); } @@ -204,8 +199,14 @@ fn zip_impl_with_builder( mut builder: Box, ) -> VortexResult { match mask.slices() { - AllOr::All => Ok(if_true.to_array()), - AllOr::None => Ok(if_false.to_array()), + AllOr::All => { + builder.extend_from_array(if_true); + Ok(builder.finish()) + } + AllOr::None => { + builder.extend_from_array(if_false); + Ok(builder.finish()) + } AllOr::Some(slices) => { for (start, end) in slices { builder.extend_from_array(&if_false.slice(builder.len()..*start)?); @@ -296,6 +297,24 @@ mod tests { assert_eq!(result.dtype(), if_false.dtype()) } + /// When the mask is all-false and `if_true` is Nullable, the result dtype must be Nullable + /// even though `if_false` is NonNullable. + #[test] + fn test_zip_all_false_widens_nullability() { + let mask = Mask::new_false(4); + let if_true = + PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), None]).into_array(); + let if_false = buffer![1i32, 2, 3, 4].into_array(); + + let result = mask.into_array().zip(if_true.clone(), if_false).unwrap(); + let expected = + PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), Some(4)]).into_array(); + + assert_arrays_eq!(result, expected); + // result must be nullable even though if_false was not + assert_eq!(result.dtype(), if_true.dtype()); + } + #[test] #[should_panic] fn test_invalid_lengths() { From 1d4b94706c1d82ccd54183c3ad4876d0e1570f72 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Thu, 5 Mar 2026 17:16:31 +0000 Subject: [PATCH 2/8] assert_arrays_eq Signed-off-by: Baris Palaska --- vortex-array/src/scalar_fn/fns/case_when.rs | 136 +++++++------------- 1 file changed, 50 insertions(+), 86 deletions(-) diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index ab425f50226..e5e29eeee08 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -301,7 +301,6 @@ mod tests { use super::*; use crate::Canonical; use crate::IntoArray; - use crate::ToCanonical; use crate::VortexSessionExecute as _; use crate::arrays::BoolArray; use crate::arrays::PrimitiveArray; @@ -639,8 +638,8 @@ mod tests { lit(0i32), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[0, 0, 100, 100, 100]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array()); } #[test] @@ -659,8 +658,8 @@ mod tests { Some(lit(0i32)), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[10, 0, 30, 0, 0]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![10i32, 0, 30, 0, 0].into_array()); } #[test] @@ -679,8 +678,8 @@ mod tests { Some(lit(0i32)), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[0, 0, 100, 100, 100]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array()); } #[test] @@ -694,26 +693,10 @@ mod tests { let result = evaluate_expr(&expr, &test_array); assert!(result.dtype().is_nullable()); - - assert_eq!( - result.scalar_at(0).unwrap(), - Scalar::null(result.dtype().clone()) - ); - assert_eq!( - result.scalar_at(1).unwrap(), - Scalar::null(result.dtype().clone()) - ); - assert_eq!( - result.scalar_at(2).unwrap(), - Scalar::null(result.dtype().clone()) - ); - assert_eq!( - result.scalar_at(3).unwrap(), - Scalar::from(100i32).cast(result.dtype()).unwrap() - ); - assert_eq!( - result.scalar_at(4).unwrap(), - Scalar::from(100i32).cast(result.dtype()).unwrap() + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([None::, None, None, Some(100), Some(100)]) + .into_array() ); } @@ -730,8 +713,8 @@ mod tests { lit(0i32), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[0, 0, 0, 0, 0]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![0i32, 0, 0, 0, 0].into_array()); } #[test] @@ -747,8 +730,8 @@ mod tests { lit(0i32), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[100, 100, 100, 100, 100]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![100i32, 100, 100, 100, 100].into_array()); } #[test] @@ -767,9 +750,9 @@ mod tests { "result dtype must be Nullable, got {:?}", result.dtype() ); - assert_eq!( - result.scalar_at(0).unwrap(), - Scalar::from(100i32).cast(result.dtype()).unwrap() + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(100i32), Some(100), Some(100)]).into_array() ); } @@ -816,12 +799,7 @@ mod tests { let expr = case_when(lit(true), lit(100i32), lit(0i32)); let result = evaluate_expr(&expr, &test_array); - if let Some(constant) = result.as_constant() { - assert_eq!(constant, Scalar::from(100i32)); - } else { - let prim = result.to_primitive(); - assert_eq!(prim.as_slice::(), &[100, 100, 100]); - } + assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array()); } #[test] @@ -837,10 +815,11 @@ mod tests { lit(false), ); - let result = evaluate_expr(&expr, &test_array).to_bool(); - assert_eq!( - result.to_bit_buffer().iter().collect::>(), - vec![false, false, true, true, true] + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!( + result, + BoolArray::from_iter([Some(false), Some(false), Some(true), Some(true), Some(true)]) + .into_array() ); } @@ -855,8 +834,8 @@ mod tests { let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32)); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[100, 0, 0, 0, 100]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![100i32, 0, 0, 0, 100].into_array()); } #[test] @@ -879,8 +858,11 @@ mod tests { ); let result = evaluate_expr(&expr, &test_array); - let prim = result.to_primitive(); - assert_eq!(prim.as_slice::(), &[0, 0, 30, 40, 50]); + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(0i32), Some(0), Some(30), Some(40), Some(50)]) + .into_array() + ); } #[test] @@ -894,8 +876,8 @@ mod tests { let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32)); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[0, 0, 0]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![0i32, 0, 0].into_array()); } // ==================== N-ary Evaluate Tests ==================== @@ -918,26 +900,10 @@ mod tests { let result = evaluate_expr(&expr, &test_array); assert!(result.dtype().is_nullable()); - - assert_eq!( - result.scalar_at(0).unwrap(), - Scalar::from(10i32).cast(result.dtype()).unwrap() - ); - assert_eq!( - result.scalar_at(1).unwrap(), - Scalar::null(result.dtype().clone()) - ); - assert_eq!( - result.scalar_at(2).unwrap(), - Scalar::from(30i32).cast(result.dtype()).unwrap() - ); - assert_eq!( - result.scalar_at(3).unwrap(), - Scalar::null(result.dtype().clone()) - ); - assert_eq!( - result.scalar_at(4).unwrap(), - Scalar::null(result.dtype().clone()) + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), None, None]) + .into_array() ); } @@ -960,8 +926,8 @@ mod tests { Some(lit(0i32)), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[10, 20, 30, 40, 50]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![10i32, 20, 30, 40, 50].into_array()); } #[test] @@ -981,12 +947,10 @@ mod tests { let result = evaluate_expr(&expr, &test_array); assert!(result.dtype().is_nullable()); - for i in 0..3 { - assert_eq!( - result.scalar_at(i).unwrap(), - Scalar::null(result.dtype().clone()) - ); - } + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([None::, None, None]).into_array() + ); } #[test] @@ -1008,9 +972,9 @@ mod tests { Some(lit(0i32)), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); + let result = evaluate_expr(&expr, &test_array); // First matching condition always wins - assert_eq!(result.as_slice::(), &[1, 1, 1]); + assert_arrays_eq!(result, buffer![1i32, 1, 1].into_array()); } #[test] @@ -1030,8 +994,8 @@ mod tests { Some(lit(0i32)), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[100, 100, 100]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array()); } #[test] @@ -1053,8 +1017,8 @@ mod tests { Some(lit(0i32)), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); - assert_eq!(result.as_slice::(), &[10, 20, 0]); + let result = evaluate_expr(&expr, &test_array); + assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array()); } #[test] @@ -1119,10 +1083,10 @@ mod tests { Some(lit(0i32)), ); - let result = evaluate_expr(&expr, &test_array).to_primitive(); + let result = evaluate_expr(&expr, &test_array); // row 0: cond1=true → 10 // row 1: cond1=NULL(→false), cond2=true → 20 // row 2: cond1=false, cond2=NULL(→false) → else=0 - assert_eq!(result.as_slice::(), &[10, 20, 0]); + assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array()); } } From 91b71a3a9b6aab36e743227b6893ac8bbabd3931 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Thu, 5 Mar 2026 19:09:34 +0000 Subject: [PATCH 3/8] add andnot Signed-off-by: Baris Palaska --- vortex-array/src/scalar_fn/fns/case_when.rs | 5 +-- vortex-mask/src/bitops.rs | 44 +++++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index e5e29eeee08..043fff5b8d6 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -214,7 +214,7 @@ impl ScalarFnVTable for CaseWhen { } let then_value = args.get(i * 2 + 1)?; - remaining = &remaining & &(!&cond_mask); + remaining = remaining.andnot(&cond_mask); branches.push((effective_mask, then_value)); } @@ -818,8 +818,7 @@ mod tests { let result = evaluate_expr(&expr, &test_array); assert_arrays_eq!( result, - BoolArray::from_iter([Some(false), Some(false), Some(true), Some(true), Some(true)]) - .into_array() + BoolArray::from_iter([false, false, true, true, true]).into_array() ); } diff --git a/vortex-mask/src/bitops.rs b/vortex-mask/src/bitops.rs index a03778eafab..6a6125a08f3 100644 --- a/vortex-mask/src/bitops.rs +++ b/vortex-mask/src/bitops.rs @@ -46,6 +46,21 @@ impl BitOr for &Mask { } } +impl Mask { + /// Computes `self & !rhs` (AND NOT), equivalent to set difference. + pub fn andnot(&self, rhs: &Mask) -> Mask { + if self.len() != rhs.len() { + vortex_panic!("Masks must have the same length"); + } + match (self.bit_buffer(), rhs.bit_buffer()) { + (AllOr::None, _) | (_, AllOr::All) => Mask::new_false(self.len()), + (_, AllOr::None) => self.clone(), + (AllOr::All, _) => !rhs, + (AllOr::Some(lhs), AllOr::Some(rhs)) => Mask::from_buffer(lhs & !rhs), + } + } +} + impl Not for Mask { type Output = Mask; @@ -353,6 +368,35 @@ mod tests { assert!(!result.value(3)); // (!(!false) | false) & !true = (false | false) & false = false } + #[test] + fn test_andnot() { + let a = Mask::from_buffer(BitBuffer::from_iter([true, true, false, false])); + let b = Mask::from_buffer(BitBuffer::from_iter([true, false, true, false])); + let result = a.andnot(&b); + assert!(!result.value(0)); // true & !true = false + assert!(result.value(1)); // true & !false = true + assert!(!result.value(2)); // false & !true = false + assert!(!result.value(3)); // false & !false = false + + // andnot(All) = None + let all = Mask::new_true(4); + assert!(a.andnot(&all).all_false()); + + // andnot(None) = self + let none = Mask::new_false(4); + assert_eq!(a.andnot(&none).true_count(), a.true_count()); + + // None.andnot(_) = None + assert!(none.andnot(&a).all_false()); + + // All.andnot(x) = !x + let not_b = !&b; + let all_andnot_b = Mask::new_true(4).andnot(&b); + for i in 0..4 { + assert_eq!(all_andnot_b.value(i), not_b.value(i)); + } + } + #[test] fn test_bitor() { // Test basic OR operations From 2506209844009f1818c6324014ccc12cab2afc51 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Thu, 5 Mar 2026 19:16:46 +0000 Subject: [PATCH 4/8] cast in zip_impl Signed-off-by: Baris Palaska --- vortex-array/src/scalar_fn/fns/zip/mod.rs | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/vortex-array/src/scalar_fn/fns/zip/mod.rs b/vortex-array/src/scalar_fn/fns/zip/mod.rs index 8b25fec760c..7ebf24a5aa1 100644 --- a/vortex-array/src/scalar_fn/fns/zip/mod.rs +++ b/vortex-array/src/scalar_fn/fns/zip/mod.rs @@ -184,6 +184,14 @@ pub(crate) fn zip_impl( .dtype() .clone() .union_nullability(if_false.dtype().nullability()); + + if mask.all_true() { + return if_true.cast(return_type); + } + if mask.all_false() { + return if_false.cast(return_type); + } + zip_impl_with_builder( if_true, if_false, @@ -199,13 +207,8 @@ fn zip_impl_with_builder( mut builder: Box, ) -> VortexResult { match mask.slices() { - AllOr::All => { - builder.extend_from_array(if_true); - Ok(builder.finish()) - } - AllOr::None => { - builder.extend_from_array(if_false); - Ok(builder.finish()) + AllOr::All | AllOr::None => { + unreachable!("zip_impl_with_builder called with all-true or all-false mask; handle in zip_impl") } AllOr::Some(slices) => { for (start, end) in slices { From db79be4ddf916e8883a8d890ae2243eedc476d5f Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Thu, 5 Mar 2026 19:21:50 +0000 Subject: [PATCH 5/8] tests Signed-off-by: Baris Palaska --- vortex-array/src/scalar_fn/fns/zip/mod.rs | 46 ++++++++++++++++++----- 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/vortex-array/src/scalar_fn/fns/zip/mod.rs b/vortex-array/src/scalar_fn/fns/zip/mod.rs index 7ebf24a5aa1..3d73ec3e639 100644 --- a/vortex-array/src/scalar_fn/fns/zip/mod.rs +++ b/vortex-array/src/scalar_fn/fns/zip/mod.rs @@ -208,7 +208,9 @@ fn zip_impl_with_builder( ) -> VortexResult { match mask.slices() { AllOr::All | AllOr::None => { - unreachable!("zip_impl_with_builder called with all-true or all-false mask; handle in zip_impl") + unreachable!( + "zip_impl_with_builder called with all-true or all-false mask; handle in zip_impl" + ) } AllOr::Some(slices) => { for (start, end) in slices { @@ -231,6 +233,7 @@ mod tests { use vortex_error::VortexResult; use vortex_mask::Mask; + use super::zip_impl; use crate::ArrayRef; use crate::IntoArray; use crate::LEGACY_SESSION; @@ -295,13 +298,9 @@ mod tests { PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), Some(40)]).into_array(); assert_arrays_eq!(result, expected); - - // result must be nullable even if_true was not assert_eq!(result.dtype(), if_false.dtype()) } - /// When the mask is all-false and `if_true` is Nullable, the result dtype must be Nullable - /// even though `if_false` is NonNullable. #[test] fn test_zip_all_false_widens_nullability() { let mask = Mask::new_false(4); @@ -314,10 +313,42 @@ mod tests { PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), Some(4)]).into_array(); assert_arrays_eq!(result, expected); - // result must be nullable even though if_false was not assert_eq!(result.dtype(), if_true.dtype()); } + #[test] + fn test_zip_impl_all_true_widens_nullability() -> VortexResult<()> { + let mask = Mask::new_true(4); + let if_true = buffer![10i32, 20, 30, 40].into_array(); + let if_false = + PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array(); + + let result = zip_impl(&if_true, &if_false, &mask)?; + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(10i32), Some(20), Some(30), Some(40)]) + .into_array() + ); + assert_eq!(result.dtype(), if_false.dtype()); + Ok(()) + } + + #[test] + fn test_zip_impl_all_false_widens_nullability() -> VortexResult<()> { + let mask = Mask::new_false(4); + let if_true = + PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), None]).into_array(); + let if_false = buffer![1i32, 2, 3, 4].into_array(); + + let result = zip_impl(&if_true, &if_false, &mask)?; + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(1i32), Some(2), Some(3), Some(4)]).into_array() + ); + assert_eq!(result.dtype(), if_true.dtype()); + Ok(()) + } + #[test] #[should_panic] fn test_invalid_lengths() { @@ -361,7 +392,6 @@ mod tests { buffer: views host 1.60 kB (align=16) (96.56%) "); - // test wrapped in a struct let wrapped1 = StructArray::try_from_iter([("nested", const1)])?.into_array(); let wrapped2 = StructArray::try_from_iter([("nested", const2)])?.into_array(); @@ -405,7 +435,6 @@ mod tests { builder.finish() }; - // [1,2,4,5,7,8,..] let mask = Mask::from_indices(200, (0..100).filter(|i| i % 3 != 0).collect()); let mask_array = mask.clone().into_array(); @@ -416,7 +445,6 @@ mod tests { .unwrap(); assert_eq!(zipped.nbuffers(), 2); - // assert the result is the same as arrow let expected = arrow_zip( mask.into_array() .into_arrow_preferred() From dbc323da75b2b8dedf4e93cca28f7024ebe89ad9 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Thu, 5 Mar 2026 19:32:44 +0000 Subject: [PATCH 6/8] public api Signed-off-by: Baris Palaska --- vortex-mask/public-api.lock | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vortex-mask/public-api.lock b/vortex-mask/public-api.lock index b0cce7abbf9..84aa1cf82f5 100644 --- a/vortex-mask/public-api.lock +++ b/vortex-mask/public-api.lock @@ -98,6 +98,10 @@ pub fn vortex_mask::Mask::values(&self) -> core::option::Option<&vortex_mask::Ma impl vortex_mask::Mask +pub fn vortex_mask::Mask::andnot(&self, rhs: &vortex_mask::Mask) -> vortex_mask::Mask + +impl vortex_mask::Mask + pub fn vortex_mask::Mask::intersect_by_rank(&self, mask: &vortex_mask::Mask) -> vortex_mask::Mask impl vortex_mask::Mask From 9d57dffbdfad7c19fcc7739f674c2d261aed6ea2 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Fri, 6 Mar 2026 13:37:42 +0000 Subject: [PATCH 7/8] add todo Signed-off-by: Baris Palaska --- vortex-array/src/scalar_fn/fns/case_when.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index 043fff5b8d6..f7384a05932 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -191,11 +191,18 @@ impl ScalarFnVTable for CaseWhen { args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx, ) -> VortexResult { + // Inspired by https://datafusion.apache.org/blog/2026/02/02/datafusion_case/ + // + // Implemented: short-circuit early exit; single-pass merge via `merge_case_branches`. + // Partial: single-branch uses `zip_impl` but THEN/ELSE still evaluated on the full batch. + // + // TODO: shrink input to `remaining` rows between WHEN iterations (batch reduction). + // TODO: project to only referenced columns before batch reduction (column projection). + // TODO: evaluate THEN/ELSE on compact matching/non-matching rows and merge without scatter. + // TODO: for constant WHEN/THEN values, compile to a hash table for a single-pass lookup. let row_count = args.row_count(); let num_pairs = options.num_when_then_pairs as usize; - // Track unmatched rows; AND each condition with `remaining` to enforce first-match-wins - // and produce disjoint branch masks. let mut remaining = Mask::new_true(row_count); let mut branches: Vec<(Mask, ArrayRef)> = Vec::with_capacity(num_pairs); From cb59e4fbf0afc158a0c3ce2ca1579f4c0272cd74 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Fri, 6 Mar 2026 14:10:05 +0000 Subject: [PATCH 8/8] mask::bitand_not uses fused bitbuffer method, also owned Signed-off-by: Baris Palaska --- vortex-array/src/scalar_fn/fns/case_when.rs | 2 +- vortex-mask/src/bitops.rs | 29 ++++++++++----------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index d64a6680f68..a476a04b8ed 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -231,7 +231,7 @@ impl ScalarFnVTable for CaseWhen { } let then_value = args.get(i * 2 + 1)?; - remaining = remaining.andnot(&cond_mask); + remaining = remaining.bitand_not(&cond_mask); branches.push((effective_mask, then_value)); } diff --git a/vortex-mask/src/bitops.rs b/vortex-mask/src/bitops.rs index 6a6125a08f3..3ad681db161 100644 --- a/vortex-mask/src/bitops.rs +++ b/vortex-mask/src/bitops.rs @@ -48,15 +48,15 @@ impl BitOr for &Mask { impl Mask { /// Computes `self & !rhs` (AND NOT), equivalent to set difference. - pub fn andnot(&self, rhs: &Mask) -> Mask { + pub fn bitand_not(self, rhs: &Mask) -> Mask { if self.len() != rhs.len() { vortex_panic!("Masks must have the same length"); } match (self.bit_buffer(), rhs.bit_buffer()) { (AllOr::None, _) | (_, AllOr::All) => Mask::new_false(self.len()), - (_, AllOr::None) => self.clone(), + (_, AllOr::None) => self, (AllOr::All, _) => !rhs, - (AllOr::Some(lhs), AllOr::Some(rhs)) => Mask::from_buffer(lhs & !rhs), + (AllOr::Some(lhs), AllOr::Some(rhs)) => Mask::from_buffer(lhs.bitand_not(rhs)), } } } @@ -369,31 +369,30 @@ mod tests { } #[test] - fn test_andnot() { + fn test_bitand_not() { let a = Mask::from_buffer(BitBuffer::from_iter([true, true, false, false])); let b = Mask::from_buffer(BitBuffer::from_iter([true, false, true, false])); - let result = a.andnot(&b); + let result = a.clone().bitand_not(&b); assert!(!result.value(0)); // true & !true = false assert!(result.value(1)); // true & !false = true assert!(!result.value(2)); // false & !true = false assert!(!result.value(3)); // false & !false = false - // andnot(All) = None - let all = Mask::new_true(4); - assert!(a.andnot(&all).all_false()); + // bitand_not(All) = None + assert!(a.clone().bitand_not(&Mask::new_true(4)).all_false()); - // andnot(None) = self + // bitand_not(None) = self let none = Mask::new_false(4); - assert_eq!(a.andnot(&none).true_count(), a.true_count()); + assert_eq!(a.clone().bitand_not(&none).true_count(), a.true_count()); - // None.andnot(_) = None - assert!(none.andnot(&a).all_false()); + // None.bitand_not(_) = None + assert!(none.bitand_not(&a).all_false()); - // All.andnot(x) = !x + // All.bitand_not(x) = !x let not_b = !&b; - let all_andnot_b = Mask::new_true(4).andnot(&b); + let all_bitand_not_b = Mask::new_true(4).bitand_not(&b); for i in 0..4 { - assert_eq!(all_andnot_b.value(i), not_b.value(i)); + assert_eq!(all_bitand_not_b.value(i), not_b.value(i)); } }