sql_cli/sql/parser/
formatter.rs

1// SQL Formatter Module
2// Handles pretty-printing of SQL queries and AST structures
3
4use super::ast::{LogicalOp, SelectStatement, SortDirection, SqlExpression, WhereClause};
5use super::lexer::{Lexer, Token};
6use crate::sql::recursive_parser::Parser;
7
8#[must_use]
9pub fn format_sql_pretty(query: &str) -> Vec<String> {
10    format_sql_pretty_compact(query, 5) // Default to 5 columns per line
11}
12
13// Pretty print AST for debug visualization
14#[must_use]
15pub fn format_ast_tree(query: &str) -> String {
16    let mut parser = Parser::new(query);
17    match parser.parse() {
18        Ok(stmt) => format_select_statement(&stmt, 0),
19        Err(e) => format!("❌ PARSE ERROR ❌\n{e}\n\n⚠️  The query could not be parsed correctly.\n💡 Check parentheses, operators, and syntax."),
20    }
21}
22
23fn format_select_statement(stmt: &SelectStatement, indent: usize) -> String {
24    let mut result = String::new();
25    let indent_str = "  ".repeat(indent);
26
27    result.push_str(&format!("{indent_str}SelectStatement {{\n"));
28
29    // Format columns
30    result.push_str(&format!("{indent_str}  columns: ["));
31    if stmt.columns.is_empty() {
32        result.push_str("],\n");
33    } else {
34        result.push('\n');
35        for col in &stmt.columns {
36            result.push_str(&format!("{indent_str}    \"{col}\",\n"));
37        }
38        result.push_str(&format!("{indent_str}  ],\n"));
39    }
40
41    // Format from table
42    if let Some(table) = &stmt.from_table {
43        result.push_str(&format!("{indent_str}  from_table: \"{table}\",\n"));
44    }
45
46    // Format where clause
47    if let Some(where_clause) = &stmt.where_clause {
48        result.push_str(&format!("{indent_str}  where_clause: {{\n"));
49        result.push_str(&format_where_clause(where_clause, indent + 2));
50        result.push_str(&format!("{indent_str}  }},\n"));
51    }
52
53    // Format order by
54    if let Some(order_by) = &stmt.order_by {
55        result.push_str(&format!("{indent_str}  order_by: ["));
56        if order_by.is_empty() {
57            result.push_str("],\n");
58        } else {
59            result.push('\n');
60            for col in order_by {
61                let dir = match col.direction {
62                    SortDirection::Asc => "ASC",
63                    SortDirection::Desc => "DESC",
64                };
65                result.push_str(&format!(
66                    "{indent_str}    {{ column: \"{}\", direction: {dir} }},\n",
67                    col.column
68                ));
69            }
70            result.push_str(&format!("{indent_str}  ],\n"));
71        }
72    }
73
74    // Format group by
75    if let Some(group_by) = &stmt.group_by {
76        result.push_str(&format!("{indent_str}  group_by: ["));
77        if group_by.is_empty() {
78            result.push_str("],\n");
79        } else {
80            result.push('\n');
81            for col in group_by {
82                result.push_str(&format!("{indent_str}    \"{col}\",\n"));
83            }
84            result.push_str(&format!("{indent_str}  ],\n"));
85        }
86    }
87
88    // Format limit
89    if let Some(limit) = stmt.limit {
90        result.push_str(&format!("{indent_str}  limit: {limit},\n"));
91    }
92
93    // Format distinct
94    if stmt.distinct {
95        result.push_str(&format!("{indent_str}  distinct: true,\n"));
96    }
97
98    result.push_str(&format!("{indent_str}}}\n"));
99    result
100}
101
102fn format_where_clause(clause: &WhereClause, indent: usize) -> String {
103    let mut result = String::new();
104    let indent_str = "  ".repeat(indent);
105
106    result.push_str(&format!("{indent_str}conditions: [\n"));
107    for (i, condition) in clause.conditions.iter().enumerate() {
108        result.push_str(&format!("{indent_str}  {{\n"));
109        result.push_str(&format!(
110            "{indent_str}    expr: {},\n",
111            format_expression_ast(&condition.expr)
112        ));
113
114        if let Some(connector) = &condition.connector {
115            let conn_str = match connector {
116                LogicalOp::And => "AND",
117                LogicalOp::Or => "OR",
118            };
119            result.push_str(&format!("{indent_str}    connector: {conn_str},\n"));
120        }
121
122        result.push_str(&format!("{indent_str}  }}"));
123        if i < clause.conditions.len() - 1 {
124            result.push(',');
125        }
126        result.push('\n');
127    }
128    result.push_str(&format!("{indent_str}]\n"));
129
130    result
131}
132
133pub fn format_expression_ast(expr: &SqlExpression) -> String {
134    match expr {
135        SqlExpression::Column(name) => format!("Column(\"{name}\")"),
136        SqlExpression::StringLiteral(value) => format!("StringLiteral(\"{value}\")"),
137        SqlExpression::NumberLiteral(value) => format!("NumberLiteral({value})"),
138        SqlExpression::BinaryOp { left, op, right } => {
139            format!(
140                "BinaryOp {{ left: {}, op: \"{op}\", right: {} }}",
141                format_expression_ast(left),
142                format_expression_ast(right)
143            )
144        }
145        SqlExpression::FunctionCall {
146            name,
147            args,
148            distinct,
149        } => {
150            let args_str = args
151                .iter()
152                .map(format_expression_ast)
153                .collect::<Vec<_>>()
154                .join(", ");
155            if *distinct {
156                format!("FunctionCall {{ name: \"{name}\", args: [{args_str}], distinct: true }}")
157            } else {
158                format!("FunctionCall {{ name: \"{name}\", args: [{args_str}] }}")
159            }
160        }
161        SqlExpression::MethodCall {
162            object,
163            method,
164            args,
165        } => {
166            let args_str = args
167                .iter()
168                .map(format_expression_ast)
169                .collect::<Vec<_>>()
170                .join(", ");
171            format!(
172                "MethodCall {{ object: \"{object}\", method: \"{method}\", args: [{args_str}] }}"
173            )
174        }
175        SqlExpression::InList { expr, values } => {
176            let values_str = values
177                .iter()
178                .map(format_expression_ast)
179                .collect::<Vec<_>>()
180                .join(", ");
181            format!(
182                "InList {{ expr: {}, values: [{values_str}] }}",
183                format_expression_ast(expr)
184            )
185        }
186        SqlExpression::NotInList { expr, values } => {
187            let values_str = values
188                .iter()
189                .map(format_expression_ast)
190                .collect::<Vec<_>>()
191                .join(", ");
192            format!(
193                "NotInList {{ expr: {}, values: [{values_str}] }}",
194                format_expression_ast(expr)
195            )
196        }
197        SqlExpression::Between { expr, lower, upper } => {
198            format!(
199                "Between {{ expr: {}, lower: {}, upper: {} }}",
200                format_expression_ast(expr),
201                format_expression_ast(lower),
202                format_expression_ast(upper)
203            )
204        }
205        SqlExpression::Null => "Null".to_string(),
206        SqlExpression::BooleanLiteral(b) => format!("BooleanLiteral({b})"),
207        SqlExpression::DateTimeConstructor {
208            year,
209            month,
210            day,
211            hour,
212            minute,
213            second,
214        } => {
215            let time_part = match (hour, minute, second) {
216                (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
217                (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
218                _ => String::new(),
219            };
220            format!("DateTimeConstructor({year}-{month:02}-{day:02}{time_part})")
221        }
222        SqlExpression::DateTimeToday {
223            hour,
224            minute,
225            second,
226        } => {
227            let time_part = match (hour, minute, second) {
228                (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
229                (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
230                _ => String::new(),
231            };
232            format!("DateTimeToday({time_part})")
233        }
234        SqlExpression::WindowFunction {
235            name,
236            args,
237            window_spec: _,
238        } => {
239            let args_str = args
240                .iter()
241                .map(format_expression_ast)
242                .collect::<Vec<_>>()
243                .join(", ");
244            format!("WindowFunction {{ name: \"{name}\", args: [{args_str}], window_spec: ... }}")
245        }
246        SqlExpression::ChainedMethodCall { base, method, args } => {
247            let args_str = args
248                .iter()
249                .map(format_expression_ast)
250                .collect::<Vec<_>>()
251                .join(", ");
252            format!(
253                "ChainedMethodCall {{ base: {}, method: \"{method}\", args: [{args_str}] }}",
254                format_expression_ast(base)
255            )
256        }
257        SqlExpression::Not { expr } => {
258            format!("Not {{ expr: {} }}", format_expression_ast(expr))
259        }
260        SqlExpression::CaseExpression {
261            when_branches,
262            else_branch,
263        } => {
264            let mut result = String::from("CaseExpression { when_branches: [");
265            for branch in when_branches {
266                result.push_str(&format!(
267                    " {{ condition: {}, result: {} }},",
268                    format_expression_ast(&branch.condition),
269                    format_expression_ast(&branch.result)
270                ));
271            }
272            result.push_str("], else_branch: ");
273            if let Some(else_expr) = else_branch {
274                result.push_str(&format_expression_ast(else_expr));
275            } else {
276                result.push_str("None");
277            }
278            result.push_str(" }");
279            result
280        }
281    }
282}
283
284// Helper function to extract text between positions
285fn extract_text_between_positions(text: &str, start: usize, end: usize) -> String {
286    if start >= text.len() || end > text.len() || start >= end {
287        return String::new();
288    }
289    text[start..end].to_string()
290}
291
292// Helper to find the position of a specific token in the query
293fn find_token_position(query: &str, target: Token, skip_count: usize) -> Option<usize> {
294    let mut lexer = Lexer::new(query);
295    let mut found_count = 0;
296
297    loop {
298        let pos = lexer.get_position();
299        let token = lexer.next_token();
300        if token == Token::Eof {
301            break;
302        }
303        if token == target {
304            if found_count == skip_count {
305                return Some(pos);
306            }
307            found_count += 1;
308        }
309    }
310    None
311}
312
313pub fn format_sql_with_preserved_parens(query: &str, cols_per_line: usize) -> Vec<String> {
314    let mut parser = Parser::new(query);
315    let stmt = match parser.parse() {
316        Ok(s) => s,
317        Err(_) => return vec![query.to_string()],
318    };
319
320    let mut lines = Vec::new();
321    let mut lexer = Lexer::new(query);
322    let mut tokens_with_pos = Vec::new();
323
324    // Collect all tokens with their positions
325    loop {
326        let pos = lexer.get_position();
327        let token = lexer.next_token();
328        if token == Token::Eof {
329            break;
330        }
331        tokens_with_pos.push((token, pos));
332    }
333
334    // Process SELECT clause
335    let mut i = 0;
336    while i < tokens_with_pos.len() {
337        match &tokens_with_pos[i].0 {
338            Token::Select => {
339                let _select_start = tokens_with_pos[i].1;
340                i += 1;
341
342                // Check for DISTINCT
343                let has_distinct = if i < tokens_with_pos.len() {
344                    matches!(tokens_with_pos[i].0, Token::Distinct)
345                } else {
346                    false
347                };
348
349                if has_distinct {
350                    i += 1;
351                }
352
353                // Find the end of SELECT clause (before FROM)
354                let _select_end = query.len();
355                let _col_count = 0;
356                let _current_line_cols: Vec<String> = Vec::new();
357                let mut all_select_lines = Vec::new();
358
359                // Determine if we should use pretty formatting
360                let use_pretty_format = stmt.columns.len() > cols_per_line;
361
362                if use_pretty_format {
363                    // Multi-line formatting
364                    let select_text = if has_distinct {
365                        "SELECT DISTINCT".to_string()
366                    } else {
367                        "SELECT".to_string()
368                    };
369                    all_select_lines.push(select_text);
370
371                    // Process columns with proper indentation
372                    for (idx, col) in stmt.columns.iter().enumerate() {
373                        let is_last = idx == stmt.columns.len() - 1;
374                        let col_text = if is_last {
375                            format!("    {col}")
376                        } else {
377                            format!("    {col},")
378                        };
379                        all_select_lines.push(col_text);
380                    }
381                } else {
382                    // Single-line formatting for few columns
383                    let mut select_line = if has_distinct {
384                        "SELECT DISTINCT ".to_string()
385                    } else {
386                        "SELECT ".to_string()
387                    };
388
389                    for (idx, col) in stmt.columns.iter().enumerate() {
390                        if idx > 0 {
391                            select_line.push_str(", ");
392                        }
393                        select_line.push_str(col);
394                    }
395                    all_select_lines.push(select_line);
396                }
397
398                lines.extend(all_select_lines);
399
400                // Skip tokens until we reach FROM
401                while i < tokens_with_pos.len() {
402                    match &tokens_with_pos[i].0 {
403                        Token::From => break,
404                        _ => i += 1,
405                    }
406                }
407            }
408            Token::From => {
409                let from_start = tokens_with_pos[i].1;
410                i += 1;
411
412                // Find the end of FROM clause
413                let mut from_end = query.len();
414                while i < tokens_with_pos.len() {
415                    match &tokens_with_pos[i].0 {
416                        Token::Where
417                        | Token::GroupBy
418                        | Token::OrderBy
419                        | Token::Limit
420                        | Token::Having
421                        | Token::Eof => {
422                            from_end = tokens_with_pos[i].1;
423                            break;
424                        }
425                        _ => i += 1,
426                    }
427                }
428
429                let from_text = extract_text_between_positions(query, from_start, from_end);
430                lines.push(from_text.trim().to_string());
431            }
432            Token::Where => {
433                let where_start = tokens_with_pos[i].1;
434                i += 1;
435
436                // Find the end of WHERE clause
437                let mut where_end = query.len();
438                let mut paren_depth = 0;
439                while i < tokens_with_pos.len() {
440                    match &tokens_with_pos[i].0 {
441                        Token::LeftParen => {
442                            paren_depth += 1;
443                            i += 1;
444                        }
445                        Token::RightParen => {
446                            paren_depth -= 1;
447                            i += 1;
448                        }
449                        Token::GroupBy
450                        | Token::OrderBy
451                        | Token::Limit
452                        | Token::Having
453                        | Token::Eof
454                            if paren_depth == 0 =>
455                        {
456                            where_end = tokens_with_pos[i].1;
457                            break;
458                        }
459                        _ => i += 1,
460                    }
461                }
462
463                let where_text = extract_text_between_positions(query, where_start, where_end);
464                let formatted_where = format_where_clause_with_parens(where_text.trim());
465                lines.extend(formatted_where);
466            }
467            Token::GroupBy => {
468                let group_start = tokens_with_pos[i].1;
469                i += 1;
470
471                // Skip BY token
472                if i < tokens_with_pos.len() && matches!(tokens_with_pos[i].0, Token::By) {
473                    i += 1;
474                }
475
476                // Find the end of GROUP BY clause
477                while i < tokens_with_pos.len() {
478                    match &tokens_with_pos[i].0 {
479                        Token::OrderBy | Token::Limit | Token::Having | Token::Eof => break,
480                        _ => i += 1,
481                    }
482                }
483
484                if i > 0 {
485                    let group_text = extract_text_between_positions(
486                        query,
487                        group_start,
488                        tokens_with_pos[i - 1].1,
489                    );
490                    lines.push(format!("GROUP BY {}", group_text.trim()));
491                }
492            }
493            _ => i += 1,
494        }
495    }
496
497    lines
498}
499
500fn format_where_clause_with_parens(where_text: &str) -> Vec<String> {
501    let mut lines = Vec::new();
502    let mut current = String::from("WHERE ");
503    let mut paren_depth = 0;
504    let mut in_string = false;
505    let mut escape_next = false;
506    let mut chars = where_text.chars().peekable();
507
508    // Skip "WHERE" if it's at the beginning
509    if where_text.trim_start().starts_with("WHERE") || where_text.trim_start().starts_with("where")
510    {
511        let skip_len = if where_text.trim_start().starts_with("WHERE") {
512            5
513        } else {
514            5
515        };
516        for _ in 0..skip_len {
517            chars.next();
518        }
519        // Skip whitespace after WHERE
520        while chars.peek() == Some(&' ') {
521            chars.next();
522        }
523    }
524
525    while let Some(ch) = chars.next() {
526        if escape_next {
527            current.push(ch);
528            escape_next = false;
529            continue;
530        }
531
532        match ch {
533            '\\' if in_string => {
534                current.push(ch);
535                escape_next = true;
536            }
537            '\'' => {
538                current.push(ch);
539                in_string = !in_string;
540            }
541            '(' if !in_string => {
542                current.push(ch);
543                paren_depth += 1;
544            }
545            ')' if !in_string => {
546                current.push(ch);
547                paren_depth -= 1;
548            }
549            _ => {
550                current.push(ch);
551            }
552        }
553    }
554
555    // Clean up the result
556    let cleaned = current.trim().to_string();
557    if !cleaned.is_empty() {
558        lines.push(cleaned);
559    }
560
561    lines
562}
563
564#[must_use]
565pub fn format_sql_pretty_compact(query: &str, cols_per_line: usize) -> Vec<String> {
566    // Use preserved parentheses formatting
567    let formatted = format_sql_with_preserved_parens(query, cols_per_line);
568
569    // Post-process to ensure clean output
570    formatted
571        .into_iter()
572        .filter(|line| !line.trim().is_empty())
573        .map(|line| {
574            // Ensure proper spacing after keywords
575            let mut result = line;
576            for keyword in &[
577                "SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "HAVING", "LIMIT",
578            ] {
579                let pattern = format!("{keyword}");
580                if result.starts_with(&pattern) && !result.starts_with(&format!("{keyword} ")) {
581                    result = format!("{keyword} {}", &result[keyword.len()..].trim_start());
582                }
583            }
584            result
585        })
586        .collect()
587}
588
589pub fn format_expression(expr: &SqlExpression) -> String {
590    match expr {
591        SqlExpression::Column(name) => name.clone(),
592        SqlExpression::StringLiteral(value) => format!("'{value}'"),
593        SqlExpression::NumberLiteral(value) => value.clone(),
594        SqlExpression::BinaryOp { left, op, right } => {
595            format!(
596                "{} {} {}",
597                format_expression(left),
598                op,
599                format_expression(right)
600            )
601        }
602        SqlExpression::FunctionCall {
603            name,
604            args,
605            distinct,
606        } => {
607            let args_str = args
608                .iter()
609                .map(format_expression)
610                .collect::<Vec<_>>()
611                .join(", ");
612            if *distinct {
613                format!("{name}(DISTINCT {args_str})")
614            } else {
615                format!("{name}({args_str})")
616            }
617        }
618        SqlExpression::MethodCall {
619            object,
620            method,
621            args,
622        } => {
623            let args_str = args
624                .iter()
625                .map(format_expression)
626                .collect::<Vec<_>>()
627                .join(", ");
628            if args.is_empty() {
629                format!("{object}.{method}()")
630            } else {
631                format!("{object}.{method}({args_str})")
632            }
633        }
634        SqlExpression::InList { expr, values } => {
635            let values_str = values
636                .iter()
637                .map(format_expression)
638                .collect::<Vec<_>>()
639                .join(", ");
640            format!("{} IN ({})", format_expression(expr), values_str)
641        }
642        SqlExpression::NotInList { expr, values } => {
643            let values_str = values
644                .iter()
645                .map(format_expression)
646                .collect::<Vec<_>>()
647                .join(", ");
648            format!("{} NOT IN ({})", format_expression(expr), values_str)
649        }
650        SqlExpression::Between { expr, lower, upper } => {
651            format!(
652                "{} BETWEEN {} AND {}",
653                format_expression(expr),
654                format_expression(lower),
655                format_expression(upper)
656            )
657        }
658        SqlExpression::Null => "NULL".to_string(),
659        SqlExpression::BooleanLiteral(b) => b.to_string().to_uppercase(),
660        SqlExpression::DateTimeConstructor {
661            year,
662            month,
663            day,
664            hour,
665            minute,
666            second,
667        } => {
668            let time_part = match (hour, minute, second) {
669                (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
670                (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
671                _ => String::new(),
672            };
673            format!("DATETIME({year}, {month}, {day}{time_part})")
674        }
675        SqlExpression::DateTimeToday {
676            hour,
677            minute,
678            second,
679        } => {
680            let time_part = match (hour, minute, second) {
681                (Some(h), Some(m), Some(s)) => format!(", {h}, {m}, {s}"),
682                (Some(h), Some(m), None) => format!(", {h}, {m}"),
683                (Some(h), None, None) => format!(", {h}"),
684                _ => String::new(),
685            };
686            format!("TODAY({time_part})")
687        }
688        SqlExpression::WindowFunction {
689            name,
690            args,
691            window_spec,
692        } => {
693            let args_str = args
694                .iter()
695                .map(format_expression)
696                .collect::<Vec<_>>()
697                .join(", ");
698
699            let mut result = format!("{name}({args_str}) OVER (");
700
701            // Format partition by
702            if !window_spec.partition_by.is_empty() {
703                result.push_str("PARTITION BY ");
704                result.push_str(&window_spec.partition_by.join(", "));
705            }
706
707            // Format order by
708            if !window_spec.order_by.is_empty() {
709                if !window_spec.partition_by.is_empty() {
710                    result.push(' ');
711                }
712                result.push_str("ORDER BY ");
713                let order_strs: Vec<String> = window_spec
714                    .order_by
715                    .iter()
716                    .map(|col| {
717                        let dir = match col.direction {
718                            SortDirection::Asc => " ASC",
719                            SortDirection::Desc => " DESC",
720                        };
721                        format!("{}{}", col.column, dir)
722                    })
723                    .collect();
724                result.push_str(&order_strs.join(", "));
725            }
726
727            result.push(')');
728            result
729        }
730        SqlExpression::ChainedMethodCall { base, method, args } => {
731            let base_str = format_expression(base);
732            let args_str = args
733                .iter()
734                .map(format_expression)
735                .collect::<Vec<_>>()
736                .join(", ");
737            if args.is_empty() {
738                format!("{base_str}.{method}()")
739            } else {
740                format!("{base_str}.{method}({args_str})")
741            }
742        }
743        SqlExpression::Not { expr } => {
744            format!("NOT {}", format_expression(expr))
745        }
746        SqlExpression::CaseExpression {
747            when_branches,
748            else_branch,
749        } => {
750            let mut result = String::from("CASE");
751            for branch in when_branches {
752                result.push_str(&format!(
753                    " WHEN {} THEN {}",
754                    format_expression(&branch.condition),
755                    format_expression(&branch.result)
756                ));
757            }
758            if let Some(else_expr) = else_branch {
759                result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
760            }
761            result.push_str(" END");
762            result
763        }
764    }
765}
766
767fn format_token(token: &Token) -> String {
768    match token {
769        Token::Identifier(s) => s.clone(),
770        Token::QuotedIdentifier(s) => format!("\"{s}\""),
771        Token::StringLiteral(s) => format!("'{s}'"),
772        Token::NumberLiteral(n) => n.clone(),
773        Token::DateTime => "DateTime".to_string(),
774        Token::Case => "CASE".to_string(),
775        Token::When => "WHEN".to_string(),
776        Token::Then => "THEN".to_string(),
777        Token::Else => "ELSE".to_string(),
778        Token::End => "END".to_string(),
779        Token::Distinct => "DISTINCT".to_string(),
780        Token::Over => "OVER".to_string(),
781        Token::Partition => "PARTITION".to_string(),
782        Token::By => "BY".to_string(),
783        Token::LeftParen => "(".to_string(),
784        Token::RightParen => ")".to_string(),
785        Token::Comma => ",".to_string(),
786        Token::Dot => ".".to_string(),
787        Token::Equal => "=".to_string(),
788        Token::NotEqual => "!=".to_string(),
789        Token::LessThan => "<".to_string(),
790        Token::GreaterThan => ">".to_string(),
791        Token::LessThanOrEqual => "<=".to_string(),
792        Token::GreaterThanOrEqual => ">=".to_string(),
793        Token::In => "IN".to_string(),
794        _ => format!("{token:?}").to_uppercase(),
795    }
796}