1use chrono::{Datelike, Local, NaiveDateTime};
2
3#[derive(Debug, Clone, PartialEq)]
4pub enum Token {
5 Select,
7 From,
8 Where,
9 With, And,
11 Or,
12 In,
13 Not,
14 Between,
15 Like,
16 Is,
17 Null,
18 OrderBy,
19 GroupBy,
20 Having,
21 As,
22 Asc,
23 Desc,
24 Limit,
25 Offset,
26 DateTime, Case, When, Then, Else, End, Distinct, Over, Partition, By, Identifier(String),
39 QuotedIdentifier(String), StringLiteral(String),
41 NumberLiteral(String),
42 Star,
43
44 Dot,
46 Comma,
47 LeftParen,
48 RightParen,
49 Equal,
50 NotEqual,
51 LessThan,
52 GreaterThan,
53 LessThanOrEqual,
54 GreaterThanOrEqual,
55
56 Plus,
58 Minus,
59 Divide,
60
61 Eof,
63}
64
65#[derive(Debug, Clone)]
66pub struct Lexer {
67 input: Vec<char>,
68 position: usize,
69 current_char: Option<char>,
70}
71
72impl Lexer {
73 #[must_use]
74 pub fn new(input: &str) -> Self {
75 let chars: Vec<char> = input.chars().collect();
76 let current = chars.first().copied();
77 Self {
78 input: chars,
79 position: 0,
80 current_char: current,
81 }
82 }
83
84 fn advance(&mut self) {
85 self.position += 1;
86 self.current_char = self.input.get(self.position).copied();
87 }
88
89 fn peek(&self, offset: usize) -> Option<char> {
90 self.input.get(self.position + offset).copied()
91 }
92
93 fn skip_whitespace(&mut self) {
94 while let Some(ch) = self.current_char {
95 if ch.is_whitespace() {
96 self.advance();
97 } else {
98 break;
99 }
100 }
101 }
102
103 fn skip_whitespace_and_comments(&mut self) {
104 loop {
105 while let Some(ch) = self.current_char {
107 if ch.is_whitespace() {
108 self.advance();
109 } else {
110 break;
111 }
112 }
113
114 match self.current_char {
116 Some('-') if self.peek(1) == Some('-') => {
117 self.advance(); self.advance(); while let Some(ch) = self.current_char {
121 self.advance();
122 if ch == '\n' {
123 break;
124 }
125 }
126 }
127 Some('/') if self.peek(1) == Some('*') => {
128 self.advance(); self.advance(); while let Some(ch) = self.current_char {
132 if ch == '*' && self.peek(1) == Some('/') {
133 self.advance(); self.advance(); break;
136 }
137 self.advance();
138 }
139 }
140 _ => {
141 break;
143 }
144 }
145 }
146 }
147
148 fn read_identifier(&mut self) -> String {
149 let mut result = String::new();
150 while let Some(ch) = self.current_char {
151 if ch.is_alphanumeric() || ch == '_' {
152 result.push(ch);
153 self.advance();
154 } else {
155 break;
156 }
157 }
158 result
159 }
160
161 fn read_string(&mut self) -> String {
162 let mut result = String::new();
163 let quote_char = self.current_char.unwrap(); self.advance(); while let Some(ch) = self.current_char {
167 if ch == quote_char {
168 self.advance(); break;
170 }
171 result.push(ch);
172 self.advance();
173 }
174 result
175 }
176
177 fn read_number(&mut self) -> String {
178 let mut result = String::new();
179 let mut has_e = false;
180
181 while let Some(ch) = self.current_char {
183 if !has_e && (ch.is_numeric() || ch == '.') {
184 result.push(ch);
185 self.advance();
186 } else if (ch == 'e' || ch == 'E') && !has_e && !result.is_empty() {
187 result.push(ch);
189 self.advance();
190 has_e = true;
191
192 if let Some(sign) = self.current_char {
194 if sign == '+' || sign == '-' {
195 result.push(sign);
196 self.advance();
197 }
198 }
199
200 while let Some(digit) = self.current_char {
202 if digit.is_numeric() {
203 result.push(digit);
204 self.advance();
205 } else {
206 break;
207 }
208 }
209 break; } else {
211 break;
212 }
213 }
214 result
215 }
216
217 pub fn next_token(&mut self) -> Token {
218 self.skip_whitespace_and_comments();
219
220 match self.current_char {
221 None => Token::Eof,
222 Some('*') => {
223 self.advance();
224 Token::Star }
228 Some('+') => {
229 self.advance();
230 Token::Plus
231 }
232 Some('/') => {
233 if self.peek(1) == Some('*') {
235 self.skip_whitespace_and_comments();
238 return self.next_token();
239 }
240 self.advance();
241 Token::Divide
242 }
243 Some('.') => {
244 self.advance();
245 Token::Dot
246 }
247 Some(',') => {
248 self.advance();
249 Token::Comma
250 }
251 Some('(') => {
252 self.advance();
253 Token::LeftParen
254 }
255 Some(')') => {
256 self.advance();
257 Token::RightParen
258 }
259 Some('=') => {
260 self.advance();
261 Token::Equal
262 }
263 Some('<') => {
264 self.advance();
265 if self.current_char == Some('=') {
266 self.advance();
267 Token::LessThanOrEqual
268 } else if self.current_char == Some('>') {
269 self.advance();
270 Token::NotEqual
271 } else {
272 Token::LessThan
273 }
274 }
275 Some('>') => {
276 self.advance();
277 if self.current_char == Some('=') {
278 self.advance();
279 Token::GreaterThanOrEqual
280 } else {
281 Token::GreaterThan
282 }
283 }
284 Some('!') if self.peek(1) == Some('=') => {
285 self.advance();
286 self.advance();
287 Token::NotEqual
288 }
289 Some('"') => {
290 let ident_val = self.read_string();
292 Token::QuotedIdentifier(ident_val)
293 }
294 Some('\'') => {
295 let string_val = self.read_string();
297 Token::StringLiteral(string_val)
298 }
299 Some('-') if self.peek(1) == Some('-') => {
300 self.skip_whitespace_and_comments();
302 self.next_token()
303 }
304 Some('-') if self.peek(1).is_some_and(char::is_numeric) => {
305 self.advance(); let num = self.read_number();
308 Token::NumberLiteral(format!("-{num}"))
309 }
310 Some('-') => {
311 self.advance();
313 Token::Minus
314 }
315 Some(ch) if ch.is_numeric() => {
316 let num = self.read_number();
317 Token::NumberLiteral(num)
318 }
319 Some(ch) if ch.is_alphabetic() || ch == '_' => {
320 let ident = self.read_identifier();
321 match ident.to_uppercase().as_str() {
322 "SELECT" => Token::Select,
323 "FROM" => Token::From,
324 "WHERE" => Token::Where,
325 "WITH" => Token::With,
326 "AND" => Token::And,
327 "OR" => Token::Or,
328 "IN" => Token::In,
329 "NOT" => Token::Not,
330 "BETWEEN" => Token::Between,
331 "LIKE" => Token::Like,
332 "IS" => Token::Is,
333 "NULL" => Token::Null,
334 "ORDER" if self.peek_keyword("BY") => {
335 self.skip_whitespace();
336 self.read_identifier(); Token::OrderBy
338 }
339 "GROUP" if self.peek_keyword("BY") => {
340 self.skip_whitespace();
341 self.read_identifier(); Token::GroupBy
343 }
344 "HAVING" => Token::Having,
345 "AS" => Token::As,
346 "ASC" => Token::Asc,
347 "DESC" => Token::Desc,
348 "LIMIT" => Token::Limit,
349 "OFFSET" => Token::Offset,
350 "DATETIME" => Token::DateTime,
351 "CASE" => Token::Case,
352 "WHEN" => Token::When,
353 "THEN" => Token::Then,
354 "ELSE" => Token::Else,
355 "END" => Token::End,
356 "DISTINCT" => Token::Distinct,
357 "OVER" => Token::Over,
358 "PARTITION" => Token::Partition,
359 "BY" => Token::By,
360 _ => Token::Identifier(ident),
361 }
362 }
363 Some(ch) => {
364 self.advance();
365 Token::Identifier(ch.to_string())
366 }
367 }
368 }
369
370 fn peek_keyword(&mut self, keyword: &str) -> bool {
371 let saved_pos = self.position;
372 let saved_char = self.current_char;
373
374 self.skip_whitespace_and_comments();
375 let next_word = self.read_identifier();
376 let matches = next_word.to_uppercase() == keyword;
377
378 self.position = saved_pos;
380 self.current_char = saved_char;
381
382 matches
383 }
384
385 #[must_use]
386 pub fn get_position(&self) -> usize {
387 self.position
388 }
389
390 pub fn tokenize_all(&mut self) -> Vec<Token> {
391 let mut tokens = Vec::new();
392 loop {
393 let token = self.next_token();
394 if matches!(token, Token::Eof) {
395 tokens.push(token);
396 break;
397 }
398 tokens.push(token);
399 }
400 tokens
401 }
402
403 pub fn tokenize_all_with_positions(&mut self) -> Vec<(usize, usize, Token)> {
404 let mut tokens = Vec::new();
405 loop {
406 self.skip_whitespace_and_comments();
407 let start_pos = self.position;
408 let token = self.next_token();
409 let end_pos = self.position;
410
411 if matches!(token, Token::Eof) {
412 break;
413 }
414 tokens.push((start_pos, end_pos, token));
415 }
416 tokens
417 }
418}
419
420#[derive(Debug, Clone)]
422pub enum SqlExpression {
423 Column(String),
424 StringLiteral(String),
425 NumberLiteral(String),
426 BooleanLiteral(bool),
427 Null, DateTimeConstructor {
429 year: i32,
430 month: u32,
431 day: u32,
432 hour: Option<u32>,
433 minute: Option<u32>,
434 second: Option<u32>,
435 },
436 DateTimeToday {
437 hour: Option<u32>,
438 minute: Option<u32>,
439 second: Option<u32>,
440 },
441 MethodCall {
442 object: String,
443 method: String,
444 args: Vec<SqlExpression>,
445 },
446 ChainedMethodCall {
447 base: Box<SqlExpression>,
448 method: String,
449 args: Vec<SqlExpression>,
450 },
451 FunctionCall {
452 name: String,
453 args: Vec<SqlExpression>,
454 },
455 WindowFunction {
456 name: String,
457 args: Vec<SqlExpression>,
458 window_spec: WindowSpec,
459 },
460 BinaryOp {
461 left: Box<SqlExpression>,
462 op: String,
463 right: Box<SqlExpression>,
464 },
465 InList {
466 expr: Box<SqlExpression>,
467 values: Vec<SqlExpression>,
468 },
469 NotInList {
470 expr: Box<SqlExpression>,
471 values: Vec<SqlExpression>,
472 },
473 Between {
474 expr: Box<SqlExpression>,
475 lower: Box<SqlExpression>,
476 upper: Box<SqlExpression>,
477 },
478 Not {
479 expr: Box<SqlExpression>,
480 },
481 CaseExpression {
482 when_branches: Vec<WhenBranch>,
483 else_branch: Option<Box<SqlExpression>>,
484 },
485}
486
487#[derive(Debug, Clone)]
488pub struct WhenBranch {
489 pub condition: Box<SqlExpression>,
490 pub result: Box<SqlExpression>,
491}
492
493#[derive(Debug, Clone)]
494pub struct WhereClause {
495 pub conditions: Vec<Condition>,
496}
497
498#[derive(Debug, Clone)]
499pub struct Condition {
500 pub expr: SqlExpression,
501 pub connector: Option<LogicalOp>, }
503
504#[derive(Debug, Clone)]
505pub enum LogicalOp {
506 And,
507 Or,
508}
509
510#[derive(Debug, Clone, PartialEq)]
511pub enum SortDirection {
512 Asc,
513 Desc,
514}
515
516#[derive(Debug, Clone)]
517pub struct OrderByColumn {
518 pub column: String,
519 pub direction: SortDirection,
520}
521
522#[derive(Debug, Clone)]
523pub struct WindowSpec {
524 pub partition_by: Vec<String>,
525 pub order_by: Vec<OrderByColumn>,
526}
527
528#[derive(Debug, Clone)]
530pub enum SelectItem {
531 Column(String),
533 Expression { expr: SqlExpression, alias: String },
535 Star,
537}
538
539#[derive(Debug, Clone)]
540pub struct SelectStatement {
541 pub columns: Vec<String>, pub select_items: Vec<SelectItem>, pub from_table: Option<String>,
544 pub from_subquery: Option<Box<SelectStatement>>, pub from_alias: Option<String>, pub where_clause: Option<WhereClause>,
547 pub order_by: Option<Vec<OrderByColumn>>,
548 pub group_by: Option<Vec<String>>,
549 pub having: Option<SqlExpression>, pub limit: Option<usize>,
551 pub offset: Option<usize>,
552 pub ctes: Vec<CTE>, }
554
555#[derive(Debug, Clone)]
557pub struct CTE {
558 pub name: String,
559 pub column_list: Option<Vec<String>>, pub query: SelectStatement,
561}
562
563#[derive(Debug, Clone)]
565pub enum TableSource {
566 Table(String), DerivedTable {
568 query: Box<SelectStatement>,
570 alias: String, },
572}
573
574#[derive(Default)]
575pub struct ParserConfig {
576 pub case_insensitive: bool,
577}
578
579pub struct Parser {
580 lexer: Lexer,
581 current_token: Token,
582 in_method_args: bool, columns: Vec<String>, paren_depth: i32, #[allow(dead_code)]
586 config: ParserConfig, }
588
589impl Parser {
590 #[must_use]
591 pub fn new(input: &str) -> Self {
592 let mut lexer = Lexer::new(input);
593 let current_token = lexer.next_token();
594 Self {
595 lexer,
596 current_token,
597 in_method_args: false,
598 columns: Vec::new(),
599 paren_depth: 0,
600 config: ParserConfig::default(),
601 }
602 }
603
604 #[must_use]
605 pub fn with_config(input: &str, config: ParserConfig) -> Self {
606 let mut lexer = Lexer::new(input);
607 let current_token = lexer.next_token();
608 Self {
609 lexer,
610 current_token,
611 in_method_args: false,
612 columns: Vec::new(),
613 paren_depth: 0,
614 config,
615 }
616 }
617
618 #[must_use]
619 pub fn with_columns(mut self, columns: Vec<String>) -> Self {
620 self.columns = columns;
621 self
622 }
623
624 fn consume(&mut self, expected: Token) -> Result<(), String> {
625 if std::mem::discriminant(&self.current_token) == std::mem::discriminant(&expected) {
626 match &expected {
628 Token::LeftParen => self.paren_depth += 1,
629 Token::RightParen => {
630 self.paren_depth -= 1;
631 if self.paren_depth < 0 {
633 return Err(
634 "Unexpected closing parenthesis - no matching opening parenthesis"
635 .to_string(),
636 );
637 }
638 }
639 _ => {}
640 }
641
642 self.current_token = self.lexer.next_token();
643 Ok(())
644 } else {
645 let error_msg = match (&expected, &self.current_token) {
647 (Token::RightParen, Token::Eof) if self.paren_depth > 0 => {
648 format!(
649 "Unclosed parenthesis - missing {} closing parenthes{}",
650 self.paren_depth,
651 if self.paren_depth == 1 { "is" } else { "es" }
652 )
653 }
654 (Token::RightParen, _) if self.paren_depth > 0 => {
655 format!(
656 "Expected closing parenthesis but found {:?} (currently {} unclosed parenthes{})",
657 self.current_token,
658 self.paren_depth,
659 if self.paren_depth == 1 { "is" } else { "es" }
660 )
661 }
662 _ => format!("Expected {:?}, found {:?}", expected, self.current_token),
663 };
664 Err(error_msg)
665 }
666 }
667
668 fn advance(&mut self) {
669 match &self.current_token {
671 Token::LeftParen => self.paren_depth += 1,
672 Token::RightParen => {
673 self.paren_depth -= 1;
674 }
677 _ => {}
678 }
679 self.current_token = self.lexer.next_token();
680 }
681
682 pub fn parse(&mut self) -> Result<SelectStatement, String> {
683 if matches!(self.current_token, Token::With) {
685 self.parse_with_clause()
686 } else {
687 self.parse_select_statement()
688 }
689 }
690
691 fn parse_with_clause(&mut self) -> Result<SelectStatement, String> {
692 self.consume(Token::With)?;
693
694 let mut ctes = Vec::new();
695
696 loop {
698 let name = match &self.current_token {
700 Token::Identifier(name) => name.clone(),
701 _ => return Err("Expected CTE name after WITH".to_string()),
702 };
703 self.advance();
704
705 let column_list = if matches!(self.current_token, Token::LeftParen) {
707 self.advance();
708 let cols = self.parse_identifier_list()?;
709 self.consume(Token::RightParen)?;
710 Some(cols)
711 } else {
712 None
713 };
714
715 self.consume(Token::As)?;
717
718 self.consume(Token::LeftParen)?;
720
721 let query = self.parse_select_statement_inner()?;
723
724 self.consume(Token::RightParen)?;
726
727 ctes.push(CTE {
728 name,
729 column_list,
730 query,
731 });
732
733 if !matches!(self.current_token, Token::Comma) {
735 break;
736 }
737 self.advance();
738 }
739
740 let mut main_query = self.parse_select_statement()?;
742 main_query.ctes = ctes;
743
744 Ok(main_query)
745 }
746
747 fn parse_select_statement(&mut self) -> Result<SelectStatement, String> {
748 let result = self.parse_select_statement_inner()?;
749
750 if self.paren_depth > 0 {
752 return Err(format!(
753 "Unclosed parenthesis - missing {} closing parenthes{}",
754 self.paren_depth,
755 if self.paren_depth == 1 { "is" } else { "es" }
756 ));
757 } else if self.paren_depth < 0 {
758 return Err(
759 "Extra closing parenthesis found - no matching opening parenthesis".to_string(),
760 );
761 }
762
763 Ok(result)
764 }
765
766 fn parse_select_statement_inner(&mut self) -> Result<SelectStatement, String> {
767 self.consume(Token::Select)?;
768
769 let select_items = self.parse_select_items()?;
771
772 let columns = select_items
774 .iter()
775 .map(|item| match item {
776 SelectItem::Star => "*".to_string(),
777 SelectItem::Column(name) => name.clone(),
778 SelectItem::Expression { alias, .. } => alias.clone(),
779 })
780 .collect();
781
782 let (from_table, from_subquery, from_alias) = if matches!(self.current_token, Token::From) {
784 self.advance();
785
786 if matches!(self.current_token, Token::LeftParen) {
788 self.advance();
789
790 let subquery = self.parse_select_statement_inner()?;
792
793 self.consume(Token::RightParen)?;
794
795 let alias = if matches!(self.current_token, Token::As) {
797 self.advance();
798 match &self.current_token {
799 Token::Identifier(name) => {
800 let alias = name.clone();
801 self.advance();
802 alias
803 }
804 _ => return Err("Expected alias name after AS".to_string()),
805 }
806 } else {
807 match &self.current_token {
809 Token::Identifier(name) => {
810 let alias = name.clone();
811 self.advance();
812 alias
813 }
814 _ => {
815 return Err(
816 "Subquery in FROM must have an alias (e.g., AS t)".to_string()
817 )
818 }
819 }
820 };
821
822 (None, Some(Box::new(subquery)), Some(alias))
823 } else {
824 match &self.current_token {
826 Token::Identifier(table) => {
827 let table_name = table.clone();
828 self.advance();
829
830 let alias = if matches!(self.current_token, Token::As) {
832 self.advance();
833 match &self.current_token {
834 Token::Identifier(name) => {
835 let alias = name.clone();
836 self.advance();
837 Some(alias)
838 }
839 _ => return Err("Expected alias name after AS".to_string()),
840 }
841 } else if let Token::Identifier(name) = &self.current_token {
842 let alias = name.clone();
844 self.advance();
845 Some(alias)
846 } else {
847 None
848 };
849
850 (Some(table_name), None, alias)
851 }
852 Token::QuotedIdentifier(table) => {
853 let table_name = table.clone();
855 self.advance();
856
857 let alias = if matches!(self.current_token, Token::As) {
859 self.advance();
860 match &self.current_token {
861 Token::Identifier(name) => {
862 let alias = name.clone();
863 self.advance();
864 Some(alias)
865 }
866 _ => return Err("Expected alias name after AS".to_string()),
867 }
868 } else if let Token::Identifier(name) = &self.current_token {
869 let alias = name.clone();
871 self.advance();
872 Some(alias)
873 } else {
874 None
875 };
876
877 (Some(table_name), None, alias)
878 }
879 _ => return Err("Expected table name or subquery after FROM".to_string()),
880 }
881 }
882 } else {
883 (None, None, None)
884 };
885
886 let where_clause = if matches!(self.current_token, Token::Where) {
887 self.advance();
888 Some(self.parse_where_clause()?)
889 } else {
890 None
891 };
892
893 let order_by = if matches!(self.current_token, Token::OrderBy) {
894 self.advance();
895 Some(self.parse_order_by_list()?)
896 } else {
897 None
898 };
899
900 let group_by = if matches!(self.current_token, Token::GroupBy) {
901 self.advance();
902 Some(self.parse_identifier_list()?)
903 } else {
904 None
905 };
906
907 let having = if matches!(self.current_token, Token::Having) {
909 if group_by.is_none() {
910 return Err("HAVING clause requires GROUP BY".to_string());
911 }
912 self.advance();
913 Some(self.parse_expression()?)
914 } else {
915 None
916 };
917
918 let limit = if matches!(self.current_token, Token::Limit) {
920 self.advance();
921 match &self.current_token {
922 Token::NumberLiteral(num) => {
923 let limit_val = num
924 .parse::<usize>()
925 .map_err(|_| format!("Invalid LIMIT value: {num}"))?;
926 self.advance();
927 Some(limit_val)
928 }
929 _ => return Err("Expected number after LIMIT".to_string()),
930 }
931 } else {
932 None
933 };
934
935 let offset = if matches!(self.current_token, Token::Offset) {
937 self.advance();
938 match &self.current_token {
939 Token::NumberLiteral(num) => {
940 let offset_val = num
941 .parse::<usize>()
942 .map_err(|_| format!("Invalid OFFSET value: {num}"))?;
943 self.advance();
944 Some(offset_val)
945 }
946 _ => return Err("Expected number after OFFSET".to_string()),
947 }
948 } else {
949 None
950 };
951
952 Ok(SelectStatement {
953 columns,
954 select_items,
955 from_table,
956 from_subquery,
957 from_alias,
958 where_clause,
959 order_by,
960 group_by,
961 having,
962 limit,
963 offset,
964 ctes: Vec::new(), })
966 }
967
968 fn parse_select_list(&mut self) -> Result<Vec<String>, String> {
969 let mut columns = Vec::new();
970
971 if matches!(self.current_token, Token::Star) {
972 columns.push("*".to_string());
973 self.advance();
974 } else {
975 loop {
976 match &self.current_token {
977 Token::Identifier(col) => {
978 columns.push(col.clone());
979 self.advance();
980 }
981 Token::QuotedIdentifier(col) => {
982 columns.push(col.clone());
984 self.advance();
985 }
986 _ => return Err("Expected column name".to_string()),
987 }
988
989 if matches!(self.current_token, Token::Comma) {
990 self.advance();
991 } else {
992 break;
993 }
994 }
995 }
996
997 Ok(columns)
998 }
999
1000 fn parse_select_items(&mut self) -> Result<Vec<SelectItem>, String> {
1002 let mut items = Vec::new();
1003
1004 loop {
1005 if matches!(self.current_token, Token::Star) {
1008 items.push(SelectItem::Star);
1016 self.advance();
1017 } else {
1018 let expr = self.parse_comparison()?; let alias = if matches!(self.current_token, Token::As) {
1023 self.advance();
1024 match &self.current_token {
1025 Token::Identifier(alias_name) => {
1026 let alias = alias_name.clone();
1027 self.advance();
1028 alias
1029 }
1030 Token::QuotedIdentifier(alias_name) => {
1031 let alias = alias_name.clone();
1032 self.advance();
1033 alias
1034 }
1035 _ => return Err("Expected alias name after AS".to_string()),
1036 }
1037 } else {
1038 match &expr {
1040 SqlExpression::Column(col_name) => col_name.clone(),
1041 _ => format!("expr_{}", items.len() + 1), }
1043 };
1044
1045 let item = match expr {
1047 SqlExpression::Column(col_name) if alias == col_name => {
1048 SelectItem::Column(col_name)
1050 }
1051 _ => {
1052 SelectItem::Expression { expr, alias }
1054 }
1055 };
1056
1057 items.push(item);
1058 }
1059
1060 if matches!(self.current_token, Token::Comma) {
1062 self.advance();
1063 } else {
1064 break;
1065 }
1066 }
1067
1068 Ok(items)
1069 }
1070
1071 fn parse_identifier_list(&mut self) -> Result<Vec<String>, String> {
1072 let mut identifiers = Vec::new();
1073
1074 loop {
1075 match &self.current_token {
1076 Token::Identifier(id) => {
1077 identifiers.push(id.clone());
1078 self.advance();
1079 }
1080 Token::QuotedIdentifier(id) => {
1081 identifiers.push(id.clone());
1083 self.advance();
1084 }
1085 _ => return Err("Expected identifier".to_string()),
1086 }
1087
1088 if matches!(self.current_token, Token::Comma) {
1089 self.advance();
1090 } else {
1091 break;
1092 }
1093 }
1094
1095 Ok(identifiers)
1096 }
1097
1098 fn parse_window_spec(&mut self) -> Result<WindowSpec, String> {
1099 let mut partition_by = Vec::new();
1100 let mut order_by = Vec::new();
1101
1102 if matches!(self.current_token, Token::Partition) {
1104 self.advance(); if !matches!(self.current_token, Token::By) {
1106 return Err("Expected BY after PARTITION".to_string());
1107 }
1108 self.advance(); partition_by = self.parse_identifier_list()?;
1112 }
1113
1114 if matches!(self.current_token, Token::OrderBy) {
1116 self.advance(); order_by = self.parse_order_by_list()?;
1118 } else if let Token::Identifier(s) = &self.current_token {
1119 if s.to_uppercase() == "ORDER" {
1120 self.advance(); if !matches!(self.current_token, Token::By) {
1123 return Err("Expected BY after ORDER".to_string());
1124 }
1125 self.advance(); order_by = self.parse_order_by_list()?;
1127 }
1128 }
1129
1130 Ok(WindowSpec {
1131 partition_by,
1132 order_by,
1133 })
1134 }
1135
1136 fn parse_order_by_list(&mut self) -> Result<Vec<OrderByColumn>, String> {
1137 let mut order_columns = Vec::new();
1138
1139 loop {
1140 let column = match &self.current_token {
1141 Token::Identifier(id) => {
1142 let col = id.clone();
1143 self.advance();
1144 col
1145 }
1146 Token::QuotedIdentifier(id) => {
1147 let col = id.clone();
1148 self.advance();
1149 col
1150 }
1151 Token::NumberLiteral(num) if self.columns.iter().any(|col| col == num) => {
1152 let col = num.clone();
1154 self.advance();
1155 col
1156 }
1157 _ => return Err("Expected column name in ORDER BY".to_string()),
1158 };
1159
1160 let direction = match &self.current_token {
1162 Token::Asc => {
1163 self.advance();
1164 SortDirection::Asc
1165 }
1166 Token::Desc => {
1167 self.advance();
1168 SortDirection::Desc
1169 }
1170 _ => SortDirection::Asc, };
1172
1173 order_columns.push(OrderByColumn { column, direction });
1174
1175 if matches!(self.current_token, Token::Comma) {
1176 self.advance();
1177 } else {
1178 break;
1179 }
1180 }
1181
1182 Ok(order_columns)
1183 }
1184
1185 fn parse_where_clause(&mut self) -> Result<WhereClause, String> {
1186 let mut conditions = Vec::new();
1187
1188 loop {
1189 let expr = self.parse_expression()?;
1190
1191 let connector = match &self.current_token {
1192 Token::And => {
1193 self.advance();
1194 Some(LogicalOp::And)
1195 }
1196 Token::Or => {
1197 self.advance();
1198 Some(LogicalOp::Or)
1199 }
1200 Token::RightParen if self.paren_depth <= 0 => {
1201 return Err(
1203 "Unexpected closing parenthesis - no matching opening parenthesis"
1204 .to_string(),
1205 );
1206 }
1207 _ => None,
1208 };
1209
1210 conditions.push(Condition {
1211 expr,
1212 connector: connector.clone(),
1213 });
1214
1215 if connector.is_none() {
1216 break;
1217 }
1218 }
1219
1220 Ok(WhereClause { conditions })
1221 }
1222
1223 fn parse_expression(&mut self) -> Result<SqlExpression, String> {
1224 let mut left = self.parse_comparison()?;
1225
1226 if let Some(op) = self.get_binary_op() {
1229 self.advance();
1230 let right = self.parse_expression()?;
1231 left = SqlExpression::BinaryOp {
1232 left: Box::new(left),
1233 op,
1234 right: Box::new(right),
1235 };
1236 }
1237
1238 if matches!(self.current_token, Token::In) {
1240 self.advance();
1241 self.consume(Token::LeftParen)?;
1242 let values = self.parse_expression_list()?;
1243 self.consume(Token::RightParen)?;
1244
1245 left = SqlExpression::InList {
1246 expr: Box::new(left),
1247 values,
1248 };
1249 }
1250
1251 Ok(left)
1255 }
1256
1257 fn parse_comparison(&mut self) -> Result<SqlExpression, String> {
1258 let mut left = self.parse_additive()?;
1259
1260 if matches!(self.current_token, Token::Between) {
1262 self.advance(); let lower = self.parse_primary()?;
1264 self.consume(Token::And)?; let upper = self.parse_primary()?;
1266
1267 return Ok(SqlExpression::Between {
1268 expr: Box::new(left),
1269 lower: Box::new(lower),
1270 upper: Box::new(upper),
1271 });
1272 }
1273
1274 if matches!(self.current_token, Token::Not) {
1276 self.advance(); if matches!(self.current_token, Token::In) {
1278 self.advance(); self.consume(Token::LeftParen)?;
1280 let values = self.parse_expression_list()?;
1281 self.consume(Token::RightParen)?;
1282
1283 return Ok(SqlExpression::NotInList {
1284 expr: Box::new(left),
1285 values,
1286 });
1287 }
1288 return Err("Expected IN after NOT".to_string());
1289 }
1290
1291 if matches!(self.current_token, Token::Is) {
1293 self.advance(); if matches!(self.current_token, Token::Not) {
1295 self.advance(); if matches!(self.current_token, Token::Null) {
1297 self.advance(); left = SqlExpression::BinaryOp {
1299 left: Box::new(left),
1300 op: "IS NOT NULL".to_string(),
1301 right: Box::new(SqlExpression::Null),
1302 };
1303 } else {
1304 return Err("Expected NULL after IS NOT".to_string());
1305 }
1306 } else if matches!(self.current_token, Token::Null) {
1307 self.advance(); left = SqlExpression::BinaryOp {
1309 left: Box::new(left),
1310 op: "IS NULL".to_string(),
1311 right: Box::new(SqlExpression::Null),
1312 };
1313 } else {
1314 return Err("Expected NULL or NOT after IS".to_string());
1315 }
1316 }
1317 else if let Some(op) = self.get_binary_op() {
1319 self.advance();
1320 let right = self.parse_additive()?;
1321 left = SqlExpression::BinaryOp {
1322 left: Box::new(left),
1323 op,
1324 right: Box::new(right),
1325 };
1326 }
1327
1328 Ok(left)
1329 }
1330
1331 fn parse_additive(&mut self) -> Result<SqlExpression, String> {
1332 let mut left = self.parse_multiplicative()?;
1333
1334 while matches!(self.current_token, Token::Plus | Token::Minus) {
1335 let op = match self.current_token {
1336 Token::Plus => "+",
1337 Token::Minus => "-",
1338 _ => unreachable!(),
1339 };
1340 self.advance();
1341 let right = self.parse_multiplicative()?;
1342 left = SqlExpression::BinaryOp {
1343 left: Box::new(left),
1344 op: op.to_string(),
1345 right: Box::new(right),
1346 };
1347 }
1348
1349 Ok(left)
1350 }
1351
1352 fn parse_multiplicative(&mut self) -> Result<SqlExpression, String> {
1353 let mut left = self.parse_primary()?;
1354
1355 while matches!(self.current_token, Token::Dot) {
1357 self.advance();
1358 if let Token::Identifier(method) = &self.current_token {
1359 let method_name = method.clone();
1360 self.advance();
1361
1362 if matches!(self.current_token, Token::LeftParen) {
1363 self.advance();
1364 let args = self.parse_method_args()?;
1365 self.consume(Token::RightParen)?;
1366
1367 match left {
1369 SqlExpression::Column(obj) => {
1370 left = SqlExpression::MethodCall {
1372 object: obj,
1373 method: method_name,
1374 args,
1375 };
1376 }
1377 SqlExpression::MethodCall { .. }
1378 | SqlExpression::ChainedMethodCall { .. } => {
1379 left = SqlExpression::ChainedMethodCall {
1381 base: Box::new(left),
1382 method: method_name,
1383 args,
1384 };
1385 }
1386 _ => {
1387 left = SqlExpression::ChainedMethodCall {
1389 base: Box::new(left),
1390 method: method_name,
1391 args,
1392 };
1393 }
1394 }
1395 } else {
1396 return Err(format!("Expected '(' after method name '{method_name}'"));
1397 }
1398 } else {
1399 return Err("Expected method name after '.'".to_string());
1400 }
1401 }
1402
1403 while matches!(self.current_token, Token::Star | Token::Divide) {
1404 let op = match self.current_token {
1405 Token::Star => "*",
1406 Token::Divide => "/",
1407 _ => unreachable!(),
1408 };
1409 self.advance();
1410 let right = self.parse_primary()?;
1411 left = SqlExpression::BinaryOp {
1412 left: Box::new(left),
1413 op: op.to_string(),
1414 right: Box::new(right),
1415 };
1416 }
1417
1418 Ok(left)
1419 }
1420
1421 fn parse_logical_or(&mut self) -> Result<SqlExpression, String> {
1422 let mut left = self.parse_logical_and()?;
1423
1424 while matches!(self.current_token, Token::Or) {
1425 self.advance();
1426 let right = self.parse_logical_and()?;
1427 left = SqlExpression::BinaryOp {
1431 left: Box::new(left),
1432 op: "OR".to_string(),
1433 right: Box::new(right),
1434 };
1435 }
1436
1437 Ok(left)
1438 }
1439
1440 fn parse_logical_and(&mut self) -> Result<SqlExpression, String> {
1441 let mut left = self.parse_expression()?;
1442
1443 while matches!(self.current_token, Token::And) {
1444 self.advance();
1445 let right = self.parse_expression()?;
1446 left = SqlExpression::BinaryOp {
1448 left: Box::new(left),
1449 op: "AND".to_string(),
1450 right: Box::new(right),
1451 };
1452 }
1453
1454 Ok(left)
1455 }
1456
1457 fn parse_case_expression(&mut self) -> Result<SqlExpression, String> {
1458 self.consume(Token::Case)?;
1460
1461 let mut when_branches = Vec::new();
1462
1463 while matches!(self.current_token, Token::When) {
1465 self.advance(); let condition = self.parse_expression()?;
1469
1470 self.consume(Token::Then)?;
1472
1473 let result = self.parse_expression()?;
1475
1476 when_branches.push(WhenBranch {
1477 condition: Box::new(condition),
1478 result: Box::new(result),
1479 });
1480 }
1481
1482 if when_branches.is_empty() {
1484 return Err("CASE expression must have at least one WHEN clause".to_string());
1485 }
1486
1487 let else_branch = if matches!(self.current_token, Token::Else) {
1489 self.advance(); Some(Box::new(self.parse_expression()?))
1491 } else {
1492 None
1493 };
1494
1495 self.consume(Token::End)?;
1497
1498 Ok(SqlExpression::CaseExpression {
1499 when_branches,
1500 else_branch,
1501 })
1502 }
1503
1504 fn parse_primary(&mut self) -> Result<SqlExpression, String> {
1505 if let Token::NumberLiteral(num_str) = &self.current_token {
1508 if self.columns.iter().any(|col| col == num_str) {
1510 let expr = SqlExpression::Column(num_str.clone());
1511 self.advance();
1512 return Ok(expr);
1513 }
1514 }
1515
1516 match &self.current_token {
1517 Token::Case => {
1518 self.parse_case_expression()
1520 }
1521 Token::DateTime => {
1522 self.advance(); self.consume(Token::LeftParen)?;
1524
1525 if matches!(&self.current_token, Token::RightParen) {
1527 self.advance(); return Ok(SqlExpression::DateTimeToday {
1529 hour: None,
1530 minute: None,
1531 second: None,
1532 });
1533 }
1534
1535 let year = if let Token::NumberLiteral(n) = &self.current_token {
1537 n.parse::<i32>().map_err(|_| "Invalid year")?
1538 } else {
1539 return Err("Expected year in DateTime constructor".to_string());
1540 };
1541 self.advance();
1542 self.consume(Token::Comma)?;
1543
1544 let month = if let Token::NumberLiteral(n) = &self.current_token {
1546 n.parse::<u32>().map_err(|_| "Invalid month")?
1547 } else {
1548 return Err("Expected month in DateTime constructor".to_string());
1549 };
1550 self.advance();
1551 self.consume(Token::Comma)?;
1552
1553 let day = if let Token::NumberLiteral(n) = &self.current_token {
1555 n.parse::<u32>().map_err(|_| "Invalid day")?
1556 } else {
1557 return Err("Expected day in DateTime constructor".to_string());
1558 };
1559 self.advance();
1560
1561 let mut hour = None;
1563 let mut minute = None;
1564 let mut second = None;
1565
1566 if matches!(&self.current_token, Token::Comma) {
1567 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1571 hour = Some(n.parse::<u32>().map_err(|_| "Invalid hour")?);
1572 self.advance();
1573
1574 if matches!(&self.current_token, Token::Comma) {
1576 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1579 minute = Some(n.parse::<u32>().map_err(|_| "Invalid minute")?);
1580 self.advance();
1581
1582 if matches!(&self.current_token, Token::Comma) {
1584 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1587 second =
1588 Some(n.parse::<u32>().map_err(|_| "Invalid second")?);
1589 self.advance();
1590 }
1591 }
1592 }
1593 }
1594 }
1595 }
1596
1597 self.consume(Token::RightParen)?;
1598 Ok(SqlExpression::DateTimeConstructor {
1599 year,
1600 month,
1601 day,
1602 hour,
1603 minute,
1604 second,
1605 })
1606 }
1607 Token::Identifier(id) => {
1608 let id_upper = id.to_uppercase();
1609 let id_clone = id.clone();
1610
1611 if id_upper == "TRUE" {
1613 self.advance();
1614 return Ok(SqlExpression::BooleanLiteral(true));
1615 } else if id_upper == "FALSE" {
1616 self.advance();
1617 return Ok(SqlExpression::BooleanLiteral(false));
1618 }
1619
1620 self.advance();
1621
1622 if matches!(self.current_token, Token::LeftParen) {
1624 self.advance(); let args = self.parse_function_args()?;
1628 self.consume(Token::RightParen)?;
1629
1630 if matches!(self.current_token, Token::Over) {
1632 self.advance(); self.consume(Token::LeftParen)?;
1634 let window_spec = self.parse_window_spec()?;
1635 self.consume(Token::RightParen)?;
1636 return Ok(SqlExpression::WindowFunction {
1637 name: id_upper,
1638 args,
1639 window_spec,
1640 });
1641 }
1642
1643 return Ok(SqlExpression::FunctionCall {
1644 name: id_upper,
1645 args,
1646 });
1647 }
1648
1649 Ok(SqlExpression::Column(id_clone))
1651 }
1652 Token::QuotedIdentifier(id) => {
1653 let expr = if self.in_method_args {
1656 SqlExpression::StringLiteral(id.clone())
1657 } else {
1658 SqlExpression::Column(id.clone())
1660 };
1661 self.advance();
1662 Ok(expr)
1663 }
1664 Token::StringLiteral(s) => {
1665 let expr = SqlExpression::StringLiteral(s.clone());
1666 self.advance();
1667 Ok(expr)
1668 }
1669 Token::NumberLiteral(n) => {
1670 let expr = SqlExpression::NumberLiteral(n.clone());
1671 self.advance();
1672 Ok(expr)
1673 }
1674 Token::Null => {
1675 self.advance();
1676 Ok(SqlExpression::Null)
1677 }
1678 Token::LeftParen => {
1679 self.advance();
1680
1681 let expr = self.parse_logical_or()?;
1684
1685 self.consume(Token::RightParen)?;
1686 Ok(expr)
1687 }
1688 Token::Not => {
1689 self.advance(); if let Ok(inner_expr) = self.parse_comparison() {
1693 if matches!(self.current_token, Token::In) {
1695 self.advance(); self.consume(Token::LeftParen)?;
1697 let values = self.parse_expression_list()?;
1698 self.consume(Token::RightParen)?;
1699
1700 Ok(SqlExpression::NotInList {
1701 expr: Box::new(inner_expr),
1702 values,
1703 })
1704 } else {
1705 Ok(SqlExpression::Not {
1707 expr: Box::new(inner_expr),
1708 })
1709 }
1710 } else {
1711 Err("Expected expression after NOT".to_string())
1712 }
1713 }
1714 Token::Star => {
1715 self.advance();
1717 Ok(SqlExpression::StringLiteral("*".to_string()))
1718 }
1719 _ => Err(format!("Unexpected token: {:?}", self.current_token)),
1720 }
1721 }
1722
1723 fn parse_method_args(&mut self) -> Result<Vec<SqlExpression>, String> {
1724 let mut args = Vec::new();
1725
1726 self.in_method_args = true;
1728
1729 if !matches!(self.current_token, Token::RightParen) {
1730 loop {
1731 args.push(self.parse_expression()?);
1732
1733 if matches!(self.current_token, Token::Comma) {
1734 self.advance();
1735 } else {
1736 break;
1737 }
1738 }
1739 }
1740
1741 self.in_method_args = false;
1743
1744 Ok(args)
1745 }
1746
1747 fn parse_function_args(&mut self) -> Result<Vec<SqlExpression>, String> {
1748 let mut args = Vec::new();
1749
1750 if !matches!(self.current_token, Token::RightParen) {
1751 if matches!(self.current_token, Token::Distinct) {
1753 self.advance(); let expr = self.parse_additive()?;
1756 args.push(SqlExpression::FunctionCall {
1758 name: "DISTINCT".to_string(),
1759 args: vec![expr],
1760 });
1761 } else {
1762 args.push(self.parse_additive()?);
1764 }
1765
1766 while matches!(self.current_token, Token::Comma) {
1768 self.advance();
1769 args.push(self.parse_additive()?);
1770 }
1771 }
1772
1773 Ok(args)
1774 }
1775
1776 fn parse_expression_list(&mut self) -> Result<Vec<SqlExpression>, String> {
1777 let mut expressions = Vec::new();
1778
1779 loop {
1780 expressions.push(self.parse_expression()?);
1781
1782 if matches!(self.current_token, Token::Comma) {
1783 self.advance();
1784 } else {
1785 break;
1786 }
1787 }
1788
1789 Ok(expressions)
1790 }
1791
1792 fn get_binary_op(&self) -> Option<String> {
1793 match &self.current_token {
1794 Token::Equal => Some("=".to_string()),
1795 Token::NotEqual => Some("!=".to_string()),
1796 Token::LessThan => Some("<".to_string()),
1797 Token::GreaterThan => Some(">".to_string()),
1798 Token::LessThanOrEqual => Some("<=".to_string()),
1799 Token::GreaterThanOrEqual => Some(">=".to_string()),
1800 Token::Like => Some("LIKE".to_string()),
1801 _ => None,
1802 }
1803 }
1804
1805 fn get_arithmetic_op(&self) -> Option<String> {
1806 match &self.current_token {
1807 Token::Plus => Some("+".to_string()),
1808 Token::Minus => Some("-".to_string()),
1809 Token::Star => Some("*".to_string()), Token::Divide => Some("/".to_string()),
1811 _ => None,
1812 }
1813 }
1814
1815 #[must_use]
1816 pub fn get_position(&self) -> usize {
1817 self.lexer.get_position()
1818 }
1819}
1820
1821#[derive(Debug, Clone)]
1823pub enum CursorContext {
1824 SelectClause,
1825 FromClause,
1826 WhereClause,
1827 OrderByClause,
1828 AfterColumn(String),
1829 AfterLogicalOp(LogicalOp),
1830 AfterComparisonOp(String, String), InMethodCall(String, String), InExpression,
1833 Unknown,
1834}
1835
1836fn safe_slice_to(s: &str, pos: usize) -> &str {
1838 if pos >= s.len() {
1839 return s;
1840 }
1841
1842 let mut safe_pos = pos;
1844 while safe_pos > 0 && !s.is_char_boundary(safe_pos) {
1845 safe_pos -= 1;
1846 }
1847
1848 &s[..safe_pos]
1849}
1850
1851fn safe_slice_from(s: &str, pos: usize) -> &str {
1853 if pos >= s.len() {
1854 return "";
1855 }
1856
1857 let mut safe_pos = pos;
1859 while safe_pos < s.len() && !s.is_char_boundary(safe_pos) {
1860 safe_pos += 1;
1861 }
1862
1863 &s[safe_pos..]
1864}
1865
1866#[must_use]
1867pub fn detect_cursor_context(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
1868 let truncated = safe_slice_to(query, cursor_pos);
1869 let mut parser = Parser::new(truncated);
1870
1871 if let Ok(stmt) = parser.parse() {
1873 let (ctx, partial) = analyze_statement(&stmt, truncated, cursor_pos);
1874 #[cfg(test)]
1875 println!("analyze_statement returned: {ctx:?}, {partial:?} for query: '{truncated}'");
1876 (ctx, partial)
1877 } else {
1878 let (ctx, partial) = analyze_partial(truncated, cursor_pos);
1880 #[cfg(test)]
1881 println!("analyze_partial returned: {ctx:?}, {partial:?} for query: '{truncated}'");
1882 (ctx, partial)
1883 }
1884}
1885
1886#[must_use]
1887pub fn tokenize_query(query: &str) -> Vec<String> {
1888 let mut lexer = Lexer::new(query);
1889 let tokens = lexer.tokenize_all();
1890 tokens.iter().map(|t| format!("{t:?}")).collect()
1891}
1892
1893#[must_use]
1894pub fn format_sql_pretty(query: &str) -> Vec<String> {
1895 format_sql_pretty_compact(query, 5) }
1897
1898#[must_use]
1900pub fn format_ast_tree(query: &str) -> String {
1901 let mut parser = Parser::new(query);
1902 match parser.parse() {
1903 Ok(stmt) => format_select_statement(&stmt, 0),
1904 Err(e) => format!("❌ PARSE ERROR ❌\n{e}\n\n⚠️ The query could not be parsed correctly.\n💡 Check parentheses, operators, and syntax."),
1905 }
1906}
1907
1908fn format_select_statement(stmt: &SelectStatement, indent: usize) -> String {
1909 let mut result = String::new();
1910 let indent_str = " ".repeat(indent);
1911
1912 result.push_str(&format!("{indent_str}SelectStatement {{\n"));
1913
1914 result.push_str(&format!("{indent_str} columns: ["));
1916 if stmt.columns.is_empty() {
1917 result.push_str("],\n");
1918 } else {
1919 result.push('\n');
1920 for col in &stmt.columns {
1921 result.push_str(&format!("{indent_str} \"{col}\",\n"));
1922 }
1923 result.push_str(&format!("{indent_str} ],\n"));
1924 }
1925
1926 if let Some(table) = &stmt.from_table {
1928 result.push_str(&format!("{indent_str} from_table: \"{table}\",\n"));
1929 }
1930
1931 if let Some(where_clause) = &stmt.where_clause {
1933 result.push_str(&format!("{indent_str} where_clause: {{\n"));
1934 result.push_str(&format_where_clause(where_clause, indent + 2));
1935 result.push_str(&format!("{indent_str} }},\n"));
1936 }
1937
1938 if let Some(order_by) = &stmt.order_by {
1940 result.push_str(&format!("{indent_str} order_by: ["));
1941 if order_by.is_empty() {
1942 result.push_str("],\n");
1943 } else {
1944 result.push('\n');
1945 for col in order_by {
1946 let dir = match col.direction {
1947 SortDirection::Asc => "ASC",
1948 SortDirection::Desc => "DESC",
1949 };
1950 result.push_str(&format!(
1951 "{indent_str} \"{col}\" {dir},\n",
1952 col = col.column
1953 ));
1954 }
1955 result.push_str(&format!("{indent_str} ],\n"));
1956 }
1957 }
1958
1959 if let Some(group_by) = &stmt.group_by {
1961 result.push_str(&format!("{indent_str} group_by: ["));
1962 if group_by.is_empty() {
1963 result.push_str("]\n");
1964 } else {
1965 result.push('\n');
1966 for col in group_by {
1967 result.push_str(&format!("{indent_str} \"{col}\",\n"));
1968 }
1969 result.push_str(&format!("{indent_str} ],\n"));
1970 }
1971 }
1972
1973 result.push_str(&format!("{indent_str}}}"));
1974 result
1975}
1976
1977fn format_where_clause(clause: &WhereClause, indent: usize) -> String {
1978 let mut result = String::new();
1979 let indent_str = " ".repeat(indent);
1980
1981 result.push_str(&format!("{indent_str}conditions: [\n"));
1982
1983 for condition in &clause.conditions {
1984 result.push_str(&format!("{indent_str} {{\n"));
1985 result.push_str(&format!(
1986 "{indent_str} expr: {},\n",
1987 format_expression_ast(&condition.expr)
1988 ));
1989
1990 if let Some(connector) = &condition.connector {
1991 let connector_str = match connector {
1992 LogicalOp::And => "AND",
1993 LogicalOp::Or => "OR",
1994 };
1995 result.push_str(&format!("{indent_str} connector: {connector_str},\n"));
1996 }
1997
1998 result.push_str(&format!("{indent_str} }},\n"));
1999 }
2000
2001 result.push_str(&format!("{indent_str}]\n"));
2002 result
2003}
2004
2005fn format_expression_ast(expr: &SqlExpression) -> String {
2006 match expr {
2007 SqlExpression::Column(name) => format!("Column(\"{name}\")"),
2008 SqlExpression::StringLiteral(value) => format!("StringLiteral(\"{value}\")"),
2009 SqlExpression::NumberLiteral(value) => format!("NumberLiteral({value})"),
2010 SqlExpression::BooleanLiteral(value) => format!("BooleanLiteral({value})"),
2011 SqlExpression::Null => "Null".to_string(),
2012 SqlExpression::DateTimeConstructor {
2013 year,
2014 month,
2015 day,
2016 hour,
2017 minute,
2018 second,
2019 } => {
2020 format!(
2021 "DateTime({}-{:02}-{:02} {:02}:{:02}:{:02})",
2022 year,
2023 month,
2024 day,
2025 hour.unwrap_or(0),
2026 minute.unwrap_or(0),
2027 second.unwrap_or(0)
2028 )
2029 }
2030 SqlExpression::DateTimeToday {
2031 hour,
2032 minute,
2033 second,
2034 } => {
2035 format!(
2036 "DateTimeToday({:02}:{:02}:{:02})",
2037 hour.unwrap_or(0),
2038 minute.unwrap_or(0),
2039 second.unwrap_or(0)
2040 )
2041 }
2042 SqlExpression::MethodCall {
2043 object,
2044 method,
2045 args,
2046 } => {
2047 let args_str = args
2048 .iter()
2049 .map(format_expression_ast)
2050 .collect::<Vec<_>>()
2051 .join(", ");
2052 format!("MethodCall({object}.{method}({args_str}))")
2053 }
2054 SqlExpression::ChainedMethodCall { base, method, args } => {
2055 let args_str = args
2056 .iter()
2057 .map(format_expression_ast)
2058 .collect::<Vec<_>>()
2059 .join(", ");
2060 format!(
2061 "ChainedMethodCall({}.{}({}))",
2062 format_expression_ast(base),
2063 method,
2064 args_str
2065 )
2066 }
2067 SqlExpression::FunctionCall { name, args } => {
2068 let args_str = args
2069 .iter()
2070 .map(format_expression_ast)
2071 .collect::<Vec<_>>()
2072 .join(", ");
2073 format!("FunctionCall({name}({args_str}))")
2074 }
2075 SqlExpression::WindowFunction {
2076 name,
2077 args,
2078 window_spec,
2079 } => {
2080 let args_str = args
2081 .iter()
2082 .map(format_expression_ast)
2083 .collect::<Vec<_>>()
2084 .join(", ");
2085 let partition_str = if !window_spec.partition_by.is_empty() {
2086 format!(" PARTITION BY {}", window_spec.partition_by.join(", "))
2087 } else {
2088 String::new()
2089 };
2090 let order_str = if !window_spec.order_by.is_empty() {
2091 let cols = window_spec
2092 .order_by
2093 .iter()
2094 .map(|col| format!("{} {:?}", col.column, col.direction))
2095 .collect::<Vec<_>>()
2096 .join(", ");
2097 format!(" ORDER BY {}", cols)
2098 } else {
2099 String::new()
2100 };
2101 format!("WindowFunction({name}({args_str}) OVER({partition_str}{order_str}))")
2102 }
2103 SqlExpression::BinaryOp { left, op, right } => {
2104 format!(
2105 "BinaryOp({} {} {})",
2106 format_expression_ast(left),
2107 op,
2108 format_expression_ast(right)
2109 )
2110 }
2111 SqlExpression::InList { expr, values } => {
2112 let list_str = values
2113 .iter()
2114 .map(format_expression_ast)
2115 .collect::<Vec<_>>()
2116 .join(", ");
2117 format!("InList({} IN [{}])", format_expression_ast(expr), list_str)
2118 }
2119 SqlExpression::NotInList { expr, values } => {
2120 let list_str = values
2121 .iter()
2122 .map(format_expression_ast)
2123 .collect::<Vec<_>>()
2124 .join(", ");
2125 format!(
2126 "NotInList({} NOT IN [{}])",
2127 format_expression_ast(expr),
2128 list_str
2129 )
2130 }
2131 SqlExpression::Between { expr, lower, upper } => {
2132 format!(
2133 "Between({} BETWEEN {} AND {})",
2134 format_expression_ast(expr),
2135 format_expression_ast(lower),
2136 format_expression_ast(upper)
2137 )
2138 }
2139 SqlExpression::Not { expr } => {
2140 format!("Not({})", format_expression_ast(expr))
2141 }
2142 SqlExpression::CaseExpression {
2143 when_branches,
2144 else_branch,
2145 } => {
2146 let when_strs: Vec<String> = when_branches
2147 .iter()
2148 .map(|branch| {
2149 format!(
2150 "WHEN {} THEN {}",
2151 format_expression_ast(&branch.condition),
2152 format_expression_ast(&branch.result)
2153 )
2154 })
2155 .collect();
2156 let else_str = else_branch
2157 .as_ref()
2158 .map(|e| format!(" ELSE {}", format_expression_ast(e)))
2159 .unwrap_or_default();
2160 format!("CASE {} {} END", when_strs.join(" "), else_str)
2161 }
2162 }
2163}
2164
2165#[must_use]
2167pub fn datetime_to_iso_string(expr: &SqlExpression) -> Option<String> {
2168 match expr {
2169 SqlExpression::DateTimeConstructor {
2170 year,
2171 month,
2172 day,
2173 hour,
2174 minute,
2175 second,
2176 } => {
2177 let h = hour.unwrap_or(0);
2178 let m = minute.unwrap_or(0);
2179 let s = second.unwrap_or(0);
2180
2181 if let Ok(dt) = NaiveDateTime::parse_from_str(
2183 &format!("{year:04}-{month:02}-{day:02} {h:02}:{m:02}:{s:02}"),
2184 "%Y-%m-%d %H:%M:%S",
2185 ) {
2186 Some(dt.format("%Y-%m-%d %H:%M:%S").to_string())
2187 } else {
2188 None
2189 }
2190 }
2191 SqlExpression::DateTimeToday {
2192 hour,
2193 minute,
2194 second,
2195 } => {
2196 let now = Local::now();
2197 let h = hour.unwrap_or(0);
2198 let m = minute.unwrap_or(0);
2199 let s = second.unwrap_or(0);
2200
2201 if let Ok(dt) = NaiveDateTime::parse_from_str(
2203 &format!(
2204 "{:04}-{:02}-{:02} {:02}:{:02}:{:02}",
2205 now.year(),
2206 now.month(),
2207 now.day(),
2208 h,
2209 m,
2210 s
2211 ),
2212 "%Y-%m-%d %H:%M:%S",
2213 ) {
2214 Some(dt.format("%Y-%m-%d %H:%M:%S").to_string())
2215 } else {
2216 None
2217 }
2218 }
2219 _ => None,
2220 }
2221}
2222
2223fn format_sql_with_preserved_parens(
2225 query: &str,
2226 cols_per_line: usize,
2227) -> Result<Vec<String>, String> {
2228 let mut lines = Vec::new();
2229 let mut lexer = Lexer::new(query);
2230 let tokens_with_pos = lexer.tokenize_all_with_positions();
2231
2232 if tokens_with_pos.is_empty() {
2233 return Err("No tokens found".to_string());
2234 }
2235
2236 let mut i = 0;
2237 let cols_per_line = cols_per_line.max(1);
2238
2239 while i < tokens_with_pos.len() {
2240 let (start, _end, ref token) = tokens_with_pos[i];
2241
2242 match token {
2243 Token::Select => {
2244 lines.push("SELECT".to_string());
2245 i += 1;
2246
2247 let mut columns = Vec::new();
2249 let mut col_start = i;
2250 while i < tokens_with_pos.len() {
2251 match &tokens_with_pos[i].2 {
2252 Token::From | Token::Eof => break,
2253 Token::Comma => {
2254 if col_start < i {
2256 let col_text = extract_text_between_positions(
2257 query,
2258 tokens_with_pos[col_start].0,
2259 tokens_with_pos[i - 1].1,
2260 );
2261 columns.push(col_text);
2262 }
2263 i += 1;
2264 col_start = i;
2265 }
2266 _ => i += 1,
2267 }
2268 }
2269 if col_start < i && i > 0 {
2271 let col_text = extract_text_between_positions(
2272 query,
2273 tokens_with_pos[col_start].0,
2274 tokens_with_pos[i - 1].1,
2275 );
2276 columns.push(col_text);
2277 }
2278
2279 for chunk in columns.chunks(cols_per_line) {
2281 let mut line = " ".to_string();
2282 for (idx, col) in chunk.iter().enumerate() {
2283 if idx > 0 {
2284 line.push_str(", ");
2285 }
2286 line.push_str(col.trim());
2287 }
2288 let is_last_chunk = chunk.as_ptr() as usize + std::mem::size_of_val(chunk)
2290 >= columns.last().map_or(0, |c| std::ptr::from_ref(c) as usize);
2291 if !is_last_chunk && columns.len() > cols_per_line {
2292 line.push(',');
2293 }
2294 lines.push(line);
2295 }
2296 }
2297 Token::From => {
2298 i += 1;
2299 if i < tokens_with_pos.len() {
2300 let table_start = tokens_with_pos[i].0;
2301 while i < tokens_with_pos.len() {
2303 match &tokens_with_pos[i].2 {
2304 Token::Where | Token::OrderBy | Token::GroupBy | Token::Eof => break,
2305 _ => i += 1,
2306 }
2307 }
2308 if i > 0 {
2309 let table_text = extract_text_between_positions(
2310 query,
2311 table_start,
2312 tokens_with_pos[i - 1].1,
2313 );
2314 lines.push(format!("FROM {}", table_text.trim()));
2315 }
2316 }
2317 }
2318 Token::Where => {
2319 lines.push("WHERE".to_string());
2320 i += 1;
2321
2322 let where_start = if i < tokens_with_pos.len() {
2324 tokens_with_pos[i].0
2325 } else {
2326 start
2327 };
2328
2329 let mut where_end = query.len();
2331 while i < tokens_with_pos.len() {
2332 match &tokens_with_pos[i].2 {
2333 Token::OrderBy | Token::GroupBy | Token::Eof => {
2334 if i > 0 {
2335 where_end = tokens_with_pos[i - 1].1;
2336 }
2337 break;
2338 }
2339 _ => i += 1,
2340 }
2341 }
2342
2343 let where_text = extract_text_between_positions(query, where_start, where_end);
2345
2346 let formatted_where = format_where_clause_with_parens(&where_text);
2348 for line in formatted_where {
2349 lines.push(format!(" {line}"));
2350 }
2351 }
2352 Token::OrderBy => {
2353 i += 1;
2354 let order_start = if i < tokens_with_pos.len() {
2355 tokens_with_pos[i].0
2356 } else {
2357 start
2358 };
2359
2360 while i < tokens_with_pos.len() {
2362 match &tokens_with_pos[i].2 {
2363 Token::GroupBy | Token::Eof => break,
2364 _ => i += 1,
2365 }
2366 }
2367
2368 if i > 0 {
2369 let order_text = extract_text_between_positions(
2370 query,
2371 order_start,
2372 tokens_with_pos[i - 1].1,
2373 );
2374 lines.push(format!("ORDER BY {}", order_text.trim()));
2375 }
2376 }
2377 Token::GroupBy => {
2378 i += 1;
2379 let group_start = if i < tokens_with_pos.len() {
2380 tokens_with_pos[i].0
2381 } else {
2382 start
2383 };
2384
2385 while i < tokens_with_pos.len() {
2387 match &tokens_with_pos[i].2 {
2388 Token::Having | Token::Eof => break,
2389 _ => i += 1,
2390 }
2391 }
2392
2393 if i > 0 {
2394 let group_text = extract_text_between_positions(
2395 query,
2396 group_start,
2397 tokens_with_pos[i - 1].1,
2398 );
2399 lines.push(format!("GROUP BY {}", group_text.trim()));
2400 }
2401 }
2402 _ => i += 1,
2403 }
2404 }
2405
2406 Ok(lines)
2407}
2408
2409fn extract_text_between_positions(query: &str, start: usize, end: usize) -> String {
2411 let chars: Vec<char> = query.chars().collect();
2412 let start = start.min(chars.len());
2413 let end = end.min(chars.len());
2414 chars[start..end].iter().collect()
2415}
2416
2417fn format_where_clause_with_parens(where_text: &str) -> Vec<String> {
2419 let mut lines = Vec::new();
2420 let mut current_line = String::new();
2421 let mut paren_depth = 0;
2422 let mut i = 0;
2423 let chars: Vec<char> = where_text.chars().collect();
2424
2425 while i < chars.len() {
2426 if paren_depth == 0 {
2428 if i + 5 <= chars.len() {
2430 let next_five: String = chars[i..i + 5].iter().collect();
2431 if next_five.to_uppercase() == " AND " {
2432 if !current_line.trim().is_empty() {
2433 lines.push(current_line.trim().to_string());
2434 }
2435 lines.push("AND".to_string());
2436 current_line.clear();
2437 i += 5;
2438 continue;
2439 }
2440 }
2441 if i + 4 <= chars.len() {
2442 let next_four: String = chars[i..i + 4].iter().collect();
2443 if next_four.to_uppercase() == " OR " {
2444 if !current_line.trim().is_empty() {
2445 lines.push(current_line.trim().to_string());
2446 }
2447 lines.push("OR".to_string());
2448 current_line.clear();
2449 i += 4;
2450 continue;
2451 }
2452 }
2453 }
2454
2455 match chars[i] {
2457 '(' => {
2458 paren_depth += 1;
2459 current_line.push('(');
2460 }
2461 ')' => {
2462 paren_depth -= 1;
2463 current_line.push(')');
2464 }
2465 c => current_line.push(c),
2466 }
2467 i += 1;
2468 }
2469
2470 if !current_line.trim().is_empty() {
2472 lines.push(current_line.trim().to_string());
2473 }
2474
2475 if lines.is_empty() {
2477 lines.push(where_text.trim().to_string());
2478 }
2479
2480 lines
2481}
2482
2483#[must_use]
2484pub fn format_sql_pretty_compact(query: &str, cols_per_line: usize) -> Vec<String> {
2485 if let Ok(lines) = format_sql_with_preserved_parens(query, cols_per_line) {
2487 return lines;
2488 }
2489
2490 let mut lines = Vec::new();
2492 let mut parser = Parser::new(query);
2493
2494 let cols_per_line = cols_per_line.max(1);
2496
2497 if let Ok(stmt) = parser.parse() {
2498 if !stmt.columns.is_empty() {
2500 lines.push("SELECT".to_string());
2501
2502 for chunk in stmt.columns.chunks(cols_per_line) {
2504 let mut line = " ".to_string();
2505 for (i, col) in chunk.iter().enumerate() {
2506 if i > 0 {
2507 line.push_str(", ");
2508 }
2509 line.push_str(col);
2510 }
2511 let last_chunk_idx = (stmt.columns.len() - 1) / cols_per_line;
2513 let current_chunk_idx =
2514 stmt.columns.iter().position(|c| c == &chunk[0]).unwrap() / cols_per_line;
2515 if current_chunk_idx < last_chunk_idx {
2516 line.push(',');
2517 }
2518 lines.push(line);
2519 }
2520 }
2521
2522 if let Some(table) = &stmt.from_table {
2524 lines.push(format!("FROM {table}"));
2525 }
2526
2527 if let Some(where_clause) = &stmt.where_clause {
2529 lines.push("WHERE".to_string());
2530 for (i, condition) in where_clause.conditions.iter().enumerate() {
2531 if i > 0 {
2532 if let Some(prev_condition) = where_clause.conditions.get(i - 1) {
2534 if let Some(connector) = &prev_condition.connector {
2535 match connector {
2536 LogicalOp::And => lines.push(" AND".to_string()),
2537 LogicalOp::Or => lines.push(" OR".to_string()),
2538 }
2539 }
2540 }
2541 }
2542 lines.push(format!(" {}", format_expression(&condition.expr)));
2543 }
2544 }
2545
2546 if let Some(order_by) = &stmt.order_by {
2548 let order_str = order_by
2549 .iter()
2550 .map(|col| {
2551 let dir = match col.direction {
2552 SortDirection::Asc => " ASC",
2553 SortDirection::Desc => " DESC",
2554 };
2555 format!("{}{}", col.column, dir)
2556 })
2557 .collect::<Vec<_>>()
2558 .join(", ");
2559 lines.push(format!("ORDER BY {order_str}"));
2560 }
2561
2562 if let Some(group_by) = &stmt.group_by {
2564 let group_str = group_by.join(", ");
2565 lines.push(format!("GROUP BY {group_str}"));
2566 }
2567 } else {
2568 let mut lexer = Lexer::new(query);
2570 let tokens = lexer.tokenize_all();
2571 let mut current_line = String::new();
2572 let mut indent = 0;
2573
2574 for token in tokens {
2575 match &token {
2576 Token::Select | Token::From | Token::Where | Token::OrderBy | Token::GroupBy => {
2577 if !current_line.is_empty() {
2578 lines.push(current_line.trim().to_string());
2579 current_line.clear();
2580 }
2581 lines.push(format!("{token:?}").to_uppercase());
2582 indent = 1;
2583 }
2584 Token::And | Token::Or => {
2585 if !current_line.is_empty() {
2586 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2587 current_line.clear();
2588 }
2589 lines.push(format!(" {token:?}").to_uppercase());
2590 }
2591 Token::Comma => {
2592 current_line.push(',');
2593 if indent > 0 {
2594 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2595 current_line.clear();
2596 }
2597 }
2598 Token::Eof => break,
2599 _ => {
2600 if !current_line.is_empty() {
2601 current_line.push(' ');
2602 }
2603 current_line.push_str(&format_token(&token));
2604 }
2605 }
2606 }
2607
2608 if !current_line.is_empty() {
2609 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2610 }
2611 }
2612
2613 lines
2614}
2615
2616fn format_expression(expr: &SqlExpression) -> String {
2617 match expr {
2618 SqlExpression::Column(name) => name.clone(),
2619 SqlExpression::StringLiteral(s) => format!("'{s}'"),
2620 SqlExpression::NumberLiteral(n) => n.clone(),
2621 SqlExpression::BooleanLiteral(b) => b.to_string(),
2622 SqlExpression::Null => "NULL".to_string(),
2623 SqlExpression::DateTimeConstructor {
2624 year,
2625 month,
2626 day,
2627 hour,
2628 minute,
2629 second,
2630 } => {
2631 let mut result = format!("DateTime({year}, {month}, {day}");
2632 if let Some(h) = hour {
2633 result.push_str(&format!(", {h}"));
2634 if let Some(m) = minute {
2635 result.push_str(&format!(", {m}"));
2636 if let Some(s) = second {
2637 result.push_str(&format!(", {s}"));
2638 }
2639 }
2640 }
2641 result.push(')');
2642 result
2643 }
2644 SqlExpression::DateTimeToday {
2645 hour,
2646 minute,
2647 second,
2648 } => {
2649 let mut result = "DateTime()".to_string();
2650 if let Some(h) = hour {
2651 result = format!("DateTime(TODAY, {h}");
2652 if let Some(m) = minute {
2653 result.push_str(&format!(", {m}"));
2654 if let Some(s) = second {
2655 result.push_str(&format!(", {s}"));
2656 }
2657 }
2658 result.push(')');
2659 }
2660 result
2661 }
2662 SqlExpression::MethodCall {
2663 object,
2664 method,
2665 args,
2666 } => {
2667 let args_str = args
2668 .iter()
2669 .map(format_expression)
2670 .collect::<Vec<_>>()
2671 .join(", ");
2672 format!("{object}.{method}({args_str})")
2673 }
2674 SqlExpression::BinaryOp { left, op, right } => {
2675 if op == "OR" || op == "AND" {
2678 format!(
2681 "({} {} {})",
2682 format_expression(left),
2683 op,
2684 format_expression(right)
2685 )
2686 } else {
2687 format!(
2688 "{} {} {}",
2689 format_expression(left),
2690 op,
2691 format_expression(right)
2692 )
2693 }
2694 }
2695 SqlExpression::InList { expr, values } => {
2696 let values_str = values
2697 .iter()
2698 .map(format_expression)
2699 .collect::<Vec<_>>()
2700 .join(", ");
2701 format!("{} IN ({})", format_expression(expr), values_str)
2702 }
2703 SqlExpression::NotInList { expr, values } => {
2704 let values_str = values
2705 .iter()
2706 .map(format_expression)
2707 .collect::<Vec<_>>()
2708 .join(", ");
2709 format!("{} NOT IN ({})", format_expression(expr), values_str)
2710 }
2711 SqlExpression::Between { expr, lower, upper } => {
2712 format!(
2713 "{} BETWEEN {} AND {}",
2714 format_expression(expr),
2715 format_expression(lower),
2716 format_expression(upper)
2717 )
2718 }
2719 SqlExpression::Not { expr } => {
2720 format!("NOT {}", format_expression(expr))
2721 }
2722 SqlExpression::ChainedMethodCall { base, method, args } => {
2723 let args_str = args
2724 .iter()
2725 .map(format_expression)
2726 .collect::<Vec<_>>()
2727 .join(", ");
2728 format!("{}.{}({})", format_expression(base), method, args_str)
2729 }
2730 SqlExpression::FunctionCall { name, args } => {
2731 let args_str = args
2732 .iter()
2733 .map(format_expression)
2734 .collect::<Vec<_>>()
2735 .join(", ");
2736 format!("{name}({args_str})")
2737 }
2738 SqlExpression::WindowFunction {
2739 name,
2740 args,
2741 window_spec,
2742 } => {
2743 let args_str = args
2744 .iter()
2745 .map(format_expression)
2746 .collect::<Vec<_>>()
2747 .join(", ");
2748 let partition_str = if !window_spec.partition_by.is_empty() {
2749 format!(" PARTITION BY {}", window_spec.partition_by.join(", "))
2750 } else {
2751 String::new()
2752 };
2753 let order_str = if !window_spec.order_by.is_empty() {
2754 let cols = window_spec
2755 .order_by
2756 .iter()
2757 .map(|col| {
2758 let dir = match col.direction {
2759 SortDirection::Asc => "ASC",
2760 SortDirection::Desc => "DESC",
2761 };
2762 format!("{} {}", col.column, dir)
2763 })
2764 .collect::<Vec<_>>()
2765 .join(", ");
2766 format!(" ORDER BY {}", cols)
2767 } else {
2768 String::new()
2769 };
2770 format!("{name}({args_str}) OVER({partition_str}{order_str})")
2771 }
2772 SqlExpression::CaseExpression {
2773 when_branches,
2774 else_branch,
2775 } => {
2776 let mut result = String::from("CASE");
2777 for branch in when_branches {
2778 result.push_str(&format!(
2779 " WHEN {} THEN {}",
2780 format_expression(&branch.condition),
2781 format_expression(&branch.result)
2782 ));
2783 }
2784 if let Some(else_expr) = else_branch {
2785 result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
2786 }
2787 result.push_str(" END");
2788 result
2789 }
2790 }
2791}
2792
2793fn format_token(token: &Token) -> String {
2794 match token {
2795 Token::Identifier(s) => s.clone(),
2796 Token::QuotedIdentifier(s) => format!("\"{s}\""),
2797 Token::StringLiteral(s) => format!("'{s}'"),
2798 Token::NumberLiteral(n) => n.clone(),
2799 Token::DateTime => "DateTime".to_string(),
2800 Token::Case => "CASE".to_string(),
2801 Token::When => "WHEN".to_string(),
2802 Token::Then => "THEN".to_string(),
2803 Token::Else => "ELSE".to_string(),
2804 Token::End => "END".to_string(),
2805 Token::Distinct => "DISTINCT".to_string(),
2806 Token::Over => "OVER".to_string(),
2807 Token::Partition => "PARTITION".to_string(),
2808 Token::By => "BY".to_string(),
2809 Token::LeftParen => "(".to_string(),
2810 Token::RightParen => ")".to_string(),
2811 Token::Comma => ",".to_string(),
2812 Token::Dot => ".".to_string(),
2813 Token::Equal => "=".to_string(),
2814 Token::NotEqual => "!=".to_string(),
2815 Token::LessThan => "<".to_string(),
2816 Token::GreaterThan => ">".to_string(),
2817 Token::LessThanOrEqual => "<=".to_string(),
2818 Token::GreaterThanOrEqual => ">=".to_string(),
2819 Token::In => "IN".to_string(),
2820 _ => format!("{token:?}").to_uppercase(),
2821 }
2822}
2823
2824fn analyze_statement(
2825 stmt: &SelectStatement,
2826 query: &str,
2827 _cursor_pos: usize,
2828) -> (CursorContext, Option<String>) {
2829 let trimmed = query.trim();
2831
2832 let comparison_ops = [" > ", " < ", " >= ", " <= ", " = ", " != "];
2834 for op in &comparison_ops {
2835 if let Some(op_pos) = query.rfind(op) {
2836 let before_op = safe_slice_to(query, op_pos);
2837 let after_op_start = op_pos + op.len();
2838 let after_op = if after_op_start < query.len() {
2839 &query[after_op_start..]
2840 } else {
2841 ""
2842 };
2843
2844 if let Some(col_name) = before_op.split_whitespace().last() {
2846 if col_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
2847 let after_op_trimmed = after_op.trim();
2849 if after_op_trimmed.is_empty()
2850 || (after_op_trimmed
2851 .chars()
2852 .all(|c| c.is_alphanumeric() || c == '_')
2853 && !after_op_trimmed.contains('('))
2854 {
2855 let partial = if after_op_trimmed.is_empty() {
2856 None
2857 } else {
2858 Some(after_op_trimmed.to_string())
2859 };
2860 return (
2861 CursorContext::AfterComparisonOp(
2862 col_name.to_string(),
2863 op.trim().to_string(),
2864 ),
2865 partial,
2866 );
2867 }
2868 }
2869 }
2870 }
2871 }
2872
2873 if trimmed.to_uppercase().ends_with(" AND")
2875 || trimmed.to_uppercase().ends_with(" OR")
2876 || trimmed.to_uppercase().ends_with(" AND ")
2877 || trimmed.to_uppercase().ends_with(" OR ")
2878 {
2879 } else {
2881 if let Some(dot_pos) = trimmed.rfind('.') {
2883 let before_dot = safe_slice_to(trimmed, dot_pos);
2885 let after_dot_start = dot_pos + 1;
2886 let after_dot = if after_dot_start < trimmed.len() {
2887 &trimmed[after_dot_start..]
2888 } else {
2889 ""
2890 };
2891
2892 if !after_dot.contains('(') {
2895 let col_name = if before_dot.ends_with('"') {
2897 let bytes = before_dot.as_bytes();
2899 let mut pos = before_dot.len() - 1; let mut found_start = None;
2901
2902 if pos > 0 {
2904 pos -= 1;
2905 while pos > 0 {
2906 if bytes[pos] == b'"' {
2907 if pos == 0 || bytes[pos - 1] != b'\\' {
2909 found_start = Some(pos);
2910 break;
2911 }
2912 }
2913 pos -= 1;
2914 }
2915 if found_start.is_none() && bytes[0] == b'"' {
2917 found_start = Some(0);
2918 }
2919 }
2920
2921 found_start.map(|start| safe_slice_from(before_dot, start))
2922 } else {
2923 before_dot
2926 .split_whitespace()
2927 .last()
2928 .map(|word| word.trim_start_matches('('))
2929 };
2930
2931 if let Some(col_name) = col_name {
2932 let is_valid = if col_name.starts_with('"') && col_name.ends_with('"') {
2934 true
2936 } else {
2937 col_name.chars().all(|c| c.is_alphanumeric() || c == '_')
2939 };
2940
2941 if is_valid {
2942 let partial_method = if after_dot.is_empty() {
2945 None
2946 } else if after_dot.chars().all(|c| c.is_alphanumeric() || c == '_') {
2947 Some(after_dot.to_string())
2948 } else {
2949 None
2950 };
2951
2952 let col_name_for_context = if col_name.starts_with('"')
2954 && col_name.ends_with('"')
2955 && col_name.len() > 2
2956 {
2957 col_name[1..col_name.len() - 1].to_string()
2958 } else {
2959 col_name.to_string()
2960 };
2961
2962 return (
2963 CursorContext::AfterColumn(col_name_for_context),
2964 partial_method,
2965 );
2966 }
2967 }
2968 }
2969 }
2970 }
2971
2972 if let Some(where_clause) = &stmt.where_clause {
2974 if trimmed.to_uppercase().ends_with(" AND") || trimmed.to_uppercase().ends_with(" OR") {
2976 let op = if trimmed.to_uppercase().ends_with(" AND") {
2977 LogicalOp::And
2978 } else {
2979 LogicalOp::Or
2980 };
2981 return (CursorContext::AfterLogicalOp(op), None);
2982 }
2983
2984 if let Some(and_pos) = query.to_uppercase().rfind(" AND ") {
2986 let after_and = safe_slice_from(query, and_pos + 5);
2987 let partial = extract_partial_at_end(after_and);
2988 if partial.is_some() {
2989 return (CursorContext::AfterLogicalOp(LogicalOp::And), partial);
2990 }
2991 }
2992
2993 if let Some(or_pos) = query.to_uppercase().rfind(" OR ") {
2994 let after_or = safe_slice_from(query, or_pos + 4);
2995 let partial = extract_partial_at_end(after_or);
2996 if partial.is_some() {
2997 return (CursorContext::AfterLogicalOp(LogicalOp::Or), partial);
2998 }
2999 }
3000
3001 if let Some(last_condition) = where_clause.conditions.last() {
3002 if let Some(connector) = &last_condition.connector {
3003 return (
3005 CursorContext::AfterLogicalOp(connector.clone()),
3006 extract_partial_at_end(query),
3007 );
3008 }
3009 }
3010 return (CursorContext::WhereClause, extract_partial_at_end(query));
3012 }
3013
3014 if query.to_uppercase().ends_with(" ORDER BY ") || query.to_uppercase().ends_with(" ORDER BY") {
3016 return (CursorContext::OrderByClause, None);
3017 }
3018
3019 if stmt.order_by.is_some() {
3021 return (CursorContext::OrderByClause, extract_partial_at_end(query));
3022 }
3023
3024 if stmt.from_table.is_some() && stmt.where_clause.is_none() && stmt.order_by.is_none() {
3025 return (CursorContext::FromClause, extract_partial_at_end(query));
3026 }
3027
3028 if !stmt.columns.is_empty() && stmt.from_table.is_none() {
3029 return (CursorContext::SelectClause, extract_partial_at_end(query));
3030 }
3031
3032 (CursorContext::Unknown, None)
3033}
3034
3035fn analyze_partial(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
3036 let upper = query.to_uppercase();
3037
3038 let trimmed = query.trim();
3040
3041 #[cfg(test)]
3042 {
3043 if trimmed.contains("\"Last Name\"") {
3044 eprintln!("DEBUG analyze_partial: query='{query}', trimmed='{trimmed}'");
3045 }
3046 }
3047
3048 let comparison_ops = [" > ", " < ", " >= ", " <= ", " = ", " != "];
3050 for op in &comparison_ops {
3051 if let Some(op_pos) = query.rfind(op) {
3052 let before_op = safe_slice_to(query, op_pos);
3053 let after_op_start = op_pos + op.len();
3054 let after_op = if after_op_start < query.len() {
3055 &query[after_op_start..]
3056 } else {
3057 ""
3058 };
3059
3060 if let Some(col_name) = before_op.split_whitespace().last() {
3062 if col_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
3063 let after_op_trimmed = after_op.trim();
3065 if after_op_trimmed.is_empty()
3066 || (after_op_trimmed
3067 .chars()
3068 .all(|c| c.is_alphanumeric() || c == '_')
3069 && !after_op_trimmed.contains('('))
3070 {
3071 let partial = if after_op_trimmed.is_empty() {
3072 None
3073 } else {
3074 Some(after_op_trimmed.to_string())
3075 };
3076 return (
3077 CursorContext::AfterComparisonOp(
3078 col_name.to_string(),
3079 op.trim().to_string(),
3080 ),
3081 partial,
3082 );
3083 }
3084 }
3085 }
3086 }
3087 }
3088
3089 if let Some(dot_pos) = trimmed.rfind('.') {
3092 #[cfg(test)]
3093 {
3094 if trimmed.contains("\"Last Name\"") {
3095 eprintln!("DEBUG: Found dot at position {dot_pos}");
3096 }
3097 }
3098 let before_dot = &trimmed[..dot_pos];
3100 let after_dot = &trimmed[dot_pos + 1..];
3101
3102 if !after_dot.contains('(') {
3105 let col_name = if before_dot.ends_with('"') {
3108 let bytes = before_dot.as_bytes();
3110 let mut pos = before_dot.len() - 1; let mut found_start = None;
3112
3113 #[cfg(test)]
3114 {
3115 if trimmed.contains("\"Last Name\"") {
3116 eprintln!("DEBUG: before_dot='{before_dot}', looking for opening quote");
3117 }
3118 }
3119
3120 if pos > 0 {
3122 pos -= 1;
3123 while pos > 0 {
3124 if bytes[pos] == b'"' {
3125 if pos == 0 || bytes[pos - 1] != b'\\' {
3127 found_start = Some(pos);
3128 break;
3129 }
3130 }
3131 pos -= 1;
3132 }
3133 if found_start.is_none() && bytes[0] == b'"' {
3135 found_start = Some(0);
3136 }
3137 }
3138
3139 if let Some(start) = found_start {
3140 let result = safe_slice_from(before_dot, start);
3142 #[cfg(test)]
3143 {
3144 if trimmed.contains("\"Last Name\"") {
3145 eprintln!("DEBUG: Extracted quoted identifier: '{result}'");
3146 }
3147 }
3148 Some(result)
3149 } else {
3150 #[cfg(test)]
3151 {
3152 if trimmed.contains("\"Last Name\"") {
3153 eprintln!("DEBUG: No opening quote found!");
3154 }
3155 }
3156 None
3157 }
3158 } else {
3159 before_dot
3162 .split_whitespace()
3163 .last()
3164 .map(|word| word.trim_start_matches('('))
3165 };
3166
3167 if let Some(col_name) = col_name {
3168 #[cfg(test)]
3169 {
3170 if trimmed.contains("\"Last Name\"") {
3171 eprintln!("DEBUG: col_name = '{col_name}'");
3172 }
3173 }
3174
3175 let is_valid = if col_name.starts_with('"') && col_name.ends_with('"') {
3177 true
3179 } else {
3180 col_name.chars().all(|c| c.is_alphanumeric() || c == '_')
3182 };
3183
3184 #[cfg(test)]
3185 {
3186 if trimmed.contains("\"Last Name\"") {
3187 eprintln!("DEBUG: is_valid = {is_valid}");
3188 }
3189 }
3190
3191 if is_valid {
3192 let partial_method = if after_dot.is_empty() {
3195 None
3196 } else if after_dot.chars().all(|c| c.is_alphanumeric() || c == '_') {
3197 Some(after_dot.to_string())
3198 } else {
3199 None
3200 };
3201
3202 let col_name_for_context = if col_name.starts_with('"')
3204 && col_name.ends_with('"')
3205 && col_name.len() > 2
3206 {
3207 col_name[1..col_name.len() - 1].to_string()
3208 } else {
3209 col_name.to_string()
3210 };
3211
3212 return (
3213 CursorContext::AfterColumn(col_name_for_context),
3214 partial_method,
3215 );
3216 }
3217 }
3218 }
3219 }
3220
3221 if let Some(and_pos) = upper.rfind(" AND ") {
3223 if cursor_pos >= and_pos + 5 {
3225 let after_and = safe_slice_from(query, and_pos + 5);
3227 let partial = extract_partial_at_end(after_and);
3228 return (CursorContext::AfterLogicalOp(LogicalOp::And), partial);
3229 }
3230 }
3231
3232 if let Some(or_pos) = upper.rfind(" OR ") {
3233 if cursor_pos >= or_pos + 4 {
3235 let after_or = safe_slice_from(query, or_pos + 4);
3237 let partial = extract_partial_at_end(after_or);
3238 return (CursorContext::AfterLogicalOp(LogicalOp::Or), partial);
3239 }
3240 }
3241
3242 if trimmed.to_uppercase().ends_with(" AND") || trimmed.to_uppercase().ends_with(" OR") {
3244 let op = if trimmed.to_uppercase().ends_with(" AND") {
3245 LogicalOp::And
3246 } else {
3247 LogicalOp::Or
3248 };
3249 return (CursorContext::AfterLogicalOp(op), None);
3250 }
3251
3252 if upper.ends_with(" ORDER BY ") || upper.ends_with(" ORDER BY") || upper.contains("ORDER BY ")
3254 {
3255 return (CursorContext::OrderByClause, extract_partial_at_end(query));
3256 }
3257
3258 if upper.contains("WHERE") && !upper.contains("ORDER") && !upper.contains("GROUP") {
3259 return (CursorContext::WhereClause, extract_partial_at_end(query));
3260 }
3261
3262 if upper.contains("FROM") && !upper.contains("WHERE") && !upper.contains("ORDER") {
3263 return (CursorContext::FromClause, extract_partial_at_end(query));
3264 }
3265
3266 if upper.contains("SELECT") && !upper.contains("FROM") {
3267 return (CursorContext::SelectClause, extract_partial_at_end(query));
3268 }
3269
3270 (CursorContext::Unknown, None)
3271}
3272
3273fn extract_partial_at_end(query: &str) -> Option<String> {
3274 let trimmed = query.trim();
3275
3276 if let Some(last_word) = trimmed.split_whitespace().last() {
3278 if last_word.starts_with('"') && !last_word.ends_with('"') {
3279 return Some(last_word.to_string());
3281 }
3282 }
3283
3284 let last_word = trimmed.split_whitespace().last()?;
3286
3287 if last_word.chars().all(|c| c.is_alphanumeric() || c == '_') && !is_sql_keyword(last_word) {
3289 Some(last_word.to_string())
3290 } else {
3291 None
3292 }
3293}
3294
3295fn is_sql_keyword(word: &str) -> bool {
3296 matches!(
3297 word.to_uppercase().as_str(),
3298 "SELECT"
3299 | "FROM"
3300 | "WHERE"
3301 | "AND"
3302 | "OR"
3303 | "IN"
3304 | "ORDER"
3305 | "BY"
3306 | "GROUP"
3307 | "HAVING"
3308 | "ASC"
3309 | "DESC"
3310 | "DISTINCT"
3311 )
3312}
3313
3314#[cfg(test)]
3315mod tests {
3316 use super::*;
3317
3318 #[test]
3319 fn test_tokenizer_window_functions() {
3320 let mut lexer = Lexer::new("LAG(value) OVER (PARTITION BY category ORDER BY id)");
3321 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "LAG"));
3322 assert!(matches!(lexer.next_token(), Token::LeftParen));
3323 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "value"));
3324 assert!(matches!(lexer.next_token(), Token::RightParen));
3325
3326 let over_token = lexer.next_token();
3327 println!("Expected OVER, got: {:?}", over_token);
3328 assert!(matches!(over_token, Token::Over));
3329
3330 assert!(matches!(lexer.next_token(), Token::LeftParen));
3331 assert!(matches!(lexer.next_token(), Token::Partition));
3332 assert!(matches!(lexer.next_token(), Token::By));
3333 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "category"));
3334 }
3335
3336 #[test]
3337 fn test_parse_window_function() {
3338 let query = "SELECT LAG(value, 1) OVER (ORDER BY id) as prev_value FROM test";
3339 let mut parser = Parser::new(query);
3340 let result = parser.parse();
3341
3342 assert!(
3343 result.is_ok(),
3344 "Failed to parse window function: {:?}",
3345 result
3346 );
3347 let stmt = result.unwrap();
3348
3349 if let Some(item) = stmt.select_items.get(0) {
3351 match item {
3352 SelectItem::Expression { expr, alias } => {
3353 println!("Parsed expression: {:?}", expr);
3354 assert!(matches!(expr, SqlExpression::WindowFunction { .. }));
3355 assert_eq!(alias, "prev_value");
3356 }
3357 _ => panic!("Expected expression, got: {:?}", item),
3358 }
3359 } else {
3360 panic!("No select items found");
3361 }
3362 }
3363
3364 #[test]
3365 fn test_chained_method_calls() {
3366 let query = "SELECT * FROM trades WHERE commission.ToString().IndexOf('.') = 1";
3368 let mut parser = Parser::new(query);
3369 let result = parser.parse();
3370
3371 assert!(
3372 result.is_ok(),
3373 "Failed to parse chained method calls: {result:?}"
3374 );
3375
3376 let query2 = "SELECT * FROM data WHERE field.ToUpper().Replace('A', 'B').StartsWith('C')";
3378 let mut parser2 = Parser::new(query2);
3379 let result2 = parser2.parse();
3380
3381 assert!(
3382 result2.is_ok(),
3383 "Failed to parse multiple chained calls: {result2:?}"
3384 );
3385 }
3386
3387 #[test]
3388 fn test_tokenizer() {
3389 let mut lexer = Lexer::new("SELECT * FROM trade_deal WHERE price > 100");
3390
3391 assert!(matches!(lexer.next_token(), Token::Select));
3392 assert!(matches!(lexer.next_token(), Token::Star));
3393 assert!(matches!(lexer.next_token(), Token::From));
3394 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "trade_deal"));
3395 assert!(matches!(lexer.next_token(), Token::Where));
3396 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "price"));
3397 assert!(matches!(lexer.next_token(), Token::GreaterThan));
3398 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "100"));
3399 }
3400
3401 #[test]
3402 fn test_tokenizer_datetime() {
3403 let mut lexer = Lexer::new("WHERE createdDate > DateTime(2025, 10, 20)");
3404
3405 assert!(matches!(lexer.next_token(), Token::Where));
3406 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "createdDate"));
3407 assert!(matches!(lexer.next_token(), Token::GreaterThan));
3408 assert!(matches!(lexer.next_token(), Token::DateTime));
3409 assert!(matches!(lexer.next_token(), Token::LeftParen));
3410 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "2025"));
3411 assert!(matches!(lexer.next_token(), Token::Comma));
3412 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "10"));
3413 assert!(matches!(lexer.next_token(), Token::Comma));
3414 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "20"));
3415 assert!(matches!(lexer.next_token(), Token::RightParen));
3416 }
3417
3418 #[test]
3419 fn test_parse_simple_select() {
3420 let mut parser = Parser::new("SELECT * FROM trade_deal");
3421 let stmt = parser.parse().unwrap();
3422
3423 assert_eq!(stmt.columns, vec!["*"]);
3424 assert_eq!(stmt.from_table, Some("trade_deal".to_string()));
3425 assert!(stmt.where_clause.is_none());
3426 }
3427
3428 #[test]
3429 fn test_parse_where_with_method() {
3430 let mut parser = Parser::new("SELECT * FROM trade_deal WHERE name.Contains(\"test\")");
3431 let stmt = parser.parse().unwrap();
3432
3433 assert!(stmt.where_clause.is_some());
3434 let where_clause = stmt.where_clause.unwrap();
3435 assert_eq!(where_clause.conditions.len(), 1);
3436 }
3437
3438 #[test]
3439 fn test_parse_datetime_constructor() {
3440 let mut parser =
3441 Parser::new("SELECT * FROM trade_deal WHERE createdDate > DateTime(2025, 10, 20)");
3442 let stmt = parser.parse().unwrap();
3443
3444 assert!(stmt.where_clause.is_some());
3445 let where_clause = stmt.where_clause.unwrap();
3446 assert_eq!(where_clause.conditions.len(), 1);
3447
3448 if let SqlExpression::BinaryOp { left, op, right } = &where_clause.conditions[0].expr {
3450 assert_eq!(op, ">");
3451 assert!(matches!(left.as_ref(), SqlExpression::Column(col) if col == "createdDate"));
3452 assert!(matches!(
3453 right.as_ref(),
3454 SqlExpression::DateTimeConstructor {
3455 year: 2025,
3456 month: 10,
3457 day: 20,
3458 hour: None,
3459 minute: None,
3460 second: None
3461 }
3462 ));
3463 } else {
3464 panic!("Expected BinaryOp with DateTime constructor");
3465 }
3466 }
3467
3468 #[test]
3469 fn test_cursor_context_after_and() {
3470 let query = "SELECT * FROM trade_deal WHERE status = 'active' AND ";
3471 let (context, partial) = detect_cursor_context(query, query.len());
3472
3473 assert!(matches!(
3474 context,
3475 CursorContext::AfterLogicalOp(LogicalOp::And)
3476 ));
3477 assert_eq!(partial, None);
3478 }
3479
3480 #[test]
3481 fn test_cursor_context_with_partial() {
3482 let query = "SELECT * FROM trade_deal WHERE status = 'active' AND p";
3483 let (context, partial) = detect_cursor_context(query, query.len());
3484
3485 assert!(matches!(
3486 context,
3487 CursorContext::AfterLogicalOp(LogicalOp::And)
3488 ));
3489 assert_eq!(partial, Some("p".to_string()));
3490 }
3491
3492 #[test]
3493 fn test_cursor_context_after_datetime_comparison() {
3494 let query = "SELECT * FROM trade_deal WHERE createdDate > ";
3495 let (context, partial) = detect_cursor_context(query, query.len());
3496
3497 assert!(
3498 matches!(context, CursorContext::AfterComparisonOp(col, op) if col == "createdDate" && op == ">")
3499 );
3500 assert_eq!(partial, None);
3501 }
3502
3503 #[test]
3504 fn test_cursor_context_partial_datetime() {
3505 let query = "SELECT * FROM trade_deal WHERE createdDate > Date";
3506 let (context, partial) = detect_cursor_context(query, query.len());
3507
3508 assert!(
3509 matches!(context, CursorContext::AfterComparisonOp(col, op) if col == "createdDate" && op == ">")
3510 );
3511 assert_eq!(partial, Some("Date".to_string()));
3512 }
3513
3514 #[test]
3516 fn test_tokenizer_quoted_identifier() {
3517 let mut lexer = Lexer::new(r#"SELECT "Customer Id", "First Name" FROM customers"#);
3518
3519 assert!(matches!(lexer.next_token(), Token::Select));
3520 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "Customer Id"));
3521 assert!(matches!(lexer.next_token(), Token::Comma));
3522 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "First Name"));
3523 assert!(matches!(lexer.next_token(), Token::From));
3524 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "customers"));
3525 }
3526
3527 #[test]
3528 fn test_tokenizer_quoted_vs_string_literal() {
3529 let mut lexer = Lexer::new(r#"WHERE "Customer Id" = 'John' AND Country.Contains('USA')"#);
3531
3532 assert!(matches!(lexer.next_token(), Token::Where));
3533 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "Customer Id"));
3534 assert!(matches!(lexer.next_token(), Token::Equal));
3535 assert!(matches!(lexer.next_token(), Token::StringLiteral(s) if s == "John"));
3536 assert!(matches!(lexer.next_token(), Token::And));
3537 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Country"));
3538 assert!(matches!(lexer.next_token(), Token::Dot));
3539 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Contains"));
3540 assert!(matches!(lexer.next_token(), Token::LeftParen));
3541 assert!(matches!(lexer.next_token(), Token::StringLiteral(s) if s == "USA"));
3542 assert!(matches!(lexer.next_token(), Token::RightParen));
3543 }
3544
3545 #[test]
3546 fn test_tokenizer_method_with_double_quotes_should_be_string() {
3547 let mut lexer = Lexer::new(r#"Country.Contains("Alb")"#);
3550
3551 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Country"));
3552 assert!(matches!(lexer.next_token(), Token::Dot));
3553 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Contains"));
3554 assert!(matches!(lexer.next_token(), Token::LeftParen));
3555
3556 let token = lexer.next_token();
3559 println!("Token for \"Alb\": {token:?}");
3560 assert!(matches!(lexer.next_token(), Token::RightParen));
3564 }
3565
3566 #[test]
3567 fn test_parse_select_with_quoted_columns() {
3568 let mut parser = Parser::new(r#"SELECT "Customer Id", Company FROM customers"#);
3569 let stmt = parser.parse().unwrap();
3570
3571 assert_eq!(stmt.columns, vec!["Customer Id", "Company"]);
3572 assert_eq!(stmt.from_table, Some("customers".to_string()));
3573 }
3574
3575 #[test]
3576 fn test_cursor_context_select_with_partial_quoted() {
3577 let query = r#"SELECT "Cust"#;
3579 let (context, partial) = detect_cursor_context(query, query.len() - 1); println!("Context: {context:?}, Partial: {partial:?}");
3582 assert!(matches!(context, CursorContext::SelectClause));
3583 }
3586
3587 #[test]
3588 fn test_cursor_context_select_after_comma_with_quoted() {
3589 let query = r#"SELECT Company, "Customer "#;
3591 let (context, partial) = detect_cursor_context(query, query.len());
3592
3593 println!("Context: {context:?}, Partial: {partial:?}");
3594 assert!(matches!(context, CursorContext::SelectClause));
3595 }
3597
3598 #[test]
3599 fn test_cursor_context_order_by_quoted() {
3600 let query = r#"SELECT * FROM customers ORDER BY "Cust"#;
3601 let (context, partial) = detect_cursor_context(query, query.len() - 1);
3602
3603 println!("Context: {context:?}, Partial: {partial:?}");
3604 assert!(matches!(context, CursorContext::OrderByClause));
3605 }
3607
3608 #[test]
3609 fn test_where_clause_with_quoted_column() {
3610 let mut parser = Parser::new(r#"SELECT * FROM customers WHERE "Customer Id" = 'C123'"#);
3611 let stmt = parser.parse().unwrap();
3612
3613 assert!(stmt.where_clause.is_some());
3614 let where_clause = stmt.where_clause.unwrap();
3615 assert_eq!(where_clause.conditions.len(), 1);
3616
3617 if let SqlExpression::BinaryOp { left, op, right } = &where_clause.conditions[0].expr {
3618 assert_eq!(op, "=");
3619 assert!(matches!(left.as_ref(), SqlExpression::Column(col) if col == "Customer Id"));
3620 assert!(matches!(right.as_ref(), SqlExpression::StringLiteral(s) if s == "C123"));
3621 } else {
3622 panic!("Expected BinaryOp");
3623 }
3624 }
3625
3626 #[test]
3627 fn test_parse_method_with_double_quotes_as_string() {
3628 let mut parser = Parser::new(r#"SELECT * FROM customers WHERE Country.Contains("USA")"#);
3630 let stmt = parser.parse().unwrap();
3631
3632 assert!(stmt.where_clause.is_some());
3633 let where_clause = stmt.where_clause.unwrap();
3634 assert_eq!(where_clause.conditions.len(), 1);
3635
3636 if let SqlExpression::MethodCall {
3637 object,
3638 method,
3639 args,
3640 } = &where_clause.conditions[0].expr
3641 {
3642 assert_eq!(object, "Country");
3643 assert_eq!(method, "Contains");
3644 assert_eq!(args.len(), 1);
3645 assert!(matches!(&args[0], SqlExpression::StringLiteral(s) if s == "USA"));
3647 } else {
3648 panic!("Expected MethodCall");
3649 }
3650 }
3651
3652 #[test]
3653 fn test_extract_partial_with_quoted_columns_in_query() {
3654 let query = r#"SELECT City,Company,Country,"Customer Id" FROM customers ORDER BY coun"#;
3656 let (context, partial) = detect_cursor_context(query, query.len());
3657
3658 assert!(matches!(context, CursorContext::OrderByClause));
3659 assert_eq!(
3660 partial,
3661 Some("coun".to_string()),
3662 "Should extract 'coun' as partial, not everything after the quoted column"
3663 );
3664 }
3665
3666 #[test]
3667 fn test_extract_partial_quoted_identifier_being_typed() {
3668 let query = r#"SELECT "Cust"#;
3670 let partial = extract_partial_at_end(query);
3671 assert_eq!(partial, Some("\"Cust".to_string()));
3672
3673 let query2 = r#"SELECT "Customer Id" FROM"#;
3675 let partial2 = extract_partial_at_end(query2);
3676 assert_eq!(partial2, None); }
3678
3679 #[test]
3681 fn test_complex_where_parentheses_basic() {
3682 let mut parser =
3684 Parser::new(r#"SELECT * FROM trades WHERE (status = "active" OR status = "pending")"#);
3685 let stmt = parser.parse().unwrap();
3686
3687 assert!(stmt.where_clause.is_some());
3688 let where_clause = stmt.where_clause.unwrap();
3689 assert_eq!(where_clause.conditions.len(), 1);
3690
3691 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[0].expr {
3693 assert_eq!(op, "OR");
3694 } else {
3695 panic!("Expected BinaryOp with OR");
3696 }
3697 }
3698
3699 #[test]
3700 fn test_complex_where_mixed_and_or_with_parens() {
3701 let mut parser = Parser::new(
3703 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#,
3704 );
3705 let stmt = parser.parse().unwrap();
3706
3707 assert!(stmt.where_clause.is_some());
3708 let where_clause = stmt.where_clause.unwrap();
3709 assert_eq!(where_clause.conditions.len(), 2);
3710
3711 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[0].expr {
3713 assert_eq!(op, "OR");
3714 } else {
3715 panic!("Expected first condition to be OR expression");
3716 }
3717
3718 assert!(matches!(
3720 where_clause.conditions[0].connector,
3721 Some(LogicalOp::And)
3722 ));
3723
3724 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[1].expr {
3726 assert_eq!(op, ">");
3727 } else {
3728 panic!("Expected second condition to be price > 100");
3729 }
3730 }
3731
3732 #[test]
3733 fn test_complex_where_nested_parentheses() {
3734 let mut parser = Parser::new(
3736 r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND price > 100) OR status = "cancelled""#,
3737 );
3738 let stmt = parser.parse().unwrap();
3739
3740 assert!(stmt.where_clause.is_some());
3741 let where_clause = stmt.where_clause.unwrap();
3742
3743 assert!(!where_clause.conditions.is_empty());
3745 }
3746
3747 #[test]
3748 fn test_complex_where_multiple_or_groups() {
3749 let mut parser = Parser::new(
3751 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL" OR symbol = "MSFT") AND (price > 100 AND price < 500)"#,
3752 );
3753 let stmt = parser.parse().unwrap();
3754
3755 assert!(stmt.where_clause.is_some());
3756 let where_clause = stmt.where_clause.unwrap();
3757 assert_eq!(where_clause.conditions.len(), 2);
3758
3759 assert!(matches!(
3761 where_clause.conditions[0].connector,
3762 Some(LogicalOp::And)
3763 ));
3764 }
3765
3766 #[test]
3767 fn test_complex_where_with_methods_in_parens() {
3768 let mut parser = Parser::new(
3770 r#"SELECT * FROM trades WHERE (symbol.StartsWith("A") OR symbol.StartsWith("G")) AND volume > 1000000"#,
3771 );
3772 let stmt = parser.parse().unwrap();
3773
3774 assert!(stmt.where_clause.is_some());
3775 let where_clause = stmt.where_clause.unwrap();
3776 assert_eq!(where_clause.conditions.len(), 2);
3777
3778 if let SqlExpression::BinaryOp { op, left, right } = &where_clause.conditions[0].expr {
3780 assert_eq!(op, "OR");
3781 assert!(matches!(left.as_ref(), SqlExpression::MethodCall { .. }));
3782 assert!(matches!(right.as_ref(), SqlExpression::MethodCall { .. }));
3783 } else {
3784 panic!("Expected OR of method calls");
3785 }
3786 }
3787
3788 #[test]
3789 fn test_complex_where_date_comparisons_with_parens() {
3790 let mut parser = Parser::new(
3792 r#"SELECT * FROM trades WHERE (executionDate > DateTime(2024, 1, 1) AND executionDate < DateTime(2024, 12, 31)) AND (status = "filled" OR status = "partial")"#,
3793 );
3794 let stmt = parser.parse().unwrap();
3795
3796 assert!(stmt.where_clause.is_some());
3797 let where_clause = stmt.where_clause.unwrap();
3798 assert_eq!(where_clause.conditions.len(), 2);
3799
3800 assert!(matches!(
3802 where_clause.conditions[0].connector,
3803 Some(LogicalOp::And)
3804 ));
3805 }
3806
3807 #[test]
3808 fn test_complex_where_price_volume_filters() {
3809 let mut parser = Parser::new(
3811 r"SELECT * FROM trades WHERE ((price > 100 AND price < 200) OR (price > 500 AND price < 1000)) AND volume > 10000",
3812 );
3813 let stmt = parser.parse().unwrap();
3814
3815 assert!(stmt.where_clause.is_some());
3816 let where_clause = stmt.where_clause.unwrap();
3817
3818 assert!(!where_clause.conditions.is_empty());
3820 }
3821
3822 #[test]
3823 fn test_complex_where_mixed_string_numeric() {
3824 let mut parser = Parser::new(
3826 r#"SELECT * FROM trades WHERE (exchange = "NYSE" OR exchange = "NASDAQ") AND (volume > 1000000 OR notes.Contains("urgent"))"#,
3827 );
3828 let stmt = parser.parse().unwrap();
3829
3830 assert!(stmt.where_clause.is_some());
3831 }
3833
3834 #[test]
3835 fn test_complex_where_triple_nested() {
3836 let mut parser = Parser::new(
3838 r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND (price > 100 OR volume > 1000000)) OR (status = "cancelled" AND reason.Contains("timeout"))"#,
3839 );
3840 let stmt = parser.parse().unwrap();
3841
3842 assert!(stmt.where_clause.is_some());
3843 }
3845
3846 #[test]
3847 fn test_complex_where_single_parens_around_and() {
3848 let mut parser = Parser::new(
3850 r#"SELECT * FROM trades WHERE (symbol = "AAPL" AND price > 150 AND volume > 100000)"#,
3851 );
3852 let stmt = parser.parse().unwrap();
3853
3854 assert!(stmt.where_clause.is_some());
3855 let where_clause = stmt.where_clause.unwrap();
3856
3857 assert!(!where_clause.conditions.is_empty());
3859 }
3860
3861 #[test]
3863 fn test_format_preserves_simple_parentheses() {
3864 let query = r#"SELECT * FROM trades WHERE (status = "active" OR status = "pending")"#;
3865 let formatted = format_sql_pretty_compact(query, 5);
3866 let formatted_text = formatted.join(" ");
3867
3868 assert!(formatted_text.contains("(status"));
3870 assert!(formatted_text.contains("\"pending\")"));
3871
3872 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3874 let formatted_parens = formatted_text
3875 .chars()
3876 .filter(|c| *c == '(' || *c == ')')
3877 .count();
3878 assert_eq!(
3879 original_parens, formatted_parens,
3880 "Parentheses should be preserved"
3881 );
3882 }
3883
3884 #[test]
3885 fn test_format_preserves_complex_parentheses() {
3886 let query =
3887 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#;
3888 let formatted = format_sql_pretty_compact(query, 5);
3889 let formatted_text = formatted.join(" ");
3890
3891 assert!(formatted_text.contains("(symbol"));
3893 assert!(formatted_text.contains("\"GOOGL\")"));
3894
3895 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3897 let formatted_parens = formatted_text
3898 .chars()
3899 .filter(|c| *c == '(' || *c == ')')
3900 .count();
3901 assert_eq!(original_parens, formatted_parens);
3902 }
3903
3904 #[test]
3905 fn test_format_preserves_nested_parentheses() {
3906 let query = r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND price > 100) OR status = "cancelled""#;
3907 let formatted = format_sql_pretty_compact(query, 5);
3908 let formatted_text = formatted.join(" ");
3909
3910 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3912 let formatted_parens = formatted_text
3913 .chars()
3914 .filter(|c| *c == '(' || *c == ')')
3915 .count();
3916 assert_eq!(
3917 original_parens, formatted_parens,
3918 "Nested parentheses should be preserved"
3919 );
3920 assert_eq!(original_parens, 4, "Should have 4 parentheses total");
3921 }
3922
3923 #[test]
3924 fn test_format_preserves_method_calls_in_parentheses() {
3925 let query = r#"SELECT * FROM trades WHERE (symbol.StartsWith("A") OR symbol.StartsWith("G")) AND volume > 1000000"#;
3926 let formatted = format_sql_pretty_compact(query, 5);
3927 let formatted_text = formatted.join(" ");
3928
3929 assert!(formatted_text.contains("(symbol.StartsWith"));
3931 assert!(formatted_text.contains("StartsWith(\"A\")"));
3932 assert!(formatted_text.contains("StartsWith(\"G\")"));
3933
3934 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3936 let formatted_parens = formatted_text
3937 .chars()
3938 .filter(|c| *c == '(' || *c == ')')
3939 .count();
3940 assert_eq!(original_parens, formatted_parens);
3941 assert_eq!(
3942 original_parens, 6,
3943 "Should have 6 parentheses (1 group + 2 method calls)"
3944 );
3945 }
3946
3947 #[test]
3948 fn test_format_preserves_multiple_groups() {
3949 let query = r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND (price > 100 AND price < 500)"#;
3950 let formatted = format_sql_pretty_compact(query, 5);
3951 let formatted_text = formatted.join(" ");
3952
3953 assert!(formatted_text.contains("(symbol"));
3955 assert!(formatted_text.contains("(price"));
3956
3957 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3958 let formatted_parens = formatted_text
3959 .chars()
3960 .filter(|c| *c == '(' || *c == ')')
3961 .count();
3962 assert_eq!(original_parens, formatted_parens);
3963 assert_eq!(original_parens, 4, "Should have 4 parentheses (2 groups)");
3964 }
3965
3966 #[test]
3967 fn test_format_preserves_date_ranges() {
3968 let query = r"SELECT * FROM trades WHERE (executionDate > DateTime(2024, 1, 1) AND executionDate < DateTime(2024, 12, 31))";
3969 let formatted = format_sql_pretty_compact(query, 5);
3970 let formatted_text = formatted.join(" ");
3971
3972 assert!(formatted_text.contains("(executionDate"));
3974 assert!(formatted_text.contains("DateTime(2024, 1, 1)"));
3975 assert!(formatted_text.contains("DateTime(2024, 12, 31)"));
3976
3977 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3978 let formatted_parens = formatted_text
3979 .chars()
3980 .filter(|c| *c == '(' || *c == ')')
3981 .count();
3982 assert_eq!(original_parens, formatted_parens);
3983 }
3984
3985 #[test]
3986 fn test_format_multiline_layout() {
3987 let query =
3989 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#;
3990 let formatted = format_sql_pretty_compact(query, 5);
3991
3992 assert!(formatted.len() >= 4, "Should have multiple lines");
3994 assert_eq!(formatted[0], "SELECT");
3995 assert!(formatted[1].trim().starts_with('*'));
3996 assert!(formatted[2].starts_with("FROM"));
3997 assert_eq!(formatted[3], "WHERE");
3998
3999 let where_lines: Vec<_> = formatted.iter().skip(4).collect();
4001 assert!(where_lines.iter().any(|l| l.contains("(symbol")));
4002 assert!(where_lines.iter().any(|l| l.trim() == "AND"));
4003 }
4004
4005 #[test]
4006 fn test_between_simple() {
4007 let mut parser = Parser::new("SELECT * FROM table WHERE price BETWEEN 50 AND 100");
4008 let stmt = parser.parse().expect("Should parse simple BETWEEN");
4009
4010 assert!(stmt.where_clause.is_some());
4011 let where_clause = stmt.where_clause.unwrap();
4012 assert_eq!(where_clause.conditions.len(), 1);
4013
4014 let ast = format_ast_tree("SELECT * FROM table WHERE price BETWEEN 50 AND 100");
4016 assert!(!ast.contains("PARSE ERROR"));
4017 assert!(ast.contains("SelectStatement"));
4018 }
4019
4020 #[test]
4021 fn test_between_in_parentheses() {
4022 let mut parser = Parser::new("SELECT * FROM table WHERE (price BETWEEN 50 AND 100)");
4023 let stmt = parser.parse().expect("Should parse BETWEEN in parentheses");
4024
4025 assert!(stmt.where_clause.is_some());
4026
4027 let ast = format_ast_tree("SELECT * FROM table WHERE (price BETWEEN 50 AND 100)");
4029 assert!(!ast.contains("PARSE ERROR"), "Should not have parse error");
4030 }
4031
4032 #[test]
4033 fn test_between_with_or() {
4034 let query = "SELECT * FROM test WHERE (Price BETWEEN 50 AND 100) OR (quantity > 5)";
4035 let mut parser = Parser::new(query);
4036 let stmt = parser.parse().expect("Should parse BETWEEN with OR");
4037
4038 assert!(stmt.where_clause.is_some());
4039 }
4042
4043 #[test]
4044 fn test_between_with_and() {
4045 let query = "SELECT * FROM table WHERE category = 'Books' AND price BETWEEN 10 AND 50";
4046 let mut parser = Parser::new(query);
4047 let stmt = parser.parse().expect("Should parse BETWEEN with AND");
4048
4049 assert!(stmt.where_clause.is_some());
4050 let where_clause = stmt.where_clause.unwrap();
4051 assert_eq!(where_clause.conditions.len(), 2); }
4053
4054 #[test]
4055 fn test_multiple_between() {
4056 let query =
4057 "SELECT * FROM table WHERE (price BETWEEN 10 AND 50) AND (quantity BETWEEN 5 AND 20)";
4058 let mut parser = Parser::new(query);
4059 let stmt = parser
4060 .parse()
4061 .expect("Should parse multiple BETWEEN clauses");
4062
4063 assert!(stmt.where_clause.is_some());
4064 }
4065
4066 #[test]
4067 fn test_between_complex_query() {
4068 let query = "SELECT * FROM test_sorting WHERE (Price BETWEEN 50 AND 100) OR (Product.Length() > 5) ORDER BY Category ASC, price DESC";
4070 let mut parser = Parser::new(query);
4071 let stmt = parser
4072 .parse()
4073 .expect("Should parse complex query with BETWEEN, method calls, and ORDER BY");
4074
4075 assert!(stmt.where_clause.is_some());
4076 assert!(stmt.order_by.is_some());
4077
4078 let order_by = stmt.order_by.unwrap();
4079 assert_eq!(order_by.len(), 2);
4080 assert_eq!(order_by[0].column, "Category");
4081 assert!(matches!(order_by[0].direction, SortDirection::Asc));
4082 assert_eq!(order_by[1].column, "price");
4083 assert!(matches!(order_by[1].direction, SortDirection::Desc));
4084 }
4085
4086 #[test]
4087 fn test_between_formatting() {
4088 let expr = SqlExpression::Between {
4089 expr: Box::new(SqlExpression::Column("price".to_string())),
4090 lower: Box::new(SqlExpression::NumberLiteral("50".to_string())),
4091 upper: Box::new(SqlExpression::NumberLiteral("100".to_string())),
4092 };
4093
4094 let formatted = format_expression(&expr);
4095 assert_eq!(formatted, "price BETWEEN 50 AND 100");
4096
4097 let ast_formatted = format_expression_ast(&expr);
4098 assert!(ast_formatted.contains("Between"));
4099 assert!(ast_formatted.contains("50"));
4100 assert!(ast_formatted.contains("100"));
4101 }
4102
4103 #[test]
4104 fn test_utf8_boundary_safety() {
4105 let query_with_unicode = "SELECT * FROM table WHERE column = 'héllo'";
4107
4108 for pos in 0..=query_with_unicode.len() {
4110 let result =
4112 std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, pos));
4113
4114 assert!(
4115 result.is_ok(),
4116 "Panic at position {pos} in query with Unicode"
4117 );
4118 }
4119
4120 let result = std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, 1000));
4122 assert!(result.is_ok(), "Panic with position beyond string length");
4123
4124 let pos_in_e = query_with_unicode.find('é').unwrap() + 1; let result =
4127 std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, pos_in_e));
4128 assert!(
4129 result.is_ok(),
4130 "Panic with cursor in middle of UTF-8 character"
4131 );
4132 }
4133}