1use super::ast::*;
5use super::lexer::{Lexer, LexerError, Token};
6use std::fmt;
7
8pub struct Parser {
10 tokens: Vec<Token>,
11 position: usize,
12}
13
14impl Parser {
15 pub fn new(input: &str) -> Result<Self, ParseError> {
17 let mut lexer = Lexer::new(input);
18 let tokens = lexer.tokenize().map_err(ParseError::LexerError)?;
19 Ok(Self {
20 tokens,
21 position: 0,
22 })
23 }
24
25 pub fn parse(&mut self) -> Result<Query, ParseError> {
27 let select = self.parse_select()?;
28 let from = self.parse_from()?;
29 let where_clause = self.parse_where()?;
30 let order_by = self.parse_order_by()?;
31 let limit = self.parse_limit()?;
32
33 self.expect_token(Token::Eof)?;
34
35 Ok(Query {
36 select,
37 from,
38 where_clause,
39 order_by,
40 limit,
41 })
42 }
43
44 fn parse_select(&mut self) -> Result<SelectClause, ParseError> {
45 self.expect_token(Token::Select)?;
46
47 let mut columns = Vec::new();
48
49 loop {
50 if self.current_token() == &Token::Asterisk {
51 self.advance();
52 columns.push(SelectColumn::Wildcard);
53 } else if matches!(
54 self.current_token(),
55 Token::Count | Token::Sum | Token::Avg | Token::Min | Token::Max
56 ) {
57 let function = match self.current_token() {
59 Token::Count => AggregateFunction::Count,
60 Token::Sum => AggregateFunction::Sum,
61 Token::Avg => AggregateFunction::Avg,
62 Token::Min => AggregateFunction::Min,
63 Token::Max => AggregateFunction::Max,
64 _ => unreachable!(),
65 };
66 self.advance();
67
68 self.expect_token(Token::LeftParen)?;
69
70 let column = if self.current_token() == &Token::Asterisk {
71 self.advance();
72 Box::new(SelectColumn::Wildcard)
73 } else if let Token::Identifier(name) = self.current_token().clone() {
74 self.advance();
75 Box::new(SelectColumn::Column { name, alias: None })
76 } else {
77 return Err(ParseError::UnexpectedToken {
78 expected: "column name or *".to_string(),
79 found: self.current_token().clone(),
80 });
81 };
82
83 self.expect_token(Token::RightParen)?;
84
85 let alias = if self.current_token() == &Token::As {
86 self.advance();
87 if let Token::Identifier(name) = self.current_token().clone() {
88 self.advance();
89 Some(name)
90 } else {
91 None
92 }
93 } else {
94 None
95 };
96
97 columns.push(SelectColumn::Aggregate {
98 function,
99 column,
100 alias,
101 });
102 } else if let Token::Identifier(name) = self.current_token().clone() {
103 self.advance();
104
105 let alias = if self.current_token() == &Token::As {
106 self.advance();
107 if let Token::Identifier(alias_name) = self.current_token().clone() {
108 self.advance();
109 Some(alias_name)
110 } else {
111 None
112 }
113 } else {
114 None
115 };
116
117 columns.push(SelectColumn::Column { name, alias });
118 } else {
119 return Err(ParseError::UnexpectedToken {
120 expected: "column name or *".to_string(),
121 found: self.current_token().clone(),
122 });
123 }
124
125 if self.current_token() == &Token::Comma {
126 self.advance();
127 } else {
128 break;
129 }
130 }
131
132 if columns.is_empty() {
133 return Err(ParseError::EmptySelectList);
134 }
135
136 Ok(SelectClause { columns })
137 }
138
139 fn parse_from(&mut self) -> Result<FromClause, ParseError> {
140 self.expect_token(Token::From)?;
141
142 let table = if let Token::Identifier(name) = self.current_token().clone() {
143 self.advance();
144 name
145 } else {
146 return Err(ParseError::UnexpectedToken {
147 expected: "table name".to_string(),
148 found: self.current_token().clone(),
149 });
150 };
151
152 let mut joins = Vec::new();
153
154 while matches!(
156 self.current_token(),
157 Token::Inner | Token::Left | Token::Right | Token::Full | Token::Join
158 ) {
159 let join_type = match self.current_token() {
160 Token::Inner => {
161 self.advance();
162 self.expect_token(Token::Join)?;
163 JoinType::Inner
164 }
165 Token::Left => {
166 self.advance();
167 self.expect_token(Token::Join)?;
168 JoinType::Left
169 }
170 Token::Right => {
171 self.advance();
172 self.expect_token(Token::Join)?;
173 JoinType::Right
174 }
175 Token::Full => {
176 self.advance();
177 self.expect_token(Token::Join)?;
178 JoinType::Full
179 }
180 Token::Join => {
181 self.advance();
182 JoinType::Inner }
184 _ => break,
185 };
186
187 let join_table = if let Token::Identifier(name) = self.current_token().clone() {
188 self.advance();
189 name
190 } else {
191 return Err(ParseError::UnexpectedToken {
192 expected: "table name".to_string(),
193 found: self.current_token().clone(),
194 });
195 };
196
197 self.expect_token(Token::On)?;
198 let condition = self.parse_expression()?;
199
200 joins.push(Join {
201 join_type,
202 table: join_table,
203 condition,
204 });
205 }
206
207 Ok(FromClause { table, joins })
208 }
209
210 fn parse_where(&mut self) -> Result<Option<WhereClause>, ParseError> {
211 if self.current_token() != &Token::Where {
212 return Ok(None);
213 }
214
215 self.advance();
216 let condition = self.parse_expression()?;
217
218 Ok(Some(WhereClause { condition }))
219 }
220
221 fn parse_expression(&mut self) -> Result<Expression, ParseError> {
222 self.parse_logical_or()
223 }
224
225 fn parse_logical_or(&mut self) -> Result<Expression, ParseError> {
226 let mut left = self.parse_logical_and()?;
227
228 while self.current_token() == &Token::Or {
229 self.advance();
230 let right = self.parse_logical_and()?;
231 left = Expression::LogicalOp {
232 left: Box::new(left),
233 op: LogicalOperator::Or,
234 right: Box::new(right),
235 };
236 }
237
238 Ok(left)
239 }
240
241 fn parse_logical_and(&mut self) -> Result<Expression, ParseError> {
242 let mut left = self.parse_not()?;
243
244 while self.current_token() == &Token::And {
245 self.advance();
246 let right = self.parse_not()?;
247 left = Expression::LogicalOp {
248 left: Box::new(left),
249 op: LogicalOperator::And,
250 right: Box::new(right),
251 };
252 }
253
254 Ok(left)
255 }
256
257 fn parse_not(&mut self) -> Result<Expression, ParseError> {
258 if self.current_token() == &Token::Not {
259 self.advance();
260 let expr = self.parse_comparison()?;
261 return Ok(Expression::Not(Box::new(expr)));
262 }
263
264 self.parse_comparison()
265 }
266
267 fn parse_comparison(&mut self) -> Result<Expression, ParseError> {
268 let left = self.parse_primary()?;
269
270 if self.current_token() == &Token::Like {
272 self.advance();
273 if let Token::String(pattern) = self.current_token().clone() {
274 self.advance();
275 return Ok(Expression::Like {
276 expr: Box::new(left),
277 pattern,
278 });
279 } else {
280 return Err(ParseError::UnexpectedToken {
281 expected: "string pattern".to_string(),
282 found: self.current_token().clone(),
283 });
284 }
285 }
286
287 if self.current_token() == &Token::In {
289 self.advance();
290 self.expect_token(Token::LeftParen)?;
291
292 let mut values = Vec::new();
293 loop {
294 let value = self.parse_literal()?;
295 values.push(value);
296
297 if self.current_token() == &Token::Comma {
298 self.advance();
299 } else {
300 break;
301 }
302 }
303
304 self.expect_token(Token::RightParen)?;
305
306 return Ok(Expression::In {
307 expr: Box::new(left),
308 values,
309 });
310 }
311
312 if self.current_token() == &Token::Between {
314 self.advance();
315 let min = self.parse_primary()?;
316 self.expect_token(Token::And)?;
317 let max = self.parse_primary()?;
318
319 return Ok(Expression::Between {
320 expr: Box::new(left),
321 min: Box::new(min),
322 max: Box::new(max),
323 });
324 }
325
326 let op = match self.current_token() {
328 Token::Eq => BinaryOperator::Eq,
329 Token::Ne => BinaryOperator::Ne,
330 Token::Lt => BinaryOperator::Lt,
331 Token::Le => BinaryOperator::Le,
332 Token::Gt => BinaryOperator::Gt,
333 Token::Ge => BinaryOperator::Ge,
334 _ => return Ok(left),
335 };
336
337 self.advance();
338 let right = self.parse_primary()?;
339
340 Ok(Expression::BinaryOp {
341 left: Box::new(left),
342 op,
343 right: Box::new(right),
344 })
345 }
346
347 fn parse_primary(&mut self) -> Result<Expression, ParseError> {
348 match self.current_token().clone() {
349 Token::Identifier(name) => {
350 self.advance();
351 Ok(Expression::Column(name))
352 }
353 Token::Integer(i) => {
354 self.advance();
355 Ok(Expression::Literal(Literal::Integer(i)))
356 }
357 Token::Float(f) => {
358 self.advance();
359 Ok(Expression::Literal(Literal::Float(f)))
360 }
361 Token::String(s) => {
362 self.advance();
363 Ok(Expression::Literal(Literal::String(s)))
364 }
365 Token::Boolean(b) => {
366 self.advance();
367 Ok(Expression::Literal(Literal::Boolean(b)))
368 }
369 Token::Null => {
370 self.advance();
371 Ok(Expression::Literal(Literal::Null))
372 }
373 Token::LeftParen => {
374 self.advance();
375 let expr = self.parse_expression()?;
376 self.expect_token(Token::RightParen)?;
377 Ok(expr)
378 }
379 token => Err(ParseError::UnexpectedToken {
380 expected: "expression".to_string(),
381 found: token,
382 }),
383 }
384 }
385
386 fn parse_literal(&mut self) -> Result<Literal, ParseError> {
387 match self.current_token().clone() {
388 Token::Integer(i) => {
389 self.advance();
390 Ok(Literal::Integer(i))
391 }
392 Token::Float(f) => {
393 self.advance();
394 Ok(Literal::Float(f))
395 }
396 Token::String(s) => {
397 self.advance();
398 Ok(Literal::String(s))
399 }
400 Token::Boolean(b) => {
401 self.advance();
402 Ok(Literal::Boolean(b))
403 }
404 Token::Null => {
405 self.advance();
406 Ok(Literal::Null)
407 }
408 token => Err(ParseError::UnexpectedToken {
409 expected: "literal value".to_string(),
410 found: token,
411 }),
412 }
413 }
414
415 fn parse_order_by(&mut self) -> Result<Option<OrderByClause>, ParseError> {
416 if self.current_token() != &Token::OrderBy {
417 return Ok(None);
418 }
419
420 self.advance();
421
422 let mut columns = Vec::new();
423
424 loop {
425 let column = if let Token::Identifier(name) = self.current_token().clone() {
426 self.advance();
427 name
428 } else {
429 return Err(ParseError::UnexpectedToken {
430 expected: "column name".to_string(),
431 found: self.current_token().clone(),
432 });
433 };
434
435 let direction = if self.current_token() == &Token::Desc {
436 self.advance();
437 OrderDirection::Desc
438 } else {
439 if self.current_token() == &Token::Asc {
440 self.advance();
441 }
442 OrderDirection::Asc
443 };
444
445 columns.push(OrderByColumn { column, direction });
446
447 if self.current_token() == &Token::Comma {
448 self.advance();
449 } else {
450 break;
451 }
452 }
453
454 Ok(Some(OrderByClause { columns }))
455 }
456
457 fn parse_limit(&mut self) -> Result<Option<LimitClause>, ParseError> {
458 if self.current_token() != &Token::Limit {
459 return Ok(None);
460 }
461
462 self.advance();
463
464 let count = if let Token::Integer(n) = self.current_token() {
465 if *n < 0 {
466 return Err(ParseError::InvalidLimitValue(*n));
467 }
468 let count = *n as usize;
469 self.advance();
470 count
471 } else {
472 return Err(ParseError::UnexpectedToken {
473 expected: "integer".to_string(),
474 found: self.current_token().clone(),
475 });
476 };
477
478 let offset = if self.current_token() == &Token::Offset {
479 self.advance();
480 if let Token::Integer(n) = self.current_token() {
481 if *n < 0 {
482 return Err(ParseError::InvalidOffsetValue(*n));
483 }
484 let offset = *n as usize;
485 self.advance();
486 Some(offset)
487 } else {
488 return Err(ParseError::UnexpectedToken {
489 expected: "integer".to_string(),
490 found: self.current_token().clone(),
491 });
492 }
493 } else {
494 None
495 };
496
497 Ok(Some(LimitClause { count, offset }))
498 }
499
500 fn current_token(&self) -> &Token {
501 &self.tokens[self.position]
502 }
503
504 fn advance(&mut self) {
505 if self.position < self.tokens.len() - 1 {
506 self.position += 1;
507 }
508 }
509
510 fn expect_token(&mut self, expected: Token) -> Result<(), ParseError> {
511 if self.current_token() == &expected {
512 self.advance();
513 Ok(())
514 } else {
515 Err(ParseError::UnexpectedToken {
516 expected: format!("{}", expected),
517 found: self.current_token().clone(),
518 })
519 }
520 }
521}
522
523#[derive(Debug, Clone)]
525pub enum ParseError {
526 LexerError(LexerError),
527 UnexpectedToken { expected: String, found: Token },
528 EmptySelectList,
529 InvalidLimitValue(i64),
530 InvalidOffsetValue(i64),
531}
532
533impl fmt::Display for ParseError {
534 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
535 match self {
536 ParseError::LexerError(e) => write!(f, "Lexer error: {}", e),
537 ParseError::UnexpectedToken { expected, found } => {
538 write!(f, "Expected {}, found {}", expected, found)
539 }
540 ParseError::EmptySelectList => write!(f, "SELECT list cannot be empty"),
541 ParseError::InvalidLimitValue(n) => {
542 write!(f, "Invalid LIMIT value: {} (must be non-negative)", n)
543 }
544 ParseError::InvalidOffsetValue(n) => {
545 write!(f, "Invalid OFFSET value: {} (must be non-negative)", n)
546 }
547 }
548 }
549}
550
551impl std::error::Error for ParseError {}
552
553#[cfg(test)]
554mod tests {
555 use super::*;
556
557 #[test]
558 fn test_simple_select() {
559 let mut parser = Parser::new("SELECT * FROM users").unwrap();
560 let query = parser.parse().unwrap();
561
562 assert_eq!(query.select.columns.len(), 1);
563 assert!(matches!(query.select.columns[0], SelectColumn::Wildcard));
564 assert_eq!(query.from.table, "users");
565 }
566
567 #[test]
568 fn test_select_with_columns() {
569 let mut parser = Parser::new("SELECT name, age FROM users").unwrap();
570 let query = parser.parse().unwrap();
571
572 assert_eq!(query.select.columns.len(), 2);
573 }
574
575 #[test]
576 fn test_select_with_where() {
577 let mut parser = Parser::new("SELECT * FROM users WHERE age > 18").unwrap();
578 let query = parser.parse().unwrap();
579
580 assert!(query.where_clause.is_some());
581 }
582
583 #[test]
584 fn test_select_with_order_by() {
585 let mut parser = Parser::new("SELECT * FROM users ORDER BY name ASC").unwrap();
586 let query = parser.parse().unwrap();
587
588 assert!(query.order_by.is_some());
589 let order_by = query.order_by.unwrap();
590 assert_eq!(order_by.columns.len(), 1);
591 assert_eq!(order_by.columns[0].column, "name");
592 assert_eq!(order_by.columns[0].direction, OrderDirection::Asc);
593 }
594
595 #[test]
596 fn test_select_with_limit() {
597 let mut parser = Parser::new("SELECT * FROM users LIMIT 10").unwrap();
598 let query = parser.parse().unwrap();
599
600 assert!(query.limit.is_some());
601 let limit = query.limit.unwrap();
602 assert_eq!(limit.count, 10);
603 assert_eq!(limit.offset, None);
604 }
605
606 #[test]
607 fn test_select_with_limit_offset() {
608 let mut parser = Parser::new("SELECT * FROM users LIMIT 10 OFFSET 5").unwrap();
609 let query = parser.parse().unwrap();
610
611 let limit = query.limit.unwrap();
612 assert_eq!(limit.count, 10);
613 assert_eq!(limit.offset, Some(5));
614 }
615
616 #[test]
617 fn test_complex_where() {
618 let mut parser =
619 Parser::new("SELECT * FROM users WHERE age > 18 AND name = 'John'").unwrap();
620 let query = parser.parse().unwrap();
621
622 assert!(query.where_clause.is_some());
623 }
624
625 #[test]
626 fn test_aggregate_function() {
627 let mut parser = Parser::new("SELECT COUNT(*) FROM users").unwrap();
628 let query = parser.parse().unwrap();
629
630 assert_eq!(query.select.columns.len(), 1);
631 assert!(matches!(
632 query.select.columns[0],
633 SelectColumn::Aggregate { .. }
634 ));
635 }
636
637 #[test]
638 fn test_join() {
639 let mut parser =
640 Parser::new("SELECT * FROM users INNER JOIN orders ON users.id = orders.user_id")
641 .unwrap();
642 let query = parser.parse().unwrap();
643
644 assert_eq!(query.from.joins.len(), 1);
645 assert_eq!(query.from.joins[0].join_type, JoinType::Inner);
646 assert_eq!(query.from.joins[0].table, "orders");
647 }
648}