rustlite_core/query/
parser.rs

1/// Parser for SQL-like queries
2///
3/// Converts a stream of tokens into an Abstract Syntax Tree (AST).
4use super::ast::*;
5use super::lexer::{Lexer, LexerError, Token};
6use std::fmt;
7
8/// Parser for SQL-like queries
9pub struct Parser {
10    tokens: Vec<Token>,
11    position: usize,
12}
13
14impl Parser {
15    /// Create a new parser from SQL text
16    pub fn new(input: &str) -> Result<Self, ParseError> {
17        let mut lexer = Lexer::new(input);
18        let tokens = lexer.tokenize().map_err(ParseError::LexerError)?;
19        Ok(Self {
20            tokens,
21            position: 0,
22        })
23    }
24
25    /// Parse the query into an AST
26    pub fn parse(&mut self) -> Result<Query, ParseError> {
27        let select = self.parse_select()?;
28        let from = self.parse_from()?;
29        let where_clause = self.parse_where()?;
30        let group_by = self.parse_group_by()?;
31        let having = self.parse_having()?;
32        let order_by = self.parse_order_by()?;
33        let limit = self.parse_limit()?;
34
35        self.expect_token(Token::Eof)?;
36
37        Ok(Query {
38            select,
39            from,
40            where_clause,
41            group_by,
42            having,
43            order_by,
44            limit,
45        })
46    }
47
48    fn parse_select(&mut self) -> Result<SelectClause, ParseError> {
49        self.expect_token(Token::Select)?;
50
51        let mut columns = Vec::new();
52
53        loop {
54            if self.current_token() == &Token::Asterisk {
55                self.advance();
56                columns.push(SelectColumn::Wildcard);
57            } else if matches!(
58                self.current_token(),
59                Token::Count | Token::Sum | Token::Avg | Token::Min | Token::Max
60            ) {
61                // Aggregate function
62                let function = match self.current_token() {
63                    Token::Count => AggregateFunction::Count,
64                    Token::Sum => AggregateFunction::Sum,
65                    Token::Avg => AggregateFunction::Avg,
66                    Token::Min => AggregateFunction::Min,
67                    Token::Max => AggregateFunction::Max,
68                    _ => unreachable!(),
69                };
70                self.advance();
71
72                self.expect_token(Token::LeftParen)?;
73
74                let column = if self.current_token() == &Token::Asterisk {
75                    self.advance();
76                    Box::new(SelectColumn::Wildcard)
77                } else if let Token::Identifier(name) = self.current_token().clone() {
78                    self.advance();
79                    Box::new(SelectColumn::Column { name, alias: None })
80                } else {
81                    return Err(ParseError::UnexpectedToken {
82                        expected: "column name or *".to_string(),
83                        found: self.current_token().clone(),
84                    });
85                };
86
87                self.expect_token(Token::RightParen)?;
88
89                let alias = if self.current_token() == &Token::As {
90                    self.advance();
91                    if let Token::Identifier(name) = self.current_token().clone() {
92                        self.advance();
93                        Some(name)
94                    } else {
95                        None
96                    }
97                } else {
98                    None
99                };
100
101                columns.push(SelectColumn::Aggregate {
102                    function,
103                    column,
104                    alias,
105                });
106            } else if let Token::Identifier(name) = self.current_token().clone() {
107                self.advance();
108
109                let alias = if self.current_token() == &Token::As {
110                    self.advance();
111                    if let Token::Identifier(alias_name) = self.current_token().clone() {
112                        self.advance();
113                        Some(alias_name)
114                    } else {
115                        None
116                    }
117                } else {
118                    None
119                };
120
121                columns.push(SelectColumn::Column { name, alias });
122            } else {
123                return Err(ParseError::UnexpectedToken {
124                    expected: "column name or *".to_string(),
125                    found: self.current_token().clone(),
126                });
127            }
128
129            if self.current_token() == &Token::Comma {
130                self.advance();
131            } else {
132                break;
133            }
134        }
135
136        if columns.is_empty() {
137            return Err(ParseError::EmptySelectList);
138        }
139
140        Ok(SelectClause { columns })
141    }
142
143    fn parse_from(&mut self) -> Result<FromClause, ParseError> {
144        self.expect_token(Token::From)?;
145
146        let table = if let Token::Identifier(name) = self.current_token().clone() {
147            self.advance();
148            name
149        } else {
150            return Err(ParseError::UnexpectedToken {
151                expected: "table name".to_string(),
152                found: self.current_token().clone(),
153            });
154        };
155
156        let mut joins = Vec::new();
157
158        // Parse JOINs
159        while matches!(
160            self.current_token(),
161            Token::Inner | Token::Left | Token::Right | Token::Full | Token::Join
162        ) {
163            let join_type = match self.current_token() {
164                Token::Inner => {
165                    self.advance();
166                    self.expect_token(Token::Join)?;
167                    JoinType::Inner
168                }
169                Token::Left => {
170                    self.advance();
171                    self.expect_token(Token::Join)?;
172                    JoinType::Left
173                }
174                Token::Right => {
175                    self.advance();
176                    self.expect_token(Token::Join)?;
177                    JoinType::Right
178                }
179                Token::Full => {
180                    self.advance();
181                    self.expect_token(Token::Join)?;
182                    JoinType::Full
183                }
184                Token::Join => {
185                    self.advance();
186                    JoinType::Inner // Default to INNER JOIN
187                }
188                _ => break,
189            };
190
191            let join_table = if let Token::Identifier(name) = self.current_token().clone() {
192                self.advance();
193                name
194            } else {
195                return Err(ParseError::UnexpectedToken {
196                    expected: "table name".to_string(),
197                    found: self.current_token().clone(),
198                });
199            };
200
201            self.expect_token(Token::On)?;
202            let condition = self.parse_expression()?;
203
204            joins.push(Join {
205                join_type,
206                table: join_table,
207                condition,
208            });
209        }
210
211        Ok(FromClause { table, joins })
212    }
213
214    fn parse_where(&mut self) -> Result<Option<WhereClause>, ParseError> {
215        if self.current_token() != &Token::Where {
216            return Ok(None);
217        }
218
219        self.advance();
220        let condition = self.parse_expression()?;
221
222        Ok(Some(WhereClause { condition }))
223    }
224
225    fn parse_group_by(&mut self) -> Result<Option<GroupByClause>, ParseError> {
226        if self.current_token() != &Token::Group {
227            return Ok(None);
228        }
229
230        self.advance();
231        self.expect_token(Token::By)?;
232
233        let mut columns = Vec::new();
234
235        loop {
236            if let Token::Identifier(name) = self.current_token().clone() {
237                self.advance();
238                columns.push(name);
239
240                if self.current_token() == &Token::Comma {
241                    self.advance();
242                    continue;
243                } else {
244                    break;
245                }
246            } else {
247                return Err(ParseError::UnexpectedToken {
248                    expected: "column name".to_string(),
249                    found: self.current_token().clone(),
250                });
251            }
252        }
253
254        if columns.is_empty() {
255            return Err(ParseError::UnexpectedToken {
256                expected: "at least one column for GROUP BY".to_string(),
257                found: self.current_token().clone(),
258            });
259        }
260
261        Ok(Some(GroupByClause { columns }))
262    }
263
264    fn parse_having(&mut self) -> Result<Option<HavingClause>, ParseError> {
265        if self.current_token() != &Token::Having {
266            return Ok(None);
267        }
268
269        self.advance();
270        let condition = self.parse_expression()?;
271
272        Ok(Some(HavingClause { condition }))
273    }
274
275    fn parse_expression(&mut self) -> Result<Expression, ParseError> {
276        self.parse_logical_or()
277    }
278
279    fn parse_logical_or(&mut self) -> Result<Expression, ParseError> {
280        let mut left = self.parse_logical_and()?;
281
282        while self.current_token() == &Token::Or {
283            self.advance();
284            let right = self.parse_logical_and()?;
285            left = Expression::LogicalOp {
286                left: Box::new(left),
287                op: LogicalOperator::Or,
288                right: Box::new(right),
289            };
290        }
291
292        Ok(left)
293    }
294
295    fn parse_logical_and(&mut self) -> Result<Expression, ParseError> {
296        let mut left = self.parse_not()?;
297
298        while self.current_token() == &Token::And {
299            self.advance();
300            let right = self.parse_not()?;
301            left = Expression::LogicalOp {
302                left: Box::new(left),
303                op: LogicalOperator::And,
304                right: Box::new(right),
305            };
306        }
307
308        Ok(left)
309    }
310
311    fn parse_not(&mut self) -> Result<Expression, ParseError> {
312        if self.current_token() == &Token::Not {
313            self.advance();
314            let expr = self.parse_comparison()?;
315            return Ok(Expression::Not(Box::new(expr)));
316        }
317
318        self.parse_comparison()
319    }
320
321    fn parse_comparison(&mut self) -> Result<Expression, ParseError> {
322        let left = self.parse_primary()?;
323
324        // Handle LIKE
325        if self.current_token() == &Token::Like {
326            self.advance();
327            if let Token::String(pattern) = self.current_token().clone() {
328                self.advance();
329                return Ok(Expression::Like {
330                    expr: Box::new(left),
331                    pattern,
332                });
333            } else {
334                return Err(ParseError::UnexpectedToken {
335                    expected: "string pattern".to_string(),
336                    found: self.current_token().clone(),
337                });
338            }
339        }
340
341        // Handle IN
342        if self.current_token() == &Token::In {
343            self.advance();
344            self.expect_token(Token::LeftParen)?;
345
346            let mut values = Vec::new();
347            loop {
348                let value = self.parse_literal()?;
349                values.push(value);
350
351                if self.current_token() == &Token::Comma {
352                    self.advance();
353                } else {
354                    break;
355                }
356            }
357
358            self.expect_token(Token::RightParen)?;
359
360            return Ok(Expression::In {
361                expr: Box::new(left),
362                values,
363            });
364        }
365
366        // Handle BETWEEN
367        if self.current_token() == &Token::Between {
368            self.advance();
369            let min = self.parse_primary()?;
370            self.expect_token(Token::And)?;
371            let max = self.parse_primary()?;
372
373            return Ok(Expression::Between {
374                expr: Box::new(left),
375                min: Box::new(min),
376                max: Box::new(max),
377            });
378        }
379
380        // Handle comparison operators
381        let op = match self.current_token() {
382            Token::Eq => BinaryOperator::Eq,
383            Token::Ne => BinaryOperator::Ne,
384            Token::Lt => BinaryOperator::Lt,
385            Token::Le => BinaryOperator::Le,
386            Token::Gt => BinaryOperator::Gt,
387            Token::Ge => BinaryOperator::Ge,
388            _ => return Ok(left),
389        };
390
391        self.advance();
392        let right = self.parse_primary()?;
393
394        Ok(Expression::BinaryOp {
395            left: Box::new(left),
396            op,
397            right: Box::new(right),
398        })
399    }
400
401    fn parse_primary(&mut self) -> Result<Expression, ParseError> {
402        match self.current_token().clone() {
403            Token::Identifier(name) => {
404                self.advance();
405                Ok(Expression::Column(name))
406            }
407            Token::Integer(i) => {
408                self.advance();
409                Ok(Expression::Literal(Literal::Integer(i)))
410            }
411            Token::Float(f) => {
412                self.advance();
413                Ok(Expression::Literal(Literal::Float(f)))
414            }
415            Token::String(s) => {
416                self.advance();
417                Ok(Expression::Literal(Literal::String(s)))
418            }
419            Token::Boolean(b) => {
420                self.advance();
421                Ok(Expression::Literal(Literal::Boolean(b)))
422            }
423            Token::Null => {
424                self.advance();
425                Ok(Expression::Literal(Literal::Null))
426            }
427            Token::LeftParen => {
428                self.advance();
429                let expr = self.parse_expression()?;
430                self.expect_token(Token::RightParen)?;
431                Ok(expr)
432            }
433            token => Err(ParseError::UnexpectedToken {
434                expected: "expression".to_string(),
435                found: token,
436            }),
437        }
438    }
439
440    fn parse_literal(&mut self) -> Result<Literal, ParseError> {
441        match self.current_token().clone() {
442            Token::Integer(i) => {
443                self.advance();
444                Ok(Literal::Integer(i))
445            }
446            Token::Float(f) => {
447                self.advance();
448                Ok(Literal::Float(f))
449            }
450            Token::String(s) => {
451                self.advance();
452                Ok(Literal::String(s))
453            }
454            Token::Boolean(b) => {
455                self.advance();
456                Ok(Literal::Boolean(b))
457            }
458            Token::Null => {
459                self.advance();
460                Ok(Literal::Null)
461            }
462            token => Err(ParseError::UnexpectedToken {
463                expected: "literal value".to_string(),
464                found: token,
465            }),
466        }
467    }
468
469    fn parse_order_by(&mut self) -> Result<Option<OrderByClause>, ParseError> {
470        if self.current_token() != &Token::OrderBy {
471            return Ok(None);
472        }
473
474        self.advance();
475
476        let mut columns = Vec::new();
477
478        loop {
479            let column = if let Token::Identifier(name) = self.current_token().clone() {
480                self.advance();
481                name
482            } else {
483                return Err(ParseError::UnexpectedToken {
484                    expected: "column name".to_string(),
485                    found: self.current_token().clone(),
486                });
487            };
488
489            let direction = if self.current_token() == &Token::Desc {
490                self.advance();
491                OrderDirection::Desc
492            } else {
493                if self.current_token() == &Token::Asc {
494                    self.advance();
495                }
496                OrderDirection::Asc
497            };
498
499            columns.push(OrderByColumn { column, direction });
500
501            if self.current_token() == &Token::Comma {
502                self.advance();
503            } else {
504                break;
505            }
506        }
507
508        Ok(Some(OrderByClause { columns }))
509    }
510
511    fn parse_limit(&mut self) -> Result<Option<LimitClause>, ParseError> {
512        if self.current_token() != &Token::Limit {
513            return Ok(None);
514        }
515
516        self.advance();
517
518        let count = if let Token::Integer(n) = self.current_token() {
519            if *n < 0 {
520                return Err(ParseError::InvalidLimitValue(*n));
521            }
522            let count = *n as usize;
523            self.advance();
524            count
525        } else {
526            return Err(ParseError::UnexpectedToken {
527                expected: "integer".to_string(),
528                found: self.current_token().clone(),
529            });
530        };
531
532        let offset = if self.current_token() == &Token::Offset {
533            self.advance();
534            if let Token::Integer(n) = self.current_token() {
535                if *n < 0 {
536                    return Err(ParseError::InvalidOffsetValue(*n));
537                }
538                let offset = *n as usize;
539                self.advance();
540                Some(offset)
541            } else {
542                return Err(ParseError::UnexpectedToken {
543                    expected: "integer".to_string(),
544                    found: self.current_token().clone(),
545                });
546            }
547        } else {
548            None
549        };
550
551        Ok(Some(LimitClause { count, offset }))
552    }
553
554    fn current_token(&self) -> &Token {
555        &self.tokens[self.position]
556    }
557
558    fn advance(&mut self) {
559        if self.position < self.tokens.len() - 1 {
560            self.position += 1;
561        }
562    }
563
564    fn expect_token(&mut self, expected: Token) -> Result<(), ParseError> {
565        if self.current_token() == &expected {
566            self.advance();
567            Ok(())
568        } else {
569            Err(ParseError::UnexpectedToken {
570                expected: format!("{}", expected),
571                found: self.current_token().clone(),
572            })
573        }
574    }
575}
576
577/// Parser errors
578#[derive(Debug, Clone)]
579pub enum ParseError {
580    LexerError(LexerError),
581    UnexpectedToken { expected: String, found: Token },
582    EmptySelectList,
583    InvalidLimitValue(i64),
584    InvalidOffsetValue(i64),
585}
586
587impl fmt::Display for ParseError {
588    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
589        match self {
590            ParseError::LexerError(e) => write!(f, "Lexer error: {}", e),
591            ParseError::UnexpectedToken { expected, found } => {
592                write!(f, "Expected {}, found {}", expected, found)
593            }
594            ParseError::EmptySelectList => write!(f, "SELECT list cannot be empty"),
595            ParseError::InvalidLimitValue(n) => {
596                write!(f, "Invalid LIMIT value: {} (must be non-negative)", n)
597            }
598            ParseError::InvalidOffsetValue(n) => {
599                write!(f, "Invalid OFFSET value: {} (must be non-negative)", n)
600            }
601        }
602    }
603}
604
605impl std::error::Error for ParseError {}
606
607#[cfg(test)]
608mod tests {
609    use super::*;
610
611    #[test]
612    fn test_simple_select() {
613        let mut parser = Parser::new("SELECT * FROM users").unwrap();
614        let query = parser.parse().unwrap();
615
616        assert_eq!(query.select.columns.len(), 1);
617        assert!(matches!(query.select.columns[0], SelectColumn::Wildcard));
618        assert_eq!(query.from.table, "users");
619    }
620
621    #[test]
622    fn test_select_with_columns() {
623        let mut parser = Parser::new("SELECT name, age FROM users").unwrap();
624        let query = parser.parse().unwrap();
625
626        assert_eq!(query.select.columns.len(), 2);
627    }
628
629    #[test]
630    fn test_select_with_where() {
631        let mut parser = Parser::new("SELECT * FROM users WHERE age > 18").unwrap();
632        let query = parser.parse().unwrap();
633
634        assert!(query.where_clause.is_some());
635    }
636
637    #[test]
638    fn test_select_with_order_by() {
639        let mut parser = Parser::new("SELECT * FROM users ORDER BY name ASC").unwrap();
640        let query = parser.parse().unwrap();
641
642        assert!(query.order_by.is_some());
643        let order_by = query.order_by.unwrap();
644        assert_eq!(order_by.columns.len(), 1);
645        assert_eq!(order_by.columns[0].column, "name");
646        assert_eq!(order_by.columns[0].direction, OrderDirection::Asc);
647    }
648
649    #[test]
650    fn test_select_with_limit() {
651        let mut parser = Parser::new("SELECT * FROM users LIMIT 10").unwrap();
652        let query = parser.parse().unwrap();
653
654        assert!(query.limit.is_some());
655        let limit = query.limit.unwrap();
656        assert_eq!(limit.count, 10);
657        assert_eq!(limit.offset, None);
658    }
659
660    #[test]
661    fn test_select_with_limit_offset() {
662        let mut parser = Parser::new("SELECT * FROM users LIMIT 10 OFFSET 5").unwrap();
663        let query = parser.parse().unwrap();
664
665        let limit = query.limit.unwrap();
666        assert_eq!(limit.count, 10);
667        assert_eq!(limit.offset, Some(5));
668    }
669
670    #[test]
671    fn test_complex_where() {
672        let mut parser =
673            Parser::new("SELECT * FROM users WHERE age > 18 AND name = 'John'").unwrap();
674        let query = parser.parse().unwrap();
675
676        assert!(query.where_clause.is_some());
677    }
678
679    #[test]
680    fn test_aggregate_function() {
681        let mut parser = Parser::new("SELECT COUNT(*) FROM users").unwrap();
682        let query = parser.parse().unwrap();
683
684        assert_eq!(query.select.columns.len(), 1);
685        assert!(matches!(
686            query.select.columns[0],
687            SelectColumn::Aggregate { .. }
688        ));
689    }
690
691    #[test]
692    fn test_join() {
693        let mut parser =
694            Parser::new("SELECT * FROM users INNER JOIN orders ON users.id = orders.user_id")
695                .unwrap();
696        let query = parser.parse().unwrap();
697
698        assert_eq!(query.from.joins.len(), 1);
699        assert_eq!(query.from.joins[0].join_type, JoinType::Inner);
700        assert_eq!(query.from.joins[0].table, "orders");
701    }
702}