diff --git a/datafusion/core/tests/fuzz_cases/mod.rs b/datafusion/core/tests/fuzz_cases/mod.rs index edb53df382c62..19d28859cbb2b 100644 --- a/datafusion/core/tests/fuzz_cases/mod.rs +++ b/datafusion/core/tests/fuzz_cases/mod.rs @@ -21,6 +21,7 @@ mod distinct_count_string_fuzz; #[expect(clippy::needless_pass_by_value)] mod join_fuzz; mod merge_fuzz; +mod smj_filter_pushdown; #[expect(clippy::needless_pass_by_value)] mod sort_fuzz; #[expect(clippy::needless_pass_by_value)] diff --git a/datafusion/core/tests/fuzz_cases/smj_filter_pushdown.rs b/datafusion/core/tests/fuzz_cases/smj_filter_pushdown.rs new file mode 100644 index 0000000000000..37e990cedad57 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/smj_filter_pushdown.rs @@ -0,0 +1,408 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Integration tests for dynamic filter pushdown through SortMergeJoinExec. +//! +//! These tests verify that when TopK dynamic filters are pushed through +//! SortMergeJoinExec for Inner joins, the query results remain correct. +//! Each test runs the same query with and without dynamic filter pushdown +//! and compares the results. +//! +//! Data is written to in-memory parquet files (via an InMemory object store) +//! so that the DataSourceExec supports filter pushdown — in-memory tables +//! do not. + +use std::sync::Arc; + +use arrow::array::{Float64Array, Int32Array, StringArray}; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::RecordBatch; +use arrow::util::pretty::pretty_format_batches; +use datafusion::datasource::listing::{ListingOptions, ListingTable, ListingTableConfig}; +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_datasource::ListingTableUrl; +use datafusion_datasource_parquet::ParquetFormat; +use datafusion_execution::object_store::ObjectStoreUrl; +use object_store::memory::InMemory; +use object_store::path::Path; +use object_store::{ObjectStore, ObjectStoreExt, PutPayload}; +use parquet::arrow::ArrowWriter; + +fn left_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + Field::new("score", DataType::Float64, false), + ])) +} + +fn left_batch() -> RecordBatch { + RecordBatch::try_new( + left_schema(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10])), + Arc::new(StringArray::from(vec![ + "alice", "bob", "carol", "dave", "eve", "frank", "grace", "heidi", + "ivan", "judy", + ])), + Arc::new(Float64Array::from(vec![ + 90.0, 85.0, 72.0, 95.0, 60.0, 88.0, 77.0, 91.0, 68.0, 83.0, + ])), + ], + ) + .unwrap() +} + +fn right_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("dept", DataType::Utf8, false), + Field::new("rating", DataType::Utf8, false), + ])) +} + +fn right_batch() -> RecordBatch { + RecordBatch::try_new( + right_schema(), + vec![ + Arc::new(Int32Array::from(vec![1, 3, 5, 7, 9, 11, 13])), + Arc::new(StringArray::from(vec![ + "eng", "sales", "eng", "hr", "sales", "eng", "hr", + ])), + Arc::new(StringArray::from(vec!["A", "B", "C", "A", "B", "A", "C"])), + ], + ) + .unwrap() +} + +/// Write a RecordBatch to an in-memory object store as a parquet file. +async fn write_parquet(store: &Arc, path: &str, batch: &RecordBatch) { + let mut buf = vec![]; + let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), None).unwrap(); + writer.write(batch).unwrap(); + writer.close().unwrap(); + store + .put(&Path::from(path), PutPayload::from(buf)) + .await + .unwrap(); +} + +/// Register a parquet-backed listing table from an in-memory object store. +fn register_listing_table( + ctx: &SessionContext, + table_name: &str, + schema: Arc, + path: &str, +) { + let format = Arc::new( + ParquetFormat::default() + .with_options(ctx.state().table_options().parquet.clone()), + ); + let options = ListingOptions::new(format); + let table_path = ListingTableUrl::parse(format!("memory:///{path}")).unwrap(); + let config = ListingTableConfig::new(table_path) + .with_listing_options(options) + .with_schema(schema); + let table = Arc::new(ListingTable::try_new(config).unwrap()); + ctx.register_table(table_name, table).unwrap(); +} + +/// Build a SessionContext backed by in-memory parquet, with SMJ forced. +/// +/// Uses 2 target partitions so that the optimizer inserts hash-repartitioning +/// and sort nodes — exercising the filter-passthrough through these operators — +/// while still producing a SortMergeJoinExec (not CollectLeft HashJoin, which +/// is preferred at target_partitions=1). +async fn build_ctx(enable_dynamic_filters: bool) -> SessionContext { + let cfg = SessionConfig::new() + .with_target_partitions(2) + .set_bool("datafusion.optimizer.prefer_hash_join", false) + .set_bool( + "datafusion.optimizer.enable_dynamic_filter_pushdown", + enable_dynamic_filters, + ); + let ctx = SessionContext::new_with_config(cfg); + + let store: Arc = Arc::new(InMemory::new()); + let url = ObjectStoreUrl::parse("memory://").unwrap(); + ctx.register_object_store(url.as_ref(), Arc::clone(&store)); + + write_parquet(&store, "left.parquet", &left_batch()).await; + write_parquet(&store, "right.parquet", &right_batch()).await; + + register_listing_table(&ctx, "left_t", left_schema(), "left.parquet"); + register_listing_table(&ctx, "right_t", right_schema(), "right.parquet"); + + ctx +} + +/// Run a query with and without dynamic filter pushdown, assert same results, +/// and verify the plan uses SortMergeJoinExec. +async fn run_and_compare(query: &str) { + // Run without dynamic filters (baseline) + let ctx_off = build_ctx(false).await; + let expected = ctx_off.sql(query).await.unwrap().collect().await.unwrap(); + + // Run with dynamic filters + let ctx_on = build_ctx(true).await; + let actual = ctx_on.sql(query).await.unwrap().collect().await.unwrap(); + + // Verify results match + let expected_str = pretty_format_batches(&expected).unwrap().to_string(); + let actual_str = pretty_format_batches(&actual).unwrap().to_string(); + assert_eq!( + expected_str, actual_str, + "Results differ for query: {query}\n\nExpected:\n{expected_str}\n\nActual:\n{actual_str}" + ); + + // Verify plan uses SortMergeJoinExec + let explain = ctx_on + .sql(&format!("EXPLAIN {query}")) + .await + .unwrap() + .collect() + .await + .unwrap(); + let plan_str = pretty_format_batches(&explain).unwrap().to_string(); + assert!( + plan_str.contains("SortMergeJoinExec"), + "Expected SortMergeJoinExec in plan for query: {query}\n\nPlan:\n{plan_str}" + ); +} + +// ---- Test cases ---- + +#[tokio::test] +async fn test_smj_topk_on_left_column() { + run_and_compare( + "SELECT l.id, l.name, r.dept \ + FROM left_t l INNER JOIN right_t r ON l.id = r.id \ + ORDER BY l.name LIMIT 3", + ) + .await; +} + +#[tokio::test] +async fn test_smj_topk_on_right_column() { + run_and_compare( + "SELECT l.id, l.name, r.dept, r.rating \ + FROM left_t l INNER JOIN right_t r ON l.id = r.id \ + ORDER BY r.rating LIMIT 2", + ) + .await; +} + +#[tokio::test] +async fn test_smj_topk_on_join_key() { + run_and_compare( + "SELECT l.id, l.name, r.dept \ + FROM left_t l INNER JOIN right_t r ON l.id = r.id \ + ORDER BY l.id LIMIT 3", + ) + .await; +} + +#[tokio::test] +async fn test_smj_topk_desc_order() { + run_and_compare( + "SELECT l.id, l.name, r.dept \ + FROM left_t l INNER JOIN right_t r ON l.id = r.id \ + ORDER BY l.score DESC LIMIT 2", + ) + .await; +} + +#[tokio::test] +async fn test_smj_topk_multi_column_order() { + run_and_compare( + "SELECT l.id, l.name, r.dept, r.rating \ + FROM left_t l INNER JOIN right_t r ON l.id = r.id \ + ORDER BY r.dept, l.name LIMIT 3", + ) + .await; +} + +#[tokio::test] +async fn test_smj_topk_with_where_clause() { + run_and_compare( + "SELECT l.id, l.name, r.dept \ + FROM left_t l INNER JOIN right_t r ON l.id = r.id \ + WHERE l.score > 70.0 \ + ORDER BY l.name LIMIT 2", + ) + .await; +} + +#[tokio::test] +async fn test_smj_topk_limit_one() { + run_and_compare( + "SELECT l.id, l.name, r.dept \ + FROM left_t l INNER JOIN right_t r ON l.id = r.id \ + ORDER BY l.score LIMIT 1", + ) + .await; +} + +#[tokio::test] +async fn test_smj_topk_limit_exceeds_rows() { + run_and_compare( + "SELECT l.id, l.name, r.dept \ + FROM left_t l INNER JOIN right_t r ON l.id = r.id \ + ORDER BY l.id LIMIT 100", + ) + .await; +} + +#[tokio::test] +async fn test_smj_left_join_correctness() { + // Left join should produce correct results with or without dynamic filters. + // DynamicFilter is pushed through SMJ for any join type that preserves the filtered side. + run_and_compare( + "SELECT l.id, l.name, r.dept \ + FROM left_t l LEFT JOIN right_t r ON l.id = r.id \ + ORDER BY l.id LIMIT 3", + ) + .await; +} + +#[tokio::test] +async fn test_smj_nested_joins_topk() { + run_and_compare( + "SELECT l.id, l.name, r1.dept, r2.rating \ + FROM left_t l \ + INNER JOIN right_t r1 ON l.id = r1.id \ + INNER JOIN right_t r2 ON l.id = r2.id \ + ORDER BY l.name LIMIT 3", + ) + .await; +} + +/// Verify that with dynamic filters enabled, the physical plan for an Inner +/// join + TopK query contains a DynamicFilter pushed down to a DataSourceExec. +/// This confirms the feature is effective end-to-end with parquet. +#[tokio::test] +async fn test_smj_dynamic_filter_present_in_plan() { + let query = "SELECT l.id, l.name, r.dept \ + FROM left_t l INNER JOIN right_t r ON l.id = r.id \ + ORDER BY l.name LIMIT 3"; + + let ctx = build_ctx(true).await; + let explain = ctx + .sql(&format!("EXPLAIN {query}")) + .await + .unwrap() + .collect() + .await + .unwrap(); + let plan_str = pretty_format_batches(&explain).unwrap().to_string(); + + assert!( + plan_str.contains("SortMergeJoinExec"), + "Expected SortMergeJoinExec in plan\n\nPlan:\n{plan_str}" + ); + + // With parquet + SMJ + TopK, a DynamicFilter should be pushed through + // the SortMergeJoinExec to one of the DataSourceExec nodes. + let has_dynamic_filter = plan_str + .lines() + .any(|l| l.contains("DataSourceExec") && l.contains("DynamicFilter")); + assert!( + has_dynamic_filter, + "Expected DynamicFilter pushed to DataSourceExec through SortMergeJoinExec\n\nPlan:\n{plan_str}" + ); +} + +/// Regression test for a non-deterministic correctness bug: the SMJ dynamic +/// filter is a single expression shared across the concurrently-executing hash +/// partitions feeding the join. An earlier design advanced a one-sided bound as +/// partitions made progress (and *raised* it on partition exhaustion), which +/// could prune valid join rows depending on thread scheduling. +/// +/// This exercises the failure shape directly — join keys with NULLs and values +/// at the range extremes, multiple partitions, and a tiny batch size — and runs +/// it repeatedly. The match set is keys {1, 3}; the previously-buggy filter +/// would intermittently drop the key=3 (max) row. +#[tokio::test] +async fn test_smj_dynamic_filter_multi_partition_null_keys_correct() { + async fn build(enable_dynamic_filters: bool) -> SessionContext { + let cfg = SessionConfig::new() + .with_target_partitions(2) + .with_batch_size(2) + .set_bool("datafusion.optimizer.prefer_hash_join", false) + .set_bool( + "datafusion.optimizer.enable_dynamic_filter_pushdown", + enable_dynamic_filters, + ); + let ctx = SessionContext::new_with_config(cfg); + let store: Arc = Arc::new(InMemory::new()); + ctx.register_object_store( + ObjectStoreUrl::parse("memory://").unwrap().as_ref(), + Arc::clone(&store), + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("tag", DataType::Utf8, false), + ])); + let left = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(3)])), + Arc::new(StringArray::from(vec!["a", "b", "c", "d"])), + ], + ) + .unwrap(); + let right = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![Some(1), None, None, Some(3)])), + Arc::new(StringArray::from(vec!["A", "B", "C", "D"])), + ], + ) + .unwrap(); + write_parquet(&store, "l.parquet", &left).await; + write_parquet(&store, "r.parquet", &right).await; + register_listing_table(&ctx, "l", Arc::clone(&schema), "l.parquet"); + register_listing_table(&ctx, "r", schema, "r.parquet"); + ctx + } + + let query = "SELECT l.id, l.tag, r.tag FROM l INNER JOIN r ON l.id = r.id"; + + // Baseline (filter off) is the ground truth: keys {1, 3} match → 2 rows. + let ctx_off = build(false).await; + let expected_batches = ctx_off.sql(query).await.unwrap().collect().await.unwrap(); + let expected_rows: usize = expected_batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(expected_rows, 2, "expected 2 result rows (keys 1 and 3)"); + let expected = pretty_format_batches(&expected_batches) + .unwrap() + .to_string(); + + // Run many times to surface any scheduling-dependent pruning. + for i in 0..40 { + let ctx_on = build(true).await; + let actual = pretty_format_batches( + &ctx_on.sql(query).await.unwrap().collect().await.unwrap(), + ) + .unwrap() + .to_string(); + assert_eq!( + expected, actual, + "dynamic filter dropped valid join rows on iteration {i}\n\nExpected:\n{expected}\n\nActual:\n{actual}" + ); + } +} diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/bitwise_stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/bitwise_stream.rs index 99aef6ed82a36..2625ef5acb967 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/bitwise_stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/bitwise_stream.rs @@ -123,6 +123,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use crate::joins::sort_merge_join::shared_bounds::SharedSortMergeBoundsAccumulator; use crate::joins::utils::{JoinFilter, JoinKeyComparator, compare_join_arrays}; use crate::metrics::{ BaselineMetrics, Count, ExecutionPlanMetricsSet, Gauge, MetricBuilder, @@ -294,6 +295,19 @@ pub(crate) struct BitwiseSortMergeJoinStream { baseline_metrics: BaselineMetrics, peak_mem_used: Gauge, + // ======================================================================== + // DYNAMIC FILTER FIELDS: + // These fields manage dynamic filter pushdown. + // ======================================================================== + /// Dynamic filter for the streamed side (here, the `outer` stream), + /// advanced from the buffered side's head values. + streamed_dynamic_filter: Option>, + /// Dynamic filter for the buffered side (here, the `inner` stream), + /// advanced from the streamed side's head values. + buffered_dynamic_filter: Option>, + /// Partition ID of this stream + partition: usize, + // Memory / spill — only the inner key buffer is tracked via reservation, // matching existing SMJ (which tracks only the buffered side). The outer // batch is a single batch at a time and cannot be spilled. @@ -337,6 +351,8 @@ impl BitwiseSortMergeJoinStream { reservation: MemoryReservation, spill_manager: SpillManager, runtime_env: Arc, + streamed_dynamic_filter: Option>, + buffered_dynamic_filter: Option>, ) -> Result { debug_assert!( matches!( @@ -394,6 +410,9 @@ impl BitwiseSortMergeJoinStream { input_rows, baseline_metrics, peak_mem_used, + streamed_dynamic_filter, + buffered_dynamic_filter, + partition, reservation, spill_manager, runtime_env, @@ -484,8 +503,22 @@ impl BitwiseSortMergeJoinStream { /// Poll for the next outer batch. Returns true if a batch was loaded. fn poll_next_outer_batch(&mut self, cx: &mut Context<'_>) -> Poll> { loop { + // The outer side is the streamed side; report its last join key to + // the buffered side's dynamic filter before pulling the next batch. + if let Some(accumulator) = &self.buffered_dynamic_filter + && let Some(batch) = &self.outer_batch + && let Some(key) = self.outer_key_arrays.first().and_then(|arr| { + ScalarValue::try_from_array(arr, batch.num_rows() - 1).ok() + }) + { + accumulator.report_head(self.partition, key)?; + } + match ready!(self.outer.poll_next_unpin(cx)) { None => { + if let Some(accumulator) = &self.buffered_dynamic_filter { + accumulator.mark_exhausted(self.partition)?; + } // Release the outer input pipeline's resources. let outer_schema = self.outer.schema(); self.outer = Box::pin(EmptyRecordBatchStream::new(outer_schema)); @@ -500,6 +533,16 @@ impl BitwiseSortMergeJoinStream { continue; } let keys = evaluate_join_keys(&batch, &self.on_outer)?; + + // Report first join key to the buffered side's dynamic filter. + if let Some(accumulator) = &self.buffered_dynamic_filter + && let Some(key) = keys + .first() + .and_then(|arr| ScalarValue::try_from_array(arr, 0).ok()) + { + accumulator.report_head(self.partition, key)?; + } + self.outer_batch = Some(batch); self.outer_offset = 0; self.outer_key_arrays = keys; @@ -517,8 +560,22 @@ impl BitwiseSortMergeJoinStream { /// Poll for the next inner batch. Returns true if a batch was loaded. fn poll_next_inner_batch(&mut self, cx: &mut Context<'_>) -> Poll> { loop { + // The inner side is the buffered side; report its last join key to + // the streamed side's dynamic filter before pulling the next batch. + if let Some(accumulator) = &self.streamed_dynamic_filter + && let Some(batch) = &self.inner_batch + && let Some(key) = self.inner_key_arrays.first().and_then(|arr| { + ScalarValue::try_from_array(arr, batch.num_rows() - 1).ok() + }) + { + accumulator.report_head(self.partition, key)?; + } + match ready!(self.inner.poll_next_unpin(cx)) { None => { + if let Some(accumulator) = &self.streamed_dynamic_filter { + accumulator.mark_exhausted(self.partition)?; + } // Release the inner input pipeline's resources. let inner_schema = self.inner.schema(); self.inner = Box::pin(EmptyRecordBatchStream::new(inner_schema)); @@ -533,6 +590,16 @@ impl BitwiseSortMergeJoinStream { continue; } let keys = evaluate_join_keys(&batch, &self.on_inner)?; + + // Report first join key to the streamed side's dynamic filter. + if let Some(accumulator) = &self.streamed_dynamic_filter + && let Some(key) = keys + .first() + .and_then(|arr| ScalarValue::try_from_array(arr, 0).ok()) + { + accumulator.report_head(self.partition, key)?; + } + self.inner_batch = Some(batch); self.inner_offset = 0; self.inner_key_arrays = keys; diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs b/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs index a8d25fd002b76..e9e102d0b150d 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs @@ -19,14 +19,20 @@ //! A Sort-Merge join plan consumes two sorted children plans and produces //! joined output by given join type and other options. -use std::fmt::Formatter; -use std::sync::Arc; +use std::collections::HashSet; +use std::fmt::{Debug, Formatter}; +use std::sync::{Arc, OnceLock}; use super::bitwise_stream::BitwiseSortMergeJoinStream; use super::materializing_stream::MaterializingSortMergeJoinStream; use super::metrics::SortMergeJoinMetrics; use crate::execution_plan::{EmissionType, boundedness_from_children}; use crate::expressions::PhysicalSortExpr; +use crate::filter_pushdown::{ + ChildFilterDescription, ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, +}; +use crate::joins::sort_merge_join::shared_bounds::SharedSortMergeBoundsAccumulator; use crate::joins::utils::{ JoinFilter, JoinOn, JoinOnRef, build_join_schema, check_join_is_valid, estimate_join_statistics, reorder_output_after_swap, @@ -46,6 +52,7 @@ use crate::{ use arrow::compute::SortOptions; use arrow::datatypes::SchemaRef; +use datafusion_common::config::ConfigOptions; use datafusion_common::{ JoinSide, JoinType, NullEquality, Result, assert_eq_or_internal_err, internal_err, plan_err, @@ -53,7 +60,10 @@ use datafusion_common::{ use datafusion_execution::TaskContext; use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_physical_expr::equivalence::join_equivalence_properties; -use datafusion_physical_expr_common::physical_expr::{PhysicalExprRef, fmt_sql}; +use datafusion_physical_expr::expressions::{DynamicFilterPhysicalExpr, lit}; +use datafusion_physical_expr_common::physical_expr::{ + PhysicalExpr, PhysicalExprRef, fmt_sql, +}; use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements}; /// Join execution plan that executes equi-join predicates on multiple partitions using Sort-Merge @@ -104,7 +114,7 @@ use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequiremen /// /// Helpful short video demonstration: /// . -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct SortMergeJoinExec { /// Left sorted joining execution plan pub left: Arc, @@ -130,6 +140,47 @@ pub struct SortMergeJoinExec { pub null_equality: NullEquality, /// Cache holding plan properties like equivalences, output partitioning etc. cache: Arc, + /// Dynamic filter for the left side + left_dynamic_filter: Option, + /// Dynamic filter for the right side + right_dynamic_filter: Option, +} + +impl Debug for SortMergeJoinExec { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SortMergeJoinExec") + .field("left", &self.left) + .field("right", &self.right) + .field("on", &self.on) + .field("filter", &self.filter) + .field("join_type", &self.join_type) + .field("schema", &self.schema) + .field("metrics", &self.metrics) + .field("left_sort_exprs", &self.left_sort_exprs) + .field("right_sort_exprs", &self.right_sort_exprs) + .field("sort_options", &self.sort_options) + .field("null_equality", &self.null_equality) + .field("cache", &self.cache) + // Explicitly exclude dynamic_filter to avoid runtime state differences in tests + .finish() + } +} + +#[derive(Clone)] +struct SortMergeJoinExecDynamicFilter { + /// Dynamic filter that we'll update with the results of the other side. + filter: Arc, + /// Shared bounds accumulator to collect information from each partition. + /// It is lazily initialized during execution. + accumulator: Arc>>, +} + +impl Debug for SortMergeJoinExecDynamicFilter { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SortMergeJoinExecDynamicFilter") + .field("filter", &self.filter) + .finish() + } } impl SortMergeJoinExec { @@ -201,9 +252,57 @@ impl SortMergeJoinExec { sort_options, null_equality, cache: Arc::new(cache), + left_dynamic_filter: None, + right_dynamic_filter: None, }) } + fn allow_join_dynamic_filter_pushdown(&self, config: &ConfigOptions) -> bool { + if !matches!( + self.join_type, + JoinType::Inner | JoinType::LeftSemi | JoinType::RightSemi + ) || !config.optimizer.enable_join_dynamic_filter_pushdown + { + return false; + } + + true + } + + pub(crate) fn create_dynamic_filter( + on: &JoinOn, + side: JoinSide, + ) -> Arc { + let keys = match side { + JoinSide::Left => on.iter().map(|(l, _)| Arc::clone(l)).collect::>(), + JoinSide::Right => on.iter().map(|(_, r)| Arc::clone(r)).collect::>(), + JoinSide::None => vec![], + }; + Arc::new(DynamicFilterPhysicalExpr::new(keys, lit(true))) + } + + pub fn with_left_dynamic_filter( + mut self, + dynamic_filter: Arc, + ) -> Self { + self.left_dynamic_filter = Some(SortMergeJoinExecDynamicFilter { + filter: dynamic_filter, + accumulator: Arc::new(OnceLock::new()), + }); + self + } + + pub fn with_right_dynamic_filter( + mut self, + dynamic_filter: Arc, + ) -> Self { + self.right_dynamic_filter = Some(SortMergeJoinExecDynamicFilter { + filter: dynamic_filter, + accumulator: Arc::new(OnceLock::new()), + }); + self + } + /// Get probe side (e.g streaming side) information for this sort merge join. /// In current implementation, probe side is determined according to join type. pub fn probe_side(join_type: &JoinType) -> JoinSide { @@ -376,6 +475,9 @@ impl DisplayAs for SortMergeJoinExec { } else { "" }; + // Note: like HashJoinExec, the join's own dynamic filter is not + // rendered on the join node. Its effect is visible on the + // DataSourceExec(s) the filter is pushed down to. write!( f, "{}: join_type={:?}, on=[{}]{}{}", @@ -456,19 +558,49 @@ impl ExecutionPlan for SortMergeJoinExec { ) -> Result> { check_if_same_properties!(self, children); match &children[..] { - [left, right] => Ok(Arc::new(SortMergeJoinExec::try_new( - Arc::clone(left), - Arc::clone(right), - self.on.clone(), - self.filter.clone(), - self.join_type, - self.sort_options.clone(), - self.null_equality, - )?)), + [left, right] => { + let mut node = SortMergeJoinExec::try_new( + Arc::clone(left), + Arc::clone(right), + self.on.clone(), + self.filter.clone(), + self.join_type, + self.sort_options.clone(), + self.null_equality, + )?; + node.left_dynamic_filter + .clone_from(&self.left_dynamic_filter); + node.right_dynamic_filter + .clone_from(&self.right_dynamic_filter); + Ok(Arc::new(node)) + } _ => internal_err!("SortMergeJoin wrong number of children"), } } + fn reset_state(self: Arc) -> Result> { + let mut new_node = (*self).clone(); + + // Reset dynamic filters by creating new containers with fresh OnceLocks + if let Some(f) = &self.left_dynamic_filter { + new_node.left_dynamic_filter = Some(SortMergeJoinExecDynamicFilter { + filter: Arc::clone(&f.filter), + accumulator: Arc::new(OnceLock::new()), + }); + } + if let Some(f) = &self.right_dynamic_filter { + new_node.right_dynamic_filter = Some(SortMergeJoinExecDynamicFilter { + filter: Arc::clone(&f.filter), + accumulator: Arc::new(OnceLock::new()), + }); + } + + // Reset metrics + new_node.metrics = ExecutionPlanMetricsSet::new(); + + Ok(Arc::new(new_node)) + } + fn execute( &self, partition: usize, @@ -482,24 +614,58 @@ impl ExecutionPlan for SortMergeJoinExec { "Invalid SortMergeJoinExec, partition count mismatch {left_partitions}!={right_partitions},\ consider using RepartitionExec" ); - let (on_left, on_right) = self.on.iter().cloned().unzip(); + let (on_left, on_right): (Vec<_>, Vec<_>) = self.on.iter().cloned().unzip(); let (streamed, buffered, on_streamed, on_buffered) = if SortMergeJoinExec::probe_side(&self.join_type) == JoinSide::Left { ( Arc::clone(&self.left), Arc::clone(&self.right), - on_left, - on_right, + on_left.clone(), + on_right.clone(), ) } else { ( Arc::clone(&self.right), Arc::clone(&self.left), - on_right, - on_left, + on_right.clone(), + on_left.clone(), ) }; + let metrics = SortMergeJoinMetrics::new(partition, &self.metrics); + + // Initialize dynamic filters if they exist + let left_dynamic_filter = self.left_dynamic_filter.as_ref().map(|f| { + let accumulator = f.accumulator.get_or_init(|| { + Arc::new(SharedSortMergeBoundsAccumulator::new( + left_partitions, + Arc::clone(&on_left[0]), + Arc::clone(&f.filter), + Some(metrics.dynamic_filter_updates()), + )) + }); + Arc::clone(accumulator) + }); + + let right_dynamic_filter = self.right_dynamic_filter.as_ref().map(|f| { + let accumulator = f.accumulator.get_or_init(|| { + Arc::new(SharedSortMergeBoundsAccumulator::new( + right_partitions, + Arc::clone(&on_right[0]), + Arc::clone(&f.filter), + Some(metrics.dynamic_filter_updates()), + )) + }); + Arc::clone(accumulator) + }); + + let (streamed_dynamic_filter, buffered_dynamic_filter) = + if SortMergeJoinExec::probe_side(&self.join_type) == JoinSide::Left { + (left_dynamic_filter, right_dynamic_filter) + } else { + (right_dynamic_filter, left_dynamic_filter) + }; + // execute children plans let streamed = streamed.execute(partition, Arc::clone(&context))?; let buffered = buffered.execute(partition, Arc::clone(&context))?; @@ -539,6 +705,8 @@ impl ExecutionPlan for SortMergeJoinExec { reservation, spill_manager, context.runtime_env(), + streamed_dynamic_filter, + buffered_dynamic_filter, )?)) } else { Ok(Box::pin(MaterializingSortMergeJoinStream::try_new( @@ -552,10 +720,13 @@ impl ExecutionPlan for SortMergeJoinExec { self.filter.clone(), self.join_type, batch_size, - SortMergeJoinMetrics::new(partition, &self.metrics), + metrics, reservation, spill_manager, context.runtime_env(), + streamed_dynamic_filter, + buffered_dynamic_filter, + partition, )?)) } } @@ -632,7 +803,7 @@ impl ExecutionPlan for SortMergeJoinExec { self.children()[1], )?; - Ok(Some(Arc::new(SortMergeJoinExec::try_new( + let mut node = SortMergeJoinExec::try_new( Arc::new(new_left), Arc::new(new_right), new_on, @@ -640,6 +811,162 @@ impl ExecutionPlan for SortMergeJoinExec { self.join_type, self.sort_options.clone(), self.null_equality, - )?))) + )?; + node.left_dynamic_filter + .clone_from(&self.left_dynamic_filter); + node.right_dynamic_filter + .clone_from(&self.right_dynamic_filter); + Ok(Some(Arc::new(node))) + } + + fn gather_filters_for_pushdown( + &self, + phase: FilterPushdownPhase, + parent_filters: Vec>, + config: &ConfigOptions, + ) -> Result { + // This is the physical-plan equivalent of `push_down_all_join` in + // `datafusion/optimizer/src/push_down_filter.rs`. + // + // We determine which parent filters can be pushed down to which child based on two criteria: + // 1. **Column Preservation**: A side is "preserved" if its columns are present in the join output. + // For example, in a `Left` join, the left side is preserved but the right side is not (right columns are null-padded). + // We can only push filters to preserved sides because non-preserved sides may need all rows + // to correctly produce null-padded matches. + // 2. **Column References**: `ChildFilterDescription::from_child` ensures that a filter is only + // routed to a child if that child's schema contains all columns referenced by the filter. + let (left_preserved, right_preserved) = match self.join_type { + JoinType::Inner => (true, true), + JoinType::Left => (true, false), + JoinType::Right => (false, true), + JoinType::Full => (false, false), + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false), + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + (false, true) + } + }; + + // For Inner joins, the output schema is [left_cols..., right_cols...]. + // Build allowed index sets for each side so that + // `from_child_with_allowed_indices` can route each parent filter to + // the correct child based on column references. + let column_indices = + build_join_schema(&self.left.schema(), &self.right.schema(), &self.join_type) + .1; + let (mut left_allowed, mut right_allowed) = (HashSet::new(), HashSet::new()); + column_indices + .iter() + .enumerate() + .for_each(|(output_idx, ci)| { + match ci.side { + JoinSide::Left => left_allowed.insert(output_idx), + JoinSide::Right => right_allowed.insert(output_idx), + JoinSide::None => false, + }; + }); + + let mut left_child = if left_preserved { + ChildFilterDescription::from_child_with_allowed_indices( + &parent_filters, + left_allowed, + &self.left, + )? + } else { + ChildFilterDescription::all_unsupported(&parent_filters) + }; + + let mut right_child = if right_preserved { + ChildFilterDescription::from_child_with_allowed_indices( + &parent_filters, + right_allowed, + &self.right, + )? + } else { + ChildFilterDescription::all_unsupported(&parent_filters) + }; + + // Add dynamic filters in Post phase if enabled + if matches!(phase, FilterPushdownPhase::Post) + && self.allow_join_dynamic_filter_pushdown(config) + { + if left_preserved { + let dynamic_filter = + Self::create_dynamic_filter(&self.on, JoinSide::Left); + left_child = left_child.with_self_filter(dynamic_filter); + } + if right_preserved { + let dynamic_filter = + Self::create_dynamic_filter(&self.on, JoinSide::Right); + right_child = right_child.with_self_filter(dynamic_filter); + } + } + + Ok(FilterDescription::new() + .with_child(left_child) + .with_child(right_child)) + } + + fn handle_child_pushdown_result( + &self, + _phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + // This method performs the "Upward Handshake" of the physical optimizer. + // It receives the result of pushing filters down to our children and decides + // whether to update this node to reflect those changes. + let mut result: FilterPushdownPropagation> = + FilterPushdownPropagation::if_any(child_pushdown_result.clone()); + assert_eq!(child_pushdown_result.self_filters.len(), 2); + + let left_child_self_filters = &child_pushdown_result.self_filters[0]; + let right_child_self_filters = &child_pushdown_result.self_filters[1]; + + let mut node = (*self).clone(); + let mut node_updated_with_filters = false; + + // 1. Check if our children accepted the dynamic filters we generated in `gather_filters_for_pushdown`. + // If so, we store the filter reference in this node so we can update it during execution. + if let Some(filter) = left_child_self_filters.first() { + let predicate = Arc::clone(&filter.predicate); + if let Ok(dynamic_filter) = + Arc::downcast::(predicate) + { + node = node.with_left_dynamic_filter(dynamic_filter); + node_updated_with_filters = true; + } + } + + if let Some(filter) = right_child_self_filters.first() { + let predicate = Arc::clone(&filter.predicate); + if let Ok(dynamic_filter) = + Arc::downcast::(predicate) + { + node = node.with_right_dynamic_filter(dynamic_filter); + node_updated_with_filters = true; + } + } + + // 2. Determine the final updated node to return to the parent. + if let Some(updated_child_plan) = result.updated_node.take() { + // Case A: The optimizer rule already provided an updated version of this node + // (e.g., because children were replaced). We must ensure our dynamic filters + // are applied to that specific version to maintain the chain. + let mut final_node = updated_child_plan + .downcast_ref::() + .expect("updated_node must be SortMergeJoinExec") + .clone(); + + if node_updated_with_filters { + final_node.left_dynamic_filter = node.left_dynamic_filter; + final_node.right_dynamic_filter = node.right_dynamic_filter; + } + result.updated_node = Some(Arc::new(final_node) as _); + } else if node_updated_with_filters { + // Case B: Children didn't change, but we added dynamic filters to ourselves. + result.updated_node = Some(Arc::new(node) as _); + } + + Ok(result) } } diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs index f1a18aac762f5..d8fc2c513dddb 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs @@ -37,6 +37,7 @@ use crate::joins::sort_merge_join::filter::{ get_filter_columns, needs_deferred_filtering, }; use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics; +use crate::joins::sort_merge_join::shared_bounds::SharedSortMergeBoundsAccumulator; use crate::joins::utils::{JoinFilter, JoinKeyComparator}; use crate::metrics::RecordOutput; use crate::spill::spill_manager::SpillManager; @@ -50,7 +51,9 @@ use arrow::compute::{ }; use arrow::datatypes::SchemaRef; use datafusion_common::cast::as_uint64_array; -use datafusion_common::{JoinType, NullEquality, Result, exec_err, internal_err}; +use datafusion_common::{ + JoinType, NullEquality, Result, ScalarValue, exec_err, internal_err, +}; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::MemoryReservation; use datafusion_execution::runtime_env::RuntimeEnv; @@ -155,6 +158,26 @@ impl StreamedBatch { } } + /// Returns the first join key value in this batch + fn first_join_key(&self) -> Option { + if self.batch.num_rows() == 0 { + return None; + } + self.join_arrays + .first() + .and_then(|arr| ScalarValue::try_from_array(arr, 0).ok()) + } + + /// Returns the last join key value in this batch + fn last_join_key(&self) -> Option { + if self.batch.num_rows() == 0 { + return None; + } + self.join_arrays.first().and_then(|arr| { + ScalarValue::try_from_array(arr, self.batch.num_rows() - 1).ok() + }) + } + /// Number of unfrozen output pairs in this streamed batch fn num_output_rows(&self) -> usize { self.num_output_rows @@ -293,6 +316,26 @@ impl BufferedBatch { num_rows, } } + + /// Returns the first join key value in this batch + fn first_join_key(&self) -> Option { + if self.num_rows == 0 { + return None; + } + self.join_arrays + .first() + .and_then(|arr| ScalarValue::try_from_array(arr, 0).ok()) + } + + /// Returns the last join key value in this batch + fn last_join_key(&self) -> Option { + if self.num_rows == 0 { + return None; + } + self.join_arrays + .first() + .and_then(|arr| ScalarValue::try_from_array(arr, self.num_rows - 1).ok()) + } } // TODO: Spill join arrays (https://github.com/apache/datafusion/pull/17429) @@ -385,6 +428,19 @@ pub(super) struct MaterializingSortMergeJoinStream { /// Tracks the number of batches currently spilled pub spilled_batch_count: usize, + // ======================================================================== + // DYNAMIC FILTER FIELDS: + // These fields manage dynamic filter pushdown. + // ======================================================================== + /// Dynamic filter for the streamed side, advanced from the buffered side's + /// head values. + pub streamed_dynamic_filter: Option>, + /// Dynamic filter for the buffered side, advanced from the streamed side's + /// head values. + pub buffered_dynamic_filter: Option>, + /// Partition ID of this stream + pub partition_id: usize, + // ======================================================================== // CACHED COMPARATORS: // Pre-built comparators to avoid per-row type dispatch in hot loops. @@ -827,6 +883,9 @@ impl MaterializingSortMergeJoinStream { reservation: MemoryReservation, spill_manager: SpillManager, runtime_env: Arc, + streamed_dynamic_filter: Option>, + buffered_dynamic_filter: Option>, + partition_id: usize, ) -> Result { let streamed_schema = streamed.schema(); let buffered_schema = buffered.schema(); @@ -872,6 +931,9 @@ impl MaterializingSortMergeJoinStream { spill_manager, spill_stream: None, spilled_batch_count: 0, + streamed_dynamic_filter, + buffered_dynamic_filter, + partition_id, streamed_buffered_cmp: None, buffered_equality_cmp: None, streamed_batch_counter: AtomicUsize::new(0), @@ -1035,6 +1097,12 @@ impl MaterializingSortMergeJoinStream { self.streamed_state = StreamedState::Ready; return Poll::Ready(Some(Ok(()))); } else { + // Report last join key before pulling next batch + if let Some(accumulator) = &self.buffered_dynamic_filter + && let Some(key) = self.streamed_batch.last_join_key() + { + accumulator.report_head(self.partition_id, key)?; + } self.streamed_state = StreamedState::Polling; } } @@ -1055,6 +1123,9 @@ impl MaterializingSortMergeJoinStream { self.streamed = Box::pin(EmptyRecordBatchStream::new(streamed_schema)); self.streamed_state = StreamedState::Exhausted; + if let Some(accumulator) = &self.buffered_dynamic_filter { + accumulator.mark_exhausted(self.partition_id)?; + } } Poll::Ready(Some(batch)) => { if batch.num_rows() > 0 { @@ -1063,6 +1134,16 @@ impl MaterializingSortMergeJoinStream { self.join_metrics.input_rows().add(batch.num_rows()); self.streamed_batch = StreamedBatch::new(batch, &self.on_streamed); + + // Report first join key to the buffered side's + // dynamic filter + if let Some(accumulator) = &self.buffered_dynamic_filter + && let Some(key) = + self.streamed_batch.first_join_key() + { + accumulator.report_head(self.partition_id, key)?; + } + self.rebuild_streamed_buffered_cmp()?; // Every incoming streaming batch should have its unique id // Check `JoinedRecordBatches.self.streamed_batch_counter` documentation @@ -1202,6 +1283,9 @@ impl MaterializingSortMergeJoinStream { self.buffered = Box::pin(EmptyRecordBatchStream::new(buffered_schema)); self.buffered_state = BufferedState::Exhausted; + if let Some(accumulator) = &self.streamed_dynamic_filter { + accumulator.mark_exhausted(self.partition_id)?; + } return Poll::Ready(None); } Poll::Ready(Some(batch)) => { @@ -1212,6 +1296,14 @@ impl MaterializingSortMergeJoinStream { let buffered_batch = BufferedBatch::new(batch, 0..1, &self.on_buffered); + // Report first join key to the streamed side's + // dynamic filter + if let Some(accumulator) = &self.streamed_dynamic_filter + && let Some(key) = buffered_batch.first_join_key() + { + accumulator.report_head(self.partition_id, key)?; + } + self.allocate_reservation(buffered_batch)?; self.streamed_buffered_cmp = None; self.buffered_state = BufferedState::PollingRest; @@ -1239,6 +1331,14 @@ impl MaterializingSortMergeJoinStream { } } } else { + // Report last join key before pulling next batch + if let Some(accumulator) = &self.streamed_dynamic_filter + && let Some(key) = + self.buffered_data.tail_batch().last_join_key() + { + accumulator.report_head(self.partition_id, key)?; + } + match self.buffered.poll_next_unpin(cx)? { Poll::Pending => { return Poll::Pending; @@ -1250,6 +1350,9 @@ impl MaterializingSortMergeJoinStream { buffered_schema, )); self.buffered_state = BufferedState::Ready; + if let Some(accumulator) = &self.streamed_dynamic_filter { + accumulator.mark_exhausted(self.partition_id)?; + } } Poll::Ready(Some(batch)) => { // Polling batches coming concurrently as multiple partitions @@ -1261,6 +1364,17 @@ impl MaterializingSortMergeJoinStream { 0..0, &self.on_buffered, ); + + // Report first join key to the streamed side's + // dynamic filter + if let Some(accumulator) = + &self.streamed_dynamic_filter + && let Some(key) = buffered_batch.first_join_key() + { + accumulator + .report_head(self.partition_id, key)?; + } + self.allocate_reservation(buffered_batch)?; self.buffered_equality_cmp = None; } diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/metrics.rs b/datafusion/physical-plan/src/joins/sort_merge_join/metrics.rs index 62efb77f877ab..99d7c40b59f1f 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/metrics.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/metrics.rs @@ -35,6 +35,8 @@ pub(super) struct SortMergeJoinMetrics { /// Peak memory used for buffered data. /// Calculated as sum of peak memory values across partitions peak_mem_used: Gauge, + /// Number of times the dynamic filter was tightened + dynamic_filter_updates: Count, } impl SortMergeJoinMetrics { @@ -49,6 +51,8 @@ impl SortMergeJoinMetrics { let peak_mem_used = MetricBuilder::new(metrics) .with_category(MetricCategory::Bytes) .gauge("peak_mem_used", partition); + let dynamic_filter_updates = + MetricBuilder::new(metrics).counter("dynamic_filter_updates", partition); let baseline_metrics = BaselineMetrics::new(metrics, partition); @@ -58,6 +62,7 @@ impl SortMergeJoinMetrics { input_rows, baseline_metrics, peak_mem_used, + dynamic_filter_updates, } } @@ -80,4 +85,8 @@ impl SortMergeJoinMetrics { pub fn peak_mem_used(&self) -> Gauge { self.peak_mem_used.clone() } + + pub fn dynamic_filter_updates(&self) -> Count { + self.dynamic_filter_updates.clone() + } } diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/mod.rs b/datafusion/physical-plan/src/joins/sort_merge_join/mod.rs index 2fdb0924e723d..b5cb8691dc3d7 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/mod.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/mod.rs @@ -24,6 +24,7 @@ mod exec; mod filter; pub(crate) mod materializing_stream; mod metrics; +mod shared_bounds; #[cfg(test)] mod tests; diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/shared_bounds.rs b/datafusion/physical-plan/src/joins/sort_merge_join/shared_bounds.rs new file mode 100644 index 0000000000000..dd6033854f355 --- /dev/null +++ b/datafusion/physical-plan/src/joins/sort_merge_join/shared_bounds.rs @@ -0,0 +1,160 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Shared bounds for Sort-Merge Join dynamic filter pushdown. +//! +//! This mirrors the correctness model used by `HashJoinExec`'s +//! [`crate::joins::hash_join::shared_bounds`]: a dynamic filter must be built +//! from *complete* information about the side it summarizes, and published +//! **exactly once**. Publishing a partial or mid-stream bound — or advancing a +//! bound as partitions make progress — can incorrectly eliminate valid join +//! results, because the single dynamic filter is shared across all of the +//! (concurrently executing) hash partitions feeding the join. +//! +//! Accordingly, each partition streams its join-key values to the accumulator, +//! which tracks the global `[min, max]` range across all partitions. Only once +//! every partition has been fully consumed is the filter published as a static +//! range predicate `col >= min AND col <= max` (a superset that never prunes a +//! matchable row), after which it is marked complete. + +use crate::metrics::Count; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::Operator; +use datafusion_physical_expr::PhysicalExprRef; +use datafusion_physical_expr::expressions::{BinaryExpr, DynamicFilterPhysicalExpr, lit}; +use parking_lot::Mutex; +use std::sync::Arc; + +/// Coordinates dynamic filter updates for Sort-Merge Join across multiple partitions. +/// +/// The filter summarizes the join keys observed on one side of the join and is +/// pushed down to prune the *other* side. It is published once, after all +/// partitions have reported their data, as a static range predicate. See the +/// module documentation for why mid-stream/advancing bounds are unsafe here. +#[derive(Debug)] +pub(crate) struct SharedSortMergeBoundsAccumulator { + /// Shared state for all partitions. + state: Mutex, + /// Total number of partitions that will report to this accumulator. + num_partitions: usize, + /// Dynamic filter to update (once, when all partitions are exhausted). + dynamic_filter: Arc, + /// Join key expression on the side being filtered. + on_expr: PhysicalExprRef, + /// Metric to track number of filter updates. + metrics: Option, +} + +#[derive(Debug)] +struct AccumulatorState { + /// Smallest (non-null) join key value observed across all partitions. + min: Option, + /// Largest (non-null) join key value observed across all partitions. + max: Option, + /// Number of partitions that have been fully consumed. + exhausted_count: usize, + /// Whether the filter has already been published. + published: bool, +} + +impl SharedSortMergeBoundsAccumulator { + pub fn new( + num_partitions: usize, + on_expr: PhysicalExprRef, + dynamic_filter: Arc, + metrics: Option, + ) -> Self { + Self { + state: Mutex::new(AccumulatorState { + min: None, + max: None, + exhausted_count: 0, + published: false, + }), + num_partitions, + dynamic_filter, + on_expr, + metrics, + } + } + + /// Report a join key value observed by a partition. + /// + /// Null keys are ignored: they never participate in equi-join matches, so + /// they neither widen the bounds nor need to be preserved by the filter. + pub fn report_head(&self, _partition_id: usize, head: ScalarValue) -> Result<()> { + if head.is_null() { + return Ok(()); + } + let mut state = self.state.lock(); + match &state.min { + Some(min) if min <= &head => {} + _ => state.min = Some(head.clone()), + } + match &state.max { + Some(max) if max >= &head => {} + _ => state.max = Some(head), + } + Ok(()) + } + + /// Mark a partition as fully consumed. + /// + /// Once every partition has been consumed, the accumulated `[min, max]` + /// range is published to the dynamic filter exactly once and the filter is + /// marked complete. + pub fn mark_exhausted(&self, _partition_id: usize) -> Result<()> { + let mut state = self.state.lock(); + state.exhausted_count += 1; + + if state.exhausted_count < self.num_partitions || state.published { + return Ok(()); + } + state.published = true; + + // Build the final, static predicate from complete information. + let filter_expr = match (state.min.take(), state.max.take()) { + (Some(min), Some(max)) => { + // col >= min AND col <= max + let lower = Arc::new(BinaryExpr::new( + Arc::clone(&self.on_expr), + Operator::GtEq, + lit(min), + )); + let upper = Arc::new(BinaryExpr::new( + Arc::clone(&self.on_expr), + Operator::LtEq, + lit(max), + )); + Arc::new(BinaryExpr::new(lower, Operator::And, upper)) as _ + } + // No non-null keys were observed on this side: for the join types + // that enable this filter (Inner / LeftSemi / RightSemi) nothing on + // the other side can match, so prune everything. + _ => lit(false), + }; + + self.dynamic_filter.update(filter_expr)?; + self.dynamic_filter.mark_complete(); + + if let Some(m) = &self.metrics { + m.add(1); + } + + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs index 338c5111d223d..ee7123ce55370 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs @@ -4231,6 +4231,8 @@ async fn filter_buffer_pending_loses_inner_rows() -> Result<()> { reservation, spill_manager, runtime_env, + None, + None, )?; let batches = collect_stream(stream).await?; @@ -4330,6 +4332,8 @@ async fn no_filter_boundary_pending_loses_outer_rows() -> Result<()> { reservation, spill_manager, runtime_env, + None, + None, )?; let batches = collect_stream(stream).await?; @@ -4443,6 +4447,8 @@ async fn filtered_boundary_pending_outer_rows() -> Result<()> { reservation, spill_manager, runtime_env, + None, + None, )?; let batches = collect_stream(stream).await?; @@ -4717,6 +4723,8 @@ async fn spill_filtered_boundary_loses_outer_rows() -> Result<()> { reservation, spill_manager, Arc::clone(&runtime), + None, + None, )?; let batches = collect_stream(stream).await?; diff --git a/datafusion/sqllogictest/test_files/explain_tree.slt b/datafusion/sqllogictest/test_files/explain_tree.slt index 8588c0e7ba2ae..063982ada8dac 100644 --- a/datafusion/sqllogictest/test_files/explain_tree.slt +++ b/datafusion/sqllogictest/test_files/explain_tree.slt @@ -1176,12 +1176,18 @@ physical_plan 09)│ c1@0 ASC ││ c1@0 ASC │ 10)└─────────────┬─────────────┘└─────────────┬─────────────┘ 11)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -12)│ DataSourceExec ││ DataSourceExec │ +12)│ FilterExec ││ FilterExec │ 13)│ -------------------- ││ -------------------- │ -14)│ bytes: 5932 ││ bytes: 5932 │ -15)│ format: memory ││ format: memory │ -16)│ rows: 1 ││ rows: 1 │ -17)└───────────────────────────┘└───────────────────────────┘ +14)│ predicate: ││ predicate: │ +15)│ DynamicFilter [ empty ] ││ DynamicFilter [ empty ] │ +16)└─────────────┬─────────────┘└─────────────┬─────────────┘ +17)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ +18)│ DataSourceExec ││ DataSourceExec │ +19)│ -------------------- ││ -------------------- │ +20)│ bytes: 5932 ││ bytes: 5932 │ +21)│ format: memory ││ format: memory │ +22)│ rows: 1 ││ rows: 1 │ +23)└───────────────────────────┘└───────────────────────────┘ statement ok set datafusion.optimizer.prefer_hash_join = true; diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 082b10167274c..2d88320179418 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -2814,11 +2814,13 @@ logical_plan physical_plan 01)SortMergeJoinExec: join_type=Inner, on=[(c1@0, c1@0)] 02)--SortExec: expr=[c1@0 ASC], preserve_partitioning=[true] -03)----RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=1 -04)------DataSourceExec: partitions=1, partition_sizes=[1] -05)--SortExec: expr=[c1@0 ASC], preserve_partitioning=[true] -06)----RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=1 -07)------DataSourceExec: partitions=1, partition_sizes=[1] +03)----FilterExec: DynamicFilter [ empty ] +04)------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=1 +05)--------DataSourceExec: partitions=1, partition_sizes=[1] +06)--SortExec: expr=[c1@0 ASC], preserve_partitioning=[true] +07)----FilterExec: DynamicFilter [ empty ] +08)------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=1 +09)--------DataSourceExec: partitions=1, partition_sizes=[1] # sort_merge_join_on_date32 inner sort merge join on data type (Date32) query DDRTDDRT rowsort diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index 6244ab70c1eb5..b6fbdef009b80 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -39,9 +39,11 @@ logical_plan physical_plan 01)SortMergeJoinExec: join_type=Inner, on=[(a@0, a@0)], filter=CAST(b@1 AS Int64) * 50 <= CAST(b@0 AS Int64) 02)--SortExec: expr=[a@0 ASC], preserve_partitioning=[false] -03)----DataSourceExec: partitions=1, partition_sizes=[1] -04)--SortExec: expr=[a@0 ASC], preserve_partitioning=[false] -05)----DataSourceExec: partitions=1, partition_sizes=[1] +03)----FilterExec: DynamicFilter [ empty ] +04)------DataSourceExec: partitions=1, partition_sizes=[1] +05)--SortExec: expr=[a@0 ASC], preserve_partitioning=[false] +06)----FilterExec: DynamicFilter [ empty ] +07)------DataSourceExec: partitions=1, partition_sizes=[1] # inner join with join filter query TITI rowsort