diff --git a/vortex-duckdb/build.rs b/vortex-duckdb/build.rs index 868b417bf6c..0b4fc201df6 100644 --- a/vortex-duckdb/build.rs +++ b/vortex-duckdb/build.rs @@ -23,7 +23,7 @@ const DEFAULT_DUCKDB_VERSION: &str = "1.5.3"; const BUILD_ARTIFACTS: [&str; 3] = ["libduckdb.dylib", "libduckdb.so", "libduckdb_static.a"]; -const SOURCE_FILES: [&str; 17] = [ +const SOURCE_FILES: [&str; 18] = [ "cpp/client_context.cpp", "cpp/config.cpp", "cpp/copy_function.cpp", @@ -33,6 +33,7 @@ const SOURCE_FILES: [&str; 17] = [ "cpp/expr.cpp", "cpp/file_system.cpp", "cpp/logical_type.cpp", + "cpp/optimizer.cpp", "cpp/replacement_scan.cpp", "cpp/reusable_dict.cpp", "cpp/scalar_function.cpp", diff --git a/vortex-duckdb/cpp/include/duckdb_vx.h b/vortex-duckdb/cpp/include/duckdb_vx.h index dcad0ae1487..176b40a415a 100644 --- a/vortex-duckdb/cpp/include/duckdb_vx.h +++ b/vortex-duckdb/cpp/include/duckdb_vx.h @@ -4,6 +4,7 @@ #pragma once #include "duckdb_vx/client_context.h" +#include "duckdb_vx/optimizer.h" #include "duckdb_vx/config.h" #include "duckdb_vx/copy_function.h" #include "duckdb_vx/data.h" diff --git a/vortex-duckdb/cpp/include/duckdb_vx/optimizer.h b/vortex-duckdb/cpp/include/duckdb_vx/optimizer.h new file mode 100644 index 00000000000..b7f98d236f2 --- /dev/null +++ b/vortex-duckdb/cpp/include/duckdb_vx/optimizer.h @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors +#pragma once + +#include "duckdb_vx/duckdb_diagnostics.h" +DUCKDB_INCLUDES_BEGIN +#include "duckdb.h" +DUCKDB_INCLUDES_END + +#ifdef __cplusplus +extern "C" { +#endif + +duckdb_state duckdb_vx_optimizer_extension_register(duckdb_database ffi_db); + +#ifdef __cplusplus +} +#endif diff --git a/vortex-duckdb/cpp/include/duckdb_vx/table_function.h b/vortex-duckdb/cpp/include/duckdb_vx/table_function.h index 4b60207a036..1524f066ae4 100644 --- a/vortex-duckdb/cpp/include/duckdb_vx/table_function.h +++ b/vortex-duckdb/cpp/include/duckdb_vx/table_function.h @@ -125,6 +125,7 @@ typedef struct { void (*cardinality)(void *bind_data, duckdb_vx_node_statistics *node_stats_out); + void (*pushdown_column_type)(void *bind_data, idx_t id, duckdb_logical_type type); bool (*pushdown_complex_filter)(void *bind_data, duckdb_vx_expr expr, duckdb_vx_error *error_out); void (*to_string)(void *bind_data, duckdb_vx_string_map map); diff --git a/vortex-duckdb/cpp/optimizer.cpp b/vortex-duckdb/cpp/optimizer.cpp new file mode 100644 index 00000000000..0f80bd6e720 --- /dev/null +++ b/vortex-duckdb/cpp/optimizer.cpp @@ -0,0 +1,210 @@ +#include "duckdb_vx/optimizer.h" +#include "duckdb_vx/duckdb_diagnostics.h" +DUCKDB_INCLUDES_BEGIN +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/optimizer/optimizer_extension.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +DUCKDB_INCLUDES_END + +using namespace duckdb; + +/* + * Until https://github.com/duckdb/duckdb/pull/22788 is merged, and Duckdb + * version used in Vortex is bumped to include this, we'll have our separate + * optimizer pass pushing down types to Vortex. + */ + +// Collect CAST(bound_column, T) patterns where bound_column binds into given GET's index. +// A bare bound_column ref (outside any CAST) is recorded as a conflict: the column is +// consumed at its original type and its scan type must not change. +static void CollectCastTypes(const Expression &expr, + idx_t index, + const vector &column_ids, + unordered_map &cast_map, + unordered_set &conflicts) { + auto collect_children = [&] { + ExpressionIterator::EnumerateChildren(expr, [&](const Expression &child) { + CollectCastTypes(child, index, column_ids, cast_map, conflicts); + }); + }; + + // Bare column ref pointing to this GET: the column is used at its original type. + if (expr.GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF) { + auto &colref = expr.Cast(); + if (colref.depth == 0 && colref.binding.table_index == index) { + const column_t proj_id = colref.binding.column_index; + if (!IsVirtualColumn(proj_id) && proj_id < column_ids.size()) { + conflicts.insert(column_ids[proj_id].GetPrimaryIndex()); + } + } + return; + } + + if (expr.GetExpressionClass() != ExpressionClass::BOUND_CAST) { + return collect_children(); + } + auto &bound_cast = expr.Cast(); + + if (bound_cast.child->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + return collect_children(); + } + auto &bound_column = bound_cast.child->Cast(); + + if (bound_column.depth > 0 || bound_column.binding.table_index != index) { + return collect_children(); + } + + // We are in a leaf: CAST(colref, T) where colref binds into this GET. + const column_t projection_id = bound_column.binding.column_index; + if (IsVirtualColumn(projection_id)) { + return; + } + D_ASSERT(projection_id < column_ids.size()); + const column_t column_id = column_ids[projection_id].GetPrimaryIndex(); + if (auto it = cast_map.find(column_id); it == cast_map.end()) { + cast_map.emplace(column_id, bound_cast.return_type); + } else if (it->second != bound_cast.return_type) { + conflicts.insert(column_id); + } +} + +// Replace every CAST(bound_column, T) with a bare bound_column at type T when T +// is listed in projection_cast. +static void ReplaceCastTypes(unique_ptr &expr, + idx_t index, + const unordered_map &projection_cast) { + auto replace_children = [&] { + ExpressionIterator::EnumerateChildren(*expr, [&](unique_ptr &child) { + ReplaceCastTypes(child, index, projection_cast); + }); + }; + + if (expr->GetExpressionClass() != ExpressionClass::BOUND_CAST) { + return replace_children(); + } + auto &bound_cast = expr->Cast(); + + if (bound_cast.child->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + return replace_children(); + } + auto &bound_column = bound_cast.child->Cast(); + + if (bound_column.depth > 0 || bound_column.binding.table_index != index) { + return replace_children(); + } + + const column_t projection_id = bound_column.binding.column_index; + auto it = projection_cast.find(projection_id); + if (it == projection_cast.end() || it->second != bound_cast.return_type) { + return replace_children(); + } + + expr = make_uniq(it->second, bound_column.binding); +} + +// Collect cast-type candidates from every operator in the plan tree. +static void CollectFromPlan(LogicalOperator &op, + idx_t index, + const vector &column_ids, + unordered_map &cast_map, + unordered_set &conflicts) { + LogicalOperatorVisitor::EnumerateExpressions(op, [&](unique_ptr *expr_ptr) { + CollectCastTypes(**expr_ptr, index, column_ids, cast_map, conflicts); + }); + for (auto &child : op.children) { + CollectFromPlan(*child, index, column_ids, cast_map, conflicts); + } +} + +// Replace cast expressions in every operator in the plan tree. +static void +ReplaceInPlan(LogicalOperator &op, idx_t index, const unordered_map &proj_to_type) { + LogicalOperatorVisitor::EnumerateExpressions(op, [&](unique_ptr *expr_ptr) { + ReplaceCastTypes(*expr_ptr, index, proj_to_type); + }); + for (auto &child : op.children) { + ReplaceInPlan(*child, index, proj_to_type); + } +} + +static void FindGetWithTypePushdown(LogicalOperator &op, vector &gets) { + if (op.type == LogicalOperatorType::LOGICAL_GET) { + auto &get = op.Cast(); + if (get.function.type_pushdown) { + gets.push_back(&get); + } + } + for (auto &child : op.children) { + FindGetWithTypePushdown(*child, gets); + } +} + +// For each GET that supports type_pushdown, collect CAST(col, T) patterns from +// the *entire* plan. Columns that appear bare (outside any cast) or are cast to +// multiple conflicting types are excluded. The surviving types are pushed into +// the GET's bind_data and returned_types, and the redundant CASTs are stripped +// from all operator expressions throughout the plan. +static unique_ptr TryPushdownCastTypes(ClientContext &context, + unique_ptr plan) { + vector gets; + FindGetWithTypePushdown(*plan, gets); + + for (LogicalGet *get : gets) { + const vector &column_ids = get->GetColumnIds(); + const idx_t index = get->table_index; + unordered_map cast_map; + unordered_set conflicts; + + CollectFromPlan(*plan, index, column_ids, cast_map, conflicts); + + for (column_t col_id : conflicts) { + cast_map.erase(col_id); + } + if (cast_map.empty()) { + continue; + } + + get->function.type_pushdown(context, get->bind_data, cast_map); + for (const auto &[col_id, new_type] : cast_map) { + get->returned_types[col_id] = new_type; + } + + unordered_map proj_to_type; + for (idx_t i = 0; i < column_ids.size(); i++) { + const column_t col_idx = column_ids[i].GetPrimaryIndex(); + if (auto it = cast_map.find(col_idx); it != cast_map.end()) { + proj_to_type[i] = it->second; + } + } + + ReplaceInPlan(*plan, index, proj_to_type); + } + + return plan; +} + +static void VortexOptimizeFunction(OptimizerExtensionInput &input, unique_ptr &plan) { + plan = TryPushdownCastTypes(input.context, std::move(plan)); +} + +struct VortexOptimizerExtension final : OptimizerExtension { + VortexOptimizerExtension() : OptimizerExtension(VortexOptimizeFunction, nullptr, {}) { + } +}; + +extern "C" duckdb_state duckdb_vx_optimizer_extension_register(duckdb_database ffi_db) { + D_ASSERT(ffi_db); + const DatabaseWrapper &wrapper = *reinterpret_cast(ffi_db); + DatabaseInstance &db = *wrapper.database->instance; + try { + DBConfig::GetConfig(db).GetCallbackManager().Register(VortexOptimizerExtension()); + } catch (const std::exception &e) { + ErrorData data(e); + DUCKDB_LOG_ERROR(db, "Failed to create Vortex optimizer extension:\t" + data.Message()); + return DuckDBError; + } + return DuckDBSuccess; +} diff --git a/vortex-duckdb/cpp/table_function.cpp b/vortex-duckdb/cpp/table_function.cpp index 92d989d30d4..41e5a77a7f4 100644 --- a/vortex-duckdb/cpp/table_function.cpp +++ b/vortex-duckdb/cpp/table_function.cpp @@ -263,6 +263,20 @@ void function(ClientContext &, TableFunctionInput &input, DataChunk &output) { } } +void type_pushdown(ClientContext &, optional_ptr bind_data, + const unordered_map &new_column_types) { + const auto &bind = bind_data->Cast(); + void *const ffi_bind = bind.ffi_data->DataPtr(); + for (const auto& [id, type] : new_column_types) { + const duckdb_logical_type casted_type = reinterpret_cast( + // This is a flaw of duckdb api which doesn't allow passing const + // LogicalTypes. We guarantee this variable won't be mutated on + // Rust side + const_cast(&type)); + bind.info.vtab.pushdown_column_type(ffi_bind, id, casted_type); + } +} + void c_pushdown_complex_filter(ClientContext &, LogicalGet &, FunctionData *bind_data, @@ -394,6 +408,7 @@ extern "C" duckdb_state duckdb_vx_tfunc_register(duckdb_database ffi_db, const d tf.filter_prune = true; tf.sampling_pushdown = false; + tf.type_pushdown = type_pushdown; tf.pushdown_complex_filter = c_pushdown_complex_filter; tf.cardinality = c_cardinality; tf.get_partition_info = get_partition_info; diff --git a/vortex-duckdb/src/convert/dtype.rs b/vortex-duckdb/src/convert/dtype.rs index f1f8ae0de9b..bfd6f20f6f6 100644 --- a/vortex-duckdb/src/convert/dtype.rs +++ b/vortex-duckdb/src/convert/dtype.rs @@ -284,6 +284,14 @@ impl TryFrom<&DType> for LogicalType { } } +impl TryFrom<&LogicalTypeRef> for DType { + type Error = VortexError; + + fn try_from(value: &LogicalTypeRef) -> Result { + DType::from_logical_type(value, Nullability::Nullable) + } +} + impl TryFrom for LogicalType { type Error = VortexError; diff --git a/vortex-duckdb/src/datasource.rs b/vortex-duckdb/src/datasource.rs index a1fa15a1b40..50a2cef88f8 100644 --- a/vortex-duckdb/src/datasource.rs +++ b/vortex-duckdb/src/datasource.rs @@ -30,13 +30,16 @@ use vortex::array::arrays::scalar_fn::ScalarFnArrayExt; use vortex::array::optimizer::ArrayOptimizer; use vortex::array::stats::StatsSet; use vortex::dtype::DType; -use vortex::dtype::FieldNames; +use vortex::dtype::FieldName; use vortex::error::VortexExpect; use vortex::error::VortexResult; use vortex::error::vortex_err; +use vortex::error::vortex_panic; use vortex::expr::Expression; use vortex::expr::and_collect; +use vortex::expr::cast; use vortex::expr::col; +use vortex::expr::get_item; use vortex::expr::merge; use vortex::expr::pack; use vortex::expr::root; @@ -77,6 +80,7 @@ use crate::duckdb::DataChunkRef; use crate::duckdb::DuckdbStringMapRef; use crate::duckdb::ExpressionRef; use crate::duckdb::LogicalType; +use crate::duckdb::LogicalTypeRef; use crate::duckdb::PartitionData; use crate::duckdb::TableFilterClass; use crate::duckdb::TableFilterSetRef; @@ -116,6 +120,7 @@ struct DuckdbField { name: String, logical_type: LogicalType, dtype: DType, + casted: bool, } /// Bind data produced by a [`DataSourceTableFunction`]. @@ -344,6 +349,7 @@ impl TableFunction for T { .has_non_optional_filter .store(true, Ordering::Relaxed); } + println!("{}", projection.display_tree()); debug!( %projection, @@ -632,6 +638,27 @@ impl TableFunction for T { map.push("Filters", &filters.join("\n")); } } + + fn pushdown_column_type( + bind_data: &mut Self::BindData, + column_id: u64, + new_type: &LogicalTypeRef, + ) { + // TODO virtual column count? + let column_id: usize = column_id.as_(); + if column_id >= bind_data.column_fields.len() { + vortex_panic!("column_id {column_id} >= {}", bind_data.column_fields.len()); + } + let field = &mut bind_data.column_fields[column_id]; + let old_dtype = field.dtype.clone(); + // TODO we don't need a copy? + field.logical_type = LogicalType::new(new_type.as_type_id()); + field.dtype = new_type + .try_into() + .vortex_expect("logical type -> dtype conversion failed"); + println!("Cast {} -> {}", old_dtype, field.dtype); + field.casted = true; + } } /// Extracts DuckDB column names and logical types from a Vortex struct DType. @@ -649,6 +676,7 @@ fn extract_schema_from_dtype(dtype: &DType) -> VortexResult> { name: field_name.to_string(), logical_type, dtype: field_dtype, + casted: false, }); } Ok(fields) @@ -660,6 +688,17 @@ struct ProjectionWithVirtualColumns { file_row_number_column_pos: Option, } +fn with_file_row_number(select: Expression, has_file_row_number_column: bool) -> Expression { + // file_index column will be filled later when exporting the chunk. + if has_file_row_number_column { + // row_idx will be moved to correct position in scan(), prepend here + let row_idx_struct = pack([("file_row_number", row_idx())], false.into()); + merge([row_idx_struct, select]) + } else { + select + } +} + fn extract_projection_expr( projection_ids: Option<&[u64]>, column_ids: &[u64], @@ -678,7 +717,9 @@ fn extract_projection_expr( let mut real_column_count = 0; // DuckDB uses u64 as column indices but Rust uses usize - for (column_pos, &column_id) in ids.iter().enumerate() { + for (column_pos, (&column_id, column_field)) in + ids.iter().zip(column_fields.as_ref()).enumerate() + { let column_id = if has_projection_ids { let column_id: usize = column_id.as_(); column_ids[column_id] @@ -695,6 +736,7 @@ fn extract_projection_expr( continue; } + is_star &= !column_field.casted; // In SELECT * DuckDB requests all columns from 0 to column_fields in // increasing order. After removing virtual columns, compare column_id // with (0..column_fields.len()) range. @@ -705,38 +747,59 @@ fn extract_projection_expr( // 5 columns total. is_star &= real_column_count == column_fields.len() as u64; - let select = if is_star { - root() - } else { - let names = ids - .iter() - .map(|&column_id| { - if has_projection_ids { - let column_id: usize = column_id.as_(); - column_ids[column_id] - } else { - column_id - } - }) - .filter(|&col_id| !is_virtual_column(col_id)) - .map(|column_id| { - let column_id: usize = column_id.as_(); - Arc::from(column_fields[column_id].name.as_str()) - }) - .collect::(); + if is_star { + let projection = with_file_row_number(root(), file_row_number_column_pos.is_some()); + return ProjectionWithVirtualColumns { + projection, + file_index_column_pos, + file_row_number_column_pos, + }; + } - select(names, root()) - }; + let mut uncasted_fields = Vec::new(); + let mut casted_fields = Vec::new(); - // file_index column will be filled later when exporting the chunk. - let projection = if file_row_number_column_pos.is_some() { - // row_idx will be moved to correct position in scan(), prepend here - let row_idx_struct = pack([("file_row_number", row_idx())], false.into()); - merge([row_idx_struct, select]) + for &column_id in ids { + let column_id = if has_projection_ids { + let column_id: usize = column_id.as_(); + column_ids[column_id] + } else { + column_id + }; + if is_virtual_column(column_id) { + continue; + } + let column_id: usize = column_id.as_(); + let column_field = &column_fields[column_id]; + let field_name = FieldName::from(Arc::from(column_fields[column_id].name.as_str())); + println!("field: {column_field:?}"); + if column_field.casted { + casted_fields.push((field_name, column_field.dtype.clone())); + } else { + uncasted_fields.push(field_name); + } + } + + println!("casted fields: {casted_fields:?}"); + + let projection = if casted_fields.is_empty() { + select(uncasted_fields, root()) } else { - select + let mut fields = Vec::new(); + for (field, dtype) in casted_fields { + println!("casted {field:?} to {dtype}"); + fields.push((field.clone(), cast(get_item(field, root()), dtype))); + } + if uncasted_fields.is_empty() { + pack(fields, false.into()) + } else { + let select = select(uncasted_fields, root()); + merge([pack(fields, false.into()), select]) + } }; + let projection = with_file_row_number(projection, file_row_number_column_pos.is_some()); + ProjectionWithVirtualColumns { projection, file_index_column_pos, @@ -849,21 +912,24 @@ mod tests { #[test] fn test_select_star() { let ids = [0, 1, 2]; - let fields = [ + let mut fields = [ DuckdbField { name: "".to_owned(), logical_type: LogicalType::null(), dtype: DType::Null, + casted: false, }, DuckdbField { name: "".to_owned(), logical_type: LogicalType::null(), dtype: DType::Null, + casted: false, }, DuckdbField { name: "".to_owned(), logical_type: LogicalType::null(), dtype: DType::Null, + casted: false, }, ]; @@ -904,5 +970,11 @@ mod tests { extract_projection_expr(None, &ids, &fields).projection, root() ); + + fields[0].casted = true; + assert_ne!( + extract_projection_expr(None, &[0, 1, 2], &fields).projection, + root() + ); } } diff --git a/vortex-duckdb/src/duckdb/table_function/mod.rs b/vortex-duckdb/src/duckdb/table_function/mod.rs index 986ac64d100..a817fd53df5 100644 --- a/vortex-duckdb/src/duckdb/table_function/mod.rs +++ b/vortex-duckdb/src/duckdb/table_function/mod.rs @@ -21,6 +21,7 @@ use crate::duckdb::DataChunk; use crate::duckdb::DatabaseRef; use crate::duckdb::Expression; use crate::duckdb::LogicalType; +use crate::duckdb::LogicalTypeRef; use crate::duckdb::Value; use crate::duckdb::client_context::ClientContextRef; use crate::duckdb::data_chunk::DataChunkRef; @@ -103,6 +104,12 @@ pub trait TableFunction: Sized + Debug { /// Return table scanning progress from 0. to 100. fn table_scan_progress(global_state: &Self::GlobalState) -> f64; + fn pushdown_column_type( + bind_data: &mut Self::BindData, + column_id: u64, + new_type: &LogicalTypeRef + ); + /// Pushes down a filter expression to the table function. /// /// Returns `true` if the filter was successfully pushed down (and stored on the bind data), @@ -138,6 +145,14 @@ pub enum Cardinality { } impl DatabaseRef { + pub fn register_optimizer_extension(&self) -> VortexResult<()> { + duckdb_try!( + unsafe { cpp::duckdb_vx_optimizer_extension_register(self.as_ptr()) }, + "Failed to register optimizer extension" + ); + Ok(()) + } + pub fn register_table_function(&self, name: &CStr) -> VortexResult<()> { // Set up the parameters. let parameters = T::parameters(); @@ -157,6 +172,7 @@ impl DatabaseRef { function: Some(function::), statistics: Some(statistics::), cardinality: Some(cardinality_callback::), + pushdown_column_type: Some(pushdown_column_type::), pushdown_complex_filter: Some(pushdown_complex_filter_callback::), to_string: Some(to_string_callback::), table_scan_progress: Some(table_scan_progress_callback::), @@ -225,6 +241,17 @@ unsafe extern "C-unwind" fn get_partition_data_callback( out.file_index = data.file_index; } +unsafe extern "C-unwind" fn pushdown_column_type( + bind_data: *mut c_void, + column_id: u64, + new_type: cpp::duckdb_logical_type +) { + let bind_data = + unsafe { bind_data.cast::().as_mut() }.vortex_expect("bind_data null pointer"); + let new_type = unsafe { LogicalType::borrow(new_type) }; + T::pushdown_column_type(bind_data, column_id, new_type); +} + unsafe extern "C-unwind" fn pushdown_complex_filter_callback( bind_data: *mut c_void, expr: cpp::duckdb_vx_expr, diff --git a/vortex-duckdb/src/lib.rs b/vortex-duckdb/src/lib.rs index 410d241a766..d07d685bc5b 100644 --- a/vortex-duckdb/src/lib.rs +++ b/vortex-duckdb/src/lib.rs @@ -70,6 +70,7 @@ pub fn initialize(db: &DatabaseRef) -> VortexResult<()> { LogicalType::varchar(), Value::from("vortex"), )?; + db.register_optimizer_extension()?; db.register_table_function::(c"vortex_scan")?; db.register_table_function::(c"read_vortex")?; // Register list overloads for multi-glob scanning (e.g., read_vortex(['a.vortex', 'b.vortex']))