1use super::ast::{LogicalOp, SelectStatement, SortDirection, SqlExpression, WhereClause};
5use super::lexer::{Lexer, Token};
6use crate::sql::recursive_parser::Parser;
7
8#[must_use]
9pub fn format_sql_pretty(query: &str) -> Vec<String> {
10 format_sql_pretty_compact(query, 5) }
12
13#[must_use]
15pub fn format_ast_tree(query: &str) -> String {
16 let mut parser = Parser::new(query);
17 match parser.parse() {
18 Ok(stmt) => format_select_statement(&stmt, 0),
19 Err(e) => format!("❌ PARSE ERROR ❌\n{e}\n\n⚠️ The query could not be parsed correctly.\n💡 Check parentheses, operators, and syntax."),
20 }
21}
22
23fn format_select_statement(stmt: &SelectStatement, indent: usize) -> String {
24 let mut result = String::new();
25 let indent_str = " ".repeat(indent);
26
27 result.push_str(&format!("{indent_str}SelectStatement {{\n"));
28
29 result.push_str(&format!("{indent_str} columns: ["));
31 if stmt.columns.is_empty() {
32 result.push_str("],\n");
33 } else {
34 result.push('\n');
35 for col in &stmt.columns {
36 result.push_str(&format!("{indent_str} \"{col}\",\n"));
37 }
38 result.push_str(&format!("{indent_str} ],\n"));
39 }
40
41 if let Some(table) = &stmt.from_table {
43 result.push_str(&format!("{indent_str} from_table: \"{table}\",\n"));
44 }
45
46 if let Some(where_clause) = &stmt.where_clause {
48 result.push_str(&format!("{indent_str} where_clause: {{\n"));
49 result.push_str(&format_where_clause(where_clause, indent + 2));
50 result.push_str(&format!("{indent_str} }},\n"));
51 }
52
53 if let Some(order_by) = &stmt.order_by {
55 result.push_str(&format!("{indent_str} order_by: ["));
56 if order_by.is_empty() {
57 result.push_str("],\n");
58 } else {
59 result.push('\n');
60 for col in order_by {
61 let dir = match col.direction {
62 SortDirection::Asc => "ASC",
63 SortDirection::Desc => "DESC",
64 };
65 let expr_str = match &col.expr {
67 SqlExpression::Column(col_ref) => col_ref.name.clone(),
68 _ => format!("{:?}", col.expr),
69 };
70 result.push_str(&format!(
71 "{indent_str} {{ expr: \"{}\", direction: {dir} }},\n",
72 expr_str
73 ));
74 }
75 result.push_str(&format!("{indent_str} ],\n"));
76 }
77 }
78
79 if let Some(group_by) = &stmt.group_by {
81 result.push_str(&format!("{indent_str} group_by: ["));
82 if group_by.is_empty() {
83 result.push_str("],\n");
84 } else {
85 result.push('\n');
86 for expr in group_by {
87 result.push_str(&format!("{indent_str} \"{:?}\",\n", expr));
88 }
89 result.push_str(&format!("{indent_str} ],\n"));
90 }
91 }
92
93 if let Some(limit) = stmt.limit {
95 result.push_str(&format!("{indent_str} limit: {limit},\n"));
96 }
97
98 if stmt.distinct {
100 result.push_str(&format!("{indent_str} distinct: true,\n"));
101 }
102
103 result.push_str(&format!("{indent_str}}}\n"));
104 result
105}
106
107fn format_where_clause(clause: &WhereClause, indent: usize) -> String {
108 let mut result = String::new();
109 let indent_str = " ".repeat(indent);
110
111 result.push_str(&format!("{indent_str}conditions: [\n"));
112 for (i, condition) in clause.conditions.iter().enumerate() {
113 result.push_str(&format!("{indent_str} {{\n"));
114 result.push_str(&format!(
115 "{indent_str} expr: {},\n",
116 format_expression_ast(&condition.expr)
117 ));
118
119 if let Some(connector) = &condition.connector {
120 let conn_str = match connector {
121 LogicalOp::And => "AND",
122 LogicalOp::Or => "OR",
123 };
124 result.push_str(&format!("{indent_str} connector: {conn_str},\n"));
125 }
126
127 result.push_str(&format!("{indent_str} }}"));
128 if i < clause.conditions.len() - 1 {
129 result.push(',');
130 }
131 result.push('\n');
132 }
133 result.push_str(&format!("{indent_str}]\n"));
134
135 result
136}
137
138pub fn format_expression_ast(expr: &SqlExpression) -> String {
139 match expr {
140 SqlExpression::Column(name) => format!("Column(\"{name}\")"),
141 SqlExpression::StringLiteral(value) => format!("StringLiteral(\"{value}\")"),
142 SqlExpression::NumberLiteral(value) => format!("NumberLiteral({value})"),
143 SqlExpression::BinaryOp { left, op, right } => {
144 format!(
145 "BinaryOp {{ left: {}, op: \"{op}\", right: {} }}",
146 format_expression_ast(left),
147 format_expression_ast(right)
148 )
149 }
150 SqlExpression::FunctionCall {
151 name,
152 args,
153 distinct,
154 } => {
155 let args_str = args
156 .iter()
157 .map(format_expression_ast)
158 .collect::<Vec<_>>()
159 .join(", ");
160 if *distinct {
161 format!("FunctionCall {{ name: \"{name}\", args: [{args_str}], distinct: true }}")
162 } else {
163 format!("FunctionCall {{ name: \"{name}\", args: [{args_str}] }}")
164 }
165 }
166 SqlExpression::MethodCall {
167 object,
168 method,
169 args,
170 } => {
171 let args_str = args
172 .iter()
173 .map(format_expression_ast)
174 .collect::<Vec<_>>()
175 .join(", ");
176 format!(
177 "MethodCall {{ object: \"{object}\", method: \"{method}\", args: [{args_str}] }}"
178 )
179 }
180 SqlExpression::InList { expr, values } => {
181 let values_str = values
182 .iter()
183 .map(format_expression_ast)
184 .collect::<Vec<_>>()
185 .join(", ");
186 format!(
187 "InList {{ expr: {}, values: [{values_str}] }}",
188 format_expression_ast(expr)
189 )
190 }
191 SqlExpression::NotInList { expr, values } => {
192 let values_str = values
193 .iter()
194 .map(format_expression_ast)
195 .collect::<Vec<_>>()
196 .join(", ");
197 format!(
198 "NotInList {{ expr: {}, values: [{values_str}] }}",
199 format_expression_ast(expr)
200 )
201 }
202 SqlExpression::Between { expr, lower, upper } => {
203 format!(
204 "Between {{ expr: {}, lower: {}, upper: {} }}",
205 format_expression_ast(expr),
206 format_expression_ast(lower),
207 format_expression_ast(upper)
208 )
209 }
210 SqlExpression::Null => "Null".to_string(),
211 SqlExpression::BooleanLiteral(b) => format!("BooleanLiteral({b})"),
212 SqlExpression::DateTimeConstructor {
213 year,
214 month,
215 day,
216 hour,
217 minute,
218 second,
219 } => {
220 let time_part = match (hour, minute, second) {
221 (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
222 (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
223 _ => String::new(),
224 };
225 format!("DateTimeConstructor({year}-{month:02}-{day:02}{time_part})")
226 }
227 SqlExpression::DateTimeToday {
228 hour,
229 minute,
230 second,
231 } => {
232 let time_part = match (hour, minute, second) {
233 (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
234 (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
235 _ => String::new(),
236 };
237 format!("DateTimeToday({time_part})")
238 }
239 SqlExpression::WindowFunction {
240 name,
241 args,
242 window_spec: _,
243 } => {
244 let args_str = args
245 .iter()
246 .map(format_expression_ast)
247 .collect::<Vec<_>>()
248 .join(", ");
249 format!("WindowFunction {{ name: \"{name}\", args: [{args_str}], window_spec: ... }}")
250 }
251 SqlExpression::ChainedMethodCall { base, method, args } => {
252 let args_str = args
253 .iter()
254 .map(format_expression_ast)
255 .collect::<Vec<_>>()
256 .join(", ");
257 format!(
258 "ChainedMethodCall {{ base: {}, method: \"{method}\", args: [{args_str}] }}",
259 format_expression_ast(base)
260 )
261 }
262 SqlExpression::Not { expr } => {
263 format!("Not {{ expr: {} }}", format_expression_ast(expr))
264 }
265 SqlExpression::CaseExpression {
266 when_branches,
267 else_branch,
268 } => {
269 let mut result = String::from("CaseExpression { when_branches: [");
270 for branch in when_branches {
271 result.push_str(&format!(
272 " {{ condition: {}, result: {} }},",
273 format_expression_ast(&branch.condition),
274 format_expression_ast(&branch.result)
275 ));
276 }
277 result.push_str("], else_branch: ");
278 if let Some(else_expr) = else_branch {
279 result.push_str(&format_expression_ast(else_expr));
280 } else {
281 result.push_str("None");
282 }
283 result.push_str(" }");
284 result
285 }
286 SqlExpression::SimpleCaseExpression {
287 expr,
288 when_branches,
289 else_branch,
290 } => {
291 let mut result = format!(
292 "SimpleCaseExpression {{ expr: {}, when_branches: [",
293 format_expression_ast(expr)
294 );
295 for branch in when_branches {
296 result.push_str(&format!(
297 " {{ value: {}, result: {} }},",
298 format_expression_ast(&branch.value),
299 format_expression_ast(&branch.result)
300 ));
301 }
302 result.push_str("], else_branch: ");
303 if let Some(else_expr) = else_branch {
304 result.push_str(&format_expression_ast(else_expr));
305 } else {
306 result.push_str("None");
307 }
308 result.push_str(" }");
309 result
310 }
311 SqlExpression::ScalarSubquery { query: _ } => {
312 format!("ScalarSubquery {{ query: <SelectStatement> }}")
313 }
314 SqlExpression::InSubquery { expr, subquery: _ } => {
315 format!(
316 "InSubquery {{ expr: {}, subquery: <SelectStatement> }}",
317 format_expression_ast(expr)
318 )
319 }
320 SqlExpression::NotInSubquery { expr, subquery: _ } => {
321 format!(
322 "NotInSubquery {{ expr: {}, subquery: <SelectStatement> }}",
323 format_expression_ast(expr)
324 )
325 }
326 SqlExpression::InSubqueryTuple { exprs, subquery: _ } => {
327 let formatted: Vec<String> = exprs.iter().map(format_expression_ast).collect();
328 format!(
329 "InSubqueryTuple {{ exprs: ({}), subquery: <SelectStatement> }}",
330 formatted.join(", ")
331 )
332 }
333 SqlExpression::NotInSubqueryTuple { exprs, subquery: _ } => {
334 let formatted: Vec<String> = exprs.iter().map(format_expression_ast).collect();
335 format!(
336 "NotInSubqueryTuple {{ exprs: ({}), subquery: <SelectStatement> }}",
337 formatted.join(", ")
338 )
339 }
340 SqlExpression::Unnest { column, delimiter } => {
341 format!(
342 "Unnest {{ column: {}, delimiter: \"{}\" }}",
343 format_expression_ast(column),
344 delimiter
345 )
346 }
347 }
348}
349
350fn extract_text_between_positions(text: &str, start: usize, end: usize) -> String {
352 if start >= text.len() || end > text.len() || start >= end {
353 return String::new();
354 }
355 text[start..end].to_string()
356}
357
358fn find_token_position(query: &str, target: Token, skip_count: usize) -> Option<usize> {
360 let mut lexer = Lexer::new(query);
361 let mut found_count = 0;
362
363 loop {
364 let pos = lexer.get_position();
365 let token = lexer.next_token();
366 if token == Token::Eof {
367 break;
368 }
369 if token == target {
370 if found_count == skip_count {
371 return Some(pos);
372 }
373 found_count += 1;
374 }
375 }
376 None
377}
378
379pub fn format_sql_with_preserved_parens(query: &str, cols_per_line: usize) -> Vec<String> {
380 let mut parser = Parser::new(query);
381 let stmt = match parser.parse() {
382 Ok(s) => s,
383 Err(_) => return vec![query.to_string()],
384 };
385
386 let mut lines = Vec::new();
387 let mut lexer = Lexer::new(query);
388 let mut tokens_with_pos = Vec::new();
389
390 loop {
392 let pos = lexer.get_position();
393 let token = lexer.next_token();
394 if token == Token::Eof {
395 break;
396 }
397 tokens_with_pos.push((token, pos));
398 }
399
400 let mut i = 0;
402 while i < tokens_with_pos.len() {
403 match &tokens_with_pos[i].0 {
404 Token::Select => {
405 let _select_start = tokens_with_pos[i].1;
406 i += 1;
407
408 let has_distinct = if i < tokens_with_pos.len() {
410 matches!(tokens_with_pos[i].0, Token::Distinct)
411 } else {
412 false
413 };
414
415 if has_distinct {
416 i += 1;
417 }
418
419 let _select_end = query.len();
421 let _col_count = 0;
422 let _current_line_cols: Vec<String> = Vec::new();
423 let mut all_select_lines = Vec::new();
424
425 let use_pretty_format = stmt.columns.len() > cols_per_line;
427
428 if use_pretty_format {
429 let select_text = if has_distinct {
431 "SELECT DISTINCT".to_string()
432 } else {
433 "SELECT".to_string()
434 };
435 all_select_lines.push(select_text);
436
437 for (idx, col) in stmt.columns.iter().enumerate() {
439 let is_last = idx == stmt.columns.len() - 1;
440 let formatted_col = if needs_quotes(col) {
442 format!("\"{}\"", col)
443 } else {
444 col.clone()
445 };
446 let col_text = if is_last {
447 format!(" {}", formatted_col)
448 } else {
449 format!(" {},", formatted_col)
450 };
451 all_select_lines.push(col_text);
452 }
453 } else {
454 let mut select_line = if has_distinct {
456 "SELECT DISTINCT ".to_string()
457 } else {
458 "SELECT ".to_string()
459 };
460
461 for (idx, col) in stmt.columns.iter().enumerate() {
462 if idx > 0 {
463 select_line.push_str(", ");
464 }
465 if needs_quotes(col) {
467 select_line.push_str(&format!("\"{}\"", col));
468 } else {
469 select_line.push_str(col);
470 }
471 }
472 all_select_lines.push(select_line);
473 }
474
475 lines.extend(all_select_lines);
476
477 while i < tokens_with_pos.len() {
479 match &tokens_with_pos[i].0 {
480 Token::From => break,
481 _ => i += 1,
482 }
483 }
484 }
485 Token::From => {
486 let from_start = tokens_with_pos[i].1;
487 i += 1;
488
489 let mut from_end = query.len();
491 while i < tokens_with_pos.len() {
492 match &tokens_with_pos[i].0 {
493 Token::Where
494 | Token::GroupBy
495 | Token::OrderBy
496 | Token::Limit
497 | Token::Having
498 | Token::Eof => {
499 from_end = tokens_with_pos[i].1;
500 break;
501 }
502 _ => i += 1,
503 }
504 }
505
506 let from_text = extract_text_between_positions(query, from_start, from_end);
507 lines.push(from_text.trim().to_string());
508 }
509 Token::Where => {
510 let where_start = tokens_with_pos[i].1;
511 i += 1;
512
513 let mut where_end = query.len();
515 let mut paren_depth = 0;
516 while i < tokens_with_pos.len() {
517 match &tokens_with_pos[i].0 {
518 Token::LeftParen => {
519 paren_depth += 1;
520 i += 1;
521 }
522 Token::RightParen => {
523 paren_depth -= 1;
524 i += 1;
525 }
526 Token::GroupBy
527 | Token::OrderBy
528 | Token::Limit
529 | Token::Having
530 | Token::Eof
531 if paren_depth == 0 =>
532 {
533 where_end = tokens_with_pos[i].1;
534 break;
535 }
536 _ => i += 1,
537 }
538 }
539
540 let where_text = extract_text_between_positions(query, where_start, where_end);
541 let formatted_where = format_where_clause_with_parens(where_text.trim());
542 lines.extend(formatted_where);
543 }
544 Token::GroupBy => {
545 let group_start = tokens_with_pos[i].1;
546 i += 1;
547
548 if i < tokens_with_pos.len() && matches!(tokens_with_pos[i].0, Token::By) {
550 i += 1;
551 }
552
553 while i < tokens_with_pos.len() {
555 match &tokens_with_pos[i].0 {
556 Token::OrderBy | Token::Limit | Token::Having | Token::Eof => break,
557 _ => i += 1,
558 }
559 }
560
561 if i > 0 {
562 let group_text = extract_text_between_positions(
563 query,
564 group_start,
565 tokens_with_pos[i - 1].1,
566 );
567 lines.push(format!("GROUP BY {}", group_text.trim()));
568 }
569 }
570 _ => i += 1,
571 }
572 }
573
574 lines
575}
576
577#[allow(unused_assignments)]
578fn format_where_clause_with_parens(where_text: &str) -> Vec<String> {
579 let mut lines = Vec::new();
580 let mut current = String::from("WHERE ");
581 let mut _paren_depth = 0;
582 let mut in_string = false;
583 let mut escape_next = false;
584 let mut chars = where_text.chars().peekable();
585
586 if where_text.trim_start().starts_with("WHERE") || where_text.trim_start().starts_with("where")
588 {
589 let skip_len = if where_text.trim_start().starts_with("WHERE") {
590 5
591 } else {
592 5
593 };
594 for _ in 0..skip_len {
595 chars.next();
596 }
597 while chars.peek() == Some(&' ') {
599 chars.next();
600 }
601 }
602
603 while let Some(ch) = chars.next() {
604 if escape_next {
605 current.push(ch);
606 escape_next = false;
607 continue;
608 }
609
610 match ch {
611 '\\' if in_string => {
612 current.push(ch);
613 escape_next = true;
614 }
615 '\'' => {
616 current.push(ch);
617 in_string = !in_string;
618 }
619 '(' if !in_string => {
620 current.push(ch);
621 _paren_depth += 1;
622 }
623 ')' if !in_string => {
624 current.push(ch);
625 _paren_depth -= 1;
626 }
627 _ => {
628 current.push(ch);
629 }
630 }
631 }
632
633 let cleaned = current.trim().to_string();
635 if !cleaned.is_empty() {
636 lines.push(cleaned);
637 }
638
639 lines
640}
641
642#[must_use]
643pub fn format_sql_pretty_compact(query: &str, cols_per_line: usize) -> Vec<String> {
644 let formatted = format_sql_with_preserved_parens(query, cols_per_line);
646
647 formatted
649 .into_iter()
650 .filter(|line| !line.trim().is_empty())
651 .map(|line| {
652 let mut result = line;
654 for keyword in &[
655 "SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "HAVING", "LIMIT",
656 ] {
657 let pattern = format!("{keyword}");
658 if result.starts_with(&pattern) && !result.starts_with(&format!("{keyword} ")) {
659 result = format!("{keyword} {}", &result[keyword.len()..].trim_start());
660 }
661 }
662 result
663 })
664 .collect()
665}
666
667pub fn format_expression(expr: &SqlExpression) -> String {
668 match expr {
669 SqlExpression::Column(column_ref) => {
670 column_ref.to_sql()
672 }
673 SqlExpression::StringLiteral(value) => format!("'{value}'"),
674 SqlExpression::NumberLiteral(value) => value.clone(),
675 SqlExpression::BinaryOp { left, op, right } => {
676 format!(
677 "{} {} {}",
678 format_expression(left),
679 op,
680 format_expression(right)
681 )
682 }
683 SqlExpression::FunctionCall {
684 name,
685 args,
686 distinct,
687 } => {
688 let args_str = args
689 .iter()
690 .map(format_expression)
691 .collect::<Vec<_>>()
692 .join(", ");
693 if *distinct {
694 format!("{name}(DISTINCT {args_str})")
695 } else {
696 format!("{name}({args_str})")
697 }
698 }
699 SqlExpression::MethodCall {
700 object,
701 method,
702 args,
703 } => {
704 let args_str = args
705 .iter()
706 .map(format_expression)
707 .collect::<Vec<_>>()
708 .join(", ");
709 if args.is_empty() {
710 format!("{object}.{method}()")
711 } else {
712 format!("{object}.{method}({args_str})")
713 }
714 }
715 SqlExpression::InList { expr, values } => {
716 let values_str = values
717 .iter()
718 .map(format_expression)
719 .collect::<Vec<_>>()
720 .join(", ");
721 format!("{} IN ({})", format_expression(expr), values_str)
722 }
723 SqlExpression::NotInList { expr, values } => {
724 let values_str = values
725 .iter()
726 .map(format_expression)
727 .collect::<Vec<_>>()
728 .join(", ");
729 format!("{} NOT IN ({})", format_expression(expr), values_str)
730 }
731 SqlExpression::Between { expr, lower, upper } => {
732 format!(
733 "{} BETWEEN {} AND {}",
734 format_expression(expr),
735 format_expression(lower),
736 format_expression(upper)
737 )
738 }
739 SqlExpression::Null => "NULL".to_string(),
740 SqlExpression::BooleanLiteral(b) => b.to_string().to_uppercase(),
741 SqlExpression::DateTimeConstructor {
742 year,
743 month,
744 day,
745 hour,
746 minute,
747 second,
748 } => {
749 let time_part = match (hour, minute, second) {
750 (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
751 (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
752 _ => String::new(),
753 };
754 format!("DATETIME({year}, {month}, {day}{time_part})")
755 }
756 SqlExpression::DateTimeToday {
757 hour,
758 minute,
759 second,
760 } => {
761 let time_part = match (hour, minute, second) {
762 (Some(h), Some(m), Some(s)) => format!(", {h}, {m}, {s}"),
763 (Some(h), Some(m), None) => format!(", {h}, {m}"),
764 (Some(h), None, None) => format!(", {h}"),
765 _ => String::new(),
766 };
767 format!("TODAY({time_part})")
768 }
769 SqlExpression::WindowFunction {
770 name,
771 args,
772 window_spec,
773 } => {
774 let args_str = args
775 .iter()
776 .map(format_expression)
777 .collect::<Vec<_>>()
778 .join(", ");
779
780 let mut result = format!("{name}({args_str}) OVER (");
781
782 if !window_spec.partition_by.is_empty() {
784 result.push_str("PARTITION BY ");
785 result.push_str(&window_spec.partition_by.join(", "));
786 }
787
788 if !window_spec.order_by.is_empty() {
790 if !window_spec.partition_by.is_empty() {
791 result.push(' ');
792 }
793 result.push_str("ORDER BY ");
794 let order_strs: Vec<String> = window_spec
795 .order_by
796 .iter()
797 .map(|col| {
798 let dir = match col.direction {
799 SortDirection::Asc => " ASC",
800 SortDirection::Desc => " DESC",
801 };
802 let expr_str = match &col.expr {
804 SqlExpression::Column(col_ref) => col_ref.name.clone(),
805 _ => format_expression(&col.expr),
806 };
807 format!("{}{}", expr_str, dir)
808 })
809 .collect();
810 result.push_str(&order_strs.join(", "));
811 }
812
813 result.push(')');
814 result
815 }
816 SqlExpression::ChainedMethodCall { base, method, args } => {
817 let base_str = format_expression(base);
818 let args_str = args
819 .iter()
820 .map(format_expression)
821 .collect::<Vec<_>>()
822 .join(", ");
823 if args.is_empty() {
824 format!("{base_str}.{method}()")
825 } else {
826 format!("{base_str}.{method}({args_str})")
827 }
828 }
829 SqlExpression::Not { expr } => {
830 format!("NOT {}", format_expression(expr))
831 }
832 SqlExpression::CaseExpression {
833 when_branches,
834 else_branch,
835 } => format_case_expression(when_branches, else_branch.as_ref().map(|v| &**v)),
836 SqlExpression::SimpleCaseExpression {
837 expr,
838 when_branches,
839 else_branch,
840 } => format_simple_case_expression(expr, when_branches, else_branch.as_ref().map(|v| &**v)),
841 SqlExpression::ScalarSubquery { query: _ } => {
842 "(SELECT ...)".to_string()
844 }
845 SqlExpression::InSubquery { expr, subquery: _ } => {
846 format!("{} IN (SELECT ...)", format_expression(expr))
847 }
848 SqlExpression::NotInSubquery { expr, subquery: _ } => {
849 format!("{} NOT IN (SELECT ...)", format_expression(expr))
850 }
851 SqlExpression::InSubqueryTuple { exprs, subquery: _ } => {
852 let formatted: Vec<String> = exprs.iter().map(format_expression).collect();
853 format!("({}) IN (SELECT ...)", formatted.join(", "))
854 }
855 SqlExpression::NotInSubqueryTuple { exprs, subquery: _ } => {
856 let formatted: Vec<String> = exprs.iter().map(format_expression).collect();
857 format!("({}) NOT IN (SELECT ...)", formatted.join(", "))
858 }
859 SqlExpression::Unnest { column, delimiter } => {
860 format!("UNNEST({}, '{}')", format_expression(column), delimiter)
861 }
862 }
863}
864
865fn format_token(token: &Token) -> String {
866 match token {
867 Token::Identifier(s) => s.clone(),
868 Token::QuotedIdentifier(s) => format!("\"{s}\""),
869 Token::StringLiteral(s) => format!("'{s}'"),
870 Token::NumberLiteral(n) => n.clone(),
871 Token::DateTime => "DateTime".to_string(),
872 Token::Case => "CASE".to_string(),
873 Token::When => "WHEN".to_string(),
874 Token::Then => "THEN".to_string(),
875 Token::Else => "ELSE".to_string(),
876 Token::End => "END".to_string(),
877 Token::Distinct => "DISTINCT".to_string(),
878 Token::Over => "OVER".to_string(),
879 Token::Partition => "PARTITION".to_string(),
880 Token::By => "BY".to_string(),
881 Token::LeftParen => "(".to_string(),
882 Token::RightParen => ")".to_string(),
883 Token::Comma => ",".to_string(),
884 Token::Dot => ".".to_string(),
885 Token::Equal => "=".to_string(),
886 Token::NotEqual => "!=".to_string(),
887 Token::LessThan => "<".to_string(),
888 Token::GreaterThan => ">".to_string(),
889 Token::LessThanOrEqual => "<=".to_string(),
890 Token::GreaterThanOrEqual => ">=".to_string(),
891 Token::In => "IN".to_string(),
892 _ => format!("{token:?}").to_uppercase(),
893 }
894}
895
896fn needs_quotes(name: &str) -> bool {
898 if name.contains('-') || name.contains(' ') || name.contains('.') || name.contains('/') {
900 return true;
901 }
902
903 if name.chars().next().map_or(false, |c| c.is_ascii_digit()) {
905 return true;
906 }
907
908 let reserved_words = [
910 "SELECT", "FROM", "WHERE", "ORDER", "GROUP", "BY", "HAVING", "INSERT", "UPDATE", "DELETE",
911 "CREATE", "DROP", "ALTER", "TABLE", "INDEX", "VIEW", "AND", "OR", "NOT", "IN", "EXISTS",
912 "BETWEEN", "LIKE", "CASE", "WHEN", "THEN", "ELSE", "END", "JOIN", "LEFT", "RIGHT", "INNER",
913 "OUTER", "ON", "AS", "DISTINCT", "ALL", "TOP", "LIMIT", "OFFSET", "ASC", "DESC",
914 ];
915
916 let upper_name = name.to_uppercase();
917 if reserved_words.contains(&upper_name.as_str()) {
918 return true;
919 }
920
921 !name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
924}
925
926fn format_case_expression(
928 when_branches: &[crate::sql::recursive_parser::WhenBranch],
929 else_branch: Option<&SqlExpression>,
930) -> String {
931 let is_simple = when_branches.len() <= 1
933 && when_branches
934 .iter()
935 .all(|b| expr_is_simple(&b.condition) && expr_is_simple(&b.result))
936 && else_branch.map_or(true, expr_is_simple);
937
938 if is_simple {
939 let mut result = String::from("CASE");
941 for branch in when_branches {
942 result.push_str(&format!(
943 " WHEN {} THEN {}",
944 format_expression(&branch.condition),
945 format_expression(&branch.result)
946 ));
947 }
948 if let Some(else_expr) = else_branch {
949 result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
950 }
951 result.push_str(" END");
952 result
953 } else {
954 let mut result = String::from("CASE");
956 for branch in when_branches {
957 result.push_str(&format!(
958 "\n WHEN {} THEN {}",
959 format_expression(&branch.condition),
960 format_expression(&branch.result)
961 ));
962 }
963 if let Some(else_expr) = else_branch {
964 result.push_str(&format!("\n ELSE {}", format_expression(else_expr)));
965 }
966 result.push_str("\n END");
967 result
968 }
969}
970
971fn format_simple_case_expression(
973 expr: &SqlExpression,
974 when_branches: &[crate::sql::parser::ast::SimpleWhenBranch],
975 else_branch: Option<&SqlExpression>,
976) -> String {
977 let is_simple = when_branches.len() <= 2
979 && expr_is_simple(expr)
980 && when_branches
981 .iter()
982 .all(|b| expr_is_simple(&b.value) && expr_is_simple(&b.result))
983 && else_branch.map_or(true, expr_is_simple);
984
985 if is_simple {
986 let mut result = format!("CASE {}", format_expression(expr));
988 for branch in when_branches {
989 result.push_str(&format!(
990 " WHEN {} THEN {}",
991 format_expression(&branch.value),
992 format_expression(&branch.result)
993 ));
994 }
995 if let Some(else_expr) = else_branch {
996 result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
997 }
998 result.push_str(" END");
999 result
1000 } else {
1001 let mut result = format!("CASE {}", format_expression(expr));
1003 for branch in when_branches {
1004 result.push_str(&format!(
1005 "\n WHEN {} THEN {}",
1006 format_expression(&branch.value),
1007 format_expression(&branch.result)
1008 ));
1009 }
1010 if let Some(else_expr) = else_branch {
1011 result.push_str(&format!("\n ELSE {}", format_expression(else_expr)));
1012 }
1013 result.push_str("\n END");
1014 result
1015 }
1016}
1017
1018fn expr_is_simple(expr: &SqlExpression) -> bool {
1020 match expr {
1021 SqlExpression::Column(_)
1022 | SqlExpression::StringLiteral(_)
1023 | SqlExpression::NumberLiteral(_)
1024 | SqlExpression::BooleanLiteral(_)
1025 | SqlExpression::Null => true,
1026 SqlExpression::BinaryOp { left, right, .. } => {
1027 expr_is_simple(left) && expr_is_simple(right)
1028 }
1029 SqlExpression::FunctionCall { args, .. } => {
1030 args.len() <= 2 && args.iter().all(expr_is_simple)
1031 }
1032 _ => false,
1033 }
1034}