diff --git a/src/dialect/hive.rs b/src/dialect/hive.rs index b39232ad5..085d1598c 100644 --- a/src/dialect/hive.rs +++ b/src/dialect/hive.rs @@ -72,4 +72,11 @@ impl Dialect for HiveDialect { fn supports_group_by_with_modifier(&self) -> bool { true } + + // TODO: The parsing of the FROM keyword seems wrong, as it happens within the CTE. + // See https://github.com/apache/datafusion-sqlparser-rs/issues/2236 for more details. + /// See + fn supports_from_first_insert(&self) -> bool { + true + } } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index bcca455ec..3eef5b49c 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -648,6 +648,21 @@ pub trait Dialect: Debug + Any { false } + /// Return true if the dialect supports "FROM-first" inserts. + /// + /// Example: + /// ```sql + /// WITH cte AS (SELECT key FROM src) + /// FROM cte + /// INSERT OVERWRITE table my_table + /// SELECT * + /// + /// See + /// ``` + fn supports_from_first_insert(&self) -> bool { + false + } + /// Return true if the dialect supports pipe operator. /// /// Example: diff --git a/src/parser/mod.rs b/src/parser/mod.rs index bb11d79c2..7d988cfe2 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -13958,7 +13958,7 @@ impl<'a> Parser<'a> { closing_paren_token: closing_paren_token.into(), } }; - if self.parse_keyword(Keyword::FROM) { + if self.dialect.supports_from_first_insert() && self.parse_keyword(Keyword::FROM) { cte.from = Some(self.parse_identifier()?); } Ok(cte) diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index a3b5404d3..19ea751b3 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -16069,10 +16069,37 @@ fn test_select_from_first() { pipe_operators: vec![], }; assert_eq!(expected, ast); - assert_eq!(ast.to_string(), q); } } +#[test] +fn test_select_from_first_with_cte() { + let dialects = all_dialects_where(|d| d.supports_from_first_select()); + let q = "WITH test AS (FROM t SELECT a) FROM test SELECT 1"; + + let ast = dialects.verified_query(q); + + let ast_select = ast.body.as_select().unwrap(); + + let expected_body_select_projection = + vec![SelectItem::UnnamedExpr(Expr::Value(ValueWithSpan { + value: test_utils::number("1"), + span: Span::empty(), + }))]; + + let expected_body_from = vec![TableWithJoins { + relation: table_from_name(ObjectName::from(vec![Ident { + value: "test".to_string(), + quote_style: None, + span: Span::empty(), + }])), + joins: vec![], + }]; + + assert_eq!(ast_select.projection, expected_body_select_projection); + assert_eq!(ast_select.from, expected_body_from); +} + #[test] fn test_geometric_unary_operators() { // Number of points in path or polygon