1use chrono::{Datelike, Local, NaiveDateTime};
2
3#[derive(Debug, Clone, PartialEq)]
4pub enum Token {
5 Select,
7 From,
8 Where,
9 And,
10 Or,
11 In,
12 Not,
13 Between,
14 Like,
15 Is,
16 Null,
17 OrderBy,
18 GroupBy,
19 Having,
20 As,
21 Asc,
22 Desc,
23 Limit,
24 Offset,
25 DateTime, Case, When, Then, Else, End, Identifier(String),
34 QuotedIdentifier(String), StringLiteral(String),
36 NumberLiteral(String),
37 Star,
38
39 Dot,
41 Comma,
42 LeftParen,
43 RightParen,
44 Equal,
45 NotEqual,
46 LessThan,
47 GreaterThan,
48 LessThanOrEqual,
49 GreaterThanOrEqual,
50
51 Plus,
53 Minus,
54 Divide,
55
56 Eof,
58}
59
60#[derive(Debug, Clone)]
61pub struct Lexer {
62 input: Vec<char>,
63 position: usize,
64 current_char: Option<char>,
65}
66
67impl Lexer {
68 pub fn new(input: &str) -> Self {
69 let chars: Vec<char> = input.chars().collect();
70 let current = chars.get(0).copied();
71 Self {
72 input: chars,
73 position: 0,
74 current_char: current,
75 }
76 }
77
78 fn advance(&mut self) {
79 self.position += 1;
80 self.current_char = self.input.get(self.position).copied();
81 }
82
83 fn peek(&self, offset: usize) -> Option<char> {
84 self.input.get(self.position + offset).copied()
85 }
86
87 fn skip_whitespace(&mut self) {
88 while let Some(ch) = self.current_char {
89 if ch.is_whitespace() {
90 self.advance();
91 } else {
92 break;
93 }
94 }
95 }
96
97 fn read_identifier(&mut self) -> String {
98 let mut result = String::new();
99 while let Some(ch) = self.current_char {
100 if ch.is_alphanumeric() || ch == '_' {
101 result.push(ch);
102 self.advance();
103 } else {
104 break;
105 }
106 }
107 result
108 }
109
110 fn read_string(&mut self) -> String {
111 let mut result = String::new();
112 let quote_char = self.current_char.unwrap(); self.advance(); while let Some(ch) = self.current_char {
116 if ch == quote_char {
117 self.advance(); break;
119 }
120 result.push(ch);
121 self.advance();
122 }
123 result
124 }
125
126 fn read_number(&mut self) -> String {
127 let mut result = String::new();
128 while let Some(ch) = self.current_char {
129 if ch.is_numeric() || ch == '.' {
130 result.push(ch);
131 self.advance();
132 } else {
133 break;
134 }
135 }
136 result
137 }
138
139 pub fn next_token(&mut self) -> Token {
140 self.skip_whitespace();
141
142 match self.current_char {
143 None => Token::Eof,
144 Some('*') => {
145 self.advance();
146 Token::Star }
150 Some('+') => {
151 self.advance();
152 Token::Plus
153 }
154 Some('/') => {
155 self.advance();
156 Token::Divide
157 }
158 Some('.') => {
159 self.advance();
160 Token::Dot
161 }
162 Some(',') => {
163 self.advance();
164 Token::Comma
165 }
166 Some('(') => {
167 self.advance();
168 Token::LeftParen
169 }
170 Some(')') => {
171 self.advance();
172 Token::RightParen
173 }
174 Some('=') => {
175 self.advance();
176 Token::Equal
177 }
178 Some('<') => {
179 self.advance();
180 if self.current_char == Some('=') {
181 self.advance();
182 Token::LessThanOrEqual
183 } else if self.current_char == Some('>') {
184 self.advance();
185 Token::NotEqual
186 } else {
187 Token::LessThan
188 }
189 }
190 Some('>') => {
191 self.advance();
192 if self.current_char == Some('=') {
193 self.advance();
194 Token::GreaterThanOrEqual
195 } else {
196 Token::GreaterThan
197 }
198 }
199 Some('!') if self.peek(1) == Some('=') => {
200 self.advance();
201 self.advance();
202 Token::NotEqual
203 }
204 Some('"') => {
205 let ident_val = self.read_string();
207 Token::QuotedIdentifier(ident_val)
208 }
209 Some('\'') => {
210 let string_val = self.read_string();
212 Token::StringLiteral(string_val)
213 }
214 Some('-') if self.peek(1).map_or(false, |c| c.is_numeric()) => {
215 self.advance(); let num = self.read_number();
218 Token::NumberLiteral(format!("-{}", num))
219 }
220 Some('-') => {
221 self.advance();
223 Token::Minus
224 }
225 Some(ch) if ch.is_numeric() => {
226 let num = self.read_number();
227 Token::NumberLiteral(num)
228 }
229 Some(ch) if ch.is_alphabetic() || ch == '_' => {
230 let ident = self.read_identifier();
231 match ident.to_uppercase().as_str() {
232 "SELECT" => Token::Select,
233 "FROM" => Token::From,
234 "WHERE" => Token::Where,
235 "AND" => Token::And,
236 "OR" => Token::Or,
237 "IN" => Token::In,
238 "NOT" => Token::Not,
239 "BETWEEN" => Token::Between,
240 "LIKE" => Token::Like,
241 "IS" => Token::Is,
242 "NULL" => Token::Null,
243 "ORDER" if self.peek_keyword("BY") => {
244 self.skip_whitespace();
245 self.read_identifier(); Token::OrderBy
247 }
248 "GROUP" if self.peek_keyword("BY") => {
249 self.skip_whitespace();
250 self.read_identifier(); Token::GroupBy
252 }
253 "HAVING" => Token::Having,
254 "AS" => Token::As,
255 "ASC" => Token::Asc,
256 "DESC" => Token::Desc,
257 "LIMIT" => Token::Limit,
258 "OFFSET" => Token::Offset,
259 "DATETIME" => Token::DateTime,
260 "CASE" => Token::Case,
261 "WHEN" => Token::When,
262 "THEN" => Token::Then,
263 "ELSE" => Token::Else,
264 "END" => Token::End,
265 _ => Token::Identifier(ident),
266 }
267 }
268 Some(ch) => {
269 self.advance();
270 Token::Identifier(ch.to_string())
271 }
272 }
273 }
274
275 fn peek_keyword(&mut self, keyword: &str) -> bool {
276 let saved_pos = self.position;
277 let saved_char = self.current_char;
278
279 self.skip_whitespace();
280 let next_word = self.read_identifier();
281 let matches = next_word.to_uppercase() == keyword;
282
283 self.position = saved_pos;
285 self.current_char = saved_char;
286
287 matches
288 }
289
290 pub fn get_position(&self) -> usize {
291 self.position
292 }
293
294 pub fn tokenize_all(&mut self) -> Vec<Token> {
295 let mut tokens = Vec::new();
296 loop {
297 let token = self.next_token();
298 if matches!(token, Token::Eof) {
299 tokens.push(token);
300 break;
301 }
302 tokens.push(token);
303 }
304 tokens
305 }
306
307 pub fn tokenize_all_with_positions(&mut self) -> Vec<(usize, usize, Token)> {
308 let mut tokens = Vec::new();
309 loop {
310 self.skip_whitespace();
311 let start_pos = self.position;
312 let token = self.next_token();
313 let end_pos = self.position;
314
315 if matches!(token, Token::Eof) {
316 break;
317 }
318 tokens.push((start_pos, end_pos, token));
319 }
320 tokens
321 }
322}
323
324#[derive(Debug, Clone)]
326pub enum SqlExpression {
327 Column(String),
328 StringLiteral(String),
329 NumberLiteral(String),
330 DateTimeConstructor {
331 year: i32,
332 month: u32,
333 day: u32,
334 hour: Option<u32>,
335 minute: Option<u32>,
336 second: Option<u32>,
337 },
338 DateTimeToday {
339 hour: Option<u32>,
340 minute: Option<u32>,
341 second: Option<u32>,
342 },
343 MethodCall {
344 object: String,
345 method: String,
346 args: Vec<SqlExpression>,
347 },
348 ChainedMethodCall {
349 base: Box<SqlExpression>,
350 method: String,
351 args: Vec<SqlExpression>,
352 },
353 FunctionCall {
354 name: String,
355 args: Vec<SqlExpression>,
356 },
357 BinaryOp {
358 left: Box<SqlExpression>,
359 op: String,
360 right: Box<SqlExpression>,
361 },
362 InList {
363 expr: Box<SqlExpression>,
364 values: Vec<SqlExpression>,
365 },
366 NotInList {
367 expr: Box<SqlExpression>,
368 values: Vec<SqlExpression>,
369 },
370 Between {
371 expr: Box<SqlExpression>,
372 lower: Box<SqlExpression>,
373 upper: Box<SqlExpression>,
374 },
375 Not {
376 expr: Box<SqlExpression>,
377 },
378 CaseExpression {
379 when_branches: Vec<WhenBranch>,
380 else_branch: Option<Box<SqlExpression>>,
381 },
382}
383
384#[derive(Debug, Clone)]
385pub struct WhenBranch {
386 pub condition: Box<SqlExpression>,
387 pub result: Box<SqlExpression>,
388}
389
390#[derive(Debug, Clone)]
391pub struct WhereClause {
392 pub conditions: Vec<Condition>,
393}
394
395#[derive(Debug, Clone)]
396pub struct Condition {
397 pub expr: SqlExpression,
398 pub connector: Option<LogicalOp>, }
400
401#[derive(Debug, Clone)]
402pub enum LogicalOp {
403 And,
404 Or,
405}
406
407#[derive(Debug, Clone, PartialEq)]
408pub enum SortDirection {
409 Asc,
410 Desc,
411}
412
413#[derive(Debug, Clone)]
414pub struct OrderByColumn {
415 pub column: String,
416 pub direction: SortDirection,
417}
418
419#[derive(Debug, Clone)]
421pub enum SelectItem {
422 Column(String),
424 Expression { expr: SqlExpression, alias: String },
426 Star,
428}
429
430#[derive(Debug, Clone)]
431pub struct SelectStatement {
432 pub columns: Vec<String>, pub select_items: Vec<SelectItem>, pub from_table: Option<String>,
435 pub where_clause: Option<WhereClause>,
436 pub order_by: Option<Vec<OrderByColumn>>,
437 pub group_by: Option<Vec<String>>,
438 pub limit: Option<usize>,
439 pub offset: Option<usize>,
440}
441
442pub struct ParserConfig {
443 pub case_insensitive: bool,
444}
445
446impl Default for ParserConfig {
447 fn default() -> Self {
448 Self {
449 case_insensitive: false,
450 }
451 }
452}
453
454pub struct Parser {
455 lexer: Lexer,
456 current_token: Token,
457 in_method_args: bool, columns: Vec<String>, paren_depth: i32, config: ParserConfig, }
462
463impl Parser {
464 pub fn new(input: &str) -> Self {
465 let mut lexer = Lexer::new(input);
466 let current_token = lexer.next_token();
467 Self {
468 lexer,
469 current_token,
470 in_method_args: false,
471 columns: Vec::new(),
472 paren_depth: 0,
473 config: ParserConfig::default(),
474 }
475 }
476
477 pub fn with_config(input: &str, config: ParserConfig) -> Self {
478 let mut lexer = Lexer::new(input);
479 let current_token = lexer.next_token();
480 Self {
481 lexer,
482 current_token,
483 in_method_args: false,
484 columns: Vec::new(),
485 paren_depth: 0,
486 config,
487 }
488 }
489
490 pub fn with_columns(mut self, columns: Vec<String>) -> Self {
491 self.columns = columns;
492 self
493 }
494
495 fn consume(&mut self, expected: Token) -> Result<(), String> {
496 if std::mem::discriminant(&self.current_token) == std::mem::discriminant(&expected) {
497 match &expected {
499 Token::LeftParen => self.paren_depth += 1,
500 Token::RightParen => {
501 self.paren_depth -= 1;
502 if self.paren_depth < 0 {
504 return Err(
505 "Unexpected closing parenthesis - no matching opening parenthesis"
506 .to_string(),
507 );
508 }
509 }
510 _ => {}
511 }
512
513 self.current_token = self.lexer.next_token();
514 Ok(())
515 } else {
516 let error_msg = match (&expected, &self.current_token) {
518 (Token::RightParen, Token::Eof) if self.paren_depth > 0 => {
519 format!(
520 "Unclosed parenthesis - missing {} closing parenthes{}",
521 self.paren_depth,
522 if self.paren_depth == 1 { "is" } else { "es" }
523 )
524 }
525 (Token::RightParen, _) if self.paren_depth > 0 => {
526 format!(
527 "Expected closing parenthesis but found {:?} (currently {} unclosed parenthes{})",
528 self.current_token,
529 self.paren_depth,
530 if self.paren_depth == 1 { "is" } else { "es" }
531 )
532 }
533 _ => format!("Expected {:?}, found {:?}", expected, self.current_token),
534 };
535 Err(error_msg)
536 }
537 }
538
539 fn advance(&mut self) {
540 match &self.current_token {
542 Token::LeftParen => self.paren_depth += 1,
543 Token::RightParen => {
544 self.paren_depth -= 1;
545 }
548 _ => {}
549 }
550 self.current_token = self.lexer.next_token();
551 }
552
553 pub fn parse(&mut self) -> Result<SelectStatement, String> {
554 self.parse_select_statement()
555 }
556
557 fn parse_select_statement(&mut self) -> Result<SelectStatement, String> {
558 self.consume(Token::Select)?;
559
560 let select_items = self.parse_select_items()?;
562
563 let columns = select_items
565 .iter()
566 .map(|item| match item {
567 SelectItem::Star => "*".to_string(),
568 SelectItem::Column(name) => name.clone(),
569 SelectItem::Expression { alias, .. } => alias.clone(),
570 })
571 .collect();
572
573 let from_table = if matches!(self.current_token, Token::From) {
574 self.advance();
575 match &self.current_token {
576 Token::Identifier(table) => {
577 let table_name = table.clone();
578 self.advance();
579 Some(table_name)
580 }
581 Token::QuotedIdentifier(table) => {
582 let table_name = table.clone();
584 self.advance();
585 Some(table_name)
586 }
587 _ => return Err("Expected table name after FROM".to_string()),
588 }
589 } else {
590 None
591 };
592
593 let where_clause = if matches!(self.current_token, Token::Where) {
594 self.advance();
595 Some(self.parse_where_clause()?)
596 } else {
597 None
598 };
599
600 let order_by = if matches!(self.current_token, Token::OrderBy) {
601 self.advance();
602 Some(self.parse_order_by_list()?)
603 } else {
604 None
605 };
606
607 let group_by = if matches!(self.current_token, Token::GroupBy) {
608 self.advance();
609 Some(self.parse_identifier_list()?)
610 } else {
611 None
612 };
613
614 let limit = if matches!(self.current_token, Token::Limit) {
616 self.advance();
617 match &self.current_token {
618 Token::NumberLiteral(num) => {
619 let limit_val = num
620 .parse::<usize>()
621 .map_err(|_| format!("Invalid LIMIT value: {}", num))?;
622 self.advance();
623 Some(limit_val)
624 }
625 _ => return Err("Expected number after LIMIT".to_string()),
626 }
627 } else {
628 None
629 };
630
631 let offset = if matches!(self.current_token, Token::Offset) {
633 self.advance();
634 match &self.current_token {
635 Token::NumberLiteral(num) => {
636 let offset_val = num
637 .parse::<usize>()
638 .map_err(|_| format!("Invalid OFFSET value: {}", num))?;
639 self.advance();
640 Some(offset_val)
641 }
642 _ => return Err("Expected number after OFFSET".to_string()),
643 }
644 } else {
645 None
646 };
647
648 if self.paren_depth > 0 {
650 return Err(format!(
651 "Unclosed parenthesis - missing {} closing parenthes{}",
652 self.paren_depth,
653 if self.paren_depth == 1 { "is" } else { "es" }
654 ));
655 } else if self.paren_depth < 0 {
656 return Err(
657 "Extra closing parenthesis found - no matching opening parenthesis".to_string(),
658 );
659 }
660
661 Ok(SelectStatement {
662 columns,
663 select_items,
664 from_table,
665 where_clause,
666 order_by,
667 group_by,
668 limit,
669 offset,
670 })
671 }
672
673 fn parse_select_list(&mut self) -> Result<Vec<String>, String> {
674 let mut columns = Vec::new();
675
676 if matches!(self.current_token, Token::Star) {
677 columns.push("*".to_string());
678 self.advance();
679 } else {
680 loop {
681 match &self.current_token {
682 Token::Identifier(col) => {
683 columns.push(col.clone());
684 self.advance();
685 }
686 Token::QuotedIdentifier(col) => {
687 columns.push(col.clone());
689 self.advance();
690 }
691 _ => return Err("Expected column name".to_string()),
692 }
693
694 if matches!(self.current_token, Token::Comma) {
695 self.advance();
696 } else {
697 break;
698 }
699 }
700 }
701
702 Ok(columns)
703 }
704
705 fn parse_select_items(&mut self) -> Result<Vec<SelectItem>, String> {
707 let mut items = Vec::new();
708
709 loop {
710 if matches!(self.current_token, Token::Star) {
713 items.push(SelectItem::Star);
721 self.advance();
722 } else {
723 let expr = self.parse_additive()?; let alias = if matches!(self.current_token, Token::As) {
728 self.advance();
729 match &self.current_token {
730 Token::Identifier(alias_name) => {
731 let alias = alias_name.clone();
732 self.advance();
733 alias
734 }
735 Token::QuotedIdentifier(alias_name) => {
736 let alias = alias_name.clone();
737 self.advance();
738 alias
739 }
740 _ => return Err("Expected alias name after AS".to_string()),
741 }
742 } else {
743 match &expr {
745 SqlExpression::Column(col_name) => col_name.clone(),
746 _ => format!("expr_{}", items.len() + 1), }
748 };
749
750 let item = match expr {
752 SqlExpression::Column(col_name) if alias == col_name => {
753 SelectItem::Column(col_name)
755 }
756 _ => {
757 SelectItem::Expression { expr, alias }
759 }
760 };
761
762 items.push(item);
763 }
764
765 if matches!(self.current_token, Token::Comma) {
767 self.advance();
768 } else {
769 break;
770 }
771 }
772
773 Ok(items)
774 }
775
776 fn parse_identifier_list(&mut self) -> Result<Vec<String>, String> {
777 let mut identifiers = Vec::new();
778
779 loop {
780 match &self.current_token {
781 Token::Identifier(id) => {
782 identifiers.push(id.clone());
783 self.advance();
784 }
785 Token::QuotedIdentifier(id) => {
786 identifiers.push(id.clone());
788 self.advance();
789 }
790 _ => return Err("Expected identifier".to_string()),
791 }
792
793 if matches!(self.current_token, Token::Comma) {
794 self.advance();
795 } else {
796 break;
797 }
798 }
799
800 Ok(identifiers)
801 }
802
803 fn parse_order_by_list(&mut self) -> Result<Vec<OrderByColumn>, String> {
804 let mut order_columns = Vec::new();
805
806 loop {
807 let column = match &self.current_token {
808 Token::Identifier(id) => {
809 let col = id.clone();
810 self.advance();
811 col
812 }
813 Token::QuotedIdentifier(id) => {
814 let col = id.clone();
815 self.advance();
816 col
817 }
818 Token::NumberLiteral(num) if self.columns.iter().any(|col| col == num) => {
819 let col = num.clone();
821 self.advance();
822 col
823 }
824 _ => return Err("Expected column name in ORDER BY".to_string()),
825 };
826
827 let direction = match &self.current_token {
829 Token::Asc => {
830 self.advance();
831 SortDirection::Asc
832 }
833 Token::Desc => {
834 self.advance();
835 SortDirection::Desc
836 }
837 _ => SortDirection::Asc, };
839
840 order_columns.push(OrderByColumn { column, direction });
841
842 if matches!(self.current_token, Token::Comma) {
843 self.advance();
844 } else {
845 break;
846 }
847 }
848
849 Ok(order_columns)
850 }
851
852 fn parse_where_clause(&mut self) -> Result<WhereClause, String> {
853 let mut conditions = Vec::new();
854
855 loop {
856 let expr = self.parse_expression()?;
857
858 let connector = match &self.current_token {
859 Token::And => {
860 self.advance();
861 Some(LogicalOp::And)
862 }
863 Token::Or => {
864 self.advance();
865 Some(LogicalOp::Or)
866 }
867 Token::RightParen if self.paren_depth <= 0 => {
868 return Err(
870 "Unexpected closing parenthesis - no matching opening parenthesis"
871 .to_string(),
872 );
873 }
874 _ => None,
875 };
876
877 conditions.push(Condition {
878 expr,
879 connector: connector.clone(),
880 });
881
882 if connector.is_none() {
883 break;
884 }
885 }
886
887 Ok(WhereClause { conditions })
888 }
889
890 fn parse_expression(&mut self) -> Result<SqlExpression, String> {
891 let mut left = self.parse_comparison()?;
892
893 if let Some(op) = self.get_binary_op() {
896 self.advance();
897 let right = self.parse_expression()?;
898 left = SqlExpression::BinaryOp {
899 left: Box::new(left),
900 op,
901 right: Box::new(right),
902 };
903 }
904
905 if matches!(self.current_token, Token::In) {
907 self.advance();
908 self.consume(Token::LeftParen)?;
909 let values = self.parse_expression_list()?;
910 self.consume(Token::RightParen)?;
911
912 left = SqlExpression::InList {
913 expr: Box::new(left),
914 values,
915 };
916 }
917
918 Ok(left)
922 }
923
924 fn parse_comparison(&mut self) -> Result<SqlExpression, String> {
925 let mut left = self.parse_additive()?;
926
927 if matches!(self.current_token, Token::Between) {
929 self.advance(); let lower = self.parse_primary()?;
931 self.consume(Token::And)?; let upper = self.parse_primary()?;
933
934 return Ok(SqlExpression::Between {
935 expr: Box::new(left),
936 lower: Box::new(lower),
937 upper: Box::new(upper),
938 });
939 }
940
941 if matches!(self.current_token, Token::Not) {
943 self.advance(); if matches!(self.current_token, Token::In) {
945 self.advance(); self.consume(Token::LeftParen)?;
947 let values = self.parse_expression_list()?;
948 self.consume(Token::RightParen)?;
949
950 return Ok(SqlExpression::NotInList {
951 expr: Box::new(left),
952 values,
953 });
954 } else {
955 return Err("Expected IN after NOT".to_string());
956 }
957 }
958
959 if let Some(op) = self.get_binary_op() {
961 self.advance();
962 let right = self.parse_additive()?;
963 left = SqlExpression::BinaryOp {
964 left: Box::new(left),
965 op,
966 right: Box::new(right),
967 };
968 }
969
970 Ok(left)
971 }
972
973 fn parse_additive(&mut self) -> Result<SqlExpression, String> {
974 let mut left = self.parse_multiplicative()?;
975
976 while matches!(self.current_token, Token::Plus | Token::Minus) {
977 let op = match self.current_token {
978 Token::Plus => "+",
979 Token::Minus => "-",
980 _ => unreachable!(),
981 };
982 self.advance();
983 let right = self.parse_multiplicative()?;
984 left = SqlExpression::BinaryOp {
985 left: Box::new(left),
986 op: op.to_string(),
987 right: Box::new(right),
988 };
989 }
990
991 Ok(left)
992 }
993
994 fn parse_multiplicative(&mut self) -> Result<SqlExpression, String> {
995 let mut left = self.parse_primary()?;
996
997 while matches!(self.current_token, Token::Dot) {
999 self.advance();
1000 if let Token::Identifier(method) = &self.current_token {
1001 let method_name = method.clone();
1002 self.advance();
1003
1004 if matches!(self.current_token, Token::LeftParen) {
1005 self.advance();
1006 let args = self.parse_method_args()?;
1007 self.consume(Token::RightParen)?;
1008
1009 match left {
1011 SqlExpression::Column(obj) => {
1012 left = SqlExpression::MethodCall {
1014 object: obj,
1015 method: method_name,
1016 args,
1017 };
1018 }
1019 SqlExpression::MethodCall { .. }
1020 | SqlExpression::ChainedMethodCall { .. } => {
1021 left = SqlExpression::ChainedMethodCall {
1023 base: Box::new(left),
1024 method: method_name,
1025 args,
1026 };
1027 }
1028 _ => {
1029 left = SqlExpression::ChainedMethodCall {
1031 base: Box::new(left),
1032 method: method_name,
1033 args,
1034 };
1035 }
1036 }
1037 } else {
1038 return Err(format!("Expected '(' after method name '{}'", method_name));
1039 }
1040 } else {
1041 return Err("Expected method name after '.'".to_string());
1042 }
1043 }
1044
1045 while matches!(self.current_token, Token::Star | Token::Divide) {
1046 let op = match self.current_token {
1047 Token::Star => "*",
1048 Token::Divide => "/",
1049 _ => unreachable!(),
1050 };
1051 self.advance();
1052 let right = self.parse_primary()?;
1053 left = SqlExpression::BinaryOp {
1054 left: Box::new(left),
1055 op: op.to_string(),
1056 right: Box::new(right),
1057 };
1058 }
1059
1060 Ok(left)
1061 }
1062
1063 fn parse_logical_or(&mut self) -> Result<SqlExpression, String> {
1064 let mut left = self.parse_logical_and()?;
1065
1066 while matches!(self.current_token, Token::Or) {
1067 self.advance();
1068 let right = self.parse_logical_and()?;
1069 left = SqlExpression::BinaryOp {
1073 left: Box::new(left),
1074 op: "OR".to_string(),
1075 right: Box::new(right),
1076 };
1077 }
1078
1079 Ok(left)
1080 }
1081
1082 fn parse_logical_and(&mut self) -> Result<SqlExpression, String> {
1083 let mut left = self.parse_expression()?;
1084
1085 while matches!(self.current_token, Token::And) {
1086 self.advance();
1087 let right = self.parse_expression()?;
1088 left = SqlExpression::BinaryOp {
1090 left: Box::new(left),
1091 op: "AND".to_string(),
1092 right: Box::new(right),
1093 };
1094 }
1095
1096 Ok(left)
1097 }
1098
1099 fn parse_case_expression(&mut self) -> Result<SqlExpression, String> {
1100 self.consume(Token::Case)?;
1102
1103 let mut when_branches = Vec::new();
1104
1105 while matches!(self.current_token, Token::When) {
1107 self.advance(); let condition = self.parse_expression()?;
1111
1112 self.consume(Token::Then)?;
1114
1115 let result = self.parse_expression()?;
1117
1118 when_branches.push(WhenBranch {
1119 condition: Box::new(condition),
1120 result: Box::new(result),
1121 });
1122 }
1123
1124 if when_branches.is_empty() {
1126 return Err("CASE expression must have at least one WHEN clause".to_string());
1127 }
1128
1129 let else_branch = if matches!(self.current_token, Token::Else) {
1131 self.advance(); Some(Box::new(self.parse_expression()?))
1133 } else {
1134 None
1135 };
1136
1137 self.consume(Token::End)?;
1139
1140 Ok(SqlExpression::CaseExpression {
1141 when_branches,
1142 else_branch,
1143 })
1144 }
1145
1146 fn parse_primary(&mut self) -> Result<SqlExpression, String> {
1147 if let Token::NumberLiteral(num_str) = &self.current_token {
1150 if self.columns.iter().any(|col| col == num_str) {
1152 let expr = SqlExpression::Column(num_str.clone());
1153 self.advance();
1154 return Ok(expr);
1155 }
1156 }
1157
1158 match &self.current_token {
1159 Token::Case => {
1160 self.parse_case_expression()
1162 }
1163 Token::DateTime => {
1164 self.advance(); self.consume(Token::LeftParen)?;
1166
1167 if matches!(&self.current_token, Token::RightParen) {
1169 self.advance(); return Ok(SqlExpression::DateTimeToday {
1171 hour: None,
1172 minute: None,
1173 second: None,
1174 });
1175 }
1176
1177 let year = if let Token::NumberLiteral(n) = &self.current_token {
1179 n.parse::<i32>().map_err(|_| "Invalid year")?
1180 } else {
1181 return Err("Expected year in DateTime constructor".to_string());
1182 };
1183 self.advance();
1184 self.consume(Token::Comma)?;
1185
1186 let month = if let Token::NumberLiteral(n) = &self.current_token {
1188 n.parse::<u32>().map_err(|_| "Invalid month")?
1189 } else {
1190 return Err("Expected month in DateTime constructor".to_string());
1191 };
1192 self.advance();
1193 self.consume(Token::Comma)?;
1194
1195 let day = if let Token::NumberLiteral(n) = &self.current_token {
1197 n.parse::<u32>().map_err(|_| "Invalid day")?
1198 } else {
1199 return Err("Expected day in DateTime constructor".to_string());
1200 };
1201 self.advance();
1202
1203 let mut hour = None;
1205 let mut minute = None;
1206 let mut second = None;
1207
1208 if matches!(&self.current_token, Token::Comma) {
1209 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1213 hour = Some(n.parse::<u32>().map_err(|_| "Invalid hour")?);
1214 self.advance();
1215
1216 if matches!(&self.current_token, Token::Comma) {
1218 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1221 minute = Some(n.parse::<u32>().map_err(|_| "Invalid minute")?);
1222 self.advance();
1223
1224 if matches!(&self.current_token, Token::Comma) {
1226 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1229 second =
1230 Some(n.parse::<u32>().map_err(|_| "Invalid second")?);
1231 self.advance();
1232 }
1233 }
1234 }
1235 }
1236 }
1237 }
1238
1239 self.consume(Token::RightParen)?;
1240 Ok(SqlExpression::DateTimeConstructor {
1241 year,
1242 month,
1243 day,
1244 hour,
1245 minute,
1246 second,
1247 })
1248 }
1249 Token::Identifier(id) => {
1250 let id_upper = id.to_uppercase();
1251 let id_clone = id.clone();
1252 self.advance();
1253
1254 if matches!(self.current_token, Token::LeftParen) {
1256 if matches!(
1258 id_upper.as_str(),
1259 "ROUND"
1260 | "ABS"
1261 | "FLOOR"
1262 | "CEILING"
1263 | "CEIL"
1264 | "MOD"
1265 | "QUOTIENT"
1266 | "POWER"
1267 | "POW"
1268 | "SQRT"
1269 | "EXP"
1270 | "LN"
1271 | "LOG"
1272 | "LOG10"
1273 | "PI"
1274 | "TEXTJOIN"
1275 | "DATEDIFF"
1276 | "DATEADD"
1277 | "NOW"
1278 | "TODAY"
1279 ) {
1280 self.advance(); let args = self.parse_function_args()?;
1282 self.consume(Token::RightParen)?;
1283 return Ok(SqlExpression::FunctionCall {
1284 name: id_upper,
1285 args,
1286 });
1287 }
1288 }
1289
1290 Ok(SqlExpression::Column(id_clone))
1292 }
1293 Token::QuotedIdentifier(id) => {
1294 let expr = if self.in_method_args {
1297 SqlExpression::StringLiteral(id.clone())
1298 } else {
1299 SqlExpression::Column(id.clone())
1301 };
1302 self.advance();
1303 Ok(expr)
1304 }
1305 Token::StringLiteral(s) => {
1306 let expr = SqlExpression::StringLiteral(s.clone());
1307 self.advance();
1308 Ok(expr)
1309 }
1310 Token::NumberLiteral(n) => {
1311 let expr = SqlExpression::NumberLiteral(n.clone());
1312 self.advance();
1313 Ok(expr)
1314 }
1315 Token::LeftParen => {
1316 self.advance();
1317
1318 let expr = self.parse_logical_or()?;
1321
1322 self.consume(Token::RightParen)?;
1323 Ok(expr)
1324 }
1325 Token::Not => {
1326 self.advance(); if let Ok(inner_expr) = self.parse_comparison() {
1330 if matches!(self.current_token, Token::In) {
1332 self.advance(); self.consume(Token::LeftParen)?;
1334 let values = self.parse_expression_list()?;
1335 self.consume(Token::RightParen)?;
1336
1337 return Ok(SqlExpression::NotInList {
1338 expr: Box::new(inner_expr),
1339 values,
1340 });
1341 } else {
1342 return Ok(SqlExpression::Not {
1344 expr: Box::new(inner_expr),
1345 });
1346 }
1347 } else {
1348 return Err("Expected expression after NOT".to_string());
1349 }
1350 }
1351 _ => Err(format!("Unexpected token: {:?}", self.current_token)),
1352 }
1353 }
1354
1355 fn parse_method_args(&mut self) -> Result<Vec<SqlExpression>, String> {
1356 let mut args = Vec::new();
1357
1358 self.in_method_args = true;
1360
1361 if !matches!(self.current_token, Token::RightParen) {
1362 loop {
1363 args.push(self.parse_expression()?);
1364
1365 if matches!(self.current_token, Token::Comma) {
1366 self.advance();
1367 } else {
1368 break;
1369 }
1370 }
1371 }
1372
1373 self.in_method_args = false;
1375
1376 Ok(args)
1377 }
1378
1379 fn parse_function_args(&mut self) -> Result<Vec<SqlExpression>, String> {
1380 let mut args = Vec::new();
1381
1382 if !matches!(self.current_token, Token::RightParen) {
1383 loop {
1384 args.push(self.parse_additive()?);
1386
1387 if matches!(self.current_token, Token::Comma) {
1388 self.advance();
1389 } else {
1390 break;
1391 }
1392 }
1393 }
1394
1395 Ok(args)
1396 }
1397
1398 fn parse_expression_list(&mut self) -> Result<Vec<SqlExpression>, String> {
1399 let mut expressions = Vec::new();
1400
1401 loop {
1402 expressions.push(self.parse_expression()?);
1403
1404 if matches!(self.current_token, Token::Comma) {
1405 self.advance();
1406 } else {
1407 break;
1408 }
1409 }
1410
1411 Ok(expressions)
1412 }
1413
1414 fn get_binary_op(&self) -> Option<String> {
1415 match &self.current_token {
1416 Token::Equal => Some("=".to_string()),
1417 Token::NotEqual => Some("!=".to_string()),
1418 Token::LessThan => Some("<".to_string()),
1419 Token::GreaterThan => Some(">".to_string()),
1420 Token::LessThanOrEqual => Some("<=".to_string()),
1421 Token::GreaterThanOrEqual => Some(">=".to_string()),
1422 Token::Like => Some("LIKE".to_string()),
1423 _ => None,
1424 }
1425 }
1426
1427 fn get_arithmetic_op(&self) -> Option<String> {
1428 match &self.current_token {
1429 Token::Plus => Some("+".to_string()),
1430 Token::Minus => Some("-".to_string()),
1431 Token::Star => Some("*".to_string()), Token::Divide => Some("/".to_string()),
1433 _ => None,
1434 }
1435 }
1436
1437 pub fn get_position(&self) -> usize {
1438 self.lexer.get_position()
1439 }
1440}
1441
1442#[derive(Debug, Clone)]
1444pub enum CursorContext {
1445 SelectClause,
1446 FromClause,
1447 WhereClause,
1448 OrderByClause,
1449 AfterColumn(String),
1450 AfterLogicalOp(LogicalOp),
1451 AfterComparisonOp(String, String), InMethodCall(String, String), InExpression,
1454 Unknown,
1455}
1456
1457fn safe_slice_to(s: &str, pos: usize) -> &str {
1459 if pos >= s.len() {
1460 return s;
1461 }
1462
1463 let mut safe_pos = pos;
1465 while safe_pos > 0 && !s.is_char_boundary(safe_pos) {
1466 safe_pos -= 1;
1467 }
1468
1469 &s[..safe_pos]
1470}
1471
1472fn safe_slice_from(s: &str, pos: usize) -> &str {
1474 if pos >= s.len() {
1475 return "";
1476 }
1477
1478 let mut safe_pos = pos;
1480 while safe_pos < s.len() && !s.is_char_boundary(safe_pos) {
1481 safe_pos += 1;
1482 }
1483
1484 &s[safe_pos..]
1485}
1486
1487pub fn detect_cursor_context(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
1488 let truncated = safe_slice_to(query, cursor_pos);
1489 let mut parser = Parser::new(truncated);
1490
1491 match parser.parse() {
1493 Ok(stmt) => {
1494 let (ctx, partial) = analyze_statement(&stmt, truncated, cursor_pos);
1495 #[cfg(test)]
1496 println!(
1497 "analyze_statement returned: {:?}, {:?} for query: '{}'",
1498 ctx, partial, truncated
1499 );
1500 (ctx, partial)
1501 }
1502 Err(_) => {
1503 let (ctx, partial) = analyze_partial(truncated, cursor_pos);
1505 #[cfg(test)]
1506 println!(
1507 "analyze_partial returned: {:?}, {:?} for query: '{}'",
1508 ctx, partial, truncated
1509 );
1510 (ctx, partial)
1511 }
1512 }
1513}
1514
1515pub fn tokenize_query(query: &str) -> Vec<String> {
1516 let mut lexer = Lexer::new(query);
1517 let tokens = lexer.tokenize_all();
1518 tokens.iter().map(|t| format!("{:?}", t)).collect()
1519}
1520
1521pub fn format_sql_pretty(query: &str) -> Vec<String> {
1522 format_sql_pretty_compact(query, 5) }
1524
1525pub fn format_ast_tree(query: &str) -> String {
1527 let mut parser = Parser::new(query);
1528 match parser.parse() {
1529 Ok(stmt) => format_select_statement(&stmt, 0),
1530 Err(e) => format!("❌ PARSE ERROR ❌\n{}\n\n⚠️ The query could not be parsed correctly.\n💡 Check parentheses, operators, and syntax.", e),
1531 }
1532}
1533
1534fn format_select_statement(stmt: &SelectStatement, indent: usize) -> String {
1535 let mut result = String::new();
1536 let indent_str = " ".repeat(indent);
1537
1538 result.push_str(&format!("{indent_str}SelectStatement {{\n"));
1539
1540 result.push_str(&format!("{indent_str} columns: ["));
1542 if !stmt.columns.is_empty() {
1543 result.push('\n');
1544 for col in &stmt.columns {
1545 result.push_str(&format!("{indent_str} \"{col}\",\n"));
1546 }
1547 result.push_str(&format!("{indent_str} ],\n"));
1548 } else {
1549 result.push_str("],\n");
1550 }
1551
1552 if let Some(table) = &stmt.from_table {
1554 result.push_str(&format!("{indent_str} from_table: \"{table}\",\n"));
1555 }
1556
1557 if let Some(where_clause) = &stmt.where_clause {
1559 result.push_str(&format!("{indent_str} where_clause: {{\n"));
1560 result.push_str(&format_where_clause(where_clause, indent + 2));
1561 result.push_str(&format!("{indent_str} }},\n"));
1562 }
1563
1564 if let Some(order_by) = &stmt.order_by {
1566 result.push_str(&format!("{indent_str} order_by: ["));
1567 if !order_by.is_empty() {
1568 result.push('\n');
1569 for col in order_by {
1570 let dir = match col.direction {
1571 SortDirection::Asc => "ASC",
1572 SortDirection::Desc => "DESC",
1573 };
1574 result.push_str(&format!(
1575 "{indent_str} \"{col}\" {dir},\n",
1576 col = col.column
1577 ));
1578 }
1579 result.push_str(&format!("{indent_str} ],\n"));
1580 } else {
1581 result.push_str("],\n");
1582 }
1583 }
1584
1585 if let Some(group_by) = &stmt.group_by {
1587 result.push_str(&format!("{indent_str} group_by: ["));
1588 if !group_by.is_empty() {
1589 result.push('\n');
1590 for col in group_by {
1591 result.push_str(&format!("{indent_str} \"{col}\",\n"));
1592 }
1593 result.push_str(&format!("{indent_str} ],\n"));
1594 } else {
1595 result.push_str("]\n");
1596 }
1597 }
1598
1599 result.push_str(&format!("{indent_str}}}"));
1600 result
1601}
1602
1603fn format_where_clause(clause: &WhereClause, indent: usize) -> String {
1604 let mut result = String::new();
1605 let indent_str = " ".repeat(indent);
1606
1607 result.push_str(&format!("{indent_str}conditions: [\n"));
1608
1609 for condition in &clause.conditions {
1610 result.push_str(&format!("{indent_str} {{\n"));
1611 result.push_str(&format!(
1612 "{indent_str} expr: {},\n",
1613 format_expression_ast(&condition.expr)
1614 ));
1615
1616 if let Some(connector) = &condition.connector {
1617 let connector_str = match connector {
1618 LogicalOp::And => "AND",
1619 LogicalOp::Or => "OR",
1620 };
1621 result.push_str(&format!("{indent_str} connector: {connector_str},\n"));
1622 }
1623
1624 result.push_str(&format!("{indent_str} }},\n"));
1625 }
1626
1627 result.push_str(&format!("{indent_str}]\n"));
1628 result
1629}
1630
1631fn format_expression_ast(expr: &SqlExpression) -> String {
1632 match expr {
1633 SqlExpression::Column(name) => format!("Column(\"{}\")", name),
1634 SqlExpression::StringLiteral(value) => format!("StringLiteral(\"{}\")", value),
1635 SqlExpression::NumberLiteral(value) => format!("NumberLiteral({})", value),
1636 SqlExpression::DateTimeConstructor {
1637 year,
1638 month,
1639 day,
1640 hour,
1641 minute,
1642 second,
1643 } => {
1644 format!(
1645 "DateTime({}-{:02}-{:02} {:02}:{:02}:{:02})",
1646 year,
1647 month,
1648 day,
1649 hour.unwrap_or(0),
1650 minute.unwrap_or(0),
1651 second.unwrap_or(0)
1652 )
1653 }
1654 SqlExpression::DateTimeToday {
1655 hour,
1656 minute,
1657 second,
1658 } => {
1659 format!(
1660 "DateTimeToday({:02}:{:02}:{:02})",
1661 hour.unwrap_or(0),
1662 minute.unwrap_or(0),
1663 second.unwrap_or(0)
1664 )
1665 }
1666 SqlExpression::MethodCall {
1667 object,
1668 method,
1669 args,
1670 } => {
1671 let args_str = args
1672 .iter()
1673 .map(|a| format_expression_ast(a))
1674 .collect::<Vec<_>>()
1675 .join(", ");
1676 format!("MethodCall({}.{}({}))", object, method, args_str)
1677 }
1678 SqlExpression::ChainedMethodCall { base, method, args } => {
1679 let args_str = args
1680 .iter()
1681 .map(|a| format_expression_ast(a))
1682 .collect::<Vec<_>>()
1683 .join(", ");
1684 format!(
1685 "ChainedMethodCall({}.{}({}))",
1686 format_expression_ast(base),
1687 method,
1688 args_str
1689 )
1690 }
1691 SqlExpression::FunctionCall { name, args } => {
1692 let args_str = args
1693 .iter()
1694 .map(|a| format_expression_ast(a))
1695 .collect::<Vec<_>>()
1696 .join(", ");
1697 format!("FunctionCall({}({}))", name, args_str)
1698 }
1699 SqlExpression::BinaryOp { left, op, right } => {
1700 format!(
1701 "BinaryOp({} {} {})",
1702 format_expression_ast(left),
1703 op,
1704 format_expression_ast(right)
1705 )
1706 }
1707 SqlExpression::InList { expr, values } => {
1708 let list_str = values
1709 .iter()
1710 .map(|e| format_expression_ast(e))
1711 .collect::<Vec<_>>()
1712 .join(", ");
1713 format!("InList({} IN [{}])", format_expression_ast(expr), list_str)
1714 }
1715 SqlExpression::NotInList { expr, values } => {
1716 let list_str = values
1717 .iter()
1718 .map(|e| format_expression_ast(e))
1719 .collect::<Vec<_>>()
1720 .join(", ");
1721 format!(
1722 "NotInList({} NOT IN [{}])",
1723 format_expression_ast(expr),
1724 list_str
1725 )
1726 }
1727 SqlExpression::Between { expr, lower, upper } => {
1728 format!(
1729 "Between({} BETWEEN {} AND {})",
1730 format_expression_ast(expr),
1731 format_expression_ast(lower),
1732 format_expression_ast(upper)
1733 )
1734 }
1735 SqlExpression::Not { expr } => {
1736 format!("Not({})", format_expression_ast(expr))
1737 }
1738 SqlExpression::CaseExpression {
1739 when_branches,
1740 else_branch,
1741 } => {
1742 let when_strs: Vec<String> = when_branches
1743 .iter()
1744 .map(|branch| {
1745 format!(
1746 "WHEN {} THEN {}",
1747 format_expression_ast(&branch.condition),
1748 format_expression_ast(&branch.result)
1749 )
1750 })
1751 .collect();
1752 let else_str = else_branch
1753 .as_ref()
1754 .map(|e| format!(" ELSE {}", format_expression_ast(e)))
1755 .unwrap_or_default();
1756 format!("CASE {} {} END", when_strs.join(" "), else_str)
1757 }
1758 }
1759}
1760
1761pub fn datetime_to_iso_string(expr: &SqlExpression) -> Option<String> {
1763 match expr {
1764 SqlExpression::DateTimeConstructor {
1765 year,
1766 month,
1767 day,
1768 hour,
1769 minute,
1770 second,
1771 } => {
1772 let h = hour.unwrap_or(0);
1773 let m = minute.unwrap_or(0);
1774 let s = second.unwrap_or(0);
1775
1776 if let Ok(dt) = NaiveDateTime::parse_from_str(
1778 &format!(
1779 "{:04}-{:02}-{:02} {:02}:{:02}:{:02}",
1780 year, month, day, h, m, s
1781 ),
1782 "%Y-%m-%d %H:%M:%S",
1783 ) {
1784 Some(dt.format("%Y-%m-%d %H:%M:%S").to_string())
1785 } else {
1786 None
1787 }
1788 }
1789 SqlExpression::DateTimeToday {
1790 hour,
1791 minute,
1792 second,
1793 } => {
1794 let now = Local::now();
1795 let h = hour.unwrap_or(0);
1796 let m = minute.unwrap_or(0);
1797 let s = second.unwrap_or(0);
1798
1799 if let Ok(dt) = NaiveDateTime::parse_from_str(
1801 &format!(
1802 "{:04}-{:02}-{:02} {:02}:{:02}:{:02}",
1803 now.year(),
1804 now.month(),
1805 now.day(),
1806 h,
1807 m,
1808 s
1809 ),
1810 "%Y-%m-%d %H:%M:%S",
1811 ) {
1812 Some(dt.format("%Y-%m-%d %H:%M:%S").to_string())
1813 } else {
1814 None
1815 }
1816 }
1817 _ => None,
1818 }
1819}
1820
1821fn format_sql_with_preserved_parens(
1823 query: &str,
1824 cols_per_line: usize,
1825) -> Result<Vec<String>, String> {
1826 let mut lines = Vec::new();
1827 let mut lexer = Lexer::new(query);
1828 let tokens_with_pos = lexer.tokenize_all_with_positions();
1829
1830 if tokens_with_pos.is_empty() {
1831 return Err("No tokens found".to_string());
1832 }
1833
1834 let mut i = 0;
1835 let cols_per_line = cols_per_line.max(1);
1836
1837 while i < tokens_with_pos.len() {
1838 let (start, _end, ref token) = tokens_with_pos[i];
1839
1840 match token {
1841 Token::Select => {
1842 lines.push("SELECT".to_string());
1843 i += 1;
1844
1845 let mut columns = Vec::new();
1847 let mut col_start = i;
1848 while i < tokens_with_pos.len() {
1849 match &tokens_with_pos[i].2 {
1850 Token::From | Token::Eof => break,
1851 Token::Comma => {
1852 if col_start < i {
1854 let col_text = extract_text_between_positions(
1855 query,
1856 tokens_with_pos[col_start].0,
1857 tokens_with_pos[i - 1].1,
1858 );
1859 columns.push(col_text);
1860 }
1861 i += 1;
1862 col_start = i;
1863 }
1864 _ => i += 1,
1865 }
1866 }
1867 if col_start < i && i > 0 {
1869 let col_text = extract_text_between_positions(
1870 query,
1871 tokens_with_pos[col_start].0,
1872 tokens_with_pos[i - 1].1,
1873 );
1874 columns.push(col_text);
1875 }
1876
1877 for chunk in columns.chunks(cols_per_line) {
1879 let mut line = " ".to_string();
1880 for (idx, col) in chunk.iter().enumerate() {
1881 if idx > 0 {
1882 line.push_str(", ");
1883 }
1884 line.push_str(col.trim());
1885 }
1886 let is_last_chunk = chunk.as_ptr() as usize
1888 + chunk.len() * std::mem::size_of::<String>()
1889 >= columns.last().map(|c| c as *const _ as usize).unwrap_or(0);
1890 if !is_last_chunk && columns.len() > cols_per_line {
1891 line.push(',');
1892 }
1893 lines.push(line);
1894 }
1895 }
1896 Token::From => {
1897 i += 1;
1898 if i < tokens_with_pos.len() {
1899 let table_start = tokens_with_pos[i].0;
1900 while i < tokens_with_pos.len() {
1902 match &tokens_with_pos[i].2 {
1903 Token::Where | Token::OrderBy | Token::GroupBy | Token::Eof => break,
1904 _ => i += 1,
1905 }
1906 }
1907 if i > 0 {
1908 let table_text = extract_text_between_positions(
1909 query,
1910 table_start,
1911 tokens_with_pos[i - 1].1,
1912 );
1913 lines.push(format!("FROM {}", table_text.trim()));
1914 }
1915 }
1916 }
1917 Token::Where => {
1918 lines.push("WHERE".to_string());
1919 i += 1;
1920
1921 let where_start = if i < tokens_with_pos.len() {
1923 tokens_with_pos[i].0
1924 } else {
1925 start
1926 };
1927
1928 let mut where_end = query.len();
1930 while i < tokens_with_pos.len() {
1931 match &tokens_with_pos[i].2 {
1932 Token::OrderBy | Token::GroupBy | Token::Eof => {
1933 if i > 0 {
1934 where_end = tokens_with_pos[i - 1].1;
1935 }
1936 break;
1937 }
1938 _ => i += 1,
1939 }
1940 }
1941
1942 let where_text = extract_text_between_positions(query, where_start, where_end);
1944
1945 let formatted_where = format_where_clause_with_parens(&where_text);
1947 for line in formatted_where {
1948 lines.push(format!(" {}", line));
1949 }
1950 }
1951 Token::OrderBy => {
1952 i += 1;
1953 let order_start = if i < tokens_with_pos.len() {
1954 tokens_with_pos[i].0
1955 } else {
1956 start
1957 };
1958
1959 while i < tokens_with_pos.len() {
1961 match &tokens_with_pos[i].2 {
1962 Token::GroupBy | Token::Eof => break,
1963 _ => i += 1,
1964 }
1965 }
1966
1967 if i > 0 {
1968 let order_text = extract_text_between_positions(
1969 query,
1970 order_start,
1971 tokens_with_pos[i - 1].1,
1972 );
1973 lines.push(format!("ORDER BY {}", order_text.trim()));
1974 }
1975 }
1976 Token::GroupBy => {
1977 i += 1;
1978 let group_start = if i < tokens_with_pos.len() {
1979 tokens_with_pos[i].0
1980 } else {
1981 start
1982 };
1983
1984 while i < tokens_with_pos.len() {
1986 match &tokens_with_pos[i].2 {
1987 Token::Having | Token::Eof => break,
1988 _ => i += 1,
1989 }
1990 }
1991
1992 if i > 0 {
1993 let group_text = extract_text_between_positions(
1994 query,
1995 group_start,
1996 tokens_with_pos[i - 1].1,
1997 );
1998 lines.push(format!("GROUP BY {}", group_text.trim()));
1999 }
2000 }
2001 _ => i += 1,
2002 }
2003 }
2004
2005 Ok(lines)
2006}
2007
2008fn extract_text_between_positions(query: &str, start: usize, end: usize) -> String {
2010 let chars: Vec<char> = query.chars().collect();
2011 let start = start.min(chars.len());
2012 let end = end.min(chars.len());
2013 chars[start..end].iter().collect()
2014}
2015
2016fn format_where_clause_with_parens(where_text: &str) -> Vec<String> {
2018 let mut lines = Vec::new();
2019 let mut current_line = String::new();
2020 let mut paren_depth = 0;
2021 let mut i = 0;
2022 let chars: Vec<char> = where_text.chars().collect();
2023
2024 while i < chars.len() {
2025 if paren_depth == 0 {
2027 if i + 5 <= chars.len() {
2029 let next_five: String = chars[i..i + 5].iter().collect();
2030 if next_five.to_uppercase() == " AND " {
2031 if !current_line.trim().is_empty() {
2032 lines.push(current_line.trim().to_string());
2033 }
2034 lines.push("AND".to_string());
2035 current_line.clear();
2036 i += 5;
2037 continue;
2038 }
2039 }
2040 if i + 4 <= chars.len() {
2041 let next_four: String = chars[i..i + 4].iter().collect();
2042 if next_four.to_uppercase() == " OR " {
2043 if !current_line.trim().is_empty() {
2044 lines.push(current_line.trim().to_string());
2045 }
2046 lines.push("OR".to_string());
2047 current_line.clear();
2048 i += 4;
2049 continue;
2050 }
2051 }
2052 }
2053
2054 match chars[i] {
2056 '(' => {
2057 paren_depth += 1;
2058 current_line.push('(');
2059 }
2060 ')' => {
2061 paren_depth -= 1;
2062 current_line.push(')');
2063 }
2064 c => current_line.push(c),
2065 }
2066 i += 1;
2067 }
2068
2069 if !current_line.trim().is_empty() {
2071 lines.push(current_line.trim().to_string());
2072 }
2073
2074 if lines.is_empty() {
2076 lines.push(where_text.trim().to_string());
2077 }
2078
2079 lines
2080}
2081
2082pub fn format_sql_pretty_compact(query: &str, cols_per_line: usize) -> Vec<String> {
2083 if let Ok(lines) = format_sql_with_preserved_parens(query, cols_per_line) {
2085 return lines;
2086 }
2087
2088 let mut lines = Vec::new();
2090 let mut parser = Parser::new(query);
2091
2092 let cols_per_line = cols_per_line.max(1);
2094
2095 match parser.parse() {
2096 Ok(stmt) => {
2097 if !stmt.columns.is_empty() {
2099 lines.push("SELECT".to_string());
2100
2101 for chunk in stmt.columns.chunks(cols_per_line) {
2103 let mut line = " ".to_string();
2104 for (i, col) in chunk.iter().enumerate() {
2105 if i > 0 {
2106 line.push_str(", ");
2107 }
2108 line.push_str(col);
2109 }
2110 let last_chunk_idx = (stmt.columns.len() - 1) / cols_per_line;
2112 let current_chunk_idx =
2113 stmt.columns.iter().position(|c| c == &chunk[0]).unwrap() / cols_per_line;
2114 if current_chunk_idx < last_chunk_idx {
2115 line.push(',');
2116 }
2117 lines.push(line);
2118 }
2119 }
2120
2121 if let Some(table) = &stmt.from_table {
2123 lines.push(format!("FROM {}", table));
2124 }
2125
2126 if let Some(where_clause) = &stmt.where_clause {
2128 lines.push("WHERE".to_string());
2129 for (i, condition) in where_clause.conditions.iter().enumerate() {
2130 if i > 0 {
2131 if let Some(prev_condition) = where_clause.conditions.get(i - 1) {
2133 if let Some(connector) = &prev_condition.connector {
2134 match connector {
2135 LogicalOp::And => lines.push(" AND".to_string()),
2136 LogicalOp::Or => lines.push(" OR".to_string()),
2137 }
2138 }
2139 }
2140 }
2141 lines.push(format!(" {}", format_expression(&condition.expr)));
2142 }
2143 }
2144
2145 if let Some(order_by) = &stmt.order_by {
2147 let order_str = order_by
2148 .iter()
2149 .map(|col| {
2150 let dir = match col.direction {
2151 SortDirection::Asc => " ASC",
2152 SortDirection::Desc => " DESC",
2153 };
2154 format!("{}{}", col.column, dir)
2155 })
2156 .collect::<Vec<_>>()
2157 .join(", ");
2158 lines.push(format!("ORDER BY {}", order_str));
2159 }
2160
2161 if let Some(group_by) = &stmt.group_by {
2163 let group_str = group_by.join(", ");
2164 lines.push(format!("GROUP BY {}", group_str));
2165 }
2166 }
2167 Err(_) => {
2168 let mut lexer = Lexer::new(query);
2170 let tokens = lexer.tokenize_all();
2171 let mut current_line = String::new();
2172 let mut indent = 0;
2173
2174 for token in tokens {
2175 match &token {
2176 Token::Select
2177 | Token::From
2178 | Token::Where
2179 | Token::OrderBy
2180 | Token::GroupBy => {
2181 if !current_line.is_empty() {
2182 lines.push(current_line.trim().to_string());
2183 current_line.clear();
2184 }
2185 lines.push(format!("{:?}", token).to_uppercase());
2186 indent = 1;
2187 }
2188 Token::And | Token::Or => {
2189 if !current_line.is_empty() {
2190 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2191 current_line.clear();
2192 }
2193 lines.push(format!(" {:?}", token).to_uppercase());
2194 }
2195 Token::Comma => {
2196 current_line.push(',');
2197 if indent > 0 {
2198 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2199 current_line.clear();
2200 }
2201 }
2202 Token::Eof => break,
2203 _ => {
2204 if !current_line.is_empty() {
2205 current_line.push(' ');
2206 }
2207 current_line.push_str(&format_token(&token));
2208 }
2209 }
2210 }
2211
2212 if !current_line.is_empty() {
2213 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2214 }
2215 }
2216 }
2217
2218 lines
2219}
2220
2221fn format_expression(expr: &SqlExpression) -> String {
2222 match expr {
2223 SqlExpression::Column(name) => name.clone(),
2224 SqlExpression::StringLiteral(s) => format!("'{}'", s),
2225 SqlExpression::NumberLiteral(n) => n.clone(),
2226 SqlExpression::DateTimeConstructor {
2227 year,
2228 month,
2229 day,
2230 hour,
2231 minute,
2232 second,
2233 } => {
2234 let mut result = format!("DateTime({}, {}, {}", year, month, day);
2235 if let Some(h) = hour {
2236 result.push_str(&format!(", {}", h));
2237 if let Some(m) = minute {
2238 result.push_str(&format!(", {}", m));
2239 if let Some(s) = second {
2240 result.push_str(&format!(", {}", s));
2241 }
2242 }
2243 }
2244 result.push(')');
2245 result
2246 }
2247 SqlExpression::DateTimeToday {
2248 hour,
2249 minute,
2250 second,
2251 } => {
2252 let mut result = "DateTime()".to_string();
2253 if let Some(h) = hour {
2254 result = format!("DateTime(TODAY, {}", h);
2255 if let Some(m) = minute {
2256 result.push_str(&format!(", {}", m));
2257 if let Some(s) = second {
2258 result.push_str(&format!(", {}", s));
2259 }
2260 }
2261 result.push(')');
2262 }
2263 result
2264 }
2265 SqlExpression::MethodCall {
2266 object,
2267 method,
2268 args,
2269 } => {
2270 let args_str = args
2271 .iter()
2272 .map(|arg| format_expression(arg))
2273 .collect::<Vec<_>>()
2274 .join(", ");
2275 format!("{}.{}({})", object, method, args_str)
2276 }
2277 SqlExpression::BinaryOp { left, op, right } => {
2278 if op == "OR" || op == "AND" {
2281 format!(
2284 "({} {} {})",
2285 format_expression(left),
2286 op,
2287 format_expression(right)
2288 )
2289 } else {
2290 format!(
2291 "{} {} {}",
2292 format_expression(left),
2293 op,
2294 format_expression(right)
2295 )
2296 }
2297 }
2298 SqlExpression::InList { expr, values } => {
2299 let values_str = values
2300 .iter()
2301 .map(|v| format_expression(v))
2302 .collect::<Vec<_>>()
2303 .join(", ");
2304 format!("{} IN ({})", format_expression(expr), values_str)
2305 }
2306 SqlExpression::NotInList { expr, values } => {
2307 let values_str = values
2308 .iter()
2309 .map(|v| format_expression(v))
2310 .collect::<Vec<_>>()
2311 .join(", ");
2312 format!("{} NOT IN ({})", format_expression(expr), values_str)
2313 }
2314 SqlExpression::Between { expr, lower, upper } => {
2315 format!(
2316 "{} BETWEEN {} AND {}",
2317 format_expression(expr),
2318 format_expression(lower),
2319 format_expression(upper)
2320 )
2321 }
2322 SqlExpression::Not { expr } => {
2323 format!("NOT {}", format_expression(expr))
2324 }
2325 SqlExpression::ChainedMethodCall { base, method, args } => {
2326 let args_str = args
2327 .iter()
2328 .map(|arg| format_expression(arg))
2329 .collect::<Vec<_>>()
2330 .join(", ");
2331 format!("{}.{}({})", format_expression(base), method, args_str)
2332 }
2333 SqlExpression::FunctionCall { name, args } => {
2334 let args_str = args
2335 .iter()
2336 .map(|arg| format_expression(arg))
2337 .collect::<Vec<_>>()
2338 .join(", ");
2339 format!("{}({})", name, args_str)
2340 }
2341 SqlExpression::CaseExpression {
2342 when_branches,
2343 else_branch,
2344 } => {
2345 let mut result = String::from("CASE");
2346 for branch in when_branches {
2347 result.push_str(&format!(
2348 " WHEN {} THEN {}",
2349 format_expression(&branch.condition),
2350 format_expression(&branch.result)
2351 ));
2352 }
2353 if let Some(else_expr) = else_branch {
2354 result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
2355 }
2356 result.push_str(" END");
2357 result
2358 }
2359 }
2360}
2361
2362fn format_token(token: &Token) -> String {
2363 match token {
2364 Token::Identifier(s) => s.clone(),
2365 Token::QuotedIdentifier(s) => format!("\"{}\"", s),
2366 Token::StringLiteral(s) => format!("'{}'", s),
2367 Token::NumberLiteral(n) => n.clone(),
2368 Token::DateTime => "DateTime".to_string(),
2369 Token::Case => "CASE".to_string(),
2370 Token::When => "WHEN".to_string(),
2371 Token::Then => "THEN".to_string(),
2372 Token::Else => "ELSE".to_string(),
2373 Token::End => "END".to_string(),
2374 Token::LeftParen => "(".to_string(),
2375 Token::RightParen => ")".to_string(),
2376 Token::Comma => ",".to_string(),
2377 Token::Dot => ".".to_string(),
2378 Token::Equal => "=".to_string(),
2379 Token::NotEqual => "!=".to_string(),
2380 Token::LessThan => "<".to_string(),
2381 Token::GreaterThan => ">".to_string(),
2382 Token::LessThanOrEqual => "<=".to_string(),
2383 Token::GreaterThanOrEqual => ">=".to_string(),
2384 Token::In => "IN".to_string(),
2385 _ => format!("{:?}", token).to_uppercase(),
2386 }
2387}
2388
2389fn analyze_statement(
2390 stmt: &SelectStatement,
2391 query: &str,
2392 _cursor_pos: usize,
2393) -> (CursorContext, Option<String>) {
2394 let trimmed = query.trim();
2396
2397 let comparison_ops = [" > ", " < ", " >= ", " <= ", " = ", " != "];
2399 for op in &comparison_ops {
2400 if let Some(op_pos) = query.rfind(op) {
2401 let before_op = safe_slice_to(query, op_pos);
2402 let after_op_start = op_pos + op.len();
2403 let after_op = if after_op_start < query.len() {
2404 &query[after_op_start..]
2405 } else {
2406 ""
2407 };
2408
2409 if let Some(col_name) = before_op.split_whitespace().last() {
2411 if col_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
2412 let after_op_trimmed = after_op.trim();
2414 if after_op_trimmed.is_empty()
2415 || (after_op_trimmed
2416 .chars()
2417 .all(|c| c.is_alphanumeric() || c == '_')
2418 && !after_op_trimmed.contains('('))
2419 {
2420 let partial = if after_op_trimmed.is_empty() {
2421 None
2422 } else {
2423 Some(after_op_trimmed.to_string())
2424 };
2425 return (
2426 CursorContext::AfterComparisonOp(
2427 col_name.to_string(),
2428 op.trim().to_string(),
2429 ),
2430 partial,
2431 );
2432 }
2433 }
2434 }
2435 }
2436 }
2437
2438 if trimmed.to_uppercase().ends_with(" AND")
2440 || trimmed.to_uppercase().ends_with(" OR")
2441 || trimmed.to_uppercase().ends_with(" AND ")
2442 || trimmed.to_uppercase().ends_with(" OR ")
2443 {
2444 } else {
2446 if let Some(dot_pos) = trimmed.rfind('.') {
2448 let before_dot = safe_slice_to(trimmed, dot_pos);
2450 let after_dot_start = dot_pos + 1;
2451 let after_dot = if after_dot_start < trimmed.len() {
2452 &trimmed[after_dot_start..]
2453 } else {
2454 ""
2455 };
2456
2457 if !after_dot.contains('(') {
2460 let col_name = if before_dot.ends_with('"') {
2462 let bytes = before_dot.as_bytes();
2464 let mut pos = before_dot.len() - 1; let mut found_start = None;
2466
2467 if pos > 0 {
2469 pos -= 1;
2470 while pos > 0 {
2471 if bytes[pos] == b'"' {
2472 if pos == 0 || bytes[pos - 1] != b'\\' {
2474 found_start = Some(pos);
2475 break;
2476 }
2477 }
2478 pos -= 1;
2479 }
2480 if found_start.is_none() && bytes[0] == b'"' {
2482 found_start = Some(0);
2483 }
2484 }
2485
2486 if let Some(start) = found_start {
2487 Some(safe_slice_from(before_dot, start))
2489 } else {
2490 None
2491 }
2492 } else {
2493 before_dot
2496 .split_whitespace()
2497 .last()
2498 .map(|word| word.trim_start_matches('('))
2499 };
2500
2501 if let Some(col_name) = col_name {
2502 let is_valid = if col_name.starts_with('"') && col_name.ends_with('"') {
2504 true
2506 } else {
2507 col_name.chars().all(|c| c.is_alphanumeric() || c == '_')
2509 };
2510
2511 if is_valid {
2512 let partial_method = if after_dot.is_empty() {
2515 None
2516 } else if after_dot.chars().all(|c| c.is_alphanumeric() || c == '_') {
2517 Some(after_dot.to_string())
2518 } else {
2519 None
2520 };
2521
2522 let col_name_for_context = if col_name.starts_with('"')
2524 && col_name.ends_with('"')
2525 && col_name.len() > 2
2526 {
2527 col_name[1..col_name.len() - 1].to_string()
2528 } else {
2529 col_name.to_string()
2530 };
2531
2532 return (
2533 CursorContext::AfterColumn(col_name_for_context),
2534 partial_method,
2535 );
2536 }
2537 }
2538 }
2539 }
2540 }
2541
2542 if let Some(where_clause) = &stmt.where_clause {
2544 if trimmed.to_uppercase().ends_with(" AND") || trimmed.to_uppercase().ends_with(" OR") {
2546 let op = if trimmed.to_uppercase().ends_with(" AND") {
2547 LogicalOp::And
2548 } else {
2549 LogicalOp::Or
2550 };
2551 return (CursorContext::AfterLogicalOp(op), None);
2552 }
2553
2554 if let Some(and_pos) = query.to_uppercase().rfind(" AND ") {
2556 let after_and = safe_slice_from(query, and_pos + 5);
2557 let partial = extract_partial_at_end(after_and);
2558 if partial.is_some() {
2559 return (CursorContext::AfterLogicalOp(LogicalOp::And), partial);
2560 }
2561 }
2562
2563 if let Some(or_pos) = query.to_uppercase().rfind(" OR ") {
2564 let after_or = safe_slice_from(query, or_pos + 4);
2565 let partial = extract_partial_at_end(after_or);
2566 if partial.is_some() {
2567 return (CursorContext::AfterLogicalOp(LogicalOp::Or), partial);
2568 }
2569 }
2570
2571 if let Some(last_condition) = where_clause.conditions.last() {
2572 if let Some(connector) = &last_condition.connector {
2573 return (
2575 CursorContext::AfterLogicalOp(connector.clone()),
2576 extract_partial_at_end(query),
2577 );
2578 }
2579 }
2580 return (CursorContext::WhereClause, extract_partial_at_end(query));
2582 }
2583
2584 if query.to_uppercase().ends_with(" ORDER BY ") || query.to_uppercase().ends_with(" ORDER BY") {
2586 return (CursorContext::OrderByClause, None);
2587 }
2588
2589 if stmt.order_by.is_some() {
2591 return (CursorContext::OrderByClause, extract_partial_at_end(query));
2592 }
2593
2594 if stmt.from_table.is_some() && stmt.where_clause.is_none() && stmt.order_by.is_none() {
2595 return (CursorContext::FromClause, extract_partial_at_end(query));
2596 }
2597
2598 if stmt.columns.len() > 0 && stmt.from_table.is_none() {
2599 return (CursorContext::SelectClause, extract_partial_at_end(query));
2600 }
2601
2602 (CursorContext::Unknown, None)
2603}
2604
2605fn analyze_partial(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
2606 let upper = query.to_uppercase();
2607
2608 let trimmed = query.trim();
2610
2611 #[cfg(test)]
2612 {
2613 if trimmed.contains("\"Last Name\"") {
2614 eprintln!(
2615 "DEBUG analyze_partial: query='{}', trimmed='{}'",
2616 query, trimmed
2617 );
2618 }
2619 }
2620
2621 let comparison_ops = [" > ", " < ", " >= ", " <= ", " = ", " != "];
2623 for op in &comparison_ops {
2624 if let Some(op_pos) = query.rfind(op) {
2625 let before_op = safe_slice_to(query, op_pos);
2626 let after_op_start = op_pos + op.len();
2627 let after_op = if after_op_start < query.len() {
2628 &query[after_op_start..]
2629 } else {
2630 ""
2631 };
2632
2633 if let Some(col_name) = before_op.split_whitespace().last() {
2635 if col_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
2636 let after_op_trimmed = after_op.trim();
2638 if after_op_trimmed.is_empty()
2639 || (after_op_trimmed
2640 .chars()
2641 .all(|c| c.is_alphanumeric() || c == '_')
2642 && !after_op_trimmed.contains('('))
2643 {
2644 let partial = if after_op_trimmed.is_empty() {
2645 None
2646 } else {
2647 Some(after_op_trimmed.to_string())
2648 };
2649 return (
2650 CursorContext::AfterComparisonOp(
2651 col_name.to_string(),
2652 op.trim().to_string(),
2653 ),
2654 partial,
2655 );
2656 }
2657 }
2658 }
2659 }
2660 }
2661
2662 if let Some(dot_pos) = trimmed.rfind('.') {
2665 #[cfg(test)]
2666 {
2667 if trimmed.contains("\"Last Name\"") {
2668 eprintln!("DEBUG: Found dot at position {}", dot_pos);
2669 }
2670 }
2671 let before_dot = &trimmed[..dot_pos];
2673 let after_dot = &trimmed[dot_pos + 1..];
2674
2675 if !after_dot.contains('(') {
2678 let col_name = if before_dot.ends_with('"') {
2681 let bytes = before_dot.as_bytes();
2683 let mut pos = before_dot.len() - 1; let mut found_start = None;
2685
2686 #[cfg(test)]
2687 {
2688 if trimmed.contains("\"Last Name\"") {
2689 eprintln!(
2690 "DEBUG: before_dot='{}', looking for opening quote",
2691 before_dot
2692 );
2693 }
2694 }
2695
2696 if pos > 0 {
2698 pos -= 1;
2699 while pos > 0 {
2700 if bytes[pos] == b'"' {
2701 if pos == 0 || bytes[pos - 1] != b'\\' {
2703 found_start = Some(pos);
2704 break;
2705 }
2706 }
2707 pos -= 1;
2708 }
2709 if found_start.is_none() && bytes[0] == b'"' {
2711 found_start = Some(0);
2712 }
2713 }
2714
2715 if let Some(start) = found_start {
2716 let result = safe_slice_from(before_dot, start);
2718 #[cfg(test)]
2719 {
2720 if trimmed.contains("\"Last Name\"") {
2721 eprintln!("DEBUG: Extracted quoted identifier: '{}'", result);
2722 }
2723 }
2724 Some(result)
2725 } else {
2726 #[cfg(test)]
2727 {
2728 if trimmed.contains("\"Last Name\"") {
2729 eprintln!("DEBUG: No opening quote found!");
2730 }
2731 }
2732 None
2733 }
2734 } else {
2735 before_dot
2738 .split_whitespace()
2739 .last()
2740 .map(|word| word.trim_start_matches('('))
2741 };
2742
2743 if let Some(col_name) = col_name {
2744 #[cfg(test)]
2745 {
2746 if trimmed.contains("\"Last Name\"") {
2747 eprintln!("DEBUG: col_name = '{}'", col_name);
2748 }
2749 }
2750
2751 let is_valid = if col_name.starts_with('"') && col_name.ends_with('"') {
2753 true
2755 } else {
2756 col_name.chars().all(|c| c.is_alphanumeric() || c == '_')
2758 };
2759
2760 #[cfg(test)]
2761 {
2762 if trimmed.contains("\"Last Name\"") {
2763 eprintln!("DEBUG: is_valid = {}", is_valid);
2764 }
2765 }
2766
2767 if is_valid {
2768 let partial_method = if after_dot.is_empty() {
2771 None
2772 } else if after_dot.chars().all(|c| c.is_alphanumeric() || c == '_') {
2773 Some(after_dot.to_string())
2774 } else {
2775 None
2776 };
2777
2778 let col_name_for_context = if col_name.starts_with('"')
2780 && col_name.ends_with('"')
2781 && col_name.len() > 2
2782 {
2783 col_name[1..col_name.len() - 1].to_string()
2784 } else {
2785 col_name.to_string()
2786 };
2787
2788 return (
2789 CursorContext::AfterColumn(col_name_for_context),
2790 partial_method,
2791 );
2792 }
2793 }
2794 }
2795 }
2796
2797 if let Some(and_pos) = upper.rfind(" AND ") {
2799 if cursor_pos >= and_pos + 5 {
2801 let after_and = safe_slice_from(query, and_pos + 5);
2803 let partial = extract_partial_at_end(after_and);
2804 return (CursorContext::AfterLogicalOp(LogicalOp::And), partial);
2805 }
2806 }
2807
2808 if let Some(or_pos) = upper.rfind(" OR ") {
2809 if cursor_pos >= or_pos + 4 {
2811 let after_or = safe_slice_from(query, or_pos + 4);
2813 let partial = extract_partial_at_end(after_or);
2814 return (CursorContext::AfterLogicalOp(LogicalOp::Or), partial);
2815 }
2816 }
2817
2818 if trimmed.to_uppercase().ends_with(" AND") || trimmed.to_uppercase().ends_with(" OR") {
2820 let op = if trimmed.to_uppercase().ends_with(" AND") {
2821 LogicalOp::And
2822 } else {
2823 LogicalOp::Or
2824 };
2825 return (CursorContext::AfterLogicalOp(op), None);
2826 }
2827
2828 if upper.ends_with(" ORDER BY ") || upper.ends_with(" ORDER BY") || upper.contains("ORDER BY ")
2830 {
2831 return (CursorContext::OrderByClause, extract_partial_at_end(query));
2832 }
2833
2834 if upper.contains("WHERE") && !upper.contains("ORDER") && !upper.contains("GROUP") {
2835 return (CursorContext::WhereClause, extract_partial_at_end(query));
2836 }
2837
2838 if upper.contains("FROM") && !upper.contains("WHERE") && !upper.contains("ORDER") {
2839 return (CursorContext::FromClause, extract_partial_at_end(query));
2840 }
2841
2842 if upper.contains("SELECT") && !upper.contains("FROM") {
2843 return (CursorContext::SelectClause, extract_partial_at_end(query));
2844 }
2845
2846 (CursorContext::Unknown, None)
2847}
2848
2849fn extract_partial_at_end(query: &str) -> Option<String> {
2850 let trimmed = query.trim();
2851
2852 if let Some(last_word) = trimmed.split_whitespace().last() {
2854 if last_word.starts_with('"') && !last_word.ends_with('"') {
2855 return Some(last_word.to_string());
2857 }
2858 }
2859
2860 let last_word = trimmed.split_whitespace().last()?;
2862
2863 if last_word.chars().all(|c| c.is_alphanumeric() || c == '_') && !is_sql_keyword(last_word) {
2865 Some(last_word.to_string())
2866 } else {
2867 None
2868 }
2869}
2870
2871fn is_sql_keyword(word: &str) -> bool {
2872 matches!(
2873 word.to_uppercase().as_str(),
2874 "SELECT"
2875 | "FROM"
2876 | "WHERE"
2877 | "AND"
2878 | "OR"
2879 | "IN"
2880 | "ORDER"
2881 | "BY"
2882 | "GROUP"
2883 | "HAVING"
2884 | "ASC"
2885 | "DESC"
2886 )
2887}
2888
2889#[cfg(test)]
2890mod tests {
2891 use super::*;
2892
2893 #[test]
2894 fn test_chained_method_calls() {
2895 let query = "SELECT * FROM trades WHERE commission.ToString().IndexOf('.') = 1";
2897 let mut parser = Parser::new(query);
2898 let result = parser.parse();
2899
2900 assert!(
2901 result.is_ok(),
2902 "Failed to parse chained method calls: {:?}",
2903 result
2904 );
2905
2906 let query2 = "SELECT * FROM data WHERE field.ToUpper().Replace('A', 'B').StartsWith('C')";
2908 let mut parser2 = Parser::new(query2);
2909 let result2 = parser2.parse();
2910
2911 assert!(
2912 result2.is_ok(),
2913 "Failed to parse multiple chained calls: {:?}",
2914 result2
2915 );
2916 }
2917
2918 #[test]
2919 fn test_tokenizer() {
2920 let mut lexer = Lexer::new("SELECT * FROM trade_deal WHERE price > 100");
2921
2922 assert!(matches!(lexer.next_token(), Token::Select));
2923 assert!(matches!(lexer.next_token(), Token::Star));
2924 assert!(matches!(lexer.next_token(), Token::From));
2925 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "trade_deal"));
2926 assert!(matches!(lexer.next_token(), Token::Where));
2927 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "price"));
2928 assert!(matches!(lexer.next_token(), Token::GreaterThan));
2929 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "100"));
2930 }
2931
2932 #[test]
2933 fn test_tokenizer_datetime() {
2934 let mut lexer = Lexer::new("WHERE createdDate > DateTime(2025, 10, 20)");
2935
2936 assert!(matches!(lexer.next_token(), Token::Where));
2937 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "createdDate"));
2938 assert!(matches!(lexer.next_token(), Token::GreaterThan));
2939 assert!(matches!(lexer.next_token(), Token::DateTime));
2940 assert!(matches!(lexer.next_token(), Token::LeftParen));
2941 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "2025"));
2942 assert!(matches!(lexer.next_token(), Token::Comma));
2943 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "10"));
2944 assert!(matches!(lexer.next_token(), Token::Comma));
2945 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "20"));
2946 assert!(matches!(lexer.next_token(), Token::RightParen));
2947 }
2948
2949 #[test]
2950 fn test_parse_simple_select() {
2951 let mut parser = Parser::new("SELECT * FROM trade_deal");
2952 let stmt = parser.parse().unwrap();
2953
2954 assert_eq!(stmt.columns, vec!["*"]);
2955 assert_eq!(stmt.from_table, Some("trade_deal".to_string()));
2956 assert!(stmt.where_clause.is_none());
2957 }
2958
2959 #[test]
2960 fn test_parse_where_with_method() {
2961 let mut parser = Parser::new("SELECT * FROM trade_deal WHERE name.Contains(\"test\")");
2962 let stmt = parser.parse().unwrap();
2963
2964 assert!(stmt.where_clause.is_some());
2965 let where_clause = stmt.where_clause.unwrap();
2966 assert_eq!(where_clause.conditions.len(), 1);
2967 }
2968
2969 #[test]
2970 fn test_parse_datetime_constructor() {
2971 let mut parser =
2972 Parser::new("SELECT * FROM trade_deal WHERE createdDate > DateTime(2025, 10, 20)");
2973 let stmt = parser.parse().unwrap();
2974
2975 assert!(stmt.where_clause.is_some());
2976 let where_clause = stmt.where_clause.unwrap();
2977 assert_eq!(where_clause.conditions.len(), 1);
2978
2979 if let SqlExpression::BinaryOp { left, op, right } = &where_clause.conditions[0].expr {
2981 assert_eq!(op, ">");
2982 assert!(matches!(left.as_ref(), SqlExpression::Column(col) if col == "createdDate"));
2983 assert!(matches!(
2984 right.as_ref(),
2985 SqlExpression::DateTimeConstructor {
2986 year: 2025,
2987 month: 10,
2988 day: 20,
2989 hour: None,
2990 minute: None,
2991 second: None
2992 }
2993 ));
2994 } else {
2995 panic!("Expected BinaryOp with DateTime constructor");
2996 }
2997 }
2998
2999 #[test]
3000 fn test_cursor_context_after_and() {
3001 let query = "SELECT * FROM trade_deal WHERE status = 'active' AND ";
3002 let (context, partial) = detect_cursor_context(query, query.len());
3003
3004 assert!(matches!(
3005 context,
3006 CursorContext::AfterLogicalOp(LogicalOp::And)
3007 ));
3008 assert_eq!(partial, None);
3009 }
3010
3011 #[test]
3012 fn test_cursor_context_with_partial() {
3013 let query = "SELECT * FROM trade_deal WHERE status = 'active' AND p";
3014 let (context, partial) = detect_cursor_context(query, query.len());
3015
3016 assert!(matches!(
3017 context,
3018 CursorContext::AfterLogicalOp(LogicalOp::And)
3019 ));
3020 assert_eq!(partial, Some("p".to_string()));
3021 }
3022
3023 #[test]
3024 fn test_cursor_context_after_datetime_comparison() {
3025 let query = "SELECT * FROM trade_deal WHERE createdDate > ";
3026 let (context, partial) = detect_cursor_context(query, query.len());
3027
3028 assert!(
3029 matches!(context, CursorContext::AfterComparisonOp(col, op) if col == "createdDate" && op == ">")
3030 );
3031 assert_eq!(partial, None);
3032 }
3033
3034 #[test]
3035 fn test_cursor_context_partial_datetime() {
3036 let query = "SELECT * FROM trade_deal WHERE createdDate > Date";
3037 let (context, partial) = detect_cursor_context(query, query.len());
3038
3039 assert!(
3040 matches!(context, CursorContext::AfterComparisonOp(col, op) if col == "createdDate" && op == ">")
3041 );
3042 assert_eq!(partial, Some("Date".to_string()));
3043 }
3044
3045 #[test]
3047 fn test_tokenizer_quoted_identifier() {
3048 let mut lexer = Lexer::new(r#"SELECT "Customer Id", "First Name" FROM customers"#);
3049
3050 assert!(matches!(lexer.next_token(), Token::Select));
3051 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "Customer Id"));
3052 assert!(matches!(lexer.next_token(), Token::Comma));
3053 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "First Name"));
3054 assert!(matches!(lexer.next_token(), Token::From));
3055 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "customers"));
3056 }
3057
3058 #[test]
3059 fn test_tokenizer_quoted_vs_string_literal() {
3060 let mut lexer = Lexer::new(r#"WHERE "Customer Id" = 'John' AND Country.Contains('USA')"#);
3062
3063 assert!(matches!(lexer.next_token(), Token::Where));
3064 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "Customer Id"));
3065 assert!(matches!(lexer.next_token(), Token::Equal));
3066 assert!(matches!(lexer.next_token(), Token::StringLiteral(s) if s == "John"));
3067 assert!(matches!(lexer.next_token(), Token::And));
3068 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Country"));
3069 assert!(matches!(lexer.next_token(), Token::Dot));
3070 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Contains"));
3071 assert!(matches!(lexer.next_token(), Token::LeftParen));
3072 assert!(matches!(lexer.next_token(), Token::StringLiteral(s) if s == "USA"));
3073 assert!(matches!(lexer.next_token(), Token::RightParen));
3074 }
3075
3076 #[test]
3077 fn test_tokenizer_method_with_double_quotes_should_be_string() {
3078 let mut lexer = Lexer::new(r#"Country.Contains("Alb")"#);
3081
3082 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Country"));
3083 assert!(matches!(lexer.next_token(), Token::Dot));
3084 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Contains"));
3085 assert!(matches!(lexer.next_token(), Token::LeftParen));
3086
3087 let token = lexer.next_token();
3090 println!("Token for \"Alb\": {:?}", token);
3091 assert!(matches!(lexer.next_token(), Token::RightParen));
3095 }
3096
3097 #[test]
3098 fn test_parse_select_with_quoted_columns() {
3099 let mut parser = Parser::new(r#"SELECT "Customer Id", Company FROM customers"#);
3100 let stmt = parser.parse().unwrap();
3101
3102 assert_eq!(stmt.columns, vec!["Customer Id", "Company"]);
3103 assert_eq!(stmt.from_table, Some("customers".to_string()));
3104 }
3105
3106 #[test]
3107 fn test_cursor_context_select_with_partial_quoted() {
3108 let query = r#"SELECT "Cust"#;
3110 let (context, partial) = detect_cursor_context(query, query.len() - 1); println!("Context: {:?}, Partial: {:?}", context, partial);
3113 assert!(matches!(context, CursorContext::SelectClause));
3114 }
3117
3118 #[test]
3119 fn test_cursor_context_select_after_comma_with_quoted() {
3120 let query = r#"SELECT Company, "Customer "#;
3122 let (context, partial) = detect_cursor_context(query, query.len());
3123
3124 println!("Context: {:?}, Partial: {:?}", context, partial);
3125 assert!(matches!(context, CursorContext::SelectClause));
3126 }
3128
3129 #[test]
3130 fn test_cursor_context_order_by_quoted() {
3131 let query = r#"SELECT * FROM customers ORDER BY "Cust"#;
3132 let (context, partial) = detect_cursor_context(query, query.len() - 1);
3133
3134 println!("Context: {:?}, Partial: {:?}", context, partial);
3135 assert!(matches!(context, CursorContext::OrderByClause));
3136 }
3138
3139 #[test]
3140 fn test_where_clause_with_quoted_column() {
3141 let mut parser = Parser::new(r#"SELECT * FROM customers WHERE "Customer Id" = 'C123'"#);
3142 let stmt = parser.parse().unwrap();
3143
3144 assert!(stmt.where_clause.is_some());
3145 let where_clause = stmt.where_clause.unwrap();
3146 assert_eq!(where_clause.conditions.len(), 1);
3147
3148 if let SqlExpression::BinaryOp { left, op, right } = &where_clause.conditions[0].expr {
3149 assert_eq!(op, "=");
3150 assert!(matches!(left.as_ref(), SqlExpression::Column(col) if col == "Customer Id"));
3151 assert!(matches!(right.as_ref(), SqlExpression::StringLiteral(s) if s == "C123"));
3152 } else {
3153 panic!("Expected BinaryOp");
3154 }
3155 }
3156
3157 #[test]
3158 fn test_parse_method_with_double_quotes_as_string() {
3159 let mut parser = Parser::new(r#"SELECT * FROM customers WHERE Country.Contains("USA")"#);
3161 let stmt = parser.parse().unwrap();
3162
3163 assert!(stmt.where_clause.is_some());
3164 let where_clause = stmt.where_clause.unwrap();
3165 assert_eq!(where_clause.conditions.len(), 1);
3166
3167 if let SqlExpression::MethodCall {
3168 object,
3169 method,
3170 args,
3171 } = &where_clause.conditions[0].expr
3172 {
3173 assert_eq!(object, "Country");
3174 assert_eq!(method, "Contains");
3175 assert_eq!(args.len(), 1);
3176 assert!(matches!(&args[0], SqlExpression::StringLiteral(s) if s == "USA"));
3178 } else {
3179 panic!("Expected MethodCall");
3180 }
3181 }
3182
3183 #[test]
3184 fn test_extract_partial_with_quoted_columns_in_query() {
3185 let query = r#"SELECT City,Company,Country,"Customer Id" FROM customers ORDER BY coun"#;
3187 let (context, partial) = detect_cursor_context(query, query.len());
3188
3189 assert!(matches!(context, CursorContext::OrderByClause));
3190 assert_eq!(
3191 partial,
3192 Some("coun".to_string()),
3193 "Should extract 'coun' as partial, not everything after the quoted column"
3194 );
3195 }
3196
3197 #[test]
3198 fn test_extract_partial_quoted_identifier_being_typed() {
3199 let query = r#"SELECT "Cust"#;
3201 let partial = extract_partial_at_end(query);
3202 assert_eq!(partial, Some("\"Cust".to_string()));
3203
3204 let query2 = r#"SELECT "Customer Id" FROM"#;
3206 let partial2 = extract_partial_at_end(query2);
3207 assert_eq!(partial2, None); }
3209
3210 #[test]
3212 fn test_complex_where_parentheses_basic() {
3213 let mut parser =
3215 Parser::new(r#"SELECT * FROM trades WHERE (status = "active" OR status = "pending")"#);
3216 let stmt = parser.parse().unwrap();
3217
3218 assert!(stmt.where_clause.is_some());
3219 let where_clause = stmt.where_clause.unwrap();
3220 assert_eq!(where_clause.conditions.len(), 1);
3221
3222 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[0].expr {
3224 assert_eq!(op, "OR");
3225 } else {
3226 panic!("Expected BinaryOp with OR");
3227 }
3228 }
3229
3230 #[test]
3231 fn test_complex_where_mixed_and_or_with_parens() {
3232 let mut parser = Parser::new(
3234 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#,
3235 );
3236 let stmt = parser.parse().unwrap();
3237
3238 assert!(stmt.where_clause.is_some());
3239 let where_clause = stmt.where_clause.unwrap();
3240 assert_eq!(where_clause.conditions.len(), 2);
3241
3242 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[0].expr {
3244 assert_eq!(op, "OR");
3245 } else {
3246 panic!("Expected first condition to be OR expression");
3247 }
3248
3249 assert!(matches!(
3251 where_clause.conditions[0].connector,
3252 Some(LogicalOp::And)
3253 ));
3254
3255 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[1].expr {
3257 assert_eq!(op, ">");
3258 } else {
3259 panic!("Expected second condition to be price > 100");
3260 }
3261 }
3262
3263 #[test]
3264 fn test_complex_where_nested_parentheses() {
3265 let mut parser = Parser::new(
3267 r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND price > 100) OR status = "cancelled""#,
3268 );
3269 let stmt = parser.parse().unwrap();
3270
3271 assert!(stmt.where_clause.is_some());
3272 let where_clause = stmt.where_clause.unwrap();
3273
3274 assert!(where_clause.conditions.len() > 0);
3276 }
3277
3278 #[test]
3279 fn test_complex_where_multiple_or_groups() {
3280 let mut parser = Parser::new(
3282 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL" OR symbol = "MSFT") AND (price > 100 AND price < 500)"#,
3283 );
3284 let stmt = parser.parse().unwrap();
3285
3286 assert!(stmt.where_clause.is_some());
3287 let where_clause = stmt.where_clause.unwrap();
3288 assert_eq!(where_clause.conditions.len(), 2);
3289
3290 assert!(matches!(
3292 where_clause.conditions[0].connector,
3293 Some(LogicalOp::And)
3294 ));
3295 }
3296
3297 #[test]
3298 fn test_complex_where_with_methods_in_parens() {
3299 let mut parser = Parser::new(
3301 r#"SELECT * FROM trades WHERE (symbol.StartsWith("A") OR symbol.StartsWith("G")) AND volume > 1000000"#,
3302 );
3303 let stmt = parser.parse().unwrap();
3304
3305 assert!(stmt.where_clause.is_some());
3306 let where_clause = stmt.where_clause.unwrap();
3307 assert_eq!(where_clause.conditions.len(), 2);
3308
3309 if let SqlExpression::BinaryOp { op, left, right } = &where_clause.conditions[0].expr {
3311 assert_eq!(op, "OR");
3312 assert!(matches!(left.as_ref(), SqlExpression::MethodCall { .. }));
3313 assert!(matches!(right.as_ref(), SqlExpression::MethodCall { .. }));
3314 } else {
3315 panic!("Expected OR of method calls");
3316 }
3317 }
3318
3319 #[test]
3320 fn test_complex_where_date_comparisons_with_parens() {
3321 let mut parser = Parser::new(
3323 r#"SELECT * FROM trades WHERE (executionDate > DateTime(2024, 1, 1) AND executionDate < DateTime(2024, 12, 31)) AND (status = "filled" OR status = "partial")"#,
3324 );
3325 let stmt = parser.parse().unwrap();
3326
3327 assert!(stmt.where_clause.is_some());
3328 let where_clause = stmt.where_clause.unwrap();
3329 assert_eq!(where_clause.conditions.len(), 2);
3330
3331 assert!(matches!(
3333 where_clause.conditions[0].connector,
3334 Some(LogicalOp::And)
3335 ));
3336 }
3337
3338 #[test]
3339 fn test_complex_where_price_volume_filters() {
3340 let mut parser = Parser::new(
3342 r#"SELECT * FROM trades WHERE ((price > 100 AND price < 200) OR (price > 500 AND price < 1000)) AND volume > 10000"#,
3343 );
3344 let stmt = parser.parse().unwrap();
3345
3346 assert!(stmt.where_clause.is_some());
3347 let where_clause = stmt.where_clause.unwrap();
3348
3349 assert!(where_clause.conditions.len() > 0);
3351 }
3352
3353 #[test]
3354 fn test_complex_where_mixed_string_numeric() {
3355 let mut parser = Parser::new(
3357 r#"SELECT * FROM trades WHERE (exchange = "NYSE" OR exchange = "NASDAQ") AND (volume > 1000000 OR notes.Contains("urgent"))"#,
3358 );
3359 let stmt = parser.parse().unwrap();
3360
3361 assert!(stmt.where_clause.is_some());
3362 }
3364
3365 #[test]
3366 fn test_complex_where_triple_nested() {
3367 let mut parser = Parser::new(
3369 r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND (price > 100 OR volume > 1000000)) OR (status = "cancelled" AND reason.Contains("timeout"))"#,
3370 );
3371 let stmt = parser.parse().unwrap();
3372
3373 assert!(stmt.where_clause.is_some());
3374 }
3376
3377 #[test]
3378 fn test_complex_where_single_parens_around_and() {
3379 let mut parser = Parser::new(
3381 r#"SELECT * FROM trades WHERE (symbol = "AAPL" AND price > 150 AND volume > 100000)"#,
3382 );
3383 let stmt = parser.parse().unwrap();
3384
3385 assert!(stmt.where_clause.is_some());
3386 let where_clause = stmt.where_clause.unwrap();
3387
3388 assert!(where_clause.conditions.len() > 0);
3390 }
3391
3392 #[test]
3394 fn test_format_preserves_simple_parentheses() {
3395 let query = r#"SELECT * FROM trades WHERE (status = "active" OR status = "pending")"#;
3396 let formatted = format_sql_pretty_compact(query, 5);
3397 let formatted_text = formatted.join(" ");
3398
3399 assert!(formatted_text.contains("(status"));
3401 assert!(formatted_text.contains("\"pending\")"));
3402
3403 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3405 let formatted_parens = formatted_text
3406 .chars()
3407 .filter(|c| *c == '(' || *c == ')')
3408 .count();
3409 assert_eq!(
3410 original_parens, formatted_parens,
3411 "Parentheses should be preserved"
3412 );
3413 }
3414
3415 #[test]
3416 fn test_format_preserves_complex_parentheses() {
3417 let query =
3418 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#;
3419 let formatted = format_sql_pretty_compact(query, 5);
3420 let formatted_text = formatted.join(" ");
3421
3422 assert!(formatted_text.contains("(symbol"));
3424 assert!(formatted_text.contains("\"GOOGL\")"));
3425
3426 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3428 let formatted_parens = formatted_text
3429 .chars()
3430 .filter(|c| *c == '(' || *c == ')')
3431 .count();
3432 assert_eq!(original_parens, formatted_parens);
3433 }
3434
3435 #[test]
3436 fn test_format_preserves_nested_parentheses() {
3437 let query = r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND price > 100) OR status = "cancelled""#;
3438 let formatted = format_sql_pretty_compact(query, 5);
3439 let formatted_text = formatted.join(" ");
3440
3441 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3443 let formatted_parens = formatted_text
3444 .chars()
3445 .filter(|c| *c == '(' || *c == ')')
3446 .count();
3447 assert_eq!(
3448 original_parens, formatted_parens,
3449 "Nested parentheses should be preserved"
3450 );
3451 assert_eq!(original_parens, 4, "Should have 4 parentheses total");
3452 }
3453
3454 #[test]
3455 fn test_format_preserves_method_calls_in_parentheses() {
3456 let query = r#"SELECT * FROM trades WHERE (symbol.StartsWith("A") OR symbol.StartsWith("G")) AND volume > 1000000"#;
3457 let formatted = format_sql_pretty_compact(query, 5);
3458 let formatted_text = formatted.join(" ");
3459
3460 assert!(formatted_text.contains("(symbol.StartsWith"));
3462 assert!(formatted_text.contains("StartsWith(\"A\")"));
3463 assert!(formatted_text.contains("StartsWith(\"G\")"));
3464
3465 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3467 let formatted_parens = formatted_text
3468 .chars()
3469 .filter(|c| *c == '(' || *c == ')')
3470 .count();
3471 assert_eq!(original_parens, formatted_parens);
3472 assert_eq!(
3473 original_parens, 6,
3474 "Should have 6 parentheses (1 group + 2 method calls)"
3475 );
3476 }
3477
3478 #[test]
3479 fn test_format_preserves_multiple_groups() {
3480 let query = r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND (price > 100 AND price < 500)"#;
3481 let formatted = format_sql_pretty_compact(query, 5);
3482 let formatted_text = formatted.join(" ");
3483
3484 assert!(formatted_text.contains("(symbol"));
3486 assert!(formatted_text.contains("(price"));
3487
3488 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3489 let formatted_parens = formatted_text
3490 .chars()
3491 .filter(|c| *c == '(' || *c == ')')
3492 .count();
3493 assert_eq!(original_parens, formatted_parens);
3494 assert_eq!(original_parens, 4, "Should have 4 parentheses (2 groups)");
3495 }
3496
3497 #[test]
3498 fn test_format_preserves_date_ranges() {
3499 let query = r#"SELECT * FROM trades WHERE (executionDate > DateTime(2024, 1, 1) AND executionDate < DateTime(2024, 12, 31))"#;
3500 let formatted = format_sql_pretty_compact(query, 5);
3501 let formatted_text = formatted.join(" ");
3502
3503 assert!(formatted_text.contains("(executionDate"));
3505 assert!(formatted_text.contains("DateTime(2024, 1, 1)"));
3506 assert!(formatted_text.contains("DateTime(2024, 12, 31)"));
3507
3508 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3509 let formatted_parens = formatted_text
3510 .chars()
3511 .filter(|c| *c == '(' || *c == ')')
3512 .count();
3513 assert_eq!(original_parens, formatted_parens);
3514 }
3515
3516 #[test]
3517 fn test_format_multiline_layout() {
3518 let query =
3520 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#;
3521 let formatted = format_sql_pretty_compact(query, 5);
3522
3523 assert!(formatted.len() >= 4, "Should have multiple lines");
3525 assert_eq!(formatted[0], "SELECT");
3526 assert!(formatted[1].trim().starts_with("*"));
3527 assert!(formatted[2].starts_with("FROM"));
3528 assert_eq!(formatted[3], "WHERE");
3529
3530 let where_lines: Vec<_> = formatted.iter().skip(4).collect();
3532 assert!(where_lines.iter().any(|l| l.contains("(symbol")));
3533 assert!(where_lines.iter().any(|l| l.trim() == "AND"));
3534 }
3535
3536 #[test]
3537 fn test_between_simple() {
3538 let mut parser = Parser::new("SELECT * FROM table WHERE price BETWEEN 50 AND 100");
3539 let stmt = parser.parse().expect("Should parse simple BETWEEN");
3540
3541 assert!(stmt.where_clause.is_some());
3542 let where_clause = stmt.where_clause.unwrap();
3543 assert_eq!(where_clause.conditions.len(), 1);
3544
3545 let ast = format_ast_tree("SELECT * FROM table WHERE price BETWEEN 50 AND 100");
3547 assert!(!ast.contains("PARSE ERROR"));
3548 assert!(ast.contains("SelectStatement"));
3549 }
3550
3551 #[test]
3552 fn test_between_in_parentheses() {
3553 let mut parser = Parser::new("SELECT * FROM table WHERE (price BETWEEN 50 AND 100)");
3554 let stmt = parser.parse().expect("Should parse BETWEEN in parentheses");
3555
3556 assert!(stmt.where_clause.is_some());
3557
3558 let ast = format_ast_tree("SELECT * FROM table WHERE (price BETWEEN 50 AND 100)");
3560 assert!(!ast.contains("PARSE ERROR"), "Should not have parse error");
3561 }
3562
3563 #[test]
3564 fn test_between_with_or() {
3565 let query = "SELECT * FROM test WHERE (Price BETWEEN 50 AND 100) OR (quantity > 5)";
3566 let mut parser = Parser::new(query);
3567 let stmt = parser.parse().expect("Should parse BETWEEN with OR");
3568
3569 assert!(stmt.where_clause.is_some());
3570 }
3573
3574 #[test]
3575 fn test_between_with_and() {
3576 let query = "SELECT * FROM table WHERE category = 'Books' AND price BETWEEN 10 AND 50";
3577 let mut parser = Parser::new(query);
3578 let stmt = parser.parse().expect("Should parse BETWEEN with AND");
3579
3580 assert!(stmt.where_clause.is_some());
3581 let where_clause = stmt.where_clause.unwrap();
3582 assert_eq!(where_clause.conditions.len(), 2); }
3584
3585 #[test]
3586 fn test_multiple_between() {
3587 let query =
3588 "SELECT * FROM table WHERE (price BETWEEN 10 AND 50) AND (quantity BETWEEN 5 AND 20)";
3589 let mut parser = Parser::new(query);
3590 let stmt = parser
3591 .parse()
3592 .expect("Should parse multiple BETWEEN clauses");
3593
3594 assert!(stmt.where_clause.is_some());
3595 }
3596
3597 #[test]
3598 fn test_between_complex_query() {
3599 let query = "SELECT * FROM test_sorting WHERE (Price BETWEEN 50 AND 100) OR (Product.Length() > 5) ORDER BY Category ASC, price DESC";
3601 let mut parser = Parser::new(query);
3602 let stmt = parser
3603 .parse()
3604 .expect("Should parse complex query with BETWEEN, method calls, and ORDER BY");
3605
3606 assert!(stmt.where_clause.is_some());
3607 assert!(stmt.order_by.is_some());
3608
3609 let order_by = stmt.order_by.unwrap();
3610 assert_eq!(order_by.len(), 2);
3611 assert_eq!(order_by[0].column, "Category");
3612 assert!(matches!(order_by[0].direction, SortDirection::Asc));
3613 assert_eq!(order_by[1].column, "price");
3614 assert!(matches!(order_by[1].direction, SortDirection::Desc));
3615 }
3616
3617 #[test]
3618 fn test_between_formatting() {
3619 let expr = SqlExpression::Between {
3620 expr: Box::new(SqlExpression::Column("price".to_string())),
3621 lower: Box::new(SqlExpression::NumberLiteral("50".to_string())),
3622 upper: Box::new(SqlExpression::NumberLiteral("100".to_string())),
3623 };
3624
3625 let formatted = format_expression(&expr);
3626 assert_eq!(formatted, "price BETWEEN 50 AND 100");
3627
3628 let ast_formatted = format_expression_ast(&expr);
3629 assert!(ast_formatted.contains("Between"));
3630 assert!(ast_formatted.contains("50"));
3631 assert!(ast_formatted.contains("100"));
3632 }
3633
3634 #[test]
3635 fn test_utf8_boundary_safety() {
3636 let query_with_unicode = "SELECT * FROM table WHERE column = 'héllo'";
3638
3639 for pos in 0..query_with_unicode.len() + 1 {
3641 let result =
3643 std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, pos));
3644
3645 assert!(
3646 result.is_ok(),
3647 "Panic at position {} in query with Unicode",
3648 pos
3649 );
3650 }
3651
3652 let result = std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, 1000));
3654 assert!(result.is_ok(), "Panic with position beyond string length");
3655
3656 let pos_in_e = query_with_unicode.find('é').unwrap() + 1; let result =
3659 std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, pos_in_e));
3660 assert!(
3661 result.is_ok(),
3662 "Panic with cursor in middle of UTF-8 character"
3663 );
3664 }
3665}