1use super::ast::{LogicalOp, SelectStatement, SortDirection, SqlExpression, WhereClause};
5use super::lexer::{Lexer, Token};
6use crate::sql::recursive_parser::Parser;
7
8#[must_use]
9pub fn format_sql_pretty(query: &str) -> Vec<String> {
10 format_sql_pretty_compact(query, 5) }
12
13#[must_use]
15pub fn format_ast_tree(query: &str) -> String {
16 let mut parser = Parser::new(query);
17 match parser.parse() {
18 Ok(stmt) => format_select_statement(&stmt, 0),
19 Err(e) => format!("❌ PARSE ERROR ❌\n{e}\n\n⚠️ The query could not be parsed correctly.\n💡 Check parentheses, operators, and syntax."),
20 }
21}
22
23fn format_select_statement(stmt: &SelectStatement, indent: usize) -> String {
24 let mut result = String::new();
25 let indent_str = " ".repeat(indent);
26
27 result.push_str(&format!("{indent_str}SelectStatement {{\n"));
28
29 result.push_str(&format!("{indent_str} columns: ["));
31 if stmt.columns.is_empty() {
32 result.push_str("],\n");
33 } else {
34 result.push('\n');
35 for col in &stmt.columns {
36 result.push_str(&format!("{indent_str} \"{col}\",\n"));
37 }
38 result.push_str(&format!("{indent_str} ],\n"));
39 }
40
41 if let Some(table) = &stmt.from_table {
43 result.push_str(&format!("{indent_str} from_table: \"{table}\",\n"));
44 }
45
46 if let Some(where_clause) = &stmt.where_clause {
48 result.push_str(&format!("{indent_str} where_clause: {{\n"));
49 result.push_str(&format_where_clause(where_clause, indent + 2));
50 result.push_str(&format!("{indent_str} }},\n"));
51 }
52
53 if let Some(order_by) = &stmt.order_by {
55 result.push_str(&format!("{indent_str} order_by: ["));
56 if order_by.is_empty() {
57 result.push_str("],\n");
58 } else {
59 result.push('\n');
60 for col in order_by {
61 let dir = match col.direction {
62 SortDirection::Asc => "ASC",
63 SortDirection::Desc => "DESC",
64 };
65 result.push_str(&format!(
66 "{indent_str} {{ column: \"{}\", direction: {dir} }},\n",
67 col.column
68 ));
69 }
70 result.push_str(&format!("{indent_str} ],\n"));
71 }
72 }
73
74 if let Some(group_by) = &stmt.group_by {
76 result.push_str(&format!("{indent_str} group_by: ["));
77 if group_by.is_empty() {
78 result.push_str("],\n");
79 } else {
80 result.push('\n');
81 for expr in group_by {
82 result.push_str(&format!("{indent_str} \"{:?}\",\n", expr));
83 }
84 result.push_str(&format!("{indent_str} ],\n"));
85 }
86 }
87
88 if let Some(limit) = stmt.limit {
90 result.push_str(&format!("{indent_str} limit: {limit},\n"));
91 }
92
93 if stmt.distinct {
95 result.push_str(&format!("{indent_str} distinct: true,\n"));
96 }
97
98 result.push_str(&format!("{indent_str}}}\n"));
99 result
100}
101
102fn format_where_clause(clause: &WhereClause, indent: usize) -> String {
103 let mut result = String::new();
104 let indent_str = " ".repeat(indent);
105
106 result.push_str(&format!("{indent_str}conditions: [\n"));
107 for (i, condition) in clause.conditions.iter().enumerate() {
108 result.push_str(&format!("{indent_str} {{\n"));
109 result.push_str(&format!(
110 "{indent_str} expr: {},\n",
111 format_expression_ast(&condition.expr)
112 ));
113
114 if let Some(connector) = &condition.connector {
115 let conn_str = match connector {
116 LogicalOp::And => "AND",
117 LogicalOp::Or => "OR",
118 };
119 result.push_str(&format!("{indent_str} connector: {conn_str},\n"));
120 }
121
122 result.push_str(&format!("{indent_str} }}"));
123 if i < clause.conditions.len() - 1 {
124 result.push(',');
125 }
126 result.push('\n');
127 }
128 result.push_str(&format!("{indent_str}]\n"));
129
130 result
131}
132
133pub fn format_expression_ast(expr: &SqlExpression) -> String {
134 match expr {
135 SqlExpression::Column(name) => format!("Column(\"{name}\")"),
136 SqlExpression::StringLiteral(value) => format!("StringLiteral(\"{value}\")"),
137 SqlExpression::NumberLiteral(value) => format!("NumberLiteral({value})"),
138 SqlExpression::BinaryOp { left, op, right } => {
139 format!(
140 "BinaryOp {{ left: {}, op: \"{op}\", right: {} }}",
141 format_expression_ast(left),
142 format_expression_ast(right)
143 )
144 }
145 SqlExpression::FunctionCall {
146 name,
147 args,
148 distinct,
149 } => {
150 let args_str = args
151 .iter()
152 .map(format_expression_ast)
153 .collect::<Vec<_>>()
154 .join(", ");
155 if *distinct {
156 format!("FunctionCall {{ name: \"{name}\", args: [{args_str}], distinct: true }}")
157 } else {
158 format!("FunctionCall {{ name: \"{name}\", args: [{args_str}] }}")
159 }
160 }
161 SqlExpression::MethodCall {
162 object,
163 method,
164 args,
165 } => {
166 let args_str = args
167 .iter()
168 .map(format_expression_ast)
169 .collect::<Vec<_>>()
170 .join(", ");
171 format!(
172 "MethodCall {{ object: \"{object}\", method: \"{method}\", args: [{args_str}] }}"
173 )
174 }
175 SqlExpression::InList { expr, values } => {
176 let values_str = values
177 .iter()
178 .map(format_expression_ast)
179 .collect::<Vec<_>>()
180 .join(", ");
181 format!(
182 "InList {{ expr: {}, values: [{values_str}] }}",
183 format_expression_ast(expr)
184 )
185 }
186 SqlExpression::NotInList { expr, values } => {
187 let values_str = values
188 .iter()
189 .map(format_expression_ast)
190 .collect::<Vec<_>>()
191 .join(", ");
192 format!(
193 "NotInList {{ expr: {}, values: [{values_str}] }}",
194 format_expression_ast(expr)
195 )
196 }
197 SqlExpression::Between { expr, lower, upper } => {
198 format!(
199 "Between {{ expr: {}, lower: {}, upper: {} }}",
200 format_expression_ast(expr),
201 format_expression_ast(lower),
202 format_expression_ast(upper)
203 )
204 }
205 SqlExpression::Null => "Null".to_string(),
206 SqlExpression::BooleanLiteral(b) => format!("BooleanLiteral({b})"),
207 SqlExpression::DateTimeConstructor {
208 year,
209 month,
210 day,
211 hour,
212 minute,
213 second,
214 } => {
215 let time_part = match (hour, minute, second) {
216 (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
217 (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
218 _ => String::new(),
219 };
220 format!("DateTimeConstructor({year}-{month:02}-{day:02}{time_part})")
221 }
222 SqlExpression::DateTimeToday {
223 hour,
224 minute,
225 second,
226 } => {
227 let time_part = match (hour, minute, second) {
228 (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
229 (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
230 _ => String::new(),
231 };
232 format!("DateTimeToday({time_part})")
233 }
234 SqlExpression::WindowFunction {
235 name,
236 args,
237 window_spec: _,
238 } => {
239 let args_str = args
240 .iter()
241 .map(format_expression_ast)
242 .collect::<Vec<_>>()
243 .join(", ");
244 format!("WindowFunction {{ name: \"{name}\", args: [{args_str}], window_spec: ... }}")
245 }
246 SqlExpression::ChainedMethodCall { base, method, args } => {
247 let args_str = args
248 .iter()
249 .map(format_expression_ast)
250 .collect::<Vec<_>>()
251 .join(", ");
252 format!(
253 "ChainedMethodCall {{ base: {}, method: \"{method}\", args: [{args_str}] }}",
254 format_expression_ast(base)
255 )
256 }
257 SqlExpression::Not { expr } => {
258 format!("Not {{ expr: {} }}", format_expression_ast(expr))
259 }
260 SqlExpression::CaseExpression {
261 when_branches,
262 else_branch,
263 } => {
264 let mut result = String::from("CaseExpression { when_branches: [");
265 for branch in when_branches {
266 result.push_str(&format!(
267 " {{ condition: {}, result: {} }},",
268 format_expression_ast(&branch.condition),
269 format_expression_ast(&branch.result)
270 ));
271 }
272 result.push_str("], else_branch: ");
273 if let Some(else_expr) = else_branch {
274 result.push_str(&format_expression_ast(else_expr));
275 } else {
276 result.push_str("None");
277 }
278 result.push_str(" }");
279 result
280 }
281 SqlExpression::SimpleCaseExpression {
282 expr,
283 when_branches,
284 else_branch,
285 } => {
286 let mut result = format!(
287 "SimpleCaseExpression {{ expr: {}, when_branches: [",
288 format_expression_ast(expr)
289 );
290 for branch in when_branches {
291 result.push_str(&format!(
292 " {{ value: {}, result: {} }},",
293 format_expression_ast(&branch.value),
294 format_expression_ast(&branch.result)
295 ));
296 }
297 result.push_str("], else_branch: ");
298 if let Some(else_expr) = else_branch {
299 result.push_str(&format_expression_ast(else_expr));
300 } else {
301 result.push_str("None");
302 }
303 result.push_str(" }");
304 result
305 }
306 SqlExpression::ScalarSubquery { query: _ } => {
307 format!("ScalarSubquery {{ query: <SelectStatement> }}")
308 }
309 SqlExpression::InSubquery { expr, subquery: _ } => {
310 format!(
311 "InSubquery {{ expr: {}, subquery: <SelectStatement> }}",
312 format_expression_ast(expr)
313 )
314 }
315 SqlExpression::NotInSubquery { expr, subquery: _ } => {
316 format!(
317 "NotInSubquery {{ expr: {}, subquery: <SelectStatement> }}",
318 format_expression_ast(expr)
319 )
320 }
321 }
322}
323
324fn extract_text_between_positions(text: &str, start: usize, end: usize) -> String {
326 if start >= text.len() || end > text.len() || start >= end {
327 return String::new();
328 }
329 text[start..end].to_string()
330}
331
332fn find_token_position(query: &str, target: Token, skip_count: usize) -> Option<usize> {
334 let mut lexer = Lexer::new(query);
335 let mut found_count = 0;
336
337 loop {
338 let pos = lexer.get_position();
339 let token = lexer.next_token();
340 if token == Token::Eof {
341 break;
342 }
343 if token == target {
344 if found_count == skip_count {
345 return Some(pos);
346 }
347 found_count += 1;
348 }
349 }
350 None
351}
352
353pub fn format_sql_with_preserved_parens(query: &str, cols_per_line: usize) -> Vec<String> {
354 let mut parser = Parser::new(query);
355 let stmt = match parser.parse() {
356 Ok(s) => s,
357 Err(_) => return vec![query.to_string()],
358 };
359
360 let mut lines = Vec::new();
361 let mut lexer = Lexer::new(query);
362 let mut tokens_with_pos = Vec::new();
363
364 loop {
366 let pos = lexer.get_position();
367 let token = lexer.next_token();
368 if token == Token::Eof {
369 break;
370 }
371 tokens_with_pos.push((token, pos));
372 }
373
374 let mut i = 0;
376 while i < tokens_with_pos.len() {
377 match &tokens_with_pos[i].0 {
378 Token::Select => {
379 let _select_start = tokens_with_pos[i].1;
380 i += 1;
381
382 let has_distinct = if i < tokens_with_pos.len() {
384 matches!(tokens_with_pos[i].0, Token::Distinct)
385 } else {
386 false
387 };
388
389 if has_distinct {
390 i += 1;
391 }
392
393 let _select_end = query.len();
395 let _col_count = 0;
396 let _current_line_cols: Vec<String> = Vec::new();
397 let mut all_select_lines = Vec::new();
398
399 let use_pretty_format = stmt.columns.len() > cols_per_line;
401
402 if use_pretty_format {
403 let select_text = if has_distinct {
405 "SELECT DISTINCT".to_string()
406 } else {
407 "SELECT".to_string()
408 };
409 all_select_lines.push(select_text);
410
411 for (idx, col) in stmt.columns.iter().enumerate() {
413 let is_last = idx == stmt.columns.len() - 1;
414 let formatted_col = if needs_quotes(col) {
416 format!("\"{}\"", col)
417 } else {
418 col.clone()
419 };
420 let col_text = if is_last {
421 format!(" {}", formatted_col)
422 } else {
423 format!(" {},", formatted_col)
424 };
425 all_select_lines.push(col_text);
426 }
427 } else {
428 let mut select_line = if has_distinct {
430 "SELECT DISTINCT ".to_string()
431 } else {
432 "SELECT ".to_string()
433 };
434
435 for (idx, col) in stmt.columns.iter().enumerate() {
436 if idx > 0 {
437 select_line.push_str(", ");
438 }
439 if needs_quotes(col) {
441 select_line.push_str(&format!("\"{}\"", col));
442 } else {
443 select_line.push_str(col);
444 }
445 }
446 all_select_lines.push(select_line);
447 }
448
449 lines.extend(all_select_lines);
450
451 while i < tokens_with_pos.len() {
453 match &tokens_with_pos[i].0 {
454 Token::From => break,
455 _ => i += 1,
456 }
457 }
458 }
459 Token::From => {
460 let from_start = tokens_with_pos[i].1;
461 i += 1;
462
463 let mut from_end = query.len();
465 while i < tokens_with_pos.len() {
466 match &tokens_with_pos[i].0 {
467 Token::Where
468 | Token::GroupBy
469 | Token::OrderBy
470 | Token::Limit
471 | Token::Having
472 | Token::Eof => {
473 from_end = tokens_with_pos[i].1;
474 break;
475 }
476 _ => i += 1,
477 }
478 }
479
480 let from_text = extract_text_between_positions(query, from_start, from_end);
481 lines.push(from_text.trim().to_string());
482 }
483 Token::Where => {
484 let where_start = tokens_with_pos[i].1;
485 i += 1;
486
487 let mut where_end = query.len();
489 let mut paren_depth = 0;
490 while i < tokens_with_pos.len() {
491 match &tokens_with_pos[i].0 {
492 Token::LeftParen => {
493 paren_depth += 1;
494 i += 1;
495 }
496 Token::RightParen => {
497 paren_depth -= 1;
498 i += 1;
499 }
500 Token::GroupBy
501 | Token::OrderBy
502 | Token::Limit
503 | Token::Having
504 | Token::Eof
505 if paren_depth == 0 =>
506 {
507 where_end = tokens_with_pos[i].1;
508 break;
509 }
510 _ => i += 1,
511 }
512 }
513
514 let where_text = extract_text_between_positions(query, where_start, where_end);
515 let formatted_where = format_where_clause_with_parens(where_text.trim());
516 lines.extend(formatted_where);
517 }
518 Token::GroupBy => {
519 let group_start = tokens_with_pos[i].1;
520 i += 1;
521
522 if i < tokens_with_pos.len() && matches!(tokens_with_pos[i].0, Token::By) {
524 i += 1;
525 }
526
527 while i < tokens_with_pos.len() {
529 match &tokens_with_pos[i].0 {
530 Token::OrderBy | Token::Limit | Token::Having | Token::Eof => break,
531 _ => i += 1,
532 }
533 }
534
535 if i > 0 {
536 let group_text = extract_text_between_positions(
537 query,
538 group_start,
539 tokens_with_pos[i - 1].1,
540 );
541 lines.push(format!("GROUP BY {}", group_text.trim()));
542 }
543 }
544 _ => i += 1,
545 }
546 }
547
548 lines
549}
550
551fn format_where_clause_with_parens(where_text: &str) -> Vec<String> {
552 let mut lines = Vec::new();
553 let mut current = String::from("WHERE ");
554 let mut paren_depth = 0;
555 let mut in_string = false;
556 let mut escape_next = false;
557 let mut chars = where_text.chars().peekable();
558
559 if where_text.trim_start().starts_with("WHERE") || where_text.trim_start().starts_with("where")
561 {
562 let skip_len = if where_text.trim_start().starts_with("WHERE") {
563 5
564 } else {
565 5
566 };
567 for _ in 0..skip_len {
568 chars.next();
569 }
570 while chars.peek() == Some(&' ') {
572 chars.next();
573 }
574 }
575
576 while let Some(ch) = chars.next() {
577 if escape_next {
578 current.push(ch);
579 escape_next = false;
580 continue;
581 }
582
583 match ch {
584 '\\' if in_string => {
585 current.push(ch);
586 escape_next = true;
587 }
588 '\'' => {
589 current.push(ch);
590 in_string = !in_string;
591 }
592 '(' if !in_string => {
593 current.push(ch);
594 paren_depth += 1;
595 }
596 ')' if !in_string => {
597 current.push(ch);
598 paren_depth -= 1;
599 }
600 _ => {
601 current.push(ch);
602 }
603 }
604 }
605
606 let cleaned = current.trim().to_string();
608 if !cleaned.is_empty() {
609 lines.push(cleaned);
610 }
611
612 lines
613}
614
615#[must_use]
616pub fn format_sql_pretty_compact(query: &str, cols_per_line: usize) -> Vec<String> {
617 let formatted = format_sql_with_preserved_parens(query, cols_per_line);
619
620 formatted
622 .into_iter()
623 .filter(|line| !line.trim().is_empty())
624 .map(|line| {
625 let mut result = line;
627 for keyword in &[
628 "SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "HAVING", "LIMIT",
629 ] {
630 let pattern = format!("{keyword}");
631 if result.starts_with(&pattern) && !result.starts_with(&format!("{keyword} ")) {
632 result = format!("{keyword} {}", &result[keyword.len()..].trim_start());
633 }
634 }
635 result
636 })
637 .collect()
638}
639
640pub fn format_expression(expr: &SqlExpression) -> String {
641 match expr {
642 SqlExpression::Column(name) => {
643 if needs_quotes(name) {
645 format!("\"{}\"", name)
646 } else {
647 name.clone()
648 }
649 }
650 SqlExpression::StringLiteral(value) => format!("'{value}'"),
651 SqlExpression::NumberLiteral(value) => value.clone(),
652 SqlExpression::BinaryOp { left, op, right } => {
653 format!(
654 "{} {} {}",
655 format_expression(left),
656 op,
657 format_expression(right)
658 )
659 }
660 SqlExpression::FunctionCall {
661 name,
662 args,
663 distinct,
664 } => {
665 let args_str = args
666 .iter()
667 .map(format_expression)
668 .collect::<Vec<_>>()
669 .join(", ");
670 if *distinct {
671 format!("{name}(DISTINCT {args_str})")
672 } else {
673 format!("{name}({args_str})")
674 }
675 }
676 SqlExpression::MethodCall {
677 object,
678 method,
679 args,
680 } => {
681 let args_str = args
682 .iter()
683 .map(format_expression)
684 .collect::<Vec<_>>()
685 .join(", ");
686 if args.is_empty() {
687 format!("{object}.{method}()")
688 } else {
689 format!("{object}.{method}({args_str})")
690 }
691 }
692 SqlExpression::InList { expr, values } => {
693 let values_str = values
694 .iter()
695 .map(format_expression)
696 .collect::<Vec<_>>()
697 .join(", ");
698 format!("{} IN ({})", format_expression(expr), values_str)
699 }
700 SqlExpression::NotInList { expr, values } => {
701 let values_str = values
702 .iter()
703 .map(format_expression)
704 .collect::<Vec<_>>()
705 .join(", ");
706 format!("{} NOT IN ({})", format_expression(expr), values_str)
707 }
708 SqlExpression::Between { expr, lower, upper } => {
709 format!(
710 "{} BETWEEN {} AND {}",
711 format_expression(expr),
712 format_expression(lower),
713 format_expression(upper)
714 )
715 }
716 SqlExpression::Null => "NULL".to_string(),
717 SqlExpression::BooleanLiteral(b) => b.to_string().to_uppercase(),
718 SqlExpression::DateTimeConstructor {
719 year,
720 month,
721 day,
722 hour,
723 minute,
724 second,
725 } => {
726 let time_part = match (hour, minute, second) {
727 (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
728 (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
729 _ => String::new(),
730 };
731 format!("DATETIME({year}, {month}, {day}{time_part})")
732 }
733 SqlExpression::DateTimeToday {
734 hour,
735 minute,
736 second,
737 } => {
738 let time_part = match (hour, minute, second) {
739 (Some(h), Some(m), Some(s)) => format!(", {h}, {m}, {s}"),
740 (Some(h), Some(m), None) => format!(", {h}, {m}"),
741 (Some(h), None, None) => format!(", {h}"),
742 _ => String::new(),
743 };
744 format!("TODAY({time_part})")
745 }
746 SqlExpression::WindowFunction {
747 name,
748 args,
749 window_spec,
750 } => {
751 let args_str = args
752 .iter()
753 .map(format_expression)
754 .collect::<Vec<_>>()
755 .join(", ");
756
757 let mut result = format!("{name}({args_str}) OVER (");
758
759 if !window_spec.partition_by.is_empty() {
761 result.push_str("PARTITION BY ");
762 result.push_str(&window_spec.partition_by.join(", "));
763 }
764
765 if !window_spec.order_by.is_empty() {
767 if !window_spec.partition_by.is_empty() {
768 result.push(' ');
769 }
770 result.push_str("ORDER BY ");
771 let order_strs: Vec<String> = window_spec
772 .order_by
773 .iter()
774 .map(|col| {
775 let dir = match col.direction {
776 SortDirection::Asc => " ASC",
777 SortDirection::Desc => " DESC",
778 };
779 format!("{}{}", col.column, dir)
780 })
781 .collect();
782 result.push_str(&order_strs.join(", "));
783 }
784
785 result.push(')');
786 result
787 }
788 SqlExpression::ChainedMethodCall { base, method, args } => {
789 let base_str = format_expression(base);
790 let args_str = args
791 .iter()
792 .map(format_expression)
793 .collect::<Vec<_>>()
794 .join(", ");
795 if args.is_empty() {
796 format!("{base_str}.{method}()")
797 } else {
798 format!("{base_str}.{method}({args_str})")
799 }
800 }
801 SqlExpression::Not { expr } => {
802 format!("NOT {}", format_expression(expr))
803 }
804 SqlExpression::CaseExpression {
805 when_branches,
806 else_branch,
807 } => format_case_expression(when_branches, else_branch.as_ref().map(|v| &**v)),
808 SqlExpression::SimpleCaseExpression {
809 expr,
810 when_branches,
811 else_branch,
812 } => format_simple_case_expression(expr, when_branches, else_branch.as_ref().map(|v| &**v)),
813 SqlExpression::ScalarSubquery { query: _ } => {
814 "(SELECT ...)".to_string()
816 }
817 SqlExpression::InSubquery { expr, subquery: _ } => {
818 format!("{} IN (SELECT ...)", format_expression(expr))
819 }
820 SqlExpression::NotInSubquery { expr, subquery: _ } => {
821 format!("{} NOT IN (SELECT ...)", format_expression(expr))
822 }
823 }
824}
825
826fn format_token(token: &Token) -> String {
827 match token {
828 Token::Identifier(s) => s.clone(),
829 Token::QuotedIdentifier(s) => format!("\"{s}\""),
830 Token::StringLiteral(s) => format!("'{s}'"),
831 Token::NumberLiteral(n) => n.clone(),
832 Token::DateTime => "DateTime".to_string(),
833 Token::Case => "CASE".to_string(),
834 Token::When => "WHEN".to_string(),
835 Token::Then => "THEN".to_string(),
836 Token::Else => "ELSE".to_string(),
837 Token::End => "END".to_string(),
838 Token::Distinct => "DISTINCT".to_string(),
839 Token::Over => "OVER".to_string(),
840 Token::Partition => "PARTITION".to_string(),
841 Token::By => "BY".to_string(),
842 Token::LeftParen => "(".to_string(),
843 Token::RightParen => ")".to_string(),
844 Token::Comma => ",".to_string(),
845 Token::Dot => ".".to_string(),
846 Token::Equal => "=".to_string(),
847 Token::NotEqual => "!=".to_string(),
848 Token::LessThan => "<".to_string(),
849 Token::GreaterThan => ">".to_string(),
850 Token::LessThanOrEqual => "<=".to_string(),
851 Token::GreaterThanOrEqual => ">=".to_string(),
852 Token::In => "IN".to_string(),
853 _ => format!("{token:?}").to_uppercase(),
854 }
855}
856
857fn needs_quotes(name: &str) -> bool {
859 if name.contains('-') || name.contains(' ') || name.contains('.') || name.contains('/') {
861 return true;
862 }
863
864 if name.chars().next().map_or(false, |c| c.is_ascii_digit()) {
866 return true;
867 }
868
869 let reserved_words = [
871 "SELECT", "FROM", "WHERE", "ORDER", "GROUP", "BY", "HAVING", "INSERT", "UPDATE", "DELETE",
872 "CREATE", "DROP", "ALTER", "TABLE", "INDEX", "VIEW", "AND", "OR", "NOT", "IN", "EXISTS",
873 "BETWEEN", "LIKE", "CASE", "WHEN", "THEN", "ELSE", "END", "JOIN", "LEFT", "RIGHT", "INNER",
874 "OUTER", "ON", "AS", "DISTINCT", "ALL", "TOP", "LIMIT", "OFFSET", "ASC", "DESC",
875 ];
876
877 let upper_name = name.to_uppercase();
878 if reserved_words.contains(&upper_name.as_str()) {
879 return true;
880 }
881
882 !name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
885}
886
887fn format_case_expression(
889 when_branches: &[crate::sql::recursive_parser::WhenBranch],
890 else_branch: Option<&SqlExpression>,
891) -> String {
892 let is_simple = when_branches.len() <= 1
894 && when_branches
895 .iter()
896 .all(|b| expr_is_simple(&b.condition) && expr_is_simple(&b.result))
897 && else_branch.map_or(true, expr_is_simple);
898
899 if is_simple {
900 let mut result = String::from("CASE");
902 for branch in when_branches {
903 result.push_str(&format!(
904 " WHEN {} THEN {}",
905 format_expression(&branch.condition),
906 format_expression(&branch.result)
907 ));
908 }
909 if let Some(else_expr) = else_branch {
910 result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
911 }
912 result.push_str(" END");
913 result
914 } else {
915 let mut result = String::from("CASE");
917 for branch in when_branches {
918 result.push_str(&format!(
919 "\n WHEN {} THEN {}",
920 format_expression(&branch.condition),
921 format_expression(&branch.result)
922 ));
923 }
924 if let Some(else_expr) = else_branch {
925 result.push_str(&format!("\n ELSE {}", format_expression(else_expr)));
926 }
927 result.push_str("\n END");
928 result
929 }
930}
931
932fn format_simple_case_expression(
934 expr: &SqlExpression,
935 when_branches: &[crate::sql::parser::ast::SimpleWhenBranch],
936 else_branch: Option<&SqlExpression>,
937) -> String {
938 let is_simple = when_branches.len() <= 2
940 && expr_is_simple(expr)
941 && when_branches
942 .iter()
943 .all(|b| expr_is_simple(&b.value) && expr_is_simple(&b.result))
944 && else_branch.map_or(true, expr_is_simple);
945
946 if is_simple {
947 let mut result = format!("CASE {}", format_expression(expr));
949 for branch in when_branches {
950 result.push_str(&format!(
951 " WHEN {} THEN {}",
952 format_expression(&branch.value),
953 format_expression(&branch.result)
954 ));
955 }
956 if let Some(else_expr) = else_branch {
957 result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
958 }
959 result.push_str(" END");
960 result
961 } else {
962 let mut result = format!("CASE {}", format_expression(expr));
964 for branch in when_branches {
965 result.push_str(&format!(
966 "\n WHEN {} THEN {}",
967 format_expression(&branch.value),
968 format_expression(&branch.result)
969 ));
970 }
971 if let Some(else_expr) = else_branch {
972 result.push_str(&format!("\n ELSE {}", format_expression(else_expr)));
973 }
974 result.push_str("\n END");
975 result
976 }
977}
978
979fn expr_is_simple(expr: &SqlExpression) -> bool {
981 match expr {
982 SqlExpression::Column(_)
983 | SqlExpression::StringLiteral(_)
984 | SqlExpression::NumberLiteral(_)
985 | SqlExpression::BooleanLiteral(_)
986 | SqlExpression::Null => true,
987 SqlExpression::BinaryOp { left, right, .. } => {
988 expr_is_simple(left) && expr_is_simple(right)
989 }
990 SqlExpression::FunctionCall { args, .. } => {
991 args.len() <= 2 && args.iter().all(expr_is_simple)
992 }
993 _ => false,
994 }
995}