1use chrono::{Datelike, Local, NaiveDateTime};
2
3#[derive(Debug, Clone, PartialEq)]
4pub enum Token {
5 Select,
7 From,
8 Where,
9 And,
10 Or,
11 In,
12 Not,
13 Between,
14 Like,
15 Is,
16 Null,
17 OrderBy,
18 GroupBy,
19 Having,
20 As,
21 Asc,
22 Desc,
23 Limit,
24 Offset,
25 DateTime, Case, When, Then, Else, End, Distinct, Over, Partition, By, Identifier(String),
38 QuotedIdentifier(String), StringLiteral(String),
40 NumberLiteral(String),
41 Star,
42
43 Dot,
45 Comma,
46 LeftParen,
47 RightParen,
48 Equal,
49 NotEqual,
50 LessThan,
51 GreaterThan,
52 LessThanOrEqual,
53 GreaterThanOrEqual,
54
55 Plus,
57 Minus,
58 Divide,
59
60 Eof,
62}
63
64#[derive(Debug, Clone)]
65pub struct Lexer {
66 input: Vec<char>,
67 position: usize,
68 current_char: Option<char>,
69}
70
71impl Lexer {
72 #[must_use]
73 pub fn new(input: &str) -> Self {
74 let chars: Vec<char> = input.chars().collect();
75 let current = chars.first().copied();
76 Self {
77 input: chars,
78 position: 0,
79 current_char: current,
80 }
81 }
82
83 fn advance(&mut self) {
84 self.position += 1;
85 self.current_char = self.input.get(self.position).copied();
86 }
87
88 fn peek(&self, offset: usize) -> Option<char> {
89 self.input.get(self.position + offset).copied()
90 }
91
92 fn skip_whitespace(&mut self) {
93 while let Some(ch) = self.current_char {
94 if ch.is_whitespace() {
95 self.advance();
96 } else {
97 break;
98 }
99 }
100 }
101
102 fn skip_whitespace_and_comments(&mut self) {
103 loop {
104 while let Some(ch) = self.current_char {
106 if ch.is_whitespace() {
107 self.advance();
108 } else {
109 break;
110 }
111 }
112
113 match self.current_char {
115 Some('-') if self.peek(1) == Some('-') => {
116 self.advance(); self.advance(); while let Some(ch) = self.current_char {
120 self.advance();
121 if ch == '\n' {
122 break;
123 }
124 }
125 }
126 Some('/') if self.peek(1) == Some('*') => {
127 self.advance(); self.advance(); while let Some(ch) = self.current_char {
131 if ch == '*' && self.peek(1) == Some('/') {
132 self.advance(); self.advance(); break;
135 }
136 self.advance();
137 }
138 }
139 _ => {
140 break;
142 }
143 }
144 }
145 }
146
147 fn read_identifier(&mut self) -> String {
148 let mut result = String::new();
149 while let Some(ch) = self.current_char {
150 if ch.is_alphanumeric() || ch == '_' {
151 result.push(ch);
152 self.advance();
153 } else {
154 break;
155 }
156 }
157 result
158 }
159
160 fn read_string(&mut self) -> String {
161 let mut result = String::new();
162 let quote_char = self.current_char.unwrap(); self.advance(); while let Some(ch) = self.current_char {
166 if ch == quote_char {
167 self.advance(); break;
169 }
170 result.push(ch);
171 self.advance();
172 }
173 result
174 }
175
176 fn read_number(&mut self) -> String {
177 let mut result = String::new();
178 let mut has_e = false;
179
180 while let Some(ch) = self.current_char {
182 if !has_e && (ch.is_numeric() || ch == '.') {
183 result.push(ch);
184 self.advance();
185 } else if (ch == 'e' || ch == 'E') && !has_e && !result.is_empty() {
186 result.push(ch);
188 self.advance();
189 has_e = true;
190
191 if let Some(sign) = self.current_char {
193 if sign == '+' || sign == '-' {
194 result.push(sign);
195 self.advance();
196 }
197 }
198
199 while let Some(digit) = self.current_char {
201 if digit.is_numeric() {
202 result.push(digit);
203 self.advance();
204 } else {
205 break;
206 }
207 }
208 break; } else {
210 break;
211 }
212 }
213 result
214 }
215
216 pub fn next_token(&mut self) -> Token {
217 self.skip_whitespace_and_comments();
218
219 match self.current_char {
220 None => Token::Eof,
221 Some('*') => {
222 self.advance();
223 Token::Star }
227 Some('+') => {
228 self.advance();
229 Token::Plus
230 }
231 Some('/') => {
232 if self.peek(1) == Some('*') {
234 self.skip_whitespace_and_comments();
237 return self.next_token();
238 }
239 self.advance();
240 Token::Divide
241 }
242 Some('.') => {
243 self.advance();
244 Token::Dot
245 }
246 Some(',') => {
247 self.advance();
248 Token::Comma
249 }
250 Some('(') => {
251 self.advance();
252 Token::LeftParen
253 }
254 Some(')') => {
255 self.advance();
256 Token::RightParen
257 }
258 Some('=') => {
259 self.advance();
260 Token::Equal
261 }
262 Some('<') => {
263 self.advance();
264 if self.current_char == Some('=') {
265 self.advance();
266 Token::LessThanOrEqual
267 } else if self.current_char == Some('>') {
268 self.advance();
269 Token::NotEqual
270 } else {
271 Token::LessThan
272 }
273 }
274 Some('>') => {
275 self.advance();
276 if self.current_char == Some('=') {
277 self.advance();
278 Token::GreaterThanOrEqual
279 } else {
280 Token::GreaterThan
281 }
282 }
283 Some('!') if self.peek(1) == Some('=') => {
284 self.advance();
285 self.advance();
286 Token::NotEqual
287 }
288 Some('"') => {
289 let ident_val = self.read_string();
291 Token::QuotedIdentifier(ident_val)
292 }
293 Some('\'') => {
294 let string_val = self.read_string();
296 Token::StringLiteral(string_val)
297 }
298 Some('-') if self.peek(1) == Some('-') => {
299 self.skip_whitespace_and_comments();
301 self.next_token()
302 }
303 Some('-') if self.peek(1).is_some_and(char::is_numeric) => {
304 self.advance(); let num = self.read_number();
307 Token::NumberLiteral(format!("-{num}"))
308 }
309 Some('-') => {
310 self.advance();
312 Token::Minus
313 }
314 Some(ch) if ch.is_numeric() => {
315 let num = self.read_number();
316 Token::NumberLiteral(num)
317 }
318 Some(ch) if ch.is_alphabetic() || ch == '_' => {
319 let ident = self.read_identifier();
320 match ident.to_uppercase().as_str() {
321 "SELECT" => Token::Select,
322 "FROM" => Token::From,
323 "WHERE" => Token::Where,
324 "AND" => Token::And,
325 "OR" => Token::Or,
326 "IN" => Token::In,
327 "NOT" => Token::Not,
328 "BETWEEN" => Token::Between,
329 "LIKE" => Token::Like,
330 "IS" => Token::Is,
331 "NULL" => Token::Null,
332 "ORDER" if self.peek_keyword("BY") => {
333 self.skip_whitespace();
334 self.read_identifier(); Token::OrderBy
336 }
337 "GROUP" if self.peek_keyword("BY") => {
338 self.skip_whitespace();
339 self.read_identifier(); Token::GroupBy
341 }
342 "HAVING" => Token::Having,
343 "AS" => Token::As,
344 "ASC" => Token::Asc,
345 "DESC" => Token::Desc,
346 "LIMIT" => Token::Limit,
347 "OFFSET" => Token::Offset,
348 "DATETIME" => Token::DateTime,
349 "CASE" => Token::Case,
350 "WHEN" => Token::When,
351 "THEN" => Token::Then,
352 "ELSE" => Token::Else,
353 "END" => Token::End,
354 "DISTINCT" => Token::Distinct,
355 "OVER" => Token::Over,
356 "PARTITION" => Token::Partition,
357 "BY" => Token::By,
358 _ => Token::Identifier(ident),
359 }
360 }
361 Some(ch) => {
362 self.advance();
363 Token::Identifier(ch.to_string())
364 }
365 }
366 }
367
368 fn peek_keyword(&mut self, keyword: &str) -> bool {
369 let saved_pos = self.position;
370 let saved_char = self.current_char;
371
372 self.skip_whitespace_and_comments();
373 let next_word = self.read_identifier();
374 let matches = next_word.to_uppercase() == keyword;
375
376 self.position = saved_pos;
378 self.current_char = saved_char;
379
380 matches
381 }
382
383 #[must_use]
384 pub fn get_position(&self) -> usize {
385 self.position
386 }
387
388 pub fn tokenize_all(&mut self) -> Vec<Token> {
389 let mut tokens = Vec::new();
390 loop {
391 let token = self.next_token();
392 if matches!(token, Token::Eof) {
393 tokens.push(token);
394 break;
395 }
396 tokens.push(token);
397 }
398 tokens
399 }
400
401 pub fn tokenize_all_with_positions(&mut self) -> Vec<(usize, usize, Token)> {
402 let mut tokens = Vec::new();
403 loop {
404 self.skip_whitespace_and_comments();
405 let start_pos = self.position;
406 let token = self.next_token();
407 let end_pos = self.position;
408
409 if matches!(token, Token::Eof) {
410 break;
411 }
412 tokens.push((start_pos, end_pos, token));
413 }
414 tokens
415 }
416}
417
418#[derive(Debug, Clone)]
420pub enum SqlExpression {
421 Column(String),
422 StringLiteral(String),
423 NumberLiteral(String),
424 BooleanLiteral(bool),
425 DateTimeConstructor {
426 year: i32,
427 month: u32,
428 day: u32,
429 hour: Option<u32>,
430 minute: Option<u32>,
431 second: Option<u32>,
432 },
433 DateTimeToday {
434 hour: Option<u32>,
435 minute: Option<u32>,
436 second: Option<u32>,
437 },
438 MethodCall {
439 object: String,
440 method: String,
441 args: Vec<SqlExpression>,
442 },
443 ChainedMethodCall {
444 base: Box<SqlExpression>,
445 method: String,
446 args: Vec<SqlExpression>,
447 },
448 FunctionCall {
449 name: String,
450 args: Vec<SqlExpression>,
451 },
452 WindowFunction {
453 name: String,
454 args: Vec<SqlExpression>,
455 window_spec: WindowSpec,
456 },
457 BinaryOp {
458 left: Box<SqlExpression>,
459 op: String,
460 right: Box<SqlExpression>,
461 },
462 InList {
463 expr: Box<SqlExpression>,
464 values: Vec<SqlExpression>,
465 },
466 NotInList {
467 expr: Box<SqlExpression>,
468 values: Vec<SqlExpression>,
469 },
470 Between {
471 expr: Box<SqlExpression>,
472 lower: Box<SqlExpression>,
473 upper: Box<SqlExpression>,
474 },
475 Not {
476 expr: Box<SqlExpression>,
477 },
478 CaseExpression {
479 when_branches: Vec<WhenBranch>,
480 else_branch: Option<Box<SqlExpression>>,
481 },
482}
483
484#[derive(Debug, Clone)]
485pub struct WhenBranch {
486 pub condition: Box<SqlExpression>,
487 pub result: Box<SqlExpression>,
488}
489
490#[derive(Debug, Clone)]
491pub struct WhereClause {
492 pub conditions: Vec<Condition>,
493}
494
495#[derive(Debug, Clone)]
496pub struct Condition {
497 pub expr: SqlExpression,
498 pub connector: Option<LogicalOp>, }
500
501#[derive(Debug, Clone)]
502pub enum LogicalOp {
503 And,
504 Or,
505}
506
507#[derive(Debug, Clone, PartialEq)]
508pub enum SortDirection {
509 Asc,
510 Desc,
511}
512
513#[derive(Debug, Clone)]
514pub struct OrderByColumn {
515 pub column: String,
516 pub direction: SortDirection,
517}
518
519#[derive(Debug, Clone)]
520pub struct WindowSpec {
521 pub partition_by: Vec<String>,
522 pub order_by: Vec<OrderByColumn>,
523}
524
525#[derive(Debug, Clone)]
527pub enum SelectItem {
528 Column(String),
530 Expression { expr: SqlExpression, alias: String },
532 Star,
534}
535
536#[derive(Debug, Clone)]
537pub struct SelectStatement {
538 pub columns: Vec<String>, pub select_items: Vec<SelectItem>, pub from_table: Option<String>,
541 pub where_clause: Option<WhereClause>,
542 pub order_by: Option<Vec<OrderByColumn>>,
543 pub group_by: Option<Vec<String>>,
544 pub having: Option<SqlExpression>, pub limit: Option<usize>,
546 pub offset: Option<usize>,
547}
548
549#[derive(Default)]
550pub struct ParserConfig {
551 pub case_insensitive: bool,
552}
553
554pub struct Parser {
555 lexer: Lexer,
556 current_token: Token,
557 in_method_args: bool, columns: Vec<String>, paren_depth: i32, #[allow(dead_code)]
561 config: ParserConfig, }
563
564impl Parser {
565 #[must_use]
566 pub fn new(input: &str) -> Self {
567 let mut lexer = Lexer::new(input);
568 let current_token = lexer.next_token();
569 Self {
570 lexer,
571 current_token,
572 in_method_args: false,
573 columns: Vec::new(),
574 paren_depth: 0,
575 config: ParserConfig::default(),
576 }
577 }
578
579 #[must_use]
580 pub fn with_config(input: &str, config: ParserConfig) -> Self {
581 let mut lexer = Lexer::new(input);
582 let current_token = lexer.next_token();
583 Self {
584 lexer,
585 current_token,
586 in_method_args: false,
587 columns: Vec::new(),
588 paren_depth: 0,
589 config,
590 }
591 }
592
593 #[must_use]
594 pub fn with_columns(mut self, columns: Vec<String>) -> Self {
595 self.columns = columns;
596 self
597 }
598
599 fn consume(&mut self, expected: Token) -> Result<(), String> {
600 if std::mem::discriminant(&self.current_token) == std::mem::discriminant(&expected) {
601 match &expected {
603 Token::LeftParen => self.paren_depth += 1,
604 Token::RightParen => {
605 self.paren_depth -= 1;
606 if self.paren_depth < 0 {
608 return Err(
609 "Unexpected closing parenthesis - no matching opening parenthesis"
610 .to_string(),
611 );
612 }
613 }
614 _ => {}
615 }
616
617 self.current_token = self.lexer.next_token();
618 Ok(())
619 } else {
620 let error_msg = match (&expected, &self.current_token) {
622 (Token::RightParen, Token::Eof) if self.paren_depth > 0 => {
623 format!(
624 "Unclosed parenthesis - missing {} closing parenthes{}",
625 self.paren_depth,
626 if self.paren_depth == 1 { "is" } else { "es" }
627 )
628 }
629 (Token::RightParen, _) if self.paren_depth > 0 => {
630 format!(
631 "Expected closing parenthesis but found {:?} (currently {} unclosed parenthes{})",
632 self.current_token,
633 self.paren_depth,
634 if self.paren_depth == 1 { "is" } else { "es" }
635 )
636 }
637 _ => format!("Expected {:?}, found {:?}", expected, self.current_token),
638 };
639 Err(error_msg)
640 }
641 }
642
643 fn advance(&mut self) {
644 match &self.current_token {
646 Token::LeftParen => self.paren_depth += 1,
647 Token::RightParen => {
648 self.paren_depth -= 1;
649 }
652 _ => {}
653 }
654 self.current_token = self.lexer.next_token();
655 }
656
657 pub fn parse(&mut self) -> Result<SelectStatement, String> {
658 self.parse_select_statement()
659 }
660
661 fn parse_select_statement(&mut self) -> Result<SelectStatement, String> {
662 self.consume(Token::Select)?;
663
664 let select_items = self.parse_select_items()?;
666
667 let columns = select_items
669 .iter()
670 .map(|item| match item {
671 SelectItem::Star => "*".to_string(),
672 SelectItem::Column(name) => name.clone(),
673 SelectItem::Expression { alias, .. } => alias.clone(),
674 })
675 .collect();
676
677 let from_table = if matches!(self.current_token, Token::From) {
678 self.advance();
679 match &self.current_token {
680 Token::Identifier(table) => {
681 let table_name = table.clone();
682 self.advance();
683 Some(table_name)
684 }
685 Token::QuotedIdentifier(table) => {
686 let table_name = table.clone();
688 self.advance();
689 Some(table_name)
690 }
691 _ => return Err("Expected table name after FROM".to_string()),
692 }
693 } else {
694 None
695 };
696
697 let where_clause = if matches!(self.current_token, Token::Where) {
698 self.advance();
699 Some(self.parse_where_clause()?)
700 } else {
701 None
702 };
703
704 let order_by = if matches!(self.current_token, Token::OrderBy) {
705 self.advance();
706 Some(self.parse_order_by_list()?)
707 } else {
708 None
709 };
710
711 let group_by = if matches!(self.current_token, Token::GroupBy) {
712 self.advance();
713 Some(self.parse_identifier_list()?)
714 } else {
715 None
716 };
717
718 let having = if matches!(self.current_token, Token::Having) {
720 if group_by.is_none() {
721 return Err("HAVING clause requires GROUP BY".to_string());
722 }
723 self.advance();
724 Some(self.parse_expression()?)
725 } else {
726 None
727 };
728
729 let limit = if matches!(self.current_token, Token::Limit) {
731 self.advance();
732 match &self.current_token {
733 Token::NumberLiteral(num) => {
734 let limit_val = num
735 .parse::<usize>()
736 .map_err(|_| format!("Invalid LIMIT value: {num}"))?;
737 self.advance();
738 Some(limit_val)
739 }
740 _ => return Err("Expected number after LIMIT".to_string()),
741 }
742 } else {
743 None
744 };
745
746 let offset = if matches!(self.current_token, Token::Offset) {
748 self.advance();
749 match &self.current_token {
750 Token::NumberLiteral(num) => {
751 let offset_val = num
752 .parse::<usize>()
753 .map_err(|_| format!("Invalid OFFSET value: {num}"))?;
754 self.advance();
755 Some(offset_val)
756 }
757 _ => return Err("Expected number after OFFSET".to_string()),
758 }
759 } else {
760 None
761 };
762
763 if self.paren_depth > 0 {
765 return Err(format!(
766 "Unclosed parenthesis - missing {} closing parenthes{}",
767 self.paren_depth,
768 if self.paren_depth == 1 { "is" } else { "es" }
769 ));
770 } else if self.paren_depth < 0 {
771 return Err(
772 "Extra closing parenthesis found - no matching opening parenthesis".to_string(),
773 );
774 }
775
776 Ok(SelectStatement {
777 columns,
778 select_items,
779 from_table,
780 where_clause,
781 order_by,
782 group_by,
783 having,
784 limit,
785 offset,
786 })
787 }
788
789 fn parse_select_list(&mut self) -> Result<Vec<String>, String> {
790 let mut columns = Vec::new();
791
792 if matches!(self.current_token, Token::Star) {
793 columns.push("*".to_string());
794 self.advance();
795 } else {
796 loop {
797 match &self.current_token {
798 Token::Identifier(col) => {
799 columns.push(col.clone());
800 self.advance();
801 }
802 Token::QuotedIdentifier(col) => {
803 columns.push(col.clone());
805 self.advance();
806 }
807 _ => return Err("Expected column name".to_string()),
808 }
809
810 if matches!(self.current_token, Token::Comma) {
811 self.advance();
812 } else {
813 break;
814 }
815 }
816 }
817
818 Ok(columns)
819 }
820
821 fn parse_select_items(&mut self) -> Result<Vec<SelectItem>, String> {
823 let mut items = Vec::new();
824
825 loop {
826 if matches!(self.current_token, Token::Star) {
829 items.push(SelectItem::Star);
837 self.advance();
838 } else {
839 let expr = self.parse_additive()?; let alias = if matches!(self.current_token, Token::As) {
844 self.advance();
845 match &self.current_token {
846 Token::Identifier(alias_name) => {
847 let alias = alias_name.clone();
848 self.advance();
849 alias
850 }
851 Token::QuotedIdentifier(alias_name) => {
852 let alias = alias_name.clone();
853 self.advance();
854 alias
855 }
856 _ => return Err("Expected alias name after AS".to_string()),
857 }
858 } else {
859 match &expr {
861 SqlExpression::Column(col_name) => col_name.clone(),
862 _ => format!("expr_{}", items.len() + 1), }
864 };
865
866 let item = match expr {
868 SqlExpression::Column(col_name) if alias == col_name => {
869 SelectItem::Column(col_name)
871 }
872 _ => {
873 SelectItem::Expression { expr, alias }
875 }
876 };
877
878 items.push(item);
879 }
880
881 if matches!(self.current_token, Token::Comma) {
883 self.advance();
884 } else {
885 break;
886 }
887 }
888
889 Ok(items)
890 }
891
892 fn parse_identifier_list(&mut self) -> Result<Vec<String>, String> {
893 let mut identifiers = Vec::new();
894
895 loop {
896 match &self.current_token {
897 Token::Identifier(id) => {
898 identifiers.push(id.clone());
899 self.advance();
900 }
901 Token::QuotedIdentifier(id) => {
902 identifiers.push(id.clone());
904 self.advance();
905 }
906 _ => return Err("Expected identifier".to_string()),
907 }
908
909 if matches!(self.current_token, Token::Comma) {
910 self.advance();
911 } else {
912 break;
913 }
914 }
915
916 Ok(identifiers)
917 }
918
919 fn parse_window_spec(&mut self) -> Result<WindowSpec, String> {
920 let mut partition_by = Vec::new();
921 let mut order_by = Vec::new();
922
923 if matches!(self.current_token, Token::Partition) {
925 self.advance(); if !matches!(self.current_token, Token::By) {
927 return Err("Expected BY after PARTITION".to_string());
928 }
929 self.advance(); partition_by = self.parse_identifier_list()?;
933 }
934
935 if matches!(self.current_token, Token::OrderBy) {
937 self.advance(); order_by = self.parse_order_by_list()?;
939 } else if let Token::Identifier(s) = &self.current_token {
940 if s.to_uppercase() == "ORDER" {
941 self.advance(); if !matches!(self.current_token, Token::By) {
944 return Err("Expected BY after ORDER".to_string());
945 }
946 self.advance(); order_by = self.parse_order_by_list()?;
948 }
949 }
950
951 Ok(WindowSpec {
952 partition_by,
953 order_by,
954 })
955 }
956
957 fn parse_order_by_list(&mut self) -> Result<Vec<OrderByColumn>, String> {
958 let mut order_columns = Vec::new();
959
960 loop {
961 let column = match &self.current_token {
962 Token::Identifier(id) => {
963 let col = id.clone();
964 self.advance();
965 col
966 }
967 Token::QuotedIdentifier(id) => {
968 let col = id.clone();
969 self.advance();
970 col
971 }
972 Token::NumberLiteral(num) if self.columns.iter().any(|col| col == num) => {
973 let col = num.clone();
975 self.advance();
976 col
977 }
978 _ => return Err("Expected column name in ORDER BY".to_string()),
979 };
980
981 let direction = match &self.current_token {
983 Token::Asc => {
984 self.advance();
985 SortDirection::Asc
986 }
987 Token::Desc => {
988 self.advance();
989 SortDirection::Desc
990 }
991 _ => SortDirection::Asc, };
993
994 order_columns.push(OrderByColumn { column, direction });
995
996 if matches!(self.current_token, Token::Comma) {
997 self.advance();
998 } else {
999 break;
1000 }
1001 }
1002
1003 Ok(order_columns)
1004 }
1005
1006 fn parse_where_clause(&mut self) -> Result<WhereClause, String> {
1007 let mut conditions = Vec::new();
1008
1009 loop {
1010 let expr = self.parse_expression()?;
1011
1012 let connector = match &self.current_token {
1013 Token::And => {
1014 self.advance();
1015 Some(LogicalOp::And)
1016 }
1017 Token::Or => {
1018 self.advance();
1019 Some(LogicalOp::Or)
1020 }
1021 Token::RightParen if self.paren_depth <= 0 => {
1022 return Err(
1024 "Unexpected closing parenthesis - no matching opening parenthesis"
1025 .to_string(),
1026 );
1027 }
1028 _ => None,
1029 };
1030
1031 conditions.push(Condition {
1032 expr,
1033 connector: connector.clone(),
1034 });
1035
1036 if connector.is_none() {
1037 break;
1038 }
1039 }
1040
1041 Ok(WhereClause { conditions })
1042 }
1043
1044 fn parse_expression(&mut self) -> Result<SqlExpression, String> {
1045 let mut left = self.parse_comparison()?;
1046
1047 if let Some(op) = self.get_binary_op() {
1050 self.advance();
1051 let right = self.parse_expression()?;
1052 left = SqlExpression::BinaryOp {
1053 left: Box::new(left),
1054 op,
1055 right: Box::new(right),
1056 };
1057 }
1058
1059 if matches!(self.current_token, Token::In) {
1061 self.advance();
1062 self.consume(Token::LeftParen)?;
1063 let values = self.parse_expression_list()?;
1064 self.consume(Token::RightParen)?;
1065
1066 left = SqlExpression::InList {
1067 expr: Box::new(left),
1068 values,
1069 };
1070 }
1071
1072 Ok(left)
1076 }
1077
1078 fn parse_comparison(&mut self) -> Result<SqlExpression, String> {
1079 let mut left = self.parse_additive()?;
1080
1081 if matches!(self.current_token, Token::Between) {
1083 self.advance(); let lower = self.parse_primary()?;
1085 self.consume(Token::And)?; let upper = self.parse_primary()?;
1087
1088 return Ok(SqlExpression::Between {
1089 expr: Box::new(left),
1090 lower: Box::new(lower),
1091 upper: Box::new(upper),
1092 });
1093 }
1094
1095 if matches!(self.current_token, Token::Not) {
1097 self.advance(); if matches!(self.current_token, Token::In) {
1099 self.advance(); self.consume(Token::LeftParen)?;
1101 let values = self.parse_expression_list()?;
1102 self.consume(Token::RightParen)?;
1103
1104 return Ok(SqlExpression::NotInList {
1105 expr: Box::new(left),
1106 values,
1107 });
1108 }
1109 return Err("Expected IN after NOT".to_string());
1110 }
1111
1112 if let Some(op) = self.get_binary_op() {
1114 self.advance();
1115 let right = self.parse_additive()?;
1116 left = SqlExpression::BinaryOp {
1117 left: Box::new(left),
1118 op,
1119 right: Box::new(right),
1120 };
1121 }
1122
1123 Ok(left)
1124 }
1125
1126 fn parse_additive(&mut self) -> Result<SqlExpression, String> {
1127 let mut left = self.parse_multiplicative()?;
1128
1129 while matches!(self.current_token, Token::Plus | Token::Minus) {
1130 let op = match self.current_token {
1131 Token::Plus => "+",
1132 Token::Minus => "-",
1133 _ => unreachable!(),
1134 };
1135 self.advance();
1136 let right = self.parse_multiplicative()?;
1137 left = SqlExpression::BinaryOp {
1138 left: Box::new(left),
1139 op: op.to_string(),
1140 right: Box::new(right),
1141 };
1142 }
1143
1144 Ok(left)
1145 }
1146
1147 fn parse_multiplicative(&mut self) -> Result<SqlExpression, String> {
1148 let mut left = self.parse_primary()?;
1149
1150 while matches!(self.current_token, Token::Dot) {
1152 self.advance();
1153 if let Token::Identifier(method) = &self.current_token {
1154 let method_name = method.clone();
1155 self.advance();
1156
1157 if matches!(self.current_token, Token::LeftParen) {
1158 self.advance();
1159 let args = self.parse_method_args()?;
1160 self.consume(Token::RightParen)?;
1161
1162 match left {
1164 SqlExpression::Column(obj) => {
1165 left = SqlExpression::MethodCall {
1167 object: obj,
1168 method: method_name,
1169 args,
1170 };
1171 }
1172 SqlExpression::MethodCall { .. }
1173 | SqlExpression::ChainedMethodCall { .. } => {
1174 left = SqlExpression::ChainedMethodCall {
1176 base: Box::new(left),
1177 method: method_name,
1178 args,
1179 };
1180 }
1181 _ => {
1182 left = SqlExpression::ChainedMethodCall {
1184 base: Box::new(left),
1185 method: method_name,
1186 args,
1187 };
1188 }
1189 }
1190 } else {
1191 return Err(format!("Expected '(' after method name '{method_name}'"));
1192 }
1193 } else {
1194 return Err("Expected method name after '.'".to_string());
1195 }
1196 }
1197
1198 while matches!(self.current_token, Token::Star | Token::Divide) {
1199 let op = match self.current_token {
1200 Token::Star => "*",
1201 Token::Divide => "/",
1202 _ => unreachable!(),
1203 };
1204 self.advance();
1205 let right = self.parse_primary()?;
1206 left = SqlExpression::BinaryOp {
1207 left: Box::new(left),
1208 op: op.to_string(),
1209 right: Box::new(right),
1210 };
1211 }
1212
1213 Ok(left)
1214 }
1215
1216 fn parse_logical_or(&mut self) -> Result<SqlExpression, String> {
1217 let mut left = self.parse_logical_and()?;
1218
1219 while matches!(self.current_token, Token::Or) {
1220 self.advance();
1221 let right = self.parse_logical_and()?;
1222 left = SqlExpression::BinaryOp {
1226 left: Box::new(left),
1227 op: "OR".to_string(),
1228 right: Box::new(right),
1229 };
1230 }
1231
1232 Ok(left)
1233 }
1234
1235 fn parse_logical_and(&mut self) -> Result<SqlExpression, String> {
1236 let mut left = self.parse_expression()?;
1237
1238 while matches!(self.current_token, Token::And) {
1239 self.advance();
1240 let right = self.parse_expression()?;
1241 left = SqlExpression::BinaryOp {
1243 left: Box::new(left),
1244 op: "AND".to_string(),
1245 right: Box::new(right),
1246 };
1247 }
1248
1249 Ok(left)
1250 }
1251
1252 fn parse_case_expression(&mut self) -> Result<SqlExpression, String> {
1253 self.consume(Token::Case)?;
1255
1256 let mut when_branches = Vec::new();
1257
1258 while matches!(self.current_token, Token::When) {
1260 self.advance(); let condition = self.parse_expression()?;
1264
1265 self.consume(Token::Then)?;
1267
1268 let result = self.parse_expression()?;
1270
1271 when_branches.push(WhenBranch {
1272 condition: Box::new(condition),
1273 result: Box::new(result),
1274 });
1275 }
1276
1277 if when_branches.is_empty() {
1279 return Err("CASE expression must have at least one WHEN clause".to_string());
1280 }
1281
1282 let else_branch = if matches!(self.current_token, Token::Else) {
1284 self.advance(); Some(Box::new(self.parse_expression()?))
1286 } else {
1287 None
1288 };
1289
1290 self.consume(Token::End)?;
1292
1293 Ok(SqlExpression::CaseExpression {
1294 when_branches,
1295 else_branch,
1296 })
1297 }
1298
1299 fn parse_primary(&mut self) -> Result<SqlExpression, String> {
1300 if let Token::NumberLiteral(num_str) = &self.current_token {
1303 if self.columns.iter().any(|col| col == num_str) {
1305 let expr = SqlExpression::Column(num_str.clone());
1306 self.advance();
1307 return Ok(expr);
1308 }
1309 }
1310
1311 match &self.current_token {
1312 Token::Case => {
1313 self.parse_case_expression()
1315 }
1316 Token::DateTime => {
1317 self.advance(); self.consume(Token::LeftParen)?;
1319
1320 if matches!(&self.current_token, Token::RightParen) {
1322 self.advance(); return Ok(SqlExpression::DateTimeToday {
1324 hour: None,
1325 minute: None,
1326 second: None,
1327 });
1328 }
1329
1330 let year = if let Token::NumberLiteral(n) = &self.current_token {
1332 n.parse::<i32>().map_err(|_| "Invalid year")?
1333 } else {
1334 return Err("Expected year in DateTime constructor".to_string());
1335 };
1336 self.advance();
1337 self.consume(Token::Comma)?;
1338
1339 let month = if let Token::NumberLiteral(n) = &self.current_token {
1341 n.parse::<u32>().map_err(|_| "Invalid month")?
1342 } else {
1343 return Err("Expected month in DateTime constructor".to_string());
1344 };
1345 self.advance();
1346 self.consume(Token::Comma)?;
1347
1348 let day = if let Token::NumberLiteral(n) = &self.current_token {
1350 n.parse::<u32>().map_err(|_| "Invalid day")?
1351 } else {
1352 return Err("Expected day in DateTime constructor".to_string());
1353 };
1354 self.advance();
1355
1356 let mut hour = None;
1358 let mut minute = None;
1359 let mut second = None;
1360
1361 if matches!(&self.current_token, Token::Comma) {
1362 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1366 hour = Some(n.parse::<u32>().map_err(|_| "Invalid hour")?);
1367 self.advance();
1368
1369 if matches!(&self.current_token, Token::Comma) {
1371 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1374 minute = Some(n.parse::<u32>().map_err(|_| "Invalid minute")?);
1375 self.advance();
1376
1377 if matches!(&self.current_token, Token::Comma) {
1379 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1382 second =
1383 Some(n.parse::<u32>().map_err(|_| "Invalid second")?);
1384 self.advance();
1385 }
1386 }
1387 }
1388 }
1389 }
1390 }
1391
1392 self.consume(Token::RightParen)?;
1393 Ok(SqlExpression::DateTimeConstructor {
1394 year,
1395 month,
1396 day,
1397 hour,
1398 minute,
1399 second,
1400 })
1401 }
1402 Token::Identifier(id) => {
1403 let id_upper = id.to_uppercase();
1404 let id_clone = id.clone();
1405
1406 if id_upper == "TRUE" {
1408 self.advance();
1409 return Ok(SqlExpression::BooleanLiteral(true));
1410 } else if id_upper == "FALSE" {
1411 self.advance();
1412 return Ok(SqlExpression::BooleanLiteral(false));
1413 }
1414
1415 self.advance();
1416
1417 if matches!(self.current_token, Token::LeftParen) {
1419 self.advance(); let args = self.parse_function_args()?;
1423 self.consume(Token::RightParen)?;
1424
1425 if matches!(self.current_token, Token::Over) {
1427 self.advance(); self.consume(Token::LeftParen)?;
1429 let window_spec = self.parse_window_spec()?;
1430 self.consume(Token::RightParen)?;
1431 return Ok(SqlExpression::WindowFunction {
1432 name: id_upper,
1433 args,
1434 window_spec,
1435 });
1436 }
1437
1438 return Ok(SqlExpression::FunctionCall {
1439 name: id_upper,
1440 args,
1441 });
1442 }
1443
1444 Ok(SqlExpression::Column(id_clone))
1446 }
1447 Token::QuotedIdentifier(id) => {
1448 let expr = if self.in_method_args {
1451 SqlExpression::StringLiteral(id.clone())
1452 } else {
1453 SqlExpression::Column(id.clone())
1455 };
1456 self.advance();
1457 Ok(expr)
1458 }
1459 Token::StringLiteral(s) => {
1460 let expr = SqlExpression::StringLiteral(s.clone());
1461 self.advance();
1462 Ok(expr)
1463 }
1464 Token::NumberLiteral(n) => {
1465 let expr = SqlExpression::NumberLiteral(n.clone());
1466 self.advance();
1467 Ok(expr)
1468 }
1469 Token::LeftParen => {
1470 self.advance();
1471
1472 let expr = self.parse_logical_or()?;
1475
1476 self.consume(Token::RightParen)?;
1477 Ok(expr)
1478 }
1479 Token::Not => {
1480 self.advance(); if let Ok(inner_expr) = self.parse_comparison() {
1484 if matches!(self.current_token, Token::In) {
1486 self.advance(); self.consume(Token::LeftParen)?;
1488 let values = self.parse_expression_list()?;
1489 self.consume(Token::RightParen)?;
1490
1491 Ok(SqlExpression::NotInList {
1492 expr: Box::new(inner_expr),
1493 values,
1494 })
1495 } else {
1496 Ok(SqlExpression::Not {
1498 expr: Box::new(inner_expr),
1499 })
1500 }
1501 } else {
1502 Err("Expected expression after NOT".to_string())
1503 }
1504 }
1505 Token::Star => {
1506 self.advance();
1508 Ok(SqlExpression::StringLiteral("*".to_string()))
1509 }
1510 _ => Err(format!("Unexpected token: {:?}", self.current_token)),
1511 }
1512 }
1513
1514 fn parse_method_args(&mut self) -> Result<Vec<SqlExpression>, String> {
1515 let mut args = Vec::new();
1516
1517 self.in_method_args = true;
1519
1520 if !matches!(self.current_token, Token::RightParen) {
1521 loop {
1522 args.push(self.parse_expression()?);
1523
1524 if matches!(self.current_token, Token::Comma) {
1525 self.advance();
1526 } else {
1527 break;
1528 }
1529 }
1530 }
1531
1532 self.in_method_args = false;
1534
1535 Ok(args)
1536 }
1537
1538 fn parse_function_args(&mut self) -> Result<Vec<SqlExpression>, String> {
1539 let mut args = Vec::new();
1540
1541 if !matches!(self.current_token, Token::RightParen) {
1542 if matches!(self.current_token, Token::Distinct) {
1544 self.advance(); let expr = self.parse_additive()?;
1547 args.push(SqlExpression::FunctionCall {
1549 name: "DISTINCT".to_string(),
1550 args: vec![expr],
1551 });
1552 } else {
1553 args.push(self.parse_additive()?);
1555 }
1556
1557 while matches!(self.current_token, Token::Comma) {
1559 self.advance();
1560 args.push(self.parse_additive()?);
1561 }
1562 }
1563
1564 Ok(args)
1565 }
1566
1567 fn parse_expression_list(&mut self) -> Result<Vec<SqlExpression>, String> {
1568 let mut expressions = Vec::new();
1569
1570 loop {
1571 expressions.push(self.parse_expression()?);
1572
1573 if matches!(self.current_token, Token::Comma) {
1574 self.advance();
1575 } else {
1576 break;
1577 }
1578 }
1579
1580 Ok(expressions)
1581 }
1582
1583 fn get_binary_op(&self) -> Option<String> {
1584 match &self.current_token {
1585 Token::Equal => Some("=".to_string()),
1586 Token::NotEqual => Some("!=".to_string()),
1587 Token::LessThan => Some("<".to_string()),
1588 Token::GreaterThan => Some(">".to_string()),
1589 Token::LessThanOrEqual => Some("<=".to_string()),
1590 Token::GreaterThanOrEqual => Some(">=".to_string()),
1591 Token::Like => Some("LIKE".to_string()),
1592 _ => None,
1593 }
1594 }
1595
1596 fn get_arithmetic_op(&self) -> Option<String> {
1597 match &self.current_token {
1598 Token::Plus => Some("+".to_string()),
1599 Token::Minus => Some("-".to_string()),
1600 Token::Star => Some("*".to_string()), Token::Divide => Some("/".to_string()),
1602 _ => None,
1603 }
1604 }
1605
1606 #[must_use]
1607 pub fn get_position(&self) -> usize {
1608 self.lexer.get_position()
1609 }
1610}
1611
1612#[derive(Debug, Clone)]
1614pub enum CursorContext {
1615 SelectClause,
1616 FromClause,
1617 WhereClause,
1618 OrderByClause,
1619 AfterColumn(String),
1620 AfterLogicalOp(LogicalOp),
1621 AfterComparisonOp(String, String), InMethodCall(String, String), InExpression,
1624 Unknown,
1625}
1626
1627fn safe_slice_to(s: &str, pos: usize) -> &str {
1629 if pos >= s.len() {
1630 return s;
1631 }
1632
1633 let mut safe_pos = pos;
1635 while safe_pos > 0 && !s.is_char_boundary(safe_pos) {
1636 safe_pos -= 1;
1637 }
1638
1639 &s[..safe_pos]
1640}
1641
1642fn safe_slice_from(s: &str, pos: usize) -> &str {
1644 if pos >= s.len() {
1645 return "";
1646 }
1647
1648 let mut safe_pos = pos;
1650 while safe_pos < s.len() && !s.is_char_boundary(safe_pos) {
1651 safe_pos += 1;
1652 }
1653
1654 &s[safe_pos..]
1655}
1656
1657#[must_use]
1658pub fn detect_cursor_context(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
1659 let truncated = safe_slice_to(query, cursor_pos);
1660 let mut parser = Parser::new(truncated);
1661
1662 if let Ok(stmt) = parser.parse() {
1664 let (ctx, partial) = analyze_statement(&stmt, truncated, cursor_pos);
1665 #[cfg(test)]
1666 println!("analyze_statement returned: {ctx:?}, {partial:?} for query: '{truncated}'");
1667 (ctx, partial)
1668 } else {
1669 let (ctx, partial) = analyze_partial(truncated, cursor_pos);
1671 #[cfg(test)]
1672 println!("analyze_partial returned: {ctx:?}, {partial:?} for query: '{truncated}'");
1673 (ctx, partial)
1674 }
1675}
1676
1677#[must_use]
1678pub fn tokenize_query(query: &str) -> Vec<String> {
1679 let mut lexer = Lexer::new(query);
1680 let tokens = lexer.tokenize_all();
1681 tokens.iter().map(|t| format!("{t:?}")).collect()
1682}
1683
1684#[must_use]
1685pub fn format_sql_pretty(query: &str) -> Vec<String> {
1686 format_sql_pretty_compact(query, 5) }
1688
1689#[must_use]
1691pub fn format_ast_tree(query: &str) -> String {
1692 let mut parser = Parser::new(query);
1693 match parser.parse() {
1694 Ok(stmt) => format_select_statement(&stmt, 0),
1695 Err(e) => format!("❌ PARSE ERROR ❌\n{e}\n\n⚠️ The query could not be parsed correctly.\n💡 Check parentheses, operators, and syntax."),
1696 }
1697}
1698
1699fn format_select_statement(stmt: &SelectStatement, indent: usize) -> String {
1700 let mut result = String::new();
1701 let indent_str = " ".repeat(indent);
1702
1703 result.push_str(&format!("{indent_str}SelectStatement {{\n"));
1704
1705 result.push_str(&format!("{indent_str} columns: ["));
1707 if stmt.columns.is_empty() {
1708 result.push_str("],\n");
1709 } else {
1710 result.push('\n');
1711 for col in &stmt.columns {
1712 result.push_str(&format!("{indent_str} \"{col}\",\n"));
1713 }
1714 result.push_str(&format!("{indent_str} ],\n"));
1715 }
1716
1717 if let Some(table) = &stmt.from_table {
1719 result.push_str(&format!("{indent_str} from_table: \"{table}\",\n"));
1720 }
1721
1722 if let Some(where_clause) = &stmt.where_clause {
1724 result.push_str(&format!("{indent_str} where_clause: {{\n"));
1725 result.push_str(&format_where_clause(where_clause, indent + 2));
1726 result.push_str(&format!("{indent_str} }},\n"));
1727 }
1728
1729 if let Some(order_by) = &stmt.order_by {
1731 result.push_str(&format!("{indent_str} order_by: ["));
1732 if order_by.is_empty() {
1733 result.push_str("],\n");
1734 } else {
1735 result.push('\n');
1736 for col in order_by {
1737 let dir = match col.direction {
1738 SortDirection::Asc => "ASC",
1739 SortDirection::Desc => "DESC",
1740 };
1741 result.push_str(&format!(
1742 "{indent_str} \"{col}\" {dir},\n",
1743 col = col.column
1744 ));
1745 }
1746 result.push_str(&format!("{indent_str} ],\n"));
1747 }
1748 }
1749
1750 if let Some(group_by) = &stmt.group_by {
1752 result.push_str(&format!("{indent_str} group_by: ["));
1753 if group_by.is_empty() {
1754 result.push_str("]\n");
1755 } else {
1756 result.push('\n');
1757 for col in group_by {
1758 result.push_str(&format!("{indent_str} \"{col}\",\n"));
1759 }
1760 result.push_str(&format!("{indent_str} ],\n"));
1761 }
1762 }
1763
1764 result.push_str(&format!("{indent_str}}}"));
1765 result
1766}
1767
1768fn format_where_clause(clause: &WhereClause, indent: usize) -> String {
1769 let mut result = String::new();
1770 let indent_str = " ".repeat(indent);
1771
1772 result.push_str(&format!("{indent_str}conditions: [\n"));
1773
1774 for condition in &clause.conditions {
1775 result.push_str(&format!("{indent_str} {{\n"));
1776 result.push_str(&format!(
1777 "{indent_str} expr: {},\n",
1778 format_expression_ast(&condition.expr)
1779 ));
1780
1781 if let Some(connector) = &condition.connector {
1782 let connector_str = match connector {
1783 LogicalOp::And => "AND",
1784 LogicalOp::Or => "OR",
1785 };
1786 result.push_str(&format!("{indent_str} connector: {connector_str},\n"));
1787 }
1788
1789 result.push_str(&format!("{indent_str} }},\n"));
1790 }
1791
1792 result.push_str(&format!("{indent_str}]\n"));
1793 result
1794}
1795
1796fn format_expression_ast(expr: &SqlExpression) -> String {
1797 match expr {
1798 SqlExpression::Column(name) => format!("Column(\"{name}\")"),
1799 SqlExpression::StringLiteral(value) => format!("StringLiteral(\"{value}\")"),
1800 SqlExpression::NumberLiteral(value) => format!("NumberLiteral({value})"),
1801 SqlExpression::BooleanLiteral(value) => format!("BooleanLiteral({value})"),
1802 SqlExpression::DateTimeConstructor {
1803 year,
1804 month,
1805 day,
1806 hour,
1807 minute,
1808 second,
1809 } => {
1810 format!(
1811 "DateTime({}-{:02}-{:02} {:02}:{:02}:{:02})",
1812 year,
1813 month,
1814 day,
1815 hour.unwrap_or(0),
1816 minute.unwrap_or(0),
1817 second.unwrap_or(0)
1818 )
1819 }
1820 SqlExpression::DateTimeToday {
1821 hour,
1822 minute,
1823 second,
1824 } => {
1825 format!(
1826 "DateTimeToday({:02}:{:02}:{:02})",
1827 hour.unwrap_or(0),
1828 minute.unwrap_or(0),
1829 second.unwrap_or(0)
1830 )
1831 }
1832 SqlExpression::MethodCall {
1833 object,
1834 method,
1835 args,
1836 } => {
1837 let args_str = args
1838 .iter()
1839 .map(format_expression_ast)
1840 .collect::<Vec<_>>()
1841 .join(", ");
1842 format!("MethodCall({object}.{method}({args_str}))")
1843 }
1844 SqlExpression::ChainedMethodCall { base, method, args } => {
1845 let args_str = args
1846 .iter()
1847 .map(format_expression_ast)
1848 .collect::<Vec<_>>()
1849 .join(", ");
1850 format!(
1851 "ChainedMethodCall({}.{}({}))",
1852 format_expression_ast(base),
1853 method,
1854 args_str
1855 )
1856 }
1857 SqlExpression::FunctionCall { name, args } => {
1858 let args_str = args
1859 .iter()
1860 .map(format_expression_ast)
1861 .collect::<Vec<_>>()
1862 .join(", ");
1863 format!("FunctionCall({name}({args_str}))")
1864 }
1865 SqlExpression::WindowFunction {
1866 name,
1867 args,
1868 window_spec,
1869 } => {
1870 let args_str = args
1871 .iter()
1872 .map(format_expression_ast)
1873 .collect::<Vec<_>>()
1874 .join(", ");
1875 let partition_str = if !window_spec.partition_by.is_empty() {
1876 format!(" PARTITION BY {}", window_spec.partition_by.join(", "))
1877 } else {
1878 String::new()
1879 };
1880 let order_str = if !window_spec.order_by.is_empty() {
1881 let cols = window_spec
1882 .order_by
1883 .iter()
1884 .map(|col| format!("{} {:?}", col.column, col.direction))
1885 .collect::<Vec<_>>()
1886 .join(", ");
1887 format!(" ORDER BY {}", cols)
1888 } else {
1889 String::new()
1890 };
1891 format!("WindowFunction({name}({args_str}) OVER({partition_str}{order_str}))")
1892 }
1893 SqlExpression::BinaryOp { left, op, right } => {
1894 format!(
1895 "BinaryOp({} {} {})",
1896 format_expression_ast(left),
1897 op,
1898 format_expression_ast(right)
1899 )
1900 }
1901 SqlExpression::InList { expr, values } => {
1902 let list_str = values
1903 .iter()
1904 .map(format_expression_ast)
1905 .collect::<Vec<_>>()
1906 .join(", ");
1907 format!("InList({} IN [{}])", format_expression_ast(expr), list_str)
1908 }
1909 SqlExpression::NotInList { expr, values } => {
1910 let list_str = values
1911 .iter()
1912 .map(format_expression_ast)
1913 .collect::<Vec<_>>()
1914 .join(", ");
1915 format!(
1916 "NotInList({} NOT IN [{}])",
1917 format_expression_ast(expr),
1918 list_str
1919 )
1920 }
1921 SqlExpression::Between { expr, lower, upper } => {
1922 format!(
1923 "Between({} BETWEEN {} AND {})",
1924 format_expression_ast(expr),
1925 format_expression_ast(lower),
1926 format_expression_ast(upper)
1927 )
1928 }
1929 SqlExpression::Not { expr } => {
1930 format!("Not({})", format_expression_ast(expr))
1931 }
1932 SqlExpression::CaseExpression {
1933 when_branches,
1934 else_branch,
1935 } => {
1936 let when_strs: Vec<String> = when_branches
1937 .iter()
1938 .map(|branch| {
1939 format!(
1940 "WHEN {} THEN {}",
1941 format_expression_ast(&branch.condition),
1942 format_expression_ast(&branch.result)
1943 )
1944 })
1945 .collect();
1946 let else_str = else_branch
1947 .as_ref()
1948 .map(|e| format!(" ELSE {}", format_expression_ast(e)))
1949 .unwrap_or_default();
1950 format!("CASE {} {} END", when_strs.join(" "), else_str)
1951 }
1952 }
1953}
1954
1955#[must_use]
1957pub fn datetime_to_iso_string(expr: &SqlExpression) -> Option<String> {
1958 match expr {
1959 SqlExpression::DateTimeConstructor {
1960 year,
1961 month,
1962 day,
1963 hour,
1964 minute,
1965 second,
1966 } => {
1967 let h = hour.unwrap_or(0);
1968 let m = minute.unwrap_or(0);
1969 let s = second.unwrap_or(0);
1970
1971 if let Ok(dt) = NaiveDateTime::parse_from_str(
1973 &format!("{year:04}-{month:02}-{day:02} {h:02}:{m:02}:{s:02}"),
1974 "%Y-%m-%d %H:%M:%S",
1975 ) {
1976 Some(dt.format("%Y-%m-%d %H:%M:%S").to_string())
1977 } else {
1978 None
1979 }
1980 }
1981 SqlExpression::DateTimeToday {
1982 hour,
1983 minute,
1984 second,
1985 } => {
1986 let now = Local::now();
1987 let h = hour.unwrap_or(0);
1988 let m = minute.unwrap_or(0);
1989 let s = second.unwrap_or(0);
1990
1991 if let Ok(dt) = NaiveDateTime::parse_from_str(
1993 &format!(
1994 "{:04}-{:02}-{:02} {:02}:{:02}:{:02}",
1995 now.year(),
1996 now.month(),
1997 now.day(),
1998 h,
1999 m,
2000 s
2001 ),
2002 "%Y-%m-%d %H:%M:%S",
2003 ) {
2004 Some(dt.format("%Y-%m-%d %H:%M:%S").to_string())
2005 } else {
2006 None
2007 }
2008 }
2009 _ => None,
2010 }
2011}
2012
2013fn format_sql_with_preserved_parens(
2015 query: &str,
2016 cols_per_line: usize,
2017) -> Result<Vec<String>, String> {
2018 let mut lines = Vec::new();
2019 let mut lexer = Lexer::new(query);
2020 let tokens_with_pos = lexer.tokenize_all_with_positions();
2021
2022 if tokens_with_pos.is_empty() {
2023 return Err("No tokens found".to_string());
2024 }
2025
2026 let mut i = 0;
2027 let cols_per_line = cols_per_line.max(1);
2028
2029 while i < tokens_with_pos.len() {
2030 let (start, _end, ref token) = tokens_with_pos[i];
2031
2032 match token {
2033 Token::Select => {
2034 lines.push("SELECT".to_string());
2035 i += 1;
2036
2037 let mut columns = Vec::new();
2039 let mut col_start = i;
2040 while i < tokens_with_pos.len() {
2041 match &tokens_with_pos[i].2 {
2042 Token::From | Token::Eof => break,
2043 Token::Comma => {
2044 if col_start < i {
2046 let col_text = extract_text_between_positions(
2047 query,
2048 tokens_with_pos[col_start].0,
2049 tokens_with_pos[i - 1].1,
2050 );
2051 columns.push(col_text);
2052 }
2053 i += 1;
2054 col_start = i;
2055 }
2056 _ => i += 1,
2057 }
2058 }
2059 if col_start < i && i > 0 {
2061 let col_text = extract_text_between_positions(
2062 query,
2063 tokens_with_pos[col_start].0,
2064 tokens_with_pos[i - 1].1,
2065 );
2066 columns.push(col_text);
2067 }
2068
2069 for chunk in columns.chunks(cols_per_line) {
2071 let mut line = " ".to_string();
2072 for (idx, col) in chunk.iter().enumerate() {
2073 if idx > 0 {
2074 line.push_str(", ");
2075 }
2076 line.push_str(col.trim());
2077 }
2078 let is_last_chunk = chunk.as_ptr() as usize + std::mem::size_of_val(chunk)
2080 >= columns.last().map_or(0, |c| std::ptr::from_ref(c) as usize);
2081 if !is_last_chunk && columns.len() > cols_per_line {
2082 line.push(',');
2083 }
2084 lines.push(line);
2085 }
2086 }
2087 Token::From => {
2088 i += 1;
2089 if i < tokens_with_pos.len() {
2090 let table_start = tokens_with_pos[i].0;
2091 while i < tokens_with_pos.len() {
2093 match &tokens_with_pos[i].2 {
2094 Token::Where | Token::OrderBy | Token::GroupBy | Token::Eof => break,
2095 _ => i += 1,
2096 }
2097 }
2098 if i > 0 {
2099 let table_text = extract_text_between_positions(
2100 query,
2101 table_start,
2102 tokens_with_pos[i - 1].1,
2103 );
2104 lines.push(format!("FROM {}", table_text.trim()));
2105 }
2106 }
2107 }
2108 Token::Where => {
2109 lines.push("WHERE".to_string());
2110 i += 1;
2111
2112 let where_start = if i < tokens_with_pos.len() {
2114 tokens_with_pos[i].0
2115 } else {
2116 start
2117 };
2118
2119 let mut where_end = query.len();
2121 while i < tokens_with_pos.len() {
2122 match &tokens_with_pos[i].2 {
2123 Token::OrderBy | Token::GroupBy | Token::Eof => {
2124 if i > 0 {
2125 where_end = tokens_with_pos[i - 1].1;
2126 }
2127 break;
2128 }
2129 _ => i += 1,
2130 }
2131 }
2132
2133 let where_text = extract_text_between_positions(query, where_start, where_end);
2135
2136 let formatted_where = format_where_clause_with_parens(&where_text);
2138 for line in formatted_where {
2139 lines.push(format!(" {line}"));
2140 }
2141 }
2142 Token::OrderBy => {
2143 i += 1;
2144 let order_start = if i < tokens_with_pos.len() {
2145 tokens_with_pos[i].0
2146 } else {
2147 start
2148 };
2149
2150 while i < tokens_with_pos.len() {
2152 match &tokens_with_pos[i].2 {
2153 Token::GroupBy | Token::Eof => break,
2154 _ => i += 1,
2155 }
2156 }
2157
2158 if i > 0 {
2159 let order_text = extract_text_between_positions(
2160 query,
2161 order_start,
2162 tokens_with_pos[i - 1].1,
2163 );
2164 lines.push(format!("ORDER BY {}", order_text.trim()));
2165 }
2166 }
2167 Token::GroupBy => {
2168 i += 1;
2169 let group_start = if i < tokens_with_pos.len() {
2170 tokens_with_pos[i].0
2171 } else {
2172 start
2173 };
2174
2175 while i < tokens_with_pos.len() {
2177 match &tokens_with_pos[i].2 {
2178 Token::Having | Token::Eof => break,
2179 _ => i += 1,
2180 }
2181 }
2182
2183 if i > 0 {
2184 let group_text = extract_text_between_positions(
2185 query,
2186 group_start,
2187 tokens_with_pos[i - 1].1,
2188 );
2189 lines.push(format!("GROUP BY {}", group_text.trim()));
2190 }
2191 }
2192 _ => i += 1,
2193 }
2194 }
2195
2196 Ok(lines)
2197}
2198
2199fn extract_text_between_positions(query: &str, start: usize, end: usize) -> String {
2201 let chars: Vec<char> = query.chars().collect();
2202 let start = start.min(chars.len());
2203 let end = end.min(chars.len());
2204 chars[start..end].iter().collect()
2205}
2206
2207fn format_where_clause_with_parens(where_text: &str) -> Vec<String> {
2209 let mut lines = Vec::new();
2210 let mut current_line = String::new();
2211 let mut paren_depth = 0;
2212 let mut i = 0;
2213 let chars: Vec<char> = where_text.chars().collect();
2214
2215 while i < chars.len() {
2216 if paren_depth == 0 {
2218 if i + 5 <= chars.len() {
2220 let next_five: String = chars[i..i + 5].iter().collect();
2221 if next_five.to_uppercase() == " AND " {
2222 if !current_line.trim().is_empty() {
2223 lines.push(current_line.trim().to_string());
2224 }
2225 lines.push("AND".to_string());
2226 current_line.clear();
2227 i += 5;
2228 continue;
2229 }
2230 }
2231 if i + 4 <= chars.len() {
2232 let next_four: String = chars[i..i + 4].iter().collect();
2233 if next_four.to_uppercase() == " OR " {
2234 if !current_line.trim().is_empty() {
2235 lines.push(current_line.trim().to_string());
2236 }
2237 lines.push("OR".to_string());
2238 current_line.clear();
2239 i += 4;
2240 continue;
2241 }
2242 }
2243 }
2244
2245 match chars[i] {
2247 '(' => {
2248 paren_depth += 1;
2249 current_line.push('(');
2250 }
2251 ')' => {
2252 paren_depth -= 1;
2253 current_line.push(')');
2254 }
2255 c => current_line.push(c),
2256 }
2257 i += 1;
2258 }
2259
2260 if !current_line.trim().is_empty() {
2262 lines.push(current_line.trim().to_string());
2263 }
2264
2265 if lines.is_empty() {
2267 lines.push(where_text.trim().to_string());
2268 }
2269
2270 lines
2271}
2272
2273#[must_use]
2274pub fn format_sql_pretty_compact(query: &str, cols_per_line: usize) -> Vec<String> {
2275 if let Ok(lines) = format_sql_with_preserved_parens(query, cols_per_line) {
2277 return lines;
2278 }
2279
2280 let mut lines = Vec::new();
2282 let mut parser = Parser::new(query);
2283
2284 let cols_per_line = cols_per_line.max(1);
2286
2287 if let Ok(stmt) = parser.parse() {
2288 if !stmt.columns.is_empty() {
2290 lines.push("SELECT".to_string());
2291
2292 for chunk in stmt.columns.chunks(cols_per_line) {
2294 let mut line = " ".to_string();
2295 for (i, col) in chunk.iter().enumerate() {
2296 if i > 0 {
2297 line.push_str(", ");
2298 }
2299 line.push_str(col);
2300 }
2301 let last_chunk_idx = (stmt.columns.len() - 1) / cols_per_line;
2303 let current_chunk_idx =
2304 stmt.columns.iter().position(|c| c == &chunk[0]).unwrap() / cols_per_line;
2305 if current_chunk_idx < last_chunk_idx {
2306 line.push(',');
2307 }
2308 lines.push(line);
2309 }
2310 }
2311
2312 if let Some(table) = &stmt.from_table {
2314 lines.push(format!("FROM {table}"));
2315 }
2316
2317 if let Some(where_clause) = &stmt.where_clause {
2319 lines.push("WHERE".to_string());
2320 for (i, condition) in where_clause.conditions.iter().enumerate() {
2321 if i > 0 {
2322 if let Some(prev_condition) = where_clause.conditions.get(i - 1) {
2324 if let Some(connector) = &prev_condition.connector {
2325 match connector {
2326 LogicalOp::And => lines.push(" AND".to_string()),
2327 LogicalOp::Or => lines.push(" OR".to_string()),
2328 }
2329 }
2330 }
2331 }
2332 lines.push(format!(" {}", format_expression(&condition.expr)));
2333 }
2334 }
2335
2336 if let Some(order_by) = &stmt.order_by {
2338 let order_str = order_by
2339 .iter()
2340 .map(|col| {
2341 let dir = match col.direction {
2342 SortDirection::Asc => " ASC",
2343 SortDirection::Desc => " DESC",
2344 };
2345 format!("{}{}", col.column, dir)
2346 })
2347 .collect::<Vec<_>>()
2348 .join(", ");
2349 lines.push(format!("ORDER BY {order_str}"));
2350 }
2351
2352 if let Some(group_by) = &stmt.group_by {
2354 let group_str = group_by.join(", ");
2355 lines.push(format!("GROUP BY {group_str}"));
2356 }
2357 } else {
2358 let mut lexer = Lexer::new(query);
2360 let tokens = lexer.tokenize_all();
2361 let mut current_line = String::new();
2362 let mut indent = 0;
2363
2364 for token in tokens {
2365 match &token {
2366 Token::Select | Token::From | Token::Where | Token::OrderBy | Token::GroupBy => {
2367 if !current_line.is_empty() {
2368 lines.push(current_line.trim().to_string());
2369 current_line.clear();
2370 }
2371 lines.push(format!("{token:?}").to_uppercase());
2372 indent = 1;
2373 }
2374 Token::And | Token::Or => {
2375 if !current_line.is_empty() {
2376 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2377 current_line.clear();
2378 }
2379 lines.push(format!(" {token:?}").to_uppercase());
2380 }
2381 Token::Comma => {
2382 current_line.push(',');
2383 if indent > 0 {
2384 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2385 current_line.clear();
2386 }
2387 }
2388 Token::Eof => break,
2389 _ => {
2390 if !current_line.is_empty() {
2391 current_line.push(' ');
2392 }
2393 current_line.push_str(&format_token(&token));
2394 }
2395 }
2396 }
2397
2398 if !current_line.is_empty() {
2399 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2400 }
2401 }
2402
2403 lines
2404}
2405
2406fn format_expression(expr: &SqlExpression) -> String {
2407 match expr {
2408 SqlExpression::Column(name) => name.clone(),
2409 SqlExpression::StringLiteral(s) => format!("'{s}'"),
2410 SqlExpression::NumberLiteral(n) => n.clone(),
2411 SqlExpression::BooleanLiteral(b) => b.to_string(),
2412 SqlExpression::DateTimeConstructor {
2413 year,
2414 month,
2415 day,
2416 hour,
2417 minute,
2418 second,
2419 } => {
2420 let mut result = format!("DateTime({year}, {month}, {day}");
2421 if let Some(h) = hour {
2422 result.push_str(&format!(", {h}"));
2423 if let Some(m) = minute {
2424 result.push_str(&format!(", {m}"));
2425 if let Some(s) = second {
2426 result.push_str(&format!(", {s}"));
2427 }
2428 }
2429 }
2430 result.push(')');
2431 result
2432 }
2433 SqlExpression::DateTimeToday {
2434 hour,
2435 minute,
2436 second,
2437 } => {
2438 let mut result = "DateTime()".to_string();
2439 if let Some(h) = hour {
2440 result = format!("DateTime(TODAY, {h}");
2441 if let Some(m) = minute {
2442 result.push_str(&format!(", {m}"));
2443 if let Some(s) = second {
2444 result.push_str(&format!(", {s}"));
2445 }
2446 }
2447 result.push(')');
2448 }
2449 result
2450 }
2451 SqlExpression::MethodCall {
2452 object,
2453 method,
2454 args,
2455 } => {
2456 let args_str = args
2457 .iter()
2458 .map(format_expression)
2459 .collect::<Vec<_>>()
2460 .join(", ");
2461 format!("{object}.{method}({args_str})")
2462 }
2463 SqlExpression::BinaryOp { left, op, right } => {
2464 if op == "OR" || op == "AND" {
2467 format!(
2470 "({} {} {})",
2471 format_expression(left),
2472 op,
2473 format_expression(right)
2474 )
2475 } else {
2476 format!(
2477 "{} {} {}",
2478 format_expression(left),
2479 op,
2480 format_expression(right)
2481 )
2482 }
2483 }
2484 SqlExpression::InList { expr, values } => {
2485 let values_str = values
2486 .iter()
2487 .map(format_expression)
2488 .collect::<Vec<_>>()
2489 .join(", ");
2490 format!("{} IN ({})", format_expression(expr), values_str)
2491 }
2492 SqlExpression::NotInList { expr, values } => {
2493 let values_str = values
2494 .iter()
2495 .map(format_expression)
2496 .collect::<Vec<_>>()
2497 .join(", ");
2498 format!("{} NOT IN ({})", format_expression(expr), values_str)
2499 }
2500 SqlExpression::Between { expr, lower, upper } => {
2501 format!(
2502 "{} BETWEEN {} AND {}",
2503 format_expression(expr),
2504 format_expression(lower),
2505 format_expression(upper)
2506 )
2507 }
2508 SqlExpression::Not { expr } => {
2509 format!("NOT {}", format_expression(expr))
2510 }
2511 SqlExpression::ChainedMethodCall { base, method, args } => {
2512 let args_str = args
2513 .iter()
2514 .map(format_expression)
2515 .collect::<Vec<_>>()
2516 .join(", ");
2517 format!("{}.{}({})", format_expression(base), method, args_str)
2518 }
2519 SqlExpression::FunctionCall { name, args } => {
2520 let args_str = args
2521 .iter()
2522 .map(format_expression)
2523 .collect::<Vec<_>>()
2524 .join(", ");
2525 format!("{name}({args_str})")
2526 }
2527 SqlExpression::WindowFunction {
2528 name,
2529 args,
2530 window_spec,
2531 } => {
2532 let args_str = args
2533 .iter()
2534 .map(format_expression)
2535 .collect::<Vec<_>>()
2536 .join(", ");
2537 let partition_str = if !window_spec.partition_by.is_empty() {
2538 format!(" PARTITION BY {}", window_spec.partition_by.join(", "))
2539 } else {
2540 String::new()
2541 };
2542 let order_str = if !window_spec.order_by.is_empty() {
2543 let cols = window_spec
2544 .order_by
2545 .iter()
2546 .map(|col| {
2547 let dir = match col.direction {
2548 SortDirection::Asc => "ASC",
2549 SortDirection::Desc => "DESC",
2550 };
2551 format!("{} {}", col.column, dir)
2552 })
2553 .collect::<Vec<_>>()
2554 .join(", ");
2555 format!(" ORDER BY {}", cols)
2556 } else {
2557 String::new()
2558 };
2559 format!("{name}({args_str}) OVER({partition_str}{order_str})")
2560 }
2561 SqlExpression::CaseExpression {
2562 when_branches,
2563 else_branch,
2564 } => {
2565 let mut result = String::from("CASE");
2566 for branch in when_branches {
2567 result.push_str(&format!(
2568 " WHEN {} THEN {}",
2569 format_expression(&branch.condition),
2570 format_expression(&branch.result)
2571 ));
2572 }
2573 if let Some(else_expr) = else_branch {
2574 result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
2575 }
2576 result.push_str(" END");
2577 result
2578 }
2579 }
2580}
2581
2582fn format_token(token: &Token) -> String {
2583 match token {
2584 Token::Identifier(s) => s.clone(),
2585 Token::QuotedIdentifier(s) => format!("\"{s}\""),
2586 Token::StringLiteral(s) => format!("'{s}'"),
2587 Token::NumberLiteral(n) => n.clone(),
2588 Token::DateTime => "DateTime".to_string(),
2589 Token::Case => "CASE".to_string(),
2590 Token::When => "WHEN".to_string(),
2591 Token::Then => "THEN".to_string(),
2592 Token::Else => "ELSE".to_string(),
2593 Token::End => "END".to_string(),
2594 Token::Distinct => "DISTINCT".to_string(),
2595 Token::Over => "OVER".to_string(),
2596 Token::Partition => "PARTITION".to_string(),
2597 Token::By => "BY".to_string(),
2598 Token::LeftParen => "(".to_string(),
2599 Token::RightParen => ")".to_string(),
2600 Token::Comma => ",".to_string(),
2601 Token::Dot => ".".to_string(),
2602 Token::Equal => "=".to_string(),
2603 Token::NotEqual => "!=".to_string(),
2604 Token::LessThan => "<".to_string(),
2605 Token::GreaterThan => ">".to_string(),
2606 Token::LessThanOrEqual => "<=".to_string(),
2607 Token::GreaterThanOrEqual => ">=".to_string(),
2608 Token::In => "IN".to_string(),
2609 _ => format!("{token:?}").to_uppercase(),
2610 }
2611}
2612
2613fn analyze_statement(
2614 stmt: &SelectStatement,
2615 query: &str,
2616 _cursor_pos: usize,
2617) -> (CursorContext, Option<String>) {
2618 let trimmed = query.trim();
2620
2621 let comparison_ops = [" > ", " < ", " >= ", " <= ", " = ", " != "];
2623 for op in &comparison_ops {
2624 if let Some(op_pos) = query.rfind(op) {
2625 let before_op = safe_slice_to(query, op_pos);
2626 let after_op_start = op_pos + op.len();
2627 let after_op = if after_op_start < query.len() {
2628 &query[after_op_start..]
2629 } else {
2630 ""
2631 };
2632
2633 if let Some(col_name) = before_op.split_whitespace().last() {
2635 if col_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
2636 let after_op_trimmed = after_op.trim();
2638 if after_op_trimmed.is_empty()
2639 || (after_op_trimmed
2640 .chars()
2641 .all(|c| c.is_alphanumeric() || c == '_')
2642 && !after_op_trimmed.contains('('))
2643 {
2644 let partial = if after_op_trimmed.is_empty() {
2645 None
2646 } else {
2647 Some(after_op_trimmed.to_string())
2648 };
2649 return (
2650 CursorContext::AfterComparisonOp(
2651 col_name.to_string(),
2652 op.trim().to_string(),
2653 ),
2654 partial,
2655 );
2656 }
2657 }
2658 }
2659 }
2660 }
2661
2662 if trimmed.to_uppercase().ends_with(" AND")
2664 || trimmed.to_uppercase().ends_with(" OR")
2665 || trimmed.to_uppercase().ends_with(" AND ")
2666 || trimmed.to_uppercase().ends_with(" OR ")
2667 {
2668 } else {
2670 if let Some(dot_pos) = trimmed.rfind('.') {
2672 let before_dot = safe_slice_to(trimmed, dot_pos);
2674 let after_dot_start = dot_pos + 1;
2675 let after_dot = if after_dot_start < trimmed.len() {
2676 &trimmed[after_dot_start..]
2677 } else {
2678 ""
2679 };
2680
2681 if !after_dot.contains('(') {
2684 let col_name = if before_dot.ends_with('"') {
2686 let bytes = before_dot.as_bytes();
2688 let mut pos = before_dot.len() - 1; let mut found_start = None;
2690
2691 if pos > 0 {
2693 pos -= 1;
2694 while pos > 0 {
2695 if bytes[pos] == b'"' {
2696 if pos == 0 || bytes[pos - 1] != b'\\' {
2698 found_start = Some(pos);
2699 break;
2700 }
2701 }
2702 pos -= 1;
2703 }
2704 if found_start.is_none() && bytes[0] == b'"' {
2706 found_start = Some(0);
2707 }
2708 }
2709
2710 found_start.map(|start| safe_slice_from(before_dot, start))
2711 } else {
2712 before_dot
2715 .split_whitespace()
2716 .last()
2717 .map(|word| word.trim_start_matches('('))
2718 };
2719
2720 if let Some(col_name) = col_name {
2721 let is_valid = if col_name.starts_with('"') && col_name.ends_with('"') {
2723 true
2725 } else {
2726 col_name.chars().all(|c| c.is_alphanumeric() || c == '_')
2728 };
2729
2730 if is_valid {
2731 let partial_method = if after_dot.is_empty() {
2734 None
2735 } else if after_dot.chars().all(|c| c.is_alphanumeric() || c == '_') {
2736 Some(after_dot.to_string())
2737 } else {
2738 None
2739 };
2740
2741 let col_name_for_context = if col_name.starts_with('"')
2743 && col_name.ends_with('"')
2744 && col_name.len() > 2
2745 {
2746 col_name[1..col_name.len() - 1].to_string()
2747 } else {
2748 col_name.to_string()
2749 };
2750
2751 return (
2752 CursorContext::AfterColumn(col_name_for_context),
2753 partial_method,
2754 );
2755 }
2756 }
2757 }
2758 }
2759 }
2760
2761 if let Some(where_clause) = &stmt.where_clause {
2763 if trimmed.to_uppercase().ends_with(" AND") || trimmed.to_uppercase().ends_with(" OR") {
2765 let op = if trimmed.to_uppercase().ends_with(" AND") {
2766 LogicalOp::And
2767 } else {
2768 LogicalOp::Or
2769 };
2770 return (CursorContext::AfterLogicalOp(op), None);
2771 }
2772
2773 if let Some(and_pos) = query.to_uppercase().rfind(" AND ") {
2775 let after_and = safe_slice_from(query, and_pos + 5);
2776 let partial = extract_partial_at_end(after_and);
2777 if partial.is_some() {
2778 return (CursorContext::AfterLogicalOp(LogicalOp::And), partial);
2779 }
2780 }
2781
2782 if let Some(or_pos) = query.to_uppercase().rfind(" OR ") {
2783 let after_or = safe_slice_from(query, or_pos + 4);
2784 let partial = extract_partial_at_end(after_or);
2785 if partial.is_some() {
2786 return (CursorContext::AfterLogicalOp(LogicalOp::Or), partial);
2787 }
2788 }
2789
2790 if let Some(last_condition) = where_clause.conditions.last() {
2791 if let Some(connector) = &last_condition.connector {
2792 return (
2794 CursorContext::AfterLogicalOp(connector.clone()),
2795 extract_partial_at_end(query),
2796 );
2797 }
2798 }
2799 return (CursorContext::WhereClause, extract_partial_at_end(query));
2801 }
2802
2803 if query.to_uppercase().ends_with(" ORDER BY ") || query.to_uppercase().ends_with(" ORDER BY") {
2805 return (CursorContext::OrderByClause, None);
2806 }
2807
2808 if stmt.order_by.is_some() {
2810 return (CursorContext::OrderByClause, extract_partial_at_end(query));
2811 }
2812
2813 if stmt.from_table.is_some() && stmt.where_clause.is_none() && stmt.order_by.is_none() {
2814 return (CursorContext::FromClause, extract_partial_at_end(query));
2815 }
2816
2817 if !stmt.columns.is_empty() && stmt.from_table.is_none() {
2818 return (CursorContext::SelectClause, extract_partial_at_end(query));
2819 }
2820
2821 (CursorContext::Unknown, None)
2822}
2823
2824fn analyze_partial(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
2825 let upper = query.to_uppercase();
2826
2827 let trimmed = query.trim();
2829
2830 #[cfg(test)]
2831 {
2832 if trimmed.contains("\"Last Name\"") {
2833 eprintln!("DEBUG analyze_partial: query='{query}', trimmed='{trimmed}'");
2834 }
2835 }
2836
2837 let comparison_ops = [" > ", " < ", " >= ", " <= ", " = ", " != "];
2839 for op in &comparison_ops {
2840 if let Some(op_pos) = query.rfind(op) {
2841 let before_op = safe_slice_to(query, op_pos);
2842 let after_op_start = op_pos + op.len();
2843 let after_op = if after_op_start < query.len() {
2844 &query[after_op_start..]
2845 } else {
2846 ""
2847 };
2848
2849 if let Some(col_name) = before_op.split_whitespace().last() {
2851 if col_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
2852 let after_op_trimmed = after_op.trim();
2854 if after_op_trimmed.is_empty()
2855 || (after_op_trimmed
2856 .chars()
2857 .all(|c| c.is_alphanumeric() || c == '_')
2858 && !after_op_trimmed.contains('('))
2859 {
2860 let partial = if after_op_trimmed.is_empty() {
2861 None
2862 } else {
2863 Some(after_op_trimmed.to_string())
2864 };
2865 return (
2866 CursorContext::AfterComparisonOp(
2867 col_name.to_string(),
2868 op.trim().to_string(),
2869 ),
2870 partial,
2871 );
2872 }
2873 }
2874 }
2875 }
2876 }
2877
2878 if let Some(dot_pos) = trimmed.rfind('.') {
2881 #[cfg(test)]
2882 {
2883 if trimmed.contains("\"Last Name\"") {
2884 eprintln!("DEBUG: Found dot at position {dot_pos}");
2885 }
2886 }
2887 let before_dot = &trimmed[..dot_pos];
2889 let after_dot = &trimmed[dot_pos + 1..];
2890
2891 if !after_dot.contains('(') {
2894 let col_name = if before_dot.ends_with('"') {
2897 let bytes = before_dot.as_bytes();
2899 let mut pos = before_dot.len() - 1; let mut found_start = None;
2901
2902 #[cfg(test)]
2903 {
2904 if trimmed.contains("\"Last Name\"") {
2905 eprintln!("DEBUG: before_dot='{before_dot}', looking for opening quote");
2906 }
2907 }
2908
2909 if pos > 0 {
2911 pos -= 1;
2912 while pos > 0 {
2913 if bytes[pos] == b'"' {
2914 if pos == 0 || bytes[pos - 1] != b'\\' {
2916 found_start = Some(pos);
2917 break;
2918 }
2919 }
2920 pos -= 1;
2921 }
2922 if found_start.is_none() && bytes[0] == b'"' {
2924 found_start = Some(0);
2925 }
2926 }
2927
2928 if let Some(start) = found_start {
2929 let result = safe_slice_from(before_dot, start);
2931 #[cfg(test)]
2932 {
2933 if trimmed.contains("\"Last Name\"") {
2934 eprintln!("DEBUG: Extracted quoted identifier: '{result}'");
2935 }
2936 }
2937 Some(result)
2938 } else {
2939 #[cfg(test)]
2940 {
2941 if trimmed.contains("\"Last Name\"") {
2942 eprintln!("DEBUG: No opening quote found!");
2943 }
2944 }
2945 None
2946 }
2947 } else {
2948 before_dot
2951 .split_whitespace()
2952 .last()
2953 .map(|word| word.trim_start_matches('('))
2954 };
2955
2956 if let Some(col_name) = col_name {
2957 #[cfg(test)]
2958 {
2959 if trimmed.contains("\"Last Name\"") {
2960 eprintln!("DEBUG: col_name = '{col_name}'");
2961 }
2962 }
2963
2964 let is_valid = if col_name.starts_with('"') && col_name.ends_with('"') {
2966 true
2968 } else {
2969 col_name.chars().all(|c| c.is_alphanumeric() || c == '_')
2971 };
2972
2973 #[cfg(test)]
2974 {
2975 if trimmed.contains("\"Last Name\"") {
2976 eprintln!("DEBUG: is_valid = {is_valid}");
2977 }
2978 }
2979
2980 if is_valid {
2981 let partial_method = if after_dot.is_empty() {
2984 None
2985 } else if after_dot.chars().all(|c| c.is_alphanumeric() || c == '_') {
2986 Some(after_dot.to_string())
2987 } else {
2988 None
2989 };
2990
2991 let col_name_for_context = if col_name.starts_with('"')
2993 && col_name.ends_with('"')
2994 && col_name.len() > 2
2995 {
2996 col_name[1..col_name.len() - 1].to_string()
2997 } else {
2998 col_name.to_string()
2999 };
3000
3001 return (
3002 CursorContext::AfterColumn(col_name_for_context),
3003 partial_method,
3004 );
3005 }
3006 }
3007 }
3008 }
3009
3010 if let Some(and_pos) = upper.rfind(" AND ") {
3012 if cursor_pos >= and_pos + 5 {
3014 let after_and = safe_slice_from(query, and_pos + 5);
3016 let partial = extract_partial_at_end(after_and);
3017 return (CursorContext::AfterLogicalOp(LogicalOp::And), partial);
3018 }
3019 }
3020
3021 if let Some(or_pos) = upper.rfind(" OR ") {
3022 if cursor_pos >= or_pos + 4 {
3024 let after_or = safe_slice_from(query, or_pos + 4);
3026 let partial = extract_partial_at_end(after_or);
3027 return (CursorContext::AfterLogicalOp(LogicalOp::Or), partial);
3028 }
3029 }
3030
3031 if trimmed.to_uppercase().ends_with(" AND") || trimmed.to_uppercase().ends_with(" OR") {
3033 let op = if trimmed.to_uppercase().ends_with(" AND") {
3034 LogicalOp::And
3035 } else {
3036 LogicalOp::Or
3037 };
3038 return (CursorContext::AfterLogicalOp(op), None);
3039 }
3040
3041 if upper.ends_with(" ORDER BY ") || upper.ends_with(" ORDER BY") || upper.contains("ORDER BY ")
3043 {
3044 return (CursorContext::OrderByClause, extract_partial_at_end(query));
3045 }
3046
3047 if upper.contains("WHERE") && !upper.contains("ORDER") && !upper.contains("GROUP") {
3048 return (CursorContext::WhereClause, extract_partial_at_end(query));
3049 }
3050
3051 if upper.contains("FROM") && !upper.contains("WHERE") && !upper.contains("ORDER") {
3052 return (CursorContext::FromClause, extract_partial_at_end(query));
3053 }
3054
3055 if upper.contains("SELECT") && !upper.contains("FROM") {
3056 return (CursorContext::SelectClause, extract_partial_at_end(query));
3057 }
3058
3059 (CursorContext::Unknown, None)
3060}
3061
3062fn extract_partial_at_end(query: &str) -> Option<String> {
3063 let trimmed = query.trim();
3064
3065 if let Some(last_word) = trimmed.split_whitespace().last() {
3067 if last_word.starts_with('"') && !last_word.ends_with('"') {
3068 return Some(last_word.to_string());
3070 }
3071 }
3072
3073 let last_word = trimmed.split_whitespace().last()?;
3075
3076 if last_word.chars().all(|c| c.is_alphanumeric() || c == '_') && !is_sql_keyword(last_word) {
3078 Some(last_word.to_string())
3079 } else {
3080 None
3081 }
3082}
3083
3084fn is_sql_keyword(word: &str) -> bool {
3085 matches!(
3086 word.to_uppercase().as_str(),
3087 "SELECT"
3088 | "FROM"
3089 | "WHERE"
3090 | "AND"
3091 | "OR"
3092 | "IN"
3093 | "ORDER"
3094 | "BY"
3095 | "GROUP"
3096 | "HAVING"
3097 | "ASC"
3098 | "DESC"
3099 | "DISTINCT"
3100 )
3101}
3102
3103#[cfg(test)]
3104mod tests {
3105 use super::*;
3106
3107 #[test]
3108 fn test_tokenizer_window_functions() {
3109 let mut lexer = Lexer::new("LAG(value) OVER (PARTITION BY category ORDER BY id)");
3110 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "LAG"));
3111 assert!(matches!(lexer.next_token(), Token::LeftParen));
3112 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "value"));
3113 assert!(matches!(lexer.next_token(), Token::RightParen));
3114
3115 let over_token = lexer.next_token();
3116 println!("Expected OVER, got: {:?}", over_token);
3117 assert!(matches!(over_token, Token::Over));
3118
3119 assert!(matches!(lexer.next_token(), Token::LeftParen));
3120 assert!(matches!(lexer.next_token(), Token::Partition));
3121 assert!(matches!(lexer.next_token(), Token::By));
3122 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "category"));
3123 }
3124
3125 #[test]
3126 fn test_parse_window_function() {
3127 let query = "SELECT LAG(value, 1) OVER (ORDER BY id) as prev_value FROM test";
3128 let mut parser = Parser::new(query);
3129 let result = parser.parse();
3130
3131 assert!(
3132 result.is_ok(),
3133 "Failed to parse window function: {:?}",
3134 result
3135 );
3136 let stmt = result.unwrap();
3137
3138 if let Some(item) = stmt.select_items.get(0) {
3140 match item {
3141 SelectItem::Expression { expr, alias } => {
3142 println!("Parsed expression: {:?}", expr);
3143 assert!(matches!(expr, SqlExpression::WindowFunction { .. }));
3144 assert_eq!(alias, "prev_value");
3145 }
3146 _ => panic!("Expected expression, got: {:?}", item),
3147 }
3148 } else {
3149 panic!("No select items found");
3150 }
3151 }
3152
3153 #[test]
3154 fn test_chained_method_calls() {
3155 let query = "SELECT * FROM trades WHERE commission.ToString().IndexOf('.') = 1";
3157 let mut parser = Parser::new(query);
3158 let result = parser.parse();
3159
3160 assert!(
3161 result.is_ok(),
3162 "Failed to parse chained method calls: {result:?}"
3163 );
3164
3165 let query2 = "SELECT * FROM data WHERE field.ToUpper().Replace('A', 'B').StartsWith('C')";
3167 let mut parser2 = Parser::new(query2);
3168 let result2 = parser2.parse();
3169
3170 assert!(
3171 result2.is_ok(),
3172 "Failed to parse multiple chained calls: {result2:?}"
3173 );
3174 }
3175
3176 #[test]
3177 fn test_tokenizer() {
3178 let mut lexer = Lexer::new("SELECT * FROM trade_deal WHERE price > 100");
3179
3180 assert!(matches!(lexer.next_token(), Token::Select));
3181 assert!(matches!(lexer.next_token(), Token::Star));
3182 assert!(matches!(lexer.next_token(), Token::From));
3183 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "trade_deal"));
3184 assert!(matches!(lexer.next_token(), Token::Where));
3185 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "price"));
3186 assert!(matches!(lexer.next_token(), Token::GreaterThan));
3187 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "100"));
3188 }
3189
3190 #[test]
3191 fn test_tokenizer_datetime() {
3192 let mut lexer = Lexer::new("WHERE createdDate > DateTime(2025, 10, 20)");
3193
3194 assert!(matches!(lexer.next_token(), Token::Where));
3195 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "createdDate"));
3196 assert!(matches!(lexer.next_token(), Token::GreaterThan));
3197 assert!(matches!(lexer.next_token(), Token::DateTime));
3198 assert!(matches!(lexer.next_token(), Token::LeftParen));
3199 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "2025"));
3200 assert!(matches!(lexer.next_token(), Token::Comma));
3201 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "10"));
3202 assert!(matches!(lexer.next_token(), Token::Comma));
3203 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "20"));
3204 assert!(matches!(lexer.next_token(), Token::RightParen));
3205 }
3206
3207 #[test]
3208 fn test_parse_simple_select() {
3209 let mut parser = Parser::new("SELECT * FROM trade_deal");
3210 let stmt = parser.parse().unwrap();
3211
3212 assert_eq!(stmt.columns, vec!["*"]);
3213 assert_eq!(stmt.from_table, Some("trade_deal".to_string()));
3214 assert!(stmt.where_clause.is_none());
3215 }
3216
3217 #[test]
3218 fn test_parse_where_with_method() {
3219 let mut parser = Parser::new("SELECT * FROM trade_deal WHERE name.Contains(\"test\")");
3220 let stmt = parser.parse().unwrap();
3221
3222 assert!(stmt.where_clause.is_some());
3223 let where_clause = stmt.where_clause.unwrap();
3224 assert_eq!(where_clause.conditions.len(), 1);
3225 }
3226
3227 #[test]
3228 fn test_parse_datetime_constructor() {
3229 let mut parser =
3230 Parser::new("SELECT * FROM trade_deal WHERE createdDate > DateTime(2025, 10, 20)");
3231 let stmt = parser.parse().unwrap();
3232
3233 assert!(stmt.where_clause.is_some());
3234 let where_clause = stmt.where_clause.unwrap();
3235 assert_eq!(where_clause.conditions.len(), 1);
3236
3237 if let SqlExpression::BinaryOp { left, op, right } = &where_clause.conditions[0].expr {
3239 assert_eq!(op, ">");
3240 assert!(matches!(left.as_ref(), SqlExpression::Column(col) if col == "createdDate"));
3241 assert!(matches!(
3242 right.as_ref(),
3243 SqlExpression::DateTimeConstructor {
3244 year: 2025,
3245 month: 10,
3246 day: 20,
3247 hour: None,
3248 minute: None,
3249 second: None
3250 }
3251 ));
3252 } else {
3253 panic!("Expected BinaryOp with DateTime constructor");
3254 }
3255 }
3256
3257 #[test]
3258 fn test_cursor_context_after_and() {
3259 let query = "SELECT * FROM trade_deal WHERE status = 'active' AND ";
3260 let (context, partial) = detect_cursor_context(query, query.len());
3261
3262 assert!(matches!(
3263 context,
3264 CursorContext::AfterLogicalOp(LogicalOp::And)
3265 ));
3266 assert_eq!(partial, None);
3267 }
3268
3269 #[test]
3270 fn test_cursor_context_with_partial() {
3271 let query = "SELECT * FROM trade_deal WHERE status = 'active' AND p";
3272 let (context, partial) = detect_cursor_context(query, query.len());
3273
3274 assert!(matches!(
3275 context,
3276 CursorContext::AfterLogicalOp(LogicalOp::And)
3277 ));
3278 assert_eq!(partial, Some("p".to_string()));
3279 }
3280
3281 #[test]
3282 fn test_cursor_context_after_datetime_comparison() {
3283 let query = "SELECT * FROM trade_deal WHERE createdDate > ";
3284 let (context, partial) = detect_cursor_context(query, query.len());
3285
3286 assert!(
3287 matches!(context, CursorContext::AfterComparisonOp(col, op) if col == "createdDate" && op == ">")
3288 );
3289 assert_eq!(partial, None);
3290 }
3291
3292 #[test]
3293 fn test_cursor_context_partial_datetime() {
3294 let query = "SELECT * FROM trade_deal WHERE createdDate > Date";
3295 let (context, partial) = detect_cursor_context(query, query.len());
3296
3297 assert!(
3298 matches!(context, CursorContext::AfterComparisonOp(col, op) if col == "createdDate" && op == ">")
3299 );
3300 assert_eq!(partial, Some("Date".to_string()));
3301 }
3302
3303 #[test]
3305 fn test_tokenizer_quoted_identifier() {
3306 let mut lexer = Lexer::new(r#"SELECT "Customer Id", "First Name" FROM customers"#);
3307
3308 assert!(matches!(lexer.next_token(), Token::Select));
3309 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "Customer Id"));
3310 assert!(matches!(lexer.next_token(), Token::Comma));
3311 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "First Name"));
3312 assert!(matches!(lexer.next_token(), Token::From));
3313 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "customers"));
3314 }
3315
3316 #[test]
3317 fn test_tokenizer_quoted_vs_string_literal() {
3318 let mut lexer = Lexer::new(r#"WHERE "Customer Id" = 'John' AND Country.Contains('USA')"#);
3320
3321 assert!(matches!(lexer.next_token(), Token::Where));
3322 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "Customer Id"));
3323 assert!(matches!(lexer.next_token(), Token::Equal));
3324 assert!(matches!(lexer.next_token(), Token::StringLiteral(s) if s == "John"));
3325 assert!(matches!(lexer.next_token(), Token::And));
3326 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Country"));
3327 assert!(matches!(lexer.next_token(), Token::Dot));
3328 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Contains"));
3329 assert!(matches!(lexer.next_token(), Token::LeftParen));
3330 assert!(matches!(lexer.next_token(), Token::StringLiteral(s) if s == "USA"));
3331 assert!(matches!(lexer.next_token(), Token::RightParen));
3332 }
3333
3334 #[test]
3335 fn test_tokenizer_method_with_double_quotes_should_be_string() {
3336 let mut lexer = Lexer::new(r#"Country.Contains("Alb")"#);
3339
3340 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Country"));
3341 assert!(matches!(lexer.next_token(), Token::Dot));
3342 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Contains"));
3343 assert!(matches!(lexer.next_token(), Token::LeftParen));
3344
3345 let token = lexer.next_token();
3348 println!("Token for \"Alb\": {token:?}");
3349 assert!(matches!(lexer.next_token(), Token::RightParen));
3353 }
3354
3355 #[test]
3356 fn test_parse_select_with_quoted_columns() {
3357 let mut parser = Parser::new(r#"SELECT "Customer Id", Company FROM customers"#);
3358 let stmt = parser.parse().unwrap();
3359
3360 assert_eq!(stmt.columns, vec!["Customer Id", "Company"]);
3361 assert_eq!(stmt.from_table, Some("customers".to_string()));
3362 }
3363
3364 #[test]
3365 fn test_cursor_context_select_with_partial_quoted() {
3366 let query = r#"SELECT "Cust"#;
3368 let (context, partial) = detect_cursor_context(query, query.len() - 1); println!("Context: {context:?}, Partial: {partial:?}");
3371 assert!(matches!(context, CursorContext::SelectClause));
3372 }
3375
3376 #[test]
3377 fn test_cursor_context_select_after_comma_with_quoted() {
3378 let query = r#"SELECT Company, "Customer "#;
3380 let (context, partial) = detect_cursor_context(query, query.len());
3381
3382 println!("Context: {context:?}, Partial: {partial:?}");
3383 assert!(matches!(context, CursorContext::SelectClause));
3384 }
3386
3387 #[test]
3388 fn test_cursor_context_order_by_quoted() {
3389 let query = r#"SELECT * FROM customers ORDER BY "Cust"#;
3390 let (context, partial) = detect_cursor_context(query, query.len() - 1);
3391
3392 println!("Context: {context:?}, Partial: {partial:?}");
3393 assert!(matches!(context, CursorContext::OrderByClause));
3394 }
3396
3397 #[test]
3398 fn test_where_clause_with_quoted_column() {
3399 let mut parser = Parser::new(r#"SELECT * FROM customers WHERE "Customer Id" = 'C123'"#);
3400 let stmt = parser.parse().unwrap();
3401
3402 assert!(stmt.where_clause.is_some());
3403 let where_clause = stmt.where_clause.unwrap();
3404 assert_eq!(where_clause.conditions.len(), 1);
3405
3406 if let SqlExpression::BinaryOp { left, op, right } = &where_clause.conditions[0].expr {
3407 assert_eq!(op, "=");
3408 assert!(matches!(left.as_ref(), SqlExpression::Column(col) if col == "Customer Id"));
3409 assert!(matches!(right.as_ref(), SqlExpression::StringLiteral(s) if s == "C123"));
3410 } else {
3411 panic!("Expected BinaryOp");
3412 }
3413 }
3414
3415 #[test]
3416 fn test_parse_method_with_double_quotes_as_string() {
3417 let mut parser = Parser::new(r#"SELECT * FROM customers WHERE Country.Contains("USA")"#);
3419 let stmt = parser.parse().unwrap();
3420
3421 assert!(stmt.where_clause.is_some());
3422 let where_clause = stmt.where_clause.unwrap();
3423 assert_eq!(where_clause.conditions.len(), 1);
3424
3425 if let SqlExpression::MethodCall {
3426 object,
3427 method,
3428 args,
3429 } = &where_clause.conditions[0].expr
3430 {
3431 assert_eq!(object, "Country");
3432 assert_eq!(method, "Contains");
3433 assert_eq!(args.len(), 1);
3434 assert!(matches!(&args[0], SqlExpression::StringLiteral(s) if s == "USA"));
3436 } else {
3437 panic!("Expected MethodCall");
3438 }
3439 }
3440
3441 #[test]
3442 fn test_extract_partial_with_quoted_columns_in_query() {
3443 let query = r#"SELECT City,Company,Country,"Customer Id" FROM customers ORDER BY coun"#;
3445 let (context, partial) = detect_cursor_context(query, query.len());
3446
3447 assert!(matches!(context, CursorContext::OrderByClause));
3448 assert_eq!(
3449 partial,
3450 Some("coun".to_string()),
3451 "Should extract 'coun' as partial, not everything after the quoted column"
3452 );
3453 }
3454
3455 #[test]
3456 fn test_extract_partial_quoted_identifier_being_typed() {
3457 let query = r#"SELECT "Cust"#;
3459 let partial = extract_partial_at_end(query);
3460 assert_eq!(partial, Some("\"Cust".to_string()));
3461
3462 let query2 = r#"SELECT "Customer Id" FROM"#;
3464 let partial2 = extract_partial_at_end(query2);
3465 assert_eq!(partial2, None); }
3467
3468 #[test]
3470 fn test_complex_where_parentheses_basic() {
3471 let mut parser =
3473 Parser::new(r#"SELECT * FROM trades WHERE (status = "active" OR status = "pending")"#);
3474 let stmt = parser.parse().unwrap();
3475
3476 assert!(stmt.where_clause.is_some());
3477 let where_clause = stmt.where_clause.unwrap();
3478 assert_eq!(where_clause.conditions.len(), 1);
3479
3480 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[0].expr {
3482 assert_eq!(op, "OR");
3483 } else {
3484 panic!("Expected BinaryOp with OR");
3485 }
3486 }
3487
3488 #[test]
3489 fn test_complex_where_mixed_and_or_with_parens() {
3490 let mut parser = Parser::new(
3492 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#,
3493 );
3494 let stmt = parser.parse().unwrap();
3495
3496 assert!(stmt.where_clause.is_some());
3497 let where_clause = stmt.where_clause.unwrap();
3498 assert_eq!(where_clause.conditions.len(), 2);
3499
3500 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[0].expr {
3502 assert_eq!(op, "OR");
3503 } else {
3504 panic!("Expected first condition to be OR expression");
3505 }
3506
3507 assert!(matches!(
3509 where_clause.conditions[0].connector,
3510 Some(LogicalOp::And)
3511 ));
3512
3513 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[1].expr {
3515 assert_eq!(op, ">");
3516 } else {
3517 panic!("Expected second condition to be price > 100");
3518 }
3519 }
3520
3521 #[test]
3522 fn test_complex_where_nested_parentheses() {
3523 let mut parser = Parser::new(
3525 r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND price > 100) OR status = "cancelled""#,
3526 );
3527 let stmt = parser.parse().unwrap();
3528
3529 assert!(stmt.where_clause.is_some());
3530 let where_clause = stmt.where_clause.unwrap();
3531
3532 assert!(!where_clause.conditions.is_empty());
3534 }
3535
3536 #[test]
3537 fn test_complex_where_multiple_or_groups() {
3538 let mut parser = Parser::new(
3540 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL" OR symbol = "MSFT") AND (price > 100 AND price < 500)"#,
3541 );
3542 let stmt = parser.parse().unwrap();
3543
3544 assert!(stmt.where_clause.is_some());
3545 let where_clause = stmt.where_clause.unwrap();
3546 assert_eq!(where_clause.conditions.len(), 2);
3547
3548 assert!(matches!(
3550 where_clause.conditions[0].connector,
3551 Some(LogicalOp::And)
3552 ));
3553 }
3554
3555 #[test]
3556 fn test_complex_where_with_methods_in_parens() {
3557 let mut parser = Parser::new(
3559 r#"SELECT * FROM trades WHERE (symbol.StartsWith("A") OR symbol.StartsWith("G")) AND volume > 1000000"#,
3560 );
3561 let stmt = parser.parse().unwrap();
3562
3563 assert!(stmt.where_clause.is_some());
3564 let where_clause = stmt.where_clause.unwrap();
3565 assert_eq!(where_clause.conditions.len(), 2);
3566
3567 if let SqlExpression::BinaryOp { op, left, right } = &where_clause.conditions[0].expr {
3569 assert_eq!(op, "OR");
3570 assert!(matches!(left.as_ref(), SqlExpression::MethodCall { .. }));
3571 assert!(matches!(right.as_ref(), SqlExpression::MethodCall { .. }));
3572 } else {
3573 panic!("Expected OR of method calls");
3574 }
3575 }
3576
3577 #[test]
3578 fn test_complex_where_date_comparisons_with_parens() {
3579 let mut parser = Parser::new(
3581 r#"SELECT * FROM trades WHERE (executionDate > DateTime(2024, 1, 1) AND executionDate < DateTime(2024, 12, 31)) AND (status = "filled" OR status = "partial")"#,
3582 );
3583 let stmt = parser.parse().unwrap();
3584
3585 assert!(stmt.where_clause.is_some());
3586 let where_clause = stmt.where_clause.unwrap();
3587 assert_eq!(where_clause.conditions.len(), 2);
3588
3589 assert!(matches!(
3591 where_clause.conditions[0].connector,
3592 Some(LogicalOp::And)
3593 ));
3594 }
3595
3596 #[test]
3597 fn test_complex_where_price_volume_filters() {
3598 let mut parser = Parser::new(
3600 r"SELECT * FROM trades WHERE ((price > 100 AND price < 200) OR (price > 500 AND price < 1000)) AND volume > 10000",
3601 );
3602 let stmt = parser.parse().unwrap();
3603
3604 assert!(stmt.where_clause.is_some());
3605 let where_clause = stmt.where_clause.unwrap();
3606
3607 assert!(!where_clause.conditions.is_empty());
3609 }
3610
3611 #[test]
3612 fn test_complex_where_mixed_string_numeric() {
3613 let mut parser = Parser::new(
3615 r#"SELECT * FROM trades WHERE (exchange = "NYSE" OR exchange = "NASDAQ") AND (volume > 1000000 OR notes.Contains("urgent"))"#,
3616 );
3617 let stmt = parser.parse().unwrap();
3618
3619 assert!(stmt.where_clause.is_some());
3620 }
3622
3623 #[test]
3624 fn test_complex_where_triple_nested() {
3625 let mut parser = Parser::new(
3627 r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND (price > 100 OR volume > 1000000)) OR (status = "cancelled" AND reason.Contains("timeout"))"#,
3628 );
3629 let stmt = parser.parse().unwrap();
3630
3631 assert!(stmt.where_clause.is_some());
3632 }
3634
3635 #[test]
3636 fn test_complex_where_single_parens_around_and() {
3637 let mut parser = Parser::new(
3639 r#"SELECT * FROM trades WHERE (symbol = "AAPL" AND price > 150 AND volume > 100000)"#,
3640 );
3641 let stmt = parser.parse().unwrap();
3642
3643 assert!(stmt.where_clause.is_some());
3644 let where_clause = stmt.where_clause.unwrap();
3645
3646 assert!(!where_clause.conditions.is_empty());
3648 }
3649
3650 #[test]
3652 fn test_format_preserves_simple_parentheses() {
3653 let query = r#"SELECT * FROM trades WHERE (status = "active" OR status = "pending")"#;
3654 let formatted = format_sql_pretty_compact(query, 5);
3655 let formatted_text = formatted.join(" ");
3656
3657 assert!(formatted_text.contains("(status"));
3659 assert!(formatted_text.contains("\"pending\")"));
3660
3661 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3663 let formatted_parens = formatted_text
3664 .chars()
3665 .filter(|c| *c == '(' || *c == ')')
3666 .count();
3667 assert_eq!(
3668 original_parens, formatted_parens,
3669 "Parentheses should be preserved"
3670 );
3671 }
3672
3673 #[test]
3674 fn test_format_preserves_complex_parentheses() {
3675 let query =
3676 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#;
3677 let formatted = format_sql_pretty_compact(query, 5);
3678 let formatted_text = formatted.join(" ");
3679
3680 assert!(formatted_text.contains("(symbol"));
3682 assert!(formatted_text.contains("\"GOOGL\")"));
3683
3684 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3686 let formatted_parens = formatted_text
3687 .chars()
3688 .filter(|c| *c == '(' || *c == ')')
3689 .count();
3690 assert_eq!(original_parens, formatted_parens);
3691 }
3692
3693 #[test]
3694 fn test_format_preserves_nested_parentheses() {
3695 let query = r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND price > 100) OR status = "cancelled""#;
3696 let formatted = format_sql_pretty_compact(query, 5);
3697 let formatted_text = formatted.join(" ");
3698
3699 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3701 let formatted_parens = formatted_text
3702 .chars()
3703 .filter(|c| *c == '(' || *c == ')')
3704 .count();
3705 assert_eq!(
3706 original_parens, formatted_parens,
3707 "Nested parentheses should be preserved"
3708 );
3709 assert_eq!(original_parens, 4, "Should have 4 parentheses total");
3710 }
3711
3712 #[test]
3713 fn test_format_preserves_method_calls_in_parentheses() {
3714 let query = r#"SELECT * FROM trades WHERE (symbol.StartsWith("A") OR symbol.StartsWith("G")) AND volume > 1000000"#;
3715 let formatted = format_sql_pretty_compact(query, 5);
3716 let formatted_text = formatted.join(" ");
3717
3718 assert!(formatted_text.contains("(symbol.StartsWith"));
3720 assert!(formatted_text.contains("StartsWith(\"A\")"));
3721 assert!(formatted_text.contains("StartsWith(\"G\")"));
3722
3723 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3725 let formatted_parens = formatted_text
3726 .chars()
3727 .filter(|c| *c == '(' || *c == ')')
3728 .count();
3729 assert_eq!(original_parens, formatted_parens);
3730 assert_eq!(
3731 original_parens, 6,
3732 "Should have 6 parentheses (1 group + 2 method calls)"
3733 );
3734 }
3735
3736 #[test]
3737 fn test_format_preserves_multiple_groups() {
3738 let query = r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND (price > 100 AND price < 500)"#;
3739 let formatted = format_sql_pretty_compact(query, 5);
3740 let formatted_text = formatted.join(" ");
3741
3742 assert!(formatted_text.contains("(symbol"));
3744 assert!(formatted_text.contains("(price"));
3745
3746 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3747 let formatted_parens = formatted_text
3748 .chars()
3749 .filter(|c| *c == '(' || *c == ')')
3750 .count();
3751 assert_eq!(original_parens, formatted_parens);
3752 assert_eq!(original_parens, 4, "Should have 4 parentheses (2 groups)");
3753 }
3754
3755 #[test]
3756 fn test_format_preserves_date_ranges() {
3757 let query = r"SELECT * FROM trades WHERE (executionDate > DateTime(2024, 1, 1) AND executionDate < DateTime(2024, 12, 31))";
3758 let formatted = format_sql_pretty_compact(query, 5);
3759 let formatted_text = formatted.join(" ");
3760
3761 assert!(formatted_text.contains("(executionDate"));
3763 assert!(formatted_text.contains("DateTime(2024, 1, 1)"));
3764 assert!(formatted_text.contains("DateTime(2024, 12, 31)"));
3765
3766 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3767 let formatted_parens = formatted_text
3768 .chars()
3769 .filter(|c| *c == '(' || *c == ')')
3770 .count();
3771 assert_eq!(original_parens, formatted_parens);
3772 }
3773
3774 #[test]
3775 fn test_format_multiline_layout() {
3776 let query =
3778 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#;
3779 let formatted = format_sql_pretty_compact(query, 5);
3780
3781 assert!(formatted.len() >= 4, "Should have multiple lines");
3783 assert_eq!(formatted[0], "SELECT");
3784 assert!(formatted[1].trim().starts_with('*'));
3785 assert!(formatted[2].starts_with("FROM"));
3786 assert_eq!(formatted[3], "WHERE");
3787
3788 let where_lines: Vec<_> = formatted.iter().skip(4).collect();
3790 assert!(where_lines.iter().any(|l| l.contains("(symbol")));
3791 assert!(where_lines.iter().any(|l| l.trim() == "AND"));
3792 }
3793
3794 #[test]
3795 fn test_between_simple() {
3796 let mut parser = Parser::new("SELECT * FROM table WHERE price BETWEEN 50 AND 100");
3797 let stmt = parser.parse().expect("Should parse simple BETWEEN");
3798
3799 assert!(stmt.where_clause.is_some());
3800 let where_clause = stmt.where_clause.unwrap();
3801 assert_eq!(where_clause.conditions.len(), 1);
3802
3803 let ast = format_ast_tree("SELECT * FROM table WHERE price BETWEEN 50 AND 100");
3805 assert!(!ast.contains("PARSE ERROR"));
3806 assert!(ast.contains("SelectStatement"));
3807 }
3808
3809 #[test]
3810 fn test_between_in_parentheses() {
3811 let mut parser = Parser::new("SELECT * FROM table WHERE (price BETWEEN 50 AND 100)");
3812 let stmt = parser.parse().expect("Should parse BETWEEN in parentheses");
3813
3814 assert!(stmt.where_clause.is_some());
3815
3816 let ast = format_ast_tree("SELECT * FROM table WHERE (price BETWEEN 50 AND 100)");
3818 assert!(!ast.contains("PARSE ERROR"), "Should not have parse error");
3819 }
3820
3821 #[test]
3822 fn test_between_with_or() {
3823 let query = "SELECT * FROM test WHERE (Price BETWEEN 50 AND 100) OR (quantity > 5)";
3824 let mut parser = Parser::new(query);
3825 let stmt = parser.parse().expect("Should parse BETWEEN with OR");
3826
3827 assert!(stmt.where_clause.is_some());
3828 }
3831
3832 #[test]
3833 fn test_between_with_and() {
3834 let query = "SELECT * FROM table WHERE category = 'Books' AND price BETWEEN 10 AND 50";
3835 let mut parser = Parser::new(query);
3836 let stmt = parser.parse().expect("Should parse BETWEEN with AND");
3837
3838 assert!(stmt.where_clause.is_some());
3839 let where_clause = stmt.where_clause.unwrap();
3840 assert_eq!(where_clause.conditions.len(), 2); }
3842
3843 #[test]
3844 fn test_multiple_between() {
3845 let query =
3846 "SELECT * FROM table WHERE (price BETWEEN 10 AND 50) AND (quantity BETWEEN 5 AND 20)";
3847 let mut parser = Parser::new(query);
3848 let stmt = parser
3849 .parse()
3850 .expect("Should parse multiple BETWEEN clauses");
3851
3852 assert!(stmt.where_clause.is_some());
3853 }
3854
3855 #[test]
3856 fn test_between_complex_query() {
3857 let query = "SELECT * FROM test_sorting WHERE (Price BETWEEN 50 AND 100) OR (Product.Length() > 5) ORDER BY Category ASC, price DESC";
3859 let mut parser = Parser::new(query);
3860 let stmt = parser
3861 .parse()
3862 .expect("Should parse complex query with BETWEEN, method calls, and ORDER BY");
3863
3864 assert!(stmt.where_clause.is_some());
3865 assert!(stmt.order_by.is_some());
3866
3867 let order_by = stmt.order_by.unwrap();
3868 assert_eq!(order_by.len(), 2);
3869 assert_eq!(order_by[0].column, "Category");
3870 assert!(matches!(order_by[0].direction, SortDirection::Asc));
3871 assert_eq!(order_by[1].column, "price");
3872 assert!(matches!(order_by[1].direction, SortDirection::Desc));
3873 }
3874
3875 #[test]
3876 fn test_between_formatting() {
3877 let expr = SqlExpression::Between {
3878 expr: Box::new(SqlExpression::Column("price".to_string())),
3879 lower: Box::new(SqlExpression::NumberLiteral("50".to_string())),
3880 upper: Box::new(SqlExpression::NumberLiteral("100".to_string())),
3881 };
3882
3883 let formatted = format_expression(&expr);
3884 assert_eq!(formatted, "price BETWEEN 50 AND 100");
3885
3886 let ast_formatted = format_expression_ast(&expr);
3887 assert!(ast_formatted.contains("Between"));
3888 assert!(ast_formatted.contains("50"));
3889 assert!(ast_formatted.contains("100"));
3890 }
3891
3892 #[test]
3893 fn test_utf8_boundary_safety() {
3894 let query_with_unicode = "SELECT * FROM table WHERE column = 'héllo'";
3896
3897 for pos in 0..=query_with_unicode.len() {
3899 let result =
3901 std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, pos));
3902
3903 assert!(
3904 result.is_ok(),
3905 "Panic at position {pos} in query with Unicode"
3906 );
3907 }
3908
3909 let result = std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, 1000));
3911 assert!(result.is_ok(), "Panic with position beyond string length");
3912
3913 let pos_in_e = query_with_unicode.find('é').unwrap() + 1; let result =
3916 std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, pos_in_e));
3917 assert!(
3918 result.is_ok(),
3919 "Panic with cursor in middle of UTF-8 character"
3920 );
3921 }
3922}