diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 94d962a4868..6a78c7e3b27 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -13328,7 +13328,7 @@ pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::reduce(&self, options: pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::return_dtype(&self, options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult -pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult>> +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult>> pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::simplify(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, ctx: &dyn vortex_array::scalar_fn::SimplifyCtx) -> vortex_error::VortexResult> @@ -15232,7 +15232,7 @@ pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::reduce(&self, options: pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::return_dtype(&self, options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult -pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult>> +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult>> pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::simplify(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, ctx: &dyn vortex_array::scalar_fn::SimplifyCtx) -> vortex_error::VortexResult> diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index f701548f145..754a4e10964 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -1,7 +1,14 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! N-ary CASE WHEN expression for conditional value selection. +//! SQL-style CASE WHEN: evaluates `(condition, value)` pairs in order and returns +//! the value from the first matching condition (first-match-wins). NULL conditions +//! are treated as false. If no ELSE clause is provided, unmatched rows produce NULL; +//! otherwise they get the ELSE value. +//! +//! Unlike SQL which coerces all branches to a common supertype, all THEN/ELSE +//! branches must share the same base dtype (ignoring nullability). The result +//! nullability is the union of all branches (forced nullable if no ELSE). use std::fmt; use std::fmt::Formatter; @@ -69,9 +76,11 @@ impl ScalarFnVTable for CaseWhen { ScalarFnId::from("vortex.case_when") } - fn serialize(&self, options: &Self::Options) -> VortexResult>> { - let num_children = options.num_when_then_pairs * 2 + u32::from(options.has_else); - Ok(Some(pb::CaseWhenOpts { num_children }.encode_to_vec())) + fn serialize(&self, _options: &Self::Options) -> VortexResult>> { + // let num_children = options.num_when_then_pairs * 2 + u32::from(options.has_else); + // Ok(Some(pb::CaseWhenOpts { num_children }.encode_to_vec())) + // stabilize the expr + vortex_bail!("cannot serialize") } fn deserialize( @@ -147,8 +156,9 @@ impl ScalarFnVTable for CaseWhen { ); } - // The return dtype is based on the first THEN expression (index 1). - // Validate all other THEN branches match and union their nullability. + // Unlike SQL which coerces all branches to a common supertype, we require + // all THEN/ELSE branches to have the same base dtype (ignoring nullability). + // The result nullability is the union of all branches. let first_then = &arg_dtypes[1]; let mut result_dtype = first_then.clone(); @@ -166,7 +176,7 @@ impl ScalarFnVTable for CaseWhen { if options.has_else { let else_dtype = &arg_dtypes[options.num_when_then_pairs as usize * 2]; - if !first_then.eq_ignore_nullability(else_dtype) { + if !result_dtype.eq_ignore_nullability(else_dtype) { vortex_bail!( "CaseWhen THEN and ELSE dtypes must match (ignoring nullability), got {} and {}", first_then, @@ -198,6 +208,9 @@ impl ScalarFnVTable for CaseWhen { ConstantArray::new(Scalar::null(then_dtype), row_count).into_array() }; + // TODO(perf): this reverse-zip approach touches every row for every condition. + // A left-to-right filter approach could maintain an "unmatched" mask, narrow it + // as conditions match, and exit early once all rows are resolved. for i in (0..num_pairs).rev() { let condition = args.get(i * 2)?; let then_value = args.get(i * 2 + 1)?; @@ -279,6 +292,7 @@ mod tests { // ==================== Serialization Tests ==================== #[test] + #[should_panic(expected = "cannot serialize")] fn test_serialization_roundtrip() { let options = CaseWhenOptions { num_when_then_pairs: 1, @@ -292,6 +306,7 @@ mod tests { } #[test] + #[should_panic(expected = "cannot serialize")] fn test_serialization_no_else() { let options = CaseWhenOptions { num_when_then_pairs: 1, @@ -448,6 +463,7 @@ mod tests { // ==================== N-ary Serialization Tests ==================== #[test] + #[should_panic(expected = "cannot serialize")] fn test_serialization_roundtrip_nary() { let options = CaseWhenOptions { num_when_then_pairs: 3, @@ -461,6 +477,7 @@ mod tests { } #[test] + #[should_panic(expected = "cannot serialize")] fn test_serialization_roundtrip_nary_no_else() { let options = CaseWhenOptions { num_when_then_pairs: 4,