1use super::ast::{LogicalOp, SelectStatement, SortDirection, SqlExpression, WhereClause};
5use super::lexer::{Lexer, Token};
6use crate::sql::recursive_parser::Parser;
7
8#[must_use]
9pub fn format_sql_pretty(query: &str) -> Vec<String> {
10 format_sql_pretty_compact(query, 5) }
12
13#[must_use]
15pub fn format_ast_tree(query: &str) -> String {
16 let mut parser = Parser::new(query);
17 match parser.parse() {
18 Ok(stmt) => format_select_statement(&stmt, 0),
19 Err(e) => format!("❌ PARSE ERROR ❌\n{e}\n\n⚠️ The query could not be parsed correctly.\n💡 Check parentheses, operators, and syntax."),
20 }
21}
22
23fn format_select_statement(stmt: &SelectStatement, indent: usize) -> String {
24 let mut result = String::new();
25 let indent_str = " ".repeat(indent);
26
27 result.push_str(&format!("{indent_str}SelectStatement {{\n"));
28
29 result.push_str(&format!("{indent_str} columns: ["));
31 if stmt.columns.is_empty() {
32 result.push_str("],\n");
33 } else {
34 result.push('\n');
35 for col in &stmt.columns {
36 result.push_str(&format!("{indent_str} \"{col}\",\n"));
37 }
38 result.push_str(&format!("{indent_str} ],\n"));
39 }
40
41 if let Some(table) = &stmt.from_table {
43 result.push_str(&format!("{indent_str} from_table: \"{table}\",\n"));
44 }
45
46 if let Some(where_clause) = &stmt.where_clause {
48 result.push_str(&format!("{indent_str} where_clause: {{\n"));
49 result.push_str(&format_where_clause(where_clause, indent + 2));
50 result.push_str(&format!("{indent_str} }},\n"));
51 }
52
53 if let Some(order_by) = &stmt.order_by {
55 result.push_str(&format!("{indent_str} order_by: ["));
56 if order_by.is_empty() {
57 result.push_str("],\n");
58 } else {
59 result.push('\n');
60 for col in order_by {
61 let dir = match col.direction {
62 SortDirection::Asc => "ASC",
63 SortDirection::Desc => "DESC",
64 };
65 result.push_str(&format!(
66 "{indent_str} {{ column: \"{}\", direction: {dir} }},\n",
67 col.column
68 ));
69 }
70 result.push_str(&format!("{indent_str} ],\n"));
71 }
72 }
73
74 if let Some(group_by) = &stmt.group_by {
76 result.push_str(&format!("{indent_str} group_by: ["));
77 if group_by.is_empty() {
78 result.push_str("],\n");
79 } else {
80 result.push('\n');
81 for expr in group_by {
82 result.push_str(&format!("{indent_str} \"{:?}\",\n", expr));
83 }
84 result.push_str(&format!("{indent_str} ],\n"));
85 }
86 }
87
88 if let Some(limit) = stmt.limit {
90 result.push_str(&format!("{indent_str} limit: {limit},\n"));
91 }
92
93 if stmt.distinct {
95 result.push_str(&format!("{indent_str} distinct: true,\n"));
96 }
97
98 result.push_str(&format!("{indent_str}}}\n"));
99 result
100}
101
102fn format_where_clause(clause: &WhereClause, indent: usize) -> String {
103 let mut result = String::new();
104 let indent_str = " ".repeat(indent);
105
106 result.push_str(&format!("{indent_str}conditions: [\n"));
107 for (i, condition) in clause.conditions.iter().enumerate() {
108 result.push_str(&format!("{indent_str} {{\n"));
109 result.push_str(&format!(
110 "{indent_str} expr: {},\n",
111 format_expression_ast(&condition.expr)
112 ));
113
114 if let Some(connector) = &condition.connector {
115 let conn_str = match connector {
116 LogicalOp::And => "AND",
117 LogicalOp::Or => "OR",
118 };
119 result.push_str(&format!("{indent_str} connector: {conn_str},\n"));
120 }
121
122 result.push_str(&format!("{indent_str} }}"));
123 if i < clause.conditions.len() - 1 {
124 result.push(',');
125 }
126 result.push('\n');
127 }
128 result.push_str(&format!("{indent_str}]\n"));
129
130 result
131}
132
133pub fn format_expression_ast(expr: &SqlExpression) -> String {
134 match expr {
135 SqlExpression::Column(name) => format!("Column(\"{name}\")"),
136 SqlExpression::StringLiteral(value) => format!("StringLiteral(\"{value}\")"),
137 SqlExpression::NumberLiteral(value) => format!("NumberLiteral({value})"),
138 SqlExpression::BinaryOp { left, op, right } => {
139 format!(
140 "BinaryOp {{ left: {}, op: \"{op}\", right: {} }}",
141 format_expression_ast(left),
142 format_expression_ast(right)
143 )
144 }
145 SqlExpression::FunctionCall {
146 name,
147 args,
148 distinct,
149 } => {
150 let args_str = args
151 .iter()
152 .map(format_expression_ast)
153 .collect::<Vec<_>>()
154 .join(", ");
155 if *distinct {
156 format!("FunctionCall {{ name: \"{name}\", args: [{args_str}], distinct: true }}")
157 } else {
158 format!("FunctionCall {{ name: \"{name}\", args: [{args_str}] }}")
159 }
160 }
161 SqlExpression::MethodCall {
162 object,
163 method,
164 args,
165 } => {
166 let args_str = args
167 .iter()
168 .map(format_expression_ast)
169 .collect::<Vec<_>>()
170 .join(", ");
171 format!(
172 "MethodCall {{ object: \"{object}\", method: \"{method}\", args: [{args_str}] }}"
173 )
174 }
175 SqlExpression::InList { expr, values } => {
176 let values_str = values
177 .iter()
178 .map(format_expression_ast)
179 .collect::<Vec<_>>()
180 .join(", ");
181 format!(
182 "InList {{ expr: {}, values: [{values_str}] }}",
183 format_expression_ast(expr)
184 )
185 }
186 SqlExpression::NotInList { expr, values } => {
187 let values_str = values
188 .iter()
189 .map(format_expression_ast)
190 .collect::<Vec<_>>()
191 .join(", ");
192 format!(
193 "NotInList {{ expr: {}, values: [{values_str}] }}",
194 format_expression_ast(expr)
195 )
196 }
197 SqlExpression::Between { expr, lower, upper } => {
198 format!(
199 "Between {{ expr: {}, lower: {}, upper: {} }}",
200 format_expression_ast(expr),
201 format_expression_ast(lower),
202 format_expression_ast(upper)
203 )
204 }
205 SqlExpression::Null => "Null".to_string(),
206 SqlExpression::BooleanLiteral(b) => format!("BooleanLiteral({b})"),
207 SqlExpression::DateTimeConstructor {
208 year,
209 month,
210 day,
211 hour,
212 minute,
213 second,
214 } => {
215 let time_part = match (hour, minute, second) {
216 (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
217 (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
218 _ => String::new(),
219 };
220 format!("DateTimeConstructor({year}-{month:02}-{day:02}{time_part})")
221 }
222 SqlExpression::DateTimeToday {
223 hour,
224 minute,
225 second,
226 } => {
227 let time_part = match (hour, minute, second) {
228 (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
229 (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
230 _ => String::new(),
231 };
232 format!("DateTimeToday({time_part})")
233 }
234 SqlExpression::WindowFunction {
235 name,
236 args,
237 window_spec: _,
238 } => {
239 let args_str = args
240 .iter()
241 .map(format_expression_ast)
242 .collect::<Vec<_>>()
243 .join(", ");
244 format!("WindowFunction {{ name: \"{name}\", args: [{args_str}], window_spec: ... }}")
245 }
246 SqlExpression::ChainedMethodCall { base, method, args } => {
247 let args_str = args
248 .iter()
249 .map(format_expression_ast)
250 .collect::<Vec<_>>()
251 .join(", ");
252 format!(
253 "ChainedMethodCall {{ base: {}, method: \"{method}\", args: [{args_str}] }}",
254 format_expression_ast(base)
255 )
256 }
257 SqlExpression::Not { expr } => {
258 format!("Not {{ expr: {} }}", format_expression_ast(expr))
259 }
260 SqlExpression::CaseExpression {
261 when_branches,
262 else_branch,
263 } => {
264 let mut result = String::from("CaseExpression { when_branches: [");
265 for branch in when_branches {
266 result.push_str(&format!(
267 " {{ condition: {}, result: {} }},",
268 format_expression_ast(&branch.condition),
269 format_expression_ast(&branch.result)
270 ));
271 }
272 result.push_str("], else_branch: ");
273 if let Some(else_expr) = else_branch {
274 result.push_str(&format_expression_ast(else_expr));
275 } else {
276 result.push_str("None");
277 }
278 result.push_str(" }");
279 result
280 }
281 SqlExpression::SimpleCaseExpression {
282 expr,
283 when_branches,
284 else_branch,
285 } => {
286 let mut result = format!(
287 "SimpleCaseExpression {{ expr: {}, when_branches: [",
288 format_expression_ast(expr)
289 );
290 for branch in when_branches {
291 result.push_str(&format!(
292 " {{ value: {}, result: {} }},",
293 format_expression_ast(&branch.value),
294 format_expression_ast(&branch.result)
295 ));
296 }
297 result.push_str("], else_branch: ");
298 if let Some(else_expr) = else_branch {
299 result.push_str(&format_expression_ast(else_expr));
300 } else {
301 result.push_str("None");
302 }
303 result.push_str(" }");
304 result
305 }
306 SqlExpression::ScalarSubquery { query: _ } => {
307 format!("ScalarSubquery {{ query: <SelectStatement> }}")
308 }
309 SqlExpression::InSubquery { expr, subquery: _ } => {
310 format!(
311 "InSubquery {{ expr: {}, subquery: <SelectStatement> }}",
312 format_expression_ast(expr)
313 )
314 }
315 SqlExpression::NotInSubquery { expr, subquery: _ } => {
316 format!(
317 "NotInSubquery {{ expr: {}, subquery: <SelectStatement> }}",
318 format_expression_ast(expr)
319 )
320 }
321 }
322}
323
324fn extract_text_between_positions(text: &str, start: usize, end: usize) -> String {
326 if start >= text.len() || end > text.len() || start >= end {
327 return String::new();
328 }
329 text[start..end].to_string()
330}
331
332fn find_token_position(query: &str, target: Token, skip_count: usize) -> Option<usize> {
334 let mut lexer = Lexer::new(query);
335 let mut found_count = 0;
336
337 loop {
338 let pos = lexer.get_position();
339 let token = lexer.next_token();
340 if token == Token::Eof {
341 break;
342 }
343 if token == target {
344 if found_count == skip_count {
345 return Some(pos);
346 }
347 found_count += 1;
348 }
349 }
350 None
351}
352
353pub fn format_sql_with_preserved_parens(query: &str, cols_per_line: usize) -> Vec<String> {
354 let mut parser = Parser::new(query);
355 let stmt = match parser.parse() {
356 Ok(s) => s,
357 Err(_) => return vec![query.to_string()],
358 };
359
360 let mut lines = Vec::new();
361 let mut lexer = Lexer::new(query);
362 let mut tokens_with_pos = Vec::new();
363
364 loop {
366 let pos = lexer.get_position();
367 let token = lexer.next_token();
368 if token == Token::Eof {
369 break;
370 }
371 tokens_with_pos.push((token, pos));
372 }
373
374 let mut i = 0;
376 while i < tokens_with_pos.len() {
377 match &tokens_with_pos[i].0 {
378 Token::Select => {
379 let _select_start = tokens_with_pos[i].1;
380 i += 1;
381
382 let has_distinct = if i < tokens_with_pos.len() {
384 matches!(tokens_with_pos[i].0, Token::Distinct)
385 } else {
386 false
387 };
388
389 if has_distinct {
390 i += 1;
391 }
392
393 let _select_end = query.len();
395 let _col_count = 0;
396 let _current_line_cols: Vec<String> = Vec::new();
397 let mut all_select_lines = Vec::new();
398
399 let use_pretty_format = stmt.columns.len() > cols_per_line;
401
402 if use_pretty_format {
403 let select_text = if has_distinct {
405 "SELECT DISTINCT".to_string()
406 } else {
407 "SELECT".to_string()
408 };
409 all_select_lines.push(select_text);
410
411 for (idx, col) in stmt.columns.iter().enumerate() {
413 let is_last = idx == stmt.columns.len() - 1;
414 let formatted_col = if needs_quotes(col) {
416 format!("\"{}\"", col)
417 } else {
418 col.clone()
419 };
420 let col_text = if is_last {
421 format!(" {}", formatted_col)
422 } else {
423 format!(" {},", formatted_col)
424 };
425 all_select_lines.push(col_text);
426 }
427 } else {
428 let mut select_line = if has_distinct {
430 "SELECT DISTINCT ".to_string()
431 } else {
432 "SELECT ".to_string()
433 };
434
435 for (idx, col) in stmt.columns.iter().enumerate() {
436 if idx > 0 {
437 select_line.push_str(", ");
438 }
439 if needs_quotes(col) {
441 select_line.push_str(&format!("\"{}\"", col));
442 } else {
443 select_line.push_str(col);
444 }
445 }
446 all_select_lines.push(select_line);
447 }
448
449 lines.extend(all_select_lines);
450
451 while i < tokens_with_pos.len() {
453 match &tokens_with_pos[i].0 {
454 Token::From => break,
455 _ => i += 1,
456 }
457 }
458 }
459 Token::From => {
460 let from_start = tokens_with_pos[i].1;
461 i += 1;
462
463 let mut from_end = query.len();
465 while i < tokens_with_pos.len() {
466 match &tokens_with_pos[i].0 {
467 Token::Where
468 | Token::GroupBy
469 | Token::OrderBy
470 | Token::Limit
471 | Token::Having
472 | Token::Eof => {
473 from_end = tokens_with_pos[i].1;
474 break;
475 }
476 _ => i += 1,
477 }
478 }
479
480 let from_text = extract_text_between_positions(query, from_start, from_end);
481 lines.push(from_text.trim().to_string());
482 }
483 Token::Where => {
484 let where_start = tokens_with_pos[i].1;
485 i += 1;
486
487 let mut where_end = query.len();
489 let mut paren_depth = 0;
490 while i < tokens_with_pos.len() {
491 match &tokens_with_pos[i].0 {
492 Token::LeftParen => {
493 paren_depth += 1;
494 i += 1;
495 }
496 Token::RightParen => {
497 paren_depth -= 1;
498 i += 1;
499 }
500 Token::GroupBy
501 | Token::OrderBy
502 | Token::Limit
503 | Token::Having
504 | Token::Eof
505 if paren_depth == 0 =>
506 {
507 where_end = tokens_with_pos[i].1;
508 break;
509 }
510 _ => i += 1,
511 }
512 }
513
514 let where_text = extract_text_between_positions(query, where_start, where_end);
515 let formatted_where = format_where_clause_with_parens(where_text.trim());
516 lines.extend(formatted_where);
517 }
518 Token::GroupBy => {
519 let group_start = tokens_with_pos[i].1;
520 i += 1;
521
522 if i < tokens_with_pos.len() && matches!(tokens_with_pos[i].0, Token::By) {
524 i += 1;
525 }
526
527 while i < tokens_with_pos.len() {
529 match &tokens_with_pos[i].0 {
530 Token::OrderBy | Token::Limit | Token::Having | Token::Eof => break,
531 _ => i += 1,
532 }
533 }
534
535 if i > 0 {
536 let group_text = extract_text_between_positions(
537 query,
538 group_start,
539 tokens_with_pos[i - 1].1,
540 );
541 lines.push(format!("GROUP BY {}", group_text.trim()));
542 }
543 }
544 _ => i += 1,
545 }
546 }
547
548 lines
549}
550
551fn format_where_clause_with_parens(where_text: &str) -> Vec<String> {
552 let mut lines = Vec::new();
553 let mut current = String::from("WHERE ");
554 let mut paren_depth = 0;
555 let mut in_string = false;
556 let mut escape_next = false;
557 let mut chars = where_text.chars().peekable();
558
559 if where_text.trim_start().starts_with("WHERE") || where_text.trim_start().starts_with("where")
561 {
562 let skip_len = if where_text.trim_start().starts_with("WHERE") {
563 5
564 } else {
565 5
566 };
567 for _ in 0..skip_len {
568 chars.next();
569 }
570 while chars.peek() == Some(&' ') {
572 chars.next();
573 }
574 }
575
576 while let Some(ch) = chars.next() {
577 if escape_next {
578 current.push(ch);
579 escape_next = false;
580 continue;
581 }
582
583 match ch {
584 '\\' if in_string => {
585 current.push(ch);
586 escape_next = true;
587 }
588 '\'' => {
589 current.push(ch);
590 in_string = !in_string;
591 }
592 '(' if !in_string => {
593 current.push(ch);
594 paren_depth += 1;
595 }
596 ')' if !in_string => {
597 current.push(ch);
598 paren_depth -= 1;
599 }
600 _ => {
601 current.push(ch);
602 }
603 }
604 }
605
606 let cleaned = current.trim().to_string();
608 if !cleaned.is_empty() {
609 lines.push(cleaned);
610 }
611
612 lines
613}
614
615#[must_use]
616pub fn format_sql_pretty_compact(query: &str, cols_per_line: usize) -> Vec<String> {
617 let formatted = format_sql_with_preserved_parens(query, cols_per_line);
619
620 formatted
622 .into_iter()
623 .filter(|line| !line.trim().is_empty())
624 .map(|line| {
625 let mut result = line;
627 for keyword in &[
628 "SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "HAVING", "LIMIT",
629 ] {
630 let pattern = format!("{keyword}");
631 if result.starts_with(&pattern) && !result.starts_with(&format!("{keyword} ")) {
632 result = format!("{keyword} {}", &result[keyword.len()..].trim_start());
633 }
634 }
635 result
636 })
637 .collect()
638}
639
640pub fn format_expression(expr: &SqlExpression) -> String {
641 match expr {
642 SqlExpression::Column(column_ref) => {
643 column_ref.to_sql()
645 }
646 SqlExpression::StringLiteral(value) => format!("'{value}'"),
647 SqlExpression::NumberLiteral(value) => value.clone(),
648 SqlExpression::BinaryOp { left, op, right } => {
649 format!(
650 "{} {} {}",
651 format_expression(left),
652 op,
653 format_expression(right)
654 )
655 }
656 SqlExpression::FunctionCall {
657 name,
658 args,
659 distinct,
660 } => {
661 let args_str = args
662 .iter()
663 .map(format_expression)
664 .collect::<Vec<_>>()
665 .join(", ");
666 if *distinct {
667 format!("{name}(DISTINCT {args_str})")
668 } else {
669 format!("{name}({args_str})")
670 }
671 }
672 SqlExpression::MethodCall {
673 object,
674 method,
675 args,
676 } => {
677 let args_str = args
678 .iter()
679 .map(format_expression)
680 .collect::<Vec<_>>()
681 .join(", ");
682 if args.is_empty() {
683 format!("{object}.{method}()")
684 } else {
685 format!("{object}.{method}({args_str})")
686 }
687 }
688 SqlExpression::InList { expr, values } => {
689 let values_str = values
690 .iter()
691 .map(format_expression)
692 .collect::<Vec<_>>()
693 .join(", ");
694 format!("{} IN ({})", format_expression(expr), values_str)
695 }
696 SqlExpression::NotInList { expr, values } => {
697 let values_str = values
698 .iter()
699 .map(format_expression)
700 .collect::<Vec<_>>()
701 .join(", ");
702 format!("{} NOT IN ({})", format_expression(expr), values_str)
703 }
704 SqlExpression::Between { expr, lower, upper } => {
705 format!(
706 "{} BETWEEN {} AND {}",
707 format_expression(expr),
708 format_expression(lower),
709 format_expression(upper)
710 )
711 }
712 SqlExpression::Null => "NULL".to_string(),
713 SqlExpression::BooleanLiteral(b) => b.to_string().to_uppercase(),
714 SqlExpression::DateTimeConstructor {
715 year,
716 month,
717 day,
718 hour,
719 minute,
720 second,
721 } => {
722 let time_part = match (hour, minute, second) {
723 (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
724 (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
725 _ => String::new(),
726 };
727 format!("DATETIME({year}, {month}, {day}{time_part})")
728 }
729 SqlExpression::DateTimeToday {
730 hour,
731 minute,
732 second,
733 } => {
734 let time_part = match (hour, minute, second) {
735 (Some(h), Some(m), Some(s)) => format!(", {h}, {m}, {s}"),
736 (Some(h), Some(m), None) => format!(", {h}, {m}"),
737 (Some(h), None, None) => format!(", {h}"),
738 _ => String::new(),
739 };
740 format!("TODAY({time_part})")
741 }
742 SqlExpression::WindowFunction {
743 name,
744 args,
745 window_spec,
746 } => {
747 let args_str = args
748 .iter()
749 .map(format_expression)
750 .collect::<Vec<_>>()
751 .join(", ");
752
753 let mut result = format!("{name}({args_str}) OVER (");
754
755 if !window_spec.partition_by.is_empty() {
757 result.push_str("PARTITION BY ");
758 result.push_str(&window_spec.partition_by.join(", "));
759 }
760
761 if !window_spec.order_by.is_empty() {
763 if !window_spec.partition_by.is_empty() {
764 result.push(' ');
765 }
766 result.push_str("ORDER BY ");
767 let order_strs: Vec<String> = window_spec
768 .order_by
769 .iter()
770 .map(|col| {
771 let dir = match col.direction {
772 SortDirection::Asc => " ASC",
773 SortDirection::Desc => " DESC",
774 };
775 format!("{}{}", col.column, dir)
776 })
777 .collect();
778 result.push_str(&order_strs.join(", "));
779 }
780
781 result.push(')');
782 result
783 }
784 SqlExpression::ChainedMethodCall { base, method, args } => {
785 let base_str = format_expression(base);
786 let args_str = args
787 .iter()
788 .map(format_expression)
789 .collect::<Vec<_>>()
790 .join(", ");
791 if args.is_empty() {
792 format!("{base_str}.{method}()")
793 } else {
794 format!("{base_str}.{method}({args_str})")
795 }
796 }
797 SqlExpression::Not { expr } => {
798 format!("NOT {}", format_expression(expr))
799 }
800 SqlExpression::CaseExpression {
801 when_branches,
802 else_branch,
803 } => format_case_expression(when_branches, else_branch.as_ref().map(|v| &**v)),
804 SqlExpression::SimpleCaseExpression {
805 expr,
806 when_branches,
807 else_branch,
808 } => format_simple_case_expression(expr, when_branches, else_branch.as_ref().map(|v| &**v)),
809 SqlExpression::ScalarSubquery { query: _ } => {
810 "(SELECT ...)".to_string()
812 }
813 SqlExpression::InSubquery { expr, subquery: _ } => {
814 format!("{} IN (SELECT ...)", format_expression(expr))
815 }
816 SqlExpression::NotInSubquery { expr, subquery: _ } => {
817 format!("{} NOT IN (SELECT ...)", format_expression(expr))
818 }
819 }
820}
821
822fn format_token(token: &Token) -> String {
823 match token {
824 Token::Identifier(s) => s.clone(),
825 Token::QuotedIdentifier(s) => format!("\"{s}\""),
826 Token::StringLiteral(s) => format!("'{s}'"),
827 Token::NumberLiteral(n) => n.clone(),
828 Token::DateTime => "DateTime".to_string(),
829 Token::Case => "CASE".to_string(),
830 Token::When => "WHEN".to_string(),
831 Token::Then => "THEN".to_string(),
832 Token::Else => "ELSE".to_string(),
833 Token::End => "END".to_string(),
834 Token::Distinct => "DISTINCT".to_string(),
835 Token::Over => "OVER".to_string(),
836 Token::Partition => "PARTITION".to_string(),
837 Token::By => "BY".to_string(),
838 Token::LeftParen => "(".to_string(),
839 Token::RightParen => ")".to_string(),
840 Token::Comma => ",".to_string(),
841 Token::Dot => ".".to_string(),
842 Token::Equal => "=".to_string(),
843 Token::NotEqual => "!=".to_string(),
844 Token::LessThan => "<".to_string(),
845 Token::GreaterThan => ">".to_string(),
846 Token::LessThanOrEqual => "<=".to_string(),
847 Token::GreaterThanOrEqual => ">=".to_string(),
848 Token::In => "IN".to_string(),
849 _ => format!("{token:?}").to_uppercase(),
850 }
851}
852
853fn needs_quotes(name: &str) -> bool {
855 if name.contains('-') || name.contains(' ') || name.contains('.') || name.contains('/') {
857 return true;
858 }
859
860 if name.chars().next().map_or(false, |c| c.is_ascii_digit()) {
862 return true;
863 }
864
865 let reserved_words = [
867 "SELECT", "FROM", "WHERE", "ORDER", "GROUP", "BY", "HAVING", "INSERT", "UPDATE", "DELETE",
868 "CREATE", "DROP", "ALTER", "TABLE", "INDEX", "VIEW", "AND", "OR", "NOT", "IN", "EXISTS",
869 "BETWEEN", "LIKE", "CASE", "WHEN", "THEN", "ELSE", "END", "JOIN", "LEFT", "RIGHT", "INNER",
870 "OUTER", "ON", "AS", "DISTINCT", "ALL", "TOP", "LIMIT", "OFFSET", "ASC", "DESC",
871 ];
872
873 let upper_name = name.to_uppercase();
874 if reserved_words.contains(&upper_name.as_str()) {
875 return true;
876 }
877
878 !name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
881}
882
883fn format_case_expression(
885 when_branches: &[crate::sql::recursive_parser::WhenBranch],
886 else_branch: Option<&SqlExpression>,
887) -> String {
888 let is_simple = when_branches.len() <= 1
890 && when_branches
891 .iter()
892 .all(|b| expr_is_simple(&b.condition) && expr_is_simple(&b.result))
893 && else_branch.map_or(true, expr_is_simple);
894
895 if is_simple {
896 let mut result = String::from("CASE");
898 for branch in when_branches {
899 result.push_str(&format!(
900 " WHEN {} THEN {}",
901 format_expression(&branch.condition),
902 format_expression(&branch.result)
903 ));
904 }
905 if let Some(else_expr) = else_branch {
906 result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
907 }
908 result.push_str(" END");
909 result
910 } else {
911 let mut result = String::from("CASE");
913 for branch in when_branches {
914 result.push_str(&format!(
915 "\n WHEN {} THEN {}",
916 format_expression(&branch.condition),
917 format_expression(&branch.result)
918 ));
919 }
920 if let Some(else_expr) = else_branch {
921 result.push_str(&format!("\n ELSE {}", format_expression(else_expr)));
922 }
923 result.push_str("\n END");
924 result
925 }
926}
927
928fn format_simple_case_expression(
930 expr: &SqlExpression,
931 when_branches: &[crate::sql::parser::ast::SimpleWhenBranch],
932 else_branch: Option<&SqlExpression>,
933) -> String {
934 let is_simple = when_branches.len() <= 2
936 && expr_is_simple(expr)
937 && when_branches
938 .iter()
939 .all(|b| expr_is_simple(&b.value) && expr_is_simple(&b.result))
940 && else_branch.map_or(true, expr_is_simple);
941
942 if is_simple {
943 let mut result = format!("CASE {}", format_expression(expr));
945 for branch in when_branches {
946 result.push_str(&format!(
947 " WHEN {} THEN {}",
948 format_expression(&branch.value),
949 format_expression(&branch.result)
950 ));
951 }
952 if let Some(else_expr) = else_branch {
953 result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
954 }
955 result.push_str(" END");
956 result
957 } else {
958 let mut result = format!("CASE {}", format_expression(expr));
960 for branch in when_branches {
961 result.push_str(&format!(
962 "\n WHEN {} THEN {}",
963 format_expression(&branch.value),
964 format_expression(&branch.result)
965 ));
966 }
967 if let Some(else_expr) = else_branch {
968 result.push_str(&format!("\n ELSE {}", format_expression(else_expr)));
969 }
970 result.push_str("\n END");
971 result
972 }
973}
974
975fn expr_is_simple(expr: &SqlExpression) -> bool {
977 match expr {
978 SqlExpression::Column(_)
979 | SqlExpression::StringLiteral(_)
980 | SqlExpression::NumberLiteral(_)
981 | SqlExpression::BooleanLiteral(_)
982 | SqlExpression::Null => true,
983 SqlExpression::BinaryOp { left, right, .. } => {
984 expr_is_simple(left) && expr_is_simple(right)
985 }
986 SqlExpression::FunctionCall { args, .. } => {
987 args.len() <= 2 && args.iter().all(expr_is_simple)
988 }
989 _ => false,
990 }
991}