1use super::ast::{LogicalOp, SelectStatement, SortDirection, SqlExpression, WhereClause};
5use super::lexer::{Lexer, Token};
6use crate::sql::recursive_parser::Parser;
7
8#[must_use]
9pub fn format_sql_pretty(query: &str) -> Vec<String> {
10 format_sql_pretty_compact(query, 5) }
12
13#[must_use]
15pub fn format_ast_tree(query: &str) -> String {
16 let mut parser = Parser::new(query);
17 match parser.parse() {
18 Ok(stmt) => format_select_statement(&stmt, 0),
19 Err(e) => format!("❌ PARSE ERROR ❌\n{e}\n\n⚠️ The query could not be parsed correctly.\n💡 Check parentheses, operators, and syntax."),
20 }
21}
22
23fn format_select_statement(stmt: &SelectStatement, indent: usize) -> String {
24 let mut result = String::new();
25 let indent_str = " ".repeat(indent);
26
27 result.push_str(&format!("{indent_str}SelectStatement {{\n"));
28
29 result.push_str(&format!("{indent_str} columns: ["));
31 if stmt.columns.is_empty() {
32 result.push_str("],\n");
33 } else {
34 result.push('\n');
35 for col in &stmt.columns {
36 result.push_str(&format!("{indent_str} \"{col}\",\n"));
37 }
38 result.push_str(&format!("{indent_str} ],\n"));
39 }
40
41 if let Some(table) = &stmt.from_table {
43 result.push_str(&format!("{indent_str} from_table: \"{table}\",\n"));
44 }
45
46 if let Some(where_clause) = &stmt.where_clause {
48 result.push_str(&format!("{indent_str} where_clause: {{\n"));
49 result.push_str(&format_where_clause(where_clause, indent + 2));
50 result.push_str(&format!("{indent_str} }},\n"));
51 }
52
53 if let Some(order_by) = &stmt.order_by {
55 result.push_str(&format!("{indent_str} order_by: ["));
56 if order_by.is_empty() {
57 result.push_str("],\n");
58 } else {
59 result.push('\n');
60 for col in order_by {
61 let dir = match col.direction {
62 SortDirection::Asc => "ASC",
63 SortDirection::Desc => "DESC",
64 };
65 let expr_str = match &col.expr {
67 SqlExpression::Column(col_ref) => col_ref.name.clone(),
68 _ => format!("{:?}", col.expr),
69 };
70 result.push_str(&format!(
71 "{indent_str} {{ expr: \"{}\", direction: {dir} }},\n",
72 expr_str
73 ));
74 }
75 result.push_str(&format!("{indent_str} ],\n"));
76 }
77 }
78
79 if let Some(group_by) = &stmt.group_by {
81 result.push_str(&format!("{indent_str} group_by: ["));
82 if group_by.is_empty() {
83 result.push_str("],\n");
84 } else {
85 result.push('\n');
86 for expr in group_by {
87 result.push_str(&format!("{indent_str} \"{:?}\",\n", expr));
88 }
89 result.push_str(&format!("{indent_str} ],\n"));
90 }
91 }
92
93 if let Some(limit) = stmt.limit {
95 result.push_str(&format!("{indent_str} limit: {limit},\n"));
96 }
97
98 if stmt.distinct {
100 result.push_str(&format!("{indent_str} distinct: true,\n"));
101 }
102
103 result.push_str(&format!("{indent_str}}}\n"));
104 result
105}
106
107fn format_where_clause(clause: &WhereClause, indent: usize) -> String {
108 let mut result = String::new();
109 let indent_str = " ".repeat(indent);
110
111 result.push_str(&format!("{indent_str}conditions: [\n"));
112 for (i, condition) in clause.conditions.iter().enumerate() {
113 result.push_str(&format!("{indent_str} {{\n"));
114 result.push_str(&format!(
115 "{indent_str} expr: {},\n",
116 format_expression_ast(&condition.expr)
117 ));
118
119 if let Some(connector) = &condition.connector {
120 let conn_str = match connector {
121 LogicalOp::And => "AND",
122 LogicalOp::Or => "OR",
123 };
124 result.push_str(&format!("{indent_str} connector: {conn_str},\n"));
125 }
126
127 result.push_str(&format!("{indent_str} }}"));
128 if i < clause.conditions.len() - 1 {
129 result.push(',');
130 }
131 result.push('\n');
132 }
133 result.push_str(&format!("{indent_str}]\n"));
134
135 result
136}
137
138pub fn format_expression_ast(expr: &SqlExpression) -> String {
139 match expr {
140 SqlExpression::Column(name) => format!("Column(\"{name}\")"),
141 SqlExpression::StringLiteral(value) => format!("StringLiteral(\"{value}\")"),
142 SqlExpression::NumberLiteral(value) => format!("NumberLiteral({value})"),
143 SqlExpression::BinaryOp { left, op, right } => {
144 format!(
145 "BinaryOp {{ left: {}, op: \"{op}\", right: {} }}",
146 format_expression_ast(left),
147 format_expression_ast(right)
148 )
149 }
150 SqlExpression::FunctionCall {
151 name,
152 args,
153 distinct,
154 } => {
155 let args_str = args
156 .iter()
157 .map(format_expression_ast)
158 .collect::<Vec<_>>()
159 .join(", ");
160 if *distinct {
161 format!("FunctionCall {{ name: \"{name}\", args: [{args_str}], distinct: true }}")
162 } else {
163 format!("FunctionCall {{ name: \"{name}\", args: [{args_str}] }}")
164 }
165 }
166 SqlExpression::MethodCall {
167 object,
168 method,
169 args,
170 } => {
171 let args_str = args
172 .iter()
173 .map(format_expression_ast)
174 .collect::<Vec<_>>()
175 .join(", ");
176 format!(
177 "MethodCall {{ object: \"{object}\", method: \"{method}\", args: [{args_str}] }}"
178 )
179 }
180 SqlExpression::InList { expr, values } => {
181 let values_str = values
182 .iter()
183 .map(format_expression_ast)
184 .collect::<Vec<_>>()
185 .join(", ");
186 format!(
187 "InList {{ expr: {}, values: [{values_str}] }}",
188 format_expression_ast(expr)
189 )
190 }
191 SqlExpression::NotInList { expr, values } => {
192 let values_str = values
193 .iter()
194 .map(format_expression_ast)
195 .collect::<Vec<_>>()
196 .join(", ");
197 format!(
198 "NotInList {{ expr: {}, values: [{values_str}] }}",
199 format_expression_ast(expr)
200 )
201 }
202 SqlExpression::Between { expr, lower, upper } => {
203 format!(
204 "Between {{ expr: {}, lower: {}, upper: {} }}",
205 format_expression_ast(expr),
206 format_expression_ast(lower),
207 format_expression_ast(upper)
208 )
209 }
210 SqlExpression::Null => "Null".to_string(),
211 SqlExpression::BooleanLiteral(b) => format!("BooleanLiteral({b})"),
212 SqlExpression::DateTimeConstructor {
213 year,
214 month,
215 day,
216 hour,
217 minute,
218 second,
219 } => {
220 let time_part = match (hour, minute, second) {
221 (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
222 (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
223 _ => String::new(),
224 };
225 format!("DateTimeConstructor({year}-{month:02}-{day:02}{time_part})")
226 }
227 SqlExpression::DateTimeToday {
228 hour,
229 minute,
230 second,
231 } => {
232 let time_part = match (hour, minute, second) {
233 (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
234 (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
235 _ => String::new(),
236 };
237 format!("DateTimeToday({time_part})")
238 }
239 SqlExpression::WindowFunction {
240 name,
241 args,
242 window_spec: _,
243 } => {
244 let args_str = args
245 .iter()
246 .map(format_expression_ast)
247 .collect::<Vec<_>>()
248 .join(", ");
249 format!("WindowFunction {{ name: \"{name}\", args: [{args_str}], window_spec: ... }}")
250 }
251 SqlExpression::ChainedMethodCall { base, method, args } => {
252 let args_str = args
253 .iter()
254 .map(format_expression_ast)
255 .collect::<Vec<_>>()
256 .join(", ");
257 format!(
258 "ChainedMethodCall {{ base: {}, method: \"{method}\", args: [{args_str}] }}",
259 format_expression_ast(base)
260 )
261 }
262 SqlExpression::Not { expr } => {
263 format!("Not {{ expr: {} }}", format_expression_ast(expr))
264 }
265 SqlExpression::CaseExpression {
266 when_branches,
267 else_branch,
268 } => {
269 let mut result = String::from("CaseExpression { when_branches: [");
270 for branch in when_branches {
271 result.push_str(&format!(
272 " {{ condition: {}, result: {} }},",
273 format_expression_ast(&branch.condition),
274 format_expression_ast(&branch.result)
275 ));
276 }
277 result.push_str("], else_branch: ");
278 if let Some(else_expr) = else_branch {
279 result.push_str(&format_expression_ast(else_expr));
280 } else {
281 result.push_str("None");
282 }
283 result.push_str(" }");
284 result
285 }
286 SqlExpression::SimpleCaseExpression {
287 expr,
288 when_branches,
289 else_branch,
290 } => {
291 let mut result = format!(
292 "SimpleCaseExpression {{ expr: {}, when_branches: [",
293 format_expression_ast(expr)
294 );
295 for branch in when_branches {
296 result.push_str(&format!(
297 " {{ value: {}, result: {} }},",
298 format_expression_ast(&branch.value),
299 format_expression_ast(&branch.result)
300 ));
301 }
302 result.push_str("], else_branch: ");
303 if let Some(else_expr) = else_branch {
304 result.push_str(&format_expression_ast(else_expr));
305 } else {
306 result.push_str("None");
307 }
308 result.push_str(" }");
309 result
310 }
311 SqlExpression::ScalarSubquery { query: _ } => {
312 format!("ScalarSubquery {{ query: <SelectStatement> }}")
313 }
314 SqlExpression::InSubquery { expr, subquery: _ } => {
315 format!(
316 "InSubquery {{ expr: {}, subquery: <SelectStatement> }}",
317 format_expression_ast(expr)
318 )
319 }
320 SqlExpression::NotInSubquery { expr, subquery: _ } => {
321 format!(
322 "NotInSubquery {{ expr: {}, subquery: <SelectStatement> }}",
323 format_expression_ast(expr)
324 )
325 }
326 SqlExpression::Unnest { column, delimiter } => {
327 format!(
328 "Unnest {{ column: {}, delimiter: \"{}\" }}",
329 format_expression_ast(column),
330 delimiter
331 )
332 }
333 }
334}
335
336fn extract_text_between_positions(text: &str, start: usize, end: usize) -> String {
338 if start >= text.len() || end > text.len() || start >= end {
339 return String::new();
340 }
341 text[start..end].to_string()
342}
343
344fn find_token_position(query: &str, target: Token, skip_count: usize) -> Option<usize> {
346 let mut lexer = Lexer::new(query);
347 let mut found_count = 0;
348
349 loop {
350 let pos = lexer.get_position();
351 let token = lexer.next_token();
352 if token == Token::Eof {
353 break;
354 }
355 if token == target {
356 if found_count == skip_count {
357 return Some(pos);
358 }
359 found_count += 1;
360 }
361 }
362 None
363}
364
365pub fn format_sql_with_preserved_parens(query: &str, cols_per_line: usize) -> Vec<String> {
366 let mut parser = Parser::new(query);
367 let stmt = match parser.parse() {
368 Ok(s) => s,
369 Err(_) => return vec![query.to_string()],
370 };
371
372 let mut lines = Vec::new();
373 let mut lexer = Lexer::new(query);
374 let mut tokens_with_pos = Vec::new();
375
376 loop {
378 let pos = lexer.get_position();
379 let token = lexer.next_token();
380 if token == Token::Eof {
381 break;
382 }
383 tokens_with_pos.push((token, pos));
384 }
385
386 let mut i = 0;
388 while i < tokens_with_pos.len() {
389 match &tokens_with_pos[i].0 {
390 Token::Select => {
391 let _select_start = tokens_with_pos[i].1;
392 i += 1;
393
394 let has_distinct = if i < tokens_with_pos.len() {
396 matches!(tokens_with_pos[i].0, Token::Distinct)
397 } else {
398 false
399 };
400
401 if has_distinct {
402 i += 1;
403 }
404
405 let _select_end = query.len();
407 let _col_count = 0;
408 let _current_line_cols: Vec<String> = Vec::new();
409 let mut all_select_lines = Vec::new();
410
411 let use_pretty_format = stmt.columns.len() > cols_per_line;
413
414 if use_pretty_format {
415 let select_text = if has_distinct {
417 "SELECT DISTINCT".to_string()
418 } else {
419 "SELECT".to_string()
420 };
421 all_select_lines.push(select_text);
422
423 for (idx, col) in stmt.columns.iter().enumerate() {
425 let is_last = idx == stmt.columns.len() - 1;
426 let formatted_col = if needs_quotes(col) {
428 format!("\"{}\"", col)
429 } else {
430 col.clone()
431 };
432 let col_text = if is_last {
433 format!(" {}", formatted_col)
434 } else {
435 format!(" {},", formatted_col)
436 };
437 all_select_lines.push(col_text);
438 }
439 } else {
440 let mut select_line = if has_distinct {
442 "SELECT DISTINCT ".to_string()
443 } else {
444 "SELECT ".to_string()
445 };
446
447 for (idx, col) in stmt.columns.iter().enumerate() {
448 if idx > 0 {
449 select_line.push_str(", ");
450 }
451 if needs_quotes(col) {
453 select_line.push_str(&format!("\"{}\"", col));
454 } else {
455 select_line.push_str(col);
456 }
457 }
458 all_select_lines.push(select_line);
459 }
460
461 lines.extend(all_select_lines);
462
463 while i < tokens_with_pos.len() {
465 match &tokens_with_pos[i].0 {
466 Token::From => break,
467 _ => i += 1,
468 }
469 }
470 }
471 Token::From => {
472 let from_start = tokens_with_pos[i].1;
473 i += 1;
474
475 let mut from_end = query.len();
477 while i < tokens_with_pos.len() {
478 match &tokens_with_pos[i].0 {
479 Token::Where
480 | Token::GroupBy
481 | Token::OrderBy
482 | Token::Limit
483 | Token::Having
484 | Token::Eof => {
485 from_end = tokens_with_pos[i].1;
486 break;
487 }
488 _ => i += 1,
489 }
490 }
491
492 let from_text = extract_text_between_positions(query, from_start, from_end);
493 lines.push(from_text.trim().to_string());
494 }
495 Token::Where => {
496 let where_start = tokens_with_pos[i].1;
497 i += 1;
498
499 let mut where_end = query.len();
501 let mut paren_depth = 0;
502 while i < tokens_with_pos.len() {
503 match &tokens_with_pos[i].0 {
504 Token::LeftParen => {
505 paren_depth += 1;
506 i += 1;
507 }
508 Token::RightParen => {
509 paren_depth -= 1;
510 i += 1;
511 }
512 Token::GroupBy
513 | Token::OrderBy
514 | Token::Limit
515 | Token::Having
516 | Token::Eof
517 if paren_depth == 0 =>
518 {
519 where_end = tokens_with_pos[i].1;
520 break;
521 }
522 _ => i += 1,
523 }
524 }
525
526 let where_text = extract_text_between_positions(query, where_start, where_end);
527 let formatted_where = format_where_clause_with_parens(where_text.trim());
528 lines.extend(formatted_where);
529 }
530 Token::GroupBy => {
531 let group_start = tokens_with_pos[i].1;
532 i += 1;
533
534 if i < tokens_with_pos.len() && matches!(tokens_with_pos[i].0, Token::By) {
536 i += 1;
537 }
538
539 while i < tokens_with_pos.len() {
541 match &tokens_with_pos[i].0 {
542 Token::OrderBy | Token::Limit | Token::Having | Token::Eof => break,
543 _ => i += 1,
544 }
545 }
546
547 if i > 0 {
548 let group_text = extract_text_between_positions(
549 query,
550 group_start,
551 tokens_with_pos[i - 1].1,
552 );
553 lines.push(format!("GROUP BY {}", group_text.trim()));
554 }
555 }
556 _ => i += 1,
557 }
558 }
559
560 lines
561}
562
563#[allow(unused_assignments)]
564fn format_where_clause_with_parens(where_text: &str) -> Vec<String> {
565 let mut lines = Vec::new();
566 let mut current = String::from("WHERE ");
567 let mut _paren_depth = 0;
568 let mut in_string = false;
569 let mut escape_next = false;
570 let mut chars = where_text.chars().peekable();
571
572 if where_text.trim_start().starts_with("WHERE") || where_text.trim_start().starts_with("where")
574 {
575 let skip_len = if where_text.trim_start().starts_with("WHERE") {
576 5
577 } else {
578 5
579 };
580 for _ in 0..skip_len {
581 chars.next();
582 }
583 while chars.peek() == Some(&' ') {
585 chars.next();
586 }
587 }
588
589 while let Some(ch) = chars.next() {
590 if escape_next {
591 current.push(ch);
592 escape_next = false;
593 continue;
594 }
595
596 match ch {
597 '\\' if in_string => {
598 current.push(ch);
599 escape_next = true;
600 }
601 '\'' => {
602 current.push(ch);
603 in_string = !in_string;
604 }
605 '(' if !in_string => {
606 current.push(ch);
607 _paren_depth += 1;
608 }
609 ')' if !in_string => {
610 current.push(ch);
611 _paren_depth -= 1;
612 }
613 _ => {
614 current.push(ch);
615 }
616 }
617 }
618
619 let cleaned = current.trim().to_string();
621 if !cleaned.is_empty() {
622 lines.push(cleaned);
623 }
624
625 lines
626}
627
628#[must_use]
629pub fn format_sql_pretty_compact(query: &str, cols_per_line: usize) -> Vec<String> {
630 let formatted = format_sql_with_preserved_parens(query, cols_per_line);
632
633 formatted
635 .into_iter()
636 .filter(|line| !line.trim().is_empty())
637 .map(|line| {
638 let mut result = line;
640 for keyword in &[
641 "SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "HAVING", "LIMIT",
642 ] {
643 let pattern = format!("{keyword}");
644 if result.starts_with(&pattern) && !result.starts_with(&format!("{keyword} ")) {
645 result = format!("{keyword} {}", &result[keyword.len()..].trim_start());
646 }
647 }
648 result
649 })
650 .collect()
651}
652
653pub fn format_expression(expr: &SqlExpression) -> String {
654 match expr {
655 SqlExpression::Column(column_ref) => {
656 column_ref.to_sql()
658 }
659 SqlExpression::StringLiteral(value) => format!("'{value}'"),
660 SqlExpression::NumberLiteral(value) => value.clone(),
661 SqlExpression::BinaryOp { left, op, right } => {
662 format!(
663 "{} {} {}",
664 format_expression(left),
665 op,
666 format_expression(right)
667 )
668 }
669 SqlExpression::FunctionCall {
670 name,
671 args,
672 distinct,
673 } => {
674 let args_str = args
675 .iter()
676 .map(format_expression)
677 .collect::<Vec<_>>()
678 .join(", ");
679 if *distinct {
680 format!("{name}(DISTINCT {args_str})")
681 } else {
682 format!("{name}({args_str})")
683 }
684 }
685 SqlExpression::MethodCall {
686 object,
687 method,
688 args,
689 } => {
690 let args_str = args
691 .iter()
692 .map(format_expression)
693 .collect::<Vec<_>>()
694 .join(", ");
695 if args.is_empty() {
696 format!("{object}.{method}()")
697 } else {
698 format!("{object}.{method}({args_str})")
699 }
700 }
701 SqlExpression::InList { expr, values } => {
702 let values_str = values
703 .iter()
704 .map(format_expression)
705 .collect::<Vec<_>>()
706 .join(", ");
707 format!("{} IN ({})", format_expression(expr), values_str)
708 }
709 SqlExpression::NotInList { expr, values } => {
710 let values_str = values
711 .iter()
712 .map(format_expression)
713 .collect::<Vec<_>>()
714 .join(", ");
715 format!("{} NOT IN ({})", format_expression(expr), values_str)
716 }
717 SqlExpression::Between { expr, lower, upper } => {
718 format!(
719 "{} BETWEEN {} AND {}",
720 format_expression(expr),
721 format_expression(lower),
722 format_expression(upper)
723 )
724 }
725 SqlExpression::Null => "NULL".to_string(),
726 SqlExpression::BooleanLiteral(b) => b.to_string().to_uppercase(),
727 SqlExpression::DateTimeConstructor {
728 year,
729 month,
730 day,
731 hour,
732 minute,
733 second,
734 } => {
735 let time_part = match (hour, minute, second) {
736 (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
737 (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
738 _ => String::new(),
739 };
740 format!("DATETIME({year}, {month}, {day}{time_part})")
741 }
742 SqlExpression::DateTimeToday {
743 hour,
744 minute,
745 second,
746 } => {
747 let time_part = match (hour, minute, second) {
748 (Some(h), Some(m), Some(s)) => format!(", {h}, {m}, {s}"),
749 (Some(h), Some(m), None) => format!(", {h}, {m}"),
750 (Some(h), None, None) => format!(", {h}"),
751 _ => String::new(),
752 };
753 format!("TODAY({time_part})")
754 }
755 SqlExpression::WindowFunction {
756 name,
757 args,
758 window_spec,
759 } => {
760 let args_str = args
761 .iter()
762 .map(format_expression)
763 .collect::<Vec<_>>()
764 .join(", ");
765
766 let mut result = format!("{name}({args_str}) OVER (");
767
768 if !window_spec.partition_by.is_empty() {
770 result.push_str("PARTITION BY ");
771 result.push_str(&window_spec.partition_by.join(", "));
772 }
773
774 if !window_spec.order_by.is_empty() {
776 if !window_spec.partition_by.is_empty() {
777 result.push(' ');
778 }
779 result.push_str("ORDER BY ");
780 let order_strs: Vec<String> = window_spec
781 .order_by
782 .iter()
783 .map(|col| {
784 let dir = match col.direction {
785 SortDirection::Asc => " ASC",
786 SortDirection::Desc => " DESC",
787 };
788 let expr_str = match &col.expr {
790 SqlExpression::Column(col_ref) => col_ref.name.clone(),
791 _ => format_expression(&col.expr),
792 };
793 format!("{}{}", expr_str, dir)
794 })
795 .collect();
796 result.push_str(&order_strs.join(", "));
797 }
798
799 result.push(')');
800 result
801 }
802 SqlExpression::ChainedMethodCall { base, method, args } => {
803 let base_str = format_expression(base);
804 let args_str = args
805 .iter()
806 .map(format_expression)
807 .collect::<Vec<_>>()
808 .join(", ");
809 if args.is_empty() {
810 format!("{base_str}.{method}()")
811 } else {
812 format!("{base_str}.{method}({args_str})")
813 }
814 }
815 SqlExpression::Not { expr } => {
816 format!("NOT {}", format_expression(expr))
817 }
818 SqlExpression::CaseExpression {
819 when_branches,
820 else_branch,
821 } => format_case_expression(when_branches, else_branch.as_ref().map(|v| &**v)),
822 SqlExpression::SimpleCaseExpression {
823 expr,
824 when_branches,
825 else_branch,
826 } => format_simple_case_expression(expr, when_branches, else_branch.as_ref().map(|v| &**v)),
827 SqlExpression::ScalarSubquery { query: _ } => {
828 "(SELECT ...)".to_string()
830 }
831 SqlExpression::InSubquery { expr, subquery: _ } => {
832 format!("{} IN (SELECT ...)", format_expression(expr))
833 }
834 SqlExpression::NotInSubquery { expr, subquery: _ } => {
835 format!("{} NOT IN (SELECT ...)", format_expression(expr))
836 }
837 SqlExpression::Unnest { column, delimiter } => {
838 format!("UNNEST({}, '{}')", format_expression(column), delimiter)
839 }
840 }
841}
842
843fn format_token(token: &Token) -> String {
844 match token {
845 Token::Identifier(s) => s.clone(),
846 Token::QuotedIdentifier(s) => format!("\"{s}\""),
847 Token::StringLiteral(s) => format!("'{s}'"),
848 Token::NumberLiteral(n) => n.clone(),
849 Token::DateTime => "DateTime".to_string(),
850 Token::Case => "CASE".to_string(),
851 Token::When => "WHEN".to_string(),
852 Token::Then => "THEN".to_string(),
853 Token::Else => "ELSE".to_string(),
854 Token::End => "END".to_string(),
855 Token::Distinct => "DISTINCT".to_string(),
856 Token::Over => "OVER".to_string(),
857 Token::Partition => "PARTITION".to_string(),
858 Token::By => "BY".to_string(),
859 Token::LeftParen => "(".to_string(),
860 Token::RightParen => ")".to_string(),
861 Token::Comma => ",".to_string(),
862 Token::Dot => ".".to_string(),
863 Token::Equal => "=".to_string(),
864 Token::NotEqual => "!=".to_string(),
865 Token::LessThan => "<".to_string(),
866 Token::GreaterThan => ">".to_string(),
867 Token::LessThanOrEqual => "<=".to_string(),
868 Token::GreaterThanOrEqual => ">=".to_string(),
869 Token::In => "IN".to_string(),
870 _ => format!("{token:?}").to_uppercase(),
871 }
872}
873
874fn needs_quotes(name: &str) -> bool {
876 if name.contains('-') || name.contains(' ') || name.contains('.') || name.contains('/') {
878 return true;
879 }
880
881 if name.chars().next().map_or(false, |c| c.is_ascii_digit()) {
883 return true;
884 }
885
886 let reserved_words = [
888 "SELECT", "FROM", "WHERE", "ORDER", "GROUP", "BY", "HAVING", "INSERT", "UPDATE", "DELETE",
889 "CREATE", "DROP", "ALTER", "TABLE", "INDEX", "VIEW", "AND", "OR", "NOT", "IN", "EXISTS",
890 "BETWEEN", "LIKE", "CASE", "WHEN", "THEN", "ELSE", "END", "JOIN", "LEFT", "RIGHT", "INNER",
891 "OUTER", "ON", "AS", "DISTINCT", "ALL", "TOP", "LIMIT", "OFFSET", "ASC", "DESC",
892 ];
893
894 let upper_name = name.to_uppercase();
895 if reserved_words.contains(&upper_name.as_str()) {
896 return true;
897 }
898
899 !name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
902}
903
904fn format_case_expression(
906 when_branches: &[crate::sql::recursive_parser::WhenBranch],
907 else_branch: Option<&SqlExpression>,
908) -> String {
909 let is_simple = when_branches.len() <= 1
911 && when_branches
912 .iter()
913 .all(|b| expr_is_simple(&b.condition) && expr_is_simple(&b.result))
914 && else_branch.map_or(true, expr_is_simple);
915
916 if is_simple {
917 let mut result = String::from("CASE");
919 for branch in when_branches {
920 result.push_str(&format!(
921 " WHEN {} THEN {}",
922 format_expression(&branch.condition),
923 format_expression(&branch.result)
924 ));
925 }
926 if let Some(else_expr) = else_branch {
927 result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
928 }
929 result.push_str(" END");
930 result
931 } else {
932 let mut result = String::from("CASE");
934 for branch in when_branches {
935 result.push_str(&format!(
936 "\n WHEN {} THEN {}",
937 format_expression(&branch.condition),
938 format_expression(&branch.result)
939 ));
940 }
941 if let Some(else_expr) = else_branch {
942 result.push_str(&format!("\n ELSE {}", format_expression(else_expr)));
943 }
944 result.push_str("\n END");
945 result
946 }
947}
948
949fn format_simple_case_expression(
951 expr: &SqlExpression,
952 when_branches: &[crate::sql::parser::ast::SimpleWhenBranch],
953 else_branch: Option<&SqlExpression>,
954) -> String {
955 let is_simple = when_branches.len() <= 2
957 && expr_is_simple(expr)
958 && when_branches
959 .iter()
960 .all(|b| expr_is_simple(&b.value) && expr_is_simple(&b.result))
961 && else_branch.map_or(true, expr_is_simple);
962
963 if is_simple {
964 let mut result = format!("CASE {}", format_expression(expr));
966 for branch in when_branches {
967 result.push_str(&format!(
968 " WHEN {} THEN {}",
969 format_expression(&branch.value),
970 format_expression(&branch.result)
971 ));
972 }
973 if let Some(else_expr) = else_branch {
974 result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
975 }
976 result.push_str(" END");
977 result
978 } else {
979 let mut result = format!("CASE {}", format_expression(expr));
981 for branch in when_branches {
982 result.push_str(&format!(
983 "\n WHEN {} THEN {}",
984 format_expression(&branch.value),
985 format_expression(&branch.result)
986 ));
987 }
988 if let Some(else_expr) = else_branch {
989 result.push_str(&format!("\n ELSE {}", format_expression(else_expr)));
990 }
991 result.push_str("\n END");
992 result
993 }
994}
995
996fn expr_is_simple(expr: &SqlExpression) -> bool {
998 match expr {
999 SqlExpression::Column(_)
1000 | SqlExpression::StringLiteral(_)
1001 | SqlExpression::NumberLiteral(_)
1002 | SqlExpression::BooleanLiteral(_)
1003 | SqlExpression::Null => true,
1004 SqlExpression::BinaryOp { left, right, .. } => {
1005 expr_is_simple(left) && expr_is_simple(right)
1006 }
1007 SqlExpression::FunctionCall { args, .. } => {
1008 args.len() <= 2 && args.iter().all(expr_is_simple)
1009 }
1010 _ => false,
1011 }
1012}