From ebd242213defe7c048dbf5ffb5752c036b7ce205 Mon Sep 17 00:00:00 2001 From: xiedeyantu Date: Thu, 19 Mar 2026 23:15:57 +0800 Subject: [PATCH 1/2] fix: Tuple IN null semantics for struct comparisons --- .../physical-expr/src/expressions/in_list.rs | 182 +++++++++++++++++- datafusion/sqllogictest/test_files/expr.slt | 40 ++-- 2 files changed, 201 insertions(+), 21 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index ca89a3ab1ef4..63752607b13f 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -27,7 +27,7 @@ use crate::physical_expr::physical_exprs_bag_equal; use arrow::array::*; use arrow::buffer::{BooleanBuffer, NullBuffer}; -use arrow::compute::kernels::boolean::{not, or_kleene}; +use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene}; use arrow::compute::kernels::cmp::eq as arrow_eq; use arrow::compute::{SortOptions, take}; use arrow::datatypes::*; @@ -79,15 +79,23 @@ struct ArrayStaticFilter { /// Note: usize::hash is not used, instead the raw entry /// API is used to store entries w.r.t their value map: HashMap, + has_nulls: bool, } impl StaticFilter for ArrayStaticFilter { fn null_count(&self) -> usize { - self.in_array.null_count() + usize::from(self.has_nulls) } /// Checks if values in `v` are contained in the `in_array` using this hash set for lookup. fn contains(&self, v: &dyn Array, negated: bool) -> Result { + if let (Some(v), Some(in_array)) = ( + v.as_any().downcast_ref::(), + self.in_array.as_any().downcast_ref::(), + ) { + return struct_in_list(v, in_array, negated); + } + // Null type comparisons always return null (SQL three-valued logic) if v.data_type() == &DataType::Null || self.in_array.data_type() == &DataType::Null @@ -161,6 +169,91 @@ fn supports_arrow_eq(dt: &DataType) -> bool { } } +fn array_has_nulls(array: &dyn Array) -> bool { + if array.null_count() > 0 { + return true; + } + + if let Some(struct_array) = array.as_any().downcast_ref::() { + return struct_array + .columns() + .iter() + .any(|column| array_has_nulls(column.as_ref())); + } + + false +} + +fn arrays_eq_kleene(lhs: &dyn Array, rhs: &dyn Array) -> Result { + match ( + lhs.as_any().downcast_ref::(), + rhs.as_any().downcast_ref::(), + ) { + (Some(lhs), Some(rhs)) => struct_eq_kleene(lhs, rhs), + _ => { + let cmp = make_comparator(lhs, rhs, SortOptions::default())?; + let buffer = BooleanBuffer::collect_bool(lhs.len(), |i| cmp(i, i).is_eq()); + let nulls = NullBuffer::union(lhs.nulls(), rhs.nulls()); + Ok(BooleanArray::new(buffer, nulls)) + } + } +} + +fn struct_eq_kleene(lhs: &StructArray, rhs: &StructArray) -> Result { + assert_or_internal_err!( + lhs.len() == rhs.len(), + "struct equality requires arrays of the same length" + ); + assert_or_internal_err!( + lhs.num_columns() == rhs.num_columns(), + "struct equality requires arrays with the same number of fields" + ); + + let top_level_nulls = NullBuffer::union(lhs.nulls(), rhs.nulls()); + + let mut result = None; + for index in 0..lhs.num_columns() { + let field_result = + arrays_eq_kleene(lhs.column(index).as_ref(), rhs.column(index).as_ref())?; + result = Some(match result { + None => field_result, + Some(acc) => and_kleene(&acc, &field_result)?, + }); + } + + let result = result + .unwrap_or_else(|| BooleanArray::new(BooleanBuffer::new_set(lhs.len()), None)); + let nulls = NullBuffer::union(result.nulls(), top_level_nulls.as_ref()); + Ok(BooleanArray::new(result.values().clone(), nulls)) +} + +fn struct_in_list( + needle: &StructArray, + haystack: &StructArray, + negated: bool, +) -> Result { + let mut found = BooleanArray::new(BooleanBuffer::new_unset(needle.len()), None); + + for row in 0..haystack.len() { + let scalar = ScalarValue::try_from_array(haystack, row)?; + let scalar_array = scalar.to_array_of_size(needle.len())?; + let scalar_struct = scalar_array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + exec_datafusion_err!("Expected struct array while comparing structs") + })?; + let matches = struct_eq_kleene(needle, scalar_struct)?; + found = or_kleene(&found, &matches)?; + + if found.null_count() == 0 && found.true_count() == needle.len() { + break; + } + } + + if negated { Ok(not(&found)?) } else { Ok(found) } +} + fn instantiate_static_filter( in_array: ArrayRef, ) -> Result> { @@ -198,11 +291,13 @@ impl ArrayStaticFilter { in_array, state: RandomState::default(), map: HashMap::with_hasher(()), + has_nulls: true, }); } let state = RandomState::default(); let mut map: HashMap = HashMap::with_hasher(()); + let has_nulls = array_has_nulls(&in_array); with_hashes([&in_array], &state, |hashes| -> Result<()> { let cmp = make_comparator(&in_array, &in_array, SortOptions::default())?; @@ -232,6 +327,7 @@ impl ArrayStaticFilter { in_array, state, map, + has_nulls, }) } } @@ -805,7 +901,12 @@ impl PhysicalExpr for InListExpr { let compare_one = |expr: &Arc| -> Result { match expr.evaluate(batch)? { ColumnarValue::Array(array) => { - if lhs_supports_arrow_eq + if let (Some(lhs), Some(rhs)) = ( + value.as_any().downcast_ref::(), + array.as_any().downcast_ref::(), + ) { + Ok(arrays_eq_kleene(lhs, rhs)?) + } else if lhs_supports_arrow_eq && supports_arrow_eq(array.data_type()) { Ok(arrow_eq(&value, &array)?) @@ -828,20 +929,32 @@ impl PhysicalExpr for InListExpr { if scalar.is_null() { // If scalar is null, all comparisons return null Ok(BooleanArray::from(vec![None; num_rows])) + } else if let Some(value_struct) = + value.as_any().downcast_ref::() + { + let scalar_array = scalar.to_array_of_size(num_rows)?; + let scalar_struct = scalar_array + .as_any() + .downcast_ref::() + .ok_or_else(|| exec_datafusion_err!( + "Expected struct array while comparing structs" + ))?; + Ok(arrays_eq_kleene(value_struct, scalar_struct)?) } else if lhs_supports_arrow_eq { let scalar_datum = scalar.to_scalar()?; Ok(arrow_eq(&value, &scalar_datum)?) } else { - // Convert scalar to 1-element array - let array = scalar.to_array()?; + // Broadcast the scalar to match the input length so the + // fallback comparator can evaluate row-by-row. + let array = scalar.to_array_of_size(num_rows)?; let cmp = make_comparator( value.as_ref(), array.as_ref(), SortOptions::default(), )?; - // Compare each row of value with the single scalar element + // Compare each row of value with the broadcast scalar value. let buffer = BooleanBuffer::collect_bool(num_rows, |i| { - cmp(i, 0).is_eq() + cmp(i, i).is_eq() }); Ok(BooleanArray::new(buffer, value.nulls().cloned())) } @@ -2024,6 +2137,61 @@ mod tests { Ok(()) } + #[test] + fn in_list_struct_field_null_in_list() -> Result<()> { + // Regresses `(row(...)) IN ((..., NULL))`. + // A NULL inside a struct field must produce NULL only when all + // preceding fields are equal. + let struct_fields = Fields::from(vec![ + Field::new("empno", DataType::Int32, false), + Field::new("deptno", DataType::Int32, true), + ]); + let schema = Schema::new(vec![Field::new( + "a", + DataType::Struct(struct_fields.clone()), + false, + )]); + + let empno = Arc::new(Int32Array::from(vec![7521, 1111, 7521])); + let deptno = Arc::new(Int32Array::from(vec![30, 30, 31])); + let struct_array = + StructArray::new(struct_fields.clone(), vec![empno, deptno], None); + + let col_a = col("a", &schema)?; + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(struct_array)])?; + + let null_deptno = ScalarValue::Struct(Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![7521])), + Arc::new(Int32Array::from(vec![None])), + ], + None, + ))); + + let list = vec![lit(null_deptno.clone())]; + in_list_raw!( + batch, + list.clone(), + &false, + vec![None, Some(false), None], + Arc::clone(&col_a), + &schema + ); + + in_list_raw!( + batch, + list, + &true, + vec![None, Some(true), None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + #[test] fn in_list_nested_struct() -> Result<()> { // Create nested struct schema diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index a6341bc686f7..278bebcadacd 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -1269,23 +1269,23 @@ NULL # Test tuple/row-wise IN comparisons using struct syntax # Note: Using arrow_cast for precise type control -# (NULL, NULL) IN ((1, 2)) => FALSE +# (NULL, NULL) IN ((1, 2)) => NULL query B SELECT struct(arrow_cast(NULL, 'Int32'), arrow_cast(NULL, 'Int32')) IN (struct(1, 2)) ---- -false +NULL -# (NULL, NULL) IN ((NULL, 1)) => FALSE +# (NULL, NULL) IN ((NULL, 1)) => NULL query B SELECT struct(arrow_cast(NULL, 'Int32'), arrow_cast(NULL, 'Int32')) IN (struct(arrow_cast(NULL, 'Int32'), 1)) ---- -false +NULL -# (NULL, NULL) IN ((NULL, NULL)) => TRUE (exact match) +# (NULL, NULL) IN ((NULL, NULL)) => NULL query B SELECT struct(arrow_cast(NULL, 'Int32'), arrow_cast(NULL, 'Int32')) IN (struct(arrow_cast(NULL, 'Int32'), arrow_cast(NULL, 'Int32'))) ---- -true +NULL # (NULL, 1) IN ((1, 2)) => FALSE query B @@ -1293,17 +1293,17 @@ SELECT struct(arrow_cast(NULL, 'Int32'), 1) IN (struct(1, 2)) ---- false -# (NULL, 1) IN ((NULL, 1)) => TRUE (exact match) +# (NULL, 1) IN ((NULL, 1)) => NULL query B SELECT struct(arrow_cast(NULL, 'Int32'), 1) IN (struct(arrow_cast(NULL, 'Int32'), 1)) ---- -true +NULL -# (NULL, 1) IN ((NULL, NULL)) => FALSE +# (NULL, 1) IN ((NULL, NULL)) => NULL query B SELECT struct(arrow_cast(NULL, 'Int32'), 1) IN (struct(arrow_cast(NULL, 'Int32'), arrow_cast(NULL, 'Int32'))) ---- -false +NULL # (1, 2) IN ((1, 2)) => TRUE query B @@ -1323,17 +1323,29 @@ SELECT struct(4, 4) IN (struct(1, 2)) ---- false -# (1, 1) IN ((NULL, 1)) => FALSE +# (1, 1) IN ((NULL, 1)) => NULL query B SELECT struct(1, 1) IN (struct(NULL, 1)) ---- -false +NULL -# (1, 1) IN ((NULL, NULL)) => FALSE +# (1, 1) IN ((NULL, NULL)) => NULL query B SELECT struct(1, 1) IN (struct(NULL, NULL)) ---- -false +NULL + +# (7521, 30) IN ((7521, NULL)) => NULL +query B +SELECT struct(7521, 30) IN (struct(7521, NULL)) +---- +NULL + +# (7521, 30) NOT IN ((7521, NULL)) => NULL +query B +SELECT struct(7521, 30) NOT IN (struct(7521, NULL)) +---- +NULL # Cleanup test tables From 9fec6b0243b519568fd02594a3cfb20246234018 Mon Sep 17 00:00:00 2001 From: xiedeyantu Date: Thu, 19 Mar 2026 23:35:10 +0800 Subject: [PATCH 2/2] fix test --- .../physical-expr/src/expressions/in_list.rs | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 63752607b13f..399a78ec30d3 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -3084,25 +3084,25 @@ mod tests { Ok(()) }; - // (NULL, NULL) IN ((1, 2)) => FALSE (tuples don't match) + // (NULL, NULL) IN ((1, 2)) => NULL run_tuple_test( make_struct(None, None), vec![make_struct(Some(1), Some(2))], - vec![Some(false)], + vec![None], )?; - // (NULL, NULL) IN ((NULL, 1)) => FALSE + // (NULL, NULL) IN ((NULL, 1)) => NULL run_tuple_test( make_struct(None, None), vec![make_struct(None, Some(1))], - vec![Some(false)], + vec![None], )?; - // (NULL, NULL) IN ((NULL, NULL)) => TRUE (exact match including nulls) + // (NULL, NULL) IN ((NULL, NULL)) => NULL run_tuple_test( make_struct(None, None), vec![make_struct(None, None)], - vec![Some(true)], + vec![None], )?; // (NULL, 1) IN ((1, 2)) => FALSE @@ -3112,18 +3112,18 @@ mod tests { vec![Some(false)], )?; - // (NULL, 1) IN ((NULL, 1)) => TRUE (exact match) + // (NULL, 1) IN ((NULL, 1)) => NULL run_tuple_test( make_struct(None, Some(1)), vec![make_struct(None, Some(1))], - vec![Some(true)], + vec![None], )?; - // (NULL, 1) IN ((NULL, NULL)) => FALSE + // (NULL, 1) IN ((NULL, NULL)) => NULL run_tuple_test( make_struct(None, Some(1)), vec![make_struct(None, None)], - vec![Some(false)], + vec![None], )?; // (1, 2) IN ((1, 2)) => TRUE @@ -3147,18 +3147,18 @@ mod tests { vec![Some(false)], )?; - // (1, 1) IN ((NULL, 1)) => FALSE + // (1, 1) IN ((NULL, 1)) => NULL run_tuple_test( make_struct(Some(1), Some(1)), vec![make_struct(None, Some(1))], - vec![Some(false)], + vec![None], )?; - // (1, 1) IN ((NULL, NULL)) => FALSE + // (1, 1) IN ((NULL, NULL)) => NULL run_tuple_test( make_struct(Some(1), Some(1)), vec![make_struct(None, None)], - vec![Some(false)], + vec![None], )?; Ok(())