1use chrono::{Datelike, Local, NaiveDateTime};
2
3#[derive(Debug, Clone, PartialEq)]
4pub enum Token {
5 Select,
7 From,
8 Where,
9 And,
10 Or,
11 In,
12 Not,
13 Between,
14 Like,
15 Is,
16 Null,
17 OrderBy,
18 GroupBy,
19 Having,
20 As,
21 Asc,
22 Desc,
23 Limit,
24 Offset,
25 DateTime, Identifier(String),
29 QuotedIdentifier(String), StringLiteral(String),
31 NumberLiteral(String),
32 Star,
33
34 Dot,
36 Comma,
37 LeftParen,
38 RightParen,
39 Equal,
40 NotEqual,
41 LessThan,
42 GreaterThan,
43 LessThanOrEqual,
44 GreaterThanOrEqual,
45
46 Plus,
48 Minus,
49 Divide,
50
51 Eof,
53}
54
55#[derive(Debug, Clone)]
56pub struct Lexer {
57 input: Vec<char>,
58 position: usize,
59 current_char: Option<char>,
60}
61
62impl Lexer {
63 pub fn new(input: &str) -> Self {
64 let chars: Vec<char> = input.chars().collect();
65 let current = chars.get(0).copied();
66 Self {
67 input: chars,
68 position: 0,
69 current_char: current,
70 }
71 }
72
73 fn advance(&mut self) {
74 self.position += 1;
75 self.current_char = self.input.get(self.position).copied();
76 }
77
78 fn peek(&self, offset: usize) -> Option<char> {
79 self.input.get(self.position + offset).copied()
80 }
81
82 fn skip_whitespace(&mut self) {
83 while let Some(ch) = self.current_char {
84 if ch.is_whitespace() {
85 self.advance();
86 } else {
87 break;
88 }
89 }
90 }
91
92 fn read_identifier(&mut self) -> String {
93 let mut result = String::new();
94 while let Some(ch) = self.current_char {
95 if ch.is_alphanumeric() || ch == '_' {
96 result.push(ch);
97 self.advance();
98 } else {
99 break;
100 }
101 }
102 result
103 }
104
105 fn read_string(&mut self) -> String {
106 let mut result = String::new();
107 let quote_char = self.current_char.unwrap(); self.advance(); while let Some(ch) = self.current_char {
111 if ch == quote_char {
112 self.advance(); break;
114 }
115 result.push(ch);
116 self.advance();
117 }
118 result
119 }
120
121 fn read_number(&mut self) -> String {
122 let mut result = String::new();
123 while let Some(ch) = self.current_char {
124 if ch.is_numeric() || ch == '.' {
125 result.push(ch);
126 self.advance();
127 } else {
128 break;
129 }
130 }
131 result
132 }
133
134 pub fn next_token(&mut self) -> Token {
135 self.skip_whitespace();
136
137 match self.current_char {
138 None => Token::Eof,
139 Some('*') => {
140 self.advance();
141 Token::Star }
145 Some('+') => {
146 self.advance();
147 Token::Plus
148 }
149 Some('/') => {
150 self.advance();
151 Token::Divide
152 }
153 Some('.') => {
154 self.advance();
155 Token::Dot
156 }
157 Some(',') => {
158 self.advance();
159 Token::Comma
160 }
161 Some('(') => {
162 self.advance();
163 Token::LeftParen
164 }
165 Some(')') => {
166 self.advance();
167 Token::RightParen
168 }
169 Some('=') => {
170 self.advance();
171 Token::Equal
172 }
173 Some('<') => {
174 self.advance();
175 if self.current_char == Some('=') {
176 self.advance();
177 Token::LessThanOrEqual
178 } else if self.current_char == Some('>') {
179 self.advance();
180 Token::NotEqual
181 } else {
182 Token::LessThan
183 }
184 }
185 Some('>') => {
186 self.advance();
187 if self.current_char == Some('=') {
188 self.advance();
189 Token::GreaterThanOrEqual
190 } else {
191 Token::GreaterThan
192 }
193 }
194 Some('!') if self.peek(1) == Some('=') => {
195 self.advance();
196 self.advance();
197 Token::NotEqual
198 }
199 Some('"') => {
200 let ident_val = self.read_string();
202 Token::QuotedIdentifier(ident_val)
203 }
204 Some('\'') => {
205 let string_val = self.read_string();
207 Token::StringLiteral(string_val)
208 }
209 Some('-') if self.peek(1).map_or(false, |c| c.is_numeric()) => {
210 self.advance(); let num = self.read_number();
213 Token::NumberLiteral(format!("-{}", num))
214 }
215 Some('-') => {
216 self.advance();
218 Token::Minus
219 }
220 Some(ch) if ch.is_numeric() => {
221 let num = self.read_number();
222 Token::NumberLiteral(num)
223 }
224 Some(ch) if ch.is_alphabetic() || ch == '_' => {
225 let ident = self.read_identifier();
226 match ident.to_uppercase().as_str() {
227 "SELECT" => Token::Select,
228 "FROM" => Token::From,
229 "WHERE" => Token::Where,
230 "AND" => Token::And,
231 "OR" => Token::Or,
232 "IN" => Token::In,
233 "NOT" => Token::Not,
234 "BETWEEN" => Token::Between,
235 "LIKE" => Token::Like,
236 "IS" => Token::Is,
237 "NULL" => Token::Null,
238 "ORDER" if self.peek_keyword("BY") => {
239 self.skip_whitespace();
240 self.read_identifier(); Token::OrderBy
242 }
243 "GROUP" if self.peek_keyword("BY") => {
244 self.skip_whitespace();
245 self.read_identifier(); Token::GroupBy
247 }
248 "HAVING" => Token::Having,
249 "AS" => Token::As,
250 "ASC" => Token::Asc,
251 "DESC" => Token::Desc,
252 "LIMIT" => Token::Limit,
253 "OFFSET" => Token::Offset,
254 "DATETIME" => Token::DateTime,
255 _ => Token::Identifier(ident),
256 }
257 }
258 Some(ch) => {
259 self.advance();
260 Token::Identifier(ch.to_string())
261 }
262 }
263 }
264
265 fn peek_keyword(&mut self, keyword: &str) -> bool {
266 let saved_pos = self.position;
267 let saved_char = self.current_char;
268
269 self.skip_whitespace();
270 let next_word = self.read_identifier();
271 let matches = next_word.to_uppercase() == keyword;
272
273 self.position = saved_pos;
275 self.current_char = saved_char;
276
277 matches
278 }
279
280 pub fn get_position(&self) -> usize {
281 self.position
282 }
283
284 pub fn tokenize_all(&mut self) -> Vec<Token> {
285 let mut tokens = Vec::new();
286 loop {
287 let token = self.next_token();
288 if matches!(token, Token::Eof) {
289 tokens.push(token);
290 break;
291 }
292 tokens.push(token);
293 }
294 tokens
295 }
296
297 pub fn tokenize_all_with_positions(&mut self) -> Vec<(usize, usize, Token)> {
298 let mut tokens = Vec::new();
299 loop {
300 self.skip_whitespace();
301 let start_pos = self.position;
302 let token = self.next_token();
303 let end_pos = self.position;
304
305 if matches!(token, Token::Eof) {
306 break;
307 }
308 tokens.push((start_pos, end_pos, token));
309 }
310 tokens
311 }
312}
313
314#[derive(Debug, Clone)]
316pub enum SqlExpression {
317 Column(String),
318 StringLiteral(String),
319 NumberLiteral(String),
320 DateTimeConstructor {
321 year: i32,
322 month: u32,
323 day: u32,
324 hour: Option<u32>,
325 minute: Option<u32>,
326 second: Option<u32>,
327 },
328 DateTimeToday {
329 hour: Option<u32>,
330 minute: Option<u32>,
331 second: Option<u32>,
332 },
333 MethodCall {
334 object: String,
335 method: String,
336 args: Vec<SqlExpression>,
337 },
338 ChainedMethodCall {
339 base: Box<SqlExpression>,
340 method: String,
341 args: Vec<SqlExpression>,
342 },
343 FunctionCall {
344 name: String,
345 args: Vec<SqlExpression>,
346 },
347 BinaryOp {
348 left: Box<SqlExpression>,
349 op: String,
350 right: Box<SqlExpression>,
351 },
352 InList {
353 expr: Box<SqlExpression>,
354 values: Vec<SqlExpression>,
355 },
356 NotInList {
357 expr: Box<SqlExpression>,
358 values: Vec<SqlExpression>,
359 },
360 Between {
361 expr: Box<SqlExpression>,
362 lower: Box<SqlExpression>,
363 upper: Box<SqlExpression>,
364 },
365 Not {
366 expr: Box<SqlExpression>,
367 },
368}
369
370#[derive(Debug, Clone)]
371pub struct WhereClause {
372 pub conditions: Vec<Condition>,
373}
374
375#[derive(Debug, Clone)]
376pub struct Condition {
377 pub expr: SqlExpression,
378 pub connector: Option<LogicalOp>, }
380
381#[derive(Debug, Clone)]
382pub enum LogicalOp {
383 And,
384 Or,
385}
386
387#[derive(Debug, Clone, PartialEq)]
388pub enum SortDirection {
389 Asc,
390 Desc,
391}
392
393#[derive(Debug, Clone)]
394pub struct OrderByColumn {
395 pub column: String,
396 pub direction: SortDirection,
397}
398
399#[derive(Debug, Clone)]
401pub enum SelectItem {
402 Column(String),
404 Expression { expr: SqlExpression, alias: String },
406 Star,
408}
409
410#[derive(Debug, Clone)]
411pub struct SelectStatement {
412 pub columns: Vec<String>, pub select_items: Vec<SelectItem>, pub from_table: Option<String>,
415 pub where_clause: Option<WhereClause>,
416 pub order_by: Option<Vec<OrderByColumn>>,
417 pub group_by: Option<Vec<String>>,
418 pub limit: Option<usize>,
419 pub offset: Option<usize>,
420}
421
422pub struct ParserConfig {
423 pub case_insensitive: bool,
424}
425
426impl Default for ParserConfig {
427 fn default() -> Self {
428 Self {
429 case_insensitive: false,
430 }
431 }
432}
433
434pub struct Parser {
435 lexer: Lexer,
436 current_token: Token,
437 in_method_args: bool, columns: Vec<String>, paren_depth: i32, config: ParserConfig, }
442
443impl Parser {
444 pub fn new(input: &str) -> Self {
445 let mut lexer = Lexer::new(input);
446 let current_token = lexer.next_token();
447 Self {
448 lexer,
449 current_token,
450 in_method_args: false,
451 columns: Vec::new(),
452 paren_depth: 0,
453 config: ParserConfig::default(),
454 }
455 }
456
457 pub fn with_config(input: &str, config: ParserConfig) -> Self {
458 let mut lexer = Lexer::new(input);
459 let current_token = lexer.next_token();
460 Self {
461 lexer,
462 current_token,
463 in_method_args: false,
464 columns: Vec::new(),
465 paren_depth: 0,
466 config,
467 }
468 }
469
470 pub fn with_columns(mut self, columns: Vec<String>) -> Self {
471 self.columns = columns;
472 self
473 }
474
475 fn consume(&mut self, expected: Token) -> Result<(), String> {
476 if std::mem::discriminant(&self.current_token) == std::mem::discriminant(&expected) {
477 match &expected {
479 Token::LeftParen => self.paren_depth += 1,
480 Token::RightParen => {
481 self.paren_depth -= 1;
482 if self.paren_depth < 0 {
484 return Err(
485 "Unexpected closing parenthesis - no matching opening parenthesis"
486 .to_string(),
487 );
488 }
489 }
490 _ => {}
491 }
492
493 self.current_token = self.lexer.next_token();
494 Ok(())
495 } else {
496 let error_msg = match (&expected, &self.current_token) {
498 (Token::RightParen, Token::Eof) if self.paren_depth > 0 => {
499 format!(
500 "Unclosed parenthesis - missing {} closing parenthes{}",
501 self.paren_depth,
502 if self.paren_depth == 1 { "is" } else { "es" }
503 )
504 }
505 (Token::RightParen, _) if self.paren_depth > 0 => {
506 format!(
507 "Expected closing parenthesis but found {:?} (currently {} unclosed parenthes{})",
508 self.current_token,
509 self.paren_depth,
510 if self.paren_depth == 1 { "is" } else { "es" }
511 )
512 }
513 _ => format!("Expected {:?}, found {:?}", expected, self.current_token),
514 };
515 Err(error_msg)
516 }
517 }
518
519 fn advance(&mut self) {
520 match &self.current_token {
522 Token::LeftParen => self.paren_depth += 1,
523 Token::RightParen => {
524 self.paren_depth -= 1;
525 }
528 _ => {}
529 }
530 self.current_token = self.lexer.next_token();
531 }
532
533 pub fn parse(&mut self) -> Result<SelectStatement, String> {
534 self.parse_select_statement()
535 }
536
537 fn parse_select_statement(&mut self) -> Result<SelectStatement, String> {
538 self.consume(Token::Select)?;
539
540 let select_items = self.parse_select_items()?;
542
543 let columns = select_items
545 .iter()
546 .map(|item| match item {
547 SelectItem::Star => "*".to_string(),
548 SelectItem::Column(name) => name.clone(),
549 SelectItem::Expression { alias, .. } => alias.clone(),
550 })
551 .collect();
552
553 let from_table = if matches!(self.current_token, Token::From) {
554 self.advance();
555 match &self.current_token {
556 Token::Identifier(table) => {
557 let table_name = table.clone();
558 self.advance();
559 Some(table_name)
560 }
561 Token::QuotedIdentifier(table) => {
562 let table_name = table.clone();
564 self.advance();
565 Some(table_name)
566 }
567 _ => return Err("Expected table name after FROM".to_string()),
568 }
569 } else {
570 None
571 };
572
573 let where_clause = if matches!(self.current_token, Token::Where) {
574 self.advance();
575 Some(self.parse_where_clause()?)
576 } else {
577 None
578 };
579
580 let order_by = if matches!(self.current_token, Token::OrderBy) {
581 self.advance();
582 Some(self.parse_order_by_list()?)
583 } else {
584 None
585 };
586
587 let group_by = if matches!(self.current_token, Token::GroupBy) {
588 self.advance();
589 Some(self.parse_identifier_list()?)
590 } else {
591 None
592 };
593
594 let limit = if matches!(self.current_token, Token::Limit) {
596 self.advance();
597 match &self.current_token {
598 Token::NumberLiteral(num) => {
599 let limit_val = num
600 .parse::<usize>()
601 .map_err(|_| format!("Invalid LIMIT value: {}", num))?;
602 self.advance();
603 Some(limit_val)
604 }
605 _ => return Err("Expected number after LIMIT".to_string()),
606 }
607 } else {
608 None
609 };
610
611 let offset = if matches!(self.current_token, Token::Offset) {
613 self.advance();
614 match &self.current_token {
615 Token::NumberLiteral(num) => {
616 let offset_val = num
617 .parse::<usize>()
618 .map_err(|_| format!("Invalid OFFSET value: {}", num))?;
619 self.advance();
620 Some(offset_val)
621 }
622 _ => return Err("Expected number after OFFSET".to_string()),
623 }
624 } else {
625 None
626 };
627
628 if self.paren_depth > 0 {
630 return Err(format!(
631 "Unclosed parenthesis - missing {} closing parenthes{}",
632 self.paren_depth,
633 if self.paren_depth == 1 { "is" } else { "es" }
634 ));
635 } else if self.paren_depth < 0 {
636 return Err(
637 "Extra closing parenthesis found - no matching opening parenthesis".to_string(),
638 );
639 }
640
641 Ok(SelectStatement {
642 columns,
643 select_items,
644 from_table,
645 where_clause,
646 order_by,
647 group_by,
648 limit,
649 offset,
650 })
651 }
652
653 fn parse_select_list(&mut self) -> Result<Vec<String>, String> {
654 let mut columns = Vec::new();
655
656 if matches!(self.current_token, Token::Star) {
657 columns.push("*".to_string());
658 self.advance();
659 } else {
660 loop {
661 match &self.current_token {
662 Token::Identifier(col) => {
663 columns.push(col.clone());
664 self.advance();
665 }
666 Token::QuotedIdentifier(col) => {
667 columns.push(col.clone());
669 self.advance();
670 }
671 _ => return Err("Expected column name".to_string()),
672 }
673
674 if matches!(self.current_token, Token::Comma) {
675 self.advance();
676 } else {
677 break;
678 }
679 }
680 }
681
682 Ok(columns)
683 }
684
685 fn parse_select_items(&mut self) -> Result<Vec<SelectItem>, String> {
687 let mut items = Vec::new();
688
689 loop {
690 if matches!(self.current_token, Token::Star) {
693 items.push(SelectItem::Star);
701 self.advance();
702 } else {
703 let expr = self.parse_additive()?; let alias = if matches!(self.current_token, Token::As) {
708 self.advance();
709 match &self.current_token {
710 Token::Identifier(alias_name) => {
711 let alias = alias_name.clone();
712 self.advance();
713 alias
714 }
715 Token::QuotedIdentifier(alias_name) => {
716 let alias = alias_name.clone();
717 self.advance();
718 alias
719 }
720 _ => return Err("Expected alias name after AS".to_string()),
721 }
722 } else {
723 match &expr {
725 SqlExpression::Column(col_name) => col_name.clone(),
726 _ => format!("expr_{}", items.len() + 1), }
728 };
729
730 let item = match expr {
732 SqlExpression::Column(col_name) if alias == col_name => {
733 SelectItem::Column(col_name)
735 }
736 _ => {
737 SelectItem::Expression { expr, alias }
739 }
740 };
741
742 items.push(item);
743 }
744
745 if matches!(self.current_token, Token::Comma) {
747 self.advance();
748 } else {
749 break;
750 }
751 }
752
753 Ok(items)
754 }
755
756 fn parse_identifier_list(&mut self) -> Result<Vec<String>, String> {
757 let mut identifiers = Vec::new();
758
759 loop {
760 match &self.current_token {
761 Token::Identifier(id) => {
762 identifiers.push(id.clone());
763 self.advance();
764 }
765 Token::QuotedIdentifier(id) => {
766 identifiers.push(id.clone());
768 self.advance();
769 }
770 _ => return Err("Expected identifier".to_string()),
771 }
772
773 if matches!(self.current_token, Token::Comma) {
774 self.advance();
775 } else {
776 break;
777 }
778 }
779
780 Ok(identifiers)
781 }
782
783 fn parse_order_by_list(&mut self) -> Result<Vec<OrderByColumn>, String> {
784 let mut order_columns = Vec::new();
785
786 loop {
787 let column = match &self.current_token {
788 Token::Identifier(id) => {
789 let col = id.clone();
790 self.advance();
791 col
792 }
793 Token::QuotedIdentifier(id) => {
794 let col = id.clone();
795 self.advance();
796 col
797 }
798 Token::NumberLiteral(num) if self.columns.iter().any(|col| col == num) => {
799 let col = num.clone();
801 self.advance();
802 col
803 }
804 _ => return Err("Expected column name in ORDER BY".to_string()),
805 };
806
807 let direction = match &self.current_token {
809 Token::Asc => {
810 self.advance();
811 SortDirection::Asc
812 }
813 Token::Desc => {
814 self.advance();
815 SortDirection::Desc
816 }
817 _ => SortDirection::Asc, };
819
820 order_columns.push(OrderByColumn { column, direction });
821
822 if matches!(self.current_token, Token::Comma) {
823 self.advance();
824 } else {
825 break;
826 }
827 }
828
829 Ok(order_columns)
830 }
831
832 fn parse_where_clause(&mut self) -> Result<WhereClause, String> {
833 let mut conditions = Vec::new();
834
835 loop {
836 let expr = self.parse_expression()?;
837
838 let connector = match &self.current_token {
839 Token::And => {
840 self.advance();
841 Some(LogicalOp::And)
842 }
843 Token::Or => {
844 self.advance();
845 Some(LogicalOp::Or)
846 }
847 Token::RightParen if self.paren_depth <= 0 => {
848 return Err(
850 "Unexpected closing parenthesis - no matching opening parenthesis"
851 .to_string(),
852 );
853 }
854 _ => None,
855 };
856
857 conditions.push(Condition {
858 expr,
859 connector: connector.clone(),
860 });
861
862 if connector.is_none() {
863 break;
864 }
865 }
866
867 Ok(WhereClause { conditions })
868 }
869
870 fn parse_expression(&mut self) -> Result<SqlExpression, String> {
871 let mut left = self.parse_comparison()?;
872
873 if let Some(op) = self.get_binary_op() {
876 self.advance();
877 let right = self.parse_expression()?;
878 left = SqlExpression::BinaryOp {
879 left: Box::new(left),
880 op,
881 right: Box::new(right),
882 };
883 }
884
885 if matches!(self.current_token, Token::In) {
887 self.advance();
888 self.consume(Token::LeftParen)?;
889 let values = self.parse_expression_list()?;
890 self.consume(Token::RightParen)?;
891
892 left = SqlExpression::InList {
893 expr: Box::new(left),
894 values,
895 };
896 }
897
898 Ok(left)
902 }
903
904 fn parse_comparison(&mut self) -> Result<SqlExpression, String> {
905 let mut left = self.parse_additive()?;
906
907 if matches!(self.current_token, Token::Between) {
909 self.advance(); let lower = self.parse_primary()?;
911 self.consume(Token::And)?; let upper = self.parse_primary()?;
913
914 return Ok(SqlExpression::Between {
915 expr: Box::new(left),
916 lower: Box::new(lower),
917 upper: Box::new(upper),
918 });
919 }
920
921 if matches!(self.current_token, Token::Not) {
923 self.advance(); if matches!(self.current_token, Token::In) {
925 self.advance(); self.consume(Token::LeftParen)?;
927 let values = self.parse_expression_list()?;
928 self.consume(Token::RightParen)?;
929
930 return Ok(SqlExpression::NotInList {
931 expr: Box::new(left),
932 values,
933 });
934 } else {
935 return Err("Expected IN after NOT".to_string());
936 }
937 }
938
939 if let Some(op) = self.get_binary_op() {
941 self.advance();
942 let right = self.parse_additive()?;
943 left = SqlExpression::BinaryOp {
944 left: Box::new(left),
945 op,
946 right: Box::new(right),
947 };
948 }
949
950 Ok(left)
951 }
952
953 fn parse_additive(&mut self) -> Result<SqlExpression, String> {
954 let mut left = self.parse_multiplicative()?;
955
956 while matches!(self.current_token, Token::Plus | Token::Minus) {
957 let op = match self.current_token {
958 Token::Plus => "+",
959 Token::Minus => "-",
960 _ => unreachable!(),
961 };
962 self.advance();
963 let right = self.parse_multiplicative()?;
964 left = SqlExpression::BinaryOp {
965 left: Box::new(left),
966 op: op.to_string(),
967 right: Box::new(right),
968 };
969 }
970
971 Ok(left)
972 }
973
974 fn parse_multiplicative(&mut self) -> Result<SqlExpression, String> {
975 let mut left = self.parse_primary()?;
976
977 while matches!(self.current_token, Token::Dot) {
979 self.advance();
980 if let Token::Identifier(method) = &self.current_token {
981 let method_name = method.clone();
982 self.advance();
983
984 if matches!(self.current_token, Token::LeftParen) {
985 self.advance();
986 let args = self.parse_method_args()?;
987 self.consume(Token::RightParen)?;
988
989 match left {
991 SqlExpression::Column(obj) => {
992 left = SqlExpression::MethodCall {
994 object: obj,
995 method: method_name,
996 args,
997 };
998 }
999 SqlExpression::MethodCall { .. }
1000 | SqlExpression::ChainedMethodCall { .. } => {
1001 left = SqlExpression::ChainedMethodCall {
1003 base: Box::new(left),
1004 method: method_name,
1005 args,
1006 };
1007 }
1008 _ => {
1009 left = SqlExpression::ChainedMethodCall {
1011 base: Box::new(left),
1012 method: method_name,
1013 args,
1014 };
1015 }
1016 }
1017 } else {
1018 return Err(format!("Expected '(' after method name '{}'", method_name));
1019 }
1020 } else {
1021 return Err("Expected method name after '.'".to_string());
1022 }
1023 }
1024
1025 while matches!(self.current_token, Token::Star | Token::Divide) {
1026 let op = match self.current_token {
1027 Token::Star => "*",
1028 Token::Divide => "/",
1029 _ => unreachable!(),
1030 };
1031 self.advance();
1032 let right = self.parse_primary()?;
1033 left = SqlExpression::BinaryOp {
1034 left: Box::new(left),
1035 op: op.to_string(),
1036 right: Box::new(right),
1037 };
1038 }
1039
1040 Ok(left)
1041 }
1042
1043 fn parse_logical_or(&mut self) -> Result<SqlExpression, String> {
1044 let mut left = self.parse_logical_and()?;
1045
1046 while matches!(self.current_token, Token::Or) {
1047 self.advance();
1048 let right = self.parse_logical_and()?;
1049 left = SqlExpression::BinaryOp {
1053 left: Box::new(left),
1054 op: "OR".to_string(),
1055 right: Box::new(right),
1056 };
1057 }
1058
1059 Ok(left)
1060 }
1061
1062 fn parse_logical_and(&mut self) -> Result<SqlExpression, String> {
1063 let mut left = self.parse_expression()?;
1064
1065 while matches!(self.current_token, Token::And) {
1066 self.advance();
1067 let right = self.parse_expression()?;
1068 left = SqlExpression::BinaryOp {
1070 left: Box::new(left),
1071 op: "AND".to_string(),
1072 right: Box::new(right),
1073 };
1074 }
1075
1076 Ok(left)
1077 }
1078
1079 fn parse_primary(&mut self) -> Result<SqlExpression, String> {
1080 if let Token::NumberLiteral(num_str) = &self.current_token {
1083 if self.columns.iter().any(|col| col == num_str) {
1085 let expr = SqlExpression::Column(num_str.clone());
1086 self.advance();
1087 return Ok(expr);
1088 }
1089 }
1090
1091 match &self.current_token {
1092 Token::DateTime => {
1093 self.advance(); self.consume(Token::LeftParen)?;
1095
1096 if matches!(&self.current_token, Token::RightParen) {
1098 self.advance(); return Ok(SqlExpression::DateTimeToday {
1100 hour: None,
1101 minute: None,
1102 second: None,
1103 });
1104 }
1105
1106 let year = if let Token::NumberLiteral(n) = &self.current_token {
1108 n.parse::<i32>().map_err(|_| "Invalid year")?
1109 } else {
1110 return Err("Expected year in DateTime constructor".to_string());
1111 };
1112 self.advance();
1113 self.consume(Token::Comma)?;
1114
1115 let month = if let Token::NumberLiteral(n) = &self.current_token {
1117 n.parse::<u32>().map_err(|_| "Invalid month")?
1118 } else {
1119 return Err("Expected month in DateTime constructor".to_string());
1120 };
1121 self.advance();
1122 self.consume(Token::Comma)?;
1123
1124 let day = if let Token::NumberLiteral(n) = &self.current_token {
1126 n.parse::<u32>().map_err(|_| "Invalid day")?
1127 } else {
1128 return Err("Expected day in DateTime constructor".to_string());
1129 };
1130 self.advance();
1131
1132 let mut hour = None;
1134 let mut minute = None;
1135 let mut second = None;
1136
1137 if matches!(&self.current_token, Token::Comma) {
1138 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1142 hour = Some(n.parse::<u32>().map_err(|_| "Invalid hour")?);
1143 self.advance();
1144
1145 if matches!(&self.current_token, Token::Comma) {
1147 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1150 minute = Some(n.parse::<u32>().map_err(|_| "Invalid minute")?);
1151 self.advance();
1152
1153 if matches!(&self.current_token, Token::Comma) {
1155 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1158 second =
1159 Some(n.parse::<u32>().map_err(|_| "Invalid second")?);
1160 self.advance();
1161 }
1162 }
1163 }
1164 }
1165 }
1166 }
1167
1168 self.consume(Token::RightParen)?;
1169 Ok(SqlExpression::DateTimeConstructor {
1170 year,
1171 month,
1172 day,
1173 hour,
1174 minute,
1175 second,
1176 })
1177 }
1178 Token::Identifier(id) => {
1179 let id_upper = id.to_uppercase();
1180 let id_clone = id.clone();
1181 self.advance();
1182
1183 if matches!(self.current_token, Token::LeftParen) {
1185 if matches!(
1187 id_upper.as_str(),
1188 "ROUND"
1189 | "ABS"
1190 | "FLOOR"
1191 | "CEILING"
1192 | "CEIL"
1193 | "MOD"
1194 | "QUOTIENT"
1195 | "POWER"
1196 | "POW"
1197 | "SQRT"
1198 | "EXP"
1199 | "LN"
1200 | "LOG"
1201 | "LOG10"
1202 | "PI"
1203 | "TEXTJOIN"
1204 | "DATEDIFF"
1205 | "DATEADD"
1206 | "NOW"
1207 | "TODAY"
1208 ) {
1209 self.advance(); let args = self.parse_function_args()?;
1211 self.consume(Token::RightParen)?;
1212 return Ok(SqlExpression::FunctionCall {
1213 name: id_upper,
1214 args,
1215 });
1216 }
1217 }
1218
1219 Ok(SqlExpression::Column(id_clone))
1221 }
1222 Token::QuotedIdentifier(id) => {
1223 let expr = if self.in_method_args {
1226 SqlExpression::StringLiteral(id.clone())
1227 } else {
1228 SqlExpression::Column(id.clone())
1230 };
1231 self.advance();
1232 Ok(expr)
1233 }
1234 Token::StringLiteral(s) => {
1235 let expr = SqlExpression::StringLiteral(s.clone());
1236 self.advance();
1237 Ok(expr)
1238 }
1239 Token::NumberLiteral(n) => {
1240 let expr = SqlExpression::NumberLiteral(n.clone());
1241 self.advance();
1242 Ok(expr)
1243 }
1244 Token::LeftParen => {
1245 self.advance();
1246
1247 let expr = self.parse_logical_or()?;
1250
1251 self.consume(Token::RightParen)?;
1252 Ok(expr)
1253 }
1254 Token::Not => {
1255 self.advance(); if let Ok(inner_expr) = self.parse_comparison() {
1259 if matches!(self.current_token, Token::In) {
1261 self.advance(); self.consume(Token::LeftParen)?;
1263 let values = self.parse_expression_list()?;
1264 self.consume(Token::RightParen)?;
1265
1266 return Ok(SqlExpression::NotInList {
1267 expr: Box::new(inner_expr),
1268 values,
1269 });
1270 } else {
1271 return Ok(SqlExpression::Not {
1273 expr: Box::new(inner_expr),
1274 });
1275 }
1276 } else {
1277 return Err("Expected expression after NOT".to_string());
1278 }
1279 }
1280 _ => Err(format!("Unexpected token: {:?}", self.current_token)),
1281 }
1282 }
1283
1284 fn parse_method_args(&mut self) -> Result<Vec<SqlExpression>, String> {
1285 let mut args = Vec::new();
1286
1287 self.in_method_args = true;
1289
1290 if !matches!(self.current_token, Token::RightParen) {
1291 loop {
1292 args.push(self.parse_expression()?);
1293
1294 if matches!(self.current_token, Token::Comma) {
1295 self.advance();
1296 } else {
1297 break;
1298 }
1299 }
1300 }
1301
1302 self.in_method_args = false;
1304
1305 Ok(args)
1306 }
1307
1308 fn parse_function_args(&mut self) -> Result<Vec<SqlExpression>, String> {
1309 let mut args = Vec::new();
1310
1311 if !matches!(self.current_token, Token::RightParen) {
1312 loop {
1313 args.push(self.parse_additive()?);
1315
1316 if matches!(self.current_token, Token::Comma) {
1317 self.advance();
1318 } else {
1319 break;
1320 }
1321 }
1322 }
1323
1324 Ok(args)
1325 }
1326
1327 fn parse_expression_list(&mut self) -> Result<Vec<SqlExpression>, String> {
1328 let mut expressions = Vec::new();
1329
1330 loop {
1331 expressions.push(self.parse_expression()?);
1332
1333 if matches!(self.current_token, Token::Comma) {
1334 self.advance();
1335 } else {
1336 break;
1337 }
1338 }
1339
1340 Ok(expressions)
1341 }
1342
1343 fn get_binary_op(&self) -> Option<String> {
1344 match &self.current_token {
1345 Token::Equal => Some("=".to_string()),
1346 Token::NotEqual => Some("!=".to_string()),
1347 Token::LessThan => Some("<".to_string()),
1348 Token::GreaterThan => Some(">".to_string()),
1349 Token::LessThanOrEqual => Some("<=".to_string()),
1350 Token::GreaterThanOrEqual => Some(">=".to_string()),
1351 Token::Like => Some("LIKE".to_string()),
1352 _ => None,
1353 }
1354 }
1355
1356 fn get_arithmetic_op(&self) -> Option<String> {
1357 match &self.current_token {
1358 Token::Plus => Some("+".to_string()),
1359 Token::Minus => Some("-".to_string()),
1360 Token::Star => Some("*".to_string()), Token::Divide => Some("/".to_string()),
1362 _ => None,
1363 }
1364 }
1365
1366 pub fn get_position(&self) -> usize {
1367 self.lexer.get_position()
1368 }
1369}
1370
1371#[derive(Debug, Clone)]
1373pub enum CursorContext {
1374 SelectClause,
1375 FromClause,
1376 WhereClause,
1377 OrderByClause,
1378 AfterColumn(String),
1379 AfterLogicalOp(LogicalOp),
1380 AfterComparisonOp(String, String), InMethodCall(String, String), InExpression,
1383 Unknown,
1384}
1385
1386fn safe_slice_to(s: &str, pos: usize) -> &str {
1388 if pos >= s.len() {
1389 return s;
1390 }
1391
1392 let mut safe_pos = pos;
1394 while safe_pos > 0 && !s.is_char_boundary(safe_pos) {
1395 safe_pos -= 1;
1396 }
1397
1398 &s[..safe_pos]
1399}
1400
1401fn safe_slice_from(s: &str, pos: usize) -> &str {
1403 if pos >= s.len() {
1404 return "";
1405 }
1406
1407 let mut safe_pos = pos;
1409 while safe_pos < s.len() && !s.is_char_boundary(safe_pos) {
1410 safe_pos += 1;
1411 }
1412
1413 &s[safe_pos..]
1414}
1415
1416pub fn detect_cursor_context(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
1417 let truncated = safe_slice_to(query, cursor_pos);
1418 let mut parser = Parser::new(truncated);
1419
1420 match parser.parse() {
1422 Ok(stmt) => {
1423 let (ctx, partial) = analyze_statement(&stmt, truncated, cursor_pos);
1424 #[cfg(test)]
1425 println!(
1426 "analyze_statement returned: {:?}, {:?} for query: '{}'",
1427 ctx, partial, truncated
1428 );
1429 (ctx, partial)
1430 }
1431 Err(_) => {
1432 let (ctx, partial) = analyze_partial(truncated, cursor_pos);
1434 #[cfg(test)]
1435 println!(
1436 "analyze_partial returned: {:?}, {:?} for query: '{}'",
1437 ctx, partial, truncated
1438 );
1439 (ctx, partial)
1440 }
1441 }
1442}
1443
1444pub fn tokenize_query(query: &str) -> Vec<String> {
1445 let mut lexer = Lexer::new(query);
1446 let tokens = lexer.tokenize_all();
1447 tokens.iter().map(|t| format!("{:?}", t)).collect()
1448}
1449
1450pub fn format_sql_pretty(query: &str) -> Vec<String> {
1451 format_sql_pretty_compact(query, 5) }
1453
1454pub fn format_ast_tree(query: &str) -> String {
1456 let mut parser = Parser::new(query);
1457 match parser.parse() {
1458 Ok(stmt) => format_select_statement(&stmt, 0),
1459 Err(e) => format!("❌ PARSE ERROR ❌\n{}\n\n⚠️ The query could not be parsed correctly.\n💡 Check parentheses, operators, and syntax.", e),
1460 }
1461}
1462
1463fn format_select_statement(stmt: &SelectStatement, indent: usize) -> String {
1464 let mut result = String::new();
1465 let indent_str = " ".repeat(indent);
1466
1467 result.push_str(&format!("{indent_str}SelectStatement {{\n"));
1468
1469 result.push_str(&format!("{indent_str} columns: ["));
1471 if !stmt.columns.is_empty() {
1472 result.push('\n');
1473 for col in &stmt.columns {
1474 result.push_str(&format!("{indent_str} \"{col}\",\n"));
1475 }
1476 result.push_str(&format!("{indent_str} ],\n"));
1477 } else {
1478 result.push_str("],\n");
1479 }
1480
1481 if let Some(table) = &stmt.from_table {
1483 result.push_str(&format!("{indent_str} from_table: \"{table}\",\n"));
1484 }
1485
1486 if let Some(where_clause) = &stmt.where_clause {
1488 result.push_str(&format!("{indent_str} where_clause: {{\n"));
1489 result.push_str(&format_where_clause(where_clause, indent + 2));
1490 result.push_str(&format!("{indent_str} }},\n"));
1491 }
1492
1493 if let Some(order_by) = &stmt.order_by {
1495 result.push_str(&format!("{indent_str} order_by: ["));
1496 if !order_by.is_empty() {
1497 result.push('\n');
1498 for col in order_by {
1499 let dir = match col.direction {
1500 SortDirection::Asc => "ASC",
1501 SortDirection::Desc => "DESC",
1502 };
1503 result.push_str(&format!(
1504 "{indent_str} \"{col}\" {dir},\n",
1505 col = col.column
1506 ));
1507 }
1508 result.push_str(&format!("{indent_str} ],\n"));
1509 } else {
1510 result.push_str("],\n");
1511 }
1512 }
1513
1514 if let Some(group_by) = &stmt.group_by {
1516 result.push_str(&format!("{indent_str} group_by: ["));
1517 if !group_by.is_empty() {
1518 result.push('\n');
1519 for col in group_by {
1520 result.push_str(&format!("{indent_str} \"{col}\",\n"));
1521 }
1522 result.push_str(&format!("{indent_str} ],\n"));
1523 } else {
1524 result.push_str("]\n");
1525 }
1526 }
1527
1528 result.push_str(&format!("{indent_str}}}"));
1529 result
1530}
1531
1532fn format_where_clause(clause: &WhereClause, indent: usize) -> String {
1533 let mut result = String::new();
1534 let indent_str = " ".repeat(indent);
1535
1536 result.push_str(&format!("{indent_str}conditions: [\n"));
1537
1538 for condition in &clause.conditions {
1539 result.push_str(&format!("{indent_str} {{\n"));
1540 result.push_str(&format!(
1541 "{indent_str} expr: {},\n",
1542 format_expression_ast(&condition.expr)
1543 ));
1544
1545 if let Some(connector) = &condition.connector {
1546 let connector_str = match connector {
1547 LogicalOp::And => "AND",
1548 LogicalOp::Or => "OR",
1549 };
1550 result.push_str(&format!("{indent_str} connector: {connector_str},\n"));
1551 }
1552
1553 result.push_str(&format!("{indent_str} }},\n"));
1554 }
1555
1556 result.push_str(&format!("{indent_str}]\n"));
1557 result
1558}
1559
1560fn format_expression_ast(expr: &SqlExpression) -> String {
1561 match expr {
1562 SqlExpression::Column(name) => format!("Column(\"{}\")", name),
1563 SqlExpression::StringLiteral(value) => format!("StringLiteral(\"{}\")", value),
1564 SqlExpression::NumberLiteral(value) => format!("NumberLiteral({})", value),
1565 SqlExpression::DateTimeConstructor {
1566 year,
1567 month,
1568 day,
1569 hour,
1570 minute,
1571 second,
1572 } => {
1573 format!(
1574 "DateTime({}-{:02}-{:02} {:02}:{:02}:{:02})",
1575 year,
1576 month,
1577 day,
1578 hour.unwrap_or(0),
1579 minute.unwrap_or(0),
1580 second.unwrap_or(0)
1581 )
1582 }
1583 SqlExpression::DateTimeToday {
1584 hour,
1585 minute,
1586 second,
1587 } => {
1588 format!(
1589 "DateTimeToday({:02}:{:02}:{:02})",
1590 hour.unwrap_or(0),
1591 minute.unwrap_or(0),
1592 second.unwrap_or(0)
1593 )
1594 }
1595 SqlExpression::MethodCall {
1596 object,
1597 method,
1598 args,
1599 } => {
1600 let args_str = args
1601 .iter()
1602 .map(|a| format_expression_ast(a))
1603 .collect::<Vec<_>>()
1604 .join(", ");
1605 format!("MethodCall({}.{}({}))", object, method, args_str)
1606 }
1607 SqlExpression::ChainedMethodCall { base, method, args } => {
1608 let args_str = args
1609 .iter()
1610 .map(|a| format_expression_ast(a))
1611 .collect::<Vec<_>>()
1612 .join(", ");
1613 format!(
1614 "ChainedMethodCall({}.{}({}))",
1615 format_expression_ast(base),
1616 method,
1617 args_str
1618 )
1619 }
1620 SqlExpression::FunctionCall { name, args } => {
1621 let args_str = args
1622 .iter()
1623 .map(|a| format_expression_ast(a))
1624 .collect::<Vec<_>>()
1625 .join(", ");
1626 format!("FunctionCall({}({}))", name, args_str)
1627 }
1628 SqlExpression::BinaryOp { left, op, right } => {
1629 format!(
1630 "BinaryOp({} {} {})",
1631 format_expression_ast(left),
1632 op,
1633 format_expression_ast(right)
1634 )
1635 }
1636 SqlExpression::InList { expr, values } => {
1637 let list_str = values
1638 .iter()
1639 .map(|e| format_expression_ast(e))
1640 .collect::<Vec<_>>()
1641 .join(", ");
1642 format!("InList({} IN [{}])", format_expression_ast(expr), list_str)
1643 }
1644 SqlExpression::NotInList { expr, values } => {
1645 let list_str = values
1646 .iter()
1647 .map(|e| format_expression_ast(e))
1648 .collect::<Vec<_>>()
1649 .join(", ");
1650 format!(
1651 "NotInList({} NOT IN [{}])",
1652 format_expression_ast(expr),
1653 list_str
1654 )
1655 }
1656 SqlExpression::Between { expr, lower, upper } => {
1657 format!(
1658 "Between({} BETWEEN {} AND {})",
1659 format_expression_ast(expr),
1660 format_expression_ast(lower),
1661 format_expression_ast(upper)
1662 )
1663 }
1664 SqlExpression::Not { expr } => {
1665 format!("Not({})", format_expression_ast(expr))
1666 }
1667 }
1668}
1669
1670pub fn datetime_to_iso_string(expr: &SqlExpression) -> Option<String> {
1672 match expr {
1673 SqlExpression::DateTimeConstructor {
1674 year,
1675 month,
1676 day,
1677 hour,
1678 minute,
1679 second,
1680 } => {
1681 let h = hour.unwrap_or(0);
1682 let m = minute.unwrap_or(0);
1683 let s = second.unwrap_or(0);
1684
1685 if let Ok(dt) = NaiveDateTime::parse_from_str(
1687 &format!(
1688 "{:04}-{:02}-{:02} {:02}:{:02}:{:02}",
1689 year, month, day, h, m, s
1690 ),
1691 "%Y-%m-%d %H:%M:%S",
1692 ) {
1693 Some(dt.format("%Y-%m-%d %H:%M:%S").to_string())
1694 } else {
1695 None
1696 }
1697 }
1698 SqlExpression::DateTimeToday {
1699 hour,
1700 minute,
1701 second,
1702 } => {
1703 let now = Local::now();
1704 let h = hour.unwrap_or(0);
1705 let m = minute.unwrap_or(0);
1706 let s = second.unwrap_or(0);
1707
1708 if let Ok(dt) = NaiveDateTime::parse_from_str(
1710 &format!(
1711 "{:04}-{:02}-{:02} {:02}:{:02}:{:02}",
1712 now.year(),
1713 now.month(),
1714 now.day(),
1715 h,
1716 m,
1717 s
1718 ),
1719 "%Y-%m-%d %H:%M:%S",
1720 ) {
1721 Some(dt.format("%Y-%m-%d %H:%M:%S").to_string())
1722 } else {
1723 None
1724 }
1725 }
1726 _ => None,
1727 }
1728}
1729
1730fn format_sql_with_preserved_parens(
1732 query: &str,
1733 cols_per_line: usize,
1734) -> Result<Vec<String>, String> {
1735 let mut lines = Vec::new();
1736 let mut lexer = Lexer::new(query);
1737 let tokens_with_pos = lexer.tokenize_all_with_positions();
1738
1739 if tokens_with_pos.is_empty() {
1740 return Err("No tokens found".to_string());
1741 }
1742
1743 let mut i = 0;
1744 let cols_per_line = cols_per_line.max(1);
1745
1746 while i < tokens_with_pos.len() {
1747 let (start, _end, ref token) = tokens_with_pos[i];
1748
1749 match token {
1750 Token::Select => {
1751 lines.push("SELECT".to_string());
1752 i += 1;
1753
1754 let mut columns = Vec::new();
1756 let mut col_start = i;
1757 while i < tokens_with_pos.len() {
1758 match &tokens_with_pos[i].2 {
1759 Token::From | Token::Eof => break,
1760 Token::Comma => {
1761 if col_start < i {
1763 let col_text = extract_text_between_positions(
1764 query,
1765 tokens_with_pos[col_start].0,
1766 tokens_with_pos[i - 1].1,
1767 );
1768 columns.push(col_text);
1769 }
1770 i += 1;
1771 col_start = i;
1772 }
1773 _ => i += 1,
1774 }
1775 }
1776 if col_start < i && i > 0 {
1778 let col_text = extract_text_between_positions(
1779 query,
1780 tokens_with_pos[col_start].0,
1781 tokens_with_pos[i - 1].1,
1782 );
1783 columns.push(col_text);
1784 }
1785
1786 for chunk in columns.chunks(cols_per_line) {
1788 let mut line = " ".to_string();
1789 for (idx, col) in chunk.iter().enumerate() {
1790 if idx > 0 {
1791 line.push_str(", ");
1792 }
1793 line.push_str(col.trim());
1794 }
1795 let is_last_chunk = chunk.as_ptr() as usize
1797 + chunk.len() * std::mem::size_of::<String>()
1798 >= columns.last().map(|c| c as *const _ as usize).unwrap_or(0);
1799 if !is_last_chunk && columns.len() > cols_per_line {
1800 line.push(',');
1801 }
1802 lines.push(line);
1803 }
1804 }
1805 Token::From => {
1806 i += 1;
1807 if i < tokens_with_pos.len() {
1808 let table_start = tokens_with_pos[i].0;
1809 while i < tokens_with_pos.len() {
1811 match &tokens_with_pos[i].2 {
1812 Token::Where | Token::OrderBy | Token::GroupBy | Token::Eof => break,
1813 _ => i += 1,
1814 }
1815 }
1816 if i > 0 {
1817 let table_text = extract_text_between_positions(
1818 query,
1819 table_start,
1820 tokens_with_pos[i - 1].1,
1821 );
1822 lines.push(format!("FROM {}", table_text.trim()));
1823 }
1824 }
1825 }
1826 Token::Where => {
1827 lines.push("WHERE".to_string());
1828 i += 1;
1829
1830 let where_start = if i < tokens_with_pos.len() {
1832 tokens_with_pos[i].0
1833 } else {
1834 start
1835 };
1836
1837 let mut where_end = query.len();
1839 while i < tokens_with_pos.len() {
1840 match &tokens_with_pos[i].2 {
1841 Token::OrderBy | Token::GroupBy | Token::Eof => {
1842 if i > 0 {
1843 where_end = tokens_with_pos[i - 1].1;
1844 }
1845 break;
1846 }
1847 _ => i += 1,
1848 }
1849 }
1850
1851 let where_text = extract_text_between_positions(query, where_start, where_end);
1853
1854 let formatted_where = format_where_clause_with_parens(&where_text);
1856 for line in formatted_where {
1857 lines.push(format!(" {}", line));
1858 }
1859 }
1860 Token::OrderBy => {
1861 i += 1;
1862 let order_start = if i < tokens_with_pos.len() {
1863 tokens_with_pos[i].0
1864 } else {
1865 start
1866 };
1867
1868 while i < tokens_with_pos.len() {
1870 match &tokens_with_pos[i].2 {
1871 Token::GroupBy | Token::Eof => break,
1872 _ => i += 1,
1873 }
1874 }
1875
1876 if i > 0 {
1877 let order_text = extract_text_between_positions(
1878 query,
1879 order_start,
1880 tokens_with_pos[i - 1].1,
1881 );
1882 lines.push(format!("ORDER BY {}", order_text.trim()));
1883 }
1884 }
1885 Token::GroupBy => {
1886 i += 1;
1887 let group_start = if i < tokens_with_pos.len() {
1888 tokens_with_pos[i].0
1889 } else {
1890 start
1891 };
1892
1893 while i < tokens_with_pos.len() {
1895 match &tokens_with_pos[i].2 {
1896 Token::Having | Token::Eof => break,
1897 _ => i += 1,
1898 }
1899 }
1900
1901 if i > 0 {
1902 let group_text = extract_text_between_positions(
1903 query,
1904 group_start,
1905 tokens_with_pos[i - 1].1,
1906 );
1907 lines.push(format!("GROUP BY {}", group_text.trim()));
1908 }
1909 }
1910 _ => i += 1,
1911 }
1912 }
1913
1914 Ok(lines)
1915}
1916
1917fn extract_text_between_positions(query: &str, start: usize, end: usize) -> String {
1919 let chars: Vec<char> = query.chars().collect();
1920 let start = start.min(chars.len());
1921 let end = end.min(chars.len());
1922 chars[start..end].iter().collect()
1923}
1924
1925fn format_where_clause_with_parens(where_text: &str) -> Vec<String> {
1927 let mut lines = Vec::new();
1928 let mut current_line = String::new();
1929 let mut paren_depth = 0;
1930 let mut i = 0;
1931 let chars: Vec<char> = where_text.chars().collect();
1932
1933 while i < chars.len() {
1934 if paren_depth == 0 {
1936 if i + 5 <= chars.len() {
1938 let next_five: String = chars[i..i + 5].iter().collect();
1939 if next_five.to_uppercase() == " AND " {
1940 if !current_line.trim().is_empty() {
1941 lines.push(current_line.trim().to_string());
1942 }
1943 lines.push("AND".to_string());
1944 current_line.clear();
1945 i += 5;
1946 continue;
1947 }
1948 }
1949 if i + 4 <= chars.len() {
1950 let next_four: String = chars[i..i + 4].iter().collect();
1951 if next_four.to_uppercase() == " OR " {
1952 if !current_line.trim().is_empty() {
1953 lines.push(current_line.trim().to_string());
1954 }
1955 lines.push("OR".to_string());
1956 current_line.clear();
1957 i += 4;
1958 continue;
1959 }
1960 }
1961 }
1962
1963 match chars[i] {
1965 '(' => {
1966 paren_depth += 1;
1967 current_line.push('(');
1968 }
1969 ')' => {
1970 paren_depth -= 1;
1971 current_line.push(')');
1972 }
1973 c => current_line.push(c),
1974 }
1975 i += 1;
1976 }
1977
1978 if !current_line.trim().is_empty() {
1980 lines.push(current_line.trim().to_string());
1981 }
1982
1983 if lines.is_empty() {
1985 lines.push(where_text.trim().to_string());
1986 }
1987
1988 lines
1989}
1990
1991pub fn format_sql_pretty_compact(query: &str, cols_per_line: usize) -> Vec<String> {
1992 if let Ok(lines) = format_sql_with_preserved_parens(query, cols_per_line) {
1994 return lines;
1995 }
1996
1997 let mut lines = Vec::new();
1999 let mut parser = Parser::new(query);
2000
2001 let cols_per_line = cols_per_line.max(1);
2003
2004 match parser.parse() {
2005 Ok(stmt) => {
2006 if !stmt.columns.is_empty() {
2008 lines.push("SELECT".to_string());
2009
2010 for chunk in stmt.columns.chunks(cols_per_line) {
2012 let mut line = " ".to_string();
2013 for (i, col) in chunk.iter().enumerate() {
2014 if i > 0 {
2015 line.push_str(", ");
2016 }
2017 line.push_str(col);
2018 }
2019 let last_chunk_idx = (stmt.columns.len() - 1) / cols_per_line;
2021 let current_chunk_idx =
2022 stmt.columns.iter().position(|c| c == &chunk[0]).unwrap() / cols_per_line;
2023 if current_chunk_idx < last_chunk_idx {
2024 line.push(',');
2025 }
2026 lines.push(line);
2027 }
2028 }
2029
2030 if let Some(table) = &stmt.from_table {
2032 lines.push(format!("FROM {}", table));
2033 }
2034
2035 if let Some(where_clause) = &stmt.where_clause {
2037 lines.push("WHERE".to_string());
2038 for (i, condition) in where_clause.conditions.iter().enumerate() {
2039 if i > 0 {
2040 if let Some(prev_condition) = where_clause.conditions.get(i - 1) {
2042 if let Some(connector) = &prev_condition.connector {
2043 match connector {
2044 LogicalOp::And => lines.push(" AND".to_string()),
2045 LogicalOp::Or => lines.push(" OR".to_string()),
2046 }
2047 }
2048 }
2049 }
2050 lines.push(format!(" {}", format_expression(&condition.expr)));
2051 }
2052 }
2053
2054 if let Some(order_by) = &stmt.order_by {
2056 let order_str = order_by
2057 .iter()
2058 .map(|col| {
2059 let dir = match col.direction {
2060 SortDirection::Asc => " ASC",
2061 SortDirection::Desc => " DESC",
2062 };
2063 format!("{}{}", col.column, dir)
2064 })
2065 .collect::<Vec<_>>()
2066 .join(", ");
2067 lines.push(format!("ORDER BY {}", order_str));
2068 }
2069
2070 if let Some(group_by) = &stmt.group_by {
2072 let group_str = group_by.join(", ");
2073 lines.push(format!("GROUP BY {}", group_str));
2074 }
2075 }
2076 Err(_) => {
2077 let mut lexer = Lexer::new(query);
2079 let tokens = lexer.tokenize_all();
2080 let mut current_line = String::new();
2081 let mut indent = 0;
2082
2083 for token in tokens {
2084 match &token {
2085 Token::Select
2086 | Token::From
2087 | Token::Where
2088 | Token::OrderBy
2089 | Token::GroupBy => {
2090 if !current_line.is_empty() {
2091 lines.push(current_line.trim().to_string());
2092 current_line.clear();
2093 }
2094 lines.push(format!("{:?}", token).to_uppercase());
2095 indent = 1;
2096 }
2097 Token::And | Token::Or => {
2098 if !current_line.is_empty() {
2099 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2100 current_line.clear();
2101 }
2102 lines.push(format!(" {:?}", token).to_uppercase());
2103 }
2104 Token::Comma => {
2105 current_line.push(',');
2106 if indent > 0 {
2107 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2108 current_line.clear();
2109 }
2110 }
2111 Token::Eof => break,
2112 _ => {
2113 if !current_line.is_empty() {
2114 current_line.push(' ');
2115 }
2116 current_line.push_str(&format_token(&token));
2117 }
2118 }
2119 }
2120
2121 if !current_line.is_empty() {
2122 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2123 }
2124 }
2125 }
2126
2127 lines
2128}
2129
2130fn format_expression(expr: &SqlExpression) -> String {
2131 match expr {
2132 SqlExpression::Column(name) => name.clone(),
2133 SqlExpression::StringLiteral(s) => format!("'{}'", s),
2134 SqlExpression::NumberLiteral(n) => n.clone(),
2135 SqlExpression::DateTimeConstructor {
2136 year,
2137 month,
2138 day,
2139 hour,
2140 minute,
2141 second,
2142 } => {
2143 let mut result = format!("DateTime({}, {}, {}", year, month, day);
2144 if let Some(h) = hour {
2145 result.push_str(&format!(", {}", h));
2146 if let Some(m) = minute {
2147 result.push_str(&format!(", {}", m));
2148 if let Some(s) = second {
2149 result.push_str(&format!(", {}", s));
2150 }
2151 }
2152 }
2153 result.push(')');
2154 result
2155 }
2156 SqlExpression::DateTimeToday {
2157 hour,
2158 minute,
2159 second,
2160 } => {
2161 let mut result = "DateTime()".to_string();
2162 if let Some(h) = hour {
2163 result = format!("DateTime(TODAY, {}", h);
2164 if let Some(m) = minute {
2165 result.push_str(&format!(", {}", m));
2166 if let Some(s) = second {
2167 result.push_str(&format!(", {}", s));
2168 }
2169 }
2170 result.push(')');
2171 }
2172 result
2173 }
2174 SqlExpression::MethodCall {
2175 object,
2176 method,
2177 args,
2178 } => {
2179 let args_str = args
2180 .iter()
2181 .map(|arg| format_expression(arg))
2182 .collect::<Vec<_>>()
2183 .join(", ");
2184 format!("{}.{}({})", object, method, args_str)
2185 }
2186 SqlExpression::BinaryOp { left, op, right } => {
2187 if op == "OR" || op == "AND" {
2190 format!(
2193 "({} {} {})",
2194 format_expression(left),
2195 op,
2196 format_expression(right)
2197 )
2198 } else {
2199 format!(
2200 "{} {} {}",
2201 format_expression(left),
2202 op,
2203 format_expression(right)
2204 )
2205 }
2206 }
2207 SqlExpression::InList { expr, values } => {
2208 let values_str = values
2209 .iter()
2210 .map(|v| format_expression(v))
2211 .collect::<Vec<_>>()
2212 .join(", ");
2213 format!("{} IN ({})", format_expression(expr), values_str)
2214 }
2215 SqlExpression::NotInList { expr, values } => {
2216 let values_str = values
2217 .iter()
2218 .map(|v| format_expression(v))
2219 .collect::<Vec<_>>()
2220 .join(", ");
2221 format!("{} NOT IN ({})", format_expression(expr), values_str)
2222 }
2223 SqlExpression::Between { expr, lower, upper } => {
2224 format!(
2225 "{} BETWEEN {} AND {}",
2226 format_expression(expr),
2227 format_expression(lower),
2228 format_expression(upper)
2229 )
2230 }
2231 SqlExpression::Not { expr } => {
2232 format!("NOT {}", format_expression(expr))
2233 }
2234 SqlExpression::ChainedMethodCall { base, method, args } => {
2235 let args_str = args
2236 .iter()
2237 .map(|arg| format_expression(arg))
2238 .collect::<Vec<_>>()
2239 .join(", ");
2240 format!("{}.{}({})", format_expression(base), method, args_str)
2241 }
2242 SqlExpression::FunctionCall { name, args } => {
2243 let args_str = args
2244 .iter()
2245 .map(|arg| format_expression(arg))
2246 .collect::<Vec<_>>()
2247 .join(", ");
2248 format!("{}({})", name, args_str)
2249 }
2250 }
2251}
2252
2253fn format_token(token: &Token) -> String {
2254 match token {
2255 Token::Identifier(s) => s.clone(),
2256 Token::QuotedIdentifier(s) => format!("\"{}\"", s),
2257 Token::StringLiteral(s) => format!("'{}'", s),
2258 Token::NumberLiteral(n) => n.clone(),
2259 Token::DateTime => "DateTime".to_string(),
2260 Token::LeftParen => "(".to_string(),
2261 Token::RightParen => ")".to_string(),
2262 Token::Comma => ",".to_string(),
2263 Token::Dot => ".".to_string(),
2264 Token::Equal => "=".to_string(),
2265 Token::NotEqual => "!=".to_string(),
2266 Token::LessThan => "<".to_string(),
2267 Token::GreaterThan => ">".to_string(),
2268 Token::LessThanOrEqual => "<=".to_string(),
2269 Token::GreaterThanOrEqual => ">=".to_string(),
2270 Token::In => "IN".to_string(),
2271 _ => format!("{:?}", token).to_uppercase(),
2272 }
2273}
2274
2275fn analyze_statement(
2276 stmt: &SelectStatement,
2277 query: &str,
2278 _cursor_pos: usize,
2279) -> (CursorContext, Option<String>) {
2280 let trimmed = query.trim();
2282
2283 let comparison_ops = [" > ", " < ", " >= ", " <= ", " = ", " != "];
2285 for op in &comparison_ops {
2286 if let Some(op_pos) = query.rfind(op) {
2287 let before_op = safe_slice_to(query, op_pos);
2288 let after_op_start = op_pos + op.len();
2289 let after_op = if after_op_start < query.len() {
2290 &query[after_op_start..]
2291 } else {
2292 ""
2293 };
2294
2295 if let Some(col_name) = before_op.split_whitespace().last() {
2297 if col_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
2298 let after_op_trimmed = after_op.trim();
2300 if after_op_trimmed.is_empty()
2301 || (after_op_trimmed
2302 .chars()
2303 .all(|c| c.is_alphanumeric() || c == '_')
2304 && !after_op_trimmed.contains('('))
2305 {
2306 let partial = if after_op_trimmed.is_empty() {
2307 None
2308 } else {
2309 Some(after_op_trimmed.to_string())
2310 };
2311 return (
2312 CursorContext::AfterComparisonOp(
2313 col_name.to_string(),
2314 op.trim().to_string(),
2315 ),
2316 partial,
2317 );
2318 }
2319 }
2320 }
2321 }
2322 }
2323
2324 if trimmed.to_uppercase().ends_with(" AND")
2326 || trimmed.to_uppercase().ends_with(" OR")
2327 || trimmed.to_uppercase().ends_with(" AND ")
2328 || trimmed.to_uppercase().ends_with(" OR ")
2329 {
2330 } else {
2332 if let Some(dot_pos) = trimmed.rfind('.') {
2334 let before_dot = safe_slice_to(trimmed, dot_pos);
2336 let after_dot_start = dot_pos + 1;
2337 let after_dot = if after_dot_start < trimmed.len() {
2338 &trimmed[after_dot_start..]
2339 } else {
2340 ""
2341 };
2342
2343 if !after_dot.contains('(') {
2346 let col_name = if before_dot.ends_with('"') {
2348 let bytes = before_dot.as_bytes();
2350 let mut pos = before_dot.len() - 1; let mut found_start = None;
2352
2353 if pos > 0 {
2355 pos -= 1;
2356 while pos > 0 {
2357 if bytes[pos] == b'"' {
2358 if pos == 0 || bytes[pos - 1] != b'\\' {
2360 found_start = Some(pos);
2361 break;
2362 }
2363 }
2364 pos -= 1;
2365 }
2366 if found_start.is_none() && bytes[0] == b'"' {
2368 found_start = Some(0);
2369 }
2370 }
2371
2372 if let Some(start) = found_start {
2373 Some(safe_slice_from(before_dot, start))
2375 } else {
2376 None
2377 }
2378 } else {
2379 before_dot
2382 .split_whitespace()
2383 .last()
2384 .map(|word| word.trim_start_matches('('))
2385 };
2386
2387 if let Some(col_name) = col_name {
2388 let is_valid = if col_name.starts_with('"') && col_name.ends_with('"') {
2390 true
2392 } else {
2393 col_name.chars().all(|c| c.is_alphanumeric() || c == '_')
2395 };
2396
2397 if is_valid {
2398 let partial_method = if after_dot.is_empty() {
2401 None
2402 } else if after_dot.chars().all(|c| c.is_alphanumeric() || c == '_') {
2403 Some(after_dot.to_string())
2404 } else {
2405 None
2406 };
2407
2408 let col_name_for_context = if col_name.starts_with('"')
2410 && col_name.ends_with('"')
2411 && col_name.len() > 2
2412 {
2413 col_name[1..col_name.len() - 1].to_string()
2414 } else {
2415 col_name.to_string()
2416 };
2417
2418 return (
2419 CursorContext::AfterColumn(col_name_for_context),
2420 partial_method,
2421 );
2422 }
2423 }
2424 }
2425 }
2426 }
2427
2428 if let Some(where_clause) = &stmt.where_clause {
2430 if trimmed.to_uppercase().ends_with(" AND") || trimmed.to_uppercase().ends_with(" OR") {
2432 let op = if trimmed.to_uppercase().ends_with(" AND") {
2433 LogicalOp::And
2434 } else {
2435 LogicalOp::Or
2436 };
2437 return (CursorContext::AfterLogicalOp(op), None);
2438 }
2439
2440 if let Some(and_pos) = query.to_uppercase().rfind(" AND ") {
2442 let after_and = safe_slice_from(query, and_pos + 5);
2443 let partial = extract_partial_at_end(after_and);
2444 if partial.is_some() {
2445 return (CursorContext::AfterLogicalOp(LogicalOp::And), partial);
2446 }
2447 }
2448
2449 if let Some(or_pos) = query.to_uppercase().rfind(" OR ") {
2450 let after_or = safe_slice_from(query, or_pos + 4);
2451 let partial = extract_partial_at_end(after_or);
2452 if partial.is_some() {
2453 return (CursorContext::AfterLogicalOp(LogicalOp::Or), partial);
2454 }
2455 }
2456
2457 if let Some(last_condition) = where_clause.conditions.last() {
2458 if let Some(connector) = &last_condition.connector {
2459 return (
2461 CursorContext::AfterLogicalOp(connector.clone()),
2462 extract_partial_at_end(query),
2463 );
2464 }
2465 }
2466 return (CursorContext::WhereClause, extract_partial_at_end(query));
2468 }
2469
2470 if query.to_uppercase().ends_with(" ORDER BY ") || query.to_uppercase().ends_with(" ORDER BY") {
2472 return (CursorContext::OrderByClause, None);
2473 }
2474
2475 if stmt.order_by.is_some() {
2477 return (CursorContext::OrderByClause, extract_partial_at_end(query));
2478 }
2479
2480 if stmt.from_table.is_some() && stmt.where_clause.is_none() && stmt.order_by.is_none() {
2481 return (CursorContext::FromClause, extract_partial_at_end(query));
2482 }
2483
2484 if stmt.columns.len() > 0 && stmt.from_table.is_none() {
2485 return (CursorContext::SelectClause, extract_partial_at_end(query));
2486 }
2487
2488 (CursorContext::Unknown, None)
2489}
2490
2491fn analyze_partial(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
2492 let upper = query.to_uppercase();
2493
2494 let trimmed = query.trim();
2496
2497 #[cfg(test)]
2498 {
2499 if trimmed.contains("\"Last Name\"") {
2500 eprintln!(
2501 "DEBUG analyze_partial: query='{}', trimmed='{}'",
2502 query, trimmed
2503 );
2504 }
2505 }
2506
2507 let comparison_ops = [" > ", " < ", " >= ", " <= ", " = ", " != "];
2509 for op in &comparison_ops {
2510 if let Some(op_pos) = query.rfind(op) {
2511 let before_op = safe_slice_to(query, op_pos);
2512 let after_op_start = op_pos + op.len();
2513 let after_op = if after_op_start < query.len() {
2514 &query[after_op_start..]
2515 } else {
2516 ""
2517 };
2518
2519 if let Some(col_name) = before_op.split_whitespace().last() {
2521 if col_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
2522 let after_op_trimmed = after_op.trim();
2524 if after_op_trimmed.is_empty()
2525 || (after_op_trimmed
2526 .chars()
2527 .all(|c| c.is_alphanumeric() || c == '_')
2528 && !after_op_trimmed.contains('('))
2529 {
2530 let partial = if after_op_trimmed.is_empty() {
2531 None
2532 } else {
2533 Some(after_op_trimmed.to_string())
2534 };
2535 return (
2536 CursorContext::AfterComparisonOp(
2537 col_name.to_string(),
2538 op.trim().to_string(),
2539 ),
2540 partial,
2541 );
2542 }
2543 }
2544 }
2545 }
2546 }
2547
2548 if let Some(dot_pos) = trimmed.rfind('.') {
2551 #[cfg(test)]
2552 {
2553 if trimmed.contains("\"Last Name\"") {
2554 eprintln!("DEBUG: Found dot at position {}", dot_pos);
2555 }
2556 }
2557 let before_dot = &trimmed[..dot_pos];
2559 let after_dot = &trimmed[dot_pos + 1..];
2560
2561 if !after_dot.contains('(') {
2564 let col_name = if before_dot.ends_with('"') {
2567 let bytes = before_dot.as_bytes();
2569 let mut pos = before_dot.len() - 1; let mut found_start = None;
2571
2572 #[cfg(test)]
2573 {
2574 if trimmed.contains("\"Last Name\"") {
2575 eprintln!(
2576 "DEBUG: before_dot='{}', looking for opening quote",
2577 before_dot
2578 );
2579 }
2580 }
2581
2582 if pos > 0 {
2584 pos -= 1;
2585 while pos > 0 {
2586 if bytes[pos] == b'"' {
2587 if pos == 0 || bytes[pos - 1] != b'\\' {
2589 found_start = Some(pos);
2590 break;
2591 }
2592 }
2593 pos -= 1;
2594 }
2595 if found_start.is_none() && bytes[0] == b'"' {
2597 found_start = Some(0);
2598 }
2599 }
2600
2601 if let Some(start) = found_start {
2602 let result = safe_slice_from(before_dot, start);
2604 #[cfg(test)]
2605 {
2606 if trimmed.contains("\"Last Name\"") {
2607 eprintln!("DEBUG: Extracted quoted identifier: '{}'", result);
2608 }
2609 }
2610 Some(result)
2611 } else {
2612 #[cfg(test)]
2613 {
2614 if trimmed.contains("\"Last Name\"") {
2615 eprintln!("DEBUG: No opening quote found!");
2616 }
2617 }
2618 None
2619 }
2620 } else {
2621 before_dot
2624 .split_whitespace()
2625 .last()
2626 .map(|word| word.trim_start_matches('('))
2627 };
2628
2629 if let Some(col_name) = col_name {
2630 #[cfg(test)]
2631 {
2632 if trimmed.contains("\"Last Name\"") {
2633 eprintln!("DEBUG: col_name = '{}'", col_name);
2634 }
2635 }
2636
2637 let is_valid = if col_name.starts_with('"') && col_name.ends_with('"') {
2639 true
2641 } else {
2642 col_name.chars().all(|c| c.is_alphanumeric() || c == '_')
2644 };
2645
2646 #[cfg(test)]
2647 {
2648 if trimmed.contains("\"Last Name\"") {
2649 eprintln!("DEBUG: is_valid = {}", is_valid);
2650 }
2651 }
2652
2653 if is_valid {
2654 let partial_method = if after_dot.is_empty() {
2657 None
2658 } else if after_dot.chars().all(|c| c.is_alphanumeric() || c == '_') {
2659 Some(after_dot.to_string())
2660 } else {
2661 None
2662 };
2663
2664 let col_name_for_context = if col_name.starts_with('"')
2666 && col_name.ends_with('"')
2667 && col_name.len() > 2
2668 {
2669 col_name[1..col_name.len() - 1].to_string()
2670 } else {
2671 col_name.to_string()
2672 };
2673
2674 return (
2675 CursorContext::AfterColumn(col_name_for_context),
2676 partial_method,
2677 );
2678 }
2679 }
2680 }
2681 }
2682
2683 if let Some(and_pos) = upper.rfind(" AND ") {
2685 if cursor_pos >= and_pos + 5 {
2687 let after_and = safe_slice_from(query, and_pos + 5);
2689 let partial = extract_partial_at_end(after_and);
2690 return (CursorContext::AfterLogicalOp(LogicalOp::And), partial);
2691 }
2692 }
2693
2694 if let Some(or_pos) = upper.rfind(" OR ") {
2695 if cursor_pos >= or_pos + 4 {
2697 let after_or = safe_slice_from(query, or_pos + 4);
2699 let partial = extract_partial_at_end(after_or);
2700 return (CursorContext::AfterLogicalOp(LogicalOp::Or), partial);
2701 }
2702 }
2703
2704 if trimmed.to_uppercase().ends_with(" AND") || trimmed.to_uppercase().ends_with(" OR") {
2706 let op = if trimmed.to_uppercase().ends_with(" AND") {
2707 LogicalOp::And
2708 } else {
2709 LogicalOp::Or
2710 };
2711 return (CursorContext::AfterLogicalOp(op), None);
2712 }
2713
2714 if upper.ends_with(" ORDER BY ") || upper.ends_with(" ORDER BY") || upper.contains("ORDER BY ")
2716 {
2717 return (CursorContext::OrderByClause, extract_partial_at_end(query));
2718 }
2719
2720 if upper.contains("WHERE") && !upper.contains("ORDER") && !upper.contains("GROUP") {
2721 return (CursorContext::WhereClause, extract_partial_at_end(query));
2722 }
2723
2724 if upper.contains("FROM") && !upper.contains("WHERE") && !upper.contains("ORDER") {
2725 return (CursorContext::FromClause, extract_partial_at_end(query));
2726 }
2727
2728 if upper.contains("SELECT") && !upper.contains("FROM") {
2729 return (CursorContext::SelectClause, extract_partial_at_end(query));
2730 }
2731
2732 (CursorContext::Unknown, None)
2733}
2734
2735fn extract_partial_at_end(query: &str) -> Option<String> {
2736 let trimmed = query.trim();
2737
2738 if let Some(last_word) = trimmed.split_whitespace().last() {
2740 if last_word.starts_with('"') && !last_word.ends_with('"') {
2741 return Some(last_word.to_string());
2743 }
2744 }
2745
2746 let last_word = trimmed.split_whitespace().last()?;
2748
2749 if last_word.chars().all(|c| c.is_alphanumeric() || c == '_') && !is_sql_keyword(last_word) {
2751 Some(last_word.to_string())
2752 } else {
2753 None
2754 }
2755}
2756
2757fn is_sql_keyword(word: &str) -> bool {
2758 matches!(
2759 word.to_uppercase().as_str(),
2760 "SELECT"
2761 | "FROM"
2762 | "WHERE"
2763 | "AND"
2764 | "OR"
2765 | "IN"
2766 | "ORDER"
2767 | "BY"
2768 | "GROUP"
2769 | "HAVING"
2770 | "ASC"
2771 | "DESC"
2772 )
2773}
2774
2775#[cfg(test)]
2776mod tests {
2777 use super::*;
2778
2779 #[test]
2780 fn test_chained_method_calls() {
2781 let query = "SELECT * FROM trades WHERE commission.ToString().IndexOf('.') = 1";
2783 let mut parser = Parser::new(query);
2784 let result = parser.parse();
2785
2786 assert!(
2787 result.is_ok(),
2788 "Failed to parse chained method calls: {:?}",
2789 result
2790 );
2791
2792 let query2 = "SELECT * FROM data WHERE field.ToUpper().Replace('A', 'B').StartsWith('C')";
2794 let mut parser2 = Parser::new(query2);
2795 let result2 = parser2.parse();
2796
2797 assert!(
2798 result2.is_ok(),
2799 "Failed to parse multiple chained calls: {:?}",
2800 result2
2801 );
2802 }
2803
2804 #[test]
2805 fn test_tokenizer() {
2806 let mut lexer = Lexer::new("SELECT * FROM trade_deal WHERE price > 100");
2807
2808 assert!(matches!(lexer.next_token(), Token::Select));
2809 assert!(matches!(lexer.next_token(), Token::Star));
2810 assert!(matches!(lexer.next_token(), Token::From));
2811 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "trade_deal"));
2812 assert!(matches!(lexer.next_token(), Token::Where));
2813 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "price"));
2814 assert!(matches!(lexer.next_token(), Token::GreaterThan));
2815 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "100"));
2816 }
2817
2818 #[test]
2819 fn test_tokenizer_datetime() {
2820 let mut lexer = Lexer::new("WHERE createdDate > DateTime(2025, 10, 20)");
2821
2822 assert!(matches!(lexer.next_token(), Token::Where));
2823 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "createdDate"));
2824 assert!(matches!(lexer.next_token(), Token::GreaterThan));
2825 assert!(matches!(lexer.next_token(), Token::DateTime));
2826 assert!(matches!(lexer.next_token(), Token::LeftParen));
2827 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "2025"));
2828 assert!(matches!(lexer.next_token(), Token::Comma));
2829 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "10"));
2830 assert!(matches!(lexer.next_token(), Token::Comma));
2831 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "20"));
2832 assert!(matches!(lexer.next_token(), Token::RightParen));
2833 }
2834
2835 #[test]
2836 fn test_parse_simple_select() {
2837 let mut parser = Parser::new("SELECT * FROM trade_deal");
2838 let stmt = parser.parse().unwrap();
2839
2840 assert_eq!(stmt.columns, vec!["*"]);
2841 assert_eq!(stmt.from_table, Some("trade_deal".to_string()));
2842 assert!(stmt.where_clause.is_none());
2843 }
2844
2845 #[test]
2846 fn test_parse_where_with_method() {
2847 let mut parser = Parser::new("SELECT * FROM trade_deal WHERE name.Contains(\"test\")");
2848 let stmt = parser.parse().unwrap();
2849
2850 assert!(stmt.where_clause.is_some());
2851 let where_clause = stmt.where_clause.unwrap();
2852 assert_eq!(where_clause.conditions.len(), 1);
2853 }
2854
2855 #[test]
2856 fn test_parse_datetime_constructor() {
2857 let mut parser =
2858 Parser::new("SELECT * FROM trade_deal WHERE createdDate > DateTime(2025, 10, 20)");
2859 let stmt = parser.parse().unwrap();
2860
2861 assert!(stmt.where_clause.is_some());
2862 let where_clause = stmt.where_clause.unwrap();
2863 assert_eq!(where_clause.conditions.len(), 1);
2864
2865 if let SqlExpression::BinaryOp { left, op, right } = &where_clause.conditions[0].expr {
2867 assert_eq!(op, ">");
2868 assert!(matches!(left.as_ref(), SqlExpression::Column(col) if col == "createdDate"));
2869 assert!(matches!(
2870 right.as_ref(),
2871 SqlExpression::DateTimeConstructor {
2872 year: 2025,
2873 month: 10,
2874 day: 20,
2875 hour: None,
2876 minute: None,
2877 second: None
2878 }
2879 ));
2880 } else {
2881 panic!("Expected BinaryOp with DateTime constructor");
2882 }
2883 }
2884
2885 #[test]
2886 fn test_cursor_context_after_and() {
2887 let query = "SELECT * FROM trade_deal WHERE status = 'active' AND ";
2888 let (context, partial) = detect_cursor_context(query, query.len());
2889
2890 assert!(matches!(
2891 context,
2892 CursorContext::AfterLogicalOp(LogicalOp::And)
2893 ));
2894 assert_eq!(partial, None);
2895 }
2896
2897 #[test]
2898 fn test_cursor_context_with_partial() {
2899 let query = "SELECT * FROM trade_deal WHERE status = 'active' AND p";
2900 let (context, partial) = detect_cursor_context(query, query.len());
2901
2902 assert!(matches!(
2903 context,
2904 CursorContext::AfterLogicalOp(LogicalOp::And)
2905 ));
2906 assert_eq!(partial, Some("p".to_string()));
2907 }
2908
2909 #[test]
2910 fn test_cursor_context_after_datetime_comparison() {
2911 let query = "SELECT * FROM trade_deal WHERE createdDate > ";
2912 let (context, partial) = detect_cursor_context(query, query.len());
2913
2914 assert!(
2915 matches!(context, CursorContext::AfterComparisonOp(col, op) if col == "createdDate" && op == ">")
2916 );
2917 assert_eq!(partial, None);
2918 }
2919
2920 #[test]
2921 fn test_cursor_context_partial_datetime() {
2922 let query = "SELECT * FROM trade_deal WHERE createdDate > Date";
2923 let (context, partial) = detect_cursor_context(query, query.len());
2924
2925 assert!(
2926 matches!(context, CursorContext::AfterComparisonOp(col, op) if col == "createdDate" && op == ">")
2927 );
2928 assert_eq!(partial, Some("Date".to_string()));
2929 }
2930
2931 #[test]
2933 fn test_tokenizer_quoted_identifier() {
2934 let mut lexer = Lexer::new(r#"SELECT "Customer Id", "First Name" FROM customers"#);
2935
2936 assert!(matches!(lexer.next_token(), Token::Select));
2937 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "Customer Id"));
2938 assert!(matches!(lexer.next_token(), Token::Comma));
2939 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "First Name"));
2940 assert!(matches!(lexer.next_token(), Token::From));
2941 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "customers"));
2942 }
2943
2944 #[test]
2945 fn test_tokenizer_quoted_vs_string_literal() {
2946 let mut lexer = Lexer::new(r#"WHERE "Customer Id" = 'John' AND Country.Contains('USA')"#);
2948
2949 assert!(matches!(lexer.next_token(), Token::Where));
2950 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "Customer Id"));
2951 assert!(matches!(lexer.next_token(), Token::Equal));
2952 assert!(matches!(lexer.next_token(), Token::StringLiteral(s) if s == "John"));
2953 assert!(matches!(lexer.next_token(), Token::And));
2954 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Country"));
2955 assert!(matches!(lexer.next_token(), Token::Dot));
2956 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Contains"));
2957 assert!(matches!(lexer.next_token(), Token::LeftParen));
2958 assert!(matches!(lexer.next_token(), Token::StringLiteral(s) if s == "USA"));
2959 assert!(matches!(lexer.next_token(), Token::RightParen));
2960 }
2961
2962 #[test]
2963 fn test_tokenizer_method_with_double_quotes_should_be_string() {
2964 let mut lexer = Lexer::new(r#"Country.Contains("Alb")"#);
2967
2968 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Country"));
2969 assert!(matches!(lexer.next_token(), Token::Dot));
2970 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Contains"));
2971 assert!(matches!(lexer.next_token(), Token::LeftParen));
2972
2973 let token = lexer.next_token();
2976 println!("Token for \"Alb\": {:?}", token);
2977 assert!(matches!(lexer.next_token(), Token::RightParen));
2981 }
2982
2983 #[test]
2984 fn test_parse_select_with_quoted_columns() {
2985 let mut parser = Parser::new(r#"SELECT "Customer Id", Company FROM customers"#);
2986 let stmt = parser.parse().unwrap();
2987
2988 assert_eq!(stmt.columns, vec!["Customer Id", "Company"]);
2989 assert_eq!(stmt.from_table, Some("customers".to_string()));
2990 }
2991
2992 #[test]
2993 fn test_cursor_context_select_with_partial_quoted() {
2994 let query = r#"SELECT "Cust"#;
2996 let (context, partial) = detect_cursor_context(query, query.len() - 1); println!("Context: {:?}, Partial: {:?}", context, partial);
2999 assert!(matches!(context, CursorContext::SelectClause));
3000 }
3003
3004 #[test]
3005 fn test_cursor_context_select_after_comma_with_quoted() {
3006 let query = r#"SELECT Company, "Customer "#;
3008 let (context, partial) = detect_cursor_context(query, query.len());
3009
3010 println!("Context: {:?}, Partial: {:?}", context, partial);
3011 assert!(matches!(context, CursorContext::SelectClause));
3012 }
3014
3015 #[test]
3016 fn test_cursor_context_order_by_quoted() {
3017 let query = r#"SELECT * FROM customers ORDER BY "Cust"#;
3018 let (context, partial) = detect_cursor_context(query, query.len() - 1);
3019
3020 println!("Context: {:?}, Partial: {:?}", context, partial);
3021 assert!(matches!(context, CursorContext::OrderByClause));
3022 }
3024
3025 #[test]
3026 fn test_where_clause_with_quoted_column() {
3027 let mut parser = Parser::new(r#"SELECT * FROM customers WHERE "Customer Id" = 'C123'"#);
3028 let stmt = parser.parse().unwrap();
3029
3030 assert!(stmt.where_clause.is_some());
3031 let where_clause = stmt.where_clause.unwrap();
3032 assert_eq!(where_clause.conditions.len(), 1);
3033
3034 if let SqlExpression::BinaryOp { left, op, right } = &where_clause.conditions[0].expr {
3035 assert_eq!(op, "=");
3036 assert!(matches!(left.as_ref(), SqlExpression::Column(col) if col == "Customer Id"));
3037 assert!(matches!(right.as_ref(), SqlExpression::StringLiteral(s) if s == "C123"));
3038 } else {
3039 panic!("Expected BinaryOp");
3040 }
3041 }
3042
3043 #[test]
3044 fn test_parse_method_with_double_quotes_as_string() {
3045 let mut parser = Parser::new(r#"SELECT * FROM customers WHERE Country.Contains("USA")"#);
3047 let stmt = parser.parse().unwrap();
3048
3049 assert!(stmt.where_clause.is_some());
3050 let where_clause = stmt.where_clause.unwrap();
3051 assert_eq!(where_clause.conditions.len(), 1);
3052
3053 if let SqlExpression::MethodCall {
3054 object,
3055 method,
3056 args,
3057 } = &where_clause.conditions[0].expr
3058 {
3059 assert_eq!(object, "Country");
3060 assert_eq!(method, "Contains");
3061 assert_eq!(args.len(), 1);
3062 assert!(matches!(&args[0], SqlExpression::StringLiteral(s) if s == "USA"));
3064 } else {
3065 panic!("Expected MethodCall");
3066 }
3067 }
3068
3069 #[test]
3070 fn test_extract_partial_with_quoted_columns_in_query() {
3071 let query = r#"SELECT City,Company,Country,"Customer Id" FROM customers ORDER BY coun"#;
3073 let (context, partial) = detect_cursor_context(query, query.len());
3074
3075 assert!(matches!(context, CursorContext::OrderByClause));
3076 assert_eq!(
3077 partial,
3078 Some("coun".to_string()),
3079 "Should extract 'coun' as partial, not everything after the quoted column"
3080 );
3081 }
3082
3083 #[test]
3084 fn test_extract_partial_quoted_identifier_being_typed() {
3085 let query = r#"SELECT "Cust"#;
3087 let partial = extract_partial_at_end(query);
3088 assert_eq!(partial, Some("\"Cust".to_string()));
3089
3090 let query2 = r#"SELECT "Customer Id" FROM"#;
3092 let partial2 = extract_partial_at_end(query2);
3093 assert_eq!(partial2, None); }
3095
3096 #[test]
3098 fn test_complex_where_parentheses_basic() {
3099 let mut parser =
3101 Parser::new(r#"SELECT * FROM trades WHERE (status = "active" OR status = "pending")"#);
3102 let stmt = parser.parse().unwrap();
3103
3104 assert!(stmt.where_clause.is_some());
3105 let where_clause = stmt.where_clause.unwrap();
3106 assert_eq!(where_clause.conditions.len(), 1);
3107
3108 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[0].expr {
3110 assert_eq!(op, "OR");
3111 } else {
3112 panic!("Expected BinaryOp with OR");
3113 }
3114 }
3115
3116 #[test]
3117 fn test_complex_where_mixed_and_or_with_parens() {
3118 let mut parser = Parser::new(
3120 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#,
3121 );
3122 let stmt = parser.parse().unwrap();
3123
3124 assert!(stmt.where_clause.is_some());
3125 let where_clause = stmt.where_clause.unwrap();
3126 assert_eq!(where_clause.conditions.len(), 2);
3127
3128 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[0].expr {
3130 assert_eq!(op, "OR");
3131 } else {
3132 panic!("Expected first condition to be OR expression");
3133 }
3134
3135 assert!(matches!(
3137 where_clause.conditions[0].connector,
3138 Some(LogicalOp::And)
3139 ));
3140
3141 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[1].expr {
3143 assert_eq!(op, ">");
3144 } else {
3145 panic!("Expected second condition to be price > 100");
3146 }
3147 }
3148
3149 #[test]
3150 fn test_complex_where_nested_parentheses() {
3151 let mut parser = Parser::new(
3153 r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND price > 100) OR status = "cancelled""#,
3154 );
3155 let stmt = parser.parse().unwrap();
3156
3157 assert!(stmt.where_clause.is_some());
3158 let where_clause = stmt.where_clause.unwrap();
3159
3160 assert!(where_clause.conditions.len() > 0);
3162 }
3163
3164 #[test]
3165 fn test_complex_where_multiple_or_groups() {
3166 let mut parser = Parser::new(
3168 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL" OR symbol = "MSFT") AND (price > 100 AND price < 500)"#,
3169 );
3170 let stmt = parser.parse().unwrap();
3171
3172 assert!(stmt.where_clause.is_some());
3173 let where_clause = stmt.where_clause.unwrap();
3174 assert_eq!(where_clause.conditions.len(), 2);
3175
3176 assert!(matches!(
3178 where_clause.conditions[0].connector,
3179 Some(LogicalOp::And)
3180 ));
3181 }
3182
3183 #[test]
3184 fn test_complex_where_with_methods_in_parens() {
3185 let mut parser = Parser::new(
3187 r#"SELECT * FROM trades WHERE (symbol.StartsWith("A") OR symbol.StartsWith("G")) AND volume > 1000000"#,
3188 );
3189 let stmt = parser.parse().unwrap();
3190
3191 assert!(stmt.where_clause.is_some());
3192 let where_clause = stmt.where_clause.unwrap();
3193 assert_eq!(where_clause.conditions.len(), 2);
3194
3195 if let SqlExpression::BinaryOp { op, left, right } = &where_clause.conditions[0].expr {
3197 assert_eq!(op, "OR");
3198 assert!(matches!(left.as_ref(), SqlExpression::MethodCall { .. }));
3199 assert!(matches!(right.as_ref(), SqlExpression::MethodCall { .. }));
3200 } else {
3201 panic!("Expected OR of method calls");
3202 }
3203 }
3204
3205 #[test]
3206 fn test_complex_where_date_comparisons_with_parens() {
3207 let mut parser = Parser::new(
3209 r#"SELECT * FROM trades WHERE (executionDate > DateTime(2024, 1, 1) AND executionDate < DateTime(2024, 12, 31)) AND (status = "filled" OR status = "partial")"#,
3210 );
3211 let stmt = parser.parse().unwrap();
3212
3213 assert!(stmt.where_clause.is_some());
3214 let where_clause = stmt.where_clause.unwrap();
3215 assert_eq!(where_clause.conditions.len(), 2);
3216
3217 assert!(matches!(
3219 where_clause.conditions[0].connector,
3220 Some(LogicalOp::And)
3221 ));
3222 }
3223
3224 #[test]
3225 fn test_complex_where_price_volume_filters() {
3226 let mut parser = Parser::new(
3228 r#"SELECT * FROM trades WHERE ((price > 100 AND price < 200) OR (price > 500 AND price < 1000)) AND volume > 10000"#,
3229 );
3230 let stmt = parser.parse().unwrap();
3231
3232 assert!(stmt.where_clause.is_some());
3233 let where_clause = stmt.where_clause.unwrap();
3234
3235 assert!(where_clause.conditions.len() > 0);
3237 }
3238
3239 #[test]
3240 fn test_complex_where_mixed_string_numeric() {
3241 let mut parser = Parser::new(
3243 r#"SELECT * FROM trades WHERE (exchange = "NYSE" OR exchange = "NASDAQ") AND (volume > 1000000 OR notes.Contains("urgent"))"#,
3244 );
3245 let stmt = parser.parse().unwrap();
3246
3247 assert!(stmt.where_clause.is_some());
3248 }
3250
3251 #[test]
3252 fn test_complex_where_triple_nested() {
3253 let mut parser = Parser::new(
3255 r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND (price > 100 OR volume > 1000000)) OR (status = "cancelled" AND reason.Contains("timeout"))"#,
3256 );
3257 let stmt = parser.parse().unwrap();
3258
3259 assert!(stmt.where_clause.is_some());
3260 }
3262
3263 #[test]
3264 fn test_complex_where_single_parens_around_and() {
3265 let mut parser = Parser::new(
3267 r#"SELECT * FROM trades WHERE (symbol = "AAPL" AND price > 150 AND volume > 100000)"#,
3268 );
3269 let stmt = parser.parse().unwrap();
3270
3271 assert!(stmt.where_clause.is_some());
3272 let where_clause = stmt.where_clause.unwrap();
3273
3274 assert!(where_clause.conditions.len() > 0);
3276 }
3277
3278 #[test]
3280 fn test_format_preserves_simple_parentheses() {
3281 let query = r#"SELECT * FROM trades WHERE (status = "active" OR status = "pending")"#;
3282 let formatted = format_sql_pretty_compact(query, 5);
3283 let formatted_text = formatted.join(" ");
3284
3285 assert!(formatted_text.contains("(status"));
3287 assert!(formatted_text.contains("\"pending\")"));
3288
3289 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3291 let formatted_parens = formatted_text
3292 .chars()
3293 .filter(|c| *c == '(' || *c == ')')
3294 .count();
3295 assert_eq!(
3296 original_parens, formatted_parens,
3297 "Parentheses should be preserved"
3298 );
3299 }
3300
3301 #[test]
3302 fn test_format_preserves_complex_parentheses() {
3303 let query =
3304 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#;
3305 let formatted = format_sql_pretty_compact(query, 5);
3306 let formatted_text = formatted.join(" ");
3307
3308 assert!(formatted_text.contains("(symbol"));
3310 assert!(formatted_text.contains("\"GOOGL\")"));
3311
3312 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3314 let formatted_parens = formatted_text
3315 .chars()
3316 .filter(|c| *c == '(' || *c == ')')
3317 .count();
3318 assert_eq!(original_parens, formatted_parens);
3319 }
3320
3321 #[test]
3322 fn test_format_preserves_nested_parentheses() {
3323 let query = r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND price > 100) OR status = "cancelled""#;
3324 let formatted = format_sql_pretty_compact(query, 5);
3325 let formatted_text = formatted.join(" ");
3326
3327 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3329 let formatted_parens = formatted_text
3330 .chars()
3331 .filter(|c| *c == '(' || *c == ')')
3332 .count();
3333 assert_eq!(
3334 original_parens, formatted_parens,
3335 "Nested parentheses should be preserved"
3336 );
3337 assert_eq!(original_parens, 4, "Should have 4 parentheses total");
3338 }
3339
3340 #[test]
3341 fn test_format_preserves_method_calls_in_parentheses() {
3342 let query = r#"SELECT * FROM trades WHERE (symbol.StartsWith("A") OR symbol.StartsWith("G")) AND volume > 1000000"#;
3343 let formatted = format_sql_pretty_compact(query, 5);
3344 let formatted_text = formatted.join(" ");
3345
3346 assert!(formatted_text.contains("(symbol.StartsWith"));
3348 assert!(formatted_text.contains("StartsWith(\"A\")"));
3349 assert!(formatted_text.contains("StartsWith(\"G\")"));
3350
3351 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3353 let formatted_parens = formatted_text
3354 .chars()
3355 .filter(|c| *c == '(' || *c == ')')
3356 .count();
3357 assert_eq!(original_parens, formatted_parens);
3358 assert_eq!(
3359 original_parens, 6,
3360 "Should have 6 parentheses (1 group + 2 method calls)"
3361 );
3362 }
3363
3364 #[test]
3365 fn test_format_preserves_multiple_groups() {
3366 let query = r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND (price > 100 AND price < 500)"#;
3367 let formatted = format_sql_pretty_compact(query, 5);
3368 let formatted_text = formatted.join(" ");
3369
3370 assert!(formatted_text.contains("(symbol"));
3372 assert!(formatted_text.contains("(price"));
3373
3374 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3375 let formatted_parens = formatted_text
3376 .chars()
3377 .filter(|c| *c == '(' || *c == ')')
3378 .count();
3379 assert_eq!(original_parens, formatted_parens);
3380 assert_eq!(original_parens, 4, "Should have 4 parentheses (2 groups)");
3381 }
3382
3383 #[test]
3384 fn test_format_preserves_date_ranges() {
3385 let query = r#"SELECT * FROM trades WHERE (executionDate > DateTime(2024, 1, 1) AND executionDate < DateTime(2024, 12, 31))"#;
3386 let formatted = format_sql_pretty_compact(query, 5);
3387 let formatted_text = formatted.join(" ");
3388
3389 assert!(formatted_text.contains("(executionDate"));
3391 assert!(formatted_text.contains("DateTime(2024, 1, 1)"));
3392 assert!(formatted_text.contains("DateTime(2024, 12, 31)"));
3393
3394 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3395 let formatted_parens = formatted_text
3396 .chars()
3397 .filter(|c| *c == '(' || *c == ')')
3398 .count();
3399 assert_eq!(original_parens, formatted_parens);
3400 }
3401
3402 #[test]
3403 fn test_format_multiline_layout() {
3404 let query =
3406 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#;
3407 let formatted = format_sql_pretty_compact(query, 5);
3408
3409 assert!(formatted.len() >= 4, "Should have multiple lines");
3411 assert_eq!(formatted[0], "SELECT");
3412 assert!(formatted[1].trim().starts_with("*"));
3413 assert!(formatted[2].starts_with("FROM"));
3414 assert_eq!(formatted[3], "WHERE");
3415
3416 let where_lines: Vec<_> = formatted.iter().skip(4).collect();
3418 assert!(where_lines.iter().any(|l| l.contains("(symbol")));
3419 assert!(where_lines.iter().any(|l| l.trim() == "AND"));
3420 }
3421
3422 #[test]
3423 fn test_between_simple() {
3424 let mut parser = Parser::new("SELECT * FROM table WHERE price BETWEEN 50 AND 100");
3425 let stmt = parser.parse().expect("Should parse simple BETWEEN");
3426
3427 assert!(stmt.where_clause.is_some());
3428 let where_clause = stmt.where_clause.unwrap();
3429 assert_eq!(where_clause.conditions.len(), 1);
3430
3431 let ast = format_ast_tree("SELECT * FROM table WHERE price BETWEEN 50 AND 100");
3433 assert!(!ast.contains("PARSE ERROR"));
3434 assert!(ast.contains("SelectStatement"));
3435 }
3436
3437 #[test]
3438 fn test_between_in_parentheses() {
3439 let mut parser = Parser::new("SELECT * FROM table WHERE (price BETWEEN 50 AND 100)");
3440 let stmt = parser.parse().expect("Should parse BETWEEN in parentheses");
3441
3442 assert!(stmt.where_clause.is_some());
3443
3444 let ast = format_ast_tree("SELECT * FROM table WHERE (price BETWEEN 50 AND 100)");
3446 assert!(!ast.contains("PARSE ERROR"), "Should not have parse error");
3447 }
3448
3449 #[test]
3450 fn test_between_with_or() {
3451 let query = "SELECT * FROM test WHERE (Price BETWEEN 50 AND 100) OR (quantity > 5)";
3452 let mut parser = Parser::new(query);
3453 let stmt = parser.parse().expect("Should parse BETWEEN with OR");
3454
3455 assert!(stmt.where_clause.is_some());
3456 }
3459
3460 #[test]
3461 fn test_between_with_and() {
3462 let query = "SELECT * FROM table WHERE category = 'Books' AND price BETWEEN 10 AND 50";
3463 let mut parser = Parser::new(query);
3464 let stmt = parser.parse().expect("Should parse BETWEEN with AND");
3465
3466 assert!(stmt.where_clause.is_some());
3467 let where_clause = stmt.where_clause.unwrap();
3468 assert_eq!(where_clause.conditions.len(), 2); }
3470
3471 #[test]
3472 fn test_multiple_between() {
3473 let query =
3474 "SELECT * FROM table WHERE (price BETWEEN 10 AND 50) AND (quantity BETWEEN 5 AND 20)";
3475 let mut parser = Parser::new(query);
3476 let stmt = parser
3477 .parse()
3478 .expect("Should parse multiple BETWEEN clauses");
3479
3480 assert!(stmt.where_clause.is_some());
3481 }
3482
3483 #[test]
3484 fn test_between_complex_query() {
3485 let query = "SELECT * FROM test_sorting WHERE (Price BETWEEN 50 AND 100) OR (Product.Length() > 5) ORDER BY Category ASC, price DESC";
3487 let mut parser = Parser::new(query);
3488 let stmt = parser
3489 .parse()
3490 .expect("Should parse complex query with BETWEEN, method calls, and ORDER BY");
3491
3492 assert!(stmt.where_clause.is_some());
3493 assert!(stmt.order_by.is_some());
3494
3495 let order_by = stmt.order_by.unwrap();
3496 assert_eq!(order_by.len(), 2);
3497 assert_eq!(order_by[0].column, "Category");
3498 assert!(matches!(order_by[0].direction, SortDirection::Asc));
3499 assert_eq!(order_by[1].column, "price");
3500 assert!(matches!(order_by[1].direction, SortDirection::Desc));
3501 }
3502
3503 #[test]
3504 fn test_between_formatting() {
3505 let expr = SqlExpression::Between {
3506 expr: Box::new(SqlExpression::Column("price".to_string())),
3507 lower: Box::new(SqlExpression::NumberLiteral("50".to_string())),
3508 upper: Box::new(SqlExpression::NumberLiteral("100".to_string())),
3509 };
3510
3511 let formatted = format_expression(&expr);
3512 assert_eq!(formatted, "price BETWEEN 50 AND 100");
3513
3514 let ast_formatted = format_expression_ast(&expr);
3515 assert!(ast_formatted.contains("Between"));
3516 assert!(ast_formatted.contains("50"));
3517 assert!(ast_formatted.contains("100"));
3518 }
3519
3520 #[test]
3521 fn test_utf8_boundary_safety() {
3522 let query_with_unicode = "SELECT * FROM table WHERE column = 'héllo'";
3524
3525 for pos in 0..query_with_unicode.len() + 1 {
3527 let result =
3529 std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, pos));
3530
3531 assert!(
3532 result.is_ok(),
3533 "Panic at position {} in query with Unicode",
3534 pos
3535 );
3536 }
3537
3538 let result = std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, 1000));
3540 assert!(result.is_ok(), "Panic with position beyond string length");
3541
3542 let pos_in_e = query_with_unicode.find('é').unwrap() + 1; let result =
3545 std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, pos_in_e));
3546 assert!(
3547 result.is_ok(),
3548 "Panic with cursor in middle of UTF-8 character"
3549 );
3550 }
3551}