1use super::ast::{LogicalOp, SelectStatement, SortDirection, SqlExpression, WhereClause};
5use super::lexer::{Lexer, Token};
6use crate::sql::recursive_parser::Parser;
7
8#[must_use]
9pub fn format_sql_pretty(query: &str) -> Vec<String> {
10 format_sql_pretty_compact(query, 5) }
12
13#[must_use]
15pub fn format_ast_tree(query: &str) -> String {
16 let mut parser = Parser::new(query);
17 match parser.parse() {
18 Ok(stmt) => format_select_statement(&stmt, 0),
19 Err(e) => format!("❌ PARSE ERROR ❌\n{e}\n\n⚠️ The query could not be parsed correctly.\n💡 Check parentheses, operators, and syntax."),
20 }
21}
22
23fn format_select_statement(stmt: &SelectStatement, indent: usize) -> String {
24 let mut result = String::new();
25 let indent_str = " ".repeat(indent);
26
27 result.push_str(&format!("{indent_str}SelectStatement {{\n"));
28
29 result.push_str(&format!("{indent_str} columns: ["));
31 if stmt.columns.is_empty() {
32 result.push_str("],\n");
33 } else {
34 result.push('\n');
35 for col in &stmt.columns {
36 result.push_str(&format!("{indent_str} \"{col}\",\n"));
37 }
38 result.push_str(&format!("{indent_str} ],\n"));
39 }
40
41 if let Some(table) = &stmt.from_table {
43 result.push_str(&format!("{indent_str} from_table: \"{table}\",\n"));
44 }
45
46 if let Some(where_clause) = &stmt.where_clause {
48 result.push_str(&format!("{indent_str} where_clause: {{\n"));
49 result.push_str(&format_where_clause(where_clause, indent + 2));
50 result.push_str(&format!("{indent_str} }},\n"));
51 }
52
53 if let Some(order_by) = &stmt.order_by {
55 result.push_str(&format!("{indent_str} order_by: ["));
56 if order_by.is_empty() {
57 result.push_str("],\n");
58 } else {
59 result.push('\n');
60 for col in order_by {
61 let dir = match col.direction {
62 SortDirection::Asc => "ASC",
63 SortDirection::Desc => "DESC",
64 };
65 result.push_str(&format!(
66 "{indent_str} {{ column: \"{}\", direction: {dir} }},\n",
67 col.column
68 ));
69 }
70 result.push_str(&format!("{indent_str} ],\n"));
71 }
72 }
73
74 if let Some(group_by) = &stmt.group_by {
76 result.push_str(&format!("{indent_str} group_by: ["));
77 if group_by.is_empty() {
78 result.push_str("],\n");
79 } else {
80 result.push('\n');
81 for expr in group_by {
82 result.push_str(&format!("{indent_str} \"{:?}\",\n", expr));
83 }
84 result.push_str(&format!("{indent_str} ],\n"));
85 }
86 }
87
88 if let Some(limit) = stmt.limit {
90 result.push_str(&format!("{indent_str} limit: {limit},\n"));
91 }
92
93 if stmt.distinct {
95 result.push_str(&format!("{indent_str} distinct: true,\n"));
96 }
97
98 result.push_str(&format!("{indent_str}}}\n"));
99 result
100}
101
102fn format_where_clause(clause: &WhereClause, indent: usize) -> String {
103 let mut result = String::new();
104 let indent_str = " ".repeat(indent);
105
106 result.push_str(&format!("{indent_str}conditions: [\n"));
107 for (i, condition) in clause.conditions.iter().enumerate() {
108 result.push_str(&format!("{indent_str} {{\n"));
109 result.push_str(&format!(
110 "{indent_str} expr: {},\n",
111 format_expression_ast(&condition.expr)
112 ));
113
114 if let Some(connector) = &condition.connector {
115 let conn_str = match connector {
116 LogicalOp::And => "AND",
117 LogicalOp::Or => "OR",
118 };
119 result.push_str(&format!("{indent_str} connector: {conn_str},\n"));
120 }
121
122 result.push_str(&format!("{indent_str} }}"));
123 if i < clause.conditions.len() - 1 {
124 result.push(',');
125 }
126 result.push('\n');
127 }
128 result.push_str(&format!("{indent_str}]\n"));
129
130 result
131}
132
133pub fn format_expression_ast(expr: &SqlExpression) -> String {
134 match expr {
135 SqlExpression::Column(name) => format!("Column(\"{name}\")"),
136 SqlExpression::StringLiteral(value) => format!("StringLiteral(\"{value}\")"),
137 SqlExpression::NumberLiteral(value) => format!("NumberLiteral({value})"),
138 SqlExpression::BinaryOp { left, op, right } => {
139 format!(
140 "BinaryOp {{ left: {}, op: \"{op}\", right: {} }}",
141 format_expression_ast(left),
142 format_expression_ast(right)
143 )
144 }
145 SqlExpression::FunctionCall {
146 name,
147 args,
148 distinct,
149 } => {
150 let args_str = args
151 .iter()
152 .map(format_expression_ast)
153 .collect::<Vec<_>>()
154 .join(", ");
155 if *distinct {
156 format!("FunctionCall {{ name: \"{name}\", args: [{args_str}], distinct: true }}")
157 } else {
158 format!("FunctionCall {{ name: \"{name}\", args: [{args_str}] }}")
159 }
160 }
161 SqlExpression::MethodCall {
162 object,
163 method,
164 args,
165 } => {
166 let args_str = args
167 .iter()
168 .map(format_expression_ast)
169 .collect::<Vec<_>>()
170 .join(", ");
171 format!(
172 "MethodCall {{ object: \"{object}\", method: \"{method}\", args: [{args_str}] }}"
173 )
174 }
175 SqlExpression::InList { expr, values } => {
176 let values_str = values
177 .iter()
178 .map(format_expression_ast)
179 .collect::<Vec<_>>()
180 .join(", ");
181 format!(
182 "InList {{ expr: {}, values: [{values_str}] }}",
183 format_expression_ast(expr)
184 )
185 }
186 SqlExpression::NotInList { expr, values } => {
187 let values_str = values
188 .iter()
189 .map(format_expression_ast)
190 .collect::<Vec<_>>()
191 .join(", ");
192 format!(
193 "NotInList {{ expr: {}, values: [{values_str}] }}",
194 format_expression_ast(expr)
195 )
196 }
197 SqlExpression::Between { expr, lower, upper } => {
198 format!(
199 "Between {{ expr: {}, lower: {}, upper: {} }}",
200 format_expression_ast(expr),
201 format_expression_ast(lower),
202 format_expression_ast(upper)
203 )
204 }
205 SqlExpression::Null => "Null".to_string(),
206 SqlExpression::BooleanLiteral(b) => format!("BooleanLiteral({b})"),
207 SqlExpression::DateTimeConstructor {
208 year,
209 month,
210 day,
211 hour,
212 minute,
213 second,
214 } => {
215 let time_part = match (hour, minute, second) {
216 (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
217 (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
218 _ => String::new(),
219 };
220 format!("DateTimeConstructor({year}-{month:02}-{day:02}{time_part})")
221 }
222 SqlExpression::DateTimeToday {
223 hour,
224 minute,
225 second,
226 } => {
227 let time_part = match (hour, minute, second) {
228 (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
229 (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
230 _ => String::new(),
231 };
232 format!("DateTimeToday({time_part})")
233 }
234 SqlExpression::WindowFunction {
235 name,
236 args,
237 window_spec: _,
238 } => {
239 let args_str = args
240 .iter()
241 .map(format_expression_ast)
242 .collect::<Vec<_>>()
243 .join(", ");
244 format!("WindowFunction {{ name: \"{name}\", args: [{args_str}], window_spec: ... }}")
245 }
246 SqlExpression::ChainedMethodCall { base, method, args } => {
247 let args_str = args
248 .iter()
249 .map(format_expression_ast)
250 .collect::<Vec<_>>()
251 .join(", ");
252 format!(
253 "ChainedMethodCall {{ base: {}, method: \"{method}\", args: [{args_str}] }}",
254 format_expression_ast(base)
255 )
256 }
257 SqlExpression::Not { expr } => {
258 format!("Not {{ expr: {} }}", format_expression_ast(expr))
259 }
260 SqlExpression::CaseExpression {
261 when_branches,
262 else_branch,
263 } => {
264 let mut result = String::from("CaseExpression { when_branches: [");
265 for branch in when_branches {
266 result.push_str(&format!(
267 " {{ condition: {}, result: {} }},",
268 format_expression_ast(&branch.condition),
269 format_expression_ast(&branch.result)
270 ));
271 }
272 result.push_str("], else_branch: ");
273 if let Some(else_expr) = else_branch {
274 result.push_str(&format_expression_ast(else_expr));
275 } else {
276 result.push_str("None");
277 }
278 result.push_str(" }");
279 result
280 }
281 SqlExpression::ScalarSubquery { query: _ } => {
282 format!("ScalarSubquery {{ query: <SelectStatement> }}")
283 }
284 SqlExpression::InSubquery { expr, subquery: _ } => {
285 format!(
286 "InSubquery {{ expr: {}, subquery: <SelectStatement> }}",
287 format_expression_ast(expr)
288 )
289 }
290 SqlExpression::NotInSubquery { expr, subquery: _ } => {
291 format!(
292 "NotInSubquery {{ expr: {}, subquery: <SelectStatement> }}",
293 format_expression_ast(expr)
294 )
295 }
296 }
297}
298
299fn extract_text_between_positions(text: &str, start: usize, end: usize) -> String {
301 if start >= text.len() || end > text.len() || start >= end {
302 return String::new();
303 }
304 text[start..end].to_string()
305}
306
307fn find_token_position(query: &str, target: Token, skip_count: usize) -> Option<usize> {
309 let mut lexer = Lexer::new(query);
310 let mut found_count = 0;
311
312 loop {
313 let pos = lexer.get_position();
314 let token = lexer.next_token();
315 if token == Token::Eof {
316 break;
317 }
318 if token == target {
319 if found_count == skip_count {
320 return Some(pos);
321 }
322 found_count += 1;
323 }
324 }
325 None
326}
327
328pub fn format_sql_with_preserved_parens(query: &str, cols_per_line: usize) -> Vec<String> {
329 let mut parser = Parser::new(query);
330 let stmt = match parser.parse() {
331 Ok(s) => s,
332 Err(_) => return vec![query.to_string()],
333 };
334
335 let mut lines = Vec::new();
336 let mut lexer = Lexer::new(query);
337 let mut tokens_with_pos = Vec::new();
338
339 loop {
341 let pos = lexer.get_position();
342 let token = lexer.next_token();
343 if token == Token::Eof {
344 break;
345 }
346 tokens_with_pos.push((token, pos));
347 }
348
349 let mut i = 0;
351 while i < tokens_with_pos.len() {
352 match &tokens_with_pos[i].0 {
353 Token::Select => {
354 let _select_start = tokens_with_pos[i].1;
355 i += 1;
356
357 let has_distinct = if i < tokens_with_pos.len() {
359 matches!(tokens_with_pos[i].0, Token::Distinct)
360 } else {
361 false
362 };
363
364 if has_distinct {
365 i += 1;
366 }
367
368 let _select_end = query.len();
370 let _col_count = 0;
371 let _current_line_cols: Vec<String> = Vec::new();
372 let mut all_select_lines = Vec::new();
373
374 let use_pretty_format = stmt.columns.len() > cols_per_line;
376
377 if use_pretty_format {
378 let select_text = if has_distinct {
380 "SELECT DISTINCT".to_string()
381 } else {
382 "SELECT".to_string()
383 };
384 all_select_lines.push(select_text);
385
386 for (idx, col) in stmt.columns.iter().enumerate() {
388 let is_last = idx == stmt.columns.len() - 1;
389 let col_text = if is_last {
390 format!(" {col}")
391 } else {
392 format!(" {col},")
393 };
394 all_select_lines.push(col_text);
395 }
396 } else {
397 let mut select_line = if has_distinct {
399 "SELECT DISTINCT ".to_string()
400 } else {
401 "SELECT ".to_string()
402 };
403
404 for (idx, col) in stmt.columns.iter().enumerate() {
405 if idx > 0 {
406 select_line.push_str(", ");
407 }
408 select_line.push_str(col);
409 }
410 all_select_lines.push(select_line);
411 }
412
413 lines.extend(all_select_lines);
414
415 while i < tokens_with_pos.len() {
417 match &tokens_with_pos[i].0 {
418 Token::From => break,
419 _ => i += 1,
420 }
421 }
422 }
423 Token::From => {
424 let from_start = tokens_with_pos[i].1;
425 i += 1;
426
427 let mut from_end = query.len();
429 while i < tokens_with_pos.len() {
430 match &tokens_with_pos[i].0 {
431 Token::Where
432 | Token::GroupBy
433 | Token::OrderBy
434 | Token::Limit
435 | Token::Having
436 | Token::Eof => {
437 from_end = tokens_with_pos[i].1;
438 break;
439 }
440 _ => i += 1,
441 }
442 }
443
444 let from_text = extract_text_between_positions(query, from_start, from_end);
445 lines.push(from_text.trim().to_string());
446 }
447 Token::Where => {
448 let where_start = tokens_with_pos[i].1;
449 i += 1;
450
451 let mut where_end = query.len();
453 let mut paren_depth = 0;
454 while i < tokens_with_pos.len() {
455 match &tokens_with_pos[i].0 {
456 Token::LeftParen => {
457 paren_depth += 1;
458 i += 1;
459 }
460 Token::RightParen => {
461 paren_depth -= 1;
462 i += 1;
463 }
464 Token::GroupBy
465 | Token::OrderBy
466 | Token::Limit
467 | Token::Having
468 | Token::Eof
469 if paren_depth == 0 =>
470 {
471 where_end = tokens_with_pos[i].1;
472 break;
473 }
474 _ => i += 1,
475 }
476 }
477
478 let where_text = extract_text_between_positions(query, where_start, where_end);
479 let formatted_where = format_where_clause_with_parens(where_text.trim());
480 lines.extend(formatted_where);
481 }
482 Token::GroupBy => {
483 let group_start = tokens_with_pos[i].1;
484 i += 1;
485
486 if i < tokens_with_pos.len() && matches!(tokens_with_pos[i].0, Token::By) {
488 i += 1;
489 }
490
491 while i < tokens_with_pos.len() {
493 match &tokens_with_pos[i].0 {
494 Token::OrderBy | Token::Limit | Token::Having | Token::Eof => break,
495 _ => i += 1,
496 }
497 }
498
499 if i > 0 {
500 let group_text = extract_text_between_positions(
501 query,
502 group_start,
503 tokens_with_pos[i - 1].1,
504 );
505 lines.push(format!("GROUP BY {}", group_text.trim()));
506 }
507 }
508 _ => i += 1,
509 }
510 }
511
512 lines
513}
514
515fn format_where_clause_with_parens(where_text: &str) -> Vec<String> {
516 let mut lines = Vec::new();
517 let mut current = String::from("WHERE ");
518 let mut paren_depth = 0;
519 let mut in_string = false;
520 let mut escape_next = false;
521 let mut chars = where_text.chars().peekable();
522
523 if where_text.trim_start().starts_with("WHERE") || where_text.trim_start().starts_with("where")
525 {
526 let skip_len = if where_text.trim_start().starts_with("WHERE") {
527 5
528 } else {
529 5
530 };
531 for _ in 0..skip_len {
532 chars.next();
533 }
534 while chars.peek() == Some(&' ') {
536 chars.next();
537 }
538 }
539
540 while let Some(ch) = chars.next() {
541 if escape_next {
542 current.push(ch);
543 escape_next = false;
544 continue;
545 }
546
547 match ch {
548 '\\' if in_string => {
549 current.push(ch);
550 escape_next = true;
551 }
552 '\'' => {
553 current.push(ch);
554 in_string = !in_string;
555 }
556 '(' if !in_string => {
557 current.push(ch);
558 paren_depth += 1;
559 }
560 ')' if !in_string => {
561 current.push(ch);
562 paren_depth -= 1;
563 }
564 _ => {
565 current.push(ch);
566 }
567 }
568 }
569
570 let cleaned = current.trim().to_string();
572 if !cleaned.is_empty() {
573 lines.push(cleaned);
574 }
575
576 lines
577}
578
579#[must_use]
580pub fn format_sql_pretty_compact(query: &str, cols_per_line: usize) -> Vec<String> {
581 let formatted = format_sql_with_preserved_parens(query, cols_per_line);
583
584 formatted
586 .into_iter()
587 .filter(|line| !line.trim().is_empty())
588 .map(|line| {
589 let mut result = line;
591 for keyword in &[
592 "SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "HAVING", "LIMIT",
593 ] {
594 let pattern = format!("{keyword}");
595 if result.starts_with(&pattern) && !result.starts_with(&format!("{keyword} ")) {
596 result = format!("{keyword} {}", &result[keyword.len()..].trim_start());
597 }
598 }
599 result
600 })
601 .collect()
602}
603
604pub fn format_expression(expr: &SqlExpression) -> String {
605 match expr {
606 SqlExpression::Column(name) => name.clone(),
607 SqlExpression::StringLiteral(value) => format!("'{value}'"),
608 SqlExpression::NumberLiteral(value) => value.clone(),
609 SqlExpression::BinaryOp { left, op, right } => {
610 format!(
611 "{} {} {}",
612 format_expression(left),
613 op,
614 format_expression(right)
615 )
616 }
617 SqlExpression::FunctionCall {
618 name,
619 args,
620 distinct,
621 } => {
622 let args_str = args
623 .iter()
624 .map(format_expression)
625 .collect::<Vec<_>>()
626 .join(", ");
627 if *distinct {
628 format!("{name}(DISTINCT {args_str})")
629 } else {
630 format!("{name}({args_str})")
631 }
632 }
633 SqlExpression::MethodCall {
634 object,
635 method,
636 args,
637 } => {
638 let args_str = args
639 .iter()
640 .map(format_expression)
641 .collect::<Vec<_>>()
642 .join(", ");
643 if args.is_empty() {
644 format!("{object}.{method}()")
645 } else {
646 format!("{object}.{method}({args_str})")
647 }
648 }
649 SqlExpression::InList { expr, values } => {
650 let values_str = values
651 .iter()
652 .map(format_expression)
653 .collect::<Vec<_>>()
654 .join(", ");
655 format!("{} IN ({})", format_expression(expr), values_str)
656 }
657 SqlExpression::NotInList { expr, values } => {
658 let values_str = values
659 .iter()
660 .map(format_expression)
661 .collect::<Vec<_>>()
662 .join(", ");
663 format!("{} NOT IN ({})", format_expression(expr), values_str)
664 }
665 SqlExpression::Between { expr, lower, upper } => {
666 format!(
667 "{} BETWEEN {} AND {}",
668 format_expression(expr),
669 format_expression(lower),
670 format_expression(upper)
671 )
672 }
673 SqlExpression::Null => "NULL".to_string(),
674 SqlExpression::BooleanLiteral(b) => b.to_string().to_uppercase(),
675 SqlExpression::DateTimeConstructor {
676 year,
677 month,
678 day,
679 hour,
680 minute,
681 second,
682 } => {
683 let time_part = match (hour, minute, second) {
684 (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
685 (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
686 _ => String::new(),
687 };
688 format!("DATETIME({year}, {month}, {day}{time_part})")
689 }
690 SqlExpression::DateTimeToday {
691 hour,
692 minute,
693 second,
694 } => {
695 let time_part = match (hour, minute, second) {
696 (Some(h), Some(m), Some(s)) => format!(", {h}, {m}, {s}"),
697 (Some(h), Some(m), None) => format!(", {h}, {m}"),
698 (Some(h), None, None) => format!(", {h}"),
699 _ => String::new(),
700 };
701 format!("TODAY({time_part})")
702 }
703 SqlExpression::WindowFunction {
704 name,
705 args,
706 window_spec,
707 } => {
708 let args_str = args
709 .iter()
710 .map(format_expression)
711 .collect::<Vec<_>>()
712 .join(", ");
713
714 let mut result = format!("{name}({args_str}) OVER (");
715
716 if !window_spec.partition_by.is_empty() {
718 result.push_str("PARTITION BY ");
719 result.push_str(&window_spec.partition_by.join(", "));
720 }
721
722 if !window_spec.order_by.is_empty() {
724 if !window_spec.partition_by.is_empty() {
725 result.push(' ');
726 }
727 result.push_str("ORDER BY ");
728 let order_strs: Vec<String> = window_spec
729 .order_by
730 .iter()
731 .map(|col| {
732 let dir = match col.direction {
733 SortDirection::Asc => " ASC",
734 SortDirection::Desc => " DESC",
735 };
736 format!("{}{}", col.column, dir)
737 })
738 .collect();
739 result.push_str(&order_strs.join(", "));
740 }
741
742 result.push(')');
743 result
744 }
745 SqlExpression::ChainedMethodCall { base, method, args } => {
746 let base_str = format_expression(base);
747 let args_str = args
748 .iter()
749 .map(format_expression)
750 .collect::<Vec<_>>()
751 .join(", ");
752 if args.is_empty() {
753 format!("{base_str}.{method}()")
754 } else {
755 format!("{base_str}.{method}({args_str})")
756 }
757 }
758 SqlExpression::Not { expr } => {
759 format!("NOT {}", format_expression(expr))
760 }
761 SqlExpression::CaseExpression {
762 when_branches,
763 else_branch,
764 } => {
765 let mut result = String::from("CASE");
766 for branch in when_branches {
767 result.push_str(&format!(
768 " WHEN {} THEN {}",
769 format_expression(&branch.condition),
770 format_expression(&branch.result)
771 ));
772 }
773 if let Some(else_expr) = else_branch {
774 result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
775 }
776 result.push_str(" END");
777 result
778 }
779 SqlExpression::ScalarSubquery { query: _ } => {
780 "(SELECT ...)".to_string()
782 }
783 SqlExpression::InSubquery { expr, subquery: _ } => {
784 format!("{} IN (SELECT ...)", format_expression(expr))
785 }
786 SqlExpression::NotInSubquery { expr, subquery: _ } => {
787 format!("{} NOT IN (SELECT ...)", format_expression(expr))
788 }
789 }
790}
791
792fn format_token(token: &Token) -> String {
793 match token {
794 Token::Identifier(s) => s.clone(),
795 Token::QuotedIdentifier(s) => format!("\"{s}\""),
796 Token::StringLiteral(s) => format!("'{s}'"),
797 Token::NumberLiteral(n) => n.clone(),
798 Token::DateTime => "DateTime".to_string(),
799 Token::Case => "CASE".to_string(),
800 Token::When => "WHEN".to_string(),
801 Token::Then => "THEN".to_string(),
802 Token::Else => "ELSE".to_string(),
803 Token::End => "END".to_string(),
804 Token::Distinct => "DISTINCT".to_string(),
805 Token::Over => "OVER".to_string(),
806 Token::Partition => "PARTITION".to_string(),
807 Token::By => "BY".to_string(),
808 Token::LeftParen => "(".to_string(),
809 Token::RightParen => ")".to_string(),
810 Token::Comma => ",".to_string(),
811 Token::Dot => ".".to_string(),
812 Token::Equal => "=".to_string(),
813 Token::NotEqual => "!=".to_string(),
814 Token::LessThan => "<".to_string(),
815 Token::GreaterThan => ">".to_string(),
816 Token::LessThanOrEqual => "<=".to_string(),
817 Token::GreaterThanOrEqual => ">=".to_string(),
818 Token::In => "IN".to_string(),
819 _ => format!("{token:?}").to_uppercase(),
820 }
821}