sql_cli/sql/parser/
ast_formatter.rs

1//! AST-based SQL Formatter
2//!
3//! This module provides proper SQL formatting by traversing the parsed AST,
4//! which is more reliable than regex-based formatting and handles complex
5//! features like CTEs, subqueries, and expressions correctly.
6
7use crate::sql::parser::ast::*;
8use std::fmt::Write;
9
10/// Configuration for SQL formatting
11pub struct FormatConfig {
12    /// Indentation string (e.g., "  " for 2 spaces, "\t" for tab)
13    pub indent: String,
14    /// Maximum number of items per line for lists (SELECT columns, etc.)
15    pub items_per_line: usize,
16    /// Whether to uppercase keywords
17    pub uppercase_keywords: bool,
18    /// Whether to add newlines between major clauses
19    pub compact: bool,
20}
21
22impl Default for FormatConfig {
23    fn default() -> Self {
24        Self {
25            indent: "    ".to_string(),
26            items_per_line: 5,
27            uppercase_keywords: true,
28            compact: false,
29        }
30    }
31}
32
33/// Format a SELECT statement into pretty SQL
34pub fn format_select_statement(stmt: &SelectStatement) -> String {
35    format_select_with_config(stmt, &FormatConfig::default())
36}
37
38/// Format a SELECT statement with custom configuration
39pub fn format_select_with_config(stmt: &SelectStatement, config: &FormatConfig) -> String {
40    let formatter = AstFormatter::new(config);
41    formatter.format_select(stmt, 0)
42}
43
44struct AstFormatter<'a> {
45    config: &'a FormatConfig,
46}
47
48impl<'a> AstFormatter<'a> {
49    fn new(config: &'a FormatConfig) -> Self {
50        Self { config }
51    }
52
53    fn keyword(&self, word: &str) -> String {
54        if self.config.uppercase_keywords {
55            word.to_uppercase()
56        } else {
57            word.to_lowercase()
58        }
59    }
60
61    fn indent(&self, level: usize) -> String {
62        self.config.indent.repeat(level)
63    }
64
65    fn format_select(&self, stmt: &SelectStatement, indent_level: usize) -> String {
66        let mut result = String::new();
67        let indent = self.indent(indent_level);
68
69        // CTEs (WITH clause)
70        if !stmt.ctes.is_empty() {
71            writeln!(&mut result, "{}{}", indent, self.keyword("WITH")).unwrap();
72            for (i, cte) in stmt.ctes.iter().enumerate() {
73                let is_last = i == stmt.ctes.len() - 1;
74                self.format_cte(&mut result, cte, indent_level + 1, is_last);
75            }
76        }
77
78        // SELECT clause
79        write!(&mut result, "{}{}", indent, self.keyword("SELECT")).unwrap();
80        if stmt.distinct {
81            write!(&mut result, " {}", self.keyword("DISTINCT")).unwrap();
82        }
83
84        // Format select items
85        if stmt.select_items.is_empty() && !stmt.columns.is_empty() {
86            // Legacy columns field
87            self.format_column_list(&mut result, &stmt.columns, indent_level);
88        } else {
89            self.format_select_items(&mut result, &stmt.select_items, indent_level);
90        }
91
92        // FROM clause
93        if let Some(ref table) = stmt.from_table {
94            writeln!(&mut result).unwrap();
95            write!(&mut result, "{}{} {}", indent, self.keyword("FROM"), table).unwrap();
96        } else if let Some(ref subquery) = stmt.from_subquery {
97            writeln!(&mut result).unwrap();
98            write!(&mut result, "{}{} (", indent, self.keyword("FROM")).unwrap();
99            writeln!(&mut result).unwrap();
100            let subquery_sql = self.format_select(subquery, indent_level + 1);
101            write!(&mut result, "{}", subquery_sql).unwrap();
102            write!(&mut result, "\n{}", indent).unwrap();
103            write!(&mut result, ")").unwrap();
104            if let Some(ref alias) = stmt.from_alias {
105                write!(&mut result, " {} {}", self.keyword("AS"), alias).unwrap();
106            }
107        } else if let Some(ref func) = stmt.from_function {
108            writeln!(&mut result).unwrap();
109            write!(&mut result, "{}{} ", indent, self.keyword("FROM")).unwrap();
110            self.format_table_function(&mut result, func);
111            if let Some(ref alias) = stmt.from_alias {
112                write!(&mut result, " {} {}", self.keyword("AS"), alias).unwrap();
113            }
114        }
115
116        // JOIN clauses
117        for join in &stmt.joins {
118            writeln!(&mut result).unwrap();
119            self.format_join(&mut result, join, indent_level);
120        }
121
122        // WHERE clause
123        if let Some(ref where_clause) = stmt.where_clause {
124            writeln!(&mut result).unwrap();
125            write!(&mut result, "{}{}", indent, self.keyword("WHERE")).unwrap();
126            self.format_where_clause(&mut result, where_clause, indent_level);
127        }
128
129        // GROUP BY clause
130        if let Some(ref group_by) = stmt.group_by {
131            writeln!(&mut result).unwrap();
132            write!(&mut result, "{}{} ", indent, self.keyword("GROUP BY")).unwrap();
133            for (i, expr) in group_by.iter().enumerate() {
134                if i > 0 {
135                    write!(&mut result, ", ").unwrap();
136                }
137                write!(&mut result, "{}", self.format_expression(expr)).unwrap();
138            }
139        }
140
141        // HAVING clause
142        if let Some(ref having) = stmt.having {
143            writeln!(&mut result).unwrap();
144            write!(
145                &mut result,
146                "{}{} {}",
147                indent,
148                self.keyword("HAVING"),
149                self.format_expression(having)
150            )
151            .unwrap();
152        }
153
154        // ORDER BY clause
155        if let Some(ref order_by) = stmt.order_by {
156            writeln!(&mut result).unwrap();
157            write!(&mut result, "{}{} ", indent, self.keyword("ORDER BY")).unwrap();
158            for (i, col) in order_by.iter().enumerate() {
159                if i > 0 {
160                    write!(&mut result, ", ").unwrap();
161                }
162                write!(&mut result, "{}", col.column).unwrap();
163                match col.direction {
164                    SortDirection::Asc => write!(&mut result, " {}", self.keyword("ASC")).unwrap(),
165                    SortDirection::Desc => {
166                        write!(&mut result, " {}", self.keyword("DESC")).unwrap()
167                    }
168                }
169            }
170        }
171
172        // LIMIT clause
173        if let Some(limit) = stmt.limit {
174            writeln!(&mut result).unwrap();
175            write!(&mut result, "{}{} {}", indent, self.keyword("LIMIT"), limit).unwrap();
176        }
177
178        // OFFSET clause
179        if let Some(offset) = stmt.offset {
180            writeln!(&mut result).unwrap();
181            write!(
182                &mut result,
183                "{}{} {}",
184                indent,
185                self.keyword("OFFSET"),
186                offset
187            )
188            .unwrap();
189        }
190
191        result
192    }
193
194    fn format_cte(&self, result: &mut String, cte: &CTE, indent_level: usize, is_last: bool) {
195        let indent = self.indent(indent_level);
196        write!(result, "{}{}", indent, cte.name).unwrap();
197
198        if let Some(ref columns) = cte.column_list {
199            write!(result, "(").unwrap();
200            for (i, col) in columns.iter().enumerate() {
201                if i > 0 {
202                    write!(result, ", ").unwrap();
203                }
204                write!(result, "{}", col).unwrap();
205            }
206            write!(result, ")").unwrap();
207        }
208
209        writeln!(result, " {} (", self.keyword("AS")).unwrap();
210        let cte_sql = self.format_select(&cte.query, indent_level + 1);
211        write!(result, "{}", cte_sql).unwrap();
212        writeln!(result).unwrap();
213        write!(result, "{}", indent).unwrap();
214        if is_last {
215            writeln!(result, ")").unwrap();
216        } else {
217            writeln!(result, "),").unwrap();
218        }
219    }
220
221    fn format_column_list(&self, result: &mut String, columns: &[String], indent_level: usize) {
222        if columns.len() <= self.config.items_per_line {
223            // Single line
224            write!(result, " ").unwrap();
225            for (i, col) in columns.iter().enumerate() {
226                if i > 0 {
227                    write!(result, ", ").unwrap();
228                }
229                write!(result, "{}", col).unwrap();
230            }
231        } else {
232            // Multi-line
233            writeln!(result).unwrap();
234            let indent = self.indent(indent_level + 1);
235            for (i, col) in columns.iter().enumerate() {
236                write!(result, "{}{}", indent, col).unwrap();
237                if i < columns.len() - 1 {
238                    writeln!(result, ",").unwrap();
239                }
240            }
241        }
242    }
243
244    fn format_select_items(&self, result: &mut String, items: &[SelectItem], indent_level: usize) {
245        if items.is_empty() {
246            write!(result, " *").unwrap();
247            return;
248        }
249
250        // Count non-star items for formatting decision
251        let non_star_count = items
252            .iter()
253            .filter(|i| !matches!(i, SelectItem::Star))
254            .count();
255
256        if non_star_count <= self.config.items_per_line {
257            // Single line
258            write!(result, " ").unwrap();
259            for (i, item) in items.iter().enumerate() {
260                if i > 0 {
261                    write!(result, ", ").unwrap();
262                }
263                self.format_select_item(result, item);
264            }
265        } else {
266            // Multi-line
267            writeln!(result).unwrap();
268            let indent = self.indent(indent_level + 1);
269            for (i, item) in items.iter().enumerate() {
270                write!(result, "{}", indent).unwrap();
271                self.format_select_item(result, item);
272                if i < items.len() - 1 {
273                    writeln!(result, ",").unwrap();
274                }
275            }
276        }
277    }
278
279    fn format_select_item(&self, result: &mut String, item: &SelectItem) {
280        match item {
281            SelectItem::Star => write!(result, "*").unwrap(),
282            SelectItem::Column(col) => write!(result, "{}", col).unwrap(),
283            SelectItem::Expression { expr, alias } => {
284                write!(
285                    result,
286                    "{} {} {}",
287                    self.format_expression(expr),
288                    self.keyword("AS"),
289                    alias
290                )
291                .unwrap();
292            }
293        }
294    }
295
296    fn format_expression(&self, expr: &SqlExpression) -> String {
297        match expr {
298            SqlExpression::Column(name) => name.clone(),
299            SqlExpression::StringLiteral(s) => format!("'{}'", s),
300            SqlExpression::NumberLiteral(n) => n.clone(),
301            SqlExpression::BooleanLiteral(b) => b.to_string().to_uppercase(),
302            SqlExpression::Null => self.keyword("NULL"),
303            SqlExpression::BinaryOp { left, op, right } => {
304                format!(
305                    "{} {} {}",
306                    self.format_expression(left),
307                    op,
308                    self.format_expression(right)
309                )
310            }
311            SqlExpression::FunctionCall {
312                name,
313                args,
314                distinct,
315            } => {
316                let mut result = name.clone();
317                result.push('(');
318                if *distinct {
319                    result.push_str(&self.keyword("DISTINCT"));
320                    result.push(' ');
321                }
322                for (i, arg) in args.iter().enumerate() {
323                    if i > 0 {
324                        result.push_str(", ");
325                    }
326                    result.push_str(&self.format_expression(arg));
327                }
328                result.push(')');
329                result
330            }
331            SqlExpression::CaseExpression {
332                when_branches,
333                else_branch,
334            } => {
335                let mut result = self.keyword("CASE");
336                for branch in when_branches {
337                    result.push_str(&format!(
338                        " {} {} {} {}",
339                        self.keyword("WHEN"),
340                        self.format_expression(&branch.condition),
341                        self.keyword("THEN"),
342                        self.format_expression(&branch.result)
343                    ));
344                }
345                if let Some(else_expr) = else_branch {
346                    result.push_str(&format!(
347                        " {} {}",
348                        self.keyword("ELSE"),
349                        self.format_expression(else_expr)
350                    ));
351                }
352                result.push_str(&format!(" {}", self.keyword("END")));
353                result
354            }
355            SqlExpression::Between { expr, lower, upper } => {
356                format!(
357                    "{} {} {} {} {}",
358                    self.format_expression(expr),
359                    self.keyword("BETWEEN"),
360                    self.format_expression(lower),
361                    self.keyword("AND"),
362                    self.format_expression(upper)
363                )
364            }
365            SqlExpression::InList { expr, values } => {
366                let mut result =
367                    format!("{} {} (", self.format_expression(expr), self.keyword("IN"));
368                for (i, val) in values.iter().enumerate() {
369                    if i > 0 {
370                        result.push_str(", ");
371                    }
372                    result.push_str(&self.format_expression(val));
373                }
374                result.push(')');
375                result
376            }
377            SqlExpression::NotInList { expr, values } => {
378                let mut result = format!(
379                    "{} {} {} (",
380                    self.format_expression(expr),
381                    self.keyword("NOT"),
382                    self.keyword("IN")
383                );
384                for (i, val) in values.iter().enumerate() {
385                    if i > 0 {
386                        result.push_str(", ");
387                    }
388                    result.push_str(&self.format_expression(val));
389                }
390                result.push(')');
391                result
392            }
393            SqlExpression::Not { expr } => {
394                format!("{} {}", self.keyword("NOT"), self.format_expression(expr))
395            }
396            SqlExpression::ScalarSubquery { query } => {
397                // Check if subquery is complex enough to warrant multi-line formatting
398                let subquery_str = self.format_select(query, 0);
399                if subquery_str.contains('\n') || subquery_str.len() > 60 {
400                    // Multi-line formatting
401                    format!("(\n{}\n)", self.format_select(query, 1))
402                } else {
403                    // Inline formatting
404                    format!("({})", subquery_str)
405                }
406            }
407            SqlExpression::InSubquery { expr, subquery } => {
408                let subquery_str = self.format_select(subquery, 0);
409                if subquery_str.contains('\n') || subquery_str.len() > 60 {
410                    // Multi-line formatting
411                    format!(
412                        "{} {} (\n{}\n)",
413                        self.format_expression(expr),
414                        self.keyword("IN"),
415                        self.format_select(subquery, 1)
416                    )
417                } else {
418                    // Inline formatting
419                    format!(
420                        "{} {} ({})",
421                        self.format_expression(expr),
422                        self.keyword("IN"),
423                        subquery_str
424                    )
425                }
426            }
427            SqlExpression::NotInSubquery { expr, subquery } => {
428                let subquery_str = self.format_select(subquery, 0);
429                if subquery_str.contains('\n') || subquery_str.len() > 60 {
430                    // Multi-line formatting
431                    format!(
432                        "{} {} {} (\n{}\n)",
433                        self.format_expression(expr),
434                        self.keyword("NOT"),
435                        self.keyword("IN"),
436                        self.format_select(subquery, 1)
437                    )
438                } else {
439                    // Inline formatting
440                    format!(
441                        "{} {} {} ({})",
442                        self.format_expression(expr),
443                        self.keyword("NOT"),
444                        self.keyword("IN"),
445                        subquery_str
446                    )
447                }
448            }
449            SqlExpression::MethodCall {
450                object,
451                method,
452                args,
453            } => {
454                let mut result = format!("{}.{}", object, method);
455                result.push('(');
456                for (i, arg) in args.iter().enumerate() {
457                    if i > 0 {
458                        result.push_str(", ");
459                    }
460                    result.push_str(&self.format_expression(arg));
461                }
462                result.push(')');
463                result
464            }
465            SqlExpression::ChainedMethodCall { base, method, args } => {
466                let mut result = format!("{}.{}", self.format_expression(base), method);
467                result.push('(');
468                for (i, arg) in args.iter().enumerate() {
469                    if i > 0 {
470                        result.push_str(", ");
471                    }
472                    result.push_str(&self.format_expression(arg));
473                }
474                result.push(')');
475                result
476            }
477            _ => format!("{:?}", expr), // Fallback for unhandled expression types
478        }
479    }
480
481    fn format_where_clause(
482        &self,
483        result: &mut String,
484        where_clause: &WhereClause,
485        indent_level: usize,
486    ) {
487        let needs_multiline = where_clause.conditions.len() > 1;
488
489        if needs_multiline {
490            writeln!(result).unwrap();
491            let indent = self.indent(indent_level + 1);
492            for (i, condition) in where_clause.conditions.iter().enumerate() {
493                if i > 0 {
494                    if let Some(ref connector) = where_clause.conditions[i - 1].connector {
495                        let connector_str = match connector {
496                            LogicalOp::And => self.keyword("AND"),
497                            LogicalOp::Or => self.keyword("OR"),
498                        };
499                        writeln!(result).unwrap();
500                        write!(result, "{}{} ", indent, connector_str).unwrap();
501                    }
502                } else {
503                    write!(result, "{}", indent).unwrap();
504                }
505                write!(result, "{}", self.format_expression(&condition.expr)).unwrap();
506            }
507        } else if let Some(condition) = where_clause.conditions.first() {
508            write!(result, " {}", self.format_expression(&condition.expr)).unwrap();
509        }
510    }
511
512    fn format_join(&self, result: &mut String, join: &JoinClause, indent_level: usize) {
513        let indent = self.indent(indent_level);
514        let join_type = match join.join_type {
515            JoinType::Inner => self.keyword("INNER JOIN"),
516            JoinType::Left => self.keyword("LEFT JOIN"),
517            JoinType::Right => self.keyword("RIGHT JOIN"),
518            JoinType::Full => self.keyword("FULL JOIN"),
519            JoinType::Cross => self.keyword("CROSS JOIN"),
520        };
521
522        write!(result, "{}{} ", indent, join_type).unwrap();
523
524        match &join.table {
525            TableSource::Table(name) => write!(result, "{}", name).unwrap(),
526            TableSource::DerivedTable { query, alias } => {
527                writeln!(result, "(").unwrap();
528                let subquery_sql = self.format_select(query, indent_level + 1);
529                write!(result, "{}", subquery_sql).unwrap();
530                writeln!(result).unwrap();
531                write!(result, "{}) {} {}", indent, self.keyword("AS"), alias).unwrap();
532            }
533        }
534
535        if let Some(ref alias) = join.alias {
536            write!(result, " {} {}", self.keyword("AS"), alias).unwrap();
537        }
538
539        write!(
540            result,
541            " {} {} {} {}",
542            self.keyword("ON"),
543            join.condition.left_column,
544            self.format_join_operator(&join.condition.operator),
545            join.condition.right_column
546        )
547        .unwrap();
548    }
549
550    fn format_join_operator(&self, op: &JoinOperator) -> String {
551        match op {
552            JoinOperator::Equal => "=",
553            JoinOperator::NotEqual => "!=",
554            JoinOperator::LessThan => "<",
555            JoinOperator::GreaterThan => ">",
556            JoinOperator::LessThanOrEqual => "<=",
557            JoinOperator::GreaterThanOrEqual => ">=",
558        }
559        .to_string()
560    }
561
562    fn format_table_function(&self, result: &mut String, func: &TableFunction) {
563        match func {
564            TableFunction::Range { start, end, step } => {
565                write!(result, "{}(", self.keyword("RANGE")).unwrap();
566                write!(
567                    result,
568                    "{}, {}",
569                    self.format_expression(start),
570                    self.format_expression(end)
571                )
572                .unwrap();
573                if let Some(step_expr) = step {
574                    write!(result, ", {}", self.format_expression(step_expr)).unwrap();
575                }
576                write!(result, ")").unwrap();
577            }
578        }
579    }
580}
581
582/// Parse and format SQL query using the AST
583pub fn format_sql_ast(query: &str) -> Result<String, String> {
584    use crate::sql::recursive_parser::Parser;
585
586    let mut parser = Parser::new(query);
587    match parser.parse() {
588        Ok(stmt) => Ok(format_select_statement(&stmt)),
589        Err(e) => Err(format!("Parse error: {}", e)),
590    }
591}
592
593/// Parse and format SQL with custom configuration
594pub fn format_sql_ast_with_config(query: &str, config: &FormatConfig) -> Result<String, String> {
595    use crate::sql::recursive_parser::Parser;
596
597    let mut parser = Parser::new(query);
598    match parser.parse() {
599        Ok(stmt) => Ok(format_select_with_config(&stmt, &config)),
600        Err(e) => Err(format!("Parse error: {}", e)),
601    }
602}