1use crate::data::datatable::{DataTable, DataValue};
2use crate::sql::recursive_parser::SqlExpression;
3use anyhow::{anyhow, Result};
4use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, TimeZone, Utc};
5use tracing::debug;
6
7pub struct ArithmeticEvaluator<'a> {
10 table: &'a DataTable,
11}
12
13impl<'a> ArithmeticEvaluator<'a> {
14 pub fn new(table: &'a DataTable) -> Self {
15 Self { table }
16 }
17
18 fn find_similar_column(&self, name: &str) -> Option<String> {
20 let columns = self.table.column_names();
21 let mut best_match: Option<(String, usize)> = None;
22
23 for col in columns {
24 let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
25 let max_distance = if name.len() > 10 { 3 } else { 2 };
28 if distance <= max_distance {
29 match &best_match {
30 None => best_match = Some((col, distance)),
31 Some((_, best_dist)) if distance < *best_dist => {
32 best_match = Some((col, distance));
33 }
34 _ => {}
35 }
36 }
37 }
38
39 best_match.map(|(name, _)| name)
40 }
41
42 fn edit_distance(&self, s1: &str, s2: &str) -> usize {
44 let len1 = s1.len();
45 let len2 = s2.len();
46 let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
47
48 for i in 0..=len1 {
49 matrix[i][0] = i;
50 }
51 for j in 0..=len2 {
52 matrix[0][j] = j;
53 }
54
55 for (i, c1) in s1.chars().enumerate() {
56 for (j, c2) in s2.chars().enumerate() {
57 let cost = if c1 == c2 { 0 } else { 1 };
58 matrix[i + 1][j + 1] = std::cmp::min(
59 matrix[i][j + 1] + 1, std::cmp::min(
61 matrix[i + 1][j] + 1, matrix[i][j] + cost, ),
64 );
65 }
66 }
67
68 matrix[len1][len2]
69 }
70
71 pub fn evaluate(&self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
73 debug!(
74 "ArithmeticEvaluator: evaluating {:?} for row {}",
75 expr, row_index
76 );
77
78 match expr {
79 SqlExpression::Column(column_name) => self.evaluate_column(column_name, row_index),
80 SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
81 SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
82 SqlExpression::BinaryOp { left, op, right } => {
83 self.evaluate_binary_op(left, op, right, row_index)
84 }
85 SqlExpression::FunctionCall { name, args } => {
86 self.evaluate_function(name, args, row_index)
87 }
88 SqlExpression::MethodCall {
89 object,
90 method,
91 args,
92 } => self.evaluate_method_call(object, method, args, row_index),
93 SqlExpression::ChainedMethodCall { base, method, args } => {
94 let base_value = self.evaluate(base, row_index)?;
96 self.evaluate_method_on_value(&base_value, method, args, row_index)
97 }
98 SqlExpression::CaseExpression {
99 when_branches,
100 else_branch,
101 } => self.evaluate_case_expression(when_branches, else_branch, row_index),
102 _ => Err(anyhow!(
103 "Unsupported expression type for arithmetic evaluation: {:?}",
104 expr
105 )),
106 }
107 }
108
109 fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
111 let col_index = self.table.get_column_index(column_name).ok_or_else(|| {
112 let suggestion = self.find_similar_column(column_name);
113 match suggestion {
114 Some(similar) => anyhow!(
115 "Column '{}' not found. Did you mean '{}'?",
116 column_name,
117 similar
118 ),
119 None => anyhow!("Column '{}' not found", column_name),
120 }
121 })?;
122
123 if row_index >= self.table.row_count() {
124 return Err(anyhow!("Row index {} out of bounds", row_index));
125 }
126
127 let row = self
128 .table
129 .get_row(row_index)
130 .ok_or_else(|| anyhow!("Row {} not found", row_index))?;
131
132 let value = row
133 .get(col_index)
134 .ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
135
136 Ok(value.clone())
137 }
138
139 fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
141 if let Ok(int_val) = number_str.parse::<i64>() {
143 return Ok(DataValue::Integer(int_val));
144 }
145
146 if let Ok(float_val) = number_str.parse::<f64>() {
148 return Ok(DataValue::Float(float_val));
149 }
150
151 Err(anyhow!("Invalid number literal: {}", number_str))
152 }
153
154 fn evaluate_binary_op(
156 &self,
157 left: &SqlExpression,
158 op: &str,
159 right: &SqlExpression,
160 row_index: usize,
161 ) -> Result<DataValue> {
162 let left_val = self.evaluate(left, row_index)?;
163 let right_val = self.evaluate(right, row_index)?;
164
165 debug!(
166 "ArithmeticEvaluator: {} {} {}",
167 self.format_value(&left_val),
168 op,
169 self.format_value(&right_val)
170 );
171
172 match op {
173 "+" => self.add_values(&left_val, &right_val),
174 "-" => self.subtract_values(&left_val, &right_val),
175 "*" => self.multiply_values(&left_val, &right_val),
176 "/" => self.divide_values(&left_val, &right_val),
177 ">" => self.compare_values(&left_val, &right_val, |a, b| a > b),
179 "<" => self.compare_values(&left_val, &right_val, |a, b| a < b),
180 ">=" => self.compare_values(&left_val, &right_val, |a, b| a >= b),
181 "<=" => self.compare_values(&left_val, &right_val, |a, b| a <= b),
182 "=" => self.compare_values(&left_val, &right_val, |a, b| a == b),
183 "!=" | "<>" => self.compare_values(&left_val, &right_val, |a, b| a != b),
184 _ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
185 }
186 }
187
188 fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
190 match (left, right) {
191 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
192 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
193 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
194 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
195 _ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
196 }
197 }
198
199 fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
201 match (left, right) {
202 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
203 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
204 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
205 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
206 _ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
207 }
208 }
209
210 fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
212 match (left, right) {
213 (DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
214 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
215 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
216 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
217 _ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
218 }
219 }
220
221 fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
223 let is_zero = match right {
225 DataValue::Integer(0) => true,
226 DataValue::Float(f) if f.abs() < f64::EPSILON => true,
227 _ => false,
228 };
229
230 if is_zero {
231 return Err(anyhow!("Division by zero"));
232 }
233
234 match (left, right) {
235 (DataValue::Integer(a), DataValue::Integer(b)) => {
236 if a % b == 0 {
238 Ok(DataValue::Integer(a / b))
239 } else {
240 Ok(DataValue::Float(*a as f64 / *b as f64))
241 }
242 }
243 (DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
244 (DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
245 (DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
246 _ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
247 }
248 }
249
250 fn format_value(&self, value: &DataValue) -> String {
252 match value {
253 DataValue::Integer(i) => i.to_string(),
254 DataValue::Float(f) => f.to_string(),
255 DataValue::String(s) => format!("'{}'", s),
256 _ => format!("{:?}", value),
257 }
258 }
259
260 fn compare_values<F>(&self, left: &DataValue, right: &DataValue, op: F) -> Result<DataValue>
262 where
263 F: Fn(f64, f64) -> bool,
264 {
265 debug!(
266 "ArithmeticEvaluator: comparing values {:?} and {:?}",
267 left, right
268 );
269
270 let result = match (left, right) {
271 (DataValue::Integer(a), DataValue::Integer(b)) => op(*a as f64, *b as f64),
273 (DataValue::Integer(a), DataValue::Float(b)) => op(*a as f64, *b),
274 (DataValue::Float(a), DataValue::Integer(b)) => op(*a, *b as f64),
275 (DataValue::Float(a), DataValue::Float(b)) => op(*a, *b),
276
277 (DataValue::String(a), DataValue::String(b)) => {
279 let a_num = a.parse::<f64>();
280 let b_num = b.parse::<f64>();
281 match (a_num, b_num) {
282 (Ok(a_val), Ok(b_val)) => op(a_val, b_val), _ => op(a.len() as f64, b.len() as f64), }
285 }
286 (DataValue::InternedString(a), DataValue::InternedString(b)) => {
287 let a_num = a.parse::<f64>();
288 let b_num = b.parse::<f64>();
289 match (a_num, b_num) {
290 (Ok(a_val), Ok(b_val)) => op(a_val, b_val), _ => op(a.len() as f64, b.len() as f64), }
293 }
294 (DataValue::String(a), DataValue::InternedString(b)) => {
295 let a_num = a.parse::<f64>();
296 let b_num = b.parse::<f64>();
297 match (a_num, b_num) {
298 (Ok(a_val), Ok(b_val)) => op(a_val, b_val), _ => op(a.len() as f64, b.len() as f64), }
301 }
302 (DataValue::InternedString(a), DataValue::String(b)) => {
303 let a_num = a.parse::<f64>();
304 let b_num = b.parse::<f64>();
305 match (a_num, b_num) {
306 (Ok(a_val), Ok(b_val)) => op(a_val, b_val), _ => op(a.len() as f64, b.len() as f64), }
309 }
310
311 (DataValue::String(a), DataValue::Integer(b)) => {
313 match a.parse::<f64>() {
314 Ok(a_val) => op(a_val, *b as f64),
315 Err(_) => false, }
317 }
318 (DataValue::Integer(a), DataValue::String(b)) => {
319 match b.parse::<f64>() {
320 Ok(b_val) => op(*a as f64, b_val),
321 Err(_) => false, }
323 }
324 (DataValue::String(a), DataValue::Float(b)) => match a.parse::<f64>() {
325 Ok(a_val) => op(a_val, *b),
326 Err(_) => false,
327 },
328 (DataValue::Float(a), DataValue::String(b)) => match b.parse::<f64>() {
329 Ok(b_val) => op(*a, b_val),
330 Err(_) => false,
331 },
332
333 (DataValue::Null, _) | (_, DataValue::Null) => false,
335
336 (DataValue::Boolean(a), DataValue::Boolean(b)) => {
338 op(if *a { 1.0 } else { 0.0 }, if *b { 1.0 } else { 0.0 })
339 }
340
341 _ => {
342 debug!(
343 "ArithmeticEvaluator: unsupported comparison between {:?} and {:?}",
344 left, right
345 );
346 false
347 }
348 };
349
350 debug!("ArithmeticEvaluator: comparison result: {}", result);
351 Ok(DataValue::Boolean(result))
352 }
353
354 fn evaluate_function(
356 &self,
357 name: &str,
358 args: &[SqlExpression],
359 row_index: usize,
360 ) -> Result<DataValue> {
361 match name {
362 "ROUND" => {
363 if args.is_empty() || args.len() > 2 {
364 return Err(anyhow!("ROUND requires 1 or 2 arguments"));
365 }
366
367 let value = self.evaluate(&args[0], row_index)?;
369
370 let decimals = if args.len() == 2 {
372 match self.evaluate(&args[1], row_index)? {
373 DataValue::Integer(n) => n as i32,
374 DataValue::Float(f) => f as i32,
375 _ => return Err(anyhow!("ROUND precision must be a number")),
376 }
377 } else {
378 0
379 };
380
381 match value {
383 DataValue::Integer(n) => Ok(DataValue::Integer(n)), DataValue::Float(f) => {
385 if decimals >= 0 {
386 let multiplier = 10_f64.powi(decimals);
387 let rounded = (f * multiplier).round() / multiplier;
388 if decimals == 0 {
389 Ok(DataValue::Integer(rounded as i64))
391 } else {
392 Ok(DataValue::Float(rounded))
393 }
394 } else {
395 let divisor = 10_f64.powi(-decimals);
397 let rounded = (f / divisor).round() * divisor;
398 Ok(DataValue::Float(rounded))
399 }
400 }
401 _ => Err(anyhow!("ROUND can only be applied to numeric values")),
402 }
403 }
404 "ABS" => {
405 if args.len() != 1 {
406 return Err(anyhow!("ABS requires exactly 1 argument"));
407 }
408
409 let value = self.evaluate(&args[0], row_index)?;
410 match value {
411 DataValue::Integer(n) => Ok(DataValue::Integer(n.abs())),
412 DataValue::Float(f) => Ok(DataValue::Float(f.abs())),
413 _ => Err(anyhow!("ABS can only be applied to numeric values")),
414 }
415 }
416 "FLOOR" => {
417 if args.len() != 1 {
418 return Err(anyhow!("FLOOR requires exactly 1 argument"));
419 }
420
421 let value = self.evaluate(&args[0], row_index)?;
422 match value {
423 DataValue::Integer(n) => Ok(DataValue::Integer(n)),
424 DataValue::Float(f) => Ok(DataValue::Integer(f.floor() as i64)),
425 _ => Err(anyhow!("FLOOR can only be applied to numeric values")),
426 }
427 }
428 "CEILING" | "CEIL" => {
429 if args.len() != 1 {
430 return Err(anyhow!("CEILING requires exactly 1 argument"));
431 }
432
433 let value = self.evaluate(&args[0], row_index)?;
434 match value {
435 DataValue::Integer(n) => Ok(DataValue::Integer(n)),
436 DataValue::Float(f) => Ok(DataValue::Integer(f.ceil() as i64)),
437 _ => Err(anyhow!("CEILING can only be applied to numeric values")),
438 }
439 }
440 "MOD" => {
441 if args.len() != 2 {
442 return Err(anyhow!("MOD requires exactly 2 arguments"));
443 }
444
445 let dividend = self.evaluate(&args[0], row_index)?;
446 let divisor = self.evaluate(&args[1], row_index)?;
447
448 match (÷nd, &divisor) {
449 (DataValue::Integer(n), DataValue::Integer(d)) => {
450 if *d == 0 {
451 return Err(anyhow!("Division by zero in MOD"));
452 }
453 Ok(DataValue::Integer(n % d))
454 }
455 _ => {
456 let n = match dividend {
458 DataValue::Integer(i) => i as f64,
459 DataValue::Float(f) => f,
460 _ => return Err(anyhow!("MOD requires numeric arguments")),
461 };
462 let d = match divisor {
463 DataValue::Integer(i) => i as f64,
464 DataValue::Float(f) => f,
465 _ => return Err(anyhow!("MOD requires numeric arguments")),
466 };
467 if d == 0.0 {
468 return Err(anyhow!("Division by zero in MOD"));
469 }
470 Ok(DataValue::Float(n % d))
471 }
472 }
473 }
474 "QUOTIENT" => {
475 if args.len() != 2 {
476 return Err(anyhow!("QUOTIENT requires exactly 2 arguments"));
477 }
478
479 let numerator = self.evaluate(&args[0], row_index)?;
480 let denominator = self.evaluate(&args[1], row_index)?;
481
482 match (&numerator, &denominator) {
483 (DataValue::Integer(n), DataValue::Integer(d)) => {
484 if *d == 0 {
485 return Err(anyhow!("Division by zero in QUOTIENT"));
486 }
487 Ok(DataValue::Integer(n / d))
488 }
489 _ => {
490 let n = match numerator {
492 DataValue::Integer(i) => i as f64,
493 DataValue::Float(f) => f,
494 _ => return Err(anyhow!("QUOTIENT requires numeric arguments")),
495 };
496 let d = match denominator {
497 DataValue::Integer(i) => i as f64,
498 DataValue::Float(f) => f,
499 _ => return Err(anyhow!("QUOTIENT requires numeric arguments")),
500 };
501 if d == 0.0 {
502 return Err(anyhow!("Division by zero in QUOTIENT"));
503 }
504 Ok(DataValue::Integer((n / d).trunc() as i64))
505 }
506 }
507 }
508 "POWER" | "POW" => {
509 if args.len() != 2 {
510 return Err(anyhow!("POWER requires exactly 2 arguments"));
511 }
512
513 let base = self.evaluate(&args[0], row_index)?;
514 let exponent = self.evaluate(&args[1], row_index)?;
515
516 match (&base, &exponent) {
517 (DataValue::Integer(b), DataValue::Integer(e)) => {
518 if *e >= 0 && *e <= i32::MAX as i64 {
519 Ok(DataValue::Float((*b as f64).powi(*e as i32)))
520 } else {
521 Ok(DataValue::Float((*b as f64).powf(*e as f64)))
522 }
523 }
524 _ => {
525 let b = match base {
527 DataValue::Integer(i) => i as f64,
528 DataValue::Float(f) => f,
529 _ => return Err(anyhow!("POWER requires numeric arguments")),
530 };
531 let e = match exponent {
532 DataValue::Integer(i) => i as f64,
533 DataValue::Float(f) => f,
534 _ => return Err(anyhow!("POWER requires numeric arguments")),
535 };
536 Ok(DataValue::Float(b.powf(e)))
537 }
538 }
539 }
540 "SQRT" => {
541 if args.len() != 1 {
542 return Err(anyhow!("SQRT requires exactly 1 argument"));
543 }
544
545 let value = self.evaluate(&args[0], row_index)?;
546 match value {
547 DataValue::Integer(n) => {
548 if n < 0 {
549 return Err(anyhow!("SQRT of negative number"));
550 }
551 Ok(DataValue::Float((n as f64).sqrt()))
552 }
553 DataValue::Float(f) => {
554 if f < 0.0 {
555 return Err(anyhow!("SQRT of negative number"));
556 }
557 Ok(DataValue::Float(f.sqrt()))
558 }
559 _ => Err(anyhow!("SQRT can only be applied to numeric values")),
560 }
561 }
562 "EXP" => {
563 if args.len() != 1 {
564 return Err(anyhow!("EXP requires exactly 1 argument"));
565 }
566
567 let value = self.evaluate(&args[0], row_index)?;
568 match value {
569 DataValue::Integer(n) => Ok(DataValue::Float((n as f64).exp())),
570 DataValue::Float(f) => Ok(DataValue::Float(f.exp())),
571 _ => Err(anyhow!("EXP can only be applied to numeric values")),
572 }
573 }
574 "LN" => {
575 if args.len() != 1 {
576 return Err(anyhow!("LN requires exactly 1 argument"));
577 }
578
579 let value = self.evaluate(&args[0], row_index)?;
580 match value {
581 DataValue::Integer(n) => {
582 if n <= 0 {
583 return Err(anyhow!("LN of non-positive number"));
584 }
585 Ok(DataValue::Float((n as f64).ln()))
586 }
587 DataValue::Float(f) => {
588 if f <= 0.0 {
589 return Err(anyhow!("LN of non-positive number"));
590 }
591 Ok(DataValue::Float(f.ln()))
592 }
593 _ => Err(anyhow!("LN can only be applied to numeric values")),
594 }
595 }
596 "LOG" | "LOG10" => {
597 if name == "LOG" && args.len() == 2 {
598 let value = self.evaluate(&args[0], row_index)?;
600 let base = self.evaluate(&args[1], row_index)?;
601
602 let n = match value {
603 DataValue::Integer(i) => i as f64,
604 DataValue::Float(f) => f,
605 _ => return Err(anyhow!("LOG requires numeric arguments")),
606 };
607 let b = match base {
608 DataValue::Integer(i) => i as f64,
609 DataValue::Float(f) => f,
610 _ => return Err(anyhow!("LOG requires numeric arguments")),
611 };
612
613 if n <= 0.0 {
614 return Err(anyhow!("LOG of non-positive number"));
615 }
616 if b <= 0.0 || b == 1.0 {
617 return Err(anyhow!("Invalid LOG base"));
618 }
619 Ok(DataValue::Float(n.log(b)))
620 } else if (name == "LOG" && args.len() == 1) || name == "LOG10" {
621 if args.len() != 1 {
623 return Err(anyhow!("{} requires exactly 1 argument", name));
624 }
625
626 let value = self.evaluate(&args[0], row_index)?;
627 match value {
628 DataValue::Integer(n) => {
629 if n <= 0 {
630 return Err(anyhow!("LOG10 of non-positive number"));
631 }
632 Ok(DataValue::Float((n as f64).log10()))
633 }
634 DataValue::Float(f) => {
635 if f <= 0.0 {
636 return Err(anyhow!("LOG10 of non-positive number"));
637 }
638 Ok(DataValue::Float(f.log10()))
639 }
640 _ => Err(anyhow!("LOG10 can only be applied to numeric values")),
641 }
642 } else {
643 Err(anyhow!("LOG requires 1 or 2 arguments"))
644 }
645 }
646 "PI" => {
647 if !args.is_empty() {
648 return Err(anyhow!("PI takes no arguments"));
649 }
650 Ok(DataValue::Float(std::f64::consts::PI))
651 }
652 "DATEDIFF" => {
653 if args.len() != 3 {
654 return Err(anyhow!(
655 "DATEDIFF requires exactly 3 arguments: unit, date1, date2"
656 ));
657 }
658
659 let unit = match self.evaluate(&args[0], row_index)? {
661 DataValue::String(s) => s.to_lowercase(),
662 DataValue::InternedString(s) => s.to_lowercase(),
663 _ => return Err(anyhow!("DATEDIFF unit must be a string")),
664 };
665
666 let parse_datetime = |value: DataValue| -> Result<DateTime<Utc>> {
668 let parse_string = |s: &str| -> Result<DateTime<Utc>> {
669 if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") {
673 return Ok(Utc.from_utc_datetime(&dt));
674 }
675 if let Ok(dt) = NaiveDate::parse_from_str(s, "%Y-%m-%d") {
676 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
677 }
678
679 if let Ok(dt) = NaiveDate::parse_from_str(s, "%m/%d/%Y") {
681 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
682 }
683 if let Ok(dt) = NaiveDate::parse_from_str(s, "%m-%d-%Y") {
684 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
685 }
686
687 if let Ok(dt) = NaiveDate::parse_from_str(s, "%d/%m/%Y") {
689 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
690 }
691 if let Ok(dt) = NaiveDate::parse_from_str(s, "%d-%m-%Y") {
692 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
693 }
694
695 if let Ok(dt) = NaiveDate::parse_from_str(s, "%d-%b-%Y") {
697 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
698 }
699
700 if let Ok(dt) = NaiveDate::parse_from_str(s, "%B %d, %Y") {
702 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
703 }
704 if let Ok(dt) = NaiveDate::parse_from_str(s, "%d %B %Y") {
705 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
706 }
707
708 if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%m/%d/%Y %H:%M:%S") {
710 return Ok(Utc.from_utc_datetime(&dt));
711 }
712 if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%d/%m/%Y %H:%M:%S") {
713 return Ok(Utc.from_utc_datetime(&dt));
714 }
715
716 if let Ok(dt) = s.parse::<DateTime<Utc>>() {
718 return Ok(dt);
719 }
720
721 Err(anyhow!("Could not parse date: {}. Supported formats: YYYY-MM-DD, MM/DD/YYYY, DD/MM/YYYY, DD-MMM-YYYY", s))
722 };
723
724 match value {
725 DataValue::String(s) | DataValue::DateTime(s) => parse_string(&s),
726 DataValue::InternedString(s) => parse_string(s.as_str()),
727 _ => Err(anyhow!("DATEDIFF requires date/datetime values")),
728 }
729 };
730
731 let date1 = parse_datetime(self.evaluate(&args[1], row_index)?)?;
733 let date2 = parse_datetime(self.evaluate(&args[2], row_index)?)?;
734
735 let diff = match unit.as_str() {
737 "day" | "days" => {
738 let duration = date2.signed_duration_since(date1);
739 duration.num_days()
740 }
741 "month" | "months" => {
742 let duration = date2.signed_duration_since(date1);
744 duration.num_days() / 30
745 }
746 "year" | "years" => {
747 let duration = date2.signed_duration_since(date1);
749 duration.num_days() / 365
750 }
751 "hour" | "hours" => {
752 let duration = date2.signed_duration_since(date1);
753 duration.num_hours()
754 }
755 "minute" | "minutes" => {
756 let duration = date2.signed_duration_since(date1);
757 duration.num_minutes()
758 }
759 "second" | "seconds" => {
760 let duration = date2.signed_duration_since(date1);
761 duration.num_seconds()
762 }
763 _ => {
764 return Err(anyhow!(
765 "Unknown DATEDIFF unit: {}. Use: day, month, year, hour, minute, second",
766 unit
767 ))
768 }
769 };
770
771 Ok(DataValue::Integer(diff))
772 }
773 "NOW" => {
774 if !args.is_empty() {
775 return Err(anyhow!("NOW takes no arguments"));
776 }
777 let now = Utc::now();
778 Ok(DataValue::DateTime(
779 now.format("%Y-%m-%d %H:%M:%S").to_string(),
780 ))
781 }
782 "TODAY" => {
783 if !args.is_empty() {
784 return Err(anyhow!("TODAY takes no arguments"));
785 }
786 let today = Utc::now().date_naive();
787 Ok(DataValue::String(today.format("%Y-%m-%d").to_string()))
788 }
789 "DATEADD" => {
790 if args.len() != 3 {
791 return Err(anyhow!(
792 "DATEADD requires exactly 3 arguments: unit, number, date"
793 ));
794 }
795
796 let unit = match self.evaluate(&args[0], row_index)? {
798 DataValue::String(s) => s.to_lowercase(),
799 DataValue::InternedString(s) => s.to_lowercase(),
800 _ => return Err(anyhow!("DATEADD unit must be a string")),
801 };
802
803 let amount = match self.evaluate(&args[1], row_index)? {
805 DataValue::Integer(i) => i,
806 DataValue::Float(f) => f as i64,
807 _ => return Err(anyhow!("DATEADD amount must be a number")),
808 };
809
810 let base_date_value = self.evaluate(&args[2], row_index)?;
812
813 let parse_datetime = |value: DataValue| -> Result<DateTime<Utc>> {
815 let parse_string = |s: &str| -> Result<DateTime<Utc>> {
816 if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") {
820 return Ok(Utc.from_utc_datetime(&dt));
821 }
822 if let Ok(dt) = NaiveDate::parse_from_str(s, "%Y-%m-%d") {
823 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
824 }
825
826 if let Ok(dt) = NaiveDate::parse_from_str(s, "%m/%d/%Y") {
828 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
829 }
830 if let Ok(dt) = NaiveDate::parse_from_str(s, "%m-%d-%Y") {
831 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
832 }
833
834 if let Ok(dt) = NaiveDate::parse_from_str(s, "%d/%m/%Y") {
836 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
837 }
838 if let Ok(dt) = NaiveDate::parse_from_str(s, "%d-%m-%Y") {
839 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
840 }
841
842 if let Ok(dt) = NaiveDate::parse_from_str(s, "%d-%b-%Y") {
844 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
845 }
846
847 if let Ok(dt) = NaiveDate::parse_from_str(s, "%B %d, %Y") {
849 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
850 }
851 if let Ok(dt) = NaiveDate::parse_from_str(s, "%d %B %Y") {
852 return Ok(Utc.from_utc_datetime(&dt.and_hms_opt(0, 0, 0).unwrap()));
853 }
854
855 if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%m/%d/%Y %H:%M:%S") {
857 return Ok(Utc.from_utc_datetime(&dt));
858 }
859 if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%d/%m/%Y %H:%M:%S") {
860 return Ok(Utc.from_utc_datetime(&dt));
861 }
862
863 if let Ok(dt) = s.parse::<DateTime<Utc>>() {
865 return Ok(dt);
866 }
867
868 Err(anyhow!("Could not parse date: {}. Supported formats: YYYY-MM-DD, MM/DD/YYYY, DD/MM/YYYY, DD-MMM-YYYY", s))
869 };
870
871 match value {
872 DataValue::String(s) | DataValue::DateTime(s) => parse_string(&s),
873 DataValue::InternedString(s) => parse_string(s.as_str()),
874 _ => Err(anyhow!("DATEADD requires date/datetime values")),
875 }
876 };
877
878 let base_date = parse_datetime(base_date_value)?;
880
881 let result_date = match unit.as_str() {
883 "day" | "days" => base_date + chrono::Duration::days(amount),
884 "month" | "months" => {
885 let mut year = base_date.year();
887 let mut month = base_date.month() as i32;
888 let day = base_date.day();
889
890 month += amount as i32;
891
892 while month > 12 {
894 month -= 12;
895 year += 1;
896 }
897 while month < 1 {
898 month += 12;
899 year -= 1;
900 }
901
902 let target_date = NaiveDate::from_ymd_opt(year, month as u32, day)
904 .unwrap_or_else(|| {
905 for test_day in (1..=day).rev() {
908 if let Some(date) =
909 NaiveDate::from_ymd_opt(year, month as u32, test_day)
910 {
911 return date;
912 }
913 }
914 NaiveDate::from_ymd_opt(year, month as u32, 28).unwrap()
916 });
917
918 Utc.from_utc_datetime(&target_date.and_time(base_date.time()))
919 }
920 "year" | "years" => {
921 let new_year = base_date.year() + amount as i32;
922 let target_date =
923 NaiveDate::from_ymd_opt(new_year, base_date.month(), base_date.day())
924 .unwrap_or_else(|| {
925 NaiveDate::from_ymd_opt(new_year, base_date.month(), 28)
927 .unwrap()
928 });
929 Utc.from_utc_datetime(&target_date.and_time(base_date.time()))
930 }
931 "hour" | "hours" => base_date + chrono::Duration::hours(amount),
932 "minute" | "minutes" => base_date + chrono::Duration::minutes(amount),
933 "second" | "seconds" => base_date + chrono::Duration::seconds(amount),
934 _ => {
935 return Err(anyhow!(
936 "Unknown DATEADD unit: {}. Use: day, month, year, hour, minute, second",
937 unit
938 ))
939 }
940 };
941
942 Ok(DataValue::DateTime(
944 result_date.format("%Y-%m-%d %H:%M:%S").to_string(),
945 ))
946 }
947 "TEXTJOIN" => {
948 if args.len() < 3 {
949 return Err(anyhow!("TEXTJOIN requires at least 3 arguments: delimiter, ignore_empty, text1, [text2, ...]"));
950 }
951
952 let delimiter = match self.evaluate(&args[0], row_index)? {
954 DataValue::String(s) => s,
955 DataValue::InternedString(s) => s.to_string(),
956 DataValue::Integer(n) => n.to_string(),
957 DataValue::Float(f) => f.to_string(),
958 DataValue::Boolean(b) => b.to_string(),
959 DataValue::Null => String::new(),
960 _ => String::new(),
961 };
962
963 let ignore_empty = match self.evaluate(&args[1], row_index)? {
965 DataValue::Integer(n) => n != 0,
966 DataValue::Float(f) => f != 0.0,
967 DataValue::Boolean(b) => b,
968 DataValue::String(s) => {
969 !s.is_empty() && s != "0" && s.to_lowercase() != "false"
970 }
971 DataValue::InternedString(s) => {
972 !s.is_empty() && s.as_str() != "0" && s.to_lowercase() != "false"
973 }
974 DataValue::Null => false,
975 _ => true,
976 };
977
978 let mut values = Vec::new();
980 for i in 2..args.len() {
981 let value = self.evaluate(&args[i], row_index)?;
982 let string_value = match value {
983 DataValue::String(s) => Some(s),
984 DataValue::InternedString(s) => Some(s.to_string()),
985 DataValue::Integer(n) => Some(n.to_string()),
986 DataValue::Float(f) => Some(f.to_string()),
987 DataValue::Boolean(b) => Some(b.to_string()),
988 DataValue::DateTime(dt) => Some(dt),
989 DataValue::Null => {
990 if ignore_empty {
991 None
992 } else {
993 Some(String::new())
994 }
995 }
996 _ => {
997 if ignore_empty {
998 None
999 } else {
1000 Some(String::new())
1001 }
1002 }
1003 };
1004
1005 if let Some(s) = string_value {
1006 if !ignore_empty || !s.is_empty() {
1007 values.push(s);
1008 }
1009 }
1010 }
1011
1012 Ok(DataValue::String(values.join(&delimiter)))
1013 }
1014 _ => Err(anyhow!("Unknown function: {}", name)),
1015 }
1016 }
1017
1018 fn evaluate_method_call(
1020 &self,
1021 object: &str,
1022 method: &str,
1023 args: &[SqlExpression],
1024 row_index: usize,
1025 ) -> Result<DataValue> {
1026 let col_index = self.table.get_column_index(object).ok_or_else(|| {
1028 let suggestion = self.find_similar_column(object);
1029 match suggestion {
1030 Some(similar) => {
1031 anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
1032 }
1033 None => anyhow!("Column '{}' not found", object),
1034 }
1035 })?;
1036
1037 let cell_value = self.table.get_value(row_index, col_index).cloned();
1038
1039 self.evaluate_method_on_value(
1040 &cell_value.unwrap_or(DataValue::Null),
1041 method,
1042 args,
1043 row_index,
1044 )
1045 }
1046
1047 fn evaluate_method_on_value(
1049 &self,
1050 value: &DataValue,
1051 method: &str,
1052 args: &[SqlExpression],
1053 row_index: usize,
1054 ) -> Result<DataValue> {
1055 match method.to_lowercase().as_str() {
1056 "trim" | "trimstart" | "trimend" => {
1057 if !args.is_empty() {
1058 return Err(anyhow!("{} takes no arguments", method));
1059 }
1060
1061 let str_val = match value {
1063 DataValue::String(s) => s.clone(),
1064 DataValue::InternedString(s) => s.to_string(),
1065 DataValue::Integer(n) => n.to_string(),
1066 DataValue::Float(f) => f.to_string(),
1067 DataValue::Boolean(b) => b.to_string(),
1068 DataValue::DateTime(dt) => dt.clone(),
1069 DataValue::Null => return Ok(DataValue::Null),
1070 };
1071
1072 let result = match method.to_lowercase().as_str() {
1073 "trim" => str_val.trim().to_string(),
1074 "trimstart" => str_val.trim_start().to_string(),
1075 "trimend" => str_val.trim_end().to_string(),
1076 _ => unreachable!(),
1077 };
1078
1079 Ok(DataValue::String(result))
1080 }
1081 "length" => {
1082 if !args.is_empty() {
1083 return Err(anyhow!("Length takes no arguments"));
1084 }
1085
1086 let len = match value {
1088 DataValue::String(s) => s.len(),
1089 DataValue::InternedString(s) => s.len(),
1090 DataValue::Integer(n) => n.to_string().len(),
1091 DataValue::Float(f) => f.to_string().len(),
1092 DataValue::Boolean(b) => b.to_string().len(),
1093 DataValue::DateTime(dt) => dt.len(),
1094 DataValue::Null => return Ok(DataValue::Integer(0)),
1095 };
1096
1097 Ok(DataValue::Integer(len as i64))
1098 }
1099 "indexof" => {
1100 if args.len() != 1 {
1101 return Err(anyhow!("IndexOf requires exactly 1 argument"));
1102 }
1103
1104 let search_str = match self.evaluate(&args[0], row_index)? {
1106 DataValue::String(s) => s,
1107 DataValue::InternedString(s) => s.to_string(),
1108 DataValue::Integer(n) => n.to_string(),
1109 DataValue::Float(f) => f.to_string(),
1110 _ => return Err(anyhow!("IndexOf argument must be a string")),
1111 };
1112
1113 let str_val = match value {
1115 DataValue::String(s) => s.clone(),
1116 DataValue::InternedString(s) => s.to_string(),
1117 DataValue::Integer(n) => n.to_string(),
1118 DataValue::Float(f) => f.to_string(),
1119 DataValue::Boolean(b) => b.to_string(),
1120 DataValue::DateTime(dt) => dt.clone(),
1121 DataValue::Null => return Ok(DataValue::Integer(-1)),
1122 };
1123
1124 let index = str_val.find(&search_str).map(|i| i as i64).unwrap_or(-1);
1125
1126 Ok(DataValue::Integer(index))
1127 }
1128 "contains" => {
1129 if args.len() != 1 {
1130 return Err(anyhow!("Contains requires exactly 1 argument"));
1131 }
1132
1133 let search_str = match self.evaluate(&args[0], row_index)? {
1135 DataValue::String(s) => s,
1136 DataValue::InternedString(s) => s.to_string(),
1137 DataValue::Integer(n) => n.to_string(),
1138 DataValue::Float(f) => f.to_string(),
1139 _ => return Err(anyhow!("Contains argument must be a string")),
1140 };
1141
1142 let str_val = match value {
1144 DataValue::String(s) => s.clone(),
1145 DataValue::InternedString(s) => s.to_string(),
1146 DataValue::Integer(n) => n.to_string(),
1147 DataValue::Float(f) => f.to_string(),
1148 DataValue::Boolean(b) => b.to_string(),
1149 DataValue::DateTime(dt) => dt.clone(),
1150 DataValue::Null => return Ok(DataValue::Boolean(false)),
1151 };
1152
1153 let result = str_val.to_lowercase().contains(&search_str.to_lowercase());
1155 Ok(DataValue::Boolean(result))
1156 }
1157 "startswith" => {
1158 if args.len() != 1 {
1159 return Err(anyhow!("StartsWith requires exactly 1 argument"));
1160 }
1161
1162 let prefix = match self.evaluate(&args[0], row_index)? {
1164 DataValue::String(s) => s,
1165 DataValue::InternedString(s) => s.to_string(),
1166 DataValue::Integer(n) => n.to_string(),
1167 DataValue::Float(f) => f.to_string(),
1168 _ => return Err(anyhow!("StartsWith argument must be a string")),
1169 };
1170
1171 let str_val = match value {
1173 DataValue::String(s) => s.clone(),
1174 DataValue::InternedString(s) => s.to_string(),
1175 DataValue::Integer(n) => n.to_string(),
1176 DataValue::Float(f) => f.to_string(),
1177 DataValue::Boolean(b) => b.to_string(),
1178 DataValue::DateTime(dt) => dt.clone(),
1179 DataValue::Null => return Ok(DataValue::Boolean(false)),
1180 };
1181
1182 let result = str_val.to_lowercase().starts_with(&prefix.to_lowercase());
1184 Ok(DataValue::Boolean(result))
1185 }
1186 "endswith" => {
1187 if args.len() != 1 {
1188 return Err(anyhow!("EndsWith requires exactly 1 argument"));
1189 }
1190
1191 let suffix = match self.evaluate(&args[0], row_index)? {
1193 DataValue::String(s) => s,
1194 DataValue::InternedString(s) => s.to_string(),
1195 DataValue::Integer(n) => n.to_string(),
1196 DataValue::Float(f) => f.to_string(),
1197 _ => return Err(anyhow!("EndsWith argument must be a string")),
1198 };
1199
1200 let str_val = match value {
1202 DataValue::String(s) => s.clone(),
1203 DataValue::InternedString(s) => s.to_string(),
1204 DataValue::Integer(n) => n.to_string(),
1205 DataValue::Float(f) => f.to_string(),
1206 DataValue::Boolean(b) => b.to_string(),
1207 DataValue::DateTime(dt) => dt.clone(),
1208 DataValue::Null => return Ok(DataValue::Boolean(false)),
1209 };
1210
1211 let result = str_val.to_lowercase().ends_with(&suffix.to_lowercase());
1213 Ok(DataValue::Boolean(result))
1214 }
1215 _ => Err(anyhow!("Unsupported method: {}", method)),
1216 }
1217 }
1218
1219 fn evaluate_case_expression(
1221 &self,
1222 when_branches: &[crate::sql::recursive_parser::WhenBranch],
1223 else_branch: &Option<Box<SqlExpression>>,
1224 row_index: usize,
1225 ) -> Result<DataValue> {
1226 debug!(
1227 "ArithmeticEvaluator: evaluating CASE expression for row {}",
1228 row_index
1229 );
1230
1231 for branch in when_branches {
1233 let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
1235
1236 if condition_result {
1237 debug!("CASE: WHEN condition matched, evaluating result expression");
1238 return self.evaluate(&branch.result, row_index);
1239 }
1240 }
1241
1242 match else_branch {
1244 Some(else_expr) => {
1245 debug!("CASE: No WHEN matched, evaluating ELSE expression");
1246 self.evaluate(else_expr, row_index)
1247 }
1248 None => {
1249 debug!("CASE: No WHEN matched and no ELSE, returning NULL");
1250 Ok(DataValue::Null)
1251 }
1252 }
1253 }
1254
1255 fn evaluate_condition_as_bool(&self, expr: &SqlExpression, row_index: usize) -> Result<bool> {
1257 let value = self.evaluate(expr, row_index)?;
1258
1259 match value {
1260 DataValue::Boolean(b) => Ok(b),
1261 DataValue::Integer(i) => Ok(i != 0),
1262 DataValue::Float(f) => Ok(f != 0.0),
1263 DataValue::Null => Ok(false),
1264 DataValue::String(s) => Ok(!s.is_empty()),
1265 DataValue::InternedString(s) => Ok(!s.is_empty()),
1266 _ => Ok(true), }
1268 }
1269}
1270
1271#[cfg(test)]
1272mod tests {
1273 use super::*;
1274 use crate::data::datatable::{DataColumn, DataRow};
1275
1276 fn create_test_table() -> DataTable {
1277 let mut table = DataTable::new("test");
1278 table.add_column(DataColumn::new("a"));
1279 table.add_column(DataColumn::new("b"));
1280 table.add_column(DataColumn::new("c"));
1281
1282 table
1283 .add_row(DataRow::new(vec![
1284 DataValue::Integer(10),
1285 DataValue::Float(2.5),
1286 DataValue::Integer(4),
1287 ]))
1288 .unwrap();
1289
1290 table
1291 }
1292
1293 #[test]
1294 fn test_evaluate_column() {
1295 let table = create_test_table();
1296 let evaluator = ArithmeticEvaluator::new(&table);
1297
1298 let expr = SqlExpression::Column("a".to_string());
1299 let result = evaluator.evaluate(&expr, 0).unwrap();
1300 assert_eq!(result, DataValue::Integer(10));
1301 }
1302
1303 #[test]
1304 fn test_evaluate_number_literal() {
1305 let table = create_test_table();
1306 let evaluator = ArithmeticEvaluator::new(&table);
1307
1308 let expr = SqlExpression::NumberLiteral("42".to_string());
1309 let result = evaluator.evaluate(&expr, 0).unwrap();
1310 assert_eq!(result, DataValue::Integer(42));
1311
1312 let expr = SqlExpression::NumberLiteral("3.14".to_string());
1313 let result = evaluator.evaluate(&expr, 0).unwrap();
1314 assert_eq!(result, DataValue::Float(3.14));
1315 }
1316
1317 #[test]
1318 fn test_add_values() {
1319 let table = create_test_table();
1320 let evaluator = ArithmeticEvaluator::new(&table);
1321
1322 let result = evaluator
1324 .add_values(&DataValue::Integer(5), &DataValue::Integer(3))
1325 .unwrap();
1326 assert_eq!(result, DataValue::Integer(8));
1327
1328 let result = evaluator
1330 .add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
1331 .unwrap();
1332 assert_eq!(result, DataValue::Float(7.5));
1333 }
1334
1335 #[test]
1336 fn test_multiply_values() {
1337 let table = create_test_table();
1338 let evaluator = ArithmeticEvaluator::new(&table);
1339
1340 let result = evaluator
1342 .multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
1343 .unwrap();
1344 assert_eq!(result, DataValue::Float(10.0));
1345 }
1346
1347 #[test]
1348 fn test_divide_values() {
1349 let table = create_test_table();
1350 let evaluator = ArithmeticEvaluator::new(&table);
1351
1352 let result = evaluator
1354 .divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
1355 .unwrap();
1356 assert_eq!(result, DataValue::Integer(5));
1357
1358 let result = evaluator
1360 .divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
1361 .unwrap();
1362 assert_eq!(result, DataValue::Float(10.0 / 3.0));
1363 }
1364
1365 #[test]
1366 fn test_division_by_zero() {
1367 let table = create_test_table();
1368 let evaluator = ArithmeticEvaluator::new(&table);
1369
1370 let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
1371 assert!(result.is_err());
1372 assert!(result.unwrap_err().to_string().contains("Division by zero"));
1373 }
1374
1375 #[test]
1376 fn test_binary_op_expression() {
1377 let table = create_test_table();
1378 let evaluator = ArithmeticEvaluator::new(&table);
1379
1380 let expr = SqlExpression::BinaryOp {
1382 left: Box::new(SqlExpression::Column("a".to_string())),
1383 op: "*".to_string(),
1384 right: Box::new(SqlExpression::Column("b".to_string())),
1385 };
1386
1387 let result = evaluator.evaluate(&expr, 0).unwrap();
1388 assert_eq!(result, DataValue::Float(25.0));
1389 }
1390}