1use super::ast::{LogicalOp, SelectStatement, SortDirection, SqlExpression, WhereClause};
5use super::lexer::{Lexer, Token};
6use crate::sql::recursive_parser::Parser;
7
8#[must_use]
9pub fn format_sql_pretty(query: &str) -> Vec<String> {
10 format_sql_pretty_compact(query, 5) }
12
13#[must_use]
15pub fn format_ast_tree(query: &str) -> String {
16 let mut parser = Parser::new(query);
17 match parser.parse() {
18 Ok(stmt) => format_select_statement(&stmt, 0),
19 Err(e) => format!("❌ PARSE ERROR ❌\n{e}\n\n⚠️ The query could not be parsed correctly.\n💡 Check parentheses, operators, and syntax."),
20 }
21}
22
23fn format_select_statement(stmt: &SelectStatement, indent: usize) -> String {
24 let mut result = String::new();
25 let indent_str = " ".repeat(indent);
26
27 result.push_str(&format!("{indent_str}SelectStatement {{\n"));
28
29 result.push_str(&format!("{indent_str} columns: ["));
31 if stmt.columns.is_empty() {
32 result.push_str("],\n");
33 } else {
34 result.push('\n');
35 for col in &stmt.columns {
36 result.push_str(&format!("{indent_str} \"{col}\",\n"));
37 }
38 result.push_str(&format!("{indent_str} ],\n"));
39 }
40
41 if let Some(table) = &stmt.from_table {
43 result.push_str(&format!("{indent_str} from_table: \"{table}\",\n"));
44 }
45
46 if let Some(where_clause) = &stmt.where_clause {
48 result.push_str(&format!("{indent_str} where_clause: {{\n"));
49 result.push_str(&format_where_clause(where_clause, indent + 2));
50 result.push_str(&format!("{indent_str} }},\n"));
51 }
52
53 if let Some(order_by) = &stmt.order_by {
55 result.push_str(&format!("{indent_str} order_by: ["));
56 if order_by.is_empty() {
57 result.push_str("],\n");
58 } else {
59 result.push('\n');
60 for col in order_by {
61 let dir = match col.direction {
62 SortDirection::Asc => "ASC",
63 SortDirection::Desc => "DESC",
64 };
65 result.push_str(&format!(
66 "{indent_str} {{ column: \"{}\", direction: {dir} }},\n",
67 col.column
68 ));
69 }
70 result.push_str(&format!("{indent_str} ],\n"));
71 }
72 }
73
74 if let Some(group_by) = &stmt.group_by {
76 result.push_str(&format!("{indent_str} group_by: ["));
77 if group_by.is_empty() {
78 result.push_str("],\n");
79 } else {
80 result.push('\n');
81 for expr in group_by {
82 result.push_str(&format!("{indent_str} \"{:?}\",\n", expr));
83 }
84 result.push_str(&format!("{indent_str} ],\n"));
85 }
86 }
87
88 if let Some(limit) = stmt.limit {
90 result.push_str(&format!("{indent_str} limit: {limit},\n"));
91 }
92
93 if stmt.distinct {
95 result.push_str(&format!("{indent_str} distinct: true,\n"));
96 }
97
98 result.push_str(&format!("{indent_str}}}\n"));
99 result
100}
101
102fn format_where_clause(clause: &WhereClause, indent: usize) -> String {
103 let mut result = String::new();
104 let indent_str = " ".repeat(indent);
105
106 result.push_str(&format!("{indent_str}conditions: [\n"));
107 for (i, condition) in clause.conditions.iter().enumerate() {
108 result.push_str(&format!("{indent_str} {{\n"));
109 result.push_str(&format!(
110 "{indent_str} expr: {},\n",
111 format_expression_ast(&condition.expr)
112 ));
113
114 if let Some(connector) = &condition.connector {
115 let conn_str = match connector {
116 LogicalOp::And => "AND",
117 LogicalOp::Or => "OR",
118 };
119 result.push_str(&format!("{indent_str} connector: {conn_str},\n"));
120 }
121
122 result.push_str(&format!("{indent_str} }}"));
123 if i < clause.conditions.len() - 1 {
124 result.push(',');
125 }
126 result.push('\n');
127 }
128 result.push_str(&format!("{indent_str}]\n"));
129
130 result
131}
132
133pub fn format_expression_ast(expr: &SqlExpression) -> String {
134 match expr {
135 SqlExpression::Column(name) => format!("Column(\"{name}\")"),
136 SqlExpression::StringLiteral(value) => format!("StringLiteral(\"{value}\")"),
137 SqlExpression::NumberLiteral(value) => format!("NumberLiteral({value})"),
138 SqlExpression::BinaryOp { left, op, right } => {
139 format!(
140 "BinaryOp {{ left: {}, op: \"{op}\", right: {} }}",
141 format_expression_ast(left),
142 format_expression_ast(right)
143 )
144 }
145 SqlExpression::FunctionCall {
146 name,
147 args,
148 distinct,
149 } => {
150 let args_str = args
151 .iter()
152 .map(format_expression_ast)
153 .collect::<Vec<_>>()
154 .join(", ");
155 if *distinct {
156 format!("FunctionCall {{ name: \"{name}\", args: [{args_str}], distinct: true }}")
157 } else {
158 format!("FunctionCall {{ name: \"{name}\", args: [{args_str}] }}")
159 }
160 }
161 SqlExpression::MethodCall {
162 object,
163 method,
164 args,
165 } => {
166 let args_str = args
167 .iter()
168 .map(format_expression_ast)
169 .collect::<Vec<_>>()
170 .join(", ");
171 format!(
172 "MethodCall {{ object: \"{object}\", method: \"{method}\", args: [{args_str}] }}"
173 )
174 }
175 SqlExpression::InList { expr, values } => {
176 let values_str = values
177 .iter()
178 .map(format_expression_ast)
179 .collect::<Vec<_>>()
180 .join(", ");
181 format!(
182 "InList {{ expr: {}, values: [{values_str}] }}",
183 format_expression_ast(expr)
184 )
185 }
186 SqlExpression::NotInList { expr, values } => {
187 let values_str = values
188 .iter()
189 .map(format_expression_ast)
190 .collect::<Vec<_>>()
191 .join(", ");
192 format!(
193 "NotInList {{ expr: {}, values: [{values_str}] }}",
194 format_expression_ast(expr)
195 )
196 }
197 SqlExpression::Between { expr, lower, upper } => {
198 format!(
199 "Between {{ expr: {}, lower: {}, upper: {} }}",
200 format_expression_ast(expr),
201 format_expression_ast(lower),
202 format_expression_ast(upper)
203 )
204 }
205 SqlExpression::Null => "Null".to_string(),
206 SqlExpression::BooleanLiteral(b) => format!("BooleanLiteral({b})"),
207 SqlExpression::DateTimeConstructor {
208 year,
209 month,
210 day,
211 hour,
212 minute,
213 second,
214 } => {
215 let time_part = match (hour, minute, second) {
216 (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
217 (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
218 _ => String::new(),
219 };
220 format!("DateTimeConstructor({year}-{month:02}-{day:02}{time_part})")
221 }
222 SqlExpression::DateTimeToday {
223 hour,
224 minute,
225 second,
226 } => {
227 let time_part = match (hour, minute, second) {
228 (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
229 (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
230 _ => String::new(),
231 };
232 format!("DateTimeToday({time_part})")
233 }
234 SqlExpression::WindowFunction {
235 name,
236 args,
237 window_spec: _,
238 } => {
239 let args_str = args
240 .iter()
241 .map(format_expression_ast)
242 .collect::<Vec<_>>()
243 .join(", ");
244 format!("WindowFunction {{ name: \"{name}\", args: [{args_str}], window_spec: ... }}")
245 }
246 SqlExpression::ChainedMethodCall { base, method, args } => {
247 let args_str = args
248 .iter()
249 .map(format_expression_ast)
250 .collect::<Vec<_>>()
251 .join(", ");
252 format!(
253 "ChainedMethodCall {{ base: {}, method: \"{method}\", args: [{args_str}] }}",
254 format_expression_ast(base)
255 )
256 }
257 SqlExpression::Not { expr } => {
258 format!("Not {{ expr: {} }}", format_expression_ast(expr))
259 }
260 SqlExpression::CaseExpression {
261 when_branches,
262 else_branch,
263 } => {
264 let mut result = String::from("CaseExpression { when_branches: [");
265 for branch in when_branches {
266 result.push_str(&format!(
267 " {{ condition: {}, result: {} }},",
268 format_expression_ast(&branch.condition),
269 format_expression_ast(&branch.result)
270 ));
271 }
272 result.push_str("], else_branch: ");
273 if let Some(else_expr) = else_branch {
274 result.push_str(&format_expression_ast(else_expr));
275 } else {
276 result.push_str("None");
277 }
278 result.push_str(" }");
279 result
280 }
281 SqlExpression::SimpleCaseExpression {
282 expr,
283 when_branches,
284 else_branch,
285 } => {
286 let mut result = format!(
287 "SimpleCaseExpression {{ expr: {}, when_branches: [",
288 format_expression_ast(expr)
289 );
290 for branch in when_branches {
291 result.push_str(&format!(
292 " {{ value: {}, result: {} }},",
293 format_expression_ast(&branch.value),
294 format_expression_ast(&branch.result)
295 ));
296 }
297 result.push_str("], else_branch: ");
298 if let Some(else_expr) = else_branch {
299 result.push_str(&format_expression_ast(else_expr));
300 } else {
301 result.push_str("None");
302 }
303 result.push_str(" }");
304 result
305 }
306 SqlExpression::ScalarSubquery { query: _ } => {
307 format!("ScalarSubquery {{ query: <SelectStatement> }}")
308 }
309 SqlExpression::InSubquery { expr, subquery: _ } => {
310 format!(
311 "InSubquery {{ expr: {}, subquery: <SelectStatement> }}",
312 format_expression_ast(expr)
313 )
314 }
315 SqlExpression::NotInSubquery { expr, subquery: _ } => {
316 format!(
317 "NotInSubquery {{ expr: {}, subquery: <SelectStatement> }}",
318 format_expression_ast(expr)
319 )
320 }
321 }
322}
323
324fn extract_text_between_positions(text: &str, start: usize, end: usize) -> String {
326 if start >= text.len() || end > text.len() || start >= end {
327 return String::new();
328 }
329 text[start..end].to_string()
330}
331
332fn find_token_position(query: &str, target: Token, skip_count: usize) -> Option<usize> {
334 let mut lexer = Lexer::new(query);
335 let mut found_count = 0;
336
337 loop {
338 let pos = lexer.get_position();
339 let token = lexer.next_token();
340 if token == Token::Eof {
341 break;
342 }
343 if token == target {
344 if found_count == skip_count {
345 return Some(pos);
346 }
347 found_count += 1;
348 }
349 }
350 None
351}
352
353pub fn format_sql_with_preserved_parens(query: &str, cols_per_line: usize) -> Vec<String> {
354 let mut parser = Parser::new(query);
355 let stmt = match parser.parse() {
356 Ok(s) => s,
357 Err(_) => return vec![query.to_string()],
358 };
359
360 let mut lines = Vec::new();
361 let mut lexer = Lexer::new(query);
362 let mut tokens_with_pos = Vec::new();
363
364 loop {
366 let pos = lexer.get_position();
367 let token = lexer.next_token();
368 if token == Token::Eof {
369 break;
370 }
371 tokens_with_pos.push((token, pos));
372 }
373
374 let mut i = 0;
376 while i < tokens_with_pos.len() {
377 match &tokens_with_pos[i].0 {
378 Token::Select => {
379 let _select_start = tokens_with_pos[i].1;
380 i += 1;
381
382 let has_distinct = if i < tokens_with_pos.len() {
384 matches!(tokens_with_pos[i].0, Token::Distinct)
385 } else {
386 false
387 };
388
389 if has_distinct {
390 i += 1;
391 }
392
393 let _select_end = query.len();
395 let _col_count = 0;
396 let _current_line_cols: Vec<String> = Vec::new();
397 let mut all_select_lines = Vec::new();
398
399 let use_pretty_format = stmt.columns.len() > cols_per_line;
401
402 if use_pretty_format {
403 let select_text = if has_distinct {
405 "SELECT DISTINCT".to_string()
406 } else {
407 "SELECT".to_string()
408 };
409 all_select_lines.push(select_text);
410
411 for (idx, col) in stmt.columns.iter().enumerate() {
413 let is_last = idx == stmt.columns.len() - 1;
414 let formatted_col = if needs_quotes(col) {
416 format!("\"{}\"", col)
417 } else {
418 col.clone()
419 };
420 let col_text = if is_last {
421 format!(" {}", formatted_col)
422 } else {
423 format!(" {},", formatted_col)
424 };
425 all_select_lines.push(col_text);
426 }
427 } else {
428 let mut select_line = if has_distinct {
430 "SELECT DISTINCT ".to_string()
431 } else {
432 "SELECT ".to_string()
433 };
434
435 for (idx, col) in stmt.columns.iter().enumerate() {
436 if idx > 0 {
437 select_line.push_str(", ");
438 }
439 if needs_quotes(col) {
441 select_line.push_str(&format!("\"{}\"", col));
442 } else {
443 select_line.push_str(col);
444 }
445 }
446 all_select_lines.push(select_line);
447 }
448
449 lines.extend(all_select_lines);
450
451 while i < tokens_with_pos.len() {
453 match &tokens_with_pos[i].0 {
454 Token::From => break,
455 _ => i += 1,
456 }
457 }
458 }
459 Token::From => {
460 let from_start = tokens_with_pos[i].1;
461 i += 1;
462
463 let mut from_end = query.len();
465 while i < tokens_with_pos.len() {
466 match &tokens_with_pos[i].0 {
467 Token::Where
468 | Token::GroupBy
469 | Token::OrderBy
470 | Token::Limit
471 | Token::Having
472 | Token::Eof => {
473 from_end = tokens_with_pos[i].1;
474 break;
475 }
476 _ => i += 1,
477 }
478 }
479
480 let from_text = extract_text_between_positions(query, from_start, from_end);
481 lines.push(from_text.trim().to_string());
482 }
483 Token::Where => {
484 let where_start = tokens_with_pos[i].1;
485 i += 1;
486
487 let mut where_end = query.len();
489 let mut paren_depth = 0;
490 while i < tokens_with_pos.len() {
491 match &tokens_with_pos[i].0 {
492 Token::LeftParen => {
493 paren_depth += 1;
494 i += 1;
495 }
496 Token::RightParen => {
497 paren_depth -= 1;
498 i += 1;
499 }
500 Token::GroupBy
501 | Token::OrderBy
502 | Token::Limit
503 | Token::Having
504 | Token::Eof
505 if paren_depth == 0 =>
506 {
507 where_end = tokens_with_pos[i].1;
508 break;
509 }
510 _ => i += 1,
511 }
512 }
513
514 let where_text = extract_text_between_positions(query, where_start, where_end);
515 let formatted_where = format_where_clause_with_parens(where_text.trim());
516 lines.extend(formatted_where);
517 }
518 Token::GroupBy => {
519 let group_start = tokens_with_pos[i].1;
520 i += 1;
521
522 if i < tokens_with_pos.len() && matches!(tokens_with_pos[i].0, Token::By) {
524 i += 1;
525 }
526
527 while i < tokens_with_pos.len() {
529 match &tokens_with_pos[i].0 {
530 Token::OrderBy | Token::Limit | Token::Having | Token::Eof => break,
531 _ => i += 1,
532 }
533 }
534
535 if i > 0 {
536 let group_text = extract_text_between_positions(
537 query,
538 group_start,
539 tokens_with_pos[i - 1].1,
540 );
541 lines.push(format!("GROUP BY {}", group_text.trim()));
542 }
543 }
544 _ => i += 1,
545 }
546 }
547
548 lines
549}
550
551#[allow(unused_assignments)]
552fn format_where_clause_with_parens(where_text: &str) -> Vec<String> {
553 let mut lines = Vec::new();
554 let mut current = String::from("WHERE ");
555 let mut _paren_depth = 0;
556 let mut in_string = false;
557 let mut escape_next = false;
558 let mut chars = where_text.chars().peekable();
559
560 if where_text.trim_start().starts_with("WHERE") || where_text.trim_start().starts_with("where")
562 {
563 let skip_len = if where_text.trim_start().starts_with("WHERE") {
564 5
565 } else {
566 5
567 };
568 for _ in 0..skip_len {
569 chars.next();
570 }
571 while chars.peek() == Some(&' ') {
573 chars.next();
574 }
575 }
576
577 while let Some(ch) = chars.next() {
578 if escape_next {
579 current.push(ch);
580 escape_next = false;
581 continue;
582 }
583
584 match ch {
585 '\\' if in_string => {
586 current.push(ch);
587 escape_next = true;
588 }
589 '\'' => {
590 current.push(ch);
591 in_string = !in_string;
592 }
593 '(' if !in_string => {
594 current.push(ch);
595 _paren_depth += 1;
596 }
597 ')' if !in_string => {
598 current.push(ch);
599 _paren_depth -= 1;
600 }
601 _ => {
602 current.push(ch);
603 }
604 }
605 }
606
607 let cleaned = current.trim().to_string();
609 if !cleaned.is_empty() {
610 lines.push(cleaned);
611 }
612
613 lines
614}
615
616#[must_use]
617pub fn format_sql_pretty_compact(query: &str, cols_per_line: usize) -> Vec<String> {
618 let formatted = format_sql_with_preserved_parens(query, cols_per_line);
620
621 formatted
623 .into_iter()
624 .filter(|line| !line.trim().is_empty())
625 .map(|line| {
626 let mut result = line;
628 for keyword in &[
629 "SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "HAVING", "LIMIT",
630 ] {
631 let pattern = format!("{keyword}");
632 if result.starts_with(&pattern) && !result.starts_with(&format!("{keyword} ")) {
633 result = format!("{keyword} {}", &result[keyword.len()..].trim_start());
634 }
635 }
636 result
637 })
638 .collect()
639}
640
641pub fn format_expression(expr: &SqlExpression) -> String {
642 match expr {
643 SqlExpression::Column(column_ref) => {
644 column_ref.to_sql()
646 }
647 SqlExpression::StringLiteral(value) => format!("'{value}'"),
648 SqlExpression::NumberLiteral(value) => value.clone(),
649 SqlExpression::BinaryOp { left, op, right } => {
650 format!(
651 "{} {} {}",
652 format_expression(left),
653 op,
654 format_expression(right)
655 )
656 }
657 SqlExpression::FunctionCall {
658 name,
659 args,
660 distinct,
661 } => {
662 let args_str = args
663 .iter()
664 .map(format_expression)
665 .collect::<Vec<_>>()
666 .join(", ");
667 if *distinct {
668 format!("{name}(DISTINCT {args_str})")
669 } else {
670 format!("{name}({args_str})")
671 }
672 }
673 SqlExpression::MethodCall {
674 object,
675 method,
676 args,
677 } => {
678 let args_str = args
679 .iter()
680 .map(format_expression)
681 .collect::<Vec<_>>()
682 .join(", ");
683 if args.is_empty() {
684 format!("{object}.{method}()")
685 } else {
686 format!("{object}.{method}({args_str})")
687 }
688 }
689 SqlExpression::InList { expr, values } => {
690 let values_str = values
691 .iter()
692 .map(format_expression)
693 .collect::<Vec<_>>()
694 .join(", ");
695 format!("{} IN ({})", format_expression(expr), values_str)
696 }
697 SqlExpression::NotInList { expr, values } => {
698 let values_str = values
699 .iter()
700 .map(format_expression)
701 .collect::<Vec<_>>()
702 .join(", ");
703 format!("{} NOT IN ({})", format_expression(expr), values_str)
704 }
705 SqlExpression::Between { expr, lower, upper } => {
706 format!(
707 "{} BETWEEN {} AND {}",
708 format_expression(expr),
709 format_expression(lower),
710 format_expression(upper)
711 )
712 }
713 SqlExpression::Null => "NULL".to_string(),
714 SqlExpression::BooleanLiteral(b) => b.to_string().to_uppercase(),
715 SqlExpression::DateTimeConstructor {
716 year,
717 month,
718 day,
719 hour,
720 minute,
721 second,
722 } => {
723 let time_part = match (hour, minute, second) {
724 (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
725 (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
726 _ => String::new(),
727 };
728 format!("DATETIME({year}, {month}, {day}{time_part})")
729 }
730 SqlExpression::DateTimeToday {
731 hour,
732 minute,
733 second,
734 } => {
735 let time_part = match (hour, minute, second) {
736 (Some(h), Some(m), Some(s)) => format!(", {h}, {m}, {s}"),
737 (Some(h), Some(m), None) => format!(", {h}, {m}"),
738 (Some(h), None, None) => format!(", {h}"),
739 _ => String::new(),
740 };
741 format!("TODAY({time_part})")
742 }
743 SqlExpression::WindowFunction {
744 name,
745 args,
746 window_spec,
747 } => {
748 let args_str = args
749 .iter()
750 .map(format_expression)
751 .collect::<Vec<_>>()
752 .join(", ");
753
754 let mut result = format!("{name}({args_str}) OVER (");
755
756 if !window_spec.partition_by.is_empty() {
758 result.push_str("PARTITION BY ");
759 result.push_str(&window_spec.partition_by.join(", "));
760 }
761
762 if !window_spec.order_by.is_empty() {
764 if !window_spec.partition_by.is_empty() {
765 result.push(' ');
766 }
767 result.push_str("ORDER BY ");
768 let order_strs: Vec<String> = window_spec
769 .order_by
770 .iter()
771 .map(|col| {
772 let dir = match col.direction {
773 SortDirection::Asc => " ASC",
774 SortDirection::Desc => " DESC",
775 };
776 format!("{}{}", col.column, dir)
777 })
778 .collect();
779 result.push_str(&order_strs.join(", "));
780 }
781
782 result.push(')');
783 result
784 }
785 SqlExpression::ChainedMethodCall { base, method, args } => {
786 let base_str = format_expression(base);
787 let args_str = args
788 .iter()
789 .map(format_expression)
790 .collect::<Vec<_>>()
791 .join(", ");
792 if args.is_empty() {
793 format!("{base_str}.{method}()")
794 } else {
795 format!("{base_str}.{method}({args_str})")
796 }
797 }
798 SqlExpression::Not { expr } => {
799 format!("NOT {}", format_expression(expr))
800 }
801 SqlExpression::CaseExpression {
802 when_branches,
803 else_branch,
804 } => format_case_expression(when_branches, else_branch.as_ref().map(|v| &**v)),
805 SqlExpression::SimpleCaseExpression {
806 expr,
807 when_branches,
808 else_branch,
809 } => format_simple_case_expression(expr, when_branches, else_branch.as_ref().map(|v| &**v)),
810 SqlExpression::ScalarSubquery { query: _ } => {
811 "(SELECT ...)".to_string()
813 }
814 SqlExpression::InSubquery { expr, subquery: _ } => {
815 format!("{} IN (SELECT ...)", format_expression(expr))
816 }
817 SqlExpression::NotInSubquery { expr, subquery: _ } => {
818 format!("{} NOT IN (SELECT ...)", format_expression(expr))
819 }
820 }
821}
822
823fn format_token(token: &Token) -> String {
824 match token {
825 Token::Identifier(s) => s.clone(),
826 Token::QuotedIdentifier(s) => format!("\"{s}\""),
827 Token::StringLiteral(s) => format!("'{s}'"),
828 Token::NumberLiteral(n) => n.clone(),
829 Token::DateTime => "DateTime".to_string(),
830 Token::Case => "CASE".to_string(),
831 Token::When => "WHEN".to_string(),
832 Token::Then => "THEN".to_string(),
833 Token::Else => "ELSE".to_string(),
834 Token::End => "END".to_string(),
835 Token::Distinct => "DISTINCT".to_string(),
836 Token::Over => "OVER".to_string(),
837 Token::Partition => "PARTITION".to_string(),
838 Token::By => "BY".to_string(),
839 Token::LeftParen => "(".to_string(),
840 Token::RightParen => ")".to_string(),
841 Token::Comma => ",".to_string(),
842 Token::Dot => ".".to_string(),
843 Token::Equal => "=".to_string(),
844 Token::NotEqual => "!=".to_string(),
845 Token::LessThan => "<".to_string(),
846 Token::GreaterThan => ">".to_string(),
847 Token::LessThanOrEqual => "<=".to_string(),
848 Token::GreaterThanOrEqual => ">=".to_string(),
849 Token::In => "IN".to_string(),
850 _ => format!("{token:?}").to_uppercase(),
851 }
852}
853
854fn needs_quotes(name: &str) -> bool {
856 if name.contains('-') || name.contains(' ') || name.contains('.') || name.contains('/') {
858 return true;
859 }
860
861 if name.chars().next().map_or(false, |c| c.is_ascii_digit()) {
863 return true;
864 }
865
866 let reserved_words = [
868 "SELECT", "FROM", "WHERE", "ORDER", "GROUP", "BY", "HAVING", "INSERT", "UPDATE", "DELETE",
869 "CREATE", "DROP", "ALTER", "TABLE", "INDEX", "VIEW", "AND", "OR", "NOT", "IN", "EXISTS",
870 "BETWEEN", "LIKE", "CASE", "WHEN", "THEN", "ELSE", "END", "JOIN", "LEFT", "RIGHT", "INNER",
871 "OUTER", "ON", "AS", "DISTINCT", "ALL", "TOP", "LIMIT", "OFFSET", "ASC", "DESC",
872 ];
873
874 let upper_name = name.to_uppercase();
875 if reserved_words.contains(&upper_name.as_str()) {
876 return true;
877 }
878
879 !name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
882}
883
884fn format_case_expression(
886 when_branches: &[crate::sql::recursive_parser::WhenBranch],
887 else_branch: Option<&SqlExpression>,
888) -> String {
889 let is_simple = when_branches.len() <= 1
891 && when_branches
892 .iter()
893 .all(|b| expr_is_simple(&b.condition) && expr_is_simple(&b.result))
894 && else_branch.map_or(true, expr_is_simple);
895
896 if is_simple {
897 let mut result = String::from("CASE");
899 for branch in when_branches {
900 result.push_str(&format!(
901 " WHEN {} THEN {}",
902 format_expression(&branch.condition),
903 format_expression(&branch.result)
904 ));
905 }
906 if let Some(else_expr) = else_branch {
907 result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
908 }
909 result.push_str(" END");
910 result
911 } else {
912 let mut result = String::from("CASE");
914 for branch in when_branches {
915 result.push_str(&format!(
916 "\n WHEN {} THEN {}",
917 format_expression(&branch.condition),
918 format_expression(&branch.result)
919 ));
920 }
921 if let Some(else_expr) = else_branch {
922 result.push_str(&format!("\n ELSE {}", format_expression(else_expr)));
923 }
924 result.push_str("\n END");
925 result
926 }
927}
928
929fn format_simple_case_expression(
931 expr: &SqlExpression,
932 when_branches: &[crate::sql::parser::ast::SimpleWhenBranch],
933 else_branch: Option<&SqlExpression>,
934) -> String {
935 let is_simple = when_branches.len() <= 2
937 && expr_is_simple(expr)
938 && when_branches
939 .iter()
940 .all(|b| expr_is_simple(&b.value) && expr_is_simple(&b.result))
941 && else_branch.map_or(true, expr_is_simple);
942
943 if is_simple {
944 let mut result = format!("CASE {}", format_expression(expr));
946 for branch in when_branches {
947 result.push_str(&format!(
948 " WHEN {} THEN {}",
949 format_expression(&branch.value),
950 format_expression(&branch.result)
951 ));
952 }
953 if let Some(else_expr) = else_branch {
954 result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
955 }
956 result.push_str(" END");
957 result
958 } else {
959 let mut result = format!("CASE {}", format_expression(expr));
961 for branch in when_branches {
962 result.push_str(&format!(
963 "\n WHEN {} THEN {}",
964 format_expression(&branch.value),
965 format_expression(&branch.result)
966 ));
967 }
968 if let Some(else_expr) = else_branch {
969 result.push_str(&format!("\n ELSE {}", format_expression(else_expr)));
970 }
971 result.push_str("\n END");
972 result
973 }
974}
975
976fn expr_is_simple(expr: &SqlExpression) -> bool {
978 match expr {
979 SqlExpression::Column(_)
980 | SqlExpression::StringLiteral(_)
981 | SqlExpression::NumberLiteral(_)
982 | SqlExpression::BooleanLiteral(_)
983 | SqlExpression::Null => true,
984 SqlExpression::BinaryOp { left, right, .. } => {
985 expr_is_simple(left) && expr_is_simple(right)
986 }
987 SqlExpression::FunctionCall { args, .. } => {
988 args.len() <= 2 && args.iter().all(expr_is_simple)
989 }
990 _ => false,
991 }
992}