1use crate::sql::parser::ast::*;
8use std::fmt::Write;
9
10pub struct FormatConfig {
12 pub indent: String,
14 pub items_per_line: usize,
16 pub uppercase_keywords: bool,
18 pub compact: bool,
20}
21
22impl Default for FormatConfig {
23 fn default() -> Self {
24 Self {
25 indent: " ".to_string(),
26 items_per_line: 5,
27 uppercase_keywords: true,
28 compact: false,
29 }
30 }
31}
32
33pub fn format_select_statement(stmt: &SelectStatement) -> String {
35 format_select_with_config(stmt, &FormatConfig::default())
36}
37
38pub fn format_select_with_config(stmt: &SelectStatement, config: &FormatConfig) -> String {
40 let formatter = AstFormatter::new(config);
41 formatter.format_select(stmt, 0)
42}
43
44struct AstFormatter<'a> {
45 config: &'a FormatConfig,
46}
47
48impl<'a> AstFormatter<'a> {
49 fn new(config: &'a FormatConfig) -> Self {
50 Self { config }
51 }
52
53 fn keyword(&self, word: &str) -> String {
54 if self.config.uppercase_keywords {
55 word.to_uppercase()
56 } else {
57 word.to_lowercase()
58 }
59 }
60
61 fn indent(&self, level: usize) -> String {
62 self.config.indent.repeat(level)
63 }
64
65 fn format_select(&self, stmt: &SelectStatement, indent_level: usize) -> String {
66 let mut result = String::new();
67 let indent = self.indent(indent_level);
68
69 if !stmt.ctes.is_empty() {
71 writeln!(&mut result, "{}{}", indent, self.keyword("WITH")).unwrap();
72 for (i, cte) in stmt.ctes.iter().enumerate() {
73 let is_last = i == stmt.ctes.len() - 1;
74 self.format_cte(&mut result, cte, indent_level + 1, is_last);
75 }
76 }
77
78 write!(&mut result, "{}{}", indent, self.keyword("SELECT")).unwrap();
80 if stmt.distinct {
81 write!(&mut result, " {}", self.keyword("DISTINCT")).unwrap();
82 }
83
84 if stmt.select_items.is_empty() && !stmt.columns.is_empty() {
86 self.format_column_list(&mut result, &stmt.columns, indent_level);
88 } else {
89 self.format_select_items(&mut result, &stmt.select_items, indent_level);
90 }
91
92 if let Some(ref table) = stmt.from_table {
94 writeln!(&mut result).unwrap();
95 write!(&mut result, "{}{} {}", indent, self.keyword("FROM"), table).unwrap();
96 } else if let Some(ref subquery) = stmt.from_subquery {
97 writeln!(&mut result).unwrap();
98 write!(&mut result, "{}{} (", indent, self.keyword("FROM")).unwrap();
99 writeln!(&mut result).unwrap();
100 let subquery_sql = self.format_select(subquery, indent_level + 1);
101 write!(&mut result, "{}", subquery_sql).unwrap();
102 write!(&mut result, "\n{}", indent).unwrap();
103 write!(&mut result, ")").unwrap();
104 if let Some(ref alias) = stmt.from_alias {
105 write!(&mut result, " {} {}", self.keyword("AS"), alias).unwrap();
106 }
107 } else if let Some(ref func) = stmt.from_function {
108 writeln!(&mut result).unwrap();
109 write!(&mut result, "{}{} ", indent, self.keyword("FROM")).unwrap();
110 self.format_table_function(&mut result, func);
111 if let Some(ref alias) = stmt.from_alias {
112 write!(&mut result, " {} {}", self.keyword("AS"), alias).unwrap();
113 }
114 }
115
116 for join in &stmt.joins {
118 writeln!(&mut result).unwrap();
119 self.format_join(&mut result, join, indent_level);
120 }
121
122 if let Some(ref where_clause) = stmt.where_clause {
124 writeln!(&mut result).unwrap();
125 write!(&mut result, "{}{}", indent, self.keyword("WHERE")).unwrap();
126 self.format_where_clause(&mut result, where_clause, indent_level);
127 }
128
129 if let Some(ref group_by) = stmt.group_by {
131 writeln!(&mut result).unwrap();
132 write!(&mut result, "{}{} ", indent, self.keyword("GROUP BY")).unwrap();
133 for (i, expr) in group_by.iter().enumerate() {
134 if i > 0 {
135 write!(&mut result, ", ").unwrap();
136 }
137 write!(&mut result, "{}", self.format_expression(expr)).unwrap();
138 }
139 }
140
141 if let Some(ref having) = stmt.having {
143 writeln!(&mut result).unwrap();
144 write!(
145 &mut result,
146 "{}{} {}",
147 indent,
148 self.keyword("HAVING"),
149 self.format_expression(having)
150 )
151 .unwrap();
152 }
153
154 if let Some(ref order_by) = stmt.order_by {
156 writeln!(&mut result).unwrap();
157 write!(&mut result, "{}{} ", indent, self.keyword("ORDER BY")).unwrap();
158 for (i, col) in order_by.iter().enumerate() {
159 if i > 0 {
160 write!(&mut result, ", ").unwrap();
161 }
162 write!(&mut result, "{}", col.column).unwrap();
163 match col.direction {
164 SortDirection::Asc => write!(&mut result, " {}", self.keyword("ASC")).unwrap(),
165 SortDirection::Desc => {
166 write!(&mut result, " {}", self.keyword("DESC")).unwrap()
167 }
168 }
169 }
170 }
171
172 if let Some(limit) = stmt.limit {
174 writeln!(&mut result).unwrap();
175 write!(&mut result, "{}{} {}", indent, self.keyword("LIMIT"), limit).unwrap();
176 }
177
178 if let Some(offset) = stmt.offset {
180 writeln!(&mut result).unwrap();
181 write!(
182 &mut result,
183 "{}{} {}",
184 indent,
185 self.keyword("OFFSET"),
186 offset
187 )
188 .unwrap();
189 }
190
191 result
192 }
193
194 fn format_cte(&self, result: &mut String, cte: &CTE, indent_level: usize, is_last: bool) {
195 let indent = self.indent(indent_level);
196
197 let is_web = matches!(&cte.cte_type, crate::sql::parser::ast::CTEType::Web(_));
199 if is_web {
200 write!(result, "{}{} {}", indent, self.keyword("WEB"), cte.name).unwrap();
201 } else {
202 write!(result, "{}{}", indent, cte.name).unwrap();
203 }
204
205 if let Some(ref columns) = cte.column_list {
206 write!(result, "(").unwrap();
207 for (i, col) in columns.iter().enumerate() {
208 if i > 0 {
209 write!(result, ", ").unwrap();
210 }
211 write!(result, "{}", col).unwrap();
212 }
213 write!(result, ")").unwrap();
214 }
215
216 writeln!(result, " {} (", self.keyword("AS")).unwrap();
217 let cte_sql = match &cte.cte_type {
218 crate::sql::parser::ast::CTEType::Standard(query) => {
219 self.format_select(query, indent_level + 1)
220 }
221 crate::sql::parser::ast::CTEType::Web(web_spec) => {
222 let mut web_str = format!(
224 "{}{} '{}'",
225 " ".repeat(indent_level + 1),
226 self.keyword("URL"),
227 web_spec.url
228 );
229
230 if let Some(method) = &web_spec.method {
232 web_str.push_str(&format!(
233 " {} {}",
234 self.keyword("METHOD"),
235 match method {
236 crate::sql::parser::ast::HttpMethod::GET => "GET",
237 crate::sql::parser::ast::HttpMethod::POST => "POST",
238 crate::sql::parser::ast::HttpMethod::PUT => "PUT",
239 crate::sql::parser::ast::HttpMethod::DELETE => "DELETE",
240 crate::sql::parser::ast::HttpMethod::PATCH => "PATCH",
241 }
242 ));
243 }
244
245 if let Some(body) = &web_spec.body {
247 web_str.push_str(&format!(" {} '{}'", self.keyword("BODY"), body));
248 }
249
250 if let Some(format) = &web_spec.format {
252 web_str.push_str(&format!(
253 " {} {}",
254 self.keyword("FORMAT"),
255 match format {
256 crate::sql::parser::ast::DataFormat::CSV => "CSV",
257 crate::sql::parser::ast::DataFormat::JSON => "JSON",
258 crate::sql::parser::ast::DataFormat::Auto => "AUTO",
259 }
260 ));
261 }
262
263 if let Some(json_path) = &web_spec.json_path {
265 web_str.push_str(&format!(" {} '{}'", self.keyword("JSON_PATH"), json_path));
266 }
267
268 if let Some(cache) = web_spec.cache_seconds {
270 web_str.push_str(&format!(" {} {}", self.keyword("CACHE"), cache));
271 }
272
273 if !web_spec.headers.is_empty() {
275 web_str.push_str(&format!(" {} (", self.keyword("HEADERS")));
276 for (i, (key, value)) in web_spec.headers.iter().enumerate() {
277 if i > 0 {
278 web_str.push_str(", ");
279 }
280 web_str.push_str(&format!("'{}': '{}'", key, value));
281 }
282 web_str.push(')');
283 }
284
285 web_str
286 }
287 };
288 write!(result, "{}", cte_sql).unwrap();
289 writeln!(result).unwrap();
290 write!(result, "{}", indent).unwrap();
291 if is_last {
292 writeln!(result, ")").unwrap();
293 } else {
294 writeln!(result, "),").unwrap();
295 }
296 }
297
298 fn format_column_list(&self, result: &mut String, columns: &[String], indent_level: usize) {
299 if columns.len() <= self.config.items_per_line {
300 write!(result, " ").unwrap();
302 for (i, col) in columns.iter().enumerate() {
303 if i > 0 {
304 write!(result, ", ").unwrap();
305 }
306 write!(result, "{}", col).unwrap();
307 }
308 } else {
309 writeln!(result).unwrap();
311 let indent = self.indent(indent_level + 1);
312 for (i, col) in columns.iter().enumerate() {
313 write!(result, "{}{}", indent, col).unwrap();
314 if i < columns.len() - 1 {
315 writeln!(result, ",").unwrap();
316 }
317 }
318 }
319 }
320
321 fn format_select_items(&self, result: &mut String, items: &[SelectItem], indent_level: usize) {
322 if items.is_empty() {
323 write!(result, " *").unwrap();
324 return;
325 }
326
327 let non_star_count = items
329 .iter()
330 .filter(|i| !matches!(i, SelectItem::Star))
331 .count();
332
333 let has_complex_items = items.iter().any(|item| match item {
335 SelectItem::Expression { expr, .. } => self.is_complex_expression(expr),
336 _ => false,
337 });
338
339 let single_line_length: usize = items
341 .iter()
342 .map(|item| {
343 match item {
344 SelectItem::Star => 1,
345 SelectItem::Column(col) => col.name.len(),
346 SelectItem::Expression { expr, alias } => {
347 self.format_expression(expr).len() + 4 + alias.len() }
349 }
350 })
351 .sum::<usize>()
352 + (items.len() - 1) * 2; let use_single_line = match items.len() {
358 1 => !has_complex_items, 2..=3 => !has_complex_items && single_line_length < 40, _ => false, };
362
363 if !use_single_line {
364 writeln!(result).unwrap();
366 let indent = self.indent(indent_level + 1);
367 for (i, item) in items.iter().enumerate() {
368 write!(result, "{}", indent).unwrap();
369 self.format_select_item(result, item);
370 if i < items.len() - 1 {
371 writeln!(result, ",").unwrap();
372 }
373 }
374 } else {
375 write!(result, " ").unwrap();
377 for (i, item) in items.iter().enumerate() {
378 if i > 0 {
379 write!(result, ", ").unwrap();
380 }
381 self.format_select_item(result, item);
382 }
383 }
384 }
385
386 fn is_complex_expression(&self, expr: &SqlExpression) -> bool {
387 match expr {
388 SqlExpression::CaseExpression { .. } => true,
389 SqlExpression::FunctionCall { .. } => true,
390 SqlExpression::WindowFunction { .. } => true,
391 SqlExpression::ScalarSubquery { .. } => true,
392 SqlExpression::InSubquery { .. } => true,
393 SqlExpression::NotInSubquery { .. } => true,
394 SqlExpression::BinaryOp { left, right, .. } => {
395 self.is_complex_expression(left) || self.is_complex_expression(right)
396 }
397 _ => false,
398 }
399 }
400
401 fn format_select_item(&self, result: &mut String, item: &SelectItem) {
402 match item {
403 SelectItem::Star => write!(result, "*").unwrap(),
404 SelectItem::Column(col) => write!(result, "{}", col.to_sql()).unwrap(),
405 SelectItem::Expression { expr, alias } => {
406 write!(
407 result,
408 "{} {} {}",
409 self.format_expression(expr),
410 self.keyword("AS"),
411 alias
412 )
413 .unwrap();
414 }
415 }
416 }
417
418 fn format_expression(&self, expr: &SqlExpression) -> String {
419 match expr {
420 SqlExpression::Column(column_ref) => column_ref.to_sql(),
421 SqlExpression::StringLiteral(s) => format!("'{}'", s),
422 SqlExpression::NumberLiteral(n) => n.clone(),
423 SqlExpression::BooleanLiteral(b) => b.to_string().to_uppercase(),
424 SqlExpression::Null => self.keyword("NULL"),
425 SqlExpression::BinaryOp { left, op, right } => {
426 if op == "IS NULL" || op == "IS NOT NULL" {
428 format!("{} {}", self.format_expression(left), op)
429 } else {
430 format!(
431 "{} {} {}",
432 self.format_expression(left),
433 op,
434 self.format_expression(right)
435 )
436 }
437 }
438 SqlExpression::FunctionCall {
439 name,
440 args,
441 distinct,
442 } => {
443 let mut result = name.clone();
444 result.push('(');
445 if *distinct {
446 result.push_str(&self.keyword("DISTINCT"));
447 result.push(' ');
448 }
449 for (i, arg) in args.iter().enumerate() {
450 if i > 0 {
451 result.push_str(", ");
452 }
453 result.push_str(&self.format_expression(arg));
454 }
455 result.push(')');
456 result
457 }
458 SqlExpression::CaseExpression {
459 when_branches,
460 else_branch,
461 } => {
462 let mut result = String::new();
464 result.push_str(&self.keyword("CASE"));
465 result.push('\n');
466
467 for branch in when_branches {
469 result.push_str(" "); result.push_str(&format!(
471 "{} {} {} {}",
472 self.keyword("WHEN"),
473 self.format_expression(&branch.condition),
474 self.keyword("THEN"),
475 self.format_expression(&branch.result)
476 ));
477 result.push('\n');
478 }
479
480 if let Some(else_expr) = else_branch {
482 result.push_str(" "); result.push_str(&format!(
484 "{} {}",
485 self.keyword("ELSE"),
486 self.format_expression(else_expr)
487 ));
488 result.push('\n');
489 }
490
491 result.push_str(" "); result.push_str(&self.keyword("END"));
493 result
494 }
495 SqlExpression::Between { expr, lower, upper } => {
496 format!(
497 "{} {} {} {} {}",
498 self.format_expression(expr),
499 self.keyword("BETWEEN"),
500 self.format_expression(lower),
501 self.keyword("AND"),
502 self.format_expression(upper)
503 )
504 }
505 SqlExpression::InList { expr, values } => {
506 let mut result =
507 format!("{} {} (", self.format_expression(expr), self.keyword("IN"));
508 for (i, val) in values.iter().enumerate() {
509 if i > 0 {
510 result.push_str(", ");
511 }
512 result.push_str(&self.format_expression(val));
513 }
514 result.push(')');
515 result
516 }
517 SqlExpression::NotInList { expr, values } => {
518 let mut result = format!(
519 "{} {} {} (",
520 self.format_expression(expr),
521 self.keyword("NOT"),
522 self.keyword("IN")
523 );
524 for (i, val) in values.iter().enumerate() {
525 if i > 0 {
526 result.push_str(", ");
527 }
528 result.push_str(&self.format_expression(val));
529 }
530 result.push(')');
531 result
532 }
533 SqlExpression::Not { expr } => {
534 format!("{} {}", self.keyword("NOT"), self.format_expression(expr))
535 }
536 SqlExpression::ScalarSubquery { query } => {
537 let subquery_str = self.format_select(query, 0);
539 if subquery_str.contains('\n') || subquery_str.len() > 60 {
540 format!("(\n{}\n)", self.format_select(query, 1))
542 } else {
543 format!("({})", subquery_str)
545 }
546 }
547 SqlExpression::InSubquery { expr, subquery } => {
548 let subquery_str = self.format_select(subquery, 0);
549 if subquery_str.contains('\n') || subquery_str.len() > 60 {
550 format!(
552 "{} {} (\n{}\n)",
553 self.format_expression(expr),
554 self.keyword("IN"),
555 self.format_select(subquery, 1)
556 )
557 } else {
558 format!(
560 "{} {} ({})",
561 self.format_expression(expr),
562 self.keyword("IN"),
563 subquery_str
564 )
565 }
566 }
567 SqlExpression::NotInSubquery { expr, subquery } => {
568 let subquery_str = self.format_select(subquery, 0);
569 if subquery_str.contains('\n') || subquery_str.len() > 60 {
570 format!(
572 "{} {} {} (\n{}\n)",
573 self.format_expression(expr),
574 self.keyword("NOT"),
575 self.keyword("IN"),
576 self.format_select(subquery, 1)
577 )
578 } else {
579 format!(
581 "{} {} {} ({})",
582 self.format_expression(expr),
583 self.keyword("NOT"),
584 self.keyword("IN"),
585 subquery_str
586 )
587 }
588 }
589 SqlExpression::MethodCall {
590 object,
591 method,
592 args,
593 } => {
594 let mut result = format!("{}.{}", object, method);
595 result.push('(');
596 for (i, arg) in args.iter().enumerate() {
597 if i > 0 {
598 result.push_str(", ");
599 }
600 result.push_str(&self.format_expression(arg));
601 }
602 result.push(')');
603 result
604 }
605 SqlExpression::ChainedMethodCall { base, method, args } => {
606 let mut result = format!("{}.{}", self.format_expression(base), method);
607 result.push('(');
608 for (i, arg) in args.iter().enumerate() {
609 if i > 0 {
610 result.push_str(", ");
611 }
612 result.push_str(&self.format_expression(arg));
613 }
614 result.push(')');
615 result
616 }
617 SqlExpression::WindowFunction {
618 name,
619 args,
620 window_spec,
621 } => {
622 let mut result = format!("{}(", name);
623
624 for (i, arg) in args.iter().enumerate() {
626 if i > 0 {
627 result.push_str(", ");
628 }
629 result.push_str(&self.format_expression(arg));
630 }
631 result.push_str(") ");
632 result.push_str(&self.keyword("OVER"));
633 result.push_str(" (");
634
635 if !window_spec.partition_by.is_empty() {
637 result.push_str(&self.keyword("PARTITION BY"));
638 result.push(' ');
639 for (i, col) in window_spec.partition_by.iter().enumerate() {
640 if i > 0 {
641 result.push_str(", ");
642 }
643 result.push_str(col);
644 }
645 }
646
647 if !window_spec.order_by.is_empty() {
649 if !window_spec.partition_by.is_empty() {
650 result.push(' ');
651 }
652 result.push_str(&self.keyword("ORDER BY"));
653 result.push(' ');
654 for (i, col) in window_spec.order_by.iter().enumerate() {
655 if i > 0 {
656 result.push_str(", ");
657 }
658 result.push_str(&col.column);
659 match col.direction {
660 SortDirection::Asc => {
661 result.push(' ');
662 result.push_str(&self.keyword("ASC"));
663 }
664 SortDirection::Desc => {
665 result.push(' ');
666 result.push_str(&self.keyword("DESC"));
667 }
668 }
669 }
670 }
671
672 result.push(')');
673 result
674 }
675 _ => format!("{:?}", expr), }
677 }
678
679 fn format_where_clause(
680 &self,
681 result: &mut String,
682 where_clause: &WhereClause,
683 indent_level: usize,
684 ) {
685 let needs_multiline = where_clause.conditions.len() > 1;
686
687 if needs_multiline {
688 writeln!(result).unwrap();
689 let indent = self.indent(indent_level + 1);
690 for (i, condition) in where_clause.conditions.iter().enumerate() {
691 if i > 0 {
692 if let Some(ref connector) = where_clause.conditions[i - 1].connector {
693 let connector_str = match connector {
694 LogicalOp::And => self.keyword("AND"),
695 LogicalOp::Or => self.keyword("OR"),
696 };
697 writeln!(result).unwrap();
698 write!(result, "{}{} ", indent, connector_str).unwrap();
699 }
700 } else {
701 write!(result, "{}", indent).unwrap();
702 }
703 write!(result, "{}", self.format_expression(&condition.expr)).unwrap();
704 }
705 } else if let Some(condition) = where_clause.conditions.first() {
706 write!(result, " {}", self.format_expression(&condition.expr)).unwrap();
707 }
708 }
709
710 fn format_join(&self, result: &mut String, join: &JoinClause, indent_level: usize) {
711 let indent = self.indent(indent_level);
712 let join_type = match join.join_type {
713 JoinType::Inner => self.keyword("INNER JOIN"),
714 JoinType::Left => self.keyword("LEFT JOIN"),
715 JoinType::Right => self.keyword("RIGHT JOIN"),
716 JoinType::Full => self.keyword("FULL JOIN"),
717 JoinType::Cross => self.keyword("CROSS JOIN"),
718 };
719
720 write!(result, "{}{} ", indent, join_type).unwrap();
721
722 match &join.table {
723 TableSource::Table(name) => write!(result, "{}", name).unwrap(),
724 TableSource::DerivedTable { query, alias } => {
725 writeln!(result, "(").unwrap();
726 let subquery_sql = self.format_select(query, indent_level + 1);
727 write!(result, "{}", subquery_sql).unwrap();
728 writeln!(result).unwrap();
729 write!(result, "{}) {} {}", indent, self.keyword("AS"), alias).unwrap();
730 }
731 }
732
733 if let Some(ref alias) = join.alias {
734 write!(result, " {} {}", self.keyword("AS"), alias).unwrap();
735 }
736
737 if !join.condition.conditions.is_empty() {
738 write!(result, " {}", self.keyword("ON")).unwrap();
739 for (i, condition) in join.condition.conditions.iter().enumerate() {
740 if i > 0 {
741 write!(result, " {}", self.keyword("AND")).unwrap();
742 }
743 write!(
744 result,
745 " {} {} {}",
746 condition.left_column,
747 self.format_join_operator(&condition.operator),
748 condition.right_column
749 )
750 .unwrap();
751 }
752 }
753 }
754
755 fn format_join_operator(&self, op: &JoinOperator) -> String {
756 match op {
757 JoinOperator::Equal => "=",
758 JoinOperator::NotEqual => "!=",
759 JoinOperator::LessThan => "<",
760 JoinOperator::GreaterThan => ">",
761 JoinOperator::LessThanOrEqual => "<=",
762 JoinOperator::GreaterThanOrEqual => ">=",
763 }
764 .to_string()
765 }
766
767 fn format_table_function(&self, result: &mut String, func: &TableFunction) {
768 match func {
769 TableFunction::Generator { name, args } => {
770 write!(result, "{}(", self.keyword(&name.to_uppercase())).unwrap();
771 for (i, arg) in args.iter().enumerate() {
772 if i > 0 {
773 write!(result, ", ").unwrap();
774 }
775 write!(result, "{}", self.format_expression(arg)).unwrap();
776 }
777 write!(result, ")").unwrap();
778 }
779 }
780 }
781}
782
783pub fn format_sql_ast(query: &str) -> Result<String, String> {
785 use crate::sql::recursive_parser::Parser;
786
787 let mut parser = Parser::new(query);
788 match parser.parse() {
789 Ok(stmt) => Ok(format_select_statement(&stmt)),
790 Err(e) => Err(format!("Parse error: {}", e)),
791 }
792}
793
794pub fn format_sql_ast_with_config(query: &str, config: &FormatConfig) -> Result<String, String> {
796 use crate::sql::recursive_parser::Parser;
797
798 let mut parser = Parser::new(query);
799 match parser.parse() {
800 Ok(stmt) => Ok(format_select_with_config(&stmt, &config)),
801 Err(e) => Err(format!("Parse error: {}", e)),
802 }
803}