1use chrono::{Datelike, Local, NaiveDateTime};
2
3#[derive(Debug, Clone, PartialEq)]
4pub enum Token {
5 Select,
7 From,
8 Where,
9 With, And,
11 Or,
12 In,
13 Not,
14 Between,
15 Like,
16 Is,
17 Null,
18 OrderBy,
19 GroupBy,
20 Having,
21 As,
22 Asc,
23 Desc,
24 Limit,
25 Offset,
26 DateTime, Case, When, Then, Else, End, Distinct, Over, Partition, By, Identifier(String),
39 QuotedIdentifier(String), StringLiteral(String),
41 NumberLiteral(String),
42 Star,
43
44 Dot,
46 Comma,
47 LeftParen,
48 RightParen,
49 Equal,
50 NotEqual,
51 LessThan,
52 GreaterThan,
53 LessThanOrEqual,
54 GreaterThanOrEqual,
55
56 Plus,
58 Minus,
59 Divide,
60 Modulo,
61
62 Eof,
64}
65
66#[derive(Debug, Clone)]
67pub struct Lexer {
68 input: Vec<char>,
69 position: usize,
70 current_char: Option<char>,
71}
72
73impl Lexer {
74 #[must_use]
75 pub fn new(input: &str) -> Self {
76 let chars: Vec<char> = input.chars().collect();
77 let current = chars.first().copied();
78 Self {
79 input: chars,
80 position: 0,
81 current_char: current,
82 }
83 }
84
85 fn advance(&mut self) {
86 self.position += 1;
87 self.current_char = self.input.get(self.position).copied();
88 }
89
90 fn peek(&self, offset: usize) -> Option<char> {
91 self.input.get(self.position + offset).copied()
92 }
93
94 fn skip_whitespace(&mut self) {
95 while let Some(ch) = self.current_char {
96 if ch.is_whitespace() {
97 self.advance();
98 } else {
99 break;
100 }
101 }
102 }
103
104 fn skip_whitespace_and_comments(&mut self) {
105 loop {
106 while let Some(ch) = self.current_char {
108 if ch.is_whitespace() {
109 self.advance();
110 } else {
111 break;
112 }
113 }
114
115 match self.current_char {
117 Some('-') if self.peek(1) == Some('-') => {
118 self.advance(); self.advance(); while let Some(ch) = self.current_char {
122 self.advance();
123 if ch == '\n' {
124 break;
125 }
126 }
127 }
128 Some('/') if self.peek(1) == Some('*') => {
129 self.advance(); self.advance(); while let Some(ch) = self.current_char {
133 if ch == '*' && self.peek(1) == Some('/') {
134 self.advance(); self.advance(); break;
137 }
138 self.advance();
139 }
140 }
141 _ => {
142 break;
144 }
145 }
146 }
147 }
148
149 fn read_identifier(&mut self) -> String {
150 let mut result = String::new();
151 while let Some(ch) = self.current_char {
152 if ch.is_alphanumeric() || ch == '_' {
153 result.push(ch);
154 self.advance();
155 } else {
156 break;
157 }
158 }
159 result
160 }
161
162 fn read_string(&mut self) -> String {
163 let mut result = String::new();
164 let quote_char = self.current_char.unwrap(); self.advance(); while let Some(ch) = self.current_char {
168 if ch == quote_char {
169 self.advance(); break;
171 }
172 result.push(ch);
173 self.advance();
174 }
175 result
176 }
177
178 fn read_number(&mut self) -> String {
179 let mut result = String::new();
180 let mut has_e = false;
181
182 while let Some(ch) = self.current_char {
184 if !has_e && (ch.is_numeric() || ch == '.') {
185 result.push(ch);
186 self.advance();
187 } else if (ch == 'e' || ch == 'E') && !has_e && !result.is_empty() {
188 result.push(ch);
190 self.advance();
191 has_e = true;
192
193 if let Some(sign) = self.current_char {
195 if sign == '+' || sign == '-' {
196 result.push(sign);
197 self.advance();
198 }
199 }
200
201 while let Some(digit) = self.current_char {
203 if digit.is_numeric() {
204 result.push(digit);
205 self.advance();
206 } else {
207 break;
208 }
209 }
210 break; } else {
212 break;
213 }
214 }
215 result
216 }
217
218 pub fn next_token(&mut self) -> Token {
219 self.skip_whitespace_and_comments();
220
221 match self.current_char {
222 None => Token::Eof,
223 Some('*') => {
224 self.advance();
225 Token::Star }
229 Some('+') => {
230 self.advance();
231 Token::Plus
232 }
233 Some('/') => {
234 if self.peek(1) == Some('*') {
236 self.skip_whitespace_and_comments();
239 return self.next_token();
240 }
241 self.advance();
242 Token::Divide
243 }
244 Some('%') => {
245 self.advance();
246 Token::Modulo
247 }
248 Some('.') => {
249 self.advance();
250 Token::Dot
251 }
252 Some(',') => {
253 self.advance();
254 Token::Comma
255 }
256 Some('(') => {
257 self.advance();
258 Token::LeftParen
259 }
260 Some(')') => {
261 self.advance();
262 Token::RightParen
263 }
264 Some('=') => {
265 self.advance();
266 Token::Equal
267 }
268 Some('<') => {
269 self.advance();
270 if self.current_char == Some('=') {
271 self.advance();
272 Token::LessThanOrEqual
273 } else if self.current_char == Some('>') {
274 self.advance();
275 Token::NotEqual
276 } else {
277 Token::LessThan
278 }
279 }
280 Some('>') => {
281 self.advance();
282 if self.current_char == Some('=') {
283 self.advance();
284 Token::GreaterThanOrEqual
285 } else {
286 Token::GreaterThan
287 }
288 }
289 Some('!') if self.peek(1) == Some('=') => {
290 self.advance();
291 self.advance();
292 Token::NotEqual
293 }
294 Some('"') => {
295 let ident_val = self.read_string();
297 Token::QuotedIdentifier(ident_val)
298 }
299 Some('\'') => {
300 let string_val = self.read_string();
302 Token::StringLiteral(string_val)
303 }
304 Some('-') if self.peek(1) == Some('-') => {
305 self.skip_whitespace_and_comments();
307 self.next_token()
308 }
309 Some('-') if self.peek(1).is_some_and(char::is_numeric) => {
310 self.advance(); let num = self.read_number();
313 Token::NumberLiteral(format!("-{num}"))
314 }
315 Some('-') => {
316 self.advance();
318 Token::Minus
319 }
320 Some(ch) if ch.is_numeric() => {
321 let num = self.read_number();
322 Token::NumberLiteral(num)
323 }
324 Some(ch) if ch.is_alphabetic() || ch == '_' => {
325 let ident = self.read_identifier();
326 match ident.to_uppercase().as_str() {
327 "SELECT" => Token::Select,
328 "FROM" => Token::From,
329 "WHERE" => Token::Where,
330 "WITH" => Token::With,
331 "AND" => Token::And,
332 "OR" => Token::Or,
333 "IN" => Token::In,
334 "NOT" => Token::Not,
335 "BETWEEN" => Token::Between,
336 "LIKE" => Token::Like,
337 "IS" => Token::Is,
338 "NULL" => Token::Null,
339 "ORDER" if self.peek_keyword("BY") => {
340 self.skip_whitespace();
341 self.read_identifier(); Token::OrderBy
343 }
344 "GROUP" if self.peek_keyword("BY") => {
345 self.skip_whitespace();
346 self.read_identifier(); Token::GroupBy
348 }
349 "HAVING" => Token::Having,
350 "AS" => Token::As,
351 "ASC" => Token::Asc,
352 "DESC" => Token::Desc,
353 "LIMIT" => Token::Limit,
354 "OFFSET" => Token::Offset,
355 "DATETIME" => Token::DateTime,
356 "CASE" => Token::Case,
357 "WHEN" => Token::When,
358 "THEN" => Token::Then,
359 "ELSE" => Token::Else,
360 "END" => Token::End,
361 "DISTINCT" => Token::Distinct,
362 "OVER" => Token::Over,
363 "PARTITION" => Token::Partition,
364 "BY" => Token::By,
365 _ => Token::Identifier(ident),
366 }
367 }
368 Some(ch) => {
369 self.advance();
370 Token::Identifier(ch.to_string())
371 }
372 }
373 }
374
375 fn peek_keyword(&mut self, keyword: &str) -> bool {
376 let saved_pos = self.position;
377 let saved_char = self.current_char;
378
379 self.skip_whitespace_and_comments();
380 let next_word = self.read_identifier();
381 let matches = next_word.to_uppercase() == keyword;
382
383 self.position = saved_pos;
385 self.current_char = saved_char;
386
387 matches
388 }
389
390 #[must_use]
391 pub fn get_position(&self) -> usize {
392 self.position
393 }
394
395 pub fn tokenize_all(&mut self) -> Vec<Token> {
396 let mut tokens = Vec::new();
397 loop {
398 let token = self.next_token();
399 if matches!(token, Token::Eof) {
400 tokens.push(token);
401 break;
402 }
403 tokens.push(token);
404 }
405 tokens
406 }
407
408 pub fn tokenize_all_with_positions(&mut self) -> Vec<(usize, usize, Token)> {
409 let mut tokens = Vec::new();
410 loop {
411 self.skip_whitespace_and_comments();
412 let start_pos = self.position;
413 let token = self.next_token();
414 let end_pos = self.position;
415
416 if matches!(token, Token::Eof) {
417 break;
418 }
419 tokens.push((start_pos, end_pos, token));
420 }
421 tokens
422 }
423}
424
425#[derive(Debug, Clone)]
427pub enum SqlExpression {
428 Column(String),
429 StringLiteral(String),
430 NumberLiteral(String),
431 BooleanLiteral(bool),
432 Null, DateTimeConstructor {
434 year: i32,
435 month: u32,
436 day: u32,
437 hour: Option<u32>,
438 minute: Option<u32>,
439 second: Option<u32>,
440 },
441 DateTimeToday {
442 hour: Option<u32>,
443 minute: Option<u32>,
444 second: Option<u32>,
445 },
446 MethodCall {
447 object: String,
448 method: String,
449 args: Vec<SqlExpression>,
450 },
451 ChainedMethodCall {
452 base: Box<SqlExpression>,
453 method: String,
454 args: Vec<SqlExpression>,
455 },
456 FunctionCall {
457 name: String,
458 args: Vec<SqlExpression>,
459 },
460 WindowFunction {
461 name: String,
462 args: Vec<SqlExpression>,
463 window_spec: WindowSpec,
464 },
465 BinaryOp {
466 left: Box<SqlExpression>,
467 op: String,
468 right: Box<SqlExpression>,
469 },
470 InList {
471 expr: Box<SqlExpression>,
472 values: Vec<SqlExpression>,
473 },
474 NotInList {
475 expr: Box<SqlExpression>,
476 values: Vec<SqlExpression>,
477 },
478 Between {
479 expr: Box<SqlExpression>,
480 lower: Box<SqlExpression>,
481 upper: Box<SqlExpression>,
482 },
483 Not {
484 expr: Box<SqlExpression>,
485 },
486 CaseExpression {
487 when_branches: Vec<WhenBranch>,
488 else_branch: Option<Box<SqlExpression>>,
489 },
490}
491
492#[derive(Debug, Clone)]
493pub struct WhenBranch {
494 pub condition: Box<SqlExpression>,
495 pub result: Box<SqlExpression>,
496}
497
498#[derive(Debug, Clone)]
499pub struct WhereClause {
500 pub conditions: Vec<Condition>,
501}
502
503#[derive(Debug, Clone)]
504pub struct Condition {
505 pub expr: SqlExpression,
506 pub connector: Option<LogicalOp>, }
508
509#[derive(Debug, Clone)]
510pub enum LogicalOp {
511 And,
512 Or,
513}
514
515#[derive(Debug, Clone, PartialEq)]
516pub enum SortDirection {
517 Asc,
518 Desc,
519}
520
521#[derive(Debug, Clone)]
522pub struct OrderByColumn {
523 pub column: String,
524 pub direction: SortDirection,
525}
526
527#[derive(Debug, Clone)]
528pub struct WindowSpec {
529 pub partition_by: Vec<String>,
530 pub order_by: Vec<OrderByColumn>,
531}
532
533#[derive(Debug, Clone)]
535pub enum SelectItem {
536 Column(String),
538 Expression { expr: SqlExpression, alias: String },
540 Star,
542}
543
544#[derive(Debug, Clone)]
545pub struct SelectStatement {
546 pub distinct: bool, pub columns: Vec<String>, pub select_items: Vec<SelectItem>, pub from_table: Option<String>,
550 pub from_subquery: Option<Box<SelectStatement>>, pub from_function: Option<TableFunction>, pub from_alias: Option<String>, pub where_clause: Option<WhereClause>,
554 pub order_by: Option<Vec<OrderByColumn>>,
555 pub group_by: Option<Vec<String>>,
556 pub having: Option<SqlExpression>, pub limit: Option<usize>,
558 pub offset: Option<usize>,
559 pub ctes: Vec<CTE>, }
561
562#[derive(Debug, Clone)]
564pub enum TableFunction {
565 Range {
566 start: SqlExpression,
567 end: SqlExpression,
568 step: Option<SqlExpression>,
569 },
570}
571
572#[derive(Debug, Clone)]
574pub struct CTE {
575 pub name: String,
576 pub column_list: Option<Vec<String>>, pub query: SelectStatement,
578}
579
580#[derive(Debug, Clone)]
582pub enum TableSource {
583 Table(String), DerivedTable {
585 query: Box<SelectStatement>,
587 alias: String, },
589}
590
591#[derive(Default)]
592pub struct ParserConfig {
593 pub case_insensitive: bool,
594}
595
596pub struct Parser {
597 lexer: Lexer,
598 current_token: Token,
599 in_method_args: bool, columns: Vec<String>, paren_depth: i32, #[allow(dead_code)]
603 config: ParserConfig, }
605
606impl Parser {
607 #[must_use]
608 pub fn new(input: &str) -> Self {
609 let mut lexer = Lexer::new(input);
610 let current_token = lexer.next_token();
611 Self {
612 lexer,
613 current_token,
614 in_method_args: false,
615 columns: Vec::new(),
616 paren_depth: 0,
617 config: ParserConfig::default(),
618 }
619 }
620
621 #[must_use]
622 pub fn with_config(input: &str, config: ParserConfig) -> Self {
623 let mut lexer = Lexer::new(input);
624 let current_token = lexer.next_token();
625 Self {
626 lexer,
627 current_token,
628 in_method_args: false,
629 columns: Vec::new(),
630 paren_depth: 0,
631 config,
632 }
633 }
634
635 #[must_use]
636 pub fn with_columns(mut self, columns: Vec<String>) -> Self {
637 self.columns = columns;
638 self
639 }
640
641 fn consume(&mut self, expected: Token) -> Result<(), String> {
642 if std::mem::discriminant(&self.current_token) == std::mem::discriminant(&expected) {
643 match &expected {
645 Token::LeftParen => self.paren_depth += 1,
646 Token::RightParen => {
647 self.paren_depth -= 1;
648 if self.paren_depth < 0 {
650 return Err(
651 "Unexpected closing parenthesis - no matching opening parenthesis"
652 .to_string(),
653 );
654 }
655 }
656 _ => {}
657 }
658
659 self.current_token = self.lexer.next_token();
660 Ok(())
661 } else {
662 let error_msg = match (&expected, &self.current_token) {
664 (Token::RightParen, Token::Eof) if self.paren_depth > 0 => {
665 format!(
666 "Unclosed parenthesis - missing {} closing parenthes{}",
667 self.paren_depth,
668 if self.paren_depth == 1 { "is" } else { "es" }
669 )
670 }
671 (Token::RightParen, _) if self.paren_depth > 0 => {
672 format!(
673 "Expected closing parenthesis but found {:?} (currently {} unclosed parenthes{})",
674 self.current_token,
675 self.paren_depth,
676 if self.paren_depth == 1 { "is" } else { "es" }
677 )
678 }
679 _ => format!("Expected {:?}, found {:?}", expected, self.current_token),
680 };
681 Err(error_msg)
682 }
683 }
684
685 fn advance(&mut self) {
686 match &self.current_token {
688 Token::LeftParen => self.paren_depth += 1,
689 Token::RightParen => {
690 self.paren_depth -= 1;
691 }
694 _ => {}
695 }
696 self.current_token = self.lexer.next_token();
697 }
698
699 pub fn parse(&mut self) -> Result<SelectStatement, String> {
700 if matches!(self.current_token, Token::With) {
702 self.parse_with_clause()
703 } else {
704 self.parse_select_statement()
705 }
706 }
707
708 fn parse_with_clause(&mut self) -> Result<SelectStatement, String> {
709 self.consume(Token::With)?;
710
711 let mut ctes = Vec::new();
712
713 loop {
715 let name = match &self.current_token {
717 Token::Identifier(name) => name.clone(),
718 _ => return Err("Expected CTE name after WITH".to_string()),
719 };
720 self.advance();
721
722 let column_list = if matches!(self.current_token, Token::LeftParen) {
724 self.advance();
725 let cols = self.parse_identifier_list()?;
726 self.consume(Token::RightParen)?;
727 Some(cols)
728 } else {
729 None
730 };
731
732 self.consume(Token::As)?;
734
735 self.consume(Token::LeftParen)?;
737
738 let query = self.parse_select_statement_inner()?;
740
741 self.consume(Token::RightParen)?;
743
744 ctes.push(CTE {
745 name,
746 column_list,
747 query,
748 });
749
750 if !matches!(self.current_token, Token::Comma) {
752 break;
753 }
754 self.advance();
755 }
756
757 let mut main_query = self.parse_select_statement()?;
759 main_query.ctes = ctes;
760
761 Ok(main_query)
762 }
763
764 fn parse_select_statement(&mut self) -> Result<SelectStatement, String> {
765 let result = self.parse_select_statement_inner()?;
766
767 if self.paren_depth > 0 {
769 return Err(format!(
770 "Unclosed parenthesis - missing {} closing parenthes{}",
771 self.paren_depth,
772 if self.paren_depth == 1 { "is" } else { "es" }
773 ));
774 } else if self.paren_depth < 0 {
775 return Err(
776 "Extra closing parenthesis found - no matching opening parenthesis".to_string(),
777 );
778 }
779
780 Ok(result)
781 }
782
783 fn parse_select_statement_inner(&mut self) -> Result<SelectStatement, String> {
784 self.consume(Token::Select)?;
785
786 let distinct = if matches!(self.current_token, Token::Distinct) {
788 self.advance();
789 true
790 } else {
791 false
792 };
793
794 let select_items = self.parse_select_items()?;
796
797 let columns = select_items
799 .iter()
800 .map(|item| match item {
801 SelectItem::Star => "*".to_string(),
802 SelectItem::Column(name) => name.clone(),
803 SelectItem::Expression { alias, .. } => alias.clone(),
804 })
805 .collect();
806
807 let (from_table, from_subquery, from_function, from_alias) =
809 if matches!(self.current_token, Token::From) {
810 self.advance();
811
812 if let Token::Identifier(name) = &self.current_token.clone() {
814 if name.to_uppercase() == "RANGE" {
815 self.advance();
816 self.consume(Token::LeftParen)?;
818
819 let start = self.parse_expression()?;
821 self.consume(Token::Comma)?;
822
823 let end = self.parse_expression()?;
825
826 let step = if matches!(self.current_token, Token::Comma) {
828 self.advance();
829 Some(self.parse_expression()?)
830 } else {
831 None
832 };
833
834 self.consume(Token::RightParen)?;
835
836 let alias = if matches!(self.current_token, Token::As) {
838 self.advance();
839 match &self.current_token {
840 Token::Identifier(name) => {
841 let alias = name.clone();
842 self.advance();
843 Some(alias)
844 }
845 _ => return Err("Expected alias name after AS".to_string()),
846 }
847 } else if let Token::Identifier(name) = &self.current_token {
848 let alias = name.clone();
849 self.advance();
850 Some(alias)
851 } else {
852 None
853 };
854
855 (
856 None,
857 None,
858 Some(TableFunction::Range { start, end, step }),
859 alias,
860 )
861 } else {
862 let table_name = name.clone();
864 self.advance();
865
866 let alias = if matches!(self.current_token, Token::As) {
868 self.advance();
869 match &self.current_token {
870 Token::Identifier(name) => {
871 let alias = name.clone();
872 self.advance();
873 Some(alias)
874 }
875 _ => return Err("Expected alias name after AS".to_string()),
876 }
877 } else if let Token::Identifier(name) = &self.current_token {
878 let alias = name.clone();
880 self.advance();
881 Some(alias)
882 } else {
883 None
884 };
885
886 (Some(table_name), None, None, alias)
887 }
888 } else if matches!(self.current_token, Token::LeftParen) {
889 self.advance();
891
892 let subquery = self.parse_select_statement_inner()?;
894
895 self.consume(Token::RightParen)?;
896
897 let alias = if matches!(self.current_token, Token::As) {
899 self.advance();
900 match &self.current_token {
901 Token::Identifier(name) => {
902 let alias = name.clone();
903 self.advance();
904 alias
905 }
906 _ => return Err("Expected alias name after AS".to_string()),
907 }
908 } else {
909 match &self.current_token {
911 Token::Identifier(name) => {
912 let alias = name.clone();
913 self.advance();
914 alias
915 }
916 _ => {
917 return Err(
918 "Subquery in FROM must have an alias (e.g., AS t)".to_string()
919 )
920 }
921 }
922 };
923
924 (None, Some(Box::new(subquery)), None, Some(alias))
925 } else {
926 match &self.current_token {
928 Token::Identifier(table) => {
929 let table_name = table.clone();
930 self.advance();
931
932 let alias = if matches!(self.current_token, Token::As) {
934 self.advance();
935 match &self.current_token {
936 Token::Identifier(name) => {
937 let alias = name.clone();
938 self.advance();
939 Some(alias)
940 }
941 _ => return Err("Expected alias name after AS".to_string()),
942 }
943 } else if let Token::Identifier(name) = &self.current_token {
944 let alias = name.clone();
946 self.advance();
947 Some(alias)
948 } else {
949 None
950 };
951
952 (Some(table_name), None, None, alias)
953 }
954 Token::QuotedIdentifier(table) => {
955 let table_name = table.clone();
957 self.advance();
958
959 let alias = if matches!(self.current_token, Token::As) {
961 self.advance();
962 match &self.current_token {
963 Token::Identifier(name) => {
964 let alias = name.clone();
965 self.advance();
966 Some(alias)
967 }
968 _ => return Err("Expected alias name after AS".to_string()),
969 }
970 } else if let Token::Identifier(name) = &self.current_token {
971 let alias = name.clone();
973 self.advance();
974 Some(alias)
975 } else {
976 None
977 };
978
979 (Some(table_name), None, None, alias)
980 }
981 _ => return Err("Expected table name or subquery after FROM".to_string()),
982 }
983 }
984 } else {
985 (None, None, None, None)
986 };
987
988 let where_clause = if matches!(self.current_token, Token::Where) {
989 self.advance();
990 Some(self.parse_where_clause()?)
991 } else {
992 None
993 };
994
995 let order_by = if matches!(self.current_token, Token::OrderBy) {
996 self.advance();
997 Some(self.parse_order_by_list()?)
998 } else {
999 None
1000 };
1001
1002 let group_by = if matches!(self.current_token, Token::GroupBy) {
1003 self.advance();
1004 Some(self.parse_identifier_list()?)
1005 } else {
1006 None
1007 };
1008
1009 let having = if matches!(self.current_token, Token::Having) {
1011 if group_by.is_none() {
1012 return Err("HAVING clause requires GROUP BY".to_string());
1013 }
1014 self.advance();
1015 Some(self.parse_expression()?)
1016 } else {
1017 None
1018 };
1019
1020 let limit = if matches!(self.current_token, Token::Limit) {
1022 self.advance();
1023 match &self.current_token {
1024 Token::NumberLiteral(num) => {
1025 let limit_val = num
1026 .parse::<usize>()
1027 .map_err(|_| format!("Invalid LIMIT value: {num}"))?;
1028 self.advance();
1029 Some(limit_val)
1030 }
1031 _ => return Err("Expected number after LIMIT".to_string()),
1032 }
1033 } else {
1034 None
1035 };
1036
1037 let offset = if matches!(self.current_token, Token::Offset) {
1039 self.advance();
1040 match &self.current_token {
1041 Token::NumberLiteral(num) => {
1042 let offset_val = num
1043 .parse::<usize>()
1044 .map_err(|_| format!("Invalid OFFSET value: {num}"))?;
1045 self.advance();
1046 Some(offset_val)
1047 }
1048 _ => return Err("Expected number after OFFSET".to_string()),
1049 }
1050 } else {
1051 None
1052 };
1053
1054 Ok(SelectStatement {
1055 distinct,
1056 columns,
1057 select_items,
1058 from_table,
1059 from_subquery,
1060 from_function,
1061 from_alias,
1062 where_clause,
1063 order_by,
1064 group_by,
1065 having,
1066 limit,
1067 offset,
1068 ctes: Vec::new(), })
1070 }
1071
1072 fn parse_select_list(&mut self) -> Result<Vec<String>, String> {
1073 let mut columns = Vec::new();
1074
1075 if matches!(self.current_token, Token::Star) {
1076 columns.push("*".to_string());
1077 self.advance();
1078 } else {
1079 loop {
1080 match &self.current_token {
1081 Token::Identifier(col) => {
1082 columns.push(col.clone());
1083 self.advance();
1084 }
1085 Token::QuotedIdentifier(col) => {
1086 columns.push(col.clone());
1088 self.advance();
1089 }
1090 _ => return Err("Expected column name".to_string()),
1091 }
1092
1093 if matches!(self.current_token, Token::Comma) {
1094 self.advance();
1095 } else {
1096 break;
1097 }
1098 }
1099 }
1100
1101 Ok(columns)
1102 }
1103
1104 fn parse_select_items(&mut self) -> Result<Vec<SelectItem>, String> {
1106 let mut items = Vec::new();
1107
1108 loop {
1109 if matches!(self.current_token, Token::Star) {
1112 items.push(SelectItem::Star);
1120 self.advance();
1121 } else {
1122 let expr = self.parse_comparison()?; let alias = if matches!(self.current_token, Token::As) {
1127 self.advance();
1128 match &self.current_token {
1129 Token::Identifier(alias_name) => {
1130 let alias = alias_name.clone();
1131 self.advance();
1132 alias
1133 }
1134 Token::QuotedIdentifier(alias_name) => {
1135 let alias = alias_name.clone();
1136 self.advance();
1137 alias
1138 }
1139 _ => return Err("Expected alias name after AS".to_string()),
1140 }
1141 } else {
1142 match &expr {
1144 SqlExpression::Column(col_name) => col_name.clone(),
1145 _ => format!("expr_{}", items.len() + 1), }
1147 };
1148
1149 let item = match expr {
1151 SqlExpression::Column(col_name) if alias == col_name => {
1152 SelectItem::Column(col_name)
1154 }
1155 _ => {
1156 SelectItem::Expression { expr, alias }
1158 }
1159 };
1160
1161 items.push(item);
1162 }
1163
1164 if matches!(self.current_token, Token::Comma) {
1166 self.advance();
1167 } else {
1168 break;
1169 }
1170 }
1171
1172 Ok(items)
1173 }
1174
1175 fn parse_identifier_list(&mut self) -> Result<Vec<String>, String> {
1176 let mut identifiers = Vec::new();
1177
1178 loop {
1179 match &self.current_token {
1180 Token::Identifier(id) => {
1181 identifiers.push(id.clone());
1182 self.advance();
1183 }
1184 Token::QuotedIdentifier(id) => {
1185 identifiers.push(id.clone());
1187 self.advance();
1188 }
1189 _ => return Err("Expected identifier".to_string()),
1190 }
1191
1192 if matches!(self.current_token, Token::Comma) {
1193 self.advance();
1194 } else {
1195 break;
1196 }
1197 }
1198
1199 Ok(identifiers)
1200 }
1201
1202 fn parse_window_spec(&mut self) -> Result<WindowSpec, String> {
1203 let mut partition_by = Vec::new();
1204 let mut order_by = Vec::new();
1205
1206 if matches!(self.current_token, Token::Partition) {
1208 self.advance(); if !matches!(self.current_token, Token::By) {
1210 return Err("Expected BY after PARTITION".to_string());
1211 }
1212 self.advance(); partition_by = self.parse_identifier_list()?;
1216 }
1217
1218 if matches!(self.current_token, Token::OrderBy) {
1220 self.advance(); order_by = self.parse_order_by_list()?;
1222 } else if let Token::Identifier(s) = &self.current_token {
1223 if s.to_uppercase() == "ORDER" {
1224 self.advance(); if !matches!(self.current_token, Token::By) {
1227 return Err("Expected BY after ORDER".to_string());
1228 }
1229 self.advance(); order_by = self.parse_order_by_list()?;
1231 }
1232 }
1233
1234 Ok(WindowSpec {
1235 partition_by,
1236 order_by,
1237 })
1238 }
1239
1240 fn parse_order_by_list(&mut self) -> Result<Vec<OrderByColumn>, String> {
1241 let mut order_columns = Vec::new();
1242
1243 loop {
1244 let column = match &self.current_token {
1245 Token::Identifier(id) => {
1246 let col = id.clone();
1247 self.advance();
1248 col
1249 }
1250 Token::QuotedIdentifier(id) => {
1251 let col = id.clone();
1252 self.advance();
1253 col
1254 }
1255 Token::NumberLiteral(num) if self.columns.iter().any(|col| col == num) => {
1256 let col = num.clone();
1258 self.advance();
1259 col
1260 }
1261 _ => return Err("Expected column name in ORDER BY".to_string()),
1262 };
1263
1264 let direction = match &self.current_token {
1266 Token::Asc => {
1267 self.advance();
1268 SortDirection::Asc
1269 }
1270 Token::Desc => {
1271 self.advance();
1272 SortDirection::Desc
1273 }
1274 _ => SortDirection::Asc, };
1276
1277 order_columns.push(OrderByColumn { column, direction });
1278
1279 if matches!(self.current_token, Token::Comma) {
1280 self.advance();
1281 } else {
1282 break;
1283 }
1284 }
1285
1286 Ok(order_columns)
1287 }
1288
1289 fn parse_where_clause(&mut self) -> Result<WhereClause, String> {
1290 let mut conditions = Vec::new();
1291
1292 loop {
1293 let expr = self.parse_expression()?;
1294
1295 let connector = match &self.current_token {
1296 Token::And => {
1297 self.advance();
1298 Some(LogicalOp::And)
1299 }
1300 Token::Or => {
1301 self.advance();
1302 Some(LogicalOp::Or)
1303 }
1304 Token::RightParen if self.paren_depth <= 0 => {
1305 return Err(
1307 "Unexpected closing parenthesis - no matching opening parenthesis"
1308 .to_string(),
1309 );
1310 }
1311 _ => None,
1312 };
1313
1314 conditions.push(Condition {
1315 expr,
1316 connector: connector.clone(),
1317 });
1318
1319 if connector.is_none() {
1320 break;
1321 }
1322 }
1323
1324 Ok(WhereClause { conditions })
1325 }
1326
1327 fn parse_expression(&mut self) -> Result<SqlExpression, String> {
1328 let mut left = self.parse_comparison()?;
1329
1330 if let Some(op) = self.get_binary_op() {
1333 self.advance();
1334 let right = self.parse_expression()?;
1335 left = SqlExpression::BinaryOp {
1336 left: Box::new(left),
1337 op,
1338 right: Box::new(right),
1339 };
1340 }
1341
1342 if matches!(self.current_token, Token::In) {
1344 self.advance();
1345 self.consume(Token::LeftParen)?;
1346 let values = self.parse_expression_list()?;
1347 self.consume(Token::RightParen)?;
1348
1349 left = SqlExpression::InList {
1350 expr: Box::new(left),
1351 values,
1352 };
1353 }
1354
1355 Ok(left)
1359 }
1360
1361 fn parse_comparison(&mut self) -> Result<SqlExpression, String> {
1362 let mut left = self.parse_additive()?;
1363
1364 if matches!(self.current_token, Token::Between) {
1366 self.advance(); let lower = self.parse_primary()?;
1368 self.consume(Token::And)?; let upper = self.parse_primary()?;
1370
1371 return Ok(SqlExpression::Between {
1372 expr: Box::new(left),
1373 lower: Box::new(lower),
1374 upper: Box::new(upper),
1375 });
1376 }
1377
1378 if matches!(self.current_token, Token::Not) {
1380 self.advance(); if matches!(self.current_token, Token::In) {
1382 self.advance(); self.consume(Token::LeftParen)?;
1384 let values = self.parse_expression_list()?;
1385 self.consume(Token::RightParen)?;
1386
1387 return Ok(SqlExpression::NotInList {
1388 expr: Box::new(left),
1389 values,
1390 });
1391 }
1392 return Err("Expected IN after NOT".to_string());
1393 }
1394
1395 if matches!(self.current_token, Token::Is) {
1397 self.advance(); if matches!(self.current_token, Token::Not) {
1399 self.advance(); if matches!(self.current_token, Token::Null) {
1401 self.advance(); left = SqlExpression::BinaryOp {
1403 left: Box::new(left),
1404 op: "IS NOT NULL".to_string(),
1405 right: Box::new(SqlExpression::Null),
1406 };
1407 } else {
1408 return Err("Expected NULL after IS NOT".to_string());
1409 }
1410 } else if matches!(self.current_token, Token::Null) {
1411 self.advance(); left = SqlExpression::BinaryOp {
1413 left: Box::new(left),
1414 op: "IS NULL".to_string(),
1415 right: Box::new(SqlExpression::Null),
1416 };
1417 } else {
1418 return Err("Expected NULL or NOT after IS".to_string());
1419 }
1420 }
1421 else if let Some(op) = self.get_binary_op() {
1423 self.advance();
1424 let right = self.parse_additive()?;
1425 left = SqlExpression::BinaryOp {
1426 left: Box::new(left),
1427 op,
1428 right: Box::new(right),
1429 };
1430 }
1431
1432 Ok(left)
1433 }
1434
1435 fn parse_additive(&mut self) -> Result<SqlExpression, String> {
1436 let mut left = self.parse_multiplicative()?;
1437
1438 while matches!(self.current_token, Token::Plus | Token::Minus) {
1439 let op = match self.current_token {
1440 Token::Plus => "+",
1441 Token::Minus => "-",
1442 _ => unreachable!(),
1443 };
1444 self.advance();
1445 let right = self.parse_multiplicative()?;
1446 left = SqlExpression::BinaryOp {
1447 left: Box::new(left),
1448 op: op.to_string(),
1449 right: Box::new(right),
1450 };
1451 }
1452
1453 Ok(left)
1454 }
1455
1456 fn parse_multiplicative(&mut self) -> Result<SqlExpression, String> {
1457 let mut left = self.parse_primary()?;
1458
1459 while matches!(self.current_token, Token::Dot) {
1461 self.advance();
1462 if let Token::Identifier(method) = &self.current_token {
1463 let method_name = method.clone();
1464 self.advance();
1465
1466 if matches!(self.current_token, Token::LeftParen) {
1467 self.advance();
1468 let args = self.parse_method_args()?;
1469 self.consume(Token::RightParen)?;
1470
1471 match left {
1473 SqlExpression::Column(obj) => {
1474 left = SqlExpression::MethodCall {
1476 object: obj,
1477 method: method_name,
1478 args,
1479 };
1480 }
1481 SqlExpression::MethodCall { .. }
1482 | SqlExpression::ChainedMethodCall { .. } => {
1483 left = SqlExpression::ChainedMethodCall {
1485 base: Box::new(left),
1486 method: method_name,
1487 args,
1488 };
1489 }
1490 _ => {
1491 left = SqlExpression::ChainedMethodCall {
1493 base: Box::new(left),
1494 method: method_name,
1495 args,
1496 };
1497 }
1498 }
1499 } else {
1500 return Err(format!("Expected '(' after method name '{method_name}'"));
1501 }
1502 } else {
1503 return Err("Expected method name after '.'".to_string());
1504 }
1505 }
1506
1507 while matches!(
1508 self.current_token,
1509 Token::Star | Token::Divide | Token::Modulo
1510 ) {
1511 let op = match self.current_token {
1512 Token::Star => "*",
1513 Token::Divide => "/",
1514 Token::Modulo => "%",
1515 _ => unreachable!(),
1516 };
1517 self.advance();
1518 let right = self.parse_primary()?;
1519 left = SqlExpression::BinaryOp {
1520 left: Box::new(left),
1521 op: op.to_string(),
1522 right: Box::new(right),
1523 };
1524 }
1525
1526 Ok(left)
1527 }
1528
1529 fn parse_logical_or(&mut self) -> Result<SqlExpression, String> {
1530 let mut left = self.parse_logical_and()?;
1531
1532 while matches!(self.current_token, Token::Or) {
1533 self.advance();
1534 let right = self.parse_logical_and()?;
1535 left = SqlExpression::BinaryOp {
1539 left: Box::new(left),
1540 op: "OR".to_string(),
1541 right: Box::new(right),
1542 };
1543 }
1544
1545 Ok(left)
1546 }
1547
1548 fn parse_logical_and(&mut self) -> Result<SqlExpression, String> {
1549 let mut left = self.parse_expression()?;
1550
1551 while matches!(self.current_token, Token::And) {
1552 self.advance();
1553 let right = self.parse_expression()?;
1554 left = SqlExpression::BinaryOp {
1556 left: Box::new(left),
1557 op: "AND".to_string(),
1558 right: Box::new(right),
1559 };
1560 }
1561
1562 Ok(left)
1563 }
1564
1565 fn parse_case_expression(&mut self) -> Result<SqlExpression, String> {
1566 self.consume(Token::Case)?;
1568
1569 let mut when_branches = Vec::new();
1570
1571 while matches!(self.current_token, Token::When) {
1573 self.advance(); let condition = self.parse_expression()?;
1577
1578 self.consume(Token::Then)?;
1580
1581 let result = self.parse_expression()?;
1583
1584 when_branches.push(WhenBranch {
1585 condition: Box::new(condition),
1586 result: Box::new(result),
1587 });
1588 }
1589
1590 if when_branches.is_empty() {
1592 return Err("CASE expression must have at least one WHEN clause".to_string());
1593 }
1594
1595 let else_branch = if matches!(self.current_token, Token::Else) {
1597 self.advance(); Some(Box::new(self.parse_expression()?))
1599 } else {
1600 None
1601 };
1602
1603 self.consume(Token::End)?;
1605
1606 Ok(SqlExpression::CaseExpression {
1607 when_branches,
1608 else_branch,
1609 })
1610 }
1611
1612 fn parse_primary(&mut self) -> Result<SqlExpression, String> {
1613 if let Token::NumberLiteral(num_str) = &self.current_token {
1616 if self.columns.iter().any(|col| col == num_str) {
1618 let expr = SqlExpression::Column(num_str.clone());
1619 self.advance();
1620 return Ok(expr);
1621 }
1622 }
1623
1624 match &self.current_token {
1625 Token::Case => {
1626 self.parse_case_expression()
1628 }
1629 Token::DateTime => {
1630 self.advance(); self.consume(Token::LeftParen)?;
1632
1633 if matches!(&self.current_token, Token::RightParen) {
1635 self.advance(); return Ok(SqlExpression::DateTimeToday {
1637 hour: None,
1638 minute: None,
1639 second: None,
1640 });
1641 }
1642
1643 let year = if let Token::NumberLiteral(n) = &self.current_token {
1645 n.parse::<i32>().map_err(|_| "Invalid year")?
1646 } else {
1647 return Err("Expected year in DateTime constructor".to_string());
1648 };
1649 self.advance();
1650 self.consume(Token::Comma)?;
1651
1652 let month = if let Token::NumberLiteral(n) = &self.current_token {
1654 n.parse::<u32>().map_err(|_| "Invalid month")?
1655 } else {
1656 return Err("Expected month in DateTime constructor".to_string());
1657 };
1658 self.advance();
1659 self.consume(Token::Comma)?;
1660
1661 let day = if let Token::NumberLiteral(n) = &self.current_token {
1663 n.parse::<u32>().map_err(|_| "Invalid day")?
1664 } else {
1665 return Err("Expected day in DateTime constructor".to_string());
1666 };
1667 self.advance();
1668
1669 let mut hour = None;
1671 let mut minute = None;
1672 let mut second = None;
1673
1674 if matches!(&self.current_token, Token::Comma) {
1675 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1679 hour = Some(n.parse::<u32>().map_err(|_| "Invalid hour")?);
1680 self.advance();
1681
1682 if matches!(&self.current_token, Token::Comma) {
1684 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1687 minute = Some(n.parse::<u32>().map_err(|_| "Invalid minute")?);
1688 self.advance();
1689
1690 if matches!(&self.current_token, Token::Comma) {
1692 self.advance(); if let Token::NumberLiteral(n) = &self.current_token {
1695 second =
1696 Some(n.parse::<u32>().map_err(|_| "Invalid second")?);
1697 self.advance();
1698 }
1699 }
1700 }
1701 }
1702 }
1703 }
1704
1705 self.consume(Token::RightParen)?;
1706 Ok(SqlExpression::DateTimeConstructor {
1707 year,
1708 month,
1709 day,
1710 hour,
1711 minute,
1712 second,
1713 })
1714 }
1715 Token::Identifier(id) => {
1716 let id_upper = id.to_uppercase();
1717 let id_clone = id.clone();
1718
1719 if id_upper == "TRUE" {
1721 self.advance();
1722 return Ok(SqlExpression::BooleanLiteral(true));
1723 } else if id_upper == "FALSE" {
1724 self.advance();
1725 return Ok(SqlExpression::BooleanLiteral(false));
1726 }
1727
1728 self.advance();
1729
1730 if matches!(self.current_token, Token::LeftParen) {
1732 self.advance(); let args = self.parse_function_args()?;
1736 self.consume(Token::RightParen)?;
1737
1738 if matches!(self.current_token, Token::Over) {
1740 self.advance(); self.consume(Token::LeftParen)?;
1742 let window_spec = self.parse_window_spec()?;
1743 self.consume(Token::RightParen)?;
1744 return Ok(SqlExpression::WindowFunction {
1745 name: id_upper,
1746 args,
1747 window_spec,
1748 });
1749 }
1750
1751 return Ok(SqlExpression::FunctionCall {
1752 name: id_upper,
1753 args,
1754 });
1755 }
1756
1757 Ok(SqlExpression::Column(id_clone))
1759 }
1760 Token::QuotedIdentifier(id) => {
1761 let expr = if self.in_method_args {
1764 SqlExpression::StringLiteral(id.clone())
1765 } else {
1766 SqlExpression::Column(id.clone())
1768 };
1769 self.advance();
1770 Ok(expr)
1771 }
1772 Token::StringLiteral(s) => {
1773 let expr = SqlExpression::StringLiteral(s.clone());
1774 self.advance();
1775 Ok(expr)
1776 }
1777 Token::NumberLiteral(n) => {
1778 let expr = SqlExpression::NumberLiteral(n.clone());
1779 self.advance();
1780 Ok(expr)
1781 }
1782 Token::Null => {
1783 self.advance();
1784 Ok(SqlExpression::Null)
1785 }
1786 Token::LeftParen => {
1787 self.advance();
1788
1789 let expr = self.parse_logical_or()?;
1792
1793 self.consume(Token::RightParen)?;
1794 Ok(expr)
1795 }
1796 Token::Not => {
1797 self.advance(); if let Ok(inner_expr) = self.parse_comparison() {
1801 if matches!(self.current_token, Token::In) {
1803 self.advance(); self.consume(Token::LeftParen)?;
1805 let values = self.parse_expression_list()?;
1806 self.consume(Token::RightParen)?;
1807
1808 Ok(SqlExpression::NotInList {
1809 expr: Box::new(inner_expr),
1810 values,
1811 })
1812 } else {
1813 Ok(SqlExpression::Not {
1815 expr: Box::new(inner_expr),
1816 })
1817 }
1818 } else {
1819 Err("Expected expression after NOT".to_string())
1820 }
1821 }
1822 Token::Star => {
1823 self.advance();
1825 Ok(SqlExpression::StringLiteral("*".to_string()))
1826 }
1827 _ => Err(format!("Unexpected token: {:?}", self.current_token)),
1828 }
1829 }
1830
1831 fn parse_method_args(&mut self) -> Result<Vec<SqlExpression>, String> {
1832 let mut args = Vec::new();
1833
1834 self.in_method_args = true;
1836
1837 if !matches!(self.current_token, Token::RightParen) {
1838 loop {
1839 args.push(self.parse_expression()?);
1840
1841 if matches!(self.current_token, Token::Comma) {
1842 self.advance();
1843 } else {
1844 break;
1845 }
1846 }
1847 }
1848
1849 self.in_method_args = false;
1851
1852 Ok(args)
1853 }
1854
1855 fn parse_function_args(&mut self) -> Result<Vec<SqlExpression>, String> {
1856 let mut args = Vec::new();
1857
1858 if !matches!(self.current_token, Token::RightParen) {
1859 if matches!(self.current_token, Token::Distinct) {
1861 self.advance(); let expr = self.parse_additive()?;
1864 args.push(SqlExpression::FunctionCall {
1866 name: "DISTINCT".to_string(),
1867 args: vec![expr],
1868 });
1869 } else {
1870 args.push(self.parse_additive()?);
1872 }
1873
1874 while matches!(self.current_token, Token::Comma) {
1876 self.advance();
1877 args.push(self.parse_additive()?);
1878 }
1879 }
1880
1881 Ok(args)
1882 }
1883
1884 fn parse_expression_list(&mut self) -> Result<Vec<SqlExpression>, String> {
1885 let mut expressions = Vec::new();
1886
1887 loop {
1888 expressions.push(self.parse_expression()?);
1889
1890 if matches!(self.current_token, Token::Comma) {
1891 self.advance();
1892 } else {
1893 break;
1894 }
1895 }
1896
1897 Ok(expressions)
1898 }
1899
1900 fn get_binary_op(&self) -> Option<String> {
1901 match &self.current_token {
1902 Token::Equal => Some("=".to_string()),
1903 Token::NotEqual => Some("!=".to_string()),
1904 Token::LessThan => Some("<".to_string()),
1905 Token::GreaterThan => Some(">".to_string()),
1906 Token::LessThanOrEqual => Some("<=".to_string()),
1907 Token::GreaterThanOrEqual => Some(">=".to_string()),
1908 Token::Like => Some("LIKE".to_string()),
1909 _ => None,
1910 }
1911 }
1912
1913 fn get_arithmetic_op(&self) -> Option<String> {
1914 match &self.current_token {
1915 Token::Plus => Some("+".to_string()),
1916 Token::Minus => Some("-".to_string()),
1917 Token::Star => Some("*".to_string()), Token::Divide => Some("/".to_string()),
1919 Token::Modulo => Some("%".to_string()),
1920 _ => None,
1921 }
1922 }
1923
1924 #[must_use]
1925 pub fn get_position(&self) -> usize {
1926 self.lexer.get_position()
1927 }
1928}
1929
1930#[derive(Debug, Clone)]
1932pub enum CursorContext {
1933 SelectClause,
1934 FromClause,
1935 WhereClause,
1936 OrderByClause,
1937 AfterColumn(String),
1938 AfterLogicalOp(LogicalOp),
1939 AfterComparisonOp(String, String), InMethodCall(String, String), InExpression,
1942 Unknown,
1943}
1944
1945fn safe_slice_to(s: &str, pos: usize) -> &str {
1947 if pos >= s.len() {
1948 return s;
1949 }
1950
1951 let mut safe_pos = pos;
1953 while safe_pos > 0 && !s.is_char_boundary(safe_pos) {
1954 safe_pos -= 1;
1955 }
1956
1957 &s[..safe_pos]
1958}
1959
1960fn safe_slice_from(s: &str, pos: usize) -> &str {
1962 if pos >= s.len() {
1963 return "";
1964 }
1965
1966 let mut safe_pos = pos;
1968 while safe_pos < s.len() && !s.is_char_boundary(safe_pos) {
1969 safe_pos += 1;
1970 }
1971
1972 &s[safe_pos..]
1973}
1974
1975#[must_use]
1976pub fn detect_cursor_context(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
1977 let truncated = safe_slice_to(query, cursor_pos);
1978 let mut parser = Parser::new(truncated);
1979
1980 if let Ok(stmt) = parser.parse() {
1982 let (ctx, partial) = analyze_statement(&stmt, truncated, cursor_pos);
1983 #[cfg(test)]
1984 println!("analyze_statement returned: {ctx:?}, {partial:?} for query: '{truncated}'");
1985 (ctx, partial)
1986 } else {
1987 let (ctx, partial) = analyze_partial(truncated, cursor_pos);
1989 #[cfg(test)]
1990 println!("analyze_partial returned: {ctx:?}, {partial:?} for query: '{truncated}'");
1991 (ctx, partial)
1992 }
1993}
1994
1995#[must_use]
1996pub fn tokenize_query(query: &str) -> Vec<String> {
1997 let mut lexer = Lexer::new(query);
1998 let tokens = lexer.tokenize_all();
1999 tokens.iter().map(|t| format!("{t:?}")).collect()
2000}
2001
2002#[must_use]
2003pub fn format_sql_pretty(query: &str) -> Vec<String> {
2004 format_sql_pretty_compact(query, 5) }
2006
2007#[must_use]
2009pub fn format_ast_tree(query: &str) -> String {
2010 let mut parser = Parser::new(query);
2011 match parser.parse() {
2012 Ok(stmt) => format_select_statement(&stmt, 0),
2013 Err(e) => format!("❌ PARSE ERROR ❌\n{e}\n\n⚠️ The query could not be parsed correctly.\n💡 Check parentheses, operators, and syntax."),
2014 }
2015}
2016
2017fn format_select_statement(stmt: &SelectStatement, indent: usize) -> String {
2018 let mut result = String::new();
2019 let indent_str = " ".repeat(indent);
2020
2021 result.push_str(&format!("{indent_str}SelectStatement {{\n"));
2022
2023 result.push_str(&format!("{indent_str} columns: ["));
2025 if stmt.columns.is_empty() {
2026 result.push_str("],\n");
2027 } else {
2028 result.push('\n');
2029 for col in &stmt.columns {
2030 result.push_str(&format!("{indent_str} \"{col}\",\n"));
2031 }
2032 result.push_str(&format!("{indent_str} ],\n"));
2033 }
2034
2035 if let Some(table) = &stmt.from_table {
2037 result.push_str(&format!("{indent_str} from_table: \"{table}\",\n"));
2038 }
2039
2040 if let Some(where_clause) = &stmt.where_clause {
2042 result.push_str(&format!("{indent_str} where_clause: {{\n"));
2043 result.push_str(&format_where_clause(where_clause, indent + 2));
2044 result.push_str(&format!("{indent_str} }},\n"));
2045 }
2046
2047 if let Some(order_by) = &stmt.order_by {
2049 result.push_str(&format!("{indent_str} order_by: ["));
2050 if order_by.is_empty() {
2051 result.push_str("],\n");
2052 } else {
2053 result.push('\n');
2054 for col in order_by {
2055 let dir = match col.direction {
2056 SortDirection::Asc => "ASC",
2057 SortDirection::Desc => "DESC",
2058 };
2059 result.push_str(&format!(
2060 "{indent_str} \"{col}\" {dir},\n",
2061 col = col.column
2062 ));
2063 }
2064 result.push_str(&format!("{indent_str} ],\n"));
2065 }
2066 }
2067
2068 if let Some(group_by) = &stmt.group_by {
2070 result.push_str(&format!("{indent_str} group_by: ["));
2071 if group_by.is_empty() {
2072 result.push_str("]\n");
2073 } else {
2074 result.push('\n');
2075 for col in group_by {
2076 result.push_str(&format!("{indent_str} \"{col}\",\n"));
2077 }
2078 result.push_str(&format!("{indent_str} ],\n"));
2079 }
2080 }
2081
2082 result.push_str(&format!("{indent_str}}}"));
2083 result
2084}
2085
2086fn format_where_clause(clause: &WhereClause, indent: usize) -> String {
2087 let mut result = String::new();
2088 let indent_str = " ".repeat(indent);
2089
2090 result.push_str(&format!("{indent_str}conditions: [\n"));
2091
2092 for condition in &clause.conditions {
2093 result.push_str(&format!("{indent_str} {{\n"));
2094 result.push_str(&format!(
2095 "{indent_str} expr: {},\n",
2096 format_expression_ast(&condition.expr)
2097 ));
2098
2099 if let Some(connector) = &condition.connector {
2100 let connector_str = match connector {
2101 LogicalOp::And => "AND",
2102 LogicalOp::Or => "OR",
2103 };
2104 result.push_str(&format!("{indent_str} connector: {connector_str},\n"));
2105 }
2106
2107 result.push_str(&format!("{indent_str} }},\n"));
2108 }
2109
2110 result.push_str(&format!("{indent_str}]\n"));
2111 result
2112}
2113
2114fn format_expression_ast(expr: &SqlExpression) -> String {
2115 match expr {
2116 SqlExpression::Column(name) => format!("Column(\"{name}\")"),
2117 SqlExpression::StringLiteral(value) => format!("StringLiteral(\"{value}\")"),
2118 SqlExpression::NumberLiteral(value) => format!("NumberLiteral({value})"),
2119 SqlExpression::BooleanLiteral(value) => format!("BooleanLiteral({value})"),
2120 SqlExpression::Null => "Null".to_string(),
2121 SqlExpression::DateTimeConstructor {
2122 year,
2123 month,
2124 day,
2125 hour,
2126 minute,
2127 second,
2128 } => {
2129 format!(
2130 "DateTime({}-{:02}-{:02} {:02}:{:02}:{:02})",
2131 year,
2132 month,
2133 day,
2134 hour.unwrap_or(0),
2135 minute.unwrap_or(0),
2136 second.unwrap_or(0)
2137 )
2138 }
2139 SqlExpression::DateTimeToday {
2140 hour,
2141 minute,
2142 second,
2143 } => {
2144 format!(
2145 "DateTimeToday({:02}:{:02}:{:02})",
2146 hour.unwrap_or(0),
2147 minute.unwrap_or(0),
2148 second.unwrap_or(0)
2149 )
2150 }
2151 SqlExpression::MethodCall {
2152 object,
2153 method,
2154 args,
2155 } => {
2156 let args_str = args
2157 .iter()
2158 .map(format_expression_ast)
2159 .collect::<Vec<_>>()
2160 .join(", ");
2161 format!("MethodCall({object}.{method}({args_str}))")
2162 }
2163 SqlExpression::ChainedMethodCall { base, method, args } => {
2164 let args_str = args
2165 .iter()
2166 .map(format_expression_ast)
2167 .collect::<Vec<_>>()
2168 .join(", ");
2169 format!(
2170 "ChainedMethodCall({}.{}({}))",
2171 format_expression_ast(base),
2172 method,
2173 args_str
2174 )
2175 }
2176 SqlExpression::FunctionCall { name, args } => {
2177 let args_str = args
2178 .iter()
2179 .map(format_expression_ast)
2180 .collect::<Vec<_>>()
2181 .join(", ");
2182 format!("FunctionCall({name}({args_str}))")
2183 }
2184 SqlExpression::WindowFunction {
2185 name,
2186 args,
2187 window_spec,
2188 } => {
2189 let args_str = args
2190 .iter()
2191 .map(format_expression_ast)
2192 .collect::<Vec<_>>()
2193 .join(", ");
2194 let partition_str = if !window_spec.partition_by.is_empty() {
2195 format!(" PARTITION BY {}", window_spec.partition_by.join(", "))
2196 } else {
2197 String::new()
2198 };
2199 let order_str = if !window_spec.order_by.is_empty() {
2200 let cols = window_spec
2201 .order_by
2202 .iter()
2203 .map(|col| format!("{} {:?}", col.column, col.direction))
2204 .collect::<Vec<_>>()
2205 .join(", ");
2206 format!(" ORDER BY {}", cols)
2207 } else {
2208 String::new()
2209 };
2210 format!("WindowFunction({name}({args_str}) OVER({partition_str}{order_str}))")
2211 }
2212 SqlExpression::BinaryOp { left, op, right } => {
2213 format!(
2214 "BinaryOp({} {} {})",
2215 format_expression_ast(left),
2216 op,
2217 format_expression_ast(right)
2218 )
2219 }
2220 SqlExpression::InList { expr, values } => {
2221 let list_str = values
2222 .iter()
2223 .map(format_expression_ast)
2224 .collect::<Vec<_>>()
2225 .join(", ");
2226 format!("InList({} IN [{}])", format_expression_ast(expr), list_str)
2227 }
2228 SqlExpression::NotInList { expr, values } => {
2229 let list_str = values
2230 .iter()
2231 .map(format_expression_ast)
2232 .collect::<Vec<_>>()
2233 .join(", ");
2234 format!(
2235 "NotInList({} NOT IN [{}])",
2236 format_expression_ast(expr),
2237 list_str
2238 )
2239 }
2240 SqlExpression::Between { expr, lower, upper } => {
2241 format!(
2242 "Between({} BETWEEN {} AND {})",
2243 format_expression_ast(expr),
2244 format_expression_ast(lower),
2245 format_expression_ast(upper)
2246 )
2247 }
2248 SqlExpression::Not { expr } => {
2249 format!("Not({})", format_expression_ast(expr))
2250 }
2251 SqlExpression::CaseExpression {
2252 when_branches,
2253 else_branch,
2254 } => {
2255 let when_strs: Vec<String> = when_branches
2256 .iter()
2257 .map(|branch| {
2258 format!(
2259 "WHEN {} THEN {}",
2260 format_expression_ast(&branch.condition),
2261 format_expression_ast(&branch.result)
2262 )
2263 })
2264 .collect();
2265 let else_str = else_branch
2266 .as_ref()
2267 .map(|e| format!(" ELSE {}", format_expression_ast(e)))
2268 .unwrap_or_default();
2269 format!("CASE {} {} END", when_strs.join(" "), else_str)
2270 }
2271 }
2272}
2273
2274#[must_use]
2276pub fn datetime_to_iso_string(expr: &SqlExpression) -> Option<String> {
2277 match expr {
2278 SqlExpression::DateTimeConstructor {
2279 year,
2280 month,
2281 day,
2282 hour,
2283 minute,
2284 second,
2285 } => {
2286 let h = hour.unwrap_or(0);
2287 let m = minute.unwrap_or(0);
2288 let s = second.unwrap_or(0);
2289
2290 if let Ok(dt) = NaiveDateTime::parse_from_str(
2292 &format!("{year:04}-{month:02}-{day:02} {h:02}:{m:02}:{s:02}"),
2293 "%Y-%m-%d %H:%M:%S",
2294 ) {
2295 Some(dt.format("%Y-%m-%d %H:%M:%S").to_string())
2296 } else {
2297 None
2298 }
2299 }
2300 SqlExpression::DateTimeToday {
2301 hour,
2302 minute,
2303 second,
2304 } => {
2305 let now = Local::now();
2306 let h = hour.unwrap_or(0);
2307 let m = minute.unwrap_or(0);
2308 let s = second.unwrap_or(0);
2309
2310 if let Ok(dt) = NaiveDateTime::parse_from_str(
2312 &format!(
2313 "{:04}-{:02}-{:02} {:02}:{:02}:{:02}",
2314 now.year(),
2315 now.month(),
2316 now.day(),
2317 h,
2318 m,
2319 s
2320 ),
2321 "%Y-%m-%d %H:%M:%S",
2322 ) {
2323 Some(dt.format("%Y-%m-%d %H:%M:%S").to_string())
2324 } else {
2325 None
2326 }
2327 }
2328 _ => None,
2329 }
2330}
2331
2332fn format_sql_with_preserved_parens(
2334 query: &str,
2335 cols_per_line: usize,
2336) -> Result<Vec<String>, String> {
2337 let mut lines = Vec::new();
2338 let mut lexer = Lexer::new(query);
2339 let tokens_with_pos = lexer.tokenize_all_with_positions();
2340
2341 if tokens_with_pos.is_empty() {
2342 return Err("No tokens found".to_string());
2343 }
2344
2345 let mut i = 0;
2346 let cols_per_line = cols_per_line.max(1);
2347
2348 while i < tokens_with_pos.len() {
2349 let (start, _end, ref token) = tokens_with_pos[i];
2350
2351 match token {
2352 Token::Select => {
2353 lines.push("SELECT".to_string());
2354 i += 1;
2355
2356 let mut columns = Vec::new();
2358 let mut col_start = i;
2359 while i < tokens_with_pos.len() {
2360 match &tokens_with_pos[i].2 {
2361 Token::From | Token::Eof => break,
2362 Token::Comma => {
2363 if col_start < i {
2365 let col_text = extract_text_between_positions(
2366 query,
2367 tokens_with_pos[col_start].0,
2368 tokens_with_pos[i - 1].1,
2369 );
2370 columns.push(col_text);
2371 }
2372 i += 1;
2373 col_start = i;
2374 }
2375 _ => i += 1,
2376 }
2377 }
2378 if col_start < i && i > 0 {
2380 let col_text = extract_text_between_positions(
2381 query,
2382 tokens_with_pos[col_start].0,
2383 tokens_with_pos[i - 1].1,
2384 );
2385 columns.push(col_text);
2386 }
2387
2388 for chunk in columns.chunks(cols_per_line) {
2390 let mut line = " ".to_string();
2391 for (idx, col) in chunk.iter().enumerate() {
2392 if idx > 0 {
2393 line.push_str(", ");
2394 }
2395 line.push_str(col.trim());
2396 }
2397 let is_last_chunk = chunk.as_ptr() as usize + std::mem::size_of_val(chunk)
2399 >= columns.last().map_or(0, |c| std::ptr::from_ref(c) as usize);
2400 if !is_last_chunk && columns.len() > cols_per_line {
2401 line.push(',');
2402 }
2403 lines.push(line);
2404 }
2405 }
2406 Token::From => {
2407 i += 1;
2408 if i < tokens_with_pos.len() {
2409 let table_start = tokens_with_pos[i].0;
2410 while i < tokens_with_pos.len() {
2412 match &tokens_with_pos[i].2 {
2413 Token::Where | Token::OrderBy | Token::GroupBy | Token::Eof => break,
2414 _ => i += 1,
2415 }
2416 }
2417 if i > 0 {
2418 let table_text = extract_text_between_positions(
2419 query,
2420 table_start,
2421 tokens_with_pos[i - 1].1,
2422 );
2423 lines.push(format!("FROM {}", table_text.trim()));
2424 }
2425 }
2426 }
2427 Token::Where => {
2428 lines.push("WHERE".to_string());
2429 i += 1;
2430
2431 let where_start = if i < tokens_with_pos.len() {
2433 tokens_with_pos[i].0
2434 } else {
2435 start
2436 };
2437
2438 let mut where_end = query.len();
2440 while i < tokens_with_pos.len() {
2441 match &tokens_with_pos[i].2 {
2442 Token::OrderBy | Token::GroupBy | Token::Eof => {
2443 if i > 0 {
2444 where_end = tokens_with_pos[i - 1].1;
2445 }
2446 break;
2447 }
2448 _ => i += 1,
2449 }
2450 }
2451
2452 let where_text = extract_text_between_positions(query, where_start, where_end);
2454
2455 let formatted_where = format_where_clause_with_parens(&where_text);
2457 for line in formatted_where {
2458 lines.push(format!(" {line}"));
2459 }
2460 }
2461 Token::OrderBy => {
2462 i += 1;
2463 let order_start = if i < tokens_with_pos.len() {
2464 tokens_with_pos[i].0
2465 } else {
2466 start
2467 };
2468
2469 while i < tokens_with_pos.len() {
2471 match &tokens_with_pos[i].2 {
2472 Token::GroupBy | Token::Eof => break,
2473 _ => i += 1,
2474 }
2475 }
2476
2477 if i > 0 {
2478 let order_text = extract_text_between_positions(
2479 query,
2480 order_start,
2481 tokens_with_pos[i - 1].1,
2482 );
2483 lines.push(format!("ORDER BY {}", order_text.trim()));
2484 }
2485 }
2486 Token::GroupBy => {
2487 i += 1;
2488 let group_start = if i < tokens_with_pos.len() {
2489 tokens_with_pos[i].0
2490 } else {
2491 start
2492 };
2493
2494 while i < tokens_with_pos.len() {
2496 match &tokens_with_pos[i].2 {
2497 Token::Having | Token::Eof => break,
2498 _ => i += 1,
2499 }
2500 }
2501
2502 if i > 0 {
2503 let group_text = extract_text_between_positions(
2504 query,
2505 group_start,
2506 tokens_with_pos[i - 1].1,
2507 );
2508 lines.push(format!("GROUP BY {}", group_text.trim()));
2509 }
2510 }
2511 _ => i += 1,
2512 }
2513 }
2514
2515 Ok(lines)
2516}
2517
2518fn extract_text_between_positions(query: &str, start: usize, end: usize) -> String {
2520 let chars: Vec<char> = query.chars().collect();
2521 let start = start.min(chars.len());
2522 let end = end.min(chars.len());
2523 chars[start..end].iter().collect()
2524}
2525
2526fn format_where_clause_with_parens(where_text: &str) -> Vec<String> {
2528 let mut lines = Vec::new();
2529 let mut current_line = String::new();
2530 let mut paren_depth = 0;
2531 let mut i = 0;
2532 let chars: Vec<char> = where_text.chars().collect();
2533
2534 while i < chars.len() {
2535 if paren_depth == 0 {
2537 if i + 5 <= chars.len() {
2539 let next_five: String = chars[i..i + 5].iter().collect();
2540 if next_five.to_uppercase() == " AND " {
2541 if !current_line.trim().is_empty() {
2542 lines.push(current_line.trim().to_string());
2543 }
2544 lines.push("AND".to_string());
2545 current_line.clear();
2546 i += 5;
2547 continue;
2548 }
2549 }
2550 if i + 4 <= chars.len() {
2551 let next_four: String = chars[i..i + 4].iter().collect();
2552 if next_four.to_uppercase() == " OR " {
2553 if !current_line.trim().is_empty() {
2554 lines.push(current_line.trim().to_string());
2555 }
2556 lines.push("OR".to_string());
2557 current_line.clear();
2558 i += 4;
2559 continue;
2560 }
2561 }
2562 }
2563
2564 match chars[i] {
2566 '(' => {
2567 paren_depth += 1;
2568 current_line.push('(');
2569 }
2570 ')' => {
2571 paren_depth -= 1;
2572 current_line.push(')');
2573 }
2574 c => current_line.push(c),
2575 }
2576 i += 1;
2577 }
2578
2579 if !current_line.trim().is_empty() {
2581 lines.push(current_line.trim().to_string());
2582 }
2583
2584 if lines.is_empty() {
2586 lines.push(where_text.trim().to_string());
2587 }
2588
2589 lines
2590}
2591
2592#[must_use]
2593pub fn format_sql_pretty_compact(query: &str, cols_per_line: usize) -> Vec<String> {
2594 if let Ok(lines) = format_sql_with_preserved_parens(query, cols_per_line) {
2596 return lines;
2597 }
2598
2599 let mut lines = Vec::new();
2601 let mut parser = Parser::new(query);
2602
2603 let cols_per_line = cols_per_line.max(1);
2605
2606 if let Ok(stmt) = parser.parse() {
2607 if !stmt.columns.is_empty() {
2609 lines.push("SELECT".to_string());
2610
2611 for chunk in stmt.columns.chunks(cols_per_line) {
2613 let mut line = " ".to_string();
2614 for (i, col) in chunk.iter().enumerate() {
2615 if i > 0 {
2616 line.push_str(", ");
2617 }
2618 line.push_str(col);
2619 }
2620 let last_chunk_idx = (stmt.columns.len() - 1) / cols_per_line;
2622 let current_chunk_idx =
2623 stmt.columns.iter().position(|c| c == &chunk[0]).unwrap() / cols_per_line;
2624 if current_chunk_idx < last_chunk_idx {
2625 line.push(',');
2626 }
2627 lines.push(line);
2628 }
2629 }
2630
2631 if let Some(table) = &stmt.from_table {
2633 lines.push(format!("FROM {table}"));
2634 }
2635
2636 if let Some(where_clause) = &stmt.where_clause {
2638 lines.push("WHERE".to_string());
2639 for (i, condition) in where_clause.conditions.iter().enumerate() {
2640 if i > 0 {
2641 if let Some(prev_condition) = where_clause.conditions.get(i - 1) {
2643 if let Some(connector) = &prev_condition.connector {
2644 match connector {
2645 LogicalOp::And => lines.push(" AND".to_string()),
2646 LogicalOp::Or => lines.push(" OR".to_string()),
2647 }
2648 }
2649 }
2650 }
2651 lines.push(format!(" {}", format_expression(&condition.expr)));
2652 }
2653 }
2654
2655 if let Some(order_by) = &stmt.order_by {
2657 let order_str = order_by
2658 .iter()
2659 .map(|col| {
2660 let dir = match col.direction {
2661 SortDirection::Asc => " ASC",
2662 SortDirection::Desc => " DESC",
2663 };
2664 format!("{}{}", col.column, dir)
2665 })
2666 .collect::<Vec<_>>()
2667 .join(", ");
2668 lines.push(format!("ORDER BY {order_str}"));
2669 }
2670
2671 if let Some(group_by) = &stmt.group_by {
2673 let group_str = group_by.join(", ");
2674 lines.push(format!("GROUP BY {group_str}"));
2675 }
2676 } else {
2677 let mut lexer = Lexer::new(query);
2679 let tokens = lexer.tokenize_all();
2680 let mut current_line = String::new();
2681 let mut indent = 0;
2682
2683 for token in tokens {
2684 match &token {
2685 Token::Select | Token::From | Token::Where | Token::OrderBy | Token::GroupBy => {
2686 if !current_line.is_empty() {
2687 lines.push(current_line.trim().to_string());
2688 current_line.clear();
2689 }
2690 lines.push(format!("{token:?}").to_uppercase());
2691 indent = 1;
2692 }
2693 Token::And | Token::Or => {
2694 if !current_line.is_empty() {
2695 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2696 current_line.clear();
2697 }
2698 lines.push(format!(" {token:?}").to_uppercase());
2699 }
2700 Token::Comma => {
2701 current_line.push(',');
2702 if indent > 0 {
2703 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2704 current_line.clear();
2705 }
2706 }
2707 Token::Eof => break,
2708 _ => {
2709 if !current_line.is_empty() {
2710 current_line.push(' ');
2711 }
2712 current_line.push_str(&format_token(&token));
2713 }
2714 }
2715 }
2716
2717 if !current_line.is_empty() {
2718 lines.push(format!("{}{}", " ".repeat(indent), current_line.trim()));
2719 }
2720 }
2721
2722 lines
2723}
2724
2725fn format_expression(expr: &SqlExpression) -> String {
2726 match expr {
2727 SqlExpression::Column(name) => name.clone(),
2728 SqlExpression::StringLiteral(s) => format!("'{s}'"),
2729 SqlExpression::NumberLiteral(n) => n.clone(),
2730 SqlExpression::BooleanLiteral(b) => b.to_string(),
2731 SqlExpression::Null => "NULL".to_string(),
2732 SqlExpression::DateTimeConstructor {
2733 year,
2734 month,
2735 day,
2736 hour,
2737 minute,
2738 second,
2739 } => {
2740 let mut result = format!("DateTime({year}, {month}, {day}");
2741 if let Some(h) = hour {
2742 result.push_str(&format!(", {h}"));
2743 if let Some(m) = minute {
2744 result.push_str(&format!(", {m}"));
2745 if let Some(s) = second {
2746 result.push_str(&format!(", {s}"));
2747 }
2748 }
2749 }
2750 result.push(')');
2751 result
2752 }
2753 SqlExpression::DateTimeToday {
2754 hour,
2755 minute,
2756 second,
2757 } => {
2758 let mut result = "DateTime()".to_string();
2759 if let Some(h) = hour {
2760 result = format!("DateTime(TODAY, {h}");
2761 if let Some(m) = minute {
2762 result.push_str(&format!(", {m}"));
2763 if let Some(s) = second {
2764 result.push_str(&format!(", {s}"));
2765 }
2766 }
2767 result.push(')');
2768 }
2769 result
2770 }
2771 SqlExpression::MethodCall {
2772 object,
2773 method,
2774 args,
2775 } => {
2776 let args_str = args
2777 .iter()
2778 .map(format_expression)
2779 .collect::<Vec<_>>()
2780 .join(", ");
2781 format!("{object}.{method}({args_str})")
2782 }
2783 SqlExpression::BinaryOp { left, op, right } => {
2784 if op == "OR" || op == "AND" {
2787 format!(
2790 "({} {} {})",
2791 format_expression(left),
2792 op,
2793 format_expression(right)
2794 )
2795 } else {
2796 format!(
2797 "{} {} {}",
2798 format_expression(left),
2799 op,
2800 format_expression(right)
2801 )
2802 }
2803 }
2804 SqlExpression::InList { expr, values } => {
2805 let values_str = values
2806 .iter()
2807 .map(format_expression)
2808 .collect::<Vec<_>>()
2809 .join(", ");
2810 format!("{} IN ({})", format_expression(expr), values_str)
2811 }
2812 SqlExpression::NotInList { expr, values } => {
2813 let values_str = values
2814 .iter()
2815 .map(format_expression)
2816 .collect::<Vec<_>>()
2817 .join(", ");
2818 format!("{} NOT IN ({})", format_expression(expr), values_str)
2819 }
2820 SqlExpression::Between { expr, lower, upper } => {
2821 format!(
2822 "{} BETWEEN {} AND {}",
2823 format_expression(expr),
2824 format_expression(lower),
2825 format_expression(upper)
2826 )
2827 }
2828 SqlExpression::Not { expr } => {
2829 format!("NOT {}", format_expression(expr))
2830 }
2831 SqlExpression::ChainedMethodCall { base, method, args } => {
2832 let args_str = args
2833 .iter()
2834 .map(format_expression)
2835 .collect::<Vec<_>>()
2836 .join(", ");
2837 format!("{}.{}({})", format_expression(base), method, args_str)
2838 }
2839 SqlExpression::FunctionCall { name, args } => {
2840 let args_str = args
2841 .iter()
2842 .map(format_expression)
2843 .collect::<Vec<_>>()
2844 .join(", ");
2845 format!("{name}({args_str})")
2846 }
2847 SqlExpression::WindowFunction {
2848 name,
2849 args,
2850 window_spec,
2851 } => {
2852 let args_str = args
2853 .iter()
2854 .map(format_expression)
2855 .collect::<Vec<_>>()
2856 .join(", ");
2857 let partition_str = if !window_spec.partition_by.is_empty() {
2858 format!(" PARTITION BY {}", window_spec.partition_by.join(", "))
2859 } else {
2860 String::new()
2861 };
2862 let order_str = if !window_spec.order_by.is_empty() {
2863 let cols = window_spec
2864 .order_by
2865 .iter()
2866 .map(|col| {
2867 let dir = match col.direction {
2868 SortDirection::Asc => "ASC",
2869 SortDirection::Desc => "DESC",
2870 };
2871 format!("{} {}", col.column, dir)
2872 })
2873 .collect::<Vec<_>>()
2874 .join(", ");
2875 format!(" ORDER BY {}", cols)
2876 } else {
2877 String::new()
2878 };
2879 format!("{name}({args_str}) OVER({partition_str}{order_str})")
2880 }
2881 SqlExpression::CaseExpression {
2882 when_branches,
2883 else_branch,
2884 } => {
2885 let mut result = String::from("CASE");
2886 for branch in when_branches {
2887 result.push_str(&format!(
2888 " WHEN {} THEN {}",
2889 format_expression(&branch.condition),
2890 format_expression(&branch.result)
2891 ));
2892 }
2893 if let Some(else_expr) = else_branch {
2894 result.push_str(&format!(" ELSE {}", format_expression(else_expr)));
2895 }
2896 result.push_str(" END");
2897 result
2898 }
2899 }
2900}
2901
2902fn format_token(token: &Token) -> String {
2903 match token {
2904 Token::Identifier(s) => s.clone(),
2905 Token::QuotedIdentifier(s) => format!("\"{s}\""),
2906 Token::StringLiteral(s) => format!("'{s}'"),
2907 Token::NumberLiteral(n) => n.clone(),
2908 Token::DateTime => "DateTime".to_string(),
2909 Token::Case => "CASE".to_string(),
2910 Token::When => "WHEN".to_string(),
2911 Token::Then => "THEN".to_string(),
2912 Token::Else => "ELSE".to_string(),
2913 Token::End => "END".to_string(),
2914 Token::Distinct => "DISTINCT".to_string(),
2915 Token::Over => "OVER".to_string(),
2916 Token::Partition => "PARTITION".to_string(),
2917 Token::By => "BY".to_string(),
2918 Token::LeftParen => "(".to_string(),
2919 Token::RightParen => ")".to_string(),
2920 Token::Comma => ",".to_string(),
2921 Token::Dot => ".".to_string(),
2922 Token::Equal => "=".to_string(),
2923 Token::NotEqual => "!=".to_string(),
2924 Token::LessThan => "<".to_string(),
2925 Token::GreaterThan => ">".to_string(),
2926 Token::LessThanOrEqual => "<=".to_string(),
2927 Token::GreaterThanOrEqual => ">=".to_string(),
2928 Token::In => "IN".to_string(),
2929 _ => format!("{token:?}").to_uppercase(),
2930 }
2931}
2932
2933fn analyze_statement(
2934 stmt: &SelectStatement,
2935 query: &str,
2936 _cursor_pos: usize,
2937) -> (CursorContext, Option<String>) {
2938 let trimmed = query.trim();
2940
2941 let comparison_ops = [" > ", " < ", " >= ", " <= ", " = ", " != "];
2943 for op in &comparison_ops {
2944 if let Some(op_pos) = query.rfind(op) {
2945 let before_op = safe_slice_to(query, op_pos);
2946 let after_op_start = op_pos + op.len();
2947 let after_op = if after_op_start < query.len() {
2948 &query[after_op_start..]
2949 } else {
2950 ""
2951 };
2952
2953 if let Some(col_name) = before_op.split_whitespace().last() {
2955 if col_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
2956 let after_op_trimmed = after_op.trim();
2958 if after_op_trimmed.is_empty()
2959 || (after_op_trimmed
2960 .chars()
2961 .all(|c| c.is_alphanumeric() || c == '_')
2962 && !after_op_trimmed.contains('('))
2963 {
2964 let partial = if after_op_trimmed.is_empty() {
2965 None
2966 } else {
2967 Some(after_op_trimmed.to_string())
2968 };
2969 return (
2970 CursorContext::AfterComparisonOp(
2971 col_name.to_string(),
2972 op.trim().to_string(),
2973 ),
2974 partial,
2975 );
2976 }
2977 }
2978 }
2979 }
2980 }
2981
2982 if trimmed.to_uppercase().ends_with(" AND")
2984 || trimmed.to_uppercase().ends_with(" OR")
2985 || trimmed.to_uppercase().ends_with(" AND ")
2986 || trimmed.to_uppercase().ends_with(" OR ")
2987 {
2988 } else {
2990 if let Some(dot_pos) = trimmed.rfind('.') {
2992 let before_dot = safe_slice_to(trimmed, dot_pos);
2994 let after_dot_start = dot_pos + 1;
2995 let after_dot = if after_dot_start < trimmed.len() {
2996 &trimmed[after_dot_start..]
2997 } else {
2998 ""
2999 };
3000
3001 if !after_dot.contains('(') {
3004 let col_name = if before_dot.ends_with('"') {
3006 let bytes = before_dot.as_bytes();
3008 let mut pos = before_dot.len() - 1; let mut found_start = None;
3010
3011 if pos > 0 {
3013 pos -= 1;
3014 while pos > 0 {
3015 if bytes[pos] == b'"' {
3016 if pos == 0 || bytes[pos - 1] != b'\\' {
3018 found_start = Some(pos);
3019 break;
3020 }
3021 }
3022 pos -= 1;
3023 }
3024 if found_start.is_none() && bytes[0] == b'"' {
3026 found_start = Some(0);
3027 }
3028 }
3029
3030 found_start.map(|start| safe_slice_from(before_dot, start))
3031 } else {
3032 before_dot
3035 .split_whitespace()
3036 .last()
3037 .map(|word| word.trim_start_matches('('))
3038 };
3039
3040 if let Some(col_name) = col_name {
3041 let is_valid = if col_name.starts_with('"') && col_name.ends_with('"') {
3043 true
3045 } else {
3046 col_name.chars().all(|c| c.is_alphanumeric() || c == '_')
3048 };
3049
3050 if is_valid {
3051 let partial_method = if after_dot.is_empty() {
3054 None
3055 } else if after_dot.chars().all(|c| c.is_alphanumeric() || c == '_') {
3056 Some(after_dot.to_string())
3057 } else {
3058 None
3059 };
3060
3061 let col_name_for_context = if col_name.starts_with('"')
3063 && col_name.ends_with('"')
3064 && col_name.len() > 2
3065 {
3066 col_name[1..col_name.len() - 1].to_string()
3067 } else {
3068 col_name.to_string()
3069 };
3070
3071 return (
3072 CursorContext::AfterColumn(col_name_for_context),
3073 partial_method,
3074 );
3075 }
3076 }
3077 }
3078 }
3079 }
3080
3081 if let Some(where_clause) = &stmt.where_clause {
3083 if trimmed.to_uppercase().ends_with(" AND") || trimmed.to_uppercase().ends_with(" OR") {
3085 let op = if trimmed.to_uppercase().ends_with(" AND") {
3086 LogicalOp::And
3087 } else {
3088 LogicalOp::Or
3089 };
3090 return (CursorContext::AfterLogicalOp(op), None);
3091 }
3092
3093 if let Some(and_pos) = query.to_uppercase().rfind(" AND ") {
3095 let after_and = safe_slice_from(query, and_pos + 5);
3096 let partial = extract_partial_at_end(after_and);
3097 if partial.is_some() {
3098 return (CursorContext::AfterLogicalOp(LogicalOp::And), partial);
3099 }
3100 }
3101
3102 if let Some(or_pos) = query.to_uppercase().rfind(" OR ") {
3103 let after_or = safe_slice_from(query, or_pos + 4);
3104 let partial = extract_partial_at_end(after_or);
3105 if partial.is_some() {
3106 return (CursorContext::AfterLogicalOp(LogicalOp::Or), partial);
3107 }
3108 }
3109
3110 if let Some(last_condition) = where_clause.conditions.last() {
3111 if let Some(connector) = &last_condition.connector {
3112 return (
3114 CursorContext::AfterLogicalOp(connector.clone()),
3115 extract_partial_at_end(query),
3116 );
3117 }
3118 }
3119 return (CursorContext::WhereClause, extract_partial_at_end(query));
3121 }
3122
3123 if query.to_uppercase().ends_with(" ORDER BY ") || query.to_uppercase().ends_with(" ORDER BY") {
3125 return (CursorContext::OrderByClause, None);
3126 }
3127
3128 if stmt.order_by.is_some() {
3130 return (CursorContext::OrderByClause, extract_partial_at_end(query));
3131 }
3132
3133 if stmt.from_table.is_some() && stmt.where_clause.is_none() && stmt.order_by.is_none() {
3134 return (CursorContext::FromClause, extract_partial_at_end(query));
3135 }
3136
3137 if !stmt.columns.is_empty() && stmt.from_table.is_none() {
3138 return (CursorContext::SelectClause, extract_partial_at_end(query));
3139 }
3140
3141 (CursorContext::Unknown, None)
3142}
3143
3144fn analyze_partial(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
3145 let upper = query.to_uppercase();
3146
3147 let trimmed = query.trim();
3149
3150 #[cfg(test)]
3151 {
3152 if trimmed.contains("\"Last Name\"") {
3153 eprintln!("DEBUG analyze_partial: query='{query}', trimmed='{trimmed}'");
3154 }
3155 }
3156
3157 let comparison_ops = [" > ", " < ", " >= ", " <= ", " = ", " != "];
3159 for op in &comparison_ops {
3160 if let Some(op_pos) = query.rfind(op) {
3161 let before_op = safe_slice_to(query, op_pos);
3162 let after_op_start = op_pos + op.len();
3163 let after_op = if after_op_start < query.len() {
3164 &query[after_op_start..]
3165 } else {
3166 ""
3167 };
3168
3169 if let Some(col_name) = before_op.split_whitespace().last() {
3171 if col_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
3172 let after_op_trimmed = after_op.trim();
3174 if after_op_trimmed.is_empty()
3175 || (after_op_trimmed
3176 .chars()
3177 .all(|c| c.is_alphanumeric() || c == '_')
3178 && !after_op_trimmed.contains('('))
3179 {
3180 let partial = if after_op_trimmed.is_empty() {
3181 None
3182 } else {
3183 Some(after_op_trimmed.to_string())
3184 };
3185 return (
3186 CursorContext::AfterComparisonOp(
3187 col_name.to_string(),
3188 op.trim().to_string(),
3189 ),
3190 partial,
3191 );
3192 }
3193 }
3194 }
3195 }
3196 }
3197
3198 if let Some(dot_pos) = trimmed.rfind('.') {
3201 #[cfg(test)]
3202 {
3203 if trimmed.contains("\"Last Name\"") {
3204 eprintln!("DEBUG: Found dot at position {dot_pos}");
3205 }
3206 }
3207 let before_dot = &trimmed[..dot_pos];
3209 let after_dot = &trimmed[dot_pos + 1..];
3210
3211 if !after_dot.contains('(') {
3214 let col_name = if before_dot.ends_with('"') {
3217 let bytes = before_dot.as_bytes();
3219 let mut pos = before_dot.len() - 1; let mut found_start = None;
3221
3222 #[cfg(test)]
3223 {
3224 if trimmed.contains("\"Last Name\"") {
3225 eprintln!("DEBUG: before_dot='{before_dot}', looking for opening quote");
3226 }
3227 }
3228
3229 if pos > 0 {
3231 pos -= 1;
3232 while pos > 0 {
3233 if bytes[pos] == b'"' {
3234 if pos == 0 || bytes[pos - 1] != b'\\' {
3236 found_start = Some(pos);
3237 break;
3238 }
3239 }
3240 pos -= 1;
3241 }
3242 if found_start.is_none() && bytes[0] == b'"' {
3244 found_start = Some(0);
3245 }
3246 }
3247
3248 if let Some(start) = found_start {
3249 let result = safe_slice_from(before_dot, start);
3251 #[cfg(test)]
3252 {
3253 if trimmed.contains("\"Last Name\"") {
3254 eprintln!("DEBUG: Extracted quoted identifier: '{result}'");
3255 }
3256 }
3257 Some(result)
3258 } else {
3259 #[cfg(test)]
3260 {
3261 if trimmed.contains("\"Last Name\"") {
3262 eprintln!("DEBUG: No opening quote found!");
3263 }
3264 }
3265 None
3266 }
3267 } else {
3268 before_dot
3271 .split_whitespace()
3272 .last()
3273 .map(|word| word.trim_start_matches('('))
3274 };
3275
3276 if let Some(col_name) = col_name {
3277 #[cfg(test)]
3278 {
3279 if trimmed.contains("\"Last Name\"") {
3280 eprintln!("DEBUG: col_name = '{col_name}'");
3281 }
3282 }
3283
3284 let is_valid = if col_name.starts_with('"') && col_name.ends_with('"') {
3286 true
3288 } else {
3289 col_name.chars().all(|c| c.is_alphanumeric() || c == '_')
3291 };
3292
3293 #[cfg(test)]
3294 {
3295 if trimmed.contains("\"Last Name\"") {
3296 eprintln!("DEBUG: is_valid = {is_valid}");
3297 }
3298 }
3299
3300 if is_valid {
3301 let partial_method = if after_dot.is_empty() {
3304 None
3305 } else if after_dot.chars().all(|c| c.is_alphanumeric() || c == '_') {
3306 Some(after_dot.to_string())
3307 } else {
3308 None
3309 };
3310
3311 let col_name_for_context = if col_name.starts_with('"')
3313 && col_name.ends_with('"')
3314 && col_name.len() > 2
3315 {
3316 col_name[1..col_name.len() - 1].to_string()
3317 } else {
3318 col_name.to_string()
3319 };
3320
3321 return (
3322 CursorContext::AfterColumn(col_name_for_context),
3323 partial_method,
3324 );
3325 }
3326 }
3327 }
3328 }
3329
3330 if let Some(and_pos) = upper.rfind(" AND ") {
3332 if cursor_pos >= and_pos + 5 {
3334 let after_and = safe_slice_from(query, and_pos + 5);
3336 let partial = extract_partial_at_end(after_and);
3337 return (CursorContext::AfterLogicalOp(LogicalOp::And), partial);
3338 }
3339 }
3340
3341 if let Some(or_pos) = upper.rfind(" OR ") {
3342 if cursor_pos >= or_pos + 4 {
3344 let after_or = safe_slice_from(query, or_pos + 4);
3346 let partial = extract_partial_at_end(after_or);
3347 return (CursorContext::AfterLogicalOp(LogicalOp::Or), partial);
3348 }
3349 }
3350
3351 if trimmed.to_uppercase().ends_with(" AND") || trimmed.to_uppercase().ends_with(" OR") {
3353 let op = if trimmed.to_uppercase().ends_with(" AND") {
3354 LogicalOp::And
3355 } else {
3356 LogicalOp::Or
3357 };
3358 return (CursorContext::AfterLogicalOp(op), None);
3359 }
3360
3361 if upper.ends_with(" ORDER BY ") || upper.ends_with(" ORDER BY") || upper.contains("ORDER BY ")
3363 {
3364 return (CursorContext::OrderByClause, extract_partial_at_end(query));
3365 }
3366
3367 if upper.contains("WHERE") && !upper.contains("ORDER") && !upper.contains("GROUP") {
3368 return (CursorContext::WhereClause, extract_partial_at_end(query));
3369 }
3370
3371 if upper.contains("FROM") && !upper.contains("WHERE") && !upper.contains("ORDER") {
3372 return (CursorContext::FromClause, extract_partial_at_end(query));
3373 }
3374
3375 if upper.contains("SELECT") && !upper.contains("FROM") {
3376 return (CursorContext::SelectClause, extract_partial_at_end(query));
3377 }
3378
3379 (CursorContext::Unknown, None)
3380}
3381
3382fn extract_partial_at_end(query: &str) -> Option<String> {
3383 let trimmed = query.trim();
3384
3385 if let Some(last_word) = trimmed.split_whitespace().last() {
3387 if last_word.starts_with('"') && !last_word.ends_with('"') {
3388 return Some(last_word.to_string());
3390 }
3391 }
3392
3393 let last_word = trimmed.split_whitespace().last()?;
3395
3396 if last_word.chars().all(|c| c.is_alphanumeric() || c == '_') && !is_sql_keyword(last_word) {
3398 Some(last_word.to_string())
3399 } else {
3400 None
3401 }
3402}
3403
3404fn is_sql_keyword(word: &str) -> bool {
3405 matches!(
3406 word.to_uppercase().as_str(),
3407 "SELECT"
3408 | "FROM"
3409 | "WHERE"
3410 | "AND"
3411 | "OR"
3412 | "IN"
3413 | "ORDER"
3414 | "BY"
3415 | "GROUP"
3416 | "HAVING"
3417 | "ASC"
3418 | "DESC"
3419 | "DISTINCT"
3420 )
3421}
3422
3423#[cfg(test)]
3424mod tests {
3425 use super::*;
3426
3427 #[test]
3428 fn test_tokenizer_window_functions() {
3429 let mut lexer = Lexer::new("LAG(value) OVER (PARTITION BY category ORDER BY id)");
3430 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "LAG"));
3431 assert!(matches!(lexer.next_token(), Token::LeftParen));
3432 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "value"));
3433 assert!(matches!(lexer.next_token(), Token::RightParen));
3434
3435 let over_token = lexer.next_token();
3436 println!("Expected OVER, got: {:?}", over_token);
3437 assert!(matches!(over_token, Token::Over));
3438
3439 assert!(matches!(lexer.next_token(), Token::LeftParen));
3440 assert!(matches!(lexer.next_token(), Token::Partition));
3441 assert!(matches!(lexer.next_token(), Token::By));
3442 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "category"));
3443 }
3444
3445 #[test]
3446 fn test_parse_window_function() {
3447 let query = "SELECT LAG(value, 1) OVER (ORDER BY id) as prev_value FROM test";
3448 let mut parser = Parser::new(query);
3449 let result = parser.parse();
3450
3451 assert!(
3452 result.is_ok(),
3453 "Failed to parse window function: {:?}",
3454 result
3455 );
3456 let stmt = result.unwrap();
3457
3458 if let Some(item) = stmt.select_items.get(0) {
3460 match item {
3461 SelectItem::Expression { expr, alias } => {
3462 println!("Parsed expression: {:?}", expr);
3463 assert!(matches!(expr, SqlExpression::WindowFunction { .. }));
3464 assert_eq!(alias, "prev_value");
3465 }
3466 _ => panic!("Expected expression, got: {:?}", item),
3467 }
3468 } else {
3469 panic!("No select items found");
3470 }
3471 }
3472
3473 #[test]
3474 fn test_chained_method_calls() {
3475 let query = "SELECT * FROM trades WHERE commission.ToString().IndexOf('.') = 1";
3477 let mut parser = Parser::new(query);
3478 let result = parser.parse();
3479
3480 assert!(
3481 result.is_ok(),
3482 "Failed to parse chained method calls: {result:?}"
3483 );
3484
3485 let query2 = "SELECT * FROM data WHERE field.ToUpper().Replace('A', 'B').StartsWith('C')";
3487 let mut parser2 = Parser::new(query2);
3488 let result2 = parser2.parse();
3489
3490 assert!(
3491 result2.is_ok(),
3492 "Failed to parse multiple chained calls: {result2:?}"
3493 );
3494 }
3495
3496 #[test]
3497 fn test_tokenizer() {
3498 let mut lexer = Lexer::new("SELECT * FROM trade_deal WHERE price > 100");
3499
3500 assert!(matches!(lexer.next_token(), Token::Select));
3501 assert!(matches!(lexer.next_token(), Token::Star));
3502 assert!(matches!(lexer.next_token(), Token::From));
3503 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "trade_deal"));
3504 assert!(matches!(lexer.next_token(), Token::Where));
3505 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "price"));
3506 assert!(matches!(lexer.next_token(), Token::GreaterThan));
3507 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "100"));
3508 }
3509
3510 #[test]
3511 fn test_tokenizer_datetime() {
3512 let mut lexer = Lexer::new("WHERE createdDate > DateTime(2025, 10, 20)");
3513
3514 assert!(matches!(lexer.next_token(), Token::Where));
3515 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "createdDate"));
3516 assert!(matches!(lexer.next_token(), Token::GreaterThan));
3517 assert!(matches!(lexer.next_token(), Token::DateTime));
3518 assert!(matches!(lexer.next_token(), Token::LeftParen));
3519 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "2025"));
3520 assert!(matches!(lexer.next_token(), Token::Comma));
3521 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "10"));
3522 assert!(matches!(lexer.next_token(), Token::Comma));
3523 assert!(matches!(lexer.next_token(), Token::NumberLiteral(s) if s == "20"));
3524 assert!(matches!(lexer.next_token(), Token::RightParen));
3525 }
3526
3527 #[test]
3528 fn test_parse_simple_select() {
3529 let mut parser = Parser::new("SELECT * FROM trade_deal");
3530 let stmt = parser.parse().unwrap();
3531
3532 assert_eq!(stmt.columns, vec!["*"]);
3533 assert_eq!(stmt.from_table, Some("trade_deal".to_string()));
3534 assert!(stmt.where_clause.is_none());
3535 }
3536
3537 #[test]
3538 fn test_parse_where_with_method() {
3539 let mut parser = Parser::new("SELECT * FROM trade_deal WHERE name.Contains(\"test\")");
3540 let stmt = parser.parse().unwrap();
3541
3542 assert!(stmt.where_clause.is_some());
3543 let where_clause = stmt.where_clause.unwrap();
3544 assert_eq!(where_clause.conditions.len(), 1);
3545 }
3546
3547 #[test]
3548 fn test_parse_datetime_constructor() {
3549 let mut parser =
3550 Parser::new("SELECT * FROM trade_deal WHERE createdDate > DateTime(2025, 10, 20)");
3551 let stmt = parser.parse().unwrap();
3552
3553 assert!(stmt.where_clause.is_some());
3554 let where_clause = stmt.where_clause.unwrap();
3555 assert_eq!(where_clause.conditions.len(), 1);
3556
3557 if let SqlExpression::BinaryOp { left, op, right } = &where_clause.conditions[0].expr {
3559 assert_eq!(op, ">");
3560 assert!(matches!(left.as_ref(), SqlExpression::Column(col) if col == "createdDate"));
3561 assert!(matches!(
3562 right.as_ref(),
3563 SqlExpression::DateTimeConstructor {
3564 year: 2025,
3565 month: 10,
3566 day: 20,
3567 hour: None,
3568 minute: None,
3569 second: None
3570 }
3571 ));
3572 } else {
3573 panic!("Expected BinaryOp with DateTime constructor");
3574 }
3575 }
3576
3577 #[test]
3578 fn test_cursor_context_after_and() {
3579 let query = "SELECT * FROM trade_deal WHERE status = 'active' AND ";
3580 let (context, partial) = detect_cursor_context(query, query.len());
3581
3582 assert!(matches!(
3583 context,
3584 CursorContext::AfterLogicalOp(LogicalOp::And)
3585 ));
3586 assert_eq!(partial, None);
3587 }
3588
3589 #[test]
3590 fn test_cursor_context_with_partial() {
3591 let query = "SELECT * FROM trade_deal WHERE status = 'active' AND p";
3592 let (context, partial) = detect_cursor_context(query, query.len());
3593
3594 assert!(matches!(
3595 context,
3596 CursorContext::AfterLogicalOp(LogicalOp::And)
3597 ));
3598 assert_eq!(partial, Some("p".to_string()));
3599 }
3600
3601 #[test]
3602 fn test_cursor_context_after_datetime_comparison() {
3603 let query = "SELECT * FROM trade_deal WHERE createdDate > ";
3604 let (context, partial) = detect_cursor_context(query, query.len());
3605
3606 assert!(
3607 matches!(context, CursorContext::AfterComparisonOp(col, op) if col == "createdDate" && op == ">")
3608 );
3609 assert_eq!(partial, None);
3610 }
3611
3612 #[test]
3613 fn test_cursor_context_partial_datetime() {
3614 let query = "SELECT * FROM trade_deal WHERE createdDate > Date";
3615 let (context, partial) = detect_cursor_context(query, query.len());
3616
3617 assert!(
3618 matches!(context, CursorContext::AfterComparisonOp(col, op) if col == "createdDate" && op == ">")
3619 );
3620 assert_eq!(partial, Some("Date".to_string()));
3621 }
3622
3623 #[test]
3625 fn test_tokenizer_quoted_identifier() {
3626 let mut lexer = Lexer::new(r#"SELECT "Customer Id", "First Name" FROM customers"#);
3627
3628 assert!(matches!(lexer.next_token(), Token::Select));
3629 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "Customer Id"));
3630 assert!(matches!(lexer.next_token(), Token::Comma));
3631 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "First Name"));
3632 assert!(matches!(lexer.next_token(), Token::From));
3633 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "customers"));
3634 }
3635
3636 #[test]
3637 fn test_tokenizer_quoted_vs_string_literal() {
3638 let mut lexer = Lexer::new(r#"WHERE "Customer Id" = 'John' AND Country.Contains('USA')"#);
3640
3641 assert!(matches!(lexer.next_token(), Token::Where));
3642 assert!(matches!(lexer.next_token(), Token::QuotedIdentifier(s) if s == "Customer Id"));
3643 assert!(matches!(lexer.next_token(), Token::Equal));
3644 assert!(matches!(lexer.next_token(), Token::StringLiteral(s) if s == "John"));
3645 assert!(matches!(lexer.next_token(), Token::And));
3646 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Country"));
3647 assert!(matches!(lexer.next_token(), Token::Dot));
3648 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Contains"));
3649 assert!(matches!(lexer.next_token(), Token::LeftParen));
3650 assert!(matches!(lexer.next_token(), Token::StringLiteral(s) if s == "USA"));
3651 assert!(matches!(lexer.next_token(), Token::RightParen));
3652 }
3653
3654 #[test]
3655 fn test_tokenizer_method_with_double_quotes_should_be_string() {
3656 let mut lexer = Lexer::new(r#"Country.Contains("Alb")"#);
3659
3660 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Country"));
3661 assert!(matches!(lexer.next_token(), Token::Dot));
3662 assert!(matches!(lexer.next_token(), Token::Identifier(s) if s == "Contains"));
3663 assert!(matches!(lexer.next_token(), Token::LeftParen));
3664
3665 let token = lexer.next_token();
3668 println!("Token for \"Alb\": {token:?}");
3669 assert!(matches!(lexer.next_token(), Token::RightParen));
3673 }
3674
3675 #[test]
3676 fn test_parse_select_with_quoted_columns() {
3677 let mut parser = Parser::new(r#"SELECT "Customer Id", Company FROM customers"#);
3678 let stmt = parser.parse().unwrap();
3679
3680 assert_eq!(stmt.columns, vec!["Customer Id", "Company"]);
3681 assert_eq!(stmt.from_table, Some("customers".to_string()));
3682 }
3683
3684 #[test]
3685 fn test_cursor_context_select_with_partial_quoted() {
3686 let query = r#"SELECT "Cust"#;
3688 let (context, partial) = detect_cursor_context(query, query.len() - 1); println!("Context: {context:?}, Partial: {partial:?}");
3691 assert!(matches!(context, CursorContext::SelectClause));
3692 }
3695
3696 #[test]
3697 fn test_cursor_context_select_after_comma_with_quoted() {
3698 let query = r#"SELECT Company, "Customer "#;
3700 let (context, partial) = detect_cursor_context(query, query.len());
3701
3702 println!("Context: {context:?}, Partial: {partial:?}");
3703 assert!(matches!(context, CursorContext::SelectClause));
3704 }
3706
3707 #[test]
3708 fn test_cursor_context_order_by_quoted() {
3709 let query = r#"SELECT * FROM customers ORDER BY "Cust"#;
3710 let (context, partial) = detect_cursor_context(query, query.len() - 1);
3711
3712 println!("Context: {context:?}, Partial: {partial:?}");
3713 assert!(matches!(context, CursorContext::OrderByClause));
3714 }
3716
3717 #[test]
3718 fn test_where_clause_with_quoted_column() {
3719 let mut parser = Parser::new(r#"SELECT * FROM customers WHERE "Customer Id" = 'C123'"#);
3720 let stmt = parser.parse().unwrap();
3721
3722 assert!(stmt.where_clause.is_some());
3723 let where_clause = stmt.where_clause.unwrap();
3724 assert_eq!(where_clause.conditions.len(), 1);
3725
3726 if let SqlExpression::BinaryOp { left, op, right } = &where_clause.conditions[0].expr {
3727 assert_eq!(op, "=");
3728 assert!(matches!(left.as_ref(), SqlExpression::Column(col) if col == "Customer Id"));
3729 assert!(matches!(right.as_ref(), SqlExpression::StringLiteral(s) if s == "C123"));
3730 } else {
3731 panic!("Expected BinaryOp");
3732 }
3733 }
3734
3735 #[test]
3736 fn test_parse_method_with_double_quotes_as_string() {
3737 let mut parser = Parser::new(r#"SELECT * FROM customers WHERE Country.Contains("USA")"#);
3739 let stmt = parser.parse().unwrap();
3740
3741 assert!(stmt.where_clause.is_some());
3742 let where_clause = stmt.where_clause.unwrap();
3743 assert_eq!(where_clause.conditions.len(), 1);
3744
3745 if let SqlExpression::MethodCall {
3746 object,
3747 method,
3748 args,
3749 } = &where_clause.conditions[0].expr
3750 {
3751 assert_eq!(object, "Country");
3752 assert_eq!(method, "Contains");
3753 assert_eq!(args.len(), 1);
3754 assert!(matches!(&args[0], SqlExpression::StringLiteral(s) if s == "USA"));
3756 } else {
3757 panic!("Expected MethodCall");
3758 }
3759 }
3760
3761 #[test]
3762 fn test_extract_partial_with_quoted_columns_in_query() {
3763 let query = r#"SELECT City,Company,Country,"Customer Id" FROM customers ORDER BY coun"#;
3765 let (context, partial) = detect_cursor_context(query, query.len());
3766
3767 assert!(matches!(context, CursorContext::OrderByClause));
3768 assert_eq!(
3769 partial,
3770 Some("coun".to_string()),
3771 "Should extract 'coun' as partial, not everything after the quoted column"
3772 );
3773 }
3774
3775 #[test]
3776 fn test_extract_partial_quoted_identifier_being_typed() {
3777 let query = r#"SELECT "Cust"#;
3779 let partial = extract_partial_at_end(query);
3780 assert_eq!(partial, Some("\"Cust".to_string()));
3781
3782 let query2 = r#"SELECT "Customer Id" FROM"#;
3784 let partial2 = extract_partial_at_end(query2);
3785 assert_eq!(partial2, None); }
3787
3788 #[test]
3790 fn test_complex_where_parentheses_basic() {
3791 let mut parser =
3793 Parser::new(r#"SELECT * FROM trades WHERE (status = "active" OR status = "pending")"#);
3794 let stmt = parser.parse().unwrap();
3795
3796 assert!(stmt.where_clause.is_some());
3797 let where_clause = stmt.where_clause.unwrap();
3798 assert_eq!(where_clause.conditions.len(), 1);
3799
3800 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[0].expr {
3802 assert_eq!(op, "OR");
3803 } else {
3804 panic!("Expected BinaryOp with OR");
3805 }
3806 }
3807
3808 #[test]
3809 fn test_complex_where_mixed_and_or_with_parens() {
3810 let mut parser = Parser::new(
3812 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#,
3813 );
3814 let stmt = parser.parse().unwrap();
3815
3816 assert!(stmt.where_clause.is_some());
3817 let where_clause = stmt.where_clause.unwrap();
3818 assert_eq!(where_clause.conditions.len(), 2);
3819
3820 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[0].expr {
3822 assert_eq!(op, "OR");
3823 } else {
3824 panic!("Expected first condition to be OR expression");
3825 }
3826
3827 assert!(matches!(
3829 where_clause.conditions[0].connector,
3830 Some(LogicalOp::And)
3831 ));
3832
3833 if let SqlExpression::BinaryOp { op, .. } = &where_clause.conditions[1].expr {
3835 assert_eq!(op, ">");
3836 } else {
3837 panic!("Expected second condition to be price > 100");
3838 }
3839 }
3840
3841 #[test]
3842 fn test_complex_where_nested_parentheses() {
3843 let mut parser = Parser::new(
3845 r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND price > 100) OR status = "cancelled""#,
3846 );
3847 let stmt = parser.parse().unwrap();
3848
3849 assert!(stmt.where_clause.is_some());
3850 let where_clause = stmt.where_clause.unwrap();
3851
3852 assert!(!where_clause.conditions.is_empty());
3854 }
3855
3856 #[test]
3857 fn test_complex_where_multiple_or_groups() {
3858 let mut parser = Parser::new(
3860 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL" OR symbol = "MSFT") AND (price > 100 AND price < 500)"#,
3861 );
3862 let stmt = parser.parse().unwrap();
3863
3864 assert!(stmt.where_clause.is_some());
3865 let where_clause = stmt.where_clause.unwrap();
3866 assert_eq!(where_clause.conditions.len(), 2);
3867
3868 assert!(matches!(
3870 where_clause.conditions[0].connector,
3871 Some(LogicalOp::And)
3872 ));
3873 }
3874
3875 #[test]
3876 fn test_complex_where_with_methods_in_parens() {
3877 let mut parser = Parser::new(
3879 r#"SELECT * FROM trades WHERE (symbol.StartsWith("A") OR symbol.StartsWith("G")) AND volume > 1000000"#,
3880 );
3881 let stmt = parser.parse().unwrap();
3882
3883 assert!(stmt.where_clause.is_some());
3884 let where_clause = stmt.where_clause.unwrap();
3885 assert_eq!(where_clause.conditions.len(), 2);
3886
3887 if let SqlExpression::BinaryOp { op, left, right } = &where_clause.conditions[0].expr {
3889 assert_eq!(op, "OR");
3890 assert!(matches!(left.as_ref(), SqlExpression::MethodCall { .. }));
3891 assert!(matches!(right.as_ref(), SqlExpression::MethodCall { .. }));
3892 } else {
3893 panic!("Expected OR of method calls");
3894 }
3895 }
3896
3897 #[test]
3898 fn test_complex_where_date_comparisons_with_parens() {
3899 let mut parser = Parser::new(
3901 r#"SELECT * FROM trades WHERE (executionDate > DateTime(2024, 1, 1) AND executionDate < DateTime(2024, 12, 31)) AND (status = "filled" OR status = "partial")"#,
3902 );
3903 let stmt = parser.parse().unwrap();
3904
3905 assert!(stmt.where_clause.is_some());
3906 let where_clause = stmt.where_clause.unwrap();
3907 assert_eq!(where_clause.conditions.len(), 2);
3908
3909 assert!(matches!(
3911 where_clause.conditions[0].connector,
3912 Some(LogicalOp::And)
3913 ));
3914 }
3915
3916 #[test]
3917 fn test_complex_where_price_volume_filters() {
3918 let mut parser = Parser::new(
3920 r"SELECT * FROM trades WHERE ((price > 100 AND price < 200) OR (price > 500 AND price < 1000)) AND volume > 10000",
3921 );
3922 let stmt = parser.parse().unwrap();
3923
3924 assert!(stmt.where_clause.is_some());
3925 let where_clause = stmt.where_clause.unwrap();
3926
3927 assert!(!where_clause.conditions.is_empty());
3929 }
3930
3931 #[test]
3932 fn test_complex_where_mixed_string_numeric() {
3933 let mut parser = Parser::new(
3935 r#"SELECT * FROM trades WHERE (exchange = "NYSE" OR exchange = "NASDAQ") AND (volume > 1000000 OR notes.Contains("urgent"))"#,
3936 );
3937 let stmt = parser.parse().unwrap();
3938
3939 assert!(stmt.where_clause.is_some());
3940 }
3942
3943 #[test]
3944 fn test_complex_where_triple_nested() {
3945 let mut parser = Parser::new(
3947 r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND (price > 100 OR volume > 1000000)) OR (status = "cancelled" AND reason.Contains("timeout"))"#,
3948 );
3949 let stmt = parser.parse().unwrap();
3950
3951 assert!(stmt.where_clause.is_some());
3952 }
3954
3955 #[test]
3956 fn test_complex_where_single_parens_around_and() {
3957 let mut parser = Parser::new(
3959 r#"SELECT * FROM trades WHERE (symbol = "AAPL" AND price > 150 AND volume > 100000)"#,
3960 );
3961 let stmt = parser.parse().unwrap();
3962
3963 assert!(stmt.where_clause.is_some());
3964 let where_clause = stmt.where_clause.unwrap();
3965
3966 assert!(!where_clause.conditions.is_empty());
3968 }
3969
3970 #[test]
3972 fn test_format_preserves_simple_parentheses() {
3973 let query = r#"SELECT * FROM trades WHERE (status = "active" OR status = "pending")"#;
3974 let formatted = format_sql_pretty_compact(query, 5);
3975 let formatted_text = formatted.join(" ");
3976
3977 assert!(formatted_text.contains("(status"));
3979 assert!(formatted_text.contains("\"pending\")"));
3980
3981 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
3983 let formatted_parens = formatted_text
3984 .chars()
3985 .filter(|c| *c == '(' || *c == ')')
3986 .count();
3987 assert_eq!(
3988 original_parens, formatted_parens,
3989 "Parentheses should be preserved"
3990 );
3991 }
3992
3993 #[test]
3994 fn test_format_preserves_complex_parentheses() {
3995 let query =
3996 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#;
3997 let formatted = format_sql_pretty_compact(query, 5);
3998 let formatted_text = formatted.join(" ");
3999
4000 assert!(formatted_text.contains("(symbol"));
4002 assert!(formatted_text.contains("\"GOOGL\")"));
4003
4004 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
4006 let formatted_parens = formatted_text
4007 .chars()
4008 .filter(|c| *c == '(' || *c == ')')
4009 .count();
4010 assert_eq!(original_parens, formatted_parens);
4011 }
4012
4013 #[test]
4014 fn test_format_preserves_nested_parentheses() {
4015 let query = r#"SELECT * FROM trades WHERE ((symbol = "AAPL" OR symbol = "GOOGL") AND price > 100) OR status = "cancelled""#;
4016 let formatted = format_sql_pretty_compact(query, 5);
4017 let formatted_text = formatted.join(" ");
4018
4019 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
4021 let formatted_parens = formatted_text
4022 .chars()
4023 .filter(|c| *c == '(' || *c == ')')
4024 .count();
4025 assert_eq!(
4026 original_parens, formatted_parens,
4027 "Nested parentheses should be preserved"
4028 );
4029 assert_eq!(original_parens, 4, "Should have 4 parentheses total");
4030 }
4031
4032 #[test]
4033 fn test_format_preserves_method_calls_in_parentheses() {
4034 let query = r#"SELECT * FROM trades WHERE (symbol.StartsWith("A") OR symbol.StartsWith("G")) AND volume > 1000000"#;
4035 let formatted = format_sql_pretty_compact(query, 5);
4036 let formatted_text = formatted.join(" ");
4037
4038 assert!(formatted_text.contains("(symbol.StartsWith"));
4040 assert!(formatted_text.contains("StartsWith(\"A\")"));
4041 assert!(formatted_text.contains("StartsWith(\"G\")"));
4042
4043 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
4045 let formatted_parens = formatted_text
4046 .chars()
4047 .filter(|c| *c == '(' || *c == ')')
4048 .count();
4049 assert_eq!(original_parens, formatted_parens);
4050 assert_eq!(
4051 original_parens, 6,
4052 "Should have 6 parentheses (1 group + 2 method calls)"
4053 );
4054 }
4055
4056 #[test]
4057 fn test_format_preserves_multiple_groups() {
4058 let query = r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND (price > 100 AND price < 500)"#;
4059 let formatted = format_sql_pretty_compact(query, 5);
4060 let formatted_text = formatted.join(" ");
4061
4062 assert!(formatted_text.contains("(symbol"));
4064 assert!(formatted_text.contains("(price"));
4065
4066 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
4067 let formatted_parens = formatted_text
4068 .chars()
4069 .filter(|c| *c == '(' || *c == ')')
4070 .count();
4071 assert_eq!(original_parens, formatted_parens);
4072 assert_eq!(original_parens, 4, "Should have 4 parentheses (2 groups)");
4073 }
4074
4075 #[test]
4076 fn test_format_preserves_date_ranges() {
4077 let query = r"SELECT * FROM trades WHERE (executionDate > DateTime(2024, 1, 1) AND executionDate < DateTime(2024, 12, 31))";
4078 let formatted = format_sql_pretty_compact(query, 5);
4079 let formatted_text = formatted.join(" ");
4080
4081 assert!(formatted_text.contains("(executionDate"));
4083 assert!(formatted_text.contains("DateTime(2024, 1, 1)"));
4084 assert!(formatted_text.contains("DateTime(2024, 12, 31)"));
4085
4086 let original_parens = query.chars().filter(|c| *c == '(' || *c == ')').count();
4087 let formatted_parens = formatted_text
4088 .chars()
4089 .filter(|c| *c == '(' || *c == ')')
4090 .count();
4091 assert_eq!(original_parens, formatted_parens);
4092 }
4093
4094 #[test]
4095 fn test_format_multiline_layout() {
4096 let query =
4098 r#"SELECT * FROM trades WHERE (symbol = "AAPL" OR symbol = "GOOGL") AND price > 100"#;
4099 let formatted = format_sql_pretty_compact(query, 5);
4100
4101 assert!(formatted.len() >= 4, "Should have multiple lines");
4103 assert_eq!(formatted[0], "SELECT");
4104 assert!(formatted[1].trim().starts_with('*'));
4105 assert!(formatted[2].starts_with("FROM"));
4106 assert_eq!(formatted[3], "WHERE");
4107
4108 let where_lines: Vec<_> = formatted.iter().skip(4).collect();
4110 assert!(where_lines.iter().any(|l| l.contains("(symbol")));
4111 assert!(where_lines.iter().any(|l| l.trim() == "AND"));
4112 }
4113
4114 #[test]
4115 fn test_between_simple() {
4116 let mut parser = Parser::new("SELECT * FROM table WHERE price BETWEEN 50 AND 100");
4117 let stmt = parser.parse().expect("Should parse simple BETWEEN");
4118
4119 assert!(stmt.where_clause.is_some());
4120 let where_clause = stmt.where_clause.unwrap();
4121 assert_eq!(where_clause.conditions.len(), 1);
4122
4123 let ast = format_ast_tree("SELECT * FROM table WHERE price BETWEEN 50 AND 100");
4125 assert!(!ast.contains("PARSE ERROR"));
4126 assert!(ast.contains("SelectStatement"));
4127 }
4128
4129 #[test]
4130 fn test_between_in_parentheses() {
4131 let mut parser = Parser::new("SELECT * FROM table WHERE (price BETWEEN 50 AND 100)");
4132 let stmt = parser.parse().expect("Should parse BETWEEN in parentheses");
4133
4134 assert!(stmt.where_clause.is_some());
4135
4136 let ast = format_ast_tree("SELECT * FROM table WHERE (price BETWEEN 50 AND 100)");
4138 assert!(!ast.contains("PARSE ERROR"), "Should not have parse error");
4139 }
4140
4141 #[test]
4142 fn test_between_with_or() {
4143 let query = "SELECT * FROM test WHERE (Price BETWEEN 50 AND 100) OR (quantity > 5)";
4144 let mut parser = Parser::new(query);
4145 let stmt = parser.parse().expect("Should parse BETWEEN with OR");
4146
4147 assert!(stmt.where_clause.is_some());
4148 }
4151
4152 #[test]
4153 fn test_between_with_and() {
4154 let query = "SELECT * FROM table WHERE category = 'Books' AND price BETWEEN 10 AND 50";
4155 let mut parser = Parser::new(query);
4156 let stmt = parser.parse().expect("Should parse BETWEEN with AND");
4157
4158 assert!(stmt.where_clause.is_some());
4159 let where_clause = stmt.where_clause.unwrap();
4160 assert_eq!(where_clause.conditions.len(), 2); }
4162
4163 #[test]
4164 fn test_multiple_between() {
4165 let query =
4166 "SELECT * FROM table WHERE (price BETWEEN 10 AND 50) AND (quantity BETWEEN 5 AND 20)";
4167 let mut parser = Parser::new(query);
4168 let stmt = parser
4169 .parse()
4170 .expect("Should parse multiple BETWEEN clauses");
4171
4172 assert!(stmt.where_clause.is_some());
4173 }
4174
4175 #[test]
4176 fn test_between_complex_query() {
4177 let query = "SELECT * FROM test_sorting WHERE (Price BETWEEN 50 AND 100) OR (Product.Length() > 5) ORDER BY Category ASC, price DESC";
4179 let mut parser = Parser::new(query);
4180 let stmt = parser
4181 .parse()
4182 .expect("Should parse complex query with BETWEEN, method calls, and ORDER BY");
4183
4184 assert!(stmt.where_clause.is_some());
4185 assert!(stmt.order_by.is_some());
4186
4187 let order_by = stmt.order_by.unwrap();
4188 assert_eq!(order_by.len(), 2);
4189 assert_eq!(order_by[0].column, "Category");
4190 assert!(matches!(order_by[0].direction, SortDirection::Asc));
4191 assert_eq!(order_by[1].column, "price");
4192 assert!(matches!(order_by[1].direction, SortDirection::Desc));
4193 }
4194
4195 #[test]
4196 fn test_between_formatting() {
4197 let expr = SqlExpression::Between {
4198 expr: Box::new(SqlExpression::Column("price".to_string())),
4199 lower: Box::new(SqlExpression::NumberLiteral("50".to_string())),
4200 upper: Box::new(SqlExpression::NumberLiteral("100".to_string())),
4201 };
4202
4203 let formatted = format_expression(&expr);
4204 assert_eq!(formatted, "price BETWEEN 50 AND 100");
4205
4206 let ast_formatted = format_expression_ast(&expr);
4207 assert!(ast_formatted.contains("Between"));
4208 assert!(ast_formatted.contains("50"));
4209 assert!(ast_formatted.contains("100"));
4210 }
4211
4212 #[test]
4213 fn test_utf8_boundary_safety() {
4214 let query_with_unicode = "SELECT * FROM table WHERE column = 'héllo'";
4216
4217 for pos in 0..=query_with_unicode.len() {
4219 let result =
4221 std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, pos));
4222
4223 assert!(
4224 result.is_ok(),
4225 "Panic at position {pos} in query with Unicode"
4226 );
4227 }
4228
4229 let result = std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, 1000));
4231 assert!(result.is_ok(), "Panic with position beyond string length");
4232
4233 let pos_in_e = query_with_unicode.find('é').unwrap() + 1; let result =
4236 std::panic::catch_unwind(|| detect_cursor_context(query_with_unicode, pos_in_e));
4237 assert!(
4238 result.is_ok(),
4239 "Panic with cursor in middle of UTF-8 character"
4240 );
4241 }
4242}