Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions sqlparser_bench/benches/sqlparser_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// under the License.

use criterion::{criterion_group, criterion_main, Criterion};
use sqlparser::dialect::GenericDialect;
use sqlparser::dialect::{GenericDialect, PostgreSqlDialect, SQLiteDialect};
use sqlparser::keywords::Keyword;
use sqlparser::parser::Parser;
use sqlparser::tokenizer::{Span, Word};
Expand Down Expand Up @@ -177,11 +177,58 @@ fn parse_compound_chain(c: &mut Criterion) {
group.finish();
}

/// Benchmark parsing pathological `IF(<keyword-fn>(<keyword-fn>(...x` chains
/// that previously caused 2^N work in `parse_prefix`. Each nested
/// `current_time(` segment used to be explored twice at every level (once via
/// the speculative reserved-word arm, once via the unreserved-word fallback),
/// doubling work per level. Post-fix the cost is linear in chain length.
fn parse_prefix_keyword_call_chain(c: &mut Criterion) {
let mut group = c.benchmark_group("parse_prefix_keyword_call_chain");
let dialect = PostgreSqlDialect {};

for &n in &[10usize, 20, 30] {
let sql = String::from("if(") + &"current_time(".repeat(n) + "x";

group.bench_function(format!("chain_{n}"), |b| {
b.iter(|| {
let _ = Parser::parse_sql(&dialect, std::hint::black_box(&sql));
});
});
}

group.finish();
}

/// Benchmark parsing pathological `case-case-case-...c` chains that
/// previously caused 2^N work in `parse_prefix`. Each `case` token used to
/// trigger a speculative `parse_case_expr` that recursively descends the
/// chain, but the unreserved-word fallback returns `Identifier(case)` so the
/// overall `parse_prefix` succeeds and the failure cache never fires.
/// Post-fix the per-arm cache short-circuits the speculative descent.
fn parse_prefix_case_chain(c: &mut Criterion) {
let mut group = c.benchmark_group("parse_prefix_case_chain");
let dialect = SQLiteDialect {};

for &n in &[10usize, 20, 30] {
let sql = "case\t-".repeat(n) + "c";

group.bench_function(format!("chain_{n}"), |b| {
b.iter(|| {
let _ = Parser::parse_sql(&dialect, std::hint::black_box(&sql));
});
});
}

group.finish();
}

criterion_group!(
benches,
basic_queries,
word_to_ident,
parse_many_identifiers,
parse_compound_chain
parse_compound_chain,
parse_prefix_keyword_call_chain,
parse_prefix_case_chain
);
criterion_main!(benches);
74 changes: 73 additions & 1 deletion src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#[cfg(not(feature = "std"))]
use alloc::{
boxed::Box,
collections::BTreeMap,
format,
string::{String, ToString},
vec,
Expand All @@ -24,6 +25,9 @@ use core::{
fmt::{self, Display},
str::FromStr,
};
#[cfg(feature = "std")]
use std::collections::BTreeMap;

use helpers::attached_token::AttachedToken;

use log::debug;
Expand Down Expand Up @@ -359,6 +363,29 @@ pub struct Parser<'a> {
options: ParserOptions,
/// Ensures the stack does not overflow by limiting recursion depth.
recursion_counter: RecursionCounter,
/// Cached failures from `parse_prefix` calls that returned `Err`. See
/// [`Parser::parse_prefix`] for the 2^N patterns this guards.
failed_prefix_positions: BTreeMap<usize, ExprPrefixError>,
/// Cached failures from the speculative reserved-word prefix arm. See
/// [`Parser::parse_prefix`] for the 2^N patterns this guards.
failed_reserved_word_prefix_positions: BTreeMap<usize, ExprPrefixError>,
}

/// Copy marker for a [`ParserError`] cached by the `parse_prefix` failure
/// memoization, so the caches hold no strings.
#[derive(Debug, Clone, Copy)]
enum ExprPrefixError {
RecursionLimitExceeded,
Err,
}

impl From<&ParserError> for ExprPrefixError {
fn from(e: &ParserError) -> Self {
match e {
ParserError::RecursionLimitExceeded => Self::RecursionLimitExceeded,
_ => Self::Err,
}
}
}

impl<'a> Parser<'a> {
Expand All @@ -385,6 +412,8 @@ impl<'a> Parser<'a> {
dialect,
recursion_counter: RecursionCounter::new(DEFAULT_REMAINING_DEPTH),
options: ParserOptions::new().with_trailing_commas(dialect.supports_trailing_commas()),
failed_prefix_positions: BTreeMap::new(),
failed_reserved_word_prefix_positions: BTreeMap::new(),
}
}

Expand Down Expand Up @@ -446,6 +475,8 @@ impl<'a> Parser<'a> {
pub fn with_tokens_with_locations(mut self, tokens: Vec<TokenWithSpan>) -> Self {
self.tokens = tokens;
self.index = 0;
self.failed_prefix_positions.clear();
self.failed_reserved_word_prefix_positions.clear();
self
}

Expand Down Expand Up @@ -1717,6 +1748,31 @@ impl<'a> Parser<'a> {
return prefix;
}

// Memoize parse_prefix failures to break 2^N speculation when both
// prefix arms fail at every level (e.g. `IF(current_time(...x`).
// The per-arm cache in `parse_prefix_inner` complements this for
// chains where the reserved arm fails but the unreserved fallback
// succeeds (e.g. `case-case-...c`).
let start_index = self.index;
if let Some(&cached) = self.failed_prefix_positions.get(&start_index) {
return Err(self.cached_prefix_error(cached, self.peek_token_ref()));
}
let result = self.parse_prefix_inner();
if let Err(ref e) = result {
self.failed_prefix_positions.insert(start_index, e.into());
}
result
}

/// Rebuild the error for a cached prefix failure at the `found` token.
fn cached_prefix_error(&self, cached: ExprPrefixError, found: &TokenWithSpan) -> ParserError {
match cached {
ExprPrefixError::RecursionLimitExceeded => ParserError::RecursionLimitExceeded,
ExprPrefixError::Err => self.expected_ref::<()>("an expression", found).unwrap_err(),
}
}

fn parse_prefix_inner(&mut self) -> Result<Expr, ParserError> {
// PostgreSQL allows any string literal to be preceded by a type name, indicating that the
// string literal represents a literal of that type. Some examples:
//
Expand Down Expand Up @@ -1801,7 +1857,21 @@ impl<'a> Parser<'a> {
// We first try to parse the word and following tokens as a special expression, and if that fails,
// we rollback and try to parse it as an identifier.
let w = w.clone();
match self.try_parse(|parser| parser.parse_expr_prefix_by_reserved_word(&w, span)) {
// Memoize failed speculative reserved-word parses. When
// the reserved arm (CASE, CURRENT_TIME, etc.) does
// exponential work but the unreserved fallback ultimately
// succeeds, the overall `parse_prefix` returns `Ok` and the
// outer cache never fires. Chains like `case-case-...c`
// need this per-arm cache to break the doubling.
Comment on lines +1860 to +1865
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't able to follow this comment, what is 'outer cache' and 'break doubling' referring to?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outer cache is the cache in the parent call before the recursive call, and doubling is that at each layer of depth it went in before, the operations doubled (hence the previous 2^layers complexity from before)

let try_parse_result = if let Some(&cached) = self
.failed_reserved_word_prefix_positions
.get(&next_token_index)
{
Err(self.cached_prefix_error(cached, self.get_current_token()))
} else {
self.try_parse(|parser| parser.parse_expr_prefix_by_reserved_word(&w, span))
};
match try_parse_result {
// This word indicated an expression prefix and parsing was successful
Ok(Some(expr)) => Ok(expr),

Expand All @@ -1815,6 +1885,8 @@ impl<'a> Parser<'a> {
// we rollback and return the parsing error we got from trying to parse a
// special expression (to maintain backwards compatibility of parsing errors).
Err(e) => {
self.failed_reserved_word_prefix_positions
.insert(next_token_index, (&e).into());
if !self.dialect.is_reserved_for_identifier(w.keyword) {
if let Ok(Some(expr)) = self.maybe_parse(|parser| {
parser.parse_expr_prefix_by_unreserved_word(&w, span)
Expand Down
64 changes: 63 additions & 1 deletion tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15564,7 +15564,10 @@ fn parse_create_table_select() {

#[test]
fn test_reserved_keywords_for_identifiers() {
let dialects = all_dialects_where(|d| d.is_reserved_for_identifier(Keyword::INTERVAL));
let dialects = all_dialects_where(|d| {
d.is_reserved_for_identifier(Keyword::INTERVAL)
&& !d.supports_named_fn_args_with_expr_name()
});
// Dialects that reserve the word INTERVAL will not allow it as an unquoted identifier
let sql = "SELECT MAX(interval) FROM tbl";
assert_eq!(
Expand All @@ -15574,6 +15577,19 @@ fn test_reserved_keywords_for_identifiers() {
))
);

// Dialects with expression-named function arguments parse the argument
// expression twice, so the second attempt reports the memoized failure
// at the start of the expression
let dialects = all_dialects_where(|d| {
d.is_reserved_for_identifier(Keyword::INTERVAL) && d.supports_named_fn_args_with_expr_name()
});
assert_eq!(
dialects.parse_sql_statements(sql),
Err(ParserError::ParserError(
"Expected: an expression, found: interval".to_string()
))
);

// Dialects that do not reserve the word INTERVAL will allow it
let dialects = all_dialects_where(|d| !d.is_reserved_for_identifier(Keyword::INTERVAL));
let sql = "SELECT MAX(interval) FROM tbl";
Expand Down Expand Up @@ -19004,3 +19020,49 @@ fn parse_compound_chain_no_exponential_blowup() {
rx.recv_timeout(Duration::from_secs(5))
.expect("parser should reject this quickly, not loop exponentially");
}

/// Regression test for the 2^N parse-time blowup in `parse_prefix` on inputs
/// like `IF(current_time(current_time(...x`. Each nested `current_time(` used
/// to be explored twice at every level (once via the speculative reserved-word
/// arm, once via the unreserved-word fallback), doubling work per level.
/// Post-fix the failing parse short-circuits via the position-keyed cache.
#[test]
fn parse_prefix_keyword_call_chain_no_exponential_blowup() {
use std::sync::mpsc;
use std::thread;
use std::time::Duration;

let sql = String::from("if(") + &"current_time(".repeat(30) + "x";

let (tx, rx) = mpsc::channel();
thread::spawn(move || {
let _ = Parser::parse_sql(&PostgreSqlDialect {}, &sql);
let _ = tx.send(());
});

rx.recv_timeout(Duration::from_secs(5))
.expect("parser should reject this quickly, not loop exponentially");
}

/// Regression test for the 2^N parse-time blowup in `parse_prefix` on inputs
/// like `case-case-case-...c`. Each `case` token triggers a speculative
/// `parse_case_expr` that fails, but the unreserved-word fallback returns
/// `Identifier(case)`, so the outer failure cache never fires. Post-fix the
/// per-arm cache short-circuits the speculative descent.
#[test]
fn parse_prefix_case_chain_no_exponential_blowup() {
use std::sync::mpsc;
use std::thread;
use std::time::Duration;

let sql = "case\t-".repeat(30) + "c";

let (tx, rx) = mpsc::channel();
thread::spawn(move || {
let _ = Parser::parse_sql(&SQLiteDialect {}, &sql);
let _ = tx.send(());
});

rx.recv_timeout(Duration::from_secs(5))
.expect("parser should reject this quickly, not loop exponentially");
}
Loading