diff --git a/Cargo.lock b/Cargo.lock index 8bf9cdb4..457c4106 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -844,16 +844,6 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" -[[package]] -name = "erased-serde" -version = "0.4.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e004d887f51fcb9fef17317a2f3525c887d8aa3f4f50fed920816a688284a5b7" -dependencies = [ - "serde", - "typeid", -] - [[package]] name = "errno" version = "0.3.13" @@ -1212,15 +1202,6 @@ dependencies = [ "str_stack", ] -[[package]] -name = "inventory" -version = "0.3.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab08d7cd2c5897f2c949e5383ea7c7db03fb19130ffcfbf7eda795137ae3cb83" -dependencies = [ - "rustversion", -] - [[package]] name = "io-uring" version = "0.7.8" @@ -1374,7 +1355,6 @@ dependencies = [ "tempfile", "thiserror 1.0.69", "tokio", - "typetag", "ulid", "wasm-bindgen", "wasm-bindgen-test", @@ -1555,7 +1535,6 @@ dependencies = [ "serde", "sqlparser", "tempfile", - "typetag", ] [[package]] @@ -2659,8 +2638,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dbf5ea8d4d7c808e1af1cbabebca9a2abe603bcefc22294c5b95018d53200cb7" dependencies = [ "log", - "recursive", - "serde", ] [[package]] @@ -2948,7 +2925,6 @@ dependencies = [ "rand 0.8.5", "rust_decimal", "sqlite", - "sqlparser", "thiserror 1.0.69", ] @@ -2983,42 +2959,12 @@ dependencies = [ "once_cell", ] -[[package]] -name = "typeid" -version = "1.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc7d623258602320d5c55d1bc22793b57daff0ec7efc270ea7d55ce1d5f5471c" - [[package]] name = "typenum" version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" -[[package]] -name = "typetag" -version = "0.2.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73f22b40dd7bfe8c14230cf9702081366421890435b2d625fa92b4acc4c3de6f" -dependencies = [ - "erased-serde", - "inventory", - "once_cell", - "serde", - "typetag-impl", -] - -[[package]] -name = "typetag-impl" -version = "0.2.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35f5380909ffc31b4de4f4bdf96b877175a016aa2ca98cee39fcfd8c4d53d952" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.104", -] - [[package]] name = "ulid" version = "1.2.1" diff --git a/Cargo.toml b/Cargo.toml index 0f17570d..e19e269c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,9 +66,8 @@ rust_decimal = { version = "1" } serde = { version = "1", features = ["derive", "rc"] } kite_sql_serde_macros = { version = "0.2.0", path = "kite_sql_serde_macros" } siphasher = { version = "1", features = ["serde"] } -sqlparser = { version = "0.61", features = ["serde"] } +sqlparser = { version = "0.61", default-features = false, features = ["std"] } thiserror = { version = "1" } -typetag = { version = "0.2" } ulid = { version = "1", features = ["serde"] } # Feature: net diff --git a/kite_sql_serde_macros/src/reference_serialization.rs b/kite_sql_serde_macros/src/reference_serialization.rs index 2a688ed1..81dc9114 100644 --- a/kite_sql_serde_macros/src/reference_serialization.rs +++ b/kite_sql_serde_macros/src/reference_serialization.rs @@ -142,7 +142,7 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { fn decode( reader: &mut R, - drive: Option<(&T, &crate::storage::TableCache)>, + drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, reference_tables: &crate::serdes::ReferenceTables, ) -> Result { #(#decode_fields)* @@ -221,7 +221,7 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { fn decode( reader: &mut R, - drive: Option<(&T, &crate::storage::TableCache)>, + drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, reference_tables: &crate::serdes::ReferenceTables, ) -> Result { let mut type_bytes = [0u8; 1]; diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index 8d5d6cd7..6ff89a0c 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -187,9 +187,9 @@ mod tests { use crate::catalog::ColumnDesc; use crate::storage::rocksdb::RocksStorage; use crate::storage::Storage; + use crate::types::CharLengthUnits; use crate::types::LogicalType; use crate::utils::lru::SharedLruCache; - use sqlparser::ast::CharLengthUnits; use std::hash::RandomState; use std::sync::atomic::AtomicUsize; use tempfile::TempDir; diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 72d96a83..51d04a72 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -18,8 +18,8 @@ use crate::expression; use crate::expression::agg::AggKind; use itertools::Itertools; use sqlparser::ast::{ - BinaryOperator, CharLengthUnits, DataType, DuplicateTreatment, Expr, Function, FunctionArg, - FunctionArgExpr, FunctionArguments, Ident, Query, TypedString, UnaryOperator, Value, + BinaryOperator, DataType, DuplicateTreatment, Expr, Function, FunctionArg, FunctionArgExpr, + FunctionArguments, Ident, Query, TypedString, UnaryOperator, Value, }; use std::borrow::Cow; use std::collections::HashMap; @@ -39,7 +39,7 @@ use crate::planner::{LogicalPlan, SchemaOutput}; use crate::storage::Transaction; use crate::types::tuple::SchemaRef; use crate::types::value::{DataValue, Utf8Type}; -use crate::types::{ColumnId, LogicalType}; +use crate::types::{CharLengthUnits, ColumnId, LogicalType}; macro_rules! try_alias { ($context:expr, $full_name:expr) => { @@ -242,7 +242,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T Ok(ScalarExpression::Trim { expr: Box::new(self.bind_expr(expr)?), trim_what_expr, - trim_where: *trim_where, + trim_where: trim_where.map(Into::into), }) } Expr::Exists { subquery, negated } => { diff --git a/src/binder/mod.rs b/src/binder/mod.rs index ec8697ea..92bb0e12 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -340,11 +340,21 @@ impl<'a, T: Transaction> BinderContext<'a, T> { pub fn view(&self, view_name: TableName) -> Result, DatabaseError> { if let Some(real_name) = self.table_aliases.get(view_name.as_ref()) { - self.transaction - .view(self.table_cache, self.view_cache, real_name.clone()) + self.transaction.view( + self.table_cache, + self.view_cache, + self.scala_functions, + self.table_functions, + real_name.clone(), + ) } else { - self.transaction - .view(self.table_cache, self.view_cache, view_name.clone()) + self.transaction.view( + self.table_cache, + self.view_cache, + self.scala_functions, + self.table_functions, + view_name.clone(), + ) } } @@ -367,11 +377,21 @@ impl<'a, T: Transaction> BinderContext<'a, T> { if source.is_none() && !only_table { source = if let Some(real_name) = self.table_aliases.get(table_name.as_ref()) { - self.transaction - .view(self.table_cache, self.view_cache, real_name.clone()) + self.transaction.view( + self.table_cache, + self.view_cache, + self.scala_functions, + self.table_functions, + real_name.clone(), + ) } else { - self.transaction - .view(self.table_cache, self.view_cache, table_name.clone()) + self.transaction.view( + self.table_cache, + self.view_cache, + self.scala_functions, + self.table_functions, + table_name.clone(), + ) }? .map(Source::View); } @@ -396,11 +416,21 @@ impl<'a, T: Transaction> BinderContext<'a, T> { if source.is_none() { source = if let Some(real_name) = self.table_aliases.get(table_name.as_ref()) { - self.transaction - .view(self.table_cache, self.view_cache, real_name.clone()) + self.transaction.view( + self.table_cache, + self.view_cache, + self.scala_functions, + self.table_functions, + real_name.clone(), + ) } else { - self.transaction - .view(self.table_cache, self.view_cache, table_name.clone()) + self.transaction.view( + self.table_cache, + self.view_cache, + self.scala_functions, + self.table_functions, + table_name.clone(), + ) }? .map(Source::View); } diff --git a/src/catalog/column.rs b/src/catalog/column.rs index 510f915c..356f0f06 100644 --- a/src/catalog/column.rs +++ b/src/catalog/column.rs @@ -17,9 +17,9 @@ use crate::errors::DatabaseError; use crate::expression::ScalarExpression; use crate::types::tuple::Tuple; use crate::types::value::DataValue; +use crate::types::CharLengthUnits; use crate::types::{ColumnId, LogicalType}; use kite_sql_serde_macros::ReferenceSerialization; -use sqlparser::ast::CharLengthUnits; use std::hash::Hash; use std::ops::Deref; use std::sync::Arc; diff --git a/src/catalog/table.rs b/src/catalog/table.rs index dd7eb0be..d1045e7b 100644 --- a/src/catalog/table.rs +++ b/src/catalog/table.rs @@ -349,7 +349,7 @@ mod tests { "name".into(), true, ColumnDesc::new( - LogicalType::Varchar(None, sqlparser::ast::CharLengthUnits::Characters), + LogicalType::Varchar(None, crate::types::CharLengthUnits::Characters), None, false, None, diff --git a/src/db.rs b/src/db.rs index d4975118..2fd8138e 100644 --- a/src/db.rs +++ b/src/db.rs @@ -540,7 +540,13 @@ impl State { let root = build_write( &mut arena, plan, - (&self.table_cache, &self.view_cache, &self.meta_cache), + ( + &self.table_cache, + &self.view_cache, + &self.meta_cache, + &self.scala_functions, + &self.table_functions, + ), transaction, ); let executor = Executor::new(arena, root); diff --git a/src/execution/ddl/drop_view.rs b/src/execution/ddl/drop_view.rs index f2c771c1..04f80fdd 100644 --- a/src/execution/ddl/drop_view.rs +++ b/src/execution/ddl/drop_view.rs @@ -53,11 +53,10 @@ impl DropView { return Ok(()); }; - let table_cache = arena.table_cache(); let view_cache = arena.view_cache(); arena .transaction_mut() - .drop_view(view_cache, table_cache, view_name.clone(), if_exists)?; + .drop_view(view_cache, view_name.clone(), if_exists)?; TupleBuilder::build_result_into(arena.result_tuple_mut(), format!("{view_name}")); arena.resume(); diff --git a/src/execution/dml/analyze.rs b/src/execution/dml/analyze.rs index 6a6226e7..62370858 100644 --- a/src/execution/dml/analyze.rs +++ b/src/execution/dml/analyze.rs @@ -24,8 +24,8 @@ use crate::planner::LogicalPlan; use crate::storage::{StatisticsMetaCache, Transaction}; use crate::types::index::IndexId; use crate::types::value::{DataValue, Utf8Type}; +use crate::types::CharLengthUnits; use itertools::Itertools; -use sqlparser::ast::CharLengthUnits; use std::fmt::{self, Formatter}; use std::sync::Arc; diff --git a/src/execution/dml/copy_from_file.rs b/src/execution/dml/copy_from_file.rs index d81f0f53..7a478ceb 100644 --- a/src/execution/dml/copy_from_file.rs +++ b/src/execution/dml/copy_from_file.rs @@ -158,8 +158,8 @@ mod tests { use crate::db::DataBaseBuilder; use crate::errors::DatabaseError; use crate::storage::Storage; + use crate::types::CharLengthUnits; use crate::types::LogicalType; - use sqlparser::ast::CharLengthUnits; use std::io::Write; use std::sync::Arc; use tempfile::TempDir; diff --git a/src/execution/dml/copy_to_file.rs b/src/execution/dml/copy_to_file.rs index c1d587ab..79dc6879 100644 --- a/src/execution/dml/copy_to_file.rs +++ b/src/execution/dml/copy_to_file.rs @@ -118,8 +118,8 @@ mod tests { use crate::errors::DatabaseError; use crate::planner::operator::table_scan::TableScanOperator; use crate::storage::Storage; + use crate::types::CharLengthUnits; use crate::types::LogicalType; - use sqlparser::ast::CharLengthUnits; use std::sync::Arc; use tempfile::TempDir; use ulid::Ulid; diff --git a/src/execution/dql/aggregate/avg.rs b/src/execution/dql/aggregate/avg.rs index 1457b341..400a2fa3 100644 --- a/src/execution/dql/aggregate/avg.rs +++ b/src/execution/dql/aggregate/avg.rs @@ -71,6 +71,6 @@ impl Accumulator for AvgAccumulator { value = value.cast(&quantity_ty)? } let evaluator = binary_create(Cow::Owned(quantity_ty), BinaryOperator::Divide)?; - evaluator.0.binary_eval(&value, &quantity) + evaluator.binary_eval(&value, &quantity) } } diff --git a/src/execution/dql/aggregate/min_max.rs b/src/execution/dql/aggregate/min_max.rs index 163594d7..49f9a38a 100644 --- a/src/execution/dql/aggregate/min_max.rs +++ b/src/execution/dql/aggregate/min_max.rs @@ -41,7 +41,7 @@ impl Accumulator for MinMaxAccumulator { if !value.is_null() { if let Some(inner_value) = &self.inner { let evaluator = binary_create(Cow::Owned(value.logical_type()), self.op)?; - if let DataValue::Boolean(result) = evaluator.0.binary_eval(inner_value, value)? { + if let DataValue::Boolean(result) = evaluator.binary_eval(inner_value, value)? { result } else { return Err(DatabaseError::InvalidType); diff --git a/src/execution/dql/aggregate/sum.rs b/src/execution/dql/aggregate/sum.rs index 39cfe55b..7bb6375c 100644 --- a/src/execution/dql/aggregate/sum.rs +++ b/src/execution/dql/aggregate/sum.rs @@ -44,7 +44,7 @@ impl Accumulator for SumAccumulator { if self.result.is_null() { self.result = value.clone(); } else { - self.result = self.evaluator.0.binary_eval(&self.result, value)?; + self.result = self.evaluator.binary_eval(&self.result, value)?; } } diff --git a/src/execution/dql/describe.rs b/src/execution/dql/describe.rs index 666c3721..c1732a59 100644 --- a/src/execution/dql/describe.rs +++ b/src/execution/dql/describe.rs @@ -18,7 +18,7 @@ use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNod use crate::planner::operator::describe::DescribeOperator; use crate::storage::Transaction; use crate::types::value::{DataValue, Utf8Type}; -use sqlparser::ast::CharLengthUnits; +use crate::types::CharLengthUnits; use std::sync::LazyLock; static PRIMARY_KEY_TYPE: LazyLock = LazyLock::new(|| DataValue::Utf8 { diff --git a/src/execution/dql/explain.rs b/src/execution/dql/explain.rs index c6277b2d..dcffb297 100644 --- a/src/execution/dql/explain.rs +++ b/src/execution/dql/explain.rs @@ -17,7 +17,7 @@ use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNod use crate::planner::LogicalPlan; use crate::storage::Transaction; use crate::types::value::{DataValue, Utf8Type}; -use sqlparser::ast::CharLengthUnits; +use crate::types::CharLengthUnits; pub struct Explain { plan: LogicalPlan, diff --git a/src/execution/dql/join/nested_loop_join.rs b/src/execution/dql/join/nested_loop_join.rs index 45e862c5..35305e6d 100644 --- a/src/execution/dql/join/nested_loop_join.rs +++ b/src/execution/dql/join/nested_loop_join.rs @@ -189,7 +189,13 @@ impl NestedLoopJoin { &mut self, arena: &mut ExecArena<'a, T>, ) -> ExecId { - let cache = (arena.table_cache(), arena.view_cache(), arena.meta_cache()); + let cache = ( + arena.table_cache(), + arena.view_cache(), + arena.meta_cache(), + arena.scala_functions(), + arena.table_functions(), + ); let transaction = arena.transaction_mut() as *mut T; // Fixme: Executor reset build_read(arena, self.right_input_plan.clone(), cache, transaction) @@ -447,6 +453,7 @@ mod test { use crate::db::DataBaseBuilder; use crate::execution::dql::test::build_integers; use crate::execution::try_collect; + use crate::expression::BinaryOperator; use crate::optimizer::heuristic::batch::HepBatchStrategy; use crate::optimizer::heuristic::optimizer::HepOptimizerPipeline; use crate::optimizer::rule::normalization::NormalizationRuleImpl; @@ -591,7 +598,11 @@ mod test { ColumnRef::from(ColumnCatalog::new("c4".to_owned(), true, desc.clone())), 3, )), - evaluator: Some(BinaryEvaluatorBox(Arc::new(Int32GtBinaryEvaluator))), + evaluator: Some(BinaryEvaluatorBox::new( + Arc::new(Int32GtBinaryEvaluator), + LogicalType::Integer, + BinaryOperator::Gt, + )), ty: LogicalType::Boolean, }; diff --git a/src/execution/dql/mark_apply.rs b/src/execution/dql/mark_apply.rs index 4ffdf494..8dfa8342 100644 --- a/src/execution/dql/mark_apply.rs +++ b/src/execution/dql/mark_apply.rs @@ -103,7 +103,13 @@ impl MarkApply { arena.push_runtime_probe(runtime_probe); } - let cache = (arena.table_cache(), arena.view_cache(), arena.meta_cache()); + let cache = ( + arena.table_cache(), + arena.view_cache(), + arena.meta_cache(), + arena.scala_functions(), + arena.table_functions(), + ); let transaction = arena.transaction_mut() as *mut T; let result = { let right_input = build_read(arena, self.right_input_plan.clone(), cache, transaction); diff --git a/src/execution/dql/show_table.rs b/src/execution/dql/show_table.rs index 14196691..8f655ecb 100644 --- a/src/execution/dql/show_table.rs +++ b/src/execution/dql/show_table.rs @@ -12,27 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::catalog::TableMeta; use crate::errors::DatabaseError; use crate::execution::ExecArena; -use crate::storage::Transaction; +use crate::storage::{TableIter, Transaction}; use crate::types::value::{DataValue, Utf8Type}; -use sqlparser::ast::CharLengthUnits; +use crate::types::CharLengthUnits; -pub struct ShowTables { - pub(crate) metas: Option>, +pub struct ShowTables<'a, T: Transaction + 'a> { + pub(crate) metas: Option>, } -impl ShowTables { - pub(crate) fn next_tuple<'a, T: Transaction>( - &mut self, - arena: &mut ExecArena<'a, T>, - ) -> Result<(), DatabaseError> { +impl<'a, T: Transaction + 'a> ShowTables<'a, T> { + pub(crate) fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { if self.metas.is_none() { - self.metas = Some(arena.transaction_mut().table_metas()?.into_iter()); + self.metas = Some(arena.transaction().tables()?); } - let Some(TableMeta { table_name }) = self.metas.as_mut().and_then(|metas| metas.next()) + let Some(table) = self + .metas + .as_mut() + .expect("show tables iterator initialized") + .try_next()? else { arena.finish(); return Ok(()); @@ -42,7 +42,7 @@ impl ShowTables { output.pk = None; output.values.clear(); output.values.push(DataValue::Utf8 { - value: table_name.to_string(), + value: table.table_name.to_string(), ty: Utf8Type::Variable(None), unit: CharLengthUnits::Characters, }); diff --git a/src/execution/dql/show_view.rs b/src/execution/dql/show_view.rs index 6d80fa2f..f9c14f08 100644 --- a/src/execution/dql/show_view.rs +++ b/src/execution/dql/show_view.rs @@ -12,32 +12,32 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::catalog::view::View; use crate::errors::DatabaseError; use crate::execution::ExecArena; -use crate::storage::Transaction; +use crate::storage::{Transaction, ViewIter}; use crate::types::value::{DataValue, Utf8Type}; -use sqlparser::ast::CharLengthUnits; +use crate::types::CharLengthUnits; -pub struct ShowViews { - pub(crate) metas: Option>, +pub struct ShowViews<'a, T: Transaction + 'a> { + pub(crate) metas: Option>, } -impl ShowViews { - pub(crate) fn next_tuple<'a, T: Transaction>( - &mut self, - arena: &mut ExecArena<'a, T>, - ) -> Result<(), DatabaseError> { +impl<'a, T: Transaction + 'a> ShowViews<'a, T> { + pub(crate) fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { if self.metas.is_none() { - self.metas = Some( - arena - .transaction_mut() - .views(arena.table_cache())? - .into_iter(), - ); + self.metas = Some(arena.transaction().views( + arena.table_cache(), + arena.scala_functions(), + arena.table_functions(), + )?); } - let Some(View { name, .. }) = self.metas.as_mut().and_then(|metas| metas.next()) else { + let Some(view) = self + .metas + .as_mut() + .expect("show views iterator initialized") + .try_next()? + else { arena.finish(); return Ok(()); }; @@ -46,7 +46,7 @@ impl ShowViews { output.pk = None; output.values.clear(); output.values.push(DataValue::Utf8 { - value: name.to_string(), + value: view.name.to_string(), ty: Utf8Type::Variable(None), unit: CharLengthUnits::Characters, }); diff --git a/src/execution/mod.rs b/src/execution/mod.rs index 8df214d6..ab8b60e6 100644 --- a/src/execution/mod.rs +++ b/src/execution/mod.rs @@ -21,6 +21,7 @@ use self::ddl::change_column::ChangeColumn; use self::dql::join::nested_loop_join::NestedLoopJoin; use self::dql::mark_apply::MarkApply; use self::dql::scalar_apply::ScalarApply; +use crate::db::{ScalaFunctions, TableFunctions}; use crate::errors::DatabaseError; use crate::execution::ddl::create_index::CreateIndex; use crate::execution::ddl::create_table::CreateTable; @@ -65,7 +66,23 @@ use crate::types::index::RuntimeIndexProbe; use crate::types::tuple::Tuple; use crate::types::value::DataValue; -pub(crate) type ExecutionCaches<'a> = (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache); +pub(crate) type ExecutionCaches<'a> = ( + &'a TableCache, + &'a ViewCache, + &'a StatisticsMetaCache, + &'a ScalaFunctions, + &'a TableFunctions, +); + +pub(crate) trait IntoExecutionCaches<'a> { + fn into_execution_caches(self) -> ExecutionCaches<'a>; +} + +impl<'a> IntoExecutionCaches<'a> for ExecutionCaches<'a> { + fn into_execution_caches(self) -> ExecutionCaches<'a> { + self + } +} pub(crate) type ExecId = usize; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -142,8 +159,8 @@ pub(crate) enum ExecNode<'a, T: Transaction + 'a> { ScalarSubquery(ScalarSubquery), SetMembership(SetMembership), SeqScan(SeqScan<'a, T>), - ShowTables(ShowTables), - ShowViews(ShowViews), + ShowTables(ShowTables<'a, T>), + ShowViews(ShowViews<'a, T>), SimpleAgg(SimpleAggExecutor), Sort(Sort), StreamDistinct(StreamDistinctExecutor), @@ -242,10 +259,10 @@ impl<'a, T: Transaction + 'a> ExecNode<'a, T> { as ExecutorNode<'a, T>>::next_tuple(exec, arena) } ExecNode::ShowTables(exec) => { - >::next_tuple(exec, arena) + as ExecutorNode<'a, T>>::next_tuple(exec, arena) } ExecNode::ShowViews(exec) => { - >::next_tuple(exec, arena) + as ExecutorNode<'a, T>>::next_tuple(exec, arena) } ExecNode::SimpleAgg(exec) => { >::next_tuple(exec, arena) @@ -287,11 +304,17 @@ impl<'a, T: Transaction + 'a> Default for ExecArena<'a, T> { } impl<'a, T: Transaction + 'a> ExecArena<'a, T> { - pub(crate) fn init_context(&mut self, cache: ExecutionCaches<'a>, transaction: *mut T) { + pub(crate) fn init_context(&mut self, cache: C, transaction: *mut T) + where + C: IntoExecutionCaches<'a>, + { + let cache = cache.into_execution_caches(); if let Some(current) = self.cache { debug_assert!(std::ptr::eq(current.0, cache.0)); debug_assert!(std::ptr::eq(current.1, cache.1)); debug_assert!(std::ptr::eq(current.2, cache.2)); + debug_assert!(std::ptr::eq(current.3, cache.3)); + debug_assert!(std::ptr::eq(current.4, cache.4)); debug_assert_eq!(self.transaction, transaction); } else { self.cache = Some(cache); @@ -317,6 +340,14 @@ impl<'a, T: Transaction + 'a> ExecArena<'a, T> { self.cache.expect("execution arena context initialized").2 } + pub(crate) fn scala_functions(&self) -> &'a ScalaFunctions { + self.cache.expect("execution arena context initialized").3 + } + + pub(crate) fn table_functions(&self) -> &'a TableFunctions { + self.cache.expect("execution arena context initialized").4 + } + pub(crate) fn transaction(&self) -> &'a T { unsafe { &*self.transaction } } @@ -554,7 +585,7 @@ impl_write_executor_node_via_from!( ) ); -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for ShowTables { +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for ShowTables<'a, T> { type Input = Self; fn into_executor( @@ -571,7 +602,7 @@ impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for ShowTables { } } -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for ShowViews { +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for ShowViews<'a, T> { type Input = Self; fn into_executor( @@ -758,13 +789,13 @@ pub(crate) fn build_read<'a, T: Transaction + 'a>( Operator::Values(op) => { >::into_executor(op, arena, cache, transaction) } - Operator::ShowTable => >::into_executor( + Operator::ShowTable => as ExecutorNode<'a, T>>::into_executor( ShowTables { metas: None }, arena, cache, transaction, ), - Operator::ShowView => >::into_executor( + Operator::ShowView => as ExecutorNode<'a, T>>::into_executor( ShowViews { metas: None }, arena, cache, @@ -897,77 +928,122 @@ pub(crate) fn build_write<'a, T: Transaction + 'a>( } #[cfg(all(test, not(target_arch = "wasm32")))] -pub(crate) fn execute<'a, T, E>( - executor: E, - cache: ExecutionCaches<'a>, - transaction: *mut T, -) -> Executor<'a, T> -where - T: Transaction + 'a, - E: ReadExecutor<'a, T>, -{ - let mut arena = ExecArena::default(); - arena.init_context(cache, transaction); - let root = executor.into_executor(&mut arena, cache, transaction); - Executor::new(arena, root) -} +mod test_utils { + use super::*; + + static EMPTY_SCALA_FUNCTIONS: std::sync::LazyLock = + std::sync::LazyLock::new(ScalaFunctions::default); + static EMPTY_TABLE_FUNCTIONS: std::sync::LazyLock = + std::sync::LazyLock::new(TableFunctions::default); + + impl<'a> IntoExecutionCaches<'a> for (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache) { + fn into_execution_caches(self) -> ExecutionCaches<'a> { + ( + self.0, + self.1, + self.2, + &EMPTY_SCALA_FUNCTIONS, + &EMPTY_TABLE_FUNCTIONS, + ) + } + } -#[cfg(all(test, not(target_arch = "wasm32")))] -pub(crate) fn execute_mut<'a, T, E>( - executor: E, - cache: ExecutionCaches<'a>, - transaction: *mut T, -) -> Executor<'a, T> -where - T: Transaction + 'a, - E: WriteExecutor<'a, T>, -{ - let mut arena = ExecArena::default(); - arena.init_context(cache, transaction); - let root = executor.into_executor(&mut arena, cache, transaction); - Executor::new(arena, root) -} + impl<'a> IntoExecutionCaches<'a> + for ( + &'a std::sync::Arc, + &'a std::sync::Arc, + &'a std::sync::Arc, + ) + { + fn into_execution_caches(self) -> ExecutionCaches<'a> { + ( + self.0.as_ref(), + self.1.as_ref(), + self.2.as_ref(), + &EMPTY_SCALA_FUNCTIONS, + &EMPTY_TABLE_FUNCTIONS, + ) + } + } -#[cfg(all(test, not(target_arch = "wasm32")))] -pub(crate) fn execute_input<'a, T, E>( - input: E::Input, - cache: ExecutionCaches<'a>, - transaction: *mut T, -) -> Executor<'a, T> -where - T: Transaction + 'a, - E: ExecutorNode<'a, T>, -{ - let mut arena = ExecArena::default(); - arena.init_context(cache, transaction); - let root = E::into_executor(input, &mut arena, cache, transaction); - Executor::new(arena, root) -} + pub(crate) fn execute<'a, T, E>( + executor: E, + cache: impl IntoExecutionCaches<'a>, + transaction: *mut T, + ) -> Executor<'a, T> + where + T: Transaction + 'a, + E: ReadExecutor<'a, T>, + { + let cache = cache.into_execution_caches(); + let mut arena = ExecArena::default(); + arena.init_context(cache, transaction); + let root = executor.into_executor(&mut arena, cache, transaction); + Executor::new(arena, root) + } -#[cfg(all(test, not(target_arch = "wasm32")))] -#[allow(dead_code)] -pub(crate) fn execute_input_mut<'a, T, E>( - input: E::Input, - cache: ExecutionCaches<'a>, - transaction: *mut T, -) -> Executor<'a, T> -where - T: Transaction + 'a, - E: ExecutorNode<'a, T>, -{ - let mut arena = ExecArena::default(); - arena.init_context(cache, transaction); - let root = E::into_executor(input, &mut arena, cache, transaction); - Executor::new(arena, root) -} + pub(crate) fn execute_mut<'a, T, E>( + executor: E, + cache: impl IntoExecutionCaches<'a>, + transaction: *mut T, + ) -> Executor<'a, T> + where + T: Transaction + 'a, + E: WriteExecutor<'a, T>, + { + let cache = cache.into_execution_caches(); + let mut arena = ExecArena::default(); + arena.init_context(cache, transaction); + let root = executor.into_executor(&mut arena, cache, transaction); + Executor::new(arena, root) + } -#[cfg(all(test, not(target_arch = "wasm32")))] -pub fn try_collect(executor: Executor<'_, T>) -> Result, DatabaseError> { - let mut executor = executor; - let mut tuples = Vec::new(); + pub(crate) fn execute_input<'a, T, E>( + input: E::Input, + cache: impl IntoExecutionCaches<'a>, + transaction: *mut T, + ) -> Executor<'a, T> + where + T: Transaction + 'a, + E: ExecutorNode<'a, T>, + { + let cache = cache.into_execution_caches(); + let mut arena = ExecArena::default(); + arena.init_context(cache, transaction); + let root = E::into_executor(input, &mut arena, cache, transaction); + Executor::new(arena, root) + } + + #[allow(dead_code)] + pub(crate) fn execute_input_mut<'a, T, E>( + input: E::Input, + cache: impl IntoExecutionCaches<'a>, + transaction: *mut T, + ) -> Executor<'a, T> + where + T: Transaction + 'a, + E: ExecutorNode<'a, T>, + { + let cache = cache.into_execution_caches(); + let mut arena = ExecArena::default(); + arena.init_context(cache, transaction); + let root = E::into_executor(input, &mut arena, cache, transaction); + Executor::new(arena, root) + } + + pub fn try_collect( + executor: Executor<'_, T>, + ) -> Result, DatabaseError> { + let mut executor = executor; + let mut tuples = Vec::new(); - while let Some(tuple) = executor.next_tuple()? { - tuples.push(tuple.clone()); + while let Some(tuple) = executor.next_tuple()? { + tuples.push(tuple.clone()); + } + Ok(tuples) } - Ok(tuples) } + +#[cfg(all(test, not(target_arch = "wasm32")))] +#[allow(unused_imports)] +pub(crate) use test_utils::{execute, execute_input, execute_input_mut, execute_mut, try_collect}; diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index 62cae29c..69f93d0e 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -14,13 +14,12 @@ use crate::errors::DatabaseError; use crate::expression::function::scala::ScalarFunction; -use crate::expression::{AliasType, BinaryOperator, ScalarExpression}; +use crate::expression::{AliasType, BinaryOperator, ScalarExpression, TrimWhereField}; use crate::types::evaluator::binary_create; use crate::types::tuple::TupleLike; use crate::types::value::{DataValue, Utf8Type}; -use crate::types::LogicalType; +use crate::types::{CharLengthUnits, LogicalType}; use regex::Regex; -use sqlparser::ast::{CharLengthUnits, TrimWhereField}; use std::borrow::Cow; use std::cmp; use std::cmp::Ordering; @@ -80,7 +79,6 @@ impl ScalarExpression { evaluator .as_ref() .ok_or(DatabaseError::EvaluatorNotFound)? - .0 .binary_eval(&left, &right) } ScalarExpression::IsNull { expr, negated } => { @@ -131,7 +129,6 @@ impl ScalarExpression { Ok(evaluator .as_ref() .ok_or(DatabaseError::EvaluatorNotFound)? - .0 .unary_eval(&value)) } ScalarExpression::AggCall { .. } => { @@ -341,7 +338,6 @@ impl ScalarExpression { when_value = when_value.cast(&ty)?; let evaluator = binary_create(Cow::Owned(ty), BinaryOperator::Eq)?; evaluator - .0 .binary_eval(operand_value, &when_value)? .is_true()? } else { diff --git a/src/expression/function/mod.rs b/src/expression/function/mod.rs index e1c0c3ef..f8f71a77 100644 --- a/src/expression/function/mod.rs +++ b/src/expression/function/mod.rs @@ -13,13 +13,14 @@ // limitations under the License. use crate::types::LogicalType; +use kite_sql_serde_macros::ReferenceSerialization; use serde::{Deserialize, Serialize}; use std::sync::Arc; pub mod scala; pub mod table; -#[derive(Debug, Eq, PartialEq, Hash, Clone, Serialize, Deserialize)] +#[derive(Debug, Eq, PartialEq, Hash, Clone, Serialize, Deserialize, ReferenceSerialization)] pub struct FunctionSummary { pub name: Arc, pub arg_types: Vec, diff --git a/src/expression/function/scala.rs b/src/expression/function/scala.rs index a61c7344..74cac768 100644 --- a/src/expression/function/scala.rs +++ b/src/expression/function/scala.rs @@ -19,7 +19,6 @@ use crate::types::tuple::TupleLike; use crate::types::value::DataValue; use crate::types::LogicalType; use kite_sql_serde_macros::ReferenceSerialization; -use serde::{Deserialize, Serialize}; use std::fmt::Debug; use std::hash::{Hash, Hasher}; use std::ops::Deref; @@ -31,7 +30,7 @@ use std::sync::Arc; /// - `Some(false)` monotonically decreasing pub type FuncMonotonicity = Vec>; -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone)] pub struct ArcScalarFunctionImpl(pub Arc); impl Deref for ArcScalarFunctionImpl { @@ -61,8 +60,6 @@ impl Hash for ScalarFunction { self.summary().hash(state); } } - -#[typetag::serde(tag = "scala")] pub trait ScalarFunctionImpl: Debug + Send + Sync { fn eval( &self, diff --git a/src/expression/function/table.rs b/src/expression/function/table.rs index 7fc1cd23..3c991a68 100644 --- a/src/expression/function/table.rs +++ b/src/expression/function/table.rs @@ -18,13 +18,12 @@ use crate::expression::function::FunctionSummary; use crate::expression::ScalarExpression; use crate::types::tuple::{SchemaRef, Tuple}; use kite_sql_serde_macros::ReferenceSerialization; -use serde::{Deserialize, Serialize}; use std::fmt::Debug; use std::hash::{Hash, Hasher}; use std::ops::Deref; use std::sync::Arc; -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone)] pub struct ArcTableFunctionImpl(pub Arc); impl Deref for ArcTableFunctionImpl { @@ -54,8 +53,6 @@ impl Hash for TableFunction { self.summary().hash(state); } } - -#[typetag::serde(tag = "table")] pub trait TableFunctionImpl: Debug + Send + Sync { fn eval( &self, diff --git a/src/expression/mod.rs b/src/expression/mod.rs index a313fc43..984e33f2 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -24,13 +24,10 @@ use crate::types::evaluator::{ UnaryEvaluatorBox, }; use crate::types::value::DataValue; -use crate::types::LogicalType; +use crate::types::{CharLengthUnits, LogicalType}; use itertools::Itertools; use kite_sql_serde_macros::ReferenceSerialization; -use sqlparser::ast::TrimWhereField; -use sqlparser::ast::{ - BinaryOperator as SqlBinaryOperator, CharLengthUnits, UnaryOperator as SqlUnaryOperator, -}; +use sqlparser::ast::{BinaryOperator as SqlBinaryOperator, UnaryOperator as SqlUnaryOperator}; use std::borrow::Cow; use std::fmt::{Debug, Formatter}; use std::hash::Hash; @@ -44,6 +41,23 @@ pub mod simplify; pub mod visitor; pub mod visitor_mut; +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] +pub enum TrimWhereField { + Both, + Leading, + Trailing, +} + +impl From for TrimWhereField { + fn from(value: sqlparser::ast::TrimWhereField) -> Self { + match value { + sqlparser::ast::TrimWhereField::Both => Self::Both, + sqlparser::ast::TrimWhereField::Leading => Self::Leading, + sqlparser::ast::TrimWhereField::Trailing => Self::Trailing, + } + } +} + #[derive(Debug, PartialEq, Eq, Clone, Hash, ReferenceSerialization)] pub enum AliasType { Name(String), @@ -845,23 +859,29 @@ impl TryFrom for BinaryOperator { mod test { use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef, ColumnRelation, ColumnSummary}; use crate::db::test::build_table; + use crate::db::{ScalaFunctions, TableFunctions}; use crate::errors::DatabaseError; use crate::expression::agg::AggKind; - use crate::expression::function::scala::{ArcScalarFunctionImpl, ScalarFunction}; - use crate::expression::function::table::{ArcTableFunctionImpl, TableFunction}; + use crate::expression::function::scala::{ + ArcScalarFunctionImpl, ScalarFunction, ScalarFunctionImpl, + }; + use crate::expression::function::table::{ + ArcTableFunctionImpl, TableFunction, TableFunctionImpl, + }; + use crate::expression::TrimWhereField; use crate::expression::{AliasType, BinaryOperator, ScalarExpression, UnaryOperator}; use crate::function::current_date::CurrentDate; use crate::function::numbers::Numbers; - use crate::serdes::{ReferenceSerialization, ReferenceTables}; + use crate::serdes::{ReferenceDecodeContext, ReferenceSerialization, ReferenceTables}; use crate::storage::rocksdb::{RocksStorage, RocksTransaction}; - use crate::storage::{Storage, TableCache, Transaction}; + use crate::storage::{Storage, Transaction}; use crate::types::evaluator::boolean::BooleanNotUnaryEvaluator; use crate::types::evaluator::int32::Int32PlusBinaryEvaluator; use crate::types::evaluator::{cast_create, BinaryEvaluatorBox, UnaryEvaluatorBox}; use crate::types::value::{DataValue, Utf8Type}; + use crate::types::CharLengthUnits; use crate::types::LogicalType; use crate::utils::lru::SharedLruCache; - use sqlparser::ast::{CharLengthUnits, TrimWhereField}; use std::borrow::Cow; use std::hash::RandomState; use std::io::{Cursor, Seek, SeekFrom}; @@ -905,7 +925,7 @@ mod test { fn fn_assert( cursor: &mut Cursor>, expr: ScalarExpression, - drive: Option<(&RocksTransaction, &TableCache)>, + drive: Option<&ReferenceDecodeContext<'_, RocksTransaction>>, reference_tables: &mut ReferenceTables, ) -> Result<(), DatabaseError> { expr.encode(cursor, false, reference_tables)?; @@ -924,7 +944,12 @@ mod test { let storage = RocksStorage::new(temp_dir.path())?; let mut transaction = storage.transaction()?; let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - + let mut scala_functions = ScalaFunctions::default(); + let current_date = CurrentDate::new(); + scala_functions.insert(current_date.summary().clone(), current_date); + let mut table_functions = TableFunctions::default(); + let numbers = Numbers::new(); + table_functions.insert(numbers.summary().clone(), numbers); build_table(&table_cache, &mut transaction)?; let mut cursor = Cursor::new(Vec::new()); @@ -935,17 +960,22 @@ mod test { .unwrap(); *table.get_column_id_by_name("c3").unwrap() }; + let context = ReferenceDecodeContext::with_functions( + Some((&transaction, &table_cache)), + &scala_functions, + &table_functions, + ); fn_assert( &mut cursor, ScalarExpression::Constant(DataValue::Null), - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( &mut cursor, ScalarExpression::Constant(DataValue::Int32(42)), - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -955,7 +985,7 @@ mod test { ty: Utf8Type::Variable(None), unit: CharLengthUnits::Characters, }), - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -976,7 +1006,7 @@ mod test { )), 0, ), - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -993,7 +1023,7 @@ mod test { )), 1, ), - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1002,7 +1032,7 @@ mod test { expr: Box::new(ScalarExpression::Empty), alias: AliasType::Name("Hello".to_string()), }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1011,7 +1041,7 @@ mod test { expr: Box::new(ScalarExpression::Empty), alias: AliasType::Expr(Box::new(ScalarExpression::Empty)), }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1024,7 +1054,7 @@ mod test { Cow::Owned(LogicalType::Integer), )?), }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1033,7 +1063,7 @@ mod test { negated: true, expr: Box::new(ScalarExpression::Empty), }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1041,10 +1071,14 @@ mod test { ScalarExpression::Unary { op: UnaryOperator::Plus, expr: Box::new(ScalarExpression::Empty), - evaluator: Some(UnaryEvaluatorBox(Arc::new(BooleanNotUnaryEvaluator))), + evaluator: Some(UnaryEvaluatorBox::new( + Arc::new(BooleanNotUnaryEvaluator), + LogicalType::Boolean, + UnaryOperator::Not, + )), ty: LogicalType::Integer, }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1055,7 +1089,7 @@ mod test { evaluator: None, ty: LogicalType::Integer, }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1064,10 +1098,14 @@ mod test { op: BinaryOperator::Plus, left_expr: Box::new(ScalarExpression::Empty), right_expr: Box::new(ScalarExpression::Empty), - evaluator: Some(BinaryEvaluatorBox(Arc::new(Int32PlusBinaryEvaluator))), + evaluator: Some(BinaryEvaluatorBox::new( + Arc::new(Int32PlusBinaryEvaluator), + LogicalType::Integer, + BinaryOperator::Plus, + )), ty: LogicalType::Integer, }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1079,7 +1117,7 @@ mod test { evaluator: None, ty: LogicalType::Integer, }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1090,7 +1128,7 @@ mod test { args: vec![ScalarExpression::Empty], ty: LogicalType::Double, }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1100,7 +1138,7 @@ mod test { expr: Box::new(ScalarExpression::Empty), args: vec![ScalarExpression::Empty], }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1111,7 +1149,7 @@ mod test { left_expr: Box::new(ScalarExpression::Empty), right_expr: Box::new(ScalarExpression::Empty), }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1121,7 +1159,7 @@ mod test { for_expr: Some(Box::new(ScalarExpression::Empty)), from_expr: Some(Box::new(ScalarExpression::Empty)), }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1131,7 +1169,7 @@ mod test { for_expr: None, from_expr: Some(Box::new(ScalarExpression::Empty)), }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1141,7 +1179,7 @@ mod test { for_expr: None, from_expr: None, }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1150,7 +1188,7 @@ mod test { expr: Box::new(ScalarExpression::Empty), in_expr: Box::new(ScalarExpression::Empty), }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1160,7 +1198,7 @@ mod test { trim_what_expr: Some(Box::new(ScalarExpression::Empty)), trim_where: Some(TrimWhereField::Both), }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1170,7 +1208,7 @@ mod test { trim_what_expr: None, trim_where: Some(TrimWhereField::Both), }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1180,19 +1218,19 @@ mod test { trim_what_expr: None, trim_where: None, }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( &mut cursor, ScalarExpression::Empty, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( &mut cursor, ScalarExpression::Tuple(vec![ScalarExpression::Empty]), - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1201,7 +1239,7 @@ mod test { args: vec![ScalarExpression::Empty], inner: ArcScalarFunctionImpl(CurrentDate::new()), }), - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1210,7 +1248,7 @@ mod test { args: vec![ScalarExpression::Empty], inner: ArcTableFunctionImpl(Numbers::new()), }), - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1221,7 +1259,7 @@ mod test { right_expr: Box::new(ScalarExpression::Empty), ty: LogicalType::Integer, }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1231,7 +1269,7 @@ mod test { right_expr: Box::new(ScalarExpression::Empty), ty: LogicalType::Integer, }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1241,7 +1279,7 @@ mod test { right_expr: Box::new(ScalarExpression::Empty), ty: LogicalType::Integer, }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1250,7 +1288,7 @@ mod test { exprs: vec![ScalarExpression::Empty], ty: LogicalType::Integer, }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1261,7 +1299,7 @@ mod test { else_expr: Some(Box::new(ScalarExpression::Empty)), ty: LogicalType::Integer, }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1272,7 +1310,7 @@ mod test { else_expr: Some(Box::new(ScalarExpression::Empty)), ty: LogicalType::Integer, }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; fn_assert( @@ -1283,7 +1321,7 @@ mod test { else_expr: None, ty: LogicalType::Integer, }, - Some((&transaction, &table_cache)), + Some(&context), &mut reference_tables, )?; diff --git a/src/expression/simplify.rs b/src/expression/simplify.rs index 78592181..52fee043 100644 --- a/src/expression/simplify.rs +++ b/src/expression/simplify.rs @@ -59,11 +59,9 @@ impl VisitorMut<'_> for ConstantCalculator { if let ScalarExpression::Constant(unary_val) = arg_expr.as_ref() { let value = if let Some(evaluator) = evaluator { - evaluator.0.unary_eval(unary_val) + evaluator.unary_eval(unary_val) } else { - unary_create(Cow::Borrowed(ty), *op)? - .0 - .unary_eval(unary_val) + unary_create(Cow::Borrowed(ty), *op)?.unary_eval(unary_val) }; let _ = mem::replace(expr, ScalarExpression::Constant(value)); } @@ -89,7 +87,7 @@ impl VisitorMut<'_> for ConstantCalculator { *left_val = mem::replace(left_val, DataValue::Null).cast(&ty)?; *right_val = mem::replace(right_val, DataValue::Null).cast(&ty)?; - let value = evaluator.0.binary_eval(left_val, right_val)?; + let value = evaluator.binary_eval(left_val, right_val)?; let _ = mem::replace(expr, ScalarExpression::Constant(value)); } } @@ -129,9 +127,9 @@ impl VisitorMut<'_> for Simplify { let child_expr = arg_expr.as_ref().clone(); let value = if let Some(value) = arg_expr.unpack_val() { Some(if let Some(evaluator) = evaluator { - evaluator.0.unary_eval(&value) + evaluator.unary_eval(&value) } else { - unary_create(Cow::Borrowed(&ty), op)?.0.unary_eval(&value) + unary_create(Cow::Borrowed(&ty), op)?.unary_eval(&value) }) } else { None @@ -565,11 +563,10 @@ impl ScalarExpression { } => { let value = expr.unpack_val()?; let unary_value = if let Some(evaluator) = evaluator { - evaluator.0.unary_eval(&value) + evaluator.unary_eval(&value) } else { unary_create(Cow::Borrowed(ty), *op) .ok()? - .0 .unary_eval(&value) }; Some(unary_value) @@ -587,11 +584,10 @@ impl ScalarExpression { left = left.cast(ty).ok()?; right = right.cast(ty).ok()?; if let Some(evaluator) = evaluator { - evaluator.0.binary_eval(&left, &right) + evaluator.binary_eval(&left, &right) } else { binary_create(Cow::Borrowed(ty), *op) .ok()? - .0 .binary_eval(&left, &right) } .ok() diff --git a/src/expression/visitor.rs b/src/expression/visitor.rs index b024e99c..562077e2 100644 --- a/src/expression/visitor.rs +++ b/src/expression/visitor.rs @@ -17,11 +17,11 @@ use crate::errors::DatabaseError; use crate::expression::agg::AggKind; use crate::expression::function::scala::ScalarFunction; use crate::expression::function::table::TableFunction; +use crate::expression::TrimWhereField; use crate::expression::{AliasType, BinaryOperator, ScalarExpression, UnaryOperator}; use crate::types::evaluator::{BinaryEvaluatorBox, CastEvaluatorBox, UnaryEvaluatorBox}; use crate::types::value::DataValue; use crate::types::LogicalType; -use sqlparser::ast::TrimWhereField; pub trait Visitor<'a>: Sized { fn visit(&mut self, expr: &'a ScalarExpression) -> Result<(), DatabaseError> { diff --git a/src/expression/visitor_mut.rs b/src/expression/visitor_mut.rs index 4619d86e..e15b20b6 100644 --- a/src/expression/visitor_mut.rs +++ b/src/expression/visitor_mut.rs @@ -17,11 +17,11 @@ use crate::errors::DatabaseError; use crate::expression::agg::AggKind; use crate::expression::function::scala::ScalarFunction; use crate::expression::function::table::TableFunction; +use crate::expression::TrimWhereField; use crate::expression::{AliasType, BinaryOperator, ScalarExpression, UnaryOperator}; use crate::types::evaluator::{BinaryEvaluatorBox, CastEvaluatorBox, UnaryEvaluatorBox}; use crate::types::value::DataValue; use crate::types::LogicalType; -use sqlparser::ast::TrimWhereField; pub(crate) struct PositionShift { pub(crate) delta: isize, diff --git a/src/function/char_length.rs b/src/function/char_length.rs index d284a549..2940f769 100644 --- a/src/function/char_length.rs +++ b/src/function/char_length.rs @@ -19,10 +19,10 @@ use crate::expression::function::FunctionSummary; use crate::expression::ScalarExpression; use crate::types::tuple::TupleLike; use crate::types::value::DataValue; +use crate::types::CharLengthUnits; use crate::types::LogicalType; use serde::Deserialize; use serde::Serialize; -use sqlparser::ast::CharLengthUnits; use std::sync::Arc; #[derive(Debug, Serialize, Deserialize)] @@ -41,8 +41,6 @@ impl CharLength { }) } } - -#[typetag::serde] impl ScalarFunctionImpl for CharLength { #[allow(unused_variables, clippy::redundant_closure_call)] fn eval( diff --git a/src/function/current_date.rs b/src/function/current_date.rs index 52b867c9..99bc308b 100644 --- a/src/function/current_date.rs +++ b/src/function/current_date.rs @@ -43,8 +43,6 @@ impl CurrentDate { }) } } - -#[typetag::serde] impl ScalarFunctionImpl for CurrentDate { #[allow(unused_variables, clippy::redundant_closure_call)] fn eval( diff --git a/src/function/current_timestamp.rs b/src/function/current_timestamp.rs index e5c00c76..3fac94c7 100644 --- a/src/function/current_timestamp.rs +++ b/src/function/current_timestamp.rs @@ -43,8 +43,6 @@ impl CurrentTimeStamp { }) } } - -#[typetag::serde] impl ScalarFunctionImpl for CurrentTimeStamp { #[allow(unused_variables, clippy::redundant_closure_call)] fn eval( diff --git a/src/function/lower.rs b/src/function/lower.rs index 2995aae9..021e1e31 100644 --- a/src/function/lower.rs +++ b/src/function/lower.rs @@ -19,10 +19,10 @@ use crate::expression::function::FunctionSummary; use crate::expression::ScalarExpression; use crate::types::tuple::TupleLike; use crate::types::value::DataValue; +use crate::types::CharLengthUnits; use crate::types::LogicalType; use serde::Deserialize; use serde::Serialize; -use sqlparser::ast::CharLengthUnits; use std::sync::Arc; #[derive(Debug, Serialize, Deserialize)] @@ -43,8 +43,6 @@ impl Lower { }) } } - -#[typetag::serde] impl ScalarFunctionImpl for Lower { #[allow(unused_variables, clippy::redundant_closure_call)] fn eval( diff --git a/src/function/numbers.rs b/src/function/numbers.rs index e3d3cdfd..b5ad1cdf 100644 --- a/src/function/numbers.rs +++ b/src/function/numbers.rs @@ -58,8 +58,6 @@ impl Numbers { }) } } - -#[typetag::serde] impl TableFunctionImpl for Numbers { #[allow(unused_variables, clippy::redundant_closure_call)] fn eval( diff --git a/src/function/octet_length.rs b/src/function/octet_length.rs index e970c336..af38e5ef 100644 --- a/src/function/octet_length.rs +++ b/src/function/octet_length.rs @@ -19,10 +19,10 @@ use crate::expression::function::FunctionSummary; use crate::expression::ScalarExpression; use crate::types::tuple::TupleLike; use crate::types::value::DataValue; +use crate::types::CharLengthUnits; use crate::types::LogicalType; use serde::Deserialize; use serde::Serialize; -use sqlparser::ast::CharLengthUnits; use std::sync::Arc; #[derive(Debug, Serialize, Deserialize)] @@ -42,8 +42,6 @@ impl OctetLength { }) } } - -#[typetag::serde] impl ScalarFunctionImpl for OctetLength { #[allow(unused_variables, clippy::redundant_closure_call)] fn eval( diff --git a/src/function/upper.rs b/src/function/upper.rs index 00bb44ec..417991a9 100644 --- a/src/function/upper.rs +++ b/src/function/upper.rs @@ -19,10 +19,10 @@ use crate::expression::function::FunctionSummary; use crate::expression::ScalarExpression; use crate::types::tuple::TupleLike; use crate::types::value::DataValue; +use crate::types::CharLengthUnits; use crate::types::LogicalType; use serde::Deserialize; use serde::Serialize; -use sqlparser::ast::CharLengthUnits; use std::sync::Arc; #[derive(Debug, Serialize, Deserialize)] @@ -43,8 +43,6 @@ impl Upper { }) } } - -#[typetag::serde] impl ScalarFunctionImpl for Upper { #[allow(unused_variables, clippy::redundant_closure_call)] fn eval( diff --git a/src/macros/mod.rs b/src/macros/mod.rs index f18a43c8..bf373a27 100644 --- a/src/macros/mod.rs +++ b/src/macros/mod.rs @@ -102,10 +102,7 @@ macro_rules! scala_function { } }) } - } - - #[typetag::serde] - impl ::kite_sql::expression::function::scala::ScalarFunctionImpl for $struct_name { + } impl ::kite_sql::expression::function::scala::ScalarFunctionImpl for $struct_name { #[allow(unused_variables, clippy::redundant_closure_call)] fn eval(&self, args: &[::kite_sql::expression::ScalarExpression], tuple: Option<&dyn ::kite_sql::types::tuple::TupleLike>) -> Result<::kite_sql::types::value::DataValue, ::kite_sql::errors::DatabaseError> { let mut _index = 0; @@ -187,10 +184,7 @@ macro_rules! table_function { } }) } - } - - #[typetag::serde] - impl ::kite_sql::expression::function::table::TableFunctionImpl for $struct_name { + } impl ::kite_sql::expression::function::table::TableFunctionImpl for $struct_name { #[allow(unused_variables, clippy::redundant_closure_call)] fn eval(&self, args: &[::kite_sql::expression::ScalarExpression]) -> Result>>, ::kite_sql::errors::DatabaseError> { let mut _index = 0; diff --git a/src/optimizer/core/cm_sketch.rs b/src/optimizer/core/cm_sketch.rs index fcc3c03f..56063875 100644 --- a/src/optimizer/core/cm_sketch.rs +++ b/src/optimizer/core/cm_sketch.rs @@ -15,7 +15,7 @@ use crate::errors::DatabaseError; use crate::expression::range_detacher::Range; use crate::serdes::{ReferenceSerialization, ReferenceTables}; -use crate::storage::{TableCache, Transaction}; +use crate::storage::Transaction; use crate::types::value::DataValue; use kite_sql_serde_macros::ReferenceSerialization; use siphasher::sip::SipHasher13; @@ -352,7 +352,7 @@ impl ReferenceSerialization for CountMinSketch { fn decode( reader: &mut R, - drive: Option<(&T, &TableCache)>, + drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, ) -> Result { let counters = Vec::>::decode(reader, drive, reference_tables)?; diff --git a/src/optimizer/rule/normalization/column_pruning.rs b/src/optimizer/rule/normalization/column_pruning.rs index b627666d..e68a85a2 100644 --- a/src/optimizer/rule/normalization/column_pruning.rs +++ b/src/optimizer/rule/normalization/column_pruning.rs @@ -23,9 +23,9 @@ use crate::planner::operator::join::JoinCondition; use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; use crate::types::value::{DataValue, Utf8Type}; +use crate::types::CharLengthUnits; use crate::types::LogicalType; use bumpalo::Bump; -use sqlparser::ast::CharLengthUnits; use std::collections::HashSet; #[derive(Clone)] diff --git a/src/orm/mod.rs b/src/orm/mod.rs index 8e903d42..23f61fae 100644 --- a/src/orm/mod.rs +++ b/src/orm/mod.rs @@ -9,6 +9,7 @@ use crate::errors::DatabaseError; use crate::storage::{Storage, Transaction}; use crate::types::tuple::{SchemaRef, Tuple}; use crate::types::value::DataValue; +use crate::types::CharLengthUnits; use crate::types::LogicalType; use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; use rust_decimal::Decimal; @@ -16,11 +17,11 @@ use sqlparser::ast::helpers::attached_token::AttachedToken; use sqlparser::ast::TruncateTableTarget; use sqlparser::ast::{ AlterColumnOperation, AlterTable, AlterTableOperation, Analyze, Assignment, AssignmentTarget, - BinaryOperator as SqlBinaryOperator, CaseWhen, CastKind, CharLengthUnits, ColumnDef, - ColumnOption, ColumnOptionDef, CreateIndex, CreateTable, CreateTableOptions, CreateView, - DataType, Delete, DescribeAlias, Distinct, Expr, FromTable, Function, FunctionArg, - FunctionArgExpr, FunctionArgumentList, FunctionArguments, GroupByExpr, HiveDistributionStyle, - Ident, IndexColumn, Insert, Join, JoinConstraint, JoinOperator, KeyOrIndexDisplay, LimitClause, + BinaryOperator as SqlBinaryOperator, CaseWhen, CastKind, ColumnDef, ColumnOption, + ColumnOptionDef, CreateIndex, CreateTable, CreateTableOptions, CreateView, DataType, Delete, + DescribeAlias, Distinct, Expr, FromTable, Function, FunctionArg, FunctionArgExpr, + FunctionArgumentList, FunctionArguments, GroupByExpr, HiveDistributionStyle, Ident, + IndexColumn, Insert, Join, JoinConstraint, JoinOperator, KeyOrIndexDisplay, LimitClause, NullsDistinctOption, ObjectName, ObjectType, Offset, OffsetRows, OrderBy, OrderByExpr, OrderByKind, OrderByOptions, PrimaryKeyConstraint, Query, Select, SelectFlavor, SelectItem, SetExpr, SetOperator, SetQuantifier, ShowStatementOptions, TableAlias, TableFactor, diff --git a/src/serdes/boolean.rs b/src/serdes/boolean.rs index 5ddc4b2b..fcbb3b8f 100644 --- a/src/serdes/boolean.rs +++ b/src/serdes/boolean.rs @@ -14,7 +14,7 @@ use crate::errors::DatabaseError; use crate::serdes::{ReferenceSerialization, ReferenceTables}; -use crate::storage::{TableCache, Transaction}; +use crate::storage::Transaction; use std::io::{Read, Write}; impl ReferenceSerialization for bool { @@ -29,7 +29,7 @@ impl ReferenceSerialization for bool { fn decode( reader: &mut R, - drive: Option<(&T, &TableCache)>, + drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, ) -> Result { Ok(u8::decode(reader, drive, reference_tables)? == 1u8) diff --git a/src/serdes/bound.rs b/src/serdes/bound.rs index 82a9f137..825f2b3d 100644 --- a/src/serdes/bound.rs +++ b/src/serdes/bound.rs @@ -14,7 +14,7 @@ use crate::errors::DatabaseError; use crate::serdes::{ReferenceSerialization, ReferenceTables}; -use crate::storage::{TableCache, Transaction}; +use crate::storage::Transaction; use std::io::{Read, Write}; use std::ops::Bound; @@ -49,7 +49,7 @@ where fn decode( reader: &mut R, - drive: Option<(&T, &TableCache)>, + drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, ) -> Result { let mut type_bytes = [0u8; 1]; diff --git a/src/serdes/btree_map.rs b/src/serdes/btree_map.rs index df3e4d54..7689ea93 100644 --- a/src/serdes/btree_map.rs +++ b/src/serdes/btree_map.rs @@ -14,7 +14,7 @@ use crate::errors::DatabaseError; use crate::serdes::{ReferenceSerialization, ReferenceTables}; -use crate::storage::{TableCache, Transaction}; +use crate::storage::Transaction; use std::collections::BTreeMap; use std::io::{Read, Write}; @@ -39,7 +39,7 @@ where fn decode( reader: &mut R, - drive: Option<(&T, &TableCache)>, + drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, ) -> Result { let len = ::decode(reader, drive, reference_tables)?; diff --git a/src/serdes/char.rs b/src/serdes/char.rs index f126cfbd..c6dcc103 100644 --- a/src/serdes/char.rs +++ b/src/serdes/char.rs @@ -14,7 +14,7 @@ use crate::errors::DatabaseError; use crate::serdes::{ReferenceSerialization, ReferenceTables}; -use crate::storage::{TableCache, Transaction}; +use crate::storage::Transaction; use std::io::{Read, Write}; impl ReferenceSerialization for char { @@ -32,7 +32,7 @@ impl ReferenceSerialization for char { fn decode( reader: &mut R, - _: Option<(&T, &TableCache)>, + _: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, _: &ReferenceTables, ) -> Result { let mut buf = [0u8; 2]; diff --git a/src/serdes/char_length_units.rs b/src/serdes/char_length_units.rs index f09e8e9c..b0b4036a 100644 --- a/src/serdes/char_length_units.rs +++ b/src/serdes/char_length_units.rs @@ -14,8 +14,8 @@ use crate::errors::DatabaseError; use crate::serdes::{ReferenceSerialization, ReferenceTables}; -use crate::storage::{TableCache, Transaction}; -use sqlparser::ast::CharLengthUnits; +use crate::storage::Transaction; +use crate::types::CharLengthUnits; use std::io::{Read, Write}; impl ReferenceSerialization for CharLengthUnits { @@ -36,7 +36,7 @@ impl ReferenceSerialization for CharLengthUnits { fn decode( reader: &mut R, - _: Option<(&T, &TableCache)>, + _: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, _: &ReferenceTables, ) -> Result { let mut one_byte = [0u8; 1]; @@ -55,7 +55,7 @@ pub(crate) mod test { use crate::errors::DatabaseError; use crate::serdes::{ReferenceSerialization, ReferenceTables}; use crate::storage::rocksdb::RocksTransaction; - use sqlparser::ast::CharLengthUnits; + use crate::types::CharLengthUnits; use std::io::{Cursor, Seek, SeekFrom}; #[test] diff --git a/src/serdes/column.rs b/src/serdes/column.rs index 5ce8caae..cd214cc5 100644 --- a/src/serdes/column.rs +++ b/src/serdes/column.rs @@ -14,8 +14,8 @@ use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef, ColumnRelation, ColumnSummary}; use crate::errors::DatabaseError; -use crate::serdes::{ReferenceSerialization, ReferenceTables}; -use crate::storage::{TableCache, Transaction}; +use crate::serdes::{ReferenceDecodeContext, ReferenceSerialization, ReferenceTables}; +use crate::storage::Transaction; use crate::types::ColumnId; use std::io::{Read, Write}; use std::sync::Arc; @@ -48,7 +48,7 @@ impl ReferenceSerialization for ColumnRef { fn decode( reader: &mut R, - drive: Option<(&T, &TableCache)>, + drive: Option<&ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, ) -> Result { let summary = ColumnSummary::decode(reader, drive, reference_tables)?; @@ -61,8 +61,10 @@ impl ReferenceSerialization for ColumnRef { is_temp: false, }, Some((transaction, table_cache)), - ) = (&summary.relation, drive) - { + ) = ( + &summary.relation, + drive.and_then(ReferenceDecodeContext::drive), + ) { let table = transaction .table(table_cache, table_name.clone())? .ok_or(DatabaseError::TableNotFound)?; @@ -123,7 +125,7 @@ impl ReferenceSerialization for ColumnRelation { fn decode( reader: &mut R, - drive: Option<(&T, &TableCache)>, + drive: Option<&ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, ) -> Result { let mut type_bytes = [0u8; 1]; @@ -161,8 +163,7 @@ pub(crate) mod test { use crate::db::test::build_table; use crate::errors::DatabaseError; use crate::expression::ScalarExpression; - use crate::serdes::ReferenceSerialization; - use crate::serdes::ReferenceTables; + use crate::serdes::{ReferenceDecodeContext, ReferenceSerialization, ReferenceTables}; use crate::storage::rocksdb::{RocksStorage, RocksTransaction}; use crate::storage::{StatisticsMetaCache, Storage, Transaction}; use crate::types::value::DataValue; @@ -213,19 +214,23 @@ pub(crate) mod test { cursor.seek(SeekFrom::Start(0))?; assert_eq!( - ColumnRef::decode::>>( - &mut cursor, - Some((&transaction, &table_cache)), - &reference_tables - )?, + { + let context = ReferenceDecodeContext::new(Some((&transaction, &table_cache))); + ColumnRef::decode::>>( + &mut cursor, + Some(&context), + &reference_tables, + )? + }, ref_column ); cursor.seek(SeekFrom::Start(0))?; transaction.drop_column(&table_cache, &meta_cache, &table_name, "c3")?; + let context = ReferenceDecodeContext::new(Some((&transaction, &table_cache))); assert!(ColumnRef::decode::>>( &mut cursor, - Some((&transaction, &table_cache)), + Some(&context), &reference_tables ) .is_err()); diff --git a/src/serdes/evaluator.rs b/src/serdes/evaluator.rs index 71b5e4ff..42b84310 100644 --- a/src/serdes/evaluator.rs +++ b/src/serdes/evaluator.rs @@ -12,9 +12,80 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::implement_serialization_by_bincode; -use crate::types::evaluator::{BinaryEvaluatorBox, CastEvaluatorBox, UnaryEvaluatorBox}; +use crate::errors::DatabaseError; +use crate::expression::{BinaryOperator, UnaryOperator}; +use crate::serdes::{ReferenceDecodeContext, ReferenceSerialization, ReferenceTables}; +use crate::storage::Transaction; +use crate::types::evaluator::{ + binary_create, cast_create, unary_create, BinaryEvaluatorBox, CastEvaluatorBox, + UnaryEvaluatorBox, +}; +use crate::types::LogicalType; +use std::borrow::Cow; +use std::io::{Read, Write}; -implement_serialization_by_bincode!(UnaryEvaluatorBox); -implement_serialization_by_bincode!(BinaryEvaluatorBox); -implement_serialization_by_bincode!(CastEvaluatorBox); +impl ReferenceSerialization for UnaryEvaluatorBox { + fn encode( + &self, + writer: &mut W, + is_direct: bool, + reference_tables: &mut ReferenceTables, + ) -> Result<(), DatabaseError> { + self.ty.encode(writer, is_direct, reference_tables)?; + self.op.encode(writer, is_direct, reference_tables) + } + + fn decode( + reader: &mut R, + context: Option<&ReferenceDecodeContext<'_, T>>, + reference_tables: &ReferenceTables, + ) -> Result { + let ty = LogicalType::decode(reader, context, reference_tables)?; + let op = UnaryOperator::decode(reader, context, reference_tables)?; + unary_create(Cow::Owned(ty), op) + } +} + +impl ReferenceSerialization for BinaryEvaluatorBox { + fn encode( + &self, + writer: &mut W, + is_direct: bool, + reference_tables: &mut ReferenceTables, + ) -> Result<(), DatabaseError> { + self.ty.encode(writer, is_direct, reference_tables)?; + self.op.encode(writer, is_direct, reference_tables) + } + + fn decode( + reader: &mut R, + context: Option<&ReferenceDecodeContext<'_, T>>, + reference_tables: &ReferenceTables, + ) -> Result { + let ty = LogicalType::decode(reader, context, reference_tables)?; + let op = BinaryOperator::decode(reader, context, reference_tables)?; + binary_create(Cow::Owned(ty), op) + } +} + +impl ReferenceSerialization for CastEvaluatorBox { + fn encode( + &self, + writer: &mut W, + is_direct: bool, + reference_tables: &mut ReferenceTables, + ) -> Result<(), DatabaseError> { + self.from.encode(writer, is_direct, reference_tables)?; + self.to.encode(writer, is_direct, reference_tables) + } + + fn decode( + reader: &mut R, + context: Option<&ReferenceDecodeContext<'_, T>>, + reference_tables: &ReferenceTables, + ) -> Result { + let from = LogicalType::decode(reader, context, reference_tables)?; + let to = LogicalType::decode(reader, context, reference_tables)?; + cast_create(Cow::Owned(from), Cow::Owned(to)) + } +} diff --git a/src/serdes/function.rs b/src/serdes/function.rs index c38ee206..9e45627a 100644 --- a/src/serdes/function.rs +++ b/src/serdes/function.rs @@ -12,9 +12,76 @@ // See the License for the specific language governing permissions and // limitations under the License. +use crate::errors::DatabaseError; use crate::expression::function::scala::ArcScalarFunctionImpl; use crate::expression::function::table::ArcTableFunctionImpl; -use crate::implement_serialization_by_bincode; +use crate::expression::function::FunctionSummary; +use crate::serdes::{ReferenceDecodeContext, ReferenceSerialization, ReferenceTables}; +use crate::storage::Transaction; +use std::io::{Read, Write}; -implement_serialization_by_bincode!(ArcScalarFunctionImpl); -implement_serialization_by_bincode!(ArcTableFunctionImpl); +impl ReferenceSerialization for ArcScalarFunctionImpl { + fn encode( + &self, + writer: &mut W, + is_direct: bool, + reference_tables: &mut ReferenceTables, + ) -> Result<(), DatabaseError> { + self.summary().encode(writer, is_direct, reference_tables) + } + + fn decode( + reader: &mut R, + context: Option<&ReferenceDecodeContext<'_, T>>, + reference_tables: &ReferenceTables, + ) -> Result { + let summary = FunctionSummary::decode(reader, context, reference_tables)?; + let Some(functions) = context.and_then(ReferenceDecodeContext::scala_functions) else { + return Err(DatabaseError::InvalidValue(format!( + "scalar function decode context missing for {}", + summary.name + ))); + }; + let Some(function) = functions.get(&summary) else { + return Err(DatabaseError::InvalidValue(format!( + "scalar function not found when decoding: {}", + summary.name + ))); + }; + + Ok(Self(function.clone())) + } +} + +impl ReferenceSerialization for ArcTableFunctionImpl { + fn encode( + &self, + writer: &mut W, + is_direct: bool, + reference_tables: &mut ReferenceTables, + ) -> Result<(), DatabaseError> { + self.summary().encode(writer, is_direct, reference_tables) + } + + fn decode( + reader: &mut R, + context: Option<&ReferenceDecodeContext<'_, T>>, + reference_tables: &ReferenceTables, + ) -> Result { + let summary = FunctionSummary::decode(reader, context, reference_tables)?; + let Some(functions) = context.and_then(ReferenceDecodeContext::table_functions) else { + return Err(DatabaseError::InvalidValue(format!( + "table function decode context missing for {}", + summary.name + ))); + }; + let Some(function) = functions.get(&summary) else { + return Err(DatabaseError::InvalidValue(format!( + "table function not found when decoding: {}", + summary.name + ))); + }; + + Ok(Self(function.clone())) + } +} diff --git a/src/serdes/mod.rs b/src/serdes/mod.rs index e0db3192..c92abb3d 100644 --- a/src/serdes/mod.rs +++ b/src/serdes/mod.rs @@ -35,6 +35,7 @@ mod ulid; mod vec; use crate::catalog::TableName; +use crate::db::{ScalaFunctions, TableFunctions}; use crate::errors::DatabaseError; use crate::storage::{TableCache, Transaction}; use std::io; @@ -57,7 +58,7 @@ macro_rules! implement_serialization_by_bincode { fn decode( reader: &mut R, - _: Option<(&T, &$crate::storage::TableCache)>, + _: Option<&$crate::serdes::ReferenceDecodeContext<'_, T>>, _: &$crate::serdes::ReferenceTables, ) -> Result { Ok(bincode::deserialize_from(reader)?) @@ -76,13 +77,53 @@ pub trait ReferenceSerialization { fn decode( reader: &mut R, - drive: Option<(&T, &TableCache)>, + context: Option<&ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, ) -> Result where Self: Sized; } +pub struct ReferenceDecodeContext<'a, T: Transaction> { + drive: Option<(&'a T, &'a TableCache)>, + scala_functions: Option<&'a ScalaFunctions>, + table_functions: Option<&'a TableFunctions>, +} + +impl<'a, T: Transaction> ReferenceDecodeContext<'a, T> { + pub fn new(drive: Option<(&'a T, &'a TableCache)>) -> Self { + Self { + drive, + scala_functions: None, + table_functions: None, + } + } + + pub fn with_functions( + drive: Option<(&'a T, &'a TableCache)>, + scala_functions: &'a ScalaFunctions, + table_functions: &'a TableFunctions, + ) -> Self { + Self { + drive, + scala_functions: Some(scala_functions), + table_functions: Some(table_functions), + } + } + + pub fn drive(&self) -> Option<(&'a T, &'a TableCache)> { + self.drive + } + + pub(crate) fn scala_functions(&self) -> Option<&'a ScalaFunctions> { + self.scala_functions + } + + pub(crate) fn table_functions(&self) -> Option<&'a TableFunctions> { + self.table_functions + } +} + #[derive(Debug, Default, Eq, PartialEq)] pub struct ReferenceTables { tables: Vec, diff --git a/src/serdes/num.rs b/src/serdes/num.rs index 9e8a86fb..ef06e1b1 100644 --- a/src/serdes/num.rs +++ b/src/serdes/num.rs @@ -14,7 +14,7 @@ use crate::errors::DatabaseError; use crate::serdes::{ReferenceSerialization, ReferenceTables}; -use crate::storage::{TableCache, Transaction}; +use crate::storage::Transaction; use std::io::Read; use std::io::Write; use std::mem::size_of; @@ -36,7 +36,7 @@ macro_rules! implement_num_serialization { fn decode( reader: &mut R, - _: Option<(&T, &TableCache)>, + _: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, _: &ReferenceTables, ) -> Result { let mut bytes = [0u8; size_of::()]; @@ -72,7 +72,7 @@ impl ReferenceSerialization for usize { fn decode( reader: &mut R, - drive: Option<(&T, &TableCache)>, + drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, ) -> Result { Ok(u32::decode(reader, drive, reference_tables)? as usize) diff --git a/src/serdes/option.rs b/src/serdes/option.rs index 6b8ece88..639d6b21 100644 --- a/src/serdes/option.rs +++ b/src/serdes/option.rs @@ -14,7 +14,7 @@ use crate::errors::DatabaseError; use crate::serdes::{ReferenceSerialization, ReferenceTables}; -use crate::storage::{TableCache, Transaction}; +use crate::storage::Transaction; use std::io::{Read, Write}; impl ReferenceSerialization for Option @@ -40,7 +40,7 @@ where fn decode( reader: &mut R, - drive: Option<(&T, &TableCache)>, + drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, ) -> Result { match u8::decode(reader, drive, reference_tables)? { diff --git a/src/serdes/pair.rs b/src/serdes/pair.rs index b90aacd3..58c19ac3 100644 --- a/src/serdes/pair.rs +++ b/src/serdes/pair.rs @@ -14,7 +14,7 @@ use crate::errors::DatabaseError; use crate::serdes::{ReferenceSerialization, ReferenceTables}; -use crate::storage::{TableCache, Transaction}; +use crate::storage::Transaction; use std::io::{Read, Write}; impl ReferenceSerialization for (A, B) @@ -37,7 +37,7 @@ where fn decode( reader: &mut R, - drive: Option<(&T, &TableCache)>, + drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, ) -> Result { let v1 = A::decode(reader, drive, reference_tables)?; diff --git a/src/serdes/phantom.rs b/src/serdes/phantom.rs index a642687d..684097da 100644 --- a/src/serdes/phantom.rs +++ b/src/serdes/phantom.rs @@ -14,7 +14,7 @@ use crate::errors::DatabaseError; use crate::serdes::{ReferenceSerialization, ReferenceTables}; -use crate::storage::{TableCache, Transaction}; +use crate::storage::Transaction; use std::io::{Read, Write}; use std::marker::PhantomData; @@ -30,7 +30,7 @@ impl ReferenceSerialization for PhantomData { fn decode( _: &mut R, - _: Option<(&T, &TableCache)>, + _: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, _: &ReferenceTables, ) -> Result { Ok(PhantomData) diff --git a/src/serdes/ptr.rs b/src/serdes/ptr.rs index 4a437795..9cba60ac 100644 --- a/src/serdes/ptr.rs +++ b/src/serdes/ptr.rs @@ -13,7 +13,6 @@ // limitations under the License. use crate::serdes::DatabaseError; -use crate::serdes::TableCache; use crate::serdes::Transaction; use crate::serdes::{ReferenceSerialization, ReferenceTables}; use std::io::{Read, Write}; @@ -37,7 +36,7 @@ macro_rules! implement_ptr_serialization { fn decode( reader: &mut R, - drive: Option<(&T, &TableCache)>, + drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, ) -> Result where diff --git a/src/serdes/slice.rs b/src/serdes/slice.rs index 602eadb1..523ffa9c 100644 --- a/src/serdes/slice.rs +++ b/src/serdes/slice.rs @@ -14,7 +14,7 @@ use crate::errors::DatabaseError; use crate::serdes::{ReferenceSerialization, ReferenceTables}; -use crate::storage::{TableCache, Transaction}; +use crate::storage::Transaction; use std::io::{Read, Write}; impl ReferenceSerialization for [V; 2] @@ -35,7 +35,7 @@ where fn decode( reader: &mut R, - drive: Option<(&T, &TableCache)>, + drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, ) -> Result { Ok([ diff --git a/src/serdes/string.rs b/src/serdes/string.rs index 3d221bfd..720b9754 100644 --- a/src/serdes/string.rs +++ b/src/serdes/string.rs @@ -15,7 +15,7 @@ use crate::errors::DatabaseError; use crate::implement_serialization_by_bincode; use crate::serdes::{ReferenceSerialization, ReferenceTables}; -use crate::storage::{TableCache, Transaction}; +use crate::storage::Transaction; use std::sync::Arc; implement_serialization_by_bincode!(String); @@ -34,7 +34,7 @@ impl ReferenceSerialization for Arc { fn decode( reader: &mut R, - _: Option<(&T, &TableCache)>, + _: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, _: &ReferenceTables, ) -> Result { let str: String = bincode::deserialize_from(reader)?; diff --git a/src/serdes/trim.rs b/src/serdes/trim.rs index de6b26f3..d9aa39ca 100644 --- a/src/serdes/trim.rs +++ b/src/serdes/trim.rs @@ -13,9 +13,9 @@ // limitations under the License. use crate::errors::DatabaseError; +use crate::expression::TrimWhereField; use crate::serdes::{ReferenceSerialization, ReferenceTables}; -use crate::storage::{TableCache, Transaction}; -use sqlparser::ast::TrimWhereField; +use crate::storage::Transaction; use std::io::{Read, Write}; impl ReferenceSerialization for TrimWhereField { @@ -37,7 +37,7 @@ impl ReferenceSerialization for TrimWhereField { fn decode( reader: &mut R, - _: Option<(&T, &TableCache)>, + _: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, _: &ReferenceTables, ) -> Result { let mut one_byte = [0u8; 1]; diff --git a/src/serdes/ulid.rs b/src/serdes/ulid.rs index d56f93ba..52804d5b 100644 --- a/src/serdes/ulid.rs +++ b/src/serdes/ulid.rs @@ -14,7 +14,7 @@ use crate::errors::DatabaseError; use crate::serdes::{ReferenceSerialization, ReferenceTables}; -use crate::storage::{TableCache, Transaction}; +use crate::storage::Transaction; use std::io::{Read, Write}; use ulid::Ulid; @@ -32,7 +32,7 @@ impl ReferenceSerialization for Ulid { fn decode( reader: &mut R, - _: Option<(&T, &TableCache)>, + _: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, _: &ReferenceTables, ) -> Result { let mut buf = [0u8; 16]; diff --git a/src/serdes/vec.rs b/src/serdes/vec.rs index 3f321d46..97f69aff 100644 --- a/src/serdes/vec.rs +++ b/src/serdes/vec.rs @@ -14,7 +14,7 @@ use crate::errors::DatabaseError; use crate::serdes::{ReferenceSerialization, ReferenceTables}; -use crate::storage::{TableCache, Transaction}; +use crate::storage::Transaction; use std::io::{Read, Write}; impl ReferenceSerialization for Vec @@ -36,7 +36,7 @@ where fn decode( reader: &mut R, - drive: Option<(&T, &TableCache)>, + drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, ) -> Result { let len = ::decode(reader, drive, reference_tables)?; diff --git a/src/storage/mod.rs b/src/storage/mod.rs index f6da370e..43ae10f0 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -21,6 +21,7 @@ pub(crate) mod table_codec; use crate::catalog::view::View; use crate::catalog::{ColumnCatalog, ColumnRef, TableCatalog, TableMeta, TableName}; +use crate::db::{ScalaFunctions, TableFunctions}; use crate::errors::DatabaseError; use crate::expression::range_detacher::Range; use crate::expression::ScalarExpression; @@ -684,15 +685,13 @@ pub trait Transaction: Sized { fn drop_view( &mut self, view_cache: &ViewCache, - table_cache: &TableCache, view_name: TableName, if_exists: bool, ) -> Result<(), DatabaseError> { self.drop_name_hash(&view_name)?; - if self - .view(table_cache, view_cache, view_name.clone())? - .is_none() - { + let exists = + unsafe { &*self.table_codec() }.with_view_key(&view_name, |key| self.exists(key))?; + if !exists { if if_exists { return Ok(()); } else { @@ -800,6 +799,8 @@ pub trait Transaction: Sized { &'a self, table_cache: &'a TableCache, view_cache: &'a ViewCache, + scala_functions: &'a ScalaFunctions, + table_functions: &'a TableFunctions, view_name: TableName, ) -> Result, DatabaseError> { if let Some(view) = view_cache.get(&view_name) { @@ -810,23 +811,30 @@ pub trait Transaction: Sized { return Ok(None); }; Ok(Some(view_cache.get_or_insert(view_name.clone(), |_| { - TableCodec::decode_view(bytes.as_ref(), (self, table_cache)) + TableCodec::decode_view( + bytes.as_ref(), + (self, table_cache), + scala_functions, + table_functions, + ) })?)) }) } - fn views(&self, table_cache: &TableCache) -> Result, DatabaseError> { - let mut metas = vec![]; + fn views<'a>( + &'a self, + table_cache: &'a TableCache, + scala_functions: &'a ScalaFunctions, + table_functions: &'a TableFunctions, + ) -> Result, DatabaseError> { unsafe { &*self.table_codec() }.with_view_bound(|min, max| { - let mut iter = self.range(Bound::Included(min), Bound::Included(max))?; - - while let Some((_, value)) = iter.try_next().ok().flatten() { - let meta = TableCodec::decode_view(value, (self, table_cache))?; - - metas.push(meta); - } - - Ok(metas) + Ok(ViewIter { + iter: self.range(Bound::Included(min), Bound::Included(max))?, + transaction: self, + table_cache, + scala_functions, + table_functions, + }) }) } @@ -849,18 +857,11 @@ pub trait Transaction: Sized { .transpose() } - fn table_metas(&self) -> Result, DatabaseError> { - let mut metas = vec![]; + fn tables<'a>(&'a self) -> Result, DatabaseError> { unsafe { &*self.table_codec() }.with_root_table_bound(|min, max| { - let mut iter = self.range(Bound::Included(min), Bound::Included(max))?; - - while let Some((_, value)) = iter.try_next().ok().flatten() { - let meta = TableCodec::decode_root_table::(value)?; - - metas.push(meta); - } - - Ok(metas) + Ok(TableIter { + iter: self.range(Bound::Included(min), Bound::Included(max))?, + }) }) } @@ -1893,6 +1894,43 @@ pub trait Iter { fn next_tuple_into(&mut self, tuple: &mut Tuple) -> Result; } +pub struct TableIter<'a, T: Transaction + 'a> { + iter: T::IterType<'a>, +} + +impl TableIter<'_, T> { + pub fn try_next(&mut self) -> Result, DatabaseError> { + let Some((_, value)) = self.iter.try_next()? else { + return Ok(None); + }; + + Ok(Some(TableCodec::decode_root_table::(value)?)) + } +} + +pub struct ViewIter<'a, T: Transaction + 'a> { + iter: T::IterType<'a>, + transaction: &'a T, + table_cache: &'a TableCache, + scala_functions: &'a ScalaFunctions, + table_functions: &'a TableFunctions, +} + +impl ViewIter<'_, T> { + pub fn try_next(&mut self) -> Result, DatabaseError> { + let Some((_, value)) = self.iter.try_next()? else { + return Ok(None); + }; + + Ok(Some(TableCodec::decode_view( + value, + (self.transaction, self.table_cache), + self.scala_functions, + self.table_functions, + )?)) + } +} + #[cfg(test)] pub(crate) fn next_tuple_for_test(iter: &mut I) -> Result, DatabaseError> { let mut tuple = Tuple::default(); @@ -2612,6 +2650,8 @@ mod test { #[test] fn test_view_create_drop() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; + let scala_functions = Default::default(); + let table_functions = Default::default(); let view_name: TableName = "v1".to_string().into(); let view = View { @@ -2629,6 +2669,8 @@ mod test { .view( &table_state.table_cache, &table_state.view_cache, + &scala_functions, + &table_functions, view_name.clone(), )? .unwrap() @@ -2639,25 +2681,23 @@ mod test { .view( &Arc::new(SharedLruCache::new(4, 1, RandomState::new())?), &table_state.view_cache, + &scala_functions, + &table_functions, view_name.clone(), )? .unwrap() ); - transaction.drop_view( - &table_state.view_cache, - &table_state.table_cache, - view_name.clone(), - false, - )?; - transaction.drop_view( - &table_state.view_cache, - &table_state.table_cache, - view_name.clone(), - true, - )?; + transaction.drop_view(&table_state.view_cache, view_name.clone(), false)?; + transaction.drop_view(&table_state.view_cache, view_name.clone(), true)?; assert!(transaction - .view(&table_state.table_cache, &table_state.view_cache, view_name)? + .view( + &table_state.table_cache, + &table_state.view_cache, + &scala_functions, + &table_functions, + view_name, + )? .is_none()); Ok(()) diff --git a/src/storage/table_codec.rs b/src/storage/table_codec.rs index 78390e1b..fd948373 100644 --- a/src/storage/table_codec.rs +++ b/src/storage/table_codec.rs @@ -14,11 +14,12 @@ use crate::catalog::view::View; use crate::catalog::{ColumnRef, ColumnRelation, TableMeta}; +use crate::db::{ScalaFunctions, TableFunctions}; use crate::errors::DatabaseError; use crate::optimizer::core::cm_sketch::{CountMinSketchMeta, CountMinSketchPage}; use crate::optimizer::core::histogram::Bucket; use crate::optimizer::core::statistics_meta::StatisticsMetaRoot; -use crate::serdes::{ReferenceSerialization, ReferenceTables}; +use crate::serdes::{ReferenceDecodeContext, ReferenceSerialization, ReferenceTables}; use crate::storage::{TableCache, Transaction}; use crate::types::index::{Index, IndexId, IndexMeta, IndexType, INDEX_ID_LEN}; use crate::types::serialize::TupleValueSerializableImpl; @@ -817,6 +818,8 @@ impl TableCodec { pub fn decode_view( bytes: &[u8], drive: (&T, &TableCache), + scala_functions: &ScalaFunctions, + table_functions: &TableFunctions, ) -> Result { let mut cursor = Cursor::new(bytes); let reference_tables_pos = { @@ -828,7 +831,9 @@ impl TableCodec { let reference_tables = ReferenceTables::from_raw(&mut cursor)?; cursor.seek(SeekFrom::Start(4))?; - View::decode(&mut cursor, Some(drive), &reference_tables) + let context = + ReferenceDecodeContext::with_functions(Some(drive), scala_functions, table_functions); + View::decode(&mut cursor, Some(&context), &reference_tables) } pub fn encode_root_table_value( @@ -1078,6 +1083,8 @@ mod tests { fn test_table_codec_view() -> Result<(), DatabaseError> { let table_codec = TableCodec::default(); let table_state = build_t1_table()?; + let scala_functions = Default::default(); + let table_functions = Default::default(); // Subquery { println!("==== Subquery"); @@ -1093,7 +1100,12 @@ mod tests { assert_eq!( view, - TableCodec::decode_view(&bytes, (&transaction, &table_state.table_cache))? + TableCodec::decode_view( + &bytes, + (&transaction, &table_state.table_cache), + &scala_functions, + &table_functions, + )? ); } // No Join @@ -1109,7 +1121,12 @@ mod tests { assert_eq!( view, - TableCodec::decode_view(&bytes, (&transaction, &table_state.table_cache))? + TableCodec::decode_view( + &bytes, + (&transaction, &table_state.table_cache), + &scala_functions, + &table_functions, + )? ); } // Join @@ -1125,7 +1142,12 @@ mod tests { assert_eq!( view, - TableCodec::decode_view(&bytes, (&transaction, &table_state.table_cache))? + TableCodec::decode_view( + &bytes, + (&transaction, &table_state.table_cache), + &scala_functions, + &table_functions, + )? ); } diff --git a/src/types/evaluator/binary.rs b/src/types/evaluator/binary.rs index 8b8d12c8..f3fc416e 100644 --- a/src/types/evaluator/binary.rs +++ b/src/types/evaluator/binary.rs @@ -42,21 +42,31 @@ use paste::paste; use std::borrow::Cow; use std::sync::Arc; +macro_rules! box_binary { + ($ty:expr, $op:expr, $evaluator:expr) => { + Ok(BinaryEvaluatorBox::new( + Arc::new($evaluator), + $ty.clone(), + $op, + )) + }; +} + macro_rules! numeric_binary_evaluator { ($value_type:ident, $op:expr, $ty:expr) => { paste! { match $op { - BinaryOperator::Plus => Ok(BinaryEvaluatorBox(Arc::new([<$value_type PlusBinaryEvaluator>]))), - BinaryOperator::Minus => Ok(BinaryEvaluatorBox(Arc::new([<$value_type MinusBinaryEvaluator>]))), - BinaryOperator::Multiply => Ok(BinaryEvaluatorBox(Arc::new([<$value_type MultiplyBinaryEvaluator>]))), - BinaryOperator::Divide => Ok(BinaryEvaluatorBox(Arc::new([<$value_type DivideBinaryEvaluator>]))), - BinaryOperator::Gt => Ok(BinaryEvaluatorBox(Arc::new([<$value_type GtBinaryEvaluator>]))), - BinaryOperator::GtEq => Ok(BinaryEvaluatorBox(Arc::new([<$value_type GtEqBinaryEvaluator>]))), - BinaryOperator::Lt => Ok(BinaryEvaluatorBox(Arc::new([<$value_type LtBinaryEvaluator>]))), - BinaryOperator::LtEq => Ok(BinaryEvaluatorBox(Arc::new([<$value_type LtEqBinaryEvaluator>]))), - BinaryOperator::Eq => Ok(BinaryEvaluatorBox(Arc::new([<$value_type EqBinaryEvaluator>]))), - BinaryOperator::NotEq => Ok(BinaryEvaluatorBox(Arc::new([<$value_type NotEqBinaryEvaluator>]))), - BinaryOperator::Modulo => Ok(BinaryEvaluatorBox(Arc::new([<$value_type ModBinaryEvaluator>]))), + BinaryOperator::Plus => box_binary!($ty, $op, [<$value_type PlusBinaryEvaluator>]), + BinaryOperator::Minus => box_binary!($ty, $op, [<$value_type MinusBinaryEvaluator>]), + BinaryOperator::Multiply => box_binary!($ty, $op, [<$value_type MultiplyBinaryEvaluator>]), + BinaryOperator::Divide => box_binary!($ty, $op, [<$value_type DivideBinaryEvaluator>]), + BinaryOperator::Gt => box_binary!($ty, $op, [<$value_type GtBinaryEvaluator>]), + BinaryOperator::GtEq => box_binary!($ty, $op, [<$value_type GtEqBinaryEvaluator>]), + BinaryOperator::Lt => box_binary!($ty, $op, [<$value_type LtBinaryEvaluator>]), + BinaryOperator::LtEq => box_binary!($ty, $op, [<$value_type LtEqBinaryEvaluator>]), + BinaryOperator::Eq => box_binary!($ty, $op, [<$value_type EqBinaryEvaluator>]), + BinaryOperator::NotEq => box_binary!($ty, $op, [<$value_type NotEqBinaryEvaluator>]), + BinaryOperator::Modulo => box_binary!($ty, $op, [<$value_type ModBinaryEvaluator>]), _ => Err(DatabaseError::UnsupportedBinaryOperator($ty.clone(), $op)), } } @@ -82,66 +92,60 @@ pub fn binary_create( LogicalType::Date => numeric_binary_evaluator!(Date, op, ty), LogicalType::DateTime => numeric_binary_evaluator!(DateTime, op, ty), LogicalType::Time(_) => match op { - BinaryOperator::Plus => Ok(BinaryEvaluatorBox(Arc::new(TimePlusBinaryEvaluator))), - BinaryOperator::Minus => Ok(BinaryEvaluatorBox(Arc::new(TimeMinusBinaryEvaluator))), - BinaryOperator::Gt => Ok(BinaryEvaluatorBox(Arc::new(TimeGtBinaryEvaluator))), - BinaryOperator::GtEq => Ok(BinaryEvaluatorBox(Arc::new(TimeGtEqBinaryEvaluator))), - BinaryOperator::Lt => Ok(BinaryEvaluatorBox(Arc::new(TimeLtBinaryEvaluator))), - BinaryOperator::LtEq => Ok(BinaryEvaluatorBox(Arc::new(TimeLtEqBinaryEvaluator))), - BinaryOperator::Eq => Ok(BinaryEvaluatorBox(Arc::new(TimeEqBinaryEvaluator))), - BinaryOperator::NotEq => Ok(BinaryEvaluatorBox(Arc::new(TimeNotEqBinaryEvaluator))), + BinaryOperator::Plus => box_binary!(ty, op, TimePlusBinaryEvaluator), + BinaryOperator::Minus => box_binary!(ty, op, TimeMinusBinaryEvaluator), + BinaryOperator::Gt => box_binary!(ty, op, TimeGtBinaryEvaluator), + BinaryOperator::GtEq => box_binary!(ty, op, TimeGtEqBinaryEvaluator), + BinaryOperator::Lt => box_binary!(ty, op, TimeLtBinaryEvaluator), + BinaryOperator::LtEq => box_binary!(ty, op, TimeLtEqBinaryEvaluator), + BinaryOperator::Eq => box_binary!(ty, op, TimeEqBinaryEvaluator), + BinaryOperator::NotEq => box_binary!(ty, op, TimeNotEqBinaryEvaluator), _ => Err(DatabaseError::UnsupportedBinaryOperator(ty.clone(), op)), }, LogicalType::TimeStamp(_, _) => match op { - BinaryOperator::Gt => Ok(BinaryEvaluatorBox(Arc::new(Time64GtBinaryEvaluator))), - BinaryOperator::GtEq => Ok(BinaryEvaluatorBox(Arc::new(Time64GtEqBinaryEvaluator))), - BinaryOperator::Lt => Ok(BinaryEvaluatorBox(Arc::new(Time64LtBinaryEvaluator))), - BinaryOperator::LtEq => Ok(BinaryEvaluatorBox(Arc::new(Time64LtEqBinaryEvaluator))), - BinaryOperator::Eq => Ok(BinaryEvaluatorBox(Arc::new(Time64EqBinaryEvaluator))), - BinaryOperator::NotEq => Ok(BinaryEvaluatorBox(Arc::new(Time64NotEqBinaryEvaluator))), + BinaryOperator::Gt => box_binary!(ty, op, Time64GtBinaryEvaluator), + BinaryOperator::GtEq => box_binary!(ty, op, Time64GtEqBinaryEvaluator), + BinaryOperator::Lt => box_binary!(ty, op, Time64LtBinaryEvaluator), + BinaryOperator::LtEq => box_binary!(ty, op, Time64LtEqBinaryEvaluator), + BinaryOperator::Eq => box_binary!(ty, op, Time64EqBinaryEvaluator), + BinaryOperator::NotEq => box_binary!(ty, op, Time64NotEqBinaryEvaluator), _ => Err(DatabaseError::UnsupportedBinaryOperator(ty.clone(), op)), }, LogicalType::Decimal(_, _) => numeric_binary_evaluator!(Decimal, op, ty), LogicalType::Boolean => match op { - BinaryOperator::And => Ok(BinaryEvaluatorBox(Arc::new(BooleanAndBinaryEvaluator))), - BinaryOperator::Or => Ok(BinaryEvaluatorBox(Arc::new(BooleanOrBinaryEvaluator))), - BinaryOperator::Eq => Ok(BinaryEvaluatorBox(Arc::new(BooleanEqBinaryEvaluator))), - BinaryOperator::NotEq => Ok(BinaryEvaluatorBox(Arc::new(BooleanNotEqBinaryEvaluator))), + BinaryOperator::And => box_binary!(ty, op, BooleanAndBinaryEvaluator), + BinaryOperator::Or => box_binary!(ty, op, BooleanOrBinaryEvaluator), + BinaryOperator::Eq => box_binary!(ty, op, BooleanEqBinaryEvaluator), + BinaryOperator::NotEq => box_binary!(ty, op, BooleanNotEqBinaryEvaluator), _ => Err(DatabaseError::UnsupportedBinaryOperator( LogicalType::Boolean, op, )), }, LogicalType::Varchar(_, _) | LogicalType::Char(_, _) => match op { - BinaryOperator::Gt => Ok(BinaryEvaluatorBox(Arc::new(Utf8GtBinaryEvaluator))), - BinaryOperator::Lt => Ok(BinaryEvaluatorBox(Arc::new(Utf8LtBinaryEvaluator))), - BinaryOperator::GtEq => Ok(BinaryEvaluatorBox(Arc::new(Utf8GtEqBinaryEvaluator))), - BinaryOperator::LtEq => Ok(BinaryEvaluatorBox(Arc::new(Utf8LtEqBinaryEvaluator))), - BinaryOperator::Eq => Ok(BinaryEvaluatorBox(Arc::new(Utf8EqBinaryEvaluator))), - BinaryOperator::NotEq => Ok(BinaryEvaluatorBox(Arc::new(Utf8NotEqBinaryEvaluator))), - BinaryOperator::StringConcat => Ok(BinaryEvaluatorBox(Arc::new( - Utf8StringConcatBinaryEvaluator, - ))), + BinaryOperator::Gt => box_binary!(ty, op, Utf8GtBinaryEvaluator), + BinaryOperator::Lt => box_binary!(ty, op, Utf8LtBinaryEvaluator), + BinaryOperator::GtEq => box_binary!(ty, op, Utf8GtEqBinaryEvaluator), + BinaryOperator::LtEq => box_binary!(ty, op, Utf8LtEqBinaryEvaluator), + BinaryOperator::Eq => box_binary!(ty, op, Utf8EqBinaryEvaluator), + BinaryOperator::NotEq => box_binary!(ty, op, Utf8NotEqBinaryEvaluator), + BinaryOperator::StringConcat => box_binary!(ty, op, Utf8StringConcatBinaryEvaluator), BinaryOperator::Like(escape_char) => { - Ok(BinaryEvaluatorBox(Arc::new(Utf8LikeBinaryEvaluator { - escape_char, - }))) + box_binary!(ty, op, Utf8LikeBinaryEvaluator { escape_char }) } BinaryOperator::NotLike(escape_char) => { - Ok(BinaryEvaluatorBox(Arc::new(Utf8NotLikeBinaryEvaluator { - escape_char, - }))) + box_binary!(ty, op, Utf8NotLikeBinaryEvaluator { escape_char }) } _ => Err(DatabaseError::UnsupportedBinaryOperator(ty.clone(), op)), }, - LogicalType::SqlNull => Ok(BinaryEvaluatorBox(Arc::new(NullBinaryEvaluator))), + LogicalType::SqlNull => box_binary!(ty, op, NullBinaryEvaluator), LogicalType::Tuple(_) => match op { - BinaryOperator::Eq => Ok(BinaryEvaluatorBox(Arc::new(TupleEqBinaryEvaluator))), - BinaryOperator::NotEq => Ok(BinaryEvaluatorBox(Arc::new(TupleNotEqBinaryEvaluator))), - BinaryOperator::Gt => Ok(BinaryEvaluatorBox(Arc::new(TupleGtBinaryEvaluator))), - BinaryOperator::GtEq => Ok(BinaryEvaluatorBox(Arc::new(TupleGtEqBinaryEvaluator))), - BinaryOperator::Lt => Ok(BinaryEvaluatorBox(Arc::new(TupleLtBinaryEvaluator))), - BinaryOperator::LtEq => Ok(BinaryEvaluatorBox(Arc::new(TupleLtEqBinaryEvaluator))), + BinaryOperator::Eq => box_binary!(ty, op, TupleEqBinaryEvaluator), + BinaryOperator::NotEq => box_binary!(ty, op, TupleNotEqBinaryEvaluator), + BinaryOperator::Gt => box_binary!(ty, op, TupleGtBinaryEvaluator), + BinaryOperator::GtEq => box_binary!(ty, op, TupleGtEqBinaryEvaluator), + BinaryOperator::Lt => box_binary!(ty, op, TupleLtBinaryEvaluator), + BinaryOperator::LtEq => box_binary!(ty, op, TupleLtEqBinaryEvaluator), _ => Err(DatabaseError::UnsupportedBinaryOperator(ty.clone(), op)), }, } @@ -172,10 +176,7 @@ macro_rules! numeric_binary_evaluator_definition { #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] pub struct [<$value_type NotEqBinaryEvaluator>]; #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] - pub struct [<$value_type ModBinaryEvaluator>]; - - #[typetag::serde] - impl $crate::types::evaluator::BinaryEvaluator for [<$value_type PlusBinaryEvaluator>] { + pub struct [<$value_type ModBinaryEvaluator>]; impl $crate::types::evaluator::BinaryEvaluator for [<$value_type PlusBinaryEvaluator>] { fn binary_eval( &self, left: &$crate::types::value::DataValue, @@ -189,9 +190,7 @@ macro_rules! numeric_binary_evaluator_definition { _ => unsafe { std::hint::unreachable_unchecked() }, }) } - } - #[typetag::serde] - impl $crate::types::evaluator::BinaryEvaluator for [<$value_type MinusBinaryEvaluator>] { + } impl $crate::types::evaluator::BinaryEvaluator for [<$value_type MinusBinaryEvaluator>] { fn binary_eval( &self, left: &$crate::types::value::DataValue, @@ -205,9 +204,7 @@ macro_rules! numeric_binary_evaluator_definition { _ => unsafe { std::hint::unreachable_unchecked() }, }) } - } - #[typetag::serde] - impl $crate::types::evaluator::BinaryEvaluator for [<$value_type MultiplyBinaryEvaluator>] { + } impl $crate::types::evaluator::BinaryEvaluator for [<$value_type MultiplyBinaryEvaluator>] { fn binary_eval( &self, left: &$crate::types::value::DataValue, @@ -221,9 +218,7 @@ macro_rules! numeric_binary_evaluator_definition { _ => unsafe { std::hint::unreachable_unchecked() }, }) } - } - #[typetag::serde] - impl $crate::types::evaluator::BinaryEvaluator for [<$value_type DivideBinaryEvaluator>] { + } impl $crate::types::evaluator::BinaryEvaluator for [<$value_type DivideBinaryEvaluator>] { fn binary_eval( &self, left: &$crate::types::value::DataValue, @@ -237,9 +232,7 @@ macro_rules! numeric_binary_evaluator_definition { _ => unsafe { std::hint::unreachable_unchecked() }, }) } - } - #[typetag::serde] - impl $crate::types::evaluator::BinaryEvaluator for [<$value_type GtBinaryEvaluator>] { + } impl $crate::types::evaluator::BinaryEvaluator for [<$value_type GtBinaryEvaluator>] { fn binary_eval( &self, left: &$crate::types::value::DataValue, @@ -253,9 +246,7 @@ macro_rules! numeric_binary_evaluator_definition { _ => unsafe { std::hint::unreachable_unchecked() }, }) } - } - #[typetag::serde] - impl $crate::types::evaluator::BinaryEvaluator for [<$value_type GtEqBinaryEvaluator>] { + } impl $crate::types::evaluator::BinaryEvaluator for [<$value_type GtEqBinaryEvaluator>] { fn binary_eval( &self, left: &$crate::types::value::DataValue, @@ -269,9 +260,7 @@ macro_rules! numeric_binary_evaluator_definition { _ => unsafe { std::hint::unreachable_unchecked() }, }) } - } - #[typetag::serde] - impl $crate::types::evaluator::BinaryEvaluator for [<$value_type LtBinaryEvaluator>] { + } impl $crate::types::evaluator::BinaryEvaluator for [<$value_type LtBinaryEvaluator>] { fn binary_eval( &self, left: &$crate::types::value::DataValue, @@ -285,9 +274,7 @@ macro_rules! numeric_binary_evaluator_definition { _ => unsafe { std::hint::unreachable_unchecked() }, }) } - } - #[typetag::serde] - impl $crate::types::evaluator::BinaryEvaluator for [<$value_type LtEqBinaryEvaluator>] { + } impl $crate::types::evaluator::BinaryEvaluator for [<$value_type LtEqBinaryEvaluator>] { fn binary_eval( &self, left: &$crate::types::value::DataValue, @@ -301,9 +288,7 @@ macro_rules! numeric_binary_evaluator_definition { _ => unsafe { std::hint::unreachable_unchecked() }, }) } - } - #[typetag::serde] - impl $crate::types::evaluator::BinaryEvaluator for [<$value_type EqBinaryEvaluator>] { + } impl $crate::types::evaluator::BinaryEvaluator for [<$value_type EqBinaryEvaluator>] { fn binary_eval( &self, left: &$crate::types::value::DataValue, @@ -317,9 +302,7 @@ macro_rules! numeric_binary_evaluator_definition { _ => unsafe { std::hint::unreachable_unchecked() }, }) } - } - #[typetag::serde] - impl $crate::types::evaluator::BinaryEvaluator for [<$value_type NotEqBinaryEvaluator>] { + } impl $crate::types::evaluator::BinaryEvaluator for [<$value_type NotEqBinaryEvaluator>] { fn binary_eval( &self, left: &$crate::types::value::DataValue, @@ -333,9 +316,7 @@ macro_rules! numeric_binary_evaluator_definition { _ => unsafe { std::hint::unreachable_unchecked() }, }) } - } - #[typetag::serde] - impl $crate::types::evaluator::BinaryEvaluator for [<$value_type ModBinaryEvaluator>] { + } impl $crate::types::evaluator::BinaryEvaluator for [<$value_type ModBinaryEvaluator>] { fn binary_eval( &self, left: &$crate::types::value::DataValue, diff --git a/src/types/evaluator/boolean.rs b/src/types/evaluator/boolean.rs index 897dac74..31a530e9 100644 --- a/src/types/evaluator/boolean.rs +++ b/src/types/evaluator/boolean.rs @@ -16,9 +16,9 @@ use crate::errors::DatabaseError; use crate::types::evaluator::cast::{to_char, to_varchar}; use crate::types::evaluator::DataValue; use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator}; +use crate::types::CharLengthUnits; use ordered_float::OrderedFloat; use serde::{Deserialize, Serialize}; -use sqlparser::ast::CharLengthUnits; use std::hint; #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] @@ -31,8 +31,6 @@ pub struct BooleanOrBinaryEvaluator; pub struct BooleanEqBinaryEvaluator; #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] pub struct BooleanNotEqBinaryEvaluator; - -#[typetag::serde] impl UnaryEvaluator for BooleanNotUnaryEvaluator { fn unary_eval(&self, value: &DataValue) -> DataValue { match value { @@ -42,7 +40,6 @@ impl UnaryEvaluator for BooleanNotUnaryEvaluator { } } } -#[typetag::serde] impl BinaryEvaluator for BooleanAndBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -56,8 +53,6 @@ impl BinaryEvaluator for BooleanAndBinaryEvaluator { }) } } - -#[typetag::serde] impl BinaryEvaluator for BooleanOrBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -71,8 +66,6 @@ impl BinaryEvaluator for BooleanOrBinaryEvaluator { }) } } - -#[typetag::serde] impl BinaryEvaluator for BooleanEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -129,8 +122,6 @@ crate::define_cast_evaluator!( }, DataValue::Boolean(value) => |this| to_varchar(value.to_string(), this.len, this.unit) ); - -#[typetag::serde] impl BinaryEvaluator for BooleanNotEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { diff --git a/src/types/evaluator/cast.rs b/src/types/evaluator/cast.rs index 40beee08..0f337cc7 100644 --- a/src/types/evaluator/cast.rs +++ b/src/types/evaluator/cast.rs @@ -34,10 +34,10 @@ use crate::types::evaluator::uint8::*; use crate::types::evaluator::utf8::*; use crate::types::evaluator::{CastEvaluator, CastEvaluatorBox}; use crate::types::value::{DataValue, Utf8Type}; +use crate::types::CharLengthUnits; use crate::types::LogicalType; use paste::paste; use serde::{Deserialize, Serialize}; -use sqlparser::ast::CharLengthUnits; use std::borrow::Cow; use std::sync::Arc; @@ -142,8 +142,6 @@ macro_rules! define_cast_evaluator { ($name:ident, $pattern:pat => $body:block) => { #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] pub struct $name; - - #[typetag::serde] impl $crate::types::evaluator::CastEvaluator for $name { fn eval_cast( &self, @@ -160,8 +158,6 @@ macro_rules! define_cast_evaluator { ($name:ident, $pattern:pat => |$this:ident| $body:expr) => { #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] pub struct $name; - - #[typetag::serde] impl $crate::types::evaluator::CastEvaluator for $name { fn eval_cast( &self, @@ -181,8 +177,6 @@ macro_rules! define_cast_evaluator { ($name:ident, $pattern:pat => |$this:ident| $body:block) => { #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] pub struct $name; - - #[typetag::serde] impl $crate::types::evaluator::CastEvaluator for $name { fn eval_cast( &self, @@ -204,8 +198,6 @@ macro_rules! define_cast_evaluator { pub struct $name { $(pub $field: $field_ty),+ } - - #[typetag::serde] impl $crate::types::evaluator::CastEvaluator for $name { fn eval_cast( &self, @@ -227,8 +219,6 @@ macro_rules! define_cast_evaluator { pub struct $name { $(pub $field: $field_ty),+ } - - #[typetag::serde] impl $crate::types::evaluator::CastEvaluator for $name { fn eval_cast( &self, @@ -284,7 +274,7 @@ macro_rules! define_integer_cast_evaluators { $crate::define_cast_evaluator!( [<$prefix ToCharCastEvaluator>] { len: u32, - unit: sqlparser::ast::CharLengthUnits + unit: crate::types::CharLengthUnits }, $crate::types::value::DataValue::$variant(value) => |this| { $crate::types::evaluator::cast::to_char(value.to_string(), this.len, this.unit) @@ -293,7 +283,7 @@ macro_rules! define_integer_cast_evaluators { $crate::define_cast_evaluator!( [<$prefix ToVarcharCastEvaluator>] { len: Option, - unit: sqlparser::ast::CharLengthUnits + unit: crate::types::CharLengthUnits }, $crate::types::value::DataValue::$variant(value) => |this| { $crate::types::evaluator::cast::to_varchar(value.to_string(), this.len, this.unit) @@ -350,7 +340,7 @@ macro_rules! define_float_cast_evaluators { $crate::define_cast_evaluator!( [<$prefix ToCharCastEvaluator>] { len: u32, - unit: sqlparser::ast::CharLengthUnits + unit: crate::types::CharLengthUnits }, $crate::types::value::DataValue::$variant(value) => |this| { $crate::types::evaluator::cast::to_char(value.to_string(), this.len, this.unit) @@ -359,7 +349,7 @@ macro_rules! define_float_cast_evaluators { $crate::define_cast_evaluator!( [<$prefix ToVarcharCastEvaluator>] { len: Option, - unit: sqlparser::ast::CharLengthUnits + unit: crate::types::CharLengthUnits }, $crate::types::value::DataValue::$variant(value) => |this| { $crate::types::evaluator::cast::to_varchar(value.to_string(), this.len, this.unit) @@ -384,8 +374,6 @@ macro_rules! define_float_cast_evaluators { #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] pub struct IdentityCastEvaluator; - -#[typetag::serde] impl CastEvaluator for IdentityCastEvaluator { fn eval_cast(&self, value: &DataValue) -> Result { Ok(value.clone()) @@ -393,8 +381,12 @@ impl CastEvaluator for IdentityCastEvaluator { } macro_rules! box_cast { - ($evaluator:expr) => { - Ok(CastEvaluatorBox(Arc::new($evaluator))) + ($from:expr, $to:expr, $evaluator:expr) => { + Ok(CastEvaluatorBox::new( + Arc::new($evaluator), + $from.clone(), + $to.clone(), + )) }; } @@ -402,21 +394,21 @@ macro_rules! build_integer_cast { ($prefix:ident, $to:expr, $from:expr) => {{ paste! { match $to { - LogicalType::SqlNull => box_cast!(ToSqlNullCastEvaluator), - LogicalType::Boolean => box_cast!([<$prefix ToBooleanCastEvaluator>]), - LogicalType::Tinyint => box_cast!([<$prefix ToTinyintCastEvaluator>]), - LogicalType::UTinyint => box_cast!([<$prefix ToUTinyintCastEvaluator>]), - LogicalType::Smallint => box_cast!([<$prefix ToSmallintCastEvaluator>]), - LogicalType::USmallint => box_cast!([<$prefix ToUSmallintCastEvaluator>]), - LogicalType::Integer => box_cast!([<$prefix ToIntegerCastEvaluator>]), - LogicalType::UInteger => box_cast!([<$prefix ToUIntegerCastEvaluator>]), - LogicalType::Bigint => box_cast!([<$prefix ToBigintCastEvaluator>]), - LogicalType::UBigint => box_cast!([<$prefix ToUBigintCastEvaluator>]), - LogicalType::Float => box_cast!([<$prefix ToFloatCastEvaluator>]), - LogicalType::Double => box_cast!([<$prefix ToDoubleCastEvaluator>]), - LogicalType::Char(len, unit) => box_cast!([<$prefix ToCharCastEvaluator>] { len: *len, unit: *unit }), - LogicalType::Varchar(len, unit) => box_cast!([<$prefix ToVarcharCastEvaluator>] { len: *len, unit: *unit }), - LogicalType::Decimal(_, scale) => box_cast!([<$prefix ToDecimalCastEvaluator>] { scale: *scale }), + LogicalType::SqlNull => box_cast!($from, $to, ToSqlNullCastEvaluator), + LogicalType::Boolean => box_cast!($from, $to, [<$prefix ToBooleanCastEvaluator>]), + LogicalType::Tinyint => box_cast!($from, $to, [<$prefix ToTinyintCastEvaluator>]), + LogicalType::UTinyint => box_cast!($from, $to, [<$prefix ToUTinyintCastEvaluator>]), + LogicalType::Smallint => box_cast!($from, $to, [<$prefix ToSmallintCastEvaluator>]), + LogicalType::USmallint => box_cast!($from, $to, [<$prefix ToUSmallintCastEvaluator>]), + LogicalType::Integer => box_cast!($from, $to, [<$prefix ToIntegerCastEvaluator>]), + LogicalType::UInteger => box_cast!($from, $to, [<$prefix ToUIntegerCastEvaluator>]), + LogicalType::Bigint => box_cast!($from, $to, [<$prefix ToBigintCastEvaluator>]), + LogicalType::UBigint => box_cast!($from, $to, [<$prefix ToUBigintCastEvaluator>]), + LogicalType::Float => box_cast!($from, $to, [<$prefix ToFloatCastEvaluator>]), + LogicalType::Double => box_cast!($from, $to, [<$prefix ToDoubleCastEvaluator>]), + LogicalType::Char(len, unit) => box_cast!($from, $to, [<$prefix ToCharCastEvaluator>] { len: *len, unit: *unit }), + LogicalType::Varchar(len, unit) => box_cast!($from, $to, [<$prefix ToVarcharCastEvaluator>] { len: *len, unit: *unit }), + LogicalType::Decimal(_, scale) => box_cast!($from, $to, [<$prefix ToDecimalCastEvaluator>] { scale: *scale }), _ => Err(cast_fail($from.clone(), $to.clone())), } } @@ -430,35 +422,61 @@ pub fn cast_create( let from = from.as_ref(); let to = to.as_ref(); if from == to { - return box_cast!(IdentityCastEvaluator); + return box_cast!(from, to, IdentityCastEvaluator); } match (from, to) { - (LogicalType::SqlNull, _) => box_cast!(NullCastEvaluator), - (_, LogicalType::SqlNull) => box_cast!(ToSqlNullCastEvaluator), - (LogicalType::Boolean, LogicalType::Tinyint) => box_cast!(BooleanToTinyintCastEvaluator), - (LogicalType::Boolean, LogicalType::UTinyint) => box_cast!(BooleanToUTinyintCastEvaluator), - (LogicalType::Boolean, LogicalType::Smallint) => box_cast!(BooleanToSmallintCastEvaluator), + (LogicalType::SqlNull, _) => box_cast!(from, to, NullCastEvaluator), + (_, LogicalType::SqlNull) => box_cast!(from, to, ToSqlNullCastEvaluator), + (LogicalType::Boolean, LogicalType::Tinyint) => { + box_cast!(from, to, BooleanToTinyintCastEvaluator) + } + (LogicalType::Boolean, LogicalType::UTinyint) => { + box_cast!(from, to, BooleanToUTinyintCastEvaluator) + } + (LogicalType::Boolean, LogicalType::Smallint) => { + box_cast!(from, to, BooleanToSmallintCastEvaluator) + } (LogicalType::Boolean, LogicalType::USmallint) => { - box_cast!(BooleanToUSmallintCastEvaluator) - } - (LogicalType::Boolean, LogicalType::Integer) => box_cast!(BooleanToIntegerCastEvaluator), - (LogicalType::Boolean, LogicalType::UInteger) => box_cast!(BooleanToUIntegerCastEvaluator), - (LogicalType::Boolean, LogicalType::Bigint) => box_cast!(BooleanToBigintCastEvaluator), - (LogicalType::Boolean, LogicalType::UBigint) => box_cast!(BooleanToUBigintCastEvaluator), - (LogicalType::Boolean, LogicalType::Float) => box_cast!(BooleanToFloatCastEvaluator), - (LogicalType::Boolean, LogicalType::Double) => box_cast!(BooleanToDoubleCastEvaluator), + box_cast!(from, to, BooleanToUSmallintCastEvaluator) + } + (LogicalType::Boolean, LogicalType::Integer) => { + box_cast!(from, to, BooleanToIntegerCastEvaluator) + } + (LogicalType::Boolean, LogicalType::UInteger) => { + box_cast!(from, to, BooleanToUIntegerCastEvaluator) + } + (LogicalType::Boolean, LogicalType::Bigint) => { + box_cast!(from, to, BooleanToBigintCastEvaluator) + } + (LogicalType::Boolean, LogicalType::UBigint) => { + box_cast!(from, to, BooleanToUBigintCastEvaluator) + } + (LogicalType::Boolean, LogicalType::Float) => { + box_cast!(from, to, BooleanToFloatCastEvaluator) + } + (LogicalType::Boolean, LogicalType::Double) => { + box_cast!(from, to, BooleanToDoubleCastEvaluator) + } (LogicalType::Boolean, LogicalType::Char(len, unit)) => { - box_cast!(BooleanToCharCastEvaluator { - len: *len, - unit: *unit - }) + box_cast!( + from, + to, + BooleanToCharCastEvaluator { + len: *len, + unit: *unit + } + ) } (LogicalType::Boolean, LogicalType::Varchar(len, unit)) => { - box_cast!(BooleanToVarcharCastEvaluator { - len: *len, - unit: *unit - }) + box_cast!( + from, + to, + BooleanToVarcharCastEvaluator { + len: *len, + unit: *unit + } + ) } (LogicalType::Tinyint, _) => build_integer_cast!(Int8, to, from), (LogicalType::Smallint, _) => build_integer_cast!(Int16, to, from), @@ -468,275 +486,421 @@ pub fn cast_create( (LogicalType::USmallint, _) => build_integer_cast!(UInt16, to, from), (LogicalType::UInteger, _) => build_integer_cast!(UInt32, to, from), (LogicalType::UBigint, _) => build_integer_cast!(UInt64, to, from), - (LogicalType::Float, LogicalType::Tinyint) => box_cast!(Float32ToTinyintCastEvaluator), - (LogicalType::Float, LogicalType::UTinyint) => box_cast!(Float32ToUTinyintCastEvaluator), - (LogicalType::Float, LogicalType::Smallint) => box_cast!(Float32ToSmallintCastEvaluator), - (LogicalType::Float, LogicalType::USmallint) => box_cast!(Float32ToUSmallintCastEvaluator), - (LogicalType::Float, LogicalType::Integer) => box_cast!(Float32ToIntegerCastEvaluator), - (LogicalType::Float, LogicalType::UInteger) => box_cast!(Float32ToUIntegerCastEvaluator), - (LogicalType::Float, LogicalType::Bigint) => box_cast!(Float32ToBigintCastEvaluator), - (LogicalType::Float, LogicalType::UBigint) => box_cast!(Float32ToUBigintCastEvaluator), - (LogicalType::Float, LogicalType::Double) => box_cast!(Float32ToDoubleCastEvaluator), + (LogicalType::Float, LogicalType::Tinyint) => { + box_cast!(from, to, Float32ToTinyintCastEvaluator) + } + (LogicalType::Float, LogicalType::UTinyint) => { + box_cast!(from, to, Float32ToUTinyintCastEvaluator) + } + (LogicalType::Float, LogicalType::Smallint) => { + box_cast!(from, to, Float32ToSmallintCastEvaluator) + } + (LogicalType::Float, LogicalType::USmallint) => { + box_cast!(from, to, Float32ToUSmallintCastEvaluator) + } + (LogicalType::Float, LogicalType::Integer) => { + box_cast!(from, to, Float32ToIntegerCastEvaluator) + } + (LogicalType::Float, LogicalType::UInteger) => { + box_cast!(from, to, Float32ToUIntegerCastEvaluator) + } + (LogicalType::Float, LogicalType::Bigint) => { + box_cast!(from, to, Float32ToBigintCastEvaluator) + } + (LogicalType::Float, LogicalType::UBigint) => { + box_cast!(from, to, Float32ToUBigintCastEvaluator) + } + (LogicalType::Float, LogicalType::Double) => { + box_cast!(from, to, Float32ToDoubleCastEvaluator) + } (LogicalType::Float, LogicalType::Char(len, unit)) => { - box_cast!(Float32ToCharCastEvaluator { - len: *len, - unit: *unit - }) + box_cast!( + from, + to, + Float32ToCharCastEvaluator { + len: *len, + unit: *unit + } + ) } (LogicalType::Float, LogicalType::Varchar(len, unit)) => { - box_cast!(Float32ToVarcharCastEvaluator { - len: *len, - unit: *unit - }) + box_cast!( + from, + to, + Float32ToVarcharCastEvaluator { + len: *len, + unit: *unit + } + ) } (LogicalType::Float, LogicalType::Decimal(_, scale)) => { - box_cast!(Float32ToDecimalCastEvaluator { - scale: *scale, - to: to.clone() - }) - } - (LogicalType::Double, LogicalType::Float) => box_cast!(Float64ToFloatCastEvaluator), - (LogicalType::Double, LogicalType::Tinyint) => box_cast!(Float64ToTinyintCastEvaluator), - (LogicalType::Double, LogicalType::UTinyint) => box_cast!(Float64ToUTinyintCastEvaluator), - (LogicalType::Double, LogicalType::Smallint) => box_cast!(Float64ToSmallintCastEvaluator), - (LogicalType::Double, LogicalType::USmallint) => box_cast!(Float64ToUSmallintCastEvaluator), - (LogicalType::Double, LogicalType::Integer) => box_cast!(Float64ToIntegerCastEvaluator), - (LogicalType::Double, LogicalType::UInteger) => box_cast!(Float64ToUIntegerCastEvaluator), - (LogicalType::Double, LogicalType::Bigint) => box_cast!(Float64ToBigintCastEvaluator), - (LogicalType::Double, LogicalType::UBigint) => box_cast!(Float64ToUBigintCastEvaluator), + box_cast!( + from, + to, + Float32ToDecimalCastEvaluator { + scale: *scale, + to: to.clone() + } + ) + } + (LogicalType::Double, LogicalType::Float) => { + box_cast!(from, to, Float64ToFloatCastEvaluator) + } + (LogicalType::Double, LogicalType::Tinyint) => { + box_cast!(from, to, Float64ToTinyintCastEvaluator) + } + (LogicalType::Double, LogicalType::UTinyint) => { + box_cast!(from, to, Float64ToUTinyintCastEvaluator) + } + (LogicalType::Double, LogicalType::Smallint) => { + box_cast!(from, to, Float64ToSmallintCastEvaluator) + } + (LogicalType::Double, LogicalType::USmallint) => { + box_cast!(from, to, Float64ToUSmallintCastEvaluator) + } + (LogicalType::Double, LogicalType::Integer) => { + box_cast!(from, to, Float64ToIntegerCastEvaluator) + } + (LogicalType::Double, LogicalType::UInteger) => { + box_cast!(from, to, Float64ToUIntegerCastEvaluator) + } + (LogicalType::Double, LogicalType::Bigint) => { + box_cast!(from, to, Float64ToBigintCastEvaluator) + } + (LogicalType::Double, LogicalType::UBigint) => { + box_cast!(from, to, Float64ToUBigintCastEvaluator) + } (LogicalType::Double, LogicalType::Char(len, unit)) => { - box_cast!(Float64ToCharCastEvaluator { - len: *len, - unit: *unit - }) + box_cast!( + from, + to, + Float64ToCharCastEvaluator { + len: *len, + unit: *unit + } + ) } (LogicalType::Double, LogicalType::Varchar(len, unit)) => { - box_cast!(Float64ToVarcharCastEvaluator { - len: *len, - unit: *unit - }) + box_cast!( + from, + to, + Float64ToVarcharCastEvaluator { + len: *len, + unit: *unit + } + ) } (LogicalType::Double, LogicalType::Decimal(_, scale)) => { - box_cast!(Float64ToDecimalCastEvaluator { - scale: *scale, - to: to.clone() - }) + box_cast!( + from, + to, + Float64ToDecimalCastEvaluator { + scale: *scale, + to: to.clone() + } + ) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Boolean) => { - box_cast!(Utf8ToBooleanCastEvaluator { from: from.clone() }) + box_cast!(from, to, Utf8ToBooleanCastEvaluator { from: from.clone() }) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Tinyint) => { - box_cast!(Utf8ToTinyintCastEvaluator) + box_cast!(from, to, Utf8ToTinyintCastEvaluator) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::UTinyint) => { - box_cast!(Utf8ToUTinyintCastEvaluator) + box_cast!(from, to, Utf8ToUTinyintCastEvaluator) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Smallint) => { - box_cast!(Utf8ToSmallintCastEvaluator) + box_cast!(from, to, Utf8ToSmallintCastEvaluator) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::USmallint) => { - box_cast!(Utf8ToUSmallintCastEvaluator) + box_cast!(from, to, Utf8ToUSmallintCastEvaluator) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Integer) => { - box_cast!(Utf8ToIntegerCastEvaluator) + box_cast!(from, to, Utf8ToIntegerCastEvaluator) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::UInteger) => { - box_cast!(Utf8ToUIntegerCastEvaluator) + box_cast!(from, to, Utf8ToUIntegerCastEvaluator) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Bigint) => { - box_cast!(Utf8ToBigintCastEvaluator) + box_cast!(from, to, Utf8ToBigintCastEvaluator) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::UBigint) => { - box_cast!(Utf8ToUBigintCastEvaluator) + box_cast!(from, to, Utf8ToUBigintCastEvaluator) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Float) => { - box_cast!(Utf8ToFloatCastEvaluator) + box_cast!(from, to, Utf8ToFloatCastEvaluator) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Double) => { - box_cast!(Utf8ToDoubleCastEvaluator) + box_cast!(from, to, Utf8ToDoubleCastEvaluator) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Char(len, unit)) => { - box_cast!(Utf8ToCharCastEvaluator { - len: *len, - unit: *unit - }) + box_cast!( + from, + to, + Utf8ToCharCastEvaluator { + len: *len, + unit: *unit + } + ) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Varchar(len, unit)) => { - box_cast!(Utf8ToVarcharCastEvaluator { - len: *len, - unit: *unit - }) + box_cast!( + from, + to, + Utf8ToVarcharCastEvaluator { + len: *len, + unit: *unit + } + ) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Date) => { - box_cast!(Utf8ToDateCastEvaluator) + box_cast!(from, to, Utf8ToDateCastEvaluator) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::DateTime) => { - box_cast!(Utf8ToDatetimeCastEvaluator) + box_cast!(from, to, Utf8ToDatetimeCastEvaluator) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Time(precision)) => { - box_cast!(Utf8ToTimeCastEvaluator { - precision: *precision - }) + box_cast!( + from, + to, + Utf8ToTimeCastEvaluator { + precision: *precision + } + ) } ( LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::TimeStamp(precision, zone), ) => { - box_cast!(Utf8ToTimestampCastEvaluator { - precision: *precision, - zone: *zone, - to: to.clone() - }) + box_cast!( + from, + to, + Utf8ToTimestampCastEvaluator { + precision: *precision, + zone: *zone, + to: to.clone() + } + ) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Decimal(_, _)) => { - box_cast!(Utf8ToDecimalCastEvaluator) + box_cast!(from, to, Utf8ToDecimalCastEvaluator) } (LogicalType::Date, LogicalType::Char(len, unit)) => { - box_cast!(Date32ToCharCastEvaluator { - len: *len, - unit: *unit, - to: to.clone() - }) + box_cast!( + from, + to, + Date32ToCharCastEvaluator { + len: *len, + unit: *unit, + to: to.clone() + } + ) } (LogicalType::Date, LogicalType::Varchar(len, unit)) => { - box_cast!(Date32ToVarcharCastEvaluator { - len: *len, - unit: *unit, - to: to.clone() - }) + box_cast!( + from, + to, + Date32ToVarcharCastEvaluator { + len: *len, + unit: *unit, + to: to.clone() + } + ) } (LogicalType::Date, LogicalType::DateTime) => { - box_cast!(Date32ToDatetimeCastEvaluator { to: to.clone() }) + box_cast!(from, to, Date32ToDatetimeCastEvaluator { to: to.clone() }) } (LogicalType::DateTime, LogicalType::Char(len, unit)) => { - box_cast!(Date64ToCharCastEvaluator { - len: *len, - unit: *unit, - to: to.clone() - }) + box_cast!( + from, + to, + Date64ToCharCastEvaluator { + len: *len, + unit: *unit, + to: to.clone() + } + ) } (LogicalType::DateTime, LogicalType::Varchar(len, unit)) => { - box_cast!(Date64ToVarcharCastEvaluator { - len: *len, - unit: *unit, - to: to.clone() - }) + box_cast!( + from, + to, + Date64ToVarcharCastEvaluator { + len: *len, + unit: *unit, + to: to.clone() + } + ) } (LogicalType::DateTime, LogicalType::Date) => { - box_cast!(Date64ToDateCastEvaluator { to: to.clone() }) + box_cast!(from, to, Date64ToDateCastEvaluator { to: to.clone() }) } (LogicalType::DateTime, LogicalType::Time(precision)) => { - box_cast!(Date64ToTimeCastEvaluator { - precision: *precision, - to: to.clone() - }) + box_cast!( + from, + to, + Date64ToTimeCastEvaluator { + precision: *precision, + to: to.clone() + } + ) } (LogicalType::DateTime, LogicalType::TimeStamp(precision, zone)) => { - box_cast!(Date64ToTimestampCastEvaluator { - precision: *precision, - zone: *zone - }) + box_cast!( + from, + to, + Date64ToTimestampCastEvaluator { + precision: *precision, + zone: *zone + } + ) } (LogicalType::Time(_), LogicalType::Char(len, unit)) => { - box_cast!(Time32ToCharCastEvaluator { - len: *len, - unit: *unit, - to: to.clone() - }) + box_cast!( + from, + to, + Time32ToCharCastEvaluator { + len: *len, + unit: *unit, + to: to.clone() + } + ) } (LogicalType::Time(_), LogicalType::Varchar(len, unit)) => { - box_cast!(Time32ToVarcharCastEvaluator { - len: *len, - unit: *unit, - to: to.clone() - }) + box_cast!( + from, + to, + Time32ToVarcharCastEvaluator { + len: *len, + unit: *unit, + to: to.clone() + } + ) } (LogicalType::Time(_), LogicalType::Time(precision)) => { - box_cast!(Time32ToTimeCastEvaluator { - precision: *precision - }) + box_cast!( + from, + to, + Time32ToTimeCastEvaluator { + precision: *precision + } + ) } (LogicalType::TimeStamp(_, _), LogicalType::Char(len, unit)) => { - box_cast!(Time64ToCharCastEvaluator { - len: *len, - unit: *unit, - to: to.clone() - }) + box_cast!( + from, + to, + Time64ToCharCastEvaluator { + len: *len, + unit: *unit, + to: to.clone() + } + ) } (LogicalType::TimeStamp(_, _), LogicalType::Varchar(len, unit)) => { - box_cast!(Time64ToVarcharCastEvaluator { - len: *len, - unit: *unit, - to: to.clone() - }) + box_cast!( + from, + to, + Time64ToVarcharCastEvaluator { + len: *len, + unit: *unit, + to: to.clone() + } + ) } (LogicalType::TimeStamp(_, _), LogicalType::Date) => { - box_cast!(Time64ToDateCastEvaluator { - from: from.clone(), - to: to.clone() - }) + box_cast!( + from, + to, + Time64ToDateCastEvaluator { + from: from.clone(), + to: to.clone() + } + ) } (LogicalType::TimeStamp(_, _), LogicalType::DateTime) => { - box_cast!(Time64ToDatetimeCastEvaluator { - from: from.clone(), - to: to.clone() - }) + box_cast!( + from, + to, + Time64ToDatetimeCastEvaluator { + from: from.clone(), + to: to.clone() + } + ) } (LogicalType::TimeStamp(_, _), LogicalType::Time(precision)) => { - box_cast!(Time64ToTimeCastEvaluator { - precision: *precision, - from: from.clone(), - to: to.clone() - }) + box_cast!( + from, + to, + Time64ToTimeCastEvaluator { + precision: *precision, + from: from.clone(), + to: to.clone() + } + ) } (LogicalType::TimeStamp(_, _), LogicalType::TimeStamp(precision, zone)) => { - box_cast!(Time64ToTimestampCastEvaluator { - precision: *precision, - zone: *zone - }) + box_cast!( + from, + to, + Time64ToTimestampCastEvaluator { + precision: *precision, + zone: *zone + } + ) + } + (LogicalType::Decimal(_, _), LogicalType::Float) => { + box_cast!(from, to, DecimalToFloatCastEvaluator) } - (LogicalType::Decimal(_, _), LogicalType::Float) => box_cast!(DecimalToFloatCastEvaluator), (LogicalType::Decimal(_, _), LogicalType::Double) => { - box_cast!(DecimalToDoubleCastEvaluator) + box_cast!(from, to, DecimalToDoubleCastEvaluator) } (LogicalType::Decimal(_, _), LogicalType::Decimal(_, _)) => { - box_cast!(DecimalToDecimalCastEvaluator) + box_cast!(from, to, DecimalToDecimalCastEvaluator) } (LogicalType::Decimal(_, _), LogicalType::Char(len, unit)) => { - box_cast!(DecimalToCharCastEvaluator { - len: *len, - unit: *unit - }) + box_cast!( + from, + to, + DecimalToCharCastEvaluator { + len: *len, + unit: *unit + } + ) } (LogicalType::Decimal(_, _), LogicalType::Varchar(len, unit)) => { - box_cast!(DecimalToVarcharCastEvaluator { - len: *len, - unit: *unit - }) + box_cast!( + from, + to, + DecimalToVarcharCastEvaluator { + len: *len, + unit: *unit + } + ) } (LogicalType::Decimal(_, _), LogicalType::Tinyint) => { - box_cast!(DecimalToTinyintCastEvaluator) + box_cast!(from, to, DecimalToTinyintCastEvaluator) } (LogicalType::Decimal(_, _), LogicalType::Smallint) => { - box_cast!(DecimalToSmallintCastEvaluator) + box_cast!(from, to, DecimalToSmallintCastEvaluator) } (LogicalType::Decimal(_, _), LogicalType::Integer) => { - box_cast!(DecimalToIntegerCastEvaluator) + box_cast!(from, to, DecimalToIntegerCastEvaluator) } (LogicalType::Decimal(_, _), LogicalType::Bigint) => { - box_cast!(DecimalToBigintCastEvaluator) + box_cast!(from, to, DecimalToBigintCastEvaluator) } (LogicalType::Decimal(_, _), LogicalType::UTinyint) => { - box_cast!(DecimalToUTinyintCastEvaluator) + box_cast!(from, to, DecimalToUTinyintCastEvaluator) } (LogicalType::Decimal(_, _), LogicalType::USmallint) => { - box_cast!(DecimalToUSmallintCastEvaluator) + box_cast!(from, to, DecimalToUSmallintCastEvaluator) } (LogicalType::Decimal(_, _), LogicalType::UInteger) => { - box_cast!(DecimalToUIntegerCastEvaluator) + box_cast!(from, to, DecimalToUIntegerCastEvaluator) } (LogicalType::Decimal(_, _), LogicalType::UBigint) => { - box_cast!(DecimalToUBigintCastEvaluator) + box_cast!(from, to, DecimalToUBigintCastEvaluator) } (LogicalType::Tuple(from_types), LogicalType::Tuple(to_types)) => { let evaluators = from_types @@ -744,9 +908,13 @@ pub fn cast_create( .zip(to_types.iter()) .map(|(from, to)| cast_create(Cow::Borrowed(from), Cow::Borrowed(to))) .collect::, _>>()?; - box_cast!(TupleCastEvaluator { - element_evaluators: evaluators - }) + box_cast!( + from, + to, + TupleCastEvaluator { + element_evaluators: evaluators + } + ) } _ => Err(cast_fail(from.clone(), to.clone())), } diff --git a/src/types/evaluator/date.rs b/src/types/evaluator/date.rs index 4eafe0c9..16457edb 100644 --- a/src/types/evaluator/date.rs +++ b/src/types/evaluator/date.rs @@ -15,9 +15,9 @@ use crate::numeric_binary_evaluator_definition; use crate::types::evaluator::cast::{cast_fail, to_char, to_varchar}; use crate::types::evaluator::DataValue; +use crate::types::CharLengthUnits; use crate::types::LogicalType; use chrono::NaiveDate; -use sqlparser::ast::CharLengthUnits; numeric_binary_evaluator_definition!(Date, DataValue::Date32); crate::define_cast_evaluator!( diff --git a/src/types/evaluator/datetime.rs b/src/types/evaluator/datetime.rs index b2c428d7..53094446 100644 --- a/src/types/evaluator/datetime.rs +++ b/src/types/evaluator/datetime.rs @@ -15,9 +15,9 @@ use crate::numeric_binary_evaluator_definition; use crate::types::evaluator::cast::{cast_fail, to_char, to_varchar}; use crate::types::evaluator::DataValue; +use crate::types::CharLengthUnits; use crate::types::LogicalType; use chrono::{DateTime, Datelike, Timelike}; -use sqlparser::ast::CharLengthUnits; numeric_binary_evaluator_definition!(DateTime, DataValue::Date64); crate::define_cast_evaluator!( @@ -91,7 +91,7 @@ mod test { use super::*; use crate::types::evaluator::CastEvaluator; use crate::types::value::Utf8Type; - use sqlparser::ast::CharLengthUnits; + use crate::types::CharLengthUnits; #[test] fn test_datetime_cast_evaluators() { diff --git a/src/types/evaluator/decimal.rs b/src/types/evaluator/decimal.rs index 80554863..0ba2f514 100644 --- a/src/types/evaluator/decimal.rs +++ b/src/types/evaluator/decimal.rs @@ -16,10 +16,10 @@ use crate::errors::DatabaseError; use crate::types::evaluator::cast::{to_char, to_varchar}; use crate::types::evaluator::BinaryEvaluator; use crate::types::evaluator::DataValue; +use crate::types::CharLengthUnits; use ordered_float::OrderedFloat; use rust_decimal::prelude::ToPrimitive; use serde::{Deserialize, Serialize}; -use sqlparser::ast::CharLengthUnits; use std::hint; #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] @@ -44,8 +44,6 @@ pub struct DecimalEqBinaryEvaluator; pub struct DecimalNotEqBinaryEvaluator; #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] pub struct DecimalModBinaryEvaluator; - -#[typetag::serde] impl BinaryEvaluator for DecimalPlusBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -57,7 +55,6 @@ impl BinaryEvaluator for DecimalPlusBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for DecimalMinusBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -127,8 +124,6 @@ crate::define_cast_evaluator!(DecimalToUIntegerCastEvaluator, DataValue::Decimal crate::define_cast_evaluator!(DecimalToUBigintCastEvaluator, DataValue::Decimal(value) => { Ok(DataValue::UInt64(crate::decimal_to_int_cast!(*value, u64))) }); - -#[typetag::serde] impl BinaryEvaluator for DecimalMultiplyBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -140,7 +135,6 @@ impl BinaryEvaluator for DecimalMultiplyBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for DecimalDivideBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -152,7 +146,6 @@ impl BinaryEvaluator for DecimalDivideBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for DecimalGtBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -164,7 +157,6 @@ impl BinaryEvaluator for DecimalGtBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for DecimalGtEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -176,7 +168,6 @@ impl BinaryEvaluator for DecimalGtEqBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for DecimalLtBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -188,7 +179,6 @@ impl BinaryEvaluator for DecimalLtBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for DecimalLtEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -200,7 +190,6 @@ impl BinaryEvaluator for DecimalLtEqBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for DecimalEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -212,7 +201,6 @@ impl BinaryEvaluator for DecimalEqBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for DecimalNotEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -224,7 +212,6 @@ impl BinaryEvaluator for DecimalNotEqBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for DecimalModBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -242,8 +229,8 @@ mod test { use super::*; use crate::types::evaluator::CastEvaluator; use crate::types::value::Utf8Type; + use crate::types::CharLengthUnits; use rust_decimal::Decimal; - use sqlparser::ast::CharLengthUnits; #[test] fn test_decimal_cast_evaluators() { diff --git a/src/types/evaluator/float32.rs b/src/types/evaluator/float32.rs index 8c630eed..d8d9be17 100644 --- a/src/types/evaluator/float32.rs +++ b/src/types/evaluator/float32.rs @@ -24,14 +24,11 @@ use std::hint; pub struct Float32PlusUnaryEvaluator; #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] pub struct Float32MinusUnaryEvaluator; - -#[typetag::serde] impl UnaryEvaluator for Float32PlusUnaryEvaluator { fn unary_eval(&self, value: &DataValue) -> DataValue { value.clone() } } -#[typetag::serde] impl UnaryEvaluator for Float32MinusUnaryEvaluator { fn unary_eval(&self, value: &DataValue) -> DataValue { match value { @@ -64,8 +61,6 @@ pub struct Float32EqBinaryEvaluator; pub struct Float32NotEqBinaryEvaluator; #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] pub struct Float32ModBinaryEvaluator; - -#[typetag::serde] impl BinaryEvaluator for Float32PlusBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -77,7 +72,6 @@ impl BinaryEvaluator for Float32PlusBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Float32MinusBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -91,7 +85,6 @@ impl BinaryEvaluator for Float32MinusBinaryEvaluator { } crate::define_float_cast_evaluators!(Float32, Float32, f32, LogicalType::Float, from_f32); -#[typetag::serde] impl BinaryEvaluator for Float32MultiplyBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -103,7 +96,6 @@ impl BinaryEvaluator for Float32MultiplyBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Float32DivideBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -117,7 +109,6 @@ impl BinaryEvaluator for Float32DivideBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Float32GtBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -129,7 +120,6 @@ impl BinaryEvaluator for Float32GtBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Float32GtEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -141,7 +131,6 @@ impl BinaryEvaluator for Float32GtEqBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Float32LtBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -153,7 +142,6 @@ impl BinaryEvaluator for Float32LtBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Float32LtEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -165,7 +153,6 @@ impl BinaryEvaluator for Float32LtEqBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Float32EqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -177,7 +164,6 @@ impl BinaryEvaluator for Float32EqBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Float32NotEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -189,7 +175,6 @@ impl BinaryEvaluator for Float32NotEqBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Float32ModBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -207,9 +192,9 @@ mod test { use super::*; use crate::types::evaluator::CastEvaluator; use crate::types::value::Utf8Type; + use crate::types::CharLengthUnits; use ordered_float::OrderedFloat; use rust_decimal::Decimal; - use sqlparser::ast::CharLengthUnits; #[test] fn test_float32_cast_evaluators() { diff --git a/src/types/evaluator/float64.rs b/src/types/evaluator/float64.rs index 219a440c..54eb8d06 100644 --- a/src/types/evaluator/float64.rs +++ b/src/types/evaluator/float64.rs @@ -16,25 +16,22 @@ use crate::errors::DatabaseError; use crate::types::evaluator::cast::{cast_fail, to_char, to_varchar}; use crate::types::evaluator::DataValue; use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator}; +use crate::types::CharLengthUnits; use crate::types::LogicalType; use rust_decimal::prelude::FromPrimitive; use rust_decimal::Decimal; use serde::{Deserialize, Serialize}; -use sqlparser::ast::CharLengthUnits; use std::hint; #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] pub struct Float64PlusUnaryEvaluator; #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] pub struct Float64MinusUnaryEvaluator; - -#[typetag::serde] impl UnaryEvaluator for Float64PlusUnaryEvaluator { fn unary_eval(&self, value: &DataValue) -> DataValue { value.clone() } } -#[typetag::serde] impl UnaryEvaluator for Float64MinusUnaryEvaluator { fn unary_eval(&self, value: &DataValue) -> DataValue { match value { @@ -67,8 +64,6 @@ pub struct Float64EqBinaryEvaluator; pub struct Float64NotEqBinaryEvaluator; #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] pub struct Float64ModBinaryEvaluator; - -#[typetag::serde] impl BinaryEvaluator for Float64PlusBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -80,7 +75,6 @@ impl BinaryEvaluator for Float64PlusBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Float64MinusBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -151,8 +145,6 @@ crate::define_cast_evaluator!( Ok(DataValue::Decimal(decimal)) } ); - -#[typetag::serde] impl BinaryEvaluator for Float64MultiplyBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -164,7 +156,6 @@ impl BinaryEvaluator for Float64MultiplyBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Float64DivideBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -178,7 +169,6 @@ impl BinaryEvaluator for Float64DivideBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Float64GtBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -190,7 +180,6 @@ impl BinaryEvaluator for Float64GtBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Float64GtEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -202,7 +191,6 @@ impl BinaryEvaluator for Float64GtEqBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Float64LtBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -214,7 +202,6 @@ impl BinaryEvaluator for Float64LtBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Float64LtEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -226,7 +213,6 @@ impl BinaryEvaluator for Float64LtEqBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Float64EqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -238,7 +224,6 @@ impl BinaryEvaluator for Float64EqBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Float64NotEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -250,7 +235,6 @@ impl BinaryEvaluator for Float64NotEqBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Float64ModBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -268,8 +252,8 @@ mod test { use super::*; use crate::types::evaluator::{BinaryEvaluator, CastEvaluator}; use crate::types::value::Utf8Type; + use crate::types::CharLengthUnits; use rust_decimal::Decimal; - use sqlparser::ast::CharLengthUnits; #[test] fn test_float64_binary_and_cast_evaluators() { diff --git a/src/types/evaluator/int16.rs b/src/types/evaluator/int16.rs index aa9f95c6..09759419 100644 --- a/src/types/evaluator/int16.rs +++ b/src/types/evaluator/int16.rs @@ -25,9 +25,9 @@ mod test { use super::*; use crate::types::evaluator::CastEvaluator; use crate::types::value::Utf8Type; + use crate::types::CharLengthUnits; use ordered_float::OrderedFloat; use rust_decimal::Decimal; - use sqlparser::ast::CharLengthUnits; #[test] fn test_int16_cast_evaluators() { diff --git a/src/types/evaluator/int32.rs b/src/types/evaluator/int32.rs index 3fbce73a..36471f8b 100644 --- a/src/types/evaluator/int32.rs +++ b/src/types/evaluator/int32.rs @@ -25,9 +25,9 @@ mod test { use super::*; use crate::types::evaluator::{BinaryEvaluator, CastEvaluator}; use crate::types::value::Utf8Type; + use crate::types::CharLengthUnits; use ordered_float::OrderedFloat; use rust_decimal::Decimal; - use sqlparser::ast::CharLengthUnits; #[test] fn test_int32_binary_evaluators() { diff --git a/src/types/evaluator/int64.rs b/src/types/evaluator/int64.rs index 7e67eb09..3aace64b 100644 --- a/src/types/evaluator/int64.rs +++ b/src/types/evaluator/int64.rs @@ -25,9 +25,9 @@ mod test { use super::*; use crate::types::evaluator::CastEvaluator; use crate::types::value::Utf8Type; + use crate::types::CharLengthUnits; use ordered_float::OrderedFloat; use rust_decimal::Decimal; - use sqlparser::ast::CharLengthUnits; #[test] fn test_int64_cast_evaluators() { diff --git a/src/types/evaluator/int8.rs b/src/types/evaluator/int8.rs index 9482c989..adb29aa9 100644 --- a/src/types/evaluator/int8.rs +++ b/src/types/evaluator/int8.rs @@ -25,9 +25,9 @@ mod test { use super::*; use crate::types::evaluator::CastEvaluator; use crate::types::value::Utf8Type; + use crate::types::CharLengthUnits; use ordered_float::OrderedFloat; use rust_decimal::Decimal; - use sqlparser::ast::CharLengthUnits; #[test] fn test_int8_cast_evaluators() { diff --git a/src/types/evaluator/mod.rs b/src/types/evaluator/mod.rs index 5544a142..1583df11 100644 --- a/src/types/evaluator/mod.rs +++ b/src/types/evaluator/mod.rs @@ -40,72 +40,98 @@ pub use self::cast::cast_create; pub use self::unary::unary_create; use crate::errors::DatabaseError; +use crate::expression::{BinaryOperator, UnaryOperator}; use crate::types::value::DataValue; -use serde::{Deserialize, Serialize}; +use crate::types::LogicalType; use std::fmt::Debug; use std::hash::{Hash, Hasher}; use std::ops::Deref; use std::sync::Arc; -#[typetag::serde(tag = "binary")] pub trait BinaryEvaluator: Send + Sync + Debug { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result; } -#[typetag::serde(tag = "unary")] pub trait UnaryEvaluator: Send + Sync + Debug { fn unary_eval(&self, value: &DataValue) -> DataValue; } -#[typetag::serde(tag = "cast")] pub trait CastEvaluator: Send + Sync + Debug { fn eval_cast(&self, value: &DataValue) -> Result; } -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct BinaryEvaluatorBox(pub Arc); +#[derive(Clone, Debug)] +pub struct BinaryEvaluatorBox { + pub evaluator: Arc, + pub ty: LogicalType, + pub op: BinaryOperator, +} impl Deref for BinaryEvaluatorBox { - type Target = Arc; + type Target = dyn BinaryEvaluator; fn deref(&self) -> &Self::Target { - &self.0 + self.evaluator.as_ref() } } impl BinaryEvaluatorBox { + pub fn new(evaluator: Arc, ty: LogicalType, op: BinaryOperator) -> Self { + Self { evaluator, ty, op } + } + pub fn binary_eval( &self, left: &DataValue, right: &DataValue, ) -> Result { - self.0.binary_eval(left, right) + self.evaluator.binary_eval(left, right) } } -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct UnaryEvaluatorBox(pub Arc); +#[derive(Clone, Debug)] +pub struct UnaryEvaluatorBox { + pub evaluator: Arc, + pub ty: LogicalType, + pub op: UnaryOperator, +} impl UnaryEvaluatorBox { + pub fn new(evaluator: Arc, ty: LogicalType, op: UnaryOperator) -> Self { + Self { evaluator, ty, op } + } + pub fn unary_eval(&self, value: &DataValue) -> DataValue { - self.0.unary_eval(value) + self.evaluator.unary_eval(value) } } -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct CastEvaluatorBox(pub Arc); +#[derive(Clone, Debug)] +pub struct CastEvaluatorBox { + pub evaluator: Arc, + pub from: LogicalType, + pub to: LogicalType, +} impl Deref for CastEvaluatorBox { - type Target = Arc; + type Target = dyn CastEvaluator; fn deref(&self) -> &Self::Target { - &self.0 + self.evaluator.as_ref() } } impl CastEvaluatorBox { + pub fn new(evaluator: Arc, from: LogicalType, to: LogicalType) -> Self { + Self { + evaluator, + from, + to, + } + } + pub fn eval_cast(&self, value: &DataValue) -> Result { - self.0.eval_cast(value) + self.evaluator.eval_cast(value) } } diff --git a/src/types/evaluator/null.rs b/src/types/evaluator/null.rs index 670c1b22..563e74cf 100644 --- a/src/types/evaluator/null.rs +++ b/src/types/evaluator/null.rs @@ -21,8 +21,6 @@ use serde::{Deserialize, Serialize}; /// - Null values operate as null values #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] pub struct NullBinaryEvaluator; - -#[typetag::serde] impl BinaryEvaluator for NullBinaryEvaluator { fn binary_eval(&self, _: &DataValue, _: &DataValue) -> Result { Ok(DataValue::Null) @@ -31,8 +29,6 @@ impl BinaryEvaluator for NullBinaryEvaluator { #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] pub struct ToSqlNullCastEvaluator; - -#[typetag::serde] impl CastEvaluator for ToSqlNullCastEvaluator { fn eval_cast(&self, _value: &DataValue) -> Result { Ok(DataValue::Null) @@ -41,8 +37,6 @@ impl CastEvaluator for ToSqlNullCastEvaluator { #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] pub struct NullCastEvaluator; - -#[typetag::serde] impl CastEvaluator for NullCastEvaluator { fn eval_cast(&self, _value: &DataValue) -> Result { Ok(DataValue::Null) diff --git a/src/types/evaluator/time32.rs b/src/types/evaluator/time32.rs index 18277430..574f694b 100644 --- a/src/types/evaluator/time32.rs +++ b/src/types/evaluator/time32.rs @@ -17,9 +17,9 @@ use crate::types::evaluator::cast::{cast_fail, to_char, to_varchar}; use crate::types::evaluator::BinaryEvaluator; use crate::types::evaluator::DataValue; use crate::types::value::{ONE_DAY_TO_SEC, ONE_SEC_TO_NANO}; +use crate::types::CharLengthUnits; use crate::types::LogicalType; use serde::{Deserialize, Serialize}; -use sqlparser::ast::CharLengthUnits; use std::hint; #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] @@ -38,8 +38,6 @@ pub struct TimeLtEqBinaryEvaluator; pub struct TimeEqBinaryEvaluator; #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] pub struct TimeNotEqBinaryEvaluator; - -#[typetag::serde] impl BinaryEvaluator for TimePlusBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -64,7 +62,6 @@ impl BinaryEvaluator for TimePlusBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for TimeMinusBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -88,8 +85,6 @@ impl BinaryEvaluator for TimeMinusBinaryEvaluator { }) } } - -#[typetag::serde] impl BinaryEvaluator for TimeGtBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -105,7 +100,6 @@ impl BinaryEvaluator for TimeGtBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for TimeGtEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -162,8 +156,6 @@ crate::define_cast_evaluator!( Ok(DataValue::Time32(*value, this.precision.unwrap_or(0))) } ); - -#[typetag::serde] impl BinaryEvaluator for TimeLtBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -179,7 +171,6 @@ impl BinaryEvaluator for TimeLtBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for TimeLtEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -195,7 +186,6 @@ impl BinaryEvaluator for TimeLtEqBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for TimeEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -211,7 +201,6 @@ impl BinaryEvaluator for TimeEqBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for TimeNotEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { diff --git a/src/types/evaluator/time64.rs b/src/types/evaluator/time64.rs index 5e7de3c4..cafdfa0a 100644 --- a/src/types/evaluator/time64.rs +++ b/src/types/evaluator/time64.rs @@ -16,10 +16,10 @@ use crate::errors::DatabaseError; use crate::types::evaluator::cast::{cast_fail, to_char, to_varchar}; use crate::types::evaluator::BinaryEvaluator; use crate::types::evaluator::DataValue; +use crate::types::CharLengthUnits; use crate::types::LogicalType; use chrono::{Datelike, Timelike}; use serde::{Deserialize, Serialize}; -use sqlparser::ast::CharLengthUnits; use std::hint; #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] @@ -34,8 +34,6 @@ pub struct Time64LtEqBinaryEvaluator; pub struct Time64EqBinaryEvaluator; #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] pub struct Time64NotEqBinaryEvaluator; - -#[typetag::serde] impl BinaryEvaluator for Time64GtBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -60,7 +58,6 @@ impl BinaryEvaluator for Time64GtBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Time64GtEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -178,8 +175,6 @@ crate::define_cast_evaluator!( Ok(DataValue::Time64(*value, this.precision.unwrap_or(0), this.zone)) } ); - -#[typetag::serde] impl BinaryEvaluator for Time64LtBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -204,7 +199,6 @@ impl BinaryEvaluator for Time64LtBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Time64LtEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -229,7 +223,6 @@ impl BinaryEvaluator for Time64LtEqBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Time64EqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -254,7 +247,6 @@ impl BinaryEvaluator for Time64EqBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Time64NotEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -285,7 +277,7 @@ mod test { use super::*; use crate::types::evaluator::{BinaryEvaluator, CastEvaluator}; use crate::types::value::Utf8Type; - use sqlparser::ast::CharLengthUnits; + use crate::types::CharLengthUnits; #[test] fn test_time64_binary_evaluators() { diff --git a/src/types/evaluator/tuple.rs b/src/types/evaluator/tuple.rs index f5145496..bcbcdfe1 100644 --- a/src/types/evaluator/tuple.rs +++ b/src/types/evaluator/tuple.rs @@ -62,8 +62,6 @@ fn tuple_cmp( } Some(order) } - -#[typetag::serde] impl BinaryEvaluator for TupleEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -75,7 +73,6 @@ impl BinaryEvaluator for TupleEqBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for TupleNotEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -88,12 +85,11 @@ impl BinaryEvaluator for TupleNotEqBinaryEvaluator { } } -#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Eq, Clone, Hash)] pub struct TupleCastEvaluator { pub element_evaluators: Vec, } -#[typetag::serde] impl CastEvaluator for TupleCastEvaluator { fn eval_cast(&self, value: &DataValue) -> Result { match value { @@ -111,8 +107,6 @@ impl CastEvaluator for TupleCastEvaluator { } } } - -#[typetag::serde] impl BinaryEvaluator for TupleGtBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -128,7 +122,6 @@ impl BinaryEvaluator for TupleGtBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for TupleGtEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -144,7 +137,6 @@ impl BinaryEvaluator for TupleGtEqBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for TupleLtBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -160,7 +152,6 @@ impl BinaryEvaluator for TupleLtBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for TupleLtEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -181,8 +172,8 @@ impl BinaryEvaluator for TupleLtEqBinaryEvaluator { mod test { use super::*; use crate::types::evaluator::cast_create; + use crate::types::CharLengthUnits; use crate::types::LogicalType; - use sqlparser::ast::CharLengthUnits; use std::borrow::Cow; #[test] diff --git a/src/types/evaluator/uint16.rs b/src/types/evaluator/uint16.rs index ff6789da..332013e2 100644 --- a/src/types/evaluator/uint16.rs +++ b/src/types/evaluator/uint16.rs @@ -24,9 +24,9 @@ mod test { use super::*; use crate::types::evaluator::CastEvaluator; use crate::types::value::Utf8Type; + use crate::types::CharLengthUnits; use ordered_float::OrderedFloat; use rust_decimal::Decimal; - use sqlparser::ast::CharLengthUnits; #[test] fn test_uint16_cast_evaluators() { diff --git a/src/types/evaluator/uint32.rs b/src/types/evaluator/uint32.rs index 058570ba..9bc08089 100644 --- a/src/types/evaluator/uint32.rs +++ b/src/types/evaluator/uint32.rs @@ -24,9 +24,9 @@ mod test { use super::*; use crate::types::evaluator::CastEvaluator; use crate::types::value::Utf8Type; + use crate::types::CharLengthUnits; use ordered_float::OrderedFloat; use rust_decimal::Decimal; - use sqlparser::ast::CharLengthUnits; #[test] fn test_uint32_cast_evaluators() { diff --git a/src/types/evaluator/uint64.rs b/src/types/evaluator/uint64.rs index 2bfbe787..5b2a64d3 100644 --- a/src/types/evaluator/uint64.rs +++ b/src/types/evaluator/uint64.rs @@ -24,9 +24,9 @@ mod test { use super::*; use crate::types::evaluator::CastEvaluator; use crate::types::value::Utf8Type; + use crate::types::CharLengthUnits; use ordered_float::OrderedFloat; use rust_decimal::Decimal; - use sqlparser::ast::CharLengthUnits; #[test] fn test_uint64_cast_evaluators() { diff --git a/src/types/evaluator/uint8.rs b/src/types/evaluator/uint8.rs index 135a5a18..355278f9 100644 --- a/src/types/evaluator/uint8.rs +++ b/src/types/evaluator/uint8.rs @@ -24,9 +24,9 @@ mod test { use super::*; use crate::types::evaluator::CastEvaluator; use crate::types::value::Utf8Type; + use crate::types::CharLengthUnits; use ordered_float::OrderedFloat; use rust_decimal::Decimal; - use sqlparser::ast::CharLengthUnits; #[test] fn test_uint8_cast_evaluators() { diff --git a/src/types/evaluator/unary.rs b/src/types/evaluator/unary.rs index d201cd64..b4645558 100644 --- a/src/types/evaluator/unary.rs +++ b/src/types/evaluator/unary.rs @@ -27,12 +27,22 @@ use paste::paste; use std::borrow::Cow; use std::sync::Arc; +macro_rules! box_unary { + ($ty:expr, $op:expr, $evaluator:expr) => { + Ok(UnaryEvaluatorBox::new( + Arc::new($evaluator), + $ty.clone(), + $op, + )) + }; +} + macro_rules! numeric_unary_evaluator { ($value_type:ident, $op:expr, $ty:expr) => { paste! { match $op { - UnaryOperator::Plus => Ok(UnaryEvaluatorBox(Arc::new([<$value_type PlusUnaryEvaluator>]))), - UnaryOperator::Minus => Ok(UnaryEvaluatorBox(Arc::new([<$value_type MinusUnaryEvaluator>]))), + UnaryOperator::Plus => box_unary!($ty, $op, [<$value_type PlusUnaryEvaluator>]), + UnaryOperator::Minus => box_unary!($ty, $op, [<$value_type MinusUnaryEvaluator>]), _ => Err(DatabaseError::UnsupportedUnaryOperator($ty.clone(), $op)), } } @@ -50,7 +60,7 @@ pub fn unary_create( LogicalType::Integer => numeric_unary_evaluator!(Int32, op, ty), LogicalType::Bigint => numeric_unary_evaluator!(Int64, op, ty), LogicalType::Boolean => match op { - UnaryOperator::Not => Ok(UnaryEvaluatorBox(Arc::new(BooleanNotUnaryEvaluator))), + UnaryOperator::Not => box_unary!(ty, op, BooleanNotUnaryEvaluator), _ => Err(DatabaseError::UnsupportedUnaryOperator(ty.clone(), op)), }, LogicalType::Float => numeric_unary_evaluator!(Float32, op, ty), @@ -66,16 +76,11 @@ macro_rules! numeric_unary_evaluator_definition { #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] pub struct [<$value_type PlusUnaryEvaluator>]; #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] - pub struct [<$value_type MinusUnaryEvaluator>]; - - #[typetag::serde] - impl $crate::types::evaluator::UnaryEvaluator for [<$value_type PlusUnaryEvaluator>] { + pub struct [<$value_type MinusUnaryEvaluator>]; impl $crate::types::evaluator::UnaryEvaluator for [<$value_type PlusUnaryEvaluator>] { fn unary_eval(&self, value: &$crate::types::value::DataValue) -> $crate::types::value::DataValue { value.clone() } - } - #[typetag::serde] - impl $crate::types::evaluator::UnaryEvaluator for [<$value_type MinusUnaryEvaluator>] { + } impl $crate::types::evaluator::UnaryEvaluator for [<$value_type MinusUnaryEvaluator>] { fn unary_eval(&self, value: &$crate::types::value::DataValue) -> $crate::types::value::DataValue { match value { $compute_type(value) => $compute_type(-value), diff --git a/src/types/evaluator/utf8.rs b/src/types/evaluator/utf8.rs index ad487d2f..78f4dbd1 100644 --- a/src/types/evaluator/utf8.rs +++ b/src/types/evaluator/utf8.rs @@ -17,13 +17,13 @@ use crate::types::evaluator::cast::{cast_fail, to_char, to_varchar}; use crate::types::evaluator::BinaryEvaluator; use crate::types::evaluator::DataValue; use crate::types::value::Utf8Type; +use crate::types::CharLengthUnits; use crate::types::LogicalType; use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike}; use ordered_float::OrderedFloat; use regex::Regex; use rust_decimal::Decimal; use serde::{Deserialize, Serialize}; -use sqlparser::ast::CharLengthUnits; use std::hint; use std::str::FromStr; @@ -49,8 +49,6 @@ pub struct Utf8LikeBinaryEvaluator { pub struct Utf8NotLikeBinaryEvaluator { pub(crate) escape_char: Option, } - -#[typetag::serde] impl BinaryEvaluator for Utf8GtBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -64,7 +62,6 @@ impl BinaryEvaluator for Utf8GtBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Utf8GtEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -78,7 +75,6 @@ impl BinaryEvaluator for Utf8GtEqBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Utf8LtBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -241,8 +237,6 @@ crate::define_cast_evaluator!( crate::define_cast_evaluator!(Utf8ToDecimalCastEvaluator, DataValue::Utf8 { value, .. } => { Ok(DataValue::Decimal(Decimal::from_str(value)?)) }); - -#[typetag::serde] impl BinaryEvaluator for Utf8LtEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -256,7 +250,6 @@ impl BinaryEvaluator for Utf8LtEqBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Utf8EqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -270,7 +263,6 @@ impl BinaryEvaluator for Utf8EqBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Utf8NotEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -284,7 +276,6 @@ impl BinaryEvaluator for Utf8NotEqBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Utf8StringConcatBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -302,7 +293,6 @@ impl BinaryEvaluator for Utf8StringConcatBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Utf8LikeBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { @@ -316,7 +306,6 @@ impl BinaryEvaluator for Utf8LikeBinaryEvaluator { }) } } -#[typetag::serde] impl BinaryEvaluator for Utf8NotLikeBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { Ok(match (left, right) { diff --git a/src/types/mod.rs b/src/types/mod.rs index 6f1681f9..bb88ad78 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -28,11 +28,35 @@ use std::cmp; use crate::errors::DatabaseError; use kite_sql_serde_macros::ReferenceSerialization; -use sqlparser::ast::{CharLengthUnits, ExactNumberInfo, TimezoneInfo}; +use sqlparser::ast::{ExactNumberInfo, TimezoneInfo}; use ulid::Ulid; pub type ColumnId = Ulid; +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] +pub enum CharLengthUnits { + Characters, + Octets, +} + +impl From for CharLengthUnits { + fn from(value: sqlparser::ast::CharLengthUnits) -> Self { + match value { + sqlparser::ast::CharLengthUnits::Characters => Self::Characters, + sqlparser::ast::CharLengthUnits::Octets => Self::Octets, + } + } +} + +impl std::fmt::Display for CharLengthUnits { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Characters => write!(f, "CHARACTERS"), + Self::Octets => write!(f, "OCTETS"), + } + } +} + /// Sqlrs type conversion: /// sqlparser::ast::DataType -> LogicalType -> arrow::datatypes::DataType #[derive( @@ -425,7 +449,9 @@ impl TryFrom for LogicalType { } Ok(LogicalType::Char( len as u32, - char_unit.unwrap_or(CharLengthUnits::Characters), + char_unit + .map(Into::into) + .unwrap_or(CharLengthUnits::Characters), )) } sqlparser::ast::DataType::CharVarying(varchar_len) @@ -448,7 +474,9 @@ impl TryFrom for LogicalType { } Ok(LogicalType::Varchar( len, - char_unit.unwrap_or(CharLengthUnits::Characters), + char_unit + .map(Into::into) + .unwrap_or(CharLengthUnits::Characters), )) } sqlparser::ast::DataType::String(_) | sqlparser::ast::DataType::Text => { @@ -597,8 +625,8 @@ pub(crate) mod test { use crate::errors::DatabaseError; use crate::serdes::{ReferenceSerialization, ReferenceTables}; use crate::storage::rocksdb::RocksTransaction; + use crate::types::CharLengthUnits; use crate::types::LogicalType; - use sqlparser::ast::CharLengthUnits; use std::io::{Cursor, Seek, SeekFrom}; #[test] diff --git a/src/types/serialize.rs b/src/types/serialize.rs index 081a05dd..dff46585 100644 --- a/src/types/serialize.rs +++ b/src/types/serialize.rs @@ -14,13 +14,13 @@ use crate::errors::DatabaseError; use crate::types::value::{DataValue, Utf8Type}; +use crate::types::CharLengthUnits; use crate::types::LogicalType; use bumpalo::collections::Vec; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use kite_sql_serde_macros::ReferenceSerialization; use ordered_float::OrderedFloat; use rust_decimal::Decimal; -use sqlparser::ast::CharLengthUnits; use std::fmt::Debug; use std::io::{Cursor, Read, Seek, SeekFrom, Write}; diff --git a/src/types/tuple.rs b/src/types/tuple.rs index bee845aa..ae2822d5 100644 --- a/src/types/tuple.rs +++ b/src/types/tuple.rs @@ -250,12 +250,12 @@ mod tests { use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; use crate::types::tuple::Tuple; use crate::types::value::{DataValue, Utf8Type}; + use crate::types::CharLengthUnits; use crate::types::LogicalType; use bumpalo::Bump; use itertools::Itertools; use ordered_float::OrderedFloat; use rust_decimal::Decimal; - use sqlparser::ast::CharLengthUnits; use std::sync::Arc; #[test] diff --git a/src/types/tuple_builder.rs b/src/types/tuple_builder.rs index 83c1de4d..f25d03c2 100644 --- a/src/types/tuple_builder.rs +++ b/src/types/tuple_builder.rs @@ -16,7 +16,7 @@ use crate::catalog::PrimaryKeyIndices; use crate::errors::DatabaseError; use crate::types::tuple::{Schema, Tuple}; use crate::types::value::{DataValue, Utf8Type}; -use sqlparser::ast::CharLengthUnits; +use crate::types::CharLengthUnits; pub struct TupleBuilder<'a> { schema: &'a Schema, diff --git a/src/types/value.rs b/src/types/value.rs index dd945f59..9c3d613a 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -16,13 +16,13 @@ use super::LogicalType; use crate::errors::DatabaseError; use crate::storage::table_codec::{BumpBytes, BOUND_MAX_TAG, NOTNULL_TAG, NULL_TAG}; use crate::types::evaluator::cast_create; +use crate::types::CharLengthUnits; use byteorder::ReadBytesExt; use chrono::format::{DelayedFormat, StrftimeItems}; use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc}; use itertools::Itertools; use ordered_float::OrderedFloat; use rust_decimal::Decimal; -use sqlparser::ast::CharLengthUnits; use std::borrow::Cow; use std::cmp::Ordering; use std::fmt::Formatter; @@ -1464,11 +1464,11 @@ mod test { use crate::errors::DatabaseError; use crate::storage::table_codec::BumpBytes; use crate::types::value::{DataValue, TupleMappingRef, Utf8Type}; + use crate::types::CharLengthUnits; use crate::types::LogicalType; use bumpalo::Bump; use ordered_float::OrderedFloat; use rust_decimal::Decimal; - use sqlparser::ast::CharLengthUnits; use std::io::Cursor; #[test] diff --git a/tests/macros-test/Cargo.toml b/tests/macros-test/Cargo.toml index 3122cab6..036e176e 100644 --- a/tests/macros-test/Cargo.toml +++ b/tests/macros-test/Cargo.toml @@ -8,6 +8,5 @@ edition = "2021" lazy_static = { version = "1" } serde = { version = "1", features = ["derive", "rc"] } rust_decimal = { version = "1" } -sqlparser = { version = "0.61", features = ["serde"] } +sqlparser = { version = "0.61", default-features = false, features = ["std"] } tempfile = { version = "3.10" } -typetag = { version = "0.2" } diff --git a/tests/macros-test/src/main.rs b/tests/macros-test/src/main.rs index a0eed830..141eb921 100644 --- a/tests/macros-test/src/main.rs +++ b/tests/macros-test/src/main.rs @@ -29,10 +29,10 @@ mod test { use kite_sql::types::evaluator::binary_create; use kite_sql::types::tuple::{SchemaRef, Tuple}; use kite_sql::types::value::{DataValue, Utf8Type}; - use kite_sql::types::LogicalType; + use kite_sql::types::{CharLengthUnits, LogicalType}; use kite_sql::{from_tuple, scala_function, table_function, Model, Projection}; use rust_decimal::Decimal; - use sqlparser::ast::{CharLengthUnits, DataType as SqlDataType}; + use sqlparser::ast::DataType as SqlDataType; use std::sync::Arc; use tempfile::TempDir; diff --git a/tpcc/Cargo.toml b/tpcc/Cargo.toml index ea399911..4dd6e9e7 100644 --- a/tpcc/Cargo.toml +++ b/tpcc/Cargo.toml @@ -16,7 +16,6 @@ rand = { version = "0.8" } rust_decimal = { version = "1" } thiserror = { version = "1" } sqlite = { version = "0.34" } -sqlparser = { version = "0.61" } [target.'cfg(unix)'.dependencies] pprof = { version = "0.15", features = ["flamegraph"], optional = true } diff --git a/tpcc/src/backend/sqlite.rs b/tpcc/src/backend/sqlite.rs index 89b1ae38..e1b05c78 100644 --- a/tpcc/src/backend/sqlite.rs +++ b/tpcc/src/backend/sqlite.rs @@ -21,9 +21,9 @@ use chrono::{NaiveDateTime, TimeZone, Utc}; use clap::ValueEnum; use kite_sql::types::tuple::Tuple; use kite_sql::types::value::{DataValue, Utf8Type}; +use kite_sql::types::CharLengthUnits; use rust_decimal::Decimal; use sqlite::{Connection, CursorWithOwnership, Row, Statement as SqliteStatement, Value}; -use sqlparser::ast::CharLengthUnits; pub struct SqliteBackend { connection: Connection,