1use chrono::{Datelike, Local, NaiveDateTime};
2
3#[derive(Debug, Clone, PartialEq)]
4pub enum Token {
5 Select,
7 From,
8 Where,
9 With, And,
11 Or,
12 In,
13 Not,
14 Between,
15 Like,
16 Is,
17 Null,
18 OrderBy,
19 GroupBy,
20 Having,
21 As,
22 Asc,
23 Desc,
24 Limit,
25 Offset,
26 DateTime, Case, When, Then, Else, End, Distinct, Over, Partition, By, Join, Inner, Left, Right, Full, Outer, On, Cross, Identifier(String),
49 QuotedIdentifier(String), StringLiteral(String),
51 NumberLiteral(String),
52 Star,
53
54 Dot,
56 Comma,
57 LeftParen,
58 RightParen,
59 Equal,
60 NotEqual,
61 LessThan,
62 GreaterThan,
63 LessThanOrEqual,
64 GreaterThanOrEqual,
65
66 Plus,
68 Minus,
69 Divide,
70 Modulo,
71
72 Eof,
74}
75
76#[derive(Debug, Clone)]
77pub struct Lexer {
78 input: Vec<char>,
79 position: usize,
80 current_char: Option<char>,
81}
82
83impl Lexer {
84 #[must_use]
85 pub fn new(input: &str) -> Self {
86 let chars: Vec<char> = input.chars().collect();
87 let current = chars.first().copied();
88 Self {
89 input: chars,
90 position: 0,
91 current_char: current,
92 }
93 }
94
95 fn advance(&mut self) {
96 self.position += 1;
97 self.current_char = self.input.get(self.position).copied();
98 }
99
100 fn peek(&self, offset: usize) -> Option<char> {
101 self.input.get(self.position + offset).copied()
102 }
103
104 fn skip_whitespace(&mut self) {
105 while let Some(ch) = self.current_char {
106 if ch.is_whitespace() {
107 self.advance();
108 } else {
109 break;
110 }
111 }
112 }
113
114 fn skip_whitespace_and_comments(&mut self) {
115 loop {
116 while let Some(ch) = self.current_char {
118 if ch.is_whitespace() {
119 self.advance();
120 } else {
121 break;
122 }
123 }
124
125 match self.current_char {
127 Some('-') if self.peek(1) == Some('-') => {
128 self.advance(); self.advance(); while let Some(ch) = self.current_char {
132 self.advance();
133 if ch == '\n' {
134 break;
135 }
136 }
137 }
138 Some('/') if self.peek(1) == Some('*') => {
139 self.advance(); self.advance(); while let Some(ch) = self.current_char {
143 if ch == '*' && self.peek(1) == Some('/') {
144 self.advance(); self.advance(); break;
147 }
148 self.advance();
149 }
150 }
151 _ => {
152 break;
154 }
155 }
156 }
157 }
158
159 fn read_identifier(&mut self) -> String {
160 let mut result = String::new();
161 while let Some(ch) = self.current_char {
162 if ch.is_alphanumeric() || ch == '_' {
163 result.push(ch);
164 self.advance();
165 } else {
166 break;
167 }
168 }
169 result
170 }
171
172 fn read_string(&mut self) -> String {
173 let mut result = String::new();
174 let quote_char = self.current_char.unwrap(); self.advance(); while let Some(ch) = self.current_char {
178 if ch == quote_char {
179 self.advance(); break;
181 }
182 result.push(ch);
183 self.advance();
184 }
185 result
186 }
187
188 fn read_number(&mut self) -> String {
189 let mut result = String::new();
190 let mut has_e = false;
191
192 while let Some(ch) = self.current_char {
194 if !has_e && (ch.is_numeric() || ch == '.') {
195 result.push(ch);
196 self.advance();
197 } else if (ch == 'e' || ch == 'E') && !has_e && !result.is_empty() {
198 result.push(ch);
200 self.advance();
201 has_e = true;
202
203 if let Some(sign) = self.current_char {
205 if sign == '+' || sign == '-' {
206 result.push(sign);
207 self.advance();
208 }
209 }
210
211 while let Some(digit) = self.current_char {
213 if digit.is_numeric() {
214 result.push(digit);
215 self.advance();
216 } else {
217 break;
218 }
219 }
220 break; } else {
222 break;
223 }
224 }
225 result
226 }
227
228 pub fn next_token(&mut self) -> Token {
229 self.skip_whitespace_and_comments();
230
231 match self.current_char {
232 None => Token::Eof,
233 Some('*') => {
234 self.advance();
235 Token::Star }
239 Some('+') => {
240 self.advance();
241 Token::Plus
242 }
243 Some('/') => {
244 if self.peek(1) == Some('*') {
246 self.skip_whitespace_and_comments();
249 return self.next_token();
250 }
251 self.advance();
252 Token::Divide
253 }
254 Some('%') => {
255 self.advance();
256 Token::Modulo
257 }
258 Some('.') => {
259 self.advance();
260 Token::Dot
261 }
262 Some(',') => {
263 self.advance();
264 Token::Comma
265 }
266 Some('(') => {
267 self.advance();
268 Token::LeftParen
269 }
270 Some(')') => {
271 self.advance();
272 Token::RightParen
273 }
274 Some('=') => {
275 self.advance();
276 Token::Equal
277 }
278 Some('<') => {
279 self.advance();
280 if self.current_char == Some('=') {
281 self.advance();
282 Token::LessThanOrEqual
283 } else if self.current_char == Some('>') {
284 self.advance();
285 Token::NotEqual
286 } else {
287 Token::LessThan
288 }
289 }
290 Some('>') => {
291 self.advance();
292 if self.current_char == Some('=') {
293 self.advance();
294 Token::GreaterThanOrEqual
295 } else {
296 Token::GreaterThan
297 }
298 }
299 Some('!') if self.peek(1) == Some('=') => {
300 self.advance();
301 self.advance();
302 Token::NotEqual
303 }
304 Some('"') => {
305 let ident_val = self.read_string();
307 Token::QuotedIdentifier(ident_val)
308 }
309 Some('\'') => {
310 let string_val = self.read_string();
312 Token::StringLiteral(string_val)
313 }
314 Some('-') if self.peek(1) == Some('-') => {
315 self.skip_whitespace_and_comments();
317 self.next_token()
318 }
319 Some('-') if self.peek(1).is_some_and(char::is_numeric) => {
320 self.advance(); let num = self.read_number();
323 Token::NumberLiteral(format!("-{num}"))
324 }
325 Some('-') => {
326 self.advance();
328 Token::Minus
329 }
330 Some(ch) if ch.is_numeric() => {
331 let num = self.read_number();
332 Token::NumberLiteral(num)
333 }
334 Some(ch) if ch.is_alphabetic() || ch == '_' => {
335 let ident = self.read_identifier();
336 match ident.to_uppercase().as_str() {
337 "SELECT" => Token::Select,
338 "FROM" => Token::From,
339 "WHERE" => Token::Where,
340 "WITH" => Token::With,
341 "AND" => Token::And,
342 "OR" => Token::Or,
343 "IN" => Token::In,
344 "NOT" => Token::Not,
345 "BETWEEN" => Token::Between,
346 "LIKE" => Token::Like,
347 "IS" => Token::Is,
348 "NULL" => Token::Null,
349 "ORDER" if self.peek_keyword("BY") => {
350 self.skip_whitespace();
351 self.read_identifier(); Token::OrderBy
353 }
354 "GROUP" if self.peek_keyword("BY") => {
355 self.skip_whitespace();
356 self.read_identifier(); Token::GroupBy
358 }
359 "HAVING" => Token::Having,
360 "AS" => Token::As,
361 "ASC" => Token::Asc,
362 "DESC" => Token::Desc,
363 "LIMIT" => Token::Limit,
364 "OFFSET" => Token::Offset,
365 "DATETIME" => Token::DateTime,
366 "CASE" => Token::Case,
367 "WHEN" => Token::When,
368 "THEN" => Token::Then,
369 "ELSE" => Token::Else,
370 "END" => Token::End,
371 "DISTINCT" => Token::Distinct,
372 "OVER" => Token::Over,
373 "PARTITION" => Token::Partition,
374 "BY" => Token::By,
375 "JOIN" => Token::Join,
377 "INNER" => Token::Inner,
378 "LEFT" => Token::Left,
379 "RIGHT" => Token::Right,
380 "FULL" => Token::Full,
381 "OUTER" => Token::Outer,
382 "ON" => Token::On,
383 "CROSS" => Token::Cross,
384 _ => Token::Identifier(ident),
385 }
386 }
387 Some(ch) => {
388 self.advance();
389 Token::Identifier(ch.to_string())
390 }
391 }
392 }
393
394 fn peek_keyword(&mut self, keyword: &str) -> bool {
395 let saved_pos = self.position;
396 let saved_char = self.current_char;
397
398 self.skip_whitespace_and_comments();
399 let next_word = self.read_identifier();
400 let matches = next_word.to_uppercase() == keyword;
401
402 self.position = saved_pos;
404 self.current_char = saved_char;
405
406 matches
407 }
408
409 #[must_use]
410 pub fn get_position(&self) -> usize {
411 self.position
412 }
413
414 pub fn tokenize_all(&mut self) -> Vec<Token> {
415 let mut tokens = Vec::new();
416 loop {
417 let token = self.next_token();
418 if matches!(token, Token::Eof) {
419 tokens.push(token);
420 break;
421 }
422 tokens.push(token);
423 }
424 tokens
425 }
426
427 pub fn tokenize_all_with_positions(&mut self) -> Vec<(usize, usize, Token)> {
428 let mut tokens = Vec::new();
429 loop {
430 self.skip_whitespace_and_comments();
431 let start_pos = self.position;
432 let token = self.next_token();
433 let end_pos = self.position;
434
435 if matches!(token, Token::Eof) {
436 break;
437 }
438 tokens.push((start_pos, end_pos, token));
439 }
440 tokens
441 }
442}
443
444#[derive(Debug, Clone)]
446pub enum SqlExpression {
447 Column(String),
448 StringLiteral(String),
449 NumberLiteral(String),
450 BooleanLiteral(bool),
451 Null, DateTimeConstructor {
453 year: i32,
454 month: u32,
455 day: u32,
456 hour: Option<u32>,
457 minute: Option<u32>,
458 second: Option<u32>,
459 },
460 DateTimeToday {
461 hour: Option<u32>,
462 minute: Option<u32>,
463 second: Option<u32>,
464 },
465 MethodCall {
466 object: String,
467 method: String,
468 args: Vec<SqlExpression>,
469 },
470 ChainedMethodCall {
471 base: Box<SqlExpression>,
472 method: String,
473 args: Vec<SqlExpression>,
474 },
475 FunctionCall {
476 name: String,
477 args: Vec<SqlExpression>,
478 distinct: bool, },
480 WindowFunction {
481 name: String,
482 args: Vec<SqlExpression>,
483 window_spec: WindowSpec,
484 },
485 BinaryOp {
486 left: Box<SqlExpression>,
487 op: String,
488 right: Box<SqlExpression>,
489 },
490 InList {
491 expr: Box<SqlExpression>,
492 values: Vec<SqlExpression>,
493 },
494 NotInList {
495 expr: Box<SqlExpression>,
496 values: Vec<SqlExpression>,
497 },
498 Between {
499 expr: Box<SqlExpression>,
500 lower: Box<SqlExpression>,
501 upper: Box<SqlExpression>,
502 },
503 Not {
504 expr: Box<SqlExpression>,
505 },
506 CaseExpression {
507 when_branches: Vec<WhenBranch>,
508 else_branch: Option<Box<SqlExpression>>,
509 },
510}
511
512#[derive(Debug, Clone)]
513pub struct WhenBranch {
514 pub condition: Box<SqlExpression>,
515 pub result: Box<SqlExpression>,
516}
517
518#[derive(Debug, Clone)]
519pub struct WhereClause {
520 pub conditions: Vec<Condition>,
521}
522
523#[derive(Debug, Clone)]
524pub struct Condition {
525 pub expr: SqlExpression,
526 pub connector: Option<LogicalOp>, }
528
529#[derive(Debug, Clone)]
530pub enum LogicalOp {
531 And,
532 Or,
533}
534
535#[derive(Debug, Clone, PartialEq)]
536pub enum SortDirection {
537 Asc,
538 Desc,
539}
540
541#[derive(Debug, Clone)]
542pub struct OrderByColumn {
543 pub column: String,
544 pub direction: SortDirection,
545}
546
547#[derive(Debug, Clone)]
548pub struct WindowSpec {
549 pub partition_by: Vec<String>,
550 pub order_by: Vec<OrderByColumn>,
551}
552
553#[derive(Debug, Clone)]
555pub enum SelectItem {
556 Column(String),
558 Expression { expr: SqlExpression, alias: String },
560 Star,
562}
563
564#[derive(Debug, Clone)]
565pub struct SelectStatement {
566 pub distinct: bool, pub columns: Vec<String>, pub select_items: Vec<SelectItem>, pub from_table: Option<String>,
570 pub from_subquery: Option<Box<SelectStatement>>, pub from_function: Option<TableFunction>, pub from_alias: Option<String>, pub joins: Vec<JoinClause>, pub where_clause: Option<WhereClause>,
575 pub order_by: Option<Vec<OrderByColumn>>,
576 pub group_by: Option<Vec<String>>,
577 pub having: Option<SqlExpression>, pub limit: Option<usize>,
579 pub offset: Option<usize>,
580 pub ctes: Vec<CTE>, }
582
583#[derive(Debug, Clone)]
585pub enum TableFunction {
586 Range {
587 start: SqlExpression,
588 end: SqlExpression,
589 step: Option<SqlExpression>,
590 },
591}
592
593#[derive(Debug, Clone)]
595pub struct CTE {
596 pub name: String,
597 pub column_list: Option<Vec<String>>, pub query: SelectStatement,
599}
600
601#[derive(Debug, Clone)]
603pub enum TableSource {
604 Table(String), DerivedTable {
606 query: Box<SelectStatement>,
608 alias: String, },
610}
611
612#[derive(Debug, Clone, PartialEq)]
614pub enum JoinType {
615 Inner,
616 Left,
617 Right,
618 Full,
619 Cross,
620}
621
622#[derive(Debug, Clone)]
624pub enum JoinOperator {
625 Equal,
626 NotEqual,
627 LessThan,
628 GreaterThan,
629 LessThanOrEqual,
630 GreaterThanOrEqual,
631}
632
633#[derive(Debug, Clone)]
635pub struct JoinCondition {
636 pub left_column: String, pub operator: JoinOperator, pub right_column: String, }
640
641#[derive(Debug, Clone)]
643pub struct JoinClause {
644 pub join_type: JoinType,
645 pub table: TableSource, pub alias: Option<String>, pub condition: JoinCondition, }
649
650#[derive(Default)]
651pub struct ParserConfig {
652 pub case_insensitive: bool,
653}
654
655pub struct Parser {
656 lexer: Lexer,
657 current_token: Token,
658 in_method_args: bool, columns: Vec<String>, paren_depth: i32, #[allow(dead_code)]
662 config: ParserConfig, }
664
665impl Parser {
666 #[must_use]
667 pub fn new(input: &str) -> Self {
668 let mut lexer = Lexer::new(input);
669 let current_token = lexer.next_token();
670 Self {
671 lexer,
672 current_token,
673 in_method_args: false,
674 columns: Vec::new(),
675 paren_depth: 0,
676 config: ParserConfig::default(),
677 }
678 }
679
680 #[must_use]
681 pub fn with_config(input: &str, config: ParserConfig) -> Self {
682 let mut lexer = Lexer::new(input);
683 let current_token = lexer.next_token();
684 Self {
685 lexer,
686 current_token,
687 in_method_args: false,
688 columns: Vec::new(),
689 paren_depth: 0,
690 config,
691 }
692 }
693
694 #[must_use]
695 pub fn with_columns(mut self, columns: Vec<String>) -> Self {
696 self.columns = columns;
697 self
698 }
699
700 fn consume(&mut self, expected: Token) -> Result<(), String> {
701 if std::mem::discriminant(&self.current_token) == std::mem::discriminant(&expected) {
702 match &expected {
704 Token::LeftParen => self.paren_depth += 1,
705 Token::RightParen => {
706 self.paren_depth -= 1;
707 if self.paren_depth < 0 {
709 return Err(
710 "Unexpected closing parenthesis - no matching opening parenthesis"
711 .to_string(),
712 );
713 }
714 }
715 _ => {}
716 }
717
718 self.current_token = self.lexer.next_token();
719 Ok(())
720 } else {
721 let error_msg = match (&expected, &self.current_token) {
723 (Token::RightParen, Token::Eof) if self.paren_depth > 0 => {
724 format!(
725 "Unclosed parenthesis - missing {} closing parenthes{}",
726 self.paren_depth,
727 if self.paren_depth == 1 { "is" } else { "es" }
728 )
729 }
730 (Token::RightParen, _) if self.paren_depth > 0 => {
731 format!(
732 "Expected closing parenthesis but found {:?} (currently {} unclosed parenthes{})",
733 self.current_token,
734 self.paren_depth,
735 if self.paren_depth == 1 { "is" } else { "es" }
736 )
737 }
738 _ => format!("Expected {:?}, found {:?}", expected, self.current_token),
739 };
740 Err(error_msg)
741 }
742 }
743
744 fn advance(&mut self) {
745 match &self.current_token {
747 Token::LeftParen => self.paren_depth += 1,
748 Token::RightParen => {
749 self.paren_depth -= 1;
750 }
753 _ => {}
754 }
755 self.current_token = self.lexer.next_token();
756 }
757
758 pub fn parse(&mut self) -> Result<SelectStatement, String> {
759 if matches!(self.current_token, Token::With) {
761 self.parse_with_clause()
762 } else {
763 self.parse_select_statement()
764 }
765 }
766
767 fn parse_with_clause(&mut self) -> Result<SelectStatement, String> {
768 self.consume(Token::With)?;
769
770 let mut ctes = Vec::new();
771
772 loop {
774 let name = match &self.current_token {
776 Token::Identifier(name) => name.clone(),
777 _ => return Err("Expected CTE name after WITH".to_string()),
778 };
779 self.advance();
780
781 let column_list = if matches!(self.current_token, Token::LeftParen) {
783 self.advance();
784 let cols = self.parse_identifier_list()?;
785 self.consume(Token::RightParen)?;
786 Some(cols)
787 } else {
788 None
789 };
790
791 self.consume(Token::As)?;
793
794 self.consume(Token::LeftParen)?;
796
797 let query = self.parse_select_statement_inner()?;
799
800 self.consume(Token::RightParen)?;
802
803 ctes.push(CTE {
804 name,
805 column_list,
806 query,
807 });
808
809 if !matches!(self.current_token, Token::Comma) {
811 break;
812 }
813 self.advance();
814 }
815
816 let mut main_query = self.parse_select_statement()?;
818 main_query.ctes = ctes;
819
820 Ok(main_query)
821 }
822
823 fn parse_select_statement(&mut self) -> Result<SelectStatement, String> {
824 let result = self.parse_select_statement_inner()?;
825
826 if self.paren_depth > 0 {
828 return Err(format!(
829 "Unclosed parenthesis - missing {} closing parenthes{}",
830 self.paren_depth,
831 if self.paren_depth == 1 { "is" } else { "es" }
832 ));
833 } else if self.paren_depth < 0 {
834 return Err(
835 "Extra closing parenthesis found - no matching opening parenthesis".to_string(),
836 );
837 }
838
839 Ok(result)
840 }
841
842 fn parse_select_statement_inner(&mut self) -> Result<SelectStatement, String> {
843 self.consume(Token::Select)?;
844
845 let distinct = if matches!(self.current_token, Token::Distinct) {
847 self.advance();
848 true
849 } else {
850 false
851 };
852
853 let select_items = self.parse_select_items()?;
855
856 let columns = select_items
858 .iter()
859 .map(|item| match item {
860 SelectItem::Star => "*".to_string(),
861 SelectItem::Column(name) => name.clone(),
862 SelectItem::Expression { alias, .. } => alias.clone(),
863 })
864 .collect();
865
866 let (from_table, from_subquery, from_function, from_alias) =
868 if matches!(self.current_token, Token::From) {
869 self.advance();
870
871 if let Token::Identifier(name) = &self.current_token.clone() {
873 if name.to_uppercase() == "RANGE" {
874 self.advance();
875 self.consume(Token::LeftParen)?;
877
878 let start = self.parse_expression()?;
880 self.consume(Token::Comma)?;
881
882 let end = self.parse_expression()?;
884
885 let step = if matches!(self.current_token, Token::Comma) {
887 self.advance();
888 Some(self.parse_expression()?)
889 } else {
890 None
891 };
892
893 self.consume(Token::RightParen)?;
894
895 let alias = if matches!(self.current_token, Token::As) {
897 self.advance();
898 match &self.current_token {
899 Token::Identifier(name) => {
900 let alias = name.clone();
901 self.advance();
902 Some(alias)
903 }
904 _ => return Err("Expected alias name after AS".to_string()),
905 }
906 } else if let Token::Identifier(name) = &self.current_token {
907 let alias = name.clone();
908 self.advance();
909 Some(alias)
910 } else {
911 None
912 };
913
914 (
915 None,
916 None,
917 Some(TableFunction::Range { start, end, step }),
918 alias,
919 )
920 } else {
921 let table_name = name.clone();
923 self.advance();
924
925 let alias = if matches!(self.current_token, Token::As) {
927 self.advance();
928 match &self.current_token {
929 Token::Identifier(name) => {
930 let alias = name.clone();
931 self.advance();
932 Some(alias)
933 }
934 _ => return Err("Expected alias name after AS".to_string()),
935 }
936 } else if let Token::Identifier(name) = &self.current_token {
937 let alias = name.clone();
939 self.advance();
940 Some(alias)
941 } else {
942 None
943 };
944
945 (Some(table_name), None, None, alias)
946 }
947 } else if matches!(self.current_token, Token::LeftParen) {
948 self.advance();
950
951 let subquery = self.parse_select_statement_inner()?;
953
954 self.consume(Token::RightParen)?;
955
956 let alias = if matches!(self.current_token, Token::As) {
958 self.advance();
959 match &self.current_token {
960 Token::Identifier(name) => {
961 let alias = name.clone();
962 self.advance();
963 alias
964 }
965 _ => return Err("Expected alias name after AS".to_string()),
966 }
967 } else {
968 match &self.current_token {
970 Token::Identifier(name) => {
971 let alias = name.clone();
972 self.advance();
973 alias
974 }
975 _ => {
976 return Err(
977 "Subquery in FROM must have an alias (e.g., AS t)".to_string()
978 )
979 }
980 }
981 };
982
983 (None, Some(Box::new(subquery)), None, Some(alias))
984 } else {
985 match &self.current_token {
987 Token::Identifier(table) => {
988 let table_name = table.clone();
989 self.advance();
990
991 let alias = if matches!(self.current_token, Token::As) {
993 self.advance();
994 match &self.current_token {
995 Token::Identifier(name) => {
996 let alias = name.clone();
997 self.advance();
998 Some(alias)
999 }
1000 _ => return Err("Expected alias name after AS".to_string()),
1001 }
1002 } else if let Token::Identifier(name) = &self.current_token {
1003 let alias = name.clone();
1005 self.advance();
1006 Some(alias)
1007 } else {
1008 None
1009 };
1010
1011 (Some(table_name), None, None, alias)
1012 }
1013 Token::QuotedIdentifier(table) => {
1014 let table_name = table.clone();
1016 self.advance();
1017
1018 let alias = if matches!(self.current_token, Token::As) {
1020 self.advance();
1021 match &self.current_token {
1022 Token::Identifier(name) => {
1023 let alias = name.clone();
1024 self.advance();
1025 Some(alias)
1026 }
1027 _ => return Err("Expected alias name after AS".to_string()),
1028 }
1029 } else if let Token::Identifier(name) = &self.current_token {
1030 let alias = name.clone();
1032 self.advance();
1033 Some(alias)
1034 } else {
1035 None
1036 };
1037
1038 (Some(table_name), None, None, alias)
1039 }
1040 _ => return Err("Expected table name or subquery after FROM".to_string()),
1041 }
1042 }
1043 } else {
1044 (None, None, None, None)
1045 };
1046
1047 let mut joins = Vec::new();
1049 while self.is_join_token() {
1050 joins.push(self.parse_join_clause()?);
1051 }
1052
1053 let where_clause = if matches!(self.current_token, Token::Where) {
1054 self.advance();
1055 Some(self.parse_where_clause()?)
1056 } else {
1057 None
1058 };
1059
1060 let group_by = if matches!(self.current_token, Token::GroupBy) {
1061 self.advance();
1062 Some(self.parse_identifier_list()?)
1063 } else {
1064 None
1065 };
1066
1067 let having = if matches!(self.current_token, Token::Having) {
1069 if group_by.is_none() {
1070 return Err("HAVING clause requires GROUP BY".to_string());
1071 }
1072 self.advance();
1073 Some(self.parse_expression()?)
1074 } else {
1075 None
1076 };
1077
1078 let order_by = if matches!(self.current_token, Token::OrderBy) {
1080 self.advance();
1081 Some(self.parse_order_by_list()?)
1082 } else if let Token::Identifier(s) = &self.current_token {
1083 if s.to_uppercase() == "ORDER" {
1084 self.advance(); if matches!(&self.current_token, Token::Identifier(by_token) if by_token.to_uppercase() == "BY")
1087 {
1088 self.advance(); Some(self.parse_order_by_list()?)
1090 } else {
1091 return Err("Expected BY after ORDER".to_string());
1092 }
1093 } else {
1094 None
1095 }
1096 } else {
1097 None
1098 };
1099
1100 let limit = if matches!(self.current_token, Token::Limit) {
1102 self.advance();
1103 match &self.current_token {
1104 Token::NumberLiteral(num) => {
1105 let limit_val = num
1106 .parse::<usize>()
1107 .map_err(|_| format!("Invalid LIMIT value: {num}"))?;
1108 self.advance();
1109 Some(limit_val)
1110 }
1111 _ => return Err("Expected number after LIMIT".to_string()),
1112 }
1113 } else {
1114 None
1115 };
1116
1117 let offset = if matches!(self.current_token, Token::Offset) {
1119 self.advance();
1120 match &self.current_token {
1121 Token::NumberLiteral(num) => {
1122 let offset_val = num
1123 .parse::<usize>()
1124 .map_err(|_| format!("Invalid OFFSET value: {num}"))?;
1125 self.advance();
1126 Some(offset_val)
1127 }
1128 _ => return Err("Expected number after OFFSET".to_string()),
1129 }
1130 } else {
1131 None
1132 };
1133
1134 Ok(SelectStatement {
1135 distinct,
1136 columns,
1137 select_items,
1138 from_table,
1139 from_subquery,
1140 from_function,
1141 from_alias,
1142 joins,
1143 where_clause,
1144 order_by,
1145 group_by,
1146 having,
1147 limit,
1148 offset,
1149 ctes: Vec::new(), })
1151 }
1152
1153 fn parse_select_list(&mut self) -> Result<Vec<String>, String> {
1154 let mut columns = Vec::new();
1155
1156 if matches!(self.current_token, Token::Star) {
1157 columns.push("*".to_string());
1158 self.advance();
1159 } else {
1160 loop {
1161 match &self.current_token {
1162 Token::Identifier(col) => {
1163 columns.push(col.clone());
1164 self.advance();
1165 }
1166 Token::QuotedIdentifier(col) => {
1167 columns.push(col.clone());
1169 self.advance();
1170 }
1171 _ => return Err("Expected column name".to_string()),
1172 }
1173
1174 if matches!(self.current_token, Token::Comma) {
1175 self.advance();
1176 } else {
1177 break;
1178 }
1179 }
1180 }
1181
1182 Ok(columns)
1183 }
1184
1185 fn parse_select_items(&mut self) -> Result<Vec<SelectItem>, String> {
1187 let mut items = Vec::new();
1188
1189 loop {
1190 if matches!(self.current_token, Token::Star) {
1193 items.push(SelectItem::Star);
1201 self.advance();
1202 } else {
1203 let expr = self.parse_comparison()?; let alias = if matches!(self.current_token, Token::As) {
1208 self.advance();
1209 match &self.current_token {
1210 Token::Identifier(alias_name) => {
1211 let alias = alias_name.clone();
1212 self.advance();
1213 alias
1214 }
1215 Token::QuotedIdentifier(alias_name) => {
1216 let alias = alias_name.clone();
1217 self.advance();
1218 alias
1219 }
1220 _ => return Err("Expected alias name after AS".to_string()),
1221 }
1222 } else {
1223 match &expr {
1225 SqlExpression::Column(col_name) => col_name.clone(),
1226 _ => format!("expr_{}", items.len() + 1), }
1228 };
1229
1230 let item = match expr {
1232 SqlExpression::Column(col_name) if alias == col_name => {
1233 SelectItem::Column(col_name)
1235 }
1236 _ => {
1237 SelectItem::Expression { expr, alias }
1239 }
1240 };
1241
1242 items.push(item);
1243 }
1244
1245 if matches!(self.current_token, Token::Comma) {
1247 self.advance();
1248 } else {
1249 break;
1250 }
1251 }
1252
1253 Ok(items)
1254 }
1255
1256 fn parse_identifier_list(&mut self) -> Result<Vec<String>, String> {
1257 let mut identifiers = Vec::new();
1258
1259 loop {
1260 match &self.current_token {
1261 Token::Identifier(id) => {
1262 let id_upper = id.to_uppercase();
1264 if matches!(
1265 id_upper.as_str(),
1266 "ORDER" | "HAVING" | "LIMIT" | "OFFSET" | "UNION" | "INTERSECT" | "EXCEPT"
1267 ) {
1268 break;
1270 }
1271 identifiers.push(id.clone());
1272 self.advance();
1273 }
1274 Token::QuotedIdentifier(id) => {
1275 identifiers.push(id.clone());
1277 self.advance();
1278 }
1279 _ => {
1280 break;
1282 }
1283 }
1284
1285 if matches!(self.current_token, Token::Comma) {
1286 self.advance();
1287 } else {
1288 break;
1289 }
1290 }
1291
1292 if identifiers.is_empty() {
1293 return Err("Expected at least one identifier".to_string());
1294 }
1295
1296 Ok(identifiers)
1297 }
1298
1299 fn parse_window_spec(&mut self) -> Result<WindowSpec, String> {
1300 let mut partition_by = Vec::new();
1301 let mut order_by = Vec::new();
1302
1303 if matches!(self.current_token, Token::Partition) {
1305 self.advance(); if !matches!(self.current_token, Token::By) {
1307 return Err("Expected BY after PARTITION".to_string());
1308 }
1309 self.advance(); partition_by = self.parse_identifier_list()?;
1313 }
1314
1315 if matches!(self.current_token, Token::OrderBy) {
1317 self.advance(); order_by = self.parse_order_by_list()?;
1319 } else if let Token::Identifier(s) = &self.current_token {
1320 if s.to_uppercase() == "ORDER" {
1321 self.advance(); if !matches!(self.current_token, Token::By) {
1324 return Err("Expected BY after ORDER".to_string());
1325 }
1326 self.advance(); order_by = self.parse_order_by_list()?;
1328 }
1329 }
1330
1331 Ok(WindowSpec {
1332 partition_by,
1333 order_by,
1334 })
1335 }
1336
1337 fn parse_order_by_list(&mut self) -> Result<Vec<OrderByColumn>, String> {
1338 let mut order_columns = Vec::new();
1339
1340 loop {
1341 let column = match &self.current_token {
1342 Token::Identifier(id) => {
1343 let col = id.clone();
1344 self.advance();
1345 col
1346 }
1347 Token::QuotedIdentifier(id) => {
1348 let col = id.clone();
1349 self.advance();
1350 col
1351 }
1352 Token::NumberLiteral(num) if self.columns.iter().any(|col| col == num) => {
1353 let col = num.clone();
1355 self.advance();
1356 col
1357 }
1358 _ => return Err("Expected column name in ORDER BY".to_string()),
1359 };
1360
1361 let direction = match &self.current_token {
1363 Token::Asc => {
1364 self.advance();
1365 SortDirection::Asc
1366 }
1367 Token::Desc => {
1368 self.advance();
1369 SortDirection::Desc
1370 }
1371 _ => SortDirection::Asc, };
1373
1374 order_columns.push(OrderByColumn { column, direction });
1375
1376 if matches!(self.current_token, Token::Comma) {
1377 self.advance();
1378 } else {
1379 break;
1380 }
1381 }
1382
1383 Ok(order_columns)
1384 }
1385
1386 fn parse_where_clause(&mut self) -> Result<WhereClause, String> {
1387 let mut conditions = Vec::new();
1388
1389 loop {
1390 let expr = self.parse_expression()?;
1391
1392 let connector = match &self.current_token {
1393 Token::And => {
1394 self.advance();
1395 Some(LogicalOp::And)
1396 }
1397 Token::Or => {
1398 self.advance();
1399 Some(LogicalOp::Or)
1400 }
1401 Token::RightParen if self.paren_depth <= 0 => {
1402 return Err(
1404 "Unexpected closing parenthesis - no matching opening parenthesis"
1405 .to_string(),
1406 );
1407 }
1408 _ => None,
1409 };
1410
1411 conditions.push(Condition {
1412 expr,
1413 connector: connector.clone(),
1414 });
1415
1416 if connector.is_none() {
1417 break;
1418 }
1419 }
1420
1421 Ok(WhereClause { conditions })
1422 }
1423
1424 fn parse_expression(&mut self) -> Result<SqlExpression, String> {
1425 let mut left = self.parse_comparison()?;
1426
1427 if let Some(op) = self.get_binary_op() {
1430 self.advance();
1431 let right = self.parse_expression()?;
1432 left = SqlExpression::BinaryOp {
1433 left: Box::new(left),
1434 op,
1435 right: Box::new(right),
1436 };
1437 }
1438
1439 if matches!(self.current_token, Token::In) {
1441 self.advance();
1442 self.consume(Token::LeftParen)?;
1443 let values = self.parse_expression_list()?;
1444 self.consume(Token::RightParen)?;
1445
1446 left = SqlExpression::InList {
1447 expr: Box::new(left),
1448 values,
1449 };
1450 }
1451
1452 Ok(left)
1456 }
1457
1458 fn parse_comparison(&mut self) -> Result<SqlExpression, String> {
1459 let mut left = self.parse_additive()?;
1460
1461 if matches!(self.current_token, Token::Between) {
1463 self.advance(); let lower = self.parse_primary()?;
1465 self.consume(Token::And)?; let upper = self.parse_primary()?;
1467
1468 return Ok(SqlExpression::Between {
1469 expr: Box::new(left),
1470 lower: Box::new(lower),
1471 upper: Box::new(upper),
1472 });
1473 }
1474
1475 if matches!(self.current_token, Token::Not) {
1477 self.advance(); if matches!(self.current_token, Token::In) {
1479 self.advance(); self.consume(Token::LeftParen)?;
1481 let values = self.parse_expression_list()?;
1482 self.consume(Token::RightParen)?;
1483
1484 return Ok(SqlExpression::NotInList {
1485 expr: Box::new(left),
1486 values,
1487 });
1488 }
1489 return Err("Expected IN after NOT".to_string());
1490 }
1491
1492 if matches!(self.current_token, Token::Is) {
1494 self.advance(); if matches!(self.current_token, Token::Not) {
1496 self.advance(); if matches!(self.current_token, Token::Null) {
1498 self.advance(); left = SqlExpression::BinaryOp {
1500 left: Box::new(left),
1501 op: "IS NOT NULL".to_string(),
1502 right: Box::new(SqlExpression::Null),
1503 };
1504 } else {
1505 return Err("Expected NULL after IS NOT".to_string());
1506 }
1507 } else if matches!(self.current_token, Token::Null) {
1508 self.advance(); left = SqlExpression::BinaryOp {
1510 left: Box::new(left),
1511 op: "IS NULL".to_string(),
1512 right: Box::new(SqlExpression::Null),
1513 };
1514 } else {
1515 return Err("Expected NULL or NOT after IS".to_string());
1516 }
1517 }
1518 else if let Some(op) = self.get_binary_op() {
1520 self.advance();
1521 let right = self.parse_additive()?;
1522 left = SqlExpression::BinaryOp {
1523 left: Box::new(left),
1524 op,
1525 right: Box::new(right),
1526 };
1527 }
1528
1529 Ok(left)
1530 }
1531
1532 fn parse_additive(&mut self) -> Result<SqlExpression, String> {
1533 let mut left = self.parse_multiplicative()?;
1534
1535 while matches!(self.current_token, Token::Plus | Token::Minus) {
1536 let op = match self.current_token {
1537 Token::Plus => "+",
1538 Token::Minus => "-",
1539 _ => unreachable!(),
1540 };
1541 self.advance();
1542 let right = self.parse_multiplicative()?;
1543 left = SqlExpression::BinaryOp {
1544 left: Box::new(left),
1545 op: op.to_string(),
1546 right: Box::new(right),
1547 };
1548 }
1549
1550 Ok(left)
1551 }
1552
1553 fn parse_multiplicative(&mut self) -> Result<SqlExpression, String> {
1554 let mut left = self.parse_primary()?;
1555
1556 while matches!(self.current_token, Token::Dot) {
1558 self.advance();
1559 if let Token::Identifier(method) = &self.current_token {
1560 let method_name = method.clone();
1561 self.advance();
1562
1563 if matches!(self.current_token, Token::LeftParen) {
1564 self.advance();
1565 let args = self.parse_method_args()?;
1566 self.consume(Token::RightParen)?;
1567
1568 match left {
1570 SqlExpression::Column(obj) => {
1571 left = SqlExpression::MethodCall {
1573 object: obj,
1574 method: method_name,
1575 args,
1576 };
1577 }
1578 SqlExpression::MethodCall { .. }
1579 | SqlExpression::ChainedMethodCall { .. } => {
1580 left = SqlExpression::ChainedMethodCall {
1582 base: Box::new(left),
1583 method: method_name,
1584 args,
1585 };
1586 }
1587 _ => {
1588 left = SqlExpression::ChainedMethodCall {
1590 base: Box::new(left),
1591 method: method_name,
1592 args,
1593 };
1594 }
1595 }
1596 } else {
1597 return Err(format!("Expected '(' after method name '{method_name}'"));
1598 }
1599 } else {
1600 return Err("Expected method name after '.'".to_string());
1601 }
1602 }
1603
1604 while matches!(
1605 self.current_token,
1606 Token::Star | Token::Divide | Token::Modulo
1607 ) {
1608 let op = match self.current_token {
1609 Token::Star => "*",
1610 Token::Divide => "/",
1611 Token::Modulo => "%",
1612 _ => unreachable!(),
1613 };
1614 self.advance();
1615 let right = self.parse_primary()?;
1616 left = SqlExpression::BinaryOp {
1617 left: Box::new(left),
1618 op: op.to_string(),
1619 right: Box::new(right),
1620 };
1621 }
1622
1623 Ok(left)
1624 }
1625
1626 fn parse_logical_or(&mut self) -> Result<SqlExpression, String> {
1627 let mut left = self.parse_logical_and()?;
1628
1629 while matches!(self.current_token, Token::Or) {
1630 self.advance();
1631 let right = self.parse_logical_and()?;
1632 left = SqlExpression::BinaryOp {
1636 left: Box::new(left),
1637 op: "OR".to_string(),
1638 right: Box::new(right),
1639 };
1640 }
1641
1642 Ok(left)
1643 }
1644
1645 fn parse_logical_and(&mut self) -> Result<SqlExpression, String> {
1646 let mut left = self.parse_expression()?;
1647
1648 while matches!(self.current_token, Token::And) {
1649 self.advance();
1650 let right = self.parse_expression()?;
1651 left = SqlExpression::BinaryOp {
1653 left: Box::new(left),
1654 op: "AND".to_string(),
1655 right: Box::new(right),
1656 };
1657 }
1658
1659 Ok(left)
1660 }
1661
1662 fn parse_case_expression(&mut self) -> Result<SqlExpression, String> {
1663 self.consume(Token::Case)?;
1665
1666 let mut when_branches = Vec::new();
1667
1668 while matches!(self.current_token, Token::When) {
1670 self.advance(); let condition = self.parse_expression()?;
1674
1675 self.consume(Token::Then)?;
1677
1678 let result = self.parse_expression()?;
1680
1681 when_branches.push(WhenBranch {
1682 condition: Box::new(condition),
1683 result: Box::new(result),
1684 });
1685 }
1686
1687 if when_branches.is_empty() {
1689 return Err("CASE expression must have at least one WHEN clause".to_string());
1690 }
1691
1692 let else_branch = if matches!(self.current_token, Token::Else) {
1694 self.advance(); Some(Box::new(self.parse_expression()?))
1696 } else {
1697 None
1698 };
1699
1700 self.consume(Token::End)?;
1702
1703 Ok(SqlExpression::CaseExpression {
1704 when_branches,
1705 else_branch,
1706 })
1707 }
1708
1709 fn parse_primary(&mut self) -> Result<SqlExpression, String> {
1710 if let Token::NumberLiteral(num_str) = &self.current_token {
1713 if self.columns.iter().any(|col| col == num_str) {
1715 let expr = SqlExpression::Column(num_str.clone());
1716 self.advance();
1717 return Ok(expr);
1718 }
1719 }
1720
1721 match &self.current_token {
1722 Token::Case => {
1723 self.parse_case_expression()
1725 }
1726 Token::DateTime => {
1727 self.advance(); self.consume(Token::LeftParen)?;
1729
1730 if matches!(&self.current_token, Token::RightParen) {
1732 self.advance(); return Ok(SqlExpression::DateTimeToday {
1734 hour: None,
1735 minute: None,
1736 second: None,
1737 });
1738 }
1739
1740 let year = if let Token::NumberLiteral(n) = &self.current_token {
1742 n.parse::<i32>().map_err(|_| "Invalid year")?
1743 } else {
1744 return Err("Expected year in DateTime constructor".to_string());
1745 };
1746 self.advance();
1747 self.consume(Token::Comma)?;
1748
1749 let month = if let Token::NumberLiteral(n) = &self.current_token {
1751 n.parse::<u32>().map_err(|_| "Invalid month")?
1752 } else {
1753 return Err("Expected month in DateTime constructor".to_string());
1754 };
1755 self.advance();
1756 self.consume(Token::Comma)?;
1757
1758 let day = if let Token::NumberLiteral(n) = &self.current_token {
1760 n.parse::<u32>().map_err(|_| "Invalid day")?
1761 } else {
1762 return Err("Expected day in DateTime constructor".to_string());
1763 };
1764 self.advance();
1765
1766 let mut hour = None;
1768 let mut minute = None;
1769 let mut second = None;
1770
1771 if matches!(&self.current_token, Token::Comma) {
1772 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1776 hour = Some(n.parse::<u32>().map_err(|_| "Invalid hour")?);
1777 self.advance();
1778
1779 if matches!(&self.current_token, Token::Comma) {
1781 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1784 minute = Some(n.parse::<u32>().map_err(|_| "Invalid minute")?);
1785 self.advance();
1786
1787 if matches!(&self.current_token, Token::Comma) {
1789 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1792 second =
1793 Some(n.parse::<u32>().map_err(|_| "Invalid second")?);
1794 self.advance();
1795 }
1796 }
1797 }
1798 }
1799 }
1800 }
1801
1802 self.consume(Token::RightParen)?;
1803 Ok(SqlExpression::DateTimeConstructor {
1804 year,
1805 month,
1806 day,
1807 hour,
1808 minute,
1809 second,
1810 })
1811 }
1812 Token::Identifier(id) => {
1813 let id_upper = id.to_uppercase();
1814 let id_clone = id.clone();
1815
1816 if id_upper == "TRUE" {
1818 self.advance();
1819 return Ok(SqlExpression::BooleanLiteral(true));
1820 } else if id_upper == "FALSE" {
1821 self.advance();
1822 return Ok(SqlExpression::BooleanLiteral(false));
1823 }
1824
1825 self.advance();
1826
1827 if matches!(self.current_token, Token::LeftParen) {
1829 self.advance(); let (args, has_distinct) = self.parse_function_args()?;
1833 self.consume(Token::RightParen)?;
1834
1835 if matches!(self.current_token, Token::Over) {
1837 self.advance(); self.consume(Token::LeftParen)?;
1839 let window_spec = self.parse_window_spec()?;
1840 self.consume(Token::RightParen)?;
1841 return Ok(SqlExpression::WindowFunction {
1842 name: id_upper,
1843 args,
1844 window_spec,
1845 });
1846 }
1847
1848 return Ok(SqlExpression::FunctionCall {
1849 name: id_upper,
1850 args,
1851 distinct: has_distinct,
1852 });
1853 }
1854
1855 Ok(SqlExpression::Column(id_clone))
1857 }
1858 Token::QuotedIdentifier(id) => {
1859 let expr = if self.in_method_args {
1862 SqlExpression::StringLiteral(id.clone())
1863 } else {
1864 SqlExpression::Column(id.clone())
1866 };
1867 self.advance();
1868 Ok(expr)
1869 }
1870 Token::StringLiteral(s) => {
1871 let expr = SqlExpression::StringLiteral(s.clone());
1872 self.advance();
1873 Ok(expr)
1874 }
1875 Token::NumberLiteral(n) => {
1876 let expr = SqlExpression::NumberLiteral(n.clone());
1877 self.advance();
1878 Ok(expr)
1879 }
1880 Token::Null => {
1881 self.advance();
1882 Ok(SqlExpression::Null)
1883 }
1884 Token::LeftParen => {
1885 self.advance();
1886
1887 let expr = self.parse_logical_or()?;
1890
1891 self.consume(Token::RightParen)?;
1892 Ok(expr)
1893 }
1894 Token::Not => {
1895 self.advance(); if let Ok(inner_expr) = self.parse_comparison() {
1899 if matches!(self.current_token, Token::In) {
1901 self.advance(); self.consume(Token::LeftParen)?;
1903 let values = self.parse_expression_list()?;
1904 self.consume(Token::RightParen)?;
1905
1906 Ok(SqlExpression::NotInList {
1907 expr: Box::new(inner_expr),
1908 values,
1909 })
1910 } else {
1911 Ok(SqlExpression::Not {
1913 expr: Box::new(inner_expr),
1914 })
1915 }
1916 } else {
1917 Err("Expected expression after NOT".to_string())
1918 }
1919 }
1920 Token::Star => {
1921 self.advance();
1923 Ok(SqlExpression::StringLiteral("*".to_string()))
1924 }
1925 _ => Err(format!("Unexpected token: {:?}", self.current_token)),
1926 }
1927 }
1928
1929 fn parse_method_args(&mut self) -> Result<Vec<SqlExpression>, String> {
1930 let mut args = Vec::new();
1931
1932 self.in_method_args = true;
1934
1935 if !matches!(self.current_token, Token::RightParen) {
1936 loop {
1937 args.push(self.parse_expression()?);
1938
1939 if matches!(self.current_token, Token::Comma) {
1940 self.advance();
1941 } else {
1942 break;
1943 }
1944 }
1945 }
1946
1947 self.in_method_args = false;
1949
1950 Ok(args)
1951 }
1952
1953 fn parse_function_args(&mut self) -> Result<(Vec<SqlExpression>, bool), String> {
1954 let mut args = Vec::new();
1955 let mut has_distinct = false;
1956
1957 if !matches!(self.current_token, Token::RightParen) {
1958 if matches!(self.current_token, Token::Distinct) {
1960 self.advance(); has_distinct = true;
1962 }
1963
1964 args.push(self.parse_additive()?);
1966
1967 while matches!(self.current_token, Token::Comma) {
1969 self.advance();
1970 args.push(self.parse_additive()?);
1971 }
1972 }
1973
1974 Ok((args, has_distinct))
1975 }
1976
1977 fn parse_expression_list(&mut self) -> Result<Vec<SqlExpression>, String> {
1978 let mut expressions = Vec::new();
1979
1980 loop {
1981 expressions.push(self.parse_expression()?);
1982
1983 if matches!(self.current_token, Token::Comma) {
1984 self.advance();
1985 } else {
1986 break;
1987 }
1988 }
1989
1990 Ok(expressions)
1991 }
1992
1993 fn get_binary_op(&self) -> Option<String> {
1994 match &self.current_token {
1995 Token::Equal => Some("=".to_string()),
1996 Token::NotEqual => Some("!=".to_string()),
1997 Token::LessThan => Some("<".to_string()),
1998 Token::GreaterThan => Some(">".to_string()),
1999 Token::LessThanOrEqual => Some("<=".to_string()),
2000 Token::GreaterThanOrEqual => Some(">=".to_string()),
2001 Token::Like => Some("LIKE".to_string()),
2002 _ => None,
2003 }
2004 }
2005
2006 fn get_arithmetic_op(&self) -> Option<String> {
2007 match &self.current_token {
2008 Token::Plus => Some("+".to_string()),
2009 Token::Minus => Some("-".to_string()),
2010 Token::Star => Some("*".to_string()), Token::Divide => Some("/".to_string()),
2012 Token::Modulo => Some("%".to_string()),
2013 _ => None,
2014 }
2015 }
2016
2017 #[must_use]
2018 pub fn get_position(&self) -> usize {
2019 self.lexer.get_position()
2020 }
2021
2022 fn is_join_token(&self) -> bool {
2024 matches!(
2025 self.current_token,
2026 Token::Join | Token::Inner | Token::Left | Token::Right | Token::Full | Token::Cross
2027 )
2028 }
2029
2030 fn parse_join_clause(&mut self) -> Result<JoinClause, String> {
2032 let join_type = match &self.current_token {
2034 Token::Join => {
2035 self.advance();
2036 JoinType::Inner }
2038 Token::Inner => {
2039 self.advance();
2040 if !matches!(self.current_token, Token::Join) {
2041 return Err("Expected JOIN after INNER".to_string());
2042 }
2043 self.advance();
2044 JoinType::Inner
2045 }
2046 Token::Left => {
2047 self.advance();
2048 if matches!(self.current_token, Token::Outer) {
2050 self.advance();
2051 }
2052 if !matches!(self.current_token, Token::Join) {
2053 return Err("Expected JOIN after LEFT".to_string());
2054 }
2055 self.advance();
2056 JoinType::Left
2057 }
2058 Token::Right => {
2059 self.advance();
2060 if matches!(self.current_token, Token::Outer) {
2062 self.advance();
2063 }
2064 if !matches!(self.current_token, Token::Join) {
2065 return Err("Expected JOIN after RIGHT".to_string());
2066 }
2067 self.advance();
2068 JoinType::Right
2069 }
2070 Token::Full => {
2071 self.advance();
2072 if matches!(self.current_token, Token::Outer) {
2074 self.advance();
2075 }
2076 if !matches!(self.current_token, Token::Join) {
2077 return Err("Expected JOIN after FULL".to_string());
2078 }
2079 self.advance();
2080 JoinType::Full
2081 }
2082 Token::Cross => {
2083 self.advance();
2084 if !matches!(self.current_token, Token::Join) {
2085 return Err("Expected JOIN after CROSS".to_string());
2086 }
2087 self.advance();
2088 JoinType::Cross
2089 }
2090 _ => return Err("Expected JOIN keyword".to_string()),
2091 };
2092
2093 let (table, alias) = self.parse_join_table_source()?;
2095
2096 let condition = if join_type == JoinType::Cross {
2098 JoinCondition {
2100 left_column: String::new(),
2101 operator: JoinOperator::Equal,
2102 right_column: String::new(),
2103 }
2104 } else {
2105 if !matches!(self.current_token, Token::On) {
2106 return Err("Expected ON keyword after JOIN table".to_string());
2107 }
2108 self.advance();
2109 self.parse_join_condition()?
2110 };
2111
2112 Ok(JoinClause {
2113 join_type,
2114 table,
2115 alias,
2116 condition,
2117 })
2118 }
2119
2120 fn parse_join_table_source(&mut self) -> Result<(TableSource, Option<String>), String> {
2121 let table = match &self.current_token {
2122 Token::Identifier(name) => {
2123 let table_name = name.clone();
2124 self.advance();
2125 TableSource::Table(table_name)
2126 }
2127 Token::LeftParen => {
2128 self.advance();
2130 let subquery = self.parse_select_statement_inner()?;
2131 if !matches!(self.current_token, Token::RightParen) {
2132 return Err("Expected ')' after subquery".to_string());
2133 }
2134 self.advance();
2135
2136 let alias = match &self.current_token {
2138 Token::Identifier(alias_name) => {
2139 let alias = alias_name.clone();
2140 self.advance();
2141 alias
2142 }
2143 Token::As => {
2144 self.advance();
2145 match &self.current_token {
2146 Token::Identifier(alias_name) => {
2147 let alias = alias_name.clone();
2148 self.advance();
2149 alias
2150 }
2151 _ => return Err("Expected alias after AS keyword".to_string()),
2152 }
2153 }
2154 _ => return Err("Subqueries must have an alias".to_string()),
2155 };
2156
2157 return Ok((
2158 TableSource::DerivedTable {
2159 query: Box::new(subquery),
2160 alias: alias.clone(),
2161 },
2162 Some(alias),
2163 ));
2164 }
2165 _ => return Err("Expected table name or subquery in JOIN clause".to_string()),
2166 };
2167
2168 let alias = match &self.current_token {
2170 Token::Identifier(alias_name) => {
2171 let alias = alias_name.clone();
2172 self.advance();
2173 Some(alias)
2174 }
2175 Token::As => {
2176 self.advance();
2177 match &self.current_token {
2178 Token::Identifier(alias_name) => {
2179 let alias = alias_name.clone();
2180 self.advance();
2181 Some(alias)
2182 }
2183 _ => return Err("Expected alias after AS keyword".to_string()),
2184 }
2185 }
2186 _ => None,
2187 };
2188
2189 Ok((table, alias))
2190 }
2191
2192 fn parse_join_condition(&mut self) -> Result<JoinCondition, String> {
2193 let left_column = self.parse_column_reference()?;
2195
2196 let operator = match &self.current_token {
2198 Token::Equal => JoinOperator::Equal,
2199 Token::NotEqual => JoinOperator::NotEqual,
2200 Token::LessThan => JoinOperator::LessThan,
2201 Token::LessThanOrEqual => JoinOperator::LessThanOrEqual,
2202 Token::GreaterThan => JoinOperator::GreaterThan,
2203 Token::GreaterThanOrEqual => JoinOperator::GreaterThanOrEqual,
2204 _ => return Err("Expected comparison operator in JOIN condition".to_string()),
2205 };
2206 self.advance();
2207
2208 let right_column = self.parse_column_reference()?;
2210
2211 Ok(JoinCondition {
2212 left_column,
2213 operator,
2214 right_column,
2215 })
2216 }
2217
2218 fn parse_column_reference(&mut self) -> Result<String, String> {
2219 match &self.current_token {
2220 Token::Identifier(name) => {
2221 let mut column_ref = name.clone();
2222 self.advance();
2223
2224 if matches!(self.current_token, Token::Dot) {
2226 self.advance();
2227 match &self.current_token {
2228 Token::Identifier(col_name) => {
2229 column_ref.push('.');
2230 column_ref.push_str(col_name);
2231 self.advance();
2232 }
2233 _ => return Err("Expected column name after '.'".to_string()),
2234 }
2235 }
2236
2237 Ok(column_ref)
2238 }
2239 _ => Err("Expected column reference".to_string()),
2240 }
2241 }
2242}
2243
2244#[derive(Debug, Clone)]
2246pub enum CursorContext {
2247 SelectClause,
2248 FromClause,
2249 WhereClause,
2250 OrderByClause,
2251 AfterColumn(String),
2252 AfterLogicalOp(LogicalOp),
2253 AfterComparisonOp(String, String), InMethodCall(String, String), InExpression,
2256 Unknown,
2257}
2258
2259fn safe_slice_to(s: &str, pos: usize) -> &str {
2261 if pos >= s.len() {
2262 return s;
2263 }
2264
2265 let mut safe_pos = pos;
2267 while safe_pos > 0 && !s.is_char_boundary(safe_pos) {
2268 safe_pos -= 1;
2269 }
2270
2271 &s[..safe_pos]
2272}
2273
2274fn safe_slice_from(s: &str, pos: usize) -> &str {
2276 if pos >= s.len() {
2277 return "";
2278 }
2279
2280 let mut safe_pos = pos;
2282 while safe_pos < s.len() && !s.is_char_boundary(safe_pos) {
2283 safe_pos += 1;
2284 }
2285
2286 &s[safe_pos..]
2287}
2288
2289#[must_use]
2290pub fn detect_cursor_context(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
2291 let truncated = safe_slice_to(query, cursor_pos);
2292 let mut parser = Parser::new(truncated);
2293
2294 if let Ok(stmt) = parser.parse() {
2296 let (ctx, partial) = analyze_statement(&stmt, truncated, cursor_pos);
2297 #[cfg(test)]
2298 println!("analyze_statement returned: {ctx:?}, {partial:?} for query: '{truncated}'");
2299 (ctx, partial)
2300 } else {
2301 let (ctx, partial) = analyze_partial(truncated, cursor_pos);
2303 #[cfg(test)]
2304 println!("analyze_partial returned: {ctx:?}, {partial:?} for query: '{truncated}'");
2305 (ctx, partial)
2306 }
2307}
2308
2309#[must_use]
2310pub fn tokenize_query(query: &str) -> Vec<String> {
2311 let mut lexer = Lexer::new(query);
2312 let tokens = lexer.tokenize_all();
2313 tokens.iter().map(|t| format!("{t:?}")).collect()
2314}
2315
2316#[must_use]
2317pub fn format_sql_pretty(query: &str) -> Vec<String> {
2318 format_sql_pretty_compact(query, 5) }
2320
2321#[must_use]
2323pub fn format_ast_tree(query: &str) -> String {
2324 let mut parser = Parser::new(query);
2325 match parser.parse() {
2326 Ok(stmt) => format_select_statement(&stmt, 0),
2327 Err(e) => format!("❌ PARSE ERROR ❌\n{e}\n\n⚠️ The query could not be parsed correctly.\n💡 Check parentheses, operators, and syntax."),
2328 }
2329}
2330
2331fn format_select_statement(stmt: &SelectStatement, indent: usize) -> String {
2332 let mut result = String::new();
2333 let indent_str = " ".repeat(indent);
2334
2335 result.push_str(&format!("{indent_str}SelectStatement {{\n"));
2336
2337 result.push_str(&format!("{indent_str} columns: ["));
2339 if stmt.columns.is_empty() {
2340 result.push_str("],\n");
2341 } else {
2342 result.push('\n');
2343 for col in &stmt.columns {
2344 result.push_str(&format!("{indent_str} \"{col}\",\n"));
2345 }
2346 result.push_str(&format!("{indent_str} ],\n"));
2347 }
2348
2349 if let Some(table) = &stmt.from_table {
2351 result.push_str(&format!("{indent_str} from_table: \"{table}\",\n"));
2352 }
2353
2354 if let Some(where_clause) = &stmt.where_clause {
2356 result.push_str(&format!("{indent_str} where_clause: {{\n"));
2357 result.push_str(&format_where_clause(where_clause, indent + 2));
2358 result.push_str(&format!("{indent_str} }},\n"));
2359 }
2360
2361 if let Some(order_by) = &stmt.order_by {
2363 result.push_str(&format!("{indent_str} order_by: ["));
2364 if order_by.is_empty() {
2365 result.push_str("],\n");
2366 } else {
2367 result.push('\n');
2368 for col in order_by {
2369 let dir = match col.direction {
2370 SortDirection::Asc => "ASC",
2371 SortDirection::Desc => "DESC",
2372 };
2373 result.push_str(&format!(
2374 "{indent_str} \"{col}\" {dir},\n",
2375 col = col.column
2376 ));
2377 }
2378 result.push_str(&format!("{indent_str} ],\n"));
2379 }
2380 }
2381
2382 if let Some(group_by) = &stmt.group_by {
2384 result.push_str(&format!("{indent_str} group_by: ["));
2385 if group_by.is_empty() {
2386 result.push_str("]\n");
2387 } else {
2388 result.push('\n');
2389 for col in group_by {
2390 result.push_str(&format!("{indent_str} \"{col}\",\n"));
2391 }
2392 result.push_str(&format!("{indent_str} ],\n"));
2393 }
2394 }
2395
2396 result.push_str(&format!("{indent_str}}}"));
2397 result
2398}
2399
2400fn format_where_clause(clause: &WhereClause, indent: usize) -> String {
2401 let mut result = String::new();
2402 let indent_str = " ".repeat(indent);
2403
2404 result.push_str(&format!("{indent_str}conditions: [\n"));
2405
2406 for condition in &clause.conditions {
2407 result.push_str(&format!("{indent_str} {{\n"));
2408 result.push_str(&format!(
2409 "{indent_str} expr: {},\n",
2410 format_expression_ast(&condition.expr)
2411 ));
2412
2413 if let Some(connector) = &condition.connector {
2414 let connector_str = match connector {
2415 LogicalOp::And => "AND",
2416 LogicalOp::Or => "OR",
2417 };
2418 result.push_str(&format!("{indent_str} connector: {connector_str},\n"));
2419 }
2420
2421 result.push_str(&format!("{indent_str} }},\n"));
2422 }
2423
2424 result.push_str(&format!("{indent_str}]\n"));
2425 result
2426}
2427
2428fn format_expression_ast(expr: &SqlExpression) -> String {
2429 match expr {
2430 SqlExpression::Column(name) => format!("Column(\"{name}\")"),
2431 SqlExpression::StringLiteral(value) => format!("StringLiteral(\"{value}\")"),
2432 SqlExpression::NumberLiteral(value) => format!("NumberLiteral({value})"),
2433 SqlExpression::BooleanLiteral(value) => format!("BooleanLiteral({value})"),
2434 SqlExpression::Null => "Null".to_string(),
2435 SqlExpression::DateTimeConstructor {
2436 year,
2437 month,
2438 day,
2439 hour,
2440 minute,
2441 second,
2442 } => {
2443 format!(
2444 "DateTime({}-{:02}-{:02} {:02}:{:02}:{:02})",
2445 year,
2446 month,
2447 day,
2448 hour.unwrap_or(0),
2449 minute.unwrap_or(0),
2450 second.unwrap_or(0)
2451 )
2452 }
2453 SqlExpression::DateTimeToday {
2454 hour,
2455 minute,
2456 second,
2457 } => {
2458 format!(
2459 "DateTimeToday({:02}:{:02}:{:02})",
2460 hour.unwrap_or(0),
2461 minute.unwrap_or(0),
2462 second.unwrap_or(0)
2463 )
2464 }
2465 SqlExpression::MethodCall {
2466 object,
2467 method,
2468 args,
2469 } => {
2470 let args_str = args
2471 .iter()
2472 .map(format_expression_ast)
2473 .collect::<Vec<_>>()
2474 .join(", ");
2475 format!("MethodCall({object}.{method}({args_str}))")
2476 }
2477 SqlExpression::ChainedMethodCall { base, method, args } => {
2478 let args_str = args
2479 .iter()
2480 .map(format_expression_ast)
2481 .collect::<Vec<_>>()
2482 .join(", ");
2483 format!(
2484 "ChainedMethodCall({}.{}({}))",
2485 format_expression_ast(base),
2486 method,
2487 args_str
2488 )
2489 }
2490 SqlExpression::FunctionCall {
2491 name,
2492 args,
2493 distinct,
2494 } => {
2495 let args_str = args
2496 .iter()
2497 .map(format_expression_ast)
2498 .collect::<Vec<_>>()
2499 .join(", ");
2500 if *distinct {
2501 format!("FunctionCall({name}(DISTINCT {args_str}))")
2502 } else {
2503 format!("FunctionCall({name}({args_str}))")
2504 }
2505 }
2506 SqlExpression::WindowFunction {
2507 name,
2508 args,
2509 window_spec,
2510 } => {
2511 let args_str = args
2512 .iter()
2513 .map(format_expression_ast)
2514 .collect::<Vec<_>>()
2515 .join(", ");
2516 let partition_str = if !window_spec.partition_by.is_empty() {
2517 format!(" PARTITION BY {}", window_spec.partition_by.join(", "))
2518 } else {
2519 String::new()
2520 };
2521 let order_str = if !window_spec.order_by.is_empty() {
2522 let cols = window_spec
2523 .order_by
2524 .iter()
2525 .map(|col| format!("{} {:?}", col.column, col.direction))
2526 .collect::<Vec<_>>()
2527 .join(", ");
2528 format!(" ORDER BY {}", cols)
2529 } else {
2530 String::new()
2531 };
2532 format!("WindowFunction({name}({args_str}) OVER({partition_str}{order_str}))")
2533 }
2534 SqlExpression::BinaryOp { left, op, right } => {
2535 format!(
2536 "BinaryOp({} {} {})",
2537 format_expression_ast(left),
2538 op,
2539 format_expression_ast(right)
2540 )
2541 }
2542 SqlExpression::InList { expr, values } => {
2543 let list_str = values
2544 .iter()
2545 .map(format_expression_ast)
2546 .collect::<Vec<_>>()
2547 .join(", ");
2548 format!("InList({} IN [{}])", format_expression_ast(expr), list_str)
2549 }
2550 SqlExpression::NotInList { expr, values } => {
2551 let list_str = values
2552 .iter()
2553 .map(format_expression_ast)
2554 .collect::<Vec<_>>()
2555 .join(", ");
2556 format!(
2557 "NotInList({} NOT IN [{}])",
2558 format_expression_ast(expr),
2559 list_str
2560 )
2561 }
2562 SqlExpression::Between { expr, lower, upper } => {
2563 format!(
2564 "Between({} BETWEEN {} AND {})",
2565 format_expression_ast(expr),
2566 format_expression_ast(lower),
2567 format_expression_ast(upper)
2568 )
2569 }
2570 SqlExpression::Not { expr } => {
2571 format!("Not({})", format_expression_ast(expr))
2572 }
2573 SqlExpression::CaseExpression {
2574 when_branches,
2575 else_branch,
2576 } => {
2577 let when_strs: Vec<String> = when_branches
2578 .iter()
2579 .map(|branch| {
2580 format!(
2581 "WHEN {} THEN {}",
2582 format_expression_ast(&branch.condition),
2583 format_expression_ast(&branch.result)
2584 )
2585 })
2586 .collect();
2587 let else_str = else_branch
2588 .as_ref()
2589 .map(|e| format!(" ELSE {}", format_expression_ast(e)))
2590 .unwrap_or_default();
2591 format!("CASE {} {} END", when_strs.join(" "), else_str)
2592 }
2593 }
2594}
2595
2596#[must_use]
2598pub fn datetime_to_iso_string(expr: &SqlExpression) -> Option<String> {
2599 match expr {
2600 SqlExpression::DateTimeConstructor {
2601 year,
2602 month,
2603 day,
2604 hour,
2605 minute,
2606 second,
2607 } => {
2608 let h = hour.unwrap_or(0);
2609 let m = minute.unwrap_or(0);
2610 let s = second.unwrap_or(0);
2611
2612 if let Ok(dt) = NaiveDateTime::parse_from_str(
2614 &format!("{year:04}-{month:02}-{day:02} {h:02}:{m:02}:{s:02}"),
2615 "%Y-%m-%d %H:%M:%S",
2616 ) {
2617 Some(dt.format("%Y-%m-%d %H:%M:%S").to_string())
2618 } else {
2619 None
2620 }
2621 }
2622 SqlExpression::DateTimeToday {
2623 hour,
2624 minute,
2625 second,
2626 } => {
2627 let now = Local::now();
2628 let h = hour.unwrap_or(0);
2629 let m = minute.unwrap_or(0);
2630 let s = second.unwrap_or(0);
2631
2632 if let Ok(dt) = NaiveDateTime::parse_from_str(
2634 &format!(
2635 "{:04}-{:02}-{:02} {:02}:{:02}:{:02}",
2636 now.year(),
2637 now.month(),
2638 now.day(),
2639 h,
2640 m,
2641 s
2642 ),
2643 "%Y-%m-%d %H:%M:%S",
2644 ) {
2645 Some(dt.format("%Y-%m-%d %H:%M:%S").to_string())
2646 } else {
2647 None
2648 }
2649 }
2650 _ => None,
2651 }
2652}
2653
2654fn format_sql_with_preserved_parens(
2656 query: &str,
2657 cols_per_line: usize,
2658) -> Result<Vec<String>, String> {
2659 let mut lines = Vec::new();
2660 let mut lexer = Lexer::new(query);
2661 let tokens_with_pos = lexer.tokenize_all_with_positions();
2662
2663 if tokens_with_pos.is_empty() {
2664 return Err("No tokens found".to_string());
2665 }
2666
2667 let mut i = 0;
2668 let cols_per_line = cols_per_line.max(1);
2669
2670 while i < tokens_with_pos.len() {
2671 let (start, _end, ref token) = tokens_with_pos[i];
2672
2673 match token {
2674 Token::Select => {
2675 lines.push("SELECT".to_string());
2676 i += 1;
2677
2678 let mut columns = Vec::new();
2680 let mut col_start = i;
2681 while i < tokens_with_pos.len() {
2682 match &tokens_with_pos[i].2 {
2683 Token::From | Token::Eof => break,
2684 Token::Comma => {
2685 if col_start < i {
2687 let col_text = extract_text_between_positions(
2688 query,
2689 tokens_with_pos[col_start].0,
2690 tokens_with_pos[i - 1].1,
2691 );
2692 columns.push(col_text);
2693 }
2694 i += 1;
2695 col_start = i;
2696 }
2697 _ => i += 1,
2698 }
2699 }
2700 if col_start < i && i > 0 {
2702 let col_text = extract_text_between_positions(
2703 query,
2704 tokens_with_pos[col_start].0,
2705 tokens_with_pos[i - 1].1,
2706 );
2707 columns.push(col_text);
2708 }
2709
2710 for chunk in columns.chunks(cols_per_line) {
2712 let mut line = " ".to_string();
2713 for (idx, col) in chunk.iter().enumerate() {
2714 if idx > 0 {
2715 line.push_str(", ");
2716 }
2717 line.push_str(col.trim());
2718 }
2719 let is_last_chunk = chunk.as_ptr() as usize + std::mem::size_of_val(chunk)
2721 >= columns.last().map_or(0, |c| std::ptr::from_ref(c) as usize);
2722 if !is_last_chunk && columns.len() > cols_per_line {
2723 line.push(',');
2724 }
2725 lines.push(line);
2726 }
2727 }
2728 Token::From => {
2729 i += 1;
2730 if i < tokens_with_pos.len() {
2731 let table_start = tokens_with_pos[i].0;
2732 while i < tokens_with_pos.len() {
2734 match &tokens_with_pos[i].2 {
2735 Token::Where | Token::OrderBy | Token::GroupBy | Token::Eof => break,
2736 _ => i += 1,
2737 }
2738 }
2739 if i > 0 {
2740 let table_text = extract_text_between_positions(
2741 query,
2742 table_start,
2743 tokens_with_pos[i - 1].1,
2744 );
2745 lines.push(format!("FROM {}", table_text.trim()));
2746 }
2747 }
2748 }
2749 Token::Where => {
2750 lines.push("WHERE".to_string());
2751 i += 1;
2752
2753 let where_start = if i < tokens_with_pos.len() {
2755 tokens_with_pos[i].0
2756 } else {
2757 start
2758 };
2759
2760 let mut where_end = query.len();
2762 while i < tokens_with_pos.len() {
2763 match &tokens_with_pos[i].2 {
2764 Token::OrderBy | Token::GroupBy | Token::Eof => {
2765 if i > 0 {
2766 where_end = tokens_with_pos[i - 1].1;
2767 }
2768 break;
2769 }
2770 _ => i += 1,
2771 }
2772 }
2773
2774 let where_text = extract_text_between_positions(query, where_start, where_end);
2776
2777 let formatted_where = format_where_clause_with_parens(&where_text);
2779 for line in formatted_where {
2780 lines.push(format!(" {line}"));
2781 }
2782 }
2783 Token::OrderBy => {
2784 i += 1;
2785 let order_start = if i < tokens_with_pos.len() {
2786 tokens_with_pos[i].0
2787 } else {
2788 start
2789 };
2790
2791 while i < tokens_with_pos.len() {
2793 match &tokens_with_pos[i].2 {
2794 Token::GroupBy | Token::Eof => break,
2795 _ => i += 1,
2796 }
2797 }
2798
2799 if i > 0 {
2800 let order_text = extract_text_between_positions(
2801 query,
2802 order_start,
2803 tokens_with_pos[i - 1].1,
2804 );
2805 lines.push(format!("ORDER BY {}", order_text.trim()));
2806 }
2807 }
2808 Token::GroupBy => {
2809 i += 1;
2810 let group_start = if i < tokens_with_pos.len() {
2811 tokens_with_pos[i].0
2812 } else {
2813 start
2814 };
2815
2816 while i < tokens_with_pos.len() {
2818 match &tokens_with_pos[i].2 {
2819 Token::Having | Token::Eof => break,
2820 _ => i += 1,
2821 }
2822 }
2823
2824 if i > 0 {
2825 let group_text = extract_text_between_positions(
2826 query,
2827 group_start,
2828 tokens_with_pos[i - 1].1,
2829 );
2830 lines.push(format!("GROUP BY {}", group_text.trim()));
2831 }
2832 }
2833 _ => i += 1,
2834 }
2835 }
2836
2837 Ok(lines)
2838}
2839
2840fn extract_text_between_positions(query: &str, start: usize, end: usize) -> String {
2842 let chars: Vec<char> = query.chars().collect();
2843 let start = start.min(chars.len());
2844 let end = end.min(chars.len());
2845 chars[start..end].iter().collect()
2846}
2847
2848fn format_where_clause_with_parens(where_text: &str) -> Vec<String> {
2850 let mut lines = Vec::new();
2851 let mut current_line = String::new();
2852 let mut paren_depth = 0;
2853 let mut i = 0;
2854 let chars: Vec<char> = where_text.chars().collect();
2855
2856 while i < chars.len() {
2857 if paren_depth == 0 {
2859 if i + 5 <= chars.len() {
2861 let next_five: String = chars[i..i + 5].iter().collect();
2862 if next_five.to_uppercase() == " AND " {
2863 if !current_line.trim().is_empty() {
2864 lines.push(current_line.trim().to_string());
2865 }
2866 lines.push("AND".to_string());
2867 current_line.clear();
2868 i += 5;
2869 continue;
2870 }
2871 }
2872 if i + 4 <= chars.len() {
2873 let next_four: String = chars[i..i + 4].iter().collect();
2874 if next_four.to_uppercase() == " OR " {
2875 if !current_line.trim().is_empty() {
2876 lines.push(current_line.trim().to_string());
2877 }
2878 lines.push("OR".to_string());
2879 current_line.clear();
2880 i += 4;
2881 continue;
2882 }
2883 }
2884 }
2885
2886 match chars[i] {
2888 '(' => {
2889 paren_depth += 1;
2890 current_line.push('(');
2891 }
2892 ')' => {
2893 paren_depth -= 1;
2894 current_line.push(')');
2895 }
2896 c => current_line.push(c),
2897 }
2898 i += 1;
2899 }
2900
2901 if !current_line.trim().is_empty() {
2903 lines.push(current_line.trim().to_string());
2904 }
2905
2906 if lines.is_empty() {
2908 lines.push(where_text.trim().to_string());
2909 }
2910
2911 lines
2912}
2913
2914#[must_use]
2915pub fn format_sql_pretty_compact(query: &str, cols_per_line: usize) -> Vec<String> {
2916 if let Ok(lines) = format_sql_with_preserved_parens(query, cols_per_line) {
2918 return lines;
2919 }
2920
2921 let mut lines = Vec::new();
2923 let mut parser = Parser::new(query);
2924
2925 let cols_per_line = cols_per_line.max(1);
2927
2928 if let Ok(stmt) = parser.parse() {
2929 if !stmt.columns.is_empty() {
2931 lines.push("SELECT".to_string());
2932
2933 for chunk in stmt.columns.chunks(cols_per_line) {
2935 let mut line = " ".to_string();
2936 for (i, col) in chunk.iter().enumerate() {
2937 if i > 0 {
2938 line.push_str(", ");
2939 }
2940 line.push_str(col);
2941 }
2942 let last_chunk_idx = (stmt.columns.len() - 1) / cols_per_line;
2944 let current_chunk_idx =
2945 stmt.columns.iter().position(|c| c == &chunk[0]).unwrap() / cols_per_line;
2946 if current_chunk_idx < last_chunk_idx {
2947 line.push(',');
2948 }
2949 lines.push(line);
2950 }
2951 }
2952
2953 if let Some(table) = &stmt.from_table {
2955 lines.push(format!("FROM {table}"));
2956 }
2957
2958 if let Some(where_clause) = &stmt.where_clause {
2960 lines.push("WHERE".to_string());
2961 for (i, condition) in where_clause.conditions.iter().enumerate() {
2962 if i > 0 {
2963 if let Some(prev_condition) = where_clause.conditions.get(i - 1) {
2965 if let Some(connector) = &prev_condition.connector {
2966 match connector {
2967 LogicalOp::And => lines.push(" AND".to_string()),
2968 LogicalOp::Or => lines.push(" OR".to_string()),
2969 }
2970 }
2971 }
2972 }
2973 lines.push(format!(" {}", format_expression(&condition.expr)));
2974 }
2975 }
2976
2977 if let Some(order_by) = &stmt.order_by {
2979 let order_str = order_by
2980 .iter()
2981 .map(|col| {
2982 let dir = match col.direction {
2983 SortDirection::Asc => " ASC",
2984 SortDirection::Desc => " DESC",
2985 };
2986 format!("{}{}", col.column, dir)
2987 })
2988 .collect::<Vec<_>>()
2989 .join(", ");
2990 lines.push(format!("ORDER BY {order_str}"));
2991 }
2992
2993 if let Some(group_by) = &stmt.group_by {
2995 let group_str = group_by.join(", ");
2996 lines.push(format!("GROUP BY {group_str}"));
2997 }
2998 } else {
2999 let mut lexer = Lexer::new(query);
3001 let tokens = lexer.tokenize_all();
3002 let mut current_line = String::new();
3003 let mut indent = 0;
3004
3005 for token in tokens {
3006 match &token {
3007 Token::Select | Token::From | Token::Where | Token::OrderBy | Token::GroupBy => {
3008 if !current_line.is_empty() {
3009 lines.push(current_line.trim().to_string());
3010 current_line.clear();
3011 }
3012 lines.push(format!("{token:?}").to_uppercase());
3013 indent = 1;
3014 }
3015 Token::And | Token::Or => {
3016 if !current_line.is_empty() {
3017 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
3018 current_line.clear();
3019 }
3020 lines.push(format!(" {token:?}").to_uppercase());
3021 }
3022 Token::Comma => {
3023 current_line.push(',');
3024 if indent > 0 {
3025 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
3026 current_line.clear();
3027 }
3028 }
3029 Token::Eof => break,
3030 _ => {
3031 if !current_line.is_empty() {
3032 current_line.push(' ');
3033 }
3034 current_line.push_str(&format_token(&token));
3035 }
3036 }
3037 }
3038
3039 if !current_line.is_empty() {
3040 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
3041 }
3042 }
3043
3044 lines
3045}
3046
3047fn format_expression(expr: &SqlExpression) -> String {
3048 match expr {
3049 SqlExpression::Column(name) => name.clone(),
3050 SqlExpression::StringLiteral(s) => format!("'{s}'"),
3051 SqlExpression::NumberLiteral(n) => n.clone(),
3052 SqlExpression::BooleanLiteral(b) => b.to_string(),
3053 SqlExpression::Null => "NULL".to_string(),
3054 SqlExpression::DateTimeConstructor {
3055 year,
3056 month,
3057 day,
3058 hour,
3059 minute,
3060 second,
3061 } => {
3062 let mut result = format!("DateTime({year}, {month}, {day}");
3063 if let Some(h) = hour {
3064 result.push_str(&format!(", {h}"));
3065 if let Some(m) = minute {
3066 result.push_str(&format!(", {m}"));
3067 if let Some(s) = second {
3068 result.push_str(&format!(", {s}"));
3069 }
3070 }
3071 }
3072 result.push(')');
3073 result
3074 }
3075 SqlExpression::DateTimeToday {
3076 hour,
3077 minute,
3078 second,
3079 } => {
3080 let mut result = "DateTime()".to_string();
3081 if let Some(h) = hour {
3082 result = format!("DateTime(TODAY, {h}");
3083 if let Some(m) = minute {
3084 result.push_str(&format!(", {m}"));
3085 if let Some(s) = second {
3086 result.push_str(&format!(", {s}"));
3087 }
3088 }
3089 result.push(')');
3090 }
3091 result
3092 }
3093 SqlExpression::MethodCall {
3094 object,
3095 method,
3096 args,
3097 } => {
3098 let args_str = args
3099 .iter()
3100 .map(format_expression)
3101 .collect::<Vec<_>>()
3102 .join(", ");
3103 format!("{object}.{method}({args_str})")
3104 }
3105 SqlExpression::BinaryOp { left, op, right } => {
3106 if op == "OR" || op == "AND" {
3109 format!(
3112 "({} {} {})",
3113 format_expression(left),
3114 op,
3115 format_expression(right)
3116 )
3117 } else {
3118 format!(
3119 "{} {} {}",
3120 format_expression(left),
3121 op,
3122 format_expression(right)
3123 )
3124 }
3125 }
3126 SqlExpression::InList { expr, values } => {
3127 let values_str = values
3128 .iter()
3129 .map(format_expression)
3130 .collect::<Vec<_>>()
3131 .join(", ");
3132 format!("{} IN ({})", format_expression(expr), values_str)
3133 }
3134 SqlExpression::NotInList { expr, values } => {
3135 let values_str = values
3136 .iter()
3137 .map(format_expression)
3138 .collect::<Vec<_>>()
3139 .join(", ");
3140 format!("{} NOT IN ({})", format_expression(expr), values_str)
3141 }
3142 SqlExpression::Between { expr, lower, upper } => {
3143 format!(
3144 "{} BETWEEN {} AND {}",
3145 format_expression(expr),
3146 format_expression(lower),
3147 format_expression(upper)
3148 )
3149 }
3150 SqlExpression::Not { expr } => {
3151 format!("NOT {}", format_expression(expr))
3152 }
3153 SqlExpression::ChainedMethodCall { base, method, args } => {
3154 let args_str = args
3155 .iter()
3156 .map(format_expression)
3157 .collect::<Vec<_>>()
3158 .join(", ");
3159 format!("{}.{}({})", format_expression(base), method, args_str)
3160 }
3161 SqlExpression::FunctionCall {
3162 name,
3163 args,
3164 distinct,
3165 } => {
3166 let args_str = args
3167 .iter()
3168 .map(format_expression)
3169 .collect::<Vec<_>>()
3170 .join(", ");
3171 if *distinct {
3172 format!("{name}(DISTINCT {args_str})")
3173 } else {
3174 format!("{name}({args_str})")
3175 }
3176 }
3177 SqlExpression::WindowFunction {
3178 name,
3179 args,
3180 window_spec,
3181 } => {
3182 let args_str = args
3183 .iter()
3184 .map(format_expression)
3185 .collect::<Vec<_>>()
3186 .join(", ");
3187 let partition_str = if !window_spec.partition_by.is_empty() {
3188 format!(" PARTITION BY {}", window_spec.partition_by.join(", "))
3189 } else {
3190 String::new()
3191 };
3192 let order_str = if !window_spec.order_by.is_empty() {
3193 let cols = window_spec
3194 .order_by
3195 .iter()
3196 .map(|col| {
3197 let dir = match col.direction {
3198 SortDirection::Asc => "ASC",
3199 SortDirection::Desc => "DESC",
3200 };
3201 format!("{} {}", col.column, dir)
3202 })
3203 .collect::<Vec<_>>()
3204 .join(", ");
3205 format!(" ORDER BY {}", cols)
3206 } else {
3207 String::new()
3208 };
3209 format!("{name}({args_str}) OVER({partition_str}{order_str})")
3210 }
3211 SqlExpression::CaseExpression {
3212 when_branches,
3213 else_branch,
3214 } => {
3215 let mut result = String::from("CASE");
3216 for branch in when_branches {
3217 result.push_str(&format!(
3218 " WHEN {} THEN {}",
3219 format_expression(&branch.condition),
3220 format_expression(&branch.result)
3221 ));
3222 }
3223 if let Some(else_expr) = else_branch {
3224 result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
3225 }
3226 result.push_str(" END");
3227 result
3228 }
3229 }
3230}
3231
3232fn format_token(token: &Token) -> String {
3233 match token {
3234 Token::Identifier(s) => s.clone(),
3235 Token::QuotedIdentifier(s) => format!("\"{s}\""),
3236 Token::StringLiteral(s) => format!("'{s}'"),
3237 Token::NumberLiteral(n) => n.clone(),
3238 Token::DateTime => "DateTime".to_string(),
3239 Token::Case => "CASE".to_string(),
3240 Token::When => "WHEN".to_string(),
3241 Token::Then => "THEN".to_string(),
3242 Token::Else => "ELSE".to_string(),
3243 Token::End => "END".to_string(),
3244 Token::Distinct => "DISTINCT".to_string(),
3245 Token::Over => "OVER".to_string(),
3246 Token::Partition => "PARTITION".to_string(),
3247 Token::By => "BY".to_string(),
3248 Token::LeftParen => "(".to_string(),
3249 Token::RightParen => ")".to_string(),
3250 Token::Comma => ",".to_string(),
3251 Token::Dot => ".".to_string(),
3252 Token::Equal => "=".to_string(),
3253 Token::NotEqual => "!=".to_string(),
3254 Token::LessThan => "<".to_string(),
3255 Token::GreaterThan => ">".to_string(),
3256 Token::LessThanOrEqual => "<=".to_string(),
3257 Token::GreaterThanOrEqual => ">=".to_string(),
3258 Token::In => "IN".to_string(),
3259 _ => format!("{token:?}").to_uppercase(),
3260 }
3261}
3262
3263fn analyze_statement(
3264 stmt: &SelectStatement,
3265 query: &str,
3266 _cursor_pos: usize,
3267) -> (CursorContext, Option<String>) {
3268 let trimmed = query.trim();
3270
3271 let comparison_ops = [" > ", " < ", " >= ", " <= ", " = ", " != "];
3273 for op in &comparison_ops {
3274 if let Some(op_pos) = query.rfind(op) {
3275 let before_op = safe_slice_to(query, op_pos);
3276 let after_op_start = op_pos + op.len();
3277 let after_op = if after_op_start < query.len() {
3278 &query[after_op_start..]
3279 } else {
3280 ""
3281 };
3282
3283 if let Some(col_name) = before_op.split_whitespace().last() {
3285 if col_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
3286 let after_op_trimmed = after_op.trim();
3288 if after_op_trimmed.is_empty()
3289 || (after_op_trimmed
3290 .chars()
3291 .all(|c| c.is_alphanumeric() || c == '_')
3292 && !after_op_trimmed.contains('('))
3293 {
3294 let partial = if after_op_trimmed.is_empty() {
3295 None
3296 } else {
3297 Some(after_op_trimmed.to_string())
3298 };
3299 return (
3300 CursorContext::AfterComparisonOp(
3301 col_name.to_string(),
3302 op.trim().to_string(),
3303 ),
3304 partial,
3305 );
3306 }
3307 }
3308 }
3309 }
3310 }
3311
3312 if trimmed.to_uppercase().ends_with(" AND")
3314 || trimmed.to_uppercase().ends_with(" OR")
3315 || trimmed.to_uppercase().ends_with(" AND ")
3316 || trimmed.to_uppercase().ends_with(" OR ")
3317 {
3318 } else {
3320 if let Some(dot_pos) = trimmed.rfind('.') {
3322 let before_dot = safe_slice_to(trimmed, dot_pos);
3324 let after_dot_start = dot_pos + 1;
3325 let after_dot = if after_dot_start < trimmed.len() {
3326 &trimmed[after_dot_start..]
3327 } else {
3328 ""
3329 };
3330
3331 if !after_dot.contains('(') {
3334 let col_name = if before_dot.ends_with('"') {
3336 let bytes = before_dot.as_bytes();
3338 let mut pos = before_dot.len() - 1; let mut found_start = None;
3340
3341 if pos > 0 {
3343 pos -= 1;
3344 while pos > 0 {
3345 if bytes[pos] == b'"' {
3346 if pos == 0 || bytes[pos - 1] != b'\\' {
3348 found_start = Some(pos);
3349 break;
3350 }
3351 }
3352 pos -= 1;
3353 }
3354 if found_start.is_none() && bytes[0] == b'"' {
3356 found_start = Some(0);
3357 }
3358 }
3359
3360 found_start.map(|start| safe_slice_from(before_dot, start))
3361 } else {
3362 before_dot
3365 .split_whitespace()
3366 .last()
3367 .map(|word| word.trim_start_matches('('))
3368 };
3369
3370 if let Some(col_name) = col_name {
3371 let is_valid = if col_name.starts_with('"') && col_name.ends_with('"') {
3373 true
3375 } else {
3376 col_name.chars().all(|c| c.is_alphanumeric() || c == '_')
3378 };
3379
3380 if is_valid {
3381 let partial_method = if after_dot.is_empty() {
3384 None
3385 } else if after_dot.chars().all(|c| c.is_alphanumeric() || c == '_') {
3386 Some(after_dot.to_string())
3387 } else {
3388 None
3389 };
3390
3391 let col_name_for_context = if col_name.starts_with('"')
3393 && col_name.ends_with('"')
3394 && col_name.len() > 2
3395 {
3396 col_name[1..col_name.len() - 1].to_string()
3397 } else {
3398 col_name.to_string()
3399 };
3400
3401 return (
3402 CursorContext::AfterColumn(col_name_for_context),
3403 partial_method,
3404 );
3405 }
3406 }
3407 }
3408 }
3409 }
3410
3411 if let Some(where_clause) = &stmt.where_clause {
3413 if trimmed.to_uppercase().ends_with(" AND") || trimmed.to_uppercase().ends_with(" OR") {
3415 let op = if trimmed.to_uppercase().ends_with(" AND") {
3416 LogicalOp::And
3417 } else {
3418 LogicalOp::Or
3419 };
3420 return (CursorContext::AfterLogicalOp(op), None);
3421 }
3422
3423 if let Some(and_pos) = query.to_uppercase().rfind(" AND ") {
3425 let after_and = safe_slice_from(query, and_pos + 5);
3426 let partial = extract_partial_at_end(after_and);
3427 if partial.is_some() {
3428 return (CursorContext::AfterLogicalOp(LogicalOp::And), partial);
3429 }
3430 }
3431
3432 if let Some(or_pos) = query.to_uppercase().rfind(" OR ") {
3433 let after_or = safe_slice_from(query, or_pos + 4);
3434 let partial = extract_partial_at_end(after_or);
3435 if partial.is_some() {
3436 return (CursorContext::AfterLogicalOp(LogicalOp::Or), partial);
3437 }
3438 }
3439
3440 if let Some(last_condition) = where_clause.conditions.last() {
3441 if let Some(connector) = &last_condition.connector {
3442 return (
3444 CursorContext::AfterLogicalOp(connector.clone()),
3445 extract_partial_at_end(query),
3446 );
3447 }
3448 }
3449 return (CursorContext::WhereClause, extract_partial_at_end(query));
3451 }
3452
3453 if query.to_uppercase().ends_with(" ORDER BY ") || query.to_uppercase().ends_with(" ORDER BY") {
3455 return (CursorContext::OrderByClause, None);
3456 }
3457
3458 if stmt.order_by.is_some() {
3460 return (CursorContext::OrderByClause, extract_partial_at_end(query));
3461 }
3462
3463 if stmt.from_table.is_some() && stmt.where_clause.is_none() && stmt.order_by.is_none() {
3464 return (CursorContext::FromClause, extract_partial_at_end(query));
3465 }
3466
3467 if !stmt.columns.is_empty() && stmt.from_table.is_none() {
3468 return (CursorContext::SelectClause, extract_partial_at_end(query));
3469 }
3470
3471 (CursorContext::Unknown, None)
3472}
3473
3474fn analyze_partial(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
3475 let upper = query.to_uppercase();
3476
3477 let trimmed = query.trim();
3479
3480 #[cfg(test)]
3481 {
3482 if trimmed.contains("\"Last Name\"") {
3483 eprintln!("DEBUG analyze_partial: query='{query}', trimmed='{trimmed}'");
3484 }
3485 }
3486
3487 let comparison_ops = [" > ", " < ", " >= ", " <= ", " = ", " != "];
3489 for op in &comparison_ops {
3490 if let Some(op_pos) = query.rfind(op) {
3491 let before_op = safe_slice_to(query, op_pos);
3492 let after_op_start = op_pos + op.len();
3493 let after_op = if after_op_start < query.len() {
3494 &query[after_op_start..]
3495 } else {
3496 ""
3497 };
3498
3499 if let Some(col_name) = before_op.split_whitespace().last() {
3501 if col_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
3502 let after_op_trimmed = after_op.trim();
3504 if after_op_trimmed.is_empty()
3505 || (after_op_trimmed
3506 .chars()
3507 .all(|c| c.is_alphanumeric() || c == '_')
3508 && !after_op_trimmed.contains('('))
3509 {
3510 let partial = if after_op_trimmed.is_empty() {
3511 None
3512 } else {
3513 Some(after_op_trimmed.to_string())
3514 };
3515 return (
3516 CursorContext::AfterComparisonOp(
3517 col_name.to_string(),
3518 op.trim().to_string(),
3519 ),
3520 partial,
3521 );
3522 }
3523 }
3524 }
3525 }
3526 }
3527
3528 if let Some(dot_pos) = trimmed.rfind('.') {
3531 #[cfg(test)]
3532 {
3533 if trimmed.contains("\"Last Name\"") {
3534 eprintln!("DEBUG: Found dot at position {dot_pos}");
3535 }
3536 }
3537 let before_dot = &trimmed[..dot_pos];
3539 let after_dot = &trimmed[dot_pos + 1..];
3540
3541 if !after_dot.contains('(') {
3544 let col_name = if before_dot.ends_with('"') {
3547 let bytes = before_dot.as_bytes();
3549 let mut pos = before_dot.len() - 1; let mut found_start = None;
3551
3552 #[cfg(test)]
3553 {
3554 if trimmed.contains("\"Last Name\"") {
3555 eprintln!("DEBUG: before_dot='{before_dot}', looking for opening quote");
3556 }
3557 }
3558
3559 if pos > 0 {
3561 pos -= 1;
3562 while pos > 0 {
3563 if bytes[pos] == b'"' {
3564 if pos == 0 || bytes[pos - 1] != b'\\' {
3566 found_start = Some(pos);
3567 break;
3568 }
3569 }
3570 pos -= 1;
3571 }
3572 if found_start.is_none() && bytes[0] == b'"' {
3574 found_start = Some(0);
3575 }
3576 }
3577
3578 if let Some(start) = found_start {
3579 let result = safe_slice_from(before_dot, start);
3581 #[cfg(test)]
3582 {
3583 if trimmed.contains("\"Last Name\"") {
3584 eprintln!("DEBUG: Extracted quoted identifier: '{result}'");
3585 }
3586 }
3587 Some(result)
3588 } else {
3589 #[cfg(test)]
3590 {
3591 if trimmed.contains("\"Last Name\"") {
3592 eprintln!("DEBUG: No opening quote found!");
3593 }
3594 }
3595 None
3596 }
3597 } else {
3598 before_dot
3601 .split_whitespace()
3602 .last()
3603 .map(|word| word.trim_start_matches('('))
3604 };
3605
3606 if let Some(col_name) = col_name {
3607 #[cfg(test)]
3608 {
3609 if trimmed.contains("\"Last Name\"") {
3610 eprintln!("DEBUG: col_name = '{col_name}'");
3611 }
3612 }
3613
3614 let is_valid = if col_name.starts_with('"') && col_name.ends_with('"') {
3616 true
3618 } else {
3619 col_name.chars().all(|c| c.is_alphanumeric() || c == '_')
3621 };
3622
3623 #[cfg(test)]
3624 {
3625 if trimmed.contains("\"Last Name\"") {
3626 eprintln!("DEBUG: is_valid = {is_valid}");
3627 }
3628 }
3629
3630 if is_valid {
3631 let partial_method = if after_dot.is_empty() {
3634 None
3635 } else if after_dot.chars().all(|c| c.is_alphanumeric() || c == '_') {
3636 Some(after_dot.to_string())
3637 } else {
3638 None
3639 };
3640
3641 let col_name_for_context = if col_name.starts_with('"')
3643 && col_name.ends_with('"')
3644 && col_name.len() > 2
3645 {
3646 col_name[1..col_name.len() - 1].to_string()
3647 } else {
3648 col_name.to_string()
3649 };
3650
3651 return (
3652 CursorContext::AfterColumn(col_name_for_context),
3653 partial_method,
3654 );
3655 }
3656 }
3657 }
3658 }
3659
3660 if let Some(and_pos) = upper.rfind(" AND ") {
3662 if cursor_pos >= and_pos + 5 {
3664 let after_and = safe_slice_from(query, and_pos + 5);
3666 let partial = extract_partial_at_end(after_and);
3667 return (CursorContext::AfterLogicalOp(LogicalOp::And), partial);
3668 }
3669 }
3670
3671 if let Some(or_pos) = upper.rfind(" OR ") {
3672 if cursor_pos >= or_pos + 4 {
3674 let after_or = safe_slice_from(query, or_pos + 4);
3676 let partial = extract_partial_at_end(after_or);
3677 return (CursorContext::AfterLogicalOp(LogicalOp::Or), partial);
3678 }
3679 }
3680
3681 if trimmed.to_uppercase().ends_with(" AND") || trimmed.to_uppercase().ends_with(" OR") {
3683 let op = if trimmed.to_uppercase().ends_with(" AND") {
3684 LogicalOp::And
3685 } else {
3686 LogicalOp::Or
3687 };
3688 return (CursorContext::AfterLogicalOp(op), None);
3689 }
3690
3691 if upper.ends_with(" ORDER BY ") || upper.ends_with(" ORDER BY") || upper.contains("ORDER BY ")
3693 {
3694 return (CursorContext::OrderByClause, extract_partial_at_end(query));
3695 }
3696
3697 if upper.contains("WHERE") && !upper.contains("ORDER") && !upper.contains("GROUP") {
3698 return (CursorContext::WhereClause, extract_partial_at_end(query));
3699 }
3700
3701 if upper.contains("FROM") && !upper.contains("WHERE") && !upper.contains("ORDER") {
3702 return (CursorContext::FromClause, extract_partial_at_end(query));
3703 }
3704
3705 if upper.contains("SELECT") && !upper.contains("FROM") {
3706 return (CursorContext::SelectClause, extract_partial_at_end(query));
3707 }
3708
3709 (CursorContext::Unknown, None)
3710}
3711
3712fn extract_partial_at_end(query: &str) -> Option<String> {
3713 let trimmed = query.trim();
3714
3715 if let Some(last_word) = trimmed.split_whitespace().last() {
3717 if last_word.starts_with('"') && !last_word.ends_with('"') {
3718 return Some(last_word.to_string());
3720 }
3721 }
3722
3723 let last_word = trimmed.split_whitespace().last()?;
3725
3726 if last_word.chars().all(|c| c.is_alphanumeric() || c == '_') && !is_sql_keyword(last_word) {
3728 Some(last_word.to_string())
3729 } else {
3730 None
3731 }
3732}
3733
3734fn is_sql_keyword(word: &str) -> bool {
3735 matches!(
3736 word.to_uppercase().as_str(),
3737 "SELECT"
3738 | "FROM"
3739 | "WHERE"
3740 | "AND"
3741 | "OR"
3742 | "IN"
3743 | "ORDER"
3744 | "BY"
3745 | "GROUP"
3746 | "HAVING"
3747 | "ASC"
3748 | "DESC"
3749 | "DISTINCT"
3750 )
3751}
3752
3753#[cfg(test)]
3754mod tests {
3755 use super::*;
3756
3757 #[test]
3758 fn test_tokenizer_window_functions() {
3759 let mut lexer = Lexer::new("LAG(value) OVER (PARTITION BY category ORDER BY id)");
3760 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "LAG"));
3761 assert!(matches!(lexer.next_token(), Token::LeftParen));
3762 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "value"));
3763 assert!(matches!(lexer.next_token(), Token::RightParen));
3764
3765 let over_token = lexer.next_token();
3766 println!("Expected OVER, got: {:?}", over_token);
3767 assert!(matches!(over_token, Token::Over));
3768
3769 assert!(matches!(lexer.next_token(), Token::LeftParen));
3770 assert!(matches!(lexer.next_token(), Token::Partition));
3771 assert!(matches!(lexer.next_token(), Token::By));
3772 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "category"));
3773 }
3774
3775 #[test]
3776 fn test_parse_window_function() {
3777 let query = "SELECT LAG(value, 1) OVER (ORDER BY id) as prev_value FROM test";
3778 let mut parser = Parser::new(query);
3779 let result = parser.parse();
3780
3781 assert!(
3782 result.is_ok(),
3783 "Failed to parse window function: {:?}",
3784 result
3785 );
3786 let stmt = result.unwrap();
3787
3788 if let Some(item) = stmt.select_items.get(0) {
3790 match item {
3791 SelectItem::Expression { expr, alias } => {
3792 println!("Parsed expression: {:?}", expr);
3793 assert!(matches!(expr, SqlExpression::WindowFunction { .. }));
3794 assert_eq!(alias, "prev_value");
3795 }
3796 _ => panic!("Expected expression, got: {:?}", item),
3797 }
3798 } else {
3799 panic!("No select items found");
3800 }
3801 }
3802
3803 #[test]
3804 fn test_chained_method_calls() {
3805 let query = "SELECT * FROM trades WHERE commission.ToString().IndexOf('.') = 1";
3807 let mut parser = Parser::new(query);
3808 let result = parser.parse();
3809
3810 assert!(
3811 result.is_ok(),
3812 "Failed to parse chained method calls: {result:?}"
3813 );
3814
3815 let query2 = "SELECT * FROM data WHERE field.ToUpper().Replace('A', 'B').StartsWith('C')";
3817 let mut parser2 = Parser::new(query2);
3818 let result2 = parser2.parse();
3819
3820 assert!(
3821 result2.is_ok(),
3822 "Failed to parse multiple chained calls: {result2:?}"
3823 );
3824 }
3825
3826 #[test]
3827 fn test_tokenizer() {
3828 let mut lexer = Lexer::new("SELECT * FROM trade_deal WHERE price > 100");
3829
3830 assert!(matches!(lexer.next_token(), Token::Select));
3831 assert!(matches!(lexer.next_token(), Token::Star));
3832 assert!(matches!(lexer.next_token(), Token::From));
3833 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "trade_deal"));
3834 assert!(matches!(lexer.next_token(), Token::Where));
3835 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "price"));
3836 assert!(matches!(lexer.next_token(), Token::GreaterThan));
3837 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "100"));
3838 }
3839
3840 #[test]
3841 fn test_tokenizer_datetime() {
3842 let mut lexer = Lexer::new("WHERE createdDate > DateTime(2025, 10, 20)");
3843
3844 assert!(matches!(lexer.next_token(), Token::Where));
3845 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "createdDate"));
3846 assert!(matches!(lexer.next_token(), Token::GreaterThan));
3847 assert!(matches!(lexer.next_token(), Token::DateTime));
3848 assert!(matches!(lexer.next_token(), Token::LeftParen));
3849 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "2025"));
3850 assert!(matches!(lexer.next_token(), Token::Comma));
3851 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "10"));
3852 assert!(matches!(lexer.next_token(), Token::Comma));
3853 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "20"));
3854 assert!(matches!(lexer.next_token(), Token::RightParen));
3855 }
3856
3857 #[test]
3858 fn test_parse_simple_select() {
3859 let mut parser = Parser::new("SELECT * FROM trade_deal");
3860 let stmt = parser.parse().unwrap();
3861
3862 assert_eq!(stmt.columns, vec!["*"]);
3863 assert_eq!(stmt.from_table, Some("trade_deal".to_string()));
3864 assert!(stmt.where_clause.is_none());
3865 }
3866
3867 #[test]
3868 fn test_parse_where_with_method() {
3869 let mut parser = Parser::new("SELECT * FROM trade_deal WHERE name.Contains(\"test\")");
3870 let stmt = parser.parse().unwrap();
3871
3872 assert!(stmt.where_clause.is_some());
3873 let where_clause = stmt.where_clause.unwrap();
3874 assert_eq!(where_clause.conditions.len(), 1);
3875 }
3876
3877 #[test]
3878 fn test_parse_datetime_constructor() {
3879 let mut parser =
3880 Parser::new("SELECT * FROM trade_deal WHERE createdDate > DateTime(2025, 10, 20)");
3881 let stmt = parser.parse().unwrap();
3882
3883 assert!(stmt.where_clause.is_some());
3884 let where_clause = stmt.where_clause.unwrap();
3885 assert_eq!(where_clause.conditions.len(), 1);
3886
3887 if let SqlExpression::BinaryOp { left, op, right } = &where_clause.conditions[0].expr {
3889 assert_eq!(op, ">");
3890 assert!(matches!(left.as_ref(), SqlExpression::Column(col) if col == "createdDate"));
3891 assert!(matches!(
3892 right.as_ref(),
3893 SqlExpression::DateTimeConstructor {
3894 year: 2025,
3895 month: 10,
3896 day: 20,
3897 hour: None,
3898 minute: None,
3899 second: None
3900 }
3901 ));
3902 } else {
3903 panic!("Expected BinaryOp with DateTime constructor");
3904 }
3905 }
3906
3907 #[test]
3908 fn test_cursor_context_after_and() {
3909 let query = "SELECT * FROM trade_deal WHERE status = 'active' AND ";
3910 let (context, partial) = detect_cursor_context(query, query.len());
3911
3912 assert!(matches!(
3913 context,
3914 CursorContext::AfterLogicalOp(LogicalOp::And)
3915 ));
3916 assert_eq!(partial, None);
3917 }
3918
3919 #[test]
3920 fn test_cursor_context_with_partial() {
3921 let query = "SELECT * FROM trade_deal WHERE status = 'active' AND p";
3922 let (context, partial) = detect_cursor_context(query, query.len());
3923
3924 assert!(matches!(
3925 context,
3926 CursorContext::AfterLogicalOp(LogicalOp::And)
3927 ));
3928 assert_eq!(partial, Some("p".to_string()));
3929 }
3930
3931 #[test]
3932 fn test_cursor_context_after_datetime_comparison() {
3933 let query = "SELECT * FROM trade_deal WHERE createdDate > ";
3934 let (context, partial) = detect_cursor_context(query, query.len());
3935
3936 assert!(
3937 matches!(context, CursorContext::AfterComparisonOp(col, op) if col == "createdDate" && op == ">")
3938 );
3939 assert_eq!(partial, None);
3940 }
3941
3942 #[test]
3943 fn test_cursor_context_partial_datetime() {
3944 let query = "SELECT * FROM trade_deal WHERE createdDate > Date";
3945 let (context, partial) = detect_cursor_context(query, query.len());
3946
3947 assert!(
3948 matches!(context, CursorContext::AfterComparisonOp(col, op) if col == "createdDate" && op == ">")
3949 );
3950 assert_eq!(partial, Some("Date".to_string()));
3951 }
3952
3953 #[test]
3955 fn test_tokenizer_quoted_identifier() {
3956 let mut lexer = Lexer::new(r#"SELECT "Customer Id", "First Name" FROM customers"#);
3957
3958 assert!(matches!(lexer.next_token(), Token::Select));
3959 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "Customer Id"));
3960 assert!(matches!(lexer.next_token(), Token::Comma));
3961 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "First Name"));
3962 assert!(matches!(lexer.next_token(), Token::From));
3963 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "customers"));
3964 }
3965
3966 #[test]
3967 fn test_tokenizer_quoted_vs_string_literal() {
3968 let mut lexer = Lexer::new(r#"WHERE "Customer Id" = 'John' AND Country.Contains('USA')"#);
3970
3971 assert!(matches!(lexer.next_token(), Token::Where));
3972 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "Customer Id"));
3973 assert!(matches!(lexer.next_token(), Token::Equal));
3974 assert!(matches!(lexer.next_token(), Token::StringLiteral(s) if s == "John"));
3975 assert!(matches!(lexer.next_token(), Token::And));
3976 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Country"));
3977 assert!(matches!(lexer.next_token(), Token::Dot));
3978 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Contains"));
3979 assert!(matches!(lexer.next_token(), Token::LeftParen));
3980 assert!(matches!(lexer.next_token(), Token::StringLiteral(s) if s == "USA"));
3981 assert!(matches!(lexer.next_token(), Token::RightParen));
3982 }
3983
3984 #[test]
3985 fn test_tokenizer_method_with_double_quotes_should_be_string() {
3986 let mut lexer = Lexer::new(r#"Country.Contains("Alb")"#);
3989
3990 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Country"));
3991 assert!(matches!(lexer.next_token(), Token::Dot));
3992 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Contains"));
3993 assert!(matches!(lexer.next_token(), Token::LeftParen));
3994
3995 let token = lexer.next_token();
3998 println!("Token for \"Alb\": {token:?}");
3999 assert!(matches!(lexer.next_token(), Token::RightParen));
4003 }
4004
4005 #[test]
4006 fn test_parse_select_with_quoted_columns() {
4007 let mut parser = Parser::new(r#"SELECT "Customer Id", Company FROM customers"#);
4008 let stmt = parser.parse().unwrap();
4009
4010 assert_eq!(stmt.columns, vec!["Customer Id", "Company"]);
4011 assert_eq!(stmt.from_table, Some("customers".to_string()));
4012 }
4013
4014 #[test]
4015 fn test_cursor_context_select_with_partial_quoted() {
4016 let query = r#"SELECT "Cust"#;
4018 let (context, partial) = detect_cursor_context(query, query.len() - 1); println!("Context: {context:?}, Partial: {partial:?}");
4021 assert!(matches!(context, CursorContext::SelectClause));
4022 }
4025
4026 #[test]
4027 fn test_cursor_context_select_after_comma_with_quoted() {
4028 let query = r#"SELECT Company, "Customer "#;
4030 let (context, partial) = detect_cursor_context(query, query.len());
4031
4032 println!("Context: {context:?}, Partial: {partial:?}");
4033 assert!(matches!(context, CursorContext::SelectClause));
4034 }
4036
4037 #[test]
4038 fn test_cursor_context_order_by_quoted() {
4039 let query = r#"SELECT * FROM customers ORDER BY "Cust"#;
4040 let (context, partial) = detect_cursor_context(query, query.len() - 1);
4041
4042 println!("Context: {context:?}, Partial: {partial:?}");
4043 assert!(matches!(context, CursorContext::OrderByClause));
4044 }
4046
4047 #[test]
4048 fn test_where_clause_with_quoted_column() {
4049 let mut parser = Parser::new(r#"SELECT * FROM customers WHERE "Customer Id" = 'C123'"#);
4050 let stmt = parser.parse().unwrap();
4051
4052 assert!(stmt.where_clause.is_some());
4053 let where_clause = stmt.where_clause.unwrap();
4054 assert_eq!(where_clause.conditions.len(), 1);
4055
4056 if let SqlExpression::BinaryOp { left, op, right } = &where_clause.conditions[0].expr {
4057 assert_eq!(op, "=");
4058 assert!(matches!(left.as_ref(), SqlExpression::Column(col) if col == "Customer Id"));
4059 assert!(matches!(right.as_ref(), SqlExpression::StringLiteral(s) if s == "C123"));
4060 } else {
4061 panic!("Expected BinaryOp");
4062 }
4063 }
4064
4065 #[test]
4066 fn test_parse_method_with_double_quotes_as_string() {
4067 let mut parser = Parser::new(r#"SELECT * FROM customers WHERE Country.Contains("USA")"#);
4069 let stmt = parser.parse().unwrap();
4070
4071 assert!(stmt.where_clause.is_some());
4072 let where_clause = stmt.where_clause.unwrap();
4073 assert_eq!(where_clause.conditions.len(), 1);
4074
4075 if let SqlExpression::MethodCall {
4076 object,
4077 method,
4078 args,
4079 } = &where_clause.conditions[0].expr
4080 {
4081 assert_eq!(object, "Country");
4082 assert_eq!(method, "Contains");
4083 assert_eq!(args.len(), 1);
4084 assert!(matches!(&args[0], SqlExpression::StringLiteral(s) if s == "USA"));
4086 } else {
4087 panic!("Expected MethodCall");
4088 }
4089 }
4090
4091 #[test]
4092 fn test_extract_partial_with_quoted_columns_in_query() {
4093 let query = r#"SELECT City,Company,Country,"Customer Id" FROM customers ORDER BY coun"#;
4095 let (context, partial) = detect_cursor_context(query, query.len());
4096
4097 assert!(matches!(context, CursorContext::OrderByClause));
4098 assert_eq!(
4099 partial,
4100 Some("coun".to_string()),
4101 "Should extract 'coun' as partial, not everything after the quoted column"
4102 );
4103 }
4104
4105 #[test]
4106 fn test_extract_partial_quoted_identifier_being_typed() {
4107 let query = r#"SELECT "Cust"#;
4109 let partial = extract_partial_at_end(query);
4110 assert_eq!(partial, Some("\"Cust".to_string()));
4111
4112 let query2 = r#"SELECT "Customer Id" FROM"#;
4114 let partial2 = extract_partial_at_end(query2);
4115 assert_eq!(partial2, None); }
4117
4118 #[test]
4120 fn test_complex_where_parentheses_basic() {
4121 let mut parser =
4123 Parser::new(r#"SELECT * FROM trades WHERE (status = "active" OR status = "pending")"#);
4124 let stmt = parser.parse().unwrap();
4125
4126 assert!(stmt.where_clause.is_some());
4127 let where_clause = stmt.where_clause.unwrap();
4128 assert_eq!(where_clause.conditions.len(), 1);
4129
4130 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[0].expr {
4132 assert_eq!(op, "OR");
4133 } else {
4134 panic!("Expected BinaryOp with OR");
4135 }
4136 }
4137
4138 #[test]
4139 fn test_complex_where_mixed_and_or_with_parens() {
4140 let mut parser = Parser::new(
4142 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#,
4143 );
4144 let stmt = parser.parse().unwrap();
4145
4146 assert!(stmt.where_clause.is_some());
4147 let where_clause = stmt.where_clause.unwrap();
4148 assert_eq!(where_clause.conditions.len(), 2);
4149
4150 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[0].expr {
4152 assert_eq!(op, "OR");
4153 } else {
4154 panic!("Expected first condition to be OR expression");
4155 }
4156
4157 assert!(matches!(
4159 where_clause.conditions[0].connector,
4160 Some(LogicalOp::And)
4161 ));
4162
4163 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[1].expr {
4165 assert_eq!(op, ">");
4166 } else {
4167 panic!("Expected second condition to be price > 100");
4168 }
4169 }
4170
4171 #[test]
4172 fn test_complex_where_nested_parentheses() {
4173 let mut parser = Parser::new(
4175 r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND price > 100) OR status = "cancelled""#,
4176 );
4177 let stmt = parser.parse().unwrap();
4178
4179 assert!(stmt.where_clause.is_some());
4180 let where_clause = stmt.where_clause.unwrap();
4181
4182 assert!(!where_clause.conditions.is_empty());
4184 }
4185
4186 #[test]
4187 fn test_complex_where_multiple_or_groups() {
4188 let mut parser = Parser::new(
4190 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL" OR symbol = "MSFT") AND (price > 100 AND price < 500)"#,
4191 );
4192 let stmt = parser.parse().unwrap();
4193
4194 assert!(stmt.where_clause.is_some());
4195 let where_clause = stmt.where_clause.unwrap();
4196 assert_eq!(where_clause.conditions.len(), 2);
4197
4198 assert!(matches!(
4200 where_clause.conditions[0].connector,
4201 Some(LogicalOp::And)
4202 ));
4203 }
4204
4205 #[test]
4206 fn test_complex_where_with_methods_in_parens() {
4207 let mut parser = Parser::new(
4209 r#"SELECT * FROM trades WHERE (symbol.StartsWith("A") OR symbol.StartsWith("G")) AND volume > 1000000"#,
4210 );
4211 let stmt = parser.parse().unwrap();
4212
4213 assert!(stmt.where_clause.is_some());
4214 let where_clause = stmt.where_clause.unwrap();
4215 assert_eq!(where_clause.conditions.len(), 2);
4216
4217 if let SqlExpression::BinaryOp { op, left, right } = &where_clause.conditions[0].expr {
4219 assert_eq!(op, "OR");
4220 assert!(matches!(left.as_ref(), SqlExpression::MethodCall { .. }));
4221 assert!(matches!(right.as_ref(), SqlExpression::MethodCall { .. }));
4222 } else {
4223 panic!("Expected OR of method calls");
4224 }
4225 }
4226
4227 #[test]
4228 fn test_complex_where_date_comparisons_with_parens() {
4229 let mut parser = Parser::new(
4231 r#"SELECT * FROM trades WHERE (executionDate > DateTime(2024, 1, 1) AND executionDate < DateTime(2024, 12, 31)) AND (status = "filled" OR status = "partial")"#,
4232 );
4233 let stmt = parser.parse().unwrap();
4234
4235 assert!(stmt.where_clause.is_some());
4236 let where_clause = stmt.where_clause.unwrap();
4237 assert_eq!(where_clause.conditions.len(), 2);
4238
4239 assert!(matches!(
4241 where_clause.conditions[0].connector,
4242 Some(LogicalOp::And)
4243 ));
4244 }
4245
4246 #[test]
4247 fn test_complex_where_price_volume_filters() {
4248 let mut parser = Parser::new(
4250 r"SELECT * FROM trades WHERE ((price > 100 AND price < 200) OR (price > 500 AND price < 1000)) AND volume > 10000",
4251 );
4252 let stmt = parser.parse().unwrap();
4253
4254 assert!(stmt.where_clause.is_some());
4255 let where_clause = stmt.where_clause.unwrap();
4256
4257 assert!(!where_clause.conditions.is_empty());
4259 }
4260
4261 #[test]
4262 fn test_complex_where_mixed_string_numeric() {
4263 let mut parser = Parser::new(
4265 r#"SELECT * FROM trades WHERE (exchange = "NYSE" OR exchange = "NASDAQ") AND (volume > 1000000 OR notes.Contains("urgent"))"#,
4266 );
4267 let stmt = parser.parse().unwrap();
4268
4269 assert!(stmt.where_clause.is_some());
4270 }
4272
4273 #[test]
4274 fn test_complex_where_triple_nested() {
4275 let mut parser = Parser::new(
4277 r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND (price > 100 OR volume > 1000000)) OR (status = "cancelled" AND reason.Contains("timeout"))"#,
4278 );
4279 let stmt = parser.parse().unwrap();
4280
4281 assert!(stmt.where_clause.is_some());
4282 }
4284
4285 #[test]
4286 fn test_complex_where_single_parens_around_and() {
4287 let mut parser = Parser::new(
4289 r#"SELECT * FROM trades WHERE (symbol = "AAPL" AND price > 150 AND volume > 100000)"#,
4290 );
4291 let stmt = parser.parse().unwrap();
4292
4293 assert!(stmt.where_clause.is_some());
4294 let where_clause = stmt.where_clause.unwrap();
4295
4296 assert!(!where_clause.conditions.is_empty());
4298 }
4299
4300 #[test]
4302 fn test_format_preserves_simple_parentheses() {
4303 let query = r#"SELECT * FROM trades WHERE (status = "active" OR status = "pending")"#;
4304 let formatted = format_sql_pretty_compact(query, 5);
4305 let formatted_text = formatted.join(" ");
4306
4307 assert!(formatted_text.contains("(status"));
4309 assert!(formatted_text.contains("\"pending\")"));
4310
4311 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
4313 let formatted_parens = formatted_text
4314 .chars()
4315 .filter(|c| *c == '(' || *c == ')')
4316 .count();
4317 assert_eq!(
4318 original_parens, formatted_parens,
4319 "Parentheses should be preserved"
4320 );
4321 }
4322
4323 #[test]
4324 fn test_format_preserves_complex_parentheses() {
4325 let query =
4326 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#;
4327 let formatted = format_sql_pretty_compact(query, 5);
4328 let formatted_text = formatted.join(" ");
4329
4330 assert!(formatted_text.contains("(symbol"));
4332 assert!(formatted_text.contains("\"GOOGL\")"));
4333
4334 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
4336 let formatted_parens = formatted_text
4337 .chars()
4338 .filter(|c| *c == '(' || *c == ')')
4339 .count();
4340 assert_eq!(original_parens, formatted_parens);
4341 }
4342
4343 #[test]
4344 fn test_format_preserves_nested_parentheses() {
4345 let query = r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND price > 100) OR status = "cancelled""#;
4346 let formatted = format_sql_pretty_compact(query, 5);
4347 let formatted_text = formatted.join(" ");
4348
4349 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
4351 let formatted_parens = formatted_text
4352 .chars()
4353 .filter(|c| *c == '(' || *c == ')')
4354 .count();
4355 assert_eq!(
4356 original_parens, formatted_parens,
4357 "Nested parentheses should be preserved"
4358 );
4359 assert_eq!(original_parens, 4, "Should have 4 parentheses total");
4360 }
4361
4362 #[test]
4363 fn test_format_preserves_method_calls_in_parentheses() {
4364 let query = r#"SELECT * FROM trades WHERE (symbol.StartsWith("A") OR symbol.StartsWith("G")) AND volume > 1000000"#;
4365 let formatted = format_sql_pretty_compact(query, 5);
4366 let formatted_text = formatted.join(" ");
4367
4368 assert!(formatted_text.contains("(symbol.StartsWith"));
4370 assert!(formatted_text.contains("StartsWith(\"A\")"));
4371 assert!(formatted_text.contains("StartsWith(\"G\")"));
4372
4373 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
4375 let formatted_parens = formatted_text
4376 .chars()
4377 .filter(|c| *c == '(' || *c == ')')
4378 .count();
4379 assert_eq!(original_parens, formatted_parens);
4380 assert_eq!(
4381 original_parens, 6,
4382 "Should have 6 parentheses (1 group + 2 method calls)"
4383 );
4384 }
4385
4386 #[test]
4387 fn test_format_preserves_multiple_groups() {
4388 let query = r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND (price > 100 AND price < 500)"#;
4389 let formatted = format_sql_pretty_compact(query, 5);
4390 let formatted_text = formatted.join(" ");
4391
4392 assert!(formatted_text.contains("(symbol"));
4394 assert!(formatted_text.contains("(price"));
4395
4396 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
4397 let formatted_parens = formatted_text
4398 .chars()
4399 .filter(|c| *c == '(' || *c == ')')
4400 .count();
4401 assert_eq!(original_parens, formatted_parens);
4402 assert_eq!(original_parens, 4, "Should have 4 parentheses (2 groups)");
4403 }
4404
4405 #[test]
4406 fn test_format_preserves_date_ranges() {
4407 let query = r"SELECT * FROM trades WHERE (executionDate > DateTime(2024, 1, 1) AND executionDate < DateTime(2024, 12, 31))";
4408 let formatted = format_sql_pretty_compact(query, 5);
4409 let formatted_text = formatted.join(" ");
4410
4411 assert!(formatted_text.contains("(executionDate"));
4413 assert!(formatted_text.contains("DateTime(2024, 1, 1)"));
4414 assert!(formatted_text.contains("DateTime(2024, 12, 31)"));
4415
4416 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
4417 let formatted_parens = formatted_text
4418 .chars()
4419 .filter(|c| *c == '(' || *c == ')')
4420 .count();
4421 assert_eq!(original_parens, formatted_parens);
4422 }
4423
4424 #[test]
4425 fn test_format_multiline_layout() {
4426 let query =
4428 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#;
4429 let formatted = format_sql_pretty_compact(query, 5);
4430
4431 assert!(formatted.len() >= 4, "Should have multiple lines");
4433 assert_eq!(formatted[0], "SELECT");
4434 assert!(formatted[1].trim().starts_with('*'));
4435 assert!(formatted[2].starts_with("FROM"));
4436 assert_eq!(formatted[3], "WHERE");
4437
4438 let where_lines: Vec<_> = formatted.iter().skip(4).collect();
4440 assert!(where_lines.iter().any(|l| l.contains("(symbol")));
4441 assert!(where_lines.iter().any(|l| l.trim() == "AND"));
4442 }
4443
4444 #[test]
4445 fn test_between_simple() {
4446 let mut parser = Parser::new("SELECT * FROM table WHERE price BETWEEN 50 AND 100");
4447 let stmt = parser.parse().expect("Should parse simple BETWEEN");
4448
4449 assert!(stmt.where_clause.is_some());
4450 let where_clause = stmt.where_clause.unwrap();
4451 assert_eq!(where_clause.conditions.len(), 1);
4452
4453 let ast = format_ast_tree("SELECT * FROM table WHERE price BETWEEN 50 AND 100");
4455 assert!(!ast.contains("PARSE ERROR"));
4456 assert!(ast.contains("SelectStatement"));
4457 }
4458
4459 #[test]
4460 fn test_between_in_parentheses() {
4461 let mut parser = Parser::new("SELECT * FROM table WHERE (price BETWEEN 50 AND 100)");
4462 let stmt = parser.parse().expect("Should parse BETWEEN in parentheses");
4463
4464 assert!(stmt.where_clause.is_some());
4465
4466 let ast = format_ast_tree("SELECT * FROM table WHERE (price BETWEEN 50 AND 100)");
4468 assert!(!ast.contains("PARSE ERROR"), "Should not have parse error");
4469 }
4470
4471 #[test]
4472 fn test_between_with_or() {
4473 let query = "SELECT * FROM test WHERE (Price BETWEEN 50 AND 100) OR (quantity > 5)";
4474 let mut parser = Parser::new(query);
4475 let stmt = parser.parse().expect("Should parse BETWEEN with OR");
4476
4477 assert!(stmt.where_clause.is_some());
4478 }
4481
4482 #[test]
4483 fn test_between_with_and() {
4484 let query = "SELECT * FROM table WHERE category = 'Books' AND price BETWEEN 10 AND 50";
4485 let mut parser = Parser::new(query);
4486 let stmt = parser.parse().expect("Should parse BETWEEN with AND");
4487
4488 assert!(stmt.where_clause.is_some());
4489 let where_clause = stmt.where_clause.unwrap();
4490 assert_eq!(where_clause.conditions.len(), 2); }
4492
4493 #[test]
4494 fn test_multiple_between() {
4495 let query =
4496 "SELECT * FROM table WHERE (price BETWEEN 10 AND 50) AND (quantity BETWEEN 5 AND 20)";
4497 let mut parser = Parser::new(query);
4498 let stmt = parser
4499 .parse()
4500 .expect("Should parse multiple BETWEEN clauses");
4501
4502 assert!(stmt.where_clause.is_some());
4503 }
4504
4505 #[test]
4506 fn test_between_complex_query() {
4507 let query = "SELECT * FROM test_sorting WHERE (Price BETWEEN 50 AND 100) OR (Product.Length() > 5) ORDER BY Category ASC, price DESC";
4509 let mut parser = Parser::new(query);
4510 let stmt = parser
4511 .parse()
4512 .expect("Should parse complex query with BETWEEN, method calls, and ORDER BY");
4513
4514 assert!(stmt.where_clause.is_some());
4515 assert!(stmt.order_by.is_some());
4516
4517 let order_by = stmt.order_by.unwrap();
4518 assert_eq!(order_by.len(), 2);
4519 assert_eq!(order_by[0].column, "Category");
4520 assert!(matches!(order_by[0].direction, SortDirection::Asc));
4521 assert_eq!(order_by[1].column, "price");
4522 assert!(matches!(order_by[1].direction, SortDirection::Desc));
4523 }
4524
4525 #[test]
4526 fn test_between_formatting() {
4527 let expr = SqlExpression::Between {
4528 expr: Box::new(SqlExpression::Column("price".to_string())),
4529 lower: Box::new(SqlExpression::NumberLiteral("50".to_string())),
4530 upper: Box::new(SqlExpression::NumberLiteral("100".to_string())),
4531 };
4532
4533 let formatted = format_expression(&expr);
4534 assert_eq!(formatted, "price BETWEEN 50 AND 100");
4535
4536 let ast_formatted = format_expression_ast(&expr);
4537 assert!(ast_formatted.contains("Between"));
4538 assert!(ast_formatted.contains("50"));
4539 assert!(ast_formatted.contains("100"));
4540 }
4541
4542 #[test]
4543 fn test_utf8_boundary_safety() {
4544 let query_with_unicode = "SELECT * FROM table WHERE column = 'héllo'";
4546
4547 for pos in 0..=query_with_unicode.len() {
4549 let result =
4551 std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, pos));
4552
4553 assert!(
4554 result.is_ok(),
4555 "Panic at position {pos} in query with Unicode"
4556 );
4557 }
4558
4559 let result = std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, 1000));
4561 assert!(result.is_ok(), "Panic with position beyond string length");
4562
4563 let pos_in_e = query_with_unicode.find('é').unwrap() + 1; let result =
4566 std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, pos_in_e));
4567 assert!(
4568 result.is_ok(),
4569 "Panic with cursor in middle of UTF-8 character"
4570 );
4571 }
4572}