1use chrono::{Datelike, Local, NaiveDateTime};
2
3#[derive(Debug, Clone, PartialEq)]
4pub enum Token {
5 Select,
7 From,
8 Where,
9 And,
10 Or,
11 In,
12 Not,
13 Between,
14 Like,
15 Is,
16 Null,
17 OrderBy,
18 GroupBy,
19 Having,
20 As,
21 Asc,
22 Desc,
23 Limit,
24 Offset,
25 DateTime, Identifier(String),
29 QuotedIdentifier(String), StringLiteral(String),
31 NumberLiteral(String),
32 Star,
33
34 Dot,
36 Comma,
37 LeftParen,
38 RightParen,
39 Equal,
40 NotEqual,
41 LessThan,
42 GreaterThan,
43 LessThanOrEqual,
44 GreaterThanOrEqual,
45
46 Plus,
48 Minus,
49 Divide,
50
51 Eof,
53}
54
55#[derive(Debug, Clone)]
56pub struct Lexer {
57 input: Vec<char>,
58 position: usize,
59 current_char: Option<char>,
60}
61
62impl Lexer {
63 pub fn new(input: &str) -> Self {
64 let chars: Vec<char> = input.chars().collect();
65 let current = chars.get(0).copied();
66 Self {
67 input: chars,
68 position: 0,
69 current_char: current,
70 }
71 }
72
73 fn advance(&mut self) {
74 self.position += 1;
75 self.current_char = self.input.get(self.position).copied();
76 }
77
78 fn peek(&self, offset: usize) -> Option<char> {
79 self.input.get(self.position + offset).copied()
80 }
81
82 fn skip_whitespace(&mut self) {
83 while let Some(ch) = self.current_char {
84 if ch.is_whitespace() {
85 self.advance();
86 } else {
87 break;
88 }
89 }
90 }
91
92 fn read_identifier(&mut self) -> String {
93 let mut result = String::new();
94 while let Some(ch) = self.current_char {
95 if ch.is_alphanumeric() || ch == '_' {
96 result.push(ch);
97 self.advance();
98 } else {
99 break;
100 }
101 }
102 result
103 }
104
105 fn read_string(&mut self) -> String {
106 let mut result = String::new();
107 let quote_char = self.current_char.unwrap(); self.advance(); while let Some(ch) = self.current_char {
111 if ch == quote_char {
112 self.advance(); break;
114 }
115 result.push(ch);
116 self.advance();
117 }
118 result
119 }
120
121 fn read_number(&mut self) -> String {
122 let mut result = String::new();
123 while let Some(ch) = self.current_char {
124 if ch.is_numeric() || ch == '.' {
125 result.push(ch);
126 self.advance();
127 } else {
128 break;
129 }
130 }
131 result
132 }
133
134 pub fn next_token(&mut self) -> Token {
135 self.skip_whitespace();
136
137 match self.current_char {
138 None => Token::Eof,
139 Some('*') => {
140 self.advance();
141 Token::Star }
145 Some('+') => {
146 self.advance();
147 Token::Plus
148 }
149 Some('/') => {
150 self.advance();
151 Token::Divide
152 }
153 Some('.') => {
154 self.advance();
155 Token::Dot
156 }
157 Some(',') => {
158 self.advance();
159 Token::Comma
160 }
161 Some('(') => {
162 self.advance();
163 Token::LeftParen
164 }
165 Some(')') => {
166 self.advance();
167 Token::RightParen
168 }
169 Some('=') => {
170 self.advance();
171 Token::Equal
172 }
173 Some('<') => {
174 self.advance();
175 if self.current_char == Some('=') {
176 self.advance();
177 Token::LessThanOrEqual
178 } else if self.current_char == Some('>') {
179 self.advance();
180 Token::NotEqual
181 } else {
182 Token::LessThan
183 }
184 }
185 Some('>') => {
186 self.advance();
187 if self.current_char == Some('=') {
188 self.advance();
189 Token::GreaterThanOrEqual
190 } else {
191 Token::GreaterThan
192 }
193 }
194 Some('!') if self.peek(1) == Some('=') => {
195 self.advance();
196 self.advance();
197 Token::NotEqual
198 }
199 Some('"') => {
200 let ident_val = self.read_string();
202 Token::QuotedIdentifier(ident_val)
203 }
204 Some('\'') => {
205 let string_val = self.read_string();
207 Token::StringLiteral(string_val)
208 }
209 Some('-') if self.peek(1).map_or(false, |c| c.is_numeric()) => {
210 self.advance(); let num = self.read_number();
213 Token::NumberLiteral(format!("-{}", num))
214 }
215 Some('-') => {
216 self.advance();
218 Token::Minus
219 }
220 Some(ch) if ch.is_numeric() => {
221 let num = self.read_number();
222 Token::NumberLiteral(num)
223 }
224 Some(ch) if ch.is_alphabetic() || ch == '_' => {
225 let ident = self.read_identifier();
226 match ident.to_uppercase().as_str() {
227 "SELECT" => Token::Select,
228 "FROM" => Token::From,
229 "WHERE" => Token::Where,
230 "AND" => Token::And,
231 "OR" => Token::Or,
232 "IN" => Token::In,
233 "NOT" => Token::Not,
234 "BETWEEN" => Token::Between,
235 "LIKE" => Token::Like,
236 "IS" => Token::Is,
237 "NULL" => Token::Null,
238 "ORDER" if self.peek_keyword("BY") => {
239 self.skip_whitespace();
240 self.read_identifier(); Token::OrderBy
242 }
243 "GROUP" if self.peek_keyword("BY") => {
244 self.skip_whitespace();
245 self.read_identifier(); Token::GroupBy
247 }
248 "HAVING" => Token::Having,
249 "AS" => Token::As,
250 "ASC" => Token::Asc,
251 "DESC" => Token::Desc,
252 "LIMIT" => Token::Limit,
253 "OFFSET" => Token::Offset,
254 "DATETIME" => Token::DateTime,
255 _ => Token::Identifier(ident),
256 }
257 }
258 Some(ch) => {
259 self.advance();
260 Token::Identifier(ch.to_string())
261 }
262 }
263 }
264
265 fn peek_keyword(&mut self, keyword: &str) -> bool {
266 let saved_pos = self.position;
267 let saved_char = self.current_char;
268
269 self.skip_whitespace();
270 let next_word = self.read_identifier();
271 let matches = next_word.to_uppercase() == keyword;
272
273 self.position = saved_pos;
275 self.current_char = saved_char;
276
277 matches
278 }
279
280 pub fn get_position(&self) -> usize {
281 self.position
282 }
283
284 pub fn tokenize_all(&mut self) -> Vec<Token> {
285 let mut tokens = Vec::new();
286 loop {
287 let token = self.next_token();
288 if matches!(token, Token::Eof) {
289 tokens.push(token);
290 break;
291 }
292 tokens.push(token);
293 }
294 tokens
295 }
296
297 pub fn tokenize_all_with_positions(&mut self) -> Vec<(usize, usize, Token)> {
298 let mut tokens = Vec::new();
299 loop {
300 self.skip_whitespace();
301 let start_pos = self.position;
302 let token = self.next_token();
303 let end_pos = self.position;
304
305 if matches!(token, Token::Eof) {
306 break;
307 }
308 tokens.push((start_pos, end_pos, token));
309 }
310 tokens
311 }
312}
313
314#[derive(Debug, Clone)]
316pub enum SqlExpression {
317 Column(String),
318 StringLiteral(String),
319 NumberLiteral(String),
320 DateTimeConstructor {
321 year: i32,
322 month: u32,
323 day: u32,
324 hour: Option<u32>,
325 minute: Option<u32>,
326 second: Option<u32>,
327 },
328 DateTimeToday {
329 hour: Option<u32>,
330 minute: Option<u32>,
331 second: Option<u32>,
332 },
333 MethodCall {
334 object: String,
335 method: String,
336 args: Vec<SqlExpression>,
337 },
338 ChainedMethodCall {
339 base: Box<SqlExpression>,
340 method: String,
341 args: Vec<SqlExpression>,
342 },
343 FunctionCall {
344 name: String,
345 args: Vec<SqlExpression>,
346 },
347 BinaryOp {
348 left: Box<SqlExpression>,
349 op: String,
350 right: Box<SqlExpression>,
351 },
352 InList {
353 expr: Box<SqlExpression>,
354 values: Vec<SqlExpression>,
355 },
356 NotInList {
357 expr: Box<SqlExpression>,
358 values: Vec<SqlExpression>,
359 },
360 Between {
361 expr: Box<SqlExpression>,
362 lower: Box<SqlExpression>,
363 upper: Box<SqlExpression>,
364 },
365 Not {
366 expr: Box<SqlExpression>,
367 },
368}
369
370#[derive(Debug, Clone)]
371pub struct WhereClause {
372 pub conditions: Vec<Condition>,
373}
374
375#[derive(Debug, Clone)]
376pub struct Condition {
377 pub expr: SqlExpression,
378 pub connector: Option<LogicalOp>, }
380
381#[derive(Debug, Clone)]
382pub enum LogicalOp {
383 And,
384 Or,
385}
386
387#[derive(Debug, Clone, PartialEq)]
388pub enum SortDirection {
389 Asc,
390 Desc,
391}
392
393#[derive(Debug, Clone)]
394pub struct OrderByColumn {
395 pub column: String,
396 pub direction: SortDirection,
397}
398
399#[derive(Debug, Clone)]
401pub enum SelectItem {
402 Column(String),
404 Expression { expr: SqlExpression, alias: String },
406 Star,
408}
409
410#[derive(Debug, Clone)]
411pub struct SelectStatement {
412 pub columns: Vec<String>, pub select_items: Vec<SelectItem>, pub from_table: Option<String>,
415 pub where_clause: Option<WhereClause>,
416 pub order_by: Option<Vec<OrderByColumn>>,
417 pub group_by: Option<Vec<String>>,
418 pub limit: Option<usize>,
419 pub offset: Option<usize>,
420}
421
422pub struct ParserConfig {
423 pub case_insensitive: bool,
424}
425
426impl Default for ParserConfig {
427 fn default() -> Self {
428 Self {
429 case_insensitive: false,
430 }
431 }
432}
433
434pub struct Parser {
435 lexer: Lexer,
436 current_token: Token,
437 in_method_args: bool, columns: Vec<String>, paren_depth: i32, config: ParserConfig, }
442
443impl Parser {
444 pub fn new(input: &str) -> Self {
445 let mut lexer = Lexer::new(input);
446 let current_token = lexer.next_token();
447 Self {
448 lexer,
449 current_token,
450 in_method_args: false,
451 columns: Vec::new(),
452 paren_depth: 0,
453 config: ParserConfig::default(),
454 }
455 }
456
457 pub fn with_config(input: &str, config: ParserConfig) -> Self {
458 let mut lexer = Lexer::new(input);
459 let current_token = lexer.next_token();
460 Self {
461 lexer,
462 current_token,
463 in_method_args: false,
464 columns: Vec::new(),
465 paren_depth: 0,
466 config,
467 }
468 }
469
470 pub fn with_columns(mut self, columns: Vec<String>) -> Self {
471 self.columns = columns;
472 self
473 }
474
475 fn consume(&mut self, expected: Token) -> Result<(), String> {
476 if std::mem::discriminant(&self.current_token) == std::mem::discriminant(&expected) {
477 match &expected {
479 Token::LeftParen => self.paren_depth += 1,
480 Token::RightParen => {
481 self.paren_depth -= 1;
482 if self.paren_depth < 0 {
484 return Err(
485 "Unexpected closing parenthesis - no matching opening parenthesis"
486 .to_string(),
487 );
488 }
489 }
490 _ => {}
491 }
492
493 self.current_token = self.lexer.next_token();
494 Ok(())
495 } else {
496 let error_msg = match (&expected, &self.current_token) {
498 (Token::RightParen, Token::Eof) if self.paren_depth > 0 => {
499 format!(
500 "Unclosed parenthesis - missing {} closing parenthes{}",
501 self.paren_depth,
502 if self.paren_depth == 1 { "is" } else { "es" }
503 )
504 }
505 (Token::RightParen, _) if self.paren_depth > 0 => {
506 format!(
507 "Expected closing parenthesis but found {:?} (currently {} unclosed parenthes{})",
508 self.current_token,
509 self.paren_depth,
510 if self.paren_depth == 1 { "is" } else { "es" }
511 )
512 }
513 _ => format!("Expected {:?}, found {:?}", expected, self.current_token),
514 };
515 Err(error_msg)
516 }
517 }
518
519 fn advance(&mut self) {
520 match &self.current_token {
522 Token::LeftParen => self.paren_depth += 1,
523 Token::RightParen => {
524 self.paren_depth -= 1;
525 }
528 _ => {}
529 }
530 self.current_token = self.lexer.next_token();
531 }
532
533 pub fn parse(&mut self) -> Result<SelectStatement, String> {
534 self.parse_select_statement()
535 }
536
537 fn parse_select_statement(&mut self) -> Result<SelectStatement, String> {
538 self.consume(Token::Select)?;
539
540 let select_items = self.parse_select_items()?;
542
543 let columns = select_items
545 .iter()
546 .map(|item| match item {
547 SelectItem::Star => "*".to_string(),
548 SelectItem::Column(name) => name.clone(),
549 SelectItem::Expression { alias, .. } => alias.clone(),
550 })
551 .collect();
552
553 let from_table = if matches!(self.current_token, Token::From) {
554 self.advance();
555 match &self.current_token {
556 Token::Identifier(table) => {
557 let table_name = table.clone();
558 self.advance();
559 Some(table_name)
560 }
561 Token::QuotedIdentifier(table) => {
562 let table_name = table.clone();
564 self.advance();
565 Some(table_name)
566 }
567 _ => return Err("Expected table name after FROM".to_string()),
568 }
569 } else {
570 None
571 };
572
573 let where_clause = if matches!(self.current_token, Token::Where) {
574 self.advance();
575 Some(self.parse_where_clause()?)
576 } else {
577 None
578 };
579
580 let order_by = if matches!(self.current_token, Token::OrderBy) {
581 self.advance();
582 Some(self.parse_order_by_list()?)
583 } else {
584 None
585 };
586
587 let group_by = if matches!(self.current_token, Token::GroupBy) {
588 self.advance();
589 Some(self.parse_identifier_list()?)
590 } else {
591 None
592 };
593
594 let limit = if matches!(self.current_token, Token::Limit) {
596 self.advance();
597 match &self.current_token {
598 Token::NumberLiteral(num) => {
599 let limit_val = num
600 .parse::<usize>()
601 .map_err(|_| format!("Invalid LIMIT value: {}", num))?;
602 self.advance();
603 Some(limit_val)
604 }
605 _ => return Err("Expected number after LIMIT".to_string()),
606 }
607 } else {
608 None
609 };
610
611 let offset = if matches!(self.current_token, Token::Offset) {
613 self.advance();
614 match &self.current_token {
615 Token::NumberLiteral(num) => {
616 let offset_val = num
617 .parse::<usize>()
618 .map_err(|_| format!("Invalid OFFSET value: {}", num))?;
619 self.advance();
620 Some(offset_val)
621 }
622 _ => return Err("Expected number after OFFSET".to_string()),
623 }
624 } else {
625 None
626 };
627
628 if self.paren_depth > 0 {
630 return Err(format!(
631 "Unclosed parenthesis - missing {} closing parenthes{}",
632 self.paren_depth,
633 if self.paren_depth == 1 { "is" } else { "es" }
634 ));
635 } else if self.paren_depth < 0 {
636 return Err(
637 "Extra closing parenthesis found - no matching opening parenthesis".to_string(),
638 );
639 }
640
641 Ok(SelectStatement {
642 columns,
643 select_items,
644 from_table,
645 where_clause,
646 order_by,
647 group_by,
648 limit,
649 offset,
650 })
651 }
652
653 fn parse_select_list(&mut self) -> Result<Vec<String>, String> {
654 let mut columns = Vec::new();
655
656 if matches!(self.current_token, Token::Star) {
657 columns.push("*".to_string());
658 self.advance();
659 } else {
660 loop {
661 match &self.current_token {
662 Token::Identifier(col) => {
663 columns.push(col.clone());
664 self.advance();
665 }
666 Token::QuotedIdentifier(col) => {
667 columns.push(col.clone());
669 self.advance();
670 }
671 _ => return Err("Expected column name".to_string()),
672 }
673
674 if matches!(self.current_token, Token::Comma) {
675 self.advance();
676 } else {
677 break;
678 }
679 }
680 }
681
682 Ok(columns)
683 }
684
685 fn parse_select_items(&mut self) -> Result<Vec<SelectItem>, String> {
687 let mut items = Vec::new();
688
689 loop {
690 if matches!(self.current_token, Token::Star) {
693 items.push(SelectItem::Star);
701 self.advance();
702 } else {
703 let expr = self.parse_additive()?; let alias = if matches!(self.current_token, Token::As) {
708 self.advance();
709 match &self.current_token {
710 Token::Identifier(alias_name) => {
711 let alias = alias_name.clone();
712 self.advance();
713 alias
714 }
715 Token::QuotedIdentifier(alias_name) => {
716 let alias = alias_name.clone();
717 self.advance();
718 alias
719 }
720 _ => return Err("Expected alias name after AS".to_string()),
721 }
722 } else {
723 match &expr {
725 SqlExpression::Column(col_name) => col_name.clone(),
726 _ => format!("expr_{}", items.len() + 1), }
728 };
729
730 let item = match expr {
732 SqlExpression::Column(col_name) if alias == col_name => {
733 SelectItem::Column(col_name)
735 }
736 _ => {
737 SelectItem::Expression { expr, alias }
739 }
740 };
741
742 items.push(item);
743 }
744
745 if matches!(self.current_token, Token::Comma) {
747 self.advance();
748 } else {
749 break;
750 }
751 }
752
753 Ok(items)
754 }
755
756 fn parse_identifier_list(&mut self) -> Result<Vec<String>, String> {
757 let mut identifiers = Vec::new();
758
759 loop {
760 match &self.current_token {
761 Token::Identifier(id) => {
762 identifiers.push(id.clone());
763 self.advance();
764 }
765 Token::QuotedIdentifier(id) => {
766 identifiers.push(id.clone());
768 self.advance();
769 }
770 _ => return Err("Expected identifier".to_string()),
771 }
772
773 if matches!(self.current_token, Token::Comma) {
774 self.advance();
775 } else {
776 break;
777 }
778 }
779
780 Ok(identifiers)
781 }
782
783 fn parse_order_by_list(&mut self) -> Result<Vec<OrderByColumn>, String> {
784 let mut order_columns = Vec::new();
785
786 loop {
787 let column = match &self.current_token {
788 Token::Identifier(id) => {
789 let col = id.clone();
790 self.advance();
791 col
792 }
793 Token::QuotedIdentifier(id) => {
794 let col = id.clone();
795 self.advance();
796 col
797 }
798 Token::NumberLiteral(num) if self.columns.iter().any(|col| col == num) => {
799 let col = num.clone();
801 self.advance();
802 col
803 }
804 _ => return Err("Expected column name in ORDER BY".to_string()),
805 };
806
807 let direction = match &self.current_token {
809 Token::Asc => {
810 self.advance();
811 SortDirection::Asc
812 }
813 Token::Desc => {
814 self.advance();
815 SortDirection::Desc
816 }
817 _ => SortDirection::Asc, };
819
820 order_columns.push(OrderByColumn { column, direction });
821
822 if matches!(self.current_token, Token::Comma) {
823 self.advance();
824 } else {
825 break;
826 }
827 }
828
829 Ok(order_columns)
830 }
831
832 fn parse_where_clause(&mut self) -> Result<WhereClause, String> {
833 let mut conditions = Vec::new();
834
835 loop {
836 let expr = self.parse_expression()?;
837
838 let connector = match &self.current_token {
839 Token::And => {
840 self.advance();
841 Some(LogicalOp::And)
842 }
843 Token::Or => {
844 self.advance();
845 Some(LogicalOp::Or)
846 }
847 Token::RightParen if self.paren_depth <= 0 => {
848 return Err(
850 "Unexpected closing parenthesis - no matching opening parenthesis"
851 .to_string(),
852 );
853 }
854 _ => None,
855 };
856
857 conditions.push(Condition {
858 expr,
859 connector: connector.clone(),
860 });
861
862 if connector.is_none() {
863 break;
864 }
865 }
866
867 Ok(WhereClause { conditions })
868 }
869
870 fn parse_expression(&mut self) -> Result<SqlExpression, String> {
871 let mut left = self.parse_comparison()?;
872
873 if let Some(op) = self.get_binary_op() {
876 self.advance();
877 let right = self.parse_expression()?;
878 left = SqlExpression::BinaryOp {
879 left: Box::new(left),
880 op,
881 right: Box::new(right),
882 };
883 }
884
885 if matches!(self.current_token, Token::In) {
887 self.advance();
888 self.consume(Token::LeftParen)?;
889 let values = self.parse_expression_list()?;
890 self.consume(Token::RightParen)?;
891
892 left = SqlExpression::InList {
893 expr: Box::new(left),
894 values,
895 };
896 }
897
898 Ok(left)
902 }
903
904 fn parse_comparison(&mut self) -> Result<SqlExpression, String> {
905 let mut left = self.parse_additive()?;
906
907 while matches!(self.current_token, Token::Dot) {
909 self.advance();
910 if let Token::Identifier(method) = &self.current_token {
911 let method_name = method.clone();
912 self.advance();
913
914 if matches!(self.current_token, Token::LeftParen) {
915 self.advance();
916 let args = self.parse_method_args()?;
917 self.consume(Token::RightParen)?;
918
919 match left {
921 SqlExpression::Column(obj) => {
922 left = SqlExpression::MethodCall {
924 object: obj,
925 method: method_name,
926 args,
927 };
928 }
929 SqlExpression::MethodCall { .. }
930 | SqlExpression::ChainedMethodCall { .. } => {
931 left = SqlExpression::ChainedMethodCall {
933 base: Box::new(left),
934 method: method_name,
935 args,
936 };
937 }
938 _ => {
939 return Err(format!("Cannot call method on {:?}", left));
941 }
942 }
943 } else {
944 break;
947 }
948 } else {
949 break;
950 }
951 }
952
953 if matches!(self.current_token, Token::Between) {
955 self.advance(); let lower = self.parse_primary()?;
957 self.consume(Token::And)?; let upper = self.parse_primary()?;
959
960 return Ok(SqlExpression::Between {
961 expr: Box::new(left),
962 lower: Box::new(lower),
963 upper: Box::new(upper),
964 });
965 }
966
967 if matches!(self.current_token, Token::Not) {
969 self.advance(); if matches!(self.current_token, Token::In) {
971 self.advance(); self.consume(Token::LeftParen)?;
973 let values = self.parse_expression_list()?;
974 self.consume(Token::RightParen)?;
975
976 return Ok(SqlExpression::NotInList {
977 expr: Box::new(left),
978 values,
979 });
980 } else {
981 return Err("Expected IN after NOT".to_string());
982 }
983 }
984
985 if let Some(op) = self.get_binary_op() {
987 self.advance();
988 let right = self.parse_additive()?;
989 left = SqlExpression::BinaryOp {
990 left: Box::new(left),
991 op,
992 right: Box::new(right),
993 };
994 }
995
996 Ok(left)
997 }
998
999 fn parse_additive(&mut self) -> Result<SqlExpression, String> {
1000 let mut left = self.parse_multiplicative()?;
1001
1002 while matches!(self.current_token, Token::Plus | Token::Minus) {
1003 let op = match self.current_token {
1004 Token::Plus => "+",
1005 Token::Minus => "-",
1006 _ => unreachable!(),
1007 };
1008 self.advance();
1009 let right = self.parse_multiplicative()?;
1010 left = SqlExpression::BinaryOp {
1011 left: Box::new(left),
1012 op: op.to_string(),
1013 right: Box::new(right),
1014 };
1015 }
1016
1017 Ok(left)
1018 }
1019
1020 fn parse_multiplicative(&mut self) -> Result<SqlExpression, String> {
1021 let mut left = self.parse_primary()?;
1022
1023 while matches!(self.current_token, Token::Star | Token::Divide) {
1024 let op = match self.current_token {
1025 Token::Star => "*",
1026 Token::Divide => "/",
1027 _ => unreachable!(),
1028 };
1029 self.advance();
1030 let right = self.parse_primary()?;
1031 left = SqlExpression::BinaryOp {
1032 left: Box::new(left),
1033 op: op.to_string(),
1034 right: Box::new(right),
1035 };
1036 }
1037
1038 Ok(left)
1039 }
1040
1041 fn parse_logical_or(&mut self) -> Result<SqlExpression, String> {
1042 let mut left = self.parse_logical_and()?;
1043
1044 while matches!(self.current_token, Token::Or) {
1045 self.advance();
1046 let right = self.parse_logical_and()?;
1047 left = SqlExpression::BinaryOp {
1051 left: Box::new(left),
1052 op: "OR".to_string(),
1053 right: Box::new(right),
1054 };
1055 }
1056
1057 Ok(left)
1058 }
1059
1060 fn parse_logical_and(&mut self) -> Result<SqlExpression, String> {
1061 let mut left = self.parse_expression()?;
1062
1063 while matches!(self.current_token, Token::And) {
1064 self.advance();
1065 let right = self.parse_expression()?;
1066 left = SqlExpression::BinaryOp {
1068 left: Box::new(left),
1069 op: "AND".to_string(),
1070 right: Box::new(right),
1071 };
1072 }
1073
1074 Ok(left)
1075 }
1076
1077 fn parse_primary(&mut self) -> Result<SqlExpression, String> {
1078 if let Token::NumberLiteral(num_str) = &self.current_token {
1081 if self.columns.iter().any(|col| col == num_str) {
1083 let expr = SqlExpression::Column(num_str.clone());
1084 self.advance();
1085 return Ok(expr);
1086 }
1087 }
1088
1089 match &self.current_token {
1090 Token::DateTime => {
1091 self.advance(); self.consume(Token::LeftParen)?;
1093
1094 if matches!(&self.current_token, Token::RightParen) {
1096 self.advance(); return Ok(SqlExpression::DateTimeToday {
1098 hour: None,
1099 minute: None,
1100 second: None,
1101 });
1102 }
1103
1104 let year = if let Token::NumberLiteral(n) = &self.current_token {
1106 n.parse::<i32>().map_err(|_| "Invalid year")?
1107 } else {
1108 return Err("Expected year in DateTime constructor".to_string());
1109 };
1110 self.advance();
1111 self.consume(Token::Comma)?;
1112
1113 let month = if let Token::NumberLiteral(n) = &self.current_token {
1115 n.parse::<u32>().map_err(|_| "Invalid month")?
1116 } else {
1117 return Err("Expected month in DateTime constructor".to_string());
1118 };
1119 self.advance();
1120 self.consume(Token::Comma)?;
1121
1122 let day = if let Token::NumberLiteral(n) = &self.current_token {
1124 n.parse::<u32>().map_err(|_| "Invalid day")?
1125 } else {
1126 return Err("Expected day in DateTime constructor".to_string());
1127 };
1128 self.advance();
1129
1130 let mut hour = None;
1132 let mut minute = None;
1133 let mut second = None;
1134
1135 if matches!(&self.current_token, Token::Comma) {
1136 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1140 hour = Some(n.parse::<u32>().map_err(|_| "Invalid hour")?);
1141 self.advance();
1142
1143 if matches!(&self.current_token, Token::Comma) {
1145 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1148 minute = Some(n.parse::<u32>().map_err(|_| "Invalid minute")?);
1149 self.advance();
1150
1151 if matches!(&self.current_token, Token::Comma) {
1153 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1156 second =
1157 Some(n.parse::<u32>().map_err(|_| "Invalid second")?);
1158 self.advance();
1159 }
1160 }
1161 }
1162 }
1163 }
1164 }
1165
1166 self.consume(Token::RightParen)?;
1167 Ok(SqlExpression::DateTimeConstructor {
1168 year,
1169 month,
1170 day,
1171 hour,
1172 minute,
1173 second,
1174 })
1175 }
1176 Token::Identifier(id) => {
1177 let id_upper = id.to_uppercase();
1178 let id_clone = id.clone();
1179 self.advance();
1180
1181 if matches!(self.current_token, Token::LeftParen) {
1183 if matches!(
1185 id_upper.as_str(),
1186 "ROUND"
1187 | "ABS"
1188 | "FLOOR"
1189 | "CEILING"
1190 | "CEIL"
1191 | "MOD"
1192 | "QUOTIENT"
1193 | "POWER"
1194 | "POW"
1195 | "SQRT"
1196 | "EXP"
1197 | "LN"
1198 | "LOG"
1199 | "LOG10"
1200 | "PI"
1201 | "TEXTJOIN"
1202 | "DATEDIFF"
1203 | "DATEADD"
1204 | "NOW"
1205 | "TODAY"
1206 ) {
1207 self.advance(); let args = self.parse_function_args()?;
1209 self.consume(Token::RightParen)?;
1210 return Ok(SqlExpression::FunctionCall {
1211 name: id_upper,
1212 args,
1213 });
1214 }
1215 }
1216
1217 Ok(SqlExpression::Column(id_clone))
1219 }
1220 Token::QuotedIdentifier(id) => {
1221 let expr = if self.in_method_args {
1224 SqlExpression::StringLiteral(id.clone())
1225 } else {
1226 SqlExpression::Column(id.clone())
1228 };
1229 self.advance();
1230 Ok(expr)
1231 }
1232 Token::StringLiteral(s) => {
1233 let expr = SqlExpression::StringLiteral(s.clone());
1234 self.advance();
1235 Ok(expr)
1236 }
1237 Token::NumberLiteral(n) => {
1238 let expr = SqlExpression::NumberLiteral(n.clone());
1239 self.advance();
1240 Ok(expr)
1241 }
1242 Token::LeftParen => {
1243 self.advance();
1244
1245 let expr = self.parse_logical_or()?;
1248
1249 self.consume(Token::RightParen)?;
1250 Ok(expr)
1251 }
1252 Token::Not => {
1253 self.advance(); if let Ok(inner_expr) = self.parse_comparison() {
1257 if matches!(self.current_token, Token::In) {
1259 self.advance(); self.consume(Token::LeftParen)?;
1261 let values = self.parse_expression_list()?;
1262 self.consume(Token::RightParen)?;
1263
1264 return Ok(SqlExpression::NotInList {
1265 expr: Box::new(inner_expr),
1266 values,
1267 });
1268 } else {
1269 return Ok(SqlExpression::Not {
1271 expr: Box::new(inner_expr),
1272 });
1273 }
1274 } else {
1275 return Err("Expected expression after NOT".to_string());
1276 }
1277 }
1278 _ => Err(format!("Unexpected token: {:?}", self.current_token)),
1279 }
1280 }
1281
1282 fn parse_method_args(&mut self) -> Result<Vec<SqlExpression>, String> {
1283 let mut args = Vec::new();
1284
1285 self.in_method_args = true;
1287
1288 if !matches!(self.current_token, Token::RightParen) {
1289 loop {
1290 args.push(self.parse_expression()?);
1291
1292 if matches!(self.current_token, Token::Comma) {
1293 self.advance();
1294 } else {
1295 break;
1296 }
1297 }
1298 }
1299
1300 self.in_method_args = false;
1302
1303 Ok(args)
1304 }
1305
1306 fn parse_function_args(&mut self) -> Result<Vec<SqlExpression>, String> {
1307 let mut args = Vec::new();
1308
1309 if !matches!(self.current_token, Token::RightParen) {
1310 loop {
1311 args.push(self.parse_additive()?);
1313
1314 if matches!(self.current_token, Token::Comma) {
1315 self.advance();
1316 } else {
1317 break;
1318 }
1319 }
1320 }
1321
1322 Ok(args)
1323 }
1324
1325 fn parse_expression_list(&mut self) -> Result<Vec<SqlExpression>, String> {
1326 let mut expressions = Vec::new();
1327
1328 loop {
1329 expressions.push(self.parse_expression()?);
1330
1331 if matches!(self.current_token, Token::Comma) {
1332 self.advance();
1333 } else {
1334 break;
1335 }
1336 }
1337
1338 Ok(expressions)
1339 }
1340
1341 fn get_binary_op(&self) -> Option<String> {
1342 match &self.current_token {
1343 Token::Equal => Some("=".to_string()),
1344 Token::NotEqual => Some("!=".to_string()),
1345 Token::LessThan => Some("<".to_string()),
1346 Token::GreaterThan => Some(">".to_string()),
1347 Token::LessThanOrEqual => Some("<=".to_string()),
1348 Token::GreaterThanOrEqual => Some(">=".to_string()),
1349 Token::Like => Some("LIKE".to_string()),
1350 _ => None,
1351 }
1352 }
1353
1354 fn get_arithmetic_op(&self) -> Option<String> {
1355 match &self.current_token {
1356 Token::Plus => Some("+".to_string()),
1357 Token::Minus => Some("-".to_string()),
1358 Token::Star => Some("*".to_string()), Token::Divide => Some("/".to_string()),
1360 _ => None,
1361 }
1362 }
1363
1364 pub fn get_position(&self) -> usize {
1365 self.lexer.get_position()
1366 }
1367}
1368
1369#[derive(Debug, Clone)]
1371pub enum CursorContext {
1372 SelectClause,
1373 FromClause,
1374 WhereClause,
1375 OrderByClause,
1376 AfterColumn(String),
1377 AfterLogicalOp(LogicalOp),
1378 AfterComparisonOp(String, String), InMethodCall(String, String), InExpression,
1381 Unknown,
1382}
1383
1384fn safe_slice_to(s: &str, pos: usize) -> &str {
1386 if pos >= s.len() {
1387 return s;
1388 }
1389
1390 let mut safe_pos = pos;
1392 while safe_pos > 0 && !s.is_char_boundary(safe_pos) {
1393 safe_pos -= 1;
1394 }
1395
1396 &s[..safe_pos]
1397}
1398
1399fn safe_slice_from(s: &str, pos: usize) -> &str {
1401 if pos >= s.len() {
1402 return "";
1403 }
1404
1405 let mut safe_pos = pos;
1407 while safe_pos < s.len() && !s.is_char_boundary(safe_pos) {
1408 safe_pos += 1;
1409 }
1410
1411 &s[safe_pos..]
1412}
1413
1414pub fn detect_cursor_context(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
1415 let truncated = safe_slice_to(query, cursor_pos);
1416 let mut parser = Parser::new(truncated);
1417
1418 match parser.parse() {
1420 Ok(stmt) => {
1421 let (ctx, partial) = analyze_statement(&stmt, truncated, cursor_pos);
1422 #[cfg(test)]
1423 println!(
1424 "analyze_statement returned: {:?}, {:?} for query: '{}'",
1425 ctx, partial, truncated
1426 );
1427 (ctx, partial)
1428 }
1429 Err(_) => {
1430 let (ctx, partial) = analyze_partial(truncated, cursor_pos);
1432 #[cfg(test)]
1433 println!(
1434 "analyze_partial returned: {:?}, {:?} for query: '{}'",
1435 ctx, partial, truncated
1436 );
1437 (ctx, partial)
1438 }
1439 }
1440}
1441
1442pub fn tokenize_query(query: &str) -> Vec<String> {
1443 let mut lexer = Lexer::new(query);
1444 let tokens = lexer.tokenize_all();
1445 tokens.iter().map(|t| format!("{:?}", t)).collect()
1446}
1447
1448pub fn format_sql_pretty(query: &str) -> Vec<String> {
1449 format_sql_pretty_compact(query, 5) }
1451
1452pub fn format_ast_tree(query: &str) -> String {
1454 let mut parser = Parser::new(query);
1455 match parser.parse() {
1456 Ok(stmt) => format_select_statement(&stmt, 0),
1457 Err(e) => format!("❌ PARSE ERROR ❌\n{}\n\n⚠️ The query could not be parsed correctly.\n💡 Check parentheses, operators, and syntax.", e),
1458 }
1459}
1460
1461fn format_select_statement(stmt: &SelectStatement, indent: usize) -> String {
1462 let mut result = String::new();
1463 let indent_str = " ".repeat(indent);
1464
1465 result.push_str(&format!("{indent_str}SelectStatement {{\n"));
1466
1467 result.push_str(&format!("{indent_str} columns: ["));
1469 if !stmt.columns.is_empty() {
1470 result.push('\n');
1471 for col in &stmt.columns {
1472 result.push_str(&format!("{indent_str} \"{col}\",\n"));
1473 }
1474 result.push_str(&format!("{indent_str} ],\n"));
1475 } else {
1476 result.push_str("],\n");
1477 }
1478
1479 if let Some(table) = &stmt.from_table {
1481 result.push_str(&format!("{indent_str} from_table: \"{table}\",\n"));
1482 }
1483
1484 if let Some(where_clause) = &stmt.where_clause {
1486 result.push_str(&format!("{indent_str} where_clause: {{\n"));
1487 result.push_str(&format_where_clause(where_clause, indent + 2));
1488 result.push_str(&format!("{indent_str} }},\n"));
1489 }
1490
1491 if let Some(order_by) = &stmt.order_by {
1493 result.push_str(&format!("{indent_str} order_by: ["));
1494 if !order_by.is_empty() {
1495 result.push('\n');
1496 for col in order_by {
1497 let dir = match col.direction {
1498 SortDirection::Asc => "ASC",
1499 SortDirection::Desc => "DESC",
1500 };
1501 result.push_str(&format!(
1502 "{indent_str} \"{col}\" {dir},\n",
1503 col = col.column
1504 ));
1505 }
1506 result.push_str(&format!("{indent_str} ],\n"));
1507 } else {
1508 result.push_str("],\n");
1509 }
1510 }
1511
1512 if let Some(group_by) = &stmt.group_by {
1514 result.push_str(&format!("{indent_str} group_by: ["));
1515 if !group_by.is_empty() {
1516 result.push('\n');
1517 for col in group_by {
1518 result.push_str(&format!("{indent_str} \"{col}\",\n"));
1519 }
1520 result.push_str(&format!("{indent_str} ],\n"));
1521 } else {
1522 result.push_str("]\n");
1523 }
1524 }
1525
1526 result.push_str(&format!("{indent_str}}}"));
1527 result
1528}
1529
1530fn format_where_clause(clause: &WhereClause, indent: usize) -> String {
1531 let mut result = String::new();
1532 let indent_str = " ".repeat(indent);
1533
1534 result.push_str(&format!("{indent_str}conditions: [\n"));
1535
1536 for condition in &clause.conditions {
1537 result.push_str(&format!("{indent_str} {{\n"));
1538 result.push_str(&format!(
1539 "{indent_str} expr: {},\n",
1540 format_expression_ast(&condition.expr)
1541 ));
1542
1543 if let Some(connector) = &condition.connector {
1544 let connector_str = match connector {
1545 LogicalOp::And => "AND",
1546 LogicalOp::Or => "OR",
1547 };
1548 result.push_str(&format!("{indent_str} connector: {connector_str},\n"));
1549 }
1550
1551 result.push_str(&format!("{indent_str} }},\n"));
1552 }
1553
1554 result.push_str(&format!("{indent_str}]\n"));
1555 result
1556}
1557
1558fn format_expression_ast(expr: &SqlExpression) -> String {
1559 match expr {
1560 SqlExpression::Column(name) => format!("Column(\"{}\")", name),
1561 SqlExpression::StringLiteral(value) => format!("StringLiteral(\"{}\")", value),
1562 SqlExpression::NumberLiteral(value) => format!("NumberLiteral({})", value),
1563 SqlExpression::DateTimeConstructor {
1564 year,
1565 month,
1566 day,
1567 hour,
1568 minute,
1569 second,
1570 } => {
1571 format!(
1572 "DateTime({}-{:02}-{:02} {:02}:{:02}:{:02})",
1573 year,
1574 month,
1575 day,
1576 hour.unwrap_or(0),
1577 minute.unwrap_or(0),
1578 second.unwrap_or(0)
1579 )
1580 }
1581 SqlExpression::DateTimeToday {
1582 hour,
1583 minute,
1584 second,
1585 } => {
1586 format!(
1587 "DateTimeToday({:02}:{:02}:{:02})",
1588 hour.unwrap_or(0),
1589 minute.unwrap_or(0),
1590 second.unwrap_or(0)
1591 )
1592 }
1593 SqlExpression::MethodCall {
1594 object,
1595 method,
1596 args,
1597 } => {
1598 let args_str = args
1599 .iter()
1600 .map(|a| format_expression_ast(a))
1601 .collect::<Vec<_>>()
1602 .join(", ");
1603 format!("MethodCall({}.{}({}))", object, method, args_str)
1604 }
1605 SqlExpression::ChainedMethodCall { base, method, args } => {
1606 let args_str = args
1607 .iter()
1608 .map(|a| format_expression_ast(a))
1609 .collect::<Vec<_>>()
1610 .join(", ");
1611 format!(
1612 "ChainedMethodCall({}.{}({}))",
1613 format_expression_ast(base),
1614 method,
1615 args_str
1616 )
1617 }
1618 SqlExpression::FunctionCall { name, args } => {
1619 let args_str = args
1620 .iter()
1621 .map(|a| format_expression_ast(a))
1622 .collect::<Vec<_>>()
1623 .join(", ");
1624 format!("FunctionCall({}({}))", name, args_str)
1625 }
1626 SqlExpression::BinaryOp { left, op, right } => {
1627 format!(
1628 "BinaryOp({} {} {})",
1629 format_expression_ast(left),
1630 op,
1631 format_expression_ast(right)
1632 )
1633 }
1634 SqlExpression::InList { expr, values } => {
1635 let list_str = values
1636 .iter()
1637 .map(|e| format_expression_ast(e))
1638 .collect::<Vec<_>>()
1639 .join(", ");
1640 format!("InList({} IN [{}])", format_expression_ast(expr), list_str)
1641 }
1642 SqlExpression::NotInList { expr, values } => {
1643 let list_str = values
1644 .iter()
1645 .map(|e| format_expression_ast(e))
1646 .collect::<Vec<_>>()
1647 .join(", ");
1648 format!(
1649 "NotInList({} NOT IN [{}])",
1650 format_expression_ast(expr),
1651 list_str
1652 )
1653 }
1654 SqlExpression::Between { expr, lower, upper } => {
1655 format!(
1656 "Between({} BETWEEN {} AND {})",
1657 format_expression_ast(expr),
1658 format_expression_ast(lower),
1659 format_expression_ast(upper)
1660 )
1661 }
1662 SqlExpression::Not { expr } => {
1663 format!("Not({})", format_expression_ast(expr))
1664 }
1665 }
1666}
1667
1668pub fn datetime_to_iso_string(expr: &SqlExpression) -> Option<String> {
1670 match expr {
1671 SqlExpression::DateTimeConstructor {
1672 year,
1673 month,
1674 day,
1675 hour,
1676 minute,
1677 second,
1678 } => {
1679 let h = hour.unwrap_or(0);
1680 let m = minute.unwrap_or(0);
1681 let s = second.unwrap_or(0);
1682
1683 if let Ok(dt) = NaiveDateTime::parse_from_str(
1685 &format!(
1686 "{:04}-{:02}-{:02} {:02}:{:02}:{:02}",
1687 year, month, day, h, m, s
1688 ),
1689 "%Y-%m-%d %H:%M:%S",
1690 ) {
1691 Some(dt.format("%Y-%m-%d %H:%M:%S").to_string())
1692 } else {
1693 None
1694 }
1695 }
1696 SqlExpression::DateTimeToday {
1697 hour,
1698 minute,
1699 second,
1700 } => {
1701 let now = Local::now();
1702 let h = hour.unwrap_or(0);
1703 let m = minute.unwrap_or(0);
1704 let s = second.unwrap_or(0);
1705
1706 if let Ok(dt) = NaiveDateTime::parse_from_str(
1708 &format!(
1709 "{:04}-{:02}-{:02} {:02}:{:02}:{:02}",
1710 now.year(),
1711 now.month(),
1712 now.day(),
1713 h,
1714 m,
1715 s
1716 ),
1717 "%Y-%m-%d %H:%M:%S",
1718 ) {
1719 Some(dt.format("%Y-%m-%d %H:%M:%S").to_string())
1720 } else {
1721 None
1722 }
1723 }
1724 _ => None,
1725 }
1726}
1727
1728fn format_sql_with_preserved_parens(
1730 query: &str,
1731 cols_per_line: usize,
1732) -> Result<Vec<String>, String> {
1733 let mut lines = Vec::new();
1734 let mut lexer = Lexer::new(query);
1735 let tokens_with_pos = lexer.tokenize_all_with_positions();
1736
1737 if tokens_with_pos.is_empty() {
1738 return Err("No tokens found".to_string());
1739 }
1740
1741 let mut i = 0;
1742 let cols_per_line = cols_per_line.max(1);
1743
1744 while i < tokens_with_pos.len() {
1745 let (start, _end, ref token) = tokens_with_pos[i];
1746
1747 match token {
1748 Token::Select => {
1749 lines.push("SELECT".to_string());
1750 i += 1;
1751
1752 let mut columns = Vec::new();
1754 let mut col_start = i;
1755 while i < tokens_with_pos.len() {
1756 match &tokens_with_pos[i].2 {
1757 Token::From | Token::Eof => break,
1758 Token::Comma => {
1759 if col_start < i {
1761 let col_text = extract_text_between_positions(
1762 query,
1763 tokens_with_pos[col_start].0,
1764 tokens_with_pos[i - 1].1,
1765 );
1766 columns.push(col_text);
1767 }
1768 i += 1;
1769 col_start = i;
1770 }
1771 _ => i += 1,
1772 }
1773 }
1774 if col_start < i && i > 0 {
1776 let col_text = extract_text_between_positions(
1777 query,
1778 tokens_with_pos[col_start].0,
1779 tokens_with_pos[i - 1].1,
1780 );
1781 columns.push(col_text);
1782 }
1783
1784 for chunk in columns.chunks(cols_per_line) {
1786 let mut line = " ".to_string();
1787 for (idx, col) in chunk.iter().enumerate() {
1788 if idx > 0 {
1789 line.push_str(", ");
1790 }
1791 line.push_str(col.trim());
1792 }
1793 let is_last_chunk = chunk.as_ptr() as usize
1795 + chunk.len() * std::mem::size_of::<String>()
1796 >= columns.last().map(|c| c as *const _ as usize).unwrap_or(0);
1797 if !is_last_chunk && columns.len() > cols_per_line {
1798 line.push(',');
1799 }
1800 lines.push(line);
1801 }
1802 }
1803 Token::From => {
1804 i += 1;
1805 if i < tokens_with_pos.len() {
1806 let table_start = tokens_with_pos[i].0;
1807 while i < tokens_with_pos.len() {
1809 match &tokens_with_pos[i].2 {
1810 Token::Where | Token::OrderBy | Token::GroupBy | Token::Eof => break,
1811 _ => i += 1,
1812 }
1813 }
1814 if i > 0 {
1815 let table_text = extract_text_between_positions(
1816 query,
1817 table_start,
1818 tokens_with_pos[i - 1].1,
1819 );
1820 lines.push(format!("FROM {}", table_text.trim()));
1821 }
1822 }
1823 }
1824 Token::Where => {
1825 lines.push("WHERE".to_string());
1826 i += 1;
1827
1828 let where_start = if i < tokens_with_pos.len() {
1830 tokens_with_pos[i].0
1831 } else {
1832 start
1833 };
1834
1835 let mut where_end = query.len();
1837 while i < tokens_with_pos.len() {
1838 match &tokens_with_pos[i].2 {
1839 Token::OrderBy | Token::GroupBy | Token::Eof => {
1840 if i > 0 {
1841 where_end = tokens_with_pos[i - 1].1;
1842 }
1843 break;
1844 }
1845 _ => i += 1,
1846 }
1847 }
1848
1849 let where_text = extract_text_between_positions(query, where_start, where_end);
1851
1852 let formatted_where = format_where_clause_with_parens(&where_text);
1854 for line in formatted_where {
1855 lines.push(format!(" {}", line));
1856 }
1857 }
1858 Token::OrderBy => {
1859 i += 1;
1860 let order_start = if i < tokens_with_pos.len() {
1861 tokens_with_pos[i].0
1862 } else {
1863 start
1864 };
1865
1866 while i < tokens_with_pos.len() {
1868 match &tokens_with_pos[i].2 {
1869 Token::GroupBy | Token::Eof => break,
1870 _ => i += 1,
1871 }
1872 }
1873
1874 if i > 0 {
1875 let order_text = extract_text_between_positions(
1876 query,
1877 order_start,
1878 tokens_with_pos[i - 1].1,
1879 );
1880 lines.push(format!("ORDER BY {}", order_text.trim()));
1881 }
1882 }
1883 Token::GroupBy => {
1884 i += 1;
1885 let group_start = if i < tokens_with_pos.len() {
1886 tokens_with_pos[i].0
1887 } else {
1888 start
1889 };
1890
1891 while i < tokens_with_pos.len() {
1893 match &tokens_with_pos[i].2 {
1894 Token::Having | Token::Eof => break,
1895 _ => i += 1,
1896 }
1897 }
1898
1899 if i > 0 {
1900 let group_text = extract_text_between_positions(
1901 query,
1902 group_start,
1903 tokens_with_pos[i - 1].1,
1904 );
1905 lines.push(format!("GROUP BY {}", group_text.trim()));
1906 }
1907 }
1908 _ => i += 1,
1909 }
1910 }
1911
1912 Ok(lines)
1913}
1914
1915fn extract_text_between_positions(query: &str, start: usize, end: usize) -> String {
1917 let chars: Vec<char> = query.chars().collect();
1918 let start = start.min(chars.len());
1919 let end = end.min(chars.len());
1920 chars[start..end].iter().collect()
1921}
1922
1923fn format_where_clause_with_parens(where_text: &str) -> Vec<String> {
1925 let mut lines = Vec::new();
1926 let mut current_line = String::new();
1927 let mut paren_depth = 0;
1928 let mut i = 0;
1929 let chars: Vec<char> = where_text.chars().collect();
1930
1931 while i < chars.len() {
1932 if paren_depth == 0 {
1934 if i + 5 <= chars.len() {
1936 let next_five: String = chars[i..i + 5].iter().collect();
1937 if next_five.to_uppercase() == " AND " {
1938 if !current_line.trim().is_empty() {
1939 lines.push(current_line.trim().to_string());
1940 }
1941 lines.push("AND".to_string());
1942 current_line.clear();
1943 i += 5;
1944 continue;
1945 }
1946 }
1947 if i + 4 <= chars.len() {
1948 let next_four: String = chars[i..i + 4].iter().collect();
1949 if next_four.to_uppercase() == " OR " {
1950 if !current_line.trim().is_empty() {
1951 lines.push(current_line.trim().to_string());
1952 }
1953 lines.push("OR".to_string());
1954 current_line.clear();
1955 i += 4;
1956 continue;
1957 }
1958 }
1959 }
1960
1961 match chars[i] {
1963 '(' => {
1964 paren_depth += 1;
1965 current_line.push('(');
1966 }
1967 ')' => {
1968 paren_depth -= 1;
1969 current_line.push(')');
1970 }
1971 c => current_line.push(c),
1972 }
1973 i += 1;
1974 }
1975
1976 if !current_line.trim().is_empty() {
1978 lines.push(current_line.trim().to_string());
1979 }
1980
1981 if lines.is_empty() {
1983 lines.push(where_text.trim().to_string());
1984 }
1985
1986 lines
1987}
1988
1989pub fn format_sql_pretty_compact(query: &str, cols_per_line: usize) -> Vec<String> {
1990 if let Ok(lines) = format_sql_with_preserved_parens(query, cols_per_line) {
1992 return lines;
1993 }
1994
1995 let mut lines = Vec::new();
1997 let mut parser = Parser::new(query);
1998
1999 let cols_per_line = cols_per_line.max(1);
2001
2002 match parser.parse() {
2003 Ok(stmt) => {
2004 if !stmt.columns.is_empty() {
2006 lines.push("SELECT".to_string());
2007
2008 for chunk in stmt.columns.chunks(cols_per_line) {
2010 let mut line = " ".to_string();
2011 for (i, col) in chunk.iter().enumerate() {
2012 if i > 0 {
2013 line.push_str(", ");
2014 }
2015 line.push_str(col);
2016 }
2017 let last_chunk_idx = (stmt.columns.len() - 1) / cols_per_line;
2019 let current_chunk_idx =
2020 stmt.columns.iter().position(|c| c == &chunk[0]).unwrap() / cols_per_line;
2021 if current_chunk_idx < last_chunk_idx {
2022 line.push(',');
2023 }
2024 lines.push(line);
2025 }
2026 }
2027
2028 if let Some(table) = &stmt.from_table {
2030 lines.push(format!("FROM {}", table));
2031 }
2032
2033 if let Some(where_clause) = &stmt.where_clause {
2035 lines.push("WHERE".to_string());
2036 for (i, condition) in where_clause.conditions.iter().enumerate() {
2037 if i > 0 {
2038 if let Some(prev_condition) = where_clause.conditions.get(i - 1) {
2040 if let Some(connector) = &prev_condition.connector {
2041 match connector {
2042 LogicalOp::And => lines.push(" AND".to_string()),
2043 LogicalOp::Or => lines.push(" OR".to_string()),
2044 }
2045 }
2046 }
2047 }
2048 lines.push(format!(" {}", format_expression(&condition.expr)));
2049 }
2050 }
2051
2052 if let Some(order_by) = &stmt.order_by {
2054 let order_str = order_by
2055 .iter()
2056 .map(|col| {
2057 let dir = match col.direction {
2058 SortDirection::Asc => " ASC",
2059 SortDirection::Desc => " DESC",
2060 };
2061 format!("{}{}", col.column, dir)
2062 })
2063 .collect::<Vec<_>>()
2064 .join(", ");
2065 lines.push(format!("ORDER BY {}", order_str));
2066 }
2067
2068 if let Some(group_by) = &stmt.group_by {
2070 let group_str = group_by.join(", ");
2071 lines.push(format!("GROUP BY {}", group_str));
2072 }
2073 }
2074 Err(_) => {
2075 let mut lexer = Lexer::new(query);
2077 let tokens = lexer.tokenize_all();
2078 let mut current_line = String::new();
2079 let mut indent = 0;
2080
2081 for token in tokens {
2082 match &token {
2083 Token::Select
2084 | Token::From
2085 | Token::Where
2086 | Token::OrderBy
2087 | Token::GroupBy => {
2088 if !current_line.is_empty() {
2089 lines.push(current_line.trim().to_string());
2090 current_line.clear();
2091 }
2092 lines.push(format!("{:?}", token).to_uppercase());
2093 indent = 1;
2094 }
2095 Token::And | Token::Or => {
2096 if !current_line.is_empty() {
2097 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2098 current_line.clear();
2099 }
2100 lines.push(format!(" {:?}", token).to_uppercase());
2101 }
2102 Token::Comma => {
2103 current_line.push(',');
2104 if indent > 0 {
2105 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2106 current_line.clear();
2107 }
2108 }
2109 Token::Eof => break,
2110 _ => {
2111 if !current_line.is_empty() {
2112 current_line.push(' ');
2113 }
2114 current_line.push_str(&format_token(&token));
2115 }
2116 }
2117 }
2118
2119 if !current_line.is_empty() {
2120 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2121 }
2122 }
2123 }
2124
2125 lines
2126}
2127
2128fn format_expression(expr: &SqlExpression) -> String {
2129 match expr {
2130 SqlExpression::Column(name) => name.clone(),
2131 SqlExpression::StringLiteral(s) => format!("'{}'", s),
2132 SqlExpression::NumberLiteral(n) => n.clone(),
2133 SqlExpression::DateTimeConstructor {
2134 year,
2135 month,
2136 day,
2137 hour,
2138 minute,
2139 second,
2140 } => {
2141 let mut result = format!("DateTime({}, {}, {}", year, month, day);
2142 if let Some(h) = hour {
2143 result.push_str(&format!(", {}", h));
2144 if let Some(m) = minute {
2145 result.push_str(&format!(", {}", m));
2146 if let Some(s) = second {
2147 result.push_str(&format!(", {}", s));
2148 }
2149 }
2150 }
2151 result.push(')');
2152 result
2153 }
2154 SqlExpression::DateTimeToday {
2155 hour,
2156 minute,
2157 second,
2158 } => {
2159 let mut result = "DateTime()".to_string();
2160 if let Some(h) = hour {
2161 result = format!("DateTime(TODAY, {}", h);
2162 if let Some(m) = minute {
2163 result.push_str(&format!(", {}", m));
2164 if let Some(s) = second {
2165 result.push_str(&format!(", {}", s));
2166 }
2167 }
2168 result.push(')');
2169 }
2170 result
2171 }
2172 SqlExpression::MethodCall {
2173 object,
2174 method,
2175 args,
2176 } => {
2177 let args_str = args
2178 .iter()
2179 .map(|arg| format_expression(arg))
2180 .collect::<Vec<_>>()
2181 .join(", ");
2182 format!("{}.{}({})", object, method, args_str)
2183 }
2184 SqlExpression::BinaryOp { left, op, right } => {
2185 if op == "OR" || op == "AND" {
2188 format!(
2191 "({} {} {})",
2192 format_expression(left),
2193 op,
2194 format_expression(right)
2195 )
2196 } else {
2197 format!(
2198 "{} {} {}",
2199 format_expression(left),
2200 op,
2201 format_expression(right)
2202 )
2203 }
2204 }
2205 SqlExpression::InList { expr, values } => {
2206 let values_str = values
2207 .iter()
2208 .map(|v| format_expression(v))
2209 .collect::<Vec<_>>()
2210 .join(", ");
2211 format!("{} IN ({})", format_expression(expr), values_str)
2212 }
2213 SqlExpression::NotInList { expr, values } => {
2214 let values_str = values
2215 .iter()
2216 .map(|v| format_expression(v))
2217 .collect::<Vec<_>>()
2218 .join(", ");
2219 format!("{} NOT IN ({})", format_expression(expr), values_str)
2220 }
2221 SqlExpression::Between { expr, lower, upper } => {
2222 format!(
2223 "{} BETWEEN {} AND {}",
2224 format_expression(expr),
2225 format_expression(lower),
2226 format_expression(upper)
2227 )
2228 }
2229 SqlExpression::Not { expr } => {
2230 format!("NOT {}", format_expression(expr))
2231 }
2232 SqlExpression::ChainedMethodCall { base, method, args } => {
2233 let args_str = args
2234 .iter()
2235 .map(|arg| format_expression(arg))
2236 .collect::<Vec<_>>()
2237 .join(", ");
2238 format!("{}.{}({})", format_expression(base), method, args_str)
2239 }
2240 SqlExpression::FunctionCall { name, args } => {
2241 let args_str = args
2242 .iter()
2243 .map(|arg| format_expression(arg))
2244 .collect::<Vec<_>>()
2245 .join(", ");
2246 format!("{}({})", name, args_str)
2247 }
2248 }
2249}
2250
2251fn format_token(token: &Token) -> String {
2252 match token {
2253 Token::Identifier(s) => s.clone(),
2254 Token::QuotedIdentifier(s) => format!("\"{}\"", s),
2255 Token::StringLiteral(s) => format!("'{}'", s),
2256 Token::NumberLiteral(n) => n.clone(),
2257 Token::DateTime => "DateTime".to_string(),
2258 Token::LeftParen => "(".to_string(),
2259 Token::RightParen => ")".to_string(),
2260 Token::Comma => ",".to_string(),
2261 Token::Dot => ".".to_string(),
2262 Token::Equal => "=".to_string(),
2263 Token::NotEqual => "!=".to_string(),
2264 Token::LessThan => "<".to_string(),
2265 Token::GreaterThan => ">".to_string(),
2266 Token::LessThanOrEqual => "<=".to_string(),
2267 Token::GreaterThanOrEqual => ">=".to_string(),
2268 Token::In => "IN".to_string(),
2269 _ => format!("{:?}", token).to_uppercase(),
2270 }
2271}
2272
2273fn analyze_statement(
2274 stmt: &SelectStatement,
2275 query: &str,
2276 _cursor_pos: usize,
2277) -> (CursorContext, Option<String>) {
2278 let trimmed = query.trim();
2280
2281 let comparison_ops = [" > ", " < ", " >= ", " <= ", " = ", " != "];
2283 for op in &comparison_ops {
2284 if let Some(op_pos) = query.rfind(op) {
2285 let before_op = safe_slice_to(query, op_pos);
2286 let after_op_start = op_pos + op.len();
2287 let after_op = if after_op_start < query.len() {
2288 &query[after_op_start..]
2289 } else {
2290 ""
2291 };
2292
2293 if let Some(col_name) = before_op.split_whitespace().last() {
2295 if col_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
2296 let after_op_trimmed = after_op.trim();
2298 if after_op_trimmed.is_empty()
2299 || (after_op_trimmed
2300 .chars()
2301 .all(|c| c.is_alphanumeric() || c == '_')
2302 && !after_op_trimmed.contains('('))
2303 {
2304 let partial = if after_op_trimmed.is_empty() {
2305 None
2306 } else {
2307 Some(after_op_trimmed.to_string())
2308 };
2309 return (
2310 CursorContext::AfterComparisonOp(
2311 col_name.to_string(),
2312 op.trim().to_string(),
2313 ),
2314 partial,
2315 );
2316 }
2317 }
2318 }
2319 }
2320 }
2321
2322 if trimmed.to_uppercase().ends_with(" AND")
2324 || trimmed.to_uppercase().ends_with(" OR")
2325 || trimmed.to_uppercase().ends_with(" AND ")
2326 || trimmed.to_uppercase().ends_with(" OR ")
2327 {
2328 } else {
2330 if let Some(dot_pos) = trimmed.rfind('.') {
2332 let before_dot = safe_slice_to(trimmed, dot_pos);
2334 let after_dot_start = dot_pos + 1;
2335 let after_dot = if after_dot_start < trimmed.len() {
2336 &trimmed[after_dot_start..]
2337 } else {
2338 ""
2339 };
2340
2341 if !after_dot.contains('(') {
2344 let col_name = if before_dot.ends_with('"') {
2346 let bytes = before_dot.as_bytes();
2348 let mut pos = before_dot.len() - 1; let mut found_start = None;
2350
2351 if pos > 0 {
2353 pos -= 1;
2354 while pos > 0 {
2355 if bytes[pos] == b'"' {
2356 if pos == 0 || bytes[pos - 1] != b'\\' {
2358 found_start = Some(pos);
2359 break;
2360 }
2361 }
2362 pos -= 1;
2363 }
2364 if found_start.is_none() && bytes[0] == b'"' {
2366 found_start = Some(0);
2367 }
2368 }
2369
2370 if let Some(start) = found_start {
2371 Some(safe_slice_from(before_dot, start))
2373 } else {
2374 None
2375 }
2376 } else {
2377 before_dot
2380 .split_whitespace()
2381 .last()
2382 .map(|word| word.trim_start_matches('('))
2383 };
2384
2385 if let Some(col_name) = col_name {
2386 let is_valid = if col_name.starts_with('"') && col_name.ends_with('"') {
2388 true
2390 } else {
2391 col_name.chars().all(|c| c.is_alphanumeric() || c == '_')
2393 };
2394
2395 if is_valid {
2396 let partial_method = if after_dot.is_empty() {
2399 None
2400 } else if after_dot.chars().all(|c| c.is_alphanumeric() || c == '_') {
2401 Some(after_dot.to_string())
2402 } else {
2403 None
2404 };
2405
2406 let col_name_for_context = if col_name.starts_with('"')
2408 && col_name.ends_with('"')
2409 && col_name.len() > 2
2410 {
2411 col_name[1..col_name.len() - 1].to_string()
2412 } else {
2413 col_name.to_string()
2414 };
2415
2416 return (
2417 CursorContext::AfterColumn(col_name_for_context),
2418 partial_method,
2419 );
2420 }
2421 }
2422 }
2423 }
2424 }
2425
2426 if let Some(where_clause) = &stmt.where_clause {
2428 if trimmed.to_uppercase().ends_with(" AND") || trimmed.to_uppercase().ends_with(" OR") {
2430 let op = if trimmed.to_uppercase().ends_with(" AND") {
2431 LogicalOp::And
2432 } else {
2433 LogicalOp::Or
2434 };
2435 return (CursorContext::AfterLogicalOp(op), None);
2436 }
2437
2438 if let Some(and_pos) = query.to_uppercase().rfind(" AND ") {
2440 let after_and = safe_slice_from(query, and_pos + 5);
2441 let partial = extract_partial_at_end(after_and);
2442 if partial.is_some() {
2443 return (CursorContext::AfterLogicalOp(LogicalOp::And), partial);
2444 }
2445 }
2446
2447 if let Some(or_pos) = query.to_uppercase().rfind(" OR ") {
2448 let after_or = safe_slice_from(query, or_pos + 4);
2449 let partial = extract_partial_at_end(after_or);
2450 if partial.is_some() {
2451 return (CursorContext::AfterLogicalOp(LogicalOp::Or), partial);
2452 }
2453 }
2454
2455 if let Some(last_condition) = where_clause.conditions.last() {
2456 if let Some(connector) = &last_condition.connector {
2457 return (
2459 CursorContext::AfterLogicalOp(connector.clone()),
2460 extract_partial_at_end(query),
2461 );
2462 }
2463 }
2464 return (CursorContext::WhereClause, extract_partial_at_end(query));
2466 }
2467
2468 if query.to_uppercase().ends_with(" ORDER BY ") || query.to_uppercase().ends_with(" ORDER BY") {
2470 return (CursorContext::OrderByClause, None);
2471 }
2472
2473 if stmt.order_by.is_some() {
2475 return (CursorContext::OrderByClause, extract_partial_at_end(query));
2476 }
2477
2478 if stmt.from_table.is_some() && stmt.where_clause.is_none() && stmt.order_by.is_none() {
2479 return (CursorContext::FromClause, extract_partial_at_end(query));
2480 }
2481
2482 if stmt.columns.len() > 0 && stmt.from_table.is_none() {
2483 return (CursorContext::SelectClause, extract_partial_at_end(query));
2484 }
2485
2486 (CursorContext::Unknown, None)
2487}
2488
2489fn analyze_partial(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
2490 let upper = query.to_uppercase();
2491
2492 let trimmed = query.trim();
2494
2495 #[cfg(test)]
2496 {
2497 if trimmed.contains("\"Last Name\"") {
2498 eprintln!(
2499 "DEBUG analyze_partial: query='{}', trimmed='{}'",
2500 query, trimmed
2501 );
2502 }
2503 }
2504
2505 let comparison_ops = [" > ", " < ", " >= ", " <= ", " = ", " != "];
2507 for op in &comparison_ops {
2508 if let Some(op_pos) = query.rfind(op) {
2509 let before_op = safe_slice_to(query, op_pos);
2510 let after_op_start = op_pos + op.len();
2511 let after_op = if after_op_start < query.len() {
2512 &query[after_op_start..]
2513 } else {
2514 ""
2515 };
2516
2517 if let Some(col_name) = before_op.split_whitespace().last() {
2519 if col_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
2520 let after_op_trimmed = after_op.trim();
2522 if after_op_trimmed.is_empty()
2523 || (after_op_trimmed
2524 .chars()
2525 .all(|c| c.is_alphanumeric() || c == '_')
2526 && !after_op_trimmed.contains('('))
2527 {
2528 let partial = if after_op_trimmed.is_empty() {
2529 None
2530 } else {
2531 Some(after_op_trimmed.to_string())
2532 };
2533 return (
2534 CursorContext::AfterComparisonOp(
2535 col_name.to_string(),
2536 op.trim().to_string(),
2537 ),
2538 partial,
2539 );
2540 }
2541 }
2542 }
2543 }
2544 }
2545
2546 if let Some(dot_pos) = trimmed.rfind('.') {
2549 #[cfg(test)]
2550 {
2551 if trimmed.contains("\"Last Name\"") {
2552 eprintln!("DEBUG: Found dot at position {}", dot_pos);
2553 }
2554 }
2555 let before_dot = &trimmed[..dot_pos];
2557 let after_dot = &trimmed[dot_pos + 1..];
2558
2559 if !after_dot.contains('(') {
2562 let col_name = if before_dot.ends_with('"') {
2565 let bytes = before_dot.as_bytes();
2567 let mut pos = before_dot.len() - 1; let mut found_start = None;
2569
2570 #[cfg(test)]
2571 {
2572 if trimmed.contains("\"Last Name\"") {
2573 eprintln!(
2574 "DEBUG: before_dot='{}', looking for opening quote",
2575 before_dot
2576 );
2577 }
2578 }
2579
2580 if pos > 0 {
2582 pos -= 1;
2583 while pos > 0 {
2584 if bytes[pos] == b'"' {
2585 if pos == 0 || bytes[pos - 1] != b'\\' {
2587 found_start = Some(pos);
2588 break;
2589 }
2590 }
2591 pos -= 1;
2592 }
2593 if found_start.is_none() && bytes[0] == b'"' {
2595 found_start = Some(0);
2596 }
2597 }
2598
2599 if let Some(start) = found_start {
2600 let result = safe_slice_from(before_dot, start);
2602 #[cfg(test)]
2603 {
2604 if trimmed.contains("\"Last Name\"") {
2605 eprintln!("DEBUG: Extracted quoted identifier: '{}'", result);
2606 }
2607 }
2608 Some(result)
2609 } else {
2610 #[cfg(test)]
2611 {
2612 if trimmed.contains("\"Last Name\"") {
2613 eprintln!("DEBUG: No opening quote found!");
2614 }
2615 }
2616 None
2617 }
2618 } else {
2619 before_dot
2622 .split_whitespace()
2623 .last()
2624 .map(|word| word.trim_start_matches('('))
2625 };
2626
2627 if let Some(col_name) = col_name {
2628 #[cfg(test)]
2629 {
2630 if trimmed.contains("\"Last Name\"") {
2631 eprintln!("DEBUG: col_name = '{}'", col_name);
2632 }
2633 }
2634
2635 let is_valid = if col_name.starts_with('"') && col_name.ends_with('"') {
2637 true
2639 } else {
2640 col_name.chars().all(|c| c.is_alphanumeric() || c == '_')
2642 };
2643
2644 #[cfg(test)]
2645 {
2646 if trimmed.contains("\"Last Name\"") {
2647 eprintln!("DEBUG: is_valid = {}", is_valid);
2648 }
2649 }
2650
2651 if is_valid {
2652 let partial_method = if after_dot.is_empty() {
2655 None
2656 } else if after_dot.chars().all(|c| c.is_alphanumeric() || c == '_') {
2657 Some(after_dot.to_string())
2658 } else {
2659 None
2660 };
2661
2662 let col_name_for_context = if col_name.starts_with('"')
2664 && col_name.ends_with('"')
2665 && col_name.len() > 2
2666 {
2667 col_name[1..col_name.len() - 1].to_string()
2668 } else {
2669 col_name.to_string()
2670 };
2671
2672 return (
2673 CursorContext::AfterColumn(col_name_for_context),
2674 partial_method,
2675 );
2676 }
2677 }
2678 }
2679 }
2680
2681 if let Some(and_pos) = upper.rfind(" AND ") {
2683 if cursor_pos >= and_pos + 5 {
2685 let after_and = safe_slice_from(query, and_pos + 5);
2687 let partial = extract_partial_at_end(after_and);
2688 return (CursorContext::AfterLogicalOp(LogicalOp::And), partial);
2689 }
2690 }
2691
2692 if let Some(or_pos) = upper.rfind(" OR ") {
2693 if cursor_pos >= or_pos + 4 {
2695 let after_or = safe_slice_from(query, or_pos + 4);
2697 let partial = extract_partial_at_end(after_or);
2698 return (CursorContext::AfterLogicalOp(LogicalOp::Or), partial);
2699 }
2700 }
2701
2702 if trimmed.to_uppercase().ends_with(" AND") || trimmed.to_uppercase().ends_with(" OR") {
2704 let op = if trimmed.to_uppercase().ends_with(" AND") {
2705 LogicalOp::And
2706 } else {
2707 LogicalOp::Or
2708 };
2709 return (CursorContext::AfterLogicalOp(op), None);
2710 }
2711
2712 if upper.ends_with(" ORDER BY ") || upper.ends_with(" ORDER BY") || upper.contains("ORDER BY ")
2714 {
2715 return (CursorContext::OrderByClause, extract_partial_at_end(query));
2716 }
2717
2718 if upper.contains("WHERE") && !upper.contains("ORDER") && !upper.contains("GROUP") {
2719 return (CursorContext::WhereClause, extract_partial_at_end(query));
2720 }
2721
2722 if upper.contains("FROM") && !upper.contains("WHERE") && !upper.contains("ORDER") {
2723 return (CursorContext::FromClause, extract_partial_at_end(query));
2724 }
2725
2726 if upper.contains("SELECT") && !upper.contains("FROM") {
2727 return (CursorContext::SelectClause, extract_partial_at_end(query));
2728 }
2729
2730 (CursorContext::Unknown, None)
2731}
2732
2733fn extract_partial_at_end(query: &str) -> Option<String> {
2734 let trimmed = query.trim();
2735
2736 if let Some(last_word) = trimmed.split_whitespace().last() {
2738 if last_word.starts_with('"') && !last_word.ends_with('"') {
2739 return Some(last_word.to_string());
2741 }
2742 }
2743
2744 let last_word = trimmed.split_whitespace().last()?;
2746
2747 if last_word.chars().all(|c| c.is_alphanumeric() || c == '_') && !is_sql_keyword(last_word) {
2749 Some(last_word.to_string())
2750 } else {
2751 None
2752 }
2753}
2754
2755fn is_sql_keyword(word: &str) -> bool {
2756 matches!(
2757 word.to_uppercase().as_str(),
2758 "SELECT"
2759 | "FROM"
2760 | "WHERE"
2761 | "AND"
2762 | "OR"
2763 | "IN"
2764 | "ORDER"
2765 | "BY"
2766 | "GROUP"
2767 | "HAVING"
2768 | "ASC"
2769 | "DESC"
2770 )
2771}
2772
2773#[cfg(test)]
2774mod tests {
2775 use super::*;
2776
2777 #[test]
2778 fn test_chained_method_calls() {
2779 let query = "SELECT * FROM trades WHERE commission.ToString().IndexOf('.') = 1";
2781 let mut parser = Parser::new(query);
2782 let result = parser.parse();
2783
2784 assert!(
2785 result.is_ok(),
2786 "Failed to parse chained method calls: {:?}",
2787 result
2788 );
2789
2790 let query2 = "SELECT * FROM data WHERE field.ToUpper().Replace('A', 'B').StartsWith('C')";
2792 let mut parser2 = Parser::new(query2);
2793 let result2 = parser2.parse();
2794
2795 assert!(
2796 result2.is_ok(),
2797 "Failed to parse multiple chained calls: {:?}",
2798 result2
2799 );
2800 }
2801
2802 #[test]
2803 fn test_tokenizer() {
2804 let mut lexer = Lexer::new("SELECT * FROM trade_deal WHERE price > 100");
2805
2806 assert!(matches!(lexer.next_token(), Token::Select));
2807 assert!(matches!(lexer.next_token(), Token::Star));
2808 assert!(matches!(lexer.next_token(), Token::From));
2809 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "trade_deal"));
2810 assert!(matches!(lexer.next_token(), Token::Where));
2811 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "price"));
2812 assert!(matches!(lexer.next_token(), Token::GreaterThan));
2813 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "100"));
2814 }
2815
2816 #[test]
2817 fn test_tokenizer_datetime() {
2818 let mut lexer = Lexer::new("WHERE createdDate > DateTime(2025, 10, 20)");
2819
2820 assert!(matches!(lexer.next_token(), Token::Where));
2821 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "createdDate"));
2822 assert!(matches!(lexer.next_token(), Token::GreaterThan));
2823 assert!(matches!(lexer.next_token(), Token::DateTime));
2824 assert!(matches!(lexer.next_token(), Token::LeftParen));
2825 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "2025"));
2826 assert!(matches!(lexer.next_token(), Token::Comma));
2827 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "10"));
2828 assert!(matches!(lexer.next_token(), Token::Comma));
2829 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "20"));
2830 assert!(matches!(lexer.next_token(), Token::RightParen));
2831 }
2832
2833 #[test]
2834 fn test_parse_simple_select() {
2835 let mut parser = Parser::new("SELECT * FROM trade_deal");
2836 let stmt = parser.parse().unwrap();
2837
2838 assert_eq!(stmt.columns, vec!["*"]);
2839 assert_eq!(stmt.from_table, Some("trade_deal".to_string()));
2840 assert!(stmt.where_clause.is_none());
2841 }
2842
2843 #[test]
2844 fn test_parse_where_with_method() {
2845 let mut parser = Parser::new("SELECT * FROM trade_deal WHERE name.Contains(\"test\")");
2846 let stmt = parser.parse().unwrap();
2847
2848 assert!(stmt.where_clause.is_some());
2849 let where_clause = stmt.where_clause.unwrap();
2850 assert_eq!(where_clause.conditions.len(), 1);
2851 }
2852
2853 #[test]
2854 fn test_parse_datetime_constructor() {
2855 let mut parser =
2856 Parser::new("SELECT * FROM trade_deal WHERE createdDate > DateTime(2025, 10, 20)");
2857 let stmt = parser.parse().unwrap();
2858
2859 assert!(stmt.where_clause.is_some());
2860 let where_clause = stmt.where_clause.unwrap();
2861 assert_eq!(where_clause.conditions.len(), 1);
2862
2863 if let SqlExpression::BinaryOp { left, op, right } = &where_clause.conditions[0].expr {
2865 assert_eq!(op, ">");
2866 assert!(matches!(left.as_ref(), SqlExpression::Column(col) if col == "createdDate"));
2867 assert!(matches!(
2868 right.as_ref(),
2869 SqlExpression::DateTimeConstructor {
2870 year: 2025,
2871 month: 10,
2872 day: 20,
2873 hour: None,
2874 minute: None,
2875 second: None
2876 }
2877 ));
2878 } else {
2879 panic!("Expected BinaryOp with DateTime constructor");
2880 }
2881 }
2882
2883 #[test]
2884 fn test_cursor_context_after_and() {
2885 let query = "SELECT * FROM trade_deal WHERE status = 'active' AND ";
2886 let (context, partial) = detect_cursor_context(query, query.len());
2887
2888 assert!(matches!(
2889 context,
2890 CursorContext::AfterLogicalOp(LogicalOp::And)
2891 ));
2892 assert_eq!(partial, None);
2893 }
2894
2895 #[test]
2896 fn test_cursor_context_with_partial() {
2897 let query = "SELECT * FROM trade_deal WHERE status = 'active' AND p";
2898 let (context, partial) = detect_cursor_context(query, query.len());
2899
2900 assert!(matches!(
2901 context,
2902 CursorContext::AfterLogicalOp(LogicalOp::And)
2903 ));
2904 assert_eq!(partial, Some("p".to_string()));
2905 }
2906
2907 #[test]
2908 fn test_cursor_context_after_datetime_comparison() {
2909 let query = "SELECT * FROM trade_deal WHERE createdDate > ";
2910 let (context, partial) = detect_cursor_context(query, query.len());
2911
2912 assert!(
2913 matches!(context, CursorContext::AfterComparisonOp(col, op) if col == "createdDate" && op == ">")
2914 );
2915 assert_eq!(partial, None);
2916 }
2917
2918 #[test]
2919 fn test_cursor_context_partial_datetime() {
2920 let query = "SELECT * FROM trade_deal WHERE createdDate > Date";
2921 let (context, partial) = detect_cursor_context(query, query.len());
2922
2923 assert!(
2924 matches!(context, CursorContext::AfterComparisonOp(col, op) if col == "createdDate" && op == ">")
2925 );
2926 assert_eq!(partial, Some("Date".to_string()));
2927 }
2928
2929 #[test]
2931 fn test_tokenizer_quoted_identifier() {
2932 let mut lexer = Lexer::new(r#"SELECT "Customer Id", "First Name" FROM customers"#);
2933
2934 assert!(matches!(lexer.next_token(), Token::Select));
2935 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "Customer Id"));
2936 assert!(matches!(lexer.next_token(), Token::Comma));
2937 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "First Name"));
2938 assert!(matches!(lexer.next_token(), Token::From));
2939 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "customers"));
2940 }
2941
2942 #[test]
2943 fn test_tokenizer_quoted_vs_string_literal() {
2944 let mut lexer = Lexer::new(r#"WHERE "Customer Id" = 'John' AND Country.Contains('USA')"#);
2946
2947 assert!(matches!(lexer.next_token(), Token::Where));
2948 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "Customer Id"));
2949 assert!(matches!(lexer.next_token(), Token::Equal));
2950 assert!(matches!(lexer.next_token(), Token::StringLiteral(s) if s == "John"));
2951 assert!(matches!(lexer.next_token(), Token::And));
2952 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Country"));
2953 assert!(matches!(lexer.next_token(), Token::Dot));
2954 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Contains"));
2955 assert!(matches!(lexer.next_token(), Token::LeftParen));
2956 assert!(matches!(lexer.next_token(), Token::StringLiteral(s) if s == "USA"));
2957 assert!(matches!(lexer.next_token(), Token::RightParen));
2958 }
2959
2960 #[test]
2961 fn test_tokenizer_method_with_double_quotes_should_be_string() {
2962 let mut lexer = Lexer::new(r#"Country.Contains("Alb")"#);
2965
2966 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Country"));
2967 assert!(matches!(lexer.next_token(), Token::Dot));
2968 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Contains"));
2969 assert!(matches!(lexer.next_token(), Token::LeftParen));
2970
2971 let token = lexer.next_token();
2974 println!("Token for \"Alb\": {:?}", token);
2975 assert!(matches!(lexer.next_token(), Token::RightParen));
2979 }
2980
2981 #[test]
2982 fn test_parse_select_with_quoted_columns() {
2983 let mut parser = Parser::new(r#"SELECT "Customer Id", Company FROM customers"#);
2984 let stmt = parser.parse().unwrap();
2985
2986 assert_eq!(stmt.columns, vec!["Customer Id", "Company"]);
2987 assert_eq!(stmt.from_table, Some("customers".to_string()));
2988 }
2989
2990 #[test]
2991 fn test_cursor_context_select_with_partial_quoted() {
2992 let query = r#"SELECT "Cust"#;
2994 let (context, partial) = detect_cursor_context(query, query.len() - 1); println!("Context: {:?}, Partial: {:?}", context, partial);
2997 assert!(matches!(context, CursorContext::SelectClause));
2998 }
3001
3002 #[test]
3003 fn test_cursor_context_select_after_comma_with_quoted() {
3004 let query = r#"SELECT Company, "Customer "#;
3006 let (context, partial) = detect_cursor_context(query, query.len());
3007
3008 println!("Context: {:?}, Partial: {:?}", context, partial);
3009 assert!(matches!(context, CursorContext::SelectClause));
3010 }
3012
3013 #[test]
3014 fn test_cursor_context_order_by_quoted() {
3015 let query = r#"SELECT * FROM customers ORDER BY "Cust"#;
3016 let (context, partial) = detect_cursor_context(query, query.len() - 1);
3017
3018 println!("Context: {:?}, Partial: {:?}", context, partial);
3019 assert!(matches!(context, CursorContext::OrderByClause));
3020 }
3022
3023 #[test]
3024 fn test_where_clause_with_quoted_column() {
3025 let mut parser = Parser::new(r#"SELECT * FROM customers WHERE "Customer Id" = 'C123'"#);
3026 let stmt = parser.parse().unwrap();
3027
3028 assert!(stmt.where_clause.is_some());
3029 let where_clause = stmt.where_clause.unwrap();
3030 assert_eq!(where_clause.conditions.len(), 1);
3031
3032 if let SqlExpression::BinaryOp { left, op, right } = &where_clause.conditions[0].expr {
3033 assert_eq!(op, "=");
3034 assert!(matches!(left.as_ref(), SqlExpression::Column(col) if col == "Customer Id"));
3035 assert!(matches!(right.as_ref(), SqlExpression::StringLiteral(s) if s == "C123"));
3036 } else {
3037 panic!("Expected BinaryOp");
3038 }
3039 }
3040
3041 #[test]
3042 fn test_parse_method_with_double_quotes_as_string() {
3043 let mut parser = Parser::new(r#"SELECT * FROM customers WHERE Country.Contains("USA")"#);
3045 let stmt = parser.parse().unwrap();
3046
3047 assert!(stmt.where_clause.is_some());
3048 let where_clause = stmt.where_clause.unwrap();
3049 assert_eq!(where_clause.conditions.len(), 1);
3050
3051 if let SqlExpression::MethodCall {
3052 object,
3053 method,
3054 args,
3055 } = &where_clause.conditions[0].expr
3056 {
3057 assert_eq!(object, "Country");
3058 assert_eq!(method, "Contains");
3059 assert_eq!(args.len(), 1);
3060 assert!(matches!(&args[0], SqlExpression::StringLiteral(s) if s == "USA"));
3062 } else {
3063 panic!("Expected MethodCall");
3064 }
3065 }
3066
3067 #[test]
3068 fn test_extract_partial_with_quoted_columns_in_query() {
3069 let query = r#"SELECT City,Company,Country,"Customer Id" FROM customers ORDER BY coun"#;
3071 let (context, partial) = detect_cursor_context(query, query.len());
3072
3073 assert!(matches!(context, CursorContext::OrderByClause));
3074 assert_eq!(
3075 partial,
3076 Some("coun".to_string()),
3077 "Should extract 'coun' as partial, not everything after the quoted column"
3078 );
3079 }
3080
3081 #[test]
3082 fn test_extract_partial_quoted_identifier_being_typed() {
3083 let query = r#"SELECT "Cust"#;
3085 let partial = extract_partial_at_end(query);
3086 assert_eq!(partial, Some("\"Cust".to_string()));
3087
3088 let query2 = r#"SELECT "Customer Id" FROM"#;
3090 let partial2 = extract_partial_at_end(query2);
3091 assert_eq!(partial2, None); }
3093
3094 #[test]
3096 fn test_complex_where_parentheses_basic() {
3097 let mut parser =
3099 Parser::new(r#"SELECT * FROM trades WHERE (status = "active" OR status = "pending")"#);
3100 let stmt = parser.parse().unwrap();
3101
3102 assert!(stmt.where_clause.is_some());
3103 let where_clause = stmt.where_clause.unwrap();
3104 assert_eq!(where_clause.conditions.len(), 1);
3105
3106 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[0].expr {
3108 assert_eq!(op, "OR");
3109 } else {
3110 panic!("Expected BinaryOp with OR");
3111 }
3112 }
3113
3114 #[test]
3115 fn test_complex_where_mixed_and_or_with_parens() {
3116 let mut parser = Parser::new(
3118 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#,
3119 );
3120 let stmt = parser.parse().unwrap();
3121
3122 assert!(stmt.where_clause.is_some());
3123 let where_clause = stmt.where_clause.unwrap();
3124 assert_eq!(where_clause.conditions.len(), 2);
3125
3126 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[0].expr {
3128 assert_eq!(op, "OR");
3129 } else {
3130 panic!("Expected first condition to be OR expression");
3131 }
3132
3133 assert!(matches!(
3135 where_clause.conditions[0].connector,
3136 Some(LogicalOp::And)
3137 ));
3138
3139 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[1].expr {
3141 assert_eq!(op, ">");
3142 } else {
3143 panic!("Expected second condition to be price > 100");
3144 }
3145 }
3146
3147 #[test]
3148 fn test_complex_where_nested_parentheses() {
3149 let mut parser = Parser::new(
3151 r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND price > 100) OR status = "cancelled""#,
3152 );
3153 let stmt = parser.parse().unwrap();
3154
3155 assert!(stmt.where_clause.is_some());
3156 let where_clause = stmt.where_clause.unwrap();
3157
3158 assert!(where_clause.conditions.len() > 0);
3160 }
3161
3162 #[test]
3163 fn test_complex_where_multiple_or_groups() {
3164 let mut parser = Parser::new(
3166 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL" OR symbol = "MSFT") AND (price > 100 AND price < 500)"#,
3167 );
3168 let stmt = parser.parse().unwrap();
3169
3170 assert!(stmt.where_clause.is_some());
3171 let where_clause = stmt.where_clause.unwrap();
3172 assert_eq!(where_clause.conditions.len(), 2);
3173
3174 assert!(matches!(
3176 where_clause.conditions[0].connector,
3177 Some(LogicalOp::And)
3178 ));
3179 }
3180
3181 #[test]
3182 fn test_complex_where_with_methods_in_parens() {
3183 let mut parser = Parser::new(
3185 r#"SELECT * FROM trades WHERE (symbol.StartsWith("A") OR symbol.StartsWith("G")) AND volume > 1000000"#,
3186 );
3187 let stmt = parser.parse().unwrap();
3188
3189 assert!(stmt.where_clause.is_some());
3190 let where_clause = stmt.where_clause.unwrap();
3191 assert_eq!(where_clause.conditions.len(), 2);
3192
3193 if let SqlExpression::BinaryOp { op, left, right } = &where_clause.conditions[0].expr {
3195 assert_eq!(op, "OR");
3196 assert!(matches!(left.as_ref(), SqlExpression::MethodCall { .. }));
3197 assert!(matches!(right.as_ref(), SqlExpression::MethodCall { .. }));
3198 } else {
3199 panic!("Expected OR of method calls");
3200 }
3201 }
3202
3203 #[test]
3204 fn test_complex_where_date_comparisons_with_parens() {
3205 let mut parser = Parser::new(
3207 r#"SELECT * FROM trades WHERE (executionDate > DateTime(2024, 1, 1) AND executionDate < DateTime(2024, 12, 31)) AND (status = "filled" OR status = "partial")"#,
3208 );
3209 let stmt = parser.parse().unwrap();
3210
3211 assert!(stmt.where_clause.is_some());
3212 let where_clause = stmt.where_clause.unwrap();
3213 assert_eq!(where_clause.conditions.len(), 2);
3214
3215 assert!(matches!(
3217 where_clause.conditions[0].connector,
3218 Some(LogicalOp::And)
3219 ));
3220 }
3221
3222 #[test]
3223 fn test_complex_where_price_volume_filters() {
3224 let mut parser = Parser::new(
3226 r#"SELECT * FROM trades WHERE ((price > 100 AND price < 200) OR (price > 500 AND price < 1000)) AND volume > 10000"#,
3227 );
3228 let stmt = parser.parse().unwrap();
3229
3230 assert!(stmt.where_clause.is_some());
3231 let where_clause = stmt.where_clause.unwrap();
3232
3233 assert!(where_clause.conditions.len() > 0);
3235 }
3236
3237 #[test]
3238 fn test_complex_where_mixed_string_numeric() {
3239 let mut parser = Parser::new(
3241 r#"SELECT * FROM trades WHERE (exchange = "NYSE" OR exchange = "NASDAQ") AND (volume > 1000000 OR notes.Contains("urgent"))"#,
3242 );
3243 let stmt = parser.parse().unwrap();
3244
3245 assert!(stmt.where_clause.is_some());
3246 }
3248
3249 #[test]
3250 fn test_complex_where_triple_nested() {
3251 let mut parser = Parser::new(
3253 r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND (price > 100 OR volume > 1000000)) OR (status = "cancelled" AND reason.Contains("timeout"))"#,
3254 );
3255 let stmt = parser.parse().unwrap();
3256
3257 assert!(stmt.where_clause.is_some());
3258 }
3260
3261 #[test]
3262 fn test_complex_where_single_parens_around_and() {
3263 let mut parser = Parser::new(
3265 r#"SELECT * FROM trades WHERE (symbol = "AAPL" AND price > 150 AND volume > 100000)"#,
3266 );
3267 let stmt = parser.parse().unwrap();
3268
3269 assert!(stmt.where_clause.is_some());
3270 let where_clause = stmt.where_clause.unwrap();
3271
3272 assert!(where_clause.conditions.len() > 0);
3274 }
3275
3276 #[test]
3278 fn test_format_preserves_simple_parentheses() {
3279 let query = r#"SELECT * FROM trades WHERE (status = "active" OR status = "pending")"#;
3280 let formatted = format_sql_pretty_compact(query, 5);
3281 let formatted_text = formatted.join(" ");
3282
3283 assert!(formatted_text.contains("(status"));
3285 assert!(formatted_text.contains("\"pending\")"));
3286
3287 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3289 let formatted_parens = formatted_text
3290 .chars()
3291 .filter(|c| *c == '(' || *c == ')')
3292 .count();
3293 assert_eq!(
3294 original_parens, formatted_parens,
3295 "Parentheses should be preserved"
3296 );
3297 }
3298
3299 #[test]
3300 fn test_format_preserves_complex_parentheses() {
3301 let query =
3302 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#;
3303 let formatted = format_sql_pretty_compact(query, 5);
3304 let formatted_text = formatted.join(" ");
3305
3306 assert!(formatted_text.contains("(symbol"));
3308 assert!(formatted_text.contains("\"GOOGL\")"));
3309
3310 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3312 let formatted_parens = formatted_text
3313 .chars()
3314 .filter(|c| *c == '(' || *c == ')')
3315 .count();
3316 assert_eq!(original_parens, formatted_parens);
3317 }
3318
3319 #[test]
3320 fn test_format_preserves_nested_parentheses() {
3321 let query = r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND price > 100) OR status = "cancelled""#;
3322 let formatted = format_sql_pretty_compact(query, 5);
3323 let formatted_text = formatted.join(" ");
3324
3325 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3327 let formatted_parens = formatted_text
3328 .chars()
3329 .filter(|c| *c == '(' || *c == ')')
3330 .count();
3331 assert_eq!(
3332 original_parens, formatted_parens,
3333 "Nested parentheses should be preserved"
3334 );
3335 assert_eq!(original_parens, 4, "Should have 4 parentheses total");
3336 }
3337
3338 #[test]
3339 fn test_format_preserves_method_calls_in_parentheses() {
3340 let query = r#"SELECT * FROM trades WHERE (symbol.StartsWith("A") OR symbol.StartsWith("G")) AND volume > 1000000"#;
3341 let formatted = format_sql_pretty_compact(query, 5);
3342 let formatted_text = formatted.join(" ");
3343
3344 assert!(formatted_text.contains("(symbol.StartsWith"));
3346 assert!(formatted_text.contains("StartsWith(\"A\")"));
3347 assert!(formatted_text.contains("StartsWith(\"G\")"));
3348
3349 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3351 let formatted_parens = formatted_text
3352 .chars()
3353 .filter(|c| *c == '(' || *c == ')')
3354 .count();
3355 assert_eq!(original_parens, formatted_parens);
3356 assert_eq!(
3357 original_parens, 6,
3358 "Should have 6 parentheses (1 group + 2 method calls)"
3359 );
3360 }
3361
3362 #[test]
3363 fn test_format_preserves_multiple_groups() {
3364 let query = r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND (price > 100 AND price < 500)"#;
3365 let formatted = format_sql_pretty_compact(query, 5);
3366 let formatted_text = formatted.join(" ");
3367
3368 assert!(formatted_text.contains("(symbol"));
3370 assert!(formatted_text.contains("(price"));
3371
3372 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3373 let formatted_parens = formatted_text
3374 .chars()
3375 .filter(|c| *c == '(' || *c == ')')
3376 .count();
3377 assert_eq!(original_parens, formatted_parens);
3378 assert_eq!(original_parens, 4, "Should have 4 parentheses (2 groups)");
3379 }
3380
3381 #[test]
3382 fn test_format_preserves_date_ranges() {
3383 let query = r#"SELECT * FROM trades WHERE (executionDate > DateTime(2024, 1, 1) AND executionDate < DateTime(2024, 12, 31))"#;
3384 let formatted = format_sql_pretty_compact(query, 5);
3385 let formatted_text = formatted.join(" ");
3386
3387 assert!(formatted_text.contains("(executionDate"));
3389 assert!(formatted_text.contains("DateTime(2024, 1, 1)"));
3390 assert!(formatted_text.contains("DateTime(2024, 12, 31)"));
3391
3392 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3393 let formatted_parens = formatted_text
3394 .chars()
3395 .filter(|c| *c == '(' || *c == ')')
3396 .count();
3397 assert_eq!(original_parens, formatted_parens);
3398 }
3399
3400 #[test]
3401 fn test_format_multiline_layout() {
3402 let query =
3404 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#;
3405 let formatted = format_sql_pretty_compact(query, 5);
3406
3407 assert!(formatted.len() >= 4, "Should have multiple lines");
3409 assert_eq!(formatted[0], "SELECT");
3410 assert!(formatted[1].trim().starts_with("*"));
3411 assert!(formatted[2].starts_with("FROM"));
3412 assert_eq!(formatted[3], "WHERE");
3413
3414 let where_lines: Vec<_> = formatted.iter().skip(4).collect();
3416 assert!(where_lines.iter().any(|l| l.contains("(symbol")));
3417 assert!(where_lines.iter().any(|l| l.trim() == "AND"));
3418 }
3419
3420 #[test]
3421 fn test_between_simple() {
3422 let mut parser = Parser::new("SELECT * FROM table WHERE price BETWEEN 50 AND 100");
3423 let stmt = parser.parse().expect("Should parse simple BETWEEN");
3424
3425 assert!(stmt.where_clause.is_some());
3426 let where_clause = stmt.where_clause.unwrap();
3427 assert_eq!(where_clause.conditions.len(), 1);
3428
3429 let ast = format_ast_tree("SELECT * FROM table WHERE price BETWEEN 50 AND 100");
3431 assert!(!ast.contains("PARSE ERROR"));
3432 assert!(ast.contains("SelectStatement"));
3433 }
3434
3435 #[test]
3436 fn test_between_in_parentheses() {
3437 let mut parser = Parser::new("SELECT * FROM table WHERE (price BETWEEN 50 AND 100)");
3438 let stmt = parser.parse().expect("Should parse BETWEEN in parentheses");
3439
3440 assert!(stmt.where_clause.is_some());
3441
3442 let ast = format_ast_tree("SELECT * FROM table WHERE (price BETWEEN 50 AND 100)");
3444 assert!(!ast.contains("PARSE ERROR"), "Should not have parse error");
3445 }
3446
3447 #[test]
3448 fn test_between_with_or() {
3449 let query = "SELECT * FROM test WHERE (Price BETWEEN 50 AND 100) OR (quantity > 5)";
3450 let mut parser = Parser::new(query);
3451 let stmt = parser.parse().expect("Should parse BETWEEN with OR");
3452
3453 assert!(stmt.where_clause.is_some());
3454 }
3457
3458 #[test]
3459 fn test_between_with_and() {
3460 let query = "SELECT * FROM table WHERE category = 'Books' AND price BETWEEN 10 AND 50";
3461 let mut parser = Parser::new(query);
3462 let stmt = parser.parse().expect("Should parse BETWEEN with AND");
3463
3464 assert!(stmt.where_clause.is_some());
3465 let where_clause = stmt.where_clause.unwrap();
3466 assert_eq!(where_clause.conditions.len(), 2); }
3468
3469 #[test]
3470 fn test_multiple_between() {
3471 let query =
3472 "SELECT * FROM table WHERE (price BETWEEN 10 AND 50) AND (quantity BETWEEN 5 AND 20)";
3473 let mut parser = Parser::new(query);
3474 let stmt = parser
3475 .parse()
3476 .expect("Should parse multiple BETWEEN clauses");
3477
3478 assert!(stmt.where_clause.is_some());
3479 }
3480
3481 #[test]
3482 fn test_between_complex_query() {
3483 let query = "SELECT * FROM test_sorting WHERE (Price BETWEEN 50 AND 100) OR (Product.Length() > 5) ORDER BY Category ASC, price DESC";
3485 let mut parser = Parser::new(query);
3486 let stmt = parser
3487 .parse()
3488 .expect("Should parse complex query with BETWEEN, method calls, and ORDER BY");
3489
3490 assert!(stmt.where_clause.is_some());
3491 assert!(stmt.order_by.is_some());
3492
3493 let order_by = stmt.order_by.unwrap();
3494 assert_eq!(order_by.len(), 2);
3495 assert_eq!(order_by[0].column, "Category");
3496 assert!(matches!(order_by[0].direction, SortDirection::Asc));
3497 assert_eq!(order_by[1].column, "price");
3498 assert!(matches!(order_by[1].direction, SortDirection::Desc));
3499 }
3500
3501 #[test]
3502 fn test_between_formatting() {
3503 let expr = SqlExpression::Between {
3504 expr: Box::new(SqlExpression::Column("price".to_string())),
3505 lower: Box::new(SqlExpression::NumberLiteral("50".to_string())),
3506 upper: Box::new(SqlExpression::NumberLiteral("100".to_string())),
3507 };
3508
3509 let formatted = format_expression(&expr);
3510 assert_eq!(formatted, "price BETWEEN 50 AND 100");
3511
3512 let ast_formatted = format_expression_ast(&expr);
3513 assert!(ast_formatted.contains("Between"));
3514 assert!(ast_formatted.contains("50"));
3515 assert!(ast_formatted.contains("100"));
3516 }
3517
3518 #[test]
3519 fn test_utf8_boundary_safety() {
3520 let query_with_unicode = "SELECT * FROM table WHERE column = 'héllo'";
3522
3523 for pos in 0..query_with_unicode.len() + 1 {
3525 let result =
3527 std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, pos));
3528
3529 assert!(
3530 result.is_ok(),
3531 "Panic at position {} in query with Unicode",
3532 pos
3533 );
3534 }
3535
3536 let result = std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, 1000));
3538 assert!(result.is_ok(), "Panic with position beyond string length");
3539
3540 let pos_in_e = query_with_unicode.find('é').unwrap() + 1; let result =
3543 std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, pos_in_e));
3544 assert!(
3545 result.is_ok(),
3546 "Panic with cursor in middle of UTF-8 character"
3547 );
3548 }
3549}