diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index ba13ef392d912..382c6406d8bfd 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -632,7 +632,14 @@ impl Statistics { col_stats.max_value = col_stats.max_value.max(&item_col_stats.max_value); col_stats.min_value = col_stats.min_value.min(&item_col_stats.min_value); col_stats.sum_value = col_stats.sum_value.add(&item_col_stats.sum_value); - col_stats.distinct_count = Precision::Absent; + // Use max as a conservative lower bound for distinct count + // (can't accurately merge NDV since duplicates may exist across partitions) + col_stats.distinct_count = col_stats + .distinct_count + .get_value() + .max(item_col_stats.distinct_count.get_value()) + .map(|&v| Precision::Inexact(v)) + .unwrap_or(Precision::Absent); col_stats.byte_size = col_stats.byte_size.add(&item_col_stats.byte_size); } @@ -1352,8 +1359,8 @@ mod tests { col_stats.max_value, Precision::Exact(ScalarValue::Int32(Some(20))) ); - // Distinct count should be Absent after merge - assert_eq!(col_stats.distinct_count, Precision::Absent); + // Distinct count should be Inexact(max) after merge as a conservative lower bound + assert_eq!(col_stats.distinct_count, Precision::Inexact(7)); } #[test] diff --git a/datafusion/core/tests/physical_optimizer/partition_statistics.rs b/datafusion/core/tests/physical_optimizer/partition_statistics.rs index b33305c23ede2..567337f91efab 100644 --- a/datafusion/core/tests/physical_optimizer/partition_statistics.rs +++ b/datafusion/core/tests/physical_optimizer/partition_statistics.rs @@ -150,13 +150,15 @@ mod test { // - null_count = 0 (partition values from paths are never null) // - min/max are the merged partition values across files in the group // - byte_size = num_rows * 4 (Date32 is 4 bytes per row) + // - distinct_count = Inexact(1) per partition file (single partition value per file), + // preserved via max() when merging stats across partitions let date32_byte_size = num_rows * 4; column_stats.push(ColumnStatistics { null_count: Precision::Exact(0), max_value: Precision::Exact(ScalarValue::Date32(Some(max_date))), min_value: Precision::Exact(ScalarValue::Date32(Some(min_date))), sum_value: Precision::Absent, - distinct_count: Precision::Absent, + distinct_count: Precision::Inexact(1), byte_size: Precision::Exact(date32_byte_size), }); } @@ -577,7 +579,7 @@ mod test { max_value: Precision::Exact(ScalarValue::Date32(Some(20151))), min_value: Precision::Exact(ScalarValue::Date32(Some(20148))), sum_value: Precision::Absent, - distinct_count: Precision::Absent, + distinct_count: Precision::Inexact(1), byte_size: Precision::Absent, }, // column 2: right.id (Int32, file column from t2) - right partition 0: ids [3,4] @@ -611,7 +613,7 @@ mod test { max_value: Precision::Exact(ScalarValue::Date32(Some(20151))), min_value: Precision::Exact(ScalarValue::Date32(Some(20148))), sum_value: Precision::Absent, - distinct_count: Precision::Absent, + distinct_count: Precision::Inexact(1), byte_size: Precision::Absent, }, // column 2: right.id (Int32, file column from t2) - right partition 1: ids [1,2] @@ -1247,7 +1249,7 @@ mod test { DATE_2025_03_01, ))), sum_value: Precision::Absent, - distinct_count: Precision::Absent, + distinct_count: Precision::Inexact(1), byte_size: Precision::Exact(8), }, ColumnStatistics::new_unknown(), // window column @@ -1275,7 +1277,7 @@ mod test { DATE_2025_03_03, ))), sum_value: Precision::Absent, - distinct_count: Precision::Absent, + distinct_count: Precision::Inexact(1), byte_size: Precision::Exact(8), }, ColumnStatistics::new_unknown(), // window column diff --git a/datafusion/datasource-parquet/src/metadata.rs b/datafusion/datasource-parquet/src/metadata.rs index b763f817a0268..68ad091e371bc 100644 --- a/datafusion/datasource-parquet/src/metadata.rs +++ b/datafusion/datasource-parquet/src/metadata.rs @@ -297,6 +297,8 @@ impl<'a> DFParquetMetadata<'a> { vec![Some(true); logical_file_schema.fields().len()]; let mut is_min_value_exact = vec![Some(true); logical_file_schema.fields().len()]; + let mut distinct_counts_array = + vec![Precision::Absent; logical_file_schema.fields().len()]; logical_file_schema.fields().iter().enumerate().for_each( |(idx, field)| match StatisticsConverter::try_new( field.name(), @@ -311,8 +313,9 @@ impl<'a> DFParquetMetadata<'a> { is_min_value_exact: &mut is_min_value_exact, is_max_value_exact: &mut is_max_value_exact, column_byte_sizes: &mut column_byte_sizes, + distinct_counts_array: &mut distinct_counts_array, }; - summarize_min_max_null_counts( + summarize_column_statistics( file_metadata.schema_descr(), logical_file_schema, &physical_file_schema, @@ -330,15 +333,16 @@ impl<'a> DFParquetMetadata<'a> { }, ); - get_col_stats( - logical_file_schema, - &null_counts_array, - &mut max_accs, - &mut min_accs, - &mut is_max_value_exact, - &mut is_min_value_exact, - &column_byte_sizes, - ) + let mut accumulators = StatisticsAccumulators { + min_accs: &mut min_accs, + max_accs: &mut max_accs, + null_counts_array: &mut null_counts_array, + is_min_value_exact: &mut is_min_value_exact, + is_max_value_exact: &mut is_max_value_exact, + column_byte_sizes: &mut column_byte_sizes, + distinct_counts_array: &mut distinct_counts_array, + }; + accumulators.build_column_statistics(logical_file_schema) } else { // Record column sizes logical_file_schema @@ -411,53 +415,6 @@ fn create_max_min_accs( (max_values, min_values) } -fn get_col_stats( - schema: &Schema, - null_counts: &[Precision], - max_values: &mut [Option], - min_values: &mut [Option], - is_max_value_exact: &mut [Option], - is_min_value_exact: &mut [Option], - column_byte_sizes: &[Precision], -) -> Vec { - (0..schema.fields().len()) - .map(|i| { - let max_value = match ( - max_values.get_mut(i).unwrap(), - is_max_value_exact.get(i).unwrap(), - ) { - (Some(max_value), Some(true)) => { - max_value.evaluate().ok().map(Precision::Exact) - } - (Some(max_value), Some(false)) | (Some(max_value), None) => { - max_value.evaluate().ok().map(Precision::Inexact) - } - (None, _) => None, - }; - let min_value = match ( - min_values.get_mut(i).unwrap(), - is_min_value_exact.get(i).unwrap(), - ) { - (Some(min_value), Some(true)) => { - min_value.evaluate().ok().map(Precision::Exact) - } - (Some(min_value), Some(false)) | (Some(min_value), None) => { - min_value.evaluate().ok().map(Precision::Inexact) - } - (None, _) => None, - }; - ColumnStatistics { - null_count: null_counts[i], - max_value: max_value.unwrap_or(Precision::Absent), - min_value: min_value.unwrap_or(Precision::Absent), - sum_value: Precision::Absent, - distinct_count: Precision::Absent, - byte_size: column_byte_sizes[i], - } - }) - .collect() -} - /// Holds the accumulator state for collecting statistics from row groups struct StatisticsAccumulators<'a> { min_accs: &'a mut [Option], @@ -466,9 +423,52 @@ struct StatisticsAccumulators<'a> { is_min_value_exact: &'a mut [Option], is_max_value_exact: &'a mut [Option], column_byte_sizes: &'a mut [Precision], + distinct_counts_array: &'a mut [Precision], } -fn summarize_min_max_null_counts( +impl StatisticsAccumulators<'_> { + /// Converts the accumulated statistics into a vector of `ColumnStatistics` + fn build_column_statistics(&mut self, schema: &Schema) -> Vec { + (0..schema.fields().len()) + .map(|i| { + let max_value = match ( + self.max_accs.get_mut(i).unwrap(), + self.is_max_value_exact.get(i).unwrap(), + ) { + (Some(max_value), Some(true)) => { + max_value.evaluate().ok().map(Precision::Exact) + } + (Some(max_value), Some(false)) | (Some(max_value), None) => { + max_value.evaluate().ok().map(Precision::Inexact) + } + (None, _) => None, + }; + let min_value = match ( + self.min_accs.get_mut(i).unwrap(), + self.is_min_value_exact.get(i).unwrap(), + ) { + (Some(min_value), Some(true)) => { + min_value.evaluate().ok().map(Precision::Exact) + } + (Some(min_value), Some(false)) | (Some(min_value), None) => { + min_value.evaluate().ok().map(Precision::Inexact) + } + (None, _) => None, + }; + ColumnStatistics { + null_count: self.null_counts_array[i], + max_value: max_value.unwrap_or(Precision::Absent), + min_value: min_value.unwrap_or(Precision::Absent), + sum_value: Precision::Absent, + distinct_count: self.distinct_counts_array[i], + byte_size: self.column_byte_sizes[i], + } + }) + .collect() + } +} + +fn summarize_column_statistics( parquet_schema: &SchemaDescriptor, logical_file_schema: &Schema, physical_file_schema: &Schema, @@ -523,6 +523,36 @@ fn summarize_min_max_null_counts( ) .map(|(idx, _)| idx); + // Extract distinct counts from row group column statistics + accumulators.distinct_counts_array[logical_schema_index] = + if let Some(parquet_idx) = parquet_index { + let distinct_counts: Vec = row_groups_metadata + .iter() + .filter_map(|rg| { + rg.columns() + .get(parquet_idx) + .and_then(|col| col.statistics()) + .and_then(|stats| stats.distinct_count_opt()) + }) + .collect(); + + if distinct_counts.is_empty() { + Precision::Absent + } else if distinct_counts.len() == 1 { + // Single row group with distinct count - use exact value + Precision::Exact(distinct_counts[0] as usize) + } else { + // Multiple row groups - use max as a lower bound estimate + // (can't accurately merge NDV since duplicates may exist across row groups) + match distinct_counts.iter().max() { + Some(&max_ndv) => Precision::Inexact(max_ndv as usize), + None => Precision::Absent, + } + } + } else { + Precision::Absent + }; + let arrow_field = logical_file_schema.field(logical_schema_index); accumulators.column_byte_sizes[logical_schema_index] = compute_arrow_column_size( arrow_field.data_type(), @@ -778,4 +808,354 @@ mod tests { assert_eq!(result, Some(false)); } } + + mod ndv_tests { + use super::*; + use arrow::datatypes::Field; + use parquet::arrow::parquet_to_arrow_schema; + use parquet::basic::Type as PhysicalType; + use parquet::file::metadata::{ColumnChunkMetaData, RowGroupMetaData}; + use parquet::file::reader::{FileReader, SerializedFileReader}; + use parquet::file::statistics::Statistics as ParquetStatistics; + use parquet::schema::types::{SchemaDescriptor, Type as SchemaType}; + use std::fs::File; + use std::path::PathBuf; + + fn create_schema_descr(num_columns: usize) -> Arc { + let fields: Vec> = (0..num_columns) + .map(|i| { + Arc::new( + SchemaType::primitive_type_builder( + &format!("col_{i}"), + PhysicalType::INT32, + ) + .build() + .unwrap(), + ) + }) + .collect(); + + let schema = SchemaType::group_type_builder("schema") + .with_fields(fields) + .build() + .unwrap(); + + Arc::new(SchemaDescriptor::new(Arc::new(schema))) + } + + fn create_arrow_schema(num_columns: usize) -> SchemaRef { + let fields: Vec = (0..num_columns) + .map(|i| Field::new(format!("col_{i}"), DataType::Int32, true)) + .collect(); + Arc::new(Schema::new(fields)) + } + + fn create_row_group_with_stats( + schema_descr: &Arc, + column_stats: Vec>, + num_rows: i64, + ) -> RowGroupMetaData { + let columns: Vec = column_stats + .into_iter() + .enumerate() + .map(|(i, stats)| { + let mut builder = + ColumnChunkMetaData::builder(schema_descr.column(i)); + if let Some(s) = stats { + builder = builder.set_statistics(s); + } + builder.set_num_values(num_rows).build().unwrap() + }) + .collect(); + + RowGroupMetaData::builder(schema_descr.clone()) + .set_num_rows(num_rows) + .set_total_byte_size(1000) + .set_column_metadata(columns) + .build() + .unwrap() + } + + fn create_parquet_metadata( + schema_descr: Arc, + row_groups: Vec, + ) -> ParquetMetaData { + use parquet::file::metadata::FileMetaData; + + let num_rows: i64 = row_groups.iter().map(|rg| rg.num_rows()).sum(); + let file_meta = FileMetaData::new( + 1, // version + num_rows, // num_rows + None, // created_by + None, // key_value_metadata + schema_descr, // schema_descr + None, // column_orders + ); + + ParquetMetaData::new(file_meta, row_groups) + } + + #[test] + fn test_distinct_count_single_row_group_with_ndv() { + // Single row group with distinct count should return Exact + let schema_descr = create_schema_descr(1); + let arrow_schema = create_arrow_schema(1); + + // Create statistics with distinct_count = 42 + let stats = ParquetStatistics::int32( + Some(1), // min + Some(100), // max + Some(42), // distinct_count + Some(0), // null_count + false, // is_deprecated + ); + + let row_group = + create_row_group_with_stats(&schema_descr, vec![Some(stats)], 1000); + let metadata = create_parquet_metadata(schema_descr, vec![row_group]); + + let result = DFParquetMetadata::statistics_from_parquet_metadata( + &metadata, + &arrow_schema, + ) + .unwrap(); + + assert_eq!( + result.column_statistics[0].distinct_count, + Precision::Exact(42) + ); + } + + #[test] + fn test_distinct_count_multiple_row_groups_with_ndv() { + // Multiple row groups with distinct counts should return Inexact (sum) + let schema_descr = create_schema_descr(1); + let arrow_schema = create_arrow_schema(1); + + // Row group 1: distinct_count = 10 + let stats1 = ParquetStatistics::int32( + Some(1), + Some(50), + Some(10), // distinct_count + Some(0), + false, + ); + + // Row group 2: distinct_count = 20 + let stats2 = ParquetStatistics::int32( + Some(51), + Some(100), + Some(20), // distinct_count + Some(0), + false, + ); + + let row_group1 = + create_row_group_with_stats(&schema_descr, vec![Some(stats1)], 500); + let row_group2 = + create_row_group_with_stats(&schema_descr, vec![Some(stats2)], 500); + let metadata = + create_parquet_metadata(schema_descr, vec![row_group1, row_group2]); + + let result = DFParquetMetadata::statistics_from_parquet_metadata( + &metadata, + &arrow_schema, + ) + .unwrap(); + + // Max of distinct counts (lower bound since we can't accurately merge NDV) + assert_eq!( + result.column_statistics[0].distinct_count, + Precision::Inexact(20) + ); + } + + #[test] + fn test_distinct_count_no_ndv_available() { + // No distinct count in statistics should return Absent + let schema_descr = create_schema_descr(1); + let arrow_schema = create_arrow_schema(1); + + // Create statistics without distinct_count (None) + let stats = ParquetStatistics::int32( + Some(1), + Some(100), + None, // no distinct_count + Some(0), + false, + ); + + let row_group = + create_row_group_with_stats(&schema_descr, vec![Some(stats)], 1000); + let metadata = create_parquet_metadata(schema_descr, vec![row_group]); + + let result = DFParquetMetadata::statistics_from_parquet_metadata( + &metadata, + &arrow_schema, + ) + .unwrap(); + + assert_eq!( + result.column_statistics[0].distinct_count, + Precision::Absent + ); + } + + #[test] + fn test_distinct_count_partial_ndv_in_row_groups() { + // Some row groups have NDV, some don't - should use only those that have it + let schema_descr = create_schema_descr(1); + let arrow_schema = create_arrow_schema(1); + + // Row group 1: has distinct_count = 15 + let stats1 = + ParquetStatistics::int32(Some(1), Some(50), Some(15), Some(0), false); + + // Row group 2: no distinct_count + let stats2 = ParquetStatistics::int32( + Some(51), + Some(100), + None, // no distinct_count + Some(0), + false, + ); + + let row_group1 = + create_row_group_with_stats(&schema_descr, vec![Some(stats1)], 500); + let row_group2 = + create_row_group_with_stats(&schema_descr, vec![Some(stats2)], 500); + let metadata = + create_parquet_metadata(schema_descr, vec![row_group1, row_group2]); + + let result = DFParquetMetadata::statistics_from_parquet_metadata( + &metadata, + &arrow_schema, + ) + .unwrap(); + + // Only one row group has NDV, so it's Exact(15) + assert_eq!( + result.column_statistics[0].distinct_count, + Precision::Exact(15) + ); + } + + #[test] + fn test_distinct_count_multiple_columns() { + // Test with multiple columns, each with different NDV + let schema_descr = create_schema_descr(3); + let arrow_schema = create_arrow_schema(3); + + // col_0: distinct_count = 5 + let stats0 = + ParquetStatistics::int32(Some(1), Some(10), Some(5), Some(0), false); + // col_1: no distinct_count + let stats1 = + ParquetStatistics::int32(Some(1), Some(100), None, Some(0), false); + // col_2: distinct_count = 100 + let stats2 = + ParquetStatistics::int32(Some(1), Some(1000), Some(100), Some(0), false); + + let row_group = create_row_group_with_stats( + &schema_descr, + vec![Some(stats0), Some(stats1), Some(stats2)], + 1000, + ); + let metadata = create_parquet_metadata(schema_descr, vec![row_group]); + + let result = DFParquetMetadata::statistics_from_parquet_metadata( + &metadata, + &arrow_schema, + ) + .unwrap(); + + assert_eq!( + result.column_statistics[0].distinct_count, + Precision::Exact(5) + ); + assert_eq!( + result.column_statistics[1].distinct_count, + Precision::Absent + ); + assert_eq!( + result.column_statistics[2].distinct_count, + Precision::Exact(100) + ); + } + + #[test] + fn test_distinct_count_no_statistics_at_all() { + // No statistics in row group should return Absent for all stats + let schema_descr = create_schema_descr(1); + let arrow_schema = create_arrow_schema(1); + + // Create row group without any statistics + let row_group = create_row_group_with_stats(&schema_descr, vec![None], 1000); + let metadata = create_parquet_metadata(schema_descr, vec![row_group]); + + let result = DFParquetMetadata::statistics_from_parquet_metadata( + &metadata, + &arrow_schema, + ) + .unwrap(); + + assert_eq!( + result.column_statistics[0].distinct_count, + Precision::Absent + ); + } + + /// Integration test that reads a real Parquet file with distinct_count statistics. + /// The test file was created with DuckDB and has known NDV values: + /// - id: NULL (high cardinality, not tracked) + /// - category: 10 distinct values + /// - name: 5 distinct values + #[test] + fn test_distinct_count_from_real_parquet_file() { + // Path to test file created by DuckDB with distinct_count statistics + let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + path.push("src/test_data/ndv_test.parquet"); + + let file = File::open(&path).expect("Failed to open test parquet file"); + let reader = + SerializedFileReader::new(file).expect("Failed to create reader"); + let parquet_metadata = reader.metadata(); + + // Derive Arrow schema from parquet file metadata + let arrow_schema = Arc::new( + parquet_to_arrow_schema( + parquet_metadata.file_metadata().schema_descr(), + None, + ) + .expect("Failed to convert schema"), + ); + + let result = DFParquetMetadata::statistics_from_parquet_metadata( + parquet_metadata, + &arrow_schema, + ) + .expect("Failed to extract statistics"); + + // id: no distinct_count (high cardinality) + assert_eq!( + result.column_statistics[0].distinct_count, + Precision::Absent, + "id column should have Absent distinct_count" + ); + + // category: 10 distinct values + assert_eq!( + result.column_statistics[1].distinct_count, + Precision::Exact(10), + "category column should have Exact(10) distinct_count" + ); + + // name: 5 distinct values + assert_eq!( + result.column_statistics[2].distinct_count, + Precision::Exact(5), + "name column should have Exact(5) distinct_count" + ); + } + } } diff --git a/datafusion/datasource-parquet/src/test_data/ndv_test.parquet b/datafusion/datasource-parquet/src/test_data/ndv_test.parquet new file mode 100644 index 0000000000000..3ecbe320f506e Binary files /dev/null and b/datafusion/datasource-parquet/src/test_data/ndv_test.parquet differ diff --git a/datafusion/physical-expr/src/expression_analyzer.rs b/datafusion/physical-expr/src/expression_analyzer.rs new file mode 100644 index 0000000000000..8c458ba0679cb --- /dev/null +++ b/datafusion/physical-expr/src/expression_analyzer.rs @@ -0,0 +1,1222 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Pluggable expression-level statistics analysis. +//! +//! This module provides an extensible mechanism for computing expression-level +//! statistics metadata (selectivity, NDV, min/max bounds) following the chain +//! of responsibility pattern. +//! +//! # Overview +//! +//! Different expressions have different statistical properties: +//! +//! - **Injective functions** (UPPER, LOWER, ABS on non-negative): preserve NDV +//! - **Non-injective functions** (FLOOR, YEAR, SUBSTRING): reduce NDV +//! - **Monotonic functions**: allow min/max bound propagation +//! - **Constants**: NDV = 1, selectivity depends on value +//! +//! The default implementation uses classic Selinger-style estimation. Users can +//! register custom [`ExpressionAnalyzer`] implementations to: +//! +//! 1. Provide statistics for custom UDFs +//! 2. Override default estimation with domain-specific knowledge +//! 3. Plug in advanced approaches (e.g., histogram-based estimation) +//! +//! # Example +//! +//! ```ignore +//! use datafusion_physical_plan::expression_analyzer::*; +//! +//! // Create registry with default analyzer +//! let mut registry = ExpressionAnalyzerRegistry::new(); +//! +//! // Register custom analyzer (higher priority) +//! registry.register(Arc::new(MyCustomAnalyzer)); +//! +//! // Query expression statistics +//! let selectivity = registry.get_selectivity(&predicate, &input_stats); +//! ``` + +use std::fmt::Debug; +use std::sync::Arc; + +use datafusion_common::{ColumnStatistics, ScalarValue, Statistics}; +use datafusion_expr::Operator; + +use crate::expressions::{BinaryExpr, Column, Literal, NotExpr}; +use crate::{PhysicalExpr, ScalarFunctionExpr}; + +// ============================================================================ +// AnalysisResult: Chain of responsibility result type +// ============================================================================ + +/// Result of expression analysis - either computed or delegate to next analyzer. +#[derive(Debug, Clone)] +pub enum AnalysisResult { + /// Analysis was performed, here's the result + Computed(T), + /// This analyzer doesn't handle this expression; delegate to next + Delegate, +} + +impl AnalysisResult { + /// Convert to Option, returning None for Delegate + pub fn into_option(self) -> Option { + match self { + AnalysisResult::Computed(v) => Some(v), + AnalysisResult::Delegate => None, + } + } + + /// Returns true if this is a Computed result + pub fn is_computed(&self) -> bool { + matches!(self, AnalysisResult::Computed(_)) + } +} + +// ============================================================================ +// ExpressionAnalyzer trait +// ============================================================================ + +/// Expression-level metadata analysis. +/// +/// Implementations can handle specific expression types or provide domain +/// knowledge for custom UDFs. The chain of analyzers is traversed until one +/// returns [`AnalysisResult::Computed`]. +/// +/// # Implementing a Custom Analyzer +/// +/// ```ignore +/// #[derive(Debug)] +/// struct MyUdfAnalyzer; +/// +/// impl ExpressionAnalyzer for MyUdfAnalyzer { +/// fn get_selectivity( +/// &self, +/// expr: &Arc, +/// input_stats: &Statistics, +/// ) -> AnalysisResult { +/// // Recognize my custom is_valid_email() UDF +/// if is_my_email_validator(expr) { +/// return AnalysisResult::Computed(0.8); // ~80% valid +/// } +/// AnalysisResult::Delegate +/// } +/// } +/// ``` +pub trait ExpressionAnalyzer: Debug + Send + Sync { + /// Estimate selectivity when this expression is used as a predicate. + /// + /// Returns a value in [0.0, 1.0] representing the fraction of rows + /// that satisfy the predicate. + fn get_selectivity( + &self, + _expr: &Arc, + _input_stats: &Statistics, + ) -> AnalysisResult { + AnalysisResult::Delegate + } + + /// Estimate the number of distinct values in the expression's output. + /// + /// Properties: + /// - Injective functions preserve input NDV + /// - Non-injective functions reduce NDV (e.g., FLOOR, YEAR) + /// - Constants have NDV = 1 + fn get_distinct_count( + &self, + _expr: &Arc, + _input_stats: &Statistics, + ) -> AnalysisResult { + AnalysisResult::Delegate + } + + /// Estimate min/max bounds of the expression's output. + /// + /// Monotonic functions can transform input bounds: + /// - Increasing: (f(min), f(max)) + /// - Decreasing: (f(max), f(min)) + /// - Non-monotonic: may need wider bounds or return Delegate + fn get_min_max( + &self, + _expr: &Arc, + _input_stats: &Statistics, + ) -> AnalysisResult<(ScalarValue, ScalarValue)> { + AnalysisResult::Delegate + } + + /// Estimate the fraction of null values in the expression's output. + /// + /// Returns a value in [0.0, 1.0]. + fn get_null_fraction( + &self, + _expr: &Arc, + _input_stats: &Statistics, + ) -> AnalysisResult { + AnalysisResult::Delegate + } +} + +// ============================================================================ +// ExpressionAnalyzerRegistry +// ============================================================================ + +/// Registry that chains [`ExpressionAnalyzer`] implementations. +/// +/// Analyzers are tried in order; the first to return [`AnalysisResult::Computed`] +/// wins. Register domain-specific analyzers before the default for override. +#[derive(Debug, Clone)] +pub struct ExpressionAnalyzerRegistry { + analyzers: Vec>, +} + +impl Default for ExpressionAnalyzerRegistry { + fn default() -> Self { + Self::new() + } +} + +impl ExpressionAnalyzerRegistry { + /// Create a new registry with the default expression analyzer. + pub fn new() -> Self { + Self { + analyzers: vec![Arc::new(DefaultExpressionAnalyzer)], + } + } + + /// Create a registry with all built-in analyzers (string, math, datetime, default). + pub fn with_builtin_analyzers() -> Self { + Self { + analyzers: vec![ + Arc::new(StringFunctionAnalyzer), + Arc::new(MathFunctionAnalyzer), + Arc::new(DateTimeFunctionAnalyzer), + Arc::new(DefaultExpressionAnalyzer), + ], + } + } + + /// Create a registry with custom analyzers (no default). + pub fn with_analyzers(analyzers: Vec>) -> Self { + Self { analyzers } + } + + /// Create a registry with custom analyzers plus default as fallback. + pub fn with_analyzers_and_default( + analyzers: impl IntoIterator>, + ) -> Self { + let mut all: Vec> = analyzers.into_iter().collect(); + all.push(Arc::new(DefaultExpressionAnalyzer)); + Self { analyzers: all } + } + + /// Register an analyzer at the front of the chain (higher priority). + pub fn register(&mut self, analyzer: Arc) { + self.analyzers.insert(0, analyzer); + } + + /// Get selectivity through the analyzer chain. + pub fn get_selectivity( + &self, + expr: &Arc, + input_stats: &Statistics, + ) -> Option { + for analyzer in &self.analyzers { + if let AnalysisResult::Computed(sel) = + analyzer.get_selectivity(expr, input_stats) + { + return Some(sel); + } + } + None + } + + /// Get distinct count through the analyzer chain. + pub fn get_distinct_count( + &self, + expr: &Arc, + input_stats: &Statistics, + ) -> Option { + for analyzer in &self.analyzers { + if let AnalysisResult::Computed(ndv) = + analyzer.get_distinct_count(expr, input_stats) + { + return Some(ndv); + } + } + None + } + + /// Get min/max bounds through the analyzer chain. + pub fn get_min_max( + &self, + expr: &Arc, + input_stats: &Statistics, + ) -> Option<(ScalarValue, ScalarValue)> { + for analyzer in &self.analyzers { + if let AnalysisResult::Computed(bounds) = + analyzer.get_min_max(expr, input_stats) + { + return Some(bounds); + } + } + None + } + + /// Get null fraction through the analyzer chain. + pub fn get_null_fraction( + &self, + expr: &Arc, + input_stats: &Statistics, + ) -> Option { + for analyzer in &self.analyzers { + if let AnalysisResult::Computed(frac) = + analyzer.get_null_fraction(expr, input_stats) + { + return Some(frac); + } + } + None + } +} + +// ============================================================================ +// DefaultExpressionAnalyzer +// ============================================================================ + +/// Default expression analyzer with Selinger-style estimation. +/// +/// Handles common expression types: +/// - Column references (uses column statistics) +/// - Binary expressions (AND, OR, comparison operators) +/// - Literals (constant selectivity/NDV) +/// - NOT expressions (1 - child selectivity) +#[derive(Debug, Default, Clone)] +pub struct DefaultExpressionAnalyzer; + +impl DefaultExpressionAnalyzer { + /// Get column index from a Column expression + fn get_column_index(expr: &Arc) -> Option { + expr.as_any().downcast_ref::().map(|c| c.index()) + } + + /// Get column statistics for an expression if it's a column reference + fn get_column_stats<'a>( + expr: &Arc, + input_stats: &'a Statistics, + ) -> Option<&'a ColumnStatistics> { + Self::get_column_index(expr) + .and_then(|idx| input_stats.column_statistics.get(idx)) + } + + /// Recursive selectivity estimation + fn estimate_selectivity_recursive( + &self, + expr: &Arc, + input_stats: &Statistics, + ) -> f64 { + if let AnalysisResult::Computed(sel) = self.get_selectivity(expr, input_stats) { + return sel; + } + 0.5 // Default fallback + } +} + +impl ExpressionAnalyzer for DefaultExpressionAnalyzer { + fn get_selectivity( + &self, + expr: &Arc, + input_stats: &Statistics, + ) -> AnalysisResult { + // Binary expressions: AND, OR, comparisons + if let Some(binary) = expr.as_any().downcast_ref::() { + let left_sel = + self.estimate_selectivity_recursive(binary.left(), input_stats); + let right_sel = + self.estimate_selectivity_recursive(binary.right(), input_stats); + + let sel = match binary.op() { + // Logical operators + Operator::And => left_sel * right_sel, + Operator::Or => left_sel + right_sel - (left_sel * right_sel), + + // Equality: selectivity = 1/NDV + Operator::Eq => { + if let Some(ndv) = Self::get_column_stats(binary.left(), input_stats) + .and_then(|s| s.distinct_count.get_value()) + .filter(|&&ndv| ndv > 0) + { + return AnalysisResult::Computed(1.0 / (*ndv as f64)); + } + 0.1 // Default equality selectivity + } + + // Inequality: selectivity = 1 - 1/NDV + Operator::NotEq => { + if let Some(ndv) = Self::get_column_stats(binary.left(), input_stats) + .and_then(|s| s.distinct_count.get_value()) + .filter(|&&ndv| ndv > 0) + { + return AnalysisResult::Computed(1.0 - (1.0 / (*ndv as f64))); + } + 0.9 + } + + // Range predicates: classic 1/3 estimate + Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq => 0.33, + + // LIKE: depends on pattern, use conservative estimate + Operator::LikeMatch | Operator::ILikeMatch => 0.25, + Operator::NotLikeMatch | Operator::NotILikeMatch => 0.75, + + // Other operators: default + _ => 0.5, + }; + + return AnalysisResult::Computed(sel); + } + + // NOT expression: 1 - child selectivity + if let Some(not_expr) = expr.as_any().downcast_ref::() { + let child_sel = + self.estimate_selectivity_recursive(not_expr.arg(), input_stats); + return AnalysisResult::Computed(1.0 - child_sel); + } + + // Literal boolean: 0.0 or 1.0 + if let Some(b) = expr + .as_any() + .downcast_ref::() + .and_then(|lit| match lit.value() { + ScalarValue::Boolean(Some(b)) => Some(*b), + _ => None, + }) + { + return AnalysisResult::Computed(if b { 1.0 } else { 0.0 }); + } + + // Column reference as predicate (boolean column) + if expr.as_any().downcast_ref::().is_some() { + return AnalysisResult::Computed(0.5); + } + + AnalysisResult::Delegate + } + + fn get_distinct_count( + &self, + expr: &Arc, + input_stats: &Statistics, + ) -> AnalysisResult { + // Column reference: use column NDV + if let Some(ndv) = Self::get_column_stats(expr, input_stats) + .and_then(|col_stats| col_stats.distinct_count.get_value().copied()) + { + return AnalysisResult::Computed(ndv); + } + + // Literal: NDV = 1 + if expr.as_any().downcast_ref::().is_some() { + return AnalysisResult::Computed(1); + } + + // BinaryExpr: for injective operations (arithmetic with literal), preserve NDV + if let Some(binary) = expr.as_any().downcast_ref::() { + let is_arithmetic = matches!( + binary.op(), + Operator::Plus + | Operator::Minus + | Operator::Multiply + | Operator::Divide + | Operator::Modulo + ); + + if is_arithmetic { + // If one side is a literal, the operation is injective on the other side + let left_is_literal = binary.left().as_any().is::(); + let right_is_literal = binary.right().as_any().is::(); + + if left_is_literal { + // NDV comes from right side + if let AnalysisResult::Computed(ndv) = + self.get_distinct_count(binary.right(), input_stats) + { + return AnalysisResult::Computed(ndv); + } + } else if right_is_literal { + // NDV comes from left side + if let AnalysisResult::Computed(ndv) = + self.get_distinct_count(binary.left(), input_stats) + { + return AnalysisResult::Computed(ndv); + } + } + // Both sides are non-literals: could combine, but delegate for now + } + } + + AnalysisResult::Delegate + } + + fn get_min_max( + &self, + expr: &Arc, + input_stats: &Statistics, + ) -> AnalysisResult<(ScalarValue, ScalarValue)> { + // Column reference: use column min/max + if let Some((min, max)) = + Self::get_column_stats(expr, input_stats).and_then(|col_stats| { + match ( + col_stats.min_value.get_value(), + col_stats.max_value.get_value(), + ) { + (Some(min), Some(max)) => Some((min.clone(), max.clone())), + _ => None, + } + }) + { + return AnalysisResult::Computed((min, max)); + } + + // Literal: min = max = value + if let Some(lit_expr) = expr.as_any().downcast_ref::() { + let val = lit_expr.value().clone(); + return AnalysisResult::Computed((val.clone(), val)); + } + + AnalysisResult::Delegate + } + + fn get_null_fraction( + &self, + expr: &Arc, + input_stats: &Statistics, + ) -> AnalysisResult { + // Column reference: null_count / num_rows + if let Some(fraction) = + Self::get_column_stats(expr, input_stats).and_then(|col_stats| { + let null_count = col_stats.null_count.get_value().copied()?; + let num_rows = input_stats.num_rows.get_value().copied()?; + if num_rows > 0 { + Some(null_count as f64 / num_rows as f64) + } else { + None + } + }) + { + return AnalysisResult::Computed(fraction); + } + + // Literal: null fraction depends on whether it's null + if let Some(lit_expr) = expr.as_any().downcast_ref::() { + let is_null = lit_expr.value().is_null(); + return AnalysisResult::Computed(if is_null { 1.0 } else { 0.0 }); + } + + AnalysisResult::Delegate + } +} + +// ============================================================================ +// StringFunctionAnalyzer +// ============================================================================ + +/// Analyzer for string functions. +/// +/// - Injective (preserve NDV): UPPER, LOWER, TRIM, LTRIM, RTRIM, REVERSE +/// - Non-injective (reduce NDV): SUBSTRING, LEFT, RIGHT, REPLACE +#[derive(Debug, Default, Clone)] +pub struct StringFunctionAnalyzer; + +impl StringFunctionAnalyzer { + /// Check if a function is injective (one-to-one) + pub fn is_injective(func_name: &str) -> bool { + matches!( + func_name.to_uppercase().as_str(), + "UPPER" | "LOWER" | "TRIM" | "LTRIM" | "RTRIM" | "REVERSE" | "INITCAP" + ) + } + + /// Get NDV reduction factor for non-injective functions + pub fn ndv_reduction_factor(func_name: &str) -> Option { + match func_name.to_uppercase().as_str() { + "SUBSTRING" | "LEFT" | "RIGHT" => Some(0.5), + "REPLACE" => Some(0.8), + _ => None, + } + } +} + +impl ExpressionAnalyzer for StringFunctionAnalyzer { + fn get_distinct_count( + &self, + expr: &Arc, + input_stats: &Statistics, + ) -> AnalysisResult { + let Some(func) = expr.as_any().downcast_ref::() else { + return AnalysisResult::Delegate; + }; + + let func_name = func.name(); + let Some(first_arg) = func.args().first() else { + return AnalysisResult::Delegate; + }; + + // Get input NDV + let Some(input_ndv) = DefaultExpressionAnalyzer + .get_distinct_count(first_arg, input_stats) + .into_option() + else { + return AnalysisResult::Delegate; + }; + + // Injective functions preserve NDV + if Self::is_injective(func_name) { + return AnalysisResult::Computed(input_ndv); + } + + // Non-injective functions reduce NDV + if let Some(factor) = Self::ndv_reduction_factor(func_name) { + let reduced = ((input_ndv as f64) * factor).ceil() as usize; + return AnalysisResult::Computed(reduced.max(1)); + } + + AnalysisResult::Delegate + } +} + +// ============================================================================ +// MathFunctionAnalyzer +// ============================================================================ + +/// Analyzer for mathematical functions. +/// +/// - Injective on domain: ABS (on non-negative), SQRT (on non-negative) +/// - Non-injective: FLOOR, CEIL, ROUND, SIGN +/// - Monotonic: EXP, LN, LOG +#[derive(Debug, Default, Clone)] +pub struct MathFunctionAnalyzer; + +impl MathFunctionAnalyzer { + /// Check if function is injective (preserves NDV) + pub fn is_injective(func_name: &str) -> bool { + matches!( + func_name.to_uppercase().as_str(), + "EXP" | "LN" | "LOG" | "LOG2" | "LOG10" + ) + } + + /// Get NDV reduction factor for non-injective functions + pub fn ndv_reduction_factor(func_name: &str) -> Option { + match func_name.to_uppercase().as_str() { + "FLOOR" | "CEIL" | "ROUND" | "TRUNC" => Some(0.1), + "SIGN" => Some(0.01), // Only -1, 0, 1 + "ABS" => Some(0.5), // Roughly halves NDV + _ => None, + } + } +} + +impl ExpressionAnalyzer for MathFunctionAnalyzer { + fn get_distinct_count( + &self, + expr: &Arc, + input_stats: &Statistics, + ) -> AnalysisResult { + let Some(func) = expr.as_any().downcast_ref::() else { + return AnalysisResult::Delegate; + }; + + let func_name = func.name(); + let Some(first_arg) = func.args().first() else { + return AnalysisResult::Delegate; + }; + + let Some(input_ndv) = DefaultExpressionAnalyzer + .get_distinct_count(first_arg, input_stats) + .into_option() + else { + return AnalysisResult::Delegate; + }; + + // Injective functions preserve NDV + if Self::is_injective(func_name) { + return AnalysisResult::Computed(input_ndv); + } + + // Non-injective functions reduce NDV + if let Some(factor) = Self::ndv_reduction_factor(func_name) { + let reduced = ((input_ndv as f64) * factor).ceil() as usize; + return AnalysisResult::Computed(reduced.max(1)); + } + + AnalysisResult::Delegate + } +} + +// ============================================================================ +// DateTimeFunctionAnalyzer +// ============================================================================ + +/// Analyzer for date/time functions. +/// +/// - Non-injective with known bounds: YEAR, MONTH, DAY, HOUR, etc. +/// - These extract components with limited cardinality +#[derive(Debug, Default, Clone)] +pub struct DateTimeFunctionAnalyzer; + +impl DateTimeFunctionAnalyzer { + /// Get maximum possible NDV for date/time extraction functions + pub fn max_ndv(func_name: &str) -> Option { + match func_name.to_uppercase().as_str() { + "YEAR" | "EXTRACT_YEAR" => None, // Unbounded, but typically < input NDV + "MONTH" | "EXTRACT_MONTH" => Some(12), + "DAY" | "EXTRACT_DAY" | "DAY_OF_MONTH" => Some(31), + "HOUR" | "EXTRACT_HOUR" => Some(24), + "MINUTE" | "EXTRACT_MINUTE" => Some(60), + "SECOND" | "EXTRACT_SECOND" => Some(60), + "DAYOFWEEK" | "DOW" | "EXTRACT_DOW" => Some(7), + "DAYOFYEAR" | "DOY" | "EXTRACT_DOY" => Some(366), + "WEEK" | "EXTRACT_WEEK" => Some(53), + "QUARTER" | "EXTRACT_QUARTER" => Some(4), + _ => None, + } + } +} + +impl ExpressionAnalyzer for DateTimeFunctionAnalyzer { + fn get_distinct_count( + &self, + expr: &Arc, + input_stats: &Statistics, + ) -> AnalysisResult { + let Some(func) = expr.as_any().downcast_ref::() else { + return AnalysisResult::Delegate; + }; + + let func_name = func.name(); + let Some(first_arg) = func.args().first() else { + return AnalysisResult::Delegate; + }; + + // Get max possible NDV for this function + let Some(max_ndv) = Self::max_ndv(func_name) else { + return AnalysisResult::Delegate; + }; + + // Get input NDV if available + let input_ndv = DefaultExpressionAnalyzer + .get_distinct_count(first_arg, input_stats) + .into_option(); + + // NDV is min(input_ndv, max_possible) + let result_ndv = match input_ndv { + Some(ndv) => ndv.min(max_ndv), + None => max_ndv, + }; + + AnalysisResult::Computed(result_ndv) + } +} + +// ============================================================================ +// Utility functions for filter statistics +// ============================================================================ + +/// Estimate selectivity for a filter predicate. +/// +/// This is a convenience function that uses the default analyzer chain. +/// For custom analysis, use [`ExpressionAnalyzerRegistry`] directly. +pub fn estimate_filter_selectivity( + predicate: &Arc, + input_stats: &Statistics, +) -> f64 { + ExpressionAnalyzerRegistry::with_builtin_analyzers() + .get_selectivity(predicate, input_stats) + .unwrap_or(0.5) +} + +/// Estimate NDV after applying a filter with given selectivity. +/// +/// Uses the formula: NDV_after = NDV_before * (1 - (1 - selectivity)^(num_rows / NDV_before)) +/// +/// This models the probability that at least one row with each distinct value survives. +pub fn ndv_after_selectivity( + original_ndv: usize, + original_rows: usize, + selectivity: f64, +) -> usize { + if original_ndv == 0 || original_rows == 0 || selectivity <= 0.0 { + return 0; + } + if selectivity >= 1.0 { + return original_ndv; + } + + // Average rows per distinct value + let rows_per_value = original_rows as f64 / original_ndv as f64; + + // Probability that all rows for a value are filtered out + let prob_all_filtered = (1.0 - selectivity).powf(rows_per_value); + + // Expected number of distinct values remaining + let expected_ndv = (original_ndv as f64) * (1.0 - prob_all_filtered); + + expected_ndv.ceil() as usize +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_common::stats::Precision; + use std::sync::Arc; + + fn make_stats_with_ndv(num_rows: usize, ndv: usize) -> Statistics { + Statistics { + num_rows: Precision::Exact(num_rows), + total_byte_size: Precision::Absent, + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Absent, + distinct_count: Precision::Exact(ndv), + byte_size: Precision::Absent, + }], + } + } + + #[test] + fn test_default_analyzer_column_ndv() { + let stats = make_stats_with_ndv(1000, 100); + let col = Arc::new(Column::new("a", 0)) as Arc; + + let analyzer = DefaultExpressionAnalyzer; + let result = analyzer.get_distinct_count(&col, &stats); + + assert!(matches!(result, AnalysisResult::Computed(100))); + } + + #[test] + fn test_default_analyzer_literal_ndv() { + let stats = make_stats_with_ndv(1000, 100); + let lit = + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))) as Arc; + + let analyzer = DefaultExpressionAnalyzer; + let result = analyzer.get_distinct_count(&lit, &stats); + + assert!(matches!(result, AnalysisResult::Computed(1))); + } + + #[test] + fn test_default_analyzer_equality_selectivity() { + let stats = make_stats_with_ndv(1000, 100); + let col = Arc::new(Column::new("a", 0)) as Arc; + let lit = + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))) as Arc; + let eq = + Arc::new(BinaryExpr::new(col, Operator::Eq, lit)) as Arc; + + let analyzer = DefaultExpressionAnalyzer; + let result = analyzer.get_selectivity(&eq, &stats); + + // Selectivity should be 1/NDV = 1/100 = 0.01 + match result { + AnalysisResult::Computed(sel) => { + assert!((sel - 0.01).abs() < 0.001); + } + _ => panic!("Expected Computed result"), + } + } + + #[test] + fn test_registry_chain() { + let stats = make_stats_with_ndv(1000, 100); + let col = Arc::new(Column::new("a", 0)) as Arc; + + let registry = ExpressionAnalyzerRegistry::with_builtin_analyzers(); + let ndv = registry.get_distinct_count(&col, &stats); + + assert_eq!(ndv, Some(100)); + } + + #[test] + fn test_ndv_after_selectivity() { + // 1000 rows, 100 NDV, 10% selectivity + let result = ndv_after_selectivity(100, 1000, 0.1); + // With 10 rows per value and 10% selectivity, most values should survive + assert!(result > 50 && result <= 100); + + // 100% selectivity preserves NDV + assert_eq!(ndv_after_selectivity(100, 1000, 1.0), 100); + + // 0% selectivity gives 0 NDV + assert_eq!(ndv_after_selectivity(100, 1000, 0.0), 0); + } + + #[test] + fn test_datetime_function_analyzer() { + // MONTH should have max NDV of 12 + assert_eq!(DateTimeFunctionAnalyzer::max_ndv("MONTH"), Some(12)); + assert_eq!(DateTimeFunctionAnalyzer::max_ndv("HOUR"), Some(24)); + assert_eq!(DateTimeFunctionAnalyzer::max_ndv("QUARTER"), Some(4)); + } + + #[test] + fn test_string_function_analyzer() { + assert!(StringFunctionAnalyzer::is_injective("UPPER")); + assert!(StringFunctionAnalyzer::is_injective("lower")); + assert!(!StringFunctionAnalyzer::is_injective("SUBSTRING")); + + assert_eq!( + StringFunctionAnalyzer::ndv_reduction_factor("SUBSTRING"), + Some(0.5) + ); + } + + #[test] + fn test_math_function_analyzer() { + assert!(MathFunctionAnalyzer::is_injective("EXP")); + assert!(MathFunctionAnalyzer::is_injective("LN")); + assert!(!MathFunctionAnalyzer::is_injective("FLOOR")); + + assert_eq!( + MathFunctionAnalyzer::ndv_reduction_factor("FLOOR"), + Some(0.1) + ); + assert_eq!( + MathFunctionAnalyzer::ndv_reduction_factor("SIGN"), + Some(0.01) + ); + } + + // ======================================================================== + // Tests for AND/OR/NOT logical operators + // ======================================================================== + + #[test] + fn test_and_selectivity() { + let stats = make_stats_with_ndv(1000, 100); + let col = Arc::new(Column::new("a", 0)) as Arc; + let lit1 = + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))) as Arc; + let lit2 = + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))) as Arc; + + // a = 42 AND a > 10 + let eq = Arc::new(BinaryExpr::new(Arc::clone(&col), Operator::Eq, lit1)) + as Arc; + let gt = + Arc::new(BinaryExpr::new(col, Operator::Gt, lit2)) as Arc; + let and_expr = + Arc::new(BinaryExpr::new(eq, Operator::And, gt)) as Arc; + + let registry = ExpressionAnalyzerRegistry::new(); + let sel = registry.get_selectivity(&and_expr, &stats).unwrap(); + + // AND: 0.01 * 0.33 = 0.0033 + assert!((sel - 0.0033).abs() < 0.001); + } + + #[test] + fn test_or_selectivity() { + let stats = make_stats_with_ndv(1000, 100); + let col = Arc::new(Column::new("a", 0)) as Arc; + let lit1 = + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))) as Arc; + let lit2 = + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))) as Arc; + + // a = 42 OR a > 10 + let eq = Arc::new(BinaryExpr::new(Arc::clone(&col), Operator::Eq, lit1)) + as Arc; + let gt = + Arc::new(BinaryExpr::new(col, Operator::Gt, lit2)) as Arc; + let or_expr = + Arc::new(BinaryExpr::new(eq, Operator::Or, gt)) as Arc; + + let registry = ExpressionAnalyzerRegistry::new(); + let sel = registry.get_selectivity(&or_expr, &stats).unwrap(); + + // OR: 0.01 + 0.33 - (0.01 * 0.33) = 0.3367 + assert!((sel - 0.3367).abs() < 0.001); + } + + #[test] + fn test_not_selectivity() { + let stats = make_stats_with_ndv(1000, 100); + let col = Arc::new(Column::new("a", 0)) as Arc; + let lit = + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))) as Arc; + + // NOT (a = 42) + let eq = + Arc::new(BinaryExpr::new(col, Operator::Eq, lit)) as Arc; + let not_expr = Arc::new(NotExpr::new(eq)) as Arc; + + let registry = ExpressionAnalyzerRegistry::new(); + let sel = registry.get_selectivity(¬_expr, &stats).unwrap(); + + // NOT: 1 - 0.01 = 0.99 + assert!((sel - 0.99).abs() < 0.001); + } + + #[test] + fn test_nested_and_or() { + let stats = make_stats_with_ndv(1000, 100); + let col = Arc::new(Column::new("a", 0)) as Arc; + let lit1 = + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; + let lit2 = + Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; + let lit3 = + Arc::new(Literal::new(ScalarValue::Int32(Some(3)))) as Arc; + + // (a = 1 OR a = 2) AND a > 3 + let eq1 = Arc::new(BinaryExpr::new(Arc::clone(&col), Operator::Eq, lit1)) + as Arc; + let eq2 = Arc::new(BinaryExpr::new(Arc::clone(&col), Operator::Eq, lit2)) + as Arc; + let gt3 = + Arc::new(BinaryExpr::new(col, Operator::Gt, lit3)) as Arc; + + let or_expr = + Arc::new(BinaryExpr::new(eq1, Operator::Or, eq2)) as Arc; + let and_expr = Arc::new(BinaryExpr::new(or_expr, Operator::And, gt3)) + as Arc; + + let registry = ExpressionAnalyzerRegistry::new(); + let sel = registry.get_selectivity(&and_expr, &stats).unwrap(); + + // (0.01 + 0.01 - 0.0001) * 0.33 ≈ 0.0066 + assert!(sel > 0.005 && sel < 0.01); + } + + // ======================================================================== + // Tests for custom analyzer override + // ======================================================================== + + /// Custom analyzer that always returns selectivity of 0.42 for any expression + #[derive(Debug)] + struct FixedSelectivityAnalyzer(f64); + + impl ExpressionAnalyzer for FixedSelectivityAnalyzer { + fn get_selectivity( + &self, + _expr: &Arc, + _input_stats: &Statistics, + ) -> AnalysisResult { + AnalysisResult::Computed(self.0) + } + } + + #[test] + fn test_custom_analyzer_overrides_default() { + let stats = make_stats_with_ndv(1000, 100); + let col = Arc::new(Column::new("a", 0)) as Arc; + let lit = + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))) as Arc; + let eq = + Arc::new(BinaryExpr::new(col, Operator::Eq, lit)) as Arc; + + // Default would give 1/100 = 0.01 + let default_registry = ExpressionAnalyzerRegistry::new(); + let default_sel = default_registry.get_selectivity(&eq, &stats).unwrap(); + assert!((default_sel - 0.01).abs() < 0.001); + + // Custom analyzer overrides to 0.42 + let mut custom_registry = ExpressionAnalyzerRegistry::new(); + custom_registry.register(Arc::new(FixedSelectivityAnalyzer(0.42))); + let custom_sel = custom_registry.get_selectivity(&eq, &stats).unwrap(); + assert!((custom_sel - 0.42).abs() < 0.001); + } + + /// Custom analyzer that only handles specific expressions + #[derive(Debug)] + struct ColumnAOnlyAnalyzer; + + impl ExpressionAnalyzer for ColumnAOnlyAnalyzer { + fn get_selectivity( + &self, + expr: &Arc, + _input_stats: &Statistics, + ) -> AnalysisResult { + // Only handle column "a" equality + if let Some(binary) = expr.as_any().downcast_ref::() + && let Some(col) = binary.left().as_any().downcast_ref::() + && col.name() == "a" + && matches!(binary.op(), Operator::Eq) + { + return AnalysisResult::Computed(0.99); // Override for col a + } + AnalysisResult::Delegate // Let default handle everything else + } + } + + #[test] + fn test_custom_analyzer_delegates_to_default() { + let stats = make_stats_with_ndv(1000, 100); + let col_a = Arc::new(Column::new("a", 0)) as Arc; + let col_b = Arc::new(Column::new("b", 0)) as Arc; + let lit = + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))) as Arc; + + let eq_a = Arc::new(BinaryExpr::new(col_a, Operator::Eq, Arc::clone(&lit))) + as Arc; + let eq_b = + Arc::new(BinaryExpr::new(col_b, Operator::Eq, lit)) as Arc; + + let mut registry = ExpressionAnalyzerRegistry::new(); + registry.register(Arc::new(ColumnAOnlyAnalyzer)); + + // Column "a" equality uses custom (0.99) + let sel_a = registry.get_selectivity(&eq_a, &stats).unwrap(); + assert!((sel_a - 0.99).abs() < 0.001); + + // Column "b" equality delegates to default (0.01) + let sel_b = registry.get_selectivity(&eq_b, &stats).unwrap(); + assert!((sel_b - 0.01).abs() < 0.001); + } + + #[test] + fn test_registry_with_no_default() { + let stats = make_stats_with_ndv(1000, 100); + let col = Arc::new(Column::new("a", 0)) as Arc; + let lit = + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))) as Arc; + let eq = + Arc::new(BinaryExpr::new(col, Operator::Eq, lit)) as Arc; + + // Registry with only custom analyzer (no default) + let registry = ExpressionAnalyzerRegistry::with_analyzers(vec![Arc::new( + FixedSelectivityAnalyzer(0.77), + )]); + + let sel = registry.get_selectivity(&eq, &stats).unwrap(); + assert!((sel - 0.77).abs() < 0.001); + } + + #[test] + fn test_registry_with_multiple_custom_analyzers() { + let stats = make_stats_with_ndv(1000, 100); + let col = Arc::new(Column::new("a", 0)) as Arc; + let lit = + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))) as Arc; + let eq = + Arc::new(BinaryExpr::new(col, Operator::Eq, lit)) as Arc; + + // First analyzer in chain wins + let registry = ExpressionAnalyzerRegistry::with_analyzers(vec![ + Arc::new(FixedSelectivityAnalyzer(0.11)), + Arc::new(FixedSelectivityAnalyzer(0.22)), + Arc::new(DefaultExpressionAnalyzer), + ]); + + let sel = registry.get_selectivity(&eq, &stats).unwrap(); + assert!((sel - 0.11).abs() < 0.001); // First one wins + } + + #[test] + fn test_custom_ndv_analyzer() { + /// Custom analyzer that doubles NDV + #[derive(Debug)] + struct DoubleNdvAnalyzer; + + impl ExpressionAnalyzer for DoubleNdvAnalyzer { + fn get_distinct_count( + &self, + expr: &Arc, + input_stats: &Statistics, + ) -> AnalysisResult { + // Get default NDV and double it + if let Some(ndv) = DefaultExpressionAnalyzer + .get_distinct_count(expr, input_stats) + .into_option() + { + return AnalysisResult::Computed(ndv * 2); + } + AnalysisResult::Delegate + } + } + + let stats = make_stats_with_ndv(1000, 100); + let col = Arc::new(Column::new("a", 0)) as Arc; + + let mut registry = ExpressionAnalyzerRegistry::new(); + registry.register(Arc::new(DoubleNdvAnalyzer)); + + let ndv = registry.get_distinct_count(&col, &stats).unwrap(); + assert_eq!(ndv, 200); // Doubled from 100 + } + + #[test] + fn test_with_analyzers_and_default() { + let stats = make_stats_with_ndv(1000, 100); + let col = Arc::new(Column::new("a", 0)) as Arc; + + // ColumnAOnlyAnalyzer only handles equality on "a", delegates NDV to default + let registry = + ExpressionAnalyzerRegistry::with_analyzers_and_default(vec![Arc::new( + ColumnAOnlyAnalyzer, + ) + as Arc]); + + // NDV should come from default (100) + let ndv = registry.get_distinct_count(&col, &stats).unwrap(); + assert_eq!(ndv, 100); + } + + #[test] + fn test_binary_expr_ndv_arithmetic() { + let stats = make_stats_with_ndv(1000, 100); + let col = Arc::new(Column::new("a", 0)) as Arc; + let lit = + Arc::new(Literal::new(ScalarValue::Int64(Some(1)))) as Arc; + + let registry = ExpressionAnalyzerRegistry::new(); + + // col + 1: injective, preserves NDV + let plus = Arc::new(BinaryExpr::new( + Arc::clone(&col), + Operator::Plus, + Arc::clone(&lit), + )) as Arc; + assert_eq!(registry.get_distinct_count(&plus, &stats), Some(100)); + + // col - 1: injective, preserves NDV + let minus = Arc::new(BinaryExpr::new( + Arc::clone(&col), + Operator::Minus, + Arc::clone(&lit), + )) as Arc; + assert_eq!(registry.get_distinct_count(&minus, &stats), Some(100)); + + // col * 2: injective, preserves NDV + let lit2 = + Arc::new(Literal::new(ScalarValue::Int64(Some(2)))) as Arc; + let mul = Arc::new(BinaryExpr::new(Arc::clone(&col), Operator::Multiply, lit2)) + as Arc; + assert_eq!(registry.get_distinct_count(&mul, &stats), Some(100)); + + // 1 + col: also injective (literal on left) + let plus_rev = + Arc::new(BinaryExpr::new(lit, Operator::Plus, col)) as Arc; + assert_eq!(registry.get_distinct_count(&plus_rev, &stats), Some(100)); + } +} diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 988e14c28e17c..d4dea40bf3dce 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -35,6 +35,7 @@ pub mod binary_map { } pub mod async_scalar_function; pub mod equivalence; +pub mod expression_analyzer; pub mod expressions; pub mod intervals; mod partitioning; diff --git a/datafusion/physical-expr/src/projection.rs b/datafusion/physical-expr/src/projection.rs index 540fd620c92ce..2fe22eae77ebf 100644 --- a/datafusion/physical-expr/src/projection.rs +++ b/datafusion/physical-expr/src/projection.rs @@ -21,6 +21,7 @@ use std::ops::Deref; use std::sync::Arc; use crate::PhysicalExpr; +use crate::expression_analyzer::ExpressionAnalyzerRegistry; use crate::expressions::{Column, Literal}; use crate::utils::collect_columns; @@ -660,9 +661,23 @@ impl ProjectionExprs { } } } else { - // TODO stats: estimate more statistics from expressions - // (expressions should compute their statistics themselves) - ColumnStatistics::new_unknown() + // Use ExpressionAnalyzer to estimate NDV for arbitrary expressions + // This handles: + // - Column references (preserves NDV) + // - Literals (NDV = 1) + // - Injective functions like UPPER(col) (preserves NDV) + // - Non-injective functions like FLOOR(col) (reduces NDV) + // - Date/time functions like MONTH(col) (bounded NDV) + let registry = ExpressionAnalyzerRegistry::with_builtin_analyzers(); + let distinct_count = registry + .get_distinct_count(expr, &stats) + .map(Precision::Inexact) + .unwrap_or(Precision::Absent); + + ColumnStatistics { + distinct_count, + ..ColumnStatistics::new_unknown() + } }; column_statistics.push(col_stats); } @@ -2610,10 +2625,11 @@ pub(crate) mod tests { // Should have 2 column statistics assert_eq!(output_stats.column_statistics.len(), 2); - // First column (expression) should have unknown statistics + // First column (expression `col0 + 1`) preserves NDV from the single + // referenced column as a conservative upper bound (marked Inexact) assert_eq!( output_stats.column_statistics[0].distinct_count, - Precision::Absent + Precision::Inexact(5) ); assert_eq!( output_stats.column_statistics[0].max_value, diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 9352a143c11f8..1714261b2f656 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -75,6 +75,11 @@ pub mod display; pub mod empty; pub mod execution_plan; pub mod explain; + +// Re-export expression_analyzer from physical-expr for backwards compatibility +pub mod expression_analyzer { + pub use datafusion_physical_expr::expression_analyzer::*; +} pub mod filter; pub mod filter_pushdown; pub mod joins;