1use crate::sql::parser::ast::*;
8use std::fmt::Write;
9
10pub struct FormatConfig {
12 pub indent: String,
14 pub items_per_line: usize,
16 pub uppercase_keywords: bool,
18 pub compact: bool,
20}
21
22impl Default for FormatConfig {
23 fn default() -> Self {
24 Self {
25 indent: " ".to_string(),
26 items_per_line: 5,
27 uppercase_keywords: true,
28 compact: false,
29 }
30 }
31}
32
33pub fn format_select_statement(stmt: &SelectStatement) -> String {
35 format_select_with_config(stmt, &FormatConfig::default())
36}
37
38pub fn format_select_with_config(stmt: &SelectStatement, config: &FormatConfig) -> String {
40 let formatter = AstFormatter::new(config);
41 formatter.format_select(stmt, 0)
42}
43
44struct AstFormatter<'a> {
45 config: &'a FormatConfig,
46}
47
48impl<'a> AstFormatter<'a> {
49 fn new(config: &'a FormatConfig) -> Self {
50 Self { config }
51 }
52
53 fn keyword(&self, word: &str) -> String {
54 if self.config.uppercase_keywords {
55 word.to_uppercase()
56 } else {
57 word.to_lowercase()
58 }
59 }
60
61 fn indent(&self, level: usize) -> String {
62 self.config.indent.repeat(level)
63 }
64
65 fn format_select(&self, stmt: &SelectStatement, indent_level: usize) -> String {
66 let mut result = String::new();
67 let indent = self.indent(indent_level);
68
69 if !stmt.ctes.is_empty() {
71 writeln!(&mut result, "{}{}", indent, self.keyword("WITH")).unwrap();
72 for (i, cte) in stmt.ctes.iter().enumerate() {
73 let is_last = i == stmt.ctes.len() - 1;
74 self.format_cte(&mut result, cte, indent_level + 1, is_last);
75 }
76 }
77
78 write!(&mut result, "{}{}", indent, self.keyword("SELECT")).unwrap();
80 if stmt.distinct {
81 write!(&mut result, " {}", self.keyword("DISTINCT")).unwrap();
82 }
83
84 if stmt.select_items.is_empty() && !stmt.columns.is_empty() {
86 self.format_column_list(&mut result, &stmt.columns, indent_level);
88 } else {
89 self.format_select_items(&mut result, &stmt.select_items, indent_level);
90 }
91
92 if let Some(ref table) = stmt.from_table {
94 writeln!(&mut result).unwrap();
95 write!(&mut result, "{}{} {}", indent, self.keyword("FROM"), table).unwrap();
96 } else if let Some(ref subquery) = stmt.from_subquery {
97 writeln!(&mut result).unwrap();
98 write!(&mut result, "{}{} (", indent, self.keyword("FROM")).unwrap();
99 writeln!(&mut result).unwrap();
100 let subquery_sql = self.format_select(subquery, indent_level + 1);
101 write!(&mut result, "{}", subquery_sql).unwrap();
102 write!(&mut result, "\n{}", indent).unwrap();
103 write!(&mut result, ")").unwrap();
104 if let Some(ref alias) = stmt.from_alias {
105 write!(&mut result, " {} {}", self.keyword("AS"), alias).unwrap();
106 }
107 } else if let Some(ref func) = stmt.from_function {
108 writeln!(&mut result).unwrap();
109 write!(&mut result, "{}{} ", indent, self.keyword("FROM")).unwrap();
110 self.format_table_function(&mut result, func);
111 if let Some(ref alias) = stmt.from_alias {
112 write!(&mut result, " {} {}", self.keyword("AS"), alias).unwrap();
113 }
114 }
115
116 for join in &stmt.joins {
118 writeln!(&mut result).unwrap();
119 self.format_join(&mut result, join, indent_level);
120 }
121
122 if let Some(ref where_clause) = stmt.where_clause {
124 writeln!(&mut result).unwrap();
125 write!(&mut result, "{}{}", indent, self.keyword("WHERE")).unwrap();
126 self.format_where_clause(&mut result, where_clause, indent_level);
127 }
128
129 if let Some(ref group_by) = stmt.group_by {
131 writeln!(&mut result).unwrap();
132 write!(&mut result, "{}{} ", indent, self.keyword("GROUP BY")).unwrap();
133 for (i, expr) in group_by.iter().enumerate() {
134 if i > 0 {
135 write!(&mut result, ", ").unwrap();
136 }
137 write!(&mut result, "{}", self.format_expression(expr)).unwrap();
138 }
139 }
140
141 if let Some(ref having) = stmt.having {
143 writeln!(&mut result).unwrap();
144 write!(
145 &mut result,
146 "{}{} {}",
147 indent,
148 self.keyword("HAVING"),
149 self.format_expression(having)
150 )
151 .unwrap();
152 }
153
154 if let Some(ref order_by) = stmt.order_by {
156 writeln!(&mut result).unwrap();
157 write!(&mut result, "{}{} ", indent, self.keyword("ORDER BY")).unwrap();
158 for (i, col) in order_by.iter().enumerate() {
159 if i > 0 {
160 write!(&mut result, ", ").unwrap();
161 }
162 write!(&mut result, "{}", col.column).unwrap();
163 match col.direction {
164 SortDirection::Asc => write!(&mut result, " {}", self.keyword("ASC")).unwrap(),
165 SortDirection::Desc => {
166 write!(&mut result, " {}", self.keyword("DESC")).unwrap()
167 }
168 }
169 }
170 }
171
172 if let Some(limit) = stmt.limit {
174 writeln!(&mut result).unwrap();
175 write!(&mut result, "{}{} {}", indent, self.keyword("LIMIT"), limit).unwrap();
176 }
177
178 if let Some(offset) = stmt.offset {
180 writeln!(&mut result).unwrap();
181 write!(
182 &mut result,
183 "{}{} {}",
184 indent,
185 self.keyword("OFFSET"),
186 offset
187 )
188 .unwrap();
189 }
190
191 result
192 }
193
194 fn format_cte(&self, result: &mut String, cte: &CTE, indent_level: usize, is_last: bool) {
195 let indent = self.indent(indent_level);
196 write!(result, "{}{}", indent, cte.name).unwrap();
197
198 if let Some(ref columns) = cte.column_list {
199 write!(result, "(").unwrap();
200 for (i, col) in columns.iter().enumerate() {
201 if i > 0 {
202 write!(result, ", ").unwrap();
203 }
204 write!(result, "{}", col).unwrap();
205 }
206 write!(result, ")").unwrap();
207 }
208
209 writeln!(result, " {} (", self.keyword("AS")).unwrap();
210 let cte_sql = self.format_select(&cte.query, indent_level + 1);
211 write!(result, "{}", cte_sql).unwrap();
212 writeln!(result).unwrap();
213 write!(result, "{}", indent).unwrap();
214 if is_last {
215 writeln!(result, ")").unwrap();
216 } else {
217 writeln!(result, "),").unwrap();
218 }
219 }
220
221 fn format_column_list(&self, result: &mut String, columns: &[String], indent_level: usize) {
222 if columns.len() <= self.config.items_per_line {
223 write!(result, " ").unwrap();
225 for (i, col) in columns.iter().enumerate() {
226 if i > 0 {
227 write!(result, ", ").unwrap();
228 }
229 write!(result, "{}", col).unwrap();
230 }
231 } else {
232 writeln!(result).unwrap();
234 let indent = self.indent(indent_level + 1);
235 for (i, col) in columns.iter().enumerate() {
236 write!(result, "{}{}", indent, col).unwrap();
237 if i < columns.len() - 1 {
238 writeln!(result, ",").unwrap();
239 }
240 }
241 }
242 }
243
244 fn format_select_items(&self, result: &mut String, items: &[SelectItem], indent_level: usize) {
245 if items.is_empty() {
246 write!(result, " *").unwrap();
247 return;
248 }
249
250 let non_star_count = items
252 .iter()
253 .filter(|i| !matches!(i, SelectItem::Star))
254 .count();
255
256 if non_star_count <= self.config.items_per_line {
257 write!(result, " ").unwrap();
259 for (i, item) in items.iter().enumerate() {
260 if i > 0 {
261 write!(result, ", ").unwrap();
262 }
263 self.format_select_item(result, item);
264 }
265 } else {
266 writeln!(result).unwrap();
268 let indent = self.indent(indent_level + 1);
269 for (i, item) in items.iter().enumerate() {
270 write!(result, "{}", indent).unwrap();
271 self.format_select_item(result, item);
272 if i < items.len() - 1 {
273 writeln!(result, ",").unwrap();
274 }
275 }
276 }
277 }
278
279 fn format_select_item(&self, result: &mut String, item: &SelectItem) {
280 match item {
281 SelectItem::Star => write!(result, "*").unwrap(),
282 SelectItem::Column(col) => write!(result, "{}", col).unwrap(),
283 SelectItem::Expression { expr, alias } => {
284 write!(
285 result,
286 "{} {} {}",
287 self.format_expression(expr),
288 self.keyword("AS"),
289 alias
290 )
291 .unwrap();
292 }
293 }
294 }
295
296 fn format_expression(&self, expr: &SqlExpression) -> String {
297 match expr {
298 SqlExpression::Column(name) => name.clone(),
299 SqlExpression::StringLiteral(s) => format!("'{}'", s),
300 SqlExpression::NumberLiteral(n) => n.clone(),
301 SqlExpression::BooleanLiteral(b) => b.to_string().to_uppercase(),
302 SqlExpression::Null => self.keyword("NULL"),
303 SqlExpression::BinaryOp { left, op, right } => {
304 format!(
305 "{} {} {}",
306 self.format_expression(left),
307 op,
308 self.format_expression(right)
309 )
310 }
311 SqlExpression::FunctionCall {
312 name,
313 args,
314 distinct,
315 } => {
316 let mut result = name.clone();
317 result.push('(');
318 if *distinct {
319 result.push_str(&self.keyword("DISTINCT"));
320 result.push(' ');
321 }
322 for (i, arg) in args.iter().enumerate() {
323 if i > 0 {
324 result.push_str(", ");
325 }
326 result.push_str(&self.format_expression(arg));
327 }
328 result.push(')');
329 result
330 }
331 SqlExpression::CaseExpression {
332 when_branches,
333 else_branch,
334 } => {
335 let mut result = self.keyword("CASE");
336 for branch in when_branches {
337 result.push_str(&format!(
338 " {} {} {} {}",
339 self.keyword("WHEN"),
340 self.format_expression(&branch.condition),
341 self.keyword("THEN"),
342 self.format_expression(&branch.result)
343 ));
344 }
345 if let Some(else_expr) = else_branch {
346 result.push_str(&format!(
347 " {} {}",
348 self.keyword("ELSE"),
349 self.format_expression(else_expr)
350 ));
351 }
352 result.push_str(&format!(" {}", self.keyword("END")));
353 result
354 }
355 SqlExpression::Between { expr, lower, upper } => {
356 format!(
357 "{} {} {} {} {}",
358 self.format_expression(expr),
359 self.keyword("BETWEEN"),
360 self.format_expression(lower),
361 self.keyword("AND"),
362 self.format_expression(upper)
363 )
364 }
365 SqlExpression::InList { expr, values } => {
366 let mut result =
367 format!("{} {} (", self.format_expression(expr), self.keyword("IN"));
368 for (i, val) in values.iter().enumerate() {
369 if i > 0 {
370 result.push_str(", ");
371 }
372 result.push_str(&self.format_expression(val));
373 }
374 result.push(')');
375 result
376 }
377 SqlExpression::NotInList { expr, values } => {
378 let mut result = format!(
379 "{} {} {} (",
380 self.format_expression(expr),
381 self.keyword("NOT"),
382 self.keyword("IN")
383 );
384 for (i, val) in values.iter().enumerate() {
385 if i > 0 {
386 result.push_str(", ");
387 }
388 result.push_str(&self.format_expression(val));
389 }
390 result.push(')');
391 result
392 }
393 SqlExpression::Not { expr } => {
394 format!("{} {}", self.keyword("NOT"), self.format_expression(expr))
395 }
396 SqlExpression::ScalarSubquery { query } => {
397 let subquery_str = self.format_select(query, 0);
399 if subquery_str.contains('\n') || subquery_str.len() > 60 {
400 format!("(\n{}\n)", self.format_select(query, 1))
402 } else {
403 format!("({})", subquery_str)
405 }
406 }
407 SqlExpression::InSubquery { expr, subquery } => {
408 let subquery_str = self.format_select(subquery, 0);
409 if subquery_str.contains('\n') || subquery_str.len() > 60 {
410 format!(
412 "{} {} (\n{}\n)",
413 self.format_expression(expr),
414 self.keyword("IN"),
415 self.format_select(subquery, 1)
416 )
417 } else {
418 format!(
420 "{} {} ({})",
421 self.format_expression(expr),
422 self.keyword("IN"),
423 subquery_str
424 )
425 }
426 }
427 SqlExpression::NotInSubquery { expr, subquery } => {
428 let subquery_str = self.format_select(subquery, 0);
429 if subquery_str.contains('\n') || subquery_str.len() > 60 {
430 format!(
432 "{} {} {} (\n{}\n)",
433 self.format_expression(expr),
434 self.keyword("NOT"),
435 self.keyword("IN"),
436 self.format_select(subquery, 1)
437 )
438 } else {
439 format!(
441 "{} {} {} ({})",
442 self.format_expression(expr),
443 self.keyword("NOT"),
444 self.keyword("IN"),
445 subquery_str
446 )
447 }
448 }
449 SqlExpression::MethodCall {
450 object,
451 method,
452 args,
453 } => {
454 let mut result = format!("{}.{}", object, method);
455 result.push('(');
456 for (i, arg) in args.iter().enumerate() {
457 if i > 0 {
458 result.push_str(", ");
459 }
460 result.push_str(&self.format_expression(arg));
461 }
462 result.push(')');
463 result
464 }
465 SqlExpression::ChainedMethodCall { base, method, args } => {
466 let mut result = format!("{}.{}", self.format_expression(base), method);
467 result.push('(');
468 for (i, arg) in args.iter().enumerate() {
469 if i > 0 {
470 result.push_str(", ");
471 }
472 result.push_str(&self.format_expression(arg));
473 }
474 result.push(')');
475 result
476 }
477 _ => format!("{:?}", expr), }
479 }
480
481 fn format_where_clause(
482 &self,
483 result: &mut String,
484 where_clause: &WhereClause,
485 indent_level: usize,
486 ) {
487 let needs_multiline = where_clause.conditions.len() > 1;
488
489 if needs_multiline {
490 writeln!(result).unwrap();
491 let indent = self.indent(indent_level + 1);
492 for (i, condition) in where_clause.conditions.iter().enumerate() {
493 if i > 0 {
494 if let Some(ref connector) = where_clause.conditions[i - 1].connector {
495 let connector_str = match connector {
496 LogicalOp::And => self.keyword("AND"),
497 LogicalOp::Or => self.keyword("OR"),
498 };
499 writeln!(result).unwrap();
500 write!(result, "{}{} ", indent, connector_str).unwrap();
501 }
502 } else {
503 write!(result, "{}", indent).unwrap();
504 }
505 write!(result, "{}", self.format_expression(&condition.expr)).unwrap();
506 }
507 } else if let Some(condition) = where_clause.conditions.first() {
508 write!(result, " {}", self.format_expression(&condition.expr)).unwrap();
509 }
510 }
511
512 fn format_join(&self, result: &mut String, join: &JoinClause, indent_level: usize) {
513 let indent = self.indent(indent_level);
514 let join_type = match join.join_type {
515 JoinType::Inner => self.keyword("INNER JOIN"),
516 JoinType::Left => self.keyword("LEFT JOIN"),
517 JoinType::Right => self.keyword("RIGHT JOIN"),
518 JoinType::Full => self.keyword("FULL JOIN"),
519 JoinType::Cross => self.keyword("CROSS JOIN"),
520 };
521
522 write!(result, "{}{} ", indent, join_type).unwrap();
523
524 match &join.table {
525 TableSource::Table(name) => write!(result, "{}", name).unwrap(),
526 TableSource::DerivedTable { query, alias } => {
527 writeln!(result, "(").unwrap();
528 let subquery_sql = self.format_select(query, indent_level + 1);
529 write!(result, "{}", subquery_sql).unwrap();
530 writeln!(result).unwrap();
531 write!(result, "{}) {} {}", indent, self.keyword("AS"), alias).unwrap();
532 }
533 }
534
535 if let Some(ref alias) = join.alias {
536 write!(result, " {} {}", self.keyword("AS"), alias).unwrap();
537 }
538
539 write!(
540 result,
541 " {} {} {} {}",
542 self.keyword("ON"),
543 join.condition.left_column,
544 self.format_join_operator(&join.condition.operator),
545 join.condition.right_column
546 )
547 .unwrap();
548 }
549
550 fn format_join_operator(&self, op: &JoinOperator) -> String {
551 match op {
552 JoinOperator::Equal => "=",
553 JoinOperator::NotEqual => "!=",
554 JoinOperator::LessThan => "<",
555 JoinOperator::GreaterThan => ">",
556 JoinOperator::LessThanOrEqual => "<=",
557 JoinOperator::GreaterThanOrEqual => ">=",
558 }
559 .to_string()
560 }
561
562 fn format_table_function(&self, result: &mut String, func: &TableFunction) {
563 match func {
564 TableFunction::Range { start, end, step } => {
565 write!(result, "{}(", self.keyword("RANGE")).unwrap();
566 write!(
567 result,
568 "{}, {}",
569 self.format_expression(start),
570 self.format_expression(end)
571 )
572 .unwrap();
573 if let Some(step_expr) = step {
574 write!(result, ", {}", self.format_expression(step_expr)).unwrap();
575 }
576 write!(result, ")").unwrap();
577 }
578 }
579 }
580}
581
582pub fn format_sql_ast(query: &str) -> Result<String, String> {
584 use crate::sql::recursive_parser::Parser;
585
586 let mut parser = Parser::new(query);
587 match parser.parse() {
588 Ok(stmt) => Ok(format_select_statement(&stmt)),
589 Err(e) => Err(format!("Parse error: {}", e)),
590 }
591}
592
593pub fn format_sql_ast_with_config(query: &str, config: &FormatConfig) -> Result<String, String> {
595 use crate::sql::recursive_parser::Parser;
596
597 let mut parser = Parser::new(query);
598 match parser.parse() {
599 Ok(stmt) => Ok(format_select_with_config(&stmt, &config)),
600 Err(e) => Err(format!("Parse error: {}", e)),
601 }
602}