1use crate::sql::parser::ast::*;
8use std::fmt::Write;
9
10pub struct FormatConfig {
12 pub indent: String,
14 pub items_per_line: usize,
16 pub uppercase_keywords: bool,
18 pub compact: bool,
20}
21
22impl Default for FormatConfig {
23 fn default() -> Self {
24 Self {
25 indent: " ".to_string(),
26 items_per_line: 5,
27 uppercase_keywords: true,
28 compact: false,
29 }
30 }
31}
32
33pub fn format_select_statement(stmt: &SelectStatement) -> String {
35 format_select_with_config(stmt, &FormatConfig::default())
36}
37
38pub fn format_select_with_config(stmt: &SelectStatement, config: &FormatConfig) -> String {
40 let formatter = AstFormatter::new(config);
41 formatter.format_select(stmt, 0)
42}
43
44struct AstFormatter<'a> {
45 config: &'a FormatConfig,
46}
47
48impl<'a> AstFormatter<'a> {
49 fn new(config: &'a FormatConfig) -> Self {
50 Self { config }
51 }
52
53 fn keyword(&self, word: &str) -> String {
54 if self.config.uppercase_keywords {
55 word.to_uppercase()
56 } else {
57 word.to_lowercase()
58 }
59 }
60
61 fn indent(&self, level: usize) -> String {
62 self.config.indent.repeat(level)
63 }
64
65 fn format_select(&self, stmt: &SelectStatement, indent_level: usize) -> String {
66 let mut result = String::new();
67 let indent = self.indent(indent_level);
68
69 if !stmt.ctes.is_empty() {
71 writeln!(&mut result, "{}{}", indent, self.keyword("WITH")).unwrap();
72 for (i, cte) in stmt.ctes.iter().enumerate() {
73 let is_last = i == stmt.ctes.len() - 1;
74 self.format_cte(&mut result, cte, indent_level + 1, is_last);
75 }
76 }
77
78 write!(&mut result, "{}{}", indent, self.keyword("SELECT")).unwrap();
80 if stmt.distinct {
81 write!(&mut result, " {}", self.keyword("DISTINCT")).unwrap();
82 }
83
84 if stmt.select_items.is_empty() && !stmt.columns.is_empty() {
86 self.format_column_list(&mut result, &stmt.columns, indent_level);
88 } else {
89 self.format_select_items(&mut result, &stmt.select_items, indent_level);
90 }
91
92 if let Some(ref table) = stmt.from_table {
94 writeln!(&mut result).unwrap();
95 write!(&mut result, "{}{} {}", indent, self.keyword("FROM"), table).unwrap();
96 } else if let Some(ref subquery) = stmt.from_subquery {
97 writeln!(&mut result).unwrap();
98 write!(&mut result, "{}{} (", indent, self.keyword("FROM")).unwrap();
99 writeln!(&mut result).unwrap();
100 let subquery_sql = self.format_select(subquery, indent_level + 1);
101 write!(&mut result, "{}", subquery_sql).unwrap();
102 write!(&mut result, "\n{}", indent).unwrap();
103 write!(&mut result, ")").unwrap();
104 if let Some(ref alias) = stmt.from_alias {
105 write!(&mut result, " {} {}", self.keyword("AS"), alias).unwrap();
106 }
107 } else if let Some(ref func) = stmt.from_function {
108 writeln!(&mut result).unwrap();
109 write!(&mut result, "{}{} ", indent, self.keyword("FROM")).unwrap();
110 self.format_table_function(&mut result, func);
111 if let Some(ref alias) = stmt.from_alias {
112 write!(&mut result, " {} {}", self.keyword("AS"), alias).unwrap();
113 }
114 }
115
116 for join in &stmt.joins {
118 writeln!(&mut result).unwrap();
119 self.format_join(&mut result, join, indent_level);
120 }
121
122 if let Some(ref where_clause) = stmt.where_clause {
124 writeln!(&mut result).unwrap();
125 write!(&mut result, "{}{}", indent, self.keyword("WHERE")).unwrap();
126 self.format_where_clause(&mut result, where_clause, indent_level);
127 }
128
129 if let Some(ref group_by) = stmt.group_by {
131 writeln!(&mut result).unwrap();
132 write!(&mut result, "{}{} ", indent, self.keyword("GROUP BY")).unwrap();
133 for (i, expr) in group_by.iter().enumerate() {
134 if i > 0 {
135 write!(&mut result, ", ").unwrap();
136 }
137 write!(&mut result, "{}", self.format_expression(expr)).unwrap();
138 }
139 }
140
141 if let Some(ref having) = stmt.having {
143 writeln!(&mut result).unwrap();
144 write!(
145 &mut result,
146 "{}{} {}",
147 indent,
148 self.keyword("HAVING"),
149 self.format_expression(having)
150 )
151 .unwrap();
152 }
153
154 if let Some(ref order_by) = stmt.order_by {
156 writeln!(&mut result).unwrap();
157 write!(&mut result, "{}{} ", indent, self.keyword("ORDER BY")).unwrap();
158 for (i, col) in order_by.iter().enumerate() {
159 if i > 0 {
160 write!(&mut result, ", ").unwrap();
161 }
162 write!(&mut result, "{}", col.column).unwrap();
163 match col.direction {
164 SortDirection::Asc => write!(&mut result, " {}", self.keyword("ASC")).unwrap(),
165 SortDirection::Desc => {
166 write!(&mut result, " {}", self.keyword("DESC")).unwrap()
167 }
168 }
169 }
170 }
171
172 if let Some(limit) = stmt.limit {
174 writeln!(&mut result).unwrap();
175 write!(&mut result, "{}{} {}", indent, self.keyword("LIMIT"), limit).unwrap();
176 }
177
178 if let Some(offset) = stmt.offset {
180 writeln!(&mut result).unwrap();
181 write!(
182 &mut result,
183 "{}{} {}",
184 indent,
185 self.keyword("OFFSET"),
186 offset
187 )
188 .unwrap();
189 }
190
191 result
192 }
193
194 fn format_cte(&self, result: &mut String, cte: &CTE, indent_level: usize, is_last: bool) {
195 let indent = self.indent(indent_level);
196
197 let is_web = matches!(&cte.cte_type, crate::sql::parser::ast::CTEType::Web(_));
199 if is_web {
200 write!(result, "{}{} {}", indent, self.keyword("WEB"), cte.name).unwrap();
201 } else {
202 write!(result, "{}{}", indent, cte.name).unwrap();
203 }
204
205 if let Some(ref columns) = cte.column_list {
206 write!(result, "(").unwrap();
207 for (i, col) in columns.iter().enumerate() {
208 if i > 0 {
209 write!(result, ", ").unwrap();
210 }
211 write!(result, "{}", col).unwrap();
212 }
213 write!(result, ")").unwrap();
214 }
215
216 writeln!(result, " {} (", self.keyword("AS")).unwrap();
217 let cte_sql = match &cte.cte_type {
218 crate::sql::parser::ast::CTEType::Standard(query) => {
219 self.format_select(query, indent_level + 1)
220 }
221 crate::sql::parser::ast::CTEType::Web(web_spec) => {
222 let mut web_str = format!(
224 "{}{} '{}'",
225 " ".repeat(indent_level + 1),
226 self.keyword("URL"),
227 web_spec.url
228 );
229
230 if let Some(format) = &web_spec.format {
231 web_str.push_str(&format!(
232 " {} {}",
233 self.keyword("FORMAT"),
234 match format {
235 crate::sql::parser::ast::DataFormat::CSV => "CSV",
236 crate::sql::parser::ast::DataFormat::JSON => "JSON",
237 crate::sql::parser::ast::DataFormat::Auto => "AUTO",
238 }
239 ));
240 }
241
242 if let Some(cache) = web_spec.cache_seconds {
243 web_str.push_str(&format!(" {} {}", self.keyword("CACHE"), cache));
244 }
245
246 if !web_spec.headers.is_empty() {
247 web_str.push_str(&format!(" {} (", self.keyword("HEADERS")));
248 for (i, (key, value)) in web_spec.headers.iter().enumerate() {
249 if i > 0 {
250 web_str.push_str(", ");
251 }
252 web_str.push_str(&format!("{} = '{}'", key, value));
253 }
254 web_str.push(')');
255 }
256
257 web_str
258 }
259 };
260 write!(result, "{}", cte_sql).unwrap();
261 writeln!(result).unwrap();
262 write!(result, "{}", indent).unwrap();
263 if is_last {
264 writeln!(result, ")").unwrap();
265 } else {
266 writeln!(result, "),").unwrap();
267 }
268 }
269
270 fn format_column_list(&self, result: &mut String, columns: &[String], indent_level: usize) {
271 if columns.len() <= self.config.items_per_line {
272 write!(result, " ").unwrap();
274 for (i, col) in columns.iter().enumerate() {
275 if i > 0 {
276 write!(result, ", ").unwrap();
277 }
278 write!(result, "{}", col).unwrap();
279 }
280 } else {
281 writeln!(result).unwrap();
283 let indent = self.indent(indent_level + 1);
284 for (i, col) in columns.iter().enumerate() {
285 write!(result, "{}{}", indent, col).unwrap();
286 if i < columns.len() - 1 {
287 writeln!(result, ",").unwrap();
288 }
289 }
290 }
291 }
292
293 fn format_select_items(&self, result: &mut String, items: &[SelectItem], indent_level: usize) {
294 if items.is_empty() {
295 write!(result, " *").unwrap();
296 return;
297 }
298
299 let non_star_count = items
301 .iter()
302 .filter(|i| !matches!(i, SelectItem::Star))
303 .count();
304
305 let has_complex_items = items.iter().any(|item| match item {
307 SelectItem::Expression { expr, .. } => self.is_complex_expression(expr),
308 _ => false,
309 });
310
311 let single_line_length: usize = items
313 .iter()
314 .map(|item| {
315 match item {
316 SelectItem::Star => 1,
317 SelectItem::Column(col) => col.len(),
318 SelectItem::Expression { expr, alias } => {
319 self.format_expression(expr).len() + 4 + alias.len() }
321 }
322 })
323 .sum::<usize>()
324 + (items.len() - 1) * 2; let use_single_line = match items.len() {
330 1 => !has_complex_items, 2..=3 => !has_complex_items && single_line_length < 40, _ => false, };
334
335 if !use_single_line {
336 writeln!(result).unwrap();
338 let indent = self.indent(indent_level + 1);
339 for (i, item) in items.iter().enumerate() {
340 write!(result, "{}", indent).unwrap();
341 self.format_select_item(result, item);
342 if i < items.len() - 1 {
343 writeln!(result, ",").unwrap();
344 }
345 }
346 } else {
347 write!(result, " ").unwrap();
349 for (i, item) in items.iter().enumerate() {
350 if i > 0 {
351 write!(result, ", ").unwrap();
352 }
353 self.format_select_item(result, item);
354 }
355 }
356 }
357
358 fn is_complex_expression(&self, expr: &SqlExpression) -> bool {
359 match expr {
360 SqlExpression::CaseExpression { .. } => true,
361 SqlExpression::FunctionCall { .. } => true,
362 SqlExpression::WindowFunction { .. } => true,
363 SqlExpression::ScalarSubquery { .. } => true,
364 SqlExpression::InSubquery { .. } => true,
365 SqlExpression::NotInSubquery { .. } => true,
366 SqlExpression::BinaryOp { left, right, .. } => {
367 self.is_complex_expression(left) || self.is_complex_expression(right)
368 }
369 _ => false,
370 }
371 }
372
373 fn format_select_item(&self, result: &mut String, item: &SelectItem) {
374 match item {
375 SelectItem::Star => write!(result, "*").unwrap(),
376 SelectItem::Column(col) => write!(result, "{}", col).unwrap(),
377 SelectItem::Expression { expr, alias } => {
378 write!(
379 result,
380 "{} {} {}",
381 self.format_expression(expr),
382 self.keyword("AS"),
383 alias
384 )
385 .unwrap();
386 }
387 }
388 }
389
390 fn format_expression(&self, expr: &SqlExpression) -> String {
391 match expr {
392 SqlExpression::Column(name) => name.clone(),
393 SqlExpression::StringLiteral(s) => format!("'{}'", s),
394 SqlExpression::NumberLiteral(n) => n.clone(),
395 SqlExpression::BooleanLiteral(b) => b.to_string().to_uppercase(),
396 SqlExpression::Null => self.keyword("NULL"),
397 SqlExpression::BinaryOp { left, op, right } => {
398 if op == "IS NULL" || op == "IS NOT NULL" {
400 format!("{} {}", self.format_expression(left), op)
401 } else {
402 format!(
403 "{} {} {}",
404 self.format_expression(left),
405 op,
406 self.format_expression(right)
407 )
408 }
409 }
410 SqlExpression::FunctionCall {
411 name,
412 args,
413 distinct,
414 } => {
415 let mut result = name.clone();
416 result.push('(');
417 if *distinct {
418 result.push_str(&self.keyword("DISTINCT"));
419 result.push(' ');
420 }
421 for (i, arg) in args.iter().enumerate() {
422 if i > 0 {
423 result.push_str(", ");
424 }
425 result.push_str(&self.format_expression(arg));
426 }
427 result.push(')');
428 result
429 }
430 SqlExpression::CaseExpression {
431 when_branches,
432 else_branch,
433 } => {
434 let mut result = String::new();
436 result.push_str(&self.keyword("CASE"));
437 result.push('\n');
438
439 for branch in when_branches {
441 result.push_str(" "); result.push_str(&format!(
443 "{} {} {} {}",
444 self.keyword("WHEN"),
445 self.format_expression(&branch.condition),
446 self.keyword("THEN"),
447 self.format_expression(&branch.result)
448 ));
449 result.push('\n');
450 }
451
452 if let Some(else_expr) = else_branch {
454 result.push_str(" "); result.push_str(&format!(
456 "{} {}",
457 self.keyword("ELSE"),
458 self.format_expression(else_expr)
459 ));
460 result.push('\n');
461 }
462
463 result.push_str(" "); result.push_str(&self.keyword("END"));
465 result
466 }
467 SqlExpression::Between { expr, lower, upper } => {
468 format!(
469 "{} {} {} {} {}",
470 self.format_expression(expr),
471 self.keyword("BETWEEN"),
472 self.format_expression(lower),
473 self.keyword("AND"),
474 self.format_expression(upper)
475 )
476 }
477 SqlExpression::InList { expr, values } => {
478 let mut result =
479 format!("{} {} (", self.format_expression(expr), self.keyword("IN"));
480 for (i, val) in values.iter().enumerate() {
481 if i > 0 {
482 result.push_str(", ");
483 }
484 result.push_str(&self.format_expression(val));
485 }
486 result.push(')');
487 result
488 }
489 SqlExpression::NotInList { expr, values } => {
490 let mut result = format!(
491 "{} {} {} (",
492 self.format_expression(expr),
493 self.keyword("NOT"),
494 self.keyword("IN")
495 );
496 for (i, val) in values.iter().enumerate() {
497 if i > 0 {
498 result.push_str(", ");
499 }
500 result.push_str(&self.format_expression(val));
501 }
502 result.push(')');
503 result
504 }
505 SqlExpression::Not { expr } => {
506 format!("{} {}", self.keyword("NOT"), self.format_expression(expr))
507 }
508 SqlExpression::ScalarSubquery { query } => {
509 let subquery_str = self.format_select(query, 0);
511 if subquery_str.contains('\n') || subquery_str.len() > 60 {
512 format!("(\n{}\n)", self.format_select(query, 1))
514 } else {
515 format!("({})", subquery_str)
517 }
518 }
519 SqlExpression::InSubquery { expr, subquery } => {
520 let subquery_str = self.format_select(subquery, 0);
521 if subquery_str.contains('\n') || subquery_str.len() > 60 {
522 format!(
524 "{} {} (\n{}\n)",
525 self.format_expression(expr),
526 self.keyword("IN"),
527 self.format_select(subquery, 1)
528 )
529 } else {
530 format!(
532 "{} {} ({})",
533 self.format_expression(expr),
534 self.keyword("IN"),
535 subquery_str
536 )
537 }
538 }
539 SqlExpression::NotInSubquery { expr, subquery } => {
540 let subquery_str = self.format_select(subquery, 0);
541 if subquery_str.contains('\n') || subquery_str.len() > 60 {
542 format!(
544 "{} {} {} (\n{}\n)",
545 self.format_expression(expr),
546 self.keyword("NOT"),
547 self.keyword("IN"),
548 self.format_select(subquery, 1)
549 )
550 } else {
551 format!(
553 "{} {} {} ({})",
554 self.format_expression(expr),
555 self.keyword("NOT"),
556 self.keyword("IN"),
557 subquery_str
558 )
559 }
560 }
561 SqlExpression::MethodCall {
562 object,
563 method,
564 args,
565 } => {
566 let mut result = format!("{}.{}", object, method);
567 result.push('(');
568 for (i, arg) in args.iter().enumerate() {
569 if i > 0 {
570 result.push_str(", ");
571 }
572 result.push_str(&self.format_expression(arg));
573 }
574 result.push(')');
575 result
576 }
577 SqlExpression::ChainedMethodCall { base, method, args } => {
578 let mut result = format!("{}.{}", self.format_expression(base), method);
579 result.push('(');
580 for (i, arg) in args.iter().enumerate() {
581 if i > 0 {
582 result.push_str(", ");
583 }
584 result.push_str(&self.format_expression(arg));
585 }
586 result.push(')');
587 result
588 }
589 SqlExpression::WindowFunction {
590 name,
591 args,
592 window_spec,
593 } => {
594 let mut result = format!("{}(", name);
595
596 for (i, arg) in args.iter().enumerate() {
598 if i > 0 {
599 result.push_str(", ");
600 }
601 result.push_str(&self.format_expression(arg));
602 }
603 result.push_str(") ");
604 result.push_str(&self.keyword("OVER"));
605 result.push_str(" (");
606
607 if !window_spec.partition_by.is_empty() {
609 result.push_str(&self.keyword("PARTITION BY"));
610 result.push(' ');
611 for (i, col) in window_spec.partition_by.iter().enumerate() {
612 if i > 0 {
613 result.push_str(", ");
614 }
615 result.push_str(col);
616 }
617 }
618
619 if !window_spec.order_by.is_empty() {
621 if !window_spec.partition_by.is_empty() {
622 result.push(' ');
623 }
624 result.push_str(&self.keyword("ORDER BY"));
625 result.push(' ');
626 for (i, col) in window_spec.order_by.iter().enumerate() {
627 if i > 0 {
628 result.push_str(", ");
629 }
630 result.push_str(&col.column);
631 match col.direction {
632 SortDirection::Asc => {
633 result.push(' ');
634 result.push_str(&self.keyword("ASC"));
635 }
636 SortDirection::Desc => {
637 result.push(' ');
638 result.push_str(&self.keyword("DESC"));
639 }
640 }
641 }
642 }
643
644 result.push(')');
645 result
646 }
647 _ => format!("{:?}", expr), }
649 }
650
651 fn format_where_clause(
652 &self,
653 result: &mut String,
654 where_clause: &WhereClause,
655 indent_level: usize,
656 ) {
657 let needs_multiline = where_clause.conditions.len() > 1;
658
659 if needs_multiline {
660 writeln!(result).unwrap();
661 let indent = self.indent(indent_level + 1);
662 for (i, condition) in where_clause.conditions.iter().enumerate() {
663 if i > 0 {
664 if let Some(ref connector) = where_clause.conditions[i - 1].connector {
665 let connector_str = match connector {
666 LogicalOp::And => self.keyword("AND"),
667 LogicalOp::Or => self.keyword("OR"),
668 };
669 writeln!(result).unwrap();
670 write!(result, "{}{} ", indent, connector_str).unwrap();
671 }
672 } else {
673 write!(result, "{}", indent).unwrap();
674 }
675 write!(result, "{}", self.format_expression(&condition.expr)).unwrap();
676 }
677 } else if let Some(condition) = where_clause.conditions.first() {
678 write!(result, " {}", self.format_expression(&condition.expr)).unwrap();
679 }
680 }
681
682 fn format_join(&self, result: &mut String, join: &JoinClause, indent_level: usize) {
683 let indent = self.indent(indent_level);
684 let join_type = match join.join_type {
685 JoinType::Inner => self.keyword("INNER JOIN"),
686 JoinType::Left => self.keyword("LEFT JOIN"),
687 JoinType::Right => self.keyword("RIGHT JOIN"),
688 JoinType::Full => self.keyword("FULL JOIN"),
689 JoinType::Cross => self.keyword("CROSS JOIN"),
690 };
691
692 write!(result, "{}{} ", indent, join_type).unwrap();
693
694 match &join.table {
695 TableSource::Table(name) => write!(result, "{}", name).unwrap(),
696 TableSource::DerivedTable { query, alias } => {
697 writeln!(result, "(").unwrap();
698 let subquery_sql = self.format_select(query, indent_level + 1);
699 write!(result, "{}", subquery_sql).unwrap();
700 writeln!(result).unwrap();
701 write!(result, "{}) {} {}", indent, self.keyword("AS"), alias).unwrap();
702 }
703 }
704
705 if let Some(ref alias) = join.alias {
706 write!(result, " {} {}", self.keyword("AS"), alias).unwrap();
707 }
708
709 write!(
710 result,
711 " {} {} {} {}",
712 self.keyword("ON"),
713 join.condition.left_column,
714 self.format_join_operator(&join.condition.operator),
715 join.condition.right_column
716 )
717 .unwrap();
718 }
719
720 fn format_join_operator(&self, op: &JoinOperator) -> String {
721 match op {
722 JoinOperator::Equal => "=",
723 JoinOperator::NotEqual => "!=",
724 JoinOperator::LessThan => "<",
725 JoinOperator::GreaterThan => ">",
726 JoinOperator::LessThanOrEqual => "<=",
727 JoinOperator::GreaterThanOrEqual => ">=",
728 }
729 .to_string()
730 }
731
732 fn format_table_function(&self, result: &mut String, func: &TableFunction) {
733 match func {
734 TableFunction::Range { start, end, step } => {
735 write!(result, "{}(", self.keyword("RANGE")).unwrap();
736 write!(
737 result,
738 "{}, {}",
739 self.format_expression(start),
740 self.format_expression(end)
741 )
742 .unwrap();
743 if let Some(step_expr) = step {
744 write!(result, ", {}", self.format_expression(step_expr)).unwrap();
745 }
746 write!(result, ")").unwrap();
747 }
748 }
749 }
750}
751
752pub fn format_sql_ast(query: &str) -> Result<String, String> {
754 use crate::sql::recursive_parser::Parser;
755
756 let mut parser = Parser::new(query);
757 match parser.parse() {
758 Ok(stmt) => Ok(format_select_statement(&stmt)),
759 Err(e) => Err(format!("Parse error: {}", e)),
760 }
761}
762
763pub fn format_sql_ast_with_config(query: &str, config: &FormatConfig) -> Result<String, String> {
765 use crate::sql::recursive_parser::Parser;
766
767 let mut parser = Parser::new(query);
768 match parser.parse() {
769 Ok(stmt) => Ok(format_select_with_config(&stmt, &config)),
770 Err(e) => Err(format!("Parse error: {}", e)),
771 }
772}