1use super::ast::{LogicalOp, SelectStatement, SortDirection, SqlExpression, WhereClause};
5use super::lexer::{Lexer, Token};
6use crate::sql::recursive_parser::Parser;
7
8#[must_use]
9pub fn format_sql_pretty(query: &str) -> Vec<String> {
10 format_sql_pretty_compact(query, 5) }
12
13#[must_use]
15pub fn format_ast_tree(query: &str) -> String {
16 let mut parser = Parser::new(query);
17 match parser.parse() {
18 Ok(stmt) => format_select_statement(&stmt, 0),
19 Err(e) => format!("❌ PARSE ERROR ❌\n{e}\n\n⚠️ The query could not be parsed correctly.\n💡 Check parentheses, operators, and syntax."),
20 }
21}
22
23fn format_select_statement(stmt: &SelectStatement, indent: usize) -> String {
24 let mut result = String::new();
25 let indent_str = " ".repeat(indent);
26
27 result.push_str(&format!("{indent_str}SelectStatement {{\n"));
28
29 result.push_str(&format!("{indent_str} columns: ["));
31 if stmt.columns.is_empty() {
32 result.push_str("],\n");
33 } else {
34 result.push('\n');
35 for col in &stmt.columns {
36 result.push_str(&format!("{indent_str} \"{col}\",\n"));
37 }
38 result.push_str(&format!("{indent_str} ],\n"));
39 }
40
41 if let Some(table) = &stmt.from_table {
43 result.push_str(&format!("{indent_str} from_table: \"{table}\",\n"));
44 }
45
46 if let Some(where_clause) = &stmt.where_clause {
48 result.push_str(&format!("{indent_str} where_clause: {{\n"));
49 result.push_str(&format_where_clause(where_clause, indent + 2));
50 result.push_str(&format!("{indent_str} }},\n"));
51 }
52
53 if let Some(order_by) = &stmt.order_by {
55 result.push_str(&format!("{indent_str} order_by: ["));
56 if order_by.is_empty() {
57 result.push_str("],\n");
58 } else {
59 result.push('\n');
60 for col in order_by {
61 let dir = match col.direction {
62 SortDirection::Asc => "ASC",
63 SortDirection::Desc => "DESC",
64 };
65 result.push_str(&format!(
66 "{indent_str} {{ column: \"{}\", direction: {dir} }},\n",
67 col.column
68 ));
69 }
70 result.push_str(&format!("{indent_str} ],\n"));
71 }
72 }
73
74 if let Some(group_by) = &stmt.group_by {
76 result.push_str(&format!("{indent_str} group_by: ["));
77 if group_by.is_empty() {
78 result.push_str("],\n");
79 } else {
80 result.push('\n');
81 for expr in group_by {
82 result.push_str(&format!("{indent_str} \"{:?}\",\n", expr));
83 }
84 result.push_str(&format!("{indent_str} ],\n"));
85 }
86 }
87
88 if let Some(limit) = stmt.limit {
90 result.push_str(&format!("{indent_str} limit: {limit},\n"));
91 }
92
93 if stmt.distinct {
95 result.push_str(&format!("{indent_str} distinct: true,\n"));
96 }
97
98 result.push_str(&format!("{indent_str}}}\n"));
99 result
100}
101
102fn format_where_clause(clause: &WhereClause, indent: usize) -> String {
103 let mut result = String::new();
104 let indent_str = " ".repeat(indent);
105
106 result.push_str(&format!("{indent_str}conditions: [\n"));
107 for (i, condition) in clause.conditions.iter().enumerate() {
108 result.push_str(&format!("{indent_str} {{\n"));
109 result.push_str(&format!(
110 "{indent_str} expr: {},\n",
111 format_expression_ast(&condition.expr)
112 ));
113
114 if let Some(connector) = &condition.connector {
115 let conn_str = match connector {
116 LogicalOp::And => "AND",
117 LogicalOp::Or => "OR",
118 };
119 result.push_str(&format!("{indent_str} connector: {conn_str},\n"));
120 }
121
122 result.push_str(&format!("{indent_str} }}"));
123 if i < clause.conditions.len() - 1 {
124 result.push(',');
125 }
126 result.push('\n');
127 }
128 result.push_str(&format!("{indent_str}]\n"));
129
130 result
131}
132
133pub fn format_expression_ast(expr: &SqlExpression) -> String {
134 match expr {
135 SqlExpression::Column(name) => format!("Column(\"{name}\")"),
136 SqlExpression::StringLiteral(value) => format!("StringLiteral(\"{value}\")"),
137 SqlExpression::NumberLiteral(value) => format!("NumberLiteral({value})"),
138 SqlExpression::BinaryOp { left, op, right } => {
139 format!(
140 "BinaryOp {{ left: {}, op: \"{op}\", right: {} }}",
141 format_expression_ast(left),
142 format_expression_ast(right)
143 )
144 }
145 SqlExpression::FunctionCall {
146 name,
147 args,
148 distinct,
149 } => {
150 let args_str = args
151 .iter()
152 .map(format_expression_ast)
153 .collect::<Vec<_>>()
154 .join(", ");
155 if *distinct {
156 format!("FunctionCall {{ name: \"{name}\", args: [{args_str}], distinct: true }}")
157 } else {
158 format!("FunctionCall {{ name: \"{name}\", args: [{args_str}] }}")
159 }
160 }
161 SqlExpression::MethodCall {
162 object,
163 method,
164 args,
165 } => {
166 let args_str = args
167 .iter()
168 .map(format_expression_ast)
169 .collect::<Vec<_>>()
170 .join(", ");
171 format!(
172 "MethodCall {{ object: \"{object}\", method: \"{method}\", args: [{args_str}] }}"
173 )
174 }
175 SqlExpression::InList { expr, values } => {
176 let values_str = values
177 .iter()
178 .map(format_expression_ast)
179 .collect::<Vec<_>>()
180 .join(", ");
181 format!(
182 "InList {{ expr: {}, values: [{values_str}] }}",
183 format_expression_ast(expr)
184 )
185 }
186 SqlExpression::NotInList { expr, values } => {
187 let values_str = values
188 .iter()
189 .map(format_expression_ast)
190 .collect::<Vec<_>>()
191 .join(", ");
192 format!(
193 "NotInList {{ expr: {}, values: [{values_str}] }}",
194 format_expression_ast(expr)
195 )
196 }
197 SqlExpression::Between { expr, lower, upper } => {
198 format!(
199 "Between {{ expr: {}, lower: {}, upper: {} }}",
200 format_expression_ast(expr),
201 format_expression_ast(lower),
202 format_expression_ast(upper)
203 )
204 }
205 SqlExpression::Null => "Null".to_string(),
206 SqlExpression::BooleanLiteral(b) => format!("BooleanLiteral({b})"),
207 SqlExpression::DateTimeConstructor {
208 year,
209 month,
210 day,
211 hour,
212 minute,
213 second,
214 } => {
215 let time_part = match (hour, minute, second) {
216 (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
217 (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
218 _ => String::new(),
219 };
220 format!("DateTimeConstructor({year}-{month:02}-{day:02}{time_part})")
221 }
222 SqlExpression::DateTimeToday {
223 hour,
224 minute,
225 second,
226 } => {
227 let time_part = match (hour, minute, second) {
228 (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
229 (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
230 _ => String::new(),
231 };
232 format!("DateTimeToday({time_part})")
233 }
234 SqlExpression::WindowFunction {
235 name,
236 args,
237 window_spec: _,
238 } => {
239 let args_str = args
240 .iter()
241 .map(format_expression_ast)
242 .collect::<Vec<_>>()
243 .join(", ");
244 format!("WindowFunction {{ name: \"{name}\", args: [{args_str}], window_spec: ... }}")
245 }
246 SqlExpression::ChainedMethodCall { base, method, args } => {
247 let args_str = args
248 .iter()
249 .map(format_expression_ast)
250 .collect::<Vec<_>>()
251 .join(", ");
252 format!(
253 "ChainedMethodCall {{ base: {}, method: \"{method}\", args: [{args_str}] }}",
254 format_expression_ast(base)
255 )
256 }
257 SqlExpression::Not { expr } => {
258 format!("Not {{ expr: {} }}", format_expression_ast(expr))
259 }
260 SqlExpression::CaseExpression {
261 when_branches,
262 else_branch,
263 } => {
264 let mut result = String::from("CaseExpression { when_branches: [");
265 for branch in when_branches {
266 result.push_str(&format!(
267 " {{ condition: {}, result: {} }},",
268 format_expression_ast(&branch.condition),
269 format_expression_ast(&branch.result)
270 ));
271 }
272 result.push_str("], else_branch: ");
273 if let Some(else_expr) = else_branch {
274 result.push_str(&format_expression_ast(else_expr));
275 } else {
276 result.push_str("None");
277 }
278 result.push_str(" }");
279 result
280 }
281 SqlExpression::SimpleCaseExpression {
282 expr,
283 when_branches,
284 else_branch,
285 } => {
286 let mut result = format!(
287 "SimpleCaseExpression {{ expr: {}, when_branches: [",
288 format_expression_ast(expr)
289 );
290 for branch in when_branches {
291 result.push_str(&format!(
292 " {{ value: {}, result: {} }},",
293 format_expression_ast(&branch.value),
294 format_expression_ast(&branch.result)
295 ));
296 }
297 result.push_str("], else_branch: ");
298 if let Some(else_expr) = else_branch {
299 result.push_str(&format_expression_ast(else_expr));
300 } else {
301 result.push_str("None");
302 }
303 result.push_str(" }");
304 result
305 }
306 SqlExpression::ScalarSubquery { query: _ } => {
307 format!("ScalarSubquery {{ query: <SelectStatement> }}")
308 }
309 SqlExpression::InSubquery { expr, subquery: _ } => {
310 format!(
311 "InSubquery {{ expr: {}, subquery: <SelectStatement> }}",
312 format_expression_ast(expr)
313 )
314 }
315 SqlExpression::NotInSubquery { expr, subquery: _ } => {
316 format!(
317 "NotInSubquery {{ expr: {}, subquery: <SelectStatement> }}",
318 format_expression_ast(expr)
319 )
320 }
321 SqlExpression::Unnest { column, delimiter } => {
322 format!(
323 "Unnest {{ column: {}, delimiter: \"{}\" }}",
324 format_expression_ast(column),
325 delimiter
326 )
327 }
328 }
329}
330
331fn extract_text_between_positions(text: &str, start: usize, end: usize) -> String {
333 if start >= text.len() || end > text.len() || start >= end {
334 return String::new();
335 }
336 text[start..end].to_string()
337}
338
339fn find_token_position(query: &str, target: Token, skip_count: usize) -> Option<usize> {
341 let mut lexer = Lexer::new(query);
342 let mut found_count = 0;
343
344 loop {
345 let pos = lexer.get_position();
346 let token = lexer.next_token();
347 if token == Token::Eof {
348 break;
349 }
350 if token == target {
351 if found_count == skip_count {
352 return Some(pos);
353 }
354 found_count += 1;
355 }
356 }
357 None
358}
359
360pub fn format_sql_with_preserved_parens(query: &str, cols_per_line: usize) -> Vec<String> {
361 let mut parser = Parser::new(query);
362 let stmt = match parser.parse() {
363 Ok(s) => s,
364 Err(_) => return vec![query.to_string()],
365 };
366
367 let mut lines = Vec::new();
368 let mut lexer = Lexer::new(query);
369 let mut tokens_with_pos = Vec::new();
370
371 loop {
373 let pos = lexer.get_position();
374 let token = lexer.next_token();
375 if token == Token::Eof {
376 break;
377 }
378 tokens_with_pos.push((token, pos));
379 }
380
381 let mut i = 0;
383 while i < tokens_with_pos.len() {
384 match &tokens_with_pos[i].0 {
385 Token::Select => {
386 let _select_start = tokens_with_pos[i].1;
387 i += 1;
388
389 let has_distinct = if i < tokens_with_pos.len() {
391 matches!(tokens_with_pos[i].0, Token::Distinct)
392 } else {
393 false
394 };
395
396 if has_distinct {
397 i += 1;
398 }
399
400 let _select_end = query.len();
402 let _col_count = 0;
403 let _current_line_cols: Vec<String> = Vec::new();
404 let mut all_select_lines = Vec::new();
405
406 let use_pretty_format = stmt.columns.len() > cols_per_line;
408
409 if use_pretty_format {
410 let select_text = if has_distinct {
412 "SELECT DISTINCT".to_string()
413 } else {
414 "SELECT".to_string()
415 };
416 all_select_lines.push(select_text);
417
418 for (idx, col) in stmt.columns.iter().enumerate() {
420 let is_last = idx == stmt.columns.len() - 1;
421 let formatted_col = if needs_quotes(col) {
423 format!("\"{}\"", col)
424 } else {
425 col.clone()
426 };
427 let col_text = if is_last {
428 format!(" {}", formatted_col)
429 } else {
430 format!(" {},", formatted_col)
431 };
432 all_select_lines.push(col_text);
433 }
434 } else {
435 let mut select_line = if has_distinct {
437 "SELECT DISTINCT ".to_string()
438 } else {
439 "SELECT ".to_string()
440 };
441
442 for (idx, col) in stmt.columns.iter().enumerate() {
443 if idx > 0 {
444 select_line.push_str(", ");
445 }
446 if needs_quotes(col) {
448 select_line.push_str(&format!("\"{}\"", col));
449 } else {
450 select_line.push_str(col);
451 }
452 }
453 all_select_lines.push(select_line);
454 }
455
456 lines.extend(all_select_lines);
457
458 while i < tokens_with_pos.len() {
460 match &tokens_with_pos[i].0 {
461 Token::From => break,
462 _ => i += 1,
463 }
464 }
465 }
466 Token::From => {
467 let from_start = tokens_with_pos[i].1;
468 i += 1;
469
470 let mut from_end = query.len();
472 while i < tokens_with_pos.len() {
473 match &tokens_with_pos[i].0 {
474 Token::Where
475 | Token::GroupBy
476 | Token::OrderBy
477 | Token::Limit
478 | Token::Having
479 | Token::Eof => {
480 from_end = tokens_with_pos[i].1;
481 break;
482 }
483 _ => i += 1,
484 }
485 }
486
487 let from_text = extract_text_between_positions(query, from_start, from_end);
488 lines.push(from_text.trim().to_string());
489 }
490 Token::Where => {
491 let where_start = tokens_with_pos[i].1;
492 i += 1;
493
494 let mut where_end = query.len();
496 let mut paren_depth = 0;
497 while i < tokens_with_pos.len() {
498 match &tokens_with_pos[i].0 {
499 Token::LeftParen => {
500 paren_depth += 1;
501 i += 1;
502 }
503 Token::RightParen => {
504 paren_depth -= 1;
505 i += 1;
506 }
507 Token::GroupBy
508 | Token::OrderBy
509 | Token::Limit
510 | Token::Having
511 | Token::Eof
512 if paren_depth == 0 =>
513 {
514 where_end = tokens_with_pos[i].1;
515 break;
516 }
517 _ => i += 1,
518 }
519 }
520
521 let where_text = extract_text_between_positions(query, where_start, where_end);
522 let formatted_where = format_where_clause_with_parens(where_text.trim());
523 lines.extend(formatted_where);
524 }
525 Token::GroupBy => {
526 let group_start = tokens_with_pos[i].1;
527 i += 1;
528
529 if i < tokens_with_pos.len() && matches!(tokens_with_pos[i].0, Token::By) {
531 i += 1;
532 }
533
534 while i < tokens_with_pos.len() {
536 match &tokens_with_pos[i].0 {
537 Token::OrderBy | Token::Limit | Token::Having | Token::Eof => break,
538 _ => i += 1,
539 }
540 }
541
542 if i > 0 {
543 let group_text = extract_text_between_positions(
544 query,
545 group_start,
546 tokens_with_pos[i - 1].1,
547 );
548 lines.push(format!("GROUP BY {}", group_text.trim()));
549 }
550 }
551 _ => i += 1,
552 }
553 }
554
555 lines
556}
557
558#[allow(unused_assignments)]
559fn format_where_clause_with_parens(where_text: &str) -> Vec<String> {
560 let mut lines = Vec::new();
561 let mut current = String::from("WHERE ");
562 let mut _paren_depth = 0;
563 let mut in_string = false;
564 let mut escape_next = false;
565 let mut chars = where_text.chars().peekable();
566
567 if where_text.trim_start().starts_with("WHERE") || where_text.trim_start().starts_with("where")
569 {
570 let skip_len = if where_text.trim_start().starts_with("WHERE") {
571 5
572 } else {
573 5
574 };
575 for _ in 0..skip_len {
576 chars.next();
577 }
578 while chars.peek() == Some(&' ') {
580 chars.next();
581 }
582 }
583
584 while let Some(ch) = chars.next() {
585 if escape_next {
586 current.push(ch);
587 escape_next = false;
588 continue;
589 }
590
591 match ch {
592 '\\' if in_string => {
593 current.push(ch);
594 escape_next = true;
595 }
596 '\'' => {
597 current.push(ch);
598 in_string = !in_string;
599 }
600 '(' if !in_string => {
601 current.push(ch);
602 _paren_depth += 1;
603 }
604 ')' if !in_string => {
605 current.push(ch);
606 _paren_depth -= 1;
607 }
608 _ => {
609 current.push(ch);
610 }
611 }
612 }
613
614 let cleaned = current.trim().to_string();
616 if !cleaned.is_empty() {
617 lines.push(cleaned);
618 }
619
620 lines
621}
622
623#[must_use]
624pub fn format_sql_pretty_compact(query: &str, cols_per_line: usize) -> Vec<String> {
625 let formatted = format_sql_with_preserved_parens(query, cols_per_line);
627
628 formatted
630 .into_iter()
631 .filter(|line| !line.trim().is_empty())
632 .map(|line| {
633 let mut result = line;
635 for keyword in &[
636 "SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "HAVING", "LIMIT",
637 ] {
638 let pattern = format!("{keyword}");
639 if result.starts_with(&pattern) && !result.starts_with(&format!("{keyword} ")) {
640 result = format!("{keyword} {}", &result[keyword.len()..].trim_start());
641 }
642 }
643 result
644 })
645 .collect()
646}
647
648pub fn format_expression(expr: &SqlExpression) -> String {
649 match expr {
650 SqlExpression::Column(column_ref) => {
651 column_ref.to_sql()
653 }
654 SqlExpression::StringLiteral(value) => format!("'{value}'"),
655 SqlExpression::NumberLiteral(value) => value.clone(),
656 SqlExpression::BinaryOp { left, op, right } => {
657 format!(
658 "{} {} {}",
659 format_expression(left),
660 op,
661 format_expression(right)
662 )
663 }
664 SqlExpression::FunctionCall {
665 name,
666 args,
667 distinct,
668 } => {
669 let args_str = args
670 .iter()
671 .map(format_expression)
672 .collect::<Vec<_>>()
673 .join(", ");
674 if *distinct {
675 format!("{name}(DISTINCT {args_str})")
676 } else {
677 format!("{name}({args_str})")
678 }
679 }
680 SqlExpression::MethodCall {
681 object,
682 method,
683 args,
684 } => {
685 let args_str = args
686 .iter()
687 .map(format_expression)
688 .collect::<Vec<_>>()
689 .join(", ");
690 if args.is_empty() {
691 format!("{object}.{method}()")
692 } else {
693 format!("{object}.{method}({args_str})")
694 }
695 }
696 SqlExpression::InList { expr, values } => {
697 let values_str = values
698 .iter()
699 .map(format_expression)
700 .collect::<Vec<_>>()
701 .join(", ");
702 format!("{} IN ({})", format_expression(expr), values_str)
703 }
704 SqlExpression::NotInList { expr, values } => {
705 let values_str = values
706 .iter()
707 .map(format_expression)
708 .collect::<Vec<_>>()
709 .join(", ");
710 format!("{} NOT IN ({})", format_expression(expr), values_str)
711 }
712 SqlExpression::Between { expr, lower, upper } => {
713 format!(
714 "{} BETWEEN {} AND {}",
715 format_expression(expr),
716 format_expression(lower),
717 format_expression(upper)
718 )
719 }
720 SqlExpression::Null => "NULL".to_string(),
721 SqlExpression::BooleanLiteral(b) => b.to_string().to_uppercase(),
722 SqlExpression::DateTimeConstructor {
723 year,
724 month,
725 day,
726 hour,
727 minute,
728 second,
729 } => {
730 let time_part = match (hour, minute, second) {
731 (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
732 (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
733 _ => String::new(),
734 };
735 format!("DATETIME({year}, {month}, {day}{time_part})")
736 }
737 SqlExpression::DateTimeToday {
738 hour,
739 minute,
740 second,
741 } => {
742 let time_part = match (hour, minute, second) {
743 (Some(h), Some(m), Some(s)) => format!(", {h}, {m}, {s}"),
744 (Some(h), Some(m), None) => format!(", {h}, {m}"),
745 (Some(h), None, None) => format!(", {h}"),
746 _ => String::new(),
747 };
748 format!("TODAY({time_part})")
749 }
750 SqlExpression::WindowFunction {
751 name,
752 args,
753 window_spec,
754 } => {
755 let args_str = args
756 .iter()
757 .map(format_expression)
758 .collect::<Vec<_>>()
759 .join(", ");
760
761 let mut result = format!("{name}({args_str}) OVER (");
762
763 if !window_spec.partition_by.is_empty() {
765 result.push_str("PARTITION BY ");
766 result.push_str(&window_spec.partition_by.join(", "));
767 }
768
769 if !window_spec.order_by.is_empty() {
771 if !window_spec.partition_by.is_empty() {
772 result.push(' ');
773 }
774 result.push_str("ORDER BY ");
775 let order_strs: Vec<String> = window_spec
776 .order_by
777 .iter()
778 .map(|col| {
779 let dir = match col.direction {
780 SortDirection::Asc => " ASC",
781 SortDirection::Desc => " DESC",
782 };
783 format!("{}{}", col.column, dir)
784 })
785 .collect();
786 result.push_str(&order_strs.join(", "));
787 }
788
789 result.push(')');
790 result
791 }
792 SqlExpression::ChainedMethodCall { base, method, args } => {
793 let base_str = format_expression(base);
794 let args_str = args
795 .iter()
796 .map(format_expression)
797 .collect::<Vec<_>>()
798 .join(", ");
799 if args.is_empty() {
800 format!("{base_str}.{method}()")
801 } else {
802 format!("{base_str}.{method}({args_str})")
803 }
804 }
805 SqlExpression::Not { expr } => {
806 format!("NOT {}", format_expression(expr))
807 }
808 SqlExpression::CaseExpression {
809 when_branches,
810 else_branch,
811 } => format_case_expression(when_branches, else_branch.as_ref().map(|v| &**v)),
812 SqlExpression::SimpleCaseExpression {
813 expr,
814 when_branches,
815 else_branch,
816 } => format_simple_case_expression(expr, when_branches, else_branch.as_ref().map(|v| &**v)),
817 SqlExpression::ScalarSubquery { query: _ } => {
818 "(SELECT ...)".to_string()
820 }
821 SqlExpression::InSubquery { expr, subquery: _ } => {
822 format!("{} IN (SELECT ...)", format_expression(expr))
823 }
824 SqlExpression::NotInSubquery { expr, subquery: _ } => {
825 format!("{} NOT IN (SELECT ...)", format_expression(expr))
826 }
827 SqlExpression::Unnest { column, delimiter } => {
828 format!("UNNEST({}, '{}')", format_expression(column), delimiter)
829 }
830 }
831}
832
833fn format_token(token: &Token) -> String {
834 match token {
835 Token::Identifier(s) => s.clone(),
836 Token::QuotedIdentifier(s) => format!("\"{s}\""),
837 Token::StringLiteral(s) => format!("'{s}'"),
838 Token::NumberLiteral(n) => n.clone(),
839 Token::DateTime => "DateTime".to_string(),
840 Token::Case => "CASE".to_string(),
841 Token::When => "WHEN".to_string(),
842 Token::Then => "THEN".to_string(),
843 Token::Else => "ELSE".to_string(),
844 Token::End => "END".to_string(),
845 Token::Distinct => "DISTINCT".to_string(),
846 Token::Over => "OVER".to_string(),
847 Token::Partition => "PARTITION".to_string(),
848 Token::By => "BY".to_string(),
849 Token::LeftParen => "(".to_string(),
850 Token::RightParen => ")".to_string(),
851 Token::Comma => ",".to_string(),
852 Token::Dot => ".".to_string(),
853 Token::Equal => "=".to_string(),
854 Token::NotEqual => "!=".to_string(),
855 Token::LessThan => "<".to_string(),
856 Token::GreaterThan => ">".to_string(),
857 Token::LessThanOrEqual => "<=".to_string(),
858 Token::GreaterThanOrEqual => ">=".to_string(),
859 Token::In => "IN".to_string(),
860 _ => format!("{token:?}").to_uppercase(),
861 }
862}
863
864fn needs_quotes(name: &str) -> bool {
866 if name.contains('-') || name.contains(' ') || name.contains('.') || name.contains('/') {
868 return true;
869 }
870
871 if name.chars().next().map_or(false, |c| c.is_ascii_digit()) {
873 return true;
874 }
875
876 let reserved_words = [
878 "SELECT", "FROM", "WHERE", "ORDER", "GROUP", "BY", "HAVING", "INSERT", "UPDATE", "DELETE",
879 "CREATE", "DROP", "ALTER", "TABLE", "INDEX", "VIEW", "AND", "OR", "NOT", "IN", "EXISTS",
880 "BETWEEN", "LIKE", "CASE", "WHEN", "THEN", "ELSE", "END", "JOIN", "LEFT", "RIGHT", "INNER",
881 "OUTER", "ON", "AS", "DISTINCT", "ALL", "TOP", "LIMIT", "OFFSET", "ASC", "DESC",
882 ];
883
884 let upper_name = name.to_uppercase();
885 if reserved_words.contains(&upper_name.as_str()) {
886 return true;
887 }
888
889 !name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
892}
893
894fn format_case_expression(
896 when_branches: &[crate::sql::recursive_parser::WhenBranch],
897 else_branch: Option<&SqlExpression>,
898) -> String {
899 let is_simple = when_branches.len() <= 1
901 && when_branches
902 .iter()
903 .all(|b| expr_is_simple(&b.condition) && expr_is_simple(&b.result))
904 && else_branch.map_or(true, expr_is_simple);
905
906 if is_simple {
907 let mut result = String::from("CASE");
909 for branch in when_branches {
910 result.push_str(&format!(
911 " WHEN {} THEN {}",
912 format_expression(&branch.condition),
913 format_expression(&branch.result)
914 ));
915 }
916 if let Some(else_expr) = else_branch {
917 result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
918 }
919 result.push_str(" END");
920 result
921 } else {
922 let mut result = String::from("CASE");
924 for branch in when_branches {
925 result.push_str(&format!(
926 "\n WHEN {} THEN {}",
927 format_expression(&branch.condition),
928 format_expression(&branch.result)
929 ));
930 }
931 if let Some(else_expr) = else_branch {
932 result.push_str(&format!("\n ELSE {}", format_expression(else_expr)));
933 }
934 result.push_str("\n END");
935 result
936 }
937}
938
939fn format_simple_case_expression(
941 expr: &SqlExpression,
942 when_branches: &[crate::sql::parser::ast::SimpleWhenBranch],
943 else_branch: Option<&SqlExpression>,
944) -> String {
945 let is_simple = when_branches.len() <= 2
947 && expr_is_simple(expr)
948 && when_branches
949 .iter()
950 .all(|b| expr_is_simple(&b.value) && expr_is_simple(&b.result))
951 && else_branch.map_or(true, expr_is_simple);
952
953 if is_simple {
954 let mut result = format!("CASE {}", format_expression(expr));
956 for branch in when_branches {
957 result.push_str(&format!(
958 " WHEN {} THEN {}",
959 format_expression(&branch.value),
960 format_expression(&branch.result)
961 ));
962 }
963 if let Some(else_expr) = else_branch {
964 result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
965 }
966 result.push_str(" END");
967 result
968 } else {
969 let mut result = format!("CASE {}", format_expression(expr));
971 for branch in when_branches {
972 result.push_str(&format!(
973 "\n WHEN {} THEN {}",
974 format_expression(&branch.value),
975 format_expression(&branch.result)
976 ));
977 }
978 if let Some(else_expr) = else_branch {
979 result.push_str(&format!("\n ELSE {}", format_expression(else_expr)));
980 }
981 result.push_str("\n END");
982 result
983 }
984}
985
986fn expr_is_simple(expr: &SqlExpression) -> bool {
988 match expr {
989 SqlExpression::Column(_)
990 | SqlExpression::StringLiteral(_)
991 | SqlExpression::NumberLiteral(_)
992 | SqlExpression::BooleanLiteral(_)
993 | SqlExpression::Null => true,
994 SqlExpression::BinaryOp { left, right, .. } => {
995 expr_is_simple(left) && expr_is_simple(right)
996 }
997 SqlExpression::FunctionCall { args, .. } => {
998 args.len() <= 2 && args.iter().all(expr_is_simple)
999 }
1000 _ => false,
1001 }
1002}