1use crate::recursive_parser::{Lexer, Token};
2use crate::where_ast::{ComparisonOp, WhereExpr, WhereValue};
3use anyhow::{anyhow, Result};
4use chrono::{Datelike, Local};
5
6pub struct WhereParser {
7 tokens: Vec<Token>,
8 current: usize,
9 columns: Vec<String>,
10 case_insensitive: bool,
11}
12
13impl WhereParser {
14 pub fn parse(where_clause: &str) -> Result<WhereExpr> {
15 Self::parse_with_columns(where_clause, vec![])
16 }
17
18 pub fn parse_with_columns(where_clause: &str, columns: Vec<String>) -> Result<WhereExpr> {
19 Self::parse_with_options(where_clause, columns, false)
20 }
21
22 pub fn parse_with_options(
23 where_clause: &str,
24 columns: Vec<String>,
25 case_insensitive: bool,
26 ) -> Result<WhereExpr> {
27 let mut lexer = Lexer::new(where_clause);
28 let mut tokens = Vec::new();
29
30 loop {
31 let token = lexer.next_token();
32 if matches!(token, Token::Eof) {
33 break;
34 }
35 tokens.push(token);
36 }
37
38 let mut parser = WhereParser {
39 tokens,
40 current: 0,
41 columns,
42 case_insensitive,
43 };
44 parser.parse_or_expr()
45 }
46
47 fn current_token(&self) -> Option<&Token> {
48 self.tokens.get(self.current)
49 }
50
51 fn peek_token(&self) -> Option<&Token> {
52 self.tokens.get(self.current + 1)
53 }
54
55 fn advance(&mut self) -> Option<&Token> {
56 let token = self.tokens.get(self.current);
57 self.current += 1;
58 token
59 }
60
61 fn expect_identifier(&mut self) -> Result<String> {
62 let is_numeric_column = if let Some(Token::NumberLiteral(num)) = self.current_token() {
64 self.columns.iter().any(|col| col == num)
65 } else {
66 false
67 };
68
69 match self.advance() {
70 Some(Token::Identifier(name)) => Ok(name.clone()),
71 Some(Token::QuotedIdentifier(name)) => Ok(name.clone()), Some(Token::NumberLiteral(num)) => {
73 if is_numeric_column {
75 Ok(num.clone())
76 } else {
77 Err(anyhow!("Expected identifier, got number {}", num))
78 }
79 }
80 _ => Err(anyhow!("Expected identifier")),
81 }
82 }
83
84 fn parse_value(&mut self) -> Result<WhereValue> {
85 match self.current_token() {
86 Some(Token::StringLiteral(_)) => {
87 if let Some(Token::StringLiteral(s)) = self.advance() {
88 Ok(WhereValue::String(s.clone()))
89 } else {
90 unreachable!()
91 }
92 }
93 Some(Token::QuotedIdentifier(_)) => {
94 if let Some(Token::QuotedIdentifier(s)) = self.advance() {
96 Ok(WhereValue::String(s.clone()))
97 } else {
98 unreachable!()
99 }
100 }
101 Some(Token::NumberLiteral(_)) => {
102 if let Some(Token::NumberLiteral(n)) = self.advance() {
103 Ok(WhereValue::Number(n.parse::<f64>().unwrap_or(0.0)))
104 } else {
105 unreachable!()
106 }
107 }
108 Some(Token::Null) => {
109 self.advance();
110 Ok(WhereValue::Null)
111 }
112 Some(Token::DateTime) => {
113 self.advance(); self.expect_token(Token::LeftParen)?;
115
116 if matches!(self.current_token(), Some(Token::RightParen)) {
118 self.advance(); let today = Local::now();
120 let date_str = format!(
121 "{:04}-{:02}-{:02} 00:00:00",
122 today.year(),
123 today.month(),
124 today.day()
125 );
126 Ok(WhereValue::String(date_str))
127 } else {
128 let year = self.parse_number_value()? as i32;
130 self.expect_token(Token::Comma)?;
131 let month = self.parse_number_value()? as u32;
132 self.expect_token(Token::Comma)?;
133 let day = self.parse_number_value()? as u32;
134
135 let mut hour = 0u32;
136 let mut minute = 0u32;
137 let mut second = 0u32;
138
139 if matches!(self.current_token(), Some(Token::Comma)) {
141 self.advance(); hour = self.parse_number_value()? as u32;
143
144 if matches!(self.current_token(), Some(Token::Comma)) {
145 self.advance(); minute = self.parse_number_value()? as u32;
147
148 if matches!(self.current_token(), Some(Token::Comma)) {
149 self.advance(); second = self.parse_number_value()? as u32;
151 }
152 }
153 }
154
155 self.expect_token(Token::RightParen)?;
156
157 let date_str =
158 format!("{year:04}-{month:02}-{day:02} {hour:02}:{minute:02}:{second:02}");
159 Ok(WhereValue::String(date_str))
160 }
161 }
162 _ => Err(anyhow!("Expected value")),
163 }
164 }
165
166 fn parse_or_expr(&mut self) -> Result<WhereExpr> {
168 let mut left = self.parse_and_expr()?;
169
170 while let Some(Token::Or) = self.current_token() {
171 self.advance(); let right = self.parse_and_expr()?;
173 left = WhereExpr::Or(Box::new(left), Box::new(right));
174 }
175
176 Ok(left)
177 }
178
179 fn parse_and_expr(&mut self) -> Result<WhereExpr> {
181 let mut left = self.parse_not_expr()?;
182
183 while let Some(Token::And) = self.current_token() {
184 self.advance(); let right = self.parse_not_expr()?;
186 left = WhereExpr::And(Box::new(left), Box::new(right));
187 }
188
189 Ok(left)
190 }
191
192 fn parse_not_expr(&mut self) -> Result<WhereExpr> {
194 if let Some(Token::Not) = self.current_token() {
195 self.advance(); let expr = self.parse_comparison_expr()?;
197 Ok(WhereExpr::Not(Box::new(expr)))
198 } else {
199 self.parse_comparison_expr()
200 }
201 }
202
203 fn parse_comparison_expr(&mut self) -> Result<WhereExpr> {
205 if let Some(Token::LeftParen) = self.current_token() {
207 self.advance(); let expr = self.parse_or_expr()?;
209 match self.advance() {
210 Some(Token::RightParen) => Ok(expr),
211 _ => Err(anyhow!("Expected closing parenthesis")),
212 }
213 } else {
214 self.parse_primary_expr()
215 }
216 }
217
218 fn parse_primary_expr(&mut self) -> Result<WhereExpr> {
220 let column = self.expect_identifier()?;
221
222 if let Some(Token::Dot) = self.current_token() {
224 self.advance(); let method = self.expect_identifier()?;
226
227 match method.as_str() {
228 "Contains" => {
229 self.expect_token(Token::LeftParen)?;
230 let value = self.parse_string_value()?;
231 self.expect_token(Token::RightParen)?;
232 if self.case_insensitive {
233 Ok(WhereExpr::ContainsIgnoreCase(column, value))
234 } else {
235 Ok(WhereExpr::Contains(column, value))
236 }
237 }
238 "StartsWith" => {
239 self.expect_token(Token::LeftParen)?;
240 let value = self.parse_string_value()?;
241 self.expect_token(Token::RightParen)?;
242 if self.case_insensitive {
243 Ok(WhereExpr::StartsWithIgnoreCase(column, value))
244 } else {
245 Ok(WhereExpr::StartsWith(column, value))
246 }
247 }
248 "EndsWith" => {
249 self.expect_token(Token::LeftParen)?;
250 let value = self.parse_string_value()?;
251 self.expect_token(Token::RightParen)?;
252 if self.case_insensitive {
253 Ok(WhereExpr::EndsWithIgnoreCase(column, value))
254 } else {
255 Ok(WhereExpr::EndsWith(column, value))
256 }
257 }
258 "Length" => {
259 self.expect_token(Token::LeftParen)?;
260 self.expect_token(Token::RightParen)?;
261
262 let op = self.parse_comparison_op()?;
264 let value = self.parse_number_value()?;
265 Ok(WhereExpr::Length(column, op, value as i64))
266 }
267 "ToLower" => {
268 self.expect_token(Token::LeftParen)?;
269 self.expect_token(Token::RightParen)?;
270
271 let op = self.parse_comparison_op()?;
273 let value = self.parse_string_value()?;
274 Ok(WhereExpr::ToLower(column, op, value))
275 }
276 "ToUpper" => {
277 self.expect_token(Token::LeftParen)?;
278 self.expect_token(Token::RightParen)?;
279
280 let op = self.parse_comparison_op()?;
282 let value = self.parse_string_value()?;
283 Ok(WhereExpr::ToUpper(column, op, value))
284 }
285 "IsNullOrEmpty" => {
286 self.expect_token(Token::LeftParen)?;
287 self.expect_token(Token::RightParen)?;
288 Ok(WhereExpr::IsNullOrEmpty(column))
289 }
290 _ => Err(anyhow!("Unknown method: {}", method)),
291 }
292 } else {
293 match self.current_token() {
295 Some(Token::Equal) => {
296 self.advance();
297 let value = self.parse_value()?;
298 Ok(WhereExpr::Equal(column, value))
299 }
300 Some(Token::NotEqual) => {
301 self.advance();
302 let value = self.parse_value()?;
303 Ok(WhereExpr::NotEqual(column, value))
304 }
305 Some(Token::GreaterThan) => {
306 self.advance();
307 let value = self.parse_value()?;
308 Ok(WhereExpr::GreaterThan(column, value))
309 }
310 Some(Token::GreaterThanOrEqual) => {
311 self.advance();
312 let value = self.parse_value()?;
313 Ok(WhereExpr::GreaterThanOrEqual(column, value))
314 }
315 Some(Token::LessThan) => {
316 self.advance();
317 let value = self.parse_value()?;
318 Ok(WhereExpr::LessThan(column, value))
319 }
320 Some(Token::LessThanOrEqual) => {
321 self.advance();
322 let value = self.parse_value()?;
323 Ok(WhereExpr::LessThanOrEqual(column, value))
324 }
325 Some(Token::Between) => {
326 self.advance();
327 let lower = self.parse_value()?;
328 self.expect_token(Token::And)?;
329 let upper = self.parse_value()?;
330 Ok(WhereExpr::Between(column, lower, upper))
331 }
332 Some(Token::In) => {
333 self.advance();
334 self.expect_token(Token::LeftParen)?;
335 let values = self.parse_value_list()?;
336 self.expect_token(Token::RightParen)?;
337 if self.case_insensitive {
338 Ok(WhereExpr::InIgnoreCase(column, values))
339 } else {
340 Ok(WhereExpr::In(column, values))
341 }
342 }
343 Some(Token::Not) if matches!(self.peek_token(), Some(Token::In)) => {
344 self.advance(); self.advance(); self.expect_token(Token::LeftParen)?;
347 let values = self.parse_value_list()?;
348 self.expect_token(Token::RightParen)?;
349 if self.case_insensitive {
350 Ok(WhereExpr::NotInIgnoreCase(column, values))
351 } else {
352 Ok(WhereExpr::NotIn(column, values))
353 }
354 }
355 Some(Token::Like) => {
356 self.advance();
357 let pattern = self.parse_string_value()?;
358 Ok(WhereExpr::Like(column, pattern))
359 }
360 Some(Token::Is) => {
361 self.advance();
362 match self.current_token() {
363 Some(Token::Null) => {
364 self.advance();
365 Ok(WhereExpr::IsNull(column))
366 }
367 Some(Token::Not) if matches!(self.peek_token(), Some(Token::Null)) => {
368 self.advance(); self.advance(); Ok(WhereExpr::IsNotNull(column))
371 }
372 _ => Err(anyhow!("Expected NULL or NOT NULL after IS")),
373 }
374 }
375 _ => Err(anyhow!("Expected operator after column")),
376 }
377 }
378 }
379
380 fn parse_comparison_op(&mut self) -> Result<ComparisonOp> {
381 match self.advance() {
382 Some(Token::Equal) => Ok(ComparisonOp::Equal),
383 Some(Token::NotEqual) => Ok(ComparisonOp::NotEqual),
384 Some(Token::GreaterThan) => Ok(ComparisonOp::GreaterThan),
385 Some(Token::GreaterThanOrEqual) => Ok(ComparisonOp::GreaterThanOrEqual),
386 Some(Token::LessThan) => Ok(ComparisonOp::LessThan),
387 Some(Token::LessThanOrEqual) => Ok(ComparisonOp::LessThanOrEqual),
388 _ => Err(anyhow!("Expected comparison operator")),
389 }
390 }
391
392 fn parse_string_value(&mut self) -> Result<String> {
393 match self.advance() {
394 Some(Token::StringLiteral(s)) => Ok(s.clone()),
395 Some(Token::QuotedIdentifier(s)) => Ok(s.clone()), _ => Err(anyhow!("Expected string literal")),
397 }
398 }
399
400 fn parse_number_value(&mut self) -> Result<f64> {
401 match self.advance() {
402 Some(Token::NumberLiteral(n)) => {
403 n.parse::<f64>().map_err(|_| anyhow!("Invalid number"))
404 }
405 _ => Err(anyhow!("Expected number literal")),
406 }
407 }
408
409 fn parse_value_list(&mut self) -> Result<Vec<WhereValue>> {
410 let mut values = vec![self.parse_value()?];
411
412 while let Some(Token::Comma) = self.current_token() {
413 self.advance(); values.push(self.parse_value()?);
415 }
416
417 Ok(values)
418 }
419
420 fn expect_token(&mut self, expected: Token) -> Result<()> {
421 match self.advance() {
422 Some(token) if std::mem::discriminant(token) == std::mem::discriminant(&expected) => {
423 Ok(())
424 }
425 Some(token) => Err(anyhow!("Expected {:?}, got {:?}", expected, token)),
426 None => Err(anyhow!("Unexpected end of input")),
427 }
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434
435 #[test]
436 fn test_simple_comparison() {
437 let expr = WhereParser::parse("price > 100").unwrap();
438 match expr {
439 WhereExpr::GreaterThan(col, val) => {
440 assert_eq!(col, "price");
441 assert_eq!(val, WhereValue::Number(100.0));
442 }
443 _ => panic!("Wrong expression type"),
444 }
445 }
446
447 #[test]
448 fn test_and_expression() {
449 let expr = WhereParser::parse("price > 100 AND category = \"Electronics\"").unwrap();
450 match expr {
451 WhereExpr::And(left, right) => {
452 match left.as_ref() {
453 WhereExpr::GreaterThan(col, val) => {
454 assert_eq!(col, "price");
455 assert_eq!(val, &WhereValue::Number(100.0));
456 }
457 _ => panic!("Wrong left expression"),
458 }
459 match right.as_ref() {
460 WhereExpr::Equal(col, val) => {
461 assert_eq!(col, "category");
462 assert_eq!(val, &WhereValue::String("Electronics".to_string()));
463 }
464 _ => panic!("Wrong right expression"),
465 }
466 }
467 _ => panic!("Wrong expression type"),
468 }
469 }
470
471 #[test]
472 fn test_between_with_and() {
473 let expr = WhereParser::parse(
474 "category = \"Electronics\" AND price BETWEEN 100 AND 500 AND quantity > 0",
475 )
476 .unwrap();
477 match expr {
479 WhereExpr::And(left, right) => {
480 match left.as_ref() {
482 WhereExpr::And(ll, lr) => {
483 match ll.as_ref() {
484 WhereExpr::Equal(col, val) => {
485 assert_eq!(col, "category");
486 assert_eq!(val, &WhereValue::String("Electronics".to_string()));
487 }
488 _ => panic!("Wrong leftmost expression"),
489 }
490 match lr.as_ref() {
491 WhereExpr::Between(col, lower, upper) => {
492 assert_eq!(col, "price");
493 assert_eq!(lower, &WhereValue::Number(100.0));
494 assert_eq!(upper, &WhereValue::Number(500.0));
495 }
496 _ => panic!("Wrong middle expression"),
497 }
498 }
499 _ => panic!("Wrong left structure"),
500 }
501 match right.as_ref() {
502 WhereExpr::GreaterThan(col, val) => {
503 assert_eq!(col, "quantity");
504 assert_eq!(val, &WhereValue::Number(0.0));
505 }
506 _ => panic!("Wrong right expression"),
507 }
508 }
509 _ => panic!("Wrong expression type"),
510 }
511 }
512
513 #[test]
514 fn test_parentheses_precedence() {
515 let expr1 = WhereParser::parse("a = 1 OR b = 2 AND c = 3").unwrap();
518 match expr1 {
519 WhereExpr::Or(left, right) => {
520 match left.as_ref() {
522 WhereExpr::Equal(col, val) => {
523 assert_eq!(col, "a");
524 assert_eq!(val, &WhereValue::Number(1.0));
525 }
526 _ => panic!("Wrong left expression"),
527 }
528 match right.as_ref() {
530 WhereExpr::And(l, r) => {
531 match l.as_ref() {
532 WhereExpr::Equal(col, val) => {
533 assert_eq!(col, "b");
534 assert_eq!(val, &WhereValue::Number(2.0));
535 }
536 _ => panic!("Wrong AND left"),
537 }
538 match r.as_ref() {
539 WhereExpr::Equal(col, val) => {
540 assert_eq!(col, "c");
541 assert_eq!(val, &WhereValue::Number(3.0));
542 }
543 _ => panic!("Wrong AND right"),
544 }
545 }
546 _ => panic!("Wrong right expression"),
547 }
548 }
549 _ => panic!("Wrong top-level expression"),
550 }
551
552 let expr2 = WhereParser::parse("(a = 1 OR b = 2) AND c = 3").unwrap();
554 match expr2 {
555 WhereExpr::And(left, right) => {
556 match left.as_ref() {
558 WhereExpr::Or(l, r) => {
559 match l.as_ref() {
560 WhereExpr::Equal(col, val) => {
561 assert_eq!(col, "a");
562 assert_eq!(val, &WhereValue::Number(1.0));
563 }
564 _ => panic!("Wrong OR left"),
565 }
566 match r.as_ref() {
567 WhereExpr::Equal(col, val) => {
568 assert_eq!(col, "b");
569 assert_eq!(val, &WhereValue::Number(2.0));
570 }
571 _ => panic!("Wrong OR right"),
572 }
573 }
574 _ => panic!("Wrong left expression"),
575 }
576 match right.as_ref() {
578 WhereExpr::Equal(col, val) => {
579 assert_eq!(col, "c");
580 assert_eq!(val, &WhereValue::Number(3.0));
581 }
582 _ => panic!("Wrong right expression"),
583 }
584 }
585 _ => panic!("Wrong top-level expression"),
586 }
587 }
588
589 #[test]
590 fn test_case_conversion_methods() {
591 let expr = WhereParser::parse("executionSide.ToLower() = \"buy\"").unwrap();
593 match expr {
594 WhereExpr::ToLower(col, op, val) => {
595 assert_eq!(col, "executionSide");
596 assert_eq!(op, ComparisonOp::Equal);
597 assert_eq!(val, "buy");
598 }
599 _ => panic!("Wrong expression type for ToLower"),
600 }
601
602 let expr = WhereParser::parse("status.ToUpper() != \"PENDING\"").unwrap();
604 match expr {
605 WhereExpr::ToUpper(col, op, val) => {
606 assert_eq!(col, "status");
607 assert_eq!(op, ComparisonOp::NotEqual);
608 assert_eq!(val, "PENDING");
609 }
610 _ => panic!("Wrong expression type for ToUpper"),
611 }
612 }
613
614 #[test]
615 fn test_is_null_or_empty() {
616 let expr = WhereParser::parse("name.IsNullOrEmpty()").unwrap();
618 match expr {
619 WhereExpr::IsNullOrEmpty(col) => {
620 assert_eq!(col, "name");
621 }
622 _ => panic!("Wrong expression type for IsNullOrEmpty"),
623 }
624
625 let expr2 = WhereParser::parse("\"Customer Name\".IsNullOrEmpty()").unwrap();
627 match expr2 {
628 WhereExpr::IsNullOrEmpty(col) => {
629 assert_eq!(col, "Customer Name");
630 }
631 _ => panic!("Wrong expression type for IsNullOrEmpty with quoted identifier"),
632 }
633 }
634
635 #[test]
636 fn test_is_null_or_empty_in_complex_expression() {
637 let expr = WhereParser::parse("name.IsNullOrEmpty() OR age > 18").unwrap();
639 match expr {
640 WhereExpr::Or(left, right) => {
641 match *left {
642 WhereExpr::IsNullOrEmpty(col) => {
643 assert_eq!(col, "name");
644 }
645 _ => panic!("Left side should be IsNullOrEmpty"),
646 }
647 match *right {
648 WhereExpr::GreaterThan(col, val) => {
649 assert_eq!(col, "age");
650 assert_eq!(val, WhereValue::Number(18.0));
651 }
652 _ => panic!("Right side should be GreaterThan"),
653 }
654 }
655 _ => panic!("Should be an OR expression"),
656 }
657 }
658
659 #[test]
660 fn test_numeric_column_names() {
661 let columns = vec![
663 "Borough".to_string(),
664 "202202".to_string(),
665 "202203".to_string(),
666 "202204".to_string(),
667 "202205".to_string(),
668 ];
669
670 let expr = WhereParser::parse_with_columns("202204 > 2.0", columns.clone()).unwrap();
672 match expr {
673 WhereExpr::GreaterThan(col, val) => {
674 assert_eq!(col, "202204");
675 assert_eq!(val, WhereValue::Number(2.0));
676 }
677 _ => panic!("Expected GreaterThan with numeric column name"),
678 }
679
680 let expr2 = WhereParser::parse_with_columns(
682 "Borough = \"London\" AND 202204 > 1.0",
683 columns.clone(),
684 )
685 .unwrap();
686 match expr2 {
687 WhereExpr::And(left, right) => {
688 match &*left {
689 WhereExpr::Equal(col, val) => {
690 assert_eq!(col, "Borough");
691 assert_eq!(val, &WhereValue::String("London".to_string()));
692 }
693 _ => panic!("Expected Equal on left"),
694 }
695 match &*right {
696 WhereExpr::GreaterThan(col, val) => {
697 assert_eq!(col, "202204");
698 assert_eq!(val, &WhereValue::Number(1.0));
699 }
700 _ => panic!("Expected GreaterThan on right"),
701 }
702 }
703 _ => panic!("Expected And expression"),
704 }
705
706 let limited_columns = vec!["price".to_string(), "quantity".to_string()];
708 let expr3 = WhereParser::parse_with_columns("price > 100", limited_columns).unwrap();
709 match expr3 {
710 WhereExpr::GreaterThan(col, val) => {
711 assert_eq!(col, "price");
712 assert_eq!(val, WhereValue::Number(100.0));
713 }
714 _ => panic!("Expected GreaterThan"),
715 }
716 }
717}