1use crate::recursive_parser::{Lexer, Token};
2use crate::where_ast::{ComparisonOp, WhereExpr, WhereValue};
3use anyhow::{anyhow, Result};
4use chrono::{Datelike, Local};
5
6pub struct WhereParser {
7 tokens: Vec<Token>,
8 current: usize,
9 columns: Vec<String>,
10 case_insensitive: bool,
11}
12
13impl WhereParser {
14 pub fn parse(where_clause: &str) -> Result<WhereExpr> {
15 Self::parse_with_columns(where_clause, vec![])
16 }
17
18 pub fn parse_with_columns(where_clause: &str, columns: Vec<String>) -> Result<WhereExpr> {
19 Self::parse_with_options(where_clause, columns, false)
20 }
21
22 pub fn parse_with_options(
23 where_clause: &str,
24 columns: Vec<String>,
25 case_insensitive: bool,
26 ) -> Result<WhereExpr> {
27 let mut lexer = Lexer::new(where_clause);
28 let mut tokens = Vec::new();
29
30 loop {
31 let token = lexer.next_token();
32 if matches!(token, Token::Eof) {
33 break;
34 }
35 tokens.push(token);
36 }
37
38 let mut parser = WhereParser {
39 tokens,
40 current: 0,
41 columns,
42 case_insensitive,
43 };
44 parser.parse_or_expr()
45 }
46
47 fn current_token(&self) -> Option<&Token> {
48 self.tokens.get(self.current)
49 }
50
51 fn peek_token(&self) -> Option<&Token> {
52 self.tokens.get(self.current + 1)
53 }
54
55 fn advance(&mut self) -> Option<&Token> {
56 let token = self.tokens.get(self.current);
57 self.current += 1;
58 token
59 }
60
61 fn expect_identifier(&mut self) -> Result<String> {
62 let is_numeric_column = if let Some(Token::NumberLiteral(num)) = self.current_token() {
64 self.columns.iter().any(|col| col == num)
65 } else {
66 false
67 };
68
69 match self.advance() {
70 Some(Token::Identifier(name)) => Ok(name.clone()),
71 Some(Token::QuotedIdentifier(name)) => Ok(name.clone()), Some(Token::NumberLiteral(num)) => {
73 if is_numeric_column {
75 Ok(num.clone())
76 } else {
77 Err(anyhow!("Expected identifier, got number {}", num))
78 }
79 }
80 _ => Err(anyhow!("Expected identifier")),
81 }
82 }
83
84 fn parse_value(&mut self) -> Result<WhereValue> {
85 match self.current_token() {
86 Some(Token::StringLiteral(_)) => {
87 if let Some(Token::StringLiteral(s)) = self.advance() {
88 Ok(WhereValue::String(s.clone()))
89 } else {
90 unreachable!()
91 }
92 }
93 Some(Token::QuotedIdentifier(_)) => {
94 if let Some(Token::QuotedIdentifier(s)) = self.advance() {
96 Ok(WhereValue::String(s.clone()))
97 } else {
98 unreachable!()
99 }
100 }
101 Some(Token::NumberLiteral(_)) => {
102 if let Some(Token::NumberLiteral(n)) = self.advance() {
103 Ok(WhereValue::Number(n.parse::<f64>().unwrap_or(0.0)))
104 } else {
105 unreachable!()
106 }
107 }
108 Some(Token::Null) => {
109 self.advance();
110 Ok(WhereValue::Null)
111 }
112 Some(Token::DateTime) => {
113 self.advance(); self.expect_token(Token::LeftParen)?;
115
116 if matches!(self.current_token(), Some(Token::RightParen)) {
118 self.advance(); let today = Local::now();
120 let date_str = format!(
121 "{:04}-{:02}-{:02} 00:00:00",
122 today.year(),
123 today.month(),
124 today.day()
125 );
126 Ok(WhereValue::String(date_str))
127 } else {
128 let year = self.parse_number_value()? as i32;
130 self.expect_token(Token::Comma)?;
131 let month = self.parse_number_value()? as u32;
132 self.expect_token(Token::Comma)?;
133 let day = self.parse_number_value()? as u32;
134
135 let mut hour = 0u32;
136 let mut minute = 0u32;
137 let mut second = 0u32;
138
139 if matches!(self.current_token(), Some(Token::Comma)) {
141 self.advance(); hour = self.parse_number_value()? as u32;
143
144 if matches!(self.current_token(), Some(Token::Comma)) {
145 self.advance(); minute = self.parse_number_value()? as u32;
147
148 if matches!(self.current_token(), Some(Token::Comma)) {
149 self.advance(); second = self.parse_number_value()? as u32;
151 }
152 }
153 }
154
155 self.expect_token(Token::RightParen)?;
156
157 let date_str = format!(
158 "{:04}-{:02}-{:02} {:02}:{:02}:{:02}",
159 year, month, day, hour, minute, second
160 );
161 Ok(WhereValue::String(date_str))
162 }
163 }
164 _ => Err(anyhow!("Expected value")),
165 }
166 }
167
168 fn parse_or_expr(&mut self) -> Result<WhereExpr> {
170 let mut left = self.parse_and_expr()?;
171
172 while let Some(Token::Or) = self.current_token() {
173 self.advance(); let right = self.parse_and_expr()?;
175 left = WhereExpr::Or(Box::new(left), Box::new(right));
176 }
177
178 Ok(left)
179 }
180
181 fn parse_and_expr(&mut self) -> Result<WhereExpr> {
183 let mut left = self.parse_not_expr()?;
184
185 while let Some(Token::And) = self.current_token() {
186 self.advance(); let right = self.parse_not_expr()?;
188 left = WhereExpr::And(Box::new(left), Box::new(right));
189 }
190
191 Ok(left)
192 }
193
194 fn parse_not_expr(&mut self) -> Result<WhereExpr> {
196 if let Some(Token::Not) = self.current_token() {
197 self.advance(); let expr = self.parse_comparison_expr()?;
199 Ok(WhereExpr::Not(Box::new(expr)))
200 } else {
201 self.parse_comparison_expr()
202 }
203 }
204
205 fn parse_comparison_expr(&mut self) -> Result<WhereExpr> {
207 if let Some(Token::LeftParen) = self.current_token() {
209 self.advance(); let expr = self.parse_or_expr()?;
211 match self.advance() {
212 Some(Token::RightParen) => Ok(expr),
213 _ => Err(anyhow!("Expected closing parenthesis")),
214 }
215 } else {
216 self.parse_primary_expr()
217 }
218 }
219
220 fn parse_primary_expr(&mut self) -> Result<WhereExpr> {
222 let column = self.expect_identifier()?;
223
224 if let Some(Token::Dot) = self.current_token() {
226 self.advance(); let method = self.expect_identifier()?;
228
229 match method.as_str() {
230 "Contains" => {
231 self.expect_token(Token::LeftParen)?;
232 let value = self.parse_string_value()?;
233 self.expect_token(Token::RightParen)?;
234 if self.case_insensitive {
235 Ok(WhereExpr::ContainsIgnoreCase(column, value))
236 } else {
237 Ok(WhereExpr::Contains(column, value))
238 }
239 }
240 "StartsWith" => {
241 self.expect_token(Token::LeftParen)?;
242 let value = self.parse_string_value()?;
243 self.expect_token(Token::RightParen)?;
244 if self.case_insensitive {
245 Ok(WhereExpr::StartsWithIgnoreCase(column, value))
246 } else {
247 Ok(WhereExpr::StartsWith(column, value))
248 }
249 }
250 "EndsWith" => {
251 self.expect_token(Token::LeftParen)?;
252 let value = self.parse_string_value()?;
253 self.expect_token(Token::RightParen)?;
254 if self.case_insensitive {
255 Ok(WhereExpr::EndsWithIgnoreCase(column, value))
256 } else {
257 Ok(WhereExpr::EndsWith(column, value))
258 }
259 }
260 "Length" => {
261 self.expect_token(Token::LeftParen)?;
262 self.expect_token(Token::RightParen)?;
263
264 let op = self.parse_comparison_op()?;
266 let value = self.parse_number_value()?;
267 Ok(WhereExpr::Length(column, op, value as i64))
268 }
269 "ToLower" => {
270 self.expect_token(Token::LeftParen)?;
271 self.expect_token(Token::RightParen)?;
272
273 let op = self.parse_comparison_op()?;
275 let value = self.parse_string_value()?;
276 Ok(WhereExpr::ToLower(column, op, value))
277 }
278 "ToUpper" => {
279 self.expect_token(Token::LeftParen)?;
280 self.expect_token(Token::RightParen)?;
281
282 let op = self.parse_comparison_op()?;
284 let value = self.parse_string_value()?;
285 Ok(WhereExpr::ToUpper(column, op, value))
286 }
287 "IsNullOrEmpty" => {
288 self.expect_token(Token::LeftParen)?;
289 self.expect_token(Token::RightParen)?;
290 Ok(WhereExpr::IsNullOrEmpty(column))
291 }
292 _ => Err(anyhow!("Unknown method: {}", method)),
293 }
294 } else {
295 match self.current_token() {
297 Some(Token::Equal) => {
298 self.advance();
299 let value = self.parse_value()?;
300 Ok(WhereExpr::Equal(column, value))
301 }
302 Some(Token::NotEqual) => {
303 self.advance();
304 let value = self.parse_value()?;
305 Ok(WhereExpr::NotEqual(column, value))
306 }
307 Some(Token::GreaterThan) => {
308 self.advance();
309 let value = self.parse_value()?;
310 Ok(WhereExpr::GreaterThan(column, value))
311 }
312 Some(Token::GreaterThanOrEqual) => {
313 self.advance();
314 let value = self.parse_value()?;
315 Ok(WhereExpr::GreaterThanOrEqual(column, value))
316 }
317 Some(Token::LessThan) => {
318 self.advance();
319 let value = self.parse_value()?;
320 Ok(WhereExpr::LessThan(column, value))
321 }
322 Some(Token::LessThanOrEqual) => {
323 self.advance();
324 let value = self.parse_value()?;
325 Ok(WhereExpr::LessThanOrEqual(column, value))
326 }
327 Some(Token::Between) => {
328 self.advance();
329 let lower = self.parse_value()?;
330 self.expect_token(Token::And)?;
331 let upper = self.parse_value()?;
332 Ok(WhereExpr::Between(column, lower, upper))
333 }
334 Some(Token::In) => {
335 self.advance();
336 self.expect_token(Token::LeftParen)?;
337 let values = self.parse_value_list()?;
338 self.expect_token(Token::RightParen)?;
339 if self.case_insensitive {
340 Ok(WhereExpr::InIgnoreCase(column, values))
341 } else {
342 Ok(WhereExpr::In(column, values))
343 }
344 }
345 Some(Token::Not) if matches!(self.peek_token(), Some(Token::In)) => {
346 self.advance(); self.advance(); self.expect_token(Token::LeftParen)?;
349 let values = self.parse_value_list()?;
350 self.expect_token(Token::RightParen)?;
351 if self.case_insensitive {
352 Ok(WhereExpr::NotInIgnoreCase(column, values))
353 } else {
354 Ok(WhereExpr::NotIn(column, values))
355 }
356 }
357 Some(Token::Like) => {
358 self.advance();
359 let pattern = self.parse_string_value()?;
360 Ok(WhereExpr::Like(column, pattern))
361 }
362 Some(Token::Is) => {
363 self.advance();
364 match self.current_token() {
365 Some(Token::Null) => {
366 self.advance();
367 Ok(WhereExpr::IsNull(column))
368 }
369 Some(Token::Not) if matches!(self.peek_token(), Some(Token::Null)) => {
370 self.advance(); self.advance(); Ok(WhereExpr::IsNotNull(column))
373 }
374 _ => Err(anyhow!("Expected NULL or NOT NULL after IS")),
375 }
376 }
377 _ => Err(anyhow!("Expected operator after column")),
378 }
379 }
380 }
381
382 fn parse_comparison_op(&mut self) -> Result<ComparisonOp> {
383 match self.advance() {
384 Some(Token::Equal) => Ok(ComparisonOp::Equal),
385 Some(Token::NotEqual) => Ok(ComparisonOp::NotEqual),
386 Some(Token::GreaterThan) => Ok(ComparisonOp::GreaterThan),
387 Some(Token::GreaterThanOrEqual) => Ok(ComparisonOp::GreaterThanOrEqual),
388 Some(Token::LessThan) => Ok(ComparisonOp::LessThan),
389 Some(Token::LessThanOrEqual) => Ok(ComparisonOp::LessThanOrEqual),
390 _ => Err(anyhow!("Expected comparison operator")),
391 }
392 }
393
394 fn parse_string_value(&mut self) -> Result<String> {
395 match self.advance() {
396 Some(Token::StringLiteral(s)) => Ok(s.clone()),
397 Some(Token::QuotedIdentifier(s)) => Ok(s.clone()), _ => Err(anyhow!("Expected string literal")),
399 }
400 }
401
402 fn parse_number_value(&mut self) -> Result<f64> {
403 match self.advance() {
404 Some(Token::NumberLiteral(n)) => {
405 n.parse::<f64>().map_err(|_| anyhow!("Invalid number"))
406 }
407 _ => Err(anyhow!("Expected number literal")),
408 }
409 }
410
411 fn parse_value_list(&mut self) -> Result<Vec<WhereValue>> {
412 let mut values = vec![self.parse_value()?];
413
414 while let Some(Token::Comma) = self.current_token() {
415 self.advance(); values.push(self.parse_value()?);
417 }
418
419 Ok(values)
420 }
421
422 fn expect_token(&mut self, expected: Token) -> Result<()> {
423 match self.advance() {
424 Some(token) if std::mem::discriminant(token) == std::mem::discriminant(&expected) => {
425 Ok(())
426 }
427 Some(token) => Err(anyhow!("Expected {:?}, got {:?}", expected, token)),
428 None => Err(anyhow!("Unexpected end of input")),
429 }
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 #[test]
438 fn test_simple_comparison() {
439 let expr = WhereParser::parse("price > 100").unwrap();
440 match expr {
441 WhereExpr::GreaterThan(col, val) => {
442 assert_eq!(col, "price");
443 assert_eq!(val, WhereValue::Number(100.0));
444 }
445 _ => panic!("Wrong expression type"),
446 }
447 }
448
449 #[test]
450 fn test_and_expression() {
451 let expr = WhereParser::parse("price > 100 AND category = \"Electronics\"").unwrap();
452 match expr {
453 WhereExpr::And(left, right) => {
454 match left.as_ref() {
455 WhereExpr::GreaterThan(col, val) => {
456 assert_eq!(col, "price");
457 assert_eq!(val, &WhereValue::Number(100.0));
458 }
459 _ => panic!("Wrong left expression"),
460 }
461 match right.as_ref() {
462 WhereExpr::Equal(col, val) => {
463 assert_eq!(col, "category");
464 assert_eq!(val, &WhereValue::String("Electronics".to_string()));
465 }
466 _ => panic!("Wrong right expression"),
467 }
468 }
469 _ => panic!("Wrong expression type"),
470 }
471 }
472
473 #[test]
474 fn test_between_with_and() {
475 let expr = WhereParser::parse(
476 "category = \"Electronics\" AND price BETWEEN 100 AND 500 AND quantity > 0",
477 )
478 .unwrap();
479 match expr {
481 WhereExpr::And(left, right) => {
482 match left.as_ref() {
484 WhereExpr::And(ll, lr) => {
485 match ll.as_ref() {
486 WhereExpr::Equal(col, val) => {
487 assert_eq!(col, "category");
488 assert_eq!(val, &WhereValue::String("Electronics".to_string()));
489 }
490 _ => panic!("Wrong leftmost expression"),
491 }
492 match lr.as_ref() {
493 WhereExpr::Between(col, lower, upper) => {
494 assert_eq!(col, "price");
495 assert_eq!(lower, &WhereValue::Number(100.0));
496 assert_eq!(upper, &WhereValue::Number(500.0));
497 }
498 _ => panic!("Wrong middle expression"),
499 }
500 }
501 _ => panic!("Wrong left structure"),
502 }
503 match right.as_ref() {
504 WhereExpr::GreaterThan(col, val) => {
505 assert_eq!(col, "quantity");
506 assert_eq!(val, &WhereValue::Number(0.0));
507 }
508 _ => panic!("Wrong right expression"),
509 }
510 }
511 _ => panic!("Wrong expression type"),
512 }
513 }
514
515 #[test]
516 fn test_parentheses_precedence() {
517 let expr1 = WhereParser::parse("a = 1 OR b = 2 AND c = 3").unwrap();
520 match expr1 {
521 WhereExpr::Or(left, right) => {
522 match left.as_ref() {
524 WhereExpr::Equal(col, val) => {
525 assert_eq!(col, "a");
526 assert_eq!(val, &WhereValue::Number(1.0));
527 }
528 _ => panic!("Wrong left expression"),
529 }
530 match right.as_ref() {
532 WhereExpr::And(l, r) => {
533 match l.as_ref() {
534 WhereExpr::Equal(col, val) => {
535 assert_eq!(col, "b");
536 assert_eq!(val, &WhereValue::Number(2.0));
537 }
538 _ => panic!("Wrong AND left"),
539 }
540 match r.as_ref() {
541 WhereExpr::Equal(col, val) => {
542 assert_eq!(col, "c");
543 assert_eq!(val, &WhereValue::Number(3.0));
544 }
545 _ => panic!("Wrong AND right"),
546 }
547 }
548 _ => panic!("Wrong right expression"),
549 }
550 }
551 _ => panic!("Wrong top-level expression"),
552 }
553
554 let expr2 = WhereParser::parse("(a = 1 OR b = 2) AND c = 3").unwrap();
556 match expr2 {
557 WhereExpr::And(left, right) => {
558 match left.as_ref() {
560 WhereExpr::Or(l, r) => {
561 match l.as_ref() {
562 WhereExpr::Equal(col, val) => {
563 assert_eq!(col, "a");
564 assert_eq!(val, &WhereValue::Number(1.0));
565 }
566 _ => panic!("Wrong OR left"),
567 }
568 match r.as_ref() {
569 WhereExpr::Equal(col, val) => {
570 assert_eq!(col, "b");
571 assert_eq!(val, &WhereValue::Number(2.0));
572 }
573 _ => panic!("Wrong OR right"),
574 }
575 }
576 _ => panic!("Wrong left expression"),
577 }
578 match right.as_ref() {
580 WhereExpr::Equal(col, val) => {
581 assert_eq!(col, "c");
582 assert_eq!(val, &WhereValue::Number(3.0));
583 }
584 _ => panic!("Wrong right expression"),
585 }
586 }
587 _ => panic!("Wrong top-level expression"),
588 }
589 }
590
591 #[test]
592 fn test_case_conversion_methods() {
593 let expr = WhereParser::parse("executionSide.ToLower() = \"buy\"").unwrap();
595 match expr {
596 WhereExpr::ToLower(col, op, val) => {
597 assert_eq!(col, "executionSide");
598 assert_eq!(op, ComparisonOp::Equal);
599 assert_eq!(val, "buy");
600 }
601 _ => panic!("Wrong expression type for ToLower"),
602 }
603
604 let expr = WhereParser::parse("status.ToUpper() != \"PENDING\"").unwrap();
606 match expr {
607 WhereExpr::ToUpper(col, op, val) => {
608 assert_eq!(col, "status");
609 assert_eq!(op, ComparisonOp::NotEqual);
610 assert_eq!(val, "PENDING");
611 }
612 _ => panic!("Wrong expression type for ToUpper"),
613 }
614 }
615
616 #[test]
617 fn test_is_null_or_empty() {
618 let expr = WhereParser::parse("name.IsNullOrEmpty()").unwrap();
620 match expr {
621 WhereExpr::IsNullOrEmpty(col) => {
622 assert_eq!(col, "name");
623 }
624 _ => panic!("Wrong expression type for IsNullOrEmpty"),
625 }
626
627 let expr2 = WhereParser::parse("\"Customer Name\".IsNullOrEmpty()").unwrap();
629 match expr2 {
630 WhereExpr::IsNullOrEmpty(col) => {
631 assert_eq!(col, "Customer Name");
632 }
633 _ => panic!("Wrong expression type for IsNullOrEmpty with quoted identifier"),
634 }
635 }
636
637 #[test]
638 fn test_is_null_or_empty_in_complex_expression() {
639 let expr = WhereParser::parse("name.IsNullOrEmpty() OR age > 18").unwrap();
641 match expr {
642 WhereExpr::Or(left, right) => {
643 match *left {
644 WhereExpr::IsNullOrEmpty(col) => {
645 assert_eq!(col, "name");
646 }
647 _ => panic!("Left side should be IsNullOrEmpty"),
648 }
649 match *right {
650 WhereExpr::GreaterThan(col, val) => {
651 assert_eq!(col, "age");
652 assert_eq!(val, WhereValue::Number(18.0));
653 }
654 _ => panic!("Right side should be GreaterThan"),
655 }
656 }
657 _ => panic!("Should be an OR expression"),
658 }
659 }
660
661 #[test]
662 fn test_numeric_column_names() {
663 let columns = vec![
665 "Borough".to_string(),
666 "202202".to_string(),
667 "202203".to_string(),
668 "202204".to_string(),
669 "202205".to_string(),
670 ];
671
672 let expr = WhereParser::parse_with_columns("202204 > 2.0", columns.clone()).unwrap();
674 match expr {
675 WhereExpr::GreaterThan(col, val) => {
676 assert_eq!(col, "202204");
677 assert_eq!(val, WhereValue::Number(2.0));
678 }
679 _ => panic!("Expected GreaterThan with numeric column name"),
680 }
681
682 let expr2 = WhereParser::parse_with_columns(
684 "Borough = \"London\" AND 202204 > 1.0",
685 columns.clone(),
686 )
687 .unwrap();
688 match expr2 {
689 WhereExpr::And(left, right) => {
690 match &*left {
691 WhereExpr::Equal(col, val) => {
692 assert_eq!(col, "Borough");
693 assert_eq!(val, &WhereValue::String("London".to_string()));
694 }
695 _ => panic!("Expected Equal on left"),
696 }
697 match &*right {
698 WhereExpr::GreaterThan(col, val) => {
699 assert_eq!(col, "202204");
700 assert_eq!(val, &WhereValue::Number(1.0));
701 }
702 _ => panic!("Expected GreaterThan on right"),
703 }
704 }
705 _ => panic!("Expected And expression"),
706 }
707
708 let limited_columns = vec!["price".to_string(), "quantity".to_string()];
710 let expr3 = WhereParser::parse_with_columns("price > 100", limited_columns).unwrap();
711 match expr3 {
712 WhereExpr::GreaterThan(col, val) => {
713 assert_eq!(col, "price");
714 assert_eq!(val, WhereValue::Number(100.0));
715 }
716 _ => panic!("Expected GreaterThan"),
717 }
718 }
719}