Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 189 additions & 21 deletions datafusion/physical-expr/src/expressions/in_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::physical_expr::physical_exprs_bag_equal;

use arrow::array::*;
use arrow::buffer::{BooleanBuffer, NullBuffer};
use arrow::compute::kernels::boolean::{not, or_kleene};
use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene};
use arrow::compute::kernels::cmp::eq as arrow_eq;
use arrow::compute::{SortOptions, take};
use arrow::datatypes::*;
Expand Down Expand Up @@ -79,15 +79,23 @@ struct ArrayStaticFilter {
/// Note: usize::hash is not used, instead the raw entry
/// API is used to store entries w.r.t their value
map: HashMap<usize, (), ()>,
has_nulls: bool,
}

impl StaticFilter for ArrayStaticFilter {
fn null_count(&self) -> usize {
self.in_array.null_count()
usize::from(self.has_nulls)
}

/// Checks if values in `v` are contained in the `in_array` using this hash set for lookup.
fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
if let (Some(v), Some(in_array)) = (
v.as_any().downcast_ref::<StructArray>(),
self.in_array.as_any().downcast_ref::<StructArray>(),
) {
return struct_in_list(v, in_array, negated);
}

// Null type comparisons always return null (SQL three-valued logic)
if v.data_type() == &DataType::Null
|| self.in_array.data_type() == &DataType::Null
Expand Down Expand Up @@ -161,6 +169,91 @@ fn supports_arrow_eq(dt: &DataType) -> bool {
}
}

fn array_has_nulls(array: &dyn Array) -> bool {
if array.null_count() > 0 {
return true;
}

if let Some(struct_array) = array.as_any().downcast_ref::<StructArray>() {
return struct_array
.columns()
.iter()
.any(|column| array_has_nulls(column.as_ref()));
}

false
}

fn arrays_eq_kleene(lhs: &dyn Array, rhs: &dyn Array) -> Result<BooleanArray> {
match (
lhs.as_any().downcast_ref::<StructArray>(),
rhs.as_any().downcast_ref::<StructArray>(),
) {
(Some(lhs), Some(rhs)) => struct_eq_kleene(lhs, rhs),
_ => {
let cmp = make_comparator(lhs, rhs, SortOptions::default())?;
let buffer = BooleanBuffer::collect_bool(lhs.len(), |i| cmp(i, i).is_eq());
let nulls = NullBuffer::union(lhs.nulls(), rhs.nulls());
Ok(BooleanArray::new(buffer, nulls))
}
}
}

fn struct_eq_kleene(lhs: &StructArray, rhs: &StructArray) -> Result<BooleanArray> {
assert_or_internal_err!(
lhs.len() == rhs.len(),
"struct equality requires arrays of the same length"
);
assert_or_internal_err!(
lhs.num_columns() == rhs.num_columns(),
"struct equality requires arrays with the same number of fields"
);

let top_level_nulls = NullBuffer::union(lhs.nulls(), rhs.nulls());

let mut result = None;
for index in 0..lhs.num_columns() {
let field_result =
arrays_eq_kleene(lhs.column(index).as_ref(), rhs.column(index).as_ref())?;
result = Some(match result {
None => field_result,
Some(acc) => and_kleene(&acc, &field_result)?,
});
}

let result = result
.unwrap_or_else(|| BooleanArray::new(BooleanBuffer::new_set(lhs.len()), None));
let nulls = NullBuffer::union(result.nulls(), top_level_nulls.as_ref());
Ok(BooleanArray::new(result.values().clone(), nulls))
}

fn struct_in_list(
needle: &StructArray,
haystack: &StructArray,
negated: bool,
) -> Result<BooleanArray> {
let mut found = BooleanArray::new(BooleanBuffer::new_unset(needle.len()), None);

for row in 0..haystack.len() {
let scalar = ScalarValue::try_from_array(haystack, row)?;
let scalar_array = scalar.to_array_of_size(needle.len())?;
let scalar_struct = scalar_array
.as_any()
.downcast_ref::<StructArray>()
.ok_or_else(|| {
exec_datafusion_err!("Expected struct array while comparing structs")
})?;
let matches = struct_eq_kleene(needle, scalar_struct)?;
found = or_kleene(&found, &matches)?;

if found.null_count() == 0 && found.true_count() == needle.len() {
break;
}
}

if negated { Ok(not(&found)?) } else { Ok(found) }
}

fn instantiate_static_filter(
in_array: ArrayRef,
) -> Result<Arc<dyn StaticFilter + Send + Sync>> {
Expand Down Expand Up @@ -198,11 +291,13 @@ impl ArrayStaticFilter {
in_array,
state: RandomState::default(),
map: HashMap::with_hasher(()),
has_nulls: true,
});
}

let state = RandomState::default();
let mut map: HashMap<usize, (), ()> = HashMap::with_hasher(());
let has_nulls = array_has_nulls(&in_array);

with_hashes([&in_array], &state, |hashes| -> Result<()> {
let cmp = make_comparator(&in_array, &in_array, SortOptions::default())?;
Expand Down Expand Up @@ -232,6 +327,7 @@ impl ArrayStaticFilter {
in_array,
state,
map,
has_nulls,
})
}
}
Expand Down Expand Up @@ -805,7 +901,12 @@ impl PhysicalExpr for InListExpr {
let compare_one = |expr: &Arc<dyn PhysicalExpr>| -> Result<BooleanArray> {
match expr.evaluate(batch)? {
ColumnarValue::Array(array) => {
if lhs_supports_arrow_eq
if let (Some(lhs), Some(rhs)) = (
value.as_any().downcast_ref::<StructArray>(),
array.as_any().downcast_ref::<StructArray>(),
) {
Ok(arrays_eq_kleene(lhs, rhs)?)
} else if lhs_supports_arrow_eq
&& supports_arrow_eq(array.data_type())
{
Ok(arrow_eq(&value, &array)?)
Expand All @@ -828,20 +929,32 @@ impl PhysicalExpr for InListExpr {
if scalar.is_null() {
// If scalar is null, all comparisons return null
Ok(BooleanArray::from(vec![None; num_rows]))
} else if let Some(value_struct) =
value.as_any().downcast_ref::<StructArray>()
{
let scalar_array = scalar.to_array_of_size(num_rows)?;
let scalar_struct = scalar_array
.as_any()
.downcast_ref::<StructArray>()
.ok_or_else(|| exec_datafusion_err!(
"Expected struct array while comparing structs"
))?;
Ok(arrays_eq_kleene(value_struct, scalar_struct)?)
} else if lhs_supports_arrow_eq {
let scalar_datum = scalar.to_scalar()?;
Ok(arrow_eq(&value, &scalar_datum)?)
} else {
// Convert scalar to 1-element array
let array = scalar.to_array()?;
// Broadcast the scalar to match the input length so the
// fallback comparator can evaluate row-by-row.
let array = scalar.to_array_of_size(num_rows)?;
let cmp = make_comparator(
value.as_ref(),
array.as_ref(),
SortOptions::default(),
)?;
// Compare each row of value with the single scalar element
// Compare each row of value with the broadcast scalar value.
let buffer = BooleanBuffer::collect_bool(num_rows, |i| {
cmp(i, 0).is_eq()
cmp(i, i).is_eq()
});
Ok(BooleanArray::new(buffer, value.nulls().cloned()))
}
Expand Down Expand Up @@ -2024,6 +2137,61 @@ mod tests {
Ok(())
}

#[test]
fn in_list_struct_field_null_in_list() -> Result<()> {
// Regresses `(row(...)) IN ((..., NULL))`.
// A NULL inside a struct field must produce NULL only when all
// preceding fields are equal.
let struct_fields = Fields::from(vec![
Field::new("empno", DataType::Int32, false),
Field::new("deptno", DataType::Int32, true),
]);
let schema = Schema::new(vec![Field::new(
"a",
DataType::Struct(struct_fields.clone()),
false,
)]);

let empno = Arc::new(Int32Array::from(vec![7521, 1111, 7521]));
let deptno = Arc::new(Int32Array::from(vec![30, 30, 31]));
let struct_array =
StructArray::new(struct_fields.clone(), vec![empno, deptno], None);

let col_a = col("a", &schema)?;
let batch =
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(struct_array)])?;

let null_deptno = ScalarValue::Struct(Arc::new(StructArray::new(
struct_fields.clone(),
vec![
Arc::new(Int32Array::from(vec![7521])),
Arc::new(Int32Array::from(vec![None])),
],
None,
)));

let list = vec![lit(null_deptno.clone())];
in_list_raw!(
batch,
list.clone(),
&false,
vec![None, Some(false), None],
Arc::clone(&col_a),
&schema
);

in_list_raw!(
batch,
list,
&true,
vec![None, Some(true), None],
Arc::clone(&col_a),
&schema
);

Ok(())
}

#[test]
fn in_list_nested_struct() -> Result<()> {
// Create nested struct schema
Expand Down Expand Up @@ -2916,25 +3084,25 @@ mod tests {
Ok(())
};

// (NULL, NULL) IN ((1, 2)) => FALSE (tuples don't match)
// (NULL, NULL) IN ((1, 2)) => NULL
run_tuple_test(
make_struct(None, None),
vec![make_struct(Some(1), Some(2))],
vec![Some(false)],
vec![None],
)?;

// (NULL, NULL) IN ((NULL, 1)) => FALSE
// (NULL, NULL) IN ((NULL, 1)) => NULL
run_tuple_test(
make_struct(None, None),
vec![make_struct(None, Some(1))],
vec![Some(false)],
vec![None],
)?;

// (NULL, NULL) IN ((NULL, NULL)) => TRUE (exact match including nulls)
// (NULL, NULL) IN ((NULL, NULL)) => NULL
run_tuple_test(
make_struct(None, None),
vec![make_struct(None, None)],
vec![Some(true)],
vec![None],
)?;

// (NULL, 1) IN ((1, 2)) => FALSE
Expand All @@ -2944,18 +3112,18 @@ mod tests {
vec![Some(false)],
)?;

// (NULL, 1) IN ((NULL, 1)) => TRUE (exact match)
// (NULL, 1) IN ((NULL, 1)) => NULL
run_tuple_test(
make_struct(None, Some(1)),
vec![make_struct(None, Some(1))],
vec![Some(true)],
vec![None],
)?;

// (NULL, 1) IN ((NULL, NULL)) => FALSE
// (NULL, 1) IN ((NULL, NULL)) => NULL
run_tuple_test(
make_struct(None, Some(1)),
vec![make_struct(None, None)],
vec![Some(false)],
vec![None],
)?;

// (1, 2) IN ((1, 2)) => TRUE
Expand All @@ -2979,18 +3147,18 @@ mod tests {
vec![Some(false)],
)?;

// (1, 1) IN ((NULL, 1)) => FALSE
// (1, 1) IN ((NULL, 1)) => NULL
run_tuple_test(
make_struct(Some(1), Some(1)),
vec![make_struct(None, Some(1))],
vec![Some(false)],
vec![None],
)?;

// (1, 1) IN ((NULL, NULL)) => FALSE
// (1, 1) IN ((NULL, NULL)) => NULL
run_tuple_test(
make_struct(Some(1), Some(1)),
vec![make_struct(None, None)],
vec![Some(false)],
vec![None],
)?;

Ok(())
Expand Down
Loading
Loading