1use chrono::{Datelike, Local, NaiveDateTime};
2
3#[derive(Debug, Clone, PartialEq)]
4pub enum Token {
5 Select,
7 From,
8 Where,
9 And,
10 Or,
11 In,
12 Not,
13 Between,
14 Like,
15 Is,
16 Null,
17 OrderBy,
18 GroupBy,
19 Having,
20 As,
21 Asc,
22 Desc,
23 Limit,
24 Offset,
25 DateTime, Case, When, Then, Else, End, Distinct, Identifier(String),
35 QuotedIdentifier(String), StringLiteral(String),
37 NumberLiteral(String),
38 Star,
39
40 Dot,
42 Comma,
43 LeftParen,
44 RightParen,
45 Equal,
46 NotEqual,
47 LessThan,
48 GreaterThan,
49 LessThanOrEqual,
50 GreaterThanOrEqual,
51
52 Plus,
54 Minus,
55 Divide,
56
57 Eof,
59}
60
61#[derive(Debug, Clone)]
62pub struct Lexer {
63 input: Vec<char>,
64 position: usize,
65 current_char: Option<char>,
66}
67
68impl Lexer {
69 pub fn new(input: &str) -> Self {
70 let chars: Vec<char> = input.chars().collect();
71 let current = chars.get(0).copied();
72 Self {
73 input: chars,
74 position: 0,
75 current_char: current,
76 }
77 }
78
79 fn advance(&mut self) {
80 self.position += 1;
81 self.current_char = self.input.get(self.position).copied();
82 }
83
84 fn peek(&self, offset: usize) -> Option<char> {
85 self.input.get(self.position + offset).copied()
86 }
87
88 fn skip_whitespace(&mut self) {
89 while let Some(ch) = self.current_char {
90 if ch.is_whitespace() {
91 self.advance();
92 } else {
93 break;
94 }
95 }
96 }
97
98 fn skip_whitespace_and_comments(&mut self) {
99 loop {
100 while let Some(ch) = self.current_char {
102 if ch.is_whitespace() {
103 self.advance();
104 } else {
105 break;
106 }
107 }
108
109 match self.current_char {
111 Some('-') if self.peek(1) == Some('-') => {
112 self.advance(); self.advance(); while let Some(ch) = self.current_char {
116 self.advance();
117 if ch == '\n' {
118 break;
119 }
120 }
121 }
122 Some('/') if self.peek(1) == Some('*') => {
123 self.advance(); self.advance(); while let Some(ch) = self.current_char {
127 if ch == '*' && self.peek(1) == Some('/') {
128 self.advance(); self.advance(); break;
131 }
132 self.advance();
133 }
134 }
135 _ => {
136 break;
138 }
139 }
140 }
141 }
142
143 fn read_identifier(&mut self) -> String {
144 let mut result = String::new();
145 while let Some(ch) = self.current_char {
146 if ch.is_alphanumeric() || ch == '_' {
147 result.push(ch);
148 self.advance();
149 } else {
150 break;
151 }
152 }
153 result
154 }
155
156 fn read_string(&mut self) -> String {
157 let mut result = String::new();
158 let quote_char = self.current_char.unwrap(); self.advance(); while let Some(ch) = self.current_char {
162 if ch == quote_char {
163 self.advance(); break;
165 }
166 result.push(ch);
167 self.advance();
168 }
169 result
170 }
171
172 fn read_number(&mut self) -> String {
173 let mut result = String::new();
174 let mut has_e = false;
175
176 while let Some(ch) = self.current_char {
178 if !has_e && (ch.is_numeric() || ch == '.') {
179 result.push(ch);
180 self.advance();
181 } else if (ch == 'e' || ch == 'E') && !has_e && !result.is_empty() {
182 result.push(ch);
184 self.advance();
185 has_e = true;
186
187 if let Some(sign) = self.current_char {
189 if sign == '+' || sign == '-' {
190 result.push(sign);
191 self.advance();
192 }
193 }
194
195 while let Some(digit) = self.current_char {
197 if digit.is_numeric() {
198 result.push(digit);
199 self.advance();
200 } else {
201 break;
202 }
203 }
204 break; } else {
206 break;
207 }
208 }
209 result
210 }
211
212 pub fn next_token(&mut self) -> Token {
213 self.skip_whitespace_and_comments();
214
215 match self.current_char {
216 None => Token::Eof,
217 Some('*') => {
218 self.advance();
219 Token::Star }
223 Some('+') => {
224 self.advance();
225 Token::Plus
226 }
227 Some('/') => {
228 if self.peek(1) == Some('*') {
230 self.skip_whitespace_and_comments();
233 return self.next_token();
234 }
235 self.advance();
236 Token::Divide
237 }
238 Some('.') => {
239 self.advance();
240 Token::Dot
241 }
242 Some(',') => {
243 self.advance();
244 Token::Comma
245 }
246 Some('(') => {
247 self.advance();
248 Token::LeftParen
249 }
250 Some(')') => {
251 self.advance();
252 Token::RightParen
253 }
254 Some('=') => {
255 self.advance();
256 Token::Equal
257 }
258 Some('<') => {
259 self.advance();
260 if self.current_char == Some('=') {
261 self.advance();
262 Token::LessThanOrEqual
263 } else if self.current_char == Some('>') {
264 self.advance();
265 Token::NotEqual
266 } else {
267 Token::LessThan
268 }
269 }
270 Some('>') => {
271 self.advance();
272 if self.current_char == Some('=') {
273 self.advance();
274 Token::GreaterThanOrEqual
275 } else {
276 Token::GreaterThan
277 }
278 }
279 Some('!') if self.peek(1) == Some('=') => {
280 self.advance();
281 self.advance();
282 Token::NotEqual
283 }
284 Some('"') => {
285 let ident_val = self.read_string();
287 Token::QuotedIdentifier(ident_val)
288 }
289 Some('\'') => {
290 let string_val = self.read_string();
292 Token::StringLiteral(string_val)
293 }
294 Some('-') if self.peek(1) == Some('-') => {
295 self.skip_whitespace_and_comments();
297 self.next_token()
298 }
299 Some('-') if self.peek(1).map_or(false, |c| c.is_numeric()) => {
300 self.advance(); let num = self.read_number();
303 Token::NumberLiteral(format!("-{}", num))
304 }
305 Some('-') => {
306 self.advance();
308 Token::Minus
309 }
310 Some(ch) if ch.is_numeric() => {
311 let num = self.read_number();
312 Token::NumberLiteral(num)
313 }
314 Some(ch) if ch.is_alphabetic() || ch == '_' => {
315 let ident = self.read_identifier();
316 match ident.to_uppercase().as_str() {
317 "SELECT" => Token::Select,
318 "FROM" => Token::From,
319 "WHERE" => Token::Where,
320 "AND" => Token::And,
321 "OR" => Token::Or,
322 "IN" => Token::In,
323 "NOT" => Token::Not,
324 "BETWEEN" => Token::Between,
325 "LIKE" => Token::Like,
326 "IS" => Token::Is,
327 "NULL" => Token::Null,
328 "ORDER" if self.peek_keyword("BY") => {
329 self.skip_whitespace();
330 self.read_identifier(); Token::OrderBy
332 }
333 "GROUP" if self.peek_keyword("BY") => {
334 self.skip_whitespace();
335 self.read_identifier(); Token::GroupBy
337 }
338 "HAVING" => Token::Having,
339 "AS" => Token::As,
340 "ASC" => Token::Asc,
341 "DESC" => Token::Desc,
342 "LIMIT" => Token::Limit,
343 "OFFSET" => Token::Offset,
344 "DATETIME" => Token::DateTime,
345 "CASE" => Token::Case,
346 "WHEN" => Token::When,
347 "THEN" => Token::Then,
348 "ELSE" => Token::Else,
349 "END" => Token::End,
350 "DISTINCT" => Token::Distinct,
351 _ => Token::Identifier(ident),
352 }
353 }
354 Some(ch) => {
355 self.advance();
356 Token::Identifier(ch.to_string())
357 }
358 }
359 }
360
361 fn peek_keyword(&mut self, keyword: &str) -> bool {
362 let saved_pos = self.position;
363 let saved_char = self.current_char;
364
365 self.skip_whitespace_and_comments();
366 let next_word = self.read_identifier();
367 let matches = next_word.to_uppercase() == keyword;
368
369 self.position = saved_pos;
371 self.current_char = saved_char;
372
373 matches
374 }
375
376 pub fn get_position(&self) -> usize {
377 self.position
378 }
379
380 pub fn tokenize_all(&mut self) -> Vec<Token> {
381 let mut tokens = Vec::new();
382 loop {
383 let token = self.next_token();
384 if matches!(token, Token::Eof) {
385 tokens.push(token);
386 break;
387 }
388 tokens.push(token);
389 }
390 tokens
391 }
392
393 pub fn tokenize_all_with_positions(&mut self) -> Vec<(usize, usize, Token)> {
394 let mut tokens = Vec::new();
395 loop {
396 self.skip_whitespace_and_comments();
397 let start_pos = self.position;
398 let token = self.next_token();
399 let end_pos = self.position;
400
401 if matches!(token, Token::Eof) {
402 break;
403 }
404 tokens.push((start_pos, end_pos, token));
405 }
406 tokens
407 }
408}
409
410#[derive(Debug, Clone)]
412pub enum SqlExpression {
413 Column(String),
414 StringLiteral(String),
415 NumberLiteral(String),
416 BooleanLiteral(bool),
417 DateTimeConstructor {
418 year: i32,
419 month: u32,
420 day: u32,
421 hour: Option<u32>,
422 minute: Option<u32>,
423 second: Option<u32>,
424 },
425 DateTimeToday {
426 hour: Option<u32>,
427 minute: Option<u32>,
428 second: Option<u32>,
429 },
430 MethodCall {
431 object: String,
432 method: String,
433 args: Vec<SqlExpression>,
434 },
435 ChainedMethodCall {
436 base: Box<SqlExpression>,
437 method: String,
438 args: Vec<SqlExpression>,
439 },
440 FunctionCall {
441 name: String,
442 args: Vec<SqlExpression>,
443 },
444 BinaryOp {
445 left: Box<SqlExpression>,
446 op: String,
447 right: Box<SqlExpression>,
448 },
449 InList {
450 expr: Box<SqlExpression>,
451 values: Vec<SqlExpression>,
452 },
453 NotInList {
454 expr: Box<SqlExpression>,
455 values: Vec<SqlExpression>,
456 },
457 Between {
458 expr: Box<SqlExpression>,
459 lower: Box<SqlExpression>,
460 upper: Box<SqlExpression>,
461 },
462 Not {
463 expr: Box<SqlExpression>,
464 },
465 CaseExpression {
466 when_branches: Vec<WhenBranch>,
467 else_branch: Option<Box<SqlExpression>>,
468 },
469}
470
471#[derive(Debug, Clone)]
472pub struct WhenBranch {
473 pub condition: Box<SqlExpression>,
474 pub result: Box<SqlExpression>,
475}
476
477#[derive(Debug, Clone)]
478pub struct WhereClause {
479 pub conditions: Vec<Condition>,
480}
481
482#[derive(Debug, Clone)]
483pub struct Condition {
484 pub expr: SqlExpression,
485 pub connector: Option<LogicalOp>, }
487
488#[derive(Debug, Clone)]
489pub enum LogicalOp {
490 And,
491 Or,
492}
493
494#[derive(Debug, Clone, PartialEq)]
495pub enum SortDirection {
496 Asc,
497 Desc,
498}
499
500#[derive(Debug, Clone)]
501pub struct OrderByColumn {
502 pub column: String,
503 pub direction: SortDirection,
504}
505
506#[derive(Debug, Clone)]
508pub enum SelectItem {
509 Column(String),
511 Expression { expr: SqlExpression, alias: String },
513 Star,
515}
516
517#[derive(Debug, Clone)]
518pub struct SelectStatement {
519 pub columns: Vec<String>, pub select_items: Vec<SelectItem>, pub from_table: Option<String>,
522 pub where_clause: Option<WhereClause>,
523 pub order_by: Option<Vec<OrderByColumn>>,
524 pub group_by: Option<Vec<String>>,
525 pub limit: Option<usize>,
526 pub offset: Option<usize>,
527}
528
529pub struct ParserConfig {
530 pub case_insensitive: bool,
531}
532
533impl Default for ParserConfig {
534 fn default() -> Self {
535 Self {
536 case_insensitive: false,
537 }
538 }
539}
540
541pub struct Parser {
542 lexer: Lexer,
543 current_token: Token,
544 in_method_args: bool, columns: Vec<String>, paren_depth: i32, #[allow(dead_code)]
548 config: ParserConfig, }
550
551impl Parser {
552 pub fn new(input: &str) -> Self {
553 let mut lexer = Lexer::new(input);
554 let current_token = lexer.next_token();
555 Self {
556 lexer,
557 current_token,
558 in_method_args: false,
559 columns: Vec::new(),
560 paren_depth: 0,
561 config: ParserConfig::default(),
562 }
563 }
564
565 pub fn with_config(input: &str, config: ParserConfig) -> Self {
566 let mut lexer = Lexer::new(input);
567 let current_token = lexer.next_token();
568 Self {
569 lexer,
570 current_token,
571 in_method_args: false,
572 columns: Vec::new(),
573 paren_depth: 0,
574 config,
575 }
576 }
577
578 pub fn with_columns(mut self, columns: Vec<String>) -> Self {
579 self.columns = columns;
580 self
581 }
582
583 fn consume(&mut self, expected: Token) -> Result<(), String> {
584 if std::mem::discriminant(&self.current_token) == std::mem::discriminant(&expected) {
585 match &expected {
587 Token::LeftParen => self.paren_depth += 1,
588 Token::RightParen => {
589 self.paren_depth -= 1;
590 if self.paren_depth < 0 {
592 return Err(
593 "Unexpected closing parenthesis - no matching opening parenthesis"
594 .to_string(),
595 );
596 }
597 }
598 _ => {}
599 }
600
601 self.current_token = self.lexer.next_token();
602 Ok(())
603 } else {
604 let error_msg = match (&expected, &self.current_token) {
606 (Token::RightParen, Token::Eof) if self.paren_depth > 0 => {
607 format!(
608 "Unclosed parenthesis - missing {} closing parenthes{}",
609 self.paren_depth,
610 if self.paren_depth == 1 { "is" } else { "es" }
611 )
612 }
613 (Token::RightParen, _) if self.paren_depth > 0 => {
614 format!(
615 "Expected closing parenthesis but found {:?} (currently {} unclosed parenthes{})",
616 self.current_token,
617 self.paren_depth,
618 if self.paren_depth == 1 { "is" } else { "es" }
619 )
620 }
621 _ => format!("Expected {:?}, found {:?}", expected, self.current_token),
622 };
623 Err(error_msg)
624 }
625 }
626
627 fn advance(&mut self) {
628 match &self.current_token {
630 Token::LeftParen => self.paren_depth += 1,
631 Token::RightParen => {
632 self.paren_depth -= 1;
633 }
636 _ => {}
637 }
638 self.current_token = self.lexer.next_token();
639 }
640
641 pub fn parse(&mut self) -> Result<SelectStatement, String> {
642 self.parse_select_statement()
643 }
644
645 fn parse_select_statement(&mut self) -> Result<SelectStatement, String> {
646 self.consume(Token::Select)?;
647
648 let select_items = self.parse_select_items()?;
650
651 let columns = select_items
653 .iter()
654 .map(|item| match item {
655 SelectItem::Star => "*".to_string(),
656 SelectItem::Column(name) => name.clone(),
657 SelectItem::Expression { alias, .. } => alias.clone(),
658 })
659 .collect();
660
661 let from_table = if matches!(self.current_token, Token::From) {
662 self.advance();
663 match &self.current_token {
664 Token::Identifier(table) => {
665 let table_name = table.clone();
666 self.advance();
667 Some(table_name)
668 }
669 Token::QuotedIdentifier(table) => {
670 let table_name = table.clone();
672 self.advance();
673 Some(table_name)
674 }
675 _ => return Err("Expected table name after FROM".to_string()),
676 }
677 } else {
678 None
679 };
680
681 let where_clause = if matches!(self.current_token, Token::Where) {
682 self.advance();
683 Some(self.parse_where_clause()?)
684 } else {
685 None
686 };
687
688 let order_by = if matches!(self.current_token, Token::OrderBy) {
689 self.advance();
690 Some(self.parse_order_by_list()?)
691 } else {
692 None
693 };
694
695 let group_by = if matches!(self.current_token, Token::GroupBy) {
696 self.advance();
697 Some(self.parse_identifier_list()?)
698 } else {
699 None
700 };
701
702 let limit = if matches!(self.current_token, Token::Limit) {
704 self.advance();
705 match &self.current_token {
706 Token::NumberLiteral(num) => {
707 let limit_val = num
708 .parse::<usize>()
709 .map_err(|_| format!("Invalid LIMIT value: {}", num))?;
710 self.advance();
711 Some(limit_val)
712 }
713 _ => return Err("Expected number after LIMIT".to_string()),
714 }
715 } else {
716 None
717 };
718
719 let offset = if matches!(self.current_token, Token::Offset) {
721 self.advance();
722 match &self.current_token {
723 Token::NumberLiteral(num) => {
724 let offset_val = num
725 .parse::<usize>()
726 .map_err(|_| format!("Invalid OFFSET value: {}", num))?;
727 self.advance();
728 Some(offset_val)
729 }
730 _ => return Err("Expected number after OFFSET".to_string()),
731 }
732 } else {
733 None
734 };
735
736 if self.paren_depth > 0 {
738 return Err(format!(
739 "Unclosed parenthesis - missing {} closing parenthes{}",
740 self.paren_depth,
741 if self.paren_depth == 1 { "is" } else { "es" }
742 ));
743 } else if self.paren_depth < 0 {
744 return Err(
745 "Extra closing parenthesis found - no matching opening parenthesis".to_string(),
746 );
747 }
748
749 Ok(SelectStatement {
750 columns,
751 select_items,
752 from_table,
753 where_clause,
754 order_by,
755 group_by,
756 limit,
757 offset,
758 })
759 }
760
761 fn parse_select_list(&mut self) -> Result<Vec<String>, String> {
762 let mut columns = Vec::new();
763
764 if matches!(self.current_token, Token::Star) {
765 columns.push("*".to_string());
766 self.advance();
767 } else {
768 loop {
769 match &self.current_token {
770 Token::Identifier(col) => {
771 columns.push(col.clone());
772 self.advance();
773 }
774 Token::QuotedIdentifier(col) => {
775 columns.push(col.clone());
777 self.advance();
778 }
779 _ => return Err("Expected column name".to_string()),
780 }
781
782 if matches!(self.current_token, Token::Comma) {
783 self.advance();
784 } else {
785 break;
786 }
787 }
788 }
789
790 Ok(columns)
791 }
792
793 fn parse_select_items(&mut self) -> Result<Vec<SelectItem>, String> {
795 let mut items = Vec::new();
796
797 loop {
798 if matches!(self.current_token, Token::Star) {
801 items.push(SelectItem::Star);
809 self.advance();
810 } else {
811 let expr = self.parse_additive()?; let alias = if matches!(self.current_token, Token::As) {
816 self.advance();
817 match &self.current_token {
818 Token::Identifier(alias_name) => {
819 let alias = alias_name.clone();
820 self.advance();
821 alias
822 }
823 Token::QuotedIdentifier(alias_name) => {
824 let alias = alias_name.clone();
825 self.advance();
826 alias
827 }
828 _ => return Err("Expected alias name after AS".to_string()),
829 }
830 } else {
831 match &expr {
833 SqlExpression::Column(col_name) => col_name.clone(),
834 _ => format!("expr_{}", items.len() + 1), }
836 };
837
838 let item = match expr {
840 SqlExpression::Column(col_name) if alias == col_name => {
841 SelectItem::Column(col_name)
843 }
844 _ => {
845 SelectItem::Expression { expr, alias }
847 }
848 };
849
850 items.push(item);
851 }
852
853 if matches!(self.current_token, Token::Comma) {
855 self.advance();
856 } else {
857 break;
858 }
859 }
860
861 Ok(items)
862 }
863
864 fn parse_identifier_list(&mut self) -> Result<Vec<String>, String> {
865 let mut identifiers = Vec::new();
866
867 loop {
868 match &self.current_token {
869 Token::Identifier(id) => {
870 identifiers.push(id.clone());
871 self.advance();
872 }
873 Token::QuotedIdentifier(id) => {
874 identifiers.push(id.clone());
876 self.advance();
877 }
878 _ => return Err("Expected identifier".to_string()),
879 }
880
881 if matches!(self.current_token, Token::Comma) {
882 self.advance();
883 } else {
884 break;
885 }
886 }
887
888 Ok(identifiers)
889 }
890
891 fn parse_order_by_list(&mut self) -> Result<Vec<OrderByColumn>, String> {
892 let mut order_columns = Vec::new();
893
894 loop {
895 let column = match &self.current_token {
896 Token::Identifier(id) => {
897 let col = id.clone();
898 self.advance();
899 col
900 }
901 Token::QuotedIdentifier(id) => {
902 let col = id.clone();
903 self.advance();
904 col
905 }
906 Token::NumberLiteral(num) if self.columns.iter().any(|col| col == num) => {
907 let col = num.clone();
909 self.advance();
910 col
911 }
912 _ => return Err("Expected column name in ORDER BY".to_string()),
913 };
914
915 let direction = match &self.current_token {
917 Token::Asc => {
918 self.advance();
919 SortDirection::Asc
920 }
921 Token::Desc => {
922 self.advance();
923 SortDirection::Desc
924 }
925 _ => SortDirection::Asc, };
927
928 order_columns.push(OrderByColumn { column, direction });
929
930 if matches!(self.current_token, Token::Comma) {
931 self.advance();
932 } else {
933 break;
934 }
935 }
936
937 Ok(order_columns)
938 }
939
940 fn parse_where_clause(&mut self) -> Result<WhereClause, String> {
941 let mut conditions = Vec::new();
942
943 loop {
944 let expr = self.parse_expression()?;
945
946 let connector = match &self.current_token {
947 Token::And => {
948 self.advance();
949 Some(LogicalOp::And)
950 }
951 Token::Or => {
952 self.advance();
953 Some(LogicalOp::Or)
954 }
955 Token::RightParen if self.paren_depth <= 0 => {
956 return Err(
958 "Unexpected closing parenthesis - no matching opening parenthesis"
959 .to_string(),
960 );
961 }
962 _ => None,
963 };
964
965 conditions.push(Condition {
966 expr,
967 connector: connector.clone(),
968 });
969
970 if connector.is_none() {
971 break;
972 }
973 }
974
975 Ok(WhereClause { conditions })
976 }
977
978 fn parse_expression(&mut self) -> Result<SqlExpression, String> {
979 let mut left = self.parse_comparison()?;
980
981 if let Some(op) = self.get_binary_op() {
984 self.advance();
985 let right = self.parse_expression()?;
986 left = SqlExpression::BinaryOp {
987 left: Box::new(left),
988 op,
989 right: Box::new(right),
990 };
991 }
992
993 if matches!(self.current_token, Token::In) {
995 self.advance();
996 self.consume(Token::LeftParen)?;
997 let values = self.parse_expression_list()?;
998 self.consume(Token::RightParen)?;
999
1000 left = SqlExpression::InList {
1001 expr: Box::new(left),
1002 values,
1003 };
1004 }
1005
1006 Ok(left)
1010 }
1011
1012 fn parse_comparison(&mut self) -> Result<SqlExpression, String> {
1013 let mut left = self.parse_additive()?;
1014
1015 if matches!(self.current_token, Token::Between) {
1017 self.advance(); let lower = self.parse_primary()?;
1019 self.consume(Token::And)?; let upper = self.parse_primary()?;
1021
1022 return Ok(SqlExpression::Between {
1023 expr: Box::new(left),
1024 lower: Box::new(lower),
1025 upper: Box::new(upper),
1026 });
1027 }
1028
1029 if matches!(self.current_token, Token::Not) {
1031 self.advance(); if matches!(self.current_token, Token::In) {
1033 self.advance(); self.consume(Token::LeftParen)?;
1035 let values = self.parse_expression_list()?;
1036 self.consume(Token::RightParen)?;
1037
1038 return Ok(SqlExpression::NotInList {
1039 expr: Box::new(left),
1040 values,
1041 });
1042 } else {
1043 return Err("Expected IN after NOT".to_string());
1044 }
1045 }
1046
1047 if let Some(op) = self.get_binary_op() {
1049 self.advance();
1050 let right = self.parse_additive()?;
1051 left = SqlExpression::BinaryOp {
1052 left: Box::new(left),
1053 op,
1054 right: Box::new(right),
1055 };
1056 }
1057
1058 Ok(left)
1059 }
1060
1061 fn parse_additive(&mut self) -> Result<SqlExpression, String> {
1062 let mut left = self.parse_multiplicative()?;
1063
1064 while matches!(self.current_token, Token::Plus | Token::Minus) {
1065 let op = match self.current_token {
1066 Token::Plus => "+",
1067 Token::Minus => "-",
1068 _ => unreachable!(),
1069 };
1070 self.advance();
1071 let right = self.parse_multiplicative()?;
1072 left = SqlExpression::BinaryOp {
1073 left: Box::new(left),
1074 op: op.to_string(),
1075 right: Box::new(right),
1076 };
1077 }
1078
1079 Ok(left)
1080 }
1081
1082 fn parse_multiplicative(&mut self) -> Result<SqlExpression, String> {
1083 let mut left = self.parse_primary()?;
1084
1085 while matches!(self.current_token, Token::Dot) {
1087 self.advance();
1088 if let Token::Identifier(method) = &self.current_token {
1089 let method_name = method.clone();
1090 self.advance();
1091
1092 if matches!(self.current_token, Token::LeftParen) {
1093 self.advance();
1094 let args = self.parse_method_args()?;
1095 self.consume(Token::RightParen)?;
1096
1097 match left {
1099 SqlExpression::Column(obj) => {
1100 left = SqlExpression::MethodCall {
1102 object: obj,
1103 method: method_name,
1104 args,
1105 };
1106 }
1107 SqlExpression::MethodCall { .. }
1108 | SqlExpression::ChainedMethodCall { .. } => {
1109 left = SqlExpression::ChainedMethodCall {
1111 base: Box::new(left),
1112 method: method_name,
1113 args,
1114 };
1115 }
1116 _ => {
1117 left = SqlExpression::ChainedMethodCall {
1119 base: Box::new(left),
1120 method: method_name,
1121 args,
1122 };
1123 }
1124 }
1125 } else {
1126 return Err(format!("Expected '(' after method name '{}'", method_name));
1127 }
1128 } else {
1129 return Err("Expected method name after '.'".to_string());
1130 }
1131 }
1132
1133 while matches!(self.current_token, Token::Star | Token::Divide) {
1134 let op = match self.current_token {
1135 Token::Star => "*",
1136 Token::Divide => "/",
1137 _ => unreachable!(),
1138 };
1139 self.advance();
1140 let right = self.parse_primary()?;
1141 left = SqlExpression::BinaryOp {
1142 left: Box::new(left),
1143 op: op.to_string(),
1144 right: Box::new(right),
1145 };
1146 }
1147
1148 Ok(left)
1149 }
1150
1151 fn parse_logical_or(&mut self) -> Result<SqlExpression, String> {
1152 let mut left = self.parse_logical_and()?;
1153
1154 while matches!(self.current_token, Token::Or) {
1155 self.advance();
1156 let right = self.parse_logical_and()?;
1157 left = SqlExpression::BinaryOp {
1161 left: Box::new(left),
1162 op: "OR".to_string(),
1163 right: Box::new(right),
1164 };
1165 }
1166
1167 Ok(left)
1168 }
1169
1170 fn parse_logical_and(&mut self) -> Result<SqlExpression, String> {
1171 let mut left = self.parse_expression()?;
1172
1173 while matches!(self.current_token, Token::And) {
1174 self.advance();
1175 let right = self.parse_expression()?;
1176 left = SqlExpression::BinaryOp {
1178 left: Box::new(left),
1179 op: "AND".to_string(),
1180 right: Box::new(right),
1181 };
1182 }
1183
1184 Ok(left)
1185 }
1186
1187 fn parse_case_expression(&mut self) -> Result<SqlExpression, String> {
1188 self.consume(Token::Case)?;
1190
1191 let mut when_branches = Vec::new();
1192
1193 while matches!(self.current_token, Token::When) {
1195 self.advance(); let condition = self.parse_expression()?;
1199
1200 self.consume(Token::Then)?;
1202
1203 let result = self.parse_expression()?;
1205
1206 when_branches.push(WhenBranch {
1207 condition: Box::new(condition),
1208 result: Box::new(result),
1209 });
1210 }
1211
1212 if when_branches.is_empty() {
1214 return Err("CASE expression must have at least one WHEN clause".to_string());
1215 }
1216
1217 let else_branch = if matches!(self.current_token, Token::Else) {
1219 self.advance(); Some(Box::new(self.parse_expression()?))
1221 } else {
1222 None
1223 };
1224
1225 self.consume(Token::End)?;
1227
1228 Ok(SqlExpression::CaseExpression {
1229 when_branches,
1230 else_branch,
1231 })
1232 }
1233
1234 fn parse_primary(&mut self) -> Result<SqlExpression, String> {
1235 if let Token::NumberLiteral(num_str) = &self.current_token {
1238 if self.columns.iter().any(|col| col == num_str) {
1240 let expr = SqlExpression::Column(num_str.clone());
1241 self.advance();
1242 return Ok(expr);
1243 }
1244 }
1245
1246 match &self.current_token {
1247 Token::Case => {
1248 self.parse_case_expression()
1250 }
1251 Token::DateTime => {
1252 self.advance(); self.consume(Token::LeftParen)?;
1254
1255 if matches!(&self.current_token, Token::RightParen) {
1257 self.advance(); return Ok(SqlExpression::DateTimeToday {
1259 hour: None,
1260 minute: None,
1261 second: None,
1262 });
1263 }
1264
1265 let year = if let Token::NumberLiteral(n) = &self.current_token {
1267 n.parse::<i32>().map_err(|_| "Invalid year")?
1268 } else {
1269 return Err("Expected year in DateTime constructor".to_string());
1270 };
1271 self.advance();
1272 self.consume(Token::Comma)?;
1273
1274 let month = if let Token::NumberLiteral(n) = &self.current_token {
1276 n.parse::<u32>().map_err(|_| "Invalid month")?
1277 } else {
1278 return Err("Expected month in DateTime constructor".to_string());
1279 };
1280 self.advance();
1281 self.consume(Token::Comma)?;
1282
1283 let day = if let Token::NumberLiteral(n) = &self.current_token {
1285 n.parse::<u32>().map_err(|_| "Invalid day")?
1286 } else {
1287 return Err("Expected day in DateTime constructor".to_string());
1288 };
1289 self.advance();
1290
1291 let mut hour = None;
1293 let mut minute = None;
1294 let mut second = None;
1295
1296 if matches!(&self.current_token, Token::Comma) {
1297 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1301 hour = Some(n.parse::<u32>().map_err(|_| "Invalid hour")?);
1302 self.advance();
1303
1304 if matches!(&self.current_token, Token::Comma) {
1306 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1309 minute = Some(n.parse::<u32>().map_err(|_| "Invalid minute")?);
1310 self.advance();
1311
1312 if matches!(&self.current_token, Token::Comma) {
1314 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1317 second =
1318 Some(n.parse::<u32>().map_err(|_| "Invalid second")?);
1319 self.advance();
1320 }
1321 }
1322 }
1323 }
1324 }
1325 }
1326
1327 self.consume(Token::RightParen)?;
1328 Ok(SqlExpression::DateTimeConstructor {
1329 year,
1330 month,
1331 day,
1332 hour,
1333 minute,
1334 second,
1335 })
1336 }
1337 Token::Identifier(id) => {
1338 let id_upper = id.to_uppercase();
1339 let id_clone = id.clone();
1340
1341 if id_upper == "TRUE" {
1343 self.advance();
1344 return Ok(SqlExpression::BooleanLiteral(true));
1345 } else if id_upper == "FALSE" {
1346 self.advance();
1347 return Ok(SqlExpression::BooleanLiteral(false));
1348 }
1349
1350 self.advance();
1351
1352 if matches!(self.current_token, Token::LeftParen) {
1354 self.advance(); let args = self.parse_function_args()?;
1358 self.consume(Token::RightParen)?;
1359 return Ok(SqlExpression::FunctionCall {
1360 name: id_upper,
1361 args,
1362 });
1363 }
1364
1365 Ok(SqlExpression::Column(id_clone))
1367 }
1368 Token::QuotedIdentifier(id) => {
1369 let expr = if self.in_method_args {
1372 SqlExpression::StringLiteral(id.clone())
1373 } else {
1374 SqlExpression::Column(id.clone())
1376 };
1377 self.advance();
1378 Ok(expr)
1379 }
1380 Token::StringLiteral(s) => {
1381 let expr = SqlExpression::StringLiteral(s.clone());
1382 self.advance();
1383 Ok(expr)
1384 }
1385 Token::NumberLiteral(n) => {
1386 let expr = SqlExpression::NumberLiteral(n.clone());
1387 self.advance();
1388 Ok(expr)
1389 }
1390 Token::LeftParen => {
1391 self.advance();
1392
1393 let expr = self.parse_logical_or()?;
1396
1397 self.consume(Token::RightParen)?;
1398 Ok(expr)
1399 }
1400 Token::Not => {
1401 self.advance(); if let Ok(inner_expr) = self.parse_comparison() {
1405 if matches!(self.current_token, Token::In) {
1407 self.advance(); self.consume(Token::LeftParen)?;
1409 let values = self.parse_expression_list()?;
1410 self.consume(Token::RightParen)?;
1411
1412 return Ok(SqlExpression::NotInList {
1413 expr: Box::new(inner_expr),
1414 values,
1415 });
1416 } else {
1417 return Ok(SqlExpression::Not {
1419 expr: Box::new(inner_expr),
1420 });
1421 }
1422 } else {
1423 return Err("Expected expression after NOT".to_string());
1424 }
1425 }
1426 Token::Star => {
1427 self.advance();
1429 Ok(SqlExpression::StringLiteral("*".to_string()))
1430 }
1431 _ => Err(format!("Unexpected token: {:?}", self.current_token)),
1432 }
1433 }
1434
1435 fn parse_method_args(&mut self) -> Result<Vec<SqlExpression>, String> {
1436 let mut args = Vec::new();
1437
1438 self.in_method_args = true;
1440
1441 if !matches!(self.current_token, Token::RightParen) {
1442 loop {
1443 args.push(self.parse_expression()?);
1444
1445 if matches!(self.current_token, Token::Comma) {
1446 self.advance();
1447 } else {
1448 break;
1449 }
1450 }
1451 }
1452
1453 self.in_method_args = false;
1455
1456 Ok(args)
1457 }
1458
1459 fn parse_function_args(&mut self) -> Result<Vec<SqlExpression>, String> {
1460 let mut args = Vec::new();
1461
1462 if !matches!(self.current_token, Token::RightParen) {
1463 if matches!(self.current_token, Token::Distinct) {
1465 self.advance(); let expr = self.parse_additive()?;
1468 args.push(SqlExpression::FunctionCall {
1470 name: "DISTINCT".to_string(),
1471 args: vec![expr],
1472 });
1473 } else {
1474 args.push(self.parse_additive()?);
1476 }
1477
1478 while matches!(self.current_token, Token::Comma) {
1480 self.advance();
1481 args.push(self.parse_additive()?);
1482 }
1483 }
1484
1485 Ok(args)
1486 }
1487
1488 fn parse_expression_list(&mut self) -> Result<Vec<SqlExpression>, String> {
1489 let mut expressions = Vec::new();
1490
1491 loop {
1492 expressions.push(self.parse_expression()?);
1493
1494 if matches!(self.current_token, Token::Comma) {
1495 self.advance();
1496 } else {
1497 break;
1498 }
1499 }
1500
1501 Ok(expressions)
1502 }
1503
1504 fn get_binary_op(&self) -> Option<String> {
1505 match &self.current_token {
1506 Token::Equal => Some("=".to_string()),
1507 Token::NotEqual => Some("!=".to_string()),
1508 Token::LessThan => Some("<".to_string()),
1509 Token::GreaterThan => Some(">".to_string()),
1510 Token::LessThanOrEqual => Some("<=".to_string()),
1511 Token::GreaterThanOrEqual => Some(">=".to_string()),
1512 Token::Like => Some("LIKE".to_string()),
1513 _ => None,
1514 }
1515 }
1516
1517 fn get_arithmetic_op(&self) -> Option<String> {
1518 match &self.current_token {
1519 Token::Plus => Some("+".to_string()),
1520 Token::Minus => Some("-".to_string()),
1521 Token::Star => Some("*".to_string()), Token::Divide => Some("/".to_string()),
1523 _ => None,
1524 }
1525 }
1526
1527 pub fn get_position(&self) -> usize {
1528 self.lexer.get_position()
1529 }
1530}
1531
1532#[derive(Debug, Clone)]
1534pub enum CursorContext {
1535 SelectClause,
1536 FromClause,
1537 WhereClause,
1538 OrderByClause,
1539 AfterColumn(String),
1540 AfterLogicalOp(LogicalOp),
1541 AfterComparisonOp(String, String), InMethodCall(String, String), InExpression,
1544 Unknown,
1545}
1546
1547fn safe_slice_to(s: &str, pos: usize) -> &str {
1549 if pos >= s.len() {
1550 return s;
1551 }
1552
1553 let mut safe_pos = pos;
1555 while safe_pos > 0 && !s.is_char_boundary(safe_pos) {
1556 safe_pos -= 1;
1557 }
1558
1559 &s[..safe_pos]
1560}
1561
1562fn safe_slice_from(s: &str, pos: usize) -> &str {
1564 if pos >= s.len() {
1565 return "";
1566 }
1567
1568 let mut safe_pos = pos;
1570 while safe_pos < s.len() && !s.is_char_boundary(safe_pos) {
1571 safe_pos += 1;
1572 }
1573
1574 &s[safe_pos..]
1575}
1576
1577pub fn detect_cursor_context(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
1578 let truncated = safe_slice_to(query, cursor_pos);
1579 let mut parser = Parser::new(truncated);
1580
1581 match parser.parse() {
1583 Ok(stmt) => {
1584 let (ctx, partial) = analyze_statement(&stmt, truncated, cursor_pos);
1585 #[cfg(test)]
1586 println!(
1587 "analyze_statement returned: {:?}, {:?} for query: '{}'",
1588 ctx, partial, truncated
1589 );
1590 (ctx, partial)
1591 }
1592 Err(_) => {
1593 let (ctx, partial) = analyze_partial(truncated, cursor_pos);
1595 #[cfg(test)]
1596 println!(
1597 "analyze_partial returned: {:?}, {:?} for query: '{}'",
1598 ctx, partial, truncated
1599 );
1600 (ctx, partial)
1601 }
1602 }
1603}
1604
1605pub fn tokenize_query(query: &str) -> Vec<String> {
1606 let mut lexer = Lexer::new(query);
1607 let tokens = lexer.tokenize_all();
1608 tokens.iter().map(|t| format!("{:?}", t)).collect()
1609}
1610
1611pub fn format_sql_pretty(query: &str) -> Vec<String> {
1612 format_sql_pretty_compact(query, 5) }
1614
1615pub fn format_ast_tree(query: &str) -> String {
1617 let mut parser = Parser::new(query);
1618 match parser.parse() {
1619 Ok(stmt) => format_select_statement(&stmt, 0),
1620 Err(e) => format!("❌ PARSE ERROR ❌\n{}\n\n⚠️ The query could not be parsed correctly.\n💡 Check parentheses, operators, and syntax.", e),
1621 }
1622}
1623
1624fn format_select_statement(stmt: &SelectStatement, indent: usize) -> String {
1625 let mut result = String::new();
1626 let indent_str = " ".repeat(indent);
1627
1628 result.push_str(&format!("{indent_str}SelectStatement {{\n"));
1629
1630 result.push_str(&format!("{indent_str} columns: ["));
1632 if !stmt.columns.is_empty() {
1633 result.push('\n');
1634 for col in &stmt.columns {
1635 result.push_str(&format!("{indent_str} \"{col}\",\n"));
1636 }
1637 result.push_str(&format!("{indent_str} ],\n"));
1638 } else {
1639 result.push_str("],\n");
1640 }
1641
1642 if let Some(table) = &stmt.from_table {
1644 result.push_str(&format!("{indent_str} from_table: \"{table}\",\n"));
1645 }
1646
1647 if let Some(where_clause) = &stmt.where_clause {
1649 result.push_str(&format!("{indent_str} where_clause: {{\n"));
1650 result.push_str(&format_where_clause(where_clause, indent + 2));
1651 result.push_str(&format!("{indent_str} }},\n"));
1652 }
1653
1654 if let Some(order_by) = &stmt.order_by {
1656 result.push_str(&format!("{indent_str} order_by: ["));
1657 if !order_by.is_empty() {
1658 result.push('\n');
1659 for col in order_by {
1660 let dir = match col.direction {
1661 SortDirection::Asc => "ASC",
1662 SortDirection::Desc => "DESC",
1663 };
1664 result.push_str(&format!(
1665 "{indent_str} \"{col}\" {dir},\n",
1666 col = col.column
1667 ));
1668 }
1669 result.push_str(&format!("{indent_str} ],\n"));
1670 } else {
1671 result.push_str("],\n");
1672 }
1673 }
1674
1675 if let Some(group_by) = &stmt.group_by {
1677 result.push_str(&format!("{indent_str} group_by: ["));
1678 if !group_by.is_empty() {
1679 result.push('\n');
1680 for col in group_by {
1681 result.push_str(&format!("{indent_str} \"{col}\",\n"));
1682 }
1683 result.push_str(&format!("{indent_str} ],\n"));
1684 } else {
1685 result.push_str("]\n");
1686 }
1687 }
1688
1689 result.push_str(&format!("{indent_str}}}"));
1690 result
1691}
1692
1693fn format_where_clause(clause: &WhereClause, indent: usize) -> String {
1694 let mut result = String::new();
1695 let indent_str = " ".repeat(indent);
1696
1697 result.push_str(&format!("{indent_str}conditions: [\n"));
1698
1699 for condition in &clause.conditions {
1700 result.push_str(&format!("{indent_str} {{\n"));
1701 result.push_str(&format!(
1702 "{indent_str} expr: {},\n",
1703 format_expression_ast(&condition.expr)
1704 ));
1705
1706 if let Some(connector) = &condition.connector {
1707 let connector_str = match connector {
1708 LogicalOp::And => "AND",
1709 LogicalOp::Or => "OR",
1710 };
1711 result.push_str(&format!("{indent_str} connector: {connector_str},\n"));
1712 }
1713
1714 result.push_str(&format!("{indent_str} }},\n"));
1715 }
1716
1717 result.push_str(&format!("{indent_str}]\n"));
1718 result
1719}
1720
1721fn format_expression_ast(expr: &SqlExpression) -> String {
1722 match expr {
1723 SqlExpression::Column(name) => format!("Column(\"{}\")", name),
1724 SqlExpression::StringLiteral(value) => format!("StringLiteral(\"{}\")", value),
1725 SqlExpression::NumberLiteral(value) => format!("NumberLiteral({})", value),
1726 SqlExpression::BooleanLiteral(value) => format!("BooleanLiteral({})", value),
1727 SqlExpression::DateTimeConstructor {
1728 year,
1729 month,
1730 day,
1731 hour,
1732 minute,
1733 second,
1734 } => {
1735 format!(
1736 "DateTime({}-{:02}-{:02} {:02}:{:02}:{:02})",
1737 year,
1738 month,
1739 day,
1740 hour.unwrap_or(0),
1741 minute.unwrap_or(0),
1742 second.unwrap_or(0)
1743 )
1744 }
1745 SqlExpression::DateTimeToday {
1746 hour,
1747 minute,
1748 second,
1749 } => {
1750 format!(
1751 "DateTimeToday({:02}:{:02}:{:02})",
1752 hour.unwrap_or(0),
1753 minute.unwrap_or(0),
1754 second.unwrap_or(0)
1755 )
1756 }
1757 SqlExpression::MethodCall {
1758 object,
1759 method,
1760 args,
1761 } => {
1762 let args_str = args
1763 .iter()
1764 .map(|a| format_expression_ast(a))
1765 .collect::<Vec<_>>()
1766 .join(", ");
1767 format!("MethodCall({}.{}({}))", object, method, args_str)
1768 }
1769 SqlExpression::ChainedMethodCall { base, method, args } => {
1770 let args_str = args
1771 .iter()
1772 .map(|a| format_expression_ast(a))
1773 .collect::<Vec<_>>()
1774 .join(", ");
1775 format!(
1776 "ChainedMethodCall({}.{}({}))",
1777 format_expression_ast(base),
1778 method,
1779 args_str
1780 )
1781 }
1782 SqlExpression::FunctionCall { name, args } => {
1783 let args_str = args
1784 .iter()
1785 .map(|a| format_expression_ast(a))
1786 .collect::<Vec<_>>()
1787 .join(", ");
1788 format!("FunctionCall({}({}))", name, args_str)
1789 }
1790 SqlExpression::BinaryOp { left, op, right } => {
1791 format!(
1792 "BinaryOp({} {} {})",
1793 format_expression_ast(left),
1794 op,
1795 format_expression_ast(right)
1796 )
1797 }
1798 SqlExpression::InList { expr, values } => {
1799 let list_str = values
1800 .iter()
1801 .map(|e| format_expression_ast(e))
1802 .collect::<Vec<_>>()
1803 .join(", ");
1804 format!("InList({} IN [{}])", format_expression_ast(expr), list_str)
1805 }
1806 SqlExpression::NotInList { expr, values } => {
1807 let list_str = values
1808 .iter()
1809 .map(|e| format_expression_ast(e))
1810 .collect::<Vec<_>>()
1811 .join(", ");
1812 format!(
1813 "NotInList({} NOT IN [{}])",
1814 format_expression_ast(expr),
1815 list_str
1816 )
1817 }
1818 SqlExpression::Between { expr, lower, upper } => {
1819 format!(
1820 "Between({} BETWEEN {} AND {})",
1821 format_expression_ast(expr),
1822 format_expression_ast(lower),
1823 format_expression_ast(upper)
1824 )
1825 }
1826 SqlExpression::Not { expr } => {
1827 format!("Not({})", format_expression_ast(expr))
1828 }
1829 SqlExpression::CaseExpression {
1830 when_branches,
1831 else_branch,
1832 } => {
1833 let when_strs: Vec<String> = when_branches
1834 .iter()
1835 .map(|branch| {
1836 format!(
1837 "WHEN {} THEN {}",
1838 format_expression_ast(&branch.condition),
1839 format_expression_ast(&branch.result)
1840 )
1841 })
1842 .collect();
1843 let else_str = else_branch
1844 .as_ref()
1845 .map(|e| format!(" ELSE {}", format_expression_ast(e)))
1846 .unwrap_or_default();
1847 format!("CASE {} {} END", when_strs.join(" "), else_str)
1848 }
1849 }
1850}
1851
1852pub fn datetime_to_iso_string(expr: &SqlExpression) -> Option<String> {
1854 match expr {
1855 SqlExpression::DateTimeConstructor {
1856 year,
1857 month,
1858 day,
1859 hour,
1860 minute,
1861 second,
1862 } => {
1863 let h = hour.unwrap_or(0);
1864 let m = minute.unwrap_or(0);
1865 let s = second.unwrap_or(0);
1866
1867 if let Ok(dt) = NaiveDateTime::parse_from_str(
1869 &format!(
1870 "{:04}-{:02}-{:02} {:02}:{:02}:{:02}",
1871 year, month, day, h, m, s
1872 ),
1873 "%Y-%m-%d %H:%M:%S",
1874 ) {
1875 Some(dt.format("%Y-%m-%d %H:%M:%S").to_string())
1876 } else {
1877 None
1878 }
1879 }
1880 SqlExpression::DateTimeToday {
1881 hour,
1882 minute,
1883 second,
1884 } => {
1885 let now = Local::now();
1886 let h = hour.unwrap_or(0);
1887 let m = minute.unwrap_or(0);
1888 let s = second.unwrap_or(0);
1889
1890 if let Ok(dt) = NaiveDateTime::parse_from_str(
1892 &format!(
1893 "{:04}-{:02}-{:02} {:02}:{:02}:{:02}",
1894 now.year(),
1895 now.month(),
1896 now.day(),
1897 h,
1898 m,
1899 s
1900 ),
1901 "%Y-%m-%d %H:%M:%S",
1902 ) {
1903 Some(dt.format("%Y-%m-%d %H:%M:%S").to_string())
1904 } else {
1905 None
1906 }
1907 }
1908 _ => None,
1909 }
1910}
1911
1912fn format_sql_with_preserved_parens(
1914 query: &str,
1915 cols_per_line: usize,
1916) -> Result<Vec<String>, String> {
1917 let mut lines = Vec::new();
1918 let mut lexer = Lexer::new(query);
1919 let tokens_with_pos = lexer.tokenize_all_with_positions();
1920
1921 if tokens_with_pos.is_empty() {
1922 return Err("No tokens found".to_string());
1923 }
1924
1925 let mut i = 0;
1926 let cols_per_line = cols_per_line.max(1);
1927
1928 while i < tokens_with_pos.len() {
1929 let (start, _end, ref token) = tokens_with_pos[i];
1930
1931 match token {
1932 Token::Select => {
1933 lines.push("SELECT".to_string());
1934 i += 1;
1935
1936 let mut columns = Vec::new();
1938 let mut col_start = i;
1939 while i < tokens_with_pos.len() {
1940 match &tokens_with_pos[i].2 {
1941 Token::From | Token::Eof => break,
1942 Token::Comma => {
1943 if col_start < i {
1945 let col_text = extract_text_between_positions(
1946 query,
1947 tokens_with_pos[col_start].0,
1948 tokens_with_pos[i - 1].1,
1949 );
1950 columns.push(col_text);
1951 }
1952 i += 1;
1953 col_start = i;
1954 }
1955 _ => i += 1,
1956 }
1957 }
1958 if col_start < i && i > 0 {
1960 let col_text = extract_text_between_positions(
1961 query,
1962 tokens_with_pos[col_start].0,
1963 tokens_with_pos[i - 1].1,
1964 );
1965 columns.push(col_text);
1966 }
1967
1968 for chunk in columns.chunks(cols_per_line) {
1970 let mut line = " ".to_string();
1971 for (idx, col) in chunk.iter().enumerate() {
1972 if idx > 0 {
1973 line.push_str(", ");
1974 }
1975 line.push_str(col.trim());
1976 }
1977 let is_last_chunk = chunk.as_ptr() as usize
1979 + chunk.len() * std::mem::size_of::<String>()
1980 >= columns.last().map(|c| c as *const _ as usize).unwrap_or(0);
1981 if !is_last_chunk && columns.len() > cols_per_line {
1982 line.push(',');
1983 }
1984 lines.push(line);
1985 }
1986 }
1987 Token::From => {
1988 i += 1;
1989 if i < tokens_with_pos.len() {
1990 let table_start = tokens_with_pos[i].0;
1991 while i < tokens_with_pos.len() {
1993 match &tokens_with_pos[i].2 {
1994 Token::Where | Token::OrderBy | Token::GroupBy | Token::Eof => break,
1995 _ => i += 1,
1996 }
1997 }
1998 if i > 0 {
1999 let table_text = extract_text_between_positions(
2000 query,
2001 table_start,
2002 tokens_with_pos[i - 1].1,
2003 );
2004 lines.push(format!("FROM {}", table_text.trim()));
2005 }
2006 }
2007 }
2008 Token::Where => {
2009 lines.push("WHERE".to_string());
2010 i += 1;
2011
2012 let where_start = if i < tokens_with_pos.len() {
2014 tokens_with_pos[i].0
2015 } else {
2016 start
2017 };
2018
2019 let mut where_end = query.len();
2021 while i < tokens_with_pos.len() {
2022 match &tokens_with_pos[i].2 {
2023 Token::OrderBy | Token::GroupBy | Token::Eof => {
2024 if i > 0 {
2025 where_end = tokens_with_pos[i - 1].1;
2026 }
2027 break;
2028 }
2029 _ => i += 1,
2030 }
2031 }
2032
2033 let where_text = extract_text_between_positions(query, where_start, where_end);
2035
2036 let formatted_where = format_where_clause_with_parens(&where_text);
2038 for line in formatted_where {
2039 lines.push(format!(" {}", line));
2040 }
2041 }
2042 Token::OrderBy => {
2043 i += 1;
2044 let order_start = if i < tokens_with_pos.len() {
2045 tokens_with_pos[i].0
2046 } else {
2047 start
2048 };
2049
2050 while i < tokens_with_pos.len() {
2052 match &tokens_with_pos[i].2 {
2053 Token::GroupBy | Token::Eof => break,
2054 _ => i += 1,
2055 }
2056 }
2057
2058 if i > 0 {
2059 let order_text = extract_text_between_positions(
2060 query,
2061 order_start,
2062 tokens_with_pos[i - 1].1,
2063 );
2064 lines.push(format!("ORDER BY {}", order_text.trim()));
2065 }
2066 }
2067 Token::GroupBy => {
2068 i += 1;
2069 let group_start = if i < tokens_with_pos.len() {
2070 tokens_with_pos[i].0
2071 } else {
2072 start
2073 };
2074
2075 while i < tokens_with_pos.len() {
2077 match &tokens_with_pos[i].2 {
2078 Token::Having | Token::Eof => break,
2079 _ => i += 1,
2080 }
2081 }
2082
2083 if i > 0 {
2084 let group_text = extract_text_between_positions(
2085 query,
2086 group_start,
2087 tokens_with_pos[i - 1].1,
2088 );
2089 lines.push(format!("GROUP BY {}", group_text.trim()));
2090 }
2091 }
2092 _ => i += 1,
2093 }
2094 }
2095
2096 Ok(lines)
2097}
2098
2099fn extract_text_between_positions(query: &str, start: usize, end: usize) -> String {
2101 let chars: Vec<char> = query.chars().collect();
2102 let start = start.min(chars.len());
2103 let end = end.min(chars.len());
2104 chars[start..end].iter().collect()
2105}
2106
2107fn format_where_clause_with_parens(where_text: &str) -> Vec<String> {
2109 let mut lines = Vec::new();
2110 let mut current_line = String::new();
2111 let mut paren_depth = 0;
2112 let mut i = 0;
2113 let chars: Vec<char> = where_text.chars().collect();
2114
2115 while i < chars.len() {
2116 if paren_depth == 0 {
2118 if i + 5 <= chars.len() {
2120 let next_five: String = chars[i..i + 5].iter().collect();
2121 if next_five.to_uppercase() == " AND " {
2122 if !current_line.trim().is_empty() {
2123 lines.push(current_line.trim().to_string());
2124 }
2125 lines.push("AND".to_string());
2126 current_line.clear();
2127 i += 5;
2128 continue;
2129 }
2130 }
2131 if i + 4 <= chars.len() {
2132 let next_four: String = chars[i..i + 4].iter().collect();
2133 if next_four.to_uppercase() == " OR " {
2134 if !current_line.trim().is_empty() {
2135 lines.push(current_line.trim().to_string());
2136 }
2137 lines.push("OR".to_string());
2138 current_line.clear();
2139 i += 4;
2140 continue;
2141 }
2142 }
2143 }
2144
2145 match chars[i] {
2147 '(' => {
2148 paren_depth += 1;
2149 current_line.push('(');
2150 }
2151 ')' => {
2152 paren_depth -= 1;
2153 current_line.push(')');
2154 }
2155 c => current_line.push(c),
2156 }
2157 i += 1;
2158 }
2159
2160 if !current_line.trim().is_empty() {
2162 lines.push(current_line.trim().to_string());
2163 }
2164
2165 if lines.is_empty() {
2167 lines.push(where_text.trim().to_string());
2168 }
2169
2170 lines
2171}
2172
2173pub fn format_sql_pretty_compact(query: &str, cols_per_line: usize) -> Vec<String> {
2174 if let Ok(lines) = format_sql_with_preserved_parens(query, cols_per_line) {
2176 return lines;
2177 }
2178
2179 let mut lines = Vec::new();
2181 let mut parser = Parser::new(query);
2182
2183 let cols_per_line = cols_per_line.max(1);
2185
2186 match parser.parse() {
2187 Ok(stmt) => {
2188 if !stmt.columns.is_empty() {
2190 lines.push("SELECT".to_string());
2191
2192 for chunk in stmt.columns.chunks(cols_per_line) {
2194 let mut line = " ".to_string();
2195 for (i, col) in chunk.iter().enumerate() {
2196 if i > 0 {
2197 line.push_str(", ");
2198 }
2199 line.push_str(col);
2200 }
2201 let last_chunk_idx = (stmt.columns.len() - 1) / cols_per_line;
2203 let current_chunk_idx =
2204 stmt.columns.iter().position(|c| c == &chunk[0]).unwrap() / cols_per_line;
2205 if current_chunk_idx < last_chunk_idx {
2206 line.push(',');
2207 }
2208 lines.push(line);
2209 }
2210 }
2211
2212 if let Some(table) = &stmt.from_table {
2214 lines.push(format!("FROM {}", table));
2215 }
2216
2217 if let Some(where_clause) = &stmt.where_clause {
2219 lines.push("WHERE".to_string());
2220 for (i, condition) in where_clause.conditions.iter().enumerate() {
2221 if i > 0 {
2222 if let Some(prev_condition) = where_clause.conditions.get(i - 1) {
2224 if let Some(connector) = &prev_condition.connector {
2225 match connector {
2226 LogicalOp::And => lines.push(" AND".to_string()),
2227 LogicalOp::Or => lines.push(" OR".to_string()),
2228 }
2229 }
2230 }
2231 }
2232 lines.push(format!(" {}", format_expression(&condition.expr)));
2233 }
2234 }
2235
2236 if let Some(order_by) = &stmt.order_by {
2238 let order_str = order_by
2239 .iter()
2240 .map(|col| {
2241 let dir = match col.direction {
2242 SortDirection::Asc => " ASC",
2243 SortDirection::Desc => " DESC",
2244 };
2245 format!("{}{}", col.column, dir)
2246 })
2247 .collect::<Vec<_>>()
2248 .join(", ");
2249 lines.push(format!("ORDER BY {}", order_str));
2250 }
2251
2252 if let Some(group_by) = &stmt.group_by {
2254 let group_str = group_by.join(", ");
2255 lines.push(format!("GROUP BY {}", group_str));
2256 }
2257 }
2258 Err(_) => {
2259 let mut lexer = Lexer::new(query);
2261 let tokens = lexer.tokenize_all();
2262 let mut current_line = String::new();
2263 let mut indent = 0;
2264
2265 for token in tokens {
2266 match &token {
2267 Token::Select
2268 | Token::From
2269 | Token::Where
2270 | Token::OrderBy
2271 | Token::GroupBy => {
2272 if !current_line.is_empty() {
2273 lines.push(current_line.trim().to_string());
2274 current_line.clear();
2275 }
2276 lines.push(format!("{:?}", token).to_uppercase());
2277 indent = 1;
2278 }
2279 Token::And | Token::Or => {
2280 if !current_line.is_empty() {
2281 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2282 current_line.clear();
2283 }
2284 lines.push(format!(" {:?}", token).to_uppercase());
2285 }
2286 Token::Comma => {
2287 current_line.push(',');
2288 if indent > 0 {
2289 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2290 current_line.clear();
2291 }
2292 }
2293 Token::Eof => break,
2294 _ => {
2295 if !current_line.is_empty() {
2296 current_line.push(' ');
2297 }
2298 current_line.push_str(&format_token(&token));
2299 }
2300 }
2301 }
2302
2303 if !current_line.is_empty() {
2304 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2305 }
2306 }
2307 }
2308
2309 lines
2310}
2311
2312fn format_expression(expr: &SqlExpression) -> String {
2313 match expr {
2314 SqlExpression::Column(name) => name.clone(),
2315 SqlExpression::StringLiteral(s) => format!("'{}'", s),
2316 SqlExpression::NumberLiteral(n) => n.clone(),
2317 SqlExpression::BooleanLiteral(b) => b.to_string(),
2318 SqlExpression::DateTimeConstructor {
2319 year,
2320 month,
2321 day,
2322 hour,
2323 minute,
2324 second,
2325 } => {
2326 let mut result = format!("DateTime({}, {}, {}", year, month, day);
2327 if let Some(h) = hour {
2328 result.push_str(&format!(", {}", h));
2329 if let Some(m) = minute {
2330 result.push_str(&format!(", {}", m));
2331 if let Some(s) = second {
2332 result.push_str(&format!(", {}", s));
2333 }
2334 }
2335 }
2336 result.push(')');
2337 result
2338 }
2339 SqlExpression::DateTimeToday {
2340 hour,
2341 minute,
2342 second,
2343 } => {
2344 let mut result = "DateTime()".to_string();
2345 if let Some(h) = hour {
2346 result = format!("DateTime(TODAY, {}", h);
2347 if let Some(m) = minute {
2348 result.push_str(&format!(", {}", m));
2349 if let Some(s) = second {
2350 result.push_str(&format!(", {}", s));
2351 }
2352 }
2353 result.push(')');
2354 }
2355 result
2356 }
2357 SqlExpression::MethodCall {
2358 object,
2359 method,
2360 args,
2361 } => {
2362 let args_str = args
2363 .iter()
2364 .map(|arg| format_expression(arg))
2365 .collect::<Vec<_>>()
2366 .join(", ");
2367 format!("{}.{}({})", object, method, args_str)
2368 }
2369 SqlExpression::BinaryOp { left, op, right } => {
2370 if op == "OR" || op == "AND" {
2373 format!(
2376 "({} {} {})",
2377 format_expression(left),
2378 op,
2379 format_expression(right)
2380 )
2381 } else {
2382 format!(
2383 "{} {} {}",
2384 format_expression(left),
2385 op,
2386 format_expression(right)
2387 )
2388 }
2389 }
2390 SqlExpression::InList { expr, values } => {
2391 let values_str = values
2392 .iter()
2393 .map(|v| format_expression(v))
2394 .collect::<Vec<_>>()
2395 .join(", ");
2396 format!("{} IN ({})", format_expression(expr), values_str)
2397 }
2398 SqlExpression::NotInList { expr, values } => {
2399 let values_str = values
2400 .iter()
2401 .map(|v| format_expression(v))
2402 .collect::<Vec<_>>()
2403 .join(", ");
2404 format!("{} NOT IN ({})", format_expression(expr), values_str)
2405 }
2406 SqlExpression::Between { expr, lower, upper } => {
2407 format!(
2408 "{} BETWEEN {} AND {}",
2409 format_expression(expr),
2410 format_expression(lower),
2411 format_expression(upper)
2412 )
2413 }
2414 SqlExpression::Not { expr } => {
2415 format!("NOT {}", format_expression(expr))
2416 }
2417 SqlExpression::ChainedMethodCall { base, method, args } => {
2418 let args_str = args
2419 .iter()
2420 .map(|arg| format_expression(arg))
2421 .collect::<Vec<_>>()
2422 .join(", ");
2423 format!("{}.{}({})", format_expression(base), method, args_str)
2424 }
2425 SqlExpression::FunctionCall { name, args } => {
2426 let args_str = args
2427 .iter()
2428 .map(|arg| format_expression(arg))
2429 .collect::<Vec<_>>()
2430 .join(", ");
2431 format!("{}({})", name, args_str)
2432 }
2433 SqlExpression::CaseExpression {
2434 when_branches,
2435 else_branch,
2436 } => {
2437 let mut result = String::from("CASE");
2438 for branch in when_branches {
2439 result.push_str(&format!(
2440 " WHEN {} THEN {}",
2441 format_expression(&branch.condition),
2442 format_expression(&branch.result)
2443 ));
2444 }
2445 if let Some(else_expr) = else_branch {
2446 result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
2447 }
2448 result.push_str(" END");
2449 result
2450 }
2451 }
2452}
2453
2454fn format_token(token: &Token) -> String {
2455 match token {
2456 Token::Identifier(s) => s.clone(),
2457 Token::QuotedIdentifier(s) => format!("\"{}\"", s),
2458 Token::StringLiteral(s) => format!("'{}'", s),
2459 Token::NumberLiteral(n) => n.clone(),
2460 Token::DateTime => "DateTime".to_string(),
2461 Token::Case => "CASE".to_string(),
2462 Token::When => "WHEN".to_string(),
2463 Token::Then => "THEN".to_string(),
2464 Token::Else => "ELSE".to_string(),
2465 Token::End => "END".to_string(),
2466 Token::Distinct => "DISTINCT".to_string(),
2467 Token::LeftParen => "(".to_string(),
2468 Token::RightParen => ")".to_string(),
2469 Token::Comma => ",".to_string(),
2470 Token::Dot => ".".to_string(),
2471 Token::Equal => "=".to_string(),
2472 Token::NotEqual => "!=".to_string(),
2473 Token::LessThan => "<".to_string(),
2474 Token::GreaterThan => ">".to_string(),
2475 Token::LessThanOrEqual => "<=".to_string(),
2476 Token::GreaterThanOrEqual => ">=".to_string(),
2477 Token::In => "IN".to_string(),
2478 _ => format!("{:?}", token).to_uppercase(),
2479 }
2480}
2481
2482fn analyze_statement(
2483 stmt: &SelectStatement,
2484 query: &str,
2485 _cursor_pos: usize,
2486) -> (CursorContext, Option<String>) {
2487 let trimmed = query.trim();
2489
2490 let comparison_ops = [" > ", " < ", " >= ", " <= ", " = ", " != "];
2492 for op in &comparison_ops {
2493 if let Some(op_pos) = query.rfind(op) {
2494 let before_op = safe_slice_to(query, op_pos);
2495 let after_op_start = op_pos + op.len();
2496 let after_op = if after_op_start < query.len() {
2497 &query[after_op_start..]
2498 } else {
2499 ""
2500 };
2501
2502 if let Some(col_name) = before_op.split_whitespace().last() {
2504 if col_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
2505 let after_op_trimmed = after_op.trim();
2507 if after_op_trimmed.is_empty()
2508 || (after_op_trimmed
2509 .chars()
2510 .all(|c| c.is_alphanumeric() || c == '_')
2511 && !after_op_trimmed.contains('('))
2512 {
2513 let partial = if after_op_trimmed.is_empty() {
2514 None
2515 } else {
2516 Some(after_op_trimmed.to_string())
2517 };
2518 return (
2519 CursorContext::AfterComparisonOp(
2520 col_name.to_string(),
2521 op.trim().to_string(),
2522 ),
2523 partial,
2524 );
2525 }
2526 }
2527 }
2528 }
2529 }
2530
2531 if trimmed.to_uppercase().ends_with(" AND")
2533 || trimmed.to_uppercase().ends_with(" OR")
2534 || trimmed.to_uppercase().ends_with(" AND ")
2535 || trimmed.to_uppercase().ends_with(" OR ")
2536 {
2537 } else {
2539 if let Some(dot_pos) = trimmed.rfind('.') {
2541 let before_dot = safe_slice_to(trimmed, dot_pos);
2543 let after_dot_start = dot_pos + 1;
2544 let after_dot = if after_dot_start < trimmed.len() {
2545 &trimmed[after_dot_start..]
2546 } else {
2547 ""
2548 };
2549
2550 if !after_dot.contains('(') {
2553 let col_name = if before_dot.ends_with('"') {
2555 let bytes = before_dot.as_bytes();
2557 let mut pos = before_dot.len() - 1; let mut found_start = None;
2559
2560 if pos > 0 {
2562 pos -= 1;
2563 while pos > 0 {
2564 if bytes[pos] == b'"' {
2565 if pos == 0 || bytes[pos - 1] != b'\\' {
2567 found_start = Some(pos);
2568 break;
2569 }
2570 }
2571 pos -= 1;
2572 }
2573 if found_start.is_none() && bytes[0] == b'"' {
2575 found_start = Some(0);
2576 }
2577 }
2578
2579 if let Some(start) = found_start {
2580 Some(safe_slice_from(before_dot, start))
2582 } else {
2583 None
2584 }
2585 } else {
2586 before_dot
2589 .split_whitespace()
2590 .last()
2591 .map(|word| word.trim_start_matches('('))
2592 };
2593
2594 if let Some(col_name) = col_name {
2595 let is_valid = if col_name.starts_with('"') && col_name.ends_with('"') {
2597 true
2599 } else {
2600 col_name.chars().all(|c| c.is_alphanumeric() || c == '_')
2602 };
2603
2604 if is_valid {
2605 let partial_method = if after_dot.is_empty() {
2608 None
2609 } else if after_dot.chars().all(|c| c.is_alphanumeric() || c == '_') {
2610 Some(after_dot.to_string())
2611 } else {
2612 None
2613 };
2614
2615 let col_name_for_context = if col_name.starts_with('"')
2617 && col_name.ends_with('"')
2618 && col_name.len() > 2
2619 {
2620 col_name[1..col_name.len() - 1].to_string()
2621 } else {
2622 col_name.to_string()
2623 };
2624
2625 return (
2626 CursorContext::AfterColumn(col_name_for_context),
2627 partial_method,
2628 );
2629 }
2630 }
2631 }
2632 }
2633 }
2634
2635 if let Some(where_clause) = &stmt.where_clause {
2637 if trimmed.to_uppercase().ends_with(" AND") || trimmed.to_uppercase().ends_with(" OR") {
2639 let op = if trimmed.to_uppercase().ends_with(" AND") {
2640 LogicalOp::And
2641 } else {
2642 LogicalOp::Or
2643 };
2644 return (CursorContext::AfterLogicalOp(op), None);
2645 }
2646
2647 if let Some(and_pos) = query.to_uppercase().rfind(" AND ") {
2649 let after_and = safe_slice_from(query, and_pos + 5);
2650 let partial = extract_partial_at_end(after_and);
2651 if partial.is_some() {
2652 return (CursorContext::AfterLogicalOp(LogicalOp::And), partial);
2653 }
2654 }
2655
2656 if let Some(or_pos) = query.to_uppercase().rfind(" OR ") {
2657 let after_or = safe_slice_from(query, or_pos + 4);
2658 let partial = extract_partial_at_end(after_or);
2659 if partial.is_some() {
2660 return (CursorContext::AfterLogicalOp(LogicalOp::Or), partial);
2661 }
2662 }
2663
2664 if let Some(last_condition) = where_clause.conditions.last() {
2665 if let Some(connector) = &last_condition.connector {
2666 return (
2668 CursorContext::AfterLogicalOp(connector.clone()),
2669 extract_partial_at_end(query),
2670 );
2671 }
2672 }
2673 return (CursorContext::WhereClause, extract_partial_at_end(query));
2675 }
2676
2677 if query.to_uppercase().ends_with(" ORDER BY ") || query.to_uppercase().ends_with(" ORDER BY") {
2679 return (CursorContext::OrderByClause, None);
2680 }
2681
2682 if stmt.order_by.is_some() {
2684 return (CursorContext::OrderByClause, extract_partial_at_end(query));
2685 }
2686
2687 if stmt.from_table.is_some() && stmt.where_clause.is_none() && stmt.order_by.is_none() {
2688 return (CursorContext::FromClause, extract_partial_at_end(query));
2689 }
2690
2691 if stmt.columns.len() > 0 && stmt.from_table.is_none() {
2692 return (CursorContext::SelectClause, extract_partial_at_end(query));
2693 }
2694
2695 (CursorContext::Unknown, None)
2696}
2697
2698fn analyze_partial(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
2699 let upper = query.to_uppercase();
2700
2701 let trimmed = query.trim();
2703
2704 #[cfg(test)]
2705 {
2706 if trimmed.contains("\"Last Name\"") {
2707 eprintln!(
2708 "DEBUG analyze_partial: query='{}', trimmed='{}'",
2709 query, trimmed
2710 );
2711 }
2712 }
2713
2714 let comparison_ops = [" > ", " < ", " >= ", " <= ", " = ", " != "];
2716 for op in &comparison_ops {
2717 if let Some(op_pos) = query.rfind(op) {
2718 let before_op = safe_slice_to(query, op_pos);
2719 let after_op_start = op_pos + op.len();
2720 let after_op = if after_op_start < query.len() {
2721 &query[after_op_start..]
2722 } else {
2723 ""
2724 };
2725
2726 if let Some(col_name) = before_op.split_whitespace().last() {
2728 if col_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
2729 let after_op_trimmed = after_op.trim();
2731 if after_op_trimmed.is_empty()
2732 || (after_op_trimmed
2733 .chars()
2734 .all(|c| c.is_alphanumeric() || c == '_')
2735 && !after_op_trimmed.contains('('))
2736 {
2737 let partial = if after_op_trimmed.is_empty() {
2738 None
2739 } else {
2740 Some(after_op_trimmed.to_string())
2741 };
2742 return (
2743 CursorContext::AfterComparisonOp(
2744 col_name.to_string(),
2745 op.trim().to_string(),
2746 ),
2747 partial,
2748 );
2749 }
2750 }
2751 }
2752 }
2753 }
2754
2755 if let Some(dot_pos) = trimmed.rfind('.') {
2758 #[cfg(test)]
2759 {
2760 if trimmed.contains("\"Last Name\"") {
2761 eprintln!("DEBUG: Found dot at position {}", dot_pos);
2762 }
2763 }
2764 let before_dot = &trimmed[..dot_pos];
2766 let after_dot = &trimmed[dot_pos + 1..];
2767
2768 if !after_dot.contains('(') {
2771 let col_name = if before_dot.ends_with('"') {
2774 let bytes = before_dot.as_bytes();
2776 let mut pos = before_dot.len() - 1; let mut found_start = None;
2778
2779 #[cfg(test)]
2780 {
2781 if trimmed.contains("\"Last Name\"") {
2782 eprintln!(
2783 "DEBUG: before_dot='{}', looking for opening quote",
2784 before_dot
2785 );
2786 }
2787 }
2788
2789 if pos > 0 {
2791 pos -= 1;
2792 while pos > 0 {
2793 if bytes[pos] == b'"' {
2794 if pos == 0 || bytes[pos - 1] != b'\\' {
2796 found_start = Some(pos);
2797 break;
2798 }
2799 }
2800 pos -= 1;
2801 }
2802 if found_start.is_none() && bytes[0] == b'"' {
2804 found_start = Some(0);
2805 }
2806 }
2807
2808 if let Some(start) = found_start {
2809 let result = safe_slice_from(before_dot, start);
2811 #[cfg(test)]
2812 {
2813 if trimmed.contains("\"Last Name\"") {
2814 eprintln!("DEBUG: Extracted quoted identifier: '{}'", result);
2815 }
2816 }
2817 Some(result)
2818 } else {
2819 #[cfg(test)]
2820 {
2821 if trimmed.contains("\"Last Name\"") {
2822 eprintln!("DEBUG: No opening quote found!");
2823 }
2824 }
2825 None
2826 }
2827 } else {
2828 before_dot
2831 .split_whitespace()
2832 .last()
2833 .map(|word| word.trim_start_matches('('))
2834 };
2835
2836 if let Some(col_name) = col_name {
2837 #[cfg(test)]
2838 {
2839 if trimmed.contains("\"Last Name\"") {
2840 eprintln!("DEBUG: col_name = '{}'", col_name);
2841 }
2842 }
2843
2844 let is_valid = if col_name.starts_with('"') && col_name.ends_with('"') {
2846 true
2848 } else {
2849 col_name.chars().all(|c| c.is_alphanumeric() || c == '_')
2851 };
2852
2853 #[cfg(test)]
2854 {
2855 if trimmed.contains("\"Last Name\"") {
2856 eprintln!("DEBUG: is_valid = {}", is_valid);
2857 }
2858 }
2859
2860 if is_valid {
2861 let partial_method = if after_dot.is_empty() {
2864 None
2865 } else if after_dot.chars().all(|c| c.is_alphanumeric() || c == '_') {
2866 Some(after_dot.to_string())
2867 } else {
2868 None
2869 };
2870
2871 let col_name_for_context = if col_name.starts_with('"')
2873 && col_name.ends_with('"')
2874 && col_name.len() > 2
2875 {
2876 col_name[1..col_name.len() - 1].to_string()
2877 } else {
2878 col_name.to_string()
2879 };
2880
2881 return (
2882 CursorContext::AfterColumn(col_name_for_context),
2883 partial_method,
2884 );
2885 }
2886 }
2887 }
2888 }
2889
2890 if let Some(and_pos) = upper.rfind(" AND ") {
2892 if cursor_pos >= and_pos + 5 {
2894 let after_and = safe_slice_from(query, and_pos + 5);
2896 let partial = extract_partial_at_end(after_and);
2897 return (CursorContext::AfterLogicalOp(LogicalOp::And), partial);
2898 }
2899 }
2900
2901 if let Some(or_pos) = upper.rfind(" OR ") {
2902 if cursor_pos >= or_pos + 4 {
2904 let after_or = safe_slice_from(query, or_pos + 4);
2906 let partial = extract_partial_at_end(after_or);
2907 return (CursorContext::AfterLogicalOp(LogicalOp::Or), partial);
2908 }
2909 }
2910
2911 if trimmed.to_uppercase().ends_with(" AND") || trimmed.to_uppercase().ends_with(" OR") {
2913 let op = if trimmed.to_uppercase().ends_with(" AND") {
2914 LogicalOp::And
2915 } else {
2916 LogicalOp::Or
2917 };
2918 return (CursorContext::AfterLogicalOp(op), None);
2919 }
2920
2921 if upper.ends_with(" ORDER BY ") || upper.ends_with(" ORDER BY") || upper.contains("ORDER BY ")
2923 {
2924 return (CursorContext::OrderByClause, extract_partial_at_end(query));
2925 }
2926
2927 if upper.contains("WHERE") && !upper.contains("ORDER") && !upper.contains("GROUP") {
2928 return (CursorContext::WhereClause, extract_partial_at_end(query));
2929 }
2930
2931 if upper.contains("FROM") && !upper.contains("WHERE") && !upper.contains("ORDER") {
2932 return (CursorContext::FromClause, extract_partial_at_end(query));
2933 }
2934
2935 if upper.contains("SELECT") && !upper.contains("FROM") {
2936 return (CursorContext::SelectClause, extract_partial_at_end(query));
2937 }
2938
2939 (CursorContext::Unknown, None)
2940}
2941
2942fn extract_partial_at_end(query: &str) -> Option<String> {
2943 let trimmed = query.trim();
2944
2945 if let Some(last_word) = trimmed.split_whitespace().last() {
2947 if last_word.starts_with('"') && !last_word.ends_with('"') {
2948 return Some(last_word.to_string());
2950 }
2951 }
2952
2953 let last_word = trimmed.split_whitespace().last()?;
2955
2956 if last_word.chars().all(|c| c.is_alphanumeric() || c == '_') && !is_sql_keyword(last_word) {
2958 Some(last_word.to_string())
2959 } else {
2960 None
2961 }
2962}
2963
2964fn is_sql_keyword(word: &str) -> bool {
2965 matches!(
2966 word.to_uppercase().as_str(),
2967 "SELECT"
2968 | "FROM"
2969 | "WHERE"
2970 | "AND"
2971 | "OR"
2972 | "IN"
2973 | "ORDER"
2974 | "BY"
2975 | "GROUP"
2976 | "HAVING"
2977 | "ASC"
2978 | "DESC"
2979 | "DISTINCT"
2980 )
2981}
2982
2983#[cfg(test)]
2984mod tests {
2985 use super::*;
2986
2987 #[test]
2988 fn test_chained_method_calls() {
2989 let query = "SELECT * FROM trades WHERE commission.ToString().IndexOf('.') = 1";
2991 let mut parser = Parser::new(query);
2992 let result = parser.parse();
2993
2994 assert!(
2995 result.is_ok(),
2996 "Failed to parse chained method calls: {:?}",
2997 result
2998 );
2999
3000 let query2 = "SELECT * FROM data WHERE field.ToUpper().Replace('A', 'B').StartsWith('C')";
3002 let mut parser2 = Parser::new(query2);
3003 let result2 = parser2.parse();
3004
3005 assert!(
3006 result2.is_ok(),
3007 "Failed to parse multiple chained calls: {:?}",
3008 result2
3009 );
3010 }
3011
3012 #[test]
3013 fn test_tokenizer() {
3014 let mut lexer = Lexer::new("SELECT * FROM trade_deal WHERE price > 100");
3015
3016 assert!(matches!(lexer.next_token(), Token::Select));
3017 assert!(matches!(lexer.next_token(), Token::Star));
3018 assert!(matches!(lexer.next_token(), Token::From));
3019 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "trade_deal"));
3020 assert!(matches!(lexer.next_token(), Token::Where));
3021 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "price"));
3022 assert!(matches!(lexer.next_token(), Token::GreaterThan));
3023 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "100"));
3024 }
3025
3026 #[test]
3027 fn test_tokenizer_datetime() {
3028 let mut lexer = Lexer::new("WHERE createdDate > DateTime(2025, 10, 20)");
3029
3030 assert!(matches!(lexer.next_token(), Token::Where));
3031 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "createdDate"));
3032 assert!(matches!(lexer.next_token(), Token::GreaterThan));
3033 assert!(matches!(lexer.next_token(), Token::DateTime));
3034 assert!(matches!(lexer.next_token(), Token::LeftParen));
3035 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "2025"));
3036 assert!(matches!(lexer.next_token(), Token::Comma));
3037 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "10"));
3038 assert!(matches!(lexer.next_token(), Token::Comma));
3039 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "20"));
3040 assert!(matches!(lexer.next_token(), Token::RightParen));
3041 }
3042
3043 #[test]
3044 fn test_parse_simple_select() {
3045 let mut parser = Parser::new("SELECT * FROM trade_deal");
3046 let stmt = parser.parse().unwrap();
3047
3048 assert_eq!(stmt.columns, vec!["*"]);
3049 assert_eq!(stmt.from_table, Some("trade_deal".to_string()));
3050 assert!(stmt.where_clause.is_none());
3051 }
3052
3053 #[test]
3054 fn test_parse_where_with_method() {
3055 let mut parser = Parser::new("SELECT * FROM trade_deal WHERE name.Contains(\"test\")");
3056 let stmt = parser.parse().unwrap();
3057
3058 assert!(stmt.where_clause.is_some());
3059 let where_clause = stmt.where_clause.unwrap();
3060 assert_eq!(where_clause.conditions.len(), 1);
3061 }
3062
3063 #[test]
3064 fn test_parse_datetime_constructor() {
3065 let mut parser =
3066 Parser::new("SELECT * FROM trade_deal WHERE createdDate > DateTime(2025, 10, 20)");
3067 let stmt = parser.parse().unwrap();
3068
3069 assert!(stmt.where_clause.is_some());
3070 let where_clause = stmt.where_clause.unwrap();
3071 assert_eq!(where_clause.conditions.len(), 1);
3072
3073 if let SqlExpression::BinaryOp { left, op, right } = &where_clause.conditions[0].expr {
3075 assert_eq!(op, ">");
3076 assert!(matches!(left.as_ref(), SqlExpression::Column(col) if col == "createdDate"));
3077 assert!(matches!(
3078 right.as_ref(),
3079 SqlExpression::DateTimeConstructor {
3080 year: 2025,
3081 month: 10,
3082 day: 20,
3083 hour: None,
3084 minute: None,
3085 second: None
3086 }
3087 ));
3088 } else {
3089 panic!("Expected BinaryOp with DateTime constructor");
3090 }
3091 }
3092
3093 #[test]
3094 fn test_cursor_context_after_and() {
3095 let query = "SELECT * FROM trade_deal WHERE status = 'active' AND ";
3096 let (context, partial) = detect_cursor_context(query, query.len());
3097
3098 assert!(matches!(
3099 context,
3100 CursorContext::AfterLogicalOp(LogicalOp::And)
3101 ));
3102 assert_eq!(partial, None);
3103 }
3104
3105 #[test]
3106 fn test_cursor_context_with_partial() {
3107 let query = "SELECT * FROM trade_deal WHERE status = 'active' AND p";
3108 let (context, partial) = detect_cursor_context(query, query.len());
3109
3110 assert!(matches!(
3111 context,
3112 CursorContext::AfterLogicalOp(LogicalOp::And)
3113 ));
3114 assert_eq!(partial, Some("p".to_string()));
3115 }
3116
3117 #[test]
3118 fn test_cursor_context_after_datetime_comparison() {
3119 let query = "SELECT * FROM trade_deal WHERE createdDate > ";
3120 let (context, partial) = detect_cursor_context(query, query.len());
3121
3122 assert!(
3123 matches!(context, CursorContext::AfterComparisonOp(col, op) if col == "createdDate" && op == ">")
3124 );
3125 assert_eq!(partial, None);
3126 }
3127
3128 #[test]
3129 fn test_cursor_context_partial_datetime() {
3130 let query = "SELECT * FROM trade_deal WHERE createdDate > Date";
3131 let (context, partial) = detect_cursor_context(query, query.len());
3132
3133 assert!(
3134 matches!(context, CursorContext::AfterComparisonOp(col, op) if col == "createdDate" && op == ">")
3135 );
3136 assert_eq!(partial, Some("Date".to_string()));
3137 }
3138
3139 #[test]
3141 fn test_tokenizer_quoted_identifier() {
3142 let mut lexer = Lexer::new(r#"SELECT "Customer Id", "First Name" FROM customers"#);
3143
3144 assert!(matches!(lexer.next_token(), Token::Select));
3145 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "Customer Id"));
3146 assert!(matches!(lexer.next_token(), Token::Comma));
3147 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "First Name"));
3148 assert!(matches!(lexer.next_token(), Token::From));
3149 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "customers"));
3150 }
3151
3152 #[test]
3153 fn test_tokenizer_quoted_vs_string_literal() {
3154 let mut lexer = Lexer::new(r#"WHERE "Customer Id" = 'John' AND Country.Contains('USA')"#);
3156
3157 assert!(matches!(lexer.next_token(), Token::Where));
3158 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "Customer Id"));
3159 assert!(matches!(lexer.next_token(), Token::Equal));
3160 assert!(matches!(lexer.next_token(), Token::StringLiteral(s) if s == "John"));
3161 assert!(matches!(lexer.next_token(), Token::And));
3162 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Country"));
3163 assert!(matches!(lexer.next_token(), Token::Dot));
3164 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Contains"));
3165 assert!(matches!(lexer.next_token(), Token::LeftParen));
3166 assert!(matches!(lexer.next_token(), Token::StringLiteral(s) if s == "USA"));
3167 assert!(matches!(lexer.next_token(), Token::RightParen));
3168 }
3169
3170 #[test]
3171 fn test_tokenizer_method_with_double_quotes_should_be_string() {
3172 let mut lexer = Lexer::new(r#"Country.Contains("Alb")"#);
3175
3176 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Country"));
3177 assert!(matches!(lexer.next_token(), Token::Dot));
3178 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Contains"));
3179 assert!(matches!(lexer.next_token(), Token::LeftParen));
3180
3181 let token = lexer.next_token();
3184 println!("Token for \"Alb\": {:?}", token);
3185 assert!(matches!(lexer.next_token(), Token::RightParen));
3189 }
3190
3191 #[test]
3192 fn test_parse_select_with_quoted_columns() {
3193 let mut parser = Parser::new(r#"SELECT "Customer Id", Company FROM customers"#);
3194 let stmt = parser.parse().unwrap();
3195
3196 assert_eq!(stmt.columns, vec!["Customer Id", "Company"]);
3197 assert_eq!(stmt.from_table, Some("customers".to_string()));
3198 }
3199
3200 #[test]
3201 fn test_cursor_context_select_with_partial_quoted() {
3202 let query = r#"SELECT "Cust"#;
3204 let (context, partial) = detect_cursor_context(query, query.len() - 1); println!("Context: {:?}, Partial: {:?}", context, partial);
3207 assert!(matches!(context, CursorContext::SelectClause));
3208 }
3211
3212 #[test]
3213 fn test_cursor_context_select_after_comma_with_quoted() {
3214 let query = r#"SELECT Company, "Customer "#;
3216 let (context, partial) = detect_cursor_context(query, query.len());
3217
3218 println!("Context: {:?}, Partial: {:?}", context, partial);
3219 assert!(matches!(context, CursorContext::SelectClause));
3220 }
3222
3223 #[test]
3224 fn test_cursor_context_order_by_quoted() {
3225 let query = r#"SELECT * FROM customers ORDER BY "Cust"#;
3226 let (context, partial) = detect_cursor_context(query, query.len() - 1);
3227
3228 println!("Context: {:?}, Partial: {:?}", context, partial);
3229 assert!(matches!(context, CursorContext::OrderByClause));
3230 }
3232
3233 #[test]
3234 fn test_where_clause_with_quoted_column() {
3235 let mut parser = Parser::new(r#"SELECT * FROM customers WHERE "Customer Id" = 'C123'"#);
3236 let stmt = parser.parse().unwrap();
3237
3238 assert!(stmt.where_clause.is_some());
3239 let where_clause = stmt.where_clause.unwrap();
3240 assert_eq!(where_clause.conditions.len(), 1);
3241
3242 if let SqlExpression::BinaryOp { left, op, right } = &where_clause.conditions[0].expr {
3243 assert_eq!(op, "=");
3244 assert!(matches!(left.as_ref(), SqlExpression::Column(col) if col == "Customer Id"));
3245 assert!(matches!(right.as_ref(), SqlExpression::StringLiteral(s) if s == "C123"));
3246 } else {
3247 panic!("Expected BinaryOp");
3248 }
3249 }
3250
3251 #[test]
3252 fn test_parse_method_with_double_quotes_as_string() {
3253 let mut parser = Parser::new(r#"SELECT * FROM customers WHERE Country.Contains("USA")"#);
3255 let stmt = parser.parse().unwrap();
3256
3257 assert!(stmt.where_clause.is_some());
3258 let where_clause = stmt.where_clause.unwrap();
3259 assert_eq!(where_clause.conditions.len(), 1);
3260
3261 if let SqlExpression::MethodCall {
3262 object,
3263 method,
3264 args,
3265 } = &where_clause.conditions[0].expr
3266 {
3267 assert_eq!(object, "Country");
3268 assert_eq!(method, "Contains");
3269 assert_eq!(args.len(), 1);
3270 assert!(matches!(&args[0], SqlExpression::StringLiteral(s) if s == "USA"));
3272 } else {
3273 panic!("Expected MethodCall");
3274 }
3275 }
3276
3277 #[test]
3278 fn test_extract_partial_with_quoted_columns_in_query() {
3279 let query = r#"SELECT City,Company,Country,"Customer Id" FROM customers ORDER BY coun"#;
3281 let (context, partial) = detect_cursor_context(query, query.len());
3282
3283 assert!(matches!(context, CursorContext::OrderByClause));
3284 assert_eq!(
3285 partial,
3286 Some("coun".to_string()),
3287 "Should extract 'coun' as partial, not everything after the quoted column"
3288 );
3289 }
3290
3291 #[test]
3292 fn test_extract_partial_quoted_identifier_being_typed() {
3293 let query = r#"SELECT "Cust"#;
3295 let partial = extract_partial_at_end(query);
3296 assert_eq!(partial, Some("\"Cust".to_string()));
3297
3298 let query2 = r#"SELECT "Customer Id" FROM"#;
3300 let partial2 = extract_partial_at_end(query2);
3301 assert_eq!(partial2, None); }
3303
3304 #[test]
3306 fn test_complex_where_parentheses_basic() {
3307 let mut parser =
3309 Parser::new(r#"SELECT * FROM trades WHERE (status = "active" OR status = "pending")"#);
3310 let stmt = parser.parse().unwrap();
3311
3312 assert!(stmt.where_clause.is_some());
3313 let where_clause = stmt.where_clause.unwrap();
3314 assert_eq!(where_clause.conditions.len(), 1);
3315
3316 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[0].expr {
3318 assert_eq!(op, "OR");
3319 } else {
3320 panic!("Expected BinaryOp with OR");
3321 }
3322 }
3323
3324 #[test]
3325 fn test_complex_where_mixed_and_or_with_parens() {
3326 let mut parser = Parser::new(
3328 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#,
3329 );
3330 let stmt = parser.parse().unwrap();
3331
3332 assert!(stmt.where_clause.is_some());
3333 let where_clause = stmt.where_clause.unwrap();
3334 assert_eq!(where_clause.conditions.len(), 2);
3335
3336 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[0].expr {
3338 assert_eq!(op, "OR");
3339 } else {
3340 panic!("Expected first condition to be OR expression");
3341 }
3342
3343 assert!(matches!(
3345 where_clause.conditions[0].connector,
3346 Some(LogicalOp::And)
3347 ));
3348
3349 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[1].expr {
3351 assert_eq!(op, ">");
3352 } else {
3353 panic!("Expected second condition to be price > 100");
3354 }
3355 }
3356
3357 #[test]
3358 fn test_complex_where_nested_parentheses() {
3359 let mut parser = Parser::new(
3361 r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND price > 100) OR status = "cancelled""#,
3362 );
3363 let stmt = parser.parse().unwrap();
3364
3365 assert!(stmt.where_clause.is_some());
3366 let where_clause = stmt.where_clause.unwrap();
3367
3368 assert!(where_clause.conditions.len() > 0);
3370 }
3371
3372 #[test]
3373 fn test_complex_where_multiple_or_groups() {
3374 let mut parser = Parser::new(
3376 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL" OR symbol = "MSFT") AND (price > 100 AND price < 500)"#,
3377 );
3378 let stmt = parser.parse().unwrap();
3379
3380 assert!(stmt.where_clause.is_some());
3381 let where_clause = stmt.where_clause.unwrap();
3382 assert_eq!(where_clause.conditions.len(), 2);
3383
3384 assert!(matches!(
3386 where_clause.conditions[0].connector,
3387 Some(LogicalOp::And)
3388 ));
3389 }
3390
3391 #[test]
3392 fn test_complex_where_with_methods_in_parens() {
3393 let mut parser = Parser::new(
3395 r#"SELECT * FROM trades WHERE (symbol.StartsWith("A") OR symbol.StartsWith("G")) AND volume > 1000000"#,
3396 );
3397 let stmt = parser.parse().unwrap();
3398
3399 assert!(stmt.where_clause.is_some());
3400 let where_clause = stmt.where_clause.unwrap();
3401 assert_eq!(where_clause.conditions.len(), 2);
3402
3403 if let SqlExpression::BinaryOp { op, left, right } = &where_clause.conditions[0].expr {
3405 assert_eq!(op, "OR");
3406 assert!(matches!(left.as_ref(), SqlExpression::MethodCall { .. }));
3407 assert!(matches!(right.as_ref(), SqlExpression::MethodCall { .. }));
3408 } else {
3409 panic!("Expected OR of method calls");
3410 }
3411 }
3412
3413 #[test]
3414 fn test_complex_where_date_comparisons_with_parens() {
3415 let mut parser = Parser::new(
3417 r#"SELECT * FROM trades WHERE (executionDate > DateTime(2024, 1, 1) AND executionDate < DateTime(2024, 12, 31)) AND (status = "filled" OR status = "partial")"#,
3418 );
3419 let stmt = parser.parse().unwrap();
3420
3421 assert!(stmt.where_clause.is_some());
3422 let where_clause = stmt.where_clause.unwrap();
3423 assert_eq!(where_clause.conditions.len(), 2);
3424
3425 assert!(matches!(
3427 where_clause.conditions[0].connector,
3428 Some(LogicalOp::And)
3429 ));
3430 }
3431
3432 #[test]
3433 fn test_complex_where_price_volume_filters() {
3434 let mut parser = Parser::new(
3436 r#"SELECT * FROM trades WHERE ((price > 100 AND price < 200) OR (price > 500 AND price < 1000)) AND volume > 10000"#,
3437 );
3438 let stmt = parser.parse().unwrap();
3439
3440 assert!(stmt.where_clause.is_some());
3441 let where_clause = stmt.where_clause.unwrap();
3442
3443 assert!(where_clause.conditions.len() > 0);
3445 }
3446
3447 #[test]
3448 fn test_complex_where_mixed_string_numeric() {
3449 let mut parser = Parser::new(
3451 r#"SELECT * FROM trades WHERE (exchange = "NYSE" OR exchange = "NASDAQ") AND (volume > 1000000 OR notes.Contains("urgent"))"#,
3452 );
3453 let stmt = parser.parse().unwrap();
3454
3455 assert!(stmt.where_clause.is_some());
3456 }
3458
3459 #[test]
3460 fn test_complex_where_triple_nested() {
3461 let mut parser = Parser::new(
3463 r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND (price > 100 OR volume > 1000000)) OR (status = "cancelled" AND reason.Contains("timeout"))"#,
3464 );
3465 let stmt = parser.parse().unwrap();
3466
3467 assert!(stmt.where_clause.is_some());
3468 }
3470
3471 #[test]
3472 fn test_complex_where_single_parens_around_and() {
3473 let mut parser = Parser::new(
3475 r#"SELECT * FROM trades WHERE (symbol = "AAPL" AND price > 150 AND volume > 100000)"#,
3476 );
3477 let stmt = parser.parse().unwrap();
3478
3479 assert!(stmt.where_clause.is_some());
3480 let where_clause = stmt.where_clause.unwrap();
3481
3482 assert!(where_clause.conditions.len() > 0);
3484 }
3485
3486 #[test]
3488 fn test_format_preserves_simple_parentheses() {
3489 let query = r#"SELECT * FROM trades WHERE (status = "active" OR status = "pending")"#;
3490 let formatted = format_sql_pretty_compact(query, 5);
3491 let formatted_text = formatted.join(" ");
3492
3493 assert!(formatted_text.contains("(status"));
3495 assert!(formatted_text.contains("\"pending\")"));
3496
3497 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3499 let formatted_parens = formatted_text
3500 .chars()
3501 .filter(|c| *c == '(' || *c == ')')
3502 .count();
3503 assert_eq!(
3504 original_parens, formatted_parens,
3505 "Parentheses should be preserved"
3506 );
3507 }
3508
3509 #[test]
3510 fn test_format_preserves_complex_parentheses() {
3511 let query =
3512 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#;
3513 let formatted = format_sql_pretty_compact(query, 5);
3514 let formatted_text = formatted.join(" ");
3515
3516 assert!(formatted_text.contains("(symbol"));
3518 assert!(formatted_text.contains("\"GOOGL\")"));
3519
3520 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3522 let formatted_parens = formatted_text
3523 .chars()
3524 .filter(|c| *c == '(' || *c == ')')
3525 .count();
3526 assert_eq!(original_parens, formatted_parens);
3527 }
3528
3529 #[test]
3530 fn test_format_preserves_nested_parentheses() {
3531 let query = r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND price > 100) OR status = "cancelled""#;
3532 let formatted = format_sql_pretty_compact(query, 5);
3533 let formatted_text = formatted.join(" ");
3534
3535 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3537 let formatted_parens = formatted_text
3538 .chars()
3539 .filter(|c| *c == '(' || *c == ')')
3540 .count();
3541 assert_eq!(
3542 original_parens, formatted_parens,
3543 "Nested parentheses should be preserved"
3544 );
3545 assert_eq!(original_parens, 4, "Should have 4 parentheses total");
3546 }
3547
3548 #[test]
3549 fn test_format_preserves_method_calls_in_parentheses() {
3550 let query = r#"SELECT * FROM trades WHERE (symbol.StartsWith("A") OR symbol.StartsWith("G")) AND volume > 1000000"#;
3551 let formatted = format_sql_pretty_compact(query, 5);
3552 let formatted_text = formatted.join(" ");
3553
3554 assert!(formatted_text.contains("(symbol.StartsWith"));
3556 assert!(formatted_text.contains("StartsWith(\"A\")"));
3557 assert!(formatted_text.contains("StartsWith(\"G\")"));
3558
3559 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3561 let formatted_parens = formatted_text
3562 .chars()
3563 .filter(|c| *c == '(' || *c == ')')
3564 .count();
3565 assert_eq!(original_parens, formatted_parens);
3566 assert_eq!(
3567 original_parens, 6,
3568 "Should have 6 parentheses (1 group + 2 method calls)"
3569 );
3570 }
3571
3572 #[test]
3573 fn test_format_preserves_multiple_groups() {
3574 let query = r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND (price > 100 AND price < 500)"#;
3575 let formatted = format_sql_pretty_compact(query, 5);
3576 let formatted_text = formatted.join(" ");
3577
3578 assert!(formatted_text.contains("(symbol"));
3580 assert!(formatted_text.contains("(price"));
3581
3582 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3583 let formatted_parens = formatted_text
3584 .chars()
3585 .filter(|c| *c == '(' || *c == ')')
3586 .count();
3587 assert_eq!(original_parens, formatted_parens);
3588 assert_eq!(original_parens, 4, "Should have 4 parentheses (2 groups)");
3589 }
3590
3591 #[test]
3592 fn test_format_preserves_date_ranges() {
3593 let query = r#"SELECT * FROM trades WHERE (executionDate > DateTime(2024, 1, 1) AND executionDate < DateTime(2024, 12, 31))"#;
3594 let formatted = format_sql_pretty_compact(query, 5);
3595 let formatted_text = formatted.join(" ");
3596
3597 assert!(formatted_text.contains("(executionDate"));
3599 assert!(formatted_text.contains("DateTime(2024, 1, 1)"));
3600 assert!(formatted_text.contains("DateTime(2024, 12, 31)"));
3601
3602 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3603 let formatted_parens = formatted_text
3604 .chars()
3605 .filter(|c| *c == '(' || *c == ')')
3606 .count();
3607 assert_eq!(original_parens, formatted_parens);
3608 }
3609
3610 #[test]
3611 fn test_format_multiline_layout() {
3612 let query =
3614 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#;
3615 let formatted = format_sql_pretty_compact(query, 5);
3616
3617 assert!(formatted.len() >= 4, "Should have multiple lines");
3619 assert_eq!(formatted[0], "SELECT");
3620 assert!(formatted[1].trim().starts_with("*"));
3621 assert!(formatted[2].starts_with("FROM"));
3622 assert_eq!(formatted[3], "WHERE");
3623
3624 let where_lines: Vec<_> = formatted.iter().skip(4).collect();
3626 assert!(where_lines.iter().any(|l| l.contains("(symbol")));
3627 assert!(where_lines.iter().any(|l| l.trim() == "AND"));
3628 }
3629
3630 #[test]
3631 fn test_between_simple() {
3632 let mut parser = Parser::new("SELECT * FROM table WHERE price BETWEEN 50 AND 100");
3633 let stmt = parser.parse().expect("Should parse simple BETWEEN");
3634
3635 assert!(stmt.where_clause.is_some());
3636 let where_clause = stmt.where_clause.unwrap();
3637 assert_eq!(where_clause.conditions.len(), 1);
3638
3639 let ast = format_ast_tree("SELECT * FROM table WHERE price BETWEEN 50 AND 100");
3641 assert!(!ast.contains("PARSE ERROR"));
3642 assert!(ast.contains("SelectStatement"));
3643 }
3644
3645 #[test]
3646 fn test_between_in_parentheses() {
3647 let mut parser = Parser::new("SELECT * FROM table WHERE (price BETWEEN 50 AND 100)");
3648 let stmt = parser.parse().expect("Should parse BETWEEN in parentheses");
3649
3650 assert!(stmt.where_clause.is_some());
3651
3652 let ast = format_ast_tree("SELECT * FROM table WHERE (price BETWEEN 50 AND 100)");
3654 assert!(!ast.contains("PARSE ERROR"), "Should not have parse error");
3655 }
3656
3657 #[test]
3658 fn test_between_with_or() {
3659 let query = "SELECT * FROM test WHERE (Price BETWEEN 50 AND 100) OR (quantity > 5)";
3660 let mut parser = Parser::new(query);
3661 let stmt = parser.parse().expect("Should parse BETWEEN with OR");
3662
3663 assert!(stmt.where_clause.is_some());
3664 }
3667
3668 #[test]
3669 fn test_between_with_and() {
3670 let query = "SELECT * FROM table WHERE category = 'Books' AND price BETWEEN 10 AND 50";
3671 let mut parser = Parser::new(query);
3672 let stmt = parser.parse().expect("Should parse BETWEEN with AND");
3673
3674 assert!(stmt.where_clause.is_some());
3675 let where_clause = stmt.where_clause.unwrap();
3676 assert_eq!(where_clause.conditions.len(), 2); }
3678
3679 #[test]
3680 fn test_multiple_between() {
3681 let query =
3682 "SELECT * FROM table WHERE (price BETWEEN 10 AND 50) AND (quantity BETWEEN 5 AND 20)";
3683 let mut parser = Parser::new(query);
3684 let stmt = parser
3685 .parse()
3686 .expect("Should parse multiple BETWEEN clauses");
3687
3688 assert!(stmt.where_clause.is_some());
3689 }
3690
3691 #[test]
3692 fn test_between_complex_query() {
3693 let query = "SELECT * FROM test_sorting WHERE (Price BETWEEN 50 AND 100) OR (Product.Length() > 5) ORDER BY Category ASC, price DESC";
3695 let mut parser = Parser::new(query);
3696 let stmt = parser
3697 .parse()
3698 .expect("Should parse complex query with BETWEEN, method calls, and ORDER BY");
3699
3700 assert!(stmt.where_clause.is_some());
3701 assert!(stmt.order_by.is_some());
3702
3703 let order_by = stmt.order_by.unwrap();
3704 assert_eq!(order_by.len(), 2);
3705 assert_eq!(order_by[0].column, "Category");
3706 assert!(matches!(order_by[0].direction, SortDirection::Asc));
3707 assert_eq!(order_by[1].column, "price");
3708 assert!(matches!(order_by[1].direction, SortDirection::Desc));
3709 }
3710
3711 #[test]
3712 fn test_between_formatting() {
3713 let expr = SqlExpression::Between {
3714 expr: Box::new(SqlExpression::Column("price".to_string())),
3715 lower: Box::new(SqlExpression::NumberLiteral("50".to_string())),
3716 upper: Box::new(SqlExpression::NumberLiteral("100".to_string())),
3717 };
3718
3719 let formatted = format_expression(&expr);
3720 assert_eq!(formatted, "price BETWEEN 50 AND 100");
3721
3722 let ast_formatted = format_expression_ast(&expr);
3723 assert!(ast_formatted.contains("Between"));
3724 assert!(ast_formatted.contains("50"));
3725 assert!(ast_formatted.contains("100"));
3726 }
3727
3728 #[test]
3729 fn test_utf8_boundary_safety() {
3730 let query_with_unicode = "SELECT * FROM table WHERE column = 'héllo'";
3732
3733 for pos in 0..query_with_unicode.len() + 1 {
3735 let result =
3737 std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, pos));
3738
3739 assert!(
3740 result.is_ok(),
3741 "Panic at position {} in query with Unicode",
3742 pos
3743 );
3744 }
3745
3746 let result = std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, 1000));
3748 assert!(result.is_ok(), "Panic with position beyond string length");
3749
3750 let pos_in_e = query_with_unicode.find('é').unwrap() + 1; let result =
3753 std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, pos_in_e));
3754 assert!(
3755 result.is_ok(),
3756 "Panic with cursor in middle of UTF-8 character"
3757 );
3758 }
3759}