1use crate::sql::parser::ast::*;
8use std::fmt::Write;
9
10pub struct FormatConfig {
12 pub indent: String,
14 pub items_per_line: usize,
16 pub uppercase_keywords: bool,
18 pub compact: bool,
20}
21
22impl Default for FormatConfig {
23 fn default() -> Self {
24 Self {
25 indent: " ".to_string(),
26 items_per_line: 5,
27 uppercase_keywords: true,
28 compact: false,
29 }
30 }
31}
32
33pub fn format_select_statement(stmt: &SelectStatement) -> String {
35 format_select_with_config(stmt, &FormatConfig::default())
36}
37
38pub fn format_select_with_config(stmt: &SelectStatement, config: &FormatConfig) -> String {
40 let formatter = AstFormatter::new(config);
41 formatter.format_select(stmt, 0)
42}
43
44pub fn format_expression(expr: &SqlExpression) -> String {
47 let config = FormatConfig::default();
48 let formatter = AstFormatter::new(&config);
49 formatter.format_expression(expr)
50}
51
52struct AstFormatter<'a> {
53 config: &'a FormatConfig,
54}
55
56impl<'a> AstFormatter<'a> {
57 fn new(config: &'a FormatConfig) -> Self {
58 Self { config }
59 }
60
61 fn keyword(&self, word: &str) -> String {
62 if self.config.uppercase_keywords {
63 word.to_uppercase()
64 } else {
65 word.to_lowercase()
66 }
67 }
68
69 fn indent(&self, level: usize) -> String {
70 self.config.indent.repeat(level)
71 }
72
73 fn format_select(&self, stmt: &SelectStatement, indent_level: usize) -> String {
74 let mut result = String::new();
75 let indent = self.indent(indent_level);
76
77 for comment in &stmt.leading_comments {
79 self.format_comment(&mut result, comment, &indent);
80 }
81
82 if !stmt.ctes.is_empty() {
84 writeln!(&mut result, "{}{}", indent, self.keyword("WITH")).unwrap();
85 for (i, cte) in stmt.ctes.iter().enumerate() {
86 let is_last = i == stmt.ctes.len() - 1;
87 self.format_cte(&mut result, cte, indent_level + 1, is_last);
88 }
89 }
90
91 if !stmt.leading_comments.is_empty() || !stmt.ctes.is_empty() {
93 writeln!(&mut result, "{}{}", indent, self.keyword("SELECT")).unwrap();
94 } else {
95 write!(&mut result, "{}{}", indent, self.keyword("SELECT")).unwrap();
96 }
97 if stmt.distinct {
98 write!(&mut result, " {}", self.keyword("DISTINCT")).unwrap();
99 }
100
101 if stmt.select_items.is_empty() && !stmt.columns.is_empty() {
103 self.format_column_list(&mut result, &stmt.columns, indent_level);
105 } else {
106 self.format_select_items(&mut result, &stmt.select_items, indent_level);
107 }
108
109 if let Some(ref into_table) = stmt.into_table {
111 writeln!(&mut result).unwrap();
112 write!(
113 &mut result,
114 "{}{} {}",
115 indent,
116 self.keyword("INTO"),
117 into_table.name
118 )
119 .unwrap();
120 }
121
122 if let Some(ref table) = stmt.from_table {
124 writeln!(&mut result).unwrap();
125 write!(&mut result, "{}{} {}", indent, self.keyword("FROM"), table).unwrap();
126 } else if let Some(ref subquery) = stmt.from_subquery {
127 writeln!(&mut result).unwrap();
128 write!(&mut result, "{}{} (", indent, self.keyword("FROM")).unwrap();
129 writeln!(&mut result).unwrap();
130 let subquery_sql = self.format_select(subquery, indent_level + 1);
131 write!(&mut result, "{}", subquery_sql).unwrap();
132 write!(&mut result, "\n{}", indent).unwrap();
133 write!(&mut result, ")").unwrap();
134 if let Some(ref alias) = stmt.from_alias {
135 write!(&mut result, " {} {}", self.keyword("AS"), alias).unwrap();
136 }
137 } else if let Some(ref func) = stmt.from_function {
138 writeln!(&mut result).unwrap();
139 write!(&mut result, "{}{} ", indent, self.keyword("FROM")).unwrap();
140 self.format_table_function(&mut result, func);
141 if let Some(ref alias) = stmt.from_alias {
142 write!(&mut result, " {} {}", self.keyword("AS"), alias).unwrap();
143 }
144 }
145
146 for join in &stmt.joins {
148 writeln!(&mut result).unwrap();
149 self.format_join(&mut result, join, indent_level);
150 }
151
152 if let Some(ref where_clause) = stmt.where_clause {
154 writeln!(&mut result).unwrap();
155 write!(&mut result, "{}{}", indent, self.keyword("WHERE")).unwrap();
156 self.format_where_clause(&mut result, where_clause, indent_level);
157 }
158
159 if let Some(ref group_by) = stmt.group_by {
161 writeln!(&mut result).unwrap();
162 write!(&mut result, "{}{} ", indent, self.keyword("GROUP BY")).unwrap();
163 for (i, expr) in group_by.iter().enumerate() {
164 if i > 0 {
165 write!(&mut result, ", ").unwrap();
166 }
167 write!(&mut result, "{}", self.format_expression(expr)).unwrap();
168 }
169 }
170
171 if let Some(ref having) = stmt.having {
173 writeln!(&mut result).unwrap();
174 write!(
175 &mut result,
176 "{}{} {}",
177 indent,
178 self.keyword("HAVING"),
179 self.format_expression(having)
180 )
181 .unwrap();
182 }
183
184 if let Some(ref order_by) = stmt.order_by {
186 writeln!(&mut result).unwrap();
187 write!(&mut result, "{}{} ", indent, self.keyword("ORDER BY")).unwrap();
188 for (i, col) in order_by.iter().enumerate() {
189 if i > 0 {
190 write!(&mut result, ", ").unwrap();
191 }
192 write!(&mut result, "{}", self.format_expression(&col.expr)).unwrap();
193 match col.direction {
194 SortDirection::Asc => write!(&mut result, " {}", self.keyword("ASC")).unwrap(),
195 SortDirection::Desc => {
196 write!(&mut result, " {}", self.keyword("DESC")).unwrap()
197 }
198 }
199 }
200 }
201
202 if let Some(limit) = stmt.limit {
204 writeln!(&mut result).unwrap();
205 write!(&mut result, "{}{} {}", indent, self.keyword("LIMIT"), limit).unwrap();
206 }
207
208 if let Some(offset) = stmt.offset {
210 writeln!(&mut result).unwrap();
211 write!(
212 &mut result,
213 "{}{} {}",
214 indent,
215 self.keyword("OFFSET"),
216 offset
217 )
218 .unwrap();
219 }
220
221 if let Some(ref comment) = stmt.trailing_comment {
223 write!(&mut result, " ").unwrap();
224 self.format_inline_comment(&mut result, comment);
225 }
226
227 result
228 }
229
230 fn format_comment(&self, result: &mut String, comment: &Comment, indent: &str) {
232 if comment.is_line_comment {
233 writeln!(result, "{}-- {}", indent, comment.text.trim()).unwrap();
234 } else {
235 writeln!(result, "{}/* {} */", indent, comment.text.trim()).unwrap();
237 }
238 }
239
240 fn format_inline_comment(&self, result: &mut String, comment: &Comment) {
242 if comment.is_line_comment {
243 write!(result, "-- {}", comment.text.trim()).unwrap();
244 } else {
245 write!(result, "/* {} */", comment.text.trim()).unwrap();
246 }
247 }
248
249 fn format_cte(&self, result: &mut String, cte: &CTE, indent_level: usize, is_last: bool) {
250 let indent = self.indent(indent_level);
251
252 let is_web = matches!(&cte.cte_type, crate::sql::parser::ast::CTEType::Web(_));
254 if is_web {
255 write!(result, "{}{} {}", indent, self.keyword("WEB"), cte.name).unwrap();
256 } else {
257 write!(result, "{}{}", indent, cte.name).unwrap();
258 }
259
260 if let Some(ref columns) = cte.column_list {
261 write!(result, "(").unwrap();
262 for (i, col) in columns.iter().enumerate() {
263 if i > 0 {
264 write!(result, ", ").unwrap();
265 }
266 write!(result, "{}", col).unwrap();
267 }
268 write!(result, ")").unwrap();
269 }
270
271 writeln!(result, " {} (", self.keyword("AS")).unwrap();
272 let cte_sql = match &cte.cte_type {
273 crate::sql::parser::ast::CTEType::Standard(query) => {
274 self.format_select(query, indent_level + 1)
275 }
276 crate::sql::parser::ast::CTEType::Web(web_spec) => {
277 let mut web_str = format!(
279 "{}{} '{}'",
280 " ".repeat(indent_level + 1),
281 self.keyword("URL"),
282 web_spec.url
283 );
284
285 if let Some(method) = &web_spec.method {
287 web_str.push_str(&format!(
288 " {} {}",
289 self.keyword("METHOD"),
290 match method {
291 crate::sql::parser::ast::HttpMethod::GET => "GET",
292 crate::sql::parser::ast::HttpMethod::POST => "POST",
293 crate::sql::parser::ast::HttpMethod::PUT => "PUT",
294 crate::sql::parser::ast::HttpMethod::DELETE => "DELETE",
295 crate::sql::parser::ast::HttpMethod::PATCH => "PATCH",
296 }
297 ));
298 }
299
300 if let Some(body) = &web_spec.body {
302 let trimmed_body = body.trim();
304 if (trimmed_body.starts_with('{') && trimmed_body.ends_with('}'))
305 || (trimmed_body.starts_with('[') && trimmed_body.ends_with(']'))
306 {
307 match serde_json::from_str::<serde_json::Value>(trimmed_body) {
309 Ok(json_val) => {
310 match serde_json::to_string_pretty(&json_val) {
312 Ok(pretty_json) => {
313 let is_complex = pretty_json.lines().count() > 1
315 || pretty_json.contains('"')
316 || pretty_json.contains('\\');
317
318 if is_complex {
319 let base_indent = " ".repeat(indent_level + 1);
321 let json_lines: Vec<String> = pretty_json
322 .lines()
323 .enumerate()
324 .map(|(i, line)| {
325 if i == 0 {
326 line.to_string()
327 } else {
328 format!("{}{}", base_indent, line)
329 }
330 })
331 .collect();
332 let formatted_json = json_lines.join("\n");
333
334 web_str.push_str(&format!(
335 " {} $JSON${}\n{}$JSON$\n{}",
336 self.keyword("BODY"),
337 formatted_json,
338 base_indent,
339 base_indent
340 ));
341 } else {
342 web_str.push_str(&format!(
344 " {} '{}'",
345 self.keyword("BODY"),
346 pretty_json
347 ));
348 }
349 }
350 Err(_) => {
351 web_str.push_str(&format!(
353 " {} '{}'",
354 self.keyword("BODY"),
355 body
356 ));
357 }
358 }
359 }
360 Err(_) => {
361 web_str.push_str(&format!(" {} '{}'", self.keyword("BODY"), body));
363 }
364 }
365 } else {
366 web_str.push_str(&format!(" {} '{}'", self.keyword("BODY"), body));
368 }
369 }
370
371 if let Some(format) = &web_spec.format {
373 web_str.push_str(&format!(
374 " {} {}",
375 self.keyword("FORMAT"),
376 match format {
377 crate::sql::parser::ast::DataFormat::CSV => "CSV",
378 crate::sql::parser::ast::DataFormat::JSON => "JSON",
379 crate::sql::parser::ast::DataFormat::Auto => "AUTO",
380 }
381 ));
382 }
383
384 if let Some(json_path) = &web_spec.json_path {
386 web_str.push_str(&format!(" {} '{}'", self.keyword("JSON_PATH"), json_path));
387 }
388
389 if let Some(cache) = web_spec.cache_seconds {
391 web_str.push_str(&format!(" {} {}", self.keyword("CACHE"), cache));
392 }
393
394 for (field_name, file_path) in &web_spec.form_files {
396 web_str.push_str(&format!(
397 "\n{}{} '{}' '{}'",
398 " ".repeat(indent_level + 1),
399 self.keyword("FORM_FILE"),
400 field_name,
401 file_path
402 ));
403 }
404
405 for (field_name, value) in &web_spec.form_fields {
407 let trimmed_value = value.trim();
409 if (trimmed_value.starts_with('{') && trimmed_value.ends_with('}'))
410 || (trimmed_value.starts_with('[') && trimmed_value.ends_with(']'))
411 {
412 match serde_json::from_str::<serde_json::Value>(trimmed_value) {
414 Ok(json_val) => {
415 match serde_json::to_string_pretty(&json_val) {
417 Ok(pretty_json) => {
418 let is_complex = pretty_json.lines().count() > 1
420 || pretty_json.contains('"')
421 || pretty_json.contains('\\');
422
423 if is_complex {
424 let base_indent = " ".repeat(indent_level + 1);
426 let json_lines: Vec<String> = pretty_json
427 .lines()
428 .enumerate()
429 .map(|(i, line)| {
430 if i == 0 {
431 line.to_string()
432 } else {
433 format!("{}{}", base_indent, line)
434 }
435 })
436 .collect();
437 let formatted_json = json_lines.join("\n");
438
439 web_str.push_str(&format!(
440 "\n{}{} '{}' $JSON${}\n{}$JSON$",
441 base_indent,
442 self.keyword("FORM_FIELD"),
443 field_name,
444 formatted_json,
445 base_indent
446 ));
447 } else {
448 web_str.push_str(&format!(
450 "\n{}{} '{}' '{}'",
451 " ".repeat(indent_level + 1),
452 self.keyword("FORM_FIELD"),
453 field_name,
454 pretty_json
455 ));
456 }
457 }
458 Err(_) => {
459 web_str.push_str(&format!(
461 "\n{}{} '{}' '{}'",
462 " ".repeat(indent_level + 1),
463 self.keyword("FORM_FIELD"),
464 field_name,
465 value
466 ));
467 }
468 }
469 }
470 Err(_) => {
471 web_str.push_str(&format!(
473 "\n{}{} '{}' '{}'",
474 " ".repeat(indent_level + 1),
475 self.keyword("FORM_FIELD"),
476 field_name,
477 value
478 ));
479 }
480 }
481 } else {
482 web_str.push_str(&format!(
484 "\n{}{} '{}' '{}'",
485 " ".repeat(indent_level + 1),
486 self.keyword("FORM_FIELD"),
487 field_name,
488 value
489 ));
490 }
491 }
492
493 if !web_spec.headers.is_empty() {
495 web_str.push_str(&format!(" {} (", self.keyword("HEADERS")));
496 for (i, (key, value)) in web_spec.headers.iter().enumerate() {
497 if i > 0 {
498 web_str.push_str(", ");
499 }
500 web_str.push_str(&format!("'{}': '{}'", key, value));
501 }
502 web_str.push(')');
503 }
504
505 web_str
506 }
507 crate::sql::parser::ast::CTEType::File(file_spec) => {
508 let base_indent = " ".repeat(indent_level + 1);
509 let mut s = format!(
510 "{}{} {} '{}'",
511 base_indent,
512 self.keyword("FILE"),
513 self.keyword("PATH"),
514 file_spec.path
515 );
516 if file_spec.recursive {
517 s.push_str(&format!(" {}", self.keyword("RECURSIVE")));
518 }
519 if let Some(ref g) = file_spec.glob {
520 s.push_str(&format!(" {} '{}'", self.keyword("GLOB"), g));
521 }
522 if let Some(d) = file_spec.max_depth {
523 s.push_str(&format!(" {} {}", self.keyword("MAX_DEPTH"), d));
524 }
525 if let Some(m) = file_spec.max_files {
526 s.push_str(&format!(" {} {}", self.keyword("MAX_FILES"), m));
527 }
528 if file_spec.follow_links {
529 s.push_str(&format!(" {}", self.keyword("FOLLOW_LINKS")));
530 }
531 if file_spec.include_hidden {
532 s.push_str(&format!(" {}", self.keyword("INCLUDE_HIDDEN")));
533 }
534 s
535 }
536 };
537 write!(result, "{}", cte_sql).unwrap();
538 writeln!(result).unwrap();
539 write!(result, "{}", indent).unwrap();
540 if is_last {
541 writeln!(result, ")").unwrap();
542 } else {
543 writeln!(result, "),").unwrap();
544 }
545 }
546
547 fn format_column_list(&self, result: &mut String, columns: &[String], indent_level: usize) {
548 if columns.len() <= self.config.items_per_line {
549 write!(result, " ").unwrap();
551 for (i, col) in columns.iter().enumerate() {
552 if i > 0 {
553 write!(result, ", ").unwrap();
554 }
555 write!(result, "{}", col).unwrap();
556 }
557 } else {
558 writeln!(result).unwrap();
560 let indent = self.indent(indent_level + 1);
561 for (i, col) in columns.iter().enumerate() {
562 write!(result, "{}{}", indent, col).unwrap();
563 if i < columns.len() - 1 {
564 writeln!(result, ",").unwrap();
565 }
566 }
567 }
568 }
569
570 fn format_select_items(&self, result: &mut String, items: &[SelectItem], indent_level: usize) {
571 if items.is_empty() {
572 write!(result, " *").unwrap();
573 return;
574 }
575
576 let _non_star_count = items
578 .iter()
579 .filter(|i| !matches!(i, SelectItem::Star { .. }))
580 .count();
581
582 let has_complex_items = items.iter().any(|item| match item {
584 SelectItem::Expression { expr, .. } => self.is_complex_expression(expr),
585 _ => false,
586 });
587
588 let single_line_length: usize = items
590 .iter()
591 .map(|item| {
592 match item {
593 SelectItem::Star { .. } => 1,
594 SelectItem::StarExclude {
595 excluded_columns, ..
596 } => {
597 11 + excluded_columns.iter().map(|c| c.len()).sum::<usize>()
599 + (excluded_columns.len().saturating_sub(1) * 2) }
601 SelectItem::Column { column: col, .. } => col.name.len(),
602 SelectItem::Expression { expr, alias, .. } => {
603 self.format_expression(expr).len() + 4 + alias.len() }
605 }
606 })
607 .sum::<usize>()
608 + (items.len() - 1) * 2; let use_single_line = match items.len() {
614 1 => !has_complex_items, 2..=3 => !has_complex_items && single_line_length < 40, _ => false, };
618
619 if !use_single_line {
620 writeln!(result).unwrap();
622 let indent = self.indent(indent_level + 1);
623 for (i, item) in items.iter().enumerate() {
624 write!(result, "{}", indent).unwrap();
625 self.format_select_item(result, item);
626 if i < items.len() - 1 {
627 writeln!(result, ",").unwrap();
628 }
629 }
630 } else {
631 write!(result, " ").unwrap();
633 for (i, item) in items.iter().enumerate() {
634 if i > 0 {
635 write!(result, ", ").unwrap();
636 }
637 self.format_select_item(result, item);
638 }
639 }
640 }
641
642 fn is_complex_expression(&self, expr: &SqlExpression) -> bool {
643 match expr {
644 SqlExpression::CaseExpression { .. } => true,
645 SqlExpression::FunctionCall { .. } => true,
646 SqlExpression::WindowFunction { .. } => true,
647 SqlExpression::ScalarSubquery { .. } => true,
648 SqlExpression::InSubquery { .. } => true,
649 SqlExpression::NotInSubquery { .. } => true,
650 SqlExpression::BinaryOp { left, right, .. } => {
651 self.is_complex_expression(left) || self.is_complex_expression(right)
652 }
653 _ => false,
654 }
655 }
656
657 fn format_select_item(&self, result: &mut String, item: &SelectItem) {
658 match item {
659 SelectItem::Star { .. } => write!(result, "*").unwrap(),
660 SelectItem::StarExclude {
661 excluded_columns, ..
662 } => {
663 write!(
664 result,
665 "* {} ({})",
666 self.keyword("EXCLUDE"),
667 excluded_columns.join(", ")
668 )
669 .unwrap();
670 }
671 SelectItem::Column { column: col, .. } => write!(result, "{}", col.to_sql()).unwrap(),
672 SelectItem::Expression { expr, alias, .. } => {
673 write!(
674 result,
675 "{} {} {}",
676 self.format_expression(expr),
677 self.keyword("AS"),
678 alias
679 )
680 .unwrap();
681 }
682 }
683 }
684
685 fn format_expression(&self, expr: &SqlExpression) -> String {
686 match expr {
687 SqlExpression::Column(column_ref) => column_ref.to_sql(),
688 SqlExpression::StringLiteral(s) => format!("'{}'", s),
689 SqlExpression::NumberLiteral(n) => n.clone(),
690 SqlExpression::BooleanLiteral(b) => b.to_string().to_uppercase(),
691 SqlExpression::Null => self.keyword("NULL"),
692 SqlExpression::BinaryOp { left, op, right } => {
693 if op == "IS NULL" || op == "IS NOT NULL" {
695 format!("{} {}", self.format_expression(left), op)
696 } else {
697 format!(
698 "{} {} {}",
699 self.format_expression(left),
700 op,
701 self.format_expression(right)
702 )
703 }
704 }
705 SqlExpression::FunctionCall {
706 name,
707 args,
708 distinct,
709 } => {
710 let mut result = name.clone();
711 result.push('(');
712 if *distinct {
713 result.push_str(&self.keyword("DISTINCT"));
714 result.push(' ');
715 }
716 for (i, arg) in args.iter().enumerate() {
717 if i > 0 {
718 result.push_str(", ");
719 }
720 result.push_str(&self.format_expression(arg));
721 }
722 result.push(')');
723 result
724 }
725 SqlExpression::CaseExpression {
726 when_branches,
727 else_branch,
728 } => {
729 let mut result = String::new();
731 result.push_str(&self.keyword("CASE"));
732 result.push('\n');
733
734 for branch in when_branches {
736 result.push_str(" "); result.push_str(&format!(
738 "{} {} {} {}",
739 self.keyword("WHEN"),
740 self.format_expression(&branch.condition),
741 self.keyword("THEN"),
742 self.format_expression(&branch.result)
743 ));
744 result.push('\n');
745 }
746
747 if let Some(else_expr) = else_branch {
749 result.push_str(" "); result.push_str(&format!(
751 "{} {}",
752 self.keyword("ELSE"),
753 self.format_expression(else_expr)
754 ));
755 result.push('\n');
756 }
757
758 result.push_str(" "); result.push_str(&self.keyword("END"));
760 result
761 }
762 SqlExpression::SimpleCaseExpression {
763 expr,
764 when_branches,
765 else_branch,
766 } => {
767 let mut result = String::new();
769 result.push_str(&format!(
770 "{} {}",
771 self.keyword("CASE"),
772 self.format_expression(expr)
773 ));
774 result.push('\n');
775
776 for branch in when_branches {
778 result.push_str(" "); result.push_str(&format!(
780 "{} {} {} {}",
781 self.keyword("WHEN"),
782 self.format_expression(&branch.value),
783 self.keyword("THEN"),
784 self.format_expression(&branch.result)
785 ));
786 result.push('\n');
787 }
788
789 if let Some(else_expr) = else_branch {
791 result.push_str(" "); result.push_str(&format!(
793 "{} {}",
794 self.keyword("ELSE"),
795 self.format_expression(else_expr)
796 ));
797 result.push('\n');
798 }
799
800 result.push_str(" "); result.push_str(&self.keyword("END"));
802 result
803 }
804 SqlExpression::Between { expr, lower, upper } => {
805 format!(
806 "{} {} {} {} {}",
807 self.format_expression(expr),
808 self.keyword("BETWEEN"),
809 self.format_expression(lower),
810 self.keyword("AND"),
811 self.format_expression(upper)
812 )
813 }
814 SqlExpression::InList { expr, values } => {
815 let mut result =
816 format!("{} {} (", self.format_expression(expr), self.keyword("IN"));
817 for (i, val) in values.iter().enumerate() {
818 if i > 0 {
819 result.push_str(", ");
820 }
821 result.push_str(&self.format_expression(val));
822 }
823 result.push(')');
824 result
825 }
826 SqlExpression::NotInList { expr, values } => {
827 let mut result = format!(
828 "{} {} {} (",
829 self.format_expression(expr),
830 self.keyword("NOT"),
831 self.keyword("IN")
832 );
833 for (i, val) in values.iter().enumerate() {
834 if i > 0 {
835 result.push_str(", ");
836 }
837 result.push_str(&self.format_expression(val));
838 }
839 result.push(')');
840 result
841 }
842 SqlExpression::Not { expr } => {
843 format!("{} {}", self.keyword("NOT"), self.format_expression(expr))
844 }
845 SqlExpression::ScalarSubquery { query } => {
846 let subquery_str = self.format_select(query, 0);
848 if subquery_str.contains('\n') || subquery_str.len() > 60 {
849 format!("(\n{}\n)", self.format_select(query, 1))
851 } else {
852 format!("({})", subquery_str)
854 }
855 }
856 SqlExpression::InSubquery { expr, subquery } => {
857 let subquery_str = self.format_select(subquery, 0);
858 if subquery_str.contains('\n') || subquery_str.len() > 60 {
859 format!(
861 "{} {} (\n{}\n)",
862 self.format_expression(expr),
863 self.keyword("IN"),
864 self.format_select(subquery, 1)
865 )
866 } else {
867 format!(
869 "{} {} ({})",
870 self.format_expression(expr),
871 self.keyword("IN"),
872 subquery_str
873 )
874 }
875 }
876 SqlExpression::NotInSubquery { expr, subquery } => {
877 let subquery_str = self.format_select(subquery, 0);
878 if subquery_str.contains('\n') || subquery_str.len() > 60 {
879 format!(
881 "{} {} {} (\n{}\n)",
882 self.format_expression(expr),
883 self.keyword("NOT"),
884 self.keyword("IN"),
885 self.format_select(subquery, 1)
886 )
887 } else {
888 format!(
890 "{} {} {} ({})",
891 self.format_expression(expr),
892 self.keyword("NOT"),
893 self.keyword("IN"),
894 subquery_str
895 )
896 }
897 }
898 SqlExpression::MethodCall {
899 object,
900 method,
901 args,
902 } => {
903 let mut result = format!("{}.{}", object, method);
904 result.push('(');
905 for (i, arg) in args.iter().enumerate() {
906 if i > 0 {
907 result.push_str(", ");
908 }
909 result.push_str(&self.format_expression(arg));
910 }
911 result.push(')');
912 result
913 }
914 SqlExpression::ChainedMethodCall { base, method, args } => {
915 let mut result = format!("{}.{}", self.format_expression(base), method);
916 result.push('(');
917 for (i, arg) in args.iter().enumerate() {
918 if i > 0 {
919 result.push_str(", ");
920 }
921 result.push_str(&self.format_expression(arg));
922 }
923 result.push(')');
924 result
925 }
926 SqlExpression::WindowFunction {
927 name,
928 args,
929 window_spec,
930 } => {
931 let mut result = format!("{}(", name);
932
933 for (i, arg) in args.iter().enumerate() {
935 if i > 0 {
936 result.push_str(", ");
937 }
938 result.push_str(&self.format_expression(arg));
939 }
940 result.push_str(") ");
941 result.push_str(&self.keyword("OVER"));
942 result.push_str(" (");
943
944 if !window_spec.partition_by.is_empty() {
946 result.push_str(&self.keyword("PARTITION BY"));
947 result.push(' ');
948 for (i, col) in window_spec.partition_by.iter().enumerate() {
949 if i > 0 {
950 result.push_str(", ");
951 }
952 result.push_str(col);
953 }
954 }
955
956 if !window_spec.order_by.is_empty() {
958 if !window_spec.partition_by.is_empty() {
959 result.push(' ');
960 }
961 result.push_str(&self.keyword("ORDER BY"));
962 result.push(' ');
963 for (i, col) in window_spec.order_by.iter().enumerate() {
964 if i > 0 {
965 result.push_str(", ");
966 }
967 result.push_str(&self.format_expression(&col.expr));
968 match col.direction {
969 SortDirection::Asc => {
970 result.push(' ');
971 result.push_str(&self.keyword("ASC"));
972 }
973 SortDirection::Desc => {
974 result.push(' ');
975 result.push_str(&self.keyword("DESC"));
976 }
977 }
978 }
979 }
980
981 if let Some(frame) = &window_spec.frame {
983 if !window_spec.partition_by.is_empty() || !window_spec.order_by.is_empty() {
985 result.push(' ');
986 }
987
988 match frame.unit {
990 FrameUnit::Rows => result.push_str(&self.keyword("ROWS")),
991 FrameUnit::Range => result.push_str(&self.keyword("RANGE")),
992 }
993
994 result.push(' ');
995
996 if let Some(end) = &frame.end {
998 result.push_str(&self.keyword("BETWEEN"));
1000 result.push(' ');
1001 result.push_str(&self.format_frame_bound(&frame.start));
1002 result.push(' ');
1003 result.push_str(&self.keyword("AND"));
1004 result.push(' ');
1005 result.push_str(&self.format_frame_bound(end));
1006 } else {
1007 result.push_str(&self.format_frame_bound(&frame.start));
1009 }
1010 }
1011
1012 result.push(')');
1013 result
1014 }
1015 SqlExpression::DateTimeConstructor {
1016 year,
1017 month,
1018 day,
1019 hour,
1020 minute,
1021 second,
1022 } => {
1023 if let (Some(h), Some(m), Some(s)) = (hour, minute, second) {
1024 format!(
1025 "DateTime({}, {}, {}, {}, {}, {})",
1026 year, month, day, h, m, s
1027 )
1028 } else {
1029 format!("DateTime({}, {}, {})", year, month, day)
1030 }
1031 }
1032 SqlExpression::DateTimeToday {
1033 hour,
1034 minute,
1035 second,
1036 } => {
1037 if let (Some(h), Some(m), Some(s)) = (hour, minute, second) {
1038 format!("Today({}, {}, {})", h, m, s)
1039 } else {
1040 "Today()".to_string()
1041 }
1042 }
1043 _ => format!("{:?}", expr), }
1045 }
1046
1047 fn format_where_clause(
1048 &self,
1049 result: &mut String,
1050 where_clause: &WhereClause,
1051 indent_level: usize,
1052 ) {
1053 let needs_multiline = where_clause.conditions.len() > 1;
1054
1055 if needs_multiline {
1056 writeln!(result).unwrap();
1057 let indent = self.indent(indent_level + 1);
1058 for (i, condition) in where_clause.conditions.iter().enumerate() {
1059 if i > 0 {
1060 if let Some(ref connector) = where_clause.conditions[i - 1].connector {
1061 let connector_str = match connector {
1062 LogicalOp::And => self.keyword("AND"),
1063 LogicalOp::Or => self.keyword("OR"),
1064 };
1065 writeln!(result).unwrap();
1066 write!(result, "{}{} ", indent, connector_str).unwrap();
1067 }
1068 } else {
1069 write!(result, "{}", indent).unwrap();
1070 }
1071 write!(result, "{}", self.format_expression(&condition.expr)).unwrap();
1072 }
1073 } else if let Some(condition) = where_clause.conditions.first() {
1074 write!(result, " {}", self.format_expression(&condition.expr)).unwrap();
1075 }
1076 }
1077
1078 fn format_frame_bound(&self, bound: &FrameBound) -> String {
1079 match bound {
1080 FrameBound::UnboundedPreceding => self.keyword("UNBOUNDED PRECEDING"),
1081 FrameBound::CurrentRow => self.keyword("CURRENT ROW"),
1082 FrameBound::UnboundedFollowing => self.keyword("UNBOUNDED FOLLOWING"),
1083 FrameBound::Preceding(n) => format!("{} {}", n, self.keyword("PRECEDING")),
1084 FrameBound::Following(n) => format!("{} {}", n, self.keyword("FOLLOWING")),
1085 }
1086 }
1087
1088 fn format_join(&self, result: &mut String, join: &JoinClause, indent_level: usize) {
1089 let indent = self.indent(indent_level);
1090 let join_type = match join.join_type {
1091 JoinType::Inner => self.keyword("INNER JOIN"),
1092 JoinType::Left => self.keyword("LEFT JOIN"),
1093 JoinType::Right => self.keyword("RIGHT JOIN"),
1094 JoinType::Full => self.keyword("FULL JOIN"),
1095 JoinType::Cross => self.keyword("CROSS JOIN"),
1096 };
1097
1098 write!(result, "{}{} ", indent, join_type).unwrap();
1099
1100 match &join.table {
1101 TableSource::Table(name) => write!(result, "{}", name).unwrap(),
1102 TableSource::DerivedTable { query, alias } => {
1103 writeln!(result, "(").unwrap();
1104 let subquery_sql = self.format_select(query, indent_level + 1);
1105 write!(result, "{}", subquery_sql).unwrap();
1106 writeln!(result).unwrap();
1107 write!(result, "{}) {} {}", indent, self.keyword("AS"), alias).unwrap();
1108 }
1109 TableSource::Pivot {
1110 source,
1111 aggregate,
1112 pivot_column,
1113 pivot_values,
1114 alias,
1115 } => {
1116 match source.as_ref() {
1118 TableSource::Table(name) => write!(result, "{}", name).unwrap(),
1119 TableSource::DerivedTable { query, alias } => {
1120 writeln!(result, "(").unwrap();
1121 let subquery_sql = self.format_select(query, indent_level + 1);
1122 write!(result, "{}", subquery_sql).unwrap();
1123 writeln!(result).unwrap();
1124 write!(result, "{}) {} {}", indent, self.keyword("AS"), alias).unwrap();
1125 }
1126 TableSource::Pivot { .. } => {
1127 write!(result, "[NESTED_PIVOT]").unwrap();
1129 }
1130 }
1131
1132 write!(result, " {} (", self.keyword("PIVOT")).unwrap();
1134 write!(result, "{}({}) ", aggregate.function, aggregate.column).unwrap();
1135 write!(
1136 result,
1137 "{} {} {} (",
1138 self.keyword("FOR"),
1139 pivot_column,
1140 self.keyword("IN")
1141 )
1142 .unwrap();
1143 for (i, val) in pivot_values.iter().enumerate() {
1144 if i > 0 {
1145 write!(result, ", ").unwrap();
1146 }
1147 write!(result, "'{}'", val).unwrap();
1148 }
1149 write!(result, "))").unwrap();
1150
1151 if let Some(pivot_alias) = alias {
1152 write!(result, " {} {}", self.keyword("AS"), pivot_alias).unwrap();
1153 }
1154 }
1155 }
1156
1157 if let Some(ref alias) = join.alias {
1158 write!(result, " {} {}", self.keyword("AS"), alias).unwrap();
1159 }
1160
1161 if !join.condition.conditions.is_empty() {
1162 write!(result, " {}", self.keyword("ON")).unwrap();
1163 for (i, condition) in join.condition.conditions.iter().enumerate() {
1164 if i > 0 {
1165 write!(result, " {}", self.keyword("AND")).unwrap();
1166 }
1167 write!(
1168 result,
1169 " {} {} {}",
1170 self.format_expression(&condition.left_expr),
1171 self.format_join_operator(&condition.operator),
1172 self.format_expression(&condition.right_expr)
1173 )
1174 .unwrap();
1175 }
1176 }
1177 }
1178
1179 fn format_join_operator(&self, op: &JoinOperator) -> String {
1180 match op {
1181 JoinOperator::Equal => "=",
1182 JoinOperator::NotEqual => "!=",
1183 JoinOperator::LessThan => "<",
1184 JoinOperator::GreaterThan => ">",
1185 JoinOperator::LessThanOrEqual => "<=",
1186 JoinOperator::GreaterThanOrEqual => ">=",
1187 }
1188 .to_string()
1189 }
1190
1191 fn format_table_function(&self, result: &mut String, func: &TableFunction) {
1192 match func {
1193 TableFunction::Generator { name, args } => {
1194 write!(result, "{}(", self.keyword(&name.to_uppercase())).unwrap();
1195 for (i, arg) in args.iter().enumerate() {
1196 if i > 0 {
1197 write!(result, ", ").unwrap();
1198 }
1199 write!(result, "{}", self.format_expression(arg)).unwrap();
1200 }
1201 write!(result, ")").unwrap();
1202 }
1203 }
1204 }
1205}
1206
1207pub fn format_sql_ast(query: &str) -> Result<String, String> {
1209 use crate::sql::recursive_parser::Parser;
1210
1211 let mut parser = Parser::new(query);
1212 match parser.parse() {
1213 Ok(stmt) => Ok(format_select_statement(&stmt)),
1214 Err(e) => Err(format!("Parse error: {}", e)),
1215 }
1216}
1217
1218pub fn format_sql_ast_with_config(query: &str, config: &FormatConfig) -> Result<String, String> {
1220 use crate::sql::recursive_parser::Parser;
1221
1222 let mut parser = Parser::new(query);
1223 match parser.parse() {
1224 Ok(stmt) => Ok(format_select_with_config(&stmt, &config)),
1225 Err(e) => Err(format!("Parse error: {}", e)),
1226 }
1227}