1use super::ast::*;
5use super::lexer::{Lexer, LexerError, Token};
6use std::fmt;
7
8pub struct Parser {
10 tokens: Vec<Token>,
11 position: usize,
12}
13
14impl Parser {
15 pub fn new(input: &str) -> Result<Self, ParseError> {
17 let mut lexer = Lexer::new(input);
18 let tokens = lexer.tokenize().map_err(ParseError::LexerError)?;
19 Ok(Self {
20 tokens,
21 position: 0,
22 })
23 }
24
25 pub fn parse(&mut self) -> Result<Query, ParseError> {
27 let select = self.parse_select()?;
28 let from = self.parse_from()?;
29 let where_clause = self.parse_where()?;
30 let group_by = self.parse_group_by()?;
31 let having = self.parse_having()?;
32 let order_by = self.parse_order_by()?;
33 let limit = self.parse_limit()?;
34
35 self.expect_token(Token::Eof)?;
36
37 Ok(Query {
38 select,
39 from,
40 where_clause,
41 group_by,
42 having,
43 order_by,
44 limit,
45 })
46 }
47
48 fn parse_select(&mut self) -> Result<SelectClause, ParseError> {
49 self.expect_token(Token::Select)?;
50
51 let mut columns = Vec::new();
52
53 loop {
54 if self.current_token() == &Token::Asterisk {
55 self.advance();
56 columns.push(SelectColumn::Wildcard);
57 } else if matches!(
58 self.current_token(),
59 Token::Count | Token::Sum | Token::Avg | Token::Min | Token::Max
60 ) {
61 let function = match self.current_token() {
63 Token::Count => AggregateFunction::Count,
64 Token::Sum => AggregateFunction::Sum,
65 Token::Avg => AggregateFunction::Avg,
66 Token::Min => AggregateFunction::Min,
67 Token::Max => AggregateFunction::Max,
68 _ => unreachable!(),
69 };
70 self.advance();
71
72 self.expect_token(Token::LeftParen)?;
73
74 let column = if self.current_token() == &Token::Asterisk {
75 self.advance();
76 Box::new(SelectColumn::Wildcard)
77 } else if let Token::Identifier(name) = self.current_token().clone() {
78 self.advance();
79 Box::new(SelectColumn::Column { name, alias: None })
80 } else {
81 return Err(ParseError::UnexpectedToken {
82 expected: "column name or *".to_string(),
83 found: self.current_token().clone(),
84 });
85 };
86
87 self.expect_token(Token::RightParen)?;
88
89 let alias = if self.current_token() == &Token::As {
90 self.advance();
91 if let Token::Identifier(name) = self.current_token().clone() {
92 self.advance();
93 Some(name)
94 } else {
95 None
96 }
97 } else {
98 None
99 };
100
101 columns.push(SelectColumn::Aggregate {
102 function,
103 column,
104 alias,
105 });
106 } else if let Token::Identifier(name) = self.current_token().clone() {
107 self.advance();
108
109 let alias = if self.current_token() == &Token::As {
110 self.advance();
111 if let Token::Identifier(alias_name) = self.current_token().clone() {
112 self.advance();
113 Some(alias_name)
114 } else {
115 None
116 }
117 } else {
118 None
119 };
120
121 columns.push(SelectColumn::Column { name, alias });
122 } else {
123 return Err(ParseError::UnexpectedToken {
124 expected: "column name or *".to_string(),
125 found: self.current_token().clone(),
126 });
127 }
128
129 if self.current_token() == &Token::Comma {
130 self.advance();
131 } else {
132 break;
133 }
134 }
135
136 if columns.is_empty() {
137 return Err(ParseError::EmptySelectList);
138 }
139
140 Ok(SelectClause { columns })
141 }
142
143 fn parse_from(&mut self) -> Result<FromClause, ParseError> {
144 self.expect_token(Token::From)?;
145
146 let table = if let Token::Identifier(name) = self.current_token().clone() {
147 self.advance();
148 name
149 } else {
150 return Err(ParseError::UnexpectedToken {
151 expected: "table name".to_string(),
152 found: self.current_token().clone(),
153 });
154 };
155
156 let mut joins = Vec::new();
157
158 while matches!(
160 self.current_token(),
161 Token::Inner | Token::Left | Token::Right | Token::Full | Token::Join
162 ) {
163 let join_type = match self.current_token() {
164 Token::Inner => {
165 self.advance();
166 self.expect_token(Token::Join)?;
167 JoinType::Inner
168 }
169 Token::Left => {
170 self.advance();
171 self.expect_token(Token::Join)?;
172 JoinType::Left
173 }
174 Token::Right => {
175 self.advance();
176 self.expect_token(Token::Join)?;
177 JoinType::Right
178 }
179 Token::Full => {
180 self.advance();
181 self.expect_token(Token::Join)?;
182 JoinType::Full
183 }
184 Token::Join => {
185 self.advance();
186 JoinType::Inner }
188 _ => break,
189 };
190
191 let join_table = if let Token::Identifier(name) = self.current_token().clone() {
192 self.advance();
193 name
194 } else {
195 return Err(ParseError::UnexpectedToken {
196 expected: "table name".to_string(),
197 found: self.current_token().clone(),
198 });
199 };
200
201 self.expect_token(Token::On)?;
202 let condition = self.parse_expression()?;
203
204 joins.push(Join {
205 join_type,
206 table: join_table,
207 condition,
208 });
209 }
210
211 Ok(FromClause { table, joins })
212 }
213
214 fn parse_where(&mut self) -> Result<Option<WhereClause>, ParseError> {
215 if self.current_token() != &Token::Where {
216 return Ok(None);
217 }
218
219 self.advance();
220 let condition = self.parse_expression()?;
221
222 Ok(Some(WhereClause { condition }))
223 }
224
225 fn parse_group_by(&mut self) -> Result<Option<GroupByClause>, ParseError> {
226 if self.current_token() != &Token::Group {
227 return Ok(None);
228 }
229
230 self.advance();
231 self.expect_token(Token::By)?;
232
233 let mut columns = Vec::new();
234
235 loop {
236 if let Token::Identifier(name) = self.current_token().clone() {
237 self.advance();
238 columns.push(name);
239
240 if self.current_token() == &Token::Comma {
241 self.advance();
242 continue;
243 } else {
244 break;
245 }
246 } else {
247 return Err(ParseError::UnexpectedToken {
248 expected: "column name".to_string(),
249 found: self.current_token().clone(),
250 });
251 }
252 }
253
254 if columns.is_empty() {
255 return Err(ParseError::UnexpectedToken {
256 expected: "at least one column for GROUP BY".to_string(),
257 found: self.current_token().clone(),
258 });
259 }
260
261 Ok(Some(GroupByClause { columns }))
262 }
263
264 fn parse_having(&mut self) -> Result<Option<HavingClause>, ParseError> {
265 if self.current_token() != &Token::Having {
266 return Ok(None);
267 }
268
269 self.advance();
270 let condition = self.parse_expression()?;
271
272 Ok(Some(HavingClause { condition }))
273 }
274
275 fn parse_expression(&mut self) -> Result<Expression, ParseError> {
276 self.parse_logical_or()
277 }
278
279 fn parse_logical_or(&mut self) -> Result<Expression, ParseError> {
280 let mut left = self.parse_logical_and()?;
281
282 while self.current_token() == &Token::Or {
283 self.advance();
284 let right = self.parse_logical_and()?;
285 left = Expression::LogicalOp {
286 left: Box::new(left),
287 op: LogicalOperator::Or,
288 right: Box::new(right),
289 };
290 }
291
292 Ok(left)
293 }
294
295 fn parse_logical_and(&mut self) -> Result<Expression, ParseError> {
296 let mut left = self.parse_not()?;
297
298 while self.current_token() == &Token::And {
299 self.advance();
300 let right = self.parse_not()?;
301 left = Expression::LogicalOp {
302 left: Box::new(left),
303 op: LogicalOperator::And,
304 right: Box::new(right),
305 };
306 }
307
308 Ok(left)
309 }
310
311 fn parse_not(&mut self) -> Result<Expression, ParseError> {
312 if self.current_token() == &Token::Not {
313 self.advance();
314 let expr = self.parse_comparison()?;
315 return Ok(Expression::Not(Box::new(expr)));
316 }
317
318 self.parse_comparison()
319 }
320
321 fn parse_comparison(&mut self) -> Result<Expression, ParseError> {
322 let left = self.parse_primary()?;
323
324 if self.current_token() == &Token::Like {
326 self.advance();
327 if let Token::String(pattern) = self.current_token().clone() {
328 self.advance();
329 return Ok(Expression::Like {
330 expr: Box::new(left),
331 pattern,
332 });
333 } else {
334 return Err(ParseError::UnexpectedToken {
335 expected: "string pattern".to_string(),
336 found: self.current_token().clone(),
337 });
338 }
339 }
340
341 if self.current_token() == &Token::In {
343 self.advance();
344 self.expect_token(Token::LeftParen)?;
345
346 let mut values = Vec::new();
347 loop {
348 let value = self.parse_literal()?;
349 values.push(value);
350
351 if self.current_token() == &Token::Comma {
352 self.advance();
353 } else {
354 break;
355 }
356 }
357
358 self.expect_token(Token::RightParen)?;
359
360 return Ok(Expression::In {
361 expr: Box::new(left),
362 values,
363 });
364 }
365
366 if self.current_token() == &Token::Between {
368 self.advance();
369 let min = self.parse_primary()?;
370 self.expect_token(Token::And)?;
371 let max = self.parse_primary()?;
372
373 return Ok(Expression::Between {
374 expr: Box::new(left),
375 min: Box::new(min),
376 max: Box::new(max),
377 });
378 }
379
380 let op = match self.current_token() {
382 Token::Eq => BinaryOperator::Eq,
383 Token::Ne => BinaryOperator::Ne,
384 Token::Lt => BinaryOperator::Lt,
385 Token::Le => BinaryOperator::Le,
386 Token::Gt => BinaryOperator::Gt,
387 Token::Ge => BinaryOperator::Ge,
388 _ => return Ok(left),
389 };
390
391 self.advance();
392 let right = self.parse_primary()?;
393
394 Ok(Expression::BinaryOp {
395 left: Box::new(left),
396 op,
397 right: Box::new(right),
398 })
399 }
400
401 fn parse_primary(&mut self) -> Result<Expression, ParseError> {
402 match self.current_token().clone() {
403 Token::Identifier(name) => {
404 self.advance();
405 Ok(Expression::Column(name))
406 }
407 Token::Integer(i) => {
408 self.advance();
409 Ok(Expression::Literal(Literal::Integer(i)))
410 }
411 Token::Float(f) => {
412 self.advance();
413 Ok(Expression::Literal(Literal::Float(f)))
414 }
415 Token::String(s) => {
416 self.advance();
417 Ok(Expression::Literal(Literal::String(s)))
418 }
419 Token::Boolean(b) => {
420 self.advance();
421 Ok(Expression::Literal(Literal::Boolean(b)))
422 }
423 Token::Null => {
424 self.advance();
425 Ok(Expression::Literal(Literal::Null))
426 }
427 Token::LeftParen => {
428 self.advance();
429 let expr = self.parse_expression()?;
430 self.expect_token(Token::RightParen)?;
431 Ok(expr)
432 }
433 token => Err(ParseError::UnexpectedToken {
434 expected: "expression".to_string(),
435 found: token,
436 }),
437 }
438 }
439
440 fn parse_literal(&mut self) -> Result<Literal, ParseError> {
441 match self.current_token().clone() {
442 Token::Integer(i) => {
443 self.advance();
444 Ok(Literal::Integer(i))
445 }
446 Token::Float(f) => {
447 self.advance();
448 Ok(Literal::Float(f))
449 }
450 Token::String(s) => {
451 self.advance();
452 Ok(Literal::String(s))
453 }
454 Token::Boolean(b) => {
455 self.advance();
456 Ok(Literal::Boolean(b))
457 }
458 Token::Null => {
459 self.advance();
460 Ok(Literal::Null)
461 }
462 token => Err(ParseError::UnexpectedToken {
463 expected: "literal value".to_string(),
464 found: token,
465 }),
466 }
467 }
468
469 fn parse_order_by(&mut self) -> Result<Option<OrderByClause>, ParseError> {
470 if self.current_token() != &Token::OrderBy {
471 return Ok(None);
472 }
473
474 self.advance();
475
476 let mut columns = Vec::new();
477
478 loop {
479 let column = if let Token::Identifier(name) = self.current_token().clone() {
480 self.advance();
481 name
482 } else {
483 return Err(ParseError::UnexpectedToken {
484 expected: "column name".to_string(),
485 found: self.current_token().clone(),
486 });
487 };
488
489 let direction = if self.current_token() == &Token::Desc {
490 self.advance();
491 OrderDirection::Desc
492 } else {
493 if self.current_token() == &Token::Asc {
494 self.advance();
495 }
496 OrderDirection::Asc
497 };
498
499 columns.push(OrderByColumn { column, direction });
500
501 if self.current_token() == &Token::Comma {
502 self.advance();
503 } else {
504 break;
505 }
506 }
507
508 Ok(Some(OrderByClause { columns }))
509 }
510
511 fn parse_limit(&mut self) -> Result<Option<LimitClause>, ParseError> {
512 if self.current_token() != &Token::Limit {
513 return Ok(None);
514 }
515
516 self.advance();
517
518 let count = if let Token::Integer(n) = self.current_token() {
519 if *n < 0 {
520 return Err(ParseError::InvalidLimitValue(*n));
521 }
522 let count = *n as usize;
523 self.advance();
524 count
525 } else {
526 return Err(ParseError::UnexpectedToken {
527 expected: "integer".to_string(),
528 found: self.current_token().clone(),
529 });
530 };
531
532 let offset = if self.current_token() == &Token::Offset {
533 self.advance();
534 if let Token::Integer(n) = self.current_token() {
535 if *n < 0 {
536 return Err(ParseError::InvalidOffsetValue(*n));
537 }
538 let offset = *n as usize;
539 self.advance();
540 Some(offset)
541 } else {
542 return Err(ParseError::UnexpectedToken {
543 expected: "integer".to_string(),
544 found: self.current_token().clone(),
545 });
546 }
547 } else {
548 None
549 };
550
551 Ok(Some(LimitClause { count, offset }))
552 }
553
554 fn current_token(&self) -> &Token {
555 &self.tokens[self.position]
556 }
557
558 fn advance(&mut self) {
559 if self.position < self.tokens.len() - 1 {
560 self.position += 1;
561 }
562 }
563
564 fn expect_token(&mut self, expected: Token) -> Result<(), ParseError> {
565 if self.current_token() == &expected {
566 self.advance();
567 Ok(())
568 } else {
569 Err(ParseError::UnexpectedToken {
570 expected: format!("{}", expected),
571 found: self.current_token().clone(),
572 })
573 }
574 }
575}
576
577#[derive(Debug, Clone)]
579pub enum ParseError {
580 LexerError(LexerError),
581 UnexpectedToken { expected: String, found: Token },
582 EmptySelectList,
583 InvalidLimitValue(i64),
584 InvalidOffsetValue(i64),
585}
586
587impl fmt::Display for ParseError {
588 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
589 match self {
590 ParseError::LexerError(e) => write!(f, "Lexer error: {}", e),
591 ParseError::UnexpectedToken { expected, found } => {
592 write!(f, "Expected {}, found {}", expected, found)
593 }
594 ParseError::EmptySelectList => write!(f, "SELECT list cannot be empty"),
595 ParseError::InvalidLimitValue(n) => {
596 write!(f, "Invalid LIMIT value: {} (must be non-negative)", n)
597 }
598 ParseError::InvalidOffsetValue(n) => {
599 write!(f, "Invalid OFFSET value: {} (must be non-negative)", n)
600 }
601 }
602 }
603}
604
605impl std::error::Error for ParseError {}
606
607#[cfg(test)]
608mod tests {
609 use super::*;
610
611 #[test]
612 fn test_simple_select() {
613 let mut parser = Parser::new("SELECT * FROM users").unwrap();
614 let query = parser.parse().unwrap();
615
616 assert_eq!(query.select.columns.len(), 1);
617 assert!(matches!(query.select.columns[0], SelectColumn::Wildcard));
618 assert_eq!(query.from.table, "users");
619 }
620
621 #[test]
622 fn test_select_with_columns() {
623 let mut parser = Parser::new("SELECT name, age FROM users").unwrap();
624 let query = parser.parse().unwrap();
625
626 assert_eq!(query.select.columns.len(), 2);
627 }
628
629 #[test]
630 fn test_select_with_where() {
631 let mut parser = Parser::new("SELECT * FROM users WHERE age > 18").unwrap();
632 let query = parser.parse().unwrap();
633
634 assert!(query.where_clause.is_some());
635 }
636
637 #[test]
638 fn test_select_with_order_by() {
639 let mut parser = Parser::new("SELECT * FROM users ORDER BY name ASC").unwrap();
640 let query = parser.parse().unwrap();
641
642 assert!(query.order_by.is_some());
643 let order_by = query.order_by.unwrap();
644 assert_eq!(order_by.columns.len(), 1);
645 assert_eq!(order_by.columns[0].column, "name");
646 assert_eq!(order_by.columns[0].direction, OrderDirection::Asc);
647 }
648
649 #[test]
650 fn test_select_with_limit() {
651 let mut parser = Parser::new("SELECT * FROM users LIMIT 10").unwrap();
652 let query = parser.parse().unwrap();
653
654 assert!(query.limit.is_some());
655 let limit = query.limit.unwrap();
656 assert_eq!(limit.count, 10);
657 assert_eq!(limit.offset, None);
658 }
659
660 #[test]
661 fn test_select_with_limit_offset() {
662 let mut parser = Parser::new("SELECT * FROM users LIMIT 10 OFFSET 5").unwrap();
663 let query = parser.parse().unwrap();
664
665 let limit = query.limit.unwrap();
666 assert_eq!(limit.count, 10);
667 assert_eq!(limit.offset, Some(5));
668 }
669
670 #[test]
671 fn test_complex_where() {
672 let mut parser =
673 Parser::new("SELECT * FROM users WHERE age > 18 AND name = 'John'").unwrap();
674 let query = parser.parse().unwrap();
675
676 assert!(query.where_clause.is_some());
677 }
678
679 #[test]
680 fn test_aggregate_function() {
681 let mut parser = Parser::new("SELECT COUNT(*) FROM users").unwrap();
682 let query = parser.parse().unwrap();
683
684 assert_eq!(query.select.columns.len(), 1);
685 assert!(matches!(
686 query.select.columns[0],
687 SelectColumn::Aggregate { .. }
688 ));
689 }
690
691 #[test]
692 fn test_join() {
693 let mut parser =
694 Parser::new("SELECT * FROM users INNER JOIN orders ON users.id = orders.user_id")
695 .unwrap();
696 let query = parser.parse().unwrap();
697
698 assert_eq!(query.from.joins.len(), 1);
699 assert_eq!(query.from.joins[0].join_type, JoinType::Inner);
700 assert_eq!(query.from.joins[0].table, "orders");
701 }
702}