1use chrono::{Datelike, Local, NaiveDateTime};
2
3#[derive(Debug, Clone, PartialEq)]
4pub enum Token {
5 Select,
7 From,
8 Where,
9 And,
10 Or,
11 In,
12 Not,
13 Between,
14 Like,
15 Is,
16 Null,
17 OrderBy,
18 GroupBy,
19 Having,
20 As,
21 Asc,
22 Desc,
23 Limit,
24 Offset,
25 DateTime, Case, When, Then, Else, End, Distinct, Identifier(String),
35 QuotedIdentifier(String), StringLiteral(String),
37 NumberLiteral(String),
38 Star,
39
40 Dot,
42 Comma,
43 LeftParen,
44 RightParen,
45 Equal,
46 NotEqual,
47 LessThan,
48 GreaterThan,
49 LessThanOrEqual,
50 GreaterThanOrEqual,
51
52 Plus,
54 Minus,
55 Divide,
56
57 Eof,
59}
60
61#[derive(Debug, Clone)]
62pub struct Lexer {
63 input: Vec<char>,
64 position: usize,
65 current_char: Option<char>,
66}
67
68impl Lexer {
69 pub fn new(input: &str) -> Self {
70 let chars: Vec<char> = input.chars().collect();
71 let current = chars.get(0).copied();
72 Self {
73 input: chars,
74 position: 0,
75 current_char: current,
76 }
77 }
78
79 fn advance(&mut self) {
80 self.position += 1;
81 self.current_char = self.input.get(self.position).copied();
82 }
83
84 fn peek(&self, offset: usize) -> Option<char> {
85 self.input.get(self.position + offset).copied()
86 }
87
88 fn skip_whitespace(&mut self) {
89 while let Some(ch) = self.current_char {
90 if ch.is_whitespace() {
91 self.advance();
92 } else {
93 break;
94 }
95 }
96 }
97
98 fn skip_whitespace_and_comments(&mut self) {
99 loop {
100 while let Some(ch) = self.current_char {
102 if ch.is_whitespace() {
103 self.advance();
104 } else {
105 break;
106 }
107 }
108
109 match self.current_char {
111 Some('-') if self.peek(1) == Some('-') => {
112 self.advance(); self.advance(); while let Some(ch) = self.current_char {
116 self.advance();
117 if ch == '\n' {
118 break;
119 }
120 }
121 }
122 Some('/') if self.peek(1) == Some('*') => {
123 self.advance(); self.advance(); while let Some(ch) = self.current_char {
127 if ch == '*' && self.peek(1) == Some('/') {
128 self.advance(); self.advance(); break;
131 }
132 self.advance();
133 }
134 }
135 _ => {
136 break;
138 }
139 }
140 }
141 }
142
143 fn read_identifier(&mut self) -> String {
144 let mut result = String::new();
145 while let Some(ch) = self.current_char {
146 if ch.is_alphanumeric() || ch == '_' {
147 result.push(ch);
148 self.advance();
149 } else {
150 break;
151 }
152 }
153 result
154 }
155
156 fn read_string(&mut self) -> String {
157 let mut result = String::new();
158 let quote_char = self.current_char.unwrap(); self.advance(); while let Some(ch) = self.current_char {
162 if ch == quote_char {
163 self.advance(); break;
165 }
166 result.push(ch);
167 self.advance();
168 }
169 result
170 }
171
172 fn read_number(&mut self) -> String {
173 let mut result = String::new();
174 let mut has_e = false;
175
176 while let Some(ch) = self.current_char {
178 if !has_e && (ch.is_numeric() || ch == '.') {
179 result.push(ch);
180 self.advance();
181 } else if (ch == 'e' || ch == 'E') && !has_e && !result.is_empty() {
182 result.push(ch);
184 self.advance();
185 has_e = true;
186
187 if let Some(sign) = self.current_char {
189 if sign == '+' || sign == '-' {
190 result.push(sign);
191 self.advance();
192 }
193 }
194
195 while let Some(digit) = self.current_char {
197 if digit.is_numeric() {
198 result.push(digit);
199 self.advance();
200 } else {
201 break;
202 }
203 }
204 break; } else {
206 break;
207 }
208 }
209 result
210 }
211
212 pub fn next_token(&mut self) -> Token {
213 self.skip_whitespace_and_comments();
214
215 match self.current_char {
216 None => Token::Eof,
217 Some('*') => {
218 self.advance();
219 Token::Star }
223 Some('+') => {
224 self.advance();
225 Token::Plus
226 }
227 Some('/') => {
228 if self.peek(1) == Some('*') {
230 self.skip_whitespace_and_comments();
233 return self.next_token();
234 }
235 self.advance();
236 Token::Divide
237 }
238 Some('.') => {
239 self.advance();
240 Token::Dot
241 }
242 Some(',') => {
243 self.advance();
244 Token::Comma
245 }
246 Some('(') => {
247 self.advance();
248 Token::LeftParen
249 }
250 Some(')') => {
251 self.advance();
252 Token::RightParen
253 }
254 Some('=') => {
255 self.advance();
256 Token::Equal
257 }
258 Some('<') => {
259 self.advance();
260 if self.current_char == Some('=') {
261 self.advance();
262 Token::LessThanOrEqual
263 } else if self.current_char == Some('>') {
264 self.advance();
265 Token::NotEqual
266 } else {
267 Token::LessThan
268 }
269 }
270 Some('>') => {
271 self.advance();
272 if self.current_char == Some('=') {
273 self.advance();
274 Token::GreaterThanOrEqual
275 } else {
276 Token::GreaterThan
277 }
278 }
279 Some('!') if self.peek(1) == Some('=') => {
280 self.advance();
281 self.advance();
282 Token::NotEqual
283 }
284 Some('"') => {
285 let ident_val = self.read_string();
287 Token::QuotedIdentifier(ident_val)
288 }
289 Some('\'') => {
290 let string_val = self.read_string();
292 Token::StringLiteral(string_val)
293 }
294 Some('-') if self.peek(1) == Some('-') => {
295 self.skip_whitespace_and_comments();
297 self.next_token()
298 }
299 Some('-') if self.peek(1).map_or(false, |c| c.is_numeric()) => {
300 self.advance(); let num = self.read_number();
303 Token::NumberLiteral(format!("-{}", num))
304 }
305 Some('-') => {
306 self.advance();
308 Token::Minus
309 }
310 Some(ch) if ch.is_numeric() => {
311 let num = self.read_number();
312 Token::NumberLiteral(num)
313 }
314 Some(ch) if ch.is_alphabetic() || ch == '_' => {
315 let ident = self.read_identifier();
316 match ident.to_uppercase().as_str() {
317 "SELECT" => Token::Select,
318 "FROM" => Token::From,
319 "WHERE" => Token::Where,
320 "AND" => Token::And,
321 "OR" => Token::Or,
322 "IN" => Token::In,
323 "NOT" => Token::Not,
324 "BETWEEN" => Token::Between,
325 "LIKE" => Token::Like,
326 "IS" => Token::Is,
327 "NULL" => Token::Null,
328 "ORDER" if self.peek_keyword("BY") => {
329 self.skip_whitespace();
330 self.read_identifier(); Token::OrderBy
332 }
333 "GROUP" if self.peek_keyword("BY") => {
334 self.skip_whitespace();
335 self.read_identifier(); Token::GroupBy
337 }
338 "HAVING" => Token::Having,
339 "AS" => Token::As,
340 "ASC" => Token::Asc,
341 "DESC" => Token::Desc,
342 "LIMIT" => Token::Limit,
343 "OFFSET" => Token::Offset,
344 "DATETIME" => Token::DateTime,
345 "CASE" => Token::Case,
346 "WHEN" => Token::When,
347 "THEN" => Token::Then,
348 "ELSE" => Token::Else,
349 "END" => Token::End,
350 "DISTINCT" => Token::Distinct,
351 _ => Token::Identifier(ident),
352 }
353 }
354 Some(ch) => {
355 self.advance();
356 Token::Identifier(ch.to_string())
357 }
358 }
359 }
360
361 fn peek_keyword(&mut self, keyword: &str) -> bool {
362 let saved_pos = self.position;
363 let saved_char = self.current_char;
364
365 self.skip_whitespace_and_comments();
366 let next_word = self.read_identifier();
367 let matches = next_word.to_uppercase() == keyword;
368
369 self.position = saved_pos;
371 self.current_char = saved_char;
372
373 matches
374 }
375
376 pub fn get_position(&self) -> usize {
377 self.position
378 }
379
380 pub fn tokenize_all(&mut self) -> Vec<Token> {
381 let mut tokens = Vec::new();
382 loop {
383 let token = self.next_token();
384 if matches!(token, Token::Eof) {
385 tokens.push(token);
386 break;
387 }
388 tokens.push(token);
389 }
390 tokens
391 }
392
393 pub fn tokenize_all_with_positions(&mut self) -> Vec<(usize, usize, Token)> {
394 let mut tokens = Vec::new();
395 loop {
396 self.skip_whitespace_and_comments();
397 let start_pos = self.position;
398 let token = self.next_token();
399 let end_pos = self.position;
400
401 if matches!(token, Token::Eof) {
402 break;
403 }
404 tokens.push((start_pos, end_pos, token));
405 }
406 tokens
407 }
408}
409
410#[derive(Debug, Clone)]
412pub enum SqlExpression {
413 Column(String),
414 StringLiteral(String),
415 NumberLiteral(String),
416 DateTimeConstructor {
417 year: i32,
418 month: u32,
419 day: u32,
420 hour: Option<u32>,
421 minute: Option<u32>,
422 second: Option<u32>,
423 },
424 DateTimeToday {
425 hour: Option<u32>,
426 minute: Option<u32>,
427 second: Option<u32>,
428 },
429 MethodCall {
430 object: String,
431 method: String,
432 args: Vec<SqlExpression>,
433 },
434 ChainedMethodCall {
435 base: Box<SqlExpression>,
436 method: String,
437 args: Vec<SqlExpression>,
438 },
439 FunctionCall {
440 name: String,
441 args: Vec<SqlExpression>,
442 },
443 BinaryOp {
444 left: Box<SqlExpression>,
445 op: String,
446 right: Box<SqlExpression>,
447 },
448 InList {
449 expr: Box<SqlExpression>,
450 values: Vec<SqlExpression>,
451 },
452 NotInList {
453 expr: Box<SqlExpression>,
454 values: Vec<SqlExpression>,
455 },
456 Between {
457 expr: Box<SqlExpression>,
458 lower: Box<SqlExpression>,
459 upper: Box<SqlExpression>,
460 },
461 Not {
462 expr: Box<SqlExpression>,
463 },
464 CaseExpression {
465 when_branches: Vec<WhenBranch>,
466 else_branch: Option<Box<SqlExpression>>,
467 },
468}
469
470#[derive(Debug, Clone)]
471pub struct WhenBranch {
472 pub condition: Box<SqlExpression>,
473 pub result: Box<SqlExpression>,
474}
475
476#[derive(Debug, Clone)]
477pub struct WhereClause {
478 pub conditions: Vec<Condition>,
479}
480
481#[derive(Debug, Clone)]
482pub struct Condition {
483 pub expr: SqlExpression,
484 pub connector: Option<LogicalOp>, }
486
487#[derive(Debug, Clone)]
488pub enum LogicalOp {
489 And,
490 Or,
491}
492
493#[derive(Debug, Clone, PartialEq)]
494pub enum SortDirection {
495 Asc,
496 Desc,
497}
498
499#[derive(Debug, Clone)]
500pub struct OrderByColumn {
501 pub column: String,
502 pub direction: SortDirection,
503}
504
505#[derive(Debug, Clone)]
507pub enum SelectItem {
508 Column(String),
510 Expression { expr: SqlExpression, alias: String },
512 Star,
514}
515
516#[derive(Debug, Clone)]
517pub struct SelectStatement {
518 pub columns: Vec<String>, pub select_items: Vec<SelectItem>, pub from_table: Option<String>,
521 pub where_clause: Option<WhereClause>,
522 pub order_by: Option<Vec<OrderByColumn>>,
523 pub group_by: Option<Vec<String>>,
524 pub limit: Option<usize>,
525 pub offset: Option<usize>,
526}
527
528pub struct ParserConfig {
529 pub case_insensitive: bool,
530}
531
532impl Default for ParserConfig {
533 fn default() -> Self {
534 Self {
535 case_insensitive: false,
536 }
537 }
538}
539
540pub struct Parser {
541 lexer: Lexer,
542 current_token: Token,
543 in_method_args: bool, columns: Vec<String>, paren_depth: i32, #[allow(dead_code)]
547 config: ParserConfig, }
549
550impl Parser {
551 pub fn new(input: &str) -> Self {
552 let mut lexer = Lexer::new(input);
553 let current_token = lexer.next_token();
554 Self {
555 lexer,
556 current_token,
557 in_method_args: false,
558 columns: Vec::new(),
559 paren_depth: 0,
560 config: ParserConfig::default(),
561 }
562 }
563
564 pub fn with_config(input: &str, config: ParserConfig) -> Self {
565 let mut lexer = Lexer::new(input);
566 let current_token = lexer.next_token();
567 Self {
568 lexer,
569 current_token,
570 in_method_args: false,
571 columns: Vec::new(),
572 paren_depth: 0,
573 config,
574 }
575 }
576
577 pub fn with_columns(mut self, columns: Vec<String>) -> Self {
578 self.columns = columns;
579 self
580 }
581
582 fn consume(&mut self, expected: Token) -> Result<(), String> {
583 if std::mem::discriminant(&self.current_token) == std::mem::discriminant(&expected) {
584 match &expected {
586 Token::LeftParen => self.paren_depth += 1,
587 Token::RightParen => {
588 self.paren_depth -= 1;
589 if self.paren_depth < 0 {
591 return Err(
592 "Unexpected closing parenthesis - no matching opening parenthesis"
593 .to_string(),
594 );
595 }
596 }
597 _ => {}
598 }
599
600 self.current_token = self.lexer.next_token();
601 Ok(())
602 } else {
603 let error_msg = match (&expected, &self.current_token) {
605 (Token::RightParen, Token::Eof) if self.paren_depth > 0 => {
606 format!(
607 "Unclosed parenthesis - missing {} closing parenthes{}",
608 self.paren_depth,
609 if self.paren_depth == 1 { "is" } else { "es" }
610 )
611 }
612 (Token::RightParen, _) if self.paren_depth > 0 => {
613 format!(
614 "Expected closing parenthesis but found {:?} (currently {} unclosed parenthes{})",
615 self.current_token,
616 self.paren_depth,
617 if self.paren_depth == 1 { "is" } else { "es" }
618 )
619 }
620 _ => format!("Expected {:?}, found {:?}", expected, self.current_token),
621 };
622 Err(error_msg)
623 }
624 }
625
626 fn advance(&mut self) {
627 match &self.current_token {
629 Token::LeftParen => self.paren_depth += 1,
630 Token::RightParen => {
631 self.paren_depth -= 1;
632 }
635 _ => {}
636 }
637 self.current_token = self.lexer.next_token();
638 }
639
640 pub fn parse(&mut self) -> Result<SelectStatement, String> {
641 self.parse_select_statement()
642 }
643
644 fn parse_select_statement(&mut self) -> Result<SelectStatement, String> {
645 self.consume(Token::Select)?;
646
647 let select_items = self.parse_select_items()?;
649
650 let columns = select_items
652 .iter()
653 .map(|item| match item {
654 SelectItem::Star => "*".to_string(),
655 SelectItem::Column(name) => name.clone(),
656 SelectItem::Expression { alias, .. } => alias.clone(),
657 })
658 .collect();
659
660 let from_table = if matches!(self.current_token, Token::From) {
661 self.advance();
662 match &self.current_token {
663 Token::Identifier(table) => {
664 let table_name = table.clone();
665 self.advance();
666 Some(table_name)
667 }
668 Token::QuotedIdentifier(table) => {
669 let table_name = table.clone();
671 self.advance();
672 Some(table_name)
673 }
674 _ => return Err("Expected table name after FROM".to_string()),
675 }
676 } else {
677 None
678 };
679
680 let where_clause = if matches!(self.current_token, Token::Where) {
681 self.advance();
682 Some(self.parse_where_clause()?)
683 } else {
684 None
685 };
686
687 let order_by = if matches!(self.current_token, Token::OrderBy) {
688 self.advance();
689 Some(self.parse_order_by_list()?)
690 } else {
691 None
692 };
693
694 let group_by = if matches!(self.current_token, Token::GroupBy) {
695 self.advance();
696 Some(self.parse_identifier_list()?)
697 } else {
698 None
699 };
700
701 let limit = if matches!(self.current_token, Token::Limit) {
703 self.advance();
704 match &self.current_token {
705 Token::NumberLiteral(num) => {
706 let limit_val = num
707 .parse::<usize>()
708 .map_err(|_| format!("Invalid LIMIT value: {}", num))?;
709 self.advance();
710 Some(limit_val)
711 }
712 _ => return Err("Expected number after LIMIT".to_string()),
713 }
714 } else {
715 None
716 };
717
718 let offset = if matches!(self.current_token, Token::Offset) {
720 self.advance();
721 match &self.current_token {
722 Token::NumberLiteral(num) => {
723 let offset_val = num
724 .parse::<usize>()
725 .map_err(|_| format!("Invalid OFFSET value: {}", num))?;
726 self.advance();
727 Some(offset_val)
728 }
729 _ => return Err("Expected number after OFFSET".to_string()),
730 }
731 } else {
732 None
733 };
734
735 if self.paren_depth > 0 {
737 return Err(format!(
738 "Unclosed parenthesis - missing {} closing parenthes{}",
739 self.paren_depth,
740 if self.paren_depth == 1 { "is" } else { "es" }
741 ));
742 } else if self.paren_depth < 0 {
743 return Err(
744 "Extra closing parenthesis found - no matching opening parenthesis".to_string(),
745 );
746 }
747
748 Ok(SelectStatement {
749 columns,
750 select_items,
751 from_table,
752 where_clause,
753 order_by,
754 group_by,
755 limit,
756 offset,
757 })
758 }
759
760 fn parse_select_list(&mut self) -> Result<Vec<String>, String> {
761 let mut columns = Vec::new();
762
763 if matches!(self.current_token, Token::Star) {
764 columns.push("*".to_string());
765 self.advance();
766 } else {
767 loop {
768 match &self.current_token {
769 Token::Identifier(col) => {
770 columns.push(col.clone());
771 self.advance();
772 }
773 Token::QuotedIdentifier(col) => {
774 columns.push(col.clone());
776 self.advance();
777 }
778 _ => return Err("Expected column name".to_string()),
779 }
780
781 if matches!(self.current_token, Token::Comma) {
782 self.advance();
783 } else {
784 break;
785 }
786 }
787 }
788
789 Ok(columns)
790 }
791
792 fn parse_select_items(&mut self) -> Result<Vec<SelectItem>, String> {
794 let mut items = Vec::new();
795
796 loop {
797 if matches!(self.current_token, Token::Star) {
800 items.push(SelectItem::Star);
808 self.advance();
809 } else {
810 let expr = self.parse_additive()?; let alias = if matches!(self.current_token, Token::As) {
815 self.advance();
816 match &self.current_token {
817 Token::Identifier(alias_name) => {
818 let alias = alias_name.clone();
819 self.advance();
820 alias
821 }
822 Token::QuotedIdentifier(alias_name) => {
823 let alias = alias_name.clone();
824 self.advance();
825 alias
826 }
827 _ => return Err("Expected alias name after AS".to_string()),
828 }
829 } else {
830 match &expr {
832 SqlExpression::Column(col_name) => col_name.clone(),
833 _ => format!("expr_{}", items.len() + 1), }
835 };
836
837 let item = match expr {
839 SqlExpression::Column(col_name) if alias == col_name => {
840 SelectItem::Column(col_name)
842 }
843 _ => {
844 SelectItem::Expression { expr, alias }
846 }
847 };
848
849 items.push(item);
850 }
851
852 if matches!(self.current_token, Token::Comma) {
854 self.advance();
855 } else {
856 break;
857 }
858 }
859
860 Ok(items)
861 }
862
863 fn parse_identifier_list(&mut self) -> Result<Vec<String>, String> {
864 let mut identifiers = Vec::new();
865
866 loop {
867 match &self.current_token {
868 Token::Identifier(id) => {
869 identifiers.push(id.clone());
870 self.advance();
871 }
872 Token::QuotedIdentifier(id) => {
873 identifiers.push(id.clone());
875 self.advance();
876 }
877 _ => return Err("Expected identifier".to_string()),
878 }
879
880 if matches!(self.current_token, Token::Comma) {
881 self.advance();
882 } else {
883 break;
884 }
885 }
886
887 Ok(identifiers)
888 }
889
890 fn parse_order_by_list(&mut self) -> Result<Vec<OrderByColumn>, String> {
891 let mut order_columns = Vec::new();
892
893 loop {
894 let column = match &self.current_token {
895 Token::Identifier(id) => {
896 let col = id.clone();
897 self.advance();
898 col
899 }
900 Token::QuotedIdentifier(id) => {
901 let col = id.clone();
902 self.advance();
903 col
904 }
905 Token::NumberLiteral(num) if self.columns.iter().any(|col| col == num) => {
906 let col = num.clone();
908 self.advance();
909 col
910 }
911 _ => return Err("Expected column name in ORDER BY".to_string()),
912 };
913
914 let direction = match &self.current_token {
916 Token::Asc => {
917 self.advance();
918 SortDirection::Asc
919 }
920 Token::Desc => {
921 self.advance();
922 SortDirection::Desc
923 }
924 _ => SortDirection::Asc, };
926
927 order_columns.push(OrderByColumn { column, direction });
928
929 if matches!(self.current_token, Token::Comma) {
930 self.advance();
931 } else {
932 break;
933 }
934 }
935
936 Ok(order_columns)
937 }
938
939 fn parse_where_clause(&mut self) -> Result<WhereClause, String> {
940 let mut conditions = Vec::new();
941
942 loop {
943 let expr = self.parse_expression()?;
944
945 let connector = match &self.current_token {
946 Token::And => {
947 self.advance();
948 Some(LogicalOp::And)
949 }
950 Token::Or => {
951 self.advance();
952 Some(LogicalOp::Or)
953 }
954 Token::RightParen if self.paren_depth <= 0 => {
955 return Err(
957 "Unexpected closing parenthesis - no matching opening parenthesis"
958 .to_string(),
959 );
960 }
961 _ => None,
962 };
963
964 conditions.push(Condition {
965 expr,
966 connector: connector.clone(),
967 });
968
969 if connector.is_none() {
970 break;
971 }
972 }
973
974 Ok(WhereClause { conditions })
975 }
976
977 fn parse_expression(&mut self) -> Result<SqlExpression, String> {
978 let mut left = self.parse_comparison()?;
979
980 if let Some(op) = self.get_binary_op() {
983 self.advance();
984 let right = self.parse_expression()?;
985 left = SqlExpression::BinaryOp {
986 left: Box::new(left),
987 op,
988 right: Box::new(right),
989 };
990 }
991
992 if matches!(self.current_token, Token::In) {
994 self.advance();
995 self.consume(Token::LeftParen)?;
996 let values = self.parse_expression_list()?;
997 self.consume(Token::RightParen)?;
998
999 left = SqlExpression::InList {
1000 expr: Box::new(left),
1001 values,
1002 };
1003 }
1004
1005 Ok(left)
1009 }
1010
1011 fn parse_comparison(&mut self) -> Result<SqlExpression, String> {
1012 let mut left = self.parse_additive()?;
1013
1014 if matches!(self.current_token, Token::Between) {
1016 self.advance(); let lower = self.parse_primary()?;
1018 self.consume(Token::And)?; let upper = self.parse_primary()?;
1020
1021 return Ok(SqlExpression::Between {
1022 expr: Box::new(left),
1023 lower: Box::new(lower),
1024 upper: Box::new(upper),
1025 });
1026 }
1027
1028 if matches!(self.current_token, Token::Not) {
1030 self.advance(); if matches!(self.current_token, Token::In) {
1032 self.advance(); self.consume(Token::LeftParen)?;
1034 let values = self.parse_expression_list()?;
1035 self.consume(Token::RightParen)?;
1036
1037 return Ok(SqlExpression::NotInList {
1038 expr: Box::new(left),
1039 values,
1040 });
1041 } else {
1042 return Err("Expected IN after NOT".to_string());
1043 }
1044 }
1045
1046 if let Some(op) = self.get_binary_op() {
1048 self.advance();
1049 let right = self.parse_additive()?;
1050 left = SqlExpression::BinaryOp {
1051 left: Box::new(left),
1052 op,
1053 right: Box::new(right),
1054 };
1055 }
1056
1057 Ok(left)
1058 }
1059
1060 fn parse_additive(&mut self) -> Result<SqlExpression, String> {
1061 let mut left = self.parse_multiplicative()?;
1062
1063 while matches!(self.current_token, Token::Plus | Token::Minus) {
1064 let op = match self.current_token {
1065 Token::Plus => "+",
1066 Token::Minus => "-",
1067 _ => unreachable!(),
1068 };
1069 self.advance();
1070 let right = self.parse_multiplicative()?;
1071 left = SqlExpression::BinaryOp {
1072 left: Box::new(left),
1073 op: op.to_string(),
1074 right: Box::new(right),
1075 };
1076 }
1077
1078 Ok(left)
1079 }
1080
1081 fn parse_multiplicative(&mut self) -> Result<SqlExpression, String> {
1082 let mut left = self.parse_primary()?;
1083
1084 while matches!(self.current_token, Token::Dot) {
1086 self.advance();
1087 if let Token::Identifier(method) = &self.current_token {
1088 let method_name = method.clone();
1089 self.advance();
1090
1091 if matches!(self.current_token, Token::LeftParen) {
1092 self.advance();
1093 let args = self.parse_method_args()?;
1094 self.consume(Token::RightParen)?;
1095
1096 match left {
1098 SqlExpression::Column(obj) => {
1099 left = SqlExpression::MethodCall {
1101 object: obj,
1102 method: method_name,
1103 args,
1104 };
1105 }
1106 SqlExpression::MethodCall { .. }
1107 | SqlExpression::ChainedMethodCall { .. } => {
1108 left = SqlExpression::ChainedMethodCall {
1110 base: Box::new(left),
1111 method: method_name,
1112 args,
1113 };
1114 }
1115 _ => {
1116 left = SqlExpression::ChainedMethodCall {
1118 base: Box::new(left),
1119 method: method_name,
1120 args,
1121 };
1122 }
1123 }
1124 } else {
1125 return Err(format!("Expected '(' after method name '{}'", method_name));
1126 }
1127 } else {
1128 return Err("Expected method name after '.'".to_string());
1129 }
1130 }
1131
1132 while matches!(self.current_token, Token::Star | Token::Divide) {
1133 let op = match self.current_token {
1134 Token::Star => "*",
1135 Token::Divide => "/",
1136 _ => unreachable!(),
1137 };
1138 self.advance();
1139 let right = self.parse_primary()?;
1140 left = SqlExpression::BinaryOp {
1141 left: Box::new(left),
1142 op: op.to_string(),
1143 right: Box::new(right),
1144 };
1145 }
1146
1147 Ok(left)
1148 }
1149
1150 fn parse_logical_or(&mut self) -> Result<SqlExpression, String> {
1151 let mut left = self.parse_logical_and()?;
1152
1153 while matches!(self.current_token, Token::Or) {
1154 self.advance();
1155 let right = self.parse_logical_and()?;
1156 left = SqlExpression::BinaryOp {
1160 left: Box::new(left),
1161 op: "OR".to_string(),
1162 right: Box::new(right),
1163 };
1164 }
1165
1166 Ok(left)
1167 }
1168
1169 fn parse_logical_and(&mut self) -> Result<SqlExpression, String> {
1170 let mut left = self.parse_expression()?;
1171
1172 while matches!(self.current_token, Token::And) {
1173 self.advance();
1174 let right = self.parse_expression()?;
1175 left = SqlExpression::BinaryOp {
1177 left: Box::new(left),
1178 op: "AND".to_string(),
1179 right: Box::new(right),
1180 };
1181 }
1182
1183 Ok(left)
1184 }
1185
1186 fn parse_case_expression(&mut self) -> Result<SqlExpression, String> {
1187 self.consume(Token::Case)?;
1189
1190 let mut when_branches = Vec::new();
1191
1192 while matches!(self.current_token, Token::When) {
1194 self.advance(); let condition = self.parse_expression()?;
1198
1199 self.consume(Token::Then)?;
1201
1202 let result = self.parse_expression()?;
1204
1205 when_branches.push(WhenBranch {
1206 condition: Box::new(condition),
1207 result: Box::new(result),
1208 });
1209 }
1210
1211 if when_branches.is_empty() {
1213 return Err("CASE expression must have at least one WHEN clause".to_string());
1214 }
1215
1216 let else_branch = if matches!(self.current_token, Token::Else) {
1218 self.advance(); Some(Box::new(self.parse_expression()?))
1220 } else {
1221 None
1222 };
1223
1224 self.consume(Token::End)?;
1226
1227 Ok(SqlExpression::CaseExpression {
1228 when_branches,
1229 else_branch,
1230 })
1231 }
1232
1233 fn parse_primary(&mut self) -> Result<SqlExpression, String> {
1234 if let Token::NumberLiteral(num_str) = &self.current_token {
1237 if self.columns.iter().any(|col| col == num_str) {
1239 let expr = SqlExpression::Column(num_str.clone());
1240 self.advance();
1241 return Ok(expr);
1242 }
1243 }
1244
1245 match &self.current_token {
1246 Token::Case => {
1247 self.parse_case_expression()
1249 }
1250 Token::DateTime => {
1251 self.advance(); self.consume(Token::LeftParen)?;
1253
1254 if matches!(&self.current_token, Token::RightParen) {
1256 self.advance(); return Ok(SqlExpression::DateTimeToday {
1258 hour: None,
1259 minute: None,
1260 second: None,
1261 });
1262 }
1263
1264 let year = if let Token::NumberLiteral(n) = &self.current_token {
1266 n.parse::<i32>().map_err(|_| "Invalid year")?
1267 } else {
1268 return Err("Expected year in DateTime constructor".to_string());
1269 };
1270 self.advance();
1271 self.consume(Token::Comma)?;
1272
1273 let month = if let Token::NumberLiteral(n) = &self.current_token {
1275 n.parse::<u32>().map_err(|_| "Invalid month")?
1276 } else {
1277 return Err("Expected month in DateTime constructor".to_string());
1278 };
1279 self.advance();
1280 self.consume(Token::Comma)?;
1281
1282 let day = if let Token::NumberLiteral(n) = &self.current_token {
1284 n.parse::<u32>().map_err(|_| "Invalid day")?
1285 } else {
1286 return Err("Expected day in DateTime constructor".to_string());
1287 };
1288 self.advance();
1289
1290 let mut hour = None;
1292 let mut minute = None;
1293 let mut second = None;
1294
1295 if matches!(&self.current_token, Token::Comma) {
1296 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1300 hour = Some(n.parse::<u32>().map_err(|_| "Invalid hour")?);
1301 self.advance();
1302
1303 if matches!(&self.current_token, Token::Comma) {
1305 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1308 minute = Some(n.parse::<u32>().map_err(|_| "Invalid minute")?);
1309 self.advance();
1310
1311 if matches!(&self.current_token, Token::Comma) {
1313 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1316 second =
1317 Some(n.parse::<u32>().map_err(|_| "Invalid second")?);
1318 self.advance();
1319 }
1320 }
1321 }
1322 }
1323 }
1324 }
1325
1326 self.consume(Token::RightParen)?;
1327 Ok(SqlExpression::DateTimeConstructor {
1328 year,
1329 month,
1330 day,
1331 hour,
1332 minute,
1333 second,
1334 })
1335 }
1336 Token::Identifier(id) => {
1337 let id_upper = id.to_uppercase();
1338 let id_clone = id.clone();
1339 self.advance();
1340
1341 if matches!(self.current_token, Token::LeftParen) {
1343 self.advance(); let args = self.parse_function_args()?;
1347 self.consume(Token::RightParen)?;
1348 return Ok(SqlExpression::FunctionCall {
1349 name: id_upper,
1350 args,
1351 });
1352 }
1353
1354 Ok(SqlExpression::Column(id_clone))
1356 }
1357 Token::QuotedIdentifier(id) => {
1358 let expr = if self.in_method_args {
1361 SqlExpression::StringLiteral(id.clone())
1362 } else {
1363 SqlExpression::Column(id.clone())
1365 };
1366 self.advance();
1367 Ok(expr)
1368 }
1369 Token::StringLiteral(s) => {
1370 let expr = SqlExpression::StringLiteral(s.clone());
1371 self.advance();
1372 Ok(expr)
1373 }
1374 Token::NumberLiteral(n) => {
1375 let expr = SqlExpression::NumberLiteral(n.clone());
1376 self.advance();
1377 Ok(expr)
1378 }
1379 Token::LeftParen => {
1380 self.advance();
1381
1382 let expr = self.parse_logical_or()?;
1385
1386 self.consume(Token::RightParen)?;
1387 Ok(expr)
1388 }
1389 Token::Not => {
1390 self.advance(); if let Ok(inner_expr) = self.parse_comparison() {
1394 if matches!(self.current_token, Token::In) {
1396 self.advance(); self.consume(Token::LeftParen)?;
1398 let values = self.parse_expression_list()?;
1399 self.consume(Token::RightParen)?;
1400
1401 return Ok(SqlExpression::NotInList {
1402 expr: Box::new(inner_expr),
1403 values,
1404 });
1405 } else {
1406 return Ok(SqlExpression::Not {
1408 expr: Box::new(inner_expr),
1409 });
1410 }
1411 } else {
1412 return Err("Expected expression after NOT".to_string());
1413 }
1414 }
1415 Token::Star => {
1416 self.advance();
1418 Ok(SqlExpression::StringLiteral("*".to_string()))
1419 }
1420 _ => Err(format!("Unexpected token: {:?}", self.current_token)),
1421 }
1422 }
1423
1424 fn parse_method_args(&mut self) -> Result<Vec<SqlExpression>, String> {
1425 let mut args = Vec::new();
1426
1427 self.in_method_args = true;
1429
1430 if !matches!(self.current_token, Token::RightParen) {
1431 loop {
1432 args.push(self.parse_expression()?);
1433
1434 if matches!(self.current_token, Token::Comma) {
1435 self.advance();
1436 } else {
1437 break;
1438 }
1439 }
1440 }
1441
1442 self.in_method_args = false;
1444
1445 Ok(args)
1446 }
1447
1448 fn parse_function_args(&mut self) -> Result<Vec<SqlExpression>, String> {
1449 let mut args = Vec::new();
1450
1451 if !matches!(self.current_token, Token::RightParen) {
1452 if matches!(self.current_token, Token::Distinct) {
1454 self.advance(); let expr = self.parse_additive()?;
1457 args.push(SqlExpression::FunctionCall {
1459 name: "DISTINCT".to_string(),
1460 args: vec![expr],
1461 });
1462 } else {
1463 args.push(self.parse_additive()?);
1465 }
1466
1467 while matches!(self.current_token, Token::Comma) {
1469 self.advance();
1470 args.push(self.parse_additive()?);
1471 }
1472 }
1473
1474 Ok(args)
1475 }
1476
1477 fn parse_expression_list(&mut self) -> Result<Vec<SqlExpression>, String> {
1478 let mut expressions = Vec::new();
1479
1480 loop {
1481 expressions.push(self.parse_expression()?);
1482
1483 if matches!(self.current_token, Token::Comma) {
1484 self.advance();
1485 } else {
1486 break;
1487 }
1488 }
1489
1490 Ok(expressions)
1491 }
1492
1493 fn get_binary_op(&self) -> Option<String> {
1494 match &self.current_token {
1495 Token::Equal => Some("=".to_string()),
1496 Token::NotEqual => Some("!=".to_string()),
1497 Token::LessThan => Some("<".to_string()),
1498 Token::GreaterThan => Some(">".to_string()),
1499 Token::LessThanOrEqual => Some("<=".to_string()),
1500 Token::GreaterThanOrEqual => Some(">=".to_string()),
1501 Token::Like => Some("LIKE".to_string()),
1502 _ => None,
1503 }
1504 }
1505
1506 fn get_arithmetic_op(&self) -> Option<String> {
1507 match &self.current_token {
1508 Token::Plus => Some("+".to_string()),
1509 Token::Minus => Some("-".to_string()),
1510 Token::Star => Some("*".to_string()), Token::Divide => Some("/".to_string()),
1512 _ => None,
1513 }
1514 }
1515
1516 pub fn get_position(&self) -> usize {
1517 self.lexer.get_position()
1518 }
1519}
1520
1521#[derive(Debug, Clone)]
1523pub enum CursorContext {
1524 SelectClause,
1525 FromClause,
1526 WhereClause,
1527 OrderByClause,
1528 AfterColumn(String),
1529 AfterLogicalOp(LogicalOp),
1530 AfterComparisonOp(String, String), InMethodCall(String, String), InExpression,
1533 Unknown,
1534}
1535
1536fn safe_slice_to(s: &str, pos: usize) -> &str {
1538 if pos >= s.len() {
1539 return s;
1540 }
1541
1542 let mut safe_pos = pos;
1544 while safe_pos > 0 && !s.is_char_boundary(safe_pos) {
1545 safe_pos -= 1;
1546 }
1547
1548 &s[..safe_pos]
1549}
1550
1551fn safe_slice_from(s: &str, pos: usize) -> &str {
1553 if pos >= s.len() {
1554 return "";
1555 }
1556
1557 let mut safe_pos = pos;
1559 while safe_pos < s.len() && !s.is_char_boundary(safe_pos) {
1560 safe_pos += 1;
1561 }
1562
1563 &s[safe_pos..]
1564}
1565
1566pub fn detect_cursor_context(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
1567 let truncated = safe_slice_to(query, cursor_pos);
1568 let mut parser = Parser::new(truncated);
1569
1570 match parser.parse() {
1572 Ok(stmt) => {
1573 let (ctx, partial) = analyze_statement(&stmt, truncated, cursor_pos);
1574 #[cfg(test)]
1575 println!(
1576 "analyze_statement returned: {:?}, {:?} for query: '{}'",
1577 ctx, partial, truncated
1578 );
1579 (ctx, partial)
1580 }
1581 Err(_) => {
1582 let (ctx, partial) = analyze_partial(truncated, cursor_pos);
1584 #[cfg(test)]
1585 println!(
1586 "analyze_partial returned: {:?}, {:?} for query: '{}'",
1587 ctx, partial, truncated
1588 );
1589 (ctx, partial)
1590 }
1591 }
1592}
1593
1594pub fn tokenize_query(query: &str) -> Vec<String> {
1595 let mut lexer = Lexer::new(query);
1596 let tokens = lexer.tokenize_all();
1597 tokens.iter().map(|t| format!("{:?}", t)).collect()
1598}
1599
1600pub fn format_sql_pretty(query: &str) -> Vec<String> {
1601 format_sql_pretty_compact(query, 5) }
1603
1604pub fn format_ast_tree(query: &str) -> String {
1606 let mut parser = Parser::new(query);
1607 match parser.parse() {
1608 Ok(stmt) => format_select_statement(&stmt, 0),
1609 Err(e) => format!("❌ PARSE ERROR ❌\n{}\n\n⚠️ The query could not be parsed correctly.\n💡 Check parentheses, operators, and syntax.", e),
1610 }
1611}
1612
1613fn format_select_statement(stmt: &SelectStatement, indent: usize) -> String {
1614 let mut result = String::new();
1615 let indent_str = " ".repeat(indent);
1616
1617 result.push_str(&format!("{indent_str}SelectStatement {{\n"));
1618
1619 result.push_str(&format!("{indent_str} columns: ["));
1621 if !stmt.columns.is_empty() {
1622 result.push('\n');
1623 for col in &stmt.columns {
1624 result.push_str(&format!("{indent_str} \"{col}\",\n"));
1625 }
1626 result.push_str(&format!("{indent_str} ],\n"));
1627 } else {
1628 result.push_str("],\n");
1629 }
1630
1631 if let Some(table) = &stmt.from_table {
1633 result.push_str(&format!("{indent_str} from_table: \"{table}\",\n"));
1634 }
1635
1636 if let Some(where_clause) = &stmt.where_clause {
1638 result.push_str(&format!("{indent_str} where_clause: {{\n"));
1639 result.push_str(&format_where_clause(where_clause, indent + 2));
1640 result.push_str(&format!("{indent_str} }},\n"));
1641 }
1642
1643 if let Some(order_by) = &stmt.order_by {
1645 result.push_str(&format!("{indent_str} order_by: ["));
1646 if !order_by.is_empty() {
1647 result.push('\n');
1648 for col in order_by {
1649 let dir = match col.direction {
1650 SortDirection::Asc => "ASC",
1651 SortDirection::Desc => "DESC",
1652 };
1653 result.push_str(&format!(
1654 "{indent_str} \"{col}\" {dir},\n",
1655 col = col.column
1656 ));
1657 }
1658 result.push_str(&format!("{indent_str} ],\n"));
1659 } else {
1660 result.push_str("],\n");
1661 }
1662 }
1663
1664 if let Some(group_by) = &stmt.group_by {
1666 result.push_str(&format!("{indent_str} group_by: ["));
1667 if !group_by.is_empty() {
1668 result.push('\n');
1669 for col in group_by {
1670 result.push_str(&format!("{indent_str} \"{col}\",\n"));
1671 }
1672 result.push_str(&format!("{indent_str} ],\n"));
1673 } else {
1674 result.push_str("]\n");
1675 }
1676 }
1677
1678 result.push_str(&format!("{indent_str}}}"));
1679 result
1680}
1681
1682fn format_where_clause(clause: &WhereClause, indent: usize) -> String {
1683 let mut result = String::new();
1684 let indent_str = " ".repeat(indent);
1685
1686 result.push_str(&format!("{indent_str}conditions: [\n"));
1687
1688 for condition in &clause.conditions {
1689 result.push_str(&format!("{indent_str} {{\n"));
1690 result.push_str(&format!(
1691 "{indent_str} expr: {},\n",
1692 format_expression_ast(&condition.expr)
1693 ));
1694
1695 if let Some(connector) = &condition.connector {
1696 let connector_str = match connector {
1697 LogicalOp::And => "AND",
1698 LogicalOp::Or => "OR",
1699 };
1700 result.push_str(&format!("{indent_str} connector: {connector_str},\n"));
1701 }
1702
1703 result.push_str(&format!("{indent_str} }},\n"));
1704 }
1705
1706 result.push_str(&format!("{indent_str}]\n"));
1707 result
1708}
1709
1710fn format_expression_ast(expr: &SqlExpression) -> String {
1711 match expr {
1712 SqlExpression::Column(name) => format!("Column(\"{}\")", name),
1713 SqlExpression::StringLiteral(value) => format!("StringLiteral(\"{}\")", value),
1714 SqlExpression::NumberLiteral(value) => format!("NumberLiteral({})", value),
1715 SqlExpression::DateTimeConstructor {
1716 year,
1717 month,
1718 day,
1719 hour,
1720 minute,
1721 second,
1722 } => {
1723 format!(
1724 "DateTime({}-{:02}-{:02} {:02}:{:02}:{:02})",
1725 year,
1726 month,
1727 day,
1728 hour.unwrap_or(0),
1729 minute.unwrap_or(0),
1730 second.unwrap_or(0)
1731 )
1732 }
1733 SqlExpression::DateTimeToday {
1734 hour,
1735 minute,
1736 second,
1737 } => {
1738 format!(
1739 "DateTimeToday({:02}:{:02}:{:02})",
1740 hour.unwrap_or(0),
1741 minute.unwrap_or(0),
1742 second.unwrap_or(0)
1743 )
1744 }
1745 SqlExpression::MethodCall {
1746 object,
1747 method,
1748 args,
1749 } => {
1750 let args_str = args
1751 .iter()
1752 .map(|a| format_expression_ast(a))
1753 .collect::<Vec<_>>()
1754 .join(", ");
1755 format!("MethodCall({}.{}({}))", object, method, args_str)
1756 }
1757 SqlExpression::ChainedMethodCall { base, method, args } => {
1758 let args_str = args
1759 .iter()
1760 .map(|a| format_expression_ast(a))
1761 .collect::<Vec<_>>()
1762 .join(", ");
1763 format!(
1764 "ChainedMethodCall({}.{}({}))",
1765 format_expression_ast(base),
1766 method,
1767 args_str
1768 )
1769 }
1770 SqlExpression::FunctionCall { name, args } => {
1771 let args_str = args
1772 .iter()
1773 .map(|a| format_expression_ast(a))
1774 .collect::<Vec<_>>()
1775 .join(", ");
1776 format!("FunctionCall({}({}))", name, args_str)
1777 }
1778 SqlExpression::BinaryOp { left, op, right } => {
1779 format!(
1780 "BinaryOp({} {} {})",
1781 format_expression_ast(left),
1782 op,
1783 format_expression_ast(right)
1784 )
1785 }
1786 SqlExpression::InList { expr, values } => {
1787 let list_str = values
1788 .iter()
1789 .map(|e| format_expression_ast(e))
1790 .collect::<Vec<_>>()
1791 .join(", ");
1792 format!("InList({} IN [{}])", format_expression_ast(expr), list_str)
1793 }
1794 SqlExpression::NotInList { expr, values } => {
1795 let list_str = values
1796 .iter()
1797 .map(|e| format_expression_ast(e))
1798 .collect::<Vec<_>>()
1799 .join(", ");
1800 format!(
1801 "NotInList({} NOT IN [{}])",
1802 format_expression_ast(expr),
1803 list_str
1804 )
1805 }
1806 SqlExpression::Between { expr, lower, upper } => {
1807 format!(
1808 "Between({} BETWEEN {} AND {})",
1809 format_expression_ast(expr),
1810 format_expression_ast(lower),
1811 format_expression_ast(upper)
1812 )
1813 }
1814 SqlExpression::Not { expr } => {
1815 format!("Not({})", format_expression_ast(expr))
1816 }
1817 SqlExpression::CaseExpression {
1818 when_branches,
1819 else_branch,
1820 } => {
1821 let when_strs: Vec<String> = when_branches
1822 .iter()
1823 .map(|branch| {
1824 format!(
1825 "WHEN {} THEN {}",
1826 format_expression_ast(&branch.condition),
1827 format_expression_ast(&branch.result)
1828 )
1829 })
1830 .collect();
1831 let else_str = else_branch
1832 .as_ref()
1833 .map(|e| format!(" ELSE {}", format_expression_ast(e)))
1834 .unwrap_or_default();
1835 format!("CASE {} {} END", when_strs.join(" "), else_str)
1836 }
1837 }
1838}
1839
1840pub fn datetime_to_iso_string(expr: &SqlExpression) -> Option<String> {
1842 match expr {
1843 SqlExpression::DateTimeConstructor {
1844 year,
1845 month,
1846 day,
1847 hour,
1848 minute,
1849 second,
1850 } => {
1851 let h = hour.unwrap_or(0);
1852 let m = minute.unwrap_or(0);
1853 let s = second.unwrap_or(0);
1854
1855 if let Ok(dt) = NaiveDateTime::parse_from_str(
1857 &format!(
1858 "{:04}-{:02}-{:02} {:02}:{:02}:{:02}",
1859 year, month, day, h, m, s
1860 ),
1861 "%Y-%m-%d %H:%M:%S",
1862 ) {
1863 Some(dt.format("%Y-%m-%d %H:%M:%S").to_string())
1864 } else {
1865 None
1866 }
1867 }
1868 SqlExpression::DateTimeToday {
1869 hour,
1870 minute,
1871 second,
1872 } => {
1873 let now = Local::now();
1874 let h = hour.unwrap_or(0);
1875 let m = minute.unwrap_or(0);
1876 let s = second.unwrap_or(0);
1877
1878 if let Ok(dt) = NaiveDateTime::parse_from_str(
1880 &format!(
1881 "{:04}-{:02}-{:02} {:02}:{:02}:{:02}",
1882 now.year(),
1883 now.month(),
1884 now.day(),
1885 h,
1886 m,
1887 s
1888 ),
1889 "%Y-%m-%d %H:%M:%S",
1890 ) {
1891 Some(dt.format("%Y-%m-%d %H:%M:%S").to_string())
1892 } else {
1893 None
1894 }
1895 }
1896 _ => None,
1897 }
1898}
1899
1900fn format_sql_with_preserved_parens(
1902 query: &str,
1903 cols_per_line: usize,
1904) -> Result<Vec<String>, String> {
1905 let mut lines = Vec::new();
1906 let mut lexer = Lexer::new(query);
1907 let tokens_with_pos = lexer.tokenize_all_with_positions();
1908
1909 if tokens_with_pos.is_empty() {
1910 return Err("No tokens found".to_string());
1911 }
1912
1913 let mut i = 0;
1914 let cols_per_line = cols_per_line.max(1);
1915
1916 while i < tokens_with_pos.len() {
1917 let (start, _end, ref token) = tokens_with_pos[i];
1918
1919 match token {
1920 Token::Select => {
1921 lines.push("SELECT".to_string());
1922 i += 1;
1923
1924 let mut columns = Vec::new();
1926 let mut col_start = i;
1927 while i < tokens_with_pos.len() {
1928 match &tokens_with_pos[i].2 {
1929 Token::From | Token::Eof => break,
1930 Token::Comma => {
1931 if col_start < i {
1933 let col_text = extract_text_between_positions(
1934 query,
1935 tokens_with_pos[col_start].0,
1936 tokens_with_pos[i - 1].1,
1937 );
1938 columns.push(col_text);
1939 }
1940 i += 1;
1941 col_start = i;
1942 }
1943 _ => i += 1,
1944 }
1945 }
1946 if col_start < i && i > 0 {
1948 let col_text = extract_text_between_positions(
1949 query,
1950 tokens_with_pos[col_start].0,
1951 tokens_with_pos[i - 1].1,
1952 );
1953 columns.push(col_text);
1954 }
1955
1956 for chunk in columns.chunks(cols_per_line) {
1958 let mut line = " ".to_string();
1959 for (idx, col) in chunk.iter().enumerate() {
1960 if idx > 0 {
1961 line.push_str(", ");
1962 }
1963 line.push_str(col.trim());
1964 }
1965 let is_last_chunk = chunk.as_ptr() as usize
1967 + chunk.len() * std::mem::size_of::<String>()
1968 >= columns.last().map(|c| c as *const _ as usize).unwrap_or(0);
1969 if !is_last_chunk && columns.len() > cols_per_line {
1970 line.push(',');
1971 }
1972 lines.push(line);
1973 }
1974 }
1975 Token::From => {
1976 i += 1;
1977 if i < tokens_with_pos.len() {
1978 let table_start = tokens_with_pos[i].0;
1979 while i < tokens_with_pos.len() {
1981 match &tokens_with_pos[i].2 {
1982 Token::Where | Token::OrderBy | Token::GroupBy | Token::Eof => break,
1983 _ => i += 1,
1984 }
1985 }
1986 if i > 0 {
1987 let table_text = extract_text_between_positions(
1988 query,
1989 table_start,
1990 tokens_with_pos[i - 1].1,
1991 );
1992 lines.push(format!("FROM {}", table_text.trim()));
1993 }
1994 }
1995 }
1996 Token::Where => {
1997 lines.push("WHERE".to_string());
1998 i += 1;
1999
2000 let where_start = if i < tokens_with_pos.len() {
2002 tokens_with_pos[i].0
2003 } else {
2004 start
2005 };
2006
2007 let mut where_end = query.len();
2009 while i < tokens_with_pos.len() {
2010 match &tokens_with_pos[i].2 {
2011 Token::OrderBy | Token::GroupBy | Token::Eof => {
2012 if i > 0 {
2013 where_end = tokens_with_pos[i - 1].1;
2014 }
2015 break;
2016 }
2017 _ => i += 1,
2018 }
2019 }
2020
2021 let where_text = extract_text_between_positions(query, where_start, where_end);
2023
2024 let formatted_where = format_where_clause_with_parens(&where_text);
2026 for line in formatted_where {
2027 lines.push(format!(" {}", line));
2028 }
2029 }
2030 Token::OrderBy => {
2031 i += 1;
2032 let order_start = if i < tokens_with_pos.len() {
2033 tokens_with_pos[i].0
2034 } else {
2035 start
2036 };
2037
2038 while i < tokens_with_pos.len() {
2040 match &tokens_with_pos[i].2 {
2041 Token::GroupBy | Token::Eof => break,
2042 _ => i += 1,
2043 }
2044 }
2045
2046 if i > 0 {
2047 let order_text = extract_text_between_positions(
2048 query,
2049 order_start,
2050 tokens_with_pos[i - 1].1,
2051 );
2052 lines.push(format!("ORDER BY {}", order_text.trim()));
2053 }
2054 }
2055 Token::GroupBy => {
2056 i += 1;
2057 let group_start = if i < tokens_with_pos.len() {
2058 tokens_with_pos[i].0
2059 } else {
2060 start
2061 };
2062
2063 while i < tokens_with_pos.len() {
2065 match &tokens_with_pos[i].2 {
2066 Token::Having | Token::Eof => break,
2067 _ => i += 1,
2068 }
2069 }
2070
2071 if i > 0 {
2072 let group_text = extract_text_between_positions(
2073 query,
2074 group_start,
2075 tokens_with_pos[i - 1].1,
2076 );
2077 lines.push(format!("GROUP BY {}", group_text.trim()));
2078 }
2079 }
2080 _ => i += 1,
2081 }
2082 }
2083
2084 Ok(lines)
2085}
2086
2087fn extract_text_between_positions(query: &str, start: usize, end: usize) -> String {
2089 let chars: Vec<char> = query.chars().collect();
2090 let start = start.min(chars.len());
2091 let end = end.min(chars.len());
2092 chars[start..end].iter().collect()
2093}
2094
2095fn format_where_clause_with_parens(where_text: &str) -> Vec<String> {
2097 let mut lines = Vec::new();
2098 let mut current_line = String::new();
2099 let mut paren_depth = 0;
2100 let mut i = 0;
2101 let chars: Vec<char> = where_text.chars().collect();
2102
2103 while i < chars.len() {
2104 if paren_depth == 0 {
2106 if i + 5 <= chars.len() {
2108 let next_five: String = chars[i..i + 5].iter().collect();
2109 if next_five.to_uppercase() == " AND " {
2110 if !current_line.trim().is_empty() {
2111 lines.push(current_line.trim().to_string());
2112 }
2113 lines.push("AND".to_string());
2114 current_line.clear();
2115 i += 5;
2116 continue;
2117 }
2118 }
2119 if i + 4 <= chars.len() {
2120 let next_four: String = chars[i..i + 4].iter().collect();
2121 if next_four.to_uppercase() == " OR " {
2122 if !current_line.trim().is_empty() {
2123 lines.push(current_line.trim().to_string());
2124 }
2125 lines.push("OR".to_string());
2126 current_line.clear();
2127 i += 4;
2128 continue;
2129 }
2130 }
2131 }
2132
2133 match chars[i] {
2135 '(' => {
2136 paren_depth += 1;
2137 current_line.push('(');
2138 }
2139 ')' => {
2140 paren_depth -= 1;
2141 current_line.push(')');
2142 }
2143 c => current_line.push(c),
2144 }
2145 i += 1;
2146 }
2147
2148 if !current_line.trim().is_empty() {
2150 lines.push(current_line.trim().to_string());
2151 }
2152
2153 if lines.is_empty() {
2155 lines.push(where_text.trim().to_string());
2156 }
2157
2158 lines
2159}
2160
2161pub fn format_sql_pretty_compact(query: &str, cols_per_line: usize) -> Vec<String> {
2162 if let Ok(lines) = format_sql_with_preserved_parens(query, cols_per_line) {
2164 return lines;
2165 }
2166
2167 let mut lines = Vec::new();
2169 let mut parser = Parser::new(query);
2170
2171 let cols_per_line = cols_per_line.max(1);
2173
2174 match parser.parse() {
2175 Ok(stmt) => {
2176 if !stmt.columns.is_empty() {
2178 lines.push("SELECT".to_string());
2179
2180 for chunk in stmt.columns.chunks(cols_per_line) {
2182 let mut line = " ".to_string();
2183 for (i, col) in chunk.iter().enumerate() {
2184 if i > 0 {
2185 line.push_str(", ");
2186 }
2187 line.push_str(col);
2188 }
2189 let last_chunk_idx = (stmt.columns.len() - 1) / cols_per_line;
2191 let current_chunk_idx =
2192 stmt.columns.iter().position(|c| c == &chunk[0]).unwrap() / cols_per_line;
2193 if current_chunk_idx < last_chunk_idx {
2194 line.push(',');
2195 }
2196 lines.push(line);
2197 }
2198 }
2199
2200 if let Some(table) = &stmt.from_table {
2202 lines.push(format!("FROM {}", table));
2203 }
2204
2205 if let Some(where_clause) = &stmt.where_clause {
2207 lines.push("WHERE".to_string());
2208 for (i, condition) in where_clause.conditions.iter().enumerate() {
2209 if i > 0 {
2210 if let Some(prev_condition) = where_clause.conditions.get(i - 1) {
2212 if let Some(connector) = &prev_condition.connector {
2213 match connector {
2214 LogicalOp::And => lines.push(" AND".to_string()),
2215 LogicalOp::Or => lines.push(" OR".to_string()),
2216 }
2217 }
2218 }
2219 }
2220 lines.push(format!(" {}", format_expression(&condition.expr)));
2221 }
2222 }
2223
2224 if let Some(order_by) = &stmt.order_by {
2226 let order_str = order_by
2227 .iter()
2228 .map(|col| {
2229 let dir = match col.direction {
2230 SortDirection::Asc => " ASC",
2231 SortDirection::Desc => " DESC",
2232 };
2233 format!("{}{}", col.column, dir)
2234 })
2235 .collect::<Vec<_>>()
2236 .join(", ");
2237 lines.push(format!("ORDER BY {}", order_str));
2238 }
2239
2240 if let Some(group_by) = &stmt.group_by {
2242 let group_str = group_by.join(", ");
2243 lines.push(format!("GROUP BY {}", group_str));
2244 }
2245 }
2246 Err(_) => {
2247 let mut lexer = Lexer::new(query);
2249 let tokens = lexer.tokenize_all();
2250 let mut current_line = String::new();
2251 let mut indent = 0;
2252
2253 for token in tokens {
2254 match &token {
2255 Token::Select
2256 | Token::From
2257 | Token::Where
2258 | Token::OrderBy
2259 | Token::GroupBy => {
2260 if !current_line.is_empty() {
2261 lines.push(current_line.trim().to_string());
2262 current_line.clear();
2263 }
2264 lines.push(format!("{:?}", token).to_uppercase());
2265 indent = 1;
2266 }
2267 Token::And | Token::Or => {
2268 if !current_line.is_empty() {
2269 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2270 current_line.clear();
2271 }
2272 lines.push(format!(" {:?}", token).to_uppercase());
2273 }
2274 Token::Comma => {
2275 current_line.push(',');
2276 if indent > 0 {
2277 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2278 current_line.clear();
2279 }
2280 }
2281 Token::Eof => break,
2282 _ => {
2283 if !current_line.is_empty() {
2284 current_line.push(' ');
2285 }
2286 current_line.push_str(&format_token(&token));
2287 }
2288 }
2289 }
2290
2291 if !current_line.is_empty() {
2292 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2293 }
2294 }
2295 }
2296
2297 lines
2298}
2299
2300fn format_expression(expr: &SqlExpression) -> String {
2301 match expr {
2302 SqlExpression::Column(name) => name.clone(),
2303 SqlExpression::StringLiteral(s) => format!("'{}'", s),
2304 SqlExpression::NumberLiteral(n) => n.clone(),
2305 SqlExpression::DateTimeConstructor {
2306 year,
2307 month,
2308 day,
2309 hour,
2310 minute,
2311 second,
2312 } => {
2313 let mut result = format!("DateTime({}, {}, {}", year, month, day);
2314 if let Some(h) = hour {
2315 result.push_str(&format!(", {}", h));
2316 if let Some(m) = minute {
2317 result.push_str(&format!(", {}", m));
2318 if let Some(s) = second {
2319 result.push_str(&format!(", {}", s));
2320 }
2321 }
2322 }
2323 result.push(')');
2324 result
2325 }
2326 SqlExpression::DateTimeToday {
2327 hour,
2328 minute,
2329 second,
2330 } => {
2331 let mut result = "DateTime()".to_string();
2332 if let Some(h) = hour {
2333 result = format!("DateTime(TODAY, {}", h);
2334 if let Some(m) = minute {
2335 result.push_str(&format!(", {}", m));
2336 if let Some(s) = second {
2337 result.push_str(&format!(", {}", s));
2338 }
2339 }
2340 result.push(')');
2341 }
2342 result
2343 }
2344 SqlExpression::MethodCall {
2345 object,
2346 method,
2347 args,
2348 } => {
2349 let args_str = args
2350 .iter()
2351 .map(|arg| format_expression(arg))
2352 .collect::<Vec<_>>()
2353 .join(", ");
2354 format!("{}.{}({})", object, method, args_str)
2355 }
2356 SqlExpression::BinaryOp { left, op, right } => {
2357 if op == "OR" || op == "AND" {
2360 format!(
2363 "({} {} {})",
2364 format_expression(left),
2365 op,
2366 format_expression(right)
2367 )
2368 } else {
2369 format!(
2370 "{} {} {}",
2371 format_expression(left),
2372 op,
2373 format_expression(right)
2374 )
2375 }
2376 }
2377 SqlExpression::InList { expr, values } => {
2378 let values_str = values
2379 .iter()
2380 .map(|v| format_expression(v))
2381 .collect::<Vec<_>>()
2382 .join(", ");
2383 format!("{} IN ({})", format_expression(expr), values_str)
2384 }
2385 SqlExpression::NotInList { expr, values } => {
2386 let values_str = values
2387 .iter()
2388 .map(|v| format_expression(v))
2389 .collect::<Vec<_>>()
2390 .join(", ");
2391 format!("{} NOT IN ({})", format_expression(expr), values_str)
2392 }
2393 SqlExpression::Between { expr, lower, upper } => {
2394 format!(
2395 "{} BETWEEN {} AND {}",
2396 format_expression(expr),
2397 format_expression(lower),
2398 format_expression(upper)
2399 )
2400 }
2401 SqlExpression::Not { expr } => {
2402 format!("NOT {}", format_expression(expr))
2403 }
2404 SqlExpression::ChainedMethodCall { base, method, args } => {
2405 let args_str = args
2406 .iter()
2407 .map(|arg| format_expression(arg))
2408 .collect::<Vec<_>>()
2409 .join(", ");
2410 format!("{}.{}({})", format_expression(base), method, args_str)
2411 }
2412 SqlExpression::FunctionCall { name, args } => {
2413 let args_str = args
2414 .iter()
2415 .map(|arg| format_expression(arg))
2416 .collect::<Vec<_>>()
2417 .join(", ");
2418 format!("{}({})", name, args_str)
2419 }
2420 SqlExpression::CaseExpression {
2421 when_branches,
2422 else_branch,
2423 } => {
2424 let mut result = String::from("CASE");
2425 for branch in when_branches {
2426 result.push_str(&format!(
2427 " WHEN {} THEN {}",
2428 format_expression(&branch.condition),
2429 format_expression(&branch.result)
2430 ));
2431 }
2432 if let Some(else_expr) = else_branch {
2433 result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
2434 }
2435 result.push_str(" END");
2436 result
2437 }
2438 }
2439}
2440
2441fn format_token(token: &Token) -> String {
2442 match token {
2443 Token::Identifier(s) => s.clone(),
2444 Token::QuotedIdentifier(s) => format!("\"{}\"", s),
2445 Token::StringLiteral(s) => format!("'{}'", s),
2446 Token::NumberLiteral(n) => n.clone(),
2447 Token::DateTime => "DateTime".to_string(),
2448 Token::Case => "CASE".to_string(),
2449 Token::When => "WHEN".to_string(),
2450 Token::Then => "THEN".to_string(),
2451 Token::Else => "ELSE".to_string(),
2452 Token::End => "END".to_string(),
2453 Token::Distinct => "DISTINCT".to_string(),
2454 Token::LeftParen => "(".to_string(),
2455 Token::RightParen => ")".to_string(),
2456 Token::Comma => ",".to_string(),
2457 Token::Dot => ".".to_string(),
2458 Token::Equal => "=".to_string(),
2459 Token::NotEqual => "!=".to_string(),
2460 Token::LessThan => "<".to_string(),
2461 Token::GreaterThan => ">".to_string(),
2462 Token::LessThanOrEqual => "<=".to_string(),
2463 Token::GreaterThanOrEqual => ">=".to_string(),
2464 Token::In => "IN".to_string(),
2465 _ => format!("{:?}", token).to_uppercase(),
2466 }
2467}
2468
2469fn analyze_statement(
2470 stmt: &SelectStatement,
2471 query: &str,
2472 _cursor_pos: usize,
2473) -> (CursorContext, Option<String>) {
2474 let trimmed = query.trim();
2476
2477 let comparison_ops = [" > ", " < ", " >= ", " <= ", " = ", " != "];
2479 for op in &comparison_ops {
2480 if let Some(op_pos) = query.rfind(op) {
2481 let before_op = safe_slice_to(query, op_pos);
2482 let after_op_start = op_pos + op.len();
2483 let after_op = if after_op_start < query.len() {
2484 &query[after_op_start..]
2485 } else {
2486 ""
2487 };
2488
2489 if let Some(col_name) = before_op.split_whitespace().last() {
2491 if col_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
2492 let after_op_trimmed = after_op.trim();
2494 if after_op_trimmed.is_empty()
2495 || (after_op_trimmed
2496 .chars()
2497 .all(|c| c.is_alphanumeric() || c == '_')
2498 && !after_op_trimmed.contains('('))
2499 {
2500 let partial = if after_op_trimmed.is_empty() {
2501 None
2502 } else {
2503 Some(after_op_trimmed.to_string())
2504 };
2505 return (
2506 CursorContext::AfterComparisonOp(
2507 col_name.to_string(),
2508 op.trim().to_string(),
2509 ),
2510 partial,
2511 );
2512 }
2513 }
2514 }
2515 }
2516 }
2517
2518 if trimmed.to_uppercase().ends_with(" AND")
2520 || trimmed.to_uppercase().ends_with(" OR")
2521 || trimmed.to_uppercase().ends_with(" AND ")
2522 || trimmed.to_uppercase().ends_with(" OR ")
2523 {
2524 } else {
2526 if let Some(dot_pos) = trimmed.rfind('.') {
2528 let before_dot = safe_slice_to(trimmed, dot_pos);
2530 let after_dot_start = dot_pos + 1;
2531 let after_dot = if after_dot_start < trimmed.len() {
2532 &trimmed[after_dot_start..]
2533 } else {
2534 ""
2535 };
2536
2537 if !after_dot.contains('(') {
2540 let col_name = if before_dot.ends_with('"') {
2542 let bytes = before_dot.as_bytes();
2544 let mut pos = before_dot.len() - 1; let mut found_start = None;
2546
2547 if pos > 0 {
2549 pos -= 1;
2550 while pos > 0 {
2551 if bytes[pos] == b'"' {
2552 if pos == 0 || bytes[pos - 1] != b'\\' {
2554 found_start = Some(pos);
2555 break;
2556 }
2557 }
2558 pos -= 1;
2559 }
2560 if found_start.is_none() && bytes[0] == b'"' {
2562 found_start = Some(0);
2563 }
2564 }
2565
2566 if let Some(start) = found_start {
2567 Some(safe_slice_from(before_dot, start))
2569 } else {
2570 None
2571 }
2572 } else {
2573 before_dot
2576 .split_whitespace()
2577 .last()
2578 .map(|word| word.trim_start_matches('('))
2579 };
2580
2581 if let Some(col_name) = col_name {
2582 let is_valid = if col_name.starts_with('"') && col_name.ends_with('"') {
2584 true
2586 } else {
2587 col_name.chars().all(|c| c.is_alphanumeric() || c == '_')
2589 };
2590
2591 if is_valid {
2592 let partial_method = if after_dot.is_empty() {
2595 None
2596 } else if after_dot.chars().all(|c| c.is_alphanumeric() || c == '_') {
2597 Some(after_dot.to_string())
2598 } else {
2599 None
2600 };
2601
2602 let col_name_for_context = if col_name.starts_with('"')
2604 && col_name.ends_with('"')
2605 && col_name.len() > 2
2606 {
2607 col_name[1..col_name.len() - 1].to_string()
2608 } else {
2609 col_name.to_string()
2610 };
2611
2612 return (
2613 CursorContext::AfterColumn(col_name_for_context),
2614 partial_method,
2615 );
2616 }
2617 }
2618 }
2619 }
2620 }
2621
2622 if let Some(where_clause) = &stmt.where_clause {
2624 if trimmed.to_uppercase().ends_with(" AND") || trimmed.to_uppercase().ends_with(" OR") {
2626 let op = if trimmed.to_uppercase().ends_with(" AND") {
2627 LogicalOp::And
2628 } else {
2629 LogicalOp::Or
2630 };
2631 return (CursorContext::AfterLogicalOp(op), None);
2632 }
2633
2634 if let Some(and_pos) = query.to_uppercase().rfind(" AND ") {
2636 let after_and = safe_slice_from(query, and_pos + 5);
2637 let partial = extract_partial_at_end(after_and);
2638 if partial.is_some() {
2639 return (CursorContext::AfterLogicalOp(LogicalOp::And), partial);
2640 }
2641 }
2642
2643 if let Some(or_pos) = query.to_uppercase().rfind(" OR ") {
2644 let after_or = safe_slice_from(query, or_pos + 4);
2645 let partial = extract_partial_at_end(after_or);
2646 if partial.is_some() {
2647 return (CursorContext::AfterLogicalOp(LogicalOp::Or), partial);
2648 }
2649 }
2650
2651 if let Some(last_condition) = where_clause.conditions.last() {
2652 if let Some(connector) = &last_condition.connector {
2653 return (
2655 CursorContext::AfterLogicalOp(connector.clone()),
2656 extract_partial_at_end(query),
2657 );
2658 }
2659 }
2660 return (CursorContext::WhereClause, extract_partial_at_end(query));
2662 }
2663
2664 if query.to_uppercase().ends_with(" ORDER BY ") || query.to_uppercase().ends_with(" ORDER BY") {
2666 return (CursorContext::OrderByClause, None);
2667 }
2668
2669 if stmt.order_by.is_some() {
2671 return (CursorContext::OrderByClause, extract_partial_at_end(query));
2672 }
2673
2674 if stmt.from_table.is_some() && stmt.where_clause.is_none() && stmt.order_by.is_none() {
2675 return (CursorContext::FromClause, extract_partial_at_end(query));
2676 }
2677
2678 if stmt.columns.len() > 0 && stmt.from_table.is_none() {
2679 return (CursorContext::SelectClause, extract_partial_at_end(query));
2680 }
2681
2682 (CursorContext::Unknown, None)
2683}
2684
2685fn analyze_partial(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
2686 let upper = query.to_uppercase();
2687
2688 let trimmed = query.trim();
2690
2691 #[cfg(test)]
2692 {
2693 if trimmed.contains("\"Last Name\"") {
2694 eprintln!(
2695 "DEBUG analyze_partial: query='{}', trimmed='{}'",
2696 query, trimmed
2697 );
2698 }
2699 }
2700
2701 let comparison_ops = [" > ", " < ", " >= ", " <= ", " = ", " != "];
2703 for op in &comparison_ops {
2704 if let Some(op_pos) = query.rfind(op) {
2705 let before_op = safe_slice_to(query, op_pos);
2706 let after_op_start = op_pos + op.len();
2707 let after_op = if after_op_start < query.len() {
2708 &query[after_op_start..]
2709 } else {
2710 ""
2711 };
2712
2713 if let Some(col_name) = before_op.split_whitespace().last() {
2715 if col_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
2716 let after_op_trimmed = after_op.trim();
2718 if after_op_trimmed.is_empty()
2719 || (after_op_trimmed
2720 .chars()
2721 .all(|c| c.is_alphanumeric() || c == '_')
2722 && !after_op_trimmed.contains('('))
2723 {
2724 let partial = if after_op_trimmed.is_empty() {
2725 None
2726 } else {
2727 Some(after_op_trimmed.to_string())
2728 };
2729 return (
2730 CursorContext::AfterComparisonOp(
2731 col_name.to_string(),
2732 op.trim().to_string(),
2733 ),
2734 partial,
2735 );
2736 }
2737 }
2738 }
2739 }
2740 }
2741
2742 if let Some(dot_pos) = trimmed.rfind('.') {
2745 #[cfg(test)]
2746 {
2747 if trimmed.contains("\"Last Name\"") {
2748 eprintln!("DEBUG: Found dot at position {}", dot_pos);
2749 }
2750 }
2751 let before_dot = &trimmed[..dot_pos];
2753 let after_dot = &trimmed[dot_pos + 1..];
2754
2755 if !after_dot.contains('(') {
2758 let col_name = if before_dot.ends_with('"') {
2761 let bytes = before_dot.as_bytes();
2763 let mut pos = before_dot.len() - 1; let mut found_start = None;
2765
2766 #[cfg(test)]
2767 {
2768 if trimmed.contains("\"Last Name\"") {
2769 eprintln!(
2770 "DEBUG: before_dot='{}', looking for opening quote",
2771 before_dot
2772 );
2773 }
2774 }
2775
2776 if pos > 0 {
2778 pos -= 1;
2779 while pos > 0 {
2780 if bytes[pos] == b'"' {
2781 if pos == 0 || bytes[pos - 1] != b'\\' {
2783 found_start = Some(pos);
2784 break;
2785 }
2786 }
2787 pos -= 1;
2788 }
2789 if found_start.is_none() && bytes[0] == b'"' {
2791 found_start = Some(0);
2792 }
2793 }
2794
2795 if let Some(start) = found_start {
2796 let result = safe_slice_from(before_dot, start);
2798 #[cfg(test)]
2799 {
2800 if trimmed.contains("\"Last Name\"") {
2801 eprintln!("DEBUG: Extracted quoted identifier: '{}'", result);
2802 }
2803 }
2804 Some(result)
2805 } else {
2806 #[cfg(test)]
2807 {
2808 if trimmed.contains("\"Last Name\"") {
2809 eprintln!("DEBUG: No opening quote found!");
2810 }
2811 }
2812 None
2813 }
2814 } else {
2815 before_dot
2818 .split_whitespace()
2819 .last()
2820 .map(|word| word.trim_start_matches('('))
2821 };
2822
2823 if let Some(col_name) = col_name {
2824 #[cfg(test)]
2825 {
2826 if trimmed.contains("\"Last Name\"") {
2827 eprintln!("DEBUG: col_name = '{}'", col_name);
2828 }
2829 }
2830
2831 let is_valid = if col_name.starts_with('"') && col_name.ends_with('"') {
2833 true
2835 } else {
2836 col_name.chars().all(|c| c.is_alphanumeric() || c == '_')
2838 };
2839
2840 #[cfg(test)]
2841 {
2842 if trimmed.contains("\"Last Name\"") {
2843 eprintln!("DEBUG: is_valid = {}", is_valid);
2844 }
2845 }
2846
2847 if is_valid {
2848 let partial_method = if after_dot.is_empty() {
2851 None
2852 } else if after_dot.chars().all(|c| c.is_alphanumeric() || c == '_') {
2853 Some(after_dot.to_string())
2854 } else {
2855 None
2856 };
2857
2858 let col_name_for_context = if col_name.starts_with('"')
2860 && col_name.ends_with('"')
2861 && col_name.len() > 2
2862 {
2863 col_name[1..col_name.len() - 1].to_string()
2864 } else {
2865 col_name.to_string()
2866 };
2867
2868 return (
2869 CursorContext::AfterColumn(col_name_for_context),
2870 partial_method,
2871 );
2872 }
2873 }
2874 }
2875 }
2876
2877 if let Some(and_pos) = upper.rfind(" AND ") {
2879 if cursor_pos >= and_pos + 5 {
2881 let after_and = safe_slice_from(query, and_pos + 5);
2883 let partial = extract_partial_at_end(after_and);
2884 return (CursorContext::AfterLogicalOp(LogicalOp::And), partial);
2885 }
2886 }
2887
2888 if let Some(or_pos) = upper.rfind(" OR ") {
2889 if cursor_pos >= or_pos + 4 {
2891 let after_or = safe_slice_from(query, or_pos + 4);
2893 let partial = extract_partial_at_end(after_or);
2894 return (CursorContext::AfterLogicalOp(LogicalOp::Or), partial);
2895 }
2896 }
2897
2898 if trimmed.to_uppercase().ends_with(" AND") || trimmed.to_uppercase().ends_with(" OR") {
2900 let op = if trimmed.to_uppercase().ends_with(" AND") {
2901 LogicalOp::And
2902 } else {
2903 LogicalOp::Or
2904 };
2905 return (CursorContext::AfterLogicalOp(op), None);
2906 }
2907
2908 if upper.ends_with(" ORDER BY ") || upper.ends_with(" ORDER BY") || upper.contains("ORDER BY ")
2910 {
2911 return (CursorContext::OrderByClause, extract_partial_at_end(query));
2912 }
2913
2914 if upper.contains("WHERE") && !upper.contains("ORDER") && !upper.contains("GROUP") {
2915 return (CursorContext::WhereClause, extract_partial_at_end(query));
2916 }
2917
2918 if upper.contains("FROM") && !upper.contains("WHERE") && !upper.contains("ORDER") {
2919 return (CursorContext::FromClause, extract_partial_at_end(query));
2920 }
2921
2922 if upper.contains("SELECT") && !upper.contains("FROM") {
2923 return (CursorContext::SelectClause, extract_partial_at_end(query));
2924 }
2925
2926 (CursorContext::Unknown, None)
2927}
2928
2929fn extract_partial_at_end(query: &str) -> Option<String> {
2930 let trimmed = query.trim();
2931
2932 if let Some(last_word) = trimmed.split_whitespace().last() {
2934 if last_word.starts_with('"') && !last_word.ends_with('"') {
2935 return Some(last_word.to_string());
2937 }
2938 }
2939
2940 let last_word = trimmed.split_whitespace().last()?;
2942
2943 if last_word.chars().all(|c| c.is_alphanumeric() || c == '_') && !is_sql_keyword(last_word) {
2945 Some(last_word.to_string())
2946 } else {
2947 None
2948 }
2949}
2950
2951fn is_sql_keyword(word: &str) -> bool {
2952 matches!(
2953 word.to_uppercase().as_str(),
2954 "SELECT"
2955 | "FROM"
2956 | "WHERE"
2957 | "AND"
2958 | "OR"
2959 | "IN"
2960 | "ORDER"
2961 | "BY"
2962 | "GROUP"
2963 | "HAVING"
2964 | "ASC"
2965 | "DESC"
2966 | "DISTINCT"
2967 )
2968}
2969
2970#[cfg(test)]
2971mod tests {
2972 use super::*;
2973
2974 #[test]
2975 fn test_chained_method_calls() {
2976 let query = "SELECT * FROM trades WHERE commission.ToString().IndexOf('.') = 1";
2978 let mut parser = Parser::new(query);
2979 let result = parser.parse();
2980
2981 assert!(
2982 result.is_ok(),
2983 "Failed to parse chained method calls: {:?}",
2984 result
2985 );
2986
2987 let query2 = "SELECT * FROM data WHERE field.ToUpper().Replace('A', 'B').StartsWith('C')";
2989 let mut parser2 = Parser::new(query2);
2990 let result2 = parser2.parse();
2991
2992 assert!(
2993 result2.is_ok(),
2994 "Failed to parse multiple chained calls: {:?}",
2995 result2
2996 );
2997 }
2998
2999 #[test]
3000 fn test_tokenizer() {
3001 let mut lexer = Lexer::new("SELECT * FROM trade_deal WHERE price > 100");
3002
3003 assert!(matches!(lexer.next_token(), Token::Select));
3004 assert!(matches!(lexer.next_token(), Token::Star));
3005 assert!(matches!(lexer.next_token(), Token::From));
3006 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "trade_deal"));
3007 assert!(matches!(lexer.next_token(), Token::Where));
3008 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "price"));
3009 assert!(matches!(lexer.next_token(), Token::GreaterThan));
3010 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "100"));
3011 }
3012
3013 #[test]
3014 fn test_tokenizer_datetime() {
3015 let mut lexer = Lexer::new("WHERE createdDate > DateTime(2025, 10, 20)");
3016
3017 assert!(matches!(lexer.next_token(), Token::Where));
3018 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "createdDate"));
3019 assert!(matches!(lexer.next_token(), Token::GreaterThan));
3020 assert!(matches!(lexer.next_token(), Token::DateTime));
3021 assert!(matches!(lexer.next_token(), Token::LeftParen));
3022 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "2025"));
3023 assert!(matches!(lexer.next_token(), Token::Comma));
3024 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "10"));
3025 assert!(matches!(lexer.next_token(), Token::Comma));
3026 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "20"));
3027 assert!(matches!(lexer.next_token(), Token::RightParen));
3028 }
3029
3030 #[test]
3031 fn test_parse_simple_select() {
3032 let mut parser = Parser::new("SELECT * FROM trade_deal");
3033 let stmt = parser.parse().unwrap();
3034
3035 assert_eq!(stmt.columns, vec!["*"]);
3036 assert_eq!(stmt.from_table, Some("trade_deal".to_string()));
3037 assert!(stmt.where_clause.is_none());
3038 }
3039
3040 #[test]
3041 fn test_parse_where_with_method() {
3042 let mut parser = Parser::new("SELECT * FROM trade_deal WHERE name.Contains(\"test\")");
3043 let stmt = parser.parse().unwrap();
3044
3045 assert!(stmt.where_clause.is_some());
3046 let where_clause = stmt.where_clause.unwrap();
3047 assert_eq!(where_clause.conditions.len(), 1);
3048 }
3049
3050 #[test]
3051 fn test_parse_datetime_constructor() {
3052 let mut parser =
3053 Parser::new("SELECT * FROM trade_deal WHERE createdDate > DateTime(2025, 10, 20)");
3054 let stmt = parser.parse().unwrap();
3055
3056 assert!(stmt.where_clause.is_some());
3057 let where_clause = stmt.where_clause.unwrap();
3058 assert_eq!(where_clause.conditions.len(), 1);
3059
3060 if let SqlExpression::BinaryOp { left, op, right } = &where_clause.conditions[0].expr {
3062 assert_eq!(op, ">");
3063 assert!(matches!(left.as_ref(), SqlExpression::Column(col) if col == "createdDate"));
3064 assert!(matches!(
3065 right.as_ref(),
3066 SqlExpression::DateTimeConstructor {
3067 year: 2025,
3068 month: 10,
3069 day: 20,
3070 hour: None,
3071 minute: None,
3072 second: None
3073 }
3074 ));
3075 } else {
3076 panic!("Expected BinaryOp with DateTime constructor");
3077 }
3078 }
3079
3080 #[test]
3081 fn test_cursor_context_after_and() {
3082 let query = "SELECT * FROM trade_deal WHERE status = 'active' AND ";
3083 let (context, partial) = detect_cursor_context(query, query.len());
3084
3085 assert!(matches!(
3086 context,
3087 CursorContext::AfterLogicalOp(LogicalOp::And)
3088 ));
3089 assert_eq!(partial, None);
3090 }
3091
3092 #[test]
3093 fn test_cursor_context_with_partial() {
3094 let query = "SELECT * FROM trade_deal WHERE status = 'active' AND p";
3095 let (context, partial) = detect_cursor_context(query, query.len());
3096
3097 assert!(matches!(
3098 context,
3099 CursorContext::AfterLogicalOp(LogicalOp::And)
3100 ));
3101 assert_eq!(partial, Some("p".to_string()));
3102 }
3103
3104 #[test]
3105 fn test_cursor_context_after_datetime_comparison() {
3106 let query = "SELECT * FROM trade_deal WHERE createdDate > ";
3107 let (context, partial) = detect_cursor_context(query, query.len());
3108
3109 assert!(
3110 matches!(context, CursorContext::AfterComparisonOp(col, op) if col == "createdDate" && op == ">")
3111 );
3112 assert_eq!(partial, None);
3113 }
3114
3115 #[test]
3116 fn test_cursor_context_partial_datetime() {
3117 let query = "SELECT * FROM trade_deal WHERE createdDate > Date";
3118 let (context, partial) = detect_cursor_context(query, query.len());
3119
3120 assert!(
3121 matches!(context, CursorContext::AfterComparisonOp(col, op) if col == "createdDate" && op == ">")
3122 );
3123 assert_eq!(partial, Some("Date".to_string()));
3124 }
3125
3126 #[test]
3128 fn test_tokenizer_quoted_identifier() {
3129 let mut lexer = Lexer::new(r#"SELECT "Customer Id", "First Name" FROM customers"#);
3130
3131 assert!(matches!(lexer.next_token(), Token::Select));
3132 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "Customer Id"));
3133 assert!(matches!(lexer.next_token(), Token::Comma));
3134 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "First Name"));
3135 assert!(matches!(lexer.next_token(), Token::From));
3136 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "customers"));
3137 }
3138
3139 #[test]
3140 fn test_tokenizer_quoted_vs_string_literal() {
3141 let mut lexer = Lexer::new(r#"WHERE "Customer Id" = 'John' AND Country.Contains('USA')"#);
3143
3144 assert!(matches!(lexer.next_token(), Token::Where));
3145 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "Customer Id"));
3146 assert!(matches!(lexer.next_token(), Token::Equal));
3147 assert!(matches!(lexer.next_token(), Token::StringLiteral(s) if s == "John"));
3148 assert!(matches!(lexer.next_token(), Token::And));
3149 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Country"));
3150 assert!(matches!(lexer.next_token(), Token::Dot));
3151 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Contains"));
3152 assert!(matches!(lexer.next_token(), Token::LeftParen));
3153 assert!(matches!(lexer.next_token(), Token::StringLiteral(s) if s == "USA"));
3154 assert!(matches!(lexer.next_token(), Token::RightParen));
3155 }
3156
3157 #[test]
3158 fn test_tokenizer_method_with_double_quotes_should_be_string() {
3159 let mut lexer = Lexer::new(r#"Country.Contains("Alb")"#);
3162
3163 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Country"));
3164 assert!(matches!(lexer.next_token(), Token::Dot));
3165 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Contains"));
3166 assert!(matches!(lexer.next_token(), Token::LeftParen));
3167
3168 let token = lexer.next_token();
3171 println!("Token for \"Alb\": {:?}", token);
3172 assert!(matches!(lexer.next_token(), Token::RightParen));
3176 }
3177
3178 #[test]
3179 fn test_parse_select_with_quoted_columns() {
3180 let mut parser = Parser::new(r#"SELECT "Customer Id", Company FROM customers"#);
3181 let stmt = parser.parse().unwrap();
3182
3183 assert_eq!(stmt.columns, vec!["Customer Id", "Company"]);
3184 assert_eq!(stmt.from_table, Some("customers".to_string()));
3185 }
3186
3187 #[test]
3188 fn test_cursor_context_select_with_partial_quoted() {
3189 let query = r#"SELECT "Cust"#;
3191 let (context, partial) = detect_cursor_context(query, query.len() - 1); println!("Context: {:?}, Partial: {:?}", context, partial);
3194 assert!(matches!(context, CursorContext::SelectClause));
3195 }
3198
3199 #[test]
3200 fn test_cursor_context_select_after_comma_with_quoted() {
3201 let query = r#"SELECT Company, "Customer "#;
3203 let (context, partial) = detect_cursor_context(query, query.len());
3204
3205 println!("Context: {:?}, Partial: {:?}", context, partial);
3206 assert!(matches!(context, CursorContext::SelectClause));
3207 }
3209
3210 #[test]
3211 fn test_cursor_context_order_by_quoted() {
3212 let query = r#"SELECT * FROM customers ORDER BY "Cust"#;
3213 let (context, partial) = detect_cursor_context(query, query.len() - 1);
3214
3215 println!("Context: {:?}, Partial: {:?}", context, partial);
3216 assert!(matches!(context, CursorContext::OrderByClause));
3217 }
3219
3220 #[test]
3221 fn test_where_clause_with_quoted_column() {
3222 let mut parser = Parser::new(r#"SELECT * FROM customers WHERE "Customer Id" = 'C123'"#);
3223 let stmt = parser.parse().unwrap();
3224
3225 assert!(stmt.where_clause.is_some());
3226 let where_clause = stmt.where_clause.unwrap();
3227 assert_eq!(where_clause.conditions.len(), 1);
3228
3229 if let SqlExpression::BinaryOp { left, op, right } = &where_clause.conditions[0].expr {
3230 assert_eq!(op, "=");
3231 assert!(matches!(left.as_ref(), SqlExpression::Column(col) if col == "Customer Id"));
3232 assert!(matches!(right.as_ref(), SqlExpression::StringLiteral(s) if s == "C123"));
3233 } else {
3234 panic!("Expected BinaryOp");
3235 }
3236 }
3237
3238 #[test]
3239 fn test_parse_method_with_double_quotes_as_string() {
3240 let mut parser = Parser::new(r#"SELECT * FROM customers WHERE Country.Contains("USA")"#);
3242 let stmt = parser.parse().unwrap();
3243
3244 assert!(stmt.where_clause.is_some());
3245 let where_clause = stmt.where_clause.unwrap();
3246 assert_eq!(where_clause.conditions.len(), 1);
3247
3248 if let SqlExpression::MethodCall {
3249 object,
3250 method,
3251 args,
3252 } = &where_clause.conditions[0].expr
3253 {
3254 assert_eq!(object, "Country");
3255 assert_eq!(method, "Contains");
3256 assert_eq!(args.len(), 1);
3257 assert!(matches!(&args[0], SqlExpression::StringLiteral(s) if s == "USA"));
3259 } else {
3260 panic!("Expected MethodCall");
3261 }
3262 }
3263
3264 #[test]
3265 fn test_extract_partial_with_quoted_columns_in_query() {
3266 let query = r#"SELECT City,Company,Country,"Customer Id" FROM customers ORDER BY coun"#;
3268 let (context, partial) = detect_cursor_context(query, query.len());
3269
3270 assert!(matches!(context, CursorContext::OrderByClause));
3271 assert_eq!(
3272 partial,
3273 Some("coun".to_string()),
3274 "Should extract 'coun' as partial, not everything after the quoted column"
3275 );
3276 }
3277
3278 #[test]
3279 fn test_extract_partial_quoted_identifier_being_typed() {
3280 let query = r#"SELECT "Cust"#;
3282 let partial = extract_partial_at_end(query);
3283 assert_eq!(partial, Some("\"Cust".to_string()));
3284
3285 let query2 = r#"SELECT "Customer Id" FROM"#;
3287 let partial2 = extract_partial_at_end(query2);
3288 assert_eq!(partial2, None); }
3290
3291 #[test]
3293 fn test_complex_where_parentheses_basic() {
3294 let mut parser =
3296 Parser::new(r#"SELECT * FROM trades WHERE (status = "active" OR status = "pending")"#);
3297 let stmt = parser.parse().unwrap();
3298
3299 assert!(stmt.where_clause.is_some());
3300 let where_clause = stmt.where_clause.unwrap();
3301 assert_eq!(where_clause.conditions.len(), 1);
3302
3303 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[0].expr {
3305 assert_eq!(op, "OR");
3306 } else {
3307 panic!("Expected BinaryOp with OR");
3308 }
3309 }
3310
3311 #[test]
3312 fn test_complex_where_mixed_and_or_with_parens() {
3313 let mut parser = Parser::new(
3315 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#,
3316 );
3317 let stmt = parser.parse().unwrap();
3318
3319 assert!(stmt.where_clause.is_some());
3320 let where_clause = stmt.where_clause.unwrap();
3321 assert_eq!(where_clause.conditions.len(), 2);
3322
3323 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[0].expr {
3325 assert_eq!(op, "OR");
3326 } else {
3327 panic!("Expected first condition to be OR expression");
3328 }
3329
3330 assert!(matches!(
3332 where_clause.conditions[0].connector,
3333 Some(LogicalOp::And)
3334 ));
3335
3336 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[1].expr {
3338 assert_eq!(op, ">");
3339 } else {
3340 panic!("Expected second condition to be price > 100");
3341 }
3342 }
3343
3344 #[test]
3345 fn test_complex_where_nested_parentheses() {
3346 let mut parser = Parser::new(
3348 r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND price > 100) OR status = "cancelled""#,
3349 );
3350 let stmt = parser.parse().unwrap();
3351
3352 assert!(stmt.where_clause.is_some());
3353 let where_clause = stmt.where_clause.unwrap();
3354
3355 assert!(where_clause.conditions.len() > 0);
3357 }
3358
3359 #[test]
3360 fn test_complex_where_multiple_or_groups() {
3361 let mut parser = Parser::new(
3363 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL" OR symbol = "MSFT") AND (price > 100 AND price < 500)"#,
3364 );
3365 let stmt = parser.parse().unwrap();
3366
3367 assert!(stmt.where_clause.is_some());
3368 let where_clause = stmt.where_clause.unwrap();
3369 assert_eq!(where_clause.conditions.len(), 2);
3370
3371 assert!(matches!(
3373 where_clause.conditions[0].connector,
3374 Some(LogicalOp::And)
3375 ));
3376 }
3377
3378 #[test]
3379 fn test_complex_where_with_methods_in_parens() {
3380 let mut parser = Parser::new(
3382 r#"SELECT * FROM trades WHERE (symbol.StartsWith("A") OR symbol.StartsWith("G")) AND volume > 1000000"#,
3383 );
3384 let stmt = parser.parse().unwrap();
3385
3386 assert!(stmt.where_clause.is_some());
3387 let where_clause = stmt.where_clause.unwrap();
3388 assert_eq!(where_clause.conditions.len(), 2);
3389
3390 if let SqlExpression::BinaryOp { op, left, right } = &where_clause.conditions[0].expr {
3392 assert_eq!(op, "OR");
3393 assert!(matches!(left.as_ref(), SqlExpression::MethodCall { .. }));
3394 assert!(matches!(right.as_ref(), SqlExpression::MethodCall { .. }));
3395 } else {
3396 panic!("Expected OR of method calls");
3397 }
3398 }
3399
3400 #[test]
3401 fn test_complex_where_date_comparisons_with_parens() {
3402 let mut parser = Parser::new(
3404 r#"SELECT * FROM trades WHERE (executionDate > DateTime(2024, 1, 1) AND executionDate < DateTime(2024, 12, 31)) AND (status = "filled" OR status = "partial")"#,
3405 );
3406 let stmt = parser.parse().unwrap();
3407
3408 assert!(stmt.where_clause.is_some());
3409 let where_clause = stmt.where_clause.unwrap();
3410 assert_eq!(where_clause.conditions.len(), 2);
3411
3412 assert!(matches!(
3414 where_clause.conditions[0].connector,
3415 Some(LogicalOp::And)
3416 ));
3417 }
3418
3419 #[test]
3420 fn test_complex_where_price_volume_filters() {
3421 let mut parser = Parser::new(
3423 r#"SELECT * FROM trades WHERE ((price > 100 AND price < 200) OR (price > 500 AND price < 1000)) AND volume > 10000"#,
3424 );
3425 let stmt = parser.parse().unwrap();
3426
3427 assert!(stmt.where_clause.is_some());
3428 let where_clause = stmt.where_clause.unwrap();
3429
3430 assert!(where_clause.conditions.len() > 0);
3432 }
3433
3434 #[test]
3435 fn test_complex_where_mixed_string_numeric() {
3436 let mut parser = Parser::new(
3438 r#"SELECT * FROM trades WHERE (exchange = "NYSE" OR exchange = "NASDAQ") AND (volume > 1000000 OR notes.Contains("urgent"))"#,
3439 );
3440 let stmt = parser.parse().unwrap();
3441
3442 assert!(stmt.where_clause.is_some());
3443 }
3445
3446 #[test]
3447 fn test_complex_where_triple_nested() {
3448 let mut parser = Parser::new(
3450 r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND (price > 100 OR volume > 1000000)) OR (status = "cancelled" AND reason.Contains("timeout"))"#,
3451 );
3452 let stmt = parser.parse().unwrap();
3453
3454 assert!(stmt.where_clause.is_some());
3455 }
3457
3458 #[test]
3459 fn test_complex_where_single_parens_around_and() {
3460 let mut parser = Parser::new(
3462 r#"SELECT * FROM trades WHERE (symbol = "AAPL" AND price > 150 AND volume > 100000)"#,
3463 );
3464 let stmt = parser.parse().unwrap();
3465
3466 assert!(stmt.where_clause.is_some());
3467 let where_clause = stmt.where_clause.unwrap();
3468
3469 assert!(where_clause.conditions.len() > 0);
3471 }
3472
3473 #[test]
3475 fn test_format_preserves_simple_parentheses() {
3476 let query = r#"SELECT * FROM trades WHERE (status = "active" OR status = "pending")"#;
3477 let formatted = format_sql_pretty_compact(query, 5);
3478 let formatted_text = formatted.join(" ");
3479
3480 assert!(formatted_text.contains("(status"));
3482 assert!(formatted_text.contains("\"pending\")"));
3483
3484 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3486 let formatted_parens = formatted_text
3487 .chars()
3488 .filter(|c| *c == '(' || *c == ')')
3489 .count();
3490 assert_eq!(
3491 original_parens, formatted_parens,
3492 "Parentheses should be preserved"
3493 );
3494 }
3495
3496 #[test]
3497 fn test_format_preserves_complex_parentheses() {
3498 let query =
3499 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#;
3500 let formatted = format_sql_pretty_compact(query, 5);
3501 let formatted_text = formatted.join(" ");
3502
3503 assert!(formatted_text.contains("(symbol"));
3505 assert!(formatted_text.contains("\"GOOGL\")"));
3506
3507 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3509 let formatted_parens = formatted_text
3510 .chars()
3511 .filter(|c| *c == '(' || *c == ')')
3512 .count();
3513 assert_eq!(original_parens, formatted_parens);
3514 }
3515
3516 #[test]
3517 fn test_format_preserves_nested_parentheses() {
3518 let query = r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND price > 100) OR status = "cancelled""#;
3519 let formatted = format_sql_pretty_compact(query, 5);
3520 let formatted_text = formatted.join(" ");
3521
3522 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3524 let formatted_parens = formatted_text
3525 .chars()
3526 .filter(|c| *c == '(' || *c == ')')
3527 .count();
3528 assert_eq!(
3529 original_parens, formatted_parens,
3530 "Nested parentheses should be preserved"
3531 );
3532 assert_eq!(original_parens, 4, "Should have 4 parentheses total");
3533 }
3534
3535 #[test]
3536 fn test_format_preserves_method_calls_in_parentheses() {
3537 let query = r#"SELECT * FROM trades WHERE (symbol.StartsWith("A") OR symbol.StartsWith("G")) AND volume > 1000000"#;
3538 let formatted = format_sql_pretty_compact(query, 5);
3539 let formatted_text = formatted.join(" ");
3540
3541 assert!(formatted_text.contains("(symbol.StartsWith"));
3543 assert!(formatted_text.contains("StartsWith(\"A\")"));
3544 assert!(formatted_text.contains("StartsWith(\"G\")"));
3545
3546 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3548 let formatted_parens = formatted_text
3549 .chars()
3550 .filter(|c| *c == '(' || *c == ')')
3551 .count();
3552 assert_eq!(original_parens, formatted_parens);
3553 assert_eq!(
3554 original_parens, 6,
3555 "Should have 6 parentheses (1 group + 2 method calls)"
3556 );
3557 }
3558
3559 #[test]
3560 fn test_format_preserves_multiple_groups() {
3561 let query = r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND (price > 100 AND price < 500)"#;
3562 let formatted = format_sql_pretty_compact(query, 5);
3563 let formatted_text = formatted.join(" ");
3564
3565 assert!(formatted_text.contains("(symbol"));
3567 assert!(formatted_text.contains("(price"));
3568
3569 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3570 let formatted_parens = formatted_text
3571 .chars()
3572 .filter(|c| *c == '(' || *c == ')')
3573 .count();
3574 assert_eq!(original_parens, formatted_parens);
3575 assert_eq!(original_parens, 4, "Should have 4 parentheses (2 groups)");
3576 }
3577
3578 #[test]
3579 fn test_format_preserves_date_ranges() {
3580 let query = r#"SELECT * FROM trades WHERE (executionDate > DateTime(2024, 1, 1) AND executionDate < DateTime(2024, 12, 31))"#;
3581 let formatted = format_sql_pretty_compact(query, 5);
3582 let formatted_text = formatted.join(" ");
3583
3584 assert!(formatted_text.contains("(executionDate"));
3586 assert!(formatted_text.contains("DateTime(2024, 1, 1)"));
3587 assert!(formatted_text.contains("DateTime(2024, 12, 31)"));
3588
3589 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3590 let formatted_parens = formatted_text
3591 .chars()
3592 .filter(|c| *c == '(' || *c == ')')
3593 .count();
3594 assert_eq!(original_parens, formatted_parens);
3595 }
3596
3597 #[test]
3598 fn test_format_multiline_layout() {
3599 let query =
3601 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#;
3602 let formatted = format_sql_pretty_compact(query, 5);
3603
3604 assert!(formatted.len() >= 4, "Should have multiple lines");
3606 assert_eq!(formatted[0], "SELECT");
3607 assert!(formatted[1].trim().starts_with("*"));
3608 assert!(formatted[2].starts_with("FROM"));
3609 assert_eq!(formatted[3], "WHERE");
3610
3611 let where_lines: Vec<_> = formatted.iter().skip(4).collect();
3613 assert!(where_lines.iter().any(|l| l.contains("(symbol")));
3614 assert!(where_lines.iter().any(|l| l.trim() == "AND"));
3615 }
3616
3617 #[test]
3618 fn test_between_simple() {
3619 let mut parser = Parser::new("SELECT * FROM table WHERE price BETWEEN 50 AND 100");
3620 let stmt = parser.parse().expect("Should parse simple BETWEEN");
3621
3622 assert!(stmt.where_clause.is_some());
3623 let where_clause = stmt.where_clause.unwrap();
3624 assert_eq!(where_clause.conditions.len(), 1);
3625
3626 let ast = format_ast_tree("SELECT * FROM table WHERE price BETWEEN 50 AND 100");
3628 assert!(!ast.contains("PARSE ERROR"));
3629 assert!(ast.contains("SelectStatement"));
3630 }
3631
3632 #[test]
3633 fn test_between_in_parentheses() {
3634 let mut parser = Parser::new("SELECT * FROM table WHERE (price BETWEEN 50 AND 100)");
3635 let stmt = parser.parse().expect("Should parse BETWEEN in parentheses");
3636
3637 assert!(stmt.where_clause.is_some());
3638
3639 let ast = format_ast_tree("SELECT * FROM table WHERE (price BETWEEN 50 AND 100)");
3641 assert!(!ast.contains("PARSE ERROR"), "Should not have parse error");
3642 }
3643
3644 #[test]
3645 fn test_between_with_or() {
3646 let query = "SELECT * FROM test WHERE (Price BETWEEN 50 AND 100) OR (quantity > 5)";
3647 let mut parser = Parser::new(query);
3648 let stmt = parser.parse().expect("Should parse BETWEEN with OR");
3649
3650 assert!(stmt.where_clause.is_some());
3651 }
3654
3655 #[test]
3656 fn test_between_with_and() {
3657 let query = "SELECT * FROM table WHERE category = 'Books' AND price BETWEEN 10 AND 50";
3658 let mut parser = Parser::new(query);
3659 let stmt = parser.parse().expect("Should parse BETWEEN with AND");
3660
3661 assert!(stmt.where_clause.is_some());
3662 let where_clause = stmt.where_clause.unwrap();
3663 assert_eq!(where_clause.conditions.len(), 2); }
3665
3666 #[test]
3667 fn test_multiple_between() {
3668 let query =
3669 "SELECT * FROM table WHERE (price BETWEEN 10 AND 50) AND (quantity BETWEEN 5 AND 20)";
3670 let mut parser = Parser::new(query);
3671 let stmt = parser
3672 .parse()
3673 .expect("Should parse multiple BETWEEN clauses");
3674
3675 assert!(stmt.where_clause.is_some());
3676 }
3677
3678 #[test]
3679 fn test_between_complex_query() {
3680 let query = "SELECT * FROM test_sorting WHERE (Price BETWEEN 50 AND 100) OR (Product.Length() > 5) ORDER BY Category ASC, price DESC";
3682 let mut parser = Parser::new(query);
3683 let stmt = parser
3684 .parse()
3685 .expect("Should parse complex query with BETWEEN, method calls, and ORDER BY");
3686
3687 assert!(stmt.where_clause.is_some());
3688 assert!(stmt.order_by.is_some());
3689
3690 let order_by = stmt.order_by.unwrap();
3691 assert_eq!(order_by.len(), 2);
3692 assert_eq!(order_by[0].column, "Category");
3693 assert!(matches!(order_by[0].direction, SortDirection::Asc));
3694 assert_eq!(order_by[1].column, "price");
3695 assert!(matches!(order_by[1].direction, SortDirection::Desc));
3696 }
3697
3698 #[test]
3699 fn test_between_formatting() {
3700 let expr = SqlExpression::Between {
3701 expr: Box::new(SqlExpression::Column("price".to_string())),
3702 lower: Box::new(SqlExpression::NumberLiteral("50".to_string())),
3703 upper: Box::new(SqlExpression::NumberLiteral("100".to_string())),
3704 };
3705
3706 let formatted = format_expression(&expr);
3707 assert_eq!(formatted, "price BETWEEN 50 AND 100");
3708
3709 let ast_formatted = format_expression_ast(&expr);
3710 assert!(ast_formatted.contains("Between"));
3711 assert!(ast_formatted.contains("50"));
3712 assert!(ast_formatted.contains("100"));
3713 }
3714
3715 #[test]
3716 fn test_utf8_boundary_safety() {
3717 let query_with_unicode = "SELECT * FROM table WHERE column = 'héllo'";
3719
3720 for pos in 0..query_with_unicode.len() + 1 {
3722 let result =
3724 std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, pos));
3725
3726 assert!(
3727 result.is_ok(),
3728 "Panic at position {} in query with Unicode",
3729 pos
3730 );
3731 }
3732
3733 let result = std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, 1000));
3735 assert!(result.is_ok(), "Panic with position beyond string length");
3736
3737 let pos_in_e = query_with_unicode.find('é').unwrap() + 1; let result =
3740 std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, pos_in_e));
3741 assert!(
3742 result.is_ok(),
3743 "Panic with cursor in middle of UTF-8 character"
3744 );
3745 }
3746}