Skip to main content

spark_sql_parser/
lib.rs

1//! Parse SQL into [sqlparser] AST.
2//!
3//! Supports a Spark-style subset: single-statement SELECT, CREATE SCHEMA/DATABASE,
4//! and DROP TABLE/VIEW/SCHEMA, plus many DDL and utility statements (CREATE/ALTER/DROP
5//! TABLE/VIEW/FUNCTION/SCHEMA, SHOW, INSERT, DESCRIBE, SET, RESET, CACHE, EXPLAIN, etc.).
6//!
7//! # SELECT and query compatibility
8//!
9//! Any statement that [sqlparser] parses as a `Query` (e.g. `SELECT`, `WITH ... SELECT`)
10//! is accepted. Clause support is determined by [sqlparser] and the dialect in use
11//! (this crate uses [GenericDialect](sqlparser::dialect::GenericDialect)).
12//!
13//! ## Known gaps
14//!
15//! Spark-specific query clauses such as `DISTRIBUTE BY`, `CLUSTER BY`, `SORT BY`
16//! may not be recognized by the parser or may be rejected; behavior depends on
17//! the upstream dialect and parser. Use single-statement queries only (one statement
18//! per call).
19
20use sqlparser::ast::{
21    Expr as SqlExpr, Ident, ObjectName, Query, Select, SelectItem, SetExpr, Statement,
22};
23use sqlparser::dialect::GenericDialect;
24use sqlparser::parser::Parser;
25use thiserror::Error;
26
27/// Error returned when SQL parsing or validation fails.
28#[derive(Error, Debug)]
29#[error("{0}")]
30pub struct ParseError(String);
31
32/// Re-export of [sqlparser::ast] so consumers can depend only on `spark-sql-parser` for SQL AST types.
33pub use sqlparser::ast;
34
35/// Spark-oriented statement variants that are not reliably represented by upstream `sqlparser` AST.
36///
37/// This enum is intended to capture Spark/PySpark command forms where upstream parsing either:
38/// - does not accept the syntax, or
39/// - accepts it but in a way that's inconvenient for execution-layer parity.
40#[derive(Debug, Clone, PartialEq)]
41pub enum SparkStatement {
42    /// A statement parsed by upstream `sqlparser`.
43    Sqlparser(Box<Statement>),
44    /// `DESCRIBE DETAIL <table>` (Delta Lake; Spark command).
45    DescribeDetail { table: ObjectName },
46    /// `SHOW DATABASES` (Spark command).
47    ShowDatabases,
48    /// `SHOW TABLES` or `SHOW TABLES IN/FROM <db>` (Spark command).
49    ShowTables { db: Option<ObjectName> },
50    /// `DESCRIBE/DESC [TABLE] [EXTENDED] <table> [<col>]` (Spark/PySpark parity).
51    Describe {
52        table: ObjectName,
53        col: Option<Ident>,
54        extended: bool,
55    },
56    /// `CREATE OR REPLACE TABLE <table> USING <format> AS SELECT ...` (Spark/Delta Lake).
57    /// This syntax is not supported by upstream sqlparser GenericDialect.
58    CreateOrReplaceTableAs {
59        table: ObjectName,
60        format: String,
61        query: Box<Query>,
62    },
63}
64
65fn parse_one_statement_raw(query: &str) -> Result<Statement, ParseError> {
66    let dialect = GenericDialect {};
67    let stmts = Parser::parse_sql(&dialect, query).map_err(|e| {
68        ParseError(format!(
69            "SQL parse error: {}. Hint: supported statements include SELECT, CREATE TABLE/VIEW/FUNCTION/SCHEMA/DATABASE, DROP TABLE/VIEW/SCHEMA.",
70            e
71        ))
72    })?;
73    if stmts.len() != 1 {
74        return Err(ParseError(format!(
75            "SQL: expected exactly one statement, got {}. Hint: run one statement at a time.",
76            stmts.len()
77        )));
78    }
79    Ok(stmts.into_iter().next().expect("len == 1"))
80}
81
82fn parse_object_name(name: &str) -> Result<ObjectName, ParseError> {
83    let s = name.trim();
84    if s.is_empty() {
85        return Err(ParseError(
86            "SQL: expected an object name, got empty string.".to_string(),
87        ));
88    }
89    // Minimal support for Spark-style qualified names like `schema.table` and `global_temp.gv`.
90    // Backtick quoting is intentionally not handled yet (can be added later).
91    let parts: Vec<Ident> = s
92        .split('.')
93        .map(|p| p.trim())
94        .filter(|p| !p.is_empty())
95        .map(Ident::new)
96        .collect();
97    if parts.is_empty() {
98        return Err(ParseError(format!(
99            "SQL: expected an object name, got '{s}'."
100        )));
101    }
102    Ok(ObjectName::from(parts))
103}
104
105fn tokenize_ws(s: &str) -> Vec<&str> {
106    s.split_whitespace().collect()
107}
108
109/// Try to parse `CREATE OR REPLACE TABLE <table> USING <format> AS SELECT ...`.
110/// Returns None if the pattern doesn't match, Some(SparkStatement) if it does.
111fn try_parse_create_or_replace_table_as(
112    query: &str,
113    toks: &[&str],
114) -> Result<Option<SparkStatement>, ParseError> {
115    // Minimum tokens: CREATE OR REPLACE TABLE <name> USING <format> AS SELECT ...
116    // That's at least 8 tokens before the SELECT part
117    if toks.len() < 8 {
118        return Ok(None);
119    }
120
121    // Check for CREATE OR REPLACE TABLE pattern
122    if !(toks[0].eq_ignore_ascii_case("CREATE")
123        && toks[1].eq_ignore_ascii_case("OR")
124        && toks[2].eq_ignore_ascii_case("REPLACE")
125        && toks[3].eq_ignore_ascii_case("TABLE"))
126    {
127        return Ok(None);
128    }
129
130    // Find USING keyword position (table name is between TABLE and USING)
131    let using_pos = toks[4..]
132        .iter()
133        .position(|t| t.eq_ignore_ascii_case("USING"))
134        .map(|i| i + 4);
135
136    let using_pos = match using_pos {
137        Some(pos) => pos,
138        None => return Ok(None), // No USING keyword, let upstream handle it
139    };
140
141    // Table name is tokens 4..using_pos, joined with dots for qualified names
142    let table_tokens = &toks[4..using_pos];
143    if table_tokens.is_empty() {
144        return Err(ParseError(
145            "SQL: CREATE OR REPLACE TABLE requires a table name.".to_string(),
146        ));
147    }
148
149    // Reconstruct table name (handle qualified names like schema.table)
150    // The tokenizer splits on whitespace, so "schema.table" stays as one token
151    let table_name_str = table_tokens.join(" ");
152    let table = parse_object_name(&table_name_str)?;
153
154    // Format is the token after USING
155    if using_pos + 1 >= toks.len() {
156        return Err(ParseError(
157            "SQL: CREATE OR REPLACE TABLE ... USING requires a format (e.g., delta, parquet)."
158                .to_string(),
159        ));
160    }
161    let format = toks[using_pos + 1].to_string();
162
163    // Find AS keyword position (must be after USING <format>)
164    let as_pos = toks[using_pos + 2..]
165        .iter()
166        .position(|t| t.eq_ignore_ascii_case("AS"))
167        .map(|i| i + using_pos + 2);
168
169    let as_pos = match as_pos {
170        Some(pos) => pos,
171        None => {
172            return Err(ParseError(
173                "SQL: CREATE OR REPLACE TABLE ... USING <format> requires AS SELECT ..."
174                    .to_string(),
175            ));
176        }
177    };
178
179    // Reconstruct the subquery from tokens after AS
180    // This is more reliable than byte-position searching for multiline queries
181    let subquery_tokens = &toks[as_pos + 1..];
182    if subquery_tokens.is_empty() {
183        return Err(ParseError(
184            "SQL: CREATE OR REPLACE TABLE ... AS requires a SELECT query.".to_string(),
185        ));
186    }
187
188    // For complex queries with expressions, we need to extract from original string
189    // Find the AS keyword in original query after the format token
190    let format_token = &toks[using_pos + 1];
191    let query_lower = query.to_lowercase();
192
193    // Find the format token position first
194    let format_byte_pos = query_lower.find(&format_token.to_lowercase()).unwrap_or(0);
195
196    // Then find " AS " or "\nAS " etc. after the format token
197    let search_start = format_byte_pos + format_token.len();
198    let remaining = &query[search_start..];
199    let remaining_lower = remaining.to_lowercase();
200
201    // Look for AS as a standalone word (surrounded by whitespace)
202    let as_offset = find_standalone_as(&remaining_lower);
203
204    let as_offset = match as_offset {
205        Some(offset) => offset,
206        None => {
207            return Err(ParseError(
208                "SQL: Could not locate AS keyword in CREATE OR REPLACE TABLE statement."
209                    .to_string(),
210            ));
211        }
212    };
213
214    // Skip past "AS" and any following whitespace
215    let after_as = &remaining[as_offset..];
216
217    // Find the "AS" keyword and skip past it
218    let as_lower = after_as.to_lowercase();
219    let as_keyword_pos = as_lower.find("as").unwrap_or(0);
220    let after_as_keyword = &after_as[as_keyword_pos + 2..];
221
222    // Find where the actual SELECT starts (skip whitespace after AS)
223    let subquery_str = after_as_keyword.trim_start();
224
225    if subquery_str.is_empty() {
226        return Err(ParseError(
227            "SQL: CREATE OR REPLACE TABLE ... AS requires a SELECT query.".to_string(),
228        ));
229    }
230
231    // Parse the subquery using upstream parser
232    let dialect = GenericDialect {};
233    let stmts = Parser::parse_sql(&dialect, subquery_str).map_err(|e| {
234        ParseError(format!(
235            "SQL parse error in CREATE OR REPLACE TABLE subquery: {}",
236            e
237        ))
238    })?;
239
240    if stmts.len() != 1 {
241        return Err(ParseError(format!(
242            "SQL: CREATE OR REPLACE TABLE subquery must be a single SELECT statement, got {} statements.",
243            stmts.len()
244        )));
245    }
246
247    let stmt = stmts.into_iter().next().expect("len == 1");
248    let query_ast = match stmt {
249        Statement::Query(q) => q,
250        _ => {
251            return Err(ParseError(
252                "SQL: CREATE OR REPLACE TABLE ... AS requires a SELECT query.".to_string(),
253            ));
254        }
255    };
256
257    Ok(Some(SparkStatement::CreateOrReplaceTableAs {
258        table,
259        format,
260        query: query_ast,
261    }))
262}
263
264/// Find the position of standalone "AS" keyword (surrounded by whitespace).
265fn find_standalone_as(s: &str) -> Option<usize> {
266    let bytes = s.as_bytes();
267    let len = bytes.len();
268
269    for i in 0..len {
270        // Check if we're at a whitespace character
271        if !bytes[i].is_ascii_whitespace() {
272            continue;
273        }
274
275        // Look for "as" after whitespace
276        if i + 3 <= len {
277            let candidate = &s[i + 1..i + 3];
278            if candidate.eq_ignore_ascii_case("as") {
279                // Check if followed by whitespace or end of string
280                if i + 3 == len || bytes[i + 3].is_ascii_whitespace() {
281                    return Some(i);
282                }
283            }
284        }
285    }
286
287    // Also check if string starts with "as"
288    if len >= 2 && s[..2].eq_ignore_ascii_case("as") && (len == 2 || bytes[2].is_ascii_whitespace())
289    {
290        return Some(0);
291    }
292
293    None
294}
295
296/// Parse a Spark/PySpark-compatible SQL string.
297///
298/// - First, fast-path Spark-only command variants (e.g. `DESCRIBE DETAIL`, `SHOW TABLES IN db`,
299///   `DESCRIBE t col`).
300/// - Otherwise, fall back to upstream `sqlparser` and return `SparkStatement::Sqlparser`.
301/// - Always enforces **exactly one statement** per call.
302pub fn parse_spark_sql(query: &str) -> Result<SparkStatement, ParseError> {
303    let q = query.trim();
304    if q.is_empty() {
305        // Let upstream parser produce the most specific error message.
306        let _ = parse_one_statement_raw(q)?;
307    }
308
309    // Tokenize for Spark-only command matching.
310    let toks = tokenize_ws(q);
311
312    // CREATE OR REPLACE TABLE <table> USING <format> AS SELECT ...
313    // This Spark-specific syntax is not supported by upstream sqlparser GenericDialect.
314    if let Some(stmt) = try_parse_create_or_replace_table_as(q, &toks)? {
315        return Ok(stmt);
316    }
317
318    if toks.len() >= 2
319        && toks[0].eq_ignore_ascii_case("SHOW")
320        && toks[1].eq_ignore_ascii_case("DATABASES")
321    {
322        return Ok(SparkStatement::ShowDatabases);
323    }
324
325    // SHOW TABLES [IN|FROM db]
326    if toks.len() >= 2
327        && toks[0].eq_ignore_ascii_case("SHOW")
328        && toks[1].eq_ignore_ascii_case("TABLES")
329    {
330        let db = if toks.len() >= 4
331            && (toks[2].eq_ignore_ascii_case("IN") || toks[2].eq_ignore_ascii_case("FROM"))
332        {
333            Some(parse_object_name(toks[3])?)
334        } else {
335            None
336        };
337        return Ok(SparkStatement::ShowTables { db });
338    }
339
340    // DESCRIBE DETAIL <table>
341    if toks.len() >= 3
342        && toks[0].eq_ignore_ascii_case("DESCRIBE")
343        && toks[1].eq_ignore_ascii_case("DETAIL")
344    {
345        let table = parse_object_name(&toks[2..].join(" "))?;
346        return Ok(SparkStatement::DescribeDetail { table });
347    }
348
349    // DESC DETAIL is a synonym for DESCRIBE DETAIL in Spark; treat it the same.
350    if toks.len() >= 3
351        && toks[0].eq_ignore_ascii_case("DESC")
352        && toks[1].eq_ignore_ascii_case("DETAIL")
353    {
354        let table = parse_object_name(&toks[2..].join(" "))?;
355        return Ok(SparkStatement::DescribeDetail { table });
356    }
357
358    // DESCRIBE/DESC [TABLE] [EXTENDED] <table> [<col>]
359    if !toks.is_empty()
360        && (toks[0].eq_ignore_ascii_case("DESCRIBE") || toks[0].eq_ignore_ascii_case("DESC"))
361    {
362        // Exclude DETAIL which is handled above.
363        if toks.len() >= 2 && toks[1].eq_ignore_ascii_case("DETAIL") {
364            // already handled above; fallthrough defensive.
365        } else {
366            let rest = &toks[1..];
367            if !rest.is_empty() {
368                let extended = rest.iter().any(|t| t.eq_ignore_ascii_case("EXTENDED"));
369                // Find the first token that is not TABLE/EXTENDED => table name token.
370                let idx = rest.iter().position(|t| {
371                    !t.eq_ignore_ascii_case("TABLE") && !t.eq_ignore_ascii_case("EXTENDED")
372                });
373                if let Some(i) = idx {
374                    let table_tok = rest.get(i).copied().unwrap_or("");
375                    if !table_tok.is_empty() {
376                        let table = parse_object_name(table_tok)?;
377                        let col = rest.get(i + 1).map(|c| Ident::new(*c));
378                        return Ok(SparkStatement::Describe {
379                            table,
380                            col,
381                            extended,
382                        });
383                    }
384                }
385            }
386        }
387    }
388
389    // Fall back to upstream parsing for everything else.
390    let stmt = parse_one_statement_raw(query)?;
391    Ok(SparkStatement::Sqlparser(Box::new(stmt)))
392}
393
394/// Parse a single SQL expression string (optionally with an alias) into `sqlparser` expression AST.
395///
396/// This is intended for PySpark parity helpers like `selectExpr` and `expr()` where the input is
397/// a *projection expression*, not a full SQL statement.
398pub fn parse_select_expr(expr_str: &str) -> Result<(SqlExpr, Option<Ident>), ParseError> {
399    let e = expr_str.trim();
400    if e.is_empty() {
401        return Err(ParseError(
402            "SQL: expected an expression string, got empty.".to_string(),
403        ));
404    }
405    // Parse by embedding into a query; keep the hack local to this crate.
406    const TMP_TABLE: &str = "__spark_sql_parser_expr_t";
407    let query = format!("SELECT {e} FROM {TMP_TABLE}");
408    let stmt = parse_one_statement_raw(&query)?;
409    let query_ast: &Query = match &stmt {
410        Statement::Query(q) => q.as_ref(),
411        other => {
412            return Err(ParseError(format!(
413                "SQL: expected SELECT when parsing expression, got {other:?}."
414            )));
415        }
416    };
417    let select: &Select = match query_ast.body.as_ref() {
418        SetExpr::Select(s) => s.as_ref(),
419        other => {
420            return Err(ParseError(format!(
421                "SQL: expected SELECT when parsing expression, got {other:?}."
422            )));
423        }
424    };
425    let first: &SelectItem = select.projection.first().ok_or_else(|| {
426        ParseError("SQL: expected non-empty SELECT list when parsing expression.".to_string())
427    })?;
428    match first {
429        SelectItem::UnnamedExpr(ex) => Ok((ex.clone(), None)),
430        SelectItem::ExprWithAlias { expr, alias } => Ok((expr.clone(), Some(alias.clone()))),
431        other => Err(ParseError(format!(
432            "SQL: unsupported expression form in SELECT list: {other:?}."
433        ))),
434    }
435}
436
437/// Parse a single SQL statement (SELECT or DDL: CREATE SCHEMA / CREATE DATABASE / DROP TABLE/VIEW/SCHEMA).
438///
439/// Returns the [sqlparser::ast::Statement] on success. Only one statement per call;
440/// run one statement at a time.
441pub fn parse_sql(query: &str) -> Result<Statement, ParseError> {
442    let stmt = parse_one_statement_raw(query)?;
443    match &stmt {
444        Statement::Query(_) => {}
445        Statement::CreateSchema { .. } | Statement::CreateDatabase { .. } => {}
446        Statement::CreateTable(_) | Statement::CreateView(_) | Statement::CreateFunction(_) => {}
447        Statement::AlterTable(_) | Statement::AlterView { .. } | Statement::AlterSchema(_) => {}
448        Statement::Drop {
449            object_type:
450                sqlparser::ast::ObjectType::Table
451                | sqlparser::ast::ObjectType::View
452                | sqlparser::ast::ObjectType::Schema
453                | sqlparser::ast::ObjectType::Database,
454            ..
455        } => {}
456        Statement::DropFunction(_) => {}
457        Statement::Use(_) | Statement::Truncate(_) | Statement::Declare { .. } => {}
458        Statement::ShowTables { .. }
459        | Statement::ShowDatabases { .. }
460        | Statement::ShowSchemas { .. }
461        | Statement::ShowFunctions { .. }
462        | Statement::ShowColumns { .. }
463        | Statement::ShowViews { .. }
464        | Statement::ShowCreate { .. } => {}
465        Statement::Insert(_) | Statement::Directory { .. } | Statement::LoadData { .. } => {}
466        Statement::Update(_) | Statement::Delete(_) => {}
467        Statement::ExplainTable { .. } => {}
468        Statement::Set(_) | Statement::Reset(_) => {}
469        Statement::Cache { .. } | Statement::UNCache { .. } => {}
470        Statement::Explain { .. } => {}
471        _ => {
472            return Err(ParseError(format!(
473                "SQL: statement type not supported, got {:?}.",
474                stmt
475            )));
476        }
477    }
478    Ok(stmt)
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484    use sqlparser::ast::{ObjectType, Statement};
485
486    /// Assert that `sql` parses to the given statement variant.
487    fn assert_parses_to<F>(sql: &str, check: F)
488    where
489        F: FnOnce(&Statement) -> bool,
490    {
491        let stmt = parse_sql(sql).unwrap_or_else(|e| panic!("parse_sql failed: {e}"));
492        assert!(check(&stmt), "expected match for: {sql}");
493    }
494
495    // --- Error handling ---
496
497    #[test]
498    fn error_multiple_statements() {
499        let err = parse_sql("SELECT 1; SELECT 2").unwrap_err();
500        assert!(err.0.contains("expected exactly one statement"));
501        assert!(err.0.contains("2"));
502    }
503
504    #[test]
505    fn error_zero_statements() {
506        let err = parse_sql("").unwrap_err();
507        assert!(err.0.contains("expected exactly one statement") || err.0.contains("parse error"));
508    }
509
510    #[test]
511    fn error_unsupported_statement_type() {
512        // COMMIT is parsed by sqlparser but not in our whitelist
513        let err = parse_sql("COMMIT").unwrap_err();
514        assert!(err.0.contains("not supported"));
515    }
516
517    #[test]
518    fn error_syntax() {
519        let err = parse_sql("SELECT FROM").unwrap_err();
520        assert!(!err.0.is_empty());
521    }
522
523    // --- Queries ---
524
525    #[test]
526    fn query_select_simple() {
527        assert_parses_to("SELECT 1", |s| matches!(s, Statement::Query(_)));
528    }
529
530    #[test]
531    fn query_select_with_from() {
532        assert_parses_to("SELECT a FROM t", |s| matches!(s, Statement::Query(_)));
533    }
534
535    #[test]
536    fn query_with_cte() {
537        assert_parses_to("WITH cte AS (SELECT 1) SELECT * FROM cte", |s| {
538            matches!(s, Statement::Query(_))
539        });
540    }
541
542    #[test]
543    fn query_create_schema() {
544        assert_parses_to("CREATE SCHEMA s", |s| {
545            matches!(s, Statement::CreateSchema { .. })
546        });
547    }
548
549    #[test]
550    fn query_create_database() {
551        assert_parses_to("CREATE DATABASE d", |s| {
552            matches!(s, Statement::CreateDatabase { .. })
553        });
554    }
555
556    // --- DDL: CREATE (issue #652) ---
557
558    #[test]
559    fn test_issue_652_create_table() {
560        assert_parses_to("CREATE TABLE t (a INT)", |s| {
561            matches!(s, Statement::CreateTable(_))
562        });
563    }
564
565    #[test]
566    fn test_issue_652_create_view() {
567        assert_parses_to("CREATE VIEW v AS SELECT 1", |s| {
568            matches!(s, Statement::CreateView(_))
569        });
570    }
571
572    #[test]
573    fn test_issue_652_create_function() {
574        assert_parses_to("CREATE FUNCTION f() AS 'com.example.UDF'", |s| {
575            matches!(s, Statement::CreateFunction(_))
576        });
577    }
578
579    // --- DDL: ALTER (issue #653) ---
580
581    #[test]
582    fn test_issue_653_alter_table() {
583        assert_parses_to("ALTER TABLE t ADD COLUMN c INT", |s| {
584            matches!(s, Statement::AlterTable(_))
585        });
586    }
587
588    #[test]
589    fn test_issue_653_alter_view() {
590        assert_parses_to("ALTER VIEW v AS SELECT 1", |s| {
591            matches!(s, Statement::AlterView { .. })
592        });
593    }
594
595    #[test]
596    fn test_issue_653_alter_schema() {
597        assert_parses_to("ALTER SCHEMA db RENAME TO db2", |s| {
598            matches!(s, Statement::AlterSchema(_))
599        });
600    }
601
602    // --- DDL: DROP (issue #654) ---
603
604    #[test]
605    fn test_issue_654_drop_table() {
606        let stmt = parse_sql("DROP TABLE t").unwrap();
607        match &stmt {
608            Statement::Drop {
609                object_type: ObjectType::Table,
610                ..
611            } => {}
612            _ => panic!("expected Drop Table: {stmt:?}"),
613        }
614    }
615
616    #[test]
617    fn test_issue_654_drop_view() {
618        let stmt = parse_sql("DROP VIEW v").unwrap();
619        match &stmt {
620            Statement::Drop {
621                object_type: ObjectType::View,
622                ..
623            } => {}
624            _ => panic!("expected Drop View: {stmt:?}"),
625        }
626    }
627
628    #[test]
629    fn test_issue_654_drop_schema() {
630        let stmt = parse_sql("DROP SCHEMA s").unwrap();
631        match &stmt {
632            Statement::Drop {
633                object_type: ObjectType::Schema,
634                ..
635            } => {}
636            _ => panic!("expected Drop Schema: {stmt:?}"),
637        }
638    }
639
640    #[test]
641    fn test_issue_654_drop_function() {
642        assert_parses_to("DROP FUNCTION f", |s| {
643            matches!(s, Statement::DropFunction(_))
644        });
645    }
646
647    // --- Utility: USE, TRUNCATE, DECLARE (issue #655) ---
648
649    #[test]
650    fn test_issue_655_use() {
651        assert_parses_to("USE db1", |s| matches!(s, Statement::Use(_)));
652    }
653
654    #[test]
655    fn test_issue_655_truncate() {
656        assert_parses_to("TRUNCATE TABLE t", |s| matches!(s, Statement::Truncate(_)));
657    }
658
659    #[test]
660    fn test_issue_655_declare() {
661        assert_parses_to("DECLARE c CURSOR FOR SELECT 1", |s| {
662            matches!(s, Statement::Declare { .. })
663        });
664    }
665
666    // --- SHOW (issue #656) ---
667
668    #[test]
669    fn test_issue_656_show_tables() {
670        assert_parses_to("SHOW TABLES", |s| matches!(s, Statement::ShowTables { .. }));
671    }
672
673    #[test]
674    fn test_issue_656_show_databases() {
675        assert_parses_to("SHOW DATABASES", |s| {
676            matches!(s, Statement::ShowDatabases { .. })
677        });
678    }
679
680    #[test]
681    fn test_issue_656_show_schemas() {
682        assert_parses_to("SHOW SCHEMAS", |s| {
683            matches!(s, Statement::ShowSchemas { .. })
684        });
685    }
686
687    #[test]
688    fn test_issue_656_show_functions() {
689        assert_parses_to("SHOW FUNCTIONS", |s| {
690            matches!(s, Statement::ShowFunctions { .. })
691        });
692    }
693
694    #[test]
695    fn test_issue_656_show_columns() {
696        assert_parses_to("SHOW COLUMNS FROM t", |s| {
697            matches!(s, Statement::ShowColumns { .. })
698        });
699    }
700
701    #[test]
702    fn test_issue_656_show_views() {
703        assert_parses_to("SHOW VIEWS", |s| matches!(s, Statement::ShowViews { .. }));
704    }
705
706    #[test]
707    fn test_issue_656_show_create_table() {
708        assert_parses_to("SHOW CREATE TABLE t", |s| {
709            matches!(s, Statement::ShowCreate { .. })
710        });
711    }
712
713    // --- INSERT / DIRECTORY (issue #657) ---
714
715    #[test]
716    fn test_issue_657_insert() {
717        assert_parses_to("INSERT INTO t SELECT 1", |s| {
718            matches!(s, Statement::Insert(_))
719        });
720    }
721
722    #[test]
723    fn test_issue_657_directory() {
724        assert_parses_to("INSERT OVERWRITE DIRECTORY '/path' SELECT 1", |s| {
725            matches!(s, Statement::Directory { .. })
726        });
727    }
728
729    // --- DESCRIBE (issue #658) ---
730
731    #[test]
732    fn test_issue_658_describe_table() {
733        assert_parses_to("DESCRIBE t", |s| {
734            matches!(s, Statement::ExplainTable { .. })
735        });
736    }
737
738    // --- SET, RESET, CACHE, UNCACHE (issue #659) ---
739
740    #[test]
741    fn test_issue_659_set() {
742        assert_parses_to("SET x = 1", |s| matches!(s, Statement::Set(_)));
743    }
744
745    #[test]
746    fn test_issue_659_reset() {
747        assert_parses_to("RESET x", |s| matches!(s, Statement::Reset(_)));
748    }
749
750    #[test]
751    fn test_issue_659_cache() {
752        assert_parses_to("CACHE TABLE t", |s| matches!(s, Statement::Cache { .. }));
753    }
754
755    #[test]
756    fn test_issue_659_uncache() {
757        assert_parses_to("UNCACHE TABLE t", |s| {
758            matches!(s, Statement::UNCache { .. })
759        });
760    }
761
762    #[test]
763    fn test_issue_659_uncache_if_exists() {
764        assert_parses_to("UNCACHE TABLE IF EXISTS t", |s| {
765            matches!(s, Statement::UNCache { .. })
766        });
767    }
768
769    // --- EXPLAIN (issue #660) ---
770
771    #[test]
772    fn test_issue_660_explain() {
773        assert_parses_to("EXPLAIN SELECT 1", |s| {
774            matches!(s, Statement::Explain { .. })
775        });
776    }
777
778    // --- SparkStatement parsing (Spark/PySpark command variants) ---
779
780    #[test]
781    fn spark_show_databases() {
782        let s = parse_spark_sql("SHOW DATABASES").unwrap();
783        assert!(matches!(s, SparkStatement::ShowDatabases));
784    }
785
786    #[test]
787    fn spark_show_tables_in_db() {
788        let s = parse_spark_sql("SHOW TABLES IN my_db").unwrap();
789        match s {
790            SparkStatement::ShowTables { db: Some(db) } => {
791                assert_eq!(db.to_string(), "my_db");
792            }
793            other => panic!("expected ShowTables with db, got {other:?}"),
794        }
795    }
796
797    #[test]
798    fn spark_describe_detail() {
799        let s = parse_spark_sql("DESCRIBE DETAIL schema1.tbl1").unwrap();
800        match s {
801            SparkStatement::DescribeDetail { table } => {
802                assert_eq!(table.to_string(), "schema1.tbl1");
803            }
804            other => panic!("expected DescribeDetail, got {other:?}"),
805        }
806    }
807
808    #[test]
809    fn spark_describe_optional_col() {
810        let s = parse_spark_sql("DESCRIBE t age").unwrap();
811        match s {
812            SparkStatement::Describe {
813                table,
814                col: Some(c),
815                extended: false,
816            } => {
817                assert_eq!(table.to_string(), "t");
818                assert_eq!(c.value, "age");
819            }
820            other => panic!("expected Describe with col, got {other:?}"),
821        }
822    }
823
824    #[test]
825    fn spark_describe_table_extended() {
826        let s = parse_spark_sql("DESCRIBE TABLE EXTENDED t").unwrap();
827        match s {
828            SparkStatement::Describe {
829                table,
830                col: None,
831                extended: true,
832            } => {
833                assert_eq!(table.to_string(), "t");
834            }
835            other => panic!("expected Describe extended, got {other:?}"),
836        }
837    }
838
839    // --- Expression parsing helper ---
840
841    #[test]
842    fn parse_select_expr_with_alias() {
843        let (e, a) = parse_select_expr("upper(Name) AS u").unwrap();
844        let _ = e; // structure validated by parse
845        assert_eq!(a.unwrap().value, "u");
846    }
847
848    #[test]
849    fn parse_select_expr_without_alias() {
850        let (_e, a) = parse_select_expr("ltrim(rtrim(Value))").unwrap();
851        assert!(a.is_none());
852    }
853
854    // ========== Robust parse_spark_sql tests ==========
855
856    #[test]
857    fn spark_show_databases_case_insensitive() {
858        for sql in ["show databases", "Show Databases", "SHOW DATABASES"] {
859            let s = parse_spark_sql(sql).unwrap();
860            assert!(
861                matches!(s, SparkStatement::ShowDatabases),
862                "failed for: {sql}"
863            );
864        }
865    }
866
867    #[test]
868    fn spark_show_tables_no_db() {
869        let s = parse_spark_sql("SHOW TABLES").unwrap();
870        match s {
871            SparkStatement::ShowTables { db: None } => {}
872            other => panic!("expected ShowTables with db=None, got {other:?}"),
873        }
874    }
875
876    #[test]
877    fn spark_show_tables_from_db() {
878        let s = parse_spark_sql("SHOW TABLES FROM other_db").unwrap();
879        match s {
880            SparkStatement::ShowTables { db: Some(db) } => assert_eq!(db.to_string(), "other_db"),
881            other => panic!("expected ShowTables with db, got {other:?}"),
882        }
883    }
884
885    #[test]
886    fn spark_show_tables_in_db_case_insensitive() {
887        let s = parse_spark_sql("show tables in MySchema").unwrap();
888        match s {
889            SparkStatement::ShowTables { db: Some(db) } => assert_eq!(db.to_string(), "MySchema"),
890            other => panic!("expected ShowTables with db, got {other:?}"),
891        }
892    }
893
894    #[test]
895    fn spark_describe_detail_single_table() {
896        let s = parse_spark_sql("DESCRIBE DETAIL t").unwrap();
897        match s {
898            SparkStatement::DescribeDetail { table } => assert_eq!(table.to_string(), "t"),
899            other => panic!("expected DescribeDetail, got {other:?}"),
900        }
901    }
902
903    #[test]
904    fn spark_describe_detail_case_insensitive() {
905        let s = parse_spark_sql("describe detail my_table").unwrap();
906        match s {
907            SparkStatement::DescribeDetail { table } => assert_eq!(table.to_string(), "my_table"),
908            other => panic!("expected DescribeDetail, got {other:?}"),
909        }
910    }
911
912    #[test]
913    fn spark_desc_detail_synonym() {
914        let s = parse_spark_sql("DESC DETAIL catalog.schema.tbl").unwrap();
915        match s {
916            SparkStatement::DescribeDetail { table } => {
917                assert_eq!(table.to_string(), "catalog.schema.tbl")
918            }
919            other => panic!("expected DescribeDetail, got {other:?}"),
920        }
921    }
922
923    #[test]
924    fn spark_describe_table_only() {
925        let s = parse_spark_sql("DESCRIBE my_tbl").unwrap();
926        match s {
927            SparkStatement::Describe {
928                table,
929                col: None,
930                extended: false,
931            } => assert_eq!(table.to_string(), "my_tbl"),
932            other => panic!("expected Describe table only, got {other:?}"),
933        }
934    }
935
936    #[test]
937    fn spark_describe_extended_only() {
938        let s = parse_spark_sql("DESCRIBE EXTENDED t").unwrap();
939        match s {
940            SparkStatement::Describe {
941                table,
942                col: None,
943                extended: true,
944            } => assert_eq!(table.to_string(), "t"),
945            other => panic!("expected Describe extended, got {other:?}"),
946        }
947    }
948
949    #[test]
950    fn spark_desc_short_form() {
951        let s = parse_spark_sql("DESC t col_x").unwrap();
952        match s {
953            SparkStatement::Describe {
954                table,
955                col: Some(c),
956                extended: false,
957            } => {
958                assert_eq!(table.to_string(), "t");
959                assert_eq!(c.value, "col_x");
960            }
961            other => panic!("expected Describe with col, got {other:?}"),
962        }
963    }
964
965    #[test]
966    fn spark_describe_qualified_table_with_col() {
967        let s = parse_spark_sql("DESCRIBE global_temp.v id").unwrap();
968        match s {
969            SparkStatement::Describe {
970                table,
971                col: Some(c),
972                extended: false,
973            } => {
974                assert_eq!(table.to_string(), "global_temp.v");
975                assert_eq!(c.value, "id");
976            }
977            other => panic!("expected Describe qualified table + col, got {other:?}"),
978        }
979    }
980
981    #[test]
982    fn spark_parse_spark_sql_empty_fails() {
983        let err = parse_spark_sql("").unwrap_err();
984        assert!(
985            err.0.contains("expected exactly one statement") || err.0.contains("parse error"),
986            "unexpected error: {}",
987            err.0
988        );
989    }
990
991    #[test]
992    fn spark_parse_spark_sql_whitespace_only_fails() {
993        let err = parse_spark_sql("   \t\n  ").unwrap_err();
994        assert!(!err.0.is_empty(), "expected some error message");
995    }
996
997    #[test]
998    fn spark_parse_spark_sql_multiple_statements_fails() {
999        let err = parse_spark_sql("SELECT 1; SELECT 2").unwrap_err();
1000        assert!(err.0.contains("expected exactly one statement"));
1001    }
1002
1003    #[test]
1004    fn spark_parse_spark_sql_fallback_select() {
1005        let s = parse_spark_sql("SELECT 1 AS x").unwrap();
1006        match s {
1007            SparkStatement::Sqlparser(stmt) if matches!(stmt.as_ref(), Statement::Query(_)) => {}
1008            other => panic!("expected Sqlparser(Query), got {other:?}"),
1009        }
1010    }
1011
1012    #[test]
1013    fn spark_parse_spark_sql_fallback_create_schema() {
1014        let s = parse_spark_sql("CREATE SCHEMA foo").unwrap();
1015        match s {
1016            SparkStatement::Sqlparser(stmt)
1017                if matches!(stmt.as_ref(), Statement::CreateSchema { .. }) => {}
1018            other => panic!("expected Sqlparser(CreateSchema), got {other:?}"),
1019        }
1020    }
1021
1022    #[test]
1023    fn spark_parse_spark_sql_fallback_drop_table() {
1024        let s = parse_spark_sql("DROP TABLE IF EXISTS t").unwrap();
1025        match s {
1026            SparkStatement::Sqlparser(stmt) if matches!(stmt.as_ref(), Statement::Drop { .. }) => {}
1027            other => panic!("expected Sqlparser(Drop), got {other:?}"),
1028        }
1029    }
1030
1031    // ========== Robust parse_select_expr tests ==========
1032
1033    #[test]
1034    fn parse_select_expr_empty_fails() {
1035        let err = parse_select_expr("").unwrap_err();
1036        assert!(err.0.contains("expected an expression"));
1037    }
1038
1039    #[test]
1040    fn parse_select_expr_whitespace_only_fails() {
1041        let err = parse_select_expr("   \n\t  ").unwrap_err();
1042        assert!(err.0.contains("expected an expression"));
1043    }
1044
1045    #[test]
1046    fn parse_select_expr_literal_number() {
1047        let (e, a) = parse_select_expr("42").unwrap();
1048        assert!(matches!(e, SqlExpr::Value(_)));
1049        assert!(a.is_none());
1050    }
1051
1052    #[test]
1053    fn parse_select_expr_literal_string() {
1054        let (e, _) = parse_select_expr("'hello'").unwrap();
1055        assert!(matches!(e, SqlExpr::Value(_)));
1056    }
1057
1058    #[test]
1059    fn parse_select_expr_literal_null() {
1060        let (e, _) = parse_select_expr("NULL").unwrap();
1061        assert!(matches!(e, SqlExpr::Value(_)));
1062    }
1063
1064    #[test]
1065    fn parse_select_expr_identifier() {
1066        let (e, _) = parse_select_expr("column_name").unwrap();
1067        assert!(matches!(e, SqlExpr::Identifier(_)));
1068    }
1069
1070    #[test]
1071    fn parse_select_expr_compound_identifier() {
1072        let (e, _) = parse_select_expr("t.id").unwrap();
1073        assert!(matches!(e, SqlExpr::CompoundIdentifier(_)));
1074    }
1075
1076    #[test]
1077    fn parse_select_expr_binary_op() {
1078        let (e, _) = parse_select_expr("a + b").unwrap();
1079        assert!(matches!(e, SqlExpr::BinaryOp { .. }));
1080    }
1081
1082    #[test]
1083    fn parse_select_expr_function_call() {
1084        let (e, a) = parse_select_expr("COUNT(*)").unwrap();
1085        assert!(matches!(e, SqlExpr::Function(_)));
1086        assert!(a.is_none());
1087    }
1088
1089    #[test]
1090    fn parse_select_expr_function_with_alias() {
1091        let (e, a) = parse_select_expr("SUM(amount) AS total").unwrap();
1092        assert!(matches!(e, SqlExpr::Function(_)));
1093        assert_eq!(a.as_ref().map(|i| i.value.as_str()), Some("total"));
1094    }
1095
1096    #[test]
1097    fn parse_select_expr_nested_function() {
1098        let (_e, a) = parse_select_expr("UPPER(TRIM(name))").unwrap();
1099        assert!(a.is_none());
1100    }
1101
1102    #[test]
1103    fn parse_select_expr_case_when() {
1104        let (e, _) = parse_select_expr("CASE WHEN x > 0 THEN 1 ELSE 0 END").unwrap();
1105        assert!(matches!(e, SqlExpr::Case { .. }));
1106    }
1107
1108    #[test]
1109    fn parse_select_expr_comparison() {
1110        let (e, _) = parse_select_expr("id = 1").unwrap();
1111        assert!(matches!(e, SqlExpr::BinaryOp { .. }));
1112    }
1113
1114    #[test]
1115    fn parse_select_expr_invalid_syntax_fails() {
1116        // Unmatched parenthesis or invalid token sequence should yield a parse error.
1117        let err = parse_select_expr("( unclosed").unwrap_err();
1118        assert!(!err.0.is_empty());
1119    }
1120
1121    // ========== CREATE OR REPLACE TABLE ... USING <format> AS SELECT tests (#1462) ==========
1122
1123    #[test]
1124    fn spark_create_or_replace_table_using_delta_as_select() {
1125        let sql = "CREATE OR REPLACE TABLE my_table USING delta AS SELECT id, name FROM source";
1126        let s = parse_spark_sql(sql).unwrap();
1127        match s {
1128            SparkStatement::CreateOrReplaceTableAs {
1129                table,
1130                format,
1131                query,
1132            } => {
1133                assert_eq!(table.to_string(), "my_table");
1134                assert_eq!(format.to_lowercase(), "delta");
1135                // Verify query is a Query
1136                assert!(matches!(query.body.as_ref(), SetExpr::Select(_)));
1137            }
1138            other => panic!("expected CreateOrReplaceTableAs, got {other:?}"),
1139        }
1140    }
1141
1142    #[test]
1143    fn spark_create_or_replace_table_qualified_name() {
1144        let sql =
1145            "CREATE OR REPLACE TABLE schema1.my_table USING parquet AS SELECT * FROM other_table";
1146        let s = parse_spark_sql(sql).unwrap();
1147        match s {
1148            SparkStatement::CreateOrReplaceTableAs {
1149                table,
1150                format,
1151                query: _,
1152            } => {
1153                assert_eq!(table.to_string(), "schema1.my_table");
1154                assert_eq!(format.to_lowercase(), "parquet");
1155            }
1156            other => panic!("expected CreateOrReplaceTableAs, got {other:?}"),
1157        }
1158    }
1159
1160    #[test]
1161    fn spark_create_or_replace_table_multiline() {
1162        let sql = r#"
1163            CREATE OR REPLACE TABLE clean_events
1164            USING delta AS
1165            SELECT user_id, name, value, '2025-01-01' AS processed_at
1166            FROM raw_events
1167        "#;
1168        let s = parse_spark_sql(sql).unwrap();
1169        match s {
1170            SparkStatement::CreateOrReplaceTableAs {
1171                table,
1172                format,
1173                query: _,
1174            } => {
1175                assert_eq!(table.to_string(), "clean_events");
1176                assert_eq!(format.to_lowercase(), "delta");
1177            }
1178            other => panic!("expected CreateOrReplaceTableAs, got {other:?}"),
1179        }
1180    }
1181
1182    #[test]
1183    fn spark_create_or_replace_table_case_insensitive() {
1184        let sql = "create or replace table T using DELTA as select 1";
1185        let s = parse_spark_sql(sql).unwrap();
1186        match s {
1187            SparkStatement::CreateOrReplaceTableAs { table, format, .. } => {
1188                assert_eq!(table.to_string(), "T");
1189                assert_eq!(format.to_uppercase(), "DELTA");
1190            }
1191            other => panic!("expected CreateOrReplaceTableAs, got {other:?}"),
1192        }
1193    }
1194
1195    #[test]
1196    fn spark_create_table_without_or_replace_falls_through() {
1197        // Regular CREATE TABLE should fall through to upstream parser
1198        let sql = "CREATE TABLE t (id INT)";
1199        let s = parse_spark_sql(sql).unwrap();
1200        match s {
1201            SparkStatement::Sqlparser(stmt) => {
1202                assert!(matches!(stmt.as_ref(), Statement::CreateTable(_)));
1203            }
1204            other => panic!("expected Sqlparser(CreateTable), got {other:?}"),
1205        }
1206    }
1207}