1pub use super::parser::ast::{
5 CTEType, Comment, Condition, DataFormat, FileCTESpec, FrameBound, FrameUnit, HttpMethod,
6 IntoTable, JoinClause, JoinCondition, JoinOperator, JoinType, LogicalOp, OrderByColumn,
7 OrderByItem, PivotAggregate, SelectItem, SelectStatement, SetOperation, SingleJoinCondition,
8 SortDirection, SqlExpression, TableFunction, TableSource, WebCTESpec, WhenBranch, WhereClause,
9 WindowFrame, WindowSpec, CTE,
10};
11pub use super::parser::legacy::{ParseContext, ParseState, Schema, SqlParser, SqlToken, TableInfo};
12pub use super::parser::lexer::{Lexer, LexerMode, Token};
13pub use super::parser::ParserConfig;
14
15pub use super::parser::formatter::{format_ast_tree, format_sql_pretty, format_sql_pretty_compact};
17
18pub use super::parser::ast_formatter::{format_sql_ast, format_sql_ast_with_config, FormatConfig};
20
21use super::parser::expressions::arithmetic::{
23 parse_additive as parse_additive_expr, parse_multiplicative as parse_multiplicative_expr,
24 ParseArithmetic,
25};
26use super::parser::expressions::case::{parse_case_expression as parse_case_expr, ParseCase};
27use super::parser::expressions::comparison::{
28 parse_comparison as parse_comparison_expr, parse_in_operator, ParseComparison,
29};
30use super::parser::expressions::logical::{
31 parse_logical_and as parse_logical_and_expr, parse_logical_or as parse_logical_or_expr,
32 ParseLogical,
33};
34use super::parser::expressions::primary::{
35 parse_primary as parse_primary_expr, ParsePrimary, PrimaryExpressionContext,
36};
37use super::parser::expressions::ExpressionParser;
38
39use crate::sql::functions::{FunctionCategory, FunctionRegistry};
41use crate::sql::generators::GeneratorRegistry;
42use std::sync::Arc;
43
44use super::parser::file_cte_parser::FileCteParser;
46use super::parser::web_cte_parser::WebCteParser;
47
48#[derive(Debug, Clone, Copy, PartialEq)]
50pub enum ParserMode {
51 Standard,
53 PreserveComments,
55}
56
57impl Default for ParserMode {
58 fn default() -> Self {
59 ParserMode::Standard
60 }
61}
62
63pub struct Parser {
64 lexer: Lexer,
65 pub current_token: Token, in_method_args: bool, columns: Vec<String>, paren_depth: i32, paren_depth_stack: Vec<i32>, _config: ParserConfig, debug_trace: bool, trace_depth: usize, function_registry: Arc<FunctionRegistry>, generator_registry: Arc<GeneratorRegistry>, mode: ParserMode, }
77
78impl Parser {
79 #[must_use]
80 pub fn new(input: &str) -> Self {
81 Self::with_mode(input, ParserMode::default())
82 }
83
84 #[must_use]
86 pub fn with_mode(input: &str, mode: ParserMode) -> Self {
87 let lexer_mode = match mode {
89 ParserMode::Standard => LexerMode::SkipComments,
90 ParserMode::PreserveComments => LexerMode::PreserveComments,
91 };
92
93 let mut lexer = Lexer::with_mode(input, lexer_mode);
94 let current_token = lexer.next_token();
95 Self {
96 lexer,
97 current_token,
98 in_method_args: false,
99 columns: Vec::new(),
100 paren_depth: 0,
101 paren_depth_stack: Vec::new(),
102 _config: ParserConfig::default(),
103 debug_trace: false,
104 trace_depth: 0,
105 function_registry: Arc::new(FunctionRegistry::new()),
106 generator_registry: Arc::new(GeneratorRegistry::new()),
107 mode,
108 }
109 }
110
111 #[must_use]
112 pub fn with_config(input: &str, config: ParserConfig) -> Self {
113 let mut lexer = Lexer::new(input);
114 let current_token = lexer.next_token();
115 Self {
116 lexer,
117 current_token,
118 in_method_args: false,
119 columns: Vec::new(),
120 paren_depth: 0,
121 paren_depth_stack: Vec::new(),
122 _config: config,
123 debug_trace: false,
124 trace_depth: 0,
125 function_registry: Arc::new(FunctionRegistry::new()),
126 generator_registry: Arc::new(GeneratorRegistry::new()),
127 mode: ParserMode::default(),
128 }
129 }
130
131 #[must_use]
132 pub fn with_columns(mut self, columns: Vec<String>) -> Self {
133 self.columns = columns;
134 self
135 }
136
137 #[must_use]
138 pub fn with_debug_trace(mut self, enabled: bool) -> Self {
139 self.debug_trace = enabled;
140 self
141 }
142
143 #[must_use]
144 pub fn with_function_registry(mut self, registry: Arc<FunctionRegistry>) -> Self {
145 self.function_registry = registry;
146 self
147 }
148
149 #[must_use]
150 pub fn with_generator_registry(mut self, registry: Arc<GeneratorRegistry>) -> Self {
151 self.generator_registry = registry;
152 self
153 }
154
155 fn trace_enter(&mut self, context: &str) {
156 if self.debug_trace {
157 let indent = " ".repeat(self.trace_depth);
158 eprintln!("{}→ {} | Token: {:?}", indent, context, self.current_token);
159 self.trace_depth += 1;
160 }
161 }
162
163 fn trace_exit(&mut self, context: &str, result: &Result<impl std::fmt::Debug, String>) {
164 if self.debug_trace {
165 self.trace_depth = self.trace_depth.saturating_sub(1);
166 let indent = " ".repeat(self.trace_depth);
167 match result {
168 Ok(val) => eprintln!("{}← {} ✓ | Result: {:?}", indent, context, val),
169 Err(e) => eprintln!("{}← {} ✗ | Error: {}", indent, context, e),
170 }
171 }
172 }
173
174 fn trace_token(&self, action: &str) {
175 if self.debug_trace {
176 let indent = " ".repeat(self.trace_depth);
177 eprintln!("{} {} | Token: {:?}", indent, action, self.current_token);
178 }
179 }
180
181 #[allow(dead_code)]
182 fn peek_token(&self) -> Option<Token> {
183 let mut temp_lexer = self.lexer.clone();
185 let next_token = temp_lexer.next_token();
186 if matches!(next_token, Token::Eof) {
187 None
188 } else {
189 Some(next_token)
190 }
191 }
192
193 fn is_identifier_reserved(id: &str) -> bool {
198 let id_upper = id.to_uppercase();
199 matches!(
200 id_upper.as_str(),
201 "ORDER" | "HAVING" | "LIMIT" | "OFFSET" | "UNION" | "INTERSECT" | "EXCEPT"
202 )
203 }
204
205 const COMPARISON_OPERATORS: [&'static str; 6] = [" > ", " < ", " >= ", " <= ", " = ", " != "];
207
208 pub fn consume(&mut self, expected: Token) -> Result<(), String> {
209 self.trace_token(&format!("Consuming expected {:?}", expected));
210 if std::mem::discriminant(&self.current_token) == std::mem::discriminant(&expected) {
211 self.update_paren_depth(&expected)?;
213
214 self.current_token = self.lexer.next_token();
215 Ok(())
216 } else {
217 let error_msg = match (&expected, &self.current_token) {
219 (Token::RightParen, Token::Eof) if self.paren_depth > 0 => {
220 format!(
221 "Unclosed parenthesis - missing {} closing parenthes{}",
222 self.paren_depth,
223 if self.paren_depth == 1 { "is" } else { "es" }
224 )
225 }
226 (Token::RightParen, _) if self.paren_depth > 0 => {
227 format!(
228 "Expected closing parenthesis but found {:?} (currently {} unclosed parenthes{})",
229 self.current_token,
230 self.paren_depth,
231 if self.paren_depth == 1 { "is" } else { "es" }
232 )
233 }
234 _ => format!("Expected {:?}, found {:?}", expected, self.current_token),
235 };
236 Err(error_msg)
237 }
238 }
239
240 pub fn advance(&mut self) {
241 match &self.current_token {
243 Token::LeftParen => self.paren_depth += 1,
244 Token::RightParen => {
245 self.paren_depth -= 1;
246 }
249 _ => {}
250 }
251 let old_token = self.current_token.clone();
252 self.current_token = self.lexer.next_token();
253 if self.debug_trace {
254 let indent = " ".repeat(self.trace_depth);
255 eprintln!(
256 "{} Advanced: {:?} → {:?}",
257 indent, old_token, self.current_token
258 );
259 }
260 }
261
262 fn collect_leading_comments(&mut self) -> Vec<Comment> {
265 let mut comments = Vec::new();
266 loop {
267 match &self.current_token {
268 Token::LineComment(text) => {
269 comments.push(Comment::line(text.clone()));
270 self.advance();
271 }
272 Token::BlockComment(text) => {
273 comments.push(Comment::block(text.clone()));
274 self.advance();
275 }
276 _ => break,
277 }
278 }
279 comments
280 }
281
282 fn collect_trailing_comment(&mut self) -> Option<Comment> {
285 match &self.current_token {
286 Token::LineComment(text) => {
287 let comment = Some(Comment::line(text.clone()));
288 self.advance();
289 comment
290 }
291 Token::BlockComment(text) => {
292 let comment = Some(Comment::block(text.clone()));
293 self.advance();
294 comment
295 }
296 _ => None,
297 }
298 }
299
300 fn push_paren_depth(&mut self) {
301 self.paren_depth_stack.push(self.paren_depth);
302 self.paren_depth = 0;
303 }
304
305 fn pop_paren_depth(&mut self) {
306 if let Some(depth) = self.paren_depth_stack.pop() {
307 self.paren_depth = depth;
309 }
310 }
311
312 pub fn parse(&mut self) -> Result<SelectStatement, String> {
313 self.trace_enter("parse");
314
315 let leading_comments = if self.mode == ParserMode::PreserveComments {
318 self.collect_leading_comments()
319 } else {
320 vec![]
321 };
322
323 let result = if matches!(self.current_token, Token::With) {
325 let mut stmt = self.parse_with_clause()?;
326 stmt.leading_comments = leading_comments;
328 stmt
329 } else {
330 let stmt = self.parse_select_statement_with_comments_public(leading_comments)?;
332 self.check_balanced_parentheses()?;
333 stmt
334 };
335
336 self.trace_exit("parse", &Ok(&result));
337 Ok(result)
338 }
339
340 fn parse_select_statement_with_comments_public(
342 &mut self,
343 comments: Vec<Comment>,
344 ) -> Result<SelectStatement, String> {
345 self.parse_select_statement_with_comments(comments)
346 }
347
348 fn parse_with_clause(&mut self) -> Result<SelectStatement, String> {
349 self.consume(Token::With)?;
350 let ctes = self.parse_cte_list()?;
351
352 let mut main_query = self.parse_select_statement_inner_no_comments()?;
354 main_query.ctes = ctes;
355
356 self.check_balanced_parentheses()?;
358
359 Ok(main_query)
360 }
361
362 fn parse_with_clause_inner(&mut self) -> Result<SelectStatement, String> {
363 self.consume(Token::With)?;
364 let ctes = self.parse_cte_list()?;
365
366 let mut main_query = self.parse_select_statement_inner()?;
368 main_query.ctes = ctes;
369
370 Ok(main_query)
371 }
372
373 fn parse_cte_list(&mut self) -> Result<Vec<CTE>, String> {
375 let mut ctes = Vec::new();
376
377 loop {
379 let is_web = if matches!(&self.current_token, Token::Web) {
383 self.trace_token("Found WEB keyword for CTE");
384 self.advance();
385 true
386 } else {
387 false
388 };
389
390 let name = match &self.current_token {
392 Token::Identifier(name) => name.clone(),
393 token => {
394 if let Some(keyword) = token.as_keyword_str() {
396 keyword.to_lowercase()
398 } else {
399 return Err(format!(
400 "Expected CTE name after {}",
401 if is_web { "WEB" } else { "WITH or comma" }
402 ));
403 }
404 }
405 };
406 self.advance();
407
408 let column_list = if matches!(self.current_token, Token::LeftParen) {
410 self.advance();
411 let cols = self.parse_identifier_list()?;
412 self.consume(Token::RightParen)?;
413 Some(cols)
414 } else {
415 None
416 };
417
418 self.consume(Token::As)?;
420
421 let cte_type = if is_web {
422 self.consume(Token::LeftParen)?;
424 let web_spec = WebCteParser::parse(self)?;
426 self.consume(Token::RightParen)?;
428 CTEType::Web(web_spec)
429 } else {
430 self.push_paren_depth();
434 self.consume(Token::LeftParen)?;
435
436 let result = if matches!(&self.current_token, Token::File) {
437 self.trace_token("Found FILE keyword inside CTE parens");
438 self.advance();
439 let file_spec = FileCteParser::parse(self)?;
440 CTEType::File(file_spec)
441 } else {
442 let query = self.parse_select_statement_inner()?;
443 CTEType::Standard(query)
444 };
445
446 self.consume(Token::RightParen)?;
448 self.pop_paren_depth();
450 result
451 };
452
453 ctes.push(CTE {
454 name,
455 column_list,
456 cte_type,
457 });
458
459 if !matches!(self.current_token, Token::Comma) {
461 break;
462 }
463 self.advance();
464 }
465
466 Ok(ctes)
467 }
468
469 fn parse_optional_alias(&mut self) -> Result<Option<String>, String> {
471 if matches!(self.current_token, Token::As) {
472 self.advance();
473 match &self.current_token {
474 Token::Identifier(name) => {
475 let alias = name.clone();
476 self.advance();
477 Ok(Some(alias))
478 }
479 token => {
480 if let Some(keyword) = token.as_keyword_str() {
482 Err(format!(
483 "Reserved keyword '{}' cannot be used as column alias. Use a different name or quote it with double quotes: \"{}\"",
484 keyword,
485 keyword.to_lowercase()
486 ))
487 } else {
488 Err("Expected alias name after AS".to_string())
489 }
490 }
491 }
492 } else if let Token::Identifier(name) = &self.current_token {
493 let alias = name.clone();
495 self.advance();
496 Ok(Some(alias))
497 } else {
498 Ok(None)
499 }
500 }
501
502 fn is_valid_identifier(name: &str) -> bool {
504 if name.starts_with('"') && name.ends_with('"') {
505 true
507 } else {
508 name.chars().all(|c| c.is_alphanumeric() || c == '_')
510 }
511 }
512
513 fn update_paren_depth(&mut self, token: &Token) -> Result<(), String> {
515 match token {
516 Token::LeftParen => self.paren_depth += 1,
517 Token::RightParen => {
518 self.paren_depth -= 1;
519 if self.paren_depth < 0 {
521 return Err(
522 "Unexpected closing parenthesis - no matching opening parenthesis"
523 .to_string(),
524 );
525 }
526 }
527 _ => {}
528 }
529 Ok(())
530 }
531
532 fn parse_argument_list(&mut self) -> Result<Vec<SqlExpression>, String> {
534 let mut args = Vec::new();
535
536 if !matches!(self.current_token, Token::RightParen) {
537 loop {
538 args.push(self.parse_expression()?);
539
540 if matches!(self.current_token, Token::Comma) {
541 self.advance();
542 } else {
543 break;
544 }
545 }
546 }
547
548 Ok(args)
549 }
550
551 fn check_balanced_parentheses(&self) -> Result<(), String> {
553 if self.paren_depth > 0 {
554 Err(format!(
555 "Unclosed parenthesis - missing {} closing parenthes{}",
556 self.paren_depth,
557 if self.paren_depth == 1 { "is" } else { "es" }
558 ))
559 } else if self.paren_depth < 0 {
560 Err("Extra closing parenthesis found - no matching opening parenthesis".to_string())
561 } else {
562 Ok(())
563 }
564 }
565
566 fn contains_aggregate_function(expr: &SqlExpression) -> bool {
569 match expr {
570 SqlExpression::FunctionCall { name, args, .. } => {
571 let upper_name = name.to_uppercase();
573 let is_aggregate = matches!(
574 upper_name.as_str(),
575 "COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "GROUP_CONCAT" | "STRING_AGG"
576 );
577
578 is_aggregate || args.iter().any(Self::contains_aggregate_function)
581 }
582 SqlExpression::BinaryOp { left, right, .. } => {
584 Self::contains_aggregate_function(left) || Self::contains_aggregate_function(right)
585 }
586 SqlExpression::Not { expr } => Self::contains_aggregate_function(expr),
587 SqlExpression::MethodCall { args, .. } => {
588 args.iter().any(Self::contains_aggregate_function)
589 }
590 SqlExpression::ChainedMethodCall { base, args, .. } => {
591 Self::contains_aggregate_function(base)
592 || args.iter().any(Self::contains_aggregate_function)
593 }
594 SqlExpression::CaseExpression {
595 when_branches,
596 else_branch,
597 } => {
598 when_branches.iter().any(|branch| {
599 Self::contains_aggregate_function(&branch.condition)
600 || Self::contains_aggregate_function(&branch.result)
601 }) || else_branch
602 .as_ref()
603 .map_or(false, |e| Self::contains_aggregate_function(e))
604 }
605 SqlExpression::SimpleCaseExpression {
606 expr,
607 when_branches,
608 else_branch,
609 } => {
610 Self::contains_aggregate_function(expr)
611 || when_branches.iter().any(|branch| {
612 Self::contains_aggregate_function(&branch.value)
613 || Self::contains_aggregate_function(&branch.result)
614 })
615 || else_branch
616 .as_ref()
617 .map_or(false, |e| Self::contains_aggregate_function(e))
618 }
619 SqlExpression::ScalarSubquery { query } => {
620 query
623 .having
624 .as_ref()
625 .map_or(false, |h| Self::contains_aggregate_function(h))
626 }
627 SqlExpression::Column(_)
629 | SqlExpression::StringLiteral(_)
630 | SqlExpression::NumberLiteral(_)
631 | SqlExpression::BooleanLiteral(_)
632 | SqlExpression::Null
633 | SqlExpression::DateTimeConstructor { .. }
634 | SqlExpression::DateTimeToday { .. } => false,
635
636 SqlExpression::WindowFunction { .. } => true,
638
639 SqlExpression::Between { expr, lower, upper } => {
641 Self::contains_aggregate_function(expr)
642 || Self::contains_aggregate_function(lower)
643 || Self::contains_aggregate_function(upper)
644 }
645
646 SqlExpression::InList { expr, values } | SqlExpression::NotInList { expr, values } => {
648 Self::contains_aggregate_function(expr)
649 || values.iter().any(Self::contains_aggregate_function)
650 }
651
652 SqlExpression::InSubquery { expr, subquery }
654 | SqlExpression::NotInSubquery { expr, subquery } => {
655 Self::contains_aggregate_function(expr)
656 || subquery
657 .having
658 .as_ref()
659 .map_or(false, |h| Self::contains_aggregate_function(h))
660 }
661
662 SqlExpression::Unnest { column, .. } => Self::contains_aggregate_function(column),
664 }
665 }
666
667 fn parse_select_statement(&mut self) -> Result<SelectStatement, String> {
668 self.trace_enter("parse_select_statement");
669 let result = self.parse_select_statement_inner()?;
670
671 self.check_balanced_parentheses()?;
673
674 Ok(result)
675 }
676
677 fn parse_select_statement_inner(&mut self) -> Result<SelectStatement, String> {
678 let leading_comments = if self.mode == ParserMode::PreserveComments {
680 self.collect_leading_comments()
681 } else {
682 vec![]
683 };
684
685 self.parse_select_statement_with_comments(leading_comments)
686 }
687
688 fn parse_select_statement_inner_no_comments(&mut self) -> Result<SelectStatement, String> {
691 self.parse_select_statement_with_comments(vec![])
692 }
693
694 fn parse_select_statement_with_comments(
696 &mut self,
697 leading_comments: Vec<Comment>,
698 ) -> Result<SelectStatement, String> {
699 self.consume(Token::Select)?;
700
701 let distinct = if matches!(self.current_token, Token::Distinct) {
703 self.advance();
704 true
705 } else {
706 false
707 };
708
709 let select_items = self.parse_select_items()?;
711
712 let columns = select_items
714 .iter()
715 .map(|item| match item {
716 SelectItem::Star { .. } => "*".to_string(),
717 SelectItem::StarExclude { .. } => "*".to_string(), SelectItem::Column {
719 column: col_ref, ..
720 } => col_ref.name.clone(),
721 SelectItem::Expression { alias, .. } => alias.clone(),
722 })
723 .collect();
724
725 let into_table = if matches!(self.current_token, Token::Into) {
727 self.advance();
728 Some(self.parse_into_clause()?)
729 } else {
730 None
731 };
732
733 let (from_table, from_subquery, from_function, from_alias) = if matches!(
735 self.current_token,
736 Token::From
737 ) {
738 self.advance();
739
740 let table_or_function_name = match &self.current_token {
743 Token::Identifier(name) => Some(name.clone()),
744 token => {
745 token.as_keyword_str().map(|k| k.to_lowercase())
747 }
748 };
749
750 if let Some(name) = table_or_function_name {
751 let has_paren = self.peek_token() == Some(Token::LeftParen);
755 if self.debug_trace {
756 eprintln!(
757 " Checking {} for table function, has_paren={}",
758 name, has_paren
759 );
760 }
761
762 let is_table_function = if has_paren {
765 if self.debug_trace {
767 eprintln!(" Checking generator registry for {}", name.to_uppercase());
768 }
769 if let Some(_gen) = self.generator_registry.get(&name.to_uppercase()) {
770 if self.debug_trace {
771 eprintln!(" Found {} in generator registry", name);
772 }
773 self.trace_token(&format!("Found generator: {}", name));
774 true
775 } else {
776 if let Some(func) = self.function_registry.get(&name.to_uppercase()) {
778 let sig = func.signature();
779 let is_table_fn = sig.category == FunctionCategory::TableFunction;
780 if self.debug_trace {
781 eprintln!(
782 " Found {} in function registry, is_table_function={}",
783 name, is_table_fn
784 );
785 }
786 if is_table_fn {
787 self.trace_token(&format!(
788 "Found table function in function registry: {}",
789 name
790 ));
791 }
792 is_table_fn
793 } else {
794 if self.debug_trace {
795 eprintln!(" {} not found in either registry", name);
796 self.trace_token(&format!(
797 "Not found as generator or table function: {}",
798 name
799 ));
800 }
801 false
802 }
803 }
804 } else {
805 if self.debug_trace {
806 eprintln!(" No parenthesis after {}, treating as table", name);
807 }
808 false
809 };
810
811 if is_table_function {
812 let function_name = name.clone();
814 self.advance(); self.consume(Token::LeftParen)?;
818 let args = self.parse_argument_list()?;
819 self.consume(Token::RightParen)?;
820
821 let alias = if matches!(self.current_token, Token::As) {
823 self.advance();
824 match &self.current_token {
825 Token::Identifier(name) => {
826 let alias = name.clone();
827 self.advance();
828 Some(alias)
829 }
830 token => {
831 if let Some(keyword) = token.as_keyword_str() {
832 return Err(format!(
833 "Reserved keyword '{}' cannot be used as column alias. Use a different name or quote it with double quotes: \"{}\"",
834 keyword,
835 keyword.to_lowercase()
836 ));
837 } else {
838 return Err("Expected alias name after AS".to_string());
839 }
840 }
841 }
842 } else if let Token::Identifier(name) = &self.current_token {
843 let alias = name.clone();
844 self.advance();
845 Some(alias)
846 } else {
847 None
848 };
849
850 (
851 None,
852 None,
853 Some(TableFunction::Generator {
854 name: function_name,
855 args,
856 }),
857 alias,
858 )
859 } else {
860 let table_name = name.clone();
862 self.advance();
863
864 let alias = self.parse_optional_alias()?;
866
867 (Some(table_name), None, None, alias)
868 }
869 } else if matches!(self.current_token, Token::LeftParen) {
870 self.advance();
872
873 let subquery = if matches!(self.current_token, Token::With) {
875 self.parse_with_clause_inner()?
876 } else {
877 self.parse_select_statement_inner()?
878 };
879
880 self.consume(Token::RightParen)?;
881
882 let alias = if matches!(self.current_token, Token::As) {
884 self.advance();
885 match &self.current_token {
886 Token::Identifier(name) => {
887 let alias = name.clone();
888 self.advance();
889 alias
890 }
891 token => {
892 if let Some(keyword) = token.as_keyword_str() {
893 return Err(format!(
894 "Reserved keyword '{}' cannot be used as subquery alias. Use a different name or quote it with double quotes: \"{}\"",
895 keyword,
896 keyword.to_lowercase()
897 ));
898 } else {
899 return Err("Expected alias name after AS".to_string());
900 }
901 }
902 }
903 } else {
904 match &self.current_token {
906 Token::Identifier(name) => {
907 let alias = name.clone();
908 self.advance();
909 alias
910 }
911 _ => {
912 return Err(
913 "Subquery in FROM must have an alias (e.g., AS t)".to_string()
914 )
915 }
916 }
917 };
918
919 (None, Some(Box::new(subquery)), None, Some(alias))
920 } else {
921 let table_name = match &self.current_token {
923 Token::Identifier(table) => table.clone(),
924 Token::QuotedIdentifier(table) => table.clone(),
925 token => {
926 if let Some(keyword) = token.as_keyword_str() {
928 keyword.to_lowercase()
929 } else {
930 return Err("Expected table name or subquery after FROM".to_string());
931 }
932 }
933 };
934
935 self.advance();
936
937 let alias = self.parse_optional_alias()?;
939
940 (Some(table_name), None, None, alias)
941 }
942 } else {
943 (None, None, None, None)
944 };
945
946 let pivot_source = if matches!(self.current_token, Token::Pivot) {
950 let source = if let Some(ref table_name) = from_table {
952 TableSource::Table(table_name.clone())
953 } else if let Some(ref subquery) = from_subquery {
954 TableSource::DerivedTable {
955 query: subquery.clone(),
956 alias: from_alias.clone().unwrap_or_default(),
957 }
958 } else {
959 return Err("PIVOT requires a table or subquery source".to_string());
960 };
961
962 let pivoted = self.parse_pivot_clause(source)?;
964 Some(pivoted)
965 } else {
966 None
967 };
968
969 let mut joins = Vec::new();
971 while self.is_join_token() {
972 joins.push(self.parse_join_clause()?);
973 }
974
975 let where_clause = if matches!(self.current_token, Token::Where) {
976 self.advance();
977 Some(self.parse_where_clause()?)
978 } else {
979 None
980 };
981
982 let group_by = if matches!(self.current_token, Token::GroupBy) {
983 self.advance();
984 Some(self.parse_expression_list()?)
987 } else {
988 None
989 };
990
991 let having = if matches!(self.current_token, Token::Having) {
993 if group_by.is_none() {
994 return Err("HAVING clause requires GROUP BY".to_string());
995 }
996 self.advance();
997 let having_expr = self.parse_expression()?;
998
999 Some(having_expr)
1004 } else {
1005 None
1006 };
1007
1008 let qualify = if matches!(self.current_token, Token::Qualify) {
1012 self.advance();
1013 let qualify_expr = self.parse_expression()?;
1014
1015 Some(qualify_expr)
1019 } else {
1020 None
1021 };
1022
1023 let order_by = if matches!(self.current_token, Token::OrderBy) {
1025 self.trace_token("Found OrderBy token");
1026 self.advance();
1027 Some(self.parse_order_by_list()?)
1028 } else if let Token::Identifier(s) = &self.current_token {
1029 if Self::is_identifier_reserved(s) && s.to_uppercase() == "ORDER" {
1032 self.trace_token("Warning: ORDER as identifier instead of OrderBy token");
1033 self.advance(); if matches!(&self.current_token, Token::By) {
1035 self.advance(); Some(self.parse_order_by_list()?)
1037 } else {
1038 return Err("Expected BY after ORDER".to_string());
1039 }
1040 } else {
1041 None
1042 }
1043 } else {
1044 None
1045 };
1046
1047 let limit = if matches!(self.current_token, Token::Limit) {
1049 self.advance();
1050 match &self.current_token {
1051 Token::NumberLiteral(num) => {
1052 let limit_val = num
1053 .parse::<usize>()
1054 .map_err(|_| format!("Invalid LIMIT value: {num}"))?;
1055 self.advance();
1056 Some(limit_val)
1057 }
1058 _ => return Err("Expected number after LIMIT".to_string()),
1059 }
1060 } else {
1061 None
1062 };
1063
1064 let offset = if matches!(self.current_token, Token::Offset) {
1066 self.advance();
1067 match &self.current_token {
1068 Token::NumberLiteral(num) => {
1069 let offset_val = num
1070 .parse::<usize>()
1071 .map_err(|_| format!("Invalid OFFSET value: {num}"))?;
1072 self.advance();
1073 Some(offset_val)
1074 }
1075 _ => return Err("Expected number after OFFSET".to_string()),
1076 }
1077 } else {
1078 None
1079 };
1080
1081 let into_table = if into_table.is_none() && matches!(self.current_token, Token::Into) {
1085 self.advance();
1086 Some(self.parse_into_clause()?)
1087 } else {
1088 into_table };
1090
1091 let set_operations = self.parse_set_operations()?;
1093
1094 let trailing_comment = if self.mode == ParserMode::PreserveComments {
1096 self.collect_trailing_comment()
1097 } else {
1098 None
1099 };
1100
1101 let from_source = if let Some(pivot) = pivot_source {
1104 Some(pivot)
1105 } else if let Some(ref table_name) = from_table {
1106 Some(TableSource::Table(table_name.clone()))
1107 } else if let Some(ref subquery) = from_subquery {
1108 Some(TableSource::DerivedTable {
1109 query: subquery.clone(),
1110 alias: from_alias.clone().unwrap_or_default(),
1111 })
1112 } else if let Some(ref _func) = from_function {
1113 None
1116 } else {
1117 None
1118 };
1119
1120 Ok(SelectStatement {
1121 distinct,
1122 columns,
1123 select_items,
1124 from_source,
1125 #[allow(deprecated)]
1126 from_table,
1127 #[allow(deprecated)]
1128 from_subquery,
1129 #[allow(deprecated)]
1130 from_function,
1131 #[allow(deprecated)]
1132 from_alias,
1133 joins,
1134 where_clause,
1135 order_by,
1136 group_by,
1137 having,
1138 qualify,
1139 limit,
1140 offset,
1141 ctes: Vec::new(), into_table,
1143 set_operations,
1144 leading_comments,
1145 trailing_comment,
1146 })
1147 }
1148
1149 fn parse_set_operations(
1152 &mut self,
1153 ) -> Result<Vec<(SetOperation, Box<SelectStatement>)>, String> {
1154 let mut operations = Vec::new();
1155
1156 while matches!(
1157 self.current_token,
1158 Token::Union | Token::Intersect | Token::Except
1159 ) {
1160 let operation = match &self.current_token {
1162 Token::Union => {
1163 self.advance();
1164 if let Token::Identifier(id) = &self.current_token {
1166 if id.to_uppercase() == "ALL" {
1167 self.advance();
1168 SetOperation::UnionAll
1169 } else {
1170 SetOperation::Union
1171 }
1172 } else {
1173 SetOperation::Union
1174 }
1175 }
1176 Token::Intersect => {
1177 self.advance();
1178 SetOperation::Intersect
1179 }
1180 Token::Except => {
1181 self.advance();
1182 SetOperation::Except
1183 }
1184 _ => unreachable!(),
1185 };
1186
1187 let next_select = self.parse_select_statement_inner()?;
1189
1190 operations.push((operation, Box::new(next_select)));
1191 }
1192
1193 Ok(operations)
1194 }
1195
1196 fn parse_select_items(&mut self) -> Result<Vec<SelectItem>, String> {
1198 let mut items = Vec::new();
1199
1200 loop {
1201 if let Token::Identifier(name) = &self.current_token.clone() {
1204 let saved_pos = self.lexer.clone();
1206 let saved_token = self.current_token.clone();
1207 let table_name = name.clone();
1208
1209 self.advance();
1210
1211 if matches!(self.current_token, Token::Dot) {
1212 self.advance();
1213 if matches!(self.current_token, Token::Star) {
1214 items.push(SelectItem::Star {
1216 table_prefix: Some(table_name),
1217 leading_comments: vec![],
1218 trailing_comment: None,
1219 });
1220 self.advance();
1221
1222 if matches!(self.current_token, Token::Comma) {
1224 self.advance();
1225 continue;
1226 } else {
1227 break;
1228 }
1229 }
1230 }
1231
1232 self.lexer = saved_pos;
1234 self.current_token = saved_token;
1235 }
1236
1237 if matches!(self.current_token, Token::Star) {
1239 self.advance(); if matches!(self.current_token, Token::Exclude) {
1243 self.advance(); if !matches!(self.current_token, Token::LeftParen) {
1247 return Err("Expected '(' after EXCLUDE".to_string());
1248 }
1249 self.advance(); let mut excluded_columns = Vec::new();
1253 loop {
1254 match &self.current_token {
1255 Token::Identifier(col_name) | Token::QuotedIdentifier(col_name) => {
1256 excluded_columns.push(col_name.clone());
1257 self.advance();
1258 }
1259 _ => return Err("Expected column name in EXCLUDE list".to_string()),
1260 }
1261
1262 if matches!(self.current_token, Token::Comma) {
1264 self.advance();
1265 } else if matches!(self.current_token, Token::RightParen) {
1266 self.advance(); break;
1268 } else {
1269 return Err("Expected ',' or ')' in EXCLUDE list".to_string());
1270 }
1271 }
1272
1273 if excluded_columns.is_empty() {
1274 return Err("EXCLUDE list cannot be empty".to_string());
1275 }
1276
1277 items.push(SelectItem::StarExclude {
1278 table_prefix: None,
1279 excluded_columns,
1280 leading_comments: vec![],
1281 trailing_comment: None,
1282 });
1283 } else {
1284 items.push(SelectItem::Star {
1286 table_prefix: None,
1287 leading_comments: vec![],
1288 trailing_comment: None,
1289 });
1290 }
1291 } else {
1292 let expr = self.parse_comparison()?; let alias = if matches!(self.current_token, Token::As) {
1297 self.advance();
1298 match &self.current_token {
1299 Token::Identifier(alias_name) => {
1300 let alias = alias_name.clone();
1301 self.advance();
1302 alias
1303 }
1304 Token::QuotedIdentifier(alias_name) => {
1305 let alias = alias_name.clone();
1306 self.advance();
1307 alias
1308 }
1309 token => {
1310 if let Some(keyword) = token.as_keyword_str() {
1311 return Err(format!(
1312 "Reserved keyword '{}' cannot be used as column alias. Use a different name or quote it with double quotes: \"{}\"",
1313 keyword,
1314 keyword.to_lowercase()
1315 ));
1316 } else {
1317 return Err("Expected alias name after AS".to_string());
1318 }
1319 }
1320 }
1321 } else {
1322 match &expr {
1324 SqlExpression::Column(col_ref) => col_ref.name.clone(),
1325 _ => format!("expr_{}", items.len() + 1), }
1327 };
1328
1329 let item = match expr {
1331 SqlExpression::Column(col_ref) if alias == col_ref.name => {
1332 SelectItem::Column {
1334 column: col_ref,
1335 leading_comments: vec![],
1336 trailing_comment: None,
1337 }
1338 }
1339 _ => {
1340 SelectItem::Expression {
1342 expr,
1343 alias,
1344 leading_comments: vec![],
1345 trailing_comment: None,
1346 }
1347 }
1348 };
1349
1350 items.push(item);
1351 }
1352
1353 if matches!(self.current_token, Token::Comma) {
1355 self.advance();
1356 } else {
1357 break;
1358 }
1359 }
1360
1361 Ok(items)
1362 }
1363
1364 fn parse_identifier_list(&mut self) -> Result<Vec<String>, String> {
1365 let mut identifiers = Vec::new();
1366
1367 loop {
1368 match &self.current_token {
1369 Token::Identifier(id) => {
1370 if Self::is_identifier_reserved(id) {
1372 break;
1374 }
1375 let mut name = id.clone();
1376 self.advance();
1377
1378 if matches!(self.current_token, Token::Dot) {
1380 self.advance(); match &self.current_token {
1382 Token::Identifier(col) => {
1383 name = format!("{}.{}", name, col);
1384 self.advance();
1385 }
1386 Token::QuotedIdentifier(col) => {
1387 name = format!("{}.{}", name, col);
1388 self.advance();
1389 }
1390 _ => {
1391 return Err("Expected identifier after '.'".to_string());
1392 }
1393 }
1394 }
1395
1396 identifiers.push(name);
1397 }
1398 Token::QuotedIdentifier(id) => {
1399 identifiers.push(id.clone());
1401 self.advance();
1402 }
1403 _ => {
1404 break;
1406 }
1407 }
1408
1409 if matches!(self.current_token, Token::Comma) {
1410 self.advance();
1411 } else {
1412 break;
1413 }
1414 }
1415
1416 if identifiers.is_empty() {
1417 return Err("Expected at least one identifier".to_string());
1418 }
1419
1420 Ok(identifiers)
1421 }
1422
1423 fn parse_window_spec(&mut self) -> Result<WindowSpec, String> {
1424 let mut partition_by = Vec::new();
1425 let mut order_by = Vec::new();
1426
1427 if matches!(self.current_token, Token::Partition) {
1429 self.advance(); if !matches!(self.current_token, Token::By) {
1431 return Err("Expected BY after PARTITION".to_string());
1432 }
1433 self.advance(); partition_by = self.parse_identifier_list()?;
1437 }
1438
1439 if matches!(self.current_token, Token::OrderBy) {
1441 self.advance(); order_by = self.parse_order_by_list()?;
1443 } else if let Token::Identifier(s) = &self.current_token {
1444 if Self::is_identifier_reserved(s) && s.to_uppercase() == "ORDER" {
1445 self.advance(); if !matches!(self.current_token, Token::By) {
1448 return Err("Expected BY after ORDER".to_string());
1449 }
1450 self.advance(); order_by = self.parse_order_by_list()?;
1452 }
1453 }
1454
1455 let mut frame = self.parse_window_frame()?;
1457
1458 if !order_by.is_empty() && frame.is_none() {
1462 frame = Some(WindowFrame {
1463 unit: FrameUnit::Range,
1464 start: FrameBound::UnboundedPreceding,
1465 end: Some(FrameBound::CurrentRow),
1466 });
1467 }
1468
1469 Ok(WindowSpec {
1470 partition_by,
1471 order_by,
1472 frame,
1473 })
1474 }
1475
1476 fn parse_order_by_list(&mut self) -> Result<Vec<OrderByItem>, String> {
1477 let mut order_items = Vec::new();
1478
1479 loop {
1480 let expr = self.parse_expression()?;
1488
1489 let direction = match &self.current_token {
1491 Token::Asc => {
1492 self.advance();
1493 SortDirection::Asc
1494 }
1495 Token::Desc => {
1496 self.advance();
1497 SortDirection::Desc
1498 }
1499 _ => SortDirection::Asc, };
1501
1502 order_items.push(OrderByItem { expr, direction });
1503
1504 if matches!(self.current_token, Token::Comma) {
1505 self.advance();
1506 } else {
1507 break;
1508 }
1509 }
1510
1511 Ok(order_items)
1512 }
1513
1514 fn parse_into_clause(&mut self) -> Result<IntoTable, String> {
1517 let name = match &self.current_token {
1519 Token::Identifier(id) if id.starts_with('#') => {
1520 let table_name = id.clone();
1521 self.advance();
1522 table_name
1523 }
1524 Token::Identifier(id) => {
1525 return Err(format!(
1526 "Temporary table name must start with #, got: {}",
1527 id
1528 ));
1529 }
1530 _ => {
1531 return Err(
1532 "Expected temporary table name (starting with #) after INTO".to_string()
1533 );
1534 }
1535 };
1536
1537 Ok(IntoTable { name })
1538 }
1539
1540 fn parse_window_frame(&mut self) -> Result<Option<WindowFrame>, String> {
1541 let unit = match &self.current_token {
1543 Token::Rows => {
1544 self.advance();
1545 FrameUnit::Rows
1546 }
1547 Token::Identifier(id) if id.to_uppercase() == "RANGE" => {
1548 self.advance();
1550 FrameUnit::Range
1551 }
1552 _ => return Ok(None), };
1554
1555 let (start, end) = if let Token::Between = &self.current_token {
1557 self.advance(); let start = self.parse_frame_bound()?;
1560
1561 if !matches!(&self.current_token, Token::And) {
1563 return Err("Expected AND after window frame start bound".to_string());
1564 }
1565 self.advance();
1566
1567 let end = self.parse_frame_bound()?;
1569 (start, Some(end))
1570 } else {
1571 let bound = self.parse_frame_bound()?;
1573 (bound, None)
1574 };
1575
1576 Ok(Some(WindowFrame { unit, start, end }))
1577 }
1578
1579 fn parse_frame_bound(&mut self) -> Result<FrameBound, String> {
1580 match &self.current_token {
1581 Token::Unbounded => {
1582 self.advance();
1583 match &self.current_token {
1584 Token::Preceding => {
1585 self.advance();
1586 Ok(FrameBound::UnboundedPreceding)
1587 }
1588 Token::Following => {
1589 self.advance();
1590 Ok(FrameBound::UnboundedFollowing)
1591 }
1592 _ => Err("Expected PRECEDING or FOLLOWING after UNBOUNDED".to_string()),
1593 }
1594 }
1595 Token::Current => {
1596 self.advance();
1597 if matches!(&self.current_token, Token::Row) {
1598 self.advance();
1599 return Ok(FrameBound::CurrentRow);
1600 }
1601 Err("Expected ROW after CURRENT".to_string())
1602 }
1603 Token::NumberLiteral(num) => {
1604 let n: i64 = num
1605 .parse()
1606 .map_err(|_| "Invalid number in window frame".to_string())?;
1607 self.advance();
1608 match &self.current_token {
1609 Token::Preceding => {
1610 self.advance();
1611 Ok(FrameBound::Preceding(n))
1612 }
1613 Token::Following => {
1614 self.advance();
1615 Ok(FrameBound::Following(n))
1616 }
1617 _ => Err("Expected PRECEDING or FOLLOWING after number".to_string()),
1618 }
1619 }
1620 _ => Err("Invalid window frame bound".to_string()),
1621 }
1622 }
1623
1624 fn parse_where_clause(&mut self) -> Result<WhereClause, String> {
1625 let expr = self.parse_expression()?;
1628
1629 if matches!(self.current_token, Token::RightParen) && self.paren_depth <= 0 {
1631 return Err(
1632 "Unexpected closing parenthesis - no matching opening parenthesis".to_string(),
1633 );
1634 }
1635
1636 let conditions = vec![Condition {
1638 expr,
1639 connector: None,
1640 }];
1641
1642 Ok(WhereClause { conditions })
1643 }
1644
1645 fn parse_expression(&mut self) -> Result<SqlExpression, String> {
1646 self.trace_enter("parse_expression");
1647 let mut left = self.parse_logical_or()?;
1650
1651 left = parse_in_operator(self, left)?;
1654
1655 let result = Ok(left);
1656 self.trace_exit("parse_expression", &result);
1657 result
1658 }
1659
1660 fn parse_comparison(&mut self) -> Result<SqlExpression, String> {
1661 parse_comparison_expr(self)
1663 }
1664
1665 fn parse_additive(&mut self) -> Result<SqlExpression, String> {
1666 parse_additive_expr(self)
1668 }
1669
1670 fn parse_multiplicative(&mut self) -> Result<SqlExpression, String> {
1671 parse_multiplicative_expr(self)
1673 }
1674
1675 fn parse_logical_or(&mut self) -> Result<SqlExpression, String> {
1676 parse_logical_or_expr(self)
1678 }
1679
1680 fn parse_logical_and(&mut self) -> Result<SqlExpression, String> {
1681 parse_logical_and_expr(self)
1683 }
1684
1685 fn parse_case_expression(&mut self) -> Result<SqlExpression, String> {
1686 parse_case_expr(self)
1688 }
1689
1690 fn parse_primary(&mut self) -> Result<SqlExpression, String> {
1691 let columns = self.columns.clone();
1694 let in_method_args = self.in_method_args;
1695 let ctx = PrimaryExpressionContext {
1696 columns: &columns,
1697 in_method_args,
1698 };
1699 parse_primary_expr(self, &ctx)
1700 }
1701
1702 fn parse_method_args(&mut self) -> Result<Vec<SqlExpression>, String> {
1704 self.in_method_args = true;
1706
1707 let args = self.parse_argument_list()?;
1708
1709 self.in_method_args = false;
1711
1712 Ok(args)
1713 }
1714
1715 fn parse_function_args(&mut self) -> Result<(Vec<SqlExpression>, bool), String> {
1716 let mut args = Vec::new();
1717 let mut has_distinct = false;
1718
1719 if !matches!(self.current_token, Token::RightParen) {
1720 if matches!(self.current_token, Token::Distinct) {
1722 self.advance(); has_distinct = true;
1724 }
1725
1726 args.push(self.parse_logical_or()?);
1729
1730 while matches!(self.current_token, Token::Comma) {
1732 self.advance();
1733 args.push(self.parse_logical_or()?);
1734 }
1735 }
1736
1737 Ok((args, has_distinct))
1738 }
1739
1740 fn parse_expression_list(&mut self) -> Result<Vec<SqlExpression>, String> {
1741 let mut expressions = Vec::new();
1742
1743 loop {
1744 expressions.push(self.parse_expression()?);
1745
1746 if matches!(self.current_token, Token::Comma) {
1747 self.advance();
1748 } else {
1749 break;
1750 }
1751 }
1752
1753 Ok(expressions)
1754 }
1755
1756 #[must_use]
1757 pub fn get_position(&self) -> usize {
1758 self.lexer.get_position()
1759 }
1760
1761 fn is_join_token(&self) -> bool {
1763 matches!(
1764 self.current_token,
1765 Token::Join | Token::Inner | Token::Left | Token::Right | Token::Full | Token::Cross
1766 )
1767 }
1768
1769 fn parse_join_clause(&mut self) -> Result<JoinClause, String> {
1771 let join_type = match &self.current_token {
1773 Token::Join => {
1774 self.advance();
1775 JoinType::Inner }
1777 Token::Inner => {
1778 self.advance();
1779 if !matches!(self.current_token, Token::Join) {
1780 return Err("Expected JOIN after INNER".to_string());
1781 }
1782 self.advance();
1783 JoinType::Inner
1784 }
1785 Token::Left => {
1786 self.advance();
1787 if matches!(self.current_token, Token::Outer) {
1789 self.advance();
1790 }
1791 if !matches!(self.current_token, Token::Join) {
1792 return Err("Expected JOIN after LEFT".to_string());
1793 }
1794 self.advance();
1795 JoinType::Left
1796 }
1797 Token::Right => {
1798 self.advance();
1799 if matches!(self.current_token, Token::Outer) {
1801 self.advance();
1802 }
1803 if !matches!(self.current_token, Token::Join) {
1804 return Err("Expected JOIN after RIGHT".to_string());
1805 }
1806 self.advance();
1807 JoinType::Right
1808 }
1809 Token::Full => {
1810 self.advance();
1811 if matches!(self.current_token, Token::Outer) {
1813 self.advance();
1814 }
1815 if !matches!(self.current_token, Token::Join) {
1816 return Err("Expected JOIN after FULL".to_string());
1817 }
1818 self.advance();
1819 JoinType::Full
1820 }
1821 Token::Cross => {
1822 self.advance();
1823 if !matches!(self.current_token, Token::Join) {
1824 return Err("Expected JOIN after CROSS".to_string());
1825 }
1826 self.advance();
1827 JoinType::Cross
1828 }
1829 _ => return Err("Expected JOIN keyword".to_string()),
1830 };
1831
1832 let (table, alias) = self.parse_join_table_source()?;
1834
1835 let condition = if join_type == JoinType::Cross {
1837 JoinCondition { conditions: vec![] }
1839 } else {
1840 if !matches!(self.current_token, Token::On) {
1841 return Err("Expected ON keyword after JOIN table".to_string());
1842 }
1843 self.advance();
1844 self.parse_join_condition()?
1845 };
1846
1847 Ok(JoinClause {
1848 join_type,
1849 table,
1850 alias,
1851 condition,
1852 })
1853 }
1854
1855 fn parse_join_table_source(&mut self) -> Result<(TableSource, Option<String>), String> {
1856 let table = match &self.current_token {
1857 Token::Identifier(name) => {
1858 let table_name = name.clone();
1859 self.advance();
1860 TableSource::Table(table_name)
1861 }
1862 Token::LeftParen => {
1863 self.advance();
1865 let subquery = self.parse_select_statement_inner()?;
1866 if !matches!(self.current_token, Token::RightParen) {
1867 return Err("Expected ')' after subquery".to_string());
1868 }
1869 self.advance();
1870
1871 let alias = match &self.current_token {
1873 Token::Identifier(alias_name) => {
1874 let alias = alias_name.clone();
1875 self.advance();
1876 alias
1877 }
1878 Token::As => {
1879 self.advance();
1880 match &self.current_token {
1881 Token::Identifier(alias_name) => {
1882 let alias = alias_name.clone();
1883 self.advance();
1884 alias
1885 }
1886 _ => return Err("Expected alias after AS keyword".to_string()),
1887 }
1888 }
1889 _ => return Err("Subqueries must have an alias".to_string()),
1890 };
1891
1892 return Ok((
1893 TableSource::DerivedTable {
1894 query: Box::new(subquery),
1895 alias: alias.clone(),
1896 },
1897 Some(alias),
1898 ));
1899 }
1900 _ => return Err("Expected table name or subquery in JOIN clause".to_string()),
1901 };
1902
1903 let alias = match &self.current_token {
1905 Token::Identifier(alias_name) => {
1906 let alias = alias_name.clone();
1907 self.advance();
1908 Some(alias)
1909 }
1910 Token::As => {
1911 self.advance();
1912 match &self.current_token {
1913 Token::Identifier(alias_name) => {
1914 let alias = alias_name.clone();
1915 self.advance();
1916 Some(alias)
1917 }
1918 _ => return Err("Expected alias after AS keyword".to_string()),
1919 }
1920 }
1921 _ => None,
1922 };
1923
1924 Ok((table, alias))
1925 }
1926
1927 fn parse_join_condition(&mut self) -> Result<JoinCondition, String> {
1928 let mut conditions = Vec::new();
1929
1930 conditions.push(self.parse_single_join_condition()?);
1932
1933 while matches!(self.current_token, Token::And) {
1935 self.advance(); conditions.push(self.parse_single_join_condition()?);
1937 }
1938
1939 Ok(JoinCondition { conditions })
1940 }
1941
1942 fn parse_single_join_condition(&mut self) -> Result<SingleJoinCondition, String> {
1943 let left_expr = self.parse_additive()?;
1946
1947 let operator = match &self.current_token {
1949 Token::Equal => JoinOperator::Equal,
1950 Token::NotEqual => JoinOperator::NotEqual,
1951 Token::LessThan => JoinOperator::LessThan,
1952 Token::LessThanOrEqual => JoinOperator::LessThanOrEqual,
1953 Token::GreaterThan => JoinOperator::GreaterThan,
1954 Token::GreaterThanOrEqual => JoinOperator::GreaterThanOrEqual,
1955 _ => return Err("Expected comparison operator in JOIN condition".to_string()),
1956 };
1957 self.advance();
1958
1959 let right_expr = self.parse_additive()?;
1961
1962 Ok(SingleJoinCondition {
1963 left_expr,
1964 operator,
1965 right_expr,
1966 })
1967 }
1968
1969 fn parse_column_reference(&mut self) -> Result<String, String> {
1970 match &self.current_token {
1971 Token::Identifier(name) => {
1972 let mut column_ref = name.clone();
1973 self.advance();
1974
1975 if matches!(self.current_token, Token::Dot) {
1977 self.advance();
1978 match &self.current_token {
1979 Token::Identifier(col_name) => {
1980 column_ref.push('.');
1981 column_ref.push_str(col_name);
1982 self.advance();
1983 }
1984 _ => return Err("Expected column name after '.'".to_string()),
1985 }
1986 }
1987
1988 Ok(column_ref)
1989 }
1990 _ => Err("Expected column reference".to_string()),
1991 }
1992 }
1993
1994 fn parse_pivot_clause(&mut self, source: TableSource) -> Result<TableSource, String> {
1999 self.consume(Token::Pivot)?;
2001
2002 self.consume(Token::LeftParen)?;
2004
2005 let aggregate = self.parse_pivot_aggregate()?;
2007
2008 self.consume(Token::For)?;
2010
2011 let pivot_column = match &self.current_token {
2013 Token::Identifier(col) => {
2014 let column = col.clone();
2015 self.advance();
2016 column
2017 }
2018 Token::QuotedIdentifier(col) => {
2019 let column = col.clone();
2020 self.advance();
2021 column
2022 }
2023 _ => return Err("Expected column name after FOR in PIVOT".to_string()),
2024 };
2025
2026 if !matches!(self.current_token, Token::In) {
2028 return Err("Expected IN keyword in PIVOT clause".to_string());
2029 }
2030 self.advance();
2031
2032 let pivot_values = self.parse_pivot_in_clause()?;
2034
2035 self.consume(Token::RightParen)?;
2037
2038 let alias = self.parse_optional_alias()?;
2040
2041 Ok(TableSource::Pivot {
2042 source: Box::new(source),
2043 aggregate,
2044 pivot_column,
2045 pivot_values,
2046 alias,
2047 })
2048 }
2049
2050 fn parse_pivot_aggregate(&mut self) -> Result<PivotAggregate, String> {
2053 let function = match &self.current_token {
2055 Token::Identifier(name) => {
2056 let func_name = name.to_uppercase();
2057 match func_name.as_str() {
2059 "MAX" | "MIN" | "SUM" | "AVG" | "COUNT" => {
2060 self.advance();
2061 func_name
2062 }
2063 _ => {
2064 return Err(format!(
2065 "Expected aggregate function (MAX, MIN, SUM, AVG, COUNT), got {}",
2066 func_name
2067 ))
2068 }
2069 }
2070 }
2071 _ => return Err("Expected aggregate function in PIVOT".to_string()),
2072 };
2073
2074 self.consume(Token::LeftParen)?;
2076
2077 let column = match &self.current_token {
2079 Token::Identifier(col) => {
2080 let column = col.clone();
2081 self.advance();
2082 column
2083 }
2084 Token::QuotedIdentifier(col) => {
2085 let column = col.clone();
2086 self.advance();
2087 column
2088 }
2089 Token::Star => {
2090 if function == "COUNT" {
2092 self.advance();
2093 "*".to_string()
2094 } else {
2095 return Err(format!("Only COUNT can use *, not {}", function));
2096 }
2097 }
2098 _ => return Err("Expected column name in aggregate function".to_string()),
2099 };
2100
2101 self.consume(Token::RightParen)?;
2103
2104 Ok(PivotAggregate { function, column })
2105 }
2106
2107 fn parse_pivot_in_clause(&mut self) -> Result<Vec<String>, String> {
2111 self.consume(Token::LeftParen)?;
2113
2114 let mut values = Vec::new();
2115
2116 match &self.current_token {
2118 Token::StringLiteral(val) => {
2119 values.push(val.clone());
2120 self.advance();
2121 }
2122 Token::Identifier(val) => {
2123 values.push(val.clone());
2125 self.advance();
2126 }
2127 Token::NumberLiteral(val) => {
2128 values.push(val.clone());
2130 self.advance();
2131 }
2132 _ => return Err("Expected value in PIVOT IN clause".to_string()),
2133 }
2134
2135 while matches!(self.current_token, Token::Comma) {
2137 self.advance(); match &self.current_token {
2140 Token::StringLiteral(val) => {
2141 values.push(val.clone());
2142 self.advance();
2143 }
2144 Token::Identifier(val) => {
2145 values.push(val.clone());
2146 self.advance();
2147 }
2148 Token::NumberLiteral(val) => {
2149 values.push(val.clone());
2150 self.advance();
2151 }
2152 _ => return Err("Expected value after comma in PIVOT IN clause".to_string()),
2153 }
2154 }
2155
2156 self.consume(Token::RightParen)?;
2158
2159 if values.is_empty() {
2160 return Err("PIVOT IN clause must have at least one value".to_string());
2161 }
2162
2163 Ok(values)
2164 }
2165}
2166
2167#[derive(Debug, Clone)]
2169pub enum CursorContext {
2170 SelectClause,
2171 FromClause,
2172 WhereClause,
2173 OrderByClause,
2174 AfterColumn(String),
2175 AfterLogicalOp(LogicalOp),
2176 AfterComparisonOp(String, String), InMethodCall(String, String), InExpression,
2179 Unknown,
2180}
2181
2182fn safe_slice_to(s: &str, pos: usize) -> &str {
2184 if pos >= s.len() {
2185 return s;
2186 }
2187
2188 let mut safe_pos = pos;
2190 while safe_pos > 0 && !s.is_char_boundary(safe_pos) {
2191 safe_pos -= 1;
2192 }
2193
2194 &s[..safe_pos]
2195}
2196
2197fn safe_slice_from(s: &str, pos: usize) -> &str {
2199 if pos >= s.len() {
2200 return "";
2201 }
2202
2203 let mut safe_pos = pos;
2205 while safe_pos < s.len() && !s.is_char_boundary(safe_pos) {
2206 safe_pos += 1;
2207 }
2208
2209 &s[safe_pos..]
2210}
2211
2212#[must_use]
2213pub fn detect_cursor_context(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
2214 let truncated = safe_slice_to(query, cursor_pos);
2215 let mut parser = Parser::new(truncated);
2216
2217 if let Ok(stmt) = parser.parse() {
2219 let (ctx, partial) = analyze_statement(&stmt, truncated, cursor_pos);
2220 #[cfg(test)]
2221 println!("analyze_statement returned: {ctx:?}, {partial:?} for query: '{truncated}'");
2222 (ctx, partial)
2223 } else {
2224 let (ctx, partial) = analyze_partial(truncated, cursor_pos);
2226 #[cfg(test)]
2227 println!("analyze_partial returned: {ctx:?}, {partial:?} for query: '{truncated}'");
2228 (ctx, partial)
2229 }
2230}
2231
2232#[must_use]
2233pub fn tokenize_query(query: &str) -> Vec<String> {
2234 let mut lexer = Lexer::new(query);
2235 let tokens = lexer.tokenize_all();
2236 tokens.iter().map(|t| format!("{t:?}")).collect()
2237}
2238
2239#[must_use]
2240fn find_quote_start(bytes: &[u8], mut pos: usize) -> Option<usize> {
2242 if pos > 0 {
2244 pos -= 1;
2245 while pos > 0 {
2246 if bytes[pos] == b'"' {
2247 if pos == 0 || bytes[pos - 1] != b'\\' {
2249 return Some(pos);
2250 }
2251 }
2252 pos -= 1;
2253 }
2254 if bytes[0] == b'"' {
2256 return Some(0);
2257 }
2258 }
2259 None
2260}
2261
2262fn handle_method_call_context(col_name: &str, after_dot: &str) -> (CursorContext, Option<String>) {
2264 let partial_method = if after_dot.is_empty() {
2266 None
2267 } else if after_dot.chars().all(|c| c.is_alphanumeric() || c == '_') {
2268 Some(after_dot.to_string())
2269 } else {
2270 None
2271 };
2272
2273 let col_name_for_context =
2275 if col_name.starts_with('"') && col_name.ends_with('"') && col_name.len() > 2 {
2276 col_name[1..col_name.len() - 1].to_string()
2277 } else {
2278 col_name.to_string()
2279 };
2280
2281 (
2282 CursorContext::AfterColumn(col_name_for_context),
2283 partial_method,
2284 )
2285}
2286
2287fn check_after_comparison_operator(query: &str) -> Option<(CursorContext, Option<String>)> {
2289 for op in &Parser::COMPARISON_OPERATORS {
2290 if let Some(op_pos) = query.rfind(op) {
2291 let before_op = safe_slice_to(query, op_pos);
2292 let after_op_start = op_pos + op.len();
2293 let after_op = if after_op_start < query.len() {
2294 &query[after_op_start..]
2295 } else {
2296 ""
2297 };
2298
2299 if let Some(col_name) = before_op.split_whitespace().last() {
2301 if col_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
2302 let after_op_trimmed = after_op.trim();
2304 if after_op_trimmed.is_empty()
2305 || (after_op_trimmed
2306 .chars()
2307 .all(|c| c.is_alphanumeric() || c == '_')
2308 && !after_op_trimmed.contains('('))
2309 {
2310 let partial = if after_op_trimmed.is_empty() {
2311 None
2312 } else {
2313 Some(after_op_trimmed.to_string())
2314 };
2315 return Some((
2316 CursorContext::AfterComparisonOp(
2317 col_name.to_string(),
2318 op.trim().to_string(),
2319 ),
2320 partial,
2321 ));
2322 }
2323 }
2324 }
2325 }
2326 }
2327 None
2328}
2329
2330fn analyze_statement(
2331 stmt: &SelectStatement,
2332 query: &str,
2333 _cursor_pos: usize,
2334) -> (CursorContext, Option<String>) {
2335 let trimmed = query.trim();
2337
2338 if let Some(result) = check_after_comparison_operator(query) {
2340 return result;
2341 }
2342
2343 let ends_with_logical_op = |s: &str| -> bool {
2346 let s_upper = s.to_uppercase();
2347 s_upper.ends_with(" AND") || s_upper.ends_with(" OR")
2348 };
2349
2350 if ends_with_logical_op(trimmed) {
2351 } else {
2353 if let Some(dot_pos) = trimmed.rfind('.') {
2355 let before_dot = safe_slice_to(trimmed, dot_pos);
2357 let after_dot_start = dot_pos + 1;
2358 let after_dot = if after_dot_start < trimmed.len() {
2359 &trimmed[after_dot_start..]
2360 } else {
2361 ""
2362 };
2363
2364 if !after_dot.contains('(') {
2367 let col_name = if before_dot.ends_with('"') {
2369 let bytes = before_dot.as_bytes();
2371 let pos = before_dot.len() - 1; find_quote_start(bytes, pos).map(|start| safe_slice_from(before_dot, start))
2374 } else {
2375 before_dot
2378 .split_whitespace()
2379 .last()
2380 .map(|word| word.trim_start_matches('('))
2381 };
2382
2383 if let Some(col_name) = col_name {
2384 let is_valid = Parser::is_valid_identifier(col_name);
2386
2387 if is_valid {
2388 return handle_method_call_context(col_name, after_dot);
2389 }
2390 }
2391 }
2392 }
2393 }
2394
2395 if let Some(where_clause) = &stmt.where_clause {
2397 let trimmed_upper = trimmed.to_uppercase();
2399 if trimmed_upper.ends_with(" AND") || trimmed_upper.ends_with(" OR") {
2400 let op = if trimmed_upper.ends_with(" AND") {
2401 LogicalOp::And
2402 } else {
2403 LogicalOp::Or
2404 };
2405 return (CursorContext::AfterLogicalOp(op), None);
2406 }
2407
2408 let query_upper = query.to_uppercase();
2410 if let Some(and_pos) = query_upper.rfind(" AND ") {
2411 let after_and = safe_slice_from(query, and_pos + 5);
2412 let partial = extract_partial_at_end(after_and);
2413 if partial.is_some() {
2414 return (CursorContext::AfterLogicalOp(LogicalOp::And), partial);
2415 }
2416 }
2417
2418 if let Some(or_pos) = query_upper.rfind(" OR ") {
2419 let after_or = safe_slice_from(query, or_pos + 4);
2420 let partial = extract_partial_at_end(after_or);
2421 if partial.is_some() {
2422 return (CursorContext::AfterLogicalOp(LogicalOp::Or), partial);
2423 }
2424 }
2425
2426 if let Some(last_condition) = where_clause.conditions.last() {
2427 if let Some(connector) = &last_condition.connector {
2428 return (
2430 CursorContext::AfterLogicalOp(connector.clone()),
2431 extract_partial_at_end(query),
2432 );
2433 }
2434 }
2435 return (CursorContext::WhereClause, extract_partial_at_end(query));
2437 }
2438
2439 let query_upper = query.to_uppercase();
2441 if query_upper.ends_with(" ORDER BY") {
2442 return (CursorContext::OrderByClause, None);
2443 }
2444
2445 if stmt.order_by.is_some() {
2447 return (CursorContext::OrderByClause, extract_partial_at_end(query));
2448 }
2449
2450 if stmt.from_table.is_some() && stmt.where_clause.is_none() && stmt.order_by.is_none() {
2451 return (CursorContext::FromClause, extract_partial_at_end(query));
2452 }
2453
2454 if !stmt.columns.is_empty() && stmt.from_table.is_none() {
2455 return (CursorContext::SelectClause, extract_partial_at_end(query));
2456 }
2457
2458 (CursorContext::Unknown, None)
2459}
2460
2461fn find_last_token(tokens: &[(usize, usize, Token)], target: &Token) -> Option<usize> {
2463 tokens
2464 .iter()
2465 .rposition(|(_, _, t)| t == target)
2466 .map(|idx| tokens[idx].0)
2467}
2468
2469fn find_last_matching_token<F>(
2471 tokens: &[(usize, usize, Token)],
2472 predicate: F,
2473) -> Option<(usize, &Token)>
2474where
2475 F: Fn(&Token) -> bool,
2476{
2477 tokens
2478 .iter()
2479 .rposition(|(_, _, t)| predicate(t))
2480 .map(|idx| (tokens[idx].0, &tokens[idx].2))
2481}
2482
2483fn is_in_clause(
2485 tokens: &[(usize, usize, Token)],
2486 clause_token: Token,
2487 exclude_tokens: &[Token],
2488) -> bool {
2489 if let Some(clause_pos) = find_last_token(tokens, &clause_token) {
2491 for (pos, _, token) in tokens.iter() {
2493 if *pos > clause_pos && exclude_tokens.contains(token) {
2494 return false;
2495 }
2496 }
2497 return true;
2498 }
2499 false
2500}
2501
2502fn analyze_partial(query: &str, cursor_pos: usize) -> (CursorContext, Option<String>) {
2503 let mut lexer = Lexer::new(query);
2505 let tokens = lexer.tokenize_all_with_positions();
2506
2507 let trimmed = query.trim();
2508
2509 #[cfg(test)]
2510 {
2511 if trimmed.contains("\"Last Name\"") {
2512 eprintln!("DEBUG analyze_partial: query='{query}', trimmed='{trimmed}'");
2513 }
2514 }
2515
2516 if let Some(result) = check_after_comparison_operator(query) {
2518 return result;
2519 }
2520
2521 if let Some(dot_pos) = trimmed.rfind('.') {
2524 #[cfg(test)]
2525 {
2526 if trimmed.contains("\"Last Name\"") {
2527 eprintln!("DEBUG: Found dot at position {dot_pos}");
2528 }
2529 }
2530 let before_dot = &trimmed[..dot_pos];
2532 let after_dot = &trimmed[dot_pos + 1..];
2533
2534 if !after_dot.contains('(') {
2537 let col_name = if before_dot.ends_with('"') {
2540 let bytes = before_dot.as_bytes();
2542 let pos = before_dot.len() - 1; #[cfg(test)]
2545 {
2546 if trimmed.contains("\"Last Name\"") {
2547 eprintln!("DEBUG: before_dot='{before_dot}', looking for opening quote");
2548 }
2549 }
2550
2551 let found_start = find_quote_start(bytes, pos);
2552
2553 if let Some(start) = found_start {
2554 let result = safe_slice_from(before_dot, start);
2556 #[cfg(test)]
2557 {
2558 if trimmed.contains("\"Last Name\"") {
2559 eprintln!("DEBUG: Extracted quoted identifier: '{result}'");
2560 }
2561 }
2562 Some(result)
2563 } else {
2564 #[cfg(test)]
2565 {
2566 if trimmed.contains("\"Last Name\"") {
2567 eprintln!("DEBUG: No opening quote found!");
2568 }
2569 }
2570 None
2571 }
2572 } else {
2573 before_dot
2576 .split_whitespace()
2577 .last()
2578 .map(|word| word.trim_start_matches('('))
2579 };
2580
2581 if let Some(col_name) = col_name {
2582 #[cfg(test)]
2583 {
2584 if trimmed.contains("\"Last Name\"") {
2585 eprintln!("DEBUG: col_name = '{col_name}'");
2586 }
2587 }
2588
2589 let is_valid = Parser::is_valid_identifier(col_name);
2591
2592 #[cfg(test)]
2593 {
2594 if trimmed.contains("\"Last Name\"") {
2595 eprintln!("DEBUG: is_valid = {is_valid}");
2596 }
2597 }
2598
2599 if is_valid {
2600 return handle_method_call_context(col_name, after_dot);
2601 }
2602 }
2603 }
2604 }
2605
2606 if let Some((pos, token)) =
2608 find_last_matching_token(&tokens, |t| matches!(t, Token::And | Token::Or))
2609 {
2610 let token_end_pos = if matches!(token, Token::And) {
2612 pos + 3 } else {
2614 pos + 2 };
2616
2617 if cursor_pos > token_end_pos {
2618 let after_op = safe_slice_from(query, token_end_pos + 1); let partial = extract_partial_at_end(after_op);
2621 let op = if matches!(token, Token::And) {
2622 LogicalOp::And
2623 } else {
2624 LogicalOp::Or
2625 };
2626 return (CursorContext::AfterLogicalOp(op), partial);
2627 }
2628 }
2629
2630 if let Some((_, _, last_token)) = tokens.last() {
2632 if matches!(last_token, Token::And | Token::Or) {
2633 let op = if matches!(last_token, Token::And) {
2634 LogicalOp::And
2635 } else {
2636 LogicalOp::Or
2637 };
2638 return (CursorContext::AfterLogicalOp(op), None);
2639 }
2640 }
2641
2642 if let Some(order_pos) = find_last_token(&tokens, &Token::OrderBy) {
2644 let has_by = tokens
2646 .iter()
2647 .any(|(pos, _, t)| *pos > order_pos && matches!(t, Token::By));
2648 if has_by
2649 || tokens
2650 .last()
2651 .map_or(false, |(_, _, t)| matches!(t, Token::OrderBy))
2652 {
2653 return (CursorContext::OrderByClause, extract_partial_at_end(query));
2654 }
2655 }
2656
2657 if is_in_clause(&tokens, Token::Where, &[Token::OrderBy, Token::GroupBy]) {
2659 return (CursorContext::WhereClause, extract_partial_at_end(query));
2660 }
2661
2662 if is_in_clause(
2664 &tokens,
2665 Token::From,
2666 &[Token::Where, Token::OrderBy, Token::GroupBy],
2667 ) {
2668 return (CursorContext::FromClause, extract_partial_at_end(query));
2669 }
2670
2671 if find_last_token(&tokens, &Token::Select).is_some()
2673 && find_last_token(&tokens, &Token::From).is_none()
2674 {
2675 return (CursorContext::SelectClause, extract_partial_at_end(query));
2676 }
2677
2678 (CursorContext::Unknown, None)
2679}
2680
2681fn extract_partial_at_end(query: &str) -> Option<String> {
2682 let trimmed = query.trim();
2683
2684 if let Some(last_word) = trimmed.split_whitespace().last() {
2686 if last_word.starts_with('"') && !last_word.ends_with('"') {
2687 return Some(last_word.to_string());
2689 }
2690 }
2691
2692 let last_word = trimmed.split_whitespace().last()?;
2694
2695 if last_word.chars().all(|c| c.is_alphanumeric() || c == '_') {
2698 if !is_sql_keyword(last_word) {
2700 Some(last_word.to_string())
2701 } else {
2702 None
2703 }
2704 } else {
2705 None
2706 }
2707}
2708
2709impl ParsePrimary for Parser {
2711 fn current_token(&self) -> &Token {
2712 &self.current_token
2713 }
2714
2715 fn advance(&mut self) {
2716 self.advance();
2717 }
2718
2719 fn consume(&mut self, expected: Token) -> Result<(), String> {
2720 self.consume(expected)
2721 }
2722
2723 fn parse_case_expression(&mut self) -> Result<SqlExpression, String> {
2724 self.parse_case_expression()
2725 }
2726
2727 fn parse_function_args(&mut self) -> Result<(Vec<SqlExpression>, bool), String> {
2728 self.parse_function_args()
2729 }
2730
2731 fn parse_window_spec(&mut self) -> Result<WindowSpec, String> {
2732 self.parse_window_spec()
2733 }
2734
2735 fn parse_logical_or(&mut self) -> Result<SqlExpression, String> {
2736 self.parse_logical_or()
2737 }
2738
2739 fn parse_comparison(&mut self) -> Result<SqlExpression, String> {
2740 self.parse_comparison()
2741 }
2742
2743 fn parse_expression_list(&mut self) -> Result<Vec<SqlExpression>, String> {
2744 self.parse_expression_list()
2745 }
2746
2747 fn parse_subquery(&mut self) -> Result<SelectStatement, String> {
2748 if matches!(self.current_token, Token::With) {
2750 self.parse_with_clause_inner()
2751 } else {
2752 self.parse_select_statement_inner()
2753 }
2754 }
2755}
2756
2757impl ExpressionParser for Parser {
2759 fn current_token(&self) -> &Token {
2760 &self.current_token
2761 }
2762
2763 fn advance(&mut self) {
2764 match &self.current_token {
2766 Token::LeftParen => self.paren_depth += 1,
2767 Token::RightParen => {
2768 self.paren_depth -= 1;
2769 }
2770 _ => {}
2771 }
2772 self.current_token = self.lexer.next_token();
2773 }
2774
2775 fn peek(&self) -> Option<&Token> {
2776 None }
2783
2784 fn is_at_end(&self) -> bool {
2785 matches!(self.current_token, Token::Eof)
2786 }
2787
2788 fn consume(&mut self, expected: Token) -> Result<(), String> {
2789 if std::mem::discriminant(&self.current_token) == std::mem::discriminant(&expected) {
2791 self.update_paren_depth(&expected)?;
2792 self.current_token = self.lexer.next_token();
2793 Ok(())
2794 } else {
2795 Err(format!(
2796 "Expected {:?}, found {:?}",
2797 expected, self.current_token
2798 ))
2799 }
2800 }
2801
2802 fn parse_identifier(&mut self) -> Result<String, String> {
2803 if let Token::Identifier(id) = &self.current_token {
2804 let id = id.clone();
2805 self.advance();
2806 Ok(id)
2807 } else {
2808 Err(format!(
2809 "Expected identifier, found {:?}",
2810 self.current_token
2811 ))
2812 }
2813 }
2814}
2815
2816impl ParseArithmetic for Parser {
2818 fn current_token(&self) -> &Token {
2819 &self.current_token
2820 }
2821
2822 fn advance(&mut self) {
2823 self.advance();
2824 }
2825
2826 fn consume(&mut self, expected: Token) -> Result<(), String> {
2827 self.consume(expected)
2828 }
2829
2830 fn parse_primary(&mut self) -> Result<SqlExpression, String> {
2831 self.parse_primary()
2832 }
2833
2834 fn parse_multiplicative(&mut self) -> Result<SqlExpression, String> {
2835 self.parse_multiplicative()
2836 }
2837
2838 fn parse_method_args(&mut self) -> Result<Vec<SqlExpression>, String> {
2839 self.parse_method_args()
2840 }
2841}
2842
2843impl ParseComparison for Parser {
2845 fn current_token(&self) -> &Token {
2846 &self.current_token
2847 }
2848
2849 fn advance(&mut self) {
2850 self.advance();
2851 }
2852
2853 fn consume(&mut self, expected: Token) -> Result<(), String> {
2854 self.consume(expected)
2855 }
2856
2857 fn parse_primary(&mut self) -> Result<SqlExpression, String> {
2858 self.parse_primary()
2859 }
2860
2861 fn parse_additive(&mut self) -> Result<SqlExpression, String> {
2862 self.parse_additive()
2863 }
2864
2865 fn parse_expression_list(&mut self) -> Result<Vec<SqlExpression>, String> {
2866 self.parse_expression_list()
2867 }
2868
2869 fn parse_subquery(&mut self) -> Result<SelectStatement, String> {
2870 if matches!(self.current_token, Token::With) {
2872 self.parse_with_clause_inner()
2873 } else {
2874 self.parse_select_statement_inner()
2875 }
2876 }
2877}
2878
2879impl ParseLogical for Parser {
2881 fn current_token(&self) -> &Token {
2882 &self.current_token
2883 }
2884
2885 fn advance(&mut self) {
2886 self.advance();
2887 }
2888
2889 fn consume(&mut self, expected: Token) -> Result<(), String> {
2890 self.consume(expected)
2891 }
2892
2893 fn parse_logical_and(&mut self) -> Result<SqlExpression, String> {
2894 self.parse_logical_and()
2895 }
2896
2897 fn parse_base_logical_expression(&mut self) -> Result<SqlExpression, String> {
2898 self.parse_comparison()
2901 }
2902
2903 fn parse_comparison(&mut self) -> Result<SqlExpression, String> {
2904 self.parse_comparison()
2905 }
2906
2907 fn parse_expression_list(&mut self) -> Result<Vec<SqlExpression>, String> {
2908 self.parse_expression_list()
2909 }
2910}
2911
2912impl ParseCase for Parser {
2914 fn current_token(&self) -> &Token {
2915 &self.current_token
2916 }
2917
2918 fn advance(&mut self) {
2919 self.advance();
2920 }
2921
2922 fn consume(&mut self, expected: Token) -> Result<(), String> {
2923 self.consume(expected)
2924 }
2925
2926 fn parse_expression(&mut self) -> Result<SqlExpression, String> {
2927 self.parse_expression()
2928 }
2929}
2930
2931fn is_sql_keyword(word: &str) -> bool {
2932 let mut lexer = Lexer::new(word);
2934 let token = lexer.next_token();
2935
2936 !matches!(token, Token::Identifier(_) | Token::Eof)
2938}
2939
2940#[cfg(test)]
2941mod tests {
2942 use super::*;
2943
2944 #[test]
2946 fn test_parser_mode_default_is_standard() {
2947 let sql = "-- Leading comment\nSELECT * FROM users";
2948 let mut parser = Parser::new(sql);
2949 let stmt = parser.parse().unwrap();
2950
2951 assert!(stmt.leading_comments.is_empty());
2953 assert!(stmt.trailing_comment.is_none());
2954 }
2955
2956 #[test]
2958 fn test_parser_mode_preserve_leading_comments() {
2959 let sql = "-- Important query\n-- Author: Alice\nSELECT id, name FROM users";
2960 let mut parser = Parser::with_mode(sql, ParserMode::PreserveComments);
2961 let stmt = parser.parse().unwrap();
2962
2963 assert_eq!(stmt.leading_comments.len(), 2);
2965 assert!(stmt.leading_comments[0].is_line_comment);
2966 assert!(stmt.leading_comments[0].text.contains("Important query"));
2967 assert!(stmt.leading_comments[1].text.contains("Author: Alice"));
2968 }
2969
2970 #[test]
2972 fn test_parser_mode_preserve_trailing_comment() {
2973 let sql = "SELECT * FROM users -- Fetch all users";
2974 let mut parser = Parser::with_mode(sql, ParserMode::PreserveComments);
2975 let stmt = parser.parse().unwrap();
2976
2977 assert!(stmt.trailing_comment.is_some());
2979 let comment = stmt.trailing_comment.unwrap();
2980 assert!(comment.is_line_comment);
2981 assert!(comment.text.contains("Fetch all users"));
2982 }
2983
2984 #[test]
2986 fn test_parser_mode_preserve_block_comments() {
2987 let sql = "/* Query explanation */\nSELECT * FROM users";
2988 let mut parser = Parser::with_mode(sql, ParserMode::PreserveComments);
2989 let stmt = parser.parse().unwrap();
2990
2991 assert_eq!(stmt.leading_comments.len(), 1);
2993 assert!(!stmt.leading_comments[0].is_line_comment); assert!(stmt.leading_comments[0].text.contains("Query explanation"));
2995 }
2996
2997 #[test]
2999 fn test_parser_mode_preserve_both_comments() {
3000 let sql = "-- Leading\nSELECT * FROM users -- Trailing";
3001 let mut parser = Parser::with_mode(sql, ParserMode::PreserveComments);
3002 let stmt = parser.parse().unwrap();
3003
3004 assert_eq!(stmt.leading_comments.len(), 1);
3006 assert!(stmt.leading_comments[0].text.contains("Leading"));
3007 assert!(stmt.trailing_comment.is_some());
3008 assert!(stmt.trailing_comment.unwrap().text.contains("Trailing"));
3009 }
3010
3011 #[test]
3013 fn test_parser_mode_standard_ignores_comments() {
3014 let sql = "-- Comment 1\n/* Comment 2 */\nSELECT * FROM users -- Comment 3";
3015 let mut parser = Parser::with_mode(sql, ParserMode::Standard);
3016 let stmt = parser.parse().unwrap();
3017
3018 assert!(stmt.leading_comments.is_empty());
3020 assert!(stmt.trailing_comment.is_none());
3021
3022 assert_eq!(stmt.select_items.len(), 1);
3024 assert_eq!(stmt.from_table, Some("users".to_string()));
3025 }
3026
3027 #[test]
3029 fn test_parser_backward_compatibility() {
3030 let sql = "SELECT id, name FROM users WHERE active = true";
3031
3032 let mut parser1 = Parser::new(sql);
3034 let stmt1 = parser1.parse().unwrap();
3035
3036 let mut parser2 = Parser::with_mode(sql, ParserMode::Standard);
3038 let stmt2 = parser2.parse().unwrap();
3039
3040 assert_eq!(stmt1.select_items.len(), stmt2.select_items.len());
3042 assert_eq!(stmt1.from_table, stmt2.from_table);
3043 assert_eq!(stmt1.where_clause.is_some(), stmt2.where_clause.is_some());
3044 assert!(stmt1.leading_comments.is_empty());
3045 assert!(stmt2.leading_comments.is_empty());
3046 }
3047
3048 #[test]
3050 fn test_pivot_parsing_not_yet_supported() {
3051 let sql = "SELECT * FROM food_eaten PIVOT (MAX(AmountEaten) FOR FoodName IN ('Sammich', 'Pickle', 'Apple'))";
3052 let mut parser = Parser::new(sql);
3053 let result = parser.parse();
3054
3055 assert!(result.is_ok());
3057 let stmt = result.unwrap();
3058
3059 assert!(stmt.from_source.is_some());
3061 if let Some(crate::sql::parser::ast::TableSource::Pivot { .. }) = stmt.from_source {
3062 } else {
3064 panic!("Expected from_source to be a Pivot variant");
3065 }
3066 }
3067
3068 #[test]
3070 fn test_pivot_aggregate_functions() {
3071 let sql = "SELECT * FROM sales PIVOT (SUM(amount) FOR month IN ('Jan', 'Feb', 'Mar'))";
3073 let mut parser = Parser::new(sql);
3074 let result = parser.parse();
3075 assert!(result.is_ok());
3076
3077 let sql2 = "SELECT * FROM sales PIVOT (COUNT(*) FOR month IN ('Jan', 'Feb'))";
3079 let mut parser2 = Parser::new(sql2);
3080 let result2 = parser2.parse();
3081 assert!(result2.is_ok());
3082
3083 let sql3 = "SELECT * FROM sales PIVOT (AVG(price) FOR category IN ('A', 'B'))";
3085 let mut parser3 = Parser::new(sql3);
3086 let result3 = parser3.parse();
3087 assert!(result3.is_ok());
3088 }
3089
3090 #[test]
3092 fn test_pivot_with_subquery() {
3093 let sql = "SELECT * FROM (SELECT * FROM food_eaten WHERE Id > 5) AS t \
3094 PIVOT (MAX(AmountEaten) FOR FoodName IN ('Sammich', 'Pickle'))";
3095 let mut parser = Parser::new(sql);
3096 let result = parser.parse();
3097
3098 assert!(result.is_ok());
3100 let stmt = result.unwrap();
3101 assert!(stmt.from_source.is_some());
3102 }
3103
3104 #[test]
3106 fn test_pivot_with_alias() {
3107 let sql =
3108 "SELECT * FROM sales PIVOT (SUM(amount) FOR month IN ('Jan', 'Feb')) AS pivot_table";
3109 let mut parser = Parser::new(sql);
3110 let result = parser.parse();
3111
3112 assert!(result.is_ok());
3114 let stmt = result.unwrap();
3115 assert!(stmt.from_source.is_some());
3116 }
3117}