1use super::ast::{LogicalOp, SelectStatement, SortDirection, SqlExpression, WhereClause};
5use super::lexer::{Lexer, Token};
6use crate::sql::recursive_parser::Parser;
7
8#[must_use]
9pub fn format_sql_pretty(query: &str) -> Vec<String> {
10 format_sql_pretty_compact(query, 5) }
12
13#[must_use]
15pub fn format_ast_tree(query: &str) -> String {
16 let mut parser = Parser::new(query);
17 match parser.parse() {
18 Ok(stmt) => format_select_statement(&stmt, 0),
19 Err(e) => format!("❌ PARSE ERROR ❌\n{e}\n\n⚠️ The query could not be parsed correctly.\n💡 Check parentheses, operators, and syntax."),
20 }
21}
22
23fn format_select_statement(stmt: &SelectStatement, indent: usize) -> String {
24 let mut result = String::new();
25 let indent_str = " ".repeat(indent);
26
27 result.push_str(&format!("{indent_str}SelectStatement {{\n"));
28
29 result.push_str(&format!("{indent_str} columns: ["));
31 if stmt.columns.is_empty() {
32 result.push_str("],\n");
33 } else {
34 result.push('\n');
35 for col in &stmt.columns {
36 result.push_str(&format!("{indent_str} \"{col}\",\n"));
37 }
38 result.push_str(&format!("{indent_str} ],\n"));
39 }
40
41 if let Some(table) = &stmt.from_table {
43 result.push_str(&format!("{indent_str} from_table: \"{table}\",\n"));
44 }
45
46 if let Some(where_clause) = &stmt.where_clause {
48 result.push_str(&format!("{indent_str} where_clause: {{\n"));
49 result.push_str(&format_where_clause(where_clause, indent + 2));
50 result.push_str(&format!("{indent_str} }},\n"));
51 }
52
53 if let Some(order_by) = &stmt.order_by {
55 result.push_str(&format!("{indent_str} order_by: ["));
56 if order_by.is_empty() {
57 result.push_str("],\n");
58 } else {
59 result.push('\n');
60 for col in order_by {
61 let dir = match col.direction {
62 SortDirection::Asc => "ASC",
63 SortDirection::Desc => "DESC",
64 };
65 result.push_str(&format!(
66 "{indent_str} {{ column: \"{}\", direction: {dir} }},\n",
67 col.column
68 ));
69 }
70 result.push_str(&format!("{indent_str} ],\n"));
71 }
72 }
73
74 if let Some(group_by) = &stmt.group_by {
76 result.push_str(&format!("{indent_str} group_by: ["));
77 if group_by.is_empty() {
78 result.push_str("],\n");
79 } else {
80 result.push('\n');
81 for col in group_by {
82 result.push_str(&format!("{indent_str} \"{col}\",\n"));
83 }
84 result.push_str(&format!("{indent_str} ],\n"));
85 }
86 }
87
88 if let Some(limit) = stmt.limit {
90 result.push_str(&format!("{indent_str} limit: {limit},\n"));
91 }
92
93 if stmt.distinct {
95 result.push_str(&format!("{indent_str} distinct: true,\n"));
96 }
97
98 result.push_str(&format!("{indent_str}}}\n"));
99 result
100}
101
102fn format_where_clause(clause: &WhereClause, indent: usize) -> String {
103 let mut result = String::new();
104 let indent_str = " ".repeat(indent);
105
106 result.push_str(&format!("{indent_str}conditions: [\n"));
107 for (i, condition) in clause.conditions.iter().enumerate() {
108 result.push_str(&format!("{indent_str} {{\n"));
109 result.push_str(&format!(
110 "{indent_str} expr: {},\n",
111 format_expression_ast(&condition.expr)
112 ));
113
114 if let Some(connector) = &condition.connector {
115 let conn_str = match connector {
116 LogicalOp::And => "AND",
117 LogicalOp::Or => "OR",
118 };
119 result.push_str(&format!("{indent_str} connector: {conn_str},\n"));
120 }
121
122 result.push_str(&format!("{indent_str} }}"));
123 if i < clause.conditions.len() - 1 {
124 result.push(',');
125 }
126 result.push('\n');
127 }
128 result.push_str(&format!("{indent_str}]\n"));
129
130 result
131}
132
133pub fn format_expression_ast(expr: &SqlExpression) -> String {
134 match expr {
135 SqlExpression::Column(name) => format!("Column(\"{name}\")"),
136 SqlExpression::StringLiteral(value) => format!("StringLiteral(\"{value}\")"),
137 SqlExpression::NumberLiteral(value) => format!("NumberLiteral({value})"),
138 SqlExpression::BinaryOp { left, op, right } => {
139 format!(
140 "BinaryOp {{ left: {}, op: \"{op}\", right: {} }}",
141 format_expression_ast(left),
142 format_expression_ast(right)
143 )
144 }
145 SqlExpression::FunctionCall {
146 name,
147 args,
148 distinct,
149 } => {
150 let args_str = args
151 .iter()
152 .map(format_expression_ast)
153 .collect::<Vec<_>>()
154 .join(", ");
155 if *distinct {
156 format!("FunctionCall {{ name: \"{name}\", args: [{args_str}], distinct: true }}")
157 } else {
158 format!("FunctionCall {{ name: \"{name}\", args: [{args_str}] }}")
159 }
160 }
161 SqlExpression::MethodCall {
162 object,
163 method,
164 args,
165 } => {
166 let args_str = args
167 .iter()
168 .map(format_expression_ast)
169 .collect::<Vec<_>>()
170 .join(", ");
171 format!(
172 "MethodCall {{ object: \"{object}\", method: \"{method}\", args: [{args_str}] }}"
173 )
174 }
175 SqlExpression::InList { expr, values } => {
176 let values_str = values
177 .iter()
178 .map(format_expression_ast)
179 .collect::<Vec<_>>()
180 .join(", ");
181 format!(
182 "InList {{ expr: {}, values: [{values_str}] }}",
183 format_expression_ast(expr)
184 )
185 }
186 SqlExpression::NotInList { expr, values } => {
187 let values_str = values
188 .iter()
189 .map(format_expression_ast)
190 .collect::<Vec<_>>()
191 .join(", ");
192 format!(
193 "NotInList {{ expr: {}, values: [{values_str}] }}",
194 format_expression_ast(expr)
195 )
196 }
197 SqlExpression::Between { expr, lower, upper } => {
198 format!(
199 "Between {{ expr: {}, lower: {}, upper: {} }}",
200 format_expression_ast(expr),
201 format_expression_ast(lower),
202 format_expression_ast(upper)
203 )
204 }
205 SqlExpression::Null => "Null".to_string(),
206 SqlExpression::BooleanLiteral(b) => format!("BooleanLiteral({b})"),
207 SqlExpression::DateTimeConstructor {
208 year,
209 month,
210 day,
211 hour,
212 minute,
213 second,
214 } => {
215 let time_part = match (hour, minute, second) {
216 (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
217 (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
218 _ => String::new(),
219 };
220 format!("DateTimeConstructor({year}-{month:02}-{day:02}{time_part})")
221 }
222 SqlExpression::DateTimeToday {
223 hour,
224 minute,
225 second,
226 } => {
227 let time_part = match (hour, minute, second) {
228 (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
229 (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
230 _ => String::new(),
231 };
232 format!("DateTimeToday({time_part})")
233 }
234 SqlExpression::WindowFunction {
235 name,
236 args,
237 window_spec: _,
238 } => {
239 let args_str = args
240 .iter()
241 .map(format_expression_ast)
242 .collect::<Vec<_>>()
243 .join(", ");
244 format!("WindowFunction {{ name: \"{name}\", args: [{args_str}], window_spec: ... }}")
245 }
246 SqlExpression::ChainedMethodCall { base, method, args } => {
247 let args_str = args
248 .iter()
249 .map(format_expression_ast)
250 .collect::<Vec<_>>()
251 .join(", ");
252 format!(
253 "ChainedMethodCall {{ base: {}, method: \"{method}\", args: [{args_str}] }}",
254 format_expression_ast(base)
255 )
256 }
257 SqlExpression::Not { expr } => {
258 format!("Not {{ expr: {} }}", format_expression_ast(expr))
259 }
260 SqlExpression::CaseExpression {
261 when_branches,
262 else_branch,
263 } => {
264 let mut result = String::from("CaseExpression { when_branches: [");
265 for branch in when_branches {
266 result.push_str(&format!(
267 " {{ condition: {}, result: {} }},",
268 format_expression_ast(&branch.condition),
269 format_expression_ast(&branch.result)
270 ));
271 }
272 result.push_str("], else_branch: ");
273 if let Some(else_expr) = else_branch {
274 result.push_str(&format_expression_ast(else_expr));
275 } else {
276 result.push_str("None");
277 }
278 result.push_str(" }");
279 result
280 }
281 }
282}
283
284fn extract_text_between_positions(text: &str, start: usize, end: usize) -> String {
286 if start >= text.len() || end > text.len() || start >= end {
287 return String::new();
288 }
289 text[start..end].to_string()
290}
291
292fn find_token_position(query: &str, target: Token, skip_count: usize) -> Option<usize> {
294 let mut lexer = Lexer::new(query);
295 let mut found_count = 0;
296
297 loop {
298 let pos = lexer.get_position();
299 let token = lexer.next_token();
300 if token == Token::Eof {
301 break;
302 }
303 if token == target {
304 if found_count == skip_count {
305 return Some(pos);
306 }
307 found_count += 1;
308 }
309 }
310 None
311}
312
313pub fn format_sql_with_preserved_parens(query: &str, cols_per_line: usize) -> Vec<String> {
314 let mut parser = Parser::new(query);
315 let stmt = match parser.parse() {
316 Ok(s) => s,
317 Err(_) => return vec![query.to_string()],
318 };
319
320 let mut lines = Vec::new();
321 let mut lexer = Lexer::new(query);
322 let mut tokens_with_pos = Vec::new();
323
324 loop {
326 let pos = lexer.get_position();
327 let token = lexer.next_token();
328 if token == Token::Eof {
329 break;
330 }
331 tokens_with_pos.push((token, pos));
332 }
333
334 let mut i = 0;
336 while i < tokens_with_pos.len() {
337 match &tokens_with_pos[i].0 {
338 Token::Select => {
339 let _select_start = tokens_with_pos[i].1;
340 i += 1;
341
342 let has_distinct = if i < tokens_with_pos.len() {
344 matches!(tokens_with_pos[i].0, Token::Distinct)
345 } else {
346 false
347 };
348
349 if has_distinct {
350 i += 1;
351 }
352
353 let _select_end = query.len();
355 let _col_count = 0;
356 let _current_line_cols: Vec<String> = Vec::new();
357 let mut all_select_lines = Vec::new();
358
359 let use_pretty_format = stmt.columns.len() > cols_per_line;
361
362 if use_pretty_format {
363 let select_text = if has_distinct {
365 "SELECT DISTINCT".to_string()
366 } else {
367 "SELECT".to_string()
368 };
369 all_select_lines.push(select_text);
370
371 for (idx, col) in stmt.columns.iter().enumerate() {
373 let is_last = idx == stmt.columns.len() - 1;
374 let col_text = if is_last {
375 format!(" {col}")
376 } else {
377 format!(" {col},")
378 };
379 all_select_lines.push(col_text);
380 }
381 } else {
382 let mut select_line = if has_distinct {
384 "SELECT DISTINCT ".to_string()
385 } else {
386 "SELECT ".to_string()
387 };
388
389 for (idx, col) in stmt.columns.iter().enumerate() {
390 if idx > 0 {
391 select_line.push_str(", ");
392 }
393 select_line.push_str(col);
394 }
395 all_select_lines.push(select_line);
396 }
397
398 lines.extend(all_select_lines);
399
400 while i < tokens_with_pos.len() {
402 match &tokens_with_pos[i].0 {
403 Token::From => break,
404 _ => i += 1,
405 }
406 }
407 }
408 Token::From => {
409 let from_start = tokens_with_pos[i].1;
410 i += 1;
411
412 let mut from_end = query.len();
414 while i < tokens_with_pos.len() {
415 match &tokens_with_pos[i].0 {
416 Token::Where
417 | Token::GroupBy
418 | Token::OrderBy
419 | Token::Limit
420 | Token::Having
421 | Token::Eof => {
422 from_end = tokens_with_pos[i].1;
423 break;
424 }
425 _ => i += 1,
426 }
427 }
428
429 let from_text = extract_text_between_positions(query, from_start, from_end);
430 lines.push(from_text.trim().to_string());
431 }
432 Token::Where => {
433 let where_start = tokens_with_pos[i].1;
434 i += 1;
435
436 let mut where_end = query.len();
438 let mut paren_depth = 0;
439 while i < tokens_with_pos.len() {
440 match &tokens_with_pos[i].0 {
441 Token::LeftParen => {
442 paren_depth += 1;
443 i += 1;
444 }
445 Token::RightParen => {
446 paren_depth -= 1;
447 i += 1;
448 }
449 Token::GroupBy
450 | Token::OrderBy
451 | Token::Limit
452 | Token::Having
453 | Token::Eof
454 if paren_depth == 0 =>
455 {
456 where_end = tokens_with_pos[i].1;
457 break;
458 }
459 _ => i += 1,
460 }
461 }
462
463 let where_text = extract_text_between_positions(query, where_start, where_end);
464 let formatted_where = format_where_clause_with_parens(where_text.trim());
465 lines.extend(formatted_where);
466 }
467 Token::GroupBy => {
468 let group_start = tokens_with_pos[i].1;
469 i += 1;
470
471 if i < tokens_with_pos.len() && matches!(tokens_with_pos[i].0, Token::By) {
473 i += 1;
474 }
475
476 while i < tokens_with_pos.len() {
478 match &tokens_with_pos[i].0 {
479 Token::OrderBy | Token::Limit | Token::Having | Token::Eof => break,
480 _ => i += 1,
481 }
482 }
483
484 if i > 0 {
485 let group_text = extract_text_between_positions(
486 query,
487 group_start,
488 tokens_with_pos[i - 1].1,
489 );
490 lines.push(format!("GROUP BY {}", group_text.trim()));
491 }
492 }
493 _ => i += 1,
494 }
495 }
496
497 lines
498}
499
500fn format_where_clause_with_parens(where_text: &str) -> Vec<String> {
501 let mut lines = Vec::new();
502 let mut current = String::from("WHERE ");
503 let mut paren_depth = 0;
504 let mut in_string = false;
505 let mut escape_next = false;
506 let mut chars = where_text.chars().peekable();
507
508 if where_text.trim_start().starts_with("WHERE") || where_text.trim_start().starts_with("where")
510 {
511 let skip_len = if where_text.trim_start().starts_with("WHERE") {
512 5
513 } else {
514 5
515 };
516 for _ in 0..skip_len {
517 chars.next();
518 }
519 while chars.peek() == Some(&' ') {
521 chars.next();
522 }
523 }
524
525 while let Some(ch) = chars.next() {
526 if escape_next {
527 current.push(ch);
528 escape_next = false;
529 continue;
530 }
531
532 match ch {
533 '\\' if in_string => {
534 current.push(ch);
535 escape_next = true;
536 }
537 '\'' => {
538 current.push(ch);
539 in_string = !in_string;
540 }
541 '(' if !in_string => {
542 current.push(ch);
543 paren_depth += 1;
544 }
545 ')' if !in_string => {
546 current.push(ch);
547 paren_depth -= 1;
548 }
549 _ => {
550 current.push(ch);
551 }
552 }
553 }
554
555 let cleaned = current.trim().to_string();
557 if !cleaned.is_empty() {
558 lines.push(cleaned);
559 }
560
561 lines
562}
563
564#[must_use]
565pub fn format_sql_pretty_compact(query: &str, cols_per_line: usize) -> Vec<String> {
566 let formatted = format_sql_with_preserved_parens(query, cols_per_line);
568
569 formatted
571 .into_iter()
572 .filter(|line| !line.trim().is_empty())
573 .map(|line| {
574 let mut result = line;
576 for keyword in &[
577 "SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "HAVING", "LIMIT",
578 ] {
579 let pattern = format!("{keyword}");
580 if result.starts_with(&pattern) && !result.starts_with(&format!("{keyword} ")) {
581 result = format!("{keyword} {}", &result[keyword.len()..].trim_start());
582 }
583 }
584 result
585 })
586 .collect()
587}
588
589pub fn format_expression(expr: &SqlExpression) -> String {
590 match expr {
591 SqlExpression::Column(name) => name.clone(),
592 SqlExpression::StringLiteral(value) => format!("'{value}'"),
593 SqlExpression::NumberLiteral(value) => value.clone(),
594 SqlExpression::BinaryOp { left, op, right } => {
595 format!(
596 "{} {} {}",
597 format_expression(left),
598 op,
599 format_expression(right)
600 )
601 }
602 SqlExpression::FunctionCall {
603 name,
604 args,
605 distinct,
606 } => {
607 let args_str = args
608 .iter()
609 .map(format_expression)
610 .collect::<Vec<_>>()
611 .join(", ");
612 if *distinct {
613 format!("{name}(DISTINCT {args_str})")
614 } else {
615 format!("{name}({args_str})")
616 }
617 }
618 SqlExpression::MethodCall {
619 object,
620 method,
621 args,
622 } => {
623 let args_str = args
624 .iter()
625 .map(format_expression)
626 .collect::<Vec<_>>()
627 .join(", ");
628 if args.is_empty() {
629 format!("{object}.{method}()")
630 } else {
631 format!("{object}.{method}({args_str})")
632 }
633 }
634 SqlExpression::InList { expr, values } => {
635 let values_str = values
636 .iter()
637 .map(format_expression)
638 .collect::<Vec<_>>()
639 .join(", ");
640 format!("{} IN ({})", format_expression(expr), values_str)
641 }
642 SqlExpression::NotInList { expr, values } => {
643 let values_str = values
644 .iter()
645 .map(format_expression)
646 .collect::<Vec<_>>()
647 .join(", ");
648 format!("{} NOT IN ({})", format_expression(expr), values_str)
649 }
650 SqlExpression::Between { expr, lower, upper } => {
651 format!(
652 "{} BETWEEN {} AND {}",
653 format_expression(expr),
654 format_expression(lower),
655 format_expression(upper)
656 )
657 }
658 SqlExpression::Null => "NULL".to_string(),
659 SqlExpression::BooleanLiteral(b) => b.to_string().to_uppercase(),
660 SqlExpression::DateTimeConstructor {
661 year,
662 month,
663 day,
664 hour,
665 minute,
666 second,
667 } => {
668 let time_part = match (hour, minute, second) {
669 (Some(h), Some(m), Some(s)) => format!(" {h:02}:{m:02}:{s:02}"),
670 (Some(h), Some(m), None) => format!(" {h:02}:{m:02}"),
671 _ => String::new(),
672 };
673 format!("DATETIME({year}, {month}, {day}{time_part})")
674 }
675 SqlExpression::DateTimeToday {
676 hour,
677 minute,
678 second,
679 } => {
680 let time_part = match (hour, minute, second) {
681 (Some(h), Some(m), Some(s)) => format!(", {h}, {m}, {s}"),
682 (Some(h), Some(m), None) => format!(", {h}, {m}"),
683 (Some(h), None, None) => format!(", {h}"),
684 _ => String::new(),
685 };
686 format!("TODAY({time_part})")
687 }
688 SqlExpression::WindowFunction {
689 name,
690 args,
691 window_spec,
692 } => {
693 let args_str = args
694 .iter()
695 .map(format_expression)
696 .collect::<Vec<_>>()
697 .join(", ");
698
699 let mut result = format!("{name}({args_str}) OVER (");
700
701 if !window_spec.partition_by.is_empty() {
703 result.push_str("PARTITION BY ");
704 result.push_str(&window_spec.partition_by.join(", "));
705 }
706
707 if !window_spec.order_by.is_empty() {
709 if !window_spec.partition_by.is_empty() {
710 result.push(' ');
711 }
712 result.push_str("ORDER BY ");
713 let order_strs: Vec<String> = window_spec
714 .order_by
715 .iter()
716 .map(|col| {
717 let dir = match col.direction {
718 SortDirection::Asc => " ASC",
719 SortDirection::Desc => " DESC",
720 };
721 format!("{}{}", col.column, dir)
722 })
723 .collect();
724 result.push_str(&order_strs.join(", "));
725 }
726
727 result.push(')');
728 result
729 }
730 SqlExpression::ChainedMethodCall { base, method, args } => {
731 let base_str = format_expression(base);
732 let args_str = args
733 .iter()
734 .map(format_expression)
735 .collect::<Vec<_>>()
736 .join(", ");
737 if args.is_empty() {
738 format!("{base_str}.{method}()")
739 } else {
740 format!("{base_str}.{method}({args_str})")
741 }
742 }
743 SqlExpression::Not { expr } => {
744 format!("NOT {}", format_expression(expr))
745 }
746 SqlExpression::CaseExpression {
747 when_branches,
748 else_branch,
749 } => {
750 let mut result = String::from("CASE");
751 for branch in when_branches {
752 result.push_str(&format!(
753 " WHEN {} THEN {}",
754 format_expression(&branch.condition),
755 format_expression(&branch.result)
756 ));
757 }
758 if let Some(else_expr) = else_branch {
759 result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
760 }
761 result.push_str(" END");
762 result
763 }
764 }
765}
766
767fn format_token(token: &Token) -> String {
768 match token {
769 Token::Identifier(s) => s.clone(),
770 Token::QuotedIdentifier(s) => format!("\"{s}\""),
771 Token::StringLiteral(s) => format!("'{s}'"),
772 Token::NumberLiteral(n) => n.clone(),
773 Token::DateTime => "DateTime".to_string(),
774 Token::Case => "CASE".to_string(),
775 Token::When => "WHEN".to_string(),
776 Token::Then => "THEN".to_string(),
777 Token::Else => "ELSE".to_string(),
778 Token::End => "END".to_string(),
779 Token::Distinct => "DISTINCT".to_string(),
780 Token::Over => "OVER".to_string(),
781 Token::Partition => "PARTITION".to_string(),
782 Token::By => "BY".to_string(),
783 Token::LeftParen => "(".to_string(),
784 Token::RightParen => ")".to_string(),
785 Token::Comma => ",".to_string(),
786 Token::Dot => ".".to_string(),
787 Token::Equal => "=".to_string(),
788 Token::NotEqual => "!=".to_string(),
789 Token::LessThan => "<".to_string(),
790 Token::GreaterThan => ">".to_string(),
791 Token::LessThanOrEqual => "<=".to_string(),
792 Token::GreaterThanOrEqual => ">=".to_string(),
793 Token::In => "IN".to_string(),
794 _ => format!("{token:?}").to_uppercase(),
795 }
796}