1pub use super::parser::ast::{
5 CTEType, Comment, Condition, DataFormat, FileCTESpec, FrameBound, FrameUnit, HttpMethod,
6 IntoTable, JoinClause, JoinCondition, JoinOperator, JoinType, LogicalOp, OrderByColumn,
7 OrderByItem, PivotAggregate, SelectItem, SelectStatement, SetOperation, SingleJoinCondition,
8 SortDirection, SqlExpression, TableFunction, TableSource, WebCTESpec, WhenBranch, WhereClause,
9 WindowFrame, WindowSpec, CTE,
10};
11pub use super::parser::legacy::{ParseContext, ParseState, Schema, SqlParser, SqlToken, TableInfo};
12pub use super::parser::lexer::{Lexer, LexerMode, Token};
13pub use super::parser::ParserConfig;
14
15pub use super::parser::formatter::{format_ast_tree, format_sql_pretty, format_sql_pretty_compact};
17
18pub use super::parser::ast_formatter::{format_sql_ast, format_sql_ast_with_config, FormatConfig};
20
21use super::parser::expressions::arithmetic::{
23 parse_additive as parse_additive_expr, parse_multiplicative as parse_multiplicative_expr,
24 ParseArithmetic,
25};
26use super::parser::expressions::case::{parse_case_expression as parse_case_expr, ParseCase};
27use super::parser::expressions::comparison::{
28 parse_comparison as parse_comparison_expr, parse_in_operator, ParseComparison,
29};
30use super::parser::expressions::logical::{
31 parse_logical_and as parse_logical_and_expr, parse_logical_or as parse_logical_or_expr,
32 ParseLogical,
33};
34use super::parser::expressions::primary::{
35 parse_primary as parse_primary_expr, ParsePrimary, PrimaryExpressionContext,
36};
37use super::parser::expressions::ExpressionParser;
38
39use crate::sql::functions::{FunctionCategory, FunctionRegistry};
41use crate::sql::generators::GeneratorRegistry;
42use std::sync::Arc;
43
44use super::parser::file_cte_parser::FileCteParser;
46use super::parser::web_cte_parser::WebCteParser;
47
48#[derive(Debug, Clone, Copy, PartialEq)]
50pub enum ParserMode {
51 Standard,
53 PreserveComments,
55}
56
57impl Default for ParserMode {
58 fn default() -> Self {
59 ParserMode::Standard
60 }
61}
62
63pub struct Parser {
64 lexer: Lexer,
65 pub current_token: Token, in_method_args: bool, columns: Vec<String>, paren_depth: i32, paren_depth_stack: Vec<i32>, _config: ParserConfig, debug_trace: bool, trace_depth: usize, function_registry: Arc<FunctionRegistry>, generator_registry: Arc<GeneratorRegistry>, mode: ParserMode, }
77
78impl Parser {
79 #[must_use]
80 pub fn new(input: &str) -> Self {
81 Self::with_mode(input, ParserMode::default())
82 }
83
84 #[must_use]
86 pub fn with_mode(input: &str, mode: ParserMode) -> Self {
87 let lexer_mode = match mode {
89 ParserMode::Standard => LexerMode::SkipComments,
90 ParserMode::PreserveComments => LexerMode::PreserveComments,
91 };
92
93 let mut lexer = Lexer::with_mode(input, lexer_mode);
94 let current_token = lexer.next_token();
95 Self {
96 lexer,
97 current_token,
98 in_method_args: false,
99 columns: Vec::new(),
100 paren_depth: 0,
101 paren_depth_stack: Vec::new(),
102 _config: ParserConfig::default(),
103 debug_trace: false,
104 trace_depth: 0,
105 function_registry: Arc::new(FunctionRegistry::new()),
106 generator_registry: Arc::new(GeneratorRegistry::new()),
107 mode,
108 }
109 }
110
111 #[must_use]
112 pub fn with_config(input: &str, config: ParserConfig) -> Self {
113 let mut lexer = Lexer::new(input);
114 let current_token = lexer.next_token();
115 Self {
116 lexer,
117 current_token,
118 in_method_args: false,
119 columns: Vec::new(),
120 paren_depth: 0,
121 paren_depth_stack: Vec::new(),
122 _config: config,
123 debug_trace: false,
124 trace_depth: 0,
125 function_registry: Arc::new(FunctionRegistry::new()),
126 generator_registry: Arc::new(GeneratorRegistry::new()),
127 mode: ParserMode::default(),
128 }
129 }
130
131 #[must_use]
132 pub fn with_columns(mut self, columns: Vec<String>) -> Self {
133 self.columns = columns;
134 self
135 }
136
137 #[must_use]
138 pub fn with_debug_trace(mut self, enabled: bool) -> Self {
139 self.debug_trace = enabled;
140 self
141 }
142
143 #[must_use]
144 pub fn with_function_registry(mut self, registry: Arc<FunctionRegistry>) -> Self {
145 self.function_registry = registry;
146 self
147 }
148
149 #[must_use]
150 pub fn with_generator_registry(mut self, registry: Arc<GeneratorRegistry>) -> Self {
151 self.generator_registry = registry;
152 self
153 }
154
155 fn trace_enter(&mut self, context: &str) {
156 if self.debug_trace {
157 let indent = " ".repeat(self.trace_depth);
158 eprintln!("{}→ {} | Token: {:?}", indent, context, self.current_token);
159 self.trace_depth += 1;
160 }
161 }
162
163 fn trace_exit(&mut self, context: &str, result: &Result<impl std::fmt::Debug, String>) {
164 if self.debug_trace {
165 self.trace_depth = self.trace_depth.saturating_sub(1);
166 let indent = " ".repeat(self.trace_depth);
167 match result {
168 Ok(val) => eprintln!("{}← {} ✓ | Result: {:?}", indent, context, val),
169 Err(e) => eprintln!("{}← {} ✗ | Error: {}", indent, context, e),
170 }
171 }
172 }
173
174 fn trace_token(&self, action: &str) {
175 if self.debug_trace {
176 let indent = " ".repeat(self.trace_depth);
177 eprintln!("{} {} | Token: {:?}", indent, action, self.current_token);
178 }
179 }
180
181 #[allow(dead_code)]
182 fn peek_token(&self) -> Option<Token> {
183 let mut temp_lexer = self.lexer.clone();
185 let next_token = temp_lexer.next_token();
186 if matches!(next_token, Token::Eof) {
187 None
188 } else {
189 Some(next_token)
190 }
191 }
192
193 fn is_identifier_reserved(id: &str) -> bool {
198 let id_upper = id.to_uppercase();
199 matches!(
200 id_upper.as_str(),
201 "ORDER" | "HAVING" | "LIMIT" | "OFFSET" | "UNION" | "INTERSECT" | "EXCEPT"
202 )
203 }
204
205 const COMPARISON_OPERATORS: [&'static str; 6] = [" > ", " < ", " >= ", " <= ", " = ", " != "];
207
208 pub fn consume(&mut self, expected: Token) -> Result<(), String> {
209 self.trace_token(&format!("Consuming expected {:?}", expected));
210 if std::mem::discriminant(&self.current_token) == std::mem::discriminant(&expected) {
211 self.update_paren_depth(&expected)?;
213
214 self.current_token = self.lexer.next_token();
215 Ok(())
216 } else {
217 let error_msg = match (&expected, &self.current_token) {
219 (Token::RightParen, Token::Eof) if self.paren_depth > 0 => {
220 format!(
221 "Unclosed parenthesis - missing {} closing parenthes{}",
222 self.paren_depth,
223 if self.paren_depth == 1 { "is" } else { "es" }
224 )
225 }
226 (Token::RightParen, _) if self.paren_depth > 0 => {
227 format!(
228 "Expected closing parenthesis but found {:?} (currently {} unclosed parenthes{})",
229 self.current_token,
230 self.paren_depth,
231 if self.paren_depth == 1 { "is" } else { "es" }
232 )
233 }
234 _ => format!("Expected {:?}, found {:?}", expected, self.current_token),
235 };
236 Err(error_msg)
237 }
238 }
239
240 pub fn advance(&mut self) {
241 match &self.current_token {
243 Token::LeftParen => self.paren_depth += 1,
244 Token::RightParen => {
245 self.paren_depth -= 1;
246 }
249 _ => {}
250 }
251 let old_token = self.current_token.clone();
252 self.current_token = self.lexer.next_token();
253 if self.debug_trace {
254 let indent = " ".repeat(self.trace_depth);
255 eprintln!(
256 "{} Advanced: {:?} → {:?}",
257 indent, old_token, self.current_token
258 );
259 }
260 }
261
262 fn collect_leading_comments(&mut self) -> Vec<Comment> {
265 let mut comments = Vec::new();
266 loop {
267 match &self.current_token {
268 Token::LineComment(text) => {
269 comments.push(Comment::line(text.clone()));
270 self.advance();
271 }
272 Token::BlockComment(text) => {
273 comments.push(Comment::block(text.clone()));
274 self.advance();
275 }
276 _ => break,
277 }
278 }
279 comments
280 }
281
282 fn collect_trailing_comment(&mut self) -> Option<Comment> {
285 match &self.current_token {
286 Token::LineComment(text) => {
287 let comment = Some(Comment::line(text.clone()));
288 self.advance();
289 comment
290 }
291 Token::BlockComment(text) => {
292 let comment = Some(Comment::block(text.clone()));
293 self.advance();
294 comment
295 }
296 _ => None,
297 }
298 }
299
300 fn push_paren_depth(&mut self) {
301 self.paren_depth_stack.push(self.paren_depth);
302 self.paren_depth = 0;
303 }
304
305 fn pop_paren_depth(&mut self) {
306 if let Some(depth) = self.paren_depth_stack.pop() {
307 self.paren_depth = depth;
309 }
310 }
311
312 pub fn parse(&mut self) -> Result<SelectStatement, String> {
313 self.trace_enter("parse");
314
315 let leading_comments = if self.mode == ParserMode::PreserveComments {
318 self.collect_leading_comments()
319 } else {
320 vec![]
321 };
322
323 let result = if matches!(self.current_token, Token::With) {
325 let mut stmt = self.parse_with_clause()?;
326 stmt.leading_comments = leading_comments;
328 stmt
329 } else {
330 let stmt = self.parse_select_statement_with_comments_public(leading_comments)?;
332 self.check_balanced_parentheses()?;
333 stmt
334 };
335
336 self.trace_exit("parse", &Ok(&result));
337 Ok(result)
338 }
339
340 fn parse_select_statement_with_comments_public(
342 &mut self,
343 comments: Vec<Comment>,
344 ) -> Result<SelectStatement, String> {
345 self.parse_select_statement_with_comments(comments)
346 }
347
348 fn parse_with_clause(&mut self) -> Result<SelectStatement, String> {
349 self.consume(Token::With)?;
350 let ctes = self.parse_cte_list()?;
351
352 let mut main_query = self.parse_select_statement_inner_no_comments()?;
354 main_query.ctes = ctes;
355
356 self.check_balanced_parentheses()?;
358
359 Ok(main_query)
360 }
361
362 fn parse_with_clause_inner(&mut self) -> Result<SelectStatement, String> {
363 self.consume(Token::With)?;
364 let ctes = self.parse_cte_list()?;
365
366 let mut main_query = self.parse_select_statement_inner()?;
368 main_query.ctes = ctes;
369
370 Ok(main_query)
371 }
372
373 fn parse_cte_list(&mut self) -> Result<Vec<CTE>, String> {
375 let mut ctes = Vec::new();
376
377 loop {
379 let is_web = if matches!(&self.current_token, Token::Web) {
383 self.trace_token("Found WEB keyword for CTE");
384 self.advance();
385 true
386 } else {
387 false
388 };
389
390 let name = match &self.current_token {
392 Token::Identifier(name) => name.clone(),
393 token => {
394 if let Some(keyword) = token.as_keyword_str() {
396 keyword.to_lowercase()
398 } else {
399 return Err(format!(
400 "Expected CTE name after {}",
401 if is_web { "WEB" } else { "WITH or comma" }
402 ));
403 }
404 }
405 };
406 self.advance();
407
408 let column_list = if matches!(self.current_token, Token::LeftParen) {
410 self.advance();
411 let cols = self.parse_identifier_list()?;
412 self.consume(Token::RightParen)?;
413 Some(cols)
414 } else {
415 None
416 };
417
418 self.consume(Token::As)?;
420
421 let cte_type = if is_web {
422 self.consume(Token::LeftParen)?;
424 let web_spec = WebCteParser::parse(self)?;
426 self.consume(Token::RightParen)?;
428 CTEType::Web(web_spec)
429 } else {
430 self.push_paren_depth();
434 self.consume(Token::LeftParen)?;
435
436 let result = if matches!(&self.current_token, Token::File) {
437 self.trace_token("Found FILE keyword inside CTE parens");
438 self.advance();
439 let file_spec = FileCteParser::parse(self)?;
440 CTEType::File(file_spec)
441 } else {
442 let query = self.parse_select_statement_inner()?;
443 CTEType::Standard(query)
444 };
445
446 self.consume(Token::RightParen)?;
448 self.pop_paren_depth();
450 result
451 };
452
453 ctes.push(CTE {
454 name,
455 column_list,
456 cte_type,
457 });
458
459 if !matches!(self.current_token, Token::Comma) {
461 break;
462 }
463 self.advance();
464 }
465
466 Ok(ctes)
467 }
468
469 fn parse_optional_alias(&mut self) -> Result<Option<String>, String> {
471 if matches!(self.current_token, Token::As) {
472 self.advance();
473 match &self.current_token {
474 Token::Identifier(name) => {
475 let alias = name.clone();
476 self.advance();
477 Ok(Some(alias))
478 }
479 token => {
480 if let Some(keyword) = token.as_keyword_str() {
482 Err(format!(
483 "Reserved keyword '{}' cannot be used as column alias. Use a different name or quote it with double quotes: \"{}\"",
484 keyword,
485 keyword.to_lowercase()
486 ))
487 } else {
488 Err("Expected alias name after AS".to_string())
489 }
490 }
491 }
492 } else if let Token::Identifier(name) = &self.current_token {
493 let alias = name.clone();
495 self.advance();
496 Ok(Some(alias))
497 } else {
498 Ok(None)
499 }
500 }
501
502 fn is_valid_identifier(name: &str) -> bool {
504 if name.starts_with('"') && name.ends_with('"') {
505 true
507 } else {
508 name.chars().all(|c| c.is_alphanumeric() || c == '_')
510 }
511 }
512
513 fn update_paren_depth(&mut self, token: &Token) -> Result<(), String> {
515 match token {
516 Token::LeftParen => self.paren_depth += 1,
517 Token::RightParen => {
518 self.paren_depth -= 1;
519 if self.paren_depth < 0 {
521 return Err(
522 "Unexpected closing parenthesis - no matching opening parenthesis"
523 .to_string(),
524 );
525 }
526 }
527 _ => {}
528 }
529 Ok(())
530 }
531
532 fn parse_argument_list(&mut self) -> Result<Vec<SqlExpression>, String> {
534 let mut args = Vec::new();
535
536 if !matches!(self.current_token, Token::RightParen) {
537 loop {
538 args.push(self.parse_expression()?);
539
540 if matches!(self.current_token, Token::Comma) {
541 self.advance();
542 } else {
543 break;
544 }
545 }
546 }
547
548 Ok(args)
549 }
550
551 fn check_balanced_parentheses(&self) -> Result<(), String> {
553 if self.paren_depth > 0 {
554 Err(format!(
555 "Unclosed parenthesis - missing {} closing parenthes{}",
556 self.paren_depth,
557 if self.paren_depth == 1 { "is" } else { "es" }
558 ))
559 } else if self.paren_depth < 0 {
560 Err("Extra closing parenthesis found - no matching opening parenthesis".to_string())
561 } else {
562 Ok(())
563 }
564 }
565
566 fn contains_aggregate_function(expr: &SqlExpression) -> bool {
569 match expr {
570 SqlExpression::FunctionCall { name, args, .. } => {
571 let upper_name = name.to_uppercase();
573 let is_aggregate = matches!(
574 upper_name.as_str(),
575 "COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "GROUP_CONCAT" | "STRING_AGG"
576 );
577
578 is_aggregate || args.iter().any(Self::contains_aggregate_function)
581 }
582 SqlExpression::BinaryOp { left, right, .. } => {
584 Self::contains_aggregate_function(left) || Self::contains_aggregate_function(right)
585 }
586 SqlExpression::Not { expr } => Self::contains_aggregate_function(expr),
587 SqlExpression::MethodCall { args, .. } => {
588 args.iter().any(Self::contains_aggregate_function)
589 }
590 SqlExpression::ChainedMethodCall { base, args, .. } => {
591 Self::contains_aggregate_function(base)
592 || args.iter().any(Self::contains_aggregate_function)
593 }
594 SqlExpression::CaseExpression {
595 when_branches,
596 else_branch,
597 } => {
598 when_branches.iter().any(|branch| {
599 Self::contains_aggregate_function(&branch.condition)
600 || Self::contains_aggregate_function(&branch.result)
601 }) || else_branch
602 .as_ref()
603 .map_or(false, |e| Self::contains_aggregate_function(e))
604 }
605 SqlExpression::SimpleCaseExpression {
606 expr,
607 when_branches,
608 else_branch,
609 } => {
610 Self::contains_aggregate_function(expr)
611 || when_branches.iter().any(|branch| {
612 Self::contains_aggregate_function(&branch.value)
613 || Self::contains_aggregate_function(&branch.result)
614 })
615 || else_branch
616 .as_ref()
617 .map_or(false, |e| Self::contains_aggregate_function(e))
618 }
619 SqlExpression::ScalarSubquery { query } => {
620 query
623 .having
624 .as_ref()
625 .map_or(false, |h| Self::contains_aggregate_function(h))
626 }
627 SqlExpression::Column(_)
629 | SqlExpression::StringLiteral(_)
630 | SqlExpression::NumberLiteral(_)
631 | SqlExpression::BooleanLiteral(_)
632 | SqlExpression::Null
633 | SqlExpression::DateTimeConstructor { .. }
634 | SqlExpression::DateTimeToday { .. } => false,
635
636 SqlExpression::WindowFunction { .. } => true,
638
639 SqlExpression::Between { expr, lower, upper } => {
641 Self::contains_aggregate_function(expr)
642 || Self::contains_aggregate_function(lower)
643 || Self::contains_aggregate_function(upper)
644 }
645
646 SqlExpression::InList { expr, values } | SqlExpression::NotInList { expr, values } => {
648 Self::contains_aggregate_function(expr)
649 || values.iter().any(Self::contains_aggregate_function)
650 }
651
652 SqlExpression::InSubquery { expr, subquery }
654 | SqlExpression::NotInSubquery { expr, subquery } => {
655 Self::contains_aggregate_function(expr)
656 || subquery
657 .having
658 .as_ref()
659 .map_or(false, |h| Self::contains_aggregate_function(h))
660 }
661
662 SqlExpression::Unnest { column, .. } => Self::contains_aggregate_function(column),
664 }
665 }
666
667 fn parse_select_statement(&mut self) -> Result<SelectStatement, String> {
668 self.trace_enter("parse_select_statement");
669 let result = self.parse_select_statement_inner()?;
670
671 self.check_balanced_parentheses()?;
673
674 Ok(result)
675 }
676
677 fn parse_select_statement_inner(&mut self) -> Result<SelectStatement, String> {
678 let leading_comments = if self.mode == ParserMode::PreserveComments {
680 self.collect_leading_comments()
681 } else {
682 vec![]
683 };
684
685 self.parse_select_statement_with_comments(leading_comments)
686 }
687
688 fn parse_select_statement_inner_no_comments(&mut self) -> Result<SelectStatement, String> {
691 self.parse_select_statement_with_comments(vec![])
692 }
693
694 fn parse_select_statement_with_comments(
696 &mut self,
697 leading_comments: Vec<Comment>,
698 ) -> Result<SelectStatement, String> {
699 self.consume(Token::Select)?;
700
701 let distinct = if matches!(self.current_token, Token::Distinct) {
703 self.advance();
704 true
705 } else {
706 false
707 };
708
709 let select_items = self.parse_select_items()?;
711
712 let columns = select_items
714 .iter()
715 .map(|item| match item {
716 SelectItem::Star { .. } => "*".to_string(),
717 SelectItem::StarExclude { .. } => "*".to_string(), SelectItem::Column {
719 column: col_ref, ..
720 } => col_ref.name.clone(),
721 SelectItem::Expression { alias, .. } => alias.clone(),
722 })
723 .collect();
724
725 let into_table = if matches!(self.current_token, Token::Into) {
727 self.advance();
728 Some(self.parse_into_clause()?)
729 } else {
730 None
731 };
732
733 let (from_table, from_subquery, from_function, from_alias) = if matches!(
735 self.current_token,
736 Token::From
737 ) {
738 self.advance();
739
740 let table_or_function_name = match &self.current_token {
743 Token::Identifier(name) => Some(name.clone()),
744 token => {
745 token.as_keyword_str().map(|k| k.to_lowercase())
747 }
748 };
749
750 if let Some(name) = table_or_function_name {
751 let has_paren = self.peek_token() == Some(Token::LeftParen);
755 if self.debug_trace {
756 eprintln!(
757 " Checking {} for table function, has_paren={}",
758 name, has_paren
759 );
760 }
761
762 let is_table_function = if has_paren {
765 if self.debug_trace {
767 eprintln!(" Checking generator registry for {}", name.to_uppercase());
768 }
769 if let Some(_gen) = self.generator_registry.get(&name.to_uppercase()) {
770 if self.debug_trace {
771 eprintln!(" Found {} in generator registry", name);
772 }
773 self.trace_token(&format!("Found generator: {}", name));
774 true
775 } else {
776 if let Some(func) = self.function_registry.get(&name.to_uppercase()) {
778 let sig = func.signature();
779 let is_table_fn = sig.category == FunctionCategory::TableFunction;
780 if self.debug_trace {
781 eprintln!(
782 " Found {} in function registry, is_table_function={}",
783 name, is_table_fn
784 );
785 }
786 if is_table_fn {
787 self.trace_token(&format!(
788 "Found table function in function registry: {}",
789 name
790 ));
791 }
792 is_table_fn
793 } else {
794 if self.debug_trace {
795 eprintln!(" {} not found in either registry", name);
796 self.trace_token(&format!(
797 "Not found as generator or table function: {}",
798 name
799 ));
800 }
801 false
802 }
803 }
804 } else {
805 if self.debug_trace {
806 eprintln!(" No parenthesis after {}, treating as table", name);
807 }
808 false
809 };
810
811 if is_table_function {
812 let function_name = name.clone();
814 self.advance(); self.consume(Token::LeftParen)?;
818 let args = self.parse_argument_list()?;
819 self.consume(Token::RightParen)?;
820
821 let alias = if matches!(self.current_token, Token::As) {
823 self.advance();
824 match &self.current_token {
825 Token::Identifier(name) => {
826 let alias = name.clone();
827 self.advance();
828 Some(alias)
829 }
830 token => {
831 if let Some(keyword) = token.as_keyword_str() {
832 return Err(format!(
833 "Reserved keyword '{}' cannot be used as column alias. Use a different name or quote it with double quotes: \"{}\"",
834 keyword,
835 keyword.to_lowercase()
836 ));
837 } else {
838 return Err("Expected alias name after AS".to_string());
839 }
840 }
841 }
842 } else if let Token::Identifier(name) = &self.current_token {
843 let alias = name.clone();
844 self.advance();
845 Some(alias)
846 } else {
847 None
848 };
849
850 (
851 None,
852 None,
853 Some(TableFunction::Generator {
854 name: function_name,
855 args,
856 }),
857 alias,
858 )
859 } else {
860 let table_name = name.clone();
862 self.advance();
863
864 let alias = self.parse_optional_alias()?;
866
867 (Some(table_name), None, None, alias)
868 }
869 } else if matches!(self.current_token, Token::LeftParen) {
870 self.advance();
872
873 let subquery = if matches!(self.current_token, Token::With) {
875 self.parse_with_clause_inner()?
876 } else {
877 self.parse_select_statement_inner()?
878 };
879
880 self.consume(Token::RightParen)?;
881
882 let alias = if matches!(self.current_token, Token::As) {
884 self.advance();
885 match &self.current_token {
886 Token::Identifier(name) => {
887 let alias = name.clone();
888 self.advance();
889 alias
890 }
891 token => {
892 if let Some(keyword) = token.as_keyword_str() {
893 return Err(format!(
894 "Reserved keyword '{}' cannot be used as subquery alias. Use a different name or quote it with double quotes: \"{}\"",
895 keyword,
896 keyword.to_lowercase()
897 ));
898 } else {
899 return Err("Expected alias name after AS".to_string());
900 }
901 }
902 }
903 } else {
904 match &self.current_token {
906 Token::Identifier(name) => {
907 let alias = name.clone();
908 self.advance();
909 alias
910 }
911 _ => {
912 return Err(
913 "Subquery in FROM must have an alias (e.g., AS t)".to_string()
914 )
915 }
916 }
917 };
918
919 (None, Some(Box::new(subquery)), None, Some(alias))
920 } else {
921 let table_name = match &self.current_token {
923 Token::Identifier(table) => table.clone(),
924 Token::QuotedIdentifier(table) => table.clone(),
925 token => {
926 if let Some(keyword) = token.as_keyword_str() {
928 keyword.to_lowercase()
929 } else {
930 return Err("Expected table name or subquery after FROM".to_string());
931 }
932 }
933 };
934
935 self.advance();
936
937 let alias = self.parse_optional_alias()?;
939
940 (Some(table_name), None, None, alias)
941 }
942 } else {
943 (None, None, None, None)
944 };
945
946 let pivot_source = if matches!(self.current_token, Token::Pivot) {
950 let source = if let Some(ref table_name) = from_table {
952 TableSource::Table(table_name.clone())
953 } else if let Some(ref subquery) = from_subquery {
954 TableSource::DerivedTable {
955 query: subquery.clone(),
956 alias: from_alias.clone().unwrap_or_default(),
957 }
958 } else {
959 return Err("PIVOT requires a table or subquery source".to_string());
960 };
961
962 let pivoted = self.parse_pivot_clause(source)?;
964 Some(pivoted)
965 } else {
966 None
967 };
968
969 let mut joins = Vec::new();
971 while self.is_join_token() {
972 joins.push(self.parse_join_clause()?);
973 }
974
975 let where_clause = if matches!(self.current_token, Token::Where) {
976 self.advance();
977 Some(self.parse_where_clause()?)
978 } else {
979 None
980 };
981
982 let group_by = if matches!(self.current_token, Token::GroupBy) {
983 self.advance();
984 Some(self.parse_expression_list()?)
987 } else {
988 None
989 };
990
991 let having = if matches!(self.current_token, Token::Having) {
993 if group_by.is_none() {
994 return Err("HAVING clause requires GROUP BY".to_string());
995 }
996 self.advance();
997 let having_expr = self.parse_expression()?;
998
999 Some(having_expr)
1004 } else {
1005 None
1006 };
1007
1008 let qualify = if matches!(self.current_token, Token::Qualify) {
1012 self.advance();
1013 let qualify_expr = self.parse_expression()?;
1014
1015 Some(qualify_expr)
1019 } else {
1020 None
1021 };
1022
1023 let order_by = if matches!(self.current_token, Token::OrderBy) {
1025 self.trace_token("Found OrderBy token");
1026 self.advance();
1027 Some(self.parse_order_by_list()?)
1028 } else if let Token::Identifier(s) = &self.current_token {
1029 if Self::is_identifier_reserved(s) && s.to_uppercase() == "ORDER" {
1032 self.trace_token("Warning: ORDER as identifier instead of OrderBy token");
1033 self.advance(); if matches!(&self.current_token, Token::By) {
1035 self.advance(); Some(self.parse_order_by_list()?)
1037 } else {
1038 return Err("Expected BY after ORDER".to_string());
1039 }
1040 } else {
1041 None
1042 }
1043 } else {
1044 None
1045 };
1046
1047 let limit = if matches!(self.current_token, Token::Limit) {
1049 self.advance();
1050 match &self.current_token {
1051 Token::NumberLiteral(num) => {
1052 let limit_val = num
1053 .parse::<usize>()
1054 .map_err(|_| format!("Invalid LIMIT value: {num}"))?;
1055 self.advance();
1056 Some(limit_val)
1057 }
1058 _ => return Err("Expected number after LIMIT".to_string()),
1059 }
1060 } else {
1061 None
1062 };
1063
1064 let offset = if matches!(self.current_token, Token::Offset) {
1066 self.advance();
1067 match &self.current_token {
1068 Token::NumberLiteral(num) => {
1069 let offset_val = num
1070 .parse::<usize>()
1071 .map_err(|_| format!("Invalid OFFSET value: {num}"))?;
1072 self.advance();
1073 Some(offset_val)
1074 }
1075 _ => return Err("Expected number after OFFSET".to_string()),
1076 }
1077 } else {
1078 None
1079 };
1080
1081 let into_table = if into_table.is_none() && matches!(self.current_token, Token::Into) {
1085 self.advance();
1086 Some(self.parse_into_clause()?)
1087 } else {
1088 into_table };
1090
1091 let set_operations = self.parse_set_operations()?;
1093
1094 let trailing_comment = if self.mode == ParserMode::PreserveComments {
1096 self.collect_trailing_comment()
1097 } else {
1098 None
1099 };
1100
1101 let from_source = if let Some(pivot) = pivot_source {
1104 Some(pivot)
1105 } else if let Some(ref table_name) = from_table {
1106 Some(TableSource::Table(table_name.clone()))
1107 } else if let Some(ref subquery) = from_subquery {
1108 Some(TableSource::DerivedTable {
1109 query: subquery.clone(),
1110 alias: from_alias.clone().unwrap_or_default(),
1111 })
1112 } else if let Some(ref _func) = from_function {
1113 None
1116 } else {
1117 None
1118 };
1119
1120 Ok(SelectStatement {
1121 distinct,
1122 columns,
1123 select_items,
1124 from_source,
1125 #[allow(deprecated)]
1126 from_table,
1127 #[allow(deprecated)]
1128 from_subquery,
1129 #[allow(deprecated)]
1130 from_function,
1131 #[allow(deprecated)]
1132 from_alias,
1133 joins,
1134 where_clause,
1135 order_by,
1136 group_by,
1137 having,
1138 qualify,
1139 limit,
1140 offset,
1141 ctes: Vec::new(), into_table,
1143 set_operations,
1144 leading_comments,
1145 trailing_comment,
1146 })
1147 }
1148
1149 fn parse_set_operations(
1152 &mut self,
1153 ) -> Result<Vec<(SetOperation, Box<SelectStatement>)>, String> {
1154 let mut operations = Vec::new();
1155
1156 while matches!(
1157 self.current_token,
1158 Token::Union | Token::Intersect | Token::Except
1159 ) {
1160 let operation = match &self.current_token {
1162 Token::Union => {
1163 self.advance();
1164 if let Token::Identifier(id) = &self.current_token {
1166 if id.to_uppercase() == "ALL" {
1167 self.advance();
1168 SetOperation::UnionAll
1169 } else {
1170 SetOperation::Union
1171 }
1172 } else {
1173 SetOperation::Union
1174 }
1175 }
1176 Token::Intersect => {
1177 self.advance();
1178 SetOperation::Intersect
1179 }
1180 Token::Except => {
1181 self.advance();
1182 SetOperation::Except
1183 }
1184 _ => unreachable!(),
1185 };
1186
1187 let next_select = self.parse_select_statement_inner()?;
1189
1190 operations.push((operation, Box::new(next_select)));
1191 }
1192
1193 Ok(operations)
1194 }
1195
1196 fn parse_select_items(&mut self) -> Result<Vec<SelectItem>, String> {
1198 let mut items = Vec::new();
1199
1200 loop {
1201 if let Token::Identifier(name) = &self.current_token.clone() {
1204 let saved_pos = self.lexer.clone();
1206 let saved_token = self.current_token.clone();
1207 let table_name = name.clone();
1208
1209 self.advance();
1210
1211 if matches!(self.current_token, Token::Dot) {
1212 self.advance();
1213 if matches!(self.current_token, Token::Star) {
1214 items.push(SelectItem::Star {
1216 table_prefix: Some(table_name),
1217 leading_comments: vec![],
1218 trailing_comment: None,
1219 });
1220 self.advance();
1221
1222 if matches!(self.current_token, Token::Comma) {
1224 self.advance();
1225 continue;
1226 } else {
1227 break;
1228 }
1229 }
1230 }
1231
1232 self.lexer = saved_pos;
1234 self.current_token = saved_token;
1235 }
1236
1237 if matches!(self.current_token, Token::Star) {
1239 self.advance(); if matches!(self.current_token, Token::Exclude) {
1243 self.advance(); if !matches!(self.current_token, Token::LeftParen) {
1247 return Err("Expected '(' after EXCLUDE".to_string());
1248 }
1249 self.advance(); let mut excluded_columns = Vec::new();
1253 loop {
1254 match &self.current_token {
1255 Token::Identifier(col_name) | Token::QuotedIdentifier(col_name) => {
1256 excluded_columns.push(col_name.clone());
1257 self.advance();
1258 }
1259 _ => return Err("Expected column name in EXCLUDE list".to_string()),
1260 }
1261
1262 if matches!(self.current_token, Token::Comma) {
1264 self.advance();
1265 } else if matches!(self.current_token, Token::RightParen) {
1266 self.advance(); break;
1268 } else {
1269 return Err("Expected ',' or ')' in EXCLUDE list".to_string());
1270 }
1271 }
1272
1273 if excluded_columns.is_empty() {
1274 return Err("EXCLUDE list cannot be empty".to_string());
1275 }
1276
1277 items.push(SelectItem::StarExclude {
1278 table_prefix: None,
1279 excluded_columns,
1280 leading_comments: vec![],
1281 trailing_comment: None,
1282 });
1283 } else {
1284 items.push(SelectItem::Star {
1286 table_prefix: None,
1287 leading_comments: vec![],
1288 trailing_comment: None,
1289 });
1290 }
1291 } else {
1292 let expr = self.parse_comparison()?; let alias = if matches!(self.current_token, Token::As) {
1297 self.advance();
1298 match &self.current_token {
1299 Token::Identifier(alias_name) => {
1300 let alias = alias_name.clone();
1301 self.advance();
1302 alias
1303 }
1304 Token::QuotedIdentifier(alias_name) => {
1305 let alias = alias_name.clone();
1306 self.advance();
1307 alias
1308 }
1309 token => {
1310 if let Some(keyword) = token.as_keyword_str() {
1311 return Err(format!(
1312 "Reserved keyword '{}' cannot be used as column alias. Use a different name or quote it with double quotes: \"{}\"",
1313 keyword,
1314 keyword.to_lowercase()
1315 ));
1316 } else {
1317 return Err("Expected alias name after AS".to_string());
1318 }
1319 }
1320 }
1321 } else {
1322 match &expr {
1324 SqlExpression::Column(col_ref) => col_ref.name.clone(),
1325 _ => format!("expr_{}", items.len() + 1), }
1327 };
1328
1329 let item = match expr {
1331 SqlExpression::Column(col_ref) if alias == col_ref.name => {
1332 SelectItem::Column {
1334 column: col_ref,
1335 leading_comments: vec![],
1336 trailing_comment: None,
1337 }
1338 }
1339 _ => {
1340 SelectItem::Expression {
1342 expr,
1343 alias,
1344 leading_comments: vec![],
1345 trailing_comment: None,
1346 }
1347 }
1348 };
1349
1350 items.push(item);
1351 }
1352
1353 if matches!(self.current_token, Token::Comma) {
1355 self.advance();
1356 } else {
1357 break;
1358 }
1359 }
1360
1361 Ok(items)
1362 }
1363
1364 fn parse_identifier_list(&mut self) -> Result<Vec<String>, String> {
1365 let mut identifiers = Vec::new();
1366
1367 loop {
1368 match &self.current_token {
1369 Token::Identifier(id) => {
1370 if Self::is_identifier_reserved(id) {
1372 break;
1374 }
1375 identifiers.push(id.clone());
1376 self.advance();
1377 }
1378 Token::QuotedIdentifier(id) => {
1379 identifiers.push(id.clone());
1381 self.advance();
1382 }
1383 _ => {
1384 break;
1386 }
1387 }
1388
1389 if matches!(self.current_token, Token::Comma) {
1390 self.advance();
1391 } else {
1392 break;
1393 }
1394 }
1395
1396 if identifiers.is_empty() {
1397 return Err("Expected at least one identifier".to_string());
1398 }
1399
1400 Ok(identifiers)
1401 }
1402
1403 fn parse_window_spec(&mut self) -> Result<WindowSpec, String> {
1404 let mut partition_by = Vec::new();
1405 let mut order_by = Vec::new();
1406
1407 if matches!(self.current_token, Token::Partition) {
1409 self.advance(); if !matches!(self.current_token, Token::By) {
1411 return Err("Expected BY after PARTITION".to_string());
1412 }
1413 self.advance(); partition_by = self.parse_identifier_list()?;
1417 }
1418
1419 if matches!(self.current_token, Token::OrderBy) {
1421 self.advance(); order_by = self.parse_order_by_list()?;
1423 } else if let Token::Identifier(s) = &self.current_token {
1424 if Self::is_identifier_reserved(s) && s.to_uppercase() == "ORDER" {
1425 self.advance(); if !matches!(self.current_token, Token::By) {
1428 return Err("Expected BY after ORDER".to_string());
1429 }
1430 self.advance(); order_by = self.parse_order_by_list()?;
1432 }
1433 }
1434
1435 let mut frame = self.parse_window_frame()?;
1437
1438 if !order_by.is_empty() && frame.is_none() {
1442 frame = Some(WindowFrame {
1443 unit: FrameUnit::Range,
1444 start: FrameBound::UnboundedPreceding,
1445 end: Some(FrameBound::CurrentRow),
1446 });
1447 }
1448
1449 Ok(WindowSpec {
1450 partition_by,
1451 order_by,
1452 frame,
1453 })
1454 }
1455
1456 fn parse_order_by_list(&mut self) -> Result<Vec<OrderByItem>, String> {
1457 let mut order_items = Vec::new();
1458
1459 loop {
1460 let expr = self.parse_expression()?;
1468
1469 let direction = match &self.current_token {
1471 Token::Asc => {
1472 self.advance();
1473 SortDirection::Asc
1474 }
1475 Token::Desc => {
1476 self.advance();
1477 SortDirection::Desc
1478 }
1479 _ => SortDirection::Asc, };
1481
1482 order_items.push(OrderByItem { expr, direction });
1483
1484 if matches!(self.current_token, Token::Comma) {
1485 self.advance();
1486 } else {
1487 break;
1488 }
1489 }
1490
1491 Ok(order_items)
1492 }
1493
1494 fn parse_into_clause(&mut self) -> Result<IntoTable, String> {
1497 let name = match &self.current_token {
1499 Token::Identifier(id) if id.starts_with('#') => {
1500 let table_name = id.clone();
1501 self.advance();
1502 table_name
1503 }
1504 Token::Identifier(id) => {
1505 return Err(format!(
1506 "Temporary table name must start with #, got: {}",
1507 id
1508 ));
1509 }
1510 _ => {
1511 return Err(
1512 "Expected temporary table name (starting with #) after INTO".to_string()
1513 );
1514 }
1515 };
1516
1517 Ok(IntoTable { name })
1518 }
1519
1520 fn parse_window_frame(&mut self) -> Result<Option<WindowFrame>, String> {
1521 let unit = match &self.current_token {
1523 Token::Rows => {
1524 self.advance();
1525 FrameUnit::Rows
1526 }
1527 Token::Identifier(id) if id.to_uppercase() == "RANGE" => {
1528 self.advance();
1530 FrameUnit::Range
1531 }
1532 _ => return Ok(None), };
1534
1535 let (start, end) = if let Token::Between = &self.current_token {
1537 self.advance(); let start = self.parse_frame_bound()?;
1540
1541 if !matches!(&self.current_token, Token::And) {
1543 return Err("Expected AND after window frame start bound".to_string());
1544 }
1545 self.advance();
1546
1547 let end = self.parse_frame_bound()?;
1549 (start, Some(end))
1550 } else {
1551 let bound = self.parse_frame_bound()?;
1553 (bound, None)
1554 };
1555
1556 Ok(Some(WindowFrame { unit, start, end }))
1557 }
1558
1559 fn parse_frame_bound(&mut self) -> Result<FrameBound, String> {
1560 match &self.current_token {
1561 Token::Unbounded => {
1562 self.advance();
1563 match &self.current_token {
1564 Token::Preceding => {
1565 self.advance();
1566 Ok(FrameBound::UnboundedPreceding)
1567 }
1568 Token::Following => {
1569 self.advance();
1570 Ok(FrameBound::UnboundedFollowing)
1571 }
1572 _ => Err("Expected PRECEDING or FOLLOWING after UNBOUNDED".to_string()),
1573 }
1574 }
1575 Token::Current => {
1576 self.advance();
1577 if matches!(&self.current_token, Token::Row) {
1578 self.advance();
1579 return Ok(FrameBound::CurrentRow);
1580 }
1581 Err("Expected ROW after CURRENT".to_string())
1582 }
1583 Token::NumberLiteral(num) => {
1584 let n: i64 = num
1585 .parse()
1586 .map_err(|_| "Invalid number in window frame".to_string())?;
1587 self.advance();
1588 match &self.current_token {
1589 Token::Preceding => {
1590 self.advance();
1591 Ok(FrameBound::Preceding(n))
1592 }
1593 Token::Following => {
1594 self.advance();
1595 Ok(FrameBound::Following(n))
1596 }
1597 _ => Err("Expected PRECEDING or FOLLOWING after number".to_string()),
1598 }
1599 }
1600 _ => Err("Invalid window frame bound".to_string()),
1601 }
1602 }
1603
1604 fn parse_where_clause(&mut self) -> Result<WhereClause, String> {
1605 let expr = self.parse_expression()?;
1608
1609 if matches!(self.current_token, Token::RightParen) && self.paren_depth <= 0 {
1611 return Err(
1612 "Unexpected closing parenthesis - no matching opening parenthesis".to_string(),
1613 );
1614 }
1615
1616 let conditions = vec![Condition {
1618 expr,
1619 connector: None,
1620 }];
1621
1622 Ok(WhereClause { conditions })
1623 }
1624
1625 fn parse_expression(&mut self) -> Result<SqlExpression, String> {
1626 self.trace_enter("parse_expression");
1627 let mut left = self.parse_logical_or()?;
1630
1631 left = parse_in_operator(self, left)?;
1634
1635 let result = Ok(left);
1636 self.trace_exit("parse_expression", &result);
1637 result
1638 }
1639
1640 fn parse_comparison(&mut self) -> Result<SqlExpression, String> {
1641 parse_comparison_expr(self)
1643 }
1644
1645 fn parse_additive(&mut self) -> Result<SqlExpression, String> {
1646 parse_additive_expr(self)
1648 }
1649
1650 fn parse_multiplicative(&mut self) -> Result<SqlExpression, String> {
1651 parse_multiplicative_expr(self)
1653 }
1654
1655 fn parse_logical_or(&mut self) -> Result<SqlExpression, String> {
1656 parse_logical_or_expr(self)
1658 }
1659
1660 fn parse_logical_and(&mut self) -> Result<SqlExpression, String> {
1661 parse_logical_and_expr(self)
1663 }
1664
1665 fn parse_case_expression(&mut self) -> Result<SqlExpression, String> {
1666 parse_case_expr(self)
1668 }
1669
1670 fn parse_primary(&mut self) -> Result<SqlExpression, String> {
1671 let columns = self.columns.clone();
1674 let in_method_args = self.in_method_args;
1675 let ctx = PrimaryExpressionContext {
1676 columns: &columns,
1677 in_method_args,
1678 };
1679 parse_primary_expr(self, &ctx)
1680 }
1681
1682 fn parse_method_args(&mut self) -> Result<Vec<SqlExpression>, String> {
1684 self.in_method_args = true;
1686
1687 let args = self.parse_argument_list()?;
1688
1689 self.in_method_args = false;
1691
1692 Ok(args)
1693 }
1694
1695 fn parse_function_args(&mut self) -> Result<(Vec<SqlExpression>, bool), String> {
1696 let mut args = Vec::new();
1697 let mut has_distinct = false;
1698
1699 if !matches!(self.current_token, Token::RightParen) {
1700 if matches!(self.current_token, Token::Distinct) {
1702 self.advance(); has_distinct = true;
1704 }
1705
1706 args.push(self.parse_additive()?);
1708
1709 while matches!(self.current_token, Token::Comma) {
1711 self.advance();
1712 args.push(self.parse_additive()?);
1713 }
1714 }
1715
1716 Ok((args, has_distinct))
1717 }
1718
1719 fn parse_expression_list(&mut self) -> Result<Vec<SqlExpression>, String> {
1720 let mut expressions = Vec::new();
1721
1722 loop {
1723 expressions.push(self.parse_expression()?);
1724
1725 if matches!(self.current_token, Token::Comma) {
1726 self.advance();
1727 } else {
1728 break;
1729 }
1730 }
1731
1732 Ok(expressions)
1733 }
1734
1735 #[must_use]
1736 pub fn get_position(&self) -> usize {
1737 self.lexer.get_position()
1738 }
1739
1740 fn is_join_token(&self) -> bool {
1742 matches!(
1743 self.current_token,
1744 Token::Join | Token::Inner | Token::Left | Token::Right | Token::Full | Token::Cross
1745 )
1746 }
1747
1748 fn parse_join_clause(&mut self) -> Result<JoinClause, String> {
1750 let join_type = match &self.current_token {
1752 Token::Join => {
1753 self.advance();
1754 JoinType::Inner }
1756 Token::Inner => {
1757 self.advance();
1758 if !matches!(self.current_token, Token::Join) {
1759 return Err("Expected JOIN after INNER".to_string());
1760 }
1761 self.advance();
1762 JoinType::Inner
1763 }
1764 Token::Left => {
1765 self.advance();
1766 if matches!(self.current_token, Token::Outer) {
1768 self.advance();
1769 }
1770 if !matches!(self.current_token, Token::Join) {
1771 return Err("Expected JOIN after LEFT".to_string());
1772 }
1773 self.advance();
1774 JoinType::Left
1775 }
1776 Token::Right => {
1777 self.advance();
1778 if matches!(self.current_token, Token::Outer) {
1780 self.advance();
1781 }
1782 if !matches!(self.current_token, Token::Join) {
1783 return Err("Expected JOIN after RIGHT".to_string());
1784 }
1785 self.advance();
1786 JoinType::Right
1787 }
1788 Token::Full => {
1789 self.advance();
1790 if matches!(self.current_token, Token::Outer) {
1792 self.advance();
1793 }
1794 if !matches!(self.current_token, Token::Join) {
1795 return Err("Expected JOIN after FULL".to_string());
1796 }
1797 self.advance();
1798 JoinType::Full
1799 }
1800 Token::Cross => {
1801 self.advance();
1802 if !matches!(self.current_token, Token::Join) {
1803 return Err("Expected JOIN after CROSS".to_string());
1804 }
1805 self.advance();
1806 JoinType::Cross
1807 }
1808 _ => return Err("Expected JOIN keyword".to_string()),
1809 };
1810
1811 let (table, alias) = self.parse_join_table_source()?;
1813
1814 let condition = if join_type == JoinType::Cross {
1816 JoinCondition { conditions: vec![] }
1818 } else {
1819 if !matches!(self.current_token, Token::On) {
1820 return Err("Expected ON keyword after JOIN table".to_string());
1821 }
1822 self.advance();
1823 self.parse_join_condition()?
1824 };
1825
1826 Ok(JoinClause {
1827 join_type,
1828 table,
1829 alias,
1830 condition,
1831 })
1832 }
1833
1834 fn parse_join_table_source(&mut self) -> Result<(TableSource, Option<String>), String> {
1835 let table = match &self.current_token {
1836 Token::Identifier(name) => {
1837 let table_name = name.clone();
1838 self.advance();
1839 TableSource::Table(table_name)
1840 }
1841 Token::LeftParen => {
1842 self.advance();
1844 let subquery = self.parse_select_statement_inner()?;
1845 if !matches!(self.current_token, Token::RightParen) {
1846 return Err("Expected ')' after subquery".to_string());
1847 }
1848 self.advance();
1849
1850 let alias = match &self.current_token {
1852 Token::Identifier(alias_name) => {
1853 let alias = alias_name.clone();
1854 self.advance();
1855 alias
1856 }
1857 Token::As => {
1858 self.advance();
1859 match &self.current_token {
1860 Token::Identifier(alias_name) => {
1861 let alias = alias_name.clone();
1862 self.advance();
1863 alias
1864 }
1865 _ => return Err("Expected alias after AS keyword".to_string()),
1866 }
1867 }
1868 _ => return Err("Subqueries must have an alias".to_string()),
1869 };
1870
1871 return Ok((
1872 TableSource::DerivedTable {
1873 query: Box::new(subquery),
1874 alias: alias.clone(),
1875 },
1876 Some(alias),
1877 ));
1878 }
1879 _ => return Err("Expected table name or subquery in JOIN clause".to_string()),
1880 };
1881
1882 let alias = match &self.current_token {
1884 Token::Identifier(alias_name) => {
1885 let alias = alias_name.clone();
1886 self.advance();
1887 Some(alias)
1888 }
1889 Token::As => {
1890 self.advance();
1891 match &self.current_token {
1892 Token::Identifier(alias_name) => {
1893 let alias = alias_name.clone();
1894 self.advance();
1895 Some(alias)
1896 }
1897 _ => return Err("Expected alias after AS keyword".to_string()),
1898 }
1899 }
1900 _ => None,
1901 };
1902
1903 Ok((table, alias))
1904 }
1905
1906 fn parse_join_condition(&mut self) -> Result<JoinCondition, String> {
1907 let mut conditions = Vec::new();
1908
1909 conditions.push(self.parse_single_join_condition()?);
1911
1912 while matches!(self.current_token, Token::And) {
1914 self.advance(); conditions.push(self.parse_single_join_condition()?);
1916 }
1917
1918 Ok(JoinCondition { conditions })
1919 }
1920
1921 fn parse_single_join_condition(&mut self) -> Result<SingleJoinCondition, String> {
1922 let left_expr = self.parse_additive()?;
1925
1926 let operator = match &self.current_token {
1928 Token::Equal => JoinOperator::Equal,
1929 Token::NotEqual => JoinOperator::NotEqual,
1930 Token::LessThan => JoinOperator::LessThan,
1931 Token::LessThanOrEqual => JoinOperator::LessThanOrEqual,
1932 Token::GreaterThan => JoinOperator::GreaterThan,
1933 Token::GreaterThanOrEqual => JoinOperator::GreaterThanOrEqual,
1934 _ => return Err("Expected comparison operator in JOIN condition".to_string()),
1935 };
1936 self.advance();
1937
1938 let right_expr = self.parse_additive()?;
1940
1941 Ok(SingleJoinCondition {
1942 left_expr,
1943 operator,
1944 right_expr,
1945 })
1946 }
1947
1948 fn parse_column_reference(&mut self) -> Result<String, String> {
1949 match &self.current_token {
1950 Token::Identifier(name) => {
1951 let mut column_ref = name.clone();
1952 self.advance();
1953
1954 if matches!(self.current_token, Token::Dot) {
1956 self.advance();
1957 match &self.current_token {
1958 Token::Identifier(col_name) => {
1959 column_ref.push('.');
1960 column_ref.push_str(col_name);
1961 self.advance();
1962 }
1963 _ => return Err("Expected column name after '.'".to_string()),
1964 }
1965 }
1966
1967 Ok(column_ref)
1968 }
1969 _ => Err("Expected column reference".to_string()),
1970 }
1971 }
1972
1973 fn parse_pivot_clause(&mut self, source: TableSource) -> Result<TableSource, String> {
1978 self.consume(Token::Pivot)?;
1980
1981 self.consume(Token::LeftParen)?;
1983
1984 let aggregate = self.parse_pivot_aggregate()?;
1986
1987 self.consume(Token::For)?;
1989
1990 let pivot_column = match &self.current_token {
1992 Token::Identifier(col) => {
1993 let column = col.clone();
1994 self.advance();
1995 column
1996 }
1997 Token::QuotedIdentifier(col) => {
1998 let column = col.clone();
1999 self.advance();
2000 column
2001 }
2002 _ => return Err("Expected column name after FOR in PIVOT".to_string()),
2003 };
2004
2005 if !matches!(self.current_token, Token::In) {
2007 return Err("Expected IN keyword in PIVOT clause".to_string());
2008 }
2009 self.advance();
2010
2011 let pivot_values = self.parse_pivot_in_clause()?;
2013
2014 self.consume(Token::RightParen)?;
2016
2017 let alias = self.parse_optional_alias()?;
2019
2020 Ok(TableSource::Pivot {
2021 source: Box::new(source),
2022 aggregate,
2023 pivot_column,
2024 pivot_values,
2025 alias,
2026 })
2027 }
2028
2029 fn parse_pivot_aggregate(&mut self) -> Result<PivotAggregate, String> {
2032 let function = match &self.current_token {
2034 Token::Identifier(name) => {
2035 let func_name = name.to_uppercase();
2036 match func_name.as_str() {
2038 "MAX" | "MIN" | "SUM" | "AVG" | "COUNT" => {
2039 self.advance();
2040 func_name
2041 }
2042 _ => {
2043 return Err(format!(
2044 "Expected aggregate function (MAX, MIN, SUM, AVG, COUNT), got {}",
2045 func_name
2046 ))
2047 }
2048 }
2049 }
2050 _ => return Err("Expected aggregate function in PIVOT".to_string()),
2051 };
2052
2053 self.consume(Token::LeftParen)?;
2055
2056 let column = match &self.current_token {
2058 Token::Identifier(col) => {
2059 let column = col.clone();
2060 self.advance();
2061 column
2062 }
2063 Token::QuotedIdentifier(col) => {
2064 let column = col.clone();
2065 self.advance();
2066 column
2067 }
2068 Token::Star => {
2069 if function == "COUNT" {
2071 self.advance();
2072 "*".to_string()
2073 } else {
2074 return Err(format!("Only COUNT can use *, not {}", function));
2075 }
2076 }
2077 _ => return Err("Expected column name in aggregate function".to_string()),
2078 };
2079
2080 self.consume(Token::RightParen)?;
2082
2083 Ok(PivotAggregate { function, column })
2084 }
2085
2086 fn parse_pivot_in_clause(&mut self) -> Result<Vec<String>, String> {
2090 self.consume(Token::LeftParen)?;
2092
2093 let mut values = Vec::new();
2094
2095 match &self.current_token {
2097 Token::StringLiteral(val) => {
2098 values.push(val.clone());
2099 self.advance();
2100 }
2101 Token::Identifier(val) => {
2102 values.push(val.clone());
2104 self.advance();
2105 }
2106 Token::NumberLiteral(val) => {
2107 values.push(val.clone());
2109 self.advance();
2110 }
2111 _ => return Err("Expected value in PIVOT IN clause".to_string()),
2112 }
2113
2114 while matches!(self.current_token, Token::Comma) {
2116 self.advance(); match &self.current_token {
2119 Token::StringLiteral(val) => {
2120 values.push(val.clone());
2121 self.advance();
2122 }
2123 Token::Identifier(val) => {
2124 values.push(val.clone());
2125 self.advance();
2126 }
2127 Token::NumberLiteral(val) => {
2128 values.push(val.clone());
2129 self.advance();
2130 }
2131 _ => return Err("Expected value after comma in PIVOT IN clause".to_string()),
2132 }
2133 }
2134
2135 self.consume(Token::RightParen)?;
2137
2138 if values.is_empty() {
2139 return Err("PIVOT IN clause must have at least one value".to_string());
2140 }
2141
2142 Ok(values)
2143 }
2144}
2145
2146#[derive(Debug, Clone)]
2148pub enum CursorContext {
2149 SelectClause,
2150 FromClause,
2151 WhereClause,
2152 OrderByClause,
2153 AfterColumn(String),
2154 AfterLogicalOp(LogicalOp),
2155 AfterComparisonOp(String, String), InMethodCall(String, String), InExpression,
2158 Unknown,
2159}
2160
2161fn safe_slice_to(s: &str, pos: usize) -> &str {
2163 if pos >= s.len() {
2164 return s;
2165 }
2166
2167 let mut safe_pos = pos;
2169 while safe_pos > 0 && !s.is_char_boundary(safe_pos) {
2170 safe_pos -= 1;
2171 }
2172
2173 &s[..safe_pos]
2174}
2175
2176fn safe_slice_from(s: &str, pos: usize) -> &str {
2178 if pos >= s.len() {
2179 return "";
2180 }
2181
2182 let mut safe_pos = pos;
2184 while safe_pos < s.len() && !s.is_char_boundary(safe_pos) {
2185 safe_pos += 1;
2186 }
2187
2188 &s[safe_pos..]
2189}
2190
2191#[must_use]
2192pub fn detect_cursor_context(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
2193 let truncated = safe_slice_to(query, cursor_pos);
2194 let mut parser = Parser::new(truncated);
2195
2196 if let Ok(stmt) = parser.parse() {
2198 let (ctx, partial) = analyze_statement(&stmt, truncated, cursor_pos);
2199 #[cfg(test)]
2200 println!("analyze_statement returned: {ctx:?}, {partial:?} for query: '{truncated}'");
2201 (ctx, partial)
2202 } else {
2203 let (ctx, partial) = analyze_partial(truncated, cursor_pos);
2205 #[cfg(test)]
2206 println!("analyze_partial returned: {ctx:?}, {partial:?} for query: '{truncated}'");
2207 (ctx, partial)
2208 }
2209}
2210
2211#[must_use]
2212pub fn tokenize_query(query: &str) -> Vec<String> {
2213 let mut lexer = Lexer::new(query);
2214 let tokens = lexer.tokenize_all();
2215 tokens.iter().map(|t| format!("{t:?}")).collect()
2216}
2217
2218#[must_use]
2219fn find_quote_start(bytes: &[u8], mut pos: usize) -> Option<usize> {
2221 if pos > 0 {
2223 pos -= 1;
2224 while pos > 0 {
2225 if bytes[pos] == b'"' {
2226 if pos == 0 || bytes[pos - 1] != b'\\' {
2228 return Some(pos);
2229 }
2230 }
2231 pos -= 1;
2232 }
2233 if bytes[0] == b'"' {
2235 return Some(0);
2236 }
2237 }
2238 None
2239}
2240
2241fn handle_method_call_context(col_name: &str, after_dot: &str) -> (CursorContext, Option<String>) {
2243 let partial_method = if after_dot.is_empty() {
2245 None
2246 } else if after_dot.chars().all(|c| c.is_alphanumeric() || c == '_') {
2247 Some(after_dot.to_string())
2248 } else {
2249 None
2250 };
2251
2252 let col_name_for_context =
2254 if col_name.starts_with('"') && col_name.ends_with('"') && col_name.len() > 2 {
2255 col_name[1..col_name.len() - 1].to_string()
2256 } else {
2257 col_name.to_string()
2258 };
2259
2260 (
2261 CursorContext::AfterColumn(col_name_for_context),
2262 partial_method,
2263 )
2264}
2265
2266fn check_after_comparison_operator(query: &str) -> Option<(CursorContext, Option<String>)> {
2268 for op in &Parser::COMPARISON_OPERATORS {
2269 if let Some(op_pos) = query.rfind(op) {
2270 let before_op = safe_slice_to(query, op_pos);
2271 let after_op_start = op_pos + op.len();
2272 let after_op = if after_op_start < query.len() {
2273 &query[after_op_start..]
2274 } else {
2275 ""
2276 };
2277
2278 if let Some(col_name) = before_op.split_whitespace().last() {
2280 if col_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
2281 let after_op_trimmed = after_op.trim();
2283 if after_op_trimmed.is_empty()
2284 || (after_op_trimmed
2285 .chars()
2286 .all(|c| c.is_alphanumeric() || c == '_')
2287 && !after_op_trimmed.contains('('))
2288 {
2289 let partial = if after_op_trimmed.is_empty() {
2290 None
2291 } else {
2292 Some(after_op_trimmed.to_string())
2293 };
2294 return Some((
2295 CursorContext::AfterComparisonOp(
2296 col_name.to_string(),
2297 op.trim().to_string(),
2298 ),
2299 partial,
2300 ));
2301 }
2302 }
2303 }
2304 }
2305 }
2306 None
2307}
2308
2309fn analyze_statement(
2310 stmt: &SelectStatement,
2311 query: &str,
2312 _cursor_pos: usize,
2313) -> (CursorContext, Option<String>) {
2314 let trimmed = query.trim();
2316
2317 if let Some(result) = check_after_comparison_operator(query) {
2319 return result;
2320 }
2321
2322 let ends_with_logical_op = |s: &str| -> bool {
2325 let s_upper = s.to_uppercase();
2326 s_upper.ends_with(" AND") || s_upper.ends_with(" OR")
2327 };
2328
2329 if ends_with_logical_op(trimmed) {
2330 } else {
2332 if let Some(dot_pos) = trimmed.rfind('.') {
2334 let before_dot = safe_slice_to(trimmed, dot_pos);
2336 let after_dot_start = dot_pos + 1;
2337 let after_dot = if after_dot_start < trimmed.len() {
2338 &trimmed[after_dot_start..]
2339 } else {
2340 ""
2341 };
2342
2343 if !after_dot.contains('(') {
2346 let col_name = if before_dot.ends_with('"') {
2348 let bytes = before_dot.as_bytes();
2350 let pos = before_dot.len() - 1; find_quote_start(bytes, pos).map(|start| safe_slice_from(before_dot, start))
2353 } else {
2354 before_dot
2357 .split_whitespace()
2358 .last()
2359 .map(|word| word.trim_start_matches('('))
2360 };
2361
2362 if let Some(col_name) = col_name {
2363 let is_valid = Parser::is_valid_identifier(col_name);
2365
2366 if is_valid {
2367 return handle_method_call_context(col_name, after_dot);
2368 }
2369 }
2370 }
2371 }
2372 }
2373
2374 if let Some(where_clause) = &stmt.where_clause {
2376 let trimmed_upper = trimmed.to_uppercase();
2378 if trimmed_upper.ends_with(" AND") || trimmed_upper.ends_with(" OR") {
2379 let op = if trimmed_upper.ends_with(" AND") {
2380 LogicalOp::And
2381 } else {
2382 LogicalOp::Or
2383 };
2384 return (CursorContext::AfterLogicalOp(op), None);
2385 }
2386
2387 let query_upper = query.to_uppercase();
2389 if let Some(and_pos) = query_upper.rfind(" AND ") {
2390 let after_and = safe_slice_from(query, and_pos + 5);
2391 let partial = extract_partial_at_end(after_and);
2392 if partial.is_some() {
2393 return (CursorContext::AfterLogicalOp(LogicalOp::And), partial);
2394 }
2395 }
2396
2397 if let Some(or_pos) = query_upper.rfind(" OR ") {
2398 let after_or = safe_slice_from(query, or_pos + 4);
2399 let partial = extract_partial_at_end(after_or);
2400 if partial.is_some() {
2401 return (CursorContext::AfterLogicalOp(LogicalOp::Or), partial);
2402 }
2403 }
2404
2405 if let Some(last_condition) = where_clause.conditions.last() {
2406 if let Some(connector) = &last_condition.connector {
2407 return (
2409 CursorContext::AfterLogicalOp(connector.clone()),
2410 extract_partial_at_end(query),
2411 );
2412 }
2413 }
2414 return (CursorContext::WhereClause, extract_partial_at_end(query));
2416 }
2417
2418 let query_upper = query.to_uppercase();
2420 if query_upper.ends_with(" ORDER BY") {
2421 return (CursorContext::OrderByClause, None);
2422 }
2423
2424 if stmt.order_by.is_some() {
2426 return (CursorContext::OrderByClause, extract_partial_at_end(query));
2427 }
2428
2429 if stmt.from_table.is_some() && stmt.where_clause.is_none() && stmt.order_by.is_none() {
2430 return (CursorContext::FromClause, extract_partial_at_end(query));
2431 }
2432
2433 if !stmt.columns.is_empty() && stmt.from_table.is_none() {
2434 return (CursorContext::SelectClause, extract_partial_at_end(query));
2435 }
2436
2437 (CursorContext::Unknown, None)
2438}
2439
2440fn find_last_token(tokens: &[(usize, usize, Token)], target: &Token) -> Option<usize> {
2442 tokens
2443 .iter()
2444 .rposition(|(_, _, t)| t == target)
2445 .map(|idx| tokens[idx].0)
2446}
2447
2448fn find_last_matching_token<F>(
2450 tokens: &[(usize, usize, Token)],
2451 predicate: F,
2452) -> Option<(usize, &Token)>
2453where
2454 F: Fn(&Token) -> bool,
2455{
2456 tokens
2457 .iter()
2458 .rposition(|(_, _, t)| predicate(t))
2459 .map(|idx| (tokens[idx].0, &tokens[idx].2))
2460}
2461
2462fn is_in_clause(
2464 tokens: &[(usize, usize, Token)],
2465 clause_token: Token,
2466 exclude_tokens: &[Token],
2467) -> bool {
2468 if let Some(clause_pos) = find_last_token(tokens, &clause_token) {
2470 for (pos, _, token) in tokens.iter() {
2472 if *pos > clause_pos && exclude_tokens.contains(token) {
2473 return false;
2474 }
2475 }
2476 return true;
2477 }
2478 false
2479}
2480
2481fn analyze_partial(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
2482 let mut lexer = Lexer::new(query);
2484 let tokens = lexer.tokenize_all_with_positions();
2485
2486 let trimmed = query.trim();
2487
2488 #[cfg(test)]
2489 {
2490 if trimmed.contains("\"Last Name\"") {
2491 eprintln!("DEBUG analyze_partial: query='{query}', trimmed='{trimmed}'");
2492 }
2493 }
2494
2495 if let Some(result) = check_after_comparison_operator(query) {
2497 return result;
2498 }
2499
2500 if let Some(dot_pos) = trimmed.rfind('.') {
2503 #[cfg(test)]
2504 {
2505 if trimmed.contains("\"Last Name\"") {
2506 eprintln!("DEBUG: Found dot at position {dot_pos}");
2507 }
2508 }
2509 let before_dot = &trimmed[..dot_pos];
2511 let after_dot = &trimmed[dot_pos + 1..];
2512
2513 if !after_dot.contains('(') {
2516 let col_name = if before_dot.ends_with('"') {
2519 let bytes = before_dot.as_bytes();
2521 let pos = before_dot.len() - 1; #[cfg(test)]
2524 {
2525 if trimmed.contains("\"Last Name\"") {
2526 eprintln!("DEBUG: before_dot='{before_dot}', looking for opening quote");
2527 }
2528 }
2529
2530 let found_start = find_quote_start(bytes, pos);
2531
2532 if let Some(start) = found_start {
2533 let result = safe_slice_from(before_dot, start);
2535 #[cfg(test)]
2536 {
2537 if trimmed.contains("\"Last Name\"") {
2538 eprintln!("DEBUG: Extracted quoted identifier: '{result}'");
2539 }
2540 }
2541 Some(result)
2542 } else {
2543 #[cfg(test)]
2544 {
2545 if trimmed.contains("\"Last Name\"") {
2546 eprintln!("DEBUG: No opening quote found!");
2547 }
2548 }
2549 None
2550 }
2551 } else {
2552 before_dot
2555 .split_whitespace()
2556 .last()
2557 .map(|word| word.trim_start_matches('('))
2558 };
2559
2560 if let Some(col_name) = col_name {
2561 #[cfg(test)]
2562 {
2563 if trimmed.contains("\"Last Name\"") {
2564 eprintln!("DEBUG: col_name = '{col_name}'");
2565 }
2566 }
2567
2568 let is_valid = Parser::is_valid_identifier(col_name);
2570
2571 #[cfg(test)]
2572 {
2573 if trimmed.contains("\"Last Name\"") {
2574 eprintln!("DEBUG: is_valid = {is_valid}");
2575 }
2576 }
2577
2578 if is_valid {
2579 return handle_method_call_context(col_name, after_dot);
2580 }
2581 }
2582 }
2583 }
2584
2585 if let Some((pos, token)) =
2587 find_last_matching_token(&tokens, |t| matches!(t, Token::And | Token::Or))
2588 {
2589 let token_end_pos = if matches!(token, Token::And) {
2591 pos + 3 } else {
2593 pos + 2 };
2595
2596 if cursor_pos > token_end_pos {
2597 let after_op = safe_slice_from(query, token_end_pos + 1); let partial = extract_partial_at_end(after_op);
2600 let op = if matches!(token, Token::And) {
2601 LogicalOp::And
2602 } else {
2603 LogicalOp::Or
2604 };
2605 return (CursorContext::AfterLogicalOp(op), partial);
2606 }
2607 }
2608
2609 if let Some((_, _, last_token)) = tokens.last() {
2611 if matches!(last_token, Token::And | Token::Or) {
2612 let op = if matches!(last_token, Token::And) {
2613 LogicalOp::And
2614 } else {
2615 LogicalOp::Or
2616 };
2617 return (CursorContext::AfterLogicalOp(op), None);
2618 }
2619 }
2620
2621 if let Some(order_pos) = find_last_token(&tokens, &Token::OrderBy) {
2623 let has_by = tokens
2625 .iter()
2626 .any(|(pos, _, t)| *pos > order_pos && matches!(t, Token::By));
2627 if has_by
2628 || tokens
2629 .last()
2630 .map_or(false, |(_, _, t)| matches!(t, Token::OrderBy))
2631 {
2632 return (CursorContext::OrderByClause, extract_partial_at_end(query));
2633 }
2634 }
2635
2636 if is_in_clause(&tokens, Token::Where, &[Token::OrderBy, Token::GroupBy]) {
2638 return (CursorContext::WhereClause, extract_partial_at_end(query));
2639 }
2640
2641 if is_in_clause(
2643 &tokens,
2644 Token::From,
2645 &[Token::Where, Token::OrderBy, Token::GroupBy],
2646 ) {
2647 return (CursorContext::FromClause, extract_partial_at_end(query));
2648 }
2649
2650 if find_last_token(&tokens, &Token::Select).is_some()
2652 && find_last_token(&tokens, &Token::From).is_none()
2653 {
2654 return (CursorContext::SelectClause, extract_partial_at_end(query));
2655 }
2656
2657 (CursorContext::Unknown, None)
2658}
2659
2660fn extract_partial_at_end(query: &str) -> Option<String> {
2661 let trimmed = query.trim();
2662
2663 if let Some(last_word) = trimmed.split_whitespace().last() {
2665 if last_word.starts_with('"') && !last_word.ends_with('"') {
2666 return Some(last_word.to_string());
2668 }
2669 }
2670
2671 let last_word = trimmed.split_whitespace().last()?;
2673
2674 if last_word.chars().all(|c| c.is_alphanumeric() || c == '_') {
2677 if !is_sql_keyword(last_word) {
2679 Some(last_word.to_string())
2680 } else {
2681 None
2682 }
2683 } else {
2684 None
2685 }
2686}
2687
2688impl ParsePrimary for Parser {
2690 fn current_token(&self) -> &Token {
2691 &self.current_token
2692 }
2693
2694 fn advance(&mut self) {
2695 self.advance();
2696 }
2697
2698 fn consume(&mut self, expected: Token) -> Result<(), String> {
2699 self.consume(expected)
2700 }
2701
2702 fn parse_case_expression(&mut self) -> Result<SqlExpression, String> {
2703 self.parse_case_expression()
2704 }
2705
2706 fn parse_function_args(&mut self) -> Result<(Vec<SqlExpression>, bool), String> {
2707 self.parse_function_args()
2708 }
2709
2710 fn parse_window_spec(&mut self) -> Result<WindowSpec, String> {
2711 self.parse_window_spec()
2712 }
2713
2714 fn parse_logical_or(&mut self) -> Result<SqlExpression, String> {
2715 self.parse_logical_or()
2716 }
2717
2718 fn parse_comparison(&mut self) -> Result<SqlExpression, String> {
2719 self.parse_comparison()
2720 }
2721
2722 fn parse_expression_list(&mut self) -> Result<Vec<SqlExpression>, String> {
2723 self.parse_expression_list()
2724 }
2725
2726 fn parse_subquery(&mut self) -> Result<SelectStatement, String> {
2727 if matches!(self.current_token, Token::With) {
2729 self.parse_with_clause_inner()
2730 } else {
2731 self.parse_select_statement_inner()
2732 }
2733 }
2734}
2735
2736impl ExpressionParser for Parser {
2738 fn current_token(&self) -> &Token {
2739 &self.current_token
2740 }
2741
2742 fn advance(&mut self) {
2743 match &self.current_token {
2745 Token::LeftParen => self.paren_depth += 1,
2746 Token::RightParen => {
2747 self.paren_depth -= 1;
2748 }
2749 _ => {}
2750 }
2751 self.current_token = self.lexer.next_token();
2752 }
2753
2754 fn peek(&self) -> Option<&Token> {
2755 None }
2762
2763 fn is_at_end(&self) -> bool {
2764 matches!(self.current_token, Token::Eof)
2765 }
2766
2767 fn consume(&mut self, expected: Token) -> Result<(), String> {
2768 if std::mem::discriminant(&self.current_token) == std::mem::discriminant(&expected) {
2770 self.update_paren_depth(&expected)?;
2771 self.current_token = self.lexer.next_token();
2772 Ok(())
2773 } else {
2774 Err(format!(
2775 "Expected {:?}, found {:?}",
2776 expected, self.current_token
2777 ))
2778 }
2779 }
2780
2781 fn parse_identifier(&mut self) -> Result<String, String> {
2782 if let Token::Identifier(id) = &self.current_token {
2783 let id = id.clone();
2784 self.advance();
2785 Ok(id)
2786 } else {
2787 Err(format!(
2788 "Expected identifier, found {:?}",
2789 self.current_token
2790 ))
2791 }
2792 }
2793}
2794
2795impl ParseArithmetic for Parser {
2797 fn current_token(&self) -> &Token {
2798 &self.current_token
2799 }
2800
2801 fn advance(&mut self) {
2802 self.advance();
2803 }
2804
2805 fn consume(&mut self, expected: Token) -> Result<(), String> {
2806 self.consume(expected)
2807 }
2808
2809 fn parse_primary(&mut self) -> Result<SqlExpression, String> {
2810 self.parse_primary()
2811 }
2812
2813 fn parse_multiplicative(&mut self) -> Result<SqlExpression, String> {
2814 self.parse_multiplicative()
2815 }
2816
2817 fn parse_method_args(&mut self) -> Result<Vec<SqlExpression>, String> {
2818 self.parse_method_args()
2819 }
2820}
2821
2822impl ParseComparison for Parser {
2824 fn current_token(&self) -> &Token {
2825 &self.current_token
2826 }
2827
2828 fn advance(&mut self) {
2829 self.advance();
2830 }
2831
2832 fn consume(&mut self, expected: Token) -> Result<(), String> {
2833 self.consume(expected)
2834 }
2835
2836 fn parse_primary(&mut self) -> Result<SqlExpression, String> {
2837 self.parse_primary()
2838 }
2839
2840 fn parse_additive(&mut self) -> Result<SqlExpression, String> {
2841 self.parse_additive()
2842 }
2843
2844 fn parse_expression_list(&mut self) -> Result<Vec<SqlExpression>, String> {
2845 self.parse_expression_list()
2846 }
2847
2848 fn parse_subquery(&mut self) -> Result<SelectStatement, String> {
2849 if matches!(self.current_token, Token::With) {
2851 self.parse_with_clause_inner()
2852 } else {
2853 self.parse_select_statement_inner()
2854 }
2855 }
2856}
2857
2858impl ParseLogical for Parser {
2860 fn current_token(&self) -> &Token {
2861 &self.current_token
2862 }
2863
2864 fn advance(&mut self) {
2865 self.advance();
2866 }
2867
2868 fn consume(&mut self, expected: Token) -> Result<(), String> {
2869 self.consume(expected)
2870 }
2871
2872 fn parse_logical_and(&mut self) -> Result<SqlExpression, String> {
2873 self.parse_logical_and()
2874 }
2875
2876 fn parse_base_logical_expression(&mut self) -> Result<SqlExpression, String> {
2877 self.parse_comparison()
2880 }
2881
2882 fn parse_comparison(&mut self) -> Result<SqlExpression, String> {
2883 self.parse_comparison()
2884 }
2885
2886 fn parse_expression_list(&mut self) -> Result<Vec<SqlExpression>, String> {
2887 self.parse_expression_list()
2888 }
2889}
2890
2891impl ParseCase for Parser {
2893 fn current_token(&self) -> &Token {
2894 &self.current_token
2895 }
2896
2897 fn advance(&mut self) {
2898 self.advance();
2899 }
2900
2901 fn consume(&mut self, expected: Token) -> Result<(), String> {
2902 self.consume(expected)
2903 }
2904
2905 fn parse_expression(&mut self) -> Result<SqlExpression, String> {
2906 self.parse_expression()
2907 }
2908}
2909
2910fn is_sql_keyword(word: &str) -> bool {
2911 let mut lexer = Lexer::new(word);
2913 let token = lexer.next_token();
2914
2915 !matches!(token, Token::Identifier(_) | Token::Eof)
2917}
2918
2919#[cfg(test)]
2920mod tests {
2921 use super::*;
2922
2923 #[test]
2925 fn test_parser_mode_default_is_standard() {
2926 let sql = "-- Leading comment\nSELECT * FROM users";
2927 let mut parser = Parser::new(sql);
2928 let stmt = parser.parse().unwrap();
2929
2930 assert!(stmt.leading_comments.is_empty());
2932 assert!(stmt.trailing_comment.is_none());
2933 }
2934
2935 #[test]
2937 fn test_parser_mode_preserve_leading_comments() {
2938 let sql = "-- Important query\n-- Author: Alice\nSELECT id, name FROM users";
2939 let mut parser = Parser::with_mode(sql, ParserMode::PreserveComments);
2940 let stmt = parser.parse().unwrap();
2941
2942 assert_eq!(stmt.leading_comments.len(), 2);
2944 assert!(stmt.leading_comments[0].is_line_comment);
2945 assert!(stmt.leading_comments[0].text.contains("Important query"));
2946 assert!(stmt.leading_comments[1].text.contains("Author: Alice"));
2947 }
2948
2949 #[test]
2951 fn test_parser_mode_preserve_trailing_comment() {
2952 let sql = "SELECT * FROM users -- Fetch all users";
2953 let mut parser = Parser::with_mode(sql, ParserMode::PreserveComments);
2954 let stmt = parser.parse().unwrap();
2955
2956 assert!(stmt.trailing_comment.is_some());
2958 let comment = stmt.trailing_comment.unwrap();
2959 assert!(comment.is_line_comment);
2960 assert!(comment.text.contains("Fetch all users"));
2961 }
2962
2963 #[test]
2965 fn test_parser_mode_preserve_block_comments() {
2966 let sql = "/* Query explanation */\nSELECT * FROM users";
2967 let mut parser = Parser::with_mode(sql, ParserMode::PreserveComments);
2968 let stmt = parser.parse().unwrap();
2969
2970 assert_eq!(stmt.leading_comments.len(), 1);
2972 assert!(!stmt.leading_comments[0].is_line_comment); assert!(stmt.leading_comments[0].text.contains("Query explanation"));
2974 }
2975
2976 #[test]
2978 fn test_parser_mode_preserve_both_comments() {
2979 let sql = "-- Leading\nSELECT * FROM users -- Trailing";
2980 let mut parser = Parser::with_mode(sql, ParserMode::PreserveComments);
2981 let stmt = parser.parse().unwrap();
2982
2983 assert_eq!(stmt.leading_comments.len(), 1);
2985 assert!(stmt.leading_comments[0].text.contains("Leading"));
2986 assert!(stmt.trailing_comment.is_some());
2987 assert!(stmt.trailing_comment.unwrap().text.contains("Trailing"));
2988 }
2989
2990 #[test]
2992 fn test_parser_mode_standard_ignores_comments() {
2993 let sql = "-- Comment 1\n/* Comment 2 */\nSELECT * FROM users -- Comment 3";
2994 let mut parser = Parser::with_mode(sql, ParserMode::Standard);
2995 let stmt = parser.parse().unwrap();
2996
2997 assert!(stmt.leading_comments.is_empty());
2999 assert!(stmt.trailing_comment.is_none());
3000
3001 assert_eq!(stmt.select_items.len(), 1);
3003 assert_eq!(stmt.from_table, Some("users".to_string()));
3004 }
3005
3006 #[test]
3008 fn test_parser_backward_compatibility() {
3009 let sql = "SELECT id, name FROM users WHERE active = true";
3010
3011 let mut parser1 = Parser::new(sql);
3013 let stmt1 = parser1.parse().unwrap();
3014
3015 let mut parser2 = Parser::with_mode(sql, ParserMode::Standard);
3017 let stmt2 = parser2.parse().unwrap();
3018
3019 assert_eq!(stmt1.select_items.len(), stmt2.select_items.len());
3021 assert_eq!(stmt1.from_table, stmt2.from_table);
3022 assert_eq!(stmt1.where_clause.is_some(), stmt2.where_clause.is_some());
3023 assert!(stmt1.leading_comments.is_empty());
3024 assert!(stmt2.leading_comments.is_empty());
3025 }
3026
3027 #[test]
3029 fn test_pivot_parsing_not_yet_supported() {
3030 let sql = "SELECT * FROM food_eaten PIVOT (MAX(AmountEaten) FOR FoodName IN ('Sammich', 'Pickle', 'Apple'))";
3031 let mut parser = Parser::new(sql);
3032 let result = parser.parse();
3033
3034 assert!(result.is_ok());
3036 let stmt = result.unwrap();
3037
3038 assert!(stmt.from_source.is_some());
3040 if let Some(crate::sql::parser::ast::TableSource::Pivot { .. }) = stmt.from_source {
3041 } else {
3043 panic!("Expected from_source to be a Pivot variant");
3044 }
3045 }
3046
3047 #[test]
3049 fn test_pivot_aggregate_functions() {
3050 let sql = "SELECT * FROM sales PIVOT (SUM(amount) FOR month IN ('Jan', 'Feb', 'Mar'))";
3052 let mut parser = Parser::new(sql);
3053 let result = parser.parse();
3054 assert!(result.is_ok());
3055
3056 let sql2 = "SELECT * FROM sales PIVOT (COUNT(*) FOR month IN ('Jan', 'Feb'))";
3058 let mut parser2 = Parser::new(sql2);
3059 let result2 = parser2.parse();
3060 assert!(result2.is_ok());
3061
3062 let sql3 = "SELECT * FROM sales PIVOT (AVG(price) FOR category IN ('A', 'B'))";
3064 let mut parser3 = Parser::new(sql3);
3065 let result3 = parser3.parse();
3066 assert!(result3.is_ok());
3067 }
3068
3069 #[test]
3071 fn test_pivot_with_subquery() {
3072 let sql = "SELECT * FROM (SELECT * FROM food_eaten WHERE Id > 5) AS t \
3073 PIVOT (MAX(AmountEaten) FOR FoodName IN ('Sammich', 'Pickle'))";
3074 let mut parser = Parser::new(sql);
3075 let result = parser.parse();
3076
3077 assert!(result.is_ok());
3079 let stmt = result.unwrap();
3080 assert!(stmt.from_source.is_some());
3081 }
3082
3083 #[test]
3085 fn test_pivot_with_alias() {
3086 let sql =
3087 "SELECT * FROM sales PIVOT (SUM(amount) FOR month IN ('Jan', 'Feb')) AS pivot_table";
3088 let mut parser = Parser::new(sql);
3089 let result = parser.parse();
3090
3091 assert!(result.is_ok());
3093 let stmt = result.unwrap();
3094 assert!(stmt.from_source.is_some());
3095 }
3096}