sql_cli/sql/parser/
formatter.rs

1// SQL Formatter Module
2// Handles pretty-printing of SQL queries and AST structures
3
4use super::ast::{LogicalOp, SelectStatement, SortDirection, SqlExpression, WhereClause};
5use super::lexer::{Lexer, Token};
6use crate::sql::recursive_parser::Parser;
7
8#[must_use]
9pub fn format_sql_pretty(query: &str) -> Vec<String> {
10    format_sql_pretty_compact(query, 5) // Default to 5 columns per line
11}
12
13// Pretty print AST for debug visualization
14#[must_use]
15pub fn format_ast_tree(query: &str) -> String {
16    let mut parser = Parser::new(query);
17    match parser.parse() {
18        Ok(stmt) => format_select_statement(&stmt, 0),
19        Err(e) => format!("❌ PARSE ERROR ❌\n{e}\n\n⚠️  The query could not be parsed correctly.\n💡 Check parentheses, operators, and syntax."),
20    }
21}
22
23fn format_select_statement(stmt: &SelectStatement, indent: usize) -> String {
24    let mut result = String::new();
25    let indent_str = "  ".repeat(indent);
26
27    result.push_str(&format!("{indent_str}SelectStatement {{\n"));
28
29    // Format columns
30    result.push_str(&format!("{indent_str}  columns: ["));
31    if stmt.columns.is_empty() {
32        result.push_str("],\n");
33    } else {
34        result.push('\n');
35        for col in &stmt.columns {
36            result.push_str(&format!("{indent_str}    \"{col}\",\n"));
37        }
38        result.push_str(&format!("{indent_str}  ],\n"));
39    }
40
41    // Format from table
42    if let Some(table) = &stmt.from_table {
43        result.push_str(&format!("{indent_str}  from_table: \"{table}\",\n"));
44    }
45
46    // Format where clause
47    if let Some(where_clause) = &stmt.where_clause {
48        result.push_str(&format!("{indent_str}  where_clause: {{\n"));
49        result.push_str(&format_where_clause(where_clause, indent + 2));
50        result.push_str(&format!("{indent_str}  }},\n"));
51    }
52
53    // Format order by
54    if let Some(order_by) = &stmt.order_by {
55        result.push_str(&format!("{indent_str}  order_by: ["));
56        if order_by.is_empty() {
57            result.push_str("],\n");
58        } else {
59            result.push('\n');
60            for col in order_by {
61                let dir = match col.direction {
62                    SortDirection::Asc => "ASC",
63                    SortDirection::Desc => "DESC",
64                };
65                result.push_str(&format!(
66                    "{indent_str}    {{ column: \"{}\", direction: {dir} }},\n",
67                    col.column
68                ));
69            }
70            result.push_str(&format!("{indent_str}  ],\n"));
71        }
72    }
73
74    // Format group by
75    if let Some(group_by) = &stmt.group_by {
76        result.push_str(&format!("{indent_str}  group_by: ["));
77        if group_by.is_empty() {
78            result.push_str("],\n");
79        } else {
80            result.push('\n');
81            for expr in group_by {
82                result.push_str(&format!("{indent_str}    \"{:?}\",\n", expr));
83            }
84            result.push_str(&format!("{indent_str}  ],\n"));
85        }
86    }
87
88    // Format limit
89    if let Some(limit) = stmt.limit {
90        result.push_str(&format!("{indent_str}  limit: {limit},\n"));
91    }
92
93    // Format distinct
94    if stmt.distinct {
95        result.push_str(&format!("{indent_str}  distinct: true,\n"));
96    }
97
98    result.push_str(&format!("{indent_str}}}\n"));
99    result
100}
101
102fn format_where_clause(clause: &WhereClause, indent: usize) -> String {
103    let mut result = String::new();
104    let indent_str = "  ".repeat(indent);
105
106    result.push_str(&format!("{indent_str}conditions: [\n"));
107    for (i, condition) in clause.conditions.iter().enumerate() {
108        result.push_str(&format!("{indent_str}  {{\n"));
109        result.push_str(&format!(
110            "{indent_str}    expr: {},\n",
111            format_expression_ast(&condition.expr)
112        ));
113
114        if let Some(connector) = &condition.connector {
115            let conn_str = match connector {
116                LogicalOp::And => "AND",
117                LogicalOp::Or => "OR",
118            };
119            result.push_str(&format!("{indent_str}    connector: {conn_str},\n"));
120        }
121
122        result.push_str(&format!("{indent_str}  }}"));
123        if i < clause.conditions.len() - 1 {
124            result.push(',');
125        }
126        result.push('\n');
127    }
128    result.push_str(&format!("{indent_str}]\n"));
129
130    result
131}
132
133pub fn format_expression_ast(expr: &SqlExpression) -> String {
134    match expr {
135        SqlExpression::Column(name) => format!("Column(\"{name}\")"),
136        SqlExpression::StringLiteral(value) => format!("StringLiteral(\"{value}\")"),
137        SqlExpression::NumberLiteral(value) => format!("NumberLiteral({value})"),
138        SqlExpression::BinaryOp { left, op, right } => {
139            format!(
140                "BinaryOp {{ left: {}, op: \"{op}\", right: {} }}",
141                format_expression_ast(left),
142                format_expression_ast(right)
143            )
144        }
145        SqlExpression::FunctionCall {
146            name,
147            args,
148            distinct,
149        } => {
150            let args_str = args
151                .iter()
152                .map(format_expression_ast)
153                .collect::<Vec<_>>()
154                .join(", ");
155            if *distinct {
156                format!("FunctionCall {{ name: \"{name}\", args: [{args_str}], distinct: true }}")
157            } else {
158                format!("FunctionCall {{ name: \"{name}\", args: [{args_str}] }}")
159            }
160        }
161        SqlExpression::MethodCall {
162            object,
163            method,
164            args,
165        } => {
166            let args_str = args
167                .iter()
168                .map(format_expression_ast)
169                .collect::<Vec<_>>()
170                .join(", ");
171            format!(
172                "MethodCall {{ object: \"{object}\", method: \"{method}\", args: [{args_str}] }}"
173            )
174        }
175        SqlExpression::InList { expr, values } => {
176            let values_str = values
177                .iter()
178                .map(format_expression_ast)
179                .collect::<Vec<_>>()
180                .join(", ");
181            format!(
182                "InList {{ expr: {}, values: [{values_str}] }}",
183                format_expression_ast(expr)
184            )
185        }
186        SqlExpression::NotInList { expr, values } => {
187            let values_str = values
188                .iter()
189                .map(format_expression_ast)
190                .collect::<Vec<_>>()
191                .join(", ");
192            format!(
193                "NotInList {{ expr: {}, values: [{values_str}] }}",
194                format_expression_ast(expr)
195            )
196        }
197        SqlExpression::Between { expr, lower, upper } => {
198            format!(
199                "Between {{ expr: {}, lower: {}, upper: {} }}",
200                format_expression_ast(expr),
201                format_expression_ast(lower),
202                format_expression_ast(upper)
203            )
204        }
205        SqlExpression::Null => "Null".to_string(),
206        SqlExpression::BooleanLiteral(b) => format!("BooleanLiteral({b})"),
207        SqlExpression::DateTimeConstructor {
208            year,
209            month,
210            day,
211            hour,
212            minute,
213            second,
214        } => {
215            let time_part = match (hour, minute, second) {
216                (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
217                (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
218                _ => String::new(),
219            };
220            format!("DateTimeConstructor({year}-{month:02}-{day:02}{time_part})")
221        }
222        SqlExpression::DateTimeToday {
223            hour,
224            minute,
225            second,
226        } => {
227            let time_part = match (hour, minute, second) {
228                (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
229                (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
230                _ => String::new(),
231            };
232            format!("DateTimeToday({time_part})")
233        }
234        SqlExpression::WindowFunction {
235            name,
236            args,
237            window_spec: _,
238        } => {
239            let args_str = args
240                .iter()
241                .map(format_expression_ast)
242                .collect::<Vec<_>>()
243                .join(", ");
244            format!("WindowFunction {{ name: \"{name}\", args: [{args_str}], window_spec: ... }}")
245        }
246        SqlExpression::ChainedMethodCall { base, method, args } => {
247            let args_str = args
248                .iter()
249                .map(format_expression_ast)
250                .collect::<Vec<_>>()
251                .join(", ");
252            format!(
253                "ChainedMethodCall {{ base: {}, method: \"{method}\", args: [{args_str}] }}",
254                format_expression_ast(base)
255            )
256        }
257        SqlExpression::Not { expr } => {
258            format!("Not {{ expr: {} }}", format_expression_ast(expr))
259        }
260        SqlExpression::CaseExpression {
261            when_branches,
262            else_branch,
263        } => {
264            let mut result = String::from("CaseExpression { when_branches: [");
265            for branch in when_branches {
266                result.push_str(&format!(
267                    " {{ condition: {}, result: {} }},",
268                    format_expression_ast(&branch.condition),
269                    format_expression_ast(&branch.result)
270                ));
271            }
272            result.push_str("], else_branch: ");
273            if let Some(else_expr) = else_branch {
274                result.push_str(&format_expression_ast(else_expr));
275            } else {
276                result.push_str("None");
277            }
278            result.push_str(" }");
279            result
280        }
281        SqlExpression::SimpleCaseExpression {
282            expr,
283            when_branches,
284            else_branch,
285        } => {
286            let mut result = format!(
287                "SimpleCaseExpression {{ expr: {}, when_branches: [",
288                format_expression_ast(expr)
289            );
290            for branch in when_branches {
291                result.push_str(&format!(
292                    " {{ value: {}, result: {} }},",
293                    format_expression_ast(&branch.value),
294                    format_expression_ast(&branch.result)
295                ));
296            }
297            result.push_str("], else_branch: ");
298            if let Some(else_expr) = else_branch {
299                result.push_str(&format_expression_ast(else_expr));
300            } else {
301                result.push_str("None");
302            }
303            result.push_str(" }");
304            result
305        }
306        SqlExpression::ScalarSubquery { query: _ } => {
307            format!("ScalarSubquery {{ query: <SelectStatement> }}")
308        }
309        SqlExpression::InSubquery { expr, subquery: _ } => {
310            format!(
311                "InSubquery {{ expr: {}, subquery: <SelectStatement> }}",
312                format_expression_ast(expr)
313            )
314        }
315        SqlExpression::NotInSubquery { expr, subquery: _ } => {
316            format!(
317                "NotInSubquery {{ expr: {}, subquery: <SelectStatement> }}",
318                format_expression_ast(expr)
319            )
320        }
321    }
322}
323
324// Helper function to extract text between positions
325fn extract_text_between_positions(text: &str, start: usize, end: usize) -> String {
326    if start >= text.len() || end > text.len() || start >= end {
327        return String::new();
328    }
329    text[start..end].to_string()
330}
331
332// Helper to find the position of a specific token in the query
333fn find_token_position(query: &str, target: Token, skip_count: usize) -> Option<usize> {
334    let mut lexer = Lexer::new(query);
335    let mut found_count = 0;
336
337    loop {
338        let pos = lexer.get_position();
339        let token = lexer.next_token();
340        if token == Token::Eof {
341            break;
342        }
343        if token == target {
344            if found_count == skip_count {
345                return Some(pos);
346            }
347            found_count += 1;
348        }
349    }
350    None
351}
352
353pub fn format_sql_with_preserved_parens(query: &str, cols_per_line: usize) -> Vec<String> {
354    let mut parser = Parser::new(query);
355    let stmt = match parser.parse() {
356        Ok(s) => s,
357        Err(_) => return vec![query.to_string()],
358    };
359
360    let mut lines = Vec::new();
361    let mut lexer = Lexer::new(query);
362    let mut tokens_with_pos = Vec::new();
363
364    // Collect all tokens with their positions
365    loop {
366        let pos = lexer.get_position();
367        let token = lexer.next_token();
368        if token == Token::Eof {
369            break;
370        }
371        tokens_with_pos.push((token, pos));
372    }
373
374    // Process SELECT clause
375    let mut i = 0;
376    while i < tokens_with_pos.len() {
377        match &tokens_with_pos[i].0 {
378            Token::Select => {
379                let _select_start = tokens_with_pos[i].1;
380                i += 1;
381
382                // Check for DISTINCT
383                let has_distinct = if i < tokens_with_pos.len() {
384                    matches!(tokens_with_pos[i].0, Token::Distinct)
385                } else {
386                    false
387                };
388
389                if has_distinct {
390                    i += 1;
391                }
392
393                // Find the end of SELECT clause (before FROM)
394                let _select_end = query.len();
395                let _col_count = 0;
396                let _current_line_cols: Vec<String> = Vec::new();
397                let mut all_select_lines = Vec::new();
398
399                // Determine if we should use pretty formatting
400                let use_pretty_format = stmt.columns.len() > cols_per_line;
401
402                if use_pretty_format {
403                    // Multi-line formatting
404                    let select_text = if has_distinct {
405                        "SELECT DISTINCT".to_string()
406                    } else {
407                        "SELECT".to_string()
408                    };
409                    all_select_lines.push(select_text);
410
411                    // Process columns with proper indentation
412                    for (idx, col) in stmt.columns.iter().enumerate() {
413                        let is_last = idx == stmt.columns.len() - 1;
414                        // Check if column needs quotes
415                        let formatted_col = if needs_quotes(col) {
416                            format!("\"{}\"", col)
417                        } else {
418                            col.clone()
419                        };
420                        let col_text = if is_last {
421                            format!("    {}", formatted_col)
422                        } else {
423                            format!("    {},", formatted_col)
424                        };
425                        all_select_lines.push(col_text);
426                    }
427                } else {
428                    // Single-line formatting for few columns
429                    let mut select_line = if has_distinct {
430                        "SELECT DISTINCT ".to_string()
431                    } else {
432                        "SELECT ".to_string()
433                    };
434
435                    for (idx, col) in stmt.columns.iter().enumerate() {
436                        if idx > 0 {
437                            select_line.push_str(", ");
438                        }
439                        // Check if column needs quotes
440                        if needs_quotes(col) {
441                            select_line.push_str(&format!("\"{}\"", col));
442                        } else {
443                            select_line.push_str(col);
444                        }
445                    }
446                    all_select_lines.push(select_line);
447                }
448
449                lines.extend(all_select_lines);
450
451                // Skip tokens until we reach FROM
452                while i < tokens_with_pos.len() {
453                    match &tokens_with_pos[i].0 {
454                        Token::From => break,
455                        _ => i += 1,
456                    }
457                }
458            }
459            Token::From => {
460                let from_start = tokens_with_pos[i].1;
461                i += 1;
462
463                // Find the end of FROM clause
464                let mut from_end = query.len();
465                while i < tokens_with_pos.len() {
466                    match &tokens_with_pos[i].0 {
467                        Token::Where
468                        | Token::GroupBy
469                        | Token::OrderBy
470                        | Token::Limit
471                        | Token::Having
472                        | Token::Eof => {
473                            from_end = tokens_with_pos[i].1;
474                            break;
475                        }
476                        _ => i += 1,
477                    }
478                }
479
480                let from_text = extract_text_between_positions(query, from_start, from_end);
481                lines.push(from_text.trim().to_string());
482            }
483            Token::Where => {
484                let where_start = tokens_with_pos[i].1;
485                i += 1;
486
487                // Find the end of WHERE clause
488                let mut where_end = query.len();
489                let mut paren_depth = 0;
490                while i < tokens_with_pos.len() {
491                    match &tokens_with_pos[i].0 {
492                        Token::LeftParen => {
493                            paren_depth += 1;
494                            i += 1;
495                        }
496                        Token::RightParen => {
497                            paren_depth -= 1;
498                            i += 1;
499                        }
500                        Token::GroupBy
501                        | Token::OrderBy
502                        | Token::Limit
503                        | Token::Having
504                        | Token::Eof
505                            if paren_depth == 0 =>
506                        {
507                            where_end = tokens_with_pos[i].1;
508                            break;
509                        }
510                        _ => i += 1,
511                    }
512                }
513
514                let where_text = extract_text_between_positions(query, where_start, where_end);
515                let formatted_where = format_where_clause_with_parens(where_text.trim());
516                lines.extend(formatted_where);
517            }
518            Token::GroupBy => {
519                let group_start = tokens_with_pos[i].1;
520                i += 1;
521
522                // Skip BY token
523                if i < tokens_with_pos.len() && matches!(tokens_with_pos[i].0, Token::By) {
524                    i += 1;
525                }
526
527                // Find the end of GROUP BY clause
528                while i < tokens_with_pos.len() {
529                    match &tokens_with_pos[i].0 {
530                        Token::OrderBy | Token::Limit | Token::Having | Token::Eof => break,
531                        _ => i += 1,
532                    }
533                }
534
535                if i > 0 {
536                    let group_text = extract_text_between_positions(
537                        query,
538                        group_start,
539                        tokens_with_pos[i - 1].1,
540                    );
541                    lines.push(format!("GROUP BY {}", group_text.trim()));
542                }
543            }
544            _ => i += 1,
545        }
546    }
547
548    lines
549}
550
551fn format_where_clause_with_parens(where_text: &str) -> Vec<String> {
552    let mut lines = Vec::new();
553    let mut current = String::from("WHERE ");
554    let mut paren_depth = 0;
555    let mut in_string = false;
556    let mut escape_next = false;
557    let mut chars = where_text.chars().peekable();
558
559    // Skip "WHERE" if it's at the beginning
560    if where_text.trim_start().starts_with("WHERE") || where_text.trim_start().starts_with("where")
561    {
562        let skip_len = if where_text.trim_start().starts_with("WHERE") {
563            5
564        } else {
565            5
566        };
567        for _ in 0..skip_len {
568            chars.next();
569        }
570        // Skip whitespace after WHERE
571        while chars.peek() == Some(&' ') {
572            chars.next();
573        }
574    }
575
576    while let Some(ch) = chars.next() {
577        if escape_next {
578            current.push(ch);
579            escape_next = false;
580            continue;
581        }
582
583        match ch {
584            '\\' if in_string => {
585                current.push(ch);
586                escape_next = true;
587            }
588            '\'' => {
589                current.push(ch);
590                in_string = !in_string;
591            }
592            '(' if !in_string => {
593                current.push(ch);
594                paren_depth += 1;
595            }
596            ')' if !in_string => {
597                current.push(ch);
598                paren_depth -= 1;
599            }
600            _ => {
601                current.push(ch);
602            }
603        }
604    }
605
606    // Clean up the result
607    let cleaned = current.trim().to_string();
608    if !cleaned.is_empty() {
609        lines.push(cleaned);
610    }
611
612    lines
613}
614
615#[must_use]
616pub fn format_sql_pretty_compact(query: &str, cols_per_line: usize) -> Vec<String> {
617    // Use preserved parentheses formatting
618    let formatted = format_sql_with_preserved_parens(query, cols_per_line);
619
620    // Post-process to ensure clean output
621    formatted
622        .into_iter()
623        .filter(|line| !line.trim().is_empty())
624        .map(|line| {
625            // Ensure proper spacing after keywords
626            let mut result = line;
627            for keyword in &[
628                "SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "HAVING", "LIMIT",
629            ] {
630                let pattern = format!("{keyword}");
631                if result.starts_with(&pattern) && !result.starts_with(&format!("{keyword} ")) {
632                    result = format!("{keyword} {}", &result[keyword.len()..].trim_start());
633                }
634            }
635            result
636        })
637        .collect()
638}
639
640pub fn format_expression(expr: &SqlExpression) -> String {
641    match expr {
642        SqlExpression::Column(column_ref) => {
643            // Use the to_sql() method which handles quoting based on quote_style
644            column_ref.to_sql()
645        }
646        SqlExpression::StringLiteral(value) => format!("'{value}'"),
647        SqlExpression::NumberLiteral(value) => value.clone(),
648        SqlExpression::BinaryOp { left, op, right } => {
649            format!(
650                "{} {} {}",
651                format_expression(left),
652                op,
653                format_expression(right)
654            )
655        }
656        SqlExpression::FunctionCall {
657            name,
658            args,
659            distinct,
660        } => {
661            let args_str = args
662                .iter()
663                .map(format_expression)
664                .collect::<Vec<_>>()
665                .join(", ");
666            if *distinct {
667                format!("{name}(DISTINCT {args_str})")
668            } else {
669                format!("{name}({args_str})")
670            }
671        }
672        SqlExpression::MethodCall {
673            object,
674            method,
675            args,
676        } => {
677            let args_str = args
678                .iter()
679                .map(format_expression)
680                .collect::<Vec<_>>()
681                .join(", ");
682            if args.is_empty() {
683                format!("{object}.{method}()")
684            } else {
685                format!("{object}.{method}({args_str})")
686            }
687        }
688        SqlExpression::InList { expr, values } => {
689            let values_str = values
690                .iter()
691                .map(format_expression)
692                .collect::<Vec<_>>()
693                .join(", ");
694            format!("{} IN ({})", format_expression(expr), values_str)
695        }
696        SqlExpression::NotInList { expr, values } => {
697            let values_str = values
698                .iter()
699                .map(format_expression)
700                .collect::<Vec<_>>()
701                .join(", ");
702            format!("{} NOT IN ({})", format_expression(expr), values_str)
703        }
704        SqlExpression::Between { expr, lower, upper } => {
705            format!(
706                "{} BETWEEN {} AND {}",
707                format_expression(expr),
708                format_expression(lower),
709                format_expression(upper)
710            )
711        }
712        SqlExpression::Null => "NULL".to_string(),
713        SqlExpression::BooleanLiteral(b) => b.to_string().to_uppercase(),
714        SqlExpression::DateTimeConstructor {
715            year,
716            month,
717            day,
718            hour,
719            minute,
720            second,
721        } => {
722            let time_part = match (hour, minute, second) {
723                (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
724                (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
725                _ => String::new(),
726            };
727            format!("DATETIME({year}, {month}, {day}{time_part})")
728        }
729        SqlExpression::DateTimeToday {
730            hour,
731            minute,
732            second,
733        } => {
734            let time_part = match (hour, minute, second) {
735                (Some(h), Some(m), Some(s)) => format!(", {h}, {m}, {s}"),
736                (Some(h), Some(m), None) => format!(", {h}, {m}"),
737                (Some(h), None, None) => format!(", {h}"),
738                _ => String::new(),
739            };
740            format!("TODAY({time_part})")
741        }
742        SqlExpression::WindowFunction {
743            name,
744            args,
745            window_spec,
746        } => {
747            let args_str = args
748                .iter()
749                .map(format_expression)
750                .collect::<Vec<_>>()
751                .join(", ");
752
753            let mut result = format!("{name}({args_str}) OVER (");
754
755            // Format partition by
756            if !window_spec.partition_by.is_empty() {
757                result.push_str("PARTITION BY ");
758                result.push_str(&window_spec.partition_by.join(", "));
759            }
760
761            // Format order by
762            if !window_spec.order_by.is_empty() {
763                if !window_spec.partition_by.is_empty() {
764                    result.push(' ');
765                }
766                result.push_str("ORDER BY ");
767                let order_strs: Vec<String> = window_spec
768                    .order_by
769                    .iter()
770                    .map(|col| {
771                        let dir = match col.direction {
772                            SortDirection::Asc => " ASC",
773                            SortDirection::Desc => " DESC",
774                        };
775                        format!("{}{}", col.column, dir)
776                    })
777                    .collect();
778                result.push_str(&order_strs.join(", "));
779            }
780
781            result.push(')');
782            result
783        }
784        SqlExpression::ChainedMethodCall { base, method, args } => {
785            let base_str = format_expression(base);
786            let args_str = args
787                .iter()
788                .map(format_expression)
789                .collect::<Vec<_>>()
790                .join(", ");
791            if args.is_empty() {
792                format!("{base_str}.{method}()")
793            } else {
794                format!("{base_str}.{method}({args_str})")
795            }
796        }
797        SqlExpression::Not { expr } => {
798            format!("NOT {}", format_expression(expr))
799        }
800        SqlExpression::CaseExpression {
801            when_branches,
802            else_branch,
803        } => format_case_expression(when_branches, else_branch.as_ref().map(|v| &**v)),
804        SqlExpression::SimpleCaseExpression {
805            expr,
806            when_branches,
807            else_branch,
808        } => format_simple_case_expression(expr, when_branches, else_branch.as_ref().map(|v| &**v)),
809        SqlExpression::ScalarSubquery { query: _ } => {
810            // For now, just format as a placeholder - proper SQL formatting would need the full query
811            "(SELECT ...)".to_string()
812        }
813        SqlExpression::InSubquery { expr, subquery: _ } => {
814            format!("{} IN (SELECT ...)", format_expression(expr))
815        }
816        SqlExpression::NotInSubquery { expr, subquery: _ } => {
817            format!("{} NOT IN (SELECT ...)", format_expression(expr))
818        }
819    }
820}
821
822fn format_token(token: &Token) -> String {
823    match token {
824        Token::Identifier(s) => s.clone(),
825        Token::QuotedIdentifier(s) => format!("\"{s}\""),
826        Token::StringLiteral(s) => format!("'{s}'"),
827        Token::NumberLiteral(n) => n.clone(),
828        Token::DateTime => "DateTime".to_string(),
829        Token::Case => "CASE".to_string(),
830        Token::When => "WHEN".to_string(),
831        Token::Then => "THEN".to_string(),
832        Token::Else => "ELSE".to_string(),
833        Token::End => "END".to_string(),
834        Token::Distinct => "DISTINCT".to_string(),
835        Token::Over => "OVER".to_string(),
836        Token::Partition => "PARTITION".to_string(),
837        Token::By => "BY".to_string(),
838        Token::LeftParen => "(".to_string(),
839        Token::RightParen => ")".to_string(),
840        Token::Comma => ",".to_string(),
841        Token::Dot => ".".to_string(),
842        Token::Equal => "=".to_string(),
843        Token::NotEqual => "!=".to_string(),
844        Token::LessThan => "<".to_string(),
845        Token::GreaterThan => ">".to_string(),
846        Token::LessThanOrEqual => "<=".to_string(),
847        Token::GreaterThanOrEqual => ">=".to_string(),
848        Token::In => "IN".to_string(),
849        _ => format!("{token:?}").to_uppercase(),
850    }
851}
852
853// Check if a column name needs quotes (contains special characters or is a reserved word)
854fn needs_quotes(name: &str) -> bool {
855    // Check for special characters that require quoting
856    if name.contains('-') || name.contains(' ') || name.contains('.') || name.contains('/') {
857        return true;
858    }
859
860    // Check if it starts with a number
861    if name.chars().next().map_or(false, |c| c.is_ascii_digit()) {
862        return true;
863    }
864
865    // Check if it's a SQL reserved word (common ones)
866    let reserved_words = [
867        "SELECT", "FROM", "WHERE", "ORDER", "GROUP", "BY", "HAVING", "INSERT", "UPDATE", "DELETE",
868        "CREATE", "DROP", "ALTER", "TABLE", "INDEX", "VIEW", "AND", "OR", "NOT", "IN", "EXISTS",
869        "BETWEEN", "LIKE", "CASE", "WHEN", "THEN", "ELSE", "END", "JOIN", "LEFT", "RIGHT", "INNER",
870        "OUTER", "ON", "AS", "DISTINCT", "ALL", "TOP", "LIMIT", "OFFSET", "ASC", "DESC",
871    ];
872
873    let upper_name = name.to_uppercase();
874    if reserved_words.contains(&upper_name.as_str()) {
875        return true;
876    }
877
878    // Check if all characters are valid for unquoted identifiers
879    // Valid: letters, numbers, underscore (but not starting with number)
880    !name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
881}
882
883// Format CASE expressions with proper indentation
884fn format_case_expression(
885    when_branches: &[crate::sql::recursive_parser::WhenBranch],
886    else_branch: Option<&SqlExpression>,
887) -> String {
888    // Check if the CASE expression is simple enough for single line
889    let is_simple = when_branches.len() <= 1
890        && when_branches
891            .iter()
892            .all(|b| expr_is_simple(&b.condition) && expr_is_simple(&b.result))
893        && else_branch.map_or(true, expr_is_simple);
894
895    if is_simple {
896        // Single line format for simple cases
897        let mut result = String::from("CASE");
898        for branch in when_branches {
899            result.push_str(&format!(
900                " WHEN {} THEN {}",
901                format_expression(&branch.condition),
902                format_expression(&branch.result)
903            ));
904        }
905        if let Some(else_expr) = else_branch {
906            result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
907        }
908        result.push_str(" END");
909        result
910    } else {
911        // Multi-line format for complex cases
912        let mut result = String::from("CASE");
913        for branch in when_branches {
914            result.push_str(&format!(
915                "\n        WHEN {} THEN {}",
916                format_expression(&branch.condition),
917                format_expression(&branch.result)
918            ));
919        }
920        if let Some(else_expr) = else_branch {
921            result.push_str(&format!("\n        ELSE {}", format_expression(else_expr)));
922        }
923        result.push_str("\n    END");
924        result
925    }
926}
927
928// Format simple CASE expressions (CASE expr WHEN val1 THEN result1 ...)
929fn format_simple_case_expression(
930    expr: &SqlExpression,
931    when_branches: &[crate::sql::parser::ast::SimpleWhenBranch],
932    else_branch: Option<&SqlExpression>,
933) -> String {
934    // Check if the CASE expression is simple enough for single line
935    let is_simple = when_branches.len() <= 2
936        && expr_is_simple(expr)
937        && when_branches
938            .iter()
939            .all(|b| expr_is_simple(&b.value) && expr_is_simple(&b.result))
940        && else_branch.map_or(true, expr_is_simple);
941
942    if is_simple {
943        // Single line format for simple cases
944        let mut result = format!("CASE {}", format_expression(expr));
945        for branch in when_branches {
946            result.push_str(&format!(
947                " WHEN {} THEN {}",
948                format_expression(&branch.value),
949                format_expression(&branch.result)
950            ));
951        }
952        if let Some(else_expr) = else_branch {
953            result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
954        }
955        result.push_str(" END");
956        result
957    } else {
958        // Multi-line format for complex cases
959        let mut result = format!("CASE {}", format_expression(expr));
960        for branch in when_branches {
961            result.push_str(&format!(
962                "\n        WHEN {} THEN {}",
963                format_expression(&branch.value),
964                format_expression(&branch.result)
965            ));
966        }
967        if let Some(else_expr) = else_branch {
968            result.push_str(&format!("\n        ELSE {}", format_expression(else_expr)));
969        }
970        result.push_str("\n    END");
971        result
972    }
973}
974
975// Check if an expression is simple enough for single-line formatting
976fn expr_is_simple(expr: &SqlExpression) -> bool {
977    match expr {
978        SqlExpression::Column(_)
979        | SqlExpression::StringLiteral(_)
980        | SqlExpression::NumberLiteral(_)
981        | SqlExpression::BooleanLiteral(_)
982        | SqlExpression::Null => true,
983        SqlExpression::BinaryOp { left, right, .. } => {
984            expr_is_simple(left) && expr_is_simple(right)
985        }
986        SqlExpression::FunctionCall { args, .. } => {
987            args.len() <= 2 && args.iter().all(expr_is_simple)
988        }
989        _ => false,
990    }
991}