rustlite_core/query/
parser.rs

1/// Parser for SQL-like queries
2///
3/// Converts a stream of tokens into an Abstract Syntax Tree (AST).
4use super::ast::*;
5use super::lexer::{Lexer, LexerError, Token};
6use std::fmt;
7
8/// Parser for SQL-like queries
9pub struct Parser {
10    tokens: Vec<Token>,
11    position: usize,
12}
13
14impl Parser {
15    /// Create a new parser from SQL text
16    pub fn new(input: &str) -> Result<Self, ParseError> {
17        let mut lexer = Lexer::new(input);
18        let tokens = lexer.tokenize().map_err(ParseError::LexerError)?;
19        Ok(Self {
20            tokens,
21            position: 0,
22        })
23    }
24
25    /// Parse the query into an AST
26    pub fn parse(&mut self) -> Result<Query, ParseError> {
27        let select = self.parse_select()?;
28        let from = self.parse_from()?;
29        let where_clause = self.parse_where()?;
30        let order_by = self.parse_order_by()?;
31        let limit = self.parse_limit()?;
32
33        self.expect_token(Token::Eof)?;
34
35        Ok(Query {
36            select,
37            from,
38            where_clause,
39            order_by,
40            limit,
41        })
42    }
43
44    fn parse_select(&mut self) -> Result<SelectClause, ParseError> {
45        self.expect_token(Token::Select)?;
46
47        let mut columns = Vec::new();
48
49        loop {
50            if self.current_token() == &Token::Asterisk {
51                self.advance();
52                columns.push(SelectColumn::Wildcard);
53            } else if matches!(
54                self.current_token(),
55                Token::Count | Token::Sum | Token::Avg | Token::Min | Token::Max
56            ) {
57                // Aggregate function
58                let function = match self.current_token() {
59                    Token::Count => AggregateFunction::Count,
60                    Token::Sum => AggregateFunction::Sum,
61                    Token::Avg => AggregateFunction::Avg,
62                    Token::Min => AggregateFunction::Min,
63                    Token::Max => AggregateFunction::Max,
64                    _ => unreachable!(),
65                };
66                self.advance();
67
68                self.expect_token(Token::LeftParen)?;
69
70                let column = if self.current_token() == &Token::Asterisk {
71                    self.advance();
72                    Box::new(SelectColumn::Wildcard)
73                } else if let Token::Identifier(name) = self.current_token().clone() {
74                    self.advance();
75                    Box::new(SelectColumn::Column { name, alias: None })
76                } else {
77                    return Err(ParseError::UnexpectedToken {
78                        expected: "column name or *".to_string(),
79                        found: self.current_token().clone(),
80                    });
81                };
82
83                self.expect_token(Token::RightParen)?;
84
85                let alias = if self.current_token() == &Token::As {
86                    self.advance();
87                    if let Token::Identifier(name) = self.current_token().clone() {
88                        self.advance();
89                        Some(name)
90                    } else {
91                        None
92                    }
93                } else {
94                    None
95                };
96
97                columns.push(SelectColumn::Aggregate {
98                    function,
99                    column,
100                    alias,
101                });
102            } else if let Token::Identifier(name) = self.current_token().clone() {
103                self.advance();
104
105                let alias = if self.current_token() == &Token::As {
106                    self.advance();
107                    if let Token::Identifier(alias_name) = self.current_token().clone() {
108                        self.advance();
109                        Some(alias_name)
110                    } else {
111                        None
112                    }
113                } else {
114                    None
115                };
116
117                columns.push(SelectColumn::Column { name, alias });
118            } else {
119                return Err(ParseError::UnexpectedToken {
120                    expected: "column name or *".to_string(),
121                    found: self.current_token().clone(),
122                });
123            }
124
125            if self.current_token() == &Token::Comma {
126                self.advance();
127            } else {
128                break;
129            }
130        }
131
132        if columns.is_empty() {
133            return Err(ParseError::EmptySelectList);
134        }
135
136        Ok(SelectClause { columns })
137    }
138
139    fn parse_from(&mut self) -> Result<FromClause, ParseError> {
140        self.expect_token(Token::From)?;
141
142        let table = if let Token::Identifier(name) = self.current_token().clone() {
143            self.advance();
144            name
145        } else {
146            return Err(ParseError::UnexpectedToken {
147                expected: "table name".to_string(),
148                found: self.current_token().clone(),
149            });
150        };
151
152        let mut joins = Vec::new();
153
154        // Parse JOINs
155        while matches!(
156            self.current_token(),
157            Token::Inner | Token::Left | Token::Right | Token::Full | Token::Join
158        ) {
159            let join_type = match self.current_token() {
160                Token::Inner => {
161                    self.advance();
162                    self.expect_token(Token::Join)?;
163                    JoinType::Inner
164                }
165                Token::Left => {
166                    self.advance();
167                    self.expect_token(Token::Join)?;
168                    JoinType::Left
169                }
170                Token::Right => {
171                    self.advance();
172                    self.expect_token(Token::Join)?;
173                    JoinType::Right
174                }
175                Token::Full => {
176                    self.advance();
177                    self.expect_token(Token::Join)?;
178                    JoinType::Full
179                }
180                Token::Join => {
181                    self.advance();
182                    JoinType::Inner // Default to INNER JOIN
183                }
184                _ => break,
185            };
186
187            let join_table = if let Token::Identifier(name) = self.current_token().clone() {
188                self.advance();
189                name
190            } else {
191                return Err(ParseError::UnexpectedToken {
192                    expected: "table name".to_string(),
193                    found: self.current_token().clone(),
194                });
195            };
196
197            self.expect_token(Token::On)?;
198            let condition = self.parse_expression()?;
199
200            joins.push(Join {
201                join_type,
202                table: join_table,
203                condition,
204            });
205        }
206
207        Ok(FromClause { table, joins })
208    }
209
210    fn parse_where(&mut self) -> Result<Option<WhereClause>, ParseError> {
211        if self.current_token() != &Token::Where {
212            return Ok(None);
213        }
214
215        self.advance();
216        let condition = self.parse_expression()?;
217
218        Ok(Some(WhereClause { condition }))
219    }
220
221    fn parse_expression(&mut self) -> Result<Expression, ParseError> {
222        self.parse_logical_or()
223    }
224
225    fn parse_logical_or(&mut self) -> Result<Expression, ParseError> {
226        let mut left = self.parse_logical_and()?;
227
228        while self.current_token() == &Token::Or {
229            self.advance();
230            let right = self.parse_logical_and()?;
231            left = Expression::LogicalOp {
232                left: Box::new(left),
233                op: LogicalOperator::Or,
234                right: Box::new(right),
235            };
236        }
237
238        Ok(left)
239    }
240
241    fn parse_logical_and(&mut self) -> Result<Expression, ParseError> {
242        let mut left = self.parse_not()?;
243
244        while self.current_token() == &Token::And {
245            self.advance();
246            let right = self.parse_not()?;
247            left = Expression::LogicalOp {
248                left: Box::new(left),
249                op: LogicalOperator::And,
250                right: Box::new(right),
251            };
252        }
253
254        Ok(left)
255    }
256
257    fn parse_not(&mut self) -> Result<Expression, ParseError> {
258        if self.current_token() == &Token::Not {
259            self.advance();
260            let expr = self.parse_comparison()?;
261            return Ok(Expression::Not(Box::new(expr)));
262        }
263
264        self.parse_comparison()
265    }
266
267    fn parse_comparison(&mut self) -> Result<Expression, ParseError> {
268        let left = self.parse_primary()?;
269
270        // Handle LIKE
271        if self.current_token() == &Token::Like {
272            self.advance();
273            if let Token::String(pattern) = self.current_token().clone() {
274                self.advance();
275                return Ok(Expression::Like {
276                    expr: Box::new(left),
277                    pattern,
278                });
279            } else {
280                return Err(ParseError::UnexpectedToken {
281                    expected: "string pattern".to_string(),
282                    found: self.current_token().clone(),
283                });
284            }
285        }
286
287        // Handle IN
288        if self.current_token() == &Token::In {
289            self.advance();
290            self.expect_token(Token::LeftParen)?;
291
292            let mut values = Vec::new();
293            loop {
294                let value = self.parse_literal()?;
295                values.push(value);
296
297                if self.current_token() == &Token::Comma {
298                    self.advance();
299                } else {
300                    break;
301                }
302            }
303
304            self.expect_token(Token::RightParen)?;
305
306            return Ok(Expression::In {
307                expr: Box::new(left),
308                values,
309            });
310        }
311
312        // Handle BETWEEN
313        if self.current_token() == &Token::Between {
314            self.advance();
315            let min = self.parse_primary()?;
316            self.expect_token(Token::And)?;
317            let max = self.parse_primary()?;
318
319            return Ok(Expression::Between {
320                expr: Box::new(left),
321                min: Box::new(min),
322                max: Box::new(max),
323            });
324        }
325
326        // Handle comparison operators
327        let op = match self.current_token() {
328            Token::Eq => BinaryOperator::Eq,
329            Token::Ne => BinaryOperator::Ne,
330            Token::Lt => BinaryOperator::Lt,
331            Token::Le => BinaryOperator::Le,
332            Token::Gt => BinaryOperator::Gt,
333            Token::Ge => BinaryOperator::Ge,
334            _ => return Ok(left),
335        };
336
337        self.advance();
338        let right = self.parse_primary()?;
339
340        Ok(Expression::BinaryOp {
341            left: Box::new(left),
342            op,
343            right: Box::new(right),
344        })
345    }
346
347    fn parse_primary(&mut self) -> Result<Expression, ParseError> {
348        match self.current_token().clone() {
349            Token::Identifier(name) => {
350                self.advance();
351                Ok(Expression::Column(name))
352            }
353            Token::Integer(i) => {
354                self.advance();
355                Ok(Expression::Literal(Literal::Integer(i)))
356            }
357            Token::Float(f) => {
358                self.advance();
359                Ok(Expression::Literal(Literal::Float(f)))
360            }
361            Token::String(s) => {
362                self.advance();
363                Ok(Expression::Literal(Literal::String(s)))
364            }
365            Token::Boolean(b) => {
366                self.advance();
367                Ok(Expression::Literal(Literal::Boolean(b)))
368            }
369            Token::Null => {
370                self.advance();
371                Ok(Expression::Literal(Literal::Null))
372            }
373            Token::LeftParen => {
374                self.advance();
375                let expr = self.parse_expression()?;
376                self.expect_token(Token::RightParen)?;
377                Ok(expr)
378            }
379            token => Err(ParseError::UnexpectedToken {
380                expected: "expression".to_string(),
381                found: token,
382            }),
383        }
384    }
385
386    fn parse_literal(&mut self) -> Result<Literal, ParseError> {
387        match self.current_token().clone() {
388            Token::Integer(i) => {
389                self.advance();
390                Ok(Literal::Integer(i))
391            }
392            Token::Float(f) => {
393                self.advance();
394                Ok(Literal::Float(f))
395            }
396            Token::String(s) => {
397                self.advance();
398                Ok(Literal::String(s))
399            }
400            Token::Boolean(b) => {
401                self.advance();
402                Ok(Literal::Boolean(b))
403            }
404            Token::Null => {
405                self.advance();
406                Ok(Literal::Null)
407            }
408            token => Err(ParseError::UnexpectedToken {
409                expected: "literal value".to_string(),
410                found: token,
411            }),
412        }
413    }
414
415    fn parse_order_by(&mut self) -> Result<Option<OrderByClause>, ParseError> {
416        if self.current_token() != &Token::OrderBy {
417            return Ok(None);
418        }
419
420        self.advance();
421
422        let mut columns = Vec::new();
423
424        loop {
425            let column = if let Token::Identifier(name) = self.current_token().clone() {
426                self.advance();
427                name
428            } else {
429                return Err(ParseError::UnexpectedToken {
430                    expected: "column name".to_string(),
431                    found: self.current_token().clone(),
432                });
433            };
434
435            let direction = if self.current_token() == &Token::Desc {
436                self.advance();
437                OrderDirection::Desc
438            } else {
439                if self.current_token() == &Token::Asc {
440                    self.advance();
441                }
442                OrderDirection::Asc
443            };
444
445            columns.push(OrderByColumn { column, direction });
446
447            if self.current_token() == &Token::Comma {
448                self.advance();
449            } else {
450                break;
451            }
452        }
453
454        Ok(Some(OrderByClause { columns }))
455    }
456
457    fn parse_limit(&mut self) -> Result<Option<LimitClause>, ParseError> {
458        if self.current_token() != &Token::Limit {
459            return Ok(None);
460        }
461
462        self.advance();
463
464        let count = if let Token::Integer(n) = self.current_token() {
465            if *n < 0 {
466                return Err(ParseError::InvalidLimitValue(*n));
467            }
468            let count = *n as usize;
469            self.advance();
470            count
471        } else {
472            return Err(ParseError::UnexpectedToken {
473                expected: "integer".to_string(),
474                found: self.current_token().clone(),
475            });
476        };
477
478        let offset = if self.current_token() == &Token::Offset {
479            self.advance();
480            if let Token::Integer(n) = self.current_token() {
481                if *n < 0 {
482                    return Err(ParseError::InvalidOffsetValue(*n));
483                }
484                let offset = *n as usize;
485                self.advance();
486                Some(offset)
487            } else {
488                return Err(ParseError::UnexpectedToken {
489                    expected: "integer".to_string(),
490                    found: self.current_token().clone(),
491                });
492            }
493        } else {
494            None
495        };
496
497        Ok(Some(LimitClause { count, offset }))
498    }
499
500    fn current_token(&self) -> &Token {
501        &self.tokens[self.position]
502    }
503
504    fn advance(&mut self) {
505        if self.position < self.tokens.len() - 1 {
506            self.position += 1;
507        }
508    }
509
510    fn expect_token(&mut self, expected: Token) -> Result<(), ParseError> {
511        if self.current_token() == &expected {
512            self.advance();
513            Ok(())
514        } else {
515            Err(ParseError::UnexpectedToken {
516                expected: format!("{}", expected),
517                found: self.current_token().clone(),
518            })
519        }
520    }
521}
522
523/// Parser errors
524#[derive(Debug, Clone)]
525pub enum ParseError {
526    LexerError(LexerError),
527    UnexpectedToken { expected: String, found: Token },
528    EmptySelectList,
529    InvalidLimitValue(i64),
530    InvalidOffsetValue(i64),
531}
532
533impl fmt::Display for ParseError {
534    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
535        match self {
536            ParseError::LexerError(e) => write!(f, "Lexer error: {}", e),
537            ParseError::UnexpectedToken { expected, found } => {
538                write!(f, "Expected {}, found {}", expected, found)
539            }
540            ParseError::EmptySelectList => write!(f, "SELECT list cannot be empty"),
541            ParseError::InvalidLimitValue(n) => {
542                write!(f, "Invalid LIMIT value: {} (must be non-negative)", n)
543            }
544            ParseError::InvalidOffsetValue(n) => {
545                write!(f, "Invalid OFFSET value: {} (must be non-negative)", n)
546            }
547        }
548    }
549}
550
551impl std::error::Error for ParseError {}
552
553#[cfg(test)]
554mod tests {
555    use super::*;
556
557    #[test]
558    fn test_simple_select() {
559        let mut parser = Parser::new("SELECT * FROM users").unwrap();
560        let query = parser.parse().unwrap();
561
562        assert_eq!(query.select.columns.len(), 1);
563        assert!(matches!(query.select.columns[0], SelectColumn::Wildcard));
564        assert_eq!(query.from.table, "users");
565    }
566
567    #[test]
568    fn test_select_with_columns() {
569        let mut parser = Parser::new("SELECT name, age FROM users").unwrap();
570        let query = parser.parse().unwrap();
571
572        assert_eq!(query.select.columns.len(), 2);
573    }
574
575    #[test]
576    fn test_select_with_where() {
577        let mut parser = Parser::new("SELECT * FROM users WHERE age > 18").unwrap();
578        let query = parser.parse().unwrap();
579
580        assert!(query.where_clause.is_some());
581    }
582
583    #[test]
584    fn test_select_with_order_by() {
585        let mut parser = Parser::new("SELECT * FROM users ORDER BY name ASC").unwrap();
586        let query = parser.parse().unwrap();
587
588        assert!(query.order_by.is_some());
589        let order_by = query.order_by.unwrap();
590        assert_eq!(order_by.columns.len(), 1);
591        assert_eq!(order_by.columns[0].column, "name");
592        assert_eq!(order_by.columns[0].direction, OrderDirection::Asc);
593    }
594
595    #[test]
596    fn test_select_with_limit() {
597        let mut parser = Parser::new("SELECT * FROM users LIMIT 10").unwrap();
598        let query = parser.parse().unwrap();
599
600        assert!(query.limit.is_some());
601        let limit = query.limit.unwrap();
602        assert_eq!(limit.count, 10);
603        assert_eq!(limit.offset, None);
604    }
605
606    #[test]
607    fn test_select_with_limit_offset() {
608        let mut parser = Parser::new("SELECT * FROM users LIMIT 10 OFFSET 5").unwrap();
609        let query = parser.parse().unwrap();
610
611        let limit = query.limit.unwrap();
612        assert_eq!(limit.count, 10);
613        assert_eq!(limit.offset, Some(5));
614    }
615
616    #[test]
617    fn test_complex_where() {
618        let mut parser =
619            Parser::new("SELECT * FROM users WHERE age > 18 AND name = 'John'").unwrap();
620        let query = parser.parse().unwrap();
621
622        assert!(query.where_clause.is_some());
623    }
624
625    #[test]
626    fn test_aggregate_function() {
627        let mut parser = Parser::new("SELECT COUNT(*) FROM users").unwrap();
628        let query = parser.parse().unwrap();
629
630        assert_eq!(query.select.columns.len(), 1);
631        assert!(matches!(
632            query.select.columns[0],
633            SelectColumn::Aggregate { .. }
634        ));
635    }
636
637    #[test]
638    fn test_join() {
639        let mut parser =
640            Parser::new("SELECT * FROM users INNER JOIN orders ON users.id = orders.user_id")
641                .unwrap();
642        let query = parser.parse().unwrap();
643
644        assert_eq!(query.from.joins.len(), 1);
645        assert_eq!(query.from.joins[0].join_type, JoinType::Inner);
646        assert_eq!(query.from.joins[0].table, "orders");
647    }
648}