Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions src/binder/alter_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

use sqlparser::ast::{AlterColumnOperation, AlterTableOperation, ObjectName};

use std::borrow::Cow;
use std::sync::Arc;

use super::{attach_span_if_absent, is_valid_identifier, Binder};
Expand Down Expand Up @@ -44,12 +45,7 @@ impl<T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'_, '_, T, A>
"column is not allowed to exist in default".to_string(),
));
}
if expr.return_type() != *ty {
expr = ScalarExpression::TypeCast {
expr: Box::new(expr),
ty: ty.clone(),
};
}
expr = ScalarExpression::type_cast(expr, Cow::Borrowed(ty))?;

Ok(expr)
}
Expand Down
11 changes: 5 additions & 6 deletions src/binder/create_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use crate::types::value::DataValue;
use crate::types::LogicalType;
use itertools::Itertools;
use sqlparser::ast::{ColumnDef, ColumnOption, Expr, IndexColumn, ObjectName, TableConstraint};
use std::borrow::Cow;
use std::collections::HashSet;
use std::sync::Arc;

Expand Down Expand Up @@ -161,12 +162,10 @@ impl<T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'_, '_, T, A>
"column is not allowed to exist in `default`".to_string(),
));
}
if expr.return_type() != column_desc.column_datatype {
expr = ScalarExpression::TypeCast {
expr: Box::new(expr),
ty: column_desc.column_datatype.clone(),
}
}
expr = ScalarExpression::type_cast(
expr,
Cow::Borrowed(&column_desc.column_datatype),
)?;
column_desc.default = Some(expr);
}
option => {
Expand Down
53 changes: 29 additions & 24 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use sqlparser::ast::{
BinaryOperator, CharLengthUnits, DataType, DuplicateTreatment, Expr, Function, FunctionArg,
FunctionArgExpr, FunctionArguments, Ident, Query, TypedString, UnaryOperator, Value,
};
use std::borrow::Cow;
use std::collections::HashMap;
use std::slice;

Expand Down Expand Up @@ -284,7 +285,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T
} => {
let left_expr = self.bind_expr(expr)?;
let (sub_query, column, correlated) =
self.bind_subquery(Some(left_expr.return_type()), subquery)?;
self.bind_subquery(Some(left_expr.return_type().as_ref()), subquery)?;

if !self.context.is_step(&QueryBindStep::Where) {
return Err(DatabaseError::UnsupportedStmt(
Expand Down Expand Up @@ -354,7 +355,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T
let mut expr_pairs = Vec::with_capacity(conditions.len());
for when in conditions {
let result = self.bind_expr(&when.result)?;
let result_ty = result.return_type();
let result_ty = result.return_type().into_owned();

fn_check_ty(&mut ty, result_ty)?;
expr_pairs.push((self.bind_expr(&when.condition)?, result))
Expand All @@ -363,7 +364,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T
let mut else_expr = None;
if let Some(expr) = else_result {
let temp_expr = Box::new(self.bind_expr(expr)?);
let else_ty = temp_expr.return_type();
let else_ty = temp_expr.return_type().into_owned();

fn_check_ty(&mut ty, else_ty)?;
else_expr = Some(temp_expr);
Expand Down Expand Up @@ -426,7 +427,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T

fn bind_subquery(
&mut self,
in_ty: Option<LogicalType>,
in_ty: Option<&LogicalType>,
subquery: &Query,
) -> Result<(LogicalPlan, ScalarExpression, bool), DatabaseError> {
let BinderContext {
Expand Down Expand Up @@ -606,18 +607,19 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T
let left_expr = Box::new(self.bind_expr(left)?);
let right_expr = Box::new(self.bind_expr(right)?);

let left_ty = left_expr.return_type();
let right_ty = right_expr.return_type();
let ty = match op {
BinaryOperator::Plus
| BinaryOperator::Minus
| BinaryOperator::Multiply
| BinaryOperator::Modulo => {
LogicalType::max_logical_type(&left_expr.return_type(), &right_expr.return_type())?
LogicalType::max_logical_type(&left_ty, &right_ty)?.into_owned()
}
BinaryOperator::Divide => {
if let LogicalType::Decimal(precision, scale) = LogicalType::max_logical_type(
&left_expr.return_type(),
&right_expr.return_type(),
)? {
if let LogicalType::Decimal(precision, scale) =
LogicalType::max_logical_type(&left_ty, &right_ty)?.into_owned()
{
LogicalType::Decimal(precision, scale)
} else {
LogicalType::Double
Expand Down Expand Up @@ -654,7 +656,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T
let ty = if let UnaryOperator::Not = op {
LogicalType::Boolean
} else {
expr.return_type()
expr.return_type().into_owned()
};

Ok(ScalarExpression::Unary {
Expand Down Expand Up @@ -714,7 +716,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T
if args.len() != 1 {
return Err(DatabaseError::MisMatch("number of sum() parameters", "1"));
}
let ty = args[0].return_type();
let ty = args[0].return_type().into_owned();

return Ok(ScalarExpression::AggCall {
distinct: is_distinct,
Expand All @@ -727,7 +729,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T
if args.len() != 1 {
return Err(DatabaseError::MisMatch("number of min() parameters", "1"));
}
let ty = args[0].return_type();
let ty = args[0].return_type().into_owned();

return Ok(ScalarExpression::AggCall {
distinct: is_distinct,
Expand All @@ -740,7 +742,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T
if args.len() != 1 {
return Err(DatabaseError::MisMatch("number of max() parameters", "1"));
}
let ty = args[0].return_type();
let ty = args[0].return_type().into_owned();

return Ok(ScalarExpression::AggCall {
distinct: is_distinct,
Expand Down Expand Up @@ -815,26 +817,29 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T
let mut ty = LogicalType::SqlNull;

if !args.is_empty() {
ty = args[0].return_type();
ty = args[0].return_type().into_owned();

for arg in args.iter_mut() {
let temp_ty = arg.return_type();
let temp_ty = arg.return_type().into_owned();

if temp_ty == LogicalType::SqlNull {
continue;
}
if ty == LogicalType::SqlNull && temp_ty != LogicalType::SqlNull {
ty = temp_ty;
} else if ty != temp_ty {
ty = LogicalType::max_logical_type(&ty, &temp_ty)?;
ty = LogicalType::max_logical_type(&ty, &temp_ty)?.into_owned();
}
}
}
return Ok(ScalarExpression::Coalesce { exprs: args, ty });
}
_ => (),
}
let arg_types = args.iter().map(ScalarExpression::return_type).collect_vec();
let arg_types = args
.iter()
.map(|arg| arg.return_type().into_owned())
.collect_vec();
let summary = FunctionSummary {
name: function_name.into(),
arg_types,
Expand Down Expand Up @@ -870,10 +875,10 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T
let temp_ty_1 = expr_1.return_type();
let temp_ty_2 = expr_2.return_type();

match (temp_ty_1, temp_ty_2) {
match (temp_ty_1.as_ref(), temp_ty_2.as_ref()) {
(LogicalType::SqlNull, LogicalType::SqlNull) => Ok(LogicalType::SqlNull),
(ty, LogicalType::SqlNull) | (LogicalType::SqlNull, ty) => Ok(ty),
(ty_1, ty_2) => LogicalType::max_logical_type(&ty_1, &ty_2),
(ty, LogicalType::SqlNull) | (LogicalType::SqlNull, ty) => Ok(ty.clone()),
(ty_1, ty_2) => Ok(LogicalType::max_logical_type(ty_1, ty_2)?.into_owned()),
}
}

Expand Down Expand Up @@ -904,10 +909,10 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T
}

fn bind_cast(&mut self, expr: &Expr, ty: &DataType) -> Result<ScalarExpression, DatabaseError> {
Ok(ScalarExpression::TypeCast {
expr: Box::new(self.bind_expr(expr)?),
ty: LogicalType::try_from(ty.clone())?,
})
ScalarExpression::type_cast(
self.bind_expr(expr)?,
Cow::Owned(LogicalType::try_from(ty.clone())?),
)
}

fn wildcard_expr() -> ScalarExpression {
Expand Down
4 changes: 1 addition & 3 deletions src/binder/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,7 @@ impl<T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'_, '_, T, A>
ScalarExpression::Constant(mut value) => {
let ty = schema_ref[i].datatype();

if &value.logical_type() != ty {
value = value.cast(ty)?;
}
value = value.cast(ty)?;
// Check if the value length is too long
value.check_len(ty)?;
if value.is_null() && !schema_ref[i].nullable() {
Expand Down
25 changes: 11 additions & 14 deletions src/binder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'

inferred_types[col_index] = match &inferred_types[col_index] {
Some(existing) => {
Some(LogicalType::max_logical_type(existing, &value_type)?)
Some(LogicalType::max_logical_type(existing, &value_type)?.into_owned())
}
None => Some(value_type),
};
Expand Down Expand Up @@ -609,22 +609,19 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'
{
let cast_type =
LogicalType::max_logical_type(left_schema.datatype(), right_schema.datatype())?;
if &cast_type != left_schema.datatype() {
left_cast.push(ScalarExpression::TypeCast {
expr: Box::new(ScalarExpression::column_expr(left_schema.clone(), position)),
ty: cast_type.clone(),
});
if cast_type.as_ref() != left_schema.datatype() {
left_cast.push(ScalarExpression::type_cast(
ScalarExpression::column_expr(left_schema.clone(), position),
cast_type.clone(),
)?);
} else {
left_cast.push(ScalarExpression::column_expr(left_schema.clone(), position));
}
if &cast_type != right_schema.datatype() {
right_cast.push(ScalarExpression::TypeCast {
expr: Box::new(ScalarExpression::column_expr(
right_schema.clone(),
position,
)),
ty: cast_type.clone(),
});
if cast_type.as_ref() != right_schema.datatype() {
right_cast.push(ScalarExpression::type_cast(
ScalarExpression::column_expr(right_schema.clone(), position),
cast_type.clone(),
)?);
} else {
right_cast.push(ScalarExpression::column_expr(
right_schema.clone(),
Expand Down
11 changes: 5 additions & 6 deletions src/binder/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use crate::types::value::DataValue;
use sqlparser::ast::{
Assignment, AssignmentTarget, Expr, Ident, ObjectName, TableFactor, TableWithJoins,
};
use std::borrow::Cow;
use std::slice;
use std::sync::Arc;

Expand Down Expand Up @@ -91,12 +92,10 @@ impl<T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'_, '_, T, A>
} else {
expression.clone()
};
if &expr.return_type() != column.datatype() {
expr = ScalarExpression::TypeCast {
expr: Box::new(expr),
ty: column.datatype().clone(),
}
}
expr = ScalarExpression::type_cast(
expr,
Cow::Borrowed(column.datatype()),
)?;
value_exprs.push((column, expr));
}
_ => {
Expand Down
7 changes: 7 additions & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use crate::types::tuple::TupleId;
use crate::types::LogicalType;
use chrono::ParseError;
use sqlparser::parser::ParserError;
use std::convert::Infallible;
use std::num::{ParseFloatError, ParseIntError, TryFromIntError};
use std::str::{ParseBoolError, Utf8Error};
use std::string::FromUtf8Error;
Expand Down Expand Up @@ -268,6 +269,12 @@ pub enum DatabaseError {
ViewNotFound,
}

impl From<Infallible> for DatabaseError {
fn from(value: Infallible) -> Self {
match value {}
}
}

impl DatabaseError {
pub fn invalid_column(name: impl Into<String>) -> Self {
Self::InvalidColumn {
Expand Down
8 changes: 3 additions & 5 deletions src/execution/ddl/create_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,9 @@ impl CreateIndex {
while arena.next_tuple(self.input)? {
let (value, tuple_pk) = {
let tuple = arena.result_tuple();
let Some(value) = DataValue::values_to_tuple(Projection::projection(
tuple,
&column_exprs,
&self.input_schema,
)?) else {
let Some(value) =
DataValue::values_to_tuple(Projection::projection(tuple, &column_exprs)?)
else {
continue;
};
let Some(tuple_pk) = tuple.pk.clone() else {
Expand Down
7 changes: 2 additions & 5 deletions src/execution/dml/analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ use crate::planner::operator::analyze::AnalyzeOperator;
use crate::planner::LogicalPlan;
use crate::storage::{StatisticsMetaCache, Transaction};
use crate::types::index::IndexId;
use crate::types::tuple::SchemaRef;
use crate::types::value::{DataValue, Utf8Type};
use itertools::Itertools;
use sqlparser::ast::CharLengthUnits;
Expand All @@ -34,7 +33,6 @@ const DEFAULT_NUM_OF_BUCKETS: usize = 100;

pub struct Analyze {
table_name: TableName,
input_schema: SchemaRef,
input_plan: LogicalPlan,
input: Option<ExecId>,
histogram_buckets: Option<usize>,
Expand All @@ -48,13 +46,12 @@ impl From<(AnalyzeOperator, LogicalPlan)> for Analyze {
index_metas,
histogram_buckets,
},
mut input,
input,
): (AnalyzeOperator, LogicalPlan),
) -> Self {
let _ = index_metas;
Analyze {
table_name,
input_schema: input.output_schema().clone(),
input_plan: input,
input: None,
histogram_buckets,
Expand Down Expand Up @@ -108,7 +105,7 @@ impl Analyze {
while arena.next_tuple(input)? {
let tuple = arena.result_tuple();
for State { exprs, builder, .. } in builders.iter_mut() {
let values = Projection::projection(tuple, exprs, &self.input_schema)?;
let values = Projection::projection(tuple, exprs)?;

if values.len() == 1 {
builder.append(&values[0])?;
Expand Down
Loading
Loading