From 921f64cd7b8401abab00fafe75601c2de124347e Mon Sep 17 00:00:00 2001 From: Jayant Shrivastava Date: Wed, 29 Apr 2026 14:49:23 +0000 Subject: [PATCH 1/3] proto: serialize dynamic filters on Sort, Aggregate, HashJoin Builds on the prior `DynamicFilterPhysicalExpr` proto serialization + dedupe work so plan-node references to a shared dynamic filter survive roundtrip. - Adds `dynamic_filter` to the proto messages for `SortExec`, `AggregateExec`, and `HashJoinExec` and wires it through to/from-proto. - Exposes `dynamic_filter()` / `with_dynamic_filter()` on those plan nodes so the dedupe deserializer can reattach the shared `DynamicFilterPhysicalExpr` after roundtrip. - Extracts `supported_accumulators_info()` on `AggregateExec` and uses it from `init_dynamic_filter` and `with_dynamic_filter`. - Adds `test_hash_join_with_dynamic_filter_roundtrip`, `test_aggregate_with_dynamic_filter_roundtrip`, and `test_sort_topk_with_dynamic_filter_roundtrip` to verify that the plan node and the pushdown-target `ParquetSource` predicate end up pointing at the same `expression_id` after roundtrip. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../physical_optimizer/filter_pushdown.rs | 2 +- .../physical-expr-common/src/physical_expr.rs | 4 +- .../src/expressions/dynamic_filters.rs | 20 +- .../physical-plan/src/aggregates/mod.rs | 205 +++++++++++++- .../physical-plan/src/joins/hash_join/exec.rs | 93 +++++- datafusion/physical-plan/src/sorts/sort.rs | 89 ++++++ datafusion/proto/proto/datafusion.proto | 6 + datafusion/proto/src/generated/pbjson.rs | 54 ++++ datafusion/proto/src/generated/prost.rs | 9 + datafusion/proto/src/physical_plan/mod.rs | 82 +++++- .../tests/cases/roundtrip_physical_plan.rs | 266 +++++++++++++++--- 11 files changed, 749 insertions(+), 81 deletions(-) diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown.rs index 4ff1fad8f52b9..5f64c9e4a5400 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown.rs @@ -2835,7 +2835,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_is_used() { // Verify that a dynamic filter was created let dynamic_filter = hash_join - .dynamic_filter_for_test() + .dynamic_filter() .expect("Dynamic filter should be created"); // Verify that is_used() returns the expected value based on probe side support. diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 6595635024ed0..8e3e23c10ed0b 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -72,7 +72,9 @@ pub type PhysicalExprRef = Arc; /// [`create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html /// [`Column`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/expressions/struct.Column.html pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash { - /// Get the data type of this expression, given the schema of the input + /// Get the data type of this expression, given the schema of the input. + /// Returns an error if the data type cannot be determined, ex. if the + /// schema is missing a required field. fn data_type(&self, input_schema: &Schema) -> Result { Ok(self.return_field(input_schema)?.data_type().to_owned()) } diff --git a/datafusion/physical-expr/src/expressions/dynamic_filters.rs b/datafusion/physical-expr/src/expressions/dynamic_filters.rs index 2db328377a5e1..5b9de882160aa 100644 --- a/datafusion/physical-expr/src/expressions/dynamic_filters.rs +++ b/datafusion/physical-expr/src/expressions/dynamic_filters.rs @@ -86,7 +86,7 @@ pub struct DynamicFilterPhysicalExpr { /// **Warning:** exposed publicly solely so that proto (de)serialization in /// `datafusion-proto` can read and rebuild this state. Do not treat this type /// or its layout as a stable API. -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Inner { /// A unique identifier for the expression. pub expression_id: u64, @@ -100,24 +100,6 @@ pub struct Inner { pub is_complete: bool, } -// TODO: Include expression_id in Debug output. -// -// See https://github.com/apache/datafusion/issues/20418. Currently, plan nodes -// like `HashJoinExec`, `AggregateExec`, `SortExec` do not serialize their -// dynamic filter. They auto-create one on decode with a fresh `expression_id`, -// so a round-trip Debug comparison would diverge purely on the id even when -// the rest of the state is preserved. Excluding it from Debug keeps those -// roundtrip equality assertions meaningful until that work lands. -impl std::fmt::Debug for Inner { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Inner") - .field("generation", &self.generation) - .field("expr", &self.expr) - .field("is_complete", &self.is_complete) - .finish() - } -} - impl Inner { fn new(expr: Arc) -> Self { Self { diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 76ecb3f1485a4..f311b37e0a0b4 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -46,7 +46,8 @@ use arrow_schema::FieldRef; use datafusion_common::stats::Precision; use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_common::{ - Constraint, Constraints, Result, ScalarValue, assert_eq_or_internal_err, not_impl_err, + Constraint, Constraints, Result, ScalarValue, assert_eq_or_internal_err, + internal_err, not_impl_err, }; use datafusion_execution::TaskContext; use datafusion_expr::{Accumulator, Aggregate}; @@ -1047,6 +1048,47 @@ impl AggregateExec { &self.input_order_mode } + /// Returns the dynamic filter expression for this aggregate, if set. + pub fn dynamic_filter(&self) -> Option<&Arc> { + self.dynamic_filter.as_ref().map(|df| &df.filter) + } + + /// Replace the dynamic filter expression. This method errors if the aggregate does not + /// support dynamic filtering or if the filter expression is incompatible with this + /// [`AggregateExec`]. + pub fn with_dynamic_filter( + mut self, + filter: Arc, + ) -> Result { + // If there is no dynamic filter state initialized via `try_new`, then + // we can safely assume that the aggregate does not support dynamic filtering. + let Some(dyn_filter) = self.dynamic_filter.as_ref() else { + return internal_err!("Aggregate does not support dynamic filtering"); + }; + + // Validate that the filter is compatible with the aggregation columns. + let cols = self.cols_for_dynamic_filter(&dyn_filter.supported_accumulators_info); + if cols.len() != filter.children().len() { + return internal_err!( + "Dynamic filter expression is incompatible with aggregate due to mismatched number of columns" + ); + } + for (col, child) in cols.iter().zip(filter.children()) { + if !col.eq(child) { + return internal_err!( + "Dynamic filter expression is incompatible with aggregate due to mismatched column references {col} != {child}" + ); + } + } + + // Overwrite our filter + self.dynamic_filter = Some(Arc::new(AggrDynFilter { + filter, + supported_accumulators_info: dyn_filter.supported_accumulators_info.clone(), + })); + Ok(self) + } + /// Estimates output statistics for this aggregate node. /// /// For grouped aggregations with known input row count > 1, the output row @@ -1284,6 +1326,28 @@ impl AggregateExec { } } + // Collect column references for the dynamic filter expression from the supported accumulators. + fn cols_for_dynamic_filter( + &self, + supported_accumulators_info: &[PerAccumulatorDynFilter], + ) -> Vec> { + let all_cols: Vec> = supported_accumulators_info + .iter() + .filter_map(|info| { + // This should always be true due to how the supported accumulators + // are constructed. See `init_dynamic_filter` for more details. + if let [arg] = &self.aggr_expr[info.aggr_index].expressions().as_slice() + && arg.is::() + { + return Some(Arc::clone(arg)); + } + None + }) + .collect(); + debug_assert!(all_cols.len() == supported_accumulators_info.len()); + all_cols + } + /// Calculate scaled byte size based on row count ratio. /// Returns `Precision::Absent` if input statistics are insufficient. /// Returns `Precision::Inexact` with the scaled value otherwise. @@ -2177,6 +2241,7 @@ mod tests { use crate::coalesce_partitions::CoalescePartitionsExec; use crate::common; use crate::common::collect; + use crate::empty::EmptyExec; use crate::execution_plan::Boundedness; use crate::expressions::col; use crate::metrics::MetricValue; @@ -2202,6 +2267,7 @@ mod tests { use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf}; use datafusion_functions_aggregate::median::median_udaf; + use datafusion_functions_aggregate::min_max::min_udaf; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::Partitioning; use datafusion_physical_expr::PhysicalSortExpr; @@ -3682,13 +3748,10 @@ mod tests { // Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count). let aggregates: Vec> = vec![ Arc::new( - AggregateExprBuilder::new( - datafusion_functions_aggregate::min_max::min_udaf(), - vec![col("b", &schema)?], - ) - .schema(Arc::clone(&schema)) - .alias("MIN(b)") - .build()?, + AggregateExprBuilder::new(min_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("MIN(b)") + .build()?, ), Arc::new( AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) @@ -3827,13 +3890,10 @@ mod tests { // Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count). let aggregates: Vec> = vec![ Arc::new( - AggregateExprBuilder::new( - datafusion_functions_aggregate::min_max::min_udaf(), - vec![col("b", &schema)?], - ) - .schema(Arc::clone(&schema)) - .alias("MIN(b)") - .build()?, + AggregateExprBuilder::new(min_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("MIN(b)") + .build()?, ), Arc::new( AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) @@ -4781,4 +4841,119 @@ mod tests { Ok(()) } + + /// Test that [`AggregateExec::with_dynamic_filter`] overrides the existing dynamic filter + #[test] + fn test_with_dynamic_filter() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let child = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + // Partial min aggregate supports dynamic filtering + let agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![]), + vec![Arc::new( + AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("min_a") + .build()?, + )], + vec![None], + child, + Arc::clone(&schema), + )?; + + // Assertion 1: A filter with the same children can override the existing + // dynamic filter. + let new_df = Arc::new(DynamicFilterPhysicalExpr::new( + vec![col("a", &schema)?], + lit(false), + )); + let agg = agg.with_dynamic_filter(Arc::clone(&new_df))?; + + // The aggregate's filter should now resolve to the new inner expression. + let swapped = agg + .dynamic_filter() + .expect("should still have dynamic filter") + .current()?; + assert_eq!(format!("{swapped}"), format!("{}", lit(false))); + + // Assertion 2: A filter that has been through `PhysicalExpr::with_new_children` + // should still be accepted when the new children are equivalent to the originals. + let new_df_as_pexpr: Arc = + Arc::::clone(&new_df); + let remapped_pexpr = + new_df_as_pexpr.with_new_children(vec![col("a", &schema)?])?; + let Ok(remapped_df) = (remapped_pexpr + as Arc) + .downcast::() + else { + panic!("should be DynamicFilterPhysicalExpr after with_new_children"); + }; + // Hard to assert this because the filter is identical. No error means + // the filter was accepted. That's a good enough assertion for now. + let _agg = agg.with_dynamic_filter(remapped_df)?; + Ok(()) + } + + /// Test that [`AggregateExec::with_dynamic_filter`] errors when the aggregate does not support dynamic filtering + #[test] + fn test_with_dynamic_filter_error_unsupported() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Int64, false), + ])); + let child = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + // Final mode with a group-by does not support dynamic filters. + let agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]), + vec![Arc::new( + AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("sum_b") + .build()?, + )], + vec![None], + child, + Arc::clone(&schema), + )?; + assert!(agg.dynamic_filter().is_none()); + + let df = Arc::new(DynamicFilterPhysicalExpr::new( + vec![col("a", &schema)?], + lit(true), + )); + assert!(agg.with_dynamic_filter(df).is_err()); + Ok(()) + } + + /// Test that [`AggregateExec::with_dynamic_filter`] errors when the column is not in the schema + #[test] + fn test_with_dynamic_filter_error_column_mismatch() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let child = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![]), + vec![Arc::new( + AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("min_a") + .build()?, + )], + vec![None], + child, + Arc::clone(&schema), + )?; + + let df = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("bad", 99)) as _], + lit(true), + )); + assert!(agg.with_dynamic_filter(df).is_err()); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index 4ebbf7cb31ccf..1f4ee9c48508a 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -903,14 +903,34 @@ impl HashJoinExec { } /// Get the dynamic filter expression for testing purposes. - /// Returns `None` if no dynamic filter has been set. - /// - /// This method is intended for testing only and should not be used in production code. - #[doc(hidden)] - pub fn dynamic_filter_for_test(&self) -> Option<&Arc> { + /// Returns the dynamic filter expression for this hash join, if set. + pub fn dynamic_filter(&self) -> Option<&Arc> { self.dynamic_filter.as_ref().map(|df| &df.filter) } + /// Set the dynamic filter on this hash join. + /// + /// Resets any internal state that depends on any existing dynamic filter. + /// + /// Validates that the filter's children reference valid columns in + /// the probe (right) side's schema. + pub fn with_dynamic_filter( + mut self, + filter: Arc, + ) -> Result { + let probe_schema = self.right.schema(); + for child in filter.children() { + child.data_type(&probe_schema)?; + } + self.dynamic_filter = Some(HashJoinExecDynamicFilter { + filter, + // Initialize with an empty accumulator which will be lazily populated + // during execution. + build_accumulator: OnceLock::new(), + }); + Ok(self) + } + /// Calculate order preservation flags for this hash join. fn maintains_input_order(join_type: JoinType) -> Vec { vec![ @@ -6307,4 +6327,67 @@ mod tests { assert_eq!(lr_is_preserved(JoinType::RightAnti), (true, true)); assert_eq!(lr_is_preserved(JoinType::RightMark), (false, true)); } + + #[test] + fn test_with_dynamic_filter() -> Result<()> { + let (_, _, on) = build_schema_and_on()?; + let left = build_table(("a1", &vec![1]), ("b1", &vec![1]), ("c1", &vec![1])); + let right = build_table(("a2", &vec![1]), ("b1", &vec![1]), ("c2", &vec![1])); + + let join = HashJoinExec::try_new( + left, + right, + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + false, + )?; + assert!(join.dynamic_filter().is_none()); + + let df = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("b1", 1)) as _], + lit(true), + )); + let join = join.with_dynamic_filter(Arc::clone(&df))?; + + let restored = join.dynamic_filter().expect("should have dynamic filter"); + assert_eq!( + restored + .expression_id() + .expect("DynamicFilterPhysicalExpr always has an expression_id"), + df.expression_id() + .expect("DynamicFilterPhysicalExpr always has an expression_id"), + ); + Ok(()) + } + + #[test] + fn test_with_dynamic_filter_rejects_invalid_columns() -> Result<()> { + let (_, _, on) = build_schema_and_on()?; + let left = build_table(("a1", &vec![1]), ("b1", &vec![1]), ("c1", &vec![1])); + let right = build_table(("a2", &vec![1]), ("b1", &vec![1]), ("c2", &vec![1])); + + let join = HashJoinExec::try_new( + left, + right, + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + false, + )?; + + // Column index 99 is out of bounds for the right (probe) side schema. + let df = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("bad", 99)) as _], + lit(true), + )); + assert!(join.with_dynamic_filter(df).is_err()); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 6c02af8dec6d3..f4c764ac73a3f 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -978,6 +978,30 @@ impl SortExec { self.fetch } + /// Returns the dynamic filter expression for this sort (TopK), if set. + pub fn dynamic_filter(&self) -> Option> { + self.filter.as_ref().map(|f| f.read().expr()) + } + + /// Replace the dynamic filter expression for this sort. + /// + /// + /// Resets any internal state which may depend on the previous dynamic filter. + /// + /// Validates that the filter's children reference valid columns in + /// the sort's input schema. + pub fn with_dynamic_filter( + mut self, + filter: Arc, + ) -> Result { + let input_schema = self.input.schema(); + for child in filter.children() { + child.data_type(&input_schema)?; + } + self.filter = Some(Arc::new(RwLock::new(TopKDynamicFilters::new(filter)))); + Ok(self) + } + fn output_partitioning_helper( input: &Arc, preserve_partitioning: bool, @@ -2723,6 +2747,71 @@ mod tests { Ok(()) } + #[test] + fn test_with_dynamic_filter() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let child = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let sort = SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }]) + .unwrap(), + child, + ) + .with_fetch(Some(10)); + + // SortExec with fetch creates a dynamic filter automatically. + let original_id = sort + .dynamic_filter() + .expect("should have dynamic filter with fetch") + .expression_id() + .expect("DynamicFilterPhysicalExpr always has an expression_id"); + + // with_dynamic_filter replaces it with a new TopKDynamicFilters. + let new_df = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("a", 0)) as _], + lit(true), + )); + let new_id = new_df + .expression_id() + .expect("DynamicFilterPhysicalExpr always has an expression_id"); + let sort = sort.with_dynamic_filter(Arc::clone(&new_df))?; + let restored_id = sort + .dynamic_filter() + .expect("should still have dynamic filter") + .expression_id() + .expect("DynamicFilterPhysicalExpr always has an expression_id"); + assert_eq!(restored_id, new_id); + assert_ne!(restored_id, original_id); + Ok(()) + } + + #[test] + fn test_with_dynamic_filter_rejects_invalid_columns() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let child = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let sort = SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }]) + .unwrap(), + child, + ) + .with_fetch(Some(10)); + + // Column index 99 is out of bounds for the input schema. + let df = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("bad", 99)) as _], + lit(true), + )); + assert!(sort.with_dynamic_filter(df).is_err()); + Ok(()) + } + /// Verifies that `ExternalSorter::sort()` transfers the pre-reserved /// merge bytes to the merge stream via `take()`, rather than leaving /// them in the sorter (via `new_empty()`). diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 511e8eb1b012e..dcc2be9f563eb 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1187,6 +1187,8 @@ message HashJoinExecNode { JoinFilter filter = 8; repeated uint32 projection = 9; bool null_aware = 10; + // Optional dynamic filter expression for pushing down to the probe side. + PhysicalExprNode dynamic_filter = 11; } enum StreamPartitionMode { @@ -1318,6 +1320,8 @@ message AggregateExecNode { repeated MaybeFilter filter_expr = 10; AggLimit limit = 11; bool has_grouping_set = 12; + // Optional dynamic filter expression for pushing down to the child. + PhysicalExprNode dynamic_filter = 13; } message GlobalLimitExecNode { @@ -1339,6 +1343,8 @@ message SortExecNode { // Maximum number of highest/lowest rows to fetch; negative means no limit int64 fetch = 3; bool preserve_partitioning = 4; + // Optional dynamic filter expression for TopK pushdown. + PhysicalExprNode dynamic_filter = 5; } message SortPreservingMergeExecNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index c05d3283eac8e..0650ce740526d 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -154,6 +154,9 @@ impl serde::Serialize for AggregateExecNode { if self.has_grouping_set { len += 1; } + if self.dynamic_filter.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.AggregateExecNode", len)?; if !self.group_expr.is_empty() { struct_ser.serialize_field("groupExpr", &self.group_expr)?; @@ -193,6 +196,9 @@ impl serde::Serialize for AggregateExecNode { if self.has_grouping_set { struct_ser.serialize_field("hasGroupingSet", &self.has_grouping_set)?; } + if let Some(v) = self.dynamic_filter.as_ref() { + struct_ser.serialize_field("dynamicFilter", v)?; + } struct_ser.end() } } @@ -223,6 +229,8 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { "limit", "has_grouping_set", "hasGroupingSet", + "dynamic_filter", + "dynamicFilter", ]; #[allow(clippy::enum_variant_names)] @@ -239,6 +247,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { FilterExpr, Limit, HasGroupingSet, + DynamicFilter, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -272,6 +281,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { "filterExpr" | "filter_expr" => Ok(GeneratedField::FilterExpr), "limit" => Ok(GeneratedField::Limit), "hasGroupingSet" | "has_grouping_set" => Ok(GeneratedField::HasGroupingSet), + "dynamicFilter" | "dynamic_filter" => Ok(GeneratedField::DynamicFilter), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -303,6 +313,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { let mut filter_expr__ = None; let mut limit__ = None; let mut has_grouping_set__ = None; + let mut dynamic_filter__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::GroupExpr => { @@ -377,6 +388,12 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { } has_grouping_set__ = Some(map_.next_value()?); } + GeneratedField::DynamicFilter => { + if dynamic_filter__.is_some() { + return Err(serde::de::Error::duplicate_field("dynamicFilter")); + } + dynamic_filter__ = map_.next_value()?; + } } } Ok(AggregateExecNode { @@ -392,6 +409,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { filter_expr: filter_expr__.unwrap_or_default(), limit: limit__, has_grouping_set: has_grouping_set__.unwrap_or_default(), + dynamic_filter: dynamic_filter__, }) } } @@ -8817,6 +8835,9 @@ impl serde::Serialize for HashJoinExecNode { if self.null_aware { len += 1; } + if self.dynamic_filter.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.HashJoinExecNode", len)?; if let Some(v) = self.left.as_ref() { struct_ser.serialize_field("left", v)?; @@ -8851,6 +8872,9 @@ impl serde::Serialize for HashJoinExecNode { if self.null_aware { struct_ser.serialize_field("nullAware", &self.null_aware)?; } + if let Some(v) = self.dynamic_filter.as_ref() { + struct_ser.serialize_field("dynamicFilter", v)?; + } struct_ser.end() } } @@ -8874,6 +8898,8 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { "projection", "null_aware", "nullAware", + "dynamic_filter", + "dynamicFilter", ]; #[allow(clippy::enum_variant_names)] @@ -8887,6 +8913,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { Filter, Projection, NullAware, + DynamicFilter, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -8917,6 +8944,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { "filter" => Ok(GeneratedField::Filter), "projection" => Ok(GeneratedField::Projection), "nullAware" | "null_aware" => Ok(GeneratedField::NullAware), + "dynamicFilter" | "dynamic_filter" => Ok(GeneratedField::DynamicFilter), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -8945,6 +8973,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { let mut filter__ = None; let mut projection__ = None; let mut null_aware__ = None; + let mut dynamic_filter__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Left => { @@ -9004,6 +9033,12 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { } null_aware__ = Some(map_.next_value()?); } + GeneratedField::DynamicFilter => { + if dynamic_filter__.is_some() { + return Err(serde::de::Error::duplicate_field("dynamicFilter")); + } + dynamic_filter__ = map_.next_value()?; + } } } Ok(HashJoinExecNode { @@ -9016,6 +9051,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { filter: filter__, projection: projection__.unwrap_or_default(), null_aware: null_aware__.unwrap_or_default(), + dynamic_filter: dynamic_filter__, }) } } @@ -22542,6 +22578,9 @@ impl serde::Serialize for SortExecNode { if self.preserve_partitioning { len += 1; } + if self.dynamic_filter.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.SortExecNode", len)?; if let Some(v) = self.input.as_ref() { struct_ser.serialize_field("input", v)?; @@ -22557,6 +22596,9 @@ impl serde::Serialize for SortExecNode { if self.preserve_partitioning { struct_ser.serialize_field("preservePartitioning", &self.preserve_partitioning)?; } + if let Some(v) = self.dynamic_filter.as_ref() { + struct_ser.serialize_field("dynamicFilter", v)?; + } struct_ser.end() } } @@ -22572,6 +22614,8 @@ impl<'de> serde::Deserialize<'de> for SortExecNode { "fetch", "preserve_partitioning", "preservePartitioning", + "dynamic_filter", + "dynamicFilter", ]; #[allow(clippy::enum_variant_names)] @@ -22580,6 +22624,7 @@ impl<'de> serde::Deserialize<'de> for SortExecNode { Expr, Fetch, PreservePartitioning, + DynamicFilter, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -22605,6 +22650,7 @@ impl<'de> serde::Deserialize<'de> for SortExecNode { "expr" => Ok(GeneratedField::Expr), "fetch" => Ok(GeneratedField::Fetch), "preservePartitioning" | "preserve_partitioning" => Ok(GeneratedField::PreservePartitioning), + "dynamicFilter" | "dynamic_filter" => Ok(GeneratedField::DynamicFilter), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -22628,6 +22674,7 @@ impl<'de> serde::Deserialize<'de> for SortExecNode { let mut expr__ = None; let mut fetch__ = None; let mut preserve_partitioning__ = None; + let mut dynamic_filter__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { @@ -22656,6 +22703,12 @@ impl<'de> serde::Deserialize<'de> for SortExecNode { } preserve_partitioning__ = Some(map_.next_value()?); } + GeneratedField::DynamicFilter => { + if dynamic_filter__.is_some() { + return Err(serde::de::Error::duplicate_field("dynamicFilter")); + } + dynamic_filter__ = map_.next_value()?; + } } } Ok(SortExecNode { @@ -22663,6 +22716,7 @@ impl<'de> serde::Deserialize<'de> for SortExecNode { expr: expr__.unwrap_or_default(), fetch: fetch__.unwrap_or_default(), preserve_partitioning: preserve_partitioning__.unwrap_or_default(), + dynamic_filter: dynamic_filter__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index af9b1404bb80a..0d978ffca0797 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1772,6 +1772,9 @@ pub struct HashJoinExecNode { pub projection: ::prost::alloc::vec::Vec, #[prost(bool, tag = "10")] pub null_aware: bool, + /// Optional dynamic filter expression for pushing down to the probe side. + #[prost(message, optional, tag = "11")] + pub dynamic_filter: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct SymmetricHashJoinExecNode { @@ -1951,6 +1954,9 @@ pub struct AggregateExecNode { pub limit: ::core::option::Option, #[prost(bool, tag = "12")] pub has_grouping_set: bool, + /// Optional dynamic filter expression for pushing down to the child. + #[prost(message, optional, tag = "13")] + pub dynamic_filter: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct GlobalLimitExecNode { @@ -1981,6 +1987,9 @@ pub struct SortExecNode { pub fetch: i64, #[prost(bool, tag = "4")] pub preserve_partitioning: bool, + /// Optional dynamic filter expression for TopK pushdown. + #[prost(message, optional, tag = "5")] + pub dynamic_filter: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct SortPreservingMergeExecNode { diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 5172a552fad4f..8994926d173ae 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -57,6 +57,7 @@ use datafusion_functions_table::generate_series::{ }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; +use datafusion_physical_expr::expressions::DynamicFilterPhysicalExpr; use datafusion_physical_expr::{LexOrdering, LexRequirement, PhysicalExprRef}; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, LimitOptions, PhysicalGroupBy, @@ -1282,6 +1283,7 @@ impl protobuf::PhysicalPlanNode { }) .collect::, _>>()?; + let physical_schema_ref = Arc::clone(&physical_schema); let agg = AggregateExec::try_new( agg_mode, PhysicalGroupBy::new(group_expr, null_expr, groups, has_grouping_set), @@ -1302,6 +1304,23 @@ impl protobuf::PhysicalPlanNode { agg }; + let agg = if let Some(dynamic_filter_proto) = &hash_agg.dynamic_filter { + let dynamic_filter_expr = proto_converter.proto_to_physical_expr( + dynamic_filter_proto, + physical_schema_ref.as_ref(), + ctx, + )?; + if let Ok(df) = (dynamic_filter_expr as Arc) + .downcast::() + { + agg.with_dynamic_filter(df)? + } else { + agg + } + } else { + agg + }; + Ok(Arc::new(agg)) } @@ -1408,7 +1427,7 @@ impl protobuf::PhysicalPlanNode { } else { None }; - Ok(Arc::new(HashJoinExec::try_new( + let mut hash_join = HashJoinExec::try_new( left, right, on, @@ -1418,7 +1437,22 @@ impl protobuf::PhysicalPlanNode { partition_mode, null_equality.into(), hashjoin.null_aware, - )?)) + )?; + + if let Some(dynamic_filter_proto) = &hashjoin.dynamic_filter { + let dynamic_filter_expr = proto_converter.proto_to_physical_expr( + dynamic_filter_proto, + right_schema.as_ref(), + ctx, + )?; + if let Ok(df) = (dynamic_filter_expr as Arc) + .downcast::() + { + hash_join = hash_join.with_dynamic_filter(df)?; + } + } + + Ok(Arc::new(hash_join)) } fn try_into_symmetric_hash_join_physical_plan( @@ -1656,6 +1690,23 @@ impl protobuf::PhysicalPlanNode { .with_fetch(fetch) .with_preserve_partitioning(sort.preserve_partitioning); + let new_sort = if let Some(dynamic_filter_proto) = &sort.dynamic_filter { + let dynamic_filter_expr = proto_converter.proto_to_physical_expr( + dynamic_filter_proto, + new_sort.input().schema().as_ref(), + ctx, + )?; + if let Ok(df) = (dynamic_filter_expr as Arc) + .downcast::() + { + new_sort.with_dynamic_filter(df)? + } else { + new_sort + } + } else { + new_sort + }; + Ok(Arc::new(new_sort)) } @@ -2462,6 +2513,15 @@ impl protobuf::PhysicalPlanNode { PartitionMode::Auto => protobuf::PartitionMode::Auto, }; + let dynamic_filter = exec + .dynamic_filter() + .map(|df| { + let df_expr: Arc = + Arc::clone(df) as Arc; + proto_converter.physical_expr_to_proto(&df_expr, codec) + }) + .transpose()?; + Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::HashJoin(Box::new( protobuf::HashJoinExecNode { @@ -2476,6 +2536,7 @@ impl protobuf::PhysicalPlanNode { v.iter().map(|x| *x as u32).collect::>() }), null_aware: exec.null_aware, + dynamic_filter, }, ))), }) @@ -2805,6 +2866,14 @@ impl protobuf::PhysicalPlanNode { groups, limit, has_grouping_set: exec.group_expr().has_grouping_set(), + dynamic_filter: exec + .dynamic_filter() + .map(|df| { + let df_expr: Arc = + Arc::clone(df) as Arc; + proto_converter.physical_expr_to_proto(&df_expr, codec) + }) + .transpose()?, }, ))), }) @@ -3098,6 +3167,14 @@ impl protobuf::PhysicalPlanNode { }) }) .collect::>>()?; + let dynamic_filter = exec + .dynamic_filter() + .map(|df| { + let df_expr: Arc = df as Arc; + proto_converter.physical_expr_to_proto(&df_expr, codec) + }) + .transpose()?; + Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Sort(Box::new( protobuf::SortExecNode { @@ -3108,6 +3185,7 @@ impl protobuf::PhysicalPlanNode { _ => -1, }, preserve_partitioning: exec.preserve_partitioning(), + dynamic_filter, }, ))), }) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index fa342ae9079d5..cbb0012024116 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -53,6 +53,8 @@ use datafusion::physical_expr::window::{SlidingAggregateWindowExpr, StandardWind use datafusion::physical_expr::{ LexOrdering, PhysicalSortRequirement, ScalarFunctionExpr, }; +use datafusion::physical_optimizer::PhysicalOptimizerRule; +use datafusion::physical_optimizer::filter_pushdown::FilterPushdown; use datafusion::physical_plan::aggregates::{ AggregateExec, AggregateMode, LimitOptions, PhysicalGroupBy, }; @@ -99,6 +101,7 @@ use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, }; use datafusion_datasource::TableSchema; +use datafusion_datasource::file::FileSource; use datafusion_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}; use datafusion_expr::dml::InsertOp; use datafusion_expr::{ @@ -2872,21 +2875,26 @@ fn assert_dynamic_filter_update_is_visible( Ok(()) } +/// Extract the dynamic-filter predicate that was pushed down to the parquet +/// scan at the bottom of the plan tree. +fn parquet_source_predicate(child: &Arc) -> Arc { + let data_source = child + .downcast_ref::() + .expect("Child should be DataSourceExec"); + let (_, parquet_source) = data_source + .downcast_to_file_source::() + .expect("Should be ParquetSource"); + parquet_source + .filter() + .expect("ParquetSource should have a predicate after roundtrip") +} + /// Assert that two dynamic filters are equal both structurally (Debug output) /// and by identity (`expression_id`). -/// fn assert_dynamic_filters_equal( expected: &Arc, actual: &Arc, ) { - // TODO: Debug currently omits `expression_id` so the id has to be checked - // separately here. Once plan nodes like `SortExec` / `AggregateExec` / - // `HashJoinExec` serialize their own dynamic filter, Debug can include - // `expression_id`. - // - // See https://github.com/apache/datafusion/issues/20418 - assert_eq!(expected.expression_id(), actual.expression_id()); - // Structural. let expected_dbg = format!("{expected:?}"); let actual_dbg = format!("{actual:?}"); @@ -3465,6 +3473,86 @@ fn test_linearization_stops_at_different_op() -> Result<()> { Ok(()) } +/// Create a DataSourceExec backed by a ParquetSource that accepts filter pushdown, +/// along with a ConfigOptions that enables all dynamic filter pushdown options. +fn datasource_for_dynamic_filter_pushdown( + schema: &Arc, +) -> (Arc, ConfigOptions) { + let mut parquet_options = TableParquetOptions::new(); + parquet_options.global.pushdown_filters = true; + let source = Arc::new( + ParquetSource::new(Arc::clone(schema)) + .with_table_parquet_options(parquet_options), + ); + let scan_config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), source) + .with_file(PartitionedFile::new("/path/to/file.parquet", 1024)) + .build(); + + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_join_dynamic_filter_pushdown = true; + config.optimizer.enable_aggregate_dynamic_filter_pushdown = true; + config.optimizer.enable_topk_dynamic_filter_pushdown = true; + + (DataSourceExec::from_data_source(scan_config), config) +} + +/// Test that plan containing a HashJoinExec with dynamic filter pushdown +/// can be serialized and deserialized while preserving references to the dynamic filter. +#[test] +fn test_hash_join_with_dynamic_filter_roundtrip() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("col", DataType::Int64, false)])); + + let left_child = Arc::new(EmptyExec::new(Arc::clone(&schema))); + let (right_child, config) = datasource_for_dynamic_filter_pushdown(&schema); + + let on: Vec<(Arc, Arc)> = vec![( + Arc::new(Column::new("col", 0)), + Arc::new(Column::new("col", 0)), + )]; + + let hash_join = Arc::new(HashJoinExec::try_new( + left_child, + right_child, + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + false, + )?) as Arc; + + // Run the optimizer rule for filter pushdown. + let optimizer = FilterPushdown::new_post_optimization(); + let plan = optimizer.optimize(hash_join, &config)?; + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let converter = DeduplicatingProtoConverter {}; + let deserialized = roundtrip_test_and_return(plan, &ctx, &codec, &converter)?; + + // Extract the deserialized HashJoinExec and its dynamic filter. + let deserialized_join = deserialized + .downcast_ref::() + .expect("Should be HashJoinExec"); + let deserialized_hash_join_df = deserialized_join + .dynamic_filter() + .expect("HashJoinExec should have a dynamic filter after roundtrip"); + + // Extract the dynamic filter pushed down to the probe side's ParquetSource. + let deserialized_predicate = parquet_source_predicate(deserialized_join.right()); + + // The HashJoinExec's dynamic filter and the probe side's predicate should + // refer to the same underlying expression. + let plan_df: Arc = deserialized_hash_join_df.clone(); + assert_dynamic_filters_equal(&plan_df, &deserialized_predicate); + assert_dynamic_filter_update_is_visible(&plan_df, &deserialized_predicate)?; + + Ok(()) +} + /// returns a SessionContext with an empty `netflow` table registered fn netflow_context() -> Result { let ctx = SessionContext::new(); @@ -3547,39 +3635,141 @@ async fn roundtrip_issue_18602_complex_filter_decode_recursion() -> Result<()> { roundtrip_test_sql_with_context(sql, &ctx).await } +/// Test that plan containing a AggregateExec with dynamic filter pushdown +/// can be serialized and deserialized while preserving references to the dynamic filter. #[test] -fn roundtrip_filter_with_none_projection() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - Field::new("c", DataType::Int32, false), - ])); - let predicate: Arc = Arc::new(BinaryExpr::new( - Arc::new(Column::new("a", 0)), - Operator::Gt, - lit(ScalarValue::Int32(Some(0))), - )); - let input: Arc = Arc::new(EmptyExec::new(Arc::clone(&schema))); +fn test_aggregate_with_dynamic_filter_roundtrip() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let col_a: Arc = Arc::new(Column::new("a", 0)); - // Case 1: None projection (return all columns) - roundtrip_test(Arc::new(FilterExec::try_new( - Arc::clone(&predicate), - Arc::clone(&input), - )?))?; + let (child, config) = datasource_for_dynamic_filter_pushdown(&schema); - // Case 2: Some(vec![]) — explicitly empty projection - roundtrip_test(Arc::new( - FilterExecBuilder::new(Arc::clone(&predicate), Arc::clone(&input)) - .apply_projection(Some(vec![]))? - .build()?, - ))?; + let agg = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![]), + vec![ + AggregateExprBuilder::new( + datafusion::functions_aggregate::min_max::min_udaf(), + vec![Arc::clone(&col_a)], + ) + .schema(Arc::clone(&schema)) + .alias("min_a") + .build() + .map(Arc::new)?, + ], + vec![None], + child, + Arc::clone(&schema), + )?) as Arc; - // Case 3: Some(vec![2, 0]) — partial projection - roundtrip_test(Arc::new( - FilterExecBuilder::new(Arc::clone(&predicate), Arc::clone(&input)) - .apply_projection(Some(vec![2, 0]))? - .build()?, - ))?; + // Run the optimizer rule for filter pushdown. + let optimizer = FilterPushdown::new_post_optimization(); + let plan = optimizer.optimize(agg, &config)?; + + // Roundtrip with deduplication. + // + // Note: We don't use `roundtrip_test_and_return` here because there's a + // pre-existing issue with PhysicalGroupBy serialization where empty groups + // `[[]]` become `[]` after roundtrip. This behavior is unrelated to this test. + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let converter = DeduplicatingProtoConverter {}; + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&plan), + &codec, + &converter, + )?; + let deserialized = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &converter, + )?; + + // Extract the deserialized AggregateExec and its dynamic filter. + let deserialized_agg = deserialized + .downcast_ref::() + .expect("Should be AggregateExec"); + let deserialized_agg_df = deserialized_agg + .dynamic_filter() + .expect("AggregateExec should have a dynamic filter after roundtrip"); + + // Extract the dynamic filter pushed down to the child ParquetSource. + let deserialized_predicate = parquet_source_predicate(deserialized_agg.input()); + + // The AggregateExec's dynamic filter and the child's predicate should + // refer to the same underlying expression. + let plan_df: Arc = deserialized_agg_df.clone(); + assert_dynamic_filters_equal(&plan_df, &deserialized_predicate); + assert_dynamic_filter_update_is_visible(&plan_df, &deserialized_predicate)?; + + Ok(()) +} + +/// Test that plan containing a SortExec with dynamic filter pushdown +/// can be serialized and deserialized while preserving references to the dynamic filter. +#[test] +fn test_sort_topk_with_dynamic_filter_roundtrip() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let col_a: Arc = Arc::new(Column::new("a", 0)); + + let (child, config) = datasource_for_dynamic_filter_pushdown(&schema); + + let sort = Arc::new( + SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::clone(&col_a), + options: SortOptions::default(), + }]) + .unwrap(), + child, + ) + .with_fetch(Some(10)), + ) as Arc; + + // Verify the optimizer kept the dynamic filter on the SortExec. + let optimizer = FilterPushdown::new_post_optimization(); + let plan = optimizer.optimize(sort, &config)?; + + // Roundtrip with deduplication. + // + // Note: We don't use `roundtrip_test_and_return` here because + // `DeduplicatingDeserializer` rewrites cache hits via `with_new_children`, + // which sets `remapped_children: Some(...)` on the second encounter of a + // shared `DynamicFilterPhysicalExpr`. SortExec's `Debug` includes its + // dynamic filter, so the original-vs-deserialized structural equality check + // would fail purely on this artifact. + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let converter = DeduplicatingProtoConverter {}; + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&plan), + &codec, + &converter, + )?; + let deserialized = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &converter, + )?; + + // Extract the deserialized SortExec and its dynamic filter. + let deserialized_sort = deserialized + .downcast_ref::() + .expect("Should be SortExec"); + let deserialized_sort_df = deserialized_sort + .dynamic_filter() + .expect("SortExec should have a dynamic filter after roundtrip"); + + // Extract the dynamic filter pushed down to the child ParquetSource. + let deserialized_predicate = parquet_source_predicate(deserialized_sort.input()); + + // The SortExec's dynamic filter and the child's predicate should + // refer to the same underlying expression. + let plan_df: Arc = deserialized_sort_df; + assert_dynamic_filters_equal(&plan_df, &deserialized_predicate); + assert_dynamic_filter_update_is_visible(&plan_df, &deserialized_predicate)?; Ok(()) } From 334ca912ca48ecc71304118840b0eeb26931a889 Mon Sep 17 00:00:00 2001 From: Jayant Shrivastava Date: Mon, 4 May 2026 20:28:00 +0000 Subject: [PATCH 2/3] rename --- .../physical_optimizer/filter_pushdown.rs | 2 +- .../physical-plan/src/aggregates/mod.rs | 103 +++++++++--------- .../physical-plan/src/joins/hash_join/exec.rs | 14 ++- datafusion/physical-plan/src/sorts/sort.rs | 12 +- datafusion/proto/src/physical_plan/mod.rs | 12 +- .../tests/cases/roundtrip_physical_plan.rs | 6 +- 6 files changed, 75 insertions(+), 74 deletions(-) diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown.rs index 5f64c9e4a5400..f56b8c6d70624 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown.rs @@ -2835,7 +2835,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_is_used() { // Verify that a dynamic filter was created let dynamic_filter = hash_join - .dynamic_filter() + .dynamic_filter_expr() .expect("Dynamic filter should be created"); // Verify that is_used() returns the expected value based on probe side support. diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index f311b37e0a0b4..7d0388dddfce4 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -893,6 +893,47 @@ impl AggregateExec { &self.filter_expr } + /// Returns the dynamic filter expression for this aggregate, if set. + pub fn dynamic_filter_expr(&self) -> Option<&Arc> { + self.dynamic_filter.as_ref().map(|df| &df.filter) + } + + /// Replace the dynamic filter expression. This method errors if the aggregate does not + /// support dynamic filtering or if the filter expression is incompatible with this + /// [`AggregateExec`]. + pub fn with_dynamic_filter_expr( + mut self, + filter: Arc, + ) -> Result { + // If there is no dynamic filter state initialized via `try_new`, then + // we can safely assume that the aggregate does not support dynamic filtering. + let Some(dyn_filter) = self.dynamic_filter.as_ref() else { + return internal_err!("Aggregate does not support dynamic filtering"); + }; + + // Validate that the filter is compatible with the aggregation columns. + let cols = self.cols_for_dynamic_filter(&dyn_filter.supported_accumulators_info); + if cols.len() != filter.children().len() { + return internal_err!( + "Dynamic filter expression is incompatible with aggregate due to mismatched number of columns" + ); + } + for (col, child) in cols.iter().zip(filter.children()) { + if !col.eq(child) { + return internal_err!( + "Dynamic filter expression is incompatible with aggregate due to mismatched column references {col} != {child}" + ); + } + } + + // Overwrite our filter + self.dynamic_filter = Some(Arc::new(AggrDynFilter { + filter, + supported_accumulators_info: dyn_filter.supported_accumulators_info.clone(), + })); + Ok(self) + } + /// Input plan pub fn input(&self) -> &Arc { &self.input @@ -1048,47 +1089,6 @@ impl AggregateExec { &self.input_order_mode } - /// Returns the dynamic filter expression for this aggregate, if set. - pub fn dynamic_filter(&self) -> Option<&Arc> { - self.dynamic_filter.as_ref().map(|df| &df.filter) - } - - /// Replace the dynamic filter expression. This method errors if the aggregate does not - /// support dynamic filtering or if the filter expression is incompatible with this - /// [`AggregateExec`]. - pub fn with_dynamic_filter( - mut self, - filter: Arc, - ) -> Result { - // If there is no dynamic filter state initialized via `try_new`, then - // we can safely assume that the aggregate does not support dynamic filtering. - let Some(dyn_filter) = self.dynamic_filter.as_ref() else { - return internal_err!("Aggregate does not support dynamic filtering"); - }; - - // Validate that the filter is compatible with the aggregation columns. - let cols = self.cols_for_dynamic_filter(&dyn_filter.supported_accumulators_info); - if cols.len() != filter.children().len() { - return internal_err!( - "Dynamic filter expression is incompatible with aggregate due to mismatched number of columns" - ); - } - for (col, child) in cols.iter().zip(filter.children()) { - if !col.eq(child) { - return internal_err!( - "Dynamic filter expression is incompatible with aggregate due to mismatched column references {col} != {child}" - ); - } - } - - // Overwrite our filter - self.dynamic_filter = Some(Arc::new(AggrDynFilter { - filter, - supported_accumulators_info: dyn_filter.supported_accumulators_info.clone(), - })); - Ok(self) - } - /// Estimates output statistics for this aggregate node. /// /// For grouped aggregations with known input row count > 1, the output row @@ -4842,7 +4842,7 @@ mod tests { Ok(()) } - /// Test that [`AggregateExec::with_dynamic_filter`] overrides the existing dynamic filter + /// Test that [`AggregateExec::with_dynamic_filter_expr`] overrides the existing dynamic filter #[test] fn test_with_dynamic_filter() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); @@ -4869,11 +4869,11 @@ mod tests { vec![col("a", &schema)?], lit(false), )); - let agg = agg.with_dynamic_filter(Arc::clone(&new_df))?; + let agg = agg.with_dynamic_filter_expr(Arc::clone(&new_df))?; // The aggregate's filter should now resolve to the new inner expression. let swapped = agg - .dynamic_filter() + .dynamic_filter_expr() .expect("should still have dynamic filter") .current()?; assert_eq!(format!("{swapped}"), format!("{}", lit(false))); @@ -4884,19 +4884,18 @@ mod tests { Arc::::clone(&new_df); let remapped_pexpr = new_df_as_pexpr.with_new_children(vec![col("a", &schema)?])?; - let Ok(remapped_df) = (remapped_pexpr - as Arc) + let Ok(remapped_df) = (remapped_pexpr as Arc) .downcast::() else { panic!("should be DynamicFilterPhysicalExpr after with_new_children"); }; // Hard to assert this because the filter is identical. No error means // the filter was accepted. That's a good enough assertion for now. - let _agg = agg.with_dynamic_filter(remapped_df)?; + let _agg = agg.with_dynamic_filter_expr(remapped_df)?; Ok(()) } - /// Test that [`AggregateExec::with_dynamic_filter`] errors when the aggregate does not support dynamic filtering + /// Test that [`AggregateExec::with_dynamic_filter_expr`] errors when the aggregate does not support dynamic filtering #[test] fn test_with_dynamic_filter_error_unsupported() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -4919,17 +4918,17 @@ mod tests { child, Arc::clone(&schema), )?; - assert!(agg.dynamic_filter().is_none()); + assert!(agg.dynamic_filter_expr().is_none()); let df = Arc::new(DynamicFilterPhysicalExpr::new( vec![col("a", &schema)?], lit(true), )); - assert!(agg.with_dynamic_filter(df).is_err()); + assert!(agg.with_dynamic_filter_expr(df).is_err()); Ok(()) } - /// Test that [`AggregateExec::with_dynamic_filter`] errors when the column is not in the schema + /// Test that [`AggregateExec::with_dynamic_filter_expr`] errors when the column is not in the schema #[test] fn test_with_dynamic_filter_error_column_mismatch() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); @@ -4953,7 +4952,7 @@ mod tests { vec![Arc::new(Column::new("bad", 99)) as _], lit(true), )); - assert!(agg.with_dynamic_filter(df).is_err()); + assert!(agg.with_dynamic_filter_expr(df).is_err()); Ok(()) } } diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index 1f4ee9c48508a..3cdd60d7ab3c8 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -904,7 +904,7 @@ impl HashJoinExec { /// Get the dynamic filter expression for testing purposes. /// Returns the dynamic filter expression for this hash join, if set. - pub fn dynamic_filter(&self) -> Option<&Arc> { + pub fn dynamic_filter_expr(&self) -> Option<&Arc> { self.dynamic_filter.as_ref().map(|df| &df.filter) } @@ -914,7 +914,7 @@ impl HashJoinExec { /// /// Validates that the filter's children reference valid columns in /// the probe (right) side's schema. - pub fn with_dynamic_filter( + pub fn with_dynamic_filter_expr( mut self, filter: Arc, ) -> Result { @@ -6345,15 +6345,17 @@ mod tests { NullEquality::NullEqualsNothing, false, )?; - assert!(join.dynamic_filter().is_none()); + assert!(join.dynamic_filter_expr().is_none()); let df = Arc::new(DynamicFilterPhysicalExpr::new( vec![Arc::new(Column::new("b1", 1)) as _], lit(true), )); - let join = join.with_dynamic_filter(Arc::clone(&df))?; + let join = join.with_dynamic_filter_expr(Arc::clone(&df))?; - let restored = join.dynamic_filter().expect("should have dynamic filter"); + let restored = join + .dynamic_filter_expr() + .expect("should have dynamic filter"); assert_eq!( restored .expression_id() @@ -6387,7 +6389,7 @@ mod tests { vec![Arc::new(Column::new("bad", 99)) as _], lit(true), )); - assert!(join.with_dynamic_filter(df).is_err()); + assert!(join.with_dynamic_filter_expr(df).is_err()); Ok(()) } } diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index f4c764ac73a3f..33bebfd8edbbd 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -979,7 +979,7 @@ impl SortExec { } /// Returns the dynamic filter expression for this sort (TopK), if set. - pub fn dynamic_filter(&self) -> Option> { + pub fn dynamic_filter_expr(&self) -> Option> { self.filter.as_ref().map(|f| f.read().expr()) } @@ -990,7 +990,7 @@ impl SortExec { /// /// Validates that the filter's children reference valid columns in /// the sort's input schema. - pub fn with_dynamic_filter( + pub fn with_dynamic_filter_expr( mut self, filter: Arc, ) -> Result { @@ -2764,7 +2764,7 @@ mod tests { // SortExec with fetch creates a dynamic filter automatically. let original_id = sort - .dynamic_filter() + .dynamic_filter_expr() .expect("should have dynamic filter with fetch") .expression_id() .expect("DynamicFilterPhysicalExpr always has an expression_id"); @@ -2777,9 +2777,9 @@ mod tests { let new_id = new_df .expression_id() .expect("DynamicFilterPhysicalExpr always has an expression_id"); - let sort = sort.with_dynamic_filter(Arc::clone(&new_df))?; + let sort = sort.with_dynamic_filter_expr(Arc::clone(&new_df))?; let restored_id = sort - .dynamic_filter() + .dynamic_filter_expr() .expect("should still have dynamic filter") .expression_id() .expect("DynamicFilterPhysicalExpr always has an expression_id"); @@ -2808,7 +2808,7 @@ mod tests { vec![Arc::new(Column::new("bad", 99)) as _], lit(true), )); - assert!(sort.with_dynamic_filter(df).is_err()); + assert!(sort.with_dynamic_filter_expr(df).is_err()); Ok(()) } diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 8994926d173ae..ef8d6305b6201 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -1313,7 +1313,7 @@ impl protobuf::PhysicalPlanNode { if let Ok(df) = (dynamic_filter_expr as Arc) .downcast::() { - agg.with_dynamic_filter(df)? + agg.with_dynamic_filter_expr(df)? } else { agg } @@ -1448,7 +1448,7 @@ impl protobuf::PhysicalPlanNode { if let Ok(df) = (dynamic_filter_expr as Arc) .downcast::() { - hash_join = hash_join.with_dynamic_filter(df)?; + hash_join = hash_join.with_dynamic_filter_expr(df)?; } } @@ -1699,7 +1699,7 @@ impl protobuf::PhysicalPlanNode { if let Ok(df) = (dynamic_filter_expr as Arc) .downcast::() { - new_sort.with_dynamic_filter(df)? + new_sort.with_dynamic_filter_expr(df)? } else { new_sort } @@ -2514,7 +2514,7 @@ impl protobuf::PhysicalPlanNode { }; let dynamic_filter = exec - .dynamic_filter() + .dynamic_filter_expr() .map(|df| { let df_expr: Arc = Arc::clone(df) as Arc; @@ -2867,7 +2867,7 @@ impl protobuf::PhysicalPlanNode { limit, has_grouping_set: exec.group_expr().has_grouping_set(), dynamic_filter: exec - .dynamic_filter() + .dynamic_filter_expr() .map(|df| { let df_expr: Arc = Arc::clone(df) as Arc; @@ -3168,7 +3168,7 @@ impl protobuf::PhysicalPlanNode { }) .collect::>>()?; let dynamic_filter = exec - .dynamic_filter() + .dynamic_filter_expr() .map(|df| { let df_expr: Arc = df as Arc; proto_converter.physical_expr_to_proto(&df_expr, codec) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index cbb0012024116..903b459bd17e4 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -3538,7 +3538,7 @@ fn test_hash_join_with_dynamic_filter_roundtrip() -> Result<()> { .downcast_ref::() .expect("Should be HashJoinExec"); let deserialized_hash_join_df = deserialized_join - .dynamic_filter() + .dynamic_filter_expr() .expect("HashJoinExec should have a dynamic filter after roundtrip"); // Extract the dynamic filter pushed down to the probe side's ParquetSource. @@ -3691,7 +3691,7 @@ fn test_aggregate_with_dynamic_filter_roundtrip() -> Result<()> { .downcast_ref::() .expect("Should be AggregateExec"); let deserialized_agg_df = deserialized_agg - .dynamic_filter() + .dynamic_filter_expr() .expect("AggregateExec should have a dynamic filter after roundtrip"); // Extract the dynamic filter pushed down to the child ParquetSource. @@ -3759,7 +3759,7 @@ fn test_sort_topk_with_dynamic_filter_roundtrip() -> Result<()> { .downcast_ref::() .expect("Should be SortExec"); let deserialized_sort_df = deserialized_sort - .dynamic_filter() + .dynamic_filter_expr() .expect("SortExec should have a dynamic filter after roundtrip"); // Extract the dynamic filter pushed down to the child ParquetSource. From 136964d31b36c3ad80622a281be387f92f2b4979 Mon Sep 17 00:00:00 2001 From: Jayant Shrivastava Date: Mon, 4 May 2026 22:44:01 +0000 Subject: [PATCH 3/3] error if unexpected expression --- datafusion/proto/src/physical_plan/mod.rs | 37 +++++++++++++---------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index ef8d6305b6201..5da90fe6a7533 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -1310,13 +1310,14 @@ impl protobuf::PhysicalPlanNode { physical_schema_ref.as_ref(), ctx, )?; - if let Ok(df) = (dynamic_filter_expr as Arc) + let df = (dynamic_filter_expr as Arc) .downcast::() - { - agg.with_dynamic_filter_expr(df)? - } else { - agg - } + .map_err(|_| { + internal_datafusion_err!( + "AggregateExec dynamic_filter did not decode to a DynamicFilterPhysicalExpr" + ) + })?; + agg.with_dynamic_filter_expr(df)? } else { agg }; @@ -1445,11 +1446,14 @@ impl protobuf::PhysicalPlanNode { right_schema.as_ref(), ctx, )?; - if let Ok(df) = (dynamic_filter_expr as Arc) + let df = (dynamic_filter_expr as Arc) .downcast::() - { - hash_join = hash_join.with_dynamic_filter_expr(df)?; - } + .map_err(|_| { + internal_datafusion_err!( + "HashJoinExec dynamic_filter did not decode to a DynamicFilterPhysicalExpr" + ) + })?; + hash_join = hash_join.with_dynamic_filter_expr(df)?; } Ok(Arc::new(hash_join)) @@ -1696,13 +1700,14 @@ impl protobuf::PhysicalPlanNode { new_sort.input().schema().as_ref(), ctx, )?; - if let Ok(df) = (dynamic_filter_expr as Arc) + let df = (dynamic_filter_expr as Arc) .downcast::() - { - new_sort.with_dynamic_filter_expr(df)? - } else { - new_sort - } + .map_err(|_| { + internal_datafusion_err!( + "SortExec dynamic_filter did not decode to a DynamicFilterPhysicalExpr" + ) + })?; + new_sort.with_dynamic_filter_expr(df)? } else { new_sort };