1use chrono::{Datelike, Local, NaiveDateTime};
2
3#[derive(Debug, Clone, PartialEq)]
4pub enum Token {
5 Select,
7 From,
8 Where,
9 With, And,
11 Or,
12 In,
13 Not,
14 Between,
15 Like,
16 Is,
17 Null,
18 OrderBy,
19 GroupBy,
20 Having,
21 As,
22 Asc,
23 Desc,
24 Limit,
25 Offset,
26 DateTime, Case, When, Then, Else, End, Distinct, Over, Partition, By, Identifier(String),
39 QuotedIdentifier(String), StringLiteral(String),
41 NumberLiteral(String),
42 Star,
43
44 Dot,
46 Comma,
47 LeftParen,
48 RightParen,
49 Equal,
50 NotEqual,
51 LessThan,
52 GreaterThan,
53 LessThanOrEqual,
54 GreaterThanOrEqual,
55
56 Plus,
58 Minus,
59 Divide,
60 Modulo,
61
62 Eof,
64}
65
66#[derive(Debug, Clone)]
67pub struct Lexer {
68 input: Vec<char>,
69 position: usize,
70 current_char: Option<char>,
71}
72
73impl Lexer {
74 #[must_use]
75 pub fn new(input: &str) -> Self {
76 let chars: Vec<char> = input.chars().collect();
77 let current = chars.first().copied();
78 Self {
79 input: chars,
80 position: 0,
81 current_char: current,
82 }
83 }
84
85 fn advance(&mut self) {
86 self.position += 1;
87 self.current_char = self.input.get(self.position).copied();
88 }
89
90 fn peek(&self, offset: usize) -> Option<char> {
91 self.input.get(self.position + offset).copied()
92 }
93
94 fn skip_whitespace(&mut self) {
95 while let Some(ch) = self.current_char {
96 if ch.is_whitespace() {
97 self.advance();
98 } else {
99 break;
100 }
101 }
102 }
103
104 fn skip_whitespace_and_comments(&mut self) {
105 loop {
106 while let Some(ch) = self.current_char {
108 if ch.is_whitespace() {
109 self.advance();
110 } else {
111 break;
112 }
113 }
114
115 match self.current_char {
117 Some('-') if self.peek(1) == Some('-') => {
118 self.advance(); self.advance(); while let Some(ch) = self.current_char {
122 self.advance();
123 if ch == '\n' {
124 break;
125 }
126 }
127 }
128 Some('/') if self.peek(1) == Some('*') => {
129 self.advance(); self.advance(); while let Some(ch) = self.current_char {
133 if ch == '*' && self.peek(1) == Some('/') {
134 self.advance(); self.advance(); break;
137 }
138 self.advance();
139 }
140 }
141 _ => {
142 break;
144 }
145 }
146 }
147 }
148
149 fn read_identifier(&mut self) -> String {
150 let mut result = String::new();
151 while let Some(ch) = self.current_char {
152 if ch.is_alphanumeric() || ch == '_' {
153 result.push(ch);
154 self.advance();
155 } else {
156 break;
157 }
158 }
159 result
160 }
161
162 fn read_string(&mut self) -> String {
163 let mut result = String::new();
164 let quote_char = self.current_char.unwrap(); self.advance(); while let Some(ch) = self.current_char {
168 if ch == quote_char {
169 self.advance(); break;
171 }
172 result.push(ch);
173 self.advance();
174 }
175 result
176 }
177
178 fn read_number(&mut self) -> String {
179 let mut result = String::new();
180 let mut has_e = false;
181
182 while let Some(ch) = self.current_char {
184 if !has_e && (ch.is_numeric() || ch == '.') {
185 result.push(ch);
186 self.advance();
187 } else if (ch == 'e' || ch == 'E') && !has_e && !result.is_empty() {
188 result.push(ch);
190 self.advance();
191 has_e = true;
192
193 if let Some(sign) = self.current_char {
195 if sign == '+' || sign == '-' {
196 result.push(sign);
197 self.advance();
198 }
199 }
200
201 while let Some(digit) = self.current_char {
203 if digit.is_numeric() {
204 result.push(digit);
205 self.advance();
206 } else {
207 break;
208 }
209 }
210 break; } else {
212 break;
213 }
214 }
215 result
216 }
217
218 pub fn next_token(&mut self) -> Token {
219 self.skip_whitespace_and_comments();
220
221 match self.current_char {
222 None => Token::Eof,
223 Some('*') => {
224 self.advance();
225 Token::Star }
229 Some('+') => {
230 self.advance();
231 Token::Plus
232 }
233 Some('/') => {
234 if self.peek(1) == Some('*') {
236 self.skip_whitespace_and_comments();
239 return self.next_token();
240 }
241 self.advance();
242 Token::Divide
243 }
244 Some('%') => {
245 self.advance();
246 Token::Modulo
247 }
248 Some('.') => {
249 self.advance();
250 Token::Dot
251 }
252 Some(',') => {
253 self.advance();
254 Token::Comma
255 }
256 Some('(') => {
257 self.advance();
258 Token::LeftParen
259 }
260 Some(')') => {
261 self.advance();
262 Token::RightParen
263 }
264 Some('=') => {
265 self.advance();
266 Token::Equal
267 }
268 Some('<') => {
269 self.advance();
270 if self.current_char == Some('=') {
271 self.advance();
272 Token::LessThanOrEqual
273 } else if self.current_char == Some('>') {
274 self.advance();
275 Token::NotEqual
276 } else {
277 Token::LessThan
278 }
279 }
280 Some('>') => {
281 self.advance();
282 if self.current_char == Some('=') {
283 self.advance();
284 Token::GreaterThanOrEqual
285 } else {
286 Token::GreaterThan
287 }
288 }
289 Some('!') if self.peek(1) == Some('=') => {
290 self.advance();
291 self.advance();
292 Token::NotEqual
293 }
294 Some('"') => {
295 let ident_val = self.read_string();
297 Token::QuotedIdentifier(ident_val)
298 }
299 Some('\'') => {
300 let string_val = self.read_string();
302 Token::StringLiteral(string_val)
303 }
304 Some('-') if self.peek(1) == Some('-') => {
305 self.skip_whitespace_and_comments();
307 self.next_token()
308 }
309 Some('-') if self.peek(1).is_some_and(char::is_numeric) => {
310 self.advance(); let num = self.read_number();
313 Token::NumberLiteral(format!("-{num}"))
314 }
315 Some('-') => {
316 self.advance();
318 Token::Minus
319 }
320 Some(ch) if ch.is_numeric() => {
321 let num = self.read_number();
322 Token::NumberLiteral(num)
323 }
324 Some(ch) if ch.is_alphabetic() || ch == '_' => {
325 let ident = self.read_identifier();
326 match ident.to_uppercase().as_str() {
327 "SELECT" => Token::Select,
328 "FROM" => Token::From,
329 "WHERE" => Token::Where,
330 "WITH" => Token::With,
331 "AND" => Token::And,
332 "OR" => Token::Or,
333 "IN" => Token::In,
334 "NOT" => Token::Not,
335 "BETWEEN" => Token::Between,
336 "LIKE" => Token::Like,
337 "IS" => Token::Is,
338 "NULL" => Token::Null,
339 "ORDER" if self.peek_keyword("BY") => {
340 self.skip_whitespace();
341 self.read_identifier(); Token::OrderBy
343 }
344 "GROUP" if self.peek_keyword("BY") => {
345 self.skip_whitespace();
346 self.read_identifier(); Token::GroupBy
348 }
349 "HAVING" => Token::Having,
350 "AS" => Token::As,
351 "ASC" => Token::Asc,
352 "DESC" => Token::Desc,
353 "LIMIT" => Token::Limit,
354 "OFFSET" => Token::Offset,
355 "DATETIME" => Token::DateTime,
356 "CASE" => Token::Case,
357 "WHEN" => Token::When,
358 "THEN" => Token::Then,
359 "ELSE" => Token::Else,
360 "END" => Token::End,
361 "DISTINCT" => Token::Distinct,
362 "OVER" => Token::Over,
363 "PARTITION" => Token::Partition,
364 "BY" => Token::By,
365 _ => Token::Identifier(ident),
366 }
367 }
368 Some(ch) => {
369 self.advance();
370 Token::Identifier(ch.to_string())
371 }
372 }
373 }
374
375 fn peek_keyword(&mut self, keyword: &str) -> bool {
376 let saved_pos = self.position;
377 let saved_char = self.current_char;
378
379 self.skip_whitespace_and_comments();
380 let next_word = self.read_identifier();
381 let matches = next_word.to_uppercase() == keyword;
382
383 self.position = saved_pos;
385 self.current_char = saved_char;
386
387 matches
388 }
389
390 #[must_use]
391 pub fn get_position(&self) -> usize {
392 self.position
393 }
394
395 pub fn tokenize_all(&mut self) -> Vec<Token> {
396 let mut tokens = Vec::new();
397 loop {
398 let token = self.next_token();
399 if matches!(token, Token::Eof) {
400 tokens.push(token);
401 break;
402 }
403 tokens.push(token);
404 }
405 tokens
406 }
407
408 pub fn tokenize_all_with_positions(&mut self) -> Vec<(usize, usize, Token)> {
409 let mut tokens = Vec::new();
410 loop {
411 self.skip_whitespace_and_comments();
412 let start_pos = self.position;
413 let token = self.next_token();
414 let end_pos = self.position;
415
416 if matches!(token, Token::Eof) {
417 break;
418 }
419 tokens.push((start_pos, end_pos, token));
420 }
421 tokens
422 }
423}
424
425#[derive(Debug, Clone)]
427pub enum SqlExpression {
428 Column(String),
429 StringLiteral(String),
430 NumberLiteral(String),
431 BooleanLiteral(bool),
432 Null, DateTimeConstructor {
434 year: i32,
435 month: u32,
436 day: u32,
437 hour: Option<u32>,
438 minute: Option<u32>,
439 second: Option<u32>,
440 },
441 DateTimeToday {
442 hour: Option<u32>,
443 minute: Option<u32>,
444 second: Option<u32>,
445 },
446 MethodCall {
447 object: String,
448 method: String,
449 args: Vec<SqlExpression>,
450 },
451 ChainedMethodCall {
452 base: Box<SqlExpression>,
453 method: String,
454 args: Vec<SqlExpression>,
455 },
456 FunctionCall {
457 name: String,
458 args: Vec<SqlExpression>,
459 distinct: bool, },
461 WindowFunction {
462 name: String,
463 args: Vec<SqlExpression>,
464 window_spec: WindowSpec,
465 },
466 BinaryOp {
467 left: Box<SqlExpression>,
468 op: String,
469 right: Box<SqlExpression>,
470 },
471 InList {
472 expr: Box<SqlExpression>,
473 values: Vec<SqlExpression>,
474 },
475 NotInList {
476 expr: Box<SqlExpression>,
477 values: Vec<SqlExpression>,
478 },
479 Between {
480 expr: Box<SqlExpression>,
481 lower: Box<SqlExpression>,
482 upper: Box<SqlExpression>,
483 },
484 Not {
485 expr: Box<SqlExpression>,
486 },
487 CaseExpression {
488 when_branches: Vec<WhenBranch>,
489 else_branch: Option<Box<SqlExpression>>,
490 },
491}
492
493#[derive(Debug, Clone)]
494pub struct WhenBranch {
495 pub condition: Box<SqlExpression>,
496 pub result: Box<SqlExpression>,
497}
498
499#[derive(Debug, Clone)]
500pub struct WhereClause {
501 pub conditions: Vec<Condition>,
502}
503
504#[derive(Debug, Clone)]
505pub struct Condition {
506 pub expr: SqlExpression,
507 pub connector: Option<LogicalOp>, }
509
510#[derive(Debug, Clone)]
511pub enum LogicalOp {
512 And,
513 Or,
514}
515
516#[derive(Debug, Clone, PartialEq)]
517pub enum SortDirection {
518 Asc,
519 Desc,
520}
521
522#[derive(Debug, Clone)]
523pub struct OrderByColumn {
524 pub column: String,
525 pub direction: SortDirection,
526}
527
528#[derive(Debug, Clone)]
529pub struct WindowSpec {
530 pub partition_by: Vec<String>,
531 pub order_by: Vec<OrderByColumn>,
532}
533
534#[derive(Debug, Clone)]
536pub enum SelectItem {
537 Column(String),
539 Expression { expr: SqlExpression, alias: String },
541 Star,
543}
544
545#[derive(Debug, Clone)]
546pub struct SelectStatement {
547 pub distinct: bool, pub columns: Vec<String>, pub select_items: Vec<SelectItem>, pub from_table: Option<String>,
551 pub from_subquery: Option<Box<SelectStatement>>, pub from_function: Option<TableFunction>, pub from_alias: Option<String>, pub where_clause: Option<WhereClause>,
555 pub order_by: Option<Vec<OrderByColumn>>,
556 pub group_by: Option<Vec<String>>,
557 pub having: Option<SqlExpression>, pub limit: Option<usize>,
559 pub offset: Option<usize>,
560 pub ctes: Vec<CTE>, }
562
563#[derive(Debug, Clone)]
565pub enum TableFunction {
566 Range {
567 start: SqlExpression,
568 end: SqlExpression,
569 step: Option<SqlExpression>,
570 },
571}
572
573#[derive(Debug, Clone)]
575pub struct CTE {
576 pub name: String,
577 pub column_list: Option<Vec<String>>, pub query: SelectStatement,
579}
580
581#[derive(Debug, Clone)]
583pub enum TableSource {
584 Table(String), DerivedTable {
586 query: Box<SelectStatement>,
588 alias: String, },
590}
591
592#[derive(Default)]
593pub struct ParserConfig {
594 pub case_insensitive: bool,
595}
596
597pub struct Parser {
598 lexer: Lexer,
599 current_token: Token,
600 in_method_args: bool, columns: Vec<String>, paren_depth: i32, #[allow(dead_code)]
604 config: ParserConfig, }
606
607impl Parser {
608 #[must_use]
609 pub fn new(input: &str) -> Self {
610 let mut lexer = Lexer::new(input);
611 let current_token = lexer.next_token();
612 Self {
613 lexer,
614 current_token,
615 in_method_args: false,
616 columns: Vec::new(),
617 paren_depth: 0,
618 config: ParserConfig::default(),
619 }
620 }
621
622 #[must_use]
623 pub fn with_config(input: &str, config: ParserConfig) -> Self {
624 let mut lexer = Lexer::new(input);
625 let current_token = lexer.next_token();
626 Self {
627 lexer,
628 current_token,
629 in_method_args: false,
630 columns: Vec::new(),
631 paren_depth: 0,
632 config,
633 }
634 }
635
636 #[must_use]
637 pub fn with_columns(mut self, columns: Vec<String>) -> Self {
638 self.columns = columns;
639 self
640 }
641
642 fn consume(&mut self, expected: Token) -> Result<(), String> {
643 if std::mem::discriminant(&self.current_token) == std::mem::discriminant(&expected) {
644 match &expected {
646 Token::LeftParen => self.paren_depth += 1,
647 Token::RightParen => {
648 self.paren_depth -= 1;
649 if self.paren_depth < 0 {
651 return Err(
652 "Unexpected closing parenthesis - no matching opening parenthesis"
653 .to_string(),
654 );
655 }
656 }
657 _ => {}
658 }
659
660 self.current_token = self.lexer.next_token();
661 Ok(())
662 } else {
663 let error_msg = match (&expected, &self.current_token) {
665 (Token::RightParen, Token::Eof) if self.paren_depth > 0 => {
666 format!(
667 "Unclosed parenthesis - missing {} closing parenthes{}",
668 self.paren_depth,
669 if self.paren_depth == 1 { "is" } else { "es" }
670 )
671 }
672 (Token::RightParen, _) if self.paren_depth > 0 => {
673 format!(
674 "Expected closing parenthesis but found {:?} (currently {} unclosed parenthes{})",
675 self.current_token,
676 self.paren_depth,
677 if self.paren_depth == 1 { "is" } else { "es" }
678 )
679 }
680 _ => format!("Expected {:?}, found {:?}", expected, self.current_token),
681 };
682 Err(error_msg)
683 }
684 }
685
686 fn advance(&mut self) {
687 match &self.current_token {
689 Token::LeftParen => self.paren_depth += 1,
690 Token::RightParen => {
691 self.paren_depth -= 1;
692 }
695 _ => {}
696 }
697 self.current_token = self.lexer.next_token();
698 }
699
700 pub fn parse(&mut self) -> Result<SelectStatement, String> {
701 if matches!(self.current_token, Token::With) {
703 self.parse_with_clause()
704 } else {
705 self.parse_select_statement()
706 }
707 }
708
709 fn parse_with_clause(&mut self) -> Result<SelectStatement, String> {
710 self.consume(Token::With)?;
711
712 let mut ctes = Vec::new();
713
714 loop {
716 let name = match &self.current_token {
718 Token::Identifier(name) => name.clone(),
719 _ => return Err("Expected CTE name after WITH".to_string()),
720 };
721 self.advance();
722
723 let column_list = if matches!(self.current_token, Token::LeftParen) {
725 self.advance();
726 let cols = self.parse_identifier_list()?;
727 self.consume(Token::RightParen)?;
728 Some(cols)
729 } else {
730 None
731 };
732
733 self.consume(Token::As)?;
735
736 self.consume(Token::LeftParen)?;
738
739 let query = self.parse_select_statement_inner()?;
741
742 self.consume(Token::RightParen)?;
744
745 ctes.push(CTE {
746 name,
747 column_list,
748 query,
749 });
750
751 if !matches!(self.current_token, Token::Comma) {
753 break;
754 }
755 self.advance();
756 }
757
758 let mut main_query = self.parse_select_statement()?;
760 main_query.ctes = ctes;
761
762 Ok(main_query)
763 }
764
765 fn parse_select_statement(&mut self) -> Result<SelectStatement, String> {
766 let result = self.parse_select_statement_inner()?;
767
768 if self.paren_depth > 0 {
770 return Err(format!(
771 "Unclosed parenthesis - missing {} closing parenthes{}",
772 self.paren_depth,
773 if self.paren_depth == 1 { "is" } else { "es" }
774 ));
775 } else if self.paren_depth < 0 {
776 return Err(
777 "Extra closing parenthesis found - no matching opening parenthesis".to_string(),
778 );
779 }
780
781 Ok(result)
782 }
783
784 fn parse_select_statement_inner(&mut self) -> Result<SelectStatement, String> {
785 self.consume(Token::Select)?;
786
787 let distinct = if matches!(self.current_token, Token::Distinct) {
789 self.advance();
790 true
791 } else {
792 false
793 };
794
795 let select_items = self.parse_select_items()?;
797
798 let columns = select_items
800 .iter()
801 .map(|item| match item {
802 SelectItem::Star => "*".to_string(),
803 SelectItem::Column(name) => name.clone(),
804 SelectItem::Expression { alias, .. } => alias.clone(),
805 })
806 .collect();
807
808 let (from_table, from_subquery, from_function, from_alias) =
810 if matches!(self.current_token, Token::From) {
811 self.advance();
812
813 if let Token::Identifier(name) = &self.current_token.clone() {
815 if name.to_uppercase() == "RANGE" {
816 self.advance();
817 self.consume(Token::LeftParen)?;
819
820 let start = self.parse_expression()?;
822 self.consume(Token::Comma)?;
823
824 let end = self.parse_expression()?;
826
827 let step = if matches!(self.current_token, Token::Comma) {
829 self.advance();
830 Some(self.parse_expression()?)
831 } else {
832 None
833 };
834
835 self.consume(Token::RightParen)?;
836
837 let alias = if matches!(self.current_token, Token::As) {
839 self.advance();
840 match &self.current_token {
841 Token::Identifier(name) => {
842 let alias = name.clone();
843 self.advance();
844 Some(alias)
845 }
846 _ => return Err("Expected alias name after AS".to_string()),
847 }
848 } else if let Token::Identifier(name) = &self.current_token {
849 let alias = name.clone();
850 self.advance();
851 Some(alias)
852 } else {
853 None
854 };
855
856 (
857 None,
858 None,
859 Some(TableFunction::Range { start, end, step }),
860 alias,
861 )
862 } else {
863 let table_name = name.clone();
865 self.advance();
866
867 let alias = if matches!(self.current_token, Token::As) {
869 self.advance();
870 match &self.current_token {
871 Token::Identifier(name) => {
872 let alias = name.clone();
873 self.advance();
874 Some(alias)
875 }
876 _ => return Err("Expected alias name after AS".to_string()),
877 }
878 } else if let Token::Identifier(name) = &self.current_token {
879 let alias = name.clone();
881 self.advance();
882 Some(alias)
883 } else {
884 None
885 };
886
887 (Some(table_name), None, None, alias)
888 }
889 } else if matches!(self.current_token, Token::LeftParen) {
890 self.advance();
892
893 let subquery = self.parse_select_statement_inner()?;
895
896 self.consume(Token::RightParen)?;
897
898 let alias = if matches!(self.current_token, Token::As) {
900 self.advance();
901 match &self.current_token {
902 Token::Identifier(name) => {
903 let alias = name.clone();
904 self.advance();
905 alias
906 }
907 _ => return Err("Expected alias name after AS".to_string()),
908 }
909 } else {
910 match &self.current_token {
912 Token::Identifier(name) => {
913 let alias = name.clone();
914 self.advance();
915 alias
916 }
917 _ => {
918 return Err(
919 "Subquery in FROM must have an alias (e.g., AS t)".to_string()
920 )
921 }
922 }
923 };
924
925 (None, Some(Box::new(subquery)), None, Some(alias))
926 } else {
927 match &self.current_token {
929 Token::Identifier(table) => {
930 let table_name = table.clone();
931 self.advance();
932
933 let alias = if matches!(self.current_token, Token::As) {
935 self.advance();
936 match &self.current_token {
937 Token::Identifier(name) => {
938 let alias = name.clone();
939 self.advance();
940 Some(alias)
941 }
942 _ => return Err("Expected alias name after AS".to_string()),
943 }
944 } else if let Token::Identifier(name) = &self.current_token {
945 let alias = name.clone();
947 self.advance();
948 Some(alias)
949 } else {
950 None
951 };
952
953 (Some(table_name), None, None, alias)
954 }
955 Token::QuotedIdentifier(table) => {
956 let table_name = table.clone();
958 self.advance();
959
960 let alias = if matches!(self.current_token, Token::As) {
962 self.advance();
963 match &self.current_token {
964 Token::Identifier(name) => {
965 let alias = name.clone();
966 self.advance();
967 Some(alias)
968 }
969 _ => return Err("Expected alias name after AS".to_string()),
970 }
971 } else if let Token::Identifier(name) = &self.current_token {
972 let alias = name.clone();
974 self.advance();
975 Some(alias)
976 } else {
977 None
978 };
979
980 (Some(table_name), None, None, alias)
981 }
982 _ => return Err("Expected table name or subquery after FROM".to_string()),
983 }
984 }
985 } else {
986 (None, None, None, None)
987 };
988
989 let where_clause = if matches!(self.current_token, Token::Where) {
990 self.advance();
991 Some(self.parse_where_clause()?)
992 } else {
993 None
994 };
995
996 let order_by = if matches!(self.current_token, Token::OrderBy) {
997 self.advance();
998 Some(self.parse_order_by_list()?)
999 } else {
1000 None
1001 };
1002
1003 let group_by = if matches!(self.current_token, Token::GroupBy) {
1004 self.advance();
1005 Some(self.parse_identifier_list()?)
1006 } else {
1007 None
1008 };
1009
1010 let having = if matches!(self.current_token, Token::Having) {
1012 if group_by.is_none() {
1013 return Err("HAVING clause requires GROUP BY".to_string());
1014 }
1015 self.advance();
1016 Some(self.parse_expression()?)
1017 } else {
1018 None
1019 };
1020
1021 let limit = if matches!(self.current_token, Token::Limit) {
1023 self.advance();
1024 match &self.current_token {
1025 Token::NumberLiteral(num) => {
1026 let limit_val = num
1027 .parse::<usize>()
1028 .map_err(|_| format!("Invalid LIMIT value: {num}"))?;
1029 self.advance();
1030 Some(limit_val)
1031 }
1032 _ => return Err("Expected number after LIMIT".to_string()),
1033 }
1034 } else {
1035 None
1036 };
1037
1038 let offset = if matches!(self.current_token, Token::Offset) {
1040 self.advance();
1041 match &self.current_token {
1042 Token::NumberLiteral(num) => {
1043 let offset_val = num
1044 .parse::<usize>()
1045 .map_err(|_| format!("Invalid OFFSET value: {num}"))?;
1046 self.advance();
1047 Some(offset_val)
1048 }
1049 _ => return Err("Expected number after OFFSET".to_string()),
1050 }
1051 } else {
1052 None
1053 };
1054
1055 Ok(SelectStatement {
1056 distinct,
1057 columns,
1058 select_items,
1059 from_table,
1060 from_subquery,
1061 from_function,
1062 from_alias,
1063 where_clause,
1064 order_by,
1065 group_by,
1066 having,
1067 limit,
1068 offset,
1069 ctes: Vec::new(), })
1071 }
1072
1073 fn parse_select_list(&mut self) -> Result<Vec<String>, String> {
1074 let mut columns = Vec::new();
1075
1076 if matches!(self.current_token, Token::Star) {
1077 columns.push("*".to_string());
1078 self.advance();
1079 } else {
1080 loop {
1081 match &self.current_token {
1082 Token::Identifier(col) => {
1083 columns.push(col.clone());
1084 self.advance();
1085 }
1086 Token::QuotedIdentifier(col) => {
1087 columns.push(col.clone());
1089 self.advance();
1090 }
1091 _ => return Err("Expected column name".to_string()),
1092 }
1093
1094 if matches!(self.current_token, Token::Comma) {
1095 self.advance();
1096 } else {
1097 break;
1098 }
1099 }
1100 }
1101
1102 Ok(columns)
1103 }
1104
1105 fn parse_select_items(&mut self) -> Result<Vec<SelectItem>, String> {
1107 let mut items = Vec::new();
1108
1109 loop {
1110 if matches!(self.current_token, Token::Star) {
1113 items.push(SelectItem::Star);
1121 self.advance();
1122 } else {
1123 let expr = self.parse_comparison()?; let alias = if matches!(self.current_token, Token::As) {
1128 self.advance();
1129 match &self.current_token {
1130 Token::Identifier(alias_name) => {
1131 let alias = alias_name.clone();
1132 self.advance();
1133 alias
1134 }
1135 Token::QuotedIdentifier(alias_name) => {
1136 let alias = alias_name.clone();
1137 self.advance();
1138 alias
1139 }
1140 _ => return Err("Expected alias name after AS".to_string()),
1141 }
1142 } else {
1143 match &expr {
1145 SqlExpression::Column(col_name) => col_name.clone(),
1146 _ => format!("expr_{}", items.len() + 1), }
1148 };
1149
1150 let item = match expr {
1152 SqlExpression::Column(col_name) if alias == col_name => {
1153 SelectItem::Column(col_name)
1155 }
1156 _ => {
1157 SelectItem::Expression { expr, alias }
1159 }
1160 };
1161
1162 items.push(item);
1163 }
1164
1165 if matches!(self.current_token, Token::Comma) {
1167 self.advance();
1168 } else {
1169 break;
1170 }
1171 }
1172
1173 Ok(items)
1174 }
1175
1176 fn parse_identifier_list(&mut self) -> Result<Vec<String>, String> {
1177 let mut identifiers = Vec::new();
1178
1179 loop {
1180 match &self.current_token {
1181 Token::Identifier(id) => {
1182 identifiers.push(id.clone());
1183 self.advance();
1184 }
1185 Token::QuotedIdentifier(id) => {
1186 identifiers.push(id.clone());
1188 self.advance();
1189 }
1190 _ => return Err("Expected identifier".to_string()),
1191 }
1192
1193 if matches!(self.current_token, Token::Comma) {
1194 self.advance();
1195 } else {
1196 break;
1197 }
1198 }
1199
1200 Ok(identifiers)
1201 }
1202
1203 fn parse_window_spec(&mut self) -> Result<WindowSpec, String> {
1204 let mut partition_by = Vec::new();
1205 let mut order_by = Vec::new();
1206
1207 if matches!(self.current_token, Token::Partition) {
1209 self.advance(); if !matches!(self.current_token, Token::By) {
1211 return Err("Expected BY after PARTITION".to_string());
1212 }
1213 self.advance(); partition_by = self.parse_identifier_list()?;
1217 }
1218
1219 if matches!(self.current_token, Token::OrderBy) {
1221 self.advance(); order_by = self.parse_order_by_list()?;
1223 } else if let Token::Identifier(s) = &self.current_token {
1224 if s.to_uppercase() == "ORDER" {
1225 self.advance(); if !matches!(self.current_token, Token::By) {
1228 return Err("Expected BY after ORDER".to_string());
1229 }
1230 self.advance(); order_by = self.parse_order_by_list()?;
1232 }
1233 }
1234
1235 Ok(WindowSpec {
1236 partition_by,
1237 order_by,
1238 })
1239 }
1240
1241 fn parse_order_by_list(&mut self) -> Result<Vec<OrderByColumn>, String> {
1242 let mut order_columns = Vec::new();
1243
1244 loop {
1245 let column = match &self.current_token {
1246 Token::Identifier(id) => {
1247 let col = id.clone();
1248 self.advance();
1249 col
1250 }
1251 Token::QuotedIdentifier(id) => {
1252 let col = id.clone();
1253 self.advance();
1254 col
1255 }
1256 Token::NumberLiteral(num) if self.columns.iter().any(|col| col == num) => {
1257 let col = num.clone();
1259 self.advance();
1260 col
1261 }
1262 _ => return Err("Expected column name in ORDER BY".to_string()),
1263 };
1264
1265 let direction = match &self.current_token {
1267 Token::Asc => {
1268 self.advance();
1269 SortDirection::Asc
1270 }
1271 Token::Desc => {
1272 self.advance();
1273 SortDirection::Desc
1274 }
1275 _ => SortDirection::Asc, };
1277
1278 order_columns.push(OrderByColumn { column, direction });
1279
1280 if matches!(self.current_token, Token::Comma) {
1281 self.advance();
1282 } else {
1283 break;
1284 }
1285 }
1286
1287 Ok(order_columns)
1288 }
1289
1290 fn parse_where_clause(&mut self) -> Result<WhereClause, String> {
1291 let mut conditions = Vec::new();
1292
1293 loop {
1294 let expr = self.parse_expression()?;
1295
1296 let connector = match &self.current_token {
1297 Token::And => {
1298 self.advance();
1299 Some(LogicalOp::And)
1300 }
1301 Token::Or => {
1302 self.advance();
1303 Some(LogicalOp::Or)
1304 }
1305 Token::RightParen if self.paren_depth <= 0 => {
1306 return Err(
1308 "Unexpected closing parenthesis - no matching opening parenthesis"
1309 .to_string(),
1310 );
1311 }
1312 _ => None,
1313 };
1314
1315 conditions.push(Condition {
1316 expr,
1317 connector: connector.clone(),
1318 });
1319
1320 if connector.is_none() {
1321 break;
1322 }
1323 }
1324
1325 Ok(WhereClause { conditions })
1326 }
1327
1328 fn parse_expression(&mut self) -> Result<SqlExpression, String> {
1329 let mut left = self.parse_comparison()?;
1330
1331 if let Some(op) = self.get_binary_op() {
1334 self.advance();
1335 let right = self.parse_expression()?;
1336 left = SqlExpression::BinaryOp {
1337 left: Box::new(left),
1338 op,
1339 right: Box::new(right),
1340 };
1341 }
1342
1343 if matches!(self.current_token, Token::In) {
1345 self.advance();
1346 self.consume(Token::LeftParen)?;
1347 let values = self.parse_expression_list()?;
1348 self.consume(Token::RightParen)?;
1349
1350 left = SqlExpression::InList {
1351 expr: Box::new(left),
1352 values,
1353 };
1354 }
1355
1356 Ok(left)
1360 }
1361
1362 fn parse_comparison(&mut self) -> Result<SqlExpression, String> {
1363 let mut left = self.parse_additive()?;
1364
1365 if matches!(self.current_token, Token::Between) {
1367 self.advance(); let lower = self.parse_primary()?;
1369 self.consume(Token::And)?; let upper = self.parse_primary()?;
1371
1372 return Ok(SqlExpression::Between {
1373 expr: Box::new(left),
1374 lower: Box::new(lower),
1375 upper: Box::new(upper),
1376 });
1377 }
1378
1379 if matches!(self.current_token, Token::Not) {
1381 self.advance(); if matches!(self.current_token, Token::In) {
1383 self.advance(); self.consume(Token::LeftParen)?;
1385 let values = self.parse_expression_list()?;
1386 self.consume(Token::RightParen)?;
1387
1388 return Ok(SqlExpression::NotInList {
1389 expr: Box::new(left),
1390 values,
1391 });
1392 }
1393 return Err("Expected IN after NOT".to_string());
1394 }
1395
1396 if matches!(self.current_token, Token::Is) {
1398 self.advance(); if matches!(self.current_token, Token::Not) {
1400 self.advance(); if matches!(self.current_token, Token::Null) {
1402 self.advance(); left = SqlExpression::BinaryOp {
1404 left: Box::new(left),
1405 op: "IS NOT NULL".to_string(),
1406 right: Box::new(SqlExpression::Null),
1407 };
1408 } else {
1409 return Err("Expected NULL after IS NOT".to_string());
1410 }
1411 } else if matches!(self.current_token, Token::Null) {
1412 self.advance(); left = SqlExpression::BinaryOp {
1414 left: Box::new(left),
1415 op: "IS NULL".to_string(),
1416 right: Box::new(SqlExpression::Null),
1417 };
1418 } else {
1419 return Err("Expected NULL or NOT after IS".to_string());
1420 }
1421 }
1422 else if let Some(op) = self.get_binary_op() {
1424 self.advance();
1425 let right = self.parse_additive()?;
1426 left = SqlExpression::BinaryOp {
1427 left: Box::new(left),
1428 op,
1429 right: Box::new(right),
1430 };
1431 }
1432
1433 Ok(left)
1434 }
1435
1436 fn parse_additive(&mut self) -> Result<SqlExpression, String> {
1437 let mut left = self.parse_multiplicative()?;
1438
1439 while matches!(self.current_token, Token::Plus | Token::Minus) {
1440 let op = match self.current_token {
1441 Token::Plus => "+",
1442 Token::Minus => "-",
1443 _ => unreachable!(),
1444 };
1445 self.advance();
1446 let right = self.parse_multiplicative()?;
1447 left = SqlExpression::BinaryOp {
1448 left: Box::new(left),
1449 op: op.to_string(),
1450 right: Box::new(right),
1451 };
1452 }
1453
1454 Ok(left)
1455 }
1456
1457 fn parse_multiplicative(&mut self) -> Result<SqlExpression, String> {
1458 let mut left = self.parse_primary()?;
1459
1460 while matches!(self.current_token, Token::Dot) {
1462 self.advance();
1463 if let Token::Identifier(method) = &self.current_token {
1464 let method_name = method.clone();
1465 self.advance();
1466
1467 if matches!(self.current_token, Token::LeftParen) {
1468 self.advance();
1469 let args = self.parse_method_args()?;
1470 self.consume(Token::RightParen)?;
1471
1472 match left {
1474 SqlExpression::Column(obj) => {
1475 left = SqlExpression::MethodCall {
1477 object: obj,
1478 method: method_name,
1479 args,
1480 };
1481 }
1482 SqlExpression::MethodCall { .. }
1483 | SqlExpression::ChainedMethodCall { .. } => {
1484 left = SqlExpression::ChainedMethodCall {
1486 base: Box::new(left),
1487 method: method_name,
1488 args,
1489 };
1490 }
1491 _ => {
1492 left = SqlExpression::ChainedMethodCall {
1494 base: Box::new(left),
1495 method: method_name,
1496 args,
1497 };
1498 }
1499 }
1500 } else {
1501 return Err(format!("Expected '(' after method name '{method_name}'"));
1502 }
1503 } else {
1504 return Err("Expected method name after '.'".to_string());
1505 }
1506 }
1507
1508 while matches!(
1509 self.current_token,
1510 Token::Star | Token::Divide | Token::Modulo
1511 ) {
1512 let op = match self.current_token {
1513 Token::Star => "*",
1514 Token::Divide => "/",
1515 Token::Modulo => "%",
1516 _ => unreachable!(),
1517 };
1518 self.advance();
1519 let right = self.parse_primary()?;
1520 left = SqlExpression::BinaryOp {
1521 left: Box::new(left),
1522 op: op.to_string(),
1523 right: Box::new(right),
1524 };
1525 }
1526
1527 Ok(left)
1528 }
1529
1530 fn parse_logical_or(&mut self) -> Result<SqlExpression, String> {
1531 let mut left = self.parse_logical_and()?;
1532
1533 while matches!(self.current_token, Token::Or) {
1534 self.advance();
1535 let right = self.parse_logical_and()?;
1536 left = SqlExpression::BinaryOp {
1540 left: Box::new(left),
1541 op: "OR".to_string(),
1542 right: Box::new(right),
1543 };
1544 }
1545
1546 Ok(left)
1547 }
1548
1549 fn parse_logical_and(&mut self) -> Result<SqlExpression, String> {
1550 let mut left = self.parse_expression()?;
1551
1552 while matches!(self.current_token, Token::And) {
1553 self.advance();
1554 let right = self.parse_expression()?;
1555 left = SqlExpression::BinaryOp {
1557 left: Box::new(left),
1558 op: "AND".to_string(),
1559 right: Box::new(right),
1560 };
1561 }
1562
1563 Ok(left)
1564 }
1565
1566 fn parse_case_expression(&mut self) -> Result<SqlExpression, String> {
1567 self.consume(Token::Case)?;
1569
1570 let mut when_branches = Vec::new();
1571
1572 while matches!(self.current_token, Token::When) {
1574 self.advance(); let condition = self.parse_expression()?;
1578
1579 self.consume(Token::Then)?;
1581
1582 let result = self.parse_expression()?;
1584
1585 when_branches.push(WhenBranch {
1586 condition: Box::new(condition),
1587 result: Box::new(result),
1588 });
1589 }
1590
1591 if when_branches.is_empty() {
1593 return Err("CASE expression must have at least one WHEN clause".to_string());
1594 }
1595
1596 let else_branch = if matches!(self.current_token, Token::Else) {
1598 self.advance(); Some(Box::new(self.parse_expression()?))
1600 } else {
1601 None
1602 };
1603
1604 self.consume(Token::End)?;
1606
1607 Ok(SqlExpression::CaseExpression {
1608 when_branches,
1609 else_branch,
1610 })
1611 }
1612
1613 fn parse_primary(&mut self) -> Result<SqlExpression, String> {
1614 if let Token::NumberLiteral(num_str) = &self.current_token {
1617 if self.columns.iter().any(|col| col == num_str) {
1619 let expr = SqlExpression::Column(num_str.clone());
1620 self.advance();
1621 return Ok(expr);
1622 }
1623 }
1624
1625 match &self.current_token {
1626 Token::Case => {
1627 self.parse_case_expression()
1629 }
1630 Token::DateTime => {
1631 self.advance(); self.consume(Token::LeftParen)?;
1633
1634 if matches!(&self.current_token, Token::RightParen) {
1636 self.advance(); return Ok(SqlExpression::DateTimeToday {
1638 hour: None,
1639 minute: None,
1640 second: None,
1641 });
1642 }
1643
1644 let year = if let Token::NumberLiteral(n) = &self.current_token {
1646 n.parse::<i32>().map_err(|_| "Invalid year")?
1647 } else {
1648 return Err("Expected year in DateTime constructor".to_string());
1649 };
1650 self.advance();
1651 self.consume(Token::Comma)?;
1652
1653 let month = if let Token::NumberLiteral(n) = &self.current_token {
1655 n.parse::<u32>().map_err(|_| "Invalid month")?
1656 } else {
1657 return Err("Expected month in DateTime constructor".to_string());
1658 };
1659 self.advance();
1660 self.consume(Token::Comma)?;
1661
1662 let day = if let Token::NumberLiteral(n) = &self.current_token {
1664 n.parse::<u32>().map_err(|_| "Invalid day")?
1665 } else {
1666 return Err("Expected day in DateTime constructor".to_string());
1667 };
1668 self.advance();
1669
1670 let mut hour = None;
1672 let mut minute = None;
1673 let mut second = None;
1674
1675 if matches!(&self.current_token, Token::Comma) {
1676 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1680 hour = Some(n.parse::<u32>().map_err(|_| "Invalid hour")?);
1681 self.advance();
1682
1683 if matches!(&self.current_token, Token::Comma) {
1685 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1688 minute = Some(n.parse::<u32>().map_err(|_| "Invalid minute")?);
1689 self.advance();
1690
1691 if matches!(&self.current_token, Token::Comma) {
1693 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1696 second =
1697 Some(n.parse::<u32>().map_err(|_| "Invalid second")?);
1698 self.advance();
1699 }
1700 }
1701 }
1702 }
1703 }
1704 }
1705
1706 self.consume(Token::RightParen)?;
1707 Ok(SqlExpression::DateTimeConstructor {
1708 year,
1709 month,
1710 day,
1711 hour,
1712 minute,
1713 second,
1714 })
1715 }
1716 Token::Identifier(id) => {
1717 let id_upper = id.to_uppercase();
1718 let id_clone = id.clone();
1719
1720 if id_upper == "TRUE" {
1722 self.advance();
1723 return Ok(SqlExpression::BooleanLiteral(true));
1724 } else if id_upper == "FALSE" {
1725 self.advance();
1726 return Ok(SqlExpression::BooleanLiteral(false));
1727 }
1728
1729 self.advance();
1730
1731 if matches!(self.current_token, Token::LeftParen) {
1733 self.advance(); let (args, has_distinct) = self.parse_function_args()?;
1737 self.consume(Token::RightParen)?;
1738
1739 if matches!(self.current_token, Token::Over) {
1741 self.advance(); self.consume(Token::LeftParen)?;
1743 let window_spec = self.parse_window_spec()?;
1744 self.consume(Token::RightParen)?;
1745 return Ok(SqlExpression::WindowFunction {
1746 name: id_upper,
1747 args,
1748 window_spec,
1749 });
1750 }
1751
1752 return Ok(SqlExpression::FunctionCall {
1753 name: id_upper,
1754 args,
1755 distinct: has_distinct,
1756 });
1757 }
1758
1759 Ok(SqlExpression::Column(id_clone))
1761 }
1762 Token::QuotedIdentifier(id) => {
1763 let expr = if self.in_method_args {
1766 SqlExpression::StringLiteral(id.clone())
1767 } else {
1768 SqlExpression::Column(id.clone())
1770 };
1771 self.advance();
1772 Ok(expr)
1773 }
1774 Token::StringLiteral(s) => {
1775 let expr = SqlExpression::StringLiteral(s.clone());
1776 self.advance();
1777 Ok(expr)
1778 }
1779 Token::NumberLiteral(n) => {
1780 let expr = SqlExpression::NumberLiteral(n.clone());
1781 self.advance();
1782 Ok(expr)
1783 }
1784 Token::Null => {
1785 self.advance();
1786 Ok(SqlExpression::Null)
1787 }
1788 Token::LeftParen => {
1789 self.advance();
1790
1791 let expr = self.parse_logical_or()?;
1794
1795 self.consume(Token::RightParen)?;
1796 Ok(expr)
1797 }
1798 Token::Not => {
1799 self.advance(); if let Ok(inner_expr) = self.parse_comparison() {
1803 if matches!(self.current_token, Token::In) {
1805 self.advance(); self.consume(Token::LeftParen)?;
1807 let values = self.parse_expression_list()?;
1808 self.consume(Token::RightParen)?;
1809
1810 Ok(SqlExpression::NotInList {
1811 expr: Box::new(inner_expr),
1812 values,
1813 })
1814 } else {
1815 Ok(SqlExpression::Not {
1817 expr: Box::new(inner_expr),
1818 })
1819 }
1820 } else {
1821 Err("Expected expression after NOT".to_string())
1822 }
1823 }
1824 Token::Star => {
1825 self.advance();
1827 Ok(SqlExpression::StringLiteral("*".to_string()))
1828 }
1829 _ => Err(format!("Unexpected token: {:?}", self.current_token)),
1830 }
1831 }
1832
1833 fn parse_method_args(&mut self) -> Result<Vec<SqlExpression>, String> {
1834 let mut args = Vec::new();
1835
1836 self.in_method_args = true;
1838
1839 if !matches!(self.current_token, Token::RightParen) {
1840 loop {
1841 args.push(self.parse_expression()?);
1842
1843 if matches!(self.current_token, Token::Comma) {
1844 self.advance();
1845 } else {
1846 break;
1847 }
1848 }
1849 }
1850
1851 self.in_method_args = false;
1853
1854 Ok(args)
1855 }
1856
1857 fn parse_function_args(&mut self) -> Result<(Vec<SqlExpression>, bool), String> {
1858 let mut args = Vec::new();
1859 let mut has_distinct = false;
1860
1861 if !matches!(self.current_token, Token::RightParen) {
1862 if matches!(self.current_token, Token::Distinct) {
1864 self.advance(); has_distinct = true;
1866 }
1867
1868 args.push(self.parse_additive()?);
1870
1871 while matches!(self.current_token, Token::Comma) {
1873 self.advance();
1874 args.push(self.parse_additive()?);
1875 }
1876 }
1877
1878 Ok((args, has_distinct))
1879 }
1880
1881 fn parse_expression_list(&mut self) -> Result<Vec<SqlExpression>, String> {
1882 let mut expressions = Vec::new();
1883
1884 loop {
1885 expressions.push(self.parse_expression()?);
1886
1887 if matches!(self.current_token, Token::Comma) {
1888 self.advance();
1889 } else {
1890 break;
1891 }
1892 }
1893
1894 Ok(expressions)
1895 }
1896
1897 fn get_binary_op(&self) -> Option<String> {
1898 match &self.current_token {
1899 Token::Equal => Some("=".to_string()),
1900 Token::NotEqual => Some("!=".to_string()),
1901 Token::LessThan => Some("<".to_string()),
1902 Token::GreaterThan => Some(">".to_string()),
1903 Token::LessThanOrEqual => Some("<=".to_string()),
1904 Token::GreaterThanOrEqual => Some(">=".to_string()),
1905 Token::Like => Some("LIKE".to_string()),
1906 _ => None,
1907 }
1908 }
1909
1910 fn get_arithmetic_op(&self) -> Option<String> {
1911 match &self.current_token {
1912 Token::Plus => Some("+".to_string()),
1913 Token::Minus => Some("-".to_string()),
1914 Token::Star => Some("*".to_string()), Token::Divide => Some("/".to_string()),
1916 Token::Modulo => Some("%".to_string()),
1917 _ => None,
1918 }
1919 }
1920
1921 #[must_use]
1922 pub fn get_position(&self) -> usize {
1923 self.lexer.get_position()
1924 }
1925}
1926
1927#[derive(Debug, Clone)]
1929pub enum CursorContext {
1930 SelectClause,
1931 FromClause,
1932 WhereClause,
1933 OrderByClause,
1934 AfterColumn(String),
1935 AfterLogicalOp(LogicalOp),
1936 AfterComparisonOp(String, String), InMethodCall(String, String), InExpression,
1939 Unknown,
1940}
1941
1942fn safe_slice_to(s: &str, pos: usize) -> &str {
1944 if pos >= s.len() {
1945 return s;
1946 }
1947
1948 let mut safe_pos = pos;
1950 while safe_pos > 0 && !s.is_char_boundary(safe_pos) {
1951 safe_pos -= 1;
1952 }
1953
1954 &s[..safe_pos]
1955}
1956
1957fn safe_slice_from(s: &str, pos: usize) -> &str {
1959 if pos >= s.len() {
1960 return "";
1961 }
1962
1963 let mut safe_pos = pos;
1965 while safe_pos < s.len() && !s.is_char_boundary(safe_pos) {
1966 safe_pos += 1;
1967 }
1968
1969 &s[safe_pos..]
1970}
1971
1972#[must_use]
1973pub fn detect_cursor_context(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
1974 let truncated = safe_slice_to(query, cursor_pos);
1975 let mut parser = Parser::new(truncated);
1976
1977 if let Ok(stmt) = parser.parse() {
1979 let (ctx, partial) = analyze_statement(&stmt, truncated, cursor_pos);
1980 #[cfg(test)]
1981 println!("analyze_statement returned: {ctx:?}, {partial:?} for query: '{truncated}'");
1982 (ctx, partial)
1983 } else {
1984 let (ctx, partial) = analyze_partial(truncated, cursor_pos);
1986 #[cfg(test)]
1987 println!("analyze_partial returned: {ctx:?}, {partial:?} for query: '{truncated}'");
1988 (ctx, partial)
1989 }
1990}
1991
1992#[must_use]
1993pub fn tokenize_query(query: &str) -> Vec<String> {
1994 let mut lexer = Lexer::new(query);
1995 let tokens = lexer.tokenize_all();
1996 tokens.iter().map(|t| format!("{t:?}")).collect()
1997}
1998
1999#[must_use]
2000pub fn format_sql_pretty(query: &str) -> Vec<String> {
2001 format_sql_pretty_compact(query, 5) }
2003
2004#[must_use]
2006pub fn format_ast_tree(query: &str) -> String {
2007 let mut parser = Parser::new(query);
2008 match parser.parse() {
2009 Ok(stmt) => format_select_statement(&stmt, 0),
2010 Err(e) => format!("❌ PARSE ERROR ❌\n{e}\n\n⚠️ The query could not be parsed correctly.\n💡 Check parentheses, operators, and syntax."),
2011 }
2012}
2013
2014fn format_select_statement(stmt: &SelectStatement, indent: usize) -> String {
2015 let mut result = String::new();
2016 let indent_str = " ".repeat(indent);
2017
2018 result.push_str(&format!("{indent_str}SelectStatement {{\n"));
2019
2020 result.push_str(&format!("{indent_str} columns: ["));
2022 if stmt.columns.is_empty() {
2023 result.push_str("],\n");
2024 } else {
2025 result.push('\n');
2026 for col in &stmt.columns {
2027 result.push_str(&format!("{indent_str} \"{col}\",\n"));
2028 }
2029 result.push_str(&format!("{indent_str} ],\n"));
2030 }
2031
2032 if let Some(table) = &stmt.from_table {
2034 result.push_str(&format!("{indent_str} from_table: \"{table}\",\n"));
2035 }
2036
2037 if let Some(where_clause) = &stmt.where_clause {
2039 result.push_str(&format!("{indent_str} where_clause: {{\n"));
2040 result.push_str(&format_where_clause(where_clause, indent + 2));
2041 result.push_str(&format!("{indent_str} }},\n"));
2042 }
2043
2044 if let Some(order_by) = &stmt.order_by {
2046 result.push_str(&format!("{indent_str} order_by: ["));
2047 if order_by.is_empty() {
2048 result.push_str("],\n");
2049 } else {
2050 result.push('\n');
2051 for col in order_by {
2052 let dir = match col.direction {
2053 SortDirection::Asc => "ASC",
2054 SortDirection::Desc => "DESC",
2055 };
2056 result.push_str(&format!(
2057 "{indent_str} \"{col}\" {dir},\n",
2058 col = col.column
2059 ));
2060 }
2061 result.push_str(&format!("{indent_str} ],\n"));
2062 }
2063 }
2064
2065 if let Some(group_by) = &stmt.group_by {
2067 result.push_str(&format!("{indent_str} group_by: ["));
2068 if group_by.is_empty() {
2069 result.push_str("]\n");
2070 } else {
2071 result.push('\n');
2072 for col in group_by {
2073 result.push_str(&format!("{indent_str} \"{col}\",\n"));
2074 }
2075 result.push_str(&format!("{indent_str} ],\n"));
2076 }
2077 }
2078
2079 result.push_str(&format!("{indent_str}}}"));
2080 result
2081}
2082
2083fn format_where_clause(clause: &WhereClause, indent: usize) -> String {
2084 let mut result = String::new();
2085 let indent_str = " ".repeat(indent);
2086
2087 result.push_str(&format!("{indent_str}conditions: [\n"));
2088
2089 for condition in &clause.conditions {
2090 result.push_str(&format!("{indent_str} {{\n"));
2091 result.push_str(&format!(
2092 "{indent_str} expr: {},\n",
2093 format_expression_ast(&condition.expr)
2094 ));
2095
2096 if let Some(connector) = &condition.connector {
2097 let connector_str = match connector {
2098 LogicalOp::And => "AND",
2099 LogicalOp::Or => "OR",
2100 };
2101 result.push_str(&format!("{indent_str} connector: {connector_str},\n"));
2102 }
2103
2104 result.push_str(&format!("{indent_str} }},\n"));
2105 }
2106
2107 result.push_str(&format!("{indent_str}]\n"));
2108 result
2109}
2110
2111fn format_expression_ast(expr: &SqlExpression) -> String {
2112 match expr {
2113 SqlExpression::Column(name) => format!("Column(\"{name}\")"),
2114 SqlExpression::StringLiteral(value) => format!("StringLiteral(\"{value}\")"),
2115 SqlExpression::NumberLiteral(value) => format!("NumberLiteral({value})"),
2116 SqlExpression::BooleanLiteral(value) => format!("BooleanLiteral({value})"),
2117 SqlExpression::Null => "Null".to_string(),
2118 SqlExpression::DateTimeConstructor {
2119 year,
2120 month,
2121 day,
2122 hour,
2123 minute,
2124 second,
2125 } => {
2126 format!(
2127 "DateTime({}-{:02}-{:02} {:02}:{:02}:{:02})",
2128 year,
2129 month,
2130 day,
2131 hour.unwrap_or(0),
2132 minute.unwrap_or(0),
2133 second.unwrap_or(0)
2134 )
2135 }
2136 SqlExpression::DateTimeToday {
2137 hour,
2138 minute,
2139 second,
2140 } => {
2141 format!(
2142 "DateTimeToday({:02}:{:02}:{:02})",
2143 hour.unwrap_or(0),
2144 minute.unwrap_or(0),
2145 second.unwrap_or(0)
2146 )
2147 }
2148 SqlExpression::MethodCall {
2149 object,
2150 method,
2151 args,
2152 } => {
2153 let args_str = args
2154 .iter()
2155 .map(format_expression_ast)
2156 .collect::<Vec<_>>()
2157 .join(", ");
2158 format!("MethodCall({object}.{method}({args_str}))")
2159 }
2160 SqlExpression::ChainedMethodCall { base, method, args } => {
2161 let args_str = args
2162 .iter()
2163 .map(format_expression_ast)
2164 .collect::<Vec<_>>()
2165 .join(", ");
2166 format!(
2167 "ChainedMethodCall({}.{}({}))",
2168 format_expression_ast(base),
2169 method,
2170 args_str
2171 )
2172 }
2173 SqlExpression::FunctionCall {
2174 name,
2175 args,
2176 distinct,
2177 } => {
2178 let args_str = args
2179 .iter()
2180 .map(format_expression_ast)
2181 .collect::<Vec<_>>()
2182 .join(", ");
2183 if *distinct {
2184 format!("FunctionCall({name}(DISTINCT {args_str}))")
2185 } else {
2186 format!("FunctionCall({name}({args_str}))")
2187 }
2188 }
2189 SqlExpression::WindowFunction {
2190 name,
2191 args,
2192 window_spec,
2193 } => {
2194 let args_str = args
2195 .iter()
2196 .map(format_expression_ast)
2197 .collect::<Vec<_>>()
2198 .join(", ");
2199 let partition_str = if !window_spec.partition_by.is_empty() {
2200 format!(" PARTITION BY {}", window_spec.partition_by.join(", "))
2201 } else {
2202 String::new()
2203 };
2204 let order_str = if !window_spec.order_by.is_empty() {
2205 let cols = window_spec
2206 .order_by
2207 .iter()
2208 .map(|col| format!("{} {:?}", col.column, col.direction))
2209 .collect::<Vec<_>>()
2210 .join(", ");
2211 format!(" ORDER BY {}", cols)
2212 } else {
2213 String::new()
2214 };
2215 format!("WindowFunction({name}({args_str}) OVER({partition_str}{order_str}))")
2216 }
2217 SqlExpression::BinaryOp { left, op, right } => {
2218 format!(
2219 "BinaryOp({} {} {})",
2220 format_expression_ast(left),
2221 op,
2222 format_expression_ast(right)
2223 )
2224 }
2225 SqlExpression::InList { expr, values } => {
2226 let list_str = values
2227 .iter()
2228 .map(format_expression_ast)
2229 .collect::<Vec<_>>()
2230 .join(", ");
2231 format!("InList({} IN [{}])", format_expression_ast(expr), list_str)
2232 }
2233 SqlExpression::NotInList { expr, values } => {
2234 let list_str = values
2235 .iter()
2236 .map(format_expression_ast)
2237 .collect::<Vec<_>>()
2238 .join(", ");
2239 format!(
2240 "NotInList({} NOT IN [{}])",
2241 format_expression_ast(expr),
2242 list_str
2243 )
2244 }
2245 SqlExpression::Between { expr, lower, upper } => {
2246 format!(
2247 "Between({} BETWEEN {} AND {})",
2248 format_expression_ast(expr),
2249 format_expression_ast(lower),
2250 format_expression_ast(upper)
2251 )
2252 }
2253 SqlExpression::Not { expr } => {
2254 format!("Not({})", format_expression_ast(expr))
2255 }
2256 SqlExpression::CaseExpression {
2257 when_branches,
2258 else_branch,
2259 } => {
2260 let when_strs: Vec<String> = when_branches
2261 .iter()
2262 .map(|branch| {
2263 format!(
2264 "WHEN {} THEN {}",
2265 format_expression_ast(&branch.condition),
2266 format_expression_ast(&branch.result)
2267 )
2268 })
2269 .collect();
2270 let else_str = else_branch
2271 .as_ref()
2272 .map(|e| format!(" ELSE {}", format_expression_ast(e)))
2273 .unwrap_or_default();
2274 format!("CASE {} {} END", when_strs.join(" "), else_str)
2275 }
2276 }
2277}
2278
2279#[must_use]
2281pub fn datetime_to_iso_string(expr: &SqlExpression) -> Option<String> {
2282 match expr {
2283 SqlExpression::DateTimeConstructor {
2284 year,
2285 month,
2286 day,
2287 hour,
2288 minute,
2289 second,
2290 } => {
2291 let h = hour.unwrap_or(0);
2292 let m = minute.unwrap_or(0);
2293 let s = second.unwrap_or(0);
2294
2295 if let Ok(dt) = NaiveDateTime::parse_from_str(
2297 &format!("{year:04}-{month:02}-{day:02} {h:02}:{m:02}:{s:02}"),
2298 "%Y-%m-%d %H:%M:%S",
2299 ) {
2300 Some(dt.format("%Y-%m-%d %H:%M:%S").to_string())
2301 } else {
2302 None
2303 }
2304 }
2305 SqlExpression::DateTimeToday {
2306 hour,
2307 minute,
2308 second,
2309 } => {
2310 let now = Local::now();
2311 let h = hour.unwrap_or(0);
2312 let m = minute.unwrap_or(0);
2313 let s = second.unwrap_or(0);
2314
2315 if let Ok(dt) = NaiveDateTime::parse_from_str(
2317 &format!(
2318 "{:04}-{:02}-{:02} {:02}:{:02}:{:02}",
2319 now.year(),
2320 now.month(),
2321 now.day(),
2322 h,
2323 m,
2324 s
2325 ),
2326 "%Y-%m-%d %H:%M:%S",
2327 ) {
2328 Some(dt.format("%Y-%m-%d %H:%M:%S").to_string())
2329 } else {
2330 None
2331 }
2332 }
2333 _ => None,
2334 }
2335}
2336
2337fn format_sql_with_preserved_parens(
2339 query: &str,
2340 cols_per_line: usize,
2341) -> Result<Vec<String>, String> {
2342 let mut lines = Vec::new();
2343 let mut lexer = Lexer::new(query);
2344 let tokens_with_pos = lexer.tokenize_all_with_positions();
2345
2346 if tokens_with_pos.is_empty() {
2347 return Err("No tokens found".to_string());
2348 }
2349
2350 let mut i = 0;
2351 let cols_per_line = cols_per_line.max(1);
2352
2353 while i < tokens_with_pos.len() {
2354 let (start, _end, ref token) = tokens_with_pos[i];
2355
2356 match token {
2357 Token::Select => {
2358 lines.push("SELECT".to_string());
2359 i += 1;
2360
2361 let mut columns = Vec::new();
2363 let mut col_start = i;
2364 while i < tokens_with_pos.len() {
2365 match &tokens_with_pos[i].2 {
2366 Token::From | Token::Eof => break,
2367 Token::Comma => {
2368 if col_start < i {
2370 let col_text = extract_text_between_positions(
2371 query,
2372 tokens_with_pos[col_start].0,
2373 tokens_with_pos[i - 1].1,
2374 );
2375 columns.push(col_text);
2376 }
2377 i += 1;
2378 col_start = i;
2379 }
2380 _ => i += 1,
2381 }
2382 }
2383 if col_start < i && i > 0 {
2385 let col_text = extract_text_between_positions(
2386 query,
2387 tokens_with_pos[col_start].0,
2388 tokens_with_pos[i - 1].1,
2389 );
2390 columns.push(col_text);
2391 }
2392
2393 for chunk in columns.chunks(cols_per_line) {
2395 let mut line = " ".to_string();
2396 for (idx, col) in chunk.iter().enumerate() {
2397 if idx > 0 {
2398 line.push_str(", ");
2399 }
2400 line.push_str(col.trim());
2401 }
2402 let is_last_chunk = chunk.as_ptr() as usize + std::mem::size_of_val(chunk)
2404 >= columns.last().map_or(0, |c| std::ptr::from_ref(c) as usize);
2405 if !is_last_chunk && columns.len() > cols_per_line {
2406 line.push(',');
2407 }
2408 lines.push(line);
2409 }
2410 }
2411 Token::From => {
2412 i += 1;
2413 if i < tokens_with_pos.len() {
2414 let table_start = tokens_with_pos[i].0;
2415 while i < tokens_with_pos.len() {
2417 match &tokens_with_pos[i].2 {
2418 Token::Where | Token::OrderBy | Token::GroupBy | Token::Eof => break,
2419 _ => i += 1,
2420 }
2421 }
2422 if i > 0 {
2423 let table_text = extract_text_between_positions(
2424 query,
2425 table_start,
2426 tokens_with_pos[i - 1].1,
2427 );
2428 lines.push(format!("FROM {}", table_text.trim()));
2429 }
2430 }
2431 }
2432 Token::Where => {
2433 lines.push("WHERE".to_string());
2434 i += 1;
2435
2436 let where_start = if i < tokens_with_pos.len() {
2438 tokens_with_pos[i].0
2439 } else {
2440 start
2441 };
2442
2443 let mut where_end = query.len();
2445 while i < tokens_with_pos.len() {
2446 match &tokens_with_pos[i].2 {
2447 Token::OrderBy | Token::GroupBy | Token::Eof => {
2448 if i > 0 {
2449 where_end = tokens_with_pos[i - 1].1;
2450 }
2451 break;
2452 }
2453 _ => i += 1,
2454 }
2455 }
2456
2457 let where_text = extract_text_between_positions(query, where_start, where_end);
2459
2460 let formatted_where = format_where_clause_with_parens(&where_text);
2462 for line in formatted_where {
2463 lines.push(format!(" {line}"));
2464 }
2465 }
2466 Token::OrderBy => {
2467 i += 1;
2468 let order_start = if i < tokens_with_pos.len() {
2469 tokens_with_pos[i].0
2470 } else {
2471 start
2472 };
2473
2474 while i < tokens_with_pos.len() {
2476 match &tokens_with_pos[i].2 {
2477 Token::GroupBy | Token::Eof => break,
2478 _ => i += 1,
2479 }
2480 }
2481
2482 if i > 0 {
2483 let order_text = extract_text_between_positions(
2484 query,
2485 order_start,
2486 tokens_with_pos[i - 1].1,
2487 );
2488 lines.push(format!("ORDER BY {}", order_text.trim()));
2489 }
2490 }
2491 Token::GroupBy => {
2492 i += 1;
2493 let group_start = if i < tokens_with_pos.len() {
2494 tokens_with_pos[i].0
2495 } else {
2496 start
2497 };
2498
2499 while i < tokens_with_pos.len() {
2501 match &tokens_with_pos[i].2 {
2502 Token::Having | Token::Eof => break,
2503 _ => i += 1,
2504 }
2505 }
2506
2507 if i > 0 {
2508 let group_text = extract_text_between_positions(
2509 query,
2510 group_start,
2511 tokens_with_pos[i - 1].1,
2512 );
2513 lines.push(format!("GROUP BY {}", group_text.trim()));
2514 }
2515 }
2516 _ => i += 1,
2517 }
2518 }
2519
2520 Ok(lines)
2521}
2522
2523fn extract_text_between_positions(query: &str, start: usize, end: usize) -> String {
2525 let chars: Vec<char> = query.chars().collect();
2526 let start = start.min(chars.len());
2527 let end = end.min(chars.len());
2528 chars[start..end].iter().collect()
2529}
2530
2531fn format_where_clause_with_parens(where_text: &str) -> Vec<String> {
2533 let mut lines = Vec::new();
2534 let mut current_line = String::new();
2535 let mut paren_depth = 0;
2536 let mut i = 0;
2537 let chars: Vec<char> = where_text.chars().collect();
2538
2539 while i < chars.len() {
2540 if paren_depth == 0 {
2542 if i + 5 <= chars.len() {
2544 let next_five: String = chars[i..i + 5].iter().collect();
2545 if next_five.to_uppercase() == " AND " {
2546 if !current_line.trim().is_empty() {
2547 lines.push(current_line.trim().to_string());
2548 }
2549 lines.push("AND".to_string());
2550 current_line.clear();
2551 i += 5;
2552 continue;
2553 }
2554 }
2555 if i + 4 <= chars.len() {
2556 let next_four: String = chars[i..i + 4].iter().collect();
2557 if next_four.to_uppercase() == " OR " {
2558 if !current_line.trim().is_empty() {
2559 lines.push(current_line.trim().to_string());
2560 }
2561 lines.push("OR".to_string());
2562 current_line.clear();
2563 i += 4;
2564 continue;
2565 }
2566 }
2567 }
2568
2569 match chars[i] {
2571 '(' => {
2572 paren_depth += 1;
2573 current_line.push('(');
2574 }
2575 ')' => {
2576 paren_depth -= 1;
2577 current_line.push(')');
2578 }
2579 c => current_line.push(c),
2580 }
2581 i += 1;
2582 }
2583
2584 if !current_line.trim().is_empty() {
2586 lines.push(current_line.trim().to_string());
2587 }
2588
2589 if lines.is_empty() {
2591 lines.push(where_text.trim().to_string());
2592 }
2593
2594 lines
2595}
2596
2597#[must_use]
2598pub fn format_sql_pretty_compact(query: &str, cols_per_line: usize) -> Vec<String> {
2599 if let Ok(lines) = format_sql_with_preserved_parens(query, cols_per_line) {
2601 return lines;
2602 }
2603
2604 let mut lines = Vec::new();
2606 let mut parser = Parser::new(query);
2607
2608 let cols_per_line = cols_per_line.max(1);
2610
2611 if let Ok(stmt) = parser.parse() {
2612 if !stmt.columns.is_empty() {
2614 lines.push("SELECT".to_string());
2615
2616 for chunk in stmt.columns.chunks(cols_per_line) {
2618 let mut line = " ".to_string();
2619 for (i, col) in chunk.iter().enumerate() {
2620 if i > 0 {
2621 line.push_str(", ");
2622 }
2623 line.push_str(col);
2624 }
2625 let last_chunk_idx = (stmt.columns.len() - 1) / cols_per_line;
2627 let current_chunk_idx =
2628 stmt.columns.iter().position(|c| c == &chunk[0]).unwrap() / cols_per_line;
2629 if current_chunk_idx < last_chunk_idx {
2630 line.push(',');
2631 }
2632 lines.push(line);
2633 }
2634 }
2635
2636 if let Some(table) = &stmt.from_table {
2638 lines.push(format!("FROM {table}"));
2639 }
2640
2641 if let Some(where_clause) = &stmt.where_clause {
2643 lines.push("WHERE".to_string());
2644 for (i, condition) in where_clause.conditions.iter().enumerate() {
2645 if i > 0 {
2646 if let Some(prev_condition) = where_clause.conditions.get(i - 1) {
2648 if let Some(connector) = &prev_condition.connector {
2649 match connector {
2650 LogicalOp::And => lines.push(" AND".to_string()),
2651 LogicalOp::Or => lines.push(" OR".to_string()),
2652 }
2653 }
2654 }
2655 }
2656 lines.push(format!(" {}", format_expression(&condition.expr)));
2657 }
2658 }
2659
2660 if let Some(order_by) = &stmt.order_by {
2662 let order_str = order_by
2663 .iter()
2664 .map(|col| {
2665 let dir = match col.direction {
2666 SortDirection::Asc => " ASC",
2667 SortDirection::Desc => " DESC",
2668 };
2669 format!("{}{}", col.column, dir)
2670 })
2671 .collect::<Vec<_>>()
2672 .join(", ");
2673 lines.push(format!("ORDER BY {order_str}"));
2674 }
2675
2676 if let Some(group_by) = &stmt.group_by {
2678 let group_str = group_by.join(", ");
2679 lines.push(format!("GROUP BY {group_str}"));
2680 }
2681 } else {
2682 let mut lexer = Lexer::new(query);
2684 let tokens = lexer.tokenize_all();
2685 let mut current_line = String::new();
2686 let mut indent = 0;
2687
2688 for token in tokens {
2689 match &token {
2690 Token::Select | Token::From | Token::Where | Token::OrderBy | Token::GroupBy => {
2691 if !current_line.is_empty() {
2692 lines.push(current_line.trim().to_string());
2693 current_line.clear();
2694 }
2695 lines.push(format!("{token:?}").to_uppercase());
2696 indent = 1;
2697 }
2698 Token::And | Token::Or => {
2699 if !current_line.is_empty() {
2700 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2701 current_line.clear();
2702 }
2703 lines.push(format!(" {token:?}").to_uppercase());
2704 }
2705 Token::Comma => {
2706 current_line.push(',');
2707 if indent > 0 {
2708 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2709 current_line.clear();
2710 }
2711 }
2712 Token::Eof => break,
2713 _ => {
2714 if !current_line.is_empty() {
2715 current_line.push(' ');
2716 }
2717 current_line.push_str(&format_token(&token));
2718 }
2719 }
2720 }
2721
2722 if !current_line.is_empty() {
2723 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2724 }
2725 }
2726
2727 lines
2728}
2729
2730fn format_expression(expr: &SqlExpression) -> String {
2731 match expr {
2732 SqlExpression::Column(name) => name.clone(),
2733 SqlExpression::StringLiteral(s) => format!("'{s}'"),
2734 SqlExpression::NumberLiteral(n) => n.clone(),
2735 SqlExpression::BooleanLiteral(b) => b.to_string(),
2736 SqlExpression::Null => "NULL".to_string(),
2737 SqlExpression::DateTimeConstructor {
2738 year,
2739 month,
2740 day,
2741 hour,
2742 minute,
2743 second,
2744 } => {
2745 let mut result = format!("DateTime({year}, {month}, {day}");
2746 if let Some(h) = hour {
2747 result.push_str(&format!(", {h}"));
2748 if let Some(m) = minute {
2749 result.push_str(&format!(", {m}"));
2750 if let Some(s) = second {
2751 result.push_str(&format!(", {s}"));
2752 }
2753 }
2754 }
2755 result.push(')');
2756 result
2757 }
2758 SqlExpression::DateTimeToday {
2759 hour,
2760 minute,
2761 second,
2762 } => {
2763 let mut result = "DateTime()".to_string();
2764 if let Some(h) = hour {
2765 result = format!("DateTime(TODAY, {h}");
2766 if let Some(m) = minute {
2767 result.push_str(&format!(", {m}"));
2768 if let Some(s) = second {
2769 result.push_str(&format!(", {s}"));
2770 }
2771 }
2772 result.push(')');
2773 }
2774 result
2775 }
2776 SqlExpression::MethodCall {
2777 object,
2778 method,
2779 args,
2780 } => {
2781 let args_str = args
2782 .iter()
2783 .map(format_expression)
2784 .collect::<Vec<_>>()
2785 .join(", ");
2786 format!("{object}.{method}({args_str})")
2787 }
2788 SqlExpression::BinaryOp { left, op, right } => {
2789 if op == "OR" || op == "AND" {
2792 format!(
2795 "({} {} {})",
2796 format_expression(left),
2797 op,
2798 format_expression(right)
2799 )
2800 } else {
2801 format!(
2802 "{} {} {}",
2803 format_expression(left),
2804 op,
2805 format_expression(right)
2806 )
2807 }
2808 }
2809 SqlExpression::InList { expr, values } => {
2810 let values_str = values
2811 .iter()
2812 .map(format_expression)
2813 .collect::<Vec<_>>()
2814 .join(", ");
2815 format!("{} IN ({})", format_expression(expr), values_str)
2816 }
2817 SqlExpression::NotInList { expr, values } => {
2818 let values_str = values
2819 .iter()
2820 .map(format_expression)
2821 .collect::<Vec<_>>()
2822 .join(", ");
2823 format!("{} NOT IN ({})", format_expression(expr), values_str)
2824 }
2825 SqlExpression::Between { expr, lower, upper } => {
2826 format!(
2827 "{} BETWEEN {} AND {}",
2828 format_expression(expr),
2829 format_expression(lower),
2830 format_expression(upper)
2831 )
2832 }
2833 SqlExpression::Not { expr } => {
2834 format!("NOT {}", format_expression(expr))
2835 }
2836 SqlExpression::ChainedMethodCall { base, method, args } => {
2837 let args_str = args
2838 .iter()
2839 .map(format_expression)
2840 .collect::<Vec<_>>()
2841 .join(", ");
2842 format!("{}.{}({})", format_expression(base), method, args_str)
2843 }
2844 SqlExpression::FunctionCall {
2845 name,
2846 args,
2847 distinct,
2848 } => {
2849 let args_str = args
2850 .iter()
2851 .map(format_expression)
2852 .collect::<Vec<_>>()
2853 .join(", ");
2854 if *distinct {
2855 format!("{name}(DISTINCT {args_str})")
2856 } else {
2857 format!("{name}({args_str})")
2858 }
2859 }
2860 SqlExpression::WindowFunction {
2861 name,
2862 args,
2863 window_spec,
2864 } => {
2865 let args_str = args
2866 .iter()
2867 .map(format_expression)
2868 .collect::<Vec<_>>()
2869 .join(", ");
2870 let partition_str = if !window_spec.partition_by.is_empty() {
2871 format!(" PARTITION BY {}", window_spec.partition_by.join(", "))
2872 } else {
2873 String::new()
2874 };
2875 let order_str = if !window_spec.order_by.is_empty() {
2876 let cols = window_spec
2877 .order_by
2878 .iter()
2879 .map(|col| {
2880 let dir = match col.direction {
2881 SortDirection::Asc => "ASC",
2882 SortDirection::Desc => "DESC",
2883 };
2884 format!("{} {}", col.column, dir)
2885 })
2886 .collect::<Vec<_>>()
2887 .join(", ");
2888 format!(" ORDER BY {}", cols)
2889 } else {
2890 String::new()
2891 };
2892 format!("{name}({args_str}) OVER({partition_str}{order_str})")
2893 }
2894 SqlExpression::CaseExpression {
2895 when_branches,
2896 else_branch,
2897 } => {
2898 let mut result = String::from("CASE");
2899 for branch in when_branches {
2900 result.push_str(&format!(
2901 " WHEN {} THEN {}",
2902 format_expression(&branch.condition),
2903 format_expression(&branch.result)
2904 ));
2905 }
2906 if let Some(else_expr) = else_branch {
2907 result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
2908 }
2909 result.push_str(" END");
2910 result
2911 }
2912 }
2913}
2914
2915fn format_token(token: &Token) -> String {
2916 match token {
2917 Token::Identifier(s) => s.clone(),
2918 Token::QuotedIdentifier(s) => format!("\"{s}\""),
2919 Token::StringLiteral(s) => format!("'{s}'"),
2920 Token::NumberLiteral(n) => n.clone(),
2921 Token::DateTime => "DateTime".to_string(),
2922 Token::Case => "CASE".to_string(),
2923 Token::When => "WHEN".to_string(),
2924 Token::Then => "THEN".to_string(),
2925 Token::Else => "ELSE".to_string(),
2926 Token::End => "END".to_string(),
2927 Token::Distinct => "DISTINCT".to_string(),
2928 Token::Over => "OVER".to_string(),
2929 Token::Partition => "PARTITION".to_string(),
2930 Token::By => "BY".to_string(),
2931 Token::LeftParen => "(".to_string(),
2932 Token::RightParen => ")".to_string(),
2933 Token::Comma => ",".to_string(),
2934 Token::Dot => ".".to_string(),
2935 Token::Equal => "=".to_string(),
2936 Token::NotEqual => "!=".to_string(),
2937 Token::LessThan => "<".to_string(),
2938 Token::GreaterThan => ">".to_string(),
2939 Token::LessThanOrEqual => "<=".to_string(),
2940 Token::GreaterThanOrEqual => ">=".to_string(),
2941 Token::In => "IN".to_string(),
2942 _ => format!("{token:?}").to_uppercase(),
2943 }
2944}
2945
2946fn analyze_statement(
2947 stmt: &SelectStatement,
2948 query: &str,
2949 _cursor_pos: usize,
2950) -> (CursorContext, Option<String>) {
2951 let trimmed = query.trim();
2953
2954 let comparison_ops = [" > ", " < ", " >= ", " <= ", " = ", " != "];
2956 for op in &comparison_ops {
2957 if let Some(op_pos) = query.rfind(op) {
2958 let before_op = safe_slice_to(query, op_pos);
2959 let after_op_start = op_pos + op.len();
2960 let after_op = if after_op_start < query.len() {
2961 &query[after_op_start..]
2962 } else {
2963 ""
2964 };
2965
2966 if let Some(col_name) = before_op.split_whitespace().last() {
2968 if col_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
2969 let after_op_trimmed = after_op.trim();
2971 if after_op_trimmed.is_empty()
2972 || (after_op_trimmed
2973 .chars()
2974 .all(|c| c.is_alphanumeric() || c == '_')
2975 && !after_op_trimmed.contains('('))
2976 {
2977 let partial = if after_op_trimmed.is_empty() {
2978 None
2979 } else {
2980 Some(after_op_trimmed.to_string())
2981 };
2982 return (
2983 CursorContext::AfterComparisonOp(
2984 col_name.to_string(),
2985 op.trim().to_string(),
2986 ),
2987 partial,
2988 );
2989 }
2990 }
2991 }
2992 }
2993 }
2994
2995 if trimmed.to_uppercase().ends_with(" AND")
2997 || trimmed.to_uppercase().ends_with(" OR")
2998 || trimmed.to_uppercase().ends_with(" AND ")
2999 || trimmed.to_uppercase().ends_with(" OR ")
3000 {
3001 } else {
3003 if let Some(dot_pos) = trimmed.rfind('.') {
3005 let before_dot = safe_slice_to(trimmed, dot_pos);
3007 let after_dot_start = dot_pos + 1;
3008 let after_dot = if after_dot_start < trimmed.len() {
3009 &trimmed[after_dot_start..]
3010 } else {
3011 ""
3012 };
3013
3014 if !after_dot.contains('(') {
3017 let col_name = if before_dot.ends_with('"') {
3019 let bytes = before_dot.as_bytes();
3021 let mut pos = before_dot.len() - 1; let mut found_start = None;
3023
3024 if pos > 0 {
3026 pos -= 1;
3027 while pos > 0 {
3028 if bytes[pos] == b'"' {
3029 if pos == 0 || bytes[pos - 1] != b'\\' {
3031 found_start = Some(pos);
3032 break;
3033 }
3034 }
3035 pos -= 1;
3036 }
3037 if found_start.is_none() && bytes[0] == b'"' {
3039 found_start = Some(0);
3040 }
3041 }
3042
3043 found_start.map(|start| safe_slice_from(before_dot, start))
3044 } else {
3045 before_dot
3048 .split_whitespace()
3049 .last()
3050 .map(|word| word.trim_start_matches('('))
3051 };
3052
3053 if let Some(col_name) = col_name {
3054 let is_valid = if col_name.starts_with('"') && col_name.ends_with('"') {
3056 true
3058 } else {
3059 col_name.chars().all(|c| c.is_alphanumeric() || c == '_')
3061 };
3062
3063 if is_valid {
3064 let partial_method = if after_dot.is_empty() {
3067 None
3068 } else if after_dot.chars().all(|c| c.is_alphanumeric() || c == '_') {
3069 Some(after_dot.to_string())
3070 } else {
3071 None
3072 };
3073
3074 let col_name_for_context = if col_name.starts_with('"')
3076 && col_name.ends_with('"')
3077 && col_name.len() > 2
3078 {
3079 col_name[1..col_name.len() - 1].to_string()
3080 } else {
3081 col_name.to_string()
3082 };
3083
3084 return (
3085 CursorContext::AfterColumn(col_name_for_context),
3086 partial_method,
3087 );
3088 }
3089 }
3090 }
3091 }
3092 }
3093
3094 if let Some(where_clause) = &stmt.where_clause {
3096 if trimmed.to_uppercase().ends_with(" AND") || trimmed.to_uppercase().ends_with(" OR") {
3098 let op = if trimmed.to_uppercase().ends_with(" AND") {
3099 LogicalOp::And
3100 } else {
3101 LogicalOp::Or
3102 };
3103 return (CursorContext::AfterLogicalOp(op), None);
3104 }
3105
3106 if let Some(and_pos) = query.to_uppercase().rfind(" AND ") {
3108 let after_and = safe_slice_from(query, and_pos + 5);
3109 let partial = extract_partial_at_end(after_and);
3110 if partial.is_some() {
3111 return (CursorContext::AfterLogicalOp(LogicalOp::And), partial);
3112 }
3113 }
3114
3115 if let Some(or_pos) = query.to_uppercase().rfind(" OR ") {
3116 let after_or = safe_slice_from(query, or_pos + 4);
3117 let partial = extract_partial_at_end(after_or);
3118 if partial.is_some() {
3119 return (CursorContext::AfterLogicalOp(LogicalOp::Or), partial);
3120 }
3121 }
3122
3123 if let Some(last_condition) = where_clause.conditions.last() {
3124 if let Some(connector) = &last_condition.connector {
3125 return (
3127 CursorContext::AfterLogicalOp(connector.clone()),
3128 extract_partial_at_end(query),
3129 );
3130 }
3131 }
3132 return (CursorContext::WhereClause, extract_partial_at_end(query));
3134 }
3135
3136 if query.to_uppercase().ends_with(" ORDER BY ") || query.to_uppercase().ends_with(" ORDER BY") {
3138 return (CursorContext::OrderByClause, None);
3139 }
3140
3141 if stmt.order_by.is_some() {
3143 return (CursorContext::OrderByClause, extract_partial_at_end(query));
3144 }
3145
3146 if stmt.from_table.is_some() && stmt.where_clause.is_none() && stmt.order_by.is_none() {
3147 return (CursorContext::FromClause, extract_partial_at_end(query));
3148 }
3149
3150 if !stmt.columns.is_empty() && stmt.from_table.is_none() {
3151 return (CursorContext::SelectClause, extract_partial_at_end(query));
3152 }
3153
3154 (CursorContext::Unknown, None)
3155}
3156
3157fn analyze_partial(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
3158 let upper = query.to_uppercase();
3159
3160 let trimmed = query.trim();
3162
3163 #[cfg(test)]
3164 {
3165 if trimmed.contains("\"Last Name\"") {
3166 eprintln!("DEBUG analyze_partial: query='{query}', trimmed='{trimmed}'");
3167 }
3168 }
3169
3170 let comparison_ops = [" > ", " < ", " >= ", " <= ", " = ", " != "];
3172 for op in &comparison_ops {
3173 if let Some(op_pos) = query.rfind(op) {
3174 let before_op = safe_slice_to(query, op_pos);
3175 let after_op_start = op_pos + op.len();
3176 let after_op = if after_op_start < query.len() {
3177 &query[after_op_start..]
3178 } else {
3179 ""
3180 };
3181
3182 if let Some(col_name) = before_op.split_whitespace().last() {
3184 if col_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
3185 let after_op_trimmed = after_op.trim();
3187 if after_op_trimmed.is_empty()
3188 || (after_op_trimmed
3189 .chars()
3190 .all(|c| c.is_alphanumeric() || c == '_')
3191 && !after_op_trimmed.contains('('))
3192 {
3193 let partial = if after_op_trimmed.is_empty() {
3194 None
3195 } else {
3196 Some(after_op_trimmed.to_string())
3197 };
3198 return (
3199 CursorContext::AfterComparisonOp(
3200 col_name.to_string(),
3201 op.trim().to_string(),
3202 ),
3203 partial,
3204 );
3205 }
3206 }
3207 }
3208 }
3209 }
3210
3211 if let Some(dot_pos) = trimmed.rfind('.') {
3214 #[cfg(test)]
3215 {
3216 if trimmed.contains("\"Last Name\"") {
3217 eprintln!("DEBUG: Found dot at position {dot_pos}");
3218 }
3219 }
3220 let before_dot = &trimmed[..dot_pos];
3222 let after_dot = &trimmed[dot_pos + 1..];
3223
3224 if !after_dot.contains('(') {
3227 let col_name = if before_dot.ends_with('"') {
3230 let bytes = before_dot.as_bytes();
3232 let mut pos = before_dot.len() - 1; let mut found_start = None;
3234
3235 #[cfg(test)]
3236 {
3237 if trimmed.contains("\"Last Name\"") {
3238 eprintln!("DEBUG: before_dot='{before_dot}', looking for opening quote");
3239 }
3240 }
3241
3242 if pos > 0 {
3244 pos -= 1;
3245 while pos > 0 {
3246 if bytes[pos] == b'"' {
3247 if pos == 0 || bytes[pos - 1] != b'\\' {
3249 found_start = Some(pos);
3250 break;
3251 }
3252 }
3253 pos -= 1;
3254 }
3255 if found_start.is_none() && bytes[0] == b'"' {
3257 found_start = Some(0);
3258 }
3259 }
3260
3261 if let Some(start) = found_start {
3262 let result = safe_slice_from(before_dot, start);
3264 #[cfg(test)]
3265 {
3266 if trimmed.contains("\"Last Name\"") {
3267 eprintln!("DEBUG: Extracted quoted identifier: '{result}'");
3268 }
3269 }
3270 Some(result)
3271 } else {
3272 #[cfg(test)]
3273 {
3274 if trimmed.contains("\"Last Name\"") {
3275 eprintln!("DEBUG: No opening quote found!");
3276 }
3277 }
3278 None
3279 }
3280 } else {
3281 before_dot
3284 .split_whitespace()
3285 .last()
3286 .map(|word| word.trim_start_matches('('))
3287 };
3288
3289 if let Some(col_name) = col_name {
3290 #[cfg(test)]
3291 {
3292 if trimmed.contains("\"Last Name\"") {
3293 eprintln!("DEBUG: col_name = '{col_name}'");
3294 }
3295 }
3296
3297 let is_valid = if col_name.starts_with('"') && col_name.ends_with('"') {
3299 true
3301 } else {
3302 col_name.chars().all(|c| c.is_alphanumeric() || c == '_')
3304 };
3305
3306 #[cfg(test)]
3307 {
3308 if trimmed.contains("\"Last Name\"") {
3309 eprintln!("DEBUG: is_valid = {is_valid}");
3310 }
3311 }
3312
3313 if is_valid {
3314 let partial_method = if after_dot.is_empty() {
3317 None
3318 } else if after_dot.chars().all(|c| c.is_alphanumeric() || c == '_') {
3319 Some(after_dot.to_string())
3320 } else {
3321 None
3322 };
3323
3324 let col_name_for_context = if col_name.starts_with('"')
3326 && col_name.ends_with('"')
3327 && col_name.len() > 2
3328 {
3329 col_name[1..col_name.len() - 1].to_string()
3330 } else {
3331 col_name.to_string()
3332 };
3333
3334 return (
3335 CursorContext::AfterColumn(col_name_for_context),
3336 partial_method,
3337 );
3338 }
3339 }
3340 }
3341 }
3342
3343 if let Some(and_pos) = upper.rfind(" AND ") {
3345 if cursor_pos >= and_pos + 5 {
3347 let after_and = safe_slice_from(query, and_pos + 5);
3349 let partial = extract_partial_at_end(after_and);
3350 return (CursorContext::AfterLogicalOp(LogicalOp::And), partial);
3351 }
3352 }
3353
3354 if let Some(or_pos) = upper.rfind(" OR ") {
3355 if cursor_pos >= or_pos + 4 {
3357 let after_or = safe_slice_from(query, or_pos + 4);
3359 let partial = extract_partial_at_end(after_or);
3360 return (CursorContext::AfterLogicalOp(LogicalOp::Or), partial);
3361 }
3362 }
3363
3364 if trimmed.to_uppercase().ends_with(" AND") || trimmed.to_uppercase().ends_with(" OR") {
3366 let op = if trimmed.to_uppercase().ends_with(" AND") {
3367 LogicalOp::And
3368 } else {
3369 LogicalOp::Or
3370 };
3371 return (CursorContext::AfterLogicalOp(op), None);
3372 }
3373
3374 if upper.ends_with(" ORDER BY ") || upper.ends_with(" ORDER BY") || upper.contains("ORDER BY ")
3376 {
3377 return (CursorContext::OrderByClause, extract_partial_at_end(query));
3378 }
3379
3380 if upper.contains("WHERE") && !upper.contains("ORDER") && !upper.contains("GROUP") {
3381 return (CursorContext::WhereClause, extract_partial_at_end(query));
3382 }
3383
3384 if upper.contains("FROM") && !upper.contains("WHERE") && !upper.contains("ORDER") {
3385 return (CursorContext::FromClause, extract_partial_at_end(query));
3386 }
3387
3388 if upper.contains("SELECT") && !upper.contains("FROM") {
3389 return (CursorContext::SelectClause, extract_partial_at_end(query));
3390 }
3391
3392 (CursorContext::Unknown, None)
3393}
3394
3395fn extract_partial_at_end(query: &str) -> Option<String> {
3396 let trimmed = query.trim();
3397
3398 if let Some(last_word) = trimmed.split_whitespace().last() {
3400 if last_word.starts_with('"') && !last_word.ends_with('"') {
3401 return Some(last_word.to_string());
3403 }
3404 }
3405
3406 let last_word = trimmed.split_whitespace().last()?;
3408
3409 if last_word.chars().all(|c| c.is_alphanumeric() || c == '_') && !is_sql_keyword(last_word) {
3411 Some(last_word.to_string())
3412 } else {
3413 None
3414 }
3415}
3416
3417fn is_sql_keyword(word: &str) -> bool {
3418 matches!(
3419 word.to_uppercase().as_str(),
3420 "SELECT"
3421 | "FROM"
3422 | "WHERE"
3423 | "AND"
3424 | "OR"
3425 | "IN"
3426 | "ORDER"
3427 | "BY"
3428 | "GROUP"
3429 | "HAVING"
3430 | "ASC"
3431 | "DESC"
3432 | "DISTINCT"
3433 )
3434}
3435
3436#[cfg(test)]
3437mod tests {
3438 use super::*;
3439
3440 #[test]
3441 fn test_tokenizer_window_functions() {
3442 let mut lexer = Lexer::new("LAG(value) OVER (PARTITION BY category ORDER BY id)");
3443 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "LAG"));
3444 assert!(matches!(lexer.next_token(), Token::LeftParen));
3445 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "value"));
3446 assert!(matches!(lexer.next_token(), Token::RightParen));
3447
3448 let over_token = lexer.next_token();
3449 println!("Expected OVER, got: {:?}", over_token);
3450 assert!(matches!(over_token, Token::Over));
3451
3452 assert!(matches!(lexer.next_token(), Token::LeftParen));
3453 assert!(matches!(lexer.next_token(), Token::Partition));
3454 assert!(matches!(lexer.next_token(), Token::By));
3455 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "category"));
3456 }
3457
3458 #[test]
3459 fn test_parse_window_function() {
3460 let query = "SELECT LAG(value, 1) OVER (ORDER BY id) as prev_value FROM test";
3461 let mut parser = Parser::new(query);
3462 let result = parser.parse();
3463
3464 assert!(
3465 result.is_ok(),
3466 "Failed to parse window function: {:?}",
3467 result
3468 );
3469 let stmt = result.unwrap();
3470
3471 if let Some(item) = stmt.select_items.get(0) {
3473 match item {
3474 SelectItem::Expression { expr, alias } => {
3475 println!("Parsed expression: {:?}", expr);
3476 assert!(matches!(expr, SqlExpression::WindowFunction { .. }));
3477 assert_eq!(alias, "prev_value");
3478 }
3479 _ => panic!("Expected expression, got: {:?}", item),
3480 }
3481 } else {
3482 panic!("No select items found");
3483 }
3484 }
3485
3486 #[test]
3487 fn test_chained_method_calls() {
3488 let query = "SELECT * FROM trades WHERE commission.ToString().IndexOf('.') = 1";
3490 let mut parser = Parser::new(query);
3491 let result = parser.parse();
3492
3493 assert!(
3494 result.is_ok(),
3495 "Failed to parse chained method calls: {result:?}"
3496 );
3497
3498 let query2 = "SELECT * FROM data WHERE field.ToUpper().Replace('A', 'B').StartsWith('C')";
3500 let mut parser2 = Parser::new(query2);
3501 let result2 = parser2.parse();
3502
3503 assert!(
3504 result2.is_ok(),
3505 "Failed to parse multiple chained calls: {result2:?}"
3506 );
3507 }
3508
3509 #[test]
3510 fn test_tokenizer() {
3511 let mut lexer = Lexer::new("SELECT * FROM trade_deal WHERE price > 100");
3512
3513 assert!(matches!(lexer.next_token(), Token::Select));
3514 assert!(matches!(lexer.next_token(), Token::Star));
3515 assert!(matches!(lexer.next_token(), Token::From));
3516 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "trade_deal"));
3517 assert!(matches!(lexer.next_token(), Token::Where));
3518 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "price"));
3519 assert!(matches!(lexer.next_token(), Token::GreaterThan));
3520 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "100"));
3521 }
3522
3523 #[test]
3524 fn test_tokenizer_datetime() {
3525 let mut lexer = Lexer::new("WHERE createdDate > DateTime(2025, 10, 20)");
3526
3527 assert!(matches!(lexer.next_token(), Token::Where));
3528 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "createdDate"));
3529 assert!(matches!(lexer.next_token(), Token::GreaterThan));
3530 assert!(matches!(lexer.next_token(), Token::DateTime));
3531 assert!(matches!(lexer.next_token(), Token::LeftParen));
3532 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "2025"));
3533 assert!(matches!(lexer.next_token(), Token::Comma));
3534 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "10"));
3535 assert!(matches!(lexer.next_token(), Token::Comma));
3536 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "20"));
3537 assert!(matches!(lexer.next_token(), Token::RightParen));
3538 }
3539
3540 #[test]
3541 fn test_parse_simple_select() {
3542 let mut parser = Parser::new("SELECT * FROM trade_deal");
3543 let stmt = parser.parse().unwrap();
3544
3545 assert_eq!(stmt.columns, vec!["*"]);
3546 assert_eq!(stmt.from_table, Some("trade_deal".to_string()));
3547 assert!(stmt.where_clause.is_none());
3548 }
3549
3550 #[test]
3551 fn test_parse_where_with_method() {
3552 let mut parser = Parser::new("SELECT * FROM trade_deal WHERE name.Contains(\"test\")");
3553 let stmt = parser.parse().unwrap();
3554
3555 assert!(stmt.where_clause.is_some());
3556 let where_clause = stmt.where_clause.unwrap();
3557 assert_eq!(where_clause.conditions.len(), 1);
3558 }
3559
3560 #[test]
3561 fn test_parse_datetime_constructor() {
3562 let mut parser =
3563 Parser::new("SELECT * FROM trade_deal WHERE createdDate > DateTime(2025, 10, 20)");
3564 let stmt = parser.parse().unwrap();
3565
3566 assert!(stmt.where_clause.is_some());
3567 let where_clause = stmt.where_clause.unwrap();
3568 assert_eq!(where_clause.conditions.len(), 1);
3569
3570 if let SqlExpression::BinaryOp { left, op, right } = &where_clause.conditions[0].expr {
3572 assert_eq!(op, ">");
3573 assert!(matches!(left.as_ref(), SqlExpression::Column(col) if col == "createdDate"));
3574 assert!(matches!(
3575 right.as_ref(),
3576 SqlExpression::DateTimeConstructor {
3577 year: 2025,
3578 month: 10,
3579 day: 20,
3580 hour: None,
3581 minute: None,
3582 second: None
3583 }
3584 ));
3585 } else {
3586 panic!("Expected BinaryOp with DateTime constructor");
3587 }
3588 }
3589
3590 #[test]
3591 fn test_cursor_context_after_and() {
3592 let query = "SELECT * FROM trade_deal WHERE status = 'active' AND ";
3593 let (context, partial) = detect_cursor_context(query, query.len());
3594
3595 assert!(matches!(
3596 context,
3597 CursorContext::AfterLogicalOp(LogicalOp::And)
3598 ));
3599 assert_eq!(partial, None);
3600 }
3601
3602 #[test]
3603 fn test_cursor_context_with_partial() {
3604 let query = "SELECT * FROM trade_deal WHERE status = 'active' AND p";
3605 let (context, partial) = detect_cursor_context(query, query.len());
3606
3607 assert!(matches!(
3608 context,
3609 CursorContext::AfterLogicalOp(LogicalOp::And)
3610 ));
3611 assert_eq!(partial, Some("p".to_string()));
3612 }
3613
3614 #[test]
3615 fn test_cursor_context_after_datetime_comparison() {
3616 let query = "SELECT * FROM trade_deal WHERE createdDate > ";
3617 let (context, partial) = detect_cursor_context(query, query.len());
3618
3619 assert!(
3620 matches!(context, CursorContext::AfterComparisonOp(col, op) if col == "createdDate" && op == ">")
3621 );
3622 assert_eq!(partial, None);
3623 }
3624
3625 #[test]
3626 fn test_cursor_context_partial_datetime() {
3627 let query = "SELECT * FROM trade_deal WHERE createdDate > Date";
3628 let (context, partial) = detect_cursor_context(query, query.len());
3629
3630 assert!(
3631 matches!(context, CursorContext::AfterComparisonOp(col, op) if col == "createdDate" && op == ">")
3632 );
3633 assert_eq!(partial, Some("Date".to_string()));
3634 }
3635
3636 #[test]
3638 fn test_tokenizer_quoted_identifier() {
3639 let mut lexer = Lexer::new(r#"SELECT "Customer Id", "First Name" FROM customers"#);
3640
3641 assert!(matches!(lexer.next_token(), Token::Select));
3642 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "Customer Id"));
3643 assert!(matches!(lexer.next_token(), Token::Comma));
3644 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "First Name"));
3645 assert!(matches!(lexer.next_token(), Token::From));
3646 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "customers"));
3647 }
3648
3649 #[test]
3650 fn test_tokenizer_quoted_vs_string_literal() {
3651 let mut lexer = Lexer::new(r#"WHERE "Customer Id" = 'John' AND Country.Contains('USA')"#);
3653
3654 assert!(matches!(lexer.next_token(), Token::Where));
3655 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "Customer Id"));
3656 assert!(matches!(lexer.next_token(), Token::Equal));
3657 assert!(matches!(lexer.next_token(), Token::StringLiteral(s) if s == "John"));
3658 assert!(matches!(lexer.next_token(), Token::And));
3659 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Country"));
3660 assert!(matches!(lexer.next_token(), Token::Dot));
3661 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Contains"));
3662 assert!(matches!(lexer.next_token(), Token::LeftParen));
3663 assert!(matches!(lexer.next_token(), Token::StringLiteral(s) if s == "USA"));
3664 assert!(matches!(lexer.next_token(), Token::RightParen));
3665 }
3666
3667 #[test]
3668 fn test_tokenizer_method_with_double_quotes_should_be_string() {
3669 let mut lexer = Lexer::new(r#"Country.Contains("Alb")"#);
3672
3673 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Country"));
3674 assert!(matches!(lexer.next_token(), Token::Dot));
3675 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Contains"));
3676 assert!(matches!(lexer.next_token(), Token::LeftParen));
3677
3678 let token = lexer.next_token();
3681 println!("Token for \"Alb\": {token:?}");
3682 assert!(matches!(lexer.next_token(), Token::RightParen));
3686 }
3687
3688 #[test]
3689 fn test_parse_select_with_quoted_columns() {
3690 let mut parser = Parser::new(r#"SELECT "Customer Id", Company FROM customers"#);
3691 let stmt = parser.parse().unwrap();
3692
3693 assert_eq!(stmt.columns, vec!["Customer Id", "Company"]);
3694 assert_eq!(stmt.from_table, Some("customers".to_string()));
3695 }
3696
3697 #[test]
3698 fn test_cursor_context_select_with_partial_quoted() {
3699 let query = r#"SELECT "Cust"#;
3701 let (context, partial) = detect_cursor_context(query, query.len() - 1); println!("Context: {context:?}, Partial: {partial:?}");
3704 assert!(matches!(context, CursorContext::SelectClause));
3705 }
3708
3709 #[test]
3710 fn test_cursor_context_select_after_comma_with_quoted() {
3711 let query = r#"SELECT Company, "Customer "#;
3713 let (context, partial) = detect_cursor_context(query, query.len());
3714
3715 println!("Context: {context:?}, Partial: {partial:?}");
3716 assert!(matches!(context, CursorContext::SelectClause));
3717 }
3719
3720 #[test]
3721 fn test_cursor_context_order_by_quoted() {
3722 let query = r#"SELECT * FROM customers ORDER BY "Cust"#;
3723 let (context, partial) = detect_cursor_context(query, query.len() - 1);
3724
3725 println!("Context: {context:?}, Partial: {partial:?}");
3726 assert!(matches!(context, CursorContext::OrderByClause));
3727 }
3729
3730 #[test]
3731 fn test_where_clause_with_quoted_column() {
3732 let mut parser = Parser::new(r#"SELECT * FROM customers WHERE "Customer Id" = 'C123'"#);
3733 let stmt = parser.parse().unwrap();
3734
3735 assert!(stmt.where_clause.is_some());
3736 let where_clause = stmt.where_clause.unwrap();
3737 assert_eq!(where_clause.conditions.len(), 1);
3738
3739 if let SqlExpression::BinaryOp { left, op, right } = &where_clause.conditions[0].expr {
3740 assert_eq!(op, "=");
3741 assert!(matches!(left.as_ref(), SqlExpression::Column(col) if col == "Customer Id"));
3742 assert!(matches!(right.as_ref(), SqlExpression::StringLiteral(s) if s == "C123"));
3743 } else {
3744 panic!("Expected BinaryOp");
3745 }
3746 }
3747
3748 #[test]
3749 fn test_parse_method_with_double_quotes_as_string() {
3750 let mut parser = Parser::new(r#"SELECT * FROM customers WHERE Country.Contains("USA")"#);
3752 let stmt = parser.parse().unwrap();
3753
3754 assert!(stmt.where_clause.is_some());
3755 let where_clause = stmt.where_clause.unwrap();
3756 assert_eq!(where_clause.conditions.len(), 1);
3757
3758 if let SqlExpression::MethodCall {
3759 object,
3760 method,
3761 args,
3762 } = &where_clause.conditions[0].expr
3763 {
3764 assert_eq!(object, "Country");
3765 assert_eq!(method, "Contains");
3766 assert_eq!(args.len(), 1);
3767 assert!(matches!(&args[0], SqlExpression::StringLiteral(s) if s == "USA"));
3769 } else {
3770 panic!("Expected MethodCall");
3771 }
3772 }
3773
3774 #[test]
3775 fn test_extract_partial_with_quoted_columns_in_query() {
3776 let query = r#"SELECT City,Company,Country,"Customer Id" FROM customers ORDER BY coun"#;
3778 let (context, partial) = detect_cursor_context(query, query.len());
3779
3780 assert!(matches!(context, CursorContext::OrderByClause));
3781 assert_eq!(
3782 partial,
3783 Some("coun".to_string()),
3784 "Should extract 'coun' as partial, not everything after the quoted column"
3785 );
3786 }
3787
3788 #[test]
3789 fn test_extract_partial_quoted_identifier_being_typed() {
3790 let query = r#"SELECT "Cust"#;
3792 let partial = extract_partial_at_end(query);
3793 assert_eq!(partial, Some("\"Cust".to_string()));
3794
3795 let query2 = r#"SELECT "Customer Id" FROM"#;
3797 let partial2 = extract_partial_at_end(query2);
3798 assert_eq!(partial2, None); }
3800
3801 #[test]
3803 fn test_complex_where_parentheses_basic() {
3804 let mut parser =
3806 Parser::new(r#"SELECT * FROM trades WHERE (status = "active" OR status = "pending")"#);
3807 let stmt = parser.parse().unwrap();
3808
3809 assert!(stmt.where_clause.is_some());
3810 let where_clause = stmt.where_clause.unwrap();
3811 assert_eq!(where_clause.conditions.len(), 1);
3812
3813 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[0].expr {
3815 assert_eq!(op, "OR");
3816 } else {
3817 panic!("Expected BinaryOp with OR");
3818 }
3819 }
3820
3821 #[test]
3822 fn test_complex_where_mixed_and_or_with_parens() {
3823 let mut parser = Parser::new(
3825 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#,
3826 );
3827 let stmt = parser.parse().unwrap();
3828
3829 assert!(stmt.where_clause.is_some());
3830 let where_clause = stmt.where_clause.unwrap();
3831 assert_eq!(where_clause.conditions.len(), 2);
3832
3833 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[0].expr {
3835 assert_eq!(op, "OR");
3836 } else {
3837 panic!("Expected first condition to be OR expression");
3838 }
3839
3840 assert!(matches!(
3842 where_clause.conditions[0].connector,
3843 Some(LogicalOp::And)
3844 ));
3845
3846 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[1].expr {
3848 assert_eq!(op, ">");
3849 } else {
3850 panic!("Expected second condition to be price > 100");
3851 }
3852 }
3853
3854 #[test]
3855 fn test_complex_where_nested_parentheses() {
3856 let mut parser = Parser::new(
3858 r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND price > 100) OR status = "cancelled""#,
3859 );
3860 let stmt = parser.parse().unwrap();
3861
3862 assert!(stmt.where_clause.is_some());
3863 let where_clause = stmt.where_clause.unwrap();
3864
3865 assert!(!where_clause.conditions.is_empty());
3867 }
3868
3869 #[test]
3870 fn test_complex_where_multiple_or_groups() {
3871 let mut parser = Parser::new(
3873 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL" OR symbol = "MSFT") AND (price > 100 AND price < 500)"#,
3874 );
3875 let stmt = parser.parse().unwrap();
3876
3877 assert!(stmt.where_clause.is_some());
3878 let where_clause = stmt.where_clause.unwrap();
3879 assert_eq!(where_clause.conditions.len(), 2);
3880
3881 assert!(matches!(
3883 where_clause.conditions[0].connector,
3884 Some(LogicalOp::And)
3885 ));
3886 }
3887
3888 #[test]
3889 fn test_complex_where_with_methods_in_parens() {
3890 let mut parser = Parser::new(
3892 r#"SELECT * FROM trades WHERE (symbol.StartsWith("A") OR symbol.StartsWith("G")) AND volume > 1000000"#,
3893 );
3894 let stmt = parser.parse().unwrap();
3895
3896 assert!(stmt.where_clause.is_some());
3897 let where_clause = stmt.where_clause.unwrap();
3898 assert_eq!(where_clause.conditions.len(), 2);
3899
3900 if let SqlExpression::BinaryOp { op, left, right } = &where_clause.conditions[0].expr {
3902 assert_eq!(op, "OR");
3903 assert!(matches!(left.as_ref(), SqlExpression::MethodCall { .. }));
3904 assert!(matches!(right.as_ref(), SqlExpression::MethodCall { .. }));
3905 } else {
3906 panic!("Expected OR of method calls");
3907 }
3908 }
3909
3910 #[test]
3911 fn test_complex_where_date_comparisons_with_parens() {
3912 let mut parser = Parser::new(
3914 r#"SELECT * FROM trades WHERE (executionDate > DateTime(2024, 1, 1) AND executionDate < DateTime(2024, 12, 31)) AND (status = "filled" OR status = "partial")"#,
3915 );
3916 let stmt = parser.parse().unwrap();
3917
3918 assert!(stmt.where_clause.is_some());
3919 let where_clause = stmt.where_clause.unwrap();
3920 assert_eq!(where_clause.conditions.len(), 2);
3921
3922 assert!(matches!(
3924 where_clause.conditions[0].connector,
3925 Some(LogicalOp::And)
3926 ));
3927 }
3928
3929 #[test]
3930 fn test_complex_where_price_volume_filters() {
3931 let mut parser = Parser::new(
3933 r"SELECT * FROM trades WHERE ((price > 100 AND price < 200) OR (price > 500 AND price < 1000)) AND volume > 10000",
3934 );
3935 let stmt = parser.parse().unwrap();
3936
3937 assert!(stmt.where_clause.is_some());
3938 let where_clause = stmt.where_clause.unwrap();
3939
3940 assert!(!where_clause.conditions.is_empty());
3942 }
3943
3944 #[test]
3945 fn test_complex_where_mixed_string_numeric() {
3946 let mut parser = Parser::new(
3948 r#"SELECT * FROM trades WHERE (exchange = "NYSE" OR exchange = "NASDAQ") AND (volume > 1000000 OR notes.Contains("urgent"))"#,
3949 );
3950 let stmt = parser.parse().unwrap();
3951
3952 assert!(stmt.where_clause.is_some());
3953 }
3955
3956 #[test]
3957 fn test_complex_where_triple_nested() {
3958 let mut parser = Parser::new(
3960 r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND (price > 100 OR volume > 1000000)) OR (status = "cancelled" AND reason.Contains("timeout"))"#,
3961 );
3962 let stmt = parser.parse().unwrap();
3963
3964 assert!(stmt.where_clause.is_some());
3965 }
3967
3968 #[test]
3969 fn test_complex_where_single_parens_around_and() {
3970 let mut parser = Parser::new(
3972 r#"SELECT * FROM trades WHERE (symbol = "AAPL" AND price > 150 AND volume > 100000)"#,
3973 );
3974 let stmt = parser.parse().unwrap();
3975
3976 assert!(stmt.where_clause.is_some());
3977 let where_clause = stmt.where_clause.unwrap();
3978
3979 assert!(!where_clause.conditions.is_empty());
3981 }
3982
3983 #[test]
3985 fn test_format_preserves_simple_parentheses() {
3986 let query = r#"SELECT * FROM trades WHERE (status = "active" OR status = "pending")"#;
3987 let formatted = format_sql_pretty_compact(query, 5);
3988 let formatted_text = formatted.join(" ");
3989
3990 assert!(formatted_text.contains("(status"));
3992 assert!(formatted_text.contains("\"pending\")"));
3993
3994 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3996 let formatted_parens = formatted_text
3997 .chars()
3998 .filter(|c| *c == '(' || *c == ')')
3999 .count();
4000 assert_eq!(
4001 original_parens, formatted_parens,
4002 "Parentheses should be preserved"
4003 );
4004 }
4005
4006 #[test]
4007 fn test_format_preserves_complex_parentheses() {
4008 let query =
4009 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#;
4010 let formatted = format_sql_pretty_compact(query, 5);
4011 let formatted_text = formatted.join(" ");
4012
4013 assert!(formatted_text.contains("(symbol"));
4015 assert!(formatted_text.contains("\"GOOGL\")"));
4016
4017 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
4019 let formatted_parens = formatted_text
4020 .chars()
4021 .filter(|c| *c == '(' || *c == ')')
4022 .count();
4023 assert_eq!(original_parens, formatted_parens);
4024 }
4025
4026 #[test]
4027 fn test_format_preserves_nested_parentheses() {
4028 let query = r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND price > 100) OR status = "cancelled""#;
4029 let formatted = format_sql_pretty_compact(query, 5);
4030 let formatted_text = formatted.join(" ");
4031
4032 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
4034 let formatted_parens = formatted_text
4035 .chars()
4036 .filter(|c| *c == '(' || *c == ')')
4037 .count();
4038 assert_eq!(
4039 original_parens, formatted_parens,
4040 "Nested parentheses should be preserved"
4041 );
4042 assert_eq!(original_parens, 4, "Should have 4 parentheses total");
4043 }
4044
4045 #[test]
4046 fn test_format_preserves_method_calls_in_parentheses() {
4047 let query = r#"SELECT * FROM trades WHERE (symbol.StartsWith("A") OR symbol.StartsWith("G")) AND volume > 1000000"#;
4048 let formatted = format_sql_pretty_compact(query, 5);
4049 let formatted_text = formatted.join(" ");
4050
4051 assert!(formatted_text.contains("(symbol.StartsWith"));
4053 assert!(formatted_text.contains("StartsWith(\"A\")"));
4054 assert!(formatted_text.contains("StartsWith(\"G\")"));
4055
4056 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
4058 let formatted_parens = formatted_text
4059 .chars()
4060 .filter(|c| *c == '(' || *c == ')')
4061 .count();
4062 assert_eq!(original_parens, formatted_parens);
4063 assert_eq!(
4064 original_parens, 6,
4065 "Should have 6 parentheses (1 group + 2 method calls)"
4066 );
4067 }
4068
4069 #[test]
4070 fn test_format_preserves_multiple_groups() {
4071 let query = r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND (price > 100 AND price < 500)"#;
4072 let formatted = format_sql_pretty_compact(query, 5);
4073 let formatted_text = formatted.join(" ");
4074
4075 assert!(formatted_text.contains("(symbol"));
4077 assert!(formatted_text.contains("(price"));
4078
4079 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
4080 let formatted_parens = formatted_text
4081 .chars()
4082 .filter(|c| *c == '(' || *c == ')')
4083 .count();
4084 assert_eq!(original_parens, formatted_parens);
4085 assert_eq!(original_parens, 4, "Should have 4 parentheses (2 groups)");
4086 }
4087
4088 #[test]
4089 fn test_format_preserves_date_ranges() {
4090 let query = r"SELECT * FROM trades WHERE (executionDate > DateTime(2024, 1, 1) AND executionDate < DateTime(2024, 12, 31))";
4091 let formatted = format_sql_pretty_compact(query, 5);
4092 let formatted_text = formatted.join(" ");
4093
4094 assert!(formatted_text.contains("(executionDate"));
4096 assert!(formatted_text.contains("DateTime(2024, 1, 1)"));
4097 assert!(formatted_text.contains("DateTime(2024, 12, 31)"));
4098
4099 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
4100 let formatted_parens = formatted_text
4101 .chars()
4102 .filter(|c| *c == '(' || *c == ')')
4103 .count();
4104 assert_eq!(original_parens, formatted_parens);
4105 }
4106
4107 #[test]
4108 fn test_format_multiline_layout() {
4109 let query =
4111 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#;
4112 let formatted = format_sql_pretty_compact(query, 5);
4113
4114 assert!(formatted.len() >= 4, "Should have multiple lines");
4116 assert_eq!(formatted[0], "SELECT");
4117 assert!(formatted[1].trim().starts_with('*'));
4118 assert!(formatted[2].starts_with("FROM"));
4119 assert_eq!(formatted[3], "WHERE");
4120
4121 let where_lines: Vec<_> = formatted.iter().skip(4).collect();
4123 assert!(where_lines.iter().any(|l| l.contains("(symbol")));
4124 assert!(where_lines.iter().any(|l| l.trim() == "AND"));
4125 }
4126
4127 #[test]
4128 fn test_between_simple() {
4129 let mut parser = Parser::new("SELECT * FROM table WHERE price BETWEEN 50 AND 100");
4130 let stmt = parser.parse().expect("Should parse simple BETWEEN");
4131
4132 assert!(stmt.where_clause.is_some());
4133 let where_clause = stmt.where_clause.unwrap();
4134 assert_eq!(where_clause.conditions.len(), 1);
4135
4136 let ast = format_ast_tree("SELECT * FROM table WHERE price BETWEEN 50 AND 100");
4138 assert!(!ast.contains("PARSE ERROR"));
4139 assert!(ast.contains("SelectStatement"));
4140 }
4141
4142 #[test]
4143 fn test_between_in_parentheses() {
4144 let mut parser = Parser::new("SELECT * FROM table WHERE (price BETWEEN 50 AND 100)");
4145 let stmt = parser.parse().expect("Should parse BETWEEN in parentheses");
4146
4147 assert!(stmt.where_clause.is_some());
4148
4149 let ast = format_ast_tree("SELECT * FROM table WHERE (price BETWEEN 50 AND 100)");
4151 assert!(!ast.contains("PARSE ERROR"), "Should not have parse error");
4152 }
4153
4154 #[test]
4155 fn test_between_with_or() {
4156 let query = "SELECT * FROM test WHERE (Price BETWEEN 50 AND 100) OR (quantity > 5)";
4157 let mut parser = Parser::new(query);
4158 let stmt = parser.parse().expect("Should parse BETWEEN with OR");
4159
4160 assert!(stmt.where_clause.is_some());
4161 }
4164
4165 #[test]
4166 fn test_between_with_and() {
4167 let query = "SELECT * FROM table WHERE category = 'Books' AND price BETWEEN 10 AND 50";
4168 let mut parser = Parser::new(query);
4169 let stmt = parser.parse().expect("Should parse BETWEEN with AND");
4170
4171 assert!(stmt.where_clause.is_some());
4172 let where_clause = stmt.where_clause.unwrap();
4173 assert_eq!(where_clause.conditions.len(), 2); }
4175
4176 #[test]
4177 fn test_multiple_between() {
4178 let query =
4179 "SELECT * FROM table WHERE (price BETWEEN 10 AND 50) AND (quantity BETWEEN 5 AND 20)";
4180 let mut parser = Parser::new(query);
4181 let stmt = parser
4182 .parse()
4183 .expect("Should parse multiple BETWEEN clauses");
4184
4185 assert!(stmt.where_clause.is_some());
4186 }
4187
4188 #[test]
4189 fn test_between_complex_query() {
4190 let query = "SELECT * FROM test_sorting WHERE (Price BETWEEN 50 AND 100) OR (Product.Length() > 5) ORDER BY Category ASC, price DESC";
4192 let mut parser = Parser::new(query);
4193 let stmt = parser
4194 .parse()
4195 .expect("Should parse complex query with BETWEEN, method calls, and ORDER BY");
4196
4197 assert!(stmt.where_clause.is_some());
4198 assert!(stmt.order_by.is_some());
4199
4200 let order_by = stmt.order_by.unwrap();
4201 assert_eq!(order_by.len(), 2);
4202 assert_eq!(order_by[0].column, "Category");
4203 assert!(matches!(order_by[0].direction, SortDirection::Asc));
4204 assert_eq!(order_by[1].column, "price");
4205 assert!(matches!(order_by[1].direction, SortDirection::Desc));
4206 }
4207
4208 #[test]
4209 fn test_between_formatting() {
4210 let expr = SqlExpression::Between {
4211 expr: Box::new(SqlExpression::Column("price".to_string())),
4212 lower: Box::new(SqlExpression::NumberLiteral("50".to_string())),
4213 upper: Box::new(SqlExpression::NumberLiteral("100".to_string())),
4214 };
4215
4216 let formatted = format_expression(&expr);
4217 assert_eq!(formatted, "price BETWEEN 50 AND 100");
4218
4219 let ast_formatted = format_expression_ast(&expr);
4220 assert!(ast_formatted.contains("Between"));
4221 assert!(ast_formatted.contains("50"));
4222 assert!(ast_formatted.contains("100"));
4223 }
4224
4225 #[test]
4226 fn test_utf8_boundary_safety() {
4227 let query_with_unicode = "SELECT * FROM table WHERE column = 'héllo'";
4229
4230 for pos in 0..=query_with_unicode.len() {
4232 let result =
4234 std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, pos));
4235
4236 assert!(
4237 result.is_ok(),
4238 "Panic at position {pos} in query with Unicode"
4239 );
4240 }
4241
4242 let result = std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, 1000));
4244 assert!(result.is_ok(), "Panic with position beyond string length");
4245
4246 let pos_in_e = query_with_unicode.find('é').unwrap() + 1; let result =
4249 std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, pos_in_e));
4250 assert!(
4251 result.is_ok(),
4252 "Panic with cursor in middle of UTF-8 character"
4253 );
4254 }
4255}