Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vortex-array/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -13328,7 +13328,7 @@ pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::reduce(&self, options:

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::return_dtype(&self, options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult<vortex_array::dtype::DType>

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::simplify(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, ctx: &dyn vortex_array::scalar_fn::SimplifyCtx) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::Expression>>

Expand Down Expand Up @@ -15232,7 +15232,7 @@ pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::reduce(&self, options:

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::return_dtype(&self, options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult<vortex_array::dtype::DType>

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::simplify(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, ctx: &dyn vortex_array::scalar_fn::SimplifyCtx) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::Expression>>

Expand Down
31 changes: 24 additions & 7 deletions vortex-array/src/scalar_fn/fns/case_when.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

//! N-ary CASE WHEN expression for conditional value selection.
//! SQL-style CASE WHEN: evaluates `(condition, value)` pairs in order and returns
//! the value from the first matching condition (first-match-wins). NULL conditions
//! are treated as false. If no ELSE clause is provided, unmatched rows produce NULL;
//! otherwise they get the ELSE value.
//!
//! Unlike SQL which coerces all branches to a common supertype, all THEN/ELSE
//! branches must share the same base dtype (ignoring nullability). The result
//! nullability is the union of all branches (forced nullable if no ELSE).

use std::fmt;
use std::fmt::Formatter;
Expand Down Expand Up @@ -69,9 +76,11 @@ impl ScalarFnVTable for CaseWhen {
ScalarFnId::from("vortex.case_when")
}

fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
let num_children = options.num_when_then_pairs * 2 + u32::from(options.has_else);
Ok(Some(pb::CaseWhenOpts { num_children }.encode_to_vec()))
fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
// let num_children = options.num_when_then_pairs * 2 + u32::from(options.has_else);
// Ok(Some(pb::CaseWhenOpts { num_children }.encode_to_vec()))
// stabilize the expr
vortex_bail!("cannot serialize")
}

fn deserialize(
Expand Down Expand Up @@ -147,8 +156,9 @@ impl ScalarFnVTable for CaseWhen {
);
}

// The return dtype is based on the first THEN expression (index 1).
// Validate all other THEN branches match and union their nullability.
// Unlike SQL which coerces all branches to a common supertype, we require
// all THEN/ELSE branches to have the same base dtype (ignoring nullability).
// The result nullability is the union of all branches.
let first_then = &arg_dtypes[1];
let mut result_dtype = first_then.clone();

Expand All @@ -166,7 +176,7 @@ impl ScalarFnVTable for CaseWhen {

if options.has_else {
let else_dtype = &arg_dtypes[options.num_when_then_pairs as usize * 2];
if !first_then.eq_ignore_nullability(else_dtype) {
if !result_dtype.eq_ignore_nullability(else_dtype) {
vortex_bail!(
"CaseWhen THEN and ELSE dtypes must match (ignoring nullability), got {} and {}",
first_then,
Expand Down Expand Up @@ -198,6 +208,9 @@ impl ScalarFnVTable for CaseWhen {
ConstantArray::new(Scalar::null(then_dtype), row_count).into_array()
};

// TODO(perf): this reverse-zip approach touches every row for every condition.
// A left-to-right filter approach could maintain an "unmatched" mask, narrow it
// as conditions match, and exit early once all rows are resolved.
for i in (0..num_pairs).rev() {
let condition = args.get(i * 2)?;
let then_value = args.get(i * 2 + 1)?;
Expand Down Expand Up @@ -279,6 +292,7 @@ mod tests {
// ==================== Serialization Tests ====================

#[test]
#[should_panic(expected = "cannot serialize")]
fn test_serialization_roundtrip() {
let options = CaseWhenOptions {
num_when_then_pairs: 1,
Expand All @@ -292,6 +306,7 @@ mod tests {
}

#[test]
#[should_panic(expected = "cannot serialize")]
fn test_serialization_no_else() {
let options = CaseWhenOptions {
num_when_then_pairs: 1,
Expand Down Expand Up @@ -448,6 +463,7 @@ mod tests {
// ==================== N-ary Serialization Tests ====================

#[test]
#[should_panic(expected = "cannot serialize")]
fn test_serialization_roundtrip_nary() {
let options = CaseWhenOptions {
num_when_then_pairs: 3,
Expand All @@ -461,6 +477,7 @@ mod tests {
}

#[test]
#[should_panic(expected = "cannot serialize")]
fn test_serialization_roundtrip_nary_no_else() {
let options = CaseWhenOptions {
num_when_then_pairs: 4,
Expand Down
Loading