1use chrono::{Datelike, Local, NaiveDateTime};
2
3#[derive(Debug, Clone, PartialEq)]
4pub enum Token {
5 Select,
7 From,
8 Where,
9 And,
10 Or,
11 In,
12 Not,
13 Between,
14 Like,
15 Is,
16 Null,
17 OrderBy,
18 GroupBy,
19 Having,
20 As,
21 Asc,
22 Desc,
23 Limit,
24 Offset,
25 DateTime, Case, When, Then, Else, End, Distinct, Identifier(String),
35 QuotedIdentifier(String), StringLiteral(String),
37 NumberLiteral(String),
38 Star,
39
40 Dot,
42 Comma,
43 LeftParen,
44 RightParen,
45 Equal,
46 NotEqual,
47 LessThan,
48 GreaterThan,
49 LessThanOrEqual,
50 GreaterThanOrEqual,
51
52 Plus,
54 Minus,
55 Divide,
56
57 Eof,
59}
60
61#[derive(Debug, Clone)]
62pub struct Lexer {
63 input: Vec<char>,
64 position: usize,
65 current_char: Option<char>,
66}
67
68impl Lexer {
69 pub fn new(input: &str) -> Self {
70 let chars: Vec<char> = input.chars().collect();
71 let current = chars.get(0).copied();
72 Self {
73 input: chars,
74 position: 0,
75 current_char: current,
76 }
77 }
78
79 fn advance(&mut self) {
80 self.position += 1;
81 self.current_char = self.input.get(self.position).copied();
82 }
83
84 fn peek(&self, offset: usize) -> Option<char> {
85 self.input.get(self.position + offset).copied()
86 }
87
88 fn skip_whitespace(&mut self) {
89 while let Some(ch) = self.current_char {
90 if ch.is_whitespace() {
91 self.advance();
92 } else {
93 break;
94 }
95 }
96 }
97
98 fn read_identifier(&mut self) -> String {
99 let mut result = String::new();
100 while let Some(ch) = self.current_char {
101 if ch.is_alphanumeric() || ch == '_' {
102 result.push(ch);
103 self.advance();
104 } else {
105 break;
106 }
107 }
108 result
109 }
110
111 fn read_string(&mut self) -> String {
112 let mut result = String::new();
113 let quote_char = self.current_char.unwrap(); self.advance(); while let Some(ch) = self.current_char {
117 if ch == quote_char {
118 self.advance(); break;
120 }
121 result.push(ch);
122 self.advance();
123 }
124 result
125 }
126
127 fn read_number(&mut self) -> String {
128 let mut result = String::new();
129 let mut has_e = false;
130
131 while let Some(ch) = self.current_char {
133 if !has_e && (ch.is_numeric() || ch == '.') {
134 result.push(ch);
135 self.advance();
136 } else if (ch == 'e' || ch == 'E') && !has_e && !result.is_empty() {
137 result.push(ch);
139 self.advance();
140 has_e = true;
141
142 if let Some(sign) = self.current_char {
144 if sign == '+' || sign == '-' {
145 result.push(sign);
146 self.advance();
147 }
148 }
149
150 while let Some(digit) = self.current_char {
152 if digit.is_numeric() {
153 result.push(digit);
154 self.advance();
155 } else {
156 break;
157 }
158 }
159 break; } else {
161 break;
162 }
163 }
164 result
165 }
166
167 pub fn next_token(&mut self) -> Token {
168 self.skip_whitespace();
169
170 match self.current_char {
171 None => Token::Eof,
172 Some('*') => {
173 self.advance();
174 Token::Star }
178 Some('+') => {
179 self.advance();
180 Token::Plus
181 }
182 Some('/') => {
183 self.advance();
184 Token::Divide
185 }
186 Some('.') => {
187 self.advance();
188 Token::Dot
189 }
190 Some(',') => {
191 self.advance();
192 Token::Comma
193 }
194 Some('(') => {
195 self.advance();
196 Token::LeftParen
197 }
198 Some(')') => {
199 self.advance();
200 Token::RightParen
201 }
202 Some('=') => {
203 self.advance();
204 Token::Equal
205 }
206 Some('<') => {
207 self.advance();
208 if self.current_char == Some('=') {
209 self.advance();
210 Token::LessThanOrEqual
211 } else if self.current_char == Some('>') {
212 self.advance();
213 Token::NotEqual
214 } else {
215 Token::LessThan
216 }
217 }
218 Some('>') => {
219 self.advance();
220 if self.current_char == Some('=') {
221 self.advance();
222 Token::GreaterThanOrEqual
223 } else {
224 Token::GreaterThan
225 }
226 }
227 Some('!') if self.peek(1) == Some('=') => {
228 self.advance();
229 self.advance();
230 Token::NotEqual
231 }
232 Some('"') => {
233 let ident_val = self.read_string();
235 Token::QuotedIdentifier(ident_val)
236 }
237 Some('\'') => {
238 let string_val = self.read_string();
240 Token::StringLiteral(string_val)
241 }
242 Some('-') if self.peek(1).map_or(false, |c| c.is_numeric()) => {
243 self.advance(); let num = self.read_number();
246 Token::NumberLiteral(format!("-{}", num))
247 }
248 Some('-') => {
249 self.advance();
251 Token::Minus
252 }
253 Some(ch) if ch.is_numeric() => {
254 let num = self.read_number();
255 Token::NumberLiteral(num)
256 }
257 Some(ch) if ch.is_alphabetic() || ch == '_' => {
258 let ident = self.read_identifier();
259 match ident.to_uppercase().as_str() {
260 "SELECT" => Token::Select,
261 "FROM" => Token::From,
262 "WHERE" => Token::Where,
263 "AND" => Token::And,
264 "OR" => Token::Or,
265 "IN" => Token::In,
266 "NOT" => Token::Not,
267 "BETWEEN" => Token::Between,
268 "LIKE" => Token::Like,
269 "IS" => Token::Is,
270 "NULL" => Token::Null,
271 "ORDER" if self.peek_keyword("BY") => {
272 self.skip_whitespace();
273 self.read_identifier(); Token::OrderBy
275 }
276 "GROUP" if self.peek_keyword("BY") => {
277 self.skip_whitespace();
278 self.read_identifier(); Token::GroupBy
280 }
281 "HAVING" => Token::Having,
282 "AS" => Token::As,
283 "ASC" => Token::Asc,
284 "DESC" => Token::Desc,
285 "LIMIT" => Token::Limit,
286 "OFFSET" => Token::Offset,
287 "DATETIME" => Token::DateTime,
288 "CASE" => Token::Case,
289 "WHEN" => Token::When,
290 "THEN" => Token::Then,
291 "ELSE" => Token::Else,
292 "END" => Token::End,
293 "DISTINCT" => Token::Distinct,
294 _ => Token::Identifier(ident),
295 }
296 }
297 Some(ch) => {
298 self.advance();
299 Token::Identifier(ch.to_string())
300 }
301 }
302 }
303
304 fn peek_keyword(&mut self, keyword: &str) -> bool {
305 let saved_pos = self.position;
306 let saved_char = self.current_char;
307
308 self.skip_whitespace();
309 let next_word = self.read_identifier();
310 let matches = next_word.to_uppercase() == keyword;
311
312 self.position = saved_pos;
314 self.current_char = saved_char;
315
316 matches
317 }
318
319 pub fn get_position(&self) -> usize {
320 self.position
321 }
322
323 pub fn tokenize_all(&mut self) -> Vec<Token> {
324 let mut tokens = Vec::new();
325 loop {
326 let token = self.next_token();
327 if matches!(token, Token::Eof) {
328 tokens.push(token);
329 break;
330 }
331 tokens.push(token);
332 }
333 tokens
334 }
335
336 pub fn tokenize_all_with_positions(&mut self) -> Vec<(usize, usize, Token)> {
337 let mut tokens = Vec::new();
338 loop {
339 self.skip_whitespace();
340 let start_pos = self.position;
341 let token = self.next_token();
342 let end_pos = self.position;
343
344 if matches!(token, Token::Eof) {
345 break;
346 }
347 tokens.push((start_pos, end_pos, token));
348 }
349 tokens
350 }
351}
352
353#[derive(Debug, Clone)]
355pub enum SqlExpression {
356 Column(String),
357 StringLiteral(String),
358 NumberLiteral(String),
359 DateTimeConstructor {
360 year: i32,
361 month: u32,
362 day: u32,
363 hour: Option<u32>,
364 minute: Option<u32>,
365 second: Option<u32>,
366 },
367 DateTimeToday {
368 hour: Option<u32>,
369 minute: Option<u32>,
370 second: Option<u32>,
371 },
372 MethodCall {
373 object: String,
374 method: String,
375 args: Vec<SqlExpression>,
376 },
377 ChainedMethodCall {
378 base: Box<SqlExpression>,
379 method: String,
380 args: Vec<SqlExpression>,
381 },
382 FunctionCall {
383 name: String,
384 args: Vec<SqlExpression>,
385 },
386 BinaryOp {
387 left: Box<SqlExpression>,
388 op: String,
389 right: Box<SqlExpression>,
390 },
391 InList {
392 expr: Box<SqlExpression>,
393 values: Vec<SqlExpression>,
394 },
395 NotInList {
396 expr: Box<SqlExpression>,
397 values: Vec<SqlExpression>,
398 },
399 Between {
400 expr: Box<SqlExpression>,
401 lower: Box<SqlExpression>,
402 upper: Box<SqlExpression>,
403 },
404 Not {
405 expr: Box<SqlExpression>,
406 },
407 CaseExpression {
408 when_branches: Vec<WhenBranch>,
409 else_branch: Option<Box<SqlExpression>>,
410 },
411}
412
413#[derive(Debug, Clone)]
414pub struct WhenBranch {
415 pub condition: Box<SqlExpression>,
416 pub result: Box<SqlExpression>,
417}
418
419#[derive(Debug, Clone)]
420pub struct WhereClause {
421 pub conditions: Vec<Condition>,
422}
423
424#[derive(Debug, Clone)]
425pub struct Condition {
426 pub expr: SqlExpression,
427 pub connector: Option<LogicalOp>, }
429
430#[derive(Debug, Clone)]
431pub enum LogicalOp {
432 And,
433 Or,
434}
435
436#[derive(Debug, Clone, PartialEq)]
437pub enum SortDirection {
438 Asc,
439 Desc,
440}
441
442#[derive(Debug, Clone)]
443pub struct OrderByColumn {
444 pub column: String,
445 pub direction: SortDirection,
446}
447
448#[derive(Debug, Clone)]
450pub enum SelectItem {
451 Column(String),
453 Expression { expr: SqlExpression, alias: String },
455 Star,
457}
458
459#[derive(Debug, Clone)]
460pub struct SelectStatement {
461 pub columns: Vec<String>, pub select_items: Vec<SelectItem>, pub from_table: Option<String>,
464 pub where_clause: Option<WhereClause>,
465 pub order_by: Option<Vec<OrderByColumn>>,
466 pub group_by: Option<Vec<String>>,
467 pub limit: Option<usize>,
468 pub offset: Option<usize>,
469}
470
471pub struct ParserConfig {
472 pub case_insensitive: bool,
473}
474
475impl Default for ParserConfig {
476 fn default() -> Self {
477 Self {
478 case_insensitive: false,
479 }
480 }
481}
482
483pub struct Parser {
484 lexer: Lexer,
485 current_token: Token,
486 in_method_args: bool, columns: Vec<String>, paren_depth: i32, #[allow(dead_code)]
490 config: ParserConfig, }
492
493impl Parser {
494 pub fn new(input: &str) -> Self {
495 let mut lexer = Lexer::new(input);
496 let current_token = lexer.next_token();
497 Self {
498 lexer,
499 current_token,
500 in_method_args: false,
501 columns: Vec::new(),
502 paren_depth: 0,
503 config: ParserConfig::default(),
504 }
505 }
506
507 pub fn with_config(input: &str, config: ParserConfig) -> Self {
508 let mut lexer = Lexer::new(input);
509 let current_token = lexer.next_token();
510 Self {
511 lexer,
512 current_token,
513 in_method_args: false,
514 columns: Vec::new(),
515 paren_depth: 0,
516 config,
517 }
518 }
519
520 pub fn with_columns(mut self, columns: Vec<String>) -> Self {
521 self.columns = columns;
522 self
523 }
524
525 fn consume(&mut self, expected: Token) -> Result<(), String> {
526 if std::mem::discriminant(&self.current_token) == std::mem::discriminant(&expected) {
527 match &expected {
529 Token::LeftParen => self.paren_depth += 1,
530 Token::RightParen => {
531 self.paren_depth -= 1;
532 if self.paren_depth < 0 {
534 return Err(
535 "Unexpected closing parenthesis - no matching opening parenthesis"
536 .to_string(),
537 );
538 }
539 }
540 _ => {}
541 }
542
543 self.current_token = self.lexer.next_token();
544 Ok(())
545 } else {
546 let error_msg = match (&expected, &self.current_token) {
548 (Token::RightParen, Token::Eof) if self.paren_depth > 0 => {
549 format!(
550 "Unclosed parenthesis - missing {} closing parenthes{}",
551 self.paren_depth,
552 if self.paren_depth == 1 { "is" } else { "es" }
553 )
554 }
555 (Token::RightParen, _) if self.paren_depth > 0 => {
556 format!(
557 "Expected closing parenthesis but found {:?} (currently {} unclosed parenthes{})",
558 self.current_token,
559 self.paren_depth,
560 if self.paren_depth == 1 { "is" } else { "es" }
561 )
562 }
563 _ => format!("Expected {:?}, found {:?}", expected, self.current_token),
564 };
565 Err(error_msg)
566 }
567 }
568
569 fn advance(&mut self) {
570 match &self.current_token {
572 Token::LeftParen => self.paren_depth += 1,
573 Token::RightParen => {
574 self.paren_depth -= 1;
575 }
578 _ => {}
579 }
580 self.current_token = self.lexer.next_token();
581 }
582
583 pub fn parse(&mut self) -> Result<SelectStatement, String> {
584 self.parse_select_statement()
585 }
586
587 fn parse_select_statement(&mut self) -> Result<SelectStatement, String> {
588 self.consume(Token::Select)?;
589
590 let select_items = self.parse_select_items()?;
592
593 let columns = select_items
595 .iter()
596 .map(|item| match item {
597 SelectItem::Star => "*".to_string(),
598 SelectItem::Column(name) => name.clone(),
599 SelectItem::Expression { alias, .. } => alias.clone(),
600 })
601 .collect();
602
603 let from_table = if matches!(self.current_token, Token::From) {
604 self.advance();
605 match &self.current_token {
606 Token::Identifier(table) => {
607 let table_name = table.clone();
608 self.advance();
609 Some(table_name)
610 }
611 Token::QuotedIdentifier(table) => {
612 let table_name = table.clone();
614 self.advance();
615 Some(table_name)
616 }
617 _ => return Err("Expected table name after FROM".to_string()),
618 }
619 } else {
620 None
621 };
622
623 let where_clause = if matches!(self.current_token, Token::Where) {
624 self.advance();
625 Some(self.parse_where_clause()?)
626 } else {
627 None
628 };
629
630 let order_by = if matches!(self.current_token, Token::OrderBy) {
631 self.advance();
632 Some(self.parse_order_by_list()?)
633 } else {
634 None
635 };
636
637 let group_by = if matches!(self.current_token, Token::GroupBy) {
638 self.advance();
639 Some(self.parse_identifier_list()?)
640 } else {
641 None
642 };
643
644 let limit = if matches!(self.current_token, Token::Limit) {
646 self.advance();
647 match &self.current_token {
648 Token::NumberLiteral(num) => {
649 let limit_val = num
650 .parse::<usize>()
651 .map_err(|_| format!("Invalid LIMIT value: {}", num))?;
652 self.advance();
653 Some(limit_val)
654 }
655 _ => return Err("Expected number after LIMIT".to_string()),
656 }
657 } else {
658 None
659 };
660
661 let offset = if matches!(self.current_token, Token::Offset) {
663 self.advance();
664 match &self.current_token {
665 Token::NumberLiteral(num) => {
666 let offset_val = num
667 .parse::<usize>()
668 .map_err(|_| format!("Invalid OFFSET value: {}", num))?;
669 self.advance();
670 Some(offset_val)
671 }
672 _ => return Err("Expected number after OFFSET".to_string()),
673 }
674 } else {
675 None
676 };
677
678 if self.paren_depth > 0 {
680 return Err(format!(
681 "Unclosed parenthesis - missing {} closing parenthes{}",
682 self.paren_depth,
683 if self.paren_depth == 1 { "is" } else { "es" }
684 ));
685 } else if self.paren_depth < 0 {
686 return Err(
687 "Extra closing parenthesis found - no matching opening parenthesis".to_string(),
688 );
689 }
690
691 Ok(SelectStatement {
692 columns,
693 select_items,
694 from_table,
695 where_clause,
696 order_by,
697 group_by,
698 limit,
699 offset,
700 })
701 }
702
703 fn parse_select_list(&mut self) -> Result<Vec<String>, String> {
704 let mut columns = Vec::new();
705
706 if matches!(self.current_token, Token::Star) {
707 columns.push("*".to_string());
708 self.advance();
709 } else {
710 loop {
711 match &self.current_token {
712 Token::Identifier(col) => {
713 columns.push(col.clone());
714 self.advance();
715 }
716 Token::QuotedIdentifier(col) => {
717 columns.push(col.clone());
719 self.advance();
720 }
721 _ => return Err("Expected column name".to_string()),
722 }
723
724 if matches!(self.current_token, Token::Comma) {
725 self.advance();
726 } else {
727 break;
728 }
729 }
730 }
731
732 Ok(columns)
733 }
734
735 fn parse_select_items(&mut self) -> Result<Vec<SelectItem>, String> {
737 let mut items = Vec::new();
738
739 loop {
740 if matches!(self.current_token, Token::Star) {
743 items.push(SelectItem::Star);
751 self.advance();
752 } else {
753 let expr = self.parse_additive()?; let alias = if matches!(self.current_token, Token::As) {
758 self.advance();
759 match &self.current_token {
760 Token::Identifier(alias_name) => {
761 let alias = alias_name.clone();
762 self.advance();
763 alias
764 }
765 Token::QuotedIdentifier(alias_name) => {
766 let alias = alias_name.clone();
767 self.advance();
768 alias
769 }
770 _ => return Err("Expected alias name after AS".to_string()),
771 }
772 } else {
773 match &expr {
775 SqlExpression::Column(col_name) => col_name.clone(),
776 _ => format!("expr_{}", items.len() + 1), }
778 };
779
780 let item = match expr {
782 SqlExpression::Column(col_name) if alias == col_name => {
783 SelectItem::Column(col_name)
785 }
786 _ => {
787 SelectItem::Expression { expr, alias }
789 }
790 };
791
792 items.push(item);
793 }
794
795 if matches!(self.current_token, Token::Comma) {
797 self.advance();
798 } else {
799 break;
800 }
801 }
802
803 Ok(items)
804 }
805
806 fn parse_identifier_list(&mut self) -> Result<Vec<String>, String> {
807 let mut identifiers = Vec::new();
808
809 loop {
810 match &self.current_token {
811 Token::Identifier(id) => {
812 identifiers.push(id.clone());
813 self.advance();
814 }
815 Token::QuotedIdentifier(id) => {
816 identifiers.push(id.clone());
818 self.advance();
819 }
820 _ => return Err("Expected identifier".to_string()),
821 }
822
823 if matches!(self.current_token, Token::Comma) {
824 self.advance();
825 } else {
826 break;
827 }
828 }
829
830 Ok(identifiers)
831 }
832
833 fn parse_order_by_list(&mut self) -> Result<Vec<OrderByColumn>, String> {
834 let mut order_columns = Vec::new();
835
836 loop {
837 let column = match &self.current_token {
838 Token::Identifier(id) => {
839 let col = id.clone();
840 self.advance();
841 col
842 }
843 Token::QuotedIdentifier(id) => {
844 let col = id.clone();
845 self.advance();
846 col
847 }
848 Token::NumberLiteral(num) if self.columns.iter().any(|col| col == num) => {
849 let col = num.clone();
851 self.advance();
852 col
853 }
854 _ => return Err("Expected column name in ORDER BY".to_string()),
855 };
856
857 let direction = match &self.current_token {
859 Token::Asc => {
860 self.advance();
861 SortDirection::Asc
862 }
863 Token::Desc => {
864 self.advance();
865 SortDirection::Desc
866 }
867 _ => SortDirection::Asc, };
869
870 order_columns.push(OrderByColumn { column, direction });
871
872 if matches!(self.current_token, Token::Comma) {
873 self.advance();
874 } else {
875 break;
876 }
877 }
878
879 Ok(order_columns)
880 }
881
882 fn parse_where_clause(&mut self) -> Result<WhereClause, String> {
883 let mut conditions = Vec::new();
884
885 loop {
886 let expr = self.parse_expression()?;
887
888 let connector = match &self.current_token {
889 Token::And => {
890 self.advance();
891 Some(LogicalOp::And)
892 }
893 Token::Or => {
894 self.advance();
895 Some(LogicalOp::Or)
896 }
897 Token::RightParen if self.paren_depth <= 0 => {
898 return Err(
900 "Unexpected closing parenthesis - no matching opening parenthesis"
901 .to_string(),
902 );
903 }
904 _ => None,
905 };
906
907 conditions.push(Condition {
908 expr,
909 connector: connector.clone(),
910 });
911
912 if connector.is_none() {
913 break;
914 }
915 }
916
917 Ok(WhereClause { conditions })
918 }
919
920 fn parse_expression(&mut self) -> Result<SqlExpression, String> {
921 let mut left = self.parse_comparison()?;
922
923 if let Some(op) = self.get_binary_op() {
926 self.advance();
927 let right = self.parse_expression()?;
928 left = SqlExpression::BinaryOp {
929 left: Box::new(left),
930 op,
931 right: Box::new(right),
932 };
933 }
934
935 if matches!(self.current_token, Token::In) {
937 self.advance();
938 self.consume(Token::LeftParen)?;
939 let values = self.parse_expression_list()?;
940 self.consume(Token::RightParen)?;
941
942 left = SqlExpression::InList {
943 expr: Box::new(left),
944 values,
945 };
946 }
947
948 Ok(left)
952 }
953
954 fn parse_comparison(&mut self) -> Result<SqlExpression, String> {
955 let mut left = self.parse_additive()?;
956
957 if matches!(self.current_token, Token::Between) {
959 self.advance(); let lower = self.parse_primary()?;
961 self.consume(Token::And)?; let upper = self.parse_primary()?;
963
964 return Ok(SqlExpression::Between {
965 expr: Box::new(left),
966 lower: Box::new(lower),
967 upper: Box::new(upper),
968 });
969 }
970
971 if matches!(self.current_token, Token::Not) {
973 self.advance(); if matches!(self.current_token, Token::In) {
975 self.advance(); self.consume(Token::LeftParen)?;
977 let values = self.parse_expression_list()?;
978 self.consume(Token::RightParen)?;
979
980 return Ok(SqlExpression::NotInList {
981 expr: Box::new(left),
982 values,
983 });
984 } else {
985 return Err("Expected IN after NOT".to_string());
986 }
987 }
988
989 if let Some(op) = self.get_binary_op() {
991 self.advance();
992 let right = self.parse_additive()?;
993 left = SqlExpression::BinaryOp {
994 left: Box::new(left),
995 op,
996 right: Box::new(right),
997 };
998 }
999
1000 Ok(left)
1001 }
1002
1003 fn parse_additive(&mut self) -> Result<SqlExpression, String> {
1004 let mut left = self.parse_multiplicative()?;
1005
1006 while matches!(self.current_token, Token::Plus | Token::Minus) {
1007 let op = match self.current_token {
1008 Token::Plus => "+",
1009 Token::Minus => "-",
1010 _ => unreachable!(),
1011 };
1012 self.advance();
1013 let right = self.parse_multiplicative()?;
1014 left = SqlExpression::BinaryOp {
1015 left: Box::new(left),
1016 op: op.to_string(),
1017 right: Box::new(right),
1018 };
1019 }
1020
1021 Ok(left)
1022 }
1023
1024 fn parse_multiplicative(&mut self) -> Result<SqlExpression, String> {
1025 let mut left = self.parse_primary()?;
1026
1027 while matches!(self.current_token, Token::Dot) {
1029 self.advance();
1030 if let Token::Identifier(method) = &self.current_token {
1031 let method_name = method.clone();
1032 self.advance();
1033
1034 if matches!(self.current_token, Token::LeftParen) {
1035 self.advance();
1036 let args = self.parse_method_args()?;
1037 self.consume(Token::RightParen)?;
1038
1039 match left {
1041 SqlExpression::Column(obj) => {
1042 left = SqlExpression::MethodCall {
1044 object: obj,
1045 method: method_name,
1046 args,
1047 };
1048 }
1049 SqlExpression::MethodCall { .. }
1050 | SqlExpression::ChainedMethodCall { .. } => {
1051 left = SqlExpression::ChainedMethodCall {
1053 base: Box::new(left),
1054 method: method_name,
1055 args,
1056 };
1057 }
1058 _ => {
1059 left = SqlExpression::ChainedMethodCall {
1061 base: Box::new(left),
1062 method: method_name,
1063 args,
1064 };
1065 }
1066 }
1067 } else {
1068 return Err(format!("Expected '(' after method name '{}'", method_name));
1069 }
1070 } else {
1071 return Err("Expected method name after '.'".to_string());
1072 }
1073 }
1074
1075 while matches!(self.current_token, Token::Star | Token::Divide) {
1076 let op = match self.current_token {
1077 Token::Star => "*",
1078 Token::Divide => "/",
1079 _ => unreachable!(),
1080 };
1081 self.advance();
1082 let right = self.parse_primary()?;
1083 left = SqlExpression::BinaryOp {
1084 left: Box::new(left),
1085 op: op.to_string(),
1086 right: Box::new(right),
1087 };
1088 }
1089
1090 Ok(left)
1091 }
1092
1093 fn parse_logical_or(&mut self) -> Result<SqlExpression, String> {
1094 let mut left = self.parse_logical_and()?;
1095
1096 while matches!(self.current_token, Token::Or) {
1097 self.advance();
1098 let right = self.parse_logical_and()?;
1099 left = SqlExpression::BinaryOp {
1103 left: Box::new(left),
1104 op: "OR".to_string(),
1105 right: Box::new(right),
1106 };
1107 }
1108
1109 Ok(left)
1110 }
1111
1112 fn parse_logical_and(&mut self) -> Result<SqlExpression, String> {
1113 let mut left = self.parse_expression()?;
1114
1115 while matches!(self.current_token, Token::And) {
1116 self.advance();
1117 let right = self.parse_expression()?;
1118 left = SqlExpression::BinaryOp {
1120 left: Box::new(left),
1121 op: "AND".to_string(),
1122 right: Box::new(right),
1123 };
1124 }
1125
1126 Ok(left)
1127 }
1128
1129 fn parse_case_expression(&mut self) -> Result<SqlExpression, String> {
1130 self.consume(Token::Case)?;
1132
1133 let mut when_branches = Vec::new();
1134
1135 while matches!(self.current_token, Token::When) {
1137 self.advance(); let condition = self.parse_expression()?;
1141
1142 self.consume(Token::Then)?;
1144
1145 let result = self.parse_expression()?;
1147
1148 when_branches.push(WhenBranch {
1149 condition: Box::new(condition),
1150 result: Box::new(result),
1151 });
1152 }
1153
1154 if when_branches.is_empty() {
1156 return Err("CASE expression must have at least one WHEN clause".to_string());
1157 }
1158
1159 let else_branch = if matches!(self.current_token, Token::Else) {
1161 self.advance(); Some(Box::new(self.parse_expression()?))
1163 } else {
1164 None
1165 };
1166
1167 self.consume(Token::End)?;
1169
1170 Ok(SqlExpression::CaseExpression {
1171 when_branches,
1172 else_branch,
1173 })
1174 }
1175
1176 fn parse_primary(&mut self) -> Result<SqlExpression, String> {
1177 if let Token::NumberLiteral(num_str) = &self.current_token {
1180 if self.columns.iter().any(|col| col == num_str) {
1182 let expr = SqlExpression::Column(num_str.clone());
1183 self.advance();
1184 return Ok(expr);
1185 }
1186 }
1187
1188 match &self.current_token {
1189 Token::Case => {
1190 self.parse_case_expression()
1192 }
1193 Token::DateTime => {
1194 self.advance(); self.consume(Token::LeftParen)?;
1196
1197 if matches!(&self.current_token, Token::RightParen) {
1199 self.advance(); return Ok(SqlExpression::DateTimeToday {
1201 hour: None,
1202 minute: None,
1203 second: None,
1204 });
1205 }
1206
1207 let year = if let Token::NumberLiteral(n) = &self.current_token {
1209 n.parse::<i32>().map_err(|_| "Invalid year")?
1210 } else {
1211 return Err("Expected year in DateTime constructor".to_string());
1212 };
1213 self.advance();
1214 self.consume(Token::Comma)?;
1215
1216 let month = if let Token::NumberLiteral(n) = &self.current_token {
1218 n.parse::<u32>().map_err(|_| "Invalid month")?
1219 } else {
1220 return Err("Expected month in DateTime constructor".to_string());
1221 };
1222 self.advance();
1223 self.consume(Token::Comma)?;
1224
1225 let day = if let Token::NumberLiteral(n) = &self.current_token {
1227 n.parse::<u32>().map_err(|_| "Invalid day")?
1228 } else {
1229 return Err("Expected day in DateTime constructor".to_string());
1230 };
1231 self.advance();
1232
1233 let mut hour = None;
1235 let mut minute = None;
1236 let mut second = None;
1237
1238 if matches!(&self.current_token, Token::Comma) {
1239 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1243 hour = Some(n.parse::<u32>().map_err(|_| "Invalid hour")?);
1244 self.advance();
1245
1246 if matches!(&self.current_token, Token::Comma) {
1248 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1251 minute = Some(n.parse::<u32>().map_err(|_| "Invalid minute")?);
1252 self.advance();
1253
1254 if matches!(&self.current_token, Token::Comma) {
1256 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1259 second =
1260 Some(n.parse::<u32>().map_err(|_| "Invalid second")?);
1261 self.advance();
1262 }
1263 }
1264 }
1265 }
1266 }
1267 }
1268
1269 self.consume(Token::RightParen)?;
1270 Ok(SqlExpression::DateTimeConstructor {
1271 year,
1272 month,
1273 day,
1274 hour,
1275 minute,
1276 second,
1277 })
1278 }
1279 Token::Identifier(id) => {
1280 let id_upper = id.to_uppercase();
1281 let id_clone = id.clone();
1282 self.advance();
1283
1284 if matches!(self.current_token, Token::LeftParen) {
1286 self.advance(); let args = self.parse_function_args()?;
1290 self.consume(Token::RightParen)?;
1291 return Ok(SqlExpression::FunctionCall {
1292 name: id_upper,
1293 args,
1294 });
1295 }
1296
1297 Ok(SqlExpression::Column(id_clone))
1299 }
1300 Token::QuotedIdentifier(id) => {
1301 let expr = if self.in_method_args {
1304 SqlExpression::StringLiteral(id.clone())
1305 } else {
1306 SqlExpression::Column(id.clone())
1308 };
1309 self.advance();
1310 Ok(expr)
1311 }
1312 Token::StringLiteral(s) => {
1313 let expr = SqlExpression::StringLiteral(s.clone());
1314 self.advance();
1315 Ok(expr)
1316 }
1317 Token::NumberLiteral(n) => {
1318 let expr = SqlExpression::NumberLiteral(n.clone());
1319 self.advance();
1320 Ok(expr)
1321 }
1322 Token::LeftParen => {
1323 self.advance();
1324
1325 let expr = self.parse_logical_or()?;
1328
1329 self.consume(Token::RightParen)?;
1330 Ok(expr)
1331 }
1332 Token::Not => {
1333 self.advance(); if let Ok(inner_expr) = self.parse_comparison() {
1337 if matches!(self.current_token, Token::In) {
1339 self.advance(); self.consume(Token::LeftParen)?;
1341 let values = self.parse_expression_list()?;
1342 self.consume(Token::RightParen)?;
1343
1344 return Ok(SqlExpression::NotInList {
1345 expr: Box::new(inner_expr),
1346 values,
1347 });
1348 } else {
1349 return Ok(SqlExpression::Not {
1351 expr: Box::new(inner_expr),
1352 });
1353 }
1354 } else {
1355 return Err("Expected expression after NOT".to_string());
1356 }
1357 }
1358 Token::Star => {
1359 self.advance();
1361 Ok(SqlExpression::StringLiteral("*".to_string()))
1362 }
1363 _ => Err(format!("Unexpected token: {:?}", self.current_token)),
1364 }
1365 }
1366
1367 fn parse_method_args(&mut self) -> Result<Vec<SqlExpression>, String> {
1368 let mut args = Vec::new();
1369
1370 self.in_method_args = true;
1372
1373 if !matches!(self.current_token, Token::RightParen) {
1374 loop {
1375 args.push(self.parse_expression()?);
1376
1377 if matches!(self.current_token, Token::Comma) {
1378 self.advance();
1379 } else {
1380 break;
1381 }
1382 }
1383 }
1384
1385 self.in_method_args = false;
1387
1388 Ok(args)
1389 }
1390
1391 fn parse_function_args(&mut self) -> Result<Vec<SqlExpression>, String> {
1392 let mut args = Vec::new();
1393
1394 if !matches!(self.current_token, Token::RightParen) {
1395 if matches!(self.current_token, Token::Distinct) {
1397 self.advance(); let expr = self.parse_additive()?;
1400 args.push(SqlExpression::FunctionCall {
1402 name: "DISTINCT".to_string(),
1403 args: vec![expr],
1404 });
1405 } else {
1406 args.push(self.parse_additive()?);
1408 }
1409
1410 while matches!(self.current_token, Token::Comma) {
1412 self.advance();
1413 args.push(self.parse_additive()?);
1414 }
1415 }
1416
1417 Ok(args)
1418 }
1419
1420 fn parse_expression_list(&mut self) -> Result<Vec<SqlExpression>, String> {
1421 let mut expressions = Vec::new();
1422
1423 loop {
1424 expressions.push(self.parse_expression()?);
1425
1426 if matches!(self.current_token, Token::Comma) {
1427 self.advance();
1428 } else {
1429 break;
1430 }
1431 }
1432
1433 Ok(expressions)
1434 }
1435
1436 fn get_binary_op(&self) -> Option<String> {
1437 match &self.current_token {
1438 Token::Equal => Some("=".to_string()),
1439 Token::NotEqual => Some("!=".to_string()),
1440 Token::LessThan => Some("<".to_string()),
1441 Token::GreaterThan => Some(">".to_string()),
1442 Token::LessThanOrEqual => Some("<=".to_string()),
1443 Token::GreaterThanOrEqual => Some(">=".to_string()),
1444 Token::Like => Some("LIKE".to_string()),
1445 _ => None,
1446 }
1447 }
1448
1449 fn get_arithmetic_op(&self) -> Option<String> {
1450 match &self.current_token {
1451 Token::Plus => Some("+".to_string()),
1452 Token::Minus => Some("-".to_string()),
1453 Token::Star => Some("*".to_string()), Token::Divide => Some("/".to_string()),
1455 _ => None,
1456 }
1457 }
1458
1459 pub fn get_position(&self) -> usize {
1460 self.lexer.get_position()
1461 }
1462}
1463
1464#[derive(Debug, Clone)]
1466pub enum CursorContext {
1467 SelectClause,
1468 FromClause,
1469 WhereClause,
1470 OrderByClause,
1471 AfterColumn(String),
1472 AfterLogicalOp(LogicalOp),
1473 AfterComparisonOp(String, String), InMethodCall(String, String), InExpression,
1476 Unknown,
1477}
1478
1479fn safe_slice_to(s: &str, pos: usize) -> &str {
1481 if pos >= s.len() {
1482 return s;
1483 }
1484
1485 let mut safe_pos = pos;
1487 while safe_pos > 0 && !s.is_char_boundary(safe_pos) {
1488 safe_pos -= 1;
1489 }
1490
1491 &s[..safe_pos]
1492}
1493
1494fn safe_slice_from(s: &str, pos: usize) -> &str {
1496 if pos >= s.len() {
1497 return "";
1498 }
1499
1500 let mut safe_pos = pos;
1502 while safe_pos < s.len() && !s.is_char_boundary(safe_pos) {
1503 safe_pos += 1;
1504 }
1505
1506 &s[safe_pos..]
1507}
1508
1509pub fn detect_cursor_context(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
1510 let truncated = safe_slice_to(query, cursor_pos);
1511 let mut parser = Parser::new(truncated);
1512
1513 match parser.parse() {
1515 Ok(stmt) => {
1516 let (ctx, partial) = analyze_statement(&stmt, truncated, cursor_pos);
1517 #[cfg(test)]
1518 println!(
1519 "analyze_statement returned: {:?}, {:?} for query: '{}'",
1520 ctx, partial, truncated
1521 );
1522 (ctx, partial)
1523 }
1524 Err(_) => {
1525 let (ctx, partial) = analyze_partial(truncated, cursor_pos);
1527 #[cfg(test)]
1528 println!(
1529 "analyze_partial returned: {:?}, {:?} for query: '{}'",
1530 ctx, partial, truncated
1531 );
1532 (ctx, partial)
1533 }
1534 }
1535}
1536
1537pub fn tokenize_query(query: &str) -> Vec<String> {
1538 let mut lexer = Lexer::new(query);
1539 let tokens = lexer.tokenize_all();
1540 tokens.iter().map(|t| format!("{:?}", t)).collect()
1541}
1542
1543pub fn format_sql_pretty(query: &str) -> Vec<String> {
1544 format_sql_pretty_compact(query, 5) }
1546
1547pub fn format_ast_tree(query: &str) -> String {
1549 let mut parser = Parser::new(query);
1550 match parser.parse() {
1551 Ok(stmt) => format_select_statement(&stmt, 0),
1552 Err(e) => format!("❌ PARSE ERROR ❌\n{}\n\n⚠️ The query could not be parsed correctly.\n💡 Check parentheses, operators, and syntax.", e),
1553 }
1554}
1555
1556fn format_select_statement(stmt: &SelectStatement, indent: usize) -> String {
1557 let mut result = String::new();
1558 let indent_str = " ".repeat(indent);
1559
1560 result.push_str(&format!("{indent_str}SelectStatement {{\n"));
1561
1562 result.push_str(&format!("{indent_str} columns: ["));
1564 if !stmt.columns.is_empty() {
1565 result.push('\n');
1566 for col in &stmt.columns {
1567 result.push_str(&format!("{indent_str} \"{col}\",\n"));
1568 }
1569 result.push_str(&format!("{indent_str} ],\n"));
1570 } else {
1571 result.push_str("],\n");
1572 }
1573
1574 if let Some(table) = &stmt.from_table {
1576 result.push_str(&format!("{indent_str} from_table: \"{table}\",\n"));
1577 }
1578
1579 if let Some(where_clause) = &stmt.where_clause {
1581 result.push_str(&format!("{indent_str} where_clause: {{\n"));
1582 result.push_str(&format_where_clause(where_clause, indent + 2));
1583 result.push_str(&format!("{indent_str} }},\n"));
1584 }
1585
1586 if let Some(order_by) = &stmt.order_by {
1588 result.push_str(&format!("{indent_str} order_by: ["));
1589 if !order_by.is_empty() {
1590 result.push('\n');
1591 for col in order_by {
1592 let dir = match col.direction {
1593 SortDirection::Asc => "ASC",
1594 SortDirection::Desc => "DESC",
1595 };
1596 result.push_str(&format!(
1597 "{indent_str} \"{col}\" {dir},\n",
1598 col = col.column
1599 ));
1600 }
1601 result.push_str(&format!("{indent_str} ],\n"));
1602 } else {
1603 result.push_str("],\n");
1604 }
1605 }
1606
1607 if let Some(group_by) = &stmt.group_by {
1609 result.push_str(&format!("{indent_str} group_by: ["));
1610 if !group_by.is_empty() {
1611 result.push('\n');
1612 for col in group_by {
1613 result.push_str(&format!("{indent_str} \"{col}\",\n"));
1614 }
1615 result.push_str(&format!("{indent_str} ],\n"));
1616 } else {
1617 result.push_str("]\n");
1618 }
1619 }
1620
1621 result.push_str(&format!("{indent_str}}}"));
1622 result
1623}
1624
1625fn format_where_clause(clause: &WhereClause, indent: usize) -> String {
1626 let mut result = String::new();
1627 let indent_str = " ".repeat(indent);
1628
1629 result.push_str(&format!("{indent_str}conditions: [\n"));
1630
1631 for condition in &clause.conditions {
1632 result.push_str(&format!("{indent_str} {{\n"));
1633 result.push_str(&format!(
1634 "{indent_str} expr: {},\n",
1635 format_expression_ast(&condition.expr)
1636 ));
1637
1638 if let Some(connector) = &condition.connector {
1639 let connector_str = match connector {
1640 LogicalOp::And => "AND",
1641 LogicalOp::Or => "OR",
1642 };
1643 result.push_str(&format!("{indent_str} connector: {connector_str},\n"));
1644 }
1645
1646 result.push_str(&format!("{indent_str} }},\n"));
1647 }
1648
1649 result.push_str(&format!("{indent_str}]\n"));
1650 result
1651}
1652
1653fn format_expression_ast(expr: &SqlExpression) -> String {
1654 match expr {
1655 SqlExpression::Column(name) => format!("Column(\"{}\")", name),
1656 SqlExpression::StringLiteral(value) => format!("StringLiteral(\"{}\")", value),
1657 SqlExpression::NumberLiteral(value) => format!("NumberLiteral({})", value),
1658 SqlExpression::DateTimeConstructor {
1659 year,
1660 month,
1661 day,
1662 hour,
1663 minute,
1664 second,
1665 } => {
1666 format!(
1667 "DateTime({}-{:02}-{:02} {:02}:{:02}:{:02})",
1668 year,
1669 month,
1670 day,
1671 hour.unwrap_or(0),
1672 minute.unwrap_or(0),
1673 second.unwrap_or(0)
1674 )
1675 }
1676 SqlExpression::DateTimeToday {
1677 hour,
1678 minute,
1679 second,
1680 } => {
1681 format!(
1682 "DateTimeToday({:02}:{:02}:{:02})",
1683 hour.unwrap_or(0),
1684 minute.unwrap_or(0),
1685 second.unwrap_or(0)
1686 )
1687 }
1688 SqlExpression::MethodCall {
1689 object,
1690 method,
1691 args,
1692 } => {
1693 let args_str = args
1694 .iter()
1695 .map(|a| format_expression_ast(a))
1696 .collect::<Vec<_>>()
1697 .join(", ");
1698 format!("MethodCall({}.{}({}))", object, method, args_str)
1699 }
1700 SqlExpression::ChainedMethodCall { base, method, args } => {
1701 let args_str = args
1702 .iter()
1703 .map(|a| format_expression_ast(a))
1704 .collect::<Vec<_>>()
1705 .join(", ");
1706 format!(
1707 "ChainedMethodCall({}.{}({}))",
1708 format_expression_ast(base),
1709 method,
1710 args_str
1711 )
1712 }
1713 SqlExpression::FunctionCall { name, args } => {
1714 let args_str = args
1715 .iter()
1716 .map(|a| format_expression_ast(a))
1717 .collect::<Vec<_>>()
1718 .join(", ");
1719 format!("FunctionCall({}({}))", name, args_str)
1720 }
1721 SqlExpression::BinaryOp { left, op, right } => {
1722 format!(
1723 "BinaryOp({} {} {})",
1724 format_expression_ast(left),
1725 op,
1726 format_expression_ast(right)
1727 )
1728 }
1729 SqlExpression::InList { expr, values } => {
1730 let list_str = values
1731 .iter()
1732 .map(|e| format_expression_ast(e))
1733 .collect::<Vec<_>>()
1734 .join(", ");
1735 format!("InList({} IN [{}])", format_expression_ast(expr), list_str)
1736 }
1737 SqlExpression::NotInList { expr, values } => {
1738 let list_str = values
1739 .iter()
1740 .map(|e| format_expression_ast(e))
1741 .collect::<Vec<_>>()
1742 .join(", ");
1743 format!(
1744 "NotInList({} NOT IN [{}])",
1745 format_expression_ast(expr),
1746 list_str
1747 )
1748 }
1749 SqlExpression::Between { expr, lower, upper } => {
1750 format!(
1751 "Between({} BETWEEN {} AND {})",
1752 format_expression_ast(expr),
1753 format_expression_ast(lower),
1754 format_expression_ast(upper)
1755 )
1756 }
1757 SqlExpression::Not { expr } => {
1758 format!("Not({})", format_expression_ast(expr))
1759 }
1760 SqlExpression::CaseExpression {
1761 when_branches,
1762 else_branch,
1763 } => {
1764 let when_strs: Vec<String> = when_branches
1765 .iter()
1766 .map(|branch| {
1767 format!(
1768 "WHEN {} THEN {}",
1769 format_expression_ast(&branch.condition),
1770 format_expression_ast(&branch.result)
1771 )
1772 })
1773 .collect();
1774 let else_str = else_branch
1775 .as_ref()
1776 .map(|e| format!(" ELSE {}", format_expression_ast(e)))
1777 .unwrap_or_default();
1778 format!("CASE {} {} END", when_strs.join(" "), else_str)
1779 }
1780 }
1781}
1782
1783pub fn datetime_to_iso_string(expr: &SqlExpression) -> Option<String> {
1785 match expr {
1786 SqlExpression::DateTimeConstructor {
1787 year,
1788 month,
1789 day,
1790 hour,
1791 minute,
1792 second,
1793 } => {
1794 let h = hour.unwrap_or(0);
1795 let m = minute.unwrap_or(0);
1796 let s = second.unwrap_or(0);
1797
1798 if let Ok(dt) = NaiveDateTime::parse_from_str(
1800 &format!(
1801 "{:04}-{:02}-{:02} {:02}:{:02}:{:02}",
1802 year, month, day, h, m, s
1803 ),
1804 "%Y-%m-%d %H:%M:%S",
1805 ) {
1806 Some(dt.format("%Y-%m-%d %H:%M:%S").to_string())
1807 } else {
1808 None
1809 }
1810 }
1811 SqlExpression::DateTimeToday {
1812 hour,
1813 minute,
1814 second,
1815 } => {
1816 let now = Local::now();
1817 let h = hour.unwrap_or(0);
1818 let m = minute.unwrap_or(0);
1819 let s = second.unwrap_or(0);
1820
1821 if let Ok(dt) = NaiveDateTime::parse_from_str(
1823 &format!(
1824 "{:04}-{:02}-{:02} {:02}:{:02}:{:02}",
1825 now.year(),
1826 now.month(),
1827 now.day(),
1828 h,
1829 m,
1830 s
1831 ),
1832 "%Y-%m-%d %H:%M:%S",
1833 ) {
1834 Some(dt.format("%Y-%m-%d %H:%M:%S").to_string())
1835 } else {
1836 None
1837 }
1838 }
1839 _ => None,
1840 }
1841}
1842
1843fn format_sql_with_preserved_parens(
1845 query: &str,
1846 cols_per_line: usize,
1847) -> Result<Vec<String>, String> {
1848 let mut lines = Vec::new();
1849 let mut lexer = Lexer::new(query);
1850 let tokens_with_pos = lexer.tokenize_all_with_positions();
1851
1852 if tokens_with_pos.is_empty() {
1853 return Err("No tokens found".to_string());
1854 }
1855
1856 let mut i = 0;
1857 let cols_per_line = cols_per_line.max(1);
1858
1859 while i < tokens_with_pos.len() {
1860 let (start, _end, ref token) = tokens_with_pos[i];
1861
1862 match token {
1863 Token::Select => {
1864 lines.push("SELECT".to_string());
1865 i += 1;
1866
1867 let mut columns = Vec::new();
1869 let mut col_start = i;
1870 while i < tokens_with_pos.len() {
1871 match &tokens_with_pos[i].2 {
1872 Token::From | Token::Eof => break,
1873 Token::Comma => {
1874 if col_start < i {
1876 let col_text = extract_text_between_positions(
1877 query,
1878 tokens_with_pos[col_start].0,
1879 tokens_with_pos[i - 1].1,
1880 );
1881 columns.push(col_text);
1882 }
1883 i += 1;
1884 col_start = i;
1885 }
1886 _ => i += 1,
1887 }
1888 }
1889 if col_start < i && i > 0 {
1891 let col_text = extract_text_between_positions(
1892 query,
1893 tokens_with_pos[col_start].0,
1894 tokens_with_pos[i - 1].1,
1895 );
1896 columns.push(col_text);
1897 }
1898
1899 for chunk in columns.chunks(cols_per_line) {
1901 let mut line = " ".to_string();
1902 for (idx, col) in chunk.iter().enumerate() {
1903 if idx > 0 {
1904 line.push_str(", ");
1905 }
1906 line.push_str(col.trim());
1907 }
1908 let is_last_chunk = chunk.as_ptr() as usize
1910 + chunk.len() * std::mem::size_of::<String>()
1911 >= columns.last().map(|c| c as *const _ as usize).unwrap_or(0);
1912 if !is_last_chunk && columns.len() > cols_per_line {
1913 line.push(',');
1914 }
1915 lines.push(line);
1916 }
1917 }
1918 Token::From => {
1919 i += 1;
1920 if i < tokens_with_pos.len() {
1921 let table_start = tokens_with_pos[i].0;
1922 while i < tokens_with_pos.len() {
1924 match &tokens_with_pos[i].2 {
1925 Token::Where | Token::OrderBy | Token::GroupBy | Token::Eof => break,
1926 _ => i += 1,
1927 }
1928 }
1929 if i > 0 {
1930 let table_text = extract_text_between_positions(
1931 query,
1932 table_start,
1933 tokens_with_pos[i - 1].1,
1934 );
1935 lines.push(format!("FROM {}", table_text.trim()));
1936 }
1937 }
1938 }
1939 Token::Where => {
1940 lines.push("WHERE".to_string());
1941 i += 1;
1942
1943 let where_start = if i < tokens_with_pos.len() {
1945 tokens_with_pos[i].0
1946 } else {
1947 start
1948 };
1949
1950 let mut where_end = query.len();
1952 while i < tokens_with_pos.len() {
1953 match &tokens_with_pos[i].2 {
1954 Token::OrderBy | Token::GroupBy | Token::Eof => {
1955 if i > 0 {
1956 where_end = tokens_with_pos[i - 1].1;
1957 }
1958 break;
1959 }
1960 _ => i += 1,
1961 }
1962 }
1963
1964 let where_text = extract_text_between_positions(query, where_start, where_end);
1966
1967 let formatted_where = format_where_clause_with_parens(&where_text);
1969 for line in formatted_where {
1970 lines.push(format!(" {}", line));
1971 }
1972 }
1973 Token::OrderBy => {
1974 i += 1;
1975 let order_start = if i < tokens_with_pos.len() {
1976 tokens_with_pos[i].0
1977 } else {
1978 start
1979 };
1980
1981 while i < tokens_with_pos.len() {
1983 match &tokens_with_pos[i].2 {
1984 Token::GroupBy | Token::Eof => break,
1985 _ => i += 1,
1986 }
1987 }
1988
1989 if i > 0 {
1990 let order_text = extract_text_between_positions(
1991 query,
1992 order_start,
1993 tokens_with_pos[i - 1].1,
1994 );
1995 lines.push(format!("ORDER BY {}", order_text.trim()));
1996 }
1997 }
1998 Token::GroupBy => {
1999 i += 1;
2000 let group_start = if i < tokens_with_pos.len() {
2001 tokens_with_pos[i].0
2002 } else {
2003 start
2004 };
2005
2006 while i < tokens_with_pos.len() {
2008 match &tokens_with_pos[i].2 {
2009 Token::Having | Token::Eof => break,
2010 _ => i += 1,
2011 }
2012 }
2013
2014 if i > 0 {
2015 let group_text = extract_text_between_positions(
2016 query,
2017 group_start,
2018 tokens_with_pos[i - 1].1,
2019 );
2020 lines.push(format!("GROUP BY {}", group_text.trim()));
2021 }
2022 }
2023 _ => i += 1,
2024 }
2025 }
2026
2027 Ok(lines)
2028}
2029
2030fn extract_text_between_positions(query: &str, start: usize, end: usize) -> String {
2032 let chars: Vec<char> = query.chars().collect();
2033 let start = start.min(chars.len());
2034 let end = end.min(chars.len());
2035 chars[start..end].iter().collect()
2036}
2037
2038fn format_where_clause_with_parens(where_text: &str) -> Vec<String> {
2040 let mut lines = Vec::new();
2041 let mut current_line = String::new();
2042 let mut paren_depth = 0;
2043 let mut i = 0;
2044 let chars: Vec<char> = where_text.chars().collect();
2045
2046 while i < chars.len() {
2047 if paren_depth == 0 {
2049 if i + 5 <= chars.len() {
2051 let next_five: String = chars[i..i + 5].iter().collect();
2052 if next_five.to_uppercase() == " AND " {
2053 if !current_line.trim().is_empty() {
2054 lines.push(current_line.trim().to_string());
2055 }
2056 lines.push("AND".to_string());
2057 current_line.clear();
2058 i += 5;
2059 continue;
2060 }
2061 }
2062 if i + 4 <= chars.len() {
2063 let next_four: String = chars[i..i + 4].iter().collect();
2064 if next_four.to_uppercase() == " OR " {
2065 if !current_line.trim().is_empty() {
2066 lines.push(current_line.trim().to_string());
2067 }
2068 lines.push("OR".to_string());
2069 current_line.clear();
2070 i += 4;
2071 continue;
2072 }
2073 }
2074 }
2075
2076 match chars[i] {
2078 '(' => {
2079 paren_depth += 1;
2080 current_line.push('(');
2081 }
2082 ')' => {
2083 paren_depth -= 1;
2084 current_line.push(')');
2085 }
2086 c => current_line.push(c),
2087 }
2088 i += 1;
2089 }
2090
2091 if !current_line.trim().is_empty() {
2093 lines.push(current_line.trim().to_string());
2094 }
2095
2096 if lines.is_empty() {
2098 lines.push(where_text.trim().to_string());
2099 }
2100
2101 lines
2102}
2103
2104pub fn format_sql_pretty_compact(query: &str, cols_per_line: usize) -> Vec<String> {
2105 if let Ok(lines) = format_sql_with_preserved_parens(query, cols_per_line) {
2107 return lines;
2108 }
2109
2110 let mut lines = Vec::new();
2112 let mut parser = Parser::new(query);
2113
2114 let cols_per_line = cols_per_line.max(1);
2116
2117 match parser.parse() {
2118 Ok(stmt) => {
2119 if !stmt.columns.is_empty() {
2121 lines.push("SELECT".to_string());
2122
2123 for chunk in stmt.columns.chunks(cols_per_line) {
2125 let mut line = " ".to_string();
2126 for (i, col) in chunk.iter().enumerate() {
2127 if i > 0 {
2128 line.push_str(", ");
2129 }
2130 line.push_str(col);
2131 }
2132 let last_chunk_idx = (stmt.columns.len() - 1) / cols_per_line;
2134 let current_chunk_idx =
2135 stmt.columns.iter().position(|c| c == &chunk[0]).unwrap() / cols_per_line;
2136 if current_chunk_idx < last_chunk_idx {
2137 line.push(',');
2138 }
2139 lines.push(line);
2140 }
2141 }
2142
2143 if let Some(table) = &stmt.from_table {
2145 lines.push(format!("FROM {}", table));
2146 }
2147
2148 if let Some(where_clause) = &stmt.where_clause {
2150 lines.push("WHERE".to_string());
2151 for (i, condition) in where_clause.conditions.iter().enumerate() {
2152 if i > 0 {
2153 if let Some(prev_condition) = where_clause.conditions.get(i - 1) {
2155 if let Some(connector) = &prev_condition.connector {
2156 match connector {
2157 LogicalOp::And => lines.push(" AND".to_string()),
2158 LogicalOp::Or => lines.push(" OR".to_string()),
2159 }
2160 }
2161 }
2162 }
2163 lines.push(format!(" {}", format_expression(&condition.expr)));
2164 }
2165 }
2166
2167 if let Some(order_by) = &stmt.order_by {
2169 let order_str = order_by
2170 .iter()
2171 .map(|col| {
2172 let dir = match col.direction {
2173 SortDirection::Asc => " ASC",
2174 SortDirection::Desc => " DESC",
2175 };
2176 format!("{}{}", col.column, dir)
2177 })
2178 .collect::<Vec<_>>()
2179 .join(", ");
2180 lines.push(format!("ORDER BY {}", order_str));
2181 }
2182
2183 if let Some(group_by) = &stmt.group_by {
2185 let group_str = group_by.join(", ");
2186 lines.push(format!("GROUP BY {}", group_str));
2187 }
2188 }
2189 Err(_) => {
2190 let mut lexer = Lexer::new(query);
2192 let tokens = lexer.tokenize_all();
2193 let mut current_line = String::new();
2194 let mut indent = 0;
2195
2196 for token in tokens {
2197 match &token {
2198 Token::Select
2199 | Token::From
2200 | Token::Where
2201 | Token::OrderBy
2202 | Token::GroupBy => {
2203 if !current_line.is_empty() {
2204 lines.push(current_line.trim().to_string());
2205 current_line.clear();
2206 }
2207 lines.push(format!("{:?}", token).to_uppercase());
2208 indent = 1;
2209 }
2210 Token::And | Token::Or => {
2211 if !current_line.is_empty() {
2212 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2213 current_line.clear();
2214 }
2215 lines.push(format!(" {:?}", token).to_uppercase());
2216 }
2217 Token::Comma => {
2218 current_line.push(',');
2219 if indent > 0 {
2220 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2221 current_line.clear();
2222 }
2223 }
2224 Token::Eof => break,
2225 _ => {
2226 if !current_line.is_empty() {
2227 current_line.push(' ');
2228 }
2229 current_line.push_str(&format_token(&token));
2230 }
2231 }
2232 }
2233
2234 if !current_line.is_empty() {
2235 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2236 }
2237 }
2238 }
2239
2240 lines
2241}
2242
2243fn format_expression(expr: &SqlExpression) -> String {
2244 match expr {
2245 SqlExpression::Column(name) => name.clone(),
2246 SqlExpression::StringLiteral(s) => format!("'{}'", s),
2247 SqlExpression::NumberLiteral(n) => n.clone(),
2248 SqlExpression::DateTimeConstructor {
2249 year,
2250 month,
2251 day,
2252 hour,
2253 minute,
2254 second,
2255 } => {
2256 let mut result = format!("DateTime({}, {}, {}", year, month, day);
2257 if let Some(h) = hour {
2258 result.push_str(&format!(", {}", h));
2259 if let Some(m) = minute {
2260 result.push_str(&format!(", {}", m));
2261 if let Some(s) = second {
2262 result.push_str(&format!(", {}", s));
2263 }
2264 }
2265 }
2266 result.push(')');
2267 result
2268 }
2269 SqlExpression::DateTimeToday {
2270 hour,
2271 minute,
2272 second,
2273 } => {
2274 let mut result = "DateTime()".to_string();
2275 if let Some(h) = hour {
2276 result = format!("DateTime(TODAY, {}", h);
2277 if let Some(m) = minute {
2278 result.push_str(&format!(", {}", m));
2279 if let Some(s) = second {
2280 result.push_str(&format!(", {}", s));
2281 }
2282 }
2283 result.push(')');
2284 }
2285 result
2286 }
2287 SqlExpression::MethodCall {
2288 object,
2289 method,
2290 args,
2291 } => {
2292 let args_str = args
2293 .iter()
2294 .map(|arg| format_expression(arg))
2295 .collect::<Vec<_>>()
2296 .join(", ");
2297 format!("{}.{}({})", object, method, args_str)
2298 }
2299 SqlExpression::BinaryOp { left, op, right } => {
2300 if op == "OR" || op == "AND" {
2303 format!(
2306 "({} {} {})",
2307 format_expression(left),
2308 op,
2309 format_expression(right)
2310 )
2311 } else {
2312 format!(
2313 "{} {} {}",
2314 format_expression(left),
2315 op,
2316 format_expression(right)
2317 )
2318 }
2319 }
2320 SqlExpression::InList { expr, values } => {
2321 let values_str = values
2322 .iter()
2323 .map(|v| format_expression(v))
2324 .collect::<Vec<_>>()
2325 .join(", ");
2326 format!("{} IN ({})", format_expression(expr), values_str)
2327 }
2328 SqlExpression::NotInList { expr, values } => {
2329 let values_str = values
2330 .iter()
2331 .map(|v| format_expression(v))
2332 .collect::<Vec<_>>()
2333 .join(", ");
2334 format!("{} NOT IN ({})", format_expression(expr), values_str)
2335 }
2336 SqlExpression::Between { expr, lower, upper } => {
2337 format!(
2338 "{} BETWEEN {} AND {}",
2339 format_expression(expr),
2340 format_expression(lower),
2341 format_expression(upper)
2342 )
2343 }
2344 SqlExpression::Not { expr } => {
2345 format!("NOT {}", format_expression(expr))
2346 }
2347 SqlExpression::ChainedMethodCall { base, method, args } => {
2348 let args_str = args
2349 .iter()
2350 .map(|arg| format_expression(arg))
2351 .collect::<Vec<_>>()
2352 .join(", ");
2353 format!("{}.{}({})", format_expression(base), method, args_str)
2354 }
2355 SqlExpression::FunctionCall { name, args } => {
2356 let args_str = args
2357 .iter()
2358 .map(|arg| format_expression(arg))
2359 .collect::<Vec<_>>()
2360 .join(", ");
2361 format!("{}({})", name, args_str)
2362 }
2363 SqlExpression::CaseExpression {
2364 when_branches,
2365 else_branch,
2366 } => {
2367 let mut result = String::from("CASE");
2368 for branch in when_branches {
2369 result.push_str(&format!(
2370 " WHEN {} THEN {}",
2371 format_expression(&branch.condition),
2372 format_expression(&branch.result)
2373 ));
2374 }
2375 if let Some(else_expr) = else_branch {
2376 result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
2377 }
2378 result.push_str(" END");
2379 result
2380 }
2381 }
2382}
2383
2384fn format_token(token: &Token) -> String {
2385 match token {
2386 Token::Identifier(s) => s.clone(),
2387 Token::QuotedIdentifier(s) => format!("\"{}\"", s),
2388 Token::StringLiteral(s) => format!("'{}'", s),
2389 Token::NumberLiteral(n) => n.clone(),
2390 Token::DateTime => "DateTime".to_string(),
2391 Token::Case => "CASE".to_string(),
2392 Token::When => "WHEN".to_string(),
2393 Token::Then => "THEN".to_string(),
2394 Token::Else => "ELSE".to_string(),
2395 Token::End => "END".to_string(),
2396 Token::Distinct => "DISTINCT".to_string(),
2397 Token::LeftParen => "(".to_string(),
2398 Token::RightParen => ")".to_string(),
2399 Token::Comma => ",".to_string(),
2400 Token::Dot => ".".to_string(),
2401 Token::Equal => "=".to_string(),
2402 Token::NotEqual => "!=".to_string(),
2403 Token::LessThan => "<".to_string(),
2404 Token::GreaterThan => ">".to_string(),
2405 Token::LessThanOrEqual => "<=".to_string(),
2406 Token::GreaterThanOrEqual => ">=".to_string(),
2407 Token::In => "IN".to_string(),
2408 _ => format!("{:?}", token).to_uppercase(),
2409 }
2410}
2411
2412fn analyze_statement(
2413 stmt: &SelectStatement,
2414 query: &str,
2415 _cursor_pos: usize,
2416) -> (CursorContext, Option<String>) {
2417 let trimmed = query.trim();
2419
2420 let comparison_ops = [" > ", " < ", " >= ", " <= ", " = ", " != "];
2422 for op in &comparison_ops {
2423 if let Some(op_pos) = query.rfind(op) {
2424 let before_op = safe_slice_to(query, op_pos);
2425 let after_op_start = op_pos + op.len();
2426 let after_op = if after_op_start < query.len() {
2427 &query[after_op_start..]
2428 } else {
2429 ""
2430 };
2431
2432 if let Some(col_name) = before_op.split_whitespace().last() {
2434 if col_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
2435 let after_op_trimmed = after_op.trim();
2437 if after_op_trimmed.is_empty()
2438 || (after_op_trimmed
2439 .chars()
2440 .all(|c| c.is_alphanumeric() || c == '_')
2441 && !after_op_trimmed.contains('('))
2442 {
2443 let partial = if after_op_trimmed.is_empty() {
2444 None
2445 } else {
2446 Some(after_op_trimmed.to_string())
2447 };
2448 return (
2449 CursorContext::AfterComparisonOp(
2450 col_name.to_string(),
2451 op.trim().to_string(),
2452 ),
2453 partial,
2454 );
2455 }
2456 }
2457 }
2458 }
2459 }
2460
2461 if trimmed.to_uppercase().ends_with(" AND")
2463 || trimmed.to_uppercase().ends_with(" OR")
2464 || trimmed.to_uppercase().ends_with(" AND ")
2465 || trimmed.to_uppercase().ends_with(" OR ")
2466 {
2467 } else {
2469 if let Some(dot_pos) = trimmed.rfind('.') {
2471 let before_dot = safe_slice_to(trimmed, dot_pos);
2473 let after_dot_start = dot_pos + 1;
2474 let after_dot = if after_dot_start < trimmed.len() {
2475 &trimmed[after_dot_start..]
2476 } else {
2477 ""
2478 };
2479
2480 if !after_dot.contains('(') {
2483 let col_name = if before_dot.ends_with('"') {
2485 let bytes = before_dot.as_bytes();
2487 let mut pos = before_dot.len() - 1; let mut found_start = None;
2489
2490 if pos > 0 {
2492 pos -= 1;
2493 while pos > 0 {
2494 if bytes[pos] == b'"' {
2495 if pos == 0 || bytes[pos - 1] != b'\\' {
2497 found_start = Some(pos);
2498 break;
2499 }
2500 }
2501 pos -= 1;
2502 }
2503 if found_start.is_none() && bytes[0] == b'"' {
2505 found_start = Some(0);
2506 }
2507 }
2508
2509 if let Some(start) = found_start {
2510 Some(safe_slice_from(before_dot, start))
2512 } else {
2513 None
2514 }
2515 } else {
2516 before_dot
2519 .split_whitespace()
2520 .last()
2521 .map(|word| word.trim_start_matches('('))
2522 };
2523
2524 if let Some(col_name) = col_name {
2525 let is_valid = if col_name.starts_with('"') && col_name.ends_with('"') {
2527 true
2529 } else {
2530 col_name.chars().all(|c| c.is_alphanumeric() || c == '_')
2532 };
2533
2534 if is_valid {
2535 let partial_method = if after_dot.is_empty() {
2538 None
2539 } else if after_dot.chars().all(|c| c.is_alphanumeric() || c == '_') {
2540 Some(after_dot.to_string())
2541 } else {
2542 None
2543 };
2544
2545 let col_name_for_context = if col_name.starts_with('"')
2547 && col_name.ends_with('"')
2548 && col_name.len() > 2
2549 {
2550 col_name[1..col_name.len() - 1].to_string()
2551 } else {
2552 col_name.to_string()
2553 };
2554
2555 return (
2556 CursorContext::AfterColumn(col_name_for_context),
2557 partial_method,
2558 );
2559 }
2560 }
2561 }
2562 }
2563 }
2564
2565 if let Some(where_clause) = &stmt.where_clause {
2567 if trimmed.to_uppercase().ends_with(" AND") || trimmed.to_uppercase().ends_with(" OR") {
2569 let op = if trimmed.to_uppercase().ends_with(" AND") {
2570 LogicalOp::And
2571 } else {
2572 LogicalOp::Or
2573 };
2574 return (CursorContext::AfterLogicalOp(op), None);
2575 }
2576
2577 if let Some(and_pos) = query.to_uppercase().rfind(" AND ") {
2579 let after_and = safe_slice_from(query, and_pos + 5);
2580 let partial = extract_partial_at_end(after_and);
2581 if partial.is_some() {
2582 return (CursorContext::AfterLogicalOp(LogicalOp::And), partial);
2583 }
2584 }
2585
2586 if let Some(or_pos) = query.to_uppercase().rfind(" OR ") {
2587 let after_or = safe_slice_from(query, or_pos + 4);
2588 let partial = extract_partial_at_end(after_or);
2589 if partial.is_some() {
2590 return (CursorContext::AfterLogicalOp(LogicalOp::Or), partial);
2591 }
2592 }
2593
2594 if let Some(last_condition) = where_clause.conditions.last() {
2595 if let Some(connector) = &last_condition.connector {
2596 return (
2598 CursorContext::AfterLogicalOp(connector.clone()),
2599 extract_partial_at_end(query),
2600 );
2601 }
2602 }
2603 return (CursorContext::WhereClause, extract_partial_at_end(query));
2605 }
2606
2607 if query.to_uppercase().ends_with(" ORDER BY ") || query.to_uppercase().ends_with(" ORDER BY") {
2609 return (CursorContext::OrderByClause, None);
2610 }
2611
2612 if stmt.order_by.is_some() {
2614 return (CursorContext::OrderByClause, extract_partial_at_end(query));
2615 }
2616
2617 if stmt.from_table.is_some() && stmt.where_clause.is_none() && stmt.order_by.is_none() {
2618 return (CursorContext::FromClause, extract_partial_at_end(query));
2619 }
2620
2621 if stmt.columns.len() > 0 && stmt.from_table.is_none() {
2622 return (CursorContext::SelectClause, extract_partial_at_end(query));
2623 }
2624
2625 (CursorContext::Unknown, None)
2626}
2627
2628fn analyze_partial(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
2629 let upper = query.to_uppercase();
2630
2631 let trimmed = query.trim();
2633
2634 #[cfg(test)]
2635 {
2636 if trimmed.contains("\"Last Name\"") {
2637 eprintln!(
2638 "DEBUG analyze_partial: query='{}', trimmed='{}'",
2639 query, trimmed
2640 );
2641 }
2642 }
2643
2644 let comparison_ops = [" > ", " < ", " >= ", " <= ", " = ", " != "];
2646 for op in &comparison_ops {
2647 if let Some(op_pos) = query.rfind(op) {
2648 let before_op = safe_slice_to(query, op_pos);
2649 let after_op_start = op_pos + op.len();
2650 let after_op = if after_op_start < query.len() {
2651 &query[after_op_start..]
2652 } else {
2653 ""
2654 };
2655
2656 if let Some(col_name) = before_op.split_whitespace().last() {
2658 if col_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
2659 let after_op_trimmed = after_op.trim();
2661 if after_op_trimmed.is_empty()
2662 || (after_op_trimmed
2663 .chars()
2664 .all(|c| c.is_alphanumeric() || c == '_')
2665 && !after_op_trimmed.contains('('))
2666 {
2667 let partial = if after_op_trimmed.is_empty() {
2668 None
2669 } else {
2670 Some(after_op_trimmed.to_string())
2671 };
2672 return (
2673 CursorContext::AfterComparisonOp(
2674 col_name.to_string(),
2675 op.trim().to_string(),
2676 ),
2677 partial,
2678 );
2679 }
2680 }
2681 }
2682 }
2683 }
2684
2685 if let Some(dot_pos) = trimmed.rfind('.') {
2688 #[cfg(test)]
2689 {
2690 if trimmed.contains("\"Last Name\"") {
2691 eprintln!("DEBUG: Found dot at position {}", dot_pos);
2692 }
2693 }
2694 let before_dot = &trimmed[..dot_pos];
2696 let after_dot = &trimmed[dot_pos + 1..];
2697
2698 if !after_dot.contains('(') {
2701 let col_name = if before_dot.ends_with('"') {
2704 let bytes = before_dot.as_bytes();
2706 let mut pos = before_dot.len() - 1; let mut found_start = None;
2708
2709 #[cfg(test)]
2710 {
2711 if trimmed.contains("\"Last Name\"") {
2712 eprintln!(
2713 "DEBUG: before_dot='{}', looking for opening quote",
2714 before_dot
2715 );
2716 }
2717 }
2718
2719 if pos > 0 {
2721 pos -= 1;
2722 while pos > 0 {
2723 if bytes[pos] == b'"' {
2724 if pos == 0 || bytes[pos - 1] != b'\\' {
2726 found_start = Some(pos);
2727 break;
2728 }
2729 }
2730 pos -= 1;
2731 }
2732 if found_start.is_none() && bytes[0] == b'"' {
2734 found_start = Some(0);
2735 }
2736 }
2737
2738 if let Some(start) = found_start {
2739 let result = safe_slice_from(before_dot, start);
2741 #[cfg(test)]
2742 {
2743 if trimmed.contains("\"Last Name\"") {
2744 eprintln!("DEBUG: Extracted quoted identifier: '{}'", result);
2745 }
2746 }
2747 Some(result)
2748 } else {
2749 #[cfg(test)]
2750 {
2751 if trimmed.contains("\"Last Name\"") {
2752 eprintln!("DEBUG: No opening quote found!");
2753 }
2754 }
2755 None
2756 }
2757 } else {
2758 before_dot
2761 .split_whitespace()
2762 .last()
2763 .map(|word| word.trim_start_matches('('))
2764 };
2765
2766 if let Some(col_name) = col_name {
2767 #[cfg(test)]
2768 {
2769 if trimmed.contains("\"Last Name\"") {
2770 eprintln!("DEBUG: col_name = '{}'", col_name);
2771 }
2772 }
2773
2774 let is_valid = if col_name.starts_with('"') && col_name.ends_with('"') {
2776 true
2778 } else {
2779 col_name.chars().all(|c| c.is_alphanumeric() || c == '_')
2781 };
2782
2783 #[cfg(test)]
2784 {
2785 if trimmed.contains("\"Last Name\"") {
2786 eprintln!("DEBUG: is_valid = {}", is_valid);
2787 }
2788 }
2789
2790 if is_valid {
2791 let partial_method = if after_dot.is_empty() {
2794 None
2795 } else if after_dot.chars().all(|c| c.is_alphanumeric() || c == '_') {
2796 Some(after_dot.to_string())
2797 } else {
2798 None
2799 };
2800
2801 let col_name_for_context = if col_name.starts_with('"')
2803 && col_name.ends_with('"')
2804 && col_name.len() > 2
2805 {
2806 col_name[1..col_name.len() - 1].to_string()
2807 } else {
2808 col_name.to_string()
2809 };
2810
2811 return (
2812 CursorContext::AfterColumn(col_name_for_context),
2813 partial_method,
2814 );
2815 }
2816 }
2817 }
2818 }
2819
2820 if let Some(and_pos) = upper.rfind(" AND ") {
2822 if cursor_pos >= and_pos + 5 {
2824 let after_and = safe_slice_from(query, and_pos + 5);
2826 let partial = extract_partial_at_end(after_and);
2827 return (CursorContext::AfterLogicalOp(LogicalOp::And), partial);
2828 }
2829 }
2830
2831 if let Some(or_pos) = upper.rfind(" OR ") {
2832 if cursor_pos >= or_pos + 4 {
2834 let after_or = safe_slice_from(query, or_pos + 4);
2836 let partial = extract_partial_at_end(after_or);
2837 return (CursorContext::AfterLogicalOp(LogicalOp::Or), partial);
2838 }
2839 }
2840
2841 if trimmed.to_uppercase().ends_with(" AND") || trimmed.to_uppercase().ends_with(" OR") {
2843 let op = if trimmed.to_uppercase().ends_with(" AND") {
2844 LogicalOp::And
2845 } else {
2846 LogicalOp::Or
2847 };
2848 return (CursorContext::AfterLogicalOp(op), None);
2849 }
2850
2851 if upper.ends_with(" ORDER BY ") || upper.ends_with(" ORDER BY") || upper.contains("ORDER BY ")
2853 {
2854 return (CursorContext::OrderByClause, extract_partial_at_end(query));
2855 }
2856
2857 if upper.contains("WHERE") && !upper.contains("ORDER") && !upper.contains("GROUP") {
2858 return (CursorContext::WhereClause, extract_partial_at_end(query));
2859 }
2860
2861 if upper.contains("FROM") && !upper.contains("WHERE") && !upper.contains("ORDER") {
2862 return (CursorContext::FromClause, extract_partial_at_end(query));
2863 }
2864
2865 if upper.contains("SELECT") && !upper.contains("FROM") {
2866 return (CursorContext::SelectClause, extract_partial_at_end(query));
2867 }
2868
2869 (CursorContext::Unknown, None)
2870}
2871
2872fn extract_partial_at_end(query: &str) -> Option<String> {
2873 let trimmed = query.trim();
2874
2875 if let Some(last_word) = trimmed.split_whitespace().last() {
2877 if last_word.starts_with('"') && !last_word.ends_with('"') {
2878 return Some(last_word.to_string());
2880 }
2881 }
2882
2883 let last_word = trimmed.split_whitespace().last()?;
2885
2886 if last_word.chars().all(|c| c.is_alphanumeric() || c == '_') && !is_sql_keyword(last_word) {
2888 Some(last_word.to_string())
2889 } else {
2890 None
2891 }
2892}
2893
2894fn is_sql_keyword(word: &str) -> bool {
2895 matches!(
2896 word.to_uppercase().as_str(),
2897 "SELECT"
2898 | "FROM"
2899 | "WHERE"
2900 | "AND"
2901 | "OR"
2902 | "IN"
2903 | "ORDER"
2904 | "BY"
2905 | "GROUP"
2906 | "HAVING"
2907 | "ASC"
2908 | "DESC"
2909 | "DISTINCT"
2910 )
2911}
2912
2913#[cfg(test)]
2914mod tests {
2915 use super::*;
2916
2917 #[test]
2918 fn test_chained_method_calls() {
2919 let query = "SELECT * FROM trades WHERE commission.ToString().IndexOf('.') = 1";
2921 let mut parser = Parser::new(query);
2922 let result = parser.parse();
2923
2924 assert!(
2925 result.is_ok(),
2926 "Failed to parse chained method calls: {:?}",
2927 result
2928 );
2929
2930 let query2 = "SELECT * FROM data WHERE field.ToUpper().Replace('A', 'B').StartsWith('C')";
2932 let mut parser2 = Parser::new(query2);
2933 let result2 = parser2.parse();
2934
2935 assert!(
2936 result2.is_ok(),
2937 "Failed to parse multiple chained calls: {:?}",
2938 result2
2939 );
2940 }
2941
2942 #[test]
2943 fn test_tokenizer() {
2944 let mut lexer = Lexer::new("SELECT * FROM trade_deal WHERE price > 100");
2945
2946 assert!(matches!(lexer.next_token(), Token::Select));
2947 assert!(matches!(lexer.next_token(), Token::Star));
2948 assert!(matches!(lexer.next_token(), Token::From));
2949 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "trade_deal"));
2950 assert!(matches!(lexer.next_token(), Token::Where));
2951 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "price"));
2952 assert!(matches!(lexer.next_token(), Token::GreaterThan));
2953 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "100"));
2954 }
2955
2956 #[test]
2957 fn test_tokenizer_datetime() {
2958 let mut lexer = Lexer::new("WHERE createdDate > DateTime(2025, 10, 20)");
2959
2960 assert!(matches!(lexer.next_token(), Token::Where));
2961 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "createdDate"));
2962 assert!(matches!(lexer.next_token(), Token::GreaterThan));
2963 assert!(matches!(lexer.next_token(), Token::DateTime));
2964 assert!(matches!(lexer.next_token(), Token::LeftParen));
2965 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "2025"));
2966 assert!(matches!(lexer.next_token(), Token::Comma));
2967 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "10"));
2968 assert!(matches!(lexer.next_token(), Token::Comma));
2969 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "20"));
2970 assert!(matches!(lexer.next_token(), Token::RightParen));
2971 }
2972
2973 #[test]
2974 fn test_parse_simple_select() {
2975 let mut parser = Parser::new("SELECT * FROM trade_deal");
2976 let stmt = parser.parse().unwrap();
2977
2978 assert_eq!(stmt.columns, vec!["*"]);
2979 assert_eq!(stmt.from_table, Some("trade_deal".to_string()));
2980 assert!(stmt.where_clause.is_none());
2981 }
2982
2983 #[test]
2984 fn test_parse_where_with_method() {
2985 let mut parser = Parser::new("SELECT * FROM trade_deal WHERE name.Contains(\"test\")");
2986 let stmt = parser.parse().unwrap();
2987
2988 assert!(stmt.where_clause.is_some());
2989 let where_clause = stmt.where_clause.unwrap();
2990 assert_eq!(where_clause.conditions.len(), 1);
2991 }
2992
2993 #[test]
2994 fn test_parse_datetime_constructor() {
2995 let mut parser =
2996 Parser::new("SELECT * FROM trade_deal WHERE createdDate > DateTime(2025, 10, 20)");
2997 let stmt = parser.parse().unwrap();
2998
2999 assert!(stmt.where_clause.is_some());
3000 let where_clause = stmt.where_clause.unwrap();
3001 assert_eq!(where_clause.conditions.len(), 1);
3002
3003 if let SqlExpression::BinaryOp { left, op, right } = &where_clause.conditions[0].expr {
3005 assert_eq!(op, ">");
3006 assert!(matches!(left.as_ref(), SqlExpression::Column(col) if col == "createdDate"));
3007 assert!(matches!(
3008 right.as_ref(),
3009 SqlExpression::DateTimeConstructor {
3010 year: 2025,
3011 month: 10,
3012 day: 20,
3013 hour: None,
3014 minute: None,
3015 second: None
3016 }
3017 ));
3018 } else {
3019 panic!("Expected BinaryOp with DateTime constructor");
3020 }
3021 }
3022
3023 #[test]
3024 fn test_cursor_context_after_and() {
3025 let query = "SELECT * FROM trade_deal WHERE status = 'active' AND ";
3026 let (context, partial) = detect_cursor_context(query, query.len());
3027
3028 assert!(matches!(
3029 context,
3030 CursorContext::AfterLogicalOp(LogicalOp::And)
3031 ));
3032 assert_eq!(partial, None);
3033 }
3034
3035 #[test]
3036 fn test_cursor_context_with_partial() {
3037 let query = "SELECT * FROM trade_deal WHERE status = 'active' AND p";
3038 let (context, partial) = detect_cursor_context(query, query.len());
3039
3040 assert!(matches!(
3041 context,
3042 CursorContext::AfterLogicalOp(LogicalOp::And)
3043 ));
3044 assert_eq!(partial, Some("p".to_string()));
3045 }
3046
3047 #[test]
3048 fn test_cursor_context_after_datetime_comparison() {
3049 let query = "SELECT * FROM trade_deal WHERE createdDate > ";
3050 let (context, partial) = detect_cursor_context(query, query.len());
3051
3052 assert!(
3053 matches!(context, CursorContext::AfterComparisonOp(col, op) if col == "createdDate" && op == ">")
3054 );
3055 assert_eq!(partial, None);
3056 }
3057
3058 #[test]
3059 fn test_cursor_context_partial_datetime() {
3060 let query = "SELECT * FROM trade_deal WHERE createdDate > Date";
3061 let (context, partial) = detect_cursor_context(query, query.len());
3062
3063 assert!(
3064 matches!(context, CursorContext::AfterComparisonOp(col, op) if col == "createdDate" && op == ">")
3065 );
3066 assert_eq!(partial, Some("Date".to_string()));
3067 }
3068
3069 #[test]
3071 fn test_tokenizer_quoted_identifier() {
3072 let mut lexer = Lexer::new(r#"SELECT "Customer Id", "First Name" FROM customers"#);
3073
3074 assert!(matches!(lexer.next_token(), Token::Select));
3075 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "Customer Id"));
3076 assert!(matches!(lexer.next_token(), Token::Comma));
3077 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "First Name"));
3078 assert!(matches!(lexer.next_token(), Token::From));
3079 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "customers"));
3080 }
3081
3082 #[test]
3083 fn test_tokenizer_quoted_vs_string_literal() {
3084 let mut lexer = Lexer::new(r#"WHERE "Customer Id" = 'John' AND Country.Contains('USA')"#);
3086
3087 assert!(matches!(lexer.next_token(), Token::Where));
3088 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "Customer Id"));
3089 assert!(matches!(lexer.next_token(), Token::Equal));
3090 assert!(matches!(lexer.next_token(), Token::StringLiteral(s) if s == "John"));
3091 assert!(matches!(lexer.next_token(), Token::And));
3092 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Country"));
3093 assert!(matches!(lexer.next_token(), Token::Dot));
3094 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Contains"));
3095 assert!(matches!(lexer.next_token(), Token::LeftParen));
3096 assert!(matches!(lexer.next_token(), Token::StringLiteral(s) if s == "USA"));
3097 assert!(matches!(lexer.next_token(), Token::RightParen));
3098 }
3099
3100 #[test]
3101 fn test_tokenizer_method_with_double_quotes_should_be_string() {
3102 let mut lexer = Lexer::new(r#"Country.Contains("Alb")"#);
3105
3106 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Country"));
3107 assert!(matches!(lexer.next_token(), Token::Dot));
3108 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Contains"));
3109 assert!(matches!(lexer.next_token(), Token::LeftParen));
3110
3111 let token = lexer.next_token();
3114 println!("Token for \"Alb\": {:?}", token);
3115 assert!(matches!(lexer.next_token(), Token::RightParen));
3119 }
3120
3121 #[test]
3122 fn test_parse_select_with_quoted_columns() {
3123 let mut parser = Parser::new(r#"SELECT "Customer Id", Company FROM customers"#);
3124 let stmt = parser.parse().unwrap();
3125
3126 assert_eq!(stmt.columns, vec!["Customer Id", "Company"]);
3127 assert_eq!(stmt.from_table, Some("customers".to_string()));
3128 }
3129
3130 #[test]
3131 fn test_cursor_context_select_with_partial_quoted() {
3132 let query = r#"SELECT "Cust"#;
3134 let (context, partial) = detect_cursor_context(query, query.len() - 1); println!("Context: {:?}, Partial: {:?}", context, partial);
3137 assert!(matches!(context, CursorContext::SelectClause));
3138 }
3141
3142 #[test]
3143 fn test_cursor_context_select_after_comma_with_quoted() {
3144 let query = r#"SELECT Company, "Customer "#;
3146 let (context, partial) = detect_cursor_context(query, query.len());
3147
3148 println!("Context: {:?}, Partial: {:?}", context, partial);
3149 assert!(matches!(context, CursorContext::SelectClause));
3150 }
3152
3153 #[test]
3154 fn test_cursor_context_order_by_quoted() {
3155 let query = r#"SELECT * FROM customers ORDER BY "Cust"#;
3156 let (context, partial) = detect_cursor_context(query, query.len() - 1);
3157
3158 println!("Context: {:?}, Partial: {:?}", context, partial);
3159 assert!(matches!(context, CursorContext::OrderByClause));
3160 }
3162
3163 #[test]
3164 fn test_where_clause_with_quoted_column() {
3165 let mut parser = Parser::new(r#"SELECT * FROM customers WHERE "Customer Id" = 'C123'"#);
3166 let stmt = parser.parse().unwrap();
3167
3168 assert!(stmt.where_clause.is_some());
3169 let where_clause = stmt.where_clause.unwrap();
3170 assert_eq!(where_clause.conditions.len(), 1);
3171
3172 if let SqlExpression::BinaryOp { left, op, right } = &where_clause.conditions[0].expr {
3173 assert_eq!(op, "=");
3174 assert!(matches!(left.as_ref(), SqlExpression::Column(col) if col == "Customer Id"));
3175 assert!(matches!(right.as_ref(), SqlExpression::StringLiteral(s) if s == "C123"));
3176 } else {
3177 panic!("Expected BinaryOp");
3178 }
3179 }
3180
3181 #[test]
3182 fn test_parse_method_with_double_quotes_as_string() {
3183 let mut parser = Parser::new(r#"SELECT * FROM customers WHERE Country.Contains("USA")"#);
3185 let stmt = parser.parse().unwrap();
3186
3187 assert!(stmt.where_clause.is_some());
3188 let where_clause = stmt.where_clause.unwrap();
3189 assert_eq!(where_clause.conditions.len(), 1);
3190
3191 if let SqlExpression::MethodCall {
3192 object,
3193 method,
3194 args,
3195 } = &where_clause.conditions[0].expr
3196 {
3197 assert_eq!(object, "Country");
3198 assert_eq!(method, "Contains");
3199 assert_eq!(args.len(), 1);
3200 assert!(matches!(&args[0], SqlExpression::StringLiteral(s) if s == "USA"));
3202 } else {
3203 panic!("Expected MethodCall");
3204 }
3205 }
3206
3207 #[test]
3208 fn test_extract_partial_with_quoted_columns_in_query() {
3209 let query = r#"SELECT City,Company,Country,"Customer Id" FROM customers ORDER BY coun"#;
3211 let (context, partial) = detect_cursor_context(query, query.len());
3212
3213 assert!(matches!(context, CursorContext::OrderByClause));
3214 assert_eq!(
3215 partial,
3216 Some("coun".to_string()),
3217 "Should extract 'coun' as partial, not everything after the quoted column"
3218 );
3219 }
3220
3221 #[test]
3222 fn test_extract_partial_quoted_identifier_being_typed() {
3223 let query = r#"SELECT "Cust"#;
3225 let partial = extract_partial_at_end(query);
3226 assert_eq!(partial, Some("\"Cust".to_string()));
3227
3228 let query2 = r#"SELECT "Customer Id" FROM"#;
3230 let partial2 = extract_partial_at_end(query2);
3231 assert_eq!(partial2, None); }
3233
3234 #[test]
3236 fn test_complex_where_parentheses_basic() {
3237 let mut parser =
3239 Parser::new(r#"SELECT * FROM trades WHERE (status = "active" OR status = "pending")"#);
3240 let stmt = parser.parse().unwrap();
3241
3242 assert!(stmt.where_clause.is_some());
3243 let where_clause = stmt.where_clause.unwrap();
3244 assert_eq!(where_clause.conditions.len(), 1);
3245
3246 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[0].expr {
3248 assert_eq!(op, "OR");
3249 } else {
3250 panic!("Expected BinaryOp with OR");
3251 }
3252 }
3253
3254 #[test]
3255 fn test_complex_where_mixed_and_or_with_parens() {
3256 let mut parser = Parser::new(
3258 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#,
3259 );
3260 let stmt = parser.parse().unwrap();
3261
3262 assert!(stmt.where_clause.is_some());
3263 let where_clause = stmt.where_clause.unwrap();
3264 assert_eq!(where_clause.conditions.len(), 2);
3265
3266 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[0].expr {
3268 assert_eq!(op, "OR");
3269 } else {
3270 panic!("Expected first condition to be OR expression");
3271 }
3272
3273 assert!(matches!(
3275 where_clause.conditions[0].connector,
3276 Some(LogicalOp::And)
3277 ));
3278
3279 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[1].expr {
3281 assert_eq!(op, ">");
3282 } else {
3283 panic!("Expected second condition to be price > 100");
3284 }
3285 }
3286
3287 #[test]
3288 fn test_complex_where_nested_parentheses() {
3289 let mut parser = Parser::new(
3291 r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND price > 100) OR status = "cancelled""#,
3292 );
3293 let stmt = parser.parse().unwrap();
3294
3295 assert!(stmt.where_clause.is_some());
3296 let where_clause = stmt.where_clause.unwrap();
3297
3298 assert!(where_clause.conditions.len() > 0);
3300 }
3301
3302 #[test]
3303 fn test_complex_where_multiple_or_groups() {
3304 let mut parser = Parser::new(
3306 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL" OR symbol = "MSFT") AND (price > 100 AND price < 500)"#,
3307 );
3308 let stmt = parser.parse().unwrap();
3309
3310 assert!(stmt.where_clause.is_some());
3311 let where_clause = stmt.where_clause.unwrap();
3312 assert_eq!(where_clause.conditions.len(), 2);
3313
3314 assert!(matches!(
3316 where_clause.conditions[0].connector,
3317 Some(LogicalOp::And)
3318 ));
3319 }
3320
3321 #[test]
3322 fn test_complex_where_with_methods_in_parens() {
3323 let mut parser = Parser::new(
3325 r#"SELECT * FROM trades WHERE (symbol.StartsWith("A") OR symbol.StartsWith("G")) AND volume > 1000000"#,
3326 );
3327 let stmt = parser.parse().unwrap();
3328
3329 assert!(stmt.where_clause.is_some());
3330 let where_clause = stmt.where_clause.unwrap();
3331 assert_eq!(where_clause.conditions.len(), 2);
3332
3333 if let SqlExpression::BinaryOp { op, left, right } = &where_clause.conditions[0].expr {
3335 assert_eq!(op, "OR");
3336 assert!(matches!(left.as_ref(), SqlExpression::MethodCall { .. }));
3337 assert!(matches!(right.as_ref(), SqlExpression::MethodCall { .. }));
3338 } else {
3339 panic!("Expected OR of method calls");
3340 }
3341 }
3342
3343 #[test]
3344 fn test_complex_where_date_comparisons_with_parens() {
3345 let mut parser = Parser::new(
3347 r#"SELECT * FROM trades WHERE (executionDate > DateTime(2024, 1, 1) AND executionDate < DateTime(2024, 12, 31)) AND (status = "filled" OR status = "partial")"#,
3348 );
3349 let stmt = parser.parse().unwrap();
3350
3351 assert!(stmt.where_clause.is_some());
3352 let where_clause = stmt.where_clause.unwrap();
3353 assert_eq!(where_clause.conditions.len(), 2);
3354
3355 assert!(matches!(
3357 where_clause.conditions[0].connector,
3358 Some(LogicalOp::And)
3359 ));
3360 }
3361
3362 #[test]
3363 fn test_complex_where_price_volume_filters() {
3364 let mut parser = Parser::new(
3366 r#"SELECT * FROM trades WHERE ((price > 100 AND price < 200) OR (price > 500 AND price < 1000)) AND volume > 10000"#,
3367 );
3368 let stmt = parser.parse().unwrap();
3369
3370 assert!(stmt.where_clause.is_some());
3371 let where_clause = stmt.where_clause.unwrap();
3372
3373 assert!(where_clause.conditions.len() > 0);
3375 }
3376
3377 #[test]
3378 fn test_complex_where_mixed_string_numeric() {
3379 let mut parser = Parser::new(
3381 r#"SELECT * FROM trades WHERE (exchange = "NYSE" OR exchange = "NASDAQ") AND (volume > 1000000 OR notes.Contains("urgent"))"#,
3382 );
3383 let stmt = parser.parse().unwrap();
3384
3385 assert!(stmt.where_clause.is_some());
3386 }
3388
3389 #[test]
3390 fn test_complex_where_triple_nested() {
3391 let mut parser = Parser::new(
3393 r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND (price > 100 OR volume > 1000000)) OR (status = "cancelled" AND reason.Contains("timeout"))"#,
3394 );
3395 let stmt = parser.parse().unwrap();
3396
3397 assert!(stmt.where_clause.is_some());
3398 }
3400
3401 #[test]
3402 fn test_complex_where_single_parens_around_and() {
3403 let mut parser = Parser::new(
3405 r#"SELECT * FROM trades WHERE (symbol = "AAPL" AND price > 150 AND volume > 100000)"#,
3406 );
3407 let stmt = parser.parse().unwrap();
3408
3409 assert!(stmt.where_clause.is_some());
3410 let where_clause = stmt.where_clause.unwrap();
3411
3412 assert!(where_clause.conditions.len() > 0);
3414 }
3415
3416 #[test]
3418 fn test_format_preserves_simple_parentheses() {
3419 let query = r#"SELECT * FROM trades WHERE (status = "active" OR status = "pending")"#;
3420 let formatted = format_sql_pretty_compact(query, 5);
3421 let formatted_text = formatted.join(" ");
3422
3423 assert!(formatted_text.contains("(status"));
3425 assert!(formatted_text.contains("\"pending\")"));
3426
3427 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3429 let formatted_parens = formatted_text
3430 .chars()
3431 .filter(|c| *c == '(' || *c == ')')
3432 .count();
3433 assert_eq!(
3434 original_parens, formatted_parens,
3435 "Parentheses should be preserved"
3436 );
3437 }
3438
3439 #[test]
3440 fn test_format_preserves_complex_parentheses() {
3441 let query =
3442 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#;
3443 let formatted = format_sql_pretty_compact(query, 5);
3444 let formatted_text = formatted.join(" ");
3445
3446 assert!(formatted_text.contains("(symbol"));
3448 assert!(formatted_text.contains("\"GOOGL\")"));
3449
3450 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3452 let formatted_parens = formatted_text
3453 .chars()
3454 .filter(|c| *c == '(' || *c == ')')
3455 .count();
3456 assert_eq!(original_parens, formatted_parens);
3457 }
3458
3459 #[test]
3460 fn test_format_preserves_nested_parentheses() {
3461 let query = r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND price > 100) OR status = "cancelled""#;
3462 let formatted = format_sql_pretty_compact(query, 5);
3463 let formatted_text = formatted.join(" ");
3464
3465 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3467 let formatted_parens = formatted_text
3468 .chars()
3469 .filter(|c| *c == '(' || *c == ')')
3470 .count();
3471 assert_eq!(
3472 original_parens, formatted_parens,
3473 "Nested parentheses should be preserved"
3474 );
3475 assert_eq!(original_parens, 4, "Should have 4 parentheses total");
3476 }
3477
3478 #[test]
3479 fn test_format_preserves_method_calls_in_parentheses() {
3480 let query = r#"SELECT * FROM trades WHERE (symbol.StartsWith("A") OR symbol.StartsWith("G")) AND volume > 1000000"#;
3481 let formatted = format_sql_pretty_compact(query, 5);
3482 let formatted_text = formatted.join(" ");
3483
3484 assert!(formatted_text.contains("(symbol.StartsWith"));
3486 assert!(formatted_text.contains("StartsWith(\"A\")"));
3487 assert!(formatted_text.contains("StartsWith(\"G\")"));
3488
3489 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3491 let formatted_parens = formatted_text
3492 .chars()
3493 .filter(|c| *c == '(' || *c == ')')
3494 .count();
3495 assert_eq!(original_parens, formatted_parens);
3496 assert_eq!(
3497 original_parens, 6,
3498 "Should have 6 parentheses (1 group + 2 method calls)"
3499 );
3500 }
3501
3502 #[test]
3503 fn test_format_preserves_multiple_groups() {
3504 let query = r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND (price > 100 AND price < 500)"#;
3505 let formatted = format_sql_pretty_compact(query, 5);
3506 let formatted_text = formatted.join(" ");
3507
3508 assert!(formatted_text.contains("(symbol"));
3510 assert!(formatted_text.contains("(price"));
3511
3512 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3513 let formatted_parens = formatted_text
3514 .chars()
3515 .filter(|c| *c == '(' || *c == ')')
3516 .count();
3517 assert_eq!(original_parens, formatted_parens);
3518 assert_eq!(original_parens, 4, "Should have 4 parentheses (2 groups)");
3519 }
3520
3521 #[test]
3522 fn test_format_preserves_date_ranges() {
3523 let query = r#"SELECT * FROM trades WHERE (executionDate > DateTime(2024, 1, 1) AND executionDate < DateTime(2024, 12, 31))"#;
3524 let formatted = format_sql_pretty_compact(query, 5);
3525 let formatted_text = formatted.join(" ");
3526
3527 assert!(formatted_text.contains("(executionDate"));
3529 assert!(formatted_text.contains("DateTime(2024, 1, 1)"));
3530 assert!(formatted_text.contains("DateTime(2024, 12, 31)"));
3531
3532 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3533 let formatted_parens = formatted_text
3534 .chars()
3535 .filter(|c| *c == '(' || *c == ')')
3536 .count();
3537 assert_eq!(original_parens, formatted_parens);
3538 }
3539
3540 #[test]
3541 fn test_format_multiline_layout() {
3542 let query =
3544 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#;
3545 let formatted = format_sql_pretty_compact(query, 5);
3546
3547 assert!(formatted.len() >= 4, "Should have multiple lines");
3549 assert_eq!(formatted[0], "SELECT");
3550 assert!(formatted[1].trim().starts_with("*"));
3551 assert!(formatted[2].starts_with("FROM"));
3552 assert_eq!(formatted[3], "WHERE");
3553
3554 let where_lines: Vec<_> = formatted.iter().skip(4).collect();
3556 assert!(where_lines.iter().any(|l| l.contains("(symbol")));
3557 assert!(where_lines.iter().any(|l| l.trim() == "AND"));
3558 }
3559
3560 #[test]
3561 fn test_between_simple() {
3562 let mut parser = Parser::new("SELECT * FROM table WHERE price BETWEEN 50 AND 100");
3563 let stmt = parser.parse().expect("Should parse simple BETWEEN");
3564
3565 assert!(stmt.where_clause.is_some());
3566 let where_clause = stmt.where_clause.unwrap();
3567 assert_eq!(where_clause.conditions.len(), 1);
3568
3569 let ast = format_ast_tree("SELECT * FROM table WHERE price BETWEEN 50 AND 100");
3571 assert!(!ast.contains("PARSE ERROR"));
3572 assert!(ast.contains("SelectStatement"));
3573 }
3574
3575 #[test]
3576 fn test_between_in_parentheses() {
3577 let mut parser = Parser::new("SELECT * FROM table WHERE (price BETWEEN 50 AND 100)");
3578 let stmt = parser.parse().expect("Should parse BETWEEN in parentheses");
3579
3580 assert!(stmt.where_clause.is_some());
3581
3582 let ast = format_ast_tree("SELECT * FROM table WHERE (price BETWEEN 50 AND 100)");
3584 assert!(!ast.contains("PARSE ERROR"), "Should not have parse error");
3585 }
3586
3587 #[test]
3588 fn test_between_with_or() {
3589 let query = "SELECT * FROM test WHERE (Price BETWEEN 50 AND 100) OR (quantity > 5)";
3590 let mut parser = Parser::new(query);
3591 let stmt = parser.parse().expect("Should parse BETWEEN with OR");
3592
3593 assert!(stmt.where_clause.is_some());
3594 }
3597
3598 #[test]
3599 fn test_between_with_and() {
3600 let query = "SELECT * FROM table WHERE category = 'Books' AND price BETWEEN 10 AND 50";
3601 let mut parser = Parser::new(query);
3602 let stmt = parser.parse().expect("Should parse BETWEEN with AND");
3603
3604 assert!(stmt.where_clause.is_some());
3605 let where_clause = stmt.where_clause.unwrap();
3606 assert_eq!(where_clause.conditions.len(), 2); }
3608
3609 #[test]
3610 fn test_multiple_between() {
3611 let query =
3612 "SELECT * FROM table WHERE (price BETWEEN 10 AND 50) AND (quantity BETWEEN 5 AND 20)";
3613 let mut parser = Parser::new(query);
3614 let stmt = parser
3615 .parse()
3616 .expect("Should parse multiple BETWEEN clauses");
3617
3618 assert!(stmt.where_clause.is_some());
3619 }
3620
3621 #[test]
3622 fn test_between_complex_query() {
3623 let query = "SELECT * FROM test_sorting WHERE (Price BETWEEN 50 AND 100) OR (Product.Length() > 5) ORDER BY Category ASC, price DESC";
3625 let mut parser = Parser::new(query);
3626 let stmt = parser
3627 .parse()
3628 .expect("Should parse complex query with BETWEEN, method calls, and ORDER BY");
3629
3630 assert!(stmt.where_clause.is_some());
3631 assert!(stmt.order_by.is_some());
3632
3633 let order_by = stmt.order_by.unwrap();
3634 assert_eq!(order_by.len(), 2);
3635 assert_eq!(order_by[0].column, "Category");
3636 assert!(matches!(order_by[0].direction, SortDirection::Asc));
3637 assert_eq!(order_by[1].column, "price");
3638 assert!(matches!(order_by[1].direction, SortDirection::Desc));
3639 }
3640
3641 #[test]
3642 fn test_between_formatting() {
3643 let expr = SqlExpression::Between {
3644 expr: Box::new(SqlExpression::Column("price".to_string())),
3645 lower: Box::new(SqlExpression::NumberLiteral("50".to_string())),
3646 upper: Box::new(SqlExpression::NumberLiteral("100".to_string())),
3647 };
3648
3649 let formatted = format_expression(&expr);
3650 assert_eq!(formatted, "price BETWEEN 50 AND 100");
3651
3652 let ast_formatted = format_expression_ast(&expr);
3653 assert!(ast_formatted.contains("Between"));
3654 assert!(ast_formatted.contains("50"));
3655 assert!(ast_formatted.contains("100"));
3656 }
3657
3658 #[test]
3659 fn test_utf8_boundary_safety() {
3660 let query_with_unicode = "SELECT * FROM table WHERE column = 'héllo'";
3662
3663 for pos in 0..query_with_unicode.len() + 1 {
3665 let result =
3667 std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, pos));
3668
3669 assert!(
3670 result.is_ok(),
3671 "Panic at position {} in query with Unicode",
3672 pos
3673 );
3674 }
3675
3676 let result = std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, 1000));
3678 assert!(result.is_ok(), "Panic with position beyond string length");
3679
3680 let pos_in_e = query_with_unicode.find('é').unwrap() + 1; let result =
3683 std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, pos_in_e));
3684 assert!(
3685 result.is_ok(),
3686 "Panic with cursor in middle of UTF-8 character"
3687 );
3688 }
3689}