diff --git a/src/binder/alter_table.rs b/src/binder/alter_table.rs index 8e823555..b94cfd0b 100644 --- a/src/binder/alter_table.rs +++ b/src/binder/alter_table.rs @@ -14,6 +14,7 @@ use sqlparser::ast::{AlterColumnOperation, AlterTableOperation, ObjectName}; +use std::borrow::Cow; use std::sync::Arc; use super::{attach_span_if_absent, is_valid_identifier, Binder}; @@ -44,12 +45,7 @@ impl> Binder<'_, '_, T, A> "column is not allowed to exist in default".to_string(), )); } - if expr.return_type() != *ty { - expr = ScalarExpression::TypeCast { - expr: Box::new(expr), - ty: ty.clone(), - }; - } + expr = ScalarExpression::type_cast(expr, Cow::Borrowed(ty))?; Ok(expr) } diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index 4179c345..8d5d6cd7 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -25,6 +25,7 @@ use crate::types::value::DataValue; use crate::types::LogicalType; use itertools::Itertools; use sqlparser::ast::{ColumnDef, ColumnOption, Expr, IndexColumn, ObjectName, TableConstraint}; +use std::borrow::Cow; use std::collections::HashSet; use std::sync::Arc; @@ -161,12 +162,10 @@ impl> Binder<'_, '_, T, A> "column is not allowed to exist in `default`".to_string(), )); } - if expr.return_type() != column_desc.column_datatype { - expr = ScalarExpression::TypeCast { - expr: Box::new(expr), - ty: column_desc.column_datatype.clone(), - } - } + expr = ScalarExpression::type_cast( + expr, + Cow::Borrowed(&column_desc.column_datatype), + )?; column_desc.default = Some(expr); } option => { diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 4d442a31..0e894165 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -21,6 +21,7 @@ use sqlparser::ast::{ BinaryOperator, CharLengthUnits, DataType, DuplicateTreatment, Expr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident, Query, TypedString, UnaryOperator, Value, }; +use std::borrow::Cow; use std::collections::HashMap; use std::slice; @@ -284,7 +285,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T } => { let left_expr = self.bind_expr(expr)?; let (sub_query, column, correlated) = - self.bind_subquery(Some(left_expr.return_type()), subquery)?; + self.bind_subquery(Some(left_expr.return_type().as_ref()), subquery)?; if !self.context.is_step(&QueryBindStep::Where) { return Err(DatabaseError::UnsupportedStmt( @@ -354,7 +355,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T let mut expr_pairs = Vec::with_capacity(conditions.len()); for when in conditions { let result = self.bind_expr(&when.result)?; - let result_ty = result.return_type(); + let result_ty = result.return_type().into_owned(); fn_check_ty(&mut ty, result_ty)?; expr_pairs.push((self.bind_expr(&when.condition)?, result)) @@ -363,7 +364,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T let mut else_expr = None; if let Some(expr) = else_result { let temp_expr = Box::new(self.bind_expr(expr)?); - let else_ty = temp_expr.return_type(); + let else_ty = temp_expr.return_type().into_owned(); fn_check_ty(&mut ty, else_ty)?; else_expr = Some(temp_expr); @@ -426,7 +427,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T fn bind_subquery( &mut self, - in_ty: Option, + in_ty: Option<&LogicalType>, subquery: &Query, ) -> Result<(LogicalPlan, ScalarExpression, bool), DatabaseError> { let BinderContext { @@ -606,18 +607,19 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T let left_expr = Box::new(self.bind_expr(left)?); let right_expr = Box::new(self.bind_expr(right)?); + let left_ty = left_expr.return_type(); + let right_ty = right_expr.return_type(); let ty = match op { BinaryOperator::Plus | BinaryOperator::Minus | BinaryOperator::Multiply | BinaryOperator::Modulo => { - LogicalType::max_logical_type(&left_expr.return_type(), &right_expr.return_type())? + LogicalType::max_logical_type(&left_ty, &right_ty)?.into_owned() } BinaryOperator::Divide => { - if let LogicalType::Decimal(precision, scale) = LogicalType::max_logical_type( - &left_expr.return_type(), - &right_expr.return_type(), - )? { + if let LogicalType::Decimal(precision, scale) = + LogicalType::max_logical_type(&left_ty, &right_ty)?.into_owned() + { LogicalType::Decimal(precision, scale) } else { LogicalType::Double @@ -654,7 +656,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T let ty = if let UnaryOperator::Not = op { LogicalType::Boolean } else { - expr.return_type() + expr.return_type().into_owned() }; Ok(ScalarExpression::Unary { @@ -714,7 +716,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T if args.len() != 1 { return Err(DatabaseError::MisMatch("number of sum() parameters", "1")); } - let ty = args[0].return_type(); + let ty = args[0].return_type().into_owned(); return Ok(ScalarExpression::AggCall { distinct: is_distinct, @@ -727,7 +729,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T if args.len() != 1 { return Err(DatabaseError::MisMatch("number of min() parameters", "1")); } - let ty = args[0].return_type(); + let ty = args[0].return_type().into_owned(); return Ok(ScalarExpression::AggCall { distinct: is_distinct, @@ -740,7 +742,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T if args.len() != 1 { return Err(DatabaseError::MisMatch("number of max() parameters", "1")); } - let ty = args[0].return_type(); + let ty = args[0].return_type().into_owned(); return Ok(ScalarExpression::AggCall { distinct: is_distinct, @@ -815,10 +817,10 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T let mut ty = LogicalType::SqlNull; if !args.is_empty() { - ty = args[0].return_type(); + ty = args[0].return_type().into_owned(); for arg in args.iter_mut() { - let temp_ty = arg.return_type(); + let temp_ty = arg.return_type().into_owned(); if temp_ty == LogicalType::SqlNull { continue; @@ -826,7 +828,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T if ty == LogicalType::SqlNull && temp_ty != LogicalType::SqlNull { ty = temp_ty; } else if ty != temp_ty { - ty = LogicalType::max_logical_type(&ty, &temp_ty)?; + ty = LogicalType::max_logical_type(&ty, &temp_ty)?.into_owned(); } } } @@ -834,7 +836,10 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T } _ => (), } - let arg_types = args.iter().map(ScalarExpression::return_type).collect_vec(); + let arg_types = args + .iter() + .map(|arg| arg.return_type().into_owned()) + .collect_vec(); let summary = FunctionSummary { name: function_name.into(), arg_types, @@ -870,10 +875,10 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T let temp_ty_1 = expr_1.return_type(); let temp_ty_2 = expr_2.return_type(); - match (temp_ty_1, temp_ty_2) { + match (temp_ty_1.as_ref(), temp_ty_2.as_ref()) { (LogicalType::SqlNull, LogicalType::SqlNull) => Ok(LogicalType::SqlNull), - (ty, LogicalType::SqlNull) | (LogicalType::SqlNull, ty) => Ok(ty), - (ty_1, ty_2) => LogicalType::max_logical_type(&ty_1, &ty_2), + (ty, LogicalType::SqlNull) | (LogicalType::SqlNull, ty) => Ok(ty.clone()), + (ty_1, ty_2) => Ok(LogicalType::max_logical_type(ty_1, ty_2)?.into_owned()), } } @@ -904,10 +909,10 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T } fn bind_cast(&mut self, expr: &Expr, ty: &DataType) -> Result { - Ok(ScalarExpression::TypeCast { - expr: Box::new(self.bind_expr(expr)?), - ty: LogicalType::try_from(ty.clone())?, - }) + ScalarExpression::type_cast( + self.bind_expr(expr)?, + Cow::Owned(LogicalType::try_from(ty.clone())?), + ) } fn wildcard_expr() -> ScalarExpression { diff --git a/src/binder/insert.rs b/src/binder/insert.rs index 07096c79..fe23a716 100644 --- a/src/binder/insert.rs +++ b/src/binder/insert.rs @@ -96,9 +96,7 @@ impl> Binder<'_, '_, T, A> ScalarExpression::Constant(mut value) => { let ty = schema_ref[i].datatype(); - if &value.logical_type() != ty { - value = value.cast(ty)?; - } + value = value.cast(ty)?; // Check if the value length is too long value.check_len(ty)?; if value.is_null() && !schema_ref[i].nullable() { diff --git a/src/binder/select.rs b/src/binder/select.rs index e378861e..8f114c0c 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -560,7 +560,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' inferred_types[col_index] = match &inferred_types[col_index] { Some(existing) => { - Some(LogicalType::max_logical_type(existing, &value_type)?) + Some(LogicalType::max_logical_type(existing, &value_type)?.into_owned()) } None => Some(value_type), }; @@ -609,22 +609,19 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' { let cast_type = LogicalType::max_logical_type(left_schema.datatype(), right_schema.datatype())?; - if &cast_type != left_schema.datatype() { - left_cast.push(ScalarExpression::TypeCast { - expr: Box::new(ScalarExpression::column_expr(left_schema.clone(), position)), - ty: cast_type.clone(), - }); + if cast_type.as_ref() != left_schema.datatype() { + left_cast.push(ScalarExpression::type_cast( + ScalarExpression::column_expr(left_schema.clone(), position), + cast_type.clone(), + )?); } else { left_cast.push(ScalarExpression::column_expr(left_schema.clone(), position)); } - if &cast_type != right_schema.datatype() { - right_cast.push(ScalarExpression::TypeCast { - expr: Box::new(ScalarExpression::column_expr( - right_schema.clone(), - position, - )), - ty: cast_type.clone(), - }); + if cast_type.as_ref() != right_schema.datatype() { + right_cast.push(ScalarExpression::type_cast( + ScalarExpression::column_expr(right_schema.clone(), position), + cast_type.clone(), + )?); } else { right_cast.push(ScalarExpression::column_expr( right_schema.clone(), diff --git a/src/binder/update.rs b/src/binder/update.rs index 35245bb6..141349d0 100644 --- a/src/binder/update.rs +++ b/src/binder/update.rs @@ -25,6 +25,7 @@ use crate::types::value::DataValue; use sqlparser::ast::{ Assignment, AssignmentTarget, Expr, Ident, ObjectName, TableFactor, TableWithJoins, }; +use std::borrow::Cow; use std::slice; use std::sync::Arc; @@ -91,12 +92,10 @@ impl> Binder<'_, '_, T, A> } else { expression.clone() }; - if &expr.return_type() != column.datatype() { - expr = ScalarExpression::TypeCast { - expr: Box::new(expr), - ty: column.datatype().clone(), - } - } + expr = ScalarExpression::type_cast( + expr, + Cow::Borrowed(column.datatype()), + )?; value_exprs.push((column, expr)); } _ => { diff --git a/src/errors.rs b/src/errors.rs index 83ea114d..a1dadcaa 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -17,6 +17,7 @@ use crate::types::tuple::TupleId; use crate::types::LogicalType; use chrono::ParseError; use sqlparser::parser::ParserError; +use std::convert::Infallible; use std::num::{ParseFloatError, ParseIntError, TryFromIntError}; use std::str::{ParseBoolError, Utf8Error}; use std::string::FromUtf8Error; @@ -268,6 +269,12 @@ pub enum DatabaseError { ViewNotFound, } +impl From for DatabaseError { + fn from(value: Infallible) -> Self { + match value {} + } +} + impl DatabaseError { pub fn invalid_column(name: impl Into) -> Self { Self::InvalidColumn { diff --git a/src/execution/ddl/create_index.rs b/src/execution/ddl/create_index.rs index 004de6f0..525583fb 100644 --- a/src/execution/ddl/create_index.rs +++ b/src/execution/ddl/create_index.rs @@ -107,11 +107,9 @@ impl CreateIndex { while arena.next_tuple(self.input)? { let (value, tuple_pk) = { let tuple = arena.result_tuple(); - let Some(value) = DataValue::values_to_tuple(Projection::projection( - tuple, - &column_exprs, - &self.input_schema, - )?) else { + let Some(value) = + DataValue::values_to_tuple(Projection::projection(tuple, &column_exprs)?) + else { continue; }; let Some(tuple_pk) = tuple.pk.clone() else { diff --git a/src/execution/dml/analyze.rs b/src/execution/dml/analyze.rs index 280dc8ba..6a6226e7 100644 --- a/src/execution/dml/analyze.rs +++ b/src/execution/dml/analyze.rs @@ -23,7 +23,6 @@ use crate::planner::operator::analyze::AnalyzeOperator; use crate::planner::LogicalPlan; use crate::storage::{StatisticsMetaCache, Transaction}; use crate::types::index::IndexId; -use crate::types::tuple::SchemaRef; use crate::types::value::{DataValue, Utf8Type}; use itertools::Itertools; use sqlparser::ast::CharLengthUnits; @@ -34,7 +33,6 @@ const DEFAULT_NUM_OF_BUCKETS: usize = 100; pub struct Analyze { table_name: TableName, - input_schema: SchemaRef, input_plan: LogicalPlan, input: Option, histogram_buckets: Option, @@ -48,13 +46,12 @@ impl From<(AnalyzeOperator, LogicalPlan)> for Analyze { index_metas, histogram_buckets, }, - mut input, + input, ): (AnalyzeOperator, LogicalPlan), ) -> Self { let _ = index_metas; Analyze { table_name, - input_schema: input.output_schema().clone(), input_plan: input, input: None, histogram_buckets, @@ -108,7 +105,7 @@ impl Analyze { while arena.next_tuple(input)? { let tuple = arena.result_tuple(); for State { exprs, builder, .. } in builders.iter_mut() { - let values = Projection::projection(tuple, exprs, &self.input_schema)?; + let values = Projection::projection(tuple, exprs)?; if values.len() == 1 { builder.append(&values[0])?; diff --git a/src/execution/dml/delete.rs b/src/execution/dml/delete.rs index c0f31059..b925730b 100644 --- a/src/execution/dml/delete.rs +++ b/src/execution/dml/delete.rs @@ -21,23 +21,20 @@ use crate::planner::operator::delete::DeleteOperator; use crate::planner::LogicalPlan; use crate::storage::Transaction; use crate::types::index::{Index, IndexId, IndexType}; -use crate::types::tuple::SchemaRef; use crate::types::tuple_builder::TupleBuilder; use crate::types::value::DataValue; use std::collections::HashMap; pub struct Delete { table_name: TableName, - input_schema: SchemaRef, input_plan: LogicalPlan, input: Option, } impl From<(DeleteOperator, LogicalPlan)> for Delete { - fn from((DeleteOperator { table_name, .. }, mut input): (DeleteOperator, LogicalPlan)) -> Self { + fn from((DeleteOperator { table_name, .. }, input): (DeleteOperator, LogicalPlan)) -> Self { Delete { table_name, - input_schema: input.output_schema().clone(), input_plan: input, input: None, } @@ -83,22 +80,18 @@ impl Delete { let tuple = arena.result_tuple().clone(); for index_meta in table.indexes() { if let Some(Value { exprs, values, .. }) = indexes.get_mut(&index_meta.id) { - let Some(data_value) = DataValue::values_to_tuple(Projection::projection( - &tuple, - exprs, - &self.input_schema, - )?) else { + let Some(data_value) = + DataValue::values_to_tuple(Projection::projection(&tuple, exprs)?) + else { continue; }; values.push(data_value); } else { let mut values = Vec::with_capacity(table.indexes().len()); let exprs = index_meta.column_exprs(table)?; - let Some(data_value) = DataValue::values_to_tuple(Projection::projection( - &tuple, - &exprs, - &self.input_schema, - )?) else { + let Some(data_value) = + DataValue::values_to_tuple(Projection::projection(&tuple, &exprs)?) + else { continue; }; values.push(data_value); diff --git a/src/execution/dml/insert.rs b/src/execution/dml/insert.rs index ff4a9036..20a4d23c 100644 --- a/src/execution/dml/insert.rs +++ b/src/execution/dml/insert.rs @@ -111,7 +111,6 @@ impl Insert { return Err(DatabaseError::not_null()); } - let table_schema = table_catalog.schema_ref(); let mut index_metas = Vec::new(); for index_meta in table_catalog.indexes() { let exprs = index_meta.column_exprs(&table_catalog)?; @@ -143,9 +142,7 @@ impl Insert { } value.unwrap_or(DataValue::Null) }; - if !value.is_null() && &value.logical_type() != col.datatype() { - value = value.cast(col.datatype())?; - } + value = value.cast(col.datatype())?; value.check_len(col.datatype())?; if value.is_null() && !col.nullable() { return Err(DatabaseError::not_null_column(col.name().to_string())); @@ -156,7 +153,7 @@ impl Insert { let tuple = Tuple::new(Some(pk), values); for (index_meta, exprs) in index_metas.iter() { - let values = Projection::projection(&tuple, exprs, table_schema.as_slice())?; + let values = Projection::projection(&tuple, exprs)?; let Some(value) = DataValue::values_to_tuple(values) else { continue; }; diff --git a/src/execution/dml/update.rs b/src/execution/dml/update.rs index d3d4b4b4..9eba12fd 100644 --- a/src/execution/dml/update.rs +++ b/src/execution/dml/update.rs @@ -112,7 +112,7 @@ impl Update { let old_pk = tuple.pk.clone().ok_or(DatabaseError::PrimaryKeyNotFound)?; for (index_meta, exprs) in index_metas.iter() { - let values = Projection::projection(&tuple, exprs, &self.input_schema)?; + let values = Projection::projection(&tuple, exprs)?; let Some(value) = DataValue::values_to_tuple(values) else { continue; }; @@ -141,7 +141,7 @@ impl Update { is_overwrite = false; } for (index_meta, exprs) in index_metas.iter() { - let values = Projection::projection(&tuple, exprs, &self.input_schema)?; + let values = Projection::projection(&tuple, exprs)?; let Some(value) = DataValue::values_to_tuple(values) else { continue; }; diff --git a/src/execution/dql/aggregate/avg.rs b/src/execution/dql/aggregate/avg.rs index db10a202..1457b341 100644 --- a/src/execution/dql/aggregate/avg.rs +++ b/src/execution/dql/aggregate/avg.rs @@ -16,8 +16,9 @@ use crate::errors::DatabaseError; use crate::execution::dql::aggregate::sum::SumAccumulator; use crate::execution::dql::aggregate::Accumulator; use crate::expression::BinaryOperator; -use crate::types::evaluator::EvaluatorFactory; +use crate::types::evaluator::binary_create; use crate::types::value::DataValue; +use std::borrow::Cow; pub struct AvgAccumulator { inner: Option, @@ -40,7 +41,7 @@ impl Accumulator for AvgAccumulator { inner } else { self.inner - .get_or_insert(SumAccumulator::new(&value.logical_type())?) + .get_or_insert(SumAccumulator::new(Cow::Owned(value.logical_type()))?) }; acc.update_value(value)?; self.count += 1; @@ -69,7 +70,7 @@ impl Accumulator for AvgAccumulator { if value_ty != quantity_ty { value = value.cast(&quantity_ty)? } - let evaluator = EvaluatorFactory::binary_create(quantity_ty, BinaryOperator::Divide)?; + let evaluator = binary_create(Cow::Owned(quantity_ty), BinaryOperator::Divide)?; evaluator.0.binary_eval(&value, &quantity) } } diff --git a/src/execution/dql/aggregate/min_max.rs b/src/execution/dql/aggregate/min_max.rs index d2436570..163594d7 100644 --- a/src/execution/dql/aggregate/min_max.rs +++ b/src/execution/dql/aggregate/min_max.rs @@ -15,8 +15,9 @@ use crate::errors::DatabaseError; use crate::execution::dql::aggregate::Accumulator; use crate::expression::BinaryOperator; -use crate::types::evaluator::EvaluatorFactory; +use crate::types::evaluator::binary_create; use crate::types::value::DataValue; +use std::borrow::Cow; pub struct MinMaxAccumulator { inner: Option, @@ -39,7 +40,7 @@ impl Accumulator for MinMaxAccumulator { fn update_value(&mut self, value: &DataValue) -> Result<(), DatabaseError> { if !value.is_null() { if let Some(inner_value) = &self.inner { - let evaluator = EvaluatorFactory::binary_create(value.logical_type(), self.op)?; + let evaluator = binary_create(Cow::Owned(value.logical_type()), self.op)?; if let DataValue::Boolean(result) = evaluator.0.binary_eval(inner_value, value)? { result } else { diff --git a/src/execution/dql/aggregate/mod.rs b/src/execution/dql/aggregate/mod.rs index 4bbba9e6..78d6a409 100644 --- a/src/execution/dql/aggregate/mod.rs +++ b/src/execution/dql/aggregate/mod.rs @@ -29,6 +29,7 @@ use crate::expression::agg::AggKind; use crate::expression::ScalarExpression; use crate::types::value::DataValue; use itertools::Itertools; +use std::borrow::Cow; /// Tips: Idea for sqlrs /// An accumulator represents a stateful object that lives throughout the evaluation of multiple @@ -49,7 +50,7 @@ fn create_accumulator(expr: &ScalarExpression) -> Result, D Ok(match (kind, distinct) { (AggKind::Count, false) => Box::new(CountAccumulator::new()), (AggKind::Count, true) => Box::new(DistinctCountAccumulator::new()), - (AggKind::Sum, false) => Box::new(SumAccumulator::new(ty)?), + (AggKind::Sum, false) => Box::new(SumAccumulator::new(Cow::Borrowed(ty))?), (AggKind::Sum, true) => Box::new(DistinctSumAccumulator::new(ty)?), (AggKind::Min, _) => Box::new(MinMaxAccumulator::new(false)), (AggKind::Max, _) => Box::new(MinMaxAccumulator::new(true)), diff --git a/src/execution/dql/aggregate/sum.rs b/src/execution/dql/aggregate/sum.rs index 5d249b47..39cfe55b 100644 --- a/src/execution/dql/aggregate/sum.rs +++ b/src/execution/dql/aggregate/sum.rs @@ -15,10 +15,11 @@ use crate::errors::DatabaseError; use crate::execution::dql::aggregate::Accumulator; use crate::expression::BinaryOperator; -use crate::types::evaluator::{BinaryEvaluatorBox, EvaluatorFactory}; +use crate::types::evaluator::{binary_create, BinaryEvaluatorBox}; use crate::types::value::DataValue; use crate::types::LogicalType; use ahash::RandomState; +use std::borrow::Cow; use std::collections::HashSet; pub struct SumAccumulator { @@ -27,12 +28,12 @@ pub struct SumAccumulator { } impl SumAccumulator { - pub fn new(ty: &LogicalType) -> Result { + pub fn new(ty: Cow<'_, LogicalType>) -> Result { debug_assert!(ty.is_numeric()); Ok(Self { result: DataValue::Null, - evaluator: EvaluatorFactory::binary_create(ty.clone(), BinaryOperator::Plus)?, + evaluator: binary_create(ty, BinaryOperator::Plus)?, }) } } @@ -64,7 +65,7 @@ impl DistinctSumAccumulator { pub fn new(ty: &LogicalType) -> Result { Ok(Self { distinct_values: HashSet::default(), - inner: SumAccumulator::new(ty)?, + inner: SumAccumulator::new(Cow::Borrowed(ty))?, }) } } diff --git a/src/execution/dql/join/nested_loop_join.rs b/src/execution/dql/join/nested_loop_join.rs index cab57fa8..087a94e5 100644 --- a/src/execution/dql/join/nested_loop_join.rs +++ b/src/execution/dql/join/nested_loop_join.rs @@ -16,7 +16,6 @@ //! [`JoinType::RightOuter`], [`JoinType::Cross`], [`JoinType::Full`]. use crate::errors::DatabaseError; -use crate::execution::dql::projection::Projection; use crate::execution::{ build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode, ReadExecutor, }; @@ -24,18 +23,17 @@ use crate::expression::ScalarExpression; use crate::planner::operator::join::{JoinCondition, JoinOperator, JoinType}; use crate::planner::LogicalPlan; use crate::storage::Transaction; -use crate::types::tuple::{Schema, SchemaRef, SplitTupleRef, Tuple}; +use crate::types::tuple::{Schema, SplitTupleRef, Tuple}; use crate::types::value::DataValue; use fixedbitset::FixedBitSet; use itertools::Itertools; -use std::sync::Arc; /// Equivalent condition struct EqualCondition { on_left_keys: Vec, on_right_keys: Vec, - left_schema: SchemaRef, - right_schema: SchemaRef, + left_len: usize, + right_len: usize, } impl EqualCondition { @@ -45,8 +43,8 @@ impl EqualCondition { fn new( on_left_keys: Vec, on_right_keys: Vec, - left_schema: Arc, - right_schema: Arc, + left_schema: &Schema, + right_schema: &Schema, ) -> EqualCondition { if !on_left_keys.is_empty() && on_left_keys.len() != on_right_keys.len() { unreachable!("Unexpected join on condition.") @@ -54,8 +52,8 @@ impl EqualCondition { EqualCondition { on_left_keys, on_right_keys, - left_schema, - right_schema, + left_len: left_schema.len(), + right_len: right_schema.len(), } } @@ -66,12 +64,14 @@ impl EqualCondition { if self.on_left_keys.is_empty() { return Ok(true); } - let left_values = - Projection::projection(left_tuple, &self.on_left_keys, &self.left_schema)?; - let right_values = - Projection::projection(right_tuple, &self.on_right_keys, &self.right_schema)?; - Ok(left_values == right_values) + for (left_expr, right_expr) in self.on_left_keys.iter().zip(self.on_right_keys.iter()) { + if left_expr.eval(Some(left_tuple))? != right_expr.eval(Some(right_tuple))? { + return Ok(false); + } + } + + Ok(true) } } @@ -141,12 +141,7 @@ impl From<(JoinOperator, LogicalPlan, LogicalPlan)> for NestedLoopJoin { std::mem::swap(&mut left_schema, &mut right_schema); } - let eq_cond = EqualCondition::new( - on_left_keys, - on_right_keys, - left_schema.clone(), - right_schema.clone(), - ); + let eq_cond = EqualCondition::new(on_left_keys, on_right_keys, &left_schema, &right_schema); NestedLoopJoin { left_input_plan: left_input, @@ -330,7 +325,7 @@ impl NestedLoopJoin { right_bitmap = Some(bits); } } - let right_schema_len = self.eq_cond.right_schema.len(); + let right_schema_len = self.eq_cond.right_len; let tuple = match self.ty { JoinType::LeftOuter | JoinType::RightOuter | JoinType::Full if !active_left.has_matched => @@ -374,7 +369,7 @@ impl NestedLoopJoin { right_emit_index += 1; if !right_bitmap.contains(idx) { - let mut values = vec![DataValue::Null; self.eq_cond.left_schema.len()]; + let mut values = vec![DataValue::Null; self.eq_cond.left_len]; values.append(&mut right_tuple.values); self.state = NestedLoopJoinState::EmitRightUnmatched { right_input, @@ -466,6 +461,7 @@ mod test { use crate::utils::lru::SharedLruCache; use std::collections::HashSet; use std::hash::RandomState; + use std::sync::Arc; use tempfile::TempDir; fn optimize_exprs(plan: LogicalPlan) -> Result { diff --git a/src/execution/dql/mark_apply.rs b/src/execution/dql/mark_apply.rs index d1dd9c4f..6d54514e 100644 --- a/src/execution/dql/mark_apply.rs +++ b/src/execution/dql/mark_apply.rs @@ -275,11 +275,12 @@ mod tests { use crate::planner::{Childrens, LogicalPlan}; use crate::storage::rocksdb::RocksStorage; use crate::storage::{StatisticsMetaCache, Storage, TableCache, ViewCache}; - use crate::types::evaluator::EvaluatorFactory; + use crate::types::evaluator::binary_create; use crate::types::index::RuntimeIndexProbe; use crate::types::tuple::Tuple; use crate::types::LogicalType; use crate::utils::lru::SharedLruCache; + use std::borrow::Cow; use std::hash::RandomState; use std::sync::Arc; use tempfile::TempDir; @@ -349,8 +350,8 @@ mod tests { op: BinaryOperator::Eq, left_expr: Box::new(ScalarExpression::column_expr(left_column, left_position)), right_expr: Box::new(ScalarExpression::column_expr(right_column, right_position)), - evaluator: Some(EvaluatorFactory::binary_create( - LogicalType::Integer, + evaluator: Some(binary_create( + Cow::Owned(LogicalType::Integer), BinaryOperator::Eq, )?), ty: LogicalType::Boolean, @@ -680,8 +681,8 @@ mod tests { op: BinaryOperator::Eq, left_expr: Box::new(ScalarExpression::column_expr(right_flag_column, 2)), right_expr: Box::new(ScalarExpression::Constant(DataValue::Int32(1))), - evaluator: Some(EvaluatorFactory::binary_create( - LogicalType::Integer, + evaluator: Some(binary_create( + std::borrow::Cow::Owned(LogicalType::Integer), BinaryOperator::Eq, )?), ty: LogicalType::Boolean, diff --git a/src/execution/dql/projection.rs b/src/execution/dql/projection.rs index 80dda08d..4d44d3d3 100644 --- a/src/execution/dql/projection.rs +++ b/src/execution/dql/projection.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::catalog::ColumnRef; use crate::errors::DatabaseError; use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode}; use crate::expression::ScalarExpression; @@ -21,10 +20,10 @@ use crate::planner::LogicalPlan; use crate::storage::Transaction; use crate::types::tuple::Tuple; use crate::types::value::DataValue; + pub struct Projection { exprs: Vec, input: ExecId, - scratch: Tuple, } impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Projection { @@ -37,11 +36,7 @@ impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Projection { transaction: *mut T, ) -> ExecId { let input = build_read(arena, input, cache, transaction); - arena.push(ExecNode::Projection(Projection { - exprs, - input, - scratch: Tuple::default(), - })) + arena.push(ExecNode::Projection(Projection { exprs, input })) } fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { @@ -50,15 +45,14 @@ impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Projection { return Ok(()); } - std::mem::swap(&mut self.scratch, arena.result_tuple_mut()); - let tuple = &self.scratch; - let output = arena.result_tuple_mut(); - output.pk.clone_from(&tuple.pk); - output.values.clear(); - output.values.reserve(self.exprs.len()); - for expr in self.exprs.iter() { - output.values.push(expr.eval(Some(tuple))?); - } + arena.with_projection_tmp(|tuple, projection_tmp| { + projection_tmp.clear(); + projection_tmp.reserve(self.exprs.len()); + for expr in self.exprs.iter() { + projection_tmp.push(expr.eval(Some(tuple))?); + } + Ok::<_, DatabaseError>(()) + })?; arena.resume(); Ok(()) } @@ -68,7 +62,6 @@ impl Projection { pub fn projection( tuple: &Tuple, exprs: &[ScalarExpression], - _schema: &[ColumnRef], ) -> Result, DatabaseError> { let mut values = Vec::with_capacity(exprs.len()); diff --git a/src/execution/dql/values.rs b/src/execution/dql/values.rs index 492b3a22..bcb7b76e 100644 --- a/src/execution/dql/values.rs +++ b/src/execution/dql/values.rs @@ -75,9 +75,7 @@ impl Values { for (i, value) in values.iter_mut().enumerate() { let ty = self.schema_ref[i].datatype().clone(); - if value.logical_type() != ty { - *value = mem::replace(value, DataValue::Null).cast(&ty)?; - } + *value = mem::replace(value, DataValue::Null).cast(&ty)?; } let output = arena.result_tuple_mut(); diff --git a/src/execution/mod.rs b/src/execution/mod.rs index 8c0c5caa..385f823b 100644 --- a/src/execution/mod.rs +++ b/src/execution/mod.rs @@ -63,6 +63,7 @@ use crate::planner::LogicalPlan; use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; use crate::types::index::RuntimeIndexProbe; use crate::types::tuple::Tuple; +use crate::types::value::DataValue; pub(crate) type ExecutionCaches<'a> = (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache); pub(crate) type ExecId = usize; @@ -264,6 +265,7 @@ impl<'a, T: Transaction + 'a> ExecNode<'a, T> { pub(crate) struct ExecArena<'a, T: Transaction + 'a> { nodes: Vec>, result: ExecResult, + projection_tmp: Vec, cache: Option>, transaction: *mut T, runtime_probe_stack: Vec, @@ -274,6 +276,7 @@ impl<'a, T: Transaction + 'a> Default for ExecArena<'a, T> { Self { nodes: Vec::new(), result: ExecResult::default(), + projection_tmp: Vec::new(), cache: None, transaction: std::ptr::null_mut(), runtime_probe_stack: Vec::new(), @@ -344,6 +347,21 @@ impl<'a, T: Transaction + 'a> ExecArena<'a, T> { &mut self.result.tuple } + #[inline] + pub(crate) fn with_projection_tmp( + &mut self, + f: impl FnOnce(&Tuple, &mut Vec) -> Result, + ) -> Result { + let ExecArena { + result, + projection_tmp, + .. + } = self; + let ret = f(&result.tuple, projection_tmp)?; + std::mem::swap(&mut result.tuple.values, projection_tmp); + Ok(ret) + } + #[inline] pub(crate) fn resume(&mut self) { self.result.status = Some(ExecStatus::Continue); diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index 8f0f8a33..62cae29c 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -15,12 +15,13 @@ use crate::errors::DatabaseError; use crate::expression::function::scala::ScalarFunction; use crate::expression::{AliasType, BinaryOperator, ScalarExpression}; -use crate::types::evaluator::EvaluatorFactory; +use crate::types::evaluator::binary_create; use crate::types::tuple::TupleLike; use crate::types::value::{DataValue, Utf8Type}; use crate::types::LogicalType; use regex::Regex; use sqlparser::ast::{CharLengthUnits, TrimWhereField}; +use std::borrow::Cow; use std::cmp; use std::cmp::Ordering; @@ -36,13 +37,6 @@ macro_rules! eval_to_num { impl ScalarExpression { pub fn eval(&self, tuple: Option) -> Result { - let check_cast = |value: DataValue, return_type: &LogicalType| { - if value.logical_type() != *return_type { - return value.cast(return_type); - } - Ok(value) - }; - match self { ScalarExpression::Constant(val) => Ok(val.clone()), ScalarExpression::ColumnRef { position, .. } => { @@ -64,7 +58,16 @@ impl ScalarExpression { expr.eval(Some(tuple)) } } - ScalarExpression::TypeCast { expr, ty, .. } => Ok(expr.eval(tuple)?.cast(ty)?), + ScalarExpression::TypeCast { + expr, evaluator, .. + } => { + let value = expr.eval(tuple)?; + if let Some(evaluator) = evaluator { + evaluator.eval_cast(&value) + } else { + Ok(value) + } + } ScalarExpression::Binary { left_expr, right_expr, @@ -277,9 +280,9 @@ impl ScalarExpression { ty, } => { if condition.eval(tuple)?.is_true()? { - check_cast(left_expr.eval(tuple)?, ty) + left_expr.eval(tuple)?.cast(ty) } else { - check_cast(right_expr.eval(tuple)?, ty) + right_expr.eval(tuple)?.cast(ty) } } ScalarExpression::IfNull { @@ -292,7 +295,7 @@ impl ScalarExpression { if value.is_null() { value = right_expr.eval(tuple)?; } - check_cast(value, ty) + value.cast(ty) } ScalarExpression::NullIf { left_expr, @@ -304,7 +307,7 @@ impl ScalarExpression { if right_expr.eval(tuple)? == value { value = DataValue::Null; } - check_cast(value, ty) + value.cast(ty) } ScalarExpression::Coalesce { exprs, ty } => { let mut value = None; @@ -317,7 +320,7 @@ impl ScalarExpression { break; } } - check_cast(value.unwrap_or(DataValue::Null), ty) + value.unwrap_or(DataValue::Null).cast(ty) } ScalarExpression::CaseWhen { operand_expr, @@ -335,12 +338,8 @@ impl ScalarExpression { let mut when_value = when_expr.eval(tuple)?; let is_true = if let Some(operand_value) = &operand_value { let ty = operand_value.logical_type(); - let evaluator = - EvaluatorFactory::binary_create(ty.clone(), BinaryOperator::Eq)?; - - if when_value.logical_type() != ty { - when_value = when_value.cast(&ty)?; - } + when_value = when_value.cast(&ty)?; + let evaluator = binary_create(Cow::Owned(ty), BinaryOperator::Eq)?; evaluator .0 .binary_eval(operand_value, &when_value)? @@ -358,7 +357,7 @@ impl ScalarExpression { result = Some(expr.eval(tuple)?); } } - check_cast(result.unwrap_or(DataValue::Null), ty) + result.unwrap_or(DataValue::Null).cast(ty) } ScalarExpression::TableFunction(_) => unreachable!(), } diff --git a/src/expression/mod.rs b/src/expression/mod.rs index a8a5a27e..a313fc43 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -19,7 +19,10 @@ use crate::expression::function::scala::ScalarFunction; use crate::expression::function::table::TableFunction; use crate::expression::visitor::{walk_expr, Visitor}; use crate::expression::visitor_mut::VisitorMut; -use crate::types::evaluator::{BinaryEvaluatorBox, EvaluatorFactory, UnaryEvaluatorBox}; +use crate::types::evaluator::{ + binary_create, cast_create, unary_create, BinaryEvaluatorBox, CastEvaluatorBox, + UnaryEvaluatorBox, +}; use crate::types::value::DataValue; use crate::types::LogicalType; use itertools::Itertools; @@ -28,6 +31,7 @@ use sqlparser::ast::TrimWhereField; use sqlparser::ast::{ BinaryOperator as SqlBinaryOperator, CharLengthUnits, UnaryOperator as SqlUnaryOperator, }; +use std::borrow::Cow; use std::fmt::{Debug, Formatter}; use std::hash::Hash; use std::{fmt, mem}; @@ -64,6 +68,7 @@ pub enum ScalarExpression { TypeCast { expr: Box, ty: LogicalType, + evaluator: Option, }, IsNull { negated: bool, @@ -149,6 +154,23 @@ pub enum ScalarExpression { pub struct BindEvaluator; impl VisitorMut<'_> for BindEvaluator { + fn visit_type_cast( + &mut self, + expr: &'_ mut ScalarExpression, + ty: &'_ mut LogicalType, + evaluator: &'_ mut Option, + ) -> Result<(), DatabaseError> { + self.visit(expr)?; + let from = expr.return_type(); + *evaluator = if from.as_ref() == ty { + None + } else { + Some(cast_create(from, Cow::Borrowed(ty))?) + }; + + Ok(()) + } + fn visit_unary( &mut self, op: &'_ mut UnaryOperator, @@ -160,18 +182,19 @@ impl VisitorMut<'_> for BindEvaluator { let ty = expr.return_type(); if ty.is_unsigned_numeric() { - *expr = ScalarExpression::TypeCast { - expr: Box::new(mem::replace(expr, ScalarExpression::Empty)), - ty: match ty { - LogicalType::UTinyint => LogicalType::Tinyint, - LogicalType::USmallint => LogicalType::Smallint, - LogicalType::UInteger => LogicalType::Integer, - LogicalType::UBigint => LogicalType::Bigint, - _ => unreachable!(), - }, - } + let target_ty = match ty.as_ref() { + LogicalType::UTinyint => LogicalType::Tinyint, + LogicalType::USmallint => LogicalType::Smallint, + LogicalType::UInteger => LogicalType::Integer, + LogicalType::UBigint => LogicalType::Bigint, + _ => unreachable!(), + }; + *expr = ScalarExpression::type_cast( + mem::replace(expr, ScalarExpression::Empty), + Cow::Owned(target_ty), + )?; } - *evaluator = Some(EvaluatorFactory::unary_create(ty, *op)?); + *evaluator = Some(unary_create(expr.return_type(), *op)?); Ok(()) } @@ -187,20 +210,21 @@ impl VisitorMut<'_> for BindEvaluator { self.visit(left_expr)?; self.visit(right_expr)?; - let ty = - LogicalType::max_logical_type(&left_expr.return_type(), &right_expr.return_type())?; - let fn_cast = |expr: &mut ScalarExpression, ty: LogicalType| { - if expr.return_type() != ty { - *expr = ScalarExpression::TypeCast { - expr: Box::new(mem::replace(expr, ScalarExpression::Empty)), - ty, - } - } - }; - fn_cast(left_expr, ty.clone()); - fn_cast(right_expr, ty.clone()); + let left_ty = left_expr.return_type().into_owned(); + let right_ty = right_expr.return_type().into_owned(); + let ty = LogicalType::max_logical_type(&left_ty, &right_ty)?; + let fn_cast = + |expr: &mut ScalarExpression, ty: &LogicalType| -> Result<(), DatabaseError> { + *expr = ScalarExpression::type_cast( + mem::replace(expr, ScalarExpression::Empty), + Cow::Borrowed(ty), + )?; + Ok(()) + }; + fn_cast(left_expr, ty.as_ref())?; + fn_cast(right_expr, ty.as_ref())?; - *evaluator = Some(EvaluatorFactory::binary_create(ty, *op)?); + *evaluator = Some(binary_create(ty, *op)?); Ok(()) } @@ -240,6 +264,23 @@ impl ScalarExpression { ScalarExpression::ColumnRef { column, position } } + pub fn type_cast( + expr: ScalarExpression, + ty: Cow<'_, LogicalType>, + ) -> Result { + let from = expr.return_type(); + if from.as_ref() == ty.as_ref() { + return Ok(expr); + } + let evaluator = Some(cast_create(from, ty.clone())?); + + Ok(ScalarExpression::TypeCast { + expr: Box::new(expr), + ty: ty.into_owned(), + evaluator, + }) + } + pub(crate) fn eq_ignore_colref_pos(&self, other: &ScalarExpression) -> bool { match (self.unpack_alias_ref(), other.unpack_alias_ref()) { ( @@ -282,10 +323,10 @@ impl ScalarExpression { } } - pub fn return_type(&self) -> LogicalType { + pub fn return_type(&self) -> Cow<'_, LogicalType> { match self { - ScalarExpression::Constant(v) => v.logical_type(), - ScalarExpression::ColumnRef { column, .. } => column.datatype().clone(), + ScalarExpression::Constant(v) => Cow::Owned(v.logical_type()), + ScalarExpression::ColumnRef { column, .. } => Cow::Borrowed(column.datatype()), ScalarExpression::Binary { ty: return_type, .. } @@ -312,26 +353,29 @@ impl ScalarExpression { } | ScalarExpression::CaseWhen { ty: return_type, .. - } => return_type.clone(), + } => Cow::Borrowed(return_type), ScalarExpression::IsNull { .. } | ScalarExpression::In { .. } - | ScalarExpression::Between { .. } => LogicalType::Boolean, + | ScalarExpression::Between { .. } => Cow::Owned(LogicalType::Boolean), ScalarExpression::SubString { .. } => { - LogicalType::Varchar(None, CharLengthUnits::Characters) + Cow::Owned(LogicalType::Varchar(None, CharLengthUnits::Characters)) } - ScalarExpression::Position { .. } => LogicalType::Integer, + ScalarExpression::Position { .. } => Cow::Owned(LogicalType::Integer), ScalarExpression::Trim { .. } => { - LogicalType::Varchar(None, CharLengthUnits::Characters) + Cow::Owned(LogicalType::Varchar(None, CharLengthUnits::Characters)) } ScalarExpression::Alias { expr, .. } => expr.return_type(), ScalarExpression::Empty | ScalarExpression::TableFunction(_) => unreachable!(), ScalarExpression::Tuple(exprs) => { - let types = exprs.iter().map(|expr| expr.return_type()).collect_vec(); + let types = exprs + .iter() + .map(|expr| expr.return_type().into_owned()) + .collect_vec(); - LogicalType::Tuple(types) + Cow::Owned(LogicalType::Tuple(types)) } ScalarExpression::ScalaFunction(ScalarFunction { inner, .. }) => { - inner.return_type().clone() + Cow::Borrowed(inner.return_type()) } } } @@ -477,7 +521,7 @@ impl ScalarExpression { format!("({}) as ({})", expr, alias_expr.output_name()) } }, - ScalarExpression::TypeCast { expr, ty } => { + ScalarExpression::TypeCast { expr, ty, .. } => { format!("cast ({} as {})", expr.output_name(), ty) } ScalarExpression::IsNull { expr, negated } => { @@ -669,7 +713,7 @@ impl ScalarExpression { self.output_name(), true, // SAFETY: default expr must not be [`ScalarExpression::ColumnRef`] - ColumnDesc::new(self.return_type(), None, false, None).unwrap(), + ColumnDesc::new(self.return_type().into_owned(), None, false, None).unwrap(), )), } } @@ -813,11 +857,12 @@ mod test { use crate::storage::{Storage, TableCache, Transaction}; use crate::types::evaluator::boolean::BooleanNotUnaryEvaluator; use crate::types::evaluator::int32::Int32PlusBinaryEvaluator; - use crate::types::evaluator::{BinaryEvaluatorBox, UnaryEvaluatorBox}; + use crate::types::evaluator::{cast_create, BinaryEvaluatorBox, UnaryEvaluatorBox}; use crate::types::value::{DataValue, Utf8Type}; use crate::types::LogicalType; use crate::utils::lru::SharedLruCache; use sqlparser::ast::{CharLengthUnits, TrimWhereField}; + use std::borrow::Cow; use std::hash::RandomState; use std::io::{Cursor, Seek, SeekFrom}; use std::sync::Arc; @@ -974,6 +1019,10 @@ mod test { ScalarExpression::TypeCast { expr: Box::new(ScalarExpression::Empty), ty: LogicalType::Integer, + evaluator: Some(cast_create( + Cow::Owned(LogicalType::Integer), + Cow::Owned(LogicalType::Integer), + )?), }, Some((&transaction, &table_cache)), &mut reference_tables, diff --git a/src/expression/range_detacher.rs b/src/expression/range_detacher.rs index 0bb3e725..ca431c2a 100644 --- a/src/expression/range_detacher.rs +++ b/src/expression/range_detacher.rs @@ -727,9 +727,7 @@ impl<'a> RangeDetacher<'a> { if !Self::_is_belong(self.table_name, &col) || col.id() != Some(*self.column_id) { return Ok(None); } - if &val.logical_type() != col.datatype() { - val = val.cast(col.datatype())? - } + val = val.cast(col.datatype())?; if is_flip { op = match op { BinaryOperator::Gt => BinaryOperator::Lt, diff --git a/src/expression/simplify.rs b/src/expression/simplify.rs index 55907286..b03476a6 100644 --- a/src/expression/simplify.rs +++ b/src/expression/simplify.rs @@ -16,9 +16,10 @@ use crate::catalog::ColumnRef; use crate::errors::DatabaseError; use crate::expression::visitor_mut::{walk_mut_expr, VisitorMut}; use crate::expression::{BinaryOperator, ScalarExpression, UnaryOperator}; -use crate::types::evaluator::EvaluatorFactory; +use crate::types::evaluator::{binary_create, unary_create}; use crate::types::value::DataValue; use crate::types::LogicalType; +use std::borrow::Cow; use std::mem; #[derive(Debug)] @@ -60,7 +61,7 @@ impl VisitorMut<'_> for ConstantCalculator { let value = if let Some(evaluator) = evaluator { evaluator.0.unary_eval(unary_val) } else { - EvaluatorFactory::unary_create(ty.clone(), *op)? + unary_create(Cow::Borrowed(ty), *op)? .0 .unary_eval(unary_val) }; @@ -73,10 +74,9 @@ impl VisitorMut<'_> for ConstantCalculator { right_expr, .. } => { - let ty = LogicalType::max_logical_type( - &left_expr.return_type(), - &right_expr.return_type(), - )?; + let left_ty = left_expr.return_type(); + let right_ty = right_expr.return_type(); + let ty = LogicalType::max_logical_type(&left_ty, &right_ty)?.into_owned(); self.visit(left_expr)?; self.visit(right_expr)?; @@ -85,19 +85,17 @@ impl VisitorMut<'_> for ConstantCalculator { ScalarExpression::Constant(right_val), ) = (left_expr.as_mut(), right_expr.as_mut()) { - let evaluator = EvaluatorFactory::binary_create(ty.clone(), *op)?; + let evaluator = binary_create(Cow::Borrowed(&ty), *op)?; - if left_val.logical_type() != ty { - *left_val = left_val.clone().cast(&ty)?; - } - if right_val.logical_type() != ty { - *right_val = right_val.clone().cast(&ty)?; - } + *left_val = mem::replace(left_val, DataValue::Null).cast(&ty)?; + *right_val = mem::replace(right_val, DataValue::Null).cast(&ty)?; let value = evaluator.0.binary_eval(left_val, right_val)?; let _ = mem::replace(expr, ScalarExpression::Constant(value)); } } - ScalarExpression::TypeCast { expr: arg_expr, ty } => { + ScalarExpression::TypeCast { + expr: arg_expr, ty, .. + } => { self.visit(arg_expr)?; if let ScalarExpression::Constant(value) = arg_expr.as_mut() { @@ -472,7 +470,7 @@ impl ScalarExpression { let unary_value = if let Some(evaluator) = evaluator { evaluator.0.unary_eval(&value) } else { - EvaluatorFactory::unary_create(ty.clone(), *op) + unary_create(Cow::Borrowed(ty), *op) .ok()? .0 .unary_eval(&value) @@ -489,16 +487,12 @@ impl ScalarExpression { } => { let mut left = left_expr.unpack_val()?; let mut right = right_expr.unpack_val()?; - if &left.logical_type() != ty { - left = left.cast(ty).ok()?; - } - if &right.logical_type() != ty { - right = right.cast(ty).ok()?; - } + left = left.cast(ty).ok()?; + right = right.cast(ty).ok()?; if let Some(evaluator) = evaluator { evaluator.0.binary_eval(&left, &right) } else { - EvaluatorFactory::binary_create(ty.clone(), *op) + binary_create(Cow::Borrowed(ty), *op) .ok()? .0 .binary_eval(&left, &right) diff --git a/src/expression/visitor.rs b/src/expression/visitor.rs index ab56977c..b024e99c 100644 --- a/src/expression/visitor.rs +++ b/src/expression/visitor.rs @@ -18,7 +18,7 @@ use crate::expression::agg::AggKind; use crate::expression::function::scala::ScalarFunction; use crate::expression::function::table::TableFunction; use crate::expression::{AliasType, BinaryOperator, ScalarExpression, UnaryOperator}; -use crate::types::evaluator::{BinaryEvaluatorBox, UnaryEvaluatorBox}; +use crate::types::evaluator::{BinaryEvaluatorBox, CastEvaluatorBox, UnaryEvaluatorBox}; use crate::types::value::DataValue; use crate::types::LogicalType; use sqlparser::ast::TrimWhereField; @@ -51,6 +51,7 @@ pub trait Visitor<'a>: Sized { &mut self, expr: &'a ScalarExpression, _ty: &'a LogicalType, + _evaluator: Option<&'a CastEvaluatorBox>, ) -> Result<(), DatabaseError> { self.visit(expr) } @@ -273,7 +274,11 @@ pub fn walk_expr<'a, V: Visitor<'a>>( ScalarExpression::Constant(value) => visitor.visit_constant(value), ScalarExpression::ColumnRef { column, .. } => visitor.visit_column_ref(column), ScalarExpression::Alias { expr, alias } => visitor.visit_alias(expr, alias), - ScalarExpression::TypeCast { expr, ty } => visitor.visit_type_cast(expr, ty), + ScalarExpression::TypeCast { + expr, + ty, + evaluator, + } => visitor.visit_type_cast(expr, ty, evaluator.as_ref()), ScalarExpression::IsNull { negated, expr } => visitor.visit_is_null(*negated, expr), ScalarExpression::Unary { op, diff --git a/src/expression/visitor_mut.rs b/src/expression/visitor_mut.rs index 36222659..4619d86e 100644 --- a/src/expression/visitor_mut.rs +++ b/src/expression/visitor_mut.rs @@ -18,7 +18,7 @@ use crate::expression::agg::AggKind; use crate::expression::function::scala::ScalarFunction; use crate::expression::function::table::TableFunction; use crate::expression::{AliasType, BinaryOperator, ScalarExpression, UnaryOperator}; -use crate::types::evaluator::{BinaryEvaluatorBox, UnaryEvaluatorBox}; +use crate::types::evaluator::{BinaryEvaluatorBox, CastEvaluatorBox, UnaryEvaluatorBox}; use crate::types::value::DataValue; use crate::types::LogicalType; use sqlparser::ast::TrimWhereField; @@ -74,6 +74,7 @@ pub trait VisitorMut<'a>: Sized { &mut self, expr: &'a mut ScalarExpression, _ty: &'a mut LogicalType, + _evaluator: &'a mut Option, ) -> Result<(), DatabaseError> { self.visit(expr) } @@ -298,7 +299,11 @@ pub fn walk_mut_expr<'a, V: VisitorMut<'a>>( visitor.visit_column_ref(column, position) } ScalarExpression::Alias { expr, alias } => visitor.visit_alias(expr, alias), - ScalarExpression::TypeCast { expr, ty } => visitor.visit_type_cast(expr, ty), + ScalarExpression::TypeCast { + expr, + ty, + evaluator, + } => visitor.visit_type_cast(expr, ty, evaluator), ScalarExpression::IsNull { negated, expr } => visitor.visit_is_null(*negated, expr), ScalarExpression::Unary { op, diff --git a/src/function/numbers.rs b/src/function/numbers.rs index 2a21e994..e3d3cdfd 100644 --- a/src/function/numbers.rs +++ b/src/function/numbers.rs @@ -68,9 +68,7 @@ impl TableFunctionImpl for Numbers { ) -> Result>>, DatabaseError> { let mut value = args[0].eval::<&Tuple>(None)?; - if value.logical_type() != LogicalType::Integer { - value = value.cast(&LogicalType::Integer)?; - } + value = value.cast(&LogicalType::Integer)?; let num = value .i32() .ok_or_else(|| DatabaseError::not_null_column("numbers() arg"))?; diff --git a/src/macros/mod.rs b/src/macros/mod.rs index a83e2d0e..f18a43c8 100644 --- a/src/macros/mod.rs +++ b/src/macros/mod.rs @@ -114,9 +114,7 @@ macro_rules! scala_function { let mut value = args[_index].eval(tuple)?; _index += 1; - if value.logical_type() != $arg_ty { - value = value.cast(&$arg_ty)?; - } + value = value.cast(&$arg_ty)?; value }, )*) } @@ -201,9 +199,7 @@ macro_rules! table_function { let mut value = args[_index].eval::<&::kite_sql::types::tuple::Tuple>(None)?; _index += 1; - if value.logical_type() != $arg_ty { - value = value.cast(&$arg_ty)?; - } + value = value.cast(&$arg_ty)?; value }, )*) } diff --git a/src/optimizer/core/histogram.rs b/src/optimizer/core/histogram.rs index c61e83f3..f2a737db 100644 --- a/src/optimizer/core/histogram.rs +++ b/src/optimizer/core/histogram.rs @@ -18,13 +18,14 @@ use crate::expression::range_detacher::Range; use crate::expression::BinaryOperator; use crate::optimizer::core::cm_sketch::CountMinSketch; use crate::storage::table_codec::BumpBytes; -use crate::types::evaluator::{BinaryEvaluatorBox, EvaluatorFactory}; +use crate::types::evaluator::{binary_create, BinaryEvaluatorBox}; use crate::types::index::{IndexId, IndexMeta}; use crate::types::value::DataValue; use crate::types::LogicalType; use bumpalo::Bump; use kite_sql_serde_macros::ReferenceSerialization; use ordered_float::OrderedFloat; +use std::borrow::Cow; use std::collections::Bound; use std::{cmp, mem}; @@ -232,10 +233,10 @@ impl HistogramBuilder { impl BoundComparator { fn new(ty: LogicalType) -> Result { Ok(Self { - lt: EvaluatorFactory::binary_create(ty.clone(), BinaryOperator::Lt)?, - lte: EvaluatorFactory::binary_create(ty.clone(), BinaryOperator::LtEq)?, - gt: EvaluatorFactory::binary_create(ty.clone(), BinaryOperator::Gt)?, - gte: EvaluatorFactory::binary_create(ty, BinaryOperator::GtEq)?, + lt: binary_create(Cow::Borrowed(&ty), BinaryOperator::Lt)?, + lte: binary_create(Cow::Borrowed(&ty), BinaryOperator::LtEq)?, + gt: binary_create(Cow::Borrowed(&ty), BinaryOperator::Gt)?, + gte: binary_create(Cow::Owned(ty), BinaryOperator::GtEq)?, }) } diff --git a/src/serdes/evaluator.rs b/src/serdes/evaluator.rs index 56acb932..71b5e4ff 100644 --- a/src/serdes/evaluator.rs +++ b/src/serdes/evaluator.rs @@ -13,7 +13,8 @@ // limitations under the License. use crate::implement_serialization_by_bincode; -use crate::types::evaluator::{BinaryEvaluatorBox, UnaryEvaluatorBox}; +use crate::types::evaluator::{BinaryEvaluatorBox, CastEvaluatorBox, UnaryEvaluatorBox}; implement_serialization_by_bincode!(UnaryEvaluatorBox); implement_serialization_by_bincode!(BinaryEvaluatorBox); +implement_serialization_by_bincode!(CastEvaluatorBox); diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 95c2cbd2..6a7c8b37 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -40,6 +40,7 @@ use crate::types::value::{DataValue, TupleMappingRef}; use crate::types::{ColumnId, LogicalType}; use crate::utils::lru::SharedLruCache; use itertools::Itertools; +use std::borrow::Cow; use std::collections::{BTreeMap, Bound}; use std::fmt::{self, Display, Formatter}; use std::io::Cursor; @@ -411,12 +412,10 @@ pub trait Transaction: Sized { match default_change { DefaultChange::NoChange => { if let Some(default_expr) = new_column.desc().default.clone() { - if default_expr.return_type() != *new_data_type { - new_column.desc_mut().default = Some(ScalarExpression::TypeCast { - expr: Box::new(default_expr), - ty: new_data_type.clone(), - }); - } + new_column.desc_mut().default = Some(ScalarExpression::type_cast( + default_expr, + Cow::Borrowed(new_data_type), + )?); } } DefaultChange::Set(default_expr) => { @@ -1125,11 +1124,11 @@ fn encode_bound<'a>( ) -> Result, DatabaseError> { match bound { Bound::Included(val) => { - inner.bound_key(params, &val, is_upper, buffer)?; + inner.bound_key(params, val, is_upper, buffer)?; Ok(Bound::Included(buffer.as_slice())) } Bound::Excluded(val) => { - inner.bound_key(params, &val, is_upper, buffer)?; + inner.bound_key(params, val, is_upper, buffer)?; Ok(Bound::Excluded(buffer.as_slice())) } Bound::Unbounded => Ok(Bound::Unbounded), diff --git a/src/storage/rocksdb.rs b/src/storage/rocksdb.rs index b8377c62..429d05a9 100644 --- a/src/storage/rocksdb.rs +++ b/src/storage/rocksdb.rs @@ -361,9 +361,9 @@ fn cleanup_failed_checkpoint_dir(path: &Path) -> Result<(), DatabaseError> { #[cfg(not(feature = "unsafe_txdb_checkpoint"))] fn unsupported_transactiondb_checkpoint_error() -> DatabaseError { - DatabaseError::UnsupportedStmt(format!( - "rocksdb TransactionDB checkpoint is disabled; enable the `unsafe_txdb_checkpoint` feature to opt in to the current implementation", - )) + DatabaseError::UnsupportedStmt( + "rocksdb TransactionDB checkpoint is disabled; enable the `unsafe_txdb_checkpoint` feature to opt in to the current implementation".to_string() + ) } #[cfg(feature = "unsafe_txdb_checkpoint")] @@ -510,7 +510,7 @@ impl CheckpointableStorage for RocksStorage { #[cfg(not(feature = "unsafe_txdb_checkpoint"))] { - return Err(unsupported_transactiondb_checkpoint_error()); + Err(unsupported_transactiondb_checkpoint_error()) } } } diff --git a/src/types/evaluator/binary.rs b/src/types/evaluator/binary.rs new file mode 100644 index 00000000..8b8d12c8 --- /dev/null +++ b/src/types/evaluator/binary.rs @@ -0,0 +1,393 @@ +// Copyright 2024 KipData/KiteSQL +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::errors::DatabaseError; +use crate::expression::BinaryOperator; +use crate::types::evaluator::boolean::*; +use crate::types::evaluator::date::*; +use crate::types::evaluator::datetime::*; +use crate::types::evaluator::decimal::*; +use crate::types::evaluator::float32::*; +use crate::types::evaluator::float64::*; +use crate::types::evaluator::int16::*; +use crate::types::evaluator::int32::*; +use crate::types::evaluator::int64::*; +use crate::types::evaluator::int8::*; +use crate::types::evaluator::null::NullBinaryEvaluator; +use crate::types::evaluator::time32::*; +use crate::types::evaluator::time64::*; +use crate::types::evaluator::tuple::{ + TupleEqBinaryEvaluator, TupleGtBinaryEvaluator, TupleGtEqBinaryEvaluator, + TupleLtBinaryEvaluator, TupleLtEqBinaryEvaluator, TupleNotEqBinaryEvaluator, +}; +use crate::types::evaluator::uint16::*; +use crate::types::evaluator::uint32::*; +use crate::types::evaluator::uint64::*; +use crate::types::evaluator::uint8::*; +use crate::types::evaluator::utf8::*; +use crate::types::evaluator::BinaryEvaluatorBox; +use crate::types::LogicalType; +use paste::paste; +use std::borrow::Cow; +use std::sync::Arc; + +macro_rules! numeric_binary_evaluator { + ($value_type:ident, $op:expr, $ty:expr) => { + paste! { + match $op { + BinaryOperator::Plus => Ok(BinaryEvaluatorBox(Arc::new([<$value_type PlusBinaryEvaluator>]))), + BinaryOperator::Minus => Ok(BinaryEvaluatorBox(Arc::new([<$value_type MinusBinaryEvaluator>]))), + BinaryOperator::Multiply => Ok(BinaryEvaluatorBox(Arc::new([<$value_type MultiplyBinaryEvaluator>]))), + BinaryOperator::Divide => Ok(BinaryEvaluatorBox(Arc::new([<$value_type DivideBinaryEvaluator>]))), + BinaryOperator::Gt => Ok(BinaryEvaluatorBox(Arc::new([<$value_type GtBinaryEvaluator>]))), + BinaryOperator::GtEq => Ok(BinaryEvaluatorBox(Arc::new([<$value_type GtEqBinaryEvaluator>]))), + BinaryOperator::Lt => Ok(BinaryEvaluatorBox(Arc::new([<$value_type LtBinaryEvaluator>]))), + BinaryOperator::LtEq => Ok(BinaryEvaluatorBox(Arc::new([<$value_type LtEqBinaryEvaluator>]))), + BinaryOperator::Eq => Ok(BinaryEvaluatorBox(Arc::new([<$value_type EqBinaryEvaluator>]))), + BinaryOperator::NotEq => Ok(BinaryEvaluatorBox(Arc::new([<$value_type NotEqBinaryEvaluator>]))), + BinaryOperator::Modulo => Ok(BinaryEvaluatorBox(Arc::new([<$value_type ModBinaryEvaluator>]))), + _ => Err(DatabaseError::UnsupportedBinaryOperator($ty.clone(), $op)), + } + } + }; +} + +pub fn binary_create( + ty: Cow<'_, LogicalType>, + op: BinaryOperator, +) -> Result { + let ty = ty.as_ref(); + match ty { + LogicalType::Tinyint => numeric_binary_evaluator!(Int8, op, ty), + LogicalType::Smallint => numeric_binary_evaluator!(Int16, op, ty), + LogicalType::Integer => numeric_binary_evaluator!(Int32, op, ty), + LogicalType::Bigint => numeric_binary_evaluator!(Int64, op, ty), + LogicalType::UTinyint => numeric_binary_evaluator!(UInt8, op, ty), + LogicalType::USmallint => numeric_binary_evaluator!(UInt16, op, ty), + LogicalType::UInteger => numeric_binary_evaluator!(UInt32, op, ty), + LogicalType::UBigint => numeric_binary_evaluator!(UInt64, op, ty), + LogicalType::Float => numeric_binary_evaluator!(Float32, op, ty), + LogicalType::Double => numeric_binary_evaluator!(Float64, op, ty), + LogicalType::Date => numeric_binary_evaluator!(Date, op, ty), + LogicalType::DateTime => numeric_binary_evaluator!(DateTime, op, ty), + LogicalType::Time(_) => match op { + BinaryOperator::Plus => Ok(BinaryEvaluatorBox(Arc::new(TimePlusBinaryEvaluator))), + BinaryOperator::Minus => Ok(BinaryEvaluatorBox(Arc::new(TimeMinusBinaryEvaluator))), + BinaryOperator::Gt => Ok(BinaryEvaluatorBox(Arc::new(TimeGtBinaryEvaluator))), + BinaryOperator::GtEq => Ok(BinaryEvaluatorBox(Arc::new(TimeGtEqBinaryEvaluator))), + BinaryOperator::Lt => Ok(BinaryEvaluatorBox(Arc::new(TimeLtBinaryEvaluator))), + BinaryOperator::LtEq => Ok(BinaryEvaluatorBox(Arc::new(TimeLtEqBinaryEvaluator))), + BinaryOperator::Eq => Ok(BinaryEvaluatorBox(Arc::new(TimeEqBinaryEvaluator))), + BinaryOperator::NotEq => Ok(BinaryEvaluatorBox(Arc::new(TimeNotEqBinaryEvaluator))), + _ => Err(DatabaseError::UnsupportedBinaryOperator(ty.clone(), op)), + }, + LogicalType::TimeStamp(_, _) => match op { + BinaryOperator::Gt => Ok(BinaryEvaluatorBox(Arc::new(Time64GtBinaryEvaluator))), + BinaryOperator::GtEq => Ok(BinaryEvaluatorBox(Arc::new(Time64GtEqBinaryEvaluator))), + BinaryOperator::Lt => Ok(BinaryEvaluatorBox(Arc::new(Time64LtBinaryEvaluator))), + BinaryOperator::LtEq => Ok(BinaryEvaluatorBox(Arc::new(Time64LtEqBinaryEvaluator))), + BinaryOperator::Eq => Ok(BinaryEvaluatorBox(Arc::new(Time64EqBinaryEvaluator))), + BinaryOperator::NotEq => Ok(BinaryEvaluatorBox(Arc::new(Time64NotEqBinaryEvaluator))), + _ => Err(DatabaseError::UnsupportedBinaryOperator(ty.clone(), op)), + }, + LogicalType::Decimal(_, _) => numeric_binary_evaluator!(Decimal, op, ty), + LogicalType::Boolean => match op { + BinaryOperator::And => Ok(BinaryEvaluatorBox(Arc::new(BooleanAndBinaryEvaluator))), + BinaryOperator::Or => Ok(BinaryEvaluatorBox(Arc::new(BooleanOrBinaryEvaluator))), + BinaryOperator::Eq => Ok(BinaryEvaluatorBox(Arc::new(BooleanEqBinaryEvaluator))), + BinaryOperator::NotEq => Ok(BinaryEvaluatorBox(Arc::new(BooleanNotEqBinaryEvaluator))), + _ => Err(DatabaseError::UnsupportedBinaryOperator( + LogicalType::Boolean, + op, + )), + }, + LogicalType::Varchar(_, _) | LogicalType::Char(_, _) => match op { + BinaryOperator::Gt => Ok(BinaryEvaluatorBox(Arc::new(Utf8GtBinaryEvaluator))), + BinaryOperator::Lt => Ok(BinaryEvaluatorBox(Arc::new(Utf8LtBinaryEvaluator))), + BinaryOperator::GtEq => Ok(BinaryEvaluatorBox(Arc::new(Utf8GtEqBinaryEvaluator))), + BinaryOperator::LtEq => Ok(BinaryEvaluatorBox(Arc::new(Utf8LtEqBinaryEvaluator))), + BinaryOperator::Eq => Ok(BinaryEvaluatorBox(Arc::new(Utf8EqBinaryEvaluator))), + BinaryOperator::NotEq => Ok(BinaryEvaluatorBox(Arc::new(Utf8NotEqBinaryEvaluator))), + BinaryOperator::StringConcat => Ok(BinaryEvaluatorBox(Arc::new( + Utf8StringConcatBinaryEvaluator, + ))), + BinaryOperator::Like(escape_char) => { + Ok(BinaryEvaluatorBox(Arc::new(Utf8LikeBinaryEvaluator { + escape_char, + }))) + } + BinaryOperator::NotLike(escape_char) => { + Ok(BinaryEvaluatorBox(Arc::new(Utf8NotLikeBinaryEvaluator { + escape_char, + }))) + } + _ => Err(DatabaseError::UnsupportedBinaryOperator(ty.clone(), op)), + }, + LogicalType::SqlNull => Ok(BinaryEvaluatorBox(Arc::new(NullBinaryEvaluator))), + LogicalType::Tuple(_) => match op { + BinaryOperator::Eq => Ok(BinaryEvaluatorBox(Arc::new(TupleEqBinaryEvaluator))), + BinaryOperator::NotEq => Ok(BinaryEvaluatorBox(Arc::new(TupleNotEqBinaryEvaluator))), + BinaryOperator::Gt => Ok(BinaryEvaluatorBox(Arc::new(TupleGtBinaryEvaluator))), + BinaryOperator::GtEq => Ok(BinaryEvaluatorBox(Arc::new(TupleGtEqBinaryEvaluator))), + BinaryOperator::Lt => Ok(BinaryEvaluatorBox(Arc::new(TupleLtBinaryEvaluator))), + BinaryOperator::LtEq => Ok(BinaryEvaluatorBox(Arc::new(TupleLtEqBinaryEvaluator))), + _ => Err(DatabaseError::UnsupportedBinaryOperator(ty.clone(), op)), + }, + } +} + +#[macro_export] +macro_rules! numeric_binary_evaluator_definition { + ($value_type:ident, $compute_type:path) => { + paste::paste! { + #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] + pub struct [<$value_type PlusBinaryEvaluator>]; + #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] + pub struct [<$value_type MinusBinaryEvaluator>]; + #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] + pub struct [<$value_type MultiplyBinaryEvaluator>]; + #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] + pub struct [<$value_type DivideBinaryEvaluator>]; + #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] + pub struct [<$value_type GtBinaryEvaluator>]; + #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] + pub struct [<$value_type GtEqBinaryEvaluator>]; + #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] + pub struct [<$value_type LtBinaryEvaluator>]; + #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] + pub struct [<$value_type LtEqBinaryEvaluator>]; + #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] + pub struct [<$value_type EqBinaryEvaluator>]; + #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] + pub struct [<$value_type NotEqBinaryEvaluator>]; + #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] + pub struct [<$value_type ModBinaryEvaluator>]; + + #[typetag::serde] + impl $crate::types::evaluator::BinaryEvaluator for [<$value_type PlusBinaryEvaluator>] { + fn binary_eval( + &self, + left: &$crate::types::value::DataValue, + right: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + Ok(match (left, right) { + ($compute_type(v1), $compute_type(v2)) => $compute_type(v1.checked_add(*v2).ok_or($crate::errors::DatabaseError::OverFlow)?), + ($compute_type(_), $crate::types::value::DataValue::Null) + | ($crate::types::value::DataValue::Null, $compute_type(_)) + | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, + _ => unsafe { std::hint::unreachable_unchecked() }, + }) + } + } + #[typetag::serde] + impl $crate::types::evaluator::BinaryEvaluator for [<$value_type MinusBinaryEvaluator>] { + fn binary_eval( + &self, + left: &$crate::types::value::DataValue, + right: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + Ok(match (left, right) { + ($compute_type(v1), $compute_type(v2)) => $compute_type(v1.checked_sub(*v2).ok_or($crate::errors::DatabaseError::OverFlow)?), + ($compute_type(_), $crate::types::value::DataValue::Null) + | ($crate::types::value::DataValue::Null, $compute_type(_)) + | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, + _ => unsafe { std::hint::unreachable_unchecked() }, + }) + } + } + #[typetag::serde] + impl $crate::types::evaluator::BinaryEvaluator for [<$value_type MultiplyBinaryEvaluator>] { + fn binary_eval( + &self, + left: &$crate::types::value::DataValue, + right: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + Ok(match (left, right) { + ($compute_type(v1), $compute_type(v2)) => $compute_type(v1.checked_mul(*v2).ok_or($crate::errors::DatabaseError::OverFlow)?), + ($compute_type(_), $crate::types::value::DataValue::Null) + | ($crate::types::value::DataValue::Null, $compute_type(_)) + | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, + _ => unsafe { std::hint::unreachable_unchecked() }, + }) + } + } + #[typetag::serde] + impl $crate::types::evaluator::BinaryEvaluator for [<$value_type DivideBinaryEvaluator>] { + fn binary_eval( + &self, + left: &$crate::types::value::DataValue, + right: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + Ok(match (left, right) { + ($compute_type(v1), $compute_type(v2)) => $crate::types::value::DataValue::Float64(ordered_float::OrderedFloat(*v1 as f64 / *v2 as f64)), + ($compute_type(_), $crate::types::value::DataValue::Null) + | ($crate::types::value::DataValue::Null, $compute_type(_)) + | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, + _ => unsafe { std::hint::unreachable_unchecked() }, + }) + } + } + #[typetag::serde] + impl $crate::types::evaluator::BinaryEvaluator for [<$value_type GtBinaryEvaluator>] { + fn binary_eval( + &self, + left: &$crate::types::value::DataValue, + right: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + Ok(match (left, right) { + ($compute_type(v1), $compute_type(v2)) => $crate::types::value::DataValue::Boolean(v1 > v2), + ($compute_type(_), $crate::types::value::DataValue::Null) + | ($crate::types::value::DataValue::Null, $compute_type(_)) + | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, + _ => unsafe { std::hint::unreachable_unchecked() }, + }) + } + } + #[typetag::serde] + impl $crate::types::evaluator::BinaryEvaluator for [<$value_type GtEqBinaryEvaluator>] { + fn binary_eval( + &self, + left: &$crate::types::value::DataValue, + right: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + Ok(match (left, right) { + ($compute_type(v1), $compute_type(v2)) => $crate::types::value::DataValue::Boolean(v1 >= v2), + ($compute_type(_), $crate::types::value::DataValue::Null) + | ($crate::types::value::DataValue::Null, $compute_type(_)) + | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, + _ => unsafe { std::hint::unreachable_unchecked() }, + }) + } + } + #[typetag::serde] + impl $crate::types::evaluator::BinaryEvaluator for [<$value_type LtBinaryEvaluator>] { + fn binary_eval( + &self, + left: &$crate::types::value::DataValue, + right: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + Ok(match (left, right) { + ($compute_type(v1), $compute_type(v2)) => $crate::types::value::DataValue::Boolean(v1 < v2), + ($compute_type(_), $crate::types::value::DataValue::Null) + | ($crate::types::value::DataValue::Null, $compute_type(_)) + | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, + _ => unsafe { std::hint::unreachable_unchecked() }, + }) + } + } + #[typetag::serde] + impl $crate::types::evaluator::BinaryEvaluator for [<$value_type LtEqBinaryEvaluator>] { + fn binary_eval( + &self, + left: &$crate::types::value::DataValue, + right: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + Ok(match (left, right) { + ($compute_type(v1), $compute_type(v2)) => $crate::types::value::DataValue::Boolean(v1 <= v2), + ($compute_type(_), $crate::types::value::DataValue::Null) + | ($crate::types::value::DataValue::Null, $compute_type(_)) + | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, + _ => unsafe { std::hint::unreachable_unchecked() }, + }) + } + } + #[typetag::serde] + impl $crate::types::evaluator::BinaryEvaluator for [<$value_type EqBinaryEvaluator>] { + fn binary_eval( + &self, + left: &$crate::types::value::DataValue, + right: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + Ok(match (left, right) { + ($compute_type(v1), $compute_type(v2)) => $crate::types::value::DataValue::Boolean(v1 == v2), + ($compute_type(_), $crate::types::value::DataValue::Null) + | ($crate::types::value::DataValue::Null, $compute_type(_)) + | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, + _ => unsafe { std::hint::unreachable_unchecked() }, + }) + } + } + #[typetag::serde] + impl $crate::types::evaluator::BinaryEvaluator for [<$value_type NotEqBinaryEvaluator>] { + fn binary_eval( + &self, + left: &$crate::types::value::DataValue, + right: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + Ok(match (left, right) { + ($compute_type(v1), $compute_type(v2)) => $crate::types::value::DataValue::Boolean(v1 != v2), + ($compute_type(_), $crate::types::value::DataValue::Null) + | ($crate::types::value::DataValue::Null, $compute_type(_)) + | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, + _ => unsafe { std::hint::unreachable_unchecked() }, + }) + } + } + #[typetag::serde] + impl $crate::types::evaluator::BinaryEvaluator for [<$value_type ModBinaryEvaluator>] { + fn binary_eval( + &self, + left: &$crate::types::value::DataValue, + right: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + Ok(match (left, right) { + ($compute_type(v1), $compute_type(v2)) => $compute_type(*v1 % *v2), + ($compute_type(_), $crate::types::value::DataValue::Null) + | ($crate::types::value::DataValue::Null, $compute_type(_)) + | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, + _ => unsafe { std::hint::unreachable_unchecked() }, + }) + } + } + } + }; +} + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod test { + use super::binary_create; + use crate::errors::DatabaseError; + use crate::expression::BinaryOperator; + use crate::serdes::{ReferenceSerialization, ReferenceTables}; + use crate::storage::rocksdb::RocksTransaction; + use crate::types::evaluator::BinaryEvaluatorBox; + use crate::types::LogicalType; + use std::borrow::Cow; + use std::io::{Cursor, Seek, SeekFrom}; + + fn create(ty: LogicalType, op: BinaryOperator) -> Result { + binary_create(Cow::Owned(ty), op) + } + + #[test] + fn test_binary_evaluator_serialization() -> Result<(), DatabaseError> { + let evaluator = create(LogicalType::Boolean, BinaryOperator::NotEq)?; + let mut cursor = Cursor::new(Vec::new()); + let mut reference_tables = ReferenceTables::new(); + + evaluator.encode(&mut cursor, false, &mut reference_tables)?; + cursor.seek(SeekFrom::Start(0))?; + + assert_eq!( + BinaryEvaluatorBox::decode::( + &mut cursor, + None, + &reference_tables + )?, + evaluator + ); + + Ok(()) + } +} diff --git a/src/types/evaluator/boolean.rs b/src/types/evaluator/boolean.rs index 8d6f99ec..897dac74 100644 --- a/src/types/evaluator/boolean.rs +++ b/src/types/evaluator/boolean.rs @@ -13,9 +13,12 @@ // limitations under the License. use crate::errors::DatabaseError; +use crate::types::evaluator::cast::{to_char, to_varchar}; use crate::types::evaluator::DataValue; use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator}; +use ordered_float::OrderedFloat; use serde::{Deserialize, Serialize}; +use sqlparser::ast::CharLengthUnits; use std::hint; #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] @@ -82,6 +85,51 @@ impl BinaryEvaluator for BooleanEqBinaryEvaluator { } } +crate::define_cast_evaluator!(BooleanToTinyintCastEvaluator, DataValue::Boolean(value) => { + Ok(DataValue::Int8(if *value { 1 } else { 0 })) +}); +crate::define_cast_evaluator!(BooleanToUTinyintCastEvaluator, DataValue::Boolean(value) => { + Ok(DataValue::UInt8(if *value { 1 } else { 0 })) +}); +crate::define_cast_evaluator!(BooleanToSmallintCastEvaluator, DataValue::Boolean(value) => { + Ok(DataValue::Int16(if *value { 1 } else { 0 })) +}); +crate::define_cast_evaluator!(BooleanToUSmallintCastEvaluator, DataValue::Boolean(value) => { + Ok(DataValue::UInt16(if *value { 1 } else { 0 })) +}); +crate::define_cast_evaluator!(BooleanToIntegerCastEvaluator, DataValue::Boolean(value) => { + Ok(DataValue::Int32(if *value { 1 } else { 0 })) +}); +crate::define_cast_evaluator!(BooleanToUIntegerCastEvaluator, DataValue::Boolean(value) => { + Ok(DataValue::UInt32(if *value { 1 } else { 0 })) +}); +crate::define_cast_evaluator!(BooleanToBigintCastEvaluator, DataValue::Boolean(value) => { + Ok(DataValue::Int64(if *value { 1 } else { 0 })) +}); +crate::define_cast_evaluator!(BooleanToUBigintCastEvaluator, DataValue::Boolean(value) => { + Ok(DataValue::UInt64(if *value { 1 } else { 0 })) +}); +crate::define_cast_evaluator!(BooleanToFloatCastEvaluator, DataValue::Boolean(value) => { + Ok(DataValue::Float32(OrderedFloat(if *value { 1.0 } else { 0.0 }))) +}); +crate::define_cast_evaluator!(BooleanToDoubleCastEvaluator, DataValue::Boolean(value) => { + Ok(DataValue::Float64(OrderedFloat(if *value { 1.0 } else { 0.0 }))) +}); +crate::define_cast_evaluator!( + BooleanToCharCastEvaluator { + len: u32, + unit: CharLengthUnits + }, + DataValue::Boolean(value) => |this| to_char(value.to_string(), this.len, this.unit) +); +crate::define_cast_evaluator!( + BooleanToVarcharCastEvaluator { + len: Option, + unit: CharLengthUnits + }, + DataValue::Boolean(value) => |this| to_varchar(value.to_string(), this.len, this.unit) +); + #[typetag::serde] impl BinaryEvaluator for BooleanNotEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { @@ -94,3 +142,123 @@ impl BinaryEvaluator for BooleanNotEqBinaryEvaluator { }) } } + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod test { + use super::*; + use crate::types::evaluator::{BinaryEvaluator, CastEvaluator}; + use crate::types::value::Utf8Type; + + #[test] + fn test_boolean_binary_evaluators() { + assert_eq!( + BooleanAndBinaryEvaluator + .binary_eval(&DataValue::Boolean(true), &DataValue::Boolean(true)) + .unwrap(), + DataValue::Boolean(true) + ); + assert_eq!( + BooleanAndBinaryEvaluator + .binary_eval(&DataValue::Boolean(false), &DataValue::Null) + .unwrap(), + DataValue::Boolean(false) + ); + assert_eq!( + BooleanOrBinaryEvaluator + .binary_eval(&DataValue::Boolean(false), &DataValue::Boolean(true)) + .unwrap(), + DataValue::Boolean(true) + ); + } + + #[test] + fn test_boolean_cast_evaluators() { + let value = DataValue::Boolean(true); + + assert_eq!( + BooleanToTinyintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int8(1) + ); + assert_eq!( + BooleanToUTinyintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt8(1) + ); + assert_eq!( + BooleanToSmallintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int16(1) + ); + assert_eq!( + BooleanToUSmallintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt16(1) + ); + assert_eq!( + BooleanToIntegerCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int32(1) + ); + assert_eq!( + BooleanToUIntegerCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt32(1) + ); + assert_eq!( + BooleanToBigintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int64(1) + ); + assert_eq!( + BooleanToUBigintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt64(1) + ); + assert_eq!( + BooleanToFloatCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Float32(OrderedFloat(1.0)) + ); + assert_eq!( + BooleanToDoubleCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Float64(OrderedFloat(1.0)) + ); + assert_eq!( + BooleanToCharCastEvaluator { + len: 4, + unit: CharLengthUnits::Characters, + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "true".to_string(), + ty: Utf8Type::Fixed(4), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + BooleanToVarcharCastEvaluator { + len: Some(4), + unit: CharLengthUnits::Characters, + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "true".to_string(), + ty: Utf8Type::Variable(Some(4)), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + BooleanToDoubleCastEvaluator + .eval_cast(&DataValue::Boolean(false)) + .unwrap(), + DataValue::Float64(OrderedFloat(0.0)) + ); + assert_eq!( + BooleanToVarcharCastEvaluator { + len: None, + unit: CharLengthUnits::Characters, + } + .eval_cast(&DataValue::Boolean(false)) + .unwrap(), + DataValue::Utf8 { + value: "false".to_string(), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + } + ); + } +} diff --git a/src/types/evaluator/cast.rs b/src/types/evaluator/cast.rs new file mode 100644 index 00000000..40beee08 --- /dev/null +++ b/src/types/evaluator/cast.rs @@ -0,0 +1,786 @@ +// Copyright 2024 KipData/KiteSQL +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::errors::DatabaseError; +use crate::types::evaluator::boolean::*; +use crate::types::evaluator::date::*; +use crate::types::evaluator::datetime::*; +use crate::types::evaluator::decimal::*; +use crate::types::evaluator::float32::*; +use crate::types::evaluator::float64::*; +use crate::types::evaluator::int16::*; +use crate::types::evaluator::int32::*; +use crate::types::evaluator::int64::*; +use crate::types::evaluator::int8::*; +use crate::types::evaluator::null::{NullCastEvaluator, ToSqlNullCastEvaluator}; +use crate::types::evaluator::time32::*; +use crate::types::evaluator::time64::*; +use crate::types::evaluator::tuple::TupleCastEvaluator; +use crate::types::evaluator::uint16::*; +use crate::types::evaluator::uint32::*; +use crate::types::evaluator::uint64::*; +use crate::types::evaluator::uint8::*; +use crate::types::evaluator::utf8::*; +use crate::types::evaluator::{CastEvaluator, CastEvaluatorBox}; +use crate::types::value::{DataValue, Utf8Type}; +use crate::types::LogicalType; +use paste::paste; +use serde::{Deserialize, Serialize}; +use sqlparser::ast::CharLengthUnits; +use std::borrow::Cow; +use std::sync::Arc; + +pub(crate) fn cast_fail(from: LogicalType, to: LogicalType) -> DatabaseError { + DatabaseError::CastFail { + from, + to, + span: None, + } +} + +pub(crate) fn to_char( + value: String, + len: u32, + unit: CharLengthUnits, +) -> Result { + if DataValue::check_string_len(&value, len as usize, unit) { + return Err(DatabaseError::TooLong); + } + + Ok(DataValue::Utf8 { + value, + ty: Utf8Type::Fixed(len), + unit, + }) +} + +pub(crate) fn to_varchar( + value: String, + len: Option, + unit: CharLengthUnits, +) -> Result { + if let Some(len) = len { + if DataValue::check_string_len(&value, len as usize, unit) { + return Err(DatabaseError::TooLong); + } + } + + Ok(DataValue::Utf8 { + value, + ty: Utf8Type::Variable(len), + unit, + }) +} + +#[macro_export] +macro_rules! numeric_to_boolean_cast { + ($value:expr, $from:expr) => { + match $value { + 0 => Ok($crate::types::value::DataValue::Boolean(false)), + 1 => Ok($crate::types::value::DataValue::Boolean(true)), + _ => Err($crate::types::evaluator::cast::cast_fail( + $from, + $crate::types::LogicalType::Boolean, + )), + } + }; +} + +#[macro_export] +macro_rules! float_to_int_cast { + ($float_value:expr, $int_type:ty, $float_type:ty) => {{ + let float_value: $float_type = $float_value; + if float_value.is_nan() { + Ok(0) + } else if float_value <= 0.0 || float_value > <$int_type>::MAX as $float_type { + Err($crate::errors::DatabaseError::OverFlow) + } else { + Ok(float_value as $int_type) + } + }}; +} + +#[macro_export] +macro_rules! decimal_to_int_cast { + ($decimal:expr, $int_type:ty) => {{ + let d = $decimal; + if d.is_sign_negative() { + if <$int_type>::MIN == 0 { + 0 + } else { + let min = rust_decimal::Decimal::from(<$int_type>::MIN); + if d <= min { + <$int_type>::MIN + } else { + d.to_i128().unwrap() as $int_type + } + } + } else { + let max = rust_decimal::Decimal::from(<$int_type>::MAX); + if d >= max { + <$int_type>::MAX + } else { + d.to_i128().unwrap() as $int_type + } + } + }}; +} + +#[macro_export] +macro_rules! define_cast_evaluator { + ($name:ident, $pattern:pat => $body:block) => { + #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] + pub struct $name; + + #[typetag::serde] + impl $crate::types::evaluator::CastEvaluator for $name { + fn eval_cast( + &self, + value: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + match value { + $crate::types::value::DataValue::Null => Ok($crate::types::value::DataValue::Null), + $pattern => $body, + _ => unsafe { std::hint::unreachable_unchecked() }, + } + } + } + }; + ($name:ident, $pattern:pat => |$this:ident| $body:expr) => { + #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] + pub struct $name; + + #[typetag::serde] + impl $crate::types::evaluator::CastEvaluator for $name { + fn eval_cast( + &self, + value: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + match value { + $crate::types::value::DataValue::Null => Ok($crate::types::value::DataValue::Null), + $pattern => { + let $this = self; + $body + } + _ => unsafe { std::hint::unreachable_unchecked() }, + } + } + } + }; + ($name:ident, $pattern:pat => |$this:ident| $body:block) => { + #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] + pub struct $name; + + #[typetag::serde] + impl $crate::types::evaluator::CastEvaluator for $name { + fn eval_cast( + &self, + value: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + match value { + $crate::types::value::DataValue::Null => Ok($crate::types::value::DataValue::Null), + $pattern => { + let $this = self; + $body + } + _ => unsafe { std::hint::unreachable_unchecked() }, + } + } + } + }; + ($name:ident { $($field:ident : $field_ty:ty),+ $(,)? }, $pattern:pat => |$this:ident| $body:expr) => { + #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] + pub struct $name { + $(pub $field: $field_ty),+ + } + + #[typetag::serde] + impl $crate::types::evaluator::CastEvaluator for $name { + fn eval_cast( + &self, + value: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + match value { + $crate::types::value::DataValue::Null => Ok($crate::types::value::DataValue::Null), + $pattern => { + let $this = self; + $body + } + _ => unsafe { std::hint::unreachable_unchecked() }, + } + } + } + }; + ($name:ident { $($field:ident : $field_ty:ty),+ $(,)? }, $pattern:pat => $body:block) => { + #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] + pub struct $name { + $(pub $field: $field_ty),+ + } + + #[typetag::serde] + impl $crate::types::evaluator::CastEvaluator for $name { + fn eval_cast( + &self, + value: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + match value { + $crate::types::value::DataValue::Null => Ok($crate::types::value::DataValue::Null), + $pattern => $body, + _ => unsafe { std::hint::unreachable_unchecked() }, + } + } + } + }; +} + +#[macro_export] +macro_rules! define_integer_cast_evaluators { + ($prefix:ident, $variant:ident, $src_ty:ty, $from_ty:expr) => { + paste::paste! { + $crate::define_cast_evaluator!([<$prefix ToBooleanCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + $crate::numeric_to_boolean_cast!(*value, $from_ty) + }); + $crate::define_cast_evaluator!([<$prefix ToTinyintCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + Ok($crate::types::value::DataValue::Int8(i8::try_from(*value)?)) + }); + $crate::define_cast_evaluator!([<$prefix ToUTinyintCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + Ok($crate::types::value::DataValue::UInt8(u8::try_from(*value)?)) + }); + $crate::define_cast_evaluator!([<$prefix ToSmallintCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + Ok($crate::types::value::DataValue::Int16(i16::try_from(*value)?)) + }); + $crate::define_cast_evaluator!([<$prefix ToUSmallintCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + Ok($crate::types::value::DataValue::UInt16(u16::try_from(*value)?)) + }); + $crate::define_cast_evaluator!([<$prefix ToIntegerCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + Ok($crate::types::value::DataValue::Int32(i32::try_from(*value)?)) + }); + $crate::define_cast_evaluator!([<$prefix ToUIntegerCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + Ok($crate::types::value::DataValue::UInt32(u32::try_from(*value)?)) + }); + $crate::define_cast_evaluator!([<$prefix ToBigintCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + Ok($crate::types::value::DataValue::Int64(i64::try_from(*value)?)) + }); + $crate::define_cast_evaluator!([<$prefix ToUBigintCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + Ok($crate::types::value::DataValue::UInt64(u64::try_from(*value)?)) + }); + $crate::define_cast_evaluator!([<$prefix ToFloatCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + Ok($crate::types::value::DataValue::Float32(ordered_float::OrderedFloat(*value as f32))) + }); + $crate::define_cast_evaluator!([<$prefix ToDoubleCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + Ok($crate::types::value::DataValue::Float64(ordered_float::OrderedFloat(*value as f64))) + }); + $crate::define_cast_evaluator!( + [<$prefix ToCharCastEvaluator>] { + len: u32, + unit: sqlparser::ast::CharLengthUnits + }, + $crate::types::value::DataValue::$variant(value) => |this| { + $crate::types::evaluator::cast::to_char(value.to_string(), this.len, this.unit) + } + ); + $crate::define_cast_evaluator!( + [<$prefix ToVarcharCastEvaluator>] { + len: Option, + unit: sqlparser::ast::CharLengthUnits + }, + $crate::types::value::DataValue::$variant(value) => |this| { + $crate::types::evaluator::cast::to_varchar(value.to_string(), this.len, this.unit) + } + ); + $crate::define_cast_evaluator!( + [<$prefix ToDecimalCastEvaluator>] { + scale: Option + }, + $crate::types::value::DataValue::$variant(value) => |this| { + let mut decimal = rust_decimal::Decimal::from(*value); + $crate::types::value::DataValue::decimal_round_i(&this.scale, &mut decimal); + Ok($crate::types::value::DataValue::Decimal(decimal)) + } + ); + } + }; +} + +#[macro_export] +macro_rules! define_float_cast_evaluators { + ($prefix:ident, $variant:ident, $src_ty:ty, $from_ty:expr, $into_decimal:ident) => { + paste::paste! { + $crate::define_cast_evaluator!([<$prefix ToFloatCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + Ok($crate::types::value::DataValue::$variant(*value)) + }); + $crate::define_cast_evaluator!([<$prefix ToDoubleCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + Ok($crate::types::value::DataValue::Float64(ordered_float::OrderedFloat(value.0 as f64))) + }); + $crate::define_cast_evaluator!([<$prefix ToTinyintCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + Ok($crate::types::value::DataValue::Int8($crate::float_to_int_cast!(value.into_inner(), i8, $src_ty)?)) + }); + $crate::define_cast_evaluator!([<$prefix ToSmallintCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + Ok($crate::types::value::DataValue::Int16($crate::float_to_int_cast!(value.into_inner(), i16, $src_ty)?)) + }); + $crate::define_cast_evaluator!([<$prefix ToIntegerCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + Ok($crate::types::value::DataValue::Int32($crate::float_to_int_cast!(value.into_inner(), i32, $src_ty)?)) + }); + $crate::define_cast_evaluator!([<$prefix ToBigintCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + Ok($crate::types::value::DataValue::Int64($crate::float_to_int_cast!(value.into_inner(), i64, $src_ty)?)) + }); + $crate::define_cast_evaluator!([<$prefix ToUTinyintCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + Ok($crate::types::value::DataValue::UInt8($crate::float_to_int_cast!(value.into_inner(), u8, $src_ty)?)) + }); + $crate::define_cast_evaluator!([<$prefix ToUSmallintCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + Ok($crate::types::value::DataValue::UInt16($crate::float_to_int_cast!(value.into_inner(), u16, $src_ty)?)) + }); + $crate::define_cast_evaluator!([<$prefix ToUIntegerCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + Ok($crate::types::value::DataValue::UInt32($crate::float_to_int_cast!(value.into_inner(), u32, $src_ty)?)) + }); + $crate::define_cast_evaluator!([<$prefix ToUBigintCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + Ok($crate::types::value::DataValue::UInt64($crate::float_to_int_cast!(value.into_inner(), u64, $src_ty)?)) + }); + $crate::define_cast_evaluator!( + [<$prefix ToCharCastEvaluator>] { + len: u32, + unit: sqlparser::ast::CharLengthUnits + }, + $crate::types::value::DataValue::$variant(value) => |this| { + $crate::types::evaluator::cast::to_char(value.to_string(), this.len, this.unit) + } + ); + $crate::define_cast_evaluator!( + [<$prefix ToVarcharCastEvaluator>] { + len: Option, + unit: sqlparser::ast::CharLengthUnits + }, + $crate::types::value::DataValue::$variant(value) => |this| { + $crate::types::evaluator::cast::to_varchar(value.to_string(), this.len, this.unit) + } + ); + $crate::define_cast_evaluator!( + [<$prefix ToDecimalCastEvaluator>] { + scale: Option, + to: $crate::types::LogicalType + }, + $crate::types::value::DataValue::$variant(value) => |this| { + let mut decimal = rust_decimal::Decimal::$into_decimal(value.0).ok_or_else(|| { + $crate::types::evaluator::cast::cast_fail($from_ty, this.to.clone()) + })?; + $crate::types::value::DataValue::decimal_round_f(&this.scale, &mut decimal); + Ok($crate::types::value::DataValue::Decimal(decimal)) + } + ); + } + }; +} + +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct IdentityCastEvaluator; + +#[typetag::serde] +impl CastEvaluator for IdentityCastEvaluator { + fn eval_cast(&self, value: &DataValue) -> Result { + Ok(value.clone()) + } +} + +macro_rules! box_cast { + ($evaluator:expr) => { + Ok(CastEvaluatorBox(Arc::new($evaluator))) + }; +} + +macro_rules! build_integer_cast { + ($prefix:ident, $to:expr, $from:expr) => {{ + paste! { + match $to { + LogicalType::SqlNull => box_cast!(ToSqlNullCastEvaluator), + LogicalType::Boolean => box_cast!([<$prefix ToBooleanCastEvaluator>]), + LogicalType::Tinyint => box_cast!([<$prefix ToTinyintCastEvaluator>]), + LogicalType::UTinyint => box_cast!([<$prefix ToUTinyintCastEvaluator>]), + LogicalType::Smallint => box_cast!([<$prefix ToSmallintCastEvaluator>]), + LogicalType::USmallint => box_cast!([<$prefix ToUSmallintCastEvaluator>]), + LogicalType::Integer => box_cast!([<$prefix ToIntegerCastEvaluator>]), + LogicalType::UInteger => box_cast!([<$prefix ToUIntegerCastEvaluator>]), + LogicalType::Bigint => box_cast!([<$prefix ToBigintCastEvaluator>]), + LogicalType::UBigint => box_cast!([<$prefix ToUBigintCastEvaluator>]), + LogicalType::Float => box_cast!([<$prefix ToFloatCastEvaluator>]), + LogicalType::Double => box_cast!([<$prefix ToDoubleCastEvaluator>]), + LogicalType::Char(len, unit) => box_cast!([<$prefix ToCharCastEvaluator>] { len: *len, unit: *unit }), + LogicalType::Varchar(len, unit) => box_cast!([<$prefix ToVarcharCastEvaluator>] { len: *len, unit: *unit }), + LogicalType::Decimal(_, scale) => box_cast!([<$prefix ToDecimalCastEvaluator>] { scale: *scale }), + _ => Err(cast_fail($from.clone(), $to.clone())), + } + } + }}; +} + +pub fn cast_create( + from: Cow<'_, LogicalType>, + to: Cow<'_, LogicalType>, +) -> Result { + let from = from.as_ref(); + let to = to.as_ref(); + if from == to { + return box_cast!(IdentityCastEvaluator); + } + + match (from, to) { + (LogicalType::SqlNull, _) => box_cast!(NullCastEvaluator), + (_, LogicalType::SqlNull) => box_cast!(ToSqlNullCastEvaluator), + (LogicalType::Boolean, LogicalType::Tinyint) => box_cast!(BooleanToTinyintCastEvaluator), + (LogicalType::Boolean, LogicalType::UTinyint) => box_cast!(BooleanToUTinyintCastEvaluator), + (LogicalType::Boolean, LogicalType::Smallint) => box_cast!(BooleanToSmallintCastEvaluator), + (LogicalType::Boolean, LogicalType::USmallint) => { + box_cast!(BooleanToUSmallintCastEvaluator) + } + (LogicalType::Boolean, LogicalType::Integer) => box_cast!(BooleanToIntegerCastEvaluator), + (LogicalType::Boolean, LogicalType::UInteger) => box_cast!(BooleanToUIntegerCastEvaluator), + (LogicalType::Boolean, LogicalType::Bigint) => box_cast!(BooleanToBigintCastEvaluator), + (LogicalType::Boolean, LogicalType::UBigint) => box_cast!(BooleanToUBigintCastEvaluator), + (LogicalType::Boolean, LogicalType::Float) => box_cast!(BooleanToFloatCastEvaluator), + (LogicalType::Boolean, LogicalType::Double) => box_cast!(BooleanToDoubleCastEvaluator), + (LogicalType::Boolean, LogicalType::Char(len, unit)) => { + box_cast!(BooleanToCharCastEvaluator { + len: *len, + unit: *unit + }) + } + (LogicalType::Boolean, LogicalType::Varchar(len, unit)) => { + box_cast!(BooleanToVarcharCastEvaluator { + len: *len, + unit: *unit + }) + } + (LogicalType::Tinyint, _) => build_integer_cast!(Int8, to, from), + (LogicalType::Smallint, _) => build_integer_cast!(Int16, to, from), + (LogicalType::Integer, _) => build_integer_cast!(Int32, to, from), + (LogicalType::Bigint, _) => build_integer_cast!(Int64, to, from), + (LogicalType::UTinyint, _) => build_integer_cast!(UInt8, to, from), + (LogicalType::USmallint, _) => build_integer_cast!(UInt16, to, from), + (LogicalType::UInteger, _) => build_integer_cast!(UInt32, to, from), + (LogicalType::UBigint, _) => build_integer_cast!(UInt64, to, from), + (LogicalType::Float, LogicalType::Tinyint) => box_cast!(Float32ToTinyintCastEvaluator), + (LogicalType::Float, LogicalType::UTinyint) => box_cast!(Float32ToUTinyintCastEvaluator), + (LogicalType::Float, LogicalType::Smallint) => box_cast!(Float32ToSmallintCastEvaluator), + (LogicalType::Float, LogicalType::USmallint) => box_cast!(Float32ToUSmallintCastEvaluator), + (LogicalType::Float, LogicalType::Integer) => box_cast!(Float32ToIntegerCastEvaluator), + (LogicalType::Float, LogicalType::UInteger) => box_cast!(Float32ToUIntegerCastEvaluator), + (LogicalType::Float, LogicalType::Bigint) => box_cast!(Float32ToBigintCastEvaluator), + (LogicalType::Float, LogicalType::UBigint) => box_cast!(Float32ToUBigintCastEvaluator), + (LogicalType::Float, LogicalType::Double) => box_cast!(Float32ToDoubleCastEvaluator), + (LogicalType::Float, LogicalType::Char(len, unit)) => { + box_cast!(Float32ToCharCastEvaluator { + len: *len, + unit: *unit + }) + } + (LogicalType::Float, LogicalType::Varchar(len, unit)) => { + box_cast!(Float32ToVarcharCastEvaluator { + len: *len, + unit: *unit + }) + } + (LogicalType::Float, LogicalType::Decimal(_, scale)) => { + box_cast!(Float32ToDecimalCastEvaluator { + scale: *scale, + to: to.clone() + }) + } + (LogicalType::Double, LogicalType::Float) => box_cast!(Float64ToFloatCastEvaluator), + (LogicalType::Double, LogicalType::Tinyint) => box_cast!(Float64ToTinyintCastEvaluator), + (LogicalType::Double, LogicalType::UTinyint) => box_cast!(Float64ToUTinyintCastEvaluator), + (LogicalType::Double, LogicalType::Smallint) => box_cast!(Float64ToSmallintCastEvaluator), + (LogicalType::Double, LogicalType::USmallint) => box_cast!(Float64ToUSmallintCastEvaluator), + (LogicalType::Double, LogicalType::Integer) => box_cast!(Float64ToIntegerCastEvaluator), + (LogicalType::Double, LogicalType::UInteger) => box_cast!(Float64ToUIntegerCastEvaluator), + (LogicalType::Double, LogicalType::Bigint) => box_cast!(Float64ToBigintCastEvaluator), + (LogicalType::Double, LogicalType::UBigint) => box_cast!(Float64ToUBigintCastEvaluator), + (LogicalType::Double, LogicalType::Char(len, unit)) => { + box_cast!(Float64ToCharCastEvaluator { + len: *len, + unit: *unit + }) + } + (LogicalType::Double, LogicalType::Varchar(len, unit)) => { + box_cast!(Float64ToVarcharCastEvaluator { + len: *len, + unit: *unit + }) + } + (LogicalType::Double, LogicalType::Decimal(_, scale)) => { + box_cast!(Float64ToDecimalCastEvaluator { + scale: *scale, + to: to.clone() + }) + } + (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Boolean) => { + box_cast!(Utf8ToBooleanCastEvaluator { from: from.clone() }) + } + (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Tinyint) => { + box_cast!(Utf8ToTinyintCastEvaluator) + } + (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::UTinyint) => { + box_cast!(Utf8ToUTinyintCastEvaluator) + } + (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Smallint) => { + box_cast!(Utf8ToSmallintCastEvaluator) + } + (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::USmallint) => { + box_cast!(Utf8ToUSmallintCastEvaluator) + } + (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Integer) => { + box_cast!(Utf8ToIntegerCastEvaluator) + } + (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::UInteger) => { + box_cast!(Utf8ToUIntegerCastEvaluator) + } + (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Bigint) => { + box_cast!(Utf8ToBigintCastEvaluator) + } + (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::UBigint) => { + box_cast!(Utf8ToUBigintCastEvaluator) + } + (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Float) => { + box_cast!(Utf8ToFloatCastEvaluator) + } + (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Double) => { + box_cast!(Utf8ToDoubleCastEvaluator) + } + (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Char(len, unit)) => { + box_cast!(Utf8ToCharCastEvaluator { + len: *len, + unit: *unit + }) + } + (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Varchar(len, unit)) => { + box_cast!(Utf8ToVarcharCastEvaluator { + len: *len, + unit: *unit + }) + } + (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Date) => { + box_cast!(Utf8ToDateCastEvaluator) + } + (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::DateTime) => { + box_cast!(Utf8ToDatetimeCastEvaluator) + } + (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Time(precision)) => { + box_cast!(Utf8ToTimeCastEvaluator { + precision: *precision + }) + } + ( + LogicalType::Char(_, _) | LogicalType::Varchar(_, _), + LogicalType::TimeStamp(precision, zone), + ) => { + box_cast!(Utf8ToTimestampCastEvaluator { + precision: *precision, + zone: *zone, + to: to.clone() + }) + } + (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Decimal(_, _)) => { + box_cast!(Utf8ToDecimalCastEvaluator) + } + (LogicalType::Date, LogicalType::Char(len, unit)) => { + box_cast!(Date32ToCharCastEvaluator { + len: *len, + unit: *unit, + to: to.clone() + }) + } + (LogicalType::Date, LogicalType::Varchar(len, unit)) => { + box_cast!(Date32ToVarcharCastEvaluator { + len: *len, + unit: *unit, + to: to.clone() + }) + } + (LogicalType::Date, LogicalType::DateTime) => { + box_cast!(Date32ToDatetimeCastEvaluator { to: to.clone() }) + } + (LogicalType::DateTime, LogicalType::Char(len, unit)) => { + box_cast!(Date64ToCharCastEvaluator { + len: *len, + unit: *unit, + to: to.clone() + }) + } + (LogicalType::DateTime, LogicalType::Varchar(len, unit)) => { + box_cast!(Date64ToVarcharCastEvaluator { + len: *len, + unit: *unit, + to: to.clone() + }) + } + (LogicalType::DateTime, LogicalType::Date) => { + box_cast!(Date64ToDateCastEvaluator { to: to.clone() }) + } + (LogicalType::DateTime, LogicalType::Time(precision)) => { + box_cast!(Date64ToTimeCastEvaluator { + precision: *precision, + to: to.clone() + }) + } + (LogicalType::DateTime, LogicalType::TimeStamp(precision, zone)) => { + box_cast!(Date64ToTimestampCastEvaluator { + precision: *precision, + zone: *zone + }) + } + (LogicalType::Time(_), LogicalType::Char(len, unit)) => { + box_cast!(Time32ToCharCastEvaluator { + len: *len, + unit: *unit, + to: to.clone() + }) + } + (LogicalType::Time(_), LogicalType::Varchar(len, unit)) => { + box_cast!(Time32ToVarcharCastEvaluator { + len: *len, + unit: *unit, + to: to.clone() + }) + } + (LogicalType::Time(_), LogicalType::Time(precision)) => { + box_cast!(Time32ToTimeCastEvaluator { + precision: *precision + }) + } + (LogicalType::TimeStamp(_, _), LogicalType::Char(len, unit)) => { + box_cast!(Time64ToCharCastEvaluator { + len: *len, + unit: *unit, + to: to.clone() + }) + } + (LogicalType::TimeStamp(_, _), LogicalType::Varchar(len, unit)) => { + box_cast!(Time64ToVarcharCastEvaluator { + len: *len, + unit: *unit, + to: to.clone() + }) + } + (LogicalType::TimeStamp(_, _), LogicalType::Date) => { + box_cast!(Time64ToDateCastEvaluator { + from: from.clone(), + to: to.clone() + }) + } + (LogicalType::TimeStamp(_, _), LogicalType::DateTime) => { + box_cast!(Time64ToDatetimeCastEvaluator { + from: from.clone(), + to: to.clone() + }) + } + (LogicalType::TimeStamp(_, _), LogicalType::Time(precision)) => { + box_cast!(Time64ToTimeCastEvaluator { + precision: *precision, + from: from.clone(), + to: to.clone() + }) + } + (LogicalType::TimeStamp(_, _), LogicalType::TimeStamp(precision, zone)) => { + box_cast!(Time64ToTimestampCastEvaluator { + precision: *precision, + zone: *zone + }) + } + (LogicalType::Decimal(_, _), LogicalType::Float) => box_cast!(DecimalToFloatCastEvaluator), + (LogicalType::Decimal(_, _), LogicalType::Double) => { + box_cast!(DecimalToDoubleCastEvaluator) + } + (LogicalType::Decimal(_, _), LogicalType::Decimal(_, _)) => { + box_cast!(DecimalToDecimalCastEvaluator) + } + (LogicalType::Decimal(_, _), LogicalType::Char(len, unit)) => { + box_cast!(DecimalToCharCastEvaluator { + len: *len, + unit: *unit + }) + } + (LogicalType::Decimal(_, _), LogicalType::Varchar(len, unit)) => { + box_cast!(DecimalToVarcharCastEvaluator { + len: *len, + unit: *unit + }) + } + (LogicalType::Decimal(_, _), LogicalType::Tinyint) => { + box_cast!(DecimalToTinyintCastEvaluator) + } + (LogicalType::Decimal(_, _), LogicalType::Smallint) => { + box_cast!(DecimalToSmallintCastEvaluator) + } + (LogicalType::Decimal(_, _), LogicalType::Integer) => { + box_cast!(DecimalToIntegerCastEvaluator) + } + (LogicalType::Decimal(_, _), LogicalType::Bigint) => { + box_cast!(DecimalToBigintCastEvaluator) + } + (LogicalType::Decimal(_, _), LogicalType::UTinyint) => { + box_cast!(DecimalToUTinyintCastEvaluator) + } + (LogicalType::Decimal(_, _), LogicalType::USmallint) => { + box_cast!(DecimalToUSmallintCastEvaluator) + } + (LogicalType::Decimal(_, _), LogicalType::UInteger) => { + box_cast!(DecimalToUIntegerCastEvaluator) + } + (LogicalType::Decimal(_, _), LogicalType::UBigint) => { + box_cast!(DecimalToUBigintCastEvaluator) + } + (LogicalType::Tuple(from_types), LogicalType::Tuple(to_types)) => { + let evaluators = from_types + .iter() + .zip(to_types.iter()) + .map(|(from, to)| cast_create(Cow::Borrowed(from), Cow::Borrowed(to))) + .collect::, _>>()?; + box_cast!(TupleCastEvaluator { + element_evaluators: evaluators + }) + } + _ => Err(cast_fail(from.clone(), to.clone())), + } +} + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod test { + use super::cast_create; + use crate::errors::DatabaseError; + use crate::serdes::{ReferenceSerialization, ReferenceTables}; + use crate::storage::rocksdb::RocksTransaction; + use crate::types::evaluator::CastEvaluatorBox; + use crate::types::LogicalType; + use std::borrow::Cow; + use std::io::{Cursor, Seek, SeekFrom}; + + fn create(from: LogicalType, to: LogicalType) -> Result { + cast_create(Cow::Owned(from), Cow::Owned(to)) + } + + #[test] + fn test_cast_evaluator_serialization() -> Result<(), DatabaseError> { + let evaluator = create(LogicalType::Integer, LogicalType::Bigint)?; + let mut cursor = Cursor::new(Vec::new()); + let mut reference_tables = ReferenceTables::new(); + + evaluator.encode(&mut cursor, false, &mut reference_tables)?; + cursor.seek(SeekFrom::Start(0))?; + + assert_eq!( + CastEvaluatorBox::decode::(&mut cursor, None, &reference_tables)?, + evaluator + ); + + Ok(()) + } +} diff --git a/src/types/evaluator/date.rs b/src/types/evaluator/date.rs index 7289ab44..4eafe0c9 100644 --- a/src/types/evaluator/date.rs +++ b/src/types/evaluator/date.rs @@ -13,11 +13,113 @@ // limitations under the License. use crate::numeric_binary_evaluator_definition; -use crate::types::evaluator::BinaryEvaluator; +use crate::types::evaluator::cast::{cast_fail, to_char, to_varchar}; use crate::types::evaluator::DataValue; -use crate::types::DatabaseError; -use paste::paste; -use serde::{Deserialize, Serialize}; -use std::hint; +use crate::types::LogicalType; +use chrono::NaiveDate; +use sqlparser::ast::CharLengthUnits; numeric_binary_evaluator_definition!(Date, DataValue::Date32); +crate::define_cast_evaluator!( + Date32ToCharCastEvaluator { + len: u32, + unit: CharLengthUnits, + to: LogicalType + }, + DataValue::Date32(value) => |this| { + to_char( + DataValue::format_date(*value).ok_or_else(|| cast_fail(LogicalType::Date, this.to.clone()))?, + this.len, + this.unit, + ) + } +); +crate::define_cast_evaluator!( + Date32ToVarcharCastEvaluator { + len: Option, + unit: CharLengthUnits, + to: LogicalType + }, + DataValue::Date32(value) => |this| { + to_varchar( + DataValue::format_date(*value).ok_or_else(|| cast_fail(LogicalType::Date, this.to.clone()))?, + this.len, + this.unit, + ) + } +); +crate::define_cast_evaluator!( + Date32ToDatetimeCastEvaluator { + to: LogicalType + }, + DataValue::Date32(value) => |this| { + let value = NaiveDate::from_num_days_from_ce_opt(*value) + .ok_or_else(|| cast_fail(LogicalType::Date, this.to.clone()))? + .and_hms_opt(0, 0, 0) + .ok_or_else(|| cast_fail(LogicalType::Date, this.to.clone()))? + .and_utc() + .timestamp(); + + Ok(DataValue::Date64(value)) + } +); + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod test { + use super::*; + use crate::types::evaluator::CastEvaluator; + use crate::types::value::Utf8Type; + use chrono::Datelike; + + #[test] + fn test_date_cast_evaluators() { + let value = DataValue::Date32( + NaiveDate::from_ymd_opt(2024, 1, 2) + .unwrap() + .num_days_from_ce(), + ); + assert_eq!( + Date32ToCharCastEvaluator { + len: 10, + unit: CharLengthUnits::Characters, + to: LogicalType::Char(10, CharLengthUnits::Characters), + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "2024-01-02".to_string(), + ty: Utf8Type::Fixed(10), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + Date32ToVarcharCastEvaluator { + len: Some(10), + unit: CharLengthUnits::Characters, + to: LogicalType::Varchar(Some(10), CharLengthUnits::Characters), + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "2024-01-02".to_string(), + ty: Utf8Type::Variable(Some(10)), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + Date32ToDatetimeCastEvaluator { + to: LogicalType::DateTime, + } + .eval_cast(&value) + .unwrap(), + DataValue::Date64( + NaiveDate::from_ymd_opt(2024, 1, 2) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap() + .and_utc() + .timestamp() + ) + ); + } +} diff --git a/src/types/evaluator/datetime.rs b/src/types/evaluator/datetime.rs index 708c005a..b2c428d7 100644 --- a/src/types/evaluator/datetime.rs +++ b/src/types/evaluator/datetime.rs @@ -13,11 +13,162 @@ // limitations under the License. use crate::numeric_binary_evaluator_definition; -use crate::types::evaluator::BinaryEvaluator; +use crate::types::evaluator::cast::{cast_fail, to_char, to_varchar}; use crate::types::evaluator::DataValue; -use crate::types::DatabaseError; -use paste::paste; -use serde::{Deserialize, Serialize}; -use std::hint; +use crate::types::LogicalType; +use chrono::{DateTime, Datelike, Timelike}; +use sqlparser::ast::CharLengthUnits; numeric_binary_evaluator_definition!(DateTime, DataValue::Date64); +crate::define_cast_evaluator!( + Date64ToCharCastEvaluator { + len: u32, + unit: CharLengthUnits, + to: LogicalType + }, + DataValue::Date64(value) => |this| { + to_char( + DataValue::format_datetime(*value).ok_or_else(|| cast_fail(LogicalType::DateTime, this.to.clone()))?, + this.len, + this.unit, + ) + } +); +crate::define_cast_evaluator!( + Date64ToVarcharCastEvaluator { + len: Option, + unit: CharLengthUnits, + to: LogicalType + }, + DataValue::Date64(value) => |this| { + to_varchar( + DataValue::format_datetime(*value).ok_or_else(|| cast_fail(LogicalType::DateTime, this.to.clone()))?, + this.len, + this.unit, + ) + } +); +crate::define_cast_evaluator!( + Date64ToDateCastEvaluator { + to: LogicalType + }, + DataValue::Date64(value) => |this| { + let value = DateTime::from_timestamp(*value, 0) + .ok_or_else(|| cast_fail(LogicalType::DateTime, this.to.clone()))? + .naive_utc() + .date() + .num_days_from_ce(); + + Ok(DataValue::Date32(value)) + } +); +crate::define_cast_evaluator!( + Date64ToTimeCastEvaluator { + precision: Option, + to: LogicalType + }, + DataValue::Date64(value) => |this| { + let precision = this.precision.unwrap_or(0); + let value = DateTime::from_timestamp(*value, 0) + .map(|date_time| date_time.time().num_seconds_from_midnight()) + .ok_or_else(|| cast_fail(LogicalType::DateTime, this.to.clone()))?; + + Ok(DataValue::Time32(DataValue::pack(value, 0, 0), precision)) + } +); +crate::define_cast_evaluator!( + Date64ToTimestampCastEvaluator { + precision: Option, + zone: bool + }, + DataValue::Date64(value) => |this| { + Ok(DataValue::Time64(*value, this.precision.unwrap_or(0), this.zone)) + } +); + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod test { + use super::*; + use crate::types::evaluator::CastEvaluator; + use crate::types::value::Utf8Type; + use sqlparser::ast::CharLengthUnits; + + #[test] + fn test_datetime_cast_evaluators() { + let value = DataValue::Date64( + chrono::NaiveDate::from_ymd_opt(2024, 1, 2) + .unwrap() + .and_hms_opt(3, 4, 5) + .unwrap() + .and_utc() + .timestamp(), + ); + assert_eq!( + Date64ToCharCastEvaluator { + len: 19, + unit: CharLengthUnits::Characters, + to: LogicalType::Char(19, CharLengthUnits::Characters), + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "2024-01-02 03:04:05".to_string(), + ty: Utf8Type::Fixed(19), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + Date64ToVarcharCastEvaluator { + len: Some(19), + unit: CharLengthUnits::Characters, + to: LogicalType::Varchar(Some(19), CharLengthUnits::Characters), + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "2024-01-02 03:04:05".to_string(), + ty: Utf8Type::Variable(Some(19)), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + Date64ToDateCastEvaluator { + to: LogicalType::Date + } + .eval_cast(&value) + .unwrap(), + DataValue::Date32( + chrono::NaiveDate::from_ymd_opt(2024, 1, 2) + .unwrap() + .num_days_from_ce() + ) + ); + assert_eq!( + Date64ToTimeCastEvaluator { + precision: Some(0), + to: LogicalType::Time(Some(0)), + } + .eval_cast(&value) + .unwrap(), + DataValue::Time32(DataValue::pack(3 * 3600 + 4 * 60 + 5, 0, 0), 0) + ); + assert_eq!( + Date64ToTimestampCastEvaluator { + precision: Some(0), + zone: true, + } + .eval_cast(&value) + .unwrap(), + DataValue::Time64( + chrono::NaiveDate::from_ymd_opt(2024, 1, 2) + .unwrap() + .and_hms_opt(3, 4, 5) + .unwrap() + .and_utc() + .timestamp(), + 0, + true, + ) + ); + } +} diff --git a/src/types/evaluator/decimal.rs b/src/types/evaluator/decimal.rs index 720aa65d..80554863 100644 --- a/src/types/evaluator/decimal.rs +++ b/src/types/evaluator/decimal.rs @@ -13,9 +13,13 @@ // limitations under the License. use crate::errors::DatabaseError; +use crate::types::evaluator::cast::{to_char, to_varchar}; use crate::types::evaluator::BinaryEvaluator; use crate::types::evaluator::DataValue; +use ordered_float::OrderedFloat; +use rust_decimal::prelude::ToPrimitive; use serde::{Deserialize, Serialize}; +use sqlparser::ast::CharLengthUnits; use std::hint; #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] @@ -65,6 +69,65 @@ impl BinaryEvaluator for DecimalMinusBinaryEvaluator { }) } } + +crate::define_cast_evaluator!(DecimalToFloatCastEvaluator, DataValue::Decimal(value) => { + Ok(DataValue::Float32(OrderedFloat(value.to_f32().ok_or_else(|| { + crate::types::evaluator::cast::cast_fail( + crate::types::LogicalType::Decimal(None, None), + crate::types::LogicalType::Float, + ) + })?))) +}); +crate::define_cast_evaluator!(DecimalToDoubleCastEvaluator, DataValue::Decimal(value) => { + Ok(DataValue::Float64(OrderedFloat(value.to_f64().ok_or_else(|| { + crate::types::evaluator::cast::cast_fail( + crate::types::LogicalType::Decimal(None, None), + crate::types::LogicalType::Double, + ) + })?))) +}); +crate::define_cast_evaluator!(DecimalToDecimalCastEvaluator, DataValue::Decimal(value) => { + Ok(DataValue::Decimal(*value)) +}); +crate::define_cast_evaluator!( + DecimalToCharCastEvaluator { + len: u32, + unit: CharLengthUnits + }, + DataValue::Decimal(value) => |this| to_char(value.to_string(), this.len, this.unit) +); +crate::define_cast_evaluator!( + DecimalToVarcharCastEvaluator { + len: Option, + unit: CharLengthUnits + }, + DataValue::Decimal(value) => |this| to_varchar(value.to_string(), this.len, this.unit) +); +crate::define_cast_evaluator!(DecimalToTinyintCastEvaluator, DataValue::Decimal(value) => { + Ok(DataValue::Int8(crate::decimal_to_int_cast!(*value, i8))) +}); +crate::define_cast_evaluator!(DecimalToSmallintCastEvaluator, DataValue::Decimal(value) => { + Ok(DataValue::Int16(crate::decimal_to_int_cast!(*value, i16))) +}); +crate::define_cast_evaluator!(DecimalToIntegerCastEvaluator, DataValue::Decimal(value) => { + Ok(DataValue::Int32(crate::decimal_to_int_cast!(*value, i32))) +}); +crate::define_cast_evaluator!(DecimalToBigintCastEvaluator, DataValue::Decimal(value) => { + Ok(DataValue::Int64(crate::decimal_to_int_cast!(*value, i64))) +}); +crate::define_cast_evaluator!(DecimalToUTinyintCastEvaluator, DataValue::Decimal(value) => { + Ok(DataValue::UInt8(crate::decimal_to_int_cast!(*value, u8))) +}); +crate::define_cast_evaluator!(DecimalToUSmallintCastEvaluator, DataValue::Decimal(value) => { + Ok(DataValue::UInt16(crate::decimal_to_int_cast!(*value, u16))) +}); +crate::define_cast_evaluator!(DecimalToUIntegerCastEvaluator, DataValue::Decimal(value) => { + Ok(DataValue::UInt32(crate::decimal_to_int_cast!(*value, u32))) +}); +crate::define_cast_evaluator!(DecimalToUBigintCastEvaluator, DataValue::Decimal(value) => { + Ok(DataValue::UInt64(crate::decimal_to_int_cast!(*value, u64))) +}); + #[typetag::serde] impl BinaryEvaluator for DecimalMultiplyBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { @@ -173,3 +236,88 @@ impl BinaryEvaluator for DecimalModBinaryEvaluator { }) } } + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod test { + use super::*; + use crate::types::evaluator::CastEvaluator; + use crate::types::value::Utf8Type; + use rust_decimal::Decimal; + use sqlparser::ast::CharLengthUnits; + + #[test] + fn test_decimal_cast_evaluators() { + let value = DataValue::Decimal(Decimal::new(125, 1)); + + assert_eq!( + DecimalToFloatCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Float32(OrderedFloat(12.5)) + ); + assert_eq!( + DecimalToDoubleCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Float64(OrderedFloat(12.5)) + ); + assert_eq!( + DecimalToDecimalCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Decimal(Decimal::new(125, 1)) + ); + assert_eq!( + DecimalToCharCastEvaluator { + len: 4, + unit: CharLengthUnits::Characters, + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "12.5".to_string(), + ty: Utf8Type::Fixed(4), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + DecimalToVarcharCastEvaluator { + len: Some(4), + unit: CharLengthUnits::Characters, + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "12.5".to_string(), + ty: Utf8Type::Variable(Some(4)), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + DecimalToTinyintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int8(12) + ); + assert_eq!( + DecimalToSmallintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int16(12) + ); + assert_eq!( + DecimalToIntegerCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int32(12) + ); + assert_eq!( + DecimalToBigintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int64(12) + ); + assert_eq!( + DecimalToUTinyintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt8(12) + ); + assert_eq!( + DecimalToUSmallintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt16(12) + ); + assert_eq!( + DecimalToUIntegerCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt32(12) + ); + assert_eq!( + DecimalToUBigintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt64(12) + ); + } +} diff --git a/src/types/evaluator/float32.rs b/src/types/evaluator/float32.rs index be278132..8c630eed 100644 --- a/src/types/evaluator/float32.rs +++ b/src/types/evaluator/float32.rs @@ -15,6 +15,8 @@ use crate::errors::DatabaseError; use crate::types::evaluator::DataValue; use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator}; +use crate::types::LogicalType; +use rust_decimal::prelude::FromPrimitive; use serde::{Deserialize, Serialize}; use std::hint; @@ -87,6 +89,8 @@ impl BinaryEvaluator for Float32MinusBinaryEvaluator { }) } } + +crate::define_float_cast_evaluators!(Float32, Float32, f32, LogicalType::Float, from_f32); #[typetag::serde] impl BinaryEvaluator for Float32MultiplyBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { @@ -197,3 +201,94 @@ impl BinaryEvaluator for Float32ModBinaryEvaluator { }) } } + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod test { + use super::*; + use crate::types::evaluator::CastEvaluator; + use crate::types::value::Utf8Type; + use ordered_float::OrderedFloat; + use rust_decimal::Decimal; + use sqlparser::ast::CharLengthUnits; + + #[test] + fn test_float32_cast_evaluators() { + let value = DataValue::Float32(OrderedFloat(1.5)); + + assert_eq!( + Float32ToFloatCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Float32(OrderedFloat(1.5)) + ); + assert_eq!( + Float32ToDoubleCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Float64(OrderedFloat(1.5)) + ); + assert_eq!( + Float32ToTinyintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int8(1) + ); + assert_eq!( + Float32ToSmallintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int16(1) + ); + assert_eq!( + Float32ToIntegerCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int32(1) + ); + assert_eq!( + Float32ToBigintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int64(1) + ); + assert_eq!( + Float32ToUTinyintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt8(1) + ); + assert_eq!( + Float32ToUSmallintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt16(1) + ); + assert_eq!( + Float32ToUIntegerCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt32(1) + ); + assert_eq!( + Float32ToUBigintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt64(1) + ); + assert_eq!( + Float32ToCharCastEvaluator { + len: 3, + unit: CharLengthUnits::Characters, + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "1.5".to_string(), + ty: Utf8Type::Fixed(3), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + Float32ToVarcharCastEvaluator { + len: Some(3), + unit: CharLengthUnits::Characters, + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "1.5".to_string(), + ty: Utf8Type::Variable(Some(3)), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + Float32ToDecimalCastEvaluator { + scale: Some(1), + to: LogicalType::Decimal(None, Some(1)), + } + .eval_cast(&value) + .unwrap(), + DataValue::Decimal(Decimal::new(15, 1)) + ); + } +} diff --git a/src/types/evaluator/float64.rs b/src/types/evaluator/float64.rs index 35a9e164..219a440c 100644 --- a/src/types/evaluator/float64.rs +++ b/src/types/evaluator/float64.rs @@ -13,9 +13,14 @@ // limitations under the License. use crate::errors::DatabaseError; +use crate::types::evaluator::cast::{cast_fail, to_char, to_varchar}; use crate::types::evaluator::DataValue; use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator}; +use crate::types::LogicalType; +use rust_decimal::prelude::FromPrimitive; +use rust_decimal::Decimal; use serde::{Deserialize, Serialize}; +use sqlparser::ast::CharLengthUnits; use std::hint; #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] @@ -87,6 +92,66 @@ impl BinaryEvaluator for Float64MinusBinaryEvaluator { }) } } + +crate::define_cast_evaluator!(Float64ToFloatCastEvaluator, DataValue::Float64(value) => { + Ok(DataValue::Float32(ordered_float::OrderedFloat(value.0 as f32))) +}); +crate::define_cast_evaluator!(Float64ToDoubleCastEvaluator, DataValue::Float64(value) => { + Ok(DataValue::Float64(*value)) +}); +crate::define_cast_evaluator!(Float64ToTinyintCastEvaluator, DataValue::Float64(value) => { + Ok(DataValue::Int8(crate::float_to_int_cast!(value.into_inner(), i8, f64)?)) +}); +crate::define_cast_evaluator!(Float64ToSmallintCastEvaluator, DataValue::Float64(value) => { + Ok(DataValue::Int16(crate::float_to_int_cast!(value.into_inner(), i16, f64)?)) +}); +crate::define_cast_evaluator!(Float64ToIntegerCastEvaluator, DataValue::Float64(value) => { + Ok(DataValue::Int32(crate::float_to_int_cast!(value.into_inner(), i32, f64)?)) +}); +crate::define_cast_evaluator!(Float64ToBigintCastEvaluator, DataValue::Float64(value) => { + Ok(DataValue::Int64(crate::float_to_int_cast!(value.into_inner(), i64, f64)?)) +}); +crate::define_cast_evaluator!(Float64ToUTinyintCastEvaluator, DataValue::Float64(value) => { + Ok(DataValue::UInt8(crate::float_to_int_cast!(value.into_inner(), u8, f64)?)) +}); +crate::define_cast_evaluator!(Float64ToUSmallintCastEvaluator, DataValue::Float64(value) => { + Ok(DataValue::UInt16(crate::float_to_int_cast!(value.into_inner(), u16, f64)?)) +}); +crate::define_cast_evaluator!(Float64ToUIntegerCastEvaluator, DataValue::Float64(value) => { + Ok(DataValue::UInt32(crate::float_to_int_cast!(value.into_inner(), u32, f64)?)) +}); +crate::define_cast_evaluator!(Float64ToUBigintCastEvaluator, DataValue::Float64(value) => { + Ok(DataValue::UInt64(crate::float_to_int_cast!(value.into_inner(), u64, f64)?)) +}); +crate::define_cast_evaluator!( + Float64ToCharCastEvaluator { + len: u32, + unit: CharLengthUnits + }, + DataValue::Float64(value) => |this| to_char(value.to_string(), this.len, this.unit) +); +crate::define_cast_evaluator!( + Float64ToVarcharCastEvaluator { + len: Option, + unit: CharLengthUnits + }, + DataValue::Float64(value) => |this| to_varchar(value.to_string(), this.len, this.unit) +); +crate::define_cast_evaluator!( + Float64ToDecimalCastEvaluator { + scale: Option, + to: LogicalType + }, + DataValue::Float64(value) => |this| { + let mut decimal = Decimal::from_f64(value.0).ok_or_else(|| { + cast_fail(LogicalType::Double, this.to.clone()) + })?; + DataValue::decimal_round_f(&this.scale, &mut decimal); + + Ok(DataValue::Decimal(decimal)) + } +); + #[typetag::serde] impl BinaryEvaluator for Float64MultiplyBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { @@ -197,3 +262,102 @@ impl BinaryEvaluator for Float64ModBinaryEvaluator { }) } } + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod test { + use super::*; + use crate::types::evaluator::{BinaryEvaluator, CastEvaluator}; + use crate::types::value::Utf8Type; + use rust_decimal::Decimal; + use sqlparser::ast::CharLengthUnits; + + #[test] + fn test_float64_binary_and_cast_evaluators() { + let value = DataValue::Float64(ordered_float::OrderedFloat(1.5)); + + assert_eq!( + Float64MultiplyBinaryEvaluator + .binary_eval( + &DataValue::Float64(ordered_float::OrderedFloat(1.5)), + &DataValue::Float64(ordered_float::OrderedFloat(2.0)), + ) + .unwrap(), + DataValue::Float64(ordered_float::OrderedFloat(3.0)) + ); + assert_eq!( + Float64ToFloatCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Float32(ordered_float::OrderedFloat(1.5)) + ); + assert_eq!( + Float64ToDoubleCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Float64(ordered_float::OrderedFloat(1.5)) + ); + assert_eq!( + Float64ToTinyintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int8(1) + ); + assert_eq!( + Float64ToSmallintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int16(1) + ); + assert_eq!( + Float64ToIntegerCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int32(1) + ); + assert_eq!( + Float64ToBigintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int64(1) + ); + assert_eq!( + Float64ToUTinyintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt8(1) + ); + assert_eq!( + Float64ToUSmallintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt16(1) + ); + assert_eq!( + Float64ToUIntegerCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt32(1) + ); + assert_eq!( + Float64ToUBigintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt64(1) + ); + assert_eq!( + Float64ToCharCastEvaluator { + len: 3, + unit: CharLengthUnits::Characters, + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "1.5".to_string(), + ty: Utf8Type::Fixed(3), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + Float64ToVarcharCastEvaluator { + len: Some(3), + unit: CharLengthUnits::Characters, + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "1.5".to_string(), + ty: Utf8Type::Variable(Some(3)), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + Float64ToDecimalCastEvaluator { + scale: Some(1), + to: LogicalType::Decimal(None, Some(1)), + } + .eval_cast(&value) + .unwrap(), + DataValue::Decimal(Decimal::new(15, 1)) + ); + } +} diff --git a/src/types/evaluator/int16.rs b/src/types/evaluator/int16.rs index 976c1502..aa9f95c6 100644 --- a/src/types/evaluator/int16.rs +++ b/src/types/evaluator/int16.rs @@ -13,12 +13,101 @@ // limitations under the License. use crate::types::evaluator::DataValue; -use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator}; -use crate::types::DatabaseError; +use crate::types::LogicalType; use crate::{numeric_binary_evaluator_definition, numeric_unary_evaluator_definition}; -use paste::paste; -use serde::{Deserialize, Serialize}; -use std::hint; numeric_unary_evaluator_definition!(Int16, DataValue::Int16); numeric_binary_evaluator_definition!(Int16, DataValue::Int16); +crate::define_integer_cast_evaluators!(Int16, Int16, i16, LogicalType::Smallint); + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod test { + use super::*; + use crate::types::evaluator::CastEvaluator; + use crate::types::value::Utf8Type; + use ordered_float::OrderedFloat; + use rust_decimal::Decimal; + use sqlparser::ast::CharLengthUnits; + + #[test] + fn test_int16_cast_evaluators() { + let value = DataValue::Int16(1); + + assert_eq!( + Int16ToBooleanCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Boolean(true) + ); + assert_eq!( + Int16ToTinyintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int8(1) + ); + assert_eq!( + Int16ToUTinyintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt8(1) + ); + assert_eq!( + Int16ToSmallintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int16(1) + ); + assert_eq!( + Int16ToUSmallintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt16(1) + ); + assert_eq!( + Int16ToIntegerCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int32(1) + ); + assert_eq!( + Int16ToUIntegerCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt32(1) + ); + assert_eq!( + Int16ToBigintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int64(1) + ); + assert_eq!( + Int16ToUBigintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt64(1) + ); + assert_eq!( + Int16ToFloatCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Float32(OrderedFloat(1.0)) + ); + assert_eq!( + Int16ToDoubleCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Float64(OrderedFloat(1.0)) + ); + assert_eq!( + Int16ToCharCastEvaluator { + len: 1, + unit: CharLengthUnits::Characters, + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "1".to_string(), + ty: Utf8Type::Fixed(1), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + Int16ToVarcharCastEvaluator { + len: Some(1), + unit: CharLengthUnits::Characters, + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "1".to_string(), + ty: Utf8Type::Variable(Some(1)), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + Int16ToDecimalCastEvaluator { scale: Some(2) } + .eval_cast(&value) + .unwrap(), + DataValue::Decimal(Decimal::new(100, 2)) + ); + } +} diff --git a/src/types/evaluator/int32.rs b/src/types/evaluator/int32.rs index 03ffddd4..3fbce73a 100644 --- a/src/types/evaluator/int32.rs +++ b/src/types/evaluator/int32.rs @@ -13,12 +13,129 @@ // limitations under the License. use crate::types::evaluator::DataValue; -use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator}; -use crate::types::DatabaseError; +use crate::types::LogicalType; use crate::{numeric_binary_evaluator_definition, numeric_unary_evaluator_definition}; -use paste::paste; -use serde::{Deserialize, Serialize}; -use std::hint; numeric_unary_evaluator_definition!(Int32, DataValue::Int32); numeric_binary_evaluator_definition!(Int32, DataValue::Int32); +crate::define_integer_cast_evaluators!(Int32, Int32, i32, LogicalType::Integer); + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod test { + use super::*; + use crate::types::evaluator::{BinaryEvaluator, CastEvaluator}; + use crate::types::value::Utf8Type; + use ordered_float::OrderedFloat; + use rust_decimal::Decimal; + use sqlparser::ast::CharLengthUnits; + + #[test] + fn test_int32_binary_evaluators() { + assert_eq!( + Int32PlusBinaryEvaluator + .binary_eval(&DataValue::Int32(1), &DataValue::Int32(1)) + .unwrap(), + DataValue::Int32(2) + ); + assert_eq!( + Int32MinusBinaryEvaluator + .binary_eval(&DataValue::Int32(1), &DataValue::Int32(1)) + .unwrap(), + DataValue::Int32(0) + ); + assert_eq!( + Int32EqBinaryEvaluator + .binary_eval(&DataValue::Int32(1), &DataValue::Int32(1)) + .unwrap(), + DataValue::Boolean(true) + ); + assert_eq!( + Int32GtBinaryEvaluator + .binary_eval(&DataValue::Int32(1), &DataValue::Int32(0)) + .unwrap(), + DataValue::Boolean(true) + ); + } + + #[test] + fn test_int32_cast_evaluators() { + let value = DataValue::Int32(1); + + assert_eq!( + Int32ToBooleanCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Boolean(true) + ); + assert_eq!( + Int32ToTinyintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int8(1) + ); + assert_eq!( + Int32ToUTinyintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt8(1) + ); + assert_eq!( + Int32ToSmallintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int16(1) + ); + assert_eq!( + Int32ToUSmallintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt16(1) + ); + assert_eq!( + Int32ToIntegerCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int32(1) + ); + assert_eq!( + Int32ToUIntegerCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt32(1) + ); + assert_eq!( + Int32ToBigintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int64(1) + ); + assert_eq!( + Int32ToUBigintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt64(1) + ); + assert_eq!( + Int32ToFloatCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Float32(OrderedFloat(1.0)) + ); + assert_eq!( + Int32ToDoubleCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Float64(OrderedFloat(1.0)) + ); + assert_eq!( + Int32ToCharCastEvaluator { + len: 1, + unit: CharLengthUnits::Characters, + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "1".to_string(), + ty: Utf8Type::Fixed(1), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + Int32ToVarcharCastEvaluator { + len: Some(1), + unit: CharLengthUnits::Characters, + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "1".to_string(), + ty: Utf8Type::Variable(Some(1)), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + Int32ToDecimalCastEvaluator { scale: Some(1) } + .eval_cast(&value) + .unwrap(), + DataValue::Decimal(Decimal::new(10, 1)) + ); + } +} diff --git a/src/types/evaluator/int64.rs b/src/types/evaluator/int64.rs index 9a4727dd..7e67eb09 100644 --- a/src/types/evaluator/int64.rs +++ b/src/types/evaluator/int64.rs @@ -13,12 +13,101 @@ // limitations under the License. use crate::types::evaluator::DataValue; -use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator}; -use crate::types::DatabaseError; +use crate::types::LogicalType; use crate::{numeric_binary_evaluator_definition, numeric_unary_evaluator_definition}; -use paste::paste; -use serde::{Deserialize, Serialize}; -use std::hint; numeric_unary_evaluator_definition!(Int64, DataValue::Int64); numeric_binary_evaluator_definition!(Int64, DataValue::Int64); +crate::define_integer_cast_evaluators!(Int64, Int64, i64, LogicalType::Bigint); + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod test { + use super::*; + use crate::types::evaluator::CastEvaluator; + use crate::types::value::Utf8Type; + use ordered_float::OrderedFloat; + use rust_decimal::Decimal; + use sqlparser::ast::CharLengthUnits; + + #[test] + fn test_int64_cast_evaluators() { + let value = DataValue::Int64(1); + + assert_eq!( + Int64ToBooleanCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Boolean(true) + ); + assert_eq!( + Int64ToTinyintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int8(1) + ); + assert_eq!( + Int64ToUTinyintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt8(1) + ); + assert_eq!( + Int64ToSmallintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int16(1) + ); + assert_eq!( + Int64ToUSmallintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt16(1) + ); + assert_eq!( + Int64ToIntegerCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int32(1) + ); + assert_eq!( + Int64ToUIntegerCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt32(1) + ); + assert_eq!( + Int64ToBigintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int64(1) + ); + assert_eq!( + Int64ToUBigintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt64(1) + ); + assert_eq!( + Int64ToFloatCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Float32(OrderedFloat(1.0)) + ); + assert_eq!( + Int64ToDoubleCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Float64(OrderedFloat(1.0)) + ); + assert_eq!( + Int64ToCharCastEvaluator { + len: 1, + unit: CharLengthUnits::Characters, + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "1".to_string(), + ty: Utf8Type::Fixed(1), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + Int64ToVarcharCastEvaluator { + len: Some(1), + unit: CharLengthUnits::Characters, + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "1".to_string(), + ty: Utf8Type::Variable(Some(1)), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + Int64ToDecimalCastEvaluator { scale: Some(2) } + .eval_cast(&value) + .unwrap(), + DataValue::Decimal(Decimal::new(100, 2)) + ); + } +} diff --git a/src/types/evaluator/int8.rs b/src/types/evaluator/int8.rs index fe5a7ccb..9482c989 100644 --- a/src/types/evaluator/int8.rs +++ b/src/types/evaluator/int8.rs @@ -13,12 +13,101 @@ // limitations under the License. use crate::types::evaluator::DataValue; -use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator}; -use crate::types::DatabaseError; +use crate::types::LogicalType; use crate::{numeric_binary_evaluator_definition, numeric_unary_evaluator_definition}; -use paste::paste; -use serde::{Deserialize, Serialize}; -use std::hint; numeric_unary_evaluator_definition!(Int8, DataValue::Int8); numeric_binary_evaluator_definition!(Int8, DataValue::Int8); +crate::define_integer_cast_evaluators!(Int8, Int8, i8, LogicalType::Tinyint); + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod test { + use super::*; + use crate::types::evaluator::CastEvaluator; + use crate::types::value::Utf8Type; + use ordered_float::OrderedFloat; + use rust_decimal::Decimal; + use sqlparser::ast::CharLengthUnits; + + #[test] + fn test_int8_cast_evaluators() { + let value = DataValue::Int8(1); + + assert_eq!( + Int8ToBooleanCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Boolean(true) + ); + assert_eq!( + Int8ToTinyintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int8(1) + ); + assert_eq!( + Int8ToUTinyintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt8(1) + ); + assert_eq!( + Int8ToSmallintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int16(1) + ); + assert_eq!( + Int8ToUSmallintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt16(1) + ); + assert_eq!( + Int8ToIntegerCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int32(1) + ); + assert_eq!( + Int8ToUIntegerCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt32(1) + ); + assert_eq!( + Int8ToBigintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int64(1) + ); + assert_eq!( + Int8ToUBigintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt64(1) + ); + assert_eq!( + Int8ToFloatCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Float32(OrderedFloat(1.0)) + ); + assert_eq!( + Int8ToDoubleCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Float64(OrderedFloat(1.0)) + ); + assert_eq!( + Int8ToCharCastEvaluator { + len: 1, + unit: CharLengthUnits::Characters, + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "1".to_string(), + ty: Utf8Type::Fixed(1), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + Int8ToVarcharCastEvaluator { + len: Some(1), + unit: CharLengthUnits::Characters, + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "1".to_string(), + ty: Utf8Type::Variable(Some(1)), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + Int8ToDecimalCastEvaluator { scale: Some(2) } + .eval_cast(&value) + .unwrap(), + DataValue::Decimal(Decimal::new(100, 2)) + ); + } +} diff --git a/src/types/evaluator/mod.rs b/src/types/evaluator/mod.rs index c66a60a7..5544a142 100644 --- a/src/types/evaluator/mod.rs +++ b/src/types/evaluator/mod.rs @@ -12,7 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +pub mod binary; pub mod boolean; +pub mod cast; pub mod date; pub mod datetime; pub mod decimal; @@ -30,39 +32,15 @@ pub mod uint16; pub mod uint32; pub mod uint64; pub mod uint8; +pub mod unary; pub mod utf8; +pub use self::binary::binary_create; +pub use self::cast::cast_create; +pub use self::unary::unary_create; + use crate::errors::DatabaseError; -use crate::expression::{BinaryOperator, UnaryOperator}; -use crate::types::evaluator::boolean::*; -use crate::types::evaluator::date::*; -use crate::types::evaluator::datetime::*; -use crate::types::evaluator::decimal::*; -use crate::types::evaluator::float32::*; -use crate::types::evaluator::float64::*; -use crate::types::evaluator::int16::*; -use crate::types::evaluator::int32::*; -use crate::types::evaluator::int64::*; -use crate::types::evaluator::int8::*; -use crate::types::evaluator::null::NullBinaryEvaluator; -use crate::types::evaluator::time32::*; -use crate::types::evaluator::time64::*; -use crate::types::evaluator::tuple::{ - TupleEqBinaryEvaluator, TupleGtBinaryEvaluator, TupleGtEqBinaryEvaluator, - TupleLtBinaryEvaluator, TupleLtEqBinaryEvaluator, TupleNotEqBinaryEvaluator, -}; -use crate::types::evaluator::uint16::*; -use crate::types::evaluator::uint32::*; -use crate::types::evaluator::uint64::*; -use crate::types::evaluator::uint8::*; -use crate::types::evaluator::utf8::*; -use crate::types::evaluator::utf8::{ - Utf8EqBinaryEvaluator, Utf8GtBinaryEvaluator, Utf8GtEqBinaryEvaluator, Utf8LtBinaryEvaluator, - Utf8LtEqBinaryEvaluator, Utf8NotEqBinaryEvaluator, Utf8StringConcatBinaryEvaluator, -}; use crate::types::value::DataValue; -use crate::types::LogicalType; -use paste::paste; use serde::{Deserialize, Serialize}; use std::fmt::Debug; use std::hash::{Hash, Hasher}; @@ -79,6 +57,11 @@ pub trait UnaryEvaluator: Send + Sync + Debug { fn unary_eval(&self, value: &DataValue) -> DataValue; } +#[typetag::serde(tag = "cast")] +pub trait CastEvaluator: Send + Sync + Debug { + fn eval_cast(&self, value: &DataValue) -> Result; +} + #[derive(Clone, Debug, Serialize, Deserialize)] pub struct BinaryEvaluatorBox(pub Arc); @@ -109,6 +92,23 @@ impl UnaryEvaluatorBox { } } +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct CastEvaluatorBox(pub Arc); + +impl Deref for CastEvaluatorBox { + type Target = Arc; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl CastEvaluatorBox { + pub fn eval_cast(&self, value: &DataValue) -> Result { + self.0.eval_cast(value) + } +} + impl PartialEq for BinaryEvaluatorBox { fn eq(&self, _: &Self) -> bool { // FIXME @@ -139,1135 +139,17 @@ impl Hash for UnaryEvaluatorBox { } } -macro_rules! numeric_binary_evaluator { - ($value_type:ident, $op:expr, $ty:expr) => { - paste! { - match $op { - BinaryOperator::Plus => Ok(BinaryEvaluatorBox(Arc::new([<$value_type PlusBinaryEvaluator>]))), - BinaryOperator::Minus => Ok(BinaryEvaluatorBox(Arc::new([<$value_type MinusBinaryEvaluator>]))), - BinaryOperator::Multiply => Ok(BinaryEvaluatorBox(Arc::new([<$value_type MultiplyBinaryEvaluator>]))), - BinaryOperator::Divide => Ok(BinaryEvaluatorBox(Arc::new([<$value_type DivideBinaryEvaluator>]))), - BinaryOperator::Gt => Ok(BinaryEvaluatorBox(Arc::new([<$value_type GtBinaryEvaluator>]))), - BinaryOperator::GtEq => Ok(BinaryEvaluatorBox(Arc::new([<$value_type GtEqBinaryEvaluator>]))), - BinaryOperator::Lt => Ok(BinaryEvaluatorBox(Arc::new([<$value_type LtBinaryEvaluator>]))), - BinaryOperator::LtEq => Ok(BinaryEvaluatorBox(Arc::new([<$value_type LtEqBinaryEvaluator>]))), - BinaryOperator::Eq => Ok(BinaryEvaluatorBox(Arc::new([<$value_type EqBinaryEvaluator>]))), - BinaryOperator::NotEq => Ok(BinaryEvaluatorBox(Arc::new([<$value_type NotEqBinaryEvaluator>]))), - BinaryOperator::Modulo => Ok(BinaryEvaluatorBox(Arc::new([<$value_type ModBinaryEvaluator>]))), - _ => { - return Err(DatabaseError::UnsupportedBinaryOperator( - $ty, - $op, - )) - } - } - } - }; -} - -macro_rules! numeric_unary_evaluator { - ($value_type:ident, $op:expr, $ty:expr) => { - paste! { - match $op { - UnaryOperator::Plus => Ok(UnaryEvaluatorBox(Arc::new([<$value_type PlusUnaryEvaluator>]))), - UnaryOperator::Minus => Ok(UnaryEvaluatorBox(Arc::new([<$value_type MinusUnaryEvaluator>]))), - _ => { - return Err(DatabaseError::UnsupportedUnaryOperator( - $ty, - $op, - )) - } - } - } - }; -} - -pub struct EvaluatorFactory; - -impl EvaluatorFactory { - pub fn unary_create( - ty: LogicalType, - op: UnaryOperator, - ) -> Result { - match ty { - LogicalType::Tinyint => numeric_unary_evaluator!(Int8, op, LogicalType::Tinyint), - LogicalType::Smallint => numeric_unary_evaluator!(Int16, op, LogicalType::Smallint), - LogicalType::Integer => numeric_unary_evaluator!(Int32, op, LogicalType::Integer), - LogicalType::Bigint => numeric_unary_evaluator!(Int64, op, LogicalType::Bigint), - LogicalType::Boolean => match op { - UnaryOperator::Not => Ok(UnaryEvaluatorBox(Arc::new(BooleanNotUnaryEvaluator))), - _ => Err(DatabaseError::UnsupportedUnaryOperator(ty, op)), - }, - LogicalType::Float => numeric_unary_evaluator!(Float32, op, LogicalType::Float), - LogicalType::Double => numeric_unary_evaluator!(Float64, op, LogicalType::Double), - _ => Err(DatabaseError::UnsupportedUnaryOperator(ty, op)), - } - } - pub fn binary_create( - ty: LogicalType, - op: BinaryOperator, - ) -> Result { - match ty { - LogicalType::Tinyint => numeric_binary_evaluator!(Int8, op, LogicalType::Tinyint), - LogicalType::Smallint => numeric_binary_evaluator!(Int16, op, LogicalType::Smallint), - LogicalType::Integer => numeric_binary_evaluator!(Int32, op, LogicalType::Integer), - LogicalType::Bigint => numeric_binary_evaluator!(Int64, op, LogicalType::Bigint), - LogicalType::UTinyint => numeric_binary_evaluator!(UInt8, op, LogicalType::UTinyint), - LogicalType::USmallint => numeric_binary_evaluator!(UInt16, op, LogicalType::USmallint), - LogicalType::UInteger => numeric_binary_evaluator!(UInt32, op, LogicalType::UInteger), - LogicalType::UBigint => numeric_binary_evaluator!(UInt64, op, LogicalType::UBigint), - LogicalType::Float => numeric_binary_evaluator!(Float32, op, LogicalType::Float), - LogicalType::Double => numeric_binary_evaluator!(Float64, op, LogicalType::Double), - LogicalType::Date => numeric_binary_evaluator!(Date, op, LogicalType::Date), - LogicalType::DateTime => numeric_binary_evaluator!(DateTime, op, LogicalType::DateTime), - LogicalType::Time(_) => match op { - BinaryOperator::Plus => Ok(BinaryEvaluatorBox(Arc::new(TimePlusBinaryEvaluator))), - BinaryOperator::Minus => Ok(BinaryEvaluatorBox(Arc::new(TimeMinusBinaryEvaluator))), - BinaryOperator::Gt => Ok(BinaryEvaluatorBox(Arc::new(TimeGtBinaryEvaluator))), - BinaryOperator::GtEq => Ok(BinaryEvaluatorBox(Arc::new(TimeGtEqBinaryEvaluator))), - BinaryOperator::Lt => Ok(BinaryEvaluatorBox(Arc::new(TimeLtBinaryEvaluator))), - BinaryOperator::LtEq => Ok(BinaryEvaluatorBox(Arc::new(TimeLtEqBinaryEvaluator))), - BinaryOperator::Eq => Ok(BinaryEvaluatorBox(Arc::new(TimeEqBinaryEvaluator))), - BinaryOperator::NotEq => Ok(BinaryEvaluatorBox(Arc::new(TimeNotEqBinaryEvaluator))), - _ => Err(DatabaseError::UnsupportedBinaryOperator(ty, op)), - }, - LogicalType::TimeStamp(_, _) => match op { - BinaryOperator::Gt => Ok(BinaryEvaluatorBox(Arc::new(Time64GtBinaryEvaluator))), - BinaryOperator::GtEq => Ok(BinaryEvaluatorBox(Arc::new(Time64GtEqBinaryEvaluator))), - BinaryOperator::Lt => Ok(BinaryEvaluatorBox(Arc::new(Time64LtBinaryEvaluator))), - BinaryOperator::LtEq => Ok(BinaryEvaluatorBox(Arc::new(Time64LtEqBinaryEvaluator))), - BinaryOperator::Eq => Ok(BinaryEvaluatorBox(Arc::new(Time64EqBinaryEvaluator))), - BinaryOperator::NotEq => { - Ok(BinaryEvaluatorBox(Arc::new(Time64NotEqBinaryEvaluator))) - } - _ => Err(DatabaseError::UnsupportedBinaryOperator(ty, op)), - }, - LogicalType::Decimal(_, _) => numeric_binary_evaluator!(Decimal, op, ty), - LogicalType::Boolean => match op { - BinaryOperator::And => Ok(BinaryEvaluatorBox(Arc::new(BooleanAndBinaryEvaluator))), - BinaryOperator::Or => Ok(BinaryEvaluatorBox(Arc::new(BooleanOrBinaryEvaluator))), - BinaryOperator::Eq => Ok(BinaryEvaluatorBox(Arc::new(BooleanEqBinaryEvaluator))), - BinaryOperator::NotEq => { - Ok(BinaryEvaluatorBox(Arc::new(BooleanNotEqBinaryEvaluator))) - } - _ => Err(DatabaseError::UnsupportedBinaryOperator( - LogicalType::Boolean, - op, - )), - }, - LogicalType::Varchar(_, _) | LogicalType::Char(_, _) => match op { - BinaryOperator::Gt => Ok(BinaryEvaluatorBox(Arc::new(Utf8GtBinaryEvaluator))), - BinaryOperator::Lt => Ok(BinaryEvaluatorBox(Arc::new(Utf8LtBinaryEvaluator))), - BinaryOperator::GtEq => Ok(BinaryEvaluatorBox(Arc::new(Utf8GtEqBinaryEvaluator))), - BinaryOperator::LtEq => Ok(BinaryEvaluatorBox(Arc::new(Utf8LtEqBinaryEvaluator))), - BinaryOperator::Eq => Ok(BinaryEvaluatorBox(Arc::new(Utf8EqBinaryEvaluator))), - BinaryOperator::NotEq => Ok(BinaryEvaluatorBox(Arc::new(Utf8NotEqBinaryEvaluator))), - BinaryOperator::StringConcat => Ok(BinaryEvaluatorBox(Arc::new( - Utf8StringConcatBinaryEvaluator, - ))), - BinaryOperator::Like(escape_char) => { - Ok(BinaryEvaluatorBox(Arc::new(Utf8LikeBinaryEvaluator { - escape_char, - }))) - } - BinaryOperator::NotLike(escape_char) => { - Ok(BinaryEvaluatorBox(Arc::new(Utf8NotLikeBinaryEvaluator { - escape_char, - }))) - } - _ => Err(DatabaseError::UnsupportedBinaryOperator(ty, op)), - }, - LogicalType::SqlNull => Ok(BinaryEvaluatorBox(Arc::new(NullBinaryEvaluator))), - LogicalType::Tuple(_) => match op { - BinaryOperator::Eq => Ok(BinaryEvaluatorBox(Arc::new(TupleEqBinaryEvaluator))), - BinaryOperator::NotEq => { - Ok(BinaryEvaluatorBox(Arc::new(TupleNotEqBinaryEvaluator))) - } - BinaryOperator::Gt => Ok(BinaryEvaluatorBox(Arc::new(TupleGtBinaryEvaluator))), - BinaryOperator::GtEq => Ok(BinaryEvaluatorBox(Arc::new(TupleGtEqBinaryEvaluator))), - BinaryOperator::Lt => Ok(BinaryEvaluatorBox(Arc::new(TupleLtBinaryEvaluator))), - BinaryOperator::LtEq => Ok(BinaryEvaluatorBox(Arc::new(TupleLtEqBinaryEvaluator))), - _ => Err(DatabaseError::UnsupportedBinaryOperator(ty, op)), - }, - } +impl PartialEq for CastEvaluatorBox { + fn eq(&self, _: &Self) -> bool { + // FIXME + true } } -#[macro_export] -macro_rules! numeric_unary_evaluator_definition { - ($value_type:ident, $compute_type:path) => { - paste! { - #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] - pub struct [<$value_type PlusUnaryEvaluator>]; - #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] - pub struct [<$value_type MinusUnaryEvaluator>]; - - #[typetag::serde] - impl UnaryEvaluator for [<$value_type PlusUnaryEvaluator>] { - fn unary_eval(&self, value: &DataValue) -> DataValue { - value.clone() - } - } - #[typetag::serde] - impl UnaryEvaluator for [<$value_type MinusUnaryEvaluator>] { - fn unary_eval(&self, value: &DataValue) -> DataValue { - match value { - $compute_type(value) => $compute_type(-value), - DataValue::Null => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - } - } - } - } - }; -} - -#[macro_export] -macro_rules! numeric_binary_evaluator_definition { - ($value_type:ident, $compute_type:path) => { - paste! { - #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] - pub struct [<$value_type PlusBinaryEvaluator>]; - #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] - pub struct [<$value_type MinusBinaryEvaluator>]; - #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] - pub struct [<$value_type MultiplyBinaryEvaluator>]; - #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] - pub struct [<$value_type DivideBinaryEvaluator>]; - #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] - pub struct [<$value_type GtBinaryEvaluator>]; - #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] - pub struct [<$value_type GtEqBinaryEvaluator>]; - #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] - pub struct [<$value_type LtBinaryEvaluator>]; - #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] - pub struct [<$value_type LtEqBinaryEvaluator>]; - #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] - pub struct [<$value_type EqBinaryEvaluator>]; - #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] - pub struct [<$value_type NotEqBinaryEvaluator>]; - #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] - pub struct [<$value_type ModBinaryEvaluator>]; - - #[typetag::serde] - impl BinaryEvaluator for [<$value_type PlusBinaryEvaluator>] { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - ($compute_type(v1), $compute_type(v2)) => $compute_type(v1.checked_add(*v2).ok_or(DatabaseError::OverFlow)?), - ($compute_type(_), DataValue::Null) | (DataValue::Null, $compute_type(_)) | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } - } - #[typetag::serde] - impl BinaryEvaluator for [<$value_type MinusBinaryEvaluator>] { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - ($compute_type(v1), $compute_type(v2)) => $compute_type(v1.checked_sub(*v2).ok_or(DatabaseError::OverFlow)?), - ($compute_type(_), DataValue::Null) | (DataValue::Null, $compute_type(_)) | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } - } - #[typetag::serde] - impl BinaryEvaluator for [<$value_type MultiplyBinaryEvaluator>] { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - ($compute_type(v1), $compute_type(v2)) => $compute_type(v1.checked_mul(*v2).ok_or(DatabaseError::OverFlow)?), - ($compute_type(_), DataValue::Null) | (DataValue::Null, $compute_type(_)) | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } - } - #[typetag::serde] - impl BinaryEvaluator for [<$value_type DivideBinaryEvaluator>] { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - ($compute_type(v1), $compute_type(v2)) => DataValue::Float64(ordered_float::OrderedFloat(*v1 as f64 / *v2 as f64)), - ($compute_type(_), DataValue::Null) | (DataValue::Null, $compute_type(_)) | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } - } - #[typetag::serde] - impl BinaryEvaluator for [<$value_type GtBinaryEvaluator>] { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - ($compute_type(v1), $compute_type(v2)) => DataValue::Boolean(v1 > v2), - ($compute_type(_), DataValue::Null) | (DataValue::Null, $compute_type(_)) | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } - } - #[typetag::serde] - impl BinaryEvaluator for [<$value_type GtEqBinaryEvaluator>] { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - ($compute_type(v1), $compute_type(v2)) => DataValue::Boolean(v1 >= v2), - ($compute_type(_), DataValue::Null) | (DataValue::Null, $compute_type(_)) | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } - } - #[typetag::serde] - impl BinaryEvaluator for [<$value_type LtBinaryEvaluator>] { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - ($compute_type(v1), $compute_type(v2)) => DataValue::Boolean(v1 < v2), - ($compute_type(_), DataValue::Null) | (DataValue::Null, $compute_type(_)) | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } - } - #[typetag::serde] - impl BinaryEvaluator for [<$value_type LtEqBinaryEvaluator>] { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - ($compute_type(v1), $compute_type(v2)) => DataValue::Boolean(v1 <= v2), - ($compute_type(_), DataValue::Null) | (DataValue::Null, $compute_type(_)) | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } - } - #[typetag::serde] - impl BinaryEvaluator for [<$value_type EqBinaryEvaluator>] { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - ($compute_type(v1), $compute_type(v2)) => DataValue::Boolean(v1 == v2), - ($compute_type(_), DataValue::Null) | (DataValue::Null, $compute_type(_)) | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } - } - #[typetag::serde] - impl BinaryEvaluator for [<$value_type NotEqBinaryEvaluator>] { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - ($compute_type(v1), $compute_type(v2)) => DataValue::Boolean(v1 != v2), - ($compute_type(_), DataValue::Null) | (DataValue::Null, $compute_type(_)) | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } - } - #[typetag::serde] - impl BinaryEvaluator for [<$value_type ModBinaryEvaluator>] { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - ($compute_type(v1), $compute_type(v2)) => $compute_type(*v1 % *v2), - ($compute_type(_), DataValue::Null) | (DataValue::Null, $compute_type(_)) | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } - } - } - }; -} - -#[cfg(all(test, not(target_arch = "wasm32")))] -mod test { - use crate::errors::DatabaseError; - use crate::expression::BinaryOperator; - use crate::serdes::{ReferenceSerialization, ReferenceTables}; - use crate::storage::rocksdb::RocksTransaction; - use crate::types::evaluator::boolean::{BooleanNotEqBinaryEvaluator, BooleanNotUnaryEvaluator}; - use crate::types::evaluator::{BinaryEvaluatorBox, EvaluatorFactory, UnaryEvaluatorBox}; - use crate::types::value::{DataValue, Utf8Type}; - use crate::types::LogicalType; - use ordered_float::OrderedFloat; - use sqlparser::ast::CharLengthUnits; - use std::io::{Cursor, Seek, SeekFrom}; - use std::sync::Arc; - - #[test] - fn test_binary_op_arithmetic_plus() -> Result<(), DatabaseError> { - let plus_evaluator = - EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Plus)?; - let plus_i32_1 = plus_evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Null)?; - let plus_i32_2 = plus_evaluator - .0 - .binary_eval(&DataValue::Int32(1), &DataValue::Null)?; - let plus_i32_3 = plus_evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Int32(1))?; - let plus_i32_4 = plus_evaluator - .0 - .binary_eval(&DataValue::Int32(1), &DataValue::Int32(1))?; - - assert_eq!(plus_i32_1, plus_i32_2); - assert_eq!(plus_i32_2, plus_i32_3); - assert_eq!(plus_i32_4, DataValue::Int32(2)); - - let plus_evaluator = - EvaluatorFactory::binary_create(LogicalType::Bigint, BinaryOperator::Plus)?; - let plus_i64_1 = plus_evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Null)?; - let plus_i64_2 = plus_evaluator - .0 - .binary_eval(&DataValue::Int64(1), &DataValue::Null)?; - let plus_i64_3 = plus_evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Int64(1))?; - let plus_i64_4 = plus_evaluator - .0 - .binary_eval(&DataValue::Int64(1), &DataValue::Int64(1))?; - - assert_eq!(plus_i64_1, plus_i64_2); - assert_eq!(plus_i64_2, plus_i64_3); - assert_eq!(plus_i64_4, DataValue::Int64(2)); - - let plus_evaluator = - EvaluatorFactory::binary_create(LogicalType::Double, BinaryOperator::Plus)?; - let plus_f64_1 = plus_evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Null)?; - let plus_f64_2 = plus_evaluator - .0 - .binary_eval(&DataValue::Float64(OrderedFloat(1.0)), &DataValue::Null)?; - let plus_f64_3 = plus_evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Float64(OrderedFloat(1.0)))?; - let plus_f64_4 = plus_evaluator.0.binary_eval( - &DataValue::Float64(OrderedFloat(1.0)), - &DataValue::Float64(OrderedFloat(1.0)), - )?; - - assert_eq!(plus_f64_1, plus_f64_2); - assert_eq!(plus_f64_2, plus_f64_3); - assert_eq!(plus_f64_4, DataValue::Float64(OrderedFloat(2.0))); - - Ok(()) - } - - #[test] - fn test_binary_op_arithmetic_minus() -> Result<(), DatabaseError> { - let minus_evaluator = - EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Minus)?; - let minus_i32_1 = minus_evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Null)?; - let minus_i32_2 = minus_evaluator - .0 - .binary_eval(&DataValue::Int32(1), &DataValue::Null)?; - let minus_i32_3 = minus_evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Int32(1))?; - let minus_i32_4 = minus_evaluator - .0 - .binary_eval(&DataValue::Int32(1), &DataValue::Int32(1))?; - - assert_eq!(minus_i32_1, minus_i32_2); - assert_eq!(minus_i32_2, minus_i32_3); - assert_eq!(minus_i32_4, DataValue::Int32(0)); - - let minus_evaluator = - EvaluatorFactory::binary_create(LogicalType::Bigint, BinaryOperator::Minus)?; - let minus_i64_1 = minus_evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Null)?; - let minus_i64_2 = minus_evaluator - .0 - .binary_eval(&DataValue::Int64(1), &DataValue::Null)?; - let minus_i64_3 = minus_evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Int64(1))?; - let minus_i64_4 = minus_evaluator - .0 - .binary_eval(&DataValue::Int64(1), &DataValue::Int64(1))?; - - assert_eq!(minus_i64_1, minus_i64_2); - assert_eq!(minus_i64_2, minus_i64_3); - assert_eq!(minus_i64_4, DataValue::Int64(0)); +impl Eq for CastEvaluatorBox {} - let minus_evaluator = - EvaluatorFactory::binary_create(LogicalType::Double, BinaryOperator::Minus)?; - let minus_f64_1 = minus_evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Null)?; - let minus_f64_2 = minus_evaluator - .0 - .binary_eval(&DataValue::Float64(OrderedFloat(1.0)), &DataValue::Null)?; - let minus_f64_3 = minus_evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Float64(OrderedFloat(1.0)))?; - let minus_f64_4 = minus_evaluator.0.binary_eval( - &DataValue::Float64(OrderedFloat(1.0)), - &DataValue::Float64(OrderedFloat(1.0)), - )?; - - assert_eq!(minus_f64_1, minus_f64_2); - assert_eq!(minus_f64_2, minus_f64_3); - assert_eq!(minus_f64_4, DataValue::Float64(OrderedFloat(0.0))); - - Ok(()) - } - - #[test] - fn test_binary_op_arithmetic_multiply() -> Result<(), DatabaseError> { - let multiply_evaluator = - EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Multiply)?; - let multiply_i32_1 = multiply_evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Null)?; - let multiply_i32_2 = multiply_evaluator - .0 - .binary_eval(&DataValue::Int32(1), &DataValue::Null)?; - let multiply_i32_3 = multiply_evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Int32(1))?; - let multiply_i32_4 = multiply_evaluator - .0 - .binary_eval(&DataValue::Int32(1), &DataValue::Int32(1))?; - - assert_eq!(multiply_i32_1, multiply_i32_2); - assert_eq!(multiply_i32_2, multiply_i32_3); - assert_eq!(multiply_i32_4, DataValue::Int32(1)); - - let multiply_evaluator = - EvaluatorFactory::binary_create(LogicalType::Bigint, BinaryOperator::Multiply)?; - let multiply_i64_1 = multiply_evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Null)?; - let multiply_i64_2 = multiply_evaluator - .0 - .binary_eval(&DataValue::Int64(1), &DataValue::Null)?; - let multiply_i64_3 = multiply_evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Int64(1))?; - let multiply_i64_4 = multiply_evaluator - .0 - .binary_eval(&DataValue::Int64(1), &DataValue::Int64(1))?; - - assert_eq!(multiply_i64_1, multiply_i64_2); - assert_eq!(multiply_i64_2, multiply_i64_3); - assert_eq!(multiply_i64_4, DataValue::Int64(1)); - - let multiply_evaluator = - EvaluatorFactory::binary_create(LogicalType::Double, BinaryOperator::Multiply)?; - let multiply_f64_1 = multiply_evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Null)?; - let multiply_f64_2 = multiply_evaluator - .0 - .binary_eval(&DataValue::Float64(OrderedFloat(1.0)), &DataValue::Null)?; - let multiply_f64_3 = multiply_evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Float64(OrderedFloat(1.0)))?; - let multiply_f64_4 = multiply_evaluator.0.binary_eval( - &DataValue::Float64(OrderedFloat(1.0)), - &DataValue::Float64(OrderedFloat(1.0)), - )?; - - assert_eq!(multiply_f64_1, multiply_f64_2); - assert_eq!(multiply_f64_2, multiply_f64_3); - assert_eq!(multiply_f64_4, DataValue::Float64(OrderedFloat(1.0))); - - Ok(()) - } - - #[test] - fn test_binary_op_arithmetic_divide() -> Result<(), DatabaseError> { - let divide_evaluator = - EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Divide)?; - let divide_i32_1 = divide_evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Null)?; - let divide_i32_2 = divide_evaluator - .0 - .binary_eval(&DataValue::Int32(1), &DataValue::Null)?; - let divide_i32_3 = divide_evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Int32(1))?; - let divide_i32_4 = divide_evaluator - .0 - .binary_eval(&DataValue::Int32(1), &DataValue::Int32(1))?; - - assert_eq!(divide_i32_1, divide_i32_2); - assert_eq!(divide_i32_2, divide_i32_3); - assert_eq!(divide_i32_4, DataValue::Float64(OrderedFloat(1.0))); - - let divide_evaluator = - EvaluatorFactory::binary_create(LogicalType::Bigint, BinaryOperator::Divide)?; - let divide_i64_1 = divide_evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Null)?; - let divide_i64_2 = divide_evaluator - .0 - .binary_eval(&DataValue::Int64(1), &DataValue::Null)?; - let divide_i64_3 = divide_evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Int64(1))?; - let divide_i64_4 = divide_evaluator - .0 - .binary_eval(&DataValue::Int64(1), &DataValue::Int64(1))?; - - assert_eq!(divide_i64_1, divide_i64_2); - assert_eq!(divide_i64_2, divide_i64_3); - assert_eq!(divide_i64_4, DataValue::Float64(OrderedFloat(1.0))); - - let divide_evaluator = - EvaluatorFactory::binary_create(LogicalType::Double, BinaryOperator::Divide)?; - let divide_f64_1 = divide_evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Null)?; - let divide_f64_2 = divide_evaluator - .0 - .binary_eval(&DataValue::Float64(OrderedFloat(1.0)), &DataValue::Null)?; - let divide_f64_3 = divide_evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Float64(OrderedFloat(1.0)))?; - let divide_f64_4 = divide_evaluator.0.binary_eval( - &DataValue::Float64(OrderedFloat(1.0)), - &DataValue::Float64(OrderedFloat(1.0)), - )?; - - assert_eq!(divide_f64_1, divide_f64_2); - assert_eq!(divide_f64_2, divide_f64_3); - assert_eq!(divide_f64_4, DataValue::Float64(OrderedFloat(1.0))); - - Ok(()) - } - - #[test] - fn test_binary_op_i32_compare() -> Result<(), DatabaseError> { - let evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Gt)?; - assert_eq!( - evaluator - .0 - .binary_eval(&DataValue::Int32(1), &DataValue::Int32(0),)?, - DataValue::Boolean(true) - ); - let evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Lt)?; - assert_eq!( - evaluator - .0 - .binary_eval(&DataValue::Int32(1), &DataValue::Int32(0),)?, - DataValue::Boolean(false) - ); - let evaluator = - EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::GtEq)?; - assert_eq!( - evaluator - .0 - .binary_eval(&DataValue::Int32(1), &DataValue::Int32(1),)?, - DataValue::Boolean(true) - ); - let evaluator = - EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::LtEq)?; - assert_eq!( - evaluator - .0 - .binary_eval(&DataValue::Int32(1), &DataValue::Int32(1),)?, - DataValue::Boolean(true) - ); - let evaluator = - EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::NotEq)?; - assert_eq!( - evaluator - .0 - .binary_eval(&DataValue::Int32(1), &DataValue::Int32(1),)?, - DataValue::Boolean(false) - ); - let evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Eq)?; - assert_eq!( - evaluator - .0 - .binary_eval(&DataValue::Int32(1), &DataValue::Int32(1),)?, - DataValue::Boolean(true) - ); - let evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Gt)?; - assert_eq!( - evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Int32(0),)?, - DataValue::Null - ); - let evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Lt)?; - assert_eq!( - evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Int32(0),)?, - DataValue::Null - ); - let evaluator = - EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::GtEq)?; - assert_eq!( - evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Int32(1),)?, - DataValue::Null - ); - let evaluator = - EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::LtEq)?; - assert_eq!( - evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Int32(1),)?, - DataValue::Null - ); - let evaluator = - EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::NotEq)?; - assert_eq!( - evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Int32(1),)?, - DataValue::Null - ); - let evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Eq)?; - assert_eq!( - evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Int32(1),)?, - DataValue::Null - ); - let evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Eq)?; - assert_eq!( - evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Null,)?, - DataValue::Null - ); - - Ok(()) - } - - #[test] - fn test_binary_op_bool_compare() -> Result<(), DatabaseError> { - let evaluator = EvaluatorFactory::binary_create(LogicalType::Boolean, BinaryOperator::And)?; - assert_eq!( - evaluator - .0 - .binary_eval(&DataValue::Boolean(true), &DataValue::Boolean(true),)?, - DataValue::Boolean(true) - ); - assert_eq!( - evaluator - .0 - .binary_eval(&DataValue::Boolean(false), &DataValue::Boolean(true),)?, - DataValue::Boolean(false) - ); - assert_eq!( - evaluator - .0 - .binary_eval(&DataValue::Boolean(false), &DataValue::Boolean(false),)?, - DataValue::Boolean(false) - ); - assert_eq!( - evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Boolean(true),)?, - DataValue::Null - ); - let evaluator = EvaluatorFactory::binary_create(LogicalType::Boolean, BinaryOperator::Or)?; - assert_eq!( - evaluator - .0 - .binary_eval(&DataValue::Boolean(true), &DataValue::Boolean(true),)?, - DataValue::Boolean(true) - ); - assert_eq!( - evaluator - .0 - .binary_eval(&DataValue::Boolean(false), &DataValue::Boolean(true),)?, - DataValue::Boolean(true) - ); - assert_eq!( - evaluator - .0 - .binary_eval(&DataValue::Boolean(false), &DataValue::Boolean(false),)?, - DataValue::Boolean(false) - ); - assert_eq!( - evaluator - .0 - .binary_eval(&DataValue::Null, &DataValue::Boolean(true),)?, - DataValue::Boolean(true) - ); - - Ok(()) - } - - #[test] - fn test_binary_op_utf8_compare() -> Result<(), DatabaseError> { - let evaluator = EvaluatorFactory::binary_create( - LogicalType::Varchar(None, CharLengthUnits::Characters), - BinaryOperator::Gt, - )?; - assert_eq!( - evaluator.0.binary_eval( - &DataValue::Utf8 { - value: "a".to_string(), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &DataValue::Utf8 { - value: "b".to_string(), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - )?, - DataValue::Boolean(false) - ); - let evaluator = EvaluatorFactory::binary_create( - LogicalType::Varchar(None, CharLengthUnits::Characters), - BinaryOperator::Lt, - )?; - assert_eq!( - evaluator.0.binary_eval( - &DataValue::Utf8 { - value: "a".to_string(), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &DataValue::Utf8 { - value: "b".to_string(), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - )?, - DataValue::Boolean(true) - ); - let evaluator = EvaluatorFactory::binary_create( - LogicalType::Varchar(None, CharLengthUnits::Characters), - BinaryOperator::GtEq, - )?; - assert_eq!( - evaluator.0.binary_eval( - &DataValue::Utf8 { - value: "a".to_string(), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &DataValue::Utf8 { - value: "a".to_string(), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - )?, - DataValue::Boolean(true) - ); - let evaluator = EvaluatorFactory::binary_create( - LogicalType::Varchar(None, CharLengthUnits::Characters), - BinaryOperator::LtEq, - )?; - assert_eq!( - evaluator.0.binary_eval( - &DataValue::Utf8 { - value: "a".to_string(), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &DataValue::Utf8 { - value: "a".to_string(), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - )?, - DataValue::Boolean(true) - ); - let evaluator = EvaluatorFactory::binary_create( - LogicalType::Varchar(None, CharLengthUnits::Characters), - BinaryOperator::NotEq, - )?; - assert_eq!( - evaluator.0.binary_eval( - &DataValue::Utf8 { - value: "a".to_string(), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &DataValue::Utf8 { - value: "a".to_string(), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - )?, - DataValue::Boolean(false) - ); - let evaluator = EvaluatorFactory::binary_create( - LogicalType::Varchar(None, CharLengthUnits::Characters), - BinaryOperator::Eq, - )?; - assert_eq!( - evaluator.0.binary_eval( - &DataValue::Utf8 { - value: "a".to_string(), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &DataValue::Utf8 { - value: "a".to_string(), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - )?, - DataValue::Boolean(true) - ); - let evaluator = EvaluatorFactory::binary_create( - LogicalType::Varchar(None, CharLengthUnits::Characters), - BinaryOperator::Gt, - )?; - assert_eq!( - evaluator.0.binary_eval( - &DataValue::Null, - &DataValue::Utf8 { - value: "a".to_string(), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - )?, - DataValue::Null - ); - let evaluator = EvaluatorFactory::binary_create( - LogicalType::Varchar(None, CharLengthUnits::Characters), - BinaryOperator::Lt, - )?; - assert_eq!( - evaluator.0.binary_eval( - &DataValue::Null, - &DataValue::Utf8 { - value: "a".to_string(), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - )?, - DataValue::Null - ); - let evaluator = EvaluatorFactory::binary_create( - LogicalType::Varchar(None, CharLengthUnits::Characters), - BinaryOperator::GtEq, - )?; - assert_eq!( - evaluator.0.binary_eval( - &DataValue::Null, - &DataValue::Utf8 { - value: "a".to_string(), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - )?, - DataValue::Null - ); - let evaluator = EvaluatorFactory::binary_create( - LogicalType::Varchar(None, CharLengthUnits::Characters), - BinaryOperator::LtEq, - )?; - assert_eq!( - evaluator.0.binary_eval( - &DataValue::Null, - &DataValue::Utf8 { - value: "a".to_string(), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - )?, - DataValue::Null - ); - let evaluator = EvaluatorFactory::binary_create( - LogicalType::Varchar(None, CharLengthUnits::Characters), - BinaryOperator::NotEq, - )?; - assert_eq!( - evaluator.0.binary_eval( - &DataValue::Null, - &DataValue::Utf8 { - value: "a".to_string(), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - )?, - DataValue::Null - ); - - Ok(()) - } - - #[test] - fn test_binary_op_time32_and_time64() -> Result<(), DatabaseError> { - let evaluator_time32 = - EvaluatorFactory::binary_create(LogicalType::Time(None), BinaryOperator::Plus)?; - assert_eq!( - evaluator_time32.0.binary_eval( - &DataValue::Time32(4190119896, 3), - &DataValue::Time32(2621204256, 4), - )?, - DataValue::Time32(2618593017, 4) - ); - assert_eq!( - evaluator_time32.0.binary_eval( - &DataValue::Time32(4190175696, 3), - &DataValue::Time32(2621224256, 4), - )?, - DataValue::Null - ); - - let evaluator_time32 = - EvaluatorFactory::binary_create(LogicalType::Time(None), BinaryOperator::Minus)?; - assert_eq!( - evaluator_time32.0.binary_eval( - &DataValue::Time32(4190119896, 3), - &DataValue::Time32(2621204256, 4), - )?, - DataValue::Null - ); - assert_eq!( - evaluator_time32.0.binary_eval( - &DataValue::Time32(2621204256, 4), - &DataValue::Time32(4190119896, 3), - )?, - DataValue::Time32(2375496, 4) - ); - - let evaluator_time32 = - EvaluatorFactory::binary_create(LogicalType::Time(None), BinaryOperator::Gt)?; - let evaluator_time64 = EvaluatorFactory::binary_create( - LogicalType::TimeStamp(None, false), - BinaryOperator::Gt, - )?; - assert_eq!( - evaluator_time32.0.binary_eval( - &DataValue::Time32(2621204256, 4), - &DataValue::Time32(4190119896, 3), - )?, - DataValue::Boolean(true) - ); - assert_eq!( - evaluator_time32.0.binary_eval( - &DataValue::Time32(4190119896, 3), - &DataValue::Time32(2621204256, 4), - )?, - DataValue::Boolean(false) - ); - assert_eq!( - evaluator_time64.0.binary_eval( - &DataValue::Time64(1736055775154814, 6, false), - &DataValue::Time64(1738734177256, 3, false), - )?, - DataValue::Boolean(false) - ); - assert_eq!( - evaluator_time64.0.binary_eval( - &DataValue::Time64(1738734177256, 3, false), - &DataValue::Time64(1736055775154814, 6, false), - )?, - DataValue::Boolean(true) - ); - - let evaluator_time32 = - EvaluatorFactory::binary_create(LogicalType::Time(None), BinaryOperator::GtEq)?; - let evaluator_time64 = EvaluatorFactory::binary_create( - LogicalType::TimeStamp(None, false), - BinaryOperator::GtEq, - )?; - assert_eq!( - evaluator_time32.0.binary_eval( - &DataValue::Time32(2621204256, 4), - &DataValue::Time32(4190119896, 3), - )?, - DataValue::Boolean(true) - ); - assert_eq!( - evaluator_time32.0.binary_eval( - &DataValue::Time32(4190119896, 3), - &DataValue::Time32(2621204256, 4), - )?, - DataValue::Boolean(false) - ); - assert_eq!( - evaluator_time32.0.binary_eval( - &DataValue::Time32(4190119896, 3), - &DataValue::Time32(2618828760, 4), - )?, - DataValue::Boolean(true) - ); - assert_eq!( - evaluator_time64.0.binary_eval( - &DataValue::Time64(1736055775154814, 6, false), - &DataValue::Time64(1738734177256, 3, false), - )?, - DataValue::Boolean(false) - ); - assert_eq!( - evaluator_time64.0.binary_eval( - &DataValue::Time64(1738734177256, 3, false), - &DataValue::Time64(1736055775154814, 6, false), - )?, - DataValue::Boolean(true) - ); - assert_eq!( - evaluator_time64.0.binary_eval( - &DataValue::Time64(1738734177256, 3, false), - &DataValue::Time64(1738734177256000, 6, false), - )?, - DataValue::Boolean(true) - ); - - let evaluator_time32 = - EvaluatorFactory::binary_create(LogicalType::Time(None), BinaryOperator::Eq)?; - let evaluator_time64 = EvaluatorFactory::binary_create( - LogicalType::TimeStamp(None, false), - BinaryOperator::Eq, - )?; - assert_eq!( - evaluator_time32.0.binary_eval( - &DataValue::Time32(4190119896, 3), - &DataValue::Time32(2621204256, 4), - )?, - DataValue::Boolean(false) - ); - assert_eq!( - evaluator_time32.0.binary_eval( - &DataValue::Time32(4190119896, 3), - &DataValue::Time32(2618828760, 4), - )?, - DataValue::Boolean(true) - ); - assert_eq!( - evaluator_time64.0.binary_eval( - &DataValue::Time64(1738734177256, 3, false), - &DataValue::Time64(1736055775154814, 6, false), - )?, - DataValue::Boolean(false) - ); - assert_eq!( - evaluator_time64.0.binary_eval( - &DataValue::Time64(1738734177256, 3, false), - &DataValue::Time64(1738734177256000, 6, false), - )?, - DataValue::Boolean(true) - ); - - Ok(()) - } - - #[test] - fn test_reference_serialization() -> Result<(), DatabaseError> { - let mut cursor = Cursor::new(Vec::new()); - let mut reference_tables = ReferenceTables::new(); - - let binary_evaluator = BinaryEvaluatorBox(Arc::new(BooleanNotEqBinaryEvaluator)); - binary_evaluator.encode(&mut cursor, false, &mut reference_tables)?; - - cursor.seek(SeekFrom::Start(0))?; - assert_eq!( - BinaryEvaluatorBox::decode::( - &mut cursor, - None, - &reference_tables - )?, - binary_evaluator - ); - cursor.seek(SeekFrom::Start(0))?; - let unary_evaluator = UnaryEvaluatorBox(Arc::new(BooleanNotUnaryEvaluator)); - unary_evaluator.encode(&mut cursor, false, &mut reference_tables)?; - cursor.seek(SeekFrom::Start(0))?; - assert_eq!( - UnaryEvaluatorBox::decode::(&mut cursor, None, &reference_tables)?, - unary_evaluator - ); - - Ok(()) +impl Hash for CastEvaluatorBox { + fn hash(&self, state: &mut H) { + state.write_i8(42) } } diff --git a/src/types/evaluator/null.rs b/src/types/evaluator/null.rs index a9343055..670c1b22 100644 --- a/src/types/evaluator/null.rs +++ b/src/types/evaluator/null.rs @@ -13,8 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::types::evaluator::BinaryEvaluator; use crate::types::evaluator::DataValue; +use crate::types::evaluator::{BinaryEvaluator, CastEvaluator}; use serde::{Deserialize, Serialize}; /// Tips: @@ -28,3 +28,42 @@ impl BinaryEvaluator for NullBinaryEvaluator { Ok(DataValue::Null) } } + +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct ToSqlNullCastEvaluator; + +#[typetag::serde] +impl CastEvaluator for ToSqlNullCastEvaluator { + fn eval_cast(&self, _value: &DataValue) -> Result { + Ok(DataValue::Null) + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct NullCastEvaluator; + +#[typetag::serde] +impl CastEvaluator for NullCastEvaluator { + fn eval_cast(&self, _value: &DataValue) -> Result { + Ok(DataValue::Null) + } +} + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod test { + use super::*; + + #[test] + fn test_null_cast_evaluators() { + assert_eq!( + ToSqlNullCastEvaluator + .eval_cast(&DataValue::Int32(1)) + .unwrap(), + DataValue::Null + ); + assert_eq!( + NullCastEvaluator.eval_cast(&DataValue::Null).unwrap(), + DataValue::Null + ); + } +} diff --git a/src/types/evaluator/time32.rs b/src/types/evaluator/time32.rs index b9be3128..18277430 100644 --- a/src/types/evaluator/time32.rs +++ b/src/types/evaluator/time32.rs @@ -13,10 +13,13 @@ // limitations under the License. use crate::errors::DatabaseError; +use crate::types::evaluator::cast::{cast_fail, to_char, to_varchar}; use crate::types::evaluator::BinaryEvaluator; use crate::types::evaluator::DataValue; use crate::types::value::{ONE_DAY_TO_SEC, ONE_SEC_TO_NANO}; +use crate::types::LogicalType; use serde::{Deserialize, Serialize}; +use sqlparser::ast::CharLengthUnits; use std::hint; #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] @@ -118,6 +121,48 @@ impl BinaryEvaluator for TimeGtEqBinaryEvaluator { }) } } + +crate::define_cast_evaluator!( + Time32ToCharCastEvaluator { + len: u32, + unit: CharLengthUnits, + to: LogicalType + }, + DataValue::Time32(value, precision) => |this| { + to_char( + DataValue::format_time(*value, *precision).ok_or_else(|| { + cast_fail(LogicalType::Time(Some(*precision)), this.to.clone()) + })?, + this.len, + this.unit, + ) + } +); +crate::define_cast_evaluator!( + Time32ToVarcharCastEvaluator { + len: Option, + unit: CharLengthUnits, + to: LogicalType + }, + DataValue::Time32(value, precision) => |this| { + to_varchar( + DataValue::format_time(*value, *precision).ok_or_else(|| { + cast_fail(LogicalType::Time(Some(*precision)), this.to.clone()) + })?, + this.len, + this.unit, + ) + } +); +crate::define_cast_evaluator!( + Time32ToTimeCastEvaluator { + precision: Option + }, + DataValue::Time32(value, _precision) => |this| { + Ok(DataValue::Time32(*value, this.precision.unwrap_or(0))) + } +); + #[typetag::serde] impl BinaryEvaluator for TimeLtBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { @@ -182,3 +227,71 @@ impl BinaryEvaluator for TimeNotEqBinaryEvaluator { }) } } + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod test { + use super::*; + use crate::types::evaluator::{BinaryEvaluator, CastEvaluator}; + use crate::types::value::Utf8Type; + + #[test] + fn test_time32_binary_evaluators() { + assert_eq!( + TimePlusBinaryEvaluator + .binary_eval( + &DataValue::Time32(4_190_119_896, 3), + &DataValue::Time32(2_621_204_256, 4), + ) + .unwrap(), + DataValue::Time32(2_618_593_017, 4) + ); + assert_eq!( + TimeGtBinaryEvaluator + .binary_eval( + &DataValue::Time32(2_621_204_256, 4), + &DataValue::Time32(4_190_119_896, 3), + ) + .unwrap(), + DataValue::Boolean(true) + ); + } + + #[test] + fn test_time32_cast_evaluators() { + let value = DataValue::Time32(DataValue::pack(3 * 3600 + 4 * 60 + 5, 123_000_000, 3), 3); + assert_eq!( + Time32ToCharCastEvaluator { + len: 12, + unit: CharLengthUnits::Characters, + to: LogicalType::Char(12, CharLengthUnits::Characters), + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "03:04:05.123".to_string(), + ty: Utf8Type::Fixed(12), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + Time32ToVarcharCastEvaluator { + len: Some(12), + unit: CharLengthUnits::Characters, + to: LogicalType::Varchar(Some(12), CharLengthUnits::Characters), + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "03:04:05.123".to_string(), + ty: Utf8Type::Variable(Some(12)), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + Time32ToTimeCastEvaluator { precision: Some(3) } + .eval_cast(&value) + .unwrap(), + value + ); + } +} diff --git a/src/types/evaluator/time64.rs b/src/types/evaluator/time64.rs index ee1dc763..5e7de3c4 100644 --- a/src/types/evaluator/time64.rs +++ b/src/types/evaluator/time64.rs @@ -13,9 +13,13 @@ // limitations under the License. use crate::errors::DatabaseError; +use crate::types::evaluator::cast::{cast_fail, to_char, to_varchar}; use crate::types::evaluator::BinaryEvaluator; use crate::types::evaluator::DataValue; +use crate::types::LogicalType; +use chrono::{Datelike, Timelike}; use serde::{Deserialize, Serialize}; +use sqlparser::ast::CharLengthUnits; use std::hint; #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] @@ -81,6 +85,100 @@ impl BinaryEvaluator for Time64GtEqBinaryEvaluator { }) } } + +crate::define_cast_evaluator!( + Time64ToCharCastEvaluator { + len: u32, + unit: CharLengthUnits, + to: LogicalType + }, + DataValue::Time64(value, precision, _) => |this| { + to_char( + DataValue::format_timestamp(*value, *precision).ok_or_else(|| { + cast_fail(LogicalType::TimeStamp(Some(*precision), false), this.to.clone()) + })?, + this.len, + this.unit, + ) + } +); +crate::define_cast_evaluator!( + Time64ToVarcharCastEvaluator { + len: Option, + unit: CharLengthUnits, + to: LogicalType + }, + DataValue::Time64(value, precision, _) => |this| { + to_varchar( + DataValue::format_timestamp(*value, *precision).ok_or_else(|| { + cast_fail(LogicalType::TimeStamp(Some(*precision), false), this.to.clone()) + })?, + this.len, + this.unit, + ) + } +); +crate::define_cast_evaluator!( + Time64ToDateCastEvaluator { + from: LogicalType, + to: LogicalType + }, + DataValue::Time64(value, precision, _) => |this| { + let value = DataValue::from_timestamp_precision(*value, *precision) + .ok_or_else(|| cast_fail(this.from.clone(), this.to.clone()))? + .naive_utc() + .date() + .num_days_from_ce(); + + Ok(DataValue::Date32(value)) + } +); +crate::define_cast_evaluator!( + Time64ToDatetimeCastEvaluator { + from: LogicalType, + to: LogicalType + }, + DataValue::Time64(value, precision, _) => |this| { + let value = DataValue::from_timestamp_precision(*value, *precision) + .ok_or_else(|| cast_fail(this.from.clone(), this.to.clone()))? + .timestamp(); + + Ok(DataValue::Date64(value)) + } +); +crate::define_cast_evaluator!( + Time64ToTimeCastEvaluator { + precision: Option, + from: LogicalType, + to: LogicalType + }, + DataValue::Time64(value, precision, _) => |this| { + let target_precision = this.precision.unwrap_or(0); + let (value, nano) = DataValue::from_timestamp_precision(*value, *precision) + .map(|date_time| { + ( + date_time.time().num_seconds_from_midnight(), + date_time.time().nanosecond(), + ) + }) + .ok_or_else(|| cast_fail(this.from.clone(), this.to.clone()))?; + + Ok(DataValue::Time32( + DataValue::pack(value, nano, target_precision), + target_precision, + )) + } +); +crate::define_cast_evaluator!( + Time64ToTimestampCastEvaluator { + precision: Option, + zone: bool + }, + DataValue::Time64(value, _precision, _) => |this| { + Ok(DataValue::Time64(*value, this.precision.unwrap_or(0), this.zone)) + } +); + #[typetag::serde] impl BinaryEvaluator for Time64LtBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { @@ -181,3 +279,111 @@ impl BinaryEvaluator for Time64NotEqBinaryEvaluator { }) } } + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod test { + use super::*; + use crate::types::evaluator::{BinaryEvaluator, CastEvaluator}; + use crate::types::value::Utf8Type; + use sqlparser::ast::CharLengthUnits; + + #[test] + fn test_time64_binary_evaluators() { + assert_eq!( + Time64EqBinaryEvaluator + .binary_eval( + &DataValue::Time64(1_738_734_177_256, 3, false), + &DataValue::Time64(1_738_734_177_256_000, 6, false), + ) + .unwrap(), + DataValue::Boolean(true) + ); + } + + #[test] + fn test_time64_cast_evaluators() { + let timestamp = chrono::NaiveDate::from_ymd_opt(2024, 1, 2) + .unwrap() + .and_hms_milli_opt(3, 4, 5, 123) + .unwrap() + .and_utc() + .timestamp_millis(); + let value = DataValue::Time64(timestamp, 3, false); + assert_eq!( + Time64ToCharCastEvaluator { + len: 23, + unit: CharLengthUnits::Characters, + to: LogicalType::Char(23, CharLengthUnits::Characters), + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "2024-01-02 03:04:05.123".to_string(), + ty: Utf8Type::Fixed(23), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + Time64ToVarcharCastEvaluator { + len: Some(23), + unit: CharLengthUnits::Characters, + to: LogicalType::Varchar(Some(23), CharLengthUnits::Characters), + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "2024-01-02 03:04:05.123".to_string(), + ty: Utf8Type::Variable(Some(23)), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + Time64ToDateCastEvaluator { + from: LogicalType::TimeStamp(Some(3), false), + to: LogicalType::Date, + } + .eval_cast(&value) + .unwrap(), + DataValue::Date32( + chrono::NaiveDate::from_ymd_opt(2024, 1, 2) + .unwrap() + .num_days_from_ce() + ) + ); + assert_eq!( + Time64ToDatetimeCastEvaluator { + from: LogicalType::TimeStamp(Some(3), false), + to: LogicalType::DateTime, + } + .eval_cast(&value) + .unwrap(), + DataValue::Date64( + chrono::NaiveDate::from_ymd_opt(2024, 1, 2) + .unwrap() + .and_hms_opt(3, 4, 5) + .unwrap() + .and_utc() + .timestamp() + ) + ); + assert_eq!( + Time64ToTimeCastEvaluator { + precision: Some(3), + from: LogicalType::TimeStamp(Some(3), false), + to: LogicalType::Time(Some(3)), + } + .eval_cast(&value) + .unwrap(), + DataValue::Time32(DataValue::pack(3 * 3600 + 4 * 60 + 5, 123_000_000, 3), 3) + ); + assert_eq!( + Time64ToTimestampCastEvaluator { + precision: Some(3), + zone: true, + } + .eval_cast(&value) + .unwrap(), + DataValue::Time64(timestamp, 3, true) + ); + } +} diff --git a/src/types/evaluator/tuple.rs b/src/types/evaluator/tuple.rs index 4561850f..f5145496 100644 --- a/src/types/evaluator/tuple.rs +++ b/src/types/evaluator/tuple.rs @@ -13,8 +13,8 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::types::evaluator::BinaryEvaluator; use crate::types::evaluator::DataValue; +use crate::types::evaluator::{BinaryEvaluator, CastEvaluator, CastEvaluatorBox}; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; use std::hint; @@ -87,6 +87,31 @@ impl BinaryEvaluator for TupleNotEqBinaryEvaluator { }) } } + +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct TupleCastEvaluator { + pub element_evaluators: Vec, +} + +#[typetag::serde] +impl CastEvaluator for TupleCastEvaluator { + fn eval_cast(&self, value: &DataValue) -> Result { + match value { + DataValue::Null => Ok(DataValue::Null), + DataValue::Tuple(values, is_upper) => { + let mut casted = Vec::with_capacity(values.len()); + + for (value, evaluator) in values.iter().zip(self.element_evaluators.iter()) { + casted.push(evaluator.eval_cast(value)?); + } + + Ok(DataValue::Tuple(casted, *is_upper)) + } + _ => unsafe { hint::unreachable_unchecked() }, + } + } +} + #[typetag::serde] impl BinaryEvaluator for TupleGtBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { @@ -151,3 +176,44 @@ impl BinaryEvaluator for TupleLtEqBinaryEvaluator { }) } } + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod test { + use super::*; + use crate::types::evaluator::cast_create; + use crate::types::LogicalType; + use sqlparser::ast::CharLengthUnits; + use std::borrow::Cow; + + #[test] + fn test_tuple_cast_evaluator() { + let evaluator = cast_create( + Cow::Owned(LogicalType::Tuple(vec![ + LogicalType::Integer, + LogicalType::Varchar(None, CharLengthUnits::Characters), + ])), + Cow::Owned(LogicalType::Tuple(vec![ + LogicalType::Bigint, + LogicalType::Integer, + ])), + ) + .unwrap(); + + assert_eq!( + evaluator + .eval_cast(&DataValue::Tuple( + vec![ + DataValue::Int32(1), + DataValue::Utf8 { + value: "2".to_string(), + ty: crate::types::value::Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }, + ], + false, + )) + .unwrap(), + DataValue::Tuple(vec![DataValue::Int64(1), DataValue::Int32(2)], false) + ); + } +} diff --git a/src/types/evaluator/uint16.rs b/src/types/evaluator/uint16.rs index 70c1f13c..ff6789da 100644 --- a/src/types/evaluator/uint16.rs +++ b/src/types/evaluator/uint16.rs @@ -13,11 +13,100 @@ // limitations under the License. use crate::numeric_binary_evaluator_definition; -use crate::types::evaluator::BinaryEvaluator; use crate::types::evaluator::DataValue; -use crate::types::DatabaseError; -use paste::paste; -use serde::{Deserialize, Serialize}; -use std::hint; +use crate::types::LogicalType; numeric_binary_evaluator_definition!(UInt16, DataValue::UInt16); +crate::define_integer_cast_evaluators!(UInt16, UInt16, u16, LogicalType::USmallint); + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod test { + use super::*; + use crate::types::evaluator::CastEvaluator; + use crate::types::value::Utf8Type; + use ordered_float::OrderedFloat; + use rust_decimal::Decimal; + use sqlparser::ast::CharLengthUnits; + + #[test] + fn test_uint16_cast_evaluators() { + let value = DataValue::UInt16(1); + + assert_eq!( + UInt16ToBooleanCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Boolean(true) + ); + assert_eq!( + UInt16ToTinyintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int8(1) + ); + assert_eq!( + UInt16ToUTinyintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt8(1) + ); + assert_eq!( + UInt16ToSmallintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int16(1) + ); + assert_eq!( + UInt16ToUSmallintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt16(1) + ); + assert_eq!( + UInt16ToIntegerCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int32(1) + ); + assert_eq!( + UInt16ToUIntegerCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt32(1) + ); + assert_eq!( + UInt16ToBigintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int64(1) + ); + assert_eq!( + UInt16ToUBigintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt64(1) + ); + assert_eq!( + UInt16ToFloatCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Float32(OrderedFloat(1.0)) + ); + assert_eq!( + UInt16ToDoubleCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Float64(OrderedFloat(1.0)) + ); + assert_eq!( + UInt16ToCharCastEvaluator { + len: 1, + unit: CharLengthUnits::Characters, + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "1".to_string(), + ty: Utf8Type::Fixed(1), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + UInt16ToVarcharCastEvaluator { + len: Some(1), + unit: CharLengthUnits::Characters, + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "1".to_string(), + ty: Utf8Type::Variable(Some(1)), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + UInt16ToDecimalCastEvaluator { scale: Some(2) } + .eval_cast(&value) + .unwrap(), + DataValue::Decimal(Decimal::new(100, 2)) + ); + } +} diff --git a/src/types/evaluator/uint32.rs b/src/types/evaluator/uint32.rs index 43409673..058570ba 100644 --- a/src/types/evaluator/uint32.rs +++ b/src/types/evaluator/uint32.rs @@ -13,11 +13,100 @@ // limitations under the License. use crate::numeric_binary_evaluator_definition; -use crate::types::evaluator::BinaryEvaluator; use crate::types::evaluator::DataValue; -use crate::types::DatabaseError; -use paste::paste; -use serde::{Deserialize, Serialize}; -use std::hint; +use crate::types::LogicalType; numeric_binary_evaluator_definition!(UInt32, DataValue::UInt32); +crate::define_integer_cast_evaluators!(UInt32, UInt32, u32, LogicalType::UInteger); + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod test { + use super::*; + use crate::types::evaluator::CastEvaluator; + use crate::types::value::Utf8Type; + use ordered_float::OrderedFloat; + use rust_decimal::Decimal; + use sqlparser::ast::CharLengthUnits; + + #[test] + fn test_uint32_cast_evaluators() { + let value = DataValue::UInt32(1); + + assert_eq!( + UInt32ToBooleanCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Boolean(true) + ); + assert_eq!( + UInt32ToTinyintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int8(1) + ); + assert_eq!( + UInt32ToUTinyintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt8(1) + ); + assert_eq!( + UInt32ToSmallintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int16(1) + ); + assert_eq!( + UInt32ToUSmallintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt16(1) + ); + assert_eq!( + UInt32ToIntegerCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int32(1) + ); + assert_eq!( + UInt32ToUIntegerCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt32(1) + ); + assert_eq!( + UInt32ToBigintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int64(1) + ); + assert_eq!( + UInt32ToUBigintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt64(1) + ); + assert_eq!( + UInt32ToFloatCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Float32(OrderedFloat(1.0)) + ); + assert_eq!( + UInt32ToDoubleCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Float64(OrderedFloat(1.0)) + ); + assert_eq!( + UInt32ToCharCastEvaluator { + len: 1, + unit: CharLengthUnits::Characters, + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "1".to_string(), + ty: Utf8Type::Fixed(1), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + UInt32ToVarcharCastEvaluator { + len: Some(1), + unit: CharLengthUnits::Characters, + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "1".to_string(), + ty: Utf8Type::Variable(Some(1)), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + UInt32ToDecimalCastEvaluator { scale: Some(2) } + .eval_cast(&value) + .unwrap(), + DataValue::Decimal(Decimal::new(100, 2)) + ); + } +} diff --git a/src/types/evaluator/uint64.rs b/src/types/evaluator/uint64.rs index 723b272f..2bfbe787 100644 --- a/src/types/evaluator/uint64.rs +++ b/src/types/evaluator/uint64.rs @@ -13,11 +13,100 @@ // limitations under the License. use crate::numeric_binary_evaluator_definition; -use crate::types::evaluator::BinaryEvaluator; use crate::types::evaluator::DataValue; -use crate::types::DatabaseError; -use paste::paste; -use serde::{Deserialize, Serialize}; -use std::hint; +use crate::types::LogicalType; numeric_binary_evaluator_definition!(UInt64, DataValue::UInt64); +crate::define_integer_cast_evaluators!(UInt64, UInt64, u64, LogicalType::UBigint); + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod test { + use super::*; + use crate::types::evaluator::CastEvaluator; + use crate::types::value::Utf8Type; + use ordered_float::OrderedFloat; + use rust_decimal::Decimal; + use sqlparser::ast::CharLengthUnits; + + #[test] + fn test_uint64_cast_evaluators() { + let value = DataValue::UInt64(1); + + assert_eq!( + UInt64ToBooleanCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Boolean(true) + ); + assert_eq!( + UInt64ToTinyintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int8(1) + ); + assert_eq!( + UInt64ToUTinyintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt8(1) + ); + assert_eq!( + UInt64ToSmallintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int16(1) + ); + assert_eq!( + UInt64ToUSmallintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt16(1) + ); + assert_eq!( + UInt64ToIntegerCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int32(1) + ); + assert_eq!( + UInt64ToUIntegerCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt32(1) + ); + assert_eq!( + UInt64ToBigintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int64(1) + ); + assert_eq!( + UInt64ToUBigintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt64(1) + ); + assert_eq!( + UInt64ToFloatCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Float32(OrderedFloat(1.0)) + ); + assert_eq!( + UInt64ToDoubleCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Float64(OrderedFloat(1.0)) + ); + assert_eq!( + UInt64ToCharCastEvaluator { + len: 1, + unit: CharLengthUnits::Characters, + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "1".to_string(), + ty: Utf8Type::Fixed(1), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + UInt64ToVarcharCastEvaluator { + len: Some(1), + unit: CharLengthUnits::Characters, + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "1".to_string(), + ty: Utf8Type::Variable(Some(1)), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + UInt64ToDecimalCastEvaluator { scale: Some(2) } + .eval_cast(&value) + .unwrap(), + DataValue::Decimal(Decimal::new(100, 2)) + ); + } +} diff --git a/src/types/evaluator/uint8.rs b/src/types/evaluator/uint8.rs index 6266bf71..135a5a18 100644 --- a/src/types/evaluator/uint8.rs +++ b/src/types/evaluator/uint8.rs @@ -13,11 +13,100 @@ // limitations under the License. use crate::numeric_binary_evaluator_definition; -use crate::types::evaluator::BinaryEvaluator; use crate::types::evaluator::DataValue; -use crate::types::DatabaseError; -use paste::paste; -use serde::{Deserialize, Serialize}; -use std::hint; +use crate::types::LogicalType; numeric_binary_evaluator_definition!(UInt8, DataValue::UInt8); +crate::define_integer_cast_evaluators!(UInt8, UInt8, u8, LogicalType::UTinyint); + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod test { + use super::*; + use crate::types::evaluator::CastEvaluator; + use crate::types::value::Utf8Type; + use ordered_float::OrderedFloat; + use rust_decimal::Decimal; + use sqlparser::ast::CharLengthUnits; + + #[test] + fn test_uint8_cast_evaluators() { + let value = DataValue::UInt8(1); + + assert_eq!( + UInt8ToBooleanCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Boolean(true) + ); + assert_eq!( + UInt8ToTinyintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int8(1) + ); + assert_eq!( + UInt8ToUTinyintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt8(1) + ); + assert_eq!( + UInt8ToSmallintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int16(1) + ); + assert_eq!( + UInt8ToUSmallintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt16(1) + ); + assert_eq!( + UInt8ToIntegerCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int32(1) + ); + assert_eq!( + UInt8ToUIntegerCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt32(1) + ); + assert_eq!( + UInt8ToBigintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Int64(1) + ); + assert_eq!( + UInt8ToUBigintCastEvaluator.eval_cast(&value).unwrap(), + DataValue::UInt64(1) + ); + assert_eq!( + UInt8ToFloatCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Float32(OrderedFloat(1.0)) + ); + assert_eq!( + UInt8ToDoubleCastEvaluator.eval_cast(&value).unwrap(), + DataValue::Float64(OrderedFloat(1.0)) + ); + assert_eq!( + UInt8ToCharCastEvaluator { + len: 1, + unit: CharLengthUnits::Characters, + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "1".to_string(), + ty: Utf8Type::Fixed(1), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + UInt8ToVarcharCastEvaluator { + len: Some(1), + unit: CharLengthUnits::Characters, + } + .eval_cast(&value) + .unwrap(), + DataValue::Utf8 { + value: "1".to_string(), + ty: Utf8Type::Variable(Some(1)), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + UInt8ToDecimalCastEvaluator { scale: Some(2) } + .eval_cast(&value) + .unwrap(), + DataValue::Decimal(Decimal::new(100, 2)) + ); + } +} diff --git a/src/types/evaluator/unary.rs b/src/types/evaluator/unary.rs new file mode 100644 index 00000000..d201cd64 --- /dev/null +++ b/src/types/evaluator/unary.rs @@ -0,0 +1,183 @@ +// Copyright 2024 KipData/KiteSQL +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::errors::DatabaseError; +use crate::expression::UnaryOperator; +use crate::types::evaluator::boolean::BooleanNotUnaryEvaluator; +use crate::types::evaluator::float32::*; +use crate::types::evaluator::float64::*; +use crate::types::evaluator::int16::*; +use crate::types::evaluator::int32::*; +use crate::types::evaluator::int64::*; +use crate::types::evaluator::int8::*; +use crate::types::evaluator::UnaryEvaluatorBox; +use crate::types::LogicalType; +use paste::paste; +use std::borrow::Cow; +use std::sync::Arc; + +macro_rules! numeric_unary_evaluator { + ($value_type:ident, $op:expr, $ty:expr) => { + paste! { + match $op { + UnaryOperator::Plus => Ok(UnaryEvaluatorBox(Arc::new([<$value_type PlusUnaryEvaluator>]))), + UnaryOperator::Minus => Ok(UnaryEvaluatorBox(Arc::new([<$value_type MinusUnaryEvaluator>]))), + _ => Err(DatabaseError::UnsupportedUnaryOperator($ty.clone(), $op)), + } + } + }; +} + +pub fn unary_create( + ty: Cow<'_, LogicalType>, + op: UnaryOperator, +) -> Result { + let ty = ty.as_ref(); + match ty { + LogicalType::Tinyint => numeric_unary_evaluator!(Int8, op, ty), + LogicalType::Smallint => numeric_unary_evaluator!(Int16, op, ty), + LogicalType::Integer => numeric_unary_evaluator!(Int32, op, ty), + LogicalType::Bigint => numeric_unary_evaluator!(Int64, op, ty), + LogicalType::Boolean => match op { + UnaryOperator::Not => Ok(UnaryEvaluatorBox(Arc::new(BooleanNotUnaryEvaluator))), + _ => Err(DatabaseError::UnsupportedUnaryOperator(ty.clone(), op)), + }, + LogicalType::Float => numeric_unary_evaluator!(Float32, op, ty), + LogicalType::Double => numeric_unary_evaluator!(Float64, op, ty), + _ => Err(DatabaseError::UnsupportedUnaryOperator(ty.clone(), op)), + } +} + +#[macro_export] +macro_rules! numeric_unary_evaluator_definition { + ($value_type:ident, $compute_type:path) => { + paste::paste! { + #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] + pub struct [<$value_type PlusUnaryEvaluator>]; + #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] + pub struct [<$value_type MinusUnaryEvaluator>]; + + #[typetag::serde] + impl $crate::types::evaluator::UnaryEvaluator for [<$value_type PlusUnaryEvaluator>] { + fn unary_eval(&self, value: &$crate::types::value::DataValue) -> $crate::types::value::DataValue { + value.clone() + } + } + #[typetag::serde] + impl $crate::types::evaluator::UnaryEvaluator for [<$value_type MinusUnaryEvaluator>] { + fn unary_eval(&self, value: &$crate::types::value::DataValue) -> $crate::types::value::DataValue { + match value { + $compute_type(value) => $compute_type(-value), + $crate::types::value::DataValue::Null => $crate::types::value::DataValue::Null, + _ => unsafe { std::hint::unreachable_unchecked() }, + } + } + } + } + }; +} + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod test { + use super::unary_create; + use crate::errors::DatabaseError; + use crate::expression::UnaryOperator; + use crate::serdes::{ReferenceSerialization, ReferenceTables}; + use crate::storage::rocksdb::RocksTransaction; + use crate::types::evaluator::UnaryEvaluatorBox; + use crate::types::value::DataValue; + use crate::types::LogicalType; + use ordered_float::OrderedFloat; + use std::borrow::Cow; + use std::io::{Cursor, Seek, SeekFrom}; + + fn create(ty: LogicalType, op: UnaryOperator) -> Result { + unary_create(Cow::Owned(ty), op) + } + + #[test] + fn test_numeric_unary_evaluators() -> Result<(), DatabaseError> { + let cases = vec![ + ( + LogicalType::Integer, + UnaryOperator::Plus, + DataValue::Int32(7), + DataValue::Int32(7), + ), + ( + LogicalType::Integer, + UnaryOperator::Minus, + DataValue::Int32(7), + DataValue::Int32(-7), + ), + ( + LogicalType::Bigint, + UnaryOperator::Minus, + DataValue::Int64(7), + DataValue::Int64(-7), + ), + ( + LogicalType::Double, + UnaryOperator::Minus, + DataValue::Float64(OrderedFloat(1.5)), + DataValue::Float64(OrderedFloat(-1.5)), + ), + ]; + + for (ty, op, value, expected) in cases { + let evaluator = create(ty, op)?; + assert_eq!(evaluator.unary_eval(&value), expected); + assert_eq!(evaluator.unary_eval(&DataValue::Null), DataValue::Null); + } + + Ok(()) + } + + #[test] + fn test_boolean_unary_evaluator() -> Result<(), DatabaseError> { + let evaluator = create(LogicalType::Boolean, UnaryOperator::Not)?; + assert_eq!( + evaluator.unary_eval(&DataValue::Boolean(true)), + DataValue::Boolean(false) + ); + assert_eq!(evaluator.unary_eval(&DataValue::Null), DataValue::Null); + Ok(()) + } + + #[test] + fn test_unary_evaluator_rejects_unsupported_operator() { + let err = create(LogicalType::Boolean, UnaryOperator::Plus).unwrap_err(); + assert!(matches!( + err, + DatabaseError::UnsupportedUnaryOperator(LogicalType::Boolean, UnaryOperator::Plus) + )); + } + + #[test] + fn test_unary_evaluator_serialization() -> Result<(), DatabaseError> { + let evaluator = create(LogicalType::Boolean, UnaryOperator::Not)?; + let mut cursor = Cursor::new(Vec::new()); + let mut reference_tables = ReferenceTables::new(); + + evaluator.encode(&mut cursor, false, &mut reference_tables)?; + cursor.seek(SeekFrom::Start(0))?; + + assert_eq!( + UnaryEvaluatorBox::decode::(&mut cursor, None, &reference_tables)?, + evaluator + ); + + Ok(()) + } +} diff --git a/src/types/evaluator/utf8.rs b/src/types/evaluator/utf8.rs index 31af9db8..ad487d2f 100644 --- a/src/types/evaluator/utf8.rs +++ b/src/types/evaluator/utf8.rs @@ -13,13 +13,19 @@ // limitations under the License. use crate::errors::DatabaseError; +use crate::types::evaluator::cast::{cast_fail, to_char, to_varchar}; use crate::types::evaluator::BinaryEvaluator; use crate::types::evaluator::DataValue; use crate::types::value::Utf8Type; +use crate::types::LogicalType; +use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike}; +use ordered_float::OrderedFloat; use regex::Regex; +use rust_decimal::Decimal; use serde::{Deserialize, Serialize}; use sqlparser::ast::CharLengthUnits; use std::hint; +use std::str::FromStr; #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] pub struct Utf8GtBinaryEvaluator; @@ -86,6 +92,156 @@ impl BinaryEvaluator for Utf8LtBinaryEvaluator { }) } } + +crate::define_cast_evaluator!( + Utf8ToBooleanCastEvaluator { + from: LogicalType + }, + DataValue::Utf8 { value, .. } => { + Ok(DataValue::Boolean(bool::from_str(value)?)) + } +); +crate::define_cast_evaluator!(Utf8ToTinyintCastEvaluator, DataValue::Utf8 { value, .. } => { + Ok(DataValue::Int8(i8::from_str(value)?)) +}); +crate::define_cast_evaluator!(Utf8ToUTinyintCastEvaluator, DataValue::Utf8 { value, .. } => { + Ok(DataValue::UInt8(u8::from_str(value)?)) +}); +crate::define_cast_evaluator!(Utf8ToSmallintCastEvaluator, DataValue::Utf8 { value, .. } => { + Ok(DataValue::Int16(i16::from_str(value)?)) +}); +crate::define_cast_evaluator!(Utf8ToUSmallintCastEvaluator, DataValue::Utf8 { value, .. } => { + Ok(DataValue::UInt16(u16::from_str(value)?)) +}); +crate::define_cast_evaluator!(Utf8ToIntegerCastEvaluator, DataValue::Utf8 { value, .. } => { + Ok(DataValue::Int32(i32::from_str(value)?)) +}); +crate::define_cast_evaluator!(Utf8ToUIntegerCastEvaluator, DataValue::Utf8 { value, .. } => { + Ok(DataValue::UInt32(u32::from_str(value)?)) +}); +crate::define_cast_evaluator!(Utf8ToBigintCastEvaluator, DataValue::Utf8 { value, .. } => { + Ok(DataValue::Int64(i64::from_str(value)?)) +}); +crate::define_cast_evaluator!(Utf8ToUBigintCastEvaluator, DataValue::Utf8 { value, .. } => { + Ok(DataValue::UInt64(u64::from_str(value)?)) +}); +crate::define_cast_evaluator!(Utf8ToFloatCastEvaluator, DataValue::Utf8 { value, .. } => { + Ok(DataValue::Float32(OrderedFloat(f32::from_str(value)?))) +}); +crate::define_cast_evaluator!(Utf8ToDoubleCastEvaluator, DataValue::Utf8 { value, .. } => { + Ok(DataValue::Float64(OrderedFloat(f64::from_str(value)?))) +}); +crate::define_cast_evaluator!( + Utf8ToCharCastEvaluator { + len: u32, + unit: CharLengthUnits + }, + DataValue::Utf8 { value, .. } => |this| to_char(value.clone(), this.len, this.unit) +); +crate::define_cast_evaluator!( + Utf8ToVarcharCastEvaluator { + len: Option, + unit: CharLengthUnits + }, + DataValue::Utf8 { value, .. } => |this| to_varchar(value.clone(), this.len, this.unit) +); +crate::define_cast_evaluator!(Utf8ToDateCastEvaluator, DataValue::Utf8 { value, .. } => { + Ok(DataValue::Date32( + NaiveDate::parse_from_str(value, crate::types::value::DATE_FMT)?.num_days_from_ce(), + )) +}); +crate::define_cast_evaluator!(Utf8ToDatetimeCastEvaluator, DataValue::Utf8 { value, .. } => { + let value = NaiveDateTime::parse_from_str(value, crate::types::value::DATE_TIME_FMT) + .or_else(|_| { + NaiveDate::parse_from_str(value, crate::types::value::DATE_FMT) + .map(|date| date.and_hms_opt(0, 0, 0).unwrap()) + })? + .and_utc() + .timestamp(); + + Ok(DataValue::Date64(value)) +}); +crate::define_cast_evaluator!( + Utf8ToTimeCastEvaluator { + precision: Option + }, + DataValue::Utf8 { value, .. } => |this| { + let precision = this.precision.unwrap_or(0); + let fmt = if precision == 0 { + crate::types::value::TIME_FMT + } else { + crate::types::value::TIME_FMT_WITHOUT_ZONE + }; + let (value, nano) = match precision { + 0 => ( + NaiveTime::parse_from_str(value, fmt) + .map(|time| time.num_seconds_from_midnight())?, + 0, + ), + _ => NaiveTime::parse_from_str(value, fmt) + .map(|time| (time.num_seconds_from_midnight(), time.nanosecond()))?, + }; + + Ok(DataValue::Time32(DataValue::pack(value, nano, precision), precision)) + } +); +crate::define_cast_evaluator!( + Utf8ToTimestampCastEvaluator { + precision: Option, + zone: bool, + to: LogicalType + }, + DataValue::Utf8 { value, .. } => |this| { + let precision = this.precision.unwrap_or(0); + let fmt = match (precision, this.zone) { + (0, false) => crate::types::value::DATE_TIME_FMT, + (0, true) => crate::types::value::TIME_STAMP_FMT_WITHOUT_PRECISION, + (3 | 6 | 9, false) => crate::types::value::TIME_STAMP_FMT_WITHOUT_ZONE, + _ => crate::types::value::TIME_STAMP_FMT_WITH_ZONE, + }; + let complete_value = if this.zone { + if value.contains("+") { + value.clone() + } else { + format!("{value}+00:00") + } + } else { + value.clone() + }; + + if !this.zone { + let value = NaiveDateTime::parse_from_str(&complete_value, fmt)?.and_utc(); + let value = match precision { + 3 => value.timestamp_millis(), + 6 => value.timestamp_micros(), + 9 => value + .timestamp_nanos_opt() + .ok_or_else(|| cast_fail(this.to.clone(), this.to.clone()))?, + 0 => value.timestamp(), + _ => unreachable!(), + }; + + return Ok(DataValue::Time64(value, precision, false)); + } + + let value = DateTime::parse_from_str(&complete_value, fmt); + let value = match precision { + 3 => value.map(|date_time| date_time.timestamp_millis())?, + 6 => value.map(|date_time| date_time.timestamp_micros())?, + 9 => value + .map(|date_time| date_time.timestamp_nanos_opt())? + .ok_or_else(|| cast_fail(this.to.clone(), this.to.clone()))?, + 0 => value.map(|date_time| date_time.timestamp())?, + _ => unreachable!(), + }; + + Ok(DataValue::Time64(value, precision, this.zone)) + } +); +crate::define_cast_evaluator!(Utf8ToDecimalCastEvaluator, DataValue::Utf8 { value, .. } => { + Ok(DataValue::Decimal(Decimal::from_str(value)?)) +}); + #[typetag::serde] impl BinaryEvaluator for Utf8LtEqBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { @@ -193,3 +349,206 @@ fn string_like(value: &str, pattern: &str, escape_char: Option) -> bool { } Regex::new(®ex_pattern).unwrap().is_match(value) } + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod test { + use super::*; + use crate::types::evaluator::{BinaryEvaluator, CastEvaluator}; + + fn utf8(value: &str) -> DataValue { + DataValue::Utf8 { + value: value.to_string(), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + } + } + + #[test] + fn test_utf8_binary_evaluators() { + assert_eq!( + Utf8LtBinaryEvaluator + .binary_eval(&utf8("a"), &utf8("b")) + .unwrap(), + DataValue::Boolean(true) + ); + assert_eq!( + Utf8StringConcatBinaryEvaluator + .binary_eval(&utf8("ab"), &utf8("cd")) + .unwrap(), + utf8("abcd") + ); + assert_eq!( + Utf8LikeBinaryEvaluator { escape_char: None } + .binary_eval(&utf8("kite"), &utf8("ki%")) + .unwrap(), + DataValue::Boolean(true) + ); + } + + #[test] + fn test_utf8_cast_evaluators() { + assert_eq!( + Utf8ToBooleanCastEvaluator { + from: LogicalType::Varchar(None, CharLengthUnits::Characters), + } + .eval_cast(&utf8("true")) + .unwrap(), + DataValue::Boolean(true) + ); + assert_eq!( + Utf8ToTinyintCastEvaluator.eval_cast(&utf8("1")).unwrap(), + DataValue::Int8(1) + ); + assert_eq!( + Utf8ToUTinyintCastEvaluator.eval_cast(&utf8("1")).unwrap(), + DataValue::UInt8(1) + ); + assert_eq!( + Utf8ToSmallintCastEvaluator.eval_cast(&utf8("1")).unwrap(), + DataValue::Int16(1) + ); + assert_eq!( + Utf8ToUSmallintCastEvaluator.eval_cast(&utf8("1")).unwrap(), + DataValue::UInt16(1) + ); + assert_eq!( + Utf8ToIntegerCastEvaluator.eval_cast(&utf8("1")).unwrap(), + DataValue::Int32(1) + ); + assert_eq!( + Utf8ToUIntegerCastEvaluator.eval_cast(&utf8("1")).unwrap(), + DataValue::UInt32(1) + ); + assert_eq!( + Utf8ToBigintCastEvaluator.eval_cast(&utf8("1")).unwrap(), + DataValue::Int64(1) + ); + assert_eq!( + Utf8ToUBigintCastEvaluator.eval_cast(&utf8("1")).unwrap(), + DataValue::UInt64(1) + ); + assert_eq!( + Utf8ToFloatCastEvaluator.eval_cast(&utf8("1.5")).unwrap(), + DataValue::Float32(OrderedFloat(1.5)) + ); + assert_eq!( + Utf8ToDoubleCastEvaluator.eval_cast(&utf8("1.5")).unwrap(), + DataValue::Float64(OrderedFloat(1.5)) + ); + assert_eq!( + Utf8ToCharCastEvaluator { + len: 2, + unit: CharLengthUnits::Characters, + } + .eval_cast(&utf8("ab")) + .unwrap(), + DataValue::Utf8 { + value: "ab".to_string(), + ty: Utf8Type::Fixed(2), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + Utf8ToVarcharCastEvaluator { + len: Some(2), + unit: CharLengthUnits::Characters, + } + .eval_cast(&utf8("ab")) + .unwrap(), + DataValue::Utf8 { + value: "ab".to_string(), + ty: Utf8Type::Variable(Some(2)), + unit: CharLengthUnits::Characters, + } + ); + assert_eq!( + Utf8ToDateCastEvaluator + .eval_cast(&utf8("2024-01-02")) + .unwrap(), + DataValue::Date32( + chrono::NaiveDate::from_ymd_opt(2024, 1, 2) + .unwrap() + .num_days_from_ce() + ) + ); + assert_eq!( + Utf8ToDatetimeCastEvaluator + .eval_cast(&utf8("2024-01-02 03:04:05")) + .unwrap(), + DataValue::Date64( + chrono::NaiveDate::from_ymd_opt(2024, 1, 2) + .unwrap() + .and_hms_opt(3, 4, 5) + .unwrap() + .and_utc() + .timestamp() + ) + ); + assert_eq!( + Utf8ToTimeCastEvaluator { precision: Some(0) } + .eval_cast(&utf8("03:04:05")) + .unwrap(), + DataValue::Time32(DataValue::pack(3 * 3600 + 4 * 60 + 5, 0, 0), 0) + ); + assert_eq!( + Utf8ToTimeCastEvaluator { precision: Some(3) } + .eval_cast(&utf8("03:04:05.123")) + .unwrap(), + { + let time = chrono::NaiveTime::parse_from_str( + "03:04:05.123", + crate::types::value::TIME_FMT_WITHOUT_ZONE, + ) + .unwrap(); + DataValue::Time32( + DataValue::pack(time.num_seconds_from_midnight(), time.nanosecond(), 3), + 3, + ) + } + ); + assert_eq!( + Utf8ToTimestampCastEvaluator { + precision: Some(3), + zone: false, + to: LogicalType::TimeStamp(Some(3), false), + } + .eval_cast(&utf8("2024-01-02 03:04:05.123")) + .unwrap(), + DataValue::Time64( + chrono::NaiveDate::from_ymd_opt(2024, 1, 2) + .unwrap() + .and_hms_milli_opt(3, 4, 5, 123) + .unwrap() + .and_utc() + .timestamp_millis(), + 3, + false, + ) + ); + assert_eq!( + Utf8ToTimestampCastEvaluator { + precision: Some(0), + zone: true, + to: LogicalType::TimeStamp(Some(0), true), + } + .eval_cast(&utf8("2024-01-02 03:04:05+00:00")) + .unwrap(), + DataValue::Time64( + chrono::NaiveDate::from_ymd_opt(2024, 1, 2) + .unwrap() + .and_hms_opt(3, 4, 5) + .unwrap() + .and_utc() + .timestamp(), + 0, + true, + ) + ); + assert_eq!( + Utf8ToDecimalCastEvaluator + .eval_cast(&utf8("12.34")) + .unwrap(), + DataValue::Decimal(Decimal::from_str("12.34").unwrap()) + ); + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs index fc70d0f1..6f1681f9 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -23,6 +23,7 @@ use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; use rust_decimal::Decimal; use serde::{Deserialize, Serialize}; use std::any::TypeId; +use std::borrow::Cow; use std::cmp; use crate::errors::DatabaseError; @@ -196,22 +197,22 @@ impl LogicalType { matches!(self, LogicalType::Float | LogicalType::Double) } - pub fn max_logical_type( - left: &LogicalType, - right: &LogicalType, - ) -> Result { + pub fn max_logical_type<'a>( + left: &'a LogicalType, + right: &'a LogicalType, + ) -> Result, DatabaseError> { if left == right { - return Ok(left.clone()); + return Ok(Cow::Borrowed(left)); } match (left, right) { // SqlNull type can be cast to anything - (LogicalType::SqlNull, _) => return Ok(right.clone()), - (_, LogicalType::SqlNull) => return Ok(left.clone()), + (LogicalType::SqlNull, _) => return Ok(Cow::Borrowed(right)), + (_, LogicalType::SqlNull) => return Ok(Cow::Borrowed(left)), (LogicalType::Tuple(types_0), LogicalType::Tuple(types_1)) => { if types_0.len() > types_1.len() { - return Ok(left.clone()); + return Ok(Cow::Borrowed(left)); } else { - return Ok(right.clone()); + return Ok(Cow::Borrowed(right)); } } _ => {} @@ -224,37 +225,40 @@ impl LogicalType { (LogicalType::Date, LogicalType::Varchar(..)) | (LogicalType::Varchar(..), LogicalType::Date) ) { - return Ok(LogicalType::Date); + return Ok(Cow::Owned(LogicalType::Date)); } if matches!( (left, right), (LogicalType::Date, LogicalType::DateTime) | (LogicalType::DateTime, LogicalType::Date) ) { - return Ok(LogicalType::DateTime); + return Ok(Cow::Owned(LogicalType::DateTime)); } if matches!( (left, right), (LogicalType::DateTime, LogicalType::Varchar(..)) | (LogicalType::Varchar(..), LogicalType::DateTime) ) { - return Ok(LogicalType::DateTime); + return Ok(Cow::Owned(LogicalType::DateTime)); } if let (LogicalType::Char(..), LogicalType::Varchar(..)) | (LogicalType::Varchar(..), LogicalType::Char(..)) | (LogicalType::Char(..), LogicalType::Char(..)) | (LogicalType::Varchar(..), LogicalType::Varchar(..)) = (left, right) { - return Ok(LogicalType::Varchar(None, CharLengthUnits::Characters)); + return Ok(Cow::Owned(LogicalType::Varchar( + None, + CharLengthUnits::Characters, + ))); } Err(DatabaseError::Incomparable(left.clone(), right.clone())) } - fn combine_numeric_types( - left: &LogicalType, - right: &LogicalType, - ) -> Result { + fn combine_numeric_types<'a>( + left: &'a LogicalType, + right: &'a LogicalType, + ) -> Result, DatabaseError> { if left == right { - return Ok(left.clone()); + return Ok(Cow::Borrowed(left)); } if left.is_signed_numeric() && right.is_unsigned_numeric() { // this method is symmetric @@ -264,20 +268,28 @@ impl LogicalType { } if LogicalType::can_implicit_cast(left, right) { - return Ok(right.clone()); + return Ok(Cow::Borrowed(right)); } if LogicalType::can_implicit_cast(right, left) { - return Ok(left.clone()); + return Ok(Cow::Borrowed(left)); } // we can't cast implicitly either way and types are not equal // this happens when left is signed and right is unsigned // e.g. INTEGER and UINTEGER // in this case we need to upcast to make sure the types fit match (left, right) { - (LogicalType::Bigint, _) | (_, LogicalType::UBigint) => Ok(LogicalType::Double), - (LogicalType::Integer, _) | (_, LogicalType::UInteger) => Ok(LogicalType::Bigint), - (LogicalType::Smallint, _) | (_, LogicalType::USmallint) => Ok(LogicalType::Integer), - (LogicalType::Tinyint, _) | (_, LogicalType::UTinyint) => Ok(LogicalType::Smallint), + (LogicalType::Bigint, _) | (_, LogicalType::UBigint) => { + Ok(Cow::Owned(LogicalType::Double)) + } + (LogicalType::Integer, _) | (_, LogicalType::UInteger) => { + Ok(Cow::Owned(LogicalType::Bigint)) + } + (LogicalType::Smallint, _) | (_, LogicalType::USmallint) => { + Ok(Cow::Owned(LogicalType::Integer)) + } + (LogicalType::Tinyint, _) | (_, LogicalType::UTinyint) => { + Ok(Cow::Owned(LogicalType::Smallint)) + } ( LogicalType::Decimal(precision_0, scale_0), LogicalType::Decimal(precision_1, scale_1), @@ -287,10 +299,10 @@ impl LogicalType { (Some(num), None) | (None, Some(num)) => Some(*num), (None, None) => None, }; - Ok(LogicalType::Decimal( + Ok(Cow::Owned(LogicalType::Decimal( fn_option(precision_0, precision_1), fn_option(scale_0, scale_1), - )) + ))) } _ => Err(DatabaseError::Incomparable(left.clone(), right.clone())), } diff --git a/src/types/tuple.rs b/src/types/tuple.rs index 284edc07..bee845aa 100644 --- a/src/types/tuple.rs +++ b/src/types/tuple.rs @@ -31,7 +31,6 @@ pub type Schema = Vec; pub type SchemaRef = Arc; pub trait TupleLike { - fn len(&self) -> usize; fn value_at(&self, index: usize) -> &DataValue; #[inline] @@ -68,11 +67,6 @@ pub struct Tuple { } impl TupleLike for Tuple { - #[inline] - fn len(&self) -> usize { - self.values.len() - } - #[inline] fn value_at(&self, index: usize) -> &DataValue { &self.values[index] @@ -85,11 +79,6 @@ impl TupleLike for Tuple { } impl TupleLike for [DataValue] { - #[inline] - fn len(&self) -> usize { - <[DataValue]>::len(self) - } - #[inline] fn value_at(&self, index: usize) -> &DataValue { &self[index] @@ -102,11 +91,6 @@ impl TupleLike for [DataValue] { } impl TupleLike for &Tuple { - #[inline] - fn len(&self) -> usize { - self.values.len() - } - #[inline] fn value_at(&self, index: usize) -> &DataValue { &self.values[index] @@ -119,11 +103,6 @@ impl TupleLike for &Tuple { } impl TupleLike for &[DataValue] { - #[inline] - fn len(&self) -> usize { - <[DataValue]>::len(self) - } - #[inline] fn value_at(&self, index: usize) -> &DataValue { &self[index] @@ -136,11 +115,6 @@ impl TupleLike for &[DataValue] { } impl TupleLike for &dyn TupleLike { - #[inline] - fn len(&self) -> usize { - (*self).len() - } - #[inline] fn value_at(&self, index: usize) -> &DataValue { (*self).value_at(index) @@ -153,11 +127,6 @@ impl TupleLike for &dyn TupleLike { } impl TupleLike for SplitTupleRef<'_> { - #[inline] - fn len(&self) -> usize { - self.left_len + self.right.len() - } - #[inline] fn value_at(&self, index: usize) -> &DataValue { if index < self.left_len { diff --git a/src/types/value.rs b/src/types/value.rs index 7accc5de..dd945f59 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -15,19 +15,19 @@ use super::LogicalType; use crate::errors::DatabaseError; use crate::storage::table_codec::{BumpBytes, BOUND_MAX_TAG, NOTNULL_TAG, NULL_TAG}; +use crate::types::evaluator::cast_create; use byteorder::ReadBytesExt; use chrono::format::{DelayedFormat, StrftimeItems}; use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc}; use itertools::Itertools; use ordered_float::OrderedFloat; -use rust_decimal::prelude::{FromPrimitive, ToPrimitive}; use rust_decimal::Decimal; use sqlparser::ast::CharLengthUnits; +use std::borrow::Cow; use std::cmp::Ordering; use std::fmt::Formatter; use std::hash::Hash; use std::io::{Read, Write}; -use std::str::FromStr; use std::sync::LazyLock; use std::{cmp, fmt, mem}; @@ -374,74 +374,6 @@ impl Hash for DataValue { } } } -macro_rules! varchar_cast { - ($value:expr, $len:expr, $ty:expr, $unit:expr) => {{ - let s_value = $value.to_string(); - if let Some(len) = $len { - if Self::check_string_len(&s_value, *len as usize, $unit) { - return Err(DatabaseError::TooLong); - } - } - Ok(DataValue::Utf8 { - value: s_value, - ty: $ty, - unit: $unit, - }) - }}; -} - -macro_rules! numeric_to_boolean { - ($value:expr, $from_ty:expr) => { - match $value { - 0 => Ok(DataValue::Boolean(false)), - 1 => Ok(DataValue::Boolean(true)), - _ => Err(DatabaseError::CastFail { - from: $from_ty, - to: LogicalType::Boolean, - span: None, - }), - } - }; -} - -macro_rules! float_to_int { - ($float_value:expr, $int_type:ty, $float_type:ty) => {{ - let float_value: $float_type = $float_value; - if float_value.is_nan() { - Ok(0) - } else if float_value <= 0.0 || float_value > <$int_type>::MAX as $float_type { - Err(DatabaseError::OverFlow) - } else { - Ok(float_value as $int_type) - } - }}; -} - -macro_rules! decimal_to_int { - ($decimal:expr, $int_type:ty) => {{ - let d = $decimal; - if d.is_sign_negative() { - if <$int_type>::MIN == 0 { - 0 - } else { - let min = Decimal::from(<$int_type>::MIN); - if d <= min { - <$int_type>::MIN - } else { - d.to_i128().unwrap() as $int_type - } - } - } else { - let max = Decimal::from(<$int_type>::MAX); - if d >= max { - <$int_type>::MAX - } else { - d.to_i128().unwrap() as $int_type - } - } - }}; -} - impl DataValue { pub fn float(&self) -> Option { if let DataValue::Float32(val) = self { @@ -595,19 +527,19 @@ impl DataValue { (b, scaled_a * (1000000000 / 10_u32.pow(precision as u32))) } - fn format_date(value: i32) -> Option { + pub(crate) fn format_date(value: i32) -> Option { Self::date_format(value).map(|fmt| format!("{fmt}")) } - fn format_datetime(value: i64) -> Option { + pub(crate) fn format_datetime(value: i64) -> Option { Self::date_time_format(value).map(|fmt| format!("{fmt}")) } - fn format_time(value: u32, precision: u64) -> Option { + pub(crate) fn format_time(value: u32, precision: u64) -> Option { Self::time_format(value, precision).map(|fmt| format!("{fmt}")) } - fn format_timestamp(value: i64, precision: u64) -> Option { + pub(crate) fn format_timestamp(value: i64, precision: u64) -> Option { Self::time_stamp_format(value, precision, false).map(|fmt| format!("{fmt}")) } @@ -1154,811 +1086,13 @@ impl DataValue { } pub fn cast(self, to: &LogicalType) -> Result { - let value = match self { - DataValue::Null => Ok(DataValue::Null), - DataValue::Boolean(value) => match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::Boolean => Ok(DataValue::Boolean(value)), - LogicalType::Tinyint => Ok(DataValue::Int8(value.into())), - LogicalType::UTinyint => Ok(DataValue::UInt8(value.into())), - LogicalType::Smallint => Ok(DataValue::Int16(value.into())), - LogicalType::USmallint => Ok(DataValue::UInt16(value.into())), - LogicalType::Integer => Ok(DataValue::Int32(value.into())), - LogicalType::UInteger => Ok(DataValue::UInt32(value.into())), - LogicalType::Bigint => Ok(DataValue::Int64(value.into())), - LogicalType::UBigint => Ok(DataValue::UInt64(value.into())), - LogicalType::Float => Ok(DataValue::Float32(value.into())), - LogicalType::Double => Ok(DataValue::Float64(value.into())), - LogicalType::Char(len, unit) => { - varchar_cast!(value, Some(len), Utf8Type::Fixed(*len), *unit) - } - LogicalType::Varchar(len, unit) => { - varchar_cast!(value, len, Utf8Type::Variable(*len), *unit) - } - _ => Err(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - }), - }, - DataValue::Float32(value) => match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::Float => Ok(DataValue::Float32(value)), - LogicalType::Double => Ok(DataValue::Float64(OrderedFloat(value.0.into()))), - LogicalType::Char(len, unit) => { - varchar_cast!(value, Some(len), Utf8Type::Fixed(*len), *unit) - } - LogicalType::Varchar(len, unit) => { - varchar_cast!(value, len, Utf8Type::Variable(*len), *unit) - } - LogicalType::Decimal(_, option) => { - let mut decimal = - Decimal::from_f32(value.0).ok_or(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - })?; - Self::decimal_round_f(option, &mut decimal); - - Ok(DataValue::Decimal(decimal)) - } - LogicalType::Tinyint => { - Ok(DataValue::Int8(float_to_int!(value.into_inner(), i8, f32)?)) - } - LogicalType::Smallint => Ok(DataValue::Int16(float_to_int!( - value.into_inner(), - i16, - f32 - )?)), - LogicalType::Integer => Ok(DataValue::Int32(float_to_int!( - value.into_inner(), - i32, - f32 - )?)), - LogicalType::Bigint => Ok(DataValue::Int64(float_to_int!( - value.into_inner(), - i64, - f32 - )?)), - LogicalType::UTinyint => Ok(DataValue::UInt8(float_to_int!( - value.into_inner(), - u8, - f32 - )?)), - LogicalType::USmallint => Ok(DataValue::UInt16(float_to_int!( - value.into_inner(), - u16, - f32 - )?)), - LogicalType::UInteger => Ok(DataValue::UInt32(float_to_int!( - value.into_inner(), - u32, - f32 - )?)), - LogicalType::UBigint => Ok(DataValue::UInt64(float_to_int!( - value.into_inner(), - u64, - f32 - )?)), - _ => Err(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - }), - }, - DataValue::Float64(value) => match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::Float => Ok(DataValue::Float32(OrderedFloat(value.0 as f32))), - LogicalType::Double => Ok(DataValue::Float64(value)), - LogicalType::Char(len, unit) => { - varchar_cast!(value, Some(len), Utf8Type::Fixed(*len), *unit) - } - LogicalType::Varchar(len, unit) => { - varchar_cast!(value, len, Utf8Type::Variable(*len), *unit) - } - LogicalType::Decimal(_, option) => { - let mut decimal = - Decimal::from_f64(value.0).ok_or(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - })?; - Self::decimal_round_f(option, &mut decimal); - - Ok(DataValue::Decimal(decimal)) - } - LogicalType::Tinyint => { - Ok(DataValue::Int8(float_to_int!(value.into_inner(), i8, f64)?)) - } - LogicalType::Smallint => Ok(DataValue::Int16(float_to_int!( - value.into_inner(), - i16, - f64 - )?)), - LogicalType::Integer => Ok(DataValue::Int32(float_to_int!( - value.into_inner(), - i32, - f64 - )?)), - LogicalType::Bigint => Ok(DataValue::Int64(float_to_int!( - value.into_inner(), - i64, - f64 - )?)), - LogicalType::UTinyint => Ok(DataValue::UInt8(float_to_int!( - value.into_inner(), - u8, - f64 - )?)), - LogicalType::USmallint => Ok(DataValue::UInt16(float_to_int!( - value.into_inner(), - u16, - f64 - )?)), - LogicalType::UInteger => Ok(DataValue::UInt32(float_to_int!( - value.into_inner(), - u32, - f64 - )?)), - LogicalType::UBigint => Ok(DataValue::UInt64(float_to_int!( - value.into_inner(), - u64, - f64 - )?)), - _ => Err(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - }), - }, - DataValue::Int8(value) => match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::Tinyint => Ok(DataValue::Int8(value)), - LogicalType::UTinyint => Ok(DataValue::UInt8(u8::try_from(value)?)), - LogicalType::USmallint => Ok(DataValue::UInt16(u16::try_from(value)?)), - LogicalType::UInteger => Ok(DataValue::UInt32(u32::try_from(value)?)), - LogicalType::UBigint => Ok(DataValue::UInt64(u64::try_from(value)?)), - LogicalType::Smallint => Ok(DataValue::Int16(value.into())), - LogicalType::Integer => Ok(DataValue::Int32(value.into())), - LogicalType::Bigint => Ok(DataValue::Int64(value.into())), - LogicalType::Float => Ok(DataValue::Float32(value.into())), - LogicalType::Double => Ok(DataValue::Float64(value.into())), - LogicalType::Char(len, unit) => { - varchar_cast!(value, Some(len), Utf8Type::Fixed(*len), *unit) - } - LogicalType::Varchar(len, unit) => { - varchar_cast!(value, len, Utf8Type::Variable(*len), *unit) - } - LogicalType::Decimal(_, option) => { - let mut decimal = Decimal::from(value); - Self::decimal_round_i(option, &mut decimal); - - Ok(DataValue::Decimal(decimal)) - } - LogicalType::Boolean => numeric_to_boolean!(value, self.logical_type()), - _ => Err(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - }), - }, - DataValue::Int16(value) => match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::UTinyint => Ok(DataValue::UInt8(u8::try_from(value)?)), - LogicalType::USmallint => Ok(DataValue::UInt16(u16::try_from(value)?)), - LogicalType::UInteger => Ok(DataValue::UInt32(u32::try_from(value)?)), - LogicalType::UBigint => Ok(DataValue::UInt64(u64::try_from(value)?)), - LogicalType::Tinyint => Ok(DataValue::Int8(i8::try_from(value)?)), - LogicalType::Smallint => Ok(DataValue::Int16(value)), - LogicalType::Integer => Ok(DataValue::Int32(value.into())), - LogicalType::Bigint => Ok(DataValue::Int64(value.into())), - LogicalType::Float => Ok(DataValue::Float32(value.into())), - LogicalType::Double => Ok(DataValue::Float64(value.into())), - LogicalType::Char(len, unit) => { - varchar_cast!(value, Some(len), Utf8Type::Fixed(*len), *unit) - } - LogicalType::Varchar(len, unit) => { - varchar_cast!(value, len, Utf8Type::Variable(*len), *unit) - } - LogicalType::Decimal(_, option) => { - let mut decimal = Decimal::from(value); - Self::decimal_round_i(option, &mut decimal); - - Ok(DataValue::Decimal(decimal)) - } - LogicalType::Boolean => numeric_to_boolean!(value, self.logical_type()), - _ => Err(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - }), - }, - DataValue::Int32(value) => match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::UTinyint => Ok(DataValue::UInt8(u8::try_from(value)?)), - LogicalType::USmallint => Ok(DataValue::UInt16(u16::try_from(value)?)), - LogicalType::UInteger => Ok(DataValue::UInt32(u32::try_from(value)?)), - LogicalType::UBigint => Ok(DataValue::UInt64(u64::try_from(value)?)), - LogicalType::Tinyint => Ok(DataValue::Int8(i8::try_from(value)?)), - LogicalType::Smallint => Ok(DataValue::Int16(i16::try_from(value)?)), - LogicalType::Integer => Ok(DataValue::Int32(value)), - LogicalType::Bigint => Ok(DataValue::Int64(value.into())), - LogicalType::Float => Ok(DataValue::Float32(OrderedFloat(value as f32))), - LogicalType::Double => Ok(DataValue::Float64(value.into())), - LogicalType::Char(len, unit) => { - varchar_cast!(value, Some(len), Utf8Type::Fixed(*len), *unit) - } - LogicalType::Varchar(len, unit) => { - varchar_cast!(value, len, Utf8Type::Variable(*len), *unit) - } - LogicalType::Decimal(_, option) => { - let mut decimal = Decimal::from(value); - Self::decimal_round_i(option, &mut decimal); - - Ok(DataValue::Decimal(decimal)) - } - LogicalType::Boolean => numeric_to_boolean!(value, self.logical_type()), - _ => Err(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - }), - }, - DataValue::Int64(value) => match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::UTinyint => Ok(DataValue::UInt8(u8::try_from(value)?)), - LogicalType::USmallint => Ok(DataValue::UInt16(u16::try_from(value)?)), - LogicalType::UInteger => Ok(DataValue::UInt32(u32::try_from(value)?)), - LogicalType::UBigint => Ok(DataValue::UInt64(u64::try_from(value)?)), - LogicalType::Tinyint => Ok(DataValue::Int8(i8::try_from(value)?)), - LogicalType::Smallint => Ok(DataValue::Int16(i16::try_from(value)?)), - LogicalType::Integer => Ok(DataValue::Int32(i32::try_from(value)?)), - LogicalType::Bigint => Ok(DataValue::Int64(value)), - LogicalType::Float => Ok(DataValue::Float32(OrderedFloat(value as f32))), - LogicalType::Double => Ok(DataValue::Float64(OrderedFloat(value as f64))), - LogicalType::Char(len, unit) => { - varchar_cast!(value, Some(len), Utf8Type::Fixed(*len), *unit) - } - LogicalType::Varchar(len, unit) => { - varchar_cast!(value, len, Utf8Type::Variable(*len), *unit) - } - LogicalType::Decimal(_, option) => { - let mut decimal = Decimal::from(value); - Self::decimal_round_i(option, &mut decimal); - - Ok(DataValue::Decimal(decimal)) - } - LogicalType::Boolean => numeric_to_boolean!(value, self.logical_type()), - _ => Err(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - }), - }, - DataValue::UInt8(value) => match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::Tinyint => Ok(DataValue::Int8(i8::try_from(value)?)), - LogicalType::UTinyint => Ok(DataValue::UInt8(value)), - LogicalType::Smallint => Ok(DataValue::Int16(value.into())), - LogicalType::USmallint => Ok(DataValue::UInt16(value.into())), - LogicalType::Integer => Ok(DataValue::Int32(value.into())), - LogicalType::UInteger => Ok(DataValue::UInt32(value.into())), - LogicalType::Bigint => Ok(DataValue::Int64(value.into())), - LogicalType::UBigint => Ok(DataValue::UInt64(value.into())), - LogicalType::Float => Ok(DataValue::Float32(value.into())), - LogicalType::Double => Ok(DataValue::Float64(value.into())), - LogicalType::Char(len, unit) => { - varchar_cast!(value, Some(len), Utf8Type::Fixed(*len), *unit) - } - LogicalType::Varchar(len, unit) => { - varchar_cast!(value, len, Utf8Type::Variable(*len), *unit) - } - LogicalType::Decimal(_, option) => { - let mut decimal = Decimal::from(value); - Self::decimal_round_i(option, &mut decimal); - - Ok(DataValue::Decimal(decimal)) - } - LogicalType::Boolean => numeric_to_boolean!(value, self.logical_type()), - _ => Err(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - }), - }, - DataValue::UInt16(value) => match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::Tinyint => Ok(DataValue::Int8(i8::try_from(value)?)), - LogicalType::UTinyint => Ok(DataValue::UInt8(u8::try_from(value)?)), - LogicalType::Smallint => Ok(DataValue::Int16(i16::try_from(value)?)), - LogicalType::USmallint => Ok(DataValue::UInt16(value)), - LogicalType::Integer => Ok(DataValue::Int32(value.into())), - LogicalType::UInteger => Ok(DataValue::UInt32(value.into())), - LogicalType::Bigint => Ok(DataValue::Int64(value.into())), - LogicalType::UBigint => Ok(DataValue::UInt64(value.into())), - LogicalType::Float => Ok(DataValue::Float32(value.into())), - LogicalType::Double => Ok(DataValue::Float64(value.into())), - LogicalType::Char(len, unit) => { - varchar_cast!(value, Some(len), Utf8Type::Fixed(*len), *unit) - } - LogicalType::Varchar(len, unit) => { - varchar_cast!(value, len, Utf8Type::Variable(*len), *unit) - } - LogicalType::Decimal(_, option) => { - let mut decimal = Decimal::from(value); - Self::decimal_round_i(option, &mut decimal); - - Ok(DataValue::Decimal(decimal)) - } - LogicalType::Boolean => numeric_to_boolean!(value, self.logical_type()), - _ => Err(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - }), - }, - DataValue::UInt32(value) => match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::Tinyint => Ok(DataValue::Int8(i8::try_from(value)?)), - LogicalType::UTinyint => Ok(DataValue::UInt8(u8::try_from(value)?)), - LogicalType::Smallint => Ok(DataValue::Int16(i16::try_from(value)?)), - LogicalType::USmallint => Ok(DataValue::UInt16(u16::try_from(value)?)), - LogicalType::Integer => Ok(DataValue::Int32(i32::try_from(value)?)), - LogicalType::UInteger => Ok(DataValue::UInt32(value)), - LogicalType::Bigint => Ok(DataValue::Int64(value.into())), - LogicalType::UBigint => Ok(DataValue::UInt64(value.into())), - LogicalType::Float => Ok(DataValue::Float32(OrderedFloat(value as f32))), - LogicalType::Double => Ok(DataValue::Float64(OrderedFloat(value.into()))), - LogicalType::Char(len, unit) => { - varchar_cast!(value, Some(len), Utf8Type::Fixed(*len), *unit) - } - LogicalType::Varchar(len, unit) => { - varchar_cast!(value, len, Utf8Type::Variable(*len), *unit) - } - LogicalType::Decimal(_, option) => { - let mut decimal = Decimal::from(value); - Self::decimal_round_i(option, &mut decimal); - - Ok(DataValue::Decimal(decimal)) - } - LogicalType::Boolean => numeric_to_boolean!(value, self.logical_type()), - _ => Err(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - }), - }, - DataValue::UInt64(value) => match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::Tinyint => Ok(DataValue::Int8(i8::try_from(value)?)), - LogicalType::UTinyint => Ok(DataValue::UInt8(u8::try_from(value)?)), - LogicalType::Smallint => Ok(DataValue::Int16(i16::try_from(value)?)), - LogicalType::USmallint => Ok(DataValue::UInt16(u16::try_from(value)?)), - LogicalType::Integer => Ok(DataValue::Int32(i32::try_from(value)?)), - LogicalType::UInteger => Ok(DataValue::UInt32(u32::try_from(value)?)), - LogicalType::Bigint => Ok(DataValue::Int64(i64::try_from(value)?)), - LogicalType::UBigint => Ok(DataValue::UInt64(value)), - LogicalType::Float => Ok(DataValue::Float32(OrderedFloat(value as f32))), - LogicalType::Double => Ok(DataValue::Float64(OrderedFloat(value as f64))), - LogicalType::Char(len, unit) => { - varchar_cast!(value, Some(len), Utf8Type::Fixed(*len), *unit) - } - LogicalType::Varchar(len, unit) => { - varchar_cast!(value, len, Utf8Type::Variable(*len), *unit) - } - LogicalType::Decimal(_, option) => { - let mut decimal = Decimal::from(value); - Self::decimal_round_i(option, &mut decimal); + let from = self.logical_type(); + if &from == to { + return Ok(self); + } + let evaluator = cast_create(Cow::Owned(from), Cow::Borrowed(to))?; - Ok(DataValue::Decimal(decimal)) - } - LogicalType::Boolean => numeric_to_boolean!(value, self.logical_type()), - _ => Err(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - }), - }, - DataValue::Utf8 { ref value, .. } => match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::Boolean => Ok(DataValue::Boolean(bool::from_str(value)?)), - LogicalType::Tinyint => Ok(DataValue::Int8(i8::from_str(value)?)), - LogicalType::UTinyint => Ok(DataValue::UInt8(u8::from_str(value)?)), - LogicalType::Smallint => Ok(DataValue::Int16(i16::from_str(value)?)), - LogicalType::USmallint => Ok(DataValue::UInt16(u16::from_str(value)?)), - LogicalType::Integer => Ok(DataValue::Int32(i32::from_str(value)?)), - LogicalType::UInteger => Ok(DataValue::UInt32(u32::from_str(value)?)), - LogicalType::Bigint => Ok(DataValue::Int64(i64::from_str(value)?)), - LogicalType::UBigint => Ok(DataValue::UInt64(u64::from_str(value)?)), - LogicalType::Float => Ok(DataValue::Float32(OrderedFloat(f32::from_str(value)?))), - LogicalType::Double => Ok(DataValue::Float64(OrderedFloat(f64::from_str(value)?))), - LogicalType::Char(len, unit) => { - varchar_cast!(value, Some(len), Utf8Type::Fixed(*len), *unit) - } - LogicalType::Varchar(len, unit) => { - varchar_cast!(value, len, Utf8Type::Variable(*len), *unit) - } - LogicalType::Date => { - let value = NaiveDate::parse_from_str(value, DATE_FMT)?.num_days_from_ce(); - Ok(DataValue::Date32(value)) - } - LogicalType::DateTime => { - let value = NaiveDateTime::parse_from_str(value, DATE_TIME_FMT) - .or_else(|_| { - NaiveDate::parse_from_str(value, DATE_FMT) - .map(|date| date.and_hms_opt(0, 0, 0).unwrap()) - }) - .map(|date_time| date_time.and_utc().timestamp())?; - - Ok(DataValue::Date64(value)) - } - LogicalType::Time(precision) => { - let precision = match precision { - Some(precision) => *precision, - None => 0, - }; - let fmt = if precision == 0 { - TIME_FMT - } else { - TIME_FMT_WITHOUT_ZONE - }; - let (value, nano) = match precision { - 0 => ( - NaiveTime::parse_from_str(value, fmt) - .map(|time| time.num_seconds_from_midnight())?, - 0, - ), - _ => NaiveTime::parse_from_str(value, fmt) - .map(|time| (time.num_seconds_from_midnight(), time.nanosecond()))?, - }; - Ok(DataValue::Time32( - Self::pack(value, nano, precision), - precision, - )) - } - LogicalType::TimeStamp(precision, zone) => { - let precision = match precision { - Some(precision) => *precision, - None => 0, - }; - let fmt = match (precision, *zone) { - (0, false) => DATE_TIME_FMT, - (0, true) => TIME_STAMP_FMT_WITHOUT_PRECISION, - (3 | 6 | 9, false) => TIME_STAMP_FMT_WITHOUT_ZONE, - _ => TIME_STAMP_FMT_WITH_ZONE, - }; - let complete_value = if *zone { - match value.contains("+") { - false => format!("{}+00:00", value.clone()), - true => value.clone(), - } - } else { - value.clone() - }; - if precision == 0 && !*zone { - return Ok(DataValue::Time64( - NaiveDateTime::parse_from_str(&complete_value, fmt) - .map(|date_time| date_time.and_utc().timestamp())?, - precision, - *zone, - )); - } - let value = DateTime::parse_from_str(&complete_value, fmt); - let value = match precision { - 3 => value.map(|date_time| date_time.timestamp_millis())?, - 6 => value.map(|date_time| date_time.timestamp_micros())?, - 9 => { - if let Some(value) = - value.map(|date_time| date_time.timestamp_nanos_opt())? - { - value - } else { - return Err(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - }); - } - } - 0 => value.map(|date_time| date_time.timestamp())?, - _ => unreachable!(), - }; - Ok(DataValue::Time64(value, precision, *zone)) - } - LogicalType::Decimal(_, _) => Ok(DataValue::Decimal(Decimal::from_str(value)?)), - _ => Err(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - }), - }, - DataValue::Date32(value) => match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::Char(len, unit) => { - varchar_cast!( - Self::format_date(value).ok_or(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - })?, - Some(len), - Utf8Type::Fixed(*len), - *unit - ) - } - LogicalType::Varchar(len, unit) => { - varchar_cast!( - Self::format_date(value).ok_or(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - })?, - len, - Utf8Type::Variable(*len), - *unit - ) - } - LogicalType::Date => Ok(DataValue::Date32(value)), - LogicalType::DateTime => { - let value = NaiveDate::from_num_days_from_ce_opt(value) - .ok_or(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - })? - .and_hms_opt(0, 0, 0) - .ok_or(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - })? - .and_utc() - .timestamp(); - - Ok(DataValue::Date64(value)) - } - _ => Err(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - }), - }, - DataValue::Date64(value) => match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::Char(len, unit) => { - varchar_cast!( - Self::format_datetime(value).ok_or(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - })?, - Some(len), - Utf8Type::Fixed(*len), - *unit - ) - } - LogicalType::Varchar(len, unit) => { - varchar_cast!( - Self::format_datetime(value).ok_or(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - })?, - len, - Utf8Type::Variable(*len), - *unit - ) - } - LogicalType::Date => { - let value = DateTime::from_timestamp(value, 0) - .ok_or(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - })? - .naive_utc() - .date() - .num_days_from_ce(); - - Ok(DataValue::Date32(value)) - } - LogicalType::DateTime => Ok(DataValue::Date64(value)), - LogicalType::Time(precision) => { - let precision = match precision { - Some(precision) => *precision, - None => 0, - }; - let value = DateTime::from_timestamp(value, 0) - .map(|date_time| date_time.time().num_seconds_from_midnight()) - .ok_or(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - })?; - - Ok(DataValue::Time32(Self::pack(value, 0, 0), precision)) - } - LogicalType::TimeStamp(precision, zone) => { - let precision = match precision { - Some(precision) => *precision, - None => 0, - }; - Ok(DataValue::Time64(value, precision, *zone)) - } - _ => Err(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - }), - }, - DataValue::Time32(value, precision) => match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::Char(len, unit) => { - varchar_cast!( - Self::format_time(value, precision).ok_or(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - })?, - Some(len), - Utf8Type::Fixed(*len), - *unit - ) - } - LogicalType::Varchar(len, unit) => { - varchar_cast!( - Self::format_time(value, precision).ok_or(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - })?, - len, - Utf8Type::Variable(*len), - *unit - ) - } - LogicalType::Time(to_precision) => { - Ok(DataValue::Time32(value, to_precision.unwrap_or(0))) - } - _ => Err(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - }), - }, - DataValue::Time64(value, precision, _) => match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::Char(len, unit) => { - varchar_cast!( - Self::format_timestamp(value, precision).ok_or( - DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - } - )?, - Some(len), - Utf8Type::Fixed(*len), - *unit - ) - } - LogicalType::Varchar(len, unit) => { - varchar_cast!( - Self::format_timestamp(value, precision).ok_or( - DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - } - )?, - len, - Utf8Type::Variable(*len), - *unit - ) - } - LogicalType::Date => { - let value = Self::from_timestamp_precision(value, precision) - .ok_or(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - })? - .naive_utc() - .date() - .num_days_from_ce(); - - Ok(DataValue::Date32(value)) - } - LogicalType::DateTime => { - let value = Self::from_timestamp_precision(value, precision) - .ok_or(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - })? - .timestamp(); - Ok(DataValue::Date64(value)) - } - LogicalType::Time(p) => { - let p = p.unwrap_or(0); - let (value, nano) = Self::from_timestamp_precision(value, precision) - .map(|date_time| { - ( - date_time.time().num_seconds_from_midnight(), - date_time.time().nanosecond(), - ) - }) - .ok_or(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - })?; - Ok(DataValue::Time32(Self::pack(value, nano, p), p)) - } - LogicalType::TimeStamp(to_precision, zone) => { - Ok(DataValue::Time64(value, to_precision.unwrap_or(0), *zone)) - } - _ => Err(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - }), - }, - DataValue::Decimal(value) => match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::Float => Ok(DataValue::Float32(OrderedFloat(value.to_f32().ok_or( - DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - }, - )?))), - LogicalType::Double => Ok(DataValue::Float64(OrderedFloat(value.to_f64().ok_or( - DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - }, - )?))), - LogicalType::Decimal(_, _) => Ok(DataValue::Decimal(value)), - LogicalType::Char(len, unit) => { - varchar_cast!(value, Some(len), Utf8Type::Fixed(*len), *unit) - } - LogicalType::Varchar(len, unit) => { - varchar_cast!(value, len, Utf8Type::Variable(*len), *unit) - } - LogicalType::Tinyint => Ok(DataValue::Int8(decimal_to_int!(value, i8))), - LogicalType::Smallint => Ok(DataValue::Int16(decimal_to_int!(value, i16))), - LogicalType::Integer => Ok(DataValue::Int32(decimal_to_int!(value, i32))), - LogicalType::Bigint => Ok(DataValue::Int64(decimal_to_int!(value, i64))), - LogicalType::UTinyint => Ok(DataValue::UInt8(decimal_to_int!(value, u8))), - LogicalType::USmallint => Ok(DataValue::UInt16(decimal_to_int!(value, u16))), - LogicalType::UInteger => Ok(DataValue::UInt32(decimal_to_int!(value, u32))), - LogicalType::UBigint => Ok(DataValue::UInt64(decimal_to_int!(value, u64))), - _ => Err(DatabaseError::CastFail { - from: self.logical_type(), - to: to.clone(), - span: None, - }), - }, - DataValue::Tuple(mut values, is_upper) => match to { - LogicalType::Tuple(types) => { - for (i, value) in values.iter_mut().enumerate() { - if types[i] != value.logical_type() { - *value = mem::replace(value, DataValue::Null).cast(&types[i])?; - } - } - Ok(DataValue::Tuple(values, is_upper)) - } - _ => Err(DatabaseError::CastFail { - from: LogicalType::Tuple(values.iter().map(DataValue::logical_type).collect()), - to: to.clone(), - span: None, - }), - }, - }?; - value.check_len(to)?; - Ok(value) + evaluator.eval_cast(&self) } #[inline] @@ -1998,14 +1132,14 @@ impl DataValue { } } - fn decimal_round_i(option: &Option, decimal: &mut Decimal) { + pub(crate) fn decimal_round_i(option: &Option, decimal: &mut Decimal) { if let Some(scale) = option { let new_decimal = decimal.trunc_with_scale(*scale as u32); let _ = mem::replace(decimal, new_decimal); } } - fn decimal_round_f(option: &Option, decimal: &mut Decimal) { + pub(crate) fn decimal_round_f(option: &Option, decimal: &mut Decimal) { if let Some(scale) = option { let new_decimal = decimal.round_dp_with_strategy( *scale as u32, diff --git a/tests/macros-test/src/main.rs b/tests/macros-test/src/main.rs index fd96ce3c..ada09c87 100644 --- a/tests/macros-test/src/main.rs +++ b/tests/macros-test/src/main.rs @@ -26,7 +26,7 @@ mod test { use kite_sql::expression::ScalarExpression; use kite_sql::orm::{case_when, count_all, func, max, min, sum, QueryValue}; use kite_sql::storage::rocksdb::RocksStorage; - use kite_sql::types::evaluator::EvaluatorFactory; + use kite_sql::types::evaluator::binary_create; use kite_sql::types::tuple::{SchemaRef, Tuple}; use kite_sql::types::value::{DataValue, Utf8Type}; use kite_sql::types::LogicalType; @@ -1917,7 +1917,7 @@ mod test { } scala_function!(MyScalaFunction::SUM(LogicalType::Integer, LogicalType::Integer) -> LogicalType::Integer => (|v1: DataValue, v2: DataValue| { - EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Plus)?.binary_eval(&v1, &v2) + binary_create(std::borrow::Cow::Owned(LogicalType::Integer), BinaryOperator::Plus)?.binary_eval(&v1, &v2) })); scala_function!(MyOrmFunction::ADD_ONE(LogicalType::Integer) -> LogicalType::Integer => (|v1: DataValue| {